Fixed losses and multiple other things
This commit is contained in:
@@ -36,7 +36,14 @@ class AutoRegressiveTrainer(Trainer):
|
||||
def debug_plots(self, task, train: bool, data_loader, sample_indices, epoch):
|
||||
num_samples = len(sample_indices)
|
||||
rows = num_samples # One row per sample since we only want one column
|
||||
cols = 1
|
||||
|
||||
# check if self has get_plot_error
|
||||
if hasattr(self, "get_plot_error"):
|
||||
cols = 2
|
||||
print("Using get_plot_error")
|
||||
else:
|
||||
cols = 1
|
||||
print("Using get_plot")
|
||||
|
||||
fig = make_subplots(
|
||||
rows=rows,
|
||||
@@ -63,6 +70,13 @@ class AutoRegressiveTrainer(Trainer):
|
||||
for trace in sub_fig.data:
|
||||
fig.add_trace(trace, row=row, col=col)
|
||||
|
||||
if cols == 2:
|
||||
error_sub_fig = self.get_plot_error(
|
||||
target, predictions
|
||||
)
|
||||
for trace in error_sub_fig.data:
|
||||
fig.add_trace(trace, row=row, col=col + 1)
|
||||
|
||||
loss = self.criterion(
|
||||
predictions.to(self.device), target.to(self.device)
|
||||
).item()
|
||||
|
||||
Reference in New Issue
Block a user