from clearml import OutputModel import torch from data.preprocessing import DataProcessor from utils.clearml import ClearMLHelper import plotly.graph_objects as go import numpy as np import plotly.subplots as sp from plotly.subplots import make_subplots class Trainer: def __init__(self, model: torch.nn.Module, optimizer: torch.optim.Optimizer, criterion: torch.nn.Module, data_processor: DataProcessor, device: torch.device, clearml_helper: ClearMLHelper = None, debug: bool = True): self.model = model self.optimizer = optimizer self.criterion = criterion self.device = device self.clearml_helper = clearml_helper self.debug = debug self.metrics_to_track = [] self.data_processor = data_processor self.patience = None self.delta = None self.plot_every_n_epochs = 1 self.model.to(self.device) def plot_every(self, n: int): self.plot_every_n_epochs = n def early_stopping(self, patience: int = 5, delta: float = 0.0): self.patience = patience self.delta = delta def add_metrics_to_track(self, loss: torch.nn.Module | list[torch.nn.Module]): if isinstance(loss, list): self.metrics_to_track.extend(loss) else: self.metrics_to_track.append(loss) def init_clearml_task(self): if not self.clearml_helper: return None task_name = input("Enter a task name: ") if task_name == "": task_name = "Untitled Task" task = self.clearml_helper.get_task(task_name=task_name) if self.debug: task.add_tags('Debug') change_description = input("Enter a change description: ") if change_description: task.set_comment(change_description) task.add_tags(self.model.__class__.__name__) task.add_tags(self.criterion.__class__.__name__) task.add_tags(self.optimizer.__class__.__name__) task.add_tags(self.__class__.__name__) self.optimizer.name = self.optimizer.__class__.__name__ self.criterion.name = self.criterion.__class__.__name__ task.connect(self.optimizer, name="optimizer") task.connect(self.criterion, name="criterion") task.connect(self.data_processor, name="data_processor") return task def random_samples(self, train: bool = True, num_samples: int = 10): train_loader, test_loader = self.data_processor.get_dataloaders(predict_sequence_length=self.model.output_size) if train: loader = train_loader else: loader = test_loader indices = np.random.randint(0, len(loader.dataset) - 1, size=num_samples) return indices def train(self, epochs: int): train_loader, test_loader = self.data_processor.get_dataloaders(predict_sequence_length=self.model.output_size) train_samples = self.random_samples(train=True) test_samples = self.random_samples(train=False) task = self.init_clearml_task() self.best_score = None counter = 0 for epoch in range(1, epochs + 1): self.model.train() running_loss = 0.0 for inputs, targets in train_loader: inputs, targets = inputs.to(self.device), targets.to(self.device) self.optimizer.zero_grad() output = self.model(inputs) loss = self.criterion(output, targets) loss.backward() self.optimizer.step() running_loss += loss.item() running_loss /= len(train_loader.dataset) test_loss = self.test(test_loader) if self.patience is not None: if self.best_score is None or test_loss < self.best_score + self.delta: self.save_checkpoint(test_loss, task, epoch) counter = 0 else: counter += 1 if counter >= self.patience: print('Early stopping triggered') break if task: task.get_logger().report_scalar(title=self.criterion.__class__.__name__, series="train", value=running_loss, iteration=epoch) task.get_logger().report_scalar(title=self.criterion.__class__.__name__, series="test", value=test_loss, iteration=epoch) if epoch % self.plot_every_n_epochs == 0: self.debug_plots(task, True, train_loader, train_samples, epoch) self.debug_plots(task, False, test_loader, test_samples, epoch) if task: self.finish_training(task=task) task.close() def log_final_metrics(self, task, dataloader, train: bool = True): metrics = { metric.__class__.__name__: 0.0 for metric in self.metrics_to_track } transformed_metrics = { metric.__class__.__name__: 0.0 for metric in self.metrics_to_track } with torch.no_grad(): for inputs, targets in dataloader: inputs, targets = inputs.to(self.device), targets outputs = self.model(inputs) inversed_outputs = torch.tensor(self.data_processor.inverse_transform(outputs)) inversed_inputs = torch.tensor(self.data_processor.inverse_transform(targets)) for metric in self.metrics_to_track: transformed_metrics[metric.__class__.__name__] += metric(outputs, targets.to(self.device)) metrics[metric.__class__.__name__] += metric(inversed_outputs, inversed_inputs) for metric in self.metrics_to_track: metrics[metric.__class__.__name__] /= len(dataloader) transformed_metrics[metric.__class__.__name__] /= len(dataloader) for metric_name, metric_value in metrics.items(): if train: metric_name = f'train_{metric_name}' else: metric_name = f'test_{metric_name}' task.get_logger().report_single_value(name=metric_name, value=metric_value) for metric_name, metric_value in transformed_metrics.items(): if train: metric_name = f'train_transformed_{metric_name}' else: metric_name = f'test_transformed_{metric_name}' task.get_logger().report_single_value(name=metric_name, value=metric_value) def finish_training(self, task): if self.best_score is not None: self.model.load_state_dict(torch.load('checkpoint.pt')) self.model.eval() transformed_train_loader, transformed_test_loader = self.data_processor.get_dataloaders(predict_sequence_length=self.model.output_size) self.log_final_metrics(task, transformed_train_loader, train=True) self.log_final_metrics(task, transformed_test_loader, train=False) def test(self, test_loader: torch.utils.data.DataLoader): self.model.eval() test_loss = 0 with torch.no_grad(): for data, target in test_loader: data, target = data.to(self.device), target.to(self.device) output = self.model(data) test_loss += self.criterion(output, target).item() test_loss /= len(test_loader.dataset) return test_loss def save_checkpoint(self, val_loss, task, iteration: int): torch.save(self.model.state_dict(), 'checkpoint.pt') task.update_output_model(model_path='checkpoint.pt', iteration=iteration, auto_delete_file=False) self.best_score = val_loss def get_plot(self, current_day, next_day, predictions, show_legend: bool = True): fig = go.Figure() fig.add_trace(go.Scatter(x=np.arange(96), y=current_day.view(-1).cpu().numpy(), name="Current Day")) fig.add_trace(go.Scatter(x=96 + np.arange(96), y=next_day.view(-1).cpu().numpy(), name="Next Day")) fig.add_trace(go.Scatter(x=96 + np.arange(96), y=predictions.reshape(-1), name="Predictions")) fig.update_layout(title="Predictions of the Linear Model") return fig def debug_plots(self, task, train: bool, data_loader, sample_indices, epoch): num_samples = len(sample_indices) rows = num_samples # One row per sample since we only want one column cols = 1 fig = make_subplots(rows=rows, cols=cols, subplot_titles=[f'Sample {i+1}' for i in range(num_samples)]) for i, idx in enumerate(sample_indices): features, target = data_loader.dataset[idx] features = features.to(self.device) target = target.to(self.device) self.model.eval() with torch.no_grad(): predictions = self.model(features).cpu() sub_fig = self.get_plot(features[:96], target, predictions, show_legend=(i == 0)) row = i + 1 col = 1 for trace in sub_fig.data: fig.add_trace(trace, row=row, col=col) loss = self.criterion(predictions.to(self.device), target.squeeze(-1).to(self.device)).item() fig['layout']['annotations'][i].update(text=f"{loss.__class__.__name__}: {loss:.6f}") # y axis same for all plots fig.update_yaxes(range=[-1, 1], col=1) fig.update_layout(height=300 * rows) task.get_logger().report_plotly( title=f"{'Training' if train else 'Test'} Samples", series="full_day", iteration=epoch, figure=fig ) def debug_scatter_plot(self, task, train: bool, samples, epoch): X, y = samples X = X.to(self.device) y = y.to(self.device) y = y[:, 0] self.model.eval() predictions = self.model(X) num_samples = len(X) rows = -(-num_samples // 2) # Ceiling division to handle odd number of samples cols = 2 fig = make_subplots(rows=rows, cols=cols, subplot_titles=[f'Sample {i+1}' for i in range(num_samples)]) for i, (current_day, next_value, pred) in enumerate(zip(X, y, predictions)): sub_fig = self.scatter_plot(current_day, pred, next_value) row = (i // cols) + 1 col = (i % cols) + 1 for trace in sub_fig.data: fig.add_trace(trace, row=row, col=col) fig.update_layout(height=300 * rows) task.get_logger().report_plotly( title=f"{'Training' if train else 'Test'} Samples", series="scatter", iteration=epoch, figure=fig ) def scatter_plot(self, x, y, real_y): fig = go.Figure() # 96 values of x fig.add_trace(go.Scatter(x=np.arange(96), y=x.view(-1).cpu().numpy(), name="Current Day")) # add one value of y fig.add_trace(go.Scatter(x=[96], y=[y.item()], name="Next Day")) # add one value of real_y fig.add_trace(go.Scatter(x=[96], y=[real_y.item()], name="Real Next Day")) fig.update_layout(title="Predictions of the Linear Model") return fig