Fixed losses and multiple other things

This commit is contained in:
Victor Mylle
2023-12-12 11:02:42 +00:00
parent d3bf04d68c
commit c06cc10aa6
12 changed files with 5093936 additions and 122 deletions

View File

@@ -147,6 +147,37 @@ class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer):
)
task.get_logger().report_single_value(name=name, value=metric_value)
def get_plot_error(
self,
next_day,
predictions,
):
metric = PinballLoss(quantiles=self.quantiles)
fig = go.Figure()
next_day_np = next_day.view(-1).cpu().numpy()
predictions_np = predictions.cpu().numpy()
if True:
next_day_np = self.data_processor.inverse_transform(next_day_np)
predictions_np = self.data_processor.inverse_transform(predictions_np)
# for each time step, calculate the error using the metric
errors = []
for i in range(96):
target_tensor = torch.tensor(next_day_np[i]).unsqueeze(0)
prediction_tensor = torch.tensor(predictions_np[i]).unsqueeze(0)
errors.append(metric(prediction_tensor, target_tensor))
# plot the error
fig.add_trace(go.Scatter(x=np.arange(96), y=errors, name=metric.__class__.__name__))
fig.update_layout(title=f"Error of {metric.__class__.__name__} for each time step")
return fig
def get_plot(
self,
current_day,
@@ -364,25 +395,24 @@ class NonAutoRegressiveQuantileRegression(Trainer):
def __init__(
self,
model: torch.nn.Module,
input_dim: tuple,
optimizer: torch.optim.Optimizer,
data_processor: DataProcessor,
quantiles: list,
device: torch.device,
clearml_helper: ClearMLHelper = None,
debug: bool = True,
):
quantiles_tensor = torch.tensor(quantiles)
quantiles_tensor = quantiles_tensor.to(device)
self.quantiles = quantiles
criterion = NonAutoRegressivePinballLoss(quantiles=quantiles_tensor)
criterion = NonAutoRegressivePinballLoss(quantiles=quantiles)
super().__init__(
model=model,
input_dim=input_dim,
optimizer=optimizer,
criterion=criterion,
data_processor=data_processor,
device=device,
clearml_helper=clearml_helper,
debug=debug,
)
@@ -398,7 +428,7 @@ class NonAutoRegressiveQuantileRegression(Trainer):
outputs = self.model(inputs)
outputted_samples = [
sample_from_dist(self.quantiles.cpu(), output.cpu().numpy())
sample_from_dist(self.quantiles, output.cpu().numpy())
for output in outputs
]