Saving diffusion model on better CRPS score
This commit is contained in:
@@ -28,6 +28,8 @@ class DiffusionTrainer:
|
|||||||
self.alpha = 1. - self.beta
|
self.alpha = 1. - self.beta
|
||||||
self.alpha_hat = torch.cumprod(self.alpha, dim=0)
|
self.alpha_hat = torch.cumprod(self.alpha, dim=0)
|
||||||
|
|
||||||
|
self.best_score = None
|
||||||
|
|
||||||
def noise_time_series(self, x: torch.tensor, t: int):
|
def noise_time_series(self, x: torch.tensor, t: int):
|
||||||
""" Add noise to time series
|
""" Add noise to time series
|
||||||
Args:
|
Args:
|
||||||
@@ -93,6 +95,7 @@ class DiffusionTrainer:
|
|||||||
self.data_processor = task.connect(self.data_processor, name="data_processor")
|
self.data_processor = task.connect(self.data_processor, name="data_processor")
|
||||||
|
|
||||||
def train(self, epochs: int, learning_rate: float, task: Task = None):
|
def train(self, epochs: int, learning_rate: float, task: Task = None):
|
||||||
|
self.best_score = None
|
||||||
optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)
|
optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)
|
||||||
criterion = nn.MSELoss()
|
criterion = nn.MSELoss()
|
||||||
self.model.to(self.device)
|
self.model.to(self.device)
|
||||||
@@ -220,6 +223,9 @@ class DiffusionTrainer:
|
|||||||
all_crps = np.array(all_crps)
|
all_crps = np.array(all_crps)
|
||||||
mean_crps = all_crps.mean()
|
mean_crps = all_crps.mean()
|
||||||
|
|
||||||
|
if self.best_score is None or mean_crps < self.best_score:
|
||||||
|
self.save_checkpoint(mean_crps, task, epoch)
|
||||||
|
|
||||||
if task:
|
if task:
|
||||||
task.get_logger().report_scalar(
|
task.get_logger().report_scalar(
|
||||||
title="CRPS",
|
title="CRPS",
|
||||||
@@ -227,4 +233,11 @@ class DiffusionTrainer:
|
|||||||
value=mean_crps,
|
value=mean_crps,
|
||||||
iteration=epoch
|
iteration=epoch
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def save_checkpoint(self, val_loss, task, iteration: int):
|
||||||
|
torch.save(self.model.state_dict(), "checkpoint.pt")
|
||||||
|
task.update_output_model(
|
||||||
|
model_path="checkpoint.pt", iteration=iteration, auto_delete_file=False
|
||||||
|
)
|
||||||
|
self.best_score = val_loss
|
||||||
|
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ data_config.DAY_OF_WEEK = False
|
|||||||
|
|
||||||
data_config.NOMINAL_NET_POSITION = True
|
data_config.NOMINAL_NET_POSITION = True
|
||||||
|
|
||||||
data_config = Task.connect(data_config, name="data_features")
|
data_config = task.connect(data_config, name="data_features")
|
||||||
|
|
||||||
data_processor = DataProcessor(data_config, path="", lstm=False)
|
data_processor = DataProcessor(data_config, path="", lstm=False)
|
||||||
data_processor.set_batch_size(8192)
|
data_processor.set_batch_size(8192)
|
||||||
|
|||||||
Reference in New Issue
Block a user