Files
Thesis/src/models/diffusion_model.py
2024-01-20 09:44:14 +00:00

97 lines
3.6 KiB
Python

import torch
import torch.nn as nn
class DiffusionModel(nn.Module):
def __init__(self, time_dim: int = 64):
super(DiffusionModel, self).__init__()
self.time_dim = time_dim
self.layers = nn.ModuleList()
def pos_encoding(self, t, channels):
inv_freq = 1.0 / (
10000 ** (torch.arange(0, channels, 2).float() / channels)
).to(t.device)
pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
pos_enc = torch.cat((pos_enc_a, pos_enc_b), dim=-1)
return pos_enc
def forward(self, x, t, inputs):
t = t.unsqueeze(-1).type(torch.float)
t = self.pos_encoding(t, self.time_dim)
x = torch.cat((x, t, inputs), dim=-1)
for layer in self.layers[:-1]:
x = layer(x)
if not isinstance(layer, nn.ReLU):
x = torch.cat((x, t, inputs), dim=-1)
x = self.layers[-1](x)
return x
class SimpleDiffusionModel(DiffusionModel):
def __init__(self, input_size: int, hidden_sizes: list, other_inputs_dim: int, time_dim: int = 64):
super(SimpleDiffusionModel, self).__init__(time_dim)
self.other_inputs_dim = other_inputs_dim
self.layers.append(nn.Linear(input_size + time_dim + other_inputs_dim, hidden_sizes[0]))
self.layers.append(nn.ReLU())
for i in range(1, len(hidden_sizes)):
self.layers.append(nn.Linear(hidden_sizes[i - 1] + time_dim + other_inputs_dim, hidden_sizes[i]))
self.layers.append(nn.ReLU())
self.layers.append(nn.Linear(hidden_sizes[-1] + time_dim + other_inputs_dim, input_size))
class GRUDiffusionModel(DiffusionModel):
def __init__(self, input_size: int, hidden_sizes: list, other_inputs_dim: int, gru_hidden_size: int, time_dim: int = 64):
super(GRUDiffusionModel, self).__init__(time_dim)
self.other_inputs_dim = other_inputs_dim
self.gru_hidden_size = gru_hidden_size
# GRU layer
self.gru = nn.GRU(input_size=input_size + time_dim + other_inputs_dim,
hidden_size=gru_hidden_size,
num_layers=3,
batch_first=True)
# Fully connected layers after GRU
self.fc_layers = nn.ModuleList()
prev_size = gru_hidden_size
for hidden_size in hidden_sizes:
self.fc_layers.append(nn.Linear(prev_size, hidden_size))
self.fc_layers.append(nn.ReLU())
prev_size = hidden_size
# Final output layer
self.fc_layers.append(nn.Linear(prev_size, input_size))
def forward(self, x, t, inputs):
batch_size, seq_len = x.shape
x = x.unsqueeze(-1).repeat(1, 1, seq_len)
# Positional encoding for each time step
t = t.unsqueeze(-1).type(torch.float)
t = self.pos_encoding(t, self.time_dim) # Shape: [batch_size, seq_len, time_dim]
# repeat time encoding for each time step t is shape [batch_size, time_dim], i want [batch_size, seq_len, time_dim]
t = t.unsqueeze(1).repeat(1, seq_len, 1)
# Concatenate x, t, and inputs along the feature dimension
x = torch.cat((x, t, inputs), dim=-1) # Shape: [batch_size, seq_len, input_size + time_dim + other_inputs_dim]
# Pass through GRU
output, hidden = self.gru(x) # Hidden Shape: [batch_size, seq_len, 1]
# Get last hidden state
x = hidden[-1]
# Process each time step's output with fully connected layers
for layer in self.fc_layers:
x = layer(x)
return x