Compare commits

...

8 Commits

Author SHA1 Message Date
Kohya S.
ae72efb92b Merge pull request #2264 from kohya-ss/release/v0.10.1
Release 0.10.1
2026-02-13 08:34:42 +09:00
Kohya S
449e70b4cf README: Update change history for version 0.10.1 with Anima model support 2026-02-13 08:31:22 +09:00
Kohya S.
b237b8deb3 Merge pull request #2263 from kohya-ss/sd3
feat: Anima support
2026-02-13 08:20:44 +09:00
Kohya S.
34e7138b6a Add/modify some implementation for anima (#2261)
* fix: update extend-exclude list in _typos.toml to include configs

* fix: exclude anima tests from pytest

* feat: add entry for 'temperal' in extend-words section of _typos.toml for Qwen-Image VAE

* fix: update default value for --discrete_flow_shift in anima training guide

* feat: add Qwen-Image VAE

* feat: simplify encode_tokens

* feat: use unified attention module, add wrapper for state dict compatibility

* feat: loading with dynamic fp8 optimization and LoRA support

* feat: add anima minimal inference script (WIP)

* format: format

* feat: simplify target module selection by regular expression patterns

* feat: kept caption dropout rate in cache and handle in training script

* feat: update train_llm_adapter and verbose default values to string type

* fix: use strategy instead of using tokenizers directly

* feat: add dtype property and all-zero mask handling in cross-attention in LLMAdapterTransformerBlock

* feat: support 5d tensor in get_noisy_model_input_and_timesteps

* feat: update loss calculation to support 5d tensor

* fix: update argument names in anima_train_utils to align with other archtectures

* feat: simplify Anima training script and update empty caption handling

* feat: support LoRA format without `net.` prefix

* fix: update to work fp8_scaled option

* feat: add regex-based learning rates and dimensions handling in create_network

* fix: improve regex matching for module selection and learning rates in LoRANetwork

* fix: update logging message for regex match in LoRANetwork

* fix: keep latents 4D except DiT call

* feat: enhance block swap functionality for inference and training in Anima model

* feat: refactor Anima training script

* feat: optimize VAE processing by adjusting tensor dimensions and data types

* fix: wait all block trasfer before siwtching offloader mode

* feat: update Anima training guide with new argument specifications and regex-based module selection. Thank you Claude!

* feat: support LORA for Qwen3

* feat: update Anima SAI model spec metadata handling

* fix: remove unused code

* feat: split CFG processing in do_sample function to reduce memory usage

* feat: add VAE chunking and caching options to reduce memory usage

* feat: optimize RMSNorm forward method and remove unused torch_attention_op

* Update library/strategy_anima.py

Use torch.all instead of all.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update library/safetensors_utils.py

Fix duplicated new_key for concat_hook.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update anima_minimal_inference.py

Remove unused code.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update anima_train.py

Remove unused import.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update library/anima_train_utils.py

Remove unused import.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* fix: review with Copilot

* feat: add script to convert LoRA format to ComfyUI compatible format (WIP, not tested yet)

* feat: add process_escape function to handle escape sequences in prompts

* feat: enhance LoRA weight handling in model loading and add text encoder loading function

* feat: improve ComfyUI conversion script with prefix constants and module name adjustments

* feat: update caption dropout documentation to clarify cache regeneration requirement

* feat: add clarification on learning rate adjustments

* feat: add note on PyTorch version requirement to prevent NaN loss

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-02-13 08:15:06 +09:00
Kohya S
9144463f7b Merge branch 'dev' into sd3 2026-02-13 08:14:21 +09:00
duongve13112002
e21a7736f8 Support Anima model (#2260)
* Support Anima model

* Update document and fix bug

* Fix latent normlization

* Fix typo

* Fix cache embedding

* fix typo in tests/test_anima_cache.py

* Remove redundant argument apply_t5_attn_mask

* Improving caching with argument caption_dropout_rate

* Fix W&B logging bugs

* Fix discrete_flow_shift default value
2026-02-08 10:18:55 +09:00
Kohya S.
8b5ce3e641 Merge pull request #2255 from cgcalatrava/fix-diffusers-unet-import
Fix AttributeError for UNet2DConditionModel with newer diffusers versions
2026-01-20 07:50:04 +09:00
cgcalatrava
da07e4c617 Make UNet2DConditionModel import compatible with old and new diffusers versions 2026-01-19 20:53:00 +01:00
35 changed files with 464474 additions and 64 deletions

View File

@@ -50,6 +50,11 @@ Stable Diffusion等の画像生成モデルの学習、モデルによる画像
### 更新履歴
- **Version 0.10.1 (2026-02-13):**
- [Anima Preview](https://huggingface.co/circlestone-labs/Anima)モデルのLoRA学習およびfine-tuningをサポートしました。[PR #2260](https://github.com/kohya-ss/sd-scripts/pull/2260) および[PR #2261](https://github.com/kohya-ss/sd-scripts/pull/2261)
- 素晴らしいモデルを公開された CircleStone Labs、および PR #2260を提出していただいたduongve13112002氏に深く感謝します
- 詳細は[ドキュメント](./docs/anima_train_network.md)をご覧ください。
- **Version 0.10.0 (2026-01-19):**
- `sd3`ブランチを`main`ブランチにマージしました。このバージョンからFLUX.1およびSD3/SD3.5等のモデルが`main`ブランチでサポートされます。
- ドキュメントにはまだ不備があるため、お気づきの点はIssue等でお知らせください。

View File

@@ -47,6 +47,11 @@ If you find this project helpful, please consider supporting its development via
### Change History
- **Version 0.10.1 (2026-02-13):**
- [Anima Preview](https://huggingface.co/circlestone-labs/Anima) model LoRA training and fine-tuning are now supported. See [PR #2260](https://github.com/kohya-ss/sd-scripts/pull/2260) and [PR #2261](https://github.com/kohya-ss/sd-scripts/pull/2261).
- Many thanks to CircleStone Labs for releasing this amazing model, and to duongve13112002 for submitting great PR #2260.
- For details, please refer to the [documentation](./docs/anima_train_network.md).
- **Version 0.10.0 (2026-01-19):**
- `sd3` branch is merged to `main` branch. From this version, FLUX.1 and SD3/SD3.5 etc. are supported in the `main` branch.
- There are still some missing parts in the documentation, so please let us know if you find any issues via Issues etc.

View File

@@ -32,6 +32,7 @@ hime="hime"
OT="OT"
byt="byt"
tak="tak"
temperal="temperal"
[files]
extend-exclude = ["_typos.toml", "venv"]
extend-exclude = ["_typos.toml", "venv", "configs"]

1082
anima_minimal_inference.py Normal file

File diff suppressed because it is too large Load Diff

759
anima_train.py Normal file
View File

@@ -0,0 +1,759 @@
# Anima full finetune training script
import argparse
from concurrent.futures import ThreadPoolExecutor
import copy
import gc
import math
import os
from multiprocessing import Value
from typing import List
import toml
from tqdm import tqdm
import torch
from library import flux_train_utils, qwen_image_autoencoder_kl
from library.device_utils import init_ipex, clean_memory_on_device
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
init_ipex()
from accelerate.utils import set_seed
from library import deepspeed_utils, anima_models, anima_train_utils, anima_utils, strategy_base, strategy_anima, sai_model_spec
import library.train_util as train_util
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
import library.config_util as config_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments
def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
deepspeed_utils.prepare_deepspeed_args(args)
setup_logging(args, reset=True)
# backward compatibility
if not args.skip_cache_check:
args.skip_cache_check = args.skip_latents_validity_check
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
logger.warning("cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled")
args.cache_text_encoder_outputs = True
if args.cpu_offload_checkpointing and not args.gradient_checkpointing:
logger.warning("cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled")
args.gradient_checkpointing = True
if args.unsloth_offload_checkpointing:
if not args.gradient_checkpointing:
logger.warning("unsloth_offload_checkpointing is enabled, so gradient_checkpointing is also enabled")
args.gradient_checkpointing = True
assert not args.cpu_offload_checkpointing, "Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing"
assert (
args.blocks_to_swap is None or args.blocks_to_swap == 0
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing"
assert (
args.blocks_to_swap is None or args.blocks_to_swap == 0
) or not args.unsloth_offload_checkpointing, "blocks_to_swap is not supported with unsloth_offload_checkpointing"
cache_latents = args.cache_latents
use_dreambooth_method = args.in_json is None
if args.seed is not None:
set_seed(args.seed)
# prepare caching strategy: must be set before preparing dataset
if args.cache_latents:
latents_caching_strategy = strategy_anima.AnimaLatentsCachingStrategy(
args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
# prepare dataset
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
if args.dataset_config is not None:
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
logger.warning("ignore following options because config file is found: {0}".format(", ".join(ignored)))
else:
if use_dreambooth_method:
logger.info("Using DreamBooth method.")
user_config = {
"datasets": [
{
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
args.train_data_dir, args.reg_data_dir
)
}
]
}
else:
logger.info("Training with captions.")
user_config = {
"datasets": [
{
"subsets": [
{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}
]
}
]
}
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args)
val_dataset_group = None
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
train_dataset_group.verify_bucket_reso_steps(16) # Qwen-Image VAE spatial downscale = 8 * patch size = 2
if args.debug_dataset:
if args.cache_text_encoder_outputs:
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
strategy_anima.AnimaTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False
)
)
train_dataset_group.set_current_strategies()
train_util.debug_dataset(train_dataset_group, True)
return
if len(train_dataset_group) == 0:
logger.error("No data found. Please verify the metadata file and train_data_dir option.")
return
if cache_latents:
assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used"
if args.cache_text_encoder_outputs:
assert train_dataset_group.is_text_encoder_output_cacheable(
cache_supports_dropout=True
), "when caching text encoder output, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used"
# prepare accelerator
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
# mixed precision dtype
weight_dtype, save_dtype = train_util.prepare_dtype(args)
# Load tokenizers and set strategies
logger.info("Loading tokenizers...")
qwen3_text_encoder, qwen3_tokenizer = anima_utils.load_qwen3_text_encoder(args.qwen3, dtype=weight_dtype, device="cpu")
t5_tokenizer = anima_utils.load_t5_tokenizer(args.t5_tokenizer_path)
# Set tokenize strategy
tokenize_strategy = strategy_anima.AnimaTokenizeStrategy(
qwen3_tokenizer=qwen3_tokenizer,
t5_tokenizer=t5_tokenizer,
qwen3_max_length=args.qwen3_max_token_length,
t5_max_length=args.t5_max_token_length,
)
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
text_encoding_strategy = strategy_anima.AnimaTextEncodingStrategy()
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
# Prepare text encoder (always frozen for Anima)
qwen3_text_encoder.to(weight_dtype)
qwen3_text_encoder.requires_grad_(False)
# Cache text encoder outputs
sample_prompts_te_outputs = None
if args.cache_text_encoder_outputs:
qwen3_text_encoder.to(accelerator.device)
qwen3_text_encoder.eval()
text_encoder_caching_strategy = strategy_anima.AnimaTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, is_partial=False
)
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy)
with accelerator.autocast():
train_dataset_group.new_cache_text_encoder_outputs([qwen3_text_encoder], accelerator)
# cache sample prompt embeddings
if args.sample_prompts is not None:
logger.info(f"Cache Text Encoder outputs for sample prompts: {args.sample_prompts}")
prompts = train_util.load_prompts(args.sample_prompts)
sample_prompts_te_outputs = {}
with accelerator.autocast(), torch.no_grad():
for prompt_dict in prompts:
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
if p not in sample_prompts_te_outputs:
logger.info(f" cache TE outputs for: {p}")
tokens_and_masks = tokenize_strategy.tokenize(p)
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
tokenize_strategy, [qwen3_text_encoder], tokens_and_masks
)
accelerator.wait_for_everyone()
# free text encoder memory
qwen3_text_encoder = None
gc.collect() # Force garbage collection to free memory
clean_memory_on_device(accelerator.device)
# Load VAE and cache latents
logger.info("Loading Anima VAE...")
vae = qwen_image_autoencoder_kl.load_vae(
args.vae, device="cpu", disable_mmap=True, spatial_chunk_size=args.vae_chunk_size, disable_cache=args.vae_disable_cache
)
if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype)
vae.requires_grad_(False)
vae.eval()
train_dataset_group.new_cache_latents(vae, accelerator)
vae.to("cpu")
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
# Load DiT (MiniTrainDIT + optional LLM Adapter)
logger.info("Loading Anima DiT...")
dit = anima_utils.load_anima_model(
"cpu", args.pretrained_model_name_or_path, args.attn_mode, args.split_attn, "cpu", dit_weight_dtype=None
)
if args.gradient_checkpointing:
dit.enable_gradient_checkpointing(
cpu_offload=args.cpu_offload_checkpointing,
unsloth_offload=args.unsloth_offload_checkpointing,
)
train_dit = args.learning_rate != 0
dit.requires_grad_(train_dit)
if not train_dit:
dit.to(accelerator.device, dtype=weight_dtype)
# Block swap
is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
if is_swapping_blocks:
logger.info(f"Enable block swap: blocks_to_swap={args.blocks_to_swap}")
dit.enable_block_swap(args.blocks_to_swap, accelerator.device)
if not cache_latents:
vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=weight_dtype)
# Setup optimizer with parameter groups
if train_dit:
param_groups = anima_train_utils.get_anima_param_groups(
dit,
base_lr=args.learning_rate,
self_attn_lr=args.self_attn_lr,
cross_attn_lr=args.cross_attn_lr,
mlp_lr=args.mlp_lr,
mod_lr=args.mod_lr,
llm_adapter_lr=args.llm_adapter_lr,
)
else:
param_groups = []
training_models = []
if train_dit:
training_models.append(dit)
# calculate trainable parameters
n_params = 0
for group in param_groups:
for p in group["params"]:
n_params += p.numel()
accelerator.print(f"train dit: {train_dit}")
accelerator.print(f"number of training models: {len(training_models)}")
accelerator.print(f"number of trainable parameters: {n_params:,}")
# prepare optimizer
accelerator.print("prepare optimizer, data loader etc.")
if args.fused_backward_pass:
# Pass per-component param_groups directly to preserve per-component LRs
_, _, optimizer = train_util.get_optimizer(args, trainable_params=param_groups)
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)
else:
_, _, optimizer = train_util.get_optimizer(args, trainable_params=param_groups)
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)
# prepare dataloader
train_dataset_group.set_current_strategies()
n_workers = min(args.max_data_loader_n_workers, os.cpu_count())
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# calculate training steps
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs: {args.max_train_steps}")
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr scheduler
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# full fp16/bf16 training
dit_weight_dtype = weight_dtype
if args.full_fp16:
assert args.mixed_precision == "fp16", "full_fp16 requires mixed_precision='fp16'"
accelerator.print("enable full fp16 training.")
elif args.full_bf16:
assert args.mixed_precision == "bf16", "full_bf16 requires mixed_precision='bf16'"
accelerator.print("enable full bf16 training.")
else:
dit_weight_dtype = torch.float32 # If neither full_fp16 nor full_bf16, the model weights should be in float32
dit.to(dit_weight_dtype) # convert dit to target weight dtype
# move text encoder to GPU if not cached
if not args.cache_text_encoder_outputs and qwen3_text_encoder is not None:
qwen3_text_encoder.to(accelerator.device)
clean_memory_on_device(accelerator.device)
# Prepare with accelerator
# Temporarily move non-training models off GPU to reduce memory during DDP init
# if not args.cache_text_encoder_outputs and qwen3_text_encoder is not None:
# qwen3_text_encoder.to("cpu")
# if not cache_latents and vae is not None:
# vae.to("cpu")
# clean_memory_on_device(accelerator.device)
if args.deepspeed:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=dit)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
training_models = [ds_model]
else:
if train_dit:
dit = accelerator.prepare(dit, device_placement=[not is_swapping_blocks])
if is_swapping_blocks:
accelerator.unwrap_model(dit).move_to_device_except_swap_blocks(accelerator.device)
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
# Move non-training models back to GPU
if not args.cache_text_encoder_outputs and qwen3_text_encoder is not None:
qwen3_text_encoder.to(accelerator.device)
if not cache_latents and vae is not None:
vae.to(accelerator.device, dtype=weight_dtype)
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
# resume
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
if args.fused_backward_pass:
# use fused optimizer for backward pass: other optimizers will be supported in the future
import library.adafactor_fused
library.adafactor_fused.patch_adafactor_fused(optimizer)
for param_group in optimizer.param_groups:
for parameter in param_group["params"]:
if parameter.requires_grad:
def create_grad_hook(p_group):
def grad_hook(tensor: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, p_group)
tensor.grad = None
return grad_hook
parameter.register_post_accumulate_grad_hook(create_grad_hook(param_group))
# Training loop
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
accelerator.print("running training / 学習開始")
accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
accelerator.print(
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
)
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
# Copy for noise and timestep generation, because noise_scheduler may be changed during training in future
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"finetuning" if args.log_tracker_name is None else args.log_tracker_name,
config=train_util.get_sanitized_config_or_none(args),
init_kwargs=init_kwargs,
)
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
import wandb
wandb.define_metric("epoch")
wandb.define_metric("loss/epoch", step_metric="epoch")
if is_swapping_blocks:
accelerator.unwrap_model(dit).prepare_block_swap_before_forward()
# For --sample_at_first
optimizer_eval_fn()
anima_train_utils.sample_images(
accelerator,
args,
0,
global_step,
dit,
vae,
qwen3_text_encoder,
tokenize_strategy,
text_encoding_strategy,
sample_prompts_te_outputs,
)
optimizer_train_fn()
if len(accelerator.trackers) > 0:
accelerator.log({}, step=0)
# Show model info
unwrapped_dit = accelerator.unwrap_model(dit) if dit is not None else None
if unwrapped_dit is not None:
logger.info(f"dit device: {unwrapped_dit.device}, dtype: {unwrapped_dit.dtype}")
if qwen3_text_encoder is not None:
logger.info(f"qwen3 device: {qwen3_text_encoder.device}")
if vae is not None:
logger.info(f"vae device: {vae.device}")
loss_recorder = train_util.LossRecorder()
epoch = 0
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
for m in training_models:
m.train()
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(*training_models):
# Get latents
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device, dtype=dit_weight_dtype)
if latents.ndim == 5: # Fallback for 5D latents (old cache)
latents = latents.squeeze(2) # (B, C, 1, H, W) -> (B, C, H, W)
else:
with torch.no_grad():
# images are already [-1, 1] from IMAGE_TRANSFORMS, add temporal dim
images = batch["images"].to(accelerator.device, dtype=weight_dtype)
latents = vae.encode_pixels_to_latents(images).to(accelerator.device, dtype=dit_weight_dtype)
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.nan_to_num(latents, 0, out=latents)
# Get text encoder outputs
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
# Cached outputs
caption_dropout_rates = text_encoder_outputs_list[-1]
text_encoder_outputs_list = text_encoder_outputs_list[:-1]
# Apply caption dropout to cached outputs
text_encoder_outputs_list = text_encoding_strategy.drop_cached_text_encoder_outputs(
*text_encoder_outputs_list, caption_dropout_rates=caption_dropout_rates
)
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoder_outputs_list
else:
# Encode on-the-fly
input_ids_list = batch["input_ids_list"]
with torch.no_grad():
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoding_strategy.encode_tokens(
tokenize_strategy, [qwen3_text_encoder], input_ids_list
)
# Move to device
prompt_embeds = prompt_embeds.to(accelerator.device, dtype=dit_weight_dtype)
attn_mask = attn_mask.to(accelerator.device)
t5_input_ids = t5_input_ids.to(accelerator.device, dtype=torch.long)
t5_attn_mask = t5_attn_mask.to(accelerator.device)
# Noise and timesteps
noise = torch.randn_like(latents)
# Get noisy model input and timesteps
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
args, noise_scheduler_copy, latents, noise, accelerator.device, dit_weight_dtype
)
timesteps = timesteps / 1000.0 # scale to [0, 1] range. timesteps is float32
# NaN checks
if torch.any(torch.isnan(noisy_model_input)):
accelerator.print("NaN found in noisy_model_input, replacing with zeros")
noisy_model_input = torch.nan_to_num(noisy_model_input, 0, out=noisy_model_input)
# Create padding mask
# padding_mask: (B, 1, H_latent, W_latent)
bs = latents.shape[0]
h_latent = latents.shape[-2]
w_latent = latents.shape[-1]
padding_mask = torch.zeros(bs, 1, h_latent, w_latent, dtype=dit_weight_dtype, device=accelerator.device)
# DiT forward (LLM adapter runs inside forward for DDP gradient sync)
noisy_model_input = noisy_model_input.unsqueeze(2) # 4D to 5D, (B, C, 1, H, W)
with accelerator.autocast():
model_pred = dit(
noisy_model_input,
timesteps,
prompt_embeds,
padding_mask=padding_mask,
source_attention_mask=attn_mask,
t5_input_ids=t5_input_ids,
t5_attn_mask=t5_attn_mask,
)
model_pred = model_pred.squeeze(2) # 5D to 4D, (B, C, H, W)
# Compute loss (rectified flow: target = noise - latents)
target = noise - latents
# Weighting
weighting = anima_train_utils.compute_loss_weighting_for_anima(
weighting_scheme=args.weighting_scheme, sigmas=sigmas
)
# Loss
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, None)
loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c)
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3]) # (B, C, H, W) -> (B,)
if weighting is not None:
loss = loss * weighting
loss_weights = batch["loss_weights"]
loss = loss * loss_weights
loss = loss.mean()
accelerator.backward(loss)
if not args.fused_backward_pass:
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = []
for m in training_models:
params_to_clip.extend(m.parameters())
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
else:
# optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
lr_scheduler.step()
# Checks if the accelerator has performed an optimization step
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
optimizer_eval_fn()
anima_train_utils.sample_images(
accelerator,
args,
None,
global_step,
dit,
vae,
qwen3_text_encoder,
tokenize_strategy,
text_encoding_strategy,
sample_prompts_te_outputs,
)
# Save at specific steps
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
anima_train_utils.save_anima_model_on_epoch_end_or_stepwise(
args,
False,
accelerator,
save_dtype,
epoch,
num_train_epochs,
global_step,
accelerator.unwrap_model(dit) if train_dit else None,
)
optimizer_train_fn()
current_loss = loss.detach().item()
if len(accelerator.trackers) > 0:
logs = {"loss": current_loss}
train_util.append_lr_to_logs_with_names(
logs,
lr_scheduler,
args.optimizer_type,
["base", "self_attn", "cross_attn", "mlp", "mod", "llm_adapter"] if train_dit else [],
)
accelerator.log(logs, step=global_step)
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_recorder.moving_average, "epoch": epoch + 1}
accelerator.log(logs, step=global_step)
accelerator.wait_for_everyone()
optimizer_eval_fn()
if args.save_every_n_epochs is not None:
if accelerator.is_main_process:
anima_train_utils.save_anima_model_on_epoch_end_or_stepwise(
args,
True,
accelerator,
save_dtype,
epoch,
num_train_epochs,
global_step,
accelerator.unwrap_model(dit) if train_dit else None,
)
anima_train_utils.sample_images(
accelerator,
args,
epoch + 1,
global_step,
dit,
vae,
qwen3_text_encoder,
tokenize_strategy,
text_encoding_strategy,
sample_prompts_te_outputs,
)
# End training
is_main_process = accelerator.is_main_process
dit = accelerator.unwrap_model(dit)
accelerator.end_training()
optimizer_eval_fn()
if args.save_state or args.save_state_on_train_end:
train_util.save_state_on_train_end(args, accelerator)
del accelerator
if is_main_process and train_dit:
anima_train_utils.save_anima_model_on_train_end(
args,
save_dtype,
epoch,
global_step,
dit,
)
logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, False)
train_util.add_masked_loss_arguments(parser)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
add_custom_train_arguments(parser)
train_util.add_dit_training_arguments(parser)
anima_train_utils.add_anima_training_arguments(parser)
sai_model_spec.add_model_spec_arguments(parser)
parser.add_argument(
"--cpu_offload_checkpointing",
action="store_true",
help="offload gradient checkpointing to CPU (reduces VRAM at cost of speed)",
)
parser.add_argument(
"--unsloth_offload_checkpointing",
action="store_true",
help="offload activations to CPU RAM using async non-blocking transfers (faster than --cpu_offload_checkpointing). "
"Cannot be used with --cpu_offload_checkpointing or --blocks_to_swap.",
)
parser.add_argument(
"--skip_latents_validity_check",
action="store_true",
help="[Deprecated] use 'skip_cache_check' instead",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
if args.attn_mode == "sdpa":
args.attn_mode = "torch" # backward compatibility
train(args)

448
anima_train_network.py Normal file
View File

@@ -0,0 +1,448 @@
# Anima LoRA training script
import argparse
from typing import Any, Optional, Union
import torch
import torch.nn as nn
from accelerate import Accelerator
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from library import (
anima_models,
anima_train_utils,
anima_utils,
flux_train_utils,
qwen_image_autoencoder_kl,
sd3_train_utils,
strategy_anima,
strategy_base,
train_util,
)
import train_network
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class AnimaNetworkTrainer(train_network.NetworkTrainer):
def __init__(self):
super().__init__()
self.sample_prompts_te_outputs = None
def assert_extra_args(
self,
args,
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
val_dataset_group: Optional[train_util.DatasetGroup],
):
if args.fp8_base or args.fp8_base_unet:
logger.warning("fp8_base and fp8_base_unet are not supported. / fp8_baseとfp8_base_unetはサポートされていません。")
args.fp8_base = False
args.fp8_base_unet = False
args.fp8_scaled = False # Anima DiT does not support fp8_scaled
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
logger.warning("cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled")
args.cache_text_encoder_outputs = True
if args.cache_text_encoder_outputs:
assert train_dataset_group.is_text_encoder_output_cacheable(
cache_supports_dropout=True
), "when caching Text Encoder output, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used"
assert (
args.network_train_unet_only or not args.cache_text_encoder_outputs
), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません"
assert (
args.blocks_to_swap is None or args.blocks_to_swap == 0
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing"
if args.unsloth_offload_checkpointing:
if not args.gradient_checkpointing:
logger.warning("unsloth_offload_checkpointing is enabled, so gradient_checkpointing is also enabled")
args.gradient_checkpointing = True
assert (
not args.cpu_offload_checkpointing
), "Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing"
assert (
args.blocks_to_swap is None or args.blocks_to_swap == 0
), "blocks_to_swap is not supported with unsloth_offload_checkpointing"
train_dataset_group.verify_bucket_reso_steps(16) # WanVAE spatial downscale = 8 and patch size = 2
if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(16)
def load_target_model(self, args, weight_dtype, accelerator):
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
# Load Qwen3 text encoder (tokenizers already loaded in get_tokenize_strategy)
logger.info("Loading Qwen3 text encoder...")
qwen3_text_encoder, _ = anima_utils.load_qwen3_text_encoder(args.qwen3, dtype=weight_dtype, device="cpu")
qwen3_text_encoder.eval()
# Load VAE
logger.info("Loading Anima VAE...")
vae = qwen_image_autoencoder_kl.load_vae(
args.vae, device="cpu", disable_mmap=True, spatial_chunk_size=args.vae_chunk_size, disable_cache=args.vae_disable_cache
)
vae.to(weight_dtype)
vae.eval()
# Return format: (model_type, text_encoders, vae, unet)
return "anima", [qwen3_text_encoder], vae, None # unet loaded lazily
def load_unet_lazily(self, args, weight_dtype, accelerator, text_encoders) -> tuple[nn.Module, list[nn.Module]]:
loading_dtype = None if args.fp8_scaled else weight_dtype
loading_device = "cpu" if self.is_swapping_blocks else accelerator.device
attn_mode = "torch"
if args.xformers:
attn_mode = "xformers"
if args.attn_mode is not None:
attn_mode = args.attn_mode
# Load DiT
logger.info(f"Loading Anima DiT model with attn_mode={attn_mode}, split_attn: {args.split_attn}...")
model = anima_utils.load_anima_model(
accelerator.device,
args.pretrained_model_name_or_path,
attn_mode,
args.split_attn,
loading_device,
loading_dtype,
args.fp8_scaled,
)
# Store unsloth preference so that when the base NetworkTrainer calls
# dit.enable_gradient_checkpointing(cpu_offload=...), we can override to use unsloth.
# The base trainer only passes cpu_offload, so we store the flag on the model.
self._use_unsloth_offload_checkpointing = args.unsloth_offload_checkpointing
# Block swap
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
if self.is_swapping_blocks:
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
model.enable_block_swap(args.blocks_to_swap, accelerator.device)
return model, text_encoders
def get_tokenize_strategy(self, args):
# Load tokenizers from paths (called before load_target_model, so self.qwen3_tokenizer isn't set yet)
tokenize_strategy = strategy_anima.AnimaTokenizeStrategy(
qwen3_path=args.qwen3,
t5_tokenizer_path=args.t5_tokenizer_path,
qwen3_max_length=args.qwen3_max_token_length,
t5_max_length=args.t5_max_token_length,
)
return tokenize_strategy
def get_tokenizers(self, tokenize_strategy: strategy_anima.AnimaTokenizeStrategy):
return [tokenize_strategy.qwen3_tokenizer]
def get_latents_caching_strategy(self, args):
return strategy_anima.AnimaLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check)
def get_text_encoding_strategy(self, args):
return strategy_anima.AnimaTextEncodingStrategy()
def post_process_network(self, args, accelerator, network, text_encoders, unet):
pass
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
if args.cache_text_encoder_outputs:
return None # no text encoders needed for encoding
return text_encoders
def get_text_encoder_outputs_caching_strategy(self, args):
if args.cache_text_encoder_outputs:
return strategy_anima.AnimaTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False
)
return None
def cache_text_encoder_outputs_if_needed(
self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
):
if args.cache_text_encoder_outputs:
if not args.lowram:
# We cannot move DiT to CPU because of block swap, so only move VAE
logger.info("move vae to cpu to save memory")
org_vae_device = vae.device
vae.to("cpu")
clean_memory_on_device(accelerator.device)
logger.info("move text encoder to gpu")
text_encoders[0].to(accelerator.device)
with accelerator.autocast():
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
# cache sample prompts
if args.sample_prompts is not None:
logger.info(f"cache Text Encoder outputs for sample prompts: {args.sample_prompts}")
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
prompts = train_util.load_prompts(args.sample_prompts)
sample_prompts_te_outputs = {}
with accelerator.autocast(), torch.no_grad():
for prompt_dict in prompts:
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
if p not in sample_prompts_te_outputs:
logger.info(f" cache TE outputs for: {p}")
tokens_and_masks = tokenize_strategy.tokenize(p)
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
tokenize_strategy, text_encoders, tokens_and_masks
)
self.sample_prompts_te_outputs = sample_prompts_te_outputs
accelerator.wait_for_everyone()
# move text encoder back to cpu
logger.info("move text encoder back to cpu")
text_encoders[0].to("cpu")
if not args.lowram:
logger.info("move vae back to original device")
vae.to(org_vae_device)
clean_memory_on_device(accelerator.device)
else:
# move text encoder to device for encoding during training/validation
text_encoders[0].to(accelerator.device)
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder] # compatibility
te = self.get_models_for_text_encoding(args, accelerator, text_encoders)
qwen3_te = te[0] if te is not None else None
text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
anima_train_utils.sample_images(
accelerator,
args,
epoch,
global_step,
unet,
vae,
qwen3_te,
tokenize_strategy,
text_encoding_strategy,
self.sample_prompts_te_outputs,
)
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
return noise_scheduler
def encode_images_to_latents(self, args, vae, images):
vae: qwen_image_autoencoder_kl.AutoencoderKLQwenImage
return vae.encode_pixels_to_latents(images) # Keep 4D for input/output
def shift_scale_latents(self, args, latents):
# Latents already normalized by vae.encode with scale
return latents
def get_noise_pred_and_target(
self,
args,
accelerator,
noise_scheduler,
latents,
batch,
text_encoder_conds,
unet,
network,
weight_dtype,
train_unet,
is_train=True,
):
anima: anima_models.Anima = unet
# Sample noise
if latents.ndim == 5: # Fallback for 5D latents (old cache)
latents = latents.squeeze(2) # [B, C, 1, H, W] -> [B, C, H, W]
noise = torch.randn_like(latents)
# Get noisy model input and timesteps
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
)
timesteps = timesteps / 1000.0 # scale to [0, 1] range. timesteps is float32
# Gradient checkpointing support
if args.gradient_checkpointing:
noisy_model_input.requires_grad_(True)
for t in text_encoder_conds:
if t is not None and t.dtype.is_floating_point:
t.requires_grad_(True)
# Unpack text encoder conditions
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoder_conds
# Move to device
prompt_embeds = prompt_embeds.to(accelerator.device, dtype=weight_dtype)
attn_mask = attn_mask.to(accelerator.device)
t5_input_ids = t5_input_ids.to(accelerator.device, dtype=torch.long)
t5_attn_mask = t5_attn_mask.to(accelerator.device)
# Create padding mask
bs = latents.shape[0]
h_latent = latents.shape[-2]
w_latent = latents.shape[-1]
padding_mask = torch.zeros(bs, 1, h_latent, w_latent, dtype=weight_dtype, device=accelerator.device)
# Call model
noisy_model_input = noisy_model_input.unsqueeze(2) # 4D to 5D, [B, C, H, W] -> [B, C, 1, H, W]
with torch.set_grad_enabled(is_train), accelerator.autocast():
model_pred = anima(
noisy_model_input,
timesteps,
prompt_embeds,
padding_mask=padding_mask,
target_input_ids=t5_input_ids,
target_attention_mask=t5_attn_mask,
source_attention_mask=attn_mask,
)
model_pred = model_pred.squeeze(2) # 5D to 4D, [B, C, 1, H, W] -> [B, C, H, W]
# Rectified flow target: noise - latents
target = noise - latents
# Loss weighting
weighting = anima_train_utils.compute_loss_weighting_for_anima(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
return model_pred, target, timesteps, weighting
def process_batch(
self,
batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train=True,
train_text_encoder=True,
train_unet=True,
) -> torch.Tensor:
"""Override base process_batch for caption dropout with cached text encoder outputs."""
# Text encoder conditions
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
anima_text_encoding_strategy: strategy_anima.AnimaTextEncodingStrategy = text_encoding_strategy
if text_encoder_outputs_list is not None:
caption_dropout_rates = text_encoder_outputs_list[-1]
text_encoder_outputs_list = text_encoder_outputs_list[:-1]
# Apply caption dropout to cached outputs
text_encoder_outputs_list = anima_text_encoding_strategy.drop_cached_text_encoder_outputs(
*text_encoder_outputs_list, caption_dropout_rates=caption_dropout_rates
)
batch["text_encoder_outputs_list"] = text_encoder_outputs_list
return super().process_batch(
batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train,
train_text_encoder,
train_unet,
)
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
return loss
def get_sai_model_spec(self, args):
return train_util.get_sai_model_spec_dataclass(None, args, False, True, False, anima="preview").to_metadata_dict()
def update_metadata(self, metadata, args):
metadata["ss_weighting_scheme"] = args.weighting_scheme
metadata["ss_logit_mean"] = args.logit_mean
metadata["ss_logit_std"] = args.logit_std
metadata["ss_mode_scale"] = args.mode_scale
metadata["ss_timestep_sampling"] = args.timestep_sampling
metadata["ss_sigmoid_scale"] = args.sigmoid_scale
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
def is_text_encoder_not_needed_for_training(self, args):
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
# Set first parameter's requires_grad to True to workaround Accelerate gradient checkpointing bug
first_param = next(text_encoder.parameters())
first_param.requires_grad_(True)
def prepare_unet_with_accelerator(
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
) -> torch.nn.Module:
# The base NetworkTrainer only calls enable_gradient_checkpointing(cpu_offload=True/False),
# so we re-apply with unsloth_offload if needed (after base has already enabled it).
if self._use_unsloth_offload_checkpointing and args.gradient_checkpointing:
unet.enable_gradient_checkpointing(unsloth_offload=True)
if not self.is_swapping_blocks:
return super().prepare_unet_with_accelerator(args, accelerator, unet)
model = unet
model = accelerator.prepare(model, device_placement=[not self.is_swapping_blocks])
accelerator.unwrap_model(model).move_to_device_except_swap_blocks(accelerator.device)
accelerator.unwrap_model(model).prepare_block_swap_before_forward()
return model
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
if self.is_swapping_blocks:
# prepare for next forward: because backward pass is not called, we need to prepare it here
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
def setup_parser() -> argparse.ArgumentParser:
parser = train_network.setup_parser()
train_util.add_dit_training_arguments(parser)
anima_train_utils.add_anima_training_arguments(parser)
# parser.add_argument("--fp8_scaled", action="store_true", help="Use scaled fp8 for DiT / DiTにスケーリングされたfp8を使う")
parser.add_argument(
"--unsloth_offload_checkpointing",
action="store_true",
help="offload activations to CPU RAM using async non-blocking transfers (faster than --cpu_offload_checkpointing). "
"Cannot be used with --cpu_offload_checkpointing or --blocks_to_swap.",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
if args.attn_mode == "sdpa":
args.attn_mode = "torch" # backward compatibility
trainer = AnimaNetworkTrainer()
trainer.train(args)

View File

@@ -0,0 +1,30 @@
{
"architectures": [
"Qwen3ForCausalLM"
],
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 151643,
"eos_token_id": 151643,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 1024,
"initializer_range": 0.02,
"intermediate_size": 3072,
"max_position_embeddings": 32768,
"max_window_layers": 28,
"model_type": "qwen3",
"num_attention_heads": 16,
"num_hidden_layers": 28,
"num_key_value_heads": 8,
"rms_norm_eps": 1e-06,
"rope_scaling": null,
"rope_theta": 1000000,
"sliding_window": null,
"tie_word_embeddings": true,
"torch_dtype": "bfloat16",
"transformers_version": "4.51.0",
"use_cache": true,
"use_sliding_window": false,
"vocab_size": 151936
}

151388
configs/qwen3_06b/merges.txt Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,239 @@
{
"add_bos_token": false,
"add_prefix_space": false,
"added_tokens_decoder": {
"151643": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151644": {
"content": "<|im_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151645": {
"content": "<|im_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151646": {
"content": "<|object_ref_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151647": {
"content": "<|object_ref_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151648": {
"content": "<|box_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151649": {
"content": "<|box_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151650": {
"content": "<|quad_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151651": {
"content": "<|quad_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151652": {
"content": "<|vision_start|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151653": {
"content": "<|vision_end|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151654": {
"content": "<|vision_pad|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151655": {
"content": "<|image_pad|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151656": {
"content": "<|video_pad|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"151657": {
"content": "<tool_call>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151658": {
"content": "</tool_call>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151659": {
"content": "<|fim_prefix|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151660": {
"content": "<|fim_middle|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151661": {
"content": "<|fim_suffix|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151662": {
"content": "<|fim_pad|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151663": {
"content": "<|repo_name|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151664": {
"content": "<|file_sep|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151665": {
"content": "<tool_response>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151666": {
"content": "</tool_response>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151667": {
"content": "<think>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
},
"151668": {
"content": "</think>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": false
}
},
"additional_special_tokens": [
"<|im_start|>",
"<|im_end|>",
"<|object_ref_start|>",
"<|object_ref_end|>",
"<|box_start|>",
"<|box_end|>",
"<|quad_start|>",
"<|quad_end|>",
"<|vision_start|>",
"<|vision_end|>",
"<|vision_pad|>",
"<|image_pad|>",
"<|video_pad|>"
],
"bos_token": null,
"chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set content = message.content %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is defined and message.reasoning_content is not none %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '</think>' in message.content %}\n {%- set content = message.content.split('</think>')[-1].lstrip('\\n') %}\n {%- set reasoning_content = message.content.split('</think>')[0].rstrip('\\n').split('<think>')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n<think>\\n' + reasoning_content.strip('\\n') + '\\n</think>\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '<think>\\n\\n</think>\\n\\n' }}\n {%- endif %}\n{%- endif %}",
"clean_up_tokenization_spaces": false,
"eos_token": "<|endoftext|>",
"errors": "replace",
"model_max_length": 131072,
"pad_token": "<|endoftext|>",
"split_special_tokens": false,
"tokenizer_class": "Qwen2Tokenizer",
"unk_token": null
}

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,51 @@
{
"architectures": [
"T5WithLMHeadModel"
],
"d_ff": 65536,
"d_kv": 128,
"d_model": 1024,
"decoder_start_token_id": 0,
"dropout_rate": 0.1,
"eos_token_id": 1,
"initializer_factor": 1.0,
"is_encoder_decoder": true,
"layer_norm_epsilon": 1e-06,
"model_type": "t5",
"n_positions": 512,
"num_heads": 128,
"num_layers": 24,
"output_past": true,
"pad_token_id": 0,
"relative_attention_num_buckets": 32,
"task_specific_params": {
"summarization": {
"early_stopping": true,
"length_penalty": 2.0,
"max_length": 200,
"min_length": 30,
"no_repeat_ngram_size": 3,
"num_beams": 4,
"prefix": "summarize: "
},
"translation_en_to_de": {
"early_stopping": true,
"max_length": 300,
"num_beams": 4,
"prefix": "translate English to German: "
},
"translation_en_to_fr": {
"early_stopping": true,
"max_length": 300,
"num_beams": 4,
"prefix": "translate English to French: "
},
"translation_en_to_ro": {
"early_stopping": true,
"max_length": 300,
"num_beams": 4,
"prefix": "translate English to Romanian: "
}
},
"vocab_size": 32128
}

BIN
configs/t5_old/spiece.model Normal file

Binary file not shown.

File diff suppressed because one or more lines are too long

655
docs/anima_train_network.md Normal file
View File

@@ -0,0 +1,655 @@
# LoRA Training Guide for Anima using `anima_train_network.py` / `anima_train_network.py` を用いたAnima モデルのLoRA学習ガイド
This document explains how to train LoRA (Low-Rank Adaptation) models for Anima using `anima_train_network.py` in the `sd-scripts` repository.
<details>
<summary>日本語</summary>
このドキュメントでは、`sd-scripts`リポジトリに含まれる`anima_train_network.py`を使用して、Anima モデルに対するLoRA (Low-Rank Adaptation) モデルを学習する基本的な手順について解説します。
</details>
## 1. Introduction / はじめに
`anima_train_network.py` trains additional networks such as LoRA for Anima models. Anima adopts a DiT (Diffusion Transformer) architecture based on the MiniTrainDIT design with Rectified Flow training. It uses a Qwen3-0.6B text encoder, an LLM Adapter (6-layer transformer bridge from Qwen3 to T5-compatible space), and a Qwen-Image VAE (16-channel, 8x spatial downscale).
Qwen-Image VAE and Qwen-Image VAE have same architecture, but [official Anima weight is named for Qwen-Image VAE](https://huggingface.co/circlestone-labs/Anima/tree/main/split_files/vae).
This guide assumes you already understand the basics of LoRA training. For common usage and options, see the [train_network.py guide](train_network.md). Some parameters are similar to those in [`sd3_train_network.py`](sd3_train_network.md) and [`flux_train_network.py`](flux_train_network.md).
**Prerequisites:**
* The `sd-scripts` repository has been cloned and the Python environment is ready.
* A training dataset has been prepared. See the [Dataset Configuration Guide](./config_README-en.md).
* Anima model files for training are available.
<details>
<summary>日本語</summary>
`anima_train_network.py`は、Anima モデルに対してLoRAなどの追加ネットワークを学習させるためのスクリプトです。AnimaはMiniTrainDIT設計に基づくDiT (Diffusion Transformer) アーキテクチャを採用しており、Rectified Flow学習を使用します。テキストエンコーダーとしてQwen3-0.6B、LLM Adapter (Qwen3からT5互換空間への6層Transformerブリッジ)、およびQwen-Image VAE (16チャンネル、8倍空間ダウンスケール) を使用します。
Qwen-Image VAEとQwen-Image VAEは同じアーキテクチャですが、[Anima公式の重みはQwen-Image VAE用](https://huggingface.co/circlestone-labs/Anima/tree/main/split_files/vae)のようです。
このガイドは、基本的なLoRA学習の手順を理解しているユーザーを対象としています。基本的な使い方や共通のオプションについては、[`train_network.py`のガイド](train_network.md)を参照してください。また一部のパラメータは [`sd3_train_network.py`](sd3_train_network.md) や [`flux_train_network.py`](flux_train_network.md) と同様のものがあるため、そちらも参考にしてください。
**前提条件:**
* `sd-scripts`リポジトリのクローンとPython環境のセットアップが完了していること。
* 学習用データセットの準備が完了していること。(データセットの準備については[データセット設定ガイド](./config_README-en.md)を参照してください)
* 学習対象のAnimaモデルファイルが準備できていること。
</details>
## 2. Differences from `train_network.py` / `train_network.py` との違い
`anima_train_network.py` is based on `train_network.py` but modified for Anima. Main differences are:
* **Target models:** Anima DiT models.
* **Model structure:** Uses a MiniTrainDIT (Transformer based) instead of U-Net. Employs a single text encoder (Qwen3-0.6B), an LLM Adapter that bridges Qwen3 embeddings to T5-compatible cross-attention space, and a Qwen-Image VAE (16-channel latent space with 8x spatial downscale).
* **Arguments:** Uses the common `--pretrained_model_name_or_path` for the DiT model path, `--qwen3` for the Qwen3 text encoder, and `--vae` for the Qwen-Image VAE. The LLM adapter and T5 tokenizer can be specified separately with `--llm_adapter_path` and `--t5_tokenizer_path`.
* **Incompatible arguments:** Stable Diffusion v1/v2 options such as `--v2`, `--v_parameterization` and `--clip_skip` are not used. `--fp8_base` is not supported.
* **Timestep sampling:** Uses the same `--timestep_sampling` options as FLUX training (`sigma`, `uniform`, `sigmoid`, `shift`, `flux_shift`).
* **LoRA:** Uses regex-based module selection and per-module rank/learning rate control (`network_reg_dims`, `network_reg_lrs`) instead of per-component arguments. Module exclusion/inclusion is controlled by `exclude_patterns` and `include_patterns`.
<details>
<summary>日本語</summary>
`anima_train_network.py``train_network.py`をベースに、Anima モデルに対応するための変更が加えられています。主な違いは以下の通りです。
* **対象モデル:** Anima DiTモデルを対象とします。
* **モデル構造:** U-Netの代わりにMiniTrainDIT (Transformerベース) を使用します。テキストエンコーダーとしてQwen3-0.6B、Qwen3埋め込みをT5互換のクロスアテンション空間に変換するLLM Adapter、およびQwen-Image VAE (16チャンネル潜在空間、8倍空間ダウンスケール) を使用します。
* **引数:** DiTモデルのパスには共通引数`--pretrained_model_name_or_path`を、Qwen3テキストエンコーダーには`--qwen3`を、Qwen-Image VAEには`--vae`を使用します。LLM AdapterとT5トークナイザーはそれぞれ`--llm_adapter_path``--t5_tokenizer_path`で個別に指定できます。
* **一部引数の非互換性:** Stable Diffusion v1/v2向けの引数例: `--v2`, `--v_parameterization`, `--clip_skip`)は使用されません。`--fp8_base`はサポートされていません。
* **タイムステップサンプリング:** FLUX学習と同じ`--timestep_sampling`オプション(`sigma``uniform``sigmoid``shift``flux_shift`)を使用します。
* **LoRA:** コンポーネント別の引数の代わりに、正規表現ベースのモジュール選択とモジュール単位のランク/学習率制御(`network_reg_dims``network_reg_lrs`)を使用します。モジュールの除外/包含は`exclude_patterns``include_patterns`で制御します。
</details>
## 3. Preparation / 準備
The following files are required before starting training:
1. **Training script:** `anima_train_network.py`
2. **Anima DiT model file:** `.safetensors` file for the base DiT model.
3. **Qwen3-0.6B text encoder:** Either a HuggingFace model directory, or a single `.safetensors` file (uses the bundled config files in `configs/qwen3_06b/`).
4. **Qwen-Image VAE model file:** `.safetensors` or `.pth` file for the VAE.
5. **LLM Adapter model file (optional):** `.safetensors` file. If not provided separately, the adapter is loaded from the DiT file if the key `llm_adapter.out_proj.weight` exists.
6. **T5 Tokenizer (optional):** If not specified, uses the bundled tokenizer at `configs/t5_old/`.
7. **Dataset definition file (.toml):** Dataset settings in TOML format. (See the [Dataset Configuration Guide](./config_README-en.md).) In this document we use `my_anima_dataset_config.toml` as an example.
Model files can be obtained from the [Anima HuggingFace repository](https://huggingface.co/circlestone-labs/Anima).
**Notes:**
* The T5 tokenizer only needs the tokenizer files (not the T5 model weights). It uses the vocabulary from `google/t5-v1_1-xxl`.
<details>
<summary>日本語</summary>
学習を開始する前に、以下のファイルが必要です。
1. **学習スクリプト:** `anima_train_network.py`
2. **Anima DiTモデルファイル:** ベースとなるDiTモデルの`.safetensors`ファイル。
3. **Qwen3-0.6Bテキストエンコーダー:** HuggingFaceモデルディレクトリまたは単体の`.safetensors`ファイル(バンドル版の`configs/qwen3_06b/`の設定ファイルが使用されます)。
4. **Qwen-Image VAEモデルファイル:** VAEの`.safetensors`または`.pth`ファイル。
5. **LLM Adapterモデルファイルオプション:** `.safetensors`ファイル。個別に指定しない場合、DiTファイル内に`llm_adapter.out_proj.weight`キーが存在すればそこから読み込まれます。
6. **T5トークナイザーオプション:** 指定しない場合、`configs/t5_old/`のバンドル版トークナイザーを使用します。
7. **データセット定義ファイル (.toml):** 学習データセットの設定を記述したTOML形式のファイル。詳細は[データセット設定ガイド](./config_README-en.md)を参照してください)。例として`my_anima_dataset_config.toml`を使用します。
モデルファイルは[HuggingFaceのAnimaリポジトリ](https://huggingface.co/circlestone-labs/Anima)から入手できます。
**注意:**
* T5トークナイザーを別途指定する場合、トークナイザーファイルのみ必要ですT5モデルの重みは不要`google/t5-v1_1-xxl`の語彙を使用します。
</details>
## 4. Running the Training / 学習の実行
Execute `anima_train_network.py` from the terminal to start training. The overall command-line format is the same as `train_network.py`, but Anima specific options must be supplied.
Example command:
```bash
accelerate launch --num_cpu_threads_per_process 1 anima_train_network.py \
--pretrained_model_name_or_path="<path to Anima DiT model>" \
--qwen3="<path to Qwen3-0.6B model or directory>" \
--vae="<path to Qwen-Image VAE model>" \
--dataset_config="my_anima_dataset_config.toml" \
--output_dir="<output directory>" \
--output_name="my_anima_lora" \
--save_model_as=safetensors \
--network_module=networks.lora_anima \
--network_dim=8 \
--learning_rate=1e-4 \
--optimizer_type="AdamW8bit" \
--lr_scheduler="constant" \
--timestep_sampling="sigmoid" \
--discrete_flow_shift=1.0 \
--max_train_epochs=10 \
--save_every_n_epochs=1 \
--mixed_precision="bf16" \
--gradient_checkpointing \
--cache_latents \
--cache_text_encoder_outputs \
--vae_chunk_size=64 \
--vae_disable_cache
```
*(Write the command on one line or use `\` or `^` for line breaks.)*
The learning rate of `1e-4` is just an example. Adjust it according to your dataset and objectives. This value is for `alpha=1.0` (default). If increasing `--network_alpha`, consider lowering the learning rate.
If loss becomes NaN, ensure you are using PyTorch version 2.5 or higher.
**Note:** `--vae_chunk_size` and `--vae_disable_cache` are custom options in this repository to reduce memory usage of the Qwen-Image VAE.
<details>
<summary>日本語</summary>
学習は、ターミナルから`anima_train_network.py`を実行することで開始します。基本的なコマンドラインの構造は`train_network.py`と同様ですが、Anima特有の引数を指定する必要があります。
コマンドラインの例は英語のドキュメントを参照してください。
※実際には1行で書くか、適切な改行文字`\` または `^`)を使用してください。
学習率1e-4はあくまで一例です。データセットや目的に応じて適切に調整してください。またこの値はalpha=1.0(デフォルト)での値です。`--network_alpha`を増やす場合は学習率を下げることを検討してください。
lossがNaNになる場合は、PyTorchのバージョンが2.5以上であることを確認してください。
注意: `--vae_chunk_size`および`--vae_disable_cache`は当リポジトリ独自のオプションで、Qwen-Image VAEのメモリ使用量を削減するために使用します。
</details>
### 4.1. Explanation of Key Options / 主要なコマンドライン引数の解説
Besides the arguments explained in the [train_network.py guide](train_network.md), specify the following Anima specific options. For shared options (`--output_dir`, `--output_name`, `--network_module`, etc.), see that guide.
#### Model Options [Required] / モデル関連 [必須]
* `--pretrained_model_name_or_path="<path to Anima DiT model>"` **[Required]**
- Path to the Anima DiT model `.safetensors` file. The model config (channels, blocks, heads) is auto-detected from the state dict. ComfyUI format with `net.` prefix is supported.
* `--qwen3="<path to Qwen3-0.6B model>"` **[Required]**
- Path to the Qwen3-0.6B text encoder. Can be a HuggingFace model directory or a single `.safetensors` file. The text encoder is always frozen during training.
* `--vae="<path to Qwen-Image VAE model>"` **[Required]**
- Path to the Qwen-Image VAE model `.safetensors` or `.pth` file. Fixed config: `dim=96, z_dim=16`.
#### Model Options [Optional] / モデル関連 [オプション]
* `--llm_adapter_path="<path to LLM adapter>"` *[Optional]*
- Path to a separate LLM adapter weights file. If omitted, the adapter is loaded from the DiT file when the key `llm_adapter.out_proj.weight` exists.
* `--t5_tokenizer_path="<path to T5 tokenizer>"` *[Optional]*
- Path to the T5 tokenizer directory. If omitted, uses the bundled config at `configs/t5_old/`.
#### Anima Training Parameters / Anima 学習パラメータ
* `--timestep_sampling=<choice>`
- Timestep sampling method. Choose from `sigma`, `uniform`, `sigmoid` (default), `shift`, `flux_shift`. Same options as FLUX training. See the [flux_train_network.py guide](flux_train_network.md) for details on each method.
* `--discrete_flow_shift=<float>`
- Shift for the timestep distribution in Rectified Flow training. Default `1.0`. This value is used when `--timestep_sampling` is set to **`shift`**. The shift formula is `t_shifted = (t * shift) / (1 + (shift - 1) * t)`.
* `--sigmoid_scale=<float>`
- Scale factor when `--timestep_sampling` is set to `sigmoid`, `shift`, or `flux_shift`. Default `1.0`.
* `--qwen3_max_token_length=<integer>`
- Maximum token length for the Qwen3 tokenizer. Default `512`.
* `--t5_max_token_length=<integer>`
- Maximum token length for the T5 tokenizer. Default `512`.
* `--attn_mode=<choice>`
- Attention implementation to use. Choose from `torch` (default), `xformers`, `flash`, `sageattn`. `xformers` requires `--split_attn`. `sageattn` does not support training (inference only). This option overrides `--xformers`.
* `--split_attn`
- Split attention computation to reduce memory usage. Required when using `--attn_mode xformers`.
#### Component-wise Learning Rates / コンポーネント別学習率
These options set separate learning rates for each component of the Anima model. They are primarily used for full fine-tuning. Set to `0` to freeze a component:
* `--self_attn_lr=<float>` - Learning rate for self-attention layers. Default: same as `--learning_rate`.
* `--cross_attn_lr=<float>` - Learning rate for cross-attention layers. Default: same as `--learning_rate`.
* `--mlp_lr=<float>` - Learning rate for MLP layers. Default: same as `--learning_rate`.
* `--mod_lr=<float>` - Learning rate for AdaLN modulation layers. Default: same as `--learning_rate`. Note: modulation layers are not included in LoRA by default.
* `--llm_adapter_lr=<float>` - Learning rate for LLM adapter layers. Default: same as `--learning_rate`.
For LoRA training, use `network_reg_lrs` in `--network_args` instead. See [Section 5.2](#52-regex-based-rank-and-learning-rate-control--正規表現によるランク学習率の制御).
#### Memory and Speed / メモリ・速度関連
* `--blocks_to_swap=<integer>`
- Number of Transformer blocks to swap between CPU and GPU. More blocks reduce VRAM but slow training. Maximum values depend on model size:
- 28-block model: max **26** (Anima-Preview)
- 36-block model: max **34**
- 20-block model: max **18**
- Cannot be used with `--cpu_offload_checkpointing` or `--unsloth_offload_checkpointing`.
* `--unsloth_offload_checkpointing`
- Offload activations to CPU RAM using async non-blocking transfers (faster than `--cpu_offload_checkpointing`). Cannot be combined with `--cpu_offload_checkpointing` or `--blocks_to_swap`.
* `--cache_text_encoder_outputs`
- Cache Qwen3 text encoder outputs to reduce VRAM usage. Recommended when not training text encoder LoRA.
* `--cache_text_encoder_outputs_to_disk`
- Cache text encoder outputs to disk. Auto-enables `--cache_text_encoder_outputs`.
* `--cache_latents`, `--cache_latents_to_disk`
- Cache Qwen-Image VAE latent outputs.
* `--vae_chunk_size=<integer>`
- Chunk size for Qwen-Image VAE processing. Reduces VRAM usage at the cost of speed. Default is no chunking.
* `--vae_disable_cache`
- Disable internal caching in Qwen-Image VAE to reduce VRAM usage.
#### Incompatible or Unsupported Options / 非互換・非サポートの引数
* `--v2`, `--v_parameterization`, `--clip_skip` - Options for Stable Diffusion v1/v2 that are not used for Anima training.
* `--fp8_base` - Not supported for Anima. If specified, it will be disabled with a warning.
<details>
<summary>日本語</summary>
[`train_network.py`のガイド](train_network.md)で説明されている引数に加え、以下のAnima特有の引数を指定します。共通の引数については、上記ガイドを参照してください。
#### モデル関連 [必須]
* `--pretrained_model_name_or_path="<path to Anima DiT model>"` **[必須]** - Anima DiTモデルの`.safetensors`ファイルのパスを指定します。モデルの設定はstate dictから自動検出されます。`net.`プレフィックス付きのComfyUIフォーマットもサポートしています。
* `--qwen3="<path to Qwen3-0.6B model>"` **[必須]** - Qwen3-0.6Bテキストエンコーダーのパスを指定します。HuggingFaceモデルディレクトリまたは単体の`.safetensors`ファイルが使用できます。
* `--vae="<path to Qwen-Image VAE model>"` **[必須]** - Qwen-Image VAEモデルのパスを指定します。
#### モデル関連 [オプション]
* `--llm_adapter_path="<path to LLM adapter>"` *[オプション]* - 個別のLLM Adapterの重みファイルのパス。
* `--t5_tokenizer_path="<path to T5 tokenizer>"` *[オプション]* - T5トークナイザーディレクトリのパス。
#### Anima 学習パラメータ
* `--timestep_sampling` - タイムステップのサンプリング方法。`sigma``uniform``sigmoid`(デフォルト)、`shift``flux_shift`から選択。FLUX学習と同じオプションです。各方法の詳細は[flux_train_network.pyのガイド](flux_train_network.md)を参照してください。
* `--discrete_flow_shift` - Rectified Flow学習のタイムステップ分布シフト。デフォルト`1.0``--timestep_sampling``shift`の場合に使用されます。
* `--sigmoid_scale` - `sigmoid``shift``flux_shift`タイムステップサンプリングのスケール係数。デフォルト`1.0`
* `--qwen3_max_token_length` - Qwen3トークナイザーの最大トークン長。デフォルト`512`
* `--t5_max_token_length` - T5トークナイザーの最大トークン長。デフォルト`512`
* `--attn_mode` - 使用するAttentionの実装。`torch`(デフォルト)、`xformers``flash``sageattn`から選択。`xformers``--split_attn`の指定が必要です。`sageattn`はトレーニングをサポートしていません(推論のみ)。
* `--split_attn` - メモリ使用量を減らすためにattention時にバッチを分割します。`--attn_mode xformers`使用時に必要です。
#### コンポーネント別学習率
これらのオプションは、Animaモデルの各コンポーネントに個別の学習率を設定します。主にフルファインチューニング用です。`0`に設定するとそのコンポーネントをフリーズします:
* `--self_attn_lr` - Self-attention層の学習率。
* `--cross_attn_lr` - Cross-attention層の学習率。
* `--mlp_lr` - MLP層の学習率。
* `--mod_lr` - AdaLNモジュレーション層の学習率。モジュレーション層はデフォルトではLoRAに含まれません。
* `--llm_adapter_lr` - LLM Adapter層の学習率。
LoRA学習の場合は、`--network_args``network_reg_lrs`を使用してください。[セクション5.2](#52-regex-based-rank-and-learning-rate-control--正規表現によるランク学習率の制御)を参照。
#### メモリ・速度関連
* `--blocks_to_swap` - TransformerブロックをCPUとGPUでスワップしてVRAMを節約。`--cpu_offload_checkpointing`および`--unsloth_offload_checkpointing`とは併用できません。
* `--unsloth_offload_checkpointing` - 非同期転送でアクティベーションをCPU RAMにオフロード。`--cpu_offload_checkpointing`および`--blocks_to_swap`とは併用できません。
* `--cache_text_encoder_outputs` - Qwen3の出力をキャッシュしてメモリ使用量を削減。
* `--cache_latents`, `--cache_latents_to_disk` - Qwen-Image VAEの出力をキャッシュ。
* `--vae_chunk_size` - Qwen-Image VAEのチャンク処理サイズ。メモリ使用量を削減しますが速度が低下します。デフォルトはチャンク処理なし。
* `--vae_disable_cache` - Qwen-Image VAEの内部キャッシュを無効化してメモリ使用量を削減します。
#### 非互換・非サポートの引数
* `--v2`, `--v_parameterization`, `--clip_skip` - Stable Diffusion v1/v2向けの引数。Animaの学習では使用されません。
* `--fp8_base` - Animaではサポートされていません。指定した場合、警告とともに無効化されます。
</details>
### 4.2. Starting Training / 学習の開始
After setting the required arguments, run the command to begin training. The overall flow and how to check logs are the same as in the [train_network.py guide](train_network.md#32-starting-the-training--学習の開始).
<details>
<summary>日本語</summary>
必要な引数を設定したら、コマンドを実行して学習を開始します。全体の流れやログの確認方法は、[train_network.pyのガイド](train_network.md#32-starting-the-training--学習の開始)と同様です。
</details>
## 5. LoRA Target Modules / LoRAの学習対象モジュール
When training LoRA with `anima_train_network.py`, the following modules are targeted by default:
* **DiT Blocks (`Block`)**: Self-attention (`self_attn`), cross-attention (`cross_attn`), and MLP (`mlp`) layers within each transformer block. Modulation (`adaln_modulation`), norm, embedder, and final layers are excluded by default.
* **Embedding layers (`PatchEmbed`, `TimestepEmbedding`) and Final layer (`FinalLayer`)**: Excluded by default but can be included using `include_patterns`.
* **LLM Adapter Blocks (`LLMAdapterTransformerBlock`)**: Only when `--network_args "train_llm_adapter=True"` is specified.
* **Text Encoder (Qwen3)**: Only when `--network_train_unet_only` is NOT specified and `--cache_text_encoder_outputs` is NOT used.
The LoRA network module is `networks.lora_anima`.
### 5.1. Module Selection with Patterns / パターンによるモジュール選択
By default, the following modules are excluded from LoRA via the built-in exclude pattern:
```
.*(_modulation|_norm|_embedder|final_layer).*
```
You can customize which modules are included or excluded using regex patterns in `--network_args`:
* `exclude_patterns` - Exclude modules matching these patterns (in addition to the default exclusion).
* `include_patterns` - Force-include modules matching these patterns, overriding exclusion.
Patterns are matched against the full module name using `re.fullmatch()`.
Example to include the final layer:
```
--network_args "include_patterns=['.*final_layer.*']"
```
Example to additionally exclude MLP layers:
```
--network_args "exclude_patterns=['.*mlp.*']"
```
### 5.2. Regex-based Rank and Learning Rate Control / 正規表現によるランク・学習率の制御
You can specify different ranks (network_dim) and learning rates for modules matching specific regex patterns:
* `network_reg_dims`: Specify ranks for modules matching a regular expression. The format is a comma-separated string of `pattern=rank`.
* Example: `--network_args "network_reg_dims=.*self_attn.*=8,.*cross_attn.*=4,.*mlp.*=8"`
* This sets the rank to 8 for self-attention modules, 4 for cross-attention modules, and 8 for MLP modules.
* `network_reg_lrs`: Specify learning rates for modules matching a regular expression. The format is a comma-separated string of `pattern=lr`.
* Example: `--network_args "network_reg_lrs=.*self_attn.*=1e-4,.*cross_attn.*=5e-5"`
* This sets the learning rate to `1e-4` for self-attention modules and `5e-5` for cross-attention modules.
**Notes:**
* Settings via `network_reg_dims` and `network_reg_lrs` take precedence over the global `--network_dim` and `--learning_rate` settings.
* Patterns are matched using `re.fullmatch()` against the module's original name (e.g., `blocks.0.self_attn.q_proj`).
### 5.3. LLM Adapter LoRA / LLM Adapter LoRA
To apply LoRA to the LLM Adapter blocks:
```
--network_args "train_llm_adapter=True"
```
In preliminary tests, lowering the learning rate for the LLM Adapter seems to improve stability. Adjust it using something like: `"network_reg_lrs=.*llm_adapter.*=5e-5"`.
### 5.4. Other Network Args / その他のネットワーク引数
* `--network_args "verbose=True"` - Print all LoRA module names and their dimensions.
* `--network_args "rank_dropout=0.1"` - Rank dropout rate.
* `--network_args "module_dropout=0.1"` - Module dropout rate.
* `--network_args "loraplus_lr_ratio=2.0"` - LoRA+ learning rate ratio.
* `--network_args "loraplus_unet_lr_ratio=2.0"` - LoRA+ learning rate ratio for DiT only.
* `--network_args "loraplus_text_encoder_lr_ratio=2.0"` - LoRA+ learning rate ratio for text encoder only.
<details>
<summary>日本語</summary>
`anima_train_network.py`でLoRAを学習させる場合、デフォルトでは以下のモジュールが対象となります。
* **DiTブロック (`Block`)**: 各Transformerブロック内のSelf-attention`self_attn`、Cross-attention`cross_attn`、MLP`mlp`)層。モジュレーション(`adaln_modulation`、norm、embedder、final layerはデフォルトで除外されます。
* **埋め込み層 (`PatchEmbed`, `TimestepEmbedding`) と最終層 (`FinalLayer`)**: デフォルトで除外されますが、`include_patterns`で含めることができます。
* **LLM Adapterブロック (`LLMAdapterTransformerBlock`)**: `--network_args "train_llm_adapter=True"`を指定した場合のみ。
* **テキストエンコーダー (Qwen3)**: `--network_train_unet_only`を指定せず、かつ`--cache_text_encoder_outputs`を使用しない場合のみ。
### 5.1. パターンによるモジュール選択
デフォルトでは以下のモジュールが組み込みの除外パターンによりLoRAから除外されます
```
.*(_modulation|_norm|_embedder|final_layer).*
```
`--network_args`で正規表現パターンを使用して、含めるモジュールと除外するモジュールをカスタマイズできます:
* `exclude_patterns` - これらのパターンにマッチするモジュールを除外(デフォルトの除外に追加)。
* `include_patterns` - これらのパターンにマッチするモジュールを強制的に含める(除外を上書き)。
パターンは`re.fullmatch()`を使用して完全なモジュール名に対してマッチングされます。
### 5.2. 正規表現によるランク・学習率の制御
正規表現にマッチするモジュールに対して、異なるランクや学習率を指定できます:
* `network_reg_dims`: 正規表現にマッチするモジュールに対してランクを指定します。`pattern=rank`形式の文字列をカンマで区切って指定します。
* 例: `--network_args "network_reg_dims=.*self_attn.*=8,.*cross_attn.*=4,.*mlp.*=8"`
* `network_reg_lrs`: 正規表現にマッチするモジュールに対して学習率を指定します。`pattern=lr`形式の文字列をカンマで区切って指定します。
* 例: `--network_args "network_reg_lrs=.*self_attn.*=1e-4,.*cross_attn.*=5e-5"`
**注意点:**
* `network_reg_dims`および`network_reg_lrs`での設定は、全体設定である`--network_dim``--learning_rate`よりも優先されます。
* パターンはモジュールのオリジナル名(例: `blocks.0.self_attn.q_proj`)に対して`re.fullmatch()`でマッチングされます。
### 5.3. LLM Adapter LoRA
LLM AdapterブロックにLoRAを適用するには`--network_args "train_llm_adapter=True"`
簡易な検証ではLLM Adapterの学習率はある程度下げた方が安定するようです。`"network_reg_lrs=.*llm_adapter.*=5e-5"`などで調整してください。
### 5.4. その他のネットワーク引数
* `verbose=True` - 全LoRAモジュール名とdimを表示
* `rank_dropout` - ランクドロップアウト率
* `module_dropout` - モジュールドロップアウト率
* `loraplus_lr_ratio` - LoRA+学習率比率
* `loraplus_unet_lr_ratio` - DiT専用のLoRA+学習率比率
* `loraplus_text_encoder_lr_ratio` - テキストエンコーダー専用のLoRA+学習率比率
</details>
## 6. Using the Trained Model / 学習済みモデルの利用
When training finishes, a LoRA model file (e.g. `my_anima_lora.safetensors`) is saved in the directory specified by `output_dir`. Use this file with inference environments that support Anima, such as ComfyUI with appropriate nodes.
<details>
<summary>日本語</summary>
学習が完了すると、指定した`output_dir`にLoRAモデルファイル例: `my_anima_lora.safetensors`が保存されます。このファイルは、Anima モデルに対応した推論環境(例: ComfyUI + 適切なノード)で使用できます。
</details>
## 7. Advanced Settings / 高度な設定
### 7.1. VRAM Usage Optimization / VRAM使用量の最適化
Anima models can be large, so GPUs with limited VRAM may require optimization:
#### Key VRAM Reduction Options
- **`--blocks_to_swap <number>`**: Swaps blocks between CPU and GPU to reduce VRAM usage. Higher numbers save more VRAM but reduce training speed. See model-specific max values in section 4.1.
- **`--unsloth_offload_checkpointing`**: Offloads gradient checkpoints to CPU using async non-blocking transfers. Faster than `--cpu_offload_checkpointing`. Cannot be combined with `--blocks_to_swap`.
- **`--gradient_checkpointing`**: Standard gradient checkpointing to reduce VRAM at the cost of compute.
- **`--cache_text_encoder_outputs`**: Caches Qwen3 outputs so the text encoder can be freed from VRAM during training.
- **`--cache_latents`**: Caches Qwen-Image VAE outputs so the VAE can be freed from VRAM during training.
- **Using Adafactor optimizer**: Can reduce VRAM usage:
```
--optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --lr_scheduler constant_with_warmup --max_grad_norm 0.0
```
<details>
<summary>日本語</summary>
Animaモデルは大きい場合があるため、VRAMが限られたGPUでは最適化が必要です。
主要なVRAM削減オプション
- `--blocks_to_swap`: CPUとGPU間でブロックをスワップ
- `--unsloth_offload_checkpointing`: 非同期転送でアクティベーションをCPUにオフロード
- `--gradient_checkpointing`: 標準的な勾配チェックポイント
- `--cache_text_encoder_outputs`: Qwen3の出力をキャッシュ
- `--cache_latents`: Qwen-Image VAEの出力をキャッシュ
- Adafactorオプティマイザの使用
</details>
### 7.2. Training Settings / 学習設定
#### Timestep Sampling
The `--timestep_sampling` option specifies how timesteps are sampled. The available methods are the same as FLUX training:
- `sigma`: Sigma-based sampling like SD3.
- `uniform`: Uniform random sampling from [0, 1].
- `sigmoid` (default): Sample from Normal(0,1), multiply by `sigmoid_scale`, apply sigmoid. Good general-purpose option.
- `shift`: Like `sigmoid`, but applies the discrete flow shift formula: `t_shifted = (t * shift) / (1 + (shift - 1) * t)`.
- `flux_shift`: Resolution-dependent shift used in FLUX training.
See the [flux_train_network.py guide](flux_train_network.md) for detailed descriptions.
#### Discrete Flow Shift
The `--discrete_flow_shift` option (default `1.0`) only applies when `--timestep_sampling` is set to `shift`. The formula is:
```
t_shifted = (t * shift) / (1 + (shift - 1) * t)
```
#### Loss Weighting
The `--weighting_scheme` option specifies loss weighting by timestep:
- `uniform` (default): Equal weight for all timesteps.
- `sigma_sqrt`: Weight by `sigma^(-2)`.
- `cosmap`: Weight by `2 / (pi * (1 - 2*sigma + 2*sigma^2))`.
- `none`: Same as uniform.
- `logit_normal`, `mode`: Additional schemes from SD3 training. See the [`sd3_train_network.md` guide](sd3_train_network.md) for details.
#### Caption Dropout
Caption dropout uses the `caption_dropout_rate` setting from the dataset configuration (per-subset in TOML). When using `--cache_text_encoder_outputs`, the dropout rate is stored with each cached entry and applied during training, so caption dropout is compatible with text encoder output caching.
**If you change the `caption_dropout_rate` setting, you must delete and regenerate the cache.**
Note: Currently, only Anima supports combining `caption_dropout_rate` with text encoder output caching.
<details>
<summary>日本語</summary>
#### タイムステップサンプリング
`--timestep_sampling`でタイムステップのサンプリング方法を指定します。FLUX学習と同じ方法が利用できます
- `sigma`: SD3と同様のシグマベースサンプリング。
- `uniform`: [0, 1]の一様分布からサンプリング。
- `sigmoid`(デフォルト): 正規分布からサンプリングし、sigmoidを適用。汎用的なオプション。
- `shift`: `sigmoid`と同様だが、離散フローシフトの式を適用。
- `flux_shift`: FLUX学習で使用される解像度依存のシフト。
詳細は[flux_train_network.pyのガイド](flux_train_network.md)を参照してください。
#### 離散フローシフト
`--discrete_flow_shift`(デフォルト`1.0`)は`--timestep_sampling`が`shift`の場合のみ適用されます。
#### 損失の重み付け
`--weighting_scheme`でタイムステップごとの損失の重み付けを指定します。
#### キャプションドロップアウト
キャプションドロップアウトにはデータセット設定TOMLでのサブセット単位の`caption_dropout_rate`を使用します。`--cache_text_encoder_outputs`使用時は、ドロップアウト率が各キャッシュエントリとともに保存され、学習中に適用されるため、テキストエンコーダー出力キャッシュと同時に使用できます。
**`caption_dropout_rate`の設定を変えた場合、キャッシュを削除し、再生成する必要があります。**
※`caption_dropout_rate`をテキストエンコーダー出力キャッシュと組み合わせられるのは、今のところAnimaのみです。
</details>
### 7.3. Text Encoder LoRA Support / Text Encoder LoRAのサポート
Anima LoRA training supports training Qwen3 text encoder LoRA:
- To train only DiT: specify `--network_train_unet_only`
- To train DiT and Qwen3: omit `--network_train_unet_only` and do NOT use `--cache_text_encoder_outputs`
You can specify a separate learning rate for Qwen3 with `--text_encoder_lr`. If not specified, the default `--learning_rate` is used.
Note: When `--cache_text_encoder_outputs` is used, text encoder outputs are pre-computed and the text encoder is removed from GPU, so text encoder LoRA cannot be trained.
<details>
<summary>日本語</summary>
Anima LoRA学習では、Qwen3テキストエンコーダーのLoRAもトレーニングできます。
- DiTのみ学習: `--network_train_unet_only`を指定
- DiTとQwen3を学習: `--network_train_unet_only`を省略し、`--cache_text_encoder_outputs`を使用しない
Qwen3に個別の学習率を指定するには`--text_encoder_lr`を使用します。未指定の場合は`--learning_rate`が使われます。
注意: `--cache_text_encoder_outputs`を使用する場合、テキストエンコーダーの出力が事前に計算されGPUから解放されるため、テキストエンコーダーLoRAは学習できません。
</details>
## 8. Other Training Options / その他の学習オプション
- **`--loss_type`**: Loss function for training. Default `l2`.
- `l1`: L1 loss.
- `l2`: L2 loss (mean squared error).
- `huber`: Huber loss.
- `smooth_l1`: Smooth L1 loss.
- **`--huber_schedule`**, **`--huber_c`**, **`--huber_scale`**: Parameters for Huber loss when `--loss_type` is `huber` or `smooth_l1`.
- **`--ip_noise_gamma`**, **`--ip_noise_gamma_random_strength`**: Input Perturbation noise gamma values.
- **`--fused_backward_pass`**: Fuses the backward pass and optimizer step to reduce VRAM usage. Only works with Adafactor. For details, see the [`sdxl_train_network.py` guide](sdxl_train_network.md).
- **`--weighting_scheme`**, **`--logit_mean`**, **`--logit_std`**, **`--mode_scale`**: Timestep loss weighting options. For details, refer to the [`sd3_train_network.md` guide](sd3_train_network.md).
<details>
<summary>日本語</summary>
- **`--loss_type`**: 学習に用いる損失関数。デフォルト`l2`。`l1`, `l2`, `huber`, `smooth_l1`から選択。
- **`--huber_schedule`**, **`--huber_c`**, **`--huber_scale`**: Huber損失のパラメータ。
- **`--ip_noise_gamma`**: Input Perturbationイズガンマ値。
- **`--fused_backward_pass`**: バックワードパスとオプティマイザステップの融合。
- **`--weighting_scheme`** 等: タイムステップ損失の重み付け。詳細は[`sd3_train_network.md`](sd3_train_network.md)を参照。
</details>
## 9. Related Tools / 関連ツール
### `networks/anima_convert_lora_to_comfy.py`
A script to convert LoRA models to ComfyUI-compatible format. ComfyUI does not directly support sd-scripts format Qwen3 LoRA, so conversion is necessary (conversion may not be needed for DiT-only LoRA). You can convert from the sd-scripts format to ComfyUI format with:
```bash
python networks/convert_anima_lora_to_comfy.py path/to/source.safetensors path/to/destination.safetensors
```
Using the `--reverse` option allows conversion in the opposite direction (ComfyUI format to sd-scripts format). However, reverse conversion is only possible for LoRAs converted by this script. LoRAs created with other training tools cannot be converted.
<details>
<summary>日本語</summary>
**`networks/convert_anima_lora_to_comfy.py`**
LoRAモデルをComfyUI互換形式に変換するスクリプト。ComfyUIがsd-scripts形式のQwen3 LoRAを直接サポートしていないため、変換が必要ですDiTのみのLoRAの場合は変換不要のようです。sd-scripts形式からComfyUI形式への変換は以下のコマンドで行います
```bash
python networks/convert_anima_lora_to_comfy.py path/to/source.safetensors path/to/destination.safetensors
```
`--reverse`オプションを付けると、逆変換ComfyUI形式からsd-scripts形式も可能です。ただし、逆変換ができるのはこのスクリプトで変換したLoRAに限ります。他の学習ツールで作成したLoRAは変換できません。
</details>
## 10. Others / その他
### Metadata Saved in LoRA Models
The following metadata is saved in the LoRA model file:
* `ss_weighting_scheme`
* `ss_logit_mean`
* `ss_logit_std`
* `ss_mode_scale`
* `ss_timestep_sampling`
* `ss_sigmoid_scale`
* `ss_discrete_flow_shift`
<details>
<summary>日本語</summary>
`anima_train_network.py`には、サンプル画像の生成 (`--sample_prompts`など) や詳細なオプティマイザ設定など、`train_network.py`と共通の機能も多く存在します。これらについては、[`train_network.py`のガイド](train_network.md#5-other-features--その他の機能)やスクリプトのヘルプ (`python anima_train_network.py --help`) を参照してください。
### LoRAモデルに保存されるメタデータ
以下のメタデータがLoRAモデルファイルに保存されます
* `ss_weighting_scheme`
* `ss_logit_mean`
* `ss_logit_std`
* `ss_mode_scale`
* `ss_timestep_sampling`
* `ss_sigmoid_scale`
* `ss_discrete_flow_shift`
</details>

1654
library/anima_models.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,615 @@
# Anima Training Utilities
import argparse
import gc
import math
import os
import time
from typing import Optional
import numpy as np
import torch
from accelerate import Accelerator
from tqdm import tqdm
from PIL import Image
from library.device_utils import init_ipex, clean_memory_on_device, synchronize_device
from library import anima_models, anima_utils, train_util, qwen_image_autoencoder_kl
init_ipex()
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# Anima-specific training arguments
def add_anima_training_arguments(parser: argparse.ArgumentParser):
"""Add Anima-specific training arguments to the parser."""
parser.add_argument(
"--qwen3",
type=str,
default=None,
help="Path to Qwen3-0.6B model (safetensors file or directory)",
)
parser.add_argument(
"--llm_adapter_path",
type=str,
default=None,
help="Path to separate LLM adapter weights. If None, adapter is loaded from DiT file if present",
)
parser.add_argument(
"--llm_adapter_lr",
type=float,
default=None,
help="Learning rate for LLM adapter. None=same as base LR, 0=freeze adapter",
)
parser.add_argument(
"--self_attn_lr",
type=float,
default=None,
help="Learning rate for self-attention layers. None=same as base LR, 0=freeze",
)
parser.add_argument(
"--cross_attn_lr",
type=float,
default=None,
help="Learning rate for cross-attention layers. None=same as base LR, 0=freeze",
)
parser.add_argument(
"--mlp_lr",
type=float,
default=None,
help="Learning rate for MLP layers. None=same as base LR, 0=freeze",
)
parser.add_argument(
"--mod_lr",
type=float,
default=None,
help="Learning rate for AdaLN modulation layers. None=same as base LR, 0=freeze. Note: mod layers are not included in LoRA by default.",
)
parser.add_argument(
"--t5_tokenizer_path",
type=str,
default=None,
help="Path to T5 tokenizer directory. If None, uses default configs/t5_old/",
)
parser.add_argument(
"--qwen3_max_token_length",
type=int,
default=512,
help="Maximum token length for Qwen3 tokenizer (default: 512)",
)
parser.add_argument(
"--t5_max_token_length",
type=int,
default=512,
help="Maximum token length for T5 tokenizer (default: 512)",
)
parser.add_argument(
"--discrete_flow_shift",
type=float,
default=1.0,
help="Timestep distribution shift for rectified flow training (default: 1.0)",
)
parser.add_argument(
"--timestep_sampling",
type=str,
default="sigmoid",
choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"],
help="Timestep sampling method (default: sigmoid (logit normal))",
)
parser.add_argument(
"--sigmoid_scale",
type=float,
default=1.0,
help="Scale factor for sigmoid (logit_normal) timestep sampling (default: 1.0)",
)
parser.add_argument(
"--attn_mode",
choices=["torch", "xformers", "flash", "sageattn", "sdpa"], # "sdpa" is for backward compatibility
default=None,
help="Attention implementation to use. Default is None (torch). xformers requires --split_attn. sageattn does not support training (inference only). This option overrides --xformers or --sdpa."
" / 使用するAttentionの実装。デフォルトはNonetorchです。xformersは--split_attnの指定が必要です。sageattnはトレーニングをサポートしていません推論のみ。このオプションは--xformersまたは--sdpaを上書きします。",
)
parser.add_argument(
"--split_attn",
action="store_true",
help="split attention computation to reduce memory usage / メモリ使用量を減らすためにattention時にバッチを分割する",
)
parser.add_argument(
"--vae_chunk_size",
type=int,
default=None,
help="Spatial chunk size for VAE encoding/decoding to reduce memory usage. Must be even number. If not specified, chunking is disabled (official behavior)."
+ " / メモリ使用量を減らすためのVAEエンコード/デコードの空間チャンクサイズ。偶数である必要があります。未指定の場合、チャンク処理は無効になります(公式の動作)。",
)
parser.add_argument(
"--vae_disable_cache",
action="store_true",
help="Disable internal VAE caching mechanism to reduce memory usage. Encoding / decoding will also be faster, but this differs from official behavior."
+ " / VAEのメモリ使用量を減らすために内部のキャッシュ機構を無効にします。エンコード/デコードも速くなりますが、公式の動作とは異なります。",
)
# Loss weighting
def compute_loss_weighting_for_anima(weighting_scheme: str, sigmas: torch.Tensor) -> torch.Tensor:
"""Compute loss weighting for Anima training.
Same schemes as SD3 but can add Anima-specific ones if needed in future.
"""
if weighting_scheme == "sigma_sqrt":
weighting = (sigmas**-2.0).float()
elif weighting_scheme == "cosmap":
bot = 1 - 2 * sigmas + 2 * sigmas**2
weighting = 2 / (math.pi * bot)
elif weighting_scheme == "none" or weighting_scheme is None:
weighting = torch.ones_like(sigmas)
else:
weighting = torch.ones_like(sigmas)
return weighting
# Parameter groups (6 groups with separate LRs)
def get_anima_param_groups(
dit,
base_lr: float,
self_attn_lr: Optional[float] = None,
cross_attn_lr: Optional[float] = None,
mlp_lr: Optional[float] = None,
mod_lr: Optional[float] = None,
llm_adapter_lr: Optional[float] = None,
):
"""Create parameter groups for Anima training with separate learning rates.
Args:
dit: Anima model
base_lr: Base learning rate
self_attn_lr: LR for self-attention layers (None = base_lr, 0 = freeze)
cross_attn_lr: LR for cross-attention layers
mlp_lr: LR for MLP layers
mod_lr: LR for AdaLN modulation layers
llm_adapter_lr: LR for LLM adapter
Returns:
List of parameter group dicts for optimizer
"""
if self_attn_lr is None:
self_attn_lr = base_lr
if cross_attn_lr is None:
cross_attn_lr = base_lr
if mlp_lr is None:
mlp_lr = base_lr
if mod_lr is None:
mod_lr = base_lr
if llm_adapter_lr is None:
llm_adapter_lr = base_lr
base_params = []
self_attn_params = []
cross_attn_params = []
mlp_params = []
mod_params = []
llm_adapter_params = []
for name, p in dit.named_parameters():
# Store original name for debugging
p.original_name = name
if "llm_adapter" in name:
llm_adapter_params.append(p)
elif ".self_attn" in name:
self_attn_params.append(p)
elif ".cross_attn" in name:
cross_attn_params.append(p)
elif ".mlp" in name:
mlp_params.append(p)
elif ".adaln_modulation" in name:
mod_params.append(p)
else:
base_params.append(p)
logger.info(f"Parameter groups:")
logger.info(f" base_params: {len(base_params)} (lr={base_lr})")
logger.info(f" self_attn_params: {len(self_attn_params)} (lr={self_attn_lr})")
logger.info(f" cross_attn_params: {len(cross_attn_params)} (lr={cross_attn_lr})")
logger.info(f" mlp_params: {len(mlp_params)} (lr={mlp_lr})")
logger.info(f" mod_params: {len(mod_params)} (lr={mod_lr})")
logger.info(f" llm_adapter_params: {len(llm_adapter_params)} (lr={llm_adapter_lr})")
param_groups = []
for lr, params, name in [
(base_lr, base_params, "base"),
(self_attn_lr, self_attn_params, "self_attn"),
(cross_attn_lr, cross_attn_params, "cross_attn"),
(mlp_lr, mlp_params, "mlp"),
(mod_lr, mod_params, "mod"),
(llm_adapter_lr, llm_adapter_params, "llm_adapter"),
]:
if lr == 0:
for p in params:
p.requires_grad_(False)
logger.info(f" Frozen {name} params ({len(params)} parameters)")
elif len(params) > 0:
param_groups.append({"params": params, "lr": lr})
total_trainable = sum(p.numel() for group in param_groups for p in group["params"] if p.requires_grad)
logger.info(f"Total trainable parameters: {total_trainable:,}")
return param_groups
# Save functions
def save_anima_model_on_train_end(
args: argparse.Namespace,
save_dtype: torch.dtype,
epoch: int,
global_step: int,
dit: anima_models.Anima,
):
"""Save Anima model at the end of training."""
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec_dataclass(
None, args, False, False, False, is_stable_diffusion_ckpt=True, anima="preview"
).to_metadata_dict()
dit_sd = dit.state_dict()
# Save with 'net.' prefix for ComfyUI compatibility
anima_utils.save_anima_model(ckpt_file, dit_sd, sai_metadata, save_dtype)
train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None)
def save_anima_model_on_epoch_end_or_stepwise(
args: argparse.Namespace,
on_epoch_end: bool,
accelerator: Accelerator,
save_dtype: torch.dtype,
epoch: int,
num_train_epochs: int,
global_step: int,
dit: anima_models.Anima,
):
"""Save Anima model at epoch end or specific steps."""
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec_dataclass(
None, args, False, False, False, is_stable_diffusion_ckpt=True, anima="preview"
).to_metadata_dict()
dit_sd = dit.state_dict()
anima_utils.save_anima_model(ckpt_file, dit_sd, sai_metadata, save_dtype)
train_util.save_sd_model_on_epoch_end_or_stepwise_common(
args,
on_epoch_end,
accelerator,
True,
True,
epoch,
num_train_epochs,
global_step,
sd_saver,
None,
)
# Sampling (Euler discrete for rectified flow)
def do_sample(
height: int,
width: int,
seed: Optional[int],
dit: anima_models.Anima,
crossattn_emb: torch.Tensor,
steps: int,
dtype: torch.dtype,
device: torch.device,
guidance_scale: float = 1.0,
flow_shift: float = 3.0,
neg_crossattn_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Generate a sample using Euler discrete sampling for rectified flow.
Args:
height, width: Output image dimensions
seed: Random seed (None for random)
dit: Anima model
crossattn_emb: Cross-attention embeddings (B, N, D)
steps: Number of sampling steps
dtype: Compute dtype
device: Compute device
guidance_scale: CFG scale (1.0 = no guidance)
flow_shift: Flow shift parameter for rectified flow
neg_crossattn_emb: Negative cross-attention embeddings for CFG
Returns:
Denoised latents
"""
# Latent shape: (1, 16, 1, H/8, W/8) for single image
latent_h = height // 8
latent_w = width // 8
latent = torch.zeros(1, 16, 1, latent_h, latent_w, device=device, dtype=dtype)
# Generate noise
if seed is not None:
generator = torch.manual_seed(seed)
else:
generator = None
noise = torch.randn(latent.size(), dtype=torch.float32, generator=generator, device="cpu").to(dtype).to(device)
# Timestep schedule: linear from 1.0 to 0.0
sigmas = torch.linspace(1.0, 0.0, steps + 1, device=device, dtype=dtype)
flow_shift = float(flow_shift)
if flow_shift != 1.0:
sigmas = (sigmas * flow_shift) / (1 + (flow_shift - 1) * sigmas)
# Start from pure noise
x = noise.clone()
# Padding mask (zeros = no padding) — resized in prepare_embedded_sequence to match latent dims
padding_mask = torch.zeros(1, 1, latent_h, latent_w, dtype=dtype, device=device)
use_cfg = guidance_scale > 1.0 and neg_crossattn_emb is not None
for i in tqdm(range(steps), desc="Sampling"):
sigma = sigmas[i]
t = sigma.unsqueeze(0) # (1,)
if use_cfg:
# CFG: two separate passes to reduce memory usage
pos_out = dit(x, t, crossattn_emb, padding_mask=padding_mask)
pos_out = pos_out.float()
neg_out = dit(x, t, neg_crossattn_emb, padding_mask=padding_mask)
neg_out = neg_out.float()
model_output = neg_out + guidance_scale * (pos_out - neg_out)
else:
model_output = dit(x, t, crossattn_emb, padding_mask=padding_mask)
model_output = model_output.float()
# Euler step: x_{t-1} = x_t - (sigma_t - sigma_{t-1}) * model_output
dt = sigmas[i + 1] - sigma
x = x + model_output * dt
x = x.to(dtype)
return x
def sample_images(
accelerator: Accelerator,
args: argparse.Namespace,
epoch,
steps,
dit: anima_models.Anima,
vae,
text_encoder,
tokenize_strategy,
text_encoding_strategy,
sample_prompts_te_outputs=None,
prompt_replacement=None,
):
"""Generate sample images during training.
This is a simplified sampler for Anima - it generates images using the current model state.
"""
if steps == 0:
if not args.sample_at_first:
return
else:
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
return
if args.sample_every_n_epochs is not None:
if epoch is None or epoch % args.sample_every_n_epochs != 0:
return
else:
if steps % args.sample_every_n_steps != 0 or epoch is not None:
return
logger.info(f"Generating sample images at step {steps}")
if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None:
logger.error(f"No prompt file: {args.sample_prompts}")
return
# Unwrap models
dit = accelerator.unwrap_model(dit)
if text_encoder is not None:
text_encoder = accelerator.unwrap_model(text_encoder)
dit.switch_block_swap_for_inference()
prompts = train_util.load_prompts(args.sample_prompts)
save_dir = os.path.join(args.output_dir, "sample")
os.makedirs(save_dir, exist_ok=True)
# Save RNG state
rng_state = torch.get_rng_state()
cuda_rng_state = None
try:
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
except Exception:
pass
with torch.no_grad(), accelerator.autocast():
for prompt_dict in prompts:
dit.prepare_block_swap_before_forward()
_sample_image_inference(
accelerator,
args,
dit,
text_encoder,
vae,
tokenize_strategy,
text_encoding_strategy,
save_dir,
prompt_dict,
epoch,
steps,
sample_prompts_te_outputs,
prompt_replacement,
)
# Restore RNG state
torch.set_rng_state(rng_state)
if cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state)
dit.switch_block_swap_for_training()
clean_memory_on_device(accelerator.device)
def _sample_image_inference(
accelerator,
args,
dit,
text_encoder,
vae: qwen_image_autoencoder_kl.AutoencoderKLQwenImage,
tokenize_strategy,
text_encoding_strategy,
save_dir,
prompt_dict,
epoch,
steps,
sample_prompts_te_outputs,
prompt_replacement,
):
"""Generate a single sample image."""
prompt = prompt_dict.get("prompt", "")
negative_prompt = prompt_dict.get("negative_prompt", "")
sample_steps = prompt_dict.get("sample_steps", 30)
width = prompt_dict.get("width", 512)
height = prompt_dict.get("height", 512)
scale = prompt_dict.get("scale", 7.5)
seed = prompt_dict.get("seed")
flow_shift = prompt_dict.get("flow_shift", 3.0)
if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
if negative_prompt:
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # seed all CUDA devices for multi-GPU
height = max(64, height - height % 16)
width = max(64, width - width % 16)
logger.info(
f" prompt: {prompt}, size: {width}x{height}, steps: {sample_steps}, scale: {scale}, flow_shift: {flow_shift}, seed: {seed}"
)
# Encode prompt
def encode_prompt(prpt):
if sample_prompts_te_outputs and prpt in sample_prompts_te_outputs:
return sample_prompts_te_outputs[prpt]
if text_encoder is not None:
tokens = tokenize_strategy.tokenize(prpt)
encoded = text_encoding_strategy.encode_tokens(tokenize_strategy, [text_encoder], tokens)
return encoded
return None
encoded = encode_prompt(prompt)
if encoded is None:
logger.warning("Cannot encode prompt, skipping sample")
return
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = encoded
# Convert to tensors if numpy
if isinstance(prompt_embeds, np.ndarray):
prompt_embeds = torch.from_numpy(prompt_embeds).unsqueeze(0)
attn_mask = torch.from_numpy(attn_mask).unsqueeze(0)
t5_input_ids = torch.from_numpy(t5_input_ids).unsqueeze(0)
t5_attn_mask = torch.from_numpy(t5_attn_mask).unsqueeze(0)
prompt_embeds = prompt_embeds.to(accelerator.device, dtype=dit.dtype)
attn_mask = attn_mask.to(accelerator.device)
t5_input_ids = t5_input_ids.to(accelerator.device, dtype=torch.long)
t5_attn_mask = t5_attn_mask.to(accelerator.device)
# Process through LLM adapter if available
if dit.use_llm_adapter:
crossattn_emb = dit.llm_adapter(
source_hidden_states=prompt_embeds,
target_input_ids=t5_input_ids,
target_attention_mask=t5_attn_mask,
source_attention_mask=attn_mask,
)
crossattn_emb[~t5_attn_mask.bool()] = 0
else:
crossattn_emb = prompt_embeds
# Encode negative prompt for CFG
neg_crossattn_emb = None
if scale > 1.0 and negative_prompt is not None:
neg_encoded = encode_prompt(negative_prompt)
if neg_encoded is not None:
neg_pe, neg_am, neg_t5_ids, neg_t5_am = neg_encoded
if isinstance(neg_pe, np.ndarray):
neg_pe = torch.from_numpy(neg_pe).unsqueeze(0)
neg_am = torch.from_numpy(neg_am).unsqueeze(0)
neg_t5_ids = torch.from_numpy(neg_t5_ids).unsqueeze(0)
neg_t5_am = torch.from_numpy(neg_t5_am).unsqueeze(0)
neg_pe = neg_pe.to(accelerator.device, dtype=dit.dtype)
neg_am = neg_am.to(accelerator.device)
neg_t5_ids = neg_t5_ids.to(accelerator.device, dtype=torch.long)
neg_t5_am = neg_t5_am.to(accelerator.device)
if dit.use_llm_adapter:
neg_crossattn_emb = dit.llm_adapter(
source_hidden_states=neg_pe,
target_input_ids=neg_t5_ids,
target_attention_mask=neg_t5_am,
source_attention_mask=neg_am,
)
neg_crossattn_emb[~neg_t5_am.bool()] = 0
else:
neg_crossattn_emb = neg_pe
# Generate sample
clean_memory_on_device(accelerator.device)
latents = do_sample(
height, width, seed, dit, crossattn_emb, sample_steps, dit.dtype, accelerator.device, scale, flow_shift, neg_crossattn_emb
)
# Decode latents
gc.collect()
synchronize_device(accelerator.device)
clean_memory_on_device(accelerator.device)
org_vae_device = vae.device
vae.to(accelerator.device)
decoded = vae.decode_to_pixels(latents)
vae.to(org_vae_device)
clean_memory_on_device(accelerator.device)
# Convert to image
image = decoded.float()
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0]
# Remove temporal dim if present
if image.ndim == 4:
image = image[:, 0, :, :]
decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2)
decoded_np = decoded_np.astype(np.uint8)
image = Image.fromarray(decoded_np)
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
seed_suffix = "" if seed is None else f"_{seed}"
i = prompt_dict.get("enum", 0)
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
image.save(os.path.join(save_dir, img_filename))
# Log to wandb if enabled
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
wandb_tracker = accelerator.get_tracker("wandb")
import wandb
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False)

309
library/anima_utils.py Normal file
View File

@@ -0,0 +1,309 @@
# Anima model loading/saving utilities
import os
from typing import Dict, List, Optional, Union
import torch
from safetensors.torch import load_file, save_file
from accelerate.utils import set_module_tensor_to_device # kept for potential future use
from accelerate import init_empty_weights
from library.fp8_optimization_utils import apply_fp8_monkey_patch
from library.lora_utils import load_safetensors_with_lora_and_fp8
from library import anima_models
from library.safetensors_utils import WeightTransformHooks
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# Original Anima high-precision keys. Kept for reference, but not used currently.
# # Keys that should stay in high precision (float32/bfloat16, not quantized)
# KEEP_IN_HIGH_PRECISION = ["x_embedder", "t_embedder", "t_embedding_norm", "final_layer"]
FP8_OPTIMIZATION_TARGET_KEYS = ["blocks", ""]
# ".embed." excludes Embedding in LLMAdapter
FP8_OPTIMIZATION_EXCLUDE_KEYS = ["_embedder", "norm", "adaln", "final_layer", ".embed."]
def load_anima_model(
device: Union[str, torch.device],
dit_path: str,
attn_mode: str,
split_attn: bool,
loading_device: Union[str, torch.device],
dit_weight_dtype: Optional[torch.dtype],
fp8_scaled: bool = False,
lora_weights_list: Optional[List[Dict[str, torch.Tensor]]] = None,
lora_multipliers: Optional[list[float]] = None,
) -> anima_models.Anima:
"""
Load Anima model from the specified checkpoint.
Args:
device (Union[str, torch.device]): Device for optimization or merging
dit_path (str): Path to the DiT model checkpoint.
attn_mode (str): Attention mode to use, e.g., "torch", "flash", etc.
split_attn (bool): Whether to use split attention.
loading_device (Union[str, torch.device]): Device to load the model weights on.
dit_weight_dtype (Optional[torch.dtype]): Data type of the DiT weights.
If None, it will be loaded as is (same as the state_dict) or scaled for fp8. if not None, model weights will be casted to this dtype.
fp8_scaled (bool): Whether to use fp8 scaling for the model weights.
lora_weights_list (Optional[List[Dict[str, torch.Tensor]]]): LoRA weights to apply, if any.
lora_multipliers (Optional[List[float]]): LoRA multipliers for the weights, if any.
"""
# dit_weight_dtype is None for fp8_scaled
assert (
not fp8_scaled and dit_weight_dtype is not None
) or dit_weight_dtype is None, "dit_weight_dtype should be None when fp8_scaled is True"
device = torch.device(device)
loading_device = torch.device(loading_device)
# We currently support fixed DiT config for Anima models
dit_config = {
"max_img_h": 512,
"max_img_w": 512,
"max_frames": 128,
"in_channels": 16,
"out_channels": 16,
"patch_spatial": 2,
"patch_temporal": 1,
"model_channels": 2048,
"concat_padding_mask": True,
"crossattn_emb_channels": 1024,
"pos_emb_cls": "rope3d",
"pos_emb_learnable": True,
"pos_emb_interpolation": "crop",
"min_fps": 1,
"max_fps": 30,
"use_adaln_lora": True,
"adaln_lora_dim": 256,
"num_blocks": 28,
"num_heads": 16,
"extra_per_block_abs_pos_emb": False,
"rope_h_extrapolation_ratio": 4.0,
"rope_w_extrapolation_ratio": 4.0,
"rope_t_extrapolation_ratio": 1.0,
"extra_h_extrapolation_ratio": 1.0,
"extra_w_extrapolation_ratio": 1.0,
"extra_t_extrapolation_ratio": 1.0,
"rope_enable_fps_modulation": False,
"use_llm_adapter": True,
"attn_mode": attn_mode,
"split_attn": split_attn,
}
with init_empty_weights():
model = anima_models.Anima(**dit_config)
if dit_weight_dtype is not None:
model.to(dit_weight_dtype)
# load model weights with dynamic fp8 optimization and LoRA merging if needed
logger.info(f"Loading DiT model from {dit_path}, device={loading_device}")
rename_hooks = WeightTransformHooks(rename_hook=lambda k: k[len("net.") :] if k.startswith("net.") else k)
sd = load_safetensors_with_lora_and_fp8(
model_files=dit_path,
lora_weights_list=lora_weights_list,
lora_multipliers=lora_multipliers,
fp8_optimization=fp8_scaled,
calc_device=device,
move_to_device=(loading_device == device),
dit_weight_dtype=dit_weight_dtype,
target_keys=FP8_OPTIMIZATION_TARGET_KEYS,
exclude_keys=FP8_OPTIMIZATION_EXCLUDE_KEYS,
weight_transform_hooks=rename_hooks,
)
if fp8_scaled:
apply_fp8_monkey_patch(model, sd, use_scaled_mm=False)
if loading_device.type != "cpu":
# make sure all the model weights are on the loading_device
logger.info(f"Moving weights to {loading_device}")
for key in sd.keys():
sd[key] = sd[key].to(loading_device)
missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
if missing:
# Filter out expected missing buffers (initialized in __init__, not saved in checkpoint)
unexpected_missing = [
k
for k in missing
if not any(buf_name in k for buf_name in ("seq", "dim_spatial_range", "dim_temporal_range", "inv_freq"))
]
if unexpected_missing:
# Raise error to avoid silent failures
raise RuntimeError(
f"Missing keys in checkpoint: {unexpected_missing[:10]}{'...' if len(unexpected_missing) > 10 else ''}"
)
missing = {} # all missing keys were expected
if unexpected:
# Raise error to avoid silent failures
raise RuntimeError(f"Unexpected keys in checkpoint: {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}")
logger.info(f"Loaded DiT model from {dit_path}, unexpected missing keys: {len(missing)}, unexpected keys: {len(unexpected)}")
return model
def load_qwen3_tokenizer(qwen3_path: str):
"""Load Qwen3 tokenizer only (without the text encoder model).
Args:
qwen3_path: Path to either a directory with model files or a safetensors file.
If a directory, loads tokenizer from it directly.
If a file, uses configs/qwen3_06b/ for tokenizer config.
Returns:
tokenizer
"""
from transformers import AutoTokenizer
if os.path.isdir(qwen3_path):
tokenizer = AutoTokenizer.from_pretrained(qwen3_path, local_files_only=True)
else:
config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs", "qwen3_06b")
if not os.path.exists(config_dir):
raise FileNotFoundError(
f"Qwen3 config directory not found at {config_dir}. "
"Expected configs/qwen3_06b/ with config.json, tokenizer.json, etc. "
"You can download these from the Qwen3-0.6B HuggingFace repository."
)
tokenizer = AutoTokenizer.from_pretrained(config_dir, local_files_only=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
def load_qwen3_text_encoder(
qwen3_path: str,
dtype: torch.dtype = torch.bfloat16,
device: str = "cpu",
lora_weights: Optional[List[Dict[str, torch.Tensor]]] = None,
lora_multipliers: Optional[List[float]] = None,
):
"""Load Qwen3-0.6B text encoder.
Args:
qwen3_path: Path to either a directory with model files or a safetensors file
dtype: Model dtype
device: Device to load to
Returns:
(text_encoder_model, tokenizer)
"""
import transformers
from transformers import AutoTokenizer
logger.info(f"Loading Qwen3 text encoder from {qwen3_path}")
if os.path.isdir(qwen3_path):
# Directory with full model
tokenizer = AutoTokenizer.from_pretrained(qwen3_path, local_files_only=True)
model = transformers.AutoModelForCausalLM.from_pretrained(qwen3_path, torch_dtype=dtype, local_files_only=True).model
else:
# Single safetensors file - use configs/qwen3_06b/ for config
config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs", "qwen3_06b")
if not os.path.exists(config_dir):
raise FileNotFoundError(
f"Qwen3 config directory not found at {config_dir}. "
"Expected configs/qwen3_06b/ with config.json, tokenizer.json, etc. "
"You can download these from the Qwen3-0.6B HuggingFace repository."
)
tokenizer = AutoTokenizer.from_pretrained(config_dir, local_files_only=True)
qwen3_config = transformers.Qwen3Config.from_pretrained(config_dir, local_files_only=True)
model = transformers.Qwen3ForCausalLM(qwen3_config).model
# Load weights
if qwen3_path.endswith(".safetensors"):
if lora_weights is None:
state_dict = load_file(qwen3_path, device="cpu")
else:
state_dict = load_safetensors_with_lora_and_fp8(
model_files=qwen3_path,
lora_weights_list=lora_weights,
lora_multipliers=lora_multipliers,
fp8_optimization=False,
calc_device=device,
move_to_device=True,
dit_weight_dtype=None,
)
else:
assert lora_weights is None, "LoRA weights merging is only supported for safetensors checkpoints"
state_dict = torch.load(qwen3_path, map_location="cpu", weights_only=True)
# Remove 'model.' prefix if present
new_sd = {}
for k, v in state_dict.items():
if k.startswith("model."):
new_sd[k[len("model.") :]] = v
else:
new_sd[k] = v
info = model.load_state_dict(new_sd, strict=False)
logger.info(f"Loaded Qwen3 state dict: {info}")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model.config.use_cache = False
model = model.requires_grad_(False).to(device, dtype=dtype)
logger.info(f"Loaded Qwen3 text encoder. Parameters: {sum(p.numel() for p in model.parameters()):,}")
return model, tokenizer
def load_t5_tokenizer(t5_tokenizer_path: Optional[str] = None):
"""Load T5 tokenizer for LLM Adapter target tokens.
Args:
t5_tokenizer_path: Optional path to T5 tokenizer directory. If None, uses default configs.
"""
from transformers import T5TokenizerFast
if t5_tokenizer_path is not None:
return T5TokenizerFast.from_pretrained(t5_tokenizer_path, local_files_only=True)
# Use bundled config
config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs", "t5_old")
if os.path.exists(config_dir):
return T5TokenizerFast(
vocab_file=os.path.join(config_dir, "spiece.model"),
tokenizer_file=os.path.join(config_dir, "tokenizer.json"),
)
raise FileNotFoundError(
f"T5 tokenizer config directory not found at {config_dir}. "
"Expected configs/t5_old/ with spiece.model and tokenizer.json. "
"You can download these from the google/t5-v1_1-xxl HuggingFace repository."
)
def save_anima_model(
save_path: str, dit_state_dict: Dict[str, torch.Tensor], metadata: Dict[str, any], dtype: Optional[torch.dtype] = None
):
"""Save Anima DiT model with 'net.' prefix for ComfyUI compatibility.
Args:
save_path: Output path (.safetensors)
dit_state_dict: State dict from dit.state_dict()
metadata: Metadata dict to include in the safetensors file
dtype: Optional dtype to cast to before saving
"""
prefixed_sd = {}
for k, v in dit_state_dict.items():
if dtype is not None:
# v = v.to(dtype)
v = v.detach().clone().to("cpu").to(dtype) # Reduce GPU memory usage during save
prefixed_sd["net." + k] = v.contiguous()
if metadata is None:
metadata = {}
metadata["format"] = "pt" # For compatibility with the official .safetensors file
save_file(prefixed_sd, save_path, metadata=metadata) # safetensors.save_file cosumes a lot of memory, but Anima is small enough
logger.info(f"Saved Anima model to {save_path}")

View File

@@ -37,6 +37,14 @@ class AttentionParams:
cu_seqlens: Optional[torch.Tensor] = None
max_seqlen: Optional[int] = None
@property
def supports_fp32(self) -> bool:
return self.attn_mode not in ["flash"]
@property
def requires_same_dtype(self) -> bool:
return self.attn_mode in ["xformers"]
@staticmethod
def create_attention_params(attn_mode: Optional[str], split_attn: bool) -> "AttentionParams":
return AttentionParams(attn_mode, split_attn)
@@ -95,7 +103,7 @@ def attention(
qkv_or_q: Query tensor [B, L, H, D]. or list of such tensors.
k: Key tensor [B, L, H, D].
v: Value tensor [B, L, H, D].
attn_param: Attention parameters including mask and sequence lengths.
attn_params: Attention parameters including mask and sequence lengths.
drop_rate: Attention dropout rate.
Returns:

View File

@@ -195,6 +195,9 @@ class ModelOffloader(Offloader):
self.remove_handles.append(handle)
def set_forward_only(self, forward_only: bool):
# switching must wait for all pending transfers
for block_idx in list(self.futures.keys()):
self._wait_blocks_move(block_idx)
self.forward_only = forward_only
def __del__(self):
@@ -237,6 +240,10 @@ class ModelOffloader(Offloader):
if self.debug:
print(f"Prepare block devices before forward")
# wait for all pending transfers
for block_idx in list(self.futures.keys()):
self._wait_blocks_move(block_idx)
for b in blocks[0 : self.num_blocks - self.blocks_to_swap]:
b.to(self.device)
weighs_to_device(b, self.device) # make sure weights are on device

View File

@@ -471,7 +471,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
def get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
bsz, _, h, w = latents.shape
bsz, h, w = latents.shape[0], latents.shape[-2], latents.shape[-1]
assert bsz > 0, "Batch size not large enough"
num_timesteps = noise_scheduler.config.num_train_timesteps
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
@@ -512,7 +512,7 @@ def get_noisy_model_input_and_timesteps(
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
# Broadcast sigmas to latent shape
sigmas = sigmas.view(-1, 1, 1, 1)
sigmas = sigmas.view(-1, 1, 1, 1) if latents.ndim == 4 else sigmas.view(-1, 1, 1, 1, 1)
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)

View File

@@ -9,7 +9,7 @@ import logging
from tqdm import tqdm
from library.device_utils import clean_memory_on_device
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.safetensors_utils import MemoryEfficientSafeOpen, TensorWeightAdapter, WeightTransformHooks
from library.utils import setup_logging
setup_logging()
@@ -220,6 +220,8 @@ def quantize_weight(
tensor_max = torch.max(torch.abs(tensor).view(-1))
scale = tensor_max / max_value
# print(f"Optimizing {key} with scale: {scale}")
# numerical safety
scale = torch.clamp(scale, min=1e-8)
scale = scale.to(torch.float32) # ensure scale is in float32 for division
@@ -245,6 +247,8 @@ def load_safetensors_with_fp8_optimization(
weight_hook=None,
quantization_mode: str = "block",
block_size: Optional[int] = 64,
disable_numpy_memmap: bool = False,
weight_transform_hooks: Optional[WeightTransformHooks] = None,
) -> dict:
"""
Load weight tensors from safetensors files and merge LoRA weights into the state dict with explicit FP8 optimization.
@@ -260,6 +264,8 @@ def load_safetensors_with_fp8_optimization(
weight_hook (callable, optional): Function to apply to each weight tensor before optimization
quantization_mode (str): Quantization mode, "tensor", "channel", or "block"
block_size (int, optional): Block size for block-wise quantization (used if quantization_mode is "block")
disable_numpy_memmap (bool): Disable numpy memmap when loading safetensors
weight_transform_hooks (WeightTransformHooks, optional): Hooks for weight transformation during loading
Returns:
dict: FP8 optimized state dict
@@ -288,7 +294,9 @@ def load_safetensors_with_fp8_optimization(
# Process each file
state_dict = {}
for model_file in model_files:
with MemoryEfficientSafeOpen(model_file) as f:
with MemoryEfficientSafeOpen(model_file, disable_numpy_memmap=disable_numpy_memmap) as original_f:
f = TensorWeightAdapter(weight_transform_hooks, original_f) if weight_transform_hooks is not None else original_f
keys = f.keys()
for key in tqdm(keys, desc=f"Loading {os.path.basename(model_file)}", unit="key"):
value = f.get_tensor(key)
@@ -311,6 +319,11 @@ def load_safetensors_with_fp8_optimization(
value = value.to(calc_device)
original_dtype = value.dtype
if original_dtype.itemsize == 1:
raise ValueError(
f"Layer {key} is already in {original_dtype} format. `--fp8_scaled` optimization should not be applied. Please use fp16/bf16/float32 model weights."
+ f" / レイヤー {key} は既に{original_dtype}形式です。`--fp8_scaled` 最適化は適用できません。FP16/BF16/Float32のモデル重みを使用してください。"
)
quantized_weight, scale_tensor = quantize_weight(
key, value, fp8_dtype, max_value, min_value, quantization_mode, block_size
)
@@ -387,7 +400,7 @@ def fp8_linear_forward_patch(self: nn.Linear, x, use_scaled_mm=False, max_value=
else:
o = torch._scaled_mm(x, weight, out_dtype=input_dtype, scale_a=scale_x, scale_b=scale_weight)
o = o.reshape(original_shape[0], original_shape[1], -1) if x.ndim == 3 else o.reshape(original_shape[0], -1)
o = o.reshape(original_shape[0], original_shape[1], -1) if len(original_shape) == 3 else o.reshape(original_shape[0], -1)
return o.to(input_dtype)
else:

View File

@@ -5,7 +5,7 @@ import torch
from tqdm import tqdm
from library.device_utils import synchronize_device
from library.fp8_optimization_utils import load_safetensors_with_fp8_optimization
from library.safetensors_utils import MemoryEfficientSafeOpen
from library.safetensors_utils import MemoryEfficientSafeOpen, TensorWeightAdapter, WeightTransformHooks, get_split_weight_filenames
from library.utils import setup_logging
setup_logging()
@@ -44,7 +44,7 @@ def filter_lora_state_dict(
def load_safetensors_with_lora_and_fp8(
model_files: Union[str, List[str]],
lora_weights_list: Optional[Dict[str, torch.Tensor]],
lora_weights_list: Optional[List[Dict[str, torch.Tensor]]],
lora_multipliers: Optional[List[float]],
fp8_optimization: bool,
calc_device: torch.device,
@@ -52,19 +52,23 @@ def load_safetensors_with_lora_and_fp8(
dit_weight_dtype: Optional[torch.dtype] = None,
target_keys: Optional[List[str]] = None,
exclude_keys: Optional[List[str]] = None,
disable_numpy_memmap: bool = False,
weight_transform_hooks: Optional[WeightTransformHooks] = None,
) -> dict[str, torch.Tensor]:
"""
Merge LoRA weights into the state dict of a model with fp8 optimization if needed.
Args:
model_files (Union[str, List[str]]): Path to the model file or list of paths. If the path matches a pattern like `00001-of-00004`, it will load all files with the same prefix.
lora_weights_list (Optional[Dict[str, torch.Tensor]]): Dictionary of LoRA weight tensors to load.
lora_weights_list (Optional[List[Dict[str, torch.Tensor]]]): List of dictionaries of LoRA weight tensors to load.
lora_multipliers (Optional[List[float]]): List of multipliers for LoRA weights.
fp8_optimization (bool): Whether to apply FP8 optimization.
calc_device (torch.device): Device to calculate on.
move_to_device (bool): Whether to move tensors to the calculation device after loading.
target_keys (Optional[List[str]]): Keys to target for optimization.
exclude_keys (Optional[List[str]]): Keys to exclude from optimization.
disable_numpy_memmap (bool): Whether to disable numpy memmap when loading safetensors.
weight_transform_hooks (Optional[WeightTransformHooks]): Hooks for transforming weights during loading.
"""
# if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix
@@ -73,19 +77,9 @@ def load_safetensors_with_lora_and_fp8(
extended_model_files = []
for model_file in model_files:
basename = os.path.basename(model_file)
match = re.match(r"^(.*?)(\d+)-of-(\d+)\.safetensors$", basename)
if match:
prefix = basename[: match.start(2)]
count = int(match.group(3))
state_dict = {}
for i in range(count):
filename = f"{prefix}{i + 1:05d}-of-{count:05d}.safetensors"
filepath = os.path.join(os.path.dirname(model_file), filename)
if os.path.exists(filepath):
extended_model_files.append(filepath)
else:
raise FileNotFoundError(f"File {filepath} not found")
split_filenames = get_split_weight_filenames(model_file)
if split_filenames is not None:
extended_model_files.extend(split_filenames)
else:
extended_model_files.append(model_file)
model_files = extended_model_files
@@ -114,7 +108,7 @@ def load_safetensors_with_lora_and_fp8(
logger.info(f"Merging LoRA weights into state dict. multipliers: {lora_multipliers}")
# make hook for LoRA merging
def weight_hook_func(model_weight_key, model_weight, keep_on_calc_device=False):
def weight_hook_func(model_weight_key, model_weight: torch.Tensor, keep_on_calc_device=False):
nonlocal list_of_lora_weight_keys, lora_weights_list, lora_multipliers, calc_device
if not model_weight_key.endswith(".weight"):
@@ -126,13 +120,18 @@ def load_safetensors_with_lora_and_fp8(
for lora_weight_keys, lora_sd, multiplier in zip(list_of_lora_weight_keys, lora_weights_list, lora_multipliers):
# check if this weight has LoRA weights
lora_name = model_weight_key.rsplit(".", 1)[0] # remove trailing ".weight"
lora_name = "lora_unet_" + lora_name.replace(".", "_")
down_key = lora_name + ".lora_down.weight"
up_key = lora_name + ".lora_up.weight"
alpha_key = lora_name + ".alpha"
if down_key not in lora_weight_keys or up_key not in lora_weight_keys:
continue
lora_name_without_prefix = model_weight_key.rsplit(".", 1)[0] # remove trailing ".weight"
found = False
for prefix in ["lora_unet_", ""]:
lora_name = prefix + lora_name_without_prefix.replace(".", "_")
down_key = lora_name + ".lora_down.weight"
up_key = lora_name + ".lora_up.weight"
alpha_key = lora_name + ".alpha"
if down_key in lora_weight_keys and up_key in lora_weight_keys:
found = True
break
if not found:
continue # no LoRA weights for this model weight
# get LoRA weights
down_weight = lora_sd[down_key]
@@ -145,6 +144,13 @@ def load_safetensors_with_lora_and_fp8(
down_weight = down_weight.to(calc_device)
up_weight = up_weight.to(calc_device)
original_dtype = model_weight.dtype
if original_dtype.itemsize == 1: # fp8
# temporarily convert to float16 for calculation
model_weight = model_weight.to(torch.float16)
down_weight = down_weight.to(torch.float16)
up_weight = up_weight.to(torch.float16)
# W <- W + U * D
if len(model_weight.size()) == 2:
# linear
@@ -166,6 +172,9 @@ def load_safetensors_with_lora_and_fp8(
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
model_weight = model_weight + multiplier * conved * scale
if original_dtype.itemsize == 1: # fp8
model_weight = model_weight.to(original_dtype) # convert back to original dtype
# remove LoRA keys from set
lora_weight_keys.remove(down_key)
lora_weight_keys.remove(up_key)
@@ -187,6 +196,8 @@ def load_safetensors_with_lora_and_fp8(
target_keys,
exclude_keys,
weight_hook=weight_hook,
disable_numpy_memmap=disable_numpy_memmap,
weight_transform_hooks=weight_transform_hooks,
)
for lora_weight_keys in list_of_lora_weight_keys:
@@ -208,6 +219,8 @@ def load_safetensors_with_fp8_optimization_and_hook(
target_keys: Optional[List[str]] = None,
exclude_keys: Optional[List[str]] = None,
weight_hook: callable = None,
disable_numpy_memmap: bool = False,
weight_transform_hooks: Optional[WeightTransformHooks] = None,
) -> dict[str, torch.Tensor]:
"""
Load state dict from safetensors files and merge LoRA weights into the state dict with fp8 optimization if needed.
@@ -218,7 +231,14 @@ def load_safetensors_with_fp8_optimization_and_hook(
)
# dit_weight_dtype is not used because we use fp8 optimization
state_dict = load_safetensors_with_fp8_optimization(
model_files, calc_device, target_keys, exclude_keys, move_to_device=move_to_device, weight_hook=weight_hook
model_files,
calc_device,
target_keys,
exclude_keys,
move_to_device=move_to_device,
weight_hook=weight_hook,
disable_numpy_memmap=disable_numpy_memmap,
weight_transform_hooks=weight_transform_hooks,
)
else:
logger.info(
@@ -226,7 +246,8 @@ def load_safetensors_with_fp8_optimization_and_hook(
)
state_dict = {}
for model_file in model_files:
with MemoryEfficientSafeOpen(model_file) as f:
with MemoryEfficientSafeOpen(model_file, disable_numpy_memmap=disable_numpy_memmap) as original_f:
f = TensorWeightAdapter(weight_transform_hooks, original_f) if weight_transform_hooks is not None else original_f
for key in tqdm(f.keys(), desc=f"Loading {os.path.basename(model_file)}", leave=False):
if weight_hook is None and move_to_device:
value = f.get_tensor(key, device=calc_device, dtype=dit_weight_dtype)

File diff suppressed because it is too large Load Diff

View File

@@ -1,3 +1,4 @@
from dataclasses import dataclass
import os
import re
import numpy as np
@@ -44,6 +45,7 @@ def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata:
validated[key] = value
return validated
# print(f"Using memory efficient save file: {filename}")
header = {}
offset = 0
@@ -88,15 +90,17 @@ class MemoryEfficientSafeOpen:
by using memory mapping for large tensors and avoiding unnecessary copies.
"""
def __init__(self, filename):
def __init__(self, filename, disable_numpy_memmap=False):
"""Initialize the SafeTensor reader.
Args:
filename (str): Path to the safetensors file to read.
disable_numpy_memmap (bool): If True, disable numpy memory mapping for large tensors, using standard file read instead.
"""
self.filename = filename
self.file = open(filename, "rb")
self.header, self.header_size = self._read_header()
self.disable_numpy_memmap = disable_numpy_memmap
def __enter__(self):
"""Enter context manager."""
@@ -178,7 +182,8 @@ class MemoryEfficientSafeOpen:
# Use memmap for large tensors to avoid intermediate copies.
# If device is cpu, tensor is not copied to gpu, so using memmap locks the file, which is not desired.
# So we only use memmap if device is not cpu.
if num_bytes > 10 * 1024 * 1024 and device is not None and device.type != "cpu":
# If disable_numpy_memmap is True, skip numpy memory mapping to load with standard file read.
if not self.disable_numpy_memmap and num_bytes > 10 * 1024 * 1024 and device is not None and device.type != "cpu":
# Create memory map for zero-copy reading
mm = np.memmap(self.filename, mode="c", dtype=np.uint8, offset=tensor_offset, shape=(num_bytes,))
byte_tensor = torch.from_numpy(mm) # zero copy
@@ -285,7 +290,11 @@ class MemoryEfficientSafeOpen:
def load_safetensors(
path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = None
path: str,
device: Union[str, torch.device],
disable_mmap: bool = False,
dtype: Optional[torch.dtype] = None,
disable_numpy_memmap: bool = False,
) -> dict[str, torch.Tensor]:
if disable_mmap:
# return safetensors.torch.load(open(path, "rb").read())
@@ -293,7 +302,7 @@ def load_safetensors(
# logger.info(f"Loading without mmap (experimental)")
state_dict = {}
device = torch.device(device) if device is not None else None
with MemoryEfficientSafeOpen(path) as f:
with MemoryEfficientSafeOpen(path, disable_numpy_memmap=disable_numpy_memmap) as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key, device=device, dtype=dtype)
synchronize_device(device)
@@ -309,6 +318,29 @@ def load_safetensors(
return state_dict
def get_split_weight_filenames(file_path: str) -> Optional[list[str]]:
"""
Get the list of split weight filenames (full paths) if the file name ends with 00001-of-00004 etc.
Returns None if the file is not split.
"""
basename = os.path.basename(file_path)
match = re.match(r"^(.*?)(\d+)-of-(\d+)\.safetensors$", basename)
if match:
prefix = basename[: match.start(2)]
count = int(match.group(3))
filenames = []
for i in range(count):
filename = f"{prefix}{i + 1:05d}-of-{count:05d}.safetensors"
filepath = os.path.join(os.path.dirname(file_path), filename)
if os.path.exists(filepath):
filenames.append(filepath)
else:
raise FileNotFoundError(f"File {filepath} not found")
return filenames
else:
return None
def load_split_weights(
file_path: str, device: Union[str, torch.device] = "cpu", disable_mmap: bool = False, dtype: Optional[torch.dtype] = None
) -> Dict[str, torch.Tensor]:
@@ -319,19 +351,11 @@ def load_split_weights(
device = torch.device(device)
# if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix
basename = os.path.basename(file_path)
match = re.match(r"^(.*?)(\d+)-of-(\d+)\.safetensors$", basename)
if match:
prefix = basename[: match.start(2)]
count = int(match.group(3))
split_filenames = get_split_weight_filenames(file_path)
if split_filenames is not None:
state_dict = {}
for i in range(count):
filename = f"{prefix}{i + 1:05d}-of-{count:05d}.safetensors"
filepath = os.path.join(os.path.dirname(file_path), filename)
if os.path.exists(filepath):
state_dict.update(load_safetensors(filepath, device=device, disable_mmap=disable_mmap, dtype=dtype))
else:
raise FileNotFoundError(f"File {filepath} not found")
for filename in split_filenames:
state_dict.update(load_safetensors(filename, device=device, disable_mmap=disable_mmap, dtype=dtype))
else:
state_dict = load_safetensors(file_path, device=device, disable_mmap=disable_mmap, dtype=dtype)
return state_dict
@@ -349,3 +373,106 @@ def find_key(safetensors_file: str, starts_with: Optional[str] = None, ends_with
if (starts_with is None or key.startswith(starts_with)) and (ends_with is None or key.endswith(ends_with)):
return key
return None
@dataclass
class WeightTransformHooks:
split_hook: Optional[callable] = None
concat_hook: Optional[callable] = None
rename_hook: Optional[callable] = None
class TensorWeightAdapter:
"""
A wrapper for weight conversion hooks (split and concat) to be used with MemoryEfficientSafeOpen.
This wrapper adapts the original MemoryEfficientSafeOpen to apply the provided split and concat hooks
when loading tensors.
split_hook: A callable that takes (original_key: str, original_tensor: torch.Tensor) and returns (new_keys: list[str], new_tensors: list[torch.Tensor]).
concat_hook: A callable that takes (original_key: str, tensors: dict[str, torch.Tensor]) and returns (new_key: str, concatenated_tensor: torch.Tensor).
rename_hook: A callable that takes (original_key: str) and returns (new_key: str).
If tensors is None, the hook should return only the new keys (for split) or new key (for concat), without tensors.
No need to implement __enter__ and __exit__ methods, as they are handled by the original MemoryEfficientSafeOpen.
Do not use this wrapper as a context manager directly, like `with WeightConvertHookWrapper(...) as f:`.
**concat_hook is not tested yet.**
"""
def __init__(self, weight_convert_hook: WeightTransformHooks, original_f: MemoryEfficientSafeOpen):
self.original_f = original_f
self.new_key_to_original_key_map: dict[str, Union[str, list[str]]] = (
{}
) # for split: new_key -> original_key; for concat: new_key -> list of original_keys; for direct mapping: new_key -> original_key
self.concat_key_set = set() # set of concatenated keys
self.split_key_set = set() # set of split keys
self.new_keys = []
self.tensor_cache = {} # cache for split tensors
self.split_hook = weight_convert_hook.split_hook
self.concat_hook = weight_convert_hook.concat_hook
self.rename_hook = weight_convert_hook.rename_hook
for key in self.original_f.keys():
if self.split_hook is not None:
converted_keys, _ = self.split_hook(key, None) # get new keys only
if converted_keys is not None:
for converted_key in converted_keys:
self.new_key_to_original_key_map[converted_key] = key
self.split_key_set.add(converted_key)
self.new_keys.extend(converted_keys)
continue # skip concat_hook if split_hook is applied
if self.concat_hook is not None:
converted_key, _ = self.concat_hook(key, None) # get new key only
if converted_key is not None:
if converted_key not in self.concat_key_set: # first time seeing this concatenated key
self.concat_key_set.add(converted_key)
self.new_key_to_original_key_map[converted_key] = []
self.new_keys.append(converted_key)
# multiple original keys map to the same concatenated key
self.new_key_to_original_key_map[converted_key].append(key)
continue # skip to next key
# direct mapping
if self.rename_hook is not None:
new_key = self.rename_hook(key)
self.new_key_to_original_key_map[new_key] = key
else:
new_key = key
self.new_keys.append(new_key)
def keys(self) -> list[str]:
return self.new_keys
def get_tensor(self, new_key: str, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
# load tensor by new_key, applying split or concat hooks as needed
if new_key not in self.new_key_to_original_key_map:
# direct mapping
return self.original_f.get_tensor(new_key, device=device, dtype=dtype)
elif new_key in self.split_key_set:
# split hook: split key is requested multiple times, so we cache the result
original_key = self.new_key_to_original_key_map[new_key]
if original_key not in self.tensor_cache: # not yet split
original_tensor = self.original_f.get_tensor(original_key, device=device, dtype=dtype)
new_keys, new_tensors = self.split_hook(original_key, original_tensor) # apply split hook
for k, t in zip(new_keys, new_tensors):
self.tensor_cache[k] = t
return self.tensor_cache.pop(new_key) # return and remove from cache
elif new_key in self.concat_key_set:
# concat hook: concatenated key is requested only once, so we do not cache the result
tensors = {}
for original_key in self.new_key_to_original_key_map[new_key]:
tensor = self.original_f.get_tensor(original_key, device=device, dtype=dtype)
tensors[original_key] = tensor
_, concatenated_tensors = self.concat_hook(self.new_key_to_original_key_map[new_key][0], tensors) # apply concat hook
return concatenated_tensors
else:
# direct mapping
original_key = self.new_key_to_original_key_map[new_key]
return self.original_f.get_tensor(original_key, device=device, dtype=dtype)

View File

@@ -81,6 +81,8 @@ ARCH_LUMINA_2 = "lumina-2"
ARCH_LUMINA_UNKNOWN = "lumina"
ARCH_HUNYUAN_IMAGE_2_1 = "hunyuan-image-2.1"
ARCH_HUNYUAN_IMAGE_UNKNOWN = "hunyuan-image"
ARCH_ANIMA_PREVIEW = "anima-preview"
ARCH_ANIMA_UNKNOWN = "anima-unknown"
ADAPTER_LORA = "lora"
ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
@@ -92,6 +94,7 @@ IMPL_FLUX = "https://github.com/black-forest-labs/flux"
IMPL_CHROMA = "https://huggingface.co/lodestones/Chroma"
IMPL_LUMINA = "https://github.com/Alpha-VLLM/Lumina-Image-2.0"
IMPL_HUNYUAN_IMAGE = "https://github.com/Tencent-Hunyuan/HunyuanImage-2.1"
IMPL_ANIMA = "https://huggingface.co/circlestone-labs/Anima"
PRED_TYPE_EPSILON = "epsilon"
PRED_TYPE_V = "v"
@@ -220,6 +223,12 @@ def determine_architecture(
arch = ARCH_HUNYUAN_IMAGE_2_1
else:
arch = ARCH_HUNYUAN_IMAGE_UNKNOWN
elif "anima" in model_config:
anima_type = model_config["anima"]
if anima_type == "preview":
arch = ARCH_ANIMA_PREVIEW
else:
arch = ARCH_ANIMA_UNKNOWN
elif v2:
arch = ARCH_SD_V2_768_V if v_parameterization else ARCH_SD_V2_512
else:
@@ -252,6 +261,8 @@ def determine_implementation(
return IMPL_FLUX
elif "lumina" in model_config:
return IMPL_LUMINA
elif "anima" in model_config:
return IMPL_ANIMA
elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
return IMPL_STABILITY_AI
else:
@@ -325,7 +336,7 @@ def determine_resolution(
reso = (reso[0], reso[0])
else:
# Determine default resolution based on model type
if sdxl or "sd3" in model_config or "flux" in model_config or "lumina" in model_config:
if sdxl or "sd3" in model_config or "flux" in model_config or "lumina" in model_config or "anima" in model_config:
reso = (1024, 1024)
elif v2 and v_parameterization:
reso = (768, 768)

302
library/strategy_anima.py Normal file
View File

@@ -0,0 +1,302 @@
# Anima Strategy Classes
import os
import random
from typing import Any, List, Optional, Tuple, Union
import numpy as np
import torch
from library import anima_utils, train_util
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
from library import qwen_image_autoencoder_kl
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class AnimaTokenizeStrategy(TokenizeStrategy):
"""Tokenize strategy for Anima: dual tokenization with Qwen3 + T5.
Qwen3 tokens are used for the text encoder.
T5 tokens are used as target input IDs for the LLM Adapter (NOT encoded by T5).
Can be initialized with either pre-loaded tokenizer objects or paths to load from.
"""
def __init__(
self,
qwen3_tokenizer=None,
t5_tokenizer=None,
qwen3_max_length: int = 512,
t5_max_length: int = 512,
qwen3_path: Optional[str] = None,
t5_tokenizer_path: Optional[str] = None,
) -> None:
# Load tokenizers from paths if not provided directly
if qwen3_tokenizer is None:
if qwen3_path is None:
raise ValueError("Either qwen3_tokenizer or qwen3_path must be provided")
qwen3_tokenizer = anima_utils.load_qwen3_tokenizer(qwen3_path)
if t5_tokenizer is None:
t5_tokenizer = anima_utils.load_t5_tokenizer(t5_tokenizer_path)
self.qwen3_tokenizer = qwen3_tokenizer
self.qwen3_max_length = qwen3_max_length
self.t5_tokenizer = t5_tokenizer
self.t5_max_length = t5_max_length
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
text = [text] if isinstance(text, str) else text
# Tokenize with Qwen3
qwen3_encoding = self.qwen3_tokenizer.batch_encode_plus(
text, return_tensors="pt", truncation=True, padding="max_length", max_length=self.qwen3_max_length
)
qwen3_input_ids = qwen3_encoding["input_ids"]
qwen3_attn_mask = qwen3_encoding["attention_mask"]
# Tokenize with T5 (for LLM Adapter target tokens)
t5_encoding = self.t5_tokenizer.batch_encode_plus(
text, return_tensors="pt", truncation=True, padding="max_length", max_length=self.t5_max_length
)
t5_input_ids = t5_encoding["input_ids"]
t5_attn_mask = t5_encoding["attention_mask"]
return [qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask]
class AnimaTextEncodingStrategy(TextEncodingStrategy):
"""Text encoding strategy for Anima.
Encodes Qwen3 tokens through the Qwen3 text encoder to get hidden states.
T5 tokens are passed through unchanged (only used by LLM Adapter).
"""
def __init__(self) -> None:
super().__init__()
def encode_tokens(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
) -> List[torch.Tensor]:
"""Encode Qwen3 tokens and return embeddings + T5 token IDs.
Args:
models: [qwen3_text_encoder]
tokens: [qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask]
Returns:
[prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
"""
# Do not handle dropout here; handled dataset-side or in drop_cached_text_encoder_outputs()
qwen3_text_encoder = models[0]
qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask = tokens
encoder_device = qwen3_text_encoder.device
qwen3_input_ids = qwen3_input_ids.to(encoder_device)
qwen3_attn_mask = qwen3_attn_mask.to(encoder_device)
outputs = qwen3_text_encoder(input_ids=qwen3_input_ids, attention_mask=qwen3_attn_mask)
prompt_embeds = outputs.last_hidden_state
prompt_embeds[~qwen3_attn_mask.bool()] = 0
return [prompt_embeds, qwen3_attn_mask, t5_input_ids, t5_attn_mask]
def drop_cached_text_encoder_outputs(
self,
prompt_embeds: torch.Tensor,
attn_mask: torch.Tensor,
t5_input_ids: torch.Tensor,
t5_attn_mask: torch.Tensor,
caption_dropout_rates: Optional[torch.Tensor] = None,
) -> List[torch.Tensor]:
"""Apply dropout to cached text encoder outputs.
Called during training when using cached outputs.
Replaces dropped items with pre-cached unconditional embeddings (from encoding "")
to match diffusion-pipe-main behavior.
"""
if caption_dropout_rates is None or torch.all(caption_dropout_rates == 0.0).item():
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
# Clone to avoid in-place modification of cached tensors
prompt_embeds = prompt_embeds.clone()
if attn_mask is not None:
attn_mask = attn_mask.clone()
if t5_input_ids is not None:
t5_input_ids = t5_input_ids.clone()
if t5_attn_mask is not None:
t5_attn_mask = t5_attn_mask.clone()
for i in range(prompt_embeds.shape[0]):
if random.random() < caption_dropout_rates[i].item():
# Use pre-cached unconditional embeddings
prompt_embeds[i] = 0
if attn_mask is not None:
attn_mask[i] = 0
if t5_input_ids is not None:
t5_input_ids[i, 0] = 1 # Set to </s> token ID
t5_input_ids[i, 1:] = 0
if t5_attn_mask is not None:
t5_attn_mask[i, 0] = 1
t5_attn_mask[i, 1:] = 0
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
"""Caching strategy for Anima text encoder outputs.
Caches: prompt_embeds (float), attn_mask (int), t5_input_ids (int), t5_attn_mask (int)
"""
ANIMA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_anima_te.npz"
def __init__(
self,
cache_to_disk: bool,
batch_size: int,
skip_disk_cache_validity_check: bool,
is_partial: bool = False,
) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return os.path.splitext(image_abs_path)[0] + self.ANIMA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
def is_disk_cached_outputs_expected(self, npz_path: str) -> bool:
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
return False
if self.skip_disk_cache_validity_check:
return True
try:
npz = np.load(npz_path)
if "prompt_embeds" not in npz:
return False
if "attn_mask" not in npz:
return False
if "t5_input_ids" not in npz:
return False
if "t5_attn_mask" not in npz:
return False
if "caption_dropout_rate" not in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
data = np.load(npz_path)
prompt_embeds = data["prompt_embeds"]
attn_mask = data["attn_mask"]
t5_input_ids = data["t5_input_ids"]
t5_attn_mask = data["t5_attn_mask"]
caption_dropout_rate = data["caption_dropout_rate"]
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask, caption_dropout_rate]
def cache_batch_outputs(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
text_encoding_strategy: TextEncodingStrategy,
infos: List,
):
anima_text_encoding_strategy: AnimaTextEncodingStrategy = text_encoding_strategy
captions = [info.caption for info in infos]
tokens_and_masks = tokenize_strategy.tokenize(captions)
with torch.no_grad():
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = anima_text_encoding_strategy.encode_tokens(
tokenize_strategy, models, tokens_and_masks
)
# Convert to numpy for caching
if prompt_embeds.dtype == torch.bfloat16:
prompt_embeds = prompt_embeds.float()
prompt_embeds = prompt_embeds.cpu().numpy()
attn_mask = attn_mask.cpu().numpy()
t5_input_ids = t5_input_ids.cpu().numpy().astype(np.int32)
t5_attn_mask = t5_attn_mask.cpu().numpy().astype(np.int32)
for i, info in enumerate(infos):
prompt_embeds_i = prompt_embeds[i]
attn_mask_i = attn_mask[i]
t5_input_ids_i = t5_input_ids[i]
t5_attn_mask_i = t5_attn_mask[i]
caption_dropout_rate = torch.tensor(info.caption_dropout_rate, dtype=torch.float32)
if self.cache_to_disk:
np.savez(
info.text_encoder_outputs_npz,
prompt_embeds=prompt_embeds_i,
attn_mask=attn_mask_i,
t5_input_ids=t5_input_ids_i,
t5_attn_mask=t5_attn_mask_i,
caption_dropout_rate=caption_dropout_rate,
)
else:
info.text_encoder_outputs = (prompt_embeds_i, attn_mask_i, t5_input_ids_i, t5_attn_mask_i, caption_dropout_rate)
class AnimaLatentsCachingStrategy(LatentsCachingStrategy):
"""Latent caching strategy for Anima using WanVAE.
WanVAE produces 16-channel latents with spatial downscale 8x.
Latent shape for images: (B, 16, 1, H/8, W/8)
"""
ANIMA_LATENTS_NPZ_SUFFIX = "_anima.npz"
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
@property
def cache_suffix(self) -> str:
return self.ANIMA_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.ANIMA_LATENTS_NPZ_SUFFIX
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
def load_latents_from_disk(
self, npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
return self._default_load_latents_from_disk(8, npz_path, bucket_reso)
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
"""Cache batch of latents using Qwen Image VAE.
vae is expected to be the Qwen Image VAE (AutoencoderKLQwenImage).
The encoding function handles the mean/std normalization.
"""
vae: qwen_image_autoencoder_kl.AutoencoderKLQwenImage = vae
vae_device = vae.device
vae_dtype = vae.dtype
def encode_by_vae(img_tensor):
"""Encode image tensor to latents.
img_tensor: (B, C, H, W) in [-1, 1] range (already normalized by IMAGE_TRANSFORMS)
Qwen Image VAE accepts inputs in (B, C, H, W) or (B, C, 1, H, W) shape.
Returns latents in (B, 16, 1, H/8, W/8) shape on CPU.
"""
latents = vae.encode_pixels_to_latents(img_tensor) # Keep 4D for input/output
return latents.to("cpu")
self._default_cache_batch_latents(
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True
)
if not train_util.HIGH_VRAM:
train_util.clean_memory_on_device(vae_device)

View File

@@ -524,7 +524,7 @@ class LatentsCachingStrategy:
original_size = original_sizes[i]
crop_ltrb = crop_ltrbs[i]
latents_size = latents.shape[1:3] # H, W
latents_size = latents.shape[-2:] # H, W (supports both 4D and 5D latents)
key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" if multi_resolution else "" # e.g. "_32x64", HxW
if self.cache_to_disk:

View File

@@ -179,12 +179,15 @@ def split_train_val(
class ImageInfo:
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
def __init__(
self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str, caption_dropout_rate: float = 0.0
) -> None:
self.image_key: str = image_key
self.num_repeats: int = num_repeats
self.caption: str = caption
self.is_reg: bool = is_reg
self.absolute_path: str = absolute_path
self.caption_dropout_rate: float = caption_dropout_rate
self.image_size: Tuple[int, int] = None
self.resized_size: Tuple[int, int] = None
self.bucket_reso: Tuple[int, int] = None
@@ -197,7 +200,7 @@ class ImageInfo:
)
self.cond_img_path: Optional[str] = None
self.image: Optional[Image.Image] = None # optional, original PIL Image
self.text_encoder_outputs_npz: Optional[str] = None # set in cache_text_encoder_outputs
self.text_encoder_outputs_npz: Optional[str] = None # filename. set in cache_text_encoder_outputs
# new
self.text_encoder_outputs: Optional[List[torch.Tensor]] = None
@@ -1096,11 +1099,11 @@ class BaseDataset(torch.utils.data.Dataset):
def is_latent_cacheable(self):
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
def is_text_encoder_output_cacheable(self):
def is_text_encoder_output_cacheable(self, cache_supports_dropout: bool = False):
return all(
[
not (
subset.caption_dropout_rate > 0
subset.caption_dropout_rate > 0 and not cache_supports_dropout
or subset.shuffle_caption
or subset.token_warmup_step > 0
or subset.caption_tag_dropout_rate > 0
@@ -2137,7 +2140,7 @@ class DreamBoothDataset(BaseDataset):
num_train_images += num_repeats * len(img_paths)
for img_path, caption, size in zip(img_paths, captions, sizes):
info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path)
info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path, subset.caption_dropout_rate)
info.resize_interpolation = (
subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
)
@@ -2338,7 +2341,7 @@ class FineTuningDataset(BaseDataset):
if caption is None:
caption = ""
image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path)
image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path, subset.caption_dropout_rate)
image_info.resize_interpolation = (
subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
)
@@ -2661,8 +2664,8 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
def is_latent_cacheable(self) -> bool:
return all([dataset.is_latent_cacheable() for dataset in self.datasets])
def is_text_encoder_output_cacheable(self) -> bool:
return all([dataset.is_text_encoder_output_cacheable() for dataset in self.datasets])
def is_text_encoder_output_cacheable(self, cache_supports_dropout: bool = False) -> bool:
return all([dataset.is_text_encoder_output_cacheable(cache_supports_dropout) for dataset in self.datasets])
def set_current_strategies(self):
for dataset in self.datasets:
@@ -3578,6 +3581,7 @@ def get_sai_model_spec_dataclass(
flux: str = None,
lumina: str = None,
hunyuan_image: str = None,
anima: str = None,
optional_metadata: dict[str, str] | None = None,
) -> sai_model_spec.ModelSpecMetadata:
"""
@@ -3609,7 +3613,8 @@ def get_sai_model_spec_dataclass(
model_config["lumina"] = lumina
if hunyuan_image is not None:
model_config["hunyuan_image"] = hunyuan_image
if anima is not None:
model_config["anima"] = anima
# Use the dataclass function directly
return sai_model_spec.build_metadata_dataclass(
state_dict,
@@ -6138,7 +6143,8 @@ def conditional_loss(
elif loss_type == "huber":
if huber_c is None:
raise NotImplementedError("huber_c not implemented correctly")
huber_c = huber_c.view(-1, 1, 1, 1)
# Reshape huber_c to broadcast with model_pred (supports 4D and 5D tensors)
huber_c = huber_c.view(-1, *([1] * (model_pred.ndim - 1)))
loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
if reduction == "mean":
loss = torch.mean(loss)
@@ -6147,7 +6153,8 @@ def conditional_loss(
elif loss_type == "smooth_l1":
if huber_c is None:
raise NotImplementedError("huber_c not implemented correctly")
huber_c = huber_c.view(-1, 1, 1, 1)
# Reshape huber_c to broadcast with model_pred (supports 4D and 5D tensors)
huber_c = huber_c.view(-1, *([1] * (model_pred.ndim - 1)))
loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
if reduction == "mean":
loss = torch.mean(loss)

View File

@@ -0,0 +1,160 @@
import argparse
from safetensors.torch import save_file
from safetensors import safe_open
from library import train_util
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
COMFYUI_DIT_PREFIX = "diffusion_model."
COMFYUI_QWEN3_PREFIX = "text_encoders.qwen3_06b.transformer.model."
def main(args):
# load source safetensors
logger.info(f"Loading source file {args.src_path}")
state_dict = {}
with safe_open(args.src_path, framework="pt") as f:
metadata = f.metadata()
for k in f.keys():
state_dict[k] = f.get_tensor(k)
logger.info(f"Converting...")
keys = list(state_dict.keys())
count = 0
for k in keys:
if not args.reverse:
is_dit_lora = k.startswith("lora_unet_")
module_and_weight_name = "_".join(k.split("_")[2:]) # Remove `lora_unet_`or `lora_te_` prefix
# Split at the first dot, e.g., "block1_linear.weight" -> "block1_linear", "weight"
module_name, weight_name = module_and_weight_name.split(".", 1)
# Weight name conversion: lora_up/lora_down to lora_A/lora_B
if weight_name.startswith("lora_up"):
weight_name = weight_name.replace("lora_up", "lora_B")
elif weight_name.startswith("lora_down"):
weight_name = weight_name.replace("lora_down", "lora_A")
else:
# Keep other weight names as-is: e.g. alpha
pass
# Module name conversion: convert dots to underscores
original_module_name = module_name.replace("_", ".") # Convert to dot notation
# Convert back illegal dots in module names
# DiT
original_module_name = original_module_name.replace("llm.adapter", "llm_adapter")
original_module_name = original_module_name.replace(".linear.", ".linear_")
original_module_name = original_module_name.replace("t.embedding.norm", "t_embedding_norm")
original_module_name = original_module_name.replace("x.embedder", "x_embedder")
original_module_name = original_module_name.replace("adaln.modulation.cross_attn", "adaln_modulation_cross_attn")
original_module_name = original_module_name.replace("adaln.modulation.mlp", "adaln_modulation_mlp")
original_module_name = original_module_name.replace("cross.attn", "cross_attn")
original_module_name = original_module_name.replace("k.proj", "k_proj")
original_module_name = original_module_name.replace("k.norm", "k_norm")
original_module_name = original_module_name.replace("q.proj", "q_proj")
original_module_name = original_module_name.replace("q.norm", "q_norm")
original_module_name = original_module_name.replace("v.proj", "v_proj")
original_module_name = original_module_name.replace("o.proj", "o_proj")
original_module_name = original_module_name.replace("output.proj", "output_proj")
original_module_name = original_module_name.replace("self.attn", "self_attn")
original_module_name = original_module_name.replace("final.layer", "final_layer")
original_module_name = original_module_name.replace("adaln.modulation", "adaln_modulation")
original_module_name = original_module_name.replace("norm.cross.attn", "norm_cross_attn")
original_module_name = original_module_name.replace("norm.mlp", "norm_mlp")
original_module_name = original_module_name.replace("norm.self.attn", "norm_self_attn")
original_module_name = original_module_name.replace("out.proj", "out_proj")
# Qwen3
original_module_name = original_module_name.replace("embed.tokens", "embed_tokens")
original_module_name = original_module_name.replace("input.layernorm", "input_layernorm")
original_module_name = original_module_name.replace("down.proj", "down_proj")
original_module_name = original_module_name.replace("gate.proj", "gate_proj")
original_module_name = original_module_name.replace("up.proj", "up_proj")
original_module_name = original_module_name.replace("post.attention.layernorm", "post_attention_layernorm")
# Prefix conversion
new_prefix = COMFYUI_DIT_PREFIX if is_dit_lora else COMFYUI_QWEN3_PREFIX
new_k = f"{new_prefix}{original_module_name}.{weight_name}"
else:
if k.startswith(COMFYUI_DIT_PREFIX):
is_dit_lora = True
module_and_weight_name = k[len(COMFYUI_DIT_PREFIX) :]
elif k.startswith(COMFYUI_QWEN3_PREFIX):
is_dit_lora = False
module_and_weight_name = k[len(COMFYUI_QWEN3_PREFIX) :]
else:
logger.warning(f"Skipping unrecognized key {k}")
continue
# Get weight name
if ".lora_" in module_and_weight_name:
module_name, weight_name = module_and_weight_name.rsplit(".lora_", 1)
weight_name = "lora_" + weight_name
else:
module_name, weight_name = module_and_weight_name.rsplit(".", 1) # Keep other weight names as-is: e.g. alpha
# Weight name conversion: lora_A/lora_B to lora_up/lora_down
# Note: we only convert lora_A and lora_B weights, other weights are kept as-is
if weight_name.startswith("lora_B"):
weight_name = weight_name.replace("lora_B", "lora_up")
elif weight_name.startswith("lora_A"):
weight_name = weight_name.replace("lora_A", "lora_down")
# Module name conversion: convert dots to underscores
module_name = module_name.replace(".", "_") # Convert to underscore notation
# Prefix conversion
prefix = "lora_unet_" if is_dit_lora else "lora_te_"
new_k = f"{prefix}{module_name}.{weight_name}"
state_dict[new_k] = state_dict.pop(k)
count += 1
logger.info(f"Converted {count} keys")
if count == 0:
logger.warning("No keys were converted. Please check if the source file is in the expected format.")
elif count > 0 and count < len(keys):
logger.warning(
f"Only {count} out of {len(keys)} keys were converted. Please check if there are unexpected keys in the source file."
)
# Calculate hash
if metadata is not None:
logger.info(f"Calculating hashes and creating metadata...")
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
# save destination safetensors
logger.info(f"Saving destination file {args.dst_path}")
save_file(state_dict, args.dst_path, metadata=metadata)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert LoRA format")
parser.add_argument(
"src_path",
type=str,
default=None,
help="source path, sd-scripts format (or ComfyUI compatible format if --reverse is set, only supported for LoRAs converted by this script)",
)
parser.add_argument(
"dst_path",
type=str,
default=None,
help="destination path, ComfyUI compatible format (or sd-scripts format if --reverse is set)",
)
parser.add_argument("--reverse", action="store_true", help="reverse conversion direction")
args = parser.parse_args()
main(args)

639
networks/lora_anima.py Normal file
View File

@@ -0,0 +1,639 @@
# LoRA network module for Anima
import ast
import os
import re
from typing import Dict, List, Optional, Tuple, Type, Union
import torch
from library.utils import setup_logging
from networks.lora_flux import LoRAModule, LoRAInfModule
import logging
setup_logging()
logger = logging.getLogger(__name__)
def create_network(
multiplier: float,
network_dim: Optional[int],
network_alpha: Optional[float],
vae,
text_encoders: list,
unet,
neuron_dropout: Optional[float] = None,
**kwargs,
):
if network_dim is None:
network_dim = 4
if network_alpha is None:
network_alpha = 1.0
# train LLM adapter
train_llm_adapter = kwargs.get("train_llm_adapter", "false")
if train_llm_adapter is not None:
train_llm_adapter = True if train_llm_adapter.lower() == "true" else False
exclude_patterns = kwargs.get("exclude_patterns", None)
if exclude_patterns is None:
exclude_patterns = []
else:
exclude_patterns = ast.literal_eval(exclude_patterns)
if not isinstance(exclude_patterns, list):
exclude_patterns = [exclude_patterns]
# add default exclude patterns
exclude_patterns.append(r".*(_modulation|_norm|_embedder|final_layer).*")
# regular expression for module selection: exclude and include
include_patterns = kwargs.get("include_patterns", None)
if include_patterns is not None:
include_patterns = ast.literal_eval(include_patterns)
if not isinstance(include_patterns, list):
include_patterns = [include_patterns]
# rank/module dropout
rank_dropout = kwargs.get("rank_dropout", None)
if rank_dropout is not None:
rank_dropout = float(rank_dropout)
module_dropout = kwargs.get("module_dropout", None)
if module_dropout is not None:
module_dropout = float(module_dropout)
# verbose
verbose = kwargs.get("verbose", "false")
if verbose is not None:
verbose = True if verbose.lower() == "true" else False
# regex-specific learning rates / dimensions
def parse_kv_pairs(kv_pair_str: str, is_int: bool) -> Dict[str, float]:
"""
Parse a string of key-value pairs separated by commas.
"""
pairs = {}
for pair in kv_pair_str.split(","):
pair = pair.strip()
if not pair:
continue
if "=" not in pair:
logger.warning(f"Invalid format: {pair}, expected 'key=value'")
continue
key, value = pair.split("=", 1)
key = key.strip()
value = value.strip()
try:
pairs[key] = int(value) if is_int else float(value)
except ValueError:
logger.warning(f"Invalid value for {key}: {value}")
return pairs
network_reg_lrs = kwargs.get("network_reg_lrs", None)
if network_reg_lrs is not None:
reg_lrs = parse_kv_pairs(network_reg_lrs, is_int=False)
else:
reg_lrs = None
network_reg_dims = kwargs.get("network_reg_dims", None)
if network_reg_dims is not None:
reg_dims = parse_kv_pairs(network_reg_dims, is_int=True)
else:
reg_dims = None
network = LoRANetwork(
text_encoders,
unet,
multiplier=multiplier,
lora_dim=network_dim,
alpha=network_alpha,
dropout=neuron_dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
train_llm_adapter=train_llm_adapter,
exclude_patterns=exclude_patterns,
include_patterns=include_patterns,
reg_dims=reg_dims,
reg_lrs=reg_lrs,
verbose=verbose,
)
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
return network
def create_network_from_weights(multiplier, file, ae, text_encoders, unet, weights_sd=None, for_inference=False, **kwargs):
if weights_sd is None:
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
modules_dim = {}
modules_alpha = {}
train_llm_adapter = False
for key, value in weights_sd.items():
if "." not in key:
continue
lora_name = key.split(".")[0]
if "alpha" in key:
modules_alpha[lora_name] = value
elif "lora_down" in key:
dim = value.size()[0]
modules_dim[lora_name] = dim
if "llm_adapter" in lora_name:
train_llm_adapter = True
module_class = LoRAInfModule if for_inference else LoRAModule
network = LoRANetwork(
text_encoders,
unet,
multiplier=multiplier,
modules_dim=modules_dim,
modules_alpha=modules_alpha,
module_class=module_class,
train_llm_adapter=train_llm_adapter,
)
return network, weights_sd
class LoRANetwork(torch.nn.Module):
# Target modules: DiT blocks, embedders, final layer. embedders and final layer are excluded by default.
ANIMA_TARGET_REPLACE_MODULE = ["Block", "PatchEmbed", "TimestepEmbedding", "FinalLayer"]
# Target modules: LLM Adapter blocks
ANIMA_ADAPTER_TARGET_REPLACE_MODULE = ["LLMAdapterTransformerBlock"]
# Target modules for text encoder (Qwen3)
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["Qwen3Attention", "Qwen3MLP", "Qwen3SdpaAttention", "Qwen3FlashAttention2"]
LORA_PREFIX_ANIMA = "lora_unet" # ComfyUI compatible
LORA_PREFIX_TEXT_ENCODER = "lora_te" # Qwen3
def __init__(
self,
text_encoders: list,
unet,
multiplier: float = 1.0,
lora_dim: int = 4,
alpha: float = 1,
dropout: Optional[float] = None,
rank_dropout: Optional[float] = None,
module_dropout: Optional[float] = None,
module_class: Type[object] = LoRAModule,
modules_dim: Optional[Dict[str, int]] = None,
modules_alpha: Optional[Dict[str, int]] = None,
train_llm_adapter: bool = False,
exclude_patterns: Optional[List[str]] = None,
include_patterns: Optional[List[str]] = None,
reg_dims: Optional[Dict[str, int]] = None,
reg_lrs: Optional[Dict[str, float]] = None,
verbose: Optional[bool] = False,
) -> None:
super().__init__()
self.multiplier = multiplier
self.lora_dim = lora_dim
self.alpha = alpha
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.train_llm_adapter = train_llm_adapter
self.reg_dims = reg_dims
self.reg_lrs = reg_lrs
self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None
self.loraplus_text_encoder_lr_ratio = None
if modules_dim is not None:
logger.info("create LoRA network from weights")
else:
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
logger.info(
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
)
# compile regular expression if specified
def str_to_re_patterns(patterns: Optional[List[str]]) -> List[re.Pattern]:
re_patterns = []
if patterns is not None:
for pattern in patterns:
try:
re_pattern = re.compile(pattern)
except re.error as e:
logger.error(f"Invalid pattern '{pattern}': {e}")
continue
re_patterns.append(re_pattern)
return re_patterns
exclude_re_patterns = str_to_re_patterns(exclude_patterns)
include_re_patterns = str_to_re_patterns(include_patterns)
# create module instances
def create_modules(
is_unet: bool,
text_encoder_idx: Optional[int],
root_module: torch.nn.Module,
target_replace_modules: List[str],
default_dim: Optional[int] = None,
) -> Tuple[List[LoRAModule], List[str]]:
prefix = self.LORA_PREFIX_ANIMA if is_unet else self.LORA_PREFIX_TEXT_ENCODER
loras = []
skipped = []
for name, module in root_module.named_modules():
if target_replace_modules is None or module.__class__.__name__ in target_replace_modules:
if target_replace_modules is None:
module = root_module
for child_name, child_module in module.named_modules():
is_linear = child_module.__class__.__name__ == "Linear"
is_conv2d = child_module.__class__.__name__ == "Conv2d"
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
if is_linear or is_conv2d:
original_name = (name + "." if name else "") + child_name
lora_name = f"{prefix}.{original_name}".replace(".", "_")
# exclude/include filter (fullmatch: pattern must match the entire original_name)
excluded = any(pattern.fullmatch(original_name) for pattern in exclude_re_patterns)
included = any(pattern.fullmatch(original_name) for pattern in include_re_patterns)
if excluded and not included:
if verbose:
logger.info(f"exclude: {original_name}")
continue
dim = None
alpha_val = None
if modules_dim is not None:
if lora_name in modules_dim:
dim = modules_dim[lora_name]
alpha_val = modules_alpha[lora_name]
else:
if self.reg_dims is not None:
for reg, d in self.reg_dims.items():
if re.fullmatch(reg, original_name):
dim = d
alpha_val = self.alpha
logger.info(f"Module {original_name} matched with regex '{reg}' -> dim: {dim}")
break
# fallback to default dim if not matched by reg_dims or reg_dims is not specified
if dim is None:
if is_linear or is_conv2d_1x1:
dim = default_dim if default_dim is not None else self.lora_dim
alpha_val = self.alpha
if dim is None or dim == 0:
if is_linear or is_conv2d_1x1:
skipped.append(lora_name)
continue
lora = module_class(
lora_name,
child_module,
self.multiplier,
dim,
alpha_val,
dropout=dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
)
lora.original_name = original_name
loras.append(lora)
if target_replace_modules is None:
break
return loras, skipped
# Create LoRA for text encoders (Qwen3 - typically not trained for Anima)
self.text_encoder_loras: List[Union[LoRAModule, LoRAInfModule]] = []
skipped_te = []
if text_encoders is not None:
for i, text_encoder in enumerate(text_encoders):
if text_encoder is None:
continue
logger.info(f"create LoRA for Text Encoder {i+1}:")
te_loras, te_skipped = create_modules(False, i, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
logger.info(f"create LoRA for Text Encoder {i+1}: {len(te_loras)} modules.")
self.text_encoder_loras.extend(te_loras)
skipped_te += te_skipped
# Create LoRA for DiT blocks
target_modules = list(LoRANetwork.ANIMA_TARGET_REPLACE_MODULE)
if train_llm_adapter:
target_modules.extend(LoRANetwork.ANIMA_ADAPTER_TARGET_REPLACE_MODULE)
self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
logger.info(f"create LoRA for Anima DiT: {len(self.unet_loras)} modules.")
if verbose:
for lora in self.unet_loras:
logger.info(f"\t{lora.lora_name:60} {lora.lora_dim}, {lora.alpha}")
skipped = skipped_te + skipped_un
if verbose and len(skipped) > 0:
logger.warning(f"dim (rank) is 0, {len(skipped)} LoRA modules are skipped:")
for name in skipped:
logger.info(f"\t{name}")
# assertion: no duplicate names
names = set()
for lora in self.text_encoder_loras + self.unet_loras:
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
names.add(lora.lora_name)
def set_multiplier(self, multiplier):
self.multiplier = multiplier
for lora in self.text_encoder_loras + self.unet_loras:
lora.multiplier = self.multiplier
def set_enabled(self, is_enabled):
for lora in self.text_encoder_loras + self.unet_loras:
lora.enabled = is_enabled
def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
info = self.load_state_dict(weights_sd, False)
return info
def apply_to(self, text_encoders, unet, apply_text_encoder=True, apply_unet=True):
if apply_text_encoder:
logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules")
else:
self.text_encoder_loras = []
if apply_unet:
logger.info(f"enable LoRA for DiT: {len(self.unet_loras)} modules")
else:
self.unet_loras = []
for lora in self.text_encoder_loras + self.unet_loras:
lora.apply_to()
self.add_module(lora.lora_name, lora)
def is_mergeable(self):
return True
def merge_to(self, text_encoders, unet, weights_sd, dtype=None, device=None):
apply_text_encoder = apply_unet = False
for key in weights_sd.keys():
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
apply_text_encoder = True
elif key.startswith(LoRANetwork.LORA_PREFIX_ANIMA):
apply_unet = True
if apply_text_encoder:
logger.info("enable LoRA for text encoder")
else:
self.text_encoder_loras = []
if apply_unet:
logger.info("enable LoRA for DiT")
else:
self.unet_loras = []
for lora in self.text_encoder_loras + self.unet_loras:
sd_for_lora = {}
for key in weights_sd.keys():
if key.startswith(lora.lora_name):
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
lora.merge_to(sd_for_lora, dtype, device)
logger.info("weights are merged")
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
self.loraplus_lr_ratio = loraplus_lr_ratio
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}")
logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}")
def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr):
if text_encoder_lr is None or (isinstance(text_encoder_lr, list) and len(text_encoder_lr) == 0):
text_encoder_lr = [default_lr]
elif isinstance(text_encoder_lr, float) or isinstance(text_encoder_lr, int):
text_encoder_lr = [float(text_encoder_lr)]
elif len(text_encoder_lr) == 1:
pass # already a list with one element
self.requires_grad_(True)
all_params = []
lr_descriptions = []
def assemble_params(loras, lr, loraplus_ratio):
param_groups = {"lora": {}, "plus": {}}
reg_groups = {}
reg_lrs_list = list(self.reg_lrs.items()) if self.reg_lrs is not None else []
for lora in loras:
matched_reg_lr = None
for i, (regex_str, reg_lr) in enumerate(reg_lrs_list):
if re.fullmatch(regex_str, lora.original_name):
matched_reg_lr = (i, reg_lr)
logger.info(f"Module {lora.original_name} matched regex '{regex_str}' -> LR {reg_lr}")
break
for name, param in lora.named_parameters():
if matched_reg_lr is not None:
reg_idx, reg_lr = matched_reg_lr
group_key = f"reg_lr_{reg_idx}"
if group_key not in reg_groups:
reg_groups[group_key] = {"lora": {}, "plus": {}, "lr": reg_lr}
if loraplus_ratio is not None and "lora_up" in name:
reg_groups[group_key]["plus"][f"{lora.lora_name}.{name}"] = param
else:
reg_groups[group_key]["lora"][f"{lora.lora_name}.{name}"] = param
continue
if loraplus_ratio is not None and "lora_up" in name:
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
else:
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
params = []
descriptions = []
for group_key, group in reg_groups.items():
reg_lr = group["lr"]
for key in ("lora", "plus"):
param_data = {"params": group[key].values()}
if len(param_data["params"]) == 0:
continue
if key == "plus":
param_data["lr"] = reg_lr * loraplus_ratio if loraplus_ratio is not None else reg_lr
else:
param_data["lr"] = reg_lr
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
logger.info("NO LR skipping!")
continue
params.append(param_data)
desc = f"reg_lr_{group_key.split('_')[-1]}"
descriptions.append(desc + (" plus" if key == "plus" else ""))
for key in param_groups.keys():
param_data = {"params": param_groups[key].values()}
if len(param_data["params"]) == 0:
continue
if lr is not None:
if key == "plus":
param_data["lr"] = lr * loraplus_ratio
else:
param_data["lr"] = lr
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
logger.info("NO LR skipping!")
continue
params.append(param_data)
descriptions.append("plus" if key == "plus" else "")
return params, descriptions
if self.text_encoder_loras:
loraplus_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio
te1_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER)]
if len(te1_loras) > 0:
logger.info(f"Text Encoder 1 (Qwen3): {len(te1_loras)} modules, LR {text_encoder_lr[0]}")
params, descriptions = assemble_params(te1_loras, text_encoder_lr[0], loraplus_ratio)
all_params.extend(params)
lr_descriptions.extend(["textencoder 1" + (" " + d if d else "") for d in descriptions])
if self.unet_loras:
params, descriptions = assemble_params(
self.unet_loras,
unet_lr if unet_lr is not None else default_lr,
self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio,
)
all_params.extend(params)
lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions])
return all_params, lr_descriptions
def enable_gradient_checkpointing(self):
pass # not supported
def prepare_grad_etc(self, text_encoder, unet):
self.requires_grad_(True)
def on_epoch_start(self, text_encoder, unet):
self.train()
def get_trainable_params(self):
return self.parameters()
def save_weights(self, file, dtype, metadata):
if metadata is not None and len(metadata) == 0:
metadata = None
state_dict = self.state_dict()
if dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(dtype)
state_dict[key] = v
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file
from library import train_util
if metadata is None:
metadata = {}
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
save_file(state_dict, file, metadata)
else:
torch.save(state_dict, file)
def backup_weights(self):
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
for lora in loras:
org_module = lora.org_module_ref[0]
if not hasattr(org_module, "_lora_org_weight"):
sd = org_module.state_dict()
org_module._lora_org_weight = sd["weight"].detach().clone()
org_module._lora_restored = True
def restore_weights(self):
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
for lora in loras:
org_module = lora.org_module_ref[0]
if not org_module._lora_restored:
sd = org_module.state_dict()
sd["weight"] = org_module._lora_org_weight
org_module.load_state_dict(sd)
org_module._lora_restored = True
def pre_calculation(self):
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
for lora in loras:
org_module = lora.org_module_ref[0]
sd = org_module.state_dict()
org_weight = sd["weight"]
lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype)
sd["weight"] = org_weight + lora_weight
assert sd["weight"].shape == org_weight.shape
org_module.load_state_dict(sd)
org_module._lora_restored = False
lora.enabled = False
def apply_max_norm_regularization(self, max_norm_value, device):
downkeys = []
upkeys = []
alphakeys = []
norms = []
keys_scaled = 0
state_dict = self.state_dict()
for key in state_dict.keys():
if "lora_down" in key and "weight" in key:
downkeys.append(key)
upkeys.append(key.replace("lora_down", "lora_up"))
alphakeys.append(key.replace("lora_down.weight", "alpha"))
for i in range(len(downkeys)):
down = state_dict[downkeys[i]].to(device)
up = state_dict[upkeys[i]].to(device)
alpha = state_dict[alphakeys[i]].to(device)
dim = down.shape[0]
scale = alpha / dim
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
else:
updown = up @ down
updown *= scale
norm = updown.norm().clamp(min=max_norm_value / 2)
desired = torch.clamp(norm, max=max_norm_value)
ratio = desired.cpu() / norm.cpu()
sqrt_ratio = ratio**0.5
if ratio != 1:
keys_scaled += 1
state_dict[upkeys[i]] *= sqrt_ratio
state_dict[downkeys[i]] *= sqrt_ratio
scalednorm = updown.norm() * ratio
norms.append(scalednorm.item())
return keys_scaled, sum(norms) / len(norms), max(norms)

View File

@@ -15,6 +15,12 @@ import random
import re
import diffusers
# Compatible import for diffusers old/new UNet path
try:
from diffusers.models.unet_2d_condition import UNet2DConditionModel
except ImportError:
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
import numpy as np
import torch
@@ -80,7 +86,7 @@ CLIP_VISION_MODEL = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
"""
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
def replace_unet_modules(unet: UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
if mem_eff_attn:
logger.info("Enable memory efficient attention for U-Net")

View File

@@ -0,0 +1,607 @@
"""
Diagnostic script to test Anima latent & text encoder caching independently.
Usage:
python manual_test_anima_cache.py \
--image_dir /path/to/images \
--qwen3_path /path/to/qwen3 \
--vae_path /path/to/vae.safetensors \
[--t5_tokenizer_path /path/to/t5] \
[--cache_to_disk]
The image_dir should contain pairs of:
image1.png + image1.txt
image2.jpg + image2.txt
...
"""
import argparse
import glob
import os
import sys
import traceback
import numpy as np
import torch
from PIL import Image
from torchvision import transforms
# Helpers
IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".tiff"}
IMAGE_TRANSFORMS = transforms.Compose(
[
transforms.ToTensor(), # [0,1]
transforms.Normalize([0.5], [0.5]), # [-1,1]
]
)
def find_image_caption_pairs(image_dir: str):
"""Find (image_path, caption_text) pairs from a directory."""
pairs = []
for f in sorted(os.listdir(image_dir)):
ext = os.path.splitext(f)[1].lower()
if ext not in IMAGE_EXTENSIONS:
continue
img_path = os.path.join(image_dir, f)
txt_path = os.path.splitext(img_path)[0] + ".txt"
if os.path.exists(txt_path):
with open(txt_path, "r", encoding="utf-8") as fh:
caption = fh.read().strip()
else:
caption = ""
pairs.append((img_path, caption))
return pairs
def print_tensor_info(name: str, t, indent=2):
prefix = " " * indent
if t is None:
print(f"{prefix}{name}: None")
return
if isinstance(t, np.ndarray):
print(f"{prefix}{name}: numpy {t.dtype} shape={t.shape} " f"min={t.min():.4f} max={t.max():.4f} mean={t.mean():.4f}")
elif isinstance(t, torch.Tensor):
print(
f"{prefix}{name}: torch {t.dtype} shape={tuple(t.shape)} "
f"min={t.min().item():.4f} max={t.max().item():.4f} mean={t.float().mean().item():.4f}"
)
else:
print(f"{prefix}{name}: type={type(t)} value={t}")
# Test 1: Latent Cache
def test_latent_cache(args, pairs):
print("\n" + "=" * 70)
print("TEST 1: LATENT CACHING (VAE encode -> cache -> reload)")
print("=" * 70)
from library import qwen_image_autoencoder_kl
# Load VAE
print("\n[1.1] Loading VAE...")
device = "cuda" if torch.cuda.is_available() else "cpu"
vae_dtype = torch.float32
vae = qwen_image_autoencoder_kl.load_vae(args.vae_path, dtype=vae_dtype, device=device)
print(f" VAE loaded on {device}, dtype={vae_dtype}")
for img_path, caption in pairs:
print(f"\n[1.2] Processing: {os.path.basename(img_path)}")
# Load image
img = Image.open(img_path).convert("RGB")
img_np = np.array(img)
print(f" Raw image: {img_np.shape} dtype={img_np.dtype} " f"min={img_np.min()} max={img_np.max()}")
# Apply IMAGE_TRANSFORMS (same as sd-scripts training)
img_tensor = IMAGE_TRANSFORMS(img_np)
print(
f" After IMAGE_TRANSFORMS: shape={tuple(img_tensor.shape)} " f"min={img_tensor.min():.4f} max={img_tensor.max():.4f}"
)
# Check range is [-1, 1]
if img_tensor.min() < -1.01 or img_tensor.max() > 1.01:
print(" ** WARNING: tensor out of [-1, 1] range!")
else:
print(" OK: tensor in [-1, 1] range")
# Encode with VAE
img_batch = img_tensor.unsqueeze(0).to(device, dtype=vae_dtype) # (1, C, H, W)
img_5d = img_batch.unsqueeze(2) # (1, C, 1, H, W) - add temporal dim
print(f" VAE input: shape={tuple(img_5d.shape)} dtype={img_5d.dtype}")
with torch.no_grad():
latents = vae.encode_pixels_to_latents(img_5d)
latents_cpu = latents.cpu()
print_tensor_info("Encoded latents", latents_cpu)
# Check for NaN/Inf
if torch.any(torch.isnan(latents_cpu)):
print(" ** ERROR: NaN in latents!")
elif torch.any(torch.isinf(latents_cpu)):
print(" ** ERROR: Inf in latents!")
else:
print(" OK: no NaN/Inf")
# Test disk cache round-trip
if args.cache_to_disk:
npz_path = os.path.splitext(img_path)[0] + "_test_latent.npz"
latents_np = latents_cpu.float().numpy()
h, w = img_np.shape[:2]
np.savez(
npz_path,
latents=latents_np,
original_size=np.array([w, h]),
crop_ltrb=np.array([0, 0, 0, 0]),
)
print(f" Saved to: {npz_path}")
# Reload
loaded = np.load(npz_path)
loaded_latents = loaded["latents"]
print_tensor_info("Reloaded latents", loaded_latents)
# Compare
diff = np.abs(latents_np - loaded_latents).max()
print(f" Max diff (save vs load): {diff:.2e}")
if diff > 1e-5:
print(" ** WARNING: latent cache round-trip has significant diff!")
else:
print(" OK: round-trip matches")
os.remove(npz_path)
print(f" Cleaned up {npz_path}")
vae.to("cpu")
del vae
torch.cuda.empty_cache() if torch.cuda.is_available() else None
print("\n[1.3] Latent cache test DONE.")
# Test 2: Text Encoder Output Cache
def test_text_encoder_cache(args, pairs):
# TODO Rewrite this
print("\n" + "=" * 70)
print("TEST 2: TEXT ENCODER OUTPUT CACHING")
print("=" * 70)
from library import anima_utils
# Load tokenizers
print("\n[2.1] Loading tokenizers...")
qwen3_tokenizer = anima_utils.load_qwen3_tokenizer(args.qwen3_path)
t5_tokenizer = anima_utils.load_t5_tokenizer(getattr(args, "t5_tokenizer_path", None))
print(f" Qwen3 tokenizer vocab: {qwen3_tokenizer.vocab_size}")
print(f" T5 tokenizer vocab: {t5_tokenizer.vocab_size}")
# Load text encoder
print("\n[2.2] Loading Qwen3 text encoder...")
device = "cuda" if torch.cuda.is_available() else "cpu"
te_dtype = torch.bfloat16 if device == "cuda" else torch.float32
qwen3_model, _ = anima_utils.load_qwen3_text_encoder(args.qwen3_path, dtype=te_dtype, device=device)
qwen3_model.eval()
# Create strategy objects
from library.strategy_anima import AnimaTokenizeStrategy, AnimaTextEncodingStrategy
tokenize_strategy = AnimaTokenizeStrategy(
qwen3_tokenizer=qwen3_tokenizer,
t5_tokenizer=t5_tokenizer,
qwen3_max_length=args.qwen3_max_length,
t5_max_length=args.t5_max_length,
)
text_encoding_strategy = AnimaTextEncodingStrategy()
captions = [cap for _, cap in pairs]
print(f"\n[2.3] Tokenizing {len(captions)} captions...")
for i, cap in enumerate(captions):
print(f" [{i}] \"{cap[:80]}{'...' if len(cap) > 80 else ''}\"")
tokens_and_masks = tokenize_strategy.tokenize(captions)
qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask = tokens_and_masks
print(f"\n Tokenization results:")
print_tensor_info("qwen3_input_ids", qwen3_input_ids)
print_tensor_info("qwen3_attn_mask", qwen3_attn_mask)
print_tensor_info("t5_input_ids", t5_input_ids)
print_tensor_info("t5_attn_mask", t5_attn_mask)
# Encode
print(f"\n[2.4] Encoding with Qwen3 text encoder...")
with torch.no_grad():
prompt_embeds, attn_mask, t5_ids_out, t5_mask_out = text_encoding_strategy.encode_tokens(
tokenize_strategy, [qwen3_model], tokens_and_masks
)
print(f" Encoding results:")
print_tensor_info("prompt_embeds", prompt_embeds)
print_tensor_info("attn_mask", attn_mask)
print_tensor_info("t5_input_ids", t5_ids_out)
print_tensor_info("t5_attn_mask", t5_mask_out)
# Check for NaN/Inf
if torch.any(torch.isnan(prompt_embeds)):
print(" ** ERROR: NaN in prompt_embeds!")
elif torch.any(torch.isinf(prompt_embeds)):
print(" ** ERROR: Inf in prompt_embeds!")
else:
print(" OK: no NaN/Inf in prompt_embeds")
# Test cache round-trip (simulate what AnimaTextEncoderOutputsCachingStrategy does)
print(f"\n[2.5] Testing cache round-trip (encode -> numpy -> npz -> reload -> tensor)...")
# Convert to numpy (same as cache_batch_outputs in strategy_anima.py)
pe_cpu = prompt_embeds.cpu()
if pe_cpu.dtype == torch.bfloat16:
pe_cpu = pe_cpu.float()
pe_np = pe_cpu.numpy()
am_np = attn_mask.cpu().numpy()
t5_ids_np = t5_ids_out.cpu().numpy().astype(np.int32)
t5_mask_np = t5_mask_out.cpu().numpy().astype(np.int32)
print(f" Numpy conversions:")
print_tensor_info("prompt_embeds_np", pe_np)
print_tensor_info("attn_mask_np", am_np)
print_tensor_info("t5_input_ids_np", t5_ids_np)
print_tensor_info("t5_attn_mask_np", t5_mask_np)
if args.cache_to_disk:
npz_path = os.path.join(args.image_dir, "_test_te_cache.npz")
# Save per-sample (simulating cache_batch_outputs)
for i in range(len(captions)):
sample_npz = os.path.splitext(pairs[i][0])[0] + "_test_te.npz"
np.savez(
sample_npz,
prompt_embeds=pe_np[i],
attn_mask=am_np[i],
t5_input_ids=t5_ids_np[i],
t5_attn_mask=t5_mask_np[i],
)
print(f" Saved: {sample_npz}")
# Reload (simulating load_outputs_npz)
data = np.load(sample_npz)
print(f" Reloaded keys: {list(data.keys())}")
print_tensor_info(" loaded prompt_embeds", data["prompt_embeds"], indent=4)
print_tensor_info(" loaded attn_mask", data["attn_mask"], indent=4)
print_tensor_info(" loaded t5_input_ids", data["t5_input_ids"], indent=4)
print_tensor_info(" loaded t5_attn_mask", data["t5_attn_mask"], indent=4)
# Check diff
diff_pe = np.abs(pe_np[i] - data["prompt_embeds"]).max()
diff_t5 = np.abs(t5_ids_np[i] - data["t5_input_ids"]).max()
print(f" Max diff prompt_embeds: {diff_pe:.2e}")
print(f" Max diff t5_input_ids: {diff_t5:.2e}")
if diff_pe > 1e-5 or diff_t5 > 0:
print(" ** WARNING: cache round-trip mismatch!")
else:
print(" OK: round-trip matches")
os.remove(sample_npz)
print(f" Cleaned up {sample_npz}")
# Test in-memory cache round-trip (simulating what __getitem__ does)
print(f"\n[2.6] Testing in-memory cache simulation (tuple -> none_or_stack_elements -> batch)...")
# Simulate per-sample storage (like info.text_encoder_outputs = tuple)
per_sample_cached = []
for i in range(len(captions)):
per_sample_cached.append((pe_np[i], am_np[i], t5_ids_np[i], t5_mask_np[i]))
# Simulate none_or_stack_elements with torch.FloatTensor converter
# This is what train_util.py __getitem__ does at line 1784
stacked = []
for elem_idx in range(4):
arrays = [sample[elem_idx] for sample in per_sample_cached]
stacked.append(torch.stack([torch.FloatTensor(a) for a in arrays]))
print(f" Stacked batch (like batch['text_encoder_outputs_list']):")
names = ["prompt_embeds", "attn_mask", "t5_input_ids", "t5_attn_mask"]
for name, tensor in zip(names, stacked):
print_tensor_info(name, tensor)
# Check condition: len(text_encoder_conds) == 0 or text_encoder_conds[0] is None
text_encoder_conds = stacked
cond_check_1 = len(text_encoder_conds) == 0
cond_check_2 = text_encoder_conds[0] is None
print(f"\n Condition check (should both be False when caching works):")
print(f" len(text_encoder_conds) == 0 : {cond_check_1}")
print(f" text_encoder_conds[0] is None: {cond_check_2}")
if not cond_check_1 and not cond_check_2:
print(" OK: cached text encoder outputs would be used")
else:
print(" ** BUG: code would try to re-encode (and crash on None input_ids_list)!")
# Test unpack for get_noise_pred_and_target (line 311)
print(f"\n[2.7] Testing unpack: prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoder_conds")
try:
pe_batch, am_batch, t5_ids_batch, t5_mask_batch = text_encoder_conds
print(f" Unpack OK")
print_tensor_info("prompt_embeds", pe_batch)
print_tensor_info("attn_mask", am_batch)
print_tensor_info("t5_input_ids", t5_ids_batch)
print_tensor_info("t5_attn_mask", t5_mask_batch)
# Check t5_input_ids are integers (they were converted to FloatTensor!)
if t5_ids_batch.dtype != torch.long and t5_ids_batch.dtype != torch.int32:
print(f"\n ** NOTE: t5_input_ids dtype is {t5_ids_batch.dtype}, will be cast to long at line 316")
t5_ids_long = t5_ids_batch.to(dtype=torch.long)
# Check if any precision was lost
diff = (t5_ids_batch - t5_ids_long.float()).abs().max()
print(f" Float->Long precision loss: {diff:.2e}")
if diff > 0.5:
print(" ** ERROR: token IDs corrupted by float conversion!")
else:
print(" OK: float->long conversion is lossless for these IDs")
except Exception as e:
print(f" ** ERROR unpacking: {e}")
traceback.print_exc()
# Test drop_cached_text_encoder_outputs
print(f"\n[2.8] Testing drop_cached_text_encoder_outputs (caption dropout)...")
dropout_strategy = AnimaTextEncodingStrategy(
dropout_rate=0.5, # high rate to ensure some drops
)
dropped = dropout_strategy.drop_cached_text_encoder_outputs(*stacked)
print(f" Returned {len(dropped)} tensors")
for name, tensor in zip(names, dropped):
print_tensor_info(f"dropped_{name}", tensor)
# Check which items were dropped
for i in range(len(captions)):
is_zero = (dropped[0][i].abs().sum() == 0).item()
print(f" Sample {i}: {'DROPPED' if is_zero else 'KEPT'}")
qwen3_model.to("cpu")
del qwen3_model
torch.cuda.empty_cache() if torch.cuda.is_available() else None
print("\n[2.8] Text encoder cache test DONE.")
# Test 3: Full batch simulation
def test_full_batch_simulation(args, pairs):
print("\n" + "=" * 70)
print("TEST 3: FULL BATCH SIMULATION (mimics process_batch flow)")
print("=" * 70)
from library import anima_utils
from library.strategy_anima import AnimaTokenizeStrategy, AnimaTextEncodingStrategy
device = "cuda" if torch.cuda.is_available() else "cpu"
te_dtype = torch.bfloat16 if device == "cuda" else torch.float32
vae_dtype = torch.float32
# Load all models
print("\n[3.1] Loading models...")
qwen3_tokenizer = anima_utils.load_qwen3_tokenizer(args.qwen3_path)
t5_tokenizer = anima_utils.load_t5_tokenizer(getattr(args, "t5_tokenizer_path", None))
qwen3_model, _ = anima_utils.load_qwen3_text_encoder(args.qwen3_path, dtype=te_dtype, device=device)
qwen3_model.eval()
vae, _, _, vae_scale = anima_utils.load_anima_vae(args.vae_path, dtype=vae_dtype, device=device)
tokenize_strategy = AnimaTokenizeStrategy(
qwen3_tokenizer=qwen3_tokenizer,
t5_tokenizer=t5_tokenizer,
qwen3_max_length=args.qwen3_max_length,
t5_max_length=args.t5_max_length,
)
text_encoding_strategy = AnimaTextEncodingStrategy(dropout_rate=0.0)
captions = [cap for _, cap in pairs]
# --- Simulate caching phase ---
print("\n[3.2] Simulating text encoder caching phase...")
tokens_and_masks = tokenize_strategy.tokenize(captions)
with torch.no_grad():
te_outputs = text_encoding_strategy.encode_tokens(
tokenize_strategy,
[qwen3_model],
tokens_and_masks,
enable_dropout=False,
)
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = te_outputs
# Convert to numpy (same as cache_batch_outputs)
pe_np = prompt_embeds.cpu().float().numpy()
am_np = attn_mask.cpu().numpy()
t5_ids_np = t5_input_ids.cpu().numpy().astype(np.int32)
t5_mask_np = t5_attn_mask.cpu().numpy().astype(np.int32)
# Per-sample storage (like info.text_encoder_outputs)
per_sample_te = [(pe_np[i], am_np[i], t5_ids_np[i], t5_mask_np[i]) for i in range(len(captions))]
print(f"\n[3.3] Simulating latent caching phase...")
per_sample_latents = []
for img_path, _ in pairs:
img = Image.open(img_path).convert("RGB")
img_np = np.array(img)
img_tensor = IMAGE_TRANSFORMS(img_np).unsqueeze(0).unsqueeze(2) # (1,C,1,H,W)
img_tensor = img_tensor.to(device, dtype=vae_dtype)
with torch.no_grad():
lat = vae.encode(img_tensor, vae_scale).cpu()
per_sample_latents.append(lat.squeeze(0)) # (C,1,H,W)
print(f" {os.path.basename(img_path)}: latent shape={tuple(lat.shape)}")
# --- Simulate batch construction (__getitem__) ---
print(f"\n[3.4] Simulating batch construction...")
# Use first image's latents only (images may have different resolutions)
latents_batch = per_sample_latents[0].unsqueeze(0) # (1,C,1,H,W)
print(f" Using first image latent for simulation: shape={tuple(latents_batch.shape)}")
# Stack text encoder outputs (none_or_stack_elements)
text_encoder_outputs_list = []
for elem_idx in range(4):
arrays = [s[elem_idx] for s in per_sample_te]
text_encoder_outputs_list.append(torch.stack([torch.FloatTensor(a) for a in arrays]))
# input_ids_list is None when caching
input_ids_list = None
batch = {
"latents": latents_batch,
"text_encoder_outputs_list": text_encoder_outputs_list,
"input_ids_list": input_ids_list,
"loss_weights": torch.ones(len(captions)),
}
print(f" batch keys: {list(batch.keys())}")
print(f" batch['latents']: shape={tuple(batch['latents'].shape)}")
print(f" batch['text_encoder_outputs_list']: {len(batch['text_encoder_outputs_list'])} tensors")
print(f" batch['input_ids_list']: {batch['input_ids_list']}")
# --- Simulate process_batch logic ---
print(f"\n[3.5] Simulating process_batch logic...")
text_encoder_conds = []
te_out = batch.get("text_encoder_outputs_list", None)
if te_out is not None:
text_encoder_conds = te_out
print(f" text_encoder_conds loaded from cache: {len(text_encoder_conds)} tensors")
else:
print(f" text_encoder_conds: empty (no cache)")
# The critical condition
train_text_encoder_TRUE = True # OLD behavior (base class default, no override)
train_text_encoder_FALSE = False # NEW behavior (with is_train_text_encoder override)
cond_old = len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder_TRUE
cond_new = len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder_FALSE
print(f"\n === CRITICAL CONDITION CHECK ===")
print(f" len(text_encoder_conds) == 0 : {len(text_encoder_conds) == 0}")
print(f" text_encoder_conds[0] is None: {text_encoder_conds[0] is None}")
print(f" train_text_encoder (OLD=True) : {train_text_encoder_TRUE}")
print(f" train_text_encoder (NEW=False): {train_text_encoder_FALSE}")
print(f"")
print(f" Condition with OLD behavior (no override): {cond_old}")
msg = (
"ENTERS re-encode block -> accesses batch['input_ids_list'] -> CRASH!"
if cond_old
else "SKIPS re-encode block -> uses cache -> OK"
)
print(f" -> {msg}")
print(f" Condition with NEW behavior (override): {cond_new}")
print(f" -> {'ENTERS re-encode block' if cond_new else 'SKIPS re-encode block -> uses cache -> OK'}")
if cond_old and not cond_new:
print(f"\n ** CONFIRMED: the is_train_text_encoder override fixes the crash **")
# Simulate the rest of process_batch
print(f"\n[3.6] Simulating get_noise_pred_and_target unpack...")
try:
pe, am, t5_ids, t5_mask = text_encoder_conds
pe = pe.to(device, dtype=te_dtype)
am = am.to(device)
t5_ids = t5_ids.to(device, dtype=torch.long)
t5_mask = t5_mask.to(device)
print(f" Unpack + device transfer OK:")
print_tensor_info("prompt_embeds", pe)
print_tensor_info("attn_mask", am)
print_tensor_info("t5_input_ids", t5_ids)
print_tensor_info("t5_attn_mask", t5_mask)
# Verify t5_input_ids didn't get corrupted by float conversion
t5_ids_orig = torch.tensor(t5_ids_np, dtype=torch.long, device=device)
id_match = torch.all(t5_ids == t5_ids_orig).item()
print(f"\n t5_input_ids integrity (float->long roundtrip): {'OK' if id_match else '** MISMATCH **'}")
if not id_match:
diff_count = (t5_ids != t5_ids_orig).sum().item()
print(f" {diff_count} token IDs differ!")
# Show example
idx = torch.where(t5_ids != t5_ids_orig)
if len(idx[0]) > 0:
i, j = idx[0][0].item(), idx[1][0].item()
print(f" Example: position [{i},{j}] original={t5_ids_orig[i,j].item()} loaded={t5_ids[i,j].item()}")
except Exception as e:
print(f" ** ERROR: {e}")
traceback.print_exc()
# Cleanup
vae.to("cpu")
qwen3_model.to("cpu")
del vae, qwen3_model
torch.cuda.empty_cache() if torch.cuda.is_available() else None
print("\n[3.7] Full batch simulation DONE.")
# Main
def main():
parser = argparse.ArgumentParser(description="Test Anima caching mechanisms")
parser.add_argument("--image_dir", type=str, required=True, help="Directory with image+txt pairs")
parser.add_argument("--qwen3_path", type=str, required=True, help="Path to Qwen3 model (directory or safetensors)")
parser.add_argument("--vae_path", type=str, required=True, help="Path to WanVAE safetensors")
parser.add_argument("--t5_tokenizer_path", type=str, default=None, help="Path to T5 tokenizer (optional, uses bundled config)")
parser.add_argument("--qwen3_max_length", type=int, default=512)
parser.add_argument("--t5_max_length", type=int, default=512)
parser.add_argument("--cache_to_disk", action="store_true", help="Also test disk cache round-trip")
parser.add_argument("--skip_latent", action="store_true", help="Skip latent cache test")
parser.add_argument("--skip_text", action="store_true", help="Skip text encoder cache test")
parser.add_argument("--skip_full", action="store_true", help="Skip full batch simulation")
args = parser.parse_args()
# Find pairs
pairs = find_image_caption_pairs(args.image_dir)
if len(pairs) == 0:
print(f"ERROR: No image+txt pairs found in {args.image_dir}")
print("Expected: image.png + image.txt, image.jpg + image.txt, etc.")
sys.exit(1)
print(f"Found {len(pairs)} image-caption pairs:")
for img_path, cap in pairs:
print(f" {os.path.basename(img_path)}: \"{cap[:60]}{'...' if len(cap) > 60 else ''}\"")
results = {}
if not args.skip_latent:
try:
test_latent_cache(args, pairs)
results["latent_cache"] = "PASS"
except Exception as e:
print(f"\n** LATENT CACHE TEST FAILED: {e}")
traceback.print_exc()
results["latent_cache"] = f"FAIL: {e}"
if not args.skip_text:
try:
test_text_encoder_cache(args, pairs)
results["text_encoder_cache"] = "PASS"
except Exception as e:
print(f"\n** TEXT ENCODER CACHE TEST FAILED: {e}")
traceback.print_exc()
results["text_encoder_cache"] = f"FAIL: {e}"
if not args.skip_full:
try:
test_full_batch_simulation(args, pairs)
results["full_batch_sim"] = "PASS"
except Exception as e:
print(f"\n** FULL BATCH SIMULATION FAILED: {e}")
traceback.print_exc()
results["full_batch_sim"] = f"FAIL: {e}"
# Summary
print("\n" + "=" * 70)
print("SUMMARY")
print("=" * 70)
for test, result in results.items():
status = "OK" if result == "PASS" else "FAIL"
print(f" [{status}] {test}: {result}")
print()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,242 @@
"""
Test script that actually runs anima_train.py and anima_train_network.py
for a few steps to verify --cache_text_encoder_outputs works.
Usage:
python test_anima_real_training.py \
--image_dir /path/to/images_with_txt \
--dit_path /path/to/dit.safetensors \
--qwen3_path /path/to/qwen3 \
--vae_path /path/to/vae.safetensors \
[--t5_tokenizer_path /path/to/t5] \
[--resolution 512]
This will run 4 tests:
1. anima_train.py (full finetune, no cache)
2. anima_train.py (full finetune, --cache_text_encoder_outputs)
3. anima_train_network.py (LoRA, no cache)
4. anima_train_network.py (LoRA, --cache_text_encoder_outputs)
Each test runs only 2 training steps then stops.
"""
import argparse
import os
import subprocess
import sys
import tempfile
import shutil
def create_dataset_toml(image_dir: str, resolution: int, toml_path: str):
"""Create a minimal dataset toml config."""
content = f"""[general]
resolution = {resolution}
enable_bucket = true
bucket_reso_steps = 8
min_bucket_reso = 256
max_bucket_reso = 1024
[[datasets]]
batch_size = 1
[[datasets.subsets]]
image_dir = "{image_dir}"
num_repeats = 1
caption_extension = ".txt"
"""
with open(toml_path, "w", encoding="utf-8") as f:
f.write(content)
return toml_path
def run_test(test_name: str, cmd: list, timeout: int = 300) -> dict:
"""Run a training command and capture result."""
print(f"\n{'=' * 70}")
print(f"TEST: {test_name}")
print(f"{'=' * 70}")
print(f"Command: {' '.join(cmd)}\n")
try:
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=timeout,
cwd=os.path.dirname(os.path.abspath(__file__)),
)
stdout = result.stdout
stderr = result.stderr
returncode = result.returncode
# Print last N lines of output
all_output = stdout + "\n" + stderr
lines = all_output.strip().split("\n")
print(f"--- Last 30 lines of output ---")
for line in lines[-30:]:
print(f" {line}")
print(f"--- End output ---\n")
if returncode == 0:
print(f"RESULT: PASS (exit code 0)")
return {"status": "PASS", "detail": "completed successfully"}
else:
# Check if it's a known error
if "TypeError: 'NoneType' object is not iterable" in all_output:
print(f"RESULT: FAIL - input_ids_list is None (the cache_text_encoder_outputs bug)")
return {"status": "FAIL", "detail": "input_ids_list is None - cache TE outputs bug"}
elif "steps: 0%" in all_output and "Error" in all_output:
# Find the actual error
error_lines = [l for l in lines if "Error" in l or "Traceback" in l or "raise" in l.lower()]
detail = error_lines[-1] if error_lines else f"exit code {returncode}"
print(f"RESULT: FAIL - {detail}")
return {"status": "FAIL", "detail": detail}
else:
print(f"RESULT: FAIL (exit code {returncode})")
return {"status": "FAIL", "detail": f"exit code {returncode}"}
except subprocess.TimeoutExpired:
print(f"RESULT: TIMEOUT (>{timeout}s)")
return {"status": "TIMEOUT", "detail": f"exceeded {timeout}s"}
except Exception as e:
print(f"RESULT: ERROR - {e}")
return {"status": "ERROR", "detail": str(e)}
def main():
parser = argparse.ArgumentParser(description="Test Anima real training with cache flags")
parser.add_argument("--image_dir", type=str, required=True,
help="Directory with image+txt pairs")
parser.add_argument("--dit_path", type=str, required=True,
help="Path to Anima DiT safetensors")
parser.add_argument("--qwen3_path", type=str, required=True,
help="Path to Qwen3 model")
parser.add_argument("--vae_path", type=str, required=True,
help="Path to WanVAE safetensors")
parser.add_argument("--t5_tokenizer_path", type=str, default=None)
parser.add_argument("--resolution", type=int, default=512)
parser.add_argument("--timeout", type=int, default=300,
help="Timeout per test in seconds (default: 300)")
parser.add_argument("--only", type=str, default=None,
choices=["finetune", "lora"],
help="Only run finetune or lora tests")
args = parser.parse_args()
# Validate paths
for name, path in [("image_dir", args.image_dir), ("dit_path", args.dit_path),
("qwen3_path", args.qwen3_path), ("vae_path", args.vae_path)]:
if not os.path.exists(path):
print(f"ERROR: {name} does not exist: {path}")
sys.exit(1)
# Create temp dir for outputs
tmp_dir = tempfile.mkdtemp(prefix="anima_test_")
print(f"Temp directory: {tmp_dir}")
# Create dataset toml
toml_path = os.path.join(tmp_dir, "dataset.toml")
create_dataset_toml(args.image_dir, args.resolution, toml_path)
print(f"Dataset config: {toml_path}")
output_dir = os.path.join(tmp_dir, "output")
os.makedirs(output_dir, exist_ok=True)
python = sys.executable
# Common args for both scripts
common_anima_args = [
"--dit_path", args.dit_path,
"--qwen3_path", args.qwen3_path,
"--vae_path", args.vae_path,
"--pretrained_model_name_or_path", args.dit_path, # required by base parser
"--output_dir", output_dir,
"--output_name", "test",
"--dataset_config", toml_path,
"--max_train_steps", "2",
"--learning_rate", "1e-5",
"--mixed_precision", "bf16",
"--save_every_n_steps", "999", # don't save
"--max_data_loader_n_workers", "0", # single process for clarity
"--logging_dir", os.path.join(tmp_dir, "logs"),
"--cache_latents",
]
if args.t5_tokenizer_path:
common_anima_args += ["--t5_tokenizer_path", args.t5_tokenizer_path]
results = {}
# TEST 1: anima_train.py - NO cache_text_encoder_outputs
if args.only is None or args.only == "finetune":
cmd = [python, "anima_train.py"] + common_anima_args + [
"--optimizer_type", "AdamW8bit",
]
results["finetune_no_cache"] = run_test(
"anima_train.py (full finetune, NO text encoder cache)",
cmd, args.timeout,
)
# TEST 2: anima_train.py - WITH cache_text_encoder_outputs
cmd = [python, "anima_train.py"] + common_anima_args + [
"--optimizer_type", "AdamW8bit",
"--cache_text_encoder_outputs",
]
results["finetune_with_cache"] = run_test(
"anima_train.py (full finetune, WITH --cache_text_encoder_outputs)",
cmd, args.timeout,
)
# TEST 3: anima_train_network.py - NO cache_text_encoder_outputs
if args.only is None or args.only == "lora":
lora_args = common_anima_args + [
"--optimizer_type", "AdamW8bit",
"--network_module", "networks.lora_anima",
"--network_dim", "4",
"--network_alpha", "1",
]
cmd = [python, "anima_train_network.py"] + lora_args
results["lora_no_cache"] = run_test(
"anima_train_network.py (LoRA, NO text encoder cache)",
cmd, args.timeout,
)
# TEST 4: anima_train_network.py - WITH cache_text_encoder_outputs
cmd = [python, "anima_train_network.py"] + lora_args + [
"--cache_text_encoder_outputs",
]
results["lora_with_cache"] = run_test(
"anima_train_network.py (LoRA, WITH --cache_text_encoder_outputs)",
cmd, args.timeout,
)
# SUMMARY
print(f"\n{'=' * 70}")
print("SUMMARY")
print(f"{'=' * 70}")
all_pass = True
for test_name, result in results.items():
status = result["status"]
icon = "OK" if status == "PASS" else "FAIL"
if status != "PASS":
all_pass = False
print(f" [{icon:4s}] {test_name}: {result['detail']}")
print(f"\nTemp directory (can delete): {tmp_dir}")
# Cleanup
try:
shutil.rmtree(tmp_dir)
print("Temp directory cleaned up.")
except Exception:
print(f"Note: could not clean up {tmp_dir}")
if all_pass:
print("\nAll tests PASSED!")
else:
print("\nSome tests FAILED!")
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -470,7 +470,7 @@ class NetworkTrainer:
loss = loss * weighting
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
loss = loss.mean(dim=list(range(1, loss.ndim))) # mean over all dims except batch
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights