fix: remove duplicated latent normalization in decoding

This commit is contained in:
Kohya S
2025-07-15 21:58:03 +09:00
parent 25771a5180
commit c0c36a4e2f

View File

@@ -158,7 +158,7 @@ def generate_image(
# 5. Decode latents
#
logger.info("Decoding image...")
latents = latents / ae.scale_factor + ae.shift_factor
# latents = latents / ae.scale_factor + ae.shift_factor
with torch.no_grad():
image = ae.decode(latents.to(ae_dtype))
image = (image / 2 + 0.5).clamp(0, 1)