diff --git a/src/models/diffusion_model.py b/src/models/diffusion_model.py index 15e84dd..18a2d4d 100644 --- a/src/models/diffusion_model.py +++ b/src/models/diffusion_model.py @@ -45,3 +45,52 @@ class SimpleDiffusionModel(DiffusionModel): self.layers.append(nn.ReLU()) self.layers.append(nn.Linear(hidden_sizes[-1] + time_dim + other_inputs_dim, input_size)) + +class GRUDiffusionModel(DiffusionModel): + def __init__(self, input_size: int, hidden_sizes: list, other_inputs_dim: int, gru_hidden_size: int, time_dim: int = 64): + super(GRUDiffusionModel, self).__init__(time_dim) + + self.other_inputs_dim = other_inputs_dim + self.gru_hidden_size = gru_hidden_size + + # GRU layer + self.gru = nn.GRU(input_size=input_size + time_dim + other_inputs_dim, + hidden_size=gru_hidden_size, + num_layers=2, + batch_first=True) + + # Fully connected layers after GRU + self.fc_layers = nn.ModuleList() + prev_size = gru_hidden_size + for hidden_size in hidden_sizes: + self.fc_layers.append(nn.Linear(prev_size, hidden_size)) + self.fc_layers.append(nn.ReLU()) + prev_size = hidden_size + + # Final output layer + self.fc_layers.append(nn.Linear(prev_size, input_size)) + + def forward(self, x, t, inputs): + batch_size, seq_len = x.shape + x = x.unsqueeze(-1).repeat(1, 1, seq_len) + + # Positional encoding for each time step + t = t.unsqueeze(-1).type(torch.float) + t = self.pos_encoding(t, self.time_dim) # Shape: [batch_size, seq_len, time_dim] + + # repeat time encoding for each time step t is shape [batch_size, time_dim], i want [batch_size, seq_len, time_dim] + t = t.unsqueeze(1).repeat(1, seq_len, 1) + + # Concatenate x, t, and inputs along the feature dimension + x = torch.cat((x, t, inputs), dim=-1) # Shape: [batch_size, seq_len, input_size + time_dim + other_inputs_dim] + + # Pass through GRU + output, hidden = self.gru(x) # Hidden Shape: [batch_size, seq_len, 1] + + x = hidden + + # Process each time step's output with fully connected layers + for layer in self.fc_layers: + x = layer(x) + + return x \ No newline at end of file