Added GRU diffusion model
This commit is contained in:
@@ -45,3 +45,52 @@ class SimpleDiffusionModel(DiffusionModel):
|
|||||||
self.layers.append(nn.ReLU())
|
self.layers.append(nn.ReLU())
|
||||||
|
|
||||||
self.layers.append(nn.Linear(hidden_sizes[-1] + time_dim + other_inputs_dim, input_size))
|
self.layers.append(nn.Linear(hidden_sizes[-1] + time_dim + other_inputs_dim, input_size))
|
||||||
|
|
||||||
|
class GRUDiffusionModel(DiffusionModel):
|
||||||
|
def __init__(self, input_size: int, hidden_sizes: list, other_inputs_dim: int, gru_hidden_size: int, time_dim: int = 64):
|
||||||
|
super(GRUDiffusionModel, self).__init__(time_dim)
|
||||||
|
|
||||||
|
self.other_inputs_dim = other_inputs_dim
|
||||||
|
self.gru_hidden_size = gru_hidden_size
|
||||||
|
|
||||||
|
# GRU layer
|
||||||
|
self.gru = nn.GRU(input_size=input_size + time_dim + other_inputs_dim,
|
||||||
|
hidden_size=gru_hidden_size,
|
||||||
|
num_layers=2,
|
||||||
|
batch_first=True)
|
||||||
|
|
||||||
|
# Fully connected layers after GRU
|
||||||
|
self.fc_layers = nn.ModuleList()
|
||||||
|
prev_size = gru_hidden_size
|
||||||
|
for hidden_size in hidden_sizes:
|
||||||
|
self.fc_layers.append(nn.Linear(prev_size, hidden_size))
|
||||||
|
self.fc_layers.append(nn.ReLU())
|
||||||
|
prev_size = hidden_size
|
||||||
|
|
||||||
|
# Final output layer
|
||||||
|
self.fc_layers.append(nn.Linear(prev_size, input_size))
|
||||||
|
|
||||||
|
def forward(self, x, t, inputs):
|
||||||
|
batch_size, seq_len = x.shape
|
||||||
|
x = x.unsqueeze(-1).repeat(1, 1, seq_len)
|
||||||
|
|
||||||
|
# Positional encoding for each time step
|
||||||
|
t = t.unsqueeze(-1).type(torch.float)
|
||||||
|
t = self.pos_encoding(t, self.time_dim) # Shape: [batch_size, seq_len, time_dim]
|
||||||
|
|
||||||
|
# repeat time encoding for each time step t is shape [batch_size, time_dim], i want [batch_size, seq_len, time_dim]
|
||||||
|
t = t.unsqueeze(1).repeat(1, seq_len, 1)
|
||||||
|
|
||||||
|
# Concatenate x, t, and inputs along the feature dimension
|
||||||
|
x = torch.cat((x, t, inputs), dim=-1) # Shape: [batch_size, seq_len, input_size + time_dim + other_inputs_dim]
|
||||||
|
|
||||||
|
# Pass through GRU
|
||||||
|
output, hidden = self.gru(x) # Hidden Shape: [batch_size, seq_len, 1]
|
||||||
|
|
||||||
|
x = hidden
|
||||||
|
|
||||||
|
# Process each time step's output with fully connected layers
|
||||||
|
for layer in self.fc_layers:
|
||||||
|
x = layer(x)
|
||||||
|
|
||||||
|
return x
|
||||||
Reference in New Issue
Block a user