Fixing some stuff
This commit is contained in:
@@ -99,6 +99,10 @@ class Trainer:
|
||||
|
||||
def train(self, epochs: int, remotely: bool = False, task: Task = None):
|
||||
try:
|
||||
_, full_day_skip_test_loader = self.data_processor.get_dataloaders(
|
||||
predict_sequence_length=self.model.output_size, full_day_skip=True
|
||||
)
|
||||
|
||||
train_loader, test_loader = self.data_processor.get_dataloaders(
|
||||
predict_sequence_length=self.model.output_size
|
||||
)
|
||||
@@ -178,6 +182,11 @@ class Trainer:
|
||||
# task, test_loader, False, epoch, True
|
||||
# )
|
||||
|
||||
if hasattr(self, "calculate_crps_from_samples"):
|
||||
self.calculate_crps_from_samples(
|
||||
task, full_day_skip_test_loader, epoch
|
||||
)
|
||||
|
||||
if task:
|
||||
self.finish_training(task=task)
|
||||
task.close()
|
||||
@@ -243,12 +252,15 @@ class Trainer:
|
||||
self.model.load_state_dict(torch.load("checkpoint.pt"))
|
||||
self.model.eval()
|
||||
|
||||
|
||||
# set full day skip
|
||||
self.data_processor.set_full_day_skip(True)
|
||||
train_loader, test_loader = self.data_processor.get_dataloaders(
|
||||
predict_sequence_length=self.model.output_size
|
||||
)
|
||||
|
||||
if not hasattr(self, "plot_quantile_percentages"):
|
||||
self.log_final_metrics(task, train_loader, train=True)
|
||||
# if not hasattr(self, "plot_quantile_percentages"):
|
||||
# self.log_final_metrics(task, train_loader, train=True)
|
||||
|
||||
self.log_final_metrics(task, test_loader, train=False)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user