From 3ad71e1acf0f38b31a623e5644ed0e375ee35951 Mon Sep 17 00:00:00 2001 From: woctordho Date: Fri, 15 Aug 2025 11:14:43 +0800 Subject: [PATCH] Refactor to avoid mutable global variable --- networks/resize_lora.py | 93 +++++++++++++++++++---------------------- 1 file changed, 43 insertions(+), 50 deletions(-) diff --git a/networks/resize_lora.py b/networks/resize_lora.py index d8a37da2..18326437 100644 --- a/networks/resize_lora.py +++ b/networks/resize_lora.py @@ -20,12 +20,12 @@ logger = logging.getLogger(__name__) MIN_SV = 1e-6 -# Tune layers to various trainer formats. -LORAFMT1 = ["lora_down", "lora_up"] -LORAFMT2 = ["lora.down", "lora.up"] -LORAFMT3 = ["lora_A", "lora_B"] -LORAFMT4 = ["down", "up"] -LORAFMT = LORAFMT1 +LORA_DOWN_UP_FORMATS = [ + ("lora_down", "lora_up"), # sd-scripts LoRA + ("lora_A", "lora_B"), # PEFT LoRA + ("down", "up"), # ControlLoRA +] + # Model save and load functions @@ -97,8 +97,8 @@ def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale U = U @ torch.diag(S) Vh = Vh[:lora_rank, :] - param_dict[LORAFMT[0]] = Vh.reshape(lora_rank, in_size, kernel_size, kernel_size).cpu() - param_dict[LORAFMT[1]] = U.reshape(out_size, lora_rank, 1, 1).cpu() + param_dict["lora_down"] = Vh.reshape(lora_rank, in_size, kernel_size, kernel_size).cpu() + param_dict["lora_up"] = U.reshape(out_size, lora_rank, 1, 1).cpu() del U, S, Vh, weight return param_dict @@ -116,8 +116,8 @@ def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, sca U = U @ torch.diag(S) Vh = Vh[:lora_rank, :] - param_dict[LORAFMT[0]] = Vh.reshape(lora_rank, in_size).cpu() - param_dict[LORAFMT[1]] = U.reshape(out_size, lora_rank).cpu() + param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu() + param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu() del U, S, Vh, weight return param_dict @@ -199,34 +199,11 @@ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1): def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dynamic_method, dynamic_param, verbose): - global LORAFMT - network_alpha = None - network_dim = None + max_old_rank = None + new_alpha = None verbose_str = "\n" fro_list = [] - # Extract loaded lora dim and alpha - for key, value in lora_sd.items(): - if network_alpha is None and "alpha" in key: - network_alpha = value - if (network_dim is None and len(value.size()) == 2 - and (LORAFMT1[0] in key or LORAFMT2[0] in key or LORAFMT3[0] in key or LORAFMT4[0] in key)): - if LORAFMT1[0] in key: - LORAFMT = LORAFMT1 - elif LORAFMT2[0] in key: - LORAFMT = LORAFMT2 - elif LORAFMT3[0] in key: - LORAFMT = LORAFMT3 - elif LORAFMT4[0] in key: - LORAFMT = LORAFMT4 - network_dim = value.size()[0] - if network_alpha is not None and network_dim is not None: - break - if network_alpha is None: - network_alpha = network_dim - - scale = network_alpha / network_dim - if dynamic_method: logger.info( f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}" @@ -241,20 +218,33 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna with torch.no_grad(): for key, value in tqdm(lora_sd.items()): - weight_name = None - if LORAFMT[0] in key: - block_down_name = key.rsplit(f".{LORAFMT[0]}", 1)[0] - if key.endswith(f".{LORAFMT[0]}"): + key_parts = key.split(".") + block_down_name = None + for _format in LORA_DOWN_UP_FORMATS: + # Currently we only match lora_down_name in the last two parts of key + # because ("down", "up") are general words and may appear in block_down_name + if len(key_parts) >= 2 and _format[0] == key_parts[-2]: + block_down_name = ".".join(key_parts[:-2]) + lora_down_name = "." + _format[0] + lora_up_name = "." + _format[1] + weight_name = "." + key_parts[-1] + break + if len(key_parts) >= 1 and _format[0] == key_parts[-1]: + block_down_name = ".".join(key_parts[:-1]) + lora_down_name = "." + _format[0] + lora_up_name = "." + _format[1] weight_name = "" - else: - weight_name = key.rsplit(f".{LORAFMT[0]}", 1)[-1] - lora_down_weight = value - else: + break + + if block_down_name is None: + # This parameter is not lora_down continue - # find corresponding lora_up and alpha + # Now weight_name can be ".weight" or "" + # Find corresponding lora_up and alpha block_up_name = block_down_name - lora_up_weight = lora_sd.get(block_up_name + f".{LORAFMT[1]}" + weight_name, None) + lora_down_weight = value + lora_up_weight = lora_sd.get(block_up_name + lora_up_name + weight_name, None) lora_alpha = lora_sd.get(block_down_name + ".alpha", None) weights_loaded = lora_down_weight is not None and lora_up_weight is not None @@ -262,10 +252,13 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna if weights_loaded: conv2d = len(lora_down_weight.size()) == 4 + old_rank = lora_down_weight.size()[0] + max_old_rank = max(max_old_rank or 0, old_rank) + if lora_alpha is None: scale = 1.0 else: - scale = lora_alpha / lora_down_weight.size()[0] + scale = lora_alpha / old_rank if conv2d: full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device) @@ -292,9 +285,9 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna verbose_str += "\n" new_alpha = param_dict["new_alpha"] - o_lora_sd[block_down_name + f".{LORAFMT[0]}" + weight_name] = param_dict[LORAFMT[0]].to(save_dtype).contiguous() - o_lora_sd[block_up_name + f".{LORAFMT[1]}" + weight_name] = param_dict[LORAFMT[1]].to(save_dtype).contiguous() - o_lora_sd[block_up_name + ".alpha"] = torch.tensor(param_dict["new_alpha"]).to(save_dtype) + o_lora_sd[block_down_name + lora_down_name + weight_name] = param_dict["lora_down"].to(save_dtype).contiguous() + o_lora_sd[block_up_name + lora_up_name + weight_name] = param_dict["lora_up"].to(save_dtype).contiguous() + o_lora_sd[block_down_name + ".alpha"] = torch.tensor(param_dict["new_alpha"]).to(save_dtype) block_down_name = None block_up_name = None @@ -307,7 +300,7 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna print(verbose_str) print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}") logger.info("resizing complete") - return o_lora_sd, network_dim, new_alpha + return o_lora_sd, max_old_rank, new_alpha def resize(args):