diff --git a/hunyuan_image_minimal_inference.py b/hunyuan_image_minimal_inference.py new file mode 100644 index 00000000..8a956f49 --- /dev/null +++ b/hunyuan_image_minimal_inference.py @@ -0,0 +1,1197 @@ +import argparse +import datetime +import gc +from importlib.util import find_spec +import random +import os +import re +import time +import copy +from types import ModuleType +from typing import Tuple, Optional, List, Any, Dict + +import numpy as np +import torch +from safetensors.torch import load_file, save_file +from safetensors import safe_open +from tqdm import tqdm +from diffusers.utils.torch_utils import randn_tensor +from PIL import Image + +from library import hunyuan_image_models, hunyuan_image_text_encoder, hunyuan_image_utils +from library import hunyuan_image_vae +from library.hunyuan_image_vae import HunyuanVAE2D +from library.device_utils import clean_memory_on_device +from networks import lora_hunyuan_image + + +lycoris_available = find_spec("lycoris") is not None +if lycoris_available: + from lycoris.kohya import create_network_from_weights + +from library.custom_offloading_utils import synchronize_device +from library.utils import mem_eff_save_file, setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +class GenerationSettings: + def __init__(self, device: torch.device, dit_weight_dtype: Optional[torch.dtype] = None): + self.device = device + self.dit_weight_dtype = dit_weight_dtype # not used currently because model may be optimized + + +def parse_args() -> argparse.Namespace: + """parse command line arguments""" + parser = argparse.ArgumentParser(description="HunyuanImage inference script") + + parser.add_argument("--dit", type=str, default=None, help="DiT directory or path") + parser.add_argument("--vae", type=str, default=None, help="VAE directory or path") + parser.add_argument("--text_encoder", type=str, required=True, help="Text Encoder 1 (Qwen2.5-VL) directory or path") + parser.add_argument("--byt5", type=str, default=None, help="ByT5 Text Encoder 2 directory or path") + + # LoRA + parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path") + parser.add_argument("--lora_multiplier", type=float, nargs="*", default=1.0, help="LoRA multiplier") + parser.add_argument("--include_patterns", type=str, nargs="*", default=None, help="LoRA module include patterns") + parser.add_argument("--exclude_patterns", type=str, nargs="*", default=None, help="LoRA module exclude patterns") + parser.add_argument( + "--save_merged_model", + type=str, + default=None, + help="Save merged model to path. If specified, no inference will be performed.", + ) + + # inference + parser.add_argument( + "--guidance_scale", type=float, default=4.0, help="Guidance scale for classifier free guidance. Default is 4.0." + ) + parser.add_argument("--prompt", type=str, default=None, help="prompt for generation") + parser.add_argument("--negative_prompt", type=str, default="", help="negative prompt for generation, default is empty string") + parser.add_argument("--image_size", type=int, nargs=2, default=[256, 256], help="image size, height and width") + parser.add_argument("--infer_steps", type=int, default=25, help="number of inference steps, default is 25") + parser.add_argument("--save_path", type=str, required=True, help="path to save generated video") + parser.add_argument("--seed", type=int, default=None, help="Seed for evaluation.") + + # Flow Matching + parser.add_argument( + "--flow_shift", + type=float, + default=None, + help="Shift factor for flow matching schedulers. Default is None (default).", + ) + + parser.add_argument("--fp8", action="store_true", help="use fp8 for DiT model") + parser.add_argument("--fp8_scaled", action="store_true", help="use scaled fp8 for DiT, only for fp8") + + parser.add_argument("--text_encoder_cpu", action="store_true", help="Inference on CPU for Text Encoders") + parser.add_argument("--vae_enable_tiling", action="store_true", help="Enable tiling for VAE decoding") + parser.add_argument( + "--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU" + ) + parser.add_argument( + "--attn_mode", + type=str, + default="torch", + choices=["flash", "torch", "sageattn", "xformers", "sdpa"], # "flash2", "flash3", + help="attention mode", + ) + parser.add_argument("--blocks_to_swap", type=int, default=0, help="number of blocks to swap in the model") + parser.add_argument( + "--output_type", + type=str, + default="images", + choices=["images", "latent", "latent_images"], + help="output type", + ) + parser.add_argument("--no_metadata", action="store_true", help="do not save metadata") + parser.add_argument("--latent_path", type=str, nargs="*", default=None, help="path to latent for decode. no inference") + parser.add_argument( + "--lycoris", action="store_true", help=f"use lycoris for inference{'' if lycoris_available else ' (not available)'}" + ) + + # arguments for batch and interactive modes + parser.add_argument("--from_file", type=str, default=None, help="Read prompts from a file") + parser.add_argument("--interactive", action="store_true", help="Interactive mode: read prompts from console") + + args = parser.parse_args() + + # Validate arguments + if args.from_file and args.interactive: + raise ValueError("Cannot use both --from_file and --interactive at the same time") + + if args.latent_path is None or len(args.latent_path) == 0: + if args.prompt is None and not args.from_file and not args.interactive: + raise ValueError("Either --prompt, --from_file or --interactive must be specified") + + if args.lycoris and not lycoris_available: + raise ValueError("install lycoris: https://github.com/KohakuBlueleaf/LyCORIS") + + return args + + +def parse_prompt_line(line: str) -> Dict[str, Any]: + """Parse a prompt line into a dictionary of argument overrides + + Args: + line: Prompt line with options + + Returns: + Dict[str, Any]: Dictionary of argument overrides + """ + # TODO common function with hv_train_network.line_to_prompt_dict + parts = line.split(" --") + prompt = parts[0].strip() + + # Create dictionary of overrides + overrides = {"prompt": prompt} + + for part in parts[1:]: + if not part.strip(): + continue + option_parts = part.split(" ", 1) + option = option_parts[0].strip() + value = option_parts[1].strip() if len(option_parts) > 1 else "" + + # Map options to argument names + if option == "w": + overrides["image_size_width"] = int(value) + elif option == "h": + overrides["image_size_height"] = int(value) + elif option == "d": + overrides["seed"] = int(value) + elif option == "s": + overrides["infer_steps"] = int(value) + elif option == "g" or option == "l": + overrides["guidance_scale"] = float(value) + elif option == "fs": + overrides["flow_shift"] = float(value) + # elif option == "i": + # overrides["image_path"] = value + # elif option == "im": + # overrides["image_mask_path"] = value + # elif option == "cn": + # overrides["control_path"] = value + elif option == "n": + overrides["negative_prompt"] = value + # elif option == "ci": # control_image_path + # overrides["control_image_path"] = value + + return overrides + + +def apply_overrides(args: argparse.Namespace, overrides: Dict[str, Any]) -> argparse.Namespace: + """Apply overrides to args + + Args: + args: Original arguments + overrides: Dictionary of overrides + + Returns: + argparse.Namespace: New arguments with overrides applied + """ + args_copy = copy.deepcopy(args) + + for key, value in overrides.items(): + if key == "image_size_width": + args_copy.image_size[1] = value + elif key == "image_size_height": + args_copy.image_size[0] = value + else: + setattr(args_copy, key, value) + + return args_copy + + +def check_inputs(args: argparse.Namespace) -> Tuple[int, int]: + """Validate video size and length + + Args: + args: command line arguments + + Returns: + Tuple[int, int]: (height, width) + """ + height = args.image_size[0] + width = args.image_size[1] + + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + return height, width + + +# region Model + + +def load_dit_model( + args: argparse.Namespace, device: torch.device, dit_weight_dtype: Optional[torch.dtype] = None +) -> hunyuan_image_models.HYImageDiffusionTransformer: + """load DiT model + + Args: + args: command line arguments + device: device to use + dit_weight_dtype: data type for the model weights. None for as-is + + Returns: + qwen_image_model.HYImageDiffusionTransformer: DiT model instance + """ + # If LyCORIS is enabled, we will load the model to CPU and then merge LoRA weights (static method) + + loading_device = "cpu" + if args.blocks_to_swap == 0 and not args.lycoris: + loading_device = device + + # load LoRA weights + if not args.lycoris and args.lora_weight is not None and len(args.lora_weight) > 0: + lora_weights_list = [] + for lora_weight in args.lora_weight: + logger.info(f"Loading LoRA weight from: {lora_weight}") + lora_sd = load_file(lora_weight) # load on CPU, dtype is as is + # lora_sd = filter_lora_state_dict(lora_sd, args.include_patterns, args.exclude_patterns) + lora_weights_list.append(lora_sd) + else: + lora_weights_list = None + + loading_weight_dtype = dit_weight_dtype + if args.fp8_scaled and not args.lycoris: + loading_weight_dtype = None # we will load weights as-is and then optimize to fp8 + + model = hunyuan_image_models.load_hunyuan_image_model( + device, + args.dit, + args.attn_mode, + False, + loading_device, + loading_weight_dtype, + args.fp8_scaled and not args.lycoris, + lora_weights_list=lora_weights_list, + lora_multipliers=args.lora_multiplier, + ) + + # merge LoRA weights + if args.lycoris: + if args.lora_weight is not None and len(args.lora_weight) > 0: + merge_lora_weights(lora_hunyuan_image, model, args, device) + + if args.fp8_scaled: + # load state dict as-is and optimize to fp8 + state_dict = model.state_dict() + + # if no blocks to swap, we can move the weights to GPU after optimization on GPU (omit redundant CPU->GPU copy) + move_to_device = args.blocks_to_swap == 0 # if blocks_to_swap > 0, we will keep the model on CPU + state_dict = model.fp8_optimization(state_dict, device, move_to_device, use_scaled_mm=args.fp8_fast) + + info = model.load_state_dict(state_dict, strict=True, assign=True) + logger.info(f"Loaded FP8 optimized weights: {info}") + + # if we only want to save the model, we can skip the rest + if args.save_merged_model: + return None + + if not args.fp8_scaled: + # simple cast to dit_weight_dtype + target_dtype = None # load as-is (dit_weight_dtype == dtype of the weights in state_dict) + target_device = None + + if dit_weight_dtype is not None: # in case of args.fp8 and not args.fp8_scaled + logger.info(f"Convert model to {dit_weight_dtype}") + target_dtype = dit_weight_dtype + + if args.blocks_to_swap == 0: + logger.info(f"Move model to device: {device}") + target_device = device + + model.to(target_device, target_dtype) # move and cast at the same time. this reduces redundant copy operations + + # if args.compile: + # compile_backend, compile_mode, compile_dynamic, compile_fullgraph = args.compile_args + # logger.info( + # f"Torch Compiling[Backend: {compile_backend}; Mode: {compile_mode}; Dynamic: {compile_dynamic}; Fullgraph: {compile_fullgraph}]" + # ) + # torch._dynamo.config.cache_size_limit = 32 + # for i in range(len(model.blocks)): + # model.blocks[i] = torch.compile( + # model.blocks[i], + # backend=compile_backend, + # mode=compile_mode, + # dynamic=compile_dynamic.lower() in "true", + # fullgraph=compile_fullgraph.lower() in "true", + # ) + + if args.blocks_to_swap > 0: + logger.info(f"Enable swap {args.blocks_to_swap} blocks to CPU from device: {device}") + model.enable_block_swap(args.blocks_to_swap, device, supports_backward=False) + model.move_to_device_except_swap_blocks(device) + model.prepare_block_swap_before_forward() + else: + # make sure the model is on the right device + model.to(device) + + model.eval().requires_grad_(False) + clean_memory_on_device(device) + + return model + + +def merge_lora_weights( + lora_module: ModuleType, + model: torch.nn.Module, + lora_weights: List[str], + lora_multipliers: List[float], + include_patterns: Optional[List[str]], + exclude_patterns: Optional[List[str]], + device: torch.device, + lycoris: bool = False, + save_merged_model: Optional[str] = None, + converter: Optional[callable] = None, +) -> None: + """merge LoRA weights to the model + + Args: + lora_module: LoRA module, e.g. lora_wan + model: DiT model + lora_weights: paths to LoRA weights + lora_multipliers: multipliers for LoRA weights + include_patterns: regex patterns to include LoRA modules + exclude_patterns: regex patterns to exclude LoRA modules + device: torch.device + lycoris: use LyCORIS + save_merged_model: path to save merged model, if specified, no inference will be performed + converter: Optional[callable] = None + """ + if lora_weights is None or len(lora_weights) == 0: + return + + for i, lora_weight in enumerate(lora_weights): + if lora_multipliers is not None and len(lora_multipliers) > i: + lora_multiplier = lora_multipliers[i] + else: + lora_multiplier = 1.0 + + logger.info(f"Loading LoRA weights from {lora_weight} with multiplier {lora_multiplier}") + weights_sd = load_file(lora_weight) + if converter is not None: + weights_sd = converter(weights_sd) + + # apply include/exclude patterns + original_key_count = len(weights_sd.keys()) + if include_patterns is not None and len(include_patterns) > i: + include_pattern = include_patterns[i] + regex_include = re.compile(include_pattern) + weights_sd = {k: v for k, v in weights_sd.items() if regex_include.search(k)} + logger.info(f"Filtered keys with include pattern {include_pattern}: {original_key_count} -> {len(weights_sd.keys())}") + if exclude_patterns is not None and len(exclude_patterns) > i: + original_key_count_ex = len(weights_sd.keys()) + exclude_pattern = exclude_patterns[i] + regex_exclude = re.compile(exclude_pattern) + weights_sd = {k: v for k, v in weights_sd.items() if not regex_exclude.search(k)} + logger.info( + f"Filtered keys with exclude pattern {exclude_pattern}: {original_key_count_ex} -> {len(weights_sd.keys())}" + ) + if len(weights_sd) != original_key_count: + remaining_keys = list(set([k.split(".", 1)[0] for k in weights_sd.keys()])) + remaining_keys.sort() + logger.info(f"Remaining LoRA modules after filtering: {remaining_keys}") + if len(weights_sd) == 0: + logger.warning("No keys left after filtering.") + + if lycoris: + lycoris_net, _ = create_network_from_weights( + multiplier=lora_multiplier, + file=None, + weights_sd=weights_sd, + unet=model, + text_encoder=None, + vae=None, + for_inference=True, + ) + lycoris_net.merge_to(None, model, weights_sd, dtype=None, device=device) + else: + network = lora_module.create_arch_network_from_weights(lora_multiplier, weights_sd, unet=model, for_inference=True) + network.merge_to(None, model, weights_sd, device=device, non_blocking=True) + + synchronize_device(device) + logger.info("LoRA weights loaded") + + # save model here before casting to dit_weight_dtype + if save_merged_model: + logger.info(f"Saving merged model to {save_merged_model}") + mem_eff_save_file(model.state_dict(), save_merged_model) # save_file needs a lot of memory + logger.info("Merged model saved") + + +# endregion + + +def decode_latent(vae: HunyuanVAE2D, latent: torch.Tensor, device: torch.device, enable_tiling: bool = False) -> torch.Tensor: + logger.info(f"Decoding image. Latent shape {latent.shape}, device {device}") + + vae.to(device) + if enable_tiling: + vae.enable_tiling() + else: + vae.disable_tiling() + with torch.no_grad(): + latent = latent / vae.scaling_factor # scale latent back to original range + pixels = vae.decode(latent.to(device, dtype=vae.dtype)) + pixels = pixels.to("cpu", dtype=torch.float32) # move to CPU and convert to float32 (bfloat16 is not supported by numpy) + vae.to("cpu") + + logger.info(f"Decoded. Pixel shape {pixels.shape}") + return pixels[0] # remove batch dimension + + +def prepare_text_inputs( + args: argparse.Namespace, device: torch.device, shared_models: Optional[Dict] = None +) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """Prepare text-related inputs for T2I: LLM encoding.""" + + # load text encoder: conds_cache holds cached encodings for prompts without padding + conds_cache = {} + vl_device = torch.device("cpu") if args.text_encoder_cpu else device + if shared_models is not None: + tokenizer_vlm = shared_models.get("tokenizer_vlm") + text_encoder_vlm = shared_models.get("text_encoder_vlm") + tokenizer_byt5 = shared_models.get("tokenizer_byt5") + text_encoder_byt5 = shared_models.get("text_encoder_byt5") + + if "conds_cache" in shared_models: # Use shared cache if available + conds_cache = shared_models["conds_cache"] + + # text_encoder is on device (batched inference) or CPU (interactive inference) + else: # Load if not in shared_models + vl_dtype = torch.bfloat16 # Default dtype for Text Encoder + tokenizer_vlm, text_encoder_vlm = hunyuan_image_text_encoder.load_qwen2_5_vl( + args.text_encoder, dtype=vl_dtype, device=vl_device, disable_mmap=True + ) + tokenizer_byt5, text_encoder_byt5 = hunyuan_image_text_encoder.load_byt5( + args.byt5, dtype=torch.float16, device=vl_device, disable_mmap=True + ) + + # Store original devices to move back later if they were shared. This does nothing if shared_models is None + text_encoder_original_device = text_encoder_vlm.device if text_encoder_vlm else None + + # Ensure text_encoder is not None before proceeding + if not text_encoder_vlm or not tokenizer_vlm or not tokenizer_byt5 or not text_encoder_byt5: + raise ValueError("Text encoder or tokenizer is not loaded properly.") + + # Define a function to move models to device if needed + # This is to avoid moving models if not needed, especially in interactive mode + model_is_moved = False + + def move_models_to_device_if_needed(): + nonlocal model_is_moved + nonlocal shared_models + + if model_is_moved: + return + model_is_moved = True + + logger.info(f"Moving DiT and Text Encoder to appropriate device: {device} or CPU") + if shared_models and "model" in shared_models: # DiT model is shared + if args.blocks_to_swap > 0: + logger.info("Waiting for 5 seconds to finish block swap") + time.sleep(5) + model = shared_models["model"] + model.to("cpu") + clean_memory_on_device(device) # clean memory on device before moving models + + text_encoder_vlm.to(vl_device) # If text_encoder_cpu is True, this will be CPU + text_encoder_byt5.to(vl_device) + + logger.info("Encoding prompt with Text Encoder") + + prompt = args.prompt + cache_key = prompt + if cache_key in conds_cache: + embed, mask = conds_cache[cache_key] + else: + move_models_to_device_if_needed() + + embed, mask = hunyuan_image_text_encoder.get_qwen_prompt_embeds(tokenizer_vlm, text_encoder_vlm, prompt) + ocr_mask, embed_byt5, mask_byt5 = hunyuan_image_text_encoder.get_glyph_prompt_embeds( + tokenizer_byt5, text_encoder_byt5, prompt + ) + embed = embed.cpu() + mask = mask.cpu() + embed_byt5 = embed_byt5.cpu() + mask_byt5 = mask_byt5.cpu() + + conds_cache[cache_key] = (embed, mask, embed_byt5, mask_byt5, ocr_mask) + + negative_prompt = args.negative_prompt + cache_key = negative_prompt + if cache_key in conds_cache: + negative_embed, negative_mask = conds_cache[cache_key] + else: + move_models_to_device_if_needed() + + negative_embed, negative_mask = hunyuan_image_text_encoder.get_qwen_prompt_embeds( + tokenizer_vlm, text_encoder_vlm, negative_prompt + ) + negative_ocr_mask, negative_embed_byt5, negative_mask_byt5 = hunyuan_image_text_encoder.get_glyph_prompt_embeds( + tokenizer_byt5, text_encoder_byt5, negative_prompt + ) + negative_embed = negative_embed.cpu() + negative_mask = negative_mask.cpu() + negative_embed_byt5 = negative_embed_byt5.cpu() + negative_mask_byt5 = negative_mask_byt5.cpu() + + conds_cache[cache_key] = (negative_embed, negative_mask, negative_embed_byt5, negative_mask_byt5, negative_ocr_mask) + + if not (shared_models and "text_encoder_vlm" in shared_models): # if loaded locally + # There is a bug text_encoder is not freed from GPU memory when text encoder is fp8 + del tokenizer_vlm, text_encoder_vlm, tokenizer_byt5, text_encoder_byt5 + gc.collect() # This may force Text Encoder to be freed from GPU memory + else: # if shared, move back to original device (likely CPU) + if text_encoder_vlm: + text_encoder_vlm.to(text_encoder_original_device) + if text_encoder_byt5: + text_encoder_byt5.to(text_encoder_original_device) + + clean_memory_on_device(device) + + arg_c = {"embed": embed, "mask": mask, "embed_byt5": embed_byt5, "mask_byt5": mask_byt5, "ocr_mask": ocr_mask, "prompt": prompt} + arg_null = { + "embed": negative_embed, + "mask": negative_mask, + "embed_byt5": negative_embed_byt5, + "mask_byt5": negative_mask_byt5, + "ocr_mask": negative_ocr_mask, + "prompt": negative_prompt, + } + + return arg_c, arg_null + + +def generate( + args: argparse.Namespace, + gen_settings: GenerationSettings, + shared_models: Optional[Dict] = None, + precomputed_text_data: Optional[Dict] = None, +) -> torch.Tensor: + """main function for generation + + Args: + args: command line arguments + shared_models: dictionary containing pre-loaded models (mainly for DiT) + precomputed_image_data: Optional dictionary with precomputed image data + precomputed_text_data: Optional dictionary with precomputed text data + + Returns: + tuple: (HunyuanVAE2D model (vae) or None, torch.Tensor generated latent) + """ + device, dit_weight_dtype = (gen_settings.device, gen_settings.dit_weight_dtype) + + # prepare seed + seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1) + args.seed = seed # set seed to args for saving + + if precomputed_text_data is not None: + logger.info("Using precomputed text data.") + context = precomputed_text_data["context"] + context_null = precomputed_text_data["context_null"] + + else: + logger.info("No precomputed data. Preparing image and text inputs.") + context, context_null = prepare_text_inputs(args, device, shared_models) + + if shared_models is None or "model" not in shared_models: + # load DiT model + model = load_dit_model(args, device, dit_weight_dtype) + + # if we only want to save the model, we can skip the rest + if args.save_merged_model: + return None + + if shared_models is not None: + shared_models["model"] = model + else: + # use shared model + model: hunyuan_image_models.HYImageDiffusionTransformer = shared_models["model"] + # model.move_to_device_except_swap_blocks(device) # Handles block swap correctly + # model.prepare_block_swap_before_forward() + + # set random generator + seed_g = torch.Generator(device="cpu") + seed_g.manual_seed(seed) + + height, width = check_inputs(args) + logger.info(f"Image size: {height}x{width} (HxW), infer_steps: {args.infer_steps}") + + # image generation ###### + + logger.info(f"Prompt: {context['prompt']}") + + embed = context["embed"].to(device, dtype=torch.bfloat16) + mask = context["mask"].to(device, dtype=torch.bfloat16) + embed_byt5 = context["embed_byt5"].to(device, dtype=torch.bfloat16) + mask_byt5 = context["mask_byt5"].to(device, dtype=torch.bfloat16) + ocr_mask = context["ocr_mask"] # list of bool + negative_embed = context_null["embed"].to(device, dtype=torch.bfloat16) + negative_mask = context_null["mask"].to(device, dtype=torch.bfloat16) + negative_embed_byt5 = context_null["embed_byt5"].to(device, dtype=torch.bfloat16) + negative_mask_byt5 = context_null["mask_byt5"].to(device, dtype=torch.bfloat16) + # negative_ocr_mask = context_null["ocr_mask"] # list of bool + + # Prepare latent variables + num_channels_latents = model.in_channels + shape = (1, num_channels_latents, height // hunyuan_image_vae.VAE_SCALE_FACTOR, width // hunyuan_image_vae.VAE_SCALE_FACTOR) + latents = randn_tensor(shape, generator=seed_g, device=device, dtype=torch.bfloat16) + + logger.info( + f"Embed: {embed.shape}, embed byt5: {embed_byt5.shape}, negative_embed: {negative_embed.shape}, negative embed byt5: {negative_embed_byt5.shape}, latents: {latents.shape}" + ) + + # Prepare timesteps + timesteps, sigmas = hunyuan_image_utils.get_timesteps_sigmas(args.infer_steps, args.flow_shift, device) + + # Prepare Guider + cfg_guider_ocr = hunyuan_image_utils.AdaptiveProjectedGuidance( + guidance_scale=10.0, eta=0.0, adaptive_projected_guidance_rescale=10.0, adaptive_projected_guidance_momentum=-0.5 + ) + cfg_guider_general = hunyuan_image_utils.AdaptiveProjectedGuidance( + guidance_scale=10.0, eta=0.0, adaptive_projected_guidance_rescale=10.0, adaptive_projected_guidance_momentum=-0.5 + ) + + # Denoising loop + do_cfg = args.guidance_scale != 1.0 + with tqdm(total=len(timesteps), desc="Denoising steps") as pbar: + for i, t in enumerate(timesteps): + t_expand = t.expand(latents.shape[0]).to(latents.dtype) + + with torch.no_grad(): + noise_pred = model(latents, t_expand, embed, mask, embed_byt5, mask_byt5) + + if do_cfg: + with torch.no_grad(): + uncond_noise_pred = model( + latents, t_expand, negative_embed, negative_mask, negative_embed_byt5, negative_mask_byt5 + ) + noise_pred = hunyuan_image_utils.apply_classifier_free_guidance( + noise_pred, + uncond_noise_pred, + ocr_mask[0], + args.guidance_scale, + i, + cfg_guider_ocr=cfg_guider_ocr, + cfg_guider_general=cfg_guider_general, + ) + + # ensure latents dtype is consistent + latents = hunyuan_image_utils.step(latents, noise_pred, sigmas, i).to(latents.dtype) + + pbar.update() + + return latents + + +def get_time_flag(): + return datetime.datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S-%f")[:-3] + + +def save_latent(latent: torch.Tensor, args: argparse.Namespace, height: int, width: int) -> str: + """Save latent to file + + Args: + latent: Latent tensor + args: command line arguments + height: height of frame + width: width of frame + + Returns: + str: Path to saved latent file + """ + save_path = args.save_path + os.makedirs(save_path, exist_ok=True) + time_flag = get_time_flag() + + seed = args.seed + + latent_path = f"{save_path}/{time_flag}_{seed}_latent.safetensors" + + if args.no_metadata: + metadata = None + else: + metadata = { + "seeds": f"{seed}", + "prompt": f"{args.prompt}", + "height": f"{height}", + "width": f"{width}", + "infer_steps": f"{args.infer_steps}", + # "embedded_cfg_scale": f"{args.embedded_cfg_scale}", + "guidance_scale": f"{args.guidance_scale}", + } + if args.negative_prompt is not None: + metadata["negative_prompt"] = f"{args.negative_prompt}" + + sd = {"latent": latent.contiguous()} + save_file(sd, latent_path, metadata=metadata) + logger.info(f"Latent saved to: {latent_path}") + + return latent_path + + +def save_images(sample: torch.Tensor, args: argparse.Namespace, original_base_name: Optional[str] = None) -> str: + """Save images to directory + + Args: + sample: Video tensor + args: command line arguments + original_base_name: Original base name (if latents are loaded from files) + + Returns: + str: Path to saved images directory + """ + save_path = args.save_path + os.makedirs(save_path, exist_ok=True) + time_flag = get_time_flag() + + seed = args.seed + original_name = "" if original_base_name is None else f"_{original_base_name}" + image_name = f"{time_flag}_{seed}{original_name}" + + x = torch.clamp(sample, -1.0, 1.0) + x = ((x + 1.0) * 127.5).to(torch.uint8).cpu().numpy() + x = x.transpose(1, 2, 0) # C, H, W -> H, W, C + + image = Image.fromarray(x) + image.save(os.path.join(save_path, f"{image_name}.png")) + + logger.info(f"Sample images saved to: {save_path}/{image_name}") + + return f"{save_path}/{image_name}" + + +def save_output( + args: argparse.Namespace, + vae: HunyuanVAE2D, + latent: torch.Tensor, + device: torch.device, + original_base_names: Optional[List[str]] = None, +) -> None: + """save output + + Args: + args: command line arguments + vae: VAE model + latent: latent tensor + device: device to use + original_base_names: original base names (if latents are loaded from files) + """ + height, width = latent.shape[-2], latent.shape[-1] # BCTHW + height *= hunyuan_image_vae.VAE_SCALE_FACTOR + width *= hunyuan_image_vae.VAE_SCALE_FACTOR + # print(f"Saving output. Latent shape {latent.shape}; pixel shape {height}x{width}") + if args.output_type == "latent" or args.output_type == "latent_images": + # save latent + save_latent(latent, args, height, width) + if args.output_type == "latent": + return + + if vae is None: + logger.error("VAE is None, cannot decode latents for saving video/images.") + return + + if latent.ndim == 2: # S,C. For packed latents from other inference scripts + latent = latent.unsqueeze(0) + height, width = check_inputs(args) # Get height/width from args + latent = latent.view( + 1, vae.latent_channels, height // hunyuan_image_vae.VAE_SCALE_FACTOR, width // hunyuan_image_vae.VAE_SCALE_FACTOR + ) + + image = decode_latent(vae, latent, device, args.vae_enable_tiling) + + if args.output_type == "images" or args.output_type == "latent_images": + # save images + if original_base_names is None or len(original_base_names) == 0: + original_name = "" + else: + original_name = f"_{original_base_names[0]}" + save_images(image, args, original_name) + + +def preprocess_prompts_for_batch(prompt_lines: List[str], base_args: argparse.Namespace) -> List[Dict]: + """Process multiple prompts for batch mode + + Args: + prompt_lines: List of prompt lines + base_args: Base command line arguments + + Returns: + List[Dict]: List of prompt data dictionaries + """ + prompts_data = [] + + for line in prompt_lines: + line = line.strip() + if not line or line.startswith("#"): # Skip empty lines and comments + continue + + # Parse prompt line and create override dictionary + prompt_data = parse_prompt_line(line) + logger.info(f"Parsed prompt data: {prompt_data}") + prompts_data.append(prompt_data) + + return prompts_data + + +def load_shared_models(args: argparse.Namespace) -> Dict: + """Load shared models for batch processing or interactive mode. + Models are loaded to CPU to save memory. VAE is NOT loaded here. + DiT model is also NOT loaded here, handled by process_batch_prompts or generate. + + Args: + args: Base command line arguments + + Returns: + Dict: Dictionary of shared models (text/image encoders) + """ + shared_models = {} + # Load text encoders to CPU + vl_dtype = torch.bfloat16 # Default dtype for Text Encoder + tokenizer_vlm, text_encoder_vlm = hunyuan_image_text_encoder.load_qwen2_5_vl( + args.text_encoder, dtype=vl_dtype, device="cpu", disable_mmap=True + ) + tokenizer_byt5, text_encoder_byt5 = hunyuan_image_text_encoder.load_byt5( + args.byt5, dtype=torch.float16, device="cpu", disable_mmap=True + ) + shared_models["tokenizer_vlm"] = tokenizer_vlm + shared_models["text_encoder_vlm"] = text_encoder_vlm + shared_models["tokenizer_byt5"] = tokenizer_byt5 + shared_models["text_encoder_byt5"] = text_encoder_byt5 + return shared_models + + +def process_batch_prompts(prompts_data: List[Dict], args: argparse.Namespace) -> None: + """Process multiple prompts with model reuse and batched precomputation + + Args: + prompts_data: List of prompt data dictionaries + args: Base command line arguments + """ + if not prompts_data: + logger.warning("No valid prompts found") + return + + gen_settings = get_generation_settings(args) + dit_weight_dtype = gen_settings.dit_weight_dtype + device = gen_settings.device + + # 1. Prepare VAE + logger.info("Loading VAE for batch generation...") + vae_for_batch = hunyuan_image_vae.load_vae(args.vae, device="cpu", disable_mmap=True) + vae_for_batch.eval() + + all_prompt_args_list = [apply_overrides(args, pd) for pd in prompts_data] # Create all arg instances first + for prompt_args in all_prompt_args_list: + check_inputs(prompt_args) # Validate each prompt's height/width + + # 2. Precompute Text Data (Text Encoder) + logger.info("Loading Text Encoder for batch text preprocessing...") + + # Text Encoder loaded to CPU by load_text_encoder + vl_dtype = torch.bfloat16 # Default dtype for Text Encoder + tokenizer_vlm, text_encoder_vlm_batch = hunyuan_image_text_encoder.load_qwen2_5_vl( + args.text_encoder, dtype=vl_dtype, device="cpu", disable_mmap=True + ) + tokenizer_byt5, text_encoder_byt5_batch = hunyuan_image_text_encoder.load_byt5( + args.byt5, dtype=torch.float16, device="cpu", disable_mmap=True + ) + + # Text Encoder to device for this phase + vl_device = torch.device("cpu") if args.text_encoder_cpu else device + text_encoder_vlm_batch.to(vl_device) # Moved into prepare_text_inputs logic + text_encoder_byt5_batch.to(vl_device) + + all_precomputed_text_data = [] + conds_cache_batch = {} + + logger.info("Preprocessing text and LLM/TextEncoder encoding for all prompts...") + temp_shared_models_txt = { + "tokenizer_vlm": tokenizer_vlm, + "text_encoder_vlm": text_encoder_vlm_batch, # on GPU if not text_encoder_cpu + "tokenizer_byt5": tokenizer_byt5, + "text_encoder_byt5": text_encoder_byt5_batch, # on GPU if not text_encoder_cpu + "conds_cache": conds_cache_batch, + } + + for i, prompt_args_item in enumerate(all_prompt_args_list): + logger.info(f"Text preprocessing for prompt {i+1}/{len(all_prompt_args_list)}: {prompt_args_item.prompt}") + + # prepare_text_inputs will move text_encoders to device temporarily + context, context_null = prepare_text_inputs(prompt_args_item, device, temp_shared_models_txt) + text_data = {"context": context, "context_null": context_null} + all_precomputed_text_data.append(text_data) + + # Models should be removed from device after prepare_text_inputs + del tokenizer_batch, text_encoder_batch, temp_shared_models_txt, conds_cache_batch + gc.collect() # Force cleanup of Text Encoder from GPU memory + clean_memory_on_device(device) + + # 3. Load DiT Model once + logger.info("Loading DiT model for batch generation...") + # Use args from the first prompt for DiT loading (LoRA etc. should be consistent for a batch) + first_prompt_args = all_prompt_args_list[0] + dit_model = load_dit_model(first_prompt_args, device, dit_weight_dtype) # Load directly to target device if possible + + if first_prompt_args.save_merged_model: + logger.info("Merged DiT model saved. Skipping generation.") + + shared_models_for_generate = {"model": dit_model} # Pass DiT via shared_models + + all_latents = [] + + logger.info("Generating latents for all prompts...") + with torch.no_grad(): + for i, prompt_args_item in enumerate(all_prompt_args_list): + current_text_data = all_precomputed_text_data[i] + height, width = check_inputs(prompt_args_item) # Get height/width for each prompt + + logger.info(f"Generating latent for prompt {i+1}/{len(all_prompt_args_list)}: {prompt_args_item.prompt}") + try: + # generate is called with precomputed data, so it won't load Text Encoders. + # It will use the DiT model from shared_models_for_generate. + latent = generate(prompt_args_item, gen_settings, shared_models_for_generate, current_text_data) + + if latent is None: # and prompt_args_item.save_merged_model: # Should be caught earlier + continue + + # Save latent if needed (using data from precomputed_image_data for H/W) + if prompt_args_item.output_type in ["latent", "latent_images"]: + save_latent(latent, prompt_args_item, height, width) + + all_latents.append(latent) + except Exception as e: + logger.error(f"Error generating latent for prompt: {prompt_args_item.prompt}. Error: {e}", exc_info=True) + all_latents.append(None) # Add placeholder for failed generations + continue + + # Free DiT model + logger.info("Releasing DiT model from memory...") + if args.blocks_to_swap > 0: + logger.info("Waiting for 5 seconds to finish block swap") + time.sleep(5) + + del shared_models_for_generate["model"] + del dit_model + clean_memory_on_device(device) + synchronize_device(device) # Ensure memory is freed before loading VAE for decoding + + # 4. Decode latents and save outputs (using vae_for_batch) + if args.output_type != "latent": + logger.info("Decoding latents to videos/images using batched VAE...") + vae_for_batch.to(device) # Move VAE to device for decoding + + for i, latent in enumerate(all_latents): + if latent is None: # Skip failed generations + logger.warning(f"Skipping decoding for prompt {i+1} due to previous error.") + continue + + current_args = all_prompt_args_list[i] + logger.info(f"Decoding output {i+1}/{len(all_latents)} for prompt: {current_args.prompt}") + + # if args.output_type is "latent_images", we already saved latent above. + # so we skip saving latent here. + if current_args.output_type == "latent_images": + current_args.output_type = "images" + + # save_output expects latent to be [BCTHW] or [CTHW]. generate returns [BCTHW] (batch size 1). + # latent[0] is correct if generate returns it with batch dim. + # The latent from generate is (1, C, T, H, W) + save_output(current_args, vae_for_batch, latent[0], device) # Pass vae_for_batch + + vae_for_batch.to("cpu") # Move VAE back to CPU + + del vae_for_batch + clean_memory_on_device(device) + + +def process_interactive(args: argparse.Namespace) -> None: + """Process prompts in interactive mode + + Args: + args: Base command line arguments + """ + gen_settings = get_generation_settings(args) + device = gen_settings.device + shared_models = load_shared_models(args) + shared_models["conds_cache"] = {} # Initialize empty cache for interactive mode + + print("Interactive mode. Enter prompts (Ctrl+D or Ctrl+Z (Windows) to exit):") + + try: + import prompt_toolkit + except ImportError: + logger.warning("prompt_toolkit not found. Using basic input instead.") + prompt_toolkit = None + + if prompt_toolkit: + session = prompt_toolkit.PromptSession() + + def input_line(prompt: str) -> str: + return session.prompt(prompt) + + else: + + def input_line(prompt: str) -> str: + return input(prompt) + + vae = hunyuan_image_vae.load_vae(args.vae, device="cpu", disable_mmap=True) + vae.eval() + + try: + while True: + try: + line = input_line("> ") + if not line.strip(): + continue + if len(line.strip()) == 1 and line.strip() in ["\x04", "\x1a"]: # Ctrl+D or Ctrl+Z with prompt_toolkit + raise EOFError # Exit on Ctrl+D or Ctrl+Z + + # Parse prompt + prompt_data = parse_prompt_line(line) + prompt_args = apply_overrides(args, prompt_data) + + # Generate latent + # For interactive, precomputed data is None. shared_models contains text encoders. + latent = generate(prompt_args, gen_settings, shared_models) + + # # If not one_frame_inference, move DiT model to CPU after generation + # if prompt_args.blocks_to_swap > 0: + # logger.info("Waiting for 5 seconds to finish block swap") + # time.sleep(5) + # model = shared_models.get("model") + # model.to("cpu") # Move DiT model to CPU after generation + + # Save latent and video + # returned_vae from generate will be used for decoding here. + save_output(prompt_args, vae, latent[0], device) + + except KeyboardInterrupt: + print("\nInterrupted. Continue (Ctrl+D or Ctrl+Z (Windows) to exit)") + continue + + except EOFError: + print("\nExiting interactive mode") + + +def get_generation_settings(args: argparse.Namespace) -> GenerationSettings: + device = torch.device(args.device) + + dit_weight_dtype = torch.bfloat16 # default + if args.fp8_scaled: + dit_weight_dtype = None # various precision weights, so don't cast to specific dtype + elif args.fp8: + dit_weight_dtype = torch.float8_e4m3fn + + logger.info(f"Using device: {device}, DiT weight weight precision: {dit_weight_dtype}") + + gen_settings = GenerationSettings(device=device, dit_weight_dtype=dit_weight_dtype) + return gen_settings + + +def main(): + # Parse arguments + args = parse_args() + + # Check if latents are provided + latents_mode = args.latent_path is not None and len(args.latent_path) > 0 + + # Set device + device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device(device) + logger.info(f"Using device: {device}") + args.device = device + + if latents_mode: + # Original latent decode mode + original_base_names = [] + latents_list = [] + seeds = [] + + # assert len(args.latent_path) == 1, "Only one latent path is supported for now" + + for latent_path in args.latent_path: + original_base_names.append(os.path.splitext(os.path.basename(latent_path))[0]) + seed = 0 + + if os.path.splitext(latent_path)[1] != ".safetensors": + latents = torch.load(latent_path, map_location="cpu") + else: + latents = load_file(latent_path)["latent"] + with safe_open(latent_path, framework="pt") as f: + metadata = f.metadata() + if metadata is None: + metadata = {} + logger.info(f"Loaded metadata: {metadata}") + + if "seeds" in metadata: + seed = int(metadata["seeds"]) + if "height" in metadata and "width" in metadata: + height = int(metadata["height"]) + width = int(metadata["width"]) + args.image_size = [height, width] + + seeds.append(seed) + logger.info(f"Loaded latent from {latent_path}. Shape: {latents.shape}") + + if latents.ndim == 5: # [BCTHW] + latents = latents.squeeze(0) # [CTHW] + + latents_list.append(latents) + + # latent = torch.stack(latents_list, dim=0) # [N, ...], must be same shape + + for i, latent in enumerate(latents_list): + args.seed = seeds[i] + + vae = hunyuan_image_vae.load_vae(args.vae, device=device, disable_mmap=True) + vae.eval() + save_output(args, vae, latent, device, original_base_names) + + elif args.from_file: + # Batch mode from file + + # Read prompts from file + with open(args.from_file, "r", encoding="utf-8") as f: + prompt_lines = f.readlines() + + # Process prompts + prompts_data = preprocess_prompts_for_batch(prompt_lines, args) + process_batch_prompts(prompts_data, args) + + elif args.interactive: + # Interactive mode + process_interactive(args) + + else: + # Single prompt mode (original behavior) + + # Generate latent + gen_settings = get_generation_settings(args) + + # For single mode, precomputed data is None, shared_models is None. + # generate will load all necessary models (Text Encoders, DiT). + latent = generate(args, gen_settings) + # print(f"Generated latent shape: {latent.shape}") + # if args.save_merged_model: + # return + + clean_memory_on_device(device) + + # Save latent and video + vae = hunyuan_image_vae.load_vae(args.vae, device="cpu", disable_mmap=True) + vae.eval() + save_output(args, vae, latent, device) + + logger.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/library/attention.py b/library/attention.py new file mode 100644 index 00000000..10a09614 --- /dev/null +++ b/library/attention.py @@ -0,0 +1,50 @@ +import torch +from typing import Optional + + +def attention( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seq_lens: list[int], attn_mode: str = "torch", drop_rate: float = 0.0 +) -> torch.Tensor: + """ + Compute scaled dot-product attention with variable sequence lengths. + + Handles batches with different sequence lengths by splitting and + processing each sequence individually. + + Args: + q: Query tensor [B, L, H, D]. + k: Key tensor [B, L, H, D]. + v: Value tensor [B, L, H, D]. + seq_lens: Valid sequence length for each batch element. + attn_mode: Attention implementation ("torch" or "sageattn"). + drop_rate: Attention dropout rate. + + Returns: + Attention output tensor [B, L, H*D]. + """ + # Determine tensor layout based on attention implementation + if attn_mode == "torch" or attn_mode == "sageattn": + transpose_fn = lambda x: x.transpose(1, 2) # [B, H, L, D] for SDPA + else: + transpose_fn = lambda x: x # [B, L, H, D] for other implementations + + # Process each batch element with its valid sequence length + q = [transpose_fn(q[i : i + 1, : seq_lens[i]]) for i in range(len(q))] + k = [transpose_fn(k[i : i + 1, : seq_lens[i]]) for i in range(len(k))] + v = [transpose_fn(v[i : i + 1, : seq_lens[i]]) for i in range(len(v))] + + if attn_mode == "torch": + x = [] + for i in range(len(q)): + x_i = torch.nn.functional.scaled_dot_product_attention(q[i], k[i], v[i], dropout_p=drop_rate) + q[i] = None + k[i] = None + v[i] = None + x.append(x_i) + x = torch.cat(x, dim=0) + del q, k, v + # Currently only PyTorch SDPA is implemented + + x = transpose_fn(x) # [B, L, H, D] + x = x.reshape(x.shape[0], x.shape[1], -1) # [B, L, H*D] + return x diff --git a/library/fp8_optimization_utils.py b/library/fp8_optimization_utils.py new file mode 100644 index 00000000..a91eb4e4 --- /dev/null +++ b/library/fp8_optimization_utils.py @@ -0,0 +1,391 @@ +import os +from typing import List, Union +import torch +import torch.nn as nn +import torch.nn.functional as F + +import logging + +from tqdm import tqdm + +from library.device_utils import clean_memory_on_device +from library.utils import MemoryEfficientSafeOpen, setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +def calculate_fp8_maxval(exp_bits=4, mantissa_bits=3, sign_bits=1): + """ + Calculate the maximum representable value in FP8 format. + Default is E4M3 format (4-bit exponent, 3-bit mantissa, 1-bit sign). + + Args: + exp_bits (int): Number of exponent bits + mantissa_bits (int): Number of mantissa bits + sign_bits (int): Number of sign bits (0 or 1) + + Returns: + float: Maximum value representable in FP8 format + """ + assert exp_bits + mantissa_bits + sign_bits == 8, "Total bits must be 8" + + # Calculate exponent bias + bias = 2 ** (exp_bits - 1) - 1 + + # Calculate maximum mantissa value + mantissa_max = 1.0 + for i in range(mantissa_bits - 1): + mantissa_max += 2 ** -(i + 1) + + # Calculate maximum value + max_value = mantissa_max * (2 ** (2**exp_bits - 1 - bias)) + + return max_value + + +def quantize_tensor_to_fp8(tensor, scale, exp_bits=4, mantissa_bits=3, sign_bits=1, max_value=None, min_value=None): + """ + Quantize a tensor to FP8 format. + + Args: + tensor (torch.Tensor): Tensor to quantize + scale (float or torch.Tensor): Scale factor + exp_bits (int): Number of exponent bits + mantissa_bits (int): Number of mantissa bits + sign_bits (int): Number of sign bits + + Returns: + tuple: (quantized_tensor, scale_factor) + """ + # Create scaled tensor + scaled_tensor = tensor / scale + + # Calculate FP8 parameters + bias = 2 ** (exp_bits - 1) - 1 + + if max_value is None: + # Calculate max and min values + max_value = calculate_fp8_maxval(exp_bits, mantissa_bits, sign_bits) + min_value = -max_value if sign_bits > 0 else 0.0 + + # Clamp tensor to range + clamped_tensor = torch.clamp(scaled_tensor, min_value, max_value) + + # Quantization process + abs_values = torch.abs(clamped_tensor) + nonzero_mask = abs_values > 0 + + # Calculate log scales (only for non-zero elements) + log_scales = torch.zeros_like(clamped_tensor) + if nonzero_mask.any(): + log_scales[nonzero_mask] = torch.floor(torch.log2(abs_values[nonzero_mask]) + bias).detach() + + # Limit log scales and calculate quantization factor + log_scales = torch.clamp(log_scales, min=1.0) + quant_factor = 2.0 ** (log_scales - mantissa_bits - bias) + + # Quantize and dequantize + quantized = torch.round(clamped_tensor / quant_factor) * quant_factor + + return quantized, scale + + +def optimize_state_dict_with_fp8( + state_dict, calc_device, target_layer_keys=None, exclude_layer_keys=None, exp_bits=4, mantissa_bits=3, move_to_device=False +): + """ + Optimize Linear layer weights in a model's state dict to FP8 format. + + Args: + state_dict (dict): State dict to optimize, replaced in-place + calc_device (str): Device to quantize tensors on + target_layer_keys (list, optional): Layer key patterns to target (None for all Linear layers) + exclude_layer_keys (list, optional): Layer key patterns to exclude + exp_bits (int): Number of exponent bits + mantissa_bits (int): Number of mantissa bits + move_to_device (bool): Move optimized tensors to the calculating device + + Returns: + dict: FP8 optimized state dict + """ + if exp_bits == 4 and mantissa_bits == 3: + fp8_dtype = torch.float8_e4m3fn + elif exp_bits == 5 and mantissa_bits == 2: + fp8_dtype = torch.float8_e5m2 + else: + raise ValueError(f"Unsupported FP8 format: E{exp_bits}M{mantissa_bits}") + + # Calculate FP8 max value + max_value = calculate_fp8_maxval(exp_bits, mantissa_bits) + min_value = -max_value # this function supports only signed FP8 + + # Create optimized state dict + optimized_count = 0 + + # Enumerate tarket keys + target_state_dict_keys = [] + for key in state_dict.keys(): + # Check if it's a weight key and matches target patterns + is_target = (target_layer_keys is None or any(pattern in key for pattern in target_layer_keys)) and key.endswith(".weight") + is_excluded = exclude_layer_keys is not None and any(pattern in key for pattern in exclude_layer_keys) + is_target = is_target and not is_excluded + + if is_target and isinstance(state_dict[key], torch.Tensor): + target_state_dict_keys.append(key) + + # Process each key + for key in tqdm(target_state_dict_keys): + value = state_dict[key] + + # Save original device and dtype + original_device = value.device + original_dtype = value.dtype + + # Move to calculation device + if calc_device is not None: + value = value.to(calc_device) + + # Calculate scale factor + scale = torch.max(torch.abs(value.flatten())) / max_value + # print(f"Optimizing {key} with scale: {scale}") + + # Quantize weight to FP8 + quantized_weight, _ = quantize_tensor_to_fp8(value, scale, exp_bits, mantissa_bits, 1, max_value, min_value) + + # Add to state dict using original key for weight and new key for scale + fp8_key = key # Maintain original key + scale_key = key.replace(".weight", ".scale_weight") + + quantized_weight = quantized_weight.to(fp8_dtype) + + if not move_to_device: + quantized_weight = quantized_weight.to(original_device) + + scale_tensor = torch.tensor([scale], dtype=original_dtype, device=quantized_weight.device) + + state_dict[fp8_key] = quantized_weight + state_dict[scale_key] = scale_tensor + + optimized_count += 1 + + if calc_device is not None: # optimized_count % 10 == 0 and + # free memory on calculation device + clean_memory_on_device(calc_device) + + logger.info(f"Number of optimized Linear layers: {optimized_count}") + return state_dict + + +def load_safetensors_with_fp8_optimization( + model_files: List[str], + calc_device: Union[str, torch.device], + target_layer_keys=None, + exclude_layer_keys=None, + exp_bits=4, + mantissa_bits=3, + move_to_device=False, + weight_hook=None, +): + """ + Load weight tensors from safetensors files and merge LoRA weights into the state dict with explicit FP8 optimization. + + Args: + model_files (list[str]): List of model files to load + calc_device (str or torch.device): Device to quantize tensors on + target_layer_keys (list, optional): Layer key patterns to target for optimization (None for all Linear layers) + exclude_layer_keys (list, optional): Layer key patterns to exclude from optimization + exp_bits (int): Number of exponent bits + mantissa_bits (int): Number of mantissa bits + move_to_device (bool): Move optimized tensors to the calculating device + weight_hook (callable, optional): Function to apply to each weight tensor before optimization + + Returns: + dict: FP8 optimized state dict + """ + if exp_bits == 4 and mantissa_bits == 3: + fp8_dtype = torch.float8_e4m3fn + elif exp_bits == 5 and mantissa_bits == 2: + fp8_dtype = torch.float8_e5m2 + else: + raise ValueError(f"Unsupported FP8 format: E{exp_bits}M{mantissa_bits}") + + # Calculate FP8 max value + max_value = calculate_fp8_maxval(exp_bits, mantissa_bits) + min_value = -max_value # this function supports only signed FP8 + + # Define function to determine if a key is a target key. target means fp8 optimization, not for weight hook. + def is_target_key(key): + # Check if weight key matches target patterns and does not match exclude patterns + is_target = (target_layer_keys is None or any(pattern in key for pattern in target_layer_keys)) and key.endswith(".weight") + is_excluded = exclude_layer_keys is not None and any(pattern in key for pattern in exclude_layer_keys) + return is_target and not is_excluded + + # Create optimized state dict + optimized_count = 0 + + # Process each file + state_dict = {} + for model_file in model_files: + with MemoryEfficientSafeOpen(model_file) as f: + keys = f.keys() + for key in tqdm(keys, desc=f"Loading {os.path.basename(model_file)}", unit="key"): + value = f.get_tensor(key) + if weight_hook is not None: + # Apply weight hook if provided + value = weight_hook(key, value) + + if not is_target_key(key): + state_dict[key] = value + continue + + # Save original device and dtype + original_device = value.device + original_dtype = value.dtype + + # Move to calculation device + if calc_device is not None: + value = value.to(calc_device) + + # Calculate scale factor + scale = torch.max(torch.abs(value.flatten())) / max_value + # print(f"Optimizing {key} with scale: {scale}") + + # Quantize weight to FP8 + quantized_weight, _ = quantize_tensor_to_fp8(value, scale, exp_bits, mantissa_bits, 1, max_value, min_value) + + # Add to state dict using original key for weight and new key for scale + fp8_key = key # Maintain original key + scale_key = key.replace(".weight", ".scale_weight") + assert fp8_key != scale_key, "FP8 key and scale key must be different" + + quantized_weight = quantized_weight.to(fp8_dtype) + + if not move_to_device: + quantized_weight = quantized_weight.to(original_device) + + scale_tensor = torch.tensor([scale], dtype=original_dtype, device=quantized_weight.device) + + state_dict[fp8_key] = quantized_weight + state_dict[scale_key] = scale_tensor + + optimized_count += 1 + + if calc_device is not None and optimized_count % 10 == 0: + # free memory on calculation device + clean_memory_on_device(calc_device) + + logger.info(f"Number of optimized Linear layers: {optimized_count}") + return state_dict + + +def fp8_linear_forward_patch(self: nn.Linear, x, use_scaled_mm=False, max_value=None): + """ + Patched forward method for Linear layers with FP8 weights. + + Args: + self: Linear layer instance + x (torch.Tensor): Input tensor + use_scaled_mm (bool): Use scaled_mm for FP8 Linear layers, requires SM 8.9+ (RTX 40 series) + max_value (float): Maximum value for FP8 quantization. If None, no quantization is applied for input tensor. + + Returns: + torch.Tensor: Result of linear transformation + """ + if use_scaled_mm: + input_dtype = x.dtype + original_weight_dtype = self.scale_weight.dtype + weight_dtype = self.weight.dtype + target_dtype = torch.float8_e5m2 + assert weight_dtype == torch.float8_e4m3fn, "Only FP8 E4M3FN format is supported" + assert x.ndim == 3, "Input tensor must be 3D (batch_size, seq_len, hidden_dim)" + + if max_value is None: + # no input quantization + scale_x = torch.tensor(1.0, dtype=torch.float32, device=x.device) + else: + # calculate scale factor for input tensor + scale_x = (torch.max(torch.abs(x.flatten())) / max_value).to(torch.float32) + + # quantize input tensor to FP8: this seems to consume a lot of memory + x, _ = quantize_tensor_to_fp8(x, scale_x, 5, 2, 1, max_value, -max_value) + + original_shape = x.shape + x = x.reshape(-1, x.shape[2]).to(target_dtype) + + weight = self.weight.t() + scale_weight = self.scale_weight.to(torch.float32) + + if self.bias is not None: + # float32 is not supported with bias in scaled_mm + o = torch._scaled_mm(x, weight, out_dtype=original_weight_dtype, bias=self.bias, scale_a=scale_x, scale_b=scale_weight) + else: + o = torch._scaled_mm(x, weight, out_dtype=input_dtype, scale_a=scale_x, scale_b=scale_weight) + + return o.reshape(original_shape[0], original_shape[1], -1).to(input_dtype) + + else: + # Dequantize the weight + original_dtype = self.scale_weight.dtype + dequantized_weight = self.weight.to(original_dtype) * self.scale_weight + + # Perform linear transformation + if self.bias is not None: + output = F.linear(x, dequantized_weight, self.bias) + else: + output = F.linear(x, dequantized_weight) + + return output + + +def apply_fp8_monkey_patch(model, optimized_state_dict, use_scaled_mm=False): + """ + Apply monkey patching to a model using FP8 optimized state dict. + + Args: + model (nn.Module): Model instance to patch + optimized_state_dict (dict): FP8 optimized state dict + use_scaled_mm (bool): Use scaled_mm for FP8 Linear layers, requires SM 8.9+ (RTX 40 series) + + Returns: + nn.Module: The patched model (same instance, modified in-place) + """ + # # Calculate FP8 float8_e5m2 max value + # max_value = calculate_fp8_maxval(5, 2) + max_value = None # do not quantize input tensor + + # Find all scale keys to identify FP8-optimized layers + scale_keys = [k for k in optimized_state_dict.keys() if k.endswith(".scale_weight")] + + # Enumerate patched layers + patched_module_paths = set() + for scale_key in scale_keys: + # Extract module path from scale key (remove .scale_weight) + module_path = scale_key.rsplit(".scale_weight", 1)[0] + patched_module_paths.add(module_path) + + patched_count = 0 + + # Apply monkey patch to each layer with FP8 weights + for name, module in model.named_modules(): + # Check if this module has a corresponding scale_weight + has_scale = name in patched_module_paths + + # Apply patch if it's a Linear layer with FP8 scale + if isinstance(module, nn.Linear) and has_scale: + # register the scale_weight as a buffer to load the state_dict + module.register_buffer("scale_weight", torch.tensor(1.0, dtype=module.weight.dtype)) + + # Create a new forward method with the patched version. + def new_forward(self, x): + return fp8_linear_forward_patch(self, x, use_scaled_mm, max_value) + + # Bind method to module + module.forward = new_forward.__get__(module, type(module)) + + patched_count += 1 + + logger.info(f"Number of monkey-patched Linear layers: {patched_count}") + return model diff --git a/library/hunyuan_image_models.py b/library/hunyuan_image_models.py new file mode 100644 index 00000000..5bd08c5c --- /dev/null +++ b/library/hunyuan_image_models.py @@ -0,0 +1,374 @@ +# Original work: https://github.com/Tencent-Hunyuan/HunyuanImage-2.1 +# Re-implemented for license compliance for sd-scripts. + +from typing import Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn +from accelerate import init_empty_weights + +from library.fp8_optimization_utils import apply_fp8_monkey_patch +from library.lora_utils import load_safetensors_with_lora_and_fp8 +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +from library.hunyuan_image_modules import ( + SingleTokenRefiner, + ByT5Mapper, + PatchEmbed2D, + TimestepEmbedder, + MMDoubleStreamBlock, + MMSingleStreamBlock, + FinalLayer, +) +from library.hunyuan_image_utils import get_nd_rotary_pos_embed + +FP8_OPTIMIZATION_TARGET_KEYS = ["double_blocks", "single_blocks"] +FP8_OPTIMIZATION_EXCLUDE_KEYS = [ + "norm", + "_mod", + "modulation", +] + + +# region DiT Model +class HYImageDiffusionTransformer(nn.Module): + """ + HunyuanImage-2.1 Diffusion Transformer. + + A multimodal transformer for image generation with text conditioning, + featuring separate double-stream and single-stream processing blocks. + + Args: + attn_mode: Attention implementation mode ("torch" or "sageattn"). + """ + + def __init__(self, attn_mode: str = "torch"): + super().__init__() + + # Fixed architecture parameters for HunyuanImage-2.1 + self.patch_size = [1, 1] # 1x1 patch size (no spatial downsampling) + self.in_channels = 64 # Input latent channels + self.out_channels = 64 # Output latent channels + self.unpatchify_channels = self.out_channels + self.guidance_embed = False # Guidance embedding disabled + self.rope_dim_list = [64, 64] # RoPE dimensions for 2D positional encoding + self.rope_theta = 256 # RoPE frequency scaling + self.use_attention_mask = True + self.text_projection = "single_refiner" + self.hidden_size = 3584 # Model dimension + self.heads_num = 28 # Number of attention heads + + # Architecture configuration + mm_double_blocks_depth = 20 # Double-stream transformer blocks + mm_single_blocks_depth = 40 # Single-stream transformer blocks + mlp_width_ratio = 4 # MLP expansion ratio + text_states_dim = 3584 # Text encoder output dimension + guidance_embed = False # No guidance embedding + + # Layer configuration + mlp_act_type: str = "gelu_tanh" # MLP activation function + qkv_bias: bool = True # Use bias in QKV projections + qk_norm: bool = True # Apply QK normalization + qk_norm_type: str = "rms" # RMS normalization type + + self.attn_mode = attn_mode + + # ByT5 character-level text encoder mapping + self.byt5_in = ByT5Mapper(in_dim=1472, out_dim=2048, hidden_dim=2048, out_dim1=self.hidden_size, use_residual=False) + + # Image latent patch embedding + self.img_in = PatchEmbed2D(self.patch_size, self.in_channels, self.hidden_size) + + # Text token refinement with cross-attention + self.txt_in = SingleTokenRefiner(text_states_dim, self.hidden_size, self.heads_num, depth=2, attn_mode=self.attn_mode) + + # Timestep embedding for diffusion process + self.time_in = TimestepEmbedder(self.hidden_size, nn.SiLU) + + # MeanFlow not supported in this implementation + self.time_r_in = None + + # Guidance embedding (disabled for non-distilled model) + self.guidance_in = TimestepEmbedder(self.hidden_size, nn.SiLU) if guidance_embed else None + + # Double-stream blocks: separate image and text processing + self.double_blocks = nn.ModuleList( + [ + MMDoubleStreamBlock( + self.hidden_size, + self.heads_num, + mlp_width_ratio=mlp_width_ratio, + mlp_act_type=mlp_act_type, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + qkv_bias=qkv_bias, + attn_mode=self.attn_mode, + ) + for _ in range(mm_double_blocks_depth) + ] + ) + + # Single-stream blocks: joint processing of concatenated features + self.single_blocks = nn.ModuleList( + [ + MMSingleStreamBlock( + self.hidden_size, + self.heads_num, + mlp_width_ratio=mlp_width_ratio, + mlp_act_type=mlp_act_type, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + attn_mode=self.attn_mode, + ) + for _ in range(mm_single_blocks_depth) + ] + ) + + self.final_layer = FinalLayer(self.hidden_size, self.patch_size, self.out_channels, nn.SiLU) + + def get_rotary_pos_embed(self, rope_sizes): + """ + Generate 2D rotary position embeddings for image tokens. + + Args: + rope_sizes: Tuple of (height, width) for spatial dimensions. + + Returns: + Tuple of (freqs_cos, freqs_sin) tensors for rotary position encoding. + """ + freqs_cos, freqs_sin = get_nd_rotary_pos_embed(self.rope_dim_list, rope_sizes, theta=self.rope_theta) + return freqs_cos, freqs_sin + + def reorder_txt_token( + self, byt5_txt: torch.Tensor, txt: torch.Tensor, byt5_text_mask: torch.Tensor, text_mask: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, list[int]]: + """ + Combine and reorder ByT5 character-level and word-level text embeddings. + + Concatenates valid tokens from both encoders and creates appropriate masks. + + Args: + byt5_txt: ByT5 character-level embeddings [B, L1, D]. + txt: Word-level text embeddings [B, L2, D]. + byt5_text_mask: Valid token mask for ByT5 [B, L1]. + text_mask: Valid token mask for word tokens [B, L2]. + + Returns: + Tuple of (reordered_embeddings, combined_mask, sequence_lengths). + """ + # Process each batch element separately to handle variable sequence lengths + + reorder_txt = [] + reorder_mask = [] + + txt_lens = [] + for i in range(text_mask.shape[0]): + byt5_text_mask_i = byt5_text_mask[i].bool() + text_mask_i = text_mask[i].bool() + byt5_text_length = byt5_text_mask_i.sum() + text_length = text_mask_i.sum() + assert byt5_text_length == byt5_text_mask_i[:byt5_text_length].sum() + assert text_length == text_mask_i[:text_length].sum() + + byt5_txt_i = byt5_txt[i] + txt_i = txt[i] + reorder_txt_i = torch.cat( + [byt5_txt_i[:byt5_text_length], txt_i[:text_length], byt5_txt_i[byt5_text_length:], txt_i[text_length:]], dim=0 + ) + + reorder_mask_i = torch.zeros( + byt5_text_mask_i.shape[0] + text_mask_i.shape[0], dtype=torch.bool, device=byt5_text_mask_i.device + ) + reorder_mask_i[: byt5_text_length + text_length] = True + + reorder_txt.append(reorder_txt_i) + reorder_mask.append(reorder_mask_i) + txt_lens.append(byt5_text_length + text_length) + + reorder_txt = torch.stack(reorder_txt) + reorder_mask = torch.stack(reorder_mask).to(dtype=torch.int64) + + return reorder_txt, reorder_mask, txt_lens + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + text_states: torch.Tensor, + encoder_attention_mask: torch.Tensor, + byt5_text_states: Optional[torch.Tensor] = None, + byt5_text_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Forward pass through the HunyuanImage diffusion transformer. + + Args: + hidden_states: Input image latents [B, C, H, W]. + timestep: Diffusion timestep [B]. + text_states: Word-level text embeddings [B, L, D]. + encoder_attention_mask: Text attention mask [B, L]. + byt5_text_states: ByT5 character-level embeddings [B, L_byt5, D_byt5]. + byt5_text_mask: ByT5 attention mask [B, L_byt5]. + + Returns: + Tuple of (denoised_image, spatial_shape). + """ + img = x = hidden_states + text_mask = encoder_attention_mask + t = timestep + txt = text_states + + # Calculate spatial dimensions for rotary position embeddings + _, _, oh, ow = x.shape + th, tw = oh, ow # Height and width (patch_size=[1,1] means no spatial downsampling) + freqs_cis = self.get_rotary_pos_embed((th, tw)) + + # Reshape image latents to sequence format: [B, C, H, W] -> [B, H*W, C] + img = self.img_in(img) + + # Generate timestep conditioning vector + vec = self.time_in(t) + + # MeanFlow and guidance embedding not used in this configuration + + # Process text tokens through refinement layers + txt_lens = text_mask.to(torch.bool).sum(dim=1).tolist() + txt = self.txt_in(txt, t, txt_lens) + + # Integrate character-level ByT5 features with word-level tokens + # Use variable length sequences with sequence lengths + byt5_txt = self.byt5_in(byt5_text_states) + txt, _, txt_lens = self.reorder_txt_token(byt5_txt, txt, byt5_text_mask, text_mask) + + # Trim sequences to maximum length in the batch + img_seq_len = img.shape[1] + # print(f"img_seq_len: {img_seq_len}, txt_lens: {txt_lens}") + seq_lens = [img_seq_len + l for l in txt_lens] + max_txt_len = max(txt_lens) + # print(f"max_txt_len: {max_txt_len}, seq_lens: {seq_lens}, txt.shape: {txt.shape}") + txt = txt[:, :max_txt_len, :] + txt_seq_len = txt.shape[1] + + # Process through double-stream blocks (separate image/text attention) + for index, block in enumerate(self.double_blocks): + img, txt = block(img, txt, vec, freqs_cis, seq_lens) + + # Concatenate image and text tokens for joint processing + x = torch.cat((img, txt), 1) + + # Process through single-stream blocks (joint attention) + for index, block in enumerate(self.single_blocks): + x = block(x, vec, txt_seq_len, freqs_cis, seq_lens) + + img = x[:, :img_seq_len, ...] + + # Apply final projection to output space + img = self.final_layer(img, vec) + + # Reshape from sequence to spatial format: [B, L, C] -> [B, C, H, W] + img = self.unpatchify_2d(img, th, tw) + return img + + def unpatchify_2d(self, x, h, w): + """ + Convert sequence format back to spatial image format. + + Args: + x: Input tensor [B, H*W, C]. + h: Height dimension. + w: Width dimension. + + Returns: + Spatial tensor [B, C, H, W]. + """ + c = self.unpatchify_channels + + x = x.reshape(shape=(x.shape[0], h, w, c)) + imgs = x.permute(0, 3, 1, 2) + return imgs + + +# endregion + +# region Model Utils + + +def create_model(attn_mode: str, split_attn: bool, dtype: Optional[torch.dtype]) -> HYImageDiffusionTransformer: + with init_empty_weights(): + model = HYImageDiffusionTransformer(attn_mode=attn_mode) + if dtype is not None: + model.to(dtype) + return model + + +def load_hunyuan_image_model( + device: Union[str, torch.device], + dit_path: str, + attn_mode: str, + split_attn: bool, + loading_device: Union[str, torch.device], + dit_weight_dtype: Optional[torch.dtype], + fp8_scaled: bool = False, + lora_weights_list: Optional[Dict[str, torch.Tensor]] = None, + lora_multipliers: Optional[list[float]] = None, +) -> HYImageDiffusionTransformer: + """ + Load a HunyuanImage model from the specified checkpoint. + + Args: + device (Union[str, torch.device]): Device for optimization or merging + dit_path (str): Path to the DiT model checkpoint. + attn_mode (str): Attention mode to use, e.g., "torch", "flash", etc. + split_attn (bool): Whether to use split attention. + loading_device (Union[str, torch.device]): Device to load the model weights on. + dit_weight_dtype (Optional[torch.dtype]): Data type of the DiT weights. + If None, it will be loaded as is (same as the state_dict) or scaled for fp8. if not None, model weights will be casted to this dtype. + fp8_scaled (bool): Whether to use fp8 scaling for the model weights. + lora_weights_list (Optional[Dict[str, torch.Tensor]]): LoRA weights to apply, if any. + lora_multipliers (Optional[List[float]]): LoRA multipliers for the weights, if any. + """ + # dit_weight_dtype is None for fp8_scaled + assert (not fp8_scaled and dit_weight_dtype is not None) or (fp8_scaled and dit_weight_dtype is None) + + device = torch.device(device) + loading_device = torch.device(loading_device) + + model = create_model(attn_mode, split_attn, dit_weight_dtype) + + # load model weights with dynamic fp8 optimization and LoRA merging if needed + logger.info(f"Loading DiT model from {dit_path}, device={loading_device}") + + sd = load_safetensors_with_lora_and_fp8( + model_files=dit_path, + lora_weights_list=lora_weights_list, + lora_multipliers=lora_multipliers, + fp8_optimization=fp8_scaled, + calc_device=device, + move_to_device=(loading_device == device), + dit_weight_dtype=dit_weight_dtype, + target_keys=FP8_OPTIMIZATION_TARGET_KEYS, + exclude_keys=FP8_OPTIMIZATION_EXCLUDE_KEYS, + ) + + if fp8_scaled: + apply_fp8_monkey_patch(model, sd, use_scaled_mm=False) + + if loading_device.type != "cpu": + # make sure all the model weights are on the loading_device + logger.info(f"Moving weights to {loading_device}") + for key in sd.keys(): + sd[key] = sd[key].to(loading_device) + + info = model.load_state_dict(sd, strict=True, assign=True) + logger.info(f"Loaded DiT model from {dit_path}, info={info}") + + return model + + +# endregion diff --git a/library/hunyuan_image_modules.py b/library/hunyuan_image_modules.py new file mode 100644 index 00000000..b4ded4c5 --- /dev/null +++ b/library/hunyuan_image_modules.py @@ -0,0 +1,804 @@ +# Original work: https://github.com/Tencent-Hunyuan/HunyuanImage-2.1 +# Re-implemented for license compliance for sd-scripts. + +from typing import Tuple, Callable +import torch +import torch.nn as nn +from einops import rearrange + +from library.attention import attention +from library.hunyuan_image_utils import timestep_embedding, apply_rotary_emb, _to_tuple, apply_gate, modulate +from library.attention import attention + +# region Modules + + +class ByT5Mapper(nn.Module): + """ + Maps ByT5 character-level encoder outputs to transformer hidden space. + + Applies layer normalization, two MLP layers with GELU activation, + and optional residual connection. + + Args: + in_dim: Input dimension from ByT5 encoder (1472 for ByT5-large). + out_dim: Intermediate dimension after first projection. + hidden_dim: Hidden dimension for MLP layer. + out_dim1: Final output dimension matching transformer hidden size. + use_residual: Whether to add residual connection (requires in_dim == out_dim). + """ + + def __init__(self, in_dim, out_dim, hidden_dim, out_dim1, use_residual=True): + super().__init__() + if use_residual: + assert in_dim == out_dim + self.layernorm = nn.LayerNorm(in_dim) + self.fc1 = nn.Linear(in_dim, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, out_dim) + self.fc3 = nn.Linear(out_dim, out_dim1) + self.use_residual = use_residual + self.act_fn = nn.GELU() + + def forward(self, x): + """ + Transform ByT5 embeddings to transformer space. + + Args: + x: Input ByT5 embeddings [..., in_dim]. + + Returns: + Transformed embeddings [..., out_dim1]. + """ + residual = x + x = self.layernorm(x) + x = self.fc1(x) + x = self.act_fn(x) + x = self.fc2(x) + x = self.act_fn(x) + x = self.fc3(x) + if self.use_residual: + x = x + residual + return x + + +class PatchEmbed2D(nn.Module): + """ + 2D patch embedding layer for converting image latents to transformer tokens. + + Uses 2D convolution to project image patches to embedding space. + For HunyuanImage-2.1, patch_size=[1,1] means no spatial downsampling. + + Args: + patch_size: Spatial size of patches (int or tuple). + in_chans: Number of input channels. + embed_dim: Output embedding dimension. + """ + + def __init__(self, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + self.patch_size = tuple(patch_size) + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=True) + self.norm = nn.Identity() # No normalization layer used + + def forward(self, x): + x = self.proj(x) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + return x + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar diffusion timesteps into vector representations. + + Uses sinusoidal encoding followed by a two-layer MLP. + + Args: + hidden_size: Output embedding dimension. + act_layer: Activation function class (e.g., nn.SiLU). + frequency_embedding_size: Dimension of sinusoidal encoding. + max_period: Maximum period for sinusoidal frequencies. + out_size: Output dimension (defaults to hidden_size). + """ + + def __init__(self, hidden_size, act_layer, frequency_embedding_size=256, max_period=10000, out_size=None): + super().__init__() + self.frequency_embedding_size = frequency_embedding_size + self.max_period = max_period + if out_size is None: + out_size = hidden_size + + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), act_layer(), nn.Linear(hidden_size, out_size, bias=True) + ) + + def forward(self, t): + t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype) + return self.mlp(t_freq) + + +class TextProjection(nn.Module): + """ + Projects text embeddings through a two-layer MLP. + + Used for context-aware representation computation in token refinement. + + Args: + in_channels: Input feature dimension. + hidden_size: Hidden and output dimension. + act_layer: Activation function class. + """ + + def __init__(self, in_channels, hidden_size, act_layer): + super().__init__() + self.linear_1 = nn.Linear(in_features=in_channels, out_features=hidden_size, bias=True) + self.act_1 = act_layer() + self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True) + + def forward(self, caption): + hidden_states = self.linear_1(caption) + hidden_states = self.act_1(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class MLP(nn.Module): + """ + Multi-layer perceptron with configurable activation and normalization. + + Standard two-layer MLP with optional dropout and intermediate normalization. + + Args: + in_channels: Input feature dimension. + hidden_channels: Hidden layer dimension (defaults to in_channels). + out_features: Output dimension (defaults to in_channels). + act_layer: Activation function class. + norm_layer: Optional normalization layer class. + bias: Whether to use bias (can be bool or tuple for each layer). + drop: Dropout rate (can be float or tuple for each layer). + use_conv: Whether to use convolution instead of linear (not supported). + """ + + def __init__( + self, + in_channels, + hidden_channels=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0.0, + use_conv=False, + ): + super().__init__() + assert not use_conv, "Convolutional MLP not supported in this implementation." + + out_features = out_features or in_channels + hidden_channels = hidden_channels or in_channels + bias = _to_tuple(bias, 2) + drop_probs = _to_tuple(drop, 2) + + self.fc1 = nn.Linear(in_channels, hidden_channels, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.norm = norm_layer(hidden_channels) if norm_layer is not None else nn.Identity() + self.fc2 = nn.Linear(hidden_channels, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.norm(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class IndividualTokenRefinerBlock(nn.Module): + """ + Single transformer block for individual token refinement. + + Applies self-attention and MLP with adaptive layer normalization (AdaLN) + conditioned on timestep and context information. + + Args: + hidden_size: Model dimension. + heads_num: Number of attention heads. + mlp_width_ratio: MLP expansion ratio. + mlp_drop_rate: MLP dropout rate. + act_type: Activation function (only "silu" supported). + qk_norm: QK normalization flag (must be False). + qk_norm_type: QK normalization type (only "layer" supported). + qkv_bias: Use bias in QKV projections. + attn_mode: Attention implementation mode. + """ + + def __init__( + self, + hidden_size: int, + heads_num: int, + mlp_width_ratio: float = 4.0, + mlp_drop_rate: float = 0.0, + act_type: str = "silu", + qk_norm: bool = False, + qk_norm_type: str = "layer", + qkv_bias: bool = True, + attn_mode: str = "torch", + ): + super().__init__() + assert qk_norm_type == "layer", "Only layer normalization supported for QK norm." + assert act_type == "silu", "Only SiLU activation supported." + assert not qk_norm, "QK normalization must be disabled." + + self.attn_mode = attn_mode + + self.heads_num = heads_num + mlp_hidden_dim = int(hidden_size * mlp_width_ratio) + + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) + self.self_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias) + + self.self_attn_q_norm = nn.Identity() + self.self_attn_k_norm = nn.Identity() + self.self_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) + + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) + self.mlp = MLP(in_channels=hidden_size, hidden_channels=mlp_hidden_dim, act_layer=nn.SiLU, drop=mlp_drop_rate) + + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 2 * hidden_size, bias=True), + ) + + def forward( + self, + x: torch.Tensor, + c: torch.Tensor, # Combined timestep and context conditioning + txt_lens: list[int], + ) -> torch.Tensor: + """ + Apply self-attention and MLP with adaptive conditioning. + + Args: + x: Input token embeddings [B, L, C]. + c: Combined conditioning vector [B, C]. + txt_lens: Valid sequence lengths for each batch element. + + Returns: + Refined token embeddings [B, L, C]. + """ + gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1) + norm_x = self.norm1(x) + qkv = self.self_attn_qkv(norm_x) + q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) + q = self.self_attn_q_norm(q).to(v) + k = self.self_attn_k_norm(k).to(v) + attn = attention(q, k, v, seq_lens=txt_lens, attn_mode=self.attn_mode) + + x = x + apply_gate(self.self_attn_proj(attn), gate_msa) + x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp) + return x + + +class IndividualTokenRefiner(nn.Module): + """ + Stack of token refinement blocks with self-attention. + + Processes tokens individually with adaptive layer normalization. + + Args: + hidden_size: Model dimension. + heads_num: Number of attention heads. + depth: Number of refinement blocks. + mlp_width_ratio: MLP expansion ratio. + mlp_drop_rate: MLP dropout rate. + act_type: Activation function type. + qk_norm: QK normalization flag. + qk_norm_type: QK normalization type. + qkv_bias: Use bias in QKV projections. + attn_mode: Attention implementation mode. + """ + + def __init__( + self, + hidden_size: int, + heads_num: int, + depth: int, + mlp_width_ratio: float = 4.0, + mlp_drop_rate: float = 0.0, + act_type: str = "silu", + qk_norm: bool = False, + qk_norm_type: str = "layer", + qkv_bias: bool = True, + attn_mode: str = "torch", + ): + super().__init__() + self.blocks = nn.ModuleList( + [ + IndividualTokenRefinerBlock( + hidden_size=hidden_size, + heads_num=heads_num, + mlp_width_ratio=mlp_width_ratio, + mlp_drop_rate=mlp_drop_rate, + act_type=act_type, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + qkv_bias=qkv_bias, + attn_mode=attn_mode, + ) + for _ in range(depth) + ] + ) + + def forward(self, x: torch.Tensor, c: torch.LongTensor, txt_lens: list[int]) -> torch.Tensor: + """ + Apply sequential token refinement. + + Args: + x: Input token embeddings [B, L, C]. + c: Combined conditioning vector [B, C]. + txt_lens: Valid sequence lengths for each batch element. + + Returns: + Refined token embeddings [B, L, C]. + """ + for block in self.blocks: + x = block(x, c, txt_lens) + return x + + +class SingleTokenRefiner(nn.Module): + """ + Text embedding refinement with timestep and context conditioning. + + Projects input text embeddings and applies self-attention refinement + conditioned on diffusion timestep and aggregate text context. + + Args: + in_channels: Input text embedding dimension. + hidden_size: Transformer hidden dimension. + heads_num: Number of attention heads. + depth: Number of refinement blocks. + attn_mode: Attention implementation mode. + """ + + def __init__(self, in_channels: int, hidden_size: int, heads_num: int, depth: int, attn_mode: str = "torch"): + # Fixed architecture parameters for HunyuanImage-2.1 + mlp_drop_rate: float = 0.0 # No MLP dropout + act_type: str = "silu" # SiLU activation + mlp_width_ratio: float = 4.0 # 4x MLP expansion + qk_norm: bool = False # No QK normalization + qk_norm_type: str = "layer" # Layer norm type (unused) + qkv_bias: bool = True # Use QKV bias + + super().__init__() + self.input_embedder = nn.Linear(in_channels, hidden_size, bias=True) + act_layer = nn.SiLU + self.t_embedder = TimestepEmbedder(hidden_size, act_layer) + self.c_embedder = TextProjection(in_channels, hidden_size, act_layer) + self.individual_token_refiner = IndividualTokenRefiner( + hidden_size=hidden_size, + heads_num=heads_num, + depth=depth, + mlp_width_ratio=mlp_width_ratio, + mlp_drop_rate=mlp_drop_rate, + act_type=act_type, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + qkv_bias=qkv_bias, + attn_mode=attn_mode, + ) + + def forward(self, x: torch.Tensor, t: torch.LongTensor, txt_lens: list[int]) -> torch.Tensor: + """ + Refine text embeddings with timestep conditioning. + + Args: + x: Input text embeddings [B, L, in_channels]. + t: Diffusion timestep [B]. + txt_lens: Valid sequence lengths for each batch element. + + Returns: + Refined embeddings [B, L, hidden_size]. + """ + timestep_aware_representations = self.t_embedder(t) + + # Compute context-aware representations by averaging valid tokens + context_aware_representations = torch.stack([x[i, : txt_lens[i]].mean(dim=0) for i in range(x.shape[0])], dim=0) # [B, C] + + context_aware_representations = self.c_embedder(context_aware_representations) + c = timestep_aware_representations + context_aware_representations + x = self.input_embedder(x) + x = self.individual_token_refiner(x, c, txt_lens) + return x + + +class FinalLayer(nn.Module): + """ + Final output projection layer with adaptive layer normalization. + + Projects transformer hidden states to output patch space with + timestep-conditioned modulation. + + Args: + hidden_size: Input hidden dimension. + patch_size: Spatial patch size for output reshaping. + out_channels: Number of output channels. + act_layer: Activation function class. + """ + + def __init__(self, hidden_size, patch_size, out_channels, act_layer): + super().__init__() + + # Layer normalization without learnable parameters + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + out_size = (patch_size[0] * patch_size[1]) * out_channels + self.linear = nn.Linear(hidden_size, out_size, bias=True) + + # Adaptive layer normalization modulation + self.adaLN_modulation = nn.Sequential( + act_layer(), + nn.Linear(hidden_size, 2 * hidden_size, bias=True), + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift=shift, scale=scale) + x = self.linear(x) + return x + + +class RMSNorm(nn.Module): + """ + Root Mean Square Layer Normalization. + + Normalizes input using RMS and applies learnable scaling. + More efficient than LayerNorm as it doesn't compute mean. + + Args: + dim: Input feature dimension. + eps: Small value for numerical stability. + """ + + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + """ + Apply RMS normalization. + + Args: + x: Input tensor. + + Returns: + RMS normalized tensor. + """ + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def reset_parameters(self): + self.weight.fill_(1) + + def forward(self, x): + """ + Apply RMSNorm with learnable scaling. + + Args: + x: Input tensor. + + Returns: + Normalized and scaled tensor. + """ + output = self._norm(x.float()).type_as(x) + output = output * self.weight + return output + + +# kept for reference, not used in current implementation +# class LinearWarpforSingle(nn.Module): +# """ +# Linear layer wrapper for concatenating and projecting two inputs. + +# Used in single-stream blocks to combine attention output with MLP features. + +# Args: +# in_dim: Input dimension (sum of both input feature dimensions). +# out_dim: Output dimension. +# bias: Whether to use bias in linear projection. +# """ + +# def __init__(self, in_dim: int, out_dim: int, bias=False): +# super().__init__() +# self.fc = nn.Linear(in_dim, out_dim, bias=bias) + +# def forward(self, x, y): +# """Concatenate inputs along feature dimension and project.""" +# x = torch.cat([x.contiguous(), y.contiguous()], dim=2).contiguous() +# return self.fc(x) + + +class ModulateDiT(nn.Module): + """ + Timestep conditioning modulation layer. + + Projects timestep embeddings to multiple modulation parameters + for adaptive layer normalization. + + Args: + hidden_size: Input conditioning dimension. + factor: Number of modulation parameters to generate. + act_layer: Activation function class. + """ + + def __init__(self, hidden_size: int, factor: int, act_layer: Callable): + super().__init__() + self.act = act_layer() + self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(self.act(x)) + + +class MMDoubleStreamBlock(nn.Module): + """ + Multimodal double-stream transformer block. + + Processes image and text tokens separately with cross-modal attention. + Each stream has its own normalization and MLP layers but shares + attention computation for cross-modal interaction. + + Args: + hidden_size: Model dimension. + heads_num: Number of attention heads. + mlp_width_ratio: MLP expansion ratio. + mlp_act_type: MLP activation function (only "gelu_tanh" supported). + qk_norm: QK normalization flag (must be True). + qk_norm_type: QK normalization type (only "rms" supported). + qkv_bias: Use bias in QKV projections. + attn_mode: Attention implementation mode. + """ + + def __init__( + self, + hidden_size: int, + heads_num: int, + mlp_width_ratio: float, + mlp_act_type: str = "gelu_tanh", + qk_norm: bool = True, + qk_norm_type: str = "rms", + qkv_bias: bool = False, + attn_mode: str = "torch", + ): + super().__init__() + + assert mlp_act_type == "gelu_tanh", "Only GELU-tanh activation supported." + assert qk_norm_type == "rms", "Only RMS normalization supported." + assert qk_norm, "QK normalization must be enabled." + + self.attn_mode = attn_mode + self.heads_num = heads_num + head_dim = hidden_size // heads_num + mlp_hidden_dim = int(hidden_size * mlp_width_ratio) + + # Image stream processing components + self.img_mod = ModulateDiT(hidden_size, factor=6, act_layer=nn.SiLU) + self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.img_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias) + + self.img_attn_q_norm = RMSNorm(head_dim, eps=1e-6) + self.img_attn_k_norm = RMSNorm(head_dim, eps=1e-6) + self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) + + self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.img_mlp = MLP(hidden_size, mlp_hidden_dim, act_layer=lambda: nn.GELU(approximate="tanh"), bias=True) + + # Text stream processing components + self.txt_mod = ModulateDiT(hidden_size, factor=6, act_layer=nn.SiLU) + self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.txt_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias) + self.txt_attn_q_norm = RMSNorm(head_dim, eps=1e-6) + self.txt_attn_k_norm = RMSNorm(head_dim, eps=1e-6) + self.txt_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) + + self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.txt_mlp = MLP(hidden_size, mlp_hidden_dim, act_layer=lambda: nn.GELU(approximate="tanh"), bias=True) + + def forward( + self, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, freqs_cis: tuple = None, seq_lens: list[int] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Extract modulation parameters for image and text streams + (img_mod1_shift, img_mod1_scale, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate) = self.img_mod(vec).chunk( + 6, dim=-1 + ) + (txt_mod1_shift, txt_mod1_scale, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate) = self.txt_mod(vec).chunk( + 6, dim=-1 + ) + + # Process image stream for attention + img_modulated = self.img_norm1(img) + img_modulated = modulate(img_modulated, shift=img_mod1_shift, scale=img_mod1_scale) + + img_qkv = self.img_attn_qkv(img_modulated) + img_q, img_k, img_v = img_qkv.chunk(3, dim=-1) + del img_qkv + + img_q = rearrange(img_q, "B L (H D) -> B L H D", H=self.heads_num) + img_k = rearrange(img_k, "B L (H D) -> B L H D", H=self.heads_num) + img_v = rearrange(img_v, "B L (H D) -> B L H D", H=self.heads_num) + + # Apply QK-Norm if enabled + img_q = self.img_attn_q_norm(img_q).to(img_v) + img_k = self.img_attn_k_norm(img_k).to(img_v) + + # Apply rotary position embeddings to image tokens + if freqs_cis is not None: + img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False) + assert ( + img_qq.shape == img_q.shape and img_kk.shape == img_k.shape + ), f"RoPE output shape mismatch: got {img_qq.shape}, {img_kk.shape}, expected {img_q.shape}, {img_k.shape}" + img_q, img_k = img_qq, img_kk + + # Process text stream for attention + txt_modulated = self.txt_norm1(txt) + txt_modulated = modulate(txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale) + + txt_qkv = self.txt_attn_qkv(txt_modulated) + txt_q, txt_k, txt_v = txt_qkv.chunk(3, dim=-1) + del txt_qkv + + txt_q = rearrange(txt_q, "B L (H D) -> B L H D", H=self.heads_num) + txt_k = rearrange(txt_k, "B L (H D) -> B L H D", H=self.heads_num) + txt_v = rearrange(txt_v, "B L (H D) -> B L H D", H=self.heads_num) + + # Apply QK-Norm if enabled + txt_q = self.txt_attn_q_norm(txt_q).to(txt_v) + txt_k = self.txt_attn_k_norm(txt_k).to(txt_v) + + # Concatenate image and text tokens for joint attention + q = torch.cat([img_q, txt_q], dim=1) + k = torch.cat([img_k, txt_k], dim=1) + v = torch.cat([img_v, txt_v], dim=1) + attn = attention(q, k, v, seq_lens=seq_lens, attn_mode=self.attn_mode) + + # Split attention outputs back to separate streams + img_attn, txt_attn = (attn[:, : img_q.shape[1]].contiguous(), attn[:, img_q.shape[1] :].contiguous()) + + # Apply attention projection and residual connection for image stream + img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate) + + # Apply MLP and residual connection for image stream + img = img + apply_gate( + self.img_mlp(modulate(self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale)), + gate=img_mod2_gate, + ) + + # Apply attention projection and residual connection for text stream + txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate) + + # Apply MLP and residual connection for text stream + txt = txt + apply_gate( + self.txt_mlp(modulate(self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale)), + gate=txt_mod2_gate, + ) + + return img, txt + + +class MMSingleStreamBlock(nn.Module): + """ + Multimodal single-stream transformer block. + + Processes concatenated image and text tokens jointly with shared attention. + Uses parallel linear layers for efficiency and applies RoPE only to image tokens. + + Args: + hidden_size: Model dimension. + heads_num: Number of attention heads. + mlp_width_ratio: MLP expansion ratio. + mlp_act_type: MLP activation function (only "gelu_tanh" supported). + qk_norm: QK normalization flag (must be True). + qk_norm_type: QK normalization type (only "rms" supported). + qk_scale: Attention scaling factor (computed automatically if None). + attn_mode: Attention implementation mode. + """ + + def __init__( + self, + hidden_size: int, + heads_num: int, + mlp_width_ratio: float = 4.0, + mlp_act_type: str = "gelu_tanh", + qk_norm: bool = True, + qk_norm_type: str = "rms", + qk_scale: float = None, + attn_mode: str = "torch", + ): + super().__init__() + + assert mlp_act_type == "gelu_tanh", "Only GELU-tanh activation supported." + assert qk_norm_type == "rms", "Only RMS normalization supported." + assert qk_norm, "QK normalization must be enabled." + + self.attn_mode = attn_mode + self.hidden_size = hidden_size + self.heads_num = heads_num + head_dim = hidden_size // heads_num + mlp_hidden_dim = int(hidden_size * mlp_width_ratio) + self.mlp_hidden_dim = mlp_hidden_dim + self.scale = qk_scale or head_dim**-0.5 + + # Parallel linear projections for efficiency + self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim) + + # Combined output projection + # self.linear2 = LinearWarpforSingle(hidden_size + mlp_hidden_dim, hidden_size, bias=True) # for reference + self.linear2 = nn.Linear(hidden_size + mlp_hidden_dim, hidden_size, bias=True) + + # QK normalization layers + self.q_norm = RMSNorm(head_dim, eps=1e-6) + self.k_norm = RMSNorm(head_dim, eps=1e-6) + + self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + self.mlp_act = nn.GELU(approximate="tanh") + self.modulation = ModulateDiT(hidden_size, factor=3, act_layer=nn.SiLU) + + def forward( + self, + x: torch.Tensor, + vec: torch.Tensor, + txt_len: int, + freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None, + seq_lens: list[int] = None, + ) -> torch.Tensor: + # Extract modulation parameters + mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1) + x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale) + + # Compute Q, K, V, and MLP input + qkv_mlp = self.linear1(x_mod) + q, k, v, mlp = qkv_mlp.split([self.hidden_size, self.hidden_size, self.hidden_size, self.mlp_hidden_dim], dim=-1) + del qkv_mlp + + q = rearrange(q, "B L (H D) -> B L H D", H=self.heads_num) + k = rearrange(k, "B L (H D) -> B L H D", H=self.heads_num) + v = rearrange(v, "B L (H D) -> B L H D", H=self.heads_num) + + # Apply QK-Norm if enabled + q = self.q_norm(q).to(v) + k = self.k_norm(k).to(v) + + # Separate image and text tokens + img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :] + img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :] + img_v, txt_v = v[:, :-txt_len, :, :], v[:, -txt_len:, :, :] + + # Apply rotary position embeddings only to image tokens + img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False) + assert ( + img_qq.shape == img_q.shape and img_kk.shape == img_k.shape + ), f"RoPE output shape mismatch: got {img_qq.shape}, {img_kk.shape}, expected {img_q.shape}, {img_k.shape}" + img_q, img_k = img_qq, img_kk + + # Recombine and compute joint attention + q = torch.cat([img_q, txt_q], dim=1) + k = torch.cat([img_k, txt_k], dim=1) + v = torch.cat([img_v, txt_v], dim=1) + attn = attention(q, k, v, seq_lens=seq_lens, attn_mode=self.attn_mode) + + # Combine attention and MLP outputs, apply gating + # output = self.linear2(attn, self.mlp_act(mlp)) + + mlp = self.mlp_act(mlp) + output = torch.cat([attn, mlp], dim=2).contiguous() + output = self.linear2(output) + + return x + apply_gate(output, gate=mod_gate) + + +# endregion diff --git a/library/hunyuan_image_text_encoder.py b/library/hunyuan_image_text_encoder.py new file mode 100644 index 00000000..85bdaa43 --- /dev/null +++ b/library/hunyuan_image_text_encoder.py @@ -0,0 +1,649 @@ +import json +import re +from typing import Tuple, Optional, Union +import torch +from transformers import ( + AutoTokenizer, + Qwen2_5_VLConfig, + Qwen2_5_VLForConditionalGeneration, + Qwen2Tokenizer, + T5ForConditionalGeneration, + T5Config, + T5Tokenizer, +) +from transformers.models.t5.modeling_t5 import T5Stack +from accelerate import init_empty_weights + +from library import model_util +from library.utils import load_safetensors, setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +BYT5_TOKENIZER_PATH = "google/byt5-small" +QWEN_2_5_VL_IMAGE_ID ="Qwen/Qwen2.5-VL-7B-Instruct" + + +# Copy from Glyph-SDXL-V2 + +COLOR_IDX_JSON = """{"white": 0, "black": 1, "darkslategray": 2, "dimgray": 3, "darkolivegreen": 4, "midnightblue": 5, "saddlebrown": 6, "sienna": 7, "whitesmoke": 8, "darkslateblue": 9, +"indianred": 10, "linen": 11, "maroon": 12, "khaki": 13, "sandybrown": 14, "gray": 15, "gainsboro": 16, "teal": 17, "peru": 18, "gold": 19, +"snow": 20, "firebrick": 21, "crimson": 22, "chocolate": 23, "tomato": 24, "brown": 25, "goldenrod": 26, "antiquewhite": 27, "rosybrown": 28, "steelblue": 29, +"floralwhite": 30, "seashell": 31, "darkgreen": 32, "oldlace": 33, "darkkhaki": 34, "burlywood": 35, "red": 36, "darkgray": 37, "orange": 38, "royalblue": 39, +"seagreen": 40, "lightgray": 41, "tan": 42, "coral": 43, "beige": 44, "palevioletred": 45, "wheat": 46, "lavender": 47, "darkcyan": 48, "slateblue": 49, +"slategray": 50, "orangered": 51, "silver": 52, "olivedrab": 53, "forestgreen": 54, "darkgoldenrod": 55, "ivory": 56, "darkorange": 57, "yellow": 58, "hotpink": 59, +"ghostwhite": 60, "lightcoral": 61, "indigo": 62, "bisque": 63, "darkred": 64, "darksalmon": 65, "lightslategray": 66, "dodgerblue": 67, "lightpink": 68, "mistyrose": 69, +"mediumvioletred": 70, "cadetblue": 71, "deeppink": 72, "salmon": 73, "palegoldenrod": 74, "blanchedalmond": 75, "lightseagreen": 76, "cornflowerblue": 77, "yellowgreen": 78, "greenyellow": 79, +"navajowhite": 80, "papayawhip": 81, "mediumslateblue": 82, "purple": 83, "blueviolet": 84, "pink": 85, "cornsilk": 86, "lightsalmon": 87, "mediumpurple": 88, "moccasin": 89, +"turquoise": 90, "mediumseagreen": 91, "lavenderblush": 92, "mediumblue": 93, "darkseagreen": 94, "mediumturquoise": 95, "paleturquoise": 96, "skyblue": 97, "lemonchiffon": 98, "olive": 99, +"peachpuff": 100, "lightyellow": 101, "lightsteelblue": 102, "mediumorchid": 103, "plum": 104, "darkturquoise": 105, "aliceblue": 106, "mediumaquamarine": 107, "orchid": 108, "powderblue": 109, +"blue": 110, "darkorchid": 111, "violet": 112, "lightskyblue": 113, "lightcyan": 114, "lightgoldenrodyellow": 115, "navy": 116, "thistle": 117, "honeydew": 118, "mintcream": 119, +"lightblue": 120, "darkblue": 121, "darkmagenta": 122, "deepskyblue": 123, "magenta": 124, "limegreen": 125, "darkviolet": 126, "cyan": 127, "palegreen": 128, "aquamarine": 129, +"lawngreen": 130, "lightgreen": 131, "azure": 132, "chartreuse": 133, "green": 134, "mediumspringgreen": 135, "lime": 136, "springgreen": 137}""" + +MULTILINGUAL_10_LANG_IDX_JSON = """{"en-Montserrat-Regular": 0, "en-Poppins-Italic": 1, "en-GlacialIndifference-Regular": 2, "en-OpenSans-ExtraBoldItalic": 3, "en-Montserrat-Bold": 4, "en-Now-Regular": 5, "en-Garet-Regular": 6, "en-LeagueSpartan-Bold": 7, "en-DMSans-Regular": 8, "en-OpenSauceOne-Regular": 9, +"en-OpenSans-ExtraBold": 10, "en-KGPrimaryPenmanship": 11, "en-Anton-Regular": 12, "en-Aileron-BlackItalic": 13, "en-Quicksand-Light": 14, "en-Roboto-BoldItalic": 15, "en-TheSeasons-It": 16, "en-Kollektif": 17, "en-Inter-BoldItalic": 18, "en-Poppins-Medium": 19, +"en-Poppins-Light": 20, "en-RoxboroughCF-RegularItalic": 21, "en-PlayfairDisplay-SemiBold": 22, "en-Agrandir-Italic": 23, "en-Lato-Regular": 24, "en-MoreSugarRegular": 25, "en-CanvaSans-RegularItalic": 26, "en-PublicSans-Italic": 27, "en-CodePro-NormalLC": 28, "en-Belleza-Regular": 29, +"en-JosefinSans-Bold": 30, "en-HKGrotesk-Bold": 31, "en-Telegraf-Medium": 32, "en-BrittanySignatureRegular": 33, "en-Raleway-ExtraBoldItalic": 34, "en-Mont-RegularItalic": 35, "en-Arimo-BoldItalic": 36, "en-Lora-Italic": 37, "en-ArchivoBlack-Regular": 38, "en-Poppins": 39, +"en-Barlow-Black": 40, "en-CormorantGaramond-Bold": 41, "en-LibreBaskerville-Regular": 42, "en-CanvaSchoolFontRegular": 43, "en-BebasNeueBold": 44, "en-LazydogRegular": 45, "en-FredokaOne-Regular": 46, "en-Horizon-Bold": 47, "en-Nourd-Regular": 48, "en-Hatton-Regular": 49, +"en-Nunito-ExtraBoldItalic": 50, "en-CerebriSans-Regular": 51, "en-Montserrat-Light": 52, "en-TenorSans": 53, "en-Norwester-Regular": 54, "en-ClearSans-Bold": 55, "en-Cardo-Regular": 56, "en-Alice-Regular": 57, "en-Oswald-Regular": 58, "en-Gaegu-Bold": 59, +"en-Muli-Black": 60, "en-TAN-PEARL-Regular": 61, "en-CooperHewitt-Book": 62, "en-Agrandir-Grand": 63, "en-BlackMango-Thin": 64, "en-DMSerifDisplay-Regular": 65, "en-Antonio-Bold": 66, "en-Sniglet-Regular": 67, "en-BeVietnam-Regular": 68, "en-NunitoSans10pt-BlackItalic": 69, +"en-AbhayaLibre-ExtraBold": 70, "en-Rubik-Regular": 71, "en-PPNeueMachina-Regular": 72, "en-TAN - MON CHERI-Regular": 73, "en-Jua-Regular": 74, "en-Playlist-Script": 75, "en-SourceSansPro-BoldItalic": 76, "en-MoonTime-Regular": 77, "en-Eczar-ExtraBold": 78, "en-Gatwick-Regular": 79, +"en-MonumentExtended-Regular": 80, "en-BarlowSemiCondensed-Regular": 81, "en-BarlowCondensed-Regular": 82, "en-Alegreya-Regular": 83, "en-DreamAvenue": 84, "en-RobotoCondensed-Italic": 85, "en-BobbyJones-Regular": 86, "en-Garet-ExtraBold": 87, "en-YesevaOne-Regular": 88, "en-Dosis-ExtraBold": 89, +"en-LeagueGothic-Regular": 90, "en-OpenSans-Italic": 91, "en-TANAEGEAN-Regular": 92, "en-Maharlika-Regular": 93, "en-MarykateRegular": 94, "en-Cinzel-Regular": 95, "en-Agrandir-Wide": 96, "en-Chewy-Regular": 97, "en-BodoniFLF-BoldItalic": 98, "en-Nunito-BlackItalic": 99, +"en-LilitaOne": 100, "en-HandyCasualCondensed-Regular": 101, "en-Ovo": 102, "en-Livvic-Regular": 103, "en-Agrandir-Narrow": 104, "en-CrimsonPro-Italic": 105, "en-AnonymousPro-Bold": 106, "en-NF-OneLittleFont-Bold": 107, "en-RedHatDisplay-BoldItalic": 108, "en-CodecPro-Regular": 109, +"en-HalimunRegular": 110, "en-LibreFranklin-Black": 111, "en-TeXGyreTermes-BoldItalic": 112, "en-Shrikhand-Regular": 113, "en-TTNormsPro-Italic": 114, "en-Gagalin-Regular": 115, "en-OpenSans-Bold": 116, "en-GreatVibes-Regular": 117, "en-Breathing": 118, "en-HeroLight-Regular": 119, +"en-KGPrimaryDots": 120, "en-Quicksand-Bold": 121, "en-Brice-ExtraLightSemiExpanded": 122, "en-Lato-BoldItalic": 123, "en-Fraunces9pt-Italic": 124, "en-AbrilFatface-Regular": 125, "en-BerkshireSwash-Regular": 126, "en-Atma-Bold": 127, "en-HolidayRegular": 128, "en-BebasNeueCyrillic": 129, +"en-IntroRust-Base": 130, "en-Gistesy": 131, "en-BDScript-Regular": 132, "en-ApricotsRegular": 133, "en-Prompt-Black": 134, "en-TAN MERINGUE": 135, "en-Sukar Regular": 136, "en-GentySans-Regular": 137, "en-NeueEinstellung-Normal": 138, "en-Garet-Bold": 139, +"en-FiraSans-Black": 140, "en-BantayogLight": 141, "en-NotoSerifDisplay-Black": 142, "en-TTChocolates-Regular": 143, "en-Ubuntu-Regular": 144, "en-Assistant-Bold": 145, "en-ABeeZee-Regular": 146, "en-LexendDeca-Regular": 147, "en-KingredSerif": 148, "en-Radley-Regular": 149, +"en-BrownSugar": 150, "en-MigraItalic-ExtraboldItalic": 151, "en-ChildosArabic-Regular": 152, "en-PeaceSans": 153, "en-LondrinaSolid-Black": 154, "en-SpaceMono-BoldItalic": 155, "en-RobotoMono-Light": 156, "en-CourierPrime-Regular": 157, "en-Alata-Regular": 158, "en-Amsterdam-One": 159, +"en-IreneFlorentina-Regular": 160, "en-CatchyMager": 161, "en-Alta_regular": 162, "en-ArticulatCF-Regular": 163, "en-Raleway-Regular": 164, "en-BrasikaDisplay": 165, "en-TANAngleton-Italic": 166, "en-NotoSerifDisplay-ExtraCondensedItalic": 167, "en-Bryndan Write": 168, "en-TTCommonsPro-It": 169, +"en-AlexBrush-Regular": 170, "en-Antic-Regular": 171, "en-TTHoves-Bold": 172, "en-DroidSerif": 173, "en-AblationRegular": 174, "en-Marcellus-Regular": 175, "en-Sanchez-Italic": 176, "en-JosefinSans": 177, "en-Afrah-Regular": 178, "en-PinyonScript": 179, +"en-TTInterphases-BoldItalic": 180, "en-Yellowtail-Regular": 181, "en-Gliker-Regular": 182, "en-BobbyJonesSoft-Regular": 183, "en-IBMPlexSans": 184, "en-Amsterdam-Three": 185, "en-Amsterdam-FourSlant": 186, "en-TTFors-Regular": 187, "en-Quattrocento": 188, "en-Sifonn-Basic": 189, +"en-AlegreyaSans-Black": 190, "en-Daydream": 191, "en-AristotelicaProTx-Rg": 192, "en-NotoSerif": 193, "en-EBGaramond-Italic": 194, "en-HammersmithOne-Regular": 195, "en-RobotoSlab-Regular": 196, "en-DO-Sans-Regular": 197, "en-KGPrimaryDotsLined": 198, "en-Blinker-Regular": 199, +"en-TAN NIMBUS": 200, "en-Blueberry-Regular": 201, "en-Rosario-Regular": 202, "en-Forum": 203, "en-MistrullyRegular": 204, "en-SourceSerifPro-Regular": 205, "en-Bugaki-Regular": 206, "en-CMUSerif-Roman": 207, "en-GulfsDisplay-NormalItalic": 208, "en-PTSans-Bold": 209, +"en-Sensei-Medium": 210, "en-SquadaOne-Regular": 211, "en-Arapey-Italic": 212, "en-Parisienne-Regular": 213, "en-Aleo-Italic": 214, "en-QuicheDisplay-Italic": 215, "en-RocaOne-It": 216, "en-Funtastic-Regular": 217, "en-PTSerif-BoldItalic": 218, "en-Muller-RegularItalic": 219, +"en-ArgentCF-Regular": 220, "en-Brightwall-Italic": 221, "en-Knewave-Regular": 222, "en-TYSerif-D": 223, "en-Agrandir-Tight": 224, "en-AlfaSlabOne-Regular": 225, "en-TANTangkiwood-Display": 226, "en-Kief-Montaser-Regular": 227, "en-Gotham-Book": 228, "en-JuliusSansOne-Regular": 229, +"en-CocoGothic-Italic": 230, "en-SairaCondensed-Regular": 231, "en-DellaRespira-Regular": 232, "en-Questrial-Regular": 233, "en-BukhariScript-Regular": 234, "en-HelveticaWorld-Bold": 235, "en-TANKINDRED-Display": 236, "en-CinzelDecorative-Regular": 237, "en-Vidaloka-Regular": 238, "en-AlegreyaSansSC-Black": 239, +"en-FeelingPassionate-Regular": 240, "en-QuincyCF-Regular": 241, "en-FiraCode-Regular": 242, "en-Genty-Regular": 243, "en-Nickainley-Normal": 244, "en-RubikOne-Regular": 245, "en-Gidole-Regular": 246, "en-Borsok": 247, "en-Gordita-RegularItalic": 248, "en-Scripter-Regular": 249, +"en-Buffalo-Regular": 250, "en-KleinText-Regular": 251, "en-Creepster-Regular": 252, "en-Arvo-Bold": 253, "en-GabrielSans-NormalItalic": 254, "en-Heebo-Black": 255, "en-LexendExa-Regular": 256, "en-BrixtonSansTC-Regular": 257, "en-GildaDisplay-Regular": 258, "en-ChunkFive-Roman": 259, +"en-Amaranth-BoldItalic": 260, "en-BubbleboddyNeue-Regular": 261, "en-MavenPro-Bold": 262, "en-TTDrugs-Italic": 263, "en-CyGrotesk-KeyRegular": 264, "en-VarelaRound-Regular": 265, "en-Ruda-Black": 266, "en-SafiraMarch": 267, "en-BloggerSans": 268, "en-TANHEADLINE-Regular": 269, +"en-SloopScriptPro-Regular": 270, "en-NeueMontreal-Regular": 271, "en-Schoolbell-Regular": 272, "en-SigherRegular": 273, "en-InriaSerif-Regular": 274, "en-JetBrainsMono-Regular": 275, "en-MADEEvolveSans": 276, "en-Dekko": 277, "en-Handyman-Regular": 278, "en-Aileron-BoldItalic": 279, +"en-Bright-Italic": 280, "en-Solway-Regular": 281, "en-Higuen-Regular": 282, "en-WedgesItalic": 283, "en-TANASHFORD-BOLD": 284, "en-IBMPlexMono": 285, "en-RacingSansOne-Regular": 286, "en-RegularBrush": 287, "en-OpenSans-LightItalic": 288, "en-SpecialElite-Regular": 289, +"en-FuturaLTPro-Medium": 290, "en-MaragsaDisplay": 291, "en-BigShouldersDisplay-Regular": 292, "en-BDSans-Regular": 293, "en-RasputinRegular": 294, "en-Yvesyvesdrawing-BoldItalic": 295, "en-Bitter-Regular": 296, "en-LuckiestGuy-Regular": 297, "en-CanvaSchoolFontDotted": 298, "en-TTFirsNeue-Italic": 299, +"en-Sunday-Regular": 300, "en-HKGothic-MediumItalic": 301, "en-CaveatBrush-Regular": 302, "en-HeliosExt": 303, "en-ArchitectsDaughter-Regular": 304, "en-Angelina": 305, "en-Calistoga-Regular": 306, "en-ArchivoNarrow-Regular": 307, "en-ObjectSans-MediumSlanted": 308, "en-AyrLucidityCondensed-Regular": 309, +"en-Nexa-RegularItalic": 310, "en-Lustria-Regular": 311, "en-Amsterdam-TwoSlant": 312, "en-Virtual-Regular": 313, "en-Brusher-Regular": 314, "en-NF-Lepetitcochon-Regular": 315, "en-TANTWINKLE": 316, "en-LeJour-Serif": 317, "en-Prata-Regular": 318, "en-PPWoodland-Regular": 319, +"en-PlayfairDisplay-BoldItalic": 320, "en-AmaticSC-Regular": 321, "en-Cabin-Regular": 322, "en-Manjari-Bold": 323, "en-MrDafoe-Regular": 324, "en-TTRamillas-Italic": 325, "en-Luckybones-Bold": 326, "en-DarkerGrotesque-Light": 327, "en-BellabooRegular": 328, "en-CormorantSC-Bold": 329, +"en-GochiHand-Regular": 330, "en-Atteron": 331, "en-RocaTwo-Lt": 332, "en-ZCOOLXiaoWei-Regular": 333, "en-TANSONGBIRD": 334, "en-HeadingNow-74Regular": 335, "en-Luthier-BoldItalic": 336, "en-Oregano-Regular": 337, "en-AyrTropikaIsland-Int": 338, "en-Mali-Regular": 339, +"en-DidactGothic-Regular": 340, "en-Lovelace-Regular": 341, "en-BakerieSmooth-Regular": 342, "en-CarterOne": 343, "en-HussarBd": 344, "en-OldStandard-Italic": 345, "en-TAN-ASTORIA-Display": 346, "en-rugratssans-Regular": 347, "en-BMHANNA": 348, "en-BetterSaturday": 349, +"en-AdigianaToybox": 350, "en-Sailors": 351, "en-PlayfairDisplaySC-Italic": 352, "en-Etna-Regular": 353, "en-Revive80Signature": 354, "en-CAGenerated": 355, "en-Poppins-Regular": 356, "en-Jonathan-Regular": 357, "en-Pacifico-Regular": 358, "en-Saira-Black": 359, +"en-Loubag-Regular": 360, "en-Decalotype-Black": 361, "en-Mansalva-Regular": 362, "en-Allura-Regular": 363, "en-ProximaNova-Bold": 364, "en-TANMIGNON-DISPLAY": 365, "en-ArsenicaAntiqua-Regular": 366, "en-BreulGroteskA-RegularItalic": 367, "en-HKModular-Bold": 368, "en-TANNightingale-Regular": 369, +"en-AristotelicaProCndTxt-Rg": 370, "en-Aprila-Regular": 371, "en-Tomorrow-Regular": 372, "en-AngellaWhite": 373, "en-KaushanScript-Regular": 374, "en-NotoSans": 375, "en-LeJour-Script": 376, "en-BrixtonTC-Regular": 377, "en-OleoScript-Regular": 378, "en-Cakerolli-Regular": 379, +"en-Lobster-Regular": 380, "en-FrunchySerif-Regular": 381, "en-PorcelainRegular": 382, "en-AlojaExtended": 383, "en-SergioTrendy-Italic": 384, "en-LovelaceText-Bold": 385, "en-Anaktoria": 386, "en-JimmyScript-Light": 387, "en-IBMPlexSerif": 388, "en-Marta": 389, +"en-Mango-Regular": 390, "en-Overpass-Italic": 391, "en-Hagrid-Regular": 392, "en-ElikaGorica": 393, "en-Amiko-Regular": 394, "en-EFCOBrookshire-Regular": 395, "en-Caladea-Regular": 396, "en-MoonlightBold": 397, "en-Staatliches-Regular": 398, "en-Helios-Bold": 399, +"en-Satisfy-Regular": 400, "en-NexaScript-Regular": 401, "en-Trocchi-Regular": 402, "en-March": 403, "en-IbarraRealNova-Regular": 404, "en-Nectarine-Regular": 405, "en-Overpass-Light": 406, "en-TruetypewriterPolyglOTT": 407, "en-Bangers-Regular": 408, "en-Lazord-BoldExpandedItalic": 409, +"en-Chloe-Regular": 410, "en-BaskervilleDisplayPT-Regular": 411, "en-Bright-Regular": 412, "en-Vollkorn-Regular": 413, "en-Harmattan": 414, "en-SortsMillGoudy-Regular": 415, "en-Biryani-Bold": 416, "en-SugoProDisplay-Italic": 417, "en-Lazord-BoldItalic": 418, "en-Alike-Regular": 419, +"en-PermanentMarker-Regular": 420, "en-Sacramento-Regular": 421, "en-HKGroteskPro-Italic": 422, "en-Aleo-BoldItalic": 423, "en-Noot": 424, "en-TANGARLAND-Regular": 425, "en-Twister": 426, "en-Arsenal-Italic": 427, "en-Bogart-Italic": 428, "en-BethEllen-Regular": 429, +"en-Caveat-Regular": 430, "en-BalsamiqSans-Bold": 431, "en-BreeSerif-Regular": 432, "en-CodecPro-ExtraBold": 433, "en-Pierson-Light": 434, "en-CyGrotesk-WideRegular": 435, "en-Lumios-Marker": 436, "en-Comfortaa-Bold": 437, "en-TraceFontRegular": 438, "en-RTL-AdamScript-Regular": 439, +"en-EastmanGrotesque-Italic": 440, "en-Kalam-Bold": 441, "en-ChauPhilomeneOne-Regular": 442, "en-Coiny-Regular": 443, "en-Lovera": 444, "en-Gellatio": 445, "en-TitilliumWeb-Bold": 446, "en-OilvareBase-Italic": 447, "en-Catamaran-Black": 448, "en-Anteb-Italic": 449, +"en-SueEllenFrancisco": 450, "en-SweetApricot": 451, "en-BrightSunshine": 452, "en-IM_FELL_Double_Pica_Italic": 453, "en-Granaina-limpia": 454, "en-TANPARFAIT": 455, "en-AcherusGrotesque-Regular": 456, "en-AwesomeLathusca-Italic": 457, "en-Signika-Bold": 458, "en-Andasia": 459, +"en-DO-AllCaps-Slanted": 460, "en-Zenaida-Regular": 461, "en-Fahkwang-Regular": 462, "en-Play-Regular": 463, "en-BERNIERRegular-Regular": 464, "en-PlumaThin-Regular": 465, "en-SportsWorld": 466, "en-Garet-Black": 467, "en-CarolloPlayscript-BlackItalic": 468, "en-Cheque-Regular": 469, +"en-SEGO": 470, "en-BobbyJones-Condensed": 471, "en-NexaSlab-RegularItalic": 472, "en-DancingScript-Regular": 473, "en-PaalalabasDisplayWideBETA": 474, "en-Magnolia-Script": 475, "en-OpunMai-400It": 476, "en-MadelynFill-Regular": 477, "en-ZingRust-Base": 478, "en-FingerPaint-Regular": 479, +"en-BostonAngel-Light": 480, "en-Gliker-RegularExpanded": 481, "en-Ahsing": 482, "en-Engagement-Regular": 483, "en-EyesomeScript": 484, "en-LibraSerifModern-Regular": 485, "en-London-Regular": 486, "en-AtkinsonHyperlegible-Regular": 487, "en-StadioNow-TextItalic": 488, "en-Aniyah": 489, +"en-ITCAvantGardePro-Bold": 490, "en-Comica-Regular": 491, "en-Coustard-Regular": 492, "en-Brice-BoldCondensed": 493, "en-TANNEWYORK-Bold": 494, "en-TANBUSTER-Bold": 495, "en-Alatsi-Regular": 496, "en-TYSerif-Book": 497, "en-Jingleberry": 498, "en-Rajdhani-Bold": 499, +"en-LobsterTwo-BoldItalic": 500, "en-BestLight-Medium": 501, "en-Hitchcut-Regular": 502, "en-GermaniaOne-Regular": 503, "en-Emitha-Script": 504, "en-LemonTuesday": 505, "en-Cubao_Free_Regular": 506, "en-MonterchiSerif-Regular": 507, "en-AllertaStencil-Regular": 508, "en-RTL-Sondos-Regular": 509, +"en-HomemadeApple-Regular": 510, "en-CosmicOcto-Medium": 511, "cn-HelloFont-FangHuaTi": 0, "cn-HelloFont-ID-DianFangSong-Bold": 1, "cn-HelloFont-ID-DianFangSong": 2, "cn-HelloFont-ID-DianHei-CEJ": 3, "cn-HelloFont-ID-DianHei-DEJ": 4, "cn-HelloFont-ID-DianHei-EEJ": 5, "cn-HelloFont-ID-DianHei-FEJ": 6, "cn-HelloFont-ID-DianHei-GEJ": 7, "cn-HelloFont-ID-DianKai-Bold": 8, "cn-HelloFont-ID-DianKai": 9, +"cn-HelloFont-WenYiHei": 10, "cn-Hellofont-ID-ChenYanXingKai": 11, "cn-Hellofont-ID-DaZiBao": 12, "cn-Hellofont-ID-DaoCaoRen": 13, "cn-Hellofont-ID-JianSong": 14, "cn-Hellofont-ID-JiangHuZhaoPaiHei": 15, "cn-Hellofont-ID-KeSong": 16, "cn-Hellofont-ID-LeYuanTi": 17, "cn-Hellofont-ID-Pinocchio": 18, "cn-Hellofont-ID-QiMiaoTi": 19, +"cn-Hellofont-ID-QingHuaKai": 20, "cn-Hellofont-ID-QingHuaXingKai": 21, "cn-Hellofont-ID-ShanShuiXingKai": 22, "cn-Hellofont-ID-ShouXieQiShu": 23, "cn-Hellofont-ID-ShouXieTongZhenTi": 24, "cn-Hellofont-ID-TengLingTi": 25, "cn-Hellofont-ID-XiaoLiShu": 26, "cn-Hellofont-ID-XuanZhenSong": 27, "cn-Hellofont-ID-ZhongLingXingKai": 28, "cn-HellofontIDJiaoTangTi": 29, +"cn-HellofontIDJiuZhuTi": 30, "cn-HuXiaoBao-SaoBao": 31, "cn-HuXiaoBo-NanShen": 32, "cn-HuXiaoBo-ZhenShuai": 33, "cn-SourceHanSansSC-Bold": 34, "cn-SourceHanSansSC-ExtraLight": 35, "cn-SourceHanSansSC-Heavy": 36, "cn-SourceHanSansSC-Light": 37, "cn-SourceHanSansSC-Medium": 38, "cn-SourceHanSansSC-Normal": 39, +"cn-SourceHanSansSC-Regular": 40, "cn-SourceHanSerifSC-Bold": 41, "cn-SourceHanSerifSC-ExtraLight": 42, "cn-SourceHanSerifSC-Heavy": 43, "cn-SourceHanSerifSC-Light": 44, "cn-SourceHanSerifSC-Medium": 45, "cn-SourceHanSerifSC-Regular": 46, "cn-SourceHanSerifSC-SemiBold": 47, "cn-xiaowei": 48, "cn-AaJianHaoTi": 49, +"cn-AlibabaPuHuiTi-Bold": 50, "cn-AlibabaPuHuiTi-Heavy": 51, "cn-AlibabaPuHuiTi-Light": 52, "cn-AlibabaPuHuiTi-Medium": 53, "cn-AlibabaPuHuiTi-Regular": 54, "cn-CanvaAcidBoldSC": 55, "cn-CanvaBreezeCN": 56, "cn-CanvaBumperCropSC": 57, "cn-CanvaCakeShopCN": 58, "cn-CanvaEndeavorBlackSC": 59, +"cn-CanvaJoyHeiCN": 60, "cn-CanvaLiCN": 61, "cn-CanvaOrientalBrushCN": 62, "cn-CanvaPoster": 63, "cn-CanvaQinfuCalligraphyCN": 64, "cn-CanvaSweetHeartCN": 65, "cn-CanvaSwordLikeDreamCN": 66, "cn-CanvaTangyuanHandwritingCN": 67, "cn-CanvaWanderWorldCN": 68, "cn-CanvaWenCN": 69, +"cn-DianZiChunYi": 70, "cn-GenSekiGothicTW-H": 71, "cn-GenWanMinTW-L": 72, "cn-GenYoMinTW-B": 73, "cn-GenYoMinTW-EL": 74, "cn-GenYoMinTW-H": 75, "cn-GenYoMinTW-M": 76, "cn-GenYoMinTW-R": 77, "cn-GenYoMinTW-SB": 78, "cn-HYQiHei-AZEJ": 79, +"cn-HYQiHei-EES": 80, "cn-HanaMinA": 81, "cn-HappyZcool-2016": 82, "cn-HelloFont ZJ KeKouKeAiTi": 83, "cn-HelloFont-ID-BoBoTi": 84, "cn-HelloFont-ID-FuGuHei-25": 85, "cn-HelloFont-ID-FuGuHei-35": 86, "cn-HelloFont-ID-FuGuHei-45": 87, "cn-HelloFont-ID-FuGuHei-55": 88, "cn-HelloFont-ID-FuGuHei-65": 89, +"cn-HelloFont-ID-FuGuHei-75": 90, "cn-HelloFont-ID-FuGuHei-85": 91, "cn-HelloFont-ID-HeiKa": 92, "cn-HelloFont-ID-HeiTang": 93, "cn-HelloFont-ID-JianSong-95": 94, "cn-HelloFont-ID-JueJiangHei-50": 95, "cn-HelloFont-ID-JueJiangHei-55": 96, "cn-HelloFont-ID-JueJiangHei-60": 97, "cn-HelloFont-ID-JueJiangHei-65": 98, "cn-HelloFont-ID-JueJiangHei-70": 99, +"cn-HelloFont-ID-JueJiangHei-75": 100, "cn-HelloFont-ID-JueJiangHei-80": 101, "cn-HelloFont-ID-KuHeiTi": 102, "cn-HelloFont-ID-LingDongTi": 103, "cn-HelloFont-ID-LingLiTi": 104, "cn-HelloFont-ID-MuFengTi": 105, "cn-HelloFont-ID-NaiNaiJiangTi": 106, "cn-HelloFont-ID-PangDu": 107, "cn-HelloFont-ID-ReLieTi": 108, "cn-HelloFont-ID-RouRun": 109, +"cn-HelloFont-ID-SaShuangShouXieTi": 110, "cn-HelloFont-ID-WangZheFengFan": 111, "cn-HelloFont-ID-YouQiTi": 112, "cn-Hellofont-ID-XiaLeTi": 113, "cn-Hellofont-ID-XianXiaTi": 114, "cn-HuXiaoBoKuHei": 115, "cn-IDDanMoXingKai": 116, "cn-IDJueJiangHei": 117, "cn-IDMeiLingTi": 118, "cn-IDQQSugar": 119, +"cn-LiuJianMaoCao-Regular": 120, "cn-LongCang-Regular": 121, "cn-MaShanZheng-Regular": 122, "cn-PangMenZhengDao-3": 123, "cn-PangMenZhengDao-Cu": 124, "cn-PangMenZhengDao": 125, "cn-SentyCaramel": 126, "cn-SourceHanSerifSC": 127, "cn-WenCang-Regular": 128, "cn-WenQuanYiMicroHei": 129, +"cn-XianErTi": 130, "cn-YRDZSTJF": 131, "cn-YS-HelloFont-BangBangTi": 132, "cn-ZCOOLKuaiLe-Regular": 133, "cn-ZCOOLQingKeHuangYou-Regular": 134, "cn-ZCOOLXiaoWei-Regular": 135, "cn-ZCOOL_KuHei": 136, "cn-ZhiMangXing-Regular": 137, "cn-baotuxiaobaiti": 138, "cn-jiangxizhuokai-Regular": 139, +"cn-zcool-gdh": 140, "cn-zcoolqingkehuangyouti-Regular": 141, "cn-zcoolwenyiti": 142, "jp-04KanjyukuGothic": 0, "jp-07LightNovelPOP": 1, "jp-07NikumaruFont": 2, "jp-07YasashisaAntique": 3, "jp-07YasashisaGothic": 4, "jp-BokutachinoGothic2Bold": 5, "jp-BokutachinoGothic2Regular": 6, "jp-CHI_SpeedyRight_full_211128-Regular": 7, "jp-CHI_SpeedyRight_italic_full_211127-Regular": 8, "jp-CP-Font": 9, +"jp-Canva_CezanneProN-B": 10, "jp-Canva_CezanneProN-M": 11, "jp-Canva_ChiaroStd-B": 12, "jp-Canva_CometStd-B": 13, "jp-Canva_DotMincho16Std-M": 14, "jp-Canva_GrecoStd-B": 15, "jp-Canva_GrecoStd-M": 16, "jp-Canva_LyraStd-DB": 17, "jp-Canva_MatisseHatsuhiPro-B": 18, "jp-Canva_MatisseHatsuhiPro-M": 19, +"jp-Canva_ModeMinAStd-B": 20, "jp-Canva_NewCezanneProN-B": 21, "jp-Canva_NewCezanneProN-M": 22, "jp-Canva_PearlStd-L": 23, "jp-Canva_RaglanStd-UB": 24, "jp-Canva_RailwayStd-B": 25, "jp-Canva_ReggaeStd-B": 26, "jp-Canva_RocknRollStd-DB": 27, "jp-Canva_RodinCattleyaPro-B": 28, "jp-Canva_RodinCattleyaPro-M": 29, +"jp-Canva_RodinCattleyaPro-UB": 30, "jp-Canva_RodinHimawariPro-B": 31, "jp-Canva_RodinHimawariPro-M": 32, "jp-Canva_RodinMariaPro-B": 33, "jp-Canva_RodinMariaPro-DB": 34, "jp-Canva_RodinProN-M": 35, "jp-Canva_ShadowTLStd-B": 36, "jp-Canva_StickStd-B": 37, "jp-Canva_TsukuAOldMinPr6N-B": 38, "jp-Canva_TsukuAOldMinPr6N-R": 39, +"jp-Canva_UtrilloPro-DB": 40, "jp-Canva_UtrilloPro-M": 41, "jp-Canva_YurukaStd-UB": 42, "jp-FGUIGEN": 43, "jp-GlowSansJ-Condensed-Heavy": 44, "jp-GlowSansJ-Condensed-Light": 45, "jp-GlowSansJ-Normal-Bold": 46, "jp-GlowSansJ-Normal-Light": 47, "jp-HannariMincho": 48, "jp-HarenosoraMincho": 49, +"jp-Jiyucho": 50, "jp-Kaiso-Makina-B": 51, "jp-Kaisotai-Next-UP-B": 52, "jp-KokoroMinchoutai": 53, "jp-Mamelon-3-Hi-Regular": 54, "jp-MotoyaAnemoneStd-W1": 55, "jp-MotoyaAnemoneStd-W5": 56, "jp-MotoyaAnticPro-W3": 57, "jp-MotoyaCedarStd-W3": 58, "jp-MotoyaCedarStd-W5": 59, +"jp-MotoyaGochikaStd-W4": 60, "jp-MotoyaGochikaStd-W8": 61, "jp-MotoyaGothicMiyabiStd-W6": 62, "jp-MotoyaGothicStd-W3": 63, "jp-MotoyaGothicStd-W5": 64, "jp-MotoyaKoinStd-W3": 65, "jp-MotoyaKyotaiStd-W2": 66, "jp-MotoyaKyotaiStd-W4": 67, "jp-MotoyaMaruStd-W3": 68, "jp-MotoyaMaruStd-W5": 69, +"jp-MotoyaMinchoMiyabiStd-W4": 70, "jp-MotoyaMinchoMiyabiStd-W6": 71, "jp-MotoyaMinchoModernStd-W4": 72, "jp-MotoyaMinchoModernStd-W6": 73, "jp-MotoyaMinchoStd-W3": 74, "jp-MotoyaMinchoStd-W5": 75, "jp-MotoyaReisyoStd-W2": 76, "jp-MotoyaReisyoStd-W6": 77, "jp-MotoyaTohitsuStd-W4": 78, "jp-MotoyaTohitsuStd-W6": 79, +"jp-MtySousyokuEmBcJis-W6": 80, "jp-MtySousyokuLiBcJis-W6": 81, "jp-Mushin": 82, "jp-NotoSansJP-Bold": 83, "jp-NotoSansJP-Regular": 84, "jp-NudMotoyaAporoStd-W3": 85, "jp-NudMotoyaAporoStd-W5": 86, "jp-NudMotoyaCedarStd-W3": 87, "jp-NudMotoyaCedarStd-W5": 88, "jp-NudMotoyaMaruStd-W3": 89, +"jp-NudMotoyaMaruStd-W5": 90, "jp-NudMotoyaMinchoStd-W5": 91, "jp-Ounen-mouhitsu": 92, "jp-Ronde-B-Square": 93, "jp-SMotoyaGyosyoStd-W5": 94, "jp-SMotoyaSinkaiStd-W3": 95, "jp-SMotoyaSinkaiStd-W5": 96, "jp-SourceHanSansJP-Bold": 97, "jp-SourceHanSansJP-Regular": 98, "jp-SourceHanSerifJP-Bold": 99, +"jp-SourceHanSerifJP-Regular": 100, "jp-TazuganeGothicStdN-Bold": 101, "jp-TazuganeGothicStdN-Regular": 102, "jp-TelopMinProN-B": 103, "jp-Togalite-Bold": 104, "jp-Togalite-Regular": 105, "jp-TsukuMinPr6N-E": 106, "jp-TsukuMinPr6N-M": 107, "jp-mikachan_o": 108, "jp-nagayama_kai": 109, +"jp-07LogoTypeGothic7": 110, "jp-07TetsubinGothic": 111, "jp-851CHIKARA-DZUYOKU-KANA-A": 112, "jp-ARMinchoJIS-Light": 113, "jp-ARMinchoJIS-Ultra": 114, "jp-ARPCrystalMinchoJIS-Medium": 115, "jp-ARPCrystalRGothicJIS-Medium": 116, "jp-ARShounanShinpitsuGyosyoJIS-Medium": 117, "jp-AozoraMincho-bold": 118, "jp-AozoraMinchoRegular": 119, +"jp-ArialUnicodeMS-Bold": 120, "jp-ArialUnicodeMS": 121, "jp-CanvaBreezeJP": 122, "jp-CanvaLiCN": 123, "jp-CanvaLiJP": 124, "jp-CanvaOrientalBrushCN": 125, "jp-CanvaQinfuCalligraphyJP": 126, "jp-CanvaSweetHeartJP": 127, "jp-CanvaWenJP": 128, "jp-Corporate-Logo-Bold": 129, +"jp-DelaGothicOne-Regular": 130, "jp-GN-Kin-iro_SansSerif": 131, "jp-GN-Koharuiro_Sunray": 132, "jp-GenEiGothicM-B": 133, "jp-GenEiGothicM-R": 134, "jp-GenJyuuGothic-Bold": 135, "jp-GenRyuMinTW-B": 136, "jp-GenRyuMinTW-R": 137, "jp-GenSekiGothicTW-B": 138, "jp-GenSekiGothicTW-R": 139, +"jp-GenSenRoundedTW-B": 140, "jp-GenSenRoundedTW-R": 141, "jp-GenShinGothic-Bold": 142, "jp-GenShinGothic-Normal": 143, "jp-GenWanMinTW-L": 144, "jp-GenYoGothicTW-B": 145, "jp-GenYoGothicTW-R": 146, "jp-GenYoMinTW-B": 147, "jp-GenYoMinTW-R": 148, "jp-HGBouquet": 149, +"jp-HanaMinA": 150, "jp-HanazomeFont": 151, "jp-HinaMincho-Regular": 152, "jp-Honoka-Antique-Maru": 153, "jp-Honoka-Mincho": 154, "jp-HuiFontP": 155, "jp-IPAexMincho": 156, "jp-JK-Gothic-L": 157, "jp-JK-Gothic-M": 158, "jp-JackeyFont": 159, +"jp-KaiseiTokumin-Bold": 160, "jp-KaiseiTokumin-Regular": 161, "jp-Keifont": 162, "jp-KiwiMaru-Regular": 163, "jp-Koku-Mincho-Regular": 164, "jp-MotoyaLMaru-W3-90ms-RKSJ-H": 165, "jp-NewTegomin-Regular": 166, "jp-NicoKaku": 167, "jp-NicoMoji+": 168, "jp-Otsutome_font-Bold": 169, +"jp-PottaOne-Regular": 170, "jp-RampartOne-Regular": 171, "jp-Senobi-Gothic-Bold": 172, "jp-Senobi-Gothic-Regular": 173, "jp-SmartFontUI-Proportional": 174, "jp-SoukouMincho": 175, "jp-TEST_Klee-DB": 176, "jp-TEST_Klee-M": 177, "jp-TEST_UDMincho-B": 178, "jp-TEST_UDMincho-L": 179, +"jp-TT_Akakane-EB": 180, "jp-Tanuki-Permanent-Marker": 181, "jp-TrainOne-Regular": 182, "jp-TsunagiGothic-Black": 183, "jp-Ume-Hy-Gothic": 184, "jp-Ume-P-Mincho": 185, "jp-WenQuanYiMicroHei": 186, "jp-XANO-mincho-U32": 187, "jp-YOzFontM90-Regular": 188, "jp-Yomogi-Regular": 189, +"jp-YujiBoku-Regular": 190, "jp-YujiSyuku-Regular": 191, "jp-ZenKakuGothicNew-Bold": 192, "jp-ZenKakuGothicNew-Regular": 193, "jp-ZenKurenaido-Regular": 194, "jp-ZenMaruGothic-Bold": 195, "jp-ZenMaruGothic-Regular": 196, "jp-darts-font": 197, "jp-irohakakuC-Bold": 198, "jp-irohakakuC-Medium": 199, +"jp-irohakakuC-Regular": 200, "jp-katyou": 201, "jp-mplus-1m-bold": 202, "jp-mplus-1m-regular": 203, "jp-mplus-1p-bold": 204, "jp-mplus-1p-regular": 205, "jp-rounded-mplus-1p-bold": 206, "jp-rounded-mplus-1p-regular": 207, "jp-timemachine-wa": 208, "jp-ttf-GenEiLateMin-Medium": 209, +"jp-uzura_font": 210, "kr-Arita-buri-Bold_OTF": 0, "kr-Arita-buri-HairLine_OTF": 1, "kr-Arita-buri-Light_OTF": 2, "kr-Arita-buri-Medium_OTF": 3, "kr-Arita-buri-SemiBold_OTF": 4, "kr-Canva_YDSunshineL": 5, "kr-Canva_YDSunshineM": 6, "kr-Canva_YoonGulimPro710": 7, "kr-Canva_YoonGulimPro730": 8, "kr-Canva_YoonGulimPro740": 9, +"kr-Canva_YoonGulimPro760": 10, "kr-Canva_YoonGulimPro770": 11, "kr-Canva_YoonGulimPro790": 12, "kr-CreHappB": 13, "kr-CreHappL": 14, "kr-CreHappM": 15, "kr-CreHappS": 16, "kr-OTAuroraB": 17, "kr-OTAuroraL": 18, "kr-OTAuroraR": 19, +"kr-OTDoldamgilB": 20, "kr-OTDoldamgilL": 21, "kr-OTDoldamgilR": 22, "kr-OTHamsterB": 23, "kr-OTHamsterL": 24, "kr-OTHamsterR": 25, "kr-OTHapchangdanB": 26, "kr-OTHapchangdanL": 27, "kr-OTHapchangdanR": 28, "kr-OTSupersizeBkBOX": 29, +"kr-SourceHanSansKR-Bold": 30, "kr-SourceHanSansKR-ExtraLight": 31, "kr-SourceHanSansKR-Heavy": 32, "kr-SourceHanSansKR-Light": 33, "kr-SourceHanSansKR-Medium": 34, "kr-SourceHanSansKR-Normal": 35, "kr-SourceHanSansKR-Regular": 36, "kr-SourceHanSansSC-Bold": 37, "kr-SourceHanSansSC-ExtraLight": 38, "kr-SourceHanSansSC-Heavy": 39, +"kr-SourceHanSansSC-Light": 40, "kr-SourceHanSansSC-Medium": 41, "kr-SourceHanSansSC-Normal": 42, "kr-SourceHanSansSC-Regular": 43, "kr-SourceHanSerifSC-Bold": 44, "kr-SourceHanSerifSC-SemiBold": 45, "kr-TDTDBubbleBubbleOTF": 46, "kr-TDTDConfusionOTF": 47, "kr-TDTDCuteAndCuteOTF": 48, "kr-TDTDEggTakOTF": 49, +"kr-TDTDEmotionalLetterOTF": 50, "kr-TDTDGalapagosOTF": 51, "kr-TDTDHappyHourOTF": 52, "kr-TDTDLatteOTF": 53, "kr-TDTDMoonLightOTF": 54, "kr-TDTDParkForestOTF": 55, "kr-TDTDPencilOTF": 56, "kr-TDTDSmileOTF": 57, "kr-TDTDSproutOTF": 58, "kr-TDTDSunshineOTF": 59, +"kr-TDTDWaferOTF": 60, "kr-777Chyaochyureu": 61, "kr-ArialUnicodeMS-Bold": 62, "kr-ArialUnicodeMS": 63, "kr-BMHANNA": 64, "kr-Baekmuk-Dotum": 65, "kr-BagelFatOne-Regular": 66, "kr-CoreBandi": 67, "kr-CoreBandiFace": 68, "kr-CoreBori": 69, +"kr-DoHyeon-Regular": 70, "kr-Dokdo-Regular": 71, "kr-Gaegu-Bold": 72, "kr-Gaegu-Light": 73, "kr-Gaegu-Regular": 74, "kr-GamjaFlower-Regular": 75, "kr-GasoekOne-Regular": 76, "kr-GothicA1-Black": 77, "kr-GothicA1-Bold": 78, "kr-GothicA1-ExtraBold": 79, +"kr-GothicA1-ExtraLight": 80, "kr-GothicA1-Light": 81, "kr-GothicA1-Medium": 82, "kr-GothicA1-Regular": 83, "kr-GothicA1-SemiBold": 84, "kr-GothicA1-Thin": 85, "kr-Gugi-Regular": 86, "kr-HiMelody-Regular": 87, "kr-Jua-Regular": 88, "kr-KirangHaerang-Regular": 89, +"kr-NanumBrush": 90, "kr-NanumPen": 91, "kr-NanumSquareRoundB": 92, "kr-NanumSquareRoundEB": 93, "kr-NanumSquareRoundL": 94, "kr-NanumSquareRoundR": 95, "kr-SeH-CB": 96, "kr-SeH-CBL": 97, "kr-SeH-CEB": 98, "kr-SeH-CL": 99, +"kr-SeH-CM": 100, "kr-SeN-CB": 101, "kr-SeN-CBL": 102, "kr-SeN-CEB": 103, "kr-SeN-CL": 104, "kr-SeN-CM": 105, "kr-Sunflower-Bold": 106, "kr-Sunflower-Light": 107, "kr-Sunflower-Medium": 108, "kr-TTClaytoyR": 109, +"kr-TTDalpangiR": 110, "kr-TTMamablockR": 111, "kr-TTNauidongmuR": 112, "kr-TTOktapbangR": 113, "kr-UhBeeMiMi": 114, "kr-UhBeeMiMiBold": 115, "kr-UhBeeSe_hyun": 116, "kr-UhBeeSe_hyunBold": 117, "kr-UhBeenamsoyoung": 118, "kr-UhBeenamsoyoungBold": 119, +"kr-WenQuanYiMicroHei": 120, "kr-YeonSung-Regular": 121}""" + + +def add_special_token(tokenizer: T5Tokenizer, text_encoder: T5Stack): + """ + Add special tokens for color and font to tokenizer and text encoder. + + Args: + tokenizer: Huggingface tokenizer. + text_encoder: Huggingface T5 encoder. + """ + idx_font_dict = json.loads(MULTILINGUAL_10_LANG_IDX_JSON) + idx_color_dict = json.loads(COLOR_IDX_JSON) + + font_token = [f"<{font_code[:2]}-font-{idx_font_dict[font_code]}>" for font_code in idx_font_dict] + color_token = [f"" for i in range(len(idx_color_dict))] + additional_special_tokens = [] + additional_special_tokens += color_token + additional_special_tokens += font_token + + tokenizer.add_tokens(additional_special_tokens, special_tokens=True) + # Set mean_resizing=False to avoid PyTorch LAPACK dependency + text_encoder.resize_token_embeddings(len(tokenizer), mean_resizing=False) + + +def load_byt5( + ckpt_path: str, + dtype: Optional[torch.dtype], + device: Union[str, torch.device], + disable_mmap: bool = False, + state_dict: Optional[dict] = None, +) -> Tuple[T5Stack, T5Tokenizer]: + BYT5_CONFIG_JSON = """ +{ + "_name_or_path": "/home/patrick/t5/byt5-small", + "architectures": [ + "T5ForConditionalGeneration" + ], + "d_ff": 3584, + "d_kv": 64, + "d_model": 1472, + "decoder_start_token_id": 0, + "dropout_rate": 0.1, + "eos_token_id": 1, + "feed_forward_proj": "gated-gelu", + "gradient_checkpointing": false, + "initializer_factor": 1.0, + "is_encoder_decoder": true, + "layer_norm_epsilon": 1e-06, + "model_type": "t5", + "num_decoder_layers": 4, + "num_heads": 6, + "num_layers": 12, + "pad_token_id": 0, + "relative_attention_num_buckets": 32, + "tie_word_embeddings": false, + "tokenizer_class": "ByT5Tokenizer", + "transformers_version": "4.7.0.dev0", + "use_cache": true, + "vocab_size": 384 + } +""" + + logger.info(f"Loading BYT5 tokenizer from {BYT5_TOKENIZER_PATH}") + byt5_tokenizer = AutoTokenizer.from_pretrained(BYT5_TOKENIZER_PATH) + + logger.info("Initializing BYT5 text encoder") + config = json.loads(BYT5_CONFIG_JSON) + config = T5Config(**config) + with init_empty_weights(): + byt5_text_encoder = T5ForConditionalGeneration._from_config(config).get_encoder() + + add_special_token(byt5_tokenizer, byt5_text_encoder) + + if state_dict is not None: + sd = state_dict + else: + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_safetensors(ckpt_path, device, disable_mmap=disable_mmap, dtype=dtype) + + # remove "encoder." prefix + sd = {k[len("encoder.") :] if k.startswith("encoder.") else k: v for k, v in sd.items()} + sd["embed_tokens.weight"] = sd.pop("shared.weight") + + info = byt5_text_encoder.load_state_dict(sd, strict=True, assign=True) + byt5_text_encoder.to(device) + logger.info(f"BYT5 text encoder loaded with info: {info}") + + return byt5_tokenizer, byt5_text_encoder + + +def load_qwen2_5_vl( + ckpt_path: str, + dtype: Optional[torch.dtype], + device: Union[str, torch.device], + disable_mmap: bool = False, + state_dict: Optional[dict] = None, +) -> tuple[Qwen2Tokenizer, Qwen2_5_VLForConditionalGeneration]: + QWEN2_5_VL_CONFIG_JSON = """ +{ + "architectures": [ + "Qwen2_5_VLForConditionalGeneration" + ], + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "hidden_act": "silu", + "hidden_size": 3584, + "image_token_id": 151655, + "initializer_range": 0.02, + "intermediate_size": 18944, + "max_position_embeddings": 128000, + "max_window_layers": 28, + "model_type": "qwen2_5_vl", + "num_attention_heads": 28, + "num_hidden_layers": 28, + "num_key_value_heads": 4, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "mrope_section": [ + 16, + 24, + 24 + ], + "rope_type": "default", + "type": "default" + }, + "rope_theta": 1000000.0, + "sliding_window": 32768, + "text_config": { + "architectures": [ + "Qwen2_5_VLForConditionalGeneration" + ], + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "hidden_act": "silu", + "hidden_size": 3584, + "image_token_id": null, + "initializer_range": 0.02, + "intermediate_size": 18944, + "layer_types": [ + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention" + ], + "max_position_embeddings": 128000, + "max_window_layers": 28, + "model_type": "qwen2_5_vl_text", + "num_attention_heads": 28, + "num_hidden_layers": 28, + "num_key_value_heads": 4, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "mrope_section": [ + 16, + 24, + 24 + ], + "rope_type": "default", + "type": "default" + }, + "rope_theta": 1000000.0, + "sliding_window": null, + "torch_dtype": "float32", + "use_cache": true, + "use_sliding_window": false, + "video_token_id": null, + "vision_end_token_id": 151653, + "vision_start_token_id": 151652, + "vision_token_id": 151654, + "vocab_size": 152064 + }, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.53.1", + "use_cache": true, + "use_sliding_window": false, + "video_token_id": 151656, + "vision_config": { + "depth": 32, + "fullatt_block_indexes": [ + 7, + 15, + 23, + 31 + ], + "hidden_act": "silu", + "hidden_size": 1280, + "in_channels": 3, + "in_chans": 3, + "initializer_range": 0.02, + "intermediate_size": 3420, + "model_type": "qwen2_5_vl", + "num_heads": 16, + "out_hidden_size": 3584, + "patch_size": 14, + "spatial_merge_size": 2, + "spatial_patch_size": 14, + "temporal_patch_size": 2, + "tokens_per_second": 2, + "torch_dtype": "float32", + "window_size": 112 + }, + "vision_end_token_id": 151653, + "vision_start_token_id": 151652, + "vision_token_id": 151654, + "vocab_size": 152064 +} +""" + config = json.loads(QWEN2_5_VL_CONFIG_JSON) + config = Qwen2_5_VLConfig(**config) + with init_empty_weights(): + qwen2_5_vl = Qwen2_5_VLForConditionalGeneration._from_config(config) + + if state_dict is not None: + sd = state_dict + else: + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_safetensors(ckpt_path, device, disable_mmap=disable_mmap, dtype=dtype) + + # convert prefixes + for key in list(sd.keys()): + if key.startswith("model."): + new_key = key.replace("model.", "model.language_model.", 1) + elif key.startswith("visual."): + new_key = key.replace("visual.", "model.visual.", 1) + else: + continue + if key not in sd: + logger.warning(f"Key {key} not found in state dict, skipping.") + continue + sd[new_key] = sd.pop(key) + + info = qwen2_5_vl.load_state_dict(sd, strict=True, assign=True) + logger.info(f"Loaded Qwen2.5-VL: {info}") + qwen2_5_vl.to(device) + + if dtype is not None: + if dtype.itemsize == 1: # fp8 + org_dtype = torch.bfloat16 # model weight is fp8 in loading, but original dtype is bfloat16 + logger.info(f"prepare Qwen2.5-VL for fp8: set to {dtype} from {org_dtype}") + qwen2_5_vl.to(dtype) + + # prepare LLM for fp8 + def prepare_fp8(vl_model: Qwen2_5_VLForConditionalGeneration, target_dtype): + def forward_hook(module): + def forward(hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + module.variance_epsilon) + # return module.weight.to(input_dtype) * hidden_states.to(input_dtype) + return (module.weight.to(torch.float32) * hidden_states.to(torch.float32)).to(input_dtype) + + return forward + + def decoder_forward_hook(module): + def forward( + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs, + ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + + residual = hidden_states + + hidden_states = module.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = module.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + input_dtype = hidden_states.dtype + hidden_states = residual.to(torch.float32) + hidden_states.to(torch.float32) + hidden_states = hidden_states.to(input_dtype) + + # Fully Connected + residual = hidden_states + hidden_states = module.post_attention_layernorm(hidden_states) + hidden_states = module.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + return forward + + for module in vl_model.modules(): + if module.__class__.__name__ in ["Embedding"]: + # print("set", module.__class__.__name__, "to", target_dtype) + module.to(target_dtype) + if module.__class__.__name__ in ["Qwen2RMSNorm"]: + # print("set", module.__class__.__name__, "hooks") + module.forward = forward_hook(module) + if module.__class__.__name__ in ["Qwen2_5_VLDecoderLayer"]: + # print("set", module.__class__.__name__, "hooks") + module.forward = decoder_forward_hook(module) + if module.__class__.__name__ in ["Qwen2_5_VisionRotaryEmbedding"]: + # print("set", module.__class__.__name__, "hooks") + module.to(target_dtype) + + prepare_fp8(qwen2_5_vl, org_dtype) + + else: + logger.info(f"Setting Qwen2.5-VL to dtype: {dtype}") + qwen2_5_vl.to(dtype) + + # Load tokenizer + logger.info(f"Loading tokenizer from {QWEN_2_5_VL_IMAGE_ID}") + tokenizer = Qwen2Tokenizer.from_pretrained(QWEN_2_5_VL_IMAGE_ID) + return tokenizer, qwen2_5_vl + + +def get_qwen_prompt_embeds( + tokenizer: Qwen2Tokenizer, vlm: Qwen2_5_VLForConditionalGeneration, prompt: Union[str, list[str]] = None +): + tokenizer_max_length = 1024 + + # HunyuanImage-2.1 does not use "<|im_start|>assistant\n" in the prompt template + prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>" + # \n<|im_start|>assistant\n" + prompt_template_encode_start_idx = 34 + # default_sample_size = 128 + + device = vlm.device + dtype = vlm.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + template = prompt_template_encode + drop_idx = prompt_template_encode_start_idx + txt = [template.format(e) for e in prompt] + txt_tokens = tokenizer(txt, max_length=tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt").to( + device + ) + + if dtype.itemsize == 1: # fp8 + with torch.no_grad(), torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=True): + encoder_hidden_states = vlm( + input_ids=txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask, output_hidden_states=True + ) + else: + with torch.no_grad(), torch.autocast(device_type=device.type, dtype=dtype, enabled=True): + encoder_hidden_states = vlm( + input_ids=txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask, output_hidden_states=True + ) + hidden_states = encoder_hidden_states.hidden_states[-3] # use the 3rd last layer's hidden states for HunyuanImage-2.1 + if hidden_states.shape[1] > tokenizer_max_length + drop_idx: + logger.warning(f"Hidden states shape {hidden_states.shape} exceeds max length {tokenizer_max_length + drop_idx}") + + # --- Unnecessary complicated processing, keep for reference --- + # split_hidden_states = extract_masked_hidden(hidden_states, txt_tokens.attention_mask) + # split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + # attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + # max_seq_len = max([e.size(0) for e in split_hidden_states]) + # prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]) + # encoder_attention_mask = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]) + # ---------------------------------------------------------- + + prompt_embeds = hidden_states[:, drop_idx:, :] + encoder_attention_mask = txt_tokens.attention_mask[:, drop_idx:] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds, encoder_attention_mask + + +def format_prompt(texts, styles): + """ + Text "{text}" in {color}, {type}. + """ + + prompt = "" + for text, style in zip(texts, styles): + # color and style are always None in official implementation, so we only use text + text_prompt = f'Text "{text}"' + text_prompt += ". " + prompt = prompt + text_prompt + return prompt + + +def get_glyph_prompt_embeds( + tokenizer: T5Tokenizer, text_encoder: T5Stack, prompt: Union[str, list[str]] = None +) -> Tuple[list[bool], torch.Tensor, torch.Tensor]: + byt5_max_length = 128 + if not prompt: + return ( + [False], + torch.zeros((1, byt5_max_length, 1472), device=text_encoder.device), + torch.zeros((1, byt5_max_length), device=text_encoder.device, dtype=torch.int64), + ) + + try: + text_prompt_texts = [] + # pattern_quote_single = r"\'(.*?)\'" + pattern_quote_double = r"\"(.*?)\"" + pattern_quote_chinese_single = r"‘(.*?)’" + pattern_quote_chinese_double = r"“(.*?)”" + + # matches_quote_single = re.findall(pattern_quote_single, prompt) + matches_quote_double = re.findall(pattern_quote_double, prompt) + matches_quote_chinese_single = re.findall(pattern_quote_chinese_single, prompt) + matches_quote_chinese_double = re.findall(pattern_quote_chinese_double, prompt) + + # text_prompt_texts.extend(matches_quote_single) + text_prompt_texts.extend(matches_quote_double) + text_prompt_texts.extend(matches_quote_chinese_single) + text_prompt_texts.extend(matches_quote_chinese_double) + + if not text_prompt_texts: + return ( + [False], + torch.zeros((1, byt5_max_length, 1472), device=text_encoder.device), + torch.zeros((1, byt5_max_length), device=text_encoder.device, dtype=torch.int64), + ) + + text_prompt_style_list = [{"color": None, "font-family": None} for _ in range(len(text_prompt_texts))] + glyph_text_formatted = format_prompt(text_prompt_texts, text_prompt_style_list) + + byt5_text_ids, byt5_text_mask = get_byt5_text_tokens(tokenizer, byt5_max_length, glyph_text_formatted) + + byt5_text_ids = byt5_text_ids.to(device=text_encoder.device) + byt5_text_mask = byt5_text_mask.to(device=text_encoder.device) + + byt5_prompt_embeds = text_encoder(byt5_text_ids, attention_mask=byt5_text_mask.float()) + byt5_emb = byt5_prompt_embeds[0] + + return [True], byt5_emb, byt5_text_mask + + except Exception as e: + logger.warning(f"Warning: Error in glyph encoding, using fallback: {e}") + return ( + [False], + torch.zeros((1, byt5_max_length, 1472), device=text_encoder.device), + torch.zeros((1, byt5_max_length), device=text_encoder.device, dtype=torch.int64), + ) + + +def get_byt5_text_tokens(tokenizer, max_length, text_list): + """ + Get byT5 text tokens. + + Args: + tokenizer: The tokenizer object + max_length: Maximum token length + text_list: List or string of text + + Returns: + Tuple of (byt5_text_ids, byt5_text_mask) + """ + if isinstance(text_list, list): + text_prompt = " ".join(text_list) + else: + text_prompt = text_list + + byt5_text_inputs = tokenizer( + text_prompt, padding="max_length", max_length=max_length, truncation=True, add_special_tokens=True, return_tensors="pt" + ) + + byt5_text_ids = byt5_text_inputs.input_ids + byt5_text_mask = byt5_text_inputs.attention_mask + + return byt5_text_ids, byt5_text_mask diff --git a/library/hunyuan_image_utils.py b/library/hunyuan_image_utils.py new file mode 100644 index 00000000..17847104 --- /dev/null +++ b/library/hunyuan_image_utils.py @@ -0,0 +1,461 @@ +# Original work: https://github.com/Tencent-Hunyuan/HunyuanImage-2.1 +# Re-implemented for license compliance for sd-scripts. + +import math +from typing import Tuple, Union, Optional +import torch + + +def _to_tuple(x, dim=2): + """ + Convert int or sequence to tuple of specified dimension. + + Args: + x: Int or sequence to convert. + dim: Target dimension for tuple. + + Returns: + Tuple of length dim. + """ + if isinstance(x, int) or isinstance(x, float): + return (x,) * dim + elif len(x) == dim: + return x + else: + raise ValueError(f"Expected length {dim} or int, but got {x}") + + +def get_meshgrid_nd(start, dim=2): + """ + Generate n-dimensional coordinate meshgrid from 0 to grid_size. + + Creates coordinate grids for each spatial dimension, useful for + generating position embeddings. + + Args: + start: Grid size for each dimension (int or tuple). + dim: Number of spatial dimensions. + + Returns: + Coordinate grid tensor [dim, *grid_size]. + """ + # Convert start to grid sizes + num = _to_tuple(start, dim=dim) + start = (0,) * dim + stop = num + + # Generate coordinate arrays for each dimension + axis_grid = [] + for i in range(dim): + a, b, n = start[i], stop[i], num[i] + g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n] + axis_grid.append(g) + grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D] + grid = torch.stack(grid, dim=0) # [dim, W, H, D] + + return grid + + +def get_nd_rotary_pos_embed(rope_dim_list, start, theta=10000.0): + """ + Generate n-dimensional rotary position embeddings for spatial tokens. + + Creates RoPE embeddings for multi-dimensional positional encoding, + distributing head dimensions across spatial dimensions. + + Args: + rope_dim_list: Dimensions allocated to each spatial axis (should sum to head_dim). + start: Spatial grid size for each dimension. + theta: Base frequency for RoPE computation. + + Returns: + Tuple of (cos_freqs, sin_freqs) for rotary embedding [H*W, D/2]. + """ + + grid = get_meshgrid_nd(start, dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H] + + # Generate RoPE embeddings for each spatial dimension + embs = [] + for i in range(len(rope_dim_list)): + emb = get_1d_rotary_pos_embed(rope_dim_list[i], grid[i].reshape(-1), theta) # 2 x [WHD, rope_dim_list[i]] + embs.append(emb) + + cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2) + sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2) + return cos, sin + + +def get_1d_rotary_pos_embed( + dim: int, pos: Union[torch.FloatTensor, int], theta: float = 10000.0 +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Generate 1D rotary position embeddings. + + Args: + dim: Embedding dimension (must be even). + pos: Position indices [S] or scalar for sequence length. + theta: Base frequency for sinusoidal encoding. + + Returns: + Tuple of (cos_freqs, sin_freqs) tensors [S, D]. + """ + if isinstance(pos, int): + pos = torch.arange(pos).float() + + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2] + freqs = torch.outer(pos, freqs) # [S, D/2] + freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] + return freqs_cos, freqs_sin + + +def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings for diffusion models. + + Converts scalar timesteps to high-dimensional embeddings using + sinusoidal encoding at different frequencies. + + Args: + t: Timestep tensor [N]. + dim: Output embedding dimension. + max_period: Maximum period for frequency computation. + + Returns: + Timestep embeddings [N, dim]. + """ + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +def modulate(x, shift=None, scale=None): + """ + Apply adaptive layer normalization modulation. + + Applies scale and shift transformations for conditioning + in adaptive layer normalization. + + Args: + x: Input tensor to modulate. + shift: Additive shift parameter (optional). + scale: Multiplicative scale parameter (optional). + + Returns: + Modulated tensor x * (1 + scale) + shift. + """ + if scale is None and shift is None: + return x + elif shift is None: + return x * (1 + scale.unsqueeze(1)) + elif scale is None: + return x + shift.unsqueeze(1) + else: + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +def apply_gate(x, gate=None, tanh=False): + """ + Apply gating mechanism to tensor. + + Multiplies input by gate values, optionally applying tanh activation. + Used in residual connections for adaptive control. + + Args: + x: Input tensor to gate. + gate: Gating values (optional). + tanh: Whether to apply tanh to gate values. + + Returns: + Gated tensor x * gate (with optional tanh). + """ + if gate is None: + return x + if tanh: + return x * gate.unsqueeze(1).tanh() + else: + return x * gate.unsqueeze(1) + + +def reshape_for_broadcast( + freqs_cis: Tuple[torch.Tensor, torch.Tensor], + x: torch.Tensor, + head_first=False, +): + """ + Reshape RoPE frequency tensors for broadcasting with attention tensors. + + Args: + freqs_cis: Tuple of (cos_freqs, sin_freqs) tensors. + x: Target tensor for broadcasting compatibility. + head_first: Must be False (only supported layout). + + Returns: + Reshaped (cos_freqs, sin_freqs) tensors ready for broadcasting. + """ + assert not head_first, "Only head_first=False layout supported." + assert isinstance(freqs_cis, tuple), "Expected tuple of (cos, sin) frequency tensors." + assert x.ndim > 1, f"x should have at least 2 dimensions, but got {x.ndim}" + + # Validate frequency tensor dimensions match target tensor + assert freqs_cis[0].shape == ( + x.shape[1], + x.shape[-1], + ), f"Frequency tensor shape {freqs_cis[0].shape} incompatible with target shape {x.shape}" + + shape = [d if i == 1 or i == x.ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape) + + +def rotate_half(x): + """ + Rotate half the dimensions for RoPE computation. + + Splits the last dimension in half and applies a 90-degree rotation + by swapping and negating components. + + Args: + x: Input tensor [..., D] where D is even. + + Returns: + Rotated tensor with same shape as input. + """ + x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + return torch.stack([-x_imag, x_real], dim=-1).flatten(3) + + +def apply_rotary_emb( + xq: torch.Tensor, xk: torch.Tensor, freqs_cis: Tuple[torch.Tensor, torch.Tensor], head_first: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary position embeddings to query and key tensors. + + Args: + xq: Query tensor [B, S, H, D]. + xk: Key tensor [B, S, H, D]. + freqs_cis: Tuple of (cos_freqs, sin_freqs) for rotation. + head_first: Whether head dimension precedes sequence dimension. + + Returns: + Tuple of rotated (query, key) tensors. + """ + device = xq.device + dtype = xq.dtype + + cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) + cos, sin = cos.to(device), sin.to(device) + + # Apply rotation: x' = x * cos + rotate_half(x) * sin + xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).to(dtype) + xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).to(dtype) + + return xq_out, xk_out + + +def get_timesteps_sigmas(sampling_steps: int, shift: float, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Generate timesteps and sigmas for diffusion sampling. + + Args: + sampling_steps: Number of sampling steps. + shift: Sigma shift parameter for schedule modification. + device: Target device for tensors. + + Returns: + Tuple of (timesteps, sigmas) tensors. + """ + sigmas = torch.linspace(1, 0, sampling_steps + 1) + sigmas = (shift * sigmas) / (1 + (shift - 1) * sigmas) + sigmas = sigmas.to(torch.float32) + timesteps = (sigmas[:-1] * 1000).to(dtype=torch.float32, device=device) + return timesteps, sigmas + + +def step(latents, noise_pred, sigmas, step_i): + """ + Perform a single diffusion sampling step. + + Args: + latents: Current latent state. + noise_pred: Predicted noise. + sigmas: Noise schedule sigmas. + step_i: Current step index. + + Returns: + Updated latents after the step. + """ + return latents.float() - (sigmas[step_i] - sigmas[step_i + 1]) * noise_pred.float() + + +# region AdaptiveProjectedGuidance + + +class MomentumBuffer: + """ + Exponential moving average buffer for APG momentum. + """ + def __init__(self, momentum: float): + self.momentum = momentum + self.running_average = 0 + + def update(self, update_value: torch.Tensor): + new_average = self.momentum * self.running_average + self.running_average = update_value + new_average + + +def normalized_guidance_apg( + pred_cond: torch.Tensor, + pred_uncond: torch.Tensor, + guidance_scale: float, + momentum_buffer: Optional[MomentumBuffer] = None, + eta: float = 1.0, + norm_threshold: float = 0.0, + use_original_formulation: bool = False, +): + """ + Apply normalized adaptive projected guidance. + + Projects the guidance vector to reduce over-saturation while maintaining + directional control by decomposing into parallel and orthogonal components. + + Args: + pred_cond: Conditional prediction. + pred_uncond: Unconditional prediction. + guidance_scale: Guidance scale factor. + momentum_buffer: Optional momentum buffer for temporal smoothing. + eta: Scaling factor for parallel component. + norm_threshold: Maximum norm for guidance vector clipping. + use_original_formulation: Whether to use original APG formulation. + + Returns: + Guided prediction tensor. + """ + diff = pred_cond - pred_uncond + dim = [-i for i in range(1, len(diff.shape))] # All dimensions except batch + + # Apply momentum smoothing if available + if momentum_buffer is not None: + momentum_buffer.update(diff) + diff = momentum_buffer.running_average + + # Apply norm clipping if threshold is set + if norm_threshold > 0: + diff_norm = diff.norm(p=2, dim=dim, keepdim=True) + scale_factor = torch.minimum(torch.ones_like(diff_norm), norm_threshold / diff_norm) + diff = diff * scale_factor + + # Project guidance vector into parallel and orthogonal components + v0, v1 = diff.double(), pred_cond.double() + v1 = torch.nn.functional.normalize(v1, dim=dim) + v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1 + v0_orthogonal = v0 - v0_parallel + diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff) + + # Combine components with different scaling + normalized_update = diff_orthogonal + eta * diff_parallel + pred = pred_cond if use_original_formulation else pred_uncond + pred = pred + guidance_scale * normalized_update + + return pred + + +class AdaptiveProjectedGuidance: + """ + Adaptive Projected Guidance for classifier-free guidance. + + Implements APG which projects the guidance vector to reduce over-saturation + while maintaining directional control. + """ + def __init__( + self, + guidance_scale: float = 7.5, + adaptive_projected_guidance_momentum: Optional[float] = None, + adaptive_projected_guidance_rescale: float = 15.0, + eta: float = 0.0, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + ): + assert guidance_rescale == 0.0, "guidance_rescale > 0.0 not supported." + + self.guidance_scale = guidance_scale + self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum + self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale + self.eta = eta + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + self.momentum_buffer = None + + def __call__(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None, step=None) -> torch.Tensor: + if step == 0 and self.adaptive_projected_guidance_momentum is not None: + self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum) + + pred = normalized_guidance_apg( + pred_cond, + pred_uncond, + self.guidance_scale, + self.momentum_buffer, + self.eta, + self.adaptive_projected_guidance_rescale, + self.use_original_formulation, + ) + + return pred + + +# endregion + + +def apply_classifier_free_guidance( + noise_pred_text: torch.Tensor, + noise_pred_uncond: torch.Tensor, + is_ocr: bool, + guidance_scale: float, + step: int, + apg_start_step_ocr: int = 75, + apg_start_step_general: int = 10, + cfg_guider_ocr: AdaptiveProjectedGuidance = None, + cfg_guider_general: AdaptiveProjectedGuidance = None, +): + """ + Apply classifier-free guidance with OCR-aware APG for batch_size=1. + + Args: + noise_pred_text: Conditional noise prediction tensor [1, ...]. + noise_pred_uncond: Unconditional noise prediction tensor [1, ...]. + is_ocr: Whether this sample requires OCR-specific guidance. + guidance_scale: Guidance scale for CFG. + step: Current diffusion step index. + apg_start_step_ocr: Step to start APG for OCR regions. + apg_start_step_general: Step to start APG for general regions. + cfg_guider_ocr: APG guider for OCR regions. + cfg_guider_general: APG guider for general regions. + + Returns: + Guided noise prediction tensor [1, ...]. + """ + if guidance_scale == 1.0: + return noise_pred_text + + # Select appropriate guider and start step based on OCR requirement + if is_ocr: + cfg_guider = cfg_guider_ocr + apg_start_step = apg_start_step_ocr + else: + cfg_guider = cfg_guider_general + apg_start_step = apg_start_step_general + + # Apply standard CFG or APG based on current step + if step <= apg_start_step: + # Standard classifier-free guidance + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + # Initialize APG guider state + _ = cfg_guider(noise_pred_text, noise_pred_uncond, step=step) + else: + # Use APG for guidance + noise_pred = cfg_guider(noise_pred_text, noise_pred_uncond, step=step) + + return noise_pred diff --git a/library/hunyuan_image_vae.py b/library/hunyuan_image_vae.py new file mode 100644 index 00000000..6eb035c3 --- /dev/null +++ b/library/hunyuan_image_vae.py @@ -0,0 +1,622 @@ +from typing import Optional, Tuple + +from einops import rearrange +import numpy as np +import torch +from torch import Tensor, nn +from torch.nn import Conv2d +from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution + +from library.utils import load_safetensors, setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +VAE_SCALE_FACTOR = 32 # 32x spatial compression + + +def swish(x: Tensor) -> Tensor: + """Swish activation function: x * sigmoid(x).""" + return x * torch.sigmoid(x) + + +class AttnBlock(nn.Module): + """Self-attention block using scaled dot-product attention.""" + + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.q = Conv2d(in_channels, in_channels, kernel_size=1) + self.k = Conv2d(in_channels, in_channels, kernel_size=1) + self.v = Conv2d(in_channels, in_channels, kernel_size=1) + self.proj_out = Conv2d(in_channels, in_channels, kernel_size=1) + + def attention(self, x: Tensor) -> Tensor: + x = self.norm(x) + q = self.q(x) + k = self.k(x) + v = self.v(x) + + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b (h w) c").contiguous() + k = rearrange(k, "b c h w -> b (h w) c").contiguous() + v = rearrange(v, "b c h w -> b (h w) c").contiguous() + + x = nn.functional.scaled_dot_product_attention(q, k, v) + return rearrange(x, "b (h w) c -> b c h w", h=h, w=w, c=c, b=b) + + def forward(self, x: Tensor) -> Tensor: + return x + self.proj_out(self.attention(x)) + + +class ResnetBlock(nn.Module): + """ + Residual block with two convolutions, group normalization, and swish activation. + Includes skip connection with optional channel dimension matching. + + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + """ + + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.conv1 = Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) + self.conv2 = Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + + # Skip connection projection for channel dimension mismatch + if self.in_channels != self.out_channels: + self.nin_shortcut = Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x: Tensor) -> Tensor: + h = x + # First convolution block + h = self.norm1(h) + h = swish(h) + h = self.conv1(h) + # Second convolution block + h = self.norm2(h) + h = swish(h) + h = self.conv2(h) + + # Apply skip connection with optional projection + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + return x + h + + +class Downsample(nn.Module): + """ + Spatial downsampling block that reduces resolution by 2x using convolution followed by + pixel rearrangement. Includes skip connection with grouped averaging. + + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels (must be divisible by 4). + """ + + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + factor = 4 # 2x2 spatial reduction factor + assert out_channels % factor == 0 + + self.conv = Conv2d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1) + self.group_size = factor * in_channels // out_channels + + def forward(self, x: Tensor) -> Tensor: + # Apply convolution and rearrange pixels for 2x downsampling + h = self.conv(x) + h = rearrange(h, "b c (h r1) (w r2) -> b (r1 r2 c) h w", r1=2, r2=2) + + # Create skip connection with pixel rearrangement + shortcut = rearrange(x, "b c (h r1) (w r2) -> b (r1 r2 c) h w", r1=2, r2=2) + B, C, H, W = shortcut.shape + shortcut = shortcut.view(B, h.shape[1], self.group_size, H, W).mean(dim=2) + + return h + shortcut + + +class Upsample(nn.Module): + """ + Spatial upsampling block that increases resolution by 2x using convolution followed by + pixel rearrangement. Includes skip connection with channel repetition. + + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + """ + + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + factor = 4 # 2x2 spatial expansion factor + self.conv = Conv2d(in_channels, out_channels * factor, kernel_size=3, stride=1, padding=1) + self.repeats = factor * out_channels // in_channels + + def forward(self, x: Tensor) -> Tensor: + # Apply convolution and rearrange pixels for 2x upsampling + h = self.conv(x) + h = rearrange(h, "b (r1 r2 c) h w -> b c (h r1) (w r2)", r1=2, r2=2) + + # Create skip connection with channel repetition + shortcut = x.repeat_interleave(repeats=self.repeats, dim=1) + shortcut = rearrange(shortcut, "b (r1 r2 c) h w -> b c (h r1) (w r2)", r1=2, r2=2) + + return h + shortcut + + +class Encoder(nn.Module): + """ + VAE encoder that progressively downsamples input images to a latent representation. + Uses residual blocks, attention, and spatial downsampling. + + Parameters + ---------- + in_channels : int + Number of input image channels (e.g., 3 for RGB). + z_channels : int + Number of latent channels in the output. + block_out_channels : Tuple[int, ...] + Output channels for each downsampling block. + num_res_blocks : int + Number of residual blocks per downsampling stage. + ffactor_spatial : int + Total spatial downsampling factor (e.g., 32 for 32x compression). + """ + + def __init__( + self, + in_channels: int, + z_channels: int, + block_out_channels: Tuple[int, ...], + num_res_blocks: int, + ffactor_spatial: int, + ): + super().__init__() + assert block_out_channels[-1] % (2 * z_channels) == 0 + + self.z_channels = z_channels + self.block_out_channels = block_out_channels + self.num_res_blocks = num_res_blocks + + self.conv_in = Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1) + + self.down = nn.ModuleList() + block_in = block_out_channels[0] + + # Build downsampling blocks + for i_level, ch in enumerate(block_out_channels): + block = nn.ModuleList() + block_out = ch + + # Add residual blocks for this level + for _ in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + + down = nn.Module() + down.block = block + + # Add spatial downsampling if needed + add_spatial_downsample = bool(i_level < np.log2(ffactor_spatial)) + if add_spatial_downsample: + assert i_level < len(block_out_channels) - 1 + block_out = block_out_channels[i_level + 1] + down.downsample = Downsample(block_in, block_out) + block_in = block_out + + self.down.append(down) + + # Middle blocks with attention + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # Output layers + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor) -> Tensor: + # Initial convolution + h = self.conv_in(x) + + # Progressive downsampling through blocks + for i_level in range(len(self.block_out_channels)): + # Apply residual blocks at this level + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](h) + # Apply spatial downsampling if available + if hasattr(self.down[i_level], "downsample"): + h = self.down[i_level].downsample(h) + + # Middle processing with attention + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # Final output layers with skip connection + group_size = self.block_out_channels[-1] // (2 * self.z_channels) + shortcut = rearrange(h, "b (c r) h w -> b c r h w", r=group_size).mean(dim=2) + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + h += shortcut + return h + + +class Decoder(nn.Module): + """ + VAE decoder that progressively upsamples latent representations back to images. + Uses residual blocks, attention, and spatial upsampling. + + Parameters + ---------- + z_channels : int + Number of latent channels in the input. + out_channels : int + Number of output image channels (e.g., 3 for RGB). + block_out_channels : Tuple[int, ...] + Output channels for each upsampling block. + num_res_blocks : int + Number of residual blocks per upsampling stage. + ffactor_spatial : int + Total spatial upsampling factor (e.g., 32 for 32x expansion). + """ + + def __init__( + self, + z_channels: int, + out_channels: int, + block_out_channels: Tuple[int, ...], + num_res_blocks: int, + ffactor_spatial: int, + ): + super().__init__() + assert block_out_channels[0] % z_channels == 0 + + self.z_channels = z_channels + self.block_out_channels = block_out_channels + self.num_res_blocks = num_res_blocks + + block_in = block_out_channels[0] + self.conv_in = Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # Middle blocks with attention + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # Build upsampling blocks + self.up = nn.ModuleList() + for i_level, ch in enumerate(block_out_channels): + block = nn.ModuleList() + block_out = ch + + # Add residual blocks for this level (extra block for decoder) + for _ in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + + up = nn.Module() + up.block = block + + # Add spatial upsampling if needed + add_spatial_upsample = bool(i_level < np.log2(ffactor_spatial)) + if add_spatial_upsample: + assert i_level < len(block_out_channels) - 1 + block_out = block_out_channels[i_level + 1] + up.upsample = Upsample(block_in, block_out) + block_in = block_out + + self.up.append(up) + + # Output layers + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, z: Tensor) -> Tensor: + # Initial processing with skip connection + repeats = self.block_out_channels[0] // self.z_channels + h = self.conv_in(z) + z.repeat_interleave(repeats=repeats, dim=1) + + # Middle processing with attention + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # Progressive upsampling through blocks + for i_level in range(len(self.block_out_channels)): + # Apply residual blocks at this level + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + # Apply spatial upsampling if available + if hasattr(self.up[i_level], "upsample"): + h = self.up[i_level].upsample(h) + + # Final output layers + h = self.norm_out(h) + h = swish(h) + h = self.conv_out(h) + return h + + +class HunyuanVAE2D(nn.Module): + """ + VAE model for Hunyuan Image-2.1 with spatial tiling support. + + This VAE uses a fixed architecture optimized for the Hunyuan Image-2.1 model, + with 32x spatial compression and optional memory-efficient tiling for large images. + """ + + def __init__(self): + super().__init__() + + # Fixed configuration for Hunyuan Image-2.1 + block_out_channels = (128, 256, 512, 512, 1024, 1024) + in_channels = 3 # RGB input + out_channels = 3 # RGB output + latent_channels = 64 + layers_per_block = 2 + ffactor_spatial = 32 # 32x spatial compression + sample_size = 384 # Minimum sample size for tiling + scaling_factor = 0.75289 # Latent scaling factor + + self.ffactor_spatial = ffactor_spatial + self.scaling_factor = scaling_factor + + self.encoder = Encoder( + in_channels=in_channels, + z_channels=latent_channels, + block_out_channels=block_out_channels, + num_res_blocks=layers_per_block, + ffactor_spatial=ffactor_spatial, + ) + + self.decoder = Decoder( + z_channels=latent_channels, + out_channels=out_channels, + block_out_channels=list(reversed(block_out_channels)), + num_res_blocks=layers_per_block, + ffactor_spatial=ffactor_spatial, + ) + + # Spatial tiling configuration for memory efficiency + self.use_spatial_tiling = False + self.tile_sample_min_size = sample_size + self.tile_latent_min_size = sample_size // ffactor_spatial + self.tile_overlap_factor = 0.25 # 25% overlap between tiles + + @property + def dtype(self): + """Get the data type of the model parameters.""" + return next(self.encoder.parameters()).dtype + + @property + def device(self): + """Get the device of the model parameters.""" + return next(self.encoder.parameters()).device + + def enable_spatial_tiling(self, use_tiling: bool = True): + """Enable or disable spatial tiling.""" + self.use_spatial_tiling = use_tiling + + def disable_spatial_tiling(self): + """Disable spatial tiling.""" + self.use_spatial_tiling = False + + def enable_tiling(self, use_tiling: bool = True): + """Enable or disable spatial tiling (alias for enable_spatial_tiling).""" + self.enable_spatial_tiling(use_tiling) + + def disable_tiling(self): + """Disable spatial tiling (alias for disable_spatial_tiling).""" + self.disable_spatial_tiling() + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + """ + Blend two tensors horizontally with smooth transition. + + Parameters + ---------- + a : torch.Tensor + Left tensor. + b : torch.Tensor + Right tensor. + blend_extent : int + Number of columns to blend. + """ + blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent) + return b + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + """ + Blend two tensors vertically with smooth transition. + + Parameters + ---------- + a : torch.Tensor + Top tensor. + b : torch.Tensor + Bottom tensor. + blend_extent : int + Number of rows to blend. + """ + blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent) + return b + + def spatial_tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + """ + Encode large images using spatial tiling to reduce memory usage. + Tiles are processed independently and blended at boundaries. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (B, C, T, H, W). + """ + B, C, T, H, W = x.shape + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + rows = [] + for i in range(0, H, overlap_size): + row = [] + for j in range(0, W, overlap_size): + tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] + tile = self.encoder(tile) + row.append(tile) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=-1)) + + moments = torch.cat(result_rows, dim=-2) + return moments + + def spatial_tiled_decode(self, z: torch.Tensor) -> torch.Tensor: + """ + Decode large latents using spatial tiling to reduce memory usage. + Tiles are processed independently and blended at boundaries. + + Parameters + ---------- + z : torch.Tensor + Latent tensor of shape (B, C, H, W). + """ + B, C, H, W = z.shape + overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) + row_limit = self.tile_sample_min_size - blend_extent + + rows = [] + for i in range(0, H, overlap_size): + row = [] + for j in range(0, W, overlap_size): + tile = z[:, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size] + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=-1)) + + dec = torch.cat(result_rows, dim=-2) + return dec + + def encode(self, x: Tensor) -> DiagonalGaussianDistribution: + """ + Encode input images to latent representation. + Uses spatial tiling for large images if enabled. + + Parameters + ---------- + x : Tensor + Input image tensor of shape (B, C, H, W) or (B, C, T, H, W). + + Returns + ------- + DiagonalGaussianDistribution + Latent distribution with mean and logvar. + """ + # Handle 5D input (B, C, T, H, W) by removing time dimension + original_ndim = x.ndim + if original_ndim == 5: + x = x.squeeze(2) + + # Use tiling for large images to reduce memory usage + if self.use_spatial_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size): + h = self.spatial_tiled_encode(x) + else: + h = self.encoder(x) + + # Restore time dimension if input was 5D + if original_ndim == 5: + h = h.unsqueeze(2) + + posterior = DiagonalGaussianDistribution(h) + return posterior + + def decode(self, z: Tensor): + """ + Decode latent representation back to images. + Uses spatial tiling for large latents if enabled. + + Parameters + ---------- + z : Tensor + Latent tensor of shape (B, C, H, W) or (B, C, T, H, W). + + Returns + ------- + Tensor + Decoded image tensor. + """ + # Handle 5D input (B, C, T, H, W) by removing time dimension + original_ndim = z.ndim + if original_ndim == 5: + z = z.squeeze(2) + + # Use tiling for large latents to reduce memory usage + if self.use_spatial_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size): + decoded = self.spatial_tiled_decode(z) + else: + decoded = self.decoder(z) + + # Restore time dimension if input was 5D + if original_ndim == 5: + decoded = decoded.unsqueeze(2) + + return decoded + + +def load_vae(vae_path: str, device: torch.device, disable_mmap: bool = False) -> HunyuanVAE2D: + logger.info("Initializing VAE") + vae = HunyuanVAE2D() + + logger.info(f"Loading VAE from {vae_path}") + state_dict = load_safetensors(vae_path, device=device, disable_mmap=disable_mmap) + info = vae.load_state_dict(state_dict, strict=True, assign=True) + logger.info(f"Loaded VAE: {info}") + + vae.to(device) + return vae diff --git a/library/lora_utils.py b/library/lora_utils.py new file mode 100644 index 00000000..db004622 --- /dev/null +++ b/library/lora_utils.py @@ -0,0 +1,249 @@ +# copy from Musubi Tuner + +import os +import re +from typing import Dict, List, Optional, Union +import torch + +from tqdm import tqdm + +from library.custom_offloading_utils import synchronize_device +from library.fp8_optimization_utils import load_safetensors_with_fp8_optimization +from library.utils import MemoryEfficientSafeOpen, setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +def filter_lora_state_dict( + weights_sd: Dict[str, torch.Tensor], + include_pattern: Optional[str] = None, + exclude_pattern: Optional[str] = None, +) -> Dict[str, torch.Tensor]: + # apply include/exclude patterns + original_key_count = len(weights_sd.keys()) + if include_pattern is not None: + regex_include = re.compile(include_pattern) + weights_sd = {k: v for k, v in weights_sd.items() if regex_include.search(k)} + logger.info(f"Filtered keys with include pattern {include_pattern}: {original_key_count} -> {len(weights_sd.keys())}") + + if exclude_pattern is not None: + original_key_count_ex = len(weights_sd.keys()) + regex_exclude = re.compile(exclude_pattern) + weights_sd = {k: v for k, v in weights_sd.items() if not regex_exclude.search(k)} + logger.info(f"Filtered keys with exclude pattern {exclude_pattern}: {original_key_count_ex} -> {len(weights_sd.keys())}") + + if len(weights_sd) != original_key_count: + remaining_keys = list(set([k.split(".", 1)[0] for k in weights_sd.keys()])) + remaining_keys.sort() + logger.info(f"Remaining LoRA modules after filtering: {remaining_keys}") + if len(weights_sd) == 0: + logger.warning("No keys left after filtering.") + + return weights_sd + + +def load_safetensors_with_lora_and_fp8( + model_files: Union[str, List[str]], + lora_weights_list: Optional[Dict[str, torch.Tensor]], + lora_multipliers: Optional[List[float]], + fp8_optimization: bool, + calc_device: torch.device, + move_to_device: bool = False, + dit_weight_dtype: Optional[torch.dtype] = None, + target_keys: Optional[List[str]] = None, + exclude_keys: Optional[List[str]] = None, +) -> dict[str, torch.Tensor]: + """ + Merge LoRA weights into the state dict of a model with fp8 optimization if needed. + + Args: + model_files (Union[str, List[str]]): Path to the model file or list of paths. If the path matches a pattern like `00001-of-00004`, it will load all files with the same prefix. + lora_weights_list (Optional[Dict[str, torch.Tensor]]): Dictionary of LoRA weight tensors to load. + lora_multipliers (Optional[List[float]]): List of multipliers for LoRA weights. + fp8_optimization (bool): Whether to apply FP8 optimization. + calc_device (torch.device): Device to calculate on. + move_to_device (bool): Whether to move tensors to the calculation device after loading. + target_keys (Optional[List[str]]): Keys to target for optimization. + exclude_keys (Optional[List[str]]): Keys to exclude from optimization. + """ + + # if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix + if isinstance(model_files, str): + model_files = [model_files] + + extended_model_files = [] + for model_file in model_files: + basename = os.path.basename(model_file) + match = re.match(r"^(.*?)(\d+)-of-(\d+)\.safetensors$", basename) + if match: + prefix = basename[: match.start(2)] + count = int(match.group(3)) + state_dict = {} + for i in range(count): + filename = f"{prefix}{i+1:05d}-of-{count:05d}.safetensors" + filepath = os.path.join(os.path.dirname(model_file), filename) + if os.path.exists(filepath): + extended_model_files.append(filepath) + else: + raise FileNotFoundError(f"File {filepath} not found") + else: + extended_model_files.append(model_file) + model_files = extended_model_files + logger.info(f"Loading model files: {model_files}") + + # load LoRA weights + weight_hook = None + if lora_weights_list is None or len(lora_weights_list) == 0: + lora_weights_list = [] + lora_multipliers = [] + list_of_lora_weight_keys = [] + else: + list_of_lora_weight_keys = [] + for lora_sd in lora_weights_list: + lora_weight_keys = set(lora_sd.keys()) + list_of_lora_weight_keys.append(lora_weight_keys) + + if lora_multipliers is None: + lora_multipliers = [1.0] * len(lora_weights_list) + while len(lora_multipliers) < len(lora_weights_list): + lora_multipliers.append(1.0) + if len(lora_multipliers) > len(lora_weights_list): + lora_multipliers = lora_multipliers[: len(lora_weights_list)] + + # Merge LoRA weights into the state dict + logger.info(f"Merging LoRA weights into state dict. multipliers: {lora_multipliers}") + + # make hook for LoRA merging + def weight_hook_func(model_weight_key, model_weight): + nonlocal list_of_lora_weight_keys, lora_weights_list, lora_multipliers, calc_device + + if not model_weight_key.endswith(".weight"): + return model_weight + + original_device = model_weight.device + if original_device != calc_device: + model_weight = model_weight.to(calc_device) # to make calculation faster + + for lora_weight_keys, lora_sd, multiplier in zip(list_of_lora_weight_keys, lora_weights_list, lora_multipliers): + # check if this weight has LoRA weights + lora_name = model_weight_key.rsplit(".", 1)[0] # remove trailing ".weight" + lora_name = "lora_unet_" + lora_name.replace(".", "_") + down_key = lora_name + ".lora_down.weight" + up_key = lora_name + ".lora_up.weight" + alpha_key = lora_name + ".alpha" + if down_key not in lora_weight_keys or up_key not in lora_weight_keys: + continue + + # get LoRA weights + down_weight = lora_sd[down_key] + up_weight = lora_sd[up_key] + + dim = down_weight.size()[0] + alpha = lora_sd.get(alpha_key, dim) + scale = alpha / dim + + down_weight = down_weight.to(calc_device) + up_weight = up_weight.to(calc_device) + + # W <- W + U * D + if len(model_weight.size()) == 2: + # linear + if len(up_weight.size()) == 4: # use linear projection mismatch + up_weight = up_weight.squeeze(3).squeeze(2) + down_weight = down_weight.squeeze(3).squeeze(2) + model_weight = model_weight + multiplier * (up_weight @ down_weight) * scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + model_weight = ( + model_weight + + multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) + model_weight = model_weight + multiplier * conved * scale + + # remove LoRA keys from set + lora_weight_keys.remove(down_key) + lora_weight_keys.remove(up_key) + if alpha_key in lora_weight_keys: + lora_weight_keys.remove(alpha_key) + + model_weight = model_weight.to(original_device) # move back to original device + return model_weight + + weight_hook = weight_hook_func + + state_dict = load_safetensors_with_fp8_optimization_and_hook( + model_files, + fp8_optimization, + calc_device, + move_to_device, + dit_weight_dtype, + target_keys, + exclude_keys, + weight_hook=weight_hook, + ) + + for lora_weight_keys in list_of_lora_weight_keys: + # check if all LoRA keys are used + if len(lora_weight_keys) > 0: + # if there are still LoRA keys left, it means they are not used in the model + # this is a warning, not an error + logger.warning(f"Warning: not all LoRA keys are used: {', '.join(lora_weight_keys)}") + + return state_dict + + +def load_safetensors_with_fp8_optimization_and_hook( + model_files: list[str], + fp8_optimization: bool, + calc_device: torch.device, + move_to_device: bool = False, + dit_weight_dtype: Optional[torch.dtype] = None, + target_keys: Optional[List[str]] = None, + exclude_keys: Optional[List[str]] = None, + weight_hook: callable = None, +) -> dict[str, torch.Tensor]: + """ + Load state dict from safetensors files and merge LoRA weights into the state dict with fp8 optimization if needed. + """ + if fp8_optimization: + logger.info( + f"Loading state dict with FP8 optimization. Dtype of weight: {dit_weight_dtype}, hook enabled: {weight_hook is not None}" + ) + # dit_weight_dtype is not used because we use fp8 optimization + state_dict = load_safetensors_with_fp8_optimization( + model_files, calc_device, target_keys, exclude_keys, move_to_device=move_to_device, weight_hook=weight_hook + ) + else: + logger.info( + f"Loading state dict without FP8 optimization. Dtype of weight: {dit_weight_dtype}, hook enabled: {weight_hook is not None}" + ) + state_dict = {} + for model_file in model_files: + with MemoryEfficientSafeOpen(model_file) as f: + for key in tqdm(f.keys(), desc=f"Loading {os.path.basename(model_file)}", leave=False): + value = f.get_tensor(key) + if weight_hook is not None: + value = weight_hook(key, value) + if move_to_device: + if dit_weight_dtype is None: + value = value.to(calc_device, non_blocking=True) + else: + value = value.to(calc_device, dtype=dit_weight_dtype, non_blocking=True) + elif dit_weight_dtype is not None: + value = value.to(dit_weight_dtype) + + state_dict[key] = value + + if move_to_device: + synchronize_device(calc_device) + + return state_dict diff --git a/networks/lora_hunyuan_image.py b/networks/lora_hunyuan_image.py new file mode 100644 index 00000000..e9ad5f68 --- /dev/null +++ b/networks/lora_hunyuan_image.py @@ -0,0 +1,1444 @@ +# temporary minimum implementation of LoRA +# FLUX doesn't have Conv2d, so we ignore it +# TODO commonize with the original implementation + +# LoRA network module +# reference: +# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py +# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py + +import math +import os +from contextlib import contextmanager +from typing import Dict, List, Optional, Tuple, Type, Union +from diffusers import AutoencoderKL +from transformers import CLIPTextModel +import numpy as np +import torch +from torch import Tensor +import re +from library.utils import setup_logging +from library.sdxl_original_unet import SdxlUNet2DConditionModel + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +NUM_DOUBLE_BLOCKS = 19 +NUM_SINGLE_BLOCKS = 38 + + +class LoRAModule(torch.nn.Module): + """ + replaces forward method of the original Linear, instead of replacing the original Linear module. + """ + + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=None, + rank_dropout=None, + module_dropout=None, + split_dims: Optional[List[int]] = None, + ggpo_beta: Optional[float] = None, + ggpo_sigma: Optional[float] = None, + ): + """ + if alpha == 0 or None, alpha is rank (no scaling). + + split_dims is used to mimic the split qkv of FLUX as same as Diffusers + """ + super().__init__() + self.lora_name = lora_name + + if org_module.__class__.__name__ == "Conv2d": + in_dim = org_module.in_channels + out_dim = org_module.out_channels + else: + in_dim = org_module.in_features + out_dim = org_module.out_features + + self.lora_dim = lora_dim + self.split_dims = split_dims + + if split_dims is None: + if org_module.__class__.__name__ == "Conv2d": + kernel_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) + self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) + else: + self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) + self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) + + torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + torch.nn.init.zeros_(self.lora_up.weight) + else: + # conv2d not supported + assert sum(split_dims) == out_dim, "sum of split_dims must be equal to out_dim" + assert org_module.__class__.__name__ == "Linear", "split_dims is only supported for Linear" + # print(f"split_dims: {split_dims}") + self.lora_down = torch.nn.ModuleList( + [torch.nn.Linear(in_dim, self.lora_dim, bias=False) for _ in range(len(split_dims))] + ) + self.lora_up = torch.nn.ModuleList([torch.nn.Linear(self.lora_dim, split_dim, bias=False) for split_dim in split_dims]) + for lora_down in self.lora_down: + torch.nn.init.kaiming_uniform_(lora_down.weight, a=math.sqrt(5)) + for lora_up in self.lora_up: + torch.nn.init.zeros_(lora_up.weight) + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = self.lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える + + # same as microsoft's + self.multiplier = multiplier + self.org_module = org_module # remove in applying + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + + self.ggpo_sigma = ggpo_sigma + self.ggpo_beta = ggpo_beta + + if self.ggpo_beta is not None and self.ggpo_sigma is not None: + self.combined_weight_norms = None + self.grad_norms = None + self.perturbation_norm_factor = 1.0 / math.sqrt(org_module.weight.shape[0]) + self.initialize_norm_cache(org_module.weight) + self.org_module_shape: tuple[int] = org_module.weight.shape + + def apply_to(self): + self.org_forward = self.org_module.forward + self.org_module.forward = self.forward + + del self.org_module + + def forward(self, x): + org_forwarded = self.org_forward(x) + + # module dropout + if self.module_dropout is not None and self.training: + if torch.rand(1) < self.module_dropout: + return org_forwarded + + if self.split_dims is None: + lx = self.lora_down(x) + + # normal dropout + if self.dropout is not None and self.training: + lx = torch.nn.functional.dropout(lx, p=self.dropout) + + # rank dropout + if self.rank_dropout is not None and self.training: + mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout + if len(lx.size()) == 3: + mask = mask.unsqueeze(1) # for Text Encoder + elif len(lx.size()) == 4: + mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d + lx = lx * mask + + # scaling for rank dropout: treat as if the rank is changed + # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる + scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + else: + scale = self.scale + + lx = self.lora_up(lx) + + # LoRA Gradient-Guided Perturbation Optimization + if ( + self.training + and self.ggpo_sigma is not None + and self.ggpo_beta is not None + and self.combined_weight_norms is not None + and self.grad_norms is not None + ): + with torch.no_grad(): + perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms**2)) + ( + self.ggpo_beta * (self.grad_norms**2) + ) + perturbation_scale_factor = (perturbation_scale * self.perturbation_norm_factor).to(self.device) + perturbation = torch.randn(self.org_module_shape, dtype=self.dtype, device=self.device) + perturbation.mul_(perturbation_scale_factor) + perturbation_output = x @ perturbation.T # Result: (batch × n) + return org_forwarded + (self.multiplier * scale * lx) + perturbation_output + else: + return org_forwarded + lx * self.multiplier * scale + else: + lxs = [lora_down(x) for lora_down in self.lora_down] + + # normal dropout + if self.dropout is not None and self.training: + lxs = [torch.nn.functional.dropout(lx, p=self.dropout) for lx in lxs] + + # rank dropout + if self.rank_dropout is not None and self.training: + masks = [torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout for lx in lxs] + for i in range(len(lxs)): + if len(lx.size()) == 3: + masks[i] = masks[i].unsqueeze(1) + elif len(lx.size()) == 4: + masks[i] = masks[i].unsqueeze(-1).unsqueeze(-1) + lxs[i] = lxs[i] * masks[i] + + # scaling for rank dropout: treat as if the rank is changed + scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + else: + scale = self.scale + + lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)] + + return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale + + @torch.no_grad() + def initialize_norm_cache(self, org_module_weight: Tensor): + # Choose a reasonable sample size + n_rows = org_module_weight.shape[0] + sample_size = min(1000, n_rows) # Cap at 1000 samples or use all if smaller + + # Sample random indices across all rows + indices = torch.randperm(n_rows)[:sample_size] + + # Convert to a supported data type first, then index + # Use float32 for indexing operations + weights_float32 = org_module_weight.to(dtype=torch.float32) + sampled_weights = weights_float32[indices].to(device=self.device) + + # Calculate sampled norms + sampled_norms = torch.norm(sampled_weights, dim=1, keepdim=True) + + # Store the mean norm as our estimate + self.org_weight_norm_estimate = sampled_norms.mean() + + # Optional: store standard deviation for confidence intervals + self.org_weight_norm_std = sampled_norms.std() + + # Free memory + del sampled_weights, weights_float32 + + @torch.no_grad() + def validate_norm_approximation(self, org_module_weight: Tensor, verbose=True): + # Calculate the true norm (this will be slow but it's just for validation) + true_norms = [] + chunk_size = 1024 # Process in chunks to avoid OOM + + for i in range(0, org_module_weight.shape[0], chunk_size): + end_idx = min(i + chunk_size, org_module_weight.shape[0]) + chunk = org_module_weight[i:end_idx].to(device=self.device, dtype=self.dtype) + chunk_norms = torch.norm(chunk, dim=1, keepdim=True) + true_norms.append(chunk_norms.cpu()) + del chunk + + true_norms = torch.cat(true_norms, dim=0) + true_mean_norm = true_norms.mean().item() + + # Compare with our estimate + estimated_norm = self.org_weight_norm_estimate.item() + + # Calculate error metrics + absolute_error = abs(true_mean_norm - estimated_norm) + relative_error = absolute_error / true_mean_norm * 100 # as percentage + + if verbose: + logger.info(f"True mean norm: {true_mean_norm:.6f}") + logger.info(f"Estimated norm: {estimated_norm:.6f}") + logger.info(f"Absolute error: {absolute_error:.6f}") + logger.info(f"Relative error: {relative_error:.2f}%") + + return { + "true_mean_norm": true_mean_norm, + "estimated_norm": estimated_norm, + "absolute_error": absolute_error, + "relative_error": relative_error, + } + + @torch.no_grad() + def update_norms(self): + # Not running GGPO so not currently running update norms + if self.ggpo_beta is None or self.ggpo_sigma is None: + return + + # only update norms when we are training + if self.training is False: + return + + module_weights = self.lora_up.weight @ self.lora_down.weight + module_weights.mul(self.scale) + + self.weight_norms = torch.norm(module_weights, dim=1, keepdim=True) + self.combined_weight_norms = torch.sqrt( + (self.org_weight_norm_estimate**2) + torch.sum(module_weights**2, dim=1, keepdim=True) + ) + + @torch.no_grad() + def update_grad_norms(self): + if self.training is False: + print(f"skipping update_grad_norms for {self.lora_name}") + return + + lora_down_grad = None + lora_up_grad = None + + for name, param in self.named_parameters(): + if name == "lora_down.weight": + lora_down_grad = param.grad + elif name == "lora_up.weight": + lora_up_grad = param.grad + + # Calculate gradient norms if we have both gradients + if lora_down_grad is not None and lora_up_grad is not None: + with torch.autocast(self.device.type): + approx_grad = self.scale * ((self.lora_up.weight @ lora_down_grad) + (lora_up_grad @ self.lora_down.weight)) + self.grad_norms = torch.norm(approx_grad, dim=1, keepdim=True) + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + +class LoRAInfModule(LoRAModule): + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + **kwargs, + ): + # no dropout for inference + super().__init__(lora_name, org_module, multiplier, lora_dim, alpha) + + self.org_module_ref = [org_module] # 後から参照できるように + self.enabled = True + self.network: LoRANetwork = None + + def set_network(self, network): + self.network = network + + # freezeしてマージする + def merge_to(self, sd, dtype, device): + # extract weight from org_module + org_sd = self.org_module.state_dict() + weight = org_sd["weight"] + org_dtype = weight.dtype + org_device = weight.device + weight = weight.to(torch.float) # calc in float + + if dtype is None: + dtype = org_dtype + if device is None: + device = org_device + + if self.split_dims is None: + # get up/down weight + down_weight = sd["lora_down.weight"].to(torch.float).to(device) + up_weight = sd["lora_up.weight"].to(torch.float).to(device) + + # merge weight + if len(weight.size()) == 2: + # linear + weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + self.multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * self.scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + self.multiplier * conved * self.scale + + # set weight to org_module + org_sd["weight"] = weight.to(dtype) + self.org_module.load_state_dict(org_sd) + else: + # split_dims + total_dims = sum(self.split_dims) + for i in range(len(self.split_dims)): + # get up/down weight + down_weight = sd[f"lora_down.{i}.weight"].to(torch.float).to(device) # (rank, in_dim) + up_weight = sd[f"lora_up.{i}.weight"].to(torch.float).to(device) # (split dim, rank) + + # pad up_weight -> (total_dims, rank) + padded_up_weight = torch.zeros((total_dims, up_weight.size(0)), device=device, dtype=torch.float) + padded_up_weight[sum(self.split_dims[:i]) : sum(self.split_dims[: i + 1])] = up_weight + + # merge weight + weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale + + # set weight to org_module + org_sd["weight"] = weight.to(dtype) + self.org_module.load_state_dict(org_sd) + + # 復元できるマージのため、このモジュールのweightを返す + def get_weight(self, multiplier=None): + if multiplier is None: + multiplier = self.multiplier + + # get up/down weight from module + up_weight = self.lora_up.weight.to(torch.float) + down_weight = self.lora_down.weight.to(torch.float) + + # pre-calculated weight + if len(down_weight.size()) == 2: + # linear + weight = self.multiplier * (up_weight @ down_weight) * self.scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + self.multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * self.scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + weight = self.multiplier * conved * self.scale + + return weight + + def set_region(self, region): + self.region = region + self.region_mask = None + + def default_forward(self, x): + # logger.info(f"default_forward {self.lora_name} {x.size()}") + if self.split_dims is None: + lx = self.lora_down(x) + lx = self.lora_up(lx) + return self.org_forward(x) + lx * self.multiplier * self.scale + else: + lxs = [lora_down(x) for lora_down in self.lora_down] + lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)] + return self.org_forward(x) + torch.cat(lxs, dim=-1) * self.multiplier * self.scale + + def forward(self, x): + if not self.enabled: + return self.org_forward(x) + return self.default_forward(x) + + +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + ae: AutoencoderKL, + text_encoders: List[CLIPTextModel], + flux, + neuron_dropout: Optional[float] = None, + **kwargs, +): + if network_dim is None: + network_dim = 4 # default + if network_alpha is None: + network_alpha = 1.0 + + # extract dim/alpha for conv2d, and block dim + conv_dim = kwargs.get("conv_dim", None) + conv_alpha = kwargs.get("conv_alpha", None) + if conv_dim is not None: + conv_dim = int(conv_dim) + if conv_alpha is None: + conv_alpha = 1.0 + else: + conv_alpha = float(conv_alpha) + + # attn dim, mlp dim: only for DoubleStreamBlock. SingleStreamBlock is not supported because of combined qkv + img_attn_dim = kwargs.get("img_attn_dim", None) + txt_attn_dim = kwargs.get("txt_attn_dim", None) + img_mlp_dim = kwargs.get("img_mlp_dim", None) + txt_mlp_dim = kwargs.get("txt_mlp_dim", None) + img_mod_dim = kwargs.get("img_mod_dim", None) + txt_mod_dim = kwargs.get("txt_mod_dim", None) + single_dim = kwargs.get("single_dim", None) # SingleStreamBlock + single_mod_dim = kwargs.get("single_mod_dim", None) # SingleStreamBlock + if img_attn_dim is not None: + img_attn_dim = int(img_attn_dim) + if txt_attn_dim is not None: + txt_attn_dim = int(txt_attn_dim) + if img_mlp_dim is not None: + img_mlp_dim = int(img_mlp_dim) + if txt_mlp_dim is not None: + txt_mlp_dim = int(txt_mlp_dim) + if img_mod_dim is not None: + img_mod_dim = int(img_mod_dim) + if txt_mod_dim is not None: + txt_mod_dim = int(txt_mod_dim) + if single_dim is not None: + single_dim = int(single_dim) + if single_mod_dim is not None: + single_mod_dim = int(single_mod_dim) + type_dims = [img_attn_dim, txt_attn_dim, img_mlp_dim, txt_mlp_dim, img_mod_dim, txt_mod_dim, single_dim, single_mod_dim] + if all([d is None for d in type_dims]): + type_dims = None + + # in_dims [img, time, vector, guidance, txt] + in_dims = kwargs.get("in_dims", None) + if in_dims is not None: + in_dims = in_dims.strip() + if in_dims.startswith("[") and in_dims.endswith("]"): + in_dims = in_dims[1:-1] + in_dims = [int(d) for d in in_dims.split(",")] # is it better to use ast.literal_eval? + assert len(in_dims) == 5, f"invalid in_dims: {in_dims}, must be 5 dimensions (img, time, vector, guidance, txt)" + + # double/single train blocks + def parse_block_selection(selection: str, total_blocks: int) -> List[bool]: + """ + Parse a block selection string and return a list of booleans. + + Args: + selection (str): A string specifying which blocks to select. + total_blocks (int): The total number of blocks available. + + Returns: + List[bool]: A list of booleans indicating which blocks are selected. + """ + if selection == "all": + return [True] * total_blocks + if selection == "none" or selection == "": + return [False] * total_blocks + + selected = [False] * total_blocks + ranges = selection.split(",") + + for r in ranges: + if "-" in r: + start, end = map(str.strip, r.split("-")) + start = int(start) + end = int(end) + assert 0 <= start < total_blocks, f"invalid start index: {start}" + assert 0 <= end < total_blocks, f"invalid end index: {end}" + assert start <= end, f"invalid range: {start}-{end}" + for i in range(start, end + 1): + selected[i] = True + else: + index = int(r) + assert 0 <= index < total_blocks, f"invalid index: {index}" + selected[index] = True + + return selected + + train_double_block_indices = kwargs.get("train_double_block_indices", None) + train_single_block_indices = kwargs.get("train_single_block_indices", None) + if train_double_block_indices is not None: + train_double_block_indices = parse_block_selection(train_double_block_indices, NUM_DOUBLE_BLOCKS) + if train_single_block_indices is not None: + train_single_block_indices = parse_block_selection(train_single_block_indices, NUM_SINGLE_BLOCKS) + + # rank/module dropout + rank_dropout = kwargs.get("rank_dropout", None) + if rank_dropout is not None: + rank_dropout = float(rank_dropout) + module_dropout = kwargs.get("module_dropout", None) + if module_dropout is not None: + module_dropout = float(module_dropout) + + # single or double blocks + train_blocks = kwargs.get("train_blocks", None) # None (default), "all" (same as None), "single", "double" + if train_blocks is not None: + assert train_blocks in ["all", "single", "double"], f"invalid train_blocks: {train_blocks}" + + # split qkv + split_qkv = kwargs.get("split_qkv", False) + if split_qkv is not None: + split_qkv = True if split_qkv == "True" else False + + ggpo_beta = kwargs.get("ggpo_beta", None) + ggpo_sigma = kwargs.get("ggpo_sigma", None) + + if ggpo_beta is not None: + ggpo_beta = float(ggpo_beta) + + if ggpo_sigma is not None: + ggpo_sigma = float(ggpo_sigma) + + # train T5XXL + train_t5xxl = kwargs.get("train_t5xxl", False) + if train_t5xxl is not None: + train_t5xxl = True if train_t5xxl == "True" else False + + # verbose + verbose = kwargs.get("verbose", False) + if verbose is not None: + verbose = True if verbose == "True" else False + + # regex-specific learning rates + def parse_kv_pairs(kv_pair_str: str, is_int: bool) -> Dict[str, float]: + """ + Parse a string of key-value pairs separated by commas. + """ + pairs = {} + for pair in kv_pair_str.split(","): + pair = pair.strip() + if not pair: + continue + if "=" not in pair: + logger.warning(f"Invalid format: {pair}, expected 'key=value'") + continue + key, value = pair.split("=", 1) + key = key.strip() + value = value.strip() + try: + pairs[key] = int(value) if is_int else float(value) + except ValueError: + logger.warning(f"Invalid value for {key}: {value}") + return pairs + + # parse regular expression based learning rates + network_reg_lrs = kwargs.get("network_reg_lrs", None) + if network_reg_lrs is not None: + reg_lrs = parse_kv_pairs(network_reg_lrs, is_int=False) + else: + reg_lrs = None + + # regex-specific dimensions (ranks) + network_reg_dims = kwargs.get("network_reg_dims", None) + if network_reg_dims is not None: + reg_dims = parse_kv_pairs(network_reg_dims, is_int=True) + else: + reg_dims = None + + # すごく引数が多いな ( ^ω^)・・・ + network = LoRANetwork( + text_encoders, + flux, + multiplier=multiplier, + lora_dim=network_dim, + alpha=network_alpha, + dropout=neuron_dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + conv_lora_dim=conv_dim, + conv_alpha=conv_alpha, + train_blocks=train_blocks, + split_qkv=split_qkv, + train_t5xxl=train_t5xxl, + type_dims=type_dims, + in_dims=in_dims, + train_double_block_indices=train_double_block_indices, + train_single_block_indices=train_single_block_indices, + reg_dims=reg_dims, + ggpo_beta=ggpo_beta, + ggpo_sigma=ggpo_sigma, + reg_lrs=reg_lrs, + verbose=verbose, + ) + + loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) + loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None) + loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None) + loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None + loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None + loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None + if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None: + network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) + + return network + + +# Create network from weights for inference, weights are not loaded here (because can be merged) +def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weights_sd=None, for_inference=False, **kwargs): + if weights_sd is None: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file, safe_open + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + # get dim/alpha mapping, and train t5xxl + modules_dim = {} + modules_alpha = {} + train_t5xxl = None + for key, value in weights_sd.items(): + if "." not in key: + continue + + lora_name = key.split(".")[0] + if "alpha" in key: + modules_alpha[lora_name] = value + elif "lora_down" in key: + dim = value.size()[0] + modules_dim[lora_name] = dim + # logger.info(lora_name, value.size(), dim) + + if train_t5xxl is None or train_t5xxl is False: + train_t5xxl = "lora_te3" in lora_name + + if train_t5xxl is None: + train_t5xxl = False + + split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined + + module_class = LoRAInfModule if for_inference else LoRAModule + + network = LoRANetwork( + text_encoders, + flux, + multiplier=multiplier, + modules_dim=modules_dim, + modules_alpha=modules_alpha, + module_class=module_class, + split_qkv=split_qkv, + train_t5xxl=train_t5xxl, + ) + return network, weights_sd + + +class LoRANetwork(torch.nn.Module): + # FLUX_TARGET_REPLACE_MODULE = ["DoubleStreamBlock", "SingleStreamBlock"] + FLUX_TARGET_REPLACE_MODULE_DOUBLE = ["DoubleStreamBlock"] + FLUX_TARGET_REPLACE_MODULE_SINGLE = ["SingleStreamBlock"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP", "T5Attention", "T5DenseGatedActDense"] + LORA_PREFIX_FLUX = "lora_unet" # make ComfyUI compatible + LORA_PREFIX_TEXT_ENCODER_CLIP = "lora_te1" + LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te3" # make ComfyUI compatible + + def __init__( + self, + text_encoders: Union[List[CLIPTextModel], CLIPTextModel], + unet, + multiplier: float = 1.0, + lora_dim: int = 4, + alpha: float = 1, + dropout: Optional[float] = None, + rank_dropout: Optional[float] = None, + module_dropout: Optional[float] = None, + conv_lora_dim: Optional[int] = None, + conv_alpha: Optional[float] = None, + module_class: Type[object] = LoRAModule, + modules_dim: Optional[Dict[str, int]] = None, + modules_alpha: Optional[Dict[str, int]] = None, + train_blocks: Optional[str] = None, + split_qkv: bool = False, + train_t5xxl: bool = False, + type_dims: Optional[List[int]] = None, + in_dims: Optional[List[int]] = None, + train_double_block_indices: Optional[List[bool]] = None, + train_single_block_indices: Optional[List[bool]] = None, + reg_dims: Optional[Dict[str, int]] = None, + ggpo_beta: Optional[float] = None, + ggpo_sigma: Optional[float] = None, + reg_lrs: Optional[Dict[str, float]] = None, + verbose: Optional[bool] = False, + ) -> None: + super().__init__() + self.multiplier = multiplier + + self.lora_dim = lora_dim + self.alpha = alpha + self.conv_lora_dim = conv_lora_dim + self.conv_alpha = conv_alpha + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + self.train_blocks = train_blocks if train_blocks is not None else "all" + self.split_qkv = split_qkv + self.train_t5xxl = train_t5xxl + + self.type_dims = type_dims + self.in_dims = in_dims + self.train_double_block_indices = train_double_block_indices + self.train_single_block_indices = train_single_block_indices + self.reg_dims = reg_dims + self.reg_lrs = reg_lrs + + self.loraplus_lr_ratio = None + self.loraplus_unet_lr_ratio = None + self.loraplus_text_encoder_lr_ratio = None + + if modules_dim is not None: + logger.info(f"create LoRA network from weights") + self.in_dims = [0] * 5 # create in_dims + # verbose = True + else: + logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) + # if self.conv_lora_dim is not None: + # logger.info( + # f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" + # ) + + if ggpo_beta is not None and ggpo_sigma is not None: + logger.info(f"LoRA-GGPO training sigma: {ggpo_sigma} beta: {ggpo_beta}") + + if self.split_qkv: + logger.info(f"split qkv for LoRA") + if self.train_blocks is not None: + logger.info(f"train {self.train_blocks} blocks only") + + if train_t5xxl: + logger.info(f"train T5XXL as well") + + # create module instances + def create_modules( + is_flux: bool, + text_encoder_idx: Optional[int], + root_module: torch.nn.Module, + target_replace_modules: List[str], + filter: Optional[str] = None, + default_dim: Optional[int] = None, + ) -> List[LoRAModule]: + prefix = ( + self.LORA_PREFIX_FLUX + if is_flux + else (self.LORA_PREFIX_TEXT_ENCODER_CLIP if text_encoder_idx == 0 else self.LORA_PREFIX_TEXT_ENCODER_T5) + ) + + loras = [] + skipped = [] + for name, module in root_module.named_modules(): + if target_replace_modules is None or module.__class__.__name__ in target_replace_modules: + if target_replace_modules is None: # dirty hack for all modules + module = root_module # search all modules + + for child_name, child_module in module.named_modules(): + is_linear = child_module.__class__.__name__ == "Linear" + is_conv2d = child_module.__class__.__name__ == "Conv2d" + is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) + + if is_linear or is_conv2d: + lora_name = prefix + "." + (name + "." if name else "") + child_name + lora_name = lora_name.replace(".", "_") + + if filter is not None and not filter in lora_name: + continue + + dim = None + alpha = None + + if modules_dim is not None: + # モジュール指定あり + if lora_name in modules_dim: + dim = modules_dim[lora_name] + alpha = modules_alpha[lora_name] + elif self.reg_dims is not None: + for reg, d in self.reg_dims.items(): + if re.search(reg, lora_name): + dim = d + alpha = self.alpha + logger.info(f"LoRA {lora_name} matched with regex {reg}, using dim: {dim}") + break + + # if modules_dim is None, we use default lora_dim. if modules_dim is not None, we use the specified dim (no default) + if dim is None and modules_dim is None: + if is_linear or is_conv2d_1x1: + dim = default_dim if default_dim is not None else self.lora_dim + alpha = self.alpha + + if is_flux and type_dims is not None: + identifier = [ + ("img_attn",), + ("txt_attn",), + ("img_mlp",), + ("txt_mlp",), + ("img_mod",), + ("txt_mod",), + ("single_blocks", "linear"), + ("modulation",), + ] + for i, d in enumerate(type_dims): + if d is not None and all([id in lora_name for id in identifier[i]]): + dim = d # may be 0 for skip + break + + if ( + is_flux + and dim + and ( + self.train_double_block_indices is not None + or self.train_single_block_indices is not None + ) + and ("double" in lora_name or "single" in lora_name) + ): + # "lora_unet_double_blocks_0_..." or "lora_unet_single_blocks_0_..." + block_index = int(lora_name.split("_")[4]) # bit dirty + if ( + "double" in lora_name + and self.train_double_block_indices is not None + and not self.train_double_block_indices[block_index] + ): + dim = 0 + elif ( + "single" in lora_name + and self.train_single_block_indices is not None + and not self.train_single_block_indices[block_index] + ): + dim = 0 + + elif self.conv_lora_dim is not None: + dim = self.conv_lora_dim + alpha = self.conv_alpha + + if dim is None or dim == 0: + # skipした情報を出力 + if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None): + skipped.append(lora_name) + continue + + # qkv split + split_dims = None + if is_flux and split_qkv: + if "double" in lora_name and "qkv" in lora_name: + split_dims = [3072] * 3 + elif "single" in lora_name and "linear1" in lora_name: + split_dims = [3072] * 3 + [12288] + + lora = module_class( + lora_name, + child_module, + self.multiplier, + dim, + alpha, + dropout=dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + split_dims=split_dims, + ggpo_beta=ggpo_beta, + ggpo_sigma=ggpo_sigma, + ) + loras.append(lora) + + if target_replace_modules is None: + break # all modules are searched + return loras, skipped + + # create LoRA for text encoder + # 毎回すべてのモジュールを作るのは無駄なので要検討 + self.text_encoder_loras: List[Union[LoRAModule, LoRAInfModule]] = [] + skipped_te = [] + for i, text_encoder in enumerate(text_encoders): + index = i + if text_encoder is None: + logger.info(f"Text Encoder {index+1} is None, skipping LoRA creation for this encoder.") + continue + if not train_t5xxl and index > 0: # 0: CLIP, 1: T5XXL, so we skip T5XXL if train_t5xxl is False + break + + logger.info(f"create LoRA for Text Encoder {index+1}:") + + text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + logger.info(f"create LoRA for Text Encoder {index+1}: {len(text_encoder_loras)} modules.") + self.text_encoder_loras.extend(text_encoder_loras) + skipped_te += skipped + + # create LoRA for U-Net + if self.train_blocks == "all": + target_replace_modules = LoRANetwork.FLUX_TARGET_REPLACE_MODULE_DOUBLE + LoRANetwork.FLUX_TARGET_REPLACE_MODULE_SINGLE + elif self.train_blocks == "single": + target_replace_modules = LoRANetwork.FLUX_TARGET_REPLACE_MODULE_SINGLE + elif self.train_blocks == "double": + target_replace_modules = LoRANetwork.FLUX_TARGET_REPLACE_MODULE_DOUBLE + + self.unet_loras: List[Union[LoRAModule, LoRAInfModule]] + self.unet_loras, skipped_un = create_modules(True, None, unet, target_replace_modules) + + # img, time, vector, guidance, txt + if self.in_dims: + for filter, in_dim in zip(["_img_in", "_time_in", "_vector_in", "_guidance_in", "_txt_in"], self.in_dims): + loras, _ = create_modules(True, None, unet, None, filter=filter, default_dim=in_dim) + self.unet_loras.extend(loras) + + logger.info(f"create LoRA for FLUX {self.train_blocks} blocks: {len(self.unet_loras)} modules.") + if verbose: + for lora in self.unet_loras: + logger.info(f"\t{lora.lora_name:50} {lora.lora_dim}, {lora.alpha}") + + skipped = skipped_te + skipped_un + if verbose and len(skipped) > 0: + logger.warning( + f"because dim (rank) is 0, {len(skipped)} LoRA modules are skipped / dim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" + ) + for name in skipped: + logger.info(f"\t{name}") + + # assertion + names = set() + for lora in self.text_encoder_loras + self.unet_loras: + assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" + names.add(lora.lora_name) + + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for lora in self.text_encoder_loras + self.unet_loras: + lora.multiplier = self.multiplier + + def set_enabled(self, is_enabled): + for lora in self.text_encoder_loras + self.unet_loras: + lora.enabled = is_enabled + + def update_norms(self): + for lora in self.text_encoder_loras + self.unet_loras: + lora.update_norms() + + def update_grad_norms(self): + for lora in self.text_encoder_loras + self.unet_loras: + lora.update_grad_norms() + + def grad_norms(self) -> Tensor | None: + grad_norms = [] + for lora in self.text_encoder_loras + self.unet_loras: + if hasattr(lora, "grad_norms") and lora.grad_norms is not None: + grad_norms.append(lora.grad_norms.mean(dim=0)) + return torch.stack(grad_norms) if len(grad_norms) > 0 else None + + def weight_norms(self) -> Tensor | None: + weight_norms = [] + for lora in self.text_encoder_loras + self.unet_loras: + if hasattr(lora, "weight_norms") and lora.weight_norms is not None: + weight_norms.append(lora.weight_norms.mean(dim=0)) + return torch.stack(weight_norms) if len(weight_norms) > 0 else None + + def combined_weight_norms(self) -> Tensor | None: + combined_weight_norms = [] + for lora in self.text_encoder_loras + self.unet_loras: + if hasattr(lora, "combined_weight_norms") and lora.combined_weight_norms is not None: + combined_weight_norms.append(lora.combined_weight_norms.mean(dim=0)) + return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else None + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + info = self.load_state_dict(weights_sd, False) + return info + + def load_state_dict(self, state_dict, strict=True): + # override to convert original weight to split qkv + if not self.split_qkv: + return super().load_state_dict(state_dict, strict) + + # split qkv + for key in list(state_dict.keys()): + if "double" in key and "qkv" in key: + split_dims = [3072] * 3 + elif "single" in key and "linear1" in key: + split_dims = [3072] * 3 + [12288] + else: + continue + + weight = state_dict[key] + lora_name = key.split(".")[0] + if "lora_down" in key and "weight" in key: + # dense weight (rank*3, in_dim) + split_weight = torch.chunk(weight, len(split_dims), dim=0) + for i, split_w in enumerate(split_weight): + state_dict[f"{lora_name}.lora_down.{i}.weight"] = split_w + + del state_dict[key] + # print(f"split {key}: {weight.shape} to {[w.shape for w in split_weight]}") + elif "lora_up" in key and "weight" in key: + # sparse weight (out_dim=sum(split_dims), rank*3) + rank = weight.size(1) // len(split_dims) + i = 0 + for j in range(len(split_dims)): + state_dict[f"{lora_name}.lora_up.{j}.weight"] = weight[i : i + split_dims[j], j * rank : (j + 1) * rank] + i += split_dims[j] + del state_dict[key] + + # # check is sparse + # i = 0 + # is_zero = True + # for j in range(len(split_dims)): + # for k in range(len(split_dims)): + # if j == k: + # continue + # is_zero = is_zero and torch.all(weight[i : i + split_dims[j], k * rank : (k + 1) * rank] == 0) + # i += split_dims[j] + # if not is_zero: + # logger.warning(f"weight is not sparse: {key}") + # else: + # logger.info(f"weight is sparse: {key}") + + # print( + # f"split {key}: {weight.shape} to {[state_dict[k].shape for k in [f'{lora_name}.lora_up.{j}.weight' for j in range(len(split_dims))]]}" + # ) + + # alpha is unchanged + + return super().load_state_dict(state_dict, strict) + + def state_dict(self, destination=None, prefix="", keep_vars=False): + if not self.split_qkv: + return super().state_dict(destination, prefix, keep_vars) + + # merge qkv + state_dict = super().state_dict(destination, prefix, keep_vars) + new_state_dict = {} + for key in list(state_dict.keys()): + if "double" in key and "qkv" in key: + split_dims = [3072] * 3 + elif "single" in key and "linear1" in key: + split_dims = [3072] * 3 + [12288] + else: + new_state_dict[key] = state_dict[key] + continue + + if key not in state_dict: + continue # already merged + + lora_name = key.split(".")[0] + + # (rank, in_dim) * 3 + down_weights = [state_dict.pop(f"{lora_name}.lora_down.{i}.weight") for i in range(len(split_dims))] + # (split dim, rank) * 3 + up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(len(split_dims))] + + alpha = state_dict.pop(f"{lora_name}.alpha") + + # merge down weight + down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim) + + # merge up weight (sum of split_dim, rank*3) + rank = up_weights[0].size(1) + up_weight = torch.zeros((sum(split_dims), down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype) + i = 0 + for j in range(len(split_dims)): + up_weight[i : i + split_dims[j], j * rank : (j + 1) * rank] = up_weights[j] + i += split_dims[j] + + new_state_dict[f"{lora_name}.lora_down.weight"] = down_weight + new_state_dict[f"{lora_name}.lora_up.weight"] = up_weight + new_state_dict[f"{lora_name}.alpha"] = alpha + + # print( + # f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}" + # ) + print(f"new key: {lora_name}.lora_down.weight, {lora_name}.lora_up.weight, {lora_name}.alpha") + + return new_state_dict + + def apply_to(self, text_encoders, flux, apply_text_encoder=True, apply_unet=True): + if apply_text_encoder: + logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info(f"enable LoRA for U-Net: {len(self.unet_loras)} modules") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + lora.apply_to() + self.add_module(lora.lora_name, lora) + + # マージできるかどうかを返す + def is_mergeable(self): + return True + + # TODO refactor to common function with apply_to + def merge_to(self, text_encoders, flux, weights_sd, dtype=None, device=None): + apply_text_encoder = apply_unet = False + for key in weights_sd.keys(): + if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_CLIP) or key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_T5): + apply_text_encoder = True + elif key.startswith(LoRANetwork.LORA_PREFIX_FLUX): + apply_unet = True + + if apply_text_encoder: + logger.info("enable LoRA for text encoder") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info("enable LoRA for U-Net") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + sd_for_lora = {} + for key in weights_sd.keys(): + if key.startswith(lora.lora_name): + sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] + lora.merge_to(sd_for_lora, dtype, device) + + logger.info(f"weights are merged") + + def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio): + self.loraplus_lr_ratio = loraplus_lr_ratio + self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio + self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio + + logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}") + logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}") + + def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr): + # make sure text_encoder_lr as list of two elements + # if float, use the same value for both text encoders + if text_encoder_lr is None or (isinstance(text_encoder_lr, list) and len(text_encoder_lr) == 0): + text_encoder_lr = [default_lr, default_lr] + elif isinstance(text_encoder_lr, float) or isinstance(text_encoder_lr, int): + text_encoder_lr = [float(text_encoder_lr), float(text_encoder_lr)] + elif len(text_encoder_lr) == 1: + text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[0]] + + self.requires_grad_(True) + + all_params = [] + lr_descriptions = [] + + reg_lrs_list = list(self.reg_lrs.items()) if self.reg_lrs is not None else [] + + def assemble_params(loras, lr, loraplus_ratio): + param_groups = {"lora": {}, "plus": {}} + # regular expression param groups: {"reg_lr_0": {"lora": {}, "plus": {}}, ...} + reg_groups = {} + + for lora in loras: + # check if this lora matches any regex learning rate + matched_reg_lr = None + for i, (regex_str, reg_lr) in enumerate(reg_lrs_list): + try: + if re.search(regex_str, lora.lora_name): + matched_reg_lr = (i, reg_lr) + logger.info(f"Module {lora.lora_name} matched regex '{regex_str}' -> LR {reg_lr}") + break + except re.error: + # regex error should have been caught during parsing, but just in case + continue + + for name, param in lora.named_parameters(): + param_key = f"{lora.lora_name}.{name}" + is_plus = loraplus_ratio is not None and "lora_up" in name + + if matched_reg_lr is not None: + # use regex-specific learning rate + reg_idx, reg_lr = matched_reg_lr + group_key = f"reg_lr_{reg_idx}" + if group_key not in reg_groups: + reg_groups[group_key] = {"lora": {}, "plus": {}, "lr": reg_lr} + + if is_plus: + reg_groups[group_key]["plus"][param_key] = param + else: + reg_groups[group_key]["lora"][param_key] = param + else: + # use default learning rate + if is_plus: + param_groups["plus"][param_key] = param + else: + param_groups["lora"][param_key] = param + + params = [] + descriptions = [] + + # process regex-specific groups first (higher priority) + for group_key in sorted(reg_groups.keys()): + group = reg_groups[group_key] + reg_lr = group["lr"] + + for param_type in ["lora", "plus"]: + if len(group[param_type]) == 0: + continue + + param_data = {"params": group[param_type].values()} + + if param_type == "plus" and loraplus_ratio is not None: + param_data["lr"] = reg_lr * loraplus_ratio + else: + param_data["lr"] = reg_lr + + if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: + continue + + params.append(param_data) + desc = f"reg_lr_{group_key.split('_')[-1]}" + if param_type == "plus": + desc += " plus" + descriptions.append(desc) + + # process default groups + for key in param_groups.keys(): + param_data = {"params": param_groups[key].values()} + + if len(param_data["params"]) == 0: + continue + + if lr is not None: + if key == "plus": + param_data["lr"] = lr * loraplus_ratio + else: + param_data["lr"] = lr + + if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: + logger.info("NO LR skipping!") + continue + + params.append(param_data) + descriptions.append("plus" if key == "plus" else "") + + return params, descriptions + + if self.text_encoder_loras: + loraplus_lr_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio + + # split text encoder loras for te1 and te3 + te1_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_CLIP)] + te3_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_T5)] + if len(te1_loras) > 0: + logger.info(f"Text Encoder 1 (CLIP-L): {len(te1_loras)} modules, LR {text_encoder_lr[0]}") + params, descriptions = assemble_params(te1_loras, text_encoder_lr[0], loraplus_lr_ratio) + all_params.extend(params) + lr_descriptions.extend(["textencoder 1 " + (" " + d if d else "") for d in descriptions]) + if len(te3_loras) > 0: + logger.info(f"Text Encoder 2 (T5XXL): {len(te3_loras)} modules, LR {text_encoder_lr[1]}") + params, descriptions = assemble_params(te3_loras, text_encoder_lr[1], loraplus_lr_ratio) + all_params.extend(params) + lr_descriptions.extend(["textencoder 2 " + (" " + d if d else "") for d in descriptions]) + + if self.unet_loras: + params, descriptions = assemble_params( + self.unet_loras, + unet_lr if unet_lr is not None else default_lr, + self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio, + ) + all_params.extend(params) + lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions]) + + return all_params, lr_descriptions + + def enable_gradient_checkpointing(self): + # not supported + pass + + def prepare_grad_etc(self, text_encoder, unet): + self.requires_grad_(True) + + def on_epoch_start(self, text_encoder, unet): + self.train() + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + from library import train_util + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + def backup_weights(self): + # 重みのバックアップを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not hasattr(org_module, "_lora_org_weight"): + sd = org_module.state_dict() + org_module._lora_org_weight = sd["weight"].detach().clone() + org_module._lora_restored = True + + def restore_weights(self): + # 重みのリストアを行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not org_module._lora_restored: + sd = org_module.state_dict() + sd["weight"] = org_module._lora_org_weight + org_module.load_state_dict(sd) + org_module._lora_restored = True + + def pre_calculation(self): + # 事前計算を行う + loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + sd = org_module.state_dict() + + org_weight = sd["weight"] + lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype) + sd["weight"] = org_weight + lora_weight + assert sd["weight"].shape == org_weight.shape + org_module.load_state_dict(sd) + + org_module._lora_restored = False + lora.enabled = False + + def apply_max_norm_regularization(self, max_norm_value, device): + downkeys = [] + upkeys = [] + alphakeys = [] + norms = [] + keys_scaled = 0 + + state_dict = self.state_dict() + for key in state_dict.keys(): + if "lora_down" in key and "weight" in key: + downkeys.append(key) + upkeys.append(key.replace("lora_down", "lora_up")) + alphakeys.append(key.replace("lora_down.weight", "alpha")) + + for i in range(len(downkeys)): + down = state_dict[downkeys[i]].to(device) + up = state_dict[upkeys[i]].to(device) + alpha = state_dict[alphakeys[i]].to(device) + dim = down.shape[0] + scale = alpha / dim + + if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): + updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) + elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3): + updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3) + else: + updown = up @ down + + updown *= scale + + norm = updown.norm().clamp(min=max_norm_value / 2) + desired = torch.clamp(norm, max=max_norm_value) + ratio = desired.cpu() / norm.cpu() + sqrt_ratio = ratio**0.5 + if ratio != 1: + keys_scaled += 1 + state_dict[upkeys[i]] *= sqrt_ratio + state_dict[downkeys[i]] *= sqrt_ratio + scalednorm = updown.norm() * ratio + norms.append(scalednorm.item()) + + return keys_scaled, sum(norms) / len(norms), max(norms)