Improved policy executer

This commit is contained in:
Victor Mylle
2024-01-16 23:22:05 +00:00
parent d1074281c4
commit b87ad1bf42
7 changed files with 1328 additions and 101 deletions

View File

@@ -9,56 +9,38 @@ from src.losses import PinballLoss, NonAutoRegressivePinballLoss, CRPSLoss
import plotly.graph_objects as go
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import CubicSpline
def sample_from_dist(quantiles, output_values):
# check if tensor:
if isinstance(quantiles, torch.Tensor):
quantiles = quantiles.cpu().numpy()
def sample_from_dist(quantiles, preds):
if isinstance(preds, torch.Tensor):
preds = preds.detach().cpu()
if isinstance(output_values, torch.Tensor):
output_values = output_values.cpu().numpy()
# if preds more than 2 dimensions, flatten to 2
if len(preds.shape) > 2:
preds = preds.reshape(-1, preds.shape[-1])
# target will be reshaped from (1024, 96, 15) to (1024*96, 15)
# our target (1024, 96) also needs to be reshaped to (1024*96, 1)
target = target.reshape(-1, 1)
if isinstance(quantiles, list):
quantiles = np.array(quantiles)
# preds and target as numpy
preds = preds.numpy()
reshaped_values = output_values.reshape(-1, len(quantiles))
# random probabilities of (1000, 1)
import random
probs = np.array([random.random() for _ in range(1000)])
uniform_random_numbers = np.random.uniform(0, 1, (reshaped_values.shape[0], 1000))
spline = CubicSpline(quantiles, preds, axis=1)
samples = spline(probs)
idx_below = np.searchsorted(quantiles, uniform_random_numbers, side="right") - 1
idx_above = np.clip(idx_below + 1, 0, len(quantiles) - 1)
# get the diagonal
samples = np.diag(samples)
# handle edge case where idx_below is -1
idx_below = np.clip(idx_below, 0, len(quantiles) - 1)
return samples
y_below = reshaped_values[np.arange(reshaped_values.shape[0])[:, None], idx_below]
y_above = reshaped_values[np.arange(reshaped_values.shape[0])[:, None], idx_above]
# Calculate the slopes for interpolation
x_below = quantiles[idx_below]
x_above = quantiles[idx_above]
# Interpolate
# Ensure all variables are NumPy arrays
x_below_np = x_below.cpu().numpy() if isinstance(x_below, torch.Tensor) else x_below
x_above_np = x_above.cpu().numpy() if isinstance(x_above, torch.Tensor) else x_above
y_below_np = y_below.cpu().numpy() if isinstance(y_below, torch.Tensor) else y_below
y_above_np = y_above.cpu().numpy() if isinstance(y_above, torch.Tensor) else y_above
# Compute slopes for interpolation
slopes_np = (y_above_np - y_below_np) / (
np.clip(x_above_np - x_below_np, 1e-6, np.inf)
)
# Perform the interpolation
new_samples = y_below_np + slopes_np * (uniform_random_numbers - x_below_np)
# Return the mean of the samples
return np.mean(new_samples, axis=1)
def auto_regressive(dataset, model, idx_batch, sequence_length: int = 96):
device = model.device
def auto_regressive(dataset, model, quantiles, idx_batch, sequence_length: int = 96):
device = next(model.parameters()).device
prev_features, targets = dataset.get_batch(idx_batch)
prev_features = prev_features.to(device)
targets = targets.to(device)
@@ -72,7 +54,7 @@ def auto_regressive(dataset, model, idx_batch, sequence_length: int = 96):
with torch.no_grad():
new_predictions_full = model(prev_features) # (batch_size, quantiles)
samples = (
torch.tensor(sample_from_dist( new_predictions_full))
torch.tensor(sample_from_dist(quantiles, new_predictions_full))
.unsqueeze(1)
.to(device)
) # (batch_size, 1)
@@ -125,7 +107,7 @@ def auto_regressive(dataset, model, idx_batch, sequence_length: int = 96):
) # (batch_size, sequence_length, quantiles)
samples = (
torch.tensor(sample_from_dist(new_predictions_full))
torch.tensor(sample_from_dist(quantiles, new_predictions_full))
.unsqueeze(-1)
.to(device)
) # (batch_size, 1)
@@ -353,7 +335,7 @@ class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer):
return fig
def auto_regressive(self, dataset, idx_batch, sequence_length: int = 96):
return auto_regressive(dataset, self.model, idx_batch, sequence_length)
return auto_regressive(dataset, self.model, self.quantiles, idx_batch, sequence_length)
def plot_quantile_percentages(
self, task, data_loader, train: bool = True, iteration: int = None, full_day: bool = False