from clearml import Task import torch import torch.nn as nn from torchinfo import summary from tqdm import tqdm from src.data.preprocessing import DataProcessor from src.models.diffusion_model import DiffusionModel import numpy as np import matplotlib.pyplot as plt import seaborn as sns import matplotlib.patches as mpatches class DiffusionTrainer: def __init__(self, model: nn.Module, data_processor: DataProcessor, device: torch.device): self.model = model self.device = device self.noise_steps = 1000 self.beta_start = 1e-4 self.beta_end = 0.02 self.ts_length = 96 self.data_processor = data_processor self.beta = torch.linspace(self.beta_start, self.beta_end, self.noise_steps).to(self.device) self.alpha = 1. - self.beta self.alpha_hat = torch.cumprod(self.alpha, dim=0) def noise_time_series(self, x: torch.tensor, t: int): """ Add noise to time series Args: x (torch.tensor): shape (batch_size, time_steps) t (int): index of time step """ sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None] sqrt_one_minus_alpha_hat = torch.sqrt(1. - self.alpha_hat[t])[:, None] noise = torch.randn_like(x) return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * noise, noise def sample_timesteps(self, n: int): """ Sample timesteps for noise Args: n (int): number of samples """ return torch.randint(low=1, high=self.noise_steps, size=(n,)) def sample(self, model: DiffusionModel, n: int, inputs: torch.tensor): inputs = inputs.repeat(n, 1).to(self.device) model.eval() with torch.no_grad(): x = torch.randn(inputs.shape[0], self.ts_length).to(self.device) for i in tqdm(reversed(range(1, self.noise_steps)), position=0): t = (torch.ones(inputs.shape[0]) * i).long().to(self.device) predicted_noise = model(x, t, inputs) alpha = self.alpha[t][:, None] alpha_hat = self.alpha_hat[t][:, None] beta = self.beta[t][:, None] if i > 1: noise = torch.randn_like(x) else: noise = torch.zeros_like(x) x = 1/torch.sqrt(alpha) * (x-((1-alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise model.train() return x def random_samples(self, train: bool = True, num_samples: int = 10): train_loader, test_loader = self.data_processor.get_dataloaders( predict_sequence_length=96 ) if train: loader = train_loader else: loader = test_loader indices = np.random.randint(0, len(loader.dataset) - 1, size=num_samples) return indices def init_clearml_task(self, task): task.add_tags(self.model.__class__.__name__) task.add_tags(self.__class__.__name__) input_data = torch.randn(1024, 96).to(self.device) time_steps = torch.randn(1024).long().to(self.device) other_input_data = torch.randn(1024, self.model.other_inputs_dim).to(self.device) task.set_configuration_object("model", str(summary(self.model, input_data=[input_data, time_steps, other_input_data]))) self.data_processor = task.connect(self.data_processor, name="data_processor") def train(self, epochs: int, learning_rate: float, task: Task = None): optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate) criterion = nn.MSELoss() self.model.to(self.device) if task: self.init_clearml_task(task) train_loader, test_loader = self.data_processor.get_dataloaders( predict_sequence_length=self.ts_length ) train_sample_indices = self.random_samples(train=True, num_samples=10) test_sample_indices = self.random_samples(train=False, num_samples=10) for epoch in range(epochs): running_loss = 0.0 for i, k in enumerate(train_loader): time_series, base_pattern = k[1], k[0] time_series = time_series.to(self.device) base_pattern = base_pattern.to(self.device) t = self.sample_timesteps(time_series.shape[0]).to(self.device) x_t, noise = self.noise_time_series(time_series, t) predicted_noise = self.model(x_t, t, base_pattern) loss = criterion(predicted_noise, noise) running_loss += loss.item() optimizer.zero_grad() loss.backward() optimizer.step() running_loss /= len(train_loader.dataset) if task: task.get_logger().report_scalar( title=criterion.__class__.__name__, series='train', iteration=epoch, value=loss.item(), ) if epoch % 100 == 0 and epoch != 0: self.debug_plots(task, True, train_loader, train_sample_indices, epoch) self.debug_plots(task, False, test_loader, test_sample_indices, epoch) if task: task.close() def debug_plots(self, task, training: bool, data_loader, sample_indices, epoch): for i, idx in enumerate(sample_indices): features, target, _ = data_loader.dataset[idx] features = features.to(self.device) self.model.eval() with torch.no_grad(): samples = self.sample(self.model, 100, features).cpu().numpy() ci_99_upper = np.quantile(samples, 0.99, axis=0) ci_99_lower = np.quantile(samples, 0.01, axis=0) ci_95_upper = np.quantile(samples, 0.95, axis=0) ci_95_lower = np.quantile(samples, 0.05, axis=0) ci_90_upper = np.quantile(samples, 0.9, axis=0) ci_90_lower = np.quantile(samples, 0.1, axis=0) ci_50_upper = np.quantile(samples, 0.5, axis=0) ci_50_lower = np.quantile(samples, 0.5, axis=0) sns.set_theme() time_steps = np.arange(0, 96) fig, ax = plt.subplots(figsize=(20, 10)) ax.plot(time_steps, samples.mean(axis=0), label="Mean of NRV samples", linewidth=3) # ax.fill_between(time_steps, ci_lower, ci_upper, color='b', alpha=0.2, label='Full Interval') ax.fill_between(time_steps, ci_99_lower, ci_99_upper, color='b', alpha=0.2, label='99% Interval') ax.fill_between(time_steps, ci_95_lower, ci_95_upper, color='b', alpha=0.2, label='95% Interval') ax.fill_between(time_steps, ci_90_lower, ci_90_upper, color='b', alpha=0.2, label='90% Interval') ax.fill_between(time_steps, ci_50_lower, ci_50_upper, color='b', alpha=0.2, label='50% Interval') ax.plot(target, label="Real NRV", linewidth=3) # full_interval_patch = mpatches.Patch(color='b', alpha=0.2, label='Full Interval') ci_99_patch = mpatches.Patch(color='b', alpha=0.3, label='99% Interval') ci_95_patch = mpatches.Patch(color='b', alpha=0.4, label='95% Interval') ci_90_patch = mpatches.Patch(color='b', alpha=0.5, label='90% Interval') ci_50_patch = mpatches.Patch(color='b', alpha=0.6, label='50% Interval') ax.legend(handles=[ci_99_patch, ci_95_patch, ci_90_patch, ci_50_patch, ax.lines[0], ax.lines[1]]) task.get_logger().report_matplotlib_figure( title="Training" if training else "Testing", series=f'Sample {i}', iteration=epoch, figure=fig, ) plt.close() def test(self, data_loader: torch.utils.data.DataLoader): for inputs, targets, _ in data_loader: inputs, targets = inputs.to(self.device), targets.to(self.device) sample = self.sample(self.model, 10, inputs) # reduce sample from (batch_size, time_steps) to (batch_size / 10, time_steps) by taking mean of each 10 samples sample = sample.view(-1, 10, self.ts_length) sample = torch.mean(sample, dim=1)