diff --git a/src/policies/policy_executer.py b/src/policies/policy_executer.py index 001de94..6225d2d 100644 --- a/src/policies/policy_executer.py +++ b/src/policies/policy_executer.py @@ -176,14 +176,13 @@ def next_day_test_set(model, data_processor, test_loader, ipc, predict_NRV: call return predicted_nrv_profits_cycles, baseline_profits_cycles def main(): - configuration, model, data_processor, test_loader = load_model(args.task_id) - clearml_helper = ClearMLHelper(project_name="Thesis/NrvForecast") task = clearml_helper.get_task(task_name="Policy Test") task.connect(args, name="Arguments") task.execute_remotely(queue_name="default", exit_process=True) + configuration, model, data_processor, test_loader = load_model(args.task_id) if args.model_type == "quantile": predict_NRV = quantile_auto_regressive_predicted_NRV @@ -214,6 +213,9 @@ def main(): # sort by name, penalty ascending df = df.sort_values(by=["name", "penalty"]) + # calculate profit per cycle + df["profit_per_cycle"] = df["profit"] / df["cycles"] + task.get_logger().report_table( "Policy Results", "Policy Results",