Merge branch 'dev' into sd3

This commit is contained in:
Kohya S
2026-02-13 08:14:21 +09:00

View File

@@ -15,6 +15,12 @@ import random
import re
import diffusers
# Compatible import for diffusers old/new UNet path
try:
from diffusers.models.unet_2d_condition import UNet2DConditionModel
except ImportError:
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
import numpy as np
import torch
@@ -80,7 +86,7 @@ CLIP_VISION_MODEL = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
"""
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
def replace_unet_modules(unet: UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
if mem_eff_attn:
logger.info("Enable memory efficient attention for U-Net")