Autoregressive Quantile Training with Policy evaluation
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from src.policies.PolicyEvaluator import PolicyEvaluator
|
||||
from src.losses.crps_metric import crps_from_samples
|
||||
from src.trainers.trainer import Trainer
|
||||
from src.trainers.autoregressive_trainer import AutoRegressiveTrainer
|
||||
@@ -131,10 +132,13 @@ class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer):
|
||||
data_processor: DataProcessor,
|
||||
quantiles: list,
|
||||
device: torch.device,
|
||||
policy_evaluator: PolicyEvaluator = None,
|
||||
debug: bool = True,
|
||||
):
|
||||
|
||||
self.quantiles = quantiles
|
||||
self.test_set_samples = {}
|
||||
self.policy_evaluator = policy_evaluator
|
||||
|
||||
criterion = PinballLoss(quantiles=quantiles)
|
||||
super().__init__(
|
||||
@@ -149,6 +153,7 @@ class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer):
|
||||
|
||||
def calculate_crps_from_samples(self, task, dataloader, epoch: int):
|
||||
crps_from_samples_metric = []
|
||||
generated_samples = {}
|
||||
|
||||
with torch.no_grad():
|
||||
total_samples = len(dataloader.dataset) - 96
|
||||
@@ -160,9 +165,12 @@ class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer):
|
||||
|
||||
for idx in tqdm(idx_batch):
|
||||
computed_idx_batch = [idx] * 100
|
||||
_, _, samples, targets = self.auto_regressive(
|
||||
initial, _, samples, targets = self.auto_regressive(
|
||||
dataloader.dataset, idx_batch=computed_idx_batch
|
||||
)
|
||||
|
||||
generated_samples[idx.item()] = (initial, self.data_processor.inverse_transform(samples))
|
||||
|
||||
samples = samples.unsqueeze(0)
|
||||
targets = targets.squeeze(-1)
|
||||
targets = targets[0].unsqueeze(0)
|
||||
@@ -175,6 +183,20 @@ class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer):
|
||||
title="CRPS_from_samples", series="test", value=np.mean(crps_from_samples_metric), iteration=epoch
|
||||
)
|
||||
|
||||
# using the policy evaluator, evaluate the policy with the generated samples
|
||||
if self.policy_evaluator is not None:
|
||||
_, test_loader = self.data_processor.get_dataloaders(
|
||||
predict_sequence_length=self.model.output_size)
|
||||
self.policy_evaluator.evaluate_test_set(generated_samples, test_loader)
|
||||
df = self.policy_evaluator.get_profits_as_scalars()
|
||||
|
||||
# for each row, report the profits
|
||||
for idx, row in df.iterrows():
|
||||
task.get_logger().report_scalar(
|
||||
title="Profit", series=f"penalty_{row['Penalty']}", value=row["Total Profit"], iteration=epoch
|
||||
)
|
||||
|
||||
|
||||
def log_final_metrics(self, task, dataloader, train: bool = True):
|
||||
metrics = {metric.__class__.__name__: 0.0 for metric in self.metrics_to_track}
|
||||
transformed_metrics = {
|
||||
@@ -194,10 +216,14 @@ class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer):
|
||||
|
||||
if train == False:
|
||||
for idx in tqdm(idx_batch):
|
||||
computed_idx_batch = [idx] * 100
|
||||
_, outputs, samples, targets = self.auto_regressive(
|
||||
computed_idx_batch = [idx] * 250
|
||||
initial, outputs, samples, targets = self.auto_regressive(
|
||||
dataloader.dataset, idx_batch=computed_idx_batch
|
||||
)
|
||||
|
||||
# save the samples for the idx, these will be used for evaluating the policy
|
||||
self.test_set_samples[idx.item()] = (initial, self.data_processor.inverse_transform(samples))
|
||||
|
||||
samples = samples.unsqueeze(0)
|
||||
targets = targets.squeeze(-1)
|
||||
targets = targets[0].unsqueeze(0)
|
||||
|
||||
@@ -196,7 +196,7 @@ class Trainer:
|
||||
|
||||
if task:
|
||||
self.finish_training(task=task)
|
||||
task.close()
|
||||
# task.close()
|
||||
except Exception:
|
||||
if task:
|
||||
task.close()
|
||||
|
||||
Reference in New Issue
Block a user