From bfef06d72086caeeaf073be5145a47e33fbb85ca Mon Sep 17 00:00:00 2001 From: RobbeDeWaele <71169585+RobbeDeWaele@users.noreply.github.com> Date: Fri, 28 Apr 2023 15:03:34 +0200 Subject: [PATCH] Fixed model.py --- src/model.py | 27 +++------------------------ 1 file changed, 3 insertions(+), 24 deletions(-) diff --git a/src/model.py b/src/model.py index 758843c..6ee9cce 100644 --- a/src/model.py +++ b/src/model.py @@ -1,7 +1,6 @@ ### SPOTER model implementation from the paper "SPOTER: Sign Pose-based Transformer for Sign Language Recognition from Sequence of Skeletal Data" import copy -import math from typing import Optional import torch @@ -39,20 +38,7 @@ class SPOTERTransformerDecoderLayer(nn.TransformerDecoderLayer): return tgt -class PositionalEmbedding(nn.Module): - def __init__(self, d_model, max_len=60): - super().__init__() - pe = torch.zeros(max_len, d_model) - position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) - div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) - pe = pe.unsqueeze(0).transpose(0, 1) - self.register_buffer('pe', pe) - def forward(self, x): - return x + self.pe[:x.size(0), :] - class SPOTER(nn.Module): """ Implementation of the SPOTER (Sign POse-based TransformER) architecture for sign language recognition from sequence @@ -62,9 +48,8 @@ class SPOTER(nn.Module): def __init__(self, num_classes, hidden_dim=55): super().__init__() - - self.pos = PositionalEmbedding(hidden_dim) - + self.row_embed = nn.Parameter(torch.rand(50, hidden_dim)) + self.pos = nn.Parameter(torch.cat([self.row_embed[0].unsqueeze(0).repeat(1, 1, 1)], dim=-1).flatten(0, 1).unsqueeze(0)) self.class_query = nn.Parameter(torch.rand(1, hidden_dim)) self.transformer = nn.Transformer(hidden_dim, 9, 6, 6) self.linear_class = nn.Linear(hidden_dim, num_classes) @@ -76,13 +61,7 @@ class SPOTER(nn.Module): def forward(self, inputs): h = torch.unsqueeze(inputs.flatten(start_dim=1), 1).float() - # add positional encoding - h = self.pos(h) - - # add class query - h = self.transformer(h, self.class_query.unsqueeze(0)).transpose(0, 1) - - # get class prediction + h = self.transformer(self.pos + h, self.class_query.unsqueeze(0)).transpose(0, 1) res = self.linear_class(h) return res \ No newline at end of file