Added trainer for Diffusion model
This commit is contained in:
207
src/trainers/diffusion_trainer.py
Normal file
207
src/trainers/diffusion_trainer.py
Normal file
@@ -0,0 +1,207 @@
|
||||
from clearml import Task
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchinfo import summary
|
||||
from tqdm import tqdm
|
||||
from src.data.preprocessing import DataProcessor
|
||||
|
||||
from src.models.diffusion_model import DiffusionModel
|
||||
import numpy as np
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
import matplotlib.patches as mpatches
|
||||
|
||||
class DiffusionTrainer:
|
||||
def __init__(self, model: nn.Module, data_processor: DataProcessor, device: torch.device):
|
||||
self.model = model
|
||||
self.device = device
|
||||
|
||||
self.noise_steps = 1000
|
||||
self.beta_start = 1e-4
|
||||
self.beta_end = 0.02
|
||||
self.ts_length = 96
|
||||
|
||||
self.data_processor = data_processor
|
||||
|
||||
self.beta = torch.linspace(self.beta_start, self.beta_end, self.noise_steps).to(self.device)
|
||||
self.alpha = 1. - self.beta
|
||||
self.alpha_hat = torch.cumprod(self.alpha, dim=0)
|
||||
|
||||
def noise_time_series(self, x: torch.tensor, t: int):
|
||||
""" Add noise to time series
|
||||
Args:
|
||||
x (torch.tensor): shape (batch_size, time_steps)
|
||||
t (int): index of time step
|
||||
"""
|
||||
sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None]
|
||||
sqrt_one_minus_alpha_hat = torch.sqrt(1. - self.alpha_hat[t])[:, None]
|
||||
noise = torch.randn_like(x)
|
||||
return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * noise, noise
|
||||
|
||||
def sample_timesteps(self, n: int):
|
||||
""" Sample timesteps for noise
|
||||
Args:
|
||||
n (int): number of samples
|
||||
"""
|
||||
return torch.randint(low=1, high=self.noise_steps, size=(n,))
|
||||
|
||||
def sample(self, model: DiffusionModel, n: int, inputs: torch.tensor):
|
||||
inputs = inputs.repeat(n, 1).to(self.device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
x = torch.randn(inputs.shape[0], self.ts_length).to(self.device)
|
||||
for i in tqdm(reversed(range(1, self.noise_steps)), position=0):
|
||||
t = (torch.ones(inputs.shape[0]) * i).long().to(self.device)
|
||||
predicted_noise = model(x, t, inputs)
|
||||
alpha = self.alpha[t][:, None]
|
||||
alpha_hat = self.alpha_hat[t][:, None]
|
||||
beta = self.beta[t][:, None]
|
||||
|
||||
if i > 1:
|
||||
noise = torch.randn_like(x)
|
||||
else:
|
||||
noise = torch.zeros_like(x)
|
||||
|
||||
x = 1/torch.sqrt(alpha) * (x-((1-alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
|
||||
model.train()
|
||||
return x
|
||||
|
||||
def random_samples(self, train: bool = True, num_samples: int = 10):
|
||||
train_loader, test_loader = self.data_processor.get_dataloaders(
|
||||
predict_sequence_length=96
|
||||
)
|
||||
|
||||
if train:
|
||||
loader = train_loader
|
||||
else:
|
||||
loader = test_loader
|
||||
|
||||
indices = np.random.randint(0, len(loader.dataset) - 1, size=num_samples)
|
||||
return indices
|
||||
|
||||
def init_clearml_task(self, task):
|
||||
task.add_tags(self.model.__class__.__name__)
|
||||
task.add_tags(self.__class__.__name__)
|
||||
|
||||
input_data = torch.randn(1024, 96).to(self.device)
|
||||
time_steps = torch.randn(1024).long().to(self.device)
|
||||
other_input_data = torch.randn(1024, self.model.other_inputs_dim).to(self.device)
|
||||
|
||||
task.set_configuration_object("model", str(summary(self.model, input_data=[input_data, time_steps, other_input_data])))
|
||||
|
||||
self.data_processor = task.connect(self.data_processor, name="data_processor")
|
||||
|
||||
def train(self, epochs: int, learning_rate: float, task: Task = None):
|
||||
optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)
|
||||
criterion = nn.MSELoss()
|
||||
self.model.to(self.device)
|
||||
|
||||
if task:
|
||||
self.init_clearml_task(task)
|
||||
|
||||
train_loader, test_loader = self.data_processor.get_dataloaders(
|
||||
predict_sequence_length=self.ts_length
|
||||
)
|
||||
|
||||
train_sample_indices = self.random_samples(train=True, num_samples=10)
|
||||
test_sample_indices = self.random_samples(train=False, num_samples=10)
|
||||
|
||||
for epoch in range(epochs):
|
||||
running_loss = 0.0
|
||||
for i, k in enumerate(train_loader):
|
||||
time_series, base_pattern = k[1], k[0]
|
||||
time_series = time_series.to(self.device)
|
||||
base_pattern = base_pattern.to(self.device)
|
||||
|
||||
t = self.sample_timesteps(time_series.shape[0]).to(self.device)
|
||||
x_t, noise = self.noise_time_series(time_series, t)
|
||||
predicted_noise = self.model(x_t, t, base_pattern)
|
||||
loss = criterion(predicted_noise, noise)
|
||||
|
||||
running_loss += loss.item()
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
running_loss /= len(train_loader.dataset)
|
||||
|
||||
if task:
|
||||
task.get_logger().report_scalar(
|
||||
title=criterion.__class__.__name__,
|
||||
series='train',
|
||||
iteration=epoch,
|
||||
value=loss.item(),
|
||||
)
|
||||
|
||||
if epoch % 100 == 0 and epoch != 0:
|
||||
self.debug_plots(task, True, train_loader, train_sample_indices, epoch)
|
||||
self.debug_plots(task, False, test_loader, test_sample_indices, epoch)
|
||||
|
||||
if task:
|
||||
task.close()
|
||||
|
||||
|
||||
def debug_plots(self, task, training: bool, data_loader, sample_indices, epoch):
|
||||
for i, idx in enumerate(sample_indices):
|
||||
features, target, _ = data_loader.dataset[idx]
|
||||
|
||||
features = features.to(self.device)
|
||||
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
samples = self.sample(self.model, 100, features).cpu().numpy()
|
||||
|
||||
ci_99_upper = np.quantile(samples, 0.99, axis=0)
|
||||
ci_99_lower = np.quantile(samples, 0.01, axis=0)
|
||||
|
||||
ci_95_upper = np.quantile(samples, 0.95, axis=0)
|
||||
ci_95_lower = np.quantile(samples, 0.05, axis=0)
|
||||
|
||||
ci_90_upper = np.quantile(samples, 0.9, axis=0)
|
||||
ci_90_lower = np.quantile(samples, 0.1, axis=0)
|
||||
|
||||
ci_50_upper = np.quantile(samples, 0.5, axis=0)
|
||||
ci_50_lower = np.quantile(samples, 0.5, axis=0)
|
||||
|
||||
sns.set_theme()
|
||||
time_steps = np.arange(0, 96)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(20, 10))
|
||||
ax.plot(time_steps, samples.mean(axis=0), label="Mean of NRV samples", linewidth=3)
|
||||
# ax.fill_between(time_steps, ci_lower, ci_upper, color='b', alpha=0.2, label='Full Interval')
|
||||
|
||||
ax.fill_between(time_steps, ci_99_lower, ci_99_upper, color='b', alpha=0.2, label='99% Interval')
|
||||
ax.fill_between(time_steps, ci_95_lower, ci_95_upper, color='b', alpha=0.2, label='95% Interval')
|
||||
ax.fill_between(time_steps, ci_90_lower, ci_90_upper, color='b', alpha=0.2, label='90% Interval')
|
||||
ax.fill_between(time_steps, ci_50_lower, ci_50_upper, color='b', alpha=0.2, label='50% Interval')
|
||||
|
||||
ax.plot(target, label="Real NRV", linewidth=3)
|
||||
# full_interval_patch = mpatches.Patch(color='b', alpha=0.2, label='Full Interval')
|
||||
ci_99_patch = mpatches.Patch(color='b', alpha=0.3, label='99% Interval')
|
||||
ci_95_patch = mpatches.Patch(color='b', alpha=0.4, label='95% Interval')
|
||||
ci_90_patch = mpatches.Patch(color='b', alpha=0.5, label='90% Interval')
|
||||
ci_50_patch = mpatches.Patch(color='b', alpha=0.6, label='50% Interval')
|
||||
|
||||
|
||||
ax.legend(handles=[ci_99_patch, ci_95_patch, ci_90_patch, ci_50_patch, ax.lines[0], ax.lines[1]])
|
||||
|
||||
task.get_logger().report_matplotlib_figure(
|
||||
title="Training" if training else "Testing",
|
||||
series=f'Sample {i}',
|
||||
iteration=epoch,
|
||||
figure=fig,
|
||||
)
|
||||
|
||||
plt.close()
|
||||
|
||||
def test(self, data_loader: torch.utils.data.DataLoader):
|
||||
for inputs, targets, _ in data_loader:
|
||||
inputs, targets = inputs.to(self.device), targets.to(self.device)
|
||||
|
||||
sample = self.sample(self.model, 10, inputs)
|
||||
|
||||
# reduce sample from (batch_size, time_steps) to (batch_size / 10, time_steps) by taking mean of each 10 samples
|
||||
sample = sample.view(-1, 10, self.ts_length)
|
||||
sample = torch.mean(sample, dim=1)
|
||||
Reference in New Issue
Block a user