Fixed policy evaluation for autoregressive

This commit is contained in:
2024-02-29 23:23:11 +01:00
parent fe1e388ffb
commit 34335cd9fe
10 changed files with 191 additions and 95 deletions

View File

@@ -8,6 +8,7 @@ from plotly.subplots import make_subplots
from clearml.config import running_remotely
from torchinfo import summary
class Trainer:
def __init__(
self,
@@ -95,13 +96,15 @@ class Trainer:
loader = test_loader
np.random.seed(42)
actual_indices = np.random.choice(loader.dataset.full_day_valid_indices, num_samples, replace=False)
actual_indices = np.random.choice(
loader.dataset.full_day_valid_indices, num_samples, replace=False
)
indices = {}
for i in actual_indices:
indices[i] = loader.dataset.valid_indices.index(i)
print(actual_indices)
return indices
def train(self, epochs: int, remotely: bool = False, task: Task = None):
@@ -190,9 +193,7 @@ class Trainer:
# )
if hasattr(self, "calculate_crps_from_samples"):
self.calculate_crps_from_samples(
task, full_day_skip_test_loader, epoch
)
self.calculate_crps_from_samples(task, test_loader, epoch)
if task:
self.finish_training(task=task)
@@ -259,7 +260,6 @@ class Trainer:
self.model = torch.load("checkpoint.pt")
self.model.eval()
# set full day skip
self.data_processor.set_full_day_skip(True)
train_loader, test_loader = self.data_processor.get_dataloaders(
@@ -361,7 +361,6 @@ class Trainer:
for trace in sub_fig.data:
fig.add_trace(trace, row=row, col=col)
# loss = self.criterion(predictions.to(self.device), target.squeeze(-1).to(self.device)).item()
# fig['layout']['annotations'][i].update(text=f"{loss.__class__.__name__}: {loss:.6f}")