diff --git a/_typos.toml b/_typos.toml
index 686da4af..fc33b6b3 100644
--- a/_typos.toml
+++ b/_typos.toml
@@ -32,6 +32,7 @@ hime="hime"
OT="OT"
byt="byt"
tak="tak"
+temperal="temperal"
[files]
-extend-exclude = ["_typos.toml", "venv"]
+extend-exclude = ["_typos.toml", "venv", "configs"]
diff --git a/anima_minimal_inference.py b/anima_minimal_inference.py
new file mode 100644
index 00000000..2a6d4ba4
--- /dev/null
+++ b/anima_minimal_inference.py
@@ -0,0 +1,1082 @@
+import argparse
+import datetime
+import gc
+from importlib.util import find_spec
+import random
+import os
+import time
+import copy
+from types import SimpleNamespace
+from typing import Tuple, Optional, List, Any, Dict, Union
+
+import torch
+from safetensors.torch import load_file, save_file
+from safetensors import safe_open
+from tqdm import tqdm
+from diffusers.utils.torch_utils import randn_tensor
+from PIL import Image
+
+from library import anima_models, anima_utils, hunyuan_image_utils, qwen_image_autoencoder_kl, strategy_anima, strategy_base
+from library.device_utils import clean_memory_on_device, synchronize_device
+
+lycoris_available = find_spec("lycoris") is not None
+if lycoris_available:
+ from lycoris.kohya import create_network_from_weights
+
+from library.utils import setup_logging
+
+setup_logging()
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class GenerationSettings:
+ def __init__(self, device: torch.device, dit_weight_dtype: Optional[torch.dtype] = None):
+ self.device = device
+ self.dit_weight_dtype = dit_weight_dtype # not used currently because model may be optimized
+
+
+def parse_args() -> argparse.Namespace:
+ """parse command line arguments"""
+ parser = argparse.ArgumentParser(description="HunyuanImage inference script")
+
+ parser.add_argument("--dit", type=str, default=None, help="DiT directory or path")
+ parser.add_argument("--vae", type=str, default=None, help="VAE directory or path")
+ parser.add_argument(
+ "--vae_chunk_size",
+ type=int,
+ default=None,
+ help="Spatial chunk size for VAE encoding/decoding to reduce memory usage. Must be even number. If not specified, chunking is disabled (official behavior)."
+ + " / メモリ使用量を減らすためのVAEエンコード/デコードの空間チャンクサイズ。偶数である必要があります。未指定の場合、チャンク処理は無効になります(公式の動作)。",
+ )
+ parser.add_argument(
+ "--vae_disable_cache",
+ action="store_true",
+ help="Disable internal VAE caching mechanism to reduce memory usage. Encoding / decoding will also be faster, but this differs from official behavior."
+ + " / VAEのメモリ使用量を減らすために内部のキャッシュ機構を無効にします。エンコード/デコードも速くなりますが、公式の動作とは異なります。",
+ )
+ parser.add_argument("--text_encoder", type=str, required=True, help="Text Encoder 1 (Qwen2.5-VL) directory or path")
+
+ # LoRA
+ parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path")
+ parser.add_argument("--lora_multiplier", type=float, nargs="*", default=1.0, help="LoRA multiplier")
+ parser.add_argument("--include_patterns", type=str, nargs="*", default=None, help="LoRA module include patterns")
+ parser.add_argument("--exclude_patterns", type=str, nargs="*", default=None, help="LoRA module exclude patterns")
+
+ # inference
+ parser.add_argument(
+ "--guidance_scale", type=float, default=3.5, help="Guidance scale for classifier free guidance. Default is 3.5."
+ )
+ parser.add_argument("--prompt", type=str, default=None, help="prompt for generation")
+ parser.add_argument("--negative_prompt", type=str, default="", help="negative prompt for generation, default is empty string")
+ parser.add_argument("--image_size", type=int, nargs=2, default=[1024, 1024], help="image size, height and width")
+ parser.add_argument("--infer_steps", type=int, default=50, help="number of inference steps, default is 50")
+ parser.add_argument("--save_path", type=str, required=True, help="path to save generated video")
+ parser.add_argument("--seed", type=int, default=None, help="Seed for evaluation.")
+
+ # Flow Matching
+ parser.add_argument(
+ "--flow_shift",
+ type=float,
+ default=5.0,
+ help="Shift factor for flow matching schedulers. Default is 5.0.",
+ )
+
+ parser.add_argument("--fp8", action="store_true", help="use fp8 for DiT model")
+ parser.add_argument("--fp8_scaled", action="store_true", help="use scaled fp8 for DiT, only for fp8")
+
+ parser.add_argument("--text_encoder_cpu", action="store_true", help="Inference on CPU for Text Encoders")
+ parser.add_argument(
+ "--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU"
+ )
+ parser.add_argument(
+ "--attn_mode",
+ type=str,
+ default="torch",
+ choices=["flash", "torch", "sageattn", "xformers", "sdpa"], # "sdpa" for backward compatibility
+ help="attention mode",
+ )
+ parser.add_argument(
+ "--output_type",
+ type=str,
+ default="images",
+ choices=["images", "latent", "latent_images"],
+ help="output type",
+ )
+ parser.add_argument("--no_metadata", action="store_true", help="do not save metadata")
+ parser.add_argument("--latent_path", type=str, nargs="*", default=None, help="path to latent for decode. no inference")
+ parser.add_argument(
+ "--lycoris", action="store_true", help=f"use lycoris for inference{'' if lycoris_available else ' (not available)'}"
+ )
+
+ # arguments for batch and interactive modes
+ parser.add_argument("--from_file", type=str, default=None, help="Read prompts from a file")
+ parser.add_argument("--interactive", action="store_true", help="Interactive mode: read prompts from console")
+
+ args = parser.parse_args()
+
+ # Validate arguments
+ if args.from_file and args.interactive:
+ raise ValueError("Cannot use both --from_file and --interactive at the same time")
+
+ if args.latent_path is None or len(args.latent_path) == 0:
+ if args.prompt is None and not args.from_file and not args.interactive:
+ raise ValueError("Either --prompt, --from_file or --interactive must be specified")
+
+ if args.lycoris and not lycoris_available:
+ raise ValueError("install lycoris: https://github.com/KohakuBlueleaf/LyCORIS")
+
+ if args.attn_mode == "sdpa":
+ args.attn_mode = "torch" # backward compatibility
+
+ return args
+
+
+def parse_prompt_line(line: str) -> Dict[str, Any]:
+ """Parse a prompt line into a dictionary of argument overrides
+
+ Args:
+ line: Prompt line with options
+
+ Returns:
+ Dict[str, Any]: Dictionary of argument overrides
+ """
+ parts = line.split(" --")
+ prompt = parts[0].strip()
+
+ # Create dictionary of overrides
+ overrides = {"prompt": prompt}
+
+ for part in parts[1:]:
+ if not part.strip():
+ continue
+ option_parts = part.split(" ", 1)
+ option = option_parts[0].strip()
+ value = option_parts[1].strip() if len(option_parts) > 1 else ""
+
+ # Map options to argument names
+ if option == "w":
+ overrides["image_size_width"] = int(value)
+ elif option == "h":
+ overrides["image_size_height"] = int(value)
+ elif option == "d":
+ overrides["seed"] = int(value)
+ elif option == "s":
+ overrides["infer_steps"] = int(value)
+ elif option == "g" or option == "l":
+ overrides["guidance_scale"] = float(value)
+ elif option == "fs":
+ overrides["flow_shift"] = float(value)
+ elif option == "n":
+ overrides["negative_prompt"] = value
+
+ return overrides
+
+
+def apply_overrides(args: argparse.Namespace, overrides: Dict[str, Any]) -> argparse.Namespace:
+ """Apply overrides to args
+
+ Args:
+ args: Original arguments
+ overrides: Dictionary of overrides
+
+ Returns:
+ argparse.Namespace: New arguments with overrides applied
+ """
+ args_copy = copy.deepcopy(args)
+
+ for key, value in overrides.items():
+ if key == "image_size_width":
+ args_copy.image_size[1] = value
+ elif key == "image_size_height":
+ args_copy.image_size[0] = value
+ else:
+ setattr(args_copy, key, value)
+
+ return args_copy
+
+
+def check_inputs(args: argparse.Namespace) -> Tuple[int, int]:
+ """Validate video size and length
+
+ Args:
+ args: command line arguments
+
+ Returns:
+ Tuple[int, int]: (height, width)
+ """
+ height = args.image_size[0]
+ width = args.image_size[1]
+
+ if height % 32 != 0 or width % 32 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
+
+ return height, width
+
+
+# region Model
+
+
+def load_dit_model(
+ args: argparse.Namespace, device: torch.device, dit_weight_dtype: Optional[torch.dtype] = None
+) -> anima_models.Anima:
+ """load DiT model
+
+ Args:
+ args: command line arguments
+ device: device to use
+ dit_weight_dtype: data type for the model weights. None for as-is
+
+ Returns:
+ anima_models.Anima: DiT model instance
+ """
+ # If LyCORIS is enabled, we will load the model to CPU and then merge LoRA weights (static method)
+
+ loading_device = "cpu"
+ if not args.lycoris:
+ loading_device = device
+
+ # load LoRA weights
+ if not args.lycoris and args.lora_weight is not None and len(args.lora_weight) > 0:
+ lora_weights_list = []
+ for lora_weight in args.lora_weight:
+ logger.info(f"Loading LoRA weight from: {lora_weight}")
+ lora_sd = load_file(lora_weight) # load on CPU, dtype is as is
+ # lora_sd = filter_lora_state_dict(lora_sd, args.include_patterns, args.exclude_patterns)
+ lora_sd = {k: v for k, v in lora_sd.items() if k.startswith("lora_unet_")} # only keep unet lora weights
+ lora_weights_list.append(lora_sd)
+ else:
+ lora_weights_list = None
+
+ loading_weight_dtype = dit_weight_dtype
+ if args.fp8_scaled and not args.lycoris:
+ loading_weight_dtype = None # we will load weights as-is and then optimize to fp8
+
+ model = anima_utils.load_anima_model(
+ device,
+ args.dit,
+ args.attn_mode,
+ True, # enable split_attn to trim masked tokens
+ loading_device,
+ loading_weight_dtype,
+ args.fp8_scaled and not args.lycoris,
+ lora_weights_list=lora_weights_list,
+ lora_multipliers=args.lora_multiplier,
+ )
+ if not args.fp8_scaled:
+ # simple cast to dit_weight_dtype
+ target_dtype = None # load as-is (dit_weight_dtype == dtype of the weights in state_dict)
+ if dit_weight_dtype is not None: # in case of args.fp8 and not args.fp8_scaled
+ logger.info(f"Convert model to {dit_weight_dtype}")
+ target_dtype = dit_weight_dtype
+
+ logger.info(f"Move model to device: {device}")
+ target_device = device
+
+ model.to(target_device, target_dtype) # move and cast at the same time. this reduces redundant copy operations
+
+ # model.to(device)
+ model.to(device, dtype=torch.bfloat16) # ensure model is in bfloat16 for inference
+
+ model.eval().requires_grad_(False)
+ clean_memory_on_device(device)
+
+ return model
+
+
+def load_text_encoder(
+ args: argparse.Namespace, dtype: torch.dtype = torch.bfloat16, device: torch.device = torch.device("cpu")
+) -> torch.nn.Module:
+ lora_weights_list = None
+ if args.lora_weight is not None and len(args.lora_weight) > 0:
+ lora_weights_list = []
+ for lora_weight in args.lora_weight:
+ logger.info(f"Loading LoRA weight from: {lora_weight}")
+ lora_sd = load_file(lora_weight) # load on CPU, dtype is as is
+ # lora_sd = filter_lora_state_dict(lora_sd, args.include_patterns, args.exclude_patterns)
+ lora_sd = {
+ "model_" + k[len("lora_te_") :]: v for k, v in lora_sd.items() if k.startswith("lora_te_")
+ } # only keep Text Encoder lora weights, remove prefix "lora_te_" and add "model_" prefix
+ lora_weights_list.append(lora_sd)
+
+ text_encoder, _ = anima_utils.load_qwen3_text_encoder(
+ args.text_encoder, dtype=dtype, device=device, lora_weights=lora_weights_list, lora_multipliers=args.lora_multiplier
+ )
+ text_encoder.eval()
+ return text_encoder
+
+
+# endregion
+
+
+def decode_latent(
+ vae: qwen_image_autoencoder_kl.AutoencoderKLQwenImage, latent: torch.Tensor, device: torch.device
+) -> torch.Tensor:
+ logger.info(f"Decoding image. Latent shape {latent.shape}, device {device}")
+
+ vae.to(device)
+ with torch.no_grad():
+ pixels = vae.decode_to_pixels(latent.to(device, dtype=vae.dtype))
+ # pixels = vae.decode(latent.to(device, dtype=torch.bfloat16), scale=vae_scale)
+ if pixels.ndim == 5: # remove frame dimension if exists, [B, C, F, H, W] -> [B, C, H, W]
+ pixels = pixels.squeeze(2)
+
+ pixels = pixels.to("cpu", dtype=torch.float32) # move to CPU and convert to float32 (bfloat16 is not supported by numpy)
+ vae.to("cpu")
+
+ logger.info(f"Decoded. Pixel shape {pixels.shape}")
+ return pixels[0] # remove batch dimension
+
+
+def process_escape(text: str) -> str:
+ """Process escape sequences in text
+
+ Args:
+ text: Input text with escape sequences
+
+ Returns:
+ str: Processed text
+ """
+ return text.encode("utf-8").decode("unicode_escape")
+
+
+def prepare_text_inputs(
+ args: argparse.Namespace, device: torch.device, anima: anima_models.Anima, shared_models: Optional[Dict] = None
+) -> Tuple[Dict[str, Any], Dict[str, Any]]:
+ """Prepare text-related inputs for T2I: LLM encoding. Anima model is also needed for preprocessing"""
+
+ # load text encoder: conds_cache holds cached encodings for prompts without padding
+ conds_cache = {}
+ text_encoder_device = torch.device("cpu") if args.text_encoder_cpu else device
+ if shared_models is not None:
+ text_encoder = shared_models.get("text_encoder")
+
+ if "conds_cache" in shared_models: # Use shared cache if available
+ conds_cache = shared_models["conds_cache"]
+
+ # text_encoder is on device (batched inference) or CPU (interactive inference)
+ else: # Load if not in shared_models
+ text_encoder_dtype = torch.bfloat16 # Default dtype for Text Encoder
+ text_encoder = load_text_encoder(args, dtype=text_encoder_dtype, device=text_encoder_device)
+ text_encoder.eval()
+ tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
+ # Store references so load_target_model can reuse them
+
+ # Store original devices to move back later if they were shared. This does nothing if shared_models is None
+ text_encoder_original_device = text_encoder.device if text_encoder else None
+
+ # Ensure text_encoder is not None before proceeding
+ if not text_encoder:
+ raise ValueError("Text encoder is not loaded properly.")
+
+ # Define a function to move models to device if needed
+ # This is to avoid moving models if not needed, especially in interactive mode
+ model_is_moved = False
+
+ def move_models_to_device_if_needed():
+ nonlocal model_is_moved
+ nonlocal shared_models
+
+ if model_is_moved:
+ return
+ model_is_moved = True
+
+ logger.info(f"Moving Text Encoder to appropriate device: {text_encoder_device}")
+ text_encoder.to(text_encoder_device) # If text_encoder_cpu is True, this will be CPU
+
+ logger.info("Encoding prompt with Text Encoder")
+
+ prompt = process_escape(args.prompt)
+ cache_key = prompt
+ if cache_key in conds_cache:
+ embed = conds_cache[cache_key]
+ else:
+ move_models_to_device_if_needed()
+
+ tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
+ encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
+
+ with torch.no_grad():
+ # embed = anima_text_encoder.get_text_embeds(anima, tokenizer, text_encoder, t5xxl_tokenizer, prompt)
+ tokens = tokenize_strategy.tokenize(prompt)
+ embed = encoding_strategy.encode_tokens(tokenize_strategy, [text_encoder], tokens)
+ crossattn_emb = anima._preprocess_text_embeds(
+ source_hidden_states=embed[0].to(anima.device),
+ target_input_ids=embed[2].to(anima.device),
+ target_attention_mask=embed[3].to(anima.device),
+ source_attention_mask=embed[1].to(anima.device),
+ )
+ crossattn_emb[~embed[3].bool()] = 0
+ embed[0] = crossattn_emb
+ embed[0] = embed[0].cpu()
+
+ conds_cache[cache_key] = embed
+
+ negative_prompt = process_escape(args.negative_prompt)
+ cache_key = negative_prompt
+ if cache_key in conds_cache:
+ negative_embed = conds_cache[cache_key]
+ else:
+ move_models_to_device_if_needed()
+
+ tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
+ encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
+
+ with torch.no_grad():
+ # negative_embed = anima_text_encoder.get_text_embeds(anima, tokenizer, text_encoder, t5xxl_tokenizer, negative_prompt)
+ tokens = tokenize_strategy.tokenize(negative_prompt)
+ negative_embed = encoding_strategy.encode_tokens(tokenize_strategy, [text_encoder], tokens)
+ crossattn_emb = anima._preprocess_text_embeds(
+ source_hidden_states=negative_embed[0].to(anima.device),
+ target_input_ids=negative_embed[2].to(anima.device),
+ target_attention_mask=negative_embed[3].to(anima.device),
+ source_attention_mask=negative_embed[1].to(anima.device),
+ )
+ crossattn_emb[~negative_embed[3].bool()] = 0
+ negative_embed[0] = crossattn_emb
+ negative_embed[0] = negative_embed[0].cpu()
+
+ conds_cache[cache_key] = negative_embed
+
+ if not (shared_models and "text_encoder" in shared_models): # if loaded locally
+ # There is a bug text_encoder is not freed from GPU memory when text encoder is fp8
+ del text_encoder
+ gc.collect() # This may force Text Encoder to be freed from GPU memory
+ else: # if shared, move back to original device (likely CPU)
+ if text_encoder:
+ text_encoder.to(text_encoder_original_device)
+
+ clean_memory_on_device(device)
+
+ arg_c = {"embed": embed, "prompt": prompt}
+ arg_null = {"embed": negative_embed, "prompt": negative_prompt}
+
+ return arg_c, arg_null
+
+
+def generate(
+ args: argparse.Namespace,
+ gen_settings: GenerationSettings,
+ shared_models: Optional[Dict] = None,
+ precomputed_text_data: Optional[Dict] = None,
+) -> torch.Tensor:
+ """main function for generation
+
+ Args:
+ args: command line arguments
+ shared_models: dictionary containing pre-loaded models (mainly for DiT)
+ precomputed_image_data: Optional dictionary with precomputed image data
+ precomputed_text_data: Optional dictionary with precomputed text data
+
+ Returns:
+ tuple: (HunyuanVAE2D model (vae) or None, torch.Tensor generated latent)
+ """
+ device, dit_weight_dtype = (gen_settings.device, gen_settings.dit_weight_dtype)
+
+ # prepare seed
+ seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1)
+ args.seed = seed # set seed to args for saving
+
+ if shared_models is None or "model" not in shared_models:
+ # load DiT model
+ anima = load_dit_model(args, device, dit_weight_dtype)
+
+ if shared_models is not None:
+ shared_models["model"] = anima
+ else:
+ # use shared model
+ logger.info("Using shared DiT model.")
+ anima: anima_models.Anima = shared_models["model"]
+
+ if precomputed_text_data is not None:
+ logger.info("Using precomputed text data.")
+ context = precomputed_text_data["context"]
+ context_null = precomputed_text_data["context_null"]
+
+ else:
+ logger.info("No precomputed data. Preparing image and text inputs.")
+ context, context_null = prepare_text_inputs(args, device, anima, shared_models)
+
+ return generate_body(args, anima, context, context_null, device, seed)
+
+
+def generate_body(
+ args: Union[argparse.Namespace, SimpleNamespace],
+ anima: anima_models.Anima,
+ context: Dict[str, Any],
+ context_null: Optional[Dict[str, Any]],
+ device: torch.device,
+ seed: int,
+) -> torch.Tensor:
+
+ # set random generator
+ seed_g = torch.Generator(device="cpu")
+ seed_g.manual_seed(seed)
+
+ height, width = check_inputs(args)
+ logger.info(f"Image size: {height}x{width} (HxW), infer_steps: {args.infer_steps}")
+
+ # image generation ######
+
+ logger.info(f"Prompt: {context['prompt']}")
+
+ embed = context["embed"][0].to(device, dtype=torch.bfloat16)
+ if context_null is None:
+ context_null = context # dummy for unconditional
+ negative_embed = context_null["embed"][0].to(device, dtype=torch.bfloat16)
+
+ # Prepare latent variables
+ num_channels_latents = anima_models.Anima.LATENT_CHANNELS
+ shape = (
+ 1,
+ num_channels_latents,
+ 1, # Frame dimension
+ height // 8, # qwen_image_autoencoder_kl.SCALE_FACTOR,
+ width // 8, # qwen_image_autoencoder_kl.SCALE_FACTOR,
+ )
+ latents = randn_tensor(shape, generator=seed_g, device=device, dtype=torch.bfloat16)
+
+ # Create padding mask
+ bs = latents.shape[0]
+ h_latent = latents.shape[-2]
+ w_latent = latents.shape[-1]
+ padding_mask = torch.zeros(bs, 1, h_latent, w_latent, dtype=torch.bfloat16, device=device)
+
+ logger.info(f"Embed: {embed.shape}, negative_embed: {negative_embed.shape}, latents: {latents.shape}")
+ embed = embed.to(torch.bfloat16)
+ negative_embed = negative_embed.to(torch.bfloat16)
+
+ # Prepare timesteps
+ timesteps, sigmas = hunyuan_image_utils.get_timesteps_sigmas(args.infer_steps, args.flow_shift, device)
+ timesteps /= 1000 # scale to [0,1] range
+ timesteps = timesteps.to(device, dtype=torch.bfloat16)
+
+ # Denoising loop
+ do_cfg = args.guidance_scale != 1.0
+ autocast_enabled = args.fp8
+
+ with tqdm(total=len(timesteps), desc="Denoising steps") as pbar:
+ for i, t in enumerate(timesteps):
+ t_expand = t.expand(latents.shape[0])
+
+ with torch.no_grad(), torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=autocast_enabled):
+ noise_pred = anima(latents, t_expand, embed, padding_mask=padding_mask)
+
+ if do_cfg:
+ with torch.no_grad(), torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=autocast_enabled):
+ uncond_noise_pred = anima(latents, t_expand, negative_embed, padding_mask=padding_mask)
+ noise_pred = uncond_noise_pred + args.guidance_scale * (noise_pred - uncond_noise_pred)
+
+ # ensure latents dtype is consistent
+ latents = hunyuan_image_utils.step(latents, noise_pred, sigmas, i).to(latents.dtype)
+
+ pbar.update()
+
+ return latents
+
+
+def get_time_flag():
+ return datetime.datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S-%f")[:-3]
+
+
+def save_latent(latent: torch.Tensor, args: argparse.Namespace, height: int, width: int) -> str:
+ """Save latent to file
+
+ Args:
+ latent: Latent tensor
+ args: command line arguments
+ height: height of frame
+ width: width of frame
+
+ Returns:
+ str: Path to saved latent file
+ """
+ save_path = args.save_path
+ os.makedirs(save_path, exist_ok=True)
+ time_flag = get_time_flag()
+
+ seed = args.seed
+
+ latent_path = f"{save_path}/{time_flag}_{seed}_latent.safetensors"
+
+ if args.no_metadata:
+ metadata = None
+ else:
+ metadata = {
+ "seeds": f"{seed}",
+ "prompt": f"{args.prompt}",
+ "height": f"{height}",
+ "width": f"{width}",
+ "infer_steps": f"{args.infer_steps}",
+ # "embedded_cfg_scale": f"{args.embedded_cfg_scale}",
+ "guidance_scale": f"{args.guidance_scale}",
+ }
+ if args.negative_prompt is not None:
+ metadata["negative_prompt"] = f"{args.negative_prompt}"
+
+ sd = {"latent": latent.contiguous()}
+ save_file(sd, latent_path, metadata=metadata)
+ logger.info(f"Latent saved to: {latent_path}")
+
+ return latent_path
+
+
+def save_images(sample: torch.Tensor, args: argparse.Namespace, original_base_name: Optional[str] = None) -> str:
+ """Save images to directory
+
+ Args:
+ sample: Video tensor
+ args: command line arguments
+ original_base_name: Original base name (if latents are loaded from files)
+
+ Returns:
+ str: Path to saved images directory
+ """
+ save_path = args.save_path
+ os.makedirs(save_path, exist_ok=True)
+ time_flag = get_time_flag()
+
+ seed = args.seed
+ original_name = "" if original_base_name is None else f"_{original_base_name}"
+ image_name = f"{time_flag}_{seed}{original_name}"
+
+ x = torch.clamp(sample, -1.0, 1.0)
+ x = ((x + 1.0) * 127.5).to(torch.uint8).cpu().numpy()
+ x = x.transpose(1, 2, 0) # C, H, W -> H, W, C
+
+ image = Image.fromarray(x)
+ image.save(os.path.join(save_path, f"{image_name}.png"))
+
+ logger.info(f"Sample images saved to: {save_path}/{image_name}")
+
+ return f"{save_path}/{image_name}"
+
+
+def save_output(
+ args: argparse.Namespace,
+ vae: qwen_image_autoencoder_kl.AutoencoderKLQwenImage,
+ latent: torch.Tensor,
+ device: torch.device,
+ original_base_name: Optional[str] = None,
+) -> None:
+ """save output
+
+ Args:
+ args: command line arguments
+ vae: VAE model
+ latent: latent tensor
+ device: device to use
+ original_base_name: original base name (if latents are loaded from files)
+ """
+ height, width = latent.shape[-2], latent.shape[-1] # BCTHW
+ height *= 8 # qwen_image_autoencoder_kl.SCALE_FACTOR
+ width *= 8 # qwen_image_autoencoder_kl.SCALE_FACTOR
+ # print(f"Saving output. Latent shape {latent.shape}; pixel shape {height}x{width}")
+ if args.output_type == "latent" or args.output_type == "latent_images":
+ # save latent
+ save_latent(latent, args, height, width)
+ if args.output_type == "latent":
+ return
+
+ if vae is None:
+ logger.error("VAE is None, cannot decode latents for saving video/images.")
+ return
+
+ if latent.ndim == 2: # S,C. For packed latents from other inference scripts
+ latent = latent.unsqueeze(0)
+ height, width = check_inputs(args) # Get height/width from args
+ latent = latent.view(
+ 1,
+ vae.latent_channels,
+ 1, # Frame dimension
+ height // 8, # qwen_image_autoencoder_kl.SCALE_FACTOR,
+ width // 8, # qwen_image_autoencoder_kl.SCALE_FACTOR,
+ )
+
+ image = decode_latent(vae, latent, device)
+
+ if args.output_type == "images" or args.output_type == "latent_images":
+ # save images
+ if original_base_name is None:
+ original_name = ""
+ else:
+ original_name = f"_{original_base_name}"
+ save_images(image, args, original_name)
+
+
+def preprocess_prompts_for_batch(prompt_lines: List[str], base_args: argparse.Namespace) -> List[Dict]:
+ """Process multiple prompts for batch mode
+
+ Args:
+ prompt_lines: List of prompt lines
+ base_args: Base command line arguments
+
+ Returns:
+ List[Dict]: List of prompt data dictionaries
+ """
+ prompts_data = []
+
+ for line in prompt_lines:
+ line = line.strip()
+ if not line or line.startswith("#"): # Skip empty lines and comments
+ continue
+
+ # Parse prompt line and create override dictionary
+ prompt_data = parse_prompt_line(line)
+ logger.info(f"Parsed prompt data: {prompt_data}")
+ prompts_data.append(prompt_data)
+
+ return prompts_data
+
+
+def load_shared_models(args: argparse.Namespace) -> Dict:
+ """Load shared models for batch processing or interactive mode.
+ Models are loaded to CPU to save memory. VAE is NOT loaded here.
+ DiT model is also NOT loaded here, handled by process_batch_prompts or generate.
+
+ Args:
+ args: Base command line arguments
+
+ Returns:
+ Dict: Dictionary of shared models (text/image encoders)
+ """
+ shared_models = {}
+ # Load text encoders to CPU
+ text_encoder_dtype = torch.bfloat16 # Default dtype for Text Encoder
+ text_encoder = load_text_encoder(args, dtype=text_encoder_dtype, device=torch.device("cpu"))
+ shared_models["text_encoder"] = text_encoder
+ return shared_models
+
+
+def process_batch_prompts(prompts_data: List[Dict], args: argparse.Namespace) -> None:
+ """Process multiple prompts with model reuse and batched precomputation
+
+ Args:
+ prompts_data: List of prompt data dictionaries
+ args: Base command line arguments
+ """
+ if not prompts_data:
+ logger.warning("No valid prompts found")
+ return
+
+ gen_settings = get_generation_settings(args)
+ dit_weight_dtype = gen_settings.dit_weight_dtype
+ device = gen_settings.device
+
+ # 1. Prepare VAE
+ logger.info("Loading VAE for batch generation...")
+ vae_for_batch = qwen_image_autoencoder_kl.load_vae(
+ args.vae, device="cpu", disable_mmap=True, spatial_chunk_size=args.vae_chunk_size, disable_cache=args.vae_disable_cache
+ )
+ vae_for_batch.to(torch.bfloat16)
+ vae_for_batch.eval()
+
+ all_prompt_args_list = [apply_overrides(args, pd) for pd in prompts_data] # Create all arg instances first
+ for prompt_args in all_prompt_args_list:
+ check_inputs(prompt_args) # Validate each prompt's height/width
+
+ # 2. Load DiT Model once
+ logger.info("Loading DiT model for batch generation...")
+ # Use args from the first prompt for DiT loading (LoRA etc. should be consistent for a batch)
+ first_prompt_args = all_prompt_args_list[0]
+ anima = load_dit_model(first_prompt_args, device, dit_weight_dtype) # Load directly to target device if possible
+
+ shared_models_for_generate = {"model": anima} # Pass DiT via shared_models
+
+ # 3. Precompute Text Data (Text Encoder)
+ logger.info("Loading Text Encoder for batch text preprocessing...")
+
+ # Text Encoder loaded to CPU by load_text_encoder
+ text_encoder_dtype = torch.bfloat16 # Default dtype for Text Encoder
+ text_encoder_batch = load_text_encoder(args, dtype=text_encoder_dtype, device=torch.device("cpu"))
+
+ # Text Encoder to device for this phase
+ text_encoder_device = torch.device("cpu") if args.text_encoder_cpu else device
+ text_encoder_batch.to(text_encoder_device) # Moved into prepare_text_inputs logic
+
+ all_precomputed_text_data = []
+ conds_cache_batch = {}
+
+ logger.info("Preprocessing text and LLM/TextEncoder encoding for all prompts...")
+ temp_shared_models_txt = {
+ "text_encoder": text_encoder_batch, # on GPU if not text_encoder_cpu
+ "conds_cache": conds_cache_batch,
+ }
+
+ for i, prompt_args_item in enumerate(all_prompt_args_list):
+ logger.info(f"Text preprocessing for prompt {i+1}/{len(all_prompt_args_list)}: {prompt_args_item.prompt}")
+
+ # prepare_text_inputs will move text_encoders to device temporarily
+ context, context_null = prepare_text_inputs(prompt_args_item, device, anima, temp_shared_models_txt)
+ text_data = {"context": context, "context_null": context_null}
+ all_precomputed_text_data.append(text_data)
+
+ # Models should be removed from device after prepare_text_inputs
+ del text_encoder_batch, temp_shared_models_txt, conds_cache_batch
+ gc.collect() # Force cleanup of Text Encoder from GPU memory
+ clean_memory_on_device(device)
+
+ all_latents = []
+
+ logger.info("Generating latents for all prompts...")
+ with torch.no_grad():
+ for i, prompt_args_item in enumerate(all_prompt_args_list):
+ current_text_data = all_precomputed_text_data[i]
+ height, width = check_inputs(prompt_args_item) # Get height/width for each prompt
+
+ logger.info(f"Generating latent for prompt {i+1}/{len(all_prompt_args_list)}: {prompt_args_item.prompt}")
+ try:
+ # generate is called with precomputed data, so it won't load Text Encoders.
+ # It will use the DiT model from shared_models_for_generate.
+ latent = generate(prompt_args_item, gen_settings, shared_models_for_generate, current_text_data)
+
+ if latent is None: # and prompt_args_item.save_merged_model: # Should be caught earlier
+ continue
+
+ # Save latent if needed (using data from precomputed_image_data for H/W)
+ if prompt_args_item.output_type in ["latent", "latent_images"]:
+ save_latent(latent, prompt_args_item, height, width)
+
+ all_latents.append(latent)
+ except Exception as e:
+ logger.error(f"Error generating latent for prompt: {prompt_args_item.prompt}. Error: {e}", exc_info=True)
+ all_latents.append(None) # Add placeholder for failed generations
+ continue
+
+ # Free DiT model
+ logger.info("Releasing DiT model from memory...")
+
+ del shared_models_for_generate["model"]
+ del anima
+ clean_memory_on_device(device)
+ synchronize_device(device) # Ensure memory is freed before loading VAE for decoding
+
+ # 4. Decode latents and save outputs (using vae_for_batch)
+ if args.output_type != "latent":
+ logger.info("Decoding latents to videos/images using batched VAE...")
+ vae_for_batch.to(device) # Move VAE to device for decoding
+
+ for i, latent in enumerate(all_latents):
+ if latent is None: # Skip failed generations
+ logger.warning(f"Skipping decoding for prompt {i+1} due to previous error.")
+ continue
+
+ current_args = all_prompt_args_list[i]
+ logger.info(f"Decoding output {i+1}/{len(all_latents)} for prompt: {current_args.prompt}")
+
+ # if args.output_type is "latent_images", we already saved latent above.
+ # so we skip saving latent here.
+ if current_args.output_type == "latent_images":
+ current_args.output_type = "images"
+
+ # save_output expects latent to be [BCTHW] or [CTHW]. generate returns [BCTHW] (batch size 1).
+ save_output(current_args, vae_for_batch, latent, device) # Pass vae_for_batch
+
+ vae_for_batch.to("cpu") # Move VAE back to CPU
+
+ del vae_for_batch
+ clean_memory_on_device(device)
+
+
+def process_interactive(args: argparse.Namespace) -> None:
+ """Process prompts in interactive mode
+
+ Args:
+ args: Base command line arguments
+ """
+ gen_settings = get_generation_settings(args)
+ device = gen_settings.device
+ shared_models = load_shared_models(args)
+ shared_models["conds_cache"] = {} # Initialize empty cache for interactive mode
+
+ vae = qwen_image_autoencoder_kl.load_vae(
+ args.vae, device="cpu", disable_mmap=True, spatial_chunk_size=args.vae_chunk_size, disable_cache=args.vae_disable_cache
+ )
+ vae.to(torch.bfloat16)
+ vae.eval()
+
+ print("Interactive mode. Enter prompts (Ctrl+D or Ctrl+Z (Windows) to exit):")
+
+ try:
+ import prompt_toolkit
+ except ImportError:
+ logger.warning("prompt_toolkit not found. Using basic input instead.")
+ prompt_toolkit = None
+
+ if prompt_toolkit:
+ session = prompt_toolkit.PromptSession()
+
+ def input_line(prompt: str) -> str:
+ return session.prompt(prompt)
+
+ else:
+
+ def input_line(prompt: str) -> str:
+ return input(prompt)
+
+ try:
+ while True:
+ try:
+ line = input_line("> ")
+ if not line.strip():
+ continue
+ if len(line.strip()) == 1 and line.strip() in ["\x04", "\x1a"]: # Ctrl+D or Ctrl+Z with prompt_toolkit
+ raise EOFError # Exit on Ctrl+D or Ctrl+Z
+
+ # Parse prompt
+ prompt_data = parse_prompt_line(line)
+ prompt_args = apply_overrides(args, prompt_data)
+
+ # Generate latent
+ # For interactive, precomputed data is None. shared_models contains text encoders.
+ latent = generate(prompt_args, gen_settings, shared_models)
+
+ # # If not one_frame_inference, move DiT model to CPU after generation
+ # model = shared_models.get("model")
+ # model.to("cpu") # Move DiT model to CPU after generation
+
+ # Save latent and video
+ # returned_vae from generate will be used for decoding here.
+ save_output(prompt_args, vae, latent, device)
+
+ except KeyboardInterrupt:
+ print("\nInterrupted. Continue (Ctrl+D or Ctrl+Z (Windows) to exit)")
+ continue
+
+ except EOFError:
+ print("\nExiting interactive mode")
+
+
+def get_generation_settings(args: argparse.Namespace) -> GenerationSettings:
+ device = torch.device(args.device)
+
+ dit_weight_dtype = torch.bfloat16 # default
+ if args.fp8_scaled:
+ dit_weight_dtype = None # various precision weights, so don't cast to specific dtype
+ elif args.fp8:
+ dit_weight_dtype = torch.float8_e4m3fn
+
+ logger.info(f"Using device: {device}, DiT weight weight precision: {dit_weight_dtype}")
+
+ gen_settings = GenerationSettings(device=device, dit_weight_dtype=dit_weight_dtype)
+ return gen_settings
+
+
+def main():
+ # Parse arguments
+ args = parse_args()
+
+ # Check if latents are provided
+ latents_mode = args.latent_path is not None and len(args.latent_path) > 0
+
+ # Set device
+ device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
+ device = torch.device(device)
+ logger.info(f"Using device: {device}")
+ args.device = device
+
+ if latents_mode:
+ # Original latent decode mode
+ original_base_names = []
+ latents_list = []
+ seeds = []
+
+ # assert len(args.latent_path) == 1, "Only one latent path is supported for now"
+
+ for latent_path in args.latent_path:
+ original_base_names.append(os.path.splitext(os.path.basename(latent_path))[0])
+ seed = 0
+
+ if os.path.splitext(latent_path)[1] != ".safetensors":
+ latents = torch.load(latent_path, map_location="cpu")
+ else:
+ latents = load_file(latent_path)["latent"]
+ with safe_open(latent_path, framework="pt") as f:
+ metadata = f.metadata()
+ if metadata is None:
+ metadata = {}
+ logger.info(f"Loaded metadata: {metadata}")
+
+ if "seeds" in metadata:
+ seed = int(metadata["seeds"])
+ if "height" in metadata and "width" in metadata:
+ height = int(metadata["height"])
+ width = int(metadata["width"])
+ args.image_size = [height, width]
+
+ seeds.append(seed)
+ logger.info(f"Loaded latent from {latent_path}. Shape: {latents.shape}")
+
+ if latents.ndim == 5: # [BCTHW]
+ latents = latents.squeeze(0) # [CTHW]
+
+ latents_list.append(latents)
+
+ vae = qwen_image_autoencoder_kl.load_vae(
+ args.vae,
+ device=device,
+ disable_mmap=True,
+ spatial_chunk_size=args.vae_chunk_size,
+ disable_cache=args.vae_disable_cache,
+ )
+ vae.to(torch.bfloat16)
+ vae.eval()
+
+ for i, latent in enumerate(latents_list):
+ args.seed = seeds[i]
+ save_output(args, vae, latent, device, original_base_names[i])
+
+ else:
+ tokenize_strategy = strategy_anima.AnimaTokenizeStrategy(
+ qwen3_path=args.text_encoder, t5_tokenizer_path=None, qwen3_max_length=512, t5_max_length=512
+ )
+ strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
+
+ encoding_strategy = strategy_anima.AnimaTextEncodingStrategy()
+ strategy_base.TextEncodingStrategy.set_strategy(encoding_strategy)
+
+ if args.from_file:
+ # Batch mode from file
+
+ # Read prompts from file
+ with open(args.from_file, "r", encoding="utf-8") as f:
+ prompt_lines = f.readlines()
+
+ # Process prompts
+ prompts_data = preprocess_prompts_for_batch(prompt_lines, args)
+ process_batch_prompts(prompts_data, args)
+
+ elif args.interactive:
+ # Interactive mode
+ process_interactive(args)
+
+ else:
+ # Single prompt mode (original behavior)
+
+ # Generate latent
+ gen_settings = get_generation_settings(args)
+
+ # For single mode, precomputed data is None, shared_models is None.
+ # generate will load all necessary models (Text Encoders, DiT).
+ latent = generate(args, gen_settings)
+
+ clean_memory_on_device(device)
+
+ # Save latent and video
+ vae = qwen_image_autoencoder_kl.load_vae(
+ args.vae,
+ device="cpu",
+ disable_mmap=True,
+ spatial_chunk_size=args.vae_chunk_size,
+ disable_cache=args.vae_disable_cache,
+ )
+ vae.to(torch.bfloat16)
+ vae.eval()
+ save_output(args, vae, latent, device)
+
+ logger.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/anima_train.py b/anima_train.py
index a86c30c3..4d1eb10f 100644
--- a/anima_train.py
+++ b/anima_train.py
@@ -3,6 +3,7 @@
import argparse
from concurrent.futures import ThreadPoolExecutor
import copy
+import gc
import math
import os
from multiprocessing import Value
@@ -12,8 +13,9 @@ import toml
from tqdm import tqdm
import torch
-from library import utils
+from library import flux_train_utils, qwen_image_autoencoder_kl
from library.device_utils import init_ipex, clean_memory_on_device
+from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
init_ipex()
@@ -49,21 +51,18 @@ def train(args):
args.skip_cache_check = args.skip_latents_validity_check
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
- logger.warning(
- "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled"
- )
+ logger.warning("cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled")
args.cache_text_encoder_outputs = True
if args.cpu_offload_checkpointing and not args.gradient_checkpointing:
logger.warning("cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled")
args.gradient_checkpointing = True
- if getattr(args, 'unsloth_offload_checkpointing', False):
+ if args.unsloth_offload_checkpointing:
if not args.gradient_checkpointing:
logger.warning("unsloth_offload_checkpointing is enabled, so gradient_checkpointing is also enabled")
args.gradient_checkpointing = True
- assert not args.cpu_offload_checkpointing, \
- "Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing"
+ assert not args.cpu_offload_checkpointing, "Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing"
assert (
args.blocks_to_swap is None or args.blocks_to_swap == 0
@@ -71,17 +70,7 @@ def train(args):
assert (
args.blocks_to_swap is None or args.blocks_to_swap == 0
- ) or not getattr(args, 'unsloth_offload_checkpointing', False), \
- "blocks_to_swap is not supported with unsloth_offload_checkpointing"
-
- # Flash attention: validate availability
- if getattr(args, 'flash_attn', False):
- try:
- import flash_attn # noqa: F401
- logger.info("Flash Attention enabled for DiT blocks")
- except ImportError:
- logger.warning("flash_attn package not installed, falling back to PyTorch SDPA")
- args.flash_attn = False
+ ) or not args.unsloth_offload_checkpointing, "blocks_to_swap is not supported with unsloth_offload_checkpointing"
cache_latents = args.cache_latents
use_dreambooth_method = args.in_json is None
@@ -104,9 +93,7 @@ def train(args):
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
- logger.warning(
- "ignore following options because config file is found: {0}".format(", ".join(ignored))
- )
+ logger.warning("ignore following options because config file is found: {0}".format(", ".join(ignored)))
else:
if use_dreambooth_method:
logger.info("Using DreamBooth method.")
@@ -145,26 +132,13 @@ def train(args):
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
- train_dataset_group.verify_bucket_reso_steps(8) # WanVAE spatial downscale = 8
-
- # Anima uses embedding-level dropout (in AnimaTextEncodingStrategy) instead of
- # dataset-level caption dropout, so we save the rate and zero out subset-level
- # caption_dropout_rate to allow text encoder output caching.
- caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
- if caption_dropout_rate > 0:
- logger.info(f"Using embedding-level caption dropout rate: {caption_dropout_rate}")
- for dataset in train_dataset_group.datasets:
- for subset in dataset.subsets:
- subset.caption_dropout_rate = 0.0
+ train_dataset_group.verify_bucket_reso_steps(16) # Qwen-Image VAE spatial downscale = 8 * patch size = 2
if args.debug_dataset:
if args.cache_text_encoder_outputs:
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
strategy_anima.AnimaTextEncoderOutputsCachingStrategy(
- args.cache_text_encoder_outputs_to_disk,
- args.text_encoder_batch_size,
- False,
- False,
+ args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False
)
)
train_dataset_group.set_current_strategies()
@@ -175,13 +149,11 @@ def train(args):
return
if cache_latents:
- assert (
- train_dataset_group.is_latent_cacheable()
- ), "when caching latents, either color_aug or random_crop cannot be used"
+ assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used"
if args.cache_text_encoder_outputs:
- assert (
- train_dataset_group.is_text_encoder_output_cacheable()
+ assert train_dataset_group.is_text_encoder_output_cacheable(
+ cache_supports_dropout=True
), "when caching text encoder output, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used"
# prepare accelerator
@@ -191,24 +163,10 @@ def train(args):
# mixed precision dtype
weight_dtype, save_dtype = train_util.prepare_dtype(args)
- # parse transformer_dtype
- transformer_dtype = None
- if hasattr(args, 'transformer_dtype') and args.transformer_dtype is not None:
- transformer_dtype_map = {
- "float16": torch.float16,
- "bfloat16": torch.bfloat16,
- "float32": torch.float32,
- }
- transformer_dtype = transformer_dtype_map.get(args.transformer_dtype, None)
-
# Load tokenizers and set strategies
logger.info("Loading tokenizers...")
- qwen3_text_encoder, qwen3_tokenizer = anima_utils.load_qwen3_text_encoder(
- args.qwen3_path, dtype=weight_dtype, device="cpu"
- )
- t5_tokenizer = anima_utils.load_t5_tokenizer(
- getattr(args, 't5_tokenizer_path', None)
- )
+ qwen3_text_encoder, qwen3_tokenizer = anima_utils.load_qwen3_text_encoder(args.qwen3, dtype=weight_dtype, device="cpu")
+ t5_tokenizer = anima_utils.load_t5_tokenizer(args.t5_tokenizer_path)
# Set tokenize strategy
tokenize_strategy = strategy_anima.AnimaTokenizeStrategy(
@@ -219,11 +177,7 @@ def train(args):
)
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
- # Set text encoding strategy
- caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
- text_encoding_strategy = strategy_anima.AnimaTextEncodingStrategy(
- dropout_rate=caption_dropout_rate,
- )
+ text_encoding_strategy = strategy_anima.AnimaTextEncodingStrategy()
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
# Prepare text encoder (always frozen for Anima)
@@ -237,10 +191,7 @@ def train(args):
qwen3_text_encoder.eval()
text_encoder_caching_strategy = strategy_anima.AnimaTextEncoderOutputsCachingStrategy(
- args.cache_text_encoder_outputs_to_disk,
- args.text_encoder_batch_size,
- args.skip_cache_check,
- is_partial=False,
+ args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, is_partial=False
)
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy)
@@ -259,27 +210,21 @@ def train(args):
logger.info(f" cache TE outputs for: {p}")
tokens_and_masks = tokenize_strategy.tokenize(p)
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
- tokenize_strategy,
- [qwen3_text_encoder],
- tokens_and_masks,
- enable_dropout=False,
+ tokenize_strategy, [qwen3_text_encoder], tokens_and_masks
)
- # Pre-cache unconditional embeddings for caption dropout before text encoder is deleted
- caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
- if caption_dropout_rate > 0.0:
- with accelerator.autocast():
- text_encoding_strategy.cache_uncond_embeddings(tokenize_strategy, [qwen3_text_encoder])
-
accelerator.wait_for_everyone()
# free text encoder memory
qwen3_text_encoder = None
+ gc.collect() # Force garbage collection to free memory
clean_memory_on_device(accelerator.device)
# Load VAE and cache latents
logger.info("Loading Anima VAE...")
- vae, vae_mean, vae_std, vae_scale = anima_utils.load_anima_vae(args.vae_path, dtype=weight_dtype, device="cpu")
+ vae = qwen_image_autoencoder_kl.load_vae(
+ args.vae, device="cpu", disable_mmap=True, spatial_chunk_size=args.vae_chunk_size, disable_cache=args.vae_disable_cache
+ )
if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype)
@@ -294,24 +239,16 @@ def train(args):
# Load DiT (MiniTrainDIT + optional LLM Adapter)
logger.info("Loading Anima DiT...")
- dit = anima_utils.load_anima_dit(
- args.dit_path,
- dtype=weight_dtype,
- device="cpu",
- transformer_dtype=transformer_dtype,
- llm_adapter_path=getattr(args, 'llm_adapter_path', None),
- disable_mmap=getattr(args, 'disable_mmap_load_safetensors', False),
+ dit = anima_utils.load_anima_model(
+ "cpu", args.pretrained_model_name_or_path, args.attn_mode, args.split_attn, "cpu", dit_weight_dtype=None
)
if args.gradient_checkpointing:
dit.enable_gradient_checkpointing(
cpu_offload=args.cpu_offload_checkpointing,
- unsloth_offload=getattr(args, 'unsloth_offload_checkpointing', False),
+ unsloth_offload=args.unsloth_offload_checkpointing,
)
- if getattr(args, 'flash_attn', False):
- dit.set_flash_attn(True)
-
train_dit = args.learning_rate != 0
dit.requires_grad_(train_dit)
if not train_dit:
@@ -327,19 +264,17 @@ def train(args):
vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=weight_dtype)
- # Move scale tensors to same device as VAE for on-the-fly encoding
- vae_scale = [s.to(accelerator.device) if isinstance(s, torch.Tensor) else s for s in vae_scale]
# Setup optimizer with parameter groups
if train_dit:
param_groups = anima_train_utils.get_anima_param_groups(
dit,
base_lr=args.learning_rate,
- self_attn_lr=getattr(args, 'self_attn_lr', None),
- cross_attn_lr=getattr(args, 'cross_attn_lr', None),
- mlp_lr=getattr(args, 'mlp_lr', None),
- mod_lr=getattr(args, 'mod_lr', None),
- llm_adapter_lr=getattr(args, 'llm_adapter_lr', None),
+ self_attn_lr=args.self_attn_lr,
+ cross_attn_lr=args.cross_attn_lr,
+ mlp_lr=args.mlp_lr,
+ mod_lr=args.mod_lr,
+ llm_adapter_lr=args.llm_adapter_lr,
)
else:
param_groups = []
@@ -361,57 +296,7 @@ def train(args):
# prepare optimizer
accelerator.print("prepare optimizer, data loader etc.")
- if args.blockwise_fused_optimizers:
- # Split params into per-block groups for blockwise fused optimizer
- # Build param_id → lr mapping from param_groups to propagate per-component LRs
- param_lr_map = {}
- for group in param_groups:
- for p in group['params']:
- param_lr_map[id(p)] = group['lr']
-
- grouped_params = []
- param_group = {}
- named_parameters = list(dit.named_parameters())
- for name, p in named_parameters:
- if not p.requires_grad:
- continue
- # Determine block type and index
- if name.startswith("blocks."):
- block_index = int(name.split(".")[1])
- block_type = "blocks"
- elif name.startswith("llm_adapter.blocks."):
- block_index = int(name.split(".")[2])
- block_type = "llm_adapter"
- else:
- block_index = -1
- block_type = "other"
-
- param_group_key = (block_type, block_index)
- if param_group_key not in param_group:
- param_group[param_group_key] = []
- param_group[param_group_key].append(p)
-
- for param_group_key, params in param_group.items():
- # Use per-component LR from param_groups if available
- lr = param_lr_map.get(id(params[0]), args.learning_rate)
- grouped_params.append({"params": params, "lr": lr})
- num_params = sum(p.numel() for p in params)
- accelerator.print(f"block {param_group_key}: {num_params} parameters, lr={lr}")
-
- # Create per-group optimizers
- optimizers = []
- for group in grouped_params:
- _, _, opt = train_util.get_optimizer(args, trainable_params=[group])
- optimizers.append(opt)
- optimizer = optimizers[0] # avoid error in following code
-
- logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers")
-
- if train_util.is_schedulefree_optimizer(optimizers[0], args):
- raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers")
- optimizer_train_fn = lambda: None
- optimizer_eval_fn = lambda: None
- elif args.fused_backward_pass:
+ if args.fused_backward_pass:
# Pass per-component param_groups directly to preserve per-component LRs
_, _, optimizer = train_util.get_optimizer(args, trainable_params=param_groups)
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)
@@ -442,21 +327,19 @@ def train(args):
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr scheduler
- if args.blockwise_fused_optimizers:
- lr_schedulers = [train_util.get_scheduler_fix(args, opt, accelerator.num_processes) for opt in optimizers]
- lr_scheduler = lr_schedulers[0] # avoid error in following code
- else:
- lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
+ lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# full fp16/bf16 training
+ dit_weight_dtype = weight_dtype
if args.full_fp16:
assert args.mixed_precision == "fp16", "full_fp16 requires mixed_precision='fp16'"
accelerator.print("enable full fp16 training.")
- dit.to(weight_dtype)
elif args.full_bf16:
assert args.mixed_precision == "bf16", "full_bf16 requires mixed_precision='bf16'"
accelerator.print("enable full bf16 training.")
- dit.to(weight_dtype)
+ else:
+ dit_weight_dtype = torch.float32 # If neither full_fp16 nor full_bf16, the model weights should be in float32
+ dit.to(dit_weight_dtype) # convert dit to target weight dtype
# move text encoder to GPU if not cached
if not args.cache_text_encoder_outputs and qwen3_text_encoder is not None:
@@ -498,6 +381,7 @@ def train(args):
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
if args.fused_backward_pass:
+ # use fused optimizer for backward pass: other optimizers will be supported in the future
import library.adafactor_fused
library.adafactor_fused.patch_adafactor_fused(optimizer)
@@ -517,55 +401,29 @@ def train(args):
parameter.register_post_accumulate_grad_hook(create_grad_hook(param_group))
- elif args.blockwise_fused_optimizers:
- # Prepare additional optimizers and lr schedulers
- for i in range(1, len(optimizers)):
- optimizers[i] = accelerator.prepare(optimizers[i])
- lr_schedulers[i] = accelerator.prepare(lr_schedulers[i])
-
- # Counters for blockwise gradient hook
- optimizer_hooked_count = {}
- num_parameters_per_group = [0] * len(optimizers)
- parameter_optimizer_map = {}
-
- for opt_idx, opt in enumerate(optimizers):
- for param_group in opt.param_groups:
- for parameter in param_group["params"]:
- if parameter.requires_grad:
-
- def grad_hook(parameter: torch.Tensor):
- if accelerator.sync_gradients and args.max_grad_norm != 0.0:
- accelerator.clip_grad_norm_(parameter, args.max_grad_norm)
-
- i = parameter_optimizer_map[parameter]
- optimizer_hooked_count[i] += 1
- if optimizer_hooked_count[i] == num_parameters_per_group[i]:
- optimizers[i].step()
- optimizers[i].zero_grad(set_to_none=True)
-
- parameter.register_post_accumulate_grad_hook(grad_hook)
- parameter_optimizer_map[parameter] = opt_idx
- num_parameters_per_group[opt_idx] += 1
-
# Training loop
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
- accelerator.print("running training")
- accelerator.print(f" num examples: {train_dataset_group.num_train_images}")
- accelerator.print(f" num batches per epoch: {len(train_dataloader)}")
- accelerator.print(f" num epochs: {num_train_epochs}")
+ accelerator.print("running training / 学習開始")
+ accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
+ accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
+ accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
accelerator.print(
- f" batch size per device: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
+ f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
)
- accelerator.print(f" gradient accumulation steps = {args.gradient_accumulation_steps}")
- accelerator.print(f" total optimization steps: {args.max_train_steps}")
+ accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
+ accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
+ noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
+ # Copy for noise and timestep generation, because noise_scheduler may be changed during training in future
+ noise_scheduler_copy = copy.deepcopy(noise_scheduler)
+
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
@@ -580,6 +438,7 @@ def train(args):
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
import wandb
+
wandb.define_metric("epoch")
wandb.define_metric("loss/epoch", step_metric="epoch")
@@ -589,8 +448,15 @@ def train(args):
# For --sample_at_first
optimizer_eval_fn()
anima_train_utils.sample_images(
- accelerator, args, 0, global_step, dit, vae, vae_scale,
- qwen3_text_encoder, tokenize_strategy, text_encoding_strategy,
+ accelerator,
+ args,
+ 0,
+ global_step,
+ dit,
+ vae,
+ qwen3_text_encoder,
+ tokenize_strategy,
+ text_encoding_strategy,
sample_prompts_te_outputs,
)
optimizer_train_fn()
@@ -600,11 +466,11 @@ def train(args):
# Show model info
unwrapped_dit = accelerator.unwrap_model(dit) if dit is not None else None
if unwrapped_dit is not None:
- logger.info(f"dit device: {unwrapped_dit.t_embedding_norm.weight.device}, dtype: {unwrapped_dit.t_embedding_norm.weight.dtype}")
+ logger.info(f"dit device: {unwrapped_dit.device}, dtype: {unwrapped_dit.dtype}")
if qwen3_text_encoder is not None:
- logger.info(f"qwen3 device: {next(qwen3_text_encoder.parameters()).device}")
+ logger.info(f"qwen3 device: {qwen3_text_encoder.device}")
if vae is not None:
- logger.info(f"vae device: {next(vae.parameters()).device}")
+ logger.info(f"vae device: {vae.device}")
loss_recorder = train_util.LossRecorder()
epoch = 0
@@ -618,19 +484,17 @@ def train(args):
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
- if args.blockwise_fused_optimizers:
- optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step
-
with accelerator.accumulate(*training_models):
# Get latents
if "latents" in batch and batch["latents"] is not None:
- latents = batch["latents"].to(accelerator.device, dtype=weight_dtype)
+ latents = batch["latents"].to(accelerator.device, dtype=dit_weight_dtype)
+ if latents.ndim == 5: # Fallback for 5D latents (old cache)
+ latents = latents.squeeze(2) # (B, C, 1, H, W) -> (B, C, H, W)
else:
with torch.no_grad():
# images are already [-1, 1] from IMAGE_TRANSFORMS, add temporal dim
images = batch["images"].to(accelerator.device, dtype=weight_dtype)
- images = images.unsqueeze(2) # (B, C, 1, H, W)
- latents = vae.encode(images, vae_scale).to(accelerator.device, dtype=weight_dtype)
+ latents = vae.encode_pixels_to_latents(images).to(accelerator.device, dtype=dit_weight_dtype)
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
@@ -640,23 +504,24 @@ def train(args):
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
# Cached outputs
+ caption_dropout_rates = text_encoder_outputs_list[-1]
+ text_encoder_outputs_list = text_encoder_outputs_list[:-1]
+
+ # Apply caption dropout to cached outputs
text_encoder_outputs_list = text_encoding_strategy.drop_cached_text_encoder_outputs(
- *text_encoder_outputs_list
+ *text_encoder_outputs_list, caption_dropout_rates=caption_dropout_rates
)
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoder_outputs_list
else:
# Encode on-the-fly
input_ids_list = batch["input_ids_list"]
- qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask = input_ids_list
with torch.no_grad():
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoding_strategy.encode_tokens(
- tokenize_strategy,
- [qwen3_text_encoder],
- [qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask],
+ tokenize_strategy, [qwen3_text_encoder], input_ids_list
)
# Move to device
- prompt_embeds = prompt_embeds.to(accelerator.device, dtype=weight_dtype)
+ prompt_embeds = prompt_embeds.to(accelerator.device, dtype=dit_weight_dtype)
attn_mask = attn_mask.to(accelerator.device)
t5_input_ids = t5_input_ids.to(accelerator.device, dtype=torch.long)
t5_attn_mask = t5_attn_mask.to(accelerator.device)
@@ -664,9 +529,11 @@ def train(args):
# Noise and timesteps
noise = torch.randn_like(latents)
- noisy_model_input, timesteps, sigmas = anima_train_utils.get_noisy_model_input_and_timesteps(
- args, latents, noise, accelerator.device, weight_dtype
+ # Get noisy model input and timesteps
+ noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
+ args, noise_scheduler_copy, latents, noise, accelerator.device, dit_weight_dtype
)
+ timesteps = timesteps / 1000.0 # scale to [0, 1] range. timesteps is float32
# NaN checks
if torch.any(torch.isnan(noisy_model_input)):
@@ -678,15 +545,10 @@ def train(args):
bs = latents.shape[0]
h_latent = latents.shape[-2]
w_latent = latents.shape[-1]
- padding_mask = torch.zeros(
- bs, 1, h_latent, w_latent,
- dtype=weight_dtype, device=accelerator.device
- )
+ padding_mask = torch.zeros(bs, 1, h_latent, w_latent, dtype=dit_weight_dtype, device=accelerator.device)
# DiT forward (LLM adapter runs inside forward for DDP gradient sync)
- if is_swapping_blocks:
- accelerator.unwrap_model(dit).prepare_block_swap_before_forward()
-
+ noisy_model_input = noisy_model_input.unsqueeze(2) # 4D to 5D, (B, C, 1, H, W)
with accelerator.autocast():
model_pred = dit(
noisy_model_input,
@@ -697,6 +559,7 @@ def train(args):
t5_input_ids=t5_input_ids,
t5_attn_mask=t5_attn_mask,
)
+ model_pred = model_pred.squeeze(2) # 5D to 4D, (B, C, H, W)
# Compute loss (rectified flow: target = noise - latents)
target = noise - latents
@@ -708,12 +571,10 @@ def train(args):
# Loss
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, None)
- loss = train_util.conditional_loss(
- model_pred.float(), target.float(), args.loss_type, "none", huber_c
- )
+ loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c)
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
- loss = loss.mean([1, 2, 3, 4]) # (B, C, T, H, W) -> (B,)
+ loss = loss.mean([1, 2, 3]) # (B, C, H, W) -> (B,)
if weighting is not None:
loss = loss * weighting
@@ -724,7 +585,7 @@ def train(args):
accelerator.backward(loss)
- if not (args.fused_backward_pass or args.blockwise_fused_optimizers):
+ if not args.fused_backward_pass:
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = []
for m in training_models:
@@ -737,9 +598,6 @@ def train(args):
else:
# optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
lr_scheduler.step()
- if args.blockwise_fused_optimizers:
- for i in range(1, len(optimizers)):
- lr_schedulers[i].step()
# Checks if the accelerator has performed an optimization step
if accelerator.sync_gradients:
@@ -748,8 +606,15 @@ def train(args):
optimizer_eval_fn()
anima_train_utils.sample_images(
- accelerator, args, None, global_step, dit, vae, vae_scale,
- qwen3_text_encoder, tokenize_strategy, text_encoding_strategy,
+ accelerator,
+ args,
+ None,
+ global_step,
+ dit,
+ vae,
+ qwen3_text_encoder,
+ tokenize_strategy,
+ text_encoding_strategy,
sample_prompts_te_outputs,
)
@@ -773,8 +638,10 @@ def train(args):
if len(accelerator.trackers) > 0:
logs = {"loss": current_loss}
train_util.append_lr_to_logs_with_names(
- logs, lr_scheduler, args.optimizer_type,
- ["base", "self_attn", "cross_attn", "mlp", "mod", "llm_adapter"] if train_dit else []
+ logs,
+ lr_scheduler,
+ args.optimizer_type,
+ ["base", "self_attn", "cross_attn", "mlp", "mod", "llm_adapter"] if train_dit else [],
)
accelerator.log(logs, step=global_step)
@@ -807,8 +674,15 @@ def train(args):
)
anima_train_utils.sample_images(
- accelerator, args, epoch + 1, global_step, dit, vae, vae_scale,
- qwen3_text_encoder, tokenize_strategy, text_encoding_strategy,
+ accelerator,
+ args,
+ epoch + 1,
+ global_step,
+ dit,
+ vae,
+ qwen3_text_encoder,
+ tokenize_strategy,
+ text_encoding_strategy,
sample_prompts_te_outputs,
)
@@ -852,11 +726,6 @@ def setup_parser() -> argparse.ArgumentParser:
anima_train_utils.add_anima_training_arguments(parser)
sai_model_spec.add_model_spec_arguments(parser)
- parser.add_argument(
- "--blockwise_fused_optimizers",
- action="store_true",
- help="enable blockwise optimizers for fused backward pass and optimizer step",
- )
parser.add_argument(
"--cpu_offload_checkpointing",
action="store_true",
@@ -884,4 +753,7 @@ if __name__ == "__main__":
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
+ if args.attn_mode == "sdpa":
+ args.attn_mode = "torch" # backward compatibility
+
train(args)
diff --git a/anima_train_network.py b/anima_train_network.py
index 57ad1681..eaad7197 100644
--- a/anima_train_network.py
+++ b/anima_train_network.py
@@ -1,16 +1,26 @@
# Anima LoRA training script
import argparse
-import math
from typing import Any, Optional, Union
import torch
+import torch.nn as nn
from accelerate import Accelerator
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
-from library import anima_models, anima_train_utils, anima_utils, strategy_anima, strategy_base, train_util
+from library import (
+ anima_models,
+ anima_train_utils,
+ anima_utils,
+ flux_train_utils,
+ qwen_image_autoencoder_kl,
+ sd3_train_utils,
+ strategy_anima,
+ strategy_base,
+ train_util,
+)
import train_network
from library.utils import setup_logging
@@ -24,13 +34,6 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
def __init__(self):
super().__init__()
self.sample_prompts_te_outputs = None
- self.vae = None
- self.vae_scale = None
- self.qwen3_text_encoder = None
- self.qwen3_tokenizer = None
- self.t5_tokenizer = None
- self.tokenize_strategy = None
- self.text_encoding_strategy = None
def assert_extra_args(
self,
@@ -38,140 +41,118 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
val_dataset_group: Optional[train_util.DatasetGroup],
):
+ if args.fp8_base or args.fp8_base_unet:
+ logger.warning("fp8_base and fp8_base_unet are not supported. / fp8_baseとfp8_base_unetはサポートされていません。")
+ args.fp8_base = False
+ args.fp8_base_unet = False
+ args.fp8_scaled = False # Anima DiT does not support fp8_scaled
+
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
- logger.warning(
- "cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled"
- )
+ logger.warning("cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled")
args.cache_text_encoder_outputs = True
- # Anima uses embedding-level dropout (in AnimaTextEncodingStrategy) instead of
- # dataset-level caption dropout, so zero out subset-level rates to allow caching.
- caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
- if caption_dropout_rate > 0:
- logger.info(f"Using embedding-level caption dropout rate: {caption_dropout_rate}")
- if hasattr(train_dataset_group, 'datasets'):
- for dataset in train_dataset_group.datasets:
- for subset in dataset.subsets:
- subset.caption_dropout_rate = 0.0
-
if args.cache_text_encoder_outputs:
- assert (
- train_dataset_group.is_text_encoder_output_cacheable()
+ assert train_dataset_group.is_text_encoder_output_cacheable(
+ cache_supports_dropout=True
), "when caching Text Encoder output, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used"
+ assert (
+ args.network_train_unet_only or not args.cache_text_encoder_outputs
+ ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません"
+
assert (
args.blocks_to_swap is None or args.blocks_to_swap == 0
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing"
- if getattr(args, 'unsloth_offload_checkpointing', False):
+ if args.unsloth_offload_checkpointing:
if not args.gradient_checkpointing:
logger.warning("unsloth_offload_checkpointing is enabled, so gradient_checkpointing is also enabled")
args.gradient_checkpointing = True
- assert not args.cpu_offload_checkpointing, \
- "Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing"
+ assert (
+ not args.cpu_offload_checkpointing
+ ), "Cannot use both --unsloth_offload_checkpointing and --cpu_offload_checkpointing"
assert (
args.blocks_to_swap is None or args.blocks_to_swap == 0
), "blocks_to_swap is not supported with unsloth_offload_checkpointing"
- # Flash attention: validate availability
- if getattr(args, 'flash_attn', False):
- try:
- import flash_attn # noqa: F401
- logger.info("Flash Attention enabled for DiT blocks")
- except ImportError:
- logger.warning("flash_attn package not installed, falling back to PyTorch SDPA")
- args.flash_attn = False
-
- if getattr(args, 'blockwise_fused_optimizers', False):
- raise ValueError("blockwise_fused_optimizers is not supported with LoRA/NetworkTrainer")
-
- train_dataset_group.verify_bucket_reso_steps(8) # WanVAE spatial downscale = 8
+ train_dataset_group.verify_bucket_reso_steps(16) # WanVAE spatial downscale = 8 and patch size = 2
if val_dataset_group is not None:
- val_dataset_group.verify_bucket_reso_steps(8)
+ val_dataset_group.verify_bucket_reso_steps(16)
def load_target_model(self, args, weight_dtype, accelerator):
+ self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
+
# Load Qwen3 text encoder (tokenizers already loaded in get_tokenize_strategy)
logger.info("Loading Qwen3 text encoder...")
- self.qwen3_text_encoder, _ = anima_utils.load_qwen3_text_encoder(
- args.qwen3_path, dtype=weight_dtype, device="cpu"
- )
- self.qwen3_text_encoder.eval()
+ qwen3_text_encoder, _ = anima_utils.load_qwen3_text_encoder(args.qwen3, dtype=weight_dtype, device="cpu")
+ qwen3_text_encoder.eval()
- # Parse transformer_dtype
- transformer_dtype = None
- if hasattr(args, 'transformer_dtype') and args.transformer_dtype is not None:
- transformer_dtype_map = {
- "float16": torch.float16,
- "bfloat16": torch.bfloat16,
- "float32": torch.float32,
- }
- transformer_dtype = transformer_dtype_map.get(args.transformer_dtype, None)
+ # Load VAE
+ logger.info("Loading Anima VAE...")
+ vae = qwen_image_autoencoder_kl.load_vae(
+ args.vae, device="cpu", disable_mmap=True, spatial_chunk_size=args.vae_chunk_size, disable_cache=args.vae_disable_cache
+ )
+ vae.to(weight_dtype)
+ vae.eval()
+
+ # Return format: (model_type, text_encoders, vae, unet)
+ return "anima", [qwen3_text_encoder], vae, None # unet loaded lazily
+
+ def load_unet_lazily(self, args, weight_dtype, accelerator, text_encoders) -> tuple[nn.Module, list[nn.Module]]:
+ loading_dtype = None if args.fp8_scaled else weight_dtype
+ loading_device = "cpu" if self.is_swapping_blocks else accelerator.device
+
+ attn_mode = "torch"
+ if args.xformers:
+ attn_mode = "xformers"
+ if args.attn_mode is not None:
+ attn_mode = args.attn_mode
# Load DiT
- logger.info("Loading Anima DiT...")
- dit = anima_utils.load_anima_dit(
- args.dit_path,
- dtype=weight_dtype,
- device="cpu",
- transformer_dtype=transformer_dtype,
- llm_adapter_path=getattr(args, 'llm_adapter_path', None),
- disable_mmap=getattr(args, 'disable_mmap_load_safetensors', False),
+ logger.info(f"Loading Anima DiT model with attn_mode={attn_mode}, split_attn: {args.split_attn}...")
+ model = anima_utils.load_anima_model(
+ accelerator.device,
+ args.pretrained_model_name_or_path,
+ attn_mode,
+ args.split_attn,
+ loading_device,
+ loading_dtype,
+ args.fp8_scaled,
)
- # Flash attention
- if getattr(args, 'flash_attn', False):
- dit.set_flash_attn(True)
-
# Store unsloth preference so that when the base NetworkTrainer calls
# dit.enable_gradient_checkpointing(cpu_offload=...), we can override to use unsloth.
# The base trainer only passes cpu_offload, so we store the flag on the model.
- self._use_unsloth_offload_checkpointing = getattr(args, 'unsloth_offload_checkpointing', False)
+ self._use_unsloth_offload_checkpointing = args.unsloth_offload_checkpointing
# Block swap
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
if self.is_swapping_blocks:
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
- dit.enable_block_swap(args.blocks_to_swap, accelerator.device)
+ model.enable_block_swap(args.blocks_to_swap, accelerator.device)
- # Load VAE
- logger.info("Loading Anima VAE...")
- self.vae, vae_mean, vae_std, self.vae_scale = anima_utils.load_anima_vae(
- args.vae_path, dtype=weight_dtype, device="cpu"
- )
-
- # Return format: (model_type, text_encoders, vae, unet)
- return "anima", [self.qwen3_text_encoder], self.vae, dit
+ return model, text_encoders
def get_tokenize_strategy(self, args):
# Load tokenizers from paths (called before load_target_model, so self.qwen3_tokenizer isn't set yet)
- self.tokenize_strategy = strategy_anima.AnimaTokenizeStrategy(
- qwen3_path=args.qwen3_path,
- t5_tokenizer_path=getattr(args, 't5_tokenizer_path', None),
+ tokenize_strategy = strategy_anima.AnimaTokenizeStrategy(
+ qwen3_path=args.qwen3,
+ t5_tokenizer_path=args.t5_tokenizer_path,
qwen3_max_length=args.qwen3_max_token_length,
t5_max_length=args.t5_max_token_length,
)
- # Store references so load_target_model can reuse them
- self.qwen3_tokenizer = self.tokenize_strategy.qwen3_tokenizer
- self.t5_tokenizer = self.tokenize_strategy.t5_tokenizer
- return self.tokenize_strategy
+ return tokenize_strategy
def get_tokenizers(self, tokenize_strategy: strategy_anima.AnimaTokenizeStrategy):
return [tokenize_strategy.qwen3_tokenizer]
def get_latents_caching_strategy(self, args):
- return strategy_anima.AnimaLatentsCachingStrategy(
- args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
- )
+ return strategy_anima.AnimaLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check)
def get_text_encoding_strategy(self, args):
- caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
- self.text_encoding_strategy = strategy_anima.AnimaTextEncodingStrategy(
- dropout_rate=caption_dropout_rate,
- )
- return self.text_encoding_strategy
+ return strategy_anima.AnimaTextEncodingStrategy()
def post_process_network(self, args, accelerator, network, text_encoders, unet):
- # Qwen3 text encoder is always frozen for Anima
pass
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
@@ -179,19 +160,10 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
return None # no text encoders needed for encoding
return text_encoders
- def get_text_encoders_train_flags(self, args, text_encoders):
- return [False] # Qwen3 always frozen
-
- def is_train_text_encoder(self, args):
- return False # Qwen3 text encoder is always frozen for Anima
-
def get_text_encoder_outputs_caching_strategy(self, args):
if args.cache_text_encoder_outputs:
return strategy_anima.AnimaTextEncoderOutputsCachingStrategy(
- args.cache_text_encoder_outputs_to_disk,
- args.text_encoder_batch_size,
- args.skip_cache_check,
- is_partial=False,
+ args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False
)
return None
@@ -200,15 +172,14 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
):
if args.cache_text_encoder_outputs:
if not args.lowram:
- logger.info("move vae and unet to cpu to save memory")
- org_vae_device = next(vae.parameters()).device
- org_unet_device = unet.device
+ # We cannot move DiT to CPU because of block swap, so only move VAE
+ logger.info("move vae to cpu to save memory")
+ org_vae_device = vae.device
vae.to("cpu")
- unet.to("cpu")
clean_memory_on_device(accelerator.device)
logger.info("move text encoder to gpu")
- text_encoders[0].to(accelerator.device, dtype=weight_dtype)
+ text_encoders[0].to(accelerator.device)
with accelerator.autocast():
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
@@ -229,59 +200,52 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
logger.info(f" cache TE outputs for: {p}")
tokens_and_masks = tokenize_strategy.tokenize(p)
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
- tokenize_strategy,
- text_encoders,
- tokens_and_masks,
- enable_dropout=False,
+ tokenize_strategy, text_encoders, tokens_and_masks
)
self.sample_prompts_te_outputs = sample_prompts_te_outputs
- # Pre-cache unconditional embeddings for caption dropout before text encoder is deleted
- caption_dropout_rate = getattr(args, 'caption_dropout_rate', 0.0)
- text_encoding_strategy_for_uncond = strategy_base.TextEncodingStrategy.get_strategy()
- if caption_dropout_rate > 0.0:
- tokenize_strategy_for_uncond = strategy_base.TokenizeStrategy.get_strategy()
- with accelerator.autocast():
- text_encoding_strategy_for_uncond.cache_uncond_embeddings(tokenize_strategy_for_uncond, text_encoders)
-
accelerator.wait_for_everyone()
# move text encoder back to cpu
logger.info("move text encoder back to cpu")
text_encoders[0].to("cpu")
- clean_memory_on_device(accelerator.device)
if not args.lowram:
- logger.info("move vae and unet back to original device")
+ logger.info("move vae back to original device")
vae.to(org_vae_device)
- unet.to(org_unet_device)
+
+ clean_memory_on_device(accelerator.device)
else:
- text_encoders[0].to(accelerator.device, dtype=weight_dtype)
+ # move text encoder to device for encoding during training/validation
+ text_encoders[0].to(accelerator.device)
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder] # compatibility
te = self.get_models_for_text_encoding(args, accelerator, text_encoders)
qwen3_te = te[0] if te is not None else None
+ text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
+ tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
anima_train_utils.sample_images(
- accelerator, args, epoch, global_step, unet, vae, self.vae_scale,
- qwen3_te, self.tokenize_strategy, self.text_encoding_strategy,
+ accelerator,
+ args,
+ epoch,
+ global_step,
+ unet,
+ vae,
+ qwen3_te,
+ tokenize_strategy,
+ text_encoding_strategy,
self.sample_prompts_te_outputs,
)
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
- noise_scheduler = anima_train_utils.FlowMatchEulerDiscreteScheduler(
- num_train_timesteps=1000, shift=args.discrete_flow_shift
- )
+ noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
return noise_scheduler
def encode_images_to_latents(self, args, vae, images):
- # images are already [-1,1] from IMAGE_TRANSFORMS, add temporal dim
- images = images.unsqueeze(2) # (B, C, 1, H, W)
- # Ensure scale tensors are on the same device as images
- vae_device = images.device
- scale = [s.to(vae_device) if isinstance(s, torch.Tensor) else s for s in self.vae_scale]
- return vae.encode(images, scale)
+ vae: qwen_image_autoencoder_kl.AutoencoderKLQwenImage
+ return vae.encode_pixels_to_latents(images) # Keep 4D for input/output
def shift_scale_latents(self, args, latents):
# Latents already normalized by vae.encode with scale
@@ -301,13 +265,18 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
train_unet,
is_train=True,
):
+ anima: anima_models.Anima = unet
+
# Sample noise
+ if latents.ndim == 5: # Fallback for 5D latents (old cache)
+ latents = latents.squeeze(2) # [B, C, 1, H, W] -> [B, C, H, W]
noise = torch.randn_like(latents)
# Get noisy model input and timesteps
- noisy_model_input, timesteps, sigmas = anima_train_utils.get_noisy_model_input_and_timesteps(
- args, latents, noise, accelerator.device, weight_dtype
+ noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
+ args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
)
+ timesteps = timesteps / 1000.0 # scale to [0, 1] range. timesteps is float32
# Gradient checkpointing support
if args.gradient_checkpointing:
@@ -329,161 +298,103 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
bs = latents.shape[0]
h_latent = latents.shape[-2]
w_latent = latents.shape[-1]
- padding_mask = torch.zeros(
- bs, 1, h_latent, w_latent,
- dtype=weight_dtype, device=accelerator.device
- )
+ padding_mask = torch.zeros(bs, 1, h_latent, w_latent, dtype=weight_dtype, device=accelerator.device)
- # Prepare block swap
- if self.is_swapping_blocks:
- accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
-
- # Call model (LLM adapter runs inside forward for DDP gradient sync)
+ # Call model
+ noisy_model_input = noisy_model_input.unsqueeze(2) # 4D to 5D, [B, C, H, W] -> [B, C, 1, H, W]
with torch.set_grad_enabled(is_train), accelerator.autocast():
- model_pred = unet(
+ model_pred = anima(
noisy_model_input,
timesteps,
prompt_embeds,
padding_mask=padding_mask,
+ target_input_ids=t5_input_ids,
+ target_attention_mask=t5_attn_mask,
source_attention_mask=attn_mask,
- t5_input_ids=t5_input_ids,
- t5_attn_mask=t5_attn_mask,
)
+ model_pred = model_pred.squeeze(2) # 5D to 4D, [B, C, 1, H, W] -> [B, C, H, W]
# Rectified flow target: noise - latents
target = noise - latents
# Loss weighting
- weighting = anima_train_utils.compute_loss_weighting_for_anima(
- weighting_scheme=args.weighting_scheme, sigmas=sigmas
- )
-
- # Differential output preservation
- if "custom_attributes" in batch:
- diff_output_pr_indices = []
- for i, custom_attributes in enumerate(batch["custom_attributes"]):
- if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
- diff_output_pr_indices.append(i)
-
- if len(diff_output_pr_indices) > 0:
- network.set_multiplier(0.0)
- with torch.no_grad(), accelerator.autocast():
- if self.is_swapping_blocks:
- accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
- model_pred_prior = unet(
- noisy_model_input[diff_output_pr_indices],
- timesteps[diff_output_pr_indices],
- prompt_embeds[diff_output_pr_indices],
- padding_mask=padding_mask[diff_output_pr_indices],
- source_attention_mask=attn_mask[diff_output_pr_indices],
- t5_input_ids=t5_input_ids[diff_output_pr_indices],
- t5_attn_mask=t5_attn_mask[diff_output_pr_indices],
- )
- network.set_multiplier(1.0)
-
- target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
+ weighting = anima_train_utils.compute_loss_weighting_for_anima(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
return model_pred, target, timesteps, weighting
def process_batch(
- self, batch, text_encoders, unet, network, vae, noise_scheduler,
- vae_dtype, weight_dtype, accelerator, args,
- text_encoding_strategy, tokenize_strategy,
- is_train=True, train_text_encoder=True, train_unet=True,
+ self,
+ batch,
+ text_encoders,
+ unet,
+ network,
+ vae,
+ noise_scheduler,
+ vae_dtype,
+ weight_dtype,
+ accelerator,
+ args,
+ text_encoding_strategy,
+ tokenize_strategy,
+ is_train=True,
+ train_text_encoder=True,
+ train_unet=True,
) -> torch.Tensor:
- """Override base process_batch for 5D video latents (B, C, T, H, W).
-
- Base class assumes 4D (B, C, H, W) for loss.mean([1,2,3]) and weighting broadcast.
- """
- import typing
- from library.custom_train_functions import apply_masked_loss
-
- with torch.no_grad():
- if "latents" in batch and batch["latents"] is not None:
- latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device))
- else:
- if args.vae_batch_size is None or len(batch["images"]) <= args.vae_batch_size:
- latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype))
- else:
- chunks = [
- batch["images"][i : i + args.vae_batch_size] for i in range(0, len(batch["images"]), args.vae_batch_size)
- ]
- list_latents = []
- for chunk in chunks:
- with torch.no_grad():
- chunk = self.encode_images_to_latents(args, vae, chunk.to(accelerator.device, dtype=vae_dtype))
- list_latents.append(chunk)
- latents = torch.cat(list_latents, dim=0)
-
- if torch.any(torch.isnan(latents)):
- accelerator.print("NaN found in latents, replacing with zeros")
- latents = typing.cast(torch.FloatTensor, torch.nan_to_num(latents, 0, out=latents))
-
- latents = self.shift_scale_latents(args, latents)
+ """Override base process_batch for caption dropout with cached text encoder outputs."""
# Text encoder conditions
- text_encoder_conds = []
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
+ anima_text_encoding_strategy: strategy_anima.AnimaTextEncodingStrategy = text_encoding_strategy
if text_encoder_outputs_list is not None:
- text_encoder_conds = text_encoder_outputs_list
+ caption_dropout_rates = text_encoder_outputs_list[-1]
+ text_encoder_outputs_list = text_encoder_outputs_list[:-1]
- if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
- with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast():
- input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
- encoded_text_encoder_conds = text_encoding_strategy.encode_tokens(
- tokenize_strategy,
- self.get_models_for_text_encoding(args, accelerator, text_encoders),
- input_ids,
- )
- if args.full_fp16:
- encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds]
+ # Apply caption dropout to cached outputs
+ text_encoder_outputs_list = anima_text_encoding_strategy.drop_cached_text_encoder_outputs(
+ *text_encoder_outputs_list, caption_dropout_rates=caption_dropout_rates
+ )
+ batch["text_encoder_outputs_list"] = text_encoder_outputs_list
- if len(text_encoder_conds) == 0:
- text_encoder_conds = encoded_text_encoder_conds
- else:
- for i in range(len(encoded_text_encoder_conds)):
- if encoded_text_encoder_conds[i] is not None:
- text_encoder_conds[i] = encoded_text_encoder_conds[i]
-
- noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target(
- args, accelerator, noise_scheduler, latents, batch,
- text_encoder_conds, unet, network, weight_dtype, train_unet, is_train=is_train,
+ return super().process_batch(
+ batch,
+ text_encoders,
+ unet,
+ network,
+ vae,
+ noise_scheduler,
+ vae_dtype,
+ weight_dtype,
+ accelerator,
+ args,
+ text_encoding_strategy,
+ tokenize_strategy,
+ is_train,
+ train_text_encoder,
+ train_unet,
)
- huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
- loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
-
- if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
- loss = apply_masked_loss(loss, batch)
-
- # Reduce all non-batch dims: (B, C, T, H, W) -> (B,) for 5D, (B, C, H, W) -> (B,) for 4D
- reduce_dims = list(range(1, loss.ndim))
- loss = loss.mean(reduce_dims)
-
- # Apply weighting after reducing to (B,)
- if weighting is not None:
- loss = loss * weighting
-
- loss_weights = batch["loss_weights"]
- loss = loss * loss_weights
-
- loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
- return loss.mean()
-
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
return loss
def get_sai_model_spec(self, args):
- return train_util.get_sai_model_spec(None, args, False, True, False, is_stable_diffusion_ckpt=True)
+ return train_util.get_sai_model_spec_dataclass(None, args, False, True, False, anima="preview").to_metadata_dict()
def update_metadata(self, metadata, args):
metadata["ss_weighting_scheme"] = args.weighting_scheme
+ metadata["ss_logit_mean"] = args.logit_mean
+ metadata["ss_logit_std"] = args.logit_std
+ metadata["ss_mode_scale"] = args.mode_scale
+ metadata["ss_timestep_sampling"] = args.timestep_sampling
+ metadata["ss_sigmoid_scale"] = args.sigmoid_scale
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
- metadata["ss_timestep_sample_method"] = getattr(args, 'timestep_sample_method', 'logit_normal')
- metadata["ss_sigmoid_scale"] = getattr(args, 'sigmoid_scale', 1.0)
def is_text_encoder_not_needed_for_training(self, args):
- return args.cache_text_encoder_outputs
+ return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
+
+ def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
+ # Set first parameter's requires_grad to True to workaround Accelerate gradient checkpointing bug
+ first_param = next(text_encoder.parameters())
+ first_param.requires_grad_(True)
def prepare_unet_with_accelerator(
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
@@ -496,23 +407,16 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer):
if not self.is_swapping_blocks:
return super().prepare_unet_with_accelerator(args, accelerator, unet)
- dit = unet
- dit = accelerator.prepare(dit, device_placement=[not self.is_swapping_blocks])
- accelerator.unwrap_model(dit).move_to_device_except_swap_blocks(accelerator.device)
- accelerator.unwrap_model(dit).prepare_block_swap_before_forward()
+ model = unet
+ model = accelerator.prepare(model, device_placement=[not self.is_swapping_blocks])
+ accelerator.unwrap_model(model).move_to_device_except_swap_blocks(accelerator.device)
+ accelerator.unwrap_model(model).prepare_block_swap_before_forward()
- return dit
-
- def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True):
- # Drop cached text encoder outputs for caption dropout
- text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
- if text_encoder_outputs_list is not None:
- text_encoding_strategy: strategy_anima.AnimaTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
- text_encoder_outputs_list = text_encoding_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list)
- batch["text_encoder_outputs_list"] = text_encoder_outputs_list
+ return model
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
if self.is_swapping_blocks:
+ # prepare for next forward: because backward pass is not called, we need to prepare it here
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
@@ -520,6 +424,7 @@ def setup_parser() -> argparse.ArgumentParser:
parser = train_network.setup_parser()
train_util.add_dit_training_arguments(parser)
anima_train_utils.add_anima_training_arguments(parser)
+ # parser.add_argument("--fp8_scaled", action="store_true", help="Use scaled fp8 for DiT / DiTにスケーリングされたfp8を使う")
parser.add_argument(
"--unsloth_offload_checkpointing",
action="store_true",
@@ -536,5 +441,8 @@ if __name__ == "__main__":
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
+ if args.attn_mode == "sdpa":
+ args.attn_mode = "torch" # backward compatibility
+
trainer = AnimaNetworkTrainer()
trainer.train(args)
diff --git a/docs/anima_train_network.md b/docs/anima_train_network.md
index fe6b2354..f97aa975 100644
--- a/docs/anima_train_network.md
+++ b/docs/anima_train_network.md
@@ -11,7 +11,9 @@ This document explains how to train LoRA (Low-Rank Adaptation) models for Anima
## 1. Introduction / はじめに
-`anima_train_network.py` trains additional networks such as LoRA for Anima models. Anima adopts a DiT (Diffusion Transformer) architecture based on the MiniTrainDIT design with Rectified Flow training. It uses a Qwen3-0.6B text encoder, an LLM Adapter (6-layer transformer bridge from Qwen3 to T5-compatible space), and a WanVAE (16-channel, 8x spatial downscale).
+`anima_train_network.py` trains additional networks such as LoRA for Anima models. Anima adopts a DiT (Diffusion Transformer) architecture based on the MiniTrainDIT design with Rectified Flow training. It uses a Qwen3-0.6B text encoder, an LLM Adapter (6-layer transformer bridge from Qwen3 to T5-compatible space), and a Qwen-Image VAE (16-channel, 8x spatial downscale).
+
+Qwen-Image VAE and Qwen-Image VAE have same architecture, but [official Anima weight is named for Qwen-Image VAE](https://huggingface.co/circlestone-labs/Anima/tree/main/split_files/vae).
This guide assumes you already understand the basics of LoRA training. For common usage and options, see the [train_network.py guide](train_network.md). Some parameters are similar to those in [`sd3_train_network.py`](sd3_train_network.md) and [`flux_train_network.py`](flux_train_network.md).
@@ -24,7 +26,9 @@ This guide assumes you already understand the basics of LoRA training. For commo
日本語
-`anima_train_network.py`は、Anima モデルに対してLoRAなどの追加ネットワークを学習させるためのスクリプトです。AnimaはMiniTrainDIT設計に基づくDiT (Diffusion Transformer) アーキテクチャを採用しており、Rectified Flow学習を使用します。テキストエンコーダーとしてQwen3-0.6B、LLM Adapter (Qwen3からT5互換空間への6層Transformerブリッジ)、およびWanVAE (16チャンネル、8倍空間ダウンスケール) を使用します。
+`anima_train_network.py`は、Anima モデルに対してLoRAなどの追加ネットワークを学習させるためのスクリプトです。AnimaはMiniTrainDIT設計に基づくDiT (Diffusion Transformer) アーキテクチャを採用しており、Rectified Flow学習を使用します。テキストエンコーダーとしてQwen3-0.6B、LLM Adapter (Qwen3からT5互換空間への6層Transformerブリッジ)、およびQwen-Image VAE (16チャンネル、8倍空間ダウンスケール) を使用します。
+
+Qwen-Image VAEとQwen-Image VAEは同じアーキテクチャですが、[Anima公式の重みはQwen-Image VAE用](https://huggingface.co/circlestone-labs/Anima/tree/main/split_files/vae)のようです。
このガイドは、基本的なLoRA学習の手順を理解しているユーザーを対象としています。基本的な使い方や共通のオプションについては、[`train_network.py`のガイド](train_network.md)を参照してください。また一部のパラメータは [`sd3_train_network.py`](sd3_train_network.md) や [`flux_train_network.py`](flux_train_network.md) と同様のものがあるため、そちらも参考にしてください。
@@ -37,14 +41,14 @@ This guide assumes you already understand the basics of LoRA training. For commo
## 2. Differences from `train_network.py` / `train_network.py` との違い
-`anima_train_network.py` is based on `train_network.py` but modified for Anima . Main differences are:
+`anima_train_network.py` is based on `train_network.py` but modified for Anima. Main differences are:
* **Target models:** Anima DiT models.
-* **Model structure:** Uses a MiniTrainDIT (Transformer based) instead of U-Net. Employs a single text encoder (Qwen3-0.6B), an LLM Adapter that bridges Qwen3 embeddings to T5-compatible cross-attention space, and a WanVAE (16-channel latent space with 8x spatial downscale).
-* **Arguments:** Options exist to specify the Anima DiT model, Qwen3 text encoder, WanVAE, LLM adapter, and T5 tokenizer separately.
-* **Incompatible arguments:** Stable Diffusion v1/v2 options such as `--v2`, `--v_parameterization` and `--clip_skip` are not used.
-* **Anima specific options:** Additional parameters for component-wise learning rates (self_attn, cross_attn, mlp, mod, llm_adapter), timestep sampling, discrete flow shift, and flash attention.
-* **6 Parameter Groups:** Independent learning rates for `base`, `self_attn`, `cross_attn`, `mlp`, `adaln_modulation`, and `llm_adapter` components.
+* **Model structure:** Uses a MiniTrainDIT (Transformer based) instead of U-Net. Employs a single text encoder (Qwen3-0.6B), an LLM Adapter that bridges Qwen3 embeddings to T5-compatible cross-attention space, and a Qwen-Image VAE (16-channel latent space with 8x spatial downscale).
+* **Arguments:** Uses the common `--pretrained_model_name_or_path` for the DiT model path, `--qwen3` for the Qwen3 text encoder, and `--vae` for the Qwen-Image VAE. The LLM adapter and T5 tokenizer can be specified separately with `--llm_adapter_path` and `--t5_tokenizer_path`.
+* **Incompatible arguments:** Stable Diffusion v1/v2 options such as `--v2`, `--v_parameterization` and `--clip_skip` are not used. `--fp8_base` is not supported.
+* **Timestep sampling:** Uses the same `--timestep_sampling` options as FLUX training (`sigma`, `uniform`, `sigmoid`, `shift`, `flux_shift`).
+* **LoRA:** Uses regex-based module selection and per-module rank/learning rate control (`network_reg_dims`, `network_reg_lrs`) instead of per-component arguments. Module exclusion/inclusion is controlled by `exclude_patterns` and `include_patterns`.
日本語
@@ -52,11 +56,11 @@ This guide assumes you already understand the basics of LoRA training. For commo
`anima_train_network.py`は`train_network.py`をベースに、Anima モデルに対応するための変更が加えられています。主な違いは以下の通りです。
* **対象モデル:** Anima DiTモデルを対象とします。
-* **モデル構造:** U-Netの代わりにMiniTrainDIT (Transformerベース) を使用します。テキストエンコーダーとしてQwen3-0.6B、Qwen3埋め込みをT5互換のクロスアテンション空間に変換するLLM Adapter、およびWanVAE (16チャンネル潜在空間、8倍空間ダウンスケール) を使用します。
-* **引数:** Anima DiTモデル、Qwen3テキストエンコーダー、WanVAE、LLM Adapter、T5トークナイザーを個別に指定する引数があります。
-* **一部引数の非互換性:** Stable Diffusion v1/v2向けの引数(例: `--v2`, `--v_parameterization`, `--clip_skip`)はAnimaの学習では使用されません。
-* **Anima特有の引数:** コンポーネント別学習率(self_attn, cross_attn, mlp, mod, llm_adapter)、タイムステップサンプリング、離散フローシフト、Flash Attentionに関する引数が追加されています。
-* **6パラメータグループ:** `base`、`self_attn`、`cross_attn`、`mlp`、`adaln_modulation`、`llm_adapter`の各コンポーネントに対して独立した学習率を設定できます。
+* **モデル構造:** U-Netの代わりにMiniTrainDIT (Transformerベース) を使用します。テキストエンコーダーとしてQwen3-0.6B、Qwen3埋め込みをT5互換のクロスアテンション空間に変換するLLM Adapter、およびQwen-Image VAE (16チャンネル潜在空間、8倍空間ダウンスケール) を使用します。
+* **引数:** DiTモデルのパスには共通引数`--pretrained_model_name_or_path`を、Qwen3テキストエンコーダーには`--qwen3`を、Qwen-Image VAEには`--vae`を使用します。LLM AdapterとT5トークナイザーはそれぞれ`--llm_adapter_path`、`--t5_tokenizer_path`で個別に指定できます。
+* **一部引数の非互換性:** Stable Diffusion v1/v2向けの引数(例: `--v2`, `--v_parameterization`, `--clip_skip`)は使用されません。`--fp8_base`はサポートされていません。
+* **タイムステップサンプリング:** FLUX学習と同じ`--timestep_sampling`オプション(`sigma`、`uniform`、`sigmoid`、`shift`、`flux_shift`)を使用します。
+* **LoRA:** コンポーネント別の引数の代わりに、正規表現ベースのモジュール選択とモジュール単位のランク/学習率制御(`network_reg_dims`、`network_reg_lrs`)を使用します。モジュールの除外/包含は`exclude_patterns`と`include_patterns`で制御します。
## 3. Preparation / 準備
@@ -65,16 +69,16 @@ The following files are required before starting training:
1. **Training script:** `anima_train_network.py`
2. **Anima DiT model file:** `.safetensors` file for the base DiT model.
-3. **Qwen3-0.6B text encoder:** Either a HuggingFace model directory or a single `.safetensors` file (requires `configs/qwen3_06b/` config files).
-4. **WanVAE model file:** `.safetensors` or `.pth` file for the VAE.
+3. **Qwen3-0.6B text encoder:** Either a HuggingFace model directory, or a single `.safetensors` file (uses the bundled config files in `configs/qwen3_06b/`).
+4. **Qwen-Image VAE model file:** `.safetensors` or `.pth` file for the VAE.
5. **LLM Adapter model file (optional):** `.safetensors` file. If not provided separately, the adapter is loaded from the DiT file if the key `llm_adapter.out_proj.weight` exists.
6. **T5 Tokenizer (optional):** If not specified, uses the bundled tokenizer at `configs/t5_old/`.
7. **Dataset definition file (.toml):** Dataset settings in TOML format. (See the [Dataset Configuration Guide](./config_README-en.md).) In this document we use `my_anima_dataset_config.toml` as an example.
+Model files can be obtained from the [Anima HuggingFace repository](https://huggingface.co/circlestone-labs/Anima).
+
**Notes:**
-* When using a single `.safetensors` file for Qwen3, download the `config.json`, `tokenizer.json`, `tokenizer_config.json`, and `vocab.json` from the [Qwen/Qwen3-0.6B](https://huggingface.co/Qwen/Qwen3-0.6B) HuggingFace repository into the `configs/qwen3_06b/` directory.
* The T5 tokenizer only needs the tokenizer files (not the T5 model weights). It uses the vocabulary from `google/t5-v1_1-xxl`.
-* Models are saved with a `net.` prefix on all keys for ComfyUI compatibility.
日本語
@@ -83,16 +87,16 @@ The following files are required before starting training:
1. **学習スクリプト:** `anima_train_network.py`
2. **Anima DiTモデルファイル:** ベースとなるDiTモデルの`.safetensors`ファイル。
-3. **Qwen3-0.6Bテキストエンコーダー:** HuggingFaceモデルディレクトリまたは単体の`.safetensors`ファイル(`configs/qwen3_06b/`の設定ファイルが必要)。
-4. **WanVAEモデルファイル:** VAEの`.safetensors`または`.pth`ファイル。
+3. **Qwen3-0.6Bテキストエンコーダー:** HuggingFaceモデルディレクトリまたは単体の`.safetensors`ファイル(バンドル版の`configs/qwen3_06b/`の設定ファイルが使用されます)。
+4. **Qwen-Image VAEモデルファイル:** VAEの`.safetensors`または`.pth`ファイル。
5. **LLM Adapterモデルファイル(オプション):** `.safetensors`ファイル。個別に指定しない場合、DiTファイル内に`llm_adapter.out_proj.weight`キーが存在すればそこから読み込まれます。
6. **T5トークナイザー(オプション):** 指定しない場合、`configs/t5_old/`のバンドル版トークナイザーを使用します。
7. **データセット定義ファイル (.toml):** 学習データセットの設定を記述したTOML形式のファイル。(詳細は[データセット設定ガイド](./config_README-en.md)を参照してください)。例として`my_anima_dataset_config.toml`を使用します。
+モデルファイルは[HuggingFaceのAnimaリポジトリ](https://huggingface.co/circlestone-labs/Anima)から入手できます。
+
**注意:**
-* Qwen3の単体`.safetensors`ファイルを使用する場合、[Qwen/Qwen3-0.6B](https://huggingface.co/Qwen/Qwen3-0.6B) HuggingFaceリポジトリから`config.json`、`tokenizer.json`、`tokenizer_config.json`、`vocab.json`をダウンロードし、`configs/qwen3_06b/`ディレクトリに配置してください。
-* T5トークナイザーはトークナイザーファイルのみ必要です(T5モデルの重みは不要)。`google/t5-v1_1-xxl`の語彙を使用します。
-* モデルはComfyUI互換のため、すべてのキーに`net.`プレフィックスを付けて保存されます。
+* T5トークナイザーを別途指定する場合、トークナイザーファイルのみ必要です(T5モデルの重みは不要)。`google/t5-v1_1-xxl`の語彙を使用します。
## 4. Running the Training / 学習の実行
@@ -103,33 +107,38 @@ Example command:
```bash
accelerate launch --num_cpu_threads_per_process 1 anima_train_network.py \
- --dit_path="" \
- --qwen3_path="" \
- --vae_path="" \
- --llm_adapter_path="" \
+ --pretrained_model_name_or_path="" \
+ --qwen3="" \
+ --vae="" \
--dataset_config="my_anima_dataset_config.toml" \
--output_dir="
## 6. Using the Trained Model / 学習済みモデルの利用
-When training finishes, a LoRA model file (e.g. `my_anima_lora.safetensors`) is saved in the directory specified by `output_dir`. Use this file with inference environments that support Anima , such as ComfyUI with appropriate nodes.
+When training finishes, a LoRA model file (e.g. `my_anima_lora.safetensors`) is saved in the directory specified by `output_dir`. Use this file with inference environments that support Anima, such as ComfyUI with appropriate nodes.
日本語
@@ -394,8 +438,6 @@ Anima models can be large, so GPUs with limited VRAM may require optimization:
#### Key VRAM Reduction Options
-- **`--fp8_base`**: Enables training in FP8 format for the DiT model.
-
- **`--blocks_to_swap `**: Swaps blocks between CPU and GPU to reduce VRAM usage. Higher numbers save more VRAM but reduce training speed. See model-specific max values in section 4.1.
- **`--unsloth_offload_checkpointing`**: Offloads gradient checkpoints to CPU using async non-blocking transfers. Faster than `--cpu_offload_checkpointing`. Cannot be combined with `--blocks_to_swap`.
@@ -404,7 +446,7 @@ Anima models can be large, so GPUs with limited VRAM may require optimization:
- **`--cache_text_encoder_outputs`**: Caches Qwen3 outputs so the text encoder can be freed from VRAM during training.
-- **`--cache_latents`**: Caches WanVAE outputs so the VAE can be freed from VRAM during training.
+- **`--cache_latents`**: Caches Qwen-Image VAE outputs so the VAE can be freed from VRAM during training.
- **Using Adafactor optimizer**: Can reduce VRAM usage:
```
@@ -417,12 +459,11 @@ Anima models can be large, so GPUs with limited VRAM may require optimization:
Animaモデルは大きい場合があるため、VRAMが限られたGPUでは最適化が必要です。
主要なVRAM削減オプション:
-- `--fp8_base`: FP8形式での学習を有効化
- `--blocks_to_swap`: CPUとGPU間でブロックをスワップ
- `--unsloth_offload_checkpointing`: 非同期転送でアクティベーションをCPUにオフロード
- `--gradient_checkpointing`: 標準的な勾配チェックポイント
- `--cache_text_encoder_outputs`: Qwen3の出力をキャッシュ
-- `--cache_latents`: WanVAEの出力をキャッシュ
+- `--cache_latents`: Qwen-Image VAEの出力をキャッシュ
- Adafactorオプティマイザの使用
@@ -431,21 +472,24 @@ Animaモデルは大きい場合があるため、VRAMが限られたGPUでは
#### Timestep Sampling
-The `--timestep_sample_method` option specifies how timesteps (0-1) are sampled:
+The `--timestep_sampling` option specifies how timesteps are sampled. The available methods are the same as FLUX training:
-- `logit_normal` (default): Sample from Normal(0,1), multiply by `sigmoid_scale`, apply sigmoid. Good general-purpose option.
+- `sigma`: Sigma-based sampling like SD3.
- `uniform`: Uniform random sampling from [0, 1].
+- `sigmoid` (default): Sample from Normal(0,1), multiply by `sigmoid_scale`, apply sigmoid. Good general-purpose option.
+- `shift`: Like `sigmoid`, but applies the discrete flow shift formula: `t_shifted = (t * shift) / (1 + (shift - 1) * t)`.
+- `flux_shift`: Resolution-dependent shift used in FLUX training.
+
+See the [flux_train_network.py guide](flux_train_network.md) for detailed descriptions.
#### Discrete Flow Shift
-The `--discrete_flow_shift` option (default `3.0`) shifts the timestep distribution toward higher noise levels. The formula is:
+The `--discrete_flow_shift` option (default `1.0`) only applies when `--timestep_sampling` is set to `shift`. The formula is:
```
t_shifted = (t * shift) / (1 + (shift - 1) * t)
```
-Timesteps are clamped to `[1e-5, 1-1e-5]` after shifting.
-
#### Loss Weighting
The `--weighting_scheme` option specifies loss weighting by timestep:
@@ -454,23 +498,34 @@ The `--weighting_scheme` option specifies loss weighting by timestep:
- `sigma_sqrt`: Weight by `sigma^(-2)`.
- `cosmap`: Weight by `2 / (pi * (1 - 2*sigma + 2*sigma^2))`.
- `none`: Same as uniform.
+- `logit_normal`, `mode`: Additional schemes from SD3 training. See the [`sd3_train_network.md` guide](sd3_train_network.md) for details.
#### Caption Dropout
-Use `--caption_dropout_rate` for embedding-level caption dropout. This is handled by `AnimaTextEncodingStrategy` and is compatible with text encoder output caching. The subset-level `caption_dropout_rate` is automatically zeroed when this is set.
+Caption dropout uses the `caption_dropout_rate` setting from the dataset configuration (per-subset in TOML). When using `--cache_text_encoder_outputs`, the dropout rate is stored with each cached entry and applied during training, so caption dropout is compatible with text encoder output caching.
+
+**If you change the `caption_dropout_rate` setting, you must delete and regenerate the cache.**
+
+Note: Currently, only Anima supports combining `caption_dropout_rate` with text encoder output caching.
日本語
#### タイムステップサンプリング
-`--timestep_sample_method`でタイムステップのサンプリング方法を指定します:
-- `logit_normal`(デフォルト): 正規分布からサンプリングし、sigmoidを適用。
+`--timestep_sampling`でタイムステップのサンプリング方法を指定します。FLUX学習と同じ方法が利用できます:
+
+- `sigma`: SD3と同様のシグマベースサンプリング。
- `uniform`: [0, 1]の一様分布からサンプリング。
+- `sigmoid`(デフォルト): 正規分布からサンプリングし、sigmoidを適用。汎用的なオプション。
+- `shift`: `sigmoid`と同様だが、離散フローシフトの式を適用。
+- `flux_shift`: FLUX学習で使用される解像度依存のシフト。
+
+詳細は[flux_train_network.pyのガイド](flux_train_network.md)を参照してください。
#### 離散フローシフト
-`--discrete_flow_shift`(デフォルト`3.0`)はタイムステップ分布を高ノイズ側にシフトします。
+`--discrete_flow_shift`(デフォルト`1.0`)は`--timestep_sampling`が`shift`の場合のみ適用されます。
#### 損失の重み付け
@@ -478,7 +533,11 @@ Use `--caption_dropout_rate` for embedding-level caption dropout. This is handle
#### キャプションドロップアウト
-`--caption_dropout_rate`で埋め込みレベルのキャプションドロップアウトを使用します。テキストエンコーダー出力のキャッシュと互換性があります。
+キャプションドロップアウトにはデータセット設定(TOMLでのサブセット単位)の`caption_dropout_rate`を使用します。`--cache_text_encoder_outputs`使用時は、ドロップアウト率が各キャッシュエントリとともに保存され、学習中に適用されるため、テキストエンコーダー出力キャッシュと同時に使用できます。
+
+**`caption_dropout_rate`の設定を変えた場合、キャッシュを削除し、再生成する必要があります。**
+
+※`caption_dropout_rate`をテキストエンコーダー出力キャッシュと組み合わせられるのは、今のところAnimaのみです。
@@ -487,17 +546,23 @@ Use `--caption_dropout_rate` for embedding-level caption dropout. This is handle
Anima LoRA training supports training Qwen3 text encoder LoRA:
- To train only DiT: specify `--network_train_unet_only`
-- To train DiT and Qwen3: omit `--network_train_unet_only`
+- To train DiT and Qwen3: omit `--network_train_unet_only` and do NOT use `--cache_text_encoder_outputs`
You can specify a separate learning rate for Qwen3 with `--text_encoder_lr`. If not specified, the default `--learning_rate` is used.
+Note: When `--cache_text_encoder_outputs` is used, text encoder outputs are pre-computed and the text encoder is removed from GPU, so text encoder LoRA cannot be trained.
+
日本語
Anima LoRA学習では、Qwen3テキストエンコーダーのLoRAもトレーニングできます。
- DiTのみ学習: `--network_train_unet_only`を指定
-- DiTとQwen3を学習: `--network_train_unet_only`を省略
+- DiTとQwen3を学習: `--network_train_unet_only`を省略し、`--cache_text_encoder_outputs`を使用しない
+
+Qwen3に個別の学習率を指定するには`--text_encoder_lr`を使用します。未指定の場合は`--learning_rate`が使われます。
+
+注意: `--cache_text_encoder_outputs`を使用する場合、テキストエンコーダーの出力が事前に計算されGPUから解放されるため、テキストエンコーダーLoRAは学習できません。
@@ -528,16 +593,47 @@ Anima LoRA学習では、Qwen3テキストエンコーダーのLoRAもトレー
-## 9. Others / その他
+## 9. Related Tools / 関連ツール
+
+### `networks/anima_convert_lora_to_comfy.py`
+
+A script to convert LoRA models to ComfyUI-compatible format. ComfyUI does not directly support sd-scripts format Qwen3 LoRA, so conversion is necessary (conversion may not be needed for DiT-only LoRA). You can convert from the sd-scripts format to ComfyUI format with:
+
+```bash
+python networks/convert_anima_lora_to_comfy.py path/to/source.safetensors path/to/destination.safetensors
+```
+
+Using the `--reverse` option allows conversion in the opposite direction (ComfyUI format to sd-scripts format). However, reverse conversion is only possible for LoRAs converted by this script. LoRAs created with other training tools cannot be converted.
+
+
+日本語
+
+**`networks/convert_anima_lora_to_comfy.py`**
+
+LoRAモデルをComfyUI互換形式に変換するスクリプト。ComfyUIがsd-scripts形式のQwen3 LoRAを直接サポートしていないため、変換が必要です(DiTのみのLoRAの場合は変換不要のようです)。sd-scripts形式からComfyUI形式への変換は以下のコマンドで行います:
+
+```bash
+python networks/convert_anima_lora_to_comfy.py path/to/source.safetensors path/to/destination.safetensors
+```
+
+`--reverse`オプションを付けると、逆変換(ComfyUI形式からsd-scripts形式)も可能です。ただし、逆変換ができるのはこのスクリプトで変換したLoRAに限ります。他の学習ツールで作成したLoRAは変換できません。
+
+
+
+
+## 10. Others / その他
### Metadata Saved in LoRA Models
-The following Anima-specific metadata is saved in the LoRA model file:
+The following metadata is saved in the LoRA model file:
* `ss_weighting_scheme`
-* `ss_discrete_flow_shift`
-* `ss_timestep_sample_method`
+* `ss_logit_mean`
+* `ss_logit_std`
+* `ss_mode_scale`
+* `ss_timestep_sampling`
* `ss_sigmoid_scale`
+* `ss_discrete_flow_shift`
日本語
@@ -546,11 +642,14 @@ The following Anima-specific metadata is saved in the LoRA model file:
### LoRAモデルに保存されるメタデータ
-以下のAnima固有のメタデータがLoRAモデルファイルに保存されます:
+以下のメタデータがLoRAモデルファイルに保存されます:
* `ss_weighting_scheme`
-* `ss_discrete_flow_shift`
-* `ss_timestep_sample_method`
+* `ss_logit_mean`
+* `ss_logit_std`
+* `ss_mode_scale`
+* `ss_timestep_sampling`
* `ss_sigmoid_scale`
+* `ss_discrete_flow_shift`
diff --git a/library/anima_models.py b/library/anima_models.py
index 6aad9d8c..6828e598 100644
--- a/library/anima_models.py
+++ b/library/anima_models.py
@@ -2,7 +2,7 @@
# Original code: NVIDIA CORPORATION & AFFILIATES, licensed under Apache-2.0
import math
-from typing import Any, Callable, List, Optional, Tuple, Union
+from typing import Any, Optional, Tuple, Union
import numpy as np
import torch
@@ -13,9 +13,7 @@ import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint as torch_checkpoint
-from library import custom_offloading_utils
-from library.device_utils import clean_memory_on_device
-
+from library import custom_offloading_utils, attention
def to_device(x, device):
@@ -39,11 +37,13 @@ def to_cpu(x):
else:
return x
+
# Unsloth Offloaded Gradient Checkpointing
# Based on Unsloth Zoo by Daniel Han-Chen & the Unsloth team
try:
from deepspeed.runtime.activation_checkpointing.checkpointing import detach_variable
except ImportError:
+
def detach_variable(inputs, device=None):
"""Detach tensors from computation graph, optionally moving to a device.
@@ -80,11 +80,11 @@ class UnslothOffloadedGradientCheckpointer(torch.autograd.Function):
"""
@staticmethod
- @torch.amp.custom_fwd(device_type='cuda')
+ @torch.amp.custom_fwd(device_type="cuda")
def forward(ctx, forward_function, hidden_states, *args):
# Remember the original device for backward pass (multi-GPU support)
ctx.input_device = hidden_states.device
- saved_hidden_states = hidden_states.to('cpu', non_blocking=True)
+ saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
with torch.no_grad():
output = forward_function(hidden_states, *args)
ctx.save_for_backward(saved_hidden_states)
@@ -96,7 +96,7 @@ class UnslothOffloadedGradientCheckpointer(torch.autograd.Function):
return output
@staticmethod
- @torch.amp.custom_bwd(device_type='cuda')
+ @torch.amp.custom_bwd(device_type="cuda")
def backward(ctx, *grads):
(hidden_states,) = ctx.saved_tensors
hidden_states = hidden_states.to(ctx.input_device, non_blocking=True).detach()
@@ -108,8 +108,9 @@ class UnslothOffloadedGradientCheckpointer(torch.autograd.Function):
output_tensors = []
grad_tensors = []
- for out, grad in zip(outputs if isinstance(outputs, tuple) else (outputs,),
- grads if isinstance(grads, tuple) else (grads,)):
+ for out, grad in zip(
+ outputs if isinstance(outputs, tuple) else (outputs,), grads if isinstance(grads, tuple) else (grads,)
+ ):
if isinstance(out, torch.Tensor) and out.requires_grad:
output_tensors.append(out)
grad_tensors.append(grad)
@@ -123,26 +124,6 @@ def unsloth_checkpoint(function, *args):
return UnslothOffloadedGradientCheckpointer.apply(function, *args)
-# Flash Attention support
-try:
- from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func
- FLASH_ATTN_AVAILABLE = True
-except ImportError:
- _flash_attn_func = None
- FLASH_ATTN_AVAILABLE = False
-
-
-def flash_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor:
- """Computes multi-head attention using Flash Attention.
-
- Input format: (batch, seq_len, n_heads, head_dim)
- Output format: (batch, seq_len, n_heads * head_dim) — matches torch_attention_op output.
- """
- # flash_attn_func expects (B, S, H, D) and returns (B, S, H, D)
- out = _flash_attn_func(q_B_S_H_D, k_B_S_H_D, v_B_S_H_D)
- return rearrange(out, "b s h d -> b s (h d)")
-
-
from .utils import setup_logging
setup_logging()
@@ -174,14 +155,10 @@ def _apply_rotary_pos_emb_base(
if start_positions is not None:
max_offset = torch.max(start_positions)
- assert (
- max_offset + cur_seq_len <= max_seq_len
- ), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
+ assert max_offset + cur_seq_len <= max_seq_len, f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
freqs = torch.concatenate([freqs[i : i + cur_seq_len] for i in start_positions], dim=1)
- assert (
- cur_seq_len <= max_seq_len
- ), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
+ assert cur_seq_len <= max_seq_len, f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
freqs = freqs[:cur_seq_len]
if tensor_format == "bshd":
@@ -205,13 +182,9 @@ def apply_rotary_pos_emb(
cu_seqlens: Union[torch.Tensor, None] = None,
cp_size: int = 1,
) -> torch.Tensor:
- assert not (
- cp_size > 1 and start_positions is not None
- ), "start_positions != None with CP SIZE > 1 is not supported!"
+ assert not (cp_size > 1 and start_positions is not None), "start_positions != None with CP SIZE > 1 is not supported!"
- assert (
- tensor_format != "thd" or cu_seqlens is not None
- ), "cu_seqlens must not be None when tensor_format is 'thd'."
+ assert tensor_format != "thd" or cu_seqlens is not None, "cu_seqlens must not be None when tensor_format is 'thd'."
assert fused == False
@@ -223,9 +196,7 @@ def apply_rotary_pos_emb(
_apply_rotary_pos_emb_base(
x.unsqueeze(1),
freqs,
- start_positions=(
- start_positions[idx : idx + 1] if start_positions is not None else None
- ),
+ start_positions=(start_positions[idx : idx + 1] if start_positions is not None else None),
interleaved=interleaved,
)
for idx, x in enumerate(torch.split(t, seqlens))
@@ -262,10 +233,10 @@ class RMSNorm(torch.nn.Module):
def _norm(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
- @torch.amp.autocast(device_type='cuda', dtype=torch.float32)
def forward(self, x: torch.Tensor) -> torch.Tensor:
- output = self._norm(x.float()).type_as(x)
- return output * self.weight
+ with torch.autocast(device_type=x.device.type, dtype=torch.float32):
+ output = self._norm(x.float()).type_as(x)
+ return output * self.weight
class GPT2FeedForward(nn.Module):
@@ -298,22 +269,6 @@ class GPT2FeedForward(nn.Module):
return x
-def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor:
- """Computes multi-head attention using PyTorch's native scaled_dot_product_attention.
-
- Input/output format: (batch, seq_len, n_heads, head_dim)
- """
- in_q_shape = q_B_S_H_D.shape
- in_k_shape = k_B_S_H_D.shape
- q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1])
- k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
- v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
- result_B_S_HD = rearrange(
- F.scaled_dot_product_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D), "b h ... l -> b ... (h l)"
- )
- return result_B_S_HD
-
-
# Attention module for DiT
class Attention(nn.Module):
"""Multi-head attention supporting both self-attention and cross-attention.
@@ -354,8 +309,6 @@ class Attention(nn.Module):
self.output_proj = nn.Linear(inner_dim, query_dim, bias=False)
self.output_dropout = nn.Dropout(dropout) if dropout > 1e-4 else nn.Identity()
- self.attn_op = torch_attention_op
-
self._query_dim = query_dim
self._context_dim = context_dim
self._inner_dim = inner_dim
@@ -399,18 +352,25 @@ class Attention(nn.Module):
return q, k, v
- def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
- result = self.attn_op(q, k, v) # [B, S, H, D]
- return self.output_dropout(self.output_proj(result))
-
def forward(
self,
x: torch.Tensor,
+ attn_params: attention.AttentionParams,
context: Optional[torch.Tensor] = None,
rope_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb)
- return self.compute_attention(q, k, v)
+ if q.dtype != v.dtype:
+ if (not attn_params.supports_fp32 or attn_params.requires_same_dtype) and torch.is_autocast_enabled():
+ # FlashAttention requires fp16/bf16, xformers require same dtype; only cast when autocast is active.
+ target_dtype = v.dtype # v has fp16/bf16 dtype
+ q = q.to(target_dtype)
+ k = k.to(target_dtype)
+ # return self.compute_attention(q, k, v)
+ qkv = [q, k, v]
+ del q, k, v
+ result = attention.attention(qkv, attn_params=attn_params)
+ return self.output_dropout(self.output_proj(result))
# Positional Embeddings
@@ -484,12 +444,8 @@ class VideoRopePosition3DEmb(VideoPositionEmb):
dim_t = self._dim_t
self.seq = torch.arange(max(self.max_h, self.max_w, self.max_t)).float().to(self.dim_spatial_range.device)
- self.dim_spatial_range = (
- torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().to(self.dim_spatial_range.device) / dim_h
- )
- self.dim_temporal_range = (
- torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().to(self.dim_spatial_range.device) / dim_t
- )
+ self.dim_spatial_range = torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().to(self.dim_spatial_range.device) / dim_h
+ self.dim_temporal_range = torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().to(self.dim_spatial_range.device) / dim_t
def generate_embeddings(
self,
@@ -664,31 +620,30 @@ class TimestepEmbedding(nn.Module):
return emb_B_T_D, adaln_lora_B_T_3D
-class FourierFeatures(nn.Module):
- """Fourier feature transform: [B] -> [B, D]."""
+# Commented out Fourier Features (not used in Anima). Kept for reference.
+# class FourierFeatures(nn.Module):
+# """Fourier feature transform: [B] -> [B, D]."""
- def __init__(self, num_channels: int, bandwidth: int = 1, normalize: bool = False):
- super().__init__()
- self.register_buffer("freqs", 2 * np.pi * bandwidth * torch.randn(num_channels), persistent=True)
- self.register_buffer("phases", 2 * np.pi * torch.rand(num_channels), persistent=True)
- self.gain = np.sqrt(2) if normalize else 1
- self.bandwidth = bandwidth
- self.num_channels = num_channels
- self.reset_parameters()
+# def __init__(self, num_channels: int, bandwidth: int = 1, normalize: bool = False):
+# super().__init__()
+# self.register_buffer("freqs", 2 * np.pi * bandwidth * torch.randn(num_channels), persistent=True)
+# self.register_buffer("phases", 2 * np.pi * torch.rand(num_channels), persistent=True)
+# self.gain = np.sqrt(2) if normalize else 1
+# self.bandwidth = bandwidth
+# self.num_channels = num_channels
+# self.reset_parameters()
- def reset_parameters(self) -> None:
- generator = torch.Generator()
- generator.manual_seed(0)
- self.freqs = (
- 2 * np.pi * self.bandwidth * torch.randn(self.num_channels, generator=generator).to(self.freqs.device)
- )
- self.phases = 2 * np.pi * torch.rand(self.num_channels, generator=generator).to(self.freqs.device)
+# def reset_parameters(self) -> None:
+# generator = torch.Generator()
+# generator.manual_seed(0)
+# self.freqs = 2 * np.pi * self.bandwidth * torch.randn(self.num_channels, generator=generator).to(self.freqs.device)
+# self.phases = 2 * np.pi * torch.rand(self.num_channels, generator=generator).to(self.freqs.device)
- def forward(self, x: torch.Tensor, gain: float = 1.0) -> torch.Tensor:
- in_dtype = x.dtype
- x = x.to(torch.float32).ger(self.freqs.to(torch.float32)).add(self.phases.to(torch.float32))
- x = x.cos().mul(self.gain * gain).to(in_dtype)
- return x
+# def forward(self, x: torch.Tensor, gain: float = 1.0) -> torch.Tensor:
+# in_dtype = x.dtype
+# x = x.to(torch.float32).ger(self.freqs.to(torch.float32)).add(self.phases.to(torch.float32))
+# x = x.cos().mul(self.gain * gain).to(in_dtype)
+# return x
# Patch Embedding
@@ -713,9 +668,7 @@ class PatchEmbed(nn.Module):
m=spatial_patch_size,
n=spatial_patch_size,
),
- nn.Linear(
- in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=False
- ),
+ nn.Linear(in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=False),
)
self.dim = in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size
@@ -765,9 +718,7 @@ class FinalLayer(nn.Module):
nn.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False),
)
else:
- self.adaln_modulation = nn.Sequential(
- nn.SiLU(), nn.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False)
- )
+ self.adaln_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False))
self.init_weights()
@@ -790,9 +741,9 @@ class FinalLayer(nn.Module):
):
if self.use_adaln_lora:
assert adaln_lora_B_T_3D is not None
- shift_B_T_D, scale_B_T_D = (
- self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size]
- ).chunk(2, dim=-1)
+ shift_B_T_D, scale_B_T_D = (self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size]).chunk(
+ 2, dim=-1
+ )
else:
shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1)
@@ -833,7 +784,11 @@ class Block(nn.Module):
self.layer_norm_cross_attn = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6)
self.cross_attn = Attention(
- x_dim, context_dim, num_heads, x_dim // num_heads, qkv_format="bshd",
+ x_dim,
+ context_dim,
+ num_heads,
+ x_dim // num_heads,
+ qkv_format="bshd",
)
self.layer_norm_mlp = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6)
@@ -904,6 +859,7 @@ class Block(nn.Module):
x_B_T_H_W_D: torch.Tensor,
emb_B_T_D: torch.Tensor,
crossattn_emb: torch.Tensor,
+ attn_params: attention.AttentionParams,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
@@ -919,13 +875,13 @@ class Block(nn.Module):
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = (
self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D
).chunk(3, dim=-1)
- shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = (
- self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D
- ).chunk(3, dim=-1)
+ shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = (self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D).chunk(
+ 3, dim=-1
+ )
else:
- shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn(
- emb_B_T_D
- ).chunk(3, dim=-1)
+ shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn(emb_B_T_D).chunk(
+ 3, dim=-1
+ )
shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn(
emb_B_T_D
).chunk(3, dim=-1)
@@ -954,11 +910,14 @@ class Block(nn.Module):
result = rearrange(
self.self_attn(
rearrange(normalized_x, "b t h w d -> b (t h w) d"),
+ attn_params,
None,
rope_emb=rope_emb_L_1_1_D,
),
"b (t h w) d -> b t h w d",
- t=T, h=H, w=W,
+ t=T,
+ h=H,
+ w=W,
)
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result
@@ -967,11 +926,14 @@ class Block(nn.Module):
result = rearrange(
self.cross_attn(
rearrange(normalized_x, "b t h w d -> b (t h w) d"),
+ attn_params,
crossattn_emb,
rope_emb=rope_emb_L_1_1_D,
),
"b (t h w) d -> b t h w d",
- t=T, h=H, w=W,
+ t=T,
+ h=H,
+ w=W,
)
x_B_T_H_W_D = result * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
@@ -987,6 +949,7 @@ class Block(nn.Module):
x_B_T_H_W_D: torch.Tensor,
emb_B_T_D: torch.Tensor,
crossattn_emb: torch.Tensor,
+ attn_params: attention.AttentionParams,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
@@ -996,8 +959,13 @@ class Block(nn.Module):
# Unsloth: async non-blocking CPU RAM offload (fastest offload method)
return unsloth_checkpoint(
self._forward,
- x_B_T_H_W_D, emb_B_T_D, crossattn_emb,
- rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb,
+ x_B_T_H_W_D,
+ emb_B_T_D,
+ crossattn_emb,
+ attn_params,
+ rope_emb_L_1_1_D,
+ adaln_lora_B_T_3D,
+ extra_per_block_pos_emb,
)
elif self.cpu_offload_checkpointing:
# Standard cpu offload: blocking transfers
@@ -1008,36 +976,54 @@ class Block(nn.Module):
device_inputs = to_device(inputs, device)
outputs = func(*device_inputs)
return to_cpu(outputs)
+
return custom_forward
return torch_checkpoint(
create_custom_forward(self._forward),
- x_B_T_H_W_D, emb_B_T_D, crossattn_emb,
- rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb,
+ x_B_T_H_W_D,
+ emb_B_T_D,
+ crossattn_emb,
+ attn_params,
+ rope_emb_L_1_1_D,
+ adaln_lora_B_T_3D,
+ extra_per_block_pos_emb,
use_reentrant=False,
)
else:
# Standard gradient checkpointing (no offload)
return torch_checkpoint(
self._forward,
- x_B_T_H_W_D, emb_B_T_D, crossattn_emb,
- rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb,
+ x_B_T_H_W_D,
+ emb_B_T_D,
+ crossattn_emb,
+ attn_params,
+ rope_emb_L_1_1_D,
+ adaln_lora_B_T_3D,
+ extra_per_block_pos_emb,
use_reentrant=False,
)
else:
return self._forward(
- x_B_T_H_W_D, emb_B_T_D, crossattn_emb,
- rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb,
+ x_B_T_H_W_D,
+ emb_B_T_D,
+ crossattn_emb,
+ attn_params,
+ rope_emb_L_1_1_D,
+ adaln_lora_B_T_3D,
+ extra_per_block_pos_emb,
)
-# Main DiT Model: MiniTrainDIT
-class MiniTrainDIT(nn.Module):
+# Main DiT Model: MiniTrainDIT (renamed to Anima)
+class Anima(nn.Module):
"""Cosmos-Predict2 DiT model for image/video generation.
28 transformer blocks with AdaLN-LoRA modulation, 3D RoPE, and optional LLM Adapter.
"""
+ LATENT_CHANNELS = 16
+
def __init__(
self,
max_img_h: int,
@@ -1069,6 +1055,8 @@ class MiniTrainDIT(nn.Module):
extra_t_extrapolation_ratio: float = 1.0,
rope_enable_fps_modulation: bool = True,
use_llm_adapter: bool = False,
+ attn_mode: str = "torch",
+ split_attn: bool = False,
) -> None:
super().__init__()
self.max_img_h = max_img_h
@@ -1097,6 +1085,9 @@ class MiniTrainDIT(nn.Module):
self.rope_enable_fps_modulation = rope_enable_fps_modulation
self.use_llm_adapter = use_llm_adapter
+ self.attn_mode = attn_mode
+ self.split_attn = split_attn
+
# Block swap support
self.blocks_to_swap = None
self.offloader: Optional[custom_offloading_utils.ModelOffloader] = None
@@ -1156,7 +1147,6 @@ class MiniTrainDIT(nn.Module):
self.final_layer.init_weights()
self.t_embedding_norm.reset_parameters()
-
def enable_gradient_checkpointing(self, cpu_offload: bool = False, unsloth_offload: bool = False):
for block in self.blocks:
block.enable_gradient_checkpointing(cpu_offload=cpu_offload, unsloth_offload=unsloth_offload)
@@ -1169,18 +1159,9 @@ class MiniTrainDIT(nn.Module):
def device(self):
return next(self.parameters()).device
-
- def set_flash_attn(self, use_flash_attn: bool):
- """Toggle flash attention for all DiT blocks (self-attn + cross-attn).
-
- LLM Adapter attention is NOT affected (it uses attention masks incompatible with flash_attn).
- """
- if use_flash_attn and not FLASH_ATTN_AVAILABLE:
- raise ImportError("flash_attn package is required for --flash_attn but is not installed")
- attn_op = flash_attention_op if use_flash_attn else torch_attention_op
- for block in self.blocks:
- block.self_attn.attn_op = attn_op
- block.cross_attn.attn_op = attn_op
+ @property
+ def dtype(self):
+ return next(self.parameters()).dtype
def build_patch_embed(self) -> None:
in_channels = self.in_channels + 1 if self.concat_padding_mask else self.in_channels
@@ -1232,9 +1213,7 @@ class MiniTrainDIT(nn.Module):
padding_mask = transforms.functional.resize(
padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
)
- x_B_C_T_H_W = torch.cat(
- [x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1
- )
+ x_B_C_T_H_W = torch.cat([x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1)
x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W)
if self.extra_per_block_abs_pos_emb:
@@ -1258,7 +1237,6 @@ class MiniTrainDIT(nn.Module):
)
return x_B_C_Tt_Hp_Wp
-
def enable_block_swap(self, num_blocks: int, device: torch.device):
self.blocks_to_swap = num_blocks
@@ -1266,9 +1244,7 @@ class MiniTrainDIT(nn.Module):
self.blocks_to_swap <= self.num_blocks - 2
), f"Cannot swap more than {self.num_blocks - 2} blocks. Requested: {self.blocks_to_swap} blocks."
- self.offloader = custom_offloading_utils.ModelOffloader(
- self.blocks, self.blocks_to_swap, device
- )
+ self.offloader = custom_offloading_utils.ModelOffloader(self.blocks, self.blocks_to_swap, device)
logger.info(f"Anima: Block swap enabled. Swapping {num_blocks} blocks, total blocks: {self.num_blocks}, device: {device}.")
def move_to_device_except_swap_blocks(self, device: torch.device):
@@ -1282,12 +1258,26 @@ class MiniTrainDIT(nn.Module):
if self.blocks_to_swap:
self.blocks = save_blocks
+ def switch_block_swap_for_inference(self):
+ if self.blocks_to_swap is None or self.blocks_to_swap == 0:
+ return
+ self.offloader.set_forward_only(True)
+ self.prepare_block_swap_before_forward()
+ print(f"Anima: Block swap set to forward only.")
+
+ def switch_block_swap_for_training(self):
+ if self.blocks_to_swap is None or self.blocks_to_swap == 0:
+ return
+ self.offloader.set_forward_only(False)
+ self.prepare_block_swap_before_forward()
+ print(f"Anima: Block swap set to forward and backward.")
+
def prepare_block_swap_before_forward(self):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
self.offloader.prepare_block_devices_before_forward(self.blocks)
- def forward(
+ def forward_mini_train_dit(
self,
x_B_C_T_H_W: torch.Tensor,
timesteps_B_T: torch.Tensor,
@@ -1310,7 +1300,7 @@ class MiniTrainDIT(nn.Module):
t5_attn_mask: Optional T5 attention mask
"""
# Run LLM adapter inside forward for correct DDP gradient synchronization
- if t5_input_ids is not None and self.use_llm_adapter and hasattr(self, 'llm_adapter'):
+ if t5_input_ids is not None and self.use_llm_adapter and hasattr(self, "llm_adapter"):
crossattn_emb = self.llm_adapter(
source_hidden_states=crossattn_emb,
target_input_ids=t5_input_ids,
@@ -1337,16 +1327,13 @@ class MiniTrainDIT(nn.Module):
"extra_per_block_pos_emb": extra_pos_emb,
}
+ attn_params = attention.AttentionParams.create_attention_params(self.attn_mode, self.split_attn)
+
for block_idx, block in enumerate(self.blocks):
if self.blocks_to_swap:
self.offloader.wait_for_block(block_idx)
- x_B_T_H_W_D = block(
- x_B_T_H_W_D,
- t_embedding_B_T_D,
- crossattn_emb,
- **block_kwargs,
- )
+ x_B_T_H_W_D = block(x_B_T_H_W_D, t_embedding_B_T_D, crossattn_emb, attn_params, **block_kwargs)
if self.blocks_to_swap:
self.offloader.submit_move_blocks(self.blocks, block_idx)
@@ -1355,6 +1342,36 @@ class MiniTrainDIT(nn.Module):
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)
return x_B_C_Tt_Hp_Wp
+ def forward(
+ self,
+ x: torch.Tensor,
+ timesteps: torch.Tensor,
+ context: Optional[torch.Tensor] = None,
+ fps: Optional[torch.Tensor] = None,
+ padding_mask: Optional[torch.Tensor] = None,
+ target_input_ids: Optional[torch.Tensor] = None,
+ target_attention_mask: Optional[torch.Tensor] = None,
+ source_attention_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ context = self._preprocess_text_embeds(context, target_input_ids, target_attention_mask, source_attention_mask)
+ return self.forward_mini_train_dit(x, timesteps, context, fps=fps, padding_mask=padding_mask, **kwargs)
+
+ def _preprocess_text_embeds(
+ self, source_hidden_states, target_input_ids, target_attention_mask=None, source_attention_mask=None
+ ):
+ if target_input_ids is not None:
+ context = self.llm_adapter(
+ source_hidden_states,
+ target_input_ids,
+ target_attention_mask=target_attention_mask,
+ source_attention_mask=source_attention_mask,
+ )
+ context[~target_attention_mask.bool()] = 0 # zero out padding tokens
+ return context
+ else:
+ return source_hidden_states
+
# LLM Adapter: Bridges Qwen3 embeddings to T5-compatible space
class LLMAdapterRMSNorm(nn.Module):
@@ -1485,24 +1502,37 @@ class LLMAdapterTransformerBlock(nn.Module):
self.norm_mlp = nn.LayerNorm(model_dim) if layer_norm else LLMAdapterRMSNorm(model_dim)
self.mlp = nn.Sequential(
- nn.Linear(model_dim, int(model_dim * mlp_ratio)),
- nn.GELU(),
- nn.Linear(int(model_dim * mlp_ratio), model_dim)
+ nn.Linear(model_dim, int(model_dim * mlp_ratio)), nn.GELU(), nn.Linear(int(model_dim * mlp_ratio), model_dim)
)
- def forward(self, x, context, target_attention_mask=None, source_attention_mask=None,
- position_embeddings=None, position_embeddings_context=None):
+ def forward(
+ self,
+ x,
+ context,
+ target_attention_mask=None,
+ source_attention_mask=None,
+ position_embeddings=None,
+ position_embeddings_context=None,
+ ):
if self.has_self_attn:
+ # Self-attention: target_attention_mask is not expected to be all zeros
normed = self.norm_self_attn(x)
- attn_out = self.self_attn(normed, mask=target_attention_mask,
- position_embeddings=position_embeddings,
- position_embeddings_context=position_embeddings)
+ attn_out = self.self_attn(
+ normed,
+ mask=target_attention_mask,
+ position_embeddings=position_embeddings,
+ position_embeddings_context=position_embeddings,
+ )
x = x + attn_out
normed = self.norm_cross_attn(x)
- attn_out = self.cross_attn(normed, mask=source_attention_mask, context=context,
- position_embeddings=position_embeddings,
- position_embeddings_context=position_embeddings_context)
+ attn_out = self.cross_attn(
+ normed,
+ mask=source_attention_mask,
+ context=context,
+ position_embeddings=position_embeddings,
+ position_embeddings_context=position_embeddings_context,
+ )
x = x + attn_out
x = x + self.mlp(self.norm_mlp(x))
@@ -1518,8 +1548,9 @@ class LLMAdapter(nn.Module):
Uses T5 token IDs as target input, embeds them, and cross-attends to Qwen3 hidden states.
"""
- def __init__(self, source_dim, target_dim, model_dim, num_layers=6, num_heads=16,
- embed=None, self_attn=False, layer_norm=False):
+ def __init__(
+ self, source_dim, target_dim, model_dim, num_layers=6, num_heads=16, embed=None, self_attn=False, layer_norm=False
+ ):
super().__init__()
if embed is not None:
self.embed = nn.Embedding.from_pretrained(embed.weight)
@@ -1530,11 +1561,12 @@ class LLMAdapter(nn.Module):
else:
self.in_proj = nn.Identity()
self.rotary_emb = AdapterRotaryEmbedding(model_dim // num_heads)
- self.blocks = nn.ModuleList([
- LLMAdapterTransformerBlock(source_dim, model_dim, num_heads=num_heads,
- self_attn=self_attn, layer_norm=layer_norm)
- for _ in range(num_layers)
- ])
+ self.blocks = nn.ModuleList(
+ [
+ LLMAdapterTransformerBlock(source_dim, model_dim, num_heads=num_heads, self_attn=self_attn, layer_norm=layer_norm)
+ for _ in range(num_layers)
+ ]
+ )
self.out_proj = nn.Linear(model_dim, target_dim)
self.norm = LLMAdapterRMSNorm(target_dim)
@@ -1556,75 +1588,67 @@ class LLMAdapter(nn.Module):
position_embeddings = self.rotary_emb(x, position_ids)
position_embeddings_context = self.rotary_emb(x, position_ids_context)
for block in self.blocks:
- x = block(x, context, target_attention_mask=target_attention_mask,
- source_attention_mask=source_attention_mask,
- position_embeddings=position_embeddings,
- position_embeddings_context=position_embeddings_context)
+ x = block(
+ x,
+ context,
+ target_attention_mask=target_attention_mask,
+ source_attention_mask=source_attention_mask,
+ position_embeddings=position_embeddings,
+ position_embeddings_context=position_embeddings_context,
+ )
return self.norm(self.out_proj(x))
-# VAE Wrapper
+# Not used currently, but kept for reference
-# VAE normalization constants
-ANIMA_VAE_MEAN = [
- -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
- 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
-]
-ANIMA_VAE_STD = [
- 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
- 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
-]
+# def get_dit_config(state_dict, key_prefix=""):
+# """Derive DiT configuration from state_dict weight shapes."""
+# dit_config = {}
+# dit_config["max_img_h"] = 512
+# dit_config["max_img_w"] = 512
+# dit_config["max_frames"] = 128
+# concat_padding_mask = True
+# dit_config["in_channels"] = (state_dict["{}x_embedder.proj.1.weight".format(key_prefix)].shape[1] // 4) - int(
+# concat_padding_mask
+# )
+# dit_config["out_channels"] = 16
+# dit_config["patch_spatial"] = 2
+# dit_config["patch_temporal"] = 1
+# dit_config["model_channels"] = state_dict["{}x_embedder.proj.1.weight".format(key_prefix)].shape[0]
+# dit_config["concat_padding_mask"] = concat_padding_mask
+# dit_config["crossattn_emb_channels"] = 1024
+# dit_config["pos_emb_cls"] = "rope3d"
+# dit_config["pos_emb_learnable"] = True
+# dit_config["pos_emb_interpolation"] = "crop"
+# dit_config["min_fps"] = 1
+# dit_config["max_fps"] = 30
-# DiT config detection from state_dict
-KEEP_IN_HIGH_PRECISION = ['x_embedder', 't_embedder', 't_embedding_norm', 'final_layer']
+# dit_config["use_adaln_lora"] = True
+# dit_config["adaln_lora_dim"] = 256
+# if dit_config["model_channels"] == 2048:
+# dit_config["num_blocks"] = 28
+# dit_config["num_heads"] = 16
+# elif dit_config["model_channels"] == 5120:
+# dit_config["num_blocks"] = 36
+# dit_config["num_heads"] = 40
+# elif dit_config["model_channels"] == 1280:
+# dit_config["num_blocks"] = 20
+# dit_config["num_heads"] = 20
+# if dit_config["in_channels"] == 16:
+# dit_config["extra_per_block_abs_pos_emb"] = False
+# dit_config["rope_h_extrapolation_ratio"] = 4.0
+# dit_config["rope_w_extrapolation_ratio"] = 4.0
+# dit_config["rope_t_extrapolation_ratio"] = 1.0
+# elif dit_config["in_channels"] == 17:
+# dit_config["extra_per_block_abs_pos_emb"] = False
+# dit_config["rope_h_extrapolation_ratio"] = 3.0
+# dit_config["rope_w_extrapolation_ratio"] = 3.0
+# dit_config["rope_t_extrapolation_ratio"] = 1.0
-def get_dit_config(state_dict, key_prefix=''):
- """Derive DiT configuration from state_dict weight shapes."""
- dit_config = {}
- dit_config["max_img_h"] = 512
- dit_config["max_img_w"] = 512
- dit_config["max_frames"] = 128
- concat_padding_mask = True
- dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask)
- dit_config["out_channels"] = 16
- dit_config["patch_spatial"] = 2
- dit_config["patch_temporal"] = 1
- dit_config["model_channels"] = state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[0]
- dit_config["concat_padding_mask"] = concat_padding_mask
- dit_config["crossattn_emb_channels"] = 1024
- dit_config["pos_emb_cls"] = "rope3d"
- dit_config["pos_emb_learnable"] = True
- dit_config["pos_emb_interpolation"] = "crop"
- dit_config["min_fps"] = 1
- dit_config["max_fps"] = 30
+# dit_config["extra_h_extrapolation_ratio"] = 1.0
+# dit_config["extra_w_extrapolation_ratio"] = 1.0
+# dit_config["extra_t_extrapolation_ratio"] = 1.0
+# dit_config["rope_enable_fps_modulation"] = False
- dit_config["use_adaln_lora"] = True
- dit_config["adaln_lora_dim"] = 256
- if dit_config["model_channels"] == 2048:
- dit_config["num_blocks"] = 28
- dit_config["num_heads"] = 16
- elif dit_config["model_channels"] == 5120:
- dit_config["num_blocks"] = 36
- dit_config["num_heads"] = 40
- elif dit_config["model_channels"] == 1280:
- dit_config["num_blocks"] = 20
- dit_config["num_heads"] = 20
-
- if dit_config["in_channels"] == 16:
- dit_config["extra_per_block_abs_pos_emb"] = False
- dit_config["rope_h_extrapolation_ratio"] = 4.0
- dit_config["rope_w_extrapolation_ratio"] = 4.0
- dit_config["rope_t_extrapolation_ratio"] = 1.0
- elif dit_config["in_channels"] == 17:
- dit_config["extra_per_block_abs_pos_emb"] = False
- dit_config["rope_h_extrapolation_ratio"] = 3.0
- dit_config["rope_w_extrapolation_ratio"] = 3.0
- dit_config["rope_t_extrapolation_ratio"] = 1.0
-
- dit_config["extra_h_extrapolation_ratio"] = 1.0
- dit_config["extra_w_extrapolation_ratio"] = 1.0
- dit_config["extra_t_extrapolation_ratio"] = 1.0
- dit_config["rope_enable_fps_modulation"] = False
-
- return dit_config
+# return dit_config
diff --git a/library/anima_train_utils.py b/library/anima_train_utils.py
index ef0016b5..3161e79e 100644
--- a/library/anima_train_utils.py
+++ b/library/anima_train_utils.py
@@ -1,20 +1,20 @@
# Anima Training Utilities
import argparse
+import gc
import math
import os
import time
-from typing import Dict, List, Optional, Tuple, Union
+from typing import Optional
import numpy as np
import torch
-import torch.nn.functional as F
-from safetensors.torch import save_file
-from accelerate import Accelerator, PartialState
+from accelerate import Accelerator
from tqdm import tqdm
from PIL import Image
-from library.device_utils import init_ipex, clean_memory_on_device
+from library.device_utils import init_ipex, clean_memory_on_device, synchronize_device
+from library import anima_models, anima_utils, train_util, qwen_image_autoencoder_kl
init_ipex()
@@ -25,29 +25,14 @@ import logging
logger = logging.getLogger(__name__)
-from library import anima_models, anima_utils, strategy_base, train_util
-
-from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler, get_sigmas
-
# Anima-specific training arguments
+
def add_anima_training_arguments(parser: argparse.ArgumentParser):
"""Add Anima-specific training arguments to the parser."""
parser.add_argument(
- "--dit_path",
- type=str,
- default=None,
- help="Path to Anima DiT model safetensors file",
- )
- parser.add_argument(
- "--vae_path",
- type=str,
- default=None,
- help="Path to WanVAE safetensors/pth file",
- )
- parser.add_argument(
- "--qwen3_path",
+ "--qwen3",
type=str,
default=None,
help="Path to Qwen3-0.6B model (safetensors file or directory)",
@@ -86,7 +71,7 @@ def add_anima_training_arguments(parser: argparse.ArgumentParser):
"--mod_lr",
type=float,
default=None,
- help="Learning rate for AdaLN modulation layers. None=same as base LR, 0=freeze",
+ help="Learning rate for AdaLN modulation layers. None=same as base LR, 0=freeze. Note: mod layers are not included in LoRA by default.",
)
parser.add_argument(
"--t5_tokenizer_path",
@@ -113,110 +98,52 @@ def add_anima_training_arguments(parser: argparse.ArgumentParser):
help="Timestep distribution shift for rectified flow training (default: 1.0)",
)
parser.add_argument(
- "--timestep_sample_method",
+ "--timestep_sampling",
type=str,
- default="logit_normal",
- choices=["logit_normal", "uniform"],
- help="Timestep sampling method (default: logit_normal)",
+ default="sigmoid",
+ choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"],
+ help="Timestep sampling method (default: sigmoid (logit normal))",
)
parser.add_argument(
"--sigmoid_scale",
type=float,
default=1.0,
- help="Scale factor for logit_normal timestep sampling (default: 1.0)",
+ help="Scale factor for sigmoid (logit_normal) timestep sampling (default: 1.0)",
)
- # Note: --caption_dropout_rate is defined by base add_dataset_arguments().
- # Anima uses embedding-level dropout (via AnimaTextEncodingStrategy.dropout_rate)
- # instead of dataset-level caption dropout, so the subset caption_dropout_rate
- # is zeroed out in the training scripts to allow caching.
parser.add_argument(
- "--transformer_dtype",
- type=str,
+ "--attn_mode",
+ choices=["torch", "xformers", "flash", "sageattn", "sdpa"], # "sdpa" is for backward compatibility
default=None,
- choices=["float16", "bfloat16", "float32", None],
- help="Separate dtype for transformer blocks. If None, uses same as mixed_precision",
+ help="Attention implementation to use. Default is None (torch). xformers requires --split_attn. sageattn does not support training (inference only). This option overrides --xformers or --sdpa."
+ " / 使用するAttentionの実装。デフォルトはNone(torch)です。xformersは--split_attnの指定が必要です。sageattnはトレーニングをサポートしていません(推論のみ)。このオプションは--xformersまたは--sdpaを上書きします。",
)
parser.add_argument(
- "--flash_attn",
+ "--split_attn",
action="store_true",
- help="Use Flash Attention for DiT self/cross-attention (requires flash-attn package). "
- "Falls back to PyTorch SDPA if flash-attn is not installed.",
+ help="split attention computation to reduce memory usage / メモリ使用量を減らすためにattention時にバッチを分割する",
+ )
+ parser.add_argument(
+ "--vae_chunk_size",
+ type=int,
+ default=None,
+ help="Spatial chunk size for VAE encoding/decoding to reduce memory usage. Must be even number. If not specified, chunking is disabled (official behavior)."
+ + " / メモリ使用量を減らすためのVAEエンコード/デコードの空間チャンクサイズ。偶数である必要があります。未指定の場合、チャンク処理は無効になります(公式の動作)。",
+ )
+ parser.add_argument(
+ "--vae_disable_cache",
+ action="store_true",
+ help="Disable internal VAE caching mechanism to reduce memory usage. Encoding / decoding will also be faster, but this differs from official behavior."
+ + " / VAEのメモリ使用量を減らすために内部のキャッシュ機構を無効にします。エンコード/デコードも速くなりますが、公式の動作とは異なります。",
)
-
-
-# Noise & Timestep sampling (Rectified Flow)
-def get_noisy_model_input_and_timesteps(
- args,
- latents: torch.Tensor,
- noise: torch.Tensor,
- device: torch.device,
- dtype: torch.dtype,
-) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- """Generate noisy model input and timesteps for rectified flow training.
-
- Rectified flow: noisy_input = (1 - t) * latents + t * noise
- Target: noise - latents
-
- Args:
- args: Training arguments with timestep_sample_method, sigmoid_scale, discrete_flow_shift
- latents: Clean latent tensors
- noise: Random noise tensors
- device: Target device
- dtype: Target dtype
-
- Returns:
- (noisy_model_input, timesteps, sigmas)
- """
- bs = latents.shape[0]
-
- timestep_sample_method = getattr(args, 'timestep_sample_method', 'logit_normal')
- sigmoid_scale = getattr(args, 'sigmoid_scale', 1.0)
- shift = getattr(args, 'discrete_flow_shift', 1.0)
-
- if timestep_sample_method == 'logit_normal':
- dist = torch.distributions.normal.Normal(0, 1)
- elif timestep_sample_method == 'uniform':
- dist = torch.distributions.uniform.Uniform(0, 1)
- else:
- raise NotImplementedError(f"Unknown timestep_sample_method: {timestep_sample_method}")
-
- t = dist.sample((bs,)).to(device)
-
- if timestep_sample_method == 'logit_normal':
- t = t * sigmoid_scale
- t = torch.sigmoid(t)
-
- # Apply shift
- if shift is not None and shift != 1.0:
- t = (t * shift) / (1 + (shift - 1) * t)
-
- # Clamp to avoid exact 0 or 1
- t = t.clamp(1e-5, 1.0 - 1e-5)
-
- # Create noisy input: (1 - t) * latents + t * noise
- t_expanded = t.view(-1, *([1] * (latents.ndim - 1)))
-
- ip_noise_gamma = getattr(args, 'ip_noise_gamma', None)
- if ip_noise_gamma:
- xi = torch.randn_like(latents, device=latents.device, dtype=dtype)
- if getattr(args, 'ip_noise_gamma_random_strength', False):
- ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * ip_noise_gamma
- noisy_model_input = (1 - t_expanded) * latents + t_expanded * (noise + ip_noise_gamma * xi)
- else:
- noisy_model_input = (1 - t_expanded) * latents + t_expanded * noise
-
- # Sigmas for potential loss weighting
- sigmas = t.view(-1, 1)
-
- return noisy_model_input.to(dtype), t.to(dtype), sigmas.to(dtype)
# Loss weighting
+
def compute_loss_weighting_for_anima(weighting_scheme: str, sigmas: torch.Tensor) -> torch.Tensor:
"""Compute loss weighting for Anima training.
- Same schemes as SD3 but can add Anima-specific ones.
+ Same schemes as SD3 but can add Anima-specific ones if needed in future.
"""
if weighting_scheme == "sigma_sqrt":
weighting = (sigmas**-2.0).float()
@@ -243,7 +170,7 @@ def get_anima_param_groups(
"""Create parameter groups for Anima training with separate learning rates.
Args:
- dit: MiniTrainDIT model
+ dit: Anima model
base_lr: Base learning rate
self_attn_lr: LR for self-attention layers (None = base_lr, 0 = freeze)
cross_attn_lr: LR for cross-attention layers
@@ -276,15 +203,15 @@ def get_anima_param_groups(
# Store original name for debugging
p.original_name = name
- if 'llm_adapter' in name:
+ if "llm_adapter" in name:
llm_adapter_params.append(p)
- elif '.self_attn' in name:
+ elif ".self_attn" in name:
self_attn_params.append(p)
- elif '.cross_attn' in name:
+ elif ".cross_attn" in name:
cross_attn_params.append(p)
- elif '.mlp' in name:
+ elif ".mlp" in name:
mlp_params.append(p)
- elif '.adaln_modulation' in name:
+ elif ".adaln_modulation" in name:
mod_params.append(p)
else:
base_params.append(p)
@@ -311,9 +238,9 @@ def get_anima_param_groups(
p.requires_grad_(False)
logger.info(f" Frozen {name} params ({len(params)} parameters)")
elif len(params) > 0:
- param_groups.append({'params': params, 'lr': lr})
+ param_groups.append({"params": params, "lr": lr})
- total_trainable = sum(p.numel() for group in param_groups for p in group['params'] if p.requires_grad)
+ total_trainable = sum(p.numel() for group in param_groups for p in group["params"] if p.requires_grad)
logger.info(f"Total trainable parameters: {total_trainable:,}")
return param_groups
@@ -325,16 +252,17 @@ def save_anima_model_on_train_end(
save_dtype: torch.dtype,
epoch: int,
global_step: int,
- dit: anima_models.MiniTrainDIT,
+ dit: anima_models.Anima,
):
"""Save Anima model at the end of training."""
+
def sd_saver(ckpt_file, epoch_no, global_step):
- sai_metadata = train_util.get_sai_model_spec(
- None, args, False, False, False, is_stable_diffusion_ckpt=True
- )
+ sai_metadata = train_util.get_sai_model_spec_dataclass(
+ None, args, False, False, False, is_stable_diffusion_ckpt=True, anima="preview"
+ ).to_metadata_dict()
dit_sd = dit.state_dict()
# Save with 'net.' prefix for ComfyUI compatibility
- anima_utils.save_anima_model(ckpt_file, dit_sd, save_dtype)
+ anima_utils.save_anima_model(ckpt_file, dit_sd, sai_metadata, save_dtype)
train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None)
@@ -347,15 +275,16 @@ def save_anima_model_on_epoch_end_or_stepwise(
epoch: int,
num_train_epochs: int,
global_step: int,
- dit: anima_models.MiniTrainDIT,
+ dit: anima_models.Anima,
):
"""Save Anima model at epoch end or specific steps."""
+
def sd_saver(ckpt_file, epoch_no, global_step):
- sai_metadata = train_util.get_sai_model_spec(
- None, args, False, False, False, is_stable_diffusion_ckpt=True
- )
+ sai_metadata = train_util.get_sai_model_spec_dataclass(
+ None, args, False, False, False, is_stable_diffusion_ckpt=True, anima="preview"
+ ).to_metadata_dict()
dit_sd = dit.state_dict()
- anima_utils.save_anima_model(ckpt_file, dit_sd, save_dtype)
+ anima_utils.save_anima_model(ckpt_file, dit_sd, sai_metadata, save_dtype)
train_util.save_sd_model_on_epoch_end_or_stepwise_common(
args,
@@ -376,12 +305,13 @@ def do_sample(
height: int,
width: int,
seed: Optional[int],
- dit: anima_models.MiniTrainDIT,
+ dit: anima_models.Anima,
crossattn_emb: torch.Tensor,
steps: int,
dtype: torch.dtype,
device: torch.device,
guidance_scale: float = 1.0,
+ flow_shift: float = 3.0,
neg_crossattn_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Generate a sample using Euler discrete sampling for rectified flow.
@@ -389,12 +319,13 @@ def do_sample(
Args:
height, width: Output image dimensions
seed: Random seed (None for random)
- dit: MiniTrainDIT model
+ dit: Anima model
crossattn_emb: Cross-attention embeddings (B, N, D)
steps: Number of sampling steps
dtype: Compute dtype
device: Compute device
guidance_scale: CFG scale (1.0 = no guidance)
+ flow_shift: Flow shift parameter for rectified flow
neg_crossattn_emb: Negative cross-attention embeddings for CFG
Returns:
@@ -410,12 +341,13 @@ def do_sample(
generator = torch.manual_seed(seed)
else:
generator = None
- noise = torch.randn(
- latent.size(), dtype=torch.float32, generator=generator, device="cpu"
- ).to(dtype).to(device)
+ noise = torch.randn(latent.size(), dtype=torch.float32, generator=generator, device="cpu").to(dtype).to(device)
# Timestep schedule: linear from 1.0 to 0.0
sigmas = torch.linspace(1.0, 0.0, steps + 1, device=device, dtype=dtype)
+ flow_shift = float(flow_shift)
+ if flow_shift != 1.0:
+ sigmas = (sigmas * flow_shift) / (1 + (flow_shift - 1) * sigmas)
# Start from pure noise
x = noise.clone()
@@ -429,19 +361,13 @@ def do_sample(
sigma = sigmas[i]
t = sigma.unsqueeze(0) # (1,)
- dit.prepare_block_swap_before_forward()
-
if use_cfg:
- # CFG: concat positive and negative
- x_input = torch.cat([x, x], dim=0)
- t_input = torch.cat([t, t], dim=0)
- crossattn_input = torch.cat([crossattn_emb, neg_crossattn_emb], dim=0)
- padding_input = torch.cat([padding_mask, padding_mask], dim=0)
+ # CFG: two separate passes to reduce memory usage
+ pos_out = dit(x, t, crossattn_emb, padding_mask=padding_mask)
+ pos_out = pos_out.float()
+ neg_out = dit(x, t, neg_crossattn_emb, padding_mask=padding_mask)
+ neg_out = neg_out.float()
- model_output = dit(x_input, t_input, crossattn_input, padding_mask=padding_input)
- model_output = model_output.float()
-
- pos_out, neg_out = model_output.chunk(2)
model_output = neg_out + guidance_scale * (pos_out - neg_out)
else:
model_output = dit(x, t, crossattn_emb, padding_mask=padding_mask)
@@ -452,7 +378,6 @@ def do_sample(
x = x + model_output * dt
x = x.to(dtype)
- dit.prepare_block_swap_before_forward()
return x
@@ -461,9 +386,8 @@ def sample_images(
args: argparse.Namespace,
epoch,
steps,
- dit,
+ dit: anima_models.Anima,
vae,
- vae_scale,
text_encoder,
tokenize_strategy,
text_encoding_strategy,
@@ -497,6 +421,8 @@ def sample_images(
if text_encoder is not None:
text_encoder = accelerator.unwrap_model(text_encoder)
+ dit.switch_block_swap_for_inference()
+
prompts = train_util.load_prompts(args.sample_prompts)
save_dir = os.path.join(args.output_dir, "sample")
os.makedirs(save_dir, exist_ok=True)
@@ -511,11 +437,21 @@ def sample_images(
with torch.no_grad(), accelerator.autocast():
for prompt_dict in prompts:
+ dit.prepare_block_swap_before_forward()
_sample_image_inference(
- accelerator, args, dit, text_encoder, vae, vae_scale,
- tokenize_strategy, text_encoding_strategy,
- save_dir, prompt_dict, epoch, steps,
- sample_prompts_te_outputs, prompt_replacement,
+ accelerator,
+ args,
+ dit,
+ text_encoder,
+ vae,
+ tokenize_strategy,
+ text_encoding_strategy,
+ save_dir,
+ prompt_dict,
+ epoch,
+ steps,
+ sample_prompts_te_outputs,
+ prompt_replacement,
)
# Restore RNG state
@@ -523,14 +459,24 @@ def sample_images(
if cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state)
+ dit.switch_block_swap_for_training()
clean_memory_on_device(accelerator.device)
def _sample_image_inference(
- accelerator, args, dit, text_encoder, vae, vae_scale,
- tokenize_strategy, text_encoding_strategy,
- save_dir, prompt_dict, epoch, steps,
- sample_prompts_te_outputs, prompt_replacement,
+ accelerator,
+ args,
+ dit,
+ text_encoder,
+ vae: qwen_image_autoencoder_kl.AutoencoderKLQwenImage,
+ tokenize_strategy,
+ text_encoding_strategy,
+ save_dir,
+ prompt_dict,
+ epoch,
+ steps,
+ sample_prompts_te_outputs,
+ prompt_replacement,
):
"""Generate a single sample image."""
prompt = prompt_dict.get("prompt", "")
@@ -540,6 +486,7 @@ def _sample_image_inference(
height = prompt_dict.get("height", 512)
scale = prompt_dict.get("scale", 7.5)
seed = prompt_dict.get("seed")
+ flow_shift = prompt_dict.get("flow_shift", 3.0)
if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
@@ -553,7 +500,9 @@ def _sample_image_inference(
height = max(64, height - height % 16)
width = max(64, width - width % 16)
- logger.info(f" prompt: {prompt}, size: {width}x{height}, steps: {sample_steps}, scale: {scale}")
+ logger.info(
+ f" prompt: {prompt}, size: {width}x{height}, steps: {sample_steps}, scale: {scale}, flow_shift: {flow_shift}, seed: {seed}"
+ )
# Encode prompt
def encode_prompt(prpt):
@@ -579,13 +528,13 @@ def _sample_image_inference(
t5_input_ids = torch.from_numpy(t5_input_ids).unsqueeze(0)
t5_attn_mask = torch.from_numpy(t5_attn_mask).unsqueeze(0)
- prompt_embeds = prompt_embeds.to(accelerator.device, dtype=dit.t_embedding_norm.weight.dtype)
+ prompt_embeds = prompt_embeds.to(accelerator.device, dtype=dit.dtype)
attn_mask = attn_mask.to(accelerator.device)
t5_input_ids = t5_input_ids.to(accelerator.device, dtype=torch.long)
t5_attn_mask = t5_attn_mask.to(accelerator.device)
# Process through LLM adapter if available
- if dit.use_llm_adapter and hasattr(dit, 'llm_adapter'):
+ if dit.use_llm_adapter:
crossattn_emb = dit.llm_adapter(
source_hidden_states=prompt_embeds,
target_input_ids=t5_input_ids,
@@ -608,12 +557,12 @@ def _sample_image_inference(
neg_t5_ids = torch.from_numpy(neg_t5_ids).unsqueeze(0)
neg_t5_am = torch.from_numpy(neg_t5_am).unsqueeze(0)
- neg_pe = neg_pe.to(accelerator.device, dtype=dit.t_embedding_norm.weight.dtype)
+ neg_pe = neg_pe.to(accelerator.device, dtype=dit.dtype)
neg_am = neg_am.to(accelerator.device)
neg_t5_ids = neg_t5_ids.to(accelerator.device, dtype=torch.long)
neg_t5_am = neg_t5_am.to(accelerator.device)
- if dit.use_llm_adapter and hasattr(dit, 'llm_adapter'):
+ if dit.use_llm_adapter:
neg_crossattn_emb = dit.llm_adapter(
source_hidden_states=neg_pe,
target_input_ids=neg_t5_ids,
@@ -627,16 +576,16 @@ def _sample_image_inference(
# Generate sample
clean_memory_on_device(accelerator.device)
latents = do_sample(
- height, width, seed, dit, crossattn_emb,
- sample_steps, dit.t_embedding_norm.weight.dtype,
- accelerator.device, scale, neg_crossattn_emb,
+ height, width, seed, dit, crossattn_emb, sample_steps, dit.dtype, accelerator.device, scale, flow_shift, neg_crossattn_emb
)
# Decode latents
+ gc.collect()
+ synchronize_device(accelerator.device)
clean_memory_on_device(accelerator.device)
- org_vae_device = next(vae.parameters()).device
+ org_vae_device = vae.device
vae.to(accelerator.device)
- decoded = vae.decode(latents.to(next(vae.parameters()).device, dtype=next(vae.parameters()).dtype), vae_scale)
+ decoded = vae.decode_to_pixels(latents)
vae.to(org_vae_device)
clean_memory_on_device(accelerator.device)
@@ -662,4 +611,5 @@ def _sample_image_inference(
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
wandb_tracker = accelerator.get_tracker("wandb")
import wandb
+
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False)
diff --git a/library/anima_utils.py b/library/anima_utils.py
index 8c171e0e..6a4422a8 100644
--- a/library/anima_utils.py
+++ b/library/anima_utils.py
@@ -3,10 +3,14 @@
import os
from typing import Dict, List, Optional, Union
import torch
-import torch.nn as nn
from safetensors.torch import load_file, save_file
from accelerate.utils import set_module_tensor_to_device # kept for potential future use
+from accelerate import init_empty_weights
+from library.fp8_optimization_utils import apply_fp8_monkey_patch
+from library.lora_utils import load_safetensors_with_lora_and_fp8
+from library import anima_models
+from library.safetensors_utils import WeightTransformHooks
from .utils import setup_logging
setup_logging()
@@ -14,150 +18,134 @@ import logging
logger = logging.getLogger(__name__)
-from library import anima_models
+
+# Original Anima high-precision keys. Kept for reference, but not used currently.
+# # Keys that should stay in high precision (float32/bfloat16, not quantized)
+# KEEP_IN_HIGH_PRECISION = ["x_embedder", "t_embedder", "t_embedding_norm", "final_layer"]
-# Keys that should stay in high precision (float32/bfloat16, not quantized)
-KEEP_IN_HIGH_PRECISION = ['x_embedder', 't_embedder', 't_embedding_norm', 'final_layer']
+FP8_OPTIMIZATION_TARGET_KEYS = ["blocks", ""]
+# ".embed." excludes Embedding in LLMAdapter
+FP8_OPTIMIZATION_EXCLUDE_KEYS = ["_embedder", "norm", "adaln", "final_layer", ".embed."]
-def load_safetensors(path: str, device: str = "cpu", dtype: Optional[torch.dtype] = None) -> Dict[str, torch.Tensor]:
- """Load a safetensors file and optionally cast to dtype."""
- sd = load_file(path, device=device)
- if dtype is not None:
- sd = {k: v.to(dtype) for k, v in sd.items()}
- return sd
-
-
-def load_anima_dit(
+def load_anima_model(
+ device: Union[str, torch.device],
dit_path: str,
- dtype: torch.dtype,
- device: Union[str, torch.device] = "cpu",
- transformer_dtype: Optional[torch.dtype] = None,
- llm_adapter_path: Optional[str] = None,
- disable_mmap: bool = False,
-) -> anima_models.MiniTrainDIT:
- """Load the MiniTrainDIT model from safetensors.
+ attn_mode: str,
+ split_attn: bool,
+ loading_device: Union[str, torch.device],
+ dit_weight_dtype: Optional[torch.dtype],
+ fp8_scaled: bool = False,
+ lora_weights_list: Optional[List[Dict[str, torch.Tensor]]] = None,
+ lora_multipliers: Optional[list[float]] = None,
+) -> anima_models.Anima:
+ """
+ Load Anima model from the specified checkpoint.
Args:
- dit_path: Path to DiT safetensors file
- dtype: Base dtype for model parameters
- device: Device to load to
- transformer_dtype: Optional separate dtype for transformer blocks (lower precision)
- llm_adapter_path: Optional separate path for LLM adapter weights
- disable_mmap: If True, disable memory-mapped loading (reduces peak memory)
+ device (Union[str, torch.device]): Device for optimization or merging
+ dit_path (str): Path to the DiT model checkpoint.
+ attn_mode (str): Attention mode to use, e.g., "torch", "flash", etc.
+ split_attn (bool): Whether to use split attention.
+ loading_device (Union[str, torch.device]): Device to load the model weights on.
+ dit_weight_dtype (Optional[torch.dtype]): Data type of the DiT weights.
+ If None, it will be loaded as is (same as the state_dict) or scaled for fp8. if not None, model weights will be casted to this dtype.
+ fp8_scaled (bool): Whether to use fp8 scaling for the model weights.
+ lora_weights_list (Optional[List[Dict[str, torch.Tensor]]]): LoRA weights to apply, if any.
+ lora_multipliers (Optional[List[float]]): LoRA multipliers for the weights, if any.
"""
- if transformer_dtype is None:
- transformer_dtype = dtype
+ # dit_weight_dtype is None for fp8_scaled
+ assert (
+ not fp8_scaled and dit_weight_dtype is not None
+ ) or dit_weight_dtype is None, "dit_weight_dtype should be None when fp8_scaled is True"
- logger.info(f"Loading Anima DiT from {dit_path}")
- if disable_mmap:
- from library.safetensors_utils import load_safetensors as load_safetensors_no_mmap
- state_dict = load_safetensors_no_mmap(dit_path, device="cpu", disable_mmap=True)
- else:
- state_dict = load_file(dit_path, device="cpu")
+ device = torch.device(device)
+ loading_device = torch.device(loading_device)
- # Remove 'net.' prefix if present
- new_state_dict = {}
- for k, v in state_dict.items():
- if k.startswith('net.'):
- k = k[len('net.'):]
- new_state_dict[k] = v
- state_dict = new_state_dict
+ # We currently support fixed DiT config for Anima models
+ dit_config = {
+ "max_img_h": 512,
+ "max_img_w": 512,
+ "max_frames": 128,
+ "in_channels": 16,
+ "out_channels": 16,
+ "patch_spatial": 2,
+ "patch_temporal": 1,
+ "model_channels": 2048,
+ "concat_padding_mask": True,
+ "crossattn_emb_channels": 1024,
+ "pos_emb_cls": "rope3d",
+ "pos_emb_learnable": True,
+ "pos_emb_interpolation": "crop",
+ "min_fps": 1,
+ "max_fps": 30,
+ "use_adaln_lora": True,
+ "adaln_lora_dim": 256,
+ "num_blocks": 28,
+ "num_heads": 16,
+ "extra_per_block_abs_pos_emb": False,
+ "rope_h_extrapolation_ratio": 4.0,
+ "rope_w_extrapolation_ratio": 4.0,
+ "rope_t_extrapolation_ratio": 1.0,
+ "extra_h_extrapolation_ratio": 1.0,
+ "extra_w_extrapolation_ratio": 1.0,
+ "extra_t_extrapolation_ratio": 1.0,
+ "rope_enable_fps_modulation": False,
+ "use_llm_adapter": True,
+ "attn_mode": attn_mode,
+ "split_attn": split_attn,
+ }
+ with init_empty_weights():
+ model = anima_models.Anima(**dit_config)
+ if dit_weight_dtype is not None:
+ model.to(dit_weight_dtype)
- # Derive config from state_dict
- dit_config = anima_models.get_dit_config(state_dict)
-
- # Detect LLM adapter
- if llm_adapter_path is not None:
- use_llm_adapter = True
- dit_config['use_llm_adapter'] = True
- llm_adapter_state_dict = load_safetensors(llm_adapter_path, device="cpu")
- elif 'llm_adapter.out_proj.weight' in state_dict:
- use_llm_adapter = True
- dit_config['use_llm_adapter'] = True
- llm_adapter_state_dict = None # Loaded as part of DiT
- else:
- use_llm_adapter = False
- llm_adapter_state_dict = None
-
- logger.info(f"DiT config: model_channels={dit_config['model_channels']}, num_blocks={dit_config['num_blocks']}, "
- f"num_heads={dit_config['num_heads']}, use_llm_adapter={use_llm_adapter}")
-
- # Build model normally on CPU — buffers get proper values from __init__
- dit = anima_models.MiniTrainDIT(**dit_config)
-
- # Merge LLM adapter weights into state_dict if loaded separately
- if use_llm_adapter and llm_adapter_state_dict is not None:
- for k, v in llm_adapter_state_dict.items():
- state_dict[f"llm_adapter.{k}"] = v
-
- # Load checkpoint: strict=False keeps buffers not in checkpoint (e.g. pos_embedder.seq)
- missing, unexpected = dit.load_state_dict(state_dict, strict=False)
- if missing:
- # Filter out expected missing buffers (initialized in __init__, not saved in checkpoint)
- unexpected_missing = [k for k in missing if not any(
- buf_name in k for buf_name in ('seq', 'dim_spatial_range', 'dim_temporal_range', 'inv_freq')
- )]
- if unexpected_missing:
- logger.warning(f"Missing keys in checkpoint: {unexpected_missing[:10]}{'...' if len(unexpected_missing) > 10 else ''}")
- if unexpected:
- logger.info(f"Unexpected keys in checkpoint (ignored): {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}")
-
- # Apply per-parameter dtype (high precision for 1D/critical, transformer_dtype for rest)
- for name, p in dit.named_parameters():
- dtype_to_use = dtype if (
- any(keyword in name for keyword in KEEP_IN_HIGH_PRECISION) or p.ndim == 1
- ) else transformer_dtype
- p.data = p.data.to(dtype=dtype_to_use)
-
- dit.to(device)
- logger.info(f"Loaded Anima DiT successfully. Parameters: {sum(p.numel() for p in dit.parameters()):,}")
- return dit
-
-
-def load_anima_vae(vae_path: str, dtype: torch.dtype = torch.float32, device: str = "cpu"):
- """Load WanVAE from a safetensors/pth file.
-
- Returns (vae_model, mean_tensor, std_tensor, scale).
- """
- from library.anima_models import ANIMA_VAE_MEAN, ANIMA_VAE_STD
-
- logger.info(f"Loading Anima VAE from {vae_path}")
-
- # VAE config (fixed for WanVAE)
- vae_config = dict(
- dim=96,
- z_dim=16,
- dim_mult=[1, 2, 4, 4],
- num_res_blocks=2,
- attn_scales=[],
- temperal_downsample=[False, True, True],
- dropout=0.0,
+ # load model weights with dynamic fp8 optimization and LoRA merging if needed
+ logger.info(f"Loading DiT model from {dit_path}, device={loading_device}")
+ rename_hooks = WeightTransformHooks(rename_hook=lambda k: k[len("net.") :] if k.startswith("net.") else k)
+ sd = load_safetensors_with_lora_and_fp8(
+ model_files=dit_path,
+ lora_weights_list=lora_weights_list,
+ lora_multipliers=lora_multipliers,
+ fp8_optimization=fp8_scaled,
+ calc_device=device,
+ move_to_device=(loading_device == device),
+ dit_weight_dtype=dit_weight_dtype,
+ target_keys=FP8_OPTIMIZATION_TARGET_KEYS,
+ exclude_keys=FP8_OPTIMIZATION_EXCLUDE_KEYS,
+ weight_transform_hooks=rename_hooks,
)
- from library.anima_vae import WanVAE_
+ if fp8_scaled:
+ apply_fp8_monkey_patch(model, sd, use_scaled_mm=False)
- # Build model
- with torch.device('meta'):
- vae = WanVAE_(**vae_config)
+ if loading_device.type != "cpu":
+ # make sure all the model weights are on the loading_device
+ logger.info(f"Moving weights to {loading_device}")
+ for key in sd.keys():
+ sd[key] = sd[key].to(loading_device)
- # Load state dict
- if vae_path.endswith('.safetensors'):
- vae_sd = load_file(vae_path, device='cpu')
- else:
- vae_sd = torch.load(vae_path, map_location='cpu', weights_only=True)
+ missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
+ if missing:
+ # Filter out expected missing buffers (initialized in __init__, not saved in checkpoint)
+ unexpected_missing = [
+ k
+ for k in missing
+ if not any(buf_name in k for buf_name in ("seq", "dim_spatial_range", "dim_temporal_range", "inv_freq"))
+ ]
+ if unexpected_missing:
+ # Raise error to avoid silent failures
+ raise RuntimeError(
+ f"Missing keys in checkpoint: {unexpected_missing[:10]}{'...' if len(unexpected_missing) > 10 else ''}"
+ )
+ missing = {} # all missing keys were expected
+ if unexpected:
+ # Raise error to avoid silent failures
+ raise RuntimeError(f"Unexpected keys in checkpoint: {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}")
+ logger.info(f"Loaded DiT model from {dit_path}, unexpected missing keys: {len(missing)}, unexpected keys: {len(unexpected)}")
- vae.load_state_dict(vae_sd, assign=True)
- vae = vae.eval().requires_grad_(False).to(device, dtype=dtype)
-
- # Create normalization tensors
- mean = torch.tensor(ANIMA_VAE_MEAN, dtype=dtype, device=device)
- std = torch.tensor(ANIMA_VAE_STD, dtype=dtype, device=device)
- scale = [mean, 1.0 / std]
-
- logger.info(f"Loaded Anima VAE successfully.")
- return vae, mean, std, scale
+ return model
def load_qwen3_tokenizer(qwen3_path: str):
@@ -175,7 +163,7 @@ def load_qwen3_tokenizer(qwen3_path: str):
if os.path.isdir(qwen3_path):
tokenizer = AutoTokenizer.from_pretrained(qwen3_path, local_files_only=True)
else:
- config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'configs', 'qwen3_06b')
+ config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs", "qwen3_06b")
if not os.path.exists(config_dir):
raise FileNotFoundError(
f"Qwen3 config directory not found at {config_dir}. "
@@ -190,7 +178,13 @@ def load_qwen3_tokenizer(qwen3_path: str):
return tokenizer
-def load_qwen3_text_encoder(qwen3_path: str, dtype: torch.dtype = torch.bfloat16, device: str = "cpu"):
+def load_qwen3_text_encoder(
+ qwen3_path: str,
+ dtype: torch.dtype = torch.bfloat16,
+ device: str = "cpu",
+ lora_weights: Optional[List[Dict[str, torch.Tensor]]] = None,
+ lora_multipliers: Optional[List[float]] = None,
+):
"""Load Qwen3-0.6B text encoder.
Args:
@@ -209,12 +203,10 @@ def load_qwen3_text_encoder(qwen3_path: str, dtype: torch.dtype = torch.bfloat16
if os.path.isdir(qwen3_path):
# Directory with full model
tokenizer = AutoTokenizer.from_pretrained(qwen3_path, local_files_only=True)
- model = transformers.AutoModelForCausalLM.from_pretrained(
- qwen3_path, torch_dtype=dtype, local_files_only=True
- ).model
+ model = transformers.AutoModelForCausalLM.from_pretrained(qwen3_path, torch_dtype=dtype, local_files_only=True).model
else:
# Single safetensors file - use configs/qwen3_06b/ for config
- config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'configs', 'qwen3_06b')
+ config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs", "qwen3_06b")
if not os.path.exists(config_dir):
raise FileNotFoundError(
f"Qwen3 config directory not found at {config_dir}. "
@@ -227,16 +219,28 @@ def load_qwen3_text_encoder(qwen3_path: str, dtype: torch.dtype = torch.bfloat16
model = transformers.Qwen3ForCausalLM(qwen3_config).model
# Load weights
- if qwen3_path.endswith('.safetensors'):
- state_dict = load_file(qwen3_path, device='cpu')
+ if qwen3_path.endswith(".safetensors"):
+ if lora_weights is None:
+ state_dict = load_file(qwen3_path, device="cpu")
+ else:
+ state_dict = load_safetensors_with_lora_and_fp8(
+ model_files=qwen3_path,
+ lora_weights_list=lora_weights,
+ lora_multipliers=lora_multipliers,
+ fp8_optimization=False,
+ calc_device=device,
+ move_to_device=True,
+ dit_weight_dtype=None,
+ )
else:
- state_dict = torch.load(qwen3_path, map_location='cpu', weights_only=True)
+ assert lora_weights is None, "LoRA weights merging is only supported for safetensors checkpoints"
+ state_dict = torch.load(qwen3_path, map_location="cpu", weights_only=True)
# Remove 'model.' prefix if present
new_sd = {}
for k, v in state_dict.items():
- if k.startswith('model.'):
- new_sd[k[len('model.'):]] = v
+ if k.startswith("model."):
+ new_sd[k[len("model.") :]] = v
else:
new_sd[k] = v
@@ -265,11 +269,11 @@ def load_t5_tokenizer(t5_tokenizer_path: Optional[str] = None):
return T5TokenizerFast.from_pretrained(t5_tokenizer_path, local_files_only=True)
# Use bundled config
- config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'configs', 't5_old')
+ config_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "configs", "t5_old")
if os.path.exists(config_dir):
return T5TokenizerFast(
- vocab_file=os.path.join(config_dir, 'spiece.model'),
- tokenizer_file=os.path.join(config_dir, 'tokenizer.json'),
+ vocab_file=os.path.join(config_dir, "spiece.model"),
+ tokenizer_file=os.path.join(config_dir, "tokenizer.json"),
)
raise FileNotFoundError(
@@ -279,47 +283,27 @@ def load_t5_tokenizer(t5_tokenizer_path: Optional[str] = None):
)
-def save_anima_model(save_path: str, dit_state_dict: Dict[str, torch.Tensor], dtype: Optional[torch.dtype] = None):
+def save_anima_model(
+ save_path: str, dit_state_dict: Dict[str, torch.Tensor], metadata: Dict[str, any], dtype: Optional[torch.dtype] = None
+):
"""Save Anima DiT model with 'net.' prefix for ComfyUI compatibility.
Args:
save_path: Output path (.safetensors)
dit_state_dict: State dict from dit.state_dict()
+ metadata: Metadata dict to include in the safetensors file
dtype: Optional dtype to cast to before saving
"""
prefixed_sd = {}
for k, v in dit_state_dict.items():
if dtype is not None:
- v = v.to(dtype)
- prefixed_sd['net.' + k] = v.contiguous()
+ # v = v.to(dtype)
+ v = v.detach().clone().to("cpu").to(dtype) # Reduce GPU memory usage during save
+ prefixed_sd["net." + k] = v.contiguous()
- save_file(prefixed_sd, save_path, metadata={'format': 'pt'})
+ if metadata is None:
+ metadata = {}
+ metadata["format"] = "pt" # For compatibility with the official .safetensors file
+
+ save_file(prefixed_sd, save_path, metadata=metadata) # safetensors.save_file cosumes a lot of memory, but Anima is small enough
logger.info(f"Saved Anima model to {save_path}")
-
-
-def vae_encode(tensor: torch.Tensor, vae, scale):
- """Encode tensor through WanVAE with normalization.
-
- Args:
- tensor: Input tensor (B, C, T, H, W) in [-1, 1] range
- vae: WanVAE_ model
- scale: [mean, 1/std] list
-
- Returns:
- Normalized latents
- """
- return vae.encode(tensor, scale)
-
-
-def vae_decode(latents: torch.Tensor, vae, scale):
- """Decode latents through WanVAE with denormalization.
-
- Args:
- latents: Normalized latents
- vae: WanVAE_ model
- scale: [mean, 1/std] list
-
- Returns:
- Decoded tensor in [-1, 1] range
- """
- return vae.decode(latents, scale)
diff --git a/library/anima_vae.py b/library/anima_vae.py
deleted file mode 100644
index 872bdfa2..00000000
--- a/library/anima_vae.py
+++ /dev/null
@@ -1,577 +0,0 @@
-import logging
-
-import torch
-import torch.cuda.amp as amp
-import torch.nn as nn
-import torch.nn.functional as F
-from einops import rearrange
-
-CACHE_T = 2
-
-
-class CausalConv3d(nn.Conv3d):
- """
- Causal 3d convolusion.
- """
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self._padding = (self.padding[2], self.padding[2], self.padding[1],
- self.padding[1], 2 * self.padding[0], 0)
- self.padding = (0, 0, 0)
-
- def forward(self, x, cache_x=None):
- padding = list(self._padding)
- if cache_x is not None and self._padding[4] > 0:
- cache_x = cache_x.to(x.device)
- x = torch.cat([cache_x, x], dim=2)
- padding[4] -= cache_x.shape[2]
- x = F.pad(x, padding)
-
- return super().forward(x)
-
-
-class RMS_norm(nn.Module):
-
- def __init__(self, dim, channel_first=True, images=True, bias=False):
- super().__init__()
- broadcastable_dims = (1, 1, 1) if not images else (1, 1)
- shape = (dim, *broadcastable_dims) if channel_first else (dim,)
-
- self.channel_first = channel_first
- self.scale = dim**0.5
- self.gamma = nn.Parameter(torch.ones(shape))
- self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
-
- def forward(self, x):
- return F.normalize(
- x, dim=(1 if self.channel_first else
- -1)) * self.scale * self.gamma + self.bias
-
-
-class Upsample(nn.Upsample):
-
- def forward(self, x):
- """
- Fix bfloat16 support for nearest neighbor interpolation.
- """
- return super().forward(x.float()).type_as(x)
-
-
-class Resample(nn.Module):
-
- def __init__(self, dim, mode):
- assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
- 'downsample3d')
- super().__init__()
- self.dim = dim
- self.mode = mode
-
- # layers
- if mode == 'upsample2d':
- self.resample = nn.Sequential(
- Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
- nn.Conv2d(dim, dim // 2, 3, padding=1))
- elif mode == 'upsample3d':
- self.resample = nn.Sequential(
- Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
- nn.Conv2d(dim, dim // 2, 3, padding=1))
- self.time_conv = CausalConv3d(
- dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
-
- elif mode == 'downsample2d':
- self.resample = nn.Sequential(
- nn.ZeroPad2d((0, 1, 0, 1)),
- nn.Conv2d(dim, dim, 3, stride=(2, 2)))
- elif mode == 'downsample3d':
- self.resample = nn.Sequential(
- nn.ZeroPad2d((0, 1, 0, 1)),
- nn.Conv2d(dim, dim, 3, stride=(2, 2)))
- self.time_conv = CausalConv3d(
- dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
-
- else:
- self.resample = nn.Identity()
-
- def forward(self, x, feat_cache=None, feat_idx=[0]):
- b, c, t, h, w = x.size()
- if self.mode == 'upsample3d':
- if feat_cache is not None:
- idx = feat_idx[0]
- if feat_cache[idx] is None:
- feat_cache[idx] = 'Rep'
- feat_idx[0] += 1
- else:
-
- cache_x = x[:, :, -CACHE_T:, :, :].clone()
- if cache_x.shape[2] < 2 and feat_cache[
- idx] is not None and feat_cache[idx] != 'Rep':
- # cache last frame of last two chunk
- cache_x = torch.cat([
- feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
- cache_x.device), cache_x
- ],
- dim=2)
- if cache_x.shape[2] < 2 and feat_cache[
- idx] is not None and feat_cache[idx] == 'Rep':
- cache_x = torch.cat([
- torch.zeros_like(cache_x).to(cache_x.device),
- cache_x
- ],
- dim=2)
- if feat_cache[idx] == 'Rep':
- x = self.time_conv(x)
- else:
- x = self.time_conv(x, feat_cache[idx])
- feat_cache[idx] = cache_x
- feat_idx[0] += 1
-
- x = x.reshape(b, 2, c, t, h, w)
- x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
- 3)
- x = x.reshape(b, c, t * 2, h, w)
- t = x.shape[2]
- x = rearrange(x, 'b c t h w -> (b t) c h w')
- x = self.resample(x)
- x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
-
- if self.mode == 'downsample3d':
- if feat_cache is not None:
- idx = feat_idx[0]
- if feat_cache[idx] is None:
- feat_cache[idx] = x.clone()
- feat_idx[0] += 1
- else:
-
- cache_x = x[:, :, -1:, :, :].clone()
- x = self.time_conv(
- torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
- feat_cache[idx] = cache_x
- feat_idx[0] += 1
- return x
-
- def init_weight(self, conv):
- conv_weight = conv.weight
- nn.init.zeros_(conv_weight)
- c1, c2, t, h, w = conv_weight.size()
- one_matrix = torch.eye(c1, c2)
- init_matrix = one_matrix
- nn.init.zeros_(conv_weight)
- conv_weight.data[:, :, 1, 0, 0] = init_matrix
- conv.weight.data.copy_(conv_weight)
- nn.init.zeros_(conv.bias.data)
-
- def init_weight2(self, conv):
- conv_weight = conv.weight.data
- nn.init.zeros_(conv_weight)
- c1, c2, t, h, w = conv_weight.size()
- init_matrix = torch.eye(c1 // 2, c2)
- conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
- conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
- conv.weight.data.copy_(conv_weight)
- nn.init.zeros_(conv.bias.data)
-
-
-class ResidualBlock(nn.Module):
-
- def __init__(self, in_dim, out_dim, dropout=0.0):
- super().__init__()
- self.in_dim = in_dim
- self.out_dim = out_dim
-
- # layers
- self.residual = nn.Sequential(
- RMS_norm(in_dim, images=False), nn.SiLU(),
- CausalConv3d(in_dim, out_dim, 3, padding=1),
- RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
- CausalConv3d(out_dim, out_dim, 3, padding=1))
- self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
- if in_dim != out_dim else nn.Identity()
-
- def forward(self, x, feat_cache=None, feat_idx=[0]):
- h = self.shortcut(x)
- for layer in self.residual:
- if isinstance(layer, CausalConv3d) and feat_cache is not None:
- idx = feat_idx[0]
- cache_x = x[:, :, -CACHE_T:, :, :].clone()
- if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
- # cache last frame of last two chunk
- cache_x = torch.cat([
- feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
- cache_x.device), cache_x
- ],
- dim=2)
- x = layer(x, feat_cache[idx])
- feat_cache[idx] = cache_x
- feat_idx[0] += 1
- else:
- x = layer(x)
- return x + h
-
-
-class AttentionBlock(nn.Module):
- """
- Causal self-attention with a single head.
- """
-
- def __init__(self, dim):
- super().__init__()
- self.dim = dim
-
- # layers
- self.norm = RMS_norm(dim)
- self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
- self.proj = nn.Conv2d(dim, dim, 1)
-
- # zero out the last layer params
- nn.init.zeros_(self.proj.weight)
-
- def forward(self, x):
- identity = x
- b, c, t, h, w = x.size()
- x = rearrange(x, 'b c t h w -> (b t) c h w')
- x = self.norm(x)
- # compute query, key, value
- q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3,
- -1).permute(0, 1, 3,
- 2).contiguous().chunk(
- 3, dim=-1)
-
- # apply attention
- x = F.scaled_dot_product_attention(
- q,
- k,
- v,
- )
- x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
-
- # output
- x = self.proj(x)
- x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
- return x + identity
-
-
-class Encoder3d(nn.Module):
-
- def __init__(self,
- dim=128,
- z_dim=4,
- dim_mult=[1, 2, 4, 4],
- num_res_blocks=2,
- attn_scales=[],
- temperal_downsample=[True, True, False],
- dropout=0.0):
- super().__init__()
- self.dim = dim
- self.z_dim = z_dim
- self.dim_mult = dim_mult
- self.num_res_blocks = num_res_blocks
- self.attn_scales = attn_scales
- self.temperal_downsample = temperal_downsample
-
- # dimensions
- dims = [dim * u for u in [1] + dim_mult]
- scale = 1.0
-
- # init block
- self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
-
- # downsample blocks
- downsamples = []
- for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
- # residual (+attention) blocks
- for _ in range(num_res_blocks):
- downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
- if scale in attn_scales:
- downsamples.append(AttentionBlock(out_dim))
- in_dim = out_dim
-
- # downsample block
- if i != len(dim_mult) - 1:
- mode = 'downsample3d' if temperal_downsample[
- i] else 'downsample2d'
- downsamples.append(Resample(out_dim, mode=mode))
- scale /= 2.0
- self.downsamples = nn.Sequential(*downsamples)
-
- # middle blocks
- self.middle = nn.Sequential(
- ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
- ResidualBlock(out_dim, out_dim, dropout))
-
- # output blocks
- self.head = nn.Sequential(
- RMS_norm(out_dim, images=False), nn.SiLU(),
- CausalConv3d(out_dim, z_dim, 3, padding=1))
-
- def forward(self, x, feat_cache=None, feat_idx=[0]):
- if feat_cache is not None:
- idx = feat_idx[0]
- cache_x = x[:, :, -CACHE_T:, :, :].clone()
- if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
- # cache last frame of last two chunk
- cache_x = torch.cat([
- feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
- cache_x.device), cache_x
- ],
- dim=2)
- x = self.conv1(x, feat_cache[idx])
- feat_cache[idx] = cache_x
- feat_idx[0] += 1
- else:
- x = self.conv1(x)
-
- ## downsamples
- for layer in self.downsamples:
- if feat_cache is not None:
- x = layer(x, feat_cache, feat_idx)
- else:
- x = layer(x)
-
- ## middle
- for layer in self.middle:
- if isinstance(layer, ResidualBlock) and feat_cache is not None:
- x = layer(x, feat_cache, feat_idx)
- else:
- x = layer(x)
-
- ## head
- for layer in self.head:
- if isinstance(layer, CausalConv3d) and feat_cache is not None:
- idx = feat_idx[0]
- cache_x = x[:, :, -CACHE_T:, :, :].clone()
- if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
- # cache last frame of last two chunk
- cache_x = torch.cat([
- feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
- cache_x.device), cache_x
- ],
- dim=2)
- x = layer(x, feat_cache[idx])
- feat_cache[idx] = cache_x
- feat_idx[0] += 1
- else:
- x = layer(x)
- return x
-
-
-class Decoder3d(nn.Module):
-
- def __init__(self,
- dim=128,
- z_dim=4,
- dim_mult=[1, 2, 4, 4],
- num_res_blocks=2,
- attn_scales=[],
- temperal_upsample=[False, True, True],
- dropout=0.0):
- super().__init__()
- self.dim = dim
- self.z_dim = z_dim
- self.dim_mult = dim_mult
- self.num_res_blocks = num_res_blocks
- self.attn_scales = attn_scales
- self.temperal_upsample = temperal_upsample
-
- # dimensions
- dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
- scale = 1.0 / 2**(len(dim_mult) - 2)
-
- # init block
- self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
-
- # middle blocks
- self.middle = nn.Sequential(
- ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
- ResidualBlock(dims[0], dims[0], dropout))
-
- # upsample blocks
- upsamples = []
- for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
- # residual (+attention) blocks
- if i == 1 or i == 2 or i == 3:
- in_dim = in_dim // 2
- for _ in range(num_res_blocks + 1):
- upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
- if scale in attn_scales:
- upsamples.append(AttentionBlock(out_dim))
- in_dim = out_dim
-
- # upsample block
- if i != len(dim_mult) - 1:
- mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
- upsamples.append(Resample(out_dim, mode=mode))
- scale *= 2.0
- self.upsamples = nn.Sequential(*upsamples)
-
- # output blocks
- self.head = nn.Sequential(
- RMS_norm(out_dim, images=False), nn.SiLU(),
- CausalConv3d(out_dim, 3, 3, padding=1))
-
- def forward(self, x, feat_cache=None, feat_idx=[0]):
- ## conv1
- if feat_cache is not None:
- idx = feat_idx[0]
- cache_x = x[:, :, -CACHE_T:, :, :].clone()
- if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
- # cache last frame of last two chunk
- cache_x = torch.cat([
- feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
- cache_x.device), cache_x
- ],
- dim=2)
- x = self.conv1(x, feat_cache[idx])
- feat_cache[idx] = cache_x
- feat_idx[0] += 1
- else:
- x = self.conv1(x)
-
- ## middle
- for layer in self.middle:
- if isinstance(layer, ResidualBlock) and feat_cache is not None:
- x = layer(x, feat_cache, feat_idx)
- else:
- x = layer(x)
-
- ## upsamples
- for layer in self.upsamples:
- if feat_cache is not None:
- x = layer(x, feat_cache, feat_idx)
- else:
- x = layer(x)
-
- ## head
- for layer in self.head:
- if isinstance(layer, CausalConv3d) and feat_cache is not None:
- idx = feat_idx[0]
- cache_x = x[:, :, -CACHE_T:, :, :].clone()
- if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
- # cache last frame of last two chunk
- cache_x = torch.cat([
- feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
- cache_x.device), cache_x
- ],
- dim=2)
- x = layer(x, feat_cache[idx])
- feat_cache[idx] = cache_x
- feat_idx[0] += 1
- else:
- x = layer(x)
- return x
-
-
-def count_conv3d(model):
- count = 0
- for m in model.modules():
- if isinstance(m, CausalConv3d):
- count += 1
- return count
-
-
-class WanVAE_(nn.Module):
-
- def __init__(self,
- dim=128,
- z_dim=4,
- dim_mult=[1, 2, 4, 4],
- num_res_blocks=2,
- attn_scales=[],
- temperal_downsample=[True, True, False],
- dropout=0.0):
- super().__init__()
- self.dim = dim
- self.z_dim = z_dim
- self.dim_mult = dim_mult
- self.num_res_blocks = num_res_blocks
- self.attn_scales = attn_scales
- self.temperal_downsample = temperal_downsample
- self.temperal_upsample = temperal_downsample[::-1]
-
- # modules
- self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
- attn_scales, self.temperal_downsample, dropout)
- self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
- self.conv2 = CausalConv3d(z_dim, z_dim, 1)
- self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
- attn_scales, self.temperal_upsample, dropout)
-
- def forward(self, x):
- mu, log_var = self.encode(x)
- z = self.reparameterize(mu, log_var)
- x_recon = self.decode(z)
- return x_recon, mu, log_var
-
- def encode(self, x, scale):
- self.clear_cache()
- ## cache
- t = x.shape[2]
- iter_ = 1 + (t - 1) // 4
- for i in range(iter_):
- self._enc_conv_idx = [0]
- if i == 0:
- out = self.encoder(
- x[:, :, :1, :, :],
- feat_cache=self._enc_feat_map,
- feat_idx=self._enc_conv_idx)
- else:
- out_ = self.encoder(
- x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
- feat_cache=self._enc_feat_map,
- feat_idx=self._enc_conv_idx)
- out = torch.cat([out, out_], 2)
- mu, log_var = self.conv1(out).chunk(2, dim=1)
- if isinstance(scale[0], torch.Tensor):
- mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
- 1, self.z_dim, 1, 1, 1)
- else:
- mu = (mu - scale[0]) * scale[1]
- self.clear_cache()
- return mu
-
- def decode(self, z, scale):
- self.clear_cache()
- # z: [b,c,t,h,w]
- if isinstance(scale[0], torch.Tensor):
- z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
- 1, self.z_dim, 1, 1, 1)
- else:
- z = z / scale[1] + scale[0]
- iter_ = z.shape[2]
- x = self.conv2(z)
- for i in range(iter_):
- self._conv_idx = [0]
- if i == 0:
- out = self.decoder(
- x[:, :, i:i + 1, :, :],
- feat_cache=self._feat_map,
- feat_idx=self._conv_idx)
- else:
- out_ = self.decoder(
- x[:, :, i:i + 1, :, :],
- feat_cache=self._feat_map,
- feat_idx=self._conv_idx)
- out = torch.cat([out, out_], 2)
- self.clear_cache()
- return out
-
- def reparameterize(self, mu, log_var):
- std = torch.exp(0.5 * log_var)
- eps = torch.randn_like(std)
- return eps * std + mu
-
- def sample(self, imgs, deterministic=False):
- mu, log_var = self.encode(imgs)
- if deterministic:
- return mu
- std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
- return mu + std * torch.randn_like(std)
-
- def clear_cache(self):
- self._conv_num = count_conv3d(self.decoder)
- self._conv_idx = [0]
- self._feat_map = [None] * self._conv_num
- #cache encode
- self._enc_conv_num = count_conv3d(self.encoder)
- self._enc_conv_idx = [0]
- self._enc_feat_map = [None] * self._enc_conv_num
diff --git a/library/attention.py b/library/attention.py
index d3b8441e..4f6a5422 100644
--- a/library/attention.py
+++ b/library/attention.py
@@ -37,6 +37,14 @@ class AttentionParams:
cu_seqlens: Optional[torch.Tensor] = None
max_seqlen: Optional[int] = None
+ @property
+ def supports_fp32(self) -> bool:
+ return self.attn_mode not in ["flash"]
+
+ @property
+ def requires_same_dtype(self) -> bool:
+ return self.attn_mode in ["xformers"]
+
@staticmethod
def create_attention_params(attn_mode: Optional[str], split_attn: bool) -> "AttentionParams":
return AttentionParams(attn_mode, split_attn)
@@ -95,7 +103,7 @@ def attention(
qkv_or_q: Query tensor [B, L, H, D]. or list of such tensors.
k: Key tensor [B, L, H, D].
v: Value tensor [B, L, H, D].
- attn_param: Attention parameters including mask and sequence lengths.
+ attn_params: Attention parameters including mask and sequence lengths.
drop_rate: Attention dropout rate.
Returns:
diff --git a/library/custom_offloading_utils.py b/library/custom_offloading_utils.py
index 0681dcdc..883379ce 100644
--- a/library/custom_offloading_utils.py
+++ b/library/custom_offloading_utils.py
@@ -195,6 +195,9 @@ class ModelOffloader(Offloader):
self.remove_handles.append(handle)
def set_forward_only(self, forward_only: bool):
+ # switching must wait for all pending transfers
+ for block_idx in list(self.futures.keys()):
+ self._wait_blocks_move(block_idx)
self.forward_only = forward_only
def __del__(self):
@@ -237,6 +240,10 @@ class ModelOffloader(Offloader):
if self.debug:
print(f"Prepare block devices before forward")
+ # wait for all pending transfers
+ for block_idx in list(self.futures.keys()):
+ self._wait_blocks_move(block_idx)
+
for b in blocks[0 : self.num_blocks - self.blocks_to_swap]:
b.to(self.device)
weighs_to_device(b, self.device) # make sure weights are on device
diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py
index 06fe0b95..c96e4bb6 100644
--- a/library/flux_train_utils.py
+++ b/library/flux_train_utils.py
@@ -471,7 +471,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
def get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- bsz, _, h, w = latents.shape
+ bsz, h, w = latents.shape[0], latents.shape[-2], latents.shape[-1]
assert bsz > 0, "Batch size not large enough"
num_timesteps = noise_scheduler.config.num_train_timesteps
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
@@ -512,7 +512,7 @@ def get_noisy_model_input_and_timesteps(
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
# Broadcast sigmas to latent shape
- sigmas = sigmas.view(-1, 1, 1, 1)
+ sigmas = sigmas.view(-1, 1, 1, 1) if latents.ndim == 4 else sigmas.view(-1, 1, 1, 1, 1)
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
diff --git a/library/fp8_optimization_utils.py b/library/fp8_optimization_utils.py
index 02f99ab6..af35fd3e 100644
--- a/library/fp8_optimization_utils.py
+++ b/library/fp8_optimization_utils.py
@@ -9,7 +9,7 @@ import logging
from tqdm import tqdm
from library.device_utils import clean_memory_on_device
-from library.safetensors_utils import MemoryEfficientSafeOpen
+from library.safetensors_utils import MemoryEfficientSafeOpen, TensorWeightAdapter, WeightTransformHooks
from library.utils import setup_logging
setup_logging()
@@ -220,6 +220,8 @@ def quantize_weight(
tensor_max = torch.max(torch.abs(tensor).view(-1))
scale = tensor_max / max_value
+ # print(f"Optimizing {key} with scale: {scale}")
+
# numerical safety
scale = torch.clamp(scale, min=1e-8)
scale = scale.to(torch.float32) # ensure scale is in float32 for division
@@ -245,6 +247,8 @@ def load_safetensors_with_fp8_optimization(
weight_hook=None,
quantization_mode: str = "block",
block_size: Optional[int] = 64,
+ disable_numpy_memmap: bool = False,
+ weight_transform_hooks: Optional[WeightTransformHooks] = None,
) -> dict:
"""
Load weight tensors from safetensors files and merge LoRA weights into the state dict with explicit FP8 optimization.
@@ -260,6 +264,8 @@ def load_safetensors_with_fp8_optimization(
weight_hook (callable, optional): Function to apply to each weight tensor before optimization
quantization_mode (str): Quantization mode, "tensor", "channel", or "block"
block_size (int, optional): Block size for block-wise quantization (used if quantization_mode is "block")
+ disable_numpy_memmap (bool): Disable numpy memmap when loading safetensors
+ weight_transform_hooks (WeightTransformHooks, optional): Hooks for weight transformation during loading
Returns:
dict: FP8 optimized state dict
@@ -288,7 +294,9 @@ def load_safetensors_with_fp8_optimization(
# Process each file
state_dict = {}
for model_file in model_files:
- with MemoryEfficientSafeOpen(model_file) as f:
+ with MemoryEfficientSafeOpen(model_file, disable_numpy_memmap=disable_numpy_memmap) as original_f:
+ f = TensorWeightAdapter(weight_transform_hooks, original_f) if weight_transform_hooks is not None else original_f
+
keys = f.keys()
for key in tqdm(keys, desc=f"Loading {os.path.basename(model_file)}", unit="key"):
value = f.get_tensor(key)
@@ -311,6 +319,11 @@ def load_safetensors_with_fp8_optimization(
value = value.to(calc_device)
original_dtype = value.dtype
+ if original_dtype.itemsize == 1:
+ raise ValueError(
+ f"Layer {key} is already in {original_dtype} format. `--fp8_scaled` optimization should not be applied. Please use fp16/bf16/float32 model weights."
+ + f" / レイヤー {key} は既に{original_dtype}形式です。`--fp8_scaled` 最適化は適用できません。FP16/BF16/Float32のモデル重みを使用してください。"
+ )
quantized_weight, scale_tensor = quantize_weight(
key, value, fp8_dtype, max_value, min_value, quantization_mode, block_size
)
@@ -387,7 +400,7 @@ def fp8_linear_forward_patch(self: nn.Linear, x, use_scaled_mm=False, max_value=
else:
o = torch._scaled_mm(x, weight, out_dtype=input_dtype, scale_a=scale_x, scale_b=scale_weight)
- o = o.reshape(original_shape[0], original_shape[1], -1) if x.ndim == 3 else o.reshape(original_shape[0], -1)
+ o = o.reshape(original_shape[0], original_shape[1], -1) if len(original_shape) == 3 else o.reshape(original_shape[0], -1)
return o.to(input_dtype)
else:
diff --git a/library/lora_utils.py b/library/lora_utils.py
index 6f0fc228..90e3c389 100644
--- a/library/lora_utils.py
+++ b/library/lora_utils.py
@@ -5,7 +5,7 @@ import torch
from tqdm import tqdm
from library.device_utils import synchronize_device
from library.fp8_optimization_utils import load_safetensors_with_fp8_optimization
-from library.safetensors_utils import MemoryEfficientSafeOpen
+from library.safetensors_utils import MemoryEfficientSafeOpen, TensorWeightAdapter, WeightTransformHooks, get_split_weight_filenames
from library.utils import setup_logging
setup_logging()
@@ -44,7 +44,7 @@ def filter_lora_state_dict(
def load_safetensors_with_lora_and_fp8(
model_files: Union[str, List[str]],
- lora_weights_list: Optional[Dict[str, torch.Tensor]],
+ lora_weights_list: Optional[List[Dict[str, torch.Tensor]]],
lora_multipliers: Optional[List[float]],
fp8_optimization: bool,
calc_device: torch.device,
@@ -52,19 +52,23 @@ def load_safetensors_with_lora_and_fp8(
dit_weight_dtype: Optional[torch.dtype] = None,
target_keys: Optional[List[str]] = None,
exclude_keys: Optional[List[str]] = None,
+ disable_numpy_memmap: bool = False,
+ weight_transform_hooks: Optional[WeightTransformHooks] = None,
) -> dict[str, torch.Tensor]:
"""
Merge LoRA weights into the state dict of a model with fp8 optimization if needed.
Args:
model_files (Union[str, List[str]]): Path to the model file or list of paths. If the path matches a pattern like `00001-of-00004`, it will load all files with the same prefix.
- lora_weights_list (Optional[Dict[str, torch.Tensor]]): Dictionary of LoRA weight tensors to load.
+ lora_weights_list (Optional[List[Dict[str, torch.Tensor]]]): List of dictionaries of LoRA weight tensors to load.
lora_multipliers (Optional[List[float]]): List of multipliers for LoRA weights.
fp8_optimization (bool): Whether to apply FP8 optimization.
calc_device (torch.device): Device to calculate on.
move_to_device (bool): Whether to move tensors to the calculation device after loading.
target_keys (Optional[List[str]]): Keys to target for optimization.
exclude_keys (Optional[List[str]]): Keys to exclude from optimization.
+ disable_numpy_memmap (bool): Whether to disable numpy memmap when loading safetensors.
+ weight_transform_hooks (Optional[WeightTransformHooks]): Hooks for transforming weights during loading.
"""
# if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix
@@ -73,19 +77,9 @@ def load_safetensors_with_lora_and_fp8(
extended_model_files = []
for model_file in model_files:
- basename = os.path.basename(model_file)
- match = re.match(r"^(.*?)(\d+)-of-(\d+)\.safetensors$", basename)
- if match:
- prefix = basename[: match.start(2)]
- count = int(match.group(3))
- state_dict = {}
- for i in range(count):
- filename = f"{prefix}{i + 1:05d}-of-{count:05d}.safetensors"
- filepath = os.path.join(os.path.dirname(model_file), filename)
- if os.path.exists(filepath):
- extended_model_files.append(filepath)
- else:
- raise FileNotFoundError(f"File {filepath} not found")
+ split_filenames = get_split_weight_filenames(model_file)
+ if split_filenames is not None:
+ extended_model_files.extend(split_filenames)
else:
extended_model_files.append(model_file)
model_files = extended_model_files
@@ -114,7 +108,7 @@ def load_safetensors_with_lora_and_fp8(
logger.info(f"Merging LoRA weights into state dict. multipliers: {lora_multipliers}")
# make hook for LoRA merging
- def weight_hook_func(model_weight_key, model_weight, keep_on_calc_device=False):
+ def weight_hook_func(model_weight_key, model_weight: torch.Tensor, keep_on_calc_device=False):
nonlocal list_of_lora_weight_keys, lora_weights_list, lora_multipliers, calc_device
if not model_weight_key.endswith(".weight"):
@@ -126,13 +120,18 @@ def load_safetensors_with_lora_and_fp8(
for lora_weight_keys, lora_sd, multiplier in zip(list_of_lora_weight_keys, lora_weights_list, lora_multipliers):
# check if this weight has LoRA weights
- lora_name = model_weight_key.rsplit(".", 1)[0] # remove trailing ".weight"
- lora_name = "lora_unet_" + lora_name.replace(".", "_")
- down_key = lora_name + ".lora_down.weight"
- up_key = lora_name + ".lora_up.weight"
- alpha_key = lora_name + ".alpha"
- if down_key not in lora_weight_keys or up_key not in lora_weight_keys:
- continue
+ lora_name_without_prefix = model_weight_key.rsplit(".", 1)[0] # remove trailing ".weight"
+ found = False
+ for prefix in ["lora_unet_", ""]:
+ lora_name = prefix + lora_name_without_prefix.replace(".", "_")
+ down_key = lora_name + ".lora_down.weight"
+ up_key = lora_name + ".lora_up.weight"
+ alpha_key = lora_name + ".alpha"
+ if down_key in lora_weight_keys and up_key in lora_weight_keys:
+ found = True
+ break
+ if not found:
+ continue # no LoRA weights for this model weight
# get LoRA weights
down_weight = lora_sd[down_key]
@@ -145,6 +144,13 @@ def load_safetensors_with_lora_and_fp8(
down_weight = down_weight.to(calc_device)
up_weight = up_weight.to(calc_device)
+ original_dtype = model_weight.dtype
+ if original_dtype.itemsize == 1: # fp8
+ # temporarily convert to float16 for calculation
+ model_weight = model_weight.to(torch.float16)
+ down_weight = down_weight.to(torch.float16)
+ up_weight = up_weight.to(torch.float16)
+
# W <- W + U * D
if len(model_weight.size()) == 2:
# linear
@@ -166,6 +172,9 @@ def load_safetensors_with_lora_and_fp8(
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
model_weight = model_weight + multiplier * conved * scale
+ if original_dtype.itemsize == 1: # fp8
+ model_weight = model_weight.to(original_dtype) # convert back to original dtype
+
# remove LoRA keys from set
lora_weight_keys.remove(down_key)
lora_weight_keys.remove(up_key)
@@ -187,6 +196,8 @@ def load_safetensors_with_lora_and_fp8(
target_keys,
exclude_keys,
weight_hook=weight_hook,
+ disable_numpy_memmap=disable_numpy_memmap,
+ weight_transform_hooks=weight_transform_hooks,
)
for lora_weight_keys in list_of_lora_weight_keys:
@@ -208,6 +219,8 @@ def load_safetensors_with_fp8_optimization_and_hook(
target_keys: Optional[List[str]] = None,
exclude_keys: Optional[List[str]] = None,
weight_hook: callable = None,
+ disable_numpy_memmap: bool = False,
+ weight_transform_hooks: Optional[WeightTransformHooks] = None,
) -> dict[str, torch.Tensor]:
"""
Load state dict from safetensors files and merge LoRA weights into the state dict with fp8 optimization if needed.
@@ -218,7 +231,14 @@ def load_safetensors_with_fp8_optimization_and_hook(
)
# dit_weight_dtype is not used because we use fp8 optimization
state_dict = load_safetensors_with_fp8_optimization(
- model_files, calc_device, target_keys, exclude_keys, move_to_device=move_to_device, weight_hook=weight_hook
+ model_files,
+ calc_device,
+ target_keys,
+ exclude_keys,
+ move_to_device=move_to_device,
+ weight_hook=weight_hook,
+ disable_numpy_memmap=disable_numpy_memmap,
+ weight_transform_hooks=weight_transform_hooks,
)
else:
logger.info(
@@ -226,7 +246,8 @@ def load_safetensors_with_fp8_optimization_and_hook(
)
state_dict = {}
for model_file in model_files:
- with MemoryEfficientSafeOpen(model_file) as f:
+ with MemoryEfficientSafeOpen(model_file, disable_numpy_memmap=disable_numpy_memmap) as original_f:
+ f = TensorWeightAdapter(weight_transform_hooks, original_f) if weight_transform_hooks is not None else original_f
for key in tqdm(f.keys(), desc=f"Loading {os.path.basename(model_file)}", leave=False):
if weight_hook is None and move_to_device:
value = f.get_tensor(key, device=calc_device, dtype=dit_weight_dtype)
diff --git a/library/qwen_image_autoencoder_kl.py b/library/qwen_image_autoencoder_kl.py
new file mode 100644
index 00000000..ab65e3b9
--- /dev/null
+++ b/library/qwen_image_autoencoder_kl.py
@@ -0,0 +1,1735 @@
+# Copied and modified from Diffusers (via Musubi-Tuner). Original copyright notice follows.
+
+# Copyright 2025 The Qwen-Image Team, Wan Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# We gratefully acknowledge the Wan Team for their outstanding contributions.
+# QwenImageVAE is further fine-tuned from the Wan Video VAE to achieve improved performance.
+# For more information about the Wan VAE, please refer to:
+# - GitHub: https://github.com/Wan-Video/Wan2.1
+# - arXiv: https://arxiv.org/abs/2503.20314
+
+import json
+from typing import Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+
+from library.safetensors_utils import load_safetensors
+
+from library.utils import setup_logging
+
+setup_logging()
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+CACHE_T = 2
+
+SCALE_FACTOR = 8 # VAE downsampling factor
+
+
+# region diffusers-vae
+
+
+class DiagonalGaussianDistribution(object):
+ def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device, dtype=self.parameters.dtype)
+
+ def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
+ # make sure sample is on the same device as the parameters and has same dtype
+ if generator is not None and generator.device.type != self.parameters.device.type:
+ rand_device = generator.device
+ else:
+ rand_device = self.parameters.device
+ sample = torch.randn(self.mean.shape, generator=generator, device=rand_device, dtype=self.parameters.dtype).to(
+ self.parameters.device
+ )
+ x = self.mean + self.std * sample
+ return x
+
+ def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ else:
+ if other is None:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
+ dim=[1, 2, 3],
+ )
+ else:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean - other.mean, 2) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar,
+ dim=[1, 2, 3],
+ )
+
+ def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
+ dim=dims,
+ )
+
+ def mode(self) -> torch.Tensor:
+ return self.mean
+
+
+# endregion diffusers-vae
+
+
+class ChunkedConv2d(nn.Conv2d):
+ """
+ Convolutional layer that processes input in chunks to reduce memory usage.
+
+ Parameters
+ ----------
+ spatial_chunk_size : int, optional
+ Size of chunks to process at a time. Default is None, which means no chunking.
+
+ TODO: Commonize with similar implementation in hunyuan_image_vae.py
+ """
+
+ def __init__(self, *args, **kwargs):
+ if "spatial_chunk_size" in kwargs:
+ self.spatial_chunk_size = kwargs.pop("spatial_chunk_size", None)
+ else:
+ self.spatial_chunk_size = None
+ super().__init__(*args, **kwargs)
+ assert self.padding_mode == "zeros", "Only 'zeros' padding mode is supported."
+ assert self.dilation == (1, 1), "Only dilation=1 is supported."
+ assert self.groups == 1, "Only groups=1 is supported."
+ assert self.kernel_size[0] == self.kernel_size[1], "Only square kernels are supported."
+ assert self.stride[0] == self.stride[1], "Only equal strides are supported."
+ self.original_padding = self.padding
+ self.padding = (0, 0) # We handle padding manually in forward
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # If chunking is not needed, process normally. We chunk only along height dimension.
+ if (
+ self.spatial_chunk_size is None
+ or x.shape[2] <= self.spatial_chunk_size + self.kernel_size[0] + self.spatial_chunk_size // 4
+ ):
+ self.padding = self.original_padding
+ x = super().forward(x)
+ self.padding = (0, 0)
+ return x
+
+ # Process input in chunks to reduce memory usage
+ org_shape = x.shape
+
+ # If kernel size is not 1, we need to use overlapping chunks
+ overlap = self.kernel_size[0] // 2 # 1 for kernel size 3
+ if self.original_padding[0] == 0:
+ overlap = 0
+
+ # If stride > 1, QwenImageVAE pads manually with zeros before convolution, so we do not need to consider it here
+ y_height = org_shape[2] // self.stride[0]
+ y_width = org_shape[3] // self.stride[1]
+ y = torch.zeros((org_shape[0], self.out_channels, y_height, y_width), dtype=x.dtype, device=x.device)
+ yi = 0
+ i = 0
+ while i < org_shape[2]:
+ si = i if i == 0 else i - overlap
+ ei = i + self.spatial_chunk_size + overlap + self.stride[0] - 1
+
+ # Check last chunk. If remaining part is small, include it in last chunk
+ if ei > org_shape[2] or ei + self.spatial_chunk_size // 4 > org_shape[2]:
+ ei = org_shape[2]
+
+ chunk = x[:, :, si:ei, :]
+
+ # Pad chunk if needed: This is as the original Conv2d with padding
+ if i == 0 and overlap > 0: # First chunk
+ # Pad except bottom
+ chunk = torch.nn.functional.pad(chunk, (overlap, overlap, overlap, 0), mode="constant", value=0)
+ elif ei == org_shape[2] and overlap > 0: # Last chunk
+ # Pad except top
+ chunk = torch.nn.functional.pad(chunk, (overlap, overlap, 0, overlap), mode="constant", value=0)
+ elif overlap > 0: # Middle chunks
+ # Pad left and right only
+ chunk = torch.nn.functional.pad(chunk, (overlap, overlap), mode="constant", value=0)
+
+ # print(f"Processing chunk: org_shape={org_shape}, si={si}, ei={ei}, chunk.shape={chunk.shape}, overlap={overlap}")
+ chunk = super().forward(chunk)
+ # print(f" -> chunk after conv shape: {chunk.shape}")
+ y[:, :, yi : yi + chunk.shape[2], :] = chunk
+ yi += chunk.shape[2]
+ del chunk
+
+ if ei == org_shape[2]:
+ break
+ i += self.spatial_chunk_size
+
+ assert yi == y_height, f"yi={yi}, y_height={y_height}"
+
+ return y
+
+
+class QwenImageCausalConv3d(nn.Conv3d):
+ r"""
+ A custom 3D causal convolution layer with feature caching support.
+
+ This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature
+ caching for efficient inference.
+
+ Args:
+ in_channels (int): Number of channels in the input image
+ out_channels (int): Number of channels produced by the convolution
+ kernel_size (int or tuple): Size of the convolving kernel
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
+ padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, int, int]],
+ stride: Union[int, Tuple[int, int, int]] = 1,
+ padding: Union[int, Tuple[int, int, int]] = 0,
+ spatial_chunk_size: Optional[int] = None,
+ ) -> None:
+ super().__init__(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ )
+
+ # Set up causal padding
+ self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
+ self.padding = (0, 0, 0)
+ self.spatial_chunk_size = spatial_chunk_size
+ self._supports_spatial_chunking = (
+ self.groups == 1 and self.dilation[1] == 1 and self.dilation[2] == 1 and self.stride[1] == 1 and self.stride[2] == 1
+ )
+
+ def _forward_chunked_height(self, x: torch.Tensor) -> torch.Tensor:
+ chunk_size = self.spatial_chunk_size
+ if chunk_size is None or chunk_size <= 0:
+ return super().forward(x)
+ if not self._supports_spatial_chunking:
+ return super().forward(x)
+
+ kernel_h = self.kernel_size[1]
+ if kernel_h <= 1 or x.shape[3] <= chunk_size:
+ return super().forward(x)
+
+ receptive_h = kernel_h
+ out_h = x.shape[3] - receptive_h + 1
+ if out_h <= 0:
+ return super().forward(x)
+
+ y0 = 0
+ out = None
+ while y0 < out_h:
+ y1 = min(y0 + chunk_size, out_h)
+ in0 = y0
+ in1 = y1 + receptive_h - 1
+ out_chunk = super().forward(x[:, :, :, in0:in1, :])
+ if out is None:
+ out_shape = list(out_chunk.shape)
+ out_shape[3] = out_h
+ out = out_chunk.new_empty(out_shape)
+ out[:, :, :, y0:y1, :] = out_chunk
+ y0 = y1
+
+ return out
+
+ def forward(self, x, cache_x=None):
+ padding = list(self._padding)
+ if cache_x is not None and self._padding[4] > 0:
+ cache_x = cache_x.to(x.device)
+ x = torch.cat([cache_x, x], dim=2)
+ padding[4] -= cache_x.shape[2]
+ x = F.pad(x, padding)
+ return self._forward_chunked_height(x)
+
+
+class QwenImageRMS_norm(nn.Module):
+ r"""
+ A custom RMS normalization layer.
+
+ Args:
+ dim (int): The number of dimensions to normalize over.
+ channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
+ Default is True.
+ images (bool, optional): Whether the input represents image data. Default is True.
+ bias (bool, optional): Whether to include a learnable bias term. Default is False.
+ """
+
+ def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None:
+ super().__init__()
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
+
+ self.channel_first = channel_first
+ self.scale = dim**0.5
+ self.gamma = nn.Parameter(torch.ones(shape))
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
+
+ def forward(self, x):
+ return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
+
+
+class QwenImageUpsample(nn.Upsample):
+ r"""
+ Perform upsampling while ensuring the output tensor has the same data type as the input.
+
+ Args:
+ x (torch.Tensor): Input tensor to be upsampled.
+
+ Returns:
+ torch.Tensor: Upsampled tensor with the same data type as the input.
+ """
+
+ def forward(self, x):
+ return super().forward(x.float()).type_as(x)
+
+
+class QwenImageResample(nn.Module):
+ r"""
+ A custom resampling module for 2D and 3D data.
+
+ Args:
+ dim (int): The number of input/output channels.
+ mode (str): The resampling mode. Must be one of:
+ - 'none': No resampling (identity operation).
+ - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution.
+ - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution.
+ - 'downsample2d': 2D downsampling with zero-padding and convolution.
+ - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
+ """
+
+ def __init__(self, dim: int, mode: str) -> None:
+ super().__init__()
+ self.dim = dim
+ self.mode = mode
+
+ # layers
+ if mode == "upsample2d":
+ self.resample = nn.Sequential(
+ QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
+ ChunkedConv2d(dim, dim // 2, 3, padding=1),
+ )
+ elif mode == "upsample3d":
+ self.resample = nn.Sequential(
+ QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
+ ChunkedConv2d(dim, dim // 2, 3, padding=1),
+ )
+ self.time_conv = QwenImageCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
+
+ elif mode == "downsample2d":
+ self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), ChunkedConv2d(dim, dim, 3, stride=(2, 2)))
+ elif mode == "downsample3d":
+ self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), ChunkedConv2d(dim, dim, 3, stride=(2, 2)))
+ self.time_conv = QwenImageCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
+
+ else:
+ self.resample = nn.Identity()
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ b, c, t, h, w = x.size()
+ if self.mode == "upsample3d":
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ if feat_cache[idx] is None:
+ feat_cache[idx] = "Rep"
+ feat_idx[0] += 1
+ else:
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
+ # cache last frame of last two chunk
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
+ cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
+ if feat_cache[idx] == "Rep":
+ x = self.time_conv(x)
+ else:
+ x = self.time_conv(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+
+ x = x.reshape(b, 2, c, t, h, w)
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
+ x = x.reshape(b, c, t * 2, h, w)
+ t = x.shape[2]
+ x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
+ x = self.resample(x)
+ x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4)
+
+ if self.mode == "downsample3d":
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ if feat_cache[idx] is None:
+ feat_cache[idx] = x.clone()
+ feat_idx[0] += 1
+ else:
+ cache_x = x[:, :, -1:, :, :].clone()
+ x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ return x
+
+
+class QwenImageResidualBlock(nn.Module):
+ r"""
+ A custom residual block module.
+
+ Args:
+ in_dim (int): Number of input channels.
+ out_dim (int): Number of output channels.
+ dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0.
+ non_linearity (str, optional): Type of non-linearity to use. Default is "silu".
+ """
+
+ def __init__(
+ self,
+ in_dim: int,
+ out_dim: int,
+ dropout: float = 0.0,
+ non_linearity: str = "silu",
+ ) -> None:
+ assert non_linearity in ["silu"], "Only 'silu' non-linearity is supported currently."
+ super().__init__()
+ self.in_dim = in_dim
+ self.out_dim = out_dim
+ self.nonlinearity = nn.SiLU() # get_activation(non_linearity)
+
+ # layers
+ self.norm1 = QwenImageRMS_norm(in_dim, images=False)
+ self.conv1 = QwenImageCausalConv3d(in_dim, out_dim, 3, padding=1)
+ self.norm2 = QwenImageRMS_norm(out_dim, images=False)
+ self.dropout = nn.Dropout(dropout)
+ self.conv2 = QwenImageCausalConv3d(out_dim, out_dim, 3, padding=1)
+ self.conv_shortcut = QwenImageCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ # Apply shortcut connection
+ h = self.conv_shortcut(x)
+
+ # First normalization and activation
+ x = self.norm1(x)
+ x = self.nonlinearity(x)
+
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
+
+ x = self.conv1(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = self.conv1(x)
+
+ # Second normalization and activation
+ x = self.norm2(x)
+ x = self.nonlinearity(x)
+
+ # Dropout
+ x = self.dropout(x)
+
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
+
+ x = self.conv2(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = self.conv2(x)
+
+ # Add residual connection
+ return x + h
+
+
+class QwenImageAttentionBlock(nn.Module):
+ r"""
+ Causal self-attention with a single head.
+
+ Args:
+ dim (int): The number of channels in the input tensor.
+ """
+
+ def __init__(self, dim):
+ super().__init__()
+ self.dim = dim
+
+ # layers
+ self.norm = QwenImageRMS_norm(dim)
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
+ self.proj = nn.Conv2d(dim, dim, 1)
+
+ def forward(self, x):
+ identity = x
+ batch_size, channels, time, height, width = x.size()
+
+ x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width)
+ x = self.norm(x)
+
+ # compute query, key, value
+ qkv = self.to_qkv(x)
+ qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
+ qkv = qkv.permute(0, 1, 3, 2).contiguous()
+ q, k, v = qkv.chunk(3, dim=-1)
+
+ # apply attention
+ x = F.scaled_dot_product_attention(q, k, v)
+
+ x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width)
+
+ # output projection
+ x = self.proj(x)
+
+ # Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w]
+ x = x.view(batch_size, time, channels, height, width)
+ x = x.permute(0, 2, 1, 3, 4)
+
+ return x + identity
+
+
+class QwenImageMidBlock(nn.Module):
+ """
+ Middle block for QwenImageVAE encoder and decoder.
+
+ Args:
+ dim (int): Number of input/output channels.
+ dropout (float): Dropout rate.
+ non_linearity (str): Type of non-linearity to use.
+ """
+
+ def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1):
+ super().__init__()
+ self.dim = dim
+
+ # Create the components
+ resnets = [QwenImageResidualBlock(dim, dim, dropout, non_linearity)]
+ attentions = []
+ for _ in range(num_layers):
+ attentions.append(QwenImageAttentionBlock(dim))
+ resnets.append(QwenImageResidualBlock(dim, dim, dropout, non_linearity))
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ self.gradient_checkpointing = False
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ # First residual block
+ x = self.resnets[0](x, feat_cache, feat_idx)
+
+ # Process through attention and residual blocks
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ if attn is not None:
+ x = attn(x)
+
+ x = resnet(x, feat_cache, feat_idx)
+
+ return x
+
+
+class QwenImageEncoder3d(nn.Module):
+ r"""
+ A 3D encoder module.
+
+ Args:
+ dim (int): The base number of channels in the first layer.
+ z_dim (int): The dimensionality of the latent space.
+ dim_mult (list of int): Multipliers for the number of channels in each block.
+ num_res_blocks (int): Number of residual blocks in each block.
+ attn_scales (list of float): Scales at which to apply attention mechanisms.
+ temperal_downsample (list of bool): Whether to downsample temporally in each block.
+ dropout (float): Dropout rate for the dropout layers.
+ input_channels (int): Number of input channels.
+ non_linearity (str): Type of non-linearity to use.
+ """
+
+ def __init__(
+ self,
+ dim=128,
+ z_dim=4,
+ dim_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ attn_scales=[],
+ temperal_downsample=[True, True, False],
+ dropout=0.0,
+ input_channels: int = 3,
+ non_linearity: str = "silu",
+ ):
+ super().__init__()
+ assert non_linearity in ["silu"], "Only 'silu' non-linearity is supported currently."
+ self.dim = dim
+ self.z_dim = z_dim
+ self.dim_mult = dim_mult
+ self.num_res_blocks = num_res_blocks
+ self.attn_scales = attn_scales
+ self.temperal_downsample = temperal_downsample
+ self.nonlinearity = nn.SiLU() # get_activation(non_linearity)
+
+ # dimensions
+ dims = [dim * u for u in [1] + dim_mult]
+ scale = 1.0
+
+ # init block
+ self.conv_in = QwenImageCausalConv3d(input_channels, dims[0], 3, padding=1)
+
+ # downsample blocks
+ self.down_blocks = nn.ModuleList([])
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
+ # residual (+attention) blocks
+ for _ in range(num_res_blocks):
+ self.down_blocks.append(QwenImageResidualBlock(in_dim, out_dim, dropout))
+ if scale in attn_scales:
+ self.down_blocks.append(QwenImageAttentionBlock(out_dim))
+ in_dim = out_dim
+
+ # downsample block
+ if i != len(dim_mult) - 1:
+ mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
+ self.down_blocks.append(QwenImageResample(out_dim, mode=mode))
+ scale /= 2.0
+
+ # middle blocks
+ self.mid_block = QwenImageMidBlock(out_dim, dropout, non_linearity, num_layers=1)
+
+ # output blocks
+ self.norm_out = QwenImageRMS_norm(out_dim, images=False)
+ self.conv_out = QwenImageCausalConv3d(out_dim, z_dim, 3, padding=1)
+
+ self.gradient_checkpointing = False
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
+ x = self.conv_in(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = self.conv_in(x)
+
+ ## downsamples
+ for layer in self.down_blocks:
+ if feat_cache is not None:
+ x = layer(x, feat_cache, feat_idx)
+ else:
+ x = layer(x)
+
+ ## middle
+ x = self.mid_block(x, feat_cache, feat_idx)
+
+ ## head
+ x = self.norm_out(x)
+ x = self.nonlinearity(x)
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
+ x = self.conv_out(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = self.conv_out(x)
+ return x
+
+
+class QwenImageUpBlock(nn.Module):
+ """
+ A block that handles upsampling for the QwenImageVAE decoder.
+
+ Args:
+ in_dim (int): Input dimension
+ out_dim (int): Output dimension
+ num_res_blocks (int): Number of residual blocks
+ dropout (float): Dropout rate
+ upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d')
+ non_linearity (str): Type of non-linearity to use
+ """
+
+ def __init__(
+ self,
+ in_dim: int,
+ out_dim: int,
+ num_res_blocks: int,
+ dropout: float = 0.0,
+ upsample_mode: Optional[str] = None,
+ non_linearity: str = "silu",
+ ):
+ super().__init__()
+ self.in_dim = in_dim
+ self.out_dim = out_dim
+
+ # Create layers list
+ resnets = []
+ # Add residual blocks and attention if needed
+ current_dim = in_dim
+ for _ in range(num_res_blocks + 1):
+ resnets.append(QwenImageResidualBlock(current_dim, out_dim, dropout, non_linearity))
+ current_dim = out_dim
+
+ self.resnets = nn.ModuleList(resnets)
+
+ # Add upsampling layer if needed
+ self.upsamplers = None
+ if upsample_mode is not None:
+ self.upsamplers = nn.ModuleList([QwenImageResample(out_dim, mode=upsample_mode)])
+
+ self.gradient_checkpointing = False
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ """
+ Forward pass through the upsampling block.
+
+ Args:
+ x (torch.Tensor): Input tensor
+ feat_cache (list, optional): Feature cache for causal convolutions
+ feat_idx (list, optional): Feature index for cache management
+
+ Returns:
+ torch.Tensor: Output tensor
+ """
+ for resnet in self.resnets:
+ if feat_cache is not None:
+ x = resnet(x, feat_cache, feat_idx)
+ else:
+ x = resnet(x)
+
+ if self.upsamplers is not None:
+ if feat_cache is not None:
+ x = self.upsamplers[0](x, feat_cache, feat_idx)
+ else:
+ x = self.upsamplers[0](x)
+ return x
+
+
+class QwenImageDecoder3d(nn.Module):
+ r"""
+ A 3D decoder module.
+
+ Args:
+ dim (int): The base number of channels in the first layer.
+ z_dim (int): The dimensionality of the latent space.
+ dim_mult (list of int): Multipliers for the number of channels in each block.
+ num_res_blocks (int): Number of residual blocks in each block.
+ attn_scales (list of float): Scales at which to apply attention mechanisms.
+ temperal_upsample (list of bool): Whether to upsample temporally in each block.
+ dropout (float): Dropout rate for the dropout layers.
+ output_channels (int): Number of output channels.
+ non_linearity (str): Type of non-linearity to use.
+ """
+
+ def __init__(
+ self,
+ dim=128,
+ z_dim=4,
+ dim_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ attn_scales=[],
+ temperal_upsample=[False, True, True],
+ dropout=0.0,
+ output_channels: int = 3,
+ non_linearity: str = "silu",
+ ):
+ super().__init__()
+ assert non_linearity in ["silu"], "Only 'silu' non-linearity is supported currently."
+ self.dim = dim
+ self.z_dim = z_dim
+ self.dim_mult = dim_mult
+ self.num_res_blocks = num_res_blocks
+ self.attn_scales = attn_scales
+ self.temperal_upsample = temperal_upsample
+
+ self.nonlinearity = nn.SiLU() # get_activation(non_linearity)
+
+ # dimensions
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
+ scale = 1.0 / 2 ** (len(dim_mult) - 2)
+
+ # init block
+ self.conv_in = QwenImageCausalConv3d(z_dim, dims[0], 3, padding=1)
+
+ # middle blocks
+ self.mid_block = QwenImageMidBlock(dims[0], dropout, non_linearity, num_layers=1)
+
+ # upsample blocks
+ self.up_blocks = nn.ModuleList([])
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
+ # residual (+attention) blocks
+ if i > 0:
+ in_dim = in_dim // 2
+
+ # Determine if we need upsampling
+ upsample_mode = None
+ if i != len(dim_mult) - 1:
+ upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
+
+ # Create and add the upsampling block
+ up_block = QwenImageUpBlock(
+ in_dim=in_dim,
+ out_dim=out_dim,
+ num_res_blocks=num_res_blocks,
+ dropout=dropout,
+ upsample_mode=upsample_mode,
+ non_linearity=non_linearity,
+ )
+ self.up_blocks.append(up_block)
+
+ # Update scale for next iteration
+ if upsample_mode is not None:
+ scale *= 2.0
+
+ # output blocks
+ self.norm_out = QwenImageRMS_norm(out_dim, images=False)
+ self.conv_out = QwenImageCausalConv3d(out_dim, output_channels, 3, padding=1)
+
+ self.gradient_checkpointing = False
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ ## conv1
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
+ x = self.conv_in(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = self.conv_in(x)
+
+ ## middle
+ x = self.mid_block(x, feat_cache, feat_idx)
+
+ ## upsamples
+ for up_block in self.up_blocks:
+ x = up_block(x, feat_cache, feat_idx)
+
+ ## head
+ x = self.norm_out(x)
+ x = self.nonlinearity(x)
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
+ x = self.conv_out(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = self.conv_out(x)
+ return x
+
+
+class AutoencoderKLQwenImage(nn.Module): # ModelMixin, ConfigMixin, FromOriginalModelMixin):
+ r"""
+ A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+ """
+
+ _supports_gradient_checkpointing = False
+
+ # @register_to_config
+ def __init__(
+ self,
+ base_dim: int = 96,
+ z_dim: int = 16,
+ dim_mult: Tuple[int] = [1, 2, 4, 4],
+ num_res_blocks: int = 2,
+ attn_scales: List[float] = [],
+ temperal_downsample: List[bool] = [False, True, True],
+ dropout: float = 0.0,
+ latents_mean: List[float] = [
+ -0.7571,
+ -0.7089,
+ -0.9113,
+ 0.1075,
+ -0.1745,
+ 0.9653,
+ -0.1517,
+ 1.5508,
+ 0.4134,
+ -0.0715,
+ 0.5517,
+ -0.3632,
+ -0.1922,
+ -0.9497,
+ 0.2503,
+ -0.2921,
+ ],
+ latents_std: List[float] = [
+ 2.8184,
+ 1.4541,
+ 2.3275,
+ 2.6558,
+ 1.2196,
+ 1.7708,
+ 2.6052,
+ 2.0743,
+ 3.2687,
+ 2.1526,
+ 2.8652,
+ 1.5579,
+ 1.6382,
+ 1.1253,
+ 2.8251,
+ 1.9160,
+ ],
+ input_channels: int = 3,
+ spatial_chunk_size: Optional[int] = None,
+ disable_cache: bool = False,
+ ) -> None:
+ super().__init__()
+
+ self.z_dim = z_dim
+ self.temperal_downsample = temperal_downsample
+ self.temperal_upsample = temperal_downsample[::-1]
+ self.latents_mean = latents_mean
+ self.latents_std = latents_std
+
+ self.encoder = QwenImageEncoder3d(
+ base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout, input_channels
+ )
+ self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1)
+ self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1)
+
+ self.decoder = QwenImageDecoder3d(
+ base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout, input_channels
+ )
+
+ self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
+
+ # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
+ # to perform decoding of a single video latent at a time.
+ self.use_slicing = False
+
+ # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
+ # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
+ # intermediate tiles together, the memory requirement can be lowered.
+ self.use_tiling = False
+
+ # The minimal tile height and width for spatial tiling to be used
+ self.tile_sample_min_height = 256
+ self.tile_sample_min_width = 256
+
+ # The minimal distance between two spatial tiles
+ self.tile_sample_stride_height = 192
+ self.tile_sample_stride_width = 192
+
+ # Precompute and cache conv counts for encoder and decoder for clear_cache speedup
+ self._cached_conv_counts = {
+ "decoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.decoder.modules()) if self.decoder is not None else 0,
+ "encoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.encoder.modules()) if self.encoder is not None else 0,
+ }
+
+ self.spatial_chunk_size = None
+ if spatial_chunk_size is not None and spatial_chunk_size > 0:
+ self.enable_spatial_chunking(spatial_chunk_size)
+
+ self.cache_disabled = False
+ if disable_cache:
+ self.disable_cache()
+
+ @property
+ def dtype(self):
+ return self.encoder.parameters().__next__().dtype
+
+ @property
+ def device(self):
+ return self.encoder.parameters().__next__().device
+
+ def enable_tiling(
+ self,
+ tile_sample_min_height: Optional[int] = None,
+ tile_sample_min_width: Optional[int] = None,
+ tile_sample_stride_height: Optional[float] = None,
+ tile_sample_stride_width: Optional[float] = None,
+ ) -> None:
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+
+ Args:
+ tile_sample_min_height (`int`, *optional*):
+ The minimum height required for a sample to be separated into tiles across the height dimension.
+ tile_sample_min_width (`int`, *optional*):
+ The minimum width required for a sample to be separated into tiles across the width dimension.
+ tile_sample_stride_height (`int`, *optional*):
+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
+ no tiling artifacts produced across the height dimension.
+ tile_sample_stride_width (`int`, *optional*):
+ The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
+ artifacts produced across the width dimension.
+ """
+ self.use_tiling = True
+ self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
+ self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
+ self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
+ self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
+
+ def disable_tiling(self) -> None:
+ r"""
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_tiling = False
+
+ def enable_slicing(self) -> None:
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.use_slicing = True
+
+ def disable_slicing(self) -> None:
+ r"""
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_slicing = False
+
+ def enable_spatial_chunking(self, spatial_chunk_size: int) -> None:
+ r"""
+ Enable memory-efficient convolution by chunking all causal Conv3d layers only along height.
+ """
+ if spatial_chunk_size is None or spatial_chunk_size <= 0:
+ raise ValueError(f"`spatial_chunk_size` must be a positive integer, got {spatial_chunk_size}.")
+ self.spatial_chunk_size = int(spatial_chunk_size)
+ for module in self.modules():
+ if isinstance(module, QwenImageCausalConv3d):
+ module.spatial_chunk_size = self.spatial_chunk_size
+ elif isinstance(module, ChunkedConv2d):
+ module.spatial_chunk_size = self.spatial_chunk_size
+
+ def disable_spatial_chunking(self) -> None:
+ r"""
+ Disable memory-efficient convolution chunking on all causal Conv3d layers.
+ """
+ self.spatial_chunk_size = None
+ for module in self.modules():
+ if isinstance(module, QwenImageCausalConv3d):
+ module.spatial_chunk_size = None
+ elif isinstance(module, ChunkedConv2d):
+ module.spatial_chunk_size = None
+
+ def disable_cache(self) -> None:
+ r"""
+ Disable caching mechanism in encoder and decoder.
+ """
+ self.cache_disabled = True
+ self.clear_cache = lambda: None
+ self._feat_map = None # Disable decoder cache
+ self._enc_feat_map = None # Disable encoder cache
+
+ def clear_cache(self):
+ def _count_conv3d(model):
+ count = 0
+ for m in model.modules():
+ if isinstance(m, QwenImageCausalConv3d):
+ count += 1
+ return count
+
+ self._conv_num = _count_conv3d(self.decoder)
+ self._conv_idx = [0]
+ self._feat_map = [None] * self._conv_num
+ # cache encode
+ self._enc_conv_num = _count_conv3d(self.encoder)
+ self._enc_conv_idx = [0]
+ self._enc_feat_map = [None] * self._enc_conv_num
+
+ def _encode(self, x: torch.Tensor):
+ _, _, num_frame, height, width = x.shape
+ assert num_frame == 1 or not self.cache_disabled, "Caching must be enabled for encoding multiple frames."
+
+ if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
+ return self.tiled_encode(x)
+
+ self.clear_cache()
+ iter_ = 1 + (num_frame - 1) // 4
+ for i in range(iter_):
+ self._enc_conv_idx = [0]
+ if i == 0:
+ out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
+ else:
+ out_ = self.encoder(
+ x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :],
+ feat_cache=self._enc_feat_map,
+ feat_idx=self._enc_conv_idx,
+ )
+ out = torch.cat([out, out_], 2)
+
+ enc = self.quant_conv(out)
+ self.clear_cache()
+ return enc
+
+ # @apply_forward_hook
+ def encode(
+ self, x: torch.Tensor, return_dict: bool = True
+ ) -> Union[Dict[str, torch.Tensor], Tuple[DiagonalGaussianDistribution]]:
+ r"""
+ Encode a batch of images into latents.
+
+ Args:
+ x (`torch.Tensor`): Input batch of images.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
+
+ Returns:
+ The latent representations of the encoded videos. If `return_dict` is True, a dictionary is returned, otherwise a plain `tuple` is returned.
+ """
+ if self.use_slicing and x.shape[0] > 1:
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
+ h = torch.cat(encoded_slices)
+ else:
+ h = self._encode(x)
+ posterior = DiagonalGaussianDistribution(h)
+
+ if not return_dict:
+ return (posterior,)
+ return {"latent_dist": posterior}
+
+ def _decode(self, z: torch.Tensor, return_dict: bool = True):
+ _, _, num_frame, height, width = z.shape
+ assert num_frame == 1 or not self.cache_disabled, "Caching must be enabled for encoding multiple frames."
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
+
+ if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
+ return self.tiled_decode(z, return_dict=return_dict)
+
+ self.clear_cache()
+ x = self.post_quant_conv(z)
+ for i in range(num_frame):
+ self._conv_idx = [0]
+ if i == 0:
+ out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
+ else:
+ out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
+ out = torch.cat([out, out_], 2)
+
+ out = torch.clamp(out, min=-1.0, max=1.0)
+ self.clear_cache()
+ if not return_dict:
+ return (out,)
+
+ return {"sample": out}
+
+ # @apply_forward_hook
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Dict[str, torch.Tensor], torch.Tensor]:
+ r"""
+ Decode a batch of images.
+
+ Args:
+ z (`torch.Tensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+ """
+ if self.use_slicing and z.shape[0] > 1:
+ decoded_slices = [self._decode(z_slice)["sample"] for z_slice in z.split(1)]
+ decoded = torch.cat(decoded_slices)
+ else:
+ decoded = self._decode(z)["sample"]
+
+ if not return_dict:
+ return (decoded,)
+ return {"sample": decoded}
+
+ def decode_to_pixels(self, latents: torch.Tensor) -> torch.Tensor:
+ is_4d = latents.dim() == 4
+ if is_4d:
+ latents = latents.unsqueeze(2) # [B, C, H, W] -> [B, C, 1, H, W]
+
+ latents = latents.to(self.dtype)
+ latents_mean = torch.tensor(self.latents_mean).view(1, self.z_dim, 1, 1, 1).to(latents.device, latents.dtype)
+ latents_std = 1.0 / torch.tensor(self.latents_std).view(1, self.z_dim, 1, 1, 1).to(latents.device, latents.dtype)
+ latents = latents / latents_std + latents_mean
+
+ image = self.decode(latents, return_dict=False)[0] # -1 to 1
+ if is_4d:
+ image = image.squeeze(2) # [B, C, 1, H, W] -> [B, C, H, W]
+
+ return image.clamp(-1.0, 1.0)
+
+ def encode_pixels_to_latents(self, pixels: torch.Tensor) -> torch.Tensor:
+ """
+ Convert pixel values to latents and apply normalization using mean/std.
+
+ Args:
+ pixels (torch.Tensor): Input pixels in [0, 1] range with shape [B, C, H, W] or [B, C, T, H, W]
+
+ Returns:
+ torch.Tensor: Normalized latents
+ """
+ # # Convert from [0, 1] to [-1, 1] range
+ # pixels = (pixels * 2.0 - 1.0).clamp(-1.0, 1.0)
+
+ # Handle 2D input by adding temporal dimension
+ is_4d = pixels.dim() == 4
+ if is_4d:
+ pixels = pixels.unsqueeze(2) # [B, C, H, W] -> [B, C, 1, H, W]
+
+ pixels = pixels.to(self.dtype)
+
+ # Encode to latent space
+ posterior = self.encode(pixels, return_dict=False)[0]
+ latents = posterior.mode() # Use mode instead of sampling for deterministic results
+ # latents = posterior.sample()
+
+ # Apply normalization using mean/std
+ latents_mean = torch.tensor(self.latents_mean).view(1, self.z_dim, 1, 1, 1).to(latents.device, latents.dtype)
+ latents_std = 1.0 / torch.tensor(self.latents_std).view(1, self.z_dim, 1, 1, 1).to(latents.device, latents.dtype)
+ latents = (latents - latents_mean) * latents_std
+
+ if is_4d:
+ latents = latents.squeeze(2) # [B, C, 1, H, W] -> [B, C, H, W]
+
+ return latents
+
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
+ for y in range(blend_extent):
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent)
+ return b
+
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
+ for x in range(blend_extent):
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent)
+ return b
+
+ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
+ r"""Encode a batch of images using a tiled encoder.
+
+ Args:
+ x (`torch.Tensor`): Input batch of videos.
+
+ Returns:
+ `torch.Tensor`:
+ The latent representation of the encoded videos.
+ """
+ _, _, num_frames, height, width = x.shape
+ latent_height = height // self.spatial_compression_ratio
+ latent_width = width // self.spatial_compression_ratio
+
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
+
+ blend_height = tile_latent_min_height - tile_latent_stride_height
+ blend_width = tile_latent_min_width - tile_latent_stride_width
+
+ # Split x into overlapping tiles and encode them separately.
+ # The tiles have an overlap to avoid seams between tiles.
+ rows = []
+ for i in range(0, height, self.tile_sample_stride_height):
+ row = []
+ for j in range(0, width, self.tile_sample_stride_width):
+ self.clear_cache()
+ time = []
+ frame_range = 1 + (num_frames - 1) // 4
+ for k in range(frame_range):
+ self._enc_conv_idx = [0]
+ if k == 0:
+ tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
+ else:
+ tile = x[
+ :,
+ :,
+ 1 + 4 * (k - 1) : 1 + 4 * k,
+ i : i + self.tile_sample_min_height,
+ j : j + self.tile_sample_min_width,
+ ]
+ tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
+ tile = self.quant_conv(tile)
+ time.append(tile)
+ row.append(torch.cat(time, dim=2))
+ rows.append(row)
+ self.clear_cache()
+
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_width)
+ result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
+ result_rows.append(torch.cat(result_row, dim=-1))
+
+ enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
+ return enc
+
+ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Dict[str, torch.Tensor], torch.Tensor]:
+ r"""
+ Decode a batch of images using a tiled decoder.
+
+ Args:
+ z (`torch.Tensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a dictionary instead of a plain tuple.
+
+ Returns:
+ `dict` or `tuple`:
+ If return_dict is True, a dictionary is returned, otherwise a plain `tuple` is
+ returned.
+ """
+ _, _, num_frames, height, width = z.shape
+ sample_height = height * self.spatial_compression_ratio
+ sample_width = width * self.spatial_compression_ratio
+
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
+
+ blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
+ blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
+
+ # Split z into overlapping tiles and decode them separately.
+ # The tiles have an overlap to avoid seams between tiles.
+ rows = []
+ for i in range(0, height, tile_latent_stride_height):
+ row = []
+ for j in range(0, width, tile_latent_stride_width):
+ self.clear_cache()
+ time = []
+ for k in range(num_frames):
+ self._conv_idx = [0]
+ tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
+ tile = self.post_quant_conv(tile)
+ decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)
+ time.append(decoded)
+ row.append(torch.cat(time, dim=2))
+ rows.append(row)
+ self.clear_cache()
+
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_width)
+ result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
+ result_rows.append(torch.cat(result_row, dim=-1))
+
+ dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
+
+ if not return_dict:
+ return (dec,)
+ return {"sample": dec}
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ sample_posterior: bool = False,
+ return_dict: bool = True,
+ generator: Optional[torch.Generator] = None,
+ ) -> Union[Dict[str, torch.Tensor], torch.Tensor]:
+ """
+ Args:
+ sample (`torch.Tensor`): Input sample.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`Dict[str, torch.Tensor]`] instead of a plain tuple.
+ """
+ x = sample
+ posterior = self.encode(x).latent_dist
+ if sample_posterior:
+ z = posterior.sample(generator=generator)
+ else:
+ z = posterior.mode()
+ dec = self.decode(z, return_dict=return_dict)
+ return dec
+
+
+# region utils
+
+# This region is not included in the original implementation. Added for musubi-tuner/sd-scripts.
+
+
+# Convert ComfyUI keys to standard keys if necessary
+def convert_comfyui_state_dict(sd):
+ if "conv1.bias" not in sd:
+ return sd
+
+ # Key mapping from ComfyUI VAE to official VAE, auto-generated by a script
+ key_map = {
+ "conv1": "quant_conv",
+ "conv2": "post_quant_conv",
+ "decoder.conv1": "decoder.conv_in",
+ "decoder.head.0": "decoder.norm_out",
+ "decoder.head.2": "decoder.conv_out",
+ "decoder.middle.0.residual.0": "decoder.mid_block.resnets.0.norm1",
+ "decoder.middle.0.residual.2": "decoder.mid_block.resnets.0.conv1",
+ "decoder.middle.0.residual.3": "decoder.mid_block.resnets.0.norm2",
+ "decoder.middle.0.residual.6": "decoder.mid_block.resnets.0.conv2",
+ "decoder.middle.1.norm": "decoder.mid_block.attentions.0.norm",
+ "decoder.middle.1.proj": "decoder.mid_block.attentions.0.proj",
+ "decoder.middle.1.to_qkv": "decoder.mid_block.attentions.0.to_qkv",
+ "decoder.middle.2.residual.0": "decoder.mid_block.resnets.1.norm1",
+ "decoder.middle.2.residual.2": "decoder.mid_block.resnets.1.conv1",
+ "decoder.middle.2.residual.3": "decoder.mid_block.resnets.1.norm2",
+ "decoder.middle.2.residual.6": "decoder.mid_block.resnets.1.conv2",
+ "decoder.upsamples.0.residual.0": "decoder.up_blocks.0.resnets.0.norm1",
+ "decoder.upsamples.0.residual.2": "decoder.up_blocks.0.resnets.0.conv1",
+ "decoder.upsamples.0.residual.3": "decoder.up_blocks.0.resnets.0.norm2",
+ "decoder.upsamples.0.residual.6": "decoder.up_blocks.0.resnets.0.conv2",
+ "decoder.upsamples.1.residual.0": "decoder.up_blocks.0.resnets.1.norm1",
+ "decoder.upsamples.1.residual.2": "decoder.up_blocks.0.resnets.1.conv1",
+ "decoder.upsamples.1.residual.3": "decoder.up_blocks.0.resnets.1.norm2",
+ "decoder.upsamples.1.residual.6": "decoder.up_blocks.0.resnets.1.conv2",
+ "decoder.upsamples.10.residual.0": "decoder.up_blocks.2.resnets.2.norm1",
+ "decoder.upsamples.10.residual.2": "decoder.up_blocks.2.resnets.2.conv1",
+ "decoder.upsamples.10.residual.3": "decoder.up_blocks.2.resnets.2.norm2",
+ "decoder.upsamples.10.residual.6": "decoder.up_blocks.2.resnets.2.conv2",
+ "decoder.upsamples.11.resample.1": "decoder.up_blocks.2.upsamplers.0.resample.1",
+ "decoder.upsamples.12.residual.0": "decoder.up_blocks.3.resnets.0.norm1",
+ "decoder.upsamples.12.residual.2": "decoder.up_blocks.3.resnets.0.conv1",
+ "decoder.upsamples.12.residual.3": "decoder.up_blocks.3.resnets.0.norm2",
+ "decoder.upsamples.12.residual.6": "decoder.up_blocks.3.resnets.0.conv2",
+ "decoder.upsamples.13.residual.0": "decoder.up_blocks.3.resnets.1.norm1",
+ "decoder.upsamples.13.residual.2": "decoder.up_blocks.3.resnets.1.conv1",
+ "decoder.upsamples.13.residual.3": "decoder.up_blocks.3.resnets.1.norm2",
+ "decoder.upsamples.13.residual.6": "decoder.up_blocks.3.resnets.1.conv2",
+ "decoder.upsamples.14.residual.0": "decoder.up_blocks.3.resnets.2.norm1",
+ "decoder.upsamples.14.residual.2": "decoder.up_blocks.3.resnets.2.conv1",
+ "decoder.upsamples.14.residual.3": "decoder.up_blocks.3.resnets.2.norm2",
+ "decoder.upsamples.14.residual.6": "decoder.up_blocks.3.resnets.2.conv2",
+ "decoder.upsamples.2.residual.0": "decoder.up_blocks.0.resnets.2.norm1",
+ "decoder.upsamples.2.residual.2": "decoder.up_blocks.0.resnets.2.conv1",
+ "decoder.upsamples.2.residual.3": "decoder.up_blocks.0.resnets.2.norm2",
+ "decoder.upsamples.2.residual.6": "decoder.up_blocks.0.resnets.2.conv2",
+ "decoder.upsamples.3.resample.1": "decoder.up_blocks.0.upsamplers.0.resample.1",
+ "decoder.upsamples.3.time_conv": "decoder.up_blocks.0.upsamplers.0.time_conv",
+ "decoder.upsamples.4.residual.0": "decoder.up_blocks.1.resnets.0.norm1",
+ "decoder.upsamples.4.residual.2": "decoder.up_blocks.1.resnets.0.conv1",
+ "decoder.upsamples.4.residual.3": "decoder.up_blocks.1.resnets.0.norm2",
+ "decoder.upsamples.4.residual.6": "decoder.up_blocks.1.resnets.0.conv2",
+ "decoder.upsamples.4.shortcut": "decoder.up_blocks.1.resnets.0.conv_shortcut",
+ "decoder.upsamples.5.residual.0": "decoder.up_blocks.1.resnets.1.norm1",
+ "decoder.upsamples.5.residual.2": "decoder.up_blocks.1.resnets.1.conv1",
+ "decoder.upsamples.5.residual.3": "decoder.up_blocks.1.resnets.1.norm2",
+ "decoder.upsamples.5.residual.6": "decoder.up_blocks.1.resnets.1.conv2",
+ "decoder.upsamples.6.residual.0": "decoder.up_blocks.1.resnets.2.norm1",
+ "decoder.upsamples.6.residual.2": "decoder.up_blocks.1.resnets.2.conv1",
+ "decoder.upsamples.6.residual.3": "decoder.up_blocks.1.resnets.2.norm2",
+ "decoder.upsamples.6.residual.6": "decoder.up_blocks.1.resnets.2.conv2",
+ "decoder.upsamples.7.resample.1": "decoder.up_blocks.1.upsamplers.0.resample.1",
+ "decoder.upsamples.7.time_conv": "decoder.up_blocks.1.upsamplers.0.time_conv",
+ "decoder.upsamples.8.residual.0": "decoder.up_blocks.2.resnets.0.norm1",
+ "decoder.upsamples.8.residual.2": "decoder.up_blocks.2.resnets.0.conv1",
+ "decoder.upsamples.8.residual.3": "decoder.up_blocks.2.resnets.0.norm2",
+ "decoder.upsamples.8.residual.6": "decoder.up_blocks.2.resnets.0.conv2",
+ "decoder.upsamples.9.residual.0": "decoder.up_blocks.2.resnets.1.norm1",
+ "decoder.upsamples.9.residual.2": "decoder.up_blocks.2.resnets.1.conv1",
+ "decoder.upsamples.9.residual.3": "decoder.up_blocks.2.resnets.1.norm2",
+ "decoder.upsamples.9.residual.6": "decoder.up_blocks.2.resnets.1.conv2",
+ "encoder.conv1": "encoder.conv_in",
+ "encoder.downsamples.0.residual.0": "encoder.down_blocks.0.norm1",
+ "encoder.downsamples.0.residual.2": "encoder.down_blocks.0.conv1",
+ "encoder.downsamples.0.residual.3": "encoder.down_blocks.0.norm2",
+ "encoder.downsamples.0.residual.6": "encoder.down_blocks.0.conv2",
+ "encoder.downsamples.1.residual.0": "encoder.down_blocks.1.norm1",
+ "encoder.downsamples.1.residual.2": "encoder.down_blocks.1.conv1",
+ "encoder.downsamples.1.residual.3": "encoder.down_blocks.1.norm2",
+ "encoder.downsamples.1.residual.6": "encoder.down_blocks.1.conv2",
+ "encoder.downsamples.10.residual.0": "encoder.down_blocks.10.norm1",
+ "encoder.downsamples.10.residual.2": "encoder.down_blocks.10.conv1",
+ "encoder.downsamples.10.residual.3": "encoder.down_blocks.10.norm2",
+ "encoder.downsamples.10.residual.6": "encoder.down_blocks.10.conv2",
+ "encoder.downsamples.2.resample.1": "encoder.down_blocks.2.resample.1",
+ "encoder.downsamples.3.residual.0": "encoder.down_blocks.3.norm1",
+ "encoder.downsamples.3.residual.2": "encoder.down_blocks.3.conv1",
+ "encoder.downsamples.3.residual.3": "encoder.down_blocks.3.norm2",
+ "encoder.downsamples.3.residual.6": "encoder.down_blocks.3.conv2",
+ "encoder.downsamples.3.shortcut": "encoder.down_blocks.3.conv_shortcut",
+ "encoder.downsamples.4.residual.0": "encoder.down_blocks.4.norm1",
+ "encoder.downsamples.4.residual.2": "encoder.down_blocks.4.conv1",
+ "encoder.downsamples.4.residual.3": "encoder.down_blocks.4.norm2",
+ "encoder.downsamples.4.residual.6": "encoder.down_blocks.4.conv2",
+ "encoder.downsamples.5.resample.1": "encoder.down_blocks.5.resample.1",
+ "encoder.downsamples.5.time_conv": "encoder.down_blocks.5.time_conv",
+ "encoder.downsamples.6.residual.0": "encoder.down_blocks.6.norm1",
+ "encoder.downsamples.6.residual.2": "encoder.down_blocks.6.conv1",
+ "encoder.downsamples.6.residual.3": "encoder.down_blocks.6.norm2",
+ "encoder.downsamples.6.residual.6": "encoder.down_blocks.6.conv2",
+ "encoder.downsamples.6.shortcut": "encoder.down_blocks.6.conv_shortcut",
+ "encoder.downsamples.7.residual.0": "encoder.down_blocks.7.norm1",
+ "encoder.downsamples.7.residual.2": "encoder.down_blocks.7.conv1",
+ "encoder.downsamples.7.residual.3": "encoder.down_blocks.7.norm2",
+ "encoder.downsamples.7.residual.6": "encoder.down_blocks.7.conv2",
+ "encoder.downsamples.8.resample.1": "encoder.down_blocks.8.resample.1",
+ "encoder.downsamples.8.time_conv": "encoder.down_blocks.8.time_conv",
+ "encoder.downsamples.9.residual.0": "encoder.down_blocks.9.norm1",
+ "encoder.downsamples.9.residual.2": "encoder.down_blocks.9.conv1",
+ "encoder.downsamples.9.residual.3": "encoder.down_blocks.9.norm2",
+ "encoder.downsamples.9.residual.6": "encoder.down_blocks.9.conv2",
+ "encoder.head.0": "encoder.norm_out",
+ "encoder.head.2": "encoder.conv_out",
+ "encoder.middle.0.residual.0": "encoder.mid_block.resnets.0.norm1",
+ "encoder.middle.0.residual.2": "encoder.mid_block.resnets.0.conv1",
+ "encoder.middle.0.residual.3": "encoder.mid_block.resnets.0.norm2",
+ "encoder.middle.0.residual.6": "encoder.mid_block.resnets.0.conv2",
+ "encoder.middle.1.norm": "encoder.mid_block.attentions.0.norm",
+ "encoder.middle.1.proj": "encoder.mid_block.attentions.0.proj",
+ "encoder.middle.1.to_qkv": "encoder.mid_block.attentions.0.to_qkv",
+ "encoder.middle.2.residual.0": "encoder.mid_block.resnets.1.norm1",
+ "encoder.middle.2.residual.2": "encoder.mid_block.resnets.1.conv1",
+ "encoder.middle.2.residual.3": "encoder.mid_block.resnets.1.norm2",
+ "encoder.middle.2.residual.6": "encoder.mid_block.resnets.1.conv2",
+ }
+
+ new_state_dict = {}
+ for key in sd.keys():
+ new_key = key
+ key_without_suffix = key.rsplit(".", 1)[0]
+ if key_without_suffix in key_map:
+ new_key = key.replace(key_without_suffix, key_map[key_without_suffix])
+ new_state_dict[new_key] = sd[key]
+
+ logger.info("Converted ComfyUI AutoencoderKL state dict keys to official format")
+ return new_state_dict
+
+
+def load_vae(
+ vae_path: str,
+ input_channels: int = 3,
+ device: Union[str, torch.device] = "cpu",
+ disable_mmap: bool = False,
+ spatial_chunk_size: Optional[int] = None,
+ disable_cache: bool = False,
+) -> AutoencoderKLQwenImage:
+ """Load VAE from a given path."""
+ VAE_CONFIG_JSON = """
+{
+ "_class_name": "AutoencoderKLQwenImage",
+ "_diffusers_version": "0.34.0.dev0",
+ "attn_scales": [],
+ "base_dim": 96,
+ "dim_mult": [
+ 1,
+ 2,
+ 4,
+ 4
+ ],
+ "dropout": 0.0,
+ "latents_mean": [
+ -0.7571,
+ -0.7089,
+ -0.9113,
+ 0.1075,
+ -0.1745,
+ 0.9653,
+ -0.1517,
+ 1.5508,
+ 0.4134,
+ -0.0715,
+ 0.5517,
+ -0.3632,
+ -0.1922,
+ -0.9497,
+ 0.2503,
+ -0.2921
+ ],
+ "latents_std": [
+ 2.8184,
+ 1.4541,
+ 2.3275,
+ 2.6558,
+ 1.2196,
+ 1.7708,
+ 2.6052,
+ 2.0743,
+ 3.2687,
+ 2.1526,
+ 2.8652,
+ 1.5579,
+ 1.6382,
+ 1.1253,
+ 2.8251,
+ 1.916
+ ],
+ "num_res_blocks": 2,
+ "temperal_downsample": [
+ false,
+ true,
+ true
+ ],
+ "z_dim": 16
+}
+"""
+ logger.info("Initializing VAE")
+
+ if spatial_chunk_size is not None and spatial_chunk_size % 2 != 0:
+ spatial_chunk_size += 1
+ logger.warning(f"Adjusted spatial_chunk_size to the next even number: {spatial_chunk_size}")
+
+ config = json.loads(VAE_CONFIG_JSON)
+ vae = AutoencoderKLQwenImage(
+ base_dim=config["base_dim"],
+ z_dim=config["z_dim"],
+ dim_mult=config["dim_mult"],
+ num_res_blocks=config["num_res_blocks"],
+ attn_scales=config["attn_scales"],
+ temperal_downsample=config["temperal_downsample"],
+ dropout=config["dropout"],
+ latents_mean=config["latents_mean"],
+ latents_std=config["latents_std"],
+ input_channels=input_channels,
+ spatial_chunk_size=spatial_chunk_size,
+ disable_cache=disable_cache,
+ )
+
+ logger.info(f"Loading VAE from {vae_path}")
+ state_dict = load_safetensors(vae_path, device=device, disable_mmap=disable_mmap)
+
+ # Convert ComfyUI VAE keys to official VAE keys
+ state_dict = convert_comfyui_state_dict(state_dict)
+
+ info = vae.load_state_dict(state_dict, strict=True, assign=True)
+ logger.info(f"Loaded VAE: {info}")
+
+ vae.to(device)
+ return vae
+
+
+if __name__ == "__main__":
+ # Debugging / testing code
+ import argparse
+ import glob
+ import os
+ import time
+
+ from PIL import Image
+
+ from library.device_utils import get_preferred_device, synchronize_device
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--vae", type=str, required=True, help="Path to the VAE model file.")
+ parser.add_argument("--input_image_dir", type=str, required=True, help="Path to the input image directory.")
+ parser.add_argument("--output_image_dir", type=str, required=True, help="Path to the output image directory.")
+ args = parser.parse_args()
+
+ # Load VAE
+ vae = load_vae(args.vae, device=get_preferred_device())
+
+ # Process images
+ def encode_decode_image(image_path, output_path):
+ image = Image.open(image_path).convert("RGB")
+
+ # Crop to multiple of 8
+ width, height = image.size
+ new_width = (width // 8) * 8
+ new_height = (height // 8) * 8
+ if new_width != width or new_height != height:
+ image = image.crop((0, 0, new_width, new_height))
+
+ image_tensor = torch.tensor(np.array(image)).permute(2, 0, 1).unsqueeze(0).float() / 255.0 * 2 - 1
+ image_tensor = image_tensor.to(vae.dtype).to(vae.device)
+
+ with torch.no_grad():
+ latents = vae.encode_pixels_to_latents(image_tensor)
+ reconstructed = vae.decode_to_pixels(latents)
+
+ diff = (image_tensor - reconstructed).abs().mean().item()
+ print(f"Processed {image_path} (size: {image.size}), reconstruction diff: {diff}")
+
+ reconstructed_image = ((reconstructed.squeeze(0).permute(1, 2, 0).float().cpu().numpy() + 1) / 2 * 255).astype(np.uint8)
+ Image.fromarray(reconstructed_image).save(output_path)
+
+ def process_directory(input_dir, output_dir):
+ if get_preferred_device().type == "cuda":
+ torch.cuda.empty_cache()
+ torch.cuda.reset_max_memory_allocated()
+
+ synchronize_device(get_preferred_device())
+ start_time = time.perf_counter()
+
+ os.makedirs(output_dir, exist_ok=True)
+ image_paths = glob.glob(os.path.join(input_dir, "*.jpg")) + glob.glob(os.path.join(input_dir, "*.png"))
+ for image_path in image_paths:
+ filename = os.path.basename(image_path)
+ output_path = os.path.join(output_dir, filename)
+ encode_decode_image(image_path, output_path)
+
+ if get_preferred_device().type == "cuda":
+ max_mem = torch.cuda.max_memory_allocated() / (1024**3)
+ print(f"Max GPU memory allocated: {max_mem:.2f} GB")
+
+ synchronize_device(get_preferred_device())
+ end_time = time.perf_counter()
+ print(f"Processing time: {end_time - start_time:.2f} seconds")
+
+ print("Starting image processing with default settings...")
+ process_directory(args.input_image_dir, args.output_image_dir)
+
+ print("Starting image processing with spatial chunking enabled with chunk size 64...")
+ vae.enable_spatial_chunking(64)
+ process_directory(args.input_image_dir, args.output_image_dir + "_chunked_64")
+
+ print("Starting image processing with spatial chunking enabled with chunk size 16...")
+ vae.enable_spatial_chunking(16)
+ process_directory(args.input_image_dir, args.output_image_dir + "_chunked_16")
+
+ print("Starting image processing without caching and chunking enabled with chunk size 64...")
+ vae.enable_spatial_chunking(64)
+ vae.disable_cache()
+ process_directory(args.input_image_dir, args.output_image_dir + "_no_cache_chunked_64")
+
+ print("Starting image processing without caching and chunking enabled with chunk size 16...")
+ vae.disable_cache()
+ process_directory(args.input_image_dir, args.output_image_dir + "_no_cache_chunked_16")
+
+ print("Starting image processing without caching and chunking disabled...")
+ vae.disable_spatial_chunking()
+ process_directory(args.input_image_dir, args.output_image_dir + "_no_cache")
+
+ print("Processing completed.")
diff --git a/library/safetensors_utils.py b/library/safetensors_utils.py
index c65cdfab..c7a3bdd7 100644
--- a/library/safetensors_utils.py
+++ b/library/safetensors_utils.py
@@ -1,3 +1,4 @@
+from dataclasses import dataclass
import os
import re
import numpy as np
@@ -44,6 +45,7 @@ def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata:
validated[key] = value
return validated
+ # print(f"Using memory efficient save file: {filename}")
header = {}
offset = 0
@@ -88,15 +90,17 @@ class MemoryEfficientSafeOpen:
by using memory mapping for large tensors and avoiding unnecessary copies.
"""
- def __init__(self, filename):
+ def __init__(self, filename, disable_numpy_memmap=False):
"""Initialize the SafeTensor reader.
Args:
filename (str): Path to the safetensors file to read.
+ disable_numpy_memmap (bool): If True, disable numpy memory mapping for large tensors, using standard file read instead.
"""
self.filename = filename
self.file = open(filename, "rb")
self.header, self.header_size = self._read_header()
+ self.disable_numpy_memmap = disable_numpy_memmap
def __enter__(self):
"""Enter context manager."""
@@ -178,7 +182,8 @@ class MemoryEfficientSafeOpen:
# Use memmap for large tensors to avoid intermediate copies.
# If device is cpu, tensor is not copied to gpu, so using memmap locks the file, which is not desired.
# So we only use memmap if device is not cpu.
- if num_bytes > 10 * 1024 * 1024 and device is not None and device.type != "cpu":
+ # If disable_numpy_memmap is True, skip numpy memory mapping to load with standard file read.
+ if not self.disable_numpy_memmap and num_bytes > 10 * 1024 * 1024 and device is not None and device.type != "cpu":
# Create memory map for zero-copy reading
mm = np.memmap(self.filename, mode="c", dtype=np.uint8, offset=tensor_offset, shape=(num_bytes,))
byte_tensor = torch.from_numpy(mm) # zero copy
@@ -285,7 +290,11 @@ class MemoryEfficientSafeOpen:
def load_safetensors(
- path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = None
+ path: str,
+ device: Union[str, torch.device],
+ disable_mmap: bool = False,
+ dtype: Optional[torch.dtype] = None,
+ disable_numpy_memmap: bool = False,
) -> dict[str, torch.Tensor]:
if disable_mmap:
# return safetensors.torch.load(open(path, "rb").read())
@@ -293,7 +302,7 @@ def load_safetensors(
# logger.info(f"Loading without mmap (experimental)")
state_dict = {}
device = torch.device(device) if device is not None else None
- with MemoryEfficientSafeOpen(path) as f:
+ with MemoryEfficientSafeOpen(path, disable_numpy_memmap=disable_numpy_memmap) as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key, device=device, dtype=dtype)
synchronize_device(device)
@@ -309,6 +318,29 @@ def load_safetensors(
return state_dict
+def get_split_weight_filenames(file_path: str) -> Optional[list[str]]:
+ """
+ Get the list of split weight filenames (full paths) if the file name ends with 00001-of-00004 etc.
+ Returns None if the file is not split.
+ """
+ basename = os.path.basename(file_path)
+ match = re.match(r"^(.*?)(\d+)-of-(\d+)\.safetensors$", basename)
+ if match:
+ prefix = basename[: match.start(2)]
+ count = int(match.group(3))
+ filenames = []
+ for i in range(count):
+ filename = f"{prefix}{i + 1:05d}-of-{count:05d}.safetensors"
+ filepath = os.path.join(os.path.dirname(file_path), filename)
+ if os.path.exists(filepath):
+ filenames.append(filepath)
+ else:
+ raise FileNotFoundError(f"File {filepath} not found")
+ return filenames
+ else:
+ return None
+
+
def load_split_weights(
file_path: str, device: Union[str, torch.device] = "cpu", disable_mmap: bool = False, dtype: Optional[torch.dtype] = None
) -> Dict[str, torch.Tensor]:
@@ -319,19 +351,11 @@ def load_split_weights(
device = torch.device(device)
# if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix
- basename = os.path.basename(file_path)
- match = re.match(r"^(.*?)(\d+)-of-(\d+)\.safetensors$", basename)
- if match:
- prefix = basename[: match.start(2)]
- count = int(match.group(3))
+ split_filenames = get_split_weight_filenames(file_path)
+ if split_filenames is not None:
state_dict = {}
- for i in range(count):
- filename = f"{prefix}{i + 1:05d}-of-{count:05d}.safetensors"
- filepath = os.path.join(os.path.dirname(file_path), filename)
- if os.path.exists(filepath):
- state_dict.update(load_safetensors(filepath, device=device, disable_mmap=disable_mmap, dtype=dtype))
- else:
- raise FileNotFoundError(f"File {filepath} not found")
+ for filename in split_filenames:
+ state_dict.update(load_safetensors(filename, device=device, disable_mmap=disable_mmap, dtype=dtype))
else:
state_dict = load_safetensors(file_path, device=device, disable_mmap=disable_mmap, dtype=dtype)
return state_dict
@@ -349,3 +373,106 @@ def find_key(safetensors_file: str, starts_with: Optional[str] = None, ends_with
if (starts_with is None or key.startswith(starts_with)) and (ends_with is None or key.endswith(ends_with)):
return key
return None
+
+
+@dataclass
+class WeightTransformHooks:
+ split_hook: Optional[callable] = None
+ concat_hook: Optional[callable] = None
+ rename_hook: Optional[callable] = None
+
+
+class TensorWeightAdapter:
+ """
+ A wrapper for weight conversion hooks (split and concat) to be used with MemoryEfficientSafeOpen.
+ This wrapper adapts the original MemoryEfficientSafeOpen to apply the provided split and concat hooks
+ when loading tensors.
+
+ split_hook: A callable that takes (original_key: str, original_tensor: torch.Tensor) and returns (new_keys: list[str], new_tensors: list[torch.Tensor]).
+ concat_hook: A callable that takes (original_key: str, tensors: dict[str, torch.Tensor]) and returns (new_key: str, concatenated_tensor: torch.Tensor).
+ rename_hook: A callable that takes (original_key: str) and returns (new_key: str).
+
+ If tensors is None, the hook should return only the new keys (for split) or new key (for concat), without tensors.
+
+ No need to implement __enter__ and __exit__ methods, as they are handled by the original MemoryEfficientSafeOpen.
+ Do not use this wrapper as a context manager directly, like `with WeightConvertHookWrapper(...) as f:`.
+
+ **concat_hook is not tested yet.**
+ """
+
+ def __init__(self, weight_convert_hook: WeightTransformHooks, original_f: MemoryEfficientSafeOpen):
+ self.original_f = original_f
+ self.new_key_to_original_key_map: dict[str, Union[str, list[str]]] = (
+ {}
+ ) # for split: new_key -> original_key; for concat: new_key -> list of original_keys; for direct mapping: new_key -> original_key
+ self.concat_key_set = set() # set of concatenated keys
+ self.split_key_set = set() # set of split keys
+ self.new_keys = []
+ self.tensor_cache = {} # cache for split tensors
+ self.split_hook = weight_convert_hook.split_hook
+ self.concat_hook = weight_convert_hook.concat_hook
+ self.rename_hook = weight_convert_hook.rename_hook
+
+ for key in self.original_f.keys():
+ if self.split_hook is not None:
+ converted_keys, _ = self.split_hook(key, None) # get new keys only
+ if converted_keys is not None:
+ for converted_key in converted_keys:
+ self.new_key_to_original_key_map[converted_key] = key
+ self.split_key_set.add(converted_key)
+ self.new_keys.extend(converted_keys)
+ continue # skip concat_hook if split_hook is applied
+
+ if self.concat_hook is not None:
+ converted_key, _ = self.concat_hook(key, None) # get new key only
+ if converted_key is not None:
+ if converted_key not in self.concat_key_set: # first time seeing this concatenated key
+ self.concat_key_set.add(converted_key)
+ self.new_key_to_original_key_map[converted_key] = []
+ self.new_keys.append(converted_key)
+
+ # multiple original keys map to the same concatenated key
+ self.new_key_to_original_key_map[converted_key].append(key)
+ continue # skip to next key
+
+ # direct mapping
+ if self.rename_hook is not None:
+ new_key = self.rename_hook(key)
+ self.new_key_to_original_key_map[new_key] = key
+ else:
+ new_key = key
+
+ self.new_keys.append(new_key)
+
+ def keys(self) -> list[str]:
+ return self.new_keys
+
+ def get_tensor(self, new_key: str, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
+ # load tensor by new_key, applying split or concat hooks as needed
+ if new_key not in self.new_key_to_original_key_map:
+ # direct mapping
+ return self.original_f.get_tensor(new_key, device=device, dtype=dtype)
+
+ elif new_key in self.split_key_set:
+ # split hook: split key is requested multiple times, so we cache the result
+ original_key = self.new_key_to_original_key_map[new_key]
+ if original_key not in self.tensor_cache: # not yet split
+ original_tensor = self.original_f.get_tensor(original_key, device=device, dtype=dtype)
+ new_keys, new_tensors = self.split_hook(original_key, original_tensor) # apply split hook
+ for k, t in zip(new_keys, new_tensors):
+ self.tensor_cache[k] = t
+ return self.tensor_cache.pop(new_key) # return and remove from cache
+
+ elif new_key in self.concat_key_set:
+ # concat hook: concatenated key is requested only once, so we do not cache the result
+ tensors = {}
+ for original_key in self.new_key_to_original_key_map[new_key]:
+ tensor = self.original_f.get_tensor(original_key, device=device, dtype=dtype)
+ tensors[original_key] = tensor
+ _, concatenated_tensors = self.concat_hook(self.new_key_to_original_key_map[new_key][0], tensors) # apply concat hook
+ return concatenated_tensors
+
+ else:
+ # direct mapping
+ original_key = self.new_key_to_original_key_map[new_key]
+ return self.original_f.get_tensor(original_key, device=device, dtype=dtype)
diff --git a/library/sai_model_spec.py b/library/sai_model_spec.py
index 32a4fd7b..0ac9b3be 100644
--- a/library/sai_model_spec.py
+++ b/library/sai_model_spec.py
@@ -81,6 +81,8 @@ ARCH_LUMINA_2 = "lumina-2"
ARCH_LUMINA_UNKNOWN = "lumina"
ARCH_HUNYUAN_IMAGE_2_1 = "hunyuan-image-2.1"
ARCH_HUNYUAN_IMAGE_UNKNOWN = "hunyuan-image"
+ARCH_ANIMA_PREVIEW = "anima-preview"
+ARCH_ANIMA_UNKNOWN = "anima-unknown"
ADAPTER_LORA = "lora"
ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
@@ -92,6 +94,7 @@ IMPL_FLUX = "https://github.com/black-forest-labs/flux"
IMPL_CHROMA = "https://huggingface.co/lodestones/Chroma"
IMPL_LUMINA = "https://github.com/Alpha-VLLM/Lumina-Image-2.0"
IMPL_HUNYUAN_IMAGE = "https://github.com/Tencent-Hunyuan/HunyuanImage-2.1"
+IMPL_ANIMA = "https://huggingface.co/circlestone-labs/Anima"
PRED_TYPE_EPSILON = "epsilon"
PRED_TYPE_V = "v"
@@ -220,6 +223,12 @@ def determine_architecture(
arch = ARCH_HUNYUAN_IMAGE_2_1
else:
arch = ARCH_HUNYUAN_IMAGE_UNKNOWN
+ elif "anima" in model_config:
+ anima_type = model_config["anima"]
+ if anima_type == "preview":
+ arch = ARCH_ANIMA_PREVIEW
+ else:
+ arch = ARCH_ANIMA_UNKNOWN
elif v2:
arch = ARCH_SD_V2_768_V if v_parameterization else ARCH_SD_V2_512
else:
@@ -252,6 +261,8 @@ def determine_implementation(
return IMPL_FLUX
elif "lumina" in model_config:
return IMPL_LUMINA
+ elif "anima" in model_config:
+ return IMPL_ANIMA
elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
return IMPL_STABILITY_AI
else:
@@ -325,7 +336,7 @@ def determine_resolution(
reso = (reso[0], reso[0])
else:
# Determine default resolution based on model type
- if sdxl or "sd3" in model_config or "flux" in model_config or "lumina" in model_config:
+ if sdxl or "sd3" in model_config or "flux" in model_config or "lumina" in model_config or "anima" in model_config:
reso = (1024, 1024)
elif v2 and v_parameterization:
reso = (768, 768)
diff --git a/library/strategy_anima.py b/library/strategy_anima.py
index 9c9b0126..d89df5b9 100644
--- a/library/strategy_anima.py
+++ b/library/strategy_anima.py
@@ -9,6 +9,7 @@ import torch
from library import anima_utils, train_util
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
+from library import qwen_image_autoencoder_kl
from library.utils import setup_logging
@@ -45,8 +46,8 @@ class AnimaTokenizeStrategy(TokenizeStrategy):
t5_tokenizer = anima_utils.load_t5_tokenizer(t5_tokenizer_path)
self.qwen3_tokenizer = qwen3_tokenizer
- self.t5_tokenizer = t5_tokenizer
self.qwen3_max_length = qwen3_max_length
+ self.t5_tokenizer = t5_tokenizer
self.t5_max_length = t5_max_length
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
@@ -54,26 +55,17 @@ class AnimaTokenizeStrategy(TokenizeStrategy):
# Tokenize with Qwen3
qwen3_encoding = self.qwen3_tokenizer.batch_encode_plus(
- text,
- return_tensors="pt",
- truncation=True,
- padding="max_length",
- max_length=self.qwen3_max_length,
+ text, return_tensors="pt", truncation=True, padding="max_length", max_length=self.qwen3_max_length
)
qwen3_input_ids = qwen3_encoding["input_ids"]
qwen3_attn_mask = qwen3_encoding["attention_mask"]
# Tokenize with T5 (for LLM Adapter target tokens)
t5_encoding = self.t5_tokenizer.batch_encode_plus(
- text,
- return_tensors="pt",
- truncation=True,
- padding="max_length",
- max_length=self.t5_max_length,
+ text, return_tensors="pt", truncation=True, padding="max_length", max_length=self.t5_max_length
)
t5_input_ids = t5_encoding["input_ids"]
t5_attn_mask = t5_encoding["attention_mask"]
-
return [qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask]
@@ -84,46 +76,11 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
T5 tokens are passed through unchanged (only used by LLM Adapter).
"""
- def __init__(
- self,
- dropout_rate: float = 0.0,
- ) -> None:
- self.dropout_rate = dropout_rate
- # Cached unconditional embeddings (from encoding empty caption "")
- # Must be initialized via cache_uncond_embeddings() before text encoder is deleted
- self._uncond_prompt_embeds: Optional[torch.Tensor] = None # (1, seq_len, hidden)
- self._uncond_attn_mask: Optional[torch.Tensor] = None # (1, seq_len)
- self._uncond_t5_input_ids: Optional[torch.Tensor] = None # (1, t5_seq_len)
- self._uncond_t5_attn_mask: Optional[torch.Tensor] = None # (1, t5_seq_len)
-
- def cache_uncond_embeddings(
- self,
- tokenize_strategy: TokenizeStrategy,
- models: List[Any],
- ) -> None:
- """Pre-encode empty caption "" and cache the unconditional embeddings.
-
- Must be called before the text encoder is deleted from GPU.
- This matches diffusion-pipe-main behavior where empty caption embeddings
- are pre-cached and swapped in during caption dropout.
- """
- logger.info("Caching unconditional embeddings for caption dropout (encoding empty caption)...")
- tokens = tokenize_strategy.tokenize("")
- with torch.no_grad():
- uncond_outputs = self.encode_tokens(tokenize_strategy, models, tokens, enable_dropout=False)
- # Store as CPU tensors (1, seq_len, ...) to avoid GPU memory waste
- self._uncond_prompt_embeds = uncond_outputs[0].cpu()
- self._uncond_attn_mask = uncond_outputs[1].cpu()
- self._uncond_t5_input_ids = uncond_outputs[2].cpu()
- self._uncond_t5_attn_mask = uncond_outputs[3].cpu()
- logger.info(" Unconditional embeddings cached successfully")
+ def __init__(self) -> None:
+ super().__init__()
def encode_tokens(
- self,
- tokenize_strategy: TokenizeStrategy,
- models: List[Any],
- tokens: List[torch.Tensor],
- enable_dropout: bool = True,
+ self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
) -> List[torch.Tensor]:
"""Encode Qwen3 tokens and return embeddings + T5 token IDs.
@@ -134,82 +91,20 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
Returns:
[prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
"""
+ # Do not handle dropout here; handled dataset-side or in drop_cached_text_encoder_outputs()
qwen3_text_encoder = models[0]
qwen3_input_ids, qwen3_attn_mask, t5_input_ids, t5_attn_mask = tokens
- # Handle dropout: replace dropped items with unconditional embeddings (matching diffusion-pipe-main)
- batch_size = qwen3_input_ids.shape[0]
- non_drop_indices = []
- for i in range(batch_size):
- drop = enable_dropout and (self.dropout_rate > 0.0 and random.random() < self.dropout_rate)
- if not drop:
- non_drop_indices.append(i)
+ encoder_device = qwen3_text_encoder.device
- encoder_device = qwen3_text_encoder.device if hasattr(qwen3_text_encoder, 'device') else next(qwen3_text_encoder.parameters()).device
+ qwen3_input_ids = qwen3_input_ids.to(encoder_device)
+ qwen3_attn_mask = qwen3_attn_mask.to(encoder_device)
+ outputs = qwen3_text_encoder(input_ids=qwen3_input_ids, attention_mask=qwen3_attn_mask)
+ prompt_embeds = outputs.last_hidden_state
+ prompt_embeds[~qwen3_attn_mask.bool()] = 0
- if len(non_drop_indices) > 0 and len(non_drop_indices) < batch_size:
- # Only encode non-dropped items to save compute
- nd_input_ids = qwen3_input_ids[non_drop_indices].to(encoder_device)
- nd_attn_mask = qwen3_attn_mask[non_drop_indices].to(encoder_device)
- elif len(non_drop_indices) == batch_size:
- nd_input_ids = qwen3_input_ids.to(encoder_device)
- nd_attn_mask = qwen3_attn_mask.to(encoder_device)
- else:
- nd_input_ids = None
- nd_attn_mask = None
-
- if nd_input_ids is not None:
- outputs = qwen3_text_encoder(input_ids=nd_input_ids, attention_mask=nd_attn_mask)
- nd_encoded_text = outputs.last_hidden_state
- # Zero out padding positions
- nd_encoded_text[~nd_attn_mask.bool()] = 0
-
- # Build full batch: fill non-dropped with encoded, dropped with unconditional
- if len(non_drop_indices) == batch_size:
- prompt_embeds = nd_encoded_text
- attn_mask = qwen3_attn_mask.to(encoder_device)
- else:
- # Get unconditional embeddings
- if self._uncond_prompt_embeds is not None:
- uncond_pe = self._uncond_prompt_embeds[0]
- uncond_am = self._uncond_attn_mask[0]
- uncond_t5_ids = self._uncond_t5_input_ids[0]
- uncond_t5_am = self._uncond_t5_attn_mask[0]
- else:
- # Encode empty caption on-the-fly (text encoder still available)
- uncond_tokens = tokenize_strategy.tokenize("")
- uncond_ids = uncond_tokens[0].to(encoder_device)
- uncond_mask = uncond_tokens[1].to(encoder_device)
- uncond_out = qwen3_text_encoder(input_ids=uncond_ids, attention_mask=uncond_mask)
- uncond_pe = uncond_out.last_hidden_state[0]
- uncond_pe[~uncond_mask[0].bool()] = 0
- uncond_am = uncond_mask[0]
- uncond_t5_ids = uncond_tokens[2][0]
- uncond_t5_am = uncond_tokens[3][0]
-
- seq_len = qwen3_input_ids.shape[1]
- hidden_size = nd_encoded_text.shape[-1] if nd_encoded_text is not None else uncond_pe.shape[-1]
- dtype = nd_encoded_text.dtype if nd_encoded_text is not None else uncond_pe.dtype
-
- prompt_embeds = torch.zeros((batch_size, seq_len, hidden_size), device=encoder_device, dtype=dtype)
- attn_mask = torch.zeros((batch_size, seq_len), device=encoder_device, dtype=qwen3_attn_mask.dtype)
-
- if len(non_drop_indices) > 0:
- prompt_embeds[non_drop_indices] = nd_encoded_text
- attn_mask[non_drop_indices] = nd_attn_mask
-
- # Fill dropped items with unconditional embeddings
- t5_input_ids = t5_input_ids.clone()
- t5_attn_mask = t5_attn_mask.clone()
- drop_indices = [i for i in range(batch_size) if i not in non_drop_indices]
- for i in drop_indices:
- prompt_embeds[i] = uncond_pe.to(device=encoder_device, dtype=dtype)
- attn_mask[i] = uncond_am.to(device=encoder_device, dtype=qwen3_attn_mask.dtype)
- t5_input_ids[i] = uncond_t5_ids.to(device=t5_input_ids.device, dtype=t5_input_ids.dtype)
- t5_attn_mask[i] = uncond_t5_am.to(device=t5_attn_mask.device, dtype=t5_attn_mask.dtype)
-
- return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
+ return [prompt_embeds, qwen3_attn_mask, t5_input_ids, t5_attn_mask]
def drop_cached_text_encoder_outputs(
self,
@@ -217,6 +112,7 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
attn_mask: torch.Tensor,
t5_input_ids: torch.Tensor,
t5_attn_mask: torch.Tensor,
+ caption_dropout_rates: Optional[torch.Tensor] = None,
) -> List[torch.Tensor]:
"""Apply dropout to cached text encoder outputs.
@@ -224,37 +120,30 @@ class AnimaTextEncodingStrategy(TextEncodingStrategy):
Replaces dropped items with pre-cached unconditional embeddings (from encoding "")
to match diffusion-pipe-main behavior.
"""
- if prompt_embeds is not None and self.dropout_rate > 0.0:
- # Clone to avoid in-place modification of cached tensors
- prompt_embeds = prompt_embeds.clone()
- if attn_mask is not None:
- attn_mask = attn_mask.clone()
- if t5_input_ids is not None:
- t5_input_ids = t5_input_ids.clone()
- if t5_attn_mask is not None:
- t5_attn_mask = t5_attn_mask.clone()
+ if caption_dropout_rates is None or torch.all(caption_dropout_rates == 0.0).item():
+ return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
- for i in range(prompt_embeds.shape[0]):
- if random.random() < self.dropout_rate:
- if self._uncond_prompt_embeds is not None:
- # Use pre-cached unconditional embeddings
- prompt_embeds[i] = self._uncond_prompt_embeds[0].to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)
- if attn_mask is not None:
- attn_mask[i] = self._uncond_attn_mask[0].to(device=attn_mask.device, dtype=attn_mask.dtype)
- if t5_input_ids is not None:
- t5_input_ids[i] = self._uncond_t5_input_ids[0].to(device=t5_input_ids.device, dtype=t5_input_ids.dtype)
- if t5_attn_mask is not None:
- t5_attn_mask[i] = self._uncond_t5_attn_mask[0].to(device=t5_attn_mask.device, dtype=t5_attn_mask.dtype)
- else:
- # Fallback: zero out (should not happen if cache_uncond_embeddings was called)
- logger.warning("Unconditional embeddings not cached, falling back to zeros for caption dropout")
- prompt_embeds[i] = torch.zeros_like(prompt_embeds[i])
- if attn_mask is not None:
- attn_mask[i] = torch.zeros_like(attn_mask[i])
- if t5_input_ids is not None:
- t5_input_ids[i] = torch.zeros_like(t5_input_ids[i])
- if t5_attn_mask is not None:
- t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i])
+ # Clone to avoid in-place modification of cached tensors
+ prompt_embeds = prompt_embeds.clone()
+ if attn_mask is not None:
+ attn_mask = attn_mask.clone()
+ if t5_input_ids is not None:
+ t5_input_ids = t5_input_ids.clone()
+ if t5_attn_mask is not None:
+ t5_attn_mask = t5_attn_mask.clone()
+
+ for i in range(prompt_embeds.shape[0]):
+ if random.random() < caption_dropout_rates[i].item():
+ # Use pre-cached unconditional embeddings
+ prompt_embeds[i] = 0
+ if attn_mask is not None:
+ attn_mask[i] = 0
+ if t5_input_ids is not None:
+ t5_input_ids[i, 0] = 1 # Set to token ID
+ t5_input_ids[i, 1:] = 0
+ if t5_attn_mask is not None:
+ t5_attn_mask[i, 0] = 1
+ t5_attn_mask[i, 1:] = 0
return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
@@ -297,6 +186,8 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
return False
if "t5_attn_mask" not in npz:
return False
+ if "caption_dropout_rate" not in npz:
+ return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
@@ -309,7 +200,8 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
attn_mask = data["attn_mask"]
t5_input_ids = data["t5_input_ids"]
t5_attn_mask = data["t5_attn_mask"]
- return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask]
+ caption_dropout_rate = data["caption_dropout_rate"]
+ return [prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask, caption_dropout_rate]
def cache_batch_outputs(
self,
@@ -323,12 +215,8 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
tokens_and_masks = tokenize_strategy.tokenize(captions)
with torch.no_grad():
- # Always disable dropout during caching
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = anima_text_encoding_strategy.encode_tokens(
- tokenize_strategy,
- models,
- tokens_and_masks,
- enable_dropout=False,
+ tokenize_strategy, models, tokens_and_masks
)
# Convert to numpy for caching
@@ -344,6 +232,7 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
attn_mask_i = attn_mask[i]
t5_input_ids_i = t5_input_ids[i]
t5_attn_mask_i = t5_attn_mask[i]
+ caption_dropout_rate = torch.tensor(info.caption_dropout_rate, dtype=torch.float32)
if self.cache_to_disk:
np.savez(
@@ -352,9 +241,10 @@ class AnimaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
attn_mask=attn_mask_i,
t5_input_ids=t5_input_ids_i,
t5_attn_mask=t5_attn_mask_i,
+ caption_dropout_rate=caption_dropout_rate,
)
else:
- info.text_encoder_outputs = (prompt_embeds_i, attn_mask_i, t5_input_ids_i, t5_attn_mask_i)
+ info.text_encoder_outputs = (prompt_embeds_i, attn_mask_i, t5_input_ids_i, t5_attn_mask_i, caption_dropout_rate)
class AnimaLatentsCachingStrategy(LatentsCachingStrategy):
@@ -374,18 +264,10 @@ class AnimaLatentsCachingStrategy(LatentsCachingStrategy):
return self.ANIMA_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
- return (
- os.path.splitext(absolute_path)[0]
- + f"_{image_size[0]:04d}x{image_size[1]:04d}"
- + self.ANIMA_LATENTS_NPZ_SUFFIX
- )
+ return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.ANIMA_LATENTS_NPZ_SUFFIX
- def is_disk_cached_latents_expected(
- self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool
- ):
- return self._default_is_disk_cached_latents_expected(
- 8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True
- )
+ def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
+ return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
def load_latents_from_disk(
self, npz_path: str, bucket_reso: Tuple[int, int]
@@ -393,32 +275,23 @@ class AnimaLatentsCachingStrategy(LatentsCachingStrategy):
return self._default_load_latents_from_disk(8, npz_path, bucket_reso)
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
- """Cache batch of latents using WanVAE.
+ """Cache batch of latents using Qwen Image VAE.
- vae is expected to be the WanVAE_ model (not the wrapper).
+ vae is expected to be the Qwen Image VAE (AutoencoderKLQwenImage).
The encoding function handles the mean/std normalization.
"""
- from library.anima_models import ANIMA_VAE_MEAN, ANIMA_VAE_STD
-
- vae_device = next(vae.parameters()).device
- vae_dtype = next(vae.parameters()).dtype
-
- # Create scale tensors on VAE device
- mean = torch.tensor(ANIMA_VAE_MEAN, dtype=vae_dtype, device=vae_device)
- std = torch.tensor(ANIMA_VAE_STD, dtype=vae_dtype, device=vae_device)
- scale = [mean, 1.0 / std]
+ vae: qwen_image_autoencoder_kl.AutoencoderKLQwenImage = vae
+ vae_device = vae.device
+ vae_dtype = vae.dtype
def encode_by_vae(img_tensor):
"""Encode image tensor to latents.
img_tensor: (B, C, H, W) in [-1, 1] range (already normalized by IMAGE_TRANSFORMS)
- Need to add temporal dim to get (B, C, T=1, H, W) for WanVAE
+ Qwen Image VAE accepts inputs in (B, C, H, W) or (B, C, 1, H, W) shape.
+ Returns latents in (B, 16, 1, H/8, W/8) shape on CPU.
"""
- # Add temporal dimension: (B, C, H, W) -> (B, C, 1, H, W)
- img_tensor = img_tensor.unsqueeze(2)
- img_tensor = img_tensor.to(vae_device, dtype=vae_dtype)
-
- latents = vae.encode(img_tensor, scale)
+ latents = vae.encode_pixels_to_latents(img_tensor) # Keep 4D for input/output
return latents.to("cpu")
self._default_cache_batch_latents(
diff --git a/library/train_util.py b/library/train_util.py
index 6874076d..d8577b9d 100644
--- a/library/train_util.py
+++ b/library/train_util.py
@@ -179,12 +179,15 @@ def split_train_val(
class ImageInfo:
- def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
+ def __init__(
+ self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str, caption_dropout_rate: float = 0.0
+ ) -> None:
self.image_key: str = image_key
self.num_repeats: int = num_repeats
self.caption: str = caption
self.is_reg: bool = is_reg
self.absolute_path: str = absolute_path
+ self.caption_dropout_rate: float = caption_dropout_rate
self.image_size: Tuple[int, int] = None
self.resized_size: Tuple[int, int] = None
self.bucket_reso: Tuple[int, int] = None
@@ -197,7 +200,7 @@ class ImageInfo:
)
self.cond_img_path: Optional[str] = None
self.image: Optional[Image.Image] = None # optional, original PIL Image
- self.text_encoder_outputs_npz: Optional[str] = None # set in cache_text_encoder_outputs
+ self.text_encoder_outputs_npz: Optional[str] = None # filename. set in cache_text_encoder_outputs
# new
self.text_encoder_outputs: Optional[List[torch.Tensor]] = None
@@ -1096,11 +1099,11 @@ class BaseDataset(torch.utils.data.Dataset):
def is_latent_cacheable(self):
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
- def is_text_encoder_output_cacheable(self):
+ def is_text_encoder_output_cacheable(self, cache_supports_dropout: bool = False):
return all(
[
not (
- subset.caption_dropout_rate > 0
+ subset.caption_dropout_rate > 0 and not cache_supports_dropout
or subset.shuffle_caption
or subset.token_warmup_step > 0
or subset.caption_tag_dropout_rate > 0
@@ -2137,7 +2140,7 @@ class DreamBoothDataset(BaseDataset):
num_train_images += num_repeats * len(img_paths)
for img_path, caption, size in zip(img_paths, captions, sizes):
- info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path)
+ info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path, subset.caption_dropout_rate)
info.resize_interpolation = (
subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
)
@@ -2338,7 +2341,7 @@ class FineTuningDataset(BaseDataset):
if caption is None:
caption = ""
- image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path)
+ image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path, subset.caption_dropout_rate)
image_info.resize_interpolation = (
subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
)
@@ -2661,8 +2664,8 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
def is_latent_cacheable(self) -> bool:
return all([dataset.is_latent_cacheable() for dataset in self.datasets])
- def is_text_encoder_output_cacheable(self) -> bool:
- return all([dataset.is_text_encoder_output_cacheable() for dataset in self.datasets])
+ def is_text_encoder_output_cacheable(self, cache_supports_dropout: bool = False) -> bool:
+ return all([dataset.is_text_encoder_output_cacheable(cache_supports_dropout) for dataset in self.datasets])
def set_current_strategies(self):
for dataset in self.datasets:
@@ -3578,6 +3581,7 @@ def get_sai_model_spec_dataclass(
flux: str = None,
lumina: str = None,
hunyuan_image: str = None,
+ anima: str = None,
optional_metadata: dict[str, str] | None = None,
) -> sai_model_spec.ModelSpecMetadata:
"""
@@ -3609,7 +3613,8 @@ def get_sai_model_spec_dataclass(
model_config["lumina"] = lumina
if hunyuan_image is not None:
model_config["hunyuan_image"] = hunyuan_image
-
+ if anima is not None:
+ model_config["anima"] = anima
# Use the dataclass function directly
return sai_model_spec.build_metadata_dataclass(
state_dict,
diff --git a/networks/convert_anima_lora_to_comfy.py b/networks/convert_anima_lora_to_comfy.py
new file mode 100644
index 00000000..5ff2b9ee
--- /dev/null
+++ b/networks/convert_anima_lora_to_comfy.py
@@ -0,0 +1,160 @@
+import argparse
+from safetensors.torch import save_file
+from safetensors import safe_open
+
+
+from library import train_util
+from library.utils import setup_logging
+
+setup_logging()
+import logging
+
+logger = logging.getLogger(__name__)
+
+COMFYUI_DIT_PREFIX = "diffusion_model."
+COMFYUI_QWEN3_PREFIX = "text_encoders.qwen3_06b.transformer.model."
+
+
+def main(args):
+ # load source safetensors
+ logger.info(f"Loading source file {args.src_path}")
+ state_dict = {}
+ with safe_open(args.src_path, framework="pt") as f:
+ metadata = f.metadata()
+ for k in f.keys():
+ state_dict[k] = f.get_tensor(k)
+
+ logger.info(f"Converting...")
+
+ keys = list(state_dict.keys())
+ count = 0
+
+ for k in keys:
+ if not args.reverse:
+ is_dit_lora = k.startswith("lora_unet_")
+ module_and_weight_name = "_".join(k.split("_")[2:]) # Remove `lora_unet_`or `lora_te_` prefix
+
+ # Split at the first dot, e.g., "block1_linear.weight" -> "block1_linear", "weight"
+ module_name, weight_name = module_and_weight_name.split(".", 1)
+
+ # Weight name conversion: lora_up/lora_down to lora_A/lora_B
+ if weight_name.startswith("lora_up"):
+ weight_name = weight_name.replace("lora_up", "lora_B")
+ elif weight_name.startswith("lora_down"):
+ weight_name = weight_name.replace("lora_down", "lora_A")
+ else:
+ # Keep other weight names as-is: e.g. alpha
+ pass
+
+ # Module name conversion: convert dots to underscores
+ original_module_name = module_name.replace("_", ".") # Convert to dot notation
+
+ # Convert back illegal dots in module names
+ # DiT
+ original_module_name = original_module_name.replace("llm.adapter", "llm_adapter")
+ original_module_name = original_module_name.replace(".linear.", ".linear_")
+ original_module_name = original_module_name.replace("t.embedding.norm", "t_embedding_norm")
+ original_module_name = original_module_name.replace("x.embedder", "x_embedder")
+ original_module_name = original_module_name.replace("adaln.modulation.cross_attn", "adaln_modulation_cross_attn")
+ original_module_name = original_module_name.replace("adaln.modulation.mlp", "adaln_modulation_mlp")
+ original_module_name = original_module_name.replace("cross.attn", "cross_attn")
+ original_module_name = original_module_name.replace("k.proj", "k_proj")
+ original_module_name = original_module_name.replace("k.norm", "k_norm")
+ original_module_name = original_module_name.replace("q.proj", "q_proj")
+ original_module_name = original_module_name.replace("q.norm", "q_norm")
+ original_module_name = original_module_name.replace("v.proj", "v_proj")
+ original_module_name = original_module_name.replace("o.proj", "o_proj")
+ original_module_name = original_module_name.replace("output.proj", "output_proj")
+ original_module_name = original_module_name.replace("self.attn", "self_attn")
+ original_module_name = original_module_name.replace("final.layer", "final_layer")
+ original_module_name = original_module_name.replace("adaln.modulation", "adaln_modulation")
+ original_module_name = original_module_name.replace("norm.cross.attn", "norm_cross_attn")
+ original_module_name = original_module_name.replace("norm.mlp", "norm_mlp")
+ original_module_name = original_module_name.replace("norm.self.attn", "norm_self_attn")
+ original_module_name = original_module_name.replace("out.proj", "out_proj")
+
+ # Qwen3
+ original_module_name = original_module_name.replace("embed.tokens", "embed_tokens")
+ original_module_name = original_module_name.replace("input.layernorm", "input_layernorm")
+ original_module_name = original_module_name.replace("down.proj", "down_proj")
+ original_module_name = original_module_name.replace("gate.proj", "gate_proj")
+ original_module_name = original_module_name.replace("up.proj", "up_proj")
+ original_module_name = original_module_name.replace("post.attention.layernorm", "post_attention_layernorm")
+
+ # Prefix conversion
+ new_prefix = COMFYUI_DIT_PREFIX if is_dit_lora else COMFYUI_QWEN3_PREFIX
+
+ new_k = f"{new_prefix}{original_module_name}.{weight_name}"
+ else:
+ if k.startswith(COMFYUI_DIT_PREFIX):
+ is_dit_lora = True
+ module_and_weight_name = k[len(COMFYUI_DIT_PREFIX) :]
+ elif k.startswith(COMFYUI_QWEN3_PREFIX):
+ is_dit_lora = False
+ module_and_weight_name = k[len(COMFYUI_QWEN3_PREFIX) :]
+ else:
+ logger.warning(f"Skipping unrecognized key {k}")
+ continue
+
+ # Get weight name
+ if ".lora_" in module_and_weight_name:
+ module_name, weight_name = module_and_weight_name.rsplit(".lora_", 1)
+ weight_name = "lora_" + weight_name
+ else:
+ module_name, weight_name = module_and_weight_name.rsplit(".", 1) # Keep other weight names as-is: e.g. alpha
+
+ # Weight name conversion: lora_A/lora_B to lora_up/lora_down
+ # Note: we only convert lora_A and lora_B weights, other weights are kept as-is
+ if weight_name.startswith("lora_B"):
+ weight_name = weight_name.replace("lora_B", "lora_up")
+ elif weight_name.startswith("lora_A"):
+ weight_name = weight_name.replace("lora_A", "lora_down")
+
+ # Module name conversion: convert dots to underscores
+ module_name = module_name.replace(".", "_") # Convert to underscore notation
+
+ # Prefix conversion
+ prefix = "lora_unet_" if is_dit_lora else "lora_te_"
+
+ new_k = f"{prefix}{module_name}.{weight_name}"
+
+ state_dict[new_k] = state_dict.pop(k)
+ count += 1
+
+ logger.info(f"Converted {count} keys")
+ if count == 0:
+ logger.warning("No keys were converted. Please check if the source file is in the expected format.")
+ elif count > 0 and count < len(keys):
+ logger.warning(
+ f"Only {count} out of {len(keys)} keys were converted. Please check if there are unexpected keys in the source file."
+ )
+
+ # Calculate hash
+ if metadata is not None:
+ logger.info(f"Calculating hashes and creating metadata...")
+ model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
+ metadata["sshs_model_hash"] = model_hash
+ metadata["sshs_legacy_hash"] = legacy_hash
+
+ # save destination safetensors
+ logger.info(f"Saving destination file {args.dst_path}")
+ save_file(state_dict, args.dst_path, metadata=metadata)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Convert LoRA format")
+ parser.add_argument(
+ "src_path",
+ type=str,
+ default=None,
+ help="source path, sd-scripts format (or ComfyUI compatible format if --reverse is set, only supported for LoRAs converted by this script)",
+ )
+ parser.add_argument(
+ "dst_path",
+ type=str,
+ default=None,
+ help="destination path, ComfyUI compatible format (or sd-scripts format if --reverse is set)",
+ )
+ parser.add_argument("--reverse", action="store_true", help="reverse conversion direction")
+ args = parser.parse_args()
+ main(args)
diff --git a/networks/lora_anima.py b/networks/lora_anima.py
index c375ead7..224ef20c 100644
--- a/networks/lora_anima.py
+++ b/networks/lora_anima.py
@@ -1,18 +1,17 @@
-# LoRA network module for Anima
-import math
+# LoRA network module for Anima
+import ast
import os
+import re
from typing import Dict, List, Optional, Tuple, Type, Union
-import numpy as np
import torch
from library.utils import setup_logging
+from networks.lora_flux import LoRAModule, LoRAInfModule
-setup_logging()
import logging
+setup_logging()
logger = logging.getLogger(__name__)
-from networks.lora_flux import LoRAModule, LoRAInfModule
-
def create_network(
multiplier: float,
@@ -29,68 +28,28 @@ def create_network(
if network_alpha is None:
network_alpha = 1.0
- # type_dims: [self_attn_dim, cross_attn_dim, mlp_dim, mod_dim, llm_adapter_dim]
- self_attn_dim = kwargs.get("self_attn_dim", None)
- cross_attn_dim = kwargs.get("cross_attn_dim", None)
- mlp_dim = kwargs.get("mlp_dim", None)
- mod_dim = kwargs.get("mod_dim", None)
- llm_adapter_dim = kwargs.get("llm_adapter_dim", None)
-
- if self_attn_dim is not None:
- self_attn_dim = int(self_attn_dim)
- if cross_attn_dim is not None:
- cross_attn_dim = int(cross_attn_dim)
- if mlp_dim is not None:
- mlp_dim = int(mlp_dim)
- if mod_dim is not None:
- mod_dim = int(mod_dim)
- if llm_adapter_dim is not None:
- llm_adapter_dim = int(llm_adapter_dim)
-
- type_dims = [self_attn_dim, cross_attn_dim, mlp_dim, mod_dim, llm_adapter_dim]
- if all([d is None for d in type_dims]):
- type_dims = None
-
- # emb_dims: [x_embedder, t_embedder, final_layer]
- emb_dims = kwargs.get("emb_dims", None)
- if emb_dims is not None:
- emb_dims = emb_dims.strip()
- if emb_dims.startswith("[") and emb_dims.endswith("]"):
- emb_dims = emb_dims[1:-1]
- emb_dims = [int(d) for d in emb_dims.split(",")]
- assert len(emb_dims) == 3, f"invalid emb_dims: {emb_dims}, must be 3 dimensions (x_embedder, t_embedder, final_layer)"
-
- # block selection
- def parse_block_selection(selection: str, total_blocks: int) -> List[bool]:
- if selection == "all":
- return [True] * total_blocks
- if selection == "none" or selection == "":
- return [False] * total_blocks
-
- selected = [False] * total_blocks
- ranges = selection.split(",")
- for r in ranges:
- if "-" in r:
- start, end = map(str.strip, r.split("-"))
- start, end = int(start), int(end)
- assert 0 <= start < total_blocks and 0 <= end < total_blocks and start <= end
- for i in range(start, end + 1):
- selected[i] = True
- else:
- index = int(r)
- assert 0 <= index < total_blocks
- selected[index] = True
- return selected
-
- train_block_indices = kwargs.get("train_block_indices", None)
- if train_block_indices is not None:
- num_blocks = len(unet.blocks) if hasattr(unet, 'blocks') else 999
- train_block_indices = parse_block_selection(train_block_indices, num_blocks)
-
# train LLM adapter
- train_llm_adapter = kwargs.get("train_llm_adapter", False)
+ train_llm_adapter = kwargs.get("train_llm_adapter", "false")
if train_llm_adapter is not None:
- train_llm_adapter = True if train_llm_adapter == "True" else False
+ train_llm_adapter = True if train_llm_adapter.lower() == "true" else False
+
+ exclude_patterns = kwargs.get("exclude_patterns", None)
+ if exclude_patterns is None:
+ exclude_patterns = []
+ else:
+ exclude_patterns = ast.literal_eval(exclude_patterns)
+ if not isinstance(exclude_patterns, list):
+ exclude_patterns = [exclude_patterns]
+
+ # add default exclude patterns
+ exclude_patterns.append(r".*(_modulation|_norm|_embedder|final_layer).*")
+
+ # regular expression for module selection: exclude and include
+ include_patterns = kwargs.get("include_patterns", None)
+ if include_patterns is not None:
+ include_patterns = ast.literal_eval(include_patterns)
+ if not isinstance(include_patterns, list):
+ include_patterns = [include_patterns]
# rank/module dropout
rank_dropout = kwargs.get("rank_dropout", None)
@@ -101,9 +60,43 @@ def create_network(
module_dropout = float(module_dropout)
# verbose
- verbose = kwargs.get("verbose", False)
+ verbose = kwargs.get("verbose", "false")
if verbose is not None:
- verbose = True if verbose == "True" else False
+ verbose = True if verbose.lower() == "true" else False
+
+ # regex-specific learning rates / dimensions
+ def parse_kv_pairs(kv_pair_str: str, is_int: bool) -> Dict[str, float]:
+ """
+ Parse a string of key-value pairs separated by commas.
+ """
+ pairs = {}
+ for pair in kv_pair_str.split(","):
+ pair = pair.strip()
+ if not pair:
+ continue
+ if "=" not in pair:
+ logger.warning(f"Invalid format: {pair}, expected 'key=value'")
+ continue
+ key, value = pair.split("=", 1)
+ key = key.strip()
+ value = value.strip()
+ try:
+ pairs[key] = int(value) if is_int else float(value)
+ except ValueError:
+ logger.warning(f"Invalid value for {key}: {value}")
+ return pairs
+
+ network_reg_lrs = kwargs.get("network_reg_lrs", None)
+ if network_reg_lrs is not None:
+ reg_lrs = parse_kv_pairs(network_reg_lrs, is_int=False)
+ else:
+ reg_lrs = None
+
+ network_reg_dims = kwargs.get("network_reg_dims", None)
+ if network_reg_dims is not None:
+ reg_dims = parse_kv_pairs(network_reg_dims, is_int=True)
+ else:
+ reg_dims = None
network = LoRANetwork(
text_encoders,
@@ -115,9 +108,10 @@ def create_network(
rank_dropout=rank_dropout,
module_dropout=module_dropout,
train_llm_adapter=train_llm_adapter,
- type_dims=type_dims,
- emb_dims=emb_dims,
- train_block_indices=train_block_indices,
+ exclude_patterns=exclude_patterns,
+ include_patterns=include_patterns,
+ reg_dims=reg_dims,
+ reg_lrs=reg_lrs,
verbose=verbose,
)
@@ -137,6 +131,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, unet, weigh
if weights_sd is None:
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
+
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
@@ -173,15 +168,15 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, unet, weigh
class LoRANetwork(torch.nn.Module):
- # Target modules: DiT blocks
- ANIMA_TARGET_REPLACE_MODULE = ["Block"]
+ # Target modules: DiT blocks, embedders, final layer. embedders and final layer are excluded by default.
+ ANIMA_TARGET_REPLACE_MODULE = ["Block", "PatchEmbed", "TimestepEmbedding", "FinalLayer"]
# Target modules: LLM Adapter blocks
ANIMA_ADAPTER_TARGET_REPLACE_MODULE = ["LLMAdapterTransformerBlock"]
# Target modules for text encoder (Qwen3)
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["Qwen3Attention", "Qwen3MLP", "Qwen3SdpaAttention", "Qwen3FlashAttention2"]
LORA_PREFIX_ANIMA = "lora_unet" # ComfyUI compatible
- LORA_PREFIX_TEXT_ENCODER = "lora_te1" # Qwen3
+ LORA_PREFIX_TEXT_ENCODER = "lora_te" # Qwen3
def __init__(
self,
@@ -197,9 +192,10 @@ class LoRANetwork(torch.nn.Module):
modules_dim: Optional[Dict[str, int]] = None,
modules_alpha: Optional[Dict[str, int]] = None,
train_llm_adapter: bool = False,
- type_dims: Optional[List[int]] = None,
- emb_dims: Optional[List[int]] = None,
- train_block_indices: Optional[List[bool]] = None,
+ exclude_patterns: Optional[List[str]] = None,
+ include_patterns: Optional[List[str]] = None,
+ reg_dims: Optional[Dict[str, int]] = None,
+ reg_lrs: Optional[Dict[str, float]] = None,
verbose: Optional[bool] = False,
) -> None:
super().__init__()
@@ -210,21 +206,36 @@ class LoRANetwork(torch.nn.Module):
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.train_llm_adapter = train_llm_adapter
- self.type_dims = type_dims
- self.emb_dims = emb_dims
- self.train_block_indices = train_block_indices
+ self.reg_dims = reg_dims
+ self.reg_lrs = reg_lrs
self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None
self.loraplus_text_encoder_lr_ratio = None
if modules_dim is not None:
- logger.info(f"create LoRA network from weights")
- if self.emb_dims is None:
- self.emb_dims = [0] * 3
+ logger.info("create LoRA network from weights")
else:
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
- logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
+ logger.info(
+ f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
+ )
+
+ # compile regular expression if specified
+ def str_to_re_patterns(patterns: Optional[List[str]]) -> List[re.Pattern]:
+ re_patterns = []
+ if patterns is not None:
+ for pattern in patterns:
+ try:
+ re_pattern = re.compile(pattern)
+ except re.error as e:
+ logger.error(f"Invalid pattern '{pattern}': {e}")
+ continue
+ re_patterns.append(re_pattern)
+ return re_patterns
+
+ exclude_re_patterns = str_to_re_patterns(exclude_patterns)
+ include_re_patterns = str_to_re_patterns(include_patterns)
# create module instances
def create_modules(
@@ -232,15 +243,9 @@ class LoRANetwork(torch.nn.Module):
text_encoder_idx: Optional[int],
root_module: torch.nn.Module,
target_replace_modules: List[str],
- filter: Optional[str] = None,
default_dim: Optional[int] = None,
- include_conv2d_if_filter: bool = False,
) -> Tuple[List[LoRAModule], List[str]]:
- prefix = (
- self.LORA_PREFIX_ANIMA
- if is_unet
- else self.LORA_PREFIX_TEXT_ENCODER
- )
+ prefix = self.LORA_PREFIX_ANIMA if is_unet else self.LORA_PREFIX_TEXT_ENCODER
loras = []
skipped = []
@@ -255,14 +260,16 @@ class LoRANetwork(torch.nn.Module):
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
if is_linear or is_conv2d:
- lora_name = prefix + "." + (name + "." if name else "") + child_name
- lora_name = lora_name.replace(".", "_")
+ original_name = (name + "." if name else "") + child_name
+ lora_name = f"{prefix}.{original_name}".replace(".", "_")
- force_incl_conv2d = False
- if filter is not None:
- if filter not in lora_name:
- continue
- force_incl_conv2d = include_conv2d_if_filter
+ # exclude/include filter (fullmatch: pattern must match the entire original_name)
+ excluded = any(pattern.fullmatch(original_name) for pattern in exclude_re_patterns)
+ included = any(pattern.fullmatch(original_name) for pattern in include_re_patterns)
+ if excluded and not included:
+ if verbose:
+ logger.info(f"exclude: {original_name}")
+ continue
dim = None
alpha_val = None
@@ -272,43 +279,18 @@ class LoRANetwork(torch.nn.Module):
dim = modules_dim[lora_name]
alpha_val = modules_alpha[lora_name]
else:
- if is_linear or is_conv2d_1x1:
- dim = default_dim if default_dim is not None else self.lora_dim
- alpha_val = self.alpha
-
- if is_unet and type_dims is not None:
- # type_dims = [self_attn_dim, cross_attn_dim, mlp_dim, mod_dim, llm_adapter_dim]
- # Order matters: check most specific identifiers first to avoid mismatches.
- identifier_order = [
- (4, ("llm_adapter",)),
- (3, ("adaln_modulation",)),
- (0, ("self_attn",)),
- (1, ("cross_attn",)),
- (2, ("mlp",)),
- ]
- for idx, ids in identifier_order:
- d = type_dims[idx]
- if d is not None and all(id_str in lora_name for id_str in ids):
- dim = d # 0 means skip
- break
-
- # block index filtering
- if is_unet and dim and self.train_block_indices is not None and "blocks_" in lora_name:
- # Extract block index from lora_name: "lora_unet_blocks_0_self_attn..."
- parts = lora_name.split("_")
- for pi, part in enumerate(parts):
- if part == "blocks" and pi + 1 < len(parts):
- try:
- block_index = int(parts[pi + 1])
- if not self.train_block_indices[block_index]:
- dim = 0
- except (ValueError, IndexError):
- pass
- break
-
- elif force_incl_conv2d:
- dim = default_dim if default_dim is not None else self.lora_dim
- alpha_val = self.alpha
+ if self.reg_dims is not None:
+ for reg, d in self.reg_dims.items():
+ if re.fullmatch(reg, original_name):
+ dim = d
+ alpha_val = self.alpha
+ logger.info(f"Module {original_name} matched with regex '{reg}' -> dim: {dim}")
+ break
+ # fallback to default dim if not matched by reg_dims or reg_dims is not specified
+ if dim is None:
+ if is_linear or is_conv2d_1x1:
+ dim = default_dim if default_dim is not None else self.lora_dim
+ alpha_val = self.alpha
if dim is None or dim == 0:
if is_linear or is_conv2d_1x1:
@@ -325,6 +307,7 @@ class LoRANetwork(torch.nn.Module):
rank_dropout=rank_dropout,
module_dropout=module_dropout,
)
+ lora.original_name = original_name
loras.append(lora)
if target_replace_modules is None:
@@ -339,9 +322,7 @@ class LoRANetwork(torch.nn.Module):
if text_encoder is None:
continue
logger.info(f"create LoRA for Text Encoder {i+1}:")
- te_loras, te_skipped = create_modules(
- False, i, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
- )
+ te_loras, te_skipped = create_modules(False, i, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
logger.info(f"create LoRA for Text Encoder {i+1}: {len(te_loras)} modules.")
self.text_encoder_loras.extend(te_loras)
skipped_te += te_skipped
@@ -354,19 +335,6 @@ class LoRANetwork(torch.nn.Module):
self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
- # emb_dims: [x_embedder, t_embedder, final_layer]
- if self.emb_dims:
- for filter_name, in_dim in zip(
- ["x_embedder", "t_embedder", "final_layer"],
- self.emb_dims,
- ):
- loras, _ = create_modules(
- True, None, unet, None,
- filter=filter_name, default_dim=in_dim,
- include_conv2d_if_filter=(filter_name == "x_embedder"),
- )
- self.unet_loras.extend(loras)
-
logger.info(f"create LoRA for Anima DiT: {len(self.unet_loras)} modules.")
if verbose:
for lora in self.unet_loras:
@@ -396,6 +364,7 @@ class LoRANetwork(torch.nn.Module):
def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
+
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
@@ -443,10 +412,10 @@ class LoRANetwork(torch.nn.Module):
sd_for_lora = {}
for key in weights_sd.keys():
if key.startswith(lora.lora_name):
- sd_for_lora[key[len(lora.lora_name) + 1:]] = weights_sd[key]
+ sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
lora.merge_to(sd_for_lora, dtype, device)
- logger.info(f"weights are merged")
+ logger.info("weights are merged")
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
self.loraplus_lr_ratio = loraplus_lr_ratio
@@ -471,8 +440,29 @@ class LoRANetwork(torch.nn.Module):
def assemble_params(loras, lr, loraplus_ratio):
param_groups = {"lora": {}, "plus": {}}
+ reg_groups = {}
+ reg_lrs_list = list(self.reg_lrs.items()) if self.reg_lrs is not None else []
+
for lora in loras:
+ matched_reg_lr = None
+ for i, (regex_str, reg_lr) in enumerate(reg_lrs_list):
+ if re.fullmatch(regex_str, lora.original_name):
+ matched_reg_lr = (i, reg_lr)
+ logger.info(f"Module {lora.original_name} matched regex '{regex_str}' -> LR {reg_lr}")
+ break
+
for name, param in lora.named_parameters():
+ if matched_reg_lr is not None:
+ reg_idx, reg_lr = matched_reg_lr
+ group_key = f"reg_lr_{reg_idx}"
+ if group_key not in reg_groups:
+ reg_groups[group_key] = {"lora": {}, "plus": {}, "lr": reg_lr}
+ if loraplus_ratio is not None and "lora_up" in name:
+ reg_groups[group_key]["plus"][f"{lora.lora_name}.{name}"] = param
+ else:
+ reg_groups[group_key]["lora"][f"{lora.lora_name}.{name}"] = param
+ continue
+
if loraplus_ratio is not None and "lora_up" in name:
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
else:
@@ -480,6 +470,23 @@ class LoRANetwork(torch.nn.Module):
params = []
descriptions = []
+ for group_key, group in reg_groups.items():
+ reg_lr = group["lr"]
+ for key in ("lora", "plus"):
+ param_data = {"params": group[key].values()}
+ if len(param_data["params"]) == 0:
+ continue
+ if key == "plus":
+ param_data["lr"] = reg_lr * loraplus_ratio if loraplus_ratio is not None else reg_lr
+ else:
+ param_data["lr"] = reg_lr
+ if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
+ logger.info("NO LR skipping!")
+ continue
+ params.append(param_data)
+ desc = f"reg_lr_{group_key.split('_')[-1]}"
+ descriptions.append(desc + (" plus" if key == "plus" else ""))
+
for key in param_groups.keys():
param_data = {"params": param_groups[key].values()}
if len(param_data["params"]) == 0:
@@ -498,10 +505,7 @@ class LoRANetwork(torch.nn.Module):
if self.text_encoder_loras:
loraplus_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio
- te1_loras = [
- lora for lora in self.text_encoder_loras
- if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER)
- ]
+ te1_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER)]
if len(te1_loras) > 0:
logger.info(f"Text Encoder 1 (Qwen3): {len(te1_loras)} modules, LR {text_encoder_lr[0]}")
params, descriptions = assemble_params(te1_loras, text_encoder_lr[0], loraplus_ratio)
diff --git a/tests/test_anima_cache.py b/tests/manual_test_anima_cache.py
similarity index 89%
rename from tests/test_anima_cache.py
rename to tests/manual_test_anima_cache.py
index 1684eb53..9809beba 100644
--- a/tests/test_anima_cache.py
+++ b/tests/manual_test_anima_cache.py
@@ -2,7 +2,7 @@
Diagnostic script to test Anima latent & text encoder caching independently.
Usage:
- python test_anima_cache.py \
+ python manual_test_anima_cache.py \
--image_dir /path/to/images \
--qwen3_path /path/to/qwen3 \
--vae_path /path/to/vae.safetensors \
@@ -30,10 +30,12 @@ from torchvision import transforms
IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".tiff"}
-IMAGE_TRANSFORMS = transforms.Compose([
- transforms.ToTensor(), # [0,1]
- transforms.Normalize([0.5], [0.5]), # [-1,1]
-])
+IMAGE_TRANSFORMS = transforms.Compose(
+ [
+ transforms.ToTensor(), # [0,1]
+ transforms.Normalize([0.5], [0.5]), # [-1,1]
+ ]
+)
def find_image_caption_pairs(image_dir: str):
@@ -60,35 +62,32 @@ def print_tensor_info(name: str, t, indent=2):
print(f"{prefix}{name}: None")
return
if isinstance(t, np.ndarray):
- print(f"{prefix}{name}: numpy {t.dtype} shape={t.shape} "
- f"min={t.min():.4f} max={t.max():.4f} mean={t.mean():.4f}")
+ print(f"{prefix}{name}: numpy {t.dtype} shape={t.shape} " f"min={t.min():.4f} max={t.max():.4f} mean={t.mean():.4f}")
elif isinstance(t, torch.Tensor):
- print(f"{prefix}{name}: torch {t.dtype} shape={tuple(t.shape)} "
- f"min={t.min().item():.4f} max={t.max().item():.4f} mean={t.float().mean().item():.4f}")
+ print(
+ f"{prefix}{name}: torch {t.dtype} shape={tuple(t.shape)} "
+ f"min={t.min().item():.4f} max={t.max().item():.4f} mean={t.float().mean().item():.4f}"
+ )
else:
print(f"{prefix}{name}: type={type(t)} value={t}")
# Test 1: Latent Cache
+
def test_latent_cache(args, pairs):
print("\n" + "=" * 70)
print("TEST 1: LATENT CACHING (VAE encode -> cache -> reload)")
print("=" * 70)
- from library import anima_utils
- from library.anima_models import ANIMA_VAE_MEAN, ANIMA_VAE_STD
+ from library import qwen_image_autoencoder_kl
# Load VAE
print("\n[1.1] Loading VAE...")
device = "cuda" if torch.cuda.is_available() else "cpu"
vae_dtype = torch.float32
- vae, vae_mean, vae_std, vae_scale = anima_utils.load_anima_vae(
- args.vae_path, dtype=vae_dtype, device=device
- )
+ vae = qwen_image_autoencoder_kl.load_vae(args.vae_path, dtype=vae_dtype, device=device)
print(f" VAE loaded on {device}, dtype={vae_dtype}")
- print(f" VAE mean (first 4): {ANIMA_VAE_MEAN[:4]}")
- print(f" VAE std (first 4): {ANIMA_VAE_STD[:4]}")
for img_path, caption in pairs:
print(f"\n[1.2] Processing: {os.path.basename(img_path)}")
@@ -96,13 +95,13 @@ def test_latent_cache(args, pairs):
# Load image
img = Image.open(img_path).convert("RGB")
img_np = np.array(img)
- print(f" Raw image: {img_np.shape} dtype={img_np.dtype} "
- f"min={img_np.min()} max={img_np.max()}")
+ print(f" Raw image: {img_np.shape} dtype={img_np.dtype} " f"min={img_np.min()} max={img_np.max()}")
# Apply IMAGE_TRANSFORMS (same as sd-scripts training)
img_tensor = IMAGE_TRANSFORMS(img_np)
- print(f" After IMAGE_TRANSFORMS: shape={tuple(img_tensor.shape)} "
- f"min={img_tensor.min():.4f} max={img_tensor.max():.4f}")
+ print(
+ f" After IMAGE_TRANSFORMS: shape={tuple(img_tensor.shape)} " f"min={img_tensor.min():.4f} max={img_tensor.max():.4f}"
+ )
# Check range is [-1, 1]
if img_tensor.min() < -1.01 or img_tensor.max() > 1.01:
@@ -116,7 +115,7 @@ def test_latent_cache(args, pairs):
print(f" VAE input: shape={tuple(img_5d.shape)} dtype={img_5d.dtype}")
with torch.no_grad():
- latents = vae.encode(img_5d, vae_scale)
+ latents = vae.encode_pixels_to_latents(img_5d)
latents_cpu = latents.cpu()
print_tensor_info("Encoded latents", latents_cpu)
@@ -165,7 +164,9 @@ def test_latent_cache(args, pairs):
# Test 2: Text Encoder Output Cache
+
def test_text_encoder_cache(args, pairs):
+ # TODO Rewrite this
print("\n" + "=" * 70)
print("TEST 2: TEXT ENCODER OUTPUT CACHING")
print("=" * 70)
@@ -175,9 +176,7 @@ def test_text_encoder_cache(args, pairs):
# Load tokenizers
print("\n[2.1] Loading tokenizers...")
qwen3_tokenizer = anima_utils.load_qwen3_tokenizer(args.qwen3_path)
- t5_tokenizer = anima_utils.load_t5_tokenizer(
- getattr(args, 't5_tokenizer_path', None)
- )
+ t5_tokenizer = anima_utils.load_t5_tokenizer(getattr(args, "t5_tokenizer_path", None))
print(f" Qwen3 tokenizer vocab: {qwen3_tokenizer.vocab_size}")
print(f" T5 tokenizer vocab: {t5_tokenizer.vocab_size}")
@@ -185,9 +184,7 @@ def test_text_encoder_cache(args, pairs):
print("\n[2.2] Loading Qwen3 text encoder...")
device = "cuda" if torch.cuda.is_available() else "cpu"
te_dtype = torch.bfloat16 if device == "cuda" else torch.float32
- qwen3_model, _ = anima_utils.load_qwen3_text_encoder(
- args.qwen3_path, dtype=te_dtype, device=device
- )
+ qwen3_model, _ = anima_utils.load_qwen3_text_encoder(args.qwen3_path, dtype=te_dtype, device=device)
qwen3_model.eval()
# Create strategy objects
@@ -199,9 +196,7 @@ def test_text_encoder_cache(args, pairs):
qwen3_max_length=args.qwen3_max_length,
t5_max_length=args.t5_max_length,
)
- text_encoding_strategy = AnimaTextEncodingStrategy(
- dropout_rate=0.0,
- )
+ text_encoding_strategy = AnimaTextEncodingStrategy()
captions = [cap for _, cap in pairs]
print(f"\n[2.3] Tokenizing {len(captions)} captions...")
@@ -221,10 +216,7 @@ def test_text_encoder_cache(args, pairs):
print(f"\n[2.4] Encoding with Qwen3 text encoder...")
with torch.no_grad():
prompt_embeds, attn_mask, t5_ids_out, t5_mask_out = text_encoding_strategy.encode_tokens(
- tokenize_strategy,
- [qwen3_model],
- tokens_and_masks,
- enable_dropout=False,
+ tokenize_strategy, [qwen3_model], tokens_and_masks
)
print(f" Encoding results:")
@@ -374,13 +366,13 @@ def test_text_encoder_cache(args, pairs):
# Test 3: Full batch simulation
+
def test_full_batch_simulation(args, pairs):
print("\n" + "=" * 70)
print("TEST 3: FULL BATCH SIMULATION (mimics process_batch flow)")
print("=" * 70)
from library import anima_utils
- from library.anima_models import ANIMA_VAE_MEAN, ANIMA_VAE_STD
from library.strategy_anima import AnimaTokenizeStrategy, AnimaTextEncodingStrategy
device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -390,14 +382,16 @@ def test_full_batch_simulation(args, pairs):
# Load all models
print("\n[3.1] Loading models...")
qwen3_tokenizer = anima_utils.load_qwen3_tokenizer(args.qwen3_path)
- t5_tokenizer = anima_utils.load_t5_tokenizer(getattr(args, 't5_tokenizer_path', None))
+ t5_tokenizer = anima_utils.load_t5_tokenizer(getattr(args, "t5_tokenizer_path", None))
qwen3_model, _ = anima_utils.load_qwen3_text_encoder(args.qwen3_path, dtype=te_dtype, device=device)
qwen3_model.eval()
vae, _, _, vae_scale = anima_utils.load_anima_vae(args.vae_path, dtype=vae_dtype, device=device)
tokenize_strategy = AnimaTokenizeStrategy(
- qwen3_tokenizer=qwen3_tokenizer, t5_tokenizer=t5_tokenizer,
- qwen3_max_length=args.qwen3_max_length, t5_max_length=args.t5_max_length,
+ qwen3_tokenizer=qwen3_tokenizer,
+ t5_tokenizer=t5_tokenizer,
+ qwen3_max_length=args.qwen3_max_length,
+ t5_max_length=args.t5_max_length,
)
text_encoding_strategy = AnimaTextEncodingStrategy(dropout_rate=0.0)
@@ -408,7 +402,10 @@ def test_full_batch_simulation(args, pairs):
tokens_and_masks = tokenize_strategy.tokenize(captions)
with torch.no_grad():
te_outputs = text_encoding_strategy.encode_tokens(
- tokenize_strategy, [qwen3_model], tokens_and_masks, enable_dropout=False,
+ tokenize_strategy,
+ [qwen3_model],
+ tokens_and_masks,
+ enable_dropout=False,
)
prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = te_outputs
@@ -473,7 +470,7 @@ def test_full_batch_simulation(args, pairs):
print(f" text_encoder_conds: empty (no cache)")
# The critical condition
- train_text_encoder_TRUE = True # OLD behavior (base class default, no override)
+ train_text_encoder_TRUE = True # OLD behavior (base class default, no override)
train_text_encoder_FALSE = False # NEW behavior (with is_train_text_encoder override)
cond_old = len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder_TRUE
@@ -541,26 +538,19 @@ def test_full_batch_simulation(args, pairs):
# Main
+
def main():
parser = argparse.ArgumentParser(description="Test Anima caching mechanisms")
- parser.add_argument("--image_dir", type=str, required=True,
- help="Directory with image+txt pairs")
- parser.add_argument("--qwen3_path", type=str, required=True,
- help="Path to Qwen3 model (directory or safetensors)")
- parser.add_argument("--vae_path", type=str, required=True,
- help="Path to WanVAE safetensors")
- parser.add_argument("--t5_tokenizer_path", type=str, default=None,
- help="Path to T5 tokenizer (optional, uses bundled config)")
+ parser.add_argument("--image_dir", type=str, required=True, help="Directory with image+txt pairs")
+ parser.add_argument("--qwen3_path", type=str, required=True, help="Path to Qwen3 model (directory or safetensors)")
+ parser.add_argument("--vae_path", type=str, required=True, help="Path to WanVAE safetensors")
+ parser.add_argument("--t5_tokenizer_path", type=str, default=None, help="Path to T5 tokenizer (optional, uses bundled config)")
parser.add_argument("--qwen3_max_length", type=int, default=512)
parser.add_argument("--t5_max_length", type=int, default=512)
- parser.add_argument("--cache_to_disk", action="store_true",
- help="Also test disk cache round-trip")
- parser.add_argument("--skip_latent", action="store_true",
- help="Skip latent cache test")
- parser.add_argument("--skip_text", action="store_true",
- help="Skip text encoder cache test")
- parser.add_argument("--skip_full", action="store_true",
- help="Skip full batch simulation")
+ parser.add_argument("--cache_to_disk", action="store_true", help="Also test disk cache round-trip")
+ parser.add_argument("--skip_latent", action="store_true", help="Skip latent cache test")
+ parser.add_argument("--skip_text", action="store_true", help="Skip text encoder cache test")
+ parser.add_argument("--skip_full", action="store_true", help="Skip full batch simulation")
args = parser.parse_args()
# Find pairs
diff --git a/tests/test_anima_real_training.py b/tests/manual_test_anima_real_training.py
similarity index 100%
rename from tests/test_anima_real_training.py
rename to tests/manual_test_anima_real_training.py
diff --git a/train_network.py b/train_network.py
index 6cebf5fc..2f8797d2 100644
--- a/train_network.py
+++ b/train_network.py
@@ -470,7 +470,7 @@ class NetworkTrainer:
loss = loss * weighting
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
- loss = loss.mean([1, 2, 3])
+ loss = loss.mean(dim=list(range(1, loss.ndim))) # mean over all dims except batch
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights