Improved policy executer

This commit is contained in:
Victor Mylle
2024-01-16 23:22:05 +00:00
parent d1074281c4
commit b87ad1bf42
7 changed files with 1328 additions and 101 deletions

File diff suppressed because one or more lines are too long

View File

@@ -9,6 +9,7 @@ import datetime
from tqdm import tqdm
from src.utils.imbalance_price_calculator import ImbalancePriceCalculator
import time
import plotly.express as px
### import functions ###
from src.trainers.quantile_trainer import auto_regressive as quantile_auto_regressive
@@ -19,10 +20,12 @@ from src.utils.clearml import ClearMLHelper
parser = argparse.ArgumentParser()
parser.add_argument('--task_id', type=str, default=None)
parser.add_argument('--model_type', type=str, default=None)
parser.add_argument('--model_name', type=str, default=None)
args = parser.parse_args()
assert args.task_id is not None, "Please specify task id"
assert args.model_type is not None, "Please specify model type"
assert args.model_name is not None, "Please specify model name"
battery = Battery(2, 1)
baseline_policy = BaselinePolicy(battery, data_path="")
@@ -43,22 +46,28 @@ def load_model(task_id: str):
"""
task = Task.get_task(task_id=task_id)
lstm = task.get_parameter("data_processor/lstm")
full_day_skip = task.get_parameter("data_processor/full_day_skip")
output_size = int(task.get_parameter("data_processor/output_size"))
print(f"lstm: {lstm}")
print(f"full_day_skip: {full_day_skip}")
print(f"output_size: {output_size}")
configuration = task.get_parameters_as_dict()
data_features = configuration['data_features']
### Data Config ###
data_config = DataConfig()
for key, value in data_features.items():
setattr(data_config, key, bool(value))
data_config.PV_FORECAST = False
data_config.PV_HISTORY = False
data_config.QUARTER = False
data_config.DAY_OF_WEEK = False
setattr(data_config, key, value == "True")
print(data_config.__dict__)
### Data Processor ###
data_processor = DataProcessor(data_config, path="", lstm=False)
data_processor = DataProcessor(data_config, path="", lstm=lstm=="True")
data_processor.set_batch_size(8192)
data_processor.set_full_day_skip(True)
data_processor.set_full_day_skip(full_day_skip == "True")
data_processor.set_output_size(int(output_size))
### Model ###
output_model_id = task.output_models_id["checkpoint"]
@@ -72,7 +81,7 @@ def load_model(task_id: str):
model.eval()
_, test_loader = data_processor.get_dataloaders(
predict_sequence_length=96
predict_sequence_length=output_size
)
return configuration, model, data_processor, test_loader
@@ -80,7 +89,7 @@ def load_model(task_id: str):
def quantile_auto_regressive_predicted_NRV(model, date, data_processor, test_loader):
idx = test_loader.dataset.get_idx_for_date(date.date())
initial, _, samples, target = quantile_auto_regressive(test_loader.dataset, model, [idx]*500, 96)
initial, _, samples, target = quantile_auto_regressive(test_loader.dataset, model, model.quantiles, [idx]*500, 96)
samples = samples.cpu().numpy()
target = target.cpu().numpy()
@@ -147,7 +156,7 @@ def get_next_day_profits_for_date(model, data_processor, test_loader, date, ipc,
return predicted_nrv_profits_cycles, baseline_profits_cycles
def next_day_test_set(model, data_processor, test_loader, ipc, predict_NRV: callable):
penalties = [0, 10, 50, 150, 250, 350, 500]
penalties = [0, 10, 50, 150, 300, 500, 600, 800, 1000, 1500, 2000, 2500]
predicted_nrv_profits_cycles = {i: [0, 0] for i in penalties}
baseline_profits_cycles = {i: [0, 0] for i in penalties}
@@ -169,7 +178,6 @@ def next_day_test_set(model, data_processor, test_loader, ipc, predict_NRV: call
baseline_profits_cycles[penalty][1] += new_baseline_profits_cycles[penalty][1]
except Exception as e:
# raise e
# print(f"Error for date {date}")
continue
@@ -179,14 +187,16 @@ def main():
clearml_helper = ClearMLHelper(project_name="Thesis/NrvForecast")
task = clearml_helper.get_task(task_name="Policy Test")
task.connect(args, name="Arguments")
task.execute_remotely(queue_name="default", exit_process=True)
configuration, model, data_processor, test_loader = load_model(args.task_id)
if args.model_type == "quantile":
if args.model_type == "autoregressive_quantile":
quantiles = configuration["general"]["quantiles"]
quantiles = list(map(float, quantiles.strip('[]').split(',')))
model.quantiles = quantiles
predict_NRV = quantile_auto_regressive_predicted_NRV
task.add_tags(["quantile"])
task.add_tags(["autoregressive_quantile"])
elif args.model_type == "diffusion":
predict_NRV = diffusion_predicted_NRV
task.add_tags(["diffusion"])
@@ -203,7 +213,7 @@ def main():
# use concat
for penalty in predicted_nrv_profits_cycles.keys():
new_rows = pd.DataFrame({
"name": [args.model_type, "baseline"],
"name": [f"{args.model_type} ({args.model_name})", "baseline"],
"penalty": [penalty, penalty],
"profit": [predicted_nrv_profits_cycles[penalty][0], baseline_profits_cycles[penalty][0]],
"cycles": [predicted_nrv_profits_cycles[penalty][1], baseline_profits_cycles[penalty][1]]
@@ -214,7 +224,7 @@ def main():
df = df.sort_values(by=["name", "penalty"])
# calculate profit per cycle
df["profit_per_cycle"] = df["profit"] / df["cycles"]
df["profit_per_cycle"] = df.apply(lambda row: row["profit"] / row["cycles"] if row["cycles"] != 0 else 0, axis=1)
task.get_logger().report_table(
"Policy Results",
@@ -223,6 +233,28 @@ def main():
table_plot=df
)
# plotly to show profit on y axis and cycles on x axis (show 2 lines, one for each model)
fig = px.line(
df,
x="cycles",
y="profit",
color="name",
title="Profit vs. Cycles for Each Model",
labels={"cycles": "Cycles", "profit": "Profit"},
markers=True, # Adds markers to the lines
hover_data=["penalty"] # Adds additional hover information
)
fig.update_xaxes(autorange="reversed")
task.get_logger().report_plotly(
"Policy Results",
"Profit vs. Cycles for Each Model",
iteration=0,
figure=fig
)
# close task
task.close()