diff --git a/src/dataset.py b/src/dataset.py index 1a08b04..1263764 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -1,10 +1,9 @@ import torch -import pandas as pd -from PIL import Image import json from keypoint_extractor import KeypointExtractor from collections import OrderedDict from identifiers import LANDMARKS +import numpy as np class WLASLDataset(torch.utils.data.Dataset): def __init__(self, json_file: str, missing: str, keypoint_extractor: KeypointExtractor, subset:str="train", keypoints_identifier: dict = None, transform=None): @@ -50,10 +49,15 @@ class WLASLDataset(torch.utils.data.Dataset): # filter the keypoints by the identified subset if self.keypoints_to_keep: keypoints_df = keypoints_df[self.keypoints_to_keep] + + current_row = np.empty(shape=(keypoints_df.shape[0], keypoints_df.shape[1] // 2, 2)) + for i in range(0, keypoints_df.shape[1], 2): + current_row[:, i//2, 0] = keypoints_df.iloc[:,i] + current_row[:, i//2, 1] = keypoints_df.iloc[:,i+1] - # TODO: convert keypoints to tensor and return + label = self.data[video_id]["action"][0] -k = KeypointExtractor("data/videos/") -d = WLASLDataset("data/nslt_100.json", "data/missing.txt", k, keypoints_identifier=LANDMARKS) + # data to tensor + data = torch.from_numpy(current_row) -d.__getitem__(0) \ No newline at end of file + return data, label \ No newline at end of file diff --git a/src/identifiers.py b/src/identifiers.py index d301e5d..c1f8465 100644 --- a/src/identifiers.py +++ b/src/identifiers.py @@ -35,7 +35,7 @@ LANDMARKS = { "right_foot_index": 32, # Left Hand Landmarks - "left_wrist": 33, + "left_wrist2": 33, "left_thumb_cmc": 34, "left_thumb_mcp": 35, "left_thumb_ip": 36, @@ -58,7 +58,7 @@ LANDMARKS = { "left_pinky_tip": 53, # Right Hand Landmarks - "right_wrist": 54, + "right_wrist2": 54, "right_thumb_cmc": 55, "right_thumb_mcp": 56, "right_thumb_ip": 57, diff --git a/src/keypoint_extractor.py b/src/keypoint_extractor.py index 93804d6..a0ee964 100644 --- a/src/keypoint_extractor.py +++ b/src/keypoint_extractor.py @@ -45,7 +45,7 @@ class KeypointExtractor: # check if cache file exists and return if os.path.exists(self.cache_folder + "/" + video + ".npy"): # create dataframe from cache - return pd.DataFrame(np.load(self.cache_folder + "/" + video + ".npy"), columns=self.columns) + return pd.DataFrame(np.load(self.cache_folder + "/" + video + ".npy", allow_pickle=True), columns=self.columns) # open video cap = cv2.VideoCapture(self.video_folder + video) @@ -89,7 +89,6 @@ class KeypointExtractor: :return: the keypoints and the frame with keypoints on if draw is set to True :rtype: np.ndarray """ - print("Extracting keypoints from frame") # Convert the BGR image to RGB and process it with MediaPipe Pose. results = self.holistic.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) diff --git a/src/model.py b/src/model.py new file mode 100644 index 0000000..c2aa1db --- /dev/null +++ b/src/model.py @@ -0,0 +1,68 @@ +### SPOTER model implementation from the paper "SPOTER: Sign Pose-based Transformer for Sign Language Recognition from Sequence of Skeletal Data" + +import copy +import torch + +import torch.nn as nn +from typing import Optional + + +def _get_clones(mod, n): + return nn.ModuleList([copy.deepcopy(mod) for _ in range(n)]) + + +class SPOTERTransformerDecoderLayer(nn.TransformerDecoderLayer): + """ + Edited TransformerDecoderLayer implementation omitting the redundant self-attention operation as opposed to the + standard implementation. + """ + + def __init__(self, d_model, nhead, dim_feedforward, dropout, activation): + super(SPOTERTransformerDecoderLayer, self).__init__(d_model, nhead, dim_feedforward, dropout, activation) + + del self.self_attn + + def forward(self, tgt: torch.Tensor, memory: torch.Tensor, tgt_mask: Optional[torch.Tensor] = None, + memory_mask: Optional[torch.Tensor] = None, tgt_key_padding_mask: Optional[torch.Tensor] = None, + memory_key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + + tgt = tgt + self.dropout1(tgt) + tgt = self.norm1(tgt) + tgt2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + + return tgt + + +class SPOTER(nn.Module): + """ + Implementation of the SPOTER (Sign POse-based TransformER) architecture for sign language recognition from sequence + of skeletal data. + """ + + def __init__(self, num_classes, hidden_dim=55): + super().__init__() + + self.row_embed = nn.Parameter(torch.rand(50, hidden_dim)) + self.pos = nn.Parameter(torch.cat([self.row_embed[0].unsqueeze(0).repeat(1, 1, 1)], dim=-1).flatten(0, 1).unsqueeze(0)) + self.class_query = nn.Parameter(torch.rand(1, hidden_dim)) + self.transformer = nn.Transformer(hidden_dim, 10, 6, 6) + self.linear_class = nn.Linear(hidden_dim, num_classes) + + # Deactivate the initial attention decoder mechanism + custom_decoder_layer = SPOTERTransformerDecoderLayer(self.transformer.d_model, self.transformer.nhead, 2048, + 0.1, "relu") + self.transformer.decoder.layers = _get_clones(custom_decoder_layer, self.transformer.decoder.num_layers) + + def forward(self, inputs): + h = torch.unsqueeze(inputs.flatten(start_dim=1), 1).float() + + h = self.transformer(self.pos + h, self.class_query.unsqueeze(0)).transpose(0, 1) + res = self.linear_class(h) + + return res \ No newline at end of file diff --git a/src/train.py b/src/train.py new file mode 100644 index 0000000..44c7f34 --- /dev/null +++ b/src/train.py @@ -0,0 +1,110 @@ +import os +import argparse +import random +import logging +import torch + +import numpy as np +import torch.nn as nn +import torch.optim as optim +import matplotlib.pyplot as plt +import matplotlib.ticker as ticker +from torchvision import transforms +from torch.utils.data import DataLoader +from pathlib import Path + +from keypoint_extractor import KeypointExtractor +from identifiers import LANDMARKS +from model import SPOTER +from dataset import WLASLDataset + +def train(): + random.seed(379) + np.random.seed(379) + os.environ['PYTHONHASHSEED'] = str(379) + torch.manual_seed(379) + torch.cuda.manual_seed(379) + torch.cuda.manual_seed_all(379) + torch.backends.cudnn.deterministic = True + g = torch.Generator() + g.manual_seed(379) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + spoter_model = SPOTER(num_classes=100, hidden_dim=2*75) + spoter_model.train(True) + spoter_model.to(device) + + criterion = nn.CrossEntropyLoss() + optimizer = optim.SGD(spoter_model.parameters(), lr=0.001, momentum=0.9) + scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=5) + + # TODO: create paths for checkpoints + + # TODO: transformations + augmentations + + k = KeypointExtractor("data/videos/") + + train_set = WLASLDataset("data/nslt_100.json", "data/missing.txt", k, keypoints_identifier=LANDMARKS, subset="train") + train_loader = DataLoader(train_set, shuffle=True, generator=g) + + val_set = WLASLDataset("data/nslt_100.json", "data/missing.txt", k, keypoints_identifier=LANDMARKS, subset="val") + val_loader = DataLoader(val_set, shuffle=True, generator=g) + + test_set = WLASLDataset("data/nslt_100.json", "data/missing.txt", k, keypoints_identifier=LANDMARKS, subset="test") + test_loader = DataLoader(test_set, shuffle=True, generator=g) + + + train_acc, val_acc = 0, 0 + losses, train_accs, val_accs = [], [], [] + lr_progress = [] + top_train_acc, top_val_acc = 0, 0 + checkpoint_index = 0 + + for epoch in range(100): + + # train + for i, (inputs, labels) in enumerate(train_loader): + inputs = inputs.squeeze(0).to(device) + labels = labels.to(device) + + optimizer.zero_grad() + outputs = spoter_model(inputs).expand(1, -1, -1) + loss = criterion(outputs[0], labels) + loss.backward() + optimizer.step() + + _, predicted = torch.max(outputs.data, 1) + train_acc = (predicted == labels).sum().item() / labels.size(0) + + losses.append(loss.item()) + train_accs.append(train_acc) + + if i % 100 == 0: + print(f"Epoch: {epoch} | Batch: {i} | Loss: {loss.item()} | Train Acc: {train_acc}") + + # validate + with torch.no_grad(): + for i, (inputs, labels) in enumerate(val_loader): + inputs = inputs.to(device) + labels = labels.to(device) + + outputs = spoter_model(inputs) + _, predicted = torch.max(outputs.data, 1) + val_acc = (predicted == labels).sum().item() / labels.size(0) + + val_accs.append(val_acc) + + scheduler.step(loss) + + # save checkpoint + if val_acc > top_val_acc: + top_val_acc = val_acc + top_train_acc = train_acc + checkpoint_index = epoch + torch.save(spoter_model.state_dict(), f"checkpoints/spoter_{epoch}.pth") + + print(f"Epoch: {epoch} | Train Acc: {train_acc} | Val Acc: {val_acc}") + lr_progress.append(optimizer.param_groups[0]['lr']) + +train() \ No newline at end of file diff --git a/test_output.jpg b/test_output.jpg deleted file mode 100644 index 17c0844..0000000 Binary files a/test_output.jpg and /dev/null differ