Saving samples plot as png at end of training

This commit is contained in:
2024-04-19 14:05:20 +02:00
parent 4e713ef564
commit 46c7c6f7e5
4 changed files with 60 additions and 53 deletions

View File

@@ -630,9 +630,7 @@ class NonAutoRegressiveQuantileRegression(Trainer):
name=metric_name, value=metric_value
)
def debug_plots(
self, task, train: bool, data_loader, sample_indices, epoch, final=False
):
def debug_plots(self, task, train: bool, data_loader, sample_indices, epoch):
for actual_idx, idx in sample_indices.items():
features, target, _ = data_loader.dataset[idx]
@@ -653,22 +651,24 @@ class NonAutoRegressiveQuantileRegression(Trainer):
features[:96], target, samples, show_legend=(0 == 0)
)
task.get_logger().report_matplotlib_figure(
title="Training" if train else "Testing",
series=f"Sample {actual_idx}",
iteration=epoch,
figure=fig,
)
if epoch != -1:
task.get_logger().report_matplotlib_figure(
title="Training" if train else "Testing",
series=f"Sample {actual_idx}",
iteration=epoch,
figure=fig,
)
task.get_logger().report_matplotlib_figure(
title="Training Samples" if train else "Testing Samples",
series=f"Sample {actual_idx} samples",
iteration=epoch,
figure=fig2,
report_interactive=False,
)
task.get_logger().report_matplotlib_figure(
title="Training Samples" if train else "Testing Samples",
series=f"Sample {actual_idx} samples",
iteration=epoch,
figure=fig2,
report_interactive=False,
)
if final:
else:
print("Saving figs")
# fig to PIL image
fig.savefig(f"sample_{actual_idx}_plot.png")
task.get_logger().report_image(
@@ -789,7 +789,7 @@ class NonAutoRegressiveQuantileRegression(Trainer):
for i in range(10):
ax2.plot(predictions_np[i], label=f"Sample {i}")
ax2.plot(next_day_np, label="Real NRV", linewidth=3)
ax2.plot(next_day_np, label="Real NRV", linewidth=4, color="orange")
ax2.legend()
ax2.set_ylim(-1500, 1500)