Files
Kohya-ss-sd-scripts/library/anima_utils.py
Kohya S. 34e7138b6a Add/modify some implementation for anima (#2261)
* fix: update extend-exclude list in _typos.toml to include configs

* fix: exclude anima tests from pytest

* feat: add entry for 'temperal' in extend-words section of _typos.toml for Qwen-Image VAE

* fix: update default value for --discrete_flow_shift in anima training guide

* feat: add Qwen-Image VAE

* feat: simplify encode_tokens

* feat: use unified attention module, add wrapper for state dict compatibility

* feat: loading with dynamic fp8 optimization and LoRA support

* feat: add anima minimal inference script (WIP)

* format: format

* feat: simplify target module selection by regular expression patterns

* feat: kept caption dropout rate in cache and handle in training script

* feat: update train_llm_adapter and verbose default values to string type

* fix: use strategy instead of using tokenizers directly

* feat: add dtype property and all-zero mask handling in cross-attention in LLMAdapterTransformerBlock

* feat: support 5d tensor in get_noisy_model_input_and_timesteps

* feat: update loss calculation to support 5d tensor

* fix: update argument names in anima_train_utils to align with other archtectures

* feat: simplify Anima training script and update empty caption handling

* feat: support LoRA format without `net.` prefix

* fix: update to work fp8_scaled option

* feat: add regex-based learning rates and dimensions handling in create_network

* fix: improve regex matching for module selection and learning rates in LoRANetwork

* fix: update logging message for regex match in LoRANetwork

* fix: keep latents 4D except DiT call

* feat: enhance block swap functionality for inference and training in Anima model

* feat: refactor Anima training script

* feat: optimize VAE processing by adjusting tensor dimensions and data types

* fix: wait all block trasfer before siwtching offloader mode

* feat: update Anima training guide with new argument specifications and regex-based module selection. Thank you Claude!

* feat: support LORA for Qwen3

* feat: update Anima SAI model spec metadata handling

* fix: remove unused code

* feat: split CFG processing in do_sample function to reduce memory usage

* feat: add VAE chunking and caching options to reduce memory usage

* feat: optimize RMSNorm forward method and remove unused torch_attention_op

* Update library/strategy_anima.py

Use torch.all instead of all.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update library/safetensors_utils.py

Fix duplicated new_key for concat_hook.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update anima_minimal_inference.py

Remove unused code.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update anima_train.py

Remove unused import.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update library/anima_train_utils.py

Remove unused import.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* fix: review with Copilot

* feat: add script to convert LoRA format to ComfyUI compatible format (WIP, not tested yet)

* feat: add process_escape function to handle escape sequences in prompts

* feat: enhance LoRA weight handling in model loading and add text encoder loading function

* feat: improve ComfyUI conversion script with prefix constants and module name adjustments

* feat: update caption dropout documentation to clarify cache regeneration requirement

* feat: add clarification on learning rate adjustments

* feat: add note on PyTorch version requirement to prevent NaN loss

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-02-13 08:15:06 +09:00

310 lines
12 KiB
Python

# Anima model loading/saving utilities
import os
from typing import Dict, List, Optional, Union
import torch
from safetensors.torch import load_file, save_file
from accelerate.utils import set_module_tensor_to_device # kept for potential future use
from accelerate import init_empty_weights
from library.fp8_optimization_utils import apply_fp8_monkey_patch
from library.lora_utils import load_safetensors_with_lora_and_fp8
from library import anima_models
from library.safetensors_utils import WeightTransformHooks
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# Original Anima high-precision keys. Kept for reference, but not used currently.
# # Keys that should stay in high precision (float32/bfloat16, not quantized)
# KEEP_IN_HIGH_PRECISION = ["x_embedder", "t_embedder", "t_embedding_norm", "final_layer"]
FP8_OPTIMIZATION_TARGET_KEYS = ["blocks", ""]
# ".embed." excludes Embedding in LLMAdapter
FP8_OPTIMIZATION_EXCLUDE_KEYS = ["_embedder", "norm", "adaln", "final_layer", ".embed."]
def load_anima_model(
device: Union[str, torch.device],
dit_path: str,
attn_mode: str,
split_attn: bool,
loading_device: Union[str, torch.device],
dit_weight_dtype: Optional[torch.dtype],
fp8_scaled: bool = False,
lora_weights_list: Optional[List[Dict[str, torch.Tensor]]] = None,
lora_multipliers: Optional[list[float]] = None,
) -> anima_models.Anima:
"""
Load Anima model from the specified checkpoint.
Args:
device (Union[str, torch.device]): Device for optimization or merging
dit_path (str): Path to the DiT model checkpoint.
attn_mode (str): Attention mode to use, e.g., "torch", "flash", etc.
split_attn (bool): Whether to use split attention.
loading_device (Union[str, torch.device]): Device to load the model weights on.
dit_weight_dtype (Optional[torch.dtype]): Data type of the DiT weights.
If None, it will be loaded as is (same as the state_dict) or scaled for fp8. if not None, model weights will be casted to this dtype.
fp8_scaled (bool): Whether to use fp8 scaling for the model weights.
lora_weights_list (Optional[List[Dict[str, torch.Tensor]]]): LoRA weights to apply, if any.
lora_multipliers (Optional[List[float]]): LoRA multipliers for the weights, if any.
"""
# dit_weight_dtype is None for fp8_scaled
assert (
not fp8_scaled and dit_weight_dtype is not None
) or dit_weight_dtype is None, "dit_weight_dtype should be None when fp8_scaled is True"
device = torch.device(device)
loading_device = torch.device(loading_device)
# We currently support fixed DiT config for Anima models
dit_config = {
"max_img_h": 512,
"max_img_w": 512,
"max_frames": 128,
"in_channels": 16,
"out_channels": 16,
"patch_spatial": 2,
"patch_temporal": 1,
"model_channels": 2048,
"concat_padding_mask": True,
"crossattn_emb_channels": 1024,
"pos_emb_cls": "rope3d",
"pos_emb_learnable": True,
"pos_emb_interpolation": "crop",
"min_fps": 1,
"max_fps": 30,
"use_adaln_lora": True,
"adaln_lora_dim": 256,
"num_blocks": 28,
"num_heads": 16,
"extra_per_block_abs_pos_emb": False,
"rope_h_extrapolation_ratio": 4.0,
"rope_w_extrapolation_ratio": 4.0,
"rope_t_extrapolation_ratio": 1.0,
"extra_h_extrapolation_ratio": 1.0,
"extra_w_extrapolation_ratio": 1.0,
"extra_t_extrapolation_ratio": 1.0,
"rope_enable_fps_modulation": False,
"use_llm_adapter": True,
"attn_mode": attn_mode,
"split_attn": split_attn,
}
with init_empty_weights():
model = anima_models.Anima(**dit_config)
if dit_weight_dtype is not None:
model.to(dit_weight_dtype)
# load model weights with dynamic fp8 optimization and LoRA merging if needed
logger.info(f"Loading DiT model from {dit_path}, device={loading_device}")
rename_hooks = WeightTransformHooks(rename_hook=lambda k: k[len("net.") :] if k.startswith("net.") else k)
sd = load_safetensors_with_lora_and_fp8(
model_files=dit_path,
lora_weights_list=lora_weights_list,
lora_multipliers=lora_multipliers,
fp8_optimization=fp8_scaled,
calc_device=device,
move_to_device=(loading_device == device),
dit_weight_dtype=dit_weight_dtype,
target_keys=FP8_OPTIMIZATION_TARGET_KEYS,
exclude_keys=FP8_OPTIMIZATION_EXCLUDE_KEYS,
weight_transform_hooks=rename_hooks,
)
if fp8_scaled:
apply_fp8_monkey_patch(model, sd, use_scaled_mm=False)
if loading_device.type != "cpu":
# make sure all the model weights are on the loading_device
logger.info(f"Moving weights to {loading_device}")
for key in sd.keys():
sd[key] = sd[key].to(loading_device)
missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
if missing:
# Filter out expected missing buffers (initialized in __init__, not saved in checkpoint)
unexpected_missing = [
k
for k in missing
if not any(buf_name in k for buf_name in ("seq", "dim_spatial_range", "dim_temporal_range", "inv_freq"))
]
if unexpected_missing:
# Raise error to avoid silent failures
raise RuntimeError(
f"Missing keys in checkpoint: {unexpected_missing[:10]}{'...' if len(unexpected_missing) > 10 else ''}"
)
missing = {} # all missing keys were expected
if unexpected:
# Raise error to avoid silent failures
raise RuntimeError(f"Unexpected keys in checkpoint: {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}")
logger.info(f"Loaded DiT model from {dit_path}, unexpected missing keys: {len(missing)}, unexpected keys: {len(unexpected)}")
return model
def load_qwen3_tokenizer(qwen3_path: str):
"""Load Qwen3 tokenizer only (without the text encoder model).
Args:
qwen3_path: Path to either a directory with model files or a safetensors file.
If a directory, loads tokenizer from it directly.
If a file, uses configs/qwen3_06b/ for tokenizer config.
Returns:
tokenizer
"""
from transformers import AutoTokenizer
if os.path.isdir(qwen3_path):
tokenizer = AutoTokenizer.from_pretrained(qwen3_path, local_files_only=True)
else:
config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs", "qwen3_06b")
if not os.path.exists(config_dir):
raise FileNotFoundError(
f"Qwen3 config directory not found at {config_dir}. "
"Expected configs/qwen3_06b/ with config.json, tokenizer.json, etc. "
"You can download these from the Qwen3-0.6B HuggingFace repository."
)
tokenizer = AutoTokenizer.from_pretrained(config_dir, local_files_only=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
def load_qwen3_text_encoder(
qwen3_path: str,
dtype: torch.dtype = torch.bfloat16,
device: str = "cpu",
lora_weights: Optional[List[Dict[str, torch.Tensor]]] = None,
lora_multipliers: Optional[List[float]] = None,
):
"""Load Qwen3-0.6B text encoder.
Args:
qwen3_path: Path to either a directory with model files or a safetensors file
dtype: Model dtype
device: Device to load to
Returns:
(text_encoder_model, tokenizer)
"""
import transformers
from transformers import AutoTokenizer
logger.info(f"Loading Qwen3 text encoder from {qwen3_path}")
if os.path.isdir(qwen3_path):
# Directory with full model
tokenizer = AutoTokenizer.from_pretrained(qwen3_path, local_files_only=True)
model = transformers.AutoModelForCausalLM.from_pretrained(qwen3_path, torch_dtype=dtype, local_files_only=True).model
else:
# Single safetensors file - use configs/qwen3_06b/ for config
config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs", "qwen3_06b")
if not os.path.exists(config_dir):
raise FileNotFoundError(
f"Qwen3 config directory not found at {config_dir}. "
"Expected configs/qwen3_06b/ with config.json, tokenizer.json, etc. "
"You can download these from the Qwen3-0.6B HuggingFace repository."
)
tokenizer = AutoTokenizer.from_pretrained(config_dir, local_files_only=True)
qwen3_config = transformers.Qwen3Config.from_pretrained(config_dir, local_files_only=True)
model = transformers.Qwen3ForCausalLM(qwen3_config).model
# Load weights
if qwen3_path.endswith(".safetensors"):
if lora_weights is None:
state_dict = load_file(qwen3_path, device="cpu")
else:
state_dict = load_safetensors_with_lora_and_fp8(
model_files=qwen3_path,
lora_weights_list=lora_weights,
lora_multipliers=lora_multipliers,
fp8_optimization=False,
calc_device=device,
move_to_device=True,
dit_weight_dtype=None,
)
else:
assert lora_weights is None, "LoRA weights merging is only supported for safetensors checkpoints"
state_dict = torch.load(qwen3_path, map_location="cpu", weights_only=True)
# Remove 'model.' prefix if present
new_sd = {}
for k, v in state_dict.items():
if k.startswith("model."):
new_sd[k[len("model.") :]] = v
else:
new_sd[k] = v
info = model.load_state_dict(new_sd, strict=False)
logger.info(f"Loaded Qwen3 state dict: {info}")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model.config.use_cache = False
model = model.requires_grad_(False).to(device, dtype=dtype)
logger.info(f"Loaded Qwen3 text encoder. Parameters: {sum(p.numel() for p in model.parameters()):,}")
return model, tokenizer
def load_t5_tokenizer(t5_tokenizer_path: Optional[str] = None):
"""Load T5 tokenizer for LLM Adapter target tokens.
Args:
t5_tokenizer_path: Optional path to T5 tokenizer directory. If None, uses default configs.
"""
from transformers import T5TokenizerFast
if t5_tokenizer_path is not None:
return T5TokenizerFast.from_pretrained(t5_tokenizer_path, local_files_only=True)
# Use bundled config
config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs", "t5_old")
if os.path.exists(config_dir):
return T5TokenizerFast(
vocab_file=os.path.join(config_dir, "spiece.model"),
tokenizer_file=os.path.join(config_dir, "tokenizer.json"),
)
raise FileNotFoundError(
f"T5 tokenizer config directory not found at {config_dir}. "
"Expected configs/t5_old/ with spiece.model and tokenizer.json. "
"You can download these from the google/t5-v1_1-xxl HuggingFace repository."
)
def save_anima_model(
save_path: str, dit_state_dict: Dict[str, torch.Tensor], metadata: Dict[str, any], dtype: Optional[torch.dtype] = None
):
"""Save Anima DiT model with 'net.' prefix for ComfyUI compatibility.
Args:
save_path: Output path (.safetensors)
dit_state_dict: State dict from dit.state_dict()
metadata: Metadata dict to include in the safetensors file
dtype: Optional dtype to cast to before saving
"""
prefixed_sd = {}
for k, v in dit_state_dict.items():
if dtype is not None:
# v = v.to(dtype)
v = v.detach().clone().to("cpu").to(dtype) # Reduce GPU memory usage during save
prefixed_sd["net." + k] = v.contiguous()
if metadata is None:
metadata = {}
metadata["format"] = "pt" # For compatibility with the official .safetensors file
save_file(prefixed_sd, save_path, metadata=metadata) # safetensors.save_file cosumes a lot of memory, but Anima is small enough
logger.info(f"Saved Anima model to {save_path}")