diff --git a/library/sdxl_lpw_stable_diffusion.py b/library/sdxl_lpw_stable_diffusion.py index e03ee405..0562e88a 100644 --- a/library/sdxl_lpw_stable_diffusion.py +++ b/library/sdxl_lpw_stable_diffusion.py @@ -923,7 +923,11 @@ class SdxlStableDiffusionLongPromptWeightingPipeline: if up1 is not None: uncond_pool = up1 - dtype = self.unet.dtype + unet_dtype = self.unet.dtype + dtype = unet_dtype + if dtype.itemsize == 1: # fp8 + dtype = torch.float16 + self.unet.to(dtype) # 4. Preprocess image and mask if isinstance(image, PIL.Image.Image): @@ -1028,6 +1032,7 @@ class SdxlStableDiffusionLongPromptWeightingPipeline: if is_cancelled_callback is not None and is_cancelled_callback(): return None + self.unet.to(unet_dtype) return latents def latents_to_image(self, latents):