mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
support Diffusers' based SDXL LoRA key for inference
This commit is contained in:
@@ -755,6 +755,52 @@ def get_block_index(lora_name: str, is_sdxl: bool = False) -> int:
|
||||
return block_idx
|
||||
|
||||
|
||||
def convert_diffusers_to_sai_if_needed(weights_sd):
|
||||
# only supports U-Net LoRA modules
|
||||
|
||||
found_up_down_blocks = False
|
||||
for k in list(weights_sd.keys()):
|
||||
if "down_blocks" in k:
|
||||
found_up_down_blocks = True
|
||||
break
|
||||
if "up_blocks" in k:
|
||||
found_up_down_blocks = True
|
||||
break
|
||||
if not found_up_down_blocks:
|
||||
return
|
||||
|
||||
from library.sdxl_model_util import make_unet_conversion_map
|
||||
|
||||
unet_conversion_map = make_unet_conversion_map()
|
||||
unet_conversion_map = {hf.replace(".", "_")[:-1]: sd.replace(".", "_")[:-1] for sd, hf in unet_conversion_map}
|
||||
|
||||
# # add extra conversion
|
||||
# unet_conversion_map["up_blocks_1_upsamplers_0"] = "lora_unet_output_blocks_2_2_conv"
|
||||
|
||||
logger.info(f"Converting LoRA keys from Diffusers to SAI")
|
||||
lora_unet_prefix = "lora_unet_"
|
||||
for k in list(weights_sd.keys()):
|
||||
if not k.startswith(lora_unet_prefix):
|
||||
continue
|
||||
|
||||
unet_module_name = k[len(lora_unet_prefix) :].split(".")[0]
|
||||
|
||||
# search for conversion: this is slow because the algorithm is O(n^2), but the number of keys is small
|
||||
for hf_module_name, sd_module_name in unet_conversion_map.items():
|
||||
if hf_module_name in unet_module_name:
|
||||
new_key = (
|
||||
lora_unet_prefix
|
||||
+ unet_module_name.replace(hf_module_name, sd_module_name)
|
||||
+ k[len(lora_unet_prefix) + len(unet_module_name) :]
|
||||
)
|
||||
weights_sd[new_key] = weights_sd.pop(k)
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found:
|
||||
logger.warning(f"Key {k} is not found in unet_conversion_map")
|
||||
|
||||
|
||||
# Create network from weights for inference, weights are not loaded here (because can be merged)
|
||||
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
|
||||
# if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True
|
||||
@@ -768,6 +814,9 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
|
||||
# if keys are Diffusers based, convert to SAI based
|
||||
convert_diffusers_to_sai_if_needed(weights_sd)
|
||||
|
||||
# get dim/alpha mapping
|
||||
modules_dim = {}
|
||||
modules_alpha = {}
|
||||
|
||||
Reference in New Issue
Block a user