other changes
This commit is contained in:
@@ -87,7 +87,10 @@ class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer):
|
||||
crps_from_samples_metric = []
|
||||
|
||||
with torch.no_grad():
|
||||
total_samples = len(dataloader.dataset) - 96
|
||||
for _, _, idx_batch in tqdm(dataloader):
|
||||
idx_batch = [idx for idx in idx_batch if idx < total_samples]
|
||||
|
||||
if len(idx_batch) == 0:
|
||||
continue
|
||||
|
||||
|
||||
Reference in New Issue
Block a user