Fixed crps + more inputs

This commit is contained in:
Victor Mylle
2023-12-05 00:08:17 +00:00
parent 120b6aa5bd
commit d3bf04d68c
13 changed files with 128426 additions and 70 deletions

View File

@@ -94,7 +94,6 @@ class AutoRegressiveTrainer(Trainer):
target_full.append(target)
with torch.no_grad():
print(prev_features.shape)
prediction = self.model(prev_features.unsqueeze(0))
predictions_full.append(prediction.squeeze(-1))
@@ -107,8 +106,6 @@ class AutoRegressiveTrainer(Trainer):
dim=0,
)
print(new_features.shape)
# get the other needed features
other_features, new_target = data_loader.dataset.random_day_autoregressive(
idx + i + 1