diff --git a/sdxl_minimal_inference.py b/sdxl_minimal_inference.py index 02a4af9c..25b8e51b 100644 --- a/sdxl_minimal_inference.py +++ b/sdxl_minimal_inference.py @@ -129,11 +129,11 @@ if __name__ == "__main__": unet.to(DEVICE, dtype=DTYPE) unet.eval() + vae_dtype = DTYPE if DTYPE == torch.float16: print("use float32 for vae") - vae.to(DEVICE, torch.float32) # avoid black image, same as no-half-vae - else: - vae.to(DEVICE, DTYPE) + vae_dtype = torch.float32 + vae.to(DEVICE, dtype=vae_dtype) vae.eval() text_model1.to(DEVICE, dtype=DTYPE) @@ -278,7 +278,7 @@ if __name__ == "__main__": # latents = 1 / 0.18215 * latents latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents - latents = latents.to(torch.float32) + latents = latents.to(vae_dtype) image = vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1)