Fixed crps + more inputs

This commit is contained in:
Victor Mylle
2023-12-05 00:08:17 +00:00
parent 120b6aa5bd
commit d3bf04d68c
13 changed files with 128426 additions and 70 deletions

View File

@@ -270,7 +270,7 @@ class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer):
)
def plot_quantile_percentages(
self, task, data_loader, train: bool = True, iteration: int = None
self, task, data_loader, train: bool = True, iteration: int = None, full_day: bool = False
):
quantiles = self.quantiles
total = 0
@@ -278,16 +278,34 @@ class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer):
self.model.eval()
with torch.no_grad():
for inputs, targets, _ in data_loader:
inputs = inputs.to(self.device)
output = self.model(inputs).cpu().numpy()
targets = targets.squeeze(-1).cpu().numpy()
total_samples = len(data_loader.dataset) - 96
for inputs, targets, idx_batch in data_loader:
idx_batch = [idx for idx in idx_batch if idx < total_samples]
if full_day:
_, outputs, samples, targets = self.auto_regressive(
data_loader.dataset, idx_batch=idx_batch
)
# outputs: (batch, sequence_length, num_quantiles)
# targets: (batch, sequence_length, 1)
# reshape to (batch_size * sequence_length, num_quantiles)
outputs = outputs.reshape(-1, len(quantiles))
targets = targets.reshape(-1)
# to cpu
outputs = outputs.cpu().numpy()
targets = targets.cpu().numpy()
else:
inputs = inputs.to(self.device)
outputs = self.model(inputs).cpu().numpy() # (batch_size, num_quantiles)
targets = targets.squeeze(-1).cpu().numpy() # (batch_size, 1)
# output shape: (batch_size, num_quantiles)
# target shape: (batch_size, 1)
for i, q in enumerate(quantiles):
quantile_counter[q] += np.sum(
targets < output[:, i]
targets < outputs[:, i]
)
total += len(targets)
@@ -322,18 +340,19 @@ class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer):
) # Format the number as a percentage
series_name = "Training Set" if train else "Test Set"
full_day_str = "Full Day" if full_day else "Single Step"
# Adding labels and title
ax.set_xlabel("Quantile")
ax.set_ylabel("Fraction of data under quantile forecast")
ax.set_title(f"Quantile Performance Comparison ({series_name})")
ax.set_title(f"{series_name} {full_day_str} Quantile Performance Comparison")
ax.set_xticks(index + bar_width / 2)
ax.set_xticklabels(quantiles)
ax.legend()
task.get_logger().report_matplotlib_figure(
title="Quantile Performance Comparison",
series=series_name,
series=f"{series_name} {full_day_str}",
report_image=True,
figure=plt,
iteration=iteration,