support Diffusers' based SDXL LoRA key for inference

This commit is contained in:
Kohya S
2024-05-18 11:05:04 +09:00
parent 153764a687
commit 146edce693

View File

@@ -755,6 +755,52 @@ def get_block_index(lora_name: str, is_sdxl: bool = False) -> int:
return block_idx
def convert_diffusers_to_sai_if_needed(weights_sd):
# only supports U-Net LoRA modules
found_up_down_blocks = False
for k in list(weights_sd.keys()):
if "down_blocks" in k:
found_up_down_blocks = True
break
if "up_blocks" in k:
found_up_down_blocks = True
break
if not found_up_down_blocks:
return
from library.sdxl_model_util import make_unet_conversion_map
unet_conversion_map = make_unet_conversion_map()
unet_conversion_map = {hf.replace(".", "_")[:-1]: sd.replace(".", "_")[:-1] for sd, hf in unet_conversion_map}
# # add extra conversion
# unet_conversion_map["up_blocks_1_upsamplers_0"] = "lora_unet_output_blocks_2_2_conv"
logger.info(f"Converting LoRA keys from Diffusers to SAI")
lora_unet_prefix = "lora_unet_"
for k in list(weights_sd.keys()):
if not k.startswith(lora_unet_prefix):
continue
unet_module_name = k[len(lora_unet_prefix) :].split(".")[0]
# search for conversion: this is slow because the algorithm is O(n^2), but the number of keys is small
for hf_module_name, sd_module_name in unet_conversion_map.items():
if hf_module_name in unet_module_name:
new_key = (
lora_unet_prefix
+ unet_module_name.replace(hf_module_name, sd_module_name)
+ k[len(lora_unet_prefix) + len(unet_module_name) :]
)
weights_sd[new_key] = weights_sd.pop(k)
found = True
break
if not found:
logger.warning(f"Key {k} is not found in unet_conversion_map")
# Create network from weights for inference, weights are not loaded here (because can be merged)
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
# if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True
@@ -768,6 +814,9 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
else:
weights_sd = torch.load(file, map_location="cpu")
# if keys are Diffusers based, convert to SAI based
convert_diffusers_to_sai_if_needed(weights_sd)
# get dim/alpha mapping
modules_dim = {}
modules_alpha = {}