Updated training scripts

This commit is contained in:
2024-03-18 12:15:06 +01:00
parent 34335cd9fe
commit 1a8e735cbc
10 changed files with 487 additions and 308 deletions

View File

@@ -9,6 +9,8 @@ import plotly.subplots as sp
from plotly.subplots import make_subplots
from src.trainers.trainer import Trainer
from tqdm import tqdm
import matplotlib.pyplot as plt
class AutoRegressiveTrainer(Trainer):
def __init__(
@@ -34,28 +36,41 @@ class AutoRegressiveTrainer(Trainer):
def debug_plots(self, task, train: bool, data_loader, sample_indices, epoch):
for actual_idx, idx in sample_indices.items():
auto_regressive_output = self.auto_regressive(data_loader.dataset, [idx]*1000)
print(f"Plotting sample {actual_idx}")
auto_regressive_output = self.auto_regressive(
data_loader.dataset, [idx] * 1000
)
if len(auto_regressive_output) == 3:
initial, predictions, target = auto_regressive_output
else:
initial, _, predictions, target = auto_regressive_output
# keep one initial
initial = initial[0]
target = target[0]
predictions = predictions
fig = self.get_plot(initial, target, predictions, show_legend=(0 == 0))
fig, fig2 = self.get_plot(
initial, target, predictions, show_legend=(0 == 0)
)
task.get_logger().report_matplotlib_figure(
title="Training" if train else "Testing",
series=f'Sample {actual_idx}',
series=f"Sample {actual_idx}",
iteration=epoch,
figure=fig,
)
task.get_logger().report_matplotlib_figure(
title="Training Samples" if train else "Testing Samples",
series=f"Sample {actual_idx} samples",
iteration=epoch,
figure=fig2,
report_interactive=False,
)
plt.close()
def auto_regressive(self, data_loader, idx, sequence_length: int = 96):
self.model.eval()