Fixed policy evaluation for autoregressive
This commit is contained in:
@@ -45,7 +45,8 @@ class PolicyEvaluator:
|
||||
):
|
||||
idx = test_loader.dataset.get_idx_for_date(date.date())
|
||||
|
||||
print("Evaluated for idx: ", idx)
|
||||
if idx not in idx_samples:
|
||||
print("No samples for idx: ", idx, date)
|
||||
(initial, samples) = idx_samples[idx]
|
||||
|
||||
if len(initial.shape) == 2:
|
||||
@@ -98,16 +99,17 @@ class PolicyEvaluator:
|
||||
|
||||
def evaluate_test_set(self, idx_samples, test_loader):
|
||||
self.profits = []
|
||||
try:
|
||||
for date in tqdm(self.dates):
|
||||
self.evaluate_for_date(date, idx_samples, test_loader)
|
||||
except KeyboardInterrupt:
|
||||
print("Interrupted")
|
||||
raise KeyboardInterrupt
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
||||
for date in tqdm(self.dates):
|
||||
try:
|
||||
self.evaluate_for_date(date, idx_samples, test_loader)
|
||||
except KeyboardInterrupt:
|
||||
print("Interrupted")
|
||||
raise KeyboardInterrupt
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
||||
|
||||
self.profits = pd.DataFrame(
|
||||
self.profits,
|
||||
|
||||
@@ -151,13 +151,25 @@ class BaselinePolicyEvaluator(PolicyEvaluator):
|
||||
|
||||
return best_thresholds
|
||||
|
||||
def evaluate_test_set(self, thresholds: dict):
|
||||
def evaluate_test_set(self, thresholds: dict, data_processor=None):
|
||||
"""Evaluate the test set using the given thresholds (multiple penalties)
|
||||
|
||||
Args:
|
||||
thresholds (dict): Dictionary with penalties as keys and the corresponding thresholds tuple as values
|
||||
"""
|
||||
self.profits = []
|
||||
|
||||
if data_processor:
|
||||
filtered_dates = []
|
||||
_, test_loader = data_processor.get_dataloaders()
|
||||
for date in self.dates:
|
||||
try:
|
||||
test_loader.dataset.get_idx_for_date(date.date())
|
||||
filtered_dates.append(date)
|
||||
except:
|
||||
pass
|
||||
self.dates = filtered_dates
|
||||
|
||||
try:
|
||||
for date in tqdm(self.dates):
|
||||
real_imbalance_prices = self.get_imbanlance_prices_for_date(date.date())
|
||||
|
||||
@@ -54,14 +54,26 @@ class YesterdayBaselinePolicyEvaluator(PolicyEvaluator):
|
||||
]
|
||||
)
|
||||
|
||||
def evaluate_test_set(self):
|
||||
def evaluate_test_set(self, data_processor):
|
||||
|
||||
if data_processor:
|
||||
filtered_dates = []
|
||||
_, test_loader = data_processor.get_dataloaders()
|
||||
for date in self.dates:
|
||||
try:
|
||||
test_loader.dataset.get_idx_for_date(date.date())
|
||||
filtered_dates.append(date)
|
||||
except:
|
||||
pass
|
||||
self.dates = filtered_dates
|
||||
|
||||
self.profits = []
|
||||
try:
|
||||
for date in tqdm(self.dates):
|
||||
for date in tqdm(self.dates):
|
||||
try:
|
||||
self.evaluate_for_date(date)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
||||
|
||||
self.profits = pd.DataFrame(
|
||||
self.profits,
|
||||
|
||||
@@ -7,6 +7,25 @@ task.execute_remotely(queue_name="default", exit_process=True)
|
||||
|
||||
from src.policies.baselines.BaselinePolicyEvaluator import BaselinePolicyEvaluator
|
||||
from src.policies.simple_baseline import BaselinePolicy, Battery
|
||||
from src.data import DataProcessor, DataConfig
|
||||
|
||||
### Data Processor ###
|
||||
data_config = DataConfig()
|
||||
data_config.NRV_HISTORY = True
|
||||
data_config.LOAD_HISTORY = True
|
||||
data_config.LOAD_FORECAST = True
|
||||
|
||||
data_config.WIND_FORECAST = True
|
||||
data_config.WIND_HISTORY = True
|
||||
|
||||
data_config.QUARTER = False
|
||||
data_config.DAY_OF_WEEK = False
|
||||
|
||||
data_config.NOMINAL_NET_POSITION = True
|
||||
|
||||
data_processor = DataProcessor(data_config, path="", lstm=False)
|
||||
data_processor.set_batch_size(64)
|
||||
data_processor.set_full_day_skip(True)
|
||||
|
||||
### Policy Evaluator ###
|
||||
battery = Battery(2, 1)
|
||||
@@ -14,7 +33,7 @@ baseline_policy = BaselinePolicy(battery, data_path="")
|
||||
policy_evaluator = BaselinePolicyEvaluator(baseline_policy, task)
|
||||
|
||||
thresholds = policy_evaluator.determine_best_thresholds()
|
||||
policy_evaluator.evaluate_test_set(thresholds)
|
||||
policy_evaluator.evaluate_test_set(thresholds, data_processor=data_processor)
|
||||
|
||||
policy_evaluator.plot_profits_table()
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from src.utils.clearml import ClearMLHelper
|
||||
|
||||
#### ClearML ####
|
||||
clearml_helper = ClearMLHelper(project_name="Thesis/NrvForecast")
|
||||
task = clearml_helper.get_task(task_name="Global Thresholds Baseline")
|
||||
task = clearml_helper.get_task(task_name="Yesterday Baseline")
|
||||
task.execute_remotely(queue_name="default", exit_process=True)
|
||||
|
||||
from src.policies.baselines.BaselinePolicyEvaluator import BaselinePolicyEvaluator
|
||||
@@ -10,13 +10,32 @@ from src.policies.simple_baseline import BaselinePolicy, Battery
|
||||
from src.policies.baselines.YesterdayBaselinePolicyExecutor import (
|
||||
YesterdayBaselinePolicyEvaluator,
|
||||
)
|
||||
from src.data import DataProcessor, DataConfig
|
||||
|
||||
### Data Processor ###
|
||||
data_config = DataConfig()
|
||||
data_config.NRV_HISTORY = True
|
||||
data_config.LOAD_HISTORY = True
|
||||
data_config.LOAD_FORECAST = True
|
||||
|
||||
data_config.WIND_FORECAST = True
|
||||
data_config.WIND_HISTORY = True
|
||||
|
||||
data_config.QUARTER = False
|
||||
data_config.DAY_OF_WEEK = False
|
||||
|
||||
data_config.NOMINAL_NET_POSITION = True
|
||||
|
||||
data_processor = DataProcessor(data_config, path="", lstm=False)
|
||||
data_processor.set_batch_size(64)
|
||||
data_processor.set_full_day_skip(True)
|
||||
|
||||
### Policy Evaluator ###
|
||||
battery = Battery(2, 1)
|
||||
baseline_policy = BaselinePolicy(battery, data_path="")
|
||||
policy_evaluator = YesterdayBaselinePolicyEvaluator(baseline_policy, task)
|
||||
|
||||
policy_evaluator.evaluate_test_set()
|
||||
policy_evaluator.evaluate_test_set(data_processor=data_processor)
|
||||
policy_evaluator.plot_profits_table()
|
||||
|
||||
task.close()
|
||||
|
||||
Reference in New Issue
Block a user