Fixing some stuff
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from src.losses.crps_metric import crps_from_samples
|
||||
from src.trainers.trainer import Trainer
|
||||
from src.trainers.autoregressive_trainer import AutoRegressiveTrainer
|
||||
from src.data.preprocessing import DataProcessor
|
||||
@@ -81,21 +83,63 @@ class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer):
|
||||
debug=debug,
|
||||
)
|
||||
|
||||
def calculate_crps_from_samples(self, task, dataloader, epoch: int):
|
||||
crps_from_samples_metric = []
|
||||
|
||||
with torch.no_grad():
|
||||
for _, _, idx_batch in tqdm(dataloader):
|
||||
if len(idx_batch) == 0:
|
||||
continue
|
||||
|
||||
for idx in tqdm(idx_batch):
|
||||
computed_idx_batch = [idx] * 100
|
||||
_, _, samples, targets = self.auto_regressive(
|
||||
dataloader.dataset, idx_batch=computed_idx_batch
|
||||
)
|
||||
samples = samples.unsqueeze(0)
|
||||
targets = targets.squeeze(-1)
|
||||
targets = targets[0].unsqueeze(0)
|
||||
|
||||
crps = crps_from_samples(samples, targets)
|
||||
|
||||
crps_from_samples_metric.append(crps[0].mean().item())
|
||||
|
||||
task.get_logger().report_scalar(
|
||||
title="CRPS_from_samples", series="test", value=np.mean(crps_from_samples_metric), iteration=epoch
|
||||
)
|
||||
|
||||
def log_final_metrics(self, task, dataloader, train: bool = True):
|
||||
metrics = {metric.__class__.__name__: 0.0 for metric in self.metrics_to_track}
|
||||
transformed_metrics = {
|
||||
metric.__class__.__name__: 0.0 for metric in self.metrics_to_track
|
||||
}
|
||||
|
||||
crps_from_samples_metric = []
|
||||
|
||||
with torch.no_grad():
|
||||
total_samples = len(dataloader.dataset) - 96
|
||||
batches = 0
|
||||
for _, _, idx_batch in dataloader:
|
||||
for _, _, idx_batch in tqdm(dataloader):
|
||||
idx_batch = [idx for idx in idx_batch if idx < total_samples]
|
||||
|
||||
if len(idx_batch) == 0:
|
||||
continue
|
||||
|
||||
if train == False:
|
||||
for idx in tqdm(idx_batch):
|
||||
computed_idx_batch = [idx] * 100
|
||||
_, outputs, samples, targets = self.auto_regressive(
|
||||
dataloader.dataset, idx_batch=computed_idx_batch
|
||||
)
|
||||
samples = samples.unsqueeze(0)
|
||||
targets = targets.squeeze(-1)
|
||||
targets = targets[0].unsqueeze(0)
|
||||
|
||||
crps = crps_from_samples(samples, targets)
|
||||
|
||||
crps_from_samples_metric.append(crps[0].mean().item())
|
||||
|
||||
|
||||
_, outputs, samples, targets = self.auto_regressive(
|
||||
dataloader.dataset, idx_batch=idx_batch
|
||||
)
|
||||
@@ -147,6 +191,11 @@ class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer):
|
||||
)
|
||||
task.get_logger().report_single_value(name=name, value=metric_value)
|
||||
|
||||
if train == False:
|
||||
task.get_logger().report_single_value(
|
||||
name="test_CRPS_from_samples_transformed", value=np.mean(crps_from_samples_metric)
|
||||
)
|
||||
|
||||
def get_plot_error(
|
||||
self,
|
||||
next_day,
|
||||
|
||||
Reference in New Issue
Block a user