* Add project code * Logger improvements * Improvements to web demo code * added create_wlasl_landmarks_dataset.py and xtract_mediapipe_landmarks.py * Fix rotation augmentation * fixed error in docstring, and removed unnecessary replace -1 -> 0 * Readme updates * Share base notebooks * Add notebooks and unify for different datasets * requirements update * fixes * Make evaluate more deterministic * Allow training with clearml * refactor preprocessing and apply linter * Minor fixes * Minor notebook tweaks * Readme updates * Fix PR comments * Remove unneeded code * Add banner to Readme --------- Co-authored-by: Gabriel Lema <gabriel.lema@xmartlabs.com>
67 lines
2.6 KiB
Python
67 lines
2.6 KiB
Python
|
|
import copy
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
from typing import Optional
|
|
|
|
|
|
def _get_clones(mod, n):
|
|
return nn.ModuleList([copy.deepcopy(mod) for _ in range(n)])
|
|
|
|
|
|
class SPOTERTransformerDecoderLayer(nn.TransformerDecoderLayer):
|
|
"""
|
|
Edited TransformerDecoderLayer implementation omitting the redundant self-attention operation as opposed to the
|
|
standard implementation.
|
|
"""
|
|
|
|
def __init__(self, d_model, nhead, dim_feedforward, dropout, activation):
|
|
super(SPOTERTransformerDecoderLayer, self).__init__(d_model, nhead, dim_feedforward, dropout, activation)
|
|
|
|
del self.self_attn
|
|
|
|
def forward(self, tgt: torch.Tensor, memory: torch.Tensor, tgt_mask: Optional[torch.Tensor] = None,
|
|
memory_mask: Optional[torch.Tensor] = None, tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
|
memory_key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
|
|
tgt = tgt + self.dropout1(tgt)
|
|
tgt = self.norm1(tgt)
|
|
tgt2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask,
|
|
key_padding_mask=memory_key_padding_mask)[0]
|
|
tgt = tgt + self.dropout2(tgt2)
|
|
tgt = self.norm2(tgt)
|
|
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
|
tgt = tgt + self.dropout3(tgt2)
|
|
tgt = self.norm3(tgt)
|
|
|
|
return tgt
|
|
|
|
|
|
class SPOTER(nn.Module):
|
|
"""
|
|
Implementation of the SPOTER (Sign POse-based TransformER) architecture for sign language recognition from sequence
|
|
of skeletal data.
|
|
"""
|
|
|
|
def __init__(self, num_classes, hidden_dim=55):
|
|
super().__init__()
|
|
|
|
self.row_embed = nn.Parameter(torch.rand(50, hidden_dim))
|
|
self.pos = nn.Parameter(torch.cat([self.row_embed[0].unsqueeze(0).repeat(1, 1, 1)], dim=-1).flatten(0, 1).unsqueeze(0))
|
|
self.class_query = nn.Parameter(torch.rand(1, hidden_dim))
|
|
self.transformer = nn.Transformer(hidden_dim, 9, 6, 6)
|
|
self.linear_class = nn.Linear(hidden_dim, num_classes)
|
|
|
|
# Deactivate the initial attention decoder mechanism
|
|
custom_decoder_layer = SPOTERTransformerDecoderLayer(self.transformer.d_model, self.transformer.nhead, 2048,
|
|
0.1, "relu")
|
|
self.transformer.decoder.layers = _get_clones(custom_decoder_layer, self.transformer.decoder.num_layers)
|
|
|
|
def forward(self, inputs):
|
|
h = torch.unsqueeze(inputs.flatten(start_dim=1), 1).float()
|
|
h = self.transformer(self.pos + h, self.class_query.unsqueeze(0)).transpose(0, 1)
|
|
res = self.linear_class(h)
|
|
|
|
return res
|