Added diffusion validation set
This commit is contained in:
@@ -15,7 +15,17 @@ class PolicyEvaluator:
|
||||
self.baseline_policy = baseline_policy
|
||||
|
||||
self.ipc = ImbalancePriceCalculator(data_path="")
|
||||
|
||||
self.dates = baseline_policy.test_data["DateTime"].dt.date.unique()
|
||||
|
||||
# also add dates from last 2 months of 2023
|
||||
self.dates = np.append(
|
||||
self.dates,
|
||||
pd.date_range(
|
||||
start="2022-11-01", end="2022-12-31", freq="D"
|
||||
).to_pydatetime(),
|
||||
)
|
||||
|
||||
self.dates = pd.to_datetime(self.dates)
|
||||
|
||||
### Load Imbalance Prices ###
|
||||
@@ -116,6 +126,10 @@ class PolicyEvaluator:
|
||||
# Calculate the gradient (difference) between the simulated and target charge cycles
|
||||
gradient = simulated_charge_cycles - target_charge_cycles
|
||||
|
||||
if abs(gradient) < tolerance:
|
||||
print(f"Optimal penalty found after {iteration+1} iterations")
|
||||
break
|
||||
|
||||
# Optionally, adjust learning rate based on the change of gradient direction to avoid oscillation
|
||||
if previous_gradient is not None and gradient * previous_gradient < 0:
|
||||
learning_rate *= learning_rate_decay
|
||||
@@ -129,9 +143,7 @@ class PolicyEvaluator:
|
||||
previous_gradient = gradient
|
||||
|
||||
# Check if the charge cycles are close enough to the target
|
||||
if abs(gradient) < tolerance:
|
||||
print(f"Optimal penalty found after {iteration+1} iterations")
|
||||
break
|
||||
|
||||
|
||||
else:
|
||||
print(
|
||||
@@ -218,7 +230,7 @@ class PolicyEvaluator:
|
||||
raise KeyboardInterrupt
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
# print(e)
|
||||
pass
|
||||
|
||||
self.profits = pd.DataFrame(
|
||||
@@ -243,6 +255,8 @@ class PolicyEvaluator:
|
||||
|
||||
loggings = []
|
||||
|
||||
total_dates = 0
|
||||
|
||||
for date in tqdm(self.dates):
|
||||
try:
|
||||
(
|
||||
@@ -272,15 +286,18 @@ class PolicyEvaluator:
|
||||
}
|
||||
|
||||
loggings.append(new_info)
|
||||
total_dates += 1
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("Interrupted")
|
||||
raise KeyboardInterrupt
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
# print(e)
|
||||
pass
|
||||
|
||||
print(f"Total Evaluated Dates: {total_dates}")
|
||||
|
||||
if log_metrics:
|
||||
log_df = pd.DataFrame(loggings)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user