This commit is contained in:
Reithan
2026-03-30 13:08:26 +08:00
committed by GitHub

View File

@@ -4,6 +4,7 @@ import argparse
import os
import time
import concurrent.futures
import re
import torch
from safetensors.torch import load_file, save_file
from tqdm import tqdm
@@ -110,15 +111,56 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, lbws,
logger.info(f"lbw: {dict(zip(LAYER26.keys(), lbw_weights))}")
if method == "LoRA":
def convert_diffusers_labels_to_unet(name: str) -> str:
if "_attentions_" not in name:
return name # attention-schema names only
# Normalize stage tokens everywhere, not only for attention names
s = (name.replace("unet_up", "unet_output")
.replace("unet_down", "unet_input")
.replace("unet_mid", "unet_middle"))
# Middle: ...middle_block_attentions_X_* -> ...middle_block_{X+1}_*
if "unet_middle" in s and "middle_block_attentions_" in s:
return re.sub(
r"middle_block_attentions_(\d+)_",
lambda m: f"middle_block_{int(m.group(1)) + 1}_",
s,
)
left, right = s.split("_attentions_", 1)
L = left.split("_") # lora_unet_[input|output]_blocks_{X}
stage = L[2]
X = int(L[-1]) # down/up block index in "wrong"
Y_str, *rest = right.split("_")
Y = int(Y_str) # attentions index in "wrong"
# Remainder includes transformer/proj tail (kept as-is)
# Map to "right" indices:
if stage == "input": # from "down"
if X == 1: i = 4 + Y # → input_blocks_{4|5}_1
elif X == 2: i = 7 + Y # → input_blocks_{7|8}_1
else: return s # no attentions elsewhere
j = 1
elif stage == "output": # from "up"
i = 3 * X + Y # db=0→0..2, db=1→3..5, db=2→6..8
j = 1
else:
return s
L[-1] = str(i)
return "_".join(L + [str(j)] + rest)
for key in tqdm(lora_sd.keys()):
if "lora_down" in key:
up_key = key.replace("lora_down", "lora_up")
alpha_key = key[: key.index("lora_down")] + "alpha"
key_base = key[: key.index("lora_down")]
up_key = key_base + "lora_up.weight"
dora_key = key_base + "dora_scale"
alpha_key = key_base + "alpha"
# find original module for this lora
module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
module_name = ".".join(convert_diffusers_labels_to_unet(key).split(".")[:-2]) # remove trailing ".lora_down.weight"
if module_name not in name_to_module:
logger.info(f"no module found for LoRA weight: {key}")
logger.info(f"no module found for LoRA weight: {module_name}, from({key})")
continue
module = name_to_module[module_name]
# logger.info(f"apply {key} to {module}")
@@ -127,7 +169,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, lbws,
up_weight = lora_sd[up_key]
dim = down_weight.size()[0]
alpha = lora_sd.get(alpha_key, dim)
alpha = lora_sd.get(alpha_key, 1.0)
scale = alpha / dim
if lbw:
@@ -138,23 +180,62 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, lbws,
# W <- W + U * D
weight = module.weight
lora_diff = None
# logger.info(module_name, down_weight.size(), up_weight.size())
if len(weight.size()) == 2:
# linear
weight = weight + ratio * (up_weight @ down_weight) * scale
lora_diff = (up_weight @ down_weight)
elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1
weight = (
weight
+ ratio
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* scale
)
lora_diff = (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
lora_diff = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + ratio * conved * scale
dora_scale = lora_sd.get(dora_key, None)
# Algorithm/math taken from reForge
if dora_scale is None:
# -------- Plain LoRA (mirror your original math) --------
# W <- W + ratio * (lora_diff * scale)
weight = weight + (ratio * (lora_diff * scale)).to(dtype=weight.dtype, device=weight.device)
else:
# -------- DoRA (literal reForge semantics) --------
# cast dora_scale like reForge does (to intermediate, then we use weight.dtype for ops)
ds = dora_scale.to(device=weight.device, dtype=merge_dtype)
# lora_diff gets 'alpha' (scale == alpha/rank) BEFORE magnitude; strength applied AFTER magnitude
lora_diff_scaled = (lora_diff * scale).to(dtype=weight.dtype, device=weight.device)
# weight_calc = weight + function(lora_diff_scaled); function is identity here
weight_calc = weight + lora_diff_scaled
wd_on_output_axis = (ds.shape[0] == weight_calc.shape[0])
if wd_on_output_axis:
# per-OUT norm taken from ORIGINAL weight (matches reForge)
weight_norm = (
weight.reshape(weight.shape[0], -1)
.norm(dim=1, keepdim=True)
.reshape(weight.shape[0], *[1] * (weight.dim() - 1))
)
else:
# per-IN norm from weight_calc^T (matches reForge)
wc = weight_calc.transpose(0, 1)
weight_norm = (
wc.reshape(wc.shape[0], -1)
.norm(dim=1, keepdim=True)
.reshape(wc.shape[0], *[1] * (wc.dim() - 1))
.transpose(0, 1)
)
weight_norm = weight_norm + torch.finfo(weight.dtype).eps
# Apply magnitude: weight_calc *= (dora_scale / weight_norm)
# (Do NOT reshape ds; rely on its stored shape for broadcasting)
weight_calc = weight_calc * (ds.to(dtype=weight.dtype) / weight_norm)
weight = torch.lerp(weight, weight_calc, float(ratio))
module.weight = torch.nn.Parameter(weight)