Saving samples plot as png at end of training
This commit is contained in:
@@ -45,7 +45,6 @@ class AutoRegressiveTrainer(Trainer):
|
||||
else:
|
||||
initial, _, predictions, target = auto_regressive_output
|
||||
|
||||
# keep one initial
|
||||
initial = initial[0]
|
||||
target = target[0]
|
||||
|
||||
@@ -55,20 +54,38 @@ class AutoRegressiveTrainer(Trainer):
|
||||
initial, target, predictions, show_legend=(0 == 0)
|
||||
)
|
||||
|
||||
task.get_logger().report_matplotlib_figure(
|
||||
title="Training" if train else "Testing",
|
||||
series=f"Sample {actual_idx}",
|
||||
iteration=epoch,
|
||||
figure=fig,
|
||||
)
|
||||
if epoch > 0:
|
||||
task.get_logger().report_matplotlib_figure(
|
||||
title="Training" if train else "Testing",
|
||||
series=f"Sample {actual_idx}",
|
||||
iteration=epoch,
|
||||
figure=fig,
|
||||
)
|
||||
|
||||
task.get_logger().report_matplotlib_figure(
|
||||
title="Training Samples" if train else "Testing Samples",
|
||||
series=f"Sample {actual_idx} samples",
|
||||
iteration=epoch,
|
||||
figure=fig2,
|
||||
report_interactive=False,
|
||||
)
|
||||
task.get_logger().report_matplotlib_figure(
|
||||
title="Training Samples" if train else "Testing Samples",
|
||||
series=f"Sample {actual_idx} samples",
|
||||
iteration=epoch,
|
||||
figure=fig2,
|
||||
report_interactive=False,
|
||||
)
|
||||
|
||||
else:
|
||||
fig.savefig(f"sample_{actual_idx}_plot.png")
|
||||
task.get_logger().report_image(
|
||||
title="Final Training Plot",
|
||||
series=f"Sample {actual_idx}",
|
||||
iteration=epoch,
|
||||
local_path=f"sample_{actual_idx}_plot.png",
|
||||
)
|
||||
|
||||
fig2.savefig(f"sample_{actual_idx}_samples_plot.png")
|
||||
task.get_logger().report_image(
|
||||
title="Final Training Samples Plot",
|
||||
series=f"Sample {actual_idx} samples",
|
||||
iteration=epoch,
|
||||
local_path=f"sample_{actual_idx}_samples_plot.png",
|
||||
)
|
||||
|
||||
plt.close()
|
||||
|
||||
|
||||
@@ -630,9 +630,7 @@ class NonAutoRegressiveQuantileRegression(Trainer):
|
||||
name=metric_name, value=metric_value
|
||||
)
|
||||
|
||||
def debug_plots(
|
||||
self, task, train: bool, data_loader, sample_indices, epoch, final=False
|
||||
):
|
||||
def debug_plots(self, task, train: bool, data_loader, sample_indices, epoch):
|
||||
for actual_idx, idx in sample_indices.items():
|
||||
features, target, _ = data_loader.dataset[idx]
|
||||
|
||||
@@ -653,22 +651,24 @@ class NonAutoRegressiveQuantileRegression(Trainer):
|
||||
features[:96], target, samples, show_legend=(0 == 0)
|
||||
)
|
||||
|
||||
task.get_logger().report_matplotlib_figure(
|
||||
title="Training" if train else "Testing",
|
||||
series=f"Sample {actual_idx}",
|
||||
iteration=epoch,
|
||||
figure=fig,
|
||||
)
|
||||
if epoch != -1:
|
||||
task.get_logger().report_matplotlib_figure(
|
||||
title="Training" if train else "Testing",
|
||||
series=f"Sample {actual_idx}",
|
||||
iteration=epoch,
|
||||
figure=fig,
|
||||
)
|
||||
|
||||
task.get_logger().report_matplotlib_figure(
|
||||
title="Training Samples" if train else "Testing Samples",
|
||||
series=f"Sample {actual_idx} samples",
|
||||
iteration=epoch,
|
||||
figure=fig2,
|
||||
report_interactive=False,
|
||||
)
|
||||
task.get_logger().report_matplotlib_figure(
|
||||
title="Training Samples" if train else "Testing Samples",
|
||||
series=f"Sample {actual_idx} samples",
|
||||
iteration=epoch,
|
||||
figure=fig2,
|
||||
report_interactive=False,
|
||||
)
|
||||
|
||||
if final:
|
||||
else:
|
||||
print("Saving figs")
|
||||
# fig to PIL image
|
||||
fig.savefig(f"sample_{actual_idx}_plot.png")
|
||||
task.get_logger().report_image(
|
||||
@@ -789,7 +789,7 @@ class NonAutoRegressiveQuantileRegression(Trainer):
|
||||
for i in range(10):
|
||||
ax2.plot(predictions_np[i], label=f"Sample {i}")
|
||||
|
||||
ax2.plot(next_day_np, label="Real NRV", linewidth=3)
|
||||
ax2.plot(next_day_np, label="Real NRV", linewidth=4, color="orange")
|
||||
ax2.legend()
|
||||
|
||||
ax2.set_ylim(-1500, 1500)
|
||||
|
||||
@@ -198,6 +198,7 @@ class Trainer:
|
||||
|
||||
if task:
|
||||
self.finish_training(task=task)
|
||||
self.debug_plots(task, False, test_loader, test_samples, -1)
|
||||
# task.close()
|
||||
except Exception:
|
||||
if task:
|
||||
@@ -342,21 +343,8 @@ class Trainer:
|
||||
features[:96], target, predictions, show_legend=(0 == 0)
|
||||
)
|
||||
|
||||
task.get_logger().report_matplotlib_figure(
|
||||
title="Training" if train else "Testing",
|
||||
series=f"Sample {actual_idx}",
|
||||
iteration=epoch,
|
||||
figure=fig,
|
||||
)
|
||||
|
||||
task.get_logger().report_matplotlib_figure(
|
||||
title="Training Samples" if train else "Testing Samples",
|
||||
series=f"Sample {actual_idx} samples",
|
||||
iteration=epoch,
|
||||
figure=fig2,
|
||||
report_interactive=False,
|
||||
)
|
||||
|
||||
|
||||
)
|
||||
plt.close()
|
||||
|
||||
def debug_scatter_plot(self, task, train: bool, samples, epoch):
|
||||
|
||||
Reference in New Issue
Block a user