Fixed losses and multiple other things
This commit is contained in:
@@ -36,7 +36,14 @@ class AutoRegressiveTrainer(Trainer):
|
||||
def debug_plots(self, task, train: bool, data_loader, sample_indices, epoch):
|
||||
num_samples = len(sample_indices)
|
||||
rows = num_samples # One row per sample since we only want one column
|
||||
cols = 1
|
||||
|
||||
# check if self has get_plot_error
|
||||
if hasattr(self, "get_plot_error"):
|
||||
cols = 2
|
||||
print("Using get_plot_error")
|
||||
else:
|
||||
cols = 1
|
||||
print("Using get_plot")
|
||||
|
||||
fig = make_subplots(
|
||||
rows=rows,
|
||||
@@ -63,6 +70,13 @@ class AutoRegressiveTrainer(Trainer):
|
||||
for trace in sub_fig.data:
|
||||
fig.add_trace(trace, row=row, col=col)
|
||||
|
||||
if cols == 2:
|
||||
error_sub_fig = self.get_plot_error(
|
||||
target, predictions
|
||||
)
|
||||
for trace in error_sub_fig.data:
|
||||
fig.add_trace(trace, row=row, col=col + 1)
|
||||
|
||||
loss = self.criterion(
|
||||
predictions.to(self.device), target.to(self.device)
|
||||
).item()
|
||||
|
||||
@@ -147,6 +147,37 @@ class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer):
|
||||
)
|
||||
task.get_logger().report_single_value(name=name, value=metric_value)
|
||||
|
||||
def get_plot_error(
|
||||
self,
|
||||
next_day,
|
||||
predictions,
|
||||
):
|
||||
metric = PinballLoss(quantiles=self.quantiles)
|
||||
fig = go.Figure()
|
||||
|
||||
next_day_np = next_day.view(-1).cpu().numpy()
|
||||
predictions_np = predictions.cpu().numpy()
|
||||
|
||||
if True:
|
||||
next_day_np = self.data_processor.inverse_transform(next_day_np)
|
||||
predictions_np = self.data_processor.inverse_transform(predictions_np)
|
||||
|
||||
# for each time step, calculate the error using the metric
|
||||
errors = []
|
||||
for i in range(96):
|
||||
|
||||
target_tensor = torch.tensor(next_day_np[i]).unsqueeze(0)
|
||||
prediction_tensor = torch.tensor(predictions_np[i]).unsqueeze(0)
|
||||
|
||||
errors.append(metric(prediction_tensor, target_tensor))
|
||||
|
||||
# plot the error
|
||||
fig.add_trace(go.Scatter(x=np.arange(96), y=errors, name=metric.__class__.__name__))
|
||||
fig.update_layout(title=f"Error of {metric.__class__.__name__} for each time step")
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def get_plot(
|
||||
self,
|
||||
current_day,
|
||||
@@ -364,25 +395,24 @@ class NonAutoRegressiveQuantileRegression(Trainer):
|
||||
def __init__(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
input_dim: tuple,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
data_processor: DataProcessor,
|
||||
quantiles: list,
|
||||
device: torch.device,
|
||||
clearml_helper: ClearMLHelper = None,
|
||||
debug: bool = True,
|
||||
):
|
||||
quantiles_tensor = torch.tensor(quantiles)
|
||||
quantiles_tensor = quantiles_tensor.to(device)
|
||||
self.quantiles = quantiles
|
||||
|
||||
criterion = NonAutoRegressivePinballLoss(quantiles=quantiles_tensor)
|
||||
|
||||
criterion = NonAutoRegressivePinballLoss(quantiles=quantiles)
|
||||
super().__init__(
|
||||
model=model,
|
||||
input_dim=input_dim,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
data_processor=data_processor,
|
||||
device=device,
|
||||
clearml_helper=clearml_helper,
|
||||
debug=debug,
|
||||
)
|
||||
|
||||
@@ -398,7 +428,7 @@ class NonAutoRegressiveQuantileRegression(Trainer):
|
||||
|
||||
outputs = self.model(inputs)
|
||||
outputted_samples = [
|
||||
sample_from_dist(self.quantiles.cpu(), output.cpu().numpy())
|
||||
sample_from_dist(self.quantiles, output.cpu().numpy())
|
||||
for output in outputs
|
||||
]
|
||||
|
||||
|
||||
@@ -313,6 +313,7 @@ class Trainer:
|
||||
def debug_plots(self, task, train: bool, data_loader, sample_indices, epoch):
|
||||
num_samples = len(sample_indices)
|
||||
rows = num_samples # One row per sample since we only want one column
|
||||
|
||||
cols = 1
|
||||
|
||||
fig = make_subplots(
|
||||
@@ -341,6 +342,7 @@ class Trainer:
|
||||
for trace in sub_fig.data:
|
||||
fig.add_trace(trace, row=row, col=col)
|
||||
|
||||
|
||||
# loss = self.criterion(predictions.to(self.device), target.squeeze(-1).to(self.device)).item()
|
||||
|
||||
# fig['layout']['annotations'][i].update(text=f"{loss.__class__.__name__}: {loss:.6f}")
|
||||
|
||||
Reference in New Issue
Block a user