mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
Merge branch 'dev' into sd3
This commit is contained in:
@@ -20,6 +20,13 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
MIN_SV = 1e-6
|
||||
|
||||
LORA_DOWN_UP_FORMATS = [
|
||||
("lora_down", "lora_up"), # sd-scripts LoRA
|
||||
("lora_A", "lora_B"), # PEFT LoRA
|
||||
("down", "up"), # ControlLoRA
|
||||
]
|
||||
|
||||
|
||||
# Model save and load functions
|
||||
|
||||
|
||||
@@ -192,24 +199,11 @@ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
|
||||
|
||||
|
||||
def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
|
||||
network_alpha = None
|
||||
network_dim = None
|
||||
max_old_rank = None
|
||||
new_alpha = None
|
||||
verbose_str = "\n"
|
||||
fro_list = []
|
||||
|
||||
# Extract loaded lora dim and alpha
|
||||
for key, value in lora_sd.items():
|
||||
if network_alpha is None and "alpha" in key:
|
||||
network_alpha = value
|
||||
if network_dim is None and "lora_down" in key and len(value.size()) == 2:
|
||||
network_dim = value.size()[0]
|
||||
if network_alpha is not None and network_dim is not None:
|
||||
break
|
||||
if network_alpha is None:
|
||||
network_alpha = network_dim
|
||||
|
||||
scale = network_alpha / network_dim
|
||||
|
||||
if dynamic_method:
|
||||
logger.info(
|
||||
f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}"
|
||||
@@ -224,17 +218,33 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
|
||||
|
||||
with torch.no_grad():
|
||||
for key, value in tqdm(lora_sd.items()):
|
||||
weight_name = None
|
||||
if "lora_down" in key:
|
||||
block_down_name = key.rsplit(".lora_down", 1)[0]
|
||||
weight_name = key.rsplit(".", 1)[-1]
|
||||
lora_down_weight = value
|
||||
else:
|
||||
key_parts = key.split(".")
|
||||
block_down_name = None
|
||||
for _format in LORA_DOWN_UP_FORMATS:
|
||||
# Currently we only match lora_down_name in the last two parts of key
|
||||
# because ("down", "up") are general words and may appear in block_down_name
|
||||
if len(key_parts) >= 2 and _format[0] == key_parts[-2]:
|
||||
block_down_name = ".".join(key_parts[:-2])
|
||||
lora_down_name = "." + _format[0]
|
||||
lora_up_name = "." + _format[1]
|
||||
weight_name = "." + key_parts[-1]
|
||||
break
|
||||
if len(key_parts) >= 1 and _format[0] == key_parts[-1]:
|
||||
block_down_name = ".".join(key_parts[:-1])
|
||||
lora_down_name = "." + _format[0]
|
||||
lora_up_name = "." + _format[1]
|
||||
weight_name = ""
|
||||
break
|
||||
|
||||
if block_down_name is None:
|
||||
# This parameter is not lora_down
|
||||
continue
|
||||
|
||||
# find corresponding lora_up and alpha
|
||||
# Now weight_name can be ".weight" or ""
|
||||
# Find corresponding lora_up and alpha
|
||||
block_up_name = block_down_name
|
||||
lora_up_weight = lora_sd.get(block_up_name + ".lora_up." + weight_name, None)
|
||||
lora_down_weight = value
|
||||
lora_up_weight = lora_sd.get(block_up_name + lora_up_name + weight_name, None)
|
||||
lora_alpha = lora_sd.get(block_down_name + ".alpha", None)
|
||||
|
||||
weights_loaded = lora_down_weight is not None and lora_up_weight is not None
|
||||
@@ -242,10 +252,13 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
|
||||
if weights_loaded:
|
||||
|
||||
conv2d = len(lora_down_weight.size()) == 4
|
||||
old_rank = lora_down_weight.size()[0]
|
||||
max_old_rank = max(max_old_rank or 0, old_rank)
|
||||
|
||||
if lora_alpha is None:
|
||||
scale = 1.0
|
||||
else:
|
||||
scale = lora_alpha / lora_down_weight.size()[0]
|
||||
scale = lora_alpha / old_rank
|
||||
|
||||
if conv2d:
|
||||
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
|
||||
@@ -272,9 +285,9 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
|
||||
verbose_str += "\n"
|
||||
|
||||
new_alpha = param_dict["new_alpha"]
|
||||
o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous()
|
||||
o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous()
|
||||
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict["new_alpha"]).to(save_dtype)
|
||||
o_lora_sd[block_down_name + lora_down_name + weight_name] = param_dict["lora_down"].to(save_dtype).contiguous()
|
||||
o_lora_sd[block_up_name + lora_up_name + weight_name] = param_dict["lora_up"].to(save_dtype).contiguous()
|
||||
o_lora_sd[block_down_name + ".alpha"] = torch.tensor(param_dict["new_alpha"]).to(save_dtype)
|
||||
|
||||
block_down_name = None
|
||||
block_up_name = None
|
||||
@@ -287,7 +300,7 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
|
||||
print(verbose_str)
|
||||
print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
|
||||
logger.info("resizing complete")
|
||||
return o_lora_sd, network_dim, new_alpha
|
||||
return o_lora_sd, max_old_rank, new_alpha
|
||||
|
||||
|
||||
def resize(args):
|
||||
|
||||
Reference in New Issue
Block a user