Autoregressive test score calculated on 96 values
This commit is contained in:
@@ -8,6 +8,7 @@ import numpy as np
|
||||
import plotly.subplots as sp
|
||||
from plotly.subplots import make_subplots
|
||||
from trainers.trainer import Trainer
|
||||
from tqdm import tqdm
|
||||
|
||||
class AutoRegressiveTrainer(Trainer):
|
||||
def debug_plots(self, task, train: bool, data_loader, sample_indices, epoch):
|
||||
@@ -77,4 +78,45 @@ class AutoRegressiveTrainer(Trainer):
|
||||
prediction = self.model(new_features.unsqueeze(0).to(self.device))
|
||||
predictions_full.append(prediction.squeeze(-1))
|
||||
|
||||
return initial_sequence.cpu(), torch.stack(predictions_full).cpu(), torch.stack(target_full).cpu()
|
||||
return initial_sequence.cpu(), torch.stack(predictions_full).cpu(), torch.stack(target_full).cpu()
|
||||
|
||||
def log_final_metrics(self, task, dataloader, train: bool = True):
|
||||
metrics = { metric.__class__.__name__: 0.0 for metric in self.metrics_to_track }
|
||||
transformed_metrics = { metric.__class__.__name__: 0.0 for metric in self.metrics_to_track }
|
||||
|
||||
with torch.no_grad():
|
||||
# iterate idx over dataset
|
||||
total_amount_samples = len(dataloader.dataset) - 95
|
||||
|
||||
for idx in tqdm(range(total_amount_samples)):
|
||||
_, outputs, targets = self.auto_regressive(dataloader, idx)
|
||||
|
||||
inversed_outputs = torch.tensor(self.data_processor.inverse_transform(outputs))
|
||||
inversed_inputs = torch.tensor(self.data_processor.inverse_transform(targets))
|
||||
|
||||
outputs = outputs.to(self.device)
|
||||
targets = targets.to(self.device)
|
||||
|
||||
for metric in self.metrics_to_track:
|
||||
transformed_metrics[metric.__class__.__name__] += metric(outputs, targets)
|
||||
metrics[metric.__class__.__name__] += metric(inversed_outputs, inversed_inputs)
|
||||
|
||||
for metric in self.metrics_to_track:
|
||||
metrics[metric.__class__.__name__] /= total_amount_samples
|
||||
transformed_metrics[metric.__class__.__name__] /= total_amount_samples
|
||||
|
||||
for metric_name, metric_value in metrics.items():
|
||||
if train:
|
||||
metric_name = f'train_{metric_name}'
|
||||
else:
|
||||
metric_name = f'test_{metric_name}'
|
||||
|
||||
task.get_logger().report_single_value(name=metric_name, value=metric_value)
|
||||
|
||||
for metric_name, metric_value in transformed_metrics.items():
|
||||
if train:
|
||||
metric_name = f'train_transformed_{metric_name}'
|
||||
else:
|
||||
metric_name = f'test_transformed_{metric_name}'
|
||||
|
||||
task.get_logger().report_single_value(name=metric_name, value=metric_value)
|
||||
@@ -180,7 +180,7 @@ class Trainer:
|
||||
|
||||
transformed_train_loader, transformed_test_loader = self.data_processor.get_dataloaders(predict_sequence_length=self.model.output_size)
|
||||
|
||||
self.log_final_metrics(task, transformed_train_loader, train=True)
|
||||
# self.log_final_metrics(task, transformed_train_loader, train=True)
|
||||
self.log_final_metrics(task, transformed_test_loader, train=False)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user