Refactor to avoid mutable global variable

This commit is contained in:
woctordho
2025-08-15 11:14:43 +08:00
parent c6fab554f4
commit 3ad71e1acf

View File

@@ -20,12 +20,12 @@ logger = logging.getLogger(__name__)
MIN_SV = 1e-6
# Tune layers to various trainer formats.
LORAFMT1 = ["lora_down", "lora_up"]
LORAFMT2 = ["lora.down", "lora.up"]
LORAFMT3 = ["lora_A", "lora_B"]
LORAFMT4 = ["down", "up"]
LORAFMT = LORAFMT1
LORA_DOWN_UP_FORMATS = [
("lora_down", "lora_up"), # sd-scripts LoRA
("lora_A", "lora_B"), # PEFT LoRA
("down", "up"), # ControlLoRA
]
# Model save and load functions
@@ -97,8 +97,8 @@ def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale
U = U @ torch.diag(S)
Vh = Vh[:lora_rank, :]
param_dict[LORAFMT[0]] = Vh.reshape(lora_rank, in_size, kernel_size, kernel_size).cpu()
param_dict[LORAFMT[1]] = U.reshape(out_size, lora_rank, 1, 1).cpu()
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size, kernel_size, kernel_size).cpu()
param_dict["lora_up"] = U.reshape(out_size, lora_rank, 1, 1).cpu()
del U, S, Vh, weight
return param_dict
@@ -116,8 +116,8 @@ def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, sca
U = U @ torch.diag(S)
Vh = Vh[:lora_rank, :]
param_dict[LORAFMT[0]] = Vh.reshape(lora_rank, in_size).cpu()
param_dict[LORAFMT[1]] = U.reshape(out_size, lora_rank).cpu()
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu()
param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu()
del U, S, Vh, weight
return param_dict
@@ -199,34 +199,11 @@ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
global LORAFMT
network_alpha = None
network_dim = None
max_old_rank = None
new_alpha = None
verbose_str = "\n"
fro_list = []
# Extract loaded lora dim and alpha
for key, value in lora_sd.items():
if network_alpha is None and "alpha" in key:
network_alpha = value
if (network_dim is None and len(value.size()) == 2
and (LORAFMT1[0] in key or LORAFMT2[0] in key or LORAFMT3[0] in key or LORAFMT4[0] in key)):
if LORAFMT1[0] in key:
LORAFMT = LORAFMT1
elif LORAFMT2[0] in key:
LORAFMT = LORAFMT2
elif LORAFMT3[0] in key:
LORAFMT = LORAFMT3
elif LORAFMT4[0] in key:
LORAFMT = LORAFMT4
network_dim = value.size()[0]
if network_alpha is not None and network_dim is not None:
break
if network_alpha is None:
network_alpha = network_dim
scale = network_alpha / network_dim
if dynamic_method:
logger.info(
f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}"
@@ -241,20 +218,33 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
with torch.no_grad():
for key, value in tqdm(lora_sd.items()):
weight_name = None
if LORAFMT[0] in key:
block_down_name = key.rsplit(f".{LORAFMT[0]}", 1)[0]
if key.endswith(f".{LORAFMT[0]}"):
key_parts = key.split(".")
block_down_name = None
for _format in LORA_DOWN_UP_FORMATS:
# Currently we only match lora_down_name in the last two parts of key
# because ("down", "up") are general words and may appear in block_down_name
if len(key_parts) >= 2 and _format[0] == key_parts[-2]:
block_down_name = ".".join(key_parts[:-2])
lora_down_name = "." + _format[0]
lora_up_name = "." + _format[1]
weight_name = "." + key_parts[-1]
break
if len(key_parts) >= 1 and _format[0] == key_parts[-1]:
block_down_name = ".".join(key_parts[:-1])
lora_down_name = "." + _format[0]
lora_up_name = "." + _format[1]
weight_name = ""
else:
weight_name = key.rsplit(f".{LORAFMT[0]}", 1)[-1]
lora_down_weight = value
else:
break
if block_down_name is None:
# This parameter is not lora_down
continue
# find corresponding lora_up and alpha
# Now weight_name can be ".weight" or ""
# Find corresponding lora_up and alpha
block_up_name = block_down_name
lora_up_weight = lora_sd.get(block_up_name + f".{LORAFMT[1]}" + weight_name, None)
lora_down_weight = value
lora_up_weight = lora_sd.get(block_up_name + lora_up_name + weight_name, None)
lora_alpha = lora_sd.get(block_down_name + ".alpha", None)
weights_loaded = lora_down_weight is not None and lora_up_weight is not None
@@ -262,10 +252,13 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
if weights_loaded:
conv2d = len(lora_down_weight.size()) == 4
old_rank = lora_down_weight.size()[0]
max_old_rank = max(max_old_rank or 0, old_rank)
if lora_alpha is None:
scale = 1.0
else:
scale = lora_alpha / lora_down_weight.size()[0]
scale = lora_alpha / old_rank
if conv2d:
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
@@ -292,9 +285,9 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
verbose_str += "\n"
new_alpha = param_dict["new_alpha"]
o_lora_sd[block_down_name + f".{LORAFMT[0]}" + weight_name] = param_dict[LORAFMT[0]].to(save_dtype).contiguous()
o_lora_sd[block_up_name + f".{LORAFMT[1]}" + weight_name] = param_dict[LORAFMT[1]].to(save_dtype).contiguous()
o_lora_sd[block_up_name + ".alpha"] = torch.tensor(param_dict["new_alpha"]).to(save_dtype)
o_lora_sd[block_down_name + lora_down_name + weight_name] = param_dict["lora_down"].to(save_dtype).contiguous()
o_lora_sd[block_up_name + lora_up_name + weight_name] = param_dict["lora_up"].to(save_dtype).contiguous()
o_lora_sd[block_down_name + ".alpha"] = torch.tensor(param_dict["new_alpha"]).to(save_dtype)
block_down_name = None
block_up_name = None
@@ -307,7 +300,7 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
print(verbose_str)
print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
logger.info("resizing complete")
return o_lora_sd, network_dim, new_alpha
return o_lora_sd, max_old_rank, new_alpha
def resize(args):