from clearml import OutputModel import torch from src.data.preprocessing import DataProcessor from src.utils.clearml import ClearMLHelper from src.utils.autoregressive import predict_auto_regressive import plotly.graph_objects as go import numpy as np import plotly.subplots as sp from plotly.subplots import make_subplots from src.trainers.trainer import Trainer from tqdm import tqdm import matplotlib.pyplot as plt class AutoRegressiveTrainer(Trainer): def __init__( self, model: torch.nn.Module, input_dim: tuple, optimizer: torch.optim.Optimizer, criterion: torch.nn.Module, data_processor: DataProcessor, device: torch.device, debug: bool = True, ): super().__init__( model=model, input_dim=input_dim, optimizer=optimizer, criterion=criterion, data_processor=data_processor, device=device, debug=debug, ) self.model.output_size = 1 def debug_plots(self, task, train: bool, data_loader, sample_indices, epoch): for actual_idx, idx in sample_indices.items(): print(f"Plotting sample {actual_idx}") auto_regressive_output = self.auto_regressive( data_loader.dataset, [idx] * 1000 ) if len(auto_regressive_output) == 3: initial, predictions, target = auto_regressive_output else: initial, _, predictions, target = auto_regressive_output # keep one initial initial = initial[0] target = target[0] predictions = predictions fig, fig2 = self.get_plot( initial, target, predictions, show_legend=(0 == 0) ) task.get_logger().report_matplotlib_figure( title="Training" if train else "Testing", series=f"Sample {actual_idx}", iteration=epoch, figure=fig, ) task.get_logger().report_matplotlib_figure( title="Training Samples" if train else "Testing Samples", series=f"Sample {actual_idx} samples", iteration=epoch, figure=fig2, report_interactive=False, ) plt.close() def auto_regressive(self, data_loader, idx, sequence_length: int = 96): self.model.eval() target_full = [] predictions_full = [] prev_features, target = data_loader.dataset[idx] prev_features = prev_features.to(self.device) initial_sequence = prev_features[:96] target_full.append(target) with torch.no_grad(): prediction = self.model(prev_features.unsqueeze(0)) predictions_full.append(prediction.squeeze(-1)) for i in range(sequence_length - 1): new_features = torch.cat( ( prev_features[1:96].cpu(), prediction.squeeze(-1).cpu(), ), dim=0, ) # get the other needed features other_features, new_target = data_loader.dataset.random_day_autoregressive( idx + i + 1 ) if other_features is not None: prev_features = torch.cat((new_features, other_features), dim=0) else: prev_features = new_features # add target to target_full target_full.append(new_target) # predict with torch.no_grad(): prediction = self.model(prev_features.unsqueeze(0).to(self.device)) predictions_full.append(prediction.squeeze(-1)) return ( initial_sequence.cpu(), torch.stack(predictions_full).cpu(), torch.stack(target_full).cpu(), ) def log_final_metrics(self, task, dataloader, train: bool = True): metrics = {metric.__class__.__name__: 0.0 for metric in self.metrics_to_track} transformed_metrics = { metric.__class__.__name__: 0.0 for metric in self.metrics_to_track } with torch.no_grad(): # iterate idx over dataset total_amount_samples = len(dataloader.dataset) - 95 for idx in tqdm(range(total_amount_samples)): _, outputs, targets = self.auto_regressive(dataloader.dataset, idx) inversed_outputs = torch.tensor( self.data_processor.inverse_transform(outputs) ) inversed_inputs = torch.tensor( self.data_processor.inverse_transform(targets) ) outputs = outputs.to(self.device) targets = targets.to(self.device) for metric in self.metrics_to_track: transformed_metrics[metric.__class__.__name__] += metric( outputs, targets ) metrics[metric.__class__.__name__] += metric( inversed_outputs, inversed_inputs ) for metric in self.metrics_to_track: metrics[metric.__class__.__name__] /= total_amount_samples transformed_metrics[metric.__class__.__name__] /= total_amount_samples for metric_name, metric_value in metrics.items(): if train: metric_name = f"train_{metric_name}" else: metric_name = f"test_{metric_name}" task.get_logger().report_single_value( name=metric_name, value=metric_value ) for metric_name, metric_value in transformed_metrics.items(): if train: metric_name = f"train_transformed_{metric_name}" else: metric_name = f"test_transformed_{metric_name}" task.get_logger().report_single_value( name=metric_name, value=metric_value )