Kohya-ss-sd-scripts/anima_train_network.py
Kohya S. 34e7138b6a
Add/modify some implementation for anima (#2261)
* fix: update extend-exclude list in _typos.toml to include configs

* fix: exclude anima tests from pytest

* feat: add entry for 'temperal' in extend-words section of _typos.toml for Qwen-Image VAE

* fix: update default value for --discrete_flow_shift in anima training guide

* feat: add Qwen-Image VAE

* feat: simplify encode_tokens

* feat: use unified attention module, add wrapper for state dict compatibility

* feat: loading with dynamic fp8 optimization and LoRA support

* feat: add anima minimal inference script (WIP)

* format: format

* feat: simplify target module selection by regular expression patterns

* feat: kept caption dropout rate in cache and handle in training script

* feat: update train_llm_adapter and verbose default values to string type

* fix: use strategy instead of using tokenizers directly

* feat: add dtype property and all-zero mask handling in cross-attention in LLMAdapterTransformerBlock

* feat: support 5d tensor in get_noisy_model_input_and_timesteps

* feat: update loss calculation to support 5d tensor

* fix: update argument names in anima_train_utils to align with other archtectures

* feat: simplify Anima training script and update empty caption handling

* feat: support LoRA format without `net.` prefix

* fix: update to work fp8_scaled option

* feat: add regex-based learning rates and dimensions handling in create_network

* fix: improve regex matching for module selection and learning rates in LoRANetwork

* fix: update logging message for regex match in LoRANetwork

* fix: keep latents 4D except DiT call

* feat: enhance block swap functionality for inference and training in Anima model

* feat: refactor Anima training script

* feat: optimize VAE processing by adjusting tensor dimensions and data types

* fix: wait all block trasfer before siwtching offloader mode

* feat: update Anima training guide with new argument specifications and regex-based module selection. Thank you Claude!

* feat: support LORA for Qwen3

* feat: update Anima SAI model spec metadata handling

* fix: remove unused code

* feat: split CFG processing in do_sample function to reduce memory usage

* feat: add VAE chunking and caching options to reduce memory usage

* feat: optimize RMSNorm forward method and remove unused torch_attention_op

* Update library/strategy_anima.py

Use torch.all instead of all.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update library/safetensors_utils.py

Fix duplicated new_key for concat_hook.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update anima_minimal_inference.py

Remove unused code.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update anima_train.py

Remove unused import.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update library/anima_train_utils.py

Remove unused import.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* fix: review with Copilot

* feat: add script to convert LoRA format to ComfyUI compatible format (WIP, not tested yet)

* feat: add process_escape function to handle escape sequences in prompts

* feat: enhance LoRA weight handling in model loading and add text encoder loading function

* feat: improve ComfyUI conversion script with prefix constants and module name adjustments

* feat: update caption dropout documentation to clarify cache regeneration requirement

* feat: add clarification on learning rate adjustments

* feat: add note on PyTorch version requirement to prevent NaN loss

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-02-13 08:15:06 +09:00

449 lines
18 KiB
Python

# Anima LoRA training script
import argparse
from typing import Any, Optional, Union
import torch
import torch.nn as nn
from accelerate import Accelerator
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from library import (
anima_models,
anima_train_utils,
anima_utils,
flux_train_utils,
qwen_image_autoencoder_kl,
sd3_train_utils,
strategy_anima,
strategy_base,
train_util,
)
import train_network
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class AnimaNetworkTrainer(train_network.NetworkTrainer):
def __init__(self):
super().__init__()
self.sample_prompts_te_outputs = None
def assert_extra_args(
self,
args,
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
val_dataset_group: Optional[train_util.DatasetGroup],
):
if args.fp8_base or args.fp8_base_unet:
logger.warning("fp8_base and fp8_base_unet are not supported. / fp8_baseとfp8_base_unetはサポートされていません。")
args.fp8_base = False
args.fp8_base_unet = False
args.fp8_scaled = False # Anima DiT does not support fp8_scaled
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
logger.warning("cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled")
args.cache_text_encoder_outputs = True
if args.cache_text_encoder_outputs:
assert train_dataset_group.is_text_encoder_output_cacheable(
cache_supports_dropout=True
), "when caching Text Encoder output, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used"
assert (
args.network_train_unet_only or not args.cache_text_encoder_outputs
), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません"
assert (
args.blocks_to_swap is None or args.blocks_to_swap == 0
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing"
if args.unsloth_offload_checkpointing:
if not args.gradient_checkpointing:
logger.warning("unsloth_offload_checkpointing is enabled, so gradient_checkpointing is also enabled")
args.gradient_checkpointing = True
assert (
not args.cpu_offload_checkpointing
), "Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing"
assert (
args.blocks_to_swap is None or args.blocks_to_swap == 0
), "blocks_to_swap is not supported with unsloth_offload_checkpointing"
train_dataset_group.verify_bucket_reso_steps(16) # WanVAE spatial downscale = 8 and patch size = 2
if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(16)
def load_target_model(self, args, weight_dtype, accelerator):
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
# Load Qwen3 text encoder (tokenizers already loaded in get_tokenize_strategy)
logger.info("Loading Qwen3 text encoder...")
qwen3_text_encoder, _ = anima_utils.load_qwen3_text_encoder(args.qwen3, dtype=weight_dtype, device="cpu")
qwen3_text_encoder.eval()
# Load VAE
logger.info("Loading Anima VAE...")
vae = qwen_image_autoencoder_kl.load_vae(
args.vae, device="cpu", disable_mmap=True, spatial_chunk_size=args.vae_chunk_size, disable_cache=args.vae_disable_cache
)
vae.to(weight_dtype)
vae.eval()
# Return format: (model_type, text_encoders, vae, unet)
return "anima", [qwen3_text_encoder], vae, None # unet loaded lazily
def load_unet_lazily(self, args, weight_dtype, accelerator, text_encoders) -> tuple[nn.Module, list[nn.Module]]:
loading_dtype = None if args.fp8_scaled else weight_dtype
loading_device = "cpu" if self.is_swapping_blocks else accelerator.device
attn_mode = "torch"
if args.xformers:
attn_mode = "xformers"
if args.attn_mode is not None:
attn_mode = args.attn_mode
# Load DiT
logger.info(f"Loading Anima DiT model with attn_mode={attn_mode}, split_attn: {args.split_attn}...")
model = anima_utils.load_anima_model(
accelerator.device,
args.pretrained_model_name_or_path,
attn_mode,
args.split_attn,
loading_device,
loading_dtype,
args.fp8_scaled,
)
# Store unsloth preference so that when the base NetworkTrainer calls
# dit.enable_gradient_checkpointing(cpu_offload=...), we can override to use unsloth.
# The base trainer only passes cpu_offload, so we store the flag on the model.
self._use_unsloth_offload_checkpointing = args.unsloth_offload_checkpointing
# Block swap
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
if self.is_swapping_blocks:
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
model.enable_block_swap(args.blocks_to_swap, accelerator.device)
return model, text_encoders
def get_tokenize_strategy(self, args):
# Load tokenizers from paths (called before load_target_model, so self.qwen3_tokenizer isn't set yet)
tokenize_strategy = strategy_anima.AnimaTokenizeStrategy(
qwen3_path=args.qwen3,
t5_tokenizer_path=args.t5_tokenizer_path,
qwen3_max_length=args.qwen3_max_token_length,
t5_max_length=args.t5_max_token_length,
)
return tokenize_strategy
def get_tokenizers(self, tokenize_strategy: strategy_anima.AnimaTokenizeStrategy):
return [tokenize_strategy.qwen3_tokenizer]
def get_latents_caching_strategy(self, args):
return strategy_anima.AnimaLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check)
def get_text_encoding_strategy(self, args):
return strategy_anima.AnimaTextEncodingStrategy()
def post_process_network(self, args, accelerator, network, text_encoders, unet):
pass
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
if args.cache_text_encoder_outputs:
return None # no text encoders needed for encoding
return text_encoders
def get_text_encoder_outputs_caching_strategy(self, args):
if args.cache_text_encoder_outputs:
return strategy_anima.AnimaTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False
)
return None
def cache_text_encoder_outputs_if_needed(
self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
):
if args.cache_text_encoder_outputs:
if not args.lowram:
# We cannot move DiT to CPU because of block swap, so only move VAE
logger.info("move vae to cpu to save memory")
org_vae_device = vae.device
vae.to("cpu")
clean_memory_on_device(accelerator.device)
logger.info("move text encoder to gpu")
text_encoders[0].to(accelerator.device)
with accelerator.autocast():
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
# cache sample prompts
if args.sample_prompts is not None:
logger.info(f"cache Text Encoder outputs for sample prompts: {args.sample_prompts}")
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
prompts = train_util.load_prompts(args.sample_prompts)
sample_prompts_te_outputs = {}
with accelerator.autocast(), torch.no_grad():
for prompt_dict in prompts:
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
if p not in sample_prompts_te_outputs:
logger.info(f" cache TE outputs for: {p}")
tokens_and_masks = tokenize_strategy.tokenize(p)
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
tokenize_strategy, text_encoders, tokens_and_masks
)
self.sample_prompts_te_outputs = sample_prompts_te_outputs
accelerator.wait_for_everyone()
# move text encoder back to cpu
logger.info("move text encoder back to cpu")
text_encoders[0].to("cpu")
if not args.lowram:
logger.info("move vae back to original device")
vae.to(org_vae_device)
clean_memory_on_device(accelerator.device)
else:
# move text encoder to device for encoding during training/validation
text_encoders[0].to(accelerator.device)
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder] # compatibility
te = self.get_models_for_text_encoding(args, accelerator, text_encoders)
qwen3_te = te[0] if te is not None else None
text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
anima_train_utils.sample_images(
accelerator,
args,
epoch,
global_step,
unet,
vae,
qwen3_te,
tokenize_strategy,
text_encoding_strategy,
self.sample_prompts_te_outputs,
)
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
return noise_scheduler
def encode_images_to_latents(self, args, vae, images):
vae: qwen_image_autoencoder_kl.AutoencoderKLQwenImage
return vae.encode_pixels_to_latents(images) # Keep 4D for input/output
def shift_scale_latents(self, args, latents):
# Latents already normalized by vae.encode with scale
return latents
def get_noise_pred_and_target(
self,
args,
accelerator,
noise_scheduler,
latents,
batch,
text_encoder_conds,
unet,
network,
weight_dtype,
train_unet,
is_train=True,
):
anima: anima_models.Anima = unet
# Sample noise
if latents.ndim == 5: # Fallback for 5D latents (old cache)
latents = latents.squeeze(2) # [B, C, 1, H, W] -> [B, C, H, W]
noise = torch.randn_like(latents)
# Get noisy model input and timesteps
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
)
timesteps = timesteps / 1000.0 # scale to [0, 1] range. timesteps is float32
# Gradient checkpointing support
if args.gradient_checkpointing:
noisy_model_input.requires_grad_(True)
for t in text_encoder_conds:
if t is not None and t.dtype.is_floating_point:
t.requires_grad_(True)
# Unpack text encoder conditions
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoder_conds
# Move to device
prompt_embeds = prompt_embeds.to(accelerator.device, dtype=weight_dtype)
attn_mask = attn_mask.to(accelerator.device)
t5_input_ids = t5_input_ids.to(accelerator.device, dtype=torch.long)
t5_attn_mask = t5_attn_mask.to(accelerator.device)
# Create padding mask
bs = latents.shape[0]
h_latent = latents.shape[-2]
w_latent = latents.shape[-1]
padding_mask = torch.zeros(bs, 1, h_latent, w_latent, dtype=weight_dtype, device=accelerator.device)
# Call model
noisy_model_input = noisy_model_input.unsqueeze(2) # 4D to 5D, [B, C, H, W] -> [B, C, 1, H, W]
with torch.set_grad_enabled(is_train), accelerator.autocast():
model_pred = anima(
noisy_model_input,
timesteps,
prompt_embeds,
padding_mask=padding_mask,
target_input_ids=t5_input_ids,
target_attention_mask=t5_attn_mask,
source_attention_mask=attn_mask,
)
model_pred = model_pred.squeeze(2) # 5D to 4D, [B, C, 1, H, W] -> [B, C, H, W]
# Rectified flow target: noise - latents
target = noise - latents
# Loss weighting
weighting = anima_train_utils.compute_loss_weighting_for_anima(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
return model_pred, target, timesteps, weighting
def process_batch(
self,
batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train=True,
train_text_encoder=True,
train_unet=True,
) -> torch.Tensor:
"""Override base process_batch for caption dropout with cached text encoder outputs."""
# Text encoder conditions
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
anima_text_encoding_strategy: strategy_anima.AnimaTextEncodingStrategy = text_encoding_strategy
if text_encoder_outputs_list is not None:
caption_dropout_rates = text_encoder_outputs_list[-1]
text_encoder_outputs_list = text_encoder_outputs_list[:-1]
# Apply caption dropout to cached outputs
text_encoder_outputs_list = anima_text_encoding_strategy.drop_cached_text_encoder_outputs(
*text_encoder_outputs_list, caption_dropout_rates=caption_dropout_rates
)
batch["text_encoder_outputs_list"] = text_encoder_outputs_list
return super().process_batch(
batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train,
train_text_encoder,
train_unet,
)
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
return loss
def get_sai_model_spec(self, args):
return train_util.get_sai_model_spec_dataclass(None, args, False, True, False, anima="preview").to_metadata_dict()
def update_metadata(self, metadata, args):
metadata["ss_weighting_scheme"] = args.weighting_scheme
metadata["ss_logit_mean"] = args.logit_mean
metadata["ss_logit_std"] = args.logit_std
metadata["ss_mode_scale"] = args.mode_scale
metadata["ss_timestep_sampling"] = args.timestep_sampling
metadata["ss_sigmoid_scale"] = args.sigmoid_scale
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
def is_text_encoder_not_needed_for_training(self, args):
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
# Set first parameter's requires_grad to True to workaround Accelerate gradient checkpointing bug
first_param = next(text_encoder.parameters())
first_param.requires_grad_(True)
def prepare_unet_with_accelerator(
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
) -> torch.nn.Module:
# The base NetworkTrainer only calls enable_gradient_checkpointing(cpu_offload=True/False),
# so we re-apply with unsloth_offload if needed (after base has already enabled it).
if self._use_unsloth_offload_checkpointing and args.gradient_checkpointing:
unet.enable_gradient_checkpointing(unsloth_offload=True)
if not self.is_swapping_blocks:
return super().prepare_unet_with_accelerator(args, accelerator, unet)
model = unet
model = accelerator.prepare(model, device_placement=[not self.is_swapping_blocks])
accelerator.unwrap_model(model).move_to_device_except_swap_blocks(accelerator.device)
accelerator.unwrap_model(model).prepare_block_swap_before_forward()
return model
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
if self.is_swapping_blocks:
# prepare for next forward: because backward pass is not called, we need to prepare it here
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
def setup_parser() -> argparse.ArgumentParser:
parser = train_network.setup_parser()
train_util.add_dit_training_arguments(parser)
anima_train_utils.add_anima_training_arguments(parser)
# parser.add_argument("--fp8_scaled", action="store_true", help="Use scaled fp8 for DiT / DiTにスケーリングされたfp8を使う")
parser.add_argument(
"--unsloth_offload_checkpointing",
action="store_true",
help="offload activations to CPU RAM using async non-blocking transfers (faster than --cpu_offload_checkpointing). "
"Cannot be used with --cpu_offload_checkpointing or --blocks_to_swap.",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
if args.attn_mode == "sdpa":
args.attn_mode = "torch" # backward compatibility
trainer = AnimaNetworkTrainer()
trainer.train(args)