Added training loop and model
This commit is contained in:
@@ -1,10 +1,9 @@
|
||||
import torch
|
||||
import pandas as pd
|
||||
from PIL import Image
|
||||
import json
|
||||
from keypoint_extractor import KeypointExtractor
|
||||
from collections import OrderedDict
|
||||
from identifiers import LANDMARKS
|
||||
import numpy as np
|
||||
|
||||
class WLASLDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, json_file: str, missing: str, keypoint_extractor: KeypointExtractor, subset:str="train", keypoints_identifier: dict = None, transform=None):
|
||||
@@ -50,10 +49,15 @@ class WLASLDataset(torch.utils.data.Dataset):
|
||||
# filter the keypoints by the identified subset
|
||||
if self.keypoints_to_keep:
|
||||
keypoints_df = keypoints_df[self.keypoints_to_keep]
|
||||
|
||||
current_row = np.empty(shape=(keypoints_df.shape[0], keypoints_df.shape[1] // 2, 2))
|
||||
for i in range(0, keypoints_df.shape[1], 2):
|
||||
current_row[:, i//2, 0] = keypoints_df.iloc[:,i]
|
||||
current_row[:, i//2, 1] = keypoints_df.iloc[:,i+1]
|
||||
|
||||
# TODO: convert keypoints to tensor and return
|
||||
label = self.data[video_id]["action"][0]
|
||||
|
||||
k = KeypointExtractor("data/videos/")
|
||||
d = WLASLDataset("data/nslt_100.json", "data/missing.txt", k, keypoints_identifier=LANDMARKS)
|
||||
# data to tensor
|
||||
data = torch.from_numpy(current_row)
|
||||
|
||||
d.__getitem__(0)
|
||||
return data, label
|
||||
@@ -35,7 +35,7 @@ LANDMARKS = {
|
||||
"right_foot_index": 32,
|
||||
|
||||
# Left Hand Landmarks
|
||||
"left_wrist": 33,
|
||||
"left_wrist2": 33,
|
||||
"left_thumb_cmc": 34,
|
||||
"left_thumb_mcp": 35,
|
||||
"left_thumb_ip": 36,
|
||||
@@ -58,7 +58,7 @@ LANDMARKS = {
|
||||
"left_pinky_tip": 53,
|
||||
|
||||
# Right Hand Landmarks
|
||||
"right_wrist": 54,
|
||||
"right_wrist2": 54,
|
||||
"right_thumb_cmc": 55,
|
||||
"right_thumb_mcp": 56,
|
||||
"right_thumb_ip": 57,
|
||||
|
||||
@@ -45,7 +45,7 @@ class KeypointExtractor:
|
||||
# check if cache file exists and return
|
||||
if os.path.exists(self.cache_folder + "/" + video + ".npy"):
|
||||
# create dataframe from cache
|
||||
return pd.DataFrame(np.load(self.cache_folder + "/" + video + ".npy"), columns=self.columns)
|
||||
return pd.DataFrame(np.load(self.cache_folder + "/" + video + ".npy", allow_pickle=True), columns=self.columns)
|
||||
|
||||
# open video
|
||||
cap = cv2.VideoCapture(self.video_folder + video)
|
||||
@@ -89,7 +89,6 @@ class KeypointExtractor:
|
||||
:return: the keypoints and the frame with keypoints on if draw is set to True
|
||||
:rtype: np.ndarray
|
||||
"""
|
||||
print("Extracting keypoints from frame")
|
||||
# Convert the BGR image to RGB and process it with MediaPipe Pose.
|
||||
results = self.holistic.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
||||
|
||||
|
||||
68
src/model.py
Normal file
68
src/model.py
Normal file
@@ -0,0 +1,68 @@
|
||||
### SPOTER model implementation from the paper "SPOTER: Sign Pose-based Transformer for Sign Language Recognition from Sequence of Skeletal Data"
|
||||
|
||||
import copy
|
||||
import torch
|
||||
|
||||
import torch.nn as nn
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def _get_clones(mod, n):
|
||||
return nn.ModuleList([copy.deepcopy(mod) for _ in range(n)])
|
||||
|
||||
|
||||
class SPOTERTransformerDecoderLayer(nn.TransformerDecoderLayer):
|
||||
"""
|
||||
Edited TransformerDecoderLayer implementation omitting the redundant self-attention operation as opposed to the
|
||||
standard implementation.
|
||||
"""
|
||||
|
||||
def __init__(self, d_model, nhead, dim_feedforward, dropout, activation):
|
||||
super(SPOTERTransformerDecoderLayer, self).__init__(d_model, nhead, dim_feedforward, dropout, activation)
|
||||
|
||||
del self.self_attn
|
||||
|
||||
def forward(self, tgt: torch.Tensor, memory: torch.Tensor, tgt_mask: Optional[torch.Tensor] = None,
|
||||
memory_mask: Optional[torch.Tensor] = None, tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
memory_key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
tgt = tgt + self.dropout1(tgt)
|
||||
tgt = self.norm1(tgt)
|
||||
tgt2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask,
|
||||
key_padding_mask=memory_key_padding_mask)[0]
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
tgt = self.norm2(tgt)
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
||||
tgt = tgt + self.dropout3(tgt2)
|
||||
tgt = self.norm3(tgt)
|
||||
|
||||
return tgt
|
||||
|
||||
|
||||
class SPOTER(nn.Module):
|
||||
"""
|
||||
Implementation of the SPOTER (Sign POse-based TransformER) architecture for sign language recognition from sequence
|
||||
of skeletal data.
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes, hidden_dim=55):
|
||||
super().__init__()
|
||||
|
||||
self.row_embed = nn.Parameter(torch.rand(50, hidden_dim))
|
||||
self.pos = nn.Parameter(torch.cat([self.row_embed[0].unsqueeze(0).repeat(1, 1, 1)], dim=-1).flatten(0, 1).unsqueeze(0))
|
||||
self.class_query = nn.Parameter(torch.rand(1, hidden_dim))
|
||||
self.transformer = nn.Transformer(hidden_dim, 10, 6, 6)
|
||||
self.linear_class = nn.Linear(hidden_dim, num_classes)
|
||||
|
||||
# Deactivate the initial attention decoder mechanism
|
||||
custom_decoder_layer = SPOTERTransformerDecoderLayer(self.transformer.d_model, self.transformer.nhead, 2048,
|
||||
0.1, "relu")
|
||||
self.transformer.decoder.layers = _get_clones(custom_decoder_layer, self.transformer.decoder.num_layers)
|
||||
|
||||
def forward(self, inputs):
|
||||
h = torch.unsqueeze(inputs.flatten(start_dim=1), 1).float()
|
||||
|
||||
h = self.transformer(self.pos + h, self.class_query.unsqueeze(0)).transpose(0, 1)
|
||||
res = self.linear_class(h)
|
||||
|
||||
return res
|
||||
110
src/train.py
Normal file
110
src/train.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import os
|
||||
import argparse
|
||||
import random
|
||||
import logging
|
||||
import torch
|
||||
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.ticker as ticker
|
||||
from torchvision import transforms
|
||||
from torch.utils.data import DataLoader
|
||||
from pathlib import Path
|
||||
|
||||
from keypoint_extractor import KeypointExtractor
|
||||
from identifiers import LANDMARKS
|
||||
from model import SPOTER
|
||||
from dataset import WLASLDataset
|
||||
|
||||
def train():
|
||||
random.seed(379)
|
||||
np.random.seed(379)
|
||||
os.environ['PYTHONHASHSEED'] = str(379)
|
||||
torch.manual_seed(379)
|
||||
torch.cuda.manual_seed(379)
|
||||
torch.cuda.manual_seed_all(379)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
g = torch.Generator()
|
||||
g.manual_seed(379)
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
spoter_model = SPOTER(num_classes=100, hidden_dim=2*75)
|
||||
spoter_model.train(True)
|
||||
spoter_model.to(device)
|
||||
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.SGD(spoter_model.parameters(), lr=0.001, momentum=0.9)
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=5)
|
||||
|
||||
# TODO: create paths for checkpoints
|
||||
|
||||
# TODO: transformations + augmentations
|
||||
|
||||
k = KeypointExtractor("data/videos/")
|
||||
|
||||
train_set = WLASLDataset("data/nslt_100.json", "data/missing.txt", k, keypoints_identifier=LANDMARKS, subset="train")
|
||||
train_loader = DataLoader(train_set, shuffle=True, generator=g)
|
||||
|
||||
val_set = WLASLDataset("data/nslt_100.json", "data/missing.txt", k, keypoints_identifier=LANDMARKS, subset="val")
|
||||
val_loader = DataLoader(val_set, shuffle=True, generator=g)
|
||||
|
||||
test_set = WLASLDataset("data/nslt_100.json", "data/missing.txt", k, keypoints_identifier=LANDMARKS, subset="test")
|
||||
test_loader = DataLoader(test_set, shuffle=True, generator=g)
|
||||
|
||||
|
||||
train_acc, val_acc = 0, 0
|
||||
losses, train_accs, val_accs = [], [], []
|
||||
lr_progress = []
|
||||
top_train_acc, top_val_acc = 0, 0
|
||||
checkpoint_index = 0
|
||||
|
||||
for epoch in range(100):
|
||||
|
||||
# train
|
||||
for i, (inputs, labels) in enumerate(train_loader):
|
||||
inputs = inputs.squeeze(0).to(device)
|
||||
labels = labels.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
outputs = spoter_model(inputs).expand(1, -1, -1)
|
||||
loss = criterion(outputs[0], labels)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
train_acc = (predicted == labels).sum().item() / labels.size(0)
|
||||
|
||||
losses.append(loss.item())
|
||||
train_accs.append(train_acc)
|
||||
|
||||
if i % 100 == 0:
|
||||
print(f"Epoch: {epoch} | Batch: {i} | Loss: {loss.item()} | Train Acc: {train_acc}")
|
||||
|
||||
# validate
|
||||
with torch.no_grad():
|
||||
for i, (inputs, labels) in enumerate(val_loader):
|
||||
inputs = inputs.to(device)
|
||||
labels = labels.to(device)
|
||||
|
||||
outputs = spoter_model(inputs)
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
val_acc = (predicted == labels).sum().item() / labels.size(0)
|
||||
|
||||
val_accs.append(val_acc)
|
||||
|
||||
scheduler.step(loss)
|
||||
|
||||
# save checkpoint
|
||||
if val_acc > top_val_acc:
|
||||
top_val_acc = val_acc
|
||||
top_train_acc = train_acc
|
||||
checkpoint_index = epoch
|
||||
torch.save(spoter_model.state_dict(), f"checkpoints/spoter_{epoch}.pth")
|
||||
|
||||
print(f"Epoch: {epoch} | Train Acc: {train_acc} | Val Acc: {val_acc}")
|
||||
lr_progress.append(optimizer.param_groups[0]['lr'])
|
||||
|
||||
train()
|
||||
Reference in New Issue
Block a user