29 lines
662 B
Python
29 lines
662 B
Python
import torch
|
|
from torch import nn
|
|
import torch
|
|
from properscoring import crps_ensemble
|
|
|
|
|
|
class CRPSLoss(nn.Module):
|
|
def __init__(self):
|
|
super(CRPSLoss, self).__init__()
|
|
|
|
def forward(self, preds, target):
|
|
# if tensor, to cpu
|
|
if isinstance(preds, torch.Tensor):
|
|
preds = preds.detach().cpu()
|
|
|
|
if isinstance(target, torch.Tensor):
|
|
target = target.detach().cpu()
|
|
|
|
# target squeeze -1
|
|
target = target.squeeze(-1)
|
|
|
|
# preds shape: [batch_size, num_quantiles]
|
|
scores = crps_ensemble(target, preds)
|
|
|
|
# mean over batch
|
|
crps = scores.mean()
|
|
|
|
return crps
|