Saving whole model instead of weights only

This commit is contained in:
Victor Mylle
2024-01-15 11:07:57 +00:00
parent c26ae76951
commit a977021dfc
2 changed files with 2 additions and 2 deletions

View File

@@ -236,7 +236,7 @@ class DiffusionTrainer:
) )
def save_checkpoint(self, val_loss, task, iteration: int): def save_checkpoint(self, val_loss, task, iteration: int):
torch.save(self.model.state_dict(), "checkpoint.pt") torch.save(self.model, "checkpoint.pt")
task.update_output_model( task.update_output_model(
model_path="checkpoint.pt", iteration=iteration, auto_delete_file=False model_path="checkpoint.pt", iteration=iteration, auto_delete_file=False
) )

View File

@@ -279,7 +279,7 @@ class Trainer:
return test_loss return test_loss
def save_checkpoint(self, val_loss, task, iteration: int): def save_checkpoint(self, val_loss, task, iteration: int):
torch.save(self.model.state_dict(), "checkpoint.pt") torch.save(self.model, "checkpoint.pt")
task.update_output_model( task.update_output_model(
model_path="checkpoint.pt", iteration=iteration, auto_delete_file=False model_path="checkpoint.pt", iteration=iteration, auto_delete_file=False
) )