Changed steps in diffusion model

This commit is contained in:
Victor Mylle
2024-01-20 09:44:14 +00:00
parent c6fa17fa40
commit acaad2710a
6 changed files with 106 additions and 25 deletions

View File

@@ -56,7 +56,7 @@ class GRUDiffusionModel(DiffusionModel):
# GRU layer
self.gru = nn.GRU(input_size=input_size + time_dim + other_inputs_dim,
hidden_size=gru_hidden_size,
num_layers=2,
num_layers=3,
batch_first=True)
# Fully connected layers after GRU
@@ -87,7 +87,8 @@ class GRUDiffusionModel(DiffusionModel):
# Pass through GRU
output, hidden = self.gru(x) # Hidden Shape: [batch_size, seq_len, 1]
x = hidden
# Get last hidden state
x = hidden[-1]
# Process each time step's output with fully connected layers
for layer in self.fc_layers: