Fixing some stuff

This commit is contained in:
Victor Mylle
2023-12-30 15:22:32 +00:00
parent ef8b5f49ac
commit c26ae76951
6 changed files with 107 additions and 33 deletions

View File

@@ -1,4 +1,6 @@
import torch
from tqdm import tqdm
from src.losses.crps_metric import crps_from_samples
from src.trainers.trainer import Trainer
from src.trainers.autoregressive_trainer import AutoRegressiveTrainer
from src.data.preprocessing import DataProcessor
@@ -81,21 +83,63 @@ class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer):
debug=debug,
)
def calculate_crps_from_samples(self, task, dataloader, epoch: int):
crps_from_samples_metric = []
with torch.no_grad():
for _, _, idx_batch in tqdm(dataloader):
if len(idx_batch) == 0:
continue
for idx in tqdm(idx_batch):
computed_idx_batch = [idx] * 100
_, _, samples, targets = self.auto_regressive(
dataloader.dataset, idx_batch=computed_idx_batch
)
samples = samples.unsqueeze(0)
targets = targets.squeeze(-1)
targets = targets[0].unsqueeze(0)
crps = crps_from_samples(samples, targets)
crps_from_samples_metric.append(crps[0].mean().item())
task.get_logger().report_scalar(
title="CRPS_from_samples", series="test", value=np.mean(crps_from_samples_metric), iteration=epoch
)
def log_final_metrics(self, task, dataloader, train: bool = True):
metrics = {metric.__class__.__name__: 0.0 for metric in self.metrics_to_track}
transformed_metrics = {
metric.__class__.__name__: 0.0 for metric in self.metrics_to_track
}
crps_from_samples_metric = []
with torch.no_grad():
total_samples = len(dataloader.dataset) - 96
batches = 0
for _, _, idx_batch in dataloader:
for _, _, idx_batch in tqdm(dataloader):
idx_batch = [idx for idx in idx_batch if idx < total_samples]
if len(idx_batch) == 0:
continue
if train == False:
for idx in tqdm(idx_batch):
computed_idx_batch = [idx] * 100
_, outputs, samples, targets = self.auto_regressive(
dataloader.dataset, idx_batch=computed_idx_batch
)
samples = samples.unsqueeze(0)
targets = targets.squeeze(-1)
targets = targets[0].unsqueeze(0)
crps = crps_from_samples(samples, targets)
crps_from_samples_metric.append(crps[0].mean().item())
_, outputs, samples, targets = self.auto_regressive(
dataloader.dataset, idx_batch=idx_batch
)
@@ -147,6 +191,11 @@ class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer):
)
task.get_logger().report_single_value(name=name, value=metric_value)
if train == False:
task.get_logger().report_single_value(
name="test_CRPS_from_samples_transformed", value=np.mean(crps_from_samples_metric)
)
def get_plot_error(
self,
next_day,