From 68162172ebf9afa21ad526fc833fcc04f74aeb5f Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 4 Aug 2024 15:03:56 +0800 Subject: [PATCH] Update train_db.py --- train_db.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train_db.py b/train_db.py index 9f8ec777..e98434db 100644 --- a/train_db.py +++ b/train_db.py @@ -209,10 +209,10 @@ def train(args): vae.eval() with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) - vae.to("cpu") if val_dataset_group is not None: print("Cache validation latents...") - val_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + val_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) + vae.to("cpu") clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone()