Implemented Non Autorgressive Quantile Regression
This commit is contained in:
@@ -1 +1 @@
|
||||
from .pinball_loss import PinballLoss
|
||||
from .pinball_loss import PinballLoss, NonAutoRegressivePinballLoss
|
||||
@@ -2,32 +2,30 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
class PinballLoss(nn.Module):
|
||||
"""
|
||||
Calculates the quantile loss function.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
self.pred : torch.tensor
|
||||
Predictions.
|
||||
self.target : torch.tensor
|
||||
Target to predict.
|
||||
self.quantiles : torch.tensor
|
||||
"""
|
||||
def __init__(self, quantiles):
|
||||
super(PinballLoss, self).__init__()
|
||||
self.quantiles_tensor = quantiles
|
||||
self.quantiles = quantiles.tolist()
|
||||
self.quantiles_tensor = torch.tensor(quantiles, dtype=torch.float32)
|
||||
|
||||
def forward(self, pred, target):
|
||||
"""
|
||||
Computes the loss for the given prediction.
|
||||
"""
|
||||
|
||||
error = target - pred
|
||||
upper = self.quantiles_tensor * error
|
||||
upper = self.quantiles_tensor * error
|
||||
lower = (self.quantiles_tensor - 1) * error
|
||||
|
||||
losses = torch.max(lower, upper)
|
||||
loss = torch.mean(torch.sum(losses, dim=1))
|
||||
loss = torch.mean(torch.mean(losses, dim=0))
|
||||
return loss
|
||||
|
||||
|
||||
class NonAutoRegressivePinballLoss(nn.Module):
|
||||
def __init__(self, quantiles):
|
||||
super(NonAutoRegressivePinballLoss, self).__init__()
|
||||
self.quantiles_tensor = torch.tensor(quantiles, dtype=torch.float32)
|
||||
|
||||
def forward(self, pred, target):
|
||||
pred = pred.reshape(-1, 96, len(self.quantiles_tensor))
|
||||
target_expanded = target.unsqueeze(2)
|
||||
error = target_expanded - pred
|
||||
upper = self.quantiles_tensor * error
|
||||
lower = (self.quantiles_tensor - 1) * error
|
||||
losses = torch.max(lower, upper)
|
||||
loss = torch.mean(losses)
|
||||
return loss
|
||||
|
||||
Reference in New Issue
Block a user