Updated training scripts
This commit is contained in:
@@ -9,6 +9,8 @@ import plotly.subplots as sp
|
||||
from plotly.subplots import make_subplots
|
||||
from src.trainers.trainer import Trainer
|
||||
from tqdm import tqdm
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
class AutoRegressiveTrainer(Trainer):
|
||||
def __init__(
|
||||
@@ -34,28 +36,41 @@ class AutoRegressiveTrainer(Trainer):
|
||||
|
||||
def debug_plots(self, task, train: bool, data_loader, sample_indices, epoch):
|
||||
for actual_idx, idx in sample_indices.items():
|
||||
auto_regressive_output = self.auto_regressive(data_loader.dataset, [idx]*1000)
|
||||
print(f"Plotting sample {actual_idx}")
|
||||
auto_regressive_output = self.auto_regressive(
|
||||
data_loader.dataset, [idx] * 1000
|
||||
)
|
||||
if len(auto_regressive_output) == 3:
|
||||
initial, predictions, target = auto_regressive_output
|
||||
else:
|
||||
initial, _, predictions, target = auto_regressive_output
|
||||
|
||||
|
||||
# keep one initial
|
||||
initial = initial[0]
|
||||
target = target[0]
|
||||
|
||||
predictions = predictions
|
||||
|
||||
fig = self.get_plot(initial, target, predictions, show_legend=(0 == 0))
|
||||
fig, fig2 = self.get_plot(
|
||||
initial, target, predictions, show_legend=(0 == 0)
|
||||
)
|
||||
|
||||
task.get_logger().report_matplotlib_figure(
|
||||
title="Training" if train else "Testing",
|
||||
series=f'Sample {actual_idx}',
|
||||
series=f"Sample {actual_idx}",
|
||||
iteration=epoch,
|
||||
figure=fig,
|
||||
)
|
||||
|
||||
task.get_logger().report_matplotlib_figure(
|
||||
title="Training Samples" if train else "Testing Samples",
|
||||
series=f"Sample {actual_idx} samples",
|
||||
iteration=epoch,
|
||||
figure=fig2,
|
||||
report_interactive=False,
|
||||
)
|
||||
|
||||
plt.close()
|
||||
|
||||
def auto_regressive(self, data_loader, idx, sequence_length: int = 96):
|
||||
self.model.eval()
|
||||
|
||||
Reference in New Issue
Block a user