mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
* feat: Add LoHa/LoKr network support for SDXL and Anima - networks/network_base.py: shared AdditionalNetwork base class with architecture auto-detection (SDXL/Anima) and generic module injection - networks/loha.py: LoHa (Low-rank Hadamard Product) module with HadaWeight custom autograd, training/inference classes, and factory functions - networks/lokr.py: LoKr (Low-rank Kronecker Product) module with factorization, training/inference classes, and factory functions - library/lora_utils.py: extend weight merge hook to detect and merge LoHa/LoKr weights alongside standard LoRA Linear and Conv2d 1x1 layers only; Conv2d 3x3 (Tucker decomposition) support will be added separately. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * feat: Enhance LoHa and LoKr modules with Tucker decomposition support - Added Tucker decomposition functionality to LoHa and LoKr modules. - Implemented new methods for weight rebuilding using Tucker decomposition. - Updated initialization and weight handling for Conv2d 3x3+ layers. - Modified get_diff_weight methods to accommodate Tucker and non-Tucker modes. - Enhanced network base to include unet_conv_target_modules for architecture detection. * fix: rank dropout handling in LoRAModule for Conv2d and Linear layers, see #2272 for details * doc: add dtype comment for load_safetensors_with_lora_and_fp8 function * fix: enhance architecture detection to support InferSdxlUNet2DConditionModel for gen_img.py * doc: update model support structure to include Lumina Image 2.0, HunyuanImage-2.1, and Anima-Preview * doc: add documentation for LoHa and LoKr fine-tuning methods * Update networks/network_base.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update docs/loha_lokr.md Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix: refactor LoHa and LoKr imports for weight merging in load_safetensors_with_lora_and_fp8 function --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
288 lines
14 KiB
Python
288 lines
14 KiB
Python
import os
|
|
import re
|
|
from typing import Dict, List, Optional, Union
|
|
import torch
|
|
from tqdm import tqdm
|
|
from library.device_utils import synchronize_device
|
|
from library.fp8_optimization_utils import load_safetensors_with_fp8_optimization
|
|
from library.safetensors_utils import MemoryEfficientSafeOpen, TensorWeightAdapter, WeightTransformHooks, get_split_weight_filenames
|
|
from networks.loha import merge_weights_to_tensor as loha_merge
|
|
from networks.lokr import merge_weights_to_tensor as lokr_merge
|
|
|
|
from library.utils import setup_logging
|
|
|
|
setup_logging()
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def filter_lora_state_dict(
|
|
weights_sd: Dict[str, torch.Tensor],
|
|
include_pattern: Optional[str] = None,
|
|
exclude_pattern: Optional[str] = None,
|
|
) -> Dict[str, torch.Tensor]:
|
|
# apply include/exclude patterns
|
|
original_key_count = len(weights_sd.keys())
|
|
if include_pattern is not None:
|
|
regex_include = re.compile(include_pattern)
|
|
weights_sd = {k: v for k, v in weights_sd.items() if regex_include.search(k)}
|
|
logger.info(f"Filtered keys with include pattern {include_pattern}: {original_key_count} -> {len(weights_sd.keys())}")
|
|
|
|
if exclude_pattern is not None:
|
|
original_key_count_ex = len(weights_sd.keys())
|
|
regex_exclude = re.compile(exclude_pattern)
|
|
weights_sd = {k: v for k, v in weights_sd.items() if not regex_exclude.search(k)}
|
|
logger.info(f"Filtered keys with exclude pattern {exclude_pattern}: {original_key_count_ex} -> {len(weights_sd.keys())}")
|
|
|
|
if len(weights_sd) != original_key_count:
|
|
remaining_keys = list(set([k.split(".", 1)[0] for k in weights_sd.keys()]))
|
|
remaining_keys.sort()
|
|
logger.info(f"Remaining LoRA modules after filtering: {remaining_keys}")
|
|
if len(weights_sd) == 0:
|
|
logger.warning("No keys left after filtering.")
|
|
|
|
return weights_sd
|
|
|
|
|
|
def load_safetensors_with_lora_and_fp8(
|
|
model_files: Union[str, List[str]],
|
|
lora_weights_list: Optional[List[Dict[str, torch.Tensor]]],
|
|
lora_multipliers: Optional[List[float]],
|
|
fp8_optimization: bool,
|
|
calc_device: torch.device,
|
|
move_to_device: bool = False,
|
|
dit_weight_dtype: Optional[torch.dtype] = None,
|
|
target_keys: Optional[List[str]] = None,
|
|
exclude_keys: Optional[List[str]] = None,
|
|
disable_numpy_memmap: bool = False,
|
|
weight_transform_hooks: Optional[WeightTransformHooks] = None,
|
|
) -> dict[str, torch.Tensor]:
|
|
"""
|
|
Merge LoRA weights into the state dict of a model with fp8 optimization if needed.
|
|
|
|
Args:
|
|
model_files (Union[str, List[str]]): Path to the model file or list of paths. If the path matches a pattern like `00001-of-00004`, it will load all files with the same prefix.
|
|
lora_weights_list (Optional[List[Dict[str, torch.Tensor]]]): List of dictionaries of LoRA weight tensors to load.
|
|
lora_multipliers (Optional[List[float]]): List of multipliers for LoRA weights.
|
|
fp8_optimization (bool): Whether to apply FP8 optimization.
|
|
calc_device (torch.device): Device to calculate on.
|
|
move_to_device (bool): Whether to move tensors to the calculation device after loading.
|
|
dit_weight_dtype (Optional[torch.dtype]): Dtype to load weights in when not using FP8 optimization.
|
|
target_keys (Optional[List[str]]): Keys to target for optimization.
|
|
exclude_keys (Optional[List[str]]): Keys to exclude from optimization.
|
|
disable_numpy_memmap (bool): Whether to disable numpy memmap when loading safetensors.
|
|
weight_transform_hooks (Optional[WeightTransformHooks]): Hooks for transforming weights during loading.
|
|
"""
|
|
|
|
# if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix
|
|
if isinstance(model_files, str):
|
|
model_files = [model_files]
|
|
|
|
extended_model_files = []
|
|
for model_file in model_files:
|
|
split_filenames = get_split_weight_filenames(model_file)
|
|
if split_filenames is not None:
|
|
extended_model_files.extend(split_filenames)
|
|
else:
|
|
extended_model_files.append(model_file)
|
|
model_files = extended_model_files
|
|
logger.info(f"Loading model files: {model_files}")
|
|
|
|
# load LoRA weights
|
|
weight_hook = None
|
|
if lora_weights_list is None or len(lora_weights_list) == 0:
|
|
lora_weights_list = []
|
|
lora_multipliers = []
|
|
list_of_lora_weight_keys = []
|
|
else:
|
|
list_of_lora_weight_keys = []
|
|
for lora_sd in lora_weights_list:
|
|
lora_weight_keys = set(lora_sd.keys())
|
|
list_of_lora_weight_keys.append(lora_weight_keys)
|
|
|
|
if lora_multipliers is None:
|
|
lora_multipliers = [1.0] * len(lora_weights_list)
|
|
while len(lora_multipliers) < len(lora_weights_list):
|
|
lora_multipliers.append(1.0)
|
|
if len(lora_multipliers) > len(lora_weights_list):
|
|
lora_multipliers = lora_multipliers[: len(lora_weights_list)]
|
|
|
|
# Merge LoRA weights into the state dict
|
|
logger.info(f"Merging LoRA weights into state dict. multipliers: {lora_multipliers}")
|
|
|
|
# make hook for LoRA merging
|
|
def weight_hook_func(model_weight_key, model_weight: torch.Tensor, keep_on_calc_device=False):
|
|
nonlocal list_of_lora_weight_keys, lora_weights_list, lora_multipliers, calc_device
|
|
|
|
if not model_weight_key.endswith(".weight"):
|
|
return model_weight
|
|
|
|
original_device = model_weight.device
|
|
if original_device != calc_device:
|
|
model_weight = model_weight.to(calc_device) # to make calculation faster
|
|
|
|
for lora_weight_keys, lora_sd, multiplier in zip(list_of_lora_weight_keys, lora_weights_list, lora_multipliers):
|
|
# check if this weight has LoRA weights
|
|
lora_name_without_prefix = model_weight_key.rsplit(".", 1)[0] # remove trailing ".weight"
|
|
found = False
|
|
for prefix in ["lora_unet_", ""]:
|
|
lora_name = prefix + lora_name_without_prefix.replace(".", "_")
|
|
down_key = lora_name + ".lora_down.weight"
|
|
up_key = lora_name + ".lora_up.weight"
|
|
alpha_key = lora_name + ".alpha"
|
|
if down_key in lora_weight_keys and up_key in lora_weight_keys:
|
|
found = True
|
|
break
|
|
|
|
if found:
|
|
# Standard LoRA merge
|
|
# get LoRA weights
|
|
down_weight = lora_sd[down_key]
|
|
up_weight = lora_sd[up_key]
|
|
|
|
dim = down_weight.size()[0]
|
|
alpha = lora_sd.get(alpha_key, dim)
|
|
scale = alpha / dim
|
|
|
|
down_weight = down_weight.to(calc_device)
|
|
up_weight = up_weight.to(calc_device)
|
|
|
|
original_dtype = model_weight.dtype
|
|
if original_dtype.itemsize == 1: # fp8
|
|
# temporarily convert to float16 for calculation
|
|
model_weight = model_weight.to(torch.float16)
|
|
down_weight = down_weight.to(torch.float16)
|
|
up_weight = up_weight.to(torch.float16)
|
|
|
|
# W <- W + U * D
|
|
if len(model_weight.size()) == 2:
|
|
# linear
|
|
if len(up_weight.size()) == 4: # use linear projection mismatch
|
|
up_weight = up_weight.squeeze(3).squeeze(2)
|
|
down_weight = down_weight.squeeze(3).squeeze(2)
|
|
model_weight = model_weight + multiplier * (up_weight @ down_weight) * scale
|
|
elif down_weight.size()[2:4] == (1, 1):
|
|
# conv2d 1x1
|
|
model_weight = (
|
|
model_weight
|
|
+ multiplier
|
|
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
|
* scale
|
|
)
|
|
else:
|
|
# conv2d 3x3
|
|
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
|
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
|
|
model_weight = model_weight + multiplier * conved * scale
|
|
|
|
if original_dtype.itemsize == 1: # fp8
|
|
model_weight = model_weight.to(original_dtype) # convert back to original dtype
|
|
|
|
# remove LoRA keys from set
|
|
lora_weight_keys.remove(down_key)
|
|
lora_weight_keys.remove(up_key)
|
|
if alpha_key in lora_weight_keys:
|
|
lora_weight_keys.remove(alpha_key)
|
|
continue
|
|
|
|
# Check for LoHa/LoKr weights with same prefix search
|
|
for prefix in ["lora_unet_", ""]:
|
|
lora_name = prefix + lora_name_without_prefix.replace(".", "_")
|
|
hada_key = lora_name + ".hada_w1_a"
|
|
lokr_key = lora_name + ".lokr_w1"
|
|
|
|
if hada_key in lora_weight_keys:
|
|
# LoHa merge
|
|
model_weight = loha_merge(model_weight, lora_name, lora_sd, lora_weight_keys, multiplier, calc_device)
|
|
break
|
|
elif lokr_key in lora_weight_keys:
|
|
# LoKr merge
|
|
model_weight = lokr_merge(model_weight, lora_name, lora_sd, lora_weight_keys, multiplier, calc_device)
|
|
break
|
|
|
|
if not keep_on_calc_device and original_device != calc_device:
|
|
model_weight = model_weight.to(original_device) # move back to original device
|
|
return model_weight
|
|
|
|
weight_hook = weight_hook_func
|
|
|
|
state_dict = load_safetensors_with_fp8_optimization_and_hook(
|
|
model_files,
|
|
fp8_optimization,
|
|
calc_device,
|
|
move_to_device,
|
|
dit_weight_dtype,
|
|
target_keys,
|
|
exclude_keys,
|
|
weight_hook=weight_hook,
|
|
disable_numpy_memmap=disable_numpy_memmap,
|
|
weight_transform_hooks=weight_transform_hooks,
|
|
)
|
|
|
|
for lora_weight_keys in list_of_lora_weight_keys:
|
|
# check if all LoRA keys are used
|
|
if len(lora_weight_keys) > 0:
|
|
# if there are still LoRA keys left, it means they are not used in the model
|
|
# this is a warning, not an error
|
|
logger.warning(f"Warning: not all LoRA keys are used: {', '.join(lora_weight_keys)}")
|
|
|
|
return state_dict
|
|
|
|
|
|
def load_safetensors_with_fp8_optimization_and_hook(
|
|
model_files: list[str],
|
|
fp8_optimization: bool,
|
|
calc_device: torch.device,
|
|
move_to_device: bool = False,
|
|
dit_weight_dtype: Optional[torch.dtype] = None,
|
|
target_keys: Optional[List[str]] = None,
|
|
exclude_keys: Optional[List[str]] = None,
|
|
weight_hook: callable = None,
|
|
disable_numpy_memmap: bool = False,
|
|
weight_transform_hooks: Optional[WeightTransformHooks] = None,
|
|
) -> dict[str, torch.Tensor]:
|
|
"""
|
|
Load state dict from safetensors files and merge LoRA weights into the state dict with fp8 optimization if needed.
|
|
"""
|
|
if fp8_optimization:
|
|
logger.info(
|
|
f"Loading state dict with FP8 optimization. Dtype of weight: {dit_weight_dtype}, hook enabled: {weight_hook is not None}"
|
|
)
|
|
# dit_weight_dtype is not used because we use fp8 optimization
|
|
state_dict = load_safetensors_with_fp8_optimization(
|
|
model_files,
|
|
calc_device,
|
|
target_keys,
|
|
exclude_keys,
|
|
move_to_device=move_to_device,
|
|
weight_hook=weight_hook,
|
|
disable_numpy_memmap=disable_numpy_memmap,
|
|
weight_transform_hooks=weight_transform_hooks,
|
|
)
|
|
else:
|
|
logger.info(
|
|
f"Loading state dict without FP8 optimization. Dtype of weight: {dit_weight_dtype}, hook enabled: {weight_hook is not None}"
|
|
)
|
|
state_dict = {}
|
|
for model_file in model_files:
|
|
with MemoryEfficientSafeOpen(model_file, disable_numpy_memmap=disable_numpy_memmap) as original_f:
|
|
f = TensorWeightAdapter(weight_transform_hooks, original_f) if weight_transform_hooks is not None else original_f
|
|
for key in tqdm(f.keys(), desc=f"Loading {os.path.basename(model_file)}", leave=False):
|
|
if weight_hook is None and move_to_device:
|
|
value = f.get_tensor(key, device=calc_device, dtype=dit_weight_dtype)
|
|
else:
|
|
value = f.get_tensor(key) # we cannot directly load to device because get_tensor does non-blocking transfer
|
|
if weight_hook is not None:
|
|
value = weight_hook(key, value, keep_on_calc_device=move_to_device)
|
|
if move_to_device:
|
|
value = value.to(calc_device, dtype=dit_weight_dtype, non_blocking=True)
|
|
elif dit_weight_dtype is not None:
|
|
value = value.to(dit_weight_dtype)
|
|
|
|
state_dict[key] = value
|
|
if move_to_device:
|
|
synchronize_device(calc_device)
|
|
|
|
return state_dict
|