Saving samples plot as png at end of training
This commit is contained in:
@@ -45,7 +45,6 @@ class AutoRegressiveTrainer(Trainer):
|
|||||||
else:
|
else:
|
||||||
initial, _, predictions, target = auto_regressive_output
|
initial, _, predictions, target = auto_regressive_output
|
||||||
|
|
||||||
# keep one initial
|
|
||||||
initial = initial[0]
|
initial = initial[0]
|
||||||
target = target[0]
|
target = target[0]
|
||||||
|
|
||||||
@@ -55,6 +54,7 @@ class AutoRegressiveTrainer(Trainer):
|
|||||||
initial, target, predictions, show_legend=(0 == 0)
|
initial, target, predictions, show_legend=(0 == 0)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if epoch > 0:
|
||||||
task.get_logger().report_matplotlib_figure(
|
task.get_logger().report_matplotlib_figure(
|
||||||
title="Training" if train else "Testing",
|
title="Training" if train else "Testing",
|
||||||
series=f"Sample {actual_idx}",
|
series=f"Sample {actual_idx}",
|
||||||
@@ -70,6 +70,23 @@ class AutoRegressiveTrainer(Trainer):
|
|||||||
report_interactive=False,
|
report_interactive=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
fig.savefig(f"sample_{actual_idx}_plot.png")
|
||||||
|
task.get_logger().report_image(
|
||||||
|
title="Final Training Plot",
|
||||||
|
series=f"Sample {actual_idx}",
|
||||||
|
iteration=epoch,
|
||||||
|
local_path=f"sample_{actual_idx}_plot.png",
|
||||||
|
)
|
||||||
|
|
||||||
|
fig2.savefig(f"sample_{actual_idx}_samples_plot.png")
|
||||||
|
task.get_logger().report_image(
|
||||||
|
title="Final Training Samples Plot",
|
||||||
|
series=f"Sample {actual_idx} samples",
|
||||||
|
iteration=epoch,
|
||||||
|
local_path=f"sample_{actual_idx}_samples_plot.png",
|
||||||
|
)
|
||||||
|
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
def auto_regressive(self, data_loader, idx, sequence_length: int = 96):
|
def auto_regressive(self, data_loader, idx, sequence_length: int = 96):
|
||||||
|
|||||||
@@ -630,9 +630,7 @@ class NonAutoRegressiveQuantileRegression(Trainer):
|
|||||||
name=metric_name, value=metric_value
|
name=metric_name, value=metric_value
|
||||||
)
|
)
|
||||||
|
|
||||||
def debug_plots(
|
def debug_plots(self, task, train: bool, data_loader, sample_indices, epoch):
|
||||||
self, task, train: bool, data_loader, sample_indices, epoch, final=False
|
|
||||||
):
|
|
||||||
for actual_idx, idx in sample_indices.items():
|
for actual_idx, idx in sample_indices.items():
|
||||||
features, target, _ = data_loader.dataset[idx]
|
features, target, _ = data_loader.dataset[idx]
|
||||||
|
|
||||||
@@ -653,6 +651,7 @@ class NonAutoRegressiveQuantileRegression(Trainer):
|
|||||||
features[:96], target, samples, show_legend=(0 == 0)
|
features[:96], target, samples, show_legend=(0 == 0)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if epoch != -1:
|
||||||
task.get_logger().report_matplotlib_figure(
|
task.get_logger().report_matplotlib_figure(
|
||||||
title="Training" if train else "Testing",
|
title="Training" if train else "Testing",
|
||||||
series=f"Sample {actual_idx}",
|
series=f"Sample {actual_idx}",
|
||||||
@@ -668,7 +667,8 @@ class NonAutoRegressiveQuantileRegression(Trainer):
|
|||||||
report_interactive=False,
|
report_interactive=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if final:
|
else:
|
||||||
|
print("Saving figs")
|
||||||
# fig to PIL image
|
# fig to PIL image
|
||||||
fig.savefig(f"sample_{actual_idx}_plot.png")
|
fig.savefig(f"sample_{actual_idx}_plot.png")
|
||||||
task.get_logger().report_image(
|
task.get_logger().report_image(
|
||||||
@@ -789,7 +789,7 @@ class NonAutoRegressiveQuantileRegression(Trainer):
|
|||||||
for i in range(10):
|
for i in range(10):
|
||||||
ax2.plot(predictions_np[i], label=f"Sample {i}")
|
ax2.plot(predictions_np[i], label=f"Sample {i}")
|
||||||
|
|
||||||
ax2.plot(next_day_np, label="Real NRV", linewidth=3)
|
ax2.plot(next_day_np, label="Real NRV", linewidth=4, color="orange")
|
||||||
ax2.legend()
|
ax2.legend()
|
||||||
|
|
||||||
ax2.set_ylim(-1500, 1500)
|
ax2.set_ylim(-1500, 1500)
|
||||||
|
|||||||
@@ -198,6 +198,7 @@ class Trainer:
|
|||||||
|
|
||||||
if task:
|
if task:
|
||||||
self.finish_training(task=task)
|
self.finish_training(task=task)
|
||||||
|
self.debug_plots(task, False, test_loader, test_samples, -1)
|
||||||
# task.close()
|
# task.close()
|
||||||
except Exception:
|
except Exception:
|
||||||
if task:
|
if task:
|
||||||
@@ -342,21 +343,8 @@ class Trainer:
|
|||||||
features[:96], target, predictions, show_legend=(0 == 0)
|
features[:96], target, predictions, show_legend=(0 == 0)
|
||||||
)
|
)
|
||||||
|
|
||||||
task.get_logger().report_matplotlib_figure(
|
|
||||||
title="Training" if train else "Testing",
|
|
||||||
series=f"Sample {actual_idx}",
|
|
||||||
iteration=epoch,
|
|
||||||
figure=fig,
|
|
||||||
)
|
|
||||||
|
|
||||||
task.get_logger().report_matplotlib_figure(
|
|
||||||
title="Training Samples" if train else "Testing Samples",
|
|
||||||
series=f"Sample {actual_idx} samples",
|
|
||||||
iteration=epoch,
|
|
||||||
figure=fig2,
|
|
||||||
report_interactive=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
def debug_scatter_plot(self, task, train: bool, samples, epoch):
|
def debug_scatter_plot(self, task, train: bool, samples, epoch):
|
||||||
|
|||||||
@@ -2,7 +2,9 @@ from src.utils.clearml import ClearMLHelper
|
|||||||
|
|
||||||
#### ClearML ####
|
#### ClearML ####
|
||||||
clearml_helper = ClearMLHelper(project_name="Thesis/NrvForecast")
|
clearml_helper = ClearMLHelper(project_name="Thesis/NrvForecast")
|
||||||
task = clearml_helper.get_task(task_name="AQR: Linear Baseline + Quarter Trigonometric")
|
task = clearml_helper.get_task(
|
||||||
|
task_name="AQR: Linear Baseline + Load + PV + Wind + Net Position + Quarter"
|
||||||
|
)
|
||||||
task.execute_remotely(queue_name="default", exit_process=True)
|
task.execute_remotely(queue_name="default", exit_process=True)
|
||||||
|
|
||||||
from src.policies.PolicyEvaluator import PolicyEvaluator
|
from src.policies.PolicyEvaluator import PolicyEvaluator
|
||||||
@@ -70,16 +72,16 @@ model_parameters = {
|
|||||||
"hidden_size": 256,
|
"hidden_size": 256,
|
||||||
"num_layers": 2,
|
"num_layers": 2,
|
||||||
"dropout": 0.2,
|
"dropout": 0.2,
|
||||||
"time_feature_embedding": 2,
|
"time_feature_embedding": 5,
|
||||||
}
|
}
|
||||||
|
|
||||||
model_parameters = task.connect(model_parameters, name="model_parameters")
|
model_parameters = task.connect(model_parameters, name="model_parameters")
|
||||||
|
|
||||||
# time_embedding = TimeEmbedding(
|
time_embedding = TimeEmbedding(
|
||||||
# data_processor.get_time_feature_size(), model_parameters["time_feature_embedding"]
|
data_processor.get_time_feature_size(), model_parameters["time_feature_embedding"]
|
||||||
# )
|
)
|
||||||
|
|
||||||
time_embedding = TrigonometricTimeEmbedding(data_processor.get_time_feature_size())
|
# time_embedding = TrigonometricTimeEmbedding(data_processor.get_time_feature_size())
|
||||||
|
|
||||||
# lstm_model = GRUModel(
|
# lstm_model = GRUModel(
|
||||||
# time_embedding.output_dim(inputDim),
|
# time_embedding.output_dim(inputDim),
|
||||||
|
|||||||
Reference in New Issue
Block a user