from clearml import Task import torch from src.data.preprocessing import DataProcessor from src.utils.clearml import ClearMLHelper import plotly.graph_objects as go import numpy as np from plotly.subplots import make_subplots from clearml.config import running_remotely from torchinfo import summary class Trainer: def __init__( self, model: torch.nn.Module, input_dim: tuple, optimizer: torch.optim.Optimizer, criterion: torch.nn.Module, data_processor: DataProcessor, device: torch.device, debug: bool = True, ): self.input_dim = input_dim self.model = model self.optimizer = optimizer self.criterion = criterion self.device = device self.debug = debug self.metrics_to_track = [] self.data_processor = data_processor self.patience = None self.delta = None self.plot_every_n_epochs = 1 self.model.to(self.device) def plot_every(self, n: int): self.plot_every_n_epochs = n def early_stopping(self, patience: int = 5, delta: float = 0.0): self.patience = patience self.delta = delta def add_metrics_to_track(self, loss): if isinstance(loss, list): self.metrics_to_track.extend(loss) else: self.metrics_to_track.append(loss) def init_clearml_task(self, task): if task is None: return # check if running remotely # if not running_remotely(): # task_name = input("Enter a task name: ") # task.set_name(task_name) # # change_description = input("Enter a change description: ") # change_description = "" # if change_description: # task.set_comment(change_description) if self.debug: task.add_tags("Debug") task.add_tags(self.model.__class__.__name__) task.add_tags(self.criterion.__class__.__name__) task.add_tags(self.optimizer.__class__.__name__) task.add_tags(self.__class__.__name__) task.set_configuration_object("model", str(summary(self.model, self.input_dim))) self.optimizer.name = self.optimizer.__class__.__name__ self.criterion.name = self.criterion.__class__.__name__ self.optimizer = task.connect(self.optimizer, name="optimizer") self.criterion = task.connect(self.criterion, name="criterion") self.data_processor = task.connect(self.data_processor, name="data_processor") self = task.connect(self, name="trainer") task.delete_parameter("trainer/quantiles", force=True) def random_samples(self, train: bool = True, num_samples: int = 10): train_loader, test_loader = self.data_processor.get_dataloaders( predict_sequence_length=self.model.output_size ) if train: loader = train_loader else: loader = test_loader indices = np.random.randint(0, len(loader.dataset) - 1, size=num_samples) return indices def train(self, epochs: int, remotely: bool = False, task: Task = None): try: train_loader, test_loader = self.data_processor.get_dataloaders( predict_sequence_length=self.model.output_size ) train_samples = self.random_samples(train=True) test_samples = self.random_samples(train=False) self.init_clearml_task(task) if remotely: task.execute_remotely(queue_name="default", exit_process=True) self.best_score = None counter = 0 for epoch in range(1, epochs + 1): self.model.train() running_loss = 0.0 for inputs, targets, _ in train_loader: inputs, targets = inputs.to(self.device), targets.to(self.device) self.optimizer.zero_grad() output = self.model(inputs) loss = self.criterion(output, targets) loss.backward() self.optimizer.step() running_loss += loss.item() running_loss /= len(train_loader.dataset) test_loss = self.test(test_loader) if self.patience is not None: if ( self.best_score is None or test_loss < self.best_score + self.delta ): self.save_checkpoint(test_loss, task, epoch) counter = 0 else: counter += 1 if counter >= self.patience: print("Early stopping triggered") break if task: task.get_logger().report_scalar( title=self.criterion.__class__.__name__, series="train", value=running_loss, iteration=epoch, ) task.get_logger().report_scalar( title=self.criterion.__class__.__name__, series="test", value=test_loss, iteration=epoch, ) if epoch % self.plot_every_n_epochs == 0: self.debug_plots(task, True, train_loader, train_samples, epoch) self.debug_plots(task, False, test_loader, test_samples, epoch) if hasattr(self, "plot_quantile_percentages"): self.plot_quantile_percentages( task, train_loader, True, epoch, False ) # self.plot_quantile_percentages( # task, train_loader, True, epoch, True # ) self.plot_quantile_percentages( task, test_loader, False, epoch, False ) # self.plot_quantile_percentages( # task, test_loader, False, epoch, True # ) if task: self.finish_training(task=task) task.close() except Exception: if task: task.close() task.set_archived(True) raise def log_final_metrics(self, task, dataloader, train: bool = True): metrics = {metric.__class__.__name__: 0.0 for metric in self.metrics_to_track} transformed_metrics = { metric.__class__.__name__: 0.0 for metric in self.metrics_to_track } with torch.no_grad(): for inputs, targets, _ in dataloader: inputs, targets = inputs.to(self.device), targets outputs = self.model(inputs) inversed_outputs = torch.tensor( self.data_processor.inverse_transform(outputs) ) inversed_inputs = torch.tensor( self.data_processor.inverse_transform(targets) ) for metric in self.metrics_to_track: transformed_metrics[metric.__class__.__name__] += metric( outputs, targets.to(self.device) ) metrics[metric.__class__.__name__] += metric( inversed_outputs, inversed_inputs ) for metric in self.metrics_to_track: metrics[metric.__class__.__name__] /= len(dataloader) transformed_metrics[metric.__class__.__name__] /= len(dataloader) for metric_name, metric_value in metrics.items(): if train: metric_name = f"train_{metric_name}" else: metric_name = f"test_{metric_name}" task.get_logger().report_single_value( name=metric_name, value=metric_value ) for metric_name, metric_value in transformed_metrics.items(): if train: metric_name = f"train_transformed_{metric_name}" else: metric_name = f"test_transformed_{metric_name}" task.get_logger().report_single_value( name=metric_name, value=metric_value ) def finish_training(self, task): if self.best_score is not None: self.model.load_state_dict(torch.load("checkpoint.pt")) self.model.eval() train_loader, test_loader = self.data_processor.get_dataloaders( predict_sequence_length=self.model.output_size ) if not hasattr(self, "plot_quantile_percentages"): self.log_final_metrics(task, train_loader, train=True) self.log_final_metrics(task, test_loader, train=False) def test(self, test_loader: torch.utils.data.DataLoader): self.model.eval() test_loss = 0 with torch.no_grad(): for data, target, _ in test_loader: data, target = data.to(self.device), target.to(self.device) output = self.model(data) test_loss += self.criterion(output, target).item() test_loss /= len(test_loader.dataset) return test_loss def save_checkpoint(self, val_loss, task, iteration: int): torch.save(self.model.state_dict(), "checkpoint.pt") task.update_output_model( model_path="checkpoint.pt", iteration=iteration, auto_delete_file=False ) self.best_score = val_loss def get_plot( self, current_day, next_day, predictions, show_legend: bool = True, retransform: bool = True, ): if retransform: current_day = self.data_processor.inverse_transform(current_day) next_day = self.data_processor.inverse_transform(next_day) predictions = self.data_processor.inverse_transform(predictions) fig = go.Figure() fig.add_trace( go.Scatter( x=np.arange(96), y=current_day.view(-1).cpu().numpy(), name="Current Day", ) ) fig.add_trace( go.Scatter( x=96 + np.arange(96), y=next_day.view(-1).cpu().numpy(), name="Next Day" ) ) fig.add_trace( go.Scatter( x=96 + np.arange(96), y=predictions.reshape(-1), name="Predictions" ) ) fig.update_layout(title="Predictions of the Linear Model") return fig def debug_plots(self, task, train: bool, data_loader, sample_indices, epoch): num_samples = len(sample_indices) rows = num_samples # One row per sample since we only want one column cols = 1 fig = make_subplots( rows=rows, cols=cols, subplot_titles=[f"Sample {i+1}" for i in range(num_samples)], ) for i, idx in enumerate(sample_indices): features, target, _ = data_loader.dataset[idx] features = features.to(self.device) target = target.to(self.device) self.model.eval() with torch.no_grad(): predictions = self.model(features).cpu() sub_fig = self.get_plot( features[:96], target, predictions, show_legend=(i == 0) ) row = i + 1 col = 1 for trace in sub_fig.data: fig.add_trace(trace, row=row, col=col) # loss = self.criterion(predictions.to(self.device), target.squeeze(-1).to(self.device)).item() # fig['layout']['annotations'][i].update(text=f"{loss.__class__.__name__}: {loss:.6f}") # y axis same for all plots # fig.update_yaxes(range=[-1, 1], col=1) fig.update_layout(height=1000 * rows) task.get_logger().report_plotly( title=f"{'Training' if train else 'Test'} Samples", series="full_day", iteration=epoch, figure=fig, ) def debug_scatter_plot(self, task, train: bool, samples, epoch): X, y = samples X = X.to(self.device) y = y.to(self.device) y = y[:, 0] self.model.eval() predictions = self.model(X) num_samples = len(X) rows = -(-num_samples // 2) # Ceiling division to handle odd number of samples cols = 2 fig = make_subplots( rows=rows, cols=cols, subplot_titles=[f"Sample {i+1}" for i in range(num_samples)], ) for i, (current_day, next_value, pred) in enumerate(zip(X, y, predictions)): sub_fig = self.scatter_plot(current_day, pred, next_value) row = (i // cols) + 1 col = (i % cols) + 1 for trace in sub_fig.data: fig.add_trace(trace, row=row, col=col) fig.update_layout(height=300 * rows) task.get_logger().report_plotly( title=f"{'Training' if train else 'Test'} Samples", series="scatter", iteration=epoch, figure=fig, ) def scatter_plot(self, x, y, real_y): fig = go.Figure() # 96 values of x fig.add_trace( go.Scatter(x=np.arange(96), y=x.view(-1).cpu().numpy(), name="Current Day") ) # add one value of y fig.add_trace(go.Scatter(x=[96], y=[y.item()], name="Next Day")) # add one value of real_y fig.add_trace(go.Scatter(x=[96], y=[real_y.item()], name="Real Next Day")) fig.update_layout(title="Predictions of the Linear Model") return fig