mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
Merge aba56a2ca2 into 1dae34b0af
This commit is contained in:
@@ -4,6 +4,7 @@ import argparse
|
||||
import os
|
||||
import time
|
||||
import concurrent.futures
|
||||
import re
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
@@ -110,15 +111,56 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, lbws,
|
||||
logger.info(f"lbw: {dict(zip(LAYER26.keys(), lbw_weights))}")
|
||||
|
||||
if method == "LoRA":
|
||||
def convert_diffusers_labels_to_unet(name: str) -> str:
|
||||
if "_attentions_" not in name:
|
||||
return name # attention-schema names only
|
||||
|
||||
# Normalize stage tokens everywhere, not only for attention names
|
||||
s = (name.replace("unet_up", "unet_output")
|
||||
.replace("unet_down", "unet_input")
|
||||
.replace("unet_mid", "unet_middle"))
|
||||
|
||||
# Middle: ...middle_block_attentions_X_* -> ...middle_block_{X+1}_*
|
||||
if "unet_middle" in s and "middle_block_attentions_" in s:
|
||||
return re.sub(
|
||||
r"middle_block_attentions_(\d+)_",
|
||||
lambda m: f"middle_block_{int(m.group(1)) + 1}_",
|
||||
s,
|
||||
)
|
||||
|
||||
left, right = s.split("_attentions_", 1)
|
||||
L = left.split("_") # lora_unet_[input|output]_blocks_{X}
|
||||
stage = L[2]
|
||||
X = int(L[-1]) # down/up block index in "wrong"
|
||||
Y_str, *rest = right.split("_")
|
||||
Y = int(Y_str) # attentions index in "wrong"
|
||||
# Remainder includes transformer/proj tail (kept as-is)
|
||||
# Map to "right" indices:
|
||||
if stage == "input": # from "down"
|
||||
if X == 1: i = 4 + Y # → input_blocks_{4|5}_1
|
||||
elif X == 2: i = 7 + Y # → input_blocks_{7|8}_1
|
||||
else: return s # no attentions elsewhere
|
||||
j = 1
|
||||
elif stage == "output": # from "up"
|
||||
i = 3 * X + Y # db=0→0..2, db=1→3..5, db=2→6..8
|
||||
j = 1
|
||||
else:
|
||||
return s
|
||||
|
||||
L[-1] = str(i)
|
||||
return "_".join(L + [str(j)] + rest)
|
||||
|
||||
for key in tqdm(lora_sd.keys()):
|
||||
if "lora_down" in key:
|
||||
up_key = key.replace("lora_down", "lora_up")
|
||||
alpha_key = key[: key.index("lora_down")] + "alpha"
|
||||
key_base = key[: key.index("lora_down")]
|
||||
up_key = key_base + "lora_up.weight"
|
||||
dora_key = key_base + "dora_scale"
|
||||
alpha_key = key_base + "alpha"
|
||||
|
||||
# find original module for this lora
|
||||
module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
|
||||
module_name = ".".join(convert_diffusers_labels_to_unet(key).split(".")[:-2]) # remove trailing ".lora_down.weight"
|
||||
if module_name not in name_to_module:
|
||||
logger.info(f"no module found for LoRA weight: {key}")
|
||||
logger.info(f"no module found for LoRA weight: {module_name}, from({key})")
|
||||
continue
|
||||
module = name_to_module[module_name]
|
||||
# logger.info(f"apply {key} to {module}")
|
||||
@@ -127,7 +169,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, lbws,
|
||||
up_weight = lora_sd[up_key]
|
||||
|
||||
dim = down_weight.size()[0]
|
||||
alpha = lora_sd.get(alpha_key, dim)
|
||||
alpha = lora_sd.get(alpha_key, 1.0)
|
||||
scale = alpha / dim
|
||||
|
||||
if lbw:
|
||||
@@ -138,23 +180,62 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, lbws,
|
||||
|
||||
# W <- W + U * D
|
||||
weight = module.weight
|
||||
lora_diff = None
|
||||
# logger.info(module_name, down_weight.size(), up_weight.size())
|
||||
if len(weight.size()) == 2:
|
||||
# linear
|
||||
weight = weight + ratio * (up_weight @ down_weight) * scale
|
||||
lora_diff = (up_weight @ down_weight)
|
||||
elif down_weight.size()[2:4] == (1, 1):
|
||||
# conv2d 1x1
|
||||
weight = (
|
||||
weight
|
||||
+ ratio
|
||||
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||
* scale
|
||||
)
|
||||
lora_diff = (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
# conv2d 3x3
|
||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
lora_diff = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
|
||||
weight = weight + ratio * conved * scale
|
||||
|
||||
dora_scale = lora_sd.get(dora_key, None)
|
||||
|
||||
# Algorithm/math taken from reForge
|
||||
if dora_scale is None:
|
||||
# -------- Plain LoRA (mirror your original math) --------
|
||||
# W <- W + ratio * (lora_diff * scale)
|
||||
weight = weight + (ratio * (lora_diff * scale)).to(dtype=weight.dtype, device=weight.device)
|
||||
|
||||
else:
|
||||
# -------- DoRA (literal reForge semantics) --------
|
||||
# cast dora_scale like reForge does (to intermediate, then we use weight.dtype for ops)
|
||||
ds = dora_scale.to(device=weight.device, dtype=merge_dtype)
|
||||
|
||||
# lora_diff gets 'alpha' (scale == alpha/rank) BEFORE magnitude; strength applied AFTER magnitude
|
||||
lora_diff_scaled = (lora_diff * scale).to(dtype=weight.dtype, device=weight.device)
|
||||
|
||||
# weight_calc = weight + function(lora_diff_scaled); function is identity here
|
||||
weight_calc = weight + lora_diff_scaled
|
||||
|
||||
wd_on_output_axis = (ds.shape[0] == weight_calc.shape[0])
|
||||
if wd_on_output_axis:
|
||||
# per-OUT norm taken from ORIGINAL weight (matches reForge)
|
||||
weight_norm = (
|
||||
weight.reshape(weight.shape[0], -1)
|
||||
.norm(dim=1, keepdim=True)
|
||||
.reshape(weight.shape[0], *[1] * (weight.dim() - 1))
|
||||
)
|
||||
else:
|
||||
# per-IN norm from weight_calc^T (matches reForge)
|
||||
wc = weight_calc.transpose(0, 1)
|
||||
weight_norm = (
|
||||
wc.reshape(wc.shape[0], -1)
|
||||
.norm(dim=1, keepdim=True)
|
||||
.reshape(wc.shape[0], *[1] * (wc.dim() - 1))
|
||||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
weight_norm = weight_norm + torch.finfo(weight.dtype).eps
|
||||
|
||||
# Apply magnitude: weight_calc *= (dora_scale / weight_norm)
|
||||
# (Do NOT reshape ds; rely on its stored shape for broadcasting)
|
||||
weight_calc = weight_calc * (ds.to(dtype=weight.dtype) / weight_norm)
|
||||
weight = torch.lerp(weight, weight_calc, float(ratio))
|
||||
|
||||
module.weight = torch.nn.Parameter(weight)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user