Sped up sampling 20x
This commit is contained in:
@@ -15,7 +15,7 @@ class CRPSLoss(nn.Module):
|
||||
# preds shape: [batch_size, num_quantiles]
|
||||
|
||||
# unsqueeze target
|
||||
target = target.unsqueeze(-1)
|
||||
# target = target.unsqueeze(-1)
|
||||
|
||||
mask = (preds > target).float()
|
||||
test = self.quantiles_tensor - mask
|
||||
|
||||
@@ -1,24 +1,27 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class PinballLoss(nn.Module):
|
||||
def __init__(self, quantiles):
|
||||
super(PinballLoss, self).__init__()
|
||||
self.quantiles_tensor = torch.tensor(quantiles, dtype=torch.float32)
|
||||
|
||||
self.quantiles = self.quantiles_tensor.tolist()
|
||||
|
||||
def forward(self, pred, target):
|
||||
error = target - pred
|
||||
upper = self.quantiles_tensor * error
|
||||
lower = (self.quantiles_tensor - 1) * error
|
||||
lower = (self.quantiles_tensor - 1) * error
|
||||
losses = torch.max(lower, upper)
|
||||
loss = torch.mean(torch.mean(losses, dim=0))
|
||||
return loss
|
||||
|
||||
|
||||
|
||||
class NonAutoRegressivePinballLoss(nn.Module):
|
||||
def __init__(self, quantiles):
|
||||
super(NonAutoRegressivePinballLoss, self).__init__()
|
||||
self.quantiles_tensor = torch.tensor(quantiles, dtype=torch.float32)
|
||||
self.quantiles = self.quantiles_tensor.tolist()
|
||||
|
||||
def forward(self, pred, target):
|
||||
pred = pred.reshape(-1, 96, len(self.quantiles_tensor))
|
||||
|
||||
Reference in New Issue
Block a user