Fixed policy evaluation for autoregressive

This commit is contained in:
2024-02-29 23:23:11 +01:00
parent fe1e388ffb
commit 34335cd9fe
10 changed files with 191 additions and 95 deletions

View File

@@ -45,7 +45,8 @@ class PolicyEvaluator:
):
idx = test_loader.dataset.get_idx_for_date(date.date())
print("Evaluated for idx: ", idx)
if idx not in idx_samples:
print("No samples for idx: ", idx, date)
(initial, samples) = idx_samples[idx]
if len(initial.shape) == 2:
@@ -98,16 +99,17 @@ class PolicyEvaluator:
def evaluate_test_set(self, idx_samples, test_loader):
self.profits = []
try:
for date in tqdm(self.dates):
self.evaluate_for_date(date, idx_samples, test_loader)
except KeyboardInterrupt:
print("Interrupted")
raise KeyboardInterrupt
except Exception as e:
print(e)
pass
for date in tqdm(self.dates):
try:
self.evaluate_for_date(date, idx_samples, test_loader)
except KeyboardInterrupt:
print("Interrupted")
raise KeyboardInterrupt
except Exception as e:
print(e)
pass
self.profits = pd.DataFrame(
self.profits,

View File

@@ -151,13 +151,25 @@ class BaselinePolicyEvaluator(PolicyEvaluator):
return best_thresholds
def evaluate_test_set(self, thresholds: dict):
def evaluate_test_set(self, thresholds: dict, data_processor=None):
"""Evaluate the test set using the given thresholds (multiple penalties)
Args:
thresholds (dict): Dictionary with penalties as keys and the corresponding thresholds tuple as values
"""
self.profits = []
if data_processor:
filtered_dates = []
_, test_loader = data_processor.get_dataloaders()
for date in self.dates:
try:
test_loader.dataset.get_idx_for_date(date.date())
filtered_dates.append(date)
except:
pass
self.dates = filtered_dates
try:
for date in tqdm(self.dates):
real_imbalance_prices = self.get_imbanlance_prices_for_date(date.date())

View File

@@ -54,14 +54,26 @@ class YesterdayBaselinePolicyEvaluator(PolicyEvaluator):
]
)
def evaluate_test_set(self):
def evaluate_test_set(self, data_processor):
if data_processor:
filtered_dates = []
_, test_loader = data_processor.get_dataloaders()
for date in self.dates:
try:
test_loader.dataset.get_idx_for_date(date.date())
filtered_dates.append(date)
except:
pass
self.dates = filtered_dates
self.profits = []
try:
for date in tqdm(self.dates):
for date in tqdm(self.dates):
try:
self.evaluate_for_date(date)
except Exception as e:
print(e)
pass
except Exception as e:
print(e)
pass
self.profits = pd.DataFrame(
self.profits,

View File

@@ -7,6 +7,25 @@ task.execute_remotely(queue_name="default", exit_process=True)
from src.policies.baselines.BaselinePolicyEvaluator import BaselinePolicyEvaluator
from src.policies.simple_baseline import BaselinePolicy, Battery
from src.data import DataProcessor, DataConfig
### Data Processor ###
data_config = DataConfig()
data_config.NRV_HISTORY = True
data_config.LOAD_HISTORY = True
data_config.LOAD_FORECAST = True
data_config.WIND_FORECAST = True
data_config.WIND_HISTORY = True
data_config.QUARTER = False
data_config.DAY_OF_WEEK = False
data_config.NOMINAL_NET_POSITION = True
data_processor = DataProcessor(data_config, path="", lstm=False)
data_processor.set_batch_size(64)
data_processor.set_full_day_skip(True)
### Policy Evaluator ###
battery = Battery(2, 1)
@@ -14,7 +33,7 @@ baseline_policy = BaselinePolicy(battery, data_path="")
policy_evaluator = BaselinePolicyEvaluator(baseline_policy, task)
thresholds = policy_evaluator.determine_best_thresholds()
policy_evaluator.evaluate_test_set(thresholds)
policy_evaluator.evaluate_test_set(thresholds, data_processor=data_processor)
policy_evaluator.plot_profits_table()

View File

@@ -2,7 +2,7 @@ from src.utils.clearml import ClearMLHelper
#### ClearML ####
clearml_helper = ClearMLHelper(project_name="Thesis/NrvForecast")
task = clearml_helper.get_task(task_name="Global Thresholds Baseline")
task = clearml_helper.get_task(task_name="Yesterday Baseline")
task.execute_remotely(queue_name="default", exit_process=True)
from src.policies.baselines.BaselinePolicyEvaluator import BaselinePolicyEvaluator
@@ -10,13 +10,32 @@ from src.policies.simple_baseline import BaselinePolicy, Battery
from src.policies.baselines.YesterdayBaselinePolicyExecutor import (
YesterdayBaselinePolicyEvaluator,
)
from src.data import DataProcessor, DataConfig
### Data Processor ###
data_config = DataConfig()
data_config.NRV_HISTORY = True
data_config.LOAD_HISTORY = True
data_config.LOAD_FORECAST = True
data_config.WIND_FORECAST = True
data_config.WIND_HISTORY = True
data_config.QUARTER = False
data_config.DAY_OF_WEEK = False
data_config.NOMINAL_NET_POSITION = True
data_processor = DataProcessor(data_config, path="", lstm=False)
data_processor.set_batch_size(64)
data_processor.set_full_day_skip(True)
### Policy Evaluator ###
battery = Battery(2, 1)
baseline_policy = BaselinePolicy(battery, data_path="")
policy_evaluator = YesterdayBaselinePolicyEvaluator(baseline_policy, task)
policy_evaluator.evaluate_test_set()
policy_evaluator.evaluate_test_set(data_processor=data_processor)
policy_evaluator.plot_profits_table()
task.close()