Fixing some stuff

This commit is contained in:
Victor Mylle
2023-12-30 15:22:32 +00:00
parent ef8b5f49ac
commit c26ae76951
6 changed files with 107 additions and 33 deletions

View File

@@ -99,6 +99,10 @@ class Trainer:
def train(self, epochs: int, remotely: bool = False, task: Task = None):
try:
_, full_day_skip_test_loader = self.data_processor.get_dataloaders(
predict_sequence_length=self.model.output_size, full_day_skip=True
)
train_loader, test_loader = self.data_processor.get_dataloaders(
predict_sequence_length=self.model.output_size
)
@@ -178,6 +182,11 @@ class Trainer:
# task, test_loader, False, epoch, True
# )
if hasattr(self, "calculate_crps_from_samples"):
self.calculate_crps_from_samples(
task, full_day_skip_test_loader, epoch
)
if task:
self.finish_training(task=task)
task.close()
@@ -243,12 +252,15 @@ class Trainer:
self.model.load_state_dict(torch.load("checkpoint.pt"))
self.model.eval()
# set full day skip
self.data_processor.set_full_day_skip(True)
train_loader, test_loader = self.data_processor.get_dataloaders(
predict_sequence_length=self.model.output_size
)
if not hasattr(self, "plot_quantile_percentages"):
self.log_final_metrics(task, train_loader, train=True)
# if not hasattr(self, "plot_quantile_percentages"):
# self.log_final_metrics(task, train_loader, train=True)
self.log_final_metrics(task, test_loader, train=False)