Saving whole model instead of weights only
This commit is contained in:
@@ -236,7 +236,7 @@ class DiffusionTrainer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def save_checkpoint(self, val_loss, task, iteration: int):
|
def save_checkpoint(self, val_loss, task, iteration: int):
|
||||||
torch.save(self.model.state_dict(), "checkpoint.pt")
|
torch.save(self.model, "checkpoint.pt")
|
||||||
task.update_output_model(
|
task.update_output_model(
|
||||||
model_path="checkpoint.pt", iteration=iteration, auto_delete_file=False
|
model_path="checkpoint.pt", iteration=iteration, auto_delete_file=False
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -279,7 +279,7 @@ class Trainer:
|
|||||||
return test_loss
|
return test_loss
|
||||||
|
|
||||||
def save_checkpoint(self, val_loss, task, iteration: int):
|
def save_checkpoint(self, val_loss, task, iteration: int):
|
||||||
torch.save(self.model.state_dict(), "checkpoint.pt")
|
torch.save(self.model, "checkpoint.pt")
|
||||||
task.update_output_model(
|
task.update_output_model(
|
||||||
model_path="checkpoint.pt", iteration=iteration, auto_delete_file=False
|
model_path="checkpoint.pt", iteration=iteration, auto_delete_file=False
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user