Implemented Non Autorgressive Quantile Regression

This commit is contained in:
Victor Mylle
2023-11-18 17:42:06 +00:00
parent 75f1f64c38
commit 1268af47a6
9 changed files with 196493 additions and 161 deletions

View File

@@ -1 +1 @@
from .pinball_loss import PinballLoss
from .pinball_loss import PinballLoss, NonAutoRegressivePinballLoss

View File

@@ -2,32 +2,30 @@ import torch
from torch import nn
class PinballLoss(nn.Module):
"""
Calculates the quantile loss function.
Attributes
----------
self.pred : torch.tensor
Predictions.
self.target : torch.tensor
Target to predict.
self.quantiles : torch.tensor
"""
def __init__(self, quantiles):
super(PinballLoss, self).__init__()
self.quantiles_tensor = quantiles
self.quantiles = quantiles.tolist()
self.quantiles_tensor = torch.tensor(quantiles, dtype=torch.float32)
def forward(self, pred, target):
"""
Computes the loss for the given prediction.
"""
error = target - pred
upper = self.quantiles_tensor * error
upper = self.quantiles_tensor * error
lower = (self.quantiles_tensor - 1) * error
losses = torch.max(lower, upper)
loss = torch.mean(torch.sum(losses, dim=1))
loss = torch.mean(torch.mean(losses, dim=0))
return loss
class NonAutoRegressivePinballLoss(nn.Module):
def __init__(self, quantiles):
super(NonAutoRegressivePinballLoss, self).__init__()
self.quantiles_tensor = torch.tensor(quantiles, dtype=torch.float32)
def forward(self, pred, target):
pred = pred.reshape(-1, 96, len(self.quantiles_tensor))
target_expanded = target.unsqueeze(2)
error = target_expanded - pred
upper = self.quantiles_tensor * error
lower = (self.quantiles_tensor - 1) * error
losses = torch.max(lower, upper)
loss = torch.mean(losses)
return loss