Added diffusion validation set

This commit is contained in:
Victor Mylle
2024-05-17 16:11:17 +00:00
parent 11ae0e1949
commit 8a219d0d19
24 changed files with 64 additions and 36 deletions

View File

@@ -15,7 +15,17 @@ class PolicyEvaluator:
self.baseline_policy = baseline_policy
self.ipc = ImbalancePriceCalculator(data_path="")
self.dates = baseline_policy.test_data["DateTime"].dt.date.unique()
# also add dates from last 2 months of 2023
self.dates = np.append(
self.dates,
pd.date_range(
start="2022-11-01", end="2022-12-31", freq="D"
).to_pydatetime(),
)
self.dates = pd.to_datetime(self.dates)
### Load Imbalance Prices ###
@@ -116,6 +126,10 @@ class PolicyEvaluator:
# Calculate the gradient (difference) between the simulated and target charge cycles
gradient = simulated_charge_cycles - target_charge_cycles
if abs(gradient) < tolerance:
print(f"Optimal penalty found after {iteration+1} iterations")
break
# Optionally, adjust learning rate based on the change of gradient direction to avoid oscillation
if previous_gradient is not None and gradient * previous_gradient < 0:
learning_rate *= learning_rate_decay
@@ -129,9 +143,7 @@ class PolicyEvaluator:
previous_gradient = gradient
# Check if the charge cycles are close enough to the target
if abs(gradient) < tolerance:
print(f"Optimal penalty found after {iteration+1} iterations")
break
else:
print(
@@ -218,7 +230,7 @@ class PolicyEvaluator:
raise KeyboardInterrupt
except Exception as e:
print(e)
# print(e)
pass
self.profits = pd.DataFrame(
@@ -243,6 +255,8 @@ class PolicyEvaluator:
loggings = []
total_dates = 0
for date in tqdm(self.dates):
try:
(
@@ -272,15 +286,18 @@ class PolicyEvaluator:
}
loggings.append(new_info)
total_dates += 1
except KeyboardInterrupt:
print("Interrupted")
raise KeyboardInterrupt
except Exception as e:
print(e)
# print(e)
pass
print(f"Total Evaluated Dates: {total_dates}")
if log_metrics:
log_df = pd.DataFrame(loggings)