Fixed policy evaluation for autoregressive

This commit is contained in:
2024-02-29 23:23:11 +01:00
parent fe1e388ffb
commit 34335cd9fe
10 changed files with 191 additions and 95 deletions

View File

@@ -155,31 +155,38 @@ class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer):
generated_samples = {}
with torch.no_grad():
total_samples = len(dataloader.dataset) - 96
for _, _, idx_batch in tqdm(dataloader):
idx_batch = [idx for idx in idx_batch if idx < total_samples]
total_samples = len(dataloader.dataset)
print(
"Full day valid indices: ",
len(dataloader.dataset.full_day_valid_indices),
)
print(
"Valid indices: ",
len(dataloader.dataset.valid_indices),
)
if len(idx_batch) == 0:
continue
print(dataloader.dataset.valid_indices)
for idx in tqdm(idx_batch):
computed_idx_batch = [idx] * 100
initial, _, samples, targets = self.auto_regressive(
dataloader.dataset, idx_batch=computed_idx_batch
)
for i in tqdm(dataloader.dataset.full_day_valid_indices):
idx = dataloader.dataset.valid_indices.index(i)
generated_samples[idx.item()] = (
self.data_processor.inverse_transform(initial),
self.data_processor.inverse_transform(samples),
)
computed_idx_batch = [idx] * 100
initial, _, samples, targets = self.auto_regressive(
dataloader.dataset, idx_batch=computed_idx_batch
)
samples = samples.unsqueeze(0)
targets = targets.squeeze(-1)
targets = targets[0].unsqueeze(0)
generated_samples[idx] = (
self.data_processor.inverse_transform(initial),
self.data_processor.inverse_transform(samples),
)
crps = crps_from_samples(samples, targets)
samples = samples.unsqueeze(0)
targets = targets.squeeze(-1)
targets = targets[0].unsqueeze(0)
crps_from_samples_metric.append(crps[0].mean().item())
crps = crps_from_samples(samples, targets)
crps_from_samples_metric.append(crps[0].mean().item())
task.get_logger().report_scalar(
title="CRPS_from_samples",
@@ -190,10 +197,7 @@ class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer):
# using the policy evaluator, evaluate the policy with the generated samples
if self.policy_evaluator is not None:
_, test_loader = self.data_processor.get_dataloaders(
predict_sequence_length=self.model.output_size, full_day_skip=True
)
self.policy_evaluator.evaluate_test_set(generated_samples, test_loader)
self.policy_evaluator.evaluate_test_set(generated_samples, dataloader)
df = self.policy_evaluator.get_profits_as_scalars()
# for each row, report the profits