Non autregressive gru model load

This commit is contained in:
2024-05-06 16:11:15 +02:00
parent 19ab597ae6
commit d7f4c1849b
7 changed files with 55 additions and 22 deletions

View File

@@ -633,6 +633,7 @@ class NonAutoRegressiveQuantileRegression(Trainer):
def debug_plots(self, task, train: bool, data_loader, sample_indices, epoch):
for actual_idx, idx in sample_indices.items():
features, target, _ = data_loader.dataset[idx]
print(features.shape, target.shape)
features = features.to(self.device)
target = target.to(self.device)