import torch from torch import nn import numpy as np from scipy.interpolate import CubicSpline class CRPSLoss(nn.Module): def __init__(self, quantiles): super(CRPSLoss, self).__init__() self.quantiles = quantiles def forward(self, preds, target): # if tensor, to cpu if isinstance(preds, torch.Tensor): preds = preds.detach().cpu() if isinstance(target, torch.Tensor): target = target.detach().cpu() # if preds more than 2 dimensions, flatten to 2 if len(preds.shape) > 2: preds = preds.reshape(-1, preds.shape[-1]) # target will be reshaped from (1024, 96, 15) to (1024*96, 15) # our target (1024, 96) also needs to be reshaped to (1024*96, 1) target = target.reshape(-1, 1) # preds and target as numpy preds = preds.numpy() target = target.numpy() n_x = 101 probs = np.linspace(0, 1, n_x) spline = CubicSpline(self.quantiles, preds, axis=1) imbalances = spline(probs) larger_than_label = imbalances > target tiled_probs = np.tile(probs, (len(imbalances), 1)) tiled_probs[larger_than_label] -= 1 crps_per_sample = np.trapz(tiled_probs ** 2, imbalances, axis=-1) crps = np.mean(crps_per_sample) return crps def crps_from_samples(samples, targets): """ Compute the Continuous Ranked Probability Score (CRPS) from multi-day samples and targets using a vectorized approach with PyTorch tensors. :param samples: (day, n_samples, n_timesteps) tensor of forecasted samples :param targets: (day, n_timesteps) tensor of observed values :return: (day, n_timesteps) tensor of CRPS for each timestep for each day """ days, n_samples, n_timesteps = samples.shape # Reshape targets to broadcast along the samples dimension (n_samples) targets_reshaped = targets.unsqueeze(1) # Compute the absolute differences of forecasts and observations abs_diff = torch.abs(samples - targets_reshaped) # Compute the average of the absolute differences along the samples dimension term1 = torch.mean(abs_diff, dim=1) # Compute the pairwise absolute differences between all samples for each day pairwise_abs_diff = torch.abs(samples.unsqueeze(2) - samples.unsqueeze(1)) # Compute the average of these differences along the sample dimensions term2 = torch.mean(pairwise_abs_diff, dim=(1, 2)) / 2 # CRPS for each timestep for each day crps = term1 - term2 return crps