mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
* fix: update extend-exclude list in _typos.toml to include configs * fix: exclude anima tests from pytest * feat: add entry for 'temperal' in extend-words section of _typos.toml for Qwen-Image VAE * fix: update default value for --discrete_flow_shift in anima training guide * feat: add Qwen-Image VAE * feat: simplify encode_tokens * feat: use unified attention module, add wrapper for state dict compatibility * feat: loading with dynamic fp8 optimization and LoRA support * feat: add anima minimal inference script (WIP) * format: format * feat: simplify target module selection by regular expression patterns * feat: kept caption dropout rate in cache and handle in training script * feat: update train_llm_adapter and verbose default values to string type * fix: use strategy instead of using tokenizers directly * feat: add dtype property and all-zero mask handling in cross-attention in LLMAdapterTransformerBlock * feat: support 5d tensor in get_noisy_model_input_and_timesteps * feat: update loss calculation to support 5d tensor * fix: update argument names in anima_train_utils to align with other archtectures * feat: simplify Anima training script and update empty caption handling * feat: support LoRA format without `net.` prefix * fix: update to work fp8_scaled option * feat: add regex-based learning rates and dimensions handling in create_network * fix: improve regex matching for module selection and learning rates in LoRANetwork * fix: update logging message for regex match in LoRANetwork * fix: keep latents 4D except DiT call * feat: enhance block swap functionality for inference and training in Anima model * feat: refactor Anima training script * feat: optimize VAE processing by adjusting tensor dimensions and data types * fix: wait all block trasfer before siwtching offloader mode * feat: update Anima training guide with new argument specifications and regex-based module selection. Thank you Claude! * feat: support LORA for Qwen3 * feat: update Anima SAI model spec metadata handling * fix: remove unused code * feat: split CFG processing in do_sample function to reduce memory usage * feat: add VAE chunking and caching options to reduce memory usage * feat: optimize RMSNorm forward method and remove unused torch_attention_op * Update library/strategy_anima.py Use torch.all instead of all. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update library/safetensors_utils.py Fix duplicated new_key for concat_hook. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update anima_minimal_inference.py Remove unused code. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update anima_train.py Remove unused import. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update library/anima_train_utils.py Remove unused import. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix: review with Copilot * feat: add script to convert LoRA format to ComfyUI compatible format (WIP, not tested yet) * feat: add process_escape function to handle escape sequences in prompts * feat: enhance LoRA weight handling in model loading and add text encoder loading function * feat: improve ComfyUI conversion script with prefix constants and module name adjustments * feat: update caption dropout documentation to clarify cache regeneration requirement * feat: add clarification on learning rate adjustments * feat: add note on PyTorch version requirement to prevent NaN loss --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
310 lines
12 KiB
Python
310 lines
12 KiB
Python
# Anima model loading/saving utilities
|
|
|
|
import os
|
|
from typing import Dict, List, Optional, Union
|
|
import torch
|
|
from safetensors.torch import load_file, save_file
|
|
from accelerate.utils import set_module_tensor_to_device # kept for potential future use
|
|
from accelerate import init_empty_weights
|
|
|
|
from library.fp8_optimization_utils import apply_fp8_monkey_patch
|
|
from library.lora_utils import load_safetensors_with_lora_and_fp8
|
|
from library import anima_models
|
|
from library.safetensors_utils import WeightTransformHooks
|
|
from .utils import setup_logging
|
|
|
|
setup_logging()
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# Original Anima high-precision keys. Kept for reference, but not used currently.
|
|
# # Keys that should stay in high precision (float32/bfloat16, not quantized)
|
|
# KEEP_IN_HIGH_PRECISION = ["x_embedder", "t_embedder", "t_embedding_norm", "final_layer"]
|
|
|
|
|
|
FP8_OPTIMIZATION_TARGET_KEYS = ["blocks", ""]
|
|
# ".embed." excludes Embedding in LLMAdapter
|
|
FP8_OPTIMIZATION_EXCLUDE_KEYS = ["_embedder", "norm", "adaln", "final_layer", ".embed."]
|
|
|
|
|
|
def load_anima_model(
|
|
device: Union[str, torch.device],
|
|
dit_path: str,
|
|
attn_mode: str,
|
|
split_attn: bool,
|
|
loading_device: Union[str, torch.device],
|
|
dit_weight_dtype: Optional[torch.dtype],
|
|
fp8_scaled: bool = False,
|
|
lora_weights_list: Optional[List[Dict[str, torch.Tensor]]] = None,
|
|
lora_multipliers: Optional[list[float]] = None,
|
|
) -> anima_models.Anima:
|
|
"""
|
|
Load Anima model from the specified checkpoint.
|
|
|
|
Args:
|
|
device (Union[str, torch.device]): Device for optimization or merging
|
|
dit_path (str): Path to the DiT model checkpoint.
|
|
attn_mode (str): Attention mode to use, e.g., "torch", "flash", etc.
|
|
split_attn (bool): Whether to use split attention.
|
|
loading_device (Union[str, torch.device]): Device to load the model weights on.
|
|
dit_weight_dtype (Optional[torch.dtype]): Data type of the DiT weights.
|
|
If None, it will be loaded as is (same as the state_dict) or scaled for fp8. if not None, model weights will be casted to this dtype.
|
|
fp8_scaled (bool): Whether to use fp8 scaling for the model weights.
|
|
lora_weights_list (Optional[List[Dict[str, torch.Tensor]]]): LoRA weights to apply, if any.
|
|
lora_multipliers (Optional[List[float]]): LoRA multipliers for the weights, if any.
|
|
"""
|
|
# dit_weight_dtype is None for fp8_scaled
|
|
assert (
|
|
not fp8_scaled and dit_weight_dtype is not None
|
|
) or dit_weight_dtype is None, "dit_weight_dtype should be None when fp8_scaled is True"
|
|
|
|
device = torch.device(device)
|
|
loading_device = torch.device(loading_device)
|
|
|
|
# We currently support fixed DiT config for Anima models
|
|
dit_config = {
|
|
"max_img_h": 512,
|
|
"max_img_w": 512,
|
|
"max_frames": 128,
|
|
"in_channels": 16,
|
|
"out_channels": 16,
|
|
"patch_spatial": 2,
|
|
"patch_temporal": 1,
|
|
"model_channels": 2048,
|
|
"concat_padding_mask": True,
|
|
"crossattn_emb_channels": 1024,
|
|
"pos_emb_cls": "rope3d",
|
|
"pos_emb_learnable": True,
|
|
"pos_emb_interpolation": "crop",
|
|
"min_fps": 1,
|
|
"max_fps": 30,
|
|
"use_adaln_lora": True,
|
|
"adaln_lora_dim": 256,
|
|
"num_blocks": 28,
|
|
"num_heads": 16,
|
|
"extra_per_block_abs_pos_emb": False,
|
|
"rope_h_extrapolation_ratio": 4.0,
|
|
"rope_w_extrapolation_ratio": 4.0,
|
|
"rope_t_extrapolation_ratio": 1.0,
|
|
"extra_h_extrapolation_ratio": 1.0,
|
|
"extra_w_extrapolation_ratio": 1.0,
|
|
"extra_t_extrapolation_ratio": 1.0,
|
|
"rope_enable_fps_modulation": False,
|
|
"use_llm_adapter": True,
|
|
"attn_mode": attn_mode,
|
|
"split_attn": split_attn,
|
|
}
|
|
with init_empty_weights():
|
|
model = anima_models.Anima(**dit_config)
|
|
if dit_weight_dtype is not None:
|
|
model.to(dit_weight_dtype)
|
|
|
|
# load model weights with dynamic fp8 optimization and LoRA merging if needed
|
|
logger.info(f"Loading DiT model from {dit_path}, device={loading_device}")
|
|
rename_hooks = WeightTransformHooks(rename_hook=lambda k: k[len("net.") :] if k.startswith("net.") else k)
|
|
sd = load_safetensors_with_lora_and_fp8(
|
|
model_files=dit_path,
|
|
lora_weights_list=lora_weights_list,
|
|
lora_multipliers=lora_multipliers,
|
|
fp8_optimization=fp8_scaled,
|
|
calc_device=device,
|
|
move_to_device=(loading_device == device),
|
|
dit_weight_dtype=dit_weight_dtype,
|
|
target_keys=FP8_OPTIMIZATION_TARGET_KEYS,
|
|
exclude_keys=FP8_OPTIMIZATION_EXCLUDE_KEYS,
|
|
weight_transform_hooks=rename_hooks,
|
|
)
|
|
|
|
if fp8_scaled:
|
|
apply_fp8_monkey_patch(model, sd, use_scaled_mm=False)
|
|
|
|
if loading_device.type != "cpu":
|
|
# make sure all the model weights are on the loading_device
|
|
logger.info(f"Moving weights to {loading_device}")
|
|
for key in sd.keys():
|
|
sd[key] = sd[key].to(loading_device)
|
|
|
|
missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
|
|
if missing:
|
|
# Filter out expected missing buffers (initialized in __init__, not saved in checkpoint)
|
|
unexpected_missing = [
|
|
k
|
|
for k in missing
|
|
if not any(buf_name in k for buf_name in ("seq", "dim_spatial_range", "dim_temporal_range", "inv_freq"))
|
|
]
|
|
if unexpected_missing:
|
|
# Raise error to avoid silent failures
|
|
raise RuntimeError(
|
|
f"Missing keys in checkpoint: {unexpected_missing[:10]}{'...' if len(unexpected_missing) > 10 else ''}"
|
|
)
|
|
missing = {} # all missing keys were expected
|
|
if unexpected:
|
|
# Raise error to avoid silent failures
|
|
raise RuntimeError(f"Unexpected keys in checkpoint: {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}")
|
|
logger.info(f"Loaded DiT model from {dit_path}, unexpected missing keys: {len(missing)}, unexpected keys: {len(unexpected)}")
|
|
|
|
return model
|
|
|
|
|
|
def load_qwen3_tokenizer(qwen3_path: str):
|
|
"""Load Qwen3 tokenizer only (without the text encoder model).
|
|
|
|
Args:
|
|
qwen3_path: Path to either a directory with model files or a safetensors file.
|
|
If a directory, loads tokenizer from it directly.
|
|
If a file, uses configs/qwen3_06b/ for tokenizer config.
|
|
Returns:
|
|
tokenizer
|
|
"""
|
|
from transformers import AutoTokenizer
|
|
|
|
if os.path.isdir(qwen3_path):
|
|
tokenizer = AutoTokenizer.from_pretrained(qwen3_path, local_files_only=True)
|
|
else:
|
|
config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs", "qwen3_06b")
|
|
if not os.path.exists(config_dir):
|
|
raise FileNotFoundError(
|
|
f"Qwen3 config directory not found at {config_dir}. "
|
|
"Expected configs/qwen3_06b/ with config.json, tokenizer.json, etc. "
|
|
"You can download these from the Qwen3-0.6B HuggingFace repository."
|
|
)
|
|
tokenizer = AutoTokenizer.from_pretrained(config_dir, local_files_only=True)
|
|
|
|
if tokenizer.pad_token is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
return tokenizer
|
|
|
|
|
|
def load_qwen3_text_encoder(
|
|
qwen3_path: str,
|
|
dtype: torch.dtype = torch.bfloat16,
|
|
device: str = "cpu",
|
|
lora_weights: Optional[List[Dict[str, torch.Tensor]]] = None,
|
|
lora_multipliers: Optional[List[float]] = None,
|
|
):
|
|
"""Load Qwen3-0.6B text encoder.
|
|
|
|
Args:
|
|
qwen3_path: Path to either a directory with model files or a safetensors file
|
|
dtype: Model dtype
|
|
device: Device to load to
|
|
|
|
Returns:
|
|
(text_encoder_model, tokenizer)
|
|
"""
|
|
import transformers
|
|
from transformers import AutoTokenizer
|
|
|
|
logger.info(f"Loading Qwen3 text encoder from {qwen3_path}")
|
|
|
|
if os.path.isdir(qwen3_path):
|
|
# Directory with full model
|
|
tokenizer = AutoTokenizer.from_pretrained(qwen3_path, local_files_only=True)
|
|
model = transformers.AutoModelForCausalLM.from_pretrained(qwen3_path, torch_dtype=dtype, local_files_only=True).model
|
|
else:
|
|
# Single safetensors file - use configs/qwen3_06b/ for config
|
|
config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs", "qwen3_06b")
|
|
if not os.path.exists(config_dir):
|
|
raise FileNotFoundError(
|
|
f"Qwen3 config directory not found at {config_dir}. "
|
|
"Expected configs/qwen3_06b/ with config.json, tokenizer.json, etc. "
|
|
"You can download these from the Qwen3-0.6B HuggingFace repository."
|
|
)
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(config_dir, local_files_only=True)
|
|
qwen3_config = transformers.Qwen3Config.from_pretrained(config_dir, local_files_only=True)
|
|
model = transformers.Qwen3ForCausalLM(qwen3_config).model
|
|
|
|
# Load weights
|
|
if qwen3_path.endswith(".safetensors"):
|
|
if lora_weights is None:
|
|
state_dict = load_file(qwen3_path, device="cpu")
|
|
else:
|
|
state_dict = load_safetensors_with_lora_and_fp8(
|
|
model_files=qwen3_path,
|
|
lora_weights_list=lora_weights,
|
|
lora_multipliers=lora_multipliers,
|
|
fp8_optimization=False,
|
|
calc_device=device,
|
|
move_to_device=True,
|
|
dit_weight_dtype=None,
|
|
)
|
|
else:
|
|
assert lora_weights is None, "LoRA weights merging is only supported for safetensors checkpoints"
|
|
state_dict = torch.load(qwen3_path, map_location="cpu", weights_only=True)
|
|
|
|
# Remove 'model.' prefix if present
|
|
new_sd = {}
|
|
for k, v in state_dict.items():
|
|
if k.startswith("model."):
|
|
new_sd[k[len("model.") :]] = v
|
|
else:
|
|
new_sd[k] = v
|
|
|
|
info = model.load_state_dict(new_sd, strict=False)
|
|
logger.info(f"Loaded Qwen3 state dict: {info}")
|
|
|
|
if tokenizer.pad_token is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
model.config.use_cache = False
|
|
model = model.requires_grad_(False).to(device, dtype=dtype)
|
|
|
|
logger.info(f"Loaded Qwen3 text encoder. Parameters: {sum(p.numel() for p in model.parameters()):,}")
|
|
return model, tokenizer
|
|
|
|
|
|
def load_t5_tokenizer(t5_tokenizer_path: Optional[str] = None):
|
|
"""Load T5 tokenizer for LLM Adapter target tokens.
|
|
|
|
Args:
|
|
t5_tokenizer_path: Optional path to T5 tokenizer directory. If None, uses default configs.
|
|
"""
|
|
from transformers import T5TokenizerFast
|
|
|
|
if t5_tokenizer_path is not None:
|
|
return T5TokenizerFast.from_pretrained(t5_tokenizer_path, local_files_only=True)
|
|
|
|
# Use bundled config
|
|
config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs", "t5_old")
|
|
if os.path.exists(config_dir):
|
|
return T5TokenizerFast(
|
|
vocab_file=os.path.join(config_dir, "spiece.model"),
|
|
tokenizer_file=os.path.join(config_dir, "tokenizer.json"),
|
|
)
|
|
|
|
raise FileNotFoundError(
|
|
f"T5 tokenizer config directory not found at {config_dir}. "
|
|
"Expected configs/t5_old/ with spiece.model and tokenizer.json. "
|
|
"You can download these from the google/t5-v1_1-xxl HuggingFace repository."
|
|
)
|
|
|
|
|
|
def save_anima_model(
|
|
save_path: str, dit_state_dict: Dict[str, torch.Tensor], metadata: Dict[str, any], dtype: Optional[torch.dtype] = None
|
|
):
|
|
"""Save Anima DiT model with 'net.' prefix for ComfyUI compatibility.
|
|
|
|
Args:
|
|
save_path: Output path (.safetensors)
|
|
dit_state_dict: State dict from dit.state_dict()
|
|
metadata: Metadata dict to include in the safetensors file
|
|
dtype: Optional dtype to cast to before saving
|
|
"""
|
|
prefixed_sd = {}
|
|
for k, v in dit_state_dict.items():
|
|
if dtype is not None:
|
|
# v = v.to(dtype)
|
|
v = v.detach().clone().to("cpu").to(dtype) # Reduce GPU memory usage during save
|
|
prefixed_sd["net." + k] = v.contiguous()
|
|
|
|
if metadata is None:
|
|
metadata = {}
|
|
metadata["format"] = "pt" # For compatibility with the official .safetensors file
|
|
|
|
save_file(prefixed_sd, save_path, metadata=metadata) # safetensors.save_file cosumes a lot of memory, but Anima is small enough
|
|
logger.info(f"Saved Anima model to {save_path}")
|