Files
Kohya-ss-sd-scripts/library/anima_train_utils.py
Kohya S. 34e7138b6a Add/modify some implementation for anima (#2261)
* fix: update extend-exclude list in _typos.toml to include configs

* fix: exclude anima tests from pytest

* feat: add entry for 'temperal' in extend-words section of _typos.toml for Qwen-Image VAE

* fix: update default value for --discrete_flow_shift in anima training guide

* feat: add Qwen-Image VAE

* feat: simplify encode_tokens

* feat: use unified attention module, add wrapper for state dict compatibility

* feat: loading with dynamic fp8 optimization and LoRA support

* feat: add anima minimal inference script (WIP)

* format: format

* feat: simplify target module selection by regular expression patterns

* feat: kept caption dropout rate in cache and handle in training script

* feat: update train_llm_adapter and verbose default values to string type

* fix: use strategy instead of using tokenizers directly

* feat: add dtype property and all-zero mask handling in cross-attention in LLMAdapterTransformerBlock

* feat: support 5d tensor in get_noisy_model_input_and_timesteps

* feat: update loss calculation to support 5d tensor

* fix: update argument names in anima_train_utils to align with other archtectures

* feat: simplify Anima training script and update empty caption handling

* feat: support LoRA format without `net.` prefix

* fix: update to work fp8_scaled option

* feat: add regex-based learning rates and dimensions handling in create_network

* fix: improve regex matching for module selection and learning rates in LoRANetwork

* fix: update logging message for regex match in LoRANetwork

* fix: keep latents 4D except DiT call

* feat: enhance block swap functionality for inference and training in Anima model

* feat: refactor Anima training script

* feat: optimize VAE processing by adjusting tensor dimensions and data types

* fix: wait all block trasfer before siwtching offloader mode

* feat: update Anima training guide with new argument specifications and regex-based module selection. Thank you Claude!

* feat: support LORA for Qwen3

* feat: update Anima SAI model spec metadata handling

* fix: remove unused code

* feat: split CFG processing in do_sample function to reduce memory usage

* feat: add VAE chunking and caching options to reduce memory usage

* feat: optimize RMSNorm forward method and remove unused torch_attention_op

* Update library/strategy_anima.py

Use torch.all instead of all.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update library/safetensors_utils.py

Fix duplicated new_key for concat_hook.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update anima_minimal_inference.py

Remove unused code.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update anima_train.py

Remove unused import.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update library/anima_train_utils.py

Remove unused import.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* fix: review with Copilot

* feat: add script to convert LoRA format to ComfyUI compatible format (WIP, not tested yet)

* feat: add process_escape function to handle escape sequences in prompts

* feat: enhance LoRA weight handling in model loading and add text encoder loading function

* feat: improve ComfyUI conversion script with prefix constants and module name adjustments

* feat: update caption dropout documentation to clarify cache regeneration requirement

* feat: add clarification on learning rate adjustments

* feat: add note on PyTorch version requirement to prevent NaN loss

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-02-13 08:15:06 +09:00

616 lines
21 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# Anima Training Utilities
import argparse
import gc
import math
import os
import time
from typing import Optional
import numpy as np
import torch
from accelerate import Accelerator
from tqdm import tqdm
from PIL import Image
from library.device_utils import init_ipex, clean_memory_on_device, synchronize_device
from library import anima_models, anima_utils, train_util, qwen_image_autoencoder_kl
init_ipex()
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# Anima-specific training arguments
def add_anima_training_arguments(parser: argparse.ArgumentParser):
"""Add Anima-specific training arguments to the parser."""
parser.add_argument(
"--qwen3",
type=str,
default=None,
help="Path to Qwen3-0.6B model (safetensors file or directory)",
)
parser.add_argument(
"--llm_adapter_path",
type=str,
default=None,
help="Path to separate LLM adapter weights. If None, adapter is loaded from DiT file if present",
)
parser.add_argument(
"--llm_adapter_lr",
type=float,
default=None,
help="Learning rate for LLM adapter. None=same as base LR, 0=freeze adapter",
)
parser.add_argument(
"--self_attn_lr",
type=float,
default=None,
help="Learning rate for self-attention layers. None=same as base LR, 0=freeze",
)
parser.add_argument(
"--cross_attn_lr",
type=float,
default=None,
help="Learning rate for cross-attention layers. None=same as base LR, 0=freeze",
)
parser.add_argument(
"--mlp_lr",
type=float,
default=None,
help="Learning rate for MLP layers. None=same as base LR, 0=freeze",
)
parser.add_argument(
"--mod_lr",
type=float,
default=None,
help="Learning rate for AdaLN modulation layers. None=same as base LR, 0=freeze. Note: mod layers are not included in LoRA by default.",
)
parser.add_argument(
"--t5_tokenizer_path",
type=str,
default=None,
help="Path to T5 tokenizer directory. If None, uses default configs/t5_old/",
)
parser.add_argument(
"--qwen3_max_token_length",
type=int,
default=512,
help="Maximum token length for Qwen3 tokenizer (default: 512)",
)
parser.add_argument(
"--t5_max_token_length",
type=int,
default=512,
help="Maximum token length for T5 tokenizer (default: 512)",
)
parser.add_argument(
"--discrete_flow_shift",
type=float,
default=1.0,
help="Timestep distribution shift for rectified flow training (default: 1.0)",
)
parser.add_argument(
"--timestep_sampling",
type=str,
default="sigmoid",
choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"],
help="Timestep sampling method (default: sigmoid (logit normal))",
)
parser.add_argument(
"--sigmoid_scale",
type=float,
default=1.0,
help="Scale factor for sigmoid (logit_normal) timestep sampling (default: 1.0)",
)
parser.add_argument(
"--attn_mode",
choices=["torch", "xformers", "flash", "sageattn", "sdpa"], # "sdpa" is for backward compatibility
default=None,
help="Attention implementation to use. Default is None (torch). xformers requires --split_attn. sageattn does not support training (inference only). This option overrides --xformers or --sdpa."
" / 使用するAttentionの実装。デフォルトはNonetorchです。xformersは--split_attnの指定が必要です。sageattnはトレーニングをサポートしていません推論のみ。このオプションは--xformersまたは--sdpaを上書きします。",
)
parser.add_argument(
"--split_attn",
action="store_true",
help="split attention computation to reduce memory usage / メモリ使用量を減らすためにattention時にバッチを分割する",
)
parser.add_argument(
"--vae_chunk_size",
type=int,
default=None,
help="Spatial chunk size for VAE encoding/decoding to reduce memory usage. Must be even number. If not specified, chunking is disabled (official behavior)."
+ " / メモリ使用量を減らすためのVAEエンコード/デコードの空間チャンクサイズ。偶数である必要があります。未指定の場合、チャンク処理は無効になります(公式の動作)。",
)
parser.add_argument(
"--vae_disable_cache",
action="store_true",
help="Disable internal VAE caching mechanism to reduce memory usage. Encoding / decoding will also be faster, but this differs from official behavior."
+ " / VAEのメモリ使用量を減らすために内部のキャッシュ機構を無効にします。エンコード/デコードも速くなりますが、公式の動作とは異なります。",
)
# Loss weighting
def compute_loss_weighting_for_anima(weighting_scheme: str, sigmas: torch.Tensor) -> torch.Tensor:
"""Compute loss weighting for Anima training.
Same schemes as SD3 but can add Anima-specific ones if needed in future.
"""
if weighting_scheme == "sigma_sqrt":
weighting = (sigmas**-2.0).float()
elif weighting_scheme == "cosmap":
bot = 1 - 2 * sigmas + 2 * sigmas**2
weighting = 2 / (math.pi * bot)
elif weighting_scheme == "none" or weighting_scheme is None:
weighting = torch.ones_like(sigmas)
else:
weighting = torch.ones_like(sigmas)
return weighting
# Parameter groups (6 groups with separate LRs)
def get_anima_param_groups(
dit,
base_lr: float,
self_attn_lr: Optional[float] = None,
cross_attn_lr: Optional[float] = None,
mlp_lr: Optional[float] = None,
mod_lr: Optional[float] = None,
llm_adapter_lr: Optional[float] = None,
):
"""Create parameter groups for Anima training with separate learning rates.
Args:
dit: Anima model
base_lr: Base learning rate
self_attn_lr: LR for self-attention layers (None = base_lr, 0 = freeze)
cross_attn_lr: LR for cross-attention layers
mlp_lr: LR for MLP layers
mod_lr: LR for AdaLN modulation layers
llm_adapter_lr: LR for LLM adapter
Returns:
List of parameter group dicts for optimizer
"""
if self_attn_lr is None:
self_attn_lr = base_lr
if cross_attn_lr is None:
cross_attn_lr = base_lr
if mlp_lr is None:
mlp_lr = base_lr
if mod_lr is None:
mod_lr = base_lr
if llm_adapter_lr is None:
llm_adapter_lr = base_lr
base_params = []
self_attn_params = []
cross_attn_params = []
mlp_params = []
mod_params = []
llm_adapter_params = []
for name, p in dit.named_parameters():
# Store original name for debugging
p.original_name = name
if "llm_adapter" in name:
llm_adapter_params.append(p)
elif ".self_attn" in name:
self_attn_params.append(p)
elif ".cross_attn" in name:
cross_attn_params.append(p)
elif ".mlp" in name:
mlp_params.append(p)
elif ".adaln_modulation" in name:
mod_params.append(p)
else:
base_params.append(p)
logger.info(f"Parameter groups:")
logger.info(f" base_params: {len(base_params)} (lr={base_lr})")
logger.info(f" self_attn_params: {len(self_attn_params)} (lr={self_attn_lr})")
logger.info(f" cross_attn_params: {len(cross_attn_params)} (lr={cross_attn_lr})")
logger.info(f" mlp_params: {len(mlp_params)} (lr={mlp_lr})")
logger.info(f" mod_params: {len(mod_params)} (lr={mod_lr})")
logger.info(f" llm_adapter_params: {len(llm_adapter_params)} (lr={llm_adapter_lr})")
param_groups = []
for lr, params, name in [
(base_lr, base_params, "base"),
(self_attn_lr, self_attn_params, "self_attn"),
(cross_attn_lr, cross_attn_params, "cross_attn"),
(mlp_lr, mlp_params, "mlp"),
(mod_lr, mod_params, "mod"),
(llm_adapter_lr, llm_adapter_params, "llm_adapter"),
]:
if lr == 0:
for p in params:
p.requires_grad_(False)
logger.info(f" Frozen {name} params ({len(params)} parameters)")
elif len(params) > 0:
param_groups.append({"params": params, "lr": lr})
total_trainable = sum(p.numel() for group in param_groups for p in group["params"] if p.requires_grad)
logger.info(f"Total trainable parameters: {total_trainable:,}")
return param_groups
# Save functions
def save_anima_model_on_train_end(
args: argparse.Namespace,
save_dtype: torch.dtype,
epoch: int,
global_step: int,
dit: anima_models.Anima,
):
"""Save Anima model at the end of training."""
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec_dataclass(
None, args, False, False, False, is_stable_diffusion_ckpt=True, anima="preview"
).to_metadata_dict()
dit_sd = dit.state_dict()
# Save with 'net.' prefix for ComfyUI compatibility
anima_utils.save_anima_model(ckpt_file, dit_sd, sai_metadata, save_dtype)
train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None)
def save_anima_model_on_epoch_end_or_stepwise(
args: argparse.Namespace,
on_epoch_end: bool,
accelerator: Accelerator,
save_dtype: torch.dtype,
epoch: int,
num_train_epochs: int,
global_step: int,
dit: anima_models.Anima,
):
"""Save Anima model at epoch end or specific steps."""
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec_dataclass(
None, args, False, False, False, is_stable_diffusion_ckpt=True, anima="preview"
).to_metadata_dict()
dit_sd = dit.state_dict()
anima_utils.save_anima_model(ckpt_file, dit_sd, sai_metadata, save_dtype)
train_util.save_sd_model_on_epoch_end_or_stepwise_common(
args,
on_epoch_end,
accelerator,
True,
True,
epoch,
num_train_epochs,
global_step,
sd_saver,
None,
)
# Sampling (Euler discrete for rectified flow)
def do_sample(
height: int,
width: int,
seed: Optional[int],
dit: anima_models.Anima,
crossattn_emb: torch.Tensor,
steps: int,
dtype: torch.dtype,
device: torch.device,
guidance_scale: float = 1.0,
flow_shift: float = 3.0,
neg_crossattn_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Generate a sample using Euler discrete sampling for rectified flow.
Args:
height, width: Output image dimensions
seed: Random seed (None for random)
dit: Anima model
crossattn_emb: Cross-attention embeddings (B, N, D)
steps: Number of sampling steps
dtype: Compute dtype
device: Compute device
guidance_scale: CFG scale (1.0 = no guidance)
flow_shift: Flow shift parameter for rectified flow
neg_crossattn_emb: Negative cross-attention embeddings for CFG
Returns:
Denoised latents
"""
# Latent shape: (1, 16, 1, H/8, W/8) for single image
latent_h = height // 8
latent_w = width // 8
latent = torch.zeros(1, 16, 1, latent_h, latent_w, device=device, dtype=dtype)
# Generate noise
if seed is not None:
generator = torch.manual_seed(seed)
else:
generator = None
noise = torch.randn(latent.size(), dtype=torch.float32, generator=generator, device="cpu").to(dtype).to(device)
# Timestep schedule: linear from 1.0 to 0.0
sigmas = torch.linspace(1.0, 0.0, steps + 1, device=device, dtype=dtype)
flow_shift = float(flow_shift)
if flow_shift != 1.0:
sigmas = (sigmas * flow_shift) / (1 + (flow_shift - 1) * sigmas)
# Start from pure noise
x = noise.clone()
# Padding mask (zeros = no padding) — resized in prepare_embedded_sequence to match latent dims
padding_mask = torch.zeros(1, 1, latent_h, latent_w, dtype=dtype, device=device)
use_cfg = guidance_scale > 1.0 and neg_crossattn_emb is not None
for i in tqdm(range(steps), desc="Sampling"):
sigma = sigmas[i]
t = sigma.unsqueeze(0) # (1,)
if use_cfg:
# CFG: two separate passes to reduce memory usage
pos_out = dit(x, t, crossattn_emb, padding_mask=padding_mask)
pos_out = pos_out.float()
neg_out = dit(x, t, neg_crossattn_emb, padding_mask=padding_mask)
neg_out = neg_out.float()
model_output = neg_out + guidance_scale * (pos_out - neg_out)
else:
model_output = dit(x, t, crossattn_emb, padding_mask=padding_mask)
model_output = model_output.float()
# Euler step: x_{t-1} = x_t - (sigma_t - sigma_{t-1}) * model_output
dt = sigmas[i + 1] - sigma
x = x + model_output * dt
x = x.to(dtype)
return x
def sample_images(
accelerator: Accelerator,
args: argparse.Namespace,
epoch,
steps,
dit: anima_models.Anima,
vae,
text_encoder,
tokenize_strategy,
text_encoding_strategy,
sample_prompts_te_outputs=None,
prompt_replacement=None,
):
"""Generate sample images during training.
This is a simplified sampler for Anima - it generates images using the current model state.
"""
if steps == 0:
if not args.sample_at_first:
return
else:
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
return
if args.sample_every_n_epochs is not None:
if epoch is None or epoch % args.sample_every_n_epochs != 0:
return
else:
if steps % args.sample_every_n_steps != 0 or epoch is not None:
return
logger.info(f"Generating sample images at step {steps}")
if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None:
logger.error(f"No prompt file: {args.sample_prompts}")
return
# Unwrap models
dit = accelerator.unwrap_model(dit)
if text_encoder is not None:
text_encoder = accelerator.unwrap_model(text_encoder)
dit.switch_block_swap_for_inference()
prompts = train_util.load_prompts(args.sample_prompts)
save_dir = os.path.join(args.output_dir, "sample")
os.makedirs(save_dir, exist_ok=True)
# Save RNG state
rng_state = torch.get_rng_state()
cuda_rng_state = None
try:
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
except Exception:
pass
with torch.no_grad(), accelerator.autocast():
for prompt_dict in prompts:
dit.prepare_block_swap_before_forward()
_sample_image_inference(
accelerator,
args,
dit,
text_encoder,
vae,
tokenize_strategy,
text_encoding_strategy,
save_dir,
prompt_dict,
epoch,
steps,
sample_prompts_te_outputs,
prompt_replacement,
)
# Restore RNG state
torch.set_rng_state(rng_state)
if cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state)
dit.switch_block_swap_for_training()
clean_memory_on_device(accelerator.device)
def _sample_image_inference(
accelerator,
args,
dit,
text_encoder,
vae: qwen_image_autoencoder_kl.AutoencoderKLQwenImage,
tokenize_strategy,
text_encoding_strategy,
save_dir,
prompt_dict,
epoch,
steps,
sample_prompts_te_outputs,
prompt_replacement,
):
"""Generate a single sample image."""
prompt = prompt_dict.get("prompt", "")
negative_prompt = prompt_dict.get("negative_prompt", "")
sample_steps = prompt_dict.get("sample_steps", 30)
width = prompt_dict.get("width", 512)
height = prompt_dict.get("height", 512)
scale = prompt_dict.get("scale", 7.5)
seed = prompt_dict.get("seed")
flow_shift = prompt_dict.get("flow_shift", 3.0)
if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
if negative_prompt:
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # seed all CUDA devices for multi-GPU
height = max(64, height - height % 16)
width = max(64, width - width % 16)
logger.info(
f" prompt: {prompt}, size: {width}x{height}, steps: {sample_steps}, scale: {scale}, flow_shift: {flow_shift}, seed: {seed}"
)
# Encode prompt
def encode_prompt(prpt):
if sample_prompts_te_outputs and prpt in sample_prompts_te_outputs:
return sample_prompts_te_outputs[prpt]
if text_encoder is not None:
tokens = tokenize_strategy.tokenize(prpt)
encoded = text_encoding_strategy.encode_tokens(tokenize_strategy, [text_encoder], tokens)
return encoded
return None
encoded = encode_prompt(prompt)
if encoded is None:
logger.warning("Cannot encode prompt, skipping sample")
return
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = encoded
# Convert to tensors if numpy
if isinstance(prompt_embeds, np.ndarray):
prompt_embeds = torch.from_numpy(prompt_embeds).unsqueeze(0)
attn_mask = torch.from_numpy(attn_mask).unsqueeze(0)
t5_input_ids = torch.from_numpy(t5_input_ids).unsqueeze(0)
t5_attn_mask = torch.from_numpy(t5_attn_mask).unsqueeze(0)
prompt_embeds = prompt_embeds.to(accelerator.device, dtype=dit.dtype)
attn_mask = attn_mask.to(accelerator.device)
t5_input_ids = t5_input_ids.to(accelerator.device, dtype=torch.long)
t5_attn_mask = t5_attn_mask.to(accelerator.device)
# Process through LLM adapter if available
if dit.use_llm_adapter:
crossattn_emb = dit.llm_adapter(
source_hidden_states=prompt_embeds,
target_input_ids=t5_input_ids,
target_attention_mask=t5_attn_mask,
source_attention_mask=attn_mask,
)
crossattn_emb[~t5_attn_mask.bool()] = 0
else:
crossattn_emb = prompt_embeds
# Encode negative prompt for CFG
neg_crossattn_emb = None
if scale > 1.0 and negative_prompt is not None:
neg_encoded = encode_prompt(negative_prompt)
if neg_encoded is not None:
neg_pe, neg_am, neg_t5_ids, neg_t5_am = neg_encoded
if isinstance(neg_pe, np.ndarray):
neg_pe = torch.from_numpy(neg_pe).unsqueeze(0)
neg_am = torch.from_numpy(neg_am).unsqueeze(0)
neg_t5_ids = torch.from_numpy(neg_t5_ids).unsqueeze(0)
neg_t5_am = torch.from_numpy(neg_t5_am).unsqueeze(0)
neg_pe = neg_pe.to(accelerator.device, dtype=dit.dtype)
neg_am = neg_am.to(accelerator.device)
neg_t5_ids = neg_t5_ids.to(accelerator.device, dtype=torch.long)
neg_t5_am = neg_t5_am.to(accelerator.device)
if dit.use_llm_adapter:
neg_crossattn_emb = dit.llm_adapter(
source_hidden_states=neg_pe,
target_input_ids=neg_t5_ids,
target_attention_mask=neg_t5_am,
source_attention_mask=neg_am,
)
neg_crossattn_emb[~neg_t5_am.bool()] = 0
else:
neg_crossattn_emb = neg_pe
# Generate sample
clean_memory_on_device(accelerator.device)
latents = do_sample(
height, width, seed, dit, crossattn_emb, sample_steps, dit.dtype, accelerator.device, scale, flow_shift, neg_crossattn_emb
)
# Decode latents
gc.collect()
synchronize_device(accelerator.device)
clean_memory_on_device(accelerator.device)
org_vae_device = vae.device
vae.to(accelerator.device)
decoded = vae.decode_to_pixels(latents)
vae.to(org_vae_device)
clean_memory_on_device(accelerator.device)
# Convert to image
image = decoded.float()
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0]
# Remove temporal dim if present
if image.ndim == 4:
image = image[:, 0, :, :]
decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2)
decoded_np = decoded_np.astype(np.uint8)
image = Image.fromarray(decoded_np)
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
seed_suffix = "" if seed is None else f"_{seed}"
i = prompt_dict.get("enum", 0)
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
image.save(os.path.join(save_dir, img_filename))
# Log to wandb if enabled
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
wandb_tracker = accelerator.get_tracker("wandb")
import wandb
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False)