Add/modify some implementation for anima (#2261)

* fix: update extend-exclude list in _typos.toml to include configs

* fix: exclude anima tests from pytest

* feat: add entry for 'temperal' in extend-words section of _typos.toml for Qwen-Image VAE

* fix: update default value for --discrete_flow_shift in anima training guide

* feat: add Qwen-Image VAE

* feat: simplify encode_tokens

* feat: use unified attention module, add wrapper for state dict compatibility

* feat: loading with dynamic fp8 optimization and LoRA support

* feat: add anima minimal inference script (WIP)

* format: format

* feat: simplify target module selection by regular expression patterns

* feat: kept caption dropout rate in cache and handle in training script

* feat: update train_llm_adapter and verbose default values to string type

* fix: use strategy instead of using tokenizers directly

* feat: add dtype property and all-zero mask handling in cross-attention in LLMAdapterTransformerBlock

* feat: support 5d tensor in get_noisy_model_input_and_timesteps

* feat: update loss calculation to support 5d tensor

* fix: update argument names in anima_train_utils to align with other archtectures

* feat: simplify Anima training script and update empty caption handling

* feat: support LoRA format without `net.` prefix

* fix: update to work fp8_scaled option

* feat: add regex-based learning rates and dimensions handling in create_network

* fix: improve regex matching for module selection and learning rates in LoRANetwork

* fix: update logging message for regex match in LoRANetwork

* fix: keep latents 4D except DiT call

* feat: enhance block swap functionality for inference and training in Anima model

* feat: refactor Anima training script

* feat: optimize VAE processing by adjusting tensor dimensions and data types

* fix: wait all block trasfer before siwtching offloader mode

* feat: update Anima training guide with new argument specifications and regex-based module selection. Thank you Claude!

* feat: support LORA for Qwen3

* feat: update Anima SAI model spec metadata handling

* fix: remove unused code

* feat: split CFG processing in do_sample function to reduce memory usage

* feat: add VAE chunking and caching options to reduce memory usage

* feat: optimize RMSNorm forward method and remove unused torch_attention_op

* Update library/strategy_anima.py

Use torch.all instead of all.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update library/safetensors_utils.py

Fix duplicated new_key for concat_hook.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update anima_minimal_inference.py

Remove unused code.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update anima_train.py

Remove unused import.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update library/anima_train_utils.py

Remove unused import.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* fix: review with Copilot

* feat: add script to convert LoRA format to ComfyUI compatible format (WIP, not tested yet)

* feat: add process_escape function to handle escape sequences in prompts

* feat: enhance LoRA weight handling in model loading and add text encoder loading function

* feat: improve ComfyUI conversion script with prefix constants and module name adjustments

* feat: update caption dropout documentation to clarify cache regeneration requirement

* feat: add clarification on learning rate adjustments

* feat: add note on PyTorch version requirement to prevent NaN loss

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Kohya S. 2026-02-13 08:15:06 +09:00 committed by GitHub
parent 9144463f7b
commit 34e7138b6a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 4556 additions and 2259 deletions

View File

@ -32,6 +32,7 @@ hime="hime"
OT="OT"
byt="byt"
tak="tak"
temperal="temperal"
[files]
extend-exclude = ["_typos.toml", "venv"]
extend-exclude = ["_typos.toml", "venv", "configs"]

1082
anima_minimal_inference.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -3,6 +3,7 @@
import argparse
from concurrent.futures import ThreadPoolExecutor
import copy
import gc
import math
import os
from multiprocessing import Value
@ -12,8 +13,9 @@ import toml
from tqdm import tqdm
import torch
from library import utils
from library import flux_train_utils, qwen_image_autoencoder_kl
from library.device_utils import init_ipex, clean_memory_on_device
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
init_ipex()
@ -49,21 +51,18 @@ def train(args):
args.skip_cache_check = args.skip_latents_validity_check
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
logger.warning(
"cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled"
)
logger.warning("cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled")
args.cache_text_encoder_outputs = True
if args.cpu_offload_checkpointing and not args.gradient_checkpointing:
logger.warning("cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled")
args.gradient_checkpointing = True
if getattr(args, 'unsloth_offload_checkpointing', False):
if args.unsloth_offload_checkpointing:
if not args.gradient_checkpointing:
logger.warning("unsloth_offload_checkpointing is enabled, so gradient_checkpointing is also enabled")
args.gradient_checkpointing = True
assert not args.cpu_offload_checkpointing, \
"Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing"
assert not args.cpu_offload_checkpointing, "Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing"
assert (
args.blocks_to_swap is None or args.blocks_to_swap == 0
@ -71,17 +70,7 @@ def train(args):
assert (
args.blocks_to_swap is None or args.blocks_to_swap == 0
) or not getattr(args, 'unsloth_offload_checkpointing', False), \
"blocks_to_swap is not supported with unsloth_offload_checkpointing"
# Flash attention: validate availability
if getattr(args, 'flash_attn', False):
try:
import flash_attn # noqa: F401
logger.info("Flash Attention enabled for DiT blocks")
except ImportError:
logger.warning("flash_attn package not installed, falling back to PyTorch SDPA")
args.flash_attn = False
) or not args.unsloth_offload_checkpointing, "blocks_to_swap is not supported with unsloth_offload_checkpointing"
cache_latents = args.cache_latents
use_dreambooth_method = args.in_json is None
@ -104,9 +93,7 @@ def train(args):
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
logger.warning(
"ignore following options because config file is found: {0}".format(", ".join(ignored))
)
logger.warning("ignore following options because config file is found: {0}".format(", ".join(ignored)))
else:
if use_dreambooth_method:
logger.info("Using DreamBooth method.")
@ -145,26 +132,13 @@ def train(args):
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
train_dataset_group.verify_bucket_reso_steps(8) # WanVAE spatial downscale = 8
# Anima uses embedding-level dropout (in AnimaTextEncodingStrategy) instead of
# dataset-level caption dropout, so we save the rate and zero out subset-level
# caption_dropout_rate to allow text encoder output caching.
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
if caption_dropout_rate > 0:
logger.info(f"Using embedding-level caption dropout rate: {caption_dropout_rate}")
for dataset in train_dataset_group.datasets:
for subset in dataset.subsets:
subset.caption_dropout_rate = 0.0
train_dataset_group.verify_bucket_reso_steps(16) # Qwen-Image VAE spatial downscale = 8 * patch size = 2
if args.debug_dataset:
if args.cache_text_encoder_outputs:
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
strategy_anima.AnimaTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
args.text_encoder_batch_size,
False,
False,
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False
)
)
train_dataset_group.set_current_strategies()
@ -175,13 +149,11 @@ def train(args):
return
if cache_latents:
assert (
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used"
assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used"
if args.cache_text_encoder_outputs:
assert (
train_dataset_group.is_text_encoder_output_cacheable()
assert train_dataset_group.is_text_encoder_output_cacheable(
cache_supports_dropout=True
), "when caching text encoder output, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used"
# prepare accelerator
@ -191,24 +163,10 @@ def train(args):
# mixed precision dtype
weight_dtype, save_dtype = train_util.prepare_dtype(args)
# parse transformer_dtype
transformer_dtype = None
if hasattr(args, 'transformer_dtype') and args.transformer_dtype is not None:
transformer_dtype_map = {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}
transformer_dtype = transformer_dtype_map.get(args.transformer_dtype, None)
# Load tokenizers and set strategies
logger.info("Loading tokenizers...")
qwen3_text_encoder, qwen3_tokenizer = anima_utils.load_qwen3_text_encoder(
args.qwen3_path, dtype=weight_dtype, device="cpu"
)
t5_tokenizer = anima_utils.load_t5_tokenizer(
getattr(args, 't5_tokenizer_path', None)
)
qwen3_text_encoder, qwen3_tokenizer = anima_utils.load_qwen3_text_encoder(args.qwen3, dtype=weight_dtype, device="cpu")
t5_tokenizer = anima_utils.load_t5_tokenizer(args.t5_tokenizer_path)
# Set tokenize strategy
tokenize_strategy = strategy_anima.AnimaTokenizeStrategy(
@ -219,11 +177,7 @@ def train(args):
)
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
# Set text encoding strategy
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
text_encoding_strategy = strategy_anima.AnimaTextEncodingStrategy(
dropout_rate=caption_dropout_rate,
)
text_encoding_strategy = strategy_anima.AnimaTextEncodingStrategy()
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
# Prepare text encoder (always frozen for Anima)
@ -237,10 +191,7 @@ def train(args):
qwen3_text_encoder.eval()
text_encoder_caching_strategy = strategy_anima.AnimaTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
args.text_encoder_batch_size,
args.skip_cache_check,
is_partial=False,
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, is_partial=False
)
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy)
@ -259,27 +210,21 @@ def train(args):
logger.info(f" cache TE outputs for: {p}")
tokens_and_masks = tokenize_strategy.tokenize(p)
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
tokenize_strategy,
[qwen3_text_encoder],
tokens_and_masks,
enable_dropout=False,
tokenize_strategy, [qwen3_text_encoder], tokens_and_masks
)
# Pre-cache unconditional embeddings for caption dropout before text encoder is deleted
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
if caption_dropout_rate > 0.0:
with accelerator.autocast():
text_encoding_strategy.cache_uncond_embeddings(tokenize_strategy, [qwen3_text_encoder])
accelerator.wait_for_everyone()
# free text encoder memory
qwen3_text_encoder = None
gc.collect() # Force garbage collection to free memory
clean_memory_on_device(accelerator.device)
# Load VAE and cache latents
logger.info("Loading Anima VAE...")
vae, vae_mean, vae_std, vae_scale = anima_utils.load_anima_vae(args.vae_path, dtype=weight_dtype, device="cpu")
vae = qwen_image_autoencoder_kl.load_vae(
args.vae, device="cpu", disable_mmap=True, spatial_chunk_size=args.vae_chunk_size, disable_cache=args.vae_disable_cache
)
if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype)
@ -294,24 +239,16 @@ def train(args):
# Load DiT (MiniTrainDIT + optional LLM Adapter)
logger.info("Loading Anima DiT...")
dit = anima_utils.load_anima_dit(
args.dit_path,
dtype=weight_dtype,
device="cpu",
transformer_dtype=transformer_dtype,
llm_adapter_path=getattr(args, 'llm_adapter_path', None),
disable_mmap=getattr(args, 'disable_mmap_load_safetensors', False),
dit = anima_utils.load_anima_model(
"cpu", args.pretrained_model_name_or_path, args.attn_mode, args.split_attn, "cpu", dit_weight_dtype=None
)
if args.gradient_checkpointing:
dit.enable_gradient_checkpointing(
cpu_offload=args.cpu_offload_checkpointing,
unsloth_offload=getattr(args, 'unsloth_offload_checkpointing', False),
unsloth_offload=args.unsloth_offload_checkpointing,
)
if getattr(args, 'flash_attn', False):
dit.set_flash_attn(True)
train_dit = args.learning_rate != 0
dit.requires_grad_(train_dit)
if not train_dit:
@ -327,19 +264,17 @@ def train(args):
vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=weight_dtype)
# Move scale tensors to same device as VAE for on-the-fly encoding
vae_scale = [s.to(accelerator.device) if isinstance(s, torch.Tensor) else s for s in vae_scale]
# Setup optimizer with parameter groups
if train_dit:
param_groups = anima_train_utils.get_anima_param_groups(
dit,
base_lr=args.learning_rate,
self_attn_lr=getattr(args, 'self_attn_lr', None),
cross_attn_lr=getattr(args, 'cross_attn_lr', None),
mlp_lr=getattr(args, 'mlp_lr', None),
mod_lr=getattr(args, 'mod_lr', None),
llm_adapter_lr=getattr(args, 'llm_adapter_lr', None),
self_attn_lr=args.self_attn_lr,
cross_attn_lr=args.cross_attn_lr,
mlp_lr=args.mlp_lr,
mod_lr=args.mod_lr,
llm_adapter_lr=args.llm_adapter_lr,
)
else:
param_groups = []
@ -361,57 +296,7 @@ def train(args):
# prepare optimizer
accelerator.print("prepare optimizer, data loader etc.")
if args.blockwise_fused_optimizers:
# Split params into per-block groups for blockwise fused optimizer
# Build param_id → lr mapping from param_groups to propagate per-component LRs
param_lr_map = {}
for group in param_groups:
for p in group['params']:
param_lr_map[id(p)] = group['lr']
grouped_params = []
param_group = {}
named_parameters = list(dit.named_parameters())
for name, p in named_parameters:
if not p.requires_grad:
continue
# Determine block type and index
if name.startswith("blocks."):
block_index = int(name.split(".")[1])
block_type = "blocks"
elif name.startswith("llm_adapter.blocks."):
block_index = int(name.split(".")[2])
block_type = "llm_adapter"
else:
block_index = -1
block_type = "other"
param_group_key = (block_type, block_index)
if param_group_key not in param_group:
param_group[param_group_key] = []
param_group[param_group_key].append(p)
for param_group_key, params in param_group.items():
# Use per-component LR from param_groups if available
lr = param_lr_map.get(id(params[0]), args.learning_rate)
grouped_params.append({"params": params, "lr": lr})
num_params = sum(p.numel() for p in params)
accelerator.print(f"block {param_group_key}: {num_params} parameters, lr={lr}")
# Create per-group optimizers
optimizers = []
for group in grouped_params:
_, _, opt = train_util.get_optimizer(args, trainable_params=[group])
optimizers.append(opt)
optimizer = optimizers[0] # avoid error in following code
logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers")
if train_util.is_schedulefree_optimizer(optimizers[0], args):
raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers")
optimizer_train_fn = lambda: None
optimizer_eval_fn = lambda: None
elif args.fused_backward_pass:
if args.fused_backward_pass:
# Pass per-component param_groups directly to preserve per-component LRs
_, _, optimizer = train_util.get_optimizer(args, trainable_params=param_groups)
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)
@ -442,21 +327,19 @@ def train(args):
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr scheduler
if args.blockwise_fused_optimizers:
lr_schedulers = [train_util.get_scheduler_fix(args, opt, accelerator.num_processes) for opt in optimizers]
lr_scheduler = lr_schedulers[0] # avoid error in following code
else:
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# full fp16/bf16 training
dit_weight_dtype = weight_dtype
if args.full_fp16:
assert args.mixed_precision == "fp16", "full_fp16 requires mixed_precision='fp16'"
accelerator.print("enable full fp16 training.")
dit.to(weight_dtype)
elif args.full_bf16:
assert args.mixed_precision == "bf16", "full_bf16 requires mixed_precision='bf16'"
accelerator.print("enable full bf16 training.")
dit.to(weight_dtype)
else:
dit_weight_dtype = torch.float32 # If neither full_fp16 nor full_bf16, the model weights should be in float32
dit.to(dit_weight_dtype) # convert dit to target weight dtype
# move text encoder to GPU if not cached
if not args.cache_text_encoder_outputs and qwen3_text_encoder is not None:
@ -498,6 +381,7 @@ def train(args):
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
if args.fused_backward_pass:
# use fused optimizer for backward pass: other optimizers will be supported in the future
import library.adafactor_fused
library.adafactor_fused.patch_adafactor_fused(optimizer)
@ -517,55 +401,29 @@ def train(args):
parameter.register_post_accumulate_grad_hook(create_grad_hook(param_group))
elif args.blockwise_fused_optimizers:
# Prepare additional optimizers and lr schedulers
for i in range(1, len(optimizers)):
optimizers[i] = accelerator.prepare(optimizers[i])
lr_schedulers[i] = accelerator.prepare(lr_schedulers[i])
# Counters for blockwise gradient hook
optimizer_hooked_count = {}
num_parameters_per_group = [0] * len(optimizers)
parameter_optimizer_map = {}
for opt_idx, opt in enumerate(optimizers):
for param_group in opt.param_groups:
for parameter in param_group["params"]:
if parameter.requires_grad:
def grad_hook(parameter: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(parameter, args.max_grad_norm)
i = parameter_optimizer_map[parameter]
optimizer_hooked_count[i] += 1
if optimizer_hooked_count[i] == num_parameters_per_group[i]:
optimizers[i].step()
optimizers[i].zero_grad(set_to_none=True)
parameter.register_post_accumulate_grad_hook(grad_hook)
parameter_optimizer_map[parameter] = opt_idx
num_parameters_per_group[opt_idx] += 1
# Training loop
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
accelerator.print("running training")
accelerator.print(f" num examples: {train_dataset_group.num_train_images}")
accelerator.print(f" num batches per epoch: {len(train_dataloader)}")
accelerator.print(f" num epochs: {num_train_epochs}")
accelerator.print("running training / 学習開始")
accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
accelerator.print(
f" batch size per device: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
)
accelerator.print(f" gradient accumulation steps = {args.gradient_accumulation_steps}")
accelerator.print(f" total optimization steps: {args.max_train_steps}")
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
# Copy for noise and timestep generation, because noise_scheduler may be changed during training in future
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
@ -580,6 +438,7 @@ def train(args):
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
import wandb
wandb.define_metric("epoch")
wandb.define_metric("loss/epoch", step_metric="epoch")
@ -589,8 +448,15 @@ def train(args):
# For --sample_at_first
optimizer_eval_fn()
anima_train_utils.sample_images(
accelerator, args, 0, global_step, dit, vae, vae_scale,
qwen3_text_encoder, tokenize_strategy, text_encoding_strategy,
accelerator,
args,
0,
global_step,
dit,
vae,
qwen3_text_encoder,
tokenize_strategy,
text_encoding_strategy,
sample_prompts_te_outputs,
)
optimizer_train_fn()
@ -600,11 +466,11 @@ def train(args):
# Show model info
unwrapped_dit = accelerator.unwrap_model(dit) if dit is not None else None
if unwrapped_dit is not None:
logger.info(f"dit device: {unwrapped_dit.t_embedding_norm.weight.device}, dtype: {unwrapped_dit.t_embedding_norm.weight.dtype}")
logger.info(f"dit device: {unwrapped_dit.device}, dtype: {unwrapped_dit.dtype}")
if qwen3_text_encoder is not None:
logger.info(f"qwen3 device: {next(qwen3_text_encoder.parameters()).device}")
logger.info(f"qwen3 device: {qwen3_text_encoder.device}")
if vae is not None:
logger.info(f"vae device: {next(vae.parameters()).device}")
logger.info(f"vae device: {vae.device}")
loss_recorder = train_util.LossRecorder()
epoch = 0
@ -618,19 +484,17 @@ def train(args):
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
if args.blockwise_fused_optimizers:
optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step
with accelerator.accumulate(*training_models):
# Get latents
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device, dtype=weight_dtype)
latents = batch["latents"].to(accelerator.device, dtype=dit_weight_dtype)
if latents.ndim == 5: # Fallback for 5D latents (old cache)
latents = latents.squeeze(2) # (B, C, 1, H, W) -> (B, C, H, W)
else:
with torch.no_grad():
# images are already [-1, 1] from IMAGE_TRANSFORMS, add temporal dim
images = batch["images"].to(accelerator.device, dtype=weight_dtype)
images = images.unsqueeze(2) # (B, C, 1, H, W)
latents = vae.encode(images, vae_scale).to(accelerator.device, dtype=weight_dtype)
latents = vae.encode_pixels_to_latents(images).to(accelerator.device, dtype=dit_weight_dtype)
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
@ -640,23 +504,24 @@ def train(args):
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
# Cached outputs
caption_dropout_rates = text_encoder_outputs_list[-1]
text_encoder_outputs_list = text_encoder_outputs_list[:-1]
# Apply caption dropout to cached outputs
text_encoder_outputs_list = text_encoding_strategy.drop_cached_text_encoder_outputs(
*text_encoder_outputs_list
*text_encoder_outputs_list, caption_dropout_rates=caption_dropout_rates
)
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoder_outputs_list
else:
# Encode on-the-fly
input_ids_list = batch["input_ids_list"]
qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask = input_ids_list
with torch.no_grad():
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoding_strategy.encode_tokens(
tokenize_strategy,
[qwen3_text_encoder],
[qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask],
tokenize_strategy, [qwen3_text_encoder], input_ids_list
)
# Move to device
prompt_embeds = prompt_embeds.to(accelerator.device, dtype=weight_dtype)
prompt_embeds = prompt_embeds.to(accelerator.device, dtype=dit_weight_dtype)
attn_mask = attn_mask.to(accelerator.device)
t5_input_ids = t5_input_ids.to(accelerator.device, dtype=torch.long)
t5_attn_mask = t5_attn_mask.to(accelerator.device)
@ -664,9 +529,11 @@ def train(args):
# Noise and timesteps
noise = torch.randn_like(latents)
noisy_model_input, timesteps, sigmas = anima_train_utils.get_noisy_model_input_and_timesteps(
args, latents, noise, accelerator.device, weight_dtype
# Get noisy model input and timesteps
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
args, noise_scheduler_copy, latents, noise, accelerator.device, dit_weight_dtype
)
timesteps = timesteps / 1000.0 # scale to [0, 1] range. timesteps is float32
# NaN checks
if torch.any(torch.isnan(noisy_model_input)):
@ -678,15 +545,10 @@ def train(args):
bs = latents.shape[0]
h_latent = latents.shape[-2]
w_latent = latents.shape[-1]
padding_mask = torch.zeros(
bs, 1, h_latent, w_latent,
dtype=weight_dtype, device=accelerator.device
)
padding_mask = torch.zeros(bs, 1, h_latent, w_latent, dtype=dit_weight_dtype, device=accelerator.device)
# DiT forward (LLM adapter runs inside forward for DDP gradient sync)
if is_swapping_blocks:
accelerator.unwrap_model(dit).prepare_block_swap_before_forward()
noisy_model_input = noisy_model_input.unsqueeze(2) # 4D to 5D, (B, C, 1, H, W)
with accelerator.autocast():
model_pred = dit(
noisy_model_input,
@ -697,6 +559,7 @@ def train(args):
t5_input_ids=t5_input_ids,
t5_attn_mask=t5_attn_mask,
)
model_pred = model_pred.squeeze(2) # 5D to 4D, (B, C, H, W)
# Compute loss (rectified flow: target = noise - latents)
target = noise - latents
@ -708,12 +571,10 @@ def train(args):
# Loss
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, None)
loss = train_util.conditional_loss(
model_pred.float(), target.float(), args.loss_type, "none", huber_c
)
loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c)
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3, 4]) # (B, C, T, H, W) -> (B,)
loss = loss.mean([1, 2, 3]) # (B, C, H, W) -> (B,)
if weighting is not None:
loss = loss * weighting
@ -724,7 +585,7 @@ def train(args):
accelerator.backward(loss)
if not (args.fused_backward_pass or args.blockwise_fused_optimizers):
if not args.fused_backward_pass:
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = []
for m in training_models:
@ -737,9 +598,6 @@ def train(args):
else:
# optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
lr_scheduler.step()
if args.blockwise_fused_optimizers:
for i in range(1, len(optimizers)):
lr_schedulers[i].step()
# Checks if the accelerator has performed an optimization step
if accelerator.sync_gradients:
@ -748,8 +606,15 @@ def train(args):
optimizer_eval_fn()
anima_train_utils.sample_images(
accelerator, args, None, global_step, dit, vae, vae_scale,
qwen3_text_encoder, tokenize_strategy, text_encoding_strategy,
accelerator,
args,
None,
global_step,
dit,
vae,
qwen3_text_encoder,
tokenize_strategy,
text_encoding_strategy,
sample_prompts_te_outputs,
)
@ -773,8 +638,10 @@ def train(args):
if len(accelerator.trackers) > 0:
logs = {"loss": current_loss}
train_util.append_lr_to_logs_with_names(
logs, lr_scheduler, args.optimizer_type,
["base", "self_attn", "cross_attn", "mlp", "mod", "llm_adapter"] if train_dit else []
logs,
lr_scheduler,
args.optimizer_type,
["base", "self_attn", "cross_attn", "mlp", "mod", "llm_adapter"] if train_dit else [],
)
accelerator.log(logs, step=global_step)
@ -807,8 +674,15 @@ def train(args):
)
anima_train_utils.sample_images(
accelerator, args, epoch + 1, global_step, dit, vae, vae_scale,
qwen3_text_encoder, tokenize_strategy, text_encoding_strategy,
accelerator,
args,
epoch + 1,
global_step,
dit,
vae,
qwen3_text_encoder,
tokenize_strategy,
text_encoding_strategy,
sample_prompts_te_outputs,
)
@ -852,11 +726,6 @@ def setup_parser() -> argparse.ArgumentParser:
anima_train_utils.add_anima_training_arguments(parser)
sai_model_spec.add_model_spec_arguments(parser)
parser.add_argument(
"--blockwise_fused_optimizers",
action="store_true",
help="enable blockwise optimizers for fused backward pass and optimizer step",
)
parser.add_argument(
"--cpu_offload_checkpointing",
action="store_true",
@ -884,4 +753,7 @@ if __name__ == "__main__":
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
if args.attn_mode == "sdpa":
args.attn_mode = "torch" # backward compatibility
train(args)

View File

@ -1,16 +1,26 @@
# Anima LoRA training script
import argparse
import math
from typing import Any, Optional, Union
import torch
import torch.nn as nn
from accelerate import Accelerator
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from library import anima_models, anima_train_utils, anima_utils, strategy_anima, strategy_base, train_util
from library import (
anima_models,
anima_train_utils,
anima_utils,
flux_train_utils,
qwen_image_autoencoder_kl,
sd3_train_utils,
strategy_anima,
strategy_base,
train_util,
)
import train_network
from library.utils import setup_logging
@ -24,13 +34,6 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
def __init__(self):
super().__init__()
self.sample_prompts_te_outputs = None
self.vae = None
self.vae_scale = None
self.qwen3_text_encoder = None
self.qwen3_tokenizer = None
self.t5_tokenizer = None
self.tokenize_strategy = None
self.text_encoding_strategy = None
def assert_extra_args(
self,
@ -38,140 +41,118 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
val_dataset_group: Optional[train_util.DatasetGroup],
):
if args.fp8_base or args.fp8_base_unet:
logger.warning("fp8_base and fp8_base_unet are not supported. / fp8_baseとfp8_base_unetはサポートされていません。")
args.fp8_base = False
args.fp8_base_unet = False
args.fp8_scaled = False # Anima DiT does not support fp8_scaled
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
logger.warning(
"cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled"
)
logger.warning("cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled")
args.cache_text_encoder_outputs = True
# Anima uses embedding-level dropout (in AnimaTextEncodingStrategy) instead of
# dataset-level caption dropout, so zero out subset-level rates to allow caching.
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
if caption_dropout_rate > 0:
logger.info(f"Using embedding-level caption dropout rate: {caption_dropout_rate}")
if hasattr(train_dataset_group, 'datasets'):
for dataset in train_dataset_group.datasets:
for subset in dataset.subsets:
subset.caption_dropout_rate = 0.0
if args.cache_text_encoder_outputs:
assert (
train_dataset_group.is_text_encoder_output_cacheable()
assert train_dataset_group.is_text_encoder_output_cacheable(
cache_supports_dropout=True
), "when caching Text Encoder output, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used"
assert (
args.network_train_unet_only or not args.cache_text_encoder_outputs
), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません"
assert (
args.blocks_to_swap is None or args.blocks_to_swap == 0
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing"
if getattr(args, 'unsloth_offload_checkpointing', False):
if args.unsloth_offload_checkpointing:
if not args.gradient_checkpointing:
logger.warning("unsloth_offload_checkpointing is enabled, so gradient_checkpointing is also enabled")
args.gradient_checkpointing = True
assert not args.cpu_offload_checkpointing, \
"Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing"
assert (
not args.cpu_offload_checkpointing
), "Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing"
assert (
args.blocks_to_swap is None or args.blocks_to_swap == 0
), "blocks_to_swap is not supported with unsloth_offload_checkpointing"
# Flash attention: validate availability
if getattr(args, 'flash_attn', False):
try:
import flash_attn # noqa: F401
logger.info("Flash Attention enabled for DiT blocks")
except ImportError:
logger.warning("flash_attn package not installed, falling back to PyTorch SDPA")
args.flash_attn = False
if getattr(args, 'blockwise_fused_optimizers', False):
raise ValueError("blockwise_fused_optimizers is not supported with LoRA/NetworkTrainer")
train_dataset_group.verify_bucket_reso_steps(8) # WanVAE spatial downscale = 8
train_dataset_group.verify_bucket_reso_steps(16) # WanVAE spatial downscale = 8 and patch size = 2
if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(8)
val_dataset_group.verify_bucket_reso_steps(16)
def load_target_model(self, args, weight_dtype, accelerator):
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
# Load Qwen3 text encoder (tokenizers already loaded in get_tokenize_strategy)
logger.info("Loading Qwen3 text encoder...")
self.qwen3_text_encoder, _ = anima_utils.load_qwen3_text_encoder(
args.qwen3_path, dtype=weight_dtype, device="cpu"
)
self.qwen3_text_encoder.eval()
qwen3_text_encoder, _ = anima_utils.load_qwen3_text_encoder(args.qwen3, dtype=weight_dtype, device="cpu")
qwen3_text_encoder.eval()
# Parse transformer_dtype
transformer_dtype = None
if hasattr(args, 'transformer_dtype') and args.transformer_dtype is not None:
transformer_dtype_map = {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}
transformer_dtype = transformer_dtype_map.get(args.transformer_dtype, None)
# Load VAE
logger.info("Loading Anima VAE...")
vae = qwen_image_autoencoder_kl.load_vae(
args.vae, device="cpu", disable_mmap=True, spatial_chunk_size=args.vae_chunk_size, disable_cache=args.vae_disable_cache
)
vae.to(weight_dtype)
vae.eval()
# Return format: (model_type, text_encoders, vae, unet)
return "anima", [qwen3_text_encoder], vae, None # unet loaded lazily
def load_unet_lazily(self, args, weight_dtype, accelerator, text_encoders) -> tuple[nn.Module, list[nn.Module]]:
loading_dtype = None if args.fp8_scaled else weight_dtype
loading_device = "cpu" if self.is_swapping_blocks else accelerator.device
attn_mode = "torch"
if args.xformers:
attn_mode = "xformers"
if args.attn_mode is not None:
attn_mode = args.attn_mode
# Load DiT
logger.info("Loading Anima DiT...")
dit = anima_utils.load_anima_dit(
args.dit_path,
dtype=weight_dtype,
device="cpu",
transformer_dtype=transformer_dtype,
llm_adapter_path=getattr(args, 'llm_adapter_path', None),
disable_mmap=getattr(args, 'disable_mmap_load_safetensors', False),
logger.info(f"Loading Anima DiT model with attn_mode={attn_mode}, split_attn: {args.split_attn}...")
model = anima_utils.load_anima_model(
accelerator.device,
args.pretrained_model_name_or_path,
attn_mode,
args.split_attn,
loading_device,
loading_dtype,
args.fp8_scaled,
)
# Flash attention
if getattr(args, 'flash_attn', False):
dit.set_flash_attn(True)
# Store unsloth preference so that when the base NetworkTrainer calls
# dit.enable_gradient_checkpointing(cpu_offload=...), we can override to use unsloth.
# The base trainer only passes cpu_offload, so we store the flag on the model.
self._use_unsloth_offload_checkpointing = getattr(args, 'unsloth_offload_checkpointing', False)
self._use_unsloth_offload_checkpointing = args.unsloth_offload_checkpointing
# Block swap
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
if self.is_swapping_blocks:
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
dit.enable_block_swap(args.blocks_to_swap, accelerator.device)
model.enable_block_swap(args.blocks_to_swap, accelerator.device)
# Load VAE
logger.info("Loading Anima VAE...")
self.vae, vae_mean, vae_std, self.vae_scale = anima_utils.load_anima_vae(
args.vae_path, dtype=weight_dtype, device="cpu"
)
# Return format: (model_type, text_encoders, vae, unet)
return "anima", [self.qwen3_text_encoder], self.vae, dit
return model, text_encoders
def get_tokenize_strategy(self, args):
# Load tokenizers from paths (called before load_target_model, so self.qwen3_tokenizer isn't set yet)
self.tokenize_strategy = strategy_anima.AnimaTokenizeStrategy(
qwen3_path=args.qwen3_path,
t5_tokenizer_path=getattr(args, 't5_tokenizer_path', None),
tokenize_strategy = strategy_anima.AnimaTokenizeStrategy(
qwen3_path=args.qwen3,
t5_tokenizer_path=args.t5_tokenizer_path,
qwen3_max_length=args.qwen3_max_token_length,
t5_max_length=args.t5_max_token_length,
)
# Store references so load_target_model can reuse them
self.qwen3_tokenizer = self.tokenize_strategy.qwen3_tokenizer
self.t5_tokenizer = self.tokenize_strategy.t5_tokenizer
return self.tokenize_strategy
return tokenize_strategy
def get_tokenizers(self, tokenize_strategy: strategy_anima.AnimaTokenizeStrategy):
return [tokenize_strategy.qwen3_tokenizer]
def get_latents_caching_strategy(self, args):
return strategy_anima.AnimaLatentsCachingStrategy(
args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
return strategy_anima.AnimaLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check)
def get_text_encoding_strategy(self, args):
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
self.text_encoding_strategy = strategy_anima.AnimaTextEncodingStrategy(
dropout_rate=caption_dropout_rate,
)
return self.text_encoding_strategy
return strategy_anima.AnimaTextEncodingStrategy()
def post_process_network(self, args, accelerator, network, text_encoders, unet):
# Qwen3 text encoder is always frozen for Anima
pass
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
@ -179,19 +160,10 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
return None # no text encoders needed for encoding
return text_encoders
def get_text_encoders_train_flags(self, args, text_encoders):
return [False] # Qwen3 always frozen
def is_train_text_encoder(self, args):
return False # Qwen3 text encoder is always frozen for Anima
def get_text_encoder_outputs_caching_strategy(self, args):
if args.cache_text_encoder_outputs:
return strategy_anima.AnimaTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
args.text_encoder_batch_size,
args.skip_cache_check,
is_partial=False,
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False
)
return None
@ -200,15 +172,14 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
):
if args.cache_text_encoder_outputs:
if not args.lowram:
logger.info("move vae and unet to cpu to save memory")
org_vae_device = next(vae.parameters()).device
org_unet_device = unet.device
# We cannot move DiT to CPU because of block swap, so only move VAE
logger.info("move vae to cpu to save memory")
org_vae_device = vae.device
vae.to("cpu")
unet.to("cpu")
clean_memory_on_device(accelerator.device)
logger.info("move text encoder to gpu")
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
text_encoders[0].to(accelerator.device)
with accelerator.autocast():
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
@ -229,59 +200,52 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
logger.info(f" cache TE outputs for: {p}")
tokens_and_masks = tokenize_strategy.tokenize(p)
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
tokenize_strategy,
text_encoders,
tokens_and_masks,
enable_dropout=False,
tokenize_strategy, text_encoders, tokens_and_masks
)
self.sample_prompts_te_outputs = sample_prompts_te_outputs
# Pre-cache unconditional embeddings for caption dropout before text encoder is deleted
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
text_encoding_strategy_for_uncond = strategy_base.TextEncodingStrategy.get_strategy()
if caption_dropout_rate > 0.0:
tokenize_strategy_for_uncond = strategy_base.TokenizeStrategy.get_strategy()
with accelerator.autocast():
text_encoding_strategy_for_uncond.cache_uncond_embeddings(tokenize_strategy_for_uncond, text_encoders)
accelerator.wait_for_everyone()
# move text encoder back to cpu
logger.info("move text encoder back to cpu")
text_encoders[0].to("cpu")
clean_memory_on_device(accelerator.device)
if not args.lowram:
logger.info("move vae and unet back to original device")
logger.info("move vae back to original device")
vae.to(org_vae_device)
unet.to(org_unet_device)
clean_memory_on_device(accelerator.device)
else:
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
# move text encoder to device for encoding during training/validation
text_encoders[0].to(accelerator.device)
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder] # compatibility
te = self.get_models_for_text_encoding(args, accelerator, text_encoders)
qwen3_te = te[0] if te is not None else None
text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
anima_train_utils.sample_images(
accelerator, args, epoch, global_step, unet, vae, self.vae_scale,
qwen3_te, self.tokenize_strategy, self.text_encoding_strategy,
accelerator,
args,
epoch,
global_step,
unet,
vae,
qwen3_te,
tokenize_strategy,
text_encoding_strategy,
self.sample_prompts_te_outputs,
)
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
noise_scheduler = anima_train_utils.FlowMatchEulerDiscreteScheduler(
num_train_timesteps=1000, shift=args.discrete_flow_shift
)
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
return noise_scheduler
def encode_images_to_latents(self, args, vae, images):
# images are already [-1,1] from IMAGE_TRANSFORMS, add temporal dim
images = images.unsqueeze(2) # (B, C, 1, H, W)
# Ensure scale tensors are on the same device as images
vae_device = images.device
scale = [s.to(vae_device) if isinstance(s, torch.Tensor) else s for s in self.vae_scale]
return vae.encode(images, scale)
vae: qwen_image_autoencoder_kl.AutoencoderKLQwenImage
return vae.encode_pixels_to_latents(images) # Keep 4D for input/output
def shift_scale_latents(self, args, latents):
# Latents already normalized by vae.encode with scale
@ -301,13 +265,18 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
train_unet,
is_train=True,
):
anima: anima_models.Anima = unet
# Sample noise
if latents.ndim == 5: # Fallback for 5D latents (old cache)
latents = latents.squeeze(2) # [B, C, 1, H, W] -> [B, C, H, W]
noise = torch.randn_like(latents)
# Get noisy model input and timesteps
noisy_model_input, timesteps, sigmas = anima_train_utils.get_noisy_model_input_and_timesteps(
args, latents, noise, accelerator.device, weight_dtype
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
)
timesteps = timesteps / 1000.0 # scale to [0, 1] range. timesteps is float32
# Gradient checkpointing support
if args.gradient_checkpointing:
@ -329,161 +298,103 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
bs = latents.shape[0]
h_latent = latents.shape[-2]
w_latent = latents.shape[-1]
padding_mask = torch.zeros(
bs, 1, h_latent, w_latent,
dtype=weight_dtype, device=accelerator.device
)
padding_mask = torch.zeros(bs, 1, h_latent, w_latent, dtype=weight_dtype, device=accelerator.device)
# Prepare block swap
if self.is_swapping_blocks:
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
# Call model (LLM adapter runs inside forward for DDP gradient sync)
# Call model
noisy_model_input = noisy_model_input.unsqueeze(2) # 4D to 5D, [B, C, H, W] -> [B, C, 1, H, W]
with torch.set_grad_enabled(is_train), accelerator.autocast():
model_pred = unet(
model_pred = anima(
noisy_model_input,
timesteps,
prompt_embeds,
padding_mask=padding_mask,
target_input_ids=t5_input_ids,
target_attention_mask=t5_attn_mask,
source_attention_mask=attn_mask,
t5_input_ids=t5_input_ids,
t5_attn_mask=t5_attn_mask,
)
model_pred = model_pred.squeeze(2) # 5D to 4D, [B, C, 1, H, W] -> [B, C, H, W]
# Rectified flow target: noise - latents
target = noise - latents
# Loss weighting
weighting = anima_train_utils.compute_loss_weighting_for_anima(
weighting_scheme=args.weighting_scheme, sigmas=sigmas
)
# Differential output preservation
if "custom_attributes" in batch:
diff_output_pr_indices = []
for i, custom_attributes in enumerate(batch["custom_attributes"]):
if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
diff_output_pr_indices.append(i)
if len(diff_output_pr_indices) > 0:
network.set_multiplier(0.0)
with torch.no_grad(), accelerator.autocast():
if self.is_swapping_blocks:
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
model_pred_prior = unet(
noisy_model_input[diff_output_pr_indices],
timesteps[diff_output_pr_indices],
prompt_embeds[diff_output_pr_indices],
padding_mask=padding_mask[diff_output_pr_indices],
source_attention_mask=attn_mask[diff_output_pr_indices],
t5_input_ids=t5_input_ids[diff_output_pr_indices],
t5_attn_mask=t5_attn_mask[diff_output_pr_indices],
)
network.set_multiplier(1.0)
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
weighting = anima_train_utils.compute_loss_weighting_for_anima(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
return model_pred, target, timesteps, weighting
def process_batch(
self, batch, text_encoders, unet, network, vae, noise_scheduler,
vae_dtype, weight_dtype, accelerator, args,
text_encoding_strategy, tokenize_strategy,
is_train=True, train_text_encoder=True, train_unet=True,
self,
batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train=True,
train_text_encoder=True,
train_unet=True,
) -> torch.Tensor:
"""Override base process_batch for 5D video latents (B, C, T, H, W).
Base class assumes 4D (B, C, H, W) for loss.mean([1,2,3]) and weighting broadcast.
"""
import typing
from library.custom_train_functions import apply_masked_loss
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device))
else:
if args.vae_batch_size is None or len(batch["images"]) <= args.vae_batch_size:
latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype))
else:
chunks = [
batch["images"][i : i + args.vae_batch_size] for i in range(0, len(batch["images"]), args.vae_batch_size)
]
list_latents = []
for chunk in chunks:
with torch.no_grad():
chunk = self.encode_images_to_latents(args, vae, chunk.to(accelerator.device, dtype=vae_dtype))
list_latents.append(chunk)
latents = torch.cat(list_latents, dim=0)
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = typing.cast(torch.FloatTensor, torch.nan_to_num(latents, 0, out=latents))
latents = self.shift_scale_latents(args, latents)
"""Override base process_batch for caption dropout with cached text encoder outputs."""
# Text encoder conditions
text_encoder_conds = []
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
anima_text_encoding_strategy: strategy_anima.AnimaTextEncodingStrategy = text_encoding_strategy
if text_encoder_outputs_list is not None:
text_encoder_conds = text_encoder_outputs_list
caption_dropout_rates = text_encoder_outputs_list[-1]
text_encoder_outputs_list = text_encoder_outputs_list[:-1]
if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast():
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens(
tokenize_strategy,
self.get_models_for_text_encoding(args, accelerator, text_encoders),
input_ids,
)
if args.full_fp16:
encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds]
# Apply caption dropout to cached outputs
text_encoder_outputs_list = anima_text_encoding_strategy.drop_cached_text_encoder_outputs(
*text_encoder_outputs_list, caption_dropout_rates=caption_dropout_rates
)
batch["text_encoder_outputs_list"] = text_encoder_outputs_list
if len(text_encoder_conds) == 0:
text_encoder_conds = encoded_text_encoder_conds
else:
for i in range(len(encoded_text_encoder_conds)):
if encoded_text_encoder_conds[i] is not None:
text_encoder_conds[i] = encoded_text_encoder_conds[i]
noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target(
args, accelerator, noise_scheduler, latents, batch,
text_encoder_conds, unet, network, weight_dtype, train_unet, is_train=is_train,
return super().process_batch(
batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train,
train_text_encoder,
train_unet,
)
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
# Reduce all non-batch dims: (B, C, T, H, W) -> (B,) for 5D, (B, C, H, W) -> (B,) for 4D
reduce_dims = list(range(1, loss.ndim))
loss = loss.mean(reduce_dims)
# Apply weighting after reducing to (B,)
if weighting is not None:
loss = loss * weighting
loss_weights = batch["loss_weights"]
loss = loss * loss_weights
loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
return loss.mean()
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
return loss
def get_sai_model_spec(self, args):
return train_util.get_sai_model_spec(None, args, False, True, False, is_stable_diffusion_ckpt=True)
return train_util.get_sai_model_spec_dataclass(None, args, False, True, False, anima="preview").to_metadata_dict()
def update_metadata(self, metadata, args):
metadata["ss_weighting_scheme"] = args.weighting_scheme
metadata["ss_logit_mean"] = args.logit_mean
metadata["ss_logit_std"] = args.logit_std
metadata["ss_mode_scale"] = args.mode_scale
metadata["ss_timestep_sampling"] = args.timestep_sampling
metadata["ss_sigmoid_scale"] = args.sigmoid_scale
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
metadata["ss_timestep_sample_method"] = getattr(args, 'timestep_sample_method', 'logit_normal')
metadata["ss_sigmoid_scale"] = getattr(args, 'sigmoid_scale', 1.0)
def is_text_encoder_not_needed_for_training(self, args):
return args.cache_text_encoder_outputs
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
# Set first parameter's requires_grad to True to workaround Accelerate gradient checkpointing bug
first_param = next(text_encoder.parameters())
first_param.requires_grad_(True)
def prepare_unet_with_accelerator(
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
@ -496,23 +407,16 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
if not self.is_swapping_blocks:
return super().prepare_unet_with_accelerator(args, accelerator, unet)
dit = unet
dit = accelerator.prepare(dit, device_placement=[not self.is_swapping_blocks])
accelerator.unwrap_model(dit).move_to_device_except_swap_blocks(accelerator.device)
accelerator.unwrap_model(dit).prepare_block_swap_before_forward()
model = unet
model = accelerator.prepare(model, device_placement=[not self.is_swapping_blocks])
accelerator.unwrap_model(model).move_to_device_except_swap_blocks(accelerator.device)
accelerator.unwrap_model(model).prepare_block_swap_before_forward()
return dit
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True):
# Drop cached text encoder outputs for caption dropout
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
text_encoding_strategy: strategy_anima.AnimaTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
text_encoder_outputs_list = text_encoding_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list)
batch["text_encoder_outputs_list"] = text_encoder_outputs_list
return model
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
if self.is_swapping_blocks:
# prepare for next forward: because backward pass is not called, we need to prepare it here
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
@ -520,6 +424,7 @@ def setup_parser() -> argparse.ArgumentParser:
parser = train_network.setup_parser()
train_util.add_dit_training_arguments(parser)
anima_train_utils.add_anima_training_arguments(parser)
# parser.add_argument("--fp8_scaled", action="store_true", help="Use scaled fp8 for DiT / DiTにスケーリングされたfp8を使う")
parser.add_argument(
"--unsloth_offload_checkpointing",
action="store_true",
@ -536,5 +441,8 @@ if __name__ == "__main__":
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
if args.attn_mode == "sdpa":
args.attn_mode = "torch" # backward compatibility
trainer = AnimaNetworkTrainer()
trainer.train(args)

View File

@ -11,7 +11,9 @@ This document explains how to train LoRA (Low-Rank Adaptation) models for Anima
## 1. Introduction / はじめに
`anima_train_network.py` trains additional networks such as LoRA for Anima models. Anima adopts a DiT (Diffusion Transformer) architecture based on the MiniTrainDIT design with Rectified Flow training. It uses a Qwen3-0.6B text encoder, an LLM Adapter (6-layer transformer bridge from Qwen3 to T5-compatible space), and a WanVAE (16-channel, 8x spatial downscale).
`anima_train_network.py` trains additional networks such as LoRA for Anima models. Anima adopts a DiT (Diffusion Transformer) architecture based on the MiniTrainDIT design with Rectified Flow training. It uses a Qwen3-0.6B text encoder, an LLM Adapter (6-layer transformer bridge from Qwen3 to T5-compatible space), and a Qwen-Image VAE (16-channel, 8x spatial downscale).
Qwen-Image VAE and Qwen-Image VAE have same architecture, but [official Anima weight is named for Qwen-Image VAE](https://huggingface.co/circlestone-labs/Anima/tree/main/split_files/vae).
This guide assumes you already understand the basics of LoRA training. For common usage and options, see the [train_network.py guide](train_network.md). Some parameters are similar to those in [`sd3_train_network.py`](sd3_train_network.md) and [`flux_train_network.py`](flux_train_network.md).
@ -24,7 +26,9 @@ This guide assumes you already understand the basics of LoRA training. For commo
<details>
<summary>日本語</summary>
`anima_train_network.py`は、Anima モデルに対してLoRAなどの追加ネットワークを学習させるためのスクリプトです。AnimaはMiniTrainDIT設計に基づくDiT (Diffusion Transformer) アーキテクチャを採用しており、Rectified Flow学習を使用します。テキストエンコーダーとしてQwen3-0.6B、LLM Adapter (Qwen3からT5互換空間への6層Transformerブリッジ)、およびWanVAE (16チャンネル、8倍空間ダウンスケール) を使用します。
`anima_train_network.py`は、Anima モデルに対してLoRAなどの追加ネットワークを学習させるためのスクリプトです。AnimaはMiniTrainDIT設計に基づくDiT (Diffusion Transformer) アーキテクチャを採用しており、Rectified Flow学習を使用します。テキストエンコーダーとしてQwen3-0.6B、LLM Adapter (Qwen3からT5互換空間への6層Transformerブリッジ)、およびQwen-Image VAE (16チャンネル、8倍空間ダウンスケール) を使用します。
Qwen-Image VAEとQwen-Image VAEは同じアーキテクチャですが、[Anima公式の重みはQwen-Image VAE用](https://huggingface.co/circlestone-labs/Anima/tree/main/split_files/vae)のようです。
このガイドは、基本的なLoRA学習の手順を理解しているユーザーを対象としています。基本的な使い方や共通のオプションについては、[`train_network.py`のガイド](train_network.md)を参照してください。また一部のパラメータは [`sd3_train_network.py`](sd3_train_network.md) や [`flux_train_network.py`](flux_train_network.md) と同様のものがあるため、そちらも参考にしてください。
@ -37,14 +41,14 @@ This guide assumes you already understand the basics of LoRA training. For commo
## 2. Differences from `train_network.py` / `train_network.py` との違い
`anima_train_network.py` is based on `train_network.py` but modified for Anima . Main differences are:
`anima_train_network.py` is based on `train_network.py` but modified for Anima. Main differences are:
* **Target models:** Anima DiT models.
* **Model structure:** Uses a MiniTrainDIT (Transformer based) instead of U-Net. Employs a single text encoder (Qwen3-0.6B), an LLM Adapter that bridges Qwen3 embeddings to T5-compatible cross-attention space, and a WanVAE (16-channel latent space with 8x spatial downscale).
* **Arguments:** Options exist to specify the Anima DiT model, Qwen3 text encoder, WanVAE, LLM adapter, and T5 tokenizer separately.
* **Incompatible arguments:** Stable Diffusion v1/v2 options such as `--v2`, `--v_parameterization` and `--clip_skip` are not used.
* **Anima specific options:** Additional parameters for component-wise learning rates (self_attn, cross_attn, mlp, mod, llm_adapter), timestep sampling, discrete flow shift, and flash attention.
* **6 Parameter Groups:** Independent learning rates for `base`, `self_attn`, `cross_attn`, `mlp`, `adaln_modulation`, and `llm_adapter` components.
* **Model structure:** Uses a MiniTrainDIT (Transformer based) instead of U-Net. Employs a single text encoder (Qwen3-0.6B), an LLM Adapter that bridges Qwen3 embeddings to T5-compatible cross-attention space, and a Qwen-Image VAE (16-channel latent space with 8x spatial downscale).
* **Arguments:** Uses the common `--pretrained_model_name_or_path` for the DiT model path, `--qwen3` for the Qwen3 text encoder, and `--vae` for the Qwen-Image VAE. The LLM adapter and T5 tokenizer can be specified separately with `--llm_adapter_path` and `--t5_tokenizer_path`.
* **Incompatible arguments:** Stable Diffusion v1/v2 options such as `--v2`, `--v_parameterization` and `--clip_skip` are not used. `--fp8_base` is not supported.
* **Timestep sampling:** Uses the same `--timestep_sampling` options as FLUX training (`sigma`, `uniform`, `sigmoid`, `shift`, `flux_shift`).
* **LoRA:** Uses regex-based module selection and per-module rank/learning rate control (`network_reg_dims`, `network_reg_lrs`) instead of per-component arguments. Module exclusion/inclusion is controlled by `exclude_patterns` and `include_patterns`.
<details>
<summary>日本語</summary>
@ -52,11 +56,11 @@ This guide assumes you already understand the basics of LoRA training. For commo
`anima_train_network.py``train_network.py`をベースに、Anima モデルに対応するための変更が加えられています。主な違いは以下の通りです。
* **対象モデル:** Anima DiTモデルを対象とします。
* **モデル構造:** U-Netの代わりにMiniTrainDIT (Transformerベース) を使用します。テキストエンコーダーとしてQwen3-0.6B、Qwen3埋め込みをT5互換のクロスアテンション空間に変換するLLM Adapter、およびWanVAE (16チャンネル潜在空間、8倍空間ダウンスケール) を使用します。
* **引数:** Anima DiTモデル、Qwen3テキストエンコーダー、WanVAE、LLM Adapter、T5トークナイザーを個別に指定する引数があります。
* **一部引数の非互換性:** Stable Diffusion v1/v2向けの引数例: `--v2`, `--v_parameterization`, `--clip_skip`)はAnimaの学習では使用されません。
* **Anima特有の引数:** コンポーネント別学習率self_attn, cross_attn, mlp, mod, llm_adapter、タイムステップサンプリング、離散フローシフト、Flash Attentionに関する引数が追加されています。
* **6パラメータグループ:** `base``self_attn``cross_attn``mlp``adaln_modulation``llm_adapter`の各コンポーネントに対して独立した学習率を設定できます。
* **モデル構造:** U-Netの代わりにMiniTrainDIT (Transformerベース) を使用します。テキストエンコーダーとしてQwen3-0.6B、Qwen3埋め込みをT5互換のクロスアテンション空間に変換するLLM Adapter、およびQwen-Image VAE (16チャンネル潜在空間、8倍空間ダウンスケール) を使用します。
* **引数:** DiTモデルのパスには共通引数`--pretrained_model_name_or_path`を、Qwen3テキストエンコーダーには`--qwen3`を、Qwen-Image VAEには`--vae`を使用します。LLM AdapterとT5トークナイザーはそれぞれ`--llm_adapter_path``--t5_tokenizer_path`で個別に指定できます。
* **一部引数の非互換性:** Stable Diffusion v1/v2向けの引数例: `--v2`, `--v_parameterization`, `--clip_skip`)は使用されません。`--fp8_base`はサポートされていません。
* **タイムステップサンプリング:** FLUX学習と同じ`--timestep_sampling`オプション(`sigma``uniform``sigmoid``shift``flux_shift`)を使用します。
* **LoRA:** コンポーネント別の引数の代わりに、正規表現ベースのモジュール選択とモジュール単位のランク/学習率制御(`network_reg_dims``network_reg_lrs`)を使用します。モジュールの除外/包含は`exclude_patterns``include_patterns`で制御します。
</details>
## 3. Preparation / 準備
@ -65,16 +69,16 @@ The following files are required before starting training:
1. **Training script:** `anima_train_network.py`
2. **Anima DiT model file:** `.safetensors` file for the base DiT model.
3. **Qwen3-0.6B text encoder:** Either a HuggingFace model directory or a single `.safetensors` file (requires `configs/qwen3_06b/` config files).
4. **WanVAE model file:** `.safetensors` or `.pth` file for the VAE.
3. **Qwen3-0.6B text encoder:** Either a HuggingFace model directory, or a single `.safetensors` file (uses the bundled config files in `configs/qwen3_06b/`).
4. **Qwen-Image VAE model file:** `.safetensors` or `.pth` file for the VAE.
5. **LLM Adapter model file (optional):** `.safetensors` file. If not provided separately, the adapter is loaded from the DiT file if the key `llm_adapter.out_proj.weight` exists.
6. **T5 Tokenizer (optional):** If not specified, uses the bundled tokenizer at `configs/t5_old/`.
7. **Dataset definition file (.toml):** Dataset settings in TOML format. (See the [Dataset Configuration Guide](./config_README-en.md).) In this document we use `my_anima_dataset_config.toml` as an example.
Model files can be obtained from the [Anima HuggingFace repository](https://huggingface.co/circlestone-labs/Anima).
**Notes:**
* When using a single `.safetensors` file for Qwen3, download the `config.json`, `tokenizer.json`, `tokenizer_config.json`, and `vocab.json` from the [Qwen/Qwen3-0.6B](https://huggingface.co/Qwen/Qwen3-0.6B) HuggingFace repository into the `configs/qwen3_06b/` directory.
* The T5 tokenizer only needs the tokenizer files (not the T5 model weights). It uses the vocabulary from `google/t5-v1_1-xxl`.
* Models are saved with a `net.` prefix on all keys for ComfyUI compatibility.
<details>
<summary>日本語</summary>
@ -83,16 +87,16 @@ The following files are required before starting training:
1. **学習スクリプト:** `anima_train_network.py`
2. **Anima DiTモデルファイル:** ベースとなるDiTモデルの`.safetensors`ファイル。
3. **Qwen3-0.6Bテキストエンコーダー:** HuggingFaceモデルディレクトリまたは単体の`.safetensors`ファイル(`configs/qwen3_06b/`の設定ファイルが必要)。
4. **WanVAEモデルファイル:** VAEの`.safetensors`または`.pth`ファイル。
3. **Qwen3-0.6Bテキストエンコーダー:** HuggingFaceモデルディレクトリまたは単体の`.safetensors`ファイル(バンドル版の`configs/qwen3_06b/`の設定ファイルが使用されます)。
4. **Qwen-Image VAEモデルファイル:** VAEの`.safetensors`または`.pth`ファイル。
5. **LLM Adapterモデルファイルオプション:** `.safetensors`ファイル。個別に指定しない場合、DiTファイル内に`llm_adapter.out_proj.weight`キーが存在すればそこから読み込まれます。
6. **T5トークナイザーオプション:** 指定しない場合、`configs/t5_old/`のバンドル版トークナイザーを使用します。
7. **データセット定義ファイル (.toml):** 学習データセットの設定を記述したTOML形式のファイル。詳細は[データセット設定ガイド](./config_README-en.md)を参照してください)。例として`my_anima_dataset_config.toml`を使用します。
モデルファイルは[HuggingFaceのAnimaリポジトリ](https://huggingface.co/circlestone-labs/Anima)から入手できます。
**注意:**
* Qwen3の単体`.safetensors`ファイルを使用する場合、[Qwen/Qwen3-0.6B](https://huggingface.co/Qwen/Qwen3-0.6B) HuggingFaceリポジトリから`config.json``tokenizer.json``tokenizer_config.json``vocab.json`をダウンロードし、`configs/qwen3_06b/`ディレクトリに配置してください。
* T5トークナイザーはトークナイザーファイルのみ必要ですT5モデルの重みは不要`google/t5-v1_1-xxl`の語彙を使用します。
* モデルはComfyUI互換のため、すべてのキーに`net.`プレフィックスを付けて保存されます。
* T5トークナイザーを別途指定する場合、トークナイザーファイルのみ必要ですT5モデルの重みは不要`google/t5-v1_1-xxl`の語彙を使用します。
</details>
## 4. Running the Training / 学習の実行
@ -103,33 +107,38 @@ Example command:
```bash
accelerate launch --num_cpu_threads_per_process 1 anima_train_network.py \
--dit_path="<path to Anima DiT model>" \
--qwen3_path="<path to Qwen3-0.6B model or directory>" \
--vae_path="<path to WanVAE model>" \
--llm_adapter_path="<path to LLM adapter model>" \
--pretrained_model_name_or_path="<path to Anima DiT model>" \
--qwen3="<path to Qwen3-0.6B model or directory>" \
--vae="<path to Qwen-Image VAE model>" \
--dataset_config="my_anima_dataset_config.toml" \
--output_dir="<output directory>" \
--output_name="my_anima_lora" \
--save_model_as=safetensors \
--network_module=networks.lora_anima \
--network_dim=8 \
--network_alpha=8 \
--learning_rate=1e-4 \
--optimizer_type="AdamW8bit" \
--lr_scheduler="constant" \
--timestep_sample_method="logit_normal" \
--discrete_flow_shift=3.0 \
--timestep_sampling="sigmoid" \
--discrete_flow_shift=1.0 \
--max_train_epochs=10 \
--save_every_n_epochs=1 \
--mixed_precision="bf16" \
--gradient_checkpointing \
--cache_latents \
--cache_text_encoder_outputs \
--blocks_to_swap=18
--vae_chunk_size=64 \
--vae_disable_cache
```
*(Write the command on one line or use `\` or `^` for line breaks.)*
The learning rate of `1e-4` is just an example. Adjust it according to your dataset and objectives. This value is for `alpha=1.0` (default). If increasing `--network_alpha`, consider lowering the learning rate.
If loss becomes NaN, ensure you are using PyTorch version 2.5 or higher.
**Note:** `--vae_chunk_size` and `--vae_disable_cache` are custom options in this repository to reduce memory usage of the Qwen-Image VAE.
<details>
<summary>日本語</summary>
@ -138,6 +147,13 @@ accelerate launch --num_cpu_threads_per_process 1 anima_train_network.py \
コマンドラインの例は英語のドキュメントを参照してください。
※実際には1行で書くか、適切な改行文字`\` または `^`)を使用してください。
学習率1e-4はあくまで一例です。データセットや目的に応じて適切に調整してください。またこの値はalpha=1.0(デフォルト)での値です。`--network_alpha`を増やす場合は学習率を下げることを検討してください。
lossがNaNになる場合は、PyTorchのバージョンが2.5以上であることを確認してください。
注意: `--vae_chunk_size`および`--vae_disable_cache`は当リポジトリ独自のオプションで、Qwen-Image VAEのメモリ使用量を削減するために使用します。
</details>
### 4.1. Explanation of Key Options / 主要なコマンドライン引数の解説
@ -146,12 +162,15 @@ Besides the arguments explained in the [train_network.py guide](train_network.md
#### Model Options [Required] / モデル関連 [必須]
* `--dit_path="<path to Anima DiT model>"` **[Required]**
* `--pretrained_model_name_or_path="<path to Anima DiT model>"` **[Required]**
- Path to the Anima DiT model `.safetensors` file. The model config (channels, blocks, heads) is auto-detected from the state dict. ComfyUI format with `net.` prefix is supported.
* `--qwen3_path="<path to Qwen3-0.6B model>"` **[Required]**
* `--qwen3="<path to Qwen3-0.6B model>"` **[Required]**
- Path to the Qwen3-0.6B text encoder. Can be a HuggingFace model directory or a single `.safetensors` file. The text encoder is always frozen during training.
* `--vae_path="<path to WanVAE model>"` **[Required]**
- Path to the WanVAE model `.safetensors` or `.pth` file. Fixed config: `dim=96, z_dim=16`.
* `--vae="<path to Qwen-Image VAE model>"` **[Required]**
- Path to the Qwen-Image VAE model `.safetensors` or `.pth` file. Fixed config: `dim=96, z_dim=16`.
#### Model Options [Optional] / モデル関連 [オプション]
* `--llm_adapter_path="<path to LLM adapter>"` *[Optional]*
- Path to a separate LLM adapter weights file. If omitted, the adapter is loaded from the DiT file when the key `llm_adapter.out_proj.weight` exists.
* `--t5_tokenizer_path="<path to T5 tokenizer>"` *[Optional]*
@ -159,53 +178,58 @@ Besides the arguments explained in the [train_network.py guide](train_network.md
#### Anima Training Parameters / Anima 学習パラメータ
* `--timestep_sample_method=<choice>`
- Timestep sampling method. Choose from `logit_normal` (default) or `uniform`.
* `--timestep_sampling=<choice>`
- Timestep sampling method. Choose from `sigma`, `uniform`, `sigmoid` (default), `shift`, `flux_shift`. Same options as FLUX training. See the [flux_train_network.py guide](flux_train_network.md) for details on each method.
* `--discrete_flow_shift=<float>`
- Shift for the timestep distribution in Rectified Flow training. Default `3.0`. The shift formula is `t_shifted = (t * shift) / (1 + (shift - 1) * t)`.
- Shift for the timestep distribution in Rectified Flow training. Default `1.0`. This value is used when `--timestep_sampling` is set to **`shift`**. The shift formula is `t_shifted = (t * shift) / (1 + (shift - 1) * t)`.
* `--sigmoid_scale=<float>`
- Scale factor for `logit_normal` timestep sampling. Default `1.0`.
- Scale factor when `--timestep_sampling` is set to `sigmoid`, `shift`, or `flux_shift`. Default `1.0`.
* `--qwen3_max_token_length=<integer>`
- Maximum token length for the Qwen3 tokenizer. Default `512`.
* `--t5_max_token_length=<integer>`
- Maximum token length for the T5 tokenizer. Default `512`.
* `--flash_attn`
- Use Flash Attention for DiT self/cross-attention. Requires `pip install flash-attn`. Falls back to PyTorch SDPA if the package is not installed. Note: Flash Attention is only applied to DiT blocks; the LLM Adapter uses standard attention because it requires attention masks.
* `--transformer_dtype=<choice>`
- Separate dtype for transformer blocks. Choose from `float16`, `bfloat16`, `float32`. If not specified, uses the same dtype as `--mixed_precision`.
* `--attn_mode=<choice>`
- Attention implementation to use. Choose from `torch` (default), `xformers`, `flash`, `sageattn`. `xformers` requires `--split_attn`. `sageattn` does not support training (inference only). This option overrides `--xformers`.
* `--split_attn`
- Split attention computation to reduce memory usage. Required when using `--attn_mode xformers`.
#### Component-wise Learning Rates / コンポーネント別学習率
Anima supports 6 independent learning rate groups. Set to `0` to freeze a component:
These options set separate learning rates for each component of the Anima model. They are primarily used for full fine-tuning. Set to `0` to freeze a component:
* `--self_attn_lr=<float>` - Learning rate for self-attention layers. Default: same as `--learning_rate`.
* `--cross_attn_lr=<float>` - Learning rate for cross-attention layers. Default: same as `--learning_rate`.
* `--mlp_lr=<float>` - Learning rate for MLP layers. Default: same as `--learning_rate`.
* `--mod_lr=<float>` - Learning rate for AdaLN modulation layers. Default: same as `--learning_rate`.
* `--mod_lr=<float>` - Learning rate for AdaLN modulation layers. Default: same as `--learning_rate`. Note: modulation layers are not included in LoRA by default.
* `--llm_adapter_lr=<float>` - Learning rate for LLM adapter layers. Default: same as `--learning_rate`.
For LoRA training, use `network_reg_lrs` in `--network_args` instead. See [Section 5.2](#52-regex-based-rank-and-learning-rate-control--正規表現によるランク学習率の制御).
#### Memory and Speed / メモリ・速度関連
* `--blocks_to_swap=<integer>` **[Experimental]**
* `--blocks_to_swap=<integer>`
- Number of Transformer blocks to swap between CPU and GPU. More blocks reduce VRAM but slow training. Maximum values depend on model size:
- 28-block model: max **26**
- 28-block model: max **26** (Anima-Preview)
- 36-block model: max **34**
- 20-block model: max **18**
- Cannot be used with `--cpu_offload_checkpointing` or `--unsloth_offload_checkpointing`.
* `--unsloth_offload_checkpointing`
- Offload activations to CPU RAM using async non-blocking transfers. Faster than `--cpu_offload_checkpointing`. Cannot be combined with `--cpu_offload_checkpointing` or `--blocks_to_swap`.
- Offload activations to CPU RAM using async non-blocking transfers (faster than `--cpu_offload_checkpointing`). Cannot be combined with `--cpu_offload_checkpointing` or `--blocks_to_swap`.
* `--cache_text_encoder_outputs`
- Cache Qwen3 text encoder outputs to reduce VRAM usage. Recommended when not training text encoder LoRA.
* `--cache_text_encoder_outputs_to_disk`
- Cache text encoder outputs to disk. Auto-enables `--cache_text_encoder_outputs`.
* `--cache_latents`, `--cache_latents_to_disk`
- Cache WanVAE latent outputs.
* `--fp8_base`
- Use FP8 precision for the base model to reduce VRAM usage.
#### Incompatible or Deprecated Options / 非互換・非推奨の引数
- Cache Qwen-Image VAE latent outputs.
* `--vae_chunk_size=<integer>`
- Chunk size for Qwen-Image VAE processing. Reduces VRAM usage at the cost of speed. Default is no chunking.
* `--vae_disable_cache`
- Disable internal caching in Qwen-Image VAE to reduce VRAM usage.
#### Incompatible or Unsupported Options / 非互換・非サポートの引数
* `--v2`, `--v_parameterization`, `--clip_skip` - Options for Stable Diffusion v1/v2 that are not used for Anima training.
* `--fp8_base` - Not supported for Anima. If specified, it will be disabled with a warning.
<details>
<summary>日本語</summary>
@ -214,39 +238,50 @@ Anima supports 6 independent learning rate groups. Set to `0` to freeze a compon
#### モデル関連 [必須]
* `--dit_path="<path to Anima DiT model>"` **[必須]** - Anima DiTモデルの`.safetensors`ファイルのパスを指定します。
* `--qwen3_path="<path to Qwen3-0.6B model>"` **[必須]** - Qwen3-0.6Bテキストエンコーダーのパスを指定します。
* `--vae_path="<path to WanVAE model>"` **[必須]** - WanVAEモデルのパスを指定します。
* `--pretrained_model_name_or_path="<path to Anima DiT model>"` **[必須]** - Anima DiTモデルの`.safetensors`ファイルのパスを指定します。モデルの設定はstate dictから自動検出されます。`net.`プレフィックス付きのComfyUIフォーマットもサポートしています。
* `--qwen3="<path to Qwen3-0.6B model>"` **[必須]** - Qwen3-0.6Bテキストエンコーダーのパスを指定します。HuggingFaceモデルディレクトリまたは単体の`.safetensors`ファイルが使用できます。
* `--vae="<path to Qwen-Image VAE model>"` **[必須]** - Qwen-Image VAEモデルのパスを指定します。
#### モデル関連 [オプション]
* `--llm_adapter_path="<path to LLM adapter>"` *[オプション]* - 個別のLLM Adapterの重みファイルのパス。
* `--t5_tokenizer_path="<path to T5 tokenizer>"` *[オプション]* - T5トークナイザーディレクトリのパス。
#### Anima 学習パラメータ
* `--timestep_sample_method` - タイムステップのサンプリング方法。`logit_normal`(デフォルト)または`uniform`
* `--discrete_flow_shift` - Rectified Flow学習のタイムステップ分布シフト。デフォルト`3.0`
* `--sigmoid_scale` - logit_normalタイムステップサンプリングのスケール係数。デフォルト`1.0`
* `--timestep_sampling` - タイムステップのサンプリング方法。`sigma``uniform``sigmoid`(デフォルト)、`shift``flux_shift`から選択。FLUX学習と同じオプションです。各方法の詳細は[flux_train_network.pyのガイド](flux_train_network.md)を参照してください
* `--discrete_flow_shift` - Rectified Flow学習のタイムステップ分布シフト。デフォルト`1.0`。`--timestep_sampling``shift`の場合に使用されます
* `--sigmoid_scale` - `sigmoid``shift``flux_shift`タイムステップサンプリングのスケール係数。デフォルト`1.0`
* `--qwen3_max_token_length` - Qwen3トークナイザーの最大トークン長。デフォルト`512`
* `--t5_max_token_length` - T5トークナイザーの最大トークン長。デフォルト`512`
* `--flash_attn` - DiTのself/cross-attentionにFlash Attentionを使用。`pip install flash-attn`が必要
* `--transformer_dtype` - Transformerブロック用の個別dtype
* `--attn_mode` - 使用するAttentionの実装。`torch`(デフォルト)、`xformers``flash``sageattn`から選択。`xformers``--split_attn`の指定が必要です。`sageattn`はトレーニングをサポートしていません(推論のみ)
* `--split_attn` - メモリ使用量を減らすためにattention時にバッチを分割します。`--attn_mode xformers`使用時に必要です
#### コンポーネント別学習率
Animaは6つの独立した学習率グループをサポートします。`0`に設定するとそのコンポーネントをフリーズします:
これらのオプションは、Animaモデルの各コンポーネントに個別の学習率を設定します。主にフルファインチューニング用です。`0`に設定するとそのコンポーネントをフリーズします:
* `--self_attn_lr` - Self-attention層の学習率。
* `--cross_attn_lr` - Cross-attention層の学習率。
* `--mlp_lr` - MLP層の学習率。
* `--mod_lr` - AdaLNモジュレーション層の学習率。
* `--mod_lr` - AdaLNモジュレーション層の学習率。モジュレーション層はデフォルトではLoRAに含まれません。
* `--llm_adapter_lr` - LLM Adapter層の学習率。
LoRA学習の場合は、`--network_args``network_reg_lrs`を使用してください。[セクション5.2](#52-regex-based-rank-and-learning-rate-control--正規表現によるランク学習率の制御)を参照。
#### メモリ・速度関連
* `--blocks_to_swap` **[実験的機能]** - TransformerブロックをCPUとGPUでスワップしてVRAMを節約。
* `--unsloth_offload_checkpointing` - 非同期転送でアクティベーションをCPU RAMにオフロード。
* `--blocks_to_swap` - TransformerブロックをCPUとGPUでスワップしてVRAMを節約。`--cpu_offload_checkpointing`および`--unsloth_offload_checkpointing`とは併用できません。
* `--unsloth_offload_checkpointing` - 非同期転送でアクティベーションをCPU RAMにオフロード。`--cpu_offload_checkpointing`および`--blocks_to_swap`とは併用できません。
* `--cache_text_encoder_outputs` - Qwen3の出力をキャッシュしてメモリ使用量を削減。
* `--cache_latents`, `--cache_latents_to_disk` - WanVAEの出力をキャッシュ。
* `--fp8_base` - ベースモデルにFP8精度を使用。
* `--cache_latents`, `--cache_latents_to_disk` - Qwen-Image VAEの出力をキャッシュ。
* `--vae_chunk_size` - Qwen-Image VAEのチャンク処理サイズ。メモリ使用量を削減しますが速度が低下します。デフォルトはチャンク処理なし。
* `--vae_disable_cache` - Qwen-Image VAEの内部キャッシュを無効化してメモリ使用量を削減します。
#### 非互換・非サポートの引数
* `--v2`, `--v_parameterization`, `--clip_skip` - Stable Diffusion v1/v2向けの引数。Animaの学習では使用されません。
* `--fp8_base` - Animaではサポートされていません。指定した場合、警告とともに無効化されます。
</details>
### 4.2. Starting Training / 学習の開始
@ -262,67 +297,66 @@ After setting the required arguments, run the command to begin training. The ove
## 5. LoRA Target Modules / LoRAの学習対象モジュール
When training LoRA with `anima_train_network.py`, the following modules are targeted:
When training LoRA with `anima_train_network.py`, the following modules are targeted by default:
* **DiT Blocks (`Block`)**: Self-attention, cross-attention, MLP, and AdaLN modulation layers within each transformer block.
* **DiT Blocks (`Block`)**: Self-attention (`self_attn`), cross-attention (`cross_attn`), and MLP (`mlp`) layers within each transformer block. Modulation (`adaln_modulation`), norm, embedder, and final layers are excluded by default.
* **Embedding layers (`PatchEmbed`, `TimestepEmbedding`) and Final layer (`FinalLayer`)**: Excluded by default but can be included using `include_patterns`.
* **LLM Adapter Blocks (`LLMAdapterTransformerBlock`)**: Only when `--network_args "train_llm_adapter=True"` is specified.
* **Text Encoder (Qwen3)**: Only when `--network_train_unet_only` is NOT specified.
* **Text Encoder (Qwen3)**: Only when `--network_train_unet_only` is NOT specified and `--cache_text_encoder_outputs` is NOT used.
The LoRA network module is `networks.lora_anima`.
### 5.1. Layer-specific Rank Configuration / 各層に対するランク指定
### 5.1. Module Selection with Patterns / パターンによるモジュール選択
You can specify different ranks (network_dim) for each component of the Anima model. Setting `0` disables LoRA for that component.
| network_args | Target Component |
|---|---|
| `self_attn_dim` | Self-attention layers in DiT blocks |
| `cross_attn_dim` | Cross-attention layers in DiT blocks |
| `mlp_dim` | MLP layers in DiT blocks |
| `mod_dim` | AdaLN modulation layers in DiT blocks |
| `llm_adapter_dim` | LLM adapter layers (requires `train_llm_adapter=True`) |
Example usage:
By default, the following modules are excluded from LoRA via the built-in exclude pattern:
```
--network_args "self_attn_dim=8" "cross_attn_dim=4" "mlp_dim=8" "mod_dim=4"
.*(_modulation|_norm|_embedder|final_layer).*
```
### 5.2. Embedding Layer LoRA / 埋め込み層LoRA
You can customize which modules are included or excluded using regex patterns in `--network_args`:
You can apply LoRA to embedding/output layers by specifying `emb_dims` in network_args as a comma-separated list of 3 numbers:
* `exclude_patterns` - Exclude modules matching these patterns (in addition to the default exclusion).
* `include_patterns` - Force-include modules matching these patterns, overriding exclusion.
Patterns are matched against the full module name using `re.fullmatch()`.
Example to include the final layer:
```
--network_args "emb_dims=[8,4,8]"
--network_args "include_patterns=['.*final_layer.*']"
```
Each number corresponds to:
1. `x_embedder` (patch embedding)
2. `t_embedder` (timestep embedding)
3. `final_layer` (output layer)
Setting `0` disables LoRA for that layer.
### 5.3. Block Selection for Training / 学習するブロックの指定
You can specify which DiT blocks to train using `train_block_indices` in network_args. The indices are 0-based. Default is to train all blocks.
Specify indices as comma-separated integers or ranges:
Example to additionally exclude MLP layers:
```
--network_args "train_block_indices=0-5,10,15-27"
--network_args "exclude_patterns=['.*mlp.*']"
```
Special values: `all` (train all blocks), `none` (skip all blocks).
### 5.2. Regex-based Rank and Learning Rate Control / 正規表現によるランク・学習率の制御
### 5.4. LLM Adapter LoRA / LLM Adapter LoRA
You can specify different ranks (network_dim) and learning rates for modules matching specific regex patterns:
* `network_reg_dims`: Specify ranks for modules matching a regular expression. The format is a comma-separated string of `pattern=rank`.
* Example: `--network_args "network_reg_dims=.*self_attn.*=8,.*cross_attn.*=4,.*mlp.*=8"`
* This sets the rank to 8 for self-attention modules, 4 for cross-attention modules, and 8 for MLP modules.
* `network_reg_lrs`: Specify learning rates for modules matching a regular expression. The format is a comma-separated string of `pattern=lr`.
* Example: `--network_args "network_reg_lrs=.*self_attn.*=1e-4,.*cross_attn.*=5e-5"`
* This sets the learning rate to `1e-4` for self-attention modules and `5e-5` for cross-attention modules.
**Notes:**
* Settings via `network_reg_dims` and `network_reg_lrs` take precedence over the global `--network_dim` and `--learning_rate` settings.
* Patterns are matched using `re.fullmatch()` against the module's original name (e.g., `blocks.0.self_attn.q_proj`).
### 5.3. LLM Adapter LoRA / LLM Adapter LoRA
To apply LoRA to the LLM Adapter blocks:
```
--network_args "train_llm_adapter=True" "llm_adapter_dim=4"
--network_args "train_llm_adapter=True"
```
### 5.5. Other Network Args / その他のネットワーク引数
In preliminary tests, lowering the learning rate for the LLM Adapter seems to improve stability. Adjust it using something like: `"network_reg_lrs=.*llm_adapter.*=5e-5"`.
### 5.4. Other Network Args / その他のネットワーク引数
* `--network_args "verbose=True"` - Print all LoRA module names and their dimensions.
* `--network_args "rank_dropout=0.1"` - Rank dropout rate.
@ -336,48 +370,58 @@ To apply LoRA to the LLM Adapter blocks:
`anima_train_network.py`でLoRAを学習させる場合、デフォルトでは以下のモジュールが対象となります。
* **DiTブロック (`Block`)**: 各Transformerブロック内のSelf-attention、Cross-attention、MLP、AdaLNモジュレーション層。
* **DiTブロック (`Block`)**: 各Transformerブロック内のSelf-attention`self_attn`、Cross-attention`cross_attn`、MLP`mlp`)層。モジュレーション(`adaln_modulation`、norm、embedder、final layerはデフォルトで除外されます。
* **埋め込み層 (`PatchEmbed`, `TimestepEmbedding`) と最終層 (`FinalLayer`)**: デフォルトで除外されますが、`include_patterns`で含めることができます。
* **LLM Adapterブロック (`LLMAdapterTransformerBlock`)**: `--network_args "train_llm_adapter=True"`を指定した場合のみ。
* **テキストエンコーダー (Qwen3)**: `--network_train_unet_only`を指定しない場合のみ。
* **テキストエンコーダー (Qwen3)**: `--network_train_unet_only`を指定せず、かつ`--cache_text_encoder_outputs`を使用しない場合のみ。
### 5.1. 各層のランクを指定する
### 5.1. パターンによるモジュール選択
`--network_args`で各コンポーネントに異なるランクを指定できます。`0`を指定するとその層にはLoRAが適用されません。
デフォルトでは以下のモジュールが組み込みの除外パターンによりLoRAから除外されます
```
.*(_modulation|_norm|_embedder|final_layer).*
```
|network_args|対象コンポーネント|
|---|---|
|`self_attn_dim`|DiTブロック内のSelf-attention層|
|`cross_attn_dim`|DiTブロック内のCross-attention層|
|`mlp_dim`|DiTブロック内のMLP層|
|`mod_dim`|DiTブロック内のAdaLNモジュレーション層|
|`llm_adapter_dim`|LLM Adapter層`train_llm_adapter=True`が必要)|
`--network_args`で正規表現パターンを使用して、含めるモジュールと除外するモジュールをカスタマイズできます:
### 5.2. 埋め込み層LoRA
* `exclude_patterns` - これらのパターンにマッチするモジュールを除外(デフォルトの除外に追加)。
* `include_patterns` - これらのパターンにマッチするモジュールを強制的に含める(除外を上書き)。
`emb_dims`で埋め込み/出力層にLoRAを適用できます。3つの数値をカンマ区切りで指定します。
パターンは`re.fullmatch()`を使用して完全なモジュール名に対してマッチングされます。
各数値は `x_embedder`(パッチ埋め込み)、`t_embedder`(タイムステップ埋め込み)、`final_layer`(出力層)に対応します。
### 5.2. 正規表現によるランク・学習率の制御
### 5.3. 学習するブロックの指定
正規表現にマッチするモジュールに対して、異なるランクや学習率を指定できます:
`train_block_indices`でLoRAを適用するDiTブロックを指定できます。
* `network_reg_dims`: 正規表現にマッチするモジュールに対してランクを指定します。`pattern=rank`形式の文字列をカンマで区切って指定します。
* 例: `--network_args "network_reg_dims=.*self_attn.*=8,.*cross_attn.*=4,.*mlp.*=8"`
* `network_reg_lrs`: 正規表現にマッチするモジュールに対して学習率を指定します。`pattern=lr`形式の文字列をカンマで区切って指定します。
* 例: `--network_args "network_reg_lrs=.*self_attn.*=1e-4,.*cross_attn.*=5e-5"`
### 5.4. LLM Adapter LoRA
**注意点:**
* `network_reg_dims`および`network_reg_lrs`での設定は、全体設定である`--network_dim``--learning_rate`よりも優先されます。
* パターンはモジュールのオリジナル名(例: `blocks.0.self_attn.q_proj`)に対して`re.fullmatch()`でマッチングされます。
LLM AdapterブロックにLoRAを適用するには`--network_args "train_llm_adapter=True" "llm_adapter_dim=4"`
### 5.3. LLM Adapter LoRA
### 5.5. その他のネットワーク引数
LLM AdapterブロックにLoRAを適用するには`--network_args "train_llm_adapter=True"`
簡易な検証ではLLM Adapterの学習率はある程度下げた方が安定するようです。`"network_reg_lrs=.*llm_adapter.*=5e-5"`などで調整してください。
### 5.4. その他のネットワーク引数
* `verbose=True` - 全LoRAモジュール名とdimを表示
* `rank_dropout` - ランクドロップアウト率
* `module_dropout` - モジュールドロップアウト率
* `loraplus_lr_ratio` - LoRA+学習率比率
* `loraplus_unet_lr_ratio` - DiT専用のLoRA+学習率比率
* `loraplus_text_encoder_lr_ratio` - テキストエンコーダー専用のLoRA+学習率比率
</details>
## 6. Using the Trained Model / 学習済みモデルの利用
When training finishes, a LoRA model file (e.g. `my_anima_lora.safetensors`) is saved in the directory specified by `output_dir`. Use this file with inference environments that support Anima , such as ComfyUI with appropriate nodes.
When training finishes, a LoRA model file (e.g. `my_anima_lora.safetensors`) is saved in the directory specified by `output_dir`. Use this file with inference environments that support Anima, such as ComfyUI with appropriate nodes.
<details>
<summary>日本語</summary>
@ -394,8 +438,6 @@ Anima models can be large, so GPUs with limited VRAM may require optimization:
#### Key VRAM Reduction Options
- **`--fp8_base`**: Enables training in FP8 format for the DiT model.
- **`--blocks_to_swap <number>`**: Swaps blocks between CPU and GPU to reduce VRAM usage. Higher numbers save more VRAM but reduce training speed. See model-specific max values in section 4.1.
- **`--unsloth_offload_checkpointing`**: Offloads gradient checkpoints to CPU using async non-blocking transfers. Faster than `--cpu_offload_checkpointing`. Cannot be combined with `--blocks_to_swap`.
@ -404,7 +446,7 @@ Anima models can be large, so GPUs with limited VRAM may require optimization:
- **`--cache_text_encoder_outputs`**: Caches Qwen3 outputs so the text encoder can be freed from VRAM during training.
- **`--cache_latents`**: Caches WanVAE outputs so the VAE can be freed from VRAM during training.
- **`--cache_latents`**: Caches Qwen-Image VAE outputs so the VAE can be freed from VRAM during training.
- **Using Adafactor optimizer**: Can reduce VRAM usage:
```
@ -417,12 +459,11 @@ Anima models can be large, so GPUs with limited VRAM may require optimization:
Animaモデルは大きい場合があるため、VRAMが限られたGPUでは最適化が必要です。
主要なVRAM削減オプション
- `--fp8_base`: FP8形式での学習を有効化
- `--blocks_to_swap`: CPUとGPU間でブロックをスワップ
- `--unsloth_offload_checkpointing`: 非同期転送でアクティベーションをCPUにオフロード
- `--gradient_checkpointing`: 標準的な勾配チェックポイント
- `--cache_text_encoder_outputs`: Qwen3の出力をキャッシュ
- `--cache_latents`: WanVAEの出力をキャッシュ
- `--cache_latents`: Qwen-Image VAEの出力をキャッシュ
- Adafactorオプティマイザの使用
</details>
@ -431,21 +472,24 @@ Animaモデルは大きい場合があるため、VRAMが限られたGPUでは
#### Timestep Sampling
The `--timestep_sample_method` option specifies how timesteps (0-1) are sampled:
The `--timestep_sampling` option specifies how timesteps are sampled. The available methods are the same as FLUX training:
- `logit_normal` (default): Sample from Normal(0,1), multiply by `sigmoid_scale`, apply sigmoid. Good general-purpose option.
- `sigma`: Sigma-based sampling like SD3.
- `uniform`: Uniform random sampling from [0, 1].
- `sigmoid` (default): Sample from Normal(0,1), multiply by `sigmoid_scale`, apply sigmoid. Good general-purpose option.
- `shift`: Like `sigmoid`, but applies the discrete flow shift formula: `t_shifted = (t * shift) / (1 + (shift - 1) * t)`.
- `flux_shift`: Resolution-dependent shift used in FLUX training.
See the [flux_train_network.py guide](flux_train_network.md) for detailed descriptions.
#### Discrete Flow Shift
The `--discrete_flow_shift` option (default `3.0`) shifts the timestep distribution toward higher noise levels. The formula is:
The `--discrete_flow_shift` option (default `1.0`) only applies when `--timestep_sampling` is set to `shift`. The formula is:
```
t_shifted = (t * shift) / (1 + (shift - 1) * t)
```
Timesteps are clamped to `[1e-5, 1-1e-5]` after shifting.
#### Loss Weighting
The `--weighting_scheme` option specifies loss weighting by timestep:
@ -454,23 +498,34 @@ The `--weighting_scheme` option specifies loss weighting by timestep:
- `sigma_sqrt`: Weight by `sigma^(-2)`.
- `cosmap`: Weight by `2 / (pi * (1 - 2*sigma + 2*sigma^2))`.
- `none`: Same as uniform.
- `logit_normal`, `mode`: Additional schemes from SD3 training. See the [`sd3_train_network.md` guide](sd3_train_network.md) for details.
#### Caption Dropout
Use `--caption_dropout_rate` for embedding-level caption dropout. This is handled by `AnimaTextEncodingStrategy` and is compatible with text encoder output caching. The subset-level `caption_dropout_rate` is automatically zeroed when this is set.
Caption dropout uses the `caption_dropout_rate` setting from the dataset configuration (per-subset in TOML). When using `--cache_text_encoder_outputs`, the dropout rate is stored with each cached entry and applied during training, so caption dropout is compatible with text encoder output caching.
**If you change the `caption_dropout_rate` setting, you must delete and regenerate the cache.**
Note: Currently, only Anima supports combining `caption_dropout_rate` with text encoder output caching.
<details>
<summary>日本語</summary>
#### タイムステップサンプリング
`--timestep_sample_method`でタイムステップのサンプリング方法を指定します:
- `logit_normal`(デフォルト): 正規分布からサンプリングし、sigmoidを適用。
`--timestep_sampling`でタイムステップのサンプリング方法を指定します。FLUX学習と同じ方法が利用できます
- `sigma`: SD3と同様のシグマベースサンプリング。
- `uniform`: [0, 1]の一様分布からサンプリング。
- `sigmoid`(デフォルト): 正規分布からサンプリングし、sigmoidを適用。汎用的なオプション。
- `shift`: `sigmoid`と同様だが、離散フローシフトの式を適用。
- `flux_shift`: FLUX学習で使用される解像度依存のシフト。
詳細は[flux_train_network.pyのガイド](flux_train_network.md)を参照してください。
#### 離散フローシフト
`--discrete_flow_shift`(デフォルト`3.0`)はタイムステップ分布を高ノイズ側にシフトします。
`--discrete_flow_shift`(デフォルト`1.0`)は`--timestep_sampling``shift`の場合のみ適用されます。
#### 損失の重み付け
@ -478,7 +533,11 @@ Use `--caption_dropout_rate` for embedding-level caption dropout. This is handle
#### キャプションドロップアウト
`--caption_dropout_rate`で埋め込みレベルのキャプションドロップアウトを使用します。テキストエンコーダー出力のキャッシュと互換性があります。
キャプションドロップアウトにはデータセット設定TOMLでのサブセット単位`caption_dropout_rate`を使用します。`--cache_text_encoder_outputs`使用時は、ドロップアウト率が各キャッシュエントリとともに保存され、学習中に適用されるため、テキストエンコーダー出力キャッシュと同時に使用できます。
**`caption_dropout_rate`の設定を変えた場合、キャッシュを削除し、再生成する必要があります。**
`caption_dropout_rate`をテキストエンコーダー出力キャッシュと組み合わせられるのは、今のところAnimaのみです。
</details>
@ -487,17 +546,23 @@ Use `--caption_dropout_rate` for embedding-level caption dropout. This is handle
Anima LoRA training supports training Qwen3 text encoder LoRA:
- To train only DiT: specify `--network_train_unet_only`
- To train DiT and Qwen3: omit `--network_train_unet_only`
- To train DiT and Qwen3: omit `--network_train_unet_only` and do NOT use `--cache_text_encoder_outputs`
You can specify a separate learning rate for Qwen3 with `--text_encoder_lr`. If not specified, the default `--learning_rate` is used.
Note: When `--cache_text_encoder_outputs` is used, text encoder outputs are pre-computed and the text encoder is removed from GPU, so text encoder LoRA cannot be trained.
<details>
<summary>日本語</summary>
Anima LoRA学習では、Qwen3テキストエンコーダーのLoRAもトレーニングできます。
- DiTのみ学習: `--network_train_unet_only`を指定
- DiTとQwen3を学習: `--network_train_unet_only`を省略
- DiTとQwen3を学習: `--network_train_unet_only`を省略し、`--cache_text_encoder_outputs`を使用しない
Qwen3に個別の学習率を指定するには`--text_encoder_lr`を使用します。未指定の場合は`--learning_rate`が使われます。
注意: `--cache_text_encoder_outputs`を使用する場合、テキストエンコーダーの出力が事前に計算されGPUから解放されるため、テキストエンコーダーLoRAは学習できません。
</details>
@ -528,16 +593,47 @@ Anima LoRA学習では、Qwen3テキストエンコーダーのLoRAもトレー
</details>
## 9. Others / その他
## 9. Related Tools / 関連ツール
### `networks/anima_convert_lora_to_comfy.py`
A script to convert LoRA models to ComfyUI-compatible format. ComfyUI does not directly support sd-scripts format Qwen3 LoRA, so conversion is necessary (conversion may not be needed for DiT-only LoRA). You can convert from the sd-scripts format to ComfyUI format with:
```bash
python networks/convert_anima_lora_to_comfy.py path/to/source.safetensors path/to/destination.safetensors
```
Using the `--reverse` option allows conversion in the opposite direction (ComfyUI format to sd-scripts format). However, reverse conversion is only possible for LoRAs converted by this script. LoRAs created with other training tools cannot be converted.
<details>
<summary>日本語</summary>
**`networks/convert_anima_lora_to_comfy.py`**
LoRAモデルをComfyUI互換形式に変換するスクリプト。ComfyUIがsd-scripts形式のQwen3 LoRAを直接サポートしていないため、変換が必要ですDiTのみのLoRAの場合は変換不要のようです。sd-scripts形式からComfyUI形式への変換は以下のコマンドで行います
```bash
python networks/convert_anima_lora_to_comfy.py path/to/source.safetensors path/to/destination.safetensors
```
`--reverse`オプションを付けると、逆変換ComfyUI形式からsd-scripts形式も可能です。ただし、逆変換ができるのはこのスクリプトで変換したLoRAに限ります。他の学習ツールで作成したLoRAは変換できません。
</details>
## 10. Others / その他
### Metadata Saved in LoRA Models
The following Anima-specific metadata is saved in the LoRA model file:
The following metadata is saved in the LoRA model file:
* `ss_weighting_scheme`
* `ss_discrete_flow_shift`
* `ss_timestep_sample_method`
* `ss_logit_mean`
* `ss_logit_std`
* `ss_mode_scale`
* `ss_timestep_sampling`
* `ss_sigmoid_scale`
* `ss_discrete_flow_shift`
<details>
<summary>日本語</summary>
@ -546,11 +642,14 @@ The following Anima-specific metadata is saved in the LoRA model file:
### LoRAモデルに保存されるメタデータ
以下のAnima固有のメタデータがLoRAモデルファイルに保存されます
以下のメタデータがLoRAモデルファイルに保存されます
* `ss_weighting_scheme`
* `ss_discrete_flow_shift`
* `ss_timestep_sample_method`
* `ss_logit_mean`
* `ss_logit_std`
* `ss_mode_scale`
* `ss_timestep_sampling`
* `ss_sigmoid_scale`
* `ss_discrete_flow_shift`
</details>

View File

@ -2,7 +2,7 @@
# Original code: NVIDIA CORPORATION & AFFILIATES, licensed under Apache-2.0
import math
from typing import Any, Callable, List, Optional, Tuple, Union
from typing import Any, Optional, Tuple, Union
import numpy as np
import torch
@ -13,9 +13,7 @@ import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint as torch_checkpoint
from library import custom_offloading_utils
from library.device_utils import clean_memory_on_device
from library import custom_offloading_utils, attention
def to_device(x, device):
@ -39,11 +37,13 @@ def to_cpu(x):
else:
return x
# Unsloth Offloaded Gradient Checkpointing
# Based on Unsloth Zoo by Daniel Han-Chen & the Unsloth team
try:
from deepspeed.runtime.activation_checkpointing.checkpointing import detach_variable
except ImportError:
def detach_variable(inputs, device=None):
"""Detach tensors from computation graph, optionally moving to a device.
@ -80,11 +80,11 @@ class UnslothOffloadedGradientCheckpointer(torch.autograd.Function):
"""
@staticmethod
@torch.amp.custom_fwd(device_type='cuda')
@torch.amp.custom_fwd(device_type="cuda")
def forward(ctx, forward_function, hidden_states, *args):
# Remember the original device for backward pass (multi-GPU support)
ctx.input_device = hidden_states.device
saved_hidden_states = hidden_states.to('cpu', non_blocking=True)
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
with torch.no_grad():
output = forward_function(hidden_states, *args)
ctx.save_for_backward(saved_hidden_states)
@ -96,7 +96,7 @@ class UnslothOffloadedGradientCheckpointer(torch.autograd.Function):
return output
@staticmethod
@torch.amp.custom_bwd(device_type='cuda')
@torch.amp.custom_bwd(device_type="cuda")
def backward(ctx, *grads):
(hidden_states,) = ctx.saved_tensors
hidden_states = hidden_states.to(ctx.input_device, non_blocking=True).detach()
@ -108,8 +108,9 @@ class UnslothOffloadedGradientCheckpointer(torch.autograd.Function):
output_tensors = []
grad_tensors = []
for out, grad in zip(outputs if isinstance(outputs, tuple) else (outputs,),
grads if isinstance(grads, tuple) else (grads,)):
for out, grad in zip(
outputs if isinstance(outputs, tuple) else (outputs,), grads if isinstance(grads, tuple) else (grads,)
):
if isinstance(out, torch.Tensor) and out.requires_grad:
output_tensors.append(out)
grad_tensors.append(grad)
@ -123,26 +124,6 @@ def unsloth_checkpoint(function, *args):
return UnslothOffloadedGradientCheckpointer.apply(function, *args)
# Flash Attention support
try:
from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func
FLASH_ATTN_AVAILABLE = True
except ImportError:
_flash_attn_func = None
FLASH_ATTN_AVAILABLE = False
def flash_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor:
"""Computes multi-head attention using Flash Attention.
Input format: (batch, seq_len, n_heads, head_dim)
Output format: (batch, seq_len, n_heads * head_dim) matches torch_attention_op output.
"""
# flash_attn_func expects (B, S, H, D) and returns (B, S, H, D)
out = _flash_attn_func(q_B_S_H_D, k_B_S_H_D, v_B_S_H_D)
return rearrange(out, "b s h d -> b s (h d)")
from .utils import setup_logging
setup_logging()
@ -174,14 +155,10 @@ def _apply_rotary_pos_emb_base(
if start_positions is not None:
max_offset = torch.max(start_positions)
assert (
max_offset + cur_seq_len <= max_seq_len
), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
assert max_offset + cur_seq_len <= max_seq_len, f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
freqs = torch.concatenate([freqs[i : i + cur_seq_len] for i in start_positions], dim=1)
assert (
cur_seq_len <= max_seq_len
), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
assert cur_seq_len <= max_seq_len, f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
freqs = freqs[:cur_seq_len]
if tensor_format == "bshd":
@ -205,13 +182,9 @@ def apply_rotary_pos_emb(
cu_seqlens: Union[torch.Tensor, None] = None,
cp_size: int = 1,
) -> torch.Tensor:
assert not (
cp_size > 1 and start_positions is not None
), "start_positions != None with CP SIZE > 1 is not supported!"
assert not (cp_size > 1 and start_positions is not None), "start_positions != None with CP SIZE > 1 is not supported!"
assert (
tensor_format != "thd" or cu_seqlens is not None
), "cu_seqlens must not be None when tensor_format is 'thd'."
assert tensor_format != "thd" or cu_seqlens is not None, "cu_seqlens must not be None when tensor_format is 'thd'."
assert fused == False
@ -223,9 +196,7 @@ def apply_rotary_pos_emb(
_apply_rotary_pos_emb_base(
x.unsqueeze(1),
freqs,
start_positions=(
start_positions[idx : idx + 1] if start_positions is not None else None
),
start_positions=(start_positions[idx : idx + 1] if start_positions is not None else None),
interleaved=interleaved,
)
for idx, x in enumerate(torch.split(t, seqlens))
@ -262,10 +233,10 @@ class RMSNorm(torch.nn.Module):
def _norm(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
@torch.amp.autocast(device_type='cuda', dtype=torch.float32)
def forward(self, x: torch.Tensor) -> torch.Tensor:
output = self._norm(x.float()).type_as(x)
return output * self.weight
with torch.autocast(device_type=x.device.type, dtype=torch.float32):
output = self._norm(x.float()).type_as(x)
return output * self.weight
class GPT2FeedForward(nn.Module):
@ -298,22 +269,6 @@ class GPT2FeedForward(nn.Module):
return x
def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor:
"""Computes multi-head attention using PyTorch's native scaled_dot_product_attention.
Input/output format: (batch, seq_len, n_heads, head_dim)
"""
in_q_shape = q_B_S_H_D.shape
in_k_shape = k_B_S_H_D.shape
q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1])
k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
result_B_S_HD = rearrange(
F.scaled_dot_product_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D), "b h ... l -> b ... (h l)"
)
return result_B_S_HD
# Attention module for DiT
class Attention(nn.Module):
"""Multi-head attention supporting both self-attention and cross-attention.
@ -354,8 +309,6 @@ class Attention(nn.Module):
self.output_proj = nn.Linear(inner_dim, query_dim, bias=False)
self.output_dropout = nn.Dropout(dropout) if dropout > 1e-4 else nn.Identity()
self.attn_op = torch_attention_op
self._query_dim = query_dim
self._context_dim = context_dim
self._inner_dim = inner_dim
@ -399,18 +352,25 @@ class Attention(nn.Module):
return q, k, v
def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
result = self.attn_op(q, k, v) # [B, S, H, D]
return self.output_dropout(self.output_proj(result))
def forward(
self,
x: torch.Tensor,
attn_params: attention.AttentionParams,
context: Optional[torch.Tensor] = None,
rope_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb)
return self.compute_attention(q, k, v)
if q.dtype != v.dtype:
if (not attn_params.supports_fp32 or attn_params.requires_same_dtype) and torch.is_autocast_enabled():
# FlashAttention requires fp16/bf16, xformers require same dtype; only cast when autocast is active.
target_dtype = v.dtype # v has fp16/bf16 dtype
q = q.to(target_dtype)
k = k.to(target_dtype)
# return self.compute_attention(q, k, v)
qkv = [q, k, v]
del q, k, v
result = attention.attention(qkv, attn_params=attn_params)
return self.output_dropout(self.output_proj(result))
# Positional Embeddings
@ -484,12 +444,8 @@ class VideoRopePosition3DEmb(VideoPositionEmb):
dim_t = self._dim_t
self.seq = torch.arange(max(self.max_h, self.max_w, self.max_t)).float().to(self.dim_spatial_range.device)
self.dim_spatial_range = (
torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().to(self.dim_spatial_range.device) / dim_h
)
self.dim_temporal_range = (
torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().to(self.dim_spatial_range.device) / dim_t
)
self.dim_spatial_range = torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().to(self.dim_spatial_range.device) / dim_h
self.dim_temporal_range = torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().to(self.dim_spatial_range.device) / dim_t
def generate_embeddings(
self,
@ -664,31 +620,30 @@ class TimestepEmbedding(nn.Module):
return emb_B_T_D, adaln_lora_B_T_3D
class FourierFeatures(nn.Module):
"""Fourier feature transform: [B] -> [B, D]."""
# Commented out Fourier Features (not used in Anima). Kept for reference.
# class FourierFeatures(nn.Module):
# """Fourier feature transform: [B] -> [B, D]."""
def __init__(self, num_channels: int, bandwidth: int = 1, normalize: bool = False):
super().__init__()
self.register_buffer("freqs", 2 * np.pi * bandwidth * torch.randn(num_channels), persistent=True)
self.register_buffer("phases", 2 * np.pi * torch.rand(num_channels), persistent=True)
self.gain = np.sqrt(2) if normalize else 1
self.bandwidth = bandwidth
self.num_channels = num_channels
self.reset_parameters()
# def __init__(self, num_channels: int, bandwidth: int = 1, normalize: bool = False):
# super().__init__()
# self.register_buffer("freqs", 2 * np.pi * bandwidth * torch.randn(num_channels), persistent=True)
# self.register_buffer("phases", 2 * np.pi * torch.rand(num_channels), persistent=True)
# self.gain = np.sqrt(2) if normalize else 1
# self.bandwidth = bandwidth
# self.num_channels = num_channels
# self.reset_parameters()
def reset_parameters(self) -> None:
generator = torch.Generator()
generator.manual_seed(0)
self.freqs = (
2 * np.pi * self.bandwidth * torch.randn(self.num_channels, generator=generator).to(self.freqs.device)
)
self.phases = 2 * np.pi * torch.rand(self.num_channels, generator=generator).to(self.freqs.device)
# def reset_parameters(self) -> None:
# generator = torch.Generator()
# generator.manual_seed(0)
# self.freqs = 2 * np.pi * self.bandwidth * torch.randn(self.num_channels, generator=generator).to(self.freqs.device)
# self.phases = 2 * np.pi * torch.rand(self.num_channels, generator=generator).to(self.freqs.device)
def forward(self, x: torch.Tensor, gain: float = 1.0) -> torch.Tensor:
in_dtype = x.dtype
x = x.to(torch.float32).ger(self.freqs.to(torch.float32)).add(self.phases.to(torch.float32))
x = x.cos().mul(self.gain * gain).to(in_dtype)
return x
# def forward(self, x: torch.Tensor, gain: float = 1.0) -> torch.Tensor:
# in_dtype = x.dtype
# x = x.to(torch.float32).ger(self.freqs.to(torch.float32)).add(self.phases.to(torch.float32))
# x = x.cos().mul(self.gain * gain).to(in_dtype)
# return x
# Patch Embedding
@ -713,9 +668,7 @@ class PatchEmbed(nn.Module):
m=spatial_patch_size,
n=spatial_patch_size,
),
nn.Linear(
in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=False
),
nn.Linear(in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=False),
)
self.dim = in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size
@ -765,9 +718,7 @@ class FinalLayer(nn.Module):
nn.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False),
)
else:
self.adaln_modulation = nn.Sequential(
nn.SiLU(), nn.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False)
)
self.adaln_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False))
self.init_weights()
@ -790,9 +741,9 @@ class FinalLayer(nn.Module):
):
if self.use_adaln_lora:
assert adaln_lora_B_T_3D is not None
shift_B_T_D, scale_B_T_D = (
self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size]
).chunk(2, dim=-1)
shift_B_T_D, scale_B_T_D = (self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size]).chunk(
2, dim=-1
)
else:
shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1)
@ -833,7 +784,11 @@ class Block(nn.Module):
self.layer_norm_cross_attn = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6)
self.cross_attn = Attention(
x_dim, context_dim, num_heads, x_dim // num_heads, qkv_format="bshd",
x_dim,
context_dim,
num_heads,
x_dim // num_heads,
qkv_format="bshd",
)
self.layer_norm_mlp = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6)
@ -904,6 +859,7 @@ class Block(nn.Module):
x_B_T_H_W_D: torch.Tensor,
emb_B_T_D: torch.Tensor,
crossattn_emb: torch.Tensor,
attn_params: attention.AttentionParams,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
@ -919,13 +875,13 @@ class Block(nn.Module):
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = (
self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D
).chunk(3, dim=-1)
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = (
self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D
).chunk(3, dim=-1)
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = (self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D).chunk(
3, dim=-1
)
else:
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn(
emb_B_T_D
).chunk(3, dim=-1)
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn(emb_B_T_D).chunk(
3, dim=-1
)
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn(
emb_B_T_D
).chunk(3, dim=-1)
@ -954,11 +910,14 @@ class Block(nn.Module):
result = rearrange(
self.self_attn(
rearrange(normalized_x, "b t h w d -> b (t h w) d"),
attn_params,
None,
rope_emb=rope_emb_L_1_1_D,
),
"b (t h w) d -> b t h w d",
t=T, h=H, w=W,
t=T,
h=H,
w=W,
)
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result
@ -967,11 +926,14 @@ class Block(nn.Module):
result = rearrange(
self.cross_attn(
rearrange(normalized_x, "b t h w d -> b (t h w) d"),
attn_params,
crossattn_emb,
rope_emb=rope_emb_L_1_1_D,
),
"b (t h w) d -> b t h w d",
t=T, h=H, w=W,
t=T,
h=H,
w=W,
)
x_B_T_H_W_D = result * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
@ -987,6 +949,7 @@ class Block(nn.Module):
x_B_T_H_W_D: torch.Tensor,
emb_B_T_D: torch.Tensor,
crossattn_emb: torch.Tensor,
attn_params: attention.AttentionParams,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
@ -996,8 +959,13 @@ class Block(nn.Module):
# Unsloth: async non-blocking CPU RAM offload (fastest offload method)
return unsloth_checkpoint(
self._forward,
x_B_T_H_W_D, emb_B_T_D, crossattn_emb,
rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb,
x_B_T_H_W_D,
emb_B_T_D,
crossattn_emb,
attn_params,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
)
elif self.cpu_offload_checkpointing:
# Standard cpu offload: blocking transfers
@ -1008,36 +976,54 @@ class Block(nn.Module):
device_inputs = to_device(inputs, device)
outputs = func(*device_inputs)
return to_cpu(outputs)
return custom_forward
return torch_checkpoint(
create_custom_forward(self._forward),
x_B_T_H_W_D, emb_B_T_D, crossattn_emb,
rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb,
x_B_T_H_W_D,
emb_B_T_D,
crossattn_emb,
attn_params,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
use_reentrant=False,
)
else:
# Standard gradient checkpointing (no offload)
return torch_checkpoint(
self._forward,
x_B_T_H_W_D, emb_B_T_D, crossattn_emb,
rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb,
x_B_T_H_W_D,
emb_B_T_D,
crossattn_emb,
attn_params,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
use_reentrant=False,
)
else:
return self._forward(
x_B_T_H_W_D, emb_B_T_D, crossattn_emb,
rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb,
x_B_T_H_W_D,
emb_B_T_D,
crossattn_emb,
attn_params,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
)
# Main DiT Model: MiniTrainDIT
class MiniTrainDIT(nn.Module):
# Main DiT Model: MiniTrainDIT (renamed to Anima)
class Anima(nn.Module):
"""Cosmos-Predict2 DiT model for image/video generation.
28 transformer blocks with AdaLN-LoRA modulation, 3D RoPE, and optional LLM Adapter.
"""
LATENT_CHANNELS = 16
def __init__(
self,
max_img_h: int,
@ -1069,6 +1055,8 @@ class MiniTrainDIT(nn.Module):
extra_t_extrapolation_ratio: float = 1.0,
rope_enable_fps_modulation: bool = True,
use_llm_adapter: bool = False,
attn_mode: str = "torch",
split_attn: bool = False,
) -> None:
super().__init__()
self.max_img_h = max_img_h
@ -1097,6 +1085,9 @@ class MiniTrainDIT(nn.Module):
self.rope_enable_fps_modulation = rope_enable_fps_modulation
self.use_llm_adapter = use_llm_adapter
self.attn_mode = attn_mode
self.split_attn = split_attn
# Block swap support
self.blocks_to_swap = None
self.offloader: Optional[custom_offloading_utils.ModelOffloader] = None
@ -1156,7 +1147,6 @@ class MiniTrainDIT(nn.Module):
self.final_layer.init_weights()
self.t_embedding_norm.reset_parameters()
def enable_gradient_checkpointing(self, cpu_offload: bool = False, unsloth_offload: bool = False):
for block in self.blocks:
block.enable_gradient_checkpointing(cpu_offload=cpu_offload, unsloth_offload=unsloth_offload)
@ -1169,18 +1159,9 @@ class MiniTrainDIT(nn.Module):
def device(self):
return next(self.parameters()).device
def set_flash_attn(self, use_flash_attn: bool):
"""Toggle flash attention for all DiT blocks (self-attn + cross-attn).
LLM Adapter attention is NOT affected (it uses attention masks incompatible with flash_attn).
"""
if use_flash_attn and not FLASH_ATTN_AVAILABLE:
raise ImportError("flash_attn package is required for --flash_attn but is not installed")
attn_op = flash_attention_op if use_flash_attn else torch_attention_op
for block in self.blocks:
block.self_attn.attn_op = attn_op
block.cross_attn.attn_op = attn_op
@property
def dtype(self):
return next(self.parameters()).dtype
def build_patch_embed(self) -> None:
in_channels = self.in_channels + 1 if self.concat_padding_mask else self.in_channels
@ -1232,9 +1213,7 @@ class MiniTrainDIT(nn.Module):
padding_mask = transforms.functional.resize(
padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
)
x_B_C_T_H_W = torch.cat(
[x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1
)
x_B_C_T_H_W = torch.cat([x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1)
x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W)
if self.extra_per_block_abs_pos_emb:
@ -1258,7 +1237,6 @@ class MiniTrainDIT(nn.Module):
)
return x_B_C_Tt_Hp_Wp
def enable_block_swap(self, num_blocks: int, device: torch.device):
self.blocks_to_swap = num_blocks
@ -1266,9 +1244,7 @@ class MiniTrainDIT(nn.Module):
self.blocks_to_swap <= self.num_blocks - 2
), f"Cannot swap more than {self.num_blocks - 2} blocks. Requested: {self.blocks_to_swap} blocks."
self.offloader = custom_offloading_utils.ModelOffloader(
self.blocks, self.blocks_to_swap, device
)
self.offloader = custom_offloading_utils.ModelOffloader(self.blocks, self.blocks_to_swap, device)
logger.info(f"Anima: Block swap enabled. Swapping {num_blocks} blocks, total blocks: {self.num_blocks}, device: {device}.")
def move_to_device_except_swap_blocks(self, device: torch.device):
@ -1282,12 +1258,26 @@ class MiniTrainDIT(nn.Module):
if self.blocks_to_swap:
self.blocks = save_blocks
def switch_block_swap_for_inference(self):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
self.offloader.set_forward_only(True)
self.prepare_block_swap_before_forward()
print(f"Anima: Block swap set to forward only.")
def switch_block_swap_for_training(self):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
self.offloader.set_forward_only(False)
self.prepare_block_swap_before_forward()
print(f"Anima: Block swap set to forward and backward.")
def prepare_block_swap_before_forward(self):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
self.offloader.prepare_block_devices_before_forward(self.blocks)
def forward(
def forward_mini_train_dit(
self,
x_B_C_T_H_W: torch.Tensor,
timesteps_B_T: torch.Tensor,
@ -1310,7 +1300,7 @@ class MiniTrainDIT(nn.Module):
t5_attn_mask: Optional T5 attention mask
"""
# Run LLM adapter inside forward for correct DDP gradient synchronization
if t5_input_ids is not None and self.use_llm_adapter and hasattr(self, 'llm_adapter'):
if t5_input_ids is not None and self.use_llm_adapter and hasattr(self, "llm_adapter"):
crossattn_emb = self.llm_adapter(
source_hidden_states=crossattn_emb,
target_input_ids=t5_input_ids,
@ -1337,16 +1327,13 @@ class MiniTrainDIT(nn.Module):
"extra_per_block_pos_emb": extra_pos_emb,
}
attn_params = attention.AttentionParams.create_attention_params(self.attn_mode, self.split_attn)
for block_idx, block in enumerate(self.blocks):
if self.blocks_to_swap:
self.offloader.wait_for_block(block_idx)
x_B_T_H_W_D = block(
x_B_T_H_W_D,
t_embedding_B_T_D,
crossattn_emb,
**block_kwargs,
)
x_B_T_H_W_D = block(x_B_T_H_W_D, t_embedding_B_T_D, crossattn_emb, attn_params, **block_kwargs)
if self.blocks_to_swap:
self.offloader.submit_move_blocks(self.blocks, block_idx)
@ -1355,6 +1342,36 @@ class MiniTrainDIT(nn.Module):
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)
return x_B_C_Tt_Hp_Wp
def forward(
self,
x: torch.Tensor,
timesteps: torch.Tensor,
context: Optional[torch.Tensor] = None,
fps: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
target_input_ids: Optional[torch.Tensor] = None,
target_attention_mask: Optional[torch.Tensor] = None,
source_attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
context = self._preprocess_text_embeds(context, target_input_ids, target_attention_mask, source_attention_mask)
return self.forward_mini_train_dit(x, timesteps, context, fps=fps, padding_mask=padding_mask, **kwargs)
def _preprocess_text_embeds(
self, source_hidden_states, target_input_ids, target_attention_mask=None, source_attention_mask=None
):
if target_input_ids is not None:
context = self.llm_adapter(
source_hidden_states,
target_input_ids,
target_attention_mask=target_attention_mask,
source_attention_mask=source_attention_mask,
)
context[~target_attention_mask.bool()] = 0 # zero out padding tokens
return context
else:
return source_hidden_states
# LLM Adapter: Bridges Qwen3 embeddings to T5-compatible space
class LLMAdapterRMSNorm(nn.Module):
@ -1485,24 +1502,37 @@ class LLMAdapterTransformerBlock(nn.Module):
self.norm_mlp = nn.LayerNorm(model_dim) if layer_norm else LLMAdapterRMSNorm(model_dim)
self.mlp = nn.Sequential(
nn.Linear(model_dim, int(model_dim * mlp_ratio)),
nn.GELU(),
nn.Linear(int(model_dim * mlp_ratio), model_dim)
nn.Linear(model_dim, int(model_dim * mlp_ratio)), nn.GELU(), nn.Linear(int(model_dim * mlp_ratio), model_dim)
)
def forward(self, x, context, target_attention_mask=None, source_attention_mask=None,
position_embeddings=None, position_embeddings_context=None):
def forward(
self,
x,
context,
target_attention_mask=None,
source_attention_mask=None,
position_embeddings=None,
position_embeddings_context=None,
):
if self.has_self_attn:
# Self-attention: target_attention_mask is not expected to be all zeros
normed = self.norm_self_attn(x)
attn_out = self.self_attn(normed, mask=target_attention_mask,
position_embeddings=position_embeddings,
position_embeddings_context=position_embeddings)
attn_out = self.self_attn(
normed,
mask=target_attention_mask,
position_embeddings=position_embeddings,
position_embeddings_context=position_embeddings,
)
x = x + attn_out
normed = self.norm_cross_attn(x)
attn_out = self.cross_attn(normed, mask=source_attention_mask, context=context,
position_embeddings=position_embeddings,
position_embeddings_context=position_embeddings_context)
attn_out = self.cross_attn(
normed,
mask=source_attention_mask,
context=context,
position_embeddings=position_embeddings,
position_embeddings_context=position_embeddings_context,
)
x = x + attn_out
x = x + self.mlp(self.norm_mlp(x))
@ -1518,8 +1548,9 @@ class LLMAdapter(nn.Module):
Uses T5 token IDs as target input, embeds them, and cross-attends to Qwen3 hidden states.
"""
def __init__(self, source_dim, target_dim, model_dim, num_layers=6, num_heads=16,
embed=None, self_attn=False, layer_norm=False):
def __init__(
self, source_dim, target_dim, model_dim, num_layers=6, num_heads=16, embed=None, self_attn=False, layer_norm=False
):
super().__init__()
if embed is not None:
self.embed = nn.Embedding.from_pretrained(embed.weight)
@ -1530,11 +1561,12 @@ class LLMAdapter(nn.Module):
else:
self.in_proj = nn.Identity()
self.rotary_emb = AdapterRotaryEmbedding(model_dim // num_heads)
self.blocks = nn.ModuleList([
LLMAdapterTransformerBlock(source_dim, model_dim, num_heads=num_heads,
self_attn=self_attn, layer_norm=layer_norm)
for _ in range(num_layers)
])
self.blocks = nn.ModuleList(
[
LLMAdapterTransformerBlock(source_dim, model_dim, num_heads=num_heads, self_attn=self_attn, layer_norm=layer_norm)
for _ in range(num_layers)
]
)
self.out_proj = nn.Linear(model_dim, target_dim)
self.norm = LLMAdapterRMSNorm(target_dim)
@ -1556,75 +1588,67 @@ class LLMAdapter(nn.Module):
position_embeddings = self.rotary_emb(x, position_ids)
position_embeddings_context = self.rotary_emb(x, position_ids_context)
for block in self.blocks:
x = block(x, context, target_attention_mask=target_attention_mask,
source_attention_mask=source_attention_mask,
position_embeddings=position_embeddings,
position_embeddings_context=position_embeddings_context)
x = block(
x,
context,
target_attention_mask=target_attention_mask,
source_attention_mask=source_attention_mask,
position_embeddings=position_embeddings,
position_embeddings_context=position_embeddings_context,
)
return self.norm(self.out_proj(x))
# VAE Wrapper
# Not used currently, but kept for reference
# VAE normalization constants
ANIMA_VAE_MEAN = [
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
]
ANIMA_VAE_STD = [
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
]
# def get_dit_config(state_dict, key_prefix=""):
# """Derive DiT configuration from state_dict weight shapes."""
# dit_config = {}
# dit_config["max_img_h"] = 512
# dit_config["max_img_w"] = 512
# dit_config["max_frames"] = 128
# concat_padding_mask = True
# dit_config["in_channels"] = (state_dict["{}x_embedder.proj.1.weight".format(key_prefix)].shape[1] // 4) - int(
# concat_padding_mask
# )
# dit_config["out_channels"] = 16
# dit_config["patch_spatial"] = 2
# dit_config["patch_temporal"] = 1
# dit_config["model_channels"] = state_dict["{}x_embedder.proj.1.weight".format(key_prefix)].shape[0]
# dit_config["concat_padding_mask"] = concat_padding_mask
# dit_config["crossattn_emb_channels"] = 1024
# dit_config["pos_emb_cls"] = "rope3d"
# dit_config["pos_emb_learnable"] = True
# dit_config["pos_emb_interpolation"] = "crop"
# dit_config["min_fps"] = 1
# dit_config["max_fps"] = 30
# DiT config detection from state_dict
KEEP_IN_HIGH_PRECISION = ['x_embedder', 't_embedder', 't_embedding_norm', 'final_layer']
# dit_config["use_adaln_lora"] = True
# dit_config["adaln_lora_dim"] = 256
# if dit_config["model_channels"] == 2048:
# dit_config["num_blocks"] = 28
# dit_config["num_heads"] = 16
# elif dit_config["model_channels"] == 5120:
# dit_config["num_blocks"] = 36
# dit_config["num_heads"] = 40
# elif dit_config["model_channels"] == 1280:
# dit_config["num_blocks"] = 20
# dit_config["num_heads"] = 20
# if dit_config["in_channels"] == 16:
# dit_config["extra_per_block_abs_pos_emb"] = False
# dit_config["rope_h_extrapolation_ratio"] = 4.0
# dit_config["rope_w_extrapolation_ratio"] = 4.0
# dit_config["rope_t_extrapolation_ratio"] = 1.0
# elif dit_config["in_channels"] == 17:
# dit_config["extra_per_block_abs_pos_emb"] = False
# dit_config["rope_h_extrapolation_ratio"] = 3.0
# dit_config["rope_w_extrapolation_ratio"] = 3.0
# dit_config["rope_t_extrapolation_ratio"] = 1.0
def get_dit_config(state_dict, key_prefix=''):
"""Derive DiT configuration from state_dict weight shapes."""
dit_config = {}
dit_config["max_img_h"] = 512
dit_config["max_img_w"] = 512
dit_config["max_frames"] = 128
concat_padding_mask = True
dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask)
dit_config["out_channels"] = 16
dit_config["patch_spatial"] = 2
dit_config["patch_temporal"] = 1
dit_config["model_channels"] = state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[0]
dit_config["concat_padding_mask"] = concat_padding_mask
dit_config["crossattn_emb_channels"] = 1024
dit_config["pos_emb_cls"] = "rope3d"
dit_config["pos_emb_learnable"] = True
dit_config["pos_emb_interpolation"] = "crop"
dit_config["min_fps"] = 1
dit_config["max_fps"] = 30
# dit_config["extra_h_extrapolation_ratio"] = 1.0
# dit_config["extra_w_extrapolation_ratio"] = 1.0
# dit_config["extra_t_extrapolation_ratio"] = 1.0
# dit_config["rope_enable_fps_modulation"] = False
dit_config["use_adaln_lora"] = True
dit_config["adaln_lora_dim"] = 256
if dit_config["model_channels"] == 2048:
dit_config["num_blocks"] = 28
dit_config["num_heads"] = 16
elif dit_config["model_channels"] == 5120:
dit_config["num_blocks"] = 36
dit_config["num_heads"] = 40
elif dit_config["model_channels"] == 1280:
dit_config["num_blocks"] = 20
dit_config["num_heads"] = 20
if dit_config["in_channels"] == 16:
dit_config["extra_per_block_abs_pos_emb"] = False
dit_config["rope_h_extrapolation_ratio"] = 4.0
dit_config["rope_w_extrapolation_ratio"] = 4.0
dit_config["rope_t_extrapolation_ratio"] = 1.0
elif dit_config["in_channels"] == 17:
dit_config["extra_per_block_abs_pos_emb"] = False
dit_config["rope_h_extrapolation_ratio"] = 3.0
dit_config["rope_w_extrapolation_ratio"] = 3.0
dit_config["rope_t_extrapolation_ratio"] = 1.0
dit_config["extra_h_extrapolation_ratio"] = 1.0
dit_config["extra_w_extrapolation_ratio"] = 1.0
dit_config["extra_t_extrapolation_ratio"] = 1.0
dit_config["rope_enable_fps_modulation"] = False
return dit_config
# return dit_config

View File

@ -1,20 +1,20 @@
# Anima Training Utilities
import argparse
import gc
import math
import os
import time
from typing import Dict, List, Optional, Tuple, Union
from typing import Optional
import numpy as np
import torch
import torch.nn.functional as F
from safetensors.torch import save_file
from accelerate import Accelerator, PartialState
from accelerate import Accelerator
from tqdm import tqdm
from PIL import Image
from library.device_utils import init_ipex, clean_memory_on_device
from library.device_utils import init_ipex, clean_memory_on_device, synchronize_device
from library import anima_models, anima_utils, train_util, qwen_image_autoencoder_kl
init_ipex()
@ -25,29 +25,14 @@ import logging
logger = logging.getLogger(__name__)
from library import anima_models, anima_utils, strategy_base, train_util
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler, get_sigmas
# Anima-specific training arguments
def add_anima_training_arguments(parser: argparse.ArgumentParser):
"""Add Anima-specific training arguments to the parser."""
parser.add_argument(
"--dit_path",
type=str,
default=None,
help="Path to Anima DiT model safetensors file",
)
parser.add_argument(
"--vae_path",
type=str,
default=None,
help="Path to WanVAE safetensors/pth file",
)
parser.add_argument(
"--qwen3_path",
"--qwen3",
type=str,
default=None,
help="Path to Qwen3-0.6B model (safetensors file or directory)",
@ -86,7 +71,7 @@ def add_anima_training_arguments(parser: argparse.ArgumentParser):
"--mod_lr",
type=float,
default=None,
help="Learning rate for AdaLN modulation layers. None=same as base LR, 0=freeze",
help="Learning rate for AdaLN modulation layers. None=same as base LR, 0=freeze. Note: mod layers are not included in LoRA by default.",
)
parser.add_argument(
"--t5_tokenizer_path",
@ -113,110 +98,52 @@ def add_anima_training_arguments(parser: argparse.ArgumentParser):
help="Timestep distribution shift for rectified flow training (default: 1.0)",
)
parser.add_argument(
"--timestep_sample_method",
"--timestep_sampling",
type=str,
default="logit_normal",
choices=["logit_normal", "uniform"],
help="Timestep sampling method (default: logit_normal)",
default="sigmoid",
choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"],
help="Timestep sampling method (default: sigmoid (logit normal))",
)
parser.add_argument(
"--sigmoid_scale",
type=float,
default=1.0,
help="Scale factor for logit_normal timestep sampling (default: 1.0)",
help="Scale factor for sigmoid (logit_normal) timestep sampling (default: 1.0)",
)
# Note: --caption_dropout_rate is defined by base add_dataset_arguments().
# Anima uses embedding-level dropout (via AnimaTextEncodingStrategy.dropout_rate)
# instead of dataset-level caption dropout, so the subset caption_dropout_rate
# is zeroed out in the training scripts to allow caching.
parser.add_argument(
"--transformer_dtype",
type=str,
"--attn_mode",
choices=["torch", "xformers", "flash", "sageattn", "sdpa"], # "sdpa" is for backward compatibility
default=None,
choices=["float16", "bfloat16", "float32", None],
help="Separate dtype for transformer blocks. If None, uses same as mixed_precision",
help="Attention implementation to use. Default is None (torch). xformers requires --split_attn. sageattn does not support training (inference only). This option overrides --xformers or --sdpa."
" / 使用するAttentionの実装。デフォルトはNonetorchです。xformersは--split_attnの指定が必要です。sageattnはトレーニングをサポートしていません推論のみ。このオプションは--xformersまたは--sdpaを上書きします。",
)
parser.add_argument(
"--flash_attn",
"--split_attn",
action="store_true",
help="Use Flash Attention for DiT self/cross-attention (requires flash-attn package). "
"Falls back to PyTorch SDPA if flash-attn is not installed.",
help="split attention computation to reduce memory usage / メモリ使用量を減らすためにattention時にバッチを分割する",
)
parser.add_argument(
"--vae_chunk_size",
type=int,
default=None,
help="Spatial chunk size for VAE encoding/decoding to reduce memory usage. Must be even number. If not specified, chunking is disabled (official behavior)."
+ " / メモリ使用量を減らすためのVAEエンコード/デコードの空間チャンクサイズ。偶数である必要があります。未指定の場合、チャンク処理は無効になります(公式の動作)。",
)
parser.add_argument(
"--vae_disable_cache",
action="store_true",
help="Disable internal VAE caching mechanism to reduce memory usage. Encoding / decoding will also be faster, but this differs from official behavior."
+ " / VAEのメモリ使用量を減らすために内部のキャッシュ機構を無効にします。エンコード/デコードも速くなりますが、公式の動作とは異なります。",
)
# Noise & Timestep sampling (Rectified Flow)
def get_noisy_model_input_and_timesteps(
args,
latents: torch.Tensor,
noise: torch.Tensor,
device: torch.device,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Generate noisy model input and timesteps for rectified flow training.
Rectified flow: noisy_input = (1 - t) * latents + t * noise
Target: noise - latents
Args:
args: Training arguments with timestep_sample_method, sigmoid_scale, discrete_flow_shift
latents: Clean latent tensors
noise: Random noise tensors
device: Target device
dtype: Target dtype
Returns:
(noisy_model_input, timesteps, sigmas)
"""
bs = latents.shape[0]
timestep_sample_method = getattr(args, 'timestep_sample_method', 'logit_normal')
sigmoid_scale = getattr(args, 'sigmoid_scale', 1.0)
shift = getattr(args, 'discrete_flow_shift', 1.0)
if timestep_sample_method == 'logit_normal':
dist = torch.distributions.normal.Normal(0, 1)
elif timestep_sample_method == 'uniform':
dist = torch.distributions.uniform.Uniform(0, 1)
else:
raise NotImplementedError(f"Unknown timestep_sample_method: {timestep_sample_method}")
t = dist.sample((bs,)).to(device)
if timestep_sample_method == 'logit_normal':
t = t * sigmoid_scale
t = torch.sigmoid(t)
# Apply shift
if shift is not None and shift != 1.0:
t = (t * shift) / (1 + (shift - 1) * t)
# Clamp to avoid exact 0 or 1
t = t.clamp(1e-5, 1.0 - 1e-5)
# Create noisy input: (1 - t) * latents + t * noise
t_expanded = t.view(-1, *([1] * (latents.ndim - 1)))
ip_noise_gamma = getattr(args, 'ip_noise_gamma', None)
if ip_noise_gamma:
xi = torch.randn_like(latents, device=latents.device, dtype=dtype)
if getattr(args, 'ip_noise_gamma_random_strength', False):
ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * ip_noise_gamma
noisy_model_input = (1 - t_expanded) * latents + t_expanded * (noise + ip_noise_gamma * xi)
else:
noisy_model_input = (1 - t_expanded) * latents + t_expanded * noise
# Sigmas for potential loss weighting
sigmas = t.view(-1, 1)
return noisy_model_input.to(dtype), t.to(dtype), sigmas.to(dtype)
# Loss weighting
def compute_loss_weighting_for_anima(weighting_scheme: str, sigmas: torch.Tensor) -> torch.Tensor:
"""Compute loss weighting for Anima training.
Same schemes as SD3 but can add Anima-specific ones.
Same schemes as SD3 but can add Anima-specific ones if needed in future.
"""
if weighting_scheme == "sigma_sqrt":
weighting = (sigmas**-2.0).float()
@ -243,7 +170,7 @@ def get_anima_param_groups(
"""Create parameter groups for Anima training with separate learning rates.
Args:
dit: MiniTrainDIT model
dit: Anima model
base_lr: Base learning rate
self_attn_lr: LR for self-attention layers (None = base_lr, 0 = freeze)
cross_attn_lr: LR for cross-attention layers
@ -276,15 +203,15 @@ def get_anima_param_groups(
# Store original name for debugging
p.original_name = name
if 'llm_adapter' in name:
if "llm_adapter" in name:
llm_adapter_params.append(p)
elif '.self_attn' in name:
elif ".self_attn" in name:
self_attn_params.append(p)
elif '.cross_attn' in name:
elif ".cross_attn" in name:
cross_attn_params.append(p)
elif '.mlp' in name:
elif ".mlp" in name:
mlp_params.append(p)
elif '.adaln_modulation' in name:
elif ".adaln_modulation" in name:
mod_params.append(p)
else:
base_params.append(p)
@ -311,9 +238,9 @@ def get_anima_param_groups(
p.requires_grad_(False)
logger.info(f" Frozen {name} params ({len(params)} parameters)")
elif len(params) > 0:
param_groups.append({'params': params, 'lr': lr})
param_groups.append({"params": params, "lr": lr})
total_trainable = sum(p.numel() for group in param_groups for p in group['params'] if p.requires_grad)
total_trainable = sum(p.numel() for group in param_groups for p in group["params"] if p.requires_grad)
logger.info(f"Total trainable parameters: {total_trainable:,}")
return param_groups
@ -325,16 +252,17 @@ def save_anima_model_on_train_end(
save_dtype: torch.dtype,
epoch: int,
global_step: int,
dit: anima_models.MiniTrainDIT,
dit: anima_models.Anima,
):
"""Save Anima model at the end of training."""
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec(
None, args, False, False, False, is_stable_diffusion_ckpt=True
)
sai_metadata = train_util.get_sai_model_spec_dataclass(
None, args, False, False, False, is_stable_diffusion_ckpt=True, anima="preview"
).to_metadata_dict()
dit_sd = dit.state_dict()
# Save with 'net.' prefix for ComfyUI compatibility
anima_utils.save_anima_model(ckpt_file, dit_sd, save_dtype)
anima_utils.save_anima_model(ckpt_file, dit_sd, sai_metadata, save_dtype)
train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None)
@ -347,15 +275,16 @@ def save_anima_model_on_epoch_end_or_stepwise(
epoch: int,
num_train_epochs: int,
global_step: int,
dit: anima_models.MiniTrainDIT,
dit: anima_models.Anima,
):
"""Save Anima model at epoch end or specific steps."""
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec(
None, args, False, False, False, is_stable_diffusion_ckpt=True
)
sai_metadata = train_util.get_sai_model_spec_dataclass(
None, args, False, False, False, is_stable_diffusion_ckpt=True, anima="preview"
).to_metadata_dict()
dit_sd = dit.state_dict()
anima_utils.save_anima_model(ckpt_file, dit_sd, save_dtype)
anima_utils.save_anima_model(ckpt_file, dit_sd, sai_metadata, save_dtype)
train_util.save_sd_model_on_epoch_end_or_stepwise_common(
args,
@ -376,12 +305,13 @@ def do_sample(
height: int,
width: int,
seed: Optional[int],
dit: anima_models.MiniTrainDIT,
dit: anima_models.Anima,
crossattn_emb: torch.Tensor,
steps: int,
dtype: torch.dtype,
device: torch.device,
guidance_scale: float = 1.0,
flow_shift: float = 3.0,
neg_crossattn_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Generate a sample using Euler discrete sampling for rectified flow.
@ -389,12 +319,13 @@ def do_sample(
Args:
height, width: Output image dimensions
seed: Random seed (None for random)
dit: MiniTrainDIT model
dit: Anima model
crossattn_emb: Cross-attention embeddings (B, N, D)
steps: Number of sampling steps
dtype: Compute dtype
device: Compute device
guidance_scale: CFG scale (1.0 = no guidance)
flow_shift: Flow shift parameter for rectified flow
neg_crossattn_emb: Negative cross-attention embeddings for CFG
Returns:
@ -410,12 +341,13 @@ def do_sample(
generator = torch.manual_seed(seed)
else:
generator = None
noise = torch.randn(
latent.size(), dtype=torch.float32, generator=generator, device="cpu"
).to(dtype).to(device)
noise = torch.randn(latent.size(), dtype=torch.float32, generator=generator, device="cpu").to(dtype).to(device)
# Timestep schedule: linear from 1.0 to 0.0
sigmas = torch.linspace(1.0, 0.0, steps + 1, device=device, dtype=dtype)
flow_shift = float(flow_shift)
if flow_shift != 1.0:
sigmas = (sigmas * flow_shift) / (1 + (flow_shift - 1) * sigmas)
# Start from pure noise
x = noise.clone()
@ -429,19 +361,13 @@ def do_sample(
sigma = sigmas[i]
t = sigma.unsqueeze(0) # (1,)
dit.prepare_block_swap_before_forward()
if use_cfg:
# CFG: concat positive and negative
x_input = torch.cat([x, x], dim=0)
t_input = torch.cat([t, t], dim=0)
crossattn_input = torch.cat([crossattn_emb, neg_crossattn_emb], dim=0)
padding_input = torch.cat([padding_mask, padding_mask], dim=0)
# CFG: two separate passes to reduce memory usage
pos_out = dit(x, t, crossattn_emb, padding_mask=padding_mask)
pos_out = pos_out.float()
neg_out = dit(x, t, neg_crossattn_emb, padding_mask=padding_mask)
neg_out = neg_out.float()
model_output = dit(x_input, t_input, crossattn_input, padding_mask=padding_input)
model_output = model_output.float()
pos_out, neg_out = model_output.chunk(2)
model_output = neg_out + guidance_scale * (pos_out - neg_out)
else:
model_output = dit(x, t, crossattn_emb, padding_mask=padding_mask)
@ -452,7 +378,6 @@ def do_sample(
x = x + model_output * dt
x = x.to(dtype)
dit.prepare_block_swap_before_forward()
return x
@ -461,9 +386,8 @@ def sample_images(
args: argparse.Namespace,
epoch,
steps,
dit,
dit: anima_models.Anima,
vae,
vae_scale,
text_encoder,
tokenize_strategy,
text_encoding_strategy,
@ -497,6 +421,8 @@ def sample_images(
if text_encoder is not None:
text_encoder = accelerator.unwrap_model(text_encoder)
dit.switch_block_swap_for_inference()
prompts = train_util.load_prompts(args.sample_prompts)
save_dir = os.path.join(args.output_dir, "sample")
os.makedirs(save_dir, exist_ok=True)
@ -511,11 +437,21 @@ def sample_images(
with torch.no_grad(), accelerator.autocast():
for prompt_dict in prompts:
dit.prepare_block_swap_before_forward()
_sample_image_inference(
accelerator, args, dit, text_encoder, vae, vae_scale,
tokenize_strategy, text_encoding_strategy,
save_dir, prompt_dict, epoch, steps,
sample_prompts_te_outputs, prompt_replacement,
accelerator,
args,
dit,
text_encoder,
vae,
tokenize_strategy,
text_encoding_strategy,
save_dir,
prompt_dict,
epoch,
steps,
sample_prompts_te_outputs,
prompt_replacement,
)
# Restore RNG state
@ -523,14 +459,24 @@ def sample_images(
if cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state)
dit.switch_block_swap_for_training()
clean_memory_on_device(accelerator.device)
def _sample_image_inference(
accelerator, args, dit, text_encoder, vae, vae_scale,
tokenize_strategy, text_encoding_strategy,
save_dir, prompt_dict, epoch, steps,
sample_prompts_te_outputs, prompt_replacement,
accelerator,
args,
dit,
text_encoder,
vae: qwen_image_autoencoder_kl.AutoencoderKLQwenImage,
tokenize_strategy,
text_encoding_strategy,
save_dir,
prompt_dict,
epoch,
steps,
sample_prompts_te_outputs,
prompt_replacement,
):
"""Generate a single sample image."""
prompt = prompt_dict.get("prompt", "")
@ -540,6 +486,7 @@ def _sample_image_inference(
height = prompt_dict.get("height", 512)
scale = prompt_dict.get("scale", 7.5)
seed = prompt_dict.get("seed")
flow_shift = prompt_dict.get("flow_shift", 3.0)
if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
@ -553,7 +500,9 @@ def _sample_image_inference(
height = max(64, height - height % 16)
width = max(64, width - width % 16)
logger.info(f" prompt: {prompt}, size: {width}x{height}, steps: {sample_steps}, scale: {scale}")
logger.info(
f" prompt: {prompt}, size: {width}x{height}, steps: {sample_steps}, scale: {scale}, flow_shift: {flow_shift}, seed: {seed}"
)
# Encode prompt
def encode_prompt(prpt):
@ -579,13 +528,13 @@ def _sample_image_inference(
t5_input_ids = torch.from_numpy(t5_input_ids).unsqueeze(0)
t5_attn_mask = torch.from_numpy(t5_attn_mask).unsqueeze(0)
prompt_embeds = prompt_embeds.to(accelerator.device, dtype=dit.t_embedding_norm.weight.dtype)
prompt_embeds = prompt_embeds.to(accelerator.device, dtype=dit.dtype)
attn_mask = attn_mask.to(accelerator.device)
t5_input_ids = t5_input_ids.to(accelerator.device, dtype=torch.long)
t5_attn_mask = t5_attn_mask.to(accelerator.device)
# Process through LLM adapter if available
if dit.use_llm_adapter and hasattr(dit, 'llm_adapter'):
if dit.use_llm_adapter:
crossattn_emb = dit.llm_adapter(
source_hidden_states=prompt_embeds,
target_input_ids=t5_input_ids,
@ -608,12 +557,12 @@ def _sample_image_inference(
neg_t5_ids = torch.from_numpy(neg_t5_ids).unsqueeze(0)
neg_t5_am = torch.from_numpy(neg_t5_am).unsqueeze(0)
neg_pe = neg_pe.to(accelerator.device, dtype=dit.t_embedding_norm.weight.dtype)
neg_pe = neg_pe.to(accelerator.device, dtype=dit.dtype)
neg_am = neg_am.to(accelerator.device)
neg_t5_ids = neg_t5_ids.to(accelerator.device, dtype=torch.long)
neg_t5_am = neg_t5_am.to(accelerator.device)
if dit.use_llm_adapter and hasattr(dit, 'llm_adapter'):
if dit.use_llm_adapter:
neg_crossattn_emb = dit.llm_adapter(
source_hidden_states=neg_pe,
target_input_ids=neg_t5_ids,
@ -627,16 +576,16 @@ def _sample_image_inference(
# Generate sample
clean_memory_on_device(accelerator.device)
latents = do_sample(
height, width, seed, dit, crossattn_emb,
sample_steps, dit.t_embedding_norm.weight.dtype,
accelerator.device, scale, neg_crossattn_emb,
height, width, seed, dit, crossattn_emb, sample_steps, dit.dtype, accelerator.device, scale, flow_shift, neg_crossattn_emb
)
# Decode latents
gc.collect()
synchronize_device(accelerator.device)
clean_memory_on_device(accelerator.device)
org_vae_device = next(vae.parameters()).device
org_vae_device = vae.device
vae.to(accelerator.device)
decoded = vae.decode(latents.to(next(vae.parameters()).device, dtype=next(vae.parameters()).dtype), vae_scale)
decoded = vae.decode_to_pixels(latents)
vae.to(org_vae_device)
clean_memory_on_device(accelerator.device)
@ -662,4 +611,5 @@ def _sample_image_inference(
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
wandb_tracker = accelerator.get_tracker("wandb")
import wandb
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False)

View File

@ -3,10 +3,14 @@
import os
from typing import Dict, List, Optional, Union
import torch
import torch.nn as nn
from safetensors.torch import load_file, save_file
from accelerate.utils import set_module_tensor_to_device # kept for potential future use
from accelerate import init_empty_weights
from library.fp8_optimization_utils import apply_fp8_monkey_patch
from library.lora_utils import load_safetensors_with_lora_and_fp8
from library import anima_models
from library.safetensors_utils import WeightTransformHooks
from .utils import setup_logging
setup_logging()
@ -14,150 +18,134 @@ import logging
logger = logging.getLogger(__name__)
from library import anima_models
# Original Anima high-precision keys. Kept for reference, but not used currently.
# # Keys that should stay in high precision (float32/bfloat16, not quantized)
# KEEP_IN_HIGH_PRECISION = ["x_embedder", "t_embedder", "t_embedding_norm", "final_layer"]
# Keys that should stay in high precision (float32/bfloat16, not quantized)
KEEP_IN_HIGH_PRECISION = ['x_embedder', 't_embedder', 't_embedding_norm', 'final_layer']
FP8_OPTIMIZATION_TARGET_KEYS = ["blocks", ""]
# ".embed." excludes Embedding in LLMAdapter
FP8_OPTIMIZATION_EXCLUDE_KEYS = ["_embedder", "norm", "adaln", "final_layer", ".embed."]
def load_safetensors(path: str, device: str = "cpu", dtype: Optional[torch.dtype] = None) -> Dict[str, torch.Tensor]:
"""Load a safetensors file and optionally cast to dtype."""
sd = load_file(path, device=device)
if dtype is not None:
sd = {k: v.to(dtype) for k, v in sd.items()}
return sd
def load_anima_dit(
def load_anima_model(
device: Union[str, torch.device],
dit_path: str,
dtype: torch.dtype,
device: Union[str, torch.device] = "cpu",
transformer_dtype: Optional[torch.dtype] = None,
llm_adapter_path: Optional[str] = None,
disable_mmap: bool = False,
) -> anima_models.MiniTrainDIT:
"""Load the MiniTrainDIT model from safetensors.
attn_mode: str,
split_attn: bool,
loading_device: Union[str, torch.device],
dit_weight_dtype: Optional[torch.dtype],
fp8_scaled: bool = False,
lora_weights_list: Optional[List[Dict[str, torch.Tensor]]] = None,
lora_multipliers: Optional[list[float]] = None,
) -> anima_models.Anima:
"""
Load Anima model from the specified checkpoint.
Args:
dit_path: Path to DiT safetensors file
dtype: Base dtype for model parameters
device: Device to load to
transformer_dtype: Optional separate dtype for transformer blocks (lower precision)
llm_adapter_path: Optional separate path for LLM adapter weights
disable_mmap: If True, disable memory-mapped loading (reduces peak memory)
device (Union[str, torch.device]): Device for optimization or merging
dit_path (str): Path to the DiT model checkpoint.
attn_mode (str): Attention mode to use, e.g., "torch", "flash", etc.
split_attn (bool): Whether to use split attention.
loading_device (Union[str, torch.device]): Device to load the model weights on.
dit_weight_dtype (Optional[torch.dtype]): Data type of the DiT weights.
If None, it will be loaded as is (same as the state_dict) or scaled for fp8. if not None, model weights will be casted to this dtype.
fp8_scaled (bool): Whether to use fp8 scaling for the model weights.
lora_weights_list (Optional[List[Dict[str, torch.Tensor]]]): LoRA weights to apply, if any.
lora_multipliers (Optional[List[float]]): LoRA multipliers for the weights, if any.
"""
if transformer_dtype is None:
transformer_dtype = dtype
# dit_weight_dtype is None for fp8_scaled
assert (
not fp8_scaled and dit_weight_dtype is not None
) or dit_weight_dtype is None, "dit_weight_dtype should be None when fp8_scaled is True"
logger.info(f"Loading Anima DiT from {dit_path}")
if disable_mmap:
from library.safetensors_utils import load_safetensors as load_safetensors_no_mmap
state_dict = load_safetensors_no_mmap(dit_path, device="cpu", disable_mmap=True)
else:
state_dict = load_file(dit_path, device="cpu")
device = torch.device(device)
loading_device = torch.device(loading_device)
# Remove 'net.' prefix if present
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith('net.'):
k = k[len('net.'):]
new_state_dict[k] = v
state_dict = new_state_dict
# We currently support fixed DiT config for Anima models
dit_config = {
"max_img_h": 512,
"max_img_w": 512,
"max_frames": 128,
"in_channels": 16,
"out_channels": 16,
"patch_spatial": 2,
"patch_temporal": 1,
"model_channels": 2048,
"concat_padding_mask": True,
"crossattn_emb_channels": 1024,
"pos_emb_cls": "rope3d",
"pos_emb_learnable": True,
"pos_emb_interpolation": "crop",
"min_fps": 1,
"max_fps": 30,
"use_adaln_lora": True,
"adaln_lora_dim": 256,
"num_blocks": 28,
"num_heads": 16,
"extra_per_block_abs_pos_emb": False,
"rope_h_extrapolation_ratio": 4.0,
"rope_w_extrapolation_ratio": 4.0,
"rope_t_extrapolation_ratio": 1.0,
"extra_h_extrapolation_ratio": 1.0,
"extra_w_extrapolation_ratio": 1.0,
"extra_t_extrapolation_ratio": 1.0,
"rope_enable_fps_modulation": False,
"use_llm_adapter": True,
"attn_mode": attn_mode,
"split_attn": split_attn,
}
with init_empty_weights():
model = anima_models.Anima(**dit_config)
if dit_weight_dtype is not None:
model.to(dit_weight_dtype)
# Derive config from state_dict
dit_config = anima_models.get_dit_config(state_dict)
# Detect LLM adapter
if llm_adapter_path is not None:
use_llm_adapter = True
dit_config['use_llm_adapter'] = True
llm_adapter_state_dict = load_safetensors(llm_adapter_path, device="cpu")
elif 'llm_adapter.out_proj.weight' in state_dict:
use_llm_adapter = True
dit_config['use_llm_adapter'] = True
llm_adapter_state_dict = None # Loaded as part of DiT
else:
use_llm_adapter = False
llm_adapter_state_dict = None
logger.info(f"DiT config: model_channels={dit_config['model_channels']}, num_blocks={dit_config['num_blocks']}, "
f"num_heads={dit_config['num_heads']}, use_llm_adapter={use_llm_adapter}")
# Build model normally on CPU — buffers get proper values from __init__
dit = anima_models.MiniTrainDIT(**dit_config)
# Merge LLM adapter weights into state_dict if loaded separately
if use_llm_adapter and llm_adapter_state_dict is not None:
for k, v in llm_adapter_state_dict.items():
state_dict[f"llm_adapter.{k}"] = v
# Load checkpoint: strict=False keeps buffers not in checkpoint (e.g. pos_embedder.seq)
missing, unexpected = dit.load_state_dict(state_dict, strict=False)
if missing:
# Filter out expected missing buffers (initialized in __init__, not saved in checkpoint)
unexpected_missing = [k for k in missing if not any(
buf_name in k for buf_name in ('seq', 'dim_spatial_range', 'dim_temporal_range', 'inv_freq')
)]
if unexpected_missing:
logger.warning(f"Missing keys in checkpoint: {unexpected_missing[:10]}{'...' if len(unexpected_missing) > 10 else ''}")
if unexpected:
logger.info(f"Unexpected keys in checkpoint (ignored): {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}")
# Apply per-parameter dtype (high precision for 1D/critical, transformer_dtype for rest)
for name, p in dit.named_parameters():
dtype_to_use = dtype if (
any(keyword in name for keyword in KEEP_IN_HIGH_PRECISION) or p.ndim == 1
) else transformer_dtype
p.data = p.data.to(dtype=dtype_to_use)
dit.to(device)
logger.info(f"Loaded Anima DiT successfully. Parameters: {sum(p.numel() for p in dit.parameters()):,}")
return dit
def load_anima_vae(vae_path: str, dtype: torch.dtype = torch.float32, device: str = "cpu"):
"""Load WanVAE from a safetensors/pth file.
Returns (vae_model, mean_tensor, std_tensor, scale).
"""
from library.anima_models import ANIMA_VAE_MEAN, ANIMA_VAE_STD
logger.info(f"Loading Anima VAE from {vae_path}")
# VAE config (fixed for WanVAE)
vae_config = dict(
dim=96,
z_dim=16,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[False, True, True],
dropout=0.0,
# load model weights with dynamic fp8 optimization and LoRA merging if needed
logger.info(f"Loading DiT model from {dit_path}, device={loading_device}")
rename_hooks = WeightTransformHooks(rename_hook=lambda k: k[len("net.") :] if k.startswith("net.") else k)
sd = load_safetensors_with_lora_and_fp8(
model_files=dit_path,
lora_weights_list=lora_weights_list,
lora_multipliers=lora_multipliers,
fp8_optimization=fp8_scaled,
calc_device=device,
move_to_device=(loading_device == device),
dit_weight_dtype=dit_weight_dtype,
target_keys=FP8_OPTIMIZATION_TARGET_KEYS,
exclude_keys=FP8_OPTIMIZATION_EXCLUDE_KEYS,
weight_transform_hooks=rename_hooks,
)
from library.anima_vae import WanVAE_
if fp8_scaled:
apply_fp8_monkey_patch(model, sd, use_scaled_mm=False)
# Build model
with torch.device('meta'):
vae = WanVAE_(**vae_config)
if loading_device.type != "cpu":
# make sure all the model weights are on the loading_device
logger.info(f"Moving weights to {loading_device}")
for key in sd.keys():
sd[key] = sd[key].to(loading_device)
# Load state dict
if vae_path.endswith('.safetensors'):
vae_sd = load_file(vae_path, device='cpu')
else:
vae_sd = torch.load(vae_path, map_location='cpu', weights_only=True)
missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
if missing:
# Filter out expected missing buffers (initialized in __init__, not saved in checkpoint)
unexpected_missing = [
k
for k in missing
if not any(buf_name in k for buf_name in ("seq", "dim_spatial_range", "dim_temporal_range", "inv_freq"))
]
if unexpected_missing:
# Raise error to avoid silent failures
raise RuntimeError(
f"Missing keys in checkpoint: {unexpected_missing[:10]}{'...' if len(unexpected_missing) > 10 else ''}"
)
missing = {} # all missing keys were expected
if unexpected:
# Raise error to avoid silent failures
raise RuntimeError(f"Unexpected keys in checkpoint: {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}")
logger.info(f"Loaded DiT model from {dit_path}, unexpected missing keys: {len(missing)}, unexpected keys: {len(unexpected)}")
vae.load_state_dict(vae_sd, assign=True)
vae = vae.eval().requires_grad_(False).to(device, dtype=dtype)
# Create normalization tensors
mean = torch.tensor(ANIMA_VAE_MEAN, dtype=dtype, device=device)
std = torch.tensor(ANIMA_VAE_STD, dtype=dtype, device=device)
scale = [mean, 1.0 / std]
logger.info(f"Loaded Anima VAE successfully.")
return vae, mean, std, scale
return model
def load_qwen3_tokenizer(qwen3_path: str):
@ -175,7 +163,7 @@ def load_qwen3_tokenizer(qwen3_path: str):
if os.path.isdir(qwen3_path):
tokenizer = AutoTokenizer.from_pretrained(qwen3_path, local_files_only=True)
else:
config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'configs', 'qwen3_06b')
config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs", "qwen3_06b")
if not os.path.exists(config_dir):
raise FileNotFoundError(
f"Qwen3 config directory not found at {config_dir}. "
@ -190,7 +178,13 @@ def load_qwen3_tokenizer(qwen3_path: str):
return tokenizer
def load_qwen3_text_encoder(qwen3_path: str, dtype: torch.dtype = torch.bfloat16, device: str = "cpu"):
def load_qwen3_text_encoder(
qwen3_path: str,
dtype: torch.dtype = torch.bfloat16,
device: str = "cpu",
lora_weights: Optional[List[Dict[str, torch.Tensor]]] = None,
lora_multipliers: Optional[List[float]] = None,
):
"""Load Qwen3-0.6B text encoder.
Args:
@ -209,12 +203,10 @@ def load_qwen3_text_encoder(qwen3_path: str, dtype: torch.dtype = torch.bfloat16
if os.path.isdir(qwen3_path):
# Directory with full model
tokenizer = AutoTokenizer.from_pretrained(qwen3_path, local_files_only=True)
model = transformers.AutoModelForCausalLM.from_pretrained(
qwen3_path, torch_dtype=dtype, local_files_only=True
).model
model = transformers.AutoModelForCausalLM.from_pretrained(qwen3_path, torch_dtype=dtype, local_files_only=True).model
else:
# Single safetensors file - use configs/qwen3_06b/ for config
config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'configs', 'qwen3_06b')
config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs", "qwen3_06b")
if not os.path.exists(config_dir):
raise FileNotFoundError(
f"Qwen3 config directory not found at {config_dir}. "
@ -227,16 +219,28 @@ def load_qwen3_text_encoder(qwen3_path: str, dtype: torch.dtype = torch.bfloat16
model = transformers.Qwen3ForCausalLM(qwen3_config).model
# Load weights
if qwen3_path.endswith('.safetensors'):
state_dict = load_file(qwen3_path, device='cpu')
if qwen3_path.endswith(".safetensors"):
if lora_weights is None:
state_dict = load_file(qwen3_path, device="cpu")
else:
state_dict = load_safetensors_with_lora_and_fp8(
model_files=qwen3_path,
lora_weights_list=lora_weights,
lora_multipliers=lora_multipliers,
fp8_optimization=False,
calc_device=device,
move_to_device=True,
dit_weight_dtype=None,
)
else:
state_dict = torch.load(qwen3_path, map_location='cpu', weights_only=True)
assert lora_weights is None, "LoRA weights merging is only supported for safetensors checkpoints"
state_dict = torch.load(qwen3_path, map_location="cpu", weights_only=True)
# Remove 'model.' prefix if present
new_sd = {}
for k, v in state_dict.items():
if k.startswith('model.'):
new_sd[k[len('model.'):]] = v
if k.startswith("model."):
new_sd[k[len("model.") :]] = v
else:
new_sd[k] = v
@ -265,11 +269,11 @@ def load_t5_tokenizer(t5_tokenizer_path: Optional[str] = None):
return T5TokenizerFast.from_pretrained(t5_tokenizer_path, local_files_only=True)
# Use bundled config
config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'configs', 't5_old')
config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs", "t5_old")
if os.path.exists(config_dir):
return T5TokenizerFast(
vocab_file=os.path.join(config_dir, 'spiece.model'),
tokenizer_file=os.path.join(config_dir, 'tokenizer.json'),
vocab_file=os.path.join(config_dir, "spiece.model"),
tokenizer_file=os.path.join(config_dir, "tokenizer.json"),
)
raise FileNotFoundError(
@ -279,47 +283,27 @@ def load_t5_tokenizer(t5_tokenizer_path: Optional[str] = None):
)
def save_anima_model(save_path: str, dit_state_dict: Dict[str, torch.Tensor], dtype: Optional[torch.dtype] = None):
def save_anima_model(
save_path: str, dit_state_dict: Dict[str, torch.Tensor], metadata: Dict[str, any], dtype: Optional[torch.dtype] = None
):
"""Save Anima DiT model with 'net.' prefix for ComfyUI compatibility.
Args:
save_path: Output path (.safetensors)
dit_state_dict: State dict from dit.state_dict()
metadata: Metadata dict to include in the safetensors file
dtype: Optional dtype to cast to before saving
"""
prefixed_sd = {}
for k, v in dit_state_dict.items():
if dtype is not None:
v = v.to(dtype)
prefixed_sd['net.' + k] = v.contiguous()
# v = v.to(dtype)
v = v.detach().clone().to("cpu").to(dtype) # Reduce GPU memory usage during save
prefixed_sd["net." + k] = v.contiguous()
save_file(prefixed_sd, save_path, metadata={'format': 'pt'})
if metadata is None:
metadata = {}
metadata["format"] = "pt" # For compatibility with the official .safetensors file
save_file(prefixed_sd, save_path, metadata=metadata) # safetensors.save_file cosumes a lot of memory, but Anima is small enough
logger.info(f"Saved Anima model to {save_path}")
def vae_encode(tensor: torch.Tensor, vae, scale):
"""Encode tensor through WanVAE with normalization.
Args:
tensor: Input tensor (B, C, T, H, W) in [-1, 1] range
vae: WanVAE_ model
scale: [mean, 1/std] list
Returns:
Normalized latents
"""
return vae.encode(tensor, scale)
def vae_decode(latents: torch.Tensor, vae, scale):
"""Decode latents through WanVAE with denormalization.
Args:
latents: Normalized latents
vae: WanVAE_ model
scale: [mean, 1/std] list
Returns:
Decoded tensor in [-1, 1] range
"""
return vae.decode(latents, scale)

View File

@ -1,577 +0,0 @@
import logging
import torch
import torch.cuda.amp as amp
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
CACHE_T = 2
class CausalConv3d(nn.Conv3d):
"""
Causal 3d convolusion.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._padding = (self.padding[2], self.padding[2], self.padding[1],
self.padding[1], 2 * self.padding[0], 0)
self.padding = (0, 0, 0)
def forward(self, x, cache_x=None):
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
cache_x = cache_x.to(x.device)
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]
x = F.pad(x, padding)
return super().forward(x)
class RMS_norm(nn.Module):
def __init__(self, dim, channel_first=True, images=True, bias=False):
super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.channel_first = channel_first
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
def forward(self, x):
return F.normalize(
x, dim=(1 if self.channel_first else
-1)) * self.scale * self.gamma + self.bias
class Upsample(nn.Upsample):
def forward(self, x):
"""
Fix bfloat16 support for nearest neighbor interpolation.
"""
return super().forward(x.float()).type_as(x)
class Resample(nn.Module):
def __init__(self, dim, mode):
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
'downsample3d')
super().__init__()
self.dim = dim
self.mode = mode
# layers
if mode == 'upsample2d':
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
nn.Conv2d(dim, dim // 2, 3, padding=1))
elif mode == 'upsample3d':
self.resample = nn.Sequential(
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
nn.Conv2d(dim, dim // 2, 3, padding=1))
self.time_conv = CausalConv3d(
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
elif mode == 'downsample2d':
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
elif mode == 'downsample3d':
self.resample = nn.Sequential(
nn.ZeroPad2d((0, 1, 0, 1)),
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
self.time_conv = CausalConv3d(
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
else:
self.resample = nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
b, c, t, h, w = x.size()
if self.mode == 'upsample3d':
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = 'Rep'
feat_idx[0] += 1
else:
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[
idx] is not None and feat_cache[idx] != 'Rep':
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
if cache_x.shape[2] < 2 and feat_cache[
idx] is not None and feat_cache[idx] == 'Rep':
cache_x = torch.cat([
torch.zeros_like(cache_x).to(cache_x.device),
cache_x
],
dim=2)
if feat_cache[idx] == 'Rep':
x = self.time_conv(x)
else:
x = self.time_conv(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
x = x.reshape(b, 2, c, t, h, w)
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
3)
x = x.reshape(b, c, t * 2, h, w)
t = x.shape[2]
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = self.resample(x)
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
if self.mode == 'downsample3d':
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = x.clone()
feat_idx[0] += 1
else:
cache_x = x[:, :, -1:, :, :].clone()
x = self.time_conv(
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
feat_cache[idx] = cache_x
feat_idx[0] += 1
return x
def init_weight(self, conv):
conv_weight = conv.weight
nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size()
one_matrix = torch.eye(c1, c2)
init_matrix = one_matrix
nn.init.zeros_(conv_weight)
conv_weight.data[:, :, 1, 0, 0] = init_matrix
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
def init_weight2(self, conv):
conv_weight = conv.weight.data
nn.init.zeros_(conv_weight)
c1, c2, t, h, w = conv_weight.size()
init_matrix = torch.eye(c1 // 2, c2)
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
conv.weight.data.copy_(conv_weight)
nn.init.zeros_(conv.bias.data)
class ResidualBlock(nn.Module):
def __init__(self, in_dim, out_dim, dropout=0.0):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
# layers
self.residual = nn.Sequential(
RMS_norm(in_dim, images=False), nn.SiLU(),
CausalConv3d(in_dim, out_dim, 3, padding=1),
RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
CausalConv3d(out_dim, out_dim, 3, padding=1))
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
if in_dim != out_dim else nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
h = self.shortcut(x)
for layer in self.residual:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x + h
class AttentionBlock(nn.Module):
"""
Causal self-attention with a single head.
"""
def __init__(self, dim):
super().__init__()
self.dim = dim
# layers
self.norm = RMS_norm(dim)
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
self.proj = nn.Conv2d(dim, dim, 1)
# zero out the last layer params
nn.init.zeros_(self.proj.weight)
def forward(self, x):
identity = x
b, c, t, h, w = x.size()
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = self.norm(x)
# compute query, key, value
q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3,
-1).permute(0, 1, 3,
2).contiguous().chunk(
3, dim=-1)
# apply attention
x = F.scaled_dot_product_attention(
q,
k,
v,
)
x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
# output
x = self.proj(x)
x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
return x + identity
class Encoder3d(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
# dimensions
dims = [dim * u for u in [1] + dim_mult]
scale = 1.0
# init block
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
# downsample blocks
downsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
for _ in range(num_res_blocks):
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
downsamples.append(AttentionBlock(out_dim))
in_dim = out_dim
# downsample block
if i != len(dim_mult) - 1:
mode = 'downsample3d' if temperal_downsample[
i] else 'downsample2d'
downsamples.append(Resample(out_dim, mode=mode))
scale /= 2.0
self.downsamples = nn.Sequential(*downsamples)
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
ResidualBlock(out_dim, out_dim, dropout))
# output blocks
self.head = nn.Sequential(
RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, z_dim, 3, padding=1))
def forward(self, x, feat_cache=None, feat_idx=[0]):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
## downsamples
for layer in self.downsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## middle
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## head
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
class Decoder3d(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_upsample=[False, True, True],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_upsample = temperal_upsample
# dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
scale = 1.0 / 2**(len(dim_mult) - 2)
# init block
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
# middle blocks
self.middle = nn.Sequential(
ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
ResidualBlock(dims[0], dims[0], dropout))
# upsample blocks
upsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
if i == 1 or i == 2 or i == 3:
in_dim = in_dim // 2
for _ in range(num_res_blocks + 1):
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
upsamples.append(AttentionBlock(out_dim))
in_dim = out_dim
# upsample block
if i != len(dim_mult) - 1:
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
upsamples.append(Resample(out_dim, mode=mode))
scale *= 2.0
self.upsamples = nn.Sequential(*upsamples)
# output blocks
self.head = nn.Sequential(
RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, 3, 3, padding=1))
def forward(self, x, feat_cache=None, feat_idx=[0]):
## conv1
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
## middle
for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## upsamples
for layer in self.upsamples:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## head
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
def count_conv3d(model):
count = 0
for m in model.modules():
if isinstance(m, CausalConv3d):
count += 1
return count
class WanVAE_(nn.Module):
def __init__(self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
self.temperal_upsample = temperal_downsample[::-1]
# modules
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
attn_scales, self.temperal_downsample, dropout)
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
attn_scales, self.temperal_upsample, dropout)
def forward(self, x):
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
x_recon = self.decode(z)
return x_recon, mu, log_var
def encode(self, x, scale):
self.clear_cache()
## cache
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
for i in range(iter_):
self._enc_conv_idx = [0]
if i == 0:
out = self.encoder(
x[:, :, :1, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
else:
out_ = self.encoder(
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx)
out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1)
if isinstance(scale[0], torch.Tensor):
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
1, self.z_dim, 1, 1, 1)
else:
mu = (mu - scale[0]) * scale[1]
self.clear_cache()
return mu
def decode(self, z, scale):
self.clear_cache()
# z: [b,c,t,h,w]
if isinstance(scale[0], torch.Tensor):
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
1, self.z_dim, 1, 1, 1)
else:
z = z / scale[1] + scale[0]
iter_ = z.shape[2]
x = self.conv2(z)
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
else:
out_ = self.decoder(
x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,
feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2)
self.clear_cache()
return out
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return eps * std + mu
def sample(self, imgs, deterministic=False):
mu, log_var = self.encode(imgs)
if deterministic:
return mu
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
return mu + std * torch.randn_like(std)
def clear_cache(self):
self._conv_num = count_conv3d(self.decoder)
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
#cache encode
self._enc_conv_num = count_conv3d(self.encoder)
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num

View File

@ -37,6 +37,14 @@ class AttentionParams:
cu_seqlens: Optional[torch.Tensor] = None
max_seqlen: Optional[int] = None
@property
def supports_fp32(self) -> bool:
return self.attn_mode not in ["flash"]
@property
def requires_same_dtype(self) -> bool:
return self.attn_mode in ["xformers"]
@staticmethod
def create_attention_params(attn_mode: Optional[str], split_attn: bool) -> "AttentionParams":
return AttentionParams(attn_mode, split_attn)
@ -95,7 +103,7 @@ def attention(
qkv_or_q: Query tensor [B, L, H, D]. or list of such tensors.
k: Key tensor [B, L, H, D].
v: Value tensor [B, L, H, D].
attn_param: Attention parameters including mask and sequence lengths.
attn_params: Attention parameters including mask and sequence lengths.
drop_rate: Attention dropout rate.
Returns:

View File

@ -195,6 +195,9 @@ class ModelOffloader(Offloader):
self.remove_handles.append(handle)
def set_forward_only(self, forward_only: bool):
# switching must wait for all pending transfers
for block_idx in list(self.futures.keys()):
self._wait_blocks_move(block_idx)
self.forward_only = forward_only
def __del__(self):
@ -237,6 +240,10 @@ class ModelOffloader(Offloader):
if self.debug:
print(f"Prepare block devices before forward")
# wait for all pending transfers
for block_idx in list(self.futures.keys()):
self._wait_blocks_move(block_idx)
for b in blocks[0 : self.num_blocks - self.blocks_to_swap]:
b.to(self.device)
weighs_to_device(b, self.device) # make sure weights are on device

View File

@ -471,7 +471,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
def get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
bsz, _, h, w = latents.shape
bsz, h, w = latents.shape[0], latents.shape[-2], latents.shape[-1]
assert bsz > 0, "Batch size not large enough"
num_timesteps = noise_scheduler.config.num_train_timesteps
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
@ -512,7 +512,7 @@ def get_noisy_model_input_and_timesteps(
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
# Broadcast sigmas to latent shape
sigmas = sigmas.view(-1, 1, 1, 1)
sigmas = sigmas.view(-1, 1, 1, 1) if latents.ndim == 4 else sigmas.view(-1, 1, 1, 1, 1)
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)

View File

@ -9,7 +9,7 @@ import logging
from tqdm import tqdm
from library.device_utils import clean_memory_on_device
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.safetensors_utils import MemoryEfficientSafeOpen, TensorWeightAdapter, WeightTransformHooks
from library.utils import setup_logging
setup_logging()
@ -220,6 +220,8 @@ def quantize_weight(
tensor_max = torch.max(torch.abs(tensor).view(-1))
scale = tensor_max / max_value
# print(f"Optimizing {key} with scale: {scale}")
# numerical safety
scale = torch.clamp(scale, min=1e-8)
scale = scale.to(torch.float32) # ensure scale is in float32 for division
@ -245,6 +247,8 @@ def load_safetensors_with_fp8_optimization(
weight_hook=None,
quantization_mode: str = "block",
block_size: Optional[int] = 64,
disable_numpy_memmap: bool = False,
weight_transform_hooks: Optional[WeightTransformHooks] = None,
) -> dict:
"""
Load weight tensors from safetensors files and merge LoRA weights into the state dict with explicit FP8 optimization.
@ -260,6 +264,8 @@ def load_safetensors_with_fp8_optimization(
weight_hook (callable, optional): Function to apply to each weight tensor before optimization
quantization_mode (str): Quantization mode, "tensor", "channel", or "block"
block_size (int, optional): Block size for block-wise quantization (used if quantization_mode is "block")
disable_numpy_memmap (bool): Disable numpy memmap when loading safetensors
weight_transform_hooks (WeightTransformHooks, optional): Hooks for weight transformation during loading
Returns:
dict: FP8 optimized state dict
@ -288,7 +294,9 @@ def load_safetensors_with_fp8_optimization(
# Process each file
state_dict = {}
for model_file in model_files:
with MemoryEfficientSafeOpen(model_file) as f:
with MemoryEfficientSafeOpen(model_file, disable_numpy_memmap=disable_numpy_memmap) as original_f:
f = TensorWeightAdapter(weight_transform_hooks, original_f) if weight_transform_hooks is not None else original_f
keys = f.keys()
for key in tqdm(keys, desc=f"Loading {os.path.basename(model_file)}", unit="key"):
value = f.get_tensor(key)
@ -311,6 +319,11 @@ def load_safetensors_with_fp8_optimization(
value = value.to(calc_device)
original_dtype = value.dtype
if original_dtype.itemsize == 1:
raise ValueError(
f"Layer {key} is already in {original_dtype} format. `--fp8_scaled` optimization should not be applied. Please use fp16/bf16/float32 model weights."
+ f" / レイヤー {key} は既に{original_dtype}形式です。`--fp8_scaled` 最適化は適用できません。FP16/BF16/Float32のモデル重みを使用してください。"
)
quantized_weight, scale_tensor = quantize_weight(
key, value, fp8_dtype, max_value, min_value, quantization_mode, block_size
)
@ -387,7 +400,7 @@ def fp8_linear_forward_patch(self: nn.Linear, x, use_scaled_mm=False, max_value=
else:
o = torch._scaled_mm(x, weight, out_dtype=input_dtype, scale_a=scale_x, scale_b=scale_weight)
o = o.reshape(original_shape[0], original_shape[1], -1) if x.ndim == 3 else o.reshape(original_shape[0], -1)
o = o.reshape(original_shape[0], original_shape[1], -1) if len(original_shape) == 3 else o.reshape(original_shape[0], -1)
return o.to(input_dtype)
else:

View File

@ -5,7 +5,7 @@ import torch
from tqdm import tqdm
from library.device_utils import synchronize_device
from library.fp8_optimization_utils import load_safetensors_with_fp8_optimization
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.safetensors_utils import MemoryEfficientSafeOpen, TensorWeightAdapter, WeightTransformHooks, get_split_weight_filenames
from library.utils import setup_logging
setup_logging()
@ -44,7 +44,7 @@ def filter_lora_state_dict(
def load_safetensors_with_lora_and_fp8(
model_files: Union[str, List[str]],
lora_weights_list: Optional[Dict[str, torch.Tensor]],
lora_weights_list: Optional[List[Dict[str, torch.Tensor]]],
lora_multipliers: Optional[List[float]],
fp8_optimization: bool,
calc_device: torch.device,
@ -52,19 +52,23 @@ def load_safetensors_with_lora_and_fp8(
dit_weight_dtype: Optional[torch.dtype] = None,
target_keys: Optional[List[str]] = None,
exclude_keys: Optional[List[str]] = None,
disable_numpy_memmap: bool = False,
weight_transform_hooks: Optional[WeightTransformHooks] = None,
) -> dict[str, torch.Tensor]:
"""
Merge LoRA weights into the state dict of a model with fp8 optimization if needed.
Args:
model_files (Union[str, List[str]]): Path to the model file or list of paths. If the path matches a pattern like `00001-of-00004`, it will load all files with the same prefix.
lora_weights_list (Optional[Dict[str, torch.Tensor]]): Dictionary of LoRA weight tensors to load.
lora_weights_list (Optional[List[Dict[str, torch.Tensor]]]): List of dictionaries of LoRA weight tensors to load.
lora_multipliers (Optional[List[float]]): List of multipliers for LoRA weights.
fp8_optimization (bool): Whether to apply FP8 optimization.
calc_device (torch.device): Device to calculate on.
move_to_device (bool): Whether to move tensors to the calculation device after loading.
target_keys (Optional[List[str]]): Keys to target for optimization.
exclude_keys (Optional[List[str]]): Keys to exclude from optimization.
disable_numpy_memmap (bool): Whether to disable numpy memmap when loading safetensors.
weight_transform_hooks (Optional[WeightTransformHooks]): Hooks for transforming weights during loading.
"""
# if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix
@ -73,19 +77,9 @@ def load_safetensors_with_lora_and_fp8(
extended_model_files = []
for model_file in model_files:
basename = os.path.basename(model_file)
match = re.match(r"^(.*?)(\d+)-of-(\d+)\.safetensors$", basename)
if match:
prefix = basename[: match.start(2)]
count = int(match.group(3))
state_dict = {}
for i in range(count):
filename = f"{prefix}{i + 1:05d}-of-{count:05d}.safetensors"
filepath = os.path.join(os.path.dirname(model_file), filename)
if os.path.exists(filepath):
extended_model_files.append(filepath)
else:
raise FileNotFoundError(f"File {filepath} not found")
split_filenames = get_split_weight_filenames(model_file)
if split_filenames is not None:
extended_model_files.extend(split_filenames)
else:
extended_model_files.append(model_file)
model_files = extended_model_files
@ -114,7 +108,7 @@ def load_safetensors_with_lora_and_fp8(
logger.info(f"Merging LoRA weights into state dict. multipliers: {lora_multipliers}")
# make hook for LoRA merging
def weight_hook_func(model_weight_key, model_weight, keep_on_calc_device=False):
def weight_hook_func(model_weight_key, model_weight: torch.Tensor, keep_on_calc_device=False):
nonlocal list_of_lora_weight_keys, lora_weights_list, lora_multipliers, calc_device
if not model_weight_key.endswith(".weight"):
@ -126,13 +120,18 @@ def load_safetensors_with_lora_and_fp8(
for lora_weight_keys, lora_sd, multiplier in zip(list_of_lora_weight_keys, lora_weights_list, lora_multipliers):
# check if this weight has LoRA weights
lora_name = model_weight_key.rsplit(".", 1)[0] # remove trailing ".weight"
lora_name = "lora_unet_" + lora_name.replace(".", "_")
down_key = lora_name + ".lora_down.weight"
up_key = lora_name + ".lora_up.weight"
alpha_key = lora_name + ".alpha"
if down_key not in lora_weight_keys or up_key not in lora_weight_keys:
continue
lora_name_without_prefix = model_weight_key.rsplit(".", 1)[0] # remove trailing ".weight"
found = False
for prefix in ["lora_unet_", ""]:
lora_name = prefix + lora_name_without_prefix.replace(".", "_")
down_key = lora_name + ".lora_down.weight"
up_key = lora_name + ".lora_up.weight"
alpha_key = lora_name + ".alpha"
if down_key in lora_weight_keys and up_key in lora_weight_keys:
found = True
break
if not found:
continue # no LoRA weights for this model weight
# get LoRA weights
down_weight = lora_sd[down_key]
@ -145,6 +144,13 @@ def load_safetensors_with_lora_and_fp8(
down_weight = down_weight.to(calc_device)
up_weight = up_weight.to(calc_device)
original_dtype = model_weight.dtype
if original_dtype.itemsize == 1: # fp8
# temporarily convert to float16 for calculation
model_weight = model_weight.to(torch.float16)
down_weight = down_weight.to(torch.float16)
up_weight = up_weight.to(torch.float16)
# W <- W + U * D
if len(model_weight.size()) == 2:
# linear
@ -166,6 +172,9 @@ def load_safetensors_with_lora_and_fp8(
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
model_weight = model_weight + multiplier * conved * scale
if original_dtype.itemsize == 1: # fp8
model_weight = model_weight.to(original_dtype) # convert back to original dtype
# remove LoRA keys from set
lora_weight_keys.remove(down_key)
lora_weight_keys.remove(up_key)
@ -187,6 +196,8 @@ def load_safetensors_with_lora_and_fp8(
target_keys,
exclude_keys,
weight_hook=weight_hook,
disable_numpy_memmap=disable_numpy_memmap,
weight_transform_hooks=weight_transform_hooks,
)
for lora_weight_keys in list_of_lora_weight_keys:
@ -208,6 +219,8 @@ def load_safetensors_with_fp8_optimization_and_hook(
target_keys: Optional[List[str]] = None,
exclude_keys: Optional[List[str]] = None,
weight_hook: callable = None,
disable_numpy_memmap: bool = False,
weight_transform_hooks: Optional[WeightTransformHooks] = None,
) -> dict[str, torch.Tensor]:
"""
Load state dict from safetensors files and merge LoRA weights into the state dict with fp8 optimization if needed.
@ -218,7 +231,14 @@ def load_safetensors_with_fp8_optimization_and_hook(
)
# dit_weight_dtype is not used because we use fp8 optimization
state_dict = load_safetensors_with_fp8_optimization(
model_files, calc_device, target_keys, exclude_keys, move_to_device=move_to_device, weight_hook=weight_hook
model_files,
calc_device,
target_keys,
exclude_keys,
move_to_device=move_to_device,
weight_hook=weight_hook,
disable_numpy_memmap=disable_numpy_memmap,
weight_transform_hooks=weight_transform_hooks,
)
else:
logger.info(
@ -226,7 +246,8 @@ def load_safetensors_with_fp8_optimization_and_hook(
)
state_dict = {}
for model_file in model_files:
with MemoryEfficientSafeOpen(model_file) as f:
with MemoryEfficientSafeOpen(model_file, disable_numpy_memmap=disable_numpy_memmap) as original_f:
f = TensorWeightAdapter(weight_transform_hooks, original_f) if weight_transform_hooks is not None else original_f
for key in tqdm(f.keys(), desc=f"Loading {os.path.basename(model_file)}", leave=False):
if weight_hook is None and move_to_device:
value = f.get_tensor(key, device=calc_device, dtype=dit_weight_dtype)

File diff suppressed because it is too large Load Diff

View File

@ -1,3 +1,4 @@
from dataclasses import dataclass
import os
import re
import numpy as np
@ -44,6 +45,7 @@ def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata:
validated[key] = value
return validated
# print(f"Using memory efficient save file: {filename}")
header = {}
offset = 0
@ -88,15 +90,17 @@ class MemoryEfficientSafeOpen:
by using memory mapping for large tensors and avoiding unnecessary copies.
"""
def __init__(self, filename):
def __init__(self, filename, disable_numpy_memmap=False):
"""Initialize the SafeTensor reader.
Args:
filename (str): Path to the safetensors file to read.
disable_numpy_memmap (bool): If True, disable numpy memory mapping for large tensors, using standard file read instead.
"""
self.filename = filename
self.file = open(filename, "rb")
self.header, self.header_size = self._read_header()
self.disable_numpy_memmap = disable_numpy_memmap
def __enter__(self):
"""Enter context manager."""
@ -178,7 +182,8 @@ class MemoryEfficientSafeOpen:
# Use memmap for large tensors to avoid intermediate copies.
# If device is cpu, tensor is not copied to gpu, so using memmap locks the file, which is not desired.
# So we only use memmap if device is not cpu.
if num_bytes > 10 * 1024 * 1024 and device is not None and device.type != "cpu":
# If disable_numpy_memmap is True, skip numpy memory mapping to load with standard file read.
if not self.disable_numpy_memmap and num_bytes > 10 * 1024 * 1024 and device is not None and device.type != "cpu":
# Create memory map for zero-copy reading
mm = np.memmap(self.filename, mode="c", dtype=np.uint8, offset=tensor_offset, shape=(num_bytes,))
byte_tensor = torch.from_numpy(mm) # zero copy
@ -285,7 +290,11 @@ class MemoryEfficientSafeOpen:
def load_safetensors(
path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = None
path: str,
device: Union[str, torch.device],
disable_mmap: bool = False,
dtype: Optional[torch.dtype] = None,
disable_numpy_memmap: bool = False,
) -> dict[str, torch.Tensor]:
if disable_mmap:
# return safetensors.torch.load(open(path, "rb").read())
@ -293,7 +302,7 @@ def load_safetensors(
# logger.info(f"Loading without mmap (experimental)")
state_dict = {}
device = torch.device(device) if device is not None else None
with MemoryEfficientSafeOpen(path) as f:
with MemoryEfficientSafeOpen(path, disable_numpy_memmap=disable_numpy_memmap) as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key, device=device, dtype=dtype)
synchronize_device(device)
@ -309,6 +318,29 @@ def load_safetensors(
return state_dict
def get_split_weight_filenames(file_path: str) -> Optional[list[str]]:
"""
Get the list of split weight filenames (full paths) if the file name ends with 00001-of-00004 etc.
Returns None if the file is not split.
"""
basename = os.path.basename(file_path)
match = re.match(r"^(.*?)(\d+)-of-(\d+)\.safetensors$", basename)
if match:
prefix = basename[: match.start(2)]
count = int(match.group(3))
filenames = []
for i in range(count):
filename = f"{prefix}{i + 1:05d}-of-{count:05d}.safetensors"
filepath = os.path.join(os.path.dirname(file_path), filename)
if os.path.exists(filepath):
filenames.append(filepath)
else:
raise FileNotFoundError(f"File {filepath} not found")
return filenames
else:
return None
def load_split_weights(
file_path: str, device: Union[str, torch.device] = "cpu", disable_mmap: bool = False, dtype: Optional[torch.dtype] = None
) -> Dict[str, torch.Tensor]:
@ -319,19 +351,11 @@ def load_split_weights(
device = torch.device(device)
# if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix
basename = os.path.basename(file_path)
match = re.match(r"^(.*?)(\d+)-of-(\d+)\.safetensors$", basename)
if match:
prefix = basename[: match.start(2)]
count = int(match.group(3))
split_filenames = get_split_weight_filenames(file_path)
if split_filenames is not None:
state_dict = {}
for i in range(count):
filename = f"{prefix}{i + 1:05d}-of-{count:05d}.safetensors"
filepath = os.path.join(os.path.dirname(file_path), filename)
if os.path.exists(filepath):
state_dict.update(load_safetensors(filepath, device=device, disable_mmap=disable_mmap, dtype=dtype))
else:
raise FileNotFoundError(f"File {filepath} not found")
for filename in split_filenames:
state_dict.update(load_safetensors(filename, device=device, disable_mmap=disable_mmap, dtype=dtype))
else:
state_dict = load_safetensors(file_path, device=device, disable_mmap=disable_mmap, dtype=dtype)
return state_dict
@ -349,3 +373,106 @@ def find_key(safetensors_file: str, starts_with: Optional[str] = None, ends_with
if (starts_with is None or key.startswith(starts_with)) and (ends_with is None or key.endswith(ends_with)):
return key
return None
@dataclass
class WeightTransformHooks:
split_hook: Optional[callable] = None
concat_hook: Optional[callable] = None
rename_hook: Optional[callable] = None
class TensorWeightAdapter:
"""
A wrapper for weight conversion hooks (split and concat) to be used with MemoryEfficientSafeOpen.
This wrapper adapts the original MemoryEfficientSafeOpen to apply the provided split and concat hooks
when loading tensors.
split_hook: A callable that takes (original_key: str, original_tensor: torch.Tensor) and returns (new_keys: list[str], new_tensors: list[torch.Tensor]).
concat_hook: A callable that takes (original_key: str, tensors: dict[str, torch.Tensor]) and returns (new_key: str, concatenated_tensor: torch.Tensor).
rename_hook: A callable that takes (original_key: str) and returns (new_key: str).
If tensors is None, the hook should return only the new keys (for split) or new key (for concat), without tensors.
No need to implement __enter__ and __exit__ methods, as they are handled by the original MemoryEfficientSafeOpen.
Do not use this wrapper as a context manager directly, like `with WeightConvertHookWrapper(...) as f:`.
**concat_hook is not tested yet.**
"""
def __init__(self, weight_convert_hook: WeightTransformHooks, original_f: MemoryEfficientSafeOpen):
self.original_f = original_f
self.new_key_to_original_key_map: dict[str, Union[str, list[str]]] = (
{}
) # for split: new_key -> original_key; for concat: new_key -> list of original_keys; for direct mapping: new_key -> original_key
self.concat_key_set = set() # set of concatenated keys
self.split_key_set = set() # set of split keys
self.new_keys = []
self.tensor_cache = {} # cache for split tensors
self.split_hook = weight_convert_hook.split_hook
self.concat_hook = weight_convert_hook.concat_hook
self.rename_hook = weight_convert_hook.rename_hook
for key in self.original_f.keys():
if self.split_hook is not None:
converted_keys, _ = self.split_hook(key, None) # get new keys only
if converted_keys is not None:
for converted_key in converted_keys:
self.new_key_to_original_key_map[converted_key] = key
self.split_key_set.add(converted_key)
self.new_keys.extend(converted_keys)
continue # skip concat_hook if split_hook is applied
if self.concat_hook is not None:
converted_key, _ = self.concat_hook(key, None) # get new key only
if converted_key is not None:
if converted_key not in self.concat_key_set: # first time seeing this concatenated key
self.concat_key_set.add(converted_key)
self.new_key_to_original_key_map[converted_key] = []
self.new_keys.append(converted_key)
# multiple original keys map to the same concatenated key
self.new_key_to_original_key_map[converted_key].append(key)
continue # skip to next key
# direct mapping
if self.rename_hook is not None:
new_key = self.rename_hook(key)
self.new_key_to_original_key_map[new_key] = key
else:
new_key = key
self.new_keys.append(new_key)
def keys(self) -> list[str]:
return self.new_keys
def get_tensor(self, new_key: str, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
# load tensor by new_key, applying split or concat hooks as needed
if new_key not in self.new_key_to_original_key_map:
# direct mapping
return self.original_f.get_tensor(new_key, device=device, dtype=dtype)
elif new_key in self.split_key_set:
# split hook: split key is requested multiple times, so we cache the result
original_key = self.new_key_to_original_key_map[new_key]
if original_key not in self.tensor_cache: # not yet split
original_tensor = self.original_f.get_tensor(original_key, device=device, dtype=dtype)
new_keys, new_tensors = self.split_hook(original_key, original_tensor) # apply split hook
for k, t in zip(new_keys, new_tensors):
self.tensor_cache[k] = t
return self.tensor_cache.pop(new_key) # return and remove from cache
elif new_key in self.concat_key_set:
# concat hook: concatenated key is requested only once, so we do not cache the result
tensors = {}
for original_key in self.new_key_to_original_key_map[new_key]:
tensor = self.original_f.get_tensor(original_key, device=device, dtype=dtype)
tensors[original_key] = tensor
_, concatenated_tensors = self.concat_hook(self.new_key_to_original_key_map[new_key][0], tensors) # apply concat hook
return concatenated_tensors
else:
# direct mapping
original_key = self.new_key_to_original_key_map[new_key]
return self.original_f.get_tensor(original_key, device=device, dtype=dtype)

View File

@ -81,6 +81,8 @@ ARCH_LUMINA_2 = "lumina-2"
ARCH_LUMINA_UNKNOWN = "lumina"
ARCH_HUNYUAN_IMAGE_2_1 = "hunyuan-image-2.1"
ARCH_HUNYUAN_IMAGE_UNKNOWN = "hunyuan-image"
ARCH_ANIMA_PREVIEW = "anima-preview"
ARCH_ANIMA_UNKNOWN = "anima-unknown"
ADAPTER_LORA = "lora"
ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
@ -92,6 +94,7 @@ IMPL_FLUX = "https://github.com/black-forest-labs/flux"
IMPL_CHROMA = "https://huggingface.co/lodestones/Chroma"
IMPL_LUMINA = "https://github.com/Alpha-VLLM/Lumina-Image-2.0"
IMPL_HUNYUAN_IMAGE = "https://github.com/Tencent-Hunyuan/HunyuanImage-2.1"
IMPL_ANIMA = "https://huggingface.co/circlestone-labs/Anima"
PRED_TYPE_EPSILON = "epsilon"
PRED_TYPE_V = "v"
@ -220,6 +223,12 @@ def determine_architecture(
arch = ARCH_HUNYUAN_IMAGE_2_1
else:
arch = ARCH_HUNYUAN_IMAGE_UNKNOWN
elif "anima" in model_config:
anima_type = model_config["anima"]
if anima_type == "preview":
arch = ARCH_ANIMA_PREVIEW
else:
arch = ARCH_ANIMA_UNKNOWN
elif v2:
arch = ARCH_SD_V2_768_V if v_parameterization else ARCH_SD_V2_512
else:
@ -252,6 +261,8 @@ def determine_implementation(
return IMPL_FLUX
elif "lumina" in model_config:
return IMPL_LUMINA
elif "anima" in model_config:
return IMPL_ANIMA
elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
return IMPL_STABILITY_AI
else:
@ -325,7 +336,7 @@ def determine_resolution(
reso = (reso[0], reso[0])
else:
# Determine default resolution based on model type
if sdxl or "sd3" in model_config or "flux" in model_config or "lumina" in model_config:
if sdxl or "sd3" in model_config or "flux" in model_config or "lumina" in model_config or "anima" in model_config:
reso = (1024, 1024)
elif v2 and v_parameterization:
reso = (768, 768)

View File

@ -9,6 +9,7 @@ import torch
from library import anima_utils, train_util
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
from library import qwen_image_autoencoder_kl
from library.utils import setup_logging
@ -45,8 +46,8 @@ class AnimaTokenizeStrategy(TokenizeStrategy):
t5_tokenizer = anima_utils.load_t5_tokenizer(t5_tokenizer_path)
self.qwen3_tokenizer = qwen3_tokenizer
self.t5_tokenizer = t5_tokenizer
self.qwen3_max_length = qwen3_max_length
self.t5_tokenizer = t5_tokenizer
self.t5_max_length = t5_max_length
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
@ -54,26 +55,17 @@ class AnimaTokenizeStrategy(TokenizeStrategy):
# Tokenize with Qwen3
qwen3_encoding = self.qwen3_tokenizer.batch_encode_plus(
text,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=self.qwen3_max_length,
text, return_tensors="pt", truncation=True, padding="max_length", max_length=self.qwen3_max_length
)
qwen3_input_ids = qwen3_encoding["input_ids"]
qwen3_attn_mask = qwen3_encoding["attention_mask"]
# Tokenize with T5 (for LLM Adapter target tokens)
t5_encoding = self.t5_tokenizer.batch_encode_plus(
text,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=self.t5_max_length,
text, return_tensors="pt", truncation=True, padding="max_length", max_length=self.t5_max_length
)
t5_input_ids = t5_encoding["input_ids"]
t5_attn_mask = t5_encoding["attention_mask"]
return [qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask]
@ -84,46 +76,11 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
T5 tokens are passed through unchanged (only used by LLM Adapter).
"""
def __init__(
self,
dropout_rate: float = 0.0,
) -> None:
self.dropout_rate = dropout_rate
# Cached unconditional embeddings (from encoding empty caption "")
# Must be initialized via cache_uncond_embeddings() before text encoder is deleted
self._uncond_prompt_embeds: Optional[torch.Tensor] = None # (1, seq_len, hidden)
self._uncond_attn_mask: Optional[torch.Tensor] = None # (1, seq_len)
self._uncond_t5_input_ids: Optional[torch.Tensor] = None # (1, t5_seq_len)
self._uncond_t5_attn_mask: Optional[torch.Tensor] = None # (1, t5_seq_len)
def cache_uncond_embeddings(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
) -> None:
"""Pre-encode empty caption "" and cache the unconditional embeddings.
Must be called before the text encoder is deleted from GPU.
This matches diffusion-pipe-main behavior where empty caption embeddings
are pre-cached and swapped in during caption dropout.
"""
logger.info("Caching unconditional embeddings for caption dropout (encoding empty caption)...")
tokens = tokenize_strategy.tokenize("")
with torch.no_grad():
uncond_outputs = self.encode_tokens(tokenize_strategy, models, tokens, enable_dropout=False)
# Store as CPU tensors (1, seq_len, ...) to avoid GPU memory waste
self._uncond_prompt_embeds = uncond_outputs[0].cpu()
self._uncond_attn_mask = uncond_outputs[1].cpu()
self._uncond_t5_input_ids = uncond_outputs[2].cpu()
self._uncond_t5_attn_mask = uncond_outputs[3].cpu()
logger.info(" Unconditional embeddings cached successfully")
def __init__(self) -> None:
super().__init__()
def encode_tokens(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
tokens: List[torch.Tensor],
enable_dropout: bool = True,
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
) -> List[torch.Tensor]:
"""Encode Qwen3 tokens and return embeddings + T5 token IDs.
@ -134,82 +91,20 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
Returns:
[prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
"""
# Do not handle dropout here; handled dataset-side or in drop_cached_text_encoder_outputs()
qwen3_text_encoder = models[0]
qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask = tokens
# Handle dropout: replace dropped items with unconditional embeddings (matching diffusion-pipe-main)
batch_size = qwen3_input_ids.shape[0]
non_drop_indices = []
for i in range(batch_size):
drop = enable_dropout and (self.dropout_rate > 0.0 and random.random() < self.dropout_rate)
if not drop:
non_drop_indices.append(i)
encoder_device = qwen3_text_encoder.device
encoder_device = qwen3_text_encoder.device if hasattr(qwen3_text_encoder, 'device') else next(qwen3_text_encoder.parameters()).device
qwen3_input_ids = qwen3_input_ids.to(encoder_device)
qwen3_attn_mask = qwen3_attn_mask.to(encoder_device)
outputs = qwen3_text_encoder(input_ids=qwen3_input_ids, attention_mask=qwen3_attn_mask)
prompt_embeds = outputs.last_hidden_state
prompt_embeds[~qwen3_attn_mask.bool()] = 0
if len(non_drop_indices) > 0 and len(non_drop_indices) < batch_size:
# Only encode non-dropped items to save compute
nd_input_ids = qwen3_input_ids[non_drop_indices].to(encoder_device)
nd_attn_mask = qwen3_attn_mask[non_drop_indices].to(encoder_device)
elif len(non_drop_indices) == batch_size:
nd_input_ids = qwen3_input_ids.to(encoder_device)
nd_attn_mask = qwen3_attn_mask.to(encoder_device)
else:
nd_input_ids = None
nd_attn_mask = None
if nd_input_ids is not None:
outputs = qwen3_text_encoder(input_ids=nd_input_ids, attention_mask=nd_attn_mask)
nd_encoded_text = outputs.last_hidden_state
# Zero out padding positions
nd_encoded_text[~nd_attn_mask.bool()] = 0
# Build full batch: fill non-dropped with encoded, dropped with unconditional
if len(non_drop_indices) == batch_size:
prompt_embeds = nd_encoded_text
attn_mask = qwen3_attn_mask.to(encoder_device)
else:
# Get unconditional embeddings
if self._uncond_prompt_embeds is not None:
uncond_pe = self._uncond_prompt_embeds[0]
uncond_am = self._uncond_attn_mask[0]
uncond_t5_ids = self._uncond_t5_input_ids[0]
uncond_t5_am = self._uncond_t5_attn_mask[0]
else:
# Encode empty caption on-the-fly (text encoder still available)
uncond_tokens = tokenize_strategy.tokenize("")
uncond_ids = uncond_tokens[0].to(encoder_device)
uncond_mask = uncond_tokens[1].to(encoder_device)
uncond_out = qwen3_text_encoder(input_ids=uncond_ids, attention_mask=uncond_mask)
uncond_pe = uncond_out.last_hidden_state[0]
uncond_pe[~uncond_mask[0].bool()] = 0
uncond_am = uncond_mask[0]
uncond_t5_ids = uncond_tokens[2][0]
uncond_t5_am = uncond_tokens[3][0]
seq_len = qwen3_input_ids.shape[1]
hidden_size = nd_encoded_text.shape[-1] if nd_encoded_text is not None else uncond_pe.shape[-1]
dtype = nd_encoded_text.dtype if nd_encoded_text is not None else uncond_pe.dtype
prompt_embeds = torch.zeros((batch_size, seq_len, hidden_size), device=encoder_device, dtype=dtype)
attn_mask = torch.zeros((batch_size, seq_len), device=encoder_device, dtype=qwen3_attn_mask.dtype)
if len(non_drop_indices) > 0:
prompt_embeds[non_drop_indices] = nd_encoded_text
attn_mask[non_drop_indices] = nd_attn_mask
# Fill dropped items with unconditional embeddings
t5_input_ids = t5_input_ids.clone()
t5_attn_mask = t5_attn_mask.clone()
drop_indices = [i for i in range(batch_size) if i not in non_drop_indices]
for i in drop_indices:
prompt_embeds[i] = uncond_pe.to(device=encoder_device, dtype=dtype)
attn_mask[i] = uncond_am.to(device=encoder_device, dtype=qwen3_attn_mask.dtype)
t5_input_ids[i] = uncond_t5_ids.to(device=t5_input_ids.device, dtype=t5_input_ids.dtype)
t5_attn_mask[i] = uncond_t5_am.to(device=t5_attn_mask.device, dtype=t5_attn_mask.dtype)
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
return [prompt_embeds, qwen3_attn_mask, t5_input_ids, t5_attn_mask]
def drop_cached_text_encoder_outputs(
self,
@ -217,6 +112,7 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
attn_mask: torch.Tensor,
t5_input_ids: torch.Tensor,
t5_attn_mask: torch.Tensor,
caption_dropout_rates: Optional[torch.Tensor] = None,
) -> List[torch.Tensor]:
"""Apply dropout to cached text encoder outputs.
@ -224,37 +120,30 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
Replaces dropped items with pre-cached unconditional embeddings (from encoding "")
to match diffusion-pipe-main behavior.
"""
if prompt_embeds is not None and self.dropout_rate > 0.0:
# Clone to avoid in-place modification of cached tensors
prompt_embeds = prompt_embeds.clone()
if attn_mask is not None:
attn_mask = attn_mask.clone()
if t5_input_ids is not None:
t5_input_ids = t5_input_ids.clone()
if t5_attn_mask is not None:
t5_attn_mask = t5_attn_mask.clone()
if caption_dropout_rates is None or torch.all(caption_dropout_rates == 0.0).item():
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
for i in range(prompt_embeds.shape[0]):
if random.random() < self.dropout_rate:
if self._uncond_prompt_embeds is not None:
# Use pre-cached unconditional embeddings
prompt_embeds[i] = self._uncond_prompt_embeds[0].to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)
if attn_mask is not None:
attn_mask[i] = self._uncond_attn_mask[0].to(device=attn_mask.device, dtype=attn_mask.dtype)
if t5_input_ids is not None:
t5_input_ids[i] = self._uncond_t5_input_ids[0].to(device=t5_input_ids.device, dtype=t5_input_ids.dtype)
if t5_attn_mask is not None:
t5_attn_mask[i] = self._uncond_t5_attn_mask[0].to(device=t5_attn_mask.device, dtype=t5_attn_mask.dtype)
else:
# Fallback: zero out (should not happen if cache_uncond_embeddings was called)
logger.warning("Unconditional embeddings not cached, falling back to zeros for caption dropout")
prompt_embeds[i] = torch.zeros_like(prompt_embeds[i])
if attn_mask is not None:
attn_mask[i] = torch.zeros_like(attn_mask[i])
if t5_input_ids is not None:
t5_input_ids[i] = torch.zeros_like(t5_input_ids[i])
if t5_attn_mask is not None:
t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i])
# Clone to avoid in-place modification of cached tensors
prompt_embeds = prompt_embeds.clone()
if attn_mask is not None:
attn_mask = attn_mask.clone()
if t5_input_ids is not None:
t5_input_ids = t5_input_ids.clone()
if t5_attn_mask is not None:
t5_attn_mask = t5_attn_mask.clone()
for i in range(prompt_embeds.shape[0]):
if random.random() < caption_dropout_rates[i].item():
# Use pre-cached unconditional embeddings
prompt_embeds[i] = 0
if attn_mask is not None:
attn_mask[i] = 0
if t5_input_ids is not None:
t5_input_ids[i, 0] = 1 # Set to </s> token ID
t5_input_ids[i, 1:] = 0
if t5_attn_mask is not None:
t5_attn_mask[i, 0] = 1
t5_attn_mask[i, 1:] = 0
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
@ -297,6 +186,8 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
return False
if "t5_attn_mask" not in npz:
return False
if "caption_dropout_rate" not in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
@ -309,7 +200,8 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
attn_mask = data["attn_mask"]
t5_input_ids = data["t5_input_ids"]
t5_attn_mask = data["t5_attn_mask"]
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
caption_dropout_rate = data["caption_dropout_rate"]
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask, caption_dropout_rate]
def cache_batch_outputs(
self,
@ -323,12 +215,8 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
tokens_and_masks = tokenize_strategy.tokenize(captions)
with torch.no_grad():
# Always disable dropout during caching
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = anima_text_encoding_strategy.encode_tokens(
tokenize_strategy,
models,
tokens_and_masks,
enable_dropout=False,
tokenize_strategy, models, tokens_and_masks
)
# Convert to numpy for caching
@ -344,6 +232,7 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
attn_mask_i = attn_mask[i]
t5_input_ids_i = t5_input_ids[i]
t5_attn_mask_i = t5_attn_mask[i]
caption_dropout_rate = torch.tensor(info.caption_dropout_rate, dtype=torch.float32)
if self.cache_to_disk:
np.savez(
@ -352,9 +241,10 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
attn_mask=attn_mask_i,
t5_input_ids=t5_input_ids_i,
t5_attn_mask=t5_attn_mask_i,
caption_dropout_rate=caption_dropout_rate,
)
else:
info.text_encoder_outputs = (prompt_embeds_i, attn_mask_i, t5_input_ids_i, t5_attn_mask_i)
info.text_encoder_outputs = (prompt_embeds_i, attn_mask_i, t5_input_ids_i, t5_attn_mask_i, caption_dropout_rate)
class AnimaLatentsCachingStrategy(LatentsCachingStrategy):
@ -374,18 +264,10 @@ class AnimaLatentsCachingStrategy(LatentsCachingStrategy):
return self.ANIMA_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return (
os.path.splitext(absolute_path)[0]
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ self.ANIMA_LATENTS_NPZ_SUFFIX
)
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.ANIMA_LATENTS_NPZ_SUFFIX
def is_disk_cached_latents_expected(
self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool
):
return self._default_is_disk_cached_latents_expected(
8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True
)
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
def load_latents_from_disk(
self, npz_path: str, bucket_reso: Tuple[int, int]
@ -393,32 +275,23 @@ class AnimaLatentsCachingStrategy(LatentsCachingStrategy):
return self._default_load_latents_from_disk(8, npz_path, bucket_reso)
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
"""Cache batch of latents using WanVAE.
"""Cache batch of latents using Qwen Image VAE.
vae is expected to be the WanVAE_ model (not the wrapper).
vae is expected to be the Qwen Image VAE (AutoencoderKLQwenImage).
The encoding function handles the mean/std normalization.
"""
from library.anima_models import ANIMA_VAE_MEAN, ANIMA_VAE_STD
vae_device = next(vae.parameters()).device
vae_dtype = next(vae.parameters()).dtype
# Create scale tensors on VAE device
mean = torch.tensor(ANIMA_VAE_MEAN, dtype=vae_dtype, device=vae_device)
std = torch.tensor(ANIMA_VAE_STD, dtype=vae_dtype, device=vae_device)
scale = [mean, 1.0 / std]
vae: qwen_image_autoencoder_kl.AutoencoderKLQwenImage = vae
vae_device = vae.device
vae_dtype = vae.dtype
def encode_by_vae(img_tensor):
"""Encode image tensor to latents.
img_tensor: (B, C, H, W) in [-1, 1] range (already normalized by IMAGE_TRANSFORMS)
Need to add temporal dim to get (B, C, T=1, H, W) for WanVAE
Qwen Image VAE accepts inputs in (B, C, H, W) or (B, C, 1, H, W) shape.
Returns latents in (B, 16, 1, H/8, W/8) shape on CPU.
"""
# Add temporal dimension: (B, C, H, W) -> (B, C, 1, H, W)
img_tensor = img_tensor.unsqueeze(2)
img_tensor = img_tensor.to(vae_device, dtype=vae_dtype)
latents = vae.encode(img_tensor, scale)
latents = vae.encode_pixels_to_latents(img_tensor) # Keep 4D for input/output
return latents.to("cpu")
self._default_cache_batch_latents(

View File

@ -179,12 +179,15 @@ def split_train_val(
class ImageInfo:
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
def __init__(
self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str, caption_dropout_rate: float = 0.0
) -> None:
self.image_key: str = image_key
self.num_repeats: int = num_repeats
self.caption: str = caption
self.is_reg: bool = is_reg
self.absolute_path: str = absolute_path
self.caption_dropout_rate: float = caption_dropout_rate
self.image_size: Tuple[int, int] = None
self.resized_size: Tuple[int, int] = None
self.bucket_reso: Tuple[int, int] = None
@ -197,7 +200,7 @@ class ImageInfo:
)
self.cond_img_path: Optional[str] = None
self.image: Optional[Image.Image] = None # optional, original PIL Image
self.text_encoder_outputs_npz: Optional[str] = None # set in cache_text_encoder_outputs
self.text_encoder_outputs_npz: Optional[str] = None # filename. set in cache_text_encoder_outputs
# new
self.text_encoder_outputs: Optional[List[torch.Tensor]] = None
@ -1096,11 +1099,11 @@ class BaseDataset(torch.utils.data.Dataset):
def is_latent_cacheable(self):
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
def is_text_encoder_output_cacheable(self):
def is_text_encoder_output_cacheable(self, cache_supports_dropout: bool = False):
return all(
[
not (
subset.caption_dropout_rate > 0
subset.caption_dropout_rate > 0 and not cache_supports_dropout
or subset.shuffle_caption
or subset.token_warmup_step > 0
or subset.caption_tag_dropout_rate > 0
@ -2137,7 +2140,7 @@ class DreamBoothDataset(BaseDataset):
num_train_images += num_repeats * len(img_paths)
for img_path, caption, size in zip(img_paths, captions, sizes):
info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path)
info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path, subset.caption_dropout_rate)
info.resize_interpolation = (
subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
)
@ -2338,7 +2341,7 @@ class FineTuningDataset(BaseDataset):
if caption is None:
caption = ""
image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path)
image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path, subset.caption_dropout_rate)
image_info.resize_interpolation = (
subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
)
@ -2661,8 +2664,8 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
def is_latent_cacheable(self) -> bool:
return all([dataset.is_latent_cacheable() for dataset in self.datasets])
def is_text_encoder_output_cacheable(self) -> bool:
return all([dataset.is_text_encoder_output_cacheable() for dataset in self.datasets])
def is_text_encoder_output_cacheable(self, cache_supports_dropout: bool = False) -> bool:
return all([dataset.is_text_encoder_output_cacheable(cache_supports_dropout) for dataset in self.datasets])
def set_current_strategies(self):
for dataset in self.datasets:
@ -3578,6 +3581,7 @@ def get_sai_model_spec_dataclass(
flux: str = None,
lumina: str = None,
hunyuan_image: str = None,
anima: str = None,
optional_metadata: dict[str, str] | None = None,
) -> sai_model_spec.ModelSpecMetadata:
"""
@ -3609,7 +3613,8 @@ def get_sai_model_spec_dataclass(
model_config["lumina"] = lumina
if hunyuan_image is not None:
model_config["hunyuan_image"] = hunyuan_image
if anima is not None:
model_config["anima"] = anima
# Use the dataclass function directly
return sai_model_spec.build_metadata_dataclass(
state_dict,

View File

@ -0,0 +1,160 @@
import argparse
from safetensors.torch import save_file
from safetensors import safe_open
from library import train_util
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
COMFYUI_DIT_PREFIX = "diffusion_model."
COMFYUI_QWEN3_PREFIX = "text_encoders.qwen3_06b.transformer.model."
def main(args):
# load source safetensors
logger.info(f"Loading source file {args.src_path}")
state_dict = {}
with safe_open(args.src_path, framework="pt") as f:
metadata = f.metadata()
for k in f.keys():
state_dict[k] = f.get_tensor(k)
logger.info(f"Converting...")
keys = list(state_dict.keys())
count = 0
for k in keys:
if not args.reverse:
is_dit_lora = k.startswith("lora_unet_")
module_and_weight_name = "_".join(k.split("_")[2:]) # Remove `lora_unet_`or `lora_te_` prefix
# Split at the first dot, e.g., "block1_linear.weight" -> "block1_linear", "weight"
module_name, weight_name = module_and_weight_name.split(".", 1)
# Weight name conversion: lora_up/lora_down to lora_A/lora_B
if weight_name.startswith("lora_up"):
weight_name = weight_name.replace("lora_up", "lora_B")
elif weight_name.startswith("lora_down"):
weight_name = weight_name.replace("lora_down", "lora_A")
else:
# Keep other weight names as-is: e.g. alpha
pass
# Module name conversion: convert dots to underscores
original_module_name = module_name.replace("_", ".") # Convert to dot notation
# Convert back illegal dots in module names
# DiT
original_module_name = original_module_name.replace("llm.adapter", "llm_adapter")
original_module_name = original_module_name.replace(".linear.", ".linear_")
original_module_name = original_module_name.replace("t.embedding.norm", "t_embedding_norm")
original_module_name = original_module_name.replace("x.embedder", "x_embedder")
original_module_name = original_module_name.replace("adaln.modulation.cross_attn", "adaln_modulation_cross_attn")
original_module_name = original_module_name.replace("adaln.modulation.mlp", "adaln_modulation_mlp")
original_module_name = original_module_name.replace("cross.attn", "cross_attn")
original_module_name = original_module_name.replace("k.proj", "k_proj")
original_module_name = original_module_name.replace("k.norm", "k_norm")
original_module_name = original_module_name.replace("q.proj", "q_proj")
original_module_name = original_module_name.replace("q.norm", "q_norm")
original_module_name = original_module_name.replace("v.proj", "v_proj")
original_module_name = original_module_name.replace("o.proj", "o_proj")
original_module_name = original_module_name.replace("output.proj", "output_proj")
original_module_name = original_module_name.replace("self.attn", "self_attn")
original_module_name = original_module_name.replace("final.layer", "final_layer")
original_module_name = original_module_name.replace("adaln.modulation", "adaln_modulation")
original_module_name = original_module_name.replace("norm.cross.attn", "norm_cross_attn")
original_module_name = original_module_name.replace("norm.mlp", "norm_mlp")
original_module_name = original_module_name.replace("norm.self.attn", "norm_self_attn")
original_module_name = original_module_name.replace("out.proj", "out_proj")
# Qwen3
original_module_name = original_module_name.replace("embed.tokens", "embed_tokens")
original_module_name = original_module_name.replace("input.layernorm", "input_layernorm")
original_module_name = original_module_name.replace("down.proj", "down_proj")
original_module_name = original_module_name.replace("gate.proj", "gate_proj")
original_module_name = original_module_name.replace("up.proj", "up_proj")
original_module_name = original_module_name.replace("post.attention.layernorm", "post_attention_layernorm")
# Prefix conversion
new_prefix = COMFYUI_DIT_PREFIX if is_dit_lora else COMFYUI_QWEN3_PREFIX
new_k = f"{new_prefix}{original_module_name}.{weight_name}"
else:
if k.startswith(COMFYUI_DIT_PREFIX):
is_dit_lora = True
module_and_weight_name = k[len(COMFYUI_DIT_PREFIX) :]
elif k.startswith(COMFYUI_QWEN3_PREFIX):
is_dit_lora = False
module_and_weight_name = k[len(COMFYUI_QWEN3_PREFIX) :]
else:
logger.warning(f"Skipping unrecognized key {k}")
continue
# Get weight name
if ".lora_" in module_and_weight_name:
module_name, weight_name = module_and_weight_name.rsplit(".lora_", 1)
weight_name = "lora_" + weight_name
else:
module_name, weight_name = module_and_weight_name.rsplit(".", 1) # Keep other weight names as-is: e.g. alpha
# Weight name conversion: lora_A/lora_B to lora_up/lora_down
# Note: we only convert lora_A and lora_B weights, other weights are kept as-is
if weight_name.startswith("lora_B"):
weight_name = weight_name.replace("lora_B", "lora_up")
elif weight_name.startswith("lora_A"):
weight_name = weight_name.replace("lora_A", "lora_down")
# Module name conversion: convert dots to underscores
module_name = module_name.replace(".", "_") # Convert to underscore notation
# Prefix conversion
prefix = "lora_unet_" if is_dit_lora else "lora_te_"
new_k = f"{prefix}{module_name}.{weight_name}"
state_dict[new_k] = state_dict.pop(k)
count += 1
logger.info(f"Converted {count} keys")
if count == 0:
logger.warning("No keys were converted. Please check if the source file is in the expected format.")
elif count > 0 and count < len(keys):
logger.warning(
f"Only {count} out of {len(keys)} keys were converted. Please check if there are unexpected keys in the source file."
)
# Calculate hash
if metadata is not None:
logger.info(f"Calculating hashes and creating metadata...")
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
# save destination safetensors
logger.info(f"Saving destination file {args.dst_path}")
save_file(state_dict, args.dst_path, metadata=metadata)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert LoRA format")
parser.add_argument(
"src_path",
type=str,
default=None,
help="source path, sd-scripts format (or ComfyUI compatible format if --reverse is set, only supported for LoRAs converted by this script)",
)
parser.add_argument(
"dst_path",
type=str,
default=None,
help="destination path, ComfyUI compatible format (or sd-scripts format if --reverse is set)",
)
parser.add_argument("--reverse", action="store_true", help="reverse conversion direction")
args = parser.parse_args()
main(args)

View File

@ -1,18 +1,17 @@
# LoRA network module for Anima
import math
# LoRA network module for Anima
import ast
import os
import re
from typing import Dict, List, Optional, Tuple, Type, Union
import numpy as np
import torch
from library.utils import setup_logging
from networks.lora_flux import LoRAModule, LoRAInfModule
setup_logging()
import logging
setup_logging()
logger = logging.getLogger(__name__)
from networks.lora_flux import LoRAModule, LoRAInfModule
def create_network(
multiplier: float,
@ -29,68 +28,28 @@ def create_network(
if network_alpha is None:
network_alpha = 1.0
# type_dims: [self_attn_dim, cross_attn_dim, mlp_dim, mod_dim, llm_adapter_dim]
self_attn_dim = kwargs.get("self_attn_dim", None)
cross_attn_dim = kwargs.get("cross_attn_dim", None)
mlp_dim = kwargs.get("mlp_dim", None)
mod_dim = kwargs.get("mod_dim", None)
llm_adapter_dim = kwargs.get("llm_adapter_dim", None)
if self_attn_dim is not None:
self_attn_dim = int(self_attn_dim)
if cross_attn_dim is not None:
cross_attn_dim = int(cross_attn_dim)
if mlp_dim is not None:
mlp_dim = int(mlp_dim)
if mod_dim is not None:
mod_dim = int(mod_dim)
if llm_adapter_dim is not None:
llm_adapter_dim = int(llm_adapter_dim)
type_dims = [self_attn_dim, cross_attn_dim, mlp_dim, mod_dim, llm_adapter_dim]
if all([d is None for d in type_dims]):
type_dims = None
# emb_dims: [x_embedder, t_embedder, final_layer]
emb_dims = kwargs.get("emb_dims", None)
if emb_dims is not None:
emb_dims = emb_dims.strip()
if emb_dims.startswith("[") and emb_dims.endswith("]"):
emb_dims = emb_dims[1:-1]
emb_dims = [int(d) for d in emb_dims.split(",")]
assert len(emb_dims) == 3, f"invalid emb_dims: {emb_dims}, must be 3 dimensions (x_embedder, t_embedder, final_layer)"
# block selection
def parse_block_selection(selection: str, total_blocks: int) -> List[bool]:
if selection == "all":
return [True] * total_blocks
if selection == "none" or selection == "":
return [False] * total_blocks
selected = [False] * total_blocks
ranges = selection.split(",")
for r in ranges:
if "-" in r:
start, end = map(str.strip, r.split("-"))
start, end = int(start), int(end)
assert 0 <= start < total_blocks and 0 <= end < total_blocks and start <= end
for i in range(start, end + 1):
selected[i] = True
else:
index = int(r)
assert 0 <= index < total_blocks
selected[index] = True
return selected
train_block_indices = kwargs.get("train_block_indices", None)
if train_block_indices is not None:
num_blocks = len(unet.blocks) if hasattr(unet, 'blocks') else 999
train_block_indices = parse_block_selection(train_block_indices, num_blocks)
# train LLM adapter
train_llm_adapter = kwargs.get("train_llm_adapter", False)
train_llm_adapter = kwargs.get("train_llm_adapter", "false")
if train_llm_adapter is not None:
train_llm_adapter = True if train_llm_adapter == "True" else False
train_llm_adapter = True if train_llm_adapter.lower() == "true" else False
exclude_patterns = kwargs.get("exclude_patterns", None)
if exclude_patterns is None:
exclude_patterns = []
else:
exclude_patterns = ast.literal_eval(exclude_patterns)
if not isinstance(exclude_patterns, list):
exclude_patterns = [exclude_patterns]
# add default exclude patterns
exclude_patterns.append(r".*(_modulation|_norm|_embedder|final_layer).*")
# regular expression for module selection: exclude and include
include_patterns = kwargs.get("include_patterns", None)
if include_patterns is not None:
include_patterns = ast.literal_eval(include_patterns)
if not isinstance(include_patterns, list):
include_patterns = [include_patterns]
# rank/module dropout
rank_dropout = kwargs.get("rank_dropout", None)
@ -101,9 +60,43 @@ def create_network(
module_dropout = float(module_dropout)
# verbose
verbose = kwargs.get("verbose", False)
verbose = kwargs.get("verbose", "false")
if verbose is not None:
verbose = True if verbose == "True" else False
verbose = True if verbose.lower() == "true" else False
# regex-specific learning rates / dimensions
def parse_kv_pairs(kv_pair_str: str, is_int: bool) -> Dict[str, float]:
"""
Parse a string of key-value pairs separated by commas.
"""
pairs = {}
for pair in kv_pair_str.split(","):
pair = pair.strip()
if not pair:
continue
if "=" not in pair:
logger.warning(f"Invalid format: {pair}, expected 'key=value'")
continue
key, value = pair.split("=", 1)
key = key.strip()
value = value.strip()
try:
pairs[key] = int(value) if is_int else float(value)
except ValueError:
logger.warning(f"Invalid value for {key}: {value}")
return pairs
network_reg_lrs = kwargs.get("network_reg_lrs", None)
if network_reg_lrs is not None:
reg_lrs = parse_kv_pairs(network_reg_lrs, is_int=False)
else:
reg_lrs = None
network_reg_dims = kwargs.get("network_reg_dims", None)
if network_reg_dims is not None:
reg_dims = parse_kv_pairs(network_reg_dims, is_int=True)
else:
reg_dims = None
network = LoRANetwork(
text_encoders,
@ -115,9 +108,10 @@ def create_network(
rank_dropout=rank_dropout,
module_dropout=module_dropout,
train_llm_adapter=train_llm_adapter,
type_dims=type_dims,
emb_dims=emb_dims,
train_block_indices=train_block_indices,
exclude_patterns=exclude_patterns,
include_patterns=include_patterns,
reg_dims=reg_dims,
reg_lrs=reg_lrs,
verbose=verbose,
)
@ -137,6 +131,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, unet, weigh
if weights_sd is None:
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
@ -173,15 +168,15 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, unet, weigh
class LoRANetwork(torch.nn.Module):
# Target modules: DiT blocks
ANIMA_TARGET_REPLACE_MODULE = ["Block"]
# Target modules: DiT blocks, embedders, final layer. embedders and final layer are excluded by default.
ANIMA_TARGET_REPLACE_MODULE = ["Block", "PatchEmbed", "TimestepEmbedding", "FinalLayer"]
# Target modules: LLM Adapter blocks
ANIMA_ADAPTER_TARGET_REPLACE_MODULE = ["LLMAdapterTransformerBlock"]
# Target modules for text encoder (Qwen3)
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["Qwen3Attention", "Qwen3MLP", "Qwen3SdpaAttention", "Qwen3FlashAttention2"]
LORA_PREFIX_ANIMA = "lora_unet" # ComfyUI compatible
LORA_PREFIX_TEXT_ENCODER = "lora_te1" # Qwen3
LORA_PREFIX_TEXT_ENCODER = "lora_te" # Qwen3
def __init__(
self,
@ -197,9 +192,10 @@ class LoRANetwork(torch.nn.Module):
modules_dim: Optional[Dict[str, int]] = None,
modules_alpha: Optional[Dict[str, int]] = None,
train_llm_adapter: bool = False,
type_dims: Optional[List[int]] = None,
emb_dims: Optional[List[int]] = None,
train_block_indices: Optional[List[bool]] = None,
exclude_patterns: Optional[List[str]] = None,
include_patterns: Optional[List[str]] = None,
reg_dims: Optional[Dict[str, int]] = None,
reg_lrs: Optional[Dict[str, float]] = None,
verbose: Optional[bool] = False,
) -> None:
super().__init__()
@ -210,21 +206,36 @@ class LoRANetwork(torch.nn.Module):
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.train_llm_adapter = train_llm_adapter
self.type_dims = type_dims
self.emb_dims = emb_dims
self.train_block_indices = train_block_indices
self.reg_dims = reg_dims
self.reg_lrs = reg_lrs
self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None
self.loraplus_text_encoder_lr_ratio = None
if modules_dim is not None:
logger.info(f"create LoRA network from weights")
if self.emb_dims is None:
self.emb_dims = [0] * 3
logger.info("create LoRA network from weights")
else:
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
logger.info(
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
)
# compile regular expression if specified
def str_to_re_patterns(patterns: Optional[List[str]]) -> List[re.Pattern]:
re_patterns = []
if patterns is not None:
for pattern in patterns:
try:
re_pattern = re.compile(pattern)
except re.error as e:
logger.error(f"Invalid pattern '{pattern}': {e}")
continue
re_patterns.append(re_pattern)
return re_patterns
exclude_re_patterns = str_to_re_patterns(exclude_patterns)
include_re_patterns = str_to_re_patterns(include_patterns)
# create module instances
def create_modules(
@ -232,15 +243,9 @@ class LoRANetwork(torch.nn.Module):
text_encoder_idx: Optional[int],
root_module: torch.nn.Module,
target_replace_modules: List[str],
filter: Optional[str] = None,
default_dim: Optional[int] = None,
include_conv2d_if_filter: bool = False,
) -> Tuple[List[LoRAModule], List[str]]:
prefix = (
self.LORA_PREFIX_ANIMA
if is_unet
else self.LORA_PREFIX_TEXT_ENCODER
)
prefix = self.LORA_PREFIX_ANIMA if is_unet else self.LORA_PREFIX_TEXT_ENCODER
loras = []
skipped = []
@ -255,14 +260,16 @@ class LoRANetwork(torch.nn.Module):
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
if is_linear or is_conv2d:
lora_name = prefix + "." + (name + "." if name else "") + child_name
lora_name = lora_name.replace(".", "_")
original_name = (name + "." if name else "") + child_name
lora_name = f"{prefix}.{original_name}".replace(".", "_")
force_incl_conv2d = False
if filter is not None:
if filter not in lora_name:
continue
force_incl_conv2d = include_conv2d_if_filter
# exclude/include filter (fullmatch: pattern must match the entire original_name)
excluded = any(pattern.fullmatch(original_name) for pattern in exclude_re_patterns)
included = any(pattern.fullmatch(original_name) for pattern in include_re_patterns)
if excluded and not included:
if verbose:
logger.info(f"exclude: {original_name}")
continue
dim = None
alpha_val = None
@ -272,43 +279,18 @@ class LoRANetwork(torch.nn.Module):
dim = modules_dim[lora_name]
alpha_val = modules_alpha[lora_name]
else:
if is_linear or is_conv2d_1x1:
dim = default_dim if default_dim is not None else self.lora_dim
alpha_val = self.alpha
if is_unet and type_dims is not None:
# type_dims = [self_attn_dim, cross_attn_dim, mlp_dim, mod_dim, llm_adapter_dim]
# Order matters: check most specific identifiers first to avoid mismatches.
identifier_order = [
(4, ("llm_adapter",)),
(3, ("adaln_modulation",)),
(0, ("self_attn",)),
(1, ("cross_attn",)),
(2, ("mlp",)),
]
for idx, ids in identifier_order:
d = type_dims[idx]
if d is not None and all(id_str in lora_name for id_str in ids):
dim = d # 0 means skip
break
# block index filtering
if is_unet and dim and self.train_block_indices is not None and "blocks_" in lora_name:
# Extract block index from lora_name: "lora_unet_blocks_0_self_attn..."
parts = lora_name.split("_")
for pi, part in enumerate(parts):
if part == "blocks" and pi + 1 < len(parts):
try:
block_index = int(parts[pi + 1])
if not self.train_block_indices[block_index]:
dim = 0
except (ValueError, IndexError):
pass
break
elif force_incl_conv2d:
dim = default_dim if default_dim is not None else self.lora_dim
alpha_val = self.alpha
if self.reg_dims is not None:
for reg, d in self.reg_dims.items():
if re.fullmatch(reg, original_name):
dim = d
alpha_val = self.alpha
logger.info(f"Module {original_name} matched with regex '{reg}' -> dim: {dim}")
break
# fallback to default dim if not matched by reg_dims or reg_dims is not specified
if dim is None:
if is_linear or is_conv2d_1x1:
dim = default_dim if default_dim is not None else self.lora_dim
alpha_val = self.alpha
if dim is None or dim == 0:
if is_linear or is_conv2d_1x1:
@ -325,6 +307,7 @@ class LoRANetwork(torch.nn.Module):
rank_dropout=rank_dropout,
module_dropout=module_dropout,
)
lora.original_name = original_name
loras.append(lora)
if target_replace_modules is None:
@ -339,9 +322,7 @@ class LoRANetwork(torch.nn.Module):
if text_encoder is None:
continue
logger.info(f"create LoRA for Text Encoder {i+1}:")
te_loras, te_skipped = create_modules(
False, i, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
)
te_loras, te_skipped = create_modules(False, i, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
logger.info(f"create LoRA for Text Encoder {i+1}: {len(te_loras)} modules.")
self.text_encoder_loras.extend(te_loras)
skipped_te += te_skipped
@ -354,19 +335,6 @@ class LoRANetwork(torch.nn.Module):
self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
# emb_dims: [x_embedder, t_embedder, final_layer]
if self.emb_dims:
for filter_name, in_dim in zip(
["x_embedder", "t_embedder", "final_layer"],
self.emb_dims,
):
loras, _ = create_modules(
True, None, unet, None,
filter=filter_name, default_dim=in_dim,
include_conv2d_if_filter=(filter_name == "x_embedder"),
)
self.unet_loras.extend(loras)
logger.info(f"create LoRA for Anima DiT: {len(self.unet_loras)} modules.")
if verbose:
for lora in self.unet_loras:
@ -396,6 +364,7 @@ class LoRANetwork(torch.nn.Module):
def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
@ -443,10 +412,10 @@ class LoRANetwork(torch.nn.Module):
sd_for_lora = {}
for key in weights_sd.keys():
if key.startswith(lora.lora_name):
sd_for_lora[key[len(lora.lora_name) + 1:]] = weights_sd[key]
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
lora.merge_to(sd_for_lora, dtype, device)
logger.info(f"weights are merged")
logger.info("weights are merged")
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
self.loraplus_lr_ratio = loraplus_lr_ratio
@ -471,8 +440,29 @@ class LoRANetwork(torch.nn.Module):
def assemble_params(loras, lr, loraplus_ratio):
param_groups = {"lora": {}, "plus": {}}
reg_groups = {}
reg_lrs_list = list(self.reg_lrs.items()) if self.reg_lrs is not None else []
for lora in loras:
matched_reg_lr = None
for i, (regex_str, reg_lr) in enumerate(reg_lrs_list):
if re.fullmatch(regex_str, lora.original_name):
matched_reg_lr = (i, reg_lr)
logger.info(f"Module {lora.original_name} matched regex '{regex_str}' -> LR {reg_lr}")
break
for name, param in lora.named_parameters():
if matched_reg_lr is not None:
reg_idx, reg_lr = matched_reg_lr
group_key = f"reg_lr_{reg_idx}"
if group_key not in reg_groups:
reg_groups[group_key] = {"lora": {}, "plus": {}, "lr": reg_lr}
if loraplus_ratio is not None and "lora_up" in name:
reg_groups[group_key]["plus"][f"{lora.lora_name}.{name}"] = param
else:
reg_groups[group_key]["lora"][f"{lora.lora_name}.{name}"] = param
continue
if loraplus_ratio is not None and "lora_up" in name:
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
else:
@ -480,6 +470,23 @@ class LoRANetwork(torch.nn.Module):
params = []
descriptions = []
for group_key, group in reg_groups.items():
reg_lr = group["lr"]
for key in ("lora", "plus"):
param_data = {"params": group[key].values()}
if len(param_data["params"]) == 0:
continue
if key == "plus":
param_data["lr"] = reg_lr * loraplus_ratio if loraplus_ratio is not None else reg_lr
else:
param_data["lr"] = reg_lr
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
logger.info("NO LR skipping!")
continue
params.append(param_data)
desc = f"reg_lr_{group_key.split('_')[-1]}"
descriptions.append(desc + (" plus" if key == "plus" else ""))
for key in param_groups.keys():
param_data = {"params": param_groups[key].values()}
if len(param_data["params"]) == 0:
@ -498,10 +505,7 @@ class LoRANetwork(torch.nn.Module):
if self.text_encoder_loras:
loraplus_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio
te1_loras = [
lora for lora in self.text_encoder_loras
if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER)
]
te1_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER)]
if len(te1_loras) > 0:
logger.info(f"Text Encoder 1 (Qwen3): {len(te1_loras)} modules, LR {text_encoder_lr[0]}")
params, descriptions = assemble_params(te1_loras, text_encoder_lr[0], loraplus_ratio)

View File

@ -2,7 +2,7 @@
Diagnostic script to test Anima latent & text encoder caching independently.
Usage:
python test_anima_cache.py \
python manual_test_anima_cache.py \
--image_dir /path/to/images \
--qwen3_path /path/to/qwen3 \
--vae_path /path/to/vae.safetensors \
@ -30,10 +30,12 @@ from torchvision import transforms
IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".tiff"}
IMAGE_TRANSFORMS = transforms.Compose([
transforms.ToTensor(), # [0,1]
transforms.Normalize([0.5], [0.5]), # [-1,1]
])
IMAGE_TRANSFORMS = transforms.Compose(
[
transforms.ToTensor(), # [0,1]
transforms.Normalize([0.5], [0.5]), # [-1,1]
]
)
def find_image_caption_pairs(image_dir: str):
@ -60,35 +62,32 @@ def print_tensor_info(name: str, t, indent=2):
print(f"{prefix}{name}: None")
return
if isinstance(t, np.ndarray):
print(f"{prefix}{name}: numpy {t.dtype} shape={t.shape} "
f"min={t.min():.4f} max={t.max():.4f} mean={t.mean():.4f}")
print(f"{prefix}{name}: numpy {t.dtype} shape={t.shape} " f"min={t.min():.4f} max={t.max():.4f} mean={t.mean():.4f}")
elif isinstance(t, torch.Tensor):
print(f"{prefix}{name}: torch {t.dtype} shape={tuple(t.shape)} "
f"min={t.min().item():.4f} max={t.max().item():.4f} mean={t.float().mean().item():.4f}")
print(
f"{prefix}{name}: torch {t.dtype} shape={tuple(t.shape)} "
f"min={t.min().item():.4f} max={t.max().item():.4f} mean={t.float().mean().item():.4f}"
)
else:
print(f"{prefix}{name}: type={type(t)} value={t}")
# Test 1: Latent Cache
def test_latent_cache(args, pairs):
print("\n" + "=" * 70)
print("TEST 1: LATENT CACHING (VAE encode -> cache -> reload)")
print("=" * 70)
from library import anima_utils
from library.anima_models import ANIMA_VAE_MEAN, ANIMA_VAE_STD
from library import qwen_image_autoencoder_kl
# Load VAE
print("\n[1.1] Loading VAE...")
device = "cuda" if torch.cuda.is_available() else "cpu"
vae_dtype = torch.float32
vae, vae_mean, vae_std, vae_scale = anima_utils.load_anima_vae(
args.vae_path, dtype=vae_dtype, device=device
)
vae = qwen_image_autoencoder_kl.load_vae(args.vae_path, dtype=vae_dtype, device=device)
print(f" VAE loaded on {device}, dtype={vae_dtype}")
print(f" VAE mean (first 4): {ANIMA_VAE_MEAN[:4]}")
print(f" VAE std (first 4): {ANIMA_VAE_STD[:4]}")
for img_path, caption in pairs:
print(f"\n[1.2] Processing: {os.path.basename(img_path)}")
@ -96,13 +95,13 @@ def test_latent_cache(args, pairs):
# Load image
img = Image.open(img_path).convert("RGB")
img_np = np.array(img)
print(f" Raw image: {img_np.shape} dtype={img_np.dtype} "
f"min={img_np.min()} max={img_np.max()}")
print(f" Raw image: {img_np.shape} dtype={img_np.dtype} " f"min={img_np.min()} max={img_np.max()}")
# Apply IMAGE_TRANSFORMS (same as sd-scripts training)
img_tensor = IMAGE_TRANSFORMS(img_np)
print(f" After IMAGE_TRANSFORMS: shape={tuple(img_tensor.shape)} "
f"min={img_tensor.min():.4f} max={img_tensor.max():.4f}")
print(
f" After IMAGE_TRANSFORMS: shape={tuple(img_tensor.shape)} " f"min={img_tensor.min():.4f} max={img_tensor.max():.4f}"
)
# Check range is [-1, 1]
if img_tensor.min() < -1.01 or img_tensor.max() > 1.01:
@ -116,7 +115,7 @@ def test_latent_cache(args, pairs):
print(f" VAE input: shape={tuple(img_5d.shape)} dtype={img_5d.dtype}")
with torch.no_grad():
latents = vae.encode(img_5d, vae_scale)
latents = vae.encode_pixels_to_latents(img_5d)
latents_cpu = latents.cpu()
print_tensor_info("Encoded latents", latents_cpu)
@ -165,7 +164,9 @@ def test_latent_cache(args, pairs):
# Test 2: Text Encoder Output Cache
def test_text_encoder_cache(args, pairs):
# TODO Rewrite this
print("\n" + "=" * 70)
print("TEST 2: TEXT ENCODER OUTPUT CACHING")
print("=" * 70)
@ -175,9 +176,7 @@ def test_text_encoder_cache(args, pairs):
# Load tokenizers
print("\n[2.1] Loading tokenizers...")
qwen3_tokenizer = anima_utils.load_qwen3_tokenizer(args.qwen3_path)
t5_tokenizer = anima_utils.load_t5_tokenizer(
getattr(args, 't5_tokenizer_path', None)
)
t5_tokenizer = anima_utils.load_t5_tokenizer(getattr(args, "t5_tokenizer_path", None))
print(f" Qwen3 tokenizer vocab: {qwen3_tokenizer.vocab_size}")
print(f" T5 tokenizer vocab: {t5_tokenizer.vocab_size}")
@ -185,9 +184,7 @@ def test_text_encoder_cache(args, pairs):
print("\n[2.2] Loading Qwen3 text encoder...")
device = "cuda" if torch.cuda.is_available() else "cpu"
te_dtype = torch.bfloat16 if device == "cuda" else torch.float32
qwen3_model, _ = anima_utils.load_qwen3_text_encoder(
args.qwen3_path, dtype=te_dtype, device=device
)
qwen3_model, _ = anima_utils.load_qwen3_text_encoder(args.qwen3_path, dtype=te_dtype, device=device)
qwen3_model.eval()
# Create strategy objects
@ -199,9 +196,7 @@ def test_text_encoder_cache(args, pairs):
qwen3_max_length=args.qwen3_max_length,
t5_max_length=args.t5_max_length,
)
text_encoding_strategy = AnimaTextEncodingStrategy(
dropout_rate=0.0,
)
text_encoding_strategy = AnimaTextEncodingStrategy()
captions = [cap for _, cap in pairs]
print(f"\n[2.3] Tokenizing {len(captions)} captions...")
@ -221,10 +216,7 @@ def test_text_encoder_cache(args, pairs):
print(f"\n[2.4] Encoding with Qwen3 text encoder...")
with torch.no_grad():
prompt_embeds, attn_mask, t5_ids_out, t5_mask_out = text_encoding_strategy.encode_tokens(
tokenize_strategy,
[qwen3_model],
tokens_and_masks,
enable_dropout=False,
tokenize_strategy, [qwen3_model], tokens_and_masks
)
print(f" Encoding results:")
@ -374,13 +366,13 @@ def test_text_encoder_cache(args, pairs):
# Test 3: Full batch simulation
def test_full_batch_simulation(args, pairs):
print("\n" + "=" * 70)
print("TEST 3: FULL BATCH SIMULATION (mimics process_batch flow)")
print("=" * 70)
from library import anima_utils
from library.anima_models import ANIMA_VAE_MEAN, ANIMA_VAE_STD
from library.strategy_anima import AnimaTokenizeStrategy, AnimaTextEncodingStrategy
device = "cuda" if torch.cuda.is_available() else "cpu"
@ -390,14 +382,16 @@ def test_full_batch_simulation(args, pairs):
# Load all models
print("\n[3.1] Loading models...")
qwen3_tokenizer = anima_utils.load_qwen3_tokenizer(args.qwen3_path)
t5_tokenizer = anima_utils.load_t5_tokenizer(getattr(args, 't5_tokenizer_path', None))
t5_tokenizer = anima_utils.load_t5_tokenizer(getattr(args, "t5_tokenizer_path", None))
qwen3_model, _ = anima_utils.load_qwen3_text_encoder(args.qwen3_path, dtype=te_dtype, device=device)
qwen3_model.eval()
vae, _, _, vae_scale = anima_utils.load_anima_vae(args.vae_path, dtype=vae_dtype, device=device)
tokenize_strategy = AnimaTokenizeStrategy(
qwen3_tokenizer=qwen3_tokenizer, t5_tokenizer=t5_tokenizer,
qwen3_max_length=args.qwen3_max_length, t5_max_length=args.t5_max_length,
qwen3_tokenizer=qwen3_tokenizer,
t5_tokenizer=t5_tokenizer,
qwen3_max_length=args.qwen3_max_length,
t5_max_length=args.t5_max_length,
)
text_encoding_strategy = AnimaTextEncodingStrategy(dropout_rate=0.0)
@ -408,7 +402,10 @@ def test_full_batch_simulation(args, pairs):
tokens_and_masks = tokenize_strategy.tokenize(captions)
with torch.no_grad():
te_outputs = text_encoding_strategy.encode_tokens(
tokenize_strategy, [qwen3_model], tokens_and_masks, enable_dropout=False,
tokenize_strategy,
[qwen3_model],
tokens_and_masks,
enable_dropout=False,
)
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = te_outputs
@ -473,7 +470,7 @@ def test_full_batch_simulation(args, pairs):
print(f" text_encoder_conds: empty (no cache)")
# The critical condition
train_text_encoder_TRUE = True # OLD behavior (base class default, no override)
train_text_encoder_TRUE = True # OLD behavior (base class default, no override)
train_text_encoder_FALSE = False # NEW behavior (with is_train_text_encoder override)
cond_old = len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder_TRUE
@ -541,26 +538,19 @@ def test_full_batch_simulation(args, pairs):
# Main
def main():
parser = argparse.ArgumentParser(description="Test Anima caching mechanisms")
parser.add_argument("--image_dir", type=str, required=True,
help="Directory with image+txt pairs")
parser.add_argument("--qwen3_path", type=str, required=True,
help="Path to Qwen3 model (directory or safetensors)")
parser.add_argument("--vae_path", type=str, required=True,
help="Path to WanVAE safetensors")
parser.add_argument("--t5_tokenizer_path", type=str, default=None,
help="Path to T5 tokenizer (optional, uses bundled config)")
parser.add_argument("--image_dir", type=str, required=True, help="Directory with image+txt pairs")
parser.add_argument("--qwen3_path", type=str, required=True, help="Path to Qwen3 model (directory or safetensors)")
parser.add_argument("--vae_path", type=str, required=True, help="Path to WanVAE safetensors")
parser.add_argument("--t5_tokenizer_path", type=str, default=None, help="Path to T5 tokenizer (optional, uses bundled config)")
parser.add_argument("--qwen3_max_length", type=int, default=512)
parser.add_argument("--t5_max_length", type=int, default=512)
parser.add_argument("--cache_to_disk", action="store_true",
help="Also test disk cache round-trip")
parser.add_argument("--skip_latent", action="store_true",
help="Skip latent cache test")
parser.add_argument("--skip_text", action="store_true",
help="Skip text encoder cache test")
parser.add_argument("--skip_full", action="store_true",
help="Skip full batch simulation")
parser.add_argument("--cache_to_disk", action="store_true", help="Also test disk cache round-trip")
parser.add_argument("--skip_latent", action="store_true", help="Skip latent cache test")
parser.add_argument("--skip_text", action="store_true", help="Skip text encoder cache test")
parser.add_argument("--skip_full", action="store_true", help="Skip full batch simulation")
args = parser.parse_args()
# Find pairs

View File

@ -470,7 +470,7 @@ class NetworkTrainer:
loss = loss * weighting
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
loss = loss.mean(dim=list(range(1, loss.ndim))) # mean over all dims except batch
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights