Fixed policy evaluation for autoregressive

This commit is contained in:
2024-02-29 23:23:11 +01:00
parent fe1e388ffb
commit 34335cd9fe
10 changed files with 191 additions and 95 deletions

View File

@@ -45,7 +45,8 @@ class PolicyEvaluator:
):
idx = test_loader.dataset.get_idx_for_date(date.date())
print("Evaluated for idx: ", idx)
if idx not in idx_samples:
print("No samples for idx: ", idx, date)
(initial, samples) = idx_samples[idx]
if len(initial.shape) == 2:
@@ -98,16 +99,17 @@ class PolicyEvaluator:
def evaluate_test_set(self, idx_samples, test_loader):
self.profits = []
try:
for date in tqdm(self.dates):
self.evaluate_for_date(date, idx_samples, test_loader)
except KeyboardInterrupt:
print("Interrupted")
raise KeyboardInterrupt
except Exception as e:
print(e)
pass
for date in tqdm(self.dates):
try:
self.evaluate_for_date(date, idx_samples, test_loader)
except KeyboardInterrupt:
print("Interrupted")
raise KeyboardInterrupt
except Exception as e:
print(e)
pass
self.profits = pd.DataFrame(
self.profits,