Sped up sampling 20x
This commit is contained in:
@@ -15,7 +15,7 @@ class CRPSLoss(nn.Module):
|
||||
# preds shape: [batch_size, num_quantiles]
|
||||
|
||||
# unsqueeze target
|
||||
target = target.unsqueeze(-1)
|
||||
# target = target.unsqueeze(-1)
|
||||
|
||||
mask = (preds > target).float()
|
||||
test = self.quantiles_tensor - mask
|
||||
|
||||
Reference in New Issue
Block a user