Sped up sampling 20x
This commit is contained in:
@@ -45,12 +45,16 @@ class AutoRegressiveTrainer(Trainer):
|
||||
)
|
||||
|
||||
for i, idx in enumerate(sample_indices):
|
||||
auto_regressive_output = self.auto_regressive(data_loader, idx)
|
||||
auto_regressive_output = self.auto_regressive(data_loader.dataset, [idx])
|
||||
if len(auto_regressive_output) == 3:
|
||||
initial, predictions, target = auto_regressive_output
|
||||
else:
|
||||
initial, predictions, _, target = auto_regressive_output
|
||||
|
||||
initial = initial.squeeze(0)
|
||||
predictions = predictions.squeeze(0)
|
||||
target = target.squeeze(0)
|
||||
|
||||
sub_fig = self.get_plot(initial, target, predictions, show_legend=(i == 0))
|
||||
|
||||
row = i + 1
|
||||
@@ -64,13 +68,13 @@ class AutoRegressiveTrainer(Trainer):
|
||||
).item()
|
||||
|
||||
fig["layout"]["annotations"][i].update(
|
||||
text=f"{loss.__class__.__name__}: {loss:.6f}"
|
||||
text=f"{self.criterion.__class__.__name__}: {loss:.6f}"
|
||||
)
|
||||
|
||||
# y axis same for all plots
|
||||
fig.update_yaxes(range=[-1, 1], col=1)
|
||||
# fig.update_yaxes(range=[-1, 1], col=1)
|
||||
|
||||
fig.update_layout(height=300 * rows)
|
||||
fig.update_layout(height=1000 * rows)
|
||||
task.get_logger().report_plotly(
|
||||
title=f"{'Training' if train else 'Test'} Samples",
|
||||
series="full_day",
|
||||
@@ -140,7 +144,7 @@ class AutoRegressiveTrainer(Trainer):
|
||||
total_amount_samples = len(dataloader.dataset) - 95
|
||||
|
||||
for idx in tqdm(range(total_amount_samples)):
|
||||
_, outputs, targets = self.auto_regressive(dataloader, idx)
|
||||
_, outputs, targets = self.auto_regressive(dataloader.dataset, idx)
|
||||
|
||||
inversed_outputs = torch.tensor(
|
||||
self.data_processor.inverse_transform(outputs)
|
||||
|
||||
Reference in New Issue
Block a user