import torch from tqdm import tqdm from src.policies.PolicyEvaluator import PolicyEvaluator from src.losses.crps_metric import crps_from_samples from src.trainers.trainer import Trainer from src.trainers.autoregressive_trainer import AutoRegressiveTrainer from src.data.preprocessing import DataProcessor from src.utils.clearml import ClearMLHelper from src.losses import PinballLoss, NonAutoRegressivePinballLoss, CRPSLoss import plotly.graph_objects as go import numpy as np import matplotlib.pyplot as plt from scipy.interpolate import CubicSpline import matplotlib.pyplot as plt import seaborn as sns import matplotlib.patches as mpatches def sample_from_dist(quantiles, preds): if isinstance(preds, torch.Tensor): preds = preds.detach().cpu() # if preds more than 2 dimensions, flatten to 2 if len(preds.shape) > 2: preds = preds.reshape(-1, preds.shape[-1]) # preds and target as numpy preds = preds.numpy() # random probabilities of (1000, 1) import random probs = np.array([random.random() for _ in range(1000)]) spline = CubicSpline(quantiles, preds, axis=1) samples = spline(probs) # get the diagonal samples = np.diag(samples) return samples def auto_regressive(dataset, model, quantiles, idx_batch, sequence_length: int = 96): device = next(model.parameters()).device prev_features, targets = dataset.get_batch(idx_batch) prev_features = prev_features.to(device) targets = targets.to(device) if len(list(prev_features.shape)) == 2: initial_sequence = prev_features[:, :96] else: initial_sequence = prev_features[:, :, 0] target_full = targets[:, 0].unsqueeze(1) # (batch_size, 1) with torch.no_grad(): new_predictions_full = model(prev_features) # (batch_size, quantiles) samples = ( torch.tensor(sample_from_dist(quantiles, new_predictions_full)) .unsqueeze(1) .to(device) ) # (batch_size, 1) predictions_samples = samples predictions_full = new_predictions_full.unsqueeze(1) for i in range(sequence_length - 1): if len(list(prev_features.shape)) == 2: new_features = torch.cat( (prev_features[:, 1:96], samples), dim=1 ) # (batch_size, 96) new_features = new_features.float() other_features, new_targets = dataset.get_batch_autoregressive( np.array(idx_batch) + i + 1 ) # (batch_size, new_features) if other_features is not None: prev_features = torch.cat( (new_features.to(device), other_features.to(device)), dim=1 ) # (batch_size, 96 + new_features) else: prev_features = new_features else: other_features, new_targets = dataset.get_batch_autoregressive( np.array(idx_batch) + i + 1 ) # (batch_size, 1, new_features) # change the other_features nrv based on the samples other_features[:, 0, 0] = samples.squeeze(-1) # make sure on same device other_features = other_features.to(device) prev_features = prev_features.to(device) prev_features = torch.cat( (prev_features[:, 1:, :], other_features), dim=1 ) # (batch_size, 96, new_features) target_full = torch.cat( (target_full, new_targets.to(device)), dim=1 ) # (batch_size, sequence_length) with torch.no_grad(): new_predictions_full = model(prev_features) # (batch_size, quantiles) predictions_full = torch.cat( (predictions_full, new_predictions_full.unsqueeze(1)), dim=1 ) # (batch_size, sequence_length, quantiles) samples = ( torch.tensor(sample_from_dist(quantiles, new_predictions_full)) .unsqueeze(-1) .to(device) ) # (batch_size, 1) predictions_samples = torch.cat((predictions_samples, samples), dim=1) return ( initial_sequence, predictions_full, predictions_samples, target_full.unsqueeze(-1), ) class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer): def __init__( self, model: torch.nn.Module, input_dim: tuple, optimizer: torch.optim.Optimizer, data_processor: DataProcessor, quantiles: list, device: torch.device, policy_evaluator: PolicyEvaluator = None, debug: bool = True, ): self.quantiles = quantiles self.test_set_samples = {} self.policy_evaluator = policy_evaluator criterion = PinballLoss(quantiles=quantiles) super().__init__( model=model, input_dim=input_dim, optimizer=optimizer, criterion=criterion, data_processor=data_processor, device=device, debug=debug, ) def calculate_crps_from_samples(self, task, dataloader, epoch: int): crps_from_samples_metric = [] generated_samples = {} with torch.no_grad(): for i in tqdm(dataloader.dataset.full_day_valid_indices): idx = dataloader.dataset.valid_indices.index(i) computed_idx_batch = [idx] * 100 initial, _, samples, targets = self.auto_regressive( dataloader.dataset, idx_batch=computed_idx_batch ) generated_samples[idx] = ( self.data_processor.inverse_transform(initial), self.data_processor.inverse_transform(samples), ) samples = samples.unsqueeze(0) targets = targets.squeeze(-1) targets = targets[0].unsqueeze(0) crps = crps_from_samples(samples, targets) crps_from_samples_metric.append(crps[0].mean().item()) if epoch is not None and task is not None: task.get_logger().report_scalar( title="CRPS_from_samples", series="val", value=np.mean(crps_from_samples_metric), iteration=epoch, ) # using the policy evaluator, evaluate the policy with the generated samples if self.policy_evaluator is not None and epoch != -1: optimal_penalty, profit, charge_cycles = ( self.policy_evaluator.optimize_penalty_for_target_charge_cycles( idx_samples=generated_samples, test_loader=dataloader, initial_penalty=900, target_charge_cycles=58 * 400 / 356, initial_learning_rate=5, max_iterations=100, tolerance=1, iteration=epoch, ) ) print( f"Optimal Penalty: {optimal_penalty}, Profit: {profit}, Charge Cycles: {charge_cycles}" ) task.get_logger().report_scalar( title="Optimal Penalty", series="val", value=optimal_penalty, iteration=epoch, ) task.get_logger().report_scalar( title="Optimal Profit", series="val", value=profit, iteration=epoch ) task.get_logger().report_scalar( title="Optimal Charge Cycles", series="val", value=charge_cycles, iteration=epoch, ) return ( np.mean(crps_from_samples_metric), profit, charge_cycles, optimal_penalty, generated_samples, ) return np.mean(crps_from_samples_metric), generated_samples def log_final_metrics(self, task, dataloader, train: bool = True): metrics = {metric.__class__.__name__: 0.0 for metric in self.metrics_to_track} transformed_metrics = { metric.__class__.__name__: 0.0 for metric in self.metrics_to_track } with torch.no_grad(): total_samples = len(dataloader.dataset) - 96 batches = 0 for _, _, idx_batch in tqdm(dataloader): idx_batch = [idx for idx in idx_batch if idx < total_samples] _, outputs, samples, targets = self.auto_regressive( dataloader.dataset, idx_batch=idx_batch ) samples = samples.to(self.device) outputs = outputs.to(self.device) targets = targets.to(self.device) inversed_samples = self.data_processor.inverse_transform(samples) inversed_targets = self.data_processor.inverse_transform(targets) inversed_outputs = self.data_processor.inverse_transform(outputs) inversed_samples = inversed_samples.to(self.device) inversed_targets = inversed_targets.to(self.device) inversed_outputs = inversed_outputs.to(self.device) for metric in self.metrics_to_track: if metric.__class__ != PinballLoss and metric.__class__ != CRPSLoss: transformed_metrics[metric.__class__.__name__] += metric( samples, targets.squeeze(-1) ) metrics[metric.__class__.__name__] += metric( inversed_samples, inversed_targets.squeeze(-1) ) else: transformed_metrics[metric.__class__.__name__] += metric( outputs, targets ) metrics[metric.__class__.__name__] += metric( inversed_outputs, inversed_targets ) batches += 1 for metric in self.metrics_to_track: metrics[metric.__class__.__name__] /= batches transformed_metrics[metric.__class__.__name__] /= batches for metric_name, metric_value in metrics.items(): if PinballLoss.__name__ in metric_name: continue name = f"train_{metric_name}" if train else f"test_{metric_name}" task.get_logger().report_single_value(name=name, value=metric_value) for metric_name, metric_value in transformed_metrics.items(): name = ( f"train_transformed_{metric_name}" if train else f"test_transformed_{metric_name}" ) task.get_logger().report_single_value(name=name, value=metric_value) if train == False: crps_from_samples_metric, self.test_set_samples = ( self.calculate_crps_from_samples(None, dataloader, None) ) task.get_logger().report_single_value( name="test_CRPS_from_samples_transformed", value=np.mean(crps_from_samples_metric), ) def get_plot( self, current_day, next_day, predictions, show_legend: bool = True, retransform: bool = True, task=None, ): fig = go.Figure() # Convert to numpy for plotting current_day_np = current_day.view(-1).cpu().numpy() next_day_np = next_day.view(-1).cpu().numpy() predictions_np = predictions.cpu().numpy() if retransform: current_day_np = self.data_processor.inverse_transform(current_day_np) next_day_np = self.data_processor.inverse_transform(next_day_np) predictions_np = self.data_processor.inverse_transform(predictions_np) ci_99_upper = np.quantile(predictions_np, 0.995, axis=0) ci_99_lower = np.quantile(predictions_np, 0.005, axis=0) ci_95_upper = np.quantile(predictions_np, 0.975, axis=0) ci_95_lower = np.quantile(predictions_np, 0.025, axis=0) ci_90_upper = np.quantile(predictions_np, 0.95, axis=0) ci_90_lower = np.quantile(predictions_np, 0.05, axis=0) ci_50_lower = np.quantile(predictions_np, 0.25, axis=0) ci_50_upper = np.quantile(predictions_np, 0.75, axis=0) # Add traces for current and next day # fig.add_trace(go.Scatter(x=np.arange(96), y=current_day_np, name="Current Day")) # fig.add_trace(go.Scatter(x=96 + np.arange(96), y=next_day_np, name="Next Day")) # for i, q in enumerate(self.quantiles): # fig.add_trace( # go.Scatter( # x=96 + np.arange(96), # y=predictions_np[:, i], # name=f"Prediction (Q={q})", # line=dict(dash="dash"), # ) # ) # # Update the layout # fig.update_layout( # title="Predictions and Quantiles of the Linear Model", # showlegend=show_legend, # ) sns.set_theme() time_steps = np.arange(0, 96) fig, ax = plt.subplots(figsize=(20, 10)) ax.plot( time_steps, predictions_np.mean(axis=0), label="Mean of NRV samples", linewidth=3, ) # ax.fill_between(time_steps, ci_lower, ci_upper, color='b', alpha=0.2, label='Full Interval') ax.fill_between( time_steps, ci_99_lower, ci_99_upper, color="b", alpha=0.2, label="99% Interval", ) ax.fill_between( time_steps, ci_95_lower, ci_95_upper, color="b", alpha=0.2, label="95% Interval", ) ax.fill_between( time_steps, ci_90_lower, ci_90_upper, color="b", alpha=0.2, label="90% Interval", ) ax.fill_between( time_steps, ci_50_lower, ci_50_upper, color="b", alpha=0.2, label="50% Interval", ) ax.plot(next_day_np, label="Real NRV", linewidth=3) # full_interval_patch = mpatches.Patch(color='b', alpha=0.2, label='Full Interval') ci_99_patch = mpatches.Patch(color="b", alpha=0.3, label="99% Interval") ci_95_patch = mpatches.Patch(color="b", alpha=0.4, label="95% Interval") ci_90_patch = mpatches.Patch(color="b", alpha=0.5, label="90% Interval") ci_50_patch = mpatches.Patch(color="b", alpha=0.6, label="50% Interval") ax.legend( handles=[ ci_99_patch, ci_95_patch, ci_90_patch, ci_50_patch, ax.lines[0], ax.lines[1], ] ) ax.set_ylim(-1500, 1500) fig2, ax2 = plt.subplots(figsize=(20, 10)) for i in range(10): ax2.plot(predictions_np[i], label=f"Sample {i}") ax2.plot(next_day_np, label="Real NRV", linewidth=4, color="orange") ax2.legend() ax2.set_ylim(-1500, 1500) return fig, fig2 def auto_regressive(self, dataset, idx_batch, sequence_length: int = 96): return auto_regressive( dataset, self.model, self.quantiles, idx_batch, sequence_length ) def plot_quantile_percentages( self, task, data_loader, train: bool = True, iteration: int = None, full_day: bool = False, ): quantiles = self.quantiles total = 0 quantile_counter = {q: 0 for q in quantiles} self.model.eval() with torch.no_grad(): total_samples = len(data_loader.dataset) - 96 for inputs, targets, idx_batch in data_loader: idx_batch = [idx for idx in idx_batch if idx < total_samples] if full_day: _, outputs, samples, targets = self.auto_regressive( data_loader.dataset, idx_batch=idx_batch ) # outputs: (batch, sequence_length, num_quantiles) # targets: (batch, sequence_length, 1) # reshape to (batch_size * sequence_length, num_quantiles) outputs = outputs.reshape(-1, len(quantiles)) targets = targets.reshape(-1) # to cpu outputs = outputs.cpu().numpy() targets = targets.cpu().numpy() else: inputs = inputs.to(self.device) outputs = ( self.model(inputs).cpu().numpy() ) # (batch_size, num_quantiles) targets = targets.squeeze(-1).cpu().numpy() # (batch_size, 1) for i, q in enumerate(quantiles): quantile_counter[q] += np.sum(targets < outputs[:, i]) total += len(targets) # to numpy array of length len(quantiles) percentages = np.array([quantile_counter[q] / total for q in quantiles]) bar_width = 0.35 index = np.arange(len(quantiles)) # Plotting the bars fig, ax = plt.subplots(figsize=(15, 10)) bar1 = ax.bar(index, quantiles, bar_width, label="Ideal", color="brown") bar2 = ax.bar( index + bar_width, percentages, bar_width, label="NN model", color="blue" ) # Adding the percentage values above the bars for bar2 for rect in bar2: height = rect.get_height() ax.text( rect.get_x() + rect.get_width() / 2.0, 1.005 * height, f"{height:.2}", ha="center", va="bottom", ) # Format the number as a percentage series_name = "Training Set" if train else "Test Set" full_day_str = "Full Day" if full_day else "Single Step" # Adding labels and title ax.set_xlabel("Quantile") ax.set_ylabel("Fraction of data under quantile forecast") ax.set_title(f"{series_name} {full_day_str} Quantile Performance Comparison") ax.set_xticks(index + bar_width / 2) ax.set_xticklabels(quantiles) ax.legend() task.get_logger().report_matplotlib_figure( title="Quantile Performance Comparison", series=f"{series_name} {full_day_str}", report_image=True, figure=plt, iteration=iteration, ) plt.close() class NonAutoRegressiveQuantileRegression(Trainer): def __init__( self, model: torch.nn.Module, input_dim: tuple, optimizer: torch.optim.Optimizer, data_processor: DataProcessor, quantiles: list, device: torch.device, debug: bool = True, policy_evaluator: PolicyEvaluator = None, ): self.quantiles = quantiles self.policy_evaluator = policy_evaluator criterion = NonAutoRegressivePinballLoss(quantiles=quantiles) super().__init__( model=model, input_dim=input_dim, optimizer=optimizer, criterion=criterion, data_processor=data_processor, device=device, debug=debug, ) def log_final_metrics(self, task, dataloader, train: bool = True): metrics = {metric.__class__.__name__: 0.0 for metric in self.metrics_to_track} transformed_metrics = { metric.__class__.__name__: 0.0 for metric in self.metrics_to_track } with torch.no_grad(): for inputs, targets, _ in dataloader: inputs, targets = inputs.to(self.device), targets.to(self.device) outputs = self.model(inputs) outputs = outputs.reshape(-1, 96, len(self.quantiles)) outputted_samples = [ sample_from_dist(self.quantiles, output.cpu()) for _ in range(100) for output in outputs ] outputted_samples = torch.tensor(outputted_samples) inversed_outputs_samples = self.data_processor.inverse_transform( outputted_samples ) expanded_targets = ( targets.unsqueeze(1).repeat(1, 100, 1).reshape(-1, 96) ) inversed_expanded_targets = self.data_processor.inverse_transform( expanded_targets ) outputs = outputs.reshape(inputs.shape[0], -1, len(self.quantiles)) inversed_outputs = self.data_processor.inverse_transform(outputs) inversed_targets = self.data_processor.inverse_transform(targets) inversed_outputs_samples = inversed_outputs_samples.to(self.device) inversed_targets = inversed_targets.to(self.device) outputted_samples = outputted_samples.to(self.device) inversed_outputs = inversed_outputs.to(self.device) expanded_targets = expanded_targets.to(self.device) inversed_expanded_targets = inversed_expanded_targets.to(self.device) for metric in self.metrics_to_track: if metric.__class__ != PinballLoss and metric.__class__ != CRPSLoss: transformed_metrics[metric.__class__.__name__] += metric( outputted_samples, expanded_targets ) metrics[metric.__class__.__name__] += metric( inversed_outputs_samples, inversed_expanded_targets ) else: transformed_metrics[metric.__class__.__name__] += metric( outputs, targets.unsqueeze(-1) ) metrics[metric.__class__.__name__] += metric( inversed_outputs, inversed_targets.unsqueeze(-1) ) for metric in self.metrics_to_track: metrics[metric.__class__.__name__] /= len(dataloader) transformed_metrics[metric.__class__.__name__] /= len(dataloader) for metric_name, metric_value in metrics.items(): if train: metric_name = f"train_{metric_name}" else: metric_name = f"test_{metric_name}" task.get_logger().report_single_value( name=metric_name, value=metric_value ) for metric_name, metric_value in transformed_metrics.items(): if train: metric_name = f"train_transformed_{metric_name}" else: metric_name = f"test_transformed_{metric_name}" task.get_logger().report_single_value( name=metric_name, value=metric_value ) def debug_plots(self, task, train: bool, data_loader, sample_indices, epoch): for actual_idx, idx in sample_indices.items(): features, target, _ = data_loader.dataset[idx] print(features.shape, target.shape) features = features.to(self.device) target = target.to(self.device) self.model.eval() with torch.no_grad(): predicted_quantiles = self.model(features) predictions = predicted_quantiles.reshape(-1, len(self.quantiles)) samples = [ sample_from_dist(self.quantiles, predictions) for _ in range(100) ] samples = torch.tensor(samples) fig, fig2 = self.get_plot( features[:96], target, samples, show_legend=(0 == 0) ) if epoch != -1: task.get_logger().report_matplotlib_figure( title="Training" if train else "Testing", series=f"Sample {actual_idx}", iteration=epoch, figure=fig, ) task.get_logger().report_matplotlib_figure( title="Training Samples" if train else "Testing Samples", series=f"Sample {actual_idx} samples", iteration=epoch, figure=fig2, report_interactive=False, ) else: print("Saving figs") # fig to PIL image fig.savefig(f"sample_{actual_idx}_plot.png", bbox_inches="tight") task.get_logger().report_image( title="Final Training Plot", series=f"Sample {actual_idx}", iteration=epoch, local_path=f"sample_{actual_idx}_plot.png", ) fig2.savefig( f"sample_{actual_idx}_samples_plot.png", bbox_inches="tight" ) task.get_logger().report_image( title="Final Training Samples Plot", series=f"Sample {actual_idx} samples", iteration=epoch, local_path=f"sample_{actual_idx}_samples_plot.png", ) plt.close() def get_plot( self, current_day, next_day, predictions, show_legend: bool = True, retransform: bool = True, ): fig = go.Figure() # Convert to numpy for plotting current_day_np = current_day.view(-1).cpu().numpy() next_day_np = next_day.view(-1).cpu().numpy() predictions_np = predictions.cpu().numpy() if retransform: current_day_np = self.data_processor.inverse_transform(current_day_np) next_day_np = self.data_processor.inverse_transform(next_day_np) predictions_np = self.data_processor.inverse_transform(predictions_np) ci_99_upper = np.quantile(predictions_np, 0.995, axis=0) ci_99_lower = np.quantile(predictions_np, 0.005, axis=0) ci_95_upper = np.quantile(predictions_np, 0.975, axis=0) ci_95_lower = np.quantile(predictions_np, 0.025, axis=0) ci_90_upper = np.quantile(predictions_np, 0.95, axis=0) ci_90_lower = np.quantile(predictions_np, 0.05, axis=0) ci_50_lower = np.quantile(predictions_np, 0.25, axis=0) ci_50_upper = np.quantile(predictions_np, 0.75, axis=0) sns.set_theme() time_steps = np.arange(0, 96) fig, ax = plt.subplots(figsize=(20, 10)) ax.plot( time_steps, predictions_np.mean(axis=0), label="Mean of NRV samples", linewidth=3, ) # ax.fill_between(time_steps, ci_lower, ci_upper, color='b', alpha=0.2, label='Full Interval') ax.fill_between( time_steps, ci_99_lower, ci_99_upper, color="b", alpha=0.2, label="99% Interval", ) ax.fill_between( time_steps, ci_95_lower, ci_95_upper, color="b", alpha=0.2, label="95% Interval", ) ax.fill_between( time_steps, ci_90_lower, ci_90_upper, color="b", alpha=0.2, label="90% Interval", ) ax.fill_between( time_steps, ci_50_lower, ci_50_upper, color="b", alpha=0.2, label="50% Interval", ) ax.plot(next_day_np, label="Real NRV", linewidth=3) # full_interval_patch = mpatches.Patch(color='b', alpha=0.2, label='Full Interval') ci_99_patch = mpatches.Patch(color="b", alpha=0.3, label="99% Interval") ci_95_patch = mpatches.Patch(color="b", alpha=0.4, label="95% Interval") ci_90_patch = mpatches.Patch(color="b", alpha=0.5, label="90% Interval") ci_50_patch = mpatches.Patch(color="b", alpha=0.6, label="50% Interval") ax.legend( handles=[ ci_99_patch, ci_95_patch, ci_90_patch, ci_50_patch, ax.lines[0], ax.lines[1], ] ) ax.set_ylim(-1500, 1500) fig2, ax2 = plt.subplots(figsize=(20, 10)) for i in range(10): ax2.plot(predictions_np[i], label=f"Sample {i}") ax2.plot(next_day_np, label="Real NRV", linewidth=4, color="orange") ax2.legend() ax2.set_ylim(-1500, 1500) return fig, fig2 def calculate_crps_from_samples(self, task, dataloader, epoch: int): crps_from_samples_metric = [] generated_samples = {} with torch.no_grad(): total_samples = len(dataloader.dataset) for _, _, idx_batch in tqdm(dataloader): idx_batch = [idx for idx in idx_batch if idx < total_samples] if len(idx_batch) == 0: continue for idx in tqdm(idx_batch): computed_idx_batch = [idx] * 100 initial, targets, _ = dataloader.dataset[idx] initial = initial.to(self.device) targets = targets.to(self.device) predicted_quantiles = self.model(initial) predictions = predicted_quantiles.reshape(-1, len(self.quantiles)) samples = [ sample_from_dist(self.quantiles, predictions) for _ in range(100) ] samples = torch.tensor(samples) generated_samples[idx.item()] = ( self.data_processor.inverse_transform(initial), self.data_processor.inverse_transform(samples), ) samples = samples.unsqueeze(0) targets = targets.squeeze(-1) targets = targets[0].unsqueeze(0) samples = samples.to(self.device) crps = crps_from_samples(samples, targets) crps_from_samples_metric.append(crps[0].mean().item()) task.get_logger().report_scalar( title="CRPS_from_samples", series="test", value=np.mean(crps_from_samples_metric), iteration=epoch, ) # using the policy evaluator, evaluate the policy with the generated samples if self.policy_evaluator is not None: optimal_penalty, profit, charge_cycles = ( self.policy_evaluator.optimize_penalty_for_target_charge_cycles( idx_samples=generated_samples, test_loader=dataloader, initial_penalty=500, target_charge_cycles=283, initial_learning_rate=2, max_iterations=100, tolerance=1, ) ) print( f"Optimal Penalty: {optimal_penalty}, Profit: {profit}, Charge Cycles: {charge_cycles}" ) task.get_logger().report_scalar( title="Optimal Penalty", series="test", value=optimal_penalty, iteration=epoch, ) task.get_logger().report_scalar( title="Optimal Profit", series="test", value=profit, iteration=epoch ) task.get_logger().report_scalar( title="Optimal Charge Cycles", series="test", value=charge_cycles, iteration=epoch, ) def plot_quantile_percentages( self, task, data_loader, train: bool = True, iteration: int = None, full_day: bool = False, ): quantiles = self.quantiles total = 0 quantile_counter = {q: 0 for q in quantiles} self.model.eval() with torch.no_grad(): total_samples = len(data_loader.dataset) - 96 for inputs, targets, idx_batch in data_loader: idx_batch = [idx for idx in idx_batch if idx < total_samples] inputs = inputs.to(self.device) outputs = ( self.model(inputs).cpu().numpy() ) # (batch_size, 96*num_quantiles) # reshape to (batch_size, num_quantiles, 96) outputs = outputs.reshape(-1, 96, len(quantiles)) targets = targets.squeeze(-1).cpu().numpy() # (batch_size, 96) for i, q in enumerate(quantiles): quantile_counter[q] += np.sum(targets < outputs[:, :, i]) total += len(targets) * 96 # to numpy array of length len(quantiles) percentages = np.array([quantile_counter[q] / total for q in quantiles]) bar_width = 0.35 index = np.arange(len(quantiles)) # Plotting the bars fig, ax = plt.subplots(figsize=(15, 10)) bar1 = ax.bar(index, quantiles, bar_width, label="Ideal", color="brown") bar2 = ax.bar( index + bar_width, percentages, bar_width, label="NN model", color="blue" ) # Adding the percentage values above the bars for bar2 for rect in bar2: height = rect.get_height() ax.text( rect.get_x() + rect.get_width() / 2.0, 1.005 * height, f"{height:.2}", ha="center", va="bottom", ) # Format the number as a percentage series_name = "Training Set" if train else "Test Set" full_day_str = "Full Day" if full_day else "Single Step" # Adding labels and title ax.set_xlabel("Quantile") ax.set_ylabel("Fraction of data under quantile forecast") ax.set_title(f"{series_name} {full_day_str} Quantile Performance Comparison") ax.set_xticks(index + bar_width / 2) ax.set_xticklabels(quantiles) ax.legend() task.get_logger().report_matplotlib_figure( title="Quantile Performance Comparison", series=f"{series_name} {full_day_str}", report_image=True, figure=plt, iteration=iteration, ) plt.close()