diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index d52f85a8..1a0f9551 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -15,6 +15,12 @@ import random import re import diffusers + +# Compatible import for diffusers old/new UNet path +try: + from diffusers.models.unet_2d_condition import UNet2DConditionModel +except ImportError: + from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel import numpy as np import torch @@ -80,7 +86,7 @@ CLIP_VISION_MODEL = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" """ -def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa): +def replace_unet_modules(unet: UNet2DConditionModel, mem_eff_attn, xformers, sdpa): if mem_eff_attn: logger.info("Enable memory efficient attention for U-Net")