Policy evaluation during training
This commit is contained in:
@@ -18,9 +18,11 @@ class PolicyEvaluator:
|
||||
self.dates = pd.to_datetime(self.dates)
|
||||
|
||||
### Load Imbalance Prices ###
|
||||
imbalance_prices = pd.read_csv('data/imbalance_prices.csv', sep=';')
|
||||
imbalance_prices["DateTime"] = pd.to_datetime(imbalance_prices['DateTime'], utc=True)
|
||||
self.imbalance_prices = imbalance_prices.sort_values(by=['DateTime'])
|
||||
imbalance_prices = pd.read_csv("data/imbalance_prices.csv", sep=";")
|
||||
imbalance_prices["DateTime"] = pd.to_datetime(
|
||||
imbalance_prices["DateTime"], utc=True
|
||||
)
|
||||
self.imbalance_prices = imbalance_prices.sort_values(by=["DateTime"])
|
||||
|
||||
self.penalties = [0, 100, 300, 500, 800, 1000, 1500]
|
||||
self.profits = []
|
||||
@@ -28,30 +30,46 @@ class PolicyEvaluator:
|
||||
self.task = task
|
||||
|
||||
def get_imbanlance_prices_for_date(self, date):
|
||||
imbalance_prices_day = self.imbalance_prices[self.imbalance_prices["DateTime"].dt.date == date]
|
||||
return imbalance_prices_day['Positive imbalance price'].values
|
||||
imbalance_prices_day = self.imbalance_prices[
|
||||
self.imbalance_prices["DateTime"].dt.date == date
|
||||
]
|
||||
return imbalance_prices_day["Positive imbalance price"].values
|
||||
|
||||
def evaluate_for_date(self, date, idx_samples, test_loader):
|
||||
charge_thresholds = np.arange(-100, 250, 25)
|
||||
discharge_thresholds = np.arange(-100, 250, 25)
|
||||
|
||||
idx = test_loader.dataset.get_idx_for_date(date.date())
|
||||
|
||||
print("Evaluated for idx: ", idx)
|
||||
(initial, samples) = idx_samples[idx]
|
||||
|
||||
initial = initial.cpu().numpy()[0][-1]
|
||||
if len(initial.shape) == 2:
|
||||
initial = initial.cpu().numpy()[0][-1]
|
||||
else:
|
||||
initial = initial.cpu().numpy()[-1]
|
||||
samples = samples.cpu().numpy()
|
||||
|
||||
initial = np.repeat(initial, samples.shape[0])
|
||||
combined = np.concatenate((initial.reshape(-1, 1), samples), axis=1)
|
||||
|
||||
reconstructed_imbalance_prices = self.ipc.get_imbalance_prices_2023_for_date_vectorized(date, combined)
|
||||
reconstructed_imbalance_prices = torch.tensor(reconstructed_imbalance_prices, device="cuda")
|
||||
reconstructed_imbalance_prices = (
|
||||
self.ipc.get_imbalance_prices_2023_for_date_vectorized(date, combined)
|
||||
)
|
||||
reconstructed_imbalance_prices = torch.tensor(
|
||||
reconstructed_imbalance_prices, device="cuda"
|
||||
)
|
||||
|
||||
real_imbalance_prices = self.get_imbanlance_prices_for_date(date.date())
|
||||
|
||||
for penalty in self.penalties:
|
||||
found_charge_thresholds, found_discharge_thresholds = self.baseline_policy.get_optimal_thresholds(
|
||||
reconstructed_imbalance_prices, charge_thresholds, discharge_thresholds, penalty
|
||||
found_charge_thresholds, found_discharge_thresholds = (
|
||||
self.baseline_policy.get_optimal_thresholds(
|
||||
reconstructed_imbalance_prices,
|
||||
charge_thresholds,
|
||||
discharge_thresholds,
|
||||
penalty,
|
||||
)
|
||||
)
|
||||
|
||||
predicted_charge_threshold = found_charge_thresholds.mean(axis=0)
|
||||
@@ -59,13 +77,25 @@ class PolicyEvaluator:
|
||||
|
||||
### Determine Profits and Charge Cycles ###
|
||||
simulated_profit, simulated_charge_cycles = self.baseline_policy.simulate(
|
||||
torch.tensor([[real_imbalance_prices]]), torch.tensor([predicted_charge_threshold]), torch.tensor([predicted_discharge_threshold])
|
||||
torch.tensor([[real_imbalance_prices]]),
|
||||
torch.tensor([predicted_charge_threshold]),
|
||||
torch.tensor([predicted_discharge_threshold]),
|
||||
)
|
||||
self.profits.append([date, penalty, simulated_profit[0][0].item(), simulated_charge_cycles[0][0].item(), predicted_charge_threshold.item(), predicted_discharge_threshold.item()])
|
||||
|
||||
self.profits.append(
|
||||
[
|
||||
date,
|
||||
penalty,
|
||||
simulated_profit[0][0].item(),
|
||||
simulated_charge_cycles[0][0].item(),
|
||||
predicted_charge_threshold.item(),
|
||||
predicted_discharge_threshold.item(),
|
||||
]
|
||||
)
|
||||
|
||||
def evaluate_test_set(self, idx_samples, test_loader):
|
||||
self.profits = []
|
||||
try:
|
||||
print(self.dates)
|
||||
for date in tqdm(self.dates):
|
||||
self.evaluate_for_date(date, idx_samples, test_loader)
|
||||
except KeyboardInterrupt:
|
||||
@@ -76,14 +106,31 @@ class PolicyEvaluator:
|
||||
print(e)
|
||||
pass
|
||||
|
||||
self.profits = pd.DataFrame(self.profits, columns=["Date", "Penalty", "Profit", "Charge Cycles", "Charge Threshold", "Discharge Threshold"])
|
||||
self.profits = pd.DataFrame(
|
||||
self.profits,
|
||||
columns=[
|
||||
"Date",
|
||||
"Penalty",
|
||||
"Profit",
|
||||
"Charge Cycles",
|
||||
"Charge Threshold",
|
||||
"Discharge Threshold",
|
||||
],
|
||||
)
|
||||
|
||||
print("Profits calculated")
|
||||
print(self.profits.head())
|
||||
|
||||
def plot_profits_table(self):
|
||||
# Check if task or penalties are not set
|
||||
if self.task is None or not hasattr(self, 'penalties') or not hasattr(self, 'profits'):
|
||||
if (
|
||||
self.task is None
|
||||
or not hasattr(self, "penalties")
|
||||
or not hasattr(self, "profits")
|
||||
):
|
||||
print("Task, penalties, or profits not defined.")
|
||||
return
|
||||
|
||||
|
||||
if self.profits.empty:
|
||||
print("Profits DataFrame is empty.")
|
||||
return
|
||||
@@ -92,23 +139,32 @@ class PolicyEvaluator:
|
||||
aggregated = self.profits.groupby("Penalty").agg(
|
||||
Total_Profit=("Profit", "sum"),
|
||||
Total_Charge_Cycles=("Charge Cycles", "sum"),
|
||||
Num_Days=("Date", "nunique")
|
||||
Num_Days=("Date", "nunique"),
|
||||
)
|
||||
aggregated["Profit_Per_Year"] = (
|
||||
aggregated["Total_Profit"] / aggregated["Num_Days"] * 365
|
||||
)
|
||||
aggregated["Charge_Cycles_Per_Year"] = (
|
||||
aggregated["Total_Charge_Cycles"] / aggregated["Num_Days"] * 365
|
||||
)
|
||||
aggregated["Profit_Per_Year"] = aggregated["Total_Profit"] / aggregated["Num_Days"] * 365
|
||||
aggregated["Charge_Cycles_Per_Year"] = aggregated["Total_Charge_Cycles"] / aggregated["Num_Days"] * 365
|
||||
|
||||
# Reset index to make 'Penalty' a column again and drop unnecessary columns
|
||||
final_df = aggregated.reset_index().drop(columns=["Total_Profit", "Total_Charge_Cycles", "Num_Days"])
|
||||
final_df = aggregated.reset_index().drop(
|
||||
columns=["Total_Profit", "Total_Charge_Cycles", "Num_Days"]
|
||||
)
|
||||
|
||||
# Rename columns to match expected output
|
||||
final_df.columns = ["Penalty", "Total Profit", "Total Charge Cycles"]
|
||||
|
||||
# Profits till 400
|
||||
profits_till_400 = self.get_profits_till_400()
|
||||
|
||||
# aggregate the final_df and profits_till_400 with columns: Penalty, total profit, total charge cycles, profit till 400, total charge cycles
|
||||
final_df = final_df.merge(profits_till_400, on="Penalty")
|
||||
|
||||
# Log the final results table
|
||||
self.task.get_logger().report_table(
|
||||
"Policy Results",
|
||||
"Policy Results",
|
||||
iteration=0,
|
||||
table_plot=final_df
|
||||
"Policy Results", "Policy Results", iteration=0, table_plot=final_df
|
||||
)
|
||||
|
||||
def plot_thresholds_per_day(self):
|
||||
@@ -116,10 +172,10 @@ class PolicyEvaluator:
|
||||
return
|
||||
|
||||
fig = px.line(
|
||||
self.profits[self.profits["Penalty"] == 0],
|
||||
x="Date",
|
||||
y=["Charge Threshold", "Discharge Threshold"],
|
||||
title="Charge and Discharge Thresholds per Day"
|
||||
self.profits[self.profits["Penalty"] == 0],
|
||||
x="Date",
|
||||
y=["Charge Threshold", "Discharge Threshold"],
|
||||
title="Charge and Discharge Thresholds per Day",
|
||||
)
|
||||
|
||||
fig.update_layout(
|
||||
@@ -129,24 +185,62 @@ class PolicyEvaluator:
|
||||
)
|
||||
|
||||
self.task.get_logger().report_plotly(
|
||||
"Thresholds per Day",
|
||||
"Thresholds per Day",
|
||||
iteration=0,
|
||||
figure=fig
|
||||
"Thresholds per Day", "Thresholds per Day", iteration=0, figure=fig
|
||||
)
|
||||
|
||||
def get_profits_as_scalars(self):
|
||||
aggregated = self.profits.groupby("Penalty").agg(
|
||||
Total_Profit=("Profit", "sum"),
|
||||
Total_Charge_Cycles=("Charge Cycles", "sum"),
|
||||
Num_Days=("Date", "nunique")
|
||||
Num_Days=("Date", "nunique"),
|
||||
)
|
||||
aggregated["Profit_Per_Year"] = (
|
||||
aggregated["Total_Profit"] / aggregated["Num_Days"] * 365
|
||||
)
|
||||
aggregated["Charge_Cycles_Per_Year"] = (
|
||||
aggregated["Total_Charge_Cycles"] / aggregated["Num_Days"] * 365
|
||||
)
|
||||
aggregated["Profit_Per_Year"] = aggregated["Total_Profit"] / aggregated["Num_Days"] * 365
|
||||
aggregated["Charge_Cycles_Per_Year"] = aggregated["Total_Charge_Cycles"] / aggregated["Num_Days"] * 365
|
||||
|
||||
# Reset index to make 'Penalty' a column again and drop unnecessary columns
|
||||
final_df = aggregated.reset_index().drop(columns=["Total_Profit", "Total_Charge_Cycles", "Num_Days"])
|
||||
final_df = aggregated.reset_index().drop(
|
||||
columns=["Total_Profit", "Total_Charge_Cycles", "Num_Days"]
|
||||
)
|
||||
|
||||
# Rename columns to match expected output
|
||||
final_df.columns = ["Penalty", "Total Profit", "Total Charge Cycles"]
|
||||
return final_df
|
||||
return final_df
|
||||
|
||||
def get_profits_till_400(self):
|
||||
# calculates profits until 400 charge cycles per year are reached
|
||||
number_of_days = len(self.profits["Date"].unique())
|
||||
usable_charge_cycles = (400 / 365) * number_of_days
|
||||
|
||||
# now sum the profit until the usable charge cycles are reached
|
||||
penalty_profits = {}
|
||||
penalty_charge_cycles = {}
|
||||
|
||||
for index, row in self.profits.iterrows():
|
||||
penalty = row["Penalty"]
|
||||
profit = row["Profit"]
|
||||
charge_cycles = row["Charge Cycles"]
|
||||
|
||||
if penalty not in penalty_profits:
|
||||
penalty_profits[penalty] = 0
|
||||
penalty_charge_cycles[penalty] = 0
|
||||
|
||||
if penalty_charge_cycles[penalty] < usable_charge_cycles:
|
||||
penalty_profits[penalty] += profit
|
||||
penalty_charge_cycles[penalty] += charge_cycles
|
||||
|
||||
df = pd.DataFrame(
|
||||
list(
|
||||
zip(
|
||||
penalty_profits.keys(),
|
||||
penalty_profits.values(),
|
||||
penalty_charge_cycles.values(),
|
||||
)
|
||||
),
|
||||
columns=["Penalty", "Profit_till_400", "Cycles_till_400"],
|
||||
)
|
||||
|
||||
return df
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from clearml import Task
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from src.policies.PolicyEvaluator import PolicyEvaluator
|
||||
from torchinfo import summary
|
||||
from src.losses.crps_metric import crps_from_samples
|
||||
from src.data.preprocessing import DataProcessor
|
||||
@@ -13,10 +14,18 @@ import seaborn as sns
|
||||
import matplotlib.patches as mpatches
|
||||
|
||||
|
||||
def sample_diffusion(model: DiffusionModel, n: int, inputs: torch.tensor, noise_steps=1000, beta_start=1e-4, beta_end=0.02, ts_length=96):
|
||||
def sample_diffusion(
|
||||
model: DiffusionModel,
|
||||
n: int,
|
||||
inputs: torch.tensor,
|
||||
noise_steps=1000,
|
||||
beta_start=1e-4,
|
||||
beta_end=0.02,
|
||||
ts_length=96,
|
||||
):
|
||||
device = next(model.parameters()).device
|
||||
beta = torch.linspace(beta_start, beta_end, noise_steps).to(device)
|
||||
alpha = 1. - beta
|
||||
alpha = 1.0 - beta
|
||||
alpha_hat = torch.cumprod(alpha, dim=0)
|
||||
|
||||
if len(inputs.shape) == 2:
|
||||
@@ -39,13 +48,24 @@ def sample_diffusion(model: DiffusionModel, n: int, inputs: torch.tensor, noise_
|
||||
else:
|
||||
noise = torch.zeros_like(x)
|
||||
|
||||
x = 1/torch.sqrt(_alpha) * (x-((1-_alpha) / (torch.sqrt(1 - _alpha_hat))) * predicted_noise) + torch.sqrt(_beta) * noise
|
||||
x = (
|
||||
1
|
||||
/ torch.sqrt(_alpha)
|
||||
* (x - ((1 - _alpha) / (torch.sqrt(1 - _alpha_hat))) * predicted_noise)
|
||||
+ torch.sqrt(_beta) * noise
|
||||
)
|
||||
x = torch.clamp(x, -1.0, 1.0)
|
||||
return x
|
||||
|
||||
|
||||
class DiffusionTrainer:
|
||||
def __init__(self, model: nn.Module, data_processor: DataProcessor, device: torch.device):
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
data_processor: DataProcessor,
|
||||
device: torch.device,
|
||||
policy_evaluator: PolicyEvaluator = None,
|
||||
):
|
||||
self.model = model
|
||||
self.device = device
|
||||
|
||||
@@ -53,39 +73,49 @@ class DiffusionTrainer:
|
||||
self.beta_start = 0.0001
|
||||
self.beta_end = 0.02
|
||||
self.ts_length = 96
|
||||
|
||||
|
||||
self.data_processor = data_processor
|
||||
|
||||
self.beta = torch.linspace(self.beta_start, self.beta_end, self.noise_steps).to(self.device)
|
||||
self.alpha = 1. - self.beta
|
||||
self.beta = torch.linspace(self.beta_start, self.beta_end, self.noise_steps).to(
|
||||
self.device
|
||||
)
|
||||
self.alpha = 1.0 - self.beta
|
||||
self.alpha_hat = torch.cumprod(self.alpha, dim=0)
|
||||
|
||||
self.best_score = None
|
||||
self.policy_evaluator = policy_evaluator
|
||||
|
||||
def noise_time_series(self, x: torch.tensor, t: int):
|
||||
""" Add noise to time series
|
||||
"""Add noise to time series
|
||||
Args:
|
||||
x (torch.tensor): shape (batch_size, time_steps)
|
||||
t (int): index of time step
|
||||
"""
|
||||
sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None]
|
||||
sqrt_one_minus_alpha_hat = torch.sqrt(1. - self.alpha_hat[t])[:, None]
|
||||
sqrt_one_minus_alpha_hat = torch.sqrt(1.0 - self.alpha_hat[t])[:, None]
|
||||
noise = torch.randn_like(x)
|
||||
return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * noise, noise
|
||||
|
||||
|
||||
def sample_timesteps(self, n: int):
|
||||
""" Sample timesteps for noise
|
||||
"""Sample timesteps for noise
|
||||
Args:
|
||||
n (int): number of samples
|
||||
"""
|
||||
return torch.randint(low=1, high=self.noise_steps, size=(n,))
|
||||
|
||||
|
||||
def sample(self, model: DiffusionModel, n: int, inputs: torch.tensor):
|
||||
x = sample_diffusion(model, n, inputs, self.noise_steps, self.beta_start, self.beta_end, self.ts_length)
|
||||
x = sample_diffusion(
|
||||
model,
|
||||
n,
|
||||
inputs,
|
||||
self.noise_steps,
|
||||
self.beta_start,
|
||||
self.beta_end,
|
||||
self.ts_length,
|
||||
)
|
||||
model.train()
|
||||
return x
|
||||
|
||||
|
||||
def random_samples(self, train: bool = True, num_samples: int = 10):
|
||||
train_loader, test_loader = self.data_processor.get_dataloaders(
|
||||
predict_sequence_length=96
|
||||
@@ -99,15 +129,17 @@ class DiffusionTrainer:
|
||||
# set seed
|
||||
np.random.seed(42)
|
||||
|
||||
actual_indices = np.random.choice(loader.dataset.full_day_valid_indices, num_samples, replace=False)
|
||||
actual_indices = np.random.choice(
|
||||
loader.dataset.full_day_valid_indices, num_samples, replace=False
|
||||
)
|
||||
indices = {}
|
||||
for i in actual_indices:
|
||||
indices[i] = loader.dataset.valid_indices.index(i)
|
||||
|
||||
print(actual_indices)
|
||||
|
||||
|
||||
return indices
|
||||
|
||||
|
||||
def init_clearml_task(self, task):
|
||||
task.add_tags(self.model.__class__.__name__)
|
||||
task.add_tags(self.__class__.__name__)
|
||||
@@ -117,13 +149,24 @@ class DiffusionTrainer:
|
||||
|
||||
if self.data_processor.lstm:
|
||||
inputDim = self.data_processor.get_input_size()
|
||||
other_input_data = torch.randn(1024, inputDim[1], self.model.other_inputs_dim).to(self.device)
|
||||
other_input_data = torch.randn(
|
||||
1024, inputDim[1], self.model.other_inputs_dim
|
||||
).to(self.device)
|
||||
else:
|
||||
other_input_data = torch.randn(1024, self.model.other_inputs_dim).to(self.device)
|
||||
task.set_configuration_object("model", str(summary(self.model, input_data=[input_data, time_steps, other_input_data])))
|
||||
other_input_data = torch.randn(1024, self.model.other_inputs_dim).to(
|
||||
self.device
|
||||
)
|
||||
task.set_configuration_object(
|
||||
"model",
|
||||
str(
|
||||
summary(
|
||||
self.model, input_data=[input_data, time_steps, other_input_data]
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
self.data_processor = task.connect(self.data_processor, name="data_processor")
|
||||
|
||||
|
||||
def train(self, epochs: int, learning_rate: float, task: Task = None):
|
||||
self.best_score = None
|
||||
optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)
|
||||
@@ -157,7 +200,7 @@ class DiffusionTrainer:
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
|
||||
running_loss /= len(train_loader.dataset)
|
||||
|
||||
if epoch % 40 == 0 and epoch != 0:
|
||||
@@ -166,19 +209,22 @@ class DiffusionTrainer:
|
||||
if task:
|
||||
task.get_logger().report_scalar(
|
||||
title=criterion.__class__.__name__,
|
||||
series='train',
|
||||
series="train",
|
||||
iteration=epoch,
|
||||
value=loss.item(),
|
||||
)
|
||||
|
||||
if epoch % 150 == 0 and epoch != 0:
|
||||
self.debug_plots(task, True, train_loader, train_sample_indices, epoch)
|
||||
self.debug_plots(task, False, test_loader, test_sample_indices, epoch)
|
||||
self.debug_plots(
|
||||
task, True, train_loader, train_sample_indices, epoch
|
||||
)
|
||||
self.debug_plots(
|
||||
task, False, test_loader, test_sample_indices, epoch
|
||||
)
|
||||
|
||||
if task:
|
||||
task.close()
|
||||
|
||||
|
||||
def debug_plots(self, task, training: bool, data_loader, sample_indices, epoch):
|
||||
for actual_idx, idx in sample_indices.items():
|
||||
features, target, _ = data_loader.dataset[idx]
|
||||
@@ -191,7 +237,7 @@ class DiffusionTrainer:
|
||||
samples = self.sample(self.model, 100, features).cpu().numpy()
|
||||
samples = self.data_processor.inverse_transform(samples)
|
||||
target = self.data_processor.inverse_transform(target)
|
||||
|
||||
|
||||
ci_99_upper = np.quantile(samples, 0.995, axis=0)
|
||||
ci_99_lower = np.quantile(samples, 0.005, axis=0)
|
||||
|
||||
@@ -204,49 +250,100 @@ class DiffusionTrainer:
|
||||
ci_50_lower = np.quantile(samples, 0.25, axis=0)
|
||||
ci_50_upper = np.quantile(samples, 0.75, axis=0)
|
||||
|
||||
|
||||
sns.set_theme()
|
||||
time_steps = np.arange(0, 96)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(20, 10))
|
||||
ax.plot(time_steps, samples.mean(axis=0), label="Mean of NRV samples", linewidth=3)
|
||||
ax.plot(
|
||||
time_steps,
|
||||
samples.mean(axis=0),
|
||||
label="Mean of NRV samples",
|
||||
linewidth=3,
|
||||
)
|
||||
# ax.fill_between(time_steps, ci_lower, ci_upper, color='b', alpha=0.2, label='Full Interval')
|
||||
|
||||
ax.fill_between(time_steps, ci_99_lower, ci_99_upper, color='b', alpha=0.2, label='99% Interval')
|
||||
ax.fill_between(time_steps, ci_95_lower, ci_95_upper, color='b', alpha=0.2, label='95% Interval')
|
||||
ax.fill_between(time_steps, ci_90_lower, ci_90_upper, color='b', alpha=0.2, label='90% Interval')
|
||||
ax.fill_between(time_steps, ci_50_lower, ci_50_upper, color='b', alpha=0.2, label='50% Interval')
|
||||
ax.fill_between(
|
||||
time_steps,
|
||||
ci_99_lower,
|
||||
ci_99_upper,
|
||||
color="b",
|
||||
alpha=0.2,
|
||||
label="99% Interval",
|
||||
)
|
||||
ax.fill_between(
|
||||
time_steps,
|
||||
ci_95_lower,
|
||||
ci_95_upper,
|
||||
color="b",
|
||||
alpha=0.2,
|
||||
label="95% Interval",
|
||||
)
|
||||
ax.fill_between(
|
||||
time_steps,
|
||||
ci_90_lower,
|
||||
ci_90_upper,
|
||||
color="b",
|
||||
alpha=0.2,
|
||||
label="90% Interval",
|
||||
)
|
||||
ax.fill_between(
|
||||
time_steps,
|
||||
ci_50_lower,
|
||||
ci_50_upper,
|
||||
color="b",
|
||||
alpha=0.2,
|
||||
label="50% Interval",
|
||||
)
|
||||
|
||||
ax.plot(target, label="Real NRV", linewidth=3)
|
||||
# full_interval_patch = mpatches.Patch(color='b', alpha=0.2, label='Full Interval')
|
||||
ci_99_patch = mpatches.Patch(color='b', alpha=0.3, label='99% Interval')
|
||||
ci_95_patch = mpatches.Patch(color='b', alpha=0.4, label='95% Interval')
|
||||
ci_90_patch = mpatches.Patch(color='b', alpha=0.5, label='90% Interval')
|
||||
ci_50_patch = mpatches.Patch(color='b', alpha=0.6, label='50% Interval')
|
||||
ci_99_patch = mpatches.Patch(color="b", alpha=0.3, label="99% Interval")
|
||||
ci_95_patch = mpatches.Patch(color="b", alpha=0.4, label="95% Interval")
|
||||
ci_90_patch = mpatches.Patch(color="b", alpha=0.5, label="90% Interval")
|
||||
ci_50_patch = mpatches.Patch(color="b", alpha=0.6, label="50% Interval")
|
||||
|
||||
|
||||
ax.legend(handles=[ci_99_patch, ci_95_patch, ci_90_patch, ci_50_patch, ax.lines[0], ax.lines[1]])
|
||||
ax.legend(
|
||||
handles=[
|
||||
ci_99_patch,
|
||||
ci_95_patch,
|
||||
ci_90_patch,
|
||||
ci_50_patch,
|
||||
ax.lines[0],
|
||||
ax.lines[1],
|
||||
]
|
||||
)
|
||||
|
||||
task.get_logger().report_matplotlib_figure(
|
||||
title="Training" if training else "Testing",
|
||||
series=f'Sample {actual_idx}',
|
||||
series=f"Sample {actual_idx}",
|
||||
iteration=epoch,
|
||||
figure=fig,
|
||||
)
|
||||
|
||||
plt.close()
|
||||
|
||||
def test(self, data_loader: torch.utils.data.DataLoader, epoch: int, task: Task = None):
|
||||
def test(
|
||||
self, data_loader: torch.utils.data.DataLoader, epoch: int, task: Task = None
|
||||
):
|
||||
all_crps = []
|
||||
for inputs, targets, _ in data_loader:
|
||||
generated_samples = {}
|
||||
for inputs, targets, idx_batch in data_loader:
|
||||
inputs, targets = inputs.to(self.device), targets.to(self.device)
|
||||
|
||||
print(inputs.shape, targets.shape)
|
||||
|
||||
number_of_samples = 100
|
||||
sample = self.sample(self.model, number_of_samples, inputs)
|
||||
|
||||
# reduce samples from (batch_size*number_of_samples, time_steps) to (batch_size, number_of_samples, time_steps)
|
||||
samples_batched = sample.reshape(inputs.shape[0], number_of_samples, 96)
|
||||
|
||||
# add samples to generated_samples generated_samples[idx.item()] = (initial, samples)
|
||||
for i, (idx, samples) in enumerate(zip(idx_batch, samples_batched)):
|
||||
generated_samples[idx.item()] = (
|
||||
self.data_processor.inverse_transform(inputs[i][:96]),
|
||||
self.data_processor.inverse_transform(samples),
|
||||
)
|
||||
|
||||
# calculate crps
|
||||
crps = crps_from_samples(samples_batched, targets)
|
||||
crps_mean = crps.mean(axis=1)
|
||||
@@ -262,16 +359,38 @@ class DiffusionTrainer:
|
||||
|
||||
if task:
|
||||
task.get_logger().report_scalar(
|
||||
title="CRPS",
|
||||
series='test',
|
||||
value=mean_crps,
|
||||
iteration=epoch
|
||||
title="CRPS", series="test", value=mean_crps, iteration=epoch
|
||||
)
|
||||
|
||||
if self.policy_evaluator:
|
||||
_, test_loader = self.data_processor.get_dataloaders(
|
||||
predict_sequence_length=self.ts_length, full_day_skip=True
|
||||
)
|
||||
|
||||
self.policy_evaluator.evaluate_test_set(generated_samples, test_loader)
|
||||
|
||||
df = self.policy_evaluator.get_profits_as_scalars()
|
||||
|
||||
for idx, row in df.iterrows():
|
||||
task.get_logger().report_scalar(
|
||||
title="Profit",
|
||||
series=f"penalty_{row['Penalty']}",
|
||||
value=row["Total Profit"],
|
||||
iteration=epoch,
|
||||
)
|
||||
|
||||
df = self.policy_evaluator.get_profits_till_400()
|
||||
for idx, row in df.iterrows():
|
||||
task.get_logger().report_scalar(
|
||||
title="Profit_till_400",
|
||||
series=f"penalty_{row['Penalty']}",
|
||||
value=row["Profit_till_400"],
|
||||
iteration=epoch,
|
||||
)
|
||||
|
||||
def save_checkpoint(self, val_loss, task, iteration: int):
|
||||
torch.save(self.model, "checkpoint.pt")
|
||||
task.update_output_model(
|
||||
model_path="checkpoint.pt", iteration=iteration, auto_delete_file=False
|
||||
)
|
||||
self.best_score = val_loss
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
import matplotlib.patches as mpatches
|
||||
|
||||
|
||||
def sample_from_dist(quantiles, preds):
|
||||
if isinstance(preds, torch.Tensor):
|
||||
preds = preds.detach().cpu()
|
||||
@@ -31,10 +32,11 @@ def sample_from_dist(quantiles, preds):
|
||||
|
||||
# random probabilities of (1000, 1)
|
||||
import random
|
||||
|
||||
probs = np.array([random.random() for _ in range(1000)])
|
||||
|
||||
spline = CubicSpline(quantiles, preds, axis=1)
|
||||
|
||||
|
||||
samples = spline(probs)
|
||||
|
||||
# get the diagonal
|
||||
@@ -42,6 +44,7 @@ def sample_from_dist(quantiles, preds):
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
def auto_regressive(dataset, model, quantiles, idx_batch, sequence_length: int = 96):
|
||||
device = next(model.parameters()).device
|
||||
prev_features, targets = dataset.get_batch(idx_batch)
|
||||
@@ -65,7 +68,7 @@ def auto_regressive(dataset, model, quantiles, idx_batch, sequence_length: int =
|
||||
predictions_full = new_predictions_full.unsqueeze(1)
|
||||
|
||||
for i in range(sequence_length - 1):
|
||||
if len(list(prev_features.shape)) == 2:
|
||||
if len(list(prev_features.shape)) == 2:
|
||||
new_features = torch.cat(
|
||||
(prev_features[:, 1:96], samples), dim=1
|
||||
) # (batch_size, 96)
|
||||
@@ -102,9 +105,7 @@ def auto_regressive(dataset, model, quantiles, idx_batch, sequence_length: int =
|
||||
) # (batch_size, sequence_length)
|
||||
|
||||
with torch.no_grad():
|
||||
new_predictions_full = model(
|
||||
prev_features
|
||||
) # (batch_size, quantiles)
|
||||
new_predictions_full = model(prev_features) # (batch_size, quantiles)
|
||||
predictions_full = torch.cat(
|
||||
(predictions_full, new_predictions_full.unsqueeze(1)), dim=1
|
||||
) # (batch_size, sequence_length, quantiles)
|
||||
@@ -123,6 +124,7 @@ def auto_regressive(dataset, model, quantiles, idx_batch, sequence_length: int =
|
||||
target_full.unsqueeze(-1),
|
||||
)
|
||||
|
||||
|
||||
class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -162,40 +164,58 @@ class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer):
|
||||
|
||||
if len(idx_batch) == 0:
|
||||
continue
|
||||
|
||||
|
||||
for idx in tqdm(idx_batch):
|
||||
computed_idx_batch = [idx] * 100
|
||||
initial, _, samples, targets = self.auto_regressive(
|
||||
dataloader.dataset, idx_batch=computed_idx_batch
|
||||
)
|
||||
|
||||
generated_samples[idx.item()] = (initial, self.data_processor.inverse_transform(samples))
|
||||
generated_samples[idx.item()] = (
|
||||
self.data_processor.inverse_transform(initial),
|
||||
self.data_processor.inverse_transform(samples),
|
||||
)
|
||||
|
||||
samples = samples.unsqueeze(0)
|
||||
targets = targets.squeeze(-1)
|
||||
targets = targets[0].unsqueeze(0)
|
||||
|
||||
|
||||
crps = crps_from_samples(samples, targets)
|
||||
|
||||
crps_from_samples_metric.append(crps[0].mean().item())
|
||||
|
||||
task.get_logger().report_scalar(
|
||||
title="CRPS_from_samples", series="test", value=np.mean(crps_from_samples_metric), iteration=epoch
|
||||
title="CRPS_from_samples",
|
||||
series="test",
|
||||
value=np.mean(crps_from_samples_metric),
|
||||
iteration=epoch,
|
||||
)
|
||||
|
||||
# using the policy evaluator, evaluate the policy with the generated samples
|
||||
if self.policy_evaluator is not None:
|
||||
_, test_loader = self.data_processor.get_dataloaders(
|
||||
predict_sequence_length=self.model.output_size)
|
||||
predict_sequence_length=self.model.output_size, full_day_skip=True
|
||||
)
|
||||
self.policy_evaluator.evaluate_test_set(generated_samples, test_loader)
|
||||
df = self.policy_evaluator.get_profits_as_scalars()
|
||||
|
||||
# for each row, report the profits
|
||||
for idx, row in df.iterrows():
|
||||
task.get_logger().report_scalar(
|
||||
title="Profit", series=f"penalty_{row['Penalty']}", value=row["Total Profit"], iteration=epoch
|
||||
title="Profit",
|
||||
series=f"penalty_{row['Penalty']}",
|
||||
value=row["Total Profit"],
|
||||
iteration=epoch,
|
||||
)
|
||||
|
||||
df = self.policy_evaluator.get_profits_till_400()
|
||||
for idx, row in df.iterrows():
|
||||
task.get_logger().report_scalar(
|
||||
title="Profit_till_400",
|
||||
series=f"penalty_{row['Penalty']}",
|
||||
value=row["Profit_till_400"],
|
||||
iteration=epoch,
|
||||
)
|
||||
|
||||
def log_final_metrics(self, task, dataloader, train: bool = True):
|
||||
metrics = {metric.__class__.__name__: 0.0 for metric in self.metrics_to_track}
|
||||
@@ -222,17 +242,19 @@ class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer):
|
||||
)
|
||||
|
||||
# save the samples for the idx, these will be used for evaluating the policy
|
||||
self.test_set_samples[idx.item()] = (initial, self.data_processor.inverse_transform(samples))
|
||||
self.test_set_samples[idx.item()] = (
|
||||
self.data_processor.inverse_transform(initial),
|
||||
self.data_processor.inverse_transform(samples),
|
||||
)
|
||||
|
||||
samples = samples.unsqueeze(0)
|
||||
targets = targets.squeeze(-1)
|
||||
targets = targets[0].unsqueeze(0)
|
||||
|
||||
|
||||
crps = crps_from_samples(samples, targets)
|
||||
|
||||
crps_from_samples_metric.append(crps[0].mean().item())
|
||||
|
||||
|
||||
_, outputs, samples, targets = self.auto_regressive(
|
||||
dataloader.dataset, idx_batch=idx_batch
|
||||
)
|
||||
@@ -286,7 +308,8 @@ class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer):
|
||||
|
||||
if train == False:
|
||||
task.get_logger().report_single_value(
|
||||
name="test_CRPS_from_samples_transformed", value=np.mean(crps_from_samples_metric)
|
||||
name="test_CRPS_from_samples_transformed",
|
||||
value=np.mean(crps_from_samples_metric),
|
||||
)
|
||||
|
||||
# def get_plot_error(
|
||||
@@ -313,13 +336,12 @@ class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer):
|
||||
|
||||
# errors.append(metric(prediction_tensor, target_tensor))
|
||||
|
||||
# # plot the error
|
||||
# # plot the error
|
||||
# fig.add_trace(go.Scatter(x=np.arange(96), y=errors, name=metric.__class__.__name__))
|
||||
# fig.update_layout(title=f"Error of {metric.__class__.__name__} for each time step")
|
||||
|
||||
# return fig
|
||||
|
||||
|
||||
def get_plot(
|
||||
self,
|
||||
current_day,
|
||||
@@ -376,30 +398,78 @@ class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer):
|
||||
time_steps = np.arange(0, 96)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(20, 10))
|
||||
ax.plot(time_steps, predictions_np.mean(axis=0), label="Mean of NRV samples", linewidth=3)
|
||||
ax.plot(
|
||||
time_steps,
|
||||
predictions_np.mean(axis=0),
|
||||
label="Mean of NRV samples",
|
||||
linewidth=3,
|
||||
)
|
||||
# ax.fill_between(time_steps, ci_lower, ci_upper, color='b', alpha=0.2, label='Full Interval')
|
||||
|
||||
ax.fill_between(time_steps, ci_99_lower, ci_99_upper, color='b', alpha=0.2, label='99% Interval')
|
||||
ax.fill_between(time_steps, ci_95_lower, ci_95_upper, color='b', alpha=0.2, label='95% Interval')
|
||||
ax.fill_between(time_steps, ci_90_lower, ci_90_upper, color='b', alpha=0.2, label='90% Interval')
|
||||
ax.fill_between(time_steps, ci_50_lower, ci_50_upper, color='b', alpha=0.2, label='50% Interval')
|
||||
ax.fill_between(
|
||||
time_steps,
|
||||
ci_99_lower,
|
||||
ci_99_upper,
|
||||
color="b",
|
||||
alpha=0.2,
|
||||
label="99% Interval",
|
||||
)
|
||||
ax.fill_between(
|
||||
time_steps,
|
||||
ci_95_lower,
|
||||
ci_95_upper,
|
||||
color="b",
|
||||
alpha=0.2,
|
||||
label="95% Interval",
|
||||
)
|
||||
ax.fill_between(
|
||||
time_steps,
|
||||
ci_90_lower,
|
||||
ci_90_upper,
|
||||
color="b",
|
||||
alpha=0.2,
|
||||
label="90% Interval",
|
||||
)
|
||||
ax.fill_between(
|
||||
time_steps,
|
||||
ci_50_lower,
|
||||
ci_50_upper,
|
||||
color="b",
|
||||
alpha=0.2,
|
||||
label="50% Interval",
|
||||
)
|
||||
|
||||
ax.plot(next_day_np, label="Real NRV", linewidth=3)
|
||||
# full_interval_patch = mpatches.Patch(color='b', alpha=0.2, label='Full Interval')
|
||||
ci_99_patch = mpatches.Patch(color='b', alpha=0.3, label='99% Interval')
|
||||
ci_95_patch = mpatches.Patch(color='b', alpha=0.4, label='95% Interval')
|
||||
ci_90_patch = mpatches.Patch(color='b', alpha=0.5, label='90% Interval')
|
||||
ci_50_patch = mpatches.Patch(color='b', alpha=0.6, label='50% Interval')
|
||||
ci_99_patch = mpatches.Patch(color="b", alpha=0.3, label="99% Interval")
|
||||
ci_95_patch = mpatches.Patch(color="b", alpha=0.4, label="95% Interval")
|
||||
ci_90_patch = mpatches.Patch(color="b", alpha=0.5, label="90% Interval")
|
||||
ci_50_patch = mpatches.Patch(color="b", alpha=0.6, label="50% Interval")
|
||||
|
||||
|
||||
ax.legend(handles=[ci_99_patch, ci_95_patch, ci_90_patch, ci_50_patch, ax.lines[0], ax.lines[1]])
|
||||
ax.legend(
|
||||
handles=[
|
||||
ci_99_patch,
|
||||
ci_95_patch,
|
||||
ci_90_patch,
|
||||
ci_50_patch,
|
||||
ax.lines[0],
|
||||
ax.lines[1],
|
||||
]
|
||||
)
|
||||
return fig
|
||||
|
||||
def auto_regressive(self, dataset, idx_batch, sequence_length: int = 96):
|
||||
return auto_regressive(dataset, self.model, self.quantiles, idx_batch, sequence_length)
|
||||
return auto_regressive(
|
||||
dataset, self.model, self.quantiles, idx_batch, sequence_length
|
||||
)
|
||||
|
||||
def plot_quantile_percentages(
|
||||
self, task, data_loader, train: bool = True, iteration: int = None, full_day: bool = False
|
||||
self,
|
||||
task,
|
||||
data_loader,
|
||||
train: bool = True,
|
||||
iteration: int = None,
|
||||
full_day: bool = False,
|
||||
):
|
||||
quantiles = self.quantiles
|
||||
total = 0
|
||||
@@ -429,20 +499,18 @@ class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer):
|
||||
|
||||
else:
|
||||
inputs = inputs.to(self.device)
|
||||
outputs = self.model(inputs).cpu().numpy() # (batch_size, num_quantiles)
|
||||
targets = targets.squeeze(-1).cpu().numpy() # (batch_size, 1)
|
||||
outputs = (
|
||||
self.model(inputs).cpu().numpy()
|
||||
) # (batch_size, num_quantiles)
|
||||
targets = targets.squeeze(-1).cpu().numpy() # (batch_size, 1)
|
||||
|
||||
for i, q in enumerate(quantiles):
|
||||
quantile_counter[q] += np.sum(
|
||||
targets < outputs[:, i]
|
||||
)
|
||||
quantile_counter[q] += np.sum(targets < outputs[:, i])
|
||||
|
||||
total += len(targets)
|
||||
|
||||
# to numpy array of length len(quantiles)
|
||||
percentages = np.array(
|
||||
[quantile_counter[q] / total for q in quantiles]
|
||||
)
|
||||
percentages = np.array([quantile_counter[q] / total for q in quantiles])
|
||||
|
||||
bar_width = 0.35
|
||||
index = np.arange(len(quantiles))
|
||||
@@ -450,9 +518,7 @@ class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer):
|
||||
# Plotting the bars
|
||||
fig, ax = plt.subplots(figsize=(15, 10))
|
||||
|
||||
bar1 = ax.bar(
|
||||
index, quantiles, bar_width, label="Ideal", color="brown"
|
||||
)
|
||||
bar1 = ax.bar(index, quantiles, bar_width, label="Ideal", color="brown")
|
||||
bar2 = ax.bar(
|
||||
index + bar_width, percentages, bar_width, label="NN model", color="blue"
|
||||
)
|
||||
@@ -502,7 +568,6 @@ class NonAutoRegressiveQuantileRegression(Trainer):
|
||||
):
|
||||
self.quantiles = quantiles
|
||||
|
||||
|
||||
criterion = NonAutoRegressivePinballLoss(quantiles=quantiles)
|
||||
super().__init__(
|
||||
model=model,
|
||||
|
||||
@@ -1,3 +1,12 @@
|
||||
from src.utils.clearml import ClearMLHelper
|
||||
|
||||
#### ClearML ####
|
||||
clearml_helper = ClearMLHelper(project_name="Thesis/NrvForecast")
|
||||
task = clearml_helper.get_task(
|
||||
task_name="Autoregressive Quantile Regression: Non Linear"
|
||||
)
|
||||
task.execute_remotely(queue_name="default", exit_process=True)
|
||||
|
||||
from src.policies.PolicyEvaluator import PolicyEvaluator
|
||||
from src.policies.simple_baseline import BaselinePolicy, Battery
|
||||
from src.models.lstm_model import GRUModel
|
||||
@@ -13,11 +22,6 @@ import torch.nn as nn
|
||||
from src.models.time_embedding_layer import TimeEmbedding
|
||||
|
||||
|
||||
#### ClearML ####
|
||||
clearml_helper = ClearMLHelper(project_name="Thesis/NrvForecast")
|
||||
task = clearml_helper.get_task(task_name="Autoregressive Quantile Regression: Non Linear")
|
||||
|
||||
|
||||
#### Data Processor ####
|
||||
data_config = DataConfig()
|
||||
|
||||
@@ -34,7 +38,6 @@ data_config.DAY_OF_WEEK = True
|
||||
data_config.NOMINAL_NET_POSITION = True
|
||||
|
||||
|
||||
|
||||
data_config = task.connect(data_config, name="data_features")
|
||||
|
||||
data_processor = DataProcessor(data_config, path="", lstm=False)
|
||||
@@ -68,9 +71,17 @@ model_parameters = {
|
||||
|
||||
model_parameters = task.connect(model_parameters, name="model_parameters")
|
||||
|
||||
time_embedding = TimeEmbedding(data_processor.get_time_feature_size(), model_parameters["time_feature_embedding"])
|
||||
time_embedding = TimeEmbedding(
|
||||
data_processor.get_time_feature_size(), model_parameters["time_feature_embedding"]
|
||||
)
|
||||
# lstm_model = GRUModel(time_embedding.output_dim(inputDim), len(quantiles), hidden_size=model_parameters["hidden_size"], num_layers=model_parameters["num_layers"], dropout=model_parameters["dropout"])
|
||||
non_linear_model = NonLinearRegression(time_embedding.output_dim(inputDim), len(quantiles), hiddenSize=model_parameters["hidden_size"], numLayers=model_parameters["num_layers"], dropout=model_parameters["dropout"])
|
||||
non_linear_model = NonLinearRegression(
|
||||
time_embedding.output_dim(inputDim),
|
||||
len(quantiles),
|
||||
hiddenSize=model_parameters["hidden_size"],
|
||||
numLayers=model_parameters["num_layers"],
|
||||
dropout=model_parameters["dropout"],
|
||||
)
|
||||
# linear_model = LinearRegression(time_embedding.output_dim(inputDim), len(quantiles))
|
||||
|
||||
model = nn.Sequential(time_embedding, non_linear_model)
|
||||
@@ -103,10 +114,11 @@ trainer.train(task=task, epochs=epochs, remotely=True)
|
||||
### Policy Evaluation ###
|
||||
idx_samples = trainer.test_set_samples
|
||||
_, test_loader = trainer.data_processor.get_dataloaders(
|
||||
predict_sequence_length=trainer.model.output_size)
|
||||
predict_sequence_length=trainer.model.output_size, full_day_skip=True
|
||||
)
|
||||
|
||||
policy_evaluator.evaluate_test_set(idx_samples, test_loader)
|
||||
policy_evaluator.plot_profits_table()
|
||||
policy_evaluator.plot_thresholds_per_day()
|
||||
|
||||
task.close()
|
||||
task.close()
|
||||
|
||||
@@ -1,25 +1,17 @@
|
||||
from clearml import Task
|
||||
from src.data import DataProcessor, DataConfig
|
||||
from src.trainers.trainer import Trainer
|
||||
from src.utils.clearml import ClearMLHelper
|
||||
from src.models import *
|
||||
from src.losses import *
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch.nn import MSELoss, L1Loss
|
||||
from datetime import datetime
|
||||
import torch.nn as nn
|
||||
from src.models.time_embedding_layer import TimeEmbedding
|
||||
from src.models.diffusion_model import GRUDiffusionModel, SimpleDiffusionModel
|
||||
from src.trainers.diffusion_trainer import DiffusionTrainer
|
||||
|
||||
|
||||
clearml_helper = ClearMLHelper(project_name="Thesis/NrvForecast")
|
||||
task = clearml_helper.get_task(task_name="Diffusion Training")
|
||||
|
||||
# execute remotely
|
||||
task.execute_remotely(queue_name="default", exit_process=True)
|
||||
print("Running remotely")
|
||||
|
||||
from src.models import *
|
||||
from src.losses import *
|
||||
from src.models.time_embedding_layer import TimeEmbedding
|
||||
from src.models.diffusion_model import GRUDiffusionModel, SimpleDiffusionModel
|
||||
from src.trainers.diffusion_trainer import DiffusionTrainer
|
||||
from src.data import DataProcessor, DataConfig
|
||||
from src.policies.simple_baseline import BaselinePolicy, Battery
|
||||
from src.policies.PolicyEvaluator import PolicyEvaluator
|
||||
|
||||
#### Data Processor ####
|
||||
data_config = DataConfig()
|
||||
@@ -54,11 +46,21 @@ model_parameters = {
|
||||
model_parameters = task.connect(model_parameters, name="model_parameters")
|
||||
|
||||
#### Model ####
|
||||
model = SimpleDiffusionModel(96, model_parameters["hidden_sizes"], other_inputs_dim=inputDim[1], time_dim=model_parameters["time_dim"])
|
||||
model = SimpleDiffusionModel(
|
||||
96,
|
||||
model_parameters["hidden_sizes"],
|
||||
other_inputs_dim=inputDim[1],
|
||||
time_dim=model_parameters["time_dim"],
|
||||
)
|
||||
# model = GRUDiffusionModel(96, model_parameters["hidden_sizes"], other_inputs_dim=inputDim[2], time_dim=model_parameters["time_dim"], gru_hidden_size=128)
|
||||
|
||||
print("Starting training ...")
|
||||
### Policy Evaluator ###
|
||||
battery = Battery(2, 1)
|
||||
baseline_policy = BaselinePolicy(battery, data_path="")
|
||||
policy_evaluator = PolicyEvaluator(baseline_policy, task)
|
||||
|
||||
#### Trainer ####
|
||||
trainer = DiffusionTrainer(model, data_processor, "cuda")
|
||||
trainer.train(model_parameters["epochs"], model_parameters["learning_rate"], task)
|
||||
trainer = DiffusionTrainer(
|
||||
model, data_processor, "cuda", policy_evaluator=policy_evaluator
|
||||
)
|
||||
trainer.train(model_parameters["epochs"], model_parameters["learning_rate"], task)
|
||||
|
||||
Reference in New Issue
Block a user