import torch from torch import nn import torch from properscoring import crps_ensemble class CRPSLoss(nn.Module): def __init__(self): super(CRPSLoss, self).__init__() def forward(self, preds, target): # if tensor, to cpu if isinstance(preds, torch.Tensor): preds = preds.detach().cpu() if isinstance(target, torch.Tensor): target = target.detach().cpu() # target squeeze -1 target = target.squeeze(-1) # preds shape: [batch_size, num_quantiles] scores = crps_ensemble(target, preds) # mean over batch crps = scores.mean() return crps