Fixed model.py
This commit is contained in:
27
src/model.py
27
src/model.py
@@ -1,7 +1,6 @@
|
||||
### SPOTER model implementation from the paper "SPOTER: Sign Pose-based Transformer for Sign Language Recognition from Sequence of Skeletal Data"
|
||||
|
||||
import copy
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@@ -39,20 +38,7 @@ class SPOTERTransformerDecoderLayer(nn.TransformerDecoderLayer):
|
||||
|
||||
return tgt
|
||||
|
||||
class PositionalEmbedding(nn.Module):
|
||||
def __init__(self, d_model, max_len=60):
|
||||
super().__init__()
|
||||
pe = torch.zeros(max_len, d_model)
|
||||
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0).transpose(0, 1)
|
||||
self.register_buffer('pe', pe)
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.pe[:x.size(0), :]
|
||||
|
||||
class SPOTER(nn.Module):
|
||||
"""
|
||||
Implementation of the SPOTER (Sign POse-based TransformER) architecture for sign language recognition from sequence
|
||||
@@ -62,9 +48,8 @@ class SPOTER(nn.Module):
|
||||
def __init__(self, num_classes, hidden_dim=55):
|
||||
super().__init__()
|
||||
|
||||
|
||||
self.pos = PositionalEmbedding(hidden_dim)
|
||||
|
||||
self.row_embed = nn.Parameter(torch.rand(50, hidden_dim))
|
||||
self.pos = nn.Parameter(torch.cat([self.row_embed[0].unsqueeze(0).repeat(1, 1, 1)], dim=-1).flatten(0, 1).unsqueeze(0))
|
||||
self.class_query = nn.Parameter(torch.rand(1, hidden_dim))
|
||||
self.transformer = nn.Transformer(hidden_dim, 9, 6, 6)
|
||||
self.linear_class = nn.Linear(hidden_dim, num_classes)
|
||||
@@ -76,13 +61,7 @@ class SPOTER(nn.Module):
|
||||
|
||||
def forward(self, inputs):
|
||||
h = torch.unsqueeze(inputs.flatten(start_dim=1), 1).float()
|
||||
# add positional encoding
|
||||
h = self.pos(h)
|
||||
|
||||
# add class query
|
||||
h = self.transformer(h, self.class_query.unsqueeze(0)).transpose(0, 1)
|
||||
|
||||
# get class prediction
|
||||
h = self.transformer(self.pos + h, self.class_query.unsqueeze(0)).transpose(0, 1)
|
||||
res = self.linear_class(h)
|
||||
|
||||
return res
|
||||
Reference in New Issue
Block a user