From 14aa2923cffe8ba4eddb75b76a55841fe4f15046 Mon Sep 17 00:00:00 2001 From: laksjdjf Date: Thu, 28 Sep 2023 14:39:32 +0900 Subject: [PATCH] Support concat LoRA --- networks/merge_lora.py | 39 +++++++++++++++++++++++++++++++----- networks/sdxl_merge_lora.py | 40 ++++++++++++++++++++++++++++++++----- 2 files changed, 69 insertions(+), 10 deletions(-) diff --git a/networks/merge_lora.py b/networks/merge_lora.py index c8d743f5..71492621 100644 --- a/networks/merge_lora.py +++ b/networks/merge_lora.py @@ -110,7 +110,7 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype): module.weight = torch.nn.Parameter(weight) -def merge_lora_models(models, ratios, merge_dtype): +def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): base_alphas = {} # alpha for merged model base_dims = {} @@ -158,6 +158,12 @@ def merge_lora_models(models, ratios, merge_dtype): for key in lora_sd.keys(): if "alpha" in key: continue + if "lora_up" in key and concat: + concat_dim = 1 + elif "lora_down" in key and concat: + concat_dim = 0 + else: + concat_dim = None lora_module_name = key[: key.rfind(".lora_")] @@ -165,12 +171,16 @@ def merge_lora_models(models, ratios, merge_dtype): alpha = alphas[lora_module_name] scale = math.sqrt(alpha / base_alpha) * ratio + scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。 if key in merged_sd: assert ( - merged_sd[key].size() == lora_sd[key].size() + merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません" - merged_sd[key] = merged_sd[key] + lora_sd[key] * scale + if concat_dim is not None: + merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim) + else: + merged_sd[key] = merged_sd[key] + lora_sd[key] * scale else: merged_sd[key] = lora_sd[key] * scale @@ -178,6 +188,13 @@ def merge_lora_models(models, ratios, merge_dtype): for lora_module_name, alpha in base_alphas.items(): key = lora_module_name + ".alpha" merged_sd[key] = torch.tensor(alpha) + if shuffle: + key_down = lora_module_name + ".lora_down.weight" + key_up = lora_module_name + ".lora_up.weight" + dim = merged_sd[key_down].shape[0] + perm = torch.randperm(dim) + merged_sd[key_down] = merged_sd[key_down][perm] + merged_sd[key_up] = merged_sd[key_up][:,perm] print("merged model") print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") @@ -256,7 +273,7 @@ def merge(args): args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, sai_metadata, save_dtype, vae ) else: - state_dict, metadata, v2 = merge_lora_models(args.models, args.ratios, merge_dtype) + state_dict, metadata, v2 = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) print(f"calculating hashes and creating metadata...") @@ -317,7 +334,19 @@ def setup_parser() -> argparse.ArgumentParser: help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / " + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)", ) - + parser.add_argument( + "--concat", + action="store_true", + help="concat lora instead of merge (The dim(rank) of the output LoRA is the sum of the input dims) / " + + "マージの代わりに結合する(LoRAのdim(rank)は入力dimの合計になる)", + ) + parser.add_argument( + "--shuffle", + action="store_true", + help="shuffle lora weight./ " + + "LoRAの重みをシャッフルする", + ) + return parser diff --git a/networks/sdxl_merge_lora.py b/networks/sdxl_merge_lora.py index 0608c01f..c513eb59 100644 --- a/networks/sdxl_merge_lora.py +++ b/networks/sdxl_merge_lora.py @@ -113,7 +113,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ module.weight = torch.nn.Parameter(weight) -def merge_lora_models(models, ratios, merge_dtype): +def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): base_alphas = {} # alpha for merged model base_dims = {} @@ -161,6 +161,13 @@ def merge_lora_models(models, ratios, merge_dtype): for key in tqdm(lora_sd.keys()): if "alpha" in key: continue + + if "lora_up" in key and concat: + concat_dim = 1 + elif "lora_down" in key and concat: + concat_dim = 0 + else: + concat_dim = None lora_module_name = key[: key.rfind(".lora_")] @@ -168,12 +175,16 @@ def merge_lora_models(models, ratios, merge_dtype): alpha = alphas[lora_module_name] scale = math.sqrt(alpha / base_alpha) * ratio - + scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。 + if key in merged_sd: assert ( - merged_sd[key].size() == lora_sd[key].size() + merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません" - merged_sd[key] = merged_sd[key] + lora_sd[key] * scale + if concat_dim is not None: + merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim) + else: + merged_sd[key] = merged_sd[key] + lora_sd[key] * scale else: merged_sd[key] = lora_sd[key] * scale @@ -181,6 +192,13 @@ def merge_lora_models(models, ratios, merge_dtype): for lora_module_name, alpha in base_alphas.items(): key = lora_module_name + ".alpha" merged_sd[key] = torch.tensor(alpha) + if shuffle: + key_down = lora_module_name + ".lora_down.weight" + key_up = lora_module_name + ".lora_up.weight" + dim = merged_sd[key_down].shape[0] + perm = torch.randperm(dim) + merged_sd[key_down] = merged_sd[key_down][perm] + merged_sd[key_up] = merged_sd[key_up][:,perm] print("merged model") print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") @@ -252,7 +270,7 @@ def merge(args): args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, logit_scale, sai_metadata, save_dtype ) else: - state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype) + state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle) print(f"calculating hashes and creating metadata...") @@ -307,6 +325,18 @@ def setup_parser() -> argparse.ArgumentParser: help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / " + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)", ) + parser.add_argument( + "--concat", + action="store_true", + help="concat lora instead of merge (The dim(rank) of the output LoRA is the sum of the input dims) / " + + "マージの代わりに結合する(LoRAのdim(rank)は入力dimの合計になる)", + ) + parser.add_argument( + "--shuffle", + action="store_true", + help="shuffle lora weight./ " + + "LoRAの重みをシャッフルする", + ) return parser