Sped up sampling 20x

This commit is contained in:
Victor Mylle
2023-11-25 18:09:42 +00:00
parent 5de3f64a1a
commit 300f268286
10 changed files with 498 additions and 238 deletions

View File

@@ -45,12 +45,16 @@ class AutoRegressiveTrainer(Trainer):
)
for i, idx in enumerate(sample_indices):
auto_regressive_output = self.auto_regressive(data_loader, idx)
auto_regressive_output = self.auto_regressive(data_loader.dataset, [idx])
if len(auto_regressive_output) == 3:
initial, predictions, target = auto_regressive_output
else:
initial, predictions, _, target = auto_regressive_output
initial = initial.squeeze(0)
predictions = predictions.squeeze(0)
target = target.squeeze(0)
sub_fig = self.get_plot(initial, target, predictions, show_legend=(i == 0))
row = i + 1
@@ -64,13 +68,13 @@ class AutoRegressiveTrainer(Trainer):
).item()
fig["layout"]["annotations"][i].update(
text=f"{loss.__class__.__name__}: {loss:.6f}"
text=f"{self.criterion.__class__.__name__}: {loss:.6f}"
)
# y axis same for all plots
fig.update_yaxes(range=[-1, 1], col=1)
# fig.update_yaxes(range=[-1, 1], col=1)
fig.update_layout(height=300 * rows)
fig.update_layout(height=1000 * rows)
task.get_logger().report_plotly(
title=f"{'Training' if train else 'Test'} Samples",
series="full_day",
@@ -140,7 +144,7 @@ class AutoRegressiveTrainer(Trainer):
total_amount_samples = len(dataloader.dataset) - 95
for idx in tqdm(range(total_amount_samples)):
_, outputs, targets = self.auto_regressive(dataloader, idx)
_, outputs, targets = self.auto_regressive(dataloader.dataset, idx)
inversed_outputs = torch.tensor(
self.data_processor.inverse_transform(outputs)