Fix to work without latent cache #1758

This commit is contained in:
Kohya S
2024-11-06 21:33:28 +09:00
parent 5e32ee26a1
commit 43849030cf

View File

@@ -885,7 +885,9 @@ def train(args):
else:
with torch.no_grad():
# encode images to latents. images are [-1, 1]
latents = vae.encode(batch["images"])
latents = vae.encode(batch["images"].to(vae.device, dtype=vae.dtype)).to(
accelerator.device, dtype=weight_dtype
)
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
@@ -927,7 +929,7 @@ def train(args):
if t5_out is None:
_, _, input_ids_t5xxl, _, _, t5_attn_mask = batch["input_ids_list"]
with torch.set_grad_enabled(train_t5xxl):
input_ids_t5xxl = input_ids_t5xxl.to("cpu") if t5_out is None else None
input_ids_t5xxl = input_ids_t5xxl.to("cpu")
_, t5_out, _, _, _, t5_attn_mask = text_encoding_strategy.encode_tokens(
sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl, None, None, t5_attn_mask]
)