mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
Fix to work without latent cache #1758
This commit is contained in:
@@ -885,7 +885,9 @@ def train(args):
|
||||
else:
|
||||
with torch.no_grad():
|
||||
# encode images to latents. images are [-1, 1]
|
||||
latents = vae.encode(batch["images"])
|
||||
latents = vae.encode(batch["images"].to(vae.device, dtype=vae.dtype)).to(
|
||||
accelerator.device, dtype=weight_dtype
|
||||
)
|
||||
|
||||
# NaNが含まれていれば警告を表示し0に置き換える
|
||||
if torch.any(torch.isnan(latents)):
|
||||
@@ -927,7 +929,7 @@ def train(args):
|
||||
if t5_out is None:
|
||||
_, _, input_ids_t5xxl, _, _, t5_attn_mask = batch["input_ids_list"]
|
||||
with torch.set_grad_enabled(train_t5xxl):
|
||||
input_ids_t5xxl = input_ids_t5xxl.to("cpu") if t5_out is None else None
|
||||
input_ids_t5xxl = input_ids_t5xxl.to("cpu")
|
||||
_, t5_out, _, _, _, t5_attn_mask = text_encoding_strategy.encode_tokens(
|
||||
sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl, None, None, t5_attn_mask]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user