import torch import torch.nn as nn class DiffusionModel(nn.Module): def __init__(self, time_dim: int = 64): super(DiffusionModel, self).__init__() self.time_dim = time_dim self.layers = nn.ModuleList() def pos_encoding(self, t, channels): inv_freq = 1.0 / ( 10000 ** (torch.arange(0, channels, 2).float() / channels) ).to(t.device) pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq) pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq) pos_enc = torch.cat((pos_enc_a, pos_enc_b), dim=-1) return pos_enc def forward(self, x, t, inputs): t = t.unsqueeze(-1).type(torch.float) t = self.pos_encoding(t, self.time_dim) x = torch.cat((x, t, inputs), dim=-1) for layer in self.layers[:-1]: x = layer(x) if not isinstance(layer, nn.ReLU): x = torch.cat((x, t, inputs), dim=-1) x = self.layers[-1](x) return x class SimpleDiffusionModel(DiffusionModel): def __init__(self, input_size: int, hidden_sizes: list, other_inputs_dim: int, time_dim: int = 64): super(SimpleDiffusionModel, self).__init__(time_dim) self.other_inputs_dim = other_inputs_dim self.layers.append(nn.Linear(input_size + time_dim + other_inputs_dim, hidden_sizes[0])) self.layers.append(nn.ReLU()) for i in range(1, len(hidden_sizes)): self.layers.append(nn.Linear(hidden_sizes[i - 1] + time_dim + other_inputs_dim, hidden_sizes[i])) self.layers.append(nn.ReLU()) self.layers.append(nn.Linear(hidden_sizes[-1] + time_dim + other_inputs_dim, input_size)) class GRUDiffusionModel(DiffusionModel): def __init__(self, input_size: int, hidden_sizes: list, other_inputs_dim: int, gru_hidden_size: int, time_dim: int = 64): super(GRUDiffusionModel, self).__init__(time_dim) self.other_inputs_dim = other_inputs_dim self.gru_hidden_size = gru_hidden_size # GRU layer self.gru = nn.GRU(input_size=input_size + time_dim + other_inputs_dim, hidden_size=gru_hidden_size, num_layers=3, batch_first=True) # Fully connected layers after GRU self.fc_layers = nn.ModuleList() prev_size = gru_hidden_size for hidden_size in hidden_sizes: self.fc_layers.append(nn.Linear(prev_size, hidden_size)) self.fc_layers.append(nn.ReLU()) prev_size = hidden_size # Final output layer self.fc_layers.append(nn.Linear(prev_size, input_size)) def forward(self, x, t, inputs): batch_size, seq_len = x.shape x = x.unsqueeze(-1).repeat(1, 1, seq_len) # Positional encoding for each time step t = t.unsqueeze(-1).type(torch.float) t = self.pos_encoding(t, self.time_dim) # Shape: [batch_size, seq_len, time_dim] # repeat time encoding for each time step t is shape [batch_size, time_dim], i want [batch_size, seq_len, time_dim] t = t.unsqueeze(1).repeat(1, seq_len, 1) # Concatenate x, t, and inputs along the feature dimension x = torch.cat((x, t, inputs), dim=-1) # Shape: [batch_size, seq_len, input_size + time_dim + other_inputs_dim] # Pass through GRU output, hidden = self.gru(x) # Hidden Shape: [batch_size, seq_len, 1] # Get last hidden state x = hidden[-1] # Process each time step's output with fully connected layers for layer in self.fc_layers: x = layer(x) return x