From 3b35547da0cc258fc41097e8926a8d66cde5fd66 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 4 Jul 2023 21:30:37 +0900 Subject: [PATCH] fix dtype for vae --- sdxl_minimal_inference.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sdxl_minimal_inference.py b/sdxl_minimal_inference.py index 02a4af9c..25b8e51b 100644 --- a/sdxl_minimal_inference.py +++ b/sdxl_minimal_inference.py @@ -129,11 +129,11 @@ if __name__ == "__main__": unet.to(DEVICE, dtype=DTYPE) unet.eval() + vae_dtype = DTYPE if DTYPE == torch.float16: print("use float32 for vae") - vae.to(DEVICE, torch.float32) # avoid black image, same as no-half-vae - else: - vae.to(DEVICE, DTYPE) + vae_dtype = torch.float32 + vae.to(DEVICE, dtype=vae_dtype) vae.eval() text_model1.to(DEVICE, dtype=DTYPE) @@ -278,7 +278,7 @@ if __name__ == "__main__": # latents = 1 / 0.18215 * latents latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents - latents = latents.to(torch.float32) + latents = latents.to(vae_dtype) image = vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1)