other changes

This commit is contained in:
Victor Mylle
2024-01-15 12:31:56 +00:00
parent 5f2418a205
commit 67cc6d4bb9
7 changed files with 855 additions and 482 deletions

View File

@@ -87,7 +87,10 @@ class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer):
crps_from_samples_metric = []
with torch.no_grad():
total_samples = len(dataloader.dataset) - 96
for _, _, idx_batch in tqdm(dataloader):
idx_batch = [idx for idx in idx_batch if idx < total_samples]
if len(idx_batch) == 0:
continue