Fixed policy evaluation for autoregressive
This commit is contained in:
@@ -155,31 +155,38 @@ class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer):
|
||||
generated_samples = {}
|
||||
|
||||
with torch.no_grad():
|
||||
total_samples = len(dataloader.dataset) - 96
|
||||
for _, _, idx_batch in tqdm(dataloader):
|
||||
idx_batch = [idx for idx in idx_batch if idx < total_samples]
|
||||
total_samples = len(dataloader.dataset)
|
||||
print(
|
||||
"Full day valid indices: ",
|
||||
len(dataloader.dataset.full_day_valid_indices),
|
||||
)
|
||||
print(
|
||||
"Valid indices: ",
|
||||
len(dataloader.dataset.valid_indices),
|
||||
)
|
||||
|
||||
if len(idx_batch) == 0:
|
||||
continue
|
||||
print(dataloader.dataset.valid_indices)
|
||||
|
||||
for idx in tqdm(idx_batch):
|
||||
computed_idx_batch = [idx] * 100
|
||||
initial, _, samples, targets = self.auto_regressive(
|
||||
dataloader.dataset, idx_batch=computed_idx_batch
|
||||
)
|
||||
for i in tqdm(dataloader.dataset.full_day_valid_indices):
|
||||
idx = dataloader.dataset.valid_indices.index(i)
|
||||
|
||||
generated_samples[idx.item()] = (
|
||||
self.data_processor.inverse_transform(initial),
|
||||
self.data_processor.inverse_transform(samples),
|
||||
)
|
||||
computed_idx_batch = [idx] * 100
|
||||
initial, _, samples, targets = self.auto_regressive(
|
||||
dataloader.dataset, idx_batch=computed_idx_batch
|
||||
)
|
||||
|
||||
samples = samples.unsqueeze(0)
|
||||
targets = targets.squeeze(-1)
|
||||
targets = targets[0].unsqueeze(0)
|
||||
generated_samples[idx] = (
|
||||
self.data_processor.inverse_transform(initial),
|
||||
self.data_processor.inverse_transform(samples),
|
||||
)
|
||||
|
||||
crps = crps_from_samples(samples, targets)
|
||||
samples = samples.unsqueeze(0)
|
||||
targets = targets.squeeze(-1)
|
||||
targets = targets[0].unsqueeze(0)
|
||||
|
||||
crps_from_samples_metric.append(crps[0].mean().item())
|
||||
crps = crps_from_samples(samples, targets)
|
||||
|
||||
crps_from_samples_metric.append(crps[0].mean().item())
|
||||
|
||||
task.get_logger().report_scalar(
|
||||
title="CRPS_from_samples",
|
||||
@@ -190,10 +197,7 @@ class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer):
|
||||
|
||||
# using the policy evaluator, evaluate the policy with the generated samples
|
||||
if self.policy_evaluator is not None:
|
||||
_, test_loader = self.data_processor.get_dataloaders(
|
||||
predict_sequence_length=self.model.output_size, full_day_skip=True
|
||||
)
|
||||
self.policy_evaluator.evaluate_test_set(generated_samples, test_loader)
|
||||
self.policy_evaluator.evaluate_test_set(generated_samples, dataloader)
|
||||
df = self.policy_evaluator.get_profits_as_scalars()
|
||||
|
||||
# for each row, report the profits
|
||||
|
||||
@@ -8,6 +8,7 @@ from plotly.subplots import make_subplots
|
||||
from clearml.config import running_remotely
|
||||
from torchinfo import summary
|
||||
|
||||
|
||||
class Trainer:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -95,13 +96,15 @@ class Trainer:
|
||||
loader = test_loader
|
||||
|
||||
np.random.seed(42)
|
||||
actual_indices = np.random.choice(loader.dataset.full_day_valid_indices, num_samples, replace=False)
|
||||
actual_indices = np.random.choice(
|
||||
loader.dataset.full_day_valid_indices, num_samples, replace=False
|
||||
)
|
||||
indices = {}
|
||||
for i in actual_indices:
|
||||
indices[i] = loader.dataset.valid_indices.index(i)
|
||||
|
||||
print(actual_indices)
|
||||
|
||||
|
||||
return indices
|
||||
|
||||
def train(self, epochs: int, remotely: bool = False, task: Task = None):
|
||||
@@ -190,9 +193,7 @@ class Trainer:
|
||||
# )
|
||||
|
||||
if hasattr(self, "calculate_crps_from_samples"):
|
||||
self.calculate_crps_from_samples(
|
||||
task, full_day_skip_test_loader, epoch
|
||||
)
|
||||
self.calculate_crps_from_samples(task, test_loader, epoch)
|
||||
|
||||
if task:
|
||||
self.finish_training(task=task)
|
||||
@@ -259,7 +260,6 @@ class Trainer:
|
||||
self.model = torch.load("checkpoint.pt")
|
||||
self.model.eval()
|
||||
|
||||
|
||||
# set full day skip
|
||||
self.data_processor.set_full_day_skip(True)
|
||||
train_loader, test_loader = self.data_processor.get_dataloaders(
|
||||
@@ -361,7 +361,6 @@ class Trainer:
|
||||
for trace in sub_fig.data:
|
||||
fig.add_trace(trace, row=row, col=col)
|
||||
|
||||
|
||||
# loss = self.criterion(predictions.to(self.device), target.squeeze(-1).to(self.device)).item()
|
||||
|
||||
# fig['layout']['annotations'][i].update(text=f"{loss.__class__.__name__}: {loss:.6f}")
|
||||
|
||||
Reference in New Issue
Block a user