fix dtype for vae

This commit is contained in:
Kohya S
2023-07-04 21:30:37 +09:00
parent 6aa62b9b66
commit 3b35547da0

View File

@@ -129,11 +129,11 @@ if __name__ == "__main__":
unet.to(DEVICE, dtype=DTYPE)
unet.eval()
vae_dtype = DTYPE
if DTYPE == torch.float16:
print("use float32 for vae")
vae.to(DEVICE, torch.float32) # avoid black image, same as no-half-vae
else:
vae.to(DEVICE, DTYPE)
vae_dtype = torch.float32
vae.to(DEVICE, dtype=vae_dtype)
vae.eval()
text_model1.to(DEVICE, dtype=DTYPE)
@@ -278,7 +278,7 @@ if __name__ == "__main__":
# latents = 1 / 0.18215 * latents
latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents
latents = latents.to(torch.float32)
latents = latents.to(vae_dtype)
image = vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)