Files
Thesis/src/losses/crps_metric.py
2023-11-28 15:35:35 +00:00

29 lines
662 B
Python

import torch
from torch import nn
import torch
from properscoring import crps_ensemble
class CRPSLoss(nn.Module):
def __init__(self):
super(CRPSLoss, self).__init__()
def forward(self, preds, target):
# if tensor, to cpu
if isinstance(preds, torch.Tensor):
preds = preds.detach().cpu()
if isinstance(target, torch.Tensor):
target = target.detach().cpu()
# target squeeze -1
target = target.squeeze(-1)
# preds shape: [batch_size, num_quantiles]
scores = crps_ensemble(target, preds)
# mean over batch
crps = scores.mean()
return crps