Non autregressive gru model load

This commit is contained in:
2024-05-06 16:11:15 +02:00
parent 19ab597ae6
commit d7f4c1849b
7 changed files with 55 additions and 22 deletions

View File

@@ -1,7 +1,15 @@
import torch
class LSTMModel(torch.nn.Module):
def __init__(self, inputSize, output_size, num_layers: int, hidden_size: int, dropout: float = 0.2):
def __init__(
self,
inputSize,
output_size,
num_layers: int,
hidden_size: int,
dropout: float = 0.2,
):
super(LSTMModel, self).__init__()
self.inputSize = inputSize
self.output_size = output_size
@@ -10,20 +18,34 @@ class LSTMModel(torch.nn.Module):
self.hidden_size = hidden_size
self.dropout = dropout
self.lstm = torch.nn.LSTM(input_size=inputSize[-1], hidden_size=hidden_size, num_layers=num_layers, dropout=dropout, batch_first=True)
self.lstm = torch.nn.LSTM(
input_size=inputSize[-1],
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
batch_first=True,
)
self.linear = torch.nn.Linear(hidden_size, output_size)
def forward(self, x):
# Forward pass through the LSTM layers
_, (hidden_state, _) = self.lstm(x)
# Use the hidden state from the last time step for the output
output = self.linear(hidden_state[-1])
return output
class GRUModel(torch.nn.Module):
def __init__(self, inputSize, output_size, num_layers: int, hidden_size: int, dropout: float = 0.2):
def __init__(
self,
inputSize,
output_size,
num_layers: int,
hidden_size: int,
dropout: float = 0.2,
):
super(GRUModel, self).__init__()
self.inputSize = inputSize
self.output_size = output_size
@@ -32,14 +54,24 @@ class GRUModel(torch.nn.Module):
self.hidden_size = hidden_size
self.dropout = dropout
self.gru = torch.nn.GRU(input_size=inputSize[-1], hidden_size=hidden_size, num_layers=num_layers, dropout=dropout, batch_first=True)
self.gru = torch.nn.GRU(
input_size=inputSize[-1],
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
batch_first=True,
)
self.linear = torch.nn.Linear(hidden_size, output_size)
def forward(self, x):
# if dimension is 2, add batch dimension to 1
if x.dim() == 2:
x = x.unsqueeze(0)
# Forward pass through the GRU layers
x, _ = self.gru(x)
x = x[:, -1, :]
# Use the hidden state from the last time step for the output
output = self.linear(x)
return output