speed up nan replace in sdxl training ref #1009

This commit is contained in:
Kohya S
2023-12-21 21:44:03 +09:00
parent 0676f1a86f
commit 04ef8d395f
3 changed files with 3 additions and 3 deletions

View File

@@ -394,7 +394,7 @@ def train(args):
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
latents = torch.nan_to_num(latents, 0, out=latents)
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None: