Improved policy executer
This commit is contained in:
@@ -9,56 +9,38 @@ from src.losses import PinballLoss, NonAutoRegressivePinballLoss, CRPSLoss
|
||||
import plotly.graph_objects as go
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from scipy.interpolate import CubicSpline
|
||||
|
||||
|
||||
def sample_from_dist(quantiles, output_values):
|
||||
# check if tensor:
|
||||
if isinstance(quantiles, torch.Tensor):
|
||||
quantiles = quantiles.cpu().numpy()
|
||||
def sample_from_dist(quantiles, preds):
|
||||
if isinstance(preds, torch.Tensor):
|
||||
preds = preds.detach().cpu()
|
||||
|
||||
if isinstance(output_values, torch.Tensor):
|
||||
output_values = output_values.cpu().numpy()
|
||||
# if preds more than 2 dimensions, flatten to 2
|
||||
if len(preds.shape) > 2:
|
||||
preds = preds.reshape(-1, preds.shape[-1])
|
||||
# target will be reshaped from (1024, 96, 15) to (1024*96, 15)
|
||||
# our target (1024, 96) also needs to be reshaped to (1024*96, 1)
|
||||
target = target.reshape(-1, 1)
|
||||
|
||||
if isinstance(quantiles, list):
|
||||
quantiles = np.array(quantiles)
|
||||
# preds and target as numpy
|
||||
preds = preds.numpy()
|
||||
|
||||
reshaped_values = output_values.reshape(-1, len(quantiles))
|
||||
# random probabilities of (1000, 1)
|
||||
import random
|
||||
probs = np.array([random.random() for _ in range(1000)])
|
||||
|
||||
uniform_random_numbers = np.random.uniform(0, 1, (reshaped_values.shape[0], 1000))
|
||||
spline = CubicSpline(quantiles, preds, axis=1)
|
||||
|
||||
samples = spline(probs)
|
||||
|
||||
idx_below = np.searchsorted(quantiles, uniform_random_numbers, side="right") - 1
|
||||
idx_above = np.clip(idx_below + 1, 0, len(quantiles) - 1)
|
||||
# get the diagonal
|
||||
samples = np.diag(samples)
|
||||
|
||||
# handle edge case where idx_below is -1
|
||||
idx_below = np.clip(idx_below, 0, len(quantiles) - 1)
|
||||
return samples
|
||||
|
||||
y_below = reshaped_values[np.arange(reshaped_values.shape[0])[:, None], idx_below]
|
||||
y_above = reshaped_values[np.arange(reshaped_values.shape[0])[:, None], idx_above]
|
||||
|
||||
# Calculate the slopes for interpolation
|
||||
x_below = quantiles[idx_below]
|
||||
x_above = quantiles[idx_above]
|
||||
|
||||
# Interpolate
|
||||
# Ensure all variables are NumPy arrays
|
||||
x_below_np = x_below.cpu().numpy() if isinstance(x_below, torch.Tensor) else x_below
|
||||
x_above_np = x_above.cpu().numpy() if isinstance(x_above, torch.Tensor) else x_above
|
||||
y_below_np = y_below.cpu().numpy() if isinstance(y_below, torch.Tensor) else y_below
|
||||
y_above_np = y_above.cpu().numpy() if isinstance(y_above, torch.Tensor) else y_above
|
||||
|
||||
# Compute slopes for interpolation
|
||||
slopes_np = (y_above_np - y_below_np) / (
|
||||
np.clip(x_above_np - x_below_np, 1e-6, np.inf)
|
||||
)
|
||||
|
||||
# Perform the interpolation
|
||||
new_samples = y_below_np + slopes_np * (uniform_random_numbers - x_below_np)
|
||||
|
||||
# Return the mean of the samples
|
||||
return np.mean(new_samples, axis=1)
|
||||
|
||||
def auto_regressive(dataset, model, idx_batch, sequence_length: int = 96):
|
||||
device = model.device
|
||||
def auto_regressive(dataset, model, quantiles, idx_batch, sequence_length: int = 96):
|
||||
device = next(model.parameters()).device
|
||||
prev_features, targets = dataset.get_batch(idx_batch)
|
||||
prev_features = prev_features.to(device)
|
||||
targets = targets.to(device)
|
||||
@@ -72,7 +54,7 @@ def auto_regressive(dataset, model, idx_batch, sequence_length: int = 96):
|
||||
with torch.no_grad():
|
||||
new_predictions_full = model(prev_features) # (batch_size, quantiles)
|
||||
samples = (
|
||||
torch.tensor(sample_from_dist( new_predictions_full))
|
||||
torch.tensor(sample_from_dist(quantiles, new_predictions_full))
|
||||
.unsqueeze(1)
|
||||
.to(device)
|
||||
) # (batch_size, 1)
|
||||
@@ -125,7 +107,7 @@ def auto_regressive(dataset, model, idx_batch, sequence_length: int = 96):
|
||||
) # (batch_size, sequence_length, quantiles)
|
||||
|
||||
samples = (
|
||||
torch.tensor(sample_from_dist(new_predictions_full))
|
||||
torch.tensor(sample_from_dist(quantiles, new_predictions_full))
|
||||
.unsqueeze(-1)
|
||||
.to(device)
|
||||
) # (batch_size, 1)
|
||||
@@ -353,7 +335,7 @@ class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer):
|
||||
return fig
|
||||
|
||||
def auto_regressive(self, dataset, idx_batch, sequence_length: int = 96):
|
||||
return auto_regressive(dataset, self.model, idx_batch, sequence_length)
|
||||
return auto_regressive(dataset, self.model, self.quantiles, idx_batch, sequence_length)
|
||||
|
||||
def plot_quantile_percentages(
|
||||
self, task, data_loader, train: bool = True, iteration: int = None, full_day: bool = False
|
||||
|
||||
Reference in New Issue
Block a user