Improved policy executer
This commit is contained in:
1128
src/policies/plot_combiner.ipynb
Normal file
1128
src/policies/plot_combiner.ipynb
Normal file
File diff suppressed because one or more lines are too long
@@ -9,6 +9,7 @@ import datetime
|
||||
from tqdm import tqdm
|
||||
from src.utils.imbalance_price_calculator import ImbalancePriceCalculator
|
||||
import time
|
||||
import plotly.express as px
|
||||
|
||||
### import functions ###
|
||||
from src.trainers.quantile_trainer import auto_regressive as quantile_auto_regressive
|
||||
@@ -19,10 +20,12 @@ from src.utils.clearml import ClearMLHelper
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task_id', type=str, default=None)
|
||||
parser.add_argument('--model_type', type=str, default=None)
|
||||
parser.add_argument('--model_name', type=str, default=None)
|
||||
args = parser.parse_args()
|
||||
|
||||
assert args.task_id is not None, "Please specify task id"
|
||||
assert args.model_type is not None, "Please specify model type"
|
||||
assert args.model_name is not None, "Please specify model name"
|
||||
|
||||
battery = Battery(2, 1)
|
||||
baseline_policy = BaselinePolicy(battery, data_path="")
|
||||
@@ -43,22 +46,28 @@ def load_model(task_id: str):
|
||||
"""
|
||||
task = Task.get_task(task_id=task_id)
|
||||
|
||||
lstm = task.get_parameter("data_processor/lstm")
|
||||
full_day_skip = task.get_parameter("data_processor/full_day_skip")
|
||||
output_size = int(task.get_parameter("data_processor/output_size"))
|
||||
|
||||
print(f"lstm: {lstm}")
|
||||
print(f"full_day_skip: {full_day_skip}")
|
||||
print(f"output_size: {output_size}")
|
||||
|
||||
configuration = task.get_parameters_as_dict()
|
||||
data_features = configuration['data_features']
|
||||
|
||||
### Data Config ###
|
||||
data_config = DataConfig()
|
||||
for key, value in data_features.items():
|
||||
setattr(data_config, key, bool(value))
|
||||
data_config.PV_FORECAST = False
|
||||
data_config.PV_HISTORY = False
|
||||
data_config.QUARTER = False
|
||||
data_config.DAY_OF_WEEK = False
|
||||
setattr(data_config, key, value == "True")
|
||||
print(data_config.__dict__)
|
||||
|
||||
### Data Processor ###
|
||||
data_processor = DataProcessor(data_config, path="", lstm=False)
|
||||
data_processor = DataProcessor(data_config, path="", lstm=lstm=="True")
|
||||
data_processor.set_batch_size(8192)
|
||||
data_processor.set_full_day_skip(True)
|
||||
data_processor.set_full_day_skip(full_day_skip == "True")
|
||||
data_processor.set_output_size(int(output_size))
|
||||
|
||||
### Model ###
|
||||
output_model_id = task.output_models_id["checkpoint"]
|
||||
@@ -72,7 +81,7 @@ def load_model(task_id: str):
|
||||
model.eval()
|
||||
|
||||
_, test_loader = data_processor.get_dataloaders(
|
||||
predict_sequence_length=96
|
||||
predict_sequence_length=output_size
|
||||
)
|
||||
|
||||
return configuration, model, data_processor, test_loader
|
||||
@@ -80,7 +89,7 @@ def load_model(task_id: str):
|
||||
|
||||
def quantile_auto_regressive_predicted_NRV(model, date, data_processor, test_loader):
|
||||
idx = test_loader.dataset.get_idx_for_date(date.date())
|
||||
initial, _, samples, target = quantile_auto_regressive(test_loader.dataset, model, [idx]*500, 96)
|
||||
initial, _, samples, target = quantile_auto_regressive(test_loader.dataset, model, model.quantiles, [idx]*500, 96)
|
||||
samples = samples.cpu().numpy()
|
||||
target = target.cpu().numpy()
|
||||
|
||||
@@ -147,7 +156,7 @@ def get_next_day_profits_for_date(model, data_processor, test_loader, date, ipc,
|
||||
return predicted_nrv_profits_cycles, baseline_profits_cycles
|
||||
|
||||
def next_day_test_set(model, data_processor, test_loader, ipc, predict_NRV: callable):
|
||||
penalties = [0, 10, 50, 150, 250, 350, 500]
|
||||
penalties = [0, 10, 50, 150, 300, 500, 600, 800, 1000, 1500, 2000, 2500]
|
||||
predicted_nrv_profits_cycles = {i: [0, 0] for i in penalties}
|
||||
baseline_profits_cycles = {i: [0, 0] for i in penalties}
|
||||
|
||||
@@ -169,7 +178,6 @@ def next_day_test_set(model, data_processor, test_loader, ipc, predict_NRV: call
|
||||
baseline_profits_cycles[penalty][1] += new_baseline_profits_cycles[penalty][1]
|
||||
|
||||
except Exception as e:
|
||||
# raise e
|
||||
# print(f"Error for date {date}")
|
||||
continue
|
||||
|
||||
@@ -179,14 +187,16 @@ def main():
|
||||
clearml_helper = ClearMLHelper(project_name="Thesis/NrvForecast")
|
||||
task = clearml_helper.get_task(task_name="Policy Test")
|
||||
|
||||
task.connect(args, name="Arguments")
|
||||
task.execute_remotely(queue_name="default", exit_process=True)
|
||||
|
||||
configuration, model, data_processor, test_loader = load_model(args.task_id)
|
||||
|
||||
if args.model_type == "quantile":
|
||||
if args.model_type == "autoregressive_quantile":
|
||||
quantiles = configuration["general"]["quantiles"]
|
||||
quantiles = list(map(float, quantiles.strip('[]').split(',')))
|
||||
model.quantiles = quantiles
|
||||
predict_NRV = quantile_auto_regressive_predicted_NRV
|
||||
task.add_tags(["quantile"])
|
||||
task.add_tags(["autoregressive_quantile"])
|
||||
elif args.model_type == "diffusion":
|
||||
predict_NRV = diffusion_predicted_NRV
|
||||
task.add_tags(["diffusion"])
|
||||
@@ -203,7 +213,7 @@ def main():
|
||||
# use concat
|
||||
for penalty in predicted_nrv_profits_cycles.keys():
|
||||
new_rows = pd.DataFrame({
|
||||
"name": [args.model_type, "baseline"],
|
||||
"name": [f"{args.model_type} ({args.model_name})", "baseline"],
|
||||
"penalty": [penalty, penalty],
|
||||
"profit": [predicted_nrv_profits_cycles[penalty][0], baseline_profits_cycles[penalty][0]],
|
||||
"cycles": [predicted_nrv_profits_cycles[penalty][1], baseline_profits_cycles[penalty][1]]
|
||||
@@ -214,7 +224,7 @@ def main():
|
||||
df = df.sort_values(by=["name", "penalty"])
|
||||
|
||||
# calculate profit per cycle
|
||||
df["profit_per_cycle"] = df["profit"] / df["cycles"]
|
||||
df["profit_per_cycle"] = df.apply(lambda row: row["profit"] / row["cycles"] if row["cycles"] != 0 else 0, axis=1)
|
||||
|
||||
task.get_logger().report_table(
|
||||
"Policy Results",
|
||||
@@ -223,6 +233,28 @@ def main():
|
||||
table_plot=df
|
||||
)
|
||||
|
||||
# plotly to show profit on y axis and cycles on x axis (show 2 lines, one for each model)
|
||||
fig = px.line(
|
||||
df,
|
||||
x="cycles",
|
||||
y="profit",
|
||||
color="name",
|
||||
title="Profit vs. Cycles for Each Model",
|
||||
labels={"cycles": "Cycles", "profit": "Profit"},
|
||||
markers=True, # Adds markers to the lines
|
||||
hover_data=["penalty"] # Adds additional hover information
|
||||
)
|
||||
|
||||
fig.update_xaxes(autorange="reversed")
|
||||
|
||||
|
||||
task.get_logger().report_plotly(
|
||||
"Policy Results",
|
||||
"Profit vs. Cycles for Each Model",
|
||||
iteration=0,
|
||||
figure=fig
|
||||
)
|
||||
|
||||
# close task
|
||||
task.close()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user