Added training script to remotely train diffusion model

This commit is contained in:
Victor Mylle
2023-12-29 09:03:09 +00:00
parent 3264b5ac53
commit 1b209a4562
3 changed files with 120 additions and 8 deletions

View File

@@ -2,7 +2,7 @@ from clearml import Task
import torch
import torch.nn as nn
from torchinfo import summary
from tqdm import tqdm
from src.losses.crps_metric import crps_from_samples
from src.data.preprocessing import DataProcessor
from src.models.diffusion_model import DiffusionModel
@@ -51,7 +51,7 @@ class DiffusionTrainer:
model.eval()
with torch.no_grad():
x = torch.randn(inputs.shape[0], self.ts_length).to(self.device)
for i in tqdm(reversed(range(1, self.noise_steps)), position=0):
for i in reversed(range(1, self.noise_steps)):
t = (torch.ones(inputs.shape[0]) * i).long().to(self.device)
predicted_noise = model(x, t, inputs)
alpha = self.alpha[t][:, None]
@@ -127,6 +127,9 @@ class DiffusionTrainer:
running_loss /= len(train_loader.dataset)
if epoch % 20 == 0 and epoch != 0:
self.test(test_loader, epoch, task)
if task:
task.get_logger().report_scalar(
title=criterion.__class__.__name__,
@@ -196,12 +199,32 @@ class DiffusionTrainer:
plt.close()
def test(self, data_loader: torch.utils.data.DataLoader):
def test(self, data_loader: torch.utils.data.DataLoader, epoch: int, task: Task = None):
all_crps = []
for inputs, targets, _ in data_loader:
inputs, targets = inputs.to(self.device), targets.to(self.device)
sample = self.sample(self.model, 10, inputs)
number_of_samples = 100
sample = self.sample(self.model, number_of_samples, inputs)
# reduce sample from (batch_size, time_steps) to (batch_size / 10, time_steps) by taking mean of each 10 samples
sample = sample.view(-1, 10, self.ts_length)
sample = torch.mean(sample, dim=1)
# reduce samples from (batch_size*number_of_samples, time_steps) to (batch_size, number_of_samples, time_steps)
samples_batched = sample.reshape(inputs.shape[0], number_of_samples, 96)
# calculate crps
crps = crps_from_samples(samples_batched, targets)
crps_mean = crps.mean(axis=1)
# add all values from crps_mean to all_crps
all_crps.extend(crps_mean.tolist())
all_crps = np.array(all_crps)
mean_crps = all_crps.mean()
if task:
task.get_logger().report_scalar(
title="CRPS",
series='test',
value=mean_crps,
iteration=epoch
)