mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
* fix: update extend-exclude list in _typos.toml to include configs * fix: exclude anima tests from pytest * feat: add entry for 'temperal' in extend-words section of _typos.toml for Qwen-Image VAE * fix: update default value for --discrete_flow_shift in anima training guide * feat: add Qwen-Image VAE * feat: simplify encode_tokens * feat: use unified attention module, add wrapper for state dict compatibility * feat: loading with dynamic fp8 optimization and LoRA support * feat: add anima minimal inference script (WIP) * format: format * feat: simplify target module selection by regular expression patterns * feat: kept caption dropout rate in cache and handle in training script * feat: update train_llm_adapter and verbose default values to string type * fix: use strategy instead of using tokenizers directly * feat: add dtype property and all-zero mask handling in cross-attention in LLMAdapterTransformerBlock * feat: support 5d tensor in get_noisy_model_input_and_timesteps * feat: update loss calculation to support 5d tensor * fix: update argument names in anima_train_utils to align with other archtectures * feat: simplify Anima training script and update empty caption handling * feat: support LoRA format without `net.` prefix * fix: update to work fp8_scaled option * feat: add regex-based learning rates and dimensions handling in create_network * fix: improve regex matching for module selection and learning rates in LoRANetwork * fix: update logging message for regex match in LoRANetwork * fix: keep latents 4D except DiT call * feat: enhance block swap functionality for inference and training in Anima model * feat: refactor Anima training script * feat: optimize VAE processing by adjusting tensor dimensions and data types * fix: wait all block trasfer before siwtching offloader mode * feat: update Anima training guide with new argument specifications and regex-based module selection. Thank you Claude! * feat: support LORA for Qwen3 * feat: update Anima SAI model spec metadata handling * fix: remove unused code * feat: split CFG processing in do_sample function to reduce memory usage * feat: add VAE chunking and caching options to reduce memory usage * feat: optimize RMSNorm forward method and remove unused torch_attention_op * Update library/strategy_anima.py Use torch.all instead of all. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update library/safetensors_utils.py Fix duplicated new_key for concat_hook. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update anima_minimal_inference.py Remove unused code. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update anima_train.py Remove unused import. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update library/anima_train_utils.py Remove unused import. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix: review with Copilot * feat: add script to convert LoRA format to ComfyUI compatible format (WIP, not tested yet) * feat: add process_escape function to handle escape sequences in prompts * feat: enhance LoRA weight handling in model loading and add text encoder loading function * feat: improve ComfyUI conversion script with prefix constants and module name adjustments * feat: update caption dropout documentation to clarify cache regeneration requirement * feat: add clarification on learning rate adjustments * feat: add note on PyTorch version requirement to prevent NaN loss --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
760 lines
31 KiB
Python
760 lines
31 KiB
Python
# Anima full finetune training script
|
|
|
|
import argparse
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
import copy
|
|
import gc
|
|
import math
|
|
import os
|
|
from multiprocessing import Value
|
|
from typing import List
|
|
import toml
|
|
|
|
from tqdm import tqdm
|
|
|
|
import torch
|
|
from library import flux_train_utils, qwen_image_autoencoder_kl
|
|
from library.device_utils import init_ipex, clean_memory_on_device
|
|
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
|
|
|
|
init_ipex()
|
|
|
|
from accelerate.utils import set_seed
|
|
from library import deepspeed_utils, anima_models, anima_train_utils, anima_utils, strategy_base, strategy_anima, sai_model_spec
|
|
|
|
import library.train_util as train_util
|
|
|
|
from library.utils import setup_logging, add_logging_arguments
|
|
|
|
setup_logging()
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
import library.config_util as config_util
|
|
|
|
from library.config_util import (
|
|
ConfigSanitizer,
|
|
BlueprintGenerator,
|
|
)
|
|
from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments
|
|
|
|
|
|
def train(args):
|
|
train_util.verify_training_args(args)
|
|
train_util.prepare_dataset_args(args, True)
|
|
deepspeed_utils.prepare_deepspeed_args(args)
|
|
setup_logging(args, reset=True)
|
|
|
|
# backward compatibility
|
|
if not args.skip_cache_check:
|
|
args.skip_cache_check = args.skip_latents_validity_check
|
|
|
|
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
|
logger.warning("cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled")
|
|
args.cache_text_encoder_outputs = True
|
|
|
|
if args.cpu_offload_checkpointing and not args.gradient_checkpointing:
|
|
logger.warning("cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled")
|
|
args.gradient_checkpointing = True
|
|
|
|
if args.unsloth_offload_checkpointing:
|
|
if not args.gradient_checkpointing:
|
|
logger.warning("unsloth_offload_checkpointing is enabled, so gradient_checkpointing is also enabled")
|
|
args.gradient_checkpointing = True
|
|
assert not args.cpu_offload_checkpointing, "Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing"
|
|
|
|
assert (
|
|
args.blocks_to_swap is None or args.blocks_to_swap == 0
|
|
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing"
|
|
|
|
assert (
|
|
args.blocks_to_swap is None or args.blocks_to_swap == 0
|
|
) or not args.unsloth_offload_checkpointing, "blocks_to_swap is not supported with unsloth_offload_checkpointing"
|
|
|
|
cache_latents = args.cache_latents
|
|
use_dreambooth_method = args.in_json is None
|
|
|
|
if args.seed is not None:
|
|
set_seed(args.seed)
|
|
|
|
# prepare caching strategy: must be set before preparing dataset
|
|
if args.cache_latents:
|
|
latents_caching_strategy = strategy_anima.AnimaLatentsCachingStrategy(
|
|
args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
|
|
)
|
|
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
|
|
|
|
# prepare dataset
|
|
if args.dataset_class is None:
|
|
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
|
|
if args.dataset_config is not None:
|
|
logger.info(f"Load dataset config from {args.dataset_config}")
|
|
user_config = config_util.load_user_config(args.dataset_config)
|
|
ignored = ["train_data_dir", "in_json"]
|
|
if any(getattr(args, attr) is not None for attr in ignored):
|
|
logger.warning("ignore following options because config file is found: {0}".format(", ".join(ignored)))
|
|
else:
|
|
if use_dreambooth_method:
|
|
logger.info("Using DreamBooth method.")
|
|
user_config = {
|
|
"datasets": [
|
|
{
|
|
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
|
|
args.train_data_dir, args.reg_data_dir
|
|
)
|
|
}
|
|
]
|
|
}
|
|
else:
|
|
logger.info("Training with captions.")
|
|
user_config = {
|
|
"datasets": [
|
|
{
|
|
"subsets": [
|
|
{
|
|
"image_dir": args.train_data_dir,
|
|
"metadata_file": args.in_json,
|
|
}
|
|
]
|
|
}
|
|
]
|
|
}
|
|
|
|
blueprint = blueprint_generator.generate(user_config, args)
|
|
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
|
else:
|
|
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
|
val_dataset_group = None
|
|
|
|
current_epoch = Value("i", 0)
|
|
current_step = Value("i", 0)
|
|
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
|
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
|
|
|
train_dataset_group.verify_bucket_reso_steps(16) # Qwen-Image VAE spatial downscale = 8 * patch size = 2
|
|
|
|
if args.debug_dataset:
|
|
if args.cache_text_encoder_outputs:
|
|
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
|
|
strategy_anima.AnimaTextEncoderOutputsCachingStrategy(
|
|
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False
|
|
)
|
|
)
|
|
train_dataset_group.set_current_strategies()
|
|
train_util.debug_dataset(train_dataset_group, True)
|
|
return
|
|
if len(train_dataset_group) == 0:
|
|
logger.error("No data found. Please verify the metadata file and train_data_dir option.")
|
|
return
|
|
|
|
if cache_latents:
|
|
assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used"
|
|
|
|
if args.cache_text_encoder_outputs:
|
|
assert train_dataset_group.is_text_encoder_output_cacheable(
|
|
cache_supports_dropout=True
|
|
), "when caching text encoder output, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used"
|
|
|
|
# prepare accelerator
|
|
logger.info("prepare accelerator")
|
|
accelerator = train_util.prepare_accelerator(args)
|
|
|
|
# mixed precision dtype
|
|
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
|
|
|
# Load tokenizers and set strategies
|
|
logger.info("Loading tokenizers...")
|
|
qwen3_text_encoder, qwen3_tokenizer = anima_utils.load_qwen3_text_encoder(args.qwen3, dtype=weight_dtype, device="cpu")
|
|
t5_tokenizer = anima_utils.load_t5_tokenizer(args.t5_tokenizer_path)
|
|
|
|
# Set tokenize strategy
|
|
tokenize_strategy = strategy_anima.AnimaTokenizeStrategy(
|
|
qwen3_tokenizer=qwen3_tokenizer,
|
|
t5_tokenizer=t5_tokenizer,
|
|
qwen3_max_length=args.qwen3_max_token_length,
|
|
t5_max_length=args.t5_max_token_length,
|
|
)
|
|
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
|
|
|
|
text_encoding_strategy = strategy_anima.AnimaTextEncodingStrategy()
|
|
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
|
|
|
|
# Prepare text encoder (always frozen for Anima)
|
|
qwen3_text_encoder.to(weight_dtype)
|
|
qwen3_text_encoder.requires_grad_(False)
|
|
|
|
# Cache text encoder outputs
|
|
sample_prompts_te_outputs = None
|
|
if args.cache_text_encoder_outputs:
|
|
qwen3_text_encoder.to(accelerator.device)
|
|
qwen3_text_encoder.eval()
|
|
|
|
text_encoder_caching_strategy = strategy_anima.AnimaTextEncoderOutputsCachingStrategy(
|
|
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, is_partial=False
|
|
)
|
|
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy)
|
|
|
|
with accelerator.autocast():
|
|
train_dataset_group.new_cache_text_encoder_outputs([qwen3_text_encoder], accelerator)
|
|
|
|
# cache sample prompt embeddings
|
|
if args.sample_prompts is not None:
|
|
logger.info(f"Cache Text Encoder outputs for sample prompts: {args.sample_prompts}")
|
|
prompts = train_util.load_prompts(args.sample_prompts)
|
|
sample_prompts_te_outputs = {}
|
|
with accelerator.autocast(), torch.no_grad():
|
|
for prompt_dict in prompts:
|
|
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
|
|
if p not in sample_prompts_te_outputs:
|
|
logger.info(f" cache TE outputs for: {p}")
|
|
tokens_and_masks = tokenize_strategy.tokenize(p)
|
|
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
|
|
tokenize_strategy, [qwen3_text_encoder], tokens_and_masks
|
|
)
|
|
|
|
accelerator.wait_for_everyone()
|
|
|
|
# free text encoder memory
|
|
qwen3_text_encoder = None
|
|
gc.collect() # Force garbage collection to free memory
|
|
clean_memory_on_device(accelerator.device)
|
|
|
|
# Load VAE and cache latents
|
|
logger.info("Loading Anima VAE...")
|
|
vae = qwen_image_autoencoder_kl.load_vae(
|
|
args.vae, device="cpu", disable_mmap=True, spatial_chunk_size=args.vae_chunk_size, disable_cache=args.vae_disable_cache
|
|
)
|
|
|
|
if cache_latents:
|
|
vae.to(accelerator.device, dtype=weight_dtype)
|
|
vae.requires_grad_(False)
|
|
vae.eval()
|
|
|
|
train_dataset_group.new_cache_latents(vae, accelerator)
|
|
|
|
vae.to("cpu")
|
|
clean_memory_on_device(accelerator.device)
|
|
accelerator.wait_for_everyone()
|
|
|
|
# Load DiT (MiniTrainDIT + optional LLM Adapter)
|
|
logger.info("Loading Anima DiT...")
|
|
dit = anima_utils.load_anima_model(
|
|
"cpu", args.pretrained_model_name_or_path, args.attn_mode, args.split_attn, "cpu", dit_weight_dtype=None
|
|
)
|
|
|
|
if args.gradient_checkpointing:
|
|
dit.enable_gradient_checkpointing(
|
|
cpu_offload=args.cpu_offload_checkpointing,
|
|
unsloth_offload=args.unsloth_offload_checkpointing,
|
|
)
|
|
|
|
train_dit = args.learning_rate != 0
|
|
dit.requires_grad_(train_dit)
|
|
if not train_dit:
|
|
dit.to(accelerator.device, dtype=weight_dtype)
|
|
|
|
# Block swap
|
|
is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
|
|
if is_swapping_blocks:
|
|
logger.info(f"Enable block swap: blocks_to_swap={args.blocks_to_swap}")
|
|
dit.enable_block_swap(args.blocks_to_swap, accelerator.device)
|
|
|
|
if not cache_latents:
|
|
vae.requires_grad_(False)
|
|
vae.eval()
|
|
vae.to(accelerator.device, dtype=weight_dtype)
|
|
|
|
# Setup optimizer with parameter groups
|
|
if train_dit:
|
|
param_groups = anima_train_utils.get_anima_param_groups(
|
|
dit,
|
|
base_lr=args.learning_rate,
|
|
self_attn_lr=args.self_attn_lr,
|
|
cross_attn_lr=args.cross_attn_lr,
|
|
mlp_lr=args.mlp_lr,
|
|
mod_lr=args.mod_lr,
|
|
llm_adapter_lr=args.llm_adapter_lr,
|
|
)
|
|
else:
|
|
param_groups = []
|
|
|
|
training_models = []
|
|
if train_dit:
|
|
training_models.append(dit)
|
|
|
|
# calculate trainable parameters
|
|
n_params = 0
|
|
for group in param_groups:
|
|
for p in group["params"]:
|
|
n_params += p.numel()
|
|
|
|
accelerator.print(f"train dit: {train_dit}")
|
|
accelerator.print(f"number of training models: {len(training_models)}")
|
|
accelerator.print(f"number of trainable parameters: {n_params:,}")
|
|
|
|
# prepare optimizer
|
|
accelerator.print("prepare optimizer, data loader etc.")
|
|
|
|
if args.fused_backward_pass:
|
|
# Pass per-component param_groups directly to preserve per-component LRs
|
|
_, _, optimizer = train_util.get_optimizer(args, trainable_params=param_groups)
|
|
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)
|
|
else:
|
|
_, _, optimizer = train_util.get_optimizer(args, trainable_params=param_groups)
|
|
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)
|
|
|
|
# prepare dataloader
|
|
train_dataset_group.set_current_strategies()
|
|
|
|
n_workers = min(args.max_data_loader_n_workers, os.cpu_count())
|
|
train_dataloader = torch.utils.data.DataLoader(
|
|
train_dataset_group,
|
|
batch_size=1,
|
|
shuffle=True,
|
|
collate_fn=collator,
|
|
num_workers=n_workers,
|
|
persistent_workers=args.persistent_data_loader_workers,
|
|
)
|
|
|
|
# calculate training steps
|
|
if args.max_train_epochs is not None:
|
|
args.max_train_steps = args.max_train_epochs * math.ceil(
|
|
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
|
)
|
|
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs: {args.max_train_steps}")
|
|
|
|
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
|
|
|
# lr scheduler
|
|
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
|
|
|
# full fp16/bf16 training
|
|
dit_weight_dtype = weight_dtype
|
|
if args.full_fp16:
|
|
assert args.mixed_precision == "fp16", "full_fp16 requires mixed_precision='fp16'"
|
|
accelerator.print("enable full fp16 training.")
|
|
elif args.full_bf16:
|
|
assert args.mixed_precision == "bf16", "full_bf16 requires mixed_precision='bf16'"
|
|
accelerator.print("enable full bf16 training.")
|
|
else:
|
|
dit_weight_dtype = torch.float32 # If neither full_fp16 nor full_bf16, the model weights should be in float32
|
|
dit.to(dit_weight_dtype) # convert dit to target weight dtype
|
|
|
|
# move text encoder to GPU if not cached
|
|
if not args.cache_text_encoder_outputs and qwen3_text_encoder is not None:
|
|
qwen3_text_encoder.to(accelerator.device)
|
|
|
|
clean_memory_on_device(accelerator.device)
|
|
|
|
# Prepare with accelerator
|
|
# Temporarily move non-training models off GPU to reduce memory during DDP init
|
|
# if not args.cache_text_encoder_outputs and qwen3_text_encoder is not None:
|
|
# qwen3_text_encoder.to("cpu")
|
|
# if not cache_latents and vae is not None:
|
|
# vae.to("cpu")
|
|
# clean_memory_on_device(accelerator.device)
|
|
|
|
if args.deepspeed:
|
|
ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=dit)
|
|
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
|
ds_model, optimizer, train_dataloader, lr_scheduler
|
|
)
|
|
training_models = [ds_model]
|
|
else:
|
|
if train_dit:
|
|
dit = accelerator.prepare(dit, device_placement=[not is_swapping_blocks])
|
|
if is_swapping_blocks:
|
|
accelerator.unwrap_model(dit).move_to_device_except_swap_blocks(accelerator.device)
|
|
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
|
|
|
|
# Move non-training models back to GPU
|
|
if not args.cache_text_encoder_outputs and qwen3_text_encoder is not None:
|
|
qwen3_text_encoder.to(accelerator.device)
|
|
if not cache_latents and vae is not None:
|
|
vae.to(accelerator.device, dtype=weight_dtype)
|
|
|
|
if args.full_fp16:
|
|
train_util.patch_accelerator_for_fp16_training(accelerator)
|
|
|
|
# resume
|
|
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
|
|
|
if args.fused_backward_pass:
|
|
# use fused optimizer for backward pass: other optimizers will be supported in the future
|
|
import library.adafactor_fused
|
|
|
|
library.adafactor_fused.patch_adafactor_fused(optimizer)
|
|
|
|
for param_group in optimizer.param_groups:
|
|
for parameter in param_group["params"]:
|
|
if parameter.requires_grad:
|
|
|
|
def create_grad_hook(p_group):
|
|
def grad_hook(tensor: torch.Tensor):
|
|
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
|
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
|
|
optimizer.step_param(tensor, p_group)
|
|
tensor.grad = None
|
|
|
|
return grad_hook
|
|
|
|
parameter.register_post_accumulate_grad_hook(create_grad_hook(param_group))
|
|
|
|
# Training loop
|
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
|
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
|
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
|
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
|
|
|
accelerator.print("running training / 学習開始")
|
|
accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
|
|
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
|
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
|
accelerator.print(
|
|
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
|
|
)
|
|
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
|
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
|
|
|
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
|
global_step = 0
|
|
|
|
noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
|
|
# Copy for noise and timestep generation, because noise_scheduler may be changed during training in future
|
|
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
|
|
|
if accelerator.is_main_process:
|
|
init_kwargs = {}
|
|
if args.wandb_run_name:
|
|
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
|
if args.log_tracker_config is not None:
|
|
init_kwargs = toml.load(args.log_tracker_config)
|
|
accelerator.init_trackers(
|
|
"finetuning" if args.log_tracker_name is None else args.log_tracker_name,
|
|
config=train_util.get_sanitized_config_or_none(args),
|
|
init_kwargs=init_kwargs,
|
|
)
|
|
|
|
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
|
|
import wandb
|
|
|
|
wandb.define_metric("epoch")
|
|
wandb.define_metric("loss/epoch", step_metric="epoch")
|
|
|
|
if is_swapping_blocks:
|
|
accelerator.unwrap_model(dit).prepare_block_swap_before_forward()
|
|
|
|
# For --sample_at_first
|
|
optimizer_eval_fn()
|
|
anima_train_utils.sample_images(
|
|
accelerator,
|
|
args,
|
|
0,
|
|
global_step,
|
|
dit,
|
|
vae,
|
|
qwen3_text_encoder,
|
|
tokenize_strategy,
|
|
text_encoding_strategy,
|
|
sample_prompts_te_outputs,
|
|
)
|
|
optimizer_train_fn()
|
|
if len(accelerator.trackers) > 0:
|
|
accelerator.log({}, step=0)
|
|
|
|
# Show model info
|
|
unwrapped_dit = accelerator.unwrap_model(dit) if dit is not None else None
|
|
if unwrapped_dit is not None:
|
|
logger.info(f"dit device: {unwrapped_dit.device}, dtype: {unwrapped_dit.dtype}")
|
|
if qwen3_text_encoder is not None:
|
|
logger.info(f"qwen3 device: {qwen3_text_encoder.device}")
|
|
if vae is not None:
|
|
logger.info(f"vae device: {vae.device}")
|
|
|
|
loss_recorder = train_util.LossRecorder()
|
|
epoch = 0
|
|
for epoch in range(num_train_epochs):
|
|
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
|
current_epoch.value = epoch + 1
|
|
|
|
for m in training_models:
|
|
m.train()
|
|
|
|
for step, batch in enumerate(train_dataloader):
|
|
current_step.value = global_step
|
|
|
|
with accelerator.accumulate(*training_models):
|
|
# Get latents
|
|
if "latents" in batch and batch["latents"] is not None:
|
|
latents = batch["latents"].to(accelerator.device, dtype=dit_weight_dtype)
|
|
if latents.ndim == 5: # Fallback for 5D latents (old cache)
|
|
latents = latents.squeeze(2) # (B, C, 1, H, W) -> (B, C, H, W)
|
|
else:
|
|
with torch.no_grad():
|
|
# images are already [-1, 1] from IMAGE_TRANSFORMS, add temporal dim
|
|
images = batch["images"].to(accelerator.device, dtype=weight_dtype)
|
|
latents = vae.encode_pixels_to_latents(images).to(accelerator.device, dtype=dit_weight_dtype)
|
|
|
|
if torch.any(torch.isnan(latents)):
|
|
accelerator.print("NaN found in latents, replacing with zeros")
|
|
latents = torch.nan_to_num(latents, 0, out=latents)
|
|
|
|
# Get text encoder outputs
|
|
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
|
if text_encoder_outputs_list is not None:
|
|
# Cached outputs
|
|
caption_dropout_rates = text_encoder_outputs_list[-1]
|
|
text_encoder_outputs_list = text_encoder_outputs_list[:-1]
|
|
|
|
# Apply caption dropout to cached outputs
|
|
text_encoder_outputs_list = text_encoding_strategy.drop_cached_text_encoder_outputs(
|
|
*text_encoder_outputs_list, caption_dropout_rates=caption_dropout_rates
|
|
)
|
|
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoder_outputs_list
|
|
else:
|
|
# Encode on-the-fly
|
|
input_ids_list = batch["input_ids_list"]
|
|
with torch.no_grad():
|
|
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoding_strategy.encode_tokens(
|
|
tokenize_strategy, [qwen3_text_encoder], input_ids_list
|
|
)
|
|
|
|
# Move to device
|
|
prompt_embeds = prompt_embeds.to(accelerator.device, dtype=dit_weight_dtype)
|
|
attn_mask = attn_mask.to(accelerator.device)
|
|
t5_input_ids = t5_input_ids.to(accelerator.device, dtype=torch.long)
|
|
t5_attn_mask = t5_attn_mask.to(accelerator.device)
|
|
|
|
# Noise and timesteps
|
|
noise = torch.randn_like(latents)
|
|
|
|
# Get noisy model input and timesteps
|
|
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
|
|
args, noise_scheduler_copy, latents, noise, accelerator.device, dit_weight_dtype
|
|
)
|
|
timesteps = timesteps / 1000.0 # scale to [0, 1] range. timesteps is float32
|
|
|
|
# NaN checks
|
|
if torch.any(torch.isnan(noisy_model_input)):
|
|
accelerator.print("NaN found in noisy_model_input, replacing with zeros")
|
|
noisy_model_input = torch.nan_to_num(noisy_model_input, 0, out=noisy_model_input)
|
|
|
|
# Create padding mask
|
|
# padding_mask: (B, 1, H_latent, W_latent)
|
|
bs = latents.shape[0]
|
|
h_latent = latents.shape[-2]
|
|
w_latent = latents.shape[-1]
|
|
padding_mask = torch.zeros(bs, 1, h_latent, w_latent, dtype=dit_weight_dtype, device=accelerator.device)
|
|
|
|
# DiT forward (LLM adapter runs inside forward for DDP gradient sync)
|
|
noisy_model_input = noisy_model_input.unsqueeze(2) # 4D to 5D, (B, C, 1, H, W)
|
|
with accelerator.autocast():
|
|
model_pred = dit(
|
|
noisy_model_input,
|
|
timesteps,
|
|
prompt_embeds,
|
|
padding_mask=padding_mask,
|
|
source_attention_mask=attn_mask,
|
|
t5_input_ids=t5_input_ids,
|
|
t5_attn_mask=t5_attn_mask,
|
|
)
|
|
model_pred = model_pred.squeeze(2) # 5D to 4D, (B, C, H, W)
|
|
|
|
# Compute loss (rectified flow: target = noise - latents)
|
|
target = noise - latents
|
|
|
|
# Weighting
|
|
weighting = anima_train_utils.compute_loss_weighting_for_anima(
|
|
weighting_scheme=args.weighting_scheme, sigmas=sigmas
|
|
)
|
|
|
|
# Loss
|
|
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, None)
|
|
loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
|
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
|
loss = apply_masked_loss(loss, batch)
|
|
loss = loss.mean([1, 2, 3]) # (B, C, H, W) -> (B,)
|
|
|
|
if weighting is not None:
|
|
loss = loss * weighting
|
|
|
|
loss_weights = batch["loss_weights"]
|
|
loss = loss * loss_weights
|
|
loss = loss.mean()
|
|
|
|
accelerator.backward(loss)
|
|
|
|
if not args.fused_backward_pass:
|
|
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
|
params_to_clip = []
|
|
for m in training_models:
|
|
params_to_clip.extend(m.parameters())
|
|
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
|
|
|
optimizer.step()
|
|
lr_scheduler.step()
|
|
optimizer.zero_grad(set_to_none=True)
|
|
else:
|
|
# optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
|
|
lr_scheduler.step()
|
|
|
|
# Checks if the accelerator has performed an optimization step
|
|
if accelerator.sync_gradients:
|
|
progress_bar.update(1)
|
|
global_step += 1
|
|
|
|
optimizer_eval_fn()
|
|
anima_train_utils.sample_images(
|
|
accelerator,
|
|
args,
|
|
None,
|
|
global_step,
|
|
dit,
|
|
vae,
|
|
qwen3_text_encoder,
|
|
tokenize_strategy,
|
|
text_encoding_strategy,
|
|
sample_prompts_te_outputs,
|
|
)
|
|
|
|
# Save at specific steps
|
|
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
|
accelerator.wait_for_everyone()
|
|
if accelerator.is_main_process:
|
|
anima_train_utils.save_anima_model_on_epoch_end_or_stepwise(
|
|
args,
|
|
False,
|
|
accelerator,
|
|
save_dtype,
|
|
epoch,
|
|
num_train_epochs,
|
|
global_step,
|
|
accelerator.unwrap_model(dit) if train_dit else None,
|
|
)
|
|
optimizer_train_fn()
|
|
|
|
current_loss = loss.detach().item()
|
|
if len(accelerator.trackers) > 0:
|
|
logs = {"loss": current_loss}
|
|
train_util.append_lr_to_logs_with_names(
|
|
logs,
|
|
lr_scheduler,
|
|
args.optimizer_type,
|
|
["base", "self_attn", "cross_attn", "mlp", "mod", "llm_adapter"] if train_dit else [],
|
|
)
|
|
accelerator.log(logs, step=global_step)
|
|
|
|
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
|
avr_loss: float = loss_recorder.moving_average
|
|
logs = {"avr_loss": avr_loss}
|
|
progress_bar.set_postfix(**logs)
|
|
|
|
if global_step >= args.max_train_steps:
|
|
break
|
|
|
|
if len(accelerator.trackers) > 0:
|
|
logs = {"loss/epoch": loss_recorder.moving_average, "epoch": epoch + 1}
|
|
accelerator.log(logs, step=global_step)
|
|
|
|
accelerator.wait_for_everyone()
|
|
|
|
optimizer_eval_fn()
|
|
if args.save_every_n_epochs is not None:
|
|
if accelerator.is_main_process:
|
|
anima_train_utils.save_anima_model_on_epoch_end_or_stepwise(
|
|
args,
|
|
True,
|
|
accelerator,
|
|
save_dtype,
|
|
epoch,
|
|
num_train_epochs,
|
|
global_step,
|
|
accelerator.unwrap_model(dit) if train_dit else None,
|
|
)
|
|
|
|
anima_train_utils.sample_images(
|
|
accelerator,
|
|
args,
|
|
epoch + 1,
|
|
global_step,
|
|
dit,
|
|
vae,
|
|
qwen3_text_encoder,
|
|
tokenize_strategy,
|
|
text_encoding_strategy,
|
|
sample_prompts_te_outputs,
|
|
)
|
|
|
|
# End training
|
|
is_main_process = accelerator.is_main_process
|
|
dit = accelerator.unwrap_model(dit)
|
|
|
|
accelerator.end_training()
|
|
optimizer_eval_fn()
|
|
|
|
if args.save_state or args.save_state_on_train_end:
|
|
train_util.save_state_on_train_end(args, accelerator)
|
|
|
|
del accelerator
|
|
|
|
if is_main_process and train_dit:
|
|
anima_train_utils.save_anima_model_on_train_end(
|
|
args,
|
|
save_dtype,
|
|
epoch,
|
|
global_step,
|
|
dit,
|
|
)
|
|
logger.info("model saved.")
|
|
|
|
|
|
def setup_parser() -> argparse.ArgumentParser:
|
|
parser = argparse.ArgumentParser()
|
|
|
|
add_logging_arguments(parser)
|
|
train_util.add_sd_models_arguments(parser)
|
|
train_util.add_dataset_arguments(parser, True, True, True)
|
|
train_util.add_training_arguments(parser, False)
|
|
train_util.add_masked_loss_arguments(parser)
|
|
deepspeed_utils.add_deepspeed_arguments(parser)
|
|
train_util.add_sd_saving_arguments(parser)
|
|
train_util.add_optimizer_arguments(parser)
|
|
config_util.add_config_arguments(parser)
|
|
add_custom_train_arguments(parser)
|
|
train_util.add_dit_training_arguments(parser)
|
|
anima_train_utils.add_anima_training_arguments(parser)
|
|
sai_model_spec.add_model_spec_arguments(parser)
|
|
|
|
parser.add_argument(
|
|
"--cpu_offload_checkpointing",
|
|
action="store_true",
|
|
help="offload gradient checkpointing to CPU (reduces VRAM at cost of speed)",
|
|
)
|
|
parser.add_argument(
|
|
"--unsloth_offload_checkpointing",
|
|
action="store_true",
|
|
help="offload activations to CPU RAM using async non-blocking transfers (faster than --cpu_offload_checkpointing). "
|
|
"Cannot be used with --cpu_offload_checkpointing or --blocks_to_swap.",
|
|
)
|
|
parser.add_argument(
|
|
"--skip_latents_validity_check",
|
|
action="store_true",
|
|
help="[Deprecated] use 'skip_cache_check' instead",
|
|
)
|
|
|
|
return parser
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = setup_parser()
|
|
|
|
args = parser.parse_args()
|
|
train_util.verify_command_line_training_args(args)
|
|
args = train_util.read_config_from_file(args, parser)
|
|
|
|
if args.attn_mode == "sdpa":
|
|
args.attn_mode = "torch" # backward compatibility
|
|
|
|
train(args)
|