import torch from tqdm import tqdm from src.losses.crps_metric import crps_from_samples from src.trainers.trainer import Trainer from src.trainers.autoregressive_trainer import AutoRegressiveTrainer from src.data.preprocessing import DataProcessor from src.utils.clearml import ClearMLHelper from src.losses import PinballLoss, NonAutoRegressivePinballLoss, CRPSLoss import plotly.graph_objects as go import numpy as np import matplotlib.pyplot as plt from scipy.interpolate import CubicSpline import matplotlib.pyplot as plt import seaborn as sns import matplotlib.patches as mpatches def sample_from_dist(quantiles, preds): if isinstance(preds, torch.Tensor): preds = preds.detach().cpu() # if preds more than 2 dimensions, flatten to 2 if len(preds.shape) > 2: preds = preds.reshape(-1, preds.shape[-1]) # target will be reshaped from (1024, 96, 15) to (1024*96, 15) # our target (1024, 96) also needs to be reshaped to (1024*96, 1) target = target.reshape(-1, 1) # preds and target as numpy preds = preds.numpy() # random probabilities of (1000, 1) import random probs = np.array([random.random() for _ in range(1000)]) spline = CubicSpline(quantiles, preds, axis=1) samples = spline(probs) # get the diagonal samples = np.diag(samples) return samples def auto_regressive(dataset, model, quantiles, idx_batch, sequence_length: int = 96): device = next(model.parameters()).device prev_features, targets = dataset.get_batch(idx_batch) prev_features = prev_features.to(device) targets = targets.to(device) if len(list(prev_features.shape)) == 2: initial_sequence = prev_features[:, :96] else: initial_sequence = prev_features[:, :, 0] target_full = targets[:, 0].unsqueeze(1) # (batch_size, 1) with torch.no_grad(): new_predictions_full = model(prev_features) # (batch_size, quantiles) samples = ( torch.tensor(sample_from_dist(quantiles, new_predictions_full)) .unsqueeze(1) .to(device) ) # (batch_size, 1) predictions_samples = samples predictions_full = new_predictions_full.unsqueeze(1) for i in range(sequence_length - 1): if len(list(prev_features.shape)) == 2: new_features = torch.cat( (prev_features[:, 1:96], samples), dim=1 ) # (batch_size, 96) new_features = new_features.float() other_features, new_targets = dataset.get_batch_autoregressive( np.array(idx_batch) + i + 1 ) # (batch_size, new_features) if other_features is not None: prev_features = torch.cat( (new_features.to(device), other_features.to(device)), dim=1 ) # (batch_size, 96 + new_features) else: prev_features = new_features else: other_features, new_targets = dataset.get_batch_autoregressive( np.array(idx_batch) + i + 1 ) # (batch_size, 1, new_features) # change the other_features nrv based on the samples other_features[:, 0, 0] = samples.squeeze(-1) # make sure on same device other_features = other_features.to(device) prev_features = prev_features.to(device) prev_features = torch.cat( (prev_features[:, 1:, :], other_features), dim=1 ) # (batch_size, 96, new_features) target_full = torch.cat( (target_full, new_targets.to(device)), dim=1 ) # (batch_size, sequence_length) with torch.no_grad(): new_predictions_full = model( prev_features ) # (batch_size, quantiles) predictions_full = torch.cat( (predictions_full, new_predictions_full.unsqueeze(1)), dim=1 ) # (batch_size, sequence_length, quantiles) samples = ( torch.tensor(sample_from_dist(quantiles, new_predictions_full)) .unsqueeze(-1) .to(device) ) # (batch_size, 1) predictions_samples = torch.cat((predictions_samples, samples), dim=1) return ( initial_sequence, predictions_full, predictions_samples, target_full.unsqueeze(-1), ) class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer): def __init__( self, model: torch.nn.Module, input_dim: tuple, optimizer: torch.optim.Optimizer, data_processor: DataProcessor, quantiles: list, device: torch.device, debug: bool = True, ): self.quantiles = quantiles criterion = PinballLoss(quantiles=quantiles) super().__init__( model=model, input_dim=input_dim, optimizer=optimizer, criterion=criterion, data_processor=data_processor, device=device, debug=debug, ) def calculate_crps_from_samples(self, task, dataloader, epoch: int): crps_from_samples_metric = [] with torch.no_grad(): total_samples = len(dataloader.dataset) - 96 for _, _, idx_batch in tqdm(dataloader): idx_batch = [idx for idx in idx_batch if idx < total_samples] if len(idx_batch) == 0: continue for idx in tqdm(idx_batch): computed_idx_batch = [idx] * 100 _, _, samples, targets = self.auto_regressive( dataloader.dataset, idx_batch=computed_idx_batch ) samples = samples.unsqueeze(0) targets = targets.squeeze(-1) targets = targets[0].unsqueeze(0) crps = crps_from_samples(samples, targets) crps_from_samples_metric.append(crps[0].mean().item()) task.get_logger().report_scalar( title="CRPS_from_samples", series="test", value=np.mean(crps_from_samples_metric), iteration=epoch ) def log_final_metrics(self, task, dataloader, train: bool = True): metrics = {metric.__class__.__name__: 0.0 for metric in self.metrics_to_track} transformed_metrics = { metric.__class__.__name__: 0.0 for metric in self.metrics_to_track } crps_from_samples_metric = [] with torch.no_grad(): total_samples = len(dataloader.dataset) - 96 batches = 0 for _, _, idx_batch in tqdm(dataloader): idx_batch = [idx for idx in idx_batch if idx < total_samples] if len(idx_batch) == 0: continue if train == False: for idx in tqdm(idx_batch): computed_idx_batch = [idx] * 100 _, outputs, samples, targets = self.auto_regressive( dataloader.dataset, idx_batch=computed_idx_batch ) samples = samples.unsqueeze(0) targets = targets.squeeze(-1) targets = targets[0].unsqueeze(0) crps = crps_from_samples(samples, targets) crps_from_samples_metric.append(crps[0].mean().item()) _, outputs, samples, targets = self.auto_regressive( dataloader.dataset, idx_batch=idx_batch ) samples = samples.to(self.device) outputs = outputs.to(self.device) targets = targets.to(self.device) inversed_samples = self.data_processor.inverse_transform(samples) inversed_targets = self.data_processor.inverse_transform(targets) inversed_outputs = self.data_processor.inverse_transform(outputs) inversed_samples = inversed_samples.to(self.device) inversed_targets = inversed_targets.to(self.device) inversed_outputs = inversed_outputs.to(self.device) for metric in self.metrics_to_track: if metric.__class__ != PinballLoss and metric.__class__ != CRPSLoss: transformed_metrics[metric.__class__.__name__] += metric( samples, targets.squeeze(-1) ) metrics[metric.__class__.__name__] += metric( inversed_samples, inversed_targets.squeeze(-1) ) else: transformed_metrics[metric.__class__.__name__] += metric( outputs, targets ) metrics[metric.__class__.__name__] += metric( inversed_outputs, inversed_targets ) batches += 1 for metric in self.metrics_to_track: metrics[metric.__class__.__name__] /= batches transformed_metrics[metric.__class__.__name__] /= batches for metric_name, metric_value in metrics.items(): if PinballLoss.__name__ in metric_name: continue name = f"train_{metric_name}" if train else f"test_{metric_name}" task.get_logger().report_single_value(name=name, value=metric_value) for metric_name, metric_value in transformed_metrics.items(): name = ( f"train_transformed_{metric_name}" if train else f"test_transformed_{metric_name}" ) task.get_logger().report_single_value(name=name, value=metric_value) if train == False: task.get_logger().report_single_value( name="test_CRPS_from_samples_transformed", value=np.mean(crps_from_samples_metric) ) # def get_plot_error( # self, # next_day, # predictions, # ): # metric = PinballLoss(quantiles=self.quantiles) # fig = go.Figure() # next_day_np = next_day.view(-1).cpu().numpy() # predictions_np = predictions.cpu().numpy() # if True: # next_day_np = self.data_processor.inverse_transform(next_day_np) # predictions_np = self.data_processor.inverse_transform(predictions_np) # # for each time step, calculate the error using the metric # errors = [] # for i in range(96): # target_tensor = torch.tensor(next_day_np[i]).unsqueeze(0) # prediction_tensor = torch.tensor(predictions_np[i]).unsqueeze(0) # errors.append(metric(prediction_tensor, target_tensor)) # # plot the error # fig.add_trace(go.Scatter(x=np.arange(96), y=errors, name=metric.__class__.__name__)) # fig.update_layout(title=f"Error of {metric.__class__.__name__} for each time step") # return fig def get_plot( self, current_day, next_day, predictions, show_legend: bool = True, retransform: bool = True, ): fig = go.Figure() # Convert to numpy for plotting current_day_np = current_day.view(-1).cpu().numpy() next_day_np = next_day.view(-1).cpu().numpy() predictions_np = predictions.cpu().numpy() if retransform: current_day_np = self.data_processor.inverse_transform(current_day_np) next_day_np = self.data_processor.inverse_transform(next_day_np) predictions_np = self.data_processor.inverse_transform(predictions_np) ci_99_upper = np.quantile(predictions_np, 0.995, axis=0) ci_99_lower = np.quantile(predictions_np, 0.005, axis=0) ci_95_upper = np.quantile(predictions_np, 0.975, axis=0) ci_95_lower = np.quantile(predictions_np, 0.025, axis=0) ci_90_upper = np.quantile(predictions_np, 0.95, axis=0) ci_90_lower = np.quantile(predictions_np, 0.05, axis=0) ci_50_lower = np.quantile(predictions_np, 0.25, axis=0) ci_50_upper = np.quantile(predictions_np, 0.75, axis=0) # Add traces for current and next day # fig.add_trace(go.Scatter(x=np.arange(96), y=current_day_np, name="Current Day")) # fig.add_trace(go.Scatter(x=96 + np.arange(96), y=next_day_np, name="Next Day")) # for i, q in enumerate(self.quantiles): # fig.add_trace( # go.Scatter( # x=96 + np.arange(96), # y=predictions_np[:, i], # name=f"Prediction (Q={q})", # line=dict(dash="dash"), # ) # ) # # Update the layout # fig.update_layout( # title="Predictions and Quantiles of the Linear Model", # showlegend=show_legend, # ) sns.set_theme() time_steps = np.arange(0, 96) fig, ax = plt.subplots(figsize=(20, 10)) ax.plot(time_steps, predictions_np.mean(axis=0), label="Mean of NRV samples", linewidth=3) # ax.fill_between(time_steps, ci_lower, ci_upper, color='b', alpha=0.2, label='Full Interval') ax.fill_between(time_steps, ci_99_lower, ci_99_upper, color='b', alpha=0.2, label='99% Interval') ax.fill_between(time_steps, ci_95_lower, ci_95_upper, color='b', alpha=0.2, label='95% Interval') ax.fill_between(time_steps, ci_90_lower, ci_90_upper, color='b', alpha=0.2, label='90% Interval') ax.fill_between(time_steps, ci_50_lower, ci_50_upper, color='b', alpha=0.2, label='50% Interval') ax.plot(next_day_np, label="Real NRV", linewidth=3) # full_interval_patch = mpatches.Patch(color='b', alpha=0.2, label='Full Interval') ci_99_patch = mpatches.Patch(color='b', alpha=0.3, label='99% Interval') ci_95_patch = mpatches.Patch(color='b', alpha=0.4, label='95% Interval') ci_90_patch = mpatches.Patch(color='b', alpha=0.5, label='90% Interval') ci_50_patch = mpatches.Patch(color='b', alpha=0.6, label='50% Interval') ax.legend(handles=[ci_99_patch, ci_95_patch, ci_90_patch, ci_50_patch, ax.lines[0], ax.lines[1]]) return fig def auto_regressive(self, dataset, idx_batch, sequence_length: int = 96): return auto_regressive(dataset, self.model, self.quantiles, idx_batch, sequence_length) def plot_quantile_percentages( self, task, data_loader, train: bool = True, iteration: int = None, full_day: bool = False ): quantiles = self.quantiles total = 0 quantile_counter = {q: 0 for q in quantiles} self.model.eval() with torch.no_grad(): total_samples = len(data_loader.dataset) - 96 for inputs, targets, idx_batch in data_loader: idx_batch = [idx for idx in idx_batch if idx < total_samples] if full_day: _, outputs, samples, targets = self.auto_regressive( data_loader.dataset, idx_batch=idx_batch ) # outputs: (batch, sequence_length, num_quantiles) # targets: (batch, sequence_length, 1) # reshape to (batch_size * sequence_length, num_quantiles) outputs = outputs.reshape(-1, len(quantiles)) targets = targets.reshape(-1) # to cpu outputs = outputs.cpu().numpy() targets = targets.cpu().numpy() else: inputs = inputs.to(self.device) outputs = self.model(inputs).cpu().numpy() # (batch_size, num_quantiles) targets = targets.squeeze(-1).cpu().numpy() # (batch_size, 1) for i, q in enumerate(quantiles): quantile_counter[q] += np.sum( targets < outputs[:, i] ) total += len(targets) # to numpy array of length len(quantiles) percentages = np.array( [quantile_counter[q] / total for q in quantiles] ) bar_width = 0.35 index = np.arange(len(quantiles)) # Plotting the bars fig, ax = plt.subplots(figsize=(15, 10)) bar1 = ax.bar( index, quantiles, bar_width, label="Ideal", color="brown" ) bar2 = ax.bar( index + bar_width, percentages, bar_width, label="NN model", color="blue" ) # Adding the percentage values above the bars for bar2 for rect in bar2: height = rect.get_height() ax.text( rect.get_x() + rect.get_width() / 2.0, 1.005 * height, f"{height:.2}", ha="center", va="bottom", ) # Format the number as a percentage series_name = "Training Set" if train else "Test Set" full_day_str = "Full Day" if full_day else "Single Step" # Adding labels and title ax.set_xlabel("Quantile") ax.set_ylabel("Fraction of data under quantile forecast") ax.set_title(f"{series_name} {full_day_str} Quantile Performance Comparison") ax.set_xticks(index + bar_width / 2) ax.set_xticklabels(quantiles) ax.legend() task.get_logger().report_matplotlib_figure( title="Quantile Performance Comparison", series=f"{series_name} {full_day_str}", report_image=True, figure=plt, iteration=iteration, ) plt.close() class NonAutoRegressiveQuantileRegression(Trainer): def __init__( self, model: torch.nn.Module, input_dim: tuple, optimizer: torch.optim.Optimizer, data_processor: DataProcessor, quantiles: list, device: torch.device, debug: bool = True, ): self.quantiles = quantiles criterion = NonAutoRegressivePinballLoss(quantiles=quantiles) super().__init__( model=model, input_dim=input_dim, optimizer=optimizer, criterion=criterion, data_processor=data_processor, device=device, debug=debug, ) def log_final_metrics(self, task, dataloader, train: bool = True): metrics = {metric.__class__.__name__: 0.0 for metric in self.metrics_to_track} transformed_metrics = { metric.__class__.__name__: 0.0 for metric in self.metrics_to_track } with torch.no_grad(): for inputs, targets, _ in dataloader: inputs, targets = inputs.to(self.device), targets.to(self.device) outputs = self.model(inputs) outputted_samples = [ sample_from_dist(self.quantiles, output.cpu().numpy()) for output in outputs ] outputted_samples = torch.tensor(outputted_samples) inversed_outputs_samples = self.data_processor.inverse_transform( outputted_samples ) outputs = outputs.reshape(inputs.shape[0], -1, len(self.quantiles)) inversed_outputs = self.data_processor.inverse_transform(outputs) inversed_targets = self.data_processor.inverse_transform(targets) inversed_outputs_samples = inversed_outputs_samples.to(self.device) inversed_targets = inversed_targets.to(self.device) outputted_samples = outputted_samples.to(self.device) inversed_outputs = inversed_outputs.to(self.device) for metric in self.metrics_to_track: if metric.__class__ != PinballLoss and metric.__class__ != CRPSLoss: transformed_metrics[metric.__class__.__name__] += metric( outputted_samples, targets ) metrics[metric.__class__.__name__] += metric( inversed_outputs_samples, inversed_targets ) else: transformed_metrics[metric.__class__.__name__] += metric( outputs, targets.unsqueeze(-1) ) metrics[metric.__class__.__name__] += metric( inversed_outputs, inversed_targets.unsqueeze(-1) ) for metric in self.metrics_to_track: metrics[metric.__class__.__name__] /= len(dataloader) transformed_metrics[metric.__class__.__name__] /= len(dataloader) for metric_name, metric_value in metrics.items(): if train: metric_name = f"train_{metric_name}" else: metric_name = f"test_{metric_name}" task.get_logger().report_single_value( name=metric_name, value=metric_value ) for metric_name, metric_value in transformed_metrics.items(): if train: metric_name = f"train_transformed_{metric_name}" else: metric_name = f"test_transformed_{metric_name}" task.get_logger().report_single_value( name=metric_name, value=metric_value ) def get_plot(self, current_day, next_day, predictions, show_legend: bool = True): fig = go.Figure() # Convert to numpy for plotting current_day_np = current_day.view(-1).cpu().numpy() next_day_np = next_day.view(-1).cpu().numpy() # reshape predictions to (n, len(quantiles))$ predictions_np = predictions.cpu().numpy().reshape(-1, len(self.quantiles)) # Add traces for current and next day fig.add_trace(go.Scatter(x=np.arange(96), y=current_day_np, name="Current Day")) fig.add_trace(go.Scatter(x=96 + np.arange(96), y=next_day_np, name="Next Day")) for i, q in enumerate(self.quantiles): fig.add_trace( go.Scatter( x=96 + np.arange(96), y=predictions_np[:, i], name=f"Prediction (Q={q})", line=dict(dash="dash"), ) ) # Update the layout fig.update_layout(title="Predictions and Quantiles", showlegend=show_legend) return fig