# Anima full finetune training script import argparse from concurrent.futures import ThreadPoolExecutor import copy import gc import math import os from multiprocessing import Value from typing import List import toml from tqdm import tqdm import torch from library import flux_train_utils, qwen_image_autoencoder_kl from library.device_utils import init_ipex, clean_memory_on_device from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler init_ipex() from accelerate.utils import set_seed from library import deepspeed_utils, anima_models, anima_train_utils, anima_utils, strategy_base, strategy_anima, sai_model_spec import library.train_util as train_util from library.utils import setup_logging, add_logging_arguments setup_logging() import logging logger = logging.getLogger(__name__) import library.config_util as config_util from library.config_util import ( ConfigSanitizer, BlueprintGenerator, ) from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments def train(args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) deepspeed_utils.prepare_deepspeed_args(args) setup_logging(args, reset=True) # backward compatibility if not args.skip_cache_check: args.skip_cache_check = args.skip_latents_validity_check if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs: logger.warning("cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled") args.cache_text_encoder_outputs = True if args.cpu_offload_checkpointing and not args.gradient_checkpointing: logger.warning("cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled") args.gradient_checkpointing = True if args.unsloth_offload_checkpointing: if not args.gradient_checkpointing: logger.warning("unsloth_offload_checkpointing is enabled, so gradient_checkpointing is also enabled") args.gradient_checkpointing = True assert not args.cpu_offload_checkpointing, "Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing" assert ( args.blocks_to_swap is None or args.blocks_to_swap == 0 ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing" assert ( args.blocks_to_swap is None or args.blocks_to_swap == 0 ) or not args.unsloth_offload_checkpointing, "blocks_to_swap is not supported with unsloth_offload_checkpointing" cache_latents = args.cache_latents use_dreambooth_method = args.in_json is None if args.seed is not None: set_seed(args.seed) # prepare caching strategy: must be set before preparing dataset if args.cache_latents: latents_caching_strategy = strategy_anima.AnimaLatentsCachingStrategy( args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check ) strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy) # prepare dataset if args.dataset_class is None: blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) if args.dataset_config is not None: logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) ignored = ["train_data_dir", "in_json"] if any(getattr(args, attr) is not None for attr in ignored): logger.warning("ignore following options because config file is found: {0}".format(", ".join(ignored))) else: if use_dreambooth_method: logger.info("Using DreamBooth method.") user_config = { "datasets": [ { "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs( args.train_data_dir, args.reg_data_dir ) } ] } else: logger.info("Training with captions.") user_config = { "datasets": [ { "subsets": [ { "image_dir": args.train_data_dir, "metadata_file": args.in_json, } ] } ] } blueprint = blueprint_generator.generate(user_config, args) train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: train_dataset_group = train_util.load_arbitrary_dataset(args) val_dataset_group = None current_epoch = Value("i", 0) current_step = Value("i", 0) ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) train_dataset_group.verify_bucket_reso_steps(16) # Qwen-Image VAE spatial downscale = 8 * patch size = 2 if args.debug_dataset: if args.cache_text_encoder_outputs: strategy_base.TextEncoderOutputsCachingStrategy.set_strategy( strategy_anima.AnimaTextEncoderOutputsCachingStrategy( args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False ) ) train_dataset_group.set_current_strategies() train_util.debug_dataset(train_dataset_group, True) return if len(train_dataset_group) == 0: logger.error("No data found. Please verify the metadata file and train_data_dir option.") return if cache_latents: assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used" if args.cache_text_encoder_outputs: assert train_dataset_group.is_text_encoder_output_cacheable( cache_supports_dropout=True ), "when caching text encoder output, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used" # prepare accelerator logger.info("prepare accelerator") accelerator = train_util.prepare_accelerator(args) # mixed precision dtype weight_dtype, save_dtype = train_util.prepare_dtype(args) # Load tokenizers and set strategies logger.info("Loading tokenizers...") qwen3_text_encoder, qwen3_tokenizer = anima_utils.load_qwen3_text_encoder(args.qwen3, dtype=weight_dtype, device="cpu") t5_tokenizer = anima_utils.load_t5_tokenizer(args.t5_tokenizer_path) # Set tokenize strategy tokenize_strategy = strategy_anima.AnimaTokenizeStrategy( qwen3_tokenizer=qwen3_tokenizer, t5_tokenizer=t5_tokenizer, qwen3_max_length=args.qwen3_max_token_length, t5_max_length=args.t5_max_token_length, ) strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy) text_encoding_strategy = strategy_anima.AnimaTextEncodingStrategy() strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy) # Prepare text encoder (always frozen for Anima) qwen3_text_encoder.to(weight_dtype) qwen3_text_encoder.requires_grad_(False) # Cache text encoder outputs sample_prompts_te_outputs = None if args.cache_text_encoder_outputs: qwen3_text_encoder.to(accelerator.device) qwen3_text_encoder.eval() text_encoder_caching_strategy = strategy_anima.AnimaTextEncoderOutputsCachingStrategy( args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, is_partial=False ) strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy) with accelerator.autocast(): train_dataset_group.new_cache_text_encoder_outputs([qwen3_text_encoder], accelerator) # cache sample prompt embeddings if args.sample_prompts is not None: logger.info(f"Cache Text Encoder outputs for sample prompts: {args.sample_prompts}") prompts = train_util.load_prompts(args.sample_prompts) sample_prompts_te_outputs = {} with accelerator.autocast(), torch.no_grad(): for prompt_dict in prompts: for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]: if p not in sample_prompts_te_outputs: logger.info(f" cache TE outputs for: {p}") tokens_and_masks = tokenize_strategy.tokenize(p) sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens( tokenize_strategy, [qwen3_text_encoder], tokens_and_masks ) accelerator.wait_for_everyone() # free text encoder memory qwen3_text_encoder = None gc.collect() # Force garbage collection to free memory clean_memory_on_device(accelerator.device) # Load VAE and cache latents logger.info("Loading Anima VAE...") vae = qwen_image_autoencoder_kl.load_vae( args.vae, device="cpu", disable_mmap=True, spatial_chunk_size=args.vae_chunk_size, disable_cache=args.vae_disable_cache ) if cache_latents: vae.to(accelerator.device, dtype=weight_dtype) vae.requires_grad_(False) vae.eval() train_dataset_group.new_cache_latents(vae, accelerator) vae.to("cpu") clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() # Load DiT (MiniTrainDIT + optional LLM Adapter) logger.info("Loading Anima DiT...") dit = anima_utils.load_anima_model( "cpu", args.pretrained_model_name_or_path, args.attn_mode, args.split_attn, "cpu", dit_weight_dtype=None ) if args.gradient_checkpointing: dit.enable_gradient_checkpointing( cpu_offload=args.cpu_offload_checkpointing, unsloth_offload=args.unsloth_offload_checkpointing, ) train_dit = args.learning_rate != 0 dit.requires_grad_(train_dit) if not train_dit: dit.to(accelerator.device, dtype=weight_dtype) # Block swap is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0 if is_swapping_blocks: logger.info(f"Enable block swap: blocks_to_swap={args.blocks_to_swap}") dit.enable_block_swap(args.blocks_to_swap, accelerator.device) if not cache_latents: vae.requires_grad_(False) vae.eval() vae.to(accelerator.device, dtype=weight_dtype) # Setup optimizer with parameter groups if train_dit: param_groups = anima_train_utils.get_anima_param_groups( dit, base_lr=args.learning_rate, self_attn_lr=args.self_attn_lr, cross_attn_lr=args.cross_attn_lr, mlp_lr=args.mlp_lr, mod_lr=args.mod_lr, llm_adapter_lr=args.llm_adapter_lr, ) else: param_groups = [] training_models = [] if train_dit: training_models.append(dit) # calculate trainable parameters n_params = 0 for group in param_groups: for p in group["params"]: n_params += p.numel() accelerator.print(f"train dit: {train_dit}") accelerator.print(f"number of training models: {len(training_models)}") accelerator.print(f"number of trainable parameters: {n_params:,}") # prepare optimizer accelerator.print("prepare optimizer, data loader etc.") if args.fused_backward_pass: # Pass per-component param_groups directly to preserve per-component LRs _, _, optimizer = train_util.get_optimizer(args, trainable_params=param_groups) optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args) else: _, _, optimizer = train_util.get_optimizer(args, trainable_params=param_groups) optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args) # prepare dataloader train_dataset_group.set_current_strategies() n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) train_dataloader = torch.utils.data.DataLoader( train_dataset_group, batch_size=1, shuffle=True, collate_fn=collator, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) # calculate training steps if args.max_train_epochs is not None: args.max_train_steps = args.max_train_epochs * math.ceil( len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps ) accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs: {args.max_train_steps}") train_dataset_group.set_max_train_steps(args.max_train_steps) # lr scheduler lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) # full fp16/bf16 training dit_weight_dtype = weight_dtype if args.full_fp16: assert args.mixed_precision == "fp16", "full_fp16 requires mixed_precision='fp16'" accelerator.print("enable full fp16 training.") elif args.full_bf16: assert args.mixed_precision == "bf16", "full_bf16 requires mixed_precision='bf16'" accelerator.print("enable full bf16 training.") else: dit_weight_dtype = torch.float32 # If neither full_fp16 nor full_bf16, the model weights should be in float32 dit.to(dit_weight_dtype) # convert dit to target weight dtype # move text encoder to GPU if not cached if not args.cache_text_encoder_outputs and qwen3_text_encoder is not None: qwen3_text_encoder.to(accelerator.device) clean_memory_on_device(accelerator.device) # Prepare with accelerator # Temporarily move non-training models off GPU to reduce memory during DDP init # if not args.cache_text_encoder_outputs and qwen3_text_encoder is not None: # qwen3_text_encoder.to("cpu") # if not cache_latents and vae is not None: # vae.to("cpu") # clean_memory_on_device(accelerator.device) if args.deepspeed: ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=dit) ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( ds_model, optimizer, train_dataloader, lr_scheduler ) training_models = [ds_model] else: if train_dit: dit = accelerator.prepare(dit, device_placement=[not is_swapping_blocks]) if is_swapping_blocks: accelerator.unwrap_model(dit).move_to_device_except_swap_blocks(accelerator.device) optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) # Move non-training models back to GPU if not args.cache_text_encoder_outputs and qwen3_text_encoder is not None: qwen3_text_encoder.to(accelerator.device) if not cache_latents and vae is not None: vae.to(accelerator.device, dtype=weight_dtype) if args.full_fp16: train_util.patch_accelerator_for_fp16_training(accelerator) # resume train_util.resume_from_local_or_hf_if_specified(accelerator, args) if args.fused_backward_pass: # use fused optimizer for backward pass: other optimizers will be supported in the future import library.adafactor_fused library.adafactor_fused.patch_adafactor_fused(optimizer) for param_group in optimizer.param_groups: for parameter in param_group["params"]: if parameter.requires_grad: def create_grad_hook(p_group): def grad_hook(tensor: torch.Tensor): if accelerator.sync_gradients and args.max_grad_norm != 0.0: accelerator.clip_grad_norm_(tensor, args.max_grad_norm) optimizer.step_param(tensor, p_group) tensor.grad = None return grad_hook parameter.register_post_accumulate_grad_hook(create_grad_hook(param_group)) # Training loop num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 accelerator.print("running training / 学習開始") accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}") accelerator.print( f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}" ) accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") global_step = 0 noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) # Copy for noise and timestep generation, because noise_scheduler may be changed during training in future noise_scheduler_copy = copy.deepcopy(noise_scheduler) if accelerator.is_main_process: init_kwargs = {} if args.wandb_run_name: init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( "finetuning" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs, ) if "wandb" in [tracker.name for tracker in accelerator.trackers]: import wandb wandb.define_metric("epoch") wandb.define_metric("loss/epoch", step_metric="epoch") if is_swapping_blocks: accelerator.unwrap_model(dit).prepare_block_swap_before_forward() # For --sample_at_first optimizer_eval_fn() anima_train_utils.sample_images( accelerator, args, 0, global_step, dit, vae, qwen3_text_encoder, tokenize_strategy, text_encoding_strategy, sample_prompts_te_outputs, ) optimizer_train_fn() if len(accelerator.trackers) > 0: accelerator.log({}, step=0) # Show model info unwrapped_dit = accelerator.unwrap_model(dit) if dit is not None else None if unwrapped_dit is not None: logger.info(f"dit device: {unwrapped_dit.device}, dtype: {unwrapped_dit.dtype}") if qwen3_text_encoder is not None: logger.info(f"qwen3 device: {qwen3_text_encoder.device}") if vae is not None: logger.info(f"vae device: {vae.device}") loss_recorder = train_util.LossRecorder() epoch = 0 for epoch in range(num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 for m in training_models: m.train() for step, batch in enumerate(train_dataloader): current_step.value = global_step with accelerator.accumulate(*training_models): # Get latents if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device, dtype=dit_weight_dtype) if latents.ndim == 5: # Fallback for 5D latents (old cache) latents = latents.squeeze(2) # (B, C, 1, H, W) -> (B, C, H, W) else: with torch.no_grad(): # images are already [-1, 1] from IMAGE_TRANSFORMS, add temporal dim images = batch["images"].to(accelerator.device, dtype=weight_dtype) latents = vae.encode_pixels_to_latents(images).to(accelerator.device, dtype=dit_weight_dtype) if torch.any(torch.isnan(latents)): accelerator.print("NaN found in latents, replacing with zeros") latents = torch.nan_to_num(latents, 0, out=latents) # Get text encoder outputs text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: # Cached outputs caption_dropout_rates = text_encoder_outputs_list[-1] text_encoder_outputs_list = text_encoder_outputs_list[:-1] # Apply caption dropout to cached outputs text_encoder_outputs_list = text_encoding_strategy.drop_cached_text_encoder_outputs( *text_encoder_outputs_list, caption_dropout_rates=caption_dropout_rates ) prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoder_outputs_list else: # Encode on-the-fly input_ids_list = batch["input_ids_list"] with torch.no_grad(): prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoding_strategy.encode_tokens( tokenize_strategy, [qwen3_text_encoder], input_ids_list ) # Move to device prompt_embeds = prompt_embeds.to(accelerator.device, dtype=dit_weight_dtype) attn_mask = attn_mask.to(accelerator.device) t5_input_ids = t5_input_ids.to(accelerator.device, dtype=torch.long) t5_attn_mask = t5_attn_mask.to(accelerator.device) # Noise and timesteps noise = torch.randn_like(latents) # Get noisy model input and timesteps noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( args, noise_scheduler_copy, latents, noise, accelerator.device, dit_weight_dtype ) timesteps = timesteps / 1000.0 # scale to [0, 1] range. timesteps is float32 # NaN checks if torch.any(torch.isnan(noisy_model_input)): accelerator.print("NaN found in noisy_model_input, replacing with zeros") noisy_model_input = torch.nan_to_num(noisy_model_input, 0, out=noisy_model_input) # Create padding mask # padding_mask: (B, 1, H_latent, W_latent) bs = latents.shape[0] h_latent = latents.shape[-2] w_latent = latents.shape[-1] padding_mask = torch.zeros(bs, 1, h_latent, w_latent, dtype=dit_weight_dtype, device=accelerator.device) # DiT forward (LLM adapter runs inside forward for DDP gradient sync) noisy_model_input = noisy_model_input.unsqueeze(2) # 4D to 5D, (B, C, 1, H, W) with accelerator.autocast(): model_pred = dit( noisy_model_input, timesteps, prompt_embeds, padding_mask=padding_mask, source_attention_mask=attn_mask, t5_input_ids=t5_input_ids, t5_attn_mask=t5_attn_mask, ) model_pred = model_pred.squeeze(2) # 5D to 4D, (B, C, H, W) # Compute loss (rectified flow: target = noise - latents) target = noise - latents # Weighting weighting = anima_train_utils.compute_loss_weighting_for_anima( weighting_scheme=args.weighting_scheme, sigmas=sigmas ) # Loss huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, None) loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) # (B, C, H, W) -> (B,) if weighting is not None: loss = loss * weighting loss_weights = batch["loss_weights"] loss = loss * loss_weights loss = loss.mean() accelerator.backward(loss) if not args.fused_backward_pass: if accelerator.sync_gradients and args.max_grad_norm != 0.0: params_to_clip = [] for m in training_models: params_to_clip.extend(m.parameters()) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=True) else: # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook lr_scheduler.step() # Checks if the accelerator has performed an optimization step if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 optimizer_eval_fn() anima_train_utils.sample_images( accelerator, args, None, global_step, dit, vae, qwen3_text_encoder, tokenize_strategy, text_encoding_strategy, sample_prompts_te_outputs, ) # Save at specific steps if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: accelerator.wait_for_everyone() if accelerator.is_main_process: anima_train_utils.save_anima_model_on_epoch_end_or_stepwise( args, False, accelerator, save_dtype, epoch, num_train_epochs, global_step, accelerator.unwrap_model(dit) if train_dit else None, ) optimizer_train_fn() current_loss = loss.detach().item() if len(accelerator.trackers) > 0: logs = {"loss": current_loss} train_util.append_lr_to_logs_with_names( logs, lr_scheduler, args.optimizer_type, ["base", "self_attn", "cross_attn", "mlp", "mod", "llm_adapter"] if train_dit else [], ) accelerator.log(logs, step=global_step) loss_recorder.add(epoch=epoch, step=step, loss=current_loss) avr_loss: float = loss_recorder.moving_average logs = {"avr_loss": avr_loss} progress_bar.set_postfix(**logs) if global_step >= args.max_train_steps: break if len(accelerator.trackers) > 0: logs = {"loss/epoch": loss_recorder.moving_average, "epoch": epoch + 1} accelerator.log(logs, step=global_step) accelerator.wait_for_everyone() optimizer_eval_fn() if args.save_every_n_epochs is not None: if accelerator.is_main_process: anima_train_utils.save_anima_model_on_epoch_end_or_stepwise( args, True, accelerator, save_dtype, epoch, num_train_epochs, global_step, accelerator.unwrap_model(dit) if train_dit else None, ) anima_train_utils.sample_images( accelerator, args, epoch + 1, global_step, dit, vae, qwen3_text_encoder, tokenize_strategy, text_encoding_strategy, sample_prompts_te_outputs, ) # End training is_main_process = accelerator.is_main_process dit = accelerator.unwrap_model(dit) accelerator.end_training() optimizer_eval_fn() if args.save_state or args.save_state_on_train_end: train_util.save_state_on_train_end(args, accelerator) del accelerator if is_main_process and train_dit: anima_train_utils.save_anima_model_on_train_end( args, save_dtype, epoch, global_step, dit, ) logger.info("model saved.") def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, True, True) train_util.add_training_arguments(parser, False) train_util.add_masked_loss_arguments(parser) deepspeed_utils.add_deepspeed_arguments(parser) train_util.add_sd_saving_arguments(parser) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) add_custom_train_arguments(parser) train_util.add_dit_training_arguments(parser) anima_train_utils.add_anima_training_arguments(parser) sai_model_spec.add_model_spec_arguments(parser) parser.add_argument( "--cpu_offload_checkpointing", action="store_true", help="offload gradient checkpointing to CPU (reduces VRAM at cost of speed)", ) parser.add_argument( "--unsloth_offload_checkpointing", action="store_true", help="offload activations to CPU RAM using async non-blocking transfers (faster than --cpu_offload_checkpointing). " "Cannot be used with --cpu_offload_checkpointing or --blocks_to_swap.", ) parser.add_argument( "--skip_latents_validity_check", action="store_true", help="[Deprecated] use 'skip_cache_check' instead", ) return parser if __name__ == "__main__": parser = setup_parser() args = parser.parse_args() train_util.verify_command_line_training_args(args) args = train_util.read_config_from_file(args, parser) if args.attn_mode == "sdpa": args.attn_mode = "torch" # backward compatibility train(args)