mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
Support Anima model (#2260)
* Support Anima model * Update document and fix bug * Fix latent normlization * Fix typo * Fix cache embedding * fix typo in tests/test_anima_cache.py * Remove redundant argument apply_t5_attn_mask * Improving caching with argument caption_dropout_rate * Fix W&B logging bugs * Fix discrete_flow_shift default value
This commit is contained in:
887
anima_train.py
Normal file
887
anima_train.py
Normal file
@@ -0,0 +1,887 @@
|
|||||||
|
# Anima full finetune training script
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
import copy
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
from multiprocessing import Value
|
||||||
|
from typing import List
|
||||||
|
import toml
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from library import utils
|
||||||
|
from library.device_utils import init_ipex, clean_memory_on_device
|
||||||
|
|
||||||
|
init_ipex()
|
||||||
|
|
||||||
|
from accelerate.utils import set_seed
|
||||||
|
from library import deepspeed_utils, anima_models, anima_train_utils, anima_utils, strategy_base, strategy_anima, sai_model_spec
|
||||||
|
|
||||||
|
import library.train_util as train_util
|
||||||
|
|
||||||
|
from library.utils import setup_logging, add_logging_arguments
|
||||||
|
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
import library.config_util as config_util
|
||||||
|
|
||||||
|
from library.config_util import (
|
||||||
|
ConfigSanitizer,
|
||||||
|
BlueprintGenerator,
|
||||||
|
)
|
||||||
|
from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments
|
||||||
|
|
||||||
|
|
||||||
|
def train(args):
|
||||||
|
train_util.verify_training_args(args)
|
||||||
|
train_util.prepare_dataset_args(args, True)
|
||||||
|
deepspeed_utils.prepare_deepspeed_args(args)
|
||||||
|
setup_logging(args, reset=True)
|
||||||
|
|
||||||
|
# backward compatibility
|
||||||
|
if not args.skip_cache_check:
|
||||||
|
args.skip_cache_check = args.skip_latents_validity_check
|
||||||
|
|
||||||
|
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
||||||
|
logger.warning(
|
||||||
|
"cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled"
|
||||||
|
)
|
||||||
|
args.cache_text_encoder_outputs = True
|
||||||
|
|
||||||
|
if args.cpu_offload_checkpointing and not args.gradient_checkpointing:
|
||||||
|
logger.warning("cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled")
|
||||||
|
args.gradient_checkpointing = True
|
||||||
|
|
||||||
|
if getattr(args, 'unsloth_offload_checkpointing', False):
|
||||||
|
if not args.gradient_checkpointing:
|
||||||
|
logger.warning("unsloth_offload_checkpointing is enabled, so gradient_checkpointing is also enabled")
|
||||||
|
args.gradient_checkpointing = True
|
||||||
|
assert not args.cpu_offload_checkpointing, \
|
||||||
|
"Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing"
|
||||||
|
|
||||||
|
assert (
|
||||||
|
args.blocks_to_swap is None or args.blocks_to_swap == 0
|
||||||
|
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing"
|
||||||
|
|
||||||
|
assert (
|
||||||
|
args.blocks_to_swap is None or args.blocks_to_swap == 0
|
||||||
|
) or not getattr(args, 'unsloth_offload_checkpointing', False), \
|
||||||
|
"blocks_to_swap is not supported with unsloth_offload_checkpointing"
|
||||||
|
|
||||||
|
# Flash attention: validate availability
|
||||||
|
if getattr(args, 'flash_attn', False):
|
||||||
|
try:
|
||||||
|
import flash_attn # noqa: F401
|
||||||
|
logger.info("Flash Attention enabled for DiT blocks")
|
||||||
|
except ImportError:
|
||||||
|
logger.warning("flash_attn package not installed, falling back to PyTorch SDPA")
|
||||||
|
args.flash_attn = False
|
||||||
|
|
||||||
|
cache_latents = args.cache_latents
|
||||||
|
use_dreambooth_method = args.in_json is None
|
||||||
|
|
||||||
|
if args.seed is not None:
|
||||||
|
set_seed(args.seed)
|
||||||
|
|
||||||
|
# prepare caching strategy: must be set before preparing dataset
|
||||||
|
if args.cache_latents:
|
||||||
|
latents_caching_strategy = strategy_anima.AnimaLatentsCachingStrategy(
|
||||||
|
args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
|
||||||
|
)
|
||||||
|
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
|
||||||
|
|
||||||
|
# prepare dataset
|
||||||
|
if args.dataset_class is None:
|
||||||
|
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
|
||||||
|
if args.dataset_config is not None:
|
||||||
|
logger.info(f"Load dataset config from {args.dataset_config}")
|
||||||
|
user_config = config_util.load_user_config(args.dataset_config)
|
||||||
|
ignored = ["train_data_dir", "in_json"]
|
||||||
|
if any(getattr(args, attr) is not None for attr in ignored):
|
||||||
|
logger.warning(
|
||||||
|
"ignore following options because config file is found: {0}".format(", ".join(ignored))
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if use_dreambooth_method:
|
||||||
|
logger.info("Using DreamBooth method.")
|
||||||
|
user_config = {
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
|
||||||
|
args.train_data_dir, args.reg_data_dir
|
||||||
|
)
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
logger.info("Training with captions.")
|
||||||
|
user_config = {
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"subsets": [
|
||||||
|
{
|
||||||
|
"image_dir": args.train_data_dir,
|
||||||
|
"metadata_file": args.in_json,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
blueprint = blueprint_generator.generate(user_config, args)
|
||||||
|
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
|
else:
|
||||||
|
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||||
|
val_dataset_group = None
|
||||||
|
|
||||||
|
current_epoch = Value("i", 0)
|
||||||
|
current_step = Value("i", 0)
|
||||||
|
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||||
|
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||||
|
|
||||||
|
train_dataset_group.verify_bucket_reso_steps(8) # WanVAE spatial downscale = 8
|
||||||
|
|
||||||
|
# Anima uses embedding-level dropout (in AnimaTextEncodingStrategy) instead of
|
||||||
|
# dataset-level caption dropout, so we save the rate and zero out subset-level
|
||||||
|
# caption_dropout_rate to allow text encoder output caching.
|
||||||
|
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
|
||||||
|
if caption_dropout_rate > 0:
|
||||||
|
logger.info(f"Using embedding-level caption dropout rate: {caption_dropout_rate}")
|
||||||
|
for dataset in train_dataset_group.datasets:
|
||||||
|
for subset in dataset.subsets:
|
||||||
|
subset.caption_dropout_rate = 0.0
|
||||||
|
|
||||||
|
if args.debug_dataset:
|
||||||
|
if args.cache_text_encoder_outputs:
|
||||||
|
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
|
||||||
|
strategy_anima.AnimaTextEncoderOutputsCachingStrategy(
|
||||||
|
args.cache_text_encoder_outputs_to_disk,
|
||||||
|
args.text_encoder_batch_size,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
train_dataset_group.set_current_strategies()
|
||||||
|
train_util.debug_dataset(train_dataset_group, True)
|
||||||
|
return
|
||||||
|
if len(train_dataset_group) == 0:
|
||||||
|
logger.error("No data found. Please verify the metadata file and train_data_dir option.")
|
||||||
|
return
|
||||||
|
|
||||||
|
if cache_latents:
|
||||||
|
assert (
|
||||||
|
train_dataset_group.is_latent_cacheable()
|
||||||
|
), "when caching latents, either color_aug or random_crop cannot be used"
|
||||||
|
|
||||||
|
if args.cache_text_encoder_outputs:
|
||||||
|
assert (
|
||||||
|
train_dataset_group.is_text_encoder_output_cacheable()
|
||||||
|
), "when caching text encoder output, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used"
|
||||||
|
|
||||||
|
# prepare accelerator
|
||||||
|
logger.info("prepare accelerator")
|
||||||
|
accelerator = train_util.prepare_accelerator(args)
|
||||||
|
|
||||||
|
# mixed precision dtype
|
||||||
|
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||||
|
|
||||||
|
# parse transformer_dtype
|
||||||
|
transformer_dtype = None
|
||||||
|
if hasattr(args, 'transformer_dtype') and args.transformer_dtype is not None:
|
||||||
|
transformer_dtype_map = {
|
||||||
|
"float16": torch.float16,
|
||||||
|
"bfloat16": torch.bfloat16,
|
||||||
|
"float32": torch.float32,
|
||||||
|
}
|
||||||
|
transformer_dtype = transformer_dtype_map.get(args.transformer_dtype, None)
|
||||||
|
|
||||||
|
# Load tokenizers and set strategies
|
||||||
|
logger.info("Loading tokenizers...")
|
||||||
|
qwen3_text_encoder, qwen3_tokenizer = anima_utils.load_qwen3_text_encoder(
|
||||||
|
args.qwen3_path, dtype=weight_dtype, device="cpu"
|
||||||
|
)
|
||||||
|
t5_tokenizer = anima_utils.load_t5_tokenizer(
|
||||||
|
getattr(args, 't5_tokenizer_path', None)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set tokenize strategy
|
||||||
|
tokenize_strategy = strategy_anima.AnimaTokenizeStrategy(
|
||||||
|
qwen3_tokenizer=qwen3_tokenizer,
|
||||||
|
t5_tokenizer=t5_tokenizer,
|
||||||
|
qwen3_max_length=args.qwen3_max_token_length,
|
||||||
|
t5_max_length=args.t5_max_token_length,
|
||||||
|
)
|
||||||
|
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
|
||||||
|
|
||||||
|
# Set text encoding strategy
|
||||||
|
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
|
||||||
|
text_encoding_strategy = strategy_anima.AnimaTextEncodingStrategy(
|
||||||
|
dropout_rate=caption_dropout_rate,
|
||||||
|
)
|
||||||
|
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
|
||||||
|
|
||||||
|
# Prepare text encoder (always frozen for Anima)
|
||||||
|
qwen3_text_encoder.to(weight_dtype)
|
||||||
|
qwen3_text_encoder.requires_grad_(False)
|
||||||
|
|
||||||
|
# Cache text encoder outputs
|
||||||
|
sample_prompts_te_outputs = None
|
||||||
|
if args.cache_text_encoder_outputs:
|
||||||
|
qwen3_text_encoder.to(accelerator.device)
|
||||||
|
qwen3_text_encoder.eval()
|
||||||
|
|
||||||
|
text_encoder_caching_strategy = strategy_anima.AnimaTextEncoderOutputsCachingStrategy(
|
||||||
|
args.cache_text_encoder_outputs_to_disk,
|
||||||
|
args.text_encoder_batch_size,
|
||||||
|
args.skip_cache_check,
|
||||||
|
is_partial=False,
|
||||||
|
)
|
||||||
|
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy)
|
||||||
|
|
||||||
|
with accelerator.autocast():
|
||||||
|
train_dataset_group.new_cache_text_encoder_outputs([qwen3_text_encoder], accelerator)
|
||||||
|
|
||||||
|
# cache sample prompt embeddings
|
||||||
|
if args.sample_prompts is not None:
|
||||||
|
logger.info(f"Cache Text Encoder outputs for sample prompts: {args.sample_prompts}")
|
||||||
|
prompts = train_util.load_prompts(args.sample_prompts)
|
||||||
|
sample_prompts_te_outputs = {}
|
||||||
|
with accelerator.autocast(), torch.no_grad():
|
||||||
|
for prompt_dict in prompts:
|
||||||
|
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
|
||||||
|
if p not in sample_prompts_te_outputs:
|
||||||
|
logger.info(f" cache TE outputs for: {p}")
|
||||||
|
tokens_and_masks = tokenize_strategy.tokenize(p)
|
||||||
|
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
|
||||||
|
tokenize_strategy,
|
||||||
|
[qwen3_text_encoder],
|
||||||
|
tokens_and_masks,
|
||||||
|
enable_dropout=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Pre-cache unconditional embeddings for caption dropout before text encoder is deleted
|
||||||
|
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
|
||||||
|
if caption_dropout_rate > 0.0:
|
||||||
|
with accelerator.autocast():
|
||||||
|
text_encoding_strategy.cache_uncond_embeddings(tokenize_strategy, [qwen3_text_encoder])
|
||||||
|
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
# free text encoder memory
|
||||||
|
qwen3_text_encoder = None
|
||||||
|
clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
|
# Load VAE and cache latents
|
||||||
|
logger.info("Loading Anima VAE...")
|
||||||
|
vae, vae_mean, vae_std, vae_scale = anima_utils.load_anima_vae(args.vae_path, dtype=weight_dtype, device="cpu")
|
||||||
|
|
||||||
|
if cache_latents:
|
||||||
|
vae.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
vae.requires_grad_(False)
|
||||||
|
vae.eval()
|
||||||
|
|
||||||
|
train_dataset_group.new_cache_latents(vae, accelerator)
|
||||||
|
|
||||||
|
vae.to("cpu")
|
||||||
|
clean_memory_on_device(accelerator.device)
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
# Load DiT (MiniTrainDIT + optional LLM Adapter)
|
||||||
|
logger.info("Loading Anima DiT...")
|
||||||
|
dit = anima_utils.load_anima_dit(
|
||||||
|
args.dit_path,
|
||||||
|
dtype=weight_dtype,
|
||||||
|
device="cpu",
|
||||||
|
transformer_dtype=transformer_dtype,
|
||||||
|
llm_adapter_path=getattr(args, 'llm_adapter_path', None),
|
||||||
|
disable_mmap=getattr(args, 'disable_mmap_load_safetensors', False),
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.gradient_checkpointing:
|
||||||
|
dit.enable_gradient_checkpointing(
|
||||||
|
cpu_offload=args.cpu_offload_checkpointing,
|
||||||
|
unsloth_offload=getattr(args, 'unsloth_offload_checkpointing', False),
|
||||||
|
)
|
||||||
|
|
||||||
|
if getattr(args, 'flash_attn', False):
|
||||||
|
dit.set_flash_attn(True)
|
||||||
|
|
||||||
|
train_dit = args.learning_rate != 0
|
||||||
|
dit.requires_grad_(train_dit)
|
||||||
|
if not train_dit:
|
||||||
|
dit.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
|
||||||
|
# Block swap
|
||||||
|
is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
|
||||||
|
if is_swapping_blocks:
|
||||||
|
logger.info(f"Enable block swap: blocks_to_swap={args.blocks_to_swap}")
|
||||||
|
dit.enable_block_swap(args.blocks_to_swap, accelerator.device)
|
||||||
|
|
||||||
|
if not cache_latents:
|
||||||
|
vae.requires_grad_(False)
|
||||||
|
vae.eval()
|
||||||
|
vae.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
# Move scale tensors to same device as VAE for on-the-fly encoding
|
||||||
|
vae_scale = [s.to(accelerator.device) if isinstance(s, torch.Tensor) else s for s in vae_scale]
|
||||||
|
|
||||||
|
# Setup optimizer with parameter groups
|
||||||
|
if train_dit:
|
||||||
|
param_groups = anima_train_utils.get_anima_param_groups(
|
||||||
|
dit,
|
||||||
|
base_lr=args.learning_rate,
|
||||||
|
self_attn_lr=getattr(args, 'self_attn_lr', None),
|
||||||
|
cross_attn_lr=getattr(args, 'cross_attn_lr', None),
|
||||||
|
mlp_lr=getattr(args, 'mlp_lr', None),
|
||||||
|
mod_lr=getattr(args, 'mod_lr', None),
|
||||||
|
llm_adapter_lr=getattr(args, 'llm_adapter_lr', None),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
param_groups = []
|
||||||
|
|
||||||
|
training_models = []
|
||||||
|
if train_dit:
|
||||||
|
training_models.append(dit)
|
||||||
|
|
||||||
|
# calculate trainable parameters
|
||||||
|
n_params = 0
|
||||||
|
for group in param_groups:
|
||||||
|
for p in group["params"]:
|
||||||
|
n_params += p.numel()
|
||||||
|
|
||||||
|
accelerator.print(f"train dit: {train_dit}")
|
||||||
|
accelerator.print(f"number of training models: {len(training_models)}")
|
||||||
|
accelerator.print(f"number of trainable parameters: {n_params:,}")
|
||||||
|
|
||||||
|
# prepare optimizer
|
||||||
|
accelerator.print("prepare optimizer, data loader etc.")
|
||||||
|
|
||||||
|
if args.blockwise_fused_optimizers:
|
||||||
|
# Split params into per-block groups for blockwise fused optimizer
|
||||||
|
# Build param_id → lr mapping from param_groups to propagate per-component LRs
|
||||||
|
param_lr_map = {}
|
||||||
|
for group in param_groups:
|
||||||
|
for p in group['params']:
|
||||||
|
param_lr_map[id(p)] = group['lr']
|
||||||
|
|
||||||
|
grouped_params = []
|
||||||
|
param_group = {}
|
||||||
|
named_parameters = list(dit.named_parameters())
|
||||||
|
for name, p in named_parameters:
|
||||||
|
if not p.requires_grad:
|
||||||
|
continue
|
||||||
|
# Determine block type and index
|
||||||
|
if name.startswith("blocks."):
|
||||||
|
block_index = int(name.split(".")[1])
|
||||||
|
block_type = "blocks"
|
||||||
|
elif name.startswith("llm_adapter.blocks."):
|
||||||
|
block_index = int(name.split(".")[2])
|
||||||
|
block_type = "llm_adapter"
|
||||||
|
else:
|
||||||
|
block_index = -1
|
||||||
|
block_type = "other"
|
||||||
|
|
||||||
|
param_group_key = (block_type, block_index)
|
||||||
|
if param_group_key not in param_group:
|
||||||
|
param_group[param_group_key] = []
|
||||||
|
param_group[param_group_key].append(p)
|
||||||
|
|
||||||
|
for param_group_key, params in param_group.items():
|
||||||
|
# Use per-component LR from param_groups if available
|
||||||
|
lr = param_lr_map.get(id(params[0]), args.learning_rate)
|
||||||
|
grouped_params.append({"params": params, "lr": lr})
|
||||||
|
num_params = sum(p.numel() for p in params)
|
||||||
|
accelerator.print(f"block {param_group_key}: {num_params} parameters, lr={lr}")
|
||||||
|
|
||||||
|
# Create per-group optimizers
|
||||||
|
optimizers = []
|
||||||
|
for group in grouped_params:
|
||||||
|
_, _, opt = train_util.get_optimizer(args, trainable_params=[group])
|
||||||
|
optimizers.append(opt)
|
||||||
|
optimizer = optimizers[0] # avoid error in following code
|
||||||
|
|
||||||
|
logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers")
|
||||||
|
|
||||||
|
if train_util.is_schedulefree_optimizer(optimizers[0], args):
|
||||||
|
raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers")
|
||||||
|
optimizer_train_fn = lambda: None
|
||||||
|
optimizer_eval_fn = lambda: None
|
||||||
|
elif args.fused_backward_pass:
|
||||||
|
# Pass per-component param_groups directly to preserve per-component LRs
|
||||||
|
_, _, optimizer = train_util.get_optimizer(args, trainable_params=param_groups)
|
||||||
|
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)
|
||||||
|
else:
|
||||||
|
_, _, optimizer = train_util.get_optimizer(args, trainable_params=param_groups)
|
||||||
|
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)
|
||||||
|
|
||||||
|
# prepare dataloader
|
||||||
|
train_dataset_group.set_current_strategies()
|
||||||
|
|
||||||
|
n_workers = min(args.max_data_loader_n_workers, os.cpu_count())
|
||||||
|
train_dataloader = torch.utils.data.DataLoader(
|
||||||
|
train_dataset_group,
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=True,
|
||||||
|
collate_fn=collator,
|
||||||
|
num_workers=n_workers,
|
||||||
|
persistent_workers=args.persistent_data_loader_workers,
|
||||||
|
)
|
||||||
|
|
||||||
|
# calculate training steps
|
||||||
|
if args.max_train_epochs is not None:
|
||||||
|
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||||
|
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||||
|
)
|
||||||
|
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs: {args.max_train_steps}")
|
||||||
|
|
||||||
|
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||||
|
|
||||||
|
# lr scheduler
|
||||||
|
if args.blockwise_fused_optimizers:
|
||||||
|
lr_schedulers = [train_util.get_scheduler_fix(args, opt, accelerator.num_processes) for opt in optimizers]
|
||||||
|
lr_scheduler = lr_schedulers[0] # avoid error in following code
|
||||||
|
else:
|
||||||
|
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||||
|
|
||||||
|
# full fp16/bf16 training
|
||||||
|
if args.full_fp16:
|
||||||
|
assert args.mixed_precision == "fp16", "full_fp16 requires mixed_precision='fp16'"
|
||||||
|
accelerator.print("enable full fp16 training.")
|
||||||
|
dit.to(weight_dtype)
|
||||||
|
elif args.full_bf16:
|
||||||
|
assert args.mixed_precision == "bf16", "full_bf16 requires mixed_precision='bf16'"
|
||||||
|
accelerator.print("enable full bf16 training.")
|
||||||
|
dit.to(weight_dtype)
|
||||||
|
|
||||||
|
# move text encoder to GPU if not cached
|
||||||
|
if not args.cache_text_encoder_outputs and qwen3_text_encoder is not None:
|
||||||
|
qwen3_text_encoder.to(accelerator.device)
|
||||||
|
|
||||||
|
clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
|
# Prepare with accelerator
|
||||||
|
# Temporarily move non-training models off GPU to reduce memory during DDP init
|
||||||
|
# if not args.cache_text_encoder_outputs and qwen3_text_encoder is not None:
|
||||||
|
# qwen3_text_encoder.to("cpu")
|
||||||
|
# if not cache_latents and vae is not None:
|
||||||
|
# vae.to("cpu")
|
||||||
|
# clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
|
if args.deepspeed:
|
||||||
|
ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=dit)
|
||||||
|
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||||
|
ds_model, optimizer, train_dataloader, lr_scheduler
|
||||||
|
)
|
||||||
|
training_models = [ds_model]
|
||||||
|
else:
|
||||||
|
if train_dit:
|
||||||
|
dit = accelerator.prepare(dit, device_placement=[not is_swapping_blocks])
|
||||||
|
if is_swapping_blocks:
|
||||||
|
accelerator.unwrap_model(dit).move_to_device_except_swap_blocks(accelerator.device)
|
||||||
|
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
|
||||||
|
|
||||||
|
# Move non-training models back to GPU
|
||||||
|
if not args.cache_text_encoder_outputs and qwen3_text_encoder is not None:
|
||||||
|
qwen3_text_encoder.to(accelerator.device)
|
||||||
|
if not cache_latents and vae is not None:
|
||||||
|
vae.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
|
||||||
|
if args.full_fp16:
|
||||||
|
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||||
|
|
||||||
|
# resume
|
||||||
|
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||||
|
|
||||||
|
if args.fused_backward_pass:
|
||||||
|
import library.adafactor_fused
|
||||||
|
|
||||||
|
library.adafactor_fused.patch_adafactor_fused(optimizer)
|
||||||
|
|
||||||
|
for param_group in optimizer.param_groups:
|
||||||
|
for parameter in param_group["params"]:
|
||||||
|
if parameter.requires_grad:
|
||||||
|
|
||||||
|
def create_grad_hook(p_group):
|
||||||
|
def grad_hook(tensor: torch.Tensor):
|
||||||
|
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||||
|
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
|
||||||
|
optimizer.step_param(tensor, p_group)
|
||||||
|
tensor.grad = None
|
||||||
|
|
||||||
|
return grad_hook
|
||||||
|
|
||||||
|
parameter.register_post_accumulate_grad_hook(create_grad_hook(param_group))
|
||||||
|
|
||||||
|
elif args.blockwise_fused_optimizers:
|
||||||
|
# Prepare additional optimizers and lr schedulers
|
||||||
|
for i in range(1, len(optimizers)):
|
||||||
|
optimizers[i] = accelerator.prepare(optimizers[i])
|
||||||
|
lr_schedulers[i] = accelerator.prepare(lr_schedulers[i])
|
||||||
|
|
||||||
|
# Counters for blockwise gradient hook
|
||||||
|
optimizer_hooked_count = {}
|
||||||
|
num_parameters_per_group = [0] * len(optimizers)
|
||||||
|
parameter_optimizer_map = {}
|
||||||
|
|
||||||
|
for opt_idx, opt in enumerate(optimizers):
|
||||||
|
for param_group in opt.param_groups:
|
||||||
|
for parameter in param_group["params"]:
|
||||||
|
if parameter.requires_grad:
|
||||||
|
|
||||||
|
def grad_hook(parameter: torch.Tensor):
|
||||||
|
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||||
|
accelerator.clip_grad_norm_(parameter, args.max_grad_norm)
|
||||||
|
|
||||||
|
i = parameter_optimizer_map[parameter]
|
||||||
|
optimizer_hooked_count[i] += 1
|
||||||
|
if optimizer_hooked_count[i] == num_parameters_per_group[i]:
|
||||||
|
optimizers[i].step()
|
||||||
|
optimizers[i].zero_grad(set_to_none=True)
|
||||||
|
|
||||||
|
parameter.register_post_accumulate_grad_hook(grad_hook)
|
||||||
|
parameter_optimizer_map[parameter] = opt_idx
|
||||||
|
num_parameters_per_group[opt_idx] += 1
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
|
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||||
|
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
||||||
|
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
||||||
|
|
||||||
|
accelerator.print("running training")
|
||||||
|
accelerator.print(f" num examples: {train_dataset_group.num_train_images}")
|
||||||
|
accelerator.print(f" num batches per epoch: {len(train_dataloader)}")
|
||||||
|
accelerator.print(f" num epochs: {num_train_epochs}")
|
||||||
|
accelerator.print(
|
||||||
|
f" batch size per device: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
|
||||||
|
)
|
||||||
|
accelerator.print(f" gradient accumulation steps = {args.gradient_accumulation_steps}")
|
||||||
|
accelerator.print(f" total optimization steps: {args.max_train_steps}")
|
||||||
|
|
||||||
|
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
||||||
|
global_step = 0
|
||||||
|
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
init_kwargs = {}
|
||||||
|
if args.wandb_run_name:
|
||||||
|
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
||||||
|
if args.log_tracker_config is not None:
|
||||||
|
init_kwargs = toml.load(args.log_tracker_config)
|
||||||
|
accelerator.init_trackers(
|
||||||
|
"finetuning" if args.log_tracker_name is None else args.log_tracker_name,
|
||||||
|
config=train_util.get_sanitized_config_or_none(args),
|
||||||
|
init_kwargs=init_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
|
||||||
|
import wandb
|
||||||
|
wandb.define_metric("epoch")
|
||||||
|
wandb.define_metric("loss/epoch", step_metric="epoch")
|
||||||
|
|
||||||
|
if is_swapping_blocks:
|
||||||
|
accelerator.unwrap_model(dit).prepare_block_swap_before_forward()
|
||||||
|
|
||||||
|
# For --sample_at_first
|
||||||
|
optimizer_eval_fn()
|
||||||
|
anima_train_utils.sample_images(
|
||||||
|
accelerator, args, 0, global_step, dit, vae, vae_scale,
|
||||||
|
qwen3_text_encoder, tokenize_strategy, text_encoding_strategy,
|
||||||
|
sample_prompts_te_outputs,
|
||||||
|
)
|
||||||
|
optimizer_train_fn()
|
||||||
|
if len(accelerator.trackers) > 0:
|
||||||
|
accelerator.log({}, step=0)
|
||||||
|
|
||||||
|
# Show model info
|
||||||
|
unwrapped_dit = accelerator.unwrap_model(dit) if dit is not None else None
|
||||||
|
if unwrapped_dit is not None:
|
||||||
|
logger.info(f"dit device: {unwrapped_dit.t_embedding_norm.weight.device}, dtype: {unwrapped_dit.t_embedding_norm.weight.dtype}")
|
||||||
|
if qwen3_text_encoder is not None:
|
||||||
|
logger.info(f"qwen3 device: {next(qwen3_text_encoder.parameters()).device}")
|
||||||
|
if vae is not None:
|
||||||
|
logger.info(f"vae device: {next(vae.parameters()).device}")
|
||||||
|
|
||||||
|
loss_recorder = train_util.LossRecorder()
|
||||||
|
epoch = 0
|
||||||
|
for epoch in range(num_train_epochs):
|
||||||
|
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||||
|
current_epoch.value = epoch + 1
|
||||||
|
|
||||||
|
for m in training_models:
|
||||||
|
m.train()
|
||||||
|
|
||||||
|
for step, batch in enumerate(train_dataloader):
|
||||||
|
current_step.value = global_step
|
||||||
|
|
||||||
|
if args.blockwise_fused_optimizers:
|
||||||
|
optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step
|
||||||
|
|
||||||
|
with accelerator.accumulate(*training_models):
|
||||||
|
# Get latents
|
||||||
|
if "latents" in batch and batch["latents"] is not None:
|
||||||
|
latents = batch["latents"].to(accelerator.device, dtype=weight_dtype)
|
||||||
|
else:
|
||||||
|
with torch.no_grad():
|
||||||
|
# images are already [-1, 1] from IMAGE_TRANSFORMS, add temporal dim
|
||||||
|
images = batch["images"].to(accelerator.device, dtype=weight_dtype)
|
||||||
|
images = images.unsqueeze(2) # (B, C, 1, H, W)
|
||||||
|
latents = vae.encode(images, vae_scale).to(accelerator.device, dtype=weight_dtype)
|
||||||
|
|
||||||
|
if torch.any(torch.isnan(latents)):
|
||||||
|
accelerator.print("NaN found in latents, replacing with zeros")
|
||||||
|
latents = torch.nan_to_num(latents, 0, out=latents)
|
||||||
|
|
||||||
|
# Get text encoder outputs
|
||||||
|
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||||
|
if text_encoder_outputs_list is not None:
|
||||||
|
# Cached outputs
|
||||||
|
text_encoder_outputs_list = text_encoding_strategy.drop_cached_text_encoder_outputs(
|
||||||
|
*text_encoder_outputs_list
|
||||||
|
)
|
||||||
|
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoder_outputs_list
|
||||||
|
else:
|
||||||
|
# Encode on-the-fly
|
||||||
|
input_ids_list = batch["input_ids_list"]
|
||||||
|
qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask = input_ids_list
|
||||||
|
with torch.no_grad():
|
||||||
|
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoding_strategy.encode_tokens(
|
||||||
|
tokenize_strategy,
|
||||||
|
[qwen3_text_encoder],
|
||||||
|
[qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Move to device
|
||||||
|
prompt_embeds = prompt_embeds.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
attn_mask = attn_mask.to(accelerator.device)
|
||||||
|
t5_input_ids = t5_input_ids.to(accelerator.device, dtype=torch.long)
|
||||||
|
t5_attn_mask = t5_attn_mask.to(accelerator.device)
|
||||||
|
|
||||||
|
# Noise and timesteps
|
||||||
|
noise = torch.randn_like(latents)
|
||||||
|
|
||||||
|
noisy_model_input, timesteps, sigmas = anima_train_utils.get_noisy_model_input_and_timesteps(
|
||||||
|
args, latents, noise, accelerator.device, weight_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
# NaN checks
|
||||||
|
if torch.any(torch.isnan(noisy_model_input)):
|
||||||
|
accelerator.print("NaN found in noisy_model_input, replacing with zeros")
|
||||||
|
noisy_model_input = torch.nan_to_num(noisy_model_input, 0, out=noisy_model_input)
|
||||||
|
|
||||||
|
# Create padding mask
|
||||||
|
# padding_mask: (B, 1, H_latent, W_latent)
|
||||||
|
bs = latents.shape[0]
|
||||||
|
h_latent = latents.shape[-2]
|
||||||
|
w_latent = latents.shape[-1]
|
||||||
|
padding_mask = torch.zeros(
|
||||||
|
bs, 1, h_latent, w_latent,
|
||||||
|
dtype=weight_dtype, device=accelerator.device
|
||||||
|
)
|
||||||
|
|
||||||
|
# DiT forward (LLM adapter runs inside forward for DDP gradient sync)
|
||||||
|
if is_swapping_blocks:
|
||||||
|
accelerator.unwrap_model(dit).prepare_block_swap_before_forward()
|
||||||
|
|
||||||
|
with accelerator.autocast():
|
||||||
|
model_pred = dit(
|
||||||
|
noisy_model_input,
|
||||||
|
timesteps,
|
||||||
|
prompt_embeds,
|
||||||
|
padding_mask=padding_mask,
|
||||||
|
source_attention_mask=attn_mask,
|
||||||
|
t5_input_ids=t5_input_ids,
|
||||||
|
t5_attn_mask=t5_attn_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute loss (rectified flow: target = noise - latents)
|
||||||
|
target = noise - latents
|
||||||
|
|
||||||
|
# Weighting
|
||||||
|
weighting = anima_train_utils.compute_loss_weighting_for_anima(
|
||||||
|
weighting_scheme=args.weighting_scheme, sigmas=sigmas
|
||||||
|
)
|
||||||
|
|
||||||
|
# Loss
|
||||||
|
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, None)
|
||||||
|
loss = train_util.conditional_loss(
|
||||||
|
model_pred.float(), target.float(), args.loss_type, "none", huber_c
|
||||||
|
)
|
||||||
|
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||||
|
loss = apply_masked_loss(loss, batch)
|
||||||
|
loss = loss.mean([1, 2, 3, 4]) # (B, C, T, H, W) -> (B,)
|
||||||
|
|
||||||
|
if weighting is not None:
|
||||||
|
loss = loss * weighting
|
||||||
|
|
||||||
|
loss_weights = batch["loss_weights"]
|
||||||
|
loss = loss * loss_weights
|
||||||
|
loss = loss.mean()
|
||||||
|
|
||||||
|
accelerator.backward(loss)
|
||||||
|
|
||||||
|
if not (args.fused_backward_pass or args.blockwise_fused_optimizers):
|
||||||
|
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||||
|
params_to_clip = []
|
||||||
|
for m in training_models:
|
||||||
|
params_to_clip.extend(m.parameters())
|
||||||
|
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||||
|
|
||||||
|
optimizer.step()
|
||||||
|
lr_scheduler.step()
|
||||||
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
else:
|
||||||
|
# optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
|
||||||
|
lr_scheduler.step()
|
||||||
|
if args.blockwise_fused_optimizers:
|
||||||
|
for i in range(1, len(optimizers)):
|
||||||
|
lr_schedulers[i].step()
|
||||||
|
|
||||||
|
# Checks if the accelerator has performed an optimization step
|
||||||
|
if accelerator.sync_gradients:
|
||||||
|
progress_bar.update(1)
|
||||||
|
global_step += 1
|
||||||
|
|
||||||
|
optimizer_eval_fn()
|
||||||
|
anima_train_utils.sample_images(
|
||||||
|
accelerator, args, None, global_step, dit, vae, vae_scale,
|
||||||
|
qwen3_text_encoder, tokenize_strategy, text_encoding_strategy,
|
||||||
|
sample_prompts_te_outputs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save at specific steps
|
||||||
|
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
anima_train_utils.save_anima_model_on_epoch_end_or_stepwise(
|
||||||
|
args,
|
||||||
|
False,
|
||||||
|
accelerator,
|
||||||
|
save_dtype,
|
||||||
|
epoch,
|
||||||
|
num_train_epochs,
|
||||||
|
global_step,
|
||||||
|
accelerator.unwrap_model(dit) if train_dit else None,
|
||||||
|
)
|
||||||
|
optimizer_train_fn()
|
||||||
|
|
||||||
|
current_loss = loss.detach().item()
|
||||||
|
if len(accelerator.trackers) > 0:
|
||||||
|
logs = {"loss": current_loss}
|
||||||
|
train_util.append_lr_to_logs_with_names(
|
||||||
|
logs, lr_scheduler, args.optimizer_type,
|
||||||
|
["base", "self_attn", "cross_attn", "mlp", "mod", "llm_adapter"] if train_dit else []
|
||||||
|
)
|
||||||
|
accelerator.log(logs, step=global_step)
|
||||||
|
|
||||||
|
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||||
|
avr_loss: float = loss_recorder.moving_average
|
||||||
|
logs = {"avr_loss": avr_loss}
|
||||||
|
progress_bar.set_postfix(**logs)
|
||||||
|
|
||||||
|
if global_step >= args.max_train_steps:
|
||||||
|
break
|
||||||
|
|
||||||
|
if len(accelerator.trackers) > 0:
|
||||||
|
logs = {"loss/epoch": loss_recorder.moving_average, "epoch": epoch + 1}
|
||||||
|
accelerator.log(logs, step=global_step)
|
||||||
|
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
optimizer_eval_fn()
|
||||||
|
if args.save_every_n_epochs is not None:
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
anima_train_utils.save_anima_model_on_epoch_end_or_stepwise(
|
||||||
|
args,
|
||||||
|
True,
|
||||||
|
accelerator,
|
||||||
|
save_dtype,
|
||||||
|
epoch,
|
||||||
|
num_train_epochs,
|
||||||
|
global_step,
|
||||||
|
accelerator.unwrap_model(dit) if train_dit else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
anima_train_utils.sample_images(
|
||||||
|
accelerator, args, epoch + 1, global_step, dit, vae, vae_scale,
|
||||||
|
qwen3_text_encoder, tokenize_strategy, text_encoding_strategy,
|
||||||
|
sample_prompts_te_outputs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# End training
|
||||||
|
is_main_process = accelerator.is_main_process
|
||||||
|
dit = accelerator.unwrap_model(dit)
|
||||||
|
|
||||||
|
accelerator.end_training()
|
||||||
|
optimizer_eval_fn()
|
||||||
|
|
||||||
|
if args.save_state or args.save_state_on_train_end:
|
||||||
|
train_util.save_state_on_train_end(args, accelerator)
|
||||||
|
|
||||||
|
del accelerator
|
||||||
|
|
||||||
|
if is_main_process and train_dit:
|
||||||
|
anima_train_utils.save_anima_model_on_train_end(
|
||||||
|
args,
|
||||||
|
save_dtype,
|
||||||
|
epoch,
|
||||||
|
global_step,
|
||||||
|
dit,
|
||||||
|
)
|
||||||
|
logger.info("model saved.")
|
||||||
|
|
||||||
|
|
||||||
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
add_logging_arguments(parser)
|
||||||
|
train_util.add_sd_models_arguments(parser)
|
||||||
|
train_util.add_dataset_arguments(parser, True, True, True)
|
||||||
|
train_util.add_training_arguments(parser, False)
|
||||||
|
train_util.add_masked_loss_arguments(parser)
|
||||||
|
deepspeed_utils.add_deepspeed_arguments(parser)
|
||||||
|
train_util.add_sd_saving_arguments(parser)
|
||||||
|
train_util.add_optimizer_arguments(parser)
|
||||||
|
config_util.add_config_arguments(parser)
|
||||||
|
add_custom_train_arguments(parser)
|
||||||
|
train_util.add_dit_training_arguments(parser)
|
||||||
|
anima_train_utils.add_anima_training_arguments(parser)
|
||||||
|
sai_model_spec.add_model_spec_arguments(parser)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--blockwise_fused_optimizers",
|
||||||
|
action="store_true",
|
||||||
|
help="enable blockwise optimizers for fused backward pass and optimizer step",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cpu_offload_checkpointing",
|
||||||
|
action="store_true",
|
||||||
|
help="offload gradient checkpointing to CPU (reduces VRAM at cost of speed)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--unsloth_offload_checkpointing",
|
||||||
|
action="store_true",
|
||||||
|
help="offload activations to CPU RAM using async non-blocking transfers (faster than --cpu_offload_checkpointing). "
|
||||||
|
"Cannot be used with --cpu_offload_checkpointing or --blocks_to_swap.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--skip_latents_validity_check",
|
||||||
|
action="store_true",
|
||||||
|
help="[Deprecated] use 'skip_cache_check' instead",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = setup_parser()
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
train_util.verify_command_line_training_args(args)
|
||||||
|
args = train_util.read_config_from_file(args, parser)
|
||||||
|
|
||||||
|
train(args)
|
||||||
540
anima_train_network.py
Normal file
540
anima_train_network.py
Normal file
@@ -0,0 +1,540 @@
|
|||||||
|
# Anima LoRA training script
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import math
|
||||||
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from accelerate import Accelerator
|
||||||
|
from library.device_utils import init_ipex, clean_memory_on_device
|
||||||
|
|
||||||
|
init_ipex()
|
||||||
|
|
||||||
|
from library import anima_models, anima_train_utils, anima_utils, strategy_anima, strategy_base, train_util
|
||||||
|
import train_network
|
||||||
|
from library.utils import setup_logging
|
||||||
|
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AnimaNetworkTrainer(train_network.NetworkTrainer):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.sample_prompts_te_outputs = None
|
||||||
|
self.vae = None
|
||||||
|
self.vae_scale = None
|
||||||
|
self.qwen3_text_encoder = None
|
||||||
|
self.qwen3_tokenizer = None
|
||||||
|
self.t5_tokenizer = None
|
||||||
|
self.tokenize_strategy = None
|
||||||
|
self.text_encoding_strategy = None
|
||||||
|
|
||||||
|
def assert_extra_args(
|
||||||
|
self,
|
||||||
|
args,
|
||||||
|
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
|
||||||
|
val_dataset_group: Optional[train_util.DatasetGroup],
|
||||||
|
):
|
||||||
|
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
||||||
|
logger.warning(
|
||||||
|
"cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled"
|
||||||
|
)
|
||||||
|
args.cache_text_encoder_outputs = True
|
||||||
|
|
||||||
|
# Anima uses embedding-level dropout (in AnimaTextEncodingStrategy) instead of
|
||||||
|
# dataset-level caption dropout, so zero out subset-level rates to allow caching.
|
||||||
|
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
|
||||||
|
if caption_dropout_rate > 0:
|
||||||
|
logger.info(f"Using embedding-level caption dropout rate: {caption_dropout_rate}")
|
||||||
|
if hasattr(train_dataset_group, 'datasets'):
|
||||||
|
for dataset in train_dataset_group.datasets:
|
||||||
|
for subset in dataset.subsets:
|
||||||
|
subset.caption_dropout_rate = 0.0
|
||||||
|
|
||||||
|
if args.cache_text_encoder_outputs:
|
||||||
|
assert (
|
||||||
|
train_dataset_group.is_text_encoder_output_cacheable()
|
||||||
|
), "when caching Text Encoder output, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used"
|
||||||
|
|
||||||
|
assert (
|
||||||
|
args.blocks_to_swap is None or args.blocks_to_swap == 0
|
||||||
|
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing"
|
||||||
|
|
||||||
|
if getattr(args, 'unsloth_offload_checkpointing', False):
|
||||||
|
if not args.gradient_checkpointing:
|
||||||
|
logger.warning("unsloth_offload_checkpointing is enabled, so gradient_checkpointing is also enabled")
|
||||||
|
args.gradient_checkpointing = True
|
||||||
|
assert not args.cpu_offload_checkpointing, \
|
||||||
|
"Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing"
|
||||||
|
assert (
|
||||||
|
args.blocks_to_swap is None or args.blocks_to_swap == 0
|
||||||
|
), "blocks_to_swap is not supported with unsloth_offload_checkpointing"
|
||||||
|
|
||||||
|
# Flash attention: validate availability
|
||||||
|
if getattr(args, 'flash_attn', False):
|
||||||
|
try:
|
||||||
|
import flash_attn # noqa: F401
|
||||||
|
logger.info("Flash Attention enabled for DiT blocks")
|
||||||
|
except ImportError:
|
||||||
|
logger.warning("flash_attn package not installed, falling back to PyTorch SDPA")
|
||||||
|
args.flash_attn = False
|
||||||
|
|
||||||
|
if getattr(args, 'blockwise_fused_optimizers', False):
|
||||||
|
raise ValueError("blockwise_fused_optimizers is not supported with LoRA/NetworkTrainer")
|
||||||
|
|
||||||
|
train_dataset_group.verify_bucket_reso_steps(8) # WanVAE spatial downscale = 8
|
||||||
|
if val_dataset_group is not None:
|
||||||
|
val_dataset_group.verify_bucket_reso_steps(8)
|
||||||
|
|
||||||
|
def load_target_model(self, args, weight_dtype, accelerator):
|
||||||
|
# Load Qwen3 text encoder (tokenizers already loaded in get_tokenize_strategy)
|
||||||
|
logger.info("Loading Qwen3 text encoder...")
|
||||||
|
self.qwen3_text_encoder, _ = anima_utils.load_qwen3_text_encoder(
|
||||||
|
args.qwen3_path, dtype=weight_dtype, device="cpu"
|
||||||
|
)
|
||||||
|
self.qwen3_text_encoder.eval()
|
||||||
|
|
||||||
|
# Parse transformer_dtype
|
||||||
|
transformer_dtype = None
|
||||||
|
if hasattr(args, 'transformer_dtype') and args.transformer_dtype is not None:
|
||||||
|
transformer_dtype_map = {
|
||||||
|
"float16": torch.float16,
|
||||||
|
"bfloat16": torch.bfloat16,
|
||||||
|
"float32": torch.float32,
|
||||||
|
}
|
||||||
|
transformer_dtype = transformer_dtype_map.get(args.transformer_dtype, None)
|
||||||
|
|
||||||
|
# Load DiT
|
||||||
|
logger.info("Loading Anima DiT...")
|
||||||
|
dit = anima_utils.load_anima_dit(
|
||||||
|
args.dit_path,
|
||||||
|
dtype=weight_dtype,
|
||||||
|
device="cpu",
|
||||||
|
transformer_dtype=transformer_dtype,
|
||||||
|
llm_adapter_path=getattr(args, 'llm_adapter_path', None),
|
||||||
|
disable_mmap=getattr(args, 'disable_mmap_load_safetensors', False),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Flash attention
|
||||||
|
if getattr(args, 'flash_attn', False):
|
||||||
|
dit.set_flash_attn(True)
|
||||||
|
|
||||||
|
# Store unsloth preference so that when the base NetworkTrainer calls
|
||||||
|
# dit.enable_gradient_checkpointing(cpu_offload=...), we can override to use unsloth.
|
||||||
|
# The base trainer only passes cpu_offload, so we store the flag on the model.
|
||||||
|
self._use_unsloth_offload_checkpointing = getattr(args, 'unsloth_offload_checkpointing', False)
|
||||||
|
|
||||||
|
# Block swap
|
||||||
|
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
|
||||||
|
if self.is_swapping_blocks:
|
||||||
|
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
|
||||||
|
dit.enable_block_swap(args.blocks_to_swap, accelerator.device)
|
||||||
|
|
||||||
|
# Load VAE
|
||||||
|
logger.info("Loading Anima VAE...")
|
||||||
|
self.vae, vae_mean, vae_std, self.vae_scale = anima_utils.load_anima_vae(
|
||||||
|
args.vae_path, dtype=weight_dtype, device="cpu"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return format: (model_type, text_encoders, vae, unet)
|
||||||
|
return "anima", [self.qwen3_text_encoder], self.vae, dit
|
||||||
|
|
||||||
|
def get_tokenize_strategy(self, args):
|
||||||
|
# Load tokenizers from paths (called before load_target_model, so self.qwen3_tokenizer isn't set yet)
|
||||||
|
self.tokenize_strategy = strategy_anima.AnimaTokenizeStrategy(
|
||||||
|
qwen3_path=args.qwen3_path,
|
||||||
|
t5_tokenizer_path=getattr(args, 't5_tokenizer_path', None),
|
||||||
|
qwen3_max_length=args.qwen3_max_token_length,
|
||||||
|
t5_max_length=args.t5_max_token_length,
|
||||||
|
)
|
||||||
|
# Store references so load_target_model can reuse them
|
||||||
|
self.qwen3_tokenizer = self.tokenize_strategy.qwen3_tokenizer
|
||||||
|
self.t5_tokenizer = self.tokenize_strategy.t5_tokenizer
|
||||||
|
return self.tokenize_strategy
|
||||||
|
|
||||||
|
def get_tokenizers(self, tokenize_strategy: strategy_anima.AnimaTokenizeStrategy):
|
||||||
|
return [tokenize_strategy.qwen3_tokenizer]
|
||||||
|
|
||||||
|
def get_latents_caching_strategy(self, args):
|
||||||
|
return strategy_anima.AnimaLatentsCachingStrategy(
|
||||||
|
args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_text_encoding_strategy(self, args):
|
||||||
|
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
|
||||||
|
self.text_encoding_strategy = strategy_anima.AnimaTextEncodingStrategy(
|
||||||
|
dropout_rate=caption_dropout_rate,
|
||||||
|
)
|
||||||
|
return self.text_encoding_strategy
|
||||||
|
|
||||||
|
def post_process_network(self, args, accelerator, network, text_encoders, unet):
|
||||||
|
# Qwen3 text encoder is always frozen for Anima
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
|
||||||
|
if args.cache_text_encoder_outputs:
|
||||||
|
return None # no text encoders needed for encoding
|
||||||
|
return text_encoders
|
||||||
|
|
||||||
|
def get_text_encoders_train_flags(self, args, text_encoders):
|
||||||
|
return [False] # Qwen3 always frozen
|
||||||
|
|
||||||
|
def is_train_text_encoder(self, args):
|
||||||
|
return False # Qwen3 text encoder is always frozen for Anima
|
||||||
|
|
||||||
|
def get_text_encoder_outputs_caching_strategy(self, args):
|
||||||
|
if args.cache_text_encoder_outputs:
|
||||||
|
return strategy_anima.AnimaTextEncoderOutputsCachingStrategy(
|
||||||
|
args.cache_text_encoder_outputs_to_disk,
|
||||||
|
args.text_encoder_batch_size,
|
||||||
|
args.skip_cache_check,
|
||||||
|
is_partial=False,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def cache_text_encoder_outputs_if_needed(
|
||||||
|
self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
|
||||||
|
):
|
||||||
|
if args.cache_text_encoder_outputs:
|
||||||
|
if not args.lowram:
|
||||||
|
logger.info("move vae and unet to cpu to save memory")
|
||||||
|
org_vae_device = next(vae.parameters()).device
|
||||||
|
org_unet_device = unet.device
|
||||||
|
vae.to("cpu")
|
||||||
|
unet.to("cpu")
|
||||||
|
clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
|
logger.info("move text encoder to gpu")
|
||||||
|
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
||||||
|
|
||||||
|
with accelerator.autocast():
|
||||||
|
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
|
||||||
|
|
||||||
|
# cache sample prompts
|
||||||
|
if args.sample_prompts is not None:
|
||||||
|
logger.info(f"cache Text Encoder outputs for sample prompts: {args.sample_prompts}")
|
||||||
|
|
||||||
|
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
|
||||||
|
text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||||
|
|
||||||
|
prompts = train_util.load_prompts(args.sample_prompts)
|
||||||
|
sample_prompts_te_outputs = {}
|
||||||
|
with accelerator.autocast(), torch.no_grad():
|
||||||
|
for prompt_dict in prompts:
|
||||||
|
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
|
||||||
|
if p not in sample_prompts_te_outputs:
|
||||||
|
logger.info(f" cache TE outputs for: {p}")
|
||||||
|
tokens_and_masks = tokenize_strategy.tokenize(p)
|
||||||
|
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
|
||||||
|
tokenize_strategy,
|
||||||
|
text_encoders,
|
||||||
|
tokens_and_masks,
|
||||||
|
enable_dropout=False,
|
||||||
|
)
|
||||||
|
self.sample_prompts_te_outputs = sample_prompts_te_outputs
|
||||||
|
|
||||||
|
# Pre-cache unconditional embeddings for caption dropout before text encoder is deleted
|
||||||
|
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
|
||||||
|
text_encoding_strategy_for_uncond = strategy_base.TextEncodingStrategy.get_strategy()
|
||||||
|
if caption_dropout_rate > 0.0:
|
||||||
|
tokenize_strategy_for_uncond = strategy_base.TokenizeStrategy.get_strategy()
|
||||||
|
with accelerator.autocast():
|
||||||
|
text_encoding_strategy_for_uncond.cache_uncond_embeddings(tokenize_strategy_for_uncond, text_encoders)
|
||||||
|
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
# move text encoder back to cpu
|
||||||
|
logger.info("move text encoder back to cpu")
|
||||||
|
text_encoders[0].to("cpu")
|
||||||
|
clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
|
if not args.lowram:
|
||||||
|
logger.info("move vae and unet back to original device")
|
||||||
|
vae.to(org_vae_device)
|
||||||
|
unet.to(org_unet_device)
|
||||||
|
else:
|
||||||
|
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
||||||
|
|
||||||
|
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
|
||||||
|
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder] # compatibility
|
||||||
|
te = self.get_models_for_text_encoding(args, accelerator, text_encoders)
|
||||||
|
qwen3_te = te[0] if te is not None else None
|
||||||
|
|
||||||
|
anima_train_utils.sample_images(
|
||||||
|
accelerator, args, epoch, global_step, unet, vae, self.vae_scale,
|
||||||
|
qwen3_te, self.tokenize_strategy, self.text_encoding_strategy,
|
||||||
|
self.sample_prompts_te_outputs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
|
||||||
|
noise_scheduler = anima_train_utils.FlowMatchEulerDiscreteScheduler(
|
||||||
|
num_train_timesteps=1000, shift=args.discrete_flow_shift
|
||||||
|
)
|
||||||
|
return noise_scheduler
|
||||||
|
|
||||||
|
def encode_images_to_latents(self, args, vae, images):
|
||||||
|
# images are already [-1,1] from IMAGE_TRANSFORMS, add temporal dim
|
||||||
|
images = images.unsqueeze(2) # (B, C, 1, H, W)
|
||||||
|
# Ensure scale tensors are on the same device as images
|
||||||
|
vae_device = images.device
|
||||||
|
scale = [s.to(vae_device) if isinstance(s, torch.Tensor) else s for s in self.vae_scale]
|
||||||
|
return vae.encode(images, scale)
|
||||||
|
|
||||||
|
def shift_scale_latents(self, args, latents):
|
||||||
|
# Latents already normalized by vae.encode with scale
|
||||||
|
return latents
|
||||||
|
|
||||||
|
def get_noise_pred_and_target(
|
||||||
|
self,
|
||||||
|
args,
|
||||||
|
accelerator,
|
||||||
|
noise_scheduler,
|
||||||
|
latents,
|
||||||
|
batch,
|
||||||
|
text_encoder_conds,
|
||||||
|
unet,
|
||||||
|
network,
|
||||||
|
weight_dtype,
|
||||||
|
train_unet,
|
||||||
|
is_train=True,
|
||||||
|
):
|
||||||
|
# Sample noise
|
||||||
|
noise = torch.randn_like(latents)
|
||||||
|
|
||||||
|
# Get noisy model input and timesteps
|
||||||
|
noisy_model_input, timesteps, sigmas = anima_train_utils.get_noisy_model_input_and_timesteps(
|
||||||
|
args, latents, noise, accelerator.device, weight_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
# Gradient checkpointing support
|
||||||
|
if args.gradient_checkpointing:
|
||||||
|
noisy_model_input.requires_grad_(True)
|
||||||
|
for t in text_encoder_conds:
|
||||||
|
if t is not None and t.dtype.is_floating_point:
|
||||||
|
t.requires_grad_(True)
|
||||||
|
|
||||||
|
# Unpack text encoder conditions
|
||||||
|
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoder_conds
|
||||||
|
|
||||||
|
# Move to device
|
||||||
|
prompt_embeds = prompt_embeds.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
attn_mask = attn_mask.to(accelerator.device)
|
||||||
|
t5_input_ids = t5_input_ids.to(accelerator.device, dtype=torch.long)
|
||||||
|
t5_attn_mask = t5_attn_mask.to(accelerator.device)
|
||||||
|
|
||||||
|
# Create padding mask
|
||||||
|
bs = latents.shape[0]
|
||||||
|
h_latent = latents.shape[-2]
|
||||||
|
w_latent = latents.shape[-1]
|
||||||
|
padding_mask = torch.zeros(
|
||||||
|
bs, 1, h_latent, w_latent,
|
||||||
|
dtype=weight_dtype, device=accelerator.device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare block swap
|
||||||
|
if self.is_swapping_blocks:
|
||||||
|
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
|
||||||
|
|
||||||
|
# Call model (LLM adapter runs inside forward for DDP gradient sync)
|
||||||
|
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||||
|
model_pred = unet(
|
||||||
|
noisy_model_input,
|
||||||
|
timesteps,
|
||||||
|
prompt_embeds,
|
||||||
|
padding_mask=padding_mask,
|
||||||
|
source_attention_mask=attn_mask,
|
||||||
|
t5_input_ids=t5_input_ids,
|
||||||
|
t5_attn_mask=t5_attn_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Rectified flow target: noise - latents
|
||||||
|
target = noise - latents
|
||||||
|
|
||||||
|
# Loss weighting
|
||||||
|
weighting = anima_train_utils.compute_loss_weighting_for_anima(
|
||||||
|
weighting_scheme=args.weighting_scheme, sigmas=sigmas
|
||||||
|
)
|
||||||
|
|
||||||
|
# Differential output preservation
|
||||||
|
if "custom_attributes" in batch:
|
||||||
|
diff_output_pr_indices = []
|
||||||
|
for i, custom_attributes in enumerate(batch["custom_attributes"]):
|
||||||
|
if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
|
||||||
|
diff_output_pr_indices.append(i)
|
||||||
|
|
||||||
|
if len(diff_output_pr_indices) > 0:
|
||||||
|
network.set_multiplier(0.0)
|
||||||
|
with torch.no_grad(), accelerator.autocast():
|
||||||
|
if self.is_swapping_blocks:
|
||||||
|
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
|
||||||
|
model_pred_prior = unet(
|
||||||
|
noisy_model_input[diff_output_pr_indices],
|
||||||
|
timesteps[diff_output_pr_indices],
|
||||||
|
prompt_embeds[diff_output_pr_indices],
|
||||||
|
padding_mask=padding_mask[diff_output_pr_indices],
|
||||||
|
source_attention_mask=attn_mask[diff_output_pr_indices],
|
||||||
|
t5_input_ids=t5_input_ids[diff_output_pr_indices],
|
||||||
|
t5_attn_mask=t5_attn_mask[diff_output_pr_indices],
|
||||||
|
)
|
||||||
|
network.set_multiplier(1.0)
|
||||||
|
|
||||||
|
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
|
||||||
|
|
||||||
|
return model_pred, target, timesteps, weighting
|
||||||
|
|
||||||
|
def process_batch(
|
||||||
|
self, batch, text_encoders, unet, network, vae, noise_scheduler,
|
||||||
|
vae_dtype, weight_dtype, accelerator, args,
|
||||||
|
text_encoding_strategy, tokenize_strategy,
|
||||||
|
is_train=True, train_text_encoder=True, train_unet=True,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Override base process_batch for 5D video latents (B, C, T, H, W).
|
||||||
|
|
||||||
|
Base class assumes 4D (B, C, H, W) for loss.mean([1,2,3]) and weighting broadcast.
|
||||||
|
"""
|
||||||
|
import typing
|
||||||
|
from library.custom_train_functions import apply_masked_loss
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
if "latents" in batch and batch["latents"] is not None:
|
||||||
|
latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device))
|
||||||
|
else:
|
||||||
|
if args.vae_batch_size is None or len(batch["images"]) <= args.vae_batch_size:
|
||||||
|
latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype))
|
||||||
|
else:
|
||||||
|
chunks = [
|
||||||
|
batch["images"][i : i + args.vae_batch_size] for i in range(0, len(batch["images"]), args.vae_batch_size)
|
||||||
|
]
|
||||||
|
list_latents = []
|
||||||
|
for chunk in chunks:
|
||||||
|
with torch.no_grad():
|
||||||
|
chunk = self.encode_images_to_latents(args, vae, chunk.to(accelerator.device, dtype=vae_dtype))
|
||||||
|
list_latents.append(chunk)
|
||||||
|
latents = torch.cat(list_latents, dim=0)
|
||||||
|
|
||||||
|
if torch.any(torch.isnan(latents)):
|
||||||
|
accelerator.print("NaN found in latents, replacing with zeros")
|
||||||
|
latents = typing.cast(torch.FloatTensor, torch.nan_to_num(latents, 0, out=latents))
|
||||||
|
|
||||||
|
latents = self.shift_scale_latents(args, latents)
|
||||||
|
|
||||||
|
# Text encoder conditions
|
||||||
|
text_encoder_conds = []
|
||||||
|
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||||
|
if text_encoder_outputs_list is not None:
|
||||||
|
text_encoder_conds = text_encoder_outputs_list
|
||||||
|
|
||||||
|
if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
|
||||||
|
with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast():
|
||||||
|
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
|
||||||
|
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens(
|
||||||
|
tokenize_strategy,
|
||||||
|
self.get_models_for_text_encoding(args, accelerator, text_encoders),
|
||||||
|
input_ids,
|
||||||
|
)
|
||||||
|
if args.full_fp16:
|
||||||
|
encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds]
|
||||||
|
|
||||||
|
if len(text_encoder_conds) == 0:
|
||||||
|
text_encoder_conds = encoded_text_encoder_conds
|
||||||
|
else:
|
||||||
|
for i in range(len(encoded_text_encoder_conds)):
|
||||||
|
if encoded_text_encoder_conds[i] is not None:
|
||||||
|
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
||||||
|
|
||||||
|
noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target(
|
||||||
|
args, accelerator, noise_scheduler, latents, batch,
|
||||||
|
text_encoder_conds, unet, network, weight_dtype, train_unet, is_train=is_train,
|
||||||
|
)
|
||||||
|
|
||||||
|
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||||
|
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||||
|
|
||||||
|
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||||
|
loss = apply_masked_loss(loss, batch)
|
||||||
|
|
||||||
|
# Reduce all non-batch dims: (B, C, T, H, W) -> (B,) for 5D, (B, C, H, W) -> (B,) for 4D
|
||||||
|
reduce_dims = list(range(1, loss.ndim))
|
||||||
|
loss = loss.mean(reduce_dims)
|
||||||
|
|
||||||
|
# Apply weighting after reducing to (B,)
|
||||||
|
if weighting is not None:
|
||||||
|
loss = loss * weighting
|
||||||
|
|
||||||
|
loss_weights = batch["loss_weights"]
|
||||||
|
loss = loss * loss_weights
|
||||||
|
|
||||||
|
loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
|
||||||
|
return loss.mean()
|
||||||
|
|
||||||
|
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def get_sai_model_spec(self, args):
|
||||||
|
return train_util.get_sai_model_spec(None, args, False, True, False, is_stable_diffusion_ckpt=True)
|
||||||
|
|
||||||
|
def update_metadata(self, metadata, args):
|
||||||
|
metadata["ss_weighting_scheme"] = args.weighting_scheme
|
||||||
|
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
|
||||||
|
metadata["ss_timestep_sample_method"] = getattr(args, 'timestep_sample_method', 'logit_normal')
|
||||||
|
metadata["ss_sigmoid_scale"] = getattr(args, 'sigmoid_scale', 1.0)
|
||||||
|
|
||||||
|
def is_text_encoder_not_needed_for_training(self, args):
|
||||||
|
return args.cache_text_encoder_outputs
|
||||||
|
|
||||||
|
def prepare_unet_with_accelerator(
|
||||||
|
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
|
||||||
|
) -> torch.nn.Module:
|
||||||
|
# The base NetworkTrainer only calls enable_gradient_checkpointing(cpu_offload=True/False),
|
||||||
|
# so we re-apply with unsloth_offload if needed (after base has already enabled it).
|
||||||
|
if self._use_unsloth_offload_checkpointing and args.gradient_checkpointing:
|
||||||
|
unet.enable_gradient_checkpointing(unsloth_offload=True)
|
||||||
|
|
||||||
|
if not self.is_swapping_blocks:
|
||||||
|
return super().prepare_unet_with_accelerator(args, accelerator, unet)
|
||||||
|
|
||||||
|
dit = unet
|
||||||
|
dit = accelerator.prepare(dit, device_placement=[not self.is_swapping_blocks])
|
||||||
|
accelerator.unwrap_model(dit).move_to_device_except_swap_blocks(accelerator.device)
|
||||||
|
accelerator.unwrap_model(dit).prepare_block_swap_before_forward()
|
||||||
|
|
||||||
|
return dit
|
||||||
|
|
||||||
|
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True):
|
||||||
|
# Drop cached text encoder outputs for caption dropout
|
||||||
|
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||||
|
if text_encoder_outputs_list is not None:
|
||||||
|
text_encoding_strategy: strategy_anima.AnimaTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||||
|
text_encoder_outputs_list = text_encoding_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list)
|
||||||
|
batch["text_encoder_outputs_list"] = text_encoder_outputs_list
|
||||||
|
|
||||||
|
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
||||||
|
if self.is_swapping_blocks:
|
||||||
|
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
|
||||||
|
|
||||||
|
|
||||||
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
|
parser = train_network.setup_parser()
|
||||||
|
train_util.add_dit_training_arguments(parser)
|
||||||
|
anima_train_utils.add_anima_training_arguments(parser)
|
||||||
|
parser.add_argument(
|
||||||
|
"--unsloth_offload_checkpointing",
|
||||||
|
action="store_true",
|
||||||
|
help="offload activations to CPU RAM using async non-blocking transfers (faster than --cpu_offload_checkpointing). "
|
||||||
|
"Cannot be used with --cpu_offload_checkpointing or --blocks_to_swap.",
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = setup_parser()
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
train_util.verify_command_line_training_args(args)
|
||||||
|
args = train_util.read_config_from_file(args, parser)
|
||||||
|
|
||||||
|
trainer = AnimaNetworkTrainer()
|
||||||
|
trainer.train(args)
|
||||||
30
configs/qwen3_06b/config.json
Normal file
30
configs/qwen3_06b/config.json
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
{
|
||||||
|
"architectures": [
|
||||||
|
"Qwen3ForCausalLM"
|
||||||
|
],
|
||||||
|
"attention_bias": false,
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"bos_token_id": 151643,
|
||||||
|
"eos_token_id": 151643,
|
||||||
|
"head_dim": 128,
|
||||||
|
"hidden_act": "silu",
|
||||||
|
"hidden_size": 1024,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 3072,
|
||||||
|
"max_position_embeddings": 32768,
|
||||||
|
"max_window_layers": 28,
|
||||||
|
"model_type": "qwen3",
|
||||||
|
"num_attention_heads": 16,
|
||||||
|
"num_hidden_layers": 28,
|
||||||
|
"num_key_value_heads": 8,
|
||||||
|
"rms_norm_eps": 1e-06,
|
||||||
|
"rope_scaling": null,
|
||||||
|
"rope_theta": 1000000,
|
||||||
|
"sliding_window": null,
|
||||||
|
"tie_word_embeddings": true,
|
||||||
|
"torch_dtype": "bfloat16",
|
||||||
|
"transformers_version": "4.51.0",
|
||||||
|
"use_cache": true,
|
||||||
|
"use_sliding_window": false,
|
||||||
|
"vocab_size": 151936
|
||||||
|
}
|
||||||
151388
configs/qwen3_06b/merges.txt
Normal file
151388
configs/qwen3_06b/merges.txt
Normal file
File diff suppressed because it is too large
Load Diff
303282
configs/qwen3_06b/tokenizer.json
Normal file
303282
configs/qwen3_06b/tokenizer.json
Normal file
File diff suppressed because it is too large
Load Diff
239
configs/qwen3_06b/tokenizer_config.json
Normal file
239
configs/qwen3_06b/tokenizer_config.json
Normal file
@@ -0,0 +1,239 @@
|
|||||||
|
{
|
||||||
|
"add_bos_token": false,
|
||||||
|
"add_prefix_space": false,
|
||||||
|
"added_tokens_decoder": {
|
||||||
|
"151643": {
|
||||||
|
"content": "<|endoftext|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"151644": {
|
||||||
|
"content": "<|im_start|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"151645": {
|
||||||
|
"content": "<|im_end|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"151646": {
|
||||||
|
"content": "<|object_ref_start|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"151647": {
|
||||||
|
"content": "<|object_ref_end|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"151648": {
|
||||||
|
"content": "<|box_start|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"151649": {
|
||||||
|
"content": "<|box_end|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"151650": {
|
||||||
|
"content": "<|quad_start|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"151651": {
|
||||||
|
"content": "<|quad_end|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"151652": {
|
||||||
|
"content": "<|vision_start|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"151653": {
|
||||||
|
"content": "<|vision_end|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"151654": {
|
||||||
|
"content": "<|vision_pad|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"151655": {
|
||||||
|
"content": "<|image_pad|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"151656": {
|
||||||
|
"content": "<|video_pad|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"151657": {
|
||||||
|
"content": "<tool_call>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
},
|
||||||
|
"151658": {
|
||||||
|
"content": "</tool_call>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
},
|
||||||
|
"151659": {
|
||||||
|
"content": "<|fim_prefix|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
},
|
||||||
|
"151660": {
|
||||||
|
"content": "<|fim_middle|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
},
|
||||||
|
"151661": {
|
||||||
|
"content": "<|fim_suffix|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
},
|
||||||
|
"151662": {
|
||||||
|
"content": "<|fim_pad|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
},
|
||||||
|
"151663": {
|
||||||
|
"content": "<|repo_name|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
},
|
||||||
|
"151664": {
|
||||||
|
"content": "<|file_sep|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
},
|
||||||
|
"151665": {
|
||||||
|
"content": "<tool_response>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
},
|
||||||
|
"151666": {
|
||||||
|
"content": "</tool_response>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
},
|
||||||
|
"151667": {
|
||||||
|
"content": "<think>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
},
|
||||||
|
"151668": {
|
||||||
|
"content": "</think>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additional_special_tokens": [
|
||||||
|
"<|im_start|>",
|
||||||
|
"<|im_end|>",
|
||||||
|
"<|object_ref_start|>",
|
||||||
|
"<|object_ref_end|>",
|
||||||
|
"<|box_start|>",
|
||||||
|
"<|box_end|>",
|
||||||
|
"<|quad_start|>",
|
||||||
|
"<|quad_end|>",
|
||||||
|
"<|vision_start|>",
|
||||||
|
"<|vision_end|>",
|
||||||
|
"<|vision_pad|>",
|
||||||
|
"<|image_pad|>",
|
||||||
|
"<|video_pad|>"
|
||||||
|
],
|
||||||
|
"bos_token": null,
|
||||||
|
"chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set content = message.content %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is defined and message.reasoning_content is not none %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '</think>' in message.content %}\n {%- set content = message.content.split('</think>')[-1].lstrip('\\n') %}\n {%- set reasoning_content = message.content.split('</think>')[0].rstrip('\\n').split('<think>')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n<think>\\n' + reasoning_content.strip('\\n') + '\\n</think>\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '<think>\\n\\n</think>\\n\\n' }}\n {%- endif %}\n{%- endif %}",
|
||||||
|
"clean_up_tokenization_spaces": false,
|
||||||
|
"eos_token": "<|endoftext|>",
|
||||||
|
"errors": "replace",
|
||||||
|
"model_max_length": 131072,
|
||||||
|
"pad_token": "<|endoftext|>",
|
||||||
|
"split_special_tokens": false,
|
||||||
|
"tokenizer_class": "Qwen2Tokenizer",
|
||||||
|
"unk_token": null
|
||||||
|
}
|
||||||
1
configs/qwen3_06b/vocab.json
Normal file
1
configs/qwen3_06b/vocab.json
Normal file
File diff suppressed because one or more lines are too long
51
configs/t5_old/config.json
Normal file
51
configs/t5_old/config.json
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
{
|
||||||
|
"architectures": [
|
||||||
|
"T5WithLMHeadModel"
|
||||||
|
],
|
||||||
|
"d_ff": 65536,
|
||||||
|
"d_kv": 128,
|
||||||
|
"d_model": 1024,
|
||||||
|
"decoder_start_token_id": 0,
|
||||||
|
"dropout_rate": 0.1,
|
||||||
|
"eos_token_id": 1,
|
||||||
|
"initializer_factor": 1.0,
|
||||||
|
"is_encoder_decoder": true,
|
||||||
|
"layer_norm_epsilon": 1e-06,
|
||||||
|
"model_type": "t5",
|
||||||
|
"n_positions": 512,
|
||||||
|
"num_heads": 128,
|
||||||
|
"num_layers": 24,
|
||||||
|
"output_past": true,
|
||||||
|
"pad_token_id": 0,
|
||||||
|
"relative_attention_num_buckets": 32,
|
||||||
|
"task_specific_params": {
|
||||||
|
"summarization": {
|
||||||
|
"early_stopping": true,
|
||||||
|
"length_penalty": 2.0,
|
||||||
|
"max_length": 200,
|
||||||
|
"min_length": 30,
|
||||||
|
"no_repeat_ngram_size": 3,
|
||||||
|
"num_beams": 4,
|
||||||
|
"prefix": "summarize: "
|
||||||
|
},
|
||||||
|
"translation_en_to_de": {
|
||||||
|
"early_stopping": true,
|
||||||
|
"max_length": 300,
|
||||||
|
"num_beams": 4,
|
||||||
|
"prefix": "translate English to German: "
|
||||||
|
},
|
||||||
|
"translation_en_to_fr": {
|
||||||
|
"early_stopping": true,
|
||||||
|
"max_length": 300,
|
||||||
|
"num_beams": 4,
|
||||||
|
"prefix": "translate English to French: "
|
||||||
|
},
|
||||||
|
"translation_en_to_ro": {
|
||||||
|
"early_stopping": true,
|
||||||
|
"max_length": 300,
|
||||||
|
"num_beams": 4,
|
||||||
|
"prefix": "translate English to Romanian: "
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"vocab_size": 32128
|
||||||
|
}
|
||||||
BIN
configs/t5_old/spiece.model
Normal file
BIN
configs/t5_old/spiece.model
Normal file
Binary file not shown.
1
configs/t5_old/tokenizer.json
Normal file
1
configs/t5_old/tokenizer.json
Normal file
File diff suppressed because one or more lines are too long
556
docs/anima_train_network.md
Normal file
556
docs/anima_train_network.md
Normal file
@@ -0,0 +1,556 @@
|
|||||||
|
# LoRA Training Guide for Anima using `anima_train_network.py` / `anima_train_network.py` を用いたAnima モデルのLoRA学習ガイド
|
||||||
|
|
||||||
|
This document explains how to train LoRA (Low-Rank Adaptation) models for Anima using `anima_train_network.py` in the `sd-scripts` repository.
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>日本語</summary>
|
||||||
|
|
||||||
|
このドキュメントでは、`sd-scripts`リポジトリに含まれる`anima_train_network.py`を使用して、Anima モデルに対するLoRA (Low-Rank Adaptation) モデルを学習する基本的な手順について解説します。
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## 1. Introduction / はじめに
|
||||||
|
|
||||||
|
`anima_train_network.py` trains additional networks such as LoRA for Anima models. Anima adopts a DiT (Diffusion Transformer) architecture based on the MiniTrainDIT design with Rectified Flow training. It uses a Qwen3-0.6B text encoder, an LLM Adapter (6-layer transformer bridge from Qwen3 to T5-compatible space), and a WanVAE (16-channel, 8x spatial downscale).
|
||||||
|
|
||||||
|
This guide assumes you already understand the basics of LoRA training. For common usage and options, see the [train_network.py guide](train_network.md). Some parameters are similar to those in [`sd3_train_network.py`](sd3_train_network.md) and [`flux_train_network.py`](flux_train_network.md).
|
||||||
|
|
||||||
|
**Prerequisites:**
|
||||||
|
|
||||||
|
* The `sd-scripts` repository has been cloned and the Python environment is ready.
|
||||||
|
* A training dataset has been prepared. See the [Dataset Configuration Guide](./config_README-en.md).
|
||||||
|
* Anima model files for training are available.
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>日本語</summary>
|
||||||
|
|
||||||
|
`anima_train_network.py`は、Anima モデルに対してLoRAなどの追加ネットワークを学習させるためのスクリプトです。AnimaはMiniTrainDIT設計に基づくDiT (Diffusion Transformer) アーキテクチャを採用しており、Rectified Flow学習を使用します。テキストエンコーダーとしてQwen3-0.6B、LLM Adapter (Qwen3からT5互換空間への6層Transformerブリッジ)、およびWanVAE (16チャンネル、8倍空間ダウンスケール) を使用します。
|
||||||
|
|
||||||
|
このガイドは、基本的なLoRA学習の手順を理解しているユーザーを対象としています。基本的な使い方や共通のオプションについては、[`train_network.py`のガイド](train_network.md)を参照してください。また一部のパラメータは [`sd3_train_network.py`](sd3_train_network.md) や [`flux_train_network.py`](flux_train_network.md) と同様のものがあるため、そちらも参考にしてください。
|
||||||
|
|
||||||
|
**前提条件:**
|
||||||
|
|
||||||
|
* `sd-scripts`リポジトリのクローンとPython環境のセットアップが完了していること。
|
||||||
|
* 学習用データセットの準備が完了していること。(データセットの準備については[データセット設定ガイド](./config_README-en.md)を参照してください)
|
||||||
|
* 学習対象のAnimaモデルファイルが準備できていること。
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## 2. Differences from `train_network.py` / `train_network.py` との違い
|
||||||
|
|
||||||
|
`anima_train_network.py` is based on `train_network.py` but modified for Anima . Main differences are:
|
||||||
|
|
||||||
|
* **Target models:** Anima DiT models.
|
||||||
|
* **Model structure:** Uses a MiniTrainDIT (Transformer based) instead of U-Net. Employs a single text encoder (Qwen3-0.6B), an LLM Adapter that bridges Qwen3 embeddings to T5-compatible cross-attention space, and a WanVAE (16-channel latent space with 8x spatial downscale).
|
||||||
|
* **Arguments:** Options exist to specify the Anima DiT model, Qwen3 text encoder, WanVAE, LLM adapter, and T5 tokenizer separately.
|
||||||
|
* **Incompatible arguments:** Stable Diffusion v1/v2 options such as `--v2`, `--v_parameterization` and `--clip_skip` are not used.
|
||||||
|
* **Anima specific options:** Additional parameters for component-wise learning rates (self_attn, cross_attn, mlp, mod, llm_adapter), timestep sampling, discrete flow shift, and flash attention.
|
||||||
|
* **6 Parameter Groups:** Independent learning rates for `base`, `self_attn`, `cross_attn`, `mlp`, `adaln_modulation`, and `llm_adapter` components.
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>日本語</summary>
|
||||||
|
|
||||||
|
`anima_train_network.py`は`train_network.py`をベースに、Anima モデルに対応するための変更が加えられています。主な違いは以下の通りです。
|
||||||
|
|
||||||
|
* **対象モデル:** Anima DiTモデルを対象とします。
|
||||||
|
* **モデル構造:** U-Netの代わりにMiniTrainDIT (Transformerベース) を使用します。テキストエンコーダーとしてQwen3-0.6B、Qwen3埋め込みをT5互換のクロスアテンション空間に変換するLLM Adapter、およびWanVAE (16チャンネル潜在空間、8倍空間ダウンスケール) を使用します。
|
||||||
|
* **引数:** Anima DiTモデル、Qwen3テキストエンコーダー、WanVAE、LLM Adapter、T5トークナイザーを個別に指定する引数があります。
|
||||||
|
* **一部引数の非互換性:** Stable Diffusion v1/v2向けの引数(例: `--v2`, `--v_parameterization`, `--clip_skip`)はAnimaの学習では使用されません。
|
||||||
|
* **Anima特有の引数:** コンポーネント別学習率(self_attn, cross_attn, mlp, mod, llm_adapter)、タイムステップサンプリング、離散フローシフト、Flash Attentionに関する引数が追加されています。
|
||||||
|
* **6パラメータグループ:** `base`、`self_attn`、`cross_attn`、`mlp`、`adaln_modulation`、`llm_adapter`の各コンポーネントに対して独立した学習率を設定できます。
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## 3. Preparation / 準備
|
||||||
|
|
||||||
|
The following files are required before starting training:
|
||||||
|
|
||||||
|
1. **Training script:** `anima_train_network.py`
|
||||||
|
2. **Anima DiT model file:** `.safetensors` file for the base DiT model.
|
||||||
|
3. **Qwen3-0.6B text encoder:** Either a HuggingFace model directory or a single `.safetensors` file (requires `configs/qwen3_06b/` config files).
|
||||||
|
4. **WanVAE model file:** `.safetensors` or `.pth` file for the VAE.
|
||||||
|
5. **LLM Adapter model file (optional):** `.safetensors` file. If not provided separately, the adapter is loaded from the DiT file if the key `llm_adapter.out_proj.weight` exists.
|
||||||
|
6. **T5 Tokenizer (optional):** If not specified, uses the bundled tokenizer at `configs/t5_old/`.
|
||||||
|
7. **Dataset definition file (.toml):** Dataset settings in TOML format. (See the [Dataset Configuration Guide](./config_README-en.md).) In this document we use `my_anima_dataset_config.toml` as an example.
|
||||||
|
|
||||||
|
**Notes:**
|
||||||
|
* When using a single `.safetensors` file for Qwen3, download the `config.json`, `tokenizer.json`, `tokenizer_config.json`, and `vocab.json` from the [Qwen/Qwen3-0.6B](https://huggingface.co/Qwen/Qwen3-0.6B) HuggingFace repository into the `configs/qwen3_06b/` directory.
|
||||||
|
* The T5 tokenizer only needs the tokenizer files (not the T5 model weights). It uses the vocabulary from `google/t5-v1_1-xxl`.
|
||||||
|
* Models are saved with a `net.` prefix on all keys for ComfyUI compatibility.
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>日本語</summary>
|
||||||
|
|
||||||
|
学習を開始する前に、以下のファイルが必要です。
|
||||||
|
|
||||||
|
1. **学習スクリプト:** `anima_train_network.py`
|
||||||
|
2. **Anima DiTモデルファイル:** ベースとなるDiTモデルの`.safetensors`ファイル。
|
||||||
|
3. **Qwen3-0.6Bテキストエンコーダー:** HuggingFaceモデルディレクトリまたは単体の`.safetensors`ファイル(`configs/qwen3_06b/`の設定ファイルが必要)。
|
||||||
|
4. **WanVAEモデルファイル:** VAEの`.safetensors`または`.pth`ファイル。
|
||||||
|
5. **LLM Adapterモデルファイル(オプション):** `.safetensors`ファイル。個別に指定しない場合、DiTファイル内に`llm_adapter.out_proj.weight`キーが存在すればそこから読み込まれます。
|
||||||
|
6. **T5トークナイザー(オプション):** 指定しない場合、`configs/t5_old/`のバンドル版トークナイザーを使用します。
|
||||||
|
7. **データセット定義ファイル (.toml):** 学習データセットの設定を記述したTOML形式のファイル。(詳細は[データセット設定ガイド](./config_README-en.md)を参照してください)。例として`my_anima_dataset_config.toml`を使用します。
|
||||||
|
|
||||||
|
**注意:**
|
||||||
|
* Qwen3の単体`.safetensors`ファイルを使用する場合、[Qwen/Qwen3-0.6B](https://huggingface.co/Qwen/Qwen3-0.6B) HuggingFaceリポジトリから`config.json`、`tokenizer.json`、`tokenizer_config.json`、`vocab.json`をダウンロードし、`configs/qwen3_06b/`ディレクトリに配置してください。
|
||||||
|
* T5トークナイザーはトークナイザーファイルのみ必要です(T5モデルの重みは不要)。`google/t5-v1_1-xxl`の語彙を使用します。
|
||||||
|
* モデルはComfyUI互換のため、すべてのキーに`net.`プレフィックスを付けて保存されます。
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## 4. Running the Training / 学習の実行
|
||||||
|
|
||||||
|
Execute `anima_train_network.py` from the terminal to start training. The overall command-line format is the same as `train_network.py`, but Anima specific options must be supplied.
|
||||||
|
|
||||||
|
Example command:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
accelerate launch --num_cpu_threads_per_process 1 anima_train_network.py \
|
||||||
|
--dit_path="<path to Anima DiT model>" \
|
||||||
|
--qwen3_path="<path to Qwen3-0.6B model or directory>" \
|
||||||
|
--vae_path="<path to WanVAE model>" \
|
||||||
|
--llm_adapter_path="<path to LLM adapter model>" \
|
||||||
|
--dataset_config="my_anima_dataset_config.toml" \
|
||||||
|
--output_dir="<output directory>" \
|
||||||
|
--output_name="my_anima_lora" \
|
||||||
|
--save_model_as=safetensors \
|
||||||
|
--network_module=networks.lora_anima \
|
||||||
|
--network_dim=8 \
|
||||||
|
--network_alpha=8 \
|
||||||
|
--learning_rate=1e-4 \
|
||||||
|
--optimizer_type="AdamW8bit" \
|
||||||
|
--lr_scheduler="constant" \
|
||||||
|
--timestep_sample_method="logit_normal" \
|
||||||
|
--discrete_flow_shift=3.0 \
|
||||||
|
--max_train_epochs=10 \
|
||||||
|
--save_every_n_epochs=1 \
|
||||||
|
--mixed_precision="bf16" \
|
||||||
|
--gradient_checkpointing \
|
||||||
|
--cache_latents \
|
||||||
|
--cache_text_encoder_outputs \
|
||||||
|
--blocks_to_swap=18
|
||||||
|
```
|
||||||
|
|
||||||
|
*(Write the command on one line or use `\` or `^` for line breaks.)*
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>日本語</summary>
|
||||||
|
|
||||||
|
学習は、ターミナルから`anima_train_network.py`を実行することで開始します。基本的なコマンドラインの構造は`train_network.py`と同様ですが、Anima特有の引数を指定する必要があります。
|
||||||
|
|
||||||
|
コマンドラインの例は英語のドキュメントを参照してください。
|
||||||
|
|
||||||
|
※実際には1行で書くか、適切な改行文字(`\` または `^`)を使用してください。
|
||||||
|
</details>
|
||||||
|
|
||||||
|
### 4.1. Explanation of Key Options / 主要なコマンドライン引数の解説
|
||||||
|
|
||||||
|
Besides the arguments explained in the [train_network.py guide](train_network.md), specify the following Anima specific options. For shared options (`--output_dir`, `--output_name`, `--network_module`, etc.), see that guide.
|
||||||
|
|
||||||
|
#### Model Options [Required] / モデル関連 [必須]
|
||||||
|
|
||||||
|
* `--dit_path="<path to Anima DiT model>"` **[Required]**
|
||||||
|
- Path to the Anima DiT model `.safetensors` file. The model config (channels, blocks, heads) is auto-detected from the state dict. ComfyUI format with `net.` prefix is supported.
|
||||||
|
* `--qwen3_path="<path to Qwen3-0.6B model>"` **[Required]**
|
||||||
|
- Path to the Qwen3-0.6B text encoder. Can be a HuggingFace model directory or a single `.safetensors` file. The text encoder is always frozen during training.
|
||||||
|
* `--vae_path="<path to WanVAE model>"` **[Required]**
|
||||||
|
- Path to the WanVAE model `.safetensors` or `.pth` file. Fixed config: `dim=96, z_dim=16`.
|
||||||
|
* `--llm_adapter_path="<path to LLM adapter>"` *[Optional]*
|
||||||
|
- Path to a separate LLM adapter weights file. If omitted, the adapter is loaded from the DiT file when the key `llm_adapter.out_proj.weight` exists.
|
||||||
|
* `--t5_tokenizer_path="<path to T5 tokenizer>"` *[Optional]*
|
||||||
|
- Path to the T5 tokenizer directory. If omitted, uses the bundled config at `configs/t5_old/`.
|
||||||
|
|
||||||
|
#### Anima Training Parameters / Anima 学習パラメータ
|
||||||
|
|
||||||
|
* `--timestep_sample_method=<choice>`
|
||||||
|
- Timestep sampling method. Choose from `logit_normal` (default) or `uniform`.
|
||||||
|
* `--discrete_flow_shift=<float>`
|
||||||
|
- Shift for the timestep distribution in Rectified Flow training. Default `3.0`. The shift formula is `t_shifted = (t * shift) / (1 + (shift - 1) * t)`.
|
||||||
|
* `--sigmoid_scale=<float>`
|
||||||
|
- Scale factor for `logit_normal` timestep sampling. Default `1.0`.
|
||||||
|
* `--qwen3_max_token_length=<integer>`
|
||||||
|
- Maximum token length for the Qwen3 tokenizer. Default `512`.
|
||||||
|
* `--t5_max_token_length=<integer>`
|
||||||
|
- Maximum token length for the T5 tokenizer. Default `512`.
|
||||||
|
* `--flash_attn`
|
||||||
|
- Use Flash Attention for DiT self/cross-attention. Requires `pip install flash-attn`. Falls back to PyTorch SDPA if the package is not installed. Note: Flash Attention is only applied to DiT blocks; the LLM Adapter uses standard attention because it requires attention masks.
|
||||||
|
* `--transformer_dtype=<choice>`
|
||||||
|
- Separate dtype for transformer blocks. Choose from `float16`, `bfloat16`, `float32`. If not specified, uses the same dtype as `--mixed_precision`.
|
||||||
|
|
||||||
|
#### Component-wise Learning Rates / コンポーネント別学習率
|
||||||
|
|
||||||
|
Anima supports 6 independent learning rate groups. Set to `0` to freeze a component:
|
||||||
|
|
||||||
|
* `--self_attn_lr=<float>` - Learning rate for self-attention layers. Default: same as `--learning_rate`.
|
||||||
|
* `--cross_attn_lr=<float>` - Learning rate for cross-attention layers. Default: same as `--learning_rate`.
|
||||||
|
* `--mlp_lr=<float>` - Learning rate for MLP layers. Default: same as `--learning_rate`.
|
||||||
|
* `--mod_lr=<float>` - Learning rate for AdaLN modulation layers. Default: same as `--learning_rate`.
|
||||||
|
* `--llm_adapter_lr=<float>` - Learning rate for LLM adapter layers. Default: same as `--learning_rate`.
|
||||||
|
|
||||||
|
#### Memory and Speed / メモリ・速度関連
|
||||||
|
|
||||||
|
* `--blocks_to_swap=<integer>` **[Experimental]**
|
||||||
|
- Number of Transformer blocks to swap between CPU and GPU. More blocks reduce VRAM but slow training. Maximum values depend on model size:
|
||||||
|
- 28-block model: max **26**
|
||||||
|
- 36-block model: max **34**
|
||||||
|
- 20-block model: max **18**
|
||||||
|
- Cannot be used with `--cpu_offload_checkpointing` or `--unsloth_offload_checkpointing`.
|
||||||
|
* `--unsloth_offload_checkpointing`
|
||||||
|
- Offload activations to CPU RAM using async non-blocking transfers. Faster than `--cpu_offload_checkpointing`. Cannot be combined with `--cpu_offload_checkpointing` or `--blocks_to_swap`.
|
||||||
|
* `--cache_text_encoder_outputs`
|
||||||
|
- Cache Qwen3 text encoder outputs to reduce VRAM usage. Recommended when not training text encoder LoRA.
|
||||||
|
* `--cache_text_encoder_outputs_to_disk`
|
||||||
|
- Cache text encoder outputs to disk. Auto-enables `--cache_text_encoder_outputs`.
|
||||||
|
* `--cache_latents`, `--cache_latents_to_disk`
|
||||||
|
- Cache WanVAE latent outputs.
|
||||||
|
* `--fp8_base`
|
||||||
|
- Use FP8 precision for the base model to reduce VRAM usage.
|
||||||
|
|
||||||
|
#### Incompatible or Deprecated Options / 非互換・非推奨の引数
|
||||||
|
|
||||||
|
* `--v2`, `--v_parameterization`, `--clip_skip` - Options for Stable Diffusion v1/v2 that are not used for Anima training.
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>日本語</summary>
|
||||||
|
|
||||||
|
[`train_network.py`のガイド](train_network.md)で説明されている引数に加え、以下のAnima特有の引数を指定します。共通の引数については、上記ガイドを参照してください。
|
||||||
|
|
||||||
|
#### モデル関連 [必須]
|
||||||
|
|
||||||
|
* `--dit_path="<path to Anima DiT model>"` **[必須]** - Anima DiTモデルの`.safetensors`ファイルのパスを指定します。
|
||||||
|
* `--qwen3_path="<path to Qwen3-0.6B model>"` **[必須]** - Qwen3-0.6Bテキストエンコーダーのパスを指定します。
|
||||||
|
* `--vae_path="<path to WanVAE model>"` **[必須]** - WanVAEモデルのパスを指定します。
|
||||||
|
* `--llm_adapter_path="<path to LLM adapter>"` *[オプション]* - 個別のLLM Adapterの重みファイルのパス。
|
||||||
|
* `--t5_tokenizer_path="<path to T5 tokenizer>"` *[オプション]* - T5トークナイザーディレクトリのパス。
|
||||||
|
|
||||||
|
#### Anima 学習パラメータ
|
||||||
|
|
||||||
|
* `--timestep_sample_method` - タイムステップのサンプリング方法。`logit_normal`(デフォルト)または`uniform`。
|
||||||
|
* `--discrete_flow_shift` - Rectified Flow学習のタイムステップ分布シフト。デフォルト`3.0`。
|
||||||
|
* `--sigmoid_scale` - logit_normalタイムステップサンプリングのスケール係数。デフォルト`1.0`。
|
||||||
|
* `--qwen3_max_token_length` - Qwen3トークナイザーの最大トークン長。デフォルト`512`。
|
||||||
|
* `--t5_max_token_length` - T5トークナイザーの最大トークン長。デフォルト`512`。
|
||||||
|
* `--flash_attn` - DiTのself/cross-attentionにFlash Attentionを使用。`pip install flash-attn`が必要。
|
||||||
|
* `--transformer_dtype` - Transformerブロック用の個別dtype。
|
||||||
|
|
||||||
|
#### コンポーネント別学習率
|
||||||
|
|
||||||
|
Animaは6つの独立した学習率グループをサポートします。`0`に設定するとそのコンポーネントをフリーズします:
|
||||||
|
|
||||||
|
* `--self_attn_lr` - Self-attention層の学習率。
|
||||||
|
* `--cross_attn_lr` - Cross-attention層の学習率。
|
||||||
|
* `--mlp_lr` - MLP層の学習率。
|
||||||
|
* `--mod_lr` - AdaLNモジュレーション層の学習率。
|
||||||
|
* `--llm_adapter_lr` - LLM Adapter層の学習率。
|
||||||
|
|
||||||
|
#### メモリ・速度関連
|
||||||
|
|
||||||
|
* `--blocks_to_swap` **[実験的機能]** - TransformerブロックをCPUとGPUでスワップしてVRAMを節約。
|
||||||
|
* `--unsloth_offload_checkpointing` - 非同期転送でアクティベーションをCPU RAMにオフロード。
|
||||||
|
* `--cache_text_encoder_outputs` - Qwen3の出力をキャッシュしてメモリ使用量を削減。
|
||||||
|
* `--cache_latents`, `--cache_latents_to_disk` - WanVAEの出力をキャッシュ。
|
||||||
|
* `--fp8_base` - ベースモデルにFP8精度を使用。
|
||||||
|
</details>
|
||||||
|
|
||||||
|
### 4.2. Starting Training / 学習の開始
|
||||||
|
|
||||||
|
After setting the required arguments, run the command to begin training. The overall flow and how to check logs are the same as in the [train_network.py guide](train_network.md#32-starting-the-training--学習の開始).
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>日本語</summary>
|
||||||
|
|
||||||
|
必要な引数を設定したら、コマンドを実行して学習を開始します。全体の流れやログの確認方法は、[train_network.pyのガイド](train_network.md#32-starting-the-training--学習の開始)と同様です。
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## 5. LoRA Target Modules / LoRAの学習対象モジュール
|
||||||
|
|
||||||
|
When training LoRA with `anima_train_network.py`, the following modules are targeted:
|
||||||
|
|
||||||
|
* **DiT Blocks (`Block`)**: Self-attention, cross-attention, MLP, and AdaLN modulation layers within each transformer block.
|
||||||
|
* **LLM Adapter Blocks (`LLMAdapterTransformerBlock`)**: Only when `--network_args "train_llm_adapter=True"` is specified.
|
||||||
|
* **Text Encoder (Qwen3)**: Only when `--network_train_unet_only` is NOT specified.
|
||||||
|
|
||||||
|
The LoRA network module is `networks.lora_anima`.
|
||||||
|
|
||||||
|
### 5.1. Layer-specific Rank Configuration / 各層に対するランク指定
|
||||||
|
|
||||||
|
You can specify different ranks (network_dim) for each component of the Anima model. Setting `0` disables LoRA for that component.
|
||||||
|
|
||||||
|
| network_args | Target Component |
|
||||||
|
|---|---|
|
||||||
|
| `self_attn_dim` | Self-attention layers in DiT blocks |
|
||||||
|
| `cross_attn_dim` | Cross-attention layers in DiT blocks |
|
||||||
|
| `mlp_dim` | MLP layers in DiT blocks |
|
||||||
|
| `mod_dim` | AdaLN modulation layers in DiT blocks |
|
||||||
|
| `llm_adapter_dim` | LLM adapter layers (requires `train_llm_adapter=True`) |
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
```
|
||||||
|
--network_args "self_attn_dim=8" "cross_attn_dim=4" "mlp_dim=8" "mod_dim=4"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5.2. Embedding Layer LoRA / 埋め込み層LoRA
|
||||||
|
|
||||||
|
You can apply LoRA to embedding/output layers by specifying `emb_dims` in network_args as a comma-separated list of 3 numbers:
|
||||||
|
|
||||||
|
```
|
||||||
|
--network_args "emb_dims=[8,4,8]"
|
||||||
|
```
|
||||||
|
|
||||||
|
Each number corresponds to:
|
||||||
|
1. `x_embedder` (patch embedding)
|
||||||
|
2. `t_embedder` (timestep embedding)
|
||||||
|
3. `final_layer` (output layer)
|
||||||
|
|
||||||
|
Setting `0` disables LoRA for that layer.
|
||||||
|
|
||||||
|
### 5.3. Block Selection for Training / 学習するブロックの指定
|
||||||
|
|
||||||
|
You can specify which DiT blocks to train using `train_block_indices` in network_args. The indices are 0-based. Default is to train all blocks.
|
||||||
|
|
||||||
|
Specify indices as comma-separated integers or ranges:
|
||||||
|
|
||||||
|
```
|
||||||
|
--network_args "train_block_indices=0-5,10,15-27"
|
||||||
|
```
|
||||||
|
|
||||||
|
Special values: `all` (train all blocks), `none` (skip all blocks).
|
||||||
|
|
||||||
|
### 5.4. LLM Adapter LoRA / LLM Adapter LoRA
|
||||||
|
|
||||||
|
To apply LoRA to the LLM Adapter blocks:
|
||||||
|
|
||||||
|
```
|
||||||
|
--network_args "train_llm_adapter=True" "llm_adapter_dim=4"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5.5. Other Network Args / その他のネットワーク引数
|
||||||
|
|
||||||
|
* `--network_args "verbose=True"` - Print all LoRA module names and their dimensions.
|
||||||
|
* `--network_args "rank_dropout=0.1"` - Rank dropout rate.
|
||||||
|
* `--network_args "module_dropout=0.1"` - Module dropout rate.
|
||||||
|
* `--network_args "loraplus_lr_ratio=2.0"` - LoRA+ learning rate ratio.
|
||||||
|
* `--network_args "loraplus_unet_lr_ratio=2.0"` - LoRA+ learning rate ratio for DiT only.
|
||||||
|
* `--network_args "loraplus_text_encoder_lr_ratio=2.0"` - LoRA+ learning rate ratio for text encoder only.
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>日本語</summary>
|
||||||
|
|
||||||
|
`anima_train_network.py`でLoRAを学習させる場合、デフォルトでは以下のモジュールが対象となります。
|
||||||
|
|
||||||
|
* **DiTブロック (`Block`)**: 各Transformerブロック内のSelf-attention、Cross-attention、MLP、AdaLNモジュレーション層。
|
||||||
|
* **LLM Adapterブロック (`LLMAdapterTransformerBlock`)**: `--network_args "train_llm_adapter=True"`を指定した場合のみ。
|
||||||
|
* **テキストエンコーダー (Qwen3)**: `--network_train_unet_only`を指定しない場合のみ。
|
||||||
|
|
||||||
|
### 5.1. 各層のランクを指定する
|
||||||
|
|
||||||
|
`--network_args`で各コンポーネントに異なるランクを指定できます。`0`を指定するとその層にはLoRAが適用されません。
|
||||||
|
|
||||||
|
|network_args|対象コンポーネント|
|
||||||
|
|---|---|
|
||||||
|
|`self_attn_dim`|DiTブロック内のSelf-attention層|
|
||||||
|
|`cross_attn_dim`|DiTブロック内のCross-attention層|
|
||||||
|
|`mlp_dim`|DiTブロック内のMLP層|
|
||||||
|
|`mod_dim`|DiTブロック内のAdaLNモジュレーション層|
|
||||||
|
|`llm_adapter_dim`|LLM Adapter層(`train_llm_adapter=True`が必要)|
|
||||||
|
|
||||||
|
### 5.2. 埋め込み層LoRA
|
||||||
|
|
||||||
|
`emb_dims`で埋め込み/出力層にLoRAを適用できます。3つの数値をカンマ区切りで指定します。
|
||||||
|
|
||||||
|
各数値は `x_embedder`(パッチ埋め込み)、`t_embedder`(タイムステップ埋め込み)、`final_layer`(出力層)に対応します。
|
||||||
|
|
||||||
|
### 5.3. 学習するブロックの指定
|
||||||
|
|
||||||
|
`train_block_indices`でLoRAを適用するDiTブロックを指定できます。
|
||||||
|
|
||||||
|
### 5.4. LLM Adapter LoRA
|
||||||
|
|
||||||
|
LLM AdapterブロックにLoRAを適用するには:`--network_args "train_llm_adapter=True" "llm_adapter_dim=4"`
|
||||||
|
|
||||||
|
### 5.5. その他のネットワーク引数
|
||||||
|
|
||||||
|
* `verbose=True` - 全LoRAモジュール名とdimを表示
|
||||||
|
* `rank_dropout` - ランクドロップアウト率
|
||||||
|
* `module_dropout` - モジュールドロップアウト率
|
||||||
|
* `loraplus_lr_ratio` - LoRA+学習率比率
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## 6. Using the Trained Model / 学習済みモデルの利用
|
||||||
|
|
||||||
|
When training finishes, a LoRA model file (e.g. `my_anima_lora.safetensors`) is saved in the directory specified by `output_dir`. Use this file with inference environments that support Anima , such as ComfyUI with appropriate nodes.
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>日本語</summary>
|
||||||
|
|
||||||
|
学習が完了すると、指定した`output_dir`にLoRAモデルファイル(例: `my_anima_lora.safetensors`)が保存されます。このファイルは、Anima モデルに対応した推論環境(例: ComfyUI + 適切なノード)で使用できます。
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## 7. Advanced Settings / 高度な設定
|
||||||
|
|
||||||
|
### 7.1. VRAM Usage Optimization / VRAM使用量の最適化
|
||||||
|
|
||||||
|
Anima models can be large, so GPUs with limited VRAM may require optimization:
|
||||||
|
|
||||||
|
#### Key VRAM Reduction Options
|
||||||
|
|
||||||
|
- **`--fp8_base`**: Enables training in FP8 format for the DiT model.
|
||||||
|
|
||||||
|
- **`--blocks_to_swap <number>`**: Swaps blocks between CPU and GPU to reduce VRAM usage. Higher numbers save more VRAM but reduce training speed. See model-specific max values in section 4.1.
|
||||||
|
|
||||||
|
- **`--unsloth_offload_checkpointing`**: Offloads gradient checkpoints to CPU using async non-blocking transfers. Faster than `--cpu_offload_checkpointing`. Cannot be combined with `--blocks_to_swap`.
|
||||||
|
|
||||||
|
- **`--gradient_checkpointing`**: Standard gradient checkpointing to reduce VRAM at the cost of compute.
|
||||||
|
|
||||||
|
- **`--cache_text_encoder_outputs`**: Caches Qwen3 outputs so the text encoder can be freed from VRAM during training.
|
||||||
|
|
||||||
|
- **`--cache_latents`**: Caches WanVAE outputs so the VAE can be freed from VRAM during training.
|
||||||
|
|
||||||
|
- **Using Adafactor optimizer**: Can reduce VRAM usage:
|
||||||
|
```
|
||||||
|
--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --lr_scheduler constant_with_warmup --max_grad_norm 0.0
|
||||||
|
```
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>日本語</summary>
|
||||||
|
|
||||||
|
Animaモデルは大きい場合があるため、VRAMが限られたGPUでは最適化が必要です。
|
||||||
|
|
||||||
|
主要なVRAM削減オプション:
|
||||||
|
- `--fp8_base`: FP8形式での学習を有効化
|
||||||
|
- `--blocks_to_swap`: CPUとGPU間でブロックをスワップ
|
||||||
|
- `--unsloth_offload_checkpointing`: 非同期転送でアクティベーションをCPUにオフロード
|
||||||
|
- `--gradient_checkpointing`: 標準的な勾配チェックポイント
|
||||||
|
- `--cache_text_encoder_outputs`: Qwen3の出力をキャッシュ
|
||||||
|
- `--cache_latents`: WanVAEの出力をキャッシュ
|
||||||
|
- Adafactorオプティマイザの使用
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
### 7.2. Training Settings / 学習設定
|
||||||
|
|
||||||
|
#### Timestep Sampling
|
||||||
|
|
||||||
|
The `--timestep_sample_method` option specifies how timesteps (0-1) are sampled:
|
||||||
|
|
||||||
|
- `logit_normal` (default): Sample from Normal(0,1), multiply by `sigmoid_scale`, apply sigmoid. Good general-purpose option.
|
||||||
|
- `uniform`: Uniform random sampling from [0, 1].
|
||||||
|
|
||||||
|
#### Discrete Flow Shift
|
||||||
|
|
||||||
|
The `--discrete_flow_shift` option (default `3.0`) shifts the timestep distribution toward higher noise levels. The formula is:
|
||||||
|
|
||||||
|
```
|
||||||
|
t_shifted = (t * shift) / (1 + (shift - 1) * t)
|
||||||
|
```
|
||||||
|
|
||||||
|
Timesteps are clamped to `[1e-5, 1-1e-5]` after shifting.
|
||||||
|
|
||||||
|
#### Loss Weighting
|
||||||
|
|
||||||
|
The `--weighting_scheme` option specifies loss weighting by timestep:
|
||||||
|
|
||||||
|
- `uniform` (default): Equal weight for all timesteps.
|
||||||
|
- `sigma_sqrt`: Weight by `sigma^(-2)`.
|
||||||
|
- `cosmap`: Weight by `2 / (pi * (1 - 2*sigma + 2*sigma^2))`.
|
||||||
|
- `none`: Same as uniform.
|
||||||
|
|
||||||
|
#### Caption Dropout
|
||||||
|
|
||||||
|
Use `--caption_dropout_rate` for embedding-level caption dropout. This is handled by `AnimaTextEncodingStrategy` and is compatible with text encoder output caching. The subset-level `caption_dropout_rate` is automatically zeroed when this is set.
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>日本語</summary>
|
||||||
|
|
||||||
|
#### タイムステップサンプリング
|
||||||
|
|
||||||
|
`--timestep_sample_method`でタイムステップのサンプリング方法を指定します:
|
||||||
|
- `logit_normal`(デフォルト): 正規分布からサンプリングし、sigmoidを適用。
|
||||||
|
- `uniform`: [0, 1]の一様分布からサンプリング。
|
||||||
|
|
||||||
|
#### 離散フローシフト
|
||||||
|
|
||||||
|
`--discrete_flow_shift`(デフォルト`3.0`)はタイムステップ分布を高ノイズ側にシフトします。
|
||||||
|
|
||||||
|
#### 損失の重み付け
|
||||||
|
|
||||||
|
`--weighting_scheme`でタイムステップごとの損失の重み付けを指定します。
|
||||||
|
|
||||||
|
#### キャプションドロップアウト
|
||||||
|
|
||||||
|
`--caption_dropout_rate`で埋め込みレベルのキャプションドロップアウトを使用します。テキストエンコーダー出力のキャッシュと互換性があります。
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
### 7.3. Text Encoder LoRA Support / Text Encoder LoRAのサポート
|
||||||
|
|
||||||
|
Anima LoRA training supports training Qwen3 text encoder LoRA:
|
||||||
|
|
||||||
|
- To train only DiT: specify `--network_train_unet_only`
|
||||||
|
- To train DiT and Qwen3: omit `--network_train_unet_only`
|
||||||
|
|
||||||
|
You can specify a separate learning rate for Qwen3 with `--text_encoder_lr`. If not specified, the default `--learning_rate` is used.
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>日本語</summary>
|
||||||
|
|
||||||
|
Anima LoRA学習では、Qwen3テキストエンコーダーのLoRAもトレーニングできます。
|
||||||
|
|
||||||
|
- DiTのみ学習: `--network_train_unet_only`を指定
|
||||||
|
- DiTとQwen3を学習: `--network_train_unet_only`を省略
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## 8. Other Training Options / その他の学習オプション
|
||||||
|
|
||||||
|
- **`--loss_type`**: Loss function for training. Default `l2`.
|
||||||
|
- `l1`: L1 loss.
|
||||||
|
- `l2`: L2 loss (mean squared error).
|
||||||
|
- `huber`: Huber loss.
|
||||||
|
- `smooth_l1`: Smooth L1 loss.
|
||||||
|
|
||||||
|
- **`--huber_schedule`**, **`--huber_c`**, **`--huber_scale`**: Parameters for Huber loss when `--loss_type` is `huber` or `smooth_l1`.
|
||||||
|
|
||||||
|
- **`--ip_noise_gamma`**, **`--ip_noise_gamma_random_strength`**: Input Perturbation noise gamma values.
|
||||||
|
|
||||||
|
- **`--fused_backward_pass`**: Fuses the backward pass and optimizer step to reduce VRAM usage. Only works with Adafactor. For details, see the [`sdxl_train_network.py` guide](sdxl_train_network.md).
|
||||||
|
|
||||||
|
- **`--weighting_scheme`**, **`--logit_mean`**, **`--logit_std`**, **`--mode_scale`**: Timestep loss weighting options. For details, refer to the [`sd3_train_network.md` guide](sd3_train_network.md).
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>日本語</summary>
|
||||||
|
|
||||||
|
- **`--loss_type`**: 学習に用いる損失関数。デフォルト`l2`。`l1`, `l2`, `huber`, `smooth_l1`から選択。
|
||||||
|
- **`--huber_schedule`**, **`--huber_c`**, **`--huber_scale`**: Huber損失のパラメータ。
|
||||||
|
- **`--ip_noise_gamma`**: Input Perturbationノイズガンマ値。
|
||||||
|
- **`--fused_backward_pass`**: バックワードパスとオプティマイザステップの融合。
|
||||||
|
- **`--weighting_scheme`** 等: タイムステップ損失の重み付け。詳細は[`sd3_train_network.md`](sd3_train_network.md)を参照。
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## 9. Others / その他
|
||||||
|
|
||||||
|
### Metadata Saved in LoRA Models
|
||||||
|
|
||||||
|
The following Anima-specific metadata is saved in the LoRA model file:
|
||||||
|
|
||||||
|
* `ss_weighting_scheme`
|
||||||
|
* `ss_discrete_flow_shift`
|
||||||
|
* `ss_timestep_sample_method`
|
||||||
|
* `ss_sigmoid_scale`
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>日本語</summary>
|
||||||
|
|
||||||
|
`anima_train_network.py`には、サンプル画像の生成 (`--sample_prompts`など) や詳細なオプティマイザ設定など、`train_network.py`と共通の機能も多く存在します。これらについては、[`train_network.py`のガイド](train_network.md#5-other-features--その他の機能)やスクリプトのヘルプ (`python anima_train_network.py --help`) を参照してください。
|
||||||
|
|
||||||
|
### LoRAモデルに保存されるメタデータ
|
||||||
|
|
||||||
|
以下のAnima固有のメタデータがLoRAモデルファイルに保存されます:
|
||||||
|
|
||||||
|
* `ss_weighting_scheme`
|
||||||
|
* `ss_discrete_flow_shift`
|
||||||
|
* `ss_timestep_sample_method`
|
||||||
|
* `ss_sigmoid_scale`
|
||||||
|
|
||||||
|
</details>
|
||||||
1630
library/anima_models.py
Normal file
1630
library/anima_models.py
Normal file
File diff suppressed because it is too large
Load Diff
665
library/anima_train_utils.py
Normal file
665
library/anima_train_utils.py
Normal file
@@ -0,0 +1,665 @@
|
|||||||
|
# Anima Training Utilities
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
from accelerate import Accelerator, PartialState
|
||||||
|
from tqdm import tqdm
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from library.device_utils import init_ipex, clean_memory_on_device
|
||||||
|
|
||||||
|
init_ipex()
|
||||||
|
|
||||||
|
from .utils import setup_logging
|
||||||
|
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
from library import anima_models, anima_utils, strategy_base, train_util
|
||||||
|
|
||||||
|
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler, get_sigmas
|
||||||
|
|
||||||
|
|
||||||
|
# Anima-specific training arguments
|
||||||
|
|
||||||
|
def add_anima_training_arguments(parser: argparse.ArgumentParser):
|
||||||
|
"""Add Anima-specific training arguments to the parser."""
|
||||||
|
parser.add_argument(
|
||||||
|
"--dit_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to Anima DiT model safetensors file",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--vae_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to WanVAE safetensors/pth file",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--qwen3_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to Qwen3-0.6B model (safetensors file or directory)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--llm_adapter_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to separate LLM adapter weights. If None, adapter is loaded from DiT file if present",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--llm_adapter_lr",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Learning rate for LLM adapter. None=same as base LR, 0=freeze adapter",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--self_attn_lr",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Learning rate for self-attention layers. None=same as base LR, 0=freeze",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cross_attn_lr",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Learning rate for cross-attention layers. None=same as base LR, 0=freeze",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--mlp_lr",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Learning rate for MLP layers. None=same as base LR, 0=freeze",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--mod_lr",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Learning rate for AdaLN modulation layers. None=same as base LR, 0=freeze",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--t5_tokenizer_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to T5 tokenizer directory. If None, uses default configs/t5_old/",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--qwen3_max_token_length",
|
||||||
|
type=int,
|
||||||
|
default=512,
|
||||||
|
help="Maximum token length for Qwen3 tokenizer (default: 512)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--t5_max_token_length",
|
||||||
|
type=int,
|
||||||
|
default=512,
|
||||||
|
help="Maximum token length for T5 tokenizer (default: 512)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--discrete_flow_shift",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help="Timestep distribution shift for rectified flow training (default: 1.0)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--timestep_sample_method",
|
||||||
|
type=str,
|
||||||
|
default="logit_normal",
|
||||||
|
choices=["logit_normal", "uniform"],
|
||||||
|
help="Timestep sampling method (default: logit_normal)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--sigmoid_scale",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help="Scale factor for logit_normal timestep sampling (default: 1.0)",
|
||||||
|
)
|
||||||
|
# Note: --caption_dropout_rate is defined by base add_dataset_arguments().
|
||||||
|
# Anima uses embedding-level dropout (via AnimaTextEncodingStrategy.dropout_rate)
|
||||||
|
# instead of dataset-level caption dropout, so the subset caption_dropout_rate
|
||||||
|
# is zeroed out in the training scripts to allow caching.
|
||||||
|
parser.add_argument(
|
||||||
|
"--transformer_dtype",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
choices=["float16", "bfloat16", "float32", None],
|
||||||
|
help="Separate dtype for transformer blocks. If None, uses same as mixed_precision",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--flash_attn",
|
||||||
|
action="store_true",
|
||||||
|
help="Use Flash Attention for DiT self/cross-attention (requires flash-attn package). "
|
||||||
|
"Falls back to PyTorch SDPA if flash-attn is not installed.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Noise & Timestep sampling (Rectified Flow)
|
||||||
|
def get_noisy_model_input_and_timesteps(
|
||||||
|
args,
|
||||||
|
latents: torch.Tensor,
|
||||||
|
noise: torch.Tensor,
|
||||||
|
device: torch.device,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""Generate noisy model input and timesteps for rectified flow training.
|
||||||
|
|
||||||
|
Rectified flow: noisy_input = (1 - t) * latents + t * noise
|
||||||
|
Target: noise - latents
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: Training arguments with timestep_sample_method, sigmoid_scale, discrete_flow_shift
|
||||||
|
latents: Clean latent tensors
|
||||||
|
noise: Random noise tensors
|
||||||
|
device: Target device
|
||||||
|
dtype: Target dtype
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(noisy_model_input, timesteps, sigmas)
|
||||||
|
"""
|
||||||
|
bs = latents.shape[0]
|
||||||
|
|
||||||
|
timestep_sample_method = getattr(args, 'timestep_sample_method', 'logit_normal')
|
||||||
|
sigmoid_scale = getattr(args, 'sigmoid_scale', 1.0)
|
||||||
|
shift = getattr(args, 'discrete_flow_shift', 1.0)
|
||||||
|
|
||||||
|
if timestep_sample_method == 'logit_normal':
|
||||||
|
dist = torch.distributions.normal.Normal(0, 1)
|
||||||
|
elif timestep_sample_method == 'uniform':
|
||||||
|
dist = torch.distributions.uniform.Uniform(0, 1)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unknown timestep_sample_method: {timestep_sample_method}")
|
||||||
|
|
||||||
|
t = dist.sample((bs,)).to(device)
|
||||||
|
|
||||||
|
if timestep_sample_method == 'logit_normal':
|
||||||
|
t = t * sigmoid_scale
|
||||||
|
t = torch.sigmoid(t)
|
||||||
|
|
||||||
|
# Apply shift
|
||||||
|
if shift is not None and shift != 1.0:
|
||||||
|
t = (t * shift) / (1 + (shift - 1) * t)
|
||||||
|
|
||||||
|
# Clamp to avoid exact 0 or 1
|
||||||
|
t = t.clamp(1e-5, 1.0 - 1e-5)
|
||||||
|
|
||||||
|
# Create noisy input: (1 - t) * latents + t * noise
|
||||||
|
t_expanded = t.view(-1, *([1] * (latents.ndim - 1)))
|
||||||
|
|
||||||
|
ip_noise_gamma = getattr(args, 'ip_noise_gamma', None)
|
||||||
|
if ip_noise_gamma:
|
||||||
|
xi = torch.randn_like(latents, device=latents.device, dtype=dtype)
|
||||||
|
if getattr(args, 'ip_noise_gamma_random_strength', False):
|
||||||
|
ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * ip_noise_gamma
|
||||||
|
noisy_model_input = (1 - t_expanded) * latents + t_expanded * (noise + ip_noise_gamma * xi)
|
||||||
|
else:
|
||||||
|
noisy_model_input = (1 - t_expanded) * latents + t_expanded * noise
|
||||||
|
|
||||||
|
# Sigmas for potential loss weighting
|
||||||
|
sigmas = t.view(-1, 1)
|
||||||
|
|
||||||
|
return noisy_model_input.to(dtype), t.to(dtype), sigmas.to(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
# Loss weighting
|
||||||
|
|
||||||
|
def compute_loss_weighting_for_anima(weighting_scheme: str, sigmas: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Compute loss weighting for Anima training.
|
||||||
|
|
||||||
|
Same schemes as SD3 but can add Anima-specific ones.
|
||||||
|
"""
|
||||||
|
if weighting_scheme == "sigma_sqrt":
|
||||||
|
weighting = (sigmas**-2.0).float()
|
||||||
|
elif weighting_scheme == "cosmap":
|
||||||
|
bot = 1 - 2 * sigmas + 2 * sigmas**2
|
||||||
|
weighting = 2 / (math.pi * bot)
|
||||||
|
elif weighting_scheme == "none" or weighting_scheme is None:
|
||||||
|
weighting = torch.ones_like(sigmas)
|
||||||
|
else:
|
||||||
|
weighting = torch.ones_like(sigmas)
|
||||||
|
return weighting
|
||||||
|
|
||||||
|
|
||||||
|
# Parameter groups (6 groups with separate LRs)
|
||||||
|
def get_anima_param_groups(
|
||||||
|
dit,
|
||||||
|
base_lr: float,
|
||||||
|
self_attn_lr: Optional[float] = None,
|
||||||
|
cross_attn_lr: Optional[float] = None,
|
||||||
|
mlp_lr: Optional[float] = None,
|
||||||
|
mod_lr: Optional[float] = None,
|
||||||
|
llm_adapter_lr: Optional[float] = None,
|
||||||
|
):
|
||||||
|
"""Create parameter groups for Anima training with separate learning rates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dit: MiniTrainDIT model
|
||||||
|
base_lr: Base learning rate
|
||||||
|
self_attn_lr: LR for self-attention layers (None = base_lr, 0 = freeze)
|
||||||
|
cross_attn_lr: LR for cross-attention layers
|
||||||
|
mlp_lr: LR for MLP layers
|
||||||
|
mod_lr: LR for AdaLN modulation layers
|
||||||
|
llm_adapter_lr: LR for LLM adapter
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of parameter group dicts for optimizer
|
||||||
|
"""
|
||||||
|
if self_attn_lr is None:
|
||||||
|
self_attn_lr = base_lr
|
||||||
|
if cross_attn_lr is None:
|
||||||
|
cross_attn_lr = base_lr
|
||||||
|
if mlp_lr is None:
|
||||||
|
mlp_lr = base_lr
|
||||||
|
if mod_lr is None:
|
||||||
|
mod_lr = base_lr
|
||||||
|
if llm_adapter_lr is None:
|
||||||
|
llm_adapter_lr = base_lr
|
||||||
|
|
||||||
|
base_params = []
|
||||||
|
self_attn_params = []
|
||||||
|
cross_attn_params = []
|
||||||
|
mlp_params = []
|
||||||
|
mod_params = []
|
||||||
|
llm_adapter_params = []
|
||||||
|
|
||||||
|
for name, p in dit.named_parameters():
|
||||||
|
# Store original name for debugging
|
||||||
|
p.original_name = name
|
||||||
|
|
||||||
|
if 'llm_adapter' in name:
|
||||||
|
llm_adapter_params.append(p)
|
||||||
|
elif '.self_attn' in name:
|
||||||
|
self_attn_params.append(p)
|
||||||
|
elif '.cross_attn' in name:
|
||||||
|
cross_attn_params.append(p)
|
||||||
|
elif '.mlp' in name:
|
||||||
|
mlp_params.append(p)
|
||||||
|
elif '.adaln_modulation' in name:
|
||||||
|
mod_params.append(p)
|
||||||
|
else:
|
||||||
|
base_params.append(p)
|
||||||
|
|
||||||
|
logger.info(f"Parameter groups:")
|
||||||
|
logger.info(f" base_params: {len(base_params)} (lr={base_lr})")
|
||||||
|
logger.info(f" self_attn_params: {len(self_attn_params)} (lr={self_attn_lr})")
|
||||||
|
logger.info(f" cross_attn_params: {len(cross_attn_params)} (lr={cross_attn_lr})")
|
||||||
|
logger.info(f" mlp_params: {len(mlp_params)} (lr={mlp_lr})")
|
||||||
|
logger.info(f" mod_params: {len(mod_params)} (lr={mod_lr})")
|
||||||
|
logger.info(f" llm_adapter_params: {len(llm_adapter_params)} (lr={llm_adapter_lr})")
|
||||||
|
|
||||||
|
param_groups = []
|
||||||
|
for lr, params, name in [
|
||||||
|
(base_lr, base_params, "base"),
|
||||||
|
(self_attn_lr, self_attn_params, "self_attn"),
|
||||||
|
(cross_attn_lr, cross_attn_params, "cross_attn"),
|
||||||
|
(mlp_lr, mlp_params, "mlp"),
|
||||||
|
(mod_lr, mod_params, "mod"),
|
||||||
|
(llm_adapter_lr, llm_adapter_params, "llm_adapter"),
|
||||||
|
]:
|
||||||
|
if lr == 0:
|
||||||
|
for p in params:
|
||||||
|
p.requires_grad_(False)
|
||||||
|
logger.info(f" Frozen {name} params ({len(params)} parameters)")
|
||||||
|
elif len(params) > 0:
|
||||||
|
param_groups.append({'params': params, 'lr': lr})
|
||||||
|
|
||||||
|
total_trainable = sum(p.numel() for group in param_groups for p in group['params'] if p.requires_grad)
|
||||||
|
logger.info(f"Total trainable parameters: {total_trainable:,}")
|
||||||
|
|
||||||
|
return param_groups
|
||||||
|
|
||||||
|
|
||||||
|
# Save functions
|
||||||
|
def save_anima_model_on_train_end(
|
||||||
|
args: argparse.Namespace,
|
||||||
|
save_dtype: torch.dtype,
|
||||||
|
epoch: int,
|
||||||
|
global_step: int,
|
||||||
|
dit: anima_models.MiniTrainDIT,
|
||||||
|
):
|
||||||
|
"""Save Anima model at the end of training."""
|
||||||
|
def sd_saver(ckpt_file, epoch_no, global_step):
|
||||||
|
sai_metadata = train_util.get_sai_model_spec(
|
||||||
|
None, args, False, False, False, is_stable_diffusion_ckpt=True
|
||||||
|
)
|
||||||
|
dit_sd = dit.state_dict()
|
||||||
|
# Save with 'net.' prefix for ComfyUI compatibility
|
||||||
|
anima_utils.save_anima_model(ckpt_file, dit_sd, save_dtype)
|
||||||
|
|
||||||
|
train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None)
|
||||||
|
|
||||||
|
|
||||||
|
def save_anima_model_on_epoch_end_or_stepwise(
|
||||||
|
args: argparse.Namespace,
|
||||||
|
on_epoch_end: bool,
|
||||||
|
accelerator: Accelerator,
|
||||||
|
save_dtype: torch.dtype,
|
||||||
|
epoch: int,
|
||||||
|
num_train_epochs: int,
|
||||||
|
global_step: int,
|
||||||
|
dit: anima_models.MiniTrainDIT,
|
||||||
|
):
|
||||||
|
"""Save Anima model at epoch end or specific steps."""
|
||||||
|
def sd_saver(ckpt_file, epoch_no, global_step):
|
||||||
|
sai_metadata = train_util.get_sai_model_spec(
|
||||||
|
None, args, False, False, False, is_stable_diffusion_ckpt=True
|
||||||
|
)
|
||||||
|
dit_sd = dit.state_dict()
|
||||||
|
anima_utils.save_anima_model(ckpt_file, dit_sd, save_dtype)
|
||||||
|
|
||||||
|
train_util.save_sd_model_on_epoch_end_or_stepwise_common(
|
||||||
|
args,
|
||||||
|
on_epoch_end,
|
||||||
|
accelerator,
|
||||||
|
True,
|
||||||
|
True,
|
||||||
|
epoch,
|
||||||
|
num_train_epochs,
|
||||||
|
global_step,
|
||||||
|
sd_saver,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Sampling (Euler discrete for rectified flow)
|
||||||
|
def do_sample(
|
||||||
|
height: int,
|
||||||
|
width: int,
|
||||||
|
seed: Optional[int],
|
||||||
|
dit: anima_models.MiniTrainDIT,
|
||||||
|
crossattn_emb: torch.Tensor,
|
||||||
|
steps: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
guidance_scale: float = 1.0,
|
||||||
|
neg_crossattn_emb: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Generate a sample using Euler discrete sampling for rectified flow.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
height, width: Output image dimensions
|
||||||
|
seed: Random seed (None for random)
|
||||||
|
dit: MiniTrainDIT model
|
||||||
|
crossattn_emb: Cross-attention embeddings (B, N, D)
|
||||||
|
steps: Number of sampling steps
|
||||||
|
dtype: Compute dtype
|
||||||
|
device: Compute device
|
||||||
|
guidance_scale: CFG scale (1.0 = no guidance)
|
||||||
|
neg_crossattn_emb: Negative cross-attention embeddings for CFG
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Denoised latents
|
||||||
|
"""
|
||||||
|
# Latent shape: (1, 16, 1, H/8, W/8) for single image
|
||||||
|
latent_h = height // 8
|
||||||
|
latent_w = width // 8
|
||||||
|
latent = torch.zeros(1, 16, 1, latent_h, latent_w, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
# Generate noise
|
||||||
|
if seed is not None:
|
||||||
|
generator = torch.manual_seed(seed)
|
||||||
|
else:
|
||||||
|
generator = None
|
||||||
|
noise = torch.randn(
|
||||||
|
latent.size(), dtype=torch.float32, generator=generator, device="cpu"
|
||||||
|
).to(dtype).to(device)
|
||||||
|
|
||||||
|
# Timestep schedule: linear from 1.0 to 0.0
|
||||||
|
sigmas = torch.linspace(1.0, 0.0, steps + 1, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
# Start from pure noise
|
||||||
|
x = noise.clone()
|
||||||
|
|
||||||
|
# Padding mask (zeros = no padding) — resized in prepare_embedded_sequence to match latent dims
|
||||||
|
padding_mask = torch.zeros(1, 1, latent_h, latent_w, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
use_cfg = guidance_scale > 1.0 and neg_crossattn_emb is not None
|
||||||
|
|
||||||
|
for i in tqdm(range(steps), desc="Sampling"):
|
||||||
|
sigma = sigmas[i]
|
||||||
|
t = sigma.unsqueeze(0) # (1,)
|
||||||
|
|
||||||
|
dit.prepare_block_swap_before_forward()
|
||||||
|
|
||||||
|
if use_cfg:
|
||||||
|
# CFG: concat positive and negative
|
||||||
|
x_input = torch.cat([x, x], dim=0)
|
||||||
|
t_input = torch.cat([t, t], dim=0)
|
||||||
|
crossattn_input = torch.cat([crossattn_emb, neg_crossattn_emb], dim=0)
|
||||||
|
padding_input = torch.cat([padding_mask, padding_mask], dim=0)
|
||||||
|
|
||||||
|
model_output = dit(x_input, t_input, crossattn_input, padding_mask=padding_input)
|
||||||
|
model_output = model_output.float()
|
||||||
|
|
||||||
|
pos_out, neg_out = model_output.chunk(2)
|
||||||
|
model_output = neg_out + guidance_scale * (pos_out - neg_out)
|
||||||
|
else:
|
||||||
|
model_output = dit(x, t, crossattn_emb, padding_mask=padding_mask)
|
||||||
|
model_output = model_output.float()
|
||||||
|
|
||||||
|
# Euler step: x_{t-1} = x_t - (sigma_t - sigma_{t-1}) * model_output
|
||||||
|
dt = sigmas[i + 1] - sigma
|
||||||
|
x = x + model_output * dt
|
||||||
|
x = x.to(dtype)
|
||||||
|
|
||||||
|
dit.prepare_block_swap_before_forward()
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def sample_images(
|
||||||
|
accelerator: Accelerator,
|
||||||
|
args: argparse.Namespace,
|
||||||
|
epoch,
|
||||||
|
steps,
|
||||||
|
dit,
|
||||||
|
vae,
|
||||||
|
vae_scale,
|
||||||
|
text_encoder,
|
||||||
|
tokenize_strategy,
|
||||||
|
text_encoding_strategy,
|
||||||
|
sample_prompts_te_outputs=None,
|
||||||
|
prompt_replacement=None,
|
||||||
|
):
|
||||||
|
"""Generate sample images during training.
|
||||||
|
|
||||||
|
This is a simplified sampler for Anima - it generates images using the current model state.
|
||||||
|
"""
|
||||||
|
if steps == 0:
|
||||||
|
if not args.sample_at_first:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
|
||||||
|
return
|
||||||
|
if args.sample_every_n_epochs is not None:
|
||||||
|
if epoch is None or epoch % args.sample_every_n_epochs != 0:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
if steps % args.sample_every_n_steps != 0 or epoch is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Generating sample images at step {steps}")
|
||||||
|
if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None:
|
||||||
|
logger.error(f"No prompt file: {args.sample_prompts}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Unwrap models
|
||||||
|
dit = accelerator.unwrap_model(dit)
|
||||||
|
if text_encoder is not None:
|
||||||
|
text_encoder = accelerator.unwrap_model(text_encoder)
|
||||||
|
|
||||||
|
prompts = train_util.load_prompts(args.sample_prompts)
|
||||||
|
save_dir = os.path.join(args.output_dir, "sample")
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Save RNG state
|
||||||
|
rng_state = torch.get_rng_state()
|
||||||
|
cuda_rng_state = None
|
||||||
|
try:
|
||||||
|
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
with torch.no_grad(), accelerator.autocast():
|
||||||
|
for prompt_dict in prompts:
|
||||||
|
_sample_image_inference(
|
||||||
|
accelerator, args, dit, text_encoder, vae, vae_scale,
|
||||||
|
tokenize_strategy, text_encoding_strategy,
|
||||||
|
save_dir, prompt_dict, epoch, steps,
|
||||||
|
sample_prompts_te_outputs, prompt_replacement,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Restore RNG state
|
||||||
|
torch.set_rng_state(rng_state)
|
||||||
|
if cuda_rng_state is not None:
|
||||||
|
torch.cuda.set_rng_state(cuda_rng_state)
|
||||||
|
|
||||||
|
clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
|
|
||||||
|
def _sample_image_inference(
|
||||||
|
accelerator, args, dit, text_encoder, vae, vae_scale,
|
||||||
|
tokenize_strategy, text_encoding_strategy,
|
||||||
|
save_dir, prompt_dict, epoch, steps,
|
||||||
|
sample_prompts_te_outputs, prompt_replacement,
|
||||||
|
):
|
||||||
|
"""Generate a single sample image."""
|
||||||
|
prompt = prompt_dict.get("prompt", "")
|
||||||
|
negative_prompt = prompt_dict.get("negative_prompt", "")
|
||||||
|
sample_steps = prompt_dict.get("sample_steps", 30)
|
||||||
|
width = prompt_dict.get("width", 512)
|
||||||
|
height = prompt_dict.get("height", 512)
|
||||||
|
scale = prompt_dict.get("scale", 7.5)
|
||||||
|
seed = prompt_dict.get("seed")
|
||||||
|
|
||||||
|
if prompt_replacement is not None:
|
||||||
|
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||||
|
if negative_prompt:
|
||||||
|
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||||
|
|
||||||
|
if seed is not None:
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed) # seed all CUDA devices for multi-GPU
|
||||||
|
|
||||||
|
height = max(64, height - height % 16)
|
||||||
|
width = max(64, width - width % 16)
|
||||||
|
|
||||||
|
logger.info(f" prompt: {prompt}, size: {width}x{height}, steps: {sample_steps}, scale: {scale}")
|
||||||
|
|
||||||
|
# Encode prompt
|
||||||
|
def encode_prompt(prpt):
|
||||||
|
if sample_prompts_te_outputs and prpt in sample_prompts_te_outputs:
|
||||||
|
return sample_prompts_te_outputs[prpt]
|
||||||
|
if text_encoder is not None:
|
||||||
|
tokens = tokenize_strategy.tokenize(prpt)
|
||||||
|
encoded = text_encoding_strategy.encode_tokens(tokenize_strategy, [text_encoder], tokens)
|
||||||
|
return encoded
|
||||||
|
return None
|
||||||
|
|
||||||
|
encoded = encode_prompt(prompt)
|
||||||
|
if encoded is None:
|
||||||
|
logger.warning("Cannot encode prompt, skipping sample")
|
||||||
|
return
|
||||||
|
|
||||||
|
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = encoded
|
||||||
|
|
||||||
|
# Convert to tensors if numpy
|
||||||
|
if isinstance(prompt_embeds, np.ndarray):
|
||||||
|
prompt_embeds = torch.from_numpy(prompt_embeds).unsqueeze(0)
|
||||||
|
attn_mask = torch.from_numpy(attn_mask).unsqueeze(0)
|
||||||
|
t5_input_ids = torch.from_numpy(t5_input_ids).unsqueeze(0)
|
||||||
|
t5_attn_mask = torch.from_numpy(t5_attn_mask).unsqueeze(0)
|
||||||
|
|
||||||
|
prompt_embeds = prompt_embeds.to(accelerator.device, dtype=dit.t_embedding_norm.weight.dtype)
|
||||||
|
attn_mask = attn_mask.to(accelerator.device)
|
||||||
|
t5_input_ids = t5_input_ids.to(accelerator.device, dtype=torch.long)
|
||||||
|
t5_attn_mask = t5_attn_mask.to(accelerator.device)
|
||||||
|
|
||||||
|
# Process through LLM adapter if available
|
||||||
|
if dit.use_llm_adapter and hasattr(dit, 'llm_adapter'):
|
||||||
|
crossattn_emb = dit.llm_adapter(
|
||||||
|
source_hidden_states=prompt_embeds,
|
||||||
|
target_input_ids=t5_input_ids,
|
||||||
|
target_attention_mask=t5_attn_mask,
|
||||||
|
source_attention_mask=attn_mask,
|
||||||
|
)
|
||||||
|
crossattn_emb[~t5_attn_mask.bool()] = 0
|
||||||
|
else:
|
||||||
|
crossattn_emb = prompt_embeds
|
||||||
|
|
||||||
|
# Encode negative prompt for CFG
|
||||||
|
neg_crossattn_emb = None
|
||||||
|
if scale > 1.0 and negative_prompt is not None:
|
||||||
|
neg_encoded = encode_prompt(negative_prompt)
|
||||||
|
if neg_encoded is not None:
|
||||||
|
neg_pe, neg_am, neg_t5_ids, neg_t5_am = neg_encoded
|
||||||
|
if isinstance(neg_pe, np.ndarray):
|
||||||
|
neg_pe = torch.from_numpy(neg_pe).unsqueeze(0)
|
||||||
|
neg_am = torch.from_numpy(neg_am).unsqueeze(0)
|
||||||
|
neg_t5_ids = torch.from_numpy(neg_t5_ids).unsqueeze(0)
|
||||||
|
neg_t5_am = torch.from_numpy(neg_t5_am).unsqueeze(0)
|
||||||
|
|
||||||
|
neg_pe = neg_pe.to(accelerator.device, dtype=dit.t_embedding_norm.weight.dtype)
|
||||||
|
neg_am = neg_am.to(accelerator.device)
|
||||||
|
neg_t5_ids = neg_t5_ids.to(accelerator.device, dtype=torch.long)
|
||||||
|
neg_t5_am = neg_t5_am.to(accelerator.device)
|
||||||
|
|
||||||
|
if dit.use_llm_adapter and hasattr(dit, 'llm_adapter'):
|
||||||
|
neg_crossattn_emb = dit.llm_adapter(
|
||||||
|
source_hidden_states=neg_pe,
|
||||||
|
target_input_ids=neg_t5_ids,
|
||||||
|
target_attention_mask=neg_t5_am,
|
||||||
|
source_attention_mask=neg_am,
|
||||||
|
)
|
||||||
|
neg_crossattn_emb[~neg_t5_am.bool()] = 0
|
||||||
|
else:
|
||||||
|
neg_crossattn_emb = neg_pe
|
||||||
|
|
||||||
|
# Generate sample
|
||||||
|
clean_memory_on_device(accelerator.device)
|
||||||
|
latents = do_sample(
|
||||||
|
height, width, seed, dit, crossattn_emb,
|
||||||
|
sample_steps, dit.t_embedding_norm.weight.dtype,
|
||||||
|
accelerator.device, scale, neg_crossattn_emb,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Decode latents
|
||||||
|
clean_memory_on_device(accelerator.device)
|
||||||
|
org_vae_device = next(vae.parameters()).device
|
||||||
|
vae.to(accelerator.device)
|
||||||
|
decoded = vae.decode(latents.to(next(vae.parameters()).device, dtype=next(vae.parameters()).dtype), vae_scale)
|
||||||
|
vae.to(org_vae_device)
|
||||||
|
clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
|
# Convert to image
|
||||||
|
image = decoded.float()
|
||||||
|
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0]
|
||||||
|
# Remove temporal dim if present
|
||||||
|
if image.ndim == 4:
|
||||||
|
image = image[:, 0, :, :]
|
||||||
|
decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2)
|
||||||
|
decoded_np = decoded_np.astype(np.uint8)
|
||||||
|
|
||||||
|
image = Image.fromarray(decoded_np)
|
||||||
|
|
||||||
|
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
||||||
|
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
|
||||||
|
seed_suffix = "" if seed is None else f"_{seed}"
|
||||||
|
i = prompt_dict.get("enum", 0)
|
||||||
|
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
|
||||||
|
image.save(os.path.join(save_dir, img_filename))
|
||||||
|
|
||||||
|
# Log to wandb if enabled
|
||||||
|
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
|
||||||
|
wandb_tracker = accelerator.get_tracker("wandb")
|
||||||
|
import wandb
|
||||||
|
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False)
|
||||||
325
library/anima_utils.py
Normal file
325
library/anima_utils.py
Normal file
@@ -0,0 +1,325 @@
|
|||||||
|
# Anima model loading/saving utilities
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from safetensors.torch import load_file, save_file
|
||||||
|
from accelerate.utils import set_module_tensor_to_device # kept for potential future use
|
||||||
|
|
||||||
|
from .utils import setup_logging
|
||||||
|
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
from library import anima_models
|
||||||
|
|
||||||
|
|
||||||
|
# Keys that should stay in high precision (float32/bfloat16, not quantized)
|
||||||
|
KEEP_IN_HIGH_PRECISION = ['x_embedder', 't_embedder', 't_embedding_norm', 'final_layer']
|
||||||
|
|
||||||
|
|
||||||
|
def load_safetensors(path: str, device: str = "cpu", dtype: Optional[torch.dtype] = None) -> Dict[str, torch.Tensor]:
|
||||||
|
"""Load a safetensors file and optionally cast to dtype."""
|
||||||
|
sd = load_file(path, device=device)
|
||||||
|
if dtype is not None:
|
||||||
|
sd = {k: v.to(dtype) for k, v in sd.items()}
|
||||||
|
return sd
|
||||||
|
|
||||||
|
|
||||||
|
def load_anima_dit(
|
||||||
|
dit_path: str,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: Union[str, torch.device] = "cpu",
|
||||||
|
transformer_dtype: Optional[torch.dtype] = None,
|
||||||
|
llm_adapter_path: Optional[str] = None,
|
||||||
|
disable_mmap: bool = False,
|
||||||
|
) -> anima_models.MiniTrainDIT:
|
||||||
|
"""Load the MiniTrainDIT model from safetensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dit_path: Path to DiT safetensors file
|
||||||
|
dtype: Base dtype for model parameters
|
||||||
|
device: Device to load to
|
||||||
|
transformer_dtype: Optional separate dtype for transformer blocks (lower precision)
|
||||||
|
llm_adapter_path: Optional separate path for LLM adapter weights
|
||||||
|
disable_mmap: If True, disable memory-mapped loading (reduces peak memory)
|
||||||
|
"""
|
||||||
|
if transformer_dtype is None:
|
||||||
|
transformer_dtype = dtype
|
||||||
|
|
||||||
|
logger.info(f"Loading Anima DiT from {dit_path}")
|
||||||
|
if disable_mmap:
|
||||||
|
from library.safetensors_utils import load_safetensors as load_safetensors_no_mmap
|
||||||
|
state_dict = load_safetensors_no_mmap(dit_path, device="cpu", disable_mmap=True)
|
||||||
|
else:
|
||||||
|
state_dict = load_file(dit_path, device="cpu")
|
||||||
|
|
||||||
|
# Remove 'net.' prefix if present
|
||||||
|
new_state_dict = {}
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
if k.startswith('net.'):
|
||||||
|
k = k[len('net.'):]
|
||||||
|
new_state_dict[k] = v
|
||||||
|
state_dict = new_state_dict
|
||||||
|
|
||||||
|
# Derive config from state_dict
|
||||||
|
dit_config = anima_models.get_dit_config(state_dict)
|
||||||
|
|
||||||
|
# Detect LLM adapter
|
||||||
|
if llm_adapter_path is not None:
|
||||||
|
use_llm_adapter = True
|
||||||
|
dit_config['use_llm_adapter'] = True
|
||||||
|
llm_adapter_state_dict = load_safetensors(llm_adapter_path, device="cpu")
|
||||||
|
elif 'llm_adapter.out_proj.weight' in state_dict:
|
||||||
|
use_llm_adapter = True
|
||||||
|
dit_config['use_llm_adapter'] = True
|
||||||
|
llm_adapter_state_dict = None # Loaded as part of DiT
|
||||||
|
else:
|
||||||
|
use_llm_adapter = False
|
||||||
|
llm_adapter_state_dict = None
|
||||||
|
|
||||||
|
logger.info(f"DiT config: model_channels={dit_config['model_channels']}, num_blocks={dit_config['num_blocks']}, "
|
||||||
|
f"num_heads={dit_config['num_heads']}, use_llm_adapter={use_llm_adapter}")
|
||||||
|
|
||||||
|
# Build model normally on CPU — buffers get proper values from __init__
|
||||||
|
dit = anima_models.MiniTrainDIT(**dit_config)
|
||||||
|
|
||||||
|
# Merge LLM adapter weights into state_dict if loaded separately
|
||||||
|
if use_llm_adapter and llm_adapter_state_dict is not None:
|
||||||
|
for k, v in llm_adapter_state_dict.items():
|
||||||
|
state_dict[f"llm_adapter.{k}"] = v
|
||||||
|
|
||||||
|
# Load checkpoint: strict=False keeps buffers not in checkpoint (e.g. pos_embedder.seq)
|
||||||
|
missing, unexpected = dit.load_state_dict(state_dict, strict=False)
|
||||||
|
if missing:
|
||||||
|
# Filter out expected missing buffers (initialized in __init__, not saved in checkpoint)
|
||||||
|
unexpected_missing = [k for k in missing if not any(
|
||||||
|
buf_name in k for buf_name in ('seq', 'dim_spatial_range', 'dim_temporal_range', 'inv_freq')
|
||||||
|
)]
|
||||||
|
if unexpected_missing:
|
||||||
|
logger.warning(f"Missing keys in checkpoint: {unexpected_missing[:10]}{'...' if len(unexpected_missing) > 10 else ''}")
|
||||||
|
if unexpected:
|
||||||
|
logger.info(f"Unexpected keys in checkpoint (ignored): {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}")
|
||||||
|
|
||||||
|
# Apply per-parameter dtype (high precision for 1D/critical, transformer_dtype for rest)
|
||||||
|
for name, p in dit.named_parameters():
|
||||||
|
dtype_to_use = dtype if (
|
||||||
|
any(keyword in name for keyword in KEEP_IN_HIGH_PRECISION) or p.ndim == 1
|
||||||
|
) else transformer_dtype
|
||||||
|
p.data = p.data.to(dtype=dtype_to_use)
|
||||||
|
|
||||||
|
dit.to(device)
|
||||||
|
logger.info(f"Loaded Anima DiT successfully. Parameters: {sum(p.numel() for p in dit.parameters()):,}")
|
||||||
|
return dit
|
||||||
|
|
||||||
|
|
||||||
|
def load_anima_vae(vae_path: str, dtype: torch.dtype = torch.float32, device: str = "cpu"):
|
||||||
|
"""Load WanVAE from a safetensors/pth file.
|
||||||
|
|
||||||
|
Returns (vae_model, mean_tensor, std_tensor, scale).
|
||||||
|
"""
|
||||||
|
from library.anima_models import ANIMA_VAE_MEAN, ANIMA_VAE_STD
|
||||||
|
|
||||||
|
logger.info(f"Loading Anima VAE from {vae_path}")
|
||||||
|
|
||||||
|
# VAE config (fixed for WanVAE)
|
||||||
|
vae_config = dict(
|
||||||
|
dim=96,
|
||||||
|
z_dim=16,
|
||||||
|
dim_mult=[1, 2, 4, 4],
|
||||||
|
num_res_blocks=2,
|
||||||
|
attn_scales=[],
|
||||||
|
temperal_downsample=[False, True, True],
|
||||||
|
dropout=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
from library.anima_vae import WanVAE_
|
||||||
|
|
||||||
|
# Build model
|
||||||
|
with torch.device('meta'):
|
||||||
|
vae = WanVAE_(**vae_config)
|
||||||
|
|
||||||
|
# Load state dict
|
||||||
|
if vae_path.endswith('.safetensors'):
|
||||||
|
vae_sd = load_file(vae_path, device='cpu')
|
||||||
|
else:
|
||||||
|
vae_sd = torch.load(vae_path, map_location='cpu', weights_only=True)
|
||||||
|
|
||||||
|
vae.load_state_dict(vae_sd, assign=True)
|
||||||
|
vae = vae.eval().requires_grad_(False).to(device, dtype=dtype)
|
||||||
|
|
||||||
|
# Create normalization tensors
|
||||||
|
mean = torch.tensor(ANIMA_VAE_MEAN, dtype=dtype, device=device)
|
||||||
|
std = torch.tensor(ANIMA_VAE_STD, dtype=dtype, device=device)
|
||||||
|
scale = [mean, 1.0 / std]
|
||||||
|
|
||||||
|
logger.info(f"Loaded Anima VAE successfully.")
|
||||||
|
return vae, mean, std, scale
|
||||||
|
|
||||||
|
|
||||||
|
def load_qwen3_tokenizer(qwen3_path: str):
|
||||||
|
"""Load Qwen3 tokenizer only (without the text encoder model).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
qwen3_path: Path to either a directory with model files or a safetensors file.
|
||||||
|
If a directory, loads tokenizer from it directly.
|
||||||
|
If a file, uses configs/qwen3_06b/ for tokenizer config.
|
||||||
|
Returns:
|
||||||
|
tokenizer
|
||||||
|
"""
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
if os.path.isdir(qwen3_path):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(qwen3_path, local_files_only=True)
|
||||||
|
else:
|
||||||
|
config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'configs', 'qwen3_06b')
|
||||||
|
if not os.path.exists(config_dir):
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Qwen3 config directory not found at {config_dir}. "
|
||||||
|
"Expected configs/qwen3_06b/ with config.json, tokenizer.json, etc. "
|
||||||
|
"You can download these from the Qwen3-0.6B HuggingFace repository."
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(config_dir, local_files_only=True)
|
||||||
|
|
||||||
|
if tokenizer.pad_token is None:
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def load_qwen3_text_encoder(qwen3_path: str, dtype: torch.dtype = torch.bfloat16, device: str = "cpu"):
|
||||||
|
"""Load Qwen3-0.6B text encoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
qwen3_path: Path to either a directory with model files or a safetensors file
|
||||||
|
dtype: Model dtype
|
||||||
|
device: Device to load to
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(text_encoder_model, tokenizer)
|
||||||
|
"""
|
||||||
|
import transformers
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
logger.info(f"Loading Qwen3 text encoder from {qwen3_path}")
|
||||||
|
|
||||||
|
if os.path.isdir(qwen3_path):
|
||||||
|
# Directory with full model
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(qwen3_path, local_files_only=True)
|
||||||
|
model = transformers.AutoModelForCausalLM.from_pretrained(
|
||||||
|
qwen3_path, torch_dtype=dtype, local_files_only=True
|
||||||
|
).model
|
||||||
|
else:
|
||||||
|
# Single safetensors file - use configs/qwen3_06b/ for config
|
||||||
|
config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'configs', 'qwen3_06b')
|
||||||
|
if not os.path.exists(config_dir):
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Qwen3 config directory not found at {config_dir}. "
|
||||||
|
"Expected configs/qwen3_06b/ with config.json, tokenizer.json, etc. "
|
||||||
|
"You can download these from the Qwen3-0.6B HuggingFace repository."
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(config_dir, local_files_only=True)
|
||||||
|
qwen3_config = transformers.Qwen3Config.from_pretrained(config_dir, local_files_only=True)
|
||||||
|
model = transformers.Qwen3ForCausalLM(qwen3_config).model
|
||||||
|
|
||||||
|
# Load weights
|
||||||
|
if qwen3_path.endswith('.safetensors'):
|
||||||
|
state_dict = load_file(qwen3_path, device='cpu')
|
||||||
|
else:
|
||||||
|
state_dict = torch.load(qwen3_path, map_location='cpu', weights_only=True)
|
||||||
|
|
||||||
|
# Remove 'model.' prefix if present
|
||||||
|
new_sd = {}
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
if k.startswith('model.'):
|
||||||
|
new_sd[k[len('model.'):]] = v
|
||||||
|
else:
|
||||||
|
new_sd[k] = v
|
||||||
|
|
||||||
|
info = model.load_state_dict(new_sd, strict=False)
|
||||||
|
logger.info(f"Loaded Qwen3 state dict: {info}")
|
||||||
|
|
||||||
|
if tokenizer.pad_token is None:
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
|
model.config.use_cache = False
|
||||||
|
model = model.requires_grad_(False).to(device, dtype=dtype)
|
||||||
|
|
||||||
|
logger.info(f"Loaded Qwen3 text encoder. Parameters: {sum(p.numel() for p in model.parameters()):,}")
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def load_t5_tokenizer(t5_tokenizer_path: Optional[str] = None):
|
||||||
|
"""Load T5 tokenizer for LLM Adapter target tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
t5_tokenizer_path: Optional path to T5 tokenizer directory. If None, uses default configs.
|
||||||
|
"""
|
||||||
|
from transformers import T5TokenizerFast
|
||||||
|
|
||||||
|
if t5_tokenizer_path is not None:
|
||||||
|
return T5TokenizerFast.from_pretrained(t5_tokenizer_path, local_files_only=True)
|
||||||
|
|
||||||
|
# Use bundled config
|
||||||
|
config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'configs', 't5_old')
|
||||||
|
if os.path.exists(config_dir):
|
||||||
|
return T5TokenizerFast(
|
||||||
|
vocab_file=os.path.join(config_dir, 'spiece.model'),
|
||||||
|
tokenizer_file=os.path.join(config_dir, 'tokenizer.json'),
|
||||||
|
)
|
||||||
|
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"T5 tokenizer config directory not found at {config_dir}. "
|
||||||
|
"Expected configs/t5_old/ with spiece.model and tokenizer.json. "
|
||||||
|
"You can download these from the google/t5-v1_1-xxl HuggingFace repository."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def save_anima_model(save_path: str, dit_state_dict: Dict[str, torch.Tensor], dtype: Optional[torch.dtype] = None):
|
||||||
|
"""Save Anima DiT model with 'net.' prefix for ComfyUI compatibility.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
save_path: Output path (.safetensors)
|
||||||
|
dit_state_dict: State dict from dit.state_dict()
|
||||||
|
dtype: Optional dtype to cast to before saving
|
||||||
|
"""
|
||||||
|
prefixed_sd = {}
|
||||||
|
for k, v in dit_state_dict.items():
|
||||||
|
if dtype is not None:
|
||||||
|
v = v.to(dtype)
|
||||||
|
prefixed_sd['net.' + k] = v.contiguous()
|
||||||
|
|
||||||
|
save_file(prefixed_sd, save_path, metadata={'format': 'pt'})
|
||||||
|
logger.info(f"Saved Anima model to {save_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def vae_encode(tensor: torch.Tensor, vae, scale):
|
||||||
|
"""Encode tensor through WanVAE with normalization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor: Input tensor (B, C, T, H, W) in [-1, 1] range
|
||||||
|
vae: WanVAE_ model
|
||||||
|
scale: [mean, 1/std] list
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized latents
|
||||||
|
"""
|
||||||
|
return vae.encode(tensor, scale)
|
||||||
|
|
||||||
|
|
||||||
|
def vae_decode(latents: torch.Tensor, vae, scale):
|
||||||
|
"""Decode latents through WanVAE with denormalization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
latents: Normalized latents
|
||||||
|
vae: WanVAE_ model
|
||||||
|
scale: [mean, 1/std] list
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Decoded tensor in [-1, 1] range
|
||||||
|
"""
|
||||||
|
return vae.decode(latents, scale)
|
||||||
577
library/anima_vae.py
Normal file
577
library/anima_vae.py
Normal file
@@ -0,0 +1,577 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.cuda.amp as amp
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
CACHE_T = 2
|
||||||
|
|
||||||
|
|
||||||
|
class CausalConv3d(nn.Conv3d):
|
||||||
|
"""
|
||||||
|
Causal 3d convolusion.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._padding = (self.padding[2], self.padding[2], self.padding[1],
|
||||||
|
self.padding[1], 2 * self.padding[0], 0)
|
||||||
|
self.padding = (0, 0, 0)
|
||||||
|
|
||||||
|
def forward(self, x, cache_x=None):
|
||||||
|
padding = list(self._padding)
|
||||||
|
if cache_x is not None and self._padding[4] > 0:
|
||||||
|
cache_x = cache_x.to(x.device)
|
||||||
|
x = torch.cat([cache_x, x], dim=2)
|
||||||
|
padding[4] -= cache_x.shape[2]
|
||||||
|
x = F.pad(x, padding)
|
||||||
|
|
||||||
|
return super().forward(x)
|
||||||
|
|
||||||
|
|
||||||
|
class RMS_norm(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim, channel_first=True, images=True, bias=False):
|
||||||
|
super().__init__()
|
||||||
|
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
||||||
|
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
||||||
|
|
||||||
|
self.channel_first = channel_first
|
||||||
|
self.scale = dim**0.5
|
||||||
|
self.gamma = nn.Parameter(torch.ones(shape))
|
||||||
|
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return F.normalize(
|
||||||
|
x, dim=(1 if self.channel_first else
|
||||||
|
-1)) * self.scale * self.gamma + self.bias
|
||||||
|
|
||||||
|
|
||||||
|
class Upsample(nn.Upsample):
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Fix bfloat16 support for nearest neighbor interpolation.
|
||||||
|
"""
|
||||||
|
return super().forward(x.float()).type_as(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Resample(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, dim, mode):
|
||||||
|
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
|
||||||
|
'downsample3d')
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.mode = mode
|
||||||
|
|
||||||
|
# layers
|
||||||
|
if mode == 'upsample2d':
|
||||||
|
self.resample = nn.Sequential(
|
||||||
|
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||||||
|
nn.Conv2d(dim, dim // 2, 3, padding=1))
|
||||||
|
elif mode == 'upsample3d':
|
||||||
|
self.resample = nn.Sequential(
|
||||||
|
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||||||
|
nn.Conv2d(dim, dim // 2, 3, padding=1))
|
||||||
|
self.time_conv = CausalConv3d(
|
||||||
|
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||||||
|
|
||||||
|
elif mode == 'downsample2d':
|
||||||
|
self.resample = nn.Sequential(
|
||||||
|
nn.ZeroPad2d((0, 1, 0, 1)),
|
||||||
|
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||||||
|
elif mode == 'downsample3d':
|
||||||
|
self.resample = nn.Sequential(
|
||||||
|
nn.ZeroPad2d((0, 1, 0, 1)),
|
||||||
|
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||||||
|
self.time_conv = CausalConv3d(
|
||||||
|
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.resample = nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||||
|
b, c, t, h, w = x.size()
|
||||||
|
if self.mode == 'upsample3d':
|
||||||
|
if feat_cache is not None:
|
||||||
|
idx = feat_idx[0]
|
||||||
|
if feat_cache[idx] is None:
|
||||||
|
feat_cache[idx] = 'Rep'
|
||||||
|
feat_idx[0] += 1
|
||||||
|
else:
|
||||||
|
|
||||||
|
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||||
|
if cache_x.shape[2] < 2 and feat_cache[
|
||||||
|
idx] is not None and feat_cache[idx] != 'Rep':
|
||||||
|
# cache last frame of last two chunk
|
||||||
|
cache_x = torch.cat([
|
||||||
|
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||||
|
cache_x.device), cache_x
|
||||||
|
],
|
||||||
|
dim=2)
|
||||||
|
if cache_x.shape[2] < 2 and feat_cache[
|
||||||
|
idx] is not None and feat_cache[idx] == 'Rep':
|
||||||
|
cache_x = torch.cat([
|
||||||
|
torch.zeros_like(cache_x).to(cache_x.device),
|
||||||
|
cache_x
|
||||||
|
],
|
||||||
|
dim=2)
|
||||||
|
if feat_cache[idx] == 'Rep':
|
||||||
|
x = self.time_conv(x)
|
||||||
|
else:
|
||||||
|
x = self.time_conv(x, feat_cache[idx])
|
||||||
|
feat_cache[idx] = cache_x
|
||||||
|
feat_idx[0] += 1
|
||||||
|
|
||||||
|
x = x.reshape(b, 2, c, t, h, w)
|
||||||
|
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
|
||||||
|
3)
|
||||||
|
x = x.reshape(b, c, t * 2, h, w)
|
||||||
|
t = x.shape[2]
|
||||||
|
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||||||
|
x = self.resample(x)
|
||||||
|
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
|
||||||
|
|
||||||
|
if self.mode == 'downsample3d':
|
||||||
|
if feat_cache is not None:
|
||||||
|
idx = feat_idx[0]
|
||||||
|
if feat_cache[idx] is None:
|
||||||
|
feat_cache[idx] = x.clone()
|
||||||
|
feat_idx[0] += 1
|
||||||
|
else:
|
||||||
|
|
||||||
|
cache_x = x[:, :, -1:, :, :].clone()
|
||||||
|
x = self.time_conv(
|
||||||
|
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
||||||
|
feat_cache[idx] = cache_x
|
||||||
|
feat_idx[0] += 1
|
||||||
|
return x
|
||||||
|
|
||||||
|
def init_weight(self, conv):
|
||||||
|
conv_weight = conv.weight
|
||||||
|
nn.init.zeros_(conv_weight)
|
||||||
|
c1, c2, t, h, w = conv_weight.size()
|
||||||
|
one_matrix = torch.eye(c1, c2)
|
||||||
|
init_matrix = one_matrix
|
||||||
|
nn.init.zeros_(conv_weight)
|
||||||
|
conv_weight.data[:, :, 1, 0, 0] = init_matrix
|
||||||
|
conv.weight.data.copy_(conv_weight)
|
||||||
|
nn.init.zeros_(conv.bias.data)
|
||||||
|
|
||||||
|
def init_weight2(self, conv):
|
||||||
|
conv_weight = conv.weight.data
|
||||||
|
nn.init.zeros_(conv_weight)
|
||||||
|
c1, c2, t, h, w = conv_weight.size()
|
||||||
|
init_matrix = torch.eye(c1 // 2, c2)
|
||||||
|
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
|
||||||
|
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
|
||||||
|
conv.weight.data.copy_(conv_weight)
|
||||||
|
nn.init.zeros_(conv.bias.data)
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_dim, out_dim, dropout=0.0):
|
||||||
|
super().__init__()
|
||||||
|
self.in_dim = in_dim
|
||||||
|
self.out_dim = out_dim
|
||||||
|
|
||||||
|
# layers
|
||||||
|
self.residual = nn.Sequential(
|
||||||
|
RMS_norm(in_dim, images=False), nn.SiLU(),
|
||||||
|
CausalConv3d(in_dim, out_dim, 3, padding=1),
|
||||||
|
RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
|
||||||
|
CausalConv3d(out_dim, out_dim, 3, padding=1))
|
||||||
|
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
|
||||||
|
if in_dim != out_dim else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||||
|
h = self.shortcut(x)
|
||||||
|
for layer in self.residual:
|
||||||
|
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||||
|
idx = feat_idx[0]
|
||||||
|
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||||
|
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||||
|
# cache last frame of last two chunk
|
||||||
|
cache_x = torch.cat([
|
||||||
|
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||||
|
cache_x.device), cache_x
|
||||||
|
],
|
||||||
|
dim=2)
|
||||||
|
x = layer(x, feat_cache[idx])
|
||||||
|
feat_cache[idx] = cache_x
|
||||||
|
feat_idx[0] += 1
|
||||||
|
else:
|
||||||
|
x = layer(x)
|
||||||
|
return x + h
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
Causal self-attention with a single head.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
# layers
|
||||||
|
self.norm = RMS_norm(dim)
|
||||||
|
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
||||||
|
self.proj = nn.Conv2d(dim, dim, 1)
|
||||||
|
|
||||||
|
# zero out the last layer params
|
||||||
|
nn.init.zeros_(self.proj.weight)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
identity = x
|
||||||
|
b, c, t, h, w = x.size()
|
||||||
|
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||||||
|
x = self.norm(x)
|
||||||
|
# compute query, key, value
|
||||||
|
q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3,
|
||||||
|
-1).permute(0, 1, 3,
|
||||||
|
2).contiguous().chunk(
|
||||||
|
3, dim=-1)
|
||||||
|
|
||||||
|
# apply attention
|
||||||
|
x = F.scaled_dot_product_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
)
|
||||||
|
x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
|
||||||
|
|
||||||
|
# output
|
||||||
|
x = self.proj(x)
|
||||||
|
x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
|
||||||
|
return x + identity
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder3d(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
dim=128,
|
||||||
|
z_dim=4,
|
||||||
|
dim_mult=[1, 2, 4, 4],
|
||||||
|
num_res_blocks=2,
|
||||||
|
attn_scales=[],
|
||||||
|
temperal_downsample=[True, True, False],
|
||||||
|
dropout=0.0):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.z_dim = z_dim
|
||||||
|
self.dim_mult = dim_mult
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.attn_scales = attn_scales
|
||||||
|
self.temperal_downsample = temperal_downsample
|
||||||
|
|
||||||
|
# dimensions
|
||||||
|
dims = [dim * u for u in [1] + dim_mult]
|
||||||
|
scale = 1.0
|
||||||
|
|
||||||
|
# init block
|
||||||
|
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
|
||||||
|
|
||||||
|
# downsample blocks
|
||||||
|
downsamples = []
|
||||||
|
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||||
|
# residual (+attention) blocks
|
||||||
|
for _ in range(num_res_blocks):
|
||||||
|
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||||||
|
if scale in attn_scales:
|
||||||
|
downsamples.append(AttentionBlock(out_dim))
|
||||||
|
in_dim = out_dim
|
||||||
|
|
||||||
|
# downsample block
|
||||||
|
if i != len(dim_mult) - 1:
|
||||||
|
mode = 'downsample3d' if temperal_downsample[
|
||||||
|
i] else 'downsample2d'
|
||||||
|
downsamples.append(Resample(out_dim, mode=mode))
|
||||||
|
scale /= 2.0
|
||||||
|
self.downsamples = nn.Sequential(*downsamples)
|
||||||
|
|
||||||
|
# middle blocks
|
||||||
|
self.middle = nn.Sequential(
|
||||||
|
ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
|
||||||
|
ResidualBlock(out_dim, out_dim, dropout))
|
||||||
|
|
||||||
|
# output blocks
|
||||||
|
self.head = nn.Sequential(
|
||||||
|
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||||
|
CausalConv3d(out_dim, z_dim, 3, padding=1))
|
||||||
|
|
||||||
|
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||||
|
if feat_cache is not None:
|
||||||
|
idx = feat_idx[0]
|
||||||
|
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||||
|
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||||
|
# cache last frame of last two chunk
|
||||||
|
cache_x = torch.cat([
|
||||||
|
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||||
|
cache_x.device), cache_x
|
||||||
|
],
|
||||||
|
dim=2)
|
||||||
|
x = self.conv1(x, feat_cache[idx])
|
||||||
|
feat_cache[idx] = cache_x
|
||||||
|
feat_idx[0] += 1
|
||||||
|
else:
|
||||||
|
x = self.conv1(x)
|
||||||
|
|
||||||
|
## downsamples
|
||||||
|
for layer in self.downsamples:
|
||||||
|
if feat_cache is not None:
|
||||||
|
x = layer(x, feat_cache, feat_idx)
|
||||||
|
else:
|
||||||
|
x = layer(x)
|
||||||
|
|
||||||
|
## middle
|
||||||
|
for layer in self.middle:
|
||||||
|
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||||
|
x = layer(x, feat_cache, feat_idx)
|
||||||
|
else:
|
||||||
|
x = layer(x)
|
||||||
|
|
||||||
|
## head
|
||||||
|
for layer in self.head:
|
||||||
|
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||||
|
idx = feat_idx[0]
|
||||||
|
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||||
|
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||||
|
# cache last frame of last two chunk
|
||||||
|
cache_x = torch.cat([
|
||||||
|
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||||
|
cache_x.device), cache_x
|
||||||
|
],
|
||||||
|
dim=2)
|
||||||
|
x = layer(x, feat_cache[idx])
|
||||||
|
feat_cache[idx] = cache_x
|
||||||
|
feat_idx[0] += 1
|
||||||
|
else:
|
||||||
|
x = layer(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder3d(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
dim=128,
|
||||||
|
z_dim=4,
|
||||||
|
dim_mult=[1, 2, 4, 4],
|
||||||
|
num_res_blocks=2,
|
||||||
|
attn_scales=[],
|
||||||
|
temperal_upsample=[False, True, True],
|
||||||
|
dropout=0.0):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.z_dim = z_dim
|
||||||
|
self.dim_mult = dim_mult
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.attn_scales = attn_scales
|
||||||
|
self.temperal_upsample = temperal_upsample
|
||||||
|
|
||||||
|
# dimensions
|
||||||
|
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
||||||
|
scale = 1.0 / 2**(len(dim_mult) - 2)
|
||||||
|
|
||||||
|
# init block
|
||||||
|
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
||||||
|
|
||||||
|
# middle blocks
|
||||||
|
self.middle = nn.Sequential(
|
||||||
|
ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
|
||||||
|
ResidualBlock(dims[0], dims[0], dropout))
|
||||||
|
|
||||||
|
# upsample blocks
|
||||||
|
upsamples = []
|
||||||
|
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||||
|
# residual (+attention) blocks
|
||||||
|
if i == 1 or i == 2 or i == 3:
|
||||||
|
in_dim = in_dim // 2
|
||||||
|
for _ in range(num_res_blocks + 1):
|
||||||
|
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||||||
|
if scale in attn_scales:
|
||||||
|
upsamples.append(AttentionBlock(out_dim))
|
||||||
|
in_dim = out_dim
|
||||||
|
|
||||||
|
# upsample block
|
||||||
|
if i != len(dim_mult) - 1:
|
||||||
|
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
|
||||||
|
upsamples.append(Resample(out_dim, mode=mode))
|
||||||
|
scale *= 2.0
|
||||||
|
self.upsamples = nn.Sequential(*upsamples)
|
||||||
|
|
||||||
|
# output blocks
|
||||||
|
self.head = nn.Sequential(
|
||||||
|
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||||
|
CausalConv3d(out_dim, 3, 3, padding=1))
|
||||||
|
|
||||||
|
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||||
|
## conv1
|
||||||
|
if feat_cache is not None:
|
||||||
|
idx = feat_idx[0]
|
||||||
|
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||||
|
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||||
|
# cache last frame of last two chunk
|
||||||
|
cache_x = torch.cat([
|
||||||
|
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||||
|
cache_x.device), cache_x
|
||||||
|
],
|
||||||
|
dim=2)
|
||||||
|
x = self.conv1(x, feat_cache[idx])
|
||||||
|
feat_cache[idx] = cache_x
|
||||||
|
feat_idx[0] += 1
|
||||||
|
else:
|
||||||
|
x = self.conv1(x)
|
||||||
|
|
||||||
|
## middle
|
||||||
|
for layer in self.middle:
|
||||||
|
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||||
|
x = layer(x, feat_cache, feat_idx)
|
||||||
|
else:
|
||||||
|
x = layer(x)
|
||||||
|
|
||||||
|
## upsamples
|
||||||
|
for layer in self.upsamples:
|
||||||
|
if feat_cache is not None:
|
||||||
|
x = layer(x, feat_cache, feat_idx)
|
||||||
|
else:
|
||||||
|
x = layer(x)
|
||||||
|
|
||||||
|
## head
|
||||||
|
for layer in self.head:
|
||||||
|
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||||
|
idx = feat_idx[0]
|
||||||
|
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||||
|
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||||
|
# cache last frame of last two chunk
|
||||||
|
cache_x = torch.cat([
|
||||||
|
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||||
|
cache_x.device), cache_x
|
||||||
|
],
|
||||||
|
dim=2)
|
||||||
|
x = layer(x, feat_cache[idx])
|
||||||
|
feat_cache[idx] = cache_x
|
||||||
|
feat_idx[0] += 1
|
||||||
|
else:
|
||||||
|
x = layer(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def count_conv3d(model):
|
||||||
|
count = 0
|
||||||
|
for m in model.modules():
|
||||||
|
if isinstance(m, CausalConv3d):
|
||||||
|
count += 1
|
||||||
|
return count
|
||||||
|
|
||||||
|
|
||||||
|
class WanVAE_(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
dim=128,
|
||||||
|
z_dim=4,
|
||||||
|
dim_mult=[1, 2, 4, 4],
|
||||||
|
num_res_blocks=2,
|
||||||
|
attn_scales=[],
|
||||||
|
temperal_downsample=[True, True, False],
|
||||||
|
dropout=0.0):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.z_dim = z_dim
|
||||||
|
self.dim_mult = dim_mult
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.attn_scales = attn_scales
|
||||||
|
self.temperal_downsample = temperal_downsample
|
||||||
|
self.temperal_upsample = temperal_downsample[::-1]
|
||||||
|
|
||||||
|
# modules
|
||||||
|
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
|
||||||
|
attn_scales, self.temperal_downsample, dropout)
|
||||||
|
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||||
|
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||||||
|
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
|
||||||
|
attn_scales, self.temperal_upsample, dropout)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
mu, log_var = self.encode(x)
|
||||||
|
z = self.reparameterize(mu, log_var)
|
||||||
|
x_recon = self.decode(z)
|
||||||
|
return x_recon, mu, log_var
|
||||||
|
|
||||||
|
def encode(self, x, scale):
|
||||||
|
self.clear_cache()
|
||||||
|
## cache
|
||||||
|
t = x.shape[2]
|
||||||
|
iter_ = 1 + (t - 1) // 4
|
||||||
|
for i in range(iter_):
|
||||||
|
self._enc_conv_idx = [0]
|
||||||
|
if i == 0:
|
||||||
|
out = self.encoder(
|
||||||
|
x[:, :, :1, :, :],
|
||||||
|
feat_cache=self._enc_feat_map,
|
||||||
|
feat_idx=self._enc_conv_idx)
|
||||||
|
else:
|
||||||
|
out_ = self.encoder(
|
||||||
|
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||||||
|
feat_cache=self._enc_feat_map,
|
||||||
|
feat_idx=self._enc_conv_idx)
|
||||||
|
out = torch.cat([out, out_], 2)
|
||||||
|
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
||||||
|
if isinstance(scale[0], torch.Tensor):
|
||||||
|
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
|
||||||
|
1, self.z_dim, 1, 1, 1)
|
||||||
|
else:
|
||||||
|
mu = (mu - scale[0]) * scale[1]
|
||||||
|
self.clear_cache()
|
||||||
|
return mu
|
||||||
|
|
||||||
|
def decode(self, z, scale):
|
||||||
|
self.clear_cache()
|
||||||
|
# z: [b,c,t,h,w]
|
||||||
|
if isinstance(scale[0], torch.Tensor):
|
||||||
|
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
|
||||||
|
1, self.z_dim, 1, 1, 1)
|
||||||
|
else:
|
||||||
|
z = z / scale[1] + scale[0]
|
||||||
|
iter_ = z.shape[2]
|
||||||
|
x = self.conv2(z)
|
||||||
|
for i in range(iter_):
|
||||||
|
self._conv_idx = [0]
|
||||||
|
if i == 0:
|
||||||
|
out = self.decoder(
|
||||||
|
x[:, :, i:i + 1, :, :],
|
||||||
|
feat_cache=self._feat_map,
|
||||||
|
feat_idx=self._conv_idx)
|
||||||
|
else:
|
||||||
|
out_ = self.decoder(
|
||||||
|
x[:, :, i:i + 1, :, :],
|
||||||
|
feat_cache=self._feat_map,
|
||||||
|
feat_idx=self._conv_idx)
|
||||||
|
out = torch.cat([out, out_], 2)
|
||||||
|
self.clear_cache()
|
||||||
|
return out
|
||||||
|
|
||||||
|
def reparameterize(self, mu, log_var):
|
||||||
|
std = torch.exp(0.5 * log_var)
|
||||||
|
eps = torch.randn_like(std)
|
||||||
|
return eps * std + mu
|
||||||
|
|
||||||
|
def sample(self, imgs, deterministic=False):
|
||||||
|
mu, log_var = self.encode(imgs)
|
||||||
|
if deterministic:
|
||||||
|
return mu
|
||||||
|
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
|
||||||
|
return mu + std * torch.randn_like(std)
|
||||||
|
|
||||||
|
def clear_cache(self):
|
||||||
|
self._conv_num = count_conv3d(self.decoder)
|
||||||
|
self._conv_idx = [0]
|
||||||
|
self._feat_map = [None] * self._conv_num
|
||||||
|
#cache encode
|
||||||
|
self._enc_conv_num = count_conv3d(self.encoder)
|
||||||
|
self._enc_conv_idx = [0]
|
||||||
|
self._enc_feat_map = [None] * self._enc_conv_num
|
||||||
429
library/strategy_anima.py
Normal file
429
library/strategy_anima.py
Normal file
@@ -0,0 +1,429 @@
|
|||||||
|
# Anima Strategy Classes
|
||||||
|
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from typing import Any, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from library import anima_utils, train_util
|
||||||
|
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
|
||||||
|
|
||||||
|
from library.utils import setup_logging
|
||||||
|
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AnimaTokenizeStrategy(TokenizeStrategy):
|
||||||
|
"""Tokenize strategy for Anima: dual tokenization with Qwen3 + T5.
|
||||||
|
|
||||||
|
Qwen3 tokens are used for the text encoder.
|
||||||
|
T5 tokens are used as target input IDs for the LLM Adapter (NOT encoded by T5).
|
||||||
|
|
||||||
|
Can be initialized with either pre-loaded tokenizer objects or paths to load from.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
qwen3_tokenizer=None,
|
||||||
|
t5_tokenizer=None,
|
||||||
|
qwen3_max_length: int = 512,
|
||||||
|
t5_max_length: int = 512,
|
||||||
|
qwen3_path: Optional[str] = None,
|
||||||
|
t5_tokenizer_path: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
# Load tokenizers from paths if not provided directly
|
||||||
|
if qwen3_tokenizer is None:
|
||||||
|
if qwen3_path is None:
|
||||||
|
raise ValueError("Either qwen3_tokenizer or qwen3_path must be provided")
|
||||||
|
qwen3_tokenizer = anima_utils.load_qwen3_tokenizer(qwen3_path)
|
||||||
|
if t5_tokenizer is None:
|
||||||
|
t5_tokenizer = anima_utils.load_t5_tokenizer(t5_tokenizer_path)
|
||||||
|
|
||||||
|
self.qwen3_tokenizer = qwen3_tokenizer
|
||||||
|
self.t5_tokenizer = t5_tokenizer
|
||||||
|
self.qwen3_max_length = qwen3_max_length
|
||||||
|
self.t5_max_length = t5_max_length
|
||||||
|
|
||||||
|
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
|
||||||
|
text = [text] if isinstance(text, str) else text
|
||||||
|
|
||||||
|
# Tokenize with Qwen3
|
||||||
|
qwen3_encoding = self.qwen3_tokenizer.batch_encode_plus(
|
||||||
|
text,
|
||||||
|
return_tensors="pt",
|
||||||
|
truncation=True,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=self.qwen3_max_length,
|
||||||
|
)
|
||||||
|
qwen3_input_ids = qwen3_encoding["input_ids"]
|
||||||
|
qwen3_attn_mask = qwen3_encoding["attention_mask"]
|
||||||
|
|
||||||
|
# Tokenize with T5 (for LLM Adapter target tokens)
|
||||||
|
t5_encoding = self.t5_tokenizer.batch_encode_plus(
|
||||||
|
text,
|
||||||
|
return_tensors="pt",
|
||||||
|
truncation=True,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=self.t5_max_length,
|
||||||
|
)
|
||||||
|
t5_input_ids = t5_encoding["input_ids"]
|
||||||
|
t5_attn_mask = t5_encoding["attention_mask"]
|
||||||
|
|
||||||
|
return [qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask]
|
||||||
|
|
||||||
|
|
||||||
|
class AnimaTextEncodingStrategy(TextEncodingStrategy):
|
||||||
|
"""Text encoding strategy for Anima.
|
||||||
|
|
||||||
|
Encodes Qwen3 tokens through the Qwen3 text encoder to get hidden states.
|
||||||
|
T5 tokens are passed through unchanged (only used by LLM Adapter).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dropout_rate: float = 0.0,
|
||||||
|
) -> None:
|
||||||
|
self.dropout_rate = dropout_rate
|
||||||
|
# Cached unconditional embeddings (from encoding empty caption "")
|
||||||
|
# Must be initialized via cache_uncond_embeddings() before text encoder is deleted
|
||||||
|
self._uncond_prompt_embeds: Optional[torch.Tensor] = None # (1, seq_len, hidden)
|
||||||
|
self._uncond_attn_mask: Optional[torch.Tensor] = None # (1, seq_len)
|
||||||
|
self._uncond_t5_input_ids: Optional[torch.Tensor] = None # (1, t5_seq_len)
|
||||||
|
self._uncond_t5_attn_mask: Optional[torch.Tensor] = None # (1, t5_seq_len)
|
||||||
|
|
||||||
|
def cache_uncond_embeddings(
|
||||||
|
self,
|
||||||
|
tokenize_strategy: TokenizeStrategy,
|
||||||
|
models: List[Any],
|
||||||
|
) -> None:
|
||||||
|
"""Pre-encode empty caption "" and cache the unconditional embeddings.
|
||||||
|
|
||||||
|
Must be called before the text encoder is deleted from GPU.
|
||||||
|
This matches diffusion-pipe-main behavior where empty caption embeddings
|
||||||
|
are pre-cached and swapped in during caption dropout.
|
||||||
|
"""
|
||||||
|
logger.info("Caching unconditional embeddings for caption dropout (encoding empty caption)...")
|
||||||
|
tokens = tokenize_strategy.tokenize("")
|
||||||
|
with torch.no_grad():
|
||||||
|
uncond_outputs = self.encode_tokens(tokenize_strategy, models, tokens, enable_dropout=False)
|
||||||
|
# Store as CPU tensors (1, seq_len, ...) to avoid GPU memory waste
|
||||||
|
self._uncond_prompt_embeds = uncond_outputs[0].cpu()
|
||||||
|
self._uncond_attn_mask = uncond_outputs[1].cpu()
|
||||||
|
self._uncond_t5_input_ids = uncond_outputs[2].cpu()
|
||||||
|
self._uncond_t5_attn_mask = uncond_outputs[3].cpu()
|
||||||
|
logger.info(" Unconditional embeddings cached successfully")
|
||||||
|
|
||||||
|
def encode_tokens(
|
||||||
|
self,
|
||||||
|
tokenize_strategy: TokenizeStrategy,
|
||||||
|
models: List[Any],
|
||||||
|
tokens: List[torch.Tensor],
|
||||||
|
enable_dropout: bool = True,
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
"""Encode Qwen3 tokens and return embeddings + T5 token IDs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
models: [qwen3_text_encoder]
|
||||||
|
tokens: [qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
|
||||||
|
"""
|
||||||
|
|
||||||
|
qwen3_text_encoder = models[0]
|
||||||
|
qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask = tokens
|
||||||
|
|
||||||
|
# Handle dropout: replace dropped items with unconditional embeddings (matching diffusion-pipe-main)
|
||||||
|
batch_size = qwen3_input_ids.shape[0]
|
||||||
|
non_drop_indices = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
drop = enable_dropout and (self.dropout_rate > 0.0 and random.random() < self.dropout_rate)
|
||||||
|
if not drop:
|
||||||
|
non_drop_indices.append(i)
|
||||||
|
|
||||||
|
encoder_device = qwen3_text_encoder.device if hasattr(qwen3_text_encoder, 'device') else next(qwen3_text_encoder.parameters()).device
|
||||||
|
|
||||||
|
if len(non_drop_indices) > 0 and len(non_drop_indices) < batch_size:
|
||||||
|
# Only encode non-dropped items to save compute
|
||||||
|
nd_input_ids = qwen3_input_ids[non_drop_indices].to(encoder_device)
|
||||||
|
nd_attn_mask = qwen3_attn_mask[non_drop_indices].to(encoder_device)
|
||||||
|
elif len(non_drop_indices) == batch_size:
|
||||||
|
nd_input_ids = qwen3_input_ids.to(encoder_device)
|
||||||
|
nd_attn_mask = qwen3_attn_mask.to(encoder_device)
|
||||||
|
else:
|
||||||
|
nd_input_ids = None
|
||||||
|
nd_attn_mask = None
|
||||||
|
|
||||||
|
if nd_input_ids is not None:
|
||||||
|
outputs = qwen3_text_encoder(input_ids=nd_input_ids, attention_mask=nd_attn_mask)
|
||||||
|
nd_encoded_text = outputs.last_hidden_state
|
||||||
|
# Zero out padding positions
|
||||||
|
nd_encoded_text[~nd_attn_mask.bool()] = 0
|
||||||
|
|
||||||
|
# Build full batch: fill non-dropped with encoded, dropped with unconditional
|
||||||
|
if len(non_drop_indices) == batch_size:
|
||||||
|
prompt_embeds = nd_encoded_text
|
||||||
|
attn_mask = qwen3_attn_mask.to(encoder_device)
|
||||||
|
else:
|
||||||
|
# Get unconditional embeddings
|
||||||
|
if self._uncond_prompt_embeds is not None:
|
||||||
|
uncond_pe = self._uncond_prompt_embeds[0]
|
||||||
|
uncond_am = self._uncond_attn_mask[0]
|
||||||
|
uncond_t5_ids = self._uncond_t5_input_ids[0]
|
||||||
|
uncond_t5_am = self._uncond_t5_attn_mask[0]
|
||||||
|
else:
|
||||||
|
# Encode empty caption on-the-fly (text encoder still available)
|
||||||
|
uncond_tokens = tokenize_strategy.tokenize("")
|
||||||
|
uncond_ids = uncond_tokens[0].to(encoder_device)
|
||||||
|
uncond_mask = uncond_tokens[1].to(encoder_device)
|
||||||
|
uncond_out = qwen3_text_encoder(input_ids=uncond_ids, attention_mask=uncond_mask)
|
||||||
|
uncond_pe = uncond_out.last_hidden_state[0]
|
||||||
|
uncond_pe[~uncond_mask[0].bool()] = 0
|
||||||
|
uncond_am = uncond_mask[0]
|
||||||
|
uncond_t5_ids = uncond_tokens[2][0]
|
||||||
|
uncond_t5_am = uncond_tokens[3][0]
|
||||||
|
|
||||||
|
seq_len = qwen3_input_ids.shape[1]
|
||||||
|
hidden_size = nd_encoded_text.shape[-1] if nd_encoded_text is not None else uncond_pe.shape[-1]
|
||||||
|
dtype = nd_encoded_text.dtype if nd_encoded_text is not None else uncond_pe.dtype
|
||||||
|
|
||||||
|
prompt_embeds = torch.zeros((batch_size, seq_len, hidden_size), device=encoder_device, dtype=dtype)
|
||||||
|
attn_mask = torch.zeros((batch_size, seq_len), device=encoder_device, dtype=qwen3_attn_mask.dtype)
|
||||||
|
|
||||||
|
if len(non_drop_indices) > 0:
|
||||||
|
prompt_embeds[non_drop_indices] = nd_encoded_text
|
||||||
|
attn_mask[non_drop_indices] = nd_attn_mask
|
||||||
|
|
||||||
|
# Fill dropped items with unconditional embeddings
|
||||||
|
t5_input_ids = t5_input_ids.clone()
|
||||||
|
t5_attn_mask = t5_attn_mask.clone()
|
||||||
|
drop_indices = [i for i in range(batch_size) if i not in non_drop_indices]
|
||||||
|
for i in drop_indices:
|
||||||
|
prompt_embeds[i] = uncond_pe.to(device=encoder_device, dtype=dtype)
|
||||||
|
attn_mask[i] = uncond_am.to(device=encoder_device, dtype=qwen3_attn_mask.dtype)
|
||||||
|
t5_input_ids[i] = uncond_t5_ids.to(device=t5_input_ids.device, dtype=t5_input_ids.dtype)
|
||||||
|
t5_attn_mask[i] = uncond_t5_am.to(device=t5_attn_mask.device, dtype=t5_attn_mask.dtype)
|
||||||
|
|
||||||
|
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
|
||||||
|
|
||||||
|
def drop_cached_text_encoder_outputs(
|
||||||
|
self,
|
||||||
|
prompt_embeds: torch.Tensor,
|
||||||
|
attn_mask: torch.Tensor,
|
||||||
|
t5_input_ids: torch.Tensor,
|
||||||
|
t5_attn_mask: torch.Tensor,
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
"""Apply dropout to cached text encoder outputs.
|
||||||
|
|
||||||
|
Called during training when using cached outputs.
|
||||||
|
Replaces dropped items with pre-cached unconditional embeddings (from encoding "")
|
||||||
|
to match diffusion-pipe-main behavior.
|
||||||
|
"""
|
||||||
|
if prompt_embeds is not None and self.dropout_rate > 0.0:
|
||||||
|
# Clone to avoid in-place modification of cached tensors
|
||||||
|
prompt_embeds = prompt_embeds.clone()
|
||||||
|
if attn_mask is not None:
|
||||||
|
attn_mask = attn_mask.clone()
|
||||||
|
if t5_input_ids is not None:
|
||||||
|
t5_input_ids = t5_input_ids.clone()
|
||||||
|
if t5_attn_mask is not None:
|
||||||
|
t5_attn_mask = t5_attn_mask.clone()
|
||||||
|
|
||||||
|
for i in range(prompt_embeds.shape[0]):
|
||||||
|
if random.random() < self.dropout_rate:
|
||||||
|
if self._uncond_prompt_embeds is not None:
|
||||||
|
# Use pre-cached unconditional embeddings
|
||||||
|
prompt_embeds[i] = self._uncond_prompt_embeds[0].to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)
|
||||||
|
if attn_mask is not None:
|
||||||
|
attn_mask[i] = self._uncond_attn_mask[0].to(device=attn_mask.device, dtype=attn_mask.dtype)
|
||||||
|
if t5_input_ids is not None:
|
||||||
|
t5_input_ids[i] = self._uncond_t5_input_ids[0].to(device=t5_input_ids.device, dtype=t5_input_ids.dtype)
|
||||||
|
if t5_attn_mask is not None:
|
||||||
|
t5_attn_mask[i] = self._uncond_t5_attn_mask[0].to(device=t5_attn_mask.device, dtype=t5_attn_mask.dtype)
|
||||||
|
else:
|
||||||
|
# Fallback: zero out (should not happen if cache_uncond_embeddings was called)
|
||||||
|
logger.warning("Unconditional embeddings not cached, falling back to zeros for caption dropout")
|
||||||
|
prompt_embeds[i] = torch.zeros_like(prompt_embeds[i])
|
||||||
|
if attn_mask is not None:
|
||||||
|
attn_mask[i] = torch.zeros_like(attn_mask[i])
|
||||||
|
if t5_input_ids is not None:
|
||||||
|
t5_input_ids[i] = torch.zeros_like(t5_input_ids[i])
|
||||||
|
if t5_attn_mask is not None:
|
||||||
|
t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i])
|
||||||
|
|
||||||
|
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
|
||||||
|
|
||||||
|
|
||||||
|
class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||||
|
"""Caching strategy for Anima text encoder outputs.
|
||||||
|
|
||||||
|
Caches: prompt_embeds (float), attn_mask (int), t5_input_ids (int), t5_attn_mask (int)
|
||||||
|
"""
|
||||||
|
|
||||||
|
ANIMA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_anima_te.npz"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cache_to_disk: bool,
|
||||||
|
batch_size: int,
|
||||||
|
skip_disk_cache_validity_check: bool,
|
||||||
|
is_partial: bool = False,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
|
||||||
|
|
||||||
|
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||||
|
return os.path.splitext(image_abs_path)[0] + self.ANIMA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
||||||
|
|
||||||
|
def is_disk_cached_outputs_expected(self, npz_path: str) -> bool:
|
||||||
|
if not self.cache_to_disk:
|
||||||
|
return False
|
||||||
|
if not os.path.exists(npz_path):
|
||||||
|
return False
|
||||||
|
if self.skip_disk_cache_validity_check:
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
npz = np.load(npz_path)
|
||||||
|
if "prompt_embeds" not in npz:
|
||||||
|
return False
|
||||||
|
if "attn_mask" not in npz:
|
||||||
|
return False
|
||||||
|
if "t5_input_ids" not in npz:
|
||||||
|
return False
|
||||||
|
if "t5_attn_mask" not in npz:
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading file: {npz_path}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
||||||
|
data = np.load(npz_path)
|
||||||
|
prompt_embeds = data["prompt_embeds"]
|
||||||
|
attn_mask = data["attn_mask"]
|
||||||
|
t5_input_ids = data["t5_input_ids"]
|
||||||
|
t5_attn_mask = data["t5_attn_mask"]
|
||||||
|
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
|
||||||
|
|
||||||
|
def cache_batch_outputs(
|
||||||
|
self,
|
||||||
|
tokenize_strategy: TokenizeStrategy,
|
||||||
|
models: List[Any],
|
||||||
|
text_encoding_strategy: TextEncodingStrategy,
|
||||||
|
infos: List,
|
||||||
|
):
|
||||||
|
anima_text_encoding_strategy: AnimaTextEncodingStrategy = text_encoding_strategy
|
||||||
|
captions = [info.caption for info in infos]
|
||||||
|
|
||||||
|
tokens_and_masks = tokenize_strategy.tokenize(captions)
|
||||||
|
with torch.no_grad():
|
||||||
|
# Always disable dropout during caching
|
||||||
|
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = anima_text_encoding_strategy.encode_tokens(
|
||||||
|
tokenize_strategy,
|
||||||
|
models,
|
||||||
|
tokens_and_masks,
|
||||||
|
enable_dropout=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert to numpy for caching
|
||||||
|
if prompt_embeds.dtype == torch.bfloat16:
|
||||||
|
prompt_embeds = prompt_embeds.float()
|
||||||
|
prompt_embeds = prompt_embeds.cpu().numpy()
|
||||||
|
attn_mask = attn_mask.cpu().numpy()
|
||||||
|
t5_input_ids = t5_input_ids.cpu().numpy().astype(np.int32)
|
||||||
|
t5_attn_mask = t5_attn_mask.cpu().numpy().astype(np.int32)
|
||||||
|
|
||||||
|
for i, info in enumerate(infos):
|
||||||
|
prompt_embeds_i = prompt_embeds[i]
|
||||||
|
attn_mask_i = attn_mask[i]
|
||||||
|
t5_input_ids_i = t5_input_ids[i]
|
||||||
|
t5_attn_mask_i = t5_attn_mask[i]
|
||||||
|
|
||||||
|
if self.cache_to_disk:
|
||||||
|
np.savez(
|
||||||
|
info.text_encoder_outputs_npz,
|
||||||
|
prompt_embeds=prompt_embeds_i,
|
||||||
|
attn_mask=attn_mask_i,
|
||||||
|
t5_input_ids=t5_input_ids_i,
|
||||||
|
t5_attn_mask=t5_attn_mask_i,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
info.text_encoder_outputs = (prompt_embeds_i, attn_mask_i, t5_input_ids_i, t5_attn_mask_i)
|
||||||
|
|
||||||
|
|
||||||
|
class AnimaLatentsCachingStrategy(LatentsCachingStrategy):
|
||||||
|
"""Latent caching strategy for Anima using WanVAE.
|
||||||
|
|
||||||
|
WanVAE produces 16-channel latents with spatial downscale 8x.
|
||||||
|
Latent shape for images: (B, 16, 1, H/8, W/8)
|
||||||
|
"""
|
||||||
|
|
||||||
|
ANIMA_LATENTS_NPZ_SUFFIX = "_anima.npz"
|
||||||
|
|
||||||
|
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||||
|
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cache_suffix(self) -> str:
|
||||||
|
return self.ANIMA_LATENTS_NPZ_SUFFIX
|
||||||
|
|
||||||
|
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
||||||
|
return (
|
||||||
|
os.path.splitext(absolute_path)[0]
|
||||||
|
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
|
||||||
|
+ self.ANIMA_LATENTS_NPZ_SUFFIX
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_disk_cached_latents_expected(
|
||||||
|
self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool
|
||||||
|
):
|
||||||
|
return self._default_is_disk_cached_latents_expected(
|
||||||
|
8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_latents_from_disk(
|
||||||
|
self, npz_path: str, bucket_reso: Tuple[int, int]
|
||||||
|
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||||
|
return self._default_load_latents_from_disk(8, npz_path, bucket_reso)
|
||||||
|
|
||||||
|
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||||
|
"""Cache batch of latents using WanVAE.
|
||||||
|
|
||||||
|
vae is expected to be the WanVAE_ model (not the wrapper).
|
||||||
|
The encoding function handles the mean/std normalization.
|
||||||
|
"""
|
||||||
|
from library.anima_models import ANIMA_VAE_MEAN, ANIMA_VAE_STD
|
||||||
|
|
||||||
|
vae_device = next(vae.parameters()).device
|
||||||
|
vae_dtype = next(vae.parameters()).dtype
|
||||||
|
|
||||||
|
# Create scale tensors on VAE device
|
||||||
|
mean = torch.tensor(ANIMA_VAE_MEAN, dtype=vae_dtype, device=vae_device)
|
||||||
|
std = torch.tensor(ANIMA_VAE_STD, dtype=vae_dtype, device=vae_device)
|
||||||
|
scale = [mean, 1.0 / std]
|
||||||
|
|
||||||
|
def encode_by_vae(img_tensor):
|
||||||
|
"""Encode image tensor to latents.
|
||||||
|
|
||||||
|
img_tensor: (B, C, H, W) in [-1, 1] range (already normalized by IMAGE_TRANSFORMS)
|
||||||
|
Need to add temporal dim to get (B, C, T=1, H, W) for WanVAE
|
||||||
|
"""
|
||||||
|
# Add temporal dimension: (B, C, H, W) -> (B, C, 1, H, W)
|
||||||
|
img_tensor = img_tensor.unsqueeze(2)
|
||||||
|
img_tensor = img_tensor.to(vae_device, dtype=vae_dtype)
|
||||||
|
|
||||||
|
latents = vae.encode(img_tensor, scale)
|
||||||
|
return latents.to("cpu")
|
||||||
|
|
||||||
|
self._default_cache_batch_latents(
|
||||||
|
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if not train_util.HIGH_VRAM:
|
||||||
|
train_util.clean_memory_on_device(vae_device)
|
||||||
@@ -524,7 +524,7 @@ class LatentsCachingStrategy:
|
|||||||
original_size = original_sizes[i]
|
original_size = original_sizes[i]
|
||||||
crop_ltrb = crop_ltrbs[i]
|
crop_ltrb = crop_ltrbs[i]
|
||||||
|
|
||||||
latents_size = latents.shape[1:3] # H, W
|
latents_size = latents.shape[-2:] # H, W (supports both 4D and 5D latents)
|
||||||
key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" if multi_resolution else "" # e.g. "_32x64", HxW
|
key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" if multi_resolution else "" # e.g. "_32x64", HxW
|
||||||
|
|
||||||
if self.cache_to_disk:
|
if self.cache_to_disk:
|
||||||
|
|||||||
@@ -6138,7 +6138,8 @@ def conditional_loss(
|
|||||||
elif loss_type == "huber":
|
elif loss_type == "huber":
|
||||||
if huber_c is None:
|
if huber_c is None:
|
||||||
raise NotImplementedError("huber_c not implemented correctly")
|
raise NotImplementedError("huber_c not implemented correctly")
|
||||||
huber_c = huber_c.view(-1, 1, 1, 1)
|
# Reshape huber_c to broadcast with model_pred (supports 4D and 5D tensors)
|
||||||
|
huber_c = huber_c.view(-1, *([1] * (model_pred.ndim - 1)))
|
||||||
loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
|
loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
|
||||||
if reduction == "mean":
|
if reduction == "mean":
|
||||||
loss = torch.mean(loss)
|
loss = torch.mean(loss)
|
||||||
@@ -6147,7 +6148,8 @@ def conditional_loss(
|
|||||||
elif loss_type == "smooth_l1":
|
elif loss_type == "smooth_l1":
|
||||||
if huber_c is None:
|
if huber_c is None:
|
||||||
raise NotImplementedError("huber_c not implemented correctly")
|
raise NotImplementedError("huber_c not implemented correctly")
|
||||||
huber_c = huber_c.view(-1, 1, 1, 1)
|
# Reshape huber_c to broadcast with model_pred (supports 4D and 5D tensors)
|
||||||
|
huber_c = huber_c.view(-1, *([1] * (model_pred.ndim - 1)))
|
||||||
loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
|
loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
|
||||||
if reduction == "mean":
|
if reduction == "mean":
|
||||||
loss = torch.mean(loss)
|
loss = torch.mean(loss)
|
||||||
|
|||||||
635
networks/lora_anima.py
Normal file
635
networks/lora_anima.py
Normal file
@@ -0,0 +1,635 @@
|
|||||||
|
# LoRA network module for Anima
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from library.utils import setup_logging
|
||||||
|
|
||||||
|
setup_logging()
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
from networks.lora_flux import LoRAModule, LoRAInfModule
|
||||||
|
|
||||||
|
|
||||||
|
def create_network(
|
||||||
|
multiplier: float,
|
||||||
|
network_dim: Optional[int],
|
||||||
|
network_alpha: Optional[float],
|
||||||
|
vae,
|
||||||
|
text_encoders: list,
|
||||||
|
unet,
|
||||||
|
neuron_dropout: Optional[float] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if network_dim is None:
|
||||||
|
network_dim = 4
|
||||||
|
if network_alpha is None:
|
||||||
|
network_alpha = 1.0
|
||||||
|
|
||||||
|
# type_dims: [self_attn_dim, cross_attn_dim, mlp_dim, mod_dim, llm_adapter_dim]
|
||||||
|
self_attn_dim = kwargs.get("self_attn_dim", None)
|
||||||
|
cross_attn_dim = kwargs.get("cross_attn_dim", None)
|
||||||
|
mlp_dim = kwargs.get("mlp_dim", None)
|
||||||
|
mod_dim = kwargs.get("mod_dim", None)
|
||||||
|
llm_adapter_dim = kwargs.get("llm_adapter_dim", None)
|
||||||
|
|
||||||
|
if self_attn_dim is not None:
|
||||||
|
self_attn_dim = int(self_attn_dim)
|
||||||
|
if cross_attn_dim is not None:
|
||||||
|
cross_attn_dim = int(cross_attn_dim)
|
||||||
|
if mlp_dim is not None:
|
||||||
|
mlp_dim = int(mlp_dim)
|
||||||
|
if mod_dim is not None:
|
||||||
|
mod_dim = int(mod_dim)
|
||||||
|
if llm_adapter_dim is not None:
|
||||||
|
llm_adapter_dim = int(llm_adapter_dim)
|
||||||
|
|
||||||
|
type_dims = [self_attn_dim, cross_attn_dim, mlp_dim, mod_dim, llm_adapter_dim]
|
||||||
|
if all([d is None for d in type_dims]):
|
||||||
|
type_dims = None
|
||||||
|
|
||||||
|
# emb_dims: [x_embedder, t_embedder, final_layer]
|
||||||
|
emb_dims = kwargs.get("emb_dims", None)
|
||||||
|
if emb_dims is not None:
|
||||||
|
emb_dims = emb_dims.strip()
|
||||||
|
if emb_dims.startswith("[") and emb_dims.endswith("]"):
|
||||||
|
emb_dims = emb_dims[1:-1]
|
||||||
|
emb_dims = [int(d) for d in emb_dims.split(",")]
|
||||||
|
assert len(emb_dims) == 3, f"invalid emb_dims: {emb_dims}, must be 3 dimensions (x_embedder, t_embedder, final_layer)"
|
||||||
|
|
||||||
|
# block selection
|
||||||
|
def parse_block_selection(selection: str, total_blocks: int) -> List[bool]:
|
||||||
|
if selection == "all":
|
||||||
|
return [True] * total_blocks
|
||||||
|
if selection == "none" or selection == "":
|
||||||
|
return [False] * total_blocks
|
||||||
|
|
||||||
|
selected = [False] * total_blocks
|
||||||
|
ranges = selection.split(",")
|
||||||
|
for r in ranges:
|
||||||
|
if "-" in r:
|
||||||
|
start, end = map(str.strip, r.split("-"))
|
||||||
|
start, end = int(start), int(end)
|
||||||
|
assert 0 <= start < total_blocks and 0 <= end < total_blocks and start <= end
|
||||||
|
for i in range(start, end + 1):
|
||||||
|
selected[i] = True
|
||||||
|
else:
|
||||||
|
index = int(r)
|
||||||
|
assert 0 <= index < total_blocks
|
||||||
|
selected[index] = True
|
||||||
|
return selected
|
||||||
|
|
||||||
|
train_block_indices = kwargs.get("train_block_indices", None)
|
||||||
|
if train_block_indices is not None:
|
||||||
|
num_blocks = len(unet.blocks) if hasattr(unet, 'blocks') else 999
|
||||||
|
train_block_indices = parse_block_selection(train_block_indices, num_blocks)
|
||||||
|
|
||||||
|
# train LLM adapter
|
||||||
|
train_llm_adapter = kwargs.get("train_llm_adapter", False)
|
||||||
|
if train_llm_adapter is not None:
|
||||||
|
train_llm_adapter = True if train_llm_adapter == "True" else False
|
||||||
|
|
||||||
|
# rank/module dropout
|
||||||
|
rank_dropout = kwargs.get("rank_dropout", None)
|
||||||
|
if rank_dropout is not None:
|
||||||
|
rank_dropout = float(rank_dropout)
|
||||||
|
module_dropout = kwargs.get("module_dropout", None)
|
||||||
|
if module_dropout is not None:
|
||||||
|
module_dropout = float(module_dropout)
|
||||||
|
|
||||||
|
# verbose
|
||||||
|
verbose = kwargs.get("verbose", False)
|
||||||
|
if verbose is not None:
|
||||||
|
verbose = True if verbose == "True" else False
|
||||||
|
|
||||||
|
network = LoRANetwork(
|
||||||
|
text_encoders,
|
||||||
|
unet,
|
||||||
|
multiplier=multiplier,
|
||||||
|
lora_dim=network_dim,
|
||||||
|
alpha=network_alpha,
|
||||||
|
dropout=neuron_dropout,
|
||||||
|
rank_dropout=rank_dropout,
|
||||||
|
module_dropout=module_dropout,
|
||||||
|
train_llm_adapter=train_llm_adapter,
|
||||||
|
type_dims=type_dims,
|
||||||
|
emb_dims=emb_dims,
|
||||||
|
train_block_indices=train_block_indices,
|
||||||
|
verbose=verbose,
|
||||||
|
)
|
||||||
|
|
||||||
|
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
|
||||||
|
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
|
||||||
|
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
|
||||||
|
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
|
||||||
|
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
|
||||||
|
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
|
||||||
|
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
|
||||||
|
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
|
||||||
|
|
||||||
|
return network
|
||||||
|
|
||||||
|
|
||||||
|
def create_network_from_weights(multiplier, file, ae, text_encoders, unet, weights_sd=None, for_inference=False, **kwargs):
|
||||||
|
if weights_sd is None:
|
||||||
|
if os.path.splitext(file)[1] == ".safetensors":
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
weights_sd = load_file(file)
|
||||||
|
else:
|
||||||
|
weights_sd = torch.load(file, map_location="cpu")
|
||||||
|
|
||||||
|
modules_dim = {}
|
||||||
|
modules_alpha = {}
|
||||||
|
train_llm_adapter = False
|
||||||
|
for key, value in weights_sd.items():
|
||||||
|
if "." not in key:
|
||||||
|
continue
|
||||||
|
|
||||||
|
lora_name = key.split(".")[0]
|
||||||
|
if "alpha" in key:
|
||||||
|
modules_alpha[lora_name] = value
|
||||||
|
elif "lora_down" in key:
|
||||||
|
dim = value.size()[0]
|
||||||
|
modules_dim[lora_name] = dim
|
||||||
|
|
||||||
|
if "llm_adapter" in lora_name:
|
||||||
|
train_llm_adapter = True
|
||||||
|
|
||||||
|
module_class = LoRAInfModule if for_inference else LoRAModule
|
||||||
|
|
||||||
|
network = LoRANetwork(
|
||||||
|
text_encoders,
|
||||||
|
unet,
|
||||||
|
multiplier=multiplier,
|
||||||
|
modules_dim=modules_dim,
|
||||||
|
modules_alpha=modules_alpha,
|
||||||
|
module_class=module_class,
|
||||||
|
train_llm_adapter=train_llm_adapter,
|
||||||
|
)
|
||||||
|
return network, weights_sd
|
||||||
|
|
||||||
|
|
||||||
|
class LoRANetwork(torch.nn.Module):
|
||||||
|
# Target modules: DiT blocks
|
||||||
|
ANIMA_TARGET_REPLACE_MODULE = ["Block"]
|
||||||
|
# Target modules: LLM Adapter blocks
|
||||||
|
ANIMA_ADAPTER_TARGET_REPLACE_MODULE = ["LLMAdapterTransformerBlock"]
|
||||||
|
# Target modules for text encoder (Qwen3)
|
||||||
|
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["Qwen3Attention", "Qwen3MLP", "Qwen3SdpaAttention", "Qwen3FlashAttention2"]
|
||||||
|
|
||||||
|
LORA_PREFIX_ANIMA = "lora_unet" # ComfyUI compatible
|
||||||
|
LORA_PREFIX_TEXT_ENCODER = "lora_te1" # Qwen3
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
text_encoders: list,
|
||||||
|
unet,
|
||||||
|
multiplier: float = 1.0,
|
||||||
|
lora_dim: int = 4,
|
||||||
|
alpha: float = 1,
|
||||||
|
dropout: Optional[float] = None,
|
||||||
|
rank_dropout: Optional[float] = None,
|
||||||
|
module_dropout: Optional[float] = None,
|
||||||
|
module_class: Type[object] = LoRAModule,
|
||||||
|
modules_dim: Optional[Dict[str, int]] = None,
|
||||||
|
modules_alpha: Optional[Dict[str, int]] = None,
|
||||||
|
train_llm_adapter: bool = False,
|
||||||
|
type_dims: Optional[List[int]] = None,
|
||||||
|
emb_dims: Optional[List[int]] = None,
|
||||||
|
train_block_indices: Optional[List[bool]] = None,
|
||||||
|
verbose: Optional[bool] = False,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.multiplier = multiplier
|
||||||
|
self.lora_dim = lora_dim
|
||||||
|
self.alpha = alpha
|
||||||
|
self.dropout = dropout
|
||||||
|
self.rank_dropout = rank_dropout
|
||||||
|
self.module_dropout = module_dropout
|
||||||
|
self.train_llm_adapter = train_llm_adapter
|
||||||
|
self.type_dims = type_dims
|
||||||
|
self.emb_dims = emb_dims
|
||||||
|
self.train_block_indices = train_block_indices
|
||||||
|
|
||||||
|
self.loraplus_lr_ratio = None
|
||||||
|
self.loraplus_unet_lr_ratio = None
|
||||||
|
self.loraplus_text_encoder_lr_ratio = None
|
||||||
|
|
||||||
|
if modules_dim is not None:
|
||||||
|
logger.info(f"create LoRA network from weights")
|
||||||
|
if self.emb_dims is None:
|
||||||
|
self.emb_dims = [0] * 3
|
||||||
|
else:
|
||||||
|
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
||||||
|
logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
|
||||||
|
|
||||||
|
# create module instances
|
||||||
|
def create_modules(
|
||||||
|
is_unet: bool,
|
||||||
|
text_encoder_idx: Optional[int],
|
||||||
|
root_module: torch.nn.Module,
|
||||||
|
target_replace_modules: List[str],
|
||||||
|
filter: Optional[str] = None,
|
||||||
|
default_dim: Optional[int] = None,
|
||||||
|
include_conv2d_if_filter: bool = False,
|
||||||
|
) -> Tuple[List[LoRAModule], List[str]]:
|
||||||
|
prefix = (
|
||||||
|
self.LORA_PREFIX_ANIMA
|
||||||
|
if is_unet
|
||||||
|
else self.LORA_PREFIX_TEXT_ENCODER
|
||||||
|
)
|
||||||
|
|
||||||
|
loras = []
|
||||||
|
skipped = []
|
||||||
|
for name, module in root_module.named_modules():
|
||||||
|
if target_replace_modules is None or module.__class__.__name__ in target_replace_modules:
|
||||||
|
if target_replace_modules is None:
|
||||||
|
module = root_module
|
||||||
|
|
||||||
|
for child_name, child_module in module.named_modules():
|
||||||
|
is_linear = child_module.__class__.__name__ == "Linear"
|
||||||
|
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
||||||
|
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
||||||
|
|
||||||
|
if is_linear or is_conv2d:
|
||||||
|
lora_name = prefix + "." + (name + "." if name else "") + child_name
|
||||||
|
lora_name = lora_name.replace(".", "_")
|
||||||
|
|
||||||
|
force_incl_conv2d = False
|
||||||
|
if filter is not None:
|
||||||
|
if filter not in lora_name:
|
||||||
|
continue
|
||||||
|
force_incl_conv2d = include_conv2d_if_filter
|
||||||
|
|
||||||
|
dim = None
|
||||||
|
alpha_val = None
|
||||||
|
|
||||||
|
if modules_dim is not None:
|
||||||
|
if lora_name in modules_dim:
|
||||||
|
dim = modules_dim[lora_name]
|
||||||
|
alpha_val = modules_alpha[lora_name]
|
||||||
|
else:
|
||||||
|
if is_linear or is_conv2d_1x1:
|
||||||
|
dim = default_dim if default_dim is not None else self.lora_dim
|
||||||
|
alpha_val = self.alpha
|
||||||
|
|
||||||
|
if is_unet and type_dims is not None:
|
||||||
|
# type_dims = [self_attn_dim, cross_attn_dim, mlp_dim, mod_dim, llm_adapter_dim]
|
||||||
|
# Order matters: check most specific identifiers first to avoid mismatches.
|
||||||
|
identifier_order = [
|
||||||
|
(4, ("llm_adapter",)),
|
||||||
|
(3, ("adaln_modulation",)),
|
||||||
|
(0, ("self_attn",)),
|
||||||
|
(1, ("cross_attn",)),
|
||||||
|
(2, ("mlp",)),
|
||||||
|
]
|
||||||
|
for idx, ids in identifier_order:
|
||||||
|
d = type_dims[idx]
|
||||||
|
if d is not None and all(id_str in lora_name for id_str in ids):
|
||||||
|
dim = d # 0 means skip
|
||||||
|
break
|
||||||
|
|
||||||
|
# block index filtering
|
||||||
|
if is_unet and dim and self.train_block_indices is not None and "blocks_" in lora_name:
|
||||||
|
# Extract block index from lora_name: "lora_unet_blocks_0_self_attn..."
|
||||||
|
parts = lora_name.split("_")
|
||||||
|
for pi, part in enumerate(parts):
|
||||||
|
if part == "blocks" and pi + 1 < len(parts):
|
||||||
|
try:
|
||||||
|
block_index = int(parts[pi + 1])
|
||||||
|
if not self.train_block_indices[block_index]:
|
||||||
|
dim = 0
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
pass
|
||||||
|
break
|
||||||
|
|
||||||
|
elif force_incl_conv2d:
|
||||||
|
dim = default_dim if default_dim is not None else self.lora_dim
|
||||||
|
alpha_val = self.alpha
|
||||||
|
|
||||||
|
if dim is None or dim == 0:
|
||||||
|
if is_linear or is_conv2d_1x1:
|
||||||
|
skipped.append(lora_name)
|
||||||
|
continue
|
||||||
|
|
||||||
|
lora = module_class(
|
||||||
|
lora_name,
|
||||||
|
child_module,
|
||||||
|
self.multiplier,
|
||||||
|
dim,
|
||||||
|
alpha_val,
|
||||||
|
dropout=dropout,
|
||||||
|
rank_dropout=rank_dropout,
|
||||||
|
module_dropout=module_dropout,
|
||||||
|
)
|
||||||
|
loras.append(lora)
|
||||||
|
|
||||||
|
if target_replace_modules is None:
|
||||||
|
break
|
||||||
|
return loras, skipped
|
||||||
|
|
||||||
|
# Create LoRA for text encoders (Qwen3 - typically not trained for Anima)
|
||||||
|
self.text_encoder_loras: List[Union[LoRAModule, LoRAInfModule]] = []
|
||||||
|
skipped_te = []
|
||||||
|
if text_encoders is not None:
|
||||||
|
for i, text_encoder in enumerate(text_encoders):
|
||||||
|
if text_encoder is None:
|
||||||
|
continue
|
||||||
|
logger.info(f"create LoRA for Text Encoder {i+1}:")
|
||||||
|
te_loras, te_skipped = create_modules(
|
||||||
|
False, i, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
|
||||||
|
)
|
||||||
|
logger.info(f"create LoRA for Text Encoder {i+1}: {len(te_loras)} modules.")
|
||||||
|
self.text_encoder_loras.extend(te_loras)
|
||||||
|
skipped_te += te_skipped
|
||||||
|
|
||||||
|
# Create LoRA for DiT blocks
|
||||||
|
target_modules = list(LoRANetwork.ANIMA_TARGET_REPLACE_MODULE)
|
||||||
|
if train_llm_adapter:
|
||||||
|
target_modules.extend(LoRANetwork.ANIMA_ADAPTER_TARGET_REPLACE_MODULE)
|
||||||
|
|
||||||
|
self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
|
||||||
|
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
||||||
|
|
||||||
|
# emb_dims: [x_embedder, t_embedder, final_layer]
|
||||||
|
if self.emb_dims:
|
||||||
|
for filter_name, in_dim in zip(
|
||||||
|
["x_embedder", "t_embedder", "final_layer"],
|
||||||
|
self.emb_dims,
|
||||||
|
):
|
||||||
|
loras, _ = create_modules(
|
||||||
|
True, None, unet, None,
|
||||||
|
filter=filter_name, default_dim=in_dim,
|
||||||
|
include_conv2d_if_filter=(filter_name == "x_embedder"),
|
||||||
|
)
|
||||||
|
self.unet_loras.extend(loras)
|
||||||
|
|
||||||
|
logger.info(f"create LoRA for Anima DiT: {len(self.unet_loras)} modules.")
|
||||||
|
if verbose:
|
||||||
|
for lora in self.unet_loras:
|
||||||
|
logger.info(f"\t{lora.lora_name:60} {lora.lora_dim}, {lora.alpha}")
|
||||||
|
|
||||||
|
skipped = skipped_te + skipped_un
|
||||||
|
if verbose and len(skipped) > 0:
|
||||||
|
logger.warning(f"dim (rank) is 0, {len(skipped)} LoRA modules are skipped:")
|
||||||
|
for name in skipped:
|
||||||
|
logger.info(f"\t{name}")
|
||||||
|
|
||||||
|
# assertion: no duplicate names
|
||||||
|
names = set()
|
||||||
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
|
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
||||||
|
names.add(lora.lora_name)
|
||||||
|
|
||||||
|
def set_multiplier(self, multiplier):
|
||||||
|
self.multiplier = multiplier
|
||||||
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
|
lora.multiplier = self.multiplier
|
||||||
|
|
||||||
|
def set_enabled(self, is_enabled):
|
||||||
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
|
lora.enabled = is_enabled
|
||||||
|
|
||||||
|
def load_weights(self, file):
|
||||||
|
if os.path.splitext(file)[1] == ".safetensors":
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
weights_sd = load_file(file)
|
||||||
|
else:
|
||||||
|
weights_sd = torch.load(file, map_location="cpu")
|
||||||
|
|
||||||
|
info = self.load_state_dict(weights_sd, False)
|
||||||
|
return info
|
||||||
|
|
||||||
|
def apply_to(self, text_encoders, unet, apply_text_encoder=True, apply_unet=True):
|
||||||
|
if apply_text_encoder:
|
||||||
|
logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules")
|
||||||
|
else:
|
||||||
|
self.text_encoder_loras = []
|
||||||
|
|
||||||
|
if apply_unet:
|
||||||
|
logger.info(f"enable LoRA for DiT: {len(self.unet_loras)} modules")
|
||||||
|
else:
|
||||||
|
self.unet_loras = []
|
||||||
|
|
||||||
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
|
lora.apply_to()
|
||||||
|
self.add_module(lora.lora_name, lora)
|
||||||
|
|
||||||
|
def is_mergeable(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def merge_to(self, text_encoders, unet, weights_sd, dtype=None, device=None):
|
||||||
|
apply_text_encoder = apply_unet = False
|
||||||
|
for key in weights_sd.keys():
|
||||||
|
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
|
||||||
|
apply_text_encoder = True
|
||||||
|
elif key.startswith(LoRANetwork.LORA_PREFIX_ANIMA):
|
||||||
|
apply_unet = True
|
||||||
|
|
||||||
|
if apply_text_encoder:
|
||||||
|
logger.info("enable LoRA for text encoder")
|
||||||
|
else:
|
||||||
|
self.text_encoder_loras = []
|
||||||
|
|
||||||
|
if apply_unet:
|
||||||
|
logger.info("enable LoRA for DiT")
|
||||||
|
else:
|
||||||
|
self.unet_loras = []
|
||||||
|
|
||||||
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
|
sd_for_lora = {}
|
||||||
|
for key in weights_sd.keys():
|
||||||
|
if key.startswith(lora.lora_name):
|
||||||
|
sd_for_lora[key[len(lora.lora_name) + 1:]] = weights_sd[key]
|
||||||
|
lora.merge_to(sd_for_lora, dtype, device)
|
||||||
|
|
||||||
|
logger.info(f"weights are merged")
|
||||||
|
|
||||||
|
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
|
||||||
|
self.loraplus_lr_ratio = loraplus_lr_ratio
|
||||||
|
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
|
||||||
|
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
|
||||||
|
|
||||||
|
logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}")
|
||||||
|
logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}")
|
||||||
|
|
||||||
|
def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr):
|
||||||
|
if text_encoder_lr is None or (isinstance(text_encoder_lr, list) and len(text_encoder_lr) == 0):
|
||||||
|
text_encoder_lr = [default_lr]
|
||||||
|
elif isinstance(text_encoder_lr, float) or isinstance(text_encoder_lr, int):
|
||||||
|
text_encoder_lr = [float(text_encoder_lr)]
|
||||||
|
elif len(text_encoder_lr) == 1:
|
||||||
|
pass # already a list with one element
|
||||||
|
|
||||||
|
self.requires_grad_(True)
|
||||||
|
|
||||||
|
all_params = []
|
||||||
|
lr_descriptions = []
|
||||||
|
|
||||||
|
def assemble_params(loras, lr, loraplus_ratio):
|
||||||
|
param_groups = {"lora": {}, "plus": {}}
|
||||||
|
for lora in loras:
|
||||||
|
for name, param in lora.named_parameters():
|
||||||
|
if loraplus_ratio is not None and "lora_up" in name:
|
||||||
|
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
|
||||||
|
else:
|
||||||
|
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
|
||||||
|
|
||||||
|
params = []
|
||||||
|
descriptions = []
|
||||||
|
for key in param_groups.keys():
|
||||||
|
param_data = {"params": param_groups[key].values()}
|
||||||
|
if len(param_data["params"]) == 0:
|
||||||
|
continue
|
||||||
|
if lr is not None:
|
||||||
|
if key == "plus":
|
||||||
|
param_data["lr"] = lr * loraplus_ratio
|
||||||
|
else:
|
||||||
|
param_data["lr"] = lr
|
||||||
|
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
|
||||||
|
logger.info("NO LR skipping!")
|
||||||
|
continue
|
||||||
|
params.append(param_data)
|
||||||
|
descriptions.append("plus" if key == "plus" else "")
|
||||||
|
return params, descriptions
|
||||||
|
|
||||||
|
if self.text_encoder_loras:
|
||||||
|
loraplus_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio
|
||||||
|
te1_loras = [
|
||||||
|
lora for lora in self.text_encoder_loras
|
||||||
|
if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER)
|
||||||
|
]
|
||||||
|
if len(te1_loras) > 0:
|
||||||
|
logger.info(f"Text Encoder 1 (Qwen3): {len(te1_loras)} modules, LR {text_encoder_lr[0]}")
|
||||||
|
params, descriptions = assemble_params(te1_loras, text_encoder_lr[0], loraplus_ratio)
|
||||||
|
all_params.extend(params)
|
||||||
|
lr_descriptions.extend(["textencoder 1" + (" " + d if d else "") for d in descriptions])
|
||||||
|
|
||||||
|
if self.unet_loras:
|
||||||
|
params, descriptions = assemble_params(
|
||||||
|
self.unet_loras,
|
||||||
|
unet_lr if unet_lr is not None else default_lr,
|
||||||
|
self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio,
|
||||||
|
)
|
||||||
|
all_params.extend(params)
|
||||||
|
lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions])
|
||||||
|
|
||||||
|
return all_params, lr_descriptions
|
||||||
|
|
||||||
|
def enable_gradient_checkpointing(self):
|
||||||
|
pass # not supported
|
||||||
|
|
||||||
|
def prepare_grad_etc(self, text_encoder, unet):
|
||||||
|
self.requires_grad_(True)
|
||||||
|
|
||||||
|
def on_epoch_start(self, text_encoder, unet):
|
||||||
|
self.train()
|
||||||
|
|
||||||
|
def get_trainable_params(self):
|
||||||
|
return self.parameters()
|
||||||
|
|
||||||
|
def save_weights(self, file, dtype, metadata):
|
||||||
|
if metadata is not None and len(metadata) == 0:
|
||||||
|
metadata = None
|
||||||
|
|
||||||
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
|
if dtype is not None:
|
||||||
|
for key in list(state_dict.keys()):
|
||||||
|
v = state_dict[key]
|
||||||
|
v = v.detach().clone().to("cpu").to(dtype)
|
||||||
|
state_dict[key] = v
|
||||||
|
|
||||||
|
if os.path.splitext(file)[1] == ".safetensors":
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
from library import train_util
|
||||||
|
|
||||||
|
if metadata is None:
|
||||||
|
metadata = {}
|
||||||
|
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||||
|
metadata["sshs_model_hash"] = model_hash
|
||||||
|
metadata["sshs_legacy_hash"] = legacy_hash
|
||||||
|
|
||||||
|
save_file(state_dict, file, metadata)
|
||||||
|
else:
|
||||||
|
torch.save(state_dict, file)
|
||||||
|
|
||||||
|
def backup_weights(self):
|
||||||
|
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
|
||||||
|
for lora in loras:
|
||||||
|
org_module = lora.org_module_ref[0]
|
||||||
|
if not hasattr(org_module, "_lora_org_weight"):
|
||||||
|
sd = org_module.state_dict()
|
||||||
|
org_module._lora_org_weight = sd["weight"].detach().clone()
|
||||||
|
org_module._lora_restored = True
|
||||||
|
|
||||||
|
def restore_weights(self):
|
||||||
|
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
|
||||||
|
for lora in loras:
|
||||||
|
org_module = lora.org_module_ref[0]
|
||||||
|
if not org_module._lora_restored:
|
||||||
|
sd = org_module.state_dict()
|
||||||
|
sd["weight"] = org_module._lora_org_weight
|
||||||
|
org_module.load_state_dict(sd)
|
||||||
|
org_module._lora_restored = True
|
||||||
|
|
||||||
|
def pre_calculation(self):
|
||||||
|
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
|
||||||
|
for lora in loras:
|
||||||
|
org_module = lora.org_module_ref[0]
|
||||||
|
sd = org_module.state_dict()
|
||||||
|
|
||||||
|
org_weight = sd["weight"]
|
||||||
|
lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype)
|
||||||
|
sd["weight"] = org_weight + lora_weight
|
||||||
|
assert sd["weight"].shape == org_weight.shape
|
||||||
|
org_module.load_state_dict(sd)
|
||||||
|
|
||||||
|
org_module._lora_restored = False
|
||||||
|
lora.enabled = False
|
||||||
|
|
||||||
|
def apply_max_norm_regularization(self, max_norm_value, device):
|
||||||
|
downkeys = []
|
||||||
|
upkeys = []
|
||||||
|
alphakeys = []
|
||||||
|
norms = []
|
||||||
|
keys_scaled = 0
|
||||||
|
|
||||||
|
state_dict = self.state_dict()
|
||||||
|
for key in state_dict.keys():
|
||||||
|
if "lora_down" in key and "weight" in key:
|
||||||
|
downkeys.append(key)
|
||||||
|
upkeys.append(key.replace("lora_down", "lora_up"))
|
||||||
|
alphakeys.append(key.replace("lora_down.weight", "alpha"))
|
||||||
|
|
||||||
|
for i in range(len(downkeys)):
|
||||||
|
down = state_dict[downkeys[i]].to(device)
|
||||||
|
up = state_dict[upkeys[i]].to(device)
|
||||||
|
alpha = state_dict[alphakeys[i]].to(device)
|
||||||
|
dim = down.shape[0]
|
||||||
|
scale = alpha / dim
|
||||||
|
|
||||||
|
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
|
||||||
|
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||||
|
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
|
||||||
|
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
|
||||||
|
else:
|
||||||
|
updown = up @ down
|
||||||
|
|
||||||
|
updown *= scale
|
||||||
|
|
||||||
|
norm = updown.norm().clamp(min=max_norm_value / 2)
|
||||||
|
desired = torch.clamp(norm, max=max_norm_value)
|
||||||
|
ratio = desired.cpu() / norm.cpu()
|
||||||
|
sqrt_ratio = ratio**0.5
|
||||||
|
if ratio != 1:
|
||||||
|
keys_scaled += 1
|
||||||
|
state_dict[upkeys[i]] *= sqrt_ratio
|
||||||
|
state_dict[downkeys[i]] *= sqrt_ratio
|
||||||
|
scalednorm = updown.norm() * ratio
|
||||||
|
norms.append(scalednorm.item())
|
||||||
|
|
||||||
|
return keys_scaled, sum(norms) / len(norms), max(norms)
|
||||||
617
tests/test_anima_cache.py
Normal file
617
tests/test_anima_cache.py
Normal file
@@ -0,0 +1,617 @@
|
|||||||
|
"""
|
||||||
|
Diagnostic script to test Anima latent & text encoder caching independently.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python test_anima_cache.py \
|
||||||
|
--image_dir /path/to/images \
|
||||||
|
--qwen3_path /path/to/qwen3 \
|
||||||
|
--vae_path /path/to/vae.safetensors \
|
||||||
|
[--t5_tokenizer_path /path/to/t5] \
|
||||||
|
[--cache_to_disk]
|
||||||
|
|
||||||
|
The image_dir should contain pairs of:
|
||||||
|
image1.png + image1.txt
|
||||||
|
image2.jpg + image2.txt
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
# Helpers
|
||||||
|
|
||||||
|
IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".tiff"}
|
||||||
|
|
||||||
|
IMAGE_TRANSFORMS = transforms.Compose([
|
||||||
|
transforms.ToTensor(), # [0,1]
|
||||||
|
transforms.Normalize([0.5], [0.5]), # [-1,1]
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
def find_image_caption_pairs(image_dir: str):
|
||||||
|
"""Find (image_path, caption_text) pairs from a directory."""
|
||||||
|
pairs = []
|
||||||
|
for f in sorted(os.listdir(image_dir)):
|
||||||
|
ext = os.path.splitext(f)[1].lower()
|
||||||
|
if ext not in IMAGE_EXTENSIONS:
|
||||||
|
continue
|
||||||
|
img_path = os.path.join(image_dir, f)
|
||||||
|
txt_path = os.path.splitext(img_path)[0] + ".txt"
|
||||||
|
if os.path.exists(txt_path):
|
||||||
|
with open(txt_path, "r", encoding="utf-8") as fh:
|
||||||
|
caption = fh.read().strip()
|
||||||
|
else:
|
||||||
|
caption = ""
|
||||||
|
pairs.append((img_path, caption))
|
||||||
|
return pairs
|
||||||
|
|
||||||
|
|
||||||
|
def print_tensor_info(name: str, t, indent=2):
|
||||||
|
prefix = " " * indent
|
||||||
|
if t is None:
|
||||||
|
print(f"{prefix}{name}: None")
|
||||||
|
return
|
||||||
|
if isinstance(t, np.ndarray):
|
||||||
|
print(f"{prefix}{name}: numpy {t.dtype} shape={t.shape} "
|
||||||
|
f"min={t.min():.4f} max={t.max():.4f} mean={t.mean():.4f}")
|
||||||
|
elif isinstance(t, torch.Tensor):
|
||||||
|
print(f"{prefix}{name}: torch {t.dtype} shape={tuple(t.shape)} "
|
||||||
|
f"min={t.min().item():.4f} max={t.max().item():.4f} mean={t.float().mean().item():.4f}")
|
||||||
|
else:
|
||||||
|
print(f"{prefix}{name}: type={type(t)} value={t}")
|
||||||
|
|
||||||
|
|
||||||
|
# Test 1: Latent Cache
|
||||||
|
|
||||||
|
def test_latent_cache(args, pairs):
|
||||||
|
print("\n" + "=" * 70)
|
||||||
|
print("TEST 1: LATENT CACHING (VAE encode -> cache -> reload)")
|
||||||
|
print("=" * 70)
|
||||||
|
|
||||||
|
from library import anima_utils
|
||||||
|
from library.anima_models import ANIMA_VAE_MEAN, ANIMA_VAE_STD
|
||||||
|
|
||||||
|
# Load VAE
|
||||||
|
print("\n[1.1] Loading VAE...")
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
vae_dtype = torch.float32
|
||||||
|
vae, vae_mean, vae_std, vae_scale = anima_utils.load_anima_vae(
|
||||||
|
args.vae_path, dtype=vae_dtype, device=device
|
||||||
|
)
|
||||||
|
print(f" VAE loaded on {device}, dtype={vae_dtype}")
|
||||||
|
print(f" VAE mean (first 4): {ANIMA_VAE_MEAN[:4]}")
|
||||||
|
print(f" VAE std (first 4): {ANIMA_VAE_STD[:4]}")
|
||||||
|
|
||||||
|
for img_path, caption in pairs:
|
||||||
|
print(f"\n[1.2] Processing: {os.path.basename(img_path)}")
|
||||||
|
|
||||||
|
# Load image
|
||||||
|
img = Image.open(img_path).convert("RGB")
|
||||||
|
img_np = np.array(img)
|
||||||
|
print(f" Raw image: {img_np.shape} dtype={img_np.dtype} "
|
||||||
|
f"min={img_np.min()} max={img_np.max()}")
|
||||||
|
|
||||||
|
# Apply IMAGE_TRANSFORMS (same as sd-scripts training)
|
||||||
|
img_tensor = IMAGE_TRANSFORMS(img_np)
|
||||||
|
print(f" After IMAGE_TRANSFORMS: shape={tuple(img_tensor.shape)} "
|
||||||
|
f"min={img_tensor.min():.4f} max={img_tensor.max():.4f}")
|
||||||
|
|
||||||
|
# Check range is [-1, 1]
|
||||||
|
if img_tensor.min() < -1.01 or img_tensor.max() > 1.01:
|
||||||
|
print(" ** WARNING: tensor out of [-1, 1] range!")
|
||||||
|
else:
|
||||||
|
print(" OK: tensor in [-1, 1] range")
|
||||||
|
|
||||||
|
# Encode with VAE
|
||||||
|
img_batch = img_tensor.unsqueeze(0).to(device, dtype=vae_dtype) # (1, C, H, W)
|
||||||
|
img_5d = img_batch.unsqueeze(2) # (1, C, 1, H, W) - add temporal dim
|
||||||
|
print(f" VAE input: shape={tuple(img_5d.shape)} dtype={img_5d.dtype}")
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
latents = vae.encode(img_5d, vae_scale)
|
||||||
|
latents_cpu = latents.cpu()
|
||||||
|
print_tensor_info("Encoded latents", latents_cpu)
|
||||||
|
|
||||||
|
# Check for NaN/Inf
|
||||||
|
if torch.any(torch.isnan(latents_cpu)):
|
||||||
|
print(" ** ERROR: NaN in latents!")
|
||||||
|
elif torch.any(torch.isinf(latents_cpu)):
|
||||||
|
print(" ** ERROR: Inf in latents!")
|
||||||
|
else:
|
||||||
|
print(" OK: no NaN/Inf")
|
||||||
|
|
||||||
|
# Test disk cache round-trip
|
||||||
|
if args.cache_to_disk:
|
||||||
|
npz_path = os.path.splitext(img_path)[0] + "_test_latent.npz"
|
||||||
|
latents_np = latents_cpu.float().numpy()
|
||||||
|
h, w = img_np.shape[:2]
|
||||||
|
np.savez(
|
||||||
|
npz_path,
|
||||||
|
latents=latents_np,
|
||||||
|
original_size=np.array([w, h]),
|
||||||
|
crop_ltrb=np.array([0, 0, 0, 0]),
|
||||||
|
)
|
||||||
|
print(f" Saved to: {npz_path}")
|
||||||
|
|
||||||
|
# Reload
|
||||||
|
loaded = np.load(npz_path)
|
||||||
|
loaded_latents = loaded["latents"]
|
||||||
|
print_tensor_info("Reloaded latents", loaded_latents)
|
||||||
|
|
||||||
|
# Compare
|
||||||
|
diff = np.abs(latents_np - loaded_latents).max()
|
||||||
|
print(f" Max diff (save vs load): {diff:.2e}")
|
||||||
|
if diff > 1e-5:
|
||||||
|
print(" ** WARNING: latent cache round-trip has significant diff!")
|
||||||
|
else:
|
||||||
|
print(" OK: round-trip matches")
|
||||||
|
|
||||||
|
os.remove(npz_path)
|
||||||
|
print(f" Cleaned up {npz_path}")
|
||||||
|
|
||||||
|
vae.to("cpu")
|
||||||
|
del vae
|
||||||
|
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
||||||
|
print("\n[1.3] Latent cache test DONE.")
|
||||||
|
|
||||||
|
|
||||||
|
# Test 2: Text Encoder Output Cache
|
||||||
|
|
||||||
|
def test_text_encoder_cache(args, pairs):
|
||||||
|
print("\n" + "=" * 70)
|
||||||
|
print("TEST 2: TEXT ENCODER OUTPUT CACHING")
|
||||||
|
print("=" * 70)
|
||||||
|
|
||||||
|
from library import anima_utils
|
||||||
|
|
||||||
|
# Load tokenizers
|
||||||
|
print("\n[2.1] Loading tokenizers...")
|
||||||
|
qwen3_tokenizer = anima_utils.load_qwen3_tokenizer(args.qwen3_path)
|
||||||
|
t5_tokenizer = anima_utils.load_t5_tokenizer(
|
||||||
|
getattr(args, 't5_tokenizer_path', None)
|
||||||
|
)
|
||||||
|
print(f" Qwen3 tokenizer vocab: {qwen3_tokenizer.vocab_size}")
|
||||||
|
print(f" T5 tokenizer vocab: {t5_tokenizer.vocab_size}")
|
||||||
|
|
||||||
|
# Load text encoder
|
||||||
|
print("\n[2.2] Loading Qwen3 text encoder...")
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
te_dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
||||||
|
qwen3_model, _ = anima_utils.load_qwen3_text_encoder(
|
||||||
|
args.qwen3_path, dtype=te_dtype, device=device
|
||||||
|
)
|
||||||
|
qwen3_model.eval()
|
||||||
|
|
||||||
|
# Create strategy objects
|
||||||
|
from library.strategy_anima import AnimaTokenizeStrategy, AnimaTextEncodingStrategy
|
||||||
|
|
||||||
|
tokenize_strategy = AnimaTokenizeStrategy(
|
||||||
|
qwen3_tokenizer=qwen3_tokenizer,
|
||||||
|
t5_tokenizer=t5_tokenizer,
|
||||||
|
qwen3_max_length=args.qwen3_max_length,
|
||||||
|
t5_max_length=args.t5_max_length,
|
||||||
|
)
|
||||||
|
text_encoding_strategy = AnimaTextEncodingStrategy(
|
||||||
|
dropout_rate=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
captions = [cap for _, cap in pairs]
|
||||||
|
print(f"\n[2.3] Tokenizing {len(captions)} captions...")
|
||||||
|
for i, cap in enumerate(captions):
|
||||||
|
print(f" [{i}] \"{cap[:80]}{'...' if len(cap) > 80 else ''}\"")
|
||||||
|
|
||||||
|
tokens_and_masks = tokenize_strategy.tokenize(captions)
|
||||||
|
qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask = tokens_and_masks
|
||||||
|
|
||||||
|
print(f"\n Tokenization results:")
|
||||||
|
print_tensor_info("qwen3_input_ids", qwen3_input_ids)
|
||||||
|
print_tensor_info("qwen3_attn_mask", qwen3_attn_mask)
|
||||||
|
print_tensor_info("t5_input_ids", t5_input_ids)
|
||||||
|
print_tensor_info("t5_attn_mask", t5_attn_mask)
|
||||||
|
|
||||||
|
# Encode
|
||||||
|
print(f"\n[2.4] Encoding with Qwen3 text encoder...")
|
||||||
|
with torch.no_grad():
|
||||||
|
prompt_embeds, attn_mask, t5_ids_out, t5_mask_out = text_encoding_strategy.encode_tokens(
|
||||||
|
tokenize_strategy,
|
||||||
|
[qwen3_model],
|
||||||
|
tokens_and_masks,
|
||||||
|
enable_dropout=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f" Encoding results:")
|
||||||
|
print_tensor_info("prompt_embeds", prompt_embeds)
|
||||||
|
print_tensor_info("attn_mask", attn_mask)
|
||||||
|
print_tensor_info("t5_input_ids", t5_ids_out)
|
||||||
|
print_tensor_info("t5_attn_mask", t5_mask_out)
|
||||||
|
|
||||||
|
# Check for NaN/Inf
|
||||||
|
if torch.any(torch.isnan(prompt_embeds)):
|
||||||
|
print(" ** ERROR: NaN in prompt_embeds!")
|
||||||
|
elif torch.any(torch.isinf(prompt_embeds)):
|
||||||
|
print(" ** ERROR: Inf in prompt_embeds!")
|
||||||
|
else:
|
||||||
|
print(" OK: no NaN/Inf in prompt_embeds")
|
||||||
|
|
||||||
|
# Test cache round-trip (simulate what AnimaTextEncoderOutputsCachingStrategy does)
|
||||||
|
print(f"\n[2.5] Testing cache round-trip (encode -> numpy -> npz -> reload -> tensor)...")
|
||||||
|
|
||||||
|
# Convert to numpy (same as cache_batch_outputs in strategy_anima.py)
|
||||||
|
pe_cpu = prompt_embeds.cpu()
|
||||||
|
if pe_cpu.dtype == torch.bfloat16:
|
||||||
|
pe_cpu = pe_cpu.float()
|
||||||
|
pe_np = pe_cpu.numpy()
|
||||||
|
am_np = attn_mask.cpu().numpy()
|
||||||
|
t5_ids_np = t5_ids_out.cpu().numpy().astype(np.int32)
|
||||||
|
t5_mask_np = t5_mask_out.cpu().numpy().astype(np.int32)
|
||||||
|
|
||||||
|
print(f" Numpy conversions:")
|
||||||
|
print_tensor_info("prompt_embeds_np", pe_np)
|
||||||
|
print_tensor_info("attn_mask_np", am_np)
|
||||||
|
print_tensor_info("t5_input_ids_np", t5_ids_np)
|
||||||
|
print_tensor_info("t5_attn_mask_np", t5_mask_np)
|
||||||
|
|
||||||
|
if args.cache_to_disk:
|
||||||
|
npz_path = os.path.join(args.image_dir, "_test_te_cache.npz")
|
||||||
|
# Save per-sample (simulating cache_batch_outputs)
|
||||||
|
for i in range(len(captions)):
|
||||||
|
sample_npz = os.path.splitext(pairs[i][0])[0] + "_test_te.npz"
|
||||||
|
np.savez(
|
||||||
|
sample_npz,
|
||||||
|
prompt_embeds=pe_np[i],
|
||||||
|
attn_mask=am_np[i],
|
||||||
|
t5_input_ids=t5_ids_np[i],
|
||||||
|
t5_attn_mask=t5_mask_np[i],
|
||||||
|
)
|
||||||
|
print(f" Saved: {sample_npz}")
|
||||||
|
|
||||||
|
# Reload (simulating load_outputs_npz)
|
||||||
|
data = np.load(sample_npz)
|
||||||
|
print(f" Reloaded keys: {list(data.keys())}")
|
||||||
|
print_tensor_info(" loaded prompt_embeds", data["prompt_embeds"], indent=4)
|
||||||
|
print_tensor_info(" loaded attn_mask", data["attn_mask"], indent=4)
|
||||||
|
print_tensor_info(" loaded t5_input_ids", data["t5_input_ids"], indent=4)
|
||||||
|
print_tensor_info(" loaded t5_attn_mask", data["t5_attn_mask"], indent=4)
|
||||||
|
|
||||||
|
# Check diff
|
||||||
|
diff_pe = np.abs(pe_np[i] - data["prompt_embeds"]).max()
|
||||||
|
diff_t5 = np.abs(t5_ids_np[i] - data["t5_input_ids"]).max()
|
||||||
|
print(f" Max diff prompt_embeds: {diff_pe:.2e}")
|
||||||
|
print(f" Max diff t5_input_ids: {diff_t5:.2e}")
|
||||||
|
if diff_pe > 1e-5 or diff_t5 > 0:
|
||||||
|
print(" ** WARNING: cache round-trip mismatch!")
|
||||||
|
else:
|
||||||
|
print(" OK: round-trip matches")
|
||||||
|
|
||||||
|
os.remove(sample_npz)
|
||||||
|
print(f" Cleaned up {sample_npz}")
|
||||||
|
|
||||||
|
# Test in-memory cache round-trip (simulating what __getitem__ does)
|
||||||
|
print(f"\n[2.6] Testing in-memory cache simulation (tuple -> none_or_stack_elements -> batch)...")
|
||||||
|
|
||||||
|
# Simulate per-sample storage (like info.text_encoder_outputs = tuple)
|
||||||
|
per_sample_cached = []
|
||||||
|
for i in range(len(captions)):
|
||||||
|
per_sample_cached.append((pe_np[i], am_np[i], t5_ids_np[i], t5_mask_np[i]))
|
||||||
|
|
||||||
|
# Simulate none_or_stack_elements with torch.FloatTensor converter
|
||||||
|
# This is what train_util.py __getitem__ does at line 1784
|
||||||
|
stacked = []
|
||||||
|
for elem_idx in range(4):
|
||||||
|
arrays = [sample[elem_idx] for sample in per_sample_cached]
|
||||||
|
stacked.append(torch.stack([torch.FloatTensor(a) for a in arrays]))
|
||||||
|
|
||||||
|
print(f" Stacked batch (like batch['text_encoder_outputs_list']):")
|
||||||
|
names = ["prompt_embeds", "attn_mask", "t5_input_ids", "t5_attn_mask"]
|
||||||
|
for name, tensor in zip(names, stacked):
|
||||||
|
print_tensor_info(name, tensor)
|
||||||
|
|
||||||
|
# Check condition: len(text_encoder_conds) == 0 or text_encoder_conds[0] is None
|
||||||
|
text_encoder_conds = stacked
|
||||||
|
cond_check_1 = len(text_encoder_conds) == 0
|
||||||
|
cond_check_2 = text_encoder_conds[0] is None
|
||||||
|
print(f"\n Condition check (should both be False when caching works):")
|
||||||
|
print(f" len(text_encoder_conds) == 0 : {cond_check_1}")
|
||||||
|
print(f" text_encoder_conds[0] is None: {cond_check_2}")
|
||||||
|
if not cond_check_1 and not cond_check_2:
|
||||||
|
print(" OK: cached text encoder outputs would be used")
|
||||||
|
else:
|
||||||
|
print(" ** BUG: code would try to re-encode (and crash on None input_ids_list)!")
|
||||||
|
|
||||||
|
# Test unpack for get_noise_pred_and_target (line 311)
|
||||||
|
print(f"\n[2.7] Testing unpack: prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoder_conds")
|
||||||
|
try:
|
||||||
|
pe_batch, am_batch, t5_ids_batch, t5_mask_batch = text_encoder_conds
|
||||||
|
print(f" Unpack OK")
|
||||||
|
print_tensor_info("prompt_embeds", pe_batch)
|
||||||
|
print_tensor_info("attn_mask", am_batch)
|
||||||
|
print_tensor_info("t5_input_ids", t5_ids_batch)
|
||||||
|
print_tensor_info("t5_attn_mask", t5_mask_batch)
|
||||||
|
|
||||||
|
# Check t5_input_ids are integers (they were converted to FloatTensor!)
|
||||||
|
if t5_ids_batch.dtype != torch.long and t5_ids_batch.dtype != torch.int32:
|
||||||
|
print(f"\n ** NOTE: t5_input_ids dtype is {t5_ids_batch.dtype}, will be cast to long at line 316")
|
||||||
|
t5_ids_long = t5_ids_batch.to(dtype=torch.long)
|
||||||
|
# Check if any precision was lost
|
||||||
|
diff = (t5_ids_batch - t5_ids_long.float()).abs().max()
|
||||||
|
print(f" Float->Long precision loss: {diff:.2e}")
|
||||||
|
if diff > 0.5:
|
||||||
|
print(" ** ERROR: token IDs corrupted by float conversion!")
|
||||||
|
else:
|
||||||
|
print(" OK: float->long conversion is lossless for these IDs")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ** ERROR unpacking: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
# Test drop_cached_text_encoder_outputs
|
||||||
|
print(f"\n[2.8] Testing drop_cached_text_encoder_outputs (caption dropout)...")
|
||||||
|
dropout_strategy = AnimaTextEncodingStrategy(
|
||||||
|
dropout_rate=0.5, # high rate to ensure some drops
|
||||||
|
)
|
||||||
|
dropped = dropout_strategy.drop_cached_text_encoder_outputs(*stacked)
|
||||||
|
print(f" Returned {len(dropped)} tensors")
|
||||||
|
for name, tensor in zip(names, dropped):
|
||||||
|
print_tensor_info(f"dropped_{name}", tensor)
|
||||||
|
|
||||||
|
# Check which items were dropped
|
||||||
|
for i in range(len(captions)):
|
||||||
|
is_zero = (dropped[0][i].abs().sum() == 0).item()
|
||||||
|
print(f" Sample {i}: {'DROPPED' if is_zero else 'KEPT'}")
|
||||||
|
|
||||||
|
qwen3_model.to("cpu")
|
||||||
|
del qwen3_model
|
||||||
|
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
||||||
|
print("\n[2.8] Text encoder cache test DONE.")
|
||||||
|
|
||||||
|
|
||||||
|
# Test 3: Full batch simulation
|
||||||
|
|
||||||
|
def test_full_batch_simulation(args, pairs):
|
||||||
|
print("\n" + "=" * 70)
|
||||||
|
print("TEST 3: FULL BATCH SIMULATION (mimics process_batch flow)")
|
||||||
|
print("=" * 70)
|
||||||
|
|
||||||
|
from library import anima_utils
|
||||||
|
from library.anima_models import ANIMA_VAE_MEAN, ANIMA_VAE_STD
|
||||||
|
from library.strategy_anima import AnimaTokenizeStrategy, AnimaTextEncodingStrategy
|
||||||
|
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
te_dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
||||||
|
vae_dtype = torch.float32
|
||||||
|
|
||||||
|
# Load all models
|
||||||
|
print("\n[3.1] Loading models...")
|
||||||
|
qwen3_tokenizer = anima_utils.load_qwen3_tokenizer(args.qwen3_path)
|
||||||
|
t5_tokenizer = anima_utils.load_t5_tokenizer(getattr(args, 't5_tokenizer_path', None))
|
||||||
|
qwen3_model, _ = anima_utils.load_qwen3_text_encoder(args.qwen3_path, dtype=te_dtype, device=device)
|
||||||
|
qwen3_model.eval()
|
||||||
|
vae, _, _, vae_scale = anima_utils.load_anima_vae(args.vae_path, dtype=vae_dtype, device=device)
|
||||||
|
|
||||||
|
tokenize_strategy = AnimaTokenizeStrategy(
|
||||||
|
qwen3_tokenizer=qwen3_tokenizer, t5_tokenizer=t5_tokenizer,
|
||||||
|
qwen3_max_length=args.qwen3_max_length, t5_max_length=args.t5_max_length,
|
||||||
|
)
|
||||||
|
text_encoding_strategy = AnimaTextEncodingStrategy(dropout_rate=0.0)
|
||||||
|
|
||||||
|
captions = [cap for _, cap in pairs]
|
||||||
|
|
||||||
|
# --- Simulate caching phase ---
|
||||||
|
print("\n[3.2] Simulating text encoder caching phase...")
|
||||||
|
tokens_and_masks = tokenize_strategy.tokenize(captions)
|
||||||
|
with torch.no_grad():
|
||||||
|
te_outputs = text_encoding_strategy.encode_tokens(
|
||||||
|
tokenize_strategy, [qwen3_model], tokens_and_masks, enable_dropout=False,
|
||||||
|
)
|
||||||
|
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = te_outputs
|
||||||
|
|
||||||
|
# Convert to numpy (same as cache_batch_outputs)
|
||||||
|
pe_np = prompt_embeds.cpu().float().numpy()
|
||||||
|
am_np = attn_mask.cpu().numpy()
|
||||||
|
t5_ids_np = t5_input_ids.cpu().numpy().astype(np.int32)
|
||||||
|
t5_mask_np = t5_attn_mask.cpu().numpy().astype(np.int32)
|
||||||
|
|
||||||
|
# Per-sample storage (like info.text_encoder_outputs)
|
||||||
|
per_sample_te = [(pe_np[i], am_np[i], t5_ids_np[i], t5_mask_np[i]) for i in range(len(captions))]
|
||||||
|
|
||||||
|
print(f"\n[3.3] Simulating latent caching phase...")
|
||||||
|
per_sample_latents = []
|
||||||
|
for img_path, _ in pairs:
|
||||||
|
img = Image.open(img_path).convert("RGB")
|
||||||
|
img_np = np.array(img)
|
||||||
|
img_tensor = IMAGE_TRANSFORMS(img_np).unsqueeze(0).unsqueeze(2) # (1,C,1,H,W)
|
||||||
|
img_tensor = img_tensor.to(device, dtype=vae_dtype)
|
||||||
|
with torch.no_grad():
|
||||||
|
lat = vae.encode(img_tensor, vae_scale).cpu()
|
||||||
|
per_sample_latents.append(lat.squeeze(0)) # (C,1,H,W)
|
||||||
|
print(f" {os.path.basename(img_path)}: latent shape={tuple(lat.shape)}")
|
||||||
|
|
||||||
|
# --- Simulate batch construction (__getitem__) ---
|
||||||
|
print(f"\n[3.4] Simulating batch construction...")
|
||||||
|
|
||||||
|
# Use first image's latents only (images may have different resolutions)
|
||||||
|
latents_batch = per_sample_latents[0].unsqueeze(0) # (1,C,1,H,W)
|
||||||
|
print(f" Using first image latent for simulation: shape={tuple(latents_batch.shape)}")
|
||||||
|
|
||||||
|
# Stack text encoder outputs (none_or_stack_elements)
|
||||||
|
text_encoder_outputs_list = []
|
||||||
|
for elem_idx in range(4):
|
||||||
|
arrays = [s[elem_idx] for s in per_sample_te]
|
||||||
|
text_encoder_outputs_list.append(torch.stack([torch.FloatTensor(a) for a in arrays]))
|
||||||
|
|
||||||
|
# input_ids_list is None when caching
|
||||||
|
input_ids_list = None
|
||||||
|
|
||||||
|
batch = {
|
||||||
|
"latents": latents_batch,
|
||||||
|
"text_encoder_outputs_list": text_encoder_outputs_list,
|
||||||
|
"input_ids_list": input_ids_list,
|
||||||
|
"loss_weights": torch.ones(len(captions)),
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f" batch keys: {list(batch.keys())}")
|
||||||
|
print(f" batch['latents']: shape={tuple(batch['latents'].shape)}")
|
||||||
|
print(f" batch['text_encoder_outputs_list']: {len(batch['text_encoder_outputs_list'])} tensors")
|
||||||
|
print(f" batch['input_ids_list']: {batch['input_ids_list']}")
|
||||||
|
|
||||||
|
# --- Simulate process_batch logic ---
|
||||||
|
print(f"\n[3.5] Simulating process_batch logic...")
|
||||||
|
|
||||||
|
text_encoder_conds = []
|
||||||
|
te_out = batch.get("text_encoder_outputs_list", None)
|
||||||
|
if te_out is not None:
|
||||||
|
text_encoder_conds = te_out
|
||||||
|
print(f" text_encoder_conds loaded from cache: {len(text_encoder_conds)} tensors")
|
||||||
|
else:
|
||||||
|
print(f" text_encoder_conds: empty (no cache)")
|
||||||
|
|
||||||
|
# The critical condition
|
||||||
|
train_text_encoder_TRUE = True # OLD behavior (base class default, no override)
|
||||||
|
train_text_encoder_FALSE = False # NEW behavior (with is_train_text_encoder override)
|
||||||
|
|
||||||
|
cond_old = len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder_TRUE
|
||||||
|
cond_new = len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder_FALSE
|
||||||
|
|
||||||
|
print(f"\n === CRITICAL CONDITION CHECK ===")
|
||||||
|
print(f" len(text_encoder_conds) == 0 : {len(text_encoder_conds) == 0}")
|
||||||
|
print(f" text_encoder_conds[0] is None: {text_encoder_conds[0] is None}")
|
||||||
|
print(f" train_text_encoder (OLD=True) : {train_text_encoder_TRUE}")
|
||||||
|
print(f" train_text_encoder (NEW=False): {train_text_encoder_FALSE}")
|
||||||
|
print(f"")
|
||||||
|
print(f" Condition with OLD behavior (no override): {cond_old}")
|
||||||
|
msg = (
|
||||||
|
"ENTERS re-encode block -> accesses batch['input_ids_list'] -> CRASH!"
|
||||||
|
if cond_old
|
||||||
|
else "SKIPS re-encode block -> uses cache -> OK"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f" -> {msg}")
|
||||||
|
print(f" Condition with NEW behavior (override): {cond_new}")
|
||||||
|
print(f" -> {'ENTERS re-encode block' if cond_new else 'SKIPS re-encode block -> uses cache -> OK'}")
|
||||||
|
|
||||||
|
if cond_old and not cond_new:
|
||||||
|
print(f"\n ** CONFIRMED: the is_train_text_encoder override fixes the crash **")
|
||||||
|
|
||||||
|
# Simulate the rest of process_batch
|
||||||
|
print(f"\n[3.6] Simulating get_noise_pred_and_target unpack...")
|
||||||
|
try:
|
||||||
|
pe, am, t5_ids, t5_mask = text_encoder_conds
|
||||||
|
pe = pe.to(device, dtype=te_dtype)
|
||||||
|
am = am.to(device)
|
||||||
|
t5_ids = t5_ids.to(device, dtype=torch.long)
|
||||||
|
t5_mask = t5_mask.to(device)
|
||||||
|
|
||||||
|
print(f" Unpack + device transfer OK:")
|
||||||
|
print_tensor_info("prompt_embeds", pe)
|
||||||
|
print_tensor_info("attn_mask", am)
|
||||||
|
print_tensor_info("t5_input_ids", t5_ids)
|
||||||
|
print_tensor_info("t5_attn_mask", t5_mask)
|
||||||
|
|
||||||
|
# Verify t5_input_ids didn't get corrupted by float conversion
|
||||||
|
t5_ids_orig = torch.tensor(t5_ids_np, dtype=torch.long, device=device)
|
||||||
|
id_match = torch.all(t5_ids == t5_ids_orig).item()
|
||||||
|
print(f"\n t5_input_ids integrity (float->long roundtrip): {'OK' if id_match else '** MISMATCH **'}")
|
||||||
|
if not id_match:
|
||||||
|
diff_count = (t5_ids != t5_ids_orig).sum().item()
|
||||||
|
print(f" {diff_count} token IDs differ!")
|
||||||
|
# Show example
|
||||||
|
idx = torch.where(t5_ids != t5_ids_orig)
|
||||||
|
if len(idx[0]) > 0:
|
||||||
|
i, j = idx[0][0].item(), idx[1][0].item()
|
||||||
|
print(f" Example: position [{i},{j}] original={t5_ids_orig[i,j].item()} loaded={t5_ids[i,j].item()}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ** ERROR: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
vae.to("cpu")
|
||||||
|
qwen3_model.to("cpu")
|
||||||
|
del vae, qwen3_model
|
||||||
|
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
||||||
|
print("\n[3.7] Full batch simulation DONE.")
|
||||||
|
|
||||||
|
|
||||||
|
# Main
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Test Anima caching mechanisms")
|
||||||
|
parser.add_argument("--image_dir", type=str, required=True,
|
||||||
|
help="Directory with image+txt pairs")
|
||||||
|
parser.add_argument("--qwen3_path", type=str, required=True,
|
||||||
|
help="Path to Qwen3 model (directory or safetensors)")
|
||||||
|
parser.add_argument("--vae_path", type=str, required=True,
|
||||||
|
help="Path to WanVAE safetensors")
|
||||||
|
parser.add_argument("--t5_tokenizer_path", type=str, default=None,
|
||||||
|
help="Path to T5 tokenizer (optional, uses bundled config)")
|
||||||
|
parser.add_argument("--qwen3_max_length", type=int, default=512)
|
||||||
|
parser.add_argument("--t5_max_length", type=int, default=512)
|
||||||
|
parser.add_argument("--cache_to_disk", action="store_true",
|
||||||
|
help="Also test disk cache round-trip")
|
||||||
|
parser.add_argument("--skip_latent", action="store_true",
|
||||||
|
help="Skip latent cache test")
|
||||||
|
parser.add_argument("--skip_text", action="store_true",
|
||||||
|
help="Skip text encoder cache test")
|
||||||
|
parser.add_argument("--skip_full", action="store_true",
|
||||||
|
help="Skip full batch simulation")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Find pairs
|
||||||
|
pairs = find_image_caption_pairs(args.image_dir)
|
||||||
|
if len(pairs) == 0:
|
||||||
|
print(f"ERROR: No image+txt pairs found in {args.image_dir}")
|
||||||
|
print("Expected: image.png + image.txt, image.jpg + image.txt, etc.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
print(f"Found {len(pairs)} image-caption pairs:")
|
||||||
|
for img_path, cap in pairs:
|
||||||
|
print(f" {os.path.basename(img_path)}: \"{cap[:60]}{'...' if len(cap) > 60 else ''}\"")
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
if not args.skip_latent:
|
||||||
|
try:
|
||||||
|
test_latent_cache(args, pairs)
|
||||||
|
results["latent_cache"] = "PASS"
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n** LATENT CACHE TEST FAILED: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
results["latent_cache"] = f"FAIL: {e}"
|
||||||
|
|
||||||
|
if not args.skip_text:
|
||||||
|
try:
|
||||||
|
test_text_encoder_cache(args, pairs)
|
||||||
|
results["text_encoder_cache"] = "PASS"
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n** TEXT ENCODER CACHE TEST FAILED: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
results["text_encoder_cache"] = f"FAIL: {e}"
|
||||||
|
|
||||||
|
if not args.skip_full:
|
||||||
|
try:
|
||||||
|
test_full_batch_simulation(args, pairs)
|
||||||
|
results["full_batch_sim"] = "PASS"
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n** FULL BATCH SIMULATION FAILED: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
results["full_batch_sim"] = f"FAIL: {e}"
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
print("\n" + "=" * 70)
|
||||||
|
print("SUMMARY")
|
||||||
|
print("=" * 70)
|
||||||
|
for test, result in results.items():
|
||||||
|
status = "OK" if result == "PASS" else "FAIL"
|
||||||
|
print(f" [{status}] {test}: {result}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
242
tests/test_anima_real_training.py
Normal file
242
tests/test_anima_real_training.py
Normal file
@@ -0,0 +1,242 @@
|
|||||||
|
"""
|
||||||
|
Test script that actually runs anima_train.py and anima_train_network.py
|
||||||
|
for a few steps to verify --cache_text_encoder_outputs works.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python test_anima_real_training.py \
|
||||||
|
--image_dir /path/to/images_with_txt \
|
||||||
|
--dit_path /path/to/dit.safetensors \
|
||||||
|
--qwen3_path /path/to/qwen3 \
|
||||||
|
--vae_path /path/to/vae.safetensors \
|
||||||
|
[--t5_tokenizer_path /path/to/t5] \
|
||||||
|
[--resolution 512]
|
||||||
|
|
||||||
|
This will run 4 tests:
|
||||||
|
1. anima_train.py (full finetune, no cache)
|
||||||
|
2. anima_train.py (full finetune, --cache_text_encoder_outputs)
|
||||||
|
3. anima_train_network.py (LoRA, no cache)
|
||||||
|
4. anima_train_network.py (LoRA, --cache_text_encoder_outputs)
|
||||||
|
|
||||||
|
Each test runs only 2 training steps then stops.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataset_toml(image_dir: str, resolution: int, toml_path: str):
|
||||||
|
"""Create a minimal dataset toml config."""
|
||||||
|
content = f"""[general]
|
||||||
|
resolution = {resolution}
|
||||||
|
enable_bucket = true
|
||||||
|
bucket_reso_steps = 8
|
||||||
|
min_bucket_reso = 256
|
||||||
|
max_bucket_reso = 1024
|
||||||
|
|
||||||
|
[[datasets]]
|
||||||
|
batch_size = 1
|
||||||
|
|
||||||
|
[[datasets.subsets]]
|
||||||
|
image_dir = "{image_dir}"
|
||||||
|
num_repeats = 1
|
||||||
|
caption_extension = ".txt"
|
||||||
|
"""
|
||||||
|
with open(toml_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(content)
|
||||||
|
return toml_path
|
||||||
|
|
||||||
|
|
||||||
|
def run_test(test_name: str, cmd: list, timeout: int = 300) -> dict:
|
||||||
|
"""Run a training command and capture result."""
|
||||||
|
print(f"\n{'=' * 70}")
|
||||||
|
print(f"TEST: {test_name}")
|
||||||
|
print(f"{'=' * 70}")
|
||||||
|
print(f"Command: {' '.join(cmd)}\n")
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=timeout,
|
||||||
|
cwd=os.path.dirname(os.path.abspath(__file__)),
|
||||||
|
)
|
||||||
|
|
||||||
|
stdout = result.stdout
|
||||||
|
stderr = result.stderr
|
||||||
|
returncode = result.returncode
|
||||||
|
|
||||||
|
# Print last N lines of output
|
||||||
|
all_output = stdout + "\n" + stderr
|
||||||
|
lines = all_output.strip().split("\n")
|
||||||
|
print(f"--- Last 30 lines of output ---")
|
||||||
|
for line in lines[-30:]:
|
||||||
|
print(f" {line}")
|
||||||
|
print(f"--- End output ---\n")
|
||||||
|
|
||||||
|
if returncode == 0:
|
||||||
|
print(f"RESULT: PASS (exit code 0)")
|
||||||
|
return {"status": "PASS", "detail": "completed successfully"}
|
||||||
|
else:
|
||||||
|
# Check if it's a known error
|
||||||
|
if "TypeError: 'NoneType' object is not iterable" in all_output:
|
||||||
|
print(f"RESULT: FAIL - input_ids_list is None (the cache_text_encoder_outputs bug)")
|
||||||
|
return {"status": "FAIL", "detail": "input_ids_list is None - cache TE outputs bug"}
|
||||||
|
elif "steps: 0%" in all_output and "Error" in all_output:
|
||||||
|
# Find the actual error
|
||||||
|
error_lines = [l for l in lines if "Error" in l or "Traceback" in l or "raise" in l.lower()]
|
||||||
|
detail = error_lines[-1] if error_lines else f"exit code {returncode}"
|
||||||
|
print(f"RESULT: FAIL - {detail}")
|
||||||
|
return {"status": "FAIL", "detail": detail}
|
||||||
|
else:
|
||||||
|
print(f"RESULT: FAIL (exit code {returncode})")
|
||||||
|
return {"status": "FAIL", "detail": f"exit code {returncode}"}
|
||||||
|
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
print(f"RESULT: TIMEOUT (>{timeout}s)")
|
||||||
|
return {"status": "TIMEOUT", "detail": f"exceeded {timeout}s"}
|
||||||
|
except Exception as e:
|
||||||
|
print(f"RESULT: ERROR - {e}")
|
||||||
|
return {"status": "ERROR", "detail": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Test Anima real training with cache flags")
|
||||||
|
parser.add_argument("--image_dir", type=str, required=True,
|
||||||
|
help="Directory with image+txt pairs")
|
||||||
|
parser.add_argument("--dit_path", type=str, required=True,
|
||||||
|
help="Path to Anima DiT safetensors")
|
||||||
|
parser.add_argument("--qwen3_path", type=str, required=True,
|
||||||
|
help="Path to Qwen3 model")
|
||||||
|
parser.add_argument("--vae_path", type=str, required=True,
|
||||||
|
help="Path to WanVAE safetensors")
|
||||||
|
parser.add_argument("--t5_tokenizer_path", type=str, default=None)
|
||||||
|
parser.add_argument("--resolution", type=int, default=512)
|
||||||
|
parser.add_argument("--timeout", type=int, default=300,
|
||||||
|
help="Timeout per test in seconds (default: 300)")
|
||||||
|
parser.add_argument("--only", type=str, default=None,
|
||||||
|
choices=["finetune", "lora"],
|
||||||
|
help="Only run finetune or lora tests")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Validate paths
|
||||||
|
for name, path in [("image_dir", args.image_dir), ("dit_path", args.dit_path),
|
||||||
|
("qwen3_path", args.qwen3_path), ("vae_path", args.vae_path)]:
|
||||||
|
if not os.path.exists(path):
|
||||||
|
print(f"ERROR: {name} does not exist: {path}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Create temp dir for outputs
|
||||||
|
tmp_dir = tempfile.mkdtemp(prefix="anima_test_")
|
||||||
|
print(f"Temp directory: {tmp_dir}")
|
||||||
|
|
||||||
|
# Create dataset toml
|
||||||
|
toml_path = os.path.join(tmp_dir, "dataset.toml")
|
||||||
|
create_dataset_toml(args.image_dir, args.resolution, toml_path)
|
||||||
|
print(f"Dataset config: {toml_path}")
|
||||||
|
|
||||||
|
output_dir = os.path.join(tmp_dir, "output")
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
python = sys.executable
|
||||||
|
|
||||||
|
# Common args for both scripts
|
||||||
|
common_anima_args = [
|
||||||
|
"--dit_path", args.dit_path,
|
||||||
|
"--qwen3_path", args.qwen3_path,
|
||||||
|
"--vae_path", args.vae_path,
|
||||||
|
"--pretrained_model_name_or_path", args.dit_path, # required by base parser
|
||||||
|
"--output_dir", output_dir,
|
||||||
|
"--output_name", "test",
|
||||||
|
"--dataset_config", toml_path,
|
||||||
|
"--max_train_steps", "2",
|
||||||
|
"--learning_rate", "1e-5",
|
||||||
|
"--mixed_precision", "bf16",
|
||||||
|
"--save_every_n_steps", "999", # don't save
|
||||||
|
"--max_data_loader_n_workers", "0", # single process for clarity
|
||||||
|
"--logging_dir", os.path.join(tmp_dir, "logs"),
|
||||||
|
"--cache_latents",
|
||||||
|
]
|
||||||
|
if args.t5_tokenizer_path:
|
||||||
|
common_anima_args += ["--t5_tokenizer_path", args.t5_tokenizer_path]
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
# TEST 1: anima_train.py - NO cache_text_encoder_outputs
|
||||||
|
if args.only is None or args.only == "finetune":
|
||||||
|
cmd = [python, "anima_train.py"] + common_anima_args + [
|
||||||
|
"--optimizer_type", "AdamW8bit",
|
||||||
|
]
|
||||||
|
results["finetune_no_cache"] = run_test(
|
||||||
|
"anima_train.py (full finetune, NO text encoder cache)",
|
||||||
|
cmd, args.timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TEST 2: anima_train.py - WITH cache_text_encoder_outputs
|
||||||
|
cmd = [python, "anima_train.py"] + common_anima_args + [
|
||||||
|
"--optimizer_type", "AdamW8bit",
|
||||||
|
"--cache_text_encoder_outputs",
|
||||||
|
]
|
||||||
|
results["finetune_with_cache"] = run_test(
|
||||||
|
"anima_train.py (full finetune, WITH --cache_text_encoder_outputs)",
|
||||||
|
cmd, args.timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TEST 3: anima_train_network.py - NO cache_text_encoder_outputs
|
||||||
|
if args.only is None or args.only == "lora":
|
||||||
|
lora_args = common_anima_args + [
|
||||||
|
"--optimizer_type", "AdamW8bit",
|
||||||
|
"--network_module", "networks.lora_anima",
|
||||||
|
"--network_dim", "4",
|
||||||
|
"--network_alpha", "1",
|
||||||
|
]
|
||||||
|
|
||||||
|
cmd = [python, "anima_train_network.py"] + lora_args
|
||||||
|
results["lora_no_cache"] = run_test(
|
||||||
|
"anima_train_network.py (LoRA, NO text encoder cache)",
|
||||||
|
cmd, args.timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TEST 4: anima_train_network.py - WITH cache_text_encoder_outputs
|
||||||
|
cmd = [python, "anima_train_network.py"] + lora_args + [
|
||||||
|
"--cache_text_encoder_outputs",
|
||||||
|
]
|
||||||
|
results["lora_with_cache"] = run_test(
|
||||||
|
"anima_train_network.py (LoRA, WITH --cache_text_encoder_outputs)",
|
||||||
|
cmd, args.timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
# SUMMARY
|
||||||
|
print(f"\n{'=' * 70}")
|
||||||
|
print("SUMMARY")
|
||||||
|
print(f"{'=' * 70}")
|
||||||
|
all_pass = True
|
||||||
|
for test_name, result in results.items():
|
||||||
|
status = result["status"]
|
||||||
|
icon = "OK" if status == "PASS" else "FAIL"
|
||||||
|
if status != "PASS":
|
||||||
|
all_pass = False
|
||||||
|
print(f" [{icon:4s}] {test_name}: {result['detail']}")
|
||||||
|
|
||||||
|
print(f"\nTemp directory (can delete): {tmp_dir}")
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
try:
|
||||||
|
shutil.rmtree(tmp_dir)
|
||||||
|
print("Temp directory cleaned up.")
|
||||||
|
except Exception:
|
||||||
|
print(f"Note: could not clean up {tmp_dir}")
|
||||||
|
|
||||||
|
if all_pass:
|
||||||
|
print("\nAll tests PASSED!")
|
||||||
|
else:
|
||||||
|
print("\nSome tests FAILED!")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user