from clearml import OutputModel import torch from src.data.preprocessing import DataProcessor from src.utils.clearml import ClearMLHelper from src.utils.autoregressive import predict_auto_regressive import plotly.graph_objects as go import numpy as np import plotly.subplots as sp from plotly.subplots import make_subplots from src.trainers.trainer import Trainer from tqdm import tqdm class AutoRegressiveTrainer(Trainer): def __init__( self, model: torch.nn.Module, input_dim: tuple, optimizer: torch.optim.Optimizer, criterion: torch.nn.Module, data_processor: DataProcessor, device: torch.device, debug: bool = True, ): super().__init__( model=model, input_dim=input_dim, optimizer=optimizer, criterion=criterion, data_processor=data_processor, device=device, debug=debug, ) self.model.output_size = 1 def debug_plots(self, task, train: bool, data_loader, sample_indices, epoch): num_samples = len(sample_indices) rows = num_samples # One row per sample since we only want one column # check if self has get_plot_error if hasattr(self, "get_plot_error"): cols = 2 print("Using get_plot_error") else: cols = 1 print("Using get_plot") fig = make_subplots( rows=rows, cols=cols, subplot_titles=[f"Sample {i+1}" for i in range(num_samples)], ) for i, idx in enumerate(sample_indices): auto_regressive_output = self.auto_regressive(data_loader.dataset, [idx]) if len(auto_regressive_output) == 3: initial, predictions, target = auto_regressive_output else: initial, predictions, _, target = auto_regressive_output initial = initial.squeeze(0) predictions = predictions.squeeze(0) target = target.squeeze(0) sub_fig = self.get_plot(initial, target, predictions, show_legend=(i == 0)) row = i + 1 col = 1 for trace in sub_fig.data: fig.add_trace(trace, row=row, col=col) if cols == 2: error_sub_fig = self.get_plot_error( target, predictions ) for trace in error_sub_fig.data: fig.add_trace(trace, row=row, col=col + 1) loss = self.criterion( predictions.to(self.device), target.to(self.device) ).item() fig["layout"]["annotations"][i].update( text=f"{self.criterion.__class__.__name__}: {loss:.6f}" ) # y axis same for all plots # fig.update_yaxes(range=[-1, 1], col=1) fig.update_layout(height=1000 * rows) task.get_logger().report_plotly( title=f"{'Training' if train else 'Test'} Samples", series="full_day", iteration=epoch, figure=fig, ) def auto_regressive(self, data_loader, idx, sequence_length: int = 96): self.model.eval() target_full = [] predictions_full = [] prev_features, target = data_loader.dataset[idx] prev_features = prev_features.to(self.device) initial_sequence = prev_features[:96] target_full.append(target) with torch.no_grad(): prediction = self.model(prev_features.unsqueeze(0)) predictions_full.append(prediction.squeeze(-1)) for i in range(sequence_length - 1): new_features = torch.cat( ( prev_features[1:96].cpu(), prediction.squeeze(-1).cpu(), ), dim=0, ) # get the other needed features other_features, new_target = data_loader.dataset.random_day_autoregressive( idx + i + 1 ) if other_features is not None: prev_features = torch.cat((new_features, other_features), dim=0) else: prev_features = new_features # add target to target_full target_full.append(new_target) # predict with torch.no_grad(): prediction = self.model(prev_features.unsqueeze(0).to(self.device)) predictions_full.append(prediction.squeeze(-1)) return ( initial_sequence.cpu(), torch.stack(predictions_full).cpu(), torch.stack(target_full).cpu(), ) def log_final_metrics(self, task, dataloader, train: bool = True): metrics = {metric.__class__.__name__: 0.0 for metric in self.metrics_to_track} transformed_metrics = { metric.__class__.__name__: 0.0 for metric in self.metrics_to_track } with torch.no_grad(): # iterate idx over dataset total_amount_samples = len(dataloader.dataset) - 95 for idx in tqdm(range(total_amount_samples)): _, outputs, targets = self.auto_regressive(dataloader.dataset, idx) inversed_outputs = torch.tensor( self.data_processor.inverse_transform(outputs) ) inversed_inputs = torch.tensor( self.data_processor.inverse_transform(targets) ) outputs = outputs.to(self.device) targets = targets.to(self.device) for metric in self.metrics_to_track: transformed_metrics[metric.__class__.__name__] += metric( outputs, targets ) metrics[metric.__class__.__name__] += metric( inversed_outputs, inversed_inputs ) for metric in self.metrics_to_track: metrics[metric.__class__.__name__] /= total_amount_samples transformed_metrics[metric.__class__.__name__] /= total_amount_samples for metric_name, metric_value in metrics.items(): if train: metric_name = f"train_{metric_name}" else: metric_name = f"test_{metric_name}" task.get_logger().report_single_value( name=metric_name, value=metric_value ) for metric_name, metric_value in transformed_metrics.items(): if train: metric_name = f"train_transformed_{metric_name}" else: metric_name = f"test_transformed_{metric_name}" task.get_logger().report_single_value( name=metric_name, value=metric_value )