From 1f77bb6e73573d0bde67b94c76b549590833ece9 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 20 Jan 2024 10:57:42 +0900 Subject: [PATCH] fix to work sample generation in fp8 ref #1057 --- library/sdxl_lpw_stable_diffusion.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/library/sdxl_lpw_stable_diffusion.py b/library/sdxl_lpw_stable_diffusion.py index e03ee405..0562e88a 100644 --- a/library/sdxl_lpw_stable_diffusion.py +++ b/library/sdxl_lpw_stable_diffusion.py @@ -923,7 +923,11 @@ class SdxlStableDiffusionLongPromptWeightingPipeline: if up1 is not None: uncond_pool = up1 - dtype = self.unet.dtype + unet_dtype = self.unet.dtype + dtype = unet_dtype + if dtype.itemsize == 1: # fp8 + dtype = torch.float16 + self.unet.to(dtype) # 4. Preprocess image and mask if isinstance(image, PIL.Image.Image): @@ -1028,6 +1032,7 @@ class SdxlStableDiffusionLongPromptWeightingPipeline: if is_cancelled_callback is not None and is_cancelled_callback(): return None + self.unet.to(unet_dtype) return latents def latents_to_image(self, latents):