Update train_db.py

This commit is contained in:
gesen2egee
2024-08-04 15:03:56 +08:00
committed by GitHub
parent 1db495127f
commit 68162172eb

View File

@@ -209,10 +209,10 @@ def train(args):
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if val_dataset_group is not None:
print("Cache validation latents...")
val_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
val_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()