Added training loop and model

This commit is contained in:
2023-02-21 23:24:51 +00:00
parent 1e05c02a7e
commit 98f29f683e
6 changed files with 191 additions and 10 deletions

View File

@@ -1,10 +1,9 @@
import torch
import pandas as pd
from PIL import Image
import json
from keypoint_extractor import KeypointExtractor
from collections import OrderedDict
from identifiers import LANDMARKS
import numpy as np
class WLASLDataset(torch.utils.data.Dataset):
def __init__(self, json_file: str, missing: str, keypoint_extractor: KeypointExtractor, subset:str="train", keypoints_identifier: dict = None, transform=None):
@@ -50,10 +49,15 @@ class WLASLDataset(torch.utils.data.Dataset):
# filter the keypoints by the identified subset
if self.keypoints_to_keep:
keypoints_df = keypoints_df[self.keypoints_to_keep]
current_row = np.empty(shape=(keypoints_df.shape[0], keypoints_df.shape[1] // 2, 2))
for i in range(0, keypoints_df.shape[1], 2):
current_row[:, i//2, 0] = keypoints_df.iloc[:,i]
current_row[:, i//2, 1] = keypoints_df.iloc[:,i+1]
# TODO: convert keypoints to tensor and return
label = self.data[video_id]["action"][0]
k = KeypointExtractor("data/videos/")
d = WLASLDataset("data/nslt_100.json", "data/missing.txt", k, keypoints_identifier=LANDMARKS)
# data to tensor
data = torch.from_numpy(current_row)
d.__getitem__(0)
return data, label

View File

@@ -35,7 +35,7 @@ LANDMARKS = {
"right_foot_index": 32,
# Left Hand Landmarks
"left_wrist": 33,
"left_wrist2": 33,
"left_thumb_cmc": 34,
"left_thumb_mcp": 35,
"left_thumb_ip": 36,
@@ -58,7 +58,7 @@ LANDMARKS = {
"left_pinky_tip": 53,
# Right Hand Landmarks
"right_wrist": 54,
"right_wrist2": 54,
"right_thumb_cmc": 55,
"right_thumb_mcp": 56,
"right_thumb_ip": 57,

View File

@@ -45,7 +45,7 @@ class KeypointExtractor:
# check if cache file exists and return
if os.path.exists(self.cache_folder + "/" + video + ".npy"):
# create dataframe from cache
return pd.DataFrame(np.load(self.cache_folder + "/" + video + ".npy"), columns=self.columns)
return pd.DataFrame(np.load(self.cache_folder + "/" + video + ".npy", allow_pickle=True), columns=self.columns)
# open video
cap = cv2.VideoCapture(self.video_folder + video)
@@ -89,7 +89,6 @@ class KeypointExtractor:
:return: the keypoints and the frame with keypoints on if draw is set to True
:rtype: np.ndarray
"""
print("Extracting keypoints from frame")
# Convert the BGR image to RGB and process it with MediaPipe Pose.
results = self.holistic.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))

68
src/model.py Normal file
View File

@@ -0,0 +1,68 @@
### SPOTER model implementation from the paper "SPOTER: Sign Pose-based Transformer for Sign Language Recognition from Sequence of Skeletal Data"
import copy
import torch
import torch.nn as nn
from typing import Optional
def _get_clones(mod, n):
return nn.ModuleList([copy.deepcopy(mod) for _ in range(n)])
class SPOTERTransformerDecoderLayer(nn.TransformerDecoderLayer):
"""
Edited TransformerDecoderLayer implementation omitting the redundant self-attention operation as opposed to the
standard implementation.
"""
def __init__(self, d_model, nhead, dim_feedforward, dropout, activation):
super(SPOTERTransformerDecoderLayer, self).__init__(d_model, nhead, dim_feedforward, dropout, activation)
del self.self_attn
def forward(self, tgt: torch.Tensor, memory: torch.Tensor, tgt_mask: Optional[torch.Tensor] = None,
memory_mask: Optional[torch.Tensor] = None, tgt_key_padding_mask: Optional[torch.Tensor] = None,
memory_key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
tgt = tgt + self.dropout1(tgt)
tgt = self.norm1(tgt)
tgt2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask)[0]
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)
return tgt
class SPOTER(nn.Module):
"""
Implementation of the SPOTER (Sign POse-based TransformER) architecture for sign language recognition from sequence
of skeletal data.
"""
def __init__(self, num_classes, hidden_dim=55):
super().__init__()
self.row_embed = nn.Parameter(torch.rand(50, hidden_dim))
self.pos = nn.Parameter(torch.cat([self.row_embed[0].unsqueeze(0).repeat(1, 1, 1)], dim=-1).flatten(0, 1).unsqueeze(0))
self.class_query = nn.Parameter(torch.rand(1, hidden_dim))
self.transformer = nn.Transformer(hidden_dim, 10, 6, 6)
self.linear_class = nn.Linear(hidden_dim, num_classes)
# Deactivate the initial attention decoder mechanism
custom_decoder_layer = SPOTERTransformerDecoderLayer(self.transformer.d_model, self.transformer.nhead, 2048,
0.1, "relu")
self.transformer.decoder.layers = _get_clones(custom_decoder_layer, self.transformer.decoder.num_layers)
def forward(self, inputs):
h = torch.unsqueeze(inputs.flatten(start_dim=1), 1).float()
h = self.transformer(self.pos + h, self.class_query.unsqueeze(0)).transpose(0, 1)
res = self.linear_class(h)
return res

110
src/train.py Normal file
View File

@@ -0,0 +1,110 @@
import os
import argparse
import random
import logging
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from torchvision import transforms
from torch.utils.data import DataLoader
from pathlib import Path
from keypoint_extractor import KeypointExtractor
from identifiers import LANDMARKS
from model import SPOTER
from dataset import WLASLDataset
def train():
random.seed(379)
np.random.seed(379)
os.environ['PYTHONHASHSEED'] = str(379)
torch.manual_seed(379)
torch.cuda.manual_seed(379)
torch.cuda.manual_seed_all(379)
torch.backends.cudnn.deterministic = True
g = torch.Generator()
g.manual_seed(379)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
spoter_model = SPOTER(num_classes=100, hidden_dim=2*75)
spoter_model.train(True)
spoter_model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(spoter_model.parameters(), lr=0.001, momentum=0.9)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=5)
# TODO: create paths for checkpoints
# TODO: transformations + augmentations
k = KeypointExtractor("data/videos/")
train_set = WLASLDataset("data/nslt_100.json", "data/missing.txt", k, keypoints_identifier=LANDMARKS, subset="train")
train_loader = DataLoader(train_set, shuffle=True, generator=g)
val_set = WLASLDataset("data/nslt_100.json", "data/missing.txt", k, keypoints_identifier=LANDMARKS, subset="val")
val_loader = DataLoader(val_set, shuffle=True, generator=g)
test_set = WLASLDataset("data/nslt_100.json", "data/missing.txt", k, keypoints_identifier=LANDMARKS, subset="test")
test_loader = DataLoader(test_set, shuffle=True, generator=g)
train_acc, val_acc = 0, 0
losses, train_accs, val_accs = [], [], []
lr_progress = []
top_train_acc, top_val_acc = 0, 0
checkpoint_index = 0
for epoch in range(100):
# train
for i, (inputs, labels) in enumerate(train_loader):
inputs = inputs.squeeze(0).to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = spoter_model(inputs).expand(1, -1, -1)
loss = criterion(outputs[0], labels)
loss.backward()
optimizer.step()
_, predicted = torch.max(outputs.data, 1)
train_acc = (predicted == labels).sum().item() / labels.size(0)
losses.append(loss.item())
train_accs.append(train_acc)
if i % 100 == 0:
print(f"Epoch: {epoch} | Batch: {i} | Loss: {loss.item()} | Train Acc: {train_acc}")
# validate
with torch.no_grad():
for i, (inputs, labels) in enumerate(val_loader):
inputs = inputs.to(device)
labels = labels.to(device)
outputs = spoter_model(inputs)
_, predicted = torch.max(outputs.data, 1)
val_acc = (predicted == labels).sum().item() / labels.size(0)
val_accs.append(val_acc)
scheduler.step(loss)
# save checkpoint
if val_acc > top_val_acc:
top_val_acc = val_acc
top_train_acc = train_acc
checkpoint_index = epoch
torch.save(spoter_model.state_dict(), f"checkpoints/spoter_{epoch}.pth")
print(f"Epoch: {epoch} | Train Acc: {train_acc} | Val Acc: {val_acc}")
lr_progress.append(optimizer.param_groups[0]['lr'])
train()