Fixed policy evaluation for autoregressive
This commit is contained in:
@@ -45,7 +45,8 @@ class PolicyEvaluator:
|
||||
):
|
||||
idx = test_loader.dataset.get_idx_for_date(date.date())
|
||||
|
||||
print("Evaluated for idx: ", idx)
|
||||
if idx not in idx_samples:
|
||||
print("No samples for idx: ", idx, date)
|
||||
(initial, samples) = idx_samples[idx]
|
||||
|
||||
if len(initial.shape) == 2:
|
||||
@@ -98,16 +99,17 @@ class PolicyEvaluator:
|
||||
|
||||
def evaluate_test_set(self, idx_samples, test_loader):
|
||||
self.profits = []
|
||||
try:
|
||||
for date in tqdm(self.dates):
|
||||
self.evaluate_for_date(date, idx_samples, test_loader)
|
||||
except KeyboardInterrupt:
|
||||
print("Interrupted")
|
||||
raise KeyboardInterrupt
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
||||
for date in tqdm(self.dates):
|
||||
try:
|
||||
self.evaluate_for_date(date, idx_samples, test_loader)
|
||||
except KeyboardInterrupt:
|
||||
print("Interrupted")
|
||||
raise KeyboardInterrupt
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
||||
|
||||
self.profits = pd.DataFrame(
|
||||
self.profits,
|
||||
|
||||
Reference in New Issue
Block a user