Fixed policy evaluation for autoregressive
This commit is contained in:
@@ -174,9 +174,6 @@ TODO:
|
|||||||
|
|
||||||
Visualizatie van thresholds over test set voor baselines en complexere modellen -> zonder penalties tonen
|
Visualizatie van thresholds over test set voor baselines en complexere modellen -> zonder penalties tonen
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
1 a 2 Case studies (extreme gevallen, thresholds 150, -5, normale mss)
|
1 a 2 Case studies (extreme gevallen, thresholds 150, -5, normale mss)
|
||||||
|
|
||||||
- Generatie van NRV (echte NRV)
|
- Generatie van NRV (echte NRV)
|
||||||
@@ -197,4 +194,12 @@ Inleiding +
|
|||||||
Literatuurstudie +
|
Literatuurstudie +
|
||||||
Tabellen die we gaan bespreken -> updaten met nieuwe data dan
|
Tabellen die we gaan bespreken -> updaten met nieuwe data dan
|
||||||
|
|
||||||
Nog eens 3e meeting opbrengen voor 2e deel maart.
|
|
||||||
|
!!!!! Fix the test set (maybe save pickle)
|
||||||
|
|
||||||
|
Plot profits per year (maybe with charge cycles) for the different models and baselines.
|
||||||
|
Spread plotten van difference between charge and discharge thresholds
|
||||||
|
|
||||||
|
RESULTATEN FIXEN
|
||||||
|
|
||||||
|
Nog eens 3e meeting opbrengen voor 2e deel maart.
|
||||||
|
|||||||
@@ -25,18 +25,22 @@ class NrvDataset(Dataset):
|
|||||||
self.sequence_length = sequence_length
|
self.sequence_length = sequence_length
|
||||||
self.predict_sequence_length = predict_sequence_length
|
self.predict_sequence_length = predict_sequence_length
|
||||||
|
|
||||||
self.samples_to_skip = self.skip_samples(dataframe=dataframe, full_day_skip=self.full_day_skip)
|
self.samples_to_skip = self.skip_samples(
|
||||||
|
dataframe=dataframe, full_day_skip=self.full_day_skip
|
||||||
|
)
|
||||||
total_indices = set(
|
total_indices = set(
|
||||||
range(len(dataframe) - self.sequence_length - self.predict_sequence_length)
|
range(len(dataframe) - self.sequence_length - self.predict_sequence_length)
|
||||||
)
|
)
|
||||||
self.valid_indices = sorted(list(total_indices - set(self.samples_to_skip)))
|
self.valid_indices = sorted(list(total_indices - set(self.samples_to_skip)))
|
||||||
|
|
||||||
# full day indices
|
full_day_skipped_samples = self.skip_samples(
|
||||||
full_day_skipped_samples = self.skip_samples(dataframe=dataframe, full_day_skip=True)
|
dataframe=dataframe, full_day_skip=True
|
||||||
full_day_total_indices = set(
|
)
|
||||||
range(len(dataframe) - self.sequence_length - self.predict_sequence_length)
|
|
||||||
|
full_day_total_indices = set(range(len(dataframe) - self.sequence_length - 96))
|
||||||
|
self.full_day_valid_indices = sorted(
|
||||||
|
list(full_day_total_indices - set(full_day_skipped_samples))
|
||||||
)
|
)
|
||||||
self.full_day_valid_indices = sorted(list(full_day_total_indices - set(full_day_skipped_samples)))
|
|
||||||
|
|
||||||
self.history_features = []
|
self.history_features = []
|
||||||
if self.data_config.LOAD_HISTORY:
|
if self.data_config.LOAD_HISTORY:
|
||||||
@@ -74,7 +78,7 @@ class NrvDataset(Dataset):
|
|||||||
self.time_feature = torch.tensor(time_feature).float().reshape(-1)
|
self.time_feature = torch.tensor(time_feature).float().reshape(-1)
|
||||||
else:
|
else:
|
||||||
self.time_feature = None
|
self.time_feature = None
|
||||||
|
|
||||||
self.nrv = torch.tensor(dataframe["nrv"].values).float().reshape(-1)
|
self.nrv = torch.tensor(dataframe["nrv"].values).float().reshape(-1)
|
||||||
self.datetime = dataframe["datetime"]
|
self.datetime = dataframe["datetime"]
|
||||||
|
|
||||||
@@ -84,12 +88,7 @@ class NrvDataset(Dataset):
|
|||||||
nan_rows = dataframe[dataframe.isnull().any(axis=1)]
|
nan_rows = dataframe[dataframe.isnull().any(axis=1)]
|
||||||
nan_indices = nan_rows.index
|
nan_indices = nan_rows.index
|
||||||
skip_indices = [
|
skip_indices = [
|
||||||
list(
|
list(range(idx - self.sequence_length - 96, idx + 1)) for idx in nan_indices
|
||||||
range(
|
|
||||||
idx - self.sequence_length - 96, idx + 1
|
|
||||||
)
|
|
||||||
)
|
|
||||||
for idx in nan_indices
|
|
||||||
]
|
]
|
||||||
|
|
||||||
skip_indices = [item for sublist in skip_indices for item in sublist]
|
skip_indices = [item for sublist in skip_indices for item in sublist]
|
||||||
@@ -106,10 +105,12 @@ class NrvDataset(Dataset):
|
|||||||
skip_indices = list(set(skip_indices))
|
skip_indices = list(set(skip_indices))
|
||||||
|
|
||||||
return skip_indices
|
return skip_indices
|
||||||
|
|
||||||
def preprocess_data(self, dataframe):
|
|
||||||
return torch.tensor(dataframe[self.history_features].values).float(), torch.tensor(dataframe[self.forecast_features].values).float()
|
|
||||||
|
|
||||||
|
def preprocess_data(self, dataframe):
|
||||||
|
return (
|
||||||
|
torch.tensor(dataframe[self.history_features].values).float(),
|
||||||
|
torch.tensor(dataframe[self.forecast_features].values).float(),
|
||||||
|
)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.valid_indices)
|
return len(self.valid_indices)
|
||||||
@@ -117,21 +118,38 @@ class NrvDataset(Dataset):
|
|||||||
def _get_all_data(self, idx: int):
|
def _get_all_data(self, idx: int):
|
||||||
history_df = self.dataframe.iloc[idx : idx + self.sequence_length]
|
history_df = self.dataframe.iloc[idx : idx + self.sequence_length]
|
||||||
forecast_df = self.dataframe.iloc[
|
forecast_df = self.dataframe.iloc[
|
||||||
idx + self.sequence_length : idx + self.sequence_length + self.predict_sequence_length
|
idx
|
||||||
|
+ self.sequence_length : idx
|
||||||
|
+ self.sequence_length
|
||||||
|
+ self.predict_sequence_length
|
||||||
]
|
]
|
||||||
return history_df, forecast_df
|
return history_df, forecast_df
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
actual_idx = self.valid_indices[idx]
|
try:
|
||||||
|
actual_idx = self.valid_indices[idx]
|
||||||
|
except IndexError:
|
||||||
|
print(f"Index {idx} not in valid indices.")
|
||||||
|
raise
|
||||||
|
|
||||||
# get nrv history features
|
# get nrv history features
|
||||||
nrv_features = self.nrv[actual_idx : actual_idx + self.sequence_length]
|
nrv_features = self.nrv[actual_idx : actual_idx + self.sequence_length]
|
||||||
|
|
||||||
history_features = self.history_features[actual_idx : actual_idx + self.sequence_length, :]
|
history_features = self.history_features[
|
||||||
forecast_features = self.forecast_features[actual_idx + self.sequence_length : actual_idx + self.sequence_length + self.predict_sequence_length, :]
|
actual_idx : actual_idx + self.sequence_length, :
|
||||||
|
]
|
||||||
|
forecast_features = self.forecast_features[
|
||||||
|
actual_idx
|
||||||
|
+ self.sequence_length : actual_idx
|
||||||
|
+ self.sequence_length
|
||||||
|
+ self.predict_sequence_length,
|
||||||
|
:,
|
||||||
|
]
|
||||||
|
|
||||||
if self.time_feature is not None:
|
if self.time_feature is not None:
|
||||||
time_features = self.time_feature[actual_idx : actual_idx + self.sequence_length]
|
time_features = self.time_feature[
|
||||||
|
actual_idx : actual_idx + self.sequence_length
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
time_features = None
|
time_features = None
|
||||||
|
|
||||||
@@ -154,7 +172,9 @@ class NrvDataset(Dataset):
|
|||||||
all_features_list = [nrv_features.unsqueeze(1)]
|
all_features_list = [nrv_features.unsqueeze(1)]
|
||||||
|
|
||||||
if self.forecast_features.numel() > 0:
|
if self.forecast_features.numel() > 0:
|
||||||
history_forecast_features = self.forecast_features[actual_idx + 1 : actual_idx + self.sequence_length + 1, :]
|
history_forecast_features = self.forecast_features[
|
||||||
|
actual_idx + 1 : actual_idx + self.sequence_length + 1, :
|
||||||
|
]
|
||||||
all_features_list.append(history_forecast_features)
|
all_features_list.append(history_forecast_features)
|
||||||
|
|
||||||
if time_features is not None:
|
if time_features is not None:
|
||||||
@@ -163,7 +183,12 @@ class NrvDataset(Dataset):
|
|||||||
all_features = torch.cat(all_features_list, dim=1)
|
all_features = torch.cat(all_features_list, dim=1)
|
||||||
|
|
||||||
# Target sequence, flattened if necessary
|
# Target sequence, flattened if necessary
|
||||||
nrv_target = self.nrv[actual_idx + self.sequence_length : actual_idx + self.sequence_length + self.predict_sequence_length]
|
nrv_target = self.nrv[
|
||||||
|
actual_idx
|
||||||
|
+ self.sequence_length : actual_idx
|
||||||
|
+ self.sequence_length
|
||||||
|
+ self.predict_sequence_length
|
||||||
|
]
|
||||||
|
|
||||||
# check if nan values are present
|
# check if nan values are present
|
||||||
if torch.isnan(all_features).any():
|
if torch.isnan(all_features).any():
|
||||||
@@ -188,7 +213,6 @@ class NrvDataset(Dataset):
|
|||||||
|
|
||||||
return all_features, nrv_target
|
return all_features, nrv_target
|
||||||
|
|
||||||
|
|
||||||
def get_batch(self, idx: list):
|
def get_batch(self, idx: list):
|
||||||
features = []
|
features = []
|
||||||
targets = []
|
targets = []
|
||||||
@@ -216,8 +240,8 @@ class NrvDataset(Dataset):
|
|||||||
# check if the date is in the valid indices
|
# check if the date is in the valid indices
|
||||||
if date not in self.datetime.dt.date.unique():
|
if date not in self.datetime.dt.date.unique():
|
||||||
raise ValueError(f"Date {date} not in dataset.")
|
raise ValueError(f"Date {date} not in dataset.")
|
||||||
|
|
||||||
idx = self.datetime[self.datetime.dt.date == date].index[0]
|
idx = self.datetime[self.datetime.dt.date == date].index[0]
|
||||||
|
|
||||||
valid_idx = self.valid_indices.index(idx)
|
valid_idx = self.valid_indices.index(idx)
|
||||||
return valid_idx
|
return valid_idx
|
||||||
|
|||||||
@@ -45,7 +45,8 @@ class PolicyEvaluator:
|
|||||||
):
|
):
|
||||||
idx = test_loader.dataset.get_idx_for_date(date.date())
|
idx = test_loader.dataset.get_idx_for_date(date.date())
|
||||||
|
|
||||||
print("Evaluated for idx: ", idx)
|
if idx not in idx_samples:
|
||||||
|
print("No samples for idx: ", idx, date)
|
||||||
(initial, samples) = idx_samples[idx]
|
(initial, samples) = idx_samples[idx]
|
||||||
|
|
||||||
if len(initial.shape) == 2:
|
if len(initial.shape) == 2:
|
||||||
@@ -98,16 +99,17 @@ class PolicyEvaluator:
|
|||||||
|
|
||||||
def evaluate_test_set(self, idx_samples, test_loader):
|
def evaluate_test_set(self, idx_samples, test_loader):
|
||||||
self.profits = []
|
self.profits = []
|
||||||
try:
|
|
||||||
for date in tqdm(self.dates):
|
|
||||||
self.evaluate_for_date(date, idx_samples, test_loader)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print("Interrupted")
|
|
||||||
raise KeyboardInterrupt
|
|
||||||
|
|
||||||
except Exception as e:
|
for date in tqdm(self.dates):
|
||||||
print(e)
|
try:
|
||||||
pass
|
self.evaluate_for_date(date, idx_samples, test_loader)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("Interrupted")
|
||||||
|
raise KeyboardInterrupt
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
pass
|
||||||
|
|
||||||
self.profits = pd.DataFrame(
|
self.profits = pd.DataFrame(
|
||||||
self.profits,
|
self.profits,
|
||||||
|
|||||||
@@ -151,13 +151,25 @@ class BaselinePolicyEvaluator(PolicyEvaluator):
|
|||||||
|
|
||||||
return best_thresholds
|
return best_thresholds
|
||||||
|
|
||||||
def evaluate_test_set(self, thresholds: dict):
|
def evaluate_test_set(self, thresholds: dict, data_processor=None):
|
||||||
"""Evaluate the test set using the given thresholds (multiple penalties)
|
"""Evaluate the test set using the given thresholds (multiple penalties)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
thresholds (dict): Dictionary with penalties as keys and the corresponding thresholds tuple as values
|
thresholds (dict): Dictionary with penalties as keys and the corresponding thresholds tuple as values
|
||||||
"""
|
"""
|
||||||
self.profits = []
|
self.profits = []
|
||||||
|
|
||||||
|
if data_processor:
|
||||||
|
filtered_dates = []
|
||||||
|
_, test_loader = data_processor.get_dataloaders()
|
||||||
|
for date in self.dates:
|
||||||
|
try:
|
||||||
|
test_loader.dataset.get_idx_for_date(date.date())
|
||||||
|
filtered_dates.append(date)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
self.dates = filtered_dates
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for date in tqdm(self.dates):
|
for date in tqdm(self.dates):
|
||||||
real_imbalance_prices = self.get_imbanlance_prices_for_date(date.date())
|
real_imbalance_prices = self.get_imbanlance_prices_for_date(date.date())
|
||||||
|
|||||||
@@ -54,14 +54,26 @@ class YesterdayBaselinePolicyEvaluator(PolicyEvaluator):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
def evaluate_test_set(self):
|
def evaluate_test_set(self, data_processor):
|
||||||
|
|
||||||
|
if data_processor:
|
||||||
|
filtered_dates = []
|
||||||
|
_, test_loader = data_processor.get_dataloaders()
|
||||||
|
for date in self.dates:
|
||||||
|
try:
|
||||||
|
test_loader.dataset.get_idx_for_date(date.date())
|
||||||
|
filtered_dates.append(date)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
self.dates = filtered_dates
|
||||||
|
|
||||||
self.profits = []
|
self.profits = []
|
||||||
try:
|
for date in tqdm(self.dates):
|
||||||
for date in tqdm(self.dates):
|
try:
|
||||||
self.evaluate_for_date(date)
|
self.evaluate_for_date(date)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
self.profits = pd.DataFrame(
|
self.profits = pd.DataFrame(
|
||||||
self.profits,
|
self.profits,
|
||||||
|
|||||||
@@ -7,6 +7,25 @@ task.execute_remotely(queue_name="default", exit_process=True)
|
|||||||
|
|
||||||
from src.policies.baselines.BaselinePolicyEvaluator import BaselinePolicyEvaluator
|
from src.policies.baselines.BaselinePolicyEvaluator import BaselinePolicyEvaluator
|
||||||
from src.policies.simple_baseline import BaselinePolicy, Battery
|
from src.policies.simple_baseline import BaselinePolicy, Battery
|
||||||
|
from src.data import DataProcessor, DataConfig
|
||||||
|
|
||||||
|
### Data Processor ###
|
||||||
|
data_config = DataConfig()
|
||||||
|
data_config.NRV_HISTORY = True
|
||||||
|
data_config.LOAD_HISTORY = True
|
||||||
|
data_config.LOAD_FORECAST = True
|
||||||
|
|
||||||
|
data_config.WIND_FORECAST = True
|
||||||
|
data_config.WIND_HISTORY = True
|
||||||
|
|
||||||
|
data_config.QUARTER = False
|
||||||
|
data_config.DAY_OF_WEEK = False
|
||||||
|
|
||||||
|
data_config.NOMINAL_NET_POSITION = True
|
||||||
|
|
||||||
|
data_processor = DataProcessor(data_config, path="", lstm=False)
|
||||||
|
data_processor.set_batch_size(64)
|
||||||
|
data_processor.set_full_day_skip(True)
|
||||||
|
|
||||||
### Policy Evaluator ###
|
### Policy Evaluator ###
|
||||||
battery = Battery(2, 1)
|
battery = Battery(2, 1)
|
||||||
@@ -14,7 +33,7 @@ baseline_policy = BaselinePolicy(battery, data_path="")
|
|||||||
policy_evaluator = BaselinePolicyEvaluator(baseline_policy, task)
|
policy_evaluator = BaselinePolicyEvaluator(baseline_policy, task)
|
||||||
|
|
||||||
thresholds = policy_evaluator.determine_best_thresholds()
|
thresholds = policy_evaluator.determine_best_thresholds()
|
||||||
policy_evaluator.evaluate_test_set(thresholds)
|
policy_evaluator.evaluate_test_set(thresholds, data_processor=data_processor)
|
||||||
|
|
||||||
policy_evaluator.plot_profits_table()
|
policy_evaluator.plot_profits_table()
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from src.utils.clearml import ClearMLHelper
|
|||||||
|
|
||||||
#### ClearML ####
|
#### ClearML ####
|
||||||
clearml_helper = ClearMLHelper(project_name="Thesis/NrvForecast")
|
clearml_helper = ClearMLHelper(project_name="Thesis/NrvForecast")
|
||||||
task = clearml_helper.get_task(task_name="Global Thresholds Baseline")
|
task = clearml_helper.get_task(task_name="Yesterday Baseline")
|
||||||
task.execute_remotely(queue_name="default", exit_process=True)
|
task.execute_remotely(queue_name="default", exit_process=True)
|
||||||
|
|
||||||
from src.policies.baselines.BaselinePolicyEvaluator import BaselinePolicyEvaluator
|
from src.policies.baselines.BaselinePolicyEvaluator import BaselinePolicyEvaluator
|
||||||
@@ -10,13 +10,32 @@ from src.policies.simple_baseline import BaselinePolicy, Battery
|
|||||||
from src.policies.baselines.YesterdayBaselinePolicyExecutor import (
|
from src.policies.baselines.YesterdayBaselinePolicyExecutor import (
|
||||||
YesterdayBaselinePolicyEvaluator,
|
YesterdayBaselinePolicyEvaluator,
|
||||||
)
|
)
|
||||||
|
from src.data import DataProcessor, DataConfig
|
||||||
|
|
||||||
|
### Data Processor ###
|
||||||
|
data_config = DataConfig()
|
||||||
|
data_config.NRV_HISTORY = True
|
||||||
|
data_config.LOAD_HISTORY = True
|
||||||
|
data_config.LOAD_FORECAST = True
|
||||||
|
|
||||||
|
data_config.WIND_FORECAST = True
|
||||||
|
data_config.WIND_HISTORY = True
|
||||||
|
|
||||||
|
data_config.QUARTER = False
|
||||||
|
data_config.DAY_OF_WEEK = False
|
||||||
|
|
||||||
|
data_config.NOMINAL_NET_POSITION = True
|
||||||
|
|
||||||
|
data_processor = DataProcessor(data_config, path="", lstm=False)
|
||||||
|
data_processor.set_batch_size(64)
|
||||||
|
data_processor.set_full_day_skip(True)
|
||||||
|
|
||||||
### Policy Evaluator ###
|
### Policy Evaluator ###
|
||||||
battery = Battery(2, 1)
|
battery = Battery(2, 1)
|
||||||
baseline_policy = BaselinePolicy(battery, data_path="")
|
baseline_policy = BaselinePolicy(battery, data_path="")
|
||||||
policy_evaluator = YesterdayBaselinePolicyEvaluator(baseline_policy, task)
|
policy_evaluator = YesterdayBaselinePolicyEvaluator(baseline_policy, task)
|
||||||
|
|
||||||
policy_evaluator.evaluate_test_set()
|
policy_evaluator.evaluate_test_set(data_processor=data_processor)
|
||||||
policy_evaluator.plot_profits_table()
|
policy_evaluator.plot_profits_table()
|
||||||
|
|
||||||
task.close()
|
task.close()
|
||||||
|
|||||||
@@ -155,31 +155,38 @@ class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer):
|
|||||||
generated_samples = {}
|
generated_samples = {}
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
total_samples = len(dataloader.dataset) - 96
|
total_samples = len(dataloader.dataset)
|
||||||
for _, _, idx_batch in tqdm(dataloader):
|
print(
|
||||||
idx_batch = [idx for idx in idx_batch if idx < total_samples]
|
"Full day valid indices: ",
|
||||||
|
len(dataloader.dataset.full_day_valid_indices),
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"Valid indices: ",
|
||||||
|
len(dataloader.dataset.valid_indices),
|
||||||
|
)
|
||||||
|
|
||||||
if len(idx_batch) == 0:
|
print(dataloader.dataset.valid_indices)
|
||||||
continue
|
|
||||||
|
|
||||||
for idx in tqdm(idx_batch):
|
for i in tqdm(dataloader.dataset.full_day_valid_indices):
|
||||||
computed_idx_batch = [idx] * 100
|
idx = dataloader.dataset.valid_indices.index(i)
|
||||||
initial, _, samples, targets = self.auto_regressive(
|
|
||||||
dataloader.dataset, idx_batch=computed_idx_batch
|
|
||||||
)
|
|
||||||
|
|
||||||
generated_samples[idx.item()] = (
|
computed_idx_batch = [idx] * 100
|
||||||
self.data_processor.inverse_transform(initial),
|
initial, _, samples, targets = self.auto_regressive(
|
||||||
self.data_processor.inverse_transform(samples),
|
dataloader.dataset, idx_batch=computed_idx_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
samples = samples.unsqueeze(0)
|
generated_samples[idx] = (
|
||||||
targets = targets.squeeze(-1)
|
self.data_processor.inverse_transform(initial),
|
||||||
targets = targets[0].unsqueeze(0)
|
self.data_processor.inverse_transform(samples),
|
||||||
|
)
|
||||||
|
|
||||||
crps = crps_from_samples(samples, targets)
|
samples = samples.unsqueeze(0)
|
||||||
|
targets = targets.squeeze(-1)
|
||||||
|
targets = targets[0].unsqueeze(0)
|
||||||
|
|
||||||
crps_from_samples_metric.append(crps[0].mean().item())
|
crps = crps_from_samples(samples, targets)
|
||||||
|
|
||||||
|
crps_from_samples_metric.append(crps[0].mean().item())
|
||||||
|
|
||||||
task.get_logger().report_scalar(
|
task.get_logger().report_scalar(
|
||||||
title="CRPS_from_samples",
|
title="CRPS_from_samples",
|
||||||
@@ -190,10 +197,7 @@ class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer):
|
|||||||
|
|
||||||
# using the policy evaluator, evaluate the policy with the generated samples
|
# using the policy evaluator, evaluate the policy with the generated samples
|
||||||
if self.policy_evaluator is not None:
|
if self.policy_evaluator is not None:
|
||||||
_, test_loader = self.data_processor.get_dataloaders(
|
self.policy_evaluator.evaluate_test_set(generated_samples, dataloader)
|
||||||
predict_sequence_length=self.model.output_size, full_day_skip=True
|
|
||||||
)
|
|
||||||
self.policy_evaluator.evaluate_test_set(generated_samples, test_loader)
|
|
||||||
df = self.policy_evaluator.get_profits_as_scalars()
|
df = self.policy_evaluator.get_profits_as_scalars()
|
||||||
|
|
||||||
# for each row, report the profits
|
# for each row, report the profits
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from plotly.subplots import make_subplots
|
|||||||
from clearml.config import running_remotely
|
from clearml.config import running_remotely
|
||||||
from torchinfo import summary
|
from torchinfo import summary
|
||||||
|
|
||||||
|
|
||||||
class Trainer:
|
class Trainer:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -95,13 +96,15 @@ class Trainer:
|
|||||||
loader = test_loader
|
loader = test_loader
|
||||||
|
|
||||||
np.random.seed(42)
|
np.random.seed(42)
|
||||||
actual_indices = np.random.choice(loader.dataset.full_day_valid_indices, num_samples, replace=False)
|
actual_indices = np.random.choice(
|
||||||
|
loader.dataset.full_day_valid_indices, num_samples, replace=False
|
||||||
|
)
|
||||||
indices = {}
|
indices = {}
|
||||||
for i in actual_indices:
|
for i in actual_indices:
|
||||||
indices[i] = loader.dataset.valid_indices.index(i)
|
indices[i] = loader.dataset.valid_indices.index(i)
|
||||||
|
|
||||||
print(actual_indices)
|
print(actual_indices)
|
||||||
|
|
||||||
return indices
|
return indices
|
||||||
|
|
||||||
def train(self, epochs: int, remotely: bool = False, task: Task = None):
|
def train(self, epochs: int, remotely: bool = False, task: Task = None):
|
||||||
@@ -190,9 +193,7 @@ class Trainer:
|
|||||||
# )
|
# )
|
||||||
|
|
||||||
if hasattr(self, "calculate_crps_from_samples"):
|
if hasattr(self, "calculate_crps_from_samples"):
|
||||||
self.calculate_crps_from_samples(
|
self.calculate_crps_from_samples(task, test_loader, epoch)
|
||||||
task, full_day_skip_test_loader, epoch
|
|
||||||
)
|
|
||||||
|
|
||||||
if task:
|
if task:
|
||||||
self.finish_training(task=task)
|
self.finish_training(task=task)
|
||||||
@@ -259,7 +260,6 @@ class Trainer:
|
|||||||
self.model = torch.load("checkpoint.pt")
|
self.model = torch.load("checkpoint.pt")
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
|
|
||||||
# set full day skip
|
# set full day skip
|
||||||
self.data_processor.set_full_day_skip(True)
|
self.data_processor.set_full_day_skip(True)
|
||||||
train_loader, test_loader = self.data_processor.get_dataloaders(
|
train_loader, test_loader = self.data_processor.get_dataloaders(
|
||||||
@@ -361,7 +361,6 @@ class Trainer:
|
|||||||
for trace in sub_fig.data:
|
for trace in sub_fig.data:
|
||||||
fig.add_trace(trace, row=row, col=col)
|
fig.add_trace(trace, row=row, col=col)
|
||||||
|
|
||||||
|
|
||||||
# loss = self.criterion(predictions.to(self.device), target.squeeze(-1).to(self.device)).item()
|
# loss = self.criterion(predictions.to(self.device), target.squeeze(-1).to(self.device)).item()
|
||||||
|
|
||||||
# fig['layout']['annotations'][i].update(text=f"{loss.__class__.__name__}: {loss:.6f}")
|
# fig['layout']['annotations'][i].update(text=f"{loss.__class__.__name__}: {loss:.6f}")
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ data_processor.set_full_day_skip(False)
|
|||||||
|
|
||||||
|
|
||||||
#### Hyperparameters ####
|
#### Hyperparameters ####
|
||||||
data_processor.set_output_size(96)
|
data_processor.set_output_size(1)
|
||||||
inputDim = data_processor.get_input_size()
|
inputDim = data_processor.get_input_size()
|
||||||
epochs = 300
|
epochs = 300
|
||||||
|
|
||||||
@@ -80,7 +80,7 @@ time_embedding = TimeEmbedding(
|
|||||||
# lstm_model = GRUModel(time_embedding.output_dim(inputDim), len(quantiles), hidden_size=model_parameters["hidden_size"], num_layers=model_parameters["num_layers"], dropout=model_parameters["dropout"])
|
# lstm_model = GRUModel(time_embedding.output_dim(inputDim), len(quantiles), hidden_size=model_parameters["hidden_size"], num_layers=model_parameters["num_layers"], dropout=model_parameters["dropout"])
|
||||||
non_linear_model = NonLinearRegression(
|
non_linear_model = NonLinearRegression(
|
||||||
time_embedding.output_dim(inputDim),
|
time_embedding.output_dim(inputDim),
|
||||||
len(quantiles) * 96,
|
len(quantiles),
|
||||||
hiddenSize=model_parameters["hidden_size"],
|
hiddenSize=model_parameters["hidden_size"],
|
||||||
numLayers=model_parameters["num_layers"],
|
numLayers=model_parameters["num_layers"],
|
||||||
dropout=model_parameters["dropout"],
|
dropout=model_parameters["dropout"],
|
||||||
@@ -97,18 +97,7 @@ baseline_policy = BaselinePolicy(battery, data_path="")
|
|||||||
policy_evaluator = PolicyEvaluator(baseline_policy, task)
|
policy_evaluator = PolicyEvaluator(baseline_policy, task)
|
||||||
|
|
||||||
#### Trainer ####
|
#### Trainer ####
|
||||||
# trainer = AutoRegressiveQuantileTrainer(
|
trainer = AutoRegressiveQuantileTrainer(
|
||||||
# model,
|
|
||||||
# inputDim,
|
|
||||||
# optimizer,
|
|
||||||
# data_processor,
|
|
||||||
# quantiles,
|
|
||||||
# "cuda",
|
|
||||||
# policy_evaluator=policy_evaluator,
|
|
||||||
# debug=False,
|
|
||||||
# )
|
|
||||||
|
|
||||||
trainer = NonAutoRegressiveQuantileRegression(
|
|
||||||
model,
|
model,
|
||||||
inputDim,
|
inputDim,
|
||||||
optimizer,
|
optimizer,
|
||||||
@@ -119,6 +108,17 @@ trainer = NonAutoRegressiveQuantileRegression(
|
|||||||
debug=False,
|
debug=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# trainer = NonAutoRegressiveQuantileRegression(
|
||||||
|
# model,
|
||||||
|
# inputDim,
|
||||||
|
# optimizer,
|
||||||
|
# data_processor,
|
||||||
|
# quantiles,
|
||||||
|
# "cuda",
|
||||||
|
# policy_evaluator=policy_evaluator,
|
||||||
|
# debug=False,
|
||||||
|
# )
|
||||||
|
|
||||||
trainer.add_metrics_to_track(
|
trainer.add_metrics_to_track(
|
||||||
[PinballLoss(quantiles), MSELoss(), L1Loss(), CRPSLoss(quantiles)]
|
[PinballLoss(quantiles), MSELoss(), L1Loss(), CRPSLoss(quantiles)]
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user