7 Commits

Author SHA1 Message Date
Tibe Habils
40c16548b2 Merge branch 'WES-184-New-letter-variants' into 'dev'
WES-184 Train the SPOTER model on the new letter variants

See merge request wesign/sign-predictor!18
2023-05-06 19:20:57 +00:00
RobbeDeWaele
17251edfda WES-184 Train the SPOTER model on the new letter variants 2023-04-28 16:00:23 +02:00
RobbeDeWaele
bfef06d720 Fixed model.py 2023-04-28 15:03:34 +02:00
Victor Mylle
7cf35d7357 Merge branch 'WES-155-mirror-augmentation' into 'dev'
Resolve WES-155 "Mirror augmentation"

See merge request wesign/sign-predictor!16
2023-04-24 12:06:32 +00:00
Robbe De Waele
65d478ef1b Resolve WES-155 "Mirror augmentation" 2023-04-24 12:06:32 +00:00
Victor Mylle
cd9cc8ce8b Merge branch 'WES-123-rotation-augmentation' into 'dev'
Rotation augmentation class added

See merge request wesign/sign-predictor!15
2023-04-24 11:57:19 +00:00
RobbeDeWaele
0af9320571 Rotation augmentation class added 2023-03-30 16:13:03 +02:00
8 changed files with 1549 additions and 1298 deletions

BIN
models/model_A-Z_v2.onnx Normal file

Binary file not shown.

BIN
models/model_A-Z_v2.pth Normal file

Binary file not shown.

View File

@@ -36,11 +36,28 @@ def circle_intersection(x0, y0, r0, x1, y1, r1):
class MirrorKeypoints: class MirrorKeypoints:
def __call__(self, sample): def __call__(self, sample):
if sample.shape[0] == 0:
return sample
if random.random() > 0.5: if random.random() > 0.5:
return sample return sample
# flip the keypoints tensor
sample = 1 - sample # flip the x coordinates
sample[:, :, 0] *= -1
# switch hands (left becomes right and vice versa)
left, right, n = 12, 33, 21
if isinstance(sample, np.ndarray): # For testing purposes only
sample[:, left:left+n, :], sample[:, right:right+n, :] = sample[: , right:right+n, :], sample[:, left:left+n, :].copy()
else:
sample[:, left:left+n, :], sample[:, right:right+n, :] = sample[: , right:right+n, :], sample[:, left:left+n, :].clone()
# switch pose keypoints
sample[:, [1, 2], :] = sample[:, [2, 1], :] #eye
sample[:, [3, 4], :] = sample[:, [4, 3], :] #ear
sample[:, [6, 7], :] = sample[:, [7, 6], :] #shoulder
sample[:, [8, 9], :] = sample[:, [9, 8], :] #elbow
sample[:, [10, 11], :] = sample[:, [11, 10], :] #wrist
return sample return sample
@@ -124,4 +141,16 @@ class NoiseAugmentation:
def __call__(self, sample): def __call__(self, sample):
# add noise to the keypoints # add noise to the keypoints
sample = sample + torch.randn(sample.shape) * self.noise sample = sample + torch.randn(sample.shape) * self.noise
return sample return sample
# augmentation to rotate all keypoints around 0,0
class RotateAugmentation:
def __call__(self, sample):
# generate a random angle between -13 and 13 degrees
angle_max = 13.0
angle = math.radians(random.uniform(a=-angle_max, b=angle_max))
# rotate the keypoints around 0.0
new_sample = sample
new_sample[:, :, 0] = sample[:, :, 0]*math.cos(angle) - sample[:, :, 1]*math.sin(angle)
new_sample[:, :, 1] = sample[:, :, 0]*math.sin(angle) + sample[:, :, 1]*math.cos(angle)
return new_sample

View File

@@ -7,7 +7,7 @@ from src.model import SPOTER
from src.identifiers import LANDMARKS from src.identifiers import LANDMARKS
# set parameters of the model # set parameters of the model
model_name = 'model_A-Z' model_name = 'model_A-Z_v2'
num_classes = 26 num_classes = 26
# load PyTorch model from .pth file # load PyTorch model from .pth file

View File

@@ -1,7 +1,6 @@
### SPOTER model implementation from the paper "SPOTER: Sign Pose-based Transformer for Sign Language Recognition from Sequence of Skeletal Data" ### SPOTER model implementation from the paper "SPOTER: Sign Pose-based Transformer for Sign Language Recognition from Sequence of Skeletal Data"
import copy import copy
import math
from typing import Optional from typing import Optional
import torch import torch
@@ -39,20 +38,7 @@ class SPOTERTransformerDecoderLayer(nn.TransformerDecoderLayer):
return tgt return tgt
class PositionalEmbedding(nn.Module):
def __init__(self, d_model, max_len=60):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
return x + self.pe[:x.size(0), :]
class SPOTER(nn.Module): class SPOTER(nn.Module):
""" """
Implementation of the SPOTER (Sign POse-based TransformER) architecture for sign language recognition from sequence Implementation of the SPOTER (Sign POse-based TransformER) architecture for sign language recognition from sequence
@@ -62,9 +48,8 @@ class SPOTER(nn.Module):
def __init__(self, num_classes, hidden_dim=55): def __init__(self, num_classes, hidden_dim=55):
super().__init__() super().__init__()
self.row_embed = nn.Parameter(torch.rand(50, hidden_dim))
self.pos = PositionalEmbedding(hidden_dim) self.pos = nn.Parameter(torch.cat([self.row_embed[0].unsqueeze(0).repeat(1, 1, 1)], dim=-1).flatten(0, 1).unsqueeze(0))
self.class_query = nn.Parameter(torch.rand(1, hidden_dim)) self.class_query = nn.Parameter(torch.rand(1, hidden_dim))
self.transformer = nn.Transformer(hidden_dim, 9, 6, 6) self.transformer = nn.Transformer(hidden_dim, 9, 6, 6)
self.linear_class = nn.Linear(hidden_dim, num_classes) self.linear_class = nn.Linear(hidden_dim, num_classes)
@@ -76,13 +61,7 @@ class SPOTER(nn.Module):
def forward(self, inputs): def forward(self, inputs):
h = torch.unsqueeze(inputs.flatten(start_dim=1), 1).float() h = torch.unsqueeze(inputs.flatten(start_dim=1), 1).float()
# add positional encoding h = self.transformer(self.pos + h, self.class_query.unsqueeze(0)).transpose(0, 1)
h = self.pos(h)
# add class query
h = self.transformer(h, self.class_query.unsqueeze(0)).transpose(0, 1)
# get class prediction
res = self.linear_class(h) res = self.linear_class(h)
return res return res

View File

@@ -8,7 +8,7 @@ import torch.optim as optim
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torchvision import transforms from torchvision import transforms
from src.augmentations import MirrorKeypoints, Z_augmentation, NoiseAugmentation from src.augmentations import MirrorKeypoints, Z_augmentation, NoiseAugmentation, RotateAugmentation
from src.datasets.finger_spelling_dataset import FingerSpellingDataset from src.datasets.finger_spelling_dataset import FingerSpellingDataset
from src.identifiers import LANDMARKS from src.identifiers import LANDMARKS
from src.model import SPOTER from src.model import SPOTER
@@ -29,12 +29,16 @@ def train():
g = torch.Generator() g = torch.Generator()
g.manual_seed(379) g.manual_seed(379)
device = torch.device("cuda:0")
spoter_model = SPOTER(num_classes=26, hidden_dim=len(LANDMARKS) *2) spoter_model = SPOTER(num_classes=26, hidden_dim=len(LANDMARKS) *2)
# use cuda if available
if torch.cuda.is_available():
device = torch.device("cuda:0")
else:
device = torch.device("cpu")
spoter_model.train(True) spoter_model.train(True)
spoter_model.to(device) spoter_model.to(device)
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
criterion_bad = CustomLoss() criterion_bad = CustomLoss()
@@ -45,7 +49,7 @@ def train():
if not os.path.exists("checkpoints"): if not os.path.exists("checkpoints"):
os.makedirs("checkpoints") os.makedirs("checkpoints")
transform = transforms.Compose([MirrorKeypoints(), NoiseAugmentation(noise=0.1)]) transform = transforms.Compose([MirrorKeypoints(), NoiseAugmentation(noise=0.1), RotateAugmentation()])
train_set = FingerSpellingDataset("data/fingerspelling/data/", bad_data_folder="", keypoints_identifier=LANDMARKS, subset="train", transform=transform) train_set = FingerSpellingDataset("data/fingerspelling/data/", bad_data_folder="", keypoints_identifier=LANDMARKS, subset="train", transform=transform)
train_loader = DataLoader(train_set, shuffle=True, generator=g) train_loader = DataLoader(train_set, shuffle=True, generator=g)
@@ -124,9 +128,9 @@ def train():
if val_acc > best_val_acc: if val_acc > best_val_acc:
best_val_acc = val_acc best_val_acc = val_acc
epochs_without_improvement = 0 epochs_without_improvement = 0
if epoch > 55: if epoch > 45:
top_val_acc = val_acc top_val_acc = val_acc
top_train_acc = train_acc top_train_acc = pred_correct / pred_all
checkpoint_index = epoch checkpoint_index = epoch
torch.save(spoter_model.state_dict(), f"checkpoints/spoter_{epoch}.pth") torch.save(spoter_model.state_dict(), f"checkpoints/spoter_{epoch}.pth")
else: else:

File diff suppressed because one or more lines are too long

View File

@@ -27,7 +27,7 @@ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
keypoints = [] keypoints = []
spoter_model = SPOTER(num_classes=26, hidden_dim=len(LANDMARKS) * 2) spoter_model = SPOTER(num_classes=26, hidden_dim=len(LANDMARKS) * 2)
spoter_model.load_state_dict(torch.load('models/spoter_76.pth', map_location=torch.device('cpu'))) spoter_model.load_state_dict(torch.load('models/model_A-Z_v2.pth', map_location=torch.device('cpu')))
# get values of the landmarks as a list of integers # get values of the landmarks as a list of integers
values = [] values = []