mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
fix dtype for vae
This commit is contained in:
@@ -129,11 +129,11 @@ if __name__ == "__main__":
|
||||
unet.to(DEVICE, dtype=DTYPE)
|
||||
unet.eval()
|
||||
|
||||
vae_dtype = DTYPE
|
||||
if DTYPE == torch.float16:
|
||||
print("use float32 for vae")
|
||||
vae.to(DEVICE, torch.float32) # avoid black image, same as no-half-vae
|
||||
else:
|
||||
vae.to(DEVICE, DTYPE)
|
||||
vae_dtype = torch.float32
|
||||
vae.to(DEVICE, dtype=vae_dtype)
|
||||
vae.eval()
|
||||
|
||||
text_model1.to(DEVICE, dtype=DTYPE)
|
||||
@@ -278,7 +278,7 @@ if __name__ == "__main__":
|
||||
|
||||
# latents = 1 / 0.18215 * latents
|
||||
latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents
|
||||
latents = latents.to(torch.float32)
|
||||
latents = latents.to(vae_dtype)
|
||||
image = vae.decode(latents).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user