mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
Compare commits
1 Commits
feature-ch
...
vae_batch_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dbd835ee4b |
@@ -640,14 +640,23 @@ def train(args):
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
# latentに変換
|
||||
latents = vae.encode(batch["images"].to(vae_dtype)).latent_dist.sample().to(weight_dtype)
|
||||
|
||||
# NaNが含まれていれば警告を表示し0に置き換える
|
||||
if torch.any(torch.isnan(latents)):
|
||||
accelerator.print("NaN found in latents, replacing with zeros")
|
||||
latents = torch.nan_to_num(latents, 0, out=latents)
|
||||
if args.vae_batch_size is None or len(batch["images"]) <= args.vae_batch_size:
|
||||
with torch.no_grad():
|
||||
# latentに変換
|
||||
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)
|
||||
else:
|
||||
chunks = [
|
||||
batch["images"][i : i + args.vae_batch_size]
|
||||
for i in range(0, len(batch["images"]), args.vae_batch_size)
|
||||
]
|
||||
list_latents = []
|
||||
for chunk in chunks:
|
||||
with torch.no_grad():
|
||||
# latentに変換
|
||||
list_latents.append(
|
||||
vae.encode(chunk.to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)
|
||||
)
|
||||
latents = torch.cat(list_latents, dim=0)
|
||||
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
|
||||
|
||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||
|
||||
Reference in New Issue
Block a user