Fixed losses and multiple other things

This commit is contained in:
Victor Mylle
2023-12-12 11:02:42 +00:00
parent d3bf04d68c
commit c06cc10aa6
12 changed files with 5093936 additions and 122 deletions

View File

@@ -36,7 +36,14 @@ class AutoRegressiveTrainer(Trainer):
def debug_plots(self, task, train: bool, data_loader, sample_indices, epoch):
num_samples = len(sample_indices)
rows = num_samples # One row per sample since we only want one column
cols = 1
# check if self has get_plot_error
if hasattr(self, "get_plot_error"):
cols = 2
print("Using get_plot_error")
else:
cols = 1
print("Using get_plot")
fig = make_subplots(
rows=rows,
@@ -63,6 +70,13 @@ class AutoRegressiveTrainer(Trainer):
for trace in sub_fig.data:
fig.add_trace(trace, row=row, col=col)
if cols == 2:
error_sub_fig = self.get_plot_error(
target, predictions
)
for trace in error_sub_fig.data:
fig.add_trace(trace, row=row, col=col + 1)
loss = self.criterion(
predictions.to(self.device), target.to(self.device)
).item()