mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 05:44:56 +00:00
Add/modify some implementation for anima (#2261)
* fix: update extend-exclude list in _typos.toml to include configs * fix: exclude anima tests from pytest * feat: add entry for 'temperal' in extend-words section of _typos.toml for Qwen-Image VAE * fix: update default value for --discrete_flow_shift in anima training guide * feat: add Qwen-Image VAE * feat: simplify encode_tokens * feat: use unified attention module, add wrapper for state dict compatibility * feat: loading with dynamic fp8 optimization and LoRA support * feat: add anima minimal inference script (WIP) * format: format * feat: simplify target module selection by regular expression patterns * feat: kept caption dropout rate in cache and handle in training script * feat: update train_llm_adapter and verbose default values to string type * fix: use strategy instead of using tokenizers directly * feat: add dtype property and all-zero mask handling in cross-attention in LLMAdapterTransformerBlock * feat: support 5d tensor in get_noisy_model_input_and_timesteps * feat: update loss calculation to support 5d tensor * fix: update argument names in anima_train_utils to align with other archtectures * feat: simplify Anima training script and update empty caption handling * feat: support LoRA format without `net.` prefix * fix: update to work fp8_scaled option * feat: add regex-based learning rates and dimensions handling in create_network * fix: improve regex matching for module selection and learning rates in LoRANetwork * fix: update logging message for regex match in LoRANetwork * fix: keep latents 4D except DiT call * feat: enhance block swap functionality for inference and training in Anima model * feat: refactor Anima training script * feat: optimize VAE processing by adjusting tensor dimensions and data types * fix: wait all block trasfer before siwtching offloader mode * feat: update Anima training guide with new argument specifications and regex-based module selection. Thank you Claude! * feat: support LORA for Qwen3 * feat: update Anima SAI model spec metadata handling * fix: remove unused code * feat: split CFG processing in do_sample function to reduce memory usage * feat: add VAE chunking and caching options to reduce memory usage * feat: optimize RMSNorm forward method and remove unused torch_attention_op * Update library/strategy_anima.py Use torch.all instead of all. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update library/safetensors_utils.py Fix duplicated new_key for concat_hook. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update anima_minimal_inference.py Remove unused code. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update anima_train.py Remove unused import. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update library/anima_train_utils.py Remove unused import. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix: review with Copilot * feat: add script to convert LoRA format to ComfyUI compatible format (WIP, not tested yet) * feat: add process_escape function to handle escape sequences in prompts * feat: enhance LoRA weight handling in model loading and add text encoder loading function * feat: improve ComfyUI conversion script with prefix constants and module name adjustments * feat: update caption dropout documentation to clarify cache regeneration requirement * feat: add clarification on learning rate adjustments * feat: add note on PyTorch version requirement to prevent NaN loss --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -32,6 +32,7 @@ hime="hime"
|
||||
OT="OT"
|
||||
byt="byt"
|
||||
tak="tak"
|
||||
temperal="temperal"
|
||||
|
||||
[files]
|
||||
extend-exclude = ["_typos.toml", "venv"]
|
||||
extend-exclude = ["_typos.toml", "venv", "configs"]
|
||||
|
||||
1082
anima_minimal_inference.py
Normal file
1082
anima_minimal_inference.py
Normal file
File diff suppressed because it is too large
Load Diff
340
anima_train.py
340
anima_train.py
@@ -3,6 +3,7 @@
|
||||
import argparse
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import copy
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
from multiprocessing import Value
|
||||
@@ -12,8 +13,9 @@ import toml
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from library import utils
|
||||
from library import flux_train_utils, qwen_image_autoencoder_kl
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
|
||||
|
||||
init_ipex()
|
||||
|
||||
@@ -49,21 +51,18 @@ def train(args):
|
||||
args.skip_cache_check = args.skip_latents_validity_check
|
||||
|
||||
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
||||
logger.warning(
|
||||
"cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled"
|
||||
)
|
||||
logger.warning("cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled")
|
||||
args.cache_text_encoder_outputs = True
|
||||
|
||||
if args.cpu_offload_checkpointing and not args.gradient_checkpointing:
|
||||
logger.warning("cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled")
|
||||
args.gradient_checkpointing = True
|
||||
|
||||
if getattr(args, 'unsloth_offload_checkpointing', False):
|
||||
if args.unsloth_offload_checkpointing:
|
||||
if not args.gradient_checkpointing:
|
||||
logger.warning("unsloth_offload_checkpointing is enabled, so gradient_checkpointing is also enabled")
|
||||
args.gradient_checkpointing = True
|
||||
assert not args.cpu_offload_checkpointing, \
|
||||
"Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing"
|
||||
assert not args.cpu_offload_checkpointing, "Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing"
|
||||
|
||||
assert (
|
||||
args.blocks_to_swap is None or args.blocks_to_swap == 0
|
||||
@@ -71,17 +70,7 @@ def train(args):
|
||||
|
||||
assert (
|
||||
args.blocks_to_swap is None or args.blocks_to_swap == 0
|
||||
) or not getattr(args, 'unsloth_offload_checkpointing', False), \
|
||||
"blocks_to_swap is not supported with unsloth_offload_checkpointing"
|
||||
|
||||
# Flash attention: validate availability
|
||||
if getattr(args, 'flash_attn', False):
|
||||
try:
|
||||
import flash_attn # noqa: F401
|
||||
logger.info("Flash Attention enabled for DiT blocks")
|
||||
except ImportError:
|
||||
logger.warning("flash_attn package not installed, falling back to PyTorch SDPA")
|
||||
args.flash_attn = False
|
||||
) or not args.unsloth_offload_checkpointing, "blocks_to_swap is not supported with unsloth_offload_checkpointing"
|
||||
|
||||
cache_latents = args.cache_latents
|
||||
use_dreambooth_method = args.in_json is None
|
||||
@@ -104,9 +93,7 @@ def train(args):
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
logger.warning(
|
||||
"ignore following options because config file is found: {0}".format(", ".join(ignored))
|
||||
)
|
||||
logger.warning("ignore following options because config file is found: {0}".format(", ".join(ignored)))
|
||||
else:
|
||||
if use_dreambooth_method:
|
||||
logger.info("Using DreamBooth method.")
|
||||
@@ -145,26 +132,13 @@ def train(args):
|
||||
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(8) # WanVAE spatial downscale = 8
|
||||
|
||||
# Anima uses embedding-level dropout (in AnimaTextEncodingStrategy) instead of
|
||||
# dataset-level caption dropout, so we save the rate and zero out subset-level
|
||||
# caption_dropout_rate to allow text encoder output caching.
|
||||
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
|
||||
if caption_dropout_rate > 0:
|
||||
logger.info(f"Using embedding-level caption dropout rate: {caption_dropout_rate}")
|
||||
for dataset in train_dataset_group.datasets:
|
||||
for subset in dataset.subsets:
|
||||
subset.caption_dropout_rate = 0.0
|
||||
train_dataset_group.verify_bucket_reso_steps(16) # Qwen-Image VAE spatial downscale = 8 * patch size = 2
|
||||
|
||||
if args.debug_dataset:
|
||||
if args.cache_text_encoder_outputs:
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
|
||||
strategy_anima.AnimaTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
args.text_encoder_batch_size,
|
||||
False,
|
||||
False,
|
||||
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False
|
||||
)
|
||||
)
|
||||
train_dataset_group.set_current_strategies()
|
||||
@@ -175,13 +149,11 @@ def train(args):
|
||||
return
|
||||
|
||||
if cache_latents:
|
||||
assert (
|
||||
train_dataset_group.is_latent_cacheable()
|
||||
), "when caching latents, either color_aug or random_crop cannot be used"
|
||||
assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used"
|
||||
|
||||
if args.cache_text_encoder_outputs:
|
||||
assert (
|
||||
train_dataset_group.is_text_encoder_output_cacheable()
|
||||
assert train_dataset_group.is_text_encoder_output_cacheable(
|
||||
cache_supports_dropout=True
|
||||
), "when caching text encoder output, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used"
|
||||
|
||||
# prepare accelerator
|
||||
@@ -191,24 +163,10 @@ def train(args):
|
||||
# mixed precision dtype
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
|
||||
# parse transformer_dtype
|
||||
transformer_dtype = None
|
||||
if hasattr(args, 'transformer_dtype') and args.transformer_dtype is not None:
|
||||
transformer_dtype_map = {
|
||||
"float16": torch.float16,
|
||||
"bfloat16": torch.bfloat16,
|
||||
"float32": torch.float32,
|
||||
}
|
||||
transformer_dtype = transformer_dtype_map.get(args.transformer_dtype, None)
|
||||
|
||||
# Load tokenizers and set strategies
|
||||
logger.info("Loading tokenizers...")
|
||||
qwen3_text_encoder, qwen3_tokenizer = anima_utils.load_qwen3_text_encoder(
|
||||
args.qwen3_path, dtype=weight_dtype, device="cpu"
|
||||
)
|
||||
t5_tokenizer = anima_utils.load_t5_tokenizer(
|
||||
getattr(args, 't5_tokenizer_path', None)
|
||||
)
|
||||
qwen3_text_encoder, qwen3_tokenizer = anima_utils.load_qwen3_text_encoder(args.qwen3, dtype=weight_dtype, device="cpu")
|
||||
t5_tokenizer = anima_utils.load_t5_tokenizer(args.t5_tokenizer_path)
|
||||
|
||||
# Set tokenize strategy
|
||||
tokenize_strategy = strategy_anima.AnimaTokenizeStrategy(
|
||||
@@ -219,11 +177,7 @@ def train(args):
|
||||
)
|
||||
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
|
||||
|
||||
# Set text encoding strategy
|
||||
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
|
||||
text_encoding_strategy = strategy_anima.AnimaTextEncodingStrategy(
|
||||
dropout_rate=caption_dropout_rate,
|
||||
)
|
||||
text_encoding_strategy = strategy_anima.AnimaTextEncodingStrategy()
|
||||
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
|
||||
|
||||
# Prepare text encoder (always frozen for Anima)
|
||||
@@ -237,10 +191,7 @@ def train(args):
|
||||
qwen3_text_encoder.eval()
|
||||
|
||||
text_encoder_caching_strategy = strategy_anima.AnimaTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
args.text_encoder_batch_size,
|
||||
args.skip_cache_check,
|
||||
is_partial=False,
|
||||
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, is_partial=False
|
||||
)
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy)
|
||||
|
||||
@@ -259,27 +210,21 @@ def train(args):
|
||||
logger.info(f" cache TE outputs for: {p}")
|
||||
tokens_and_masks = tokenize_strategy.tokenize(p)
|
||||
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy,
|
||||
[qwen3_text_encoder],
|
||||
tokens_and_masks,
|
||||
enable_dropout=False,
|
||||
tokenize_strategy, [qwen3_text_encoder], tokens_and_masks
|
||||
)
|
||||
|
||||
# Pre-cache unconditional embeddings for caption dropout before text encoder is deleted
|
||||
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
|
||||
if caption_dropout_rate > 0.0:
|
||||
with accelerator.autocast():
|
||||
text_encoding_strategy.cache_uncond_embeddings(tokenize_strategy, [qwen3_text_encoder])
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# free text encoder memory
|
||||
qwen3_text_encoder = None
|
||||
gc.collect() # Force garbage collection to free memory
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
# Load VAE and cache latents
|
||||
logger.info("Loading Anima VAE...")
|
||||
vae, vae_mean, vae_std, vae_scale = anima_utils.load_anima_vae(args.vae_path, dtype=weight_dtype, device="cpu")
|
||||
vae = qwen_image_autoencoder_kl.load_vae(
|
||||
args.vae, device="cpu", disable_mmap=True, spatial_chunk_size=args.vae_chunk_size, disable_cache=args.vae_disable_cache
|
||||
)
|
||||
|
||||
if cache_latents:
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
@@ -294,24 +239,16 @@ def train(args):
|
||||
|
||||
# Load DiT (MiniTrainDIT + optional LLM Adapter)
|
||||
logger.info("Loading Anima DiT...")
|
||||
dit = anima_utils.load_anima_dit(
|
||||
args.dit_path,
|
||||
dtype=weight_dtype,
|
||||
device="cpu",
|
||||
transformer_dtype=transformer_dtype,
|
||||
llm_adapter_path=getattr(args, 'llm_adapter_path', None),
|
||||
disable_mmap=getattr(args, 'disable_mmap_load_safetensors', False),
|
||||
dit = anima_utils.load_anima_model(
|
||||
"cpu", args.pretrained_model_name_or_path, args.attn_mode, args.split_attn, "cpu", dit_weight_dtype=None
|
||||
)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
dit.enable_gradient_checkpointing(
|
||||
cpu_offload=args.cpu_offload_checkpointing,
|
||||
unsloth_offload=getattr(args, 'unsloth_offload_checkpointing', False),
|
||||
unsloth_offload=args.unsloth_offload_checkpointing,
|
||||
)
|
||||
|
||||
if getattr(args, 'flash_attn', False):
|
||||
dit.set_flash_attn(True)
|
||||
|
||||
train_dit = args.learning_rate != 0
|
||||
dit.requires_grad_(train_dit)
|
||||
if not train_dit:
|
||||
@@ -327,19 +264,17 @@ def train(args):
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
# Move scale tensors to same device as VAE for on-the-fly encoding
|
||||
vae_scale = [s.to(accelerator.device) if isinstance(s, torch.Tensor) else s for s in vae_scale]
|
||||
|
||||
# Setup optimizer with parameter groups
|
||||
if train_dit:
|
||||
param_groups = anima_train_utils.get_anima_param_groups(
|
||||
dit,
|
||||
base_lr=args.learning_rate,
|
||||
self_attn_lr=getattr(args, 'self_attn_lr', None),
|
||||
cross_attn_lr=getattr(args, 'cross_attn_lr', None),
|
||||
mlp_lr=getattr(args, 'mlp_lr', None),
|
||||
mod_lr=getattr(args, 'mod_lr', None),
|
||||
llm_adapter_lr=getattr(args, 'llm_adapter_lr', None),
|
||||
self_attn_lr=args.self_attn_lr,
|
||||
cross_attn_lr=args.cross_attn_lr,
|
||||
mlp_lr=args.mlp_lr,
|
||||
mod_lr=args.mod_lr,
|
||||
llm_adapter_lr=args.llm_adapter_lr,
|
||||
)
|
||||
else:
|
||||
param_groups = []
|
||||
@@ -361,57 +296,7 @@ def train(args):
|
||||
# prepare optimizer
|
||||
accelerator.print("prepare optimizer, data loader etc.")
|
||||
|
||||
if args.blockwise_fused_optimizers:
|
||||
# Split params into per-block groups for blockwise fused optimizer
|
||||
# Build param_id → lr mapping from param_groups to propagate per-component LRs
|
||||
param_lr_map = {}
|
||||
for group in param_groups:
|
||||
for p in group['params']:
|
||||
param_lr_map[id(p)] = group['lr']
|
||||
|
||||
grouped_params = []
|
||||
param_group = {}
|
||||
named_parameters = list(dit.named_parameters())
|
||||
for name, p in named_parameters:
|
||||
if not p.requires_grad:
|
||||
continue
|
||||
# Determine block type and index
|
||||
if name.startswith("blocks."):
|
||||
block_index = int(name.split(".")[1])
|
||||
block_type = "blocks"
|
||||
elif name.startswith("llm_adapter.blocks."):
|
||||
block_index = int(name.split(".")[2])
|
||||
block_type = "llm_adapter"
|
||||
else:
|
||||
block_index = -1
|
||||
block_type = "other"
|
||||
|
||||
param_group_key = (block_type, block_index)
|
||||
if param_group_key not in param_group:
|
||||
param_group[param_group_key] = []
|
||||
param_group[param_group_key].append(p)
|
||||
|
||||
for param_group_key, params in param_group.items():
|
||||
# Use per-component LR from param_groups if available
|
||||
lr = param_lr_map.get(id(params[0]), args.learning_rate)
|
||||
grouped_params.append({"params": params, "lr": lr})
|
||||
num_params = sum(p.numel() for p in params)
|
||||
accelerator.print(f"block {param_group_key}: {num_params} parameters, lr={lr}")
|
||||
|
||||
# Create per-group optimizers
|
||||
optimizers = []
|
||||
for group in grouped_params:
|
||||
_, _, opt = train_util.get_optimizer(args, trainable_params=[group])
|
||||
optimizers.append(opt)
|
||||
optimizer = optimizers[0] # avoid error in following code
|
||||
|
||||
logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers")
|
||||
|
||||
if train_util.is_schedulefree_optimizer(optimizers[0], args):
|
||||
raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers")
|
||||
optimizer_train_fn = lambda: None
|
||||
optimizer_eval_fn = lambda: None
|
||||
elif args.fused_backward_pass:
|
||||
if args.fused_backward_pass:
|
||||
# Pass per-component param_groups directly to preserve per-component LRs
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params=param_groups)
|
||||
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)
|
||||
@@ -442,21 +327,19 @@ def train(args):
|
||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
|
||||
# lr scheduler
|
||||
if args.blockwise_fused_optimizers:
|
||||
lr_schedulers = [train_util.get_scheduler_fix(args, opt, accelerator.num_processes) for opt in optimizers]
|
||||
lr_scheduler = lr_schedulers[0] # avoid error in following code
|
||||
else:
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
|
||||
# full fp16/bf16 training
|
||||
dit_weight_dtype = weight_dtype
|
||||
if args.full_fp16:
|
||||
assert args.mixed_precision == "fp16", "full_fp16 requires mixed_precision='fp16'"
|
||||
accelerator.print("enable full fp16 training.")
|
||||
dit.to(weight_dtype)
|
||||
elif args.full_bf16:
|
||||
assert args.mixed_precision == "bf16", "full_bf16 requires mixed_precision='bf16'"
|
||||
accelerator.print("enable full bf16 training.")
|
||||
dit.to(weight_dtype)
|
||||
else:
|
||||
dit_weight_dtype = torch.float32 # If neither full_fp16 nor full_bf16, the model weights should be in float32
|
||||
dit.to(dit_weight_dtype) # convert dit to target weight dtype
|
||||
|
||||
# move text encoder to GPU if not cached
|
||||
if not args.cache_text_encoder_outputs and qwen3_text_encoder is not None:
|
||||
@@ -498,6 +381,7 @@ def train(args):
|
||||
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||
|
||||
if args.fused_backward_pass:
|
||||
# use fused optimizer for backward pass: other optimizers will be supported in the future
|
||||
import library.adafactor_fused
|
||||
|
||||
library.adafactor_fused.patch_adafactor_fused(optimizer)
|
||||
@@ -517,55 +401,29 @@ def train(args):
|
||||
|
||||
parameter.register_post_accumulate_grad_hook(create_grad_hook(param_group))
|
||||
|
||||
elif args.blockwise_fused_optimizers:
|
||||
# Prepare additional optimizers and lr schedulers
|
||||
for i in range(1, len(optimizers)):
|
||||
optimizers[i] = accelerator.prepare(optimizers[i])
|
||||
lr_schedulers[i] = accelerator.prepare(lr_schedulers[i])
|
||||
|
||||
# Counters for blockwise gradient hook
|
||||
optimizer_hooked_count = {}
|
||||
num_parameters_per_group = [0] * len(optimizers)
|
||||
parameter_optimizer_map = {}
|
||||
|
||||
for opt_idx, opt in enumerate(optimizers):
|
||||
for param_group in opt.param_groups:
|
||||
for parameter in param_group["params"]:
|
||||
if parameter.requires_grad:
|
||||
|
||||
def grad_hook(parameter: torch.Tensor):
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
accelerator.clip_grad_norm_(parameter, args.max_grad_norm)
|
||||
|
||||
i = parameter_optimizer_map[parameter]
|
||||
optimizer_hooked_count[i] += 1
|
||||
if optimizer_hooked_count[i] == num_parameters_per_group[i]:
|
||||
optimizers[i].step()
|
||||
optimizers[i].zero_grad(set_to_none=True)
|
||||
|
||||
parameter.register_post_accumulate_grad_hook(grad_hook)
|
||||
parameter_optimizer_map[parameter] = opt_idx
|
||||
num_parameters_per_group[opt_idx] += 1
|
||||
|
||||
# Training loop
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
||||
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
||||
|
||||
accelerator.print("running training")
|
||||
accelerator.print(f" num examples: {train_dataset_group.num_train_images}")
|
||||
accelerator.print(f" num batches per epoch: {len(train_dataloader)}")
|
||||
accelerator.print(f" num epochs: {num_train_epochs}")
|
||||
accelerator.print("running training / 学習開始")
|
||||
accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
|
||||
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
||||
accelerator.print(
|
||||
f" batch size per device: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
|
||||
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
|
||||
)
|
||||
accelerator.print(f" gradient accumulation steps = {args.gradient_accumulation_steps}")
|
||||
accelerator.print(f" total optimization steps: {args.max_train_steps}")
|
||||
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||
|
||||
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
||||
global_step = 0
|
||||
|
||||
noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
|
||||
# Copy for noise and timestep generation, because noise_scheduler may be changed during training in future
|
||||
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
init_kwargs = {}
|
||||
if args.wandb_run_name:
|
||||
@@ -580,6 +438,7 @@ def train(args):
|
||||
|
||||
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
|
||||
import wandb
|
||||
|
||||
wandb.define_metric("epoch")
|
||||
wandb.define_metric("loss/epoch", step_metric="epoch")
|
||||
|
||||
@@ -589,8 +448,15 @@ def train(args):
|
||||
# For --sample_at_first
|
||||
optimizer_eval_fn()
|
||||
anima_train_utils.sample_images(
|
||||
accelerator, args, 0, global_step, dit, vae, vae_scale,
|
||||
qwen3_text_encoder, tokenize_strategy, text_encoding_strategy,
|
||||
accelerator,
|
||||
args,
|
||||
0,
|
||||
global_step,
|
||||
dit,
|
||||
vae,
|
||||
qwen3_text_encoder,
|
||||
tokenize_strategy,
|
||||
text_encoding_strategy,
|
||||
sample_prompts_te_outputs,
|
||||
)
|
||||
optimizer_train_fn()
|
||||
@@ -600,11 +466,11 @@ def train(args):
|
||||
# Show model info
|
||||
unwrapped_dit = accelerator.unwrap_model(dit) if dit is not None else None
|
||||
if unwrapped_dit is not None:
|
||||
logger.info(f"dit device: {unwrapped_dit.t_embedding_norm.weight.device}, dtype: {unwrapped_dit.t_embedding_norm.weight.dtype}")
|
||||
logger.info(f"dit device: {unwrapped_dit.device}, dtype: {unwrapped_dit.dtype}")
|
||||
if qwen3_text_encoder is not None:
|
||||
logger.info(f"qwen3 device: {next(qwen3_text_encoder.parameters()).device}")
|
||||
logger.info(f"qwen3 device: {qwen3_text_encoder.device}")
|
||||
if vae is not None:
|
||||
logger.info(f"vae device: {next(vae.parameters()).device}")
|
||||
logger.info(f"vae device: {vae.device}")
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
epoch = 0
|
||||
@@ -618,19 +484,17 @@ def train(args):
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
current_step.value = global_step
|
||||
|
||||
if args.blockwise_fused_optimizers:
|
||||
optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step
|
||||
|
||||
with accelerator.accumulate(*training_models):
|
||||
# Get latents
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device, dtype=weight_dtype)
|
||||
latents = batch["latents"].to(accelerator.device, dtype=dit_weight_dtype)
|
||||
if latents.ndim == 5: # Fallback for 5D latents (old cache)
|
||||
latents = latents.squeeze(2) # (B, C, 1, H, W) -> (B, C, H, W)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
# images are already [-1, 1] from IMAGE_TRANSFORMS, add temporal dim
|
||||
images = batch["images"].to(accelerator.device, dtype=weight_dtype)
|
||||
images = images.unsqueeze(2) # (B, C, 1, H, W)
|
||||
latents = vae.encode(images, vae_scale).to(accelerator.device, dtype=weight_dtype)
|
||||
latents = vae.encode_pixels_to_latents(images).to(accelerator.device, dtype=dit_weight_dtype)
|
||||
|
||||
if torch.any(torch.isnan(latents)):
|
||||
accelerator.print("NaN found in latents, replacing with zeros")
|
||||
@@ -640,23 +504,24 @@ def train(args):
|
||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||
if text_encoder_outputs_list is not None:
|
||||
# Cached outputs
|
||||
caption_dropout_rates = text_encoder_outputs_list[-1]
|
||||
text_encoder_outputs_list = text_encoder_outputs_list[:-1]
|
||||
|
||||
# Apply caption dropout to cached outputs
|
||||
text_encoder_outputs_list = text_encoding_strategy.drop_cached_text_encoder_outputs(
|
||||
*text_encoder_outputs_list
|
||||
*text_encoder_outputs_list, caption_dropout_rates=caption_dropout_rates
|
||||
)
|
||||
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoder_outputs_list
|
||||
else:
|
||||
# Encode on-the-fly
|
||||
input_ids_list = batch["input_ids_list"]
|
||||
qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask = input_ids_list
|
||||
with torch.no_grad():
|
||||
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy,
|
||||
[qwen3_text_encoder],
|
||||
[qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask],
|
||||
tokenize_strategy, [qwen3_text_encoder], input_ids_list
|
||||
)
|
||||
|
||||
# Move to device
|
||||
prompt_embeds = prompt_embeds.to(accelerator.device, dtype=weight_dtype)
|
||||
prompt_embeds = prompt_embeds.to(accelerator.device, dtype=dit_weight_dtype)
|
||||
attn_mask = attn_mask.to(accelerator.device)
|
||||
t5_input_ids = t5_input_ids.to(accelerator.device, dtype=torch.long)
|
||||
t5_attn_mask = t5_attn_mask.to(accelerator.device)
|
||||
@@ -664,9 +529,11 @@ def train(args):
|
||||
# Noise and timesteps
|
||||
noise = torch.randn_like(latents)
|
||||
|
||||
noisy_model_input, timesteps, sigmas = anima_train_utils.get_noisy_model_input_and_timesteps(
|
||||
args, latents, noise, accelerator.device, weight_dtype
|
||||
# Get noisy model input and timesteps
|
||||
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler_copy, latents, noise, accelerator.device, dit_weight_dtype
|
||||
)
|
||||
timesteps = timesteps / 1000.0 # scale to [0, 1] range. timesteps is float32
|
||||
|
||||
# NaN checks
|
||||
if torch.any(torch.isnan(noisy_model_input)):
|
||||
@@ -678,15 +545,10 @@ def train(args):
|
||||
bs = latents.shape[0]
|
||||
h_latent = latents.shape[-2]
|
||||
w_latent = latents.shape[-1]
|
||||
padding_mask = torch.zeros(
|
||||
bs, 1, h_latent, w_latent,
|
||||
dtype=weight_dtype, device=accelerator.device
|
||||
)
|
||||
padding_mask = torch.zeros(bs, 1, h_latent, w_latent, dtype=dit_weight_dtype, device=accelerator.device)
|
||||
|
||||
# DiT forward (LLM adapter runs inside forward for DDP gradient sync)
|
||||
if is_swapping_blocks:
|
||||
accelerator.unwrap_model(dit).prepare_block_swap_before_forward()
|
||||
|
||||
noisy_model_input = noisy_model_input.unsqueeze(2) # 4D to 5D, (B, C, 1, H, W)
|
||||
with accelerator.autocast():
|
||||
model_pred = dit(
|
||||
noisy_model_input,
|
||||
@@ -697,6 +559,7 @@ def train(args):
|
||||
t5_input_ids=t5_input_ids,
|
||||
t5_attn_mask=t5_attn_mask,
|
||||
)
|
||||
model_pred = model_pred.squeeze(2) # 5D to 4D, (B, C, H, W)
|
||||
|
||||
# Compute loss (rectified flow: target = noise - latents)
|
||||
target = noise - latents
|
||||
@@ -708,12 +571,10 @@ def train(args):
|
||||
|
||||
# Loss
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, None)
|
||||
loss = train_util.conditional_loss(
|
||||
model_pred.float(), target.float(), args.loss_type, "none", huber_c
|
||||
)
|
||||
loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = loss.mean([1, 2, 3, 4]) # (B, C, T, H, W) -> (B,)
|
||||
loss = loss.mean([1, 2, 3]) # (B, C, H, W) -> (B,)
|
||||
|
||||
if weighting is not None:
|
||||
loss = loss * weighting
|
||||
@@ -724,7 +585,7 @@ def train(args):
|
||||
|
||||
accelerator.backward(loss)
|
||||
|
||||
if not (args.fused_backward_pass or args.blockwise_fused_optimizers):
|
||||
if not args.fused_backward_pass:
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
params_to_clip = []
|
||||
for m in training_models:
|
||||
@@ -737,9 +598,6 @@ def train(args):
|
||||
else:
|
||||
# optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
|
||||
lr_scheduler.step()
|
||||
if args.blockwise_fused_optimizers:
|
||||
for i in range(1, len(optimizers)):
|
||||
lr_schedulers[i].step()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step
|
||||
if accelerator.sync_gradients:
|
||||
@@ -748,8 +606,15 @@ def train(args):
|
||||
|
||||
optimizer_eval_fn()
|
||||
anima_train_utils.sample_images(
|
||||
accelerator, args, None, global_step, dit, vae, vae_scale,
|
||||
qwen3_text_encoder, tokenize_strategy, text_encoding_strategy,
|
||||
accelerator,
|
||||
args,
|
||||
None,
|
||||
global_step,
|
||||
dit,
|
||||
vae,
|
||||
qwen3_text_encoder,
|
||||
tokenize_strategy,
|
||||
text_encoding_strategy,
|
||||
sample_prompts_te_outputs,
|
||||
)
|
||||
|
||||
@@ -773,8 +638,10 @@ def train(args):
|
||||
if len(accelerator.trackers) > 0:
|
||||
logs = {"loss": current_loss}
|
||||
train_util.append_lr_to_logs_with_names(
|
||||
logs, lr_scheduler, args.optimizer_type,
|
||||
["base", "self_attn", "cross_attn", "mlp", "mod", "llm_adapter"] if train_dit else []
|
||||
logs,
|
||||
lr_scheduler,
|
||||
args.optimizer_type,
|
||||
["base", "self_attn", "cross_attn", "mlp", "mod", "llm_adapter"] if train_dit else [],
|
||||
)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
@@ -807,8 +674,15 @@ def train(args):
|
||||
)
|
||||
|
||||
anima_train_utils.sample_images(
|
||||
accelerator, args, epoch + 1, global_step, dit, vae, vae_scale,
|
||||
qwen3_text_encoder, tokenize_strategy, text_encoding_strategy,
|
||||
accelerator,
|
||||
args,
|
||||
epoch + 1,
|
||||
global_step,
|
||||
dit,
|
||||
vae,
|
||||
qwen3_text_encoder,
|
||||
tokenize_strategy,
|
||||
text_encoding_strategy,
|
||||
sample_prompts_te_outputs,
|
||||
)
|
||||
|
||||
@@ -852,11 +726,6 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
anima_train_utils.add_anima_training_arguments(parser)
|
||||
sai_model_spec.add_model_spec_arguments(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--blockwise_fused_optimizers",
|
||||
action="store_true",
|
||||
help="enable blockwise optimizers for fused backward pass and optimizer step",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cpu_offload_checkpointing",
|
||||
action="store_true",
|
||||
@@ -884,4 +753,7 @@ if __name__ == "__main__":
|
||||
train_util.verify_command_line_training_args(args)
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
if args.attn_mode == "sdpa":
|
||||
args.attn_mode = "torch" # backward compatibility
|
||||
|
||||
train(args)
|
||||
|
||||
@@ -1,16 +1,26 @@
|
||||
# Anima LoRA training script
|
||||
|
||||
import argparse
|
||||
import math
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from accelerate import Accelerator
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from library import anima_models, anima_train_utils, anima_utils, strategy_anima, strategy_base, train_util
|
||||
from library import (
|
||||
anima_models,
|
||||
anima_train_utils,
|
||||
anima_utils,
|
||||
flux_train_utils,
|
||||
qwen_image_autoencoder_kl,
|
||||
sd3_train_utils,
|
||||
strategy_anima,
|
||||
strategy_base,
|
||||
train_util,
|
||||
)
|
||||
import train_network
|
||||
from library.utils import setup_logging
|
||||
|
||||
@@ -24,13 +34,6 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sample_prompts_te_outputs = None
|
||||
self.vae = None
|
||||
self.vae_scale = None
|
||||
self.qwen3_text_encoder = None
|
||||
self.qwen3_tokenizer = None
|
||||
self.t5_tokenizer = None
|
||||
self.tokenize_strategy = None
|
||||
self.text_encoding_strategy = None
|
||||
|
||||
def assert_extra_args(
|
||||
self,
|
||||
@@ -38,140 +41,118 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
|
||||
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
|
||||
val_dataset_group: Optional[train_util.DatasetGroup],
|
||||
):
|
||||
if args.fp8_base or args.fp8_base_unet:
|
||||
logger.warning("fp8_base and fp8_base_unet are not supported. / fp8_baseとfp8_base_unetはサポートされていません。")
|
||||
args.fp8_base = False
|
||||
args.fp8_base_unet = False
|
||||
args.fp8_scaled = False # Anima DiT does not support fp8_scaled
|
||||
|
||||
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
||||
logger.warning(
|
||||
"cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled"
|
||||
)
|
||||
logger.warning("cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled")
|
||||
args.cache_text_encoder_outputs = True
|
||||
|
||||
# Anima uses embedding-level dropout (in AnimaTextEncodingStrategy) instead of
|
||||
# dataset-level caption dropout, so zero out subset-level rates to allow caching.
|
||||
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
|
||||
if caption_dropout_rate > 0:
|
||||
logger.info(f"Using embedding-level caption dropout rate: {caption_dropout_rate}")
|
||||
if hasattr(train_dataset_group, 'datasets'):
|
||||
for dataset in train_dataset_group.datasets:
|
||||
for subset in dataset.subsets:
|
||||
subset.caption_dropout_rate = 0.0
|
||||
|
||||
if args.cache_text_encoder_outputs:
|
||||
assert (
|
||||
train_dataset_group.is_text_encoder_output_cacheable()
|
||||
assert train_dataset_group.is_text_encoder_output_cacheable(
|
||||
cache_supports_dropout=True
|
||||
), "when caching Text Encoder output, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used"
|
||||
|
||||
assert (
|
||||
args.network_train_unet_only or not args.cache_text_encoder_outputs
|
||||
), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません"
|
||||
|
||||
assert (
|
||||
args.blocks_to_swap is None or args.blocks_to_swap == 0
|
||||
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing"
|
||||
|
||||
if getattr(args, 'unsloth_offload_checkpointing', False):
|
||||
if args.unsloth_offload_checkpointing:
|
||||
if not args.gradient_checkpointing:
|
||||
logger.warning("unsloth_offload_checkpointing is enabled, so gradient_checkpointing is also enabled")
|
||||
args.gradient_checkpointing = True
|
||||
assert not args.cpu_offload_checkpointing, \
|
||||
"Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing"
|
||||
assert (
|
||||
not args.cpu_offload_checkpointing
|
||||
), "Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing"
|
||||
assert (
|
||||
args.blocks_to_swap is None or args.blocks_to_swap == 0
|
||||
), "blocks_to_swap is not supported with unsloth_offload_checkpointing"
|
||||
|
||||
# Flash attention: validate availability
|
||||
if getattr(args, 'flash_attn', False):
|
||||
try:
|
||||
import flash_attn # noqa: F401
|
||||
logger.info("Flash Attention enabled for DiT blocks")
|
||||
except ImportError:
|
||||
logger.warning("flash_attn package not installed, falling back to PyTorch SDPA")
|
||||
args.flash_attn = False
|
||||
|
||||
if getattr(args, 'blockwise_fused_optimizers', False):
|
||||
raise ValueError("blockwise_fused_optimizers is not supported with LoRA/NetworkTrainer")
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(8) # WanVAE spatial downscale = 8
|
||||
train_dataset_group.verify_bucket_reso_steps(16) # WanVAE spatial downscale = 8 and patch size = 2
|
||||
if val_dataset_group is not None:
|
||||
val_dataset_group.verify_bucket_reso_steps(8)
|
||||
val_dataset_group.verify_bucket_reso_steps(16)
|
||||
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
|
||||
|
||||
# Load Qwen3 text encoder (tokenizers already loaded in get_tokenize_strategy)
|
||||
logger.info("Loading Qwen3 text encoder...")
|
||||
self.qwen3_text_encoder, _ = anima_utils.load_qwen3_text_encoder(
|
||||
args.qwen3_path, dtype=weight_dtype, device="cpu"
|
||||
)
|
||||
self.qwen3_text_encoder.eval()
|
||||
qwen3_text_encoder, _ = anima_utils.load_qwen3_text_encoder(args.qwen3, dtype=weight_dtype, device="cpu")
|
||||
qwen3_text_encoder.eval()
|
||||
|
||||
# Parse transformer_dtype
|
||||
transformer_dtype = None
|
||||
if hasattr(args, 'transformer_dtype') and args.transformer_dtype is not None:
|
||||
transformer_dtype_map = {
|
||||
"float16": torch.float16,
|
||||
"bfloat16": torch.bfloat16,
|
||||
"float32": torch.float32,
|
||||
}
|
||||
transformer_dtype = transformer_dtype_map.get(args.transformer_dtype, None)
|
||||
# Load VAE
|
||||
logger.info("Loading Anima VAE...")
|
||||
vae = qwen_image_autoencoder_kl.load_vae(
|
||||
args.vae, device="cpu", disable_mmap=True, spatial_chunk_size=args.vae_chunk_size, disable_cache=args.vae_disable_cache
|
||||
)
|
||||
vae.to(weight_dtype)
|
||||
vae.eval()
|
||||
|
||||
# Return format: (model_type, text_encoders, vae, unet)
|
||||
return "anima", [qwen3_text_encoder], vae, None # unet loaded lazily
|
||||
|
||||
def load_unet_lazily(self, args, weight_dtype, accelerator, text_encoders) -> tuple[nn.Module, list[nn.Module]]:
|
||||
loading_dtype = None if args.fp8_scaled else weight_dtype
|
||||
loading_device = "cpu" if self.is_swapping_blocks else accelerator.device
|
||||
|
||||
attn_mode = "torch"
|
||||
if args.xformers:
|
||||
attn_mode = "xformers"
|
||||
if args.attn_mode is not None:
|
||||
attn_mode = args.attn_mode
|
||||
|
||||
# Load DiT
|
||||
logger.info("Loading Anima DiT...")
|
||||
dit = anima_utils.load_anima_dit(
|
||||
args.dit_path,
|
||||
dtype=weight_dtype,
|
||||
device="cpu",
|
||||
transformer_dtype=transformer_dtype,
|
||||
llm_adapter_path=getattr(args, 'llm_adapter_path', None),
|
||||
disable_mmap=getattr(args, 'disable_mmap_load_safetensors', False),
|
||||
logger.info(f"Loading Anima DiT model with attn_mode={attn_mode}, split_attn: {args.split_attn}...")
|
||||
model = anima_utils.load_anima_model(
|
||||
accelerator.device,
|
||||
args.pretrained_model_name_or_path,
|
||||
attn_mode,
|
||||
args.split_attn,
|
||||
loading_device,
|
||||
loading_dtype,
|
||||
args.fp8_scaled,
|
||||
)
|
||||
|
||||
# Flash attention
|
||||
if getattr(args, 'flash_attn', False):
|
||||
dit.set_flash_attn(True)
|
||||
|
||||
# Store unsloth preference so that when the base NetworkTrainer calls
|
||||
# dit.enable_gradient_checkpointing(cpu_offload=...), we can override to use unsloth.
|
||||
# The base trainer only passes cpu_offload, so we store the flag on the model.
|
||||
self._use_unsloth_offload_checkpointing = getattr(args, 'unsloth_offload_checkpointing', False)
|
||||
self._use_unsloth_offload_checkpointing = args.unsloth_offload_checkpointing
|
||||
|
||||
# Block swap
|
||||
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
|
||||
if self.is_swapping_blocks:
|
||||
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
|
||||
dit.enable_block_swap(args.blocks_to_swap, accelerator.device)
|
||||
model.enable_block_swap(args.blocks_to_swap, accelerator.device)
|
||||
|
||||
# Load VAE
|
||||
logger.info("Loading Anima VAE...")
|
||||
self.vae, vae_mean, vae_std, self.vae_scale = anima_utils.load_anima_vae(
|
||||
args.vae_path, dtype=weight_dtype, device="cpu"
|
||||
)
|
||||
|
||||
# Return format: (model_type, text_encoders, vae, unet)
|
||||
return "anima", [self.qwen3_text_encoder], self.vae, dit
|
||||
return model, text_encoders
|
||||
|
||||
def get_tokenize_strategy(self, args):
|
||||
# Load tokenizers from paths (called before load_target_model, so self.qwen3_tokenizer isn't set yet)
|
||||
self.tokenize_strategy = strategy_anima.AnimaTokenizeStrategy(
|
||||
qwen3_path=args.qwen3_path,
|
||||
t5_tokenizer_path=getattr(args, 't5_tokenizer_path', None),
|
||||
tokenize_strategy = strategy_anima.AnimaTokenizeStrategy(
|
||||
qwen3_path=args.qwen3,
|
||||
t5_tokenizer_path=args.t5_tokenizer_path,
|
||||
qwen3_max_length=args.qwen3_max_token_length,
|
||||
t5_max_length=args.t5_max_token_length,
|
||||
)
|
||||
# Store references so load_target_model can reuse them
|
||||
self.qwen3_tokenizer = self.tokenize_strategy.qwen3_tokenizer
|
||||
self.t5_tokenizer = self.tokenize_strategy.t5_tokenizer
|
||||
return self.tokenize_strategy
|
||||
return tokenize_strategy
|
||||
|
||||
def get_tokenizers(self, tokenize_strategy: strategy_anima.AnimaTokenizeStrategy):
|
||||
return [tokenize_strategy.qwen3_tokenizer]
|
||||
|
||||
def get_latents_caching_strategy(self, args):
|
||||
return strategy_anima.AnimaLatentsCachingStrategy(
|
||||
args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
|
||||
)
|
||||
return strategy_anima.AnimaLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check)
|
||||
|
||||
def get_text_encoding_strategy(self, args):
|
||||
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
|
||||
self.text_encoding_strategy = strategy_anima.AnimaTextEncodingStrategy(
|
||||
dropout_rate=caption_dropout_rate,
|
||||
)
|
||||
return self.text_encoding_strategy
|
||||
return strategy_anima.AnimaTextEncodingStrategy()
|
||||
|
||||
def post_process_network(self, args, accelerator, network, text_encoders, unet):
|
||||
# Qwen3 text encoder is always frozen for Anima
|
||||
pass
|
||||
|
||||
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
|
||||
@@ -179,19 +160,10 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
|
||||
return None # no text encoders needed for encoding
|
||||
return text_encoders
|
||||
|
||||
def get_text_encoders_train_flags(self, args, text_encoders):
|
||||
return [False] # Qwen3 always frozen
|
||||
|
||||
def is_train_text_encoder(self, args):
|
||||
return False # Qwen3 text encoder is always frozen for Anima
|
||||
|
||||
def get_text_encoder_outputs_caching_strategy(self, args):
|
||||
if args.cache_text_encoder_outputs:
|
||||
return strategy_anima.AnimaTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
args.text_encoder_batch_size,
|
||||
args.skip_cache_check,
|
||||
is_partial=False,
|
||||
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -200,15 +172,14 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
|
||||
):
|
||||
if args.cache_text_encoder_outputs:
|
||||
if not args.lowram:
|
||||
logger.info("move vae and unet to cpu to save memory")
|
||||
org_vae_device = next(vae.parameters()).device
|
||||
org_unet_device = unet.device
|
||||
# We cannot move DiT to CPU because of block swap, so only move VAE
|
||||
logger.info("move vae to cpu to save memory")
|
||||
org_vae_device = vae.device
|
||||
vae.to("cpu")
|
||||
unet.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
logger.info("move text encoder to gpu")
|
||||
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoders[0].to(accelerator.device)
|
||||
|
||||
with accelerator.autocast():
|
||||
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
|
||||
@@ -229,59 +200,52 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
|
||||
logger.info(f" cache TE outputs for: {p}")
|
||||
tokens_and_masks = tokenize_strategy.tokenize(p)
|
||||
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy,
|
||||
text_encoders,
|
||||
tokens_and_masks,
|
||||
enable_dropout=False,
|
||||
tokenize_strategy, text_encoders, tokens_and_masks
|
||||
)
|
||||
self.sample_prompts_te_outputs = sample_prompts_te_outputs
|
||||
|
||||
# Pre-cache unconditional embeddings for caption dropout before text encoder is deleted
|
||||
caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
|
||||
text_encoding_strategy_for_uncond = strategy_base.TextEncodingStrategy.get_strategy()
|
||||
if caption_dropout_rate > 0.0:
|
||||
tokenize_strategy_for_uncond = strategy_base.TokenizeStrategy.get_strategy()
|
||||
with accelerator.autocast():
|
||||
text_encoding_strategy_for_uncond.cache_uncond_embeddings(tokenize_strategy_for_uncond, text_encoders)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# move text encoder back to cpu
|
||||
logger.info("move text encoder back to cpu")
|
||||
text_encoders[0].to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
if not args.lowram:
|
||||
logger.info("move vae and unet back to original device")
|
||||
logger.info("move vae back to original device")
|
||||
vae.to(org_vae_device)
|
||||
unet.to(org_unet_device)
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
else:
|
||||
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
||||
# move text encoder to device for encoding during training/validation
|
||||
text_encoders[0].to(accelerator.device)
|
||||
|
||||
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
|
||||
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder] # compatibility
|
||||
te = self.get_models_for_text_encoding(args, accelerator, text_encoders)
|
||||
qwen3_te = te[0] if te is not None else None
|
||||
|
||||
text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
|
||||
anima_train_utils.sample_images(
|
||||
accelerator, args, epoch, global_step, unet, vae, self.vae_scale,
|
||||
qwen3_te, self.tokenize_strategy, self.text_encoding_strategy,
|
||||
accelerator,
|
||||
args,
|
||||
epoch,
|
||||
global_step,
|
||||
unet,
|
||||
vae,
|
||||
qwen3_te,
|
||||
tokenize_strategy,
|
||||
text_encoding_strategy,
|
||||
self.sample_prompts_te_outputs,
|
||||
)
|
||||
|
||||
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
|
||||
noise_scheduler = anima_train_utils.FlowMatchEulerDiscreteScheduler(
|
||||
num_train_timesteps=1000, shift=args.discrete_flow_shift
|
||||
)
|
||||
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
|
||||
return noise_scheduler
|
||||
|
||||
def encode_images_to_latents(self, args, vae, images):
|
||||
# images are already [-1,1] from IMAGE_TRANSFORMS, add temporal dim
|
||||
images = images.unsqueeze(2) # (B, C, 1, H, W)
|
||||
# Ensure scale tensors are on the same device as images
|
||||
vae_device = images.device
|
||||
scale = [s.to(vae_device) if isinstance(s, torch.Tensor) else s for s in self.vae_scale]
|
||||
return vae.encode(images, scale)
|
||||
vae: qwen_image_autoencoder_kl.AutoencoderKLQwenImage
|
||||
return vae.encode_pixels_to_latents(images) # Keep 4D for input/output
|
||||
|
||||
def shift_scale_latents(self, args, latents):
|
||||
# Latents already normalized by vae.encode with scale
|
||||
@@ -301,13 +265,18 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
|
||||
train_unet,
|
||||
is_train=True,
|
||||
):
|
||||
anima: anima_models.Anima = unet
|
||||
|
||||
# Sample noise
|
||||
if latents.ndim == 5: # Fallback for 5D latents (old cache)
|
||||
latents = latents.squeeze(2) # [B, C, 1, H, W] -> [B, C, H, W]
|
||||
noise = torch.randn_like(latents)
|
||||
|
||||
# Get noisy model input and timesteps
|
||||
noisy_model_input, timesteps, sigmas = anima_train_utils.get_noisy_model_input_and_timesteps(
|
||||
args, latents, noise, accelerator.device, weight_dtype
|
||||
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
|
||||
)
|
||||
timesteps = timesteps / 1000.0 # scale to [0, 1] range. timesteps is float32
|
||||
|
||||
# Gradient checkpointing support
|
||||
if args.gradient_checkpointing:
|
||||
@@ -329,161 +298,103 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
|
||||
bs = latents.shape[0]
|
||||
h_latent = latents.shape[-2]
|
||||
w_latent = latents.shape[-1]
|
||||
padding_mask = torch.zeros(
|
||||
bs, 1, h_latent, w_latent,
|
||||
dtype=weight_dtype, device=accelerator.device
|
||||
)
|
||||
padding_mask = torch.zeros(bs, 1, h_latent, w_latent, dtype=weight_dtype, device=accelerator.device)
|
||||
|
||||
# Prepare block swap
|
||||
if self.is_swapping_blocks:
|
||||
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
|
||||
|
||||
# Call model (LLM adapter runs inside forward for DDP gradient sync)
|
||||
# Call model
|
||||
noisy_model_input = noisy_model_input.unsqueeze(2) # 4D to 5D, [B, C, H, W] -> [B, C, 1, H, W]
|
||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||
model_pred = unet(
|
||||
model_pred = anima(
|
||||
noisy_model_input,
|
||||
timesteps,
|
||||
prompt_embeds,
|
||||
padding_mask=padding_mask,
|
||||
target_input_ids=t5_input_ids,
|
||||
target_attention_mask=t5_attn_mask,
|
||||
source_attention_mask=attn_mask,
|
||||
t5_input_ids=t5_input_ids,
|
||||
t5_attn_mask=t5_attn_mask,
|
||||
)
|
||||
model_pred = model_pred.squeeze(2) # 5D to 4D, [B, C, 1, H, W] -> [B, C, H, W]
|
||||
|
||||
# Rectified flow target: noise - latents
|
||||
target = noise - latents
|
||||
|
||||
# Loss weighting
|
||||
weighting = anima_train_utils.compute_loss_weighting_for_anima(
|
||||
weighting_scheme=args.weighting_scheme, sigmas=sigmas
|
||||
)
|
||||
|
||||
# Differential output preservation
|
||||
if "custom_attributes" in batch:
|
||||
diff_output_pr_indices = []
|
||||
for i, custom_attributes in enumerate(batch["custom_attributes"]):
|
||||
if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
|
||||
diff_output_pr_indices.append(i)
|
||||
|
||||
if len(diff_output_pr_indices) > 0:
|
||||
network.set_multiplier(0.0)
|
||||
with torch.no_grad(), accelerator.autocast():
|
||||
if self.is_swapping_blocks:
|
||||
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
|
||||
model_pred_prior = unet(
|
||||
noisy_model_input[diff_output_pr_indices],
|
||||
timesteps[diff_output_pr_indices],
|
||||
prompt_embeds[diff_output_pr_indices],
|
||||
padding_mask=padding_mask[diff_output_pr_indices],
|
||||
source_attention_mask=attn_mask[diff_output_pr_indices],
|
||||
t5_input_ids=t5_input_ids[diff_output_pr_indices],
|
||||
t5_attn_mask=t5_attn_mask[diff_output_pr_indices],
|
||||
)
|
||||
network.set_multiplier(1.0)
|
||||
|
||||
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
|
||||
weighting = anima_train_utils.compute_loss_weighting_for_anima(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
|
||||
|
||||
return model_pred, target, timesteps, weighting
|
||||
|
||||
def process_batch(
|
||||
self, batch, text_encoders, unet, network, vae, noise_scheduler,
|
||||
vae_dtype, weight_dtype, accelerator, args,
|
||||
text_encoding_strategy, tokenize_strategy,
|
||||
is_train=True, train_text_encoder=True, train_unet=True,
|
||||
self,
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
network,
|
||||
vae,
|
||||
noise_scheduler,
|
||||
vae_dtype,
|
||||
weight_dtype,
|
||||
accelerator,
|
||||
args,
|
||||
text_encoding_strategy,
|
||||
tokenize_strategy,
|
||||
is_train=True,
|
||||
train_text_encoder=True,
|
||||
train_unet=True,
|
||||
) -> torch.Tensor:
|
||||
"""Override base process_batch for 5D video latents (B, C, T, H, W).
|
||||
|
||||
Base class assumes 4D (B, C, H, W) for loss.mean([1,2,3]) and weighting broadcast.
|
||||
"""
|
||||
import typing
|
||||
from library.custom_train_functions import apply_masked_loss
|
||||
|
||||
with torch.no_grad():
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device))
|
||||
else:
|
||||
if args.vae_batch_size is None or len(batch["images"]) <= args.vae_batch_size:
|
||||
latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype))
|
||||
else:
|
||||
chunks = [
|
||||
batch["images"][i : i + args.vae_batch_size] for i in range(0, len(batch["images"]), args.vae_batch_size)
|
||||
]
|
||||
list_latents = []
|
||||
for chunk in chunks:
|
||||
with torch.no_grad():
|
||||
chunk = self.encode_images_to_latents(args, vae, chunk.to(accelerator.device, dtype=vae_dtype))
|
||||
list_latents.append(chunk)
|
||||
latents = torch.cat(list_latents, dim=0)
|
||||
|
||||
if torch.any(torch.isnan(latents)):
|
||||
accelerator.print("NaN found in latents, replacing with zeros")
|
||||
latents = typing.cast(torch.FloatTensor, torch.nan_to_num(latents, 0, out=latents))
|
||||
|
||||
latents = self.shift_scale_latents(args, latents)
|
||||
"""Override base process_batch for caption dropout with cached text encoder outputs."""
|
||||
|
||||
# Text encoder conditions
|
||||
text_encoder_conds = []
|
||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||
anima_text_encoding_strategy: strategy_anima.AnimaTextEncodingStrategy = text_encoding_strategy
|
||||
if text_encoder_outputs_list is not None:
|
||||
text_encoder_conds = text_encoder_outputs_list
|
||||
caption_dropout_rates = text_encoder_outputs_list[-1]
|
||||
text_encoder_outputs_list = text_encoder_outputs_list[:-1]
|
||||
|
||||
if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
|
||||
with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast():
|
||||
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
|
||||
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens(
|
||||
# Apply caption dropout to cached outputs
|
||||
text_encoder_outputs_list = anima_text_encoding_strategy.drop_cached_text_encoder_outputs(
|
||||
*text_encoder_outputs_list, caption_dropout_rates=caption_dropout_rates
|
||||
)
|
||||
batch["text_encoder_outputs_list"] = text_encoder_outputs_list
|
||||
|
||||
return super().process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
network,
|
||||
vae,
|
||||
noise_scheduler,
|
||||
vae_dtype,
|
||||
weight_dtype,
|
||||
accelerator,
|
||||
args,
|
||||
text_encoding_strategy,
|
||||
tokenize_strategy,
|
||||
self.get_models_for_text_encoding(args, accelerator, text_encoders),
|
||||
input_ids,
|
||||
is_train,
|
||||
train_text_encoder,
|
||||
train_unet,
|
||||
)
|
||||
if args.full_fp16:
|
||||
encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds]
|
||||
|
||||
if len(text_encoder_conds) == 0:
|
||||
text_encoder_conds = encoded_text_encoder_conds
|
||||
else:
|
||||
for i in range(len(encoded_text_encoder_conds)):
|
||||
if encoded_text_encoder_conds[i] is not None:
|
||||
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
||||
|
||||
noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target(
|
||||
args, accelerator, noise_scheduler, latents, batch,
|
||||
text_encoder_conds, unet, network, weight_dtype, train_unet, is_train=is_train,
|
||||
)
|
||||
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
|
||||
# Reduce all non-batch dims: (B, C, T, H, W) -> (B,) for 5D, (B, C, H, W) -> (B,) for 4D
|
||||
reduce_dims = list(range(1, loss.ndim))
|
||||
loss = loss.mean(reduce_dims)
|
||||
|
||||
# Apply weighting after reducing to (B,)
|
||||
if weighting is not None:
|
||||
loss = loss * weighting
|
||||
|
||||
loss_weights = batch["loss_weights"]
|
||||
loss = loss * loss_weights
|
||||
|
||||
loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
|
||||
return loss.mean()
|
||||
|
||||
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
||||
return loss
|
||||
|
||||
def get_sai_model_spec(self, args):
|
||||
return train_util.get_sai_model_spec(None, args, False, True, False, is_stable_diffusion_ckpt=True)
|
||||
return train_util.get_sai_model_spec_dataclass(None, args, False, True, False, anima="preview").to_metadata_dict()
|
||||
|
||||
def update_metadata(self, metadata, args):
|
||||
metadata["ss_weighting_scheme"] = args.weighting_scheme
|
||||
metadata["ss_logit_mean"] = args.logit_mean
|
||||
metadata["ss_logit_std"] = args.logit_std
|
||||
metadata["ss_mode_scale"] = args.mode_scale
|
||||
metadata["ss_timestep_sampling"] = args.timestep_sampling
|
||||
metadata["ss_sigmoid_scale"] = args.sigmoid_scale
|
||||
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
|
||||
metadata["ss_timestep_sample_method"] = getattr(args, 'timestep_sample_method', 'logit_normal')
|
||||
metadata["ss_sigmoid_scale"] = getattr(args, 'sigmoid_scale', 1.0)
|
||||
|
||||
def is_text_encoder_not_needed_for_training(self, args):
|
||||
return args.cache_text_encoder_outputs
|
||||
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
|
||||
|
||||
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
|
||||
# Set first parameter's requires_grad to True to workaround Accelerate gradient checkpointing bug
|
||||
first_param = next(text_encoder.parameters())
|
||||
first_param.requires_grad_(True)
|
||||
|
||||
def prepare_unet_with_accelerator(
|
||||
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
|
||||
@@ -496,23 +407,16 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
|
||||
if not self.is_swapping_blocks:
|
||||
return super().prepare_unet_with_accelerator(args, accelerator, unet)
|
||||
|
||||
dit = unet
|
||||
dit = accelerator.prepare(dit, device_placement=[not self.is_swapping_blocks])
|
||||
accelerator.unwrap_model(dit).move_to_device_except_swap_blocks(accelerator.device)
|
||||
accelerator.unwrap_model(dit).prepare_block_swap_before_forward()
|
||||
model = unet
|
||||
model = accelerator.prepare(model, device_placement=[not self.is_swapping_blocks])
|
||||
accelerator.unwrap_model(model).move_to_device_except_swap_blocks(accelerator.device)
|
||||
accelerator.unwrap_model(model).prepare_block_swap_before_forward()
|
||||
|
||||
return dit
|
||||
|
||||
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True):
|
||||
# Drop cached text encoder outputs for caption dropout
|
||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||
if text_encoder_outputs_list is not None:
|
||||
text_encoding_strategy: strategy_anima.AnimaTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||
text_encoder_outputs_list = text_encoding_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list)
|
||||
batch["text_encoder_outputs_list"] = text_encoder_outputs_list
|
||||
return model
|
||||
|
||||
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
||||
if self.is_swapping_blocks:
|
||||
# prepare for next forward: because backward pass is not called, we need to prepare it here
|
||||
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
|
||||
|
||||
|
||||
@@ -520,6 +424,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = train_network.setup_parser()
|
||||
train_util.add_dit_training_arguments(parser)
|
||||
anima_train_utils.add_anima_training_arguments(parser)
|
||||
# parser.add_argument("--fp8_scaled", action="store_true", help="Use scaled fp8 for DiT / DiTにスケーリングされたfp8を使う")
|
||||
parser.add_argument(
|
||||
"--unsloth_offload_checkpointing",
|
||||
action="store_true",
|
||||
@@ -536,5 +441,8 @@ if __name__ == "__main__":
|
||||
train_util.verify_command_line_training_args(args)
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
if args.attn_mode == "sdpa":
|
||||
args.attn_mode = "torch" # backward compatibility
|
||||
|
||||
trainer = AnimaNetworkTrainer()
|
||||
trainer.train(args)
|
||||
|
||||
@@ -11,7 +11,9 @@ This document explains how to train LoRA (Low-Rank Adaptation) models for Anima
|
||||
|
||||
## 1. Introduction / はじめに
|
||||
|
||||
`anima_train_network.py` trains additional networks such as LoRA for Anima models. Anima adopts a DiT (Diffusion Transformer) architecture based on the MiniTrainDIT design with Rectified Flow training. It uses a Qwen3-0.6B text encoder, an LLM Adapter (6-layer transformer bridge from Qwen3 to T5-compatible space), and a WanVAE (16-channel, 8x spatial downscale).
|
||||
`anima_train_network.py` trains additional networks such as LoRA for Anima models. Anima adopts a DiT (Diffusion Transformer) architecture based on the MiniTrainDIT design with Rectified Flow training. It uses a Qwen3-0.6B text encoder, an LLM Adapter (6-layer transformer bridge from Qwen3 to T5-compatible space), and a Qwen-Image VAE (16-channel, 8x spatial downscale).
|
||||
|
||||
Qwen-Image VAE and Qwen-Image VAE have same architecture, but [official Anima weight is named for Qwen-Image VAE](https://huggingface.co/circlestone-labs/Anima/tree/main/split_files/vae).
|
||||
|
||||
This guide assumes you already understand the basics of LoRA training. For common usage and options, see the [train_network.py guide](train_network.md). Some parameters are similar to those in [`sd3_train_network.py`](sd3_train_network.md) and [`flux_train_network.py`](flux_train_network.md).
|
||||
|
||||
@@ -24,7 +26,9 @@ This guide assumes you already understand the basics of LoRA training. For commo
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
`anima_train_network.py`は、Anima モデルに対してLoRAなどの追加ネットワークを学習させるためのスクリプトです。AnimaはMiniTrainDIT設計に基づくDiT (Diffusion Transformer) アーキテクチャを採用しており、Rectified Flow学習を使用します。テキストエンコーダーとしてQwen3-0.6B、LLM Adapter (Qwen3からT5互換空間への6層Transformerブリッジ)、およびWanVAE (16チャンネル、8倍空間ダウンスケール) を使用します。
|
||||
`anima_train_network.py`は、Anima モデルに対してLoRAなどの追加ネットワークを学習させるためのスクリプトです。AnimaはMiniTrainDIT設計に基づくDiT (Diffusion Transformer) アーキテクチャを採用しており、Rectified Flow学習を使用します。テキストエンコーダーとしてQwen3-0.6B、LLM Adapter (Qwen3からT5互換空間への6層Transformerブリッジ)、およびQwen-Image VAE (16チャンネル、8倍空間ダウンスケール) を使用します。
|
||||
|
||||
Qwen-Image VAEとQwen-Image VAEは同じアーキテクチャですが、[Anima公式の重みはQwen-Image VAE用](https://huggingface.co/circlestone-labs/Anima/tree/main/split_files/vae)のようです。
|
||||
|
||||
このガイドは、基本的なLoRA学習の手順を理解しているユーザーを対象としています。基本的な使い方や共通のオプションについては、[`train_network.py`のガイド](train_network.md)を参照してください。また一部のパラメータは [`sd3_train_network.py`](sd3_train_network.md) や [`flux_train_network.py`](flux_train_network.md) と同様のものがあるため、そちらも参考にしてください。
|
||||
|
||||
@@ -40,11 +44,11 @@ This guide assumes you already understand the basics of LoRA training. For commo
|
||||
`anima_train_network.py` is based on `train_network.py` but modified for Anima. Main differences are:
|
||||
|
||||
* **Target models:** Anima DiT models.
|
||||
* **Model structure:** Uses a MiniTrainDIT (Transformer based) instead of U-Net. Employs a single text encoder (Qwen3-0.6B), an LLM Adapter that bridges Qwen3 embeddings to T5-compatible cross-attention space, and a WanVAE (16-channel latent space with 8x spatial downscale).
|
||||
* **Arguments:** Options exist to specify the Anima DiT model, Qwen3 text encoder, WanVAE, LLM adapter, and T5 tokenizer separately.
|
||||
* **Incompatible arguments:** Stable Diffusion v1/v2 options such as `--v2`, `--v_parameterization` and `--clip_skip` are not used.
|
||||
* **Anima specific options:** Additional parameters for component-wise learning rates (self_attn, cross_attn, mlp, mod, llm_adapter), timestep sampling, discrete flow shift, and flash attention.
|
||||
* **6 Parameter Groups:** Independent learning rates for `base`, `self_attn`, `cross_attn`, `mlp`, `adaln_modulation`, and `llm_adapter` components.
|
||||
* **Model structure:** Uses a MiniTrainDIT (Transformer based) instead of U-Net. Employs a single text encoder (Qwen3-0.6B), an LLM Adapter that bridges Qwen3 embeddings to T5-compatible cross-attention space, and a Qwen-Image VAE (16-channel latent space with 8x spatial downscale).
|
||||
* **Arguments:** Uses the common `--pretrained_model_name_or_path` for the DiT model path, `--qwen3` for the Qwen3 text encoder, and `--vae` for the Qwen-Image VAE. The LLM adapter and T5 tokenizer can be specified separately with `--llm_adapter_path` and `--t5_tokenizer_path`.
|
||||
* **Incompatible arguments:** Stable Diffusion v1/v2 options such as `--v2`, `--v_parameterization` and `--clip_skip` are not used. `--fp8_base` is not supported.
|
||||
* **Timestep sampling:** Uses the same `--timestep_sampling` options as FLUX training (`sigma`, `uniform`, `sigmoid`, `shift`, `flux_shift`).
|
||||
* **LoRA:** Uses regex-based module selection and per-module rank/learning rate control (`network_reg_dims`, `network_reg_lrs`) instead of per-component arguments. Module exclusion/inclusion is controlled by `exclude_patterns` and `include_patterns`.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
@@ -52,11 +56,11 @@ This guide assumes you already understand the basics of LoRA training. For commo
|
||||
`anima_train_network.py`は`train_network.py`をベースに、Anima モデルに対応するための変更が加えられています。主な違いは以下の通りです。
|
||||
|
||||
* **対象モデル:** Anima DiTモデルを対象とします。
|
||||
* **モデル構造:** U-Netの代わりにMiniTrainDIT (Transformerベース) を使用します。テキストエンコーダーとしてQwen3-0.6B、Qwen3埋め込みをT5互換のクロスアテンション空間に変換するLLM Adapter、およびWanVAE (16チャンネル潜在空間、8倍空間ダウンスケール) を使用します。
|
||||
* **引数:** Anima DiTモデル、Qwen3テキストエンコーダー、WanVAE、LLM Adapter、T5トークナイザーを個別に指定する引数があります。
|
||||
* **一部引数の非互換性:** Stable Diffusion v1/v2向けの引数(例: `--v2`, `--v_parameterization`, `--clip_skip`)はAnimaの学習では使用されません。
|
||||
* **Anima特有の引数:** コンポーネント別学習率(self_attn, cross_attn, mlp, mod, llm_adapter)、タイムステップサンプリング、離散フローシフト、Flash Attentionに関する引数が追加されています。
|
||||
* **6パラメータグループ:** `base`、`self_attn`、`cross_attn`、`mlp`、`adaln_modulation`、`llm_adapter`の各コンポーネントに対して独立した学習率を設定できます。
|
||||
* **モデル構造:** U-Netの代わりにMiniTrainDIT (Transformerベース) を使用します。テキストエンコーダーとしてQwen3-0.6B、Qwen3埋め込みをT5互換のクロスアテンション空間に変換するLLM Adapter、およびQwen-Image VAE (16チャンネル潜在空間、8倍空間ダウンスケール) を使用します。
|
||||
* **引数:** DiTモデルのパスには共通引数`--pretrained_model_name_or_path`を、Qwen3テキストエンコーダーには`--qwen3`を、Qwen-Image VAEには`--vae`を使用します。LLM AdapterとT5トークナイザーはそれぞれ`--llm_adapter_path`、`--t5_tokenizer_path`で個別に指定できます。
|
||||
* **一部引数の非互換性:** Stable Diffusion v1/v2向けの引数(例: `--v2`, `--v_parameterization`, `--clip_skip`)は使用されません。`--fp8_base`はサポートされていません。
|
||||
* **タイムステップサンプリング:** FLUX学習と同じ`--timestep_sampling`オプション(`sigma`、`uniform`、`sigmoid`、`shift`、`flux_shift`)を使用します。
|
||||
* **LoRA:** コンポーネント別の引数の代わりに、正規表現ベースのモジュール選択とモジュール単位のランク/学習率制御(`network_reg_dims`、`network_reg_lrs`)を使用します。モジュールの除外/包含は`exclude_patterns`と`include_patterns`で制御します。
|
||||
</details>
|
||||
|
||||
## 3. Preparation / 準備
|
||||
@@ -65,16 +69,16 @@ The following files are required before starting training:
|
||||
|
||||
1. **Training script:** `anima_train_network.py`
|
||||
2. **Anima DiT model file:** `.safetensors` file for the base DiT model.
|
||||
3. **Qwen3-0.6B text encoder:** Either a HuggingFace model directory or a single `.safetensors` file (requires `configs/qwen3_06b/` config files).
|
||||
4. **WanVAE model file:** `.safetensors` or `.pth` file for the VAE.
|
||||
3. **Qwen3-0.6B text encoder:** Either a HuggingFace model directory, or a single `.safetensors` file (uses the bundled config files in `configs/qwen3_06b/`).
|
||||
4. **Qwen-Image VAE model file:** `.safetensors` or `.pth` file for the VAE.
|
||||
5. **LLM Adapter model file (optional):** `.safetensors` file. If not provided separately, the adapter is loaded from the DiT file if the key `llm_adapter.out_proj.weight` exists.
|
||||
6. **T5 Tokenizer (optional):** If not specified, uses the bundled tokenizer at `configs/t5_old/`.
|
||||
7. **Dataset definition file (.toml):** Dataset settings in TOML format. (See the [Dataset Configuration Guide](./config_README-en.md).) In this document we use `my_anima_dataset_config.toml` as an example.
|
||||
|
||||
Model files can be obtained from the [Anima HuggingFace repository](https://huggingface.co/circlestone-labs/Anima).
|
||||
|
||||
**Notes:**
|
||||
* When using a single `.safetensors` file for Qwen3, download the `config.json`, `tokenizer.json`, `tokenizer_config.json`, and `vocab.json` from the [Qwen/Qwen3-0.6B](https://huggingface.co/Qwen/Qwen3-0.6B) HuggingFace repository into the `configs/qwen3_06b/` directory.
|
||||
* The T5 tokenizer only needs the tokenizer files (not the T5 model weights). It uses the vocabulary from `google/t5-v1_1-xxl`.
|
||||
* Models are saved with a `net.` prefix on all keys for ComfyUI compatibility.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
@@ -83,16 +87,16 @@ The following files are required before starting training:
|
||||
|
||||
1. **学習スクリプト:** `anima_train_network.py`
|
||||
2. **Anima DiTモデルファイル:** ベースとなるDiTモデルの`.safetensors`ファイル。
|
||||
3. **Qwen3-0.6Bテキストエンコーダー:** HuggingFaceモデルディレクトリまたは単体の`.safetensors`ファイル(`configs/qwen3_06b/`の設定ファイルが必要)。
|
||||
4. **WanVAEモデルファイル:** VAEの`.safetensors`または`.pth`ファイル。
|
||||
3. **Qwen3-0.6Bテキストエンコーダー:** HuggingFaceモデルディレクトリまたは単体の`.safetensors`ファイル(バンドル版の`configs/qwen3_06b/`の設定ファイルが使用されます)。
|
||||
4. **Qwen-Image VAEモデルファイル:** VAEの`.safetensors`または`.pth`ファイル。
|
||||
5. **LLM Adapterモデルファイル(オプション):** `.safetensors`ファイル。個別に指定しない場合、DiTファイル内に`llm_adapter.out_proj.weight`キーが存在すればそこから読み込まれます。
|
||||
6. **T5トークナイザー(オプション):** 指定しない場合、`configs/t5_old/`のバンドル版トークナイザーを使用します。
|
||||
7. **データセット定義ファイル (.toml):** 学習データセットの設定を記述したTOML形式のファイル。(詳細は[データセット設定ガイド](./config_README-en.md)を参照してください)。例として`my_anima_dataset_config.toml`を使用します。
|
||||
|
||||
モデルファイルは[HuggingFaceのAnimaリポジトリ](https://huggingface.co/circlestone-labs/Anima)から入手できます。
|
||||
|
||||
**注意:**
|
||||
* Qwen3の単体`.safetensors`ファイルを使用する場合、[Qwen/Qwen3-0.6B](https://huggingface.co/Qwen/Qwen3-0.6B) HuggingFaceリポジトリから`config.json`、`tokenizer.json`、`tokenizer_config.json`、`vocab.json`をダウンロードし、`configs/qwen3_06b/`ディレクトリに配置してください。
|
||||
* T5トークナイザーはトークナイザーファイルのみ必要です(T5モデルの重みは不要)。`google/t5-v1_1-xxl`の語彙を使用します。
|
||||
* モデルはComfyUI互換のため、すべてのキーに`net.`プレフィックスを付けて保存されます。
|
||||
* T5トークナイザーを別途指定する場合、トークナイザーファイルのみ必要です(T5モデルの重みは不要)。`google/t5-v1_1-xxl`の語彙を使用します。
|
||||
</details>
|
||||
|
||||
## 4. Running the Training / 学習の実行
|
||||
@@ -103,33 +107,38 @@ Example command:
|
||||
|
||||
```bash
|
||||
accelerate launch --num_cpu_threads_per_process 1 anima_train_network.py \
|
||||
--dit_path="<path to Anima DiT model>" \
|
||||
--qwen3_path="<path to Qwen3-0.6B model or directory>" \
|
||||
--vae_path="<path to WanVAE model>" \
|
||||
--llm_adapter_path="<path to LLM adapter model>" \
|
||||
--pretrained_model_name_or_path="<path to Anima DiT model>" \
|
||||
--qwen3="<path to Qwen3-0.6B model or directory>" \
|
||||
--vae="<path to Qwen-Image VAE model>" \
|
||||
--dataset_config="my_anima_dataset_config.toml" \
|
||||
--output_dir="<output directory>" \
|
||||
--output_name="my_anima_lora" \
|
||||
--save_model_as=safetensors \
|
||||
--network_module=networks.lora_anima \
|
||||
--network_dim=8 \
|
||||
--network_alpha=8 \
|
||||
--learning_rate=1e-4 \
|
||||
--optimizer_type="AdamW8bit" \
|
||||
--lr_scheduler="constant" \
|
||||
--timestep_sample_method="logit_normal" \
|
||||
--discrete_flow_shift=3.0 \
|
||||
--timestep_sampling="sigmoid" \
|
||||
--discrete_flow_shift=1.0 \
|
||||
--max_train_epochs=10 \
|
||||
--save_every_n_epochs=1 \
|
||||
--mixed_precision="bf16" \
|
||||
--gradient_checkpointing \
|
||||
--cache_latents \
|
||||
--cache_text_encoder_outputs \
|
||||
--blocks_to_swap=18
|
||||
--vae_chunk_size=64 \
|
||||
--vae_disable_cache
|
||||
```
|
||||
|
||||
*(Write the command on one line or use `\` or `^` for line breaks.)*
|
||||
|
||||
The learning rate of `1e-4` is just an example. Adjust it according to your dataset and objectives. This value is for `alpha=1.0` (default). If increasing `--network_alpha`, consider lowering the learning rate.
|
||||
|
||||
If loss becomes NaN, ensure you are using PyTorch version 2.5 or higher.
|
||||
|
||||
**Note:** `--vae_chunk_size` and `--vae_disable_cache` are custom options in this repository to reduce memory usage of the Qwen-Image VAE.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
@@ -138,6 +147,13 @@ accelerate launch --num_cpu_threads_per_process 1 anima_train_network.py \
|
||||
コマンドラインの例は英語のドキュメントを参照してください。
|
||||
|
||||
※実際には1行で書くか、適切な改行文字(`\` または `^`)を使用してください。
|
||||
|
||||
学習率1e-4はあくまで一例です。データセットや目的に応じて適切に調整してください。またこの値はalpha=1.0(デフォルト)での値です。`--network_alpha`を増やす場合は学習率を下げることを検討してください。
|
||||
|
||||
lossがNaNになる場合は、PyTorchのバージョンが2.5以上であることを確認してください。
|
||||
|
||||
注意: `--vae_chunk_size`および`--vae_disable_cache`は当リポジトリ独自のオプションで、Qwen-Image VAEのメモリ使用量を削減するために使用します。
|
||||
|
||||
</details>
|
||||
|
||||
### 4.1. Explanation of Key Options / 主要なコマンドライン引数の解説
|
||||
@@ -146,12 +162,15 @@ Besides the arguments explained in the [train_network.py guide](train_network.md
|
||||
|
||||
#### Model Options [Required] / モデル関連 [必須]
|
||||
|
||||
* `--dit_path="<path to Anima DiT model>"` **[Required]**
|
||||
* `--pretrained_model_name_or_path="<path to Anima DiT model>"` **[Required]**
|
||||
- Path to the Anima DiT model `.safetensors` file. The model config (channels, blocks, heads) is auto-detected from the state dict. ComfyUI format with `net.` prefix is supported.
|
||||
* `--qwen3_path="<path to Qwen3-0.6B model>"` **[Required]**
|
||||
* `--qwen3="<path to Qwen3-0.6B model>"` **[Required]**
|
||||
- Path to the Qwen3-0.6B text encoder. Can be a HuggingFace model directory or a single `.safetensors` file. The text encoder is always frozen during training.
|
||||
* `--vae_path="<path to WanVAE model>"` **[Required]**
|
||||
- Path to the WanVAE model `.safetensors` or `.pth` file. Fixed config: `dim=96, z_dim=16`.
|
||||
* `--vae="<path to Qwen-Image VAE model>"` **[Required]**
|
||||
- Path to the Qwen-Image VAE model `.safetensors` or `.pth` file. Fixed config: `dim=96, z_dim=16`.
|
||||
|
||||
#### Model Options [Optional] / モデル関連 [オプション]
|
||||
|
||||
* `--llm_adapter_path="<path to LLM adapter>"` *[Optional]*
|
||||
- Path to a separate LLM adapter weights file. If omitted, the adapter is loaded from the DiT file when the key `llm_adapter.out_proj.weight` exists.
|
||||
* `--t5_tokenizer_path="<path to T5 tokenizer>"` *[Optional]*
|
||||
@@ -159,53 +178,58 @@ Besides the arguments explained in the [train_network.py guide](train_network.md
|
||||
|
||||
#### Anima Training Parameters / Anima 学習パラメータ
|
||||
|
||||
* `--timestep_sample_method=<choice>`
|
||||
- Timestep sampling method. Choose from `logit_normal` (default) or `uniform`.
|
||||
* `--timestep_sampling=<choice>`
|
||||
- Timestep sampling method. Choose from `sigma`, `uniform`, `sigmoid` (default), `shift`, `flux_shift`. Same options as FLUX training. See the [flux_train_network.py guide](flux_train_network.md) for details on each method.
|
||||
* `--discrete_flow_shift=<float>`
|
||||
- Shift for the timestep distribution in Rectified Flow training. Default `3.0`. The shift formula is `t_shifted = (t * shift) / (1 + (shift - 1) * t)`.
|
||||
- Shift for the timestep distribution in Rectified Flow training. Default `1.0`. This value is used when `--timestep_sampling` is set to **`shift`**. The shift formula is `t_shifted = (t * shift) / (1 + (shift - 1) * t)`.
|
||||
* `--sigmoid_scale=<float>`
|
||||
- Scale factor for `logit_normal` timestep sampling. Default `1.0`.
|
||||
- Scale factor when `--timestep_sampling` is set to `sigmoid`, `shift`, or `flux_shift`. Default `1.0`.
|
||||
* `--qwen3_max_token_length=<integer>`
|
||||
- Maximum token length for the Qwen3 tokenizer. Default `512`.
|
||||
* `--t5_max_token_length=<integer>`
|
||||
- Maximum token length for the T5 tokenizer. Default `512`.
|
||||
* `--flash_attn`
|
||||
- Use Flash Attention for DiT self/cross-attention. Requires `pip install flash-attn`. Falls back to PyTorch SDPA if the package is not installed. Note: Flash Attention is only applied to DiT blocks; the LLM Adapter uses standard attention because it requires attention masks.
|
||||
* `--transformer_dtype=<choice>`
|
||||
- Separate dtype for transformer blocks. Choose from `float16`, `bfloat16`, `float32`. If not specified, uses the same dtype as `--mixed_precision`.
|
||||
* `--attn_mode=<choice>`
|
||||
- Attention implementation to use. Choose from `torch` (default), `xformers`, `flash`, `sageattn`. `xformers` requires `--split_attn`. `sageattn` does not support training (inference only). This option overrides `--xformers`.
|
||||
* `--split_attn`
|
||||
- Split attention computation to reduce memory usage. Required when using `--attn_mode xformers`.
|
||||
|
||||
#### Component-wise Learning Rates / コンポーネント別学習率
|
||||
|
||||
Anima supports 6 independent learning rate groups. Set to `0` to freeze a component:
|
||||
These options set separate learning rates for each component of the Anima model. They are primarily used for full fine-tuning. Set to `0` to freeze a component:
|
||||
|
||||
* `--self_attn_lr=<float>` - Learning rate for self-attention layers. Default: same as `--learning_rate`.
|
||||
* `--cross_attn_lr=<float>` - Learning rate for cross-attention layers. Default: same as `--learning_rate`.
|
||||
* `--mlp_lr=<float>` - Learning rate for MLP layers. Default: same as `--learning_rate`.
|
||||
* `--mod_lr=<float>` - Learning rate for AdaLN modulation layers. Default: same as `--learning_rate`.
|
||||
* `--mod_lr=<float>` - Learning rate for AdaLN modulation layers. Default: same as `--learning_rate`. Note: modulation layers are not included in LoRA by default.
|
||||
* `--llm_adapter_lr=<float>` - Learning rate for LLM adapter layers. Default: same as `--learning_rate`.
|
||||
|
||||
For LoRA training, use `network_reg_lrs` in `--network_args` instead. See [Section 5.2](#52-regex-based-rank-and-learning-rate-control--正規表現によるランク学習率の制御).
|
||||
|
||||
#### Memory and Speed / メモリ・速度関連
|
||||
|
||||
* `--blocks_to_swap=<integer>` **[Experimental]**
|
||||
* `--blocks_to_swap=<integer>`
|
||||
- Number of Transformer blocks to swap between CPU and GPU. More blocks reduce VRAM but slow training. Maximum values depend on model size:
|
||||
- 28-block model: max **26**
|
||||
- 28-block model: max **26** (Anima-Preview)
|
||||
- 36-block model: max **34**
|
||||
- 20-block model: max **18**
|
||||
- Cannot be used with `--cpu_offload_checkpointing` or `--unsloth_offload_checkpointing`.
|
||||
* `--unsloth_offload_checkpointing`
|
||||
- Offload activations to CPU RAM using async non-blocking transfers. Faster than `--cpu_offload_checkpointing`. Cannot be combined with `--cpu_offload_checkpointing` or `--blocks_to_swap`.
|
||||
- Offload activations to CPU RAM using async non-blocking transfers (faster than `--cpu_offload_checkpointing`). Cannot be combined with `--cpu_offload_checkpointing` or `--blocks_to_swap`.
|
||||
* `--cache_text_encoder_outputs`
|
||||
- Cache Qwen3 text encoder outputs to reduce VRAM usage. Recommended when not training text encoder LoRA.
|
||||
* `--cache_text_encoder_outputs_to_disk`
|
||||
- Cache text encoder outputs to disk. Auto-enables `--cache_text_encoder_outputs`.
|
||||
* `--cache_latents`, `--cache_latents_to_disk`
|
||||
- Cache WanVAE latent outputs.
|
||||
* `--fp8_base`
|
||||
- Use FP8 precision for the base model to reduce VRAM usage.
|
||||
- Cache Qwen-Image VAE latent outputs.
|
||||
* `--vae_chunk_size=<integer>`
|
||||
- Chunk size for Qwen-Image VAE processing. Reduces VRAM usage at the cost of speed. Default is no chunking.
|
||||
* `--vae_disable_cache`
|
||||
- Disable internal caching in Qwen-Image VAE to reduce VRAM usage.
|
||||
|
||||
#### Incompatible or Deprecated Options / 非互換・非推奨の引数
|
||||
#### Incompatible or Unsupported Options / 非互換・非サポートの引数
|
||||
|
||||
* `--v2`, `--v_parameterization`, `--clip_skip` - Options for Stable Diffusion v1/v2 that are not used for Anima training.
|
||||
* `--fp8_base` - Not supported for Anima. If specified, it will be disabled with a warning.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
@@ -214,39 +238,50 @@ Anima supports 6 independent learning rate groups. Set to `0` to freeze a compon
|
||||
|
||||
#### モデル関連 [必須]
|
||||
|
||||
* `--dit_path="<path to Anima DiT model>"` **[必須]** - Anima DiTモデルの`.safetensors`ファイルのパスを指定します。
|
||||
* `--qwen3_path="<path to Qwen3-0.6B model>"` **[必須]** - Qwen3-0.6Bテキストエンコーダーのパスを指定します。
|
||||
* `--vae_path="<path to WanVAE model>"` **[必須]** - WanVAEモデルのパスを指定します。
|
||||
* `--pretrained_model_name_or_path="<path to Anima DiT model>"` **[必須]** - Anima DiTモデルの`.safetensors`ファイルのパスを指定します。モデルの設定はstate dictから自動検出されます。`net.`プレフィックス付きのComfyUIフォーマットもサポートしています。
|
||||
* `--qwen3="<path to Qwen3-0.6B model>"` **[必須]** - Qwen3-0.6Bテキストエンコーダーのパスを指定します。HuggingFaceモデルディレクトリまたは単体の`.safetensors`ファイルが使用できます。
|
||||
* `--vae="<path to Qwen-Image VAE model>"` **[必須]** - Qwen-Image VAEモデルのパスを指定します。
|
||||
|
||||
#### モデル関連 [オプション]
|
||||
|
||||
* `--llm_adapter_path="<path to LLM adapter>"` *[オプション]* - 個別のLLM Adapterの重みファイルのパス。
|
||||
* `--t5_tokenizer_path="<path to T5 tokenizer>"` *[オプション]* - T5トークナイザーディレクトリのパス。
|
||||
|
||||
#### Anima 学習パラメータ
|
||||
|
||||
* `--timestep_sample_method` - タイムステップのサンプリング方法。`logit_normal`(デフォルト)または`uniform`。
|
||||
* `--discrete_flow_shift` - Rectified Flow学習のタイムステップ分布シフト。デフォルト`3.0`。
|
||||
* `--sigmoid_scale` - logit_normalタイムステップサンプリングのスケール係数。デフォルト`1.0`。
|
||||
* `--timestep_sampling` - タイムステップのサンプリング方法。`sigma`、`uniform`、`sigmoid`(デフォルト)、`shift`、`flux_shift`から選択。FLUX学習と同じオプションです。各方法の詳細は[flux_train_network.pyのガイド](flux_train_network.md)を参照してください。
|
||||
* `--discrete_flow_shift` - Rectified Flow学習のタイムステップ分布シフト。デフォルト`1.0`。`--timestep_sampling`が`shift`の場合に使用されます。
|
||||
* `--sigmoid_scale` - `sigmoid`、`shift`、`flux_shift`タイムステップサンプリングのスケール係数。デフォルト`1.0`。
|
||||
* `--qwen3_max_token_length` - Qwen3トークナイザーの最大トークン長。デフォルト`512`。
|
||||
* `--t5_max_token_length` - T5トークナイザーの最大トークン長。デフォルト`512`。
|
||||
* `--flash_attn` - DiTのself/cross-attentionにFlash Attentionを使用。`pip install flash-attn`が必要。
|
||||
* `--transformer_dtype` - Transformerブロック用の個別dtype。
|
||||
* `--attn_mode` - 使用するAttentionの実装。`torch`(デフォルト)、`xformers`、`flash`、`sageattn`から選択。`xformers`は`--split_attn`の指定が必要です。`sageattn`はトレーニングをサポートしていません(推論のみ)。
|
||||
* `--split_attn` - メモリ使用量を減らすためにattention時にバッチを分割します。`--attn_mode xformers`使用時に必要です。
|
||||
|
||||
#### コンポーネント別学習率
|
||||
|
||||
Animaは6つの独立した学習率グループをサポートします。`0`に設定するとそのコンポーネントをフリーズします:
|
||||
これらのオプションは、Animaモデルの各コンポーネントに個別の学習率を設定します。主にフルファインチューニング用です。`0`に設定するとそのコンポーネントをフリーズします:
|
||||
|
||||
* `--self_attn_lr` - Self-attention層の学習率。
|
||||
* `--cross_attn_lr` - Cross-attention層の学習率。
|
||||
* `--mlp_lr` - MLP層の学習率。
|
||||
* `--mod_lr` - AdaLNモジュレーション層の学習率。
|
||||
* `--mod_lr` - AdaLNモジュレーション層の学習率。モジュレーション層はデフォルトではLoRAに含まれません。
|
||||
* `--llm_adapter_lr` - LLM Adapter層の学習率。
|
||||
|
||||
LoRA学習の場合は、`--network_args`の`network_reg_lrs`を使用してください。[セクション5.2](#52-regex-based-rank-and-learning-rate-control--正規表現によるランク学習率の制御)を参照。
|
||||
|
||||
#### メモリ・速度関連
|
||||
|
||||
* `--blocks_to_swap` **[実験的機能]** - TransformerブロックをCPUとGPUでスワップしてVRAMを節約。
|
||||
* `--unsloth_offload_checkpointing` - 非同期転送でアクティベーションをCPU RAMにオフロード。
|
||||
* `--blocks_to_swap` - TransformerブロックをCPUとGPUでスワップしてVRAMを節約。`--cpu_offload_checkpointing`および`--unsloth_offload_checkpointing`とは併用できません。
|
||||
* `--unsloth_offload_checkpointing` - 非同期転送でアクティベーションをCPU RAMにオフロード。`--cpu_offload_checkpointing`および`--blocks_to_swap`とは併用できません。
|
||||
* `--cache_text_encoder_outputs` - Qwen3の出力をキャッシュしてメモリ使用量を削減。
|
||||
* `--cache_latents`, `--cache_latents_to_disk` - WanVAEの出力をキャッシュ。
|
||||
* `--fp8_base` - ベースモデルにFP8精度を使用。
|
||||
* `--cache_latents`, `--cache_latents_to_disk` - Qwen-Image VAEの出力をキャッシュ。
|
||||
* `--vae_chunk_size` - Qwen-Image VAEのチャンク処理サイズ。メモリ使用量を削減しますが速度が低下します。デフォルトはチャンク処理なし。
|
||||
* `--vae_disable_cache` - Qwen-Image VAEの内部キャッシュを無効化してメモリ使用量を削減します。
|
||||
|
||||
#### 非互換・非サポートの引数
|
||||
|
||||
* `--v2`, `--v_parameterization`, `--clip_skip` - Stable Diffusion v1/v2向けの引数。Animaの学習では使用されません。
|
||||
* `--fp8_base` - Animaではサポートされていません。指定した場合、警告とともに無効化されます。
|
||||
</details>
|
||||
|
||||
### 4.2. Starting Training / 学習の開始
|
||||
@@ -262,67 +297,66 @@ After setting the required arguments, run the command to begin training. The ove
|
||||
|
||||
## 5. LoRA Target Modules / LoRAの学習対象モジュール
|
||||
|
||||
When training LoRA with `anima_train_network.py`, the following modules are targeted:
|
||||
When training LoRA with `anima_train_network.py`, the following modules are targeted by default:
|
||||
|
||||
* **DiT Blocks (`Block`)**: Self-attention, cross-attention, MLP, and AdaLN modulation layers within each transformer block.
|
||||
* **DiT Blocks (`Block`)**: Self-attention (`self_attn`), cross-attention (`cross_attn`), and MLP (`mlp`) layers within each transformer block. Modulation (`adaln_modulation`), norm, embedder, and final layers are excluded by default.
|
||||
* **Embedding layers (`PatchEmbed`, `TimestepEmbedding`) and Final layer (`FinalLayer`)**: Excluded by default but can be included using `include_patterns`.
|
||||
* **LLM Adapter Blocks (`LLMAdapterTransformerBlock`)**: Only when `--network_args "train_llm_adapter=True"` is specified.
|
||||
* **Text Encoder (Qwen3)**: Only when `--network_train_unet_only` is NOT specified.
|
||||
* **Text Encoder (Qwen3)**: Only when `--network_train_unet_only` is NOT specified and `--cache_text_encoder_outputs` is NOT used.
|
||||
|
||||
The LoRA network module is `networks.lora_anima`.
|
||||
|
||||
### 5.1. Layer-specific Rank Configuration / 各層に対するランク指定
|
||||
### 5.1. Module Selection with Patterns / パターンによるモジュール選択
|
||||
|
||||
You can specify different ranks (network_dim) for each component of the Anima model. Setting `0` disables LoRA for that component.
|
||||
|
||||
| network_args | Target Component |
|
||||
|---|---|
|
||||
| `self_attn_dim` | Self-attention layers in DiT blocks |
|
||||
| `cross_attn_dim` | Cross-attention layers in DiT blocks |
|
||||
| `mlp_dim` | MLP layers in DiT blocks |
|
||||
| `mod_dim` | AdaLN modulation layers in DiT blocks |
|
||||
| `llm_adapter_dim` | LLM adapter layers (requires `train_llm_adapter=True`) |
|
||||
|
||||
Example usage:
|
||||
By default, the following modules are excluded from LoRA via the built-in exclude pattern:
|
||||
```
|
||||
--network_args "self_attn_dim=8" "cross_attn_dim=4" "mlp_dim=8" "mod_dim=4"
|
||||
.*(_modulation|_norm|_embedder|final_layer).*
|
||||
```
|
||||
|
||||
### 5.2. Embedding Layer LoRA / 埋め込み層LoRA
|
||||
You can customize which modules are included or excluded using regex patterns in `--network_args`:
|
||||
|
||||
You can apply LoRA to embedding/output layers by specifying `emb_dims` in network_args as a comma-separated list of 3 numbers:
|
||||
* `exclude_patterns` - Exclude modules matching these patterns (in addition to the default exclusion).
|
||||
* `include_patterns` - Force-include modules matching these patterns, overriding exclusion.
|
||||
|
||||
Patterns are matched against the full module name using `re.fullmatch()`.
|
||||
|
||||
Example to include the final layer:
|
||||
```
|
||||
--network_args "emb_dims=[8,4,8]"
|
||||
--network_args "include_patterns=['.*final_layer.*']"
|
||||
```
|
||||
|
||||
Each number corresponds to:
|
||||
1. `x_embedder` (patch embedding)
|
||||
2. `t_embedder` (timestep embedding)
|
||||
3. `final_layer` (output layer)
|
||||
|
||||
Setting `0` disables LoRA for that layer.
|
||||
|
||||
### 5.3. Block Selection for Training / 学習するブロックの指定
|
||||
|
||||
You can specify which DiT blocks to train using `train_block_indices` in network_args. The indices are 0-based. Default is to train all blocks.
|
||||
|
||||
Specify indices as comma-separated integers or ranges:
|
||||
|
||||
Example to additionally exclude MLP layers:
|
||||
```
|
||||
--network_args "train_block_indices=0-5,10,15-27"
|
||||
--network_args "exclude_patterns=['.*mlp.*']"
|
||||
```
|
||||
|
||||
Special values: `all` (train all blocks), `none` (skip all blocks).
|
||||
### 5.2. Regex-based Rank and Learning Rate Control / 正規表現によるランク・学習率の制御
|
||||
|
||||
### 5.4. LLM Adapter LoRA / LLM Adapter LoRA
|
||||
You can specify different ranks (network_dim) and learning rates for modules matching specific regex patterns:
|
||||
|
||||
* `network_reg_dims`: Specify ranks for modules matching a regular expression. The format is a comma-separated string of `pattern=rank`.
|
||||
* Example: `--network_args "network_reg_dims=.*self_attn.*=8,.*cross_attn.*=4,.*mlp.*=8"`
|
||||
* This sets the rank to 8 for self-attention modules, 4 for cross-attention modules, and 8 for MLP modules.
|
||||
* `network_reg_lrs`: Specify learning rates for modules matching a regular expression. The format is a comma-separated string of `pattern=lr`.
|
||||
* Example: `--network_args "network_reg_lrs=.*self_attn.*=1e-4,.*cross_attn.*=5e-5"`
|
||||
* This sets the learning rate to `1e-4` for self-attention modules and `5e-5` for cross-attention modules.
|
||||
|
||||
**Notes:**
|
||||
|
||||
* Settings via `network_reg_dims` and `network_reg_lrs` take precedence over the global `--network_dim` and `--learning_rate` settings.
|
||||
* Patterns are matched using `re.fullmatch()` against the module's original name (e.g., `blocks.0.self_attn.q_proj`).
|
||||
|
||||
### 5.3. LLM Adapter LoRA / LLM Adapter LoRA
|
||||
|
||||
To apply LoRA to the LLM Adapter blocks:
|
||||
|
||||
```
|
||||
--network_args "train_llm_adapter=True" "llm_adapter_dim=4"
|
||||
--network_args "train_llm_adapter=True"
|
||||
```
|
||||
|
||||
### 5.5. Other Network Args / その他のネットワーク引数
|
||||
In preliminary tests, lowering the learning rate for the LLM Adapter seems to improve stability. Adjust it using something like: `"network_reg_lrs=.*llm_adapter.*=5e-5"`.
|
||||
|
||||
### 5.4. Other Network Args / その他のネットワーク引数
|
||||
|
||||
* `--network_args "verbose=True"` - Print all LoRA module names and their dimensions.
|
||||
* `--network_args "rank_dropout=0.1"` - Rank dropout rate.
|
||||
@@ -336,42 +370,52 @@ To apply LoRA to the LLM Adapter blocks:
|
||||
|
||||
`anima_train_network.py`でLoRAを学習させる場合、デフォルトでは以下のモジュールが対象となります。
|
||||
|
||||
* **DiTブロック (`Block`)**: 各Transformerブロック内のSelf-attention、Cross-attention、MLP、AdaLNモジュレーション層。
|
||||
* **DiTブロック (`Block`)**: 各Transformerブロック内のSelf-attention(`self_attn`)、Cross-attention(`cross_attn`)、MLP(`mlp`)層。モジュレーション(`adaln_modulation`)、norm、embedder、final layerはデフォルトで除外されます。
|
||||
* **埋め込み層 (`PatchEmbed`, `TimestepEmbedding`) と最終層 (`FinalLayer`)**: デフォルトで除外されますが、`include_patterns`で含めることができます。
|
||||
* **LLM Adapterブロック (`LLMAdapterTransformerBlock`)**: `--network_args "train_llm_adapter=True"`を指定した場合のみ。
|
||||
* **テキストエンコーダー (Qwen3)**: `--network_train_unet_only`を指定しない場合のみ。
|
||||
* **テキストエンコーダー (Qwen3)**: `--network_train_unet_only`を指定せず、かつ`--cache_text_encoder_outputs`を使用しない場合のみ。
|
||||
|
||||
### 5.1. 各層のランクを指定する
|
||||
### 5.1. パターンによるモジュール選択
|
||||
|
||||
`--network_args`で各コンポーネントに異なるランクを指定できます。`0`を指定するとその層にはLoRAが適用されません。
|
||||
デフォルトでは以下のモジュールが組み込みの除外パターンによりLoRAから除外されます:
|
||||
```
|
||||
.*(_modulation|_norm|_embedder|final_layer).*
|
||||
```
|
||||
|
||||
|network_args|対象コンポーネント|
|
||||
|---|---|
|
||||
|`self_attn_dim`|DiTブロック内のSelf-attention層|
|
||||
|`cross_attn_dim`|DiTブロック内のCross-attention層|
|
||||
|`mlp_dim`|DiTブロック内のMLP層|
|
||||
|`mod_dim`|DiTブロック内のAdaLNモジュレーション層|
|
||||
|`llm_adapter_dim`|LLM Adapter層(`train_llm_adapter=True`が必要)|
|
||||
`--network_args`で正規表現パターンを使用して、含めるモジュールと除外するモジュールをカスタマイズできます:
|
||||
|
||||
### 5.2. 埋め込み層LoRA
|
||||
* `exclude_patterns` - これらのパターンにマッチするモジュールを除外(デフォルトの除外に追加)。
|
||||
* `include_patterns` - これらのパターンにマッチするモジュールを強制的に含める(除外を上書き)。
|
||||
|
||||
`emb_dims`で埋め込み/出力層にLoRAを適用できます。3つの数値をカンマ区切りで指定します。
|
||||
パターンは`re.fullmatch()`を使用して完全なモジュール名に対してマッチングされます。
|
||||
|
||||
各数値は `x_embedder`(パッチ埋め込み)、`t_embedder`(タイムステップ埋め込み)、`final_layer`(出力層)に対応します。
|
||||
### 5.2. 正規表現によるランク・学習率の制御
|
||||
|
||||
### 5.3. 学習するブロックの指定
|
||||
正規表現にマッチするモジュールに対して、異なるランクや学習率を指定できます:
|
||||
|
||||
`train_block_indices`でLoRAを適用するDiTブロックを指定できます。
|
||||
* `network_reg_dims`: 正規表現にマッチするモジュールに対してランクを指定します。`pattern=rank`形式の文字列をカンマで区切って指定します。
|
||||
* 例: `--network_args "network_reg_dims=.*self_attn.*=8,.*cross_attn.*=4,.*mlp.*=8"`
|
||||
* `network_reg_lrs`: 正規表現にマッチするモジュールに対して学習率を指定します。`pattern=lr`形式の文字列をカンマで区切って指定します。
|
||||
* 例: `--network_args "network_reg_lrs=.*self_attn.*=1e-4,.*cross_attn.*=5e-5"`
|
||||
|
||||
### 5.4. LLM Adapter LoRA
|
||||
**注意点:**
|
||||
* `network_reg_dims`および`network_reg_lrs`での設定は、全体設定である`--network_dim`や`--learning_rate`よりも優先されます。
|
||||
* パターンはモジュールのオリジナル名(例: `blocks.0.self_attn.q_proj`)に対して`re.fullmatch()`でマッチングされます。
|
||||
|
||||
LLM AdapterブロックにLoRAを適用するには:`--network_args "train_llm_adapter=True" "llm_adapter_dim=4"`
|
||||
### 5.3. LLM Adapter LoRA
|
||||
|
||||
### 5.5. その他のネットワーク引数
|
||||
LLM AdapterブロックにLoRAを適用するには:`--network_args "train_llm_adapter=True"`
|
||||
|
||||
簡易な検証ではLLM Adapterの学習率はある程度下げた方が安定するようです。`"network_reg_lrs=.*llm_adapter.*=5e-5"`などで調整してください。
|
||||
|
||||
### 5.4. その他のネットワーク引数
|
||||
|
||||
* `verbose=True` - 全LoRAモジュール名とdimを表示
|
||||
* `rank_dropout` - ランクドロップアウト率
|
||||
* `module_dropout` - モジュールドロップアウト率
|
||||
* `loraplus_lr_ratio` - LoRA+学習率比率
|
||||
* `loraplus_unet_lr_ratio` - DiT専用のLoRA+学習率比率
|
||||
* `loraplus_text_encoder_lr_ratio` - テキストエンコーダー専用のLoRA+学習率比率
|
||||
|
||||
</details>
|
||||
|
||||
@@ -394,8 +438,6 @@ Anima models can be large, so GPUs with limited VRAM may require optimization:
|
||||
|
||||
#### Key VRAM Reduction Options
|
||||
|
||||
- **`--fp8_base`**: Enables training in FP8 format for the DiT model.
|
||||
|
||||
- **`--blocks_to_swap <number>`**: Swaps blocks between CPU and GPU to reduce VRAM usage. Higher numbers save more VRAM but reduce training speed. See model-specific max values in section 4.1.
|
||||
|
||||
- **`--unsloth_offload_checkpointing`**: Offloads gradient checkpoints to CPU using async non-blocking transfers. Faster than `--cpu_offload_checkpointing`. Cannot be combined with `--blocks_to_swap`.
|
||||
@@ -404,7 +446,7 @@ Anima models can be large, so GPUs with limited VRAM may require optimization:
|
||||
|
||||
- **`--cache_text_encoder_outputs`**: Caches Qwen3 outputs so the text encoder can be freed from VRAM during training.
|
||||
|
||||
- **`--cache_latents`**: Caches WanVAE outputs so the VAE can be freed from VRAM during training.
|
||||
- **`--cache_latents`**: Caches Qwen-Image VAE outputs so the VAE can be freed from VRAM during training.
|
||||
|
||||
- **Using Adafactor optimizer**: Can reduce VRAM usage:
|
||||
```
|
||||
@@ -417,12 +459,11 @@ Anima models can be large, so GPUs with limited VRAM may require optimization:
|
||||
Animaモデルは大きい場合があるため、VRAMが限られたGPUでは最適化が必要です。
|
||||
|
||||
主要なVRAM削減オプション:
|
||||
- `--fp8_base`: FP8形式での学習を有効化
|
||||
- `--blocks_to_swap`: CPUとGPU間でブロックをスワップ
|
||||
- `--unsloth_offload_checkpointing`: 非同期転送でアクティベーションをCPUにオフロード
|
||||
- `--gradient_checkpointing`: 標準的な勾配チェックポイント
|
||||
- `--cache_text_encoder_outputs`: Qwen3の出力をキャッシュ
|
||||
- `--cache_latents`: WanVAEの出力をキャッシュ
|
||||
- `--cache_latents`: Qwen-Image VAEの出力をキャッシュ
|
||||
- Adafactorオプティマイザの使用
|
||||
|
||||
</details>
|
||||
@@ -431,21 +472,24 @@ Animaモデルは大きい場合があるため、VRAMが限られたGPUでは
|
||||
|
||||
#### Timestep Sampling
|
||||
|
||||
The `--timestep_sample_method` option specifies how timesteps (0-1) are sampled:
|
||||
The `--timestep_sampling` option specifies how timesteps are sampled. The available methods are the same as FLUX training:
|
||||
|
||||
- `logit_normal` (default): Sample from Normal(0,1), multiply by `sigmoid_scale`, apply sigmoid. Good general-purpose option.
|
||||
- `sigma`: Sigma-based sampling like SD3.
|
||||
- `uniform`: Uniform random sampling from [0, 1].
|
||||
- `sigmoid` (default): Sample from Normal(0,1), multiply by `sigmoid_scale`, apply sigmoid. Good general-purpose option.
|
||||
- `shift`: Like `sigmoid`, but applies the discrete flow shift formula: `t_shifted = (t * shift) / (1 + (shift - 1) * t)`.
|
||||
- `flux_shift`: Resolution-dependent shift used in FLUX training.
|
||||
|
||||
See the [flux_train_network.py guide](flux_train_network.md) for detailed descriptions.
|
||||
|
||||
#### Discrete Flow Shift
|
||||
|
||||
The `--discrete_flow_shift` option (default `3.0`) shifts the timestep distribution toward higher noise levels. The formula is:
|
||||
The `--discrete_flow_shift` option (default `1.0`) only applies when `--timestep_sampling` is set to `shift`. The formula is:
|
||||
|
||||
```
|
||||
t_shifted = (t * shift) / (1 + (shift - 1) * t)
|
||||
```
|
||||
|
||||
Timesteps are clamped to `[1e-5, 1-1e-5]` after shifting.
|
||||
|
||||
#### Loss Weighting
|
||||
|
||||
The `--weighting_scheme` option specifies loss weighting by timestep:
|
||||
@@ -454,23 +498,34 @@ The `--weighting_scheme` option specifies loss weighting by timestep:
|
||||
- `sigma_sqrt`: Weight by `sigma^(-2)`.
|
||||
- `cosmap`: Weight by `2 / (pi * (1 - 2*sigma + 2*sigma^2))`.
|
||||
- `none`: Same as uniform.
|
||||
- `logit_normal`, `mode`: Additional schemes from SD3 training. See the [`sd3_train_network.md` guide](sd3_train_network.md) for details.
|
||||
|
||||
#### Caption Dropout
|
||||
|
||||
Use `--caption_dropout_rate` for embedding-level caption dropout. This is handled by `AnimaTextEncodingStrategy` and is compatible with text encoder output caching. The subset-level `caption_dropout_rate` is automatically zeroed when this is set.
|
||||
Caption dropout uses the `caption_dropout_rate` setting from the dataset configuration (per-subset in TOML). When using `--cache_text_encoder_outputs`, the dropout rate is stored with each cached entry and applied during training, so caption dropout is compatible with text encoder output caching.
|
||||
|
||||
**If you change the `caption_dropout_rate` setting, you must delete and regenerate the cache.**
|
||||
|
||||
Note: Currently, only Anima supports combining `caption_dropout_rate` with text encoder output caching.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
#### タイムステップサンプリング
|
||||
|
||||
`--timestep_sample_method`でタイムステップのサンプリング方法を指定します:
|
||||
- `logit_normal`(デフォルト): 正規分布からサンプリングし、sigmoidを適用。
|
||||
`--timestep_sampling`でタイムステップのサンプリング方法を指定します。FLUX学習と同じ方法が利用できます:
|
||||
|
||||
- `sigma`: SD3と同様のシグマベースサンプリング。
|
||||
- `uniform`: [0, 1]の一様分布からサンプリング。
|
||||
- `sigmoid`(デフォルト): 正規分布からサンプリングし、sigmoidを適用。汎用的なオプション。
|
||||
- `shift`: `sigmoid`と同様だが、離散フローシフトの式を適用。
|
||||
- `flux_shift`: FLUX学習で使用される解像度依存のシフト。
|
||||
|
||||
詳細は[flux_train_network.pyのガイド](flux_train_network.md)を参照してください。
|
||||
|
||||
#### 離散フローシフト
|
||||
|
||||
`--discrete_flow_shift`(デフォルト`3.0`)はタイムステップ分布を高ノイズ側にシフトします。
|
||||
`--discrete_flow_shift`(デフォルト`1.0`)は`--timestep_sampling`が`shift`の場合のみ適用されます。
|
||||
|
||||
#### 損失の重み付け
|
||||
|
||||
@@ -478,7 +533,11 @@ Use `--caption_dropout_rate` for embedding-level caption dropout. This is handle
|
||||
|
||||
#### キャプションドロップアウト
|
||||
|
||||
`--caption_dropout_rate`で埋め込みレベルのキャプションドロップアウトを使用します。テキストエンコーダー出力のキャッシュと互換性があります。
|
||||
キャプションドロップアウトにはデータセット設定(TOMLでのサブセット単位)の`caption_dropout_rate`を使用します。`--cache_text_encoder_outputs`使用時は、ドロップアウト率が各キャッシュエントリとともに保存され、学習中に適用されるため、テキストエンコーダー出力キャッシュと同時に使用できます。
|
||||
|
||||
**`caption_dropout_rate`の設定を変えた場合、キャッシュを削除し、再生成する必要があります。**
|
||||
|
||||
※`caption_dropout_rate`をテキストエンコーダー出力キャッシュと組み合わせられるのは、今のところAnimaのみです。
|
||||
|
||||
</details>
|
||||
|
||||
@@ -487,17 +546,23 @@ Use `--caption_dropout_rate` for embedding-level caption dropout. This is handle
|
||||
Anima LoRA training supports training Qwen3 text encoder LoRA:
|
||||
|
||||
- To train only DiT: specify `--network_train_unet_only`
|
||||
- To train DiT and Qwen3: omit `--network_train_unet_only`
|
||||
- To train DiT and Qwen3: omit `--network_train_unet_only` and do NOT use `--cache_text_encoder_outputs`
|
||||
|
||||
You can specify a separate learning rate for Qwen3 with `--text_encoder_lr`. If not specified, the default `--learning_rate` is used.
|
||||
|
||||
Note: When `--cache_text_encoder_outputs` is used, text encoder outputs are pre-computed and the text encoder is removed from GPU, so text encoder LoRA cannot be trained.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
Anima LoRA学習では、Qwen3テキストエンコーダーのLoRAもトレーニングできます。
|
||||
|
||||
- DiTのみ学習: `--network_train_unet_only`を指定
|
||||
- DiTとQwen3を学習: `--network_train_unet_only`を省略
|
||||
- DiTとQwen3を学習: `--network_train_unet_only`を省略し、`--cache_text_encoder_outputs`を使用しない
|
||||
|
||||
Qwen3に個別の学習率を指定するには`--text_encoder_lr`を使用します。未指定の場合は`--learning_rate`が使われます。
|
||||
|
||||
注意: `--cache_text_encoder_outputs`を使用する場合、テキストエンコーダーの出力が事前に計算されGPUから解放されるため、テキストエンコーダーLoRAは学習できません。
|
||||
|
||||
</details>
|
||||
|
||||
@@ -528,16 +593,47 @@ Anima LoRA学習では、Qwen3テキストエンコーダーのLoRAもトレー
|
||||
|
||||
</details>
|
||||
|
||||
## 9. Others / その他
|
||||
## 9. Related Tools / 関連ツール
|
||||
|
||||
### `networks/anima_convert_lora_to_comfy.py`
|
||||
|
||||
A script to convert LoRA models to ComfyUI-compatible format. ComfyUI does not directly support sd-scripts format Qwen3 LoRA, so conversion is necessary (conversion may not be needed for DiT-only LoRA). You can convert from the sd-scripts format to ComfyUI format with:
|
||||
|
||||
```bash
|
||||
python networks/convert_anima_lora_to_comfy.py path/to/source.safetensors path/to/destination.safetensors
|
||||
```
|
||||
|
||||
Using the `--reverse` option allows conversion in the opposite direction (ComfyUI format to sd-scripts format). However, reverse conversion is only possible for LoRAs converted by this script. LoRAs created with other training tools cannot be converted.
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
**`networks/convert_anima_lora_to_comfy.py`**
|
||||
|
||||
LoRAモデルをComfyUI互換形式に変換するスクリプト。ComfyUIがsd-scripts形式のQwen3 LoRAを直接サポートしていないため、変換が必要です(DiTのみのLoRAの場合は変換不要のようです)。sd-scripts形式からComfyUI形式への変換は以下のコマンドで行います:
|
||||
|
||||
```bash
|
||||
python networks/convert_anima_lora_to_comfy.py path/to/source.safetensors path/to/destination.safetensors
|
||||
```
|
||||
|
||||
`--reverse`オプションを付けると、逆変換(ComfyUI形式からsd-scripts形式)も可能です。ただし、逆変換ができるのはこのスクリプトで変換したLoRAに限ります。他の学習ツールで作成したLoRAは変換できません。
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
## 10. Others / その他
|
||||
|
||||
### Metadata Saved in LoRA Models
|
||||
|
||||
The following Anima-specific metadata is saved in the LoRA model file:
|
||||
The following metadata is saved in the LoRA model file:
|
||||
|
||||
* `ss_weighting_scheme`
|
||||
* `ss_discrete_flow_shift`
|
||||
* `ss_timestep_sample_method`
|
||||
* `ss_logit_mean`
|
||||
* `ss_logit_std`
|
||||
* `ss_mode_scale`
|
||||
* `ss_timestep_sampling`
|
||||
* `ss_sigmoid_scale`
|
||||
* `ss_discrete_flow_shift`
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
@@ -546,11 +642,14 @@ The following Anima-specific metadata is saved in the LoRA model file:
|
||||
|
||||
### LoRAモデルに保存されるメタデータ
|
||||
|
||||
以下のAnima固有のメタデータがLoRAモデルファイルに保存されます:
|
||||
以下のメタデータがLoRAモデルファイルに保存されます:
|
||||
|
||||
* `ss_weighting_scheme`
|
||||
* `ss_discrete_flow_shift`
|
||||
* `ss_timestep_sample_method`
|
||||
* `ss_logit_mean`
|
||||
* `ss_logit_std`
|
||||
* `ss_mode_scale`
|
||||
* `ss_timestep_sampling`
|
||||
* `ss_sigmoid_scale`
|
||||
* `ss_discrete_flow_shift`
|
||||
|
||||
</details>
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# Original code: NVIDIA CORPORATION & AFFILIATES, licensed under Apache-2.0
|
||||
|
||||
import math
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -13,9 +13,7 @@ import torch.nn.functional as F
|
||||
|
||||
from torch.utils.checkpoint import checkpoint as torch_checkpoint
|
||||
|
||||
from library import custom_offloading_utils
|
||||
from library.device_utils import clean_memory_on_device
|
||||
|
||||
from library import custom_offloading_utils, attention
|
||||
|
||||
|
||||
def to_device(x, device):
|
||||
@@ -39,11 +37,13 @@ def to_cpu(x):
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
# Unsloth Offloaded Gradient Checkpointing
|
||||
# Based on Unsloth Zoo by Daniel Han-Chen & the Unsloth team
|
||||
try:
|
||||
from deepspeed.runtime.activation_checkpointing.checkpointing import detach_variable
|
||||
except ImportError:
|
||||
|
||||
def detach_variable(inputs, device=None):
|
||||
"""Detach tensors from computation graph, optionally moving to a device.
|
||||
|
||||
@@ -80,11 +80,11 @@ class UnslothOffloadedGradientCheckpointer(torch.autograd.Function):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@torch.amp.custom_fwd(device_type='cuda')
|
||||
@torch.amp.custom_fwd(device_type="cuda")
|
||||
def forward(ctx, forward_function, hidden_states, *args):
|
||||
# Remember the original device for backward pass (multi-GPU support)
|
||||
ctx.input_device = hidden_states.device
|
||||
saved_hidden_states = hidden_states.to('cpu', non_blocking=True)
|
||||
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
|
||||
with torch.no_grad():
|
||||
output = forward_function(hidden_states, *args)
|
||||
ctx.save_for_backward(saved_hidden_states)
|
||||
@@ -96,7 +96,7 @@ class UnslothOffloadedGradientCheckpointer(torch.autograd.Function):
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@torch.amp.custom_bwd(device_type='cuda')
|
||||
@torch.amp.custom_bwd(device_type="cuda")
|
||||
def backward(ctx, *grads):
|
||||
(hidden_states,) = ctx.saved_tensors
|
||||
hidden_states = hidden_states.to(ctx.input_device, non_blocking=True).detach()
|
||||
@@ -108,8 +108,9 @@ class UnslothOffloadedGradientCheckpointer(torch.autograd.Function):
|
||||
|
||||
output_tensors = []
|
||||
grad_tensors = []
|
||||
for out, grad in zip(outputs if isinstance(outputs, tuple) else (outputs,),
|
||||
grads if isinstance(grads, tuple) else (grads,)):
|
||||
for out, grad in zip(
|
||||
outputs if isinstance(outputs, tuple) else (outputs,), grads if isinstance(grads, tuple) else (grads,)
|
||||
):
|
||||
if isinstance(out, torch.Tensor) and out.requires_grad:
|
||||
output_tensors.append(out)
|
||||
grad_tensors.append(grad)
|
||||
@@ -123,26 +124,6 @@ def unsloth_checkpoint(function, *args):
|
||||
return UnslothOffloadedGradientCheckpointer.apply(function, *args)
|
||||
|
||||
|
||||
# Flash Attention support
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func
|
||||
FLASH_ATTN_AVAILABLE = True
|
||||
except ImportError:
|
||||
_flash_attn_func = None
|
||||
FLASH_ATTN_AVAILABLE = False
|
||||
|
||||
|
||||
def flash_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor:
|
||||
"""Computes multi-head attention using Flash Attention.
|
||||
|
||||
Input format: (batch, seq_len, n_heads, head_dim)
|
||||
Output format: (batch, seq_len, n_heads * head_dim) — matches torch_attention_op output.
|
||||
"""
|
||||
# flash_attn_func expects (B, S, H, D) and returns (B, S, H, D)
|
||||
out = _flash_attn_func(q_B_S_H_D, k_B_S_H_D, v_B_S_H_D)
|
||||
return rearrange(out, "b s h d -> b s (h d)")
|
||||
|
||||
|
||||
from .utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
@@ -174,14 +155,10 @@ def _apply_rotary_pos_emb_base(
|
||||
|
||||
if start_positions is not None:
|
||||
max_offset = torch.max(start_positions)
|
||||
assert (
|
||||
max_offset + cur_seq_len <= max_seq_len
|
||||
), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
|
||||
assert max_offset + cur_seq_len <= max_seq_len, f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
|
||||
freqs = torch.concatenate([freqs[i : i + cur_seq_len] for i in start_positions], dim=1)
|
||||
|
||||
assert (
|
||||
cur_seq_len <= max_seq_len
|
||||
), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
|
||||
assert cur_seq_len <= max_seq_len, f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
|
||||
freqs = freqs[:cur_seq_len]
|
||||
|
||||
if tensor_format == "bshd":
|
||||
@@ -205,13 +182,9 @@ def apply_rotary_pos_emb(
|
||||
cu_seqlens: Union[torch.Tensor, None] = None,
|
||||
cp_size: int = 1,
|
||||
) -> torch.Tensor:
|
||||
assert not (
|
||||
cp_size > 1 and start_positions is not None
|
||||
), "start_positions != None with CP SIZE > 1 is not supported!"
|
||||
assert not (cp_size > 1 and start_positions is not None), "start_positions != None with CP SIZE > 1 is not supported!"
|
||||
|
||||
assert (
|
||||
tensor_format != "thd" or cu_seqlens is not None
|
||||
), "cu_seqlens must not be None when tensor_format is 'thd'."
|
||||
assert tensor_format != "thd" or cu_seqlens is not None, "cu_seqlens must not be None when tensor_format is 'thd'."
|
||||
|
||||
assert fused == False
|
||||
|
||||
@@ -223,9 +196,7 @@ def apply_rotary_pos_emb(
|
||||
_apply_rotary_pos_emb_base(
|
||||
x.unsqueeze(1),
|
||||
freqs,
|
||||
start_positions=(
|
||||
start_positions[idx : idx + 1] if start_positions is not None else None
|
||||
),
|
||||
start_positions=(start_positions[idx : idx + 1] if start_positions is not None else None),
|
||||
interleaved=interleaved,
|
||||
)
|
||||
for idx, x in enumerate(torch.split(t, seqlens))
|
||||
@@ -262,8 +233,8 @@ class RMSNorm(torch.nn.Module):
|
||||
def _norm(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
@torch.amp.autocast(device_type='cuda', dtype=torch.float32)
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
with torch.autocast(device_type=x.device.type, dtype=torch.float32):
|
||||
output = self._norm(x.float()).type_as(x)
|
||||
return output * self.weight
|
||||
|
||||
@@ -298,22 +269,6 @@ class GPT2FeedForward(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor:
|
||||
"""Computes multi-head attention using PyTorch's native scaled_dot_product_attention.
|
||||
|
||||
Input/output format: (batch, seq_len, n_heads, head_dim)
|
||||
"""
|
||||
in_q_shape = q_B_S_H_D.shape
|
||||
in_k_shape = k_B_S_H_D.shape
|
||||
q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1])
|
||||
k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
|
||||
v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
|
||||
result_B_S_HD = rearrange(
|
||||
F.scaled_dot_product_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D), "b h ... l -> b ... (h l)"
|
||||
)
|
||||
return result_B_S_HD
|
||||
|
||||
|
||||
# Attention module for DiT
|
||||
class Attention(nn.Module):
|
||||
"""Multi-head attention supporting both self-attention and cross-attention.
|
||||
@@ -354,8 +309,6 @@ class Attention(nn.Module):
|
||||
self.output_proj = nn.Linear(inner_dim, query_dim, bias=False)
|
||||
self.output_dropout = nn.Dropout(dropout) if dropout > 1e-4 else nn.Identity()
|
||||
|
||||
self.attn_op = torch_attention_op
|
||||
|
||||
self._query_dim = query_dim
|
||||
self._context_dim = context_dim
|
||||
self._inner_dim = inner_dim
|
||||
@@ -399,18 +352,25 @@ class Attention(nn.Module):
|
||||
|
||||
return q, k, v
|
||||
|
||||
def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
||||
result = self.attn_op(q, k, v) # [B, S, H, D]
|
||||
return self.output_dropout(self.output_proj(result))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_params: attention.AttentionParams,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
rope_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb)
|
||||
return self.compute_attention(q, k, v)
|
||||
if q.dtype != v.dtype:
|
||||
if (not attn_params.supports_fp32 or attn_params.requires_same_dtype) and torch.is_autocast_enabled():
|
||||
# FlashAttention requires fp16/bf16, xformers require same dtype; only cast when autocast is active.
|
||||
target_dtype = v.dtype # v has fp16/bf16 dtype
|
||||
q = q.to(target_dtype)
|
||||
k = k.to(target_dtype)
|
||||
# return self.compute_attention(q, k, v)
|
||||
qkv = [q, k, v]
|
||||
del q, k, v
|
||||
result = attention.attention(qkv, attn_params=attn_params)
|
||||
return self.output_dropout(self.output_proj(result))
|
||||
|
||||
|
||||
# Positional Embeddings
|
||||
@@ -484,12 +444,8 @@ class VideoRopePosition3DEmb(VideoPositionEmb):
|
||||
dim_t = self._dim_t
|
||||
|
||||
self.seq = torch.arange(max(self.max_h, self.max_w, self.max_t)).float().to(self.dim_spatial_range.device)
|
||||
self.dim_spatial_range = (
|
||||
torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().to(self.dim_spatial_range.device) / dim_h
|
||||
)
|
||||
self.dim_temporal_range = (
|
||||
torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().to(self.dim_spatial_range.device) / dim_t
|
||||
)
|
||||
self.dim_spatial_range = torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().to(self.dim_spatial_range.device) / dim_h
|
||||
self.dim_temporal_range = torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().to(self.dim_spatial_range.device) / dim_t
|
||||
|
||||
def generate_embeddings(
|
||||
self,
|
||||
@@ -664,31 +620,30 @@ class TimestepEmbedding(nn.Module):
|
||||
return emb_B_T_D, adaln_lora_B_T_3D
|
||||
|
||||
|
||||
class FourierFeatures(nn.Module):
|
||||
"""Fourier feature transform: [B] -> [B, D]."""
|
||||
# Commented out Fourier Features (not used in Anima). Kept for reference.
|
||||
# class FourierFeatures(nn.Module):
|
||||
# """Fourier feature transform: [B] -> [B, D]."""
|
||||
|
||||
def __init__(self, num_channels: int, bandwidth: int = 1, normalize: bool = False):
|
||||
super().__init__()
|
||||
self.register_buffer("freqs", 2 * np.pi * bandwidth * torch.randn(num_channels), persistent=True)
|
||||
self.register_buffer("phases", 2 * np.pi * torch.rand(num_channels), persistent=True)
|
||||
self.gain = np.sqrt(2) if normalize else 1
|
||||
self.bandwidth = bandwidth
|
||||
self.num_channels = num_channels
|
||||
self.reset_parameters()
|
||||
# def __init__(self, num_channels: int, bandwidth: int = 1, normalize: bool = False):
|
||||
# super().__init__()
|
||||
# self.register_buffer("freqs", 2 * np.pi * bandwidth * torch.randn(num_channels), persistent=True)
|
||||
# self.register_buffer("phases", 2 * np.pi * torch.rand(num_channels), persistent=True)
|
||||
# self.gain = np.sqrt(2) if normalize else 1
|
||||
# self.bandwidth = bandwidth
|
||||
# self.num_channels = num_channels
|
||||
# self.reset_parameters()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(0)
|
||||
self.freqs = (
|
||||
2 * np.pi * self.bandwidth * torch.randn(self.num_channels, generator=generator).to(self.freqs.device)
|
||||
)
|
||||
self.phases = 2 * np.pi * torch.rand(self.num_channels, generator=generator).to(self.freqs.device)
|
||||
# def reset_parameters(self) -> None:
|
||||
# generator = torch.Generator()
|
||||
# generator.manual_seed(0)
|
||||
# self.freqs = 2 * np.pi * self.bandwidth * torch.randn(self.num_channels, generator=generator).to(self.freqs.device)
|
||||
# self.phases = 2 * np.pi * torch.rand(self.num_channels, generator=generator).to(self.freqs.device)
|
||||
|
||||
def forward(self, x: torch.Tensor, gain: float = 1.0) -> torch.Tensor:
|
||||
in_dtype = x.dtype
|
||||
x = x.to(torch.float32).ger(self.freqs.to(torch.float32)).add(self.phases.to(torch.float32))
|
||||
x = x.cos().mul(self.gain * gain).to(in_dtype)
|
||||
return x
|
||||
# def forward(self, x: torch.Tensor, gain: float = 1.0) -> torch.Tensor:
|
||||
# in_dtype = x.dtype
|
||||
# x = x.to(torch.float32).ger(self.freqs.to(torch.float32)).add(self.phases.to(torch.float32))
|
||||
# x = x.cos().mul(self.gain * gain).to(in_dtype)
|
||||
# return x
|
||||
|
||||
|
||||
# Patch Embedding
|
||||
@@ -713,9 +668,7 @@ class PatchEmbed(nn.Module):
|
||||
m=spatial_patch_size,
|
||||
n=spatial_patch_size,
|
||||
),
|
||||
nn.Linear(
|
||||
in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=False
|
||||
),
|
||||
nn.Linear(in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=False),
|
||||
)
|
||||
self.dim = in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size
|
||||
|
||||
@@ -765,9 +718,7 @@ class FinalLayer(nn.Module):
|
||||
nn.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False),
|
||||
)
|
||||
else:
|
||||
self.adaln_modulation = nn.Sequential(
|
||||
nn.SiLU(), nn.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False)
|
||||
)
|
||||
self.adaln_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False))
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@@ -790,9 +741,9 @@ class FinalLayer(nn.Module):
|
||||
):
|
||||
if self.use_adaln_lora:
|
||||
assert adaln_lora_B_T_3D is not None
|
||||
shift_B_T_D, scale_B_T_D = (
|
||||
self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size]
|
||||
).chunk(2, dim=-1)
|
||||
shift_B_T_D, scale_B_T_D = (self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size]).chunk(
|
||||
2, dim=-1
|
||||
)
|
||||
else:
|
||||
shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1)
|
||||
|
||||
@@ -833,7 +784,11 @@ class Block(nn.Module):
|
||||
|
||||
self.layer_norm_cross_attn = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.cross_attn = Attention(
|
||||
x_dim, context_dim, num_heads, x_dim // num_heads, qkv_format="bshd",
|
||||
x_dim,
|
||||
context_dim,
|
||||
num_heads,
|
||||
x_dim // num_heads,
|
||||
qkv_format="bshd",
|
||||
)
|
||||
|
||||
self.layer_norm_mlp = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6)
|
||||
@@ -904,6 +859,7 @@ class Block(nn.Module):
|
||||
x_B_T_H_W_D: torch.Tensor,
|
||||
emb_B_T_D: torch.Tensor,
|
||||
crossattn_emb: torch.Tensor,
|
||||
attn_params: attention.AttentionParams,
|
||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
||||
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
||||
@@ -919,13 +875,13 @@ class Block(nn.Module):
|
||||
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = (
|
||||
self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D
|
||||
).chunk(3, dim=-1)
|
||||
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = (
|
||||
self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D
|
||||
).chunk(3, dim=-1)
|
||||
shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = (self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D).chunk(
|
||||
3, dim=-1
|
||||
)
|
||||
else:
|
||||
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn(
|
||||
emb_B_T_D
|
||||
).chunk(3, dim=-1)
|
||||
shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn(emb_B_T_D).chunk(
|
||||
3, dim=-1
|
||||
)
|
||||
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn(
|
||||
emb_B_T_D
|
||||
).chunk(3, dim=-1)
|
||||
@@ -954,11 +910,14 @@ class Block(nn.Module):
|
||||
result = rearrange(
|
||||
self.self_attn(
|
||||
rearrange(normalized_x, "b t h w d -> b (t h w) d"),
|
||||
attn_params,
|
||||
None,
|
||||
rope_emb=rope_emb_L_1_1_D,
|
||||
),
|
||||
"b (t h w) d -> b t h w d",
|
||||
t=T, h=H, w=W,
|
||||
t=T,
|
||||
h=H,
|
||||
w=W,
|
||||
)
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result
|
||||
|
||||
@@ -967,11 +926,14 @@ class Block(nn.Module):
|
||||
result = rearrange(
|
||||
self.cross_attn(
|
||||
rearrange(normalized_x, "b t h w d -> b (t h w) d"),
|
||||
attn_params,
|
||||
crossattn_emb,
|
||||
rope_emb=rope_emb_L_1_1_D,
|
||||
),
|
||||
"b (t h w) d -> b t h w d",
|
||||
t=T, h=H, w=W,
|
||||
t=T,
|
||||
h=H,
|
||||
w=W,
|
||||
)
|
||||
x_B_T_H_W_D = result * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
|
||||
|
||||
@@ -987,6 +949,7 @@ class Block(nn.Module):
|
||||
x_B_T_H_W_D: torch.Tensor,
|
||||
emb_B_T_D: torch.Tensor,
|
||||
crossattn_emb: torch.Tensor,
|
||||
attn_params: attention.AttentionParams,
|
||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
||||
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
||||
@@ -996,8 +959,13 @@ class Block(nn.Module):
|
||||
# Unsloth: async non-blocking CPU RAM offload (fastest offload method)
|
||||
return unsloth_checkpoint(
|
||||
self._forward,
|
||||
x_B_T_H_W_D, emb_B_T_D, crossattn_emb,
|
||||
rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb,
|
||||
x_B_T_H_W_D,
|
||||
emb_B_T_D,
|
||||
crossattn_emb,
|
||||
attn_params,
|
||||
rope_emb_L_1_1_D,
|
||||
adaln_lora_B_T_3D,
|
||||
extra_per_block_pos_emb,
|
||||
)
|
||||
elif self.cpu_offload_checkpointing:
|
||||
# Standard cpu offload: blocking transfers
|
||||
@@ -1008,36 +976,54 @@ class Block(nn.Module):
|
||||
device_inputs = to_device(inputs, device)
|
||||
outputs = func(*device_inputs)
|
||||
return to_cpu(outputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
return torch_checkpoint(
|
||||
create_custom_forward(self._forward),
|
||||
x_B_T_H_W_D, emb_B_T_D, crossattn_emb,
|
||||
rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb,
|
||||
x_B_T_H_W_D,
|
||||
emb_B_T_D,
|
||||
crossattn_emb,
|
||||
attn_params,
|
||||
rope_emb_L_1_1_D,
|
||||
adaln_lora_B_T_3D,
|
||||
extra_per_block_pos_emb,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
# Standard gradient checkpointing (no offload)
|
||||
return torch_checkpoint(
|
||||
self._forward,
|
||||
x_B_T_H_W_D, emb_B_T_D, crossattn_emb,
|
||||
rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb,
|
||||
x_B_T_H_W_D,
|
||||
emb_B_T_D,
|
||||
crossattn_emb,
|
||||
attn_params,
|
||||
rope_emb_L_1_1_D,
|
||||
adaln_lora_B_T_3D,
|
||||
extra_per_block_pos_emb,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
return self._forward(
|
||||
x_B_T_H_W_D, emb_B_T_D, crossattn_emb,
|
||||
rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb,
|
||||
x_B_T_H_W_D,
|
||||
emb_B_T_D,
|
||||
crossattn_emb,
|
||||
attn_params,
|
||||
rope_emb_L_1_1_D,
|
||||
adaln_lora_B_T_3D,
|
||||
extra_per_block_pos_emb,
|
||||
)
|
||||
|
||||
|
||||
# Main DiT Model: MiniTrainDIT
|
||||
class MiniTrainDIT(nn.Module):
|
||||
# Main DiT Model: MiniTrainDIT (renamed to Anima)
|
||||
class Anima(nn.Module):
|
||||
"""Cosmos-Predict2 DiT model for image/video generation.
|
||||
|
||||
28 transformer blocks with AdaLN-LoRA modulation, 3D RoPE, and optional LLM Adapter.
|
||||
"""
|
||||
|
||||
LATENT_CHANNELS = 16
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_img_h: int,
|
||||
@@ -1069,6 +1055,8 @@ class MiniTrainDIT(nn.Module):
|
||||
extra_t_extrapolation_ratio: float = 1.0,
|
||||
rope_enable_fps_modulation: bool = True,
|
||||
use_llm_adapter: bool = False,
|
||||
attn_mode: str = "torch",
|
||||
split_attn: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.max_img_h = max_img_h
|
||||
@@ -1097,6 +1085,9 @@ class MiniTrainDIT(nn.Module):
|
||||
self.rope_enable_fps_modulation = rope_enable_fps_modulation
|
||||
self.use_llm_adapter = use_llm_adapter
|
||||
|
||||
self.attn_mode = attn_mode
|
||||
self.split_attn = split_attn
|
||||
|
||||
# Block swap support
|
||||
self.blocks_to_swap = None
|
||||
self.offloader: Optional[custom_offloading_utils.ModelOffloader] = None
|
||||
@@ -1156,7 +1147,6 @@ class MiniTrainDIT(nn.Module):
|
||||
self.final_layer.init_weights()
|
||||
self.t_embedding_norm.reset_parameters()
|
||||
|
||||
|
||||
def enable_gradient_checkpointing(self, cpu_offload: bool = False, unsloth_offload: bool = False):
|
||||
for block in self.blocks:
|
||||
block.enable_gradient_checkpointing(cpu_offload=cpu_offload, unsloth_offload=unsloth_offload)
|
||||
@@ -1169,18 +1159,9 @@ class MiniTrainDIT(nn.Module):
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
|
||||
def set_flash_attn(self, use_flash_attn: bool):
|
||||
"""Toggle flash attention for all DiT blocks (self-attn + cross-attn).
|
||||
|
||||
LLM Adapter attention is NOT affected (it uses attention masks incompatible with flash_attn).
|
||||
"""
|
||||
if use_flash_attn and not FLASH_ATTN_AVAILABLE:
|
||||
raise ImportError("flash_attn package is required for --flash_attn but is not installed")
|
||||
attn_op = flash_attention_op if use_flash_attn else torch_attention_op
|
||||
for block in self.blocks:
|
||||
block.self_attn.attn_op = attn_op
|
||||
block.cross_attn.attn_op = attn_op
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
def build_patch_embed(self) -> None:
|
||||
in_channels = self.in_channels + 1 if self.concat_padding_mask else self.in_channels
|
||||
@@ -1232,9 +1213,7 @@ class MiniTrainDIT(nn.Module):
|
||||
padding_mask = transforms.functional.resize(
|
||||
padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
|
||||
)
|
||||
x_B_C_T_H_W = torch.cat(
|
||||
[x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1
|
||||
)
|
||||
x_B_C_T_H_W = torch.cat([x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1)
|
||||
x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W)
|
||||
|
||||
if self.extra_per_block_abs_pos_emb:
|
||||
@@ -1258,7 +1237,6 @@ class MiniTrainDIT(nn.Module):
|
||||
)
|
||||
return x_B_C_Tt_Hp_Wp
|
||||
|
||||
|
||||
def enable_block_swap(self, num_blocks: int, device: torch.device):
|
||||
self.blocks_to_swap = num_blocks
|
||||
|
||||
@@ -1266,9 +1244,7 @@ class MiniTrainDIT(nn.Module):
|
||||
self.blocks_to_swap <= self.num_blocks - 2
|
||||
), f"Cannot swap more than {self.num_blocks - 2} blocks. Requested: {self.blocks_to_swap} blocks."
|
||||
|
||||
self.offloader = custom_offloading_utils.ModelOffloader(
|
||||
self.blocks, self.blocks_to_swap, device
|
||||
)
|
||||
self.offloader = custom_offloading_utils.ModelOffloader(self.blocks, self.blocks_to_swap, device)
|
||||
logger.info(f"Anima: Block swap enabled. Swapping {num_blocks} blocks, total blocks: {self.num_blocks}, device: {device}.")
|
||||
|
||||
def move_to_device_except_swap_blocks(self, device: torch.device):
|
||||
@@ -1282,12 +1258,26 @@ class MiniTrainDIT(nn.Module):
|
||||
if self.blocks_to_swap:
|
||||
self.blocks = save_blocks
|
||||
|
||||
def switch_block_swap_for_inference(self):
|
||||
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
||||
return
|
||||
self.offloader.set_forward_only(True)
|
||||
self.prepare_block_swap_before_forward()
|
||||
print(f"Anima: Block swap set to forward only.")
|
||||
|
||||
def switch_block_swap_for_training(self):
|
||||
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
||||
return
|
||||
self.offloader.set_forward_only(False)
|
||||
self.prepare_block_swap_before_forward()
|
||||
print(f"Anima: Block swap set to forward and backward.")
|
||||
|
||||
def prepare_block_swap_before_forward(self):
|
||||
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
||||
return
|
||||
self.offloader.prepare_block_devices_before_forward(self.blocks)
|
||||
|
||||
def forward(
|
||||
def forward_mini_train_dit(
|
||||
self,
|
||||
x_B_C_T_H_W: torch.Tensor,
|
||||
timesteps_B_T: torch.Tensor,
|
||||
@@ -1310,7 +1300,7 @@ class MiniTrainDIT(nn.Module):
|
||||
t5_attn_mask: Optional T5 attention mask
|
||||
"""
|
||||
# Run LLM adapter inside forward for correct DDP gradient synchronization
|
||||
if t5_input_ids is not None and self.use_llm_adapter and hasattr(self, 'llm_adapter'):
|
||||
if t5_input_ids is not None and self.use_llm_adapter and hasattr(self, "llm_adapter"):
|
||||
crossattn_emb = self.llm_adapter(
|
||||
source_hidden_states=crossattn_emb,
|
||||
target_input_ids=t5_input_ids,
|
||||
@@ -1337,16 +1327,13 @@ class MiniTrainDIT(nn.Module):
|
||||
"extra_per_block_pos_emb": extra_pos_emb,
|
||||
}
|
||||
|
||||
attn_params = attention.AttentionParams.create_attention_params(self.attn_mode, self.split_attn)
|
||||
|
||||
for block_idx, block in enumerate(self.blocks):
|
||||
if self.blocks_to_swap:
|
||||
self.offloader.wait_for_block(block_idx)
|
||||
|
||||
x_B_T_H_W_D = block(
|
||||
x_B_T_H_W_D,
|
||||
t_embedding_B_T_D,
|
||||
crossattn_emb,
|
||||
**block_kwargs,
|
||||
)
|
||||
x_B_T_H_W_D = block(x_B_T_H_W_D, t_embedding_B_T_D, crossattn_emb, attn_params, **block_kwargs)
|
||||
|
||||
if self.blocks_to_swap:
|
||||
self.offloader.submit_move_blocks(self.blocks, block_idx)
|
||||
@@ -1355,6 +1342,36 @@ class MiniTrainDIT(nn.Module):
|
||||
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)
|
||||
return x_B_C_Tt_Hp_Wp
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
fps: Optional[torch.Tensor] = None,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
target_input_ids: Optional[torch.Tensor] = None,
|
||||
target_attention_mask: Optional[torch.Tensor] = None,
|
||||
source_attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
context = self._preprocess_text_embeds(context, target_input_ids, target_attention_mask, source_attention_mask)
|
||||
return self.forward_mini_train_dit(x, timesteps, context, fps=fps, padding_mask=padding_mask, **kwargs)
|
||||
|
||||
def _preprocess_text_embeds(
|
||||
self, source_hidden_states, target_input_ids, target_attention_mask=None, source_attention_mask=None
|
||||
):
|
||||
if target_input_ids is not None:
|
||||
context = self.llm_adapter(
|
||||
source_hidden_states,
|
||||
target_input_ids,
|
||||
target_attention_mask=target_attention_mask,
|
||||
source_attention_mask=source_attention_mask,
|
||||
)
|
||||
context[~target_attention_mask.bool()] = 0 # zero out padding tokens
|
||||
return context
|
||||
else:
|
||||
return source_hidden_states
|
||||
|
||||
|
||||
# LLM Adapter: Bridges Qwen3 embeddings to T5-compatible space
|
||||
class LLMAdapterRMSNorm(nn.Module):
|
||||
@@ -1485,24 +1502,37 @@ class LLMAdapterTransformerBlock(nn.Module):
|
||||
|
||||
self.norm_mlp = nn.LayerNorm(model_dim) if layer_norm else LLMAdapterRMSNorm(model_dim)
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(model_dim, int(model_dim * mlp_ratio)),
|
||||
nn.GELU(),
|
||||
nn.Linear(int(model_dim * mlp_ratio), model_dim)
|
||||
nn.Linear(model_dim, int(model_dim * mlp_ratio)), nn.GELU(), nn.Linear(int(model_dim * mlp_ratio), model_dim)
|
||||
)
|
||||
|
||||
def forward(self, x, context, target_attention_mask=None, source_attention_mask=None,
|
||||
position_embeddings=None, position_embeddings_context=None):
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
context,
|
||||
target_attention_mask=None,
|
||||
source_attention_mask=None,
|
||||
position_embeddings=None,
|
||||
position_embeddings_context=None,
|
||||
):
|
||||
if self.has_self_attn:
|
||||
# Self-attention: target_attention_mask is not expected to be all zeros
|
||||
normed = self.norm_self_attn(x)
|
||||
attn_out = self.self_attn(normed, mask=target_attention_mask,
|
||||
attn_out = self.self_attn(
|
||||
normed,
|
||||
mask=target_attention_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
position_embeddings_context=position_embeddings)
|
||||
position_embeddings_context=position_embeddings,
|
||||
)
|
||||
x = x + attn_out
|
||||
|
||||
normed = self.norm_cross_attn(x)
|
||||
attn_out = self.cross_attn(normed, mask=source_attention_mask, context=context,
|
||||
attn_out = self.cross_attn(
|
||||
normed,
|
||||
mask=source_attention_mask,
|
||||
context=context,
|
||||
position_embeddings=position_embeddings,
|
||||
position_embeddings_context=position_embeddings_context)
|
||||
position_embeddings_context=position_embeddings_context,
|
||||
)
|
||||
x = x + attn_out
|
||||
|
||||
x = x + self.mlp(self.norm_mlp(x))
|
||||
@@ -1518,8 +1548,9 @@ class LLMAdapter(nn.Module):
|
||||
Uses T5 token IDs as target input, embeds them, and cross-attends to Qwen3 hidden states.
|
||||
"""
|
||||
|
||||
def __init__(self, source_dim, target_dim, model_dim, num_layers=6, num_heads=16,
|
||||
embed=None, self_attn=False, layer_norm=False):
|
||||
def __init__(
|
||||
self, source_dim, target_dim, model_dim, num_layers=6, num_heads=16, embed=None, self_attn=False, layer_norm=False
|
||||
):
|
||||
super().__init__()
|
||||
if embed is not None:
|
||||
self.embed = nn.Embedding.from_pretrained(embed.weight)
|
||||
@@ -1530,11 +1561,12 @@ class LLMAdapter(nn.Module):
|
||||
else:
|
||||
self.in_proj = nn.Identity()
|
||||
self.rotary_emb = AdapterRotaryEmbedding(model_dim // num_heads)
|
||||
self.blocks = nn.ModuleList([
|
||||
LLMAdapterTransformerBlock(source_dim, model_dim, num_heads=num_heads,
|
||||
self_attn=self_attn, layer_norm=layer_norm)
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
LLMAdapterTransformerBlock(source_dim, model_dim, num_heads=num_heads, self_attn=self_attn, layer_norm=layer_norm)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
]
|
||||
)
|
||||
self.out_proj = nn.Linear(model_dim, target_dim)
|
||||
self.norm = LLMAdapterRMSNorm(target_dim)
|
||||
|
||||
@@ -1556,75 +1588,67 @@ class LLMAdapter(nn.Module):
|
||||
position_embeddings = self.rotary_emb(x, position_ids)
|
||||
position_embeddings_context = self.rotary_emb(x, position_ids_context)
|
||||
for block in self.blocks:
|
||||
x = block(x, context, target_attention_mask=target_attention_mask,
|
||||
x = block(
|
||||
x,
|
||||
context,
|
||||
target_attention_mask=target_attention_mask,
|
||||
source_attention_mask=source_attention_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
position_embeddings_context=position_embeddings_context)
|
||||
position_embeddings_context=position_embeddings_context,
|
||||
)
|
||||
return self.norm(self.out_proj(x))
|
||||
|
||||
|
||||
# VAE Wrapper
|
||||
# Not used currently, but kept for reference
|
||||
|
||||
# VAE normalization constants
|
||||
ANIMA_VAE_MEAN = [
|
||||
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
|
||||
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
|
||||
]
|
||||
ANIMA_VAE_STD = [
|
||||
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
|
||||
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
|
||||
]
|
||||
# def get_dit_config(state_dict, key_prefix=""):
|
||||
# """Derive DiT configuration from state_dict weight shapes."""
|
||||
# dit_config = {}
|
||||
# dit_config["max_img_h"] = 512
|
||||
# dit_config["max_img_w"] = 512
|
||||
# dit_config["max_frames"] = 128
|
||||
# concat_padding_mask = True
|
||||
# dit_config["in_channels"] = (state_dict["{}x_embedder.proj.1.weight".format(key_prefix)].shape[1] // 4) - int(
|
||||
# concat_padding_mask
|
||||
# )
|
||||
# dit_config["out_channels"] = 16
|
||||
# dit_config["patch_spatial"] = 2
|
||||
# dit_config["patch_temporal"] = 1
|
||||
# dit_config["model_channels"] = state_dict["{}x_embedder.proj.1.weight".format(key_prefix)].shape[0]
|
||||
# dit_config["concat_padding_mask"] = concat_padding_mask
|
||||
# dit_config["crossattn_emb_channels"] = 1024
|
||||
# dit_config["pos_emb_cls"] = "rope3d"
|
||||
# dit_config["pos_emb_learnable"] = True
|
||||
# dit_config["pos_emb_interpolation"] = "crop"
|
||||
# dit_config["min_fps"] = 1
|
||||
# dit_config["max_fps"] = 30
|
||||
|
||||
# DiT config detection from state_dict
|
||||
KEEP_IN_HIGH_PRECISION = ['x_embedder', 't_embedder', 't_embedding_norm', 'final_layer']
|
||||
# dit_config["use_adaln_lora"] = True
|
||||
# dit_config["adaln_lora_dim"] = 256
|
||||
# if dit_config["model_channels"] == 2048:
|
||||
# dit_config["num_blocks"] = 28
|
||||
# dit_config["num_heads"] = 16
|
||||
# elif dit_config["model_channels"] == 5120:
|
||||
# dit_config["num_blocks"] = 36
|
||||
# dit_config["num_heads"] = 40
|
||||
# elif dit_config["model_channels"] == 1280:
|
||||
# dit_config["num_blocks"] = 20
|
||||
# dit_config["num_heads"] = 20
|
||||
|
||||
# if dit_config["in_channels"] == 16:
|
||||
# dit_config["extra_per_block_abs_pos_emb"] = False
|
||||
# dit_config["rope_h_extrapolation_ratio"] = 4.0
|
||||
# dit_config["rope_w_extrapolation_ratio"] = 4.0
|
||||
# dit_config["rope_t_extrapolation_ratio"] = 1.0
|
||||
# elif dit_config["in_channels"] == 17:
|
||||
# dit_config["extra_per_block_abs_pos_emb"] = False
|
||||
# dit_config["rope_h_extrapolation_ratio"] = 3.0
|
||||
# dit_config["rope_w_extrapolation_ratio"] = 3.0
|
||||
# dit_config["rope_t_extrapolation_ratio"] = 1.0
|
||||
|
||||
def get_dit_config(state_dict, key_prefix=''):
|
||||
"""Derive DiT configuration from state_dict weight shapes."""
|
||||
dit_config = {}
|
||||
dit_config["max_img_h"] = 512
|
||||
dit_config["max_img_w"] = 512
|
||||
dit_config["max_frames"] = 128
|
||||
concat_padding_mask = True
|
||||
dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask)
|
||||
dit_config["out_channels"] = 16
|
||||
dit_config["patch_spatial"] = 2
|
||||
dit_config["patch_temporal"] = 1
|
||||
dit_config["model_channels"] = state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[0]
|
||||
dit_config["concat_padding_mask"] = concat_padding_mask
|
||||
dit_config["crossattn_emb_channels"] = 1024
|
||||
dit_config["pos_emb_cls"] = "rope3d"
|
||||
dit_config["pos_emb_learnable"] = True
|
||||
dit_config["pos_emb_interpolation"] = "crop"
|
||||
dit_config["min_fps"] = 1
|
||||
dit_config["max_fps"] = 30
|
||||
# dit_config["extra_h_extrapolation_ratio"] = 1.0
|
||||
# dit_config["extra_w_extrapolation_ratio"] = 1.0
|
||||
# dit_config["extra_t_extrapolation_ratio"] = 1.0
|
||||
# dit_config["rope_enable_fps_modulation"] = False
|
||||
|
||||
dit_config["use_adaln_lora"] = True
|
||||
dit_config["adaln_lora_dim"] = 256
|
||||
if dit_config["model_channels"] == 2048:
|
||||
dit_config["num_blocks"] = 28
|
||||
dit_config["num_heads"] = 16
|
||||
elif dit_config["model_channels"] == 5120:
|
||||
dit_config["num_blocks"] = 36
|
||||
dit_config["num_heads"] = 40
|
||||
elif dit_config["model_channels"] == 1280:
|
||||
dit_config["num_blocks"] = 20
|
||||
dit_config["num_heads"] = 20
|
||||
|
||||
if dit_config["in_channels"] == 16:
|
||||
dit_config["extra_per_block_abs_pos_emb"] = False
|
||||
dit_config["rope_h_extrapolation_ratio"] = 4.0
|
||||
dit_config["rope_w_extrapolation_ratio"] = 4.0
|
||||
dit_config["rope_t_extrapolation_ratio"] = 1.0
|
||||
elif dit_config["in_channels"] == 17:
|
||||
dit_config["extra_per_block_abs_pos_emb"] = False
|
||||
dit_config["rope_h_extrapolation_ratio"] = 3.0
|
||||
dit_config["rope_w_extrapolation_ratio"] = 3.0
|
||||
dit_config["rope_t_extrapolation_ratio"] = 1.0
|
||||
|
||||
dit_config["extra_h_extrapolation_ratio"] = 1.0
|
||||
dit_config["extra_w_extrapolation_ratio"] = 1.0
|
||||
dit_config["extra_t_extrapolation_ratio"] = 1.0
|
||||
dit_config["rope_enable_fps_modulation"] = False
|
||||
|
||||
return dit_config
|
||||
# return dit_config
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
# Anima Training Utilities
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from safetensors.torch import save_file
|
||||
from accelerate import Accelerator, PartialState
|
||||
from accelerate import Accelerator
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
from library.device_utils import init_ipex, clean_memory_on_device, synchronize_device
|
||||
from library import anima_models, anima_utils, train_util, qwen_image_autoencoder_kl
|
||||
|
||||
init_ipex()
|
||||
|
||||
@@ -25,29 +25,14 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from library import anima_models, anima_utils, strategy_base, train_util
|
||||
|
||||
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler, get_sigmas
|
||||
|
||||
|
||||
# Anima-specific training arguments
|
||||
|
||||
|
||||
def add_anima_training_arguments(parser: argparse.ArgumentParser):
|
||||
"""Add Anima-specific training arguments to the parser."""
|
||||
parser.add_argument(
|
||||
"--dit_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to Anima DiT model safetensors file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to WanVAE safetensors/pth file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--qwen3_path",
|
||||
"--qwen3",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to Qwen3-0.6B model (safetensors file or directory)",
|
||||
@@ -86,7 +71,7 @@ def add_anima_training_arguments(parser: argparse.ArgumentParser):
|
||||
"--mod_lr",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Learning rate for AdaLN modulation layers. None=same as base LR, 0=freeze",
|
||||
help="Learning rate for AdaLN modulation layers. None=same as base LR, 0=freeze. Note: mod layers are not included in LoRA by default.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--t5_tokenizer_path",
|
||||
@@ -113,110 +98,52 @@ def add_anima_training_arguments(parser: argparse.ArgumentParser):
|
||||
help="Timestep distribution shift for rectified flow training (default: 1.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timestep_sample_method",
|
||||
"--timestep_sampling",
|
||||
type=str,
|
||||
default="logit_normal",
|
||||
choices=["logit_normal", "uniform"],
|
||||
help="Timestep sampling method (default: logit_normal)",
|
||||
default="sigmoid",
|
||||
choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"],
|
||||
help="Timestep sampling method (default: sigmoid (logit normal))",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sigmoid_scale",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Scale factor for logit_normal timestep sampling (default: 1.0)",
|
||||
help="Scale factor for sigmoid (logit_normal) timestep sampling (default: 1.0)",
|
||||
)
|
||||
# Note: --caption_dropout_rate is defined by base add_dataset_arguments().
|
||||
# Anima uses embedding-level dropout (via AnimaTextEncodingStrategy.dropout_rate)
|
||||
# instead of dataset-level caption dropout, so the subset caption_dropout_rate
|
||||
# is zeroed out in the training scripts to allow caching.
|
||||
parser.add_argument(
|
||||
"--transformer_dtype",
|
||||
type=str,
|
||||
"--attn_mode",
|
||||
choices=["torch", "xformers", "flash", "sageattn", "sdpa"], # "sdpa" is for backward compatibility
|
||||
default=None,
|
||||
choices=["float16", "bfloat16", "float32", None],
|
||||
help="Separate dtype for transformer blocks. If None, uses same as mixed_precision",
|
||||
help="Attention implementation to use. Default is None (torch). xformers requires --split_attn. sageattn does not support training (inference only). This option overrides --xformers or --sdpa."
|
||||
" / 使用するAttentionの実装。デフォルトはNone(torch)です。xformersは--split_attnの指定が必要です。sageattnはトレーニングをサポートしていません(推論のみ)。このオプションは--xformersまたは--sdpaを上書きします。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--flash_attn",
|
||||
"--split_attn",
|
||||
action="store_true",
|
||||
help="Use Flash Attention for DiT self/cross-attention (requires flash-attn package). "
|
||||
"Falls back to PyTorch SDPA if flash-attn is not installed.",
|
||||
help="split attention computation to reduce memory usage / メモリ使用量を減らすためにattention時にバッチを分割する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_chunk_size",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Spatial chunk size for VAE encoding/decoding to reduce memory usage. Must be even number. If not specified, chunking is disabled (official behavior)."
|
||||
+ " / メモリ使用量を減らすためのVAEエンコード/デコードの空間チャンクサイズ。偶数である必要があります。未指定の場合、チャンク処理は無効になります(公式の動作)。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_disable_cache",
|
||||
action="store_true",
|
||||
help="Disable internal VAE caching mechanism to reduce memory usage. Encoding / decoding will also be faster, but this differs from official behavior."
|
||||
+ " / VAEのメモリ使用量を減らすために内部のキャッシュ機構を無効にします。エンコード/デコードも速くなりますが、公式の動作とは異なります。",
|
||||
)
|
||||
|
||||
|
||||
# Noise & Timestep sampling (Rectified Flow)
|
||||
def get_noisy_model_input_and_timesteps(
|
||||
args,
|
||||
latents: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Generate noisy model input and timesteps for rectified flow training.
|
||||
|
||||
Rectified flow: noisy_input = (1 - t) * latents + t * noise
|
||||
Target: noise - latents
|
||||
|
||||
Args:
|
||||
args: Training arguments with timestep_sample_method, sigmoid_scale, discrete_flow_shift
|
||||
latents: Clean latent tensors
|
||||
noise: Random noise tensors
|
||||
device: Target device
|
||||
dtype: Target dtype
|
||||
|
||||
Returns:
|
||||
(noisy_model_input, timesteps, sigmas)
|
||||
"""
|
||||
bs = latents.shape[0]
|
||||
|
||||
timestep_sample_method = getattr(args, 'timestep_sample_method', 'logit_normal')
|
||||
sigmoid_scale = getattr(args, 'sigmoid_scale', 1.0)
|
||||
shift = getattr(args, 'discrete_flow_shift', 1.0)
|
||||
|
||||
if timestep_sample_method == 'logit_normal':
|
||||
dist = torch.distributions.normal.Normal(0, 1)
|
||||
elif timestep_sample_method == 'uniform':
|
||||
dist = torch.distributions.uniform.Uniform(0, 1)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown timestep_sample_method: {timestep_sample_method}")
|
||||
|
||||
t = dist.sample((bs,)).to(device)
|
||||
|
||||
if timestep_sample_method == 'logit_normal':
|
||||
t = t * sigmoid_scale
|
||||
t = torch.sigmoid(t)
|
||||
|
||||
# Apply shift
|
||||
if shift is not None and shift != 1.0:
|
||||
t = (t * shift) / (1 + (shift - 1) * t)
|
||||
|
||||
# Clamp to avoid exact 0 or 1
|
||||
t = t.clamp(1e-5, 1.0 - 1e-5)
|
||||
|
||||
# Create noisy input: (1 - t) * latents + t * noise
|
||||
t_expanded = t.view(-1, *([1] * (latents.ndim - 1)))
|
||||
|
||||
ip_noise_gamma = getattr(args, 'ip_noise_gamma', None)
|
||||
if ip_noise_gamma:
|
||||
xi = torch.randn_like(latents, device=latents.device, dtype=dtype)
|
||||
if getattr(args, 'ip_noise_gamma_random_strength', False):
|
||||
ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * ip_noise_gamma
|
||||
noisy_model_input = (1 - t_expanded) * latents + t_expanded * (noise + ip_noise_gamma * xi)
|
||||
else:
|
||||
noisy_model_input = (1 - t_expanded) * latents + t_expanded * noise
|
||||
|
||||
# Sigmas for potential loss weighting
|
||||
sigmas = t.view(-1, 1)
|
||||
|
||||
return noisy_model_input.to(dtype), t.to(dtype), sigmas.to(dtype)
|
||||
|
||||
|
||||
# Loss weighting
|
||||
|
||||
|
||||
def compute_loss_weighting_for_anima(weighting_scheme: str, sigmas: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute loss weighting for Anima training.
|
||||
|
||||
Same schemes as SD3 but can add Anima-specific ones.
|
||||
Same schemes as SD3 but can add Anima-specific ones if needed in future.
|
||||
"""
|
||||
if weighting_scheme == "sigma_sqrt":
|
||||
weighting = (sigmas**-2.0).float()
|
||||
@@ -243,7 +170,7 @@ def get_anima_param_groups(
|
||||
"""Create parameter groups for Anima training with separate learning rates.
|
||||
|
||||
Args:
|
||||
dit: MiniTrainDIT model
|
||||
dit: Anima model
|
||||
base_lr: Base learning rate
|
||||
self_attn_lr: LR for self-attention layers (None = base_lr, 0 = freeze)
|
||||
cross_attn_lr: LR for cross-attention layers
|
||||
@@ -276,15 +203,15 @@ def get_anima_param_groups(
|
||||
# Store original name for debugging
|
||||
p.original_name = name
|
||||
|
||||
if 'llm_adapter' in name:
|
||||
if "llm_adapter" in name:
|
||||
llm_adapter_params.append(p)
|
||||
elif '.self_attn' in name:
|
||||
elif ".self_attn" in name:
|
||||
self_attn_params.append(p)
|
||||
elif '.cross_attn' in name:
|
||||
elif ".cross_attn" in name:
|
||||
cross_attn_params.append(p)
|
||||
elif '.mlp' in name:
|
||||
elif ".mlp" in name:
|
||||
mlp_params.append(p)
|
||||
elif '.adaln_modulation' in name:
|
||||
elif ".adaln_modulation" in name:
|
||||
mod_params.append(p)
|
||||
else:
|
||||
base_params.append(p)
|
||||
@@ -311,9 +238,9 @@ def get_anima_param_groups(
|
||||
p.requires_grad_(False)
|
||||
logger.info(f" Frozen {name} params ({len(params)} parameters)")
|
||||
elif len(params) > 0:
|
||||
param_groups.append({'params': params, 'lr': lr})
|
||||
param_groups.append({"params": params, "lr": lr})
|
||||
|
||||
total_trainable = sum(p.numel() for group in param_groups for p in group['params'] if p.requires_grad)
|
||||
total_trainable = sum(p.numel() for group in param_groups for p in group["params"] if p.requires_grad)
|
||||
logger.info(f"Total trainable parameters: {total_trainable:,}")
|
||||
|
||||
return param_groups
|
||||
@@ -325,16 +252,17 @@ def save_anima_model_on_train_end(
|
||||
save_dtype: torch.dtype,
|
||||
epoch: int,
|
||||
global_step: int,
|
||||
dit: anima_models.MiniTrainDIT,
|
||||
dit: anima_models.Anima,
|
||||
):
|
||||
"""Save Anima model at the end of training."""
|
||||
|
||||
def sd_saver(ckpt_file, epoch_no, global_step):
|
||||
sai_metadata = train_util.get_sai_model_spec(
|
||||
None, args, False, False, False, is_stable_diffusion_ckpt=True
|
||||
)
|
||||
sai_metadata = train_util.get_sai_model_spec_dataclass(
|
||||
None, args, False, False, False, is_stable_diffusion_ckpt=True, anima="preview"
|
||||
).to_metadata_dict()
|
||||
dit_sd = dit.state_dict()
|
||||
# Save with 'net.' prefix for ComfyUI compatibility
|
||||
anima_utils.save_anima_model(ckpt_file, dit_sd, save_dtype)
|
||||
anima_utils.save_anima_model(ckpt_file, dit_sd, sai_metadata, save_dtype)
|
||||
|
||||
train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None)
|
||||
|
||||
@@ -347,15 +275,16 @@ def save_anima_model_on_epoch_end_or_stepwise(
|
||||
epoch: int,
|
||||
num_train_epochs: int,
|
||||
global_step: int,
|
||||
dit: anima_models.MiniTrainDIT,
|
||||
dit: anima_models.Anima,
|
||||
):
|
||||
"""Save Anima model at epoch end or specific steps."""
|
||||
|
||||
def sd_saver(ckpt_file, epoch_no, global_step):
|
||||
sai_metadata = train_util.get_sai_model_spec(
|
||||
None, args, False, False, False, is_stable_diffusion_ckpt=True
|
||||
)
|
||||
sai_metadata = train_util.get_sai_model_spec_dataclass(
|
||||
None, args, False, False, False, is_stable_diffusion_ckpt=True, anima="preview"
|
||||
).to_metadata_dict()
|
||||
dit_sd = dit.state_dict()
|
||||
anima_utils.save_anima_model(ckpt_file, dit_sd, save_dtype)
|
||||
anima_utils.save_anima_model(ckpt_file, dit_sd, sai_metadata, save_dtype)
|
||||
|
||||
train_util.save_sd_model_on_epoch_end_or_stepwise_common(
|
||||
args,
|
||||
@@ -376,12 +305,13 @@ def do_sample(
|
||||
height: int,
|
||||
width: int,
|
||||
seed: Optional[int],
|
||||
dit: anima_models.MiniTrainDIT,
|
||||
dit: anima_models.Anima,
|
||||
crossattn_emb: torch.Tensor,
|
||||
steps: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
guidance_scale: float = 1.0,
|
||||
flow_shift: float = 3.0,
|
||||
neg_crossattn_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Generate a sample using Euler discrete sampling for rectified flow.
|
||||
@@ -389,12 +319,13 @@ def do_sample(
|
||||
Args:
|
||||
height, width: Output image dimensions
|
||||
seed: Random seed (None for random)
|
||||
dit: MiniTrainDIT model
|
||||
dit: Anima model
|
||||
crossattn_emb: Cross-attention embeddings (B, N, D)
|
||||
steps: Number of sampling steps
|
||||
dtype: Compute dtype
|
||||
device: Compute device
|
||||
guidance_scale: CFG scale (1.0 = no guidance)
|
||||
flow_shift: Flow shift parameter for rectified flow
|
||||
neg_crossattn_emb: Negative cross-attention embeddings for CFG
|
||||
|
||||
Returns:
|
||||
@@ -410,12 +341,13 @@ def do_sample(
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = None
|
||||
noise = torch.randn(
|
||||
latent.size(), dtype=torch.float32, generator=generator, device="cpu"
|
||||
).to(dtype).to(device)
|
||||
noise = torch.randn(latent.size(), dtype=torch.float32, generator=generator, device="cpu").to(dtype).to(device)
|
||||
|
||||
# Timestep schedule: linear from 1.0 to 0.0
|
||||
sigmas = torch.linspace(1.0, 0.0, steps + 1, device=device, dtype=dtype)
|
||||
flow_shift = float(flow_shift)
|
||||
if flow_shift != 1.0:
|
||||
sigmas = (sigmas * flow_shift) / (1 + (flow_shift - 1) * sigmas)
|
||||
|
||||
# Start from pure noise
|
||||
x = noise.clone()
|
||||
@@ -429,19 +361,13 @@ def do_sample(
|
||||
sigma = sigmas[i]
|
||||
t = sigma.unsqueeze(0) # (1,)
|
||||
|
||||
dit.prepare_block_swap_before_forward()
|
||||
|
||||
if use_cfg:
|
||||
# CFG: concat positive and negative
|
||||
x_input = torch.cat([x, x], dim=0)
|
||||
t_input = torch.cat([t, t], dim=0)
|
||||
crossattn_input = torch.cat([crossattn_emb, neg_crossattn_emb], dim=0)
|
||||
padding_input = torch.cat([padding_mask, padding_mask], dim=0)
|
||||
# CFG: two separate passes to reduce memory usage
|
||||
pos_out = dit(x, t, crossattn_emb, padding_mask=padding_mask)
|
||||
pos_out = pos_out.float()
|
||||
neg_out = dit(x, t, neg_crossattn_emb, padding_mask=padding_mask)
|
||||
neg_out = neg_out.float()
|
||||
|
||||
model_output = dit(x_input, t_input, crossattn_input, padding_mask=padding_input)
|
||||
model_output = model_output.float()
|
||||
|
||||
pos_out, neg_out = model_output.chunk(2)
|
||||
model_output = neg_out + guidance_scale * (pos_out - neg_out)
|
||||
else:
|
||||
model_output = dit(x, t, crossattn_emb, padding_mask=padding_mask)
|
||||
@@ -452,7 +378,6 @@ def do_sample(
|
||||
x = x + model_output * dt
|
||||
x = x.to(dtype)
|
||||
|
||||
dit.prepare_block_swap_before_forward()
|
||||
return x
|
||||
|
||||
|
||||
@@ -461,9 +386,8 @@ def sample_images(
|
||||
args: argparse.Namespace,
|
||||
epoch,
|
||||
steps,
|
||||
dit,
|
||||
dit: anima_models.Anima,
|
||||
vae,
|
||||
vae_scale,
|
||||
text_encoder,
|
||||
tokenize_strategy,
|
||||
text_encoding_strategy,
|
||||
@@ -497,6 +421,8 @@ def sample_images(
|
||||
if text_encoder is not None:
|
||||
text_encoder = accelerator.unwrap_model(text_encoder)
|
||||
|
||||
dit.switch_block_swap_for_inference()
|
||||
|
||||
prompts = train_util.load_prompts(args.sample_prompts)
|
||||
save_dir = os.path.join(args.output_dir, "sample")
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
@@ -511,11 +437,21 @@ def sample_images(
|
||||
|
||||
with torch.no_grad(), accelerator.autocast():
|
||||
for prompt_dict in prompts:
|
||||
dit.prepare_block_swap_before_forward()
|
||||
_sample_image_inference(
|
||||
accelerator, args, dit, text_encoder, vae, vae_scale,
|
||||
tokenize_strategy, text_encoding_strategy,
|
||||
save_dir, prompt_dict, epoch, steps,
|
||||
sample_prompts_te_outputs, prompt_replacement,
|
||||
accelerator,
|
||||
args,
|
||||
dit,
|
||||
text_encoder,
|
||||
vae,
|
||||
tokenize_strategy,
|
||||
text_encoding_strategy,
|
||||
save_dir,
|
||||
prompt_dict,
|
||||
epoch,
|
||||
steps,
|
||||
sample_prompts_te_outputs,
|
||||
prompt_replacement,
|
||||
)
|
||||
|
||||
# Restore RNG state
|
||||
@@ -523,14 +459,24 @@ def sample_images(
|
||||
if cuda_rng_state is not None:
|
||||
torch.cuda.set_rng_state(cuda_rng_state)
|
||||
|
||||
dit.switch_block_swap_for_training()
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
|
||||
def _sample_image_inference(
|
||||
accelerator, args, dit, text_encoder, vae, vae_scale,
|
||||
tokenize_strategy, text_encoding_strategy,
|
||||
save_dir, prompt_dict, epoch, steps,
|
||||
sample_prompts_te_outputs, prompt_replacement,
|
||||
accelerator,
|
||||
args,
|
||||
dit,
|
||||
text_encoder,
|
||||
vae: qwen_image_autoencoder_kl.AutoencoderKLQwenImage,
|
||||
tokenize_strategy,
|
||||
text_encoding_strategy,
|
||||
save_dir,
|
||||
prompt_dict,
|
||||
epoch,
|
||||
steps,
|
||||
sample_prompts_te_outputs,
|
||||
prompt_replacement,
|
||||
):
|
||||
"""Generate a single sample image."""
|
||||
prompt = prompt_dict.get("prompt", "")
|
||||
@@ -540,6 +486,7 @@ def _sample_image_inference(
|
||||
height = prompt_dict.get("height", 512)
|
||||
scale = prompt_dict.get("scale", 7.5)
|
||||
seed = prompt_dict.get("seed")
|
||||
flow_shift = prompt_dict.get("flow_shift", 3.0)
|
||||
|
||||
if prompt_replacement is not None:
|
||||
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
@@ -553,7 +500,9 @@ def _sample_image_inference(
|
||||
height = max(64, height - height % 16)
|
||||
width = max(64, width - width % 16)
|
||||
|
||||
logger.info(f" prompt: {prompt}, size: {width}x{height}, steps: {sample_steps}, scale: {scale}")
|
||||
logger.info(
|
||||
f" prompt: {prompt}, size: {width}x{height}, steps: {sample_steps}, scale: {scale}, flow_shift: {flow_shift}, seed: {seed}"
|
||||
)
|
||||
|
||||
# Encode prompt
|
||||
def encode_prompt(prpt):
|
||||
@@ -579,13 +528,13 @@ def _sample_image_inference(
|
||||
t5_input_ids = torch.from_numpy(t5_input_ids).unsqueeze(0)
|
||||
t5_attn_mask = torch.from_numpy(t5_attn_mask).unsqueeze(0)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(accelerator.device, dtype=dit.t_embedding_norm.weight.dtype)
|
||||
prompt_embeds = prompt_embeds.to(accelerator.device, dtype=dit.dtype)
|
||||
attn_mask = attn_mask.to(accelerator.device)
|
||||
t5_input_ids = t5_input_ids.to(accelerator.device, dtype=torch.long)
|
||||
t5_attn_mask = t5_attn_mask.to(accelerator.device)
|
||||
|
||||
# Process through LLM adapter if available
|
||||
if dit.use_llm_adapter and hasattr(dit, 'llm_adapter'):
|
||||
if dit.use_llm_adapter:
|
||||
crossattn_emb = dit.llm_adapter(
|
||||
source_hidden_states=prompt_embeds,
|
||||
target_input_ids=t5_input_ids,
|
||||
@@ -608,12 +557,12 @@ def _sample_image_inference(
|
||||
neg_t5_ids = torch.from_numpy(neg_t5_ids).unsqueeze(0)
|
||||
neg_t5_am = torch.from_numpy(neg_t5_am).unsqueeze(0)
|
||||
|
||||
neg_pe = neg_pe.to(accelerator.device, dtype=dit.t_embedding_norm.weight.dtype)
|
||||
neg_pe = neg_pe.to(accelerator.device, dtype=dit.dtype)
|
||||
neg_am = neg_am.to(accelerator.device)
|
||||
neg_t5_ids = neg_t5_ids.to(accelerator.device, dtype=torch.long)
|
||||
neg_t5_am = neg_t5_am.to(accelerator.device)
|
||||
|
||||
if dit.use_llm_adapter and hasattr(dit, 'llm_adapter'):
|
||||
if dit.use_llm_adapter:
|
||||
neg_crossattn_emb = dit.llm_adapter(
|
||||
source_hidden_states=neg_pe,
|
||||
target_input_ids=neg_t5_ids,
|
||||
@@ -627,16 +576,16 @@ def _sample_image_inference(
|
||||
# Generate sample
|
||||
clean_memory_on_device(accelerator.device)
|
||||
latents = do_sample(
|
||||
height, width, seed, dit, crossattn_emb,
|
||||
sample_steps, dit.t_embedding_norm.weight.dtype,
|
||||
accelerator.device, scale, neg_crossattn_emb,
|
||||
height, width, seed, dit, crossattn_emb, sample_steps, dit.dtype, accelerator.device, scale, flow_shift, neg_crossattn_emb
|
||||
)
|
||||
|
||||
# Decode latents
|
||||
gc.collect()
|
||||
synchronize_device(accelerator.device)
|
||||
clean_memory_on_device(accelerator.device)
|
||||
org_vae_device = next(vae.parameters()).device
|
||||
org_vae_device = vae.device
|
||||
vae.to(accelerator.device)
|
||||
decoded = vae.decode(latents.to(next(vae.parameters()).device, dtype=next(vae.parameters()).dtype), vae_scale)
|
||||
decoded = vae.decode_to_pixels(latents)
|
||||
vae.to(org_vae_device)
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
@@ -662,4 +611,5 @@ def _sample_image_inference(
|
||||
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
|
||||
wandb_tracker = accelerator.get_tracker("wandb")
|
||||
import wandb
|
||||
|
||||
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False)
|
||||
|
||||
@@ -3,10 +3,14 @@
|
||||
import os
|
||||
from typing import Dict, List, Optional, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from safetensors.torch import load_file, save_file
|
||||
from accelerate.utils import set_module_tensor_to_device # kept for potential future use
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
from library.fp8_optimization_utils import apply_fp8_monkey_patch
|
||||
from library.lora_utils import load_safetensors_with_lora_and_fp8
|
||||
from library import anima_models
|
||||
from library.safetensors_utils import WeightTransformHooks
|
||||
from .utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
@@ -14,150 +18,134 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from library import anima_models
|
||||
|
||||
# Original Anima high-precision keys. Kept for reference, but not used currently.
|
||||
# # Keys that should stay in high precision (float32/bfloat16, not quantized)
|
||||
# KEEP_IN_HIGH_PRECISION = ["x_embedder", "t_embedder", "t_embedding_norm", "final_layer"]
|
||||
|
||||
|
||||
# Keys that should stay in high precision (float32/bfloat16, not quantized)
|
||||
KEEP_IN_HIGH_PRECISION = ['x_embedder', 't_embedder', 't_embedding_norm', 'final_layer']
|
||||
FP8_OPTIMIZATION_TARGET_KEYS = ["blocks", ""]
|
||||
# ".embed." excludes Embedding in LLMAdapter
|
||||
FP8_OPTIMIZATION_EXCLUDE_KEYS = ["_embedder", "norm", "adaln", "final_layer", ".embed."]
|
||||
|
||||
|
||||
def load_safetensors(path: str, device: str = "cpu", dtype: Optional[torch.dtype] = None) -> Dict[str, torch.Tensor]:
|
||||
"""Load a safetensors file and optionally cast to dtype."""
|
||||
sd = load_file(path, device=device)
|
||||
if dtype is not None:
|
||||
sd = {k: v.to(dtype) for k, v in sd.items()}
|
||||
return sd
|
||||
|
||||
|
||||
def load_anima_dit(
|
||||
def load_anima_model(
|
||||
device: Union[str, torch.device],
|
||||
dit_path: str,
|
||||
dtype: torch.dtype,
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
transformer_dtype: Optional[torch.dtype] = None,
|
||||
llm_adapter_path: Optional[str] = None,
|
||||
disable_mmap: bool = False,
|
||||
) -> anima_models.MiniTrainDIT:
|
||||
"""Load the MiniTrainDIT model from safetensors.
|
||||
attn_mode: str,
|
||||
split_attn: bool,
|
||||
loading_device: Union[str, torch.device],
|
||||
dit_weight_dtype: Optional[torch.dtype],
|
||||
fp8_scaled: bool = False,
|
||||
lora_weights_list: Optional[List[Dict[str, torch.Tensor]]] = None,
|
||||
lora_multipliers: Optional[list[float]] = None,
|
||||
) -> anima_models.Anima:
|
||||
"""
|
||||
Load Anima model from the specified checkpoint.
|
||||
|
||||
Args:
|
||||
dit_path: Path to DiT safetensors file
|
||||
dtype: Base dtype for model parameters
|
||||
device: Device to load to
|
||||
transformer_dtype: Optional separate dtype for transformer blocks (lower precision)
|
||||
llm_adapter_path: Optional separate path for LLM adapter weights
|
||||
disable_mmap: If True, disable memory-mapped loading (reduces peak memory)
|
||||
device (Union[str, torch.device]): Device for optimization or merging
|
||||
dit_path (str): Path to the DiT model checkpoint.
|
||||
attn_mode (str): Attention mode to use, e.g., "torch", "flash", etc.
|
||||
split_attn (bool): Whether to use split attention.
|
||||
loading_device (Union[str, torch.device]): Device to load the model weights on.
|
||||
dit_weight_dtype (Optional[torch.dtype]): Data type of the DiT weights.
|
||||
If None, it will be loaded as is (same as the state_dict) or scaled for fp8. if not None, model weights will be casted to this dtype.
|
||||
fp8_scaled (bool): Whether to use fp8 scaling for the model weights.
|
||||
lora_weights_list (Optional[List[Dict[str, torch.Tensor]]]): LoRA weights to apply, if any.
|
||||
lora_multipliers (Optional[List[float]]): LoRA multipliers for the weights, if any.
|
||||
"""
|
||||
if transformer_dtype is None:
|
||||
transformer_dtype = dtype
|
||||
# dit_weight_dtype is None for fp8_scaled
|
||||
assert (
|
||||
not fp8_scaled and dit_weight_dtype is not None
|
||||
) or dit_weight_dtype is None, "dit_weight_dtype should be None when fp8_scaled is True"
|
||||
|
||||
logger.info(f"Loading Anima DiT from {dit_path}")
|
||||
if disable_mmap:
|
||||
from library.safetensors_utils import load_safetensors as load_safetensors_no_mmap
|
||||
state_dict = load_safetensors_no_mmap(dit_path, device="cpu", disable_mmap=True)
|
||||
else:
|
||||
state_dict = load_file(dit_path, device="cpu")
|
||||
device = torch.device(device)
|
||||
loading_device = torch.device(loading_device)
|
||||
|
||||
# Remove 'net.' prefix if present
|
||||
new_state_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
if k.startswith('net.'):
|
||||
k = k[len('net.'):]
|
||||
new_state_dict[k] = v
|
||||
state_dict = new_state_dict
|
||||
# We currently support fixed DiT config for Anima models
|
||||
dit_config = {
|
||||
"max_img_h": 512,
|
||||
"max_img_w": 512,
|
||||
"max_frames": 128,
|
||||
"in_channels": 16,
|
||||
"out_channels": 16,
|
||||
"patch_spatial": 2,
|
||||
"patch_temporal": 1,
|
||||
"model_channels": 2048,
|
||||
"concat_padding_mask": True,
|
||||
"crossattn_emb_channels": 1024,
|
||||
"pos_emb_cls": "rope3d",
|
||||
"pos_emb_learnable": True,
|
||||
"pos_emb_interpolation": "crop",
|
||||
"min_fps": 1,
|
||||
"max_fps": 30,
|
||||
"use_adaln_lora": True,
|
||||
"adaln_lora_dim": 256,
|
||||
"num_blocks": 28,
|
||||
"num_heads": 16,
|
||||
"extra_per_block_abs_pos_emb": False,
|
||||
"rope_h_extrapolation_ratio": 4.0,
|
||||
"rope_w_extrapolation_ratio": 4.0,
|
||||
"rope_t_extrapolation_ratio": 1.0,
|
||||
"extra_h_extrapolation_ratio": 1.0,
|
||||
"extra_w_extrapolation_ratio": 1.0,
|
||||
"extra_t_extrapolation_ratio": 1.0,
|
||||
"rope_enable_fps_modulation": False,
|
||||
"use_llm_adapter": True,
|
||||
"attn_mode": attn_mode,
|
||||
"split_attn": split_attn,
|
||||
}
|
||||
with init_empty_weights():
|
||||
model = anima_models.Anima(**dit_config)
|
||||
if dit_weight_dtype is not None:
|
||||
model.to(dit_weight_dtype)
|
||||
|
||||
# Derive config from state_dict
|
||||
dit_config = anima_models.get_dit_config(state_dict)
|
||||
|
||||
# Detect LLM adapter
|
||||
if llm_adapter_path is not None:
|
||||
use_llm_adapter = True
|
||||
dit_config['use_llm_adapter'] = True
|
||||
llm_adapter_state_dict = load_safetensors(llm_adapter_path, device="cpu")
|
||||
elif 'llm_adapter.out_proj.weight' in state_dict:
|
||||
use_llm_adapter = True
|
||||
dit_config['use_llm_adapter'] = True
|
||||
llm_adapter_state_dict = None # Loaded as part of DiT
|
||||
else:
|
||||
use_llm_adapter = False
|
||||
llm_adapter_state_dict = None
|
||||
|
||||
logger.info(f"DiT config: model_channels={dit_config['model_channels']}, num_blocks={dit_config['num_blocks']}, "
|
||||
f"num_heads={dit_config['num_heads']}, use_llm_adapter={use_llm_adapter}")
|
||||
|
||||
# Build model normally on CPU — buffers get proper values from __init__
|
||||
dit = anima_models.MiniTrainDIT(**dit_config)
|
||||
|
||||
# Merge LLM adapter weights into state_dict if loaded separately
|
||||
if use_llm_adapter and llm_adapter_state_dict is not None:
|
||||
for k, v in llm_adapter_state_dict.items():
|
||||
state_dict[f"llm_adapter.{k}"] = v
|
||||
|
||||
# Load checkpoint: strict=False keeps buffers not in checkpoint (e.g. pos_embedder.seq)
|
||||
missing, unexpected = dit.load_state_dict(state_dict, strict=False)
|
||||
if missing:
|
||||
# Filter out expected missing buffers (initialized in __init__, not saved in checkpoint)
|
||||
unexpected_missing = [k for k in missing if not any(
|
||||
buf_name in k for buf_name in ('seq', 'dim_spatial_range', 'dim_temporal_range', 'inv_freq')
|
||||
)]
|
||||
if unexpected_missing:
|
||||
logger.warning(f"Missing keys in checkpoint: {unexpected_missing[:10]}{'...' if len(unexpected_missing) > 10 else ''}")
|
||||
if unexpected:
|
||||
logger.info(f"Unexpected keys in checkpoint (ignored): {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}")
|
||||
|
||||
# Apply per-parameter dtype (high precision for 1D/critical, transformer_dtype for rest)
|
||||
for name, p in dit.named_parameters():
|
||||
dtype_to_use = dtype if (
|
||||
any(keyword in name for keyword in KEEP_IN_HIGH_PRECISION) or p.ndim == 1
|
||||
) else transformer_dtype
|
||||
p.data = p.data.to(dtype=dtype_to_use)
|
||||
|
||||
dit.to(device)
|
||||
logger.info(f"Loaded Anima DiT successfully. Parameters: {sum(p.numel() for p in dit.parameters()):,}")
|
||||
return dit
|
||||
|
||||
|
||||
def load_anima_vae(vae_path: str, dtype: torch.dtype = torch.float32, device: str = "cpu"):
|
||||
"""Load WanVAE from a safetensors/pth file.
|
||||
|
||||
Returns (vae_model, mean_tensor, std_tensor, scale).
|
||||
"""
|
||||
from library.anima_models import ANIMA_VAE_MEAN, ANIMA_VAE_STD
|
||||
|
||||
logger.info(f"Loading Anima VAE from {vae_path}")
|
||||
|
||||
# VAE config (fixed for WanVAE)
|
||||
vae_config = dict(
|
||||
dim=96,
|
||||
z_dim=16,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[False, True, True],
|
||||
dropout=0.0,
|
||||
# load model weights with dynamic fp8 optimization and LoRA merging if needed
|
||||
logger.info(f"Loading DiT model from {dit_path}, device={loading_device}")
|
||||
rename_hooks = WeightTransformHooks(rename_hook=lambda k: k[len("net.") :] if k.startswith("net.") else k)
|
||||
sd = load_safetensors_with_lora_and_fp8(
|
||||
model_files=dit_path,
|
||||
lora_weights_list=lora_weights_list,
|
||||
lora_multipliers=lora_multipliers,
|
||||
fp8_optimization=fp8_scaled,
|
||||
calc_device=device,
|
||||
move_to_device=(loading_device == device),
|
||||
dit_weight_dtype=dit_weight_dtype,
|
||||
target_keys=FP8_OPTIMIZATION_TARGET_KEYS,
|
||||
exclude_keys=FP8_OPTIMIZATION_EXCLUDE_KEYS,
|
||||
weight_transform_hooks=rename_hooks,
|
||||
)
|
||||
|
||||
from library.anima_vae import WanVAE_
|
||||
if fp8_scaled:
|
||||
apply_fp8_monkey_patch(model, sd, use_scaled_mm=False)
|
||||
|
||||
# Build model
|
||||
with torch.device('meta'):
|
||||
vae = WanVAE_(**vae_config)
|
||||
if loading_device.type != "cpu":
|
||||
# make sure all the model weights are on the loading_device
|
||||
logger.info(f"Moving weights to {loading_device}")
|
||||
for key in sd.keys():
|
||||
sd[key] = sd[key].to(loading_device)
|
||||
|
||||
# Load state dict
|
||||
if vae_path.endswith('.safetensors'):
|
||||
vae_sd = load_file(vae_path, device='cpu')
|
||||
else:
|
||||
vae_sd = torch.load(vae_path, map_location='cpu', weights_only=True)
|
||||
missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
|
||||
if missing:
|
||||
# Filter out expected missing buffers (initialized in __init__, not saved in checkpoint)
|
||||
unexpected_missing = [
|
||||
k
|
||||
for k in missing
|
||||
if not any(buf_name in k for buf_name in ("seq", "dim_spatial_range", "dim_temporal_range", "inv_freq"))
|
||||
]
|
||||
if unexpected_missing:
|
||||
# Raise error to avoid silent failures
|
||||
raise RuntimeError(
|
||||
f"Missing keys in checkpoint: {unexpected_missing[:10]}{'...' if len(unexpected_missing) > 10 else ''}"
|
||||
)
|
||||
missing = {} # all missing keys were expected
|
||||
if unexpected:
|
||||
# Raise error to avoid silent failures
|
||||
raise RuntimeError(f"Unexpected keys in checkpoint: {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}")
|
||||
logger.info(f"Loaded DiT model from {dit_path}, unexpected missing keys: {len(missing)}, unexpected keys: {len(unexpected)}")
|
||||
|
||||
vae.load_state_dict(vae_sd, assign=True)
|
||||
vae = vae.eval().requires_grad_(False).to(device, dtype=dtype)
|
||||
|
||||
# Create normalization tensors
|
||||
mean = torch.tensor(ANIMA_VAE_MEAN, dtype=dtype, device=device)
|
||||
std = torch.tensor(ANIMA_VAE_STD, dtype=dtype, device=device)
|
||||
scale = [mean, 1.0 / std]
|
||||
|
||||
logger.info(f"Loaded Anima VAE successfully.")
|
||||
return vae, mean, std, scale
|
||||
return model
|
||||
|
||||
|
||||
def load_qwen3_tokenizer(qwen3_path: str):
|
||||
@@ -175,7 +163,7 @@ def load_qwen3_tokenizer(qwen3_path: str):
|
||||
if os.path.isdir(qwen3_path):
|
||||
tokenizer = AutoTokenizer.from_pretrained(qwen3_path, local_files_only=True)
|
||||
else:
|
||||
config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'configs', 'qwen3_06b')
|
||||
config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs", "qwen3_06b")
|
||||
if not os.path.exists(config_dir):
|
||||
raise FileNotFoundError(
|
||||
f"Qwen3 config directory not found at {config_dir}. "
|
||||
@@ -190,7 +178,13 @@ def load_qwen3_tokenizer(qwen3_path: str):
|
||||
return tokenizer
|
||||
|
||||
|
||||
def load_qwen3_text_encoder(qwen3_path: str, dtype: torch.dtype = torch.bfloat16, device: str = "cpu"):
|
||||
def load_qwen3_text_encoder(
|
||||
qwen3_path: str,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
device: str = "cpu",
|
||||
lora_weights: Optional[List[Dict[str, torch.Tensor]]] = None,
|
||||
lora_multipliers: Optional[List[float]] = None,
|
||||
):
|
||||
"""Load Qwen3-0.6B text encoder.
|
||||
|
||||
Args:
|
||||
@@ -209,12 +203,10 @@ def load_qwen3_text_encoder(qwen3_path: str, dtype: torch.dtype = torch.bfloat16
|
||||
if os.path.isdir(qwen3_path):
|
||||
# Directory with full model
|
||||
tokenizer = AutoTokenizer.from_pretrained(qwen3_path, local_files_only=True)
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(
|
||||
qwen3_path, torch_dtype=dtype, local_files_only=True
|
||||
).model
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(qwen3_path, torch_dtype=dtype, local_files_only=True).model
|
||||
else:
|
||||
# Single safetensors file - use configs/qwen3_06b/ for config
|
||||
config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'configs', 'qwen3_06b')
|
||||
config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs", "qwen3_06b")
|
||||
if not os.path.exists(config_dir):
|
||||
raise FileNotFoundError(
|
||||
f"Qwen3 config directory not found at {config_dir}. "
|
||||
@@ -227,16 +219,28 @@ def load_qwen3_text_encoder(qwen3_path: str, dtype: torch.dtype = torch.bfloat16
|
||||
model = transformers.Qwen3ForCausalLM(qwen3_config).model
|
||||
|
||||
# Load weights
|
||||
if qwen3_path.endswith('.safetensors'):
|
||||
state_dict = load_file(qwen3_path, device='cpu')
|
||||
if qwen3_path.endswith(".safetensors"):
|
||||
if lora_weights is None:
|
||||
state_dict = load_file(qwen3_path, device="cpu")
|
||||
else:
|
||||
state_dict = torch.load(qwen3_path, map_location='cpu', weights_only=True)
|
||||
state_dict = load_safetensors_with_lora_and_fp8(
|
||||
model_files=qwen3_path,
|
||||
lora_weights_list=lora_weights,
|
||||
lora_multipliers=lora_multipliers,
|
||||
fp8_optimization=False,
|
||||
calc_device=device,
|
||||
move_to_device=True,
|
||||
dit_weight_dtype=None,
|
||||
)
|
||||
else:
|
||||
assert lora_weights is None, "LoRA weights merging is only supported for safetensors checkpoints"
|
||||
state_dict = torch.load(qwen3_path, map_location="cpu", weights_only=True)
|
||||
|
||||
# Remove 'model.' prefix if present
|
||||
new_sd = {}
|
||||
for k, v in state_dict.items():
|
||||
if k.startswith('model.'):
|
||||
new_sd[k[len('model.'):]] = v
|
||||
if k.startswith("model."):
|
||||
new_sd[k[len("model.") :]] = v
|
||||
else:
|
||||
new_sd[k] = v
|
||||
|
||||
@@ -265,11 +269,11 @@ def load_t5_tokenizer(t5_tokenizer_path: Optional[str] = None):
|
||||
return T5TokenizerFast.from_pretrained(t5_tokenizer_path, local_files_only=True)
|
||||
|
||||
# Use bundled config
|
||||
config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'configs', 't5_old')
|
||||
config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs", "t5_old")
|
||||
if os.path.exists(config_dir):
|
||||
return T5TokenizerFast(
|
||||
vocab_file=os.path.join(config_dir, 'spiece.model'),
|
||||
tokenizer_file=os.path.join(config_dir, 'tokenizer.json'),
|
||||
vocab_file=os.path.join(config_dir, "spiece.model"),
|
||||
tokenizer_file=os.path.join(config_dir, "tokenizer.json"),
|
||||
)
|
||||
|
||||
raise FileNotFoundError(
|
||||
@@ -279,47 +283,27 @@ def load_t5_tokenizer(t5_tokenizer_path: Optional[str] = None):
|
||||
)
|
||||
|
||||
|
||||
def save_anima_model(save_path: str, dit_state_dict: Dict[str, torch.Tensor], dtype: Optional[torch.dtype] = None):
|
||||
def save_anima_model(
|
||||
save_path: str, dit_state_dict: Dict[str, torch.Tensor], metadata: Dict[str, any], dtype: Optional[torch.dtype] = None
|
||||
):
|
||||
"""Save Anima DiT model with 'net.' prefix for ComfyUI compatibility.
|
||||
|
||||
Args:
|
||||
save_path: Output path (.safetensors)
|
||||
dit_state_dict: State dict from dit.state_dict()
|
||||
metadata: Metadata dict to include in the safetensors file
|
||||
dtype: Optional dtype to cast to before saving
|
||||
"""
|
||||
prefixed_sd = {}
|
||||
for k, v in dit_state_dict.items():
|
||||
if dtype is not None:
|
||||
v = v.to(dtype)
|
||||
prefixed_sd['net.' + k] = v.contiguous()
|
||||
# v = v.to(dtype)
|
||||
v = v.detach().clone().to("cpu").to(dtype) # Reduce GPU memory usage during save
|
||||
prefixed_sd["net." + k] = v.contiguous()
|
||||
|
||||
save_file(prefixed_sd, save_path, metadata={'format': 'pt'})
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
metadata["format"] = "pt" # For compatibility with the official .safetensors file
|
||||
|
||||
save_file(prefixed_sd, save_path, metadata=metadata) # safetensors.save_file cosumes a lot of memory, but Anima is small enough
|
||||
logger.info(f"Saved Anima model to {save_path}")
|
||||
|
||||
|
||||
def vae_encode(tensor: torch.Tensor, vae, scale):
|
||||
"""Encode tensor through WanVAE with normalization.
|
||||
|
||||
Args:
|
||||
tensor: Input tensor (B, C, T, H, W) in [-1, 1] range
|
||||
vae: WanVAE_ model
|
||||
scale: [mean, 1/std] list
|
||||
|
||||
Returns:
|
||||
Normalized latents
|
||||
"""
|
||||
return vae.encode(tensor, scale)
|
||||
|
||||
|
||||
def vae_decode(latents: torch.Tensor, vae, scale):
|
||||
"""Decode latents through WanVAE with denormalization.
|
||||
|
||||
Args:
|
||||
latents: Normalized latents
|
||||
vae: WanVAE_ model
|
||||
scale: [mean, 1/std] list
|
||||
|
||||
Returns:
|
||||
Decoded tensor in [-1, 1] range
|
||||
"""
|
||||
return vae.decode(latents, scale)
|
||||
|
||||
@@ -1,577 +0,0 @@
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.cuda.amp as amp
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
CACHE_T = 2
|
||||
|
||||
|
||||
class CausalConv3d(nn.Conv3d):
|
||||
"""
|
||||
Causal 3d convolusion.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._padding = (self.padding[2], self.padding[2], self.padding[1],
|
||||
self.padding[1], 2 * self.padding[0], 0)
|
||||
self.padding = (0, 0, 0)
|
||||
|
||||
def forward(self, x, cache_x=None):
|
||||
padding = list(self._padding)
|
||||
if cache_x is not None and self._padding[4] > 0:
|
||||
cache_x = cache_x.to(x.device)
|
||||
x = torch.cat([cache_x, x], dim=2)
|
||||
padding[4] -= cache_x.shape[2]
|
||||
x = F.pad(x, padding)
|
||||
|
||||
return super().forward(x)
|
||||
|
||||
|
||||
class RMS_norm(nn.Module):
|
||||
|
||||
def __init__(self, dim, channel_first=True, images=True, bias=False):
|
||||
super().__init__()
|
||||
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
||||
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
||||
|
||||
self.channel_first = channel_first
|
||||
self.scale = dim**0.5
|
||||
self.gamma = nn.Parameter(torch.ones(shape))
|
||||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(
|
||||
x, dim=(1 if self.channel_first else
|
||||
-1)) * self.scale * self.gamma + self.bias
|
||||
|
||||
|
||||
class Upsample(nn.Upsample):
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Fix bfloat16 support for nearest neighbor interpolation.
|
||||
"""
|
||||
return super().forward(x.float()).type_as(x)
|
||||
|
||||
|
||||
class Resample(nn.Module):
|
||||
|
||||
def __init__(self, dim, mode):
|
||||
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
|
||||
'downsample3d')
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.mode = mode
|
||||
|
||||
# layers
|
||||
if mode == 'upsample2d':
|
||||
self.resample = nn.Sequential(
|
||||
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||||
nn.Conv2d(dim, dim // 2, 3, padding=1))
|
||||
elif mode == 'upsample3d':
|
||||
self.resample = nn.Sequential(
|
||||
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
||||
nn.Conv2d(dim, dim // 2, 3, padding=1))
|
||||
self.time_conv = CausalConv3d(
|
||||
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
||||
|
||||
elif mode == 'downsample2d':
|
||||
self.resample = nn.Sequential(
|
||||
nn.ZeroPad2d((0, 1, 0, 1)),
|
||||
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||||
elif mode == 'downsample3d':
|
||||
self.resample = nn.Sequential(
|
||||
nn.ZeroPad2d((0, 1, 0, 1)),
|
||||
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
||||
self.time_conv = CausalConv3d(
|
||||
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
||||
|
||||
else:
|
||||
self.resample = nn.Identity()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
b, c, t, h, w = x.size()
|
||||
if self.mode == 'upsample3d':
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
if feat_cache[idx] is None:
|
||||
feat_cache[idx] = 'Rep'
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[
|
||||
idx] is not None and feat_cache[idx] != 'Rep':
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
if cache_x.shape[2] < 2 and feat_cache[
|
||||
idx] is not None and feat_cache[idx] == 'Rep':
|
||||
cache_x = torch.cat([
|
||||
torch.zeros_like(cache_x).to(cache_x.device),
|
||||
cache_x
|
||||
],
|
||||
dim=2)
|
||||
if feat_cache[idx] == 'Rep':
|
||||
x = self.time_conv(x)
|
||||
else:
|
||||
x = self.time_conv(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
|
||||
x = x.reshape(b, 2, c, t, h, w)
|
||||
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
|
||||
3)
|
||||
x = x.reshape(b, c, t * 2, h, w)
|
||||
t = x.shape[2]
|
||||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||||
x = self.resample(x)
|
||||
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
|
||||
|
||||
if self.mode == 'downsample3d':
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
if feat_cache[idx] is None:
|
||||
feat_cache[idx] = x.clone()
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
|
||||
cache_x = x[:, :, -1:, :, :].clone()
|
||||
x = self.time_conv(
|
||||
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
return x
|
||||
|
||||
def init_weight(self, conv):
|
||||
conv_weight = conv.weight
|
||||
nn.init.zeros_(conv_weight)
|
||||
c1, c2, t, h, w = conv_weight.size()
|
||||
one_matrix = torch.eye(c1, c2)
|
||||
init_matrix = one_matrix
|
||||
nn.init.zeros_(conv_weight)
|
||||
conv_weight.data[:, :, 1, 0, 0] = init_matrix
|
||||
conv.weight.data.copy_(conv_weight)
|
||||
nn.init.zeros_(conv.bias.data)
|
||||
|
||||
def init_weight2(self, conv):
|
||||
conv_weight = conv.weight.data
|
||||
nn.init.zeros_(conv_weight)
|
||||
c1, c2, t, h, w = conv_weight.size()
|
||||
init_matrix = torch.eye(c1 // 2, c2)
|
||||
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
|
||||
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
|
||||
conv.weight.data.copy_(conv_weight)
|
||||
nn.init.zeros_(conv.bias.data)
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_dim, out_dim, dropout=0.0):
|
||||
super().__init__()
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
|
||||
# layers
|
||||
self.residual = nn.Sequential(
|
||||
RMS_norm(in_dim, images=False), nn.SiLU(),
|
||||
CausalConv3d(in_dim, out_dim, 3, padding=1),
|
||||
RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
|
||||
CausalConv3d(out_dim, out_dim, 3, padding=1))
|
||||
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
|
||||
if in_dim != out_dim else nn.Identity()
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
h = self.shortcut(x)
|
||||
for layer in self.residual:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x + h
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""
|
||||
Causal self-attention with a single head.
|
||||
"""
|
||||
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
# layers
|
||||
self.norm = RMS_norm(dim)
|
||||
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
||||
self.proj = nn.Conv2d(dim, dim, 1)
|
||||
|
||||
# zero out the last layer params
|
||||
nn.init.zeros_(self.proj.weight)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
b, c, t, h, w = x.size()
|
||||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||||
x = self.norm(x)
|
||||
# compute query, key, value
|
||||
q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3,
|
||||
-1).permute(0, 1, 3,
|
||||
2).contiguous().chunk(
|
||||
3, dim=-1)
|
||||
|
||||
# apply attention
|
||||
x = F.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
)
|
||||
x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
|
||||
|
||||
# output
|
||||
x = self.proj(x)
|
||||
x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
|
||||
return x + identity
|
||||
|
||||
|
||||
class Encoder3d(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[True, True, False],
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_downsample = temperal_downsample
|
||||
|
||||
# dimensions
|
||||
dims = [dim * u for u in [1] + dim_mult]
|
||||
scale = 1.0
|
||||
|
||||
# init block
|
||||
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
|
||||
|
||||
# downsample blocks
|
||||
downsamples = []
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
# residual (+attention) blocks
|
||||
for _ in range(num_res_blocks):
|
||||
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||||
if scale in attn_scales:
|
||||
downsamples.append(AttentionBlock(out_dim))
|
||||
in_dim = out_dim
|
||||
|
||||
# downsample block
|
||||
if i != len(dim_mult) - 1:
|
||||
mode = 'downsample3d' if temperal_downsample[
|
||||
i] else 'downsample2d'
|
||||
downsamples.append(Resample(out_dim, mode=mode))
|
||||
scale /= 2.0
|
||||
self.downsamples = nn.Sequential(*downsamples)
|
||||
|
||||
# middle blocks
|
||||
self.middle = nn.Sequential(
|
||||
ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
|
||||
ResidualBlock(out_dim, out_dim, dropout))
|
||||
|
||||
# output blocks
|
||||
self.head = nn.Sequential(
|
||||
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||
CausalConv3d(out_dim, z_dim, 3, padding=1))
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = self.conv1(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
|
||||
## downsamples
|
||||
for layer in self.downsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## middle
|
||||
for layer in self.middle:
|
||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## head
|
||||
for layer in self.head:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
class Decoder3d(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_upsample=[False, True, True],
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_upsample = temperal_upsample
|
||||
|
||||
# dimensions
|
||||
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
||||
scale = 1.0 / 2**(len(dim_mult) - 2)
|
||||
|
||||
# init block
|
||||
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
||||
|
||||
# middle blocks
|
||||
self.middle = nn.Sequential(
|
||||
ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
|
||||
ResidualBlock(dims[0], dims[0], dropout))
|
||||
|
||||
# upsample blocks
|
||||
upsamples = []
|
||||
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
||||
# residual (+attention) blocks
|
||||
if i == 1 or i == 2 or i == 3:
|
||||
in_dim = in_dim // 2
|
||||
for _ in range(num_res_blocks + 1):
|
||||
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
||||
if scale in attn_scales:
|
||||
upsamples.append(AttentionBlock(out_dim))
|
||||
in_dim = out_dim
|
||||
|
||||
# upsample block
|
||||
if i != len(dim_mult) - 1:
|
||||
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
|
||||
upsamples.append(Resample(out_dim, mode=mode))
|
||||
scale *= 2.0
|
||||
self.upsamples = nn.Sequential(*upsamples)
|
||||
|
||||
# output blocks
|
||||
self.head = nn.Sequential(
|
||||
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||
CausalConv3d(out_dim, 3, 3, padding=1))
|
||||
|
||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||
## conv1
|
||||
if feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = self.conv1(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = self.conv1(x)
|
||||
|
||||
## middle
|
||||
for layer in self.middle:
|
||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## upsamples
|
||||
for layer in self.upsamples:
|
||||
if feat_cache is not None:
|
||||
x = layer(x, feat_cache, feat_idx)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
## head
|
||||
for layer in self.head:
|
||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||
idx = feat_idx[0]
|
||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
||||
# cache last frame of last two chunk
|
||||
cache_x = torch.cat([
|
||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
||||
cache_x.device), cache_x
|
||||
],
|
||||
dim=2)
|
||||
x = layer(x, feat_cache[idx])
|
||||
feat_cache[idx] = cache_x
|
||||
feat_idx[0] += 1
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
def count_conv3d(model):
|
||||
count = 0
|
||||
for m in model.modules():
|
||||
if isinstance(m, CausalConv3d):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
class WanVAE_(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=128,
|
||||
z_dim=4,
|
||||
dim_mult=[1, 2, 4, 4],
|
||||
num_res_blocks=2,
|
||||
attn_scales=[],
|
||||
temperal_downsample=[True, True, False],
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.z_dim = z_dim
|
||||
self.dim_mult = dim_mult
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attn_scales = attn_scales
|
||||
self.temperal_downsample = temperal_downsample
|
||||
self.temperal_upsample = temperal_downsample[::-1]
|
||||
|
||||
# modules
|
||||
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
|
||||
attn_scales, self.temperal_downsample, dropout)
|
||||
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||||
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
|
||||
attn_scales, self.temperal_upsample, dropout)
|
||||
|
||||
def forward(self, x):
|
||||
mu, log_var = self.encode(x)
|
||||
z = self.reparameterize(mu, log_var)
|
||||
x_recon = self.decode(z)
|
||||
return x_recon, mu, log_var
|
||||
|
||||
def encode(self, x, scale):
|
||||
self.clear_cache()
|
||||
## cache
|
||||
t = x.shape[2]
|
||||
iter_ = 1 + (t - 1) // 4
|
||||
for i in range(iter_):
|
||||
self._enc_conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.encoder(
|
||||
x[:, :, :1, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx)
|
||||
else:
|
||||
out_ = self.encoder(
|
||||
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||||
feat_cache=self._enc_feat_map,
|
||||
feat_idx=self._enc_conv_idx)
|
||||
out = torch.cat([out, out_], 2)
|
||||
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
||||
if isinstance(scale[0], torch.Tensor):
|
||||
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
|
||||
1, self.z_dim, 1, 1, 1)
|
||||
else:
|
||||
mu = (mu - scale[0]) * scale[1]
|
||||
self.clear_cache()
|
||||
return mu
|
||||
|
||||
def decode(self, z, scale):
|
||||
self.clear_cache()
|
||||
# z: [b,c,t,h,w]
|
||||
if isinstance(scale[0], torch.Tensor):
|
||||
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
|
||||
1, self.z_dim, 1, 1, 1)
|
||||
else:
|
||||
z = z / scale[1] + scale[0]
|
||||
iter_ = z.shape[2]
|
||||
x = self.conv2(z)
|
||||
for i in range(iter_):
|
||||
self._conv_idx = [0]
|
||||
if i == 0:
|
||||
out = self.decoder(
|
||||
x[:, :, i:i + 1, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx)
|
||||
else:
|
||||
out_ = self.decoder(
|
||||
x[:, :, i:i + 1, :, :],
|
||||
feat_cache=self._feat_map,
|
||||
feat_idx=self._conv_idx)
|
||||
out = torch.cat([out, out_], 2)
|
||||
self.clear_cache()
|
||||
return out
|
||||
|
||||
def reparameterize(self, mu, log_var):
|
||||
std = torch.exp(0.5 * log_var)
|
||||
eps = torch.randn_like(std)
|
||||
return eps * std + mu
|
||||
|
||||
def sample(self, imgs, deterministic=False):
|
||||
mu, log_var = self.encode(imgs)
|
||||
if deterministic:
|
||||
return mu
|
||||
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
|
||||
return mu + std * torch.randn_like(std)
|
||||
|
||||
def clear_cache(self):
|
||||
self._conv_num = count_conv3d(self.decoder)
|
||||
self._conv_idx = [0]
|
||||
self._feat_map = [None] * self._conv_num
|
||||
#cache encode
|
||||
self._enc_conv_num = count_conv3d(self.encoder)
|
||||
self._enc_conv_idx = [0]
|
||||
self._enc_feat_map = [None] * self._enc_conv_num
|
||||
@@ -37,6 +37,14 @@ class AttentionParams:
|
||||
cu_seqlens: Optional[torch.Tensor] = None
|
||||
max_seqlen: Optional[int] = None
|
||||
|
||||
@property
|
||||
def supports_fp32(self) -> bool:
|
||||
return self.attn_mode not in ["flash"]
|
||||
|
||||
@property
|
||||
def requires_same_dtype(self) -> bool:
|
||||
return self.attn_mode in ["xformers"]
|
||||
|
||||
@staticmethod
|
||||
def create_attention_params(attn_mode: Optional[str], split_attn: bool) -> "AttentionParams":
|
||||
return AttentionParams(attn_mode, split_attn)
|
||||
@@ -95,7 +103,7 @@ def attention(
|
||||
qkv_or_q: Query tensor [B, L, H, D]. or list of such tensors.
|
||||
k: Key tensor [B, L, H, D].
|
||||
v: Value tensor [B, L, H, D].
|
||||
attn_param: Attention parameters including mask and sequence lengths.
|
||||
attn_params: Attention parameters including mask and sequence lengths.
|
||||
drop_rate: Attention dropout rate.
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -195,6 +195,9 @@ class ModelOffloader(Offloader):
|
||||
self.remove_handles.append(handle)
|
||||
|
||||
def set_forward_only(self, forward_only: bool):
|
||||
# switching must wait for all pending transfers
|
||||
for block_idx in list(self.futures.keys()):
|
||||
self._wait_blocks_move(block_idx)
|
||||
self.forward_only = forward_only
|
||||
|
||||
def __del__(self):
|
||||
@@ -237,6 +240,10 @@ class ModelOffloader(Offloader):
|
||||
if self.debug:
|
||||
print(f"Prepare block devices before forward")
|
||||
|
||||
# wait for all pending transfers
|
||||
for block_idx in list(self.futures.keys()):
|
||||
self._wait_blocks_move(block_idx)
|
||||
|
||||
for b in blocks[0 : self.num_blocks - self.blocks_to_swap]:
|
||||
b.to(self.device)
|
||||
weighs_to_device(b, self.device) # make sure weights are on device
|
||||
|
||||
@@ -471,7 +471,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
||||
def get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
bsz, _, h, w = latents.shape
|
||||
bsz, h, w = latents.shape[0], latents.shape[-2], latents.shape[-1]
|
||||
assert bsz > 0, "Batch size not large enough"
|
||||
num_timesteps = noise_scheduler.config.num_train_timesteps
|
||||
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
|
||||
@@ -512,7 +512,7 @@ def get_noisy_model_input_and_timesteps(
|
||||
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
|
||||
|
||||
# Broadcast sigmas to latent shape
|
||||
sigmas = sigmas.view(-1, 1, 1, 1)
|
||||
sigmas = sigmas.view(-1, 1, 1, 1) if latents.ndim == 4 else sigmas.view(-1, 1, 1, 1, 1)
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
|
||||
@@ -9,7 +9,7 @@ import logging
|
||||
from tqdm import tqdm
|
||||
|
||||
from library.device_utils import clean_memory_on_device
|
||||
from library.safetensors_utils import MemoryEfficientSafeOpen
|
||||
from library.safetensors_utils import MemoryEfficientSafeOpen, TensorWeightAdapter, WeightTransformHooks
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
@@ -220,6 +220,8 @@ def quantize_weight(
|
||||
tensor_max = torch.max(torch.abs(tensor).view(-1))
|
||||
scale = tensor_max / max_value
|
||||
|
||||
# print(f"Optimizing {key} with scale: {scale}")
|
||||
|
||||
# numerical safety
|
||||
scale = torch.clamp(scale, min=1e-8)
|
||||
scale = scale.to(torch.float32) # ensure scale is in float32 for division
|
||||
@@ -245,6 +247,8 @@ def load_safetensors_with_fp8_optimization(
|
||||
weight_hook=None,
|
||||
quantization_mode: str = "block",
|
||||
block_size: Optional[int] = 64,
|
||||
disable_numpy_memmap: bool = False,
|
||||
weight_transform_hooks: Optional[WeightTransformHooks] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Load weight tensors from safetensors files and merge LoRA weights into the state dict with explicit FP8 optimization.
|
||||
@@ -260,6 +264,8 @@ def load_safetensors_with_fp8_optimization(
|
||||
weight_hook (callable, optional): Function to apply to each weight tensor before optimization
|
||||
quantization_mode (str): Quantization mode, "tensor", "channel", or "block"
|
||||
block_size (int, optional): Block size for block-wise quantization (used if quantization_mode is "block")
|
||||
disable_numpy_memmap (bool): Disable numpy memmap when loading safetensors
|
||||
weight_transform_hooks (WeightTransformHooks, optional): Hooks for weight transformation during loading
|
||||
|
||||
Returns:
|
||||
dict: FP8 optimized state dict
|
||||
@@ -288,7 +294,9 @@ def load_safetensors_with_fp8_optimization(
|
||||
# Process each file
|
||||
state_dict = {}
|
||||
for model_file in model_files:
|
||||
with MemoryEfficientSafeOpen(model_file) as f:
|
||||
with MemoryEfficientSafeOpen(model_file, disable_numpy_memmap=disable_numpy_memmap) as original_f:
|
||||
f = TensorWeightAdapter(weight_transform_hooks, original_f) if weight_transform_hooks is not None else original_f
|
||||
|
||||
keys = f.keys()
|
||||
for key in tqdm(keys, desc=f"Loading {os.path.basename(model_file)}", unit="key"):
|
||||
value = f.get_tensor(key)
|
||||
@@ -311,6 +319,11 @@ def load_safetensors_with_fp8_optimization(
|
||||
value = value.to(calc_device)
|
||||
|
||||
original_dtype = value.dtype
|
||||
if original_dtype.itemsize == 1:
|
||||
raise ValueError(
|
||||
f"Layer {key} is already in {original_dtype} format. `--fp8_scaled` optimization should not be applied. Please use fp16/bf16/float32 model weights."
|
||||
+ f" / レイヤー {key} は既に{original_dtype}形式です。`--fp8_scaled` 最適化は適用できません。FP16/BF16/Float32のモデル重みを使用してください。"
|
||||
)
|
||||
quantized_weight, scale_tensor = quantize_weight(
|
||||
key, value, fp8_dtype, max_value, min_value, quantization_mode, block_size
|
||||
)
|
||||
@@ -387,7 +400,7 @@ def fp8_linear_forward_patch(self: nn.Linear, x, use_scaled_mm=False, max_value=
|
||||
else:
|
||||
o = torch._scaled_mm(x, weight, out_dtype=input_dtype, scale_a=scale_x, scale_b=scale_weight)
|
||||
|
||||
o = o.reshape(original_shape[0], original_shape[1], -1) if x.ndim == 3 else o.reshape(original_shape[0], -1)
|
||||
o = o.reshape(original_shape[0], original_shape[1], -1) if len(original_shape) == 3 else o.reshape(original_shape[0], -1)
|
||||
return o.to(input_dtype)
|
||||
|
||||
else:
|
||||
|
||||
@@ -5,7 +5,7 @@ import torch
|
||||
from tqdm import tqdm
|
||||
from library.device_utils import synchronize_device
|
||||
from library.fp8_optimization_utils import load_safetensors_with_fp8_optimization
|
||||
from library.safetensors_utils import MemoryEfficientSafeOpen
|
||||
from library.safetensors_utils import MemoryEfficientSafeOpen, TensorWeightAdapter, WeightTransformHooks, get_split_weight_filenames
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
@@ -44,7 +44,7 @@ def filter_lora_state_dict(
|
||||
|
||||
def load_safetensors_with_lora_and_fp8(
|
||||
model_files: Union[str, List[str]],
|
||||
lora_weights_list: Optional[Dict[str, torch.Tensor]],
|
||||
lora_weights_list: Optional[List[Dict[str, torch.Tensor]]],
|
||||
lora_multipliers: Optional[List[float]],
|
||||
fp8_optimization: bool,
|
||||
calc_device: torch.device,
|
||||
@@ -52,19 +52,23 @@ def load_safetensors_with_lora_and_fp8(
|
||||
dit_weight_dtype: Optional[torch.dtype] = None,
|
||||
target_keys: Optional[List[str]] = None,
|
||||
exclude_keys: Optional[List[str]] = None,
|
||||
disable_numpy_memmap: bool = False,
|
||||
weight_transform_hooks: Optional[WeightTransformHooks] = None,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Merge LoRA weights into the state dict of a model with fp8 optimization if needed.
|
||||
|
||||
Args:
|
||||
model_files (Union[str, List[str]]): Path to the model file or list of paths. If the path matches a pattern like `00001-of-00004`, it will load all files with the same prefix.
|
||||
lora_weights_list (Optional[Dict[str, torch.Tensor]]): Dictionary of LoRA weight tensors to load.
|
||||
lora_weights_list (Optional[List[Dict[str, torch.Tensor]]]): List of dictionaries of LoRA weight tensors to load.
|
||||
lora_multipliers (Optional[List[float]]): List of multipliers for LoRA weights.
|
||||
fp8_optimization (bool): Whether to apply FP8 optimization.
|
||||
calc_device (torch.device): Device to calculate on.
|
||||
move_to_device (bool): Whether to move tensors to the calculation device after loading.
|
||||
target_keys (Optional[List[str]]): Keys to target for optimization.
|
||||
exclude_keys (Optional[List[str]]): Keys to exclude from optimization.
|
||||
disable_numpy_memmap (bool): Whether to disable numpy memmap when loading safetensors.
|
||||
weight_transform_hooks (Optional[WeightTransformHooks]): Hooks for transforming weights during loading.
|
||||
"""
|
||||
|
||||
# if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix
|
||||
@@ -73,19 +77,9 @@ def load_safetensors_with_lora_and_fp8(
|
||||
|
||||
extended_model_files = []
|
||||
for model_file in model_files:
|
||||
basename = os.path.basename(model_file)
|
||||
match = re.match(r"^(.*?)(\d+)-of-(\d+)\.safetensors$", basename)
|
||||
if match:
|
||||
prefix = basename[: match.start(2)]
|
||||
count = int(match.group(3))
|
||||
state_dict = {}
|
||||
for i in range(count):
|
||||
filename = f"{prefix}{i + 1:05d}-of-{count:05d}.safetensors"
|
||||
filepath = os.path.join(os.path.dirname(model_file), filename)
|
||||
if os.path.exists(filepath):
|
||||
extended_model_files.append(filepath)
|
||||
else:
|
||||
raise FileNotFoundError(f"File {filepath} not found")
|
||||
split_filenames = get_split_weight_filenames(model_file)
|
||||
if split_filenames is not None:
|
||||
extended_model_files.extend(split_filenames)
|
||||
else:
|
||||
extended_model_files.append(model_file)
|
||||
model_files = extended_model_files
|
||||
@@ -114,7 +108,7 @@ def load_safetensors_with_lora_and_fp8(
|
||||
logger.info(f"Merging LoRA weights into state dict. multipliers: {lora_multipliers}")
|
||||
|
||||
# make hook for LoRA merging
|
||||
def weight_hook_func(model_weight_key, model_weight, keep_on_calc_device=False):
|
||||
def weight_hook_func(model_weight_key, model_weight: torch.Tensor, keep_on_calc_device=False):
|
||||
nonlocal list_of_lora_weight_keys, lora_weights_list, lora_multipliers, calc_device
|
||||
|
||||
if not model_weight_key.endswith(".weight"):
|
||||
@@ -126,13 +120,18 @@ def load_safetensors_with_lora_and_fp8(
|
||||
|
||||
for lora_weight_keys, lora_sd, multiplier in zip(list_of_lora_weight_keys, lora_weights_list, lora_multipliers):
|
||||
# check if this weight has LoRA weights
|
||||
lora_name = model_weight_key.rsplit(".", 1)[0] # remove trailing ".weight"
|
||||
lora_name = "lora_unet_" + lora_name.replace(".", "_")
|
||||
lora_name_without_prefix = model_weight_key.rsplit(".", 1)[0] # remove trailing ".weight"
|
||||
found = False
|
||||
for prefix in ["lora_unet_", ""]:
|
||||
lora_name = prefix + lora_name_without_prefix.replace(".", "_")
|
||||
down_key = lora_name + ".lora_down.weight"
|
||||
up_key = lora_name + ".lora_up.weight"
|
||||
alpha_key = lora_name + ".alpha"
|
||||
if down_key not in lora_weight_keys or up_key not in lora_weight_keys:
|
||||
continue
|
||||
if down_key in lora_weight_keys and up_key in lora_weight_keys:
|
||||
found = True
|
||||
break
|
||||
if not found:
|
||||
continue # no LoRA weights for this model weight
|
||||
|
||||
# get LoRA weights
|
||||
down_weight = lora_sd[down_key]
|
||||
@@ -145,6 +144,13 @@ def load_safetensors_with_lora_and_fp8(
|
||||
down_weight = down_weight.to(calc_device)
|
||||
up_weight = up_weight.to(calc_device)
|
||||
|
||||
original_dtype = model_weight.dtype
|
||||
if original_dtype.itemsize == 1: # fp8
|
||||
# temporarily convert to float16 for calculation
|
||||
model_weight = model_weight.to(torch.float16)
|
||||
down_weight = down_weight.to(torch.float16)
|
||||
up_weight = up_weight.to(torch.float16)
|
||||
|
||||
# W <- W + U * D
|
||||
if len(model_weight.size()) == 2:
|
||||
# linear
|
||||
@@ -166,6 +172,9 @@ def load_safetensors_with_lora_and_fp8(
|
||||
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
|
||||
model_weight = model_weight + multiplier * conved * scale
|
||||
|
||||
if original_dtype.itemsize == 1: # fp8
|
||||
model_weight = model_weight.to(original_dtype) # convert back to original dtype
|
||||
|
||||
# remove LoRA keys from set
|
||||
lora_weight_keys.remove(down_key)
|
||||
lora_weight_keys.remove(up_key)
|
||||
@@ -187,6 +196,8 @@ def load_safetensors_with_lora_and_fp8(
|
||||
target_keys,
|
||||
exclude_keys,
|
||||
weight_hook=weight_hook,
|
||||
disable_numpy_memmap=disable_numpy_memmap,
|
||||
weight_transform_hooks=weight_transform_hooks,
|
||||
)
|
||||
|
||||
for lora_weight_keys in list_of_lora_weight_keys:
|
||||
@@ -208,6 +219,8 @@ def load_safetensors_with_fp8_optimization_and_hook(
|
||||
target_keys: Optional[List[str]] = None,
|
||||
exclude_keys: Optional[List[str]] = None,
|
||||
weight_hook: callable = None,
|
||||
disable_numpy_memmap: bool = False,
|
||||
weight_transform_hooks: Optional[WeightTransformHooks] = None,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Load state dict from safetensors files and merge LoRA weights into the state dict with fp8 optimization if needed.
|
||||
@@ -218,7 +231,14 @@ def load_safetensors_with_fp8_optimization_and_hook(
|
||||
)
|
||||
# dit_weight_dtype is not used because we use fp8 optimization
|
||||
state_dict = load_safetensors_with_fp8_optimization(
|
||||
model_files, calc_device, target_keys, exclude_keys, move_to_device=move_to_device, weight_hook=weight_hook
|
||||
model_files,
|
||||
calc_device,
|
||||
target_keys,
|
||||
exclude_keys,
|
||||
move_to_device=move_to_device,
|
||||
weight_hook=weight_hook,
|
||||
disable_numpy_memmap=disable_numpy_memmap,
|
||||
weight_transform_hooks=weight_transform_hooks,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
@@ -226,7 +246,8 @@ def load_safetensors_with_fp8_optimization_and_hook(
|
||||
)
|
||||
state_dict = {}
|
||||
for model_file in model_files:
|
||||
with MemoryEfficientSafeOpen(model_file) as f:
|
||||
with MemoryEfficientSafeOpen(model_file, disable_numpy_memmap=disable_numpy_memmap) as original_f:
|
||||
f = TensorWeightAdapter(weight_transform_hooks, original_f) if weight_transform_hooks is not None else original_f
|
||||
for key in tqdm(f.keys(), desc=f"Loading {os.path.basename(model_file)}", leave=False):
|
||||
if weight_hook is None and move_to_device:
|
||||
value = f.get_tensor(key, device=calc_device, dtype=dit_weight_dtype)
|
||||
|
||||
1735
library/qwen_image_autoencoder_kl.py
Normal file
1735
library/qwen_image_autoencoder_kl.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,3 +1,4 @@
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
import re
|
||||
import numpy as np
|
||||
@@ -44,6 +45,7 @@ def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata:
|
||||
validated[key] = value
|
||||
return validated
|
||||
|
||||
# print(f"Using memory efficient save file: {filename}")
|
||||
|
||||
header = {}
|
||||
offset = 0
|
||||
@@ -88,15 +90,17 @@ class MemoryEfficientSafeOpen:
|
||||
by using memory mapping for large tensors and avoiding unnecessary copies.
|
||||
"""
|
||||
|
||||
def __init__(self, filename):
|
||||
def __init__(self, filename, disable_numpy_memmap=False):
|
||||
"""Initialize the SafeTensor reader.
|
||||
|
||||
Args:
|
||||
filename (str): Path to the safetensors file to read.
|
||||
disable_numpy_memmap (bool): If True, disable numpy memory mapping for large tensors, using standard file read instead.
|
||||
"""
|
||||
self.filename = filename
|
||||
self.file = open(filename, "rb")
|
||||
self.header, self.header_size = self._read_header()
|
||||
self.disable_numpy_memmap = disable_numpy_memmap
|
||||
|
||||
def __enter__(self):
|
||||
"""Enter context manager."""
|
||||
@@ -178,7 +182,8 @@ class MemoryEfficientSafeOpen:
|
||||
# Use memmap for large tensors to avoid intermediate copies.
|
||||
# If device is cpu, tensor is not copied to gpu, so using memmap locks the file, which is not desired.
|
||||
# So we only use memmap if device is not cpu.
|
||||
if num_bytes > 10 * 1024 * 1024 and device is not None and device.type != "cpu":
|
||||
# If disable_numpy_memmap is True, skip numpy memory mapping to load with standard file read.
|
||||
if not self.disable_numpy_memmap and num_bytes > 10 * 1024 * 1024 and device is not None and device.type != "cpu":
|
||||
# Create memory map for zero-copy reading
|
||||
mm = np.memmap(self.filename, mode="c", dtype=np.uint8, offset=tensor_offset, shape=(num_bytes,))
|
||||
byte_tensor = torch.from_numpy(mm) # zero copy
|
||||
@@ -285,7 +290,11 @@ class MemoryEfficientSafeOpen:
|
||||
|
||||
|
||||
def load_safetensors(
|
||||
path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = None
|
||||
path: str,
|
||||
device: Union[str, torch.device],
|
||||
disable_mmap: bool = False,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
disable_numpy_memmap: bool = False,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
if disable_mmap:
|
||||
# return safetensors.torch.load(open(path, "rb").read())
|
||||
@@ -293,7 +302,7 @@ def load_safetensors(
|
||||
# logger.info(f"Loading without mmap (experimental)")
|
||||
state_dict = {}
|
||||
device = torch.device(device) if device is not None else None
|
||||
with MemoryEfficientSafeOpen(path) as f:
|
||||
with MemoryEfficientSafeOpen(path, disable_numpy_memmap=disable_numpy_memmap) as f:
|
||||
for key in f.keys():
|
||||
state_dict[key] = f.get_tensor(key, device=device, dtype=dtype)
|
||||
synchronize_device(device)
|
||||
@@ -309,6 +318,29 @@ def load_safetensors(
|
||||
return state_dict
|
||||
|
||||
|
||||
def get_split_weight_filenames(file_path: str) -> Optional[list[str]]:
|
||||
"""
|
||||
Get the list of split weight filenames (full paths) if the file name ends with 00001-of-00004 etc.
|
||||
Returns None if the file is not split.
|
||||
"""
|
||||
basename = os.path.basename(file_path)
|
||||
match = re.match(r"^(.*?)(\d+)-of-(\d+)\.safetensors$", basename)
|
||||
if match:
|
||||
prefix = basename[: match.start(2)]
|
||||
count = int(match.group(3))
|
||||
filenames = []
|
||||
for i in range(count):
|
||||
filename = f"{prefix}{i + 1:05d}-of-{count:05d}.safetensors"
|
||||
filepath = os.path.join(os.path.dirname(file_path), filename)
|
||||
if os.path.exists(filepath):
|
||||
filenames.append(filepath)
|
||||
else:
|
||||
raise FileNotFoundError(f"File {filepath} not found")
|
||||
return filenames
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def load_split_weights(
|
||||
file_path: str, device: Union[str, torch.device] = "cpu", disable_mmap: bool = False, dtype: Optional[torch.dtype] = None
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
@@ -319,19 +351,11 @@ def load_split_weights(
|
||||
device = torch.device(device)
|
||||
|
||||
# if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix
|
||||
basename = os.path.basename(file_path)
|
||||
match = re.match(r"^(.*?)(\d+)-of-(\d+)\.safetensors$", basename)
|
||||
if match:
|
||||
prefix = basename[: match.start(2)]
|
||||
count = int(match.group(3))
|
||||
split_filenames = get_split_weight_filenames(file_path)
|
||||
if split_filenames is not None:
|
||||
state_dict = {}
|
||||
for i in range(count):
|
||||
filename = f"{prefix}{i + 1:05d}-of-{count:05d}.safetensors"
|
||||
filepath = os.path.join(os.path.dirname(file_path), filename)
|
||||
if os.path.exists(filepath):
|
||||
state_dict.update(load_safetensors(filepath, device=device, disable_mmap=disable_mmap, dtype=dtype))
|
||||
else:
|
||||
raise FileNotFoundError(f"File {filepath} not found")
|
||||
for filename in split_filenames:
|
||||
state_dict.update(load_safetensors(filename, device=device, disable_mmap=disable_mmap, dtype=dtype))
|
||||
else:
|
||||
state_dict = load_safetensors(file_path, device=device, disable_mmap=disable_mmap, dtype=dtype)
|
||||
return state_dict
|
||||
@@ -349,3 +373,106 @@ def find_key(safetensors_file: str, starts_with: Optional[str] = None, ends_with
|
||||
if (starts_with is None or key.startswith(starts_with)) and (ends_with is None or key.endswith(ends_with)):
|
||||
return key
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class WeightTransformHooks:
|
||||
split_hook: Optional[callable] = None
|
||||
concat_hook: Optional[callable] = None
|
||||
rename_hook: Optional[callable] = None
|
||||
|
||||
|
||||
class TensorWeightAdapter:
|
||||
"""
|
||||
A wrapper for weight conversion hooks (split and concat) to be used with MemoryEfficientSafeOpen.
|
||||
This wrapper adapts the original MemoryEfficientSafeOpen to apply the provided split and concat hooks
|
||||
when loading tensors.
|
||||
|
||||
split_hook: A callable that takes (original_key: str, original_tensor: torch.Tensor) and returns (new_keys: list[str], new_tensors: list[torch.Tensor]).
|
||||
concat_hook: A callable that takes (original_key: str, tensors: dict[str, torch.Tensor]) and returns (new_key: str, concatenated_tensor: torch.Tensor).
|
||||
rename_hook: A callable that takes (original_key: str) and returns (new_key: str).
|
||||
|
||||
If tensors is None, the hook should return only the new keys (for split) or new key (for concat), without tensors.
|
||||
|
||||
No need to implement __enter__ and __exit__ methods, as they are handled by the original MemoryEfficientSafeOpen.
|
||||
Do not use this wrapper as a context manager directly, like `with WeightConvertHookWrapper(...) as f:`.
|
||||
|
||||
**concat_hook is not tested yet.**
|
||||
"""
|
||||
|
||||
def __init__(self, weight_convert_hook: WeightTransformHooks, original_f: MemoryEfficientSafeOpen):
|
||||
self.original_f = original_f
|
||||
self.new_key_to_original_key_map: dict[str, Union[str, list[str]]] = (
|
||||
{}
|
||||
) # for split: new_key -> original_key; for concat: new_key -> list of original_keys; for direct mapping: new_key -> original_key
|
||||
self.concat_key_set = set() # set of concatenated keys
|
||||
self.split_key_set = set() # set of split keys
|
||||
self.new_keys = []
|
||||
self.tensor_cache = {} # cache for split tensors
|
||||
self.split_hook = weight_convert_hook.split_hook
|
||||
self.concat_hook = weight_convert_hook.concat_hook
|
||||
self.rename_hook = weight_convert_hook.rename_hook
|
||||
|
||||
for key in self.original_f.keys():
|
||||
if self.split_hook is not None:
|
||||
converted_keys, _ = self.split_hook(key, None) # get new keys only
|
||||
if converted_keys is not None:
|
||||
for converted_key in converted_keys:
|
||||
self.new_key_to_original_key_map[converted_key] = key
|
||||
self.split_key_set.add(converted_key)
|
||||
self.new_keys.extend(converted_keys)
|
||||
continue # skip concat_hook if split_hook is applied
|
||||
|
||||
if self.concat_hook is not None:
|
||||
converted_key, _ = self.concat_hook(key, None) # get new key only
|
||||
if converted_key is not None:
|
||||
if converted_key not in self.concat_key_set: # first time seeing this concatenated key
|
||||
self.concat_key_set.add(converted_key)
|
||||
self.new_key_to_original_key_map[converted_key] = []
|
||||
self.new_keys.append(converted_key)
|
||||
|
||||
# multiple original keys map to the same concatenated key
|
||||
self.new_key_to_original_key_map[converted_key].append(key)
|
||||
continue # skip to next key
|
||||
|
||||
# direct mapping
|
||||
if self.rename_hook is not None:
|
||||
new_key = self.rename_hook(key)
|
||||
self.new_key_to_original_key_map[new_key] = key
|
||||
else:
|
||||
new_key = key
|
||||
|
||||
self.new_keys.append(new_key)
|
||||
|
||||
def keys(self) -> list[str]:
|
||||
return self.new_keys
|
||||
|
||||
def get_tensor(self, new_key: str, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
|
||||
# load tensor by new_key, applying split or concat hooks as needed
|
||||
if new_key not in self.new_key_to_original_key_map:
|
||||
# direct mapping
|
||||
return self.original_f.get_tensor(new_key, device=device, dtype=dtype)
|
||||
|
||||
elif new_key in self.split_key_set:
|
||||
# split hook: split key is requested multiple times, so we cache the result
|
||||
original_key = self.new_key_to_original_key_map[new_key]
|
||||
if original_key not in self.tensor_cache: # not yet split
|
||||
original_tensor = self.original_f.get_tensor(original_key, device=device, dtype=dtype)
|
||||
new_keys, new_tensors = self.split_hook(original_key, original_tensor) # apply split hook
|
||||
for k, t in zip(new_keys, new_tensors):
|
||||
self.tensor_cache[k] = t
|
||||
return self.tensor_cache.pop(new_key) # return and remove from cache
|
||||
|
||||
elif new_key in self.concat_key_set:
|
||||
# concat hook: concatenated key is requested only once, so we do not cache the result
|
||||
tensors = {}
|
||||
for original_key in self.new_key_to_original_key_map[new_key]:
|
||||
tensor = self.original_f.get_tensor(original_key, device=device, dtype=dtype)
|
||||
tensors[original_key] = tensor
|
||||
_, concatenated_tensors = self.concat_hook(self.new_key_to_original_key_map[new_key][0], tensors) # apply concat hook
|
||||
return concatenated_tensors
|
||||
|
||||
else:
|
||||
# direct mapping
|
||||
original_key = self.new_key_to_original_key_map[new_key]
|
||||
return self.original_f.get_tensor(original_key, device=device, dtype=dtype)
|
||||
|
||||
@@ -81,6 +81,8 @@ ARCH_LUMINA_2 = "lumina-2"
|
||||
ARCH_LUMINA_UNKNOWN = "lumina"
|
||||
ARCH_HUNYUAN_IMAGE_2_1 = "hunyuan-image-2.1"
|
||||
ARCH_HUNYUAN_IMAGE_UNKNOWN = "hunyuan-image"
|
||||
ARCH_ANIMA_PREVIEW = "anima-preview"
|
||||
ARCH_ANIMA_UNKNOWN = "anima-unknown"
|
||||
|
||||
ADAPTER_LORA = "lora"
|
||||
ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
|
||||
@@ -92,6 +94,7 @@ IMPL_FLUX = "https://github.com/black-forest-labs/flux"
|
||||
IMPL_CHROMA = "https://huggingface.co/lodestones/Chroma"
|
||||
IMPL_LUMINA = "https://github.com/Alpha-VLLM/Lumina-Image-2.0"
|
||||
IMPL_HUNYUAN_IMAGE = "https://github.com/Tencent-Hunyuan/HunyuanImage-2.1"
|
||||
IMPL_ANIMA = "https://huggingface.co/circlestone-labs/Anima"
|
||||
|
||||
PRED_TYPE_EPSILON = "epsilon"
|
||||
PRED_TYPE_V = "v"
|
||||
@@ -220,6 +223,12 @@ def determine_architecture(
|
||||
arch = ARCH_HUNYUAN_IMAGE_2_1
|
||||
else:
|
||||
arch = ARCH_HUNYUAN_IMAGE_UNKNOWN
|
||||
elif "anima" in model_config:
|
||||
anima_type = model_config["anima"]
|
||||
if anima_type == "preview":
|
||||
arch = ARCH_ANIMA_PREVIEW
|
||||
else:
|
||||
arch = ARCH_ANIMA_UNKNOWN
|
||||
elif v2:
|
||||
arch = ARCH_SD_V2_768_V if v_parameterization else ARCH_SD_V2_512
|
||||
else:
|
||||
@@ -252,6 +261,8 @@ def determine_implementation(
|
||||
return IMPL_FLUX
|
||||
elif "lumina" in model_config:
|
||||
return IMPL_LUMINA
|
||||
elif "anima" in model_config:
|
||||
return IMPL_ANIMA
|
||||
elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
|
||||
return IMPL_STABILITY_AI
|
||||
else:
|
||||
@@ -325,7 +336,7 @@ def determine_resolution(
|
||||
reso = (reso[0], reso[0])
|
||||
else:
|
||||
# Determine default resolution based on model type
|
||||
if sdxl or "sd3" in model_config or "flux" in model_config or "lumina" in model_config:
|
||||
if sdxl or "sd3" in model_config or "flux" in model_config or "lumina" in model_config or "anima" in model_config:
|
||||
reso = (1024, 1024)
|
||||
elif v2 and v_parameterization:
|
||||
reso = (768, 768)
|
||||
|
||||
@@ -9,6 +9,7 @@ import torch
|
||||
|
||||
from library import anima_utils, train_util
|
||||
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
|
||||
from library import qwen_image_autoencoder_kl
|
||||
|
||||
from library.utils import setup_logging
|
||||
|
||||
@@ -45,8 +46,8 @@ class AnimaTokenizeStrategy(TokenizeStrategy):
|
||||
t5_tokenizer = anima_utils.load_t5_tokenizer(t5_tokenizer_path)
|
||||
|
||||
self.qwen3_tokenizer = qwen3_tokenizer
|
||||
self.t5_tokenizer = t5_tokenizer
|
||||
self.qwen3_max_length = qwen3_max_length
|
||||
self.t5_tokenizer = t5_tokenizer
|
||||
self.t5_max_length = t5_max_length
|
||||
|
||||
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
|
||||
@@ -54,26 +55,17 @@ class AnimaTokenizeStrategy(TokenizeStrategy):
|
||||
|
||||
# Tokenize with Qwen3
|
||||
qwen3_encoding = self.qwen3_tokenizer.batch_encode_plus(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
max_length=self.qwen3_max_length,
|
||||
text, return_tensors="pt", truncation=True, padding="max_length", max_length=self.qwen3_max_length
|
||||
)
|
||||
qwen3_input_ids = qwen3_encoding["input_ids"]
|
||||
qwen3_attn_mask = qwen3_encoding["attention_mask"]
|
||||
|
||||
# Tokenize with T5 (for LLM Adapter target tokens)
|
||||
t5_encoding = self.t5_tokenizer.batch_encode_plus(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
max_length=self.t5_max_length,
|
||||
text, return_tensors="pt", truncation=True, padding="max_length", max_length=self.t5_max_length
|
||||
)
|
||||
t5_input_ids = t5_encoding["input_ids"]
|
||||
t5_attn_mask = t5_encoding["attention_mask"]
|
||||
|
||||
return [qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask]
|
||||
|
||||
|
||||
@@ -84,46 +76,11 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
|
||||
T5 tokens are passed through unchanged (only used by LLM Adapter).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dropout_rate: float = 0.0,
|
||||
) -> None:
|
||||
self.dropout_rate = dropout_rate
|
||||
# Cached unconditional embeddings (from encoding empty caption "")
|
||||
# Must be initialized via cache_uncond_embeddings() before text encoder is deleted
|
||||
self._uncond_prompt_embeds: Optional[torch.Tensor] = None # (1, seq_len, hidden)
|
||||
self._uncond_attn_mask: Optional[torch.Tensor] = None # (1, seq_len)
|
||||
self._uncond_t5_input_ids: Optional[torch.Tensor] = None # (1, t5_seq_len)
|
||||
self._uncond_t5_attn_mask: Optional[torch.Tensor] = None # (1, t5_seq_len)
|
||||
|
||||
def cache_uncond_embeddings(
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
models: List[Any],
|
||||
) -> None:
|
||||
"""Pre-encode empty caption "" and cache the unconditional embeddings.
|
||||
|
||||
Must be called before the text encoder is deleted from GPU.
|
||||
This matches diffusion-pipe-main behavior where empty caption embeddings
|
||||
are pre-cached and swapped in during caption dropout.
|
||||
"""
|
||||
logger.info("Caching unconditional embeddings for caption dropout (encoding empty caption)...")
|
||||
tokens = tokenize_strategy.tokenize("")
|
||||
with torch.no_grad():
|
||||
uncond_outputs = self.encode_tokens(tokenize_strategy, models, tokens, enable_dropout=False)
|
||||
# Store as CPU tensors (1, seq_len, ...) to avoid GPU memory waste
|
||||
self._uncond_prompt_embeds = uncond_outputs[0].cpu()
|
||||
self._uncond_attn_mask = uncond_outputs[1].cpu()
|
||||
self._uncond_t5_input_ids = uncond_outputs[2].cpu()
|
||||
self._uncond_t5_attn_mask = uncond_outputs[3].cpu()
|
||||
logger.info(" Unconditional embeddings cached successfully")
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def encode_tokens(
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
models: List[Any],
|
||||
tokens: List[torch.Tensor],
|
||||
enable_dropout: bool = True,
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
|
||||
) -> List[torch.Tensor]:
|
||||
"""Encode Qwen3 tokens and return embeddings + T5 token IDs.
|
||||
|
||||
@@ -134,82 +91,20 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
|
||||
Returns:
|
||||
[prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
|
||||
"""
|
||||
# Do not handle dropout here; handled dataset-side or in drop_cached_text_encoder_outputs()
|
||||
|
||||
qwen3_text_encoder = models[0]
|
||||
qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask = tokens
|
||||
|
||||
# Handle dropout: replace dropped items with unconditional embeddings (matching diffusion-pipe-main)
|
||||
batch_size = qwen3_input_ids.shape[0]
|
||||
non_drop_indices = []
|
||||
for i in range(batch_size):
|
||||
drop = enable_dropout and (self.dropout_rate > 0.0 and random.random() < self.dropout_rate)
|
||||
if not drop:
|
||||
non_drop_indices.append(i)
|
||||
encoder_device = qwen3_text_encoder.device
|
||||
|
||||
encoder_device = qwen3_text_encoder.device if hasattr(qwen3_text_encoder, 'device') else next(qwen3_text_encoder.parameters()).device
|
||||
qwen3_input_ids = qwen3_input_ids.to(encoder_device)
|
||||
qwen3_attn_mask = qwen3_attn_mask.to(encoder_device)
|
||||
outputs = qwen3_text_encoder(input_ids=qwen3_input_ids, attention_mask=qwen3_attn_mask)
|
||||
prompt_embeds = outputs.last_hidden_state
|
||||
prompt_embeds[~qwen3_attn_mask.bool()] = 0
|
||||
|
||||
if len(non_drop_indices) > 0 and len(non_drop_indices) < batch_size:
|
||||
# Only encode non-dropped items to save compute
|
||||
nd_input_ids = qwen3_input_ids[non_drop_indices].to(encoder_device)
|
||||
nd_attn_mask = qwen3_attn_mask[non_drop_indices].to(encoder_device)
|
||||
elif len(non_drop_indices) == batch_size:
|
||||
nd_input_ids = qwen3_input_ids.to(encoder_device)
|
||||
nd_attn_mask = qwen3_attn_mask.to(encoder_device)
|
||||
else:
|
||||
nd_input_ids = None
|
||||
nd_attn_mask = None
|
||||
|
||||
if nd_input_ids is not None:
|
||||
outputs = qwen3_text_encoder(input_ids=nd_input_ids, attention_mask=nd_attn_mask)
|
||||
nd_encoded_text = outputs.last_hidden_state
|
||||
# Zero out padding positions
|
||||
nd_encoded_text[~nd_attn_mask.bool()] = 0
|
||||
|
||||
# Build full batch: fill non-dropped with encoded, dropped with unconditional
|
||||
if len(non_drop_indices) == batch_size:
|
||||
prompt_embeds = nd_encoded_text
|
||||
attn_mask = qwen3_attn_mask.to(encoder_device)
|
||||
else:
|
||||
# Get unconditional embeddings
|
||||
if self._uncond_prompt_embeds is not None:
|
||||
uncond_pe = self._uncond_prompt_embeds[0]
|
||||
uncond_am = self._uncond_attn_mask[0]
|
||||
uncond_t5_ids = self._uncond_t5_input_ids[0]
|
||||
uncond_t5_am = self._uncond_t5_attn_mask[0]
|
||||
else:
|
||||
# Encode empty caption on-the-fly (text encoder still available)
|
||||
uncond_tokens = tokenize_strategy.tokenize("")
|
||||
uncond_ids = uncond_tokens[0].to(encoder_device)
|
||||
uncond_mask = uncond_tokens[1].to(encoder_device)
|
||||
uncond_out = qwen3_text_encoder(input_ids=uncond_ids, attention_mask=uncond_mask)
|
||||
uncond_pe = uncond_out.last_hidden_state[0]
|
||||
uncond_pe[~uncond_mask[0].bool()] = 0
|
||||
uncond_am = uncond_mask[0]
|
||||
uncond_t5_ids = uncond_tokens[2][0]
|
||||
uncond_t5_am = uncond_tokens[3][0]
|
||||
|
||||
seq_len = qwen3_input_ids.shape[1]
|
||||
hidden_size = nd_encoded_text.shape[-1] if nd_encoded_text is not None else uncond_pe.shape[-1]
|
||||
dtype = nd_encoded_text.dtype if nd_encoded_text is not None else uncond_pe.dtype
|
||||
|
||||
prompt_embeds = torch.zeros((batch_size, seq_len, hidden_size), device=encoder_device, dtype=dtype)
|
||||
attn_mask = torch.zeros((batch_size, seq_len), device=encoder_device, dtype=qwen3_attn_mask.dtype)
|
||||
|
||||
if len(non_drop_indices) > 0:
|
||||
prompt_embeds[non_drop_indices] = nd_encoded_text
|
||||
attn_mask[non_drop_indices] = nd_attn_mask
|
||||
|
||||
# Fill dropped items with unconditional embeddings
|
||||
t5_input_ids = t5_input_ids.clone()
|
||||
t5_attn_mask = t5_attn_mask.clone()
|
||||
drop_indices = [i for i in range(batch_size) if i not in non_drop_indices]
|
||||
for i in drop_indices:
|
||||
prompt_embeds[i] = uncond_pe.to(device=encoder_device, dtype=dtype)
|
||||
attn_mask[i] = uncond_am.to(device=encoder_device, dtype=qwen3_attn_mask.dtype)
|
||||
t5_input_ids[i] = uncond_t5_ids.to(device=t5_input_ids.device, dtype=t5_input_ids.dtype)
|
||||
t5_attn_mask[i] = uncond_t5_am.to(device=t5_attn_mask.device, dtype=t5_attn_mask.dtype)
|
||||
|
||||
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
|
||||
return [prompt_embeds, qwen3_attn_mask, t5_input_ids, t5_attn_mask]
|
||||
|
||||
def drop_cached_text_encoder_outputs(
|
||||
self,
|
||||
@@ -217,6 +112,7 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
|
||||
attn_mask: torch.Tensor,
|
||||
t5_input_ids: torch.Tensor,
|
||||
t5_attn_mask: torch.Tensor,
|
||||
caption_dropout_rates: Optional[torch.Tensor] = None,
|
||||
) -> List[torch.Tensor]:
|
||||
"""Apply dropout to cached text encoder outputs.
|
||||
|
||||
@@ -224,7 +120,9 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
|
||||
Replaces dropped items with pre-cached unconditional embeddings (from encoding "")
|
||||
to match diffusion-pipe-main behavior.
|
||||
"""
|
||||
if prompt_embeds is not None and self.dropout_rate > 0.0:
|
||||
if caption_dropout_rates is None or torch.all(caption_dropout_rates == 0.0).item():
|
||||
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
|
||||
|
||||
# Clone to avoid in-place modification of cached tensors
|
||||
prompt_embeds = prompt_embeds.clone()
|
||||
if attn_mask is not None:
|
||||
@@ -235,26 +133,17 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
|
||||
t5_attn_mask = t5_attn_mask.clone()
|
||||
|
||||
for i in range(prompt_embeds.shape[0]):
|
||||
if random.random() < self.dropout_rate:
|
||||
if self._uncond_prompt_embeds is not None:
|
||||
if random.random() < caption_dropout_rates[i].item():
|
||||
# Use pre-cached unconditional embeddings
|
||||
prompt_embeds[i] = self._uncond_prompt_embeds[0].to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)
|
||||
prompt_embeds[i] = 0
|
||||
if attn_mask is not None:
|
||||
attn_mask[i] = self._uncond_attn_mask[0].to(device=attn_mask.device, dtype=attn_mask.dtype)
|
||||
attn_mask[i] = 0
|
||||
if t5_input_ids is not None:
|
||||
t5_input_ids[i] = self._uncond_t5_input_ids[0].to(device=t5_input_ids.device, dtype=t5_input_ids.dtype)
|
||||
t5_input_ids[i, 0] = 1 # Set to </s> token ID
|
||||
t5_input_ids[i, 1:] = 0
|
||||
if t5_attn_mask is not None:
|
||||
t5_attn_mask[i] = self._uncond_t5_attn_mask[0].to(device=t5_attn_mask.device, dtype=t5_attn_mask.dtype)
|
||||
else:
|
||||
# Fallback: zero out (should not happen if cache_uncond_embeddings was called)
|
||||
logger.warning("Unconditional embeddings not cached, falling back to zeros for caption dropout")
|
||||
prompt_embeds[i] = torch.zeros_like(prompt_embeds[i])
|
||||
if attn_mask is not None:
|
||||
attn_mask[i] = torch.zeros_like(attn_mask[i])
|
||||
if t5_input_ids is not None:
|
||||
t5_input_ids[i] = torch.zeros_like(t5_input_ids[i])
|
||||
if t5_attn_mask is not None:
|
||||
t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i])
|
||||
t5_attn_mask[i, 0] = 1
|
||||
t5_attn_mask[i, 1:] = 0
|
||||
|
||||
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
|
||||
|
||||
@@ -297,6 +186,8 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
return False
|
||||
if "t5_attn_mask" not in npz:
|
||||
return False
|
||||
if "caption_dropout_rate" not in npz:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
raise e
|
||||
@@ -309,7 +200,8 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
attn_mask = data["attn_mask"]
|
||||
t5_input_ids = data["t5_input_ids"]
|
||||
t5_attn_mask = data["t5_attn_mask"]
|
||||
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
|
||||
caption_dropout_rate = data["caption_dropout_rate"]
|
||||
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask, caption_dropout_rate]
|
||||
|
||||
def cache_batch_outputs(
|
||||
self,
|
||||
@@ -323,12 +215,8 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
|
||||
tokens_and_masks = tokenize_strategy.tokenize(captions)
|
||||
with torch.no_grad():
|
||||
# Always disable dropout during caching
|
||||
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = anima_text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy,
|
||||
models,
|
||||
tokens_and_masks,
|
||||
enable_dropout=False,
|
||||
tokenize_strategy, models, tokens_and_masks
|
||||
)
|
||||
|
||||
# Convert to numpy for caching
|
||||
@@ -344,6 +232,7 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
attn_mask_i = attn_mask[i]
|
||||
t5_input_ids_i = t5_input_ids[i]
|
||||
t5_attn_mask_i = t5_attn_mask[i]
|
||||
caption_dropout_rate = torch.tensor(info.caption_dropout_rate, dtype=torch.float32)
|
||||
|
||||
if self.cache_to_disk:
|
||||
np.savez(
|
||||
@@ -352,9 +241,10 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
attn_mask=attn_mask_i,
|
||||
t5_input_ids=t5_input_ids_i,
|
||||
t5_attn_mask=t5_attn_mask_i,
|
||||
caption_dropout_rate=caption_dropout_rate,
|
||||
)
|
||||
else:
|
||||
info.text_encoder_outputs = (prompt_embeds_i, attn_mask_i, t5_input_ids_i, t5_attn_mask_i)
|
||||
info.text_encoder_outputs = (prompt_embeds_i, attn_mask_i, t5_input_ids_i, t5_attn_mask_i, caption_dropout_rate)
|
||||
|
||||
|
||||
class AnimaLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
@@ -374,18 +264,10 @@ class AnimaLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
return self.ANIMA_LATENTS_NPZ_SUFFIX
|
||||
|
||||
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
||||
return (
|
||||
os.path.splitext(absolute_path)[0]
|
||||
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
|
||||
+ self.ANIMA_LATENTS_NPZ_SUFFIX
|
||||
)
|
||||
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.ANIMA_LATENTS_NPZ_SUFFIX
|
||||
|
||||
def is_disk_cached_latents_expected(
|
||||
self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool
|
||||
):
|
||||
return self._default_is_disk_cached_latents_expected(
|
||||
8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True
|
||||
)
|
||||
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
||||
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
|
||||
|
||||
def load_latents_from_disk(
|
||||
self, npz_path: str, bucket_reso: Tuple[int, int]
|
||||
@@ -393,32 +275,23 @@ class AnimaLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
return self._default_load_latents_from_disk(8, npz_path, bucket_reso)
|
||||
|
||||
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
"""Cache batch of latents using WanVAE.
|
||||
"""Cache batch of latents using Qwen Image VAE.
|
||||
|
||||
vae is expected to be the WanVAE_ model (not the wrapper).
|
||||
vae is expected to be the Qwen Image VAE (AutoencoderKLQwenImage).
|
||||
The encoding function handles the mean/std normalization.
|
||||
"""
|
||||
from library.anima_models import ANIMA_VAE_MEAN, ANIMA_VAE_STD
|
||||
|
||||
vae_device = next(vae.parameters()).device
|
||||
vae_dtype = next(vae.parameters()).dtype
|
||||
|
||||
# Create scale tensors on VAE device
|
||||
mean = torch.tensor(ANIMA_VAE_MEAN, dtype=vae_dtype, device=vae_device)
|
||||
std = torch.tensor(ANIMA_VAE_STD, dtype=vae_dtype, device=vae_device)
|
||||
scale = [mean, 1.0 / std]
|
||||
vae: qwen_image_autoencoder_kl.AutoencoderKLQwenImage = vae
|
||||
vae_device = vae.device
|
||||
vae_dtype = vae.dtype
|
||||
|
||||
def encode_by_vae(img_tensor):
|
||||
"""Encode image tensor to latents.
|
||||
|
||||
img_tensor: (B, C, H, W) in [-1, 1] range (already normalized by IMAGE_TRANSFORMS)
|
||||
Need to add temporal dim to get (B, C, T=1, H, W) for WanVAE
|
||||
Qwen Image VAE accepts inputs in (B, C, H, W) or (B, C, 1, H, W) shape.
|
||||
Returns latents in (B, 16, 1, H/8, W/8) shape on CPU.
|
||||
"""
|
||||
# Add temporal dimension: (B, C, H, W) -> (B, C, 1, H, W)
|
||||
img_tensor = img_tensor.unsqueeze(2)
|
||||
img_tensor = img_tensor.to(vae_device, dtype=vae_dtype)
|
||||
|
||||
latents = vae.encode(img_tensor, scale)
|
||||
latents = vae.encode_pixels_to_latents(img_tensor) # Keep 4D for input/output
|
||||
return latents.to("cpu")
|
||||
|
||||
self._default_cache_batch_latents(
|
||||
|
||||
@@ -179,12 +179,15 @@ def split_train_val(
|
||||
|
||||
|
||||
class ImageInfo:
|
||||
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
|
||||
def __init__(
|
||||
self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str, caption_dropout_rate: float = 0.0
|
||||
) -> None:
|
||||
self.image_key: str = image_key
|
||||
self.num_repeats: int = num_repeats
|
||||
self.caption: str = caption
|
||||
self.is_reg: bool = is_reg
|
||||
self.absolute_path: str = absolute_path
|
||||
self.caption_dropout_rate: float = caption_dropout_rate
|
||||
self.image_size: Tuple[int, int] = None
|
||||
self.resized_size: Tuple[int, int] = None
|
||||
self.bucket_reso: Tuple[int, int] = None
|
||||
@@ -197,7 +200,7 @@ class ImageInfo:
|
||||
)
|
||||
self.cond_img_path: Optional[str] = None
|
||||
self.image: Optional[Image.Image] = None # optional, original PIL Image
|
||||
self.text_encoder_outputs_npz: Optional[str] = None # set in cache_text_encoder_outputs
|
||||
self.text_encoder_outputs_npz: Optional[str] = None # filename. set in cache_text_encoder_outputs
|
||||
|
||||
# new
|
||||
self.text_encoder_outputs: Optional[List[torch.Tensor]] = None
|
||||
@@ -1096,11 +1099,11 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
def is_latent_cacheable(self):
|
||||
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
|
||||
|
||||
def is_text_encoder_output_cacheable(self):
|
||||
def is_text_encoder_output_cacheable(self, cache_supports_dropout: bool = False):
|
||||
return all(
|
||||
[
|
||||
not (
|
||||
subset.caption_dropout_rate > 0
|
||||
subset.caption_dropout_rate > 0 and not cache_supports_dropout
|
||||
or subset.shuffle_caption
|
||||
or subset.token_warmup_step > 0
|
||||
or subset.caption_tag_dropout_rate > 0
|
||||
@@ -2137,7 +2140,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
num_train_images += num_repeats * len(img_paths)
|
||||
|
||||
for img_path, caption, size in zip(img_paths, captions, sizes):
|
||||
info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path)
|
||||
info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path, subset.caption_dropout_rate)
|
||||
info.resize_interpolation = (
|
||||
subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
|
||||
)
|
||||
@@ -2338,7 +2341,7 @@ class FineTuningDataset(BaseDataset):
|
||||
if caption is None:
|
||||
caption = ""
|
||||
|
||||
image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path)
|
||||
image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path, subset.caption_dropout_rate)
|
||||
image_info.resize_interpolation = (
|
||||
subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
|
||||
)
|
||||
@@ -2661,8 +2664,8 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
def is_latent_cacheable(self) -> bool:
|
||||
return all([dataset.is_latent_cacheable() for dataset in self.datasets])
|
||||
|
||||
def is_text_encoder_output_cacheable(self) -> bool:
|
||||
return all([dataset.is_text_encoder_output_cacheable() for dataset in self.datasets])
|
||||
def is_text_encoder_output_cacheable(self, cache_supports_dropout: bool = False) -> bool:
|
||||
return all([dataset.is_text_encoder_output_cacheable(cache_supports_dropout) for dataset in self.datasets])
|
||||
|
||||
def set_current_strategies(self):
|
||||
for dataset in self.datasets:
|
||||
@@ -3578,6 +3581,7 @@ def get_sai_model_spec_dataclass(
|
||||
flux: str = None,
|
||||
lumina: str = None,
|
||||
hunyuan_image: str = None,
|
||||
anima: str = None,
|
||||
optional_metadata: dict[str, str] | None = None,
|
||||
) -> sai_model_spec.ModelSpecMetadata:
|
||||
"""
|
||||
@@ -3609,7 +3613,8 @@ def get_sai_model_spec_dataclass(
|
||||
model_config["lumina"] = lumina
|
||||
if hunyuan_image is not None:
|
||||
model_config["hunyuan_image"] = hunyuan_image
|
||||
|
||||
if anima is not None:
|
||||
model_config["anima"] = anima
|
||||
# Use the dataclass function directly
|
||||
return sai_model_spec.build_metadata_dataclass(
|
||||
state_dict,
|
||||
|
||||
160
networks/convert_anima_lora_to_comfy.py
Normal file
160
networks/convert_anima_lora_to_comfy.py
Normal file
@@ -0,0 +1,160 @@
|
||||
import argparse
|
||||
from safetensors.torch import save_file
|
||||
from safetensors import safe_open
|
||||
|
||||
|
||||
from library import train_util
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
COMFYUI_DIT_PREFIX = "diffusion_model."
|
||||
COMFYUI_QWEN3_PREFIX = "text_encoders.qwen3_06b.transformer.model."
|
||||
|
||||
|
||||
def main(args):
|
||||
# load source safetensors
|
||||
logger.info(f"Loading source file {args.src_path}")
|
||||
state_dict = {}
|
||||
with safe_open(args.src_path, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
for k in f.keys():
|
||||
state_dict[k] = f.get_tensor(k)
|
||||
|
||||
logger.info(f"Converting...")
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
count = 0
|
||||
|
||||
for k in keys:
|
||||
if not args.reverse:
|
||||
is_dit_lora = k.startswith("lora_unet_")
|
||||
module_and_weight_name = "_".join(k.split("_")[2:]) # Remove `lora_unet_`or `lora_te_` prefix
|
||||
|
||||
# Split at the first dot, e.g., "block1_linear.weight" -> "block1_linear", "weight"
|
||||
module_name, weight_name = module_and_weight_name.split(".", 1)
|
||||
|
||||
# Weight name conversion: lora_up/lora_down to lora_A/lora_B
|
||||
if weight_name.startswith("lora_up"):
|
||||
weight_name = weight_name.replace("lora_up", "lora_B")
|
||||
elif weight_name.startswith("lora_down"):
|
||||
weight_name = weight_name.replace("lora_down", "lora_A")
|
||||
else:
|
||||
# Keep other weight names as-is: e.g. alpha
|
||||
pass
|
||||
|
||||
# Module name conversion: convert dots to underscores
|
||||
original_module_name = module_name.replace("_", ".") # Convert to dot notation
|
||||
|
||||
# Convert back illegal dots in module names
|
||||
# DiT
|
||||
original_module_name = original_module_name.replace("llm.adapter", "llm_adapter")
|
||||
original_module_name = original_module_name.replace(".linear.", ".linear_")
|
||||
original_module_name = original_module_name.replace("t.embedding.norm", "t_embedding_norm")
|
||||
original_module_name = original_module_name.replace("x.embedder", "x_embedder")
|
||||
original_module_name = original_module_name.replace("adaln.modulation.cross_attn", "adaln_modulation_cross_attn")
|
||||
original_module_name = original_module_name.replace("adaln.modulation.mlp", "adaln_modulation_mlp")
|
||||
original_module_name = original_module_name.replace("cross.attn", "cross_attn")
|
||||
original_module_name = original_module_name.replace("k.proj", "k_proj")
|
||||
original_module_name = original_module_name.replace("k.norm", "k_norm")
|
||||
original_module_name = original_module_name.replace("q.proj", "q_proj")
|
||||
original_module_name = original_module_name.replace("q.norm", "q_norm")
|
||||
original_module_name = original_module_name.replace("v.proj", "v_proj")
|
||||
original_module_name = original_module_name.replace("o.proj", "o_proj")
|
||||
original_module_name = original_module_name.replace("output.proj", "output_proj")
|
||||
original_module_name = original_module_name.replace("self.attn", "self_attn")
|
||||
original_module_name = original_module_name.replace("final.layer", "final_layer")
|
||||
original_module_name = original_module_name.replace("adaln.modulation", "adaln_modulation")
|
||||
original_module_name = original_module_name.replace("norm.cross.attn", "norm_cross_attn")
|
||||
original_module_name = original_module_name.replace("norm.mlp", "norm_mlp")
|
||||
original_module_name = original_module_name.replace("norm.self.attn", "norm_self_attn")
|
||||
original_module_name = original_module_name.replace("out.proj", "out_proj")
|
||||
|
||||
# Qwen3
|
||||
original_module_name = original_module_name.replace("embed.tokens", "embed_tokens")
|
||||
original_module_name = original_module_name.replace("input.layernorm", "input_layernorm")
|
||||
original_module_name = original_module_name.replace("down.proj", "down_proj")
|
||||
original_module_name = original_module_name.replace("gate.proj", "gate_proj")
|
||||
original_module_name = original_module_name.replace("up.proj", "up_proj")
|
||||
original_module_name = original_module_name.replace("post.attention.layernorm", "post_attention_layernorm")
|
||||
|
||||
# Prefix conversion
|
||||
new_prefix = COMFYUI_DIT_PREFIX if is_dit_lora else COMFYUI_QWEN3_PREFIX
|
||||
|
||||
new_k = f"{new_prefix}{original_module_name}.{weight_name}"
|
||||
else:
|
||||
if k.startswith(COMFYUI_DIT_PREFIX):
|
||||
is_dit_lora = True
|
||||
module_and_weight_name = k[len(COMFYUI_DIT_PREFIX) :]
|
||||
elif k.startswith(COMFYUI_QWEN3_PREFIX):
|
||||
is_dit_lora = False
|
||||
module_and_weight_name = k[len(COMFYUI_QWEN3_PREFIX) :]
|
||||
else:
|
||||
logger.warning(f"Skipping unrecognized key {k}")
|
||||
continue
|
||||
|
||||
# Get weight name
|
||||
if ".lora_" in module_and_weight_name:
|
||||
module_name, weight_name = module_and_weight_name.rsplit(".lora_", 1)
|
||||
weight_name = "lora_" + weight_name
|
||||
else:
|
||||
module_name, weight_name = module_and_weight_name.rsplit(".", 1) # Keep other weight names as-is: e.g. alpha
|
||||
|
||||
# Weight name conversion: lora_A/lora_B to lora_up/lora_down
|
||||
# Note: we only convert lora_A and lora_B weights, other weights are kept as-is
|
||||
if weight_name.startswith("lora_B"):
|
||||
weight_name = weight_name.replace("lora_B", "lora_up")
|
||||
elif weight_name.startswith("lora_A"):
|
||||
weight_name = weight_name.replace("lora_A", "lora_down")
|
||||
|
||||
# Module name conversion: convert dots to underscores
|
||||
module_name = module_name.replace(".", "_") # Convert to underscore notation
|
||||
|
||||
# Prefix conversion
|
||||
prefix = "lora_unet_" if is_dit_lora else "lora_te_"
|
||||
|
||||
new_k = f"{prefix}{module_name}.{weight_name}"
|
||||
|
||||
state_dict[new_k] = state_dict.pop(k)
|
||||
count += 1
|
||||
|
||||
logger.info(f"Converted {count} keys")
|
||||
if count == 0:
|
||||
logger.warning("No keys were converted. Please check if the source file is in the expected format.")
|
||||
elif count > 0 and count < len(keys):
|
||||
logger.warning(
|
||||
f"Only {count} out of {len(keys)} keys were converted. Please check if there are unexpected keys in the source file."
|
||||
)
|
||||
|
||||
# Calculate hash
|
||||
if metadata is not None:
|
||||
logger.info(f"Calculating hashes and creating metadata...")
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
metadata["sshs_model_hash"] = model_hash
|
||||
metadata["sshs_legacy_hash"] = legacy_hash
|
||||
|
||||
# save destination safetensors
|
||||
logger.info(f"Saving destination file {args.dst_path}")
|
||||
save_file(state_dict, args.dst_path, metadata=metadata)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert LoRA format")
|
||||
parser.add_argument(
|
||||
"src_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="source path, sd-scripts format (or ComfyUI compatible format if --reverse is set, only supported for LoRAs converted by this script)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"dst_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="destination path, ComfyUI compatible format (or sd-scripts format if --reverse is set)",
|
||||
)
|
||||
parser.add_argument("--reverse", action="store_true", help="reverse conversion direction")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@@ -1,18 +1,17 @@
|
||||
# LoRA network module for Anima
|
||||
import math
|
||||
import ast
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
from library.utils import setup_logging
|
||||
from networks.lora_flux import LoRAModule, LoRAInfModule
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
setup_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from networks.lora_flux import LoRAModule, LoRAInfModule
|
||||
|
||||
|
||||
def create_network(
|
||||
multiplier: float,
|
||||
@@ -29,68 +28,28 @@ def create_network(
|
||||
if network_alpha is None:
|
||||
network_alpha = 1.0
|
||||
|
||||
# type_dims: [self_attn_dim, cross_attn_dim, mlp_dim, mod_dim, llm_adapter_dim]
|
||||
self_attn_dim = kwargs.get("self_attn_dim", None)
|
||||
cross_attn_dim = kwargs.get("cross_attn_dim", None)
|
||||
mlp_dim = kwargs.get("mlp_dim", None)
|
||||
mod_dim = kwargs.get("mod_dim", None)
|
||||
llm_adapter_dim = kwargs.get("llm_adapter_dim", None)
|
||||
|
||||
if self_attn_dim is not None:
|
||||
self_attn_dim = int(self_attn_dim)
|
||||
if cross_attn_dim is not None:
|
||||
cross_attn_dim = int(cross_attn_dim)
|
||||
if mlp_dim is not None:
|
||||
mlp_dim = int(mlp_dim)
|
||||
if mod_dim is not None:
|
||||
mod_dim = int(mod_dim)
|
||||
if llm_adapter_dim is not None:
|
||||
llm_adapter_dim = int(llm_adapter_dim)
|
||||
|
||||
type_dims = [self_attn_dim, cross_attn_dim, mlp_dim, mod_dim, llm_adapter_dim]
|
||||
if all([d is None for d in type_dims]):
|
||||
type_dims = None
|
||||
|
||||
# emb_dims: [x_embedder, t_embedder, final_layer]
|
||||
emb_dims = kwargs.get("emb_dims", None)
|
||||
if emb_dims is not None:
|
||||
emb_dims = emb_dims.strip()
|
||||
if emb_dims.startswith("[") and emb_dims.endswith("]"):
|
||||
emb_dims = emb_dims[1:-1]
|
||||
emb_dims = [int(d) for d in emb_dims.split(",")]
|
||||
assert len(emb_dims) == 3, f"invalid emb_dims: {emb_dims}, must be 3 dimensions (x_embedder, t_embedder, final_layer)"
|
||||
|
||||
# block selection
|
||||
def parse_block_selection(selection: str, total_blocks: int) -> List[bool]:
|
||||
if selection == "all":
|
||||
return [True] * total_blocks
|
||||
if selection == "none" or selection == "":
|
||||
return [False] * total_blocks
|
||||
|
||||
selected = [False] * total_blocks
|
||||
ranges = selection.split(",")
|
||||
for r in ranges:
|
||||
if "-" in r:
|
||||
start, end = map(str.strip, r.split("-"))
|
||||
start, end = int(start), int(end)
|
||||
assert 0 <= start < total_blocks and 0 <= end < total_blocks and start <= end
|
||||
for i in range(start, end + 1):
|
||||
selected[i] = True
|
||||
else:
|
||||
index = int(r)
|
||||
assert 0 <= index < total_blocks
|
||||
selected[index] = True
|
||||
return selected
|
||||
|
||||
train_block_indices = kwargs.get("train_block_indices", None)
|
||||
if train_block_indices is not None:
|
||||
num_blocks = len(unet.blocks) if hasattr(unet, 'blocks') else 999
|
||||
train_block_indices = parse_block_selection(train_block_indices, num_blocks)
|
||||
|
||||
# train LLM adapter
|
||||
train_llm_adapter = kwargs.get("train_llm_adapter", False)
|
||||
train_llm_adapter = kwargs.get("train_llm_adapter", "false")
|
||||
if train_llm_adapter is not None:
|
||||
train_llm_adapter = True if train_llm_adapter == "True" else False
|
||||
train_llm_adapter = True if train_llm_adapter.lower() == "true" else False
|
||||
|
||||
exclude_patterns = kwargs.get("exclude_patterns", None)
|
||||
if exclude_patterns is None:
|
||||
exclude_patterns = []
|
||||
else:
|
||||
exclude_patterns = ast.literal_eval(exclude_patterns)
|
||||
if not isinstance(exclude_patterns, list):
|
||||
exclude_patterns = [exclude_patterns]
|
||||
|
||||
# add default exclude patterns
|
||||
exclude_patterns.append(r".*(_modulation|_norm|_embedder|final_layer).*")
|
||||
|
||||
# regular expression for module selection: exclude and include
|
||||
include_patterns = kwargs.get("include_patterns", None)
|
||||
if include_patterns is not None:
|
||||
include_patterns = ast.literal_eval(include_patterns)
|
||||
if not isinstance(include_patterns, list):
|
||||
include_patterns = [include_patterns]
|
||||
|
||||
# rank/module dropout
|
||||
rank_dropout = kwargs.get("rank_dropout", None)
|
||||
@@ -101,9 +60,43 @@ def create_network(
|
||||
module_dropout = float(module_dropout)
|
||||
|
||||
# verbose
|
||||
verbose = kwargs.get("verbose", False)
|
||||
verbose = kwargs.get("verbose", "false")
|
||||
if verbose is not None:
|
||||
verbose = True if verbose == "True" else False
|
||||
verbose = True if verbose.lower() == "true" else False
|
||||
|
||||
# regex-specific learning rates / dimensions
|
||||
def parse_kv_pairs(kv_pair_str: str, is_int: bool) -> Dict[str, float]:
|
||||
"""
|
||||
Parse a string of key-value pairs separated by commas.
|
||||
"""
|
||||
pairs = {}
|
||||
for pair in kv_pair_str.split(","):
|
||||
pair = pair.strip()
|
||||
if not pair:
|
||||
continue
|
||||
if "=" not in pair:
|
||||
logger.warning(f"Invalid format: {pair}, expected 'key=value'")
|
||||
continue
|
||||
key, value = pair.split("=", 1)
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
try:
|
||||
pairs[key] = int(value) if is_int else float(value)
|
||||
except ValueError:
|
||||
logger.warning(f"Invalid value for {key}: {value}")
|
||||
return pairs
|
||||
|
||||
network_reg_lrs = kwargs.get("network_reg_lrs", None)
|
||||
if network_reg_lrs is not None:
|
||||
reg_lrs = parse_kv_pairs(network_reg_lrs, is_int=False)
|
||||
else:
|
||||
reg_lrs = None
|
||||
|
||||
network_reg_dims = kwargs.get("network_reg_dims", None)
|
||||
if network_reg_dims is not None:
|
||||
reg_dims = parse_kv_pairs(network_reg_dims, is_int=True)
|
||||
else:
|
||||
reg_dims = None
|
||||
|
||||
network = LoRANetwork(
|
||||
text_encoders,
|
||||
@@ -115,9 +108,10 @@ def create_network(
|
||||
rank_dropout=rank_dropout,
|
||||
module_dropout=module_dropout,
|
||||
train_llm_adapter=train_llm_adapter,
|
||||
type_dims=type_dims,
|
||||
emb_dims=emb_dims,
|
||||
train_block_indices=train_block_indices,
|
||||
exclude_patterns=exclude_patterns,
|
||||
include_patterns=include_patterns,
|
||||
reg_dims=reg_dims,
|
||||
reg_lrs=reg_lrs,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
@@ -137,6 +131,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, unet, weigh
|
||||
if weights_sd is None:
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
@@ -173,15 +168,15 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, unet, weigh
|
||||
|
||||
|
||||
class LoRANetwork(torch.nn.Module):
|
||||
# Target modules: DiT blocks
|
||||
ANIMA_TARGET_REPLACE_MODULE = ["Block"]
|
||||
# Target modules: DiT blocks, embedders, final layer. embedders and final layer are excluded by default.
|
||||
ANIMA_TARGET_REPLACE_MODULE = ["Block", "PatchEmbed", "TimestepEmbedding", "FinalLayer"]
|
||||
# Target modules: LLM Adapter blocks
|
||||
ANIMA_ADAPTER_TARGET_REPLACE_MODULE = ["LLMAdapterTransformerBlock"]
|
||||
# Target modules for text encoder (Qwen3)
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["Qwen3Attention", "Qwen3MLP", "Qwen3SdpaAttention", "Qwen3FlashAttention2"]
|
||||
|
||||
LORA_PREFIX_ANIMA = "lora_unet" # ComfyUI compatible
|
||||
LORA_PREFIX_TEXT_ENCODER = "lora_te1" # Qwen3
|
||||
LORA_PREFIX_TEXT_ENCODER = "lora_te" # Qwen3
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -197,9 +192,10 @@ class LoRANetwork(torch.nn.Module):
|
||||
modules_dim: Optional[Dict[str, int]] = None,
|
||||
modules_alpha: Optional[Dict[str, int]] = None,
|
||||
train_llm_adapter: bool = False,
|
||||
type_dims: Optional[List[int]] = None,
|
||||
emb_dims: Optional[List[int]] = None,
|
||||
train_block_indices: Optional[List[bool]] = None,
|
||||
exclude_patterns: Optional[List[str]] = None,
|
||||
include_patterns: Optional[List[str]] = None,
|
||||
reg_dims: Optional[Dict[str, int]] = None,
|
||||
reg_lrs: Optional[Dict[str, float]] = None,
|
||||
verbose: Optional[bool] = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -210,21 +206,36 @@ class LoRANetwork(torch.nn.Module):
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
self.train_llm_adapter = train_llm_adapter
|
||||
self.type_dims = type_dims
|
||||
self.emb_dims = emb_dims
|
||||
self.train_block_indices = train_block_indices
|
||||
self.reg_dims = reg_dims
|
||||
self.reg_lrs = reg_lrs
|
||||
|
||||
self.loraplus_lr_ratio = None
|
||||
self.loraplus_unet_lr_ratio = None
|
||||
self.loraplus_text_encoder_lr_ratio = None
|
||||
|
||||
if modules_dim is not None:
|
||||
logger.info(f"create LoRA network from weights")
|
||||
if self.emb_dims is None:
|
||||
self.emb_dims = [0] * 3
|
||||
logger.info("create LoRA network from weights")
|
||||
else:
|
||||
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
||||
logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
|
||||
logger.info(
|
||||
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
|
||||
)
|
||||
|
||||
# compile regular expression if specified
|
||||
def str_to_re_patterns(patterns: Optional[List[str]]) -> List[re.Pattern]:
|
||||
re_patterns = []
|
||||
if patterns is not None:
|
||||
for pattern in patterns:
|
||||
try:
|
||||
re_pattern = re.compile(pattern)
|
||||
except re.error as e:
|
||||
logger.error(f"Invalid pattern '{pattern}': {e}")
|
||||
continue
|
||||
re_patterns.append(re_pattern)
|
||||
return re_patterns
|
||||
|
||||
exclude_re_patterns = str_to_re_patterns(exclude_patterns)
|
||||
include_re_patterns = str_to_re_patterns(include_patterns)
|
||||
|
||||
# create module instances
|
||||
def create_modules(
|
||||
@@ -232,15 +243,9 @@ class LoRANetwork(torch.nn.Module):
|
||||
text_encoder_idx: Optional[int],
|
||||
root_module: torch.nn.Module,
|
||||
target_replace_modules: List[str],
|
||||
filter: Optional[str] = None,
|
||||
default_dim: Optional[int] = None,
|
||||
include_conv2d_if_filter: bool = False,
|
||||
) -> Tuple[List[LoRAModule], List[str]]:
|
||||
prefix = (
|
||||
self.LORA_PREFIX_ANIMA
|
||||
if is_unet
|
||||
else self.LORA_PREFIX_TEXT_ENCODER
|
||||
)
|
||||
prefix = self.LORA_PREFIX_ANIMA if is_unet else self.LORA_PREFIX_TEXT_ENCODER
|
||||
|
||||
loras = []
|
||||
skipped = []
|
||||
@@ -255,14 +260,16 @@ class LoRANetwork(torch.nn.Module):
|
||||
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
||||
|
||||
if is_linear or is_conv2d:
|
||||
lora_name = prefix + "." + (name + "." if name else "") + child_name
|
||||
lora_name = lora_name.replace(".", "_")
|
||||
original_name = (name + "." if name else "") + child_name
|
||||
lora_name = f"{prefix}.{original_name}".replace(".", "_")
|
||||
|
||||
force_incl_conv2d = False
|
||||
if filter is not None:
|
||||
if filter not in lora_name:
|
||||
# exclude/include filter (fullmatch: pattern must match the entire original_name)
|
||||
excluded = any(pattern.fullmatch(original_name) for pattern in exclude_re_patterns)
|
||||
included = any(pattern.fullmatch(original_name) for pattern in include_re_patterns)
|
||||
if excluded and not included:
|
||||
if verbose:
|
||||
logger.info(f"exclude: {original_name}")
|
||||
continue
|
||||
force_incl_conv2d = include_conv2d_if_filter
|
||||
|
||||
dim = None
|
||||
alpha_val = None
|
||||
@@ -272,41 +279,16 @@ class LoRANetwork(torch.nn.Module):
|
||||
dim = modules_dim[lora_name]
|
||||
alpha_val = modules_alpha[lora_name]
|
||||
else:
|
||||
if is_linear or is_conv2d_1x1:
|
||||
dim = default_dim if default_dim is not None else self.lora_dim
|
||||
if self.reg_dims is not None:
|
||||
for reg, d in self.reg_dims.items():
|
||||
if re.fullmatch(reg, original_name):
|
||||
dim = d
|
||||
alpha_val = self.alpha
|
||||
|
||||
if is_unet and type_dims is not None:
|
||||
# type_dims = [self_attn_dim, cross_attn_dim, mlp_dim, mod_dim, llm_adapter_dim]
|
||||
# Order matters: check most specific identifiers first to avoid mismatches.
|
||||
identifier_order = [
|
||||
(4, ("llm_adapter",)),
|
||||
(3, ("adaln_modulation",)),
|
||||
(0, ("self_attn",)),
|
||||
(1, ("cross_attn",)),
|
||||
(2, ("mlp",)),
|
||||
]
|
||||
for idx, ids in identifier_order:
|
||||
d = type_dims[idx]
|
||||
if d is not None and all(id_str in lora_name for id_str in ids):
|
||||
dim = d # 0 means skip
|
||||
logger.info(f"Module {original_name} matched with regex '{reg}' -> dim: {dim}")
|
||||
break
|
||||
|
||||
# block index filtering
|
||||
if is_unet and dim and self.train_block_indices is not None and "blocks_" in lora_name:
|
||||
# Extract block index from lora_name: "lora_unet_blocks_0_self_attn..."
|
||||
parts = lora_name.split("_")
|
||||
for pi, part in enumerate(parts):
|
||||
if part == "blocks" and pi + 1 < len(parts):
|
||||
try:
|
||||
block_index = int(parts[pi + 1])
|
||||
if not self.train_block_indices[block_index]:
|
||||
dim = 0
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
break
|
||||
|
||||
elif force_incl_conv2d:
|
||||
# fallback to default dim if not matched by reg_dims or reg_dims is not specified
|
||||
if dim is None:
|
||||
if is_linear or is_conv2d_1x1:
|
||||
dim = default_dim if default_dim is not None else self.lora_dim
|
||||
alpha_val = self.alpha
|
||||
|
||||
@@ -325,6 +307,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
rank_dropout=rank_dropout,
|
||||
module_dropout=module_dropout,
|
||||
)
|
||||
lora.original_name = original_name
|
||||
loras.append(lora)
|
||||
|
||||
if target_replace_modules is None:
|
||||
@@ -339,9 +322,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
if text_encoder is None:
|
||||
continue
|
||||
logger.info(f"create LoRA for Text Encoder {i+1}:")
|
||||
te_loras, te_skipped = create_modules(
|
||||
False, i, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
|
||||
)
|
||||
te_loras, te_skipped = create_modules(False, i, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||
logger.info(f"create LoRA for Text Encoder {i+1}: {len(te_loras)} modules.")
|
||||
self.text_encoder_loras.extend(te_loras)
|
||||
skipped_te += te_skipped
|
||||
@@ -354,19 +335,6 @@ class LoRANetwork(torch.nn.Module):
|
||||
self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
|
||||
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
||||
|
||||
# emb_dims: [x_embedder, t_embedder, final_layer]
|
||||
if self.emb_dims:
|
||||
for filter_name, in_dim in zip(
|
||||
["x_embedder", "t_embedder", "final_layer"],
|
||||
self.emb_dims,
|
||||
):
|
||||
loras, _ = create_modules(
|
||||
True, None, unet, None,
|
||||
filter=filter_name, default_dim=in_dim,
|
||||
include_conv2d_if_filter=(filter_name == "x_embedder"),
|
||||
)
|
||||
self.unet_loras.extend(loras)
|
||||
|
||||
logger.info(f"create LoRA for Anima DiT: {len(self.unet_loras)} modules.")
|
||||
if verbose:
|
||||
for lora in self.unet_loras:
|
||||
@@ -396,6 +364,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
def load_weights(self, file):
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
@@ -446,7 +415,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
|
||||
lora.merge_to(sd_for_lora, dtype, device)
|
||||
|
||||
logger.info(f"weights are merged")
|
||||
logger.info("weights are merged")
|
||||
|
||||
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
|
||||
self.loraplus_lr_ratio = loraplus_lr_ratio
|
||||
@@ -471,8 +440,29 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
def assemble_params(loras, lr, loraplus_ratio):
|
||||
param_groups = {"lora": {}, "plus": {}}
|
||||
reg_groups = {}
|
||||
reg_lrs_list = list(self.reg_lrs.items()) if self.reg_lrs is not None else []
|
||||
|
||||
for lora in loras:
|
||||
matched_reg_lr = None
|
||||
for i, (regex_str, reg_lr) in enumerate(reg_lrs_list):
|
||||
if re.fullmatch(regex_str, lora.original_name):
|
||||
matched_reg_lr = (i, reg_lr)
|
||||
logger.info(f"Module {lora.original_name} matched regex '{regex_str}' -> LR {reg_lr}")
|
||||
break
|
||||
|
||||
for name, param in lora.named_parameters():
|
||||
if matched_reg_lr is not None:
|
||||
reg_idx, reg_lr = matched_reg_lr
|
||||
group_key = f"reg_lr_{reg_idx}"
|
||||
if group_key not in reg_groups:
|
||||
reg_groups[group_key] = {"lora": {}, "plus": {}, "lr": reg_lr}
|
||||
if loraplus_ratio is not None and "lora_up" in name:
|
||||
reg_groups[group_key]["plus"][f"{lora.lora_name}.{name}"] = param
|
||||
else:
|
||||
reg_groups[group_key]["lora"][f"{lora.lora_name}.{name}"] = param
|
||||
continue
|
||||
|
||||
if loraplus_ratio is not None and "lora_up" in name:
|
||||
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
|
||||
else:
|
||||
@@ -480,6 +470,23 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
params = []
|
||||
descriptions = []
|
||||
for group_key, group in reg_groups.items():
|
||||
reg_lr = group["lr"]
|
||||
for key in ("lora", "plus"):
|
||||
param_data = {"params": group[key].values()}
|
||||
if len(param_data["params"]) == 0:
|
||||
continue
|
||||
if key == "plus":
|
||||
param_data["lr"] = reg_lr * loraplus_ratio if loraplus_ratio is not None else reg_lr
|
||||
else:
|
||||
param_data["lr"] = reg_lr
|
||||
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
|
||||
logger.info("NO LR skipping!")
|
||||
continue
|
||||
params.append(param_data)
|
||||
desc = f"reg_lr_{group_key.split('_')[-1]}"
|
||||
descriptions.append(desc + (" plus" if key == "plus" else ""))
|
||||
|
||||
for key in param_groups.keys():
|
||||
param_data = {"params": param_groups[key].values()}
|
||||
if len(param_data["params"]) == 0:
|
||||
@@ -498,10 +505,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
if self.text_encoder_loras:
|
||||
loraplus_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio
|
||||
te1_loras = [
|
||||
lora for lora in self.text_encoder_loras
|
||||
if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER)
|
||||
]
|
||||
te1_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER)]
|
||||
if len(te1_loras) > 0:
|
||||
logger.info(f"Text Encoder 1 (Qwen3): {len(te1_loras)} modules, LR {text_encoder_lr[0]}")
|
||||
params, descriptions = assemble_params(te1_loras, text_encoder_lr[0], loraplus_ratio)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
Diagnostic script to test Anima latent & text encoder caching independently.
|
||||
|
||||
Usage:
|
||||
python test_anima_cache.py \
|
||||
python manual_test_anima_cache.py \
|
||||
--image_dir /path/to/images \
|
||||
--qwen3_path /path/to/qwen3 \
|
||||
--vae_path /path/to/vae.safetensors \
|
||||
@@ -30,10 +30,12 @@ from torchvision import transforms
|
||||
|
||||
IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".tiff"}
|
||||
|
||||
IMAGE_TRANSFORMS = transforms.Compose([
|
||||
IMAGE_TRANSFORMS = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(), # [0,1]
|
||||
transforms.Normalize([0.5], [0.5]), # [-1,1]
|
||||
])
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def find_image_caption_pairs(image_dir: str):
|
||||
@@ -60,35 +62,32 @@ def print_tensor_info(name: str, t, indent=2):
|
||||
print(f"{prefix}{name}: None")
|
||||
return
|
||||
if isinstance(t, np.ndarray):
|
||||
print(f"{prefix}{name}: numpy {t.dtype} shape={t.shape} "
|
||||
f"min={t.min():.4f} max={t.max():.4f} mean={t.mean():.4f}")
|
||||
print(f"{prefix}{name}: numpy {t.dtype} shape={t.shape} " f"min={t.min():.4f} max={t.max():.4f} mean={t.mean():.4f}")
|
||||
elif isinstance(t, torch.Tensor):
|
||||
print(f"{prefix}{name}: torch {t.dtype} shape={tuple(t.shape)} "
|
||||
f"min={t.min().item():.4f} max={t.max().item():.4f} mean={t.float().mean().item():.4f}")
|
||||
print(
|
||||
f"{prefix}{name}: torch {t.dtype} shape={tuple(t.shape)} "
|
||||
f"min={t.min().item():.4f} max={t.max().item():.4f} mean={t.float().mean().item():.4f}"
|
||||
)
|
||||
else:
|
||||
print(f"{prefix}{name}: type={type(t)} value={t}")
|
||||
|
||||
|
||||
# Test 1: Latent Cache
|
||||
|
||||
|
||||
def test_latent_cache(args, pairs):
|
||||
print("\n" + "=" * 70)
|
||||
print("TEST 1: LATENT CACHING (VAE encode -> cache -> reload)")
|
||||
print("=" * 70)
|
||||
|
||||
from library import anima_utils
|
||||
from library.anima_models import ANIMA_VAE_MEAN, ANIMA_VAE_STD
|
||||
from library import qwen_image_autoencoder_kl
|
||||
|
||||
# Load VAE
|
||||
print("\n[1.1] Loading VAE...")
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
vae_dtype = torch.float32
|
||||
vae, vae_mean, vae_std, vae_scale = anima_utils.load_anima_vae(
|
||||
args.vae_path, dtype=vae_dtype, device=device
|
||||
)
|
||||
vae = qwen_image_autoencoder_kl.load_vae(args.vae_path, dtype=vae_dtype, device=device)
|
||||
print(f" VAE loaded on {device}, dtype={vae_dtype}")
|
||||
print(f" VAE mean (first 4): {ANIMA_VAE_MEAN[:4]}")
|
||||
print(f" VAE std (first 4): {ANIMA_VAE_STD[:4]}")
|
||||
|
||||
for img_path, caption in pairs:
|
||||
print(f"\n[1.2] Processing: {os.path.basename(img_path)}")
|
||||
@@ -96,13 +95,13 @@ def test_latent_cache(args, pairs):
|
||||
# Load image
|
||||
img = Image.open(img_path).convert("RGB")
|
||||
img_np = np.array(img)
|
||||
print(f" Raw image: {img_np.shape} dtype={img_np.dtype} "
|
||||
f"min={img_np.min()} max={img_np.max()}")
|
||||
print(f" Raw image: {img_np.shape} dtype={img_np.dtype} " f"min={img_np.min()} max={img_np.max()}")
|
||||
|
||||
# Apply IMAGE_TRANSFORMS (same as sd-scripts training)
|
||||
img_tensor = IMAGE_TRANSFORMS(img_np)
|
||||
print(f" After IMAGE_TRANSFORMS: shape={tuple(img_tensor.shape)} "
|
||||
f"min={img_tensor.min():.4f} max={img_tensor.max():.4f}")
|
||||
print(
|
||||
f" After IMAGE_TRANSFORMS: shape={tuple(img_tensor.shape)} " f"min={img_tensor.min():.4f} max={img_tensor.max():.4f}"
|
||||
)
|
||||
|
||||
# Check range is [-1, 1]
|
||||
if img_tensor.min() < -1.01 or img_tensor.max() > 1.01:
|
||||
@@ -116,7 +115,7 @@ def test_latent_cache(args, pairs):
|
||||
print(f" VAE input: shape={tuple(img_5d.shape)} dtype={img_5d.dtype}")
|
||||
|
||||
with torch.no_grad():
|
||||
latents = vae.encode(img_5d, vae_scale)
|
||||
latents = vae.encode_pixels_to_latents(img_5d)
|
||||
latents_cpu = latents.cpu()
|
||||
print_tensor_info("Encoded latents", latents_cpu)
|
||||
|
||||
@@ -165,7 +164,9 @@ def test_latent_cache(args, pairs):
|
||||
|
||||
# Test 2: Text Encoder Output Cache
|
||||
|
||||
|
||||
def test_text_encoder_cache(args, pairs):
|
||||
# TODO Rewrite this
|
||||
print("\n" + "=" * 70)
|
||||
print("TEST 2: TEXT ENCODER OUTPUT CACHING")
|
||||
print("=" * 70)
|
||||
@@ -175,9 +176,7 @@ def test_text_encoder_cache(args, pairs):
|
||||
# Load tokenizers
|
||||
print("\n[2.1] Loading tokenizers...")
|
||||
qwen3_tokenizer = anima_utils.load_qwen3_tokenizer(args.qwen3_path)
|
||||
t5_tokenizer = anima_utils.load_t5_tokenizer(
|
||||
getattr(args, 't5_tokenizer_path', None)
|
||||
)
|
||||
t5_tokenizer = anima_utils.load_t5_tokenizer(getattr(args, "t5_tokenizer_path", None))
|
||||
print(f" Qwen3 tokenizer vocab: {qwen3_tokenizer.vocab_size}")
|
||||
print(f" T5 tokenizer vocab: {t5_tokenizer.vocab_size}")
|
||||
|
||||
@@ -185,9 +184,7 @@ def test_text_encoder_cache(args, pairs):
|
||||
print("\n[2.2] Loading Qwen3 text encoder...")
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
te_dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
||||
qwen3_model, _ = anima_utils.load_qwen3_text_encoder(
|
||||
args.qwen3_path, dtype=te_dtype, device=device
|
||||
)
|
||||
qwen3_model, _ = anima_utils.load_qwen3_text_encoder(args.qwen3_path, dtype=te_dtype, device=device)
|
||||
qwen3_model.eval()
|
||||
|
||||
# Create strategy objects
|
||||
@@ -199,9 +196,7 @@ def test_text_encoder_cache(args, pairs):
|
||||
qwen3_max_length=args.qwen3_max_length,
|
||||
t5_max_length=args.t5_max_length,
|
||||
)
|
||||
text_encoding_strategy = AnimaTextEncodingStrategy(
|
||||
dropout_rate=0.0,
|
||||
)
|
||||
text_encoding_strategy = AnimaTextEncodingStrategy()
|
||||
|
||||
captions = [cap for _, cap in pairs]
|
||||
print(f"\n[2.3] Tokenizing {len(captions)} captions...")
|
||||
@@ -221,10 +216,7 @@ def test_text_encoder_cache(args, pairs):
|
||||
print(f"\n[2.4] Encoding with Qwen3 text encoder...")
|
||||
with torch.no_grad():
|
||||
prompt_embeds, attn_mask, t5_ids_out, t5_mask_out = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy,
|
||||
[qwen3_model],
|
||||
tokens_and_masks,
|
||||
enable_dropout=False,
|
||||
tokenize_strategy, [qwen3_model], tokens_and_masks
|
||||
)
|
||||
|
||||
print(f" Encoding results:")
|
||||
@@ -374,13 +366,13 @@ def test_text_encoder_cache(args, pairs):
|
||||
|
||||
# Test 3: Full batch simulation
|
||||
|
||||
|
||||
def test_full_batch_simulation(args, pairs):
|
||||
print("\n" + "=" * 70)
|
||||
print("TEST 3: FULL BATCH SIMULATION (mimics process_batch flow)")
|
||||
print("=" * 70)
|
||||
|
||||
from library import anima_utils
|
||||
from library.anima_models import ANIMA_VAE_MEAN, ANIMA_VAE_STD
|
||||
from library.strategy_anima import AnimaTokenizeStrategy, AnimaTextEncodingStrategy
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
@@ -390,14 +382,16 @@ def test_full_batch_simulation(args, pairs):
|
||||
# Load all models
|
||||
print("\n[3.1] Loading models...")
|
||||
qwen3_tokenizer = anima_utils.load_qwen3_tokenizer(args.qwen3_path)
|
||||
t5_tokenizer = anima_utils.load_t5_tokenizer(getattr(args, 't5_tokenizer_path', None))
|
||||
t5_tokenizer = anima_utils.load_t5_tokenizer(getattr(args, "t5_tokenizer_path", None))
|
||||
qwen3_model, _ = anima_utils.load_qwen3_text_encoder(args.qwen3_path, dtype=te_dtype, device=device)
|
||||
qwen3_model.eval()
|
||||
vae, _, _, vae_scale = anima_utils.load_anima_vae(args.vae_path, dtype=vae_dtype, device=device)
|
||||
|
||||
tokenize_strategy = AnimaTokenizeStrategy(
|
||||
qwen3_tokenizer=qwen3_tokenizer, t5_tokenizer=t5_tokenizer,
|
||||
qwen3_max_length=args.qwen3_max_length, t5_max_length=args.t5_max_length,
|
||||
qwen3_tokenizer=qwen3_tokenizer,
|
||||
t5_tokenizer=t5_tokenizer,
|
||||
qwen3_max_length=args.qwen3_max_length,
|
||||
t5_max_length=args.t5_max_length,
|
||||
)
|
||||
text_encoding_strategy = AnimaTextEncodingStrategy(dropout_rate=0.0)
|
||||
|
||||
@@ -408,7 +402,10 @@ def test_full_batch_simulation(args, pairs):
|
||||
tokens_and_masks = tokenize_strategy.tokenize(captions)
|
||||
with torch.no_grad():
|
||||
te_outputs = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, [qwen3_model], tokens_and_masks, enable_dropout=False,
|
||||
tokenize_strategy,
|
||||
[qwen3_model],
|
||||
tokens_and_masks,
|
||||
enable_dropout=False,
|
||||
)
|
||||
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = te_outputs
|
||||
|
||||
@@ -541,26 +538,19 @@ def test_full_batch_simulation(args, pairs):
|
||||
|
||||
# Main
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Test Anima caching mechanisms")
|
||||
parser.add_argument("--image_dir", type=str, required=True,
|
||||
help="Directory with image+txt pairs")
|
||||
parser.add_argument("--qwen3_path", type=str, required=True,
|
||||
help="Path to Qwen3 model (directory or safetensors)")
|
||||
parser.add_argument("--vae_path", type=str, required=True,
|
||||
help="Path to WanVAE safetensors")
|
||||
parser.add_argument("--t5_tokenizer_path", type=str, default=None,
|
||||
help="Path to T5 tokenizer (optional, uses bundled config)")
|
||||
parser.add_argument("--image_dir", type=str, required=True, help="Directory with image+txt pairs")
|
||||
parser.add_argument("--qwen3_path", type=str, required=True, help="Path to Qwen3 model (directory or safetensors)")
|
||||
parser.add_argument("--vae_path", type=str, required=True, help="Path to WanVAE safetensors")
|
||||
parser.add_argument("--t5_tokenizer_path", type=str, default=None, help="Path to T5 tokenizer (optional, uses bundled config)")
|
||||
parser.add_argument("--qwen3_max_length", type=int, default=512)
|
||||
parser.add_argument("--t5_max_length", type=int, default=512)
|
||||
parser.add_argument("--cache_to_disk", action="store_true",
|
||||
help="Also test disk cache round-trip")
|
||||
parser.add_argument("--skip_latent", action="store_true",
|
||||
help="Skip latent cache test")
|
||||
parser.add_argument("--skip_text", action="store_true",
|
||||
help="Skip text encoder cache test")
|
||||
parser.add_argument("--skip_full", action="store_true",
|
||||
help="Skip full batch simulation")
|
||||
parser.add_argument("--cache_to_disk", action="store_true", help="Also test disk cache round-trip")
|
||||
parser.add_argument("--skip_latent", action="store_true", help="Skip latent cache test")
|
||||
parser.add_argument("--skip_text", action="store_true", help="Skip text encoder cache test")
|
||||
parser.add_argument("--skip_full", action="store_true", help="Skip full batch simulation")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Find pairs
|
||||
@@ -470,7 +470,7 @@ class NetworkTrainer:
|
||||
loss = loss * weighting
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
loss = loss.mean(dim=list(range(1, loss.ndim))) # mean over all dims except batch
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
|
||||
Reference in New Issue
Block a user