Files
Kohya-ss-sd-scripts/networks/lora_anima.py
Kohya S. 34e7138b6a Add/modify some implementation for anima (#2261)
* fix: update extend-exclude list in _typos.toml to include configs

* fix: exclude anima tests from pytest

* feat: add entry for 'temperal' in extend-words section of _typos.toml for Qwen-Image VAE

* fix: update default value for --discrete_flow_shift in anima training guide

* feat: add Qwen-Image VAE

* feat: simplify encode_tokens

* feat: use unified attention module, add wrapper for state dict compatibility

* feat: loading with dynamic fp8 optimization and LoRA support

* feat: add anima minimal inference script (WIP)

* format: format

* feat: simplify target module selection by regular expression patterns

* feat: kept caption dropout rate in cache and handle in training script

* feat: update train_llm_adapter and verbose default values to string type

* fix: use strategy instead of using tokenizers directly

* feat: add dtype property and all-zero mask handling in cross-attention in LLMAdapterTransformerBlock

* feat: support 5d tensor in get_noisy_model_input_and_timesteps

* feat: update loss calculation to support 5d tensor

* fix: update argument names in anima_train_utils to align with other archtectures

* feat: simplify Anima training script and update empty caption handling

* feat: support LoRA format without `net.` prefix

* fix: update to work fp8_scaled option

* feat: add regex-based learning rates and dimensions handling in create_network

* fix: improve regex matching for module selection and learning rates in LoRANetwork

* fix: update logging message for regex match in LoRANetwork

* fix: keep latents 4D except DiT call

* feat: enhance block swap functionality for inference and training in Anima model

* feat: refactor Anima training script

* feat: optimize VAE processing by adjusting tensor dimensions and data types

* fix: wait all block trasfer before siwtching offloader mode

* feat: update Anima training guide with new argument specifications and regex-based module selection. Thank you Claude!

* feat: support LORA for Qwen3

* feat: update Anima SAI model spec metadata handling

* fix: remove unused code

* feat: split CFG processing in do_sample function to reduce memory usage

* feat: add VAE chunking and caching options to reduce memory usage

* feat: optimize RMSNorm forward method and remove unused torch_attention_op

* Update library/strategy_anima.py

Use torch.all instead of all.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update library/safetensors_utils.py

Fix duplicated new_key for concat_hook.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update anima_minimal_inference.py

Remove unused code.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update anima_train.py

Remove unused import.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update library/anima_train_utils.py

Remove unused import.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* fix: review with Copilot

* feat: add script to convert LoRA format to ComfyUI compatible format (WIP, not tested yet)

* feat: add process_escape function to handle escape sequences in prompts

* feat: enhance LoRA weight handling in model loading and add text encoder loading function

* feat: improve ComfyUI conversion script with prefix constants and module name adjustments

* feat: update caption dropout documentation to clarify cache regeneration requirement

* feat: add clarification on learning rate adjustments

* feat: add note on PyTorch version requirement to prevent NaN loss

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-02-13 08:15:06 +09:00

640 lines
26 KiB
Python

# LoRA network module for Anima
import ast
import os
import re
from typing import Dict, List, Optional, Tuple, Type, Union
import torch
from library.utils import setup_logging
from networks.lora_flux import LoRAModule, LoRAInfModule
import logging
setup_logging()
logger = logging.getLogger(__name__)
def create_network(
multiplier: float,
network_dim: Optional[int],
network_alpha: Optional[float],
vae,
text_encoders: list,
unet,
neuron_dropout: Optional[float] = None,
**kwargs,
):
if network_dim is None:
network_dim = 4
if network_alpha is None:
network_alpha = 1.0
# train LLM adapter
train_llm_adapter = kwargs.get("train_llm_adapter", "false")
if train_llm_adapter is not None:
train_llm_adapter = True if train_llm_adapter.lower() == "true" else False
exclude_patterns = kwargs.get("exclude_patterns", None)
if exclude_patterns is None:
exclude_patterns = []
else:
exclude_patterns = ast.literal_eval(exclude_patterns)
if not isinstance(exclude_patterns, list):
exclude_patterns = [exclude_patterns]
# add default exclude patterns
exclude_patterns.append(r".*(_modulation|_norm|_embedder|final_layer).*")
# regular expression for module selection: exclude and include
include_patterns = kwargs.get("include_patterns", None)
if include_patterns is not None:
include_patterns = ast.literal_eval(include_patterns)
if not isinstance(include_patterns, list):
include_patterns = [include_patterns]
# rank/module dropout
rank_dropout = kwargs.get("rank_dropout", None)
if rank_dropout is not None:
rank_dropout = float(rank_dropout)
module_dropout = kwargs.get("module_dropout", None)
if module_dropout is not None:
module_dropout = float(module_dropout)
# verbose
verbose = kwargs.get("verbose", "false")
if verbose is not None:
verbose = True if verbose.lower() == "true" else False
# regex-specific learning rates / dimensions
def parse_kv_pairs(kv_pair_str: str, is_int: bool) -> Dict[str, float]:
"""
Parse a string of key-value pairs separated by commas.
"""
pairs = {}
for pair in kv_pair_str.split(","):
pair = pair.strip()
if not pair:
continue
if "=" not in pair:
logger.warning(f"Invalid format: {pair}, expected 'key=value'")
continue
key, value = pair.split("=", 1)
key = key.strip()
value = value.strip()
try:
pairs[key] = int(value) if is_int else float(value)
except ValueError:
logger.warning(f"Invalid value for {key}: {value}")
return pairs
network_reg_lrs = kwargs.get("network_reg_lrs", None)
if network_reg_lrs is not None:
reg_lrs = parse_kv_pairs(network_reg_lrs, is_int=False)
else:
reg_lrs = None
network_reg_dims = kwargs.get("network_reg_dims", None)
if network_reg_dims is not None:
reg_dims = parse_kv_pairs(network_reg_dims, is_int=True)
else:
reg_dims = None
network = LoRANetwork(
text_encoders,
unet,
multiplier=multiplier,
lora_dim=network_dim,
alpha=network_alpha,
dropout=neuron_dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
train_llm_adapter=train_llm_adapter,
exclude_patterns=exclude_patterns,
include_patterns=include_patterns,
reg_dims=reg_dims,
reg_lrs=reg_lrs,
verbose=verbose,
)
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
return network
def create_network_from_weights(multiplier, file, ae, text_encoders, unet, weights_sd=None, for_inference=False, **kwargs):
if weights_sd is None:
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
modules_dim = {}
modules_alpha = {}
train_llm_adapter = False
for key, value in weights_sd.items():
if "." not in key:
continue
lora_name = key.split(".")[0]
if "alpha" in key:
modules_alpha[lora_name] = value
elif "lora_down" in key:
dim = value.size()[0]
modules_dim[lora_name] = dim
if "llm_adapter" in lora_name:
train_llm_adapter = True
module_class = LoRAInfModule if for_inference else LoRAModule
network = LoRANetwork(
text_encoders,
unet,
multiplier=multiplier,
modules_dim=modules_dim,
modules_alpha=modules_alpha,
module_class=module_class,
train_llm_adapter=train_llm_adapter,
)
return network, weights_sd
class LoRANetwork(torch.nn.Module):
# Target modules: DiT blocks, embedders, final layer. embedders and final layer are excluded by default.
ANIMA_TARGET_REPLACE_MODULE = ["Block", "PatchEmbed", "TimestepEmbedding", "FinalLayer"]
# Target modules: LLM Adapter blocks
ANIMA_ADAPTER_TARGET_REPLACE_MODULE = ["LLMAdapterTransformerBlock"]
# Target modules for text encoder (Qwen3)
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["Qwen3Attention", "Qwen3MLP", "Qwen3SdpaAttention", "Qwen3FlashAttention2"]
LORA_PREFIX_ANIMA = "lora_unet" # ComfyUI compatible
LORA_PREFIX_TEXT_ENCODER = "lora_te" # Qwen3
def __init__(
self,
text_encoders: list,
unet,
multiplier: float = 1.0,
lora_dim: int = 4,
alpha: float = 1,
dropout: Optional[float] = None,
rank_dropout: Optional[float] = None,
module_dropout: Optional[float] = None,
module_class: Type[object] = LoRAModule,
modules_dim: Optional[Dict[str, int]] = None,
modules_alpha: Optional[Dict[str, int]] = None,
train_llm_adapter: bool = False,
exclude_patterns: Optional[List[str]] = None,
include_patterns: Optional[List[str]] = None,
reg_dims: Optional[Dict[str, int]] = None,
reg_lrs: Optional[Dict[str, float]] = None,
verbose: Optional[bool] = False,
) -> None:
super().__init__()
self.multiplier = multiplier
self.lora_dim = lora_dim
self.alpha = alpha
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.train_llm_adapter = train_llm_adapter
self.reg_dims = reg_dims
self.reg_lrs = reg_lrs
self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None
self.loraplus_text_encoder_lr_ratio = None
if modules_dim is not None:
logger.info("create LoRA network from weights")
else:
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
logger.info(
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
)
# compile regular expression if specified
def str_to_re_patterns(patterns: Optional[List[str]]) -> List[re.Pattern]:
re_patterns = []
if patterns is not None:
for pattern in patterns:
try:
re_pattern = re.compile(pattern)
except re.error as e:
logger.error(f"Invalid pattern '{pattern}': {e}")
continue
re_patterns.append(re_pattern)
return re_patterns
exclude_re_patterns = str_to_re_patterns(exclude_patterns)
include_re_patterns = str_to_re_patterns(include_patterns)
# create module instances
def create_modules(
is_unet: bool,
text_encoder_idx: Optional[int],
root_module: torch.nn.Module,
target_replace_modules: List[str],
default_dim: Optional[int] = None,
) -> Tuple[List[LoRAModule], List[str]]:
prefix = self.LORA_PREFIX_ANIMA if is_unet else self.LORA_PREFIX_TEXT_ENCODER
loras = []
skipped = []
for name, module in root_module.named_modules():
if target_replace_modules is None or module.__class__.__name__ in target_replace_modules:
if target_replace_modules is None:
module = root_module
for child_name, child_module in module.named_modules():
is_linear = child_module.__class__.__name__ == "Linear"
is_conv2d = child_module.__class__.__name__ == "Conv2d"
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
if is_linear or is_conv2d:
original_name = (name + "." if name else "") + child_name
lora_name = f"{prefix}.{original_name}".replace(".", "_")
# exclude/include filter (fullmatch: pattern must match the entire original_name)
excluded = any(pattern.fullmatch(original_name) for pattern in exclude_re_patterns)
included = any(pattern.fullmatch(original_name) for pattern in include_re_patterns)
if excluded and not included:
if verbose:
logger.info(f"exclude: {original_name}")
continue
dim = None
alpha_val = None
if modules_dim is not None:
if lora_name in modules_dim:
dim = modules_dim[lora_name]
alpha_val = modules_alpha[lora_name]
else:
if self.reg_dims is not None:
for reg, d in self.reg_dims.items():
if re.fullmatch(reg, original_name):
dim = d
alpha_val = self.alpha
logger.info(f"Module {original_name} matched with regex '{reg}' -> dim: {dim}")
break
# fallback to default dim if not matched by reg_dims or reg_dims is not specified
if dim is None:
if is_linear or is_conv2d_1x1:
dim = default_dim if default_dim is not None else self.lora_dim
alpha_val = self.alpha
if dim is None or dim == 0:
if is_linear or is_conv2d_1x1:
skipped.append(lora_name)
continue
lora = module_class(
lora_name,
child_module,
self.multiplier,
dim,
alpha_val,
dropout=dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
)
lora.original_name = original_name
loras.append(lora)
if target_replace_modules is None:
break
return loras, skipped
# Create LoRA for text encoders (Qwen3 - typically not trained for Anima)
self.text_encoder_loras: List[Union[LoRAModule, LoRAInfModule]] = []
skipped_te = []
if text_encoders is not None:
for i, text_encoder in enumerate(text_encoders):
if text_encoder is None:
continue
logger.info(f"create LoRA for Text Encoder {i+1}:")
te_loras, te_skipped = create_modules(False, i, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
logger.info(f"create LoRA for Text Encoder {i+1}: {len(te_loras)} modules.")
self.text_encoder_loras.extend(te_loras)
skipped_te += te_skipped
# Create LoRA for DiT blocks
target_modules = list(LoRANetwork.ANIMA_TARGET_REPLACE_MODULE)
if train_llm_adapter:
target_modules.extend(LoRANetwork.ANIMA_ADAPTER_TARGET_REPLACE_MODULE)
self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
logger.info(f"create LoRA for Anima DiT: {len(self.unet_loras)} modules.")
if verbose:
for lora in self.unet_loras:
logger.info(f"\t{lora.lora_name:60} {lora.lora_dim}, {lora.alpha}")
skipped = skipped_te + skipped_un
if verbose and len(skipped) > 0:
logger.warning(f"dim (rank) is 0, {len(skipped)} LoRA modules are skipped:")
for name in skipped:
logger.info(f"\t{name}")
# assertion: no duplicate names
names = set()
for lora in self.text_encoder_loras + self.unet_loras:
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
names.add(lora.lora_name)
def set_multiplier(self, multiplier):
self.multiplier = multiplier
for lora in self.text_encoder_loras + self.unet_loras:
lora.multiplier = self.multiplier
def set_enabled(self, is_enabled):
for lora in self.text_encoder_loras + self.unet_loras:
lora.enabled = is_enabled
def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
info = self.load_state_dict(weights_sd, False)
return info
def apply_to(self, text_encoders, unet, apply_text_encoder=True, apply_unet=True):
if apply_text_encoder:
logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules")
else:
self.text_encoder_loras = []
if apply_unet:
logger.info(f"enable LoRA for DiT: {len(self.unet_loras)} modules")
else:
self.unet_loras = []
for lora in self.text_encoder_loras + self.unet_loras:
lora.apply_to()
self.add_module(lora.lora_name, lora)
def is_mergeable(self):
return True
def merge_to(self, text_encoders, unet, weights_sd, dtype=None, device=None):
apply_text_encoder = apply_unet = False
for key in weights_sd.keys():
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
apply_text_encoder = True
elif key.startswith(LoRANetwork.LORA_PREFIX_ANIMA):
apply_unet = True
if apply_text_encoder:
logger.info("enable LoRA for text encoder")
else:
self.text_encoder_loras = []
if apply_unet:
logger.info("enable LoRA for DiT")
else:
self.unet_loras = []
for lora in self.text_encoder_loras + self.unet_loras:
sd_for_lora = {}
for key in weights_sd.keys():
if key.startswith(lora.lora_name):
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
lora.merge_to(sd_for_lora, dtype, device)
logger.info("weights are merged")
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
self.loraplus_lr_ratio = loraplus_lr_ratio
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}")
logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}")
def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr):
if text_encoder_lr is None or (isinstance(text_encoder_lr, list) and len(text_encoder_lr) == 0):
text_encoder_lr = [default_lr]
elif isinstance(text_encoder_lr, float) or isinstance(text_encoder_lr, int):
text_encoder_lr = [float(text_encoder_lr)]
elif len(text_encoder_lr) == 1:
pass # already a list with one element
self.requires_grad_(True)
all_params = []
lr_descriptions = []
def assemble_params(loras, lr, loraplus_ratio):
param_groups = {"lora": {}, "plus": {}}
reg_groups = {}
reg_lrs_list = list(self.reg_lrs.items()) if self.reg_lrs is not None else []
for lora in loras:
matched_reg_lr = None
for i, (regex_str, reg_lr) in enumerate(reg_lrs_list):
if re.fullmatch(regex_str, lora.original_name):
matched_reg_lr = (i, reg_lr)
logger.info(f"Module {lora.original_name} matched regex '{regex_str}' -> LR {reg_lr}")
break
for name, param in lora.named_parameters():
if matched_reg_lr is not None:
reg_idx, reg_lr = matched_reg_lr
group_key = f"reg_lr_{reg_idx}"
if group_key not in reg_groups:
reg_groups[group_key] = {"lora": {}, "plus": {}, "lr": reg_lr}
if loraplus_ratio is not None and "lora_up" in name:
reg_groups[group_key]["plus"][f"{lora.lora_name}.{name}"] = param
else:
reg_groups[group_key]["lora"][f"{lora.lora_name}.{name}"] = param
continue
if loraplus_ratio is not None and "lora_up" in name:
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
else:
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
params = []
descriptions = []
for group_key, group in reg_groups.items():
reg_lr = group["lr"]
for key in ("lora", "plus"):
param_data = {"params": group[key].values()}
if len(param_data["params"]) == 0:
continue
if key == "plus":
param_data["lr"] = reg_lr * loraplus_ratio if loraplus_ratio is not None else reg_lr
else:
param_data["lr"] = reg_lr
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
logger.info("NO LR skipping!")
continue
params.append(param_data)
desc = f"reg_lr_{group_key.split('_')[-1]}"
descriptions.append(desc + (" plus" if key == "plus" else ""))
for key in param_groups.keys():
param_data = {"params": param_groups[key].values()}
if len(param_data["params"]) == 0:
continue
if lr is not None:
if key == "plus":
param_data["lr"] = lr * loraplus_ratio
else:
param_data["lr"] = lr
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
logger.info("NO LR skipping!")
continue
params.append(param_data)
descriptions.append("plus" if key == "plus" else "")
return params, descriptions
if self.text_encoder_loras:
loraplus_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio
te1_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER)]
if len(te1_loras) > 0:
logger.info(f"Text Encoder 1 (Qwen3): {len(te1_loras)} modules, LR {text_encoder_lr[0]}")
params, descriptions = assemble_params(te1_loras, text_encoder_lr[0], loraplus_ratio)
all_params.extend(params)
lr_descriptions.extend(["textencoder 1" + (" " + d if d else "") for d in descriptions])
if self.unet_loras:
params, descriptions = assemble_params(
self.unet_loras,
unet_lr if unet_lr is not None else default_lr,
self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio,
)
all_params.extend(params)
lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions])
return all_params, lr_descriptions
def enable_gradient_checkpointing(self):
pass # not supported
def prepare_grad_etc(self, text_encoder, unet):
self.requires_grad_(True)
def on_epoch_start(self, text_encoder, unet):
self.train()
def get_trainable_params(self):
return self.parameters()
def save_weights(self, file, dtype, metadata):
if metadata is not None and len(metadata) == 0:
metadata = None
state_dict = self.state_dict()
if dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(dtype)
state_dict[key] = v
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file
from library import train_util
if metadata is None:
metadata = {}
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
save_file(state_dict, file, metadata)
else:
torch.save(state_dict, file)
def backup_weights(self):
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
for lora in loras:
org_module = lora.org_module_ref[0]
if not hasattr(org_module, "_lora_org_weight"):
sd = org_module.state_dict()
org_module._lora_org_weight = sd["weight"].detach().clone()
org_module._lora_restored = True
def restore_weights(self):
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
for lora in loras:
org_module = lora.org_module_ref[0]
if not org_module._lora_restored:
sd = org_module.state_dict()
sd["weight"] = org_module._lora_org_weight
org_module.load_state_dict(sd)
org_module._lora_restored = True
def pre_calculation(self):
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
for lora in loras:
org_module = lora.org_module_ref[0]
sd = org_module.state_dict()
org_weight = sd["weight"]
lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype)
sd["weight"] = org_weight + lora_weight
assert sd["weight"].shape == org_weight.shape
org_module.load_state_dict(sd)
org_module._lora_restored = False
lora.enabled = False
def apply_max_norm_regularization(self, max_norm_value, device):
downkeys = []
upkeys = []
alphakeys = []
norms = []
keys_scaled = 0
state_dict = self.state_dict()
for key in state_dict.keys():
if "lora_down" in key and "weight" in key:
downkeys.append(key)
upkeys.append(key.replace("lora_down", "lora_up"))
alphakeys.append(key.replace("lora_down.weight", "alpha"))
for i in range(len(downkeys)):
down = state_dict[downkeys[i]].to(device)
up = state_dict[upkeys[i]].to(device)
alpha = state_dict[alphakeys[i]].to(device)
dim = down.shape[0]
scale = alpha / dim
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
else:
updown = up @ down
updown *= scale
norm = updown.norm().clamp(min=max_norm_value / 2)
desired = torch.clamp(norm, max=max_norm_value)
ratio = desired.cpu() / norm.cpu()
sqrt_ratio = ratio**0.5
if ratio != 1:
keys_scaled += 1
state_dict[upkeys[i]] *= sqrt_ratio
state_dict[downkeys[i]] *= sqrt_ratio
scalednorm = updown.norm() * ratio
norms.append(scalednorm.item())
return keys_scaled, sum(norms) / len(norms), max(norms)