Compare commits

...

1 Commits

Author SHA1 Message Date
kohya-ss
dbd835ee4b train: Optimize VAE encoding by handling batch sizes for images 2025-04-08 21:57:16 +09:00

View File

@@ -640,14 +640,23 @@ def train(args):
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
with torch.no_grad():
# latentに変換
latents = vae.encode(batch["images"].to(vae_dtype)).latent_dist.sample().to(weight_dtype)
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.nan_to_num(latents, 0, out=latents)
if args.vae_batch_size is None or len(batch["images"]) <= args.vae_batch_size:
with torch.no_grad():
# latentに変換
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)
else:
chunks = [
batch["images"][i : i + args.vae_batch_size]
for i in range(0, len(batch["images"]), args.vae_batch_size)
]
list_latents = []
for chunk in chunks:
with torch.no_grad():
# latentに変換
list_latents.append(
vae.encode(chunk.to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)
)
latents = torch.cat(list_latents, dim=0)
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)