Fixed the non autoregressive final metric calculations
This commit is contained in:
@@ -558,18 +558,23 @@ class NonAutoRegressiveQuantileRegression(Trainer):
|
||||
inputs, targets = inputs.to(self.device), targets.to(self.device)
|
||||
|
||||
outputs = self.model(inputs)
|
||||
|
||||
outputs = outputs.reshape(-1, len(self.quantiles))
|
||||
outputs = outputs.reshape(-1, 96, len(self.quantiles))
|
||||
|
||||
outputted_samples = [
|
||||
sample_from_dist(self.quantiles, output.cpu()) for output in outputs
|
||||
sample_from_dist(self.quantiles, output.cpu()) for _ in range(100) for output in outputs
|
||||
]
|
||||
|
||||
|
||||
outputted_samples = torch.tensor(outputted_samples)
|
||||
inversed_outputs_samples = self.data_processor.inverse_transform(
|
||||
outputted_samples
|
||||
)
|
||||
|
||||
expanded_targets = targets.unsqueeze(1).repeat(1, 100, 1).reshape(-1, 96)
|
||||
inversed_expanded_targets = self.data_processor.inverse_transform(
|
||||
expanded_targets
|
||||
)
|
||||
|
||||
outputs = outputs.reshape(inputs.shape[0], -1, len(self.quantiles))
|
||||
inversed_outputs = self.data_processor.inverse_transform(outputs)
|
||||
inversed_targets = self.data_processor.inverse_transform(targets)
|
||||
@@ -579,13 +584,17 @@ class NonAutoRegressiveQuantileRegression(Trainer):
|
||||
outputted_samples = outputted_samples.to(self.device)
|
||||
inversed_outputs = inversed_outputs.to(self.device)
|
||||
|
||||
expanded_targets = expanded_targets.to(self.device)
|
||||
inversed_expanded_targets = inversed_expanded_targets.to(self.device)
|
||||
|
||||
|
||||
for metric in self.metrics_to_track:
|
||||
if metric.__class__ != PinballLoss and metric.__class__ != CRPSLoss:
|
||||
transformed_metrics[metric.__class__.__name__] += metric(
|
||||
outputted_samples, targets
|
||||
outputted_samples, expanded_targets
|
||||
)
|
||||
metrics[metric.__class__.__name__] += metric(
|
||||
inversed_outputs_samples, inversed_targets
|
||||
inversed_outputs_samples, inversed_expanded_targets
|
||||
)
|
||||
else:
|
||||
transformed_metrics[metric.__class__.__name__] += metric(
|
||||
|
||||
Reference in New Issue
Block a user