mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
feat: initial commit for HunyuanImage-2.1 inference
This commit is contained in:
1197
hunyuan_image_minimal_inference.py
Normal file
1197
hunyuan_image_minimal_inference.py
Normal file
File diff suppressed because it is too large
Load Diff
50
library/attention.py
Normal file
50
library/attention.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import torch
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def attention(
|
||||
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seq_lens: list[int], attn_mode: str = "torch", drop_rate: float = 0.0
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute scaled dot-product attention with variable sequence lengths.
|
||||
|
||||
Handles batches with different sequence lengths by splitting and
|
||||
processing each sequence individually.
|
||||
|
||||
Args:
|
||||
q: Query tensor [B, L, H, D].
|
||||
k: Key tensor [B, L, H, D].
|
||||
v: Value tensor [B, L, H, D].
|
||||
seq_lens: Valid sequence length for each batch element.
|
||||
attn_mode: Attention implementation ("torch" or "sageattn").
|
||||
drop_rate: Attention dropout rate.
|
||||
|
||||
Returns:
|
||||
Attention output tensor [B, L, H*D].
|
||||
"""
|
||||
# Determine tensor layout based on attention implementation
|
||||
if attn_mode == "torch" or attn_mode == "sageattn":
|
||||
transpose_fn = lambda x: x.transpose(1, 2) # [B, H, L, D] for SDPA
|
||||
else:
|
||||
transpose_fn = lambda x: x # [B, L, H, D] for other implementations
|
||||
|
||||
# Process each batch element with its valid sequence length
|
||||
q = [transpose_fn(q[i : i + 1, : seq_lens[i]]) for i in range(len(q))]
|
||||
k = [transpose_fn(k[i : i + 1, : seq_lens[i]]) for i in range(len(k))]
|
||||
v = [transpose_fn(v[i : i + 1, : seq_lens[i]]) for i in range(len(v))]
|
||||
|
||||
if attn_mode == "torch":
|
||||
x = []
|
||||
for i in range(len(q)):
|
||||
x_i = torch.nn.functional.scaled_dot_product_attention(q[i], k[i], v[i], dropout_p=drop_rate)
|
||||
q[i] = None
|
||||
k[i] = None
|
||||
v[i] = None
|
||||
x.append(x_i)
|
||||
x = torch.cat(x, dim=0)
|
||||
del q, k, v
|
||||
# Currently only PyTorch SDPA is implemented
|
||||
|
||||
x = transpose_fn(x) # [B, L, H, D]
|
||||
x = x.reshape(x.shape[0], x.shape[1], -1) # [B, L, H*D]
|
||||
return x
|
||||
391
library/fp8_optimization_utils.py
Normal file
391
library/fp8_optimization_utils.py
Normal file
@@ -0,0 +1,391 @@
|
||||
import os
|
||||
from typing import List, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import logging
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from library.device_utils import clean_memory_on_device
|
||||
from library.utils import MemoryEfficientSafeOpen, setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def calculate_fp8_maxval(exp_bits=4, mantissa_bits=3, sign_bits=1):
|
||||
"""
|
||||
Calculate the maximum representable value in FP8 format.
|
||||
Default is E4M3 format (4-bit exponent, 3-bit mantissa, 1-bit sign).
|
||||
|
||||
Args:
|
||||
exp_bits (int): Number of exponent bits
|
||||
mantissa_bits (int): Number of mantissa bits
|
||||
sign_bits (int): Number of sign bits (0 or 1)
|
||||
|
||||
Returns:
|
||||
float: Maximum value representable in FP8 format
|
||||
"""
|
||||
assert exp_bits + mantissa_bits + sign_bits == 8, "Total bits must be 8"
|
||||
|
||||
# Calculate exponent bias
|
||||
bias = 2 ** (exp_bits - 1) - 1
|
||||
|
||||
# Calculate maximum mantissa value
|
||||
mantissa_max = 1.0
|
||||
for i in range(mantissa_bits - 1):
|
||||
mantissa_max += 2 ** -(i + 1)
|
||||
|
||||
# Calculate maximum value
|
||||
max_value = mantissa_max * (2 ** (2**exp_bits - 1 - bias))
|
||||
|
||||
return max_value
|
||||
|
||||
|
||||
def quantize_tensor_to_fp8(tensor, scale, exp_bits=4, mantissa_bits=3, sign_bits=1, max_value=None, min_value=None):
|
||||
"""
|
||||
Quantize a tensor to FP8 format.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): Tensor to quantize
|
||||
scale (float or torch.Tensor): Scale factor
|
||||
exp_bits (int): Number of exponent bits
|
||||
mantissa_bits (int): Number of mantissa bits
|
||||
sign_bits (int): Number of sign bits
|
||||
|
||||
Returns:
|
||||
tuple: (quantized_tensor, scale_factor)
|
||||
"""
|
||||
# Create scaled tensor
|
||||
scaled_tensor = tensor / scale
|
||||
|
||||
# Calculate FP8 parameters
|
||||
bias = 2 ** (exp_bits - 1) - 1
|
||||
|
||||
if max_value is None:
|
||||
# Calculate max and min values
|
||||
max_value = calculate_fp8_maxval(exp_bits, mantissa_bits, sign_bits)
|
||||
min_value = -max_value if sign_bits > 0 else 0.0
|
||||
|
||||
# Clamp tensor to range
|
||||
clamped_tensor = torch.clamp(scaled_tensor, min_value, max_value)
|
||||
|
||||
# Quantization process
|
||||
abs_values = torch.abs(clamped_tensor)
|
||||
nonzero_mask = abs_values > 0
|
||||
|
||||
# Calculate log scales (only for non-zero elements)
|
||||
log_scales = torch.zeros_like(clamped_tensor)
|
||||
if nonzero_mask.any():
|
||||
log_scales[nonzero_mask] = torch.floor(torch.log2(abs_values[nonzero_mask]) + bias).detach()
|
||||
|
||||
# Limit log scales and calculate quantization factor
|
||||
log_scales = torch.clamp(log_scales, min=1.0)
|
||||
quant_factor = 2.0 ** (log_scales - mantissa_bits - bias)
|
||||
|
||||
# Quantize and dequantize
|
||||
quantized = torch.round(clamped_tensor / quant_factor) * quant_factor
|
||||
|
||||
return quantized, scale
|
||||
|
||||
|
||||
def optimize_state_dict_with_fp8(
|
||||
state_dict, calc_device, target_layer_keys=None, exclude_layer_keys=None, exp_bits=4, mantissa_bits=3, move_to_device=False
|
||||
):
|
||||
"""
|
||||
Optimize Linear layer weights in a model's state dict to FP8 format.
|
||||
|
||||
Args:
|
||||
state_dict (dict): State dict to optimize, replaced in-place
|
||||
calc_device (str): Device to quantize tensors on
|
||||
target_layer_keys (list, optional): Layer key patterns to target (None for all Linear layers)
|
||||
exclude_layer_keys (list, optional): Layer key patterns to exclude
|
||||
exp_bits (int): Number of exponent bits
|
||||
mantissa_bits (int): Number of mantissa bits
|
||||
move_to_device (bool): Move optimized tensors to the calculating device
|
||||
|
||||
Returns:
|
||||
dict: FP8 optimized state dict
|
||||
"""
|
||||
if exp_bits == 4 and mantissa_bits == 3:
|
||||
fp8_dtype = torch.float8_e4m3fn
|
||||
elif exp_bits == 5 and mantissa_bits == 2:
|
||||
fp8_dtype = torch.float8_e5m2
|
||||
else:
|
||||
raise ValueError(f"Unsupported FP8 format: E{exp_bits}M{mantissa_bits}")
|
||||
|
||||
# Calculate FP8 max value
|
||||
max_value = calculate_fp8_maxval(exp_bits, mantissa_bits)
|
||||
min_value = -max_value # this function supports only signed FP8
|
||||
|
||||
# Create optimized state dict
|
||||
optimized_count = 0
|
||||
|
||||
# Enumerate tarket keys
|
||||
target_state_dict_keys = []
|
||||
for key in state_dict.keys():
|
||||
# Check if it's a weight key and matches target patterns
|
||||
is_target = (target_layer_keys is None or any(pattern in key for pattern in target_layer_keys)) and key.endswith(".weight")
|
||||
is_excluded = exclude_layer_keys is not None and any(pattern in key for pattern in exclude_layer_keys)
|
||||
is_target = is_target and not is_excluded
|
||||
|
||||
if is_target and isinstance(state_dict[key], torch.Tensor):
|
||||
target_state_dict_keys.append(key)
|
||||
|
||||
# Process each key
|
||||
for key in tqdm(target_state_dict_keys):
|
||||
value = state_dict[key]
|
||||
|
||||
# Save original device and dtype
|
||||
original_device = value.device
|
||||
original_dtype = value.dtype
|
||||
|
||||
# Move to calculation device
|
||||
if calc_device is not None:
|
||||
value = value.to(calc_device)
|
||||
|
||||
# Calculate scale factor
|
||||
scale = torch.max(torch.abs(value.flatten())) / max_value
|
||||
# print(f"Optimizing {key} with scale: {scale}")
|
||||
|
||||
# Quantize weight to FP8
|
||||
quantized_weight, _ = quantize_tensor_to_fp8(value, scale, exp_bits, mantissa_bits, 1, max_value, min_value)
|
||||
|
||||
# Add to state dict using original key for weight and new key for scale
|
||||
fp8_key = key # Maintain original key
|
||||
scale_key = key.replace(".weight", ".scale_weight")
|
||||
|
||||
quantized_weight = quantized_weight.to(fp8_dtype)
|
||||
|
||||
if not move_to_device:
|
||||
quantized_weight = quantized_weight.to(original_device)
|
||||
|
||||
scale_tensor = torch.tensor([scale], dtype=original_dtype, device=quantized_weight.device)
|
||||
|
||||
state_dict[fp8_key] = quantized_weight
|
||||
state_dict[scale_key] = scale_tensor
|
||||
|
||||
optimized_count += 1
|
||||
|
||||
if calc_device is not None: # optimized_count % 10 == 0 and
|
||||
# free memory on calculation device
|
||||
clean_memory_on_device(calc_device)
|
||||
|
||||
logger.info(f"Number of optimized Linear layers: {optimized_count}")
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_safetensors_with_fp8_optimization(
|
||||
model_files: List[str],
|
||||
calc_device: Union[str, torch.device],
|
||||
target_layer_keys=None,
|
||||
exclude_layer_keys=None,
|
||||
exp_bits=4,
|
||||
mantissa_bits=3,
|
||||
move_to_device=False,
|
||||
weight_hook=None,
|
||||
):
|
||||
"""
|
||||
Load weight tensors from safetensors files and merge LoRA weights into the state dict with explicit FP8 optimization.
|
||||
|
||||
Args:
|
||||
model_files (list[str]): List of model files to load
|
||||
calc_device (str or torch.device): Device to quantize tensors on
|
||||
target_layer_keys (list, optional): Layer key patterns to target for optimization (None for all Linear layers)
|
||||
exclude_layer_keys (list, optional): Layer key patterns to exclude from optimization
|
||||
exp_bits (int): Number of exponent bits
|
||||
mantissa_bits (int): Number of mantissa bits
|
||||
move_to_device (bool): Move optimized tensors to the calculating device
|
||||
weight_hook (callable, optional): Function to apply to each weight tensor before optimization
|
||||
|
||||
Returns:
|
||||
dict: FP8 optimized state dict
|
||||
"""
|
||||
if exp_bits == 4 and mantissa_bits == 3:
|
||||
fp8_dtype = torch.float8_e4m3fn
|
||||
elif exp_bits == 5 and mantissa_bits == 2:
|
||||
fp8_dtype = torch.float8_e5m2
|
||||
else:
|
||||
raise ValueError(f"Unsupported FP8 format: E{exp_bits}M{mantissa_bits}")
|
||||
|
||||
# Calculate FP8 max value
|
||||
max_value = calculate_fp8_maxval(exp_bits, mantissa_bits)
|
||||
min_value = -max_value # this function supports only signed FP8
|
||||
|
||||
# Define function to determine if a key is a target key. target means fp8 optimization, not for weight hook.
|
||||
def is_target_key(key):
|
||||
# Check if weight key matches target patterns and does not match exclude patterns
|
||||
is_target = (target_layer_keys is None or any(pattern in key for pattern in target_layer_keys)) and key.endswith(".weight")
|
||||
is_excluded = exclude_layer_keys is not None and any(pattern in key for pattern in exclude_layer_keys)
|
||||
return is_target and not is_excluded
|
||||
|
||||
# Create optimized state dict
|
||||
optimized_count = 0
|
||||
|
||||
# Process each file
|
||||
state_dict = {}
|
||||
for model_file in model_files:
|
||||
with MemoryEfficientSafeOpen(model_file) as f:
|
||||
keys = f.keys()
|
||||
for key in tqdm(keys, desc=f"Loading {os.path.basename(model_file)}", unit="key"):
|
||||
value = f.get_tensor(key)
|
||||
if weight_hook is not None:
|
||||
# Apply weight hook if provided
|
||||
value = weight_hook(key, value)
|
||||
|
||||
if not is_target_key(key):
|
||||
state_dict[key] = value
|
||||
continue
|
||||
|
||||
# Save original device and dtype
|
||||
original_device = value.device
|
||||
original_dtype = value.dtype
|
||||
|
||||
# Move to calculation device
|
||||
if calc_device is not None:
|
||||
value = value.to(calc_device)
|
||||
|
||||
# Calculate scale factor
|
||||
scale = torch.max(torch.abs(value.flatten())) / max_value
|
||||
# print(f"Optimizing {key} with scale: {scale}")
|
||||
|
||||
# Quantize weight to FP8
|
||||
quantized_weight, _ = quantize_tensor_to_fp8(value, scale, exp_bits, mantissa_bits, 1, max_value, min_value)
|
||||
|
||||
# Add to state dict using original key for weight and new key for scale
|
||||
fp8_key = key # Maintain original key
|
||||
scale_key = key.replace(".weight", ".scale_weight")
|
||||
assert fp8_key != scale_key, "FP8 key and scale key must be different"
|
||||
|
||||
quantized_weight = quantized_weight.to(fp8_dtype)
|
||||
|
||||
if not move_to_device:
|
||||
quantized_weight = quantized_weight.to(original_device)
|
||||
|
||||
scale_tensor = torch.tensor([scale], dtype=original_dtype, device=quantized_weight.device)
|
||||
|
||||
state_dict[fp8_key] = quantized_weight
|
||||
state_dict[scale_key] = scale_tensor
|
||||
|
||||
optimized_count += 1
|
||||
|
||||
if calc_device is not None and optimized_count % 10 == 0:
|
||||
# free memory on calculation device
|
||||
clean_memory_on_device(calc_device)
|
||||
|
||||
logger.info(f"Number of optimized Linear layers: {optimized_count}")
|
||||
return state_dict
|
||||
|
||||
|
||||
def fp8_linear_forward_patch(self: nn.Linear, x, use_scaled_mm=False, max_value=None):
|
||||
"""
|
||||
Patched forward method for Linear layers with FP8 weights.
|
||||
|
||||
Args:
|
||||
self: Linear layer instance
|
||||
x (torch.Tensor): Input tensor
|
||||
use_scaled_mm (bool): Use scaled_mm for FP8 Linear layers, requires SM 8.9+ (RTX 40 series)
|
||||
max_value (float): Maximum value for FP8 quantization. If None, no quantization is applied for input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Result of linear transformation
|
||||
"""
|
||||
if use_scaled_mm:
|
||||
input_dtype = x.dtype
|
||||
original_weight_dtype = self.scale_weight.dtype
|
||||
weight_dtype = self.weight.dtype
|
||||
target_dtype = torch.float8_e5m2
|
||||
assert weight_dtype == torch.float8_e4m3fn, "Only FP8 E4M3FN format is supported"
|
||||
assert x.ndim == 3, "Input tensor must be 3D (batch_size, seq_len, hidden_dim)"
|
||||
|
||||
if max_value is None:
|
||||
# no input quantization
|
||||
scale_x = torch.tensor(1.0, dtype=torch.float32, device=x.device)
|
||||
else:
|
||||
# calculate scale factor for input tensor
|
||||
scale_x = (torch.max(torch.abs(x.flatten())) / max_value).to(torch.float32)
|
||||
|
||||
# quantize input tensor to FP8: this seems to consume a lot of memory
|
||||
x, _ = quantize_tensor_to_fp8(x, scale_x, 5, 2, 1, max_value, -max_value)
|
||||
|
||||
original_shape = x.shape
|
||||
x = x.reshape(-1, x.shape[2]).to(target_dtype)
|
||||
|
||||
weight = self.weight.t()
|
||||
scale_weight = self.scale_weight.to(torch.float32)
|
||||
|
||||
if self.bias is not None:
|
||||
# float32 is not supported with bias in scaled_mm
|
||||
o = torch._scaled_mm(x, weight, out_dtype=original_weight_dtype, bias=self.bias, scale_a=scale_x, scale_b=scale_weight)
|
||||
else:
|
||||
o = torch._scaled_mm(x, weight, out_dtype=input_dtype, scale_a=scale_x, scale_b=scale_weight)
|
||||
|
||||
return o.reshape(original_shape[0], original_shape[1], -1).to(input_dtype)
|
||||
|
||||
else:
|
||||
# Dequantize the weight
|
||||
original_dtype = self.scale_weight.dtype
|
||||
dequantized_weight = self.weight.to(original_dtype) * self.scale_weight
|
||||
|
||||
# Perform linear transformation
|
||||
if self.bias is not None:
|
||||
output = F.linear(x, dequantized_weight, self.bias)
|
||||
else:
|
||||
output = F.linear(x, dequantized_weight)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def apply_fp8_monkey_patch(model, optimized_state_dict, use_scaled_mm=False):
|
||||
"""
|
||||
Apply monkey patching to a model using FP8 optimized state dict.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Model instance to patch
|
||||
optimized_state_dict (dict): FP8 optimized state dict
|
||||
use_scaled_mm (bool): Use scaled_mm for FP8 Linear layers, requires SM 8.9+ (RTX 40 series)
|
||||
|
||||
Returns:
|
||||
nn.Module: The patched model (same instance, modified in-place)
|
||||
"""
|
||||
# # Calculate FP8 float8_e5m2 max value
|
||||
# max_value = calculate_fp8_maxval(5, 2)
|
||||
max_value = None # do not quantize input tensor
|
||||
|
||||
# Find all scale keys to identify FP8-optimized layers
|
||||
scale_keys = [k for k in optimized_state_dict.keys() if k.endswith(".scale_weight")]
|
||||
|
||||
# Enumerate patched layers
|
||||
patched_module_paths = set()
|
||||
for scale_key in scale_keys:
|
||||
# Extract module path from scale key (remove .scale_weight)
|
||||
module_path = scale_key.rsplit(".scale_weight", 1)[0]
|
||||
patched_module_paths.add(module_path)
|
||||
|
||||
patched_count = 0
|
||||
|
||||
# Apply monkey patch to each layer with FP8 weights
|
||||
for name, module in model.named_modules():
|
||||
# Check if this module has a corresponding scale_weight
|
||||
has_scale = name in patched_module_paths
|
||||
|
||||
# Apply patch if it's a Linear layer with FP8 scale
|
||||
if isinstance(module, nn.Linear) and has_scale:
|
||||
# register the scale_weight as a buffer to load the state_dict
|
||||
module.register_buffer("scale_weight", torch.tensor(1.0, dtype=module.weight.dtype))
|
||||
|
||||
# Create a new forward method with the patched version.
|
||||
def new_forward(self, x):
|
||||
return fp8_linear_forward_patch(self, x, use_scaled_mm, max_value)
|
||||
|
||||
# Bind method to module
|
||||
module.forward = new_forward.__get__(module, type(module))
|
||||
|
||||
patched_count += 1
|
||||
|
||||
logger.info(f"Number of monkey-patched Linear layers: {patched_count}")
|
||||
return model
|
||||
374
library/hunyuan_image_models.py
Normal file
374
library/hunyuan_image_models.py
Normal file
@@ -0,0 +1,374 @@
|
||||
# Original work: https://github.com/Tencent-Hunyuan/HunyuanImage-2.1
|
||||
# Re-implemented for license compliance for sd-scripts.
|
||||
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
from library.fp8_optimization_utils import apply_fp8_monkey_patch
|
||||
from library.lora_utils import load_safetensors_with_lora_and_fp8
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
from library.hunyuan_image_modules import (
|
||||
SingleTokenRefiner,
|
||||
ByT5Mapper,
|
||||
PatchEmbed2D,
|
||||
TimestepEmbedder,
|
||||
MMDoubleStreamBlock,
|
||||
MMSingleStreamBlock,
|
||||
FinalLayer,
|
||||
)
|
||||
from library.hunyuan_image_utils import get_nd_rotary_pos_embed
|
||||
|
||||
FP8_OPTIMIZATION_TARGET_KEYS = ["double_blocks", "single_blocks"]
|
||||
FP8_OPTIMIZATION_EXCLUDE_KEYS = [
|
||||
"norm",
|
||||
"_mod",
|
||||
"modulation",
|
||||
]
|
||||
|
||||
|
||||
# region DiT Model
|
||||
class HYImageDiffusionTransformer(nn.Module):
|
||||
"""
|
||||
HunyuanImage-2.1 Diffusion Transformer.
|
||||
|
||||
A multimodal transformer for image generation with text conditioning,
|
||||
featuring separate double-stream and single-stream processing blocks.
|
||||
|
||||
Args:
|
||||
attn_mode: Attention implementation mode ("torch" or "sageattn").
|
||||
"""
|
||||
|
||||
def __init__(self, attn_mode: str = "torch"):
|
||||
super().__init__()
|
||||
|
||||
# Fixed architecture parameters for HunyuanImage-2.1
|
||||
self.patch_size = [1, 1] # 1x1 patch size (no spatial downsampling)
|
||||
self.in_channels = 64 # Input latent channels
|
||||
self.out_channels = 64 # Output latent channels
|
||||
self.unpatchify_channels = self.out_channels
|
||||
self.guidance_embed = False # Guidance embedding disabled
|
||||
self.rope_dim_list = [64, 64] # RoPE dimensions for 2D positional encoding
|
||||
self.rope_theta = 256 # RoPE frequency scaling
|
||||
self.use_attention_mask = True
|
||||
self.text_projection = "single_refiner"
|
||||
self.hidden_size = 3584 # Model dimension
|
||||
self.heads_num = 28 # Number of attention heads
|
||||
|
||||
# Architecture configuration
|
||||
mm_double_blocks_depth = 20 # Double-stream transformer blocks
|
||||
mm_single_blocks_depth = 40 # Single-stream transformer blocks
|
||||
mlp_width_ratio = 4 # MLP expansion ratio
|
||||
text_states_dim = 3584 # Text encoder output dimension
|
||||
guidance_embed = False # No guidance embedding
|
||||
|
||||
# Layer configuration
|
||||
mlp_act_type: str = "gelu_tanh" # MLP activation function
|
||||
qkv_bias: bool = True # Use bias in QKV projections
|
||||
qk_norm: bool = True # Apply QK normalization
|
||||
qk_norm_type: str = "rms" # RMS normalization type
|
||||
|
||||
self.attn_mode = attn_mode
|
||||
|
||||
# ByT5 character-level text encoder mapping
|
||||
self.byt5_in = ByT5Mapper(in_dim=1472, out_dim=2048, hidden_dim=2048, out_dim1=self.hidden_size, use_residual=False)
|
||||
|
||||
# Image latent patch embedding
|
||||
self.img_in = PatchEmbed2D(self.patch_size, self.in_channels, self.hidden_size)
|
||||
|
||||
# Text token refinement with cross-attention
|
||||
self.txt_in = SingleTokenRefiner(text_states_dim, self.hidden_size, self.heads_num, depth=2, attn_mode=self.attn_mode)
|
||||
|
||||
# Timestep embedding for diffusion process
|
||||
self.time_in = TimestepEmbedder(self.hidden_size, nn.SiLU)
|
||||
|
||||
# MeanFlow not supported in this implementation
|
||||
self.time_r_in = None
|
||||
|
||||
# Guidance embedding (disabled for non-distilled model)
|
||||
self.guidance_in = TimestepEmbedder(self.hidden_size, nn.SiLU) if guidance_embed else None
|
||||
|
||||
# Double-stream blocks: separate image and text processing
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
MMDoubleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.heads_num,
|
||||
mlp_width_ratio=mlp_width_ratio,
|
||||
mlp_act_type=mlp_act_type,
|
||||
qk_norm=qk_norm,
|
||||
qk_norm_type=qk_norm_type,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_mode=self.attn_mode,
|
||||
)
|
||||
for _ in range(mm_double_blocks_depth)
|
||||
]
|
||||
)
|
||||
|
||||
# Single-stream blocks: joint processing of concatenated features
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
MMSingleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.heads_num,
|
||||
mlp_width_ratio=mlp_width_ratio,
|
||||
mlp_act_type=mlp_act_type,
|
||||
qk_norm=qk_norm,
|
||||
qk_norm_type=qk_norm_type,
|
||||
attn_mode=self.attn_mode,
|
||||
)
|
||||
for _ in range(mm_single_blocks_depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.final_layer = FinalLayer(self.hidden_size, self.patch_size, self.out_channels, nn.SiLU)
|
||||
|
||||
def get_rotary_pos_embed(self, rope_sizes):
|
||||
"""
|
||||
Generate 2D rotary position embeddings for image tokens.
|
||||
|
||||
Args:
|
||||
rope_sizes: Tuple of (height, width) for spatial dimensions.
|
||||
|
||||
Returns:
|
||||
Tuple of (freqs_cos, freqs_sin) tensors for rotary position encoding.
|
||||
"""
|
||||
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(self.rope_dim_list, rope_sizes, theta=self.rope_theta)
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
def reorder_txt_token(
|
||||
self, byt5_txt: torch.Tensor, txt: torch.Tensor, byt5_text_mask: torch.Tensor, text_mask: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, list[int]]:
|
||||
"""
|
||||
Combine and reorder ByT5 character-level and word-level text embeddings.
|
||||
|
||||
Concatenates valid tokens from both encoders and creates appropriate masks.
|
||||
|
||||
Args:
|
||||
byt5_txt: ByT5 character-level embeddings [B, L1, D].
|
||||
txt: Word-level text embeddings [B, L2, D].
|
||||
byt5_text_mask: Valid token mask for ByT5 [B, L1].
|
||||
text_mask: Valid token mask for word tokens [B, L2].
|
||||
|
||||
Returns:
|
||||
Tuple of (reordered_embeddings, combined_mask, sequence_lengths).
|
||||
"""
|
||||
# Process each batch element separately to handle variable sequence lengths
|
||||
|
||||
reorder_txt = []
|
||||
reorder_mask = []
|
||||
|
||||
txt_lens = []
|
||||
for i in range(text_mask.shape[0]):
|
||||
byt5_text_mask_i = byt5_text_mask[i].bool()
|
||||
text_mask_i = text_mask[i].bool()
|
||||
byt5_text_length = byt5_text_mask_i.sum()
|
||||
text_length = text_mask_i.sum()
|
||||
assert byt5_text_length == byt5_text_mask_i[:byt5_text_length].sum()
|
||||
assert text_length == text_mask_i[:text_length].sum()
|
||||
|
||||
byt5_txt_i = byt5_txt[i]
|
||||
txt_i = txt[i]
|
||||
reorder_txt_i = torch.cat(
|
||||
[byt5_txt_i[:byt5_text_length], txt_i[:text_length], byt5_txt_i[byt5_text_length:], txt_i[text_length:]], dim=0
|
||||
)
|
||||
|
||||
reorder_mask_i = torch.zeros(
|
||||
byt5_text_mask_i.shape[0] + text_mask_i.shape[0], dtype=torch.bool, device=byt5_text_mask_i.device
|
||||
)
|
||||
reorder_mask_i[: byt5_text_length + text_length] = True
|
||||
|
||||
reorder_txt.append(reorder_txt_i)
|
||||
reorder_mask.append(reorder_mask_i)
|
||||
txt_lens.append(byt5_text_length + text_length)
|
||||
|
||||
reorder_txt = torch.stack(reorder_txt)
|
||||
reorder_mask = torch.stack(reorder_mask).to(dtype=torch.int64)
|
||||
|
||||
return reorder_txt, reorder_mask, txt_lens
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
text_states: torch.Tensor,
|
||||
encoder_attention_mask: torch.Tensor,
|
||||
byt5_text_states: Optional[torch.Tensor] = None,
|
||||
byt5_text_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass through the HunyuanImage diffusion transformer.
|
||||
|
||||
Args:
|
||||
hidden_states: Input image latents [B, C, H, W].
|
||||
timestep: Diffusion timestep [B].
|
||||
text_states: Word-level text embeddings [B, L, D].
|
||||
encoder_attention_mask: Text attention mask [B, L].
|
||||
byt5_text_states: ByT5 character-level embeddings [B, L_byt5, D_byt5].
|
||||
byt5_text_mask: ByT5 attention mask [B, L_byt5].
|
||||
|
||||
Returns:
|
||||
Tuple of (denoised_image, spatial_shape).
|
||||
"""
|
||||
img = x = hidden_states
|
||||
text_mask = encoder_attention_mask
|
||||
t = timestep
|
||||
txt = text_states
|
||||
|
||||
# Calculate spatial dimensions for rotary position embeddings
|
||||
_, _, oh, ow = x.shape
|
||||
th, tw = oh, ow # Height and width (patch_size=[1,1] means no spatial downsampling)
|
||||
freqs_cis = self.get_rotary_pos_embed((th, tw))
|
||||
|
||||
# Reshape image latents to sequence format: [B, C, H, W] -> [B, H*W, C]
|
||||
img = self.img_in(img)
|
||||
|
||||
# Generate timestep conditioning vector
|
||||
vec = self.time_in(t)
|
||||
|
||||
# MeanFlow and guidance embedding not used in this configuration
|
||||
|
||||
# Process text tokens through refinement layers
|
||||
txt_lens = text_mask.to(torch.bool).sum(dim=1).tolist()
|
||||
txt = self.txt_in(txt, t, txt_lens)
|
||||
|
||||
# Integrate character-level ByT5 features with word-level tokens
|
||||
# Use variable length sequences with sequence lengths
|
||||
byt5_txt = self.byt5_in(byt5_text_states)
|
||||
txt, _, txt_lens = self.reorder_txt_token(byt5_txt, txt, byt5_text_mask, text_mask)
|
||||
|
||||
# Trim sequences to maximum length in the batch
|
||||
img_seq_len = img.shape[1]
|
||||
# print(f"img_seq_len: {img_seq_len}, txt_lens: {txt_lens}")
|
||||
seq_lens = [img_seq_len + l for l in txt_lens]
|
||||
max_txt_len = max(txt_lens)
|
||||
# print(f"max_txt_len: {max_txt_len}, seq_lens: {seq_lens}, txt.shape: {txt.shape}")
|
||||
txt = txt[:, :max_txt_len, :]
|
||||
txt_seq_len = txt.shape[1]
|
||||
|
||||
# Process through double-stream blocks (separate image/text attention)
|
||||
for index, block in enumerate(self.double_blocks):
|
||||
img, txt = block(img, txt, vec, freqs_cis, seq_lens)
|
||||
|
||||
# Concatenate image and text tokens for joint processing
|
||||
x = torch.cat((img, txt), 1)
|
||||
|
||||
# Process through single-stream blocks (joint attention)
|
||||
for index, block in enumerate(self.single_blocks):
|
||||
x = block(x, vec, txt_seq_len, freqs_cis, seq_lens)
|
||||
|
||||
img = x[:, :img_seq_len, ...]
|
||||
|
||||
# Apply final projection to output space
|
||||
img = self.final_layer(img, vec)
|
||||
|
||||
# Reshape from sequence to spatial format: [B, L, C] -> [B, C, H, W]
|
||||
img = self.unpatchify_2d(img, th, tw)
|
||||
return img
|
||||
|
||||
def unpatchify_2d(self, x, h, w):
|
||||
"""
|
||||
Convert sequence format back to spatial image format.
|
||||
|
||||
Args:
|
||||
x: Input tensor [B, H*W, C].
|
||||
h: Height dimension.
|
||||
w: Width dimension.
|
||||
|
||||
Returns:
|
||||
Spatial tensor [B, C, H, W].
|
||||
"""
|
||||
c = self.unpatchify_channels
|
||||
|
||||
x = x.reshape(shape=(x.shape[0], h, w, c))
|
||||
imgs = x.permute(0, 3, 1, 2)
|
||||
return imgs
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Model Utils
|
||||
|
||||
|
||||
def create_model(attn_mode: str, split_attn: bool, dtype: Optional[torch.dtype]) -> HYImageDiffusionTransformer:
|
||||
with init_empty_weights():
|
||||
model = HYImageDiffusionTransformer(attn_mode=attn_mode)
|
||||
if dtype is not None:
|
||||
model.to(dtype)
|
||||
return model
|
||||
|
||||
|
||||
def load_hunyuan_image_model(
|
||||
device: Union[str, torch.device],
|
||||
dit_path: str,
|
||||
attn_mode: str,
|
||||
split_attn: bool,
|
||||
loading_device: Union[str, torch.device],
|
||||
dit_weight_dtype: Optional[torch.dtype],
|
||||
fp8_scaled: bool = False,
|
||||
lora_weights_list: Optional[Dict[str, torch.Tensor]] = None,
|
||||
lora_multipliers: Optional[list[float]] = None,
|
||||
) -> HYImageDiffusionTransformer:
|
||||
"""
|
||||
Load a HunyuanImage model from the specified checkpoint.
|
||||
|
||||
Args:
|
||||
device (Union[str, torch.device]): Device for optimization or merging
|
||||
dit_path (str): Path to the DiT model checkpoint.
|
||||
attn_mode (str): Attention mode to use, e.g., "torch", "flash", etc.
|
||||
split_attn (bool): Whether to use split attention.
|
||||
loading_device (Union[str, torch.device]): Device to load the model weights on.
|
||||
dit_weight_dtype (Optional[torch.dtype]): Data type of the DiT weights.
|
||||
If None, it will be loaded as is (same as the state_dict) or scaled for fp8. if not None, model weights will be casted to this dtype.
|
||||
fp8_scaled (bool): Whether to use fp8 scaling for the model weights.
|
||||
lora_weights_list (Optional[Dict[str, torch.Tensor]]): LoRA weights to apply, if any.
|
||||
lora_multipliers (Optional[List[float]]): LoRA multipliers for the weights, if any.
|
||||
"""
|
||||
# dit_weight_dtype is None for fp8_scaled
|
||||
assert (not fp8_scaled and dit_weight_dtype is not None) or (fp8_scaled and dit_weight_dtype is None)
|
||||
|
||||
device = torch.device(device)
|
||||
loading_device = torch.device(loading_device)
|
||||
|
||||
model = create_model(attn_mode, split_attn, dit_weight_dtype)
|
||||
|
||||
# load model weights with dynamic fp8 optimization and LoRA merging if needed
|
||||
logger.info(f"Loading DiT model from {dit_path}, device={loading_device}")
|
||||
|
||||
sd = load_safetensors_with_lora_and_fp8(
|
||||
model_files=dit_path,
|
||||
lora_weights_list=lora_weights_list,
|
||||
lora_multipliers=lora_multipliers,
|
||||
fp8_optimization=fp8_scaled,
|
||||
calc_device=device,
|
||||
move_to_device=(loading_device == device),
|
||||
dit_weight_dtype=dit_weight_dtype,
|
||||
target_keys=FP8_OPTIMIZATION_TARGET_KEYS,
|
||||
exclude_keys=FP8_OPTIMIZATION_EXCLUDE_KEYS,
|
||||
)
|
||||
|
||||
if fp8_scaled:
|
||||
apply_fp8_monkey_patch(model, sd, use_scaled_mm=False)
|
||||
|
||||
if loading_device.type != "cpu":
|
||||
# make sure all the model weights are on the loading_device
|
||||
logger.info(f"Moving weights to {loading_device}")
|
||||
for key in sd.keys():
|
||||
sd[key] = sd[key].to(loading_device)
|
||||
|
||||
info = model.load_state_dict(sd, strict=True, assign=True)
|
||||
logger.info(f"Loaded DiT model from {dit_path}, info={info}")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
# endregion
|
||||
804
library/hunyuan_image_modules.py
Normal file
804
library/hunyuan_image_modules.py
Normal file
@@ -0,0 +1,804 @@
|
||||
# Original work: https://github.com/Tencent-Hunyuan/HunyuanImage-2.1
|
||||
# Re-implemented for license compliance for sd-scripts.
|
||||
|
||||
from typing import Tuple, Callable
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
|
||||
from library.attention import attention
|
||||
from library.hunyuan_image_utils import timestep_embedding, apply_rotary_emb, _to_tuple, apply_gate, modulate
|
||||
from library.attention import attention
|
||||
|
||||
# region Modules
|
||||
|
||||
|
||||
class ByT5Mapper(nn.Module):
|
||||
"""
|
||||
Maps ByT5 character-level encoder outputs to transformer hidden space.
|
||||
|
||||
Applies layer normalization, two MLP layers with GELU activation,
|
||||
and optional residual connection.
|
||||
|
||||
Args:
|
||||
in_dim: Input dimension from ByT5 encoder (1472 for ByT5-large).
|
||||
out_dim: Intermediate dimension after first projection.
|
||||
hidden_dim: Hidden dimension for MLP layer.
|
||||
out_dim1: Final output dimension matching transformer hidden size.
|
||||
use_residual: Whether to add residual connection (requires in_dim == out_dim).
|
||||
"""
|
||||
|
||||
def __init__(self, in_dim, out_dim, hidden_dim, out_dim1, use_residual=True):
|
||||
super().__init__()
|
||||
if use_residual:
|
||||
assert in_dim == out_dim
|
||||
self.layernorm = nn.LayerNorm(in_dim)
|
||||
self.fc1 = nn.Linear(in_dim, hidden_dim)
|
||||
self.fc2 = nn.Linear(hidden_dim, out_dim)
|
||||
self.fc3 = nn.Linear(out_dim, out_dim1)
|
||||
self.use_residual = use_residual
|
||||
self.act_fn = nn.GELU()
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Transform ByT5 embeddings to transformer space.
|
||||
|
||||
Args:
|
||||
x: Input ByT5 embeddings [..., in_dim].
|
||||
|
||||
Returns:
|
||||
Transformed embeddings [..., out_dim1].
|
||||
"""
|
||||
residual = x
|
||||
x = self.layernorm(x)
|
||||
x = self.fc1(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.fc2(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.fc3(x)
|
||||
if self.use_residual:
|
||||
x = x + residual
|
||||
return x
|
||||
|
||||
|
||||
class PatchEmbed2D(nn.Module):
|
||||
"""
|
||||
2D patch embedding layer for converting image latents to transformer tokens.
|
||||
|
||||
Uses 2D convolution to project image patches to embedding space.
|
||||
For HunyuanImage-2.1, patch_size=[1,1] means no spatial downsampling.
|
||||
|
||||
Args:
|
||||
patch_size: Spatial size of patches (int or tuple).
|
||||
in_chans: Number of input channels.
|
||||
embed_dim: Output embedding dimension.
|
||||
"""
|
||||
|
||||
def __init__(self, patch_size=16, in_chans=3, embed_dim=768):
|
||||
super().__init__()
|
||||
self.patch_size = tuple(patch_size)
|
||||
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=True)
|
||||
self.norm = nn.Identity() # No normalization layer used
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds scalar diffusion timesteps into vector representations.
|
||||
|
||||
Uses sinusoidal encoding followed by a two-layer MLP.
|
||||
|
||||
Args:
|
||||
hidden_size: Output embedding dimension.
|
||||
act_layer: Activation function class (e.g., nn.SiLU).
|
||||
frequency_embedding_size: Dimension of sinusoidal encoding.
|
||||
max_period: Maximum period for sinusoidal frequencies.
|
||||
out_size: Output dimension (defaults to hidden_size).
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, act_layer, frequency_embedding_size=256, max_period=10000, out_size=None):
|
||||
super().__init__()
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
self.max_period = max_period
|
||||
if out_size is None:
|
||||
out_size = hidden_size
|
||||
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True), act_layer(), nn.Linear(hidden_size, out_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype)
|
||||
return self.mlp(t_freq)
|
||||
|
||||
|
||||
class TextProjection(nn.Module):
|
||||
"""
|
||||
Projects text embeddings through a two-layer MLP.
|
||||
|
||||
Used for context-aware representation computation in token refinement.
|
||||
|
||||
Args:
|
||||
in_channels: Input feature dimension.
|
||||
hidden_size: Hidden and output dimension.
|
||||
act_layer: Activation function class.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, hidden_size, act_layer):
|
||||
super().__init__()
|
||||
self.linear_1 = nn.Linear(in_features=in_channels, out_features=hidden_size, bias=True)
|
||||
self.act_1 = act_layer()
|
||||
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True)
|
||||
|
||||
def forward(self, caption):
|
||||
hidden_states = self.linear_1(caption)
|
||||
hidden_states = self.act_1(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
"""
|
||||
Multi-layer perceptron with configurable activation and normalization.
|
||||
|
||||
Standard two-layer MLP with optional dropout and intermediate normalization.
|
||||
|
||||
Args:
|
||||
in_channels: Input feature dimension.
|
||||
hidden_channels: Hidden layer dimension (defaults to in_channels).
|
||||
out_features: Output dimension (defaults to in_channels).
|
||||
act_layer: Activation function class.
|
||||
norm_layer: Optional normalization layer class.
|
||||
bias: Whether to use bias (can be bool or tuple for each layer).
|
||||
drop: Dropout rate (can be float or tuple for each layer).
|
||||
use_conv: Whether to use convolution instead of linear (not supported).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
hidden_channels=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=None,
|
||||
bias=True,
|
||||
drop=0.0,
|
||||
use_conv=False,
|
||||
):
|
||||
super().__init__()
|
||||
assert not use_conv, "Convolutional MLP not supported in this implementation."
|
||||
|
||||
out_features = out_features or in_channels
|
||||
hidden_channels = hidden_channels or in_channels
|
||||
bias = _to_tuple(bias, 2)
|
||||
drop_probs = _to_tuple(drop, 2)
|
||||
|
||||
self.fc1 = nn.Linear(in_channels, hidden_channels, bias=bias[0])
|
||||
self.act = act_layer()
|
||||
self.drop1 = nn.Dropout(drop_probs[0])
|
||||
self.norm = norm_layer(hidden_channels) if norm_layer is not None else nn.Identity()
|
||||
self.fc2 = nn.Linear(hidden_channels, out_features, bias=bias[1])
|
||||
self.drop2 = nn.Dropout(drop_probs[1])
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop1(x)
|
||||
x = self.norm(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop2(x)
|
||||
return x
|
||||
|
||||
|
||||
class IndividualTokenRefinerBlock(nn.Module):
|
||||
"""
|
||||
Single transformer block for individual token refinement.
|
||||
|
||||
Applies self-attention and MLP with adaptive layer normalization (AdaLN)
|
||||
conditioned on timestep and context information.
|
||||
|
||||
Args:
|
||||
hidden_size: Model dimension.
|
||||
heads_num: Number of attention heads.
|
||||
mlp_width_ratio: MLP expansion ratio.
|
||||
mlp_drop_rate: MLP dropout rate.
|
||||
act_type: Activation function (only "silu" supported).
|
||||
qk_norm: QK normalization flag (must be False).
|
||||
qk_norm_type: QK normalization type (only "layer" supported).
|
||||
qkv_bias: Use bias in QKV projections.
|
||||
attn_mode: Attention implementation mode.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
heads_num: int,
|
||||
mlp_width_ratio: float = 4.0,
|
||||
mlp_drop_rate: float = 0.0,
|
||||
act_type: str = "silu",
|
||||
qk_norm: bool = False,
|
||||
qk_norm_type: str = "layer",
|
||||
qkv_bias: bool = True,
|
||||
attn_mode: str = "torch",
|
||||
):
|
||||
super().__init__()
|
||||
assert qk_norm_type == "layer", "Only layer normalization supported for QK norm."
|
||||
assert act_type == "silu", "Only SiLU activation supported."
|
||||
assert not qk_norm, "QK normalization must be disabled."
|
||||
|
||||
self.attn_mode = attn_mode
|
||||
|
||||
self.heads_num = heads_num
|
||||
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
|
||||
|
||||
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
|
||||
self.self_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias)
|
||||
|
||||
self.self_attn_q_norm = nn.Identity()
|
||||
self.self_attn_k_norm = nn.Identity()
|
||||
self.self_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias)
|
||||
|
||||
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
|
||||
self.mlp = MLP(in_channels=hidden_size, hidden_channels=mlp_hidden_dim, act_layer=nn.SiLU, drop=mlp_drop_rate)
|
||||
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, 2 * hidden_size, bias=True),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
c: torch.Tensor, # Combined timestep and context conditioning
|
||||
txt_lens: list[int],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply self-attention and MLP with adaptive conditioning.
|
||||
|
||||
Args:
|
||||
x: Input token embeddings [B, L, C].
|
||||
c: Combined conditioning vector [B, C].
|
||||
txt_lens: Valid sequence lengths for each batch element.
|
||||
|
||||
Returns:
|
||||
Refined token embeddings [B, L, C].
|
||||
"""
|
||||
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
norm_x = self.norm1(x)
|
||||
qkv = self.self_attn_qkv(norm_x)
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
||||
q = self.self_attn_q_norm(q).to(v)
|
||||
k = self.self_attn_k_norm(k).to(v)
|
||||
attn = attention(q, k, v, seq_lens=txt_lens, attn_mode=self.attn_mode)
|
||||
|
||||
x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
|
||||
x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
|
||||
return x
|
||||
|
||||
|
||||
class IndividualTokenRefiner(nn.Module):
|
||||
"""
|
||||
Stack of token refinement blocks with self-attention.
|
||||
|
||||
Processes tokens individually with adaptive layer normalization.
|
||||
|
||||
Args:
|
||||
hidden_size: Model dimension.
|
||||
heads_num: Number of attention heads.
|
||||
depth: Number of refinement blocks.
|
||||
mlp_width_ratio: MLP expansion ratio.
|
||||
mlp_drop_rate: MLP dropout rate.
|
||||
act_type: Activation function type.
|
||||
qk_norm: QK normalization flag.
|
||||
qk_norm_type: QK normalization type.
|
||||
qkv_bias: Use bias in QKV projections.
|
||||
attn_mode: Attention implementation mode.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
heads_num: int,
|
||||
depth: int,
|
||||
mlp_width_ratio: float = 4.0,
|
||||
mlp_drop_rate: float = 0.0,
|
||||
act_type: str = "silu",
|
||||
qk_norm: bool = False,
|
||||
qk_norm_type: str = "layer",
|
||||
qkv_bias: bool = True,
|
||||
attn_mode: str = "torch",
|
||||
):
|
||||
super().__init__()
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
IndividualTokenRefinerBlock(
|
||||
hidden_size=hidden_size,
|
||||
heads_num=heads_num,
|
||||
mlp_width_ratio=mlp_width_ratio,
|
||||
mlp_drop_rate=mlp_drop_rate,
|
||||
act_type=act_type,
|
||||
qk_norm=qk_norm,
|
||||
qk_norm_type=qk_norm_type,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_mode=attn_mode,
|
||||
)
|
||||
for _ in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, c: torch.LongTensor, txt_lens: list[int]) -> torch.Tensor:
|
||||
"""
|
||||
Apply sequential token refinement.
|
||||
|
||||
Args:
|
||||
x: Input token embeddings [B, L, C].
|
||||
c: Combined conditioning vector [B, C].
|
||||
txt_lens: Valid sequence lengths for each batch element.
|
||||
|
||||
Returns:
|
||||
Refined token embeddings [B, L, C].
|
||||
"""
|
||||
for block in self.blocks:
|
||||
x = block(x, c, txt_lens)
|
||||
return x
|
||||
|
||||
|
||||
class SingleTokenRefiner(nn.Module):
|
||||
"""
|
||||
Text embedding refinement with timestep and context conditioning.
|
||||
|
||||
Projects input text embeddings and applies self-attention refinement
|
||||
conditioned on diffusion timestep and aggregate text context.
|
||||
|
||||
Args:
|
||||
in_channels: Input text embedding dimension.
|
||||
hidden_size: Transformer hidden dimension.
|
||||
heads_num: Number of attention heads.
|
||||
depth: Number of refinement blocks.
|
||||
attn_mode: Attention implementation mode.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, hidden_size: int, heads_num: int, depth: int, attn_mode: str = "torch"):
|
||||
# Fixed architecture parameters for HunyuanImage-2.1
|
||||
mlp_drop_rate: float = 0.0 # No MLP dropout
|
||||
act_type: str = "silu" # SiLU activation
|
||||
mlp_width_ratio: float = 4.0 # 4x MLP expansion
|
||||
qk_norm: bool = False # No QK normalization
|
||||
qk_norm_type: str = "layer" # Layer norm type (unused)
|
||||
qkv_bias: bool = True # Use QKV bias
|
||||
|
||||
super().__init__()
|
||||
self.input_embedder = nn.Linear(in_channels, hidden_size, bias=True)
|
||||
act_layer = nn.SiLU
|
||||
self.t_embedder = TimestepEmbedder(hidden_size, act_layer)
|
||||
self.c_embedder = TextProjection(in_channels, hidden_size, act_layer)
|
||||
self.individual_token_refiner = IndividualTokenRefiner(
|
||||
hidden_size=hidden_size,
|
||||
heads_num=heads_num,
|
||||
depth=depth,
|
||||
mlp_width_ratio=mlp_width_ratio,
|
||||
mlp_drop_rate=mlp_drop_rate,
|
||||
act_type=act_type,
|
||||
qk_norm=qk_norm,
|
||||
qk_norm_type=qk_norm_type,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_mode=attn_mode,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, t: torch.LongTensor, txt_lens: list[int]) -> torch.Tensor:
|
||||
"""
|
||||
Refine text embeddings with timestep conditioning.
|
||||
|
||||
Args:
|
||||
x: Input text embeddings [B, L, in_channels].
|
||||
t: Diffusion timestep [B].
|
||||
txt_lens: Valid sequence lengths for each batch element.
|
||||
|
||||
Returns:
|
||||
Refined embeddings [B, L, hidden_size].
|
||||
"""
|
||||
timestep_aware_representations = self.t_embedder(t)
|
||||
|
||||
# Compute context-aware representations by averaging valid tokens
|
||||
context_aware_representations = torch.stack([x[i, : txt_lens[i]].mean(dim=0) for i in range(x.shape[0])], dim=0) # [B, C]
|
||||
|
||||
context_aware_representations = self.c_embedder(context_aware_representations)
|
||||
c = timestep_aware_representations + context_aware_representations
|
||||
x = self.input_embedder(x)
|
||||
x = self.individual_token_refiner(x, c, txt_lens)
|
||||
return x
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
"""
|
||||
Final output projection layer with adaptive layer normalization.
|
||||
|
||||
Projects transformer hidden states to output patch space with
|
||||
timestep-conditioned modulation.
|
||||
|
||||
Args:
|
||||
hidden_size: Input hidden dimension.
|
||||
patch_size: Spatial patch size for output reshaping.
|
||||
out_channels: Number of output channels.
|
||||
act_layer: Activation function class.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, patch_size, out_channels, act_layer):
|
||||
super().__init__()
|
||||
|
||||
# Layer normalization without learnable parameters
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
out_size = (patch_size[0] * patch_size[1]) * out_channels
|
||||
self.linear = nn.Linear(hidden_size, out_size, bias=True)
|
||||
|
||||
# Adaptive layer normalization modulation
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
act_layer(),
|
||||
nn.Linear(hidden_size, 2 * hidden_size, bias=True),
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift=shift, scale=scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
"""
|
||||
Root Mean Square Layer Normalization.
|
||||
|
||||
Normalizes input using RMS and applies learnable scaling.
|
||||
More efficient than LayerNorm as it doesn't compute mean.
|
||||
|
||||
Args:
|
||||
dim: Input feature dimension.
|
||||
eps: Small value for numerical stability.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def _norm(self, x):
|
||||
"""
|
||||
Apply RMS normalization.
|
||||
|
||||
Args:
|
||||
x: Input tensor.
|
||||
|
||||
Returns:
|
||||
RMS normalized tensor.
|
||||
"""
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def reset_parameters(self):
|
||||
self.weight.fill_(1)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Apply RMSNorm with learnable scaling.
|
||||
|
||||
Args:
|
||||
x: Input tensor.
|
||||
|
||||
Returns:
|
||||
Normalized and scaled tensor.
|
||||
"""
|
||||
output = self._norm(x.float()).type_as(x)
|
||||
output = output * self.weight
|
||||
return output
|
||||
|
||||
|
||||
# kept for reference, not used in current implementation
|
||||
# class LinearWarpforSingle(nn.Module):
|
||||
# """
|
||||
# Linear layer wrapper for concatenating and projecting two inputs.
|
||||
|
||||
# Used in single-stream blocks to combine attention output with MLP features.
|
||||
|
||||
# Args:
|
||||
# in_dim: Input dimension (sum of both input feature dimensions).
|
||||
# out_dim: Output dimension.
|
||||
# bias: Whether to use bias in linear projection.
|
||||
# """
|
||||
|
||||
# def __init__(self, in_dim: int, out_dim: int, bias=False):
|
||||
# super().__init__()
|
||||
# self.fc = nn.Linear(in_dim, out_dim, bias=bias)
|
||||
|
||||
# def forward(self, x, y):
|
||||
# """Concatenate inputs along feature dimension and project."""
|
||||
# x = torch.cat([x.contiguous(), y.contiguous()], dim=2).contiguous()
|
||||
# return self.fc(x)
|
||||
|
||||
|
||||
class ModulateDiT(nn.Module):
|
||||
"""
|
||||
Timestep conditioning modulation layer.
|
||||
|
||||
Projects timestep embeddings to multiple modulation parameters
|
||||
for adaptive layer normalization.
|
||||
|
||||
Args:
|
||||
hidden_size: Input conditioning dimension.
|
||||
factor: Number of modulation parameters to generate.
|
||||
act_layer: Activation function class.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size: int, factor: int, act_layer: Callable):
|
||||
super().__init__()
|
||||
self.act = act_layer()
|
||||
self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.linear(self.act(x))
|
||||
|
||||
|
||||
class MMDoubleStreamBlock(nn.Module):
|
||||
"""
|
||||
Multimodal double-stream transformer block.
|
||||
|
||||
Processes image and text tokens separately with cross-modal attention.
|
||||
Each stream has its own normalization and MLP layers but shares
|
||||
attention computation for cross-modal interaction.
|
||||
|
||||
Args:
|
||||
hidden_size: Model dimension.
|
||||
heads_num: Number of attention heads.
|
||||
mlp_width_ratio: MLP expansion ratio.
|
||||
mlp_act_type: MLP activation function (only "gelu_tanh" supported).
|
||||
qk_norm: QK normalization flag (must be True).
|
||||
qk_norm_type: QK normalization type (only "rms" supported).
|
||||
qkv_bias: Use bias in QKV projections.
|
||||
attn_mode: Attention implementation mode.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
heads_num: int,
|
||||
mlp_width_ratio: float,
|
||||
mlp_act_type: str = "gelu_tanh",
|
||||
qk_norm: bool = True,
|
||||
qk_norm_type: str = "rms",
|
||||
qkv_bias: bool = False,
|
||||
attn_mode: str = "torch",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert mlp_act_type == "gelu_tanh", "Only GELU-tanh activation supported."
|
||||
assert qk_norm_type == "rms", "Only RMS normalization supported."
|
||||
assert qk_norm, "QK normalization must be enabled."
|
||||
|
||||
self.attn_mode = attn_mode
|
||||
self.heads_num = heads_num
|
||||
head_dim = hidden_size // heads_num
|
||||
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
|
||||
|
||||
# Image stream processing components
|
||||
self.img_mod = ModulateDiT(hidden_size, factor=6, act_layer=nn.SiLU)
|
||||
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
self.img_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias)
|
||||
|
||||
self.img_attn_q_norm = RMSNorm(head_dim, eps=1e-6)
|
||||
self.img_attn_k_norm = RMSNorm(head_dim, eps=1e-6)
|
||||
self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias)
|
||||
|
||||
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.img_mlp = MLP(hidden_size, mlp_hidden_dim, act_layer=lambda: nn.GELU(approximate="tanh"), bias=True)
|
||||
|
||||
# Text stream processing components
|
||||
self.txt_mod = ModulateDiT(hidden_size, factor=6, act_layer=nn.SiLU)
|
||||
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
self.txt_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias)
|
||||
self.txt_attn_q_norm = RMSNorm(head_dim, eps=1e-6)
|
||||
self.txt_attn_k_norm = RMSNorm(head_dim, eps=1e-6)
|
||||
self.txt_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias)
|
||||
|
||||
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.txt_mlp = MLP(hidden_size, mlp_hidden_dim, act_layer=lambda: nn.GELU(approximate="tanh"), bias=True)
|
||||
|
||||
def forward(
|
||||
self, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, freqs_cis: tuple = None, seq_lens: list[int] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Extract modulation parameters for image and text streams
|
||||
(img_mod1_shift, img_mod1_scale, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate) = self.img_mod(vec).chunk(
|
||||
6, dim=-1
|
||||
)
|
||||
(txt_mod1_shift, txt_mod1_scale, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate) = self.txt_mod(vec).chunk(
|
||||
6, dim=-1
|
||||
)
|
||||
|
||||
# Process image stream for attention
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = modulate(img_modulated, shift=img_mod1_shift, scale=img_mod1_scale)
|
||||
|
||||
img_qkv = self.img_attn_qkv(img_modulated)
|
||||
img_q, img_k, img_v = img_qkv.chunk(3, dim=-1)
|
||||
del img_qkv
|
||||
|
||||
img_q = rearrange(img_q, "B L (H D) -> B L H D", H=self.heads_num)
|
||||
img_k = rearrange(img_k, "B L (H D) -> B L H D", H=self.heads_num)
|
||||
img_v = rearrange(img_v, "B L (H D) -> B L H D", H=self.heads_num)
|
||||
|
||||
# Apply QK-Norm if enabled
|
||||
img_q = self.img_attn_q_norm(img_q).to(img_v)
|
||||
img_k = self.img_attn_k_norm(img_k).to(img_v)
|
||||
|
||||
# Apply rotary position embeddings to image tokens
|
||||
if freqs_cis is not None:
|
||||
img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
|
||||
assert (
|
||||
img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
|
||||
), f"RoPE output shape mismatch: got {img_qq.shape}, {img_kk.shape}, expected {img_q.shape}, {img_k.shape}"
|
||||
img_q, img_k = img_qq, img_kk
|
||||
|
||||
# Process text stream for attention
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
txt_modulated = modulate(txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale)
|
||||
|
||||
txt_qkv = self.txt_attn_qkv(txt_modulated)
|
||||
txt_q, txt_k, txt_v = txt_qkv.chunk(3, dim=-1)
|
||||
del txt_qkv
|
||||
|
||||
txt_q = rearrange(txt_q, "B L (H D) -> B L H D", H=self.heads_num)
|
||||
txt_k = rearrange(txt_k, "B L (H D) -> B L H D", H=self.heads_num)
|
||||
txt_v = rearrange(txt_v, "B L (H D) -> B L H D", H=self.heads_num)
|
||||
|
||||
# Apply QK-Norm if enabled
|
||||
txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
|
||||
txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
|
||||
|
||||
# Concatenate image and text tokens for joint attention
|
||||
q = torch.cat([img_q, txt_q], dim=1)
|
||||
k = torch.cat([img_k, txt_k], dim=1)
|
||||
v = torch.cat([img_v, txt_v], dim=1)
|
||||
attn = attention(q, k, v, seq_lens=seq_lens, attn_mode=self.attn_mode)
|
||||
|
||||
# Split attention outputs back to separate streams
|
||||
img_attn, txt_attn = (attn[:, : img_q.shape[1]].contiguous(), attn[:, img_q.shape[1] :].contiguous())
|
||||
|
||||
# Apply attention projection and residual connection for image stream
|
||||
img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate)
|
||||
|
||||
# Apply MLP and residual connection for image stream
|
||||
img = img + apply_gate(
|
||||
self.img_mlp(modulate(self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale)),
|
||||
gate=img_mod2_gate,
|
||||
)
|
||||
|
||||
# Apply attention projection and residual connection for text stream
|
||||
txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate)
|
||||
|
||||
# Apply MLP and residual connection for text stream
|
||||
txt = txt + apply_gate(
|
||||
self.txt_mlp(modulate(self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale)),
|
||||
gate=txt_mod2_gate,
|
||||
)
|
||||
|
||||
return img, txt
|
||||
|
||||
|
||||
class MMSingleStreamBlock(nn.Module):
|
||||
"""
|
||||
Multimodal single-stream transformer block.
|
||||
|
||||
Processes concatenated image and text tokens jointly with shared attention.
|
||||
Uses parallel linear layers for efficiency and applies RoPE only to image tokens.
|
||||
|
||||
Args:
|
||||
hidden_size: Model dimension.
|
||||
heads_num: Number of attention heads.
|
||||
mlp_width_ratio: MLP expansion ratio.
|
||||
mlp_act_type: MLP activation function (only "gelu_tanh" supported).
|
||||
qk_norm: QK normalization flag (must be True).
|
||||
qk_norm_type: QK normalization type (only "rms" supported).
|
||||
qk_scale: Attention scaling factor (computed automatically if None).
|
||||
attn_mode: Attention implementation mode.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
heads_num: int,
|
||||
mlp_width_ratio: float = 4.0,
|
||||
mlp_act_type: str = "gelu_tanh",
|
||||
qk_norm: bool = True,
|
||||
qk_norm_type: str = "rms",
|
||||
qk_scale: float = None,
|
||||
attn_mode: str = "torch",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert mlp_act_type == "gelu_tanh", "Only GELU-tanh activation supported."
|
||||
assert qk_norm_type == "rms", "Only RMS normalization supported."
|
||||
assert qk_norm, "QK normalization must be enabled."
|
||||
|
||||
self.attn_mode = attn_mode
|
||||
self.hidden_size = hidden_size
|
||||
self.heads_num = heads_num
|
||||
head_dim = hidden_size // heads_num
|
||||
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
|
||||
self.mlp_hidden_dim = mlp_hidden_dim
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
# Parallel linear projections for efficiency
|
||||
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim)
|
||||
|
||||
# Combined output projection
|
||||
# self.linear2 = LinearWarpforSingle(hidden_size + mlp_hidden_dim, hidden_size, bias=True) # for reference
|
||||
self.linear2 = nn.Linear(hidden_size + mlp_hidden_dim, hidden_size, bias=True)
|
||||
|
||||
# QK normalization layers
|
||||
self.q_norm = RMSNorm(head_dim, eps=1e-6)
|
||||
self.k_norm = RMSNorm(head_dim, eps=1e-6)
|
||||
|
||||
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
self.modulation = ModulateDiT(hidden_size, factor=3, act_layer=nn.SiLU)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
vec: torch.Tensor,
|
||||
txt_len: int,
|
||||
freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
|
||||
seq_lens: list[int] = None,
|
||||
) -> torch.Tensor:
|
||||
# Extract modulation parameters
|
||||
mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
|
||||
x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale)
|
||||
|
||||
# Compute Q, K, V, and MLP input
|
||||
qkv_mlp = self.linear1(x_mod)
|
||||
q, k, v, mlp = qkv_mlp.split([self.hidden_size, self.hidden_size, self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
del qkv_mlp
|
||||
|
||||
q = rearrange(q, "B L (H D) -> B L H D", H=self.heads_num)
|
||||
k = rearrange(k, "B L (H D) -> B L H D", H=self.heads_num)
|
||||
v = rearrange(v, "B L (H D) -> B L H D", H=self.heads_num)
|
||||
|
||||
# Apply QK-Norm if enabled
|
||||
q = self.q_norm(q).to(v)
|
||||
k = self.k_norm(k).to(v)
|
||||
|
||||
# Separate image and text tokens
|
||||
img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
|
||||
img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
|
||||
img_v, txt_v = v[:, :-txt_len, :, :], v[:, -txt_len:, :, :]
|
||||
|
||||
# Apply rotary position embeddings only to image tokens
|
||||
img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
|
||||
assert (
|
||||
img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
|
||||
), f"RoPE output shape mismatch: got {img_qq.shape}, {img_kk.shape}, expected {img_q.shape}, {img_k.shape}"
|
||||
img_q, img_k = img_qq, img_kk
|
||||
|
||||
# Recombine and compute joint attention
|
||||
q = torch.cat([img_q, txt_q], dim=1)
|
||||
k = torch.cat([img_k, txt_k], dim=1)
|
||||
v = torch.cat([img_v, txt_v], dim=1)
|
||||
attn = attention(q, k, v, seq_lens=seq_lens, attn_mode=self.attn_mode)
|
||||
|
||||
# Combine attention and MLP outputs, apply gating
|
||||
# output = self.linear2(attn, self.mlp_act(mlp))
|
||||
|
||||
mlp = self.mlp_act(mlp)
|
||||
output = torch.cat([attn, mlp], dim=2).contiguous()
|
||||
output = self.linear2(output)
|
||||
|
||||
return x + apply_gate(output, gate=mod_gate)
|
||||
|
||||
|
||||
# endregion
|
||||
649
library/hunyuan_image_text_encoder.py
Normal file
649
library/hunyuan_image_text_encoder.py
Normal file
@@ -0,0 +1,649 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Tuple, Optional, Union
|
||||
import torch
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
Qwen2_5_VLConfig,
|
||||
Qwen2_5_VLForConditionalGeneration,
|
||||
Qwen2Tokenizer,
|
||||
T5ForConditionalGeneration,
|
||||
T5Config,
|
||||
T5Tokenizer,
|
||||
)
|
||||
from transformers.models.t5.modeling_t5 import T5Stack
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
from library import model_util
|
||||
from library.utils import load_safetensors, setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
BYT5_TOKENIZER_PATH = "google/byt5-small"
|
||||
QWEN_2_5_VL_IMAGE_ID ="Qwen/Qwen2.5-VL-7B-Instruct"
|
||||
|
||||
|
||||
# Copy from Glyph-SDXL-V2
|
||||
|
||||
COLOR_IDX_JSON = """{"white": 0, "black": 1, "darkslategray": 2, "dimgray": 3, "darkolivegreen": 4, "midnightblue": 5, "saddlebrown": 6, "sienna": 7, "whitesmoke": 8, "darkslateblue": 9,
|
||||
"indianred": 10, "linen": 11, "maroon": 12, "khaki": 13, "sandybrown": 14, "gray": 15, "gainsboro": 16, "teal": 17, "peru": 18, "gold": 19,
|
||||
"snow": 20, "firebrick": 21, "crimson": 22, "chocolate": 23, "tomato": 24, "brown": 25, "goldenrod": 26, "antiquewhite": 27, "rosybrown": 28, "steelblue": 29,
|
||||
"floralwhite": 30, "seashell": 31, "darkgreen": 32, "oldlace": 33, "darkkhaki": 34, "burlywood": 35, "red": 36, "darkgray": 37, "orange": 38, "royalblue": 39,
|
||||
"seagreen": 40, "lightgray": 41, "tan": 42, "coral": 43, "beige": 44, "palevioletred": 45, "wheat": 46, "lavender": 47, "darkcyan": 48, "slateblue": 49,
|
||||
"slategray": 50, "orangered": 51, "silver": 52, "olivedrab": 53, "forestgreen": 54, "darkgoldenrod": 55, "ivory": 56, "darkorange": 57, "yellow": 58, "hotpink": 59,
|
||||
"ghostwhite": 60, "lightcoral": 61, "indigo": 62, "bisque": 63, "darkred": 64, "darksalmon": 65, "lightslategray": 66, "dodgerblue": 67, "lightpink": 68, "mistyrose": 69,
|
||||
"mediumvioletred": 70, "cadetblue": 71, "deeppink": 72, "salmon": 73, "palegoldenrod": 74, "blanchedalmond": 75, "lightseagreen": 76, "cornflowerblue": 77, "yellowgreen": 78, "greenyellow": 79,
|
||||
"navajowhite": 80, "papayawhip": 81, "mediumslateblue": 82, "purple": 83, "blueviolet": 84, "pink": 85, "cornsilk": 86, "lightsalmon": 87, "mediumpurple": 88, "moccasin": 89,
|
||||
"turquoise": 90, "mediumseagreen": 91, "lavenderblush": 92, "mediumblue": 93, "darkseagreen": 94, "mediumturquoise": 95, "paleturquoise": 96, "skyblue": 97, "lemonchiffon": 98, "olive": 99,
|
||||
"peachpuff": 100, "lightyellow": 101, "lightsteelblue": 102, "mediumorchid": 103, "plum": 104, "darkturquoise": 105, "aliceblue": 106, "mediumaquamarine": 107, "orchid": 108, "powderblue": 109,
|
||||
"blue": 110, "darkorchid": 111, "violet": 112, "lightskyblue": 113, "lightcyan": 114, "lightgoldenrodyellow": 115, "navy": 116, "thistle": 117, "honeydew": 118, "mintcream": 119,
|
||||
"lightblue": 120, "darkblue": 121, "darkmagenta": 122, "deepskyblue": 123, "magenta": 124, "limegreen": 125, "darkviolet": 126, "cyan": 127, "palegreen": 128, "aquamarine": 129,
|
||||
"lawngreen": 130, "lightgreen": 131, "azure": 132, "chartreuse": 133, "green": 134, "mediumspringgreen": 135, "lime": 136, "springgreen": 137}"""
|
||||
|
||||
MULTILINGUAL_10_LANG_IDX_JSON = """{"en-Montserrat-Regular": 0, "en-Poppins-Italic": 1, "en-GlacialIndifference-Regular": 2, "en-OpenSans-ExtraBoldItalic": 3, "en-Montserrat-Bold": 4, "en-Now-Regular": 5, "en-Garet-Regular": 6, "en-LeagueSpartan-Bold": 7, "en-DMSans-Regular": 8, "en-OpenSauceOne-Regular": 9,
|
||||
"en-OpenSans-ExtraBold": 10, "en-KGPrimaryPenmanship": 11, "en-Anton-Regular": 12, "en-Aileron-BlackItalic": 13, "en-Quicksand-Light": 14, "en-Roboto-BoldItalic": 15, "en-TheSeasons-It": 16, "en-Kollektif": 17, "en-Inter-BoldItalic": 18, "en-Poppins-Medium": 19,
|
||||
"en-Poppins-Light": 20, "en-RoxboroughCF-RegularItalic": 21, "en-PlayfairDisplay-SemiBold": 22, "en-Agrandir-Italic": 23, "en-Lato-Regular": 24, "en-MoreSugarRegular": 25, "en-CanvaSans-RegularItalic": 26, "en-PublicSans-Italic": 27, "en-CodePro-NormalLC": 28, "en-Belleza-Regular": 29,
|
||||
"en-JosefinSans-Bold": 30, "en-HKGrotesk-Bold": 31, "en-Telegraf-Medium": 32, "en-BrittanySignatureRegular": 33, "en-Raleway-ExtraBoldItalic": 34, "en-Mont-RegularItalic": 35, "en-Arimo-BoldItalic": 36, "en-Lora-Italic": 37, "en-ArchivoBlack-Regular": 38, "en-Poppins": 39,
|
||||
"en-Barlow-Black": 40, "en-CormorantGaramond-Bold": 41, "en-LibreBaskerville-Regular": 42, "en-CanvaSchoolFontRegular": 43, "en-BebasNeueBold": 44, "en-LazydogRegular": 45, "en-FredokaOne-Regular": 46, "en-Horizon-Bold": 47, "en-Nourd-Regular": 48, "en-Hatton-Regular": 49,
|
||||
"en-Nunito-ExtraBoldItalic": 50, "en-CerebriSans-Regular": 51, "en-Montserrat-Light": 52, "en-TenorSans": 53, "en-Norwester-Regular": 54, "en-ClearSans-Bold": 55, "en-Cardo-Regular": 56, "en-Alice-Regular": 57, "en-Oswald-Regular": 58, "en-Gaegu-Bold": 59,
|
||||
"en-Muli-Black": 60, "en-TAN-PEARL-Regular": 61, "en-CooperHewitt-Book": 62, "en-Agrandir-Grand": 63, "en-BlackMango-Thin": 64, "en-DMSerifDisplay-Regular": 65, "en-Antonio-Bold": 66, "en-Sniglet-Regular": 67, "en-BeVietnam-Regular": 68, "en-NunitoSans10pt-BlackItalic": 69,
|
||||
"en-AbhayaLibre-ExtraBold": 70, "en-Rubik-Regular": 71, "en-PPNeueMachina-Regular": 72, "en-TAN - MON CHERI-Regular": 73, "en-Jua-Regular": 74, "en-Playlist-Script": 75, "en-SourceSansPro-BoldItalic": 76, "en-MoonTime-Regular": 77, "en-Eczar-ExtraBold": 78, "en-Gatwick-Regular": 79,
|
||||
"en-MonumentExtended-Regular": 80, "en-BarlowSemiCondensed-Regular": 81, "en-BarlowCondensed-Regular": 82, "en-Alegreya-Regular": 83, "en-DreamAvenue": 84, "en-RobotoCondensed-Italic": 85, "en-BobbyJones-Regular": 86, "en-Garet-ExtraBold": 87, "en-YesevaOne-Regular": 88, "en-Dosis-ExtraBold": 89,
|
||||
"en-LeagueGothic-Regular": 90, "en-OpenSans-Italic": 91, "en-TANAEGEAN-Regular": 92, "en-Maharlika-Regular": 93, "en-MarykateRegular": 94, "en-Cinzel-Regular": 95, "en-Agrandir-Wide": 96, "en-Chewy-Regular": 97, "en-BodoniFLF-BoldItalic": 98, "en-Nunito-BlackItalic": 99,
|
||||
"en-LilitaOne": 100, "en-HandyCasualCondensed-Regular": 101, "en-Ovo": 102, "en-Livvic-Regular": 103, "en-Agrandir-Narrow": 104, "en-CrimsonPro-Italic": 105, "en-AnonymousPro-Bold": 106, "en-NF-OneLittleFont-Bold": 107, "en-RedHatDisplay-BoldItalic": 108, "en-CodecPro-Regular": 109,
|
||||
"en-HalimunRegular": 110, "en-LibreFranklin-Black": 111, "en-TeXGyreTermes-BoldItalic": 112, "en-Shrikhand-Regular": 113, "en-TTNormsPro-Italic": 114, "en-Gagalin-Regular": 115, "en-OpenSans-Bold": 116, "en-GreatVibes-Regular": 117, "en-Breathing": 118, "en-HeroLight-Regular": 119,
|
||||
"en-KGPrimaryDots": 120, "en-Quicksand-Bold": 121, "en-Brice-ExtraLightSemiExpanded": 122, "en-Lato-BoldItalic": 123, "en-Fraunces9pt-Italic": 124, "en-AbrilFatface-Regular": 125, "en-BerkshireSwash-Regular": 126, "en-Atma-Bold": 127, "en-HolidayRegular": 128, "en-BebasNeueCyrillic": 129,
|
||||
"en-IntroRust-Base": 130, "en-Gistesy": 131, "en-BDScript-Regular": 132, "en-ApricotsRegular": 133, "en-Prompt-Black": 134, "en-TAN MERINGUE": 135, "en-Sukar Regular": 136, "en-GentySans-Regular": 137, "en-NeueEinstellung-Normal": 138, "en-Garet-Bold": 139,
|
||||
"en-FiraSans-Black": 140, "en-BantayogLight": 141, "en-NotoSerifDisplay-Black": 142, "en-TTChocolates-Regular": 143, "en-Ubuntu-Regular": 144, "en-Assistant-Bold": 145, "en-ABeeZee-Regular": 146, "en-LexendDeca-Regular": 147, "en-KingredSerif": 148, "en-Radley-Regular": 149,
|
||||
"en-BrownSugar": 150, "en-MigraItalic-ExtraboldItalic": 151, "en-ChildosArabic-Regular": 152, "en-PeaceSans": 153, "en-LondrinaSolid-Black": 154, "en-SpaceMono-BoldItalic": 155, "en-RobotoMono-Light": 156, "en-CourierPrime-Regular": 157, "en-Alata-Regular": 158, "en-Amsterdam-One": 159,
|
||||
"en-IreneFlorentina-Regular": 160, "en-CatchyMager": 161, "en-Alta_regular": 162, "en-ArticulatCF-Regular": 163, "en-Raleway-Regular": 164, "en-BrasikaDisplay": 165, "en-TANAngleton-Italic": 166, "en-NotoSerifDisplay-ExtraCondensedItalic": 167, "en-Bryndan Write": 168, "en-TTCommonsPro-It": 169,
|
||||
"en-AlexBrush-Regular": 170, "en-Antic-Regular": 171, "en-TTHoves-Bold": 172, "en-DroidSerif": 173, "en-AblationRegular": 174, "en-Marcellus-Regular": 175, "en-Sanchez-Italic": 176, "en-JosefinSans": 177, "en-Afrah-Regular": 178, "en-PinyonScript": 179,
|
||||
"en-TTInterphases-BoldItalic": 180, "en-Yellowtail-Regular": 181, "en-Gliker-Regular": 182, "en-BobbyJonesSoft-Regular": 183, "en-IBMPlexSans": 184, "en-Amsterdam-Three": 185, "en-Amsterdam-FourSlant": 186, "en-TTFors-Regular": 187, "en-Quattrocento": 188, "en-Sifonn-Basic": 189,
|
||||
"en-AlegreyaSans-Black": 190, "en-Daydream": 191, "en-AristotelicaProTx-Rg": 192, "en-NotoSerif": 193, "en-EBGaramond-Italic": 194, "en-HammersmithOne-Regular": 195, "en-RobotoSlab-Regular": 196, "en-DO-Sans-Regular": 197, "en-KGPrimaryDotsLined": 198, "en-Blinker-Regular": 199,
|
||||
"en-TAN NIMBUS": 200, "en-Blueberry-Regular": 201, "en-Rosario-Regular": 202, "en-Forum": 203, "en-MistrullyRegular": 204, "en-SourceSerifPro-Regular": 205, "en-Bugaki-Regular": 206, "en-CMUSerif-Roman": 207, "en-GulfsDisplay-NormalItalic": 208, "en-PTSans-Bold": 209,
|
||||
"en-Sensei-Medium": 210, "en-SquadaOne-Regular": 211, "en-Arapey-Italic": 212, "en-Parisienne-Regular": 213, "en-Aleo-Italic": 214, "en-QuicheDisplay-Italic": 215, "en-RocaOne-It": 216, "en-Funtastic-Regular": 217, "en-PTSerif-BoldItalic": 218, "en-Muller-RegularItalic": 219,
|
||||
"en-ArgentCF-Regular": 220, "en-Brightwall-Italic": 221, "en-Knewave-Regular": 222, "en-TYSerif-D": 223, "en-Agrandir-Tight": 224, "en-AlfaSlabOne-Regular": 225, "en-TANTangkiwood-Display": 226, "en-Kief-Montaser-Regular": 227, "en-Gotham-Book": 228, "en-JuliusSansOne-Regular": 229,
|
||||
"en-CocoGothic-Italic": 230, "en-SairaCondensed-Regular": 231, "en-DellaRespira-Regular": 232, "en-Questrial-Regular": 233, "en-BukhariScript-Regular": 234, "en-HelveticaWorld-Bold": 235, "en-TANKINDRED-Display": 236, "en-CinzelDecorative-Regular": 237, "en-Vidaloka-Regular": 238, "en-AlegreyaSansSC-Black": 239,
|
||||
"en-FeelingPassionate-Regular": 240, "en-QuincyCF-Regular": 241, "en-FiraCode-Regular": 242, "en-Genty-Regular": 243, "en-Nickainley-Normal": 244, "en-RubikOne-Regular": 245, "en-Gidole-Regular": 246, "en-Borsok": 247, "en-Gordita-RegularItalic": 248, "en-Scripter-Regular": 249,
|
||||
"en-Buffalo-Regular": 250, "en-KleinText-Regular": 251, "en-Creepster-Regular": 252, "en-Arvo-Bold": 253, "en-GabrielSans-NormalItalic": 254, "en-Heebo-Black": 255, "en-LexendExa-Regular": 256, "en-BrixtonSansTC-Regular": 257, "en-GildaDisplay-Regular": 258, "en-ChunkFive-Roman": 259,
|
||||
"en-Amaranth-BoldItalic": 260, "en-BubbleboddyNeue-Regular": 261, "en-MavenPro-Bold": 262, "en-TTDrugs-Italic": 263, "en-CyGrotesk-KeyRegular": 264, "en-VarelaRound-Regular": 265, "en-Ruda-Black": 266, "en-SafiraMarch": 267, "en-BloggerSans": 268, "en-TANHEADLINE-Regular": 269,
|
||||
"en-SloopScriptPro-Regular": 270, "en-NeueMontreal-Regular": 271, "en-Schoolbell-Regular": 272, "en-SigherRegular": 273, "en-InriaSerif-Regular": 274, "en-JetBrainsMono-Regular": 275, "en-MADEEvolveSans": 276, "en-Dekko": 277, "en-Handyman-Regular": 278, "en-Aileron-BoldItalic": 279,
|
||||
"en-Bright-Italic": 280, "en-Solway-Regular": 281, "en-Higuen-Regular": 282, "en-WedgesItalic": 283, "en-TANASHFORD-BOLD": 284, "en-IBMPlexMono": 285, "en-RacingSansOne-Regular": 286, "en-RegularBrush": 287, "en-OpenSans-LightItalic": 288, "en-SpecialElite-Regular": 289,
|
||||
"en-FuturaLTPro-Medium": 290, "en-MaragsaDisplay": 291, "en-BigShouldersDisplay-Regular": 292, "en-BDSans-Regular": 293, "en-RasputinRegular": 294, "en-Yvesyvesdrawing-BoldItalic": 295, "en-Bitter-Regular": 296, "en-LuckiestGuy-Regular": 297, "en-CanvaSchoolFontDotted": 298, "en-TTFirsNeue-Italic": 299,
|
||||
"en-Sunday-Regular": 300, "en-HKGothic-MediumItalic": 301, "en-CaveatBrush-Regular": 302, "en-HeliosExt": 303, "en-ArchitectsDaughter-Regular": 304, "en-Angelina": 305, "en-Calistoga-Regular": 306, "en-ArchivoNarrow-Regular": 307, "en-ObjectSans-MediumSlanted": 308, "en-AyrLucidityCondensed-Regular": 309,
|
||||
"en-Nexa-RegularItalic": 310, "en-Lustria-Regular": 311, "en-Amsterdam-TwoSlant": 312, "en-Virtual-Regular": 313, "en-Brusher-Regular": 314, "en-NF-Lepetitcochon-Regular": 315, "en-TANTWINKLE": 316, "en-LeJour-Serif": 317, "en-Prata-Regular": 318, "en-PPWoodland-Regular": 319,
|
||||
"en-PlayfairDisplay-BoldItalic": 320, "en-AmaticSC-Regular": 321, "en-Cabin-Regular": 322, "en-Manjari-Bold": 323, "en-MrDafoe-Regular": 324, "en-TTRamillas-Italic": 325, "en-Luckybones-Bold": 326, "en-DarkerGrotesque-Light": 327, "en-BellabooRegular": 328, "en-CormorantSC-Bold": 329,
|
||||
"en-GochiHand-Regular": 330, "en-Atteron": 331, "en-RocaTwo-Lt": 332, "en-ZCOOLXiaoWei-Regular": 333, "en-TANSONGBIRD": 334, "en-HeadingNow-74Regular": 335, "en-Luthier-BoldItalic": 336, "en-Oregano-Regular": 337, "en-AyrTropikaIsland-Int": 338, "en-Mali-Regular": 339,
|
||||
"en-DidactGothic-Regular": 340, "en-Lovelace-Regular": 341, "en-BakerieSmooth-Regular": 342, "en-CarterOne": 343, "en-HussarBd": 344, "en-OldStandard-Italic": 345, "en-TAN-ASTORIA-Display": 346, "en-rugratssans-Regular": 347, "en-BMHANNA": 348, "en-BetterSaturday": 349,
|
||||
"en-AdigianaToybox": 350, "en-Sailors": 351, "en-PlayfairDisplaySC-Italic": 352, "en-Etna-Regular": 353, "en-Revive80Signature": 354, "en-CAGenerated": 355, "en-Poppins-Regular": 356, "en-Jonathan-Regular": 357, "en-Pacifico-Regular": 358, "en-Saira-Black": 359,
|
||||
"en-Loubag-Regular": 360, "en-Decalotype-Black": 361, "en-Mansalva-Regular": 362, "en-Allura-Regular": 363, "en-ProximaNova-Bold": 364, "en-TANMIGNON-DISPLAY": 365, "en-ArsenicaAntiqua-Regular": 366, "en-BreulGroteskA-RegularItalic": 367, "en-HKModular-Bold": 368, "en-TANNightingale-Regular": 369,
|
||||
"en-AristotelicaProCndTxt-Rg": 370, "en-Aprila-Regular": 371, "en-Tomorrow-Regular": 372, "en-AngellaWhite": 373, "en-KaushanScript-Regular": 374, "en-NotoSans": 375, "en-LeJour-Script": 376, "en-BrixtonTC-Regular": 377, "en-OleoScript-Regular": 378, "en-Cakerolli-Regular": 379,
|
||||
"en-Lobster-Regular": 380, "en-FrunchySerif-Regular": 381, "en-PorcelainRegular": 382, "en-AlojaExtended": 383, "en-SergioTrendy-Italic": 384, "en-LovelaceText-Bold": 385, "en-Anaktoria": 386, "en-JimmyScript-Light": 387, "en-IBMPlexSerif": 388, "en-Marta": 389,
|
||||
"en-Mango-Regular": 390, "en-Overpass-Italic": 391, "en-Hagrid-Regular": 392, "en-ElikaGorica": 393, "en-Amiko-Regular": 394, "en-EFCOBrookshire-Regular": 395, "en-Caladea-Regular": 396, "en-MoonlightBold": 397, "en-Staatliches-Regular": 398, "en-Helios-Bold": 399,
|
||||
"en-Satisfy-Regular": 400, "en-NexaScript-Regular": 401, "en-Trocchi-Regular": 402, "en-March": 403, "en-IbarraRealNova-Regular": 404, "en-Nectarine-Regular": 405, "en-Overpass-Light": 406, "en-TruetypewriterPolyglOTT": 407, "en-Bangers-Regular": 408, "en-Lazord-BoldExpandedItalic": 409,
|
||||
"en-Chloe-Regular": 410, "en-BaskervilleDisplayPT-Regular": 411, "en-Bright-Regular": 412, "en-Vollkorn-Regular": 413, "en-Harmattan": 414, "en-SortsMillGoudy-Regular": 415, "en-Biryani-Bold": 416, "en-SugoProDisplay-Italic": 417, "en-Lazord-BoldItalic": 418, "en-Alike-Regular": 419,
|
||||
"en-PermanentMarker-Regular": 420, "en-Sacramento-Regular": 421, "en-HKGroteskPro-Italic": 422, "en-Aleo-BoldItalic": 423, "en-Noot": 424, "en-TANGARLAND-Regular": 425, "en-Twister": 426, "en-Arsenal-Italic": 427, "en-Bogart-Italic": 428, "en-BethEllen-Regular": 429,
|
||||
"en-Caveat-Regular": 430, "en-BalsamiqSans-Bold": 431, "en-BreeSerif-Regular": 432, "en-CodecPro-ExtraBold": 433, "en-Pierson-Light": 434, "en-CyGrotesk-WideRegular": 435, "en-Lumios-Marker": 436, "en-Comfortaa-Bold": 437, "en-TraceFontRegular": 438, "en-RTL-AdamScript-Regular": 439,
|
||||
"en-EastmanGrotesque-Italic": 440, "en-Kalam-Bold": 441, "en-ChauPhilomeneOne-Regular": 442, "en-Coiny-Regular": 443, "en-Lovera": 444, "en-Gellatio": 445, "en-TitilliumWeb-Bold": 446, "en-OilvareBase-Italic": 447, "en-Catamaran-Black": 448, "en-Anteb-Italic": 449,
|
||||
"en-SueEllenFrancisco": 450, "en-SweetApricot": 451, "en-BrightSunshine": 452, "en-IM_FELL_Double_Pica_Italic": 453, "en-Granaina-limpia": 454, "en-TANPARFAIT": 455, "en-AcherusGrotesque-Regular": 456, "en-AwesomeLathusca-Italic": 457, "en-Signika-Bold": 458, "en-Andasia": 459,
|
||||
"en-DO-AllCaps-Slanted": 460, "en-Zenaida-Regular": 461, "en-Fahkwang-Regular": 462, "en-Play-Regular": 463, "en-BERNIERRegular-Regular": 464, "en-PlumaThin-Regular": 465, "en-SportsWorld": 466, "en-Garet-Black": 467, "en-CarolloPlayscript-BlackItalic": 468, "en-Cheque-Regular": 469,
|
||||
"en-SEGO": 470, "en-BobbyJones-Condensed": 471, "en-NexaSlab-RegularItalic": 472, "en-DancingScript-Regular": 473, "en-PaalalabasDisplayWideBETA": 474, "en-Magnolia-Script": 475, "en-OpunMai-400It": 476, "en-MadelynFill-Regular": 477, "en-ZingRust-Base": 478, "en-FingerPaint-Regular": 479,
|
||||
"en-BostonAngel-Light": 480, "en-Gliker-RegularExpanded": 481, "en-Ahsing": 482, "en-Engagement-Regular": 483, "en-EyesomeScript": 484, "en-LibraSerifModern-Regular": 485, "en-London-Regular": 486, "en-AtkinsonHyperlegible-Regular": 487, "en-StadioNow-TextItalic": 488, "en-Aniyah": 489,
|
||||
"en-ITCAvantGardePro-Bold": 490, "en-Comica-Regular": 491, "en-Coustard-Regular": 492, "en-Brice-BoldCondensed": 493, "en-TANNEWYORK-Bold": 494, "en-TANBUSTER-Bold": 495, "en-Alatsi-Regular": 496, "en-TYSerif-Book": 497, "en-Jingleberry": 498, "en-Rajdhani-Bold": 499,
|
||||
"en-LobsterTwo-BoldItalic": 500, "en-BestLight-Medium": 501, "en-Hitchcut-Regular": 502, "en-GermaniaOne-Regular": 503, "en-Emitha-Script": 504, "en-LemonTuesday": 505, "en-Cubao_Free_Regular": 506, "en-MonterchiSerif-Regular": 507, "en-AllertaStencil-Regular": 508, "en-RTL-Sondos-Regular": 509,
|
||||
"en-HomemadeApple-Regular": 510, "en-CosmicOcto-Medium": 511, "cn-HelloFont-FangHuaTi": 0, "cn-HelloFont-ID-DianFangSong-Bold": 1, "cn-HelloFont-ID-DianFangSong": 2, "cn-HelloFont-ID-DianHei-CEJ": 3, "cn-HelloFont-ID-DianHei-DEJ": 4, "cn-HelloFont-ID-DianHei-EEJ": 5, "cn-HelloFont-ID-DianHei-FEJ": 6, "cn-HelloFont-ID-DianHei-GEJ": 7, "cn-HelloFont-ID-DianKai-Bold": 8, "cn-HelloFont-ID-DianKai": 9,
|
||||
"cn-HelloFont-WenYiHei": 10, "cn-Hellofont-ID-ChenYanXingKai": 11, "cn-Hellofont-ID-DaZiBao": 12, "cn-Hellofont-ID-DaoCaoRen": 13, "cn-Hellofont-ID-JianSong": 14, "cn-Hellofont-ID-JiangHuZhaoPaiHei": 15, "cn-Hellofont-ID-KeSong": 16, "cn-Hellofont-ID-LeYuanTi": 17, "cn-Hellofont-ID-Pinocchio": 18, "cn-Hellofont-ID-QiMiaoTi": 19,
|
||||
"cn-Hellofont-ID-QingHuaKai": 20, "cn-Hellofont-ID-QingHuaXingKai": 21, "cn-Hellofont-ID-ShanShuiXingKai": 22, "cn-Hellofont-ID-ShouXieQiShu": 23, "cn-Hellofont-ID-ShouXieTongZhenTi": 24, "cn-Hellofont-ID-TengLingTi": 25, "cn-Hellofont-ID-XiaoLiShu": 26, "cn-Hellofont-ID-XuanZhenSong": 27, "cn-Hellofont-ID-ZhongLingXingKai": 28, "cn-HellofontIDJiaoTangTi": 29,
|
||||
"cn-HellofontIDJiuZhuTi": 30, "cn-HuXiaoBao-SaoBao": 31, "cn-HuXiaoBo-NanShen": 32, "cn-HuXiaoBo-ZhenShuai": 33, "cn-SourceHanSansSC-Bold": 34, "cn-SourceHanSansSC-ExtraLight": 35, "cn-SourceHanSansSC-Heavy": 36, "cn-SourceHanSansSC-Light": 37, "cn-SourceHanSansSC-Medium": 38, "cn-SourceHanSansSC-Normal": 39,
|
||||
"cn-SourceHanSansSC-Regular": 40, "cn-SourceHanSerifSC-Bold": 41, "cn-SourceHanSerifSC-ExtraLight": 42, "cn-SourceHanSerifSC-Heavy": 43, "cn-SourceHanSerifSC-Light": 44, "cn-SourceHanSerifSC-Medium": 45, "cn-SourceHanSerifSC-Regular": 46, "cn-SourceHanSerifSC-SemiBold": 47, "cn-xiaowei": 48, "cn-AaJianHaoTi": 49,
|
||||
"cn-AlibabaPuHuiTi-Bold": 50, "cn-AlibabaPuHuiTi-Heavy": 51, "cn-AlibabaPuHuiTi-Light": 52, "cn-AlibabaPuHuiTi-Medium": 53, "cn-AlibabaPuHuiTi-Regular": 54, "cn-CanvaAcidBoldSC": 55, "cn-CanvaBreezeCN": 56, "cn-CanvaBumperCropSC": 57, "cn-CanvaCakeShopCN": 58, "cn-CanvaEndeavorBlackSC": 59,
|
||||
"cn-CanvaJoyHeiCN": 60, "cn-CanvaLiCN": 61, "cn-CanvaOrientalBrushCN": 62, "cn-CanvaPoster": 63, "cn-CanvaQinfuCalligraphyCN": 64, "cn-CanvaSweetHeartCN": 65, "cn-CanvaSwordLikeDreamCN": 66, "cn-CanvaTangyuanHandwritingCN": 67, "cn-CanvaWanderWorldCN": 68, "cn-CanvaWenCN": 69,
|
||||
"cn-DianZiChunYi": 70, "cn-GenSekiGothicTW-H": 71, "cn-GenWanMinTW-L": 72, "cn-GenYoMinTW-B": 73, "cn-GenYoMinTW-EL": 74, "cn-GenYoMinTW-H": 75, "cn-GenYoMinTW-M": 76, "cn-GenYoMinTW-R": 77, "cn-GenYoMinTW-SB": 78, "cn-HYQiHei-AZEJ": 79,
|
||||
"cn-HYQiHei-EES": 80, "cn-HanaMinA": 81, "cn-HappyZcool-2016": 82, "cn-HelloFont ZJ KeKouKeAiTi": 83, "cn-HelloFont-ID-BoBoTi": 84, "cn-HelloFont-ID-FuGuHei-25": 85, "cn-HelloFont-ID-FuGuHei-35": 86, "cn-HelloFont-ID-FuGuHei-45": 87, "cn-HelloFont-ID-FuGuHei-55": 88, "cn-HelloFont-ID-FuGuHei-65": 89,
|
||||
"cn-HelloFont-ID-FuGuHei-75": 90, "cn-HelloFont-ID-FuGuHei-85": 91, "cn-HelloFont-ID-HeiKa": 92, "cn-HelloFont-ID-HeiTang": 93, "cn-HelloFont-ID-JianSong-95": 94, "cn-HelloFont-ID-JueJiangHei-50": 95, "cn-HelloFont-ID-JueJiangHei-55": 96, "cn-HelloFont-ID-JueJiangHei-60": 97, "cn-HelloFont-ID-JueJiangHei-65": 98, "cn-HelloFont-ID-JueJiangHei-70": 99,
|
||||
"cn-HelloFont-ID-JueJiangHei-75": 100, "cn-HelloFont-ID-JueJiangHei-80": 101, "cn-HelloFont-ID-KuHeiTi": 102, "cn-HelloFont-ID-LingDongTi": 103, "cn-HelloFont-ID-LingLiTi": 104, "cn-HelloFont-ID-MuFengTi": 105, "cn-HelloFont-ID-NaiNaiJiangTi": 106, "cn-HelloFont-ID-PangDu": 107, "cn-HelloFont-ID-ReLieTi": 108, "cn-HelloFont-ID-RouRun": 109,
|
||||
"cn-HelloFont-ID-SaShuangShouXieTi": 110, "cn-HelloFont-ID-WangZheFengFan": 111, "cn-HelloFont-ID-YouQiTi": 112, "cn-Hellofont-ID-XiaLeTi": 113, "cn-Hellofont-ID-XianXiaTi": 114, "cn-HuXiaoBoKuHei": 115, "cn-IDDanMoXingKai": 116, "cn-IDJueJiangHei": 117, "cn-IDMeiLingTi": 118, "cn-IDQQSugar": 119,
|
||||
"cn-LiuJianMaoCao-Regular": 120, "cn-LongCang-Regular": 121, "cn-MaShanZheng-Regular": 122, "cn-PangMenZhengDao-3": 123, "cn-PangMenZhengDao-Cu": 124, "cn-PangMenZhengDao": 125, "cn-SentyCaramel": 126, "cn-SourceHanSerifSC": 127, "cn-WenCang-Regular": 128, "cn-WenQuanYiMicroHei": 129,
|
||||
"cn-XianErTi": 130, "cn-YRDZSTJF": 131, "cn-YS-HelloFont-BangBangTi": 132, "cn-ZCOOLKuaiLe-Regular": 133, "cn-ZCOOLQingKeHuangYou-Regular": 134, "cn-ZCOOLXiaoWei-Regular": 135, "cn-ZCOOL_KuHei": 136, "cn-ZhiMangXing-Regular": 137, "cn-baotuxiaobaiti": 138, "cn-jiangxizhuokai-Regular": 139,
|
||||
"cn-zcool-gdh": 140, "cn-zcoolqingkehuangyouti-Regular": 141, "cn-zcoolwenyiti": 142, "jp-04KanjyukuGothic": 0, "jp-07LightNovelPOP": 1, "jp-07NikumaruFont": 2, "jp-07YasashisaAntique": 3, "jp-07YasashisaGothic": 4, "jp-BokutachinoGothic2Bold": 5, "jp-BokutachinoGothic2Regular": 6, "jp-CHI_SpeedyRight_full_211128-Regular": 7, "jp-CHI_SpeedyRight_italic_full_211127-Regular": 8, "jp-CP-Font": 9,
|
||||
"jp-Canva_CezanneProN-B": 10, "jp-Canva_CezanneProN-M": 11, "jp-Canva_ChiaroStd-B": 12, "jp-Canva_CometStd-B": 13, "jp-Canva_DotMincho16Std-M": 14, "jp-Canva_GrecoStd-B": 15, "jp-Canva_GrecoStd-M": 16, "jp-Canva_LyraStd-DB": 17, "jp-Canva_MatisseHatsuhiPro-B": 18, "jp-Canva_MatisseHatsuhiPro-M": 19,
|
||||
"jp-Canva_ModeMinAStd-B": 20, "jp-Canva_NewCezanneProN-B": 21, "jp-Canva_NewCezanneProN-M": 22, "jp-Canva_PearlStd-L": 23, "jp-Canva_RaglanStd-UB": 24, "jp-Canva_RailwayStd-B": 25, "jp-Canva_ReggaeStd-B": 26, "jp-Canva_RocknRollStd-DB": 27, "jp-Canva_RodinCattleyaPro-B": 28, "jp-Canva_RodinCattleyaPro-M": 29,
|
||||
"jp-Canva_RodinCattleyaPro-UB": 30, "jp-Canva_RodinHimawariPro-B": 31, "jp-Canva_RodinHimawariPro-M": 32, "jp-Canva_RodinMariaPro-B": 33, "jp-Canva_RodinMariaPro-DB": 34, "jp-Canva_RodinProN-M": 35, "jp-Canva_ShadowTLStd-B": 36, "jp-Canva_StickStd-B": 37, "jp-Canva_TsukuAOldMinPr6N-B": 38, "jp-Canva_TsukuAOldMinPr6N-R": 39,
|
||||
"jp-Canva_UtrilloPro-DB": 40, "jp-Canva_UtrilloPro-M": 41, "jp-Canva_YurukaStd-UB": 42, "jp-FGUIGEN": 43, "jp-GlowSansJ-Condensed-Heavy": 44, "jp-GlowSansJ-Condensed-Light": 45, "jp-GlowSansJ-Normal-Bold": 46, "jp-GlowSansJ-Normal-Light": 47, "jp-HannariMincho": 48, "jp-HarenosoraMincho": 49,
|
||||
"jp-Jiyucho": 50, "jp-Kaiso-Makina-B": 51, "jp-Kaisotai-Next-UP-B": 52, "jp-KokoroMinchoutai": 53, "jp-Mamelon-3-Hi-Regular": 54, "jp-MotoyaAnemoneStd-W1": 55, "jp-MotoyaAnemoneStd-W5": 56, "jp-MotoyaAnticPro-W3": 57, "jp-MotoyaCedarStd-W3": 58, "jp-MotoyaCedarStd-W5": 59,
|
||||
"jp-MotoyaGochikaStd-W4": 60, "jp-MotoyaGochikaStd-W8": 61, "jp-MotoyaGothicMiyabiStd-W6": 62, "jp-MotoyaGothicStd-W3": 63, "jp-MotoyaGothicStd-W5": 64, "jp-MotoyaKoinStd-W3": 65, "jp-MotoyaKyotaiStd-W2": 66, "jp-MotoyaKyotaiStd-W4": 67, "jp-MotoyaMaruStd-W3": 68, "jp-MotoyaMaruStd-W5": 69,
|
||||
"jp-MotoyaMinchoMiyabiStd-W4": 70, "jp-MotoyaMinchoMiyabiStd-W6": 71, "jp-MotoyaMinchoModernStd-W4": 72, "jp-MotoyaMinchoModernStd-W6": 73, "jp-MotoyaMinchoStd-W3": 74, "jp-MotoyaMinchoStd-W5": 75, "jp-MotoyaReisyoStd-W2": 76, "jp-MotoyaReisyoStd-W6": 77, "jp-MotoyaTohitsuStd-W4": 78, "jp-MotoyaTohitsuStd-W6": 79,
|
||||
"jp-MtySousyokuEmBcJis-W6": 80, "jp-MtySousyokuLiBcJis-W6": 81, "jp-Mushin": 82, "jp-NotoSansJP-Bold": 83, "jp-NotoSansJP-Regular": 84, "jp-NudMotoyaAporoStd-W3": 85, "jp-NudMotoyaAporoStd-W5": 86, "jp-NudMotoyaCedarStd-W3": 87, "jp-NudMotoyaCedarStd-W5": 88, "jp-NudMotoyaMaruStd-W3": 89,
|
||||
"jp-NudMotoyaMaruStd-W5": 90, "jp-NudMotoyaMinchoStd-W5": 91, "jp-Ounen-mouhitsu": 92, "jp-Ronde-B-Square": 93, "jp-SMotoyaGyosyoStd-W5": 94, "jp-SMotoyaSinkaiStd-W3": 95, "jp-SMotoyaSinkaiStd-W5": 96, "jp-SourceHanSansJP-Bold": 97, "jp-SourceHanSansJP-Regular": 98, "jp-SourceHanSerifJP-Bold": 99,
|
||||
"jp-SourceHanSerifJP-Regular": 100, "jp-TazuganeGothicStdN-Bold": 101, "jp-TazuganeGothicStdN-Regular": 102, "jp-TelopMinProN-B": 103, "jp-Togalite-Bold": 104, "jp-Togalite-Regular": 105, "jp-TsukuMinPr6N-E": 106, "jp-TsukuMinPr6N-M": 107, "jp-mikachan_o": 108, "jp-nagayama_kai": 109,
|
||||
"jp-07LogoTypeGothic7": 110, "jp-07TetsubinGothic": 111, "jp-851CHIKARA-DZUYOKU-KANA-A": 112, "jp-ARMinchoJIS-Light": 113, "jp-ARMinchoJIS-Ultra": 114, "jp-ARPCrystalMinchoJIS-Medium": 115, "jp-ARPCrystalRGothicJIS-Medium": 116, "jp-ARShounanShinpitsuGyosyoJIS-Medium": 117, "jp-AozoraMincho-bold": 118, "jp-AozoraMinchoRegular": 119,
|
||||
"jp-ArialUnicodeMS-Bold": 120, "jp-ArialUnicodeMS": 121, "jp-CanvaBreezeJP": 122, "jp-CanvaLiCN": 123, "jp-CanvaLiJP": 124, "jp-CanvaOrientalBrushCN": 125, "jp-CanvaQinfuCalligraphyJP": 126, "jp-CanvaSweetHeartJP": 127, "jp-CanvaWenJP": 128, "jp-Corporate-Logo-Bold": 129,
|
||||
"jp-DelaGothicOne-Regular": 130, "jp-GN-Kin-iro_SansSerif": 131, "jp-GN-Koharuiro_Sunray": 132, "jp-GenEiGothicM-B": 133, "jp-GenEiGothicM-R": 134, "jp-GenJyuuGothic-Bold": 135, "jp-GenRyuMinTW-B": 136, "jp-GenRyuMinTW-R": 137, "jp-GenSekiGothicTW-B": 138, "jp-GenSekiGothicTW-R": 139,
|
||||
"jp-GenSenRoundedTW-B": 140, "jp-GenSenRoundedTW-R": 141, "jp-GenShinGothic-Bold": 142, "jp-GenShinGothic-Normal": 143, "jp-GenWanMinTW-L": 144, "jp-GenYoGothicTW-B": 145, "jp-GenYoGothicTW-R": 146, "jp-GenYoMinTW-B": 147, "jp-GenYoMinTW-R": 148, "jp-HGBouquet": 149,
|
||||
"jp-HanaMinA": 150, "jp-HanazomeFont": 151, "jp-HinaMincho-Regular": 152, "jp-Honoka-Antique-Maru": 153, "jp-Honoka-Mincho": 154, "jp-HuiFontP": 155, "jp-IPAexMincho": 156, "jp-JK-Gothic-L": 157, "jp-JK-Gothic-M": 158, "jp-JackeyFont": 159,
|
||||
"jp-KaiseiTokumin-Bold": 160, "jp-KaiseiTokumin-Regular": 161, "jp-Keifont": 162, "jp-KiwiMaru-Regular": 163, "jp-Koku-Mincho-Regular": 164, "jp-MotoyaLMaru-W3-90ms-RKSJ-H": 165, "jp-NewTegomin-Regular": 166, "jp-NicoKaku": 167, "jp-NicoMoji+": 168, "jp-Otsutome_font-Bold": 169,
|
||||
"jp-PottaOne-Regular": 170, "jp-RampartOne-Regular": 171, "jp-Senobi-Gothic-Bold": 172, "jp-Senobi-Gothic-Regular": 173, "jp-SmartFontUI-Proportional": 174, "jp-SoukouMincho": 175, "jp-TEST_Klee-DB": 176, "jp-TEST_Klee-M": 177, "jp-TEST_UDMincho-B": 178, "jp-TEST_UDMincho-L": 179,
|
||||
"jp-TT_Akakane-EB": 180, "jp-Tanuki-Permanent-Marker": 181, "jp-TrainOne-Regular": 182, "jp-TsunagiGothic-Black": 183, "jp-Ume-Hy-Gothic": 184, "jp-Ume-P-Mincho": 185, "jp-WenQuanYiMicroHei": 186, "jp-XANO-mincho-U32": 187, "jp-YOzFontM90-Regular": 188, "jp-Yomogi-Regular": 189,
|
||||
"jp-YujiBoku-Regular": 190, "jp-YujiSyuku-Regular": 191, "jp-ZenKakuGothicNew-Bold": 192, "jp-ZenKakuGothicNew-Regular": 193, "jp-ZenKurenaido-Regular": 194, "jp-ZenMaruGothic-Bold": 195, "jp-ZenMaruGothic-Regular": 196, "jp-darts-font": 197, "jp-irohakakuC-Bold": 198, "jp-irohakakuC-Medium": 199,
|
||||
"jp-irohakakuC-Regular": 200, "jp-katyou": 201, "jp-mplus-1m-bold": 202, "jp-mplus-1m-regular": 203, "jp-mplus-1p-bold": 204, "jp-mplus-1p-regular": 205, "jp-rounded-mplus-1p-bold": 206, "jp-rounded-mplus-1p-regular": 207, "jp-timemachine-wa": 208, "jp-ttf-GenEiLateMin-Medium": 209,
|
||||
"jp-uzura_font": 210, "kr-Arita-buri-Bold_OTF": 0, "kr-Arita-buri-HairLine_OTF": 1, "kr-Arita-buri-Light_OTF": 2, "kr-Arita-buri-Medium_OTF": 3, "kr-Arita-buri-SemiBold_OTF": 4, "kr-Canva_YDSunshineL": 5, "kr-Canva_YDSunshineM": 6, "kr-Canva_YoonGulimPro710": 7, "kr-Canva_YoonGulimPro730": 8, "kr-Canva_YoonGulimPro740": 9,
|
||||
"kr-Canva_YoonGulimPro760": 10, "kr-Canva_YoonGulimPro770": 11, "kr-Canva_YoonGulimPro790": 12, "kr-CreHappB": 13, "kr-CreHappL": 14, "kr-CreHappM": 15, "kr-CreHappS": 16, "kr-OTAuroraB": 17, "kr-OTAuroraL": 18, "kr-OTAuroraR": 19,
|
||||
"kr-OTDoldamgilB": 20, "kr-OTDoldamgilL": 21, "kr-OTDoldamgilR": 22, "kr-OTHamsterB": 23, "kr-OTHamsterL": 24, "kr-OTHamsterR": 25, "kr-OTHapchangdanB": 26, "kr-OTHapchangdanL": 27, "kr-OTHapchangdanR": 28, "kr-OTSupersizeBkBOX": 29,
|
||||
"kr-SourceHanSansKR-Bold": 30, "kr-SourceHanSansKR-ExtraLight": 31, "kr-SourceHanSansKR-Heavy": 32, "kr-SourceHanSansKR-Light": 33, "kr-SourceHanSansKR-Medium": 34, "kr-SourceHanSansKR-Normal": 35, "kr-SourceHanSansKR-Regular": 36, "kr-SourceHanSansSC-Bold": 37, "kr-SourceHanSansSC-ExtraLight": 38, "kr-SourceHanSansSC-Heavy": 39,
|
||||
"kr-SourceHanSansSC-Light": 40, "kr-SourceHanSansSC-Medium": 41, "kr-SourceHanSansSC-Normal": 42, "kr-SourceHanSansSC-Regular": 43, "kr-SourceHanSerifSC-Bold": 44, "kr-SourceHanSerifSC-SemiBold": 45, "kr-TDTDBubbleBubbleOTF": 46, "kr-TDTDConfusionOTF": 47, "kr-TDTDCuteAndCuteOTF": 48, "kr-TDTDEggTakOTF": 49,
|
||||
"kr-TDTDEmotionalLetterOTF": 50, "kr-TDTDGalapagosOTF": 51, "kr-TDTDHappyHourOTF": 52, "kr-TDTDLatteOTF": 53, "kr-TDTDMoonLightOTF": 54, "kr-TDTDParkForestOTF": 55, "kr-TDTDPencilOTF": 56, "kr-TDTDSmileOTF": 57, "kr-TDTDSproutOTF": 58, "kr-TDTDSunshineOTF": 59,
|
||||
"kr-TDTDWaferOTF": 60, "kr-777Chyaochyureu": 61, "kr-ArialUnicodeMS-Bold": 62, "kr-ArialUnicodeMS": 63, "kr-BMHANNA": 64, "kr-Baekmuk-Dotum": 65, "kr-BagelFatOne-Regular": 66, "kr-CoreBandi": 67, "kr-CoreBandiFace": 68, "kr-CoreBori": 69,
|
||||
"kr-DoHyeon-Regular": 70, "kr-Dokdo-Regular": 71, "kr-Gaegu-Bold": 72, "kr-Gaegu-Light": 73, "kr-Gaegu-Regular": 74, "kr-GamjaFlower-Regular": 75, "kr-GasoekOne-Regular": 76, "kr-GothicA1-Black": 77, "kr-GothicA1-Bold": 78, "kr-GothicA1-ExtraBold": 79,
|
||||
"kr-GothicA1-ExtraLight": 80, "kr-GothicA1-Light": 81, "kr-GothicA1-Medium": 82, "kr-GothicA1-Regular": 83, "kr-GothicA1-SemiBold": 84, "kr-GothicA1-Thin": 85, "kr-Gugi-Regular": 86, "kr-HiMelody-Regular": 87, "kr-Jua-Regular": 88, "kr-KirangHaerang-Regular": 89,
|
||||
"kr-NanumBrush": 90, "kr-NanumPen": 91, "kr-NanumSquareRoundB": 92, "kr-NanumSquareRoundEB": 93, "kr-NanumSquareRoundL": 94, "kr-NanumSquareRoundR": 95, "kr-SeH-CB": 96, "kr-SeH-CBL": 97, "kr-SeH-CEB": 98, "kr-SeH-CL": 99,
|
||||
"kr-SeH-CM": 100, "kr-SeN-CB": 101, "kr-SeN-CBL": 102, "kr-SeN-CEB": 103, "kr-SeN-CL": 104, "kr-SeN-CM": 105, "kr-Sunflower-Bold": 106, "kr-Sunflower-Light": 107, "kr-Sunflower-Medium": 108, "kr-TTClaytoyR": 109,
|
||||
"kr-TTDalpangiR": 110, "kr-TTMamablockR": 111, "kr-TTNauidongmuR": 112, "kr-TTOktapbangR": 113, "kr-UhBeeMiMi": 114, "kr-UhBeeMiMiBold": 115, "kr-UhBeeSe_hyun": 116, "kr-UhBeeSe_hyunBold": 117, "kr-UhBeenamsoyoung": 118, "kr-UhBeenamsoyoungBold": 119,
|
||||
"kr-WenQuanYiMicroHei": 120, "kr-YeonSung-Regular": 121}"""
|
||||
|
||||
|
||||
def add_special_token(tokenizer: T5Tokenizer, text_encoder: T5Stack):
|
||||
"""
|
||||
Add special tokens for color and font to tokenizer and text encoder.
|
||||
|
||||
Args:
|
||||
tokenizer: Huggingface tokenizer.
|
||||
text_encoder: Huggingface T5 encoder.
|
||||
"""
|
||||
idx_font_dict = json.loads(MULTILINGUAL_10_LANG_IDX_JSON)
|
||||
idx_color_dict = json.loads(COLOR_IDX_JSON)
|
||||
|
||||
font_token = [f"<{font_code[:2]}-font-{idx_font_dict[font_code]}>" for font_code in idx_font_dict]
|
||||
color_token = [f"<color-{i}>" for i in range(len(idx_color_dict))]
|
||||
additional_special_tokens = []
|
||||
additional_special_tokens += color_token
|
||||
additional_special_tokens += font_token
|
||||
|
||||
tokenizer.add_tokens(additional_special_tokens, special_tokens=True)
|
||||
# Set mean_resizing=False to avoid PyTorch LAPACK dependency
|
||||
text_encoder.resize_token_embeddings(len(tokenizer), mean_resizing=False)
|
||||
|
||||
|
||||
def load_byt5(
|
||||
ckpt_path: str,
|
||||
dtype: Optional[torch.dtype],
|
||||
device: Union[str, torch.device],
|
||||
disable_mmap: bool = False,
|
||||
state_dict: Optional[dict] = None,
|
||||
) -> Tuple[T5Stack, T5Tokenizer]:
|
||||
BYT5_CONFIG_JSON = """
|
||||
{
|
||||
"_name_or_path": "/home/patrick/t5/byt5-small",
|
||||
"architectures": [
|
||||
"T5ForConditionalGeneration"
|
||||
],
|
||||
"d_ff": 3584,
|
||||
"d_kv": 64,
|
||||
"d_model": 1472,
|
||||
"decoder_start_token_id": 0,
|
||||
"dropout_rate": 0.1,
|
||||
"eos_token_id": 1,
|
||||
"feed_forward_proj": "gated-gelu",
|
||||
"gradient_checkpointing": false,
|
||||
"initializer_factor": 1.0,
|
||||
"is_encoder_decoder": true,
|
||||
"layer_norm_epsilon": 1e-06,
|
||||
"model_type": "t5",
|
||||
"num_decoder_layers": 4,
|
||||
"num_heads": 6,
|
||||
"num_layers": 12,
|
||||
"pad_token_id": 0,
|
||||
"relative_attention_num_buckets": 32,
|
||||
"tie_word_embeddings": false,
|
||||
"tokenizer_class": "ByT5Tokenizer",
|
||||
"transformers_version": "4.7.0.dev0",
|
||||
"use_cache": true,
|
||||
"vocab_size": 384
|
||||
}
|
||||
"""
|
||||
|
||||
logger.info(f"Loading BYT5 tokenizer from {BYT5_TOKENIZER_PATH}")
|
||||
byt5_tokenizer = AutoTokenizer.from_pretrained(BYT5_TOKENIZER_PATH)
|
||||
|
||||
logger.info("Initializing BYT5 text encoder")
|
||||
config = json.loads(BYT5_CONFIG_JSON)
|
||||
config = T5Config(**config)
|
||||
with init_empty_weights():
|
||||
byt5_text_encoder = T5ForConditionalGeneration._from_config(config).get_encoder()
|
||||
|
||||
add_special_token(byt5_tokenizer, byt5_text_encoder)
|
||||
|
||||
if state_dict is not None:
|
||||
sd = state_dict
|
||||
else:
|
||||
logger.info(f"Loading state dict from {ckpt_path}")
|
||||
sd = load_safetensors(ckpt_path, device, disable_mmap=disable_mmap, dtype=dtype)
|
||||
|
||||
# remove "encoder." prefix
|
||||
sd = {k[len("encoder.") :] if k.startswith("encoder.") else k: v for k, v in sd.items()}
|
||||
sd["embed_tokens.weight"] = sd.pop("shared.weight")
|
||||
|
||||
info = byt5_text_encoder.load_state_dict(sd, strict=True, assign=True)
|
||||
byt5_text_encoder.to(device)
|
||||
logger.info(f"BYT5 text encoder loaded with info: {info}")
|
||||
|
||||
return byt5_tokenizer, byt5_text_encoder
|
||||
|
||||
|
||||
def load_qwen2_5_vl(
|
||||
ckpt_path: str,
|
||||
dtype: Optional[torch.dtype],
|
||||
device: Union[str, torch.device],
|
||||
disable_mmap: bool = False,
|
||||
state_dict: Optional[dict] = None,
|
||||
) -> tuple[Qwen2Tokenizer, Qwen2_5_VLForConditionalGeneration]:
|
||||
QWEN2_5_VL_CONFIG_JSON = """
|
||||
{
|
||||
"architectures": [
|
||||
"Qwen2_5_VLForConditionalGeneration"
|
||||
],
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 151643,
|
||||
"eos_token_id": 151645,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 3584,
|
||||
"image_token_id": 151655,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 18944,
|
||||
"max_position_embeddings": 128000,
|
||||
"max_window_layers": 28,
|
||||
"model_type": "qwen2_5_vl",
|
||||
"num_attention_heads": 28,
|
||||
"num_hidden_layers": 28,
|
||||
"num_key_value_heads": 4,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_scaling": {
|
||||
"mrope_section": [
|
||||
16,
|
||||
24,
|
||||
24
|
||||
],
|
||||
"rope_type": "default",
|
||||
"type": "default"
|
||||
},
|
||||
"rope_theta": 1000000.0,
|
||||
"sliding_window": 32768,
|
||||
"text_config": {
|
||||
"architectures": [
|
||||
"Qwen2_5_VLForConditionalGeneration"
|
||||
],
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 151643,
|
||||
"eos_token_id": 151645,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 3584,
|
||||
"image_token_id": null,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 18944,
|
||||
"layer_types": [
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention",
|
||||
"full_attention"
|
||||
],
|
||||
"max_position_embeddings": 128000,
|
||||
"max_window_layers": 28,
|
||||
"model_type": "qwen2_5_vl_text",
|
||||
"num_attention_heads": 28,
|
||||
"num_hidden_layers": 28,
|
||||
"num_key_value_heads": 4,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_scaling": {
|
||||
"mrope_section": [
|
||||
16,
|
||||
24,
|
||||
24
|
||||
],
|
||||
"rope_type": "default",
|
||||
"type": "default"
|
||||
},
|
||||
"rope_theta": 1000000.0,
|
||||
"sliding_window": null,
|
||||
"torch_dtype": "float32",
|
||||
"use_cache": true,
|
||||
"use_sliding_window": false,
|
||||
"video_token_id": null,
|
||||
"vision_end_token_id": 151653,
|
||||
"vision_start_token_id": 151652,
|
||||
"vision_token_id": 151654,
|
||||
"vocab_size": 152064
|
||||
},
|
||||
"tie_word_embeddings": false,
|
||||
"torch_dtype": "bfloat16",
|
||||
"transformers_version": "4.53.1",
|
||||
"use_cache": true,
|
||||
"use_sliding_window": false,
|
||||
"video_token_id": 151656,
|
||||
"vision_config": {
|
||||
"depth": 32,
|
||||
"fullatt_block_indexes": [
|
||||
7,
|
||||
15,
|
||||
23,
|
||||
31
|
||||
],
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 1280,
|
||||
"in_channels": 3,
|
||||
"in_chans": 3,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 3420,
|
||||
"model_type": "qwen2_5_vl",
|
||||
"num_heads": 16,
|
||||
"out_hidden_size": 3584,
|
||||
"patch_size": 14,
|
||||
"spatial_merge_size": 2,
|
||||
"spatial_patch_size": 14,
|
||||
"temporal_patch_size": 2,
|
||||
"tokens_per_second": 2,
|
||||
"torch_dtype": "float32",
|
||||
"window_size": 112
|
||||
},
|
||||
"vision_end_token_id": 151653,
|
||||
"vision_start_token_id": 151652,
|
||||
"vision_token_id": 151654,
|
||||
"vocab_size": 152064
|
||||
}
|
||||
"""
|
||||
config = json.loads(QWEN2_5_VL_CONFIG_JSON)
|
||||
config = Qwen2_5_VLConfig(**config)
|
||||
with init_empty_weights():
|
||||
qwen2_5_vl = Qwen2_5_VLForConditionalGeneration._from_config(config)
|
||||
|
||||
if state_dict is not None:
|
||||
sd = state_dict
|
||||
else:
|
||||
logger.info(f"Loading state dict from {ckpt_path}")
|
||||
sd = load_safetensors(ckpt_path, device, disable_mmap=disable_mmap, dtype=dtype)
|
||||
|
||||
# convert prefixes
|
||||
for key in list(sd.keys()):
|
||||
if key.startswith("model."):
|
||||
new_key = key.replace("model.", "model.language_model.", 1)
|
||||
elif key.startswith("visual."):
|
||||
new_key = key.replace("visual.", "model.visual.", 1)
|
||||
else:
|
||||
continue
|
||||
if key not in sd:
|
||||
logger.warning(f"Key {key} not found in state dict, skipping.")
|
||||
continue
|
||||
sd[new_key] = sd.pop(key)
|
||||
|
||||
info = qwen2_5_vl.load_state_dict(sd, strict=True, assign=True)
|
||||
logger.info(f"Loaded Qwen2.5-VL: {info}")
|
||||
qwen2_5_vl.to(device)
|
||||
|
||||
if dtype is not None:
|
||||
if dtype.itemsize == 1: # fp8
|
||||
org_dtype = torch.bfloat16 # model weight is fp8 in loading, but original dtype is bfloat16
|
||||
logger.info(f"prepare Qwen2.5-VL for fp8: set to {dtype} from {org_dtype}")
|
||||
qwen2_5_vl.to(dtype)
|
||||
|
||||
# prepare LLM for fp8
|
||||
def prepare_fp8(vl_model: Qwen2_5_VLForConditionalGeneration, target_dtype):
|
||||
def forward_hook(module):
|
||||
def forward(hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + module.variance_epsilon)
|
||||
# return module.weight.to(input_dtype) * hidden_states.to(input_dtype)
|
||||
return (module.weight.to(torch.float32) * hidden_states.to(torch.float32)).to(input_dtype)
|
||||
|
||||
return forward
|
||||
|
||||
def decoder_forward_hook(module):
|
||||
def forward(
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||
**kwargs,
|
||||
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = module.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights = module.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = residual.to(torch.float32) + hidden_states.to(torch.float32)
|
||||
hidden_states = hidden_states.to(input_dtype)
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = module.post_attention_layernorm(hidden_states)
|
||||
hidden_states = module.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
return outputs
|
||||
|
||||
return forward
|
||||
|
||||
for module in vl_model.modules():
|
||||
if module.__class__.__name__ in ["Embedding"]:
|
||||
# print("set", module.__class__.__name__, "to", target_dtype)
|
||||
module.to(target_dtype)
|
||||
if module.__class__.__name__ in ["Qwen2RMSNorm"]:
|
||||
# print("set", module.__class__.__name__, "hooks")
|
||||
module.forward = forward_hook(module)
|
||||
if module.__class__.__name__ in ["Qwen2_5_VLDecoderLayer"]:
|
||||
# print("set", module.__class__.__name__, "hooks")
|
||||
module.forward = decoder_forward_hook(module)
|
||||
if module.__class__.__name__ in ["Qwen2_5_VisionRotaryEmbedding"]:
|
||||
# print("set", module.__class__.__name__, "hooks")
|
||||
module.to(target_dtype)
|
||||
|
||||
prepare_fp8(qwen2_5_vl, org_dtype)
|
||||
|
||||
else:
|
||||
logger.info(f"Setting Qwen2.5-VL to dtype: {dtype}")
|
||||
qwen2_5_vl.to(dtype)
|
||||
|
||||
# Load tokenizer
|
||||
logger.info(f"Loading tokenizer from {QWEN_2_5_VL_IMAGE_ID}")
|
||||
tokenizer = Qwen2Tokenizer.from_pretrained(QWEN_2_5_VL_IMAGE_ID)
|
||||
return tokenizer, qwen2_5_vl
|
||||
|
||||
|
||||
def get_qwen_prompt_embeds(
|
||||
tokenizer: Qwen2Tokenizer, vlm: Qwen2_5_VLForConditionalGeneration, prompt: Union[str, list[str]] = None
|
||||
):
|
||||
tokenizer_max_length = 1024
|
||||
|
||||
# HunyuanImage-2.1 does not use "<|im_start|>assistant\n" in the prompt template
|
||||
prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
|
||||
# \n<|im_start|>assistant\n"
|
||||
prompt_template_encode_start_idx = 34
|
||||
# default_sample_size = 128
|
||||
|
||||
device = vlm.device
|
||||
dtype = vlm.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
template = prompt_template_encode
|
||||
drop_idx = prompt_template_encode_start_idx
|
||||
txt = [template.format(e) for e in prompt]
|
||||
txt_tokens = tokenizer(txt, max_length=tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt").to(
|
||||
device
|
||||
)
|
||||
|
||||
if dtype.itemsize == 1: # fp8
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=True):
|
||||
encoder_hidden_states = vlm(
|
||||
input_ids=txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask, output_hidden_states=True
|
||||
)
|
||||
else:
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type, dtype=dtype, enabled=True):
|
||||
encoder_hidden_states = vlm(
|
||||
input_ids=txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask, output_hidden_states=True
|
||||
)
|
||||
hidden_states = encoder_hidden_states.hidden_states[-3] # use the 3rd last layer's hidden states for HunyuanImage-2.1
|
||||
if hidden_states.shape[1] > tokenizer_max_length + drop_idx:
|
||||
logger.warning(f"Hidden states shape {hidden_states.shape} exceeds max length {tokenizer_max_length + drop_idx}")
|
||||
|
||||
# --- Unnecessary complicated processing, keep for reference ---
|
||||
# split_hidden_states = extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
|
||||
# split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
|
||||
# attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
|
||||
# max_seq_len = max([e.size(0) for e in split_hidden_states])
|
||||
# prompt_embeds = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states])
|
||||
# encoder_attention_mask = torch.stack([torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list])
|
||||
# ----------------------------------------------------------
|
||||
|
||||
prompt_embeds = hidden_states[:, drop_idx:, :]
|
||||
encoder_attention_mask = txt_tokens.attention_mask[:, drop_idx:]
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
return prompt_embeds, encoder_attention_mask
|
||||
|
||||
|
||||
def format_prompt(texts, styles):
|
||||
"""
|
||||
Text "{text}" in {color}, {type}.
|
||||
"""
|
||||
|
||||
prompt = ""
|
||||
for text, style in zip(texts, styles):
|
||||
# color and style are always None in official implementation, so we only use text
|
||||
text_prompt = f'Text "{text}"'
|
||||
text_prompt += ". "
|
||||
prompt = prompt + text_prompt
|
||||
return prompt
|
||||
|
||||
|
||||
def get_glyph_prompt_embeds(
|
||||
tokenizer: T5Tokenizer, text_encoder: T5Stack, prompt: Union[str, list[str]] = None
|
||||
) -> Tuple[list[bool], torch.Tensor, torch.Tensor]:
|
||||
byt5_max_length = 128
|
||||
if not prompt:
|
||||
return (
|
||||
[False],
|
||||
torch.zeros((1, byt5_max_length, 1472), device=text_encoder.device),
|
||||
torch.zeros((1, byt5_max_length), device=text_encoder.device, dtype=torch.int64),
|
||||
)
|
||||
|
||||
try:
|
||||
text_prompt_texts = []
|
||||
# pattern_quote_single = r"\'(.*?)\'"
|
||||
pattern_quote_double = r"\"(.*?)\""
|
||||
pattern_quote_chinese_single = r"‘(.*?)’"
|
||||
pattern_quote_chinese_double = r"“(.*?)”"
|
||||
|
||||
# matches_quote_single = re.findall(pattern_quote_single, prompt)
|
||||
matches_quote_double = re.findall(pattern_quote_double, prompt)
|
||||
matches_quote_chinese_single = re.findall(pattern_quote_chinese_single, prompt)
|
||||
matches_quote_chinese_double = re.findall(pattern_quote_chinese_double, prompt)
|
||||
|
||||
# text_prompt_texts.extend(matches_quote_single)
|
||||
text_prompt_texts.extend(matches_quote_double)
|
||||
text_prompt_texts.extend(matches_quote_chinese_single)
|
||||
text_prompt_texts.extend(matches_quote_chinese_double)
|
||||
|
||||
if not text_prompt_texts:
|
||||
return (
|
||||
[False],
|
||||
torch.zeros((1, byt5_max_length, 1472), device=text_encoder.device),
|
||||
torch.zeros((1, byt5_max_length), device=text_encoder.device, dtype=torch.int64),
|
||||
)
|
||||
|
||||
text_prompt_style_list = [{"color": None, "font-family": None} for _ in range(len(text_prompt_texts))]
|
||||
glyph_text_formatted = format_prompt(text_prompt_texts, text_prompt_style_list)
|
||||
|
||||
byt5_text_ids, byt5_text_mask = get_byt5_text_tokens(tokenizer, byt5_max_length, glyph_text_formatted)
|
||||
|
||||
byt5_text_ids = byt5_text_ids.to(device=text_encoder.device)
|
||||
byt5_text_mask = byt5_text_mask.to(device=text_encoder.device)
|
||||
|
||||
byt5_prompt_embeds = text_encoder(byt5_text_ids, attention_mask=byt5_text_mask.float())
|
||||
byt5_emb = byt5_prompt_embeds[0]
|
||||
|
||||
return [True], byt5_emb, byt5_text_mask
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Warning: Error in glyph encoding, using fallback: {e}")
|
||||
return (
|
||||
[False],
|
||||
torch.zeros((1, byt5_max_length, 1472), device=text_encoder.device),
|
||||
torch.zeros((1, byt5_max_length), device=text_encoder.device, dtype=torch.int64),
|
||||
)
|
||||
|
||||
|
||||
def get_byt5_text_tokens(tokenizer, max_length, text_list):
|
||||
"""
|
||||
Get byT5 text tokens.
|
||||
|
||||
Args:
|
||||
tokenizer: The tokenizer object
|
||||
max_length: Maximum token length
|
||||
text_list: List or string of text
|
||||
|
||||
Returns:
|
||||
Tuple of (byt5_text_ids, byt5_text_mask)
|
||||
"""
|
||||
if isinstance(text_list, list):
|
||||
text_prompt = " ".join(text_list)
|
||||
else:
|
||||
text_prompt = text_list
|
||||
|
||||
byt5_text_inputs = tokenizer(
|
||||
text_prompt, padding="max_length", max_length=max_length, truncation=True, add_special_tokens=True, return_tensors="pt"
|
||||
)
|
||||
|
||||
byt5_text_ids = byt5_text_inputs.input_ids
|
||||
byt5_text_mask = byt5_text_inputs.attention_mask
|
||||
|
||||
return byt5_text_ids, byt5_text_mask
|
||||
461
library/hunyuan_image_utils.py
Normal file
461
library/hunyuan_image_utils.py
Normal file
@@ -0,0 +1,461 @@
|
||||
# Original work: https://github.com/Tencent-Hunyuan/HunyuanImage-2.1
|
||||
# Re-implemented for license compliance for sd-scripts.
|
||||
|
||||
import math
|
||||
from typing import Tuple, Union, Optional
|
||||
import torch
|
||||
|
||||
|
||||
def _to_tuple(x, dim=2):
|
||||
"""
|
||||
Convert int or sequence to tuple of specified dimension.
|
||||
|
||||
Args:
|
||||
x: Int or sequence to convert.
|
||||
dim: Target dimension for tuple.
|
||||
|
||||
Returns:
|
||||
Tuple of length dim.
|
||||
"""
|
||||
if isinstance(x, int) or isinstance(x, float):
|
||||
return (x,) * dim
|
||||
elif len(x) == dim:
|
||||
return x
|
||||
else:
|
||||
raise ValueError(f"Expected length {dim} or int, but got {x}")
|
||||
|
||||
|
||||
def get_meshgrid_nd(start, dim=2):
|
||||
"""
|
||||
Generate n-dimensional coordinate meshgrid from 0 to grid_size.
|
||||
|
||||
Creates coordinate grids for each spatial dimension, useful for
|
||||
generating position embeddings.
|
||||
|
||||
Args:
|
||||
start: Grid size for each dimension (int or tuple).
|
||||
dim: Number of spatial dimensions.
|
||||
|
||||
Returns:
|
||||
Coordinate grid tensor [dim, *grid_size].
|
||||
"""
|
||||
# Convert start to grid sizes
|
||||
num = _to_tuple(start, dim=dim)
|
||||
start = (0,) * dim
|
||||
stop = num
|
||||
|
||||
# Generate coordinate arrays for each dimension
|
||||
axis_grid = []
|
||||
for i in range(dim):
|
||||
a, b, n = start[i], stop[i], num[i]
|
||||
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
|
||||
axis_grid.append(g)
|
||||
grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
|
||||
grid = torch.stack(grid, dim=0) # [dim, W, H, D]
|
||||
|
||||
return grid
|
||||
|
||||
|
||||
def get_nd_rotary_pos_embed(rope_dim_list, start, theta=10000.0):
|
||||
"""
|
||||
Generate n-dimensional rotary position embeddings for spatial tokens.
|
||||
|
||||
Creates RoPE embeddings for multi-dimensional positional encoding,
|
||||
distributing head dimensions across spatial dimensions.
|
||||
|
||||
Args:
|
||||
rope_dim_list: Dimensions allocated to each spatial axis (should sum to head_dim).
|
||||
start: Spatial grid size for each dimension.
|
||||
theta: Base frequency for RoPE computation.
|
||||
|
||||
Returns:
|
||||
Tuple of (cos_freqs, sin_freqs) for rotary embedding [H*W, D/2].
|
||||
"""
|
||||
|
||||
grid = get_meshgrid_nd(start, dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H]
|
||||
|
||||
# Generate RoPE embeddings for each spatial dimension
|
||||
embs = []
|
||||
for i in range(len(rope_dim_list)):
|
||||
emb = get_1d_rotary_pos_embed(rope_dim_list[i], grid[i].reshape(-1), theta) # 2 x [WHD, rope_dim_list[i]]
|
||||
embs.append(emb)
|
||||
|
||||
cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
|
||||
sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
|
||||
return cos, sin
|
||||
|
||||
|
||||
def get_1d_rotary_pos_embed(
|
||||
dim: int, pos: Union[torch.FloatTensor, int], theta: float = 10000.0
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Generate 1D rotary position embeddings.
|
||||
|
||||
Args:
|
||||
dim: Embedding dimension (must be even).
|
||||
pos: Position indices [S] or scalar for sequence length.
|
||||
theta: Base frequency for sinusoidal encoding.
|
||||
|
||||
Returns:
|
||||
Tuple of (cos_freqs, sin_freqs) tensors [S, D].
|
||||
"""
|
||||
if isinstance(pos, int):
|
||||
pos = torch.arange(pos).float()
|
||||
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
|
||||
freqs = torch.outer(pos, freqs) # [S, D/2]
|
||||
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
|
||||
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings for diffusion models.
|
||||
|
||||
Converts scalar timesteps to high-dimensional embeddings using
|
||||
sinusoidal encoding at different frequencies.
|
||||
|
||||
Args:
|
||||
t: Timestep tensor [N].
|
||||
dim: Output embedding dimension.
|
||||
max_period: Maximum period for frequency computation.
|
||||
|
||||
Returns:
|
||||
Timestep embeddings [N, dim].
|
||||
"""
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
|
||||
def modulate(x, shift=None, scale=None):
|
||||
"""
|
||||
Apply adaptive layer normalization modulation.
|
||||
|
||||
Applies scale and shift transformations for conditioning
|
||||
in adaptive layer normalization.
|
||||
|
||||
Args:
|
||||
x: Input tensor to modulate.
|
||||
shift: Additive shift parameter (optional).
|
||||
scale: Multiplicative scale parameter (optional).
|
||||
|
||||
Returns:
|
||||
Modulated tensor x * (1 + scale) + shift.
|
||||
"""
|
||||
if scale is None and shift is None:
|
||||
return x
|
||||
elif shift is None:
|
||||
return x * (1 + scale.unsqueeze(1))
|
||||
elif scale is None:
|
||||
return x + shift.unsqueeze(1)
|
||||
else:
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
def apply_gate(x, gate=None, tanh=False):
|
||||
"""
|
||||
Apply gating mechanism to tensor.
|
||||
|
||||
Multiplies input by gate values, optionally applying tanh activation.
|
||||
Used in residual connections for adaptive control.
|
||||
|
||||
Args:
|
||||
x: Input tensor to gate.
|
||||
gate: Gating values (optional).
|
||||
tanh: Whether to apply tanh to gate values.
|
||||
|
||||
Returns:
|
||||
Gated tensor x * gate (with optional tanh).
|
||||
"""
|
||||
if gate is None:
|
||||
return x
|
||||
if tanh:
|
||||
return x * gate.unsqueeze(1).tanh()
|
||||
else:
|
||||
return x * gate.unsqueeze(1)
|
||||
|
||||
|
||||
def reshape_for_broadcast(
|
||||
freqs_cis: Tuple[torch.Tensor, torch.Tensor],
|
||||
x: torch.Tensor,
|
||||
head_first=False,
|
||||
):
|
||||
"""
|
||||
Reshape RoPE frequency tensors for broadcasting with attention tensors.
|
||||
|
||||
Args:
|
||||
freqs_cis: Tuple of (cos_freqs, sin_freqs) tensors.
|
||||
x: Target tensor for broadcasting compatibility.
|
||||
head_first: Must be False (only supported layout).
|
||||
|
||||
Returns:
|
||||
Reshaped (cos_freqs, sin_freqs) tensors ready for broadcasting.
|
||||
"""
|
||||
assert not head_first, "Only head_first=False layout supported."
|
||||
assert isinstance(freqs_cis, tuple), "Expected tuple of (cos, sin) frequency tensors."
|
||||
assert x.ndim > 1, f"x should have at least 2 dimensions, but got {x.ndim}"
|
||||
|
||||
# Validate frequency tensor dimensions match target tensor
|
||||
assert freqs_cis[0].shape == (
|
||||
x.shape[1],
|
||||
x.shape[-1],
|
||||
), f"Frequency tensor shape {freqs_cis[0].shape} incompatible with target shape {x.shape}"
|
||||
|
||||
shape = [d if i == 1 or i == x.ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""
|
||||
Rotate half the dimensions for RoPE computation.
|
||||
|
||||
Splits the last dimension in half and applies a 90-degree rotation
|
||||
by swapping and negating components.
|
||||
|
||||
Args:
|
||||
x: Input tensor [..., D] where D is even.
|
||||
|
||||
Returns:
|
||||
Rotated tensor with same shape as input.
|
||||
"""
|
||||
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
||||
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor, xk: torch.Tensor, freqs_cis: Tuple[torch.Tensor, torch.Tensor], head_first: bool = False
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Apply rotary position embeddings to query and key tensors.
|
||||
|
||||
Args:
|
||||
xq: Query tensor [B, S, H, D].
|
||||
xk: Key tensor [B, S, H, D].
|
||||
freqs_cis: Tuple of (cos_freqs, sin_freqs) for rotation.
|
||||
head_first: Whether head dimension precedes sequence dimension.
|
||||
|
||||
Returns:
|
||||
Tuple of rotated (query, key) tensors.
|
||||
"""
|
||||
device = xq.device
|
||||
dtype = xq.dtype
|
||||
|
||||
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first)
|
||||
cos, sin = cos.to(device), sin.to(device)
|
||||
|
||||
# Apply rotation: x' = x * cos + rotate_half(x) * sin
|
||||
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).to(dtype)
|
||||
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).to(dtype)
|
||||
|
||||
return xq_out, xk_out
|
||||
|
||||
|
||||
def get_timesteps_sigmas(sampling_steps: int, shift: float, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Generate timesteps and sigmas for diffusion sampling.
|
||||
|
||||
Args:
|
||||
sampling_steps: Number of sampling steps.
|
||||
shift: Sigma shift parameter for schedule modification.
|
||||
device: Target device for tensors.
|
||||
|
||||
Returns:
|
||||
Tuple of (timesteps, sigmas) tensors.
|
||||
"""
|
||||
sigmas = torch.linspace(1, 0, sampling_steps + 1)
|
||||
sigmas = (shift * sigmas) / (1 + (shift - 1) * sigmas)
|
||||
sigmas = sigmas.to(torch.float32)
|
||||
timesteps = (sigmas[:-1] * 1000).to(dtype=torch.float32, device=device)
|
||||
return timesteps, sigmas
|
||||
|
||||
|
||||
def step(latents, noise_pred, sigmas, step_i):
|
||||
"""
|
||||
Perform a single diffusion sampling step.
|
||||
|
||||
Args:
|
||||
latents: Current latent state.
|
||||
noise_pred: Predicted noise.
|
||||
sigmas: Noise schedule sigmas.
|
||||
step_i: Current step index.
|
||||
|
||||
Returns:
|
||||
Updated latents after the step.
|
||||
"""
|
||||
return latents.float() - (sigmas[step_i] - sigmas[step_i + 1]) * noise_pred.float()
|
||||
|
||||
|
||||
# region AdaptiveProjectedGuidance
|
||||
|
||||
|
||||
class MomentumBuffer:
|
||||
"""
|
||||
Exponential moving average buffer for APG momentum.
|
||||
"""
|
||||
def __init__(self, momentum: float):
|
||||
self.momentum = momentum
|
||||
self.running_average = 0
|
||||
|
||||
def update(self, update_value: torch.Tensor):
|
||||
new_average = self.momentum * self.running_average
|
||||
self.running_average = update_value + new_average
|
||||
|
||||
|
||||
def normalized_guidance_apg(
|
||||
pred_cond: torch.Tensor,
|
||||
pred_uncond: torch.Tensor,
|
||||
guidance_scale: float,
|
||||
momentum_buffer: Optional[MomentumBuffer] = None,
|
||||
eta: float = 1.0,
|
||||
norm_threshold: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
):
|
||||
"""
|
||||
Apply normalized adaptive projected guidance.
|
||||
|
||||
Projects the guidance vector to reduce over-saturation while maintaining
|
||||
directional control by decomposing into parallel and orthogonal components.
|
||||
|
||||
Args:
|
||||
pred_cond: Conditional prediction.
|
||||
pred_uncond: Unconditional prediction.
|
||||
guidance_scale: Guidance scale factor.
|
||||
momentum_buffer: Optional momentum buffer for temporal smoothing.
|
||||
eta: Scaling factor for parallel component.
|
||||
norm_threshold: Maximum norm for guidance vector clipping.
|
||||
use_original_formulation: Whether to use original APG formulation.
|
||||
|
||||
Returns:
|
||||
Guided prediction tensor.
|
||||
"""
|
||||
diff = pred_cond - pred_uncond
|
||||
dim = [-i for i in range(1, len(diff.shape))] # All dimensions except batch
|
||||
|
||||
# Apply momentum smoothing if available
|
||||
if momentum_buffer is not None:
|
||||
momentum_buffer.update(diff)
|
||||
diff = momentum_buffer.running_average
|
||||
|
||||
# Apply norm clipping if threshold is set
|
||||
if norm_threshold > 0:
|
||||
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
|
||||
scale_factor = torch.minimum(torch.ones_like(diff_norm), norm_threshold / diff_norm)
|
||||
diff = diff * scale_factor
|
||||
|
||||
# Project guidance vector into parallel and orthogonal components
|
||||
v0, v1 = diff.double(), pred_cond.double()
|
||||
v1 = torch.nn.functional.normalize(v1, dim=dim)
|
||||
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
|
||||
v0_orthogonal = v0 - v0_parallel
|
||||
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
|
||||
|
||||
# Combine components with different scaling
|
||||
normalized_update = diff_orthogonal + eta * diff_parallel
|
||||
pred = pred_cond if use_original_formulation else pred_uncond
|
||||
pred = pred + guidance_scale * normalized_update
|
||||
|
||||
return pred
|
||||
|
||||
|
||||
class AdaptiveProjectedGuidance:
|
||||
"""
|
||||
Adaptive Projected Guidance for classifier-free guidance.
|
||||
|
||||
Implements APG which projects the guidance vector to reduce over-saturation
|
||||
while maintaining directional control.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 7.5,
|
||||
adaptive_projected_guidance_momentum: Optional[float] = None,
|
||||
adaptive_projected_guidance_rescale: float = 15.0,
|
||||
eta: float = 0.0,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
):
|
||||
assert guidance_rescale == 0.0, "guidance_rescale > 0.0 not supported."
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
|
||||
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
|
||||
self.eta = eta
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
self.momentum_buffer = None
|
||||
|
||||
def __call__(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None, step=None) -> torch.Tensor:
|
||||
if step == 0 and self.adaptive_projected_guidance_momentum is not None:
|
||||
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
|
||||
|
||||
pred = normalized_guidance_apg(
|
||||
pred_cond,
|
||||
pred_uncond,
|
||||
self.guidance_scale,
|
||||
self.momentum_buffer,
|
||||
self.eta,
|
||||
self.adaptive_projected_guidance_rescale,
|
||||
self.use_original_formulation,
|
||||
)
|
||||
|
||||
return pred
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
def apply_classifier_free_guidance(
|
||||
noise_pred_text: torch.Tensor,
|
||||
noise_pred_uncond: torch.Tensor,
|
||||
is_ocr: bool,
|
||||
guidance_scale: float,
|
||||
step: int,
|
||||
apg_start_step_ocr: int = 75,
|
||||
apg_start_step_general: int = 10,
|
||||
cfg_guider_ocr: AdaptiveProjectedGuidance = None,
|
||||
cfg_guider_general: AdaptiveProjectedGuidance = None,
|
||||
):
|
||||
"""
|
||||
Apply classifier-free guidance with OCR-aware APG for batch_size=1.
|
||||
|
||||
Args:
|
||||
noise_pred_text: Conditional noise prediction tensor [1, ...].
|
||||
noise_pred_uncond: Unconditional noise prediction tensor [1, ...].
|
||||
is_ocr: Whether this sample requires OCR-specific guidance.
|
||||
guidance_scale: Guidance scale for CFG.
|
||||
step: Current diffusion step index.
|
||||
apg_start_step_ocr: Step to start APG for OCR regions.
|
||||
apg_start_step_general: Step to start APG for general regions.
|
||||
cfg_guider_ocr: APG guider for OCR regions.
|
||||
cfg_guider_general: APG guider for general regions.
|
||||
|
||||
Returns:
|
||||
Guided noise prediction tensor [1, ...].
|
||||
"""
|
||||
if guidance_scale == 1.0:
|
||||
return noise_pred_text
|
||||
|
||||
# Select appropriate guider and start step based on OCR requirement
|
||||
if is_ocr:
|
||||
cfg_guider = cfg_guider_ocr
|
||||
apg_start_step = apg_start_step_ocr
|
||||
else:
|
||||
cfg_guider = cfg_guider_general
|
||||
apg_start_step = apg_start_step_general
|
||||
|
||||
# Apply standard CFG or APG based on current step
|
||||
if step <= apg_start_step:
|
||||
# Standard classifier-free guidance
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
# Initialize APG guider state
|
||||
_ = cfg_guider(noise_pred_text, noise_pred_uncond, step=step)
|
||||
else:
|
||||
# Use APG for guidance
|
||||
noise_pred = cfg_guider(noise_pred_text, noise_pred_uncond, step=step)
|
||||
|
||||
return noise_pred
|
||||
622
library/hunyuan_image_vae.py
Normal file
622
library/hunyuan_image_vae.py
Normal file
@@ -0,0 +1,622 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from torch.nn import Conv2d
|
||||
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
|
||||
|
||||
from library.utils import load_safetensors, setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
VAE_SCALE_FACTOR = 32 # 32x spatial compression
|
||||
|
||||
|
||||
def swish(x: Tensor) -> Tensor:
|
||||
"""Swish activation function: x * sigmoid(x)."""
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
"""Self-attention block using scaled dot-product attention."""
|
||||
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
self.q = Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
self.k = Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
self.v = Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
self.proj_out = Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
|
||||
def attention(self, x: Tensor) -> Tensor:
|
||||
x = self.norm(x)
|
||||
q = self.q(x)
|
||||
k = self.k(x)
|
||||
v = self.v(x)
|
||||
|
||||
b, c, h, w = q.shape
|
||||
q = rearrange(q, "b c h w -> b (h w) c").contiguous()
|
||||
k = rearrange(k, "b c h w -> b (h w) c").contiguous()
|
||||
v = rearrange(v, "b c h w -> b (h w) c").contiguous()
|
||||
|
||||
x = nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
return rearrange(x, "b (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return x + self.proj_out(self.attention(x))
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
"""
|
||||
Residual block with two convolutions, group normalization, and swish activation.
|
||||
Includes skip connection with optional channel dimension matching.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
in_channels : int
|
||||
Number of input channels.
|
||||
out_channels : int
|
||||
Number of output channels.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
self.conv1 = Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
|
||||
self.conv2 = Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# Skip connection projection for channel dimension mismatch
|
||||
if self.in_channels != self.out_channels:
|
||||
self.nin_shortcut = Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
h = x
|
||||
# First convolution block
|
||||
h = self.norm1(h)
|
||||
h = swish(h)
|
||||
h = self.conv1(h)
|
||||
# Second convolution block
|
||||
h = self.norm2(h)
|
||||
h = swish(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
# Apply skip connection with optional projection
|
||||
if self.in_channels != self.out_channels:
|
||||
x = self.nin_shortcut(x)
|
||||
return x + h
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
"""
|
||||
Spatial downsampling block that reduces resolution by 2x using convolution followed by
|
||||
pixel rearrangement. Includes skip connection with grouped averaging.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
in_channels : int
|
||||
Number of input channels.
|
||||
out_channels : int
|
||||
Number of output channels (must be divisible by 4).
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int):
|
||||
super().__init__()
|
||||
factor = 4 # 2x2 spatial reduction factor
|
||||
assert out_channels % factor == 0
|
||||
|
||||
self.conv = Conv2d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1)
|
||||
self.group_size = factor * in_channels // out_channels
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
# Apply convolution and rearrange pixels for 2x downsampling
|
||||
h = self.conv(x)
|
||||
h = rearrange(h, "b c (h r1) (w r2) -> b (r1 r2 c) h w", r1=2, r2=2)
|
||||
|
||||
# Create skip connection with pixel rearrangement
|
||||
shortcut = rearrange(x, "b c (h r1) (w r2) -> b (r1 r2 c) h w", r1=2, r2=2)
|
||||
B, C, H, W = shortcut.shape
|
||||
shortcut = shortcut.view(B, h.shape[1], self.group_size, H, W).mean(dim=2)
|
||||
|
||||
return h + shortcut
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
"""
|
||||
Spatial upsampling block that increases resolution by 2x using convolution followed by
|
||||
pixel rearrangement. Includes skip connection with channel repetition.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
in_channels : int
|
||||
Number of input channels.
|
||||
out_channels : int
|
||||
Number of output channels.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int):
|
||||
super().__init__()
|
||||
factor = 4 # 2x2 spatial expansion factor
|
||||
self.conv = Conv2d(in_channels, out_channels * factor, kernel_size=3, stride=1, padding=1)
|
||||
self.repeats = factor * out_channels // in_channels
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
# Apply convolution and rearrange pixels for 2x upsampling
|
||||
h = self.conv(x)
|
||||
h = rearrange(h, "b (r1 r2 c) h w -> b c (h r1) (w r2)", r1=2, r2=2)
|
||||
|
||||
# Create skip connection with channel repetition
|
||||
shortcut = x.repeat_interleave(repeats=self.repeats, dim=1)
|
||||
shortcut = rearrange(shortcut, "b (r1 r2 c) h w -> b c (h r1) (w r2)", r1=2, r2=2)
|
||||
|
||||
return h + shortcut
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
"""
|
||||
VAE encoder that progressively downsamples input images to a latent representation.
|
||||
Uses residual blocks, attention, and spatial downsampling.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
in_channels : int
|
||||
Number of input image channels (e.g., 3 for RGB).
|
||||
z_channels : int
|
||||
Number of latent channels in the output.
|
||||
block_out_channels : Tuple[int, ...]
|
||||
Output channels for each downsampling block.
|
||||
num_res_blocks : int
|
||||
Number of residual blocks per downsampling stage.
|
||||
ffactor_spatial : int
|
||||
Total spatial downsampling factor (e.g., 32 for 32x compression).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
z_channels: int,
|
||||
block_out_channels: Tuple[int, ...],
|
||||
num_res_blocks: int,
|
||||
ffactor_spatial: int,
|
||||
):
|
||||
super().__init__()
|
||||
assert block_out_channels[-1] % (2 * z_channels) == 0
|
||||
|
||||
self.z_channels = z_channels
|
||||
self.block_out_channels = block_out_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
|
||||
self.conv_in = Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.down = nn.ModuleList()
|
||||
block_in = block_out_channels[0]
|
||||
|
||||
# Build downsampling blocks
|
||||
for i_level, ch in enumerate(block_out_channels):
|
||||
block = nn.ModuleList()
|
||||
block_out = ch
|
||||
|
||||
# Add residual blocks for this level
|
||||
for _ in range(self.num_res_blocks):
|
||||
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
||||
block_in = block_out
|
||||
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
|
||||
# Add spatial downsampling if needed
|
||||
add_spatial_downsample = bool(i_level < np.log2(ffactor_spatial))
|
||||
if add_spatial_downsample:
|
||||
assert i_level < len(block_out_channels) - 1
|
||||
block_out = block_out_channels[i_level + 1]
|
||||
down.downsample = Downsample(block_in, block_out)
|
||||
block_in = block_out
|
||||
|
||||
self.down.append(down)
|
||||
|
||||
# Middle blocks with attention
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
|
||||
# Output layers
|
||||
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
||||
self.conv_out = Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
# Initial convolution
|
||||
h = self.conv_in(x)
|
||||
|
||||
# Progressive downsampling through blocks
|
||||
for i_level in range(len(self.block_out_channels)):
|
||||
# Apply residual blocks at this level
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](h)
|
||||
# Apply spatial downsampling if available
|
||||
if hasattr(self.down[i_level], "downsample"):
|
||||
h = self.down[i_level].downsample(h)
|
||||
|
||||
# Middle processing with attention
|
||||
h = self.mid.block_1(h)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h)
|
||||
|
||||
# Final output layers with skip connection
|
||||
group_size = self.block_out_channels[-1] // (2 * self.z_channels)
|
||||
shortcut = rearrange(h, "b (c r) h w -> b c r h w", r=group_size).mean(dim=2)
|
||||
h = self.norm_out(h)
|
||||
h = swish(h)
|
||||
h = self.conv_out(h)
|
||||
h += shortcut
|
||||
return h
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""
|
||||
VAE decoder that progressively upsamples latent representations back to images.
|
||||
Uses residual blocks, attention, and spatial upsampling.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
z_channels : int
|
||||
Number of latent channels in the input.
|
||||
out_channels : int
|
||||
Number of output image channels (e.g., 3 for RGB).
|
||||
block_out_channels : Tuple[int, ...]
|
||||
Output channels for each upsampling block.
|
||||
num_res_blocks : int
|
||||
Number of residual blocks per upsampling stage.
|
||||
ffactor_spatial : int
|
||||
Total spatial upsampling factor (e.g., 32 for 32x expansion).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
z_channels: int,
|
||||
out_channels: int,
|
||||
block_out_channels: Tuple[int, ...],
|
||||
num_res_blocks: int,
|
||||
ffactor_spatial: int,
|
||||
):
|
||||
super().__init__()
|
||||
assert block_out_channels[0] % z_channels == 0
|
||||
|
||||
self.z_channels = z_channels
|
||||
self.block_out_channels = block_out_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
|
||||
block_in = block_out_channels[0]
|
||||
self.conv_in = Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# Middle blocks with attention
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
|
||||
# Build upsampling blocks
|
||||
self.up = nn.ModuleList()
|
||||
for i_level, ch in enumerate(block_out_channels):
|
||||
block = nn.ModuleList()
|
||||
block_out = ch
|
||||
|
||||
# Add residual blocks for this level (extra block for decoder)
|
||||
for _ in range(self.num_res_blocks + 1):
|
||||
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
||||
block_in = block_out
|
||||
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
|
||||
# Add spatial upsampling if needed
|
||||
add_spatial_upsample = bool(i_level < np.log2(ffactor_spatial))
|
||||
if add_spatial_upsample:
|
||||
assert i_level < len(block_out_channels) - 1
|
||||
block_out = block_out_channels[i_level + 1]
|
||||
up.upsample = Upsample(block_in, block_out)
|
||||
block_in = block_out
|
||||
|
||||
self.up.append(up)
|
||||
|
||||
# Output layers
|
||||
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
||||
self.conv_out = Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, z: Tensor) -> Tensor:
|
||||
# Initial processing with skip connection
|
||||
repeats = self.block_out_channels[0] // self.z_channels
|
||||
h = self.conv_in(z) + z.repeat_interleave(repeats=repeats, dim=1)
|
||||
|
||||
# Middle processing with attention
|
||||
h = self.mid.block_1(h)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h)
|
||||
|
||||
# Progressive upsampling through blocks
|
||||
for i_level in range(len(self.block_out_channels)):
|
||||
# Apply residual blocks at this level
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h)
|
||||
# Apply spatial upsampling if available
|
||||
if hasattr(self.up[i_level], "upsample"):
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# Final output layers
|
||||
h = self.norm_out(h)
|
||||
h = swish(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class HunyuanVAE2D(nn.Module):
|
||||
"""
|
||||
VAE model for Hunyuan Image-2.1 with spatial tiling support.
|
||||
|
||||
This VAE uses a fixed architecture optimized for the Hunyuan Image-2.1 model,
|
||||
with 32x spatial compression and optional memory-efficient tiling for large images.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# Fixed configuration for Hunyuan Image-2.1
|
||||
block_out_channels = (128, 256, 512, 512, 1024, 1024)
|
||||
in_channels = 3 # RGB input
|
||||
out_channels = 3 # RGB output
|
||||
latent_channels = 64
|
||||
layers_per_block = 2
|
||||
ffactor_spatial = 32 # 32x spatial compression
|
||||
sample_size = 384 # Minimum sample size for tiling
|
||||
scaling_factor = 0.75289 # Latent scaling factor
|
||||
|
||||
self.ffactor_spatial = ffactor_spatial
|
||||
self.scaling_factor = scaling_factor
|
||||
|
||||
self.encoder = Encoder(
|
||||
in_channels=in_channels,
|
||||
z_channels=latent_channels,
|
||||
block_out_channels=block_out_channels,
|
||||
num_res_blocks=layers_per_block,
|
||||
ffactor_spatial=ffactor_spatial,
|
||||
)
|
||||
|
||||
self.decoder = Decoder(
|
||||
z_channels=latent_channels,
|
||||
out_channels=out_channels,
|
||||
block_out_channels=list(reversed(block_out_channels)),
|
||||
num_res_blocks=layers_per_block,
|
||||
ffactor_spatial=ffactor_spatial,
|
||||
)
|
||||
|
||||
# Spatial tiling configuration for memory efficiency
|
||||
self.use_spatial_tiling = False
|
||||
self.tile_sample_min_size = sample_size
|
||||
self.tile_latent_min_size = sample_size // ffactor_spatial
|
||||
self.tile_overlap_factor = 0.25 # 25% overlap between tiles
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
"""Get the data type of the model parameters."""
|
||||
return next(self.encoder.parameters()).dtype
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
"""Get the device of the model parameters."""
|
||||
return next(self.encoder.parameters()).device
|
||||
|
||||
def enable_spatial_tiling(self, use_tiling: bool = True):
|
||||
"""Enable or disable spatial tiling."""
|
||||
self.use_spatial_tiling = use_tiling
|
||||
|
||||
def disable_spatial_tiling(self):
|
||||
"""Disable spatial tiling."""
|
||||
self.use_spatial_tiling = False
|
||||
|
||||
def enable_tiling(self, use_tiling: bool = True):
|
||||
"""Enable or disable spatial tiling (alias for enable_spatial_tiling)."""
|
||||
self.enable_spatial_tiling(use_tiling)
|
||||
|
||||
def disable_tiling(self):
|
||||
"""Disable spatial tiling (alias for disable_spatial_tiling)."""
|
||||
self.disable_spatial_tiling()
|
||||
|
||||
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
"""
|
||||
Blend two tensors horizontally with smooth transition.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a : torch.Tensor
|
||||
Left tensor.
|
||||
b : torch.Tensor
|
||||
Right tensor.
|
||||
blend_extent : int
|
||||
Number of columns to blend.
|
||||
"""
|
||||
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
|
||||
for x in range(blend_extent):
|
||||
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent)
|
||||
return b
|
||||
|
||||
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
"""
|
||||
Blend two tensors vertically with smooth transition.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a : torch.Tensor
|
||||
Top tensor.
|
||||
b : torch.Tensor
|
||||
Bottom tensor.
|
||||
blend_extent : int
|
||||
Number of rows to blend.
|
||||
"""
|
||||
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
|
||||
for y in range(blend_extent):
|
||||
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent)
|
||||
return b
|
||||
|
||||
def spatial_tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Encode large images using spatial tiling to reduce memory usage.
|
||||
Tiles are processed independently and blended at boundaries.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Input tensor of shape (B, C, T, H, W).
|
||||
"""
|
||||
B, C, T, H, W = x.shape
|
||||
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
||||
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
||||
row_limit = self.tile_latent_min_size - blend_extent
|
||||
|
||||
rows = []
|
||||
for i in range(0, H, overlap_size):
|
||||
row = []
|
||||
for j in range(0, W, overlap_size):
|
||||
tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
|
||||
tile = self.encoder(tile)
|
||||
row.append(tile)
|
||||
rows.append(row)
|
||||
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
if i > 0:
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
||||
result_row.append(tile[:, :, :, :row_limit, :row_limit])
|
||||
result_rows.append(torch.cat(result_row, dim=-1))
|
||||
|
||||
moments = torch.cat(result_rows, dim=-2)
|
||||
return moments
|
||||
|
||||
def spatial_tiled_decode(self, z: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Decode large latents using spatial tiling to reduce memory usage.
|
||||
Tiles are processed independently and blended at boundaries.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
z : torch.Tensor
|
||||
Latent tensor of shape (B, C, H, W).
|
||||
"""
|
||||
B, C, H, W = z.shape
|
||||
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
|
||||
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
|
||||
row_limit = self.tile_sample_min_size - blend_extent
|
||||
|
||||
rows = []
|
||||
for i in range(0, H, overlap_size):
|
||||
row = []
|
||||
for j in range(0, W, overlap_size):
|
||||
tile = z[:, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
|
||||
decoded = self.decoder(tile)
|
||||
row.append(decoded)
|
||||
rows.append(row)
|
||||
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
if i > 0:
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
||||
result_row.append(tile[:, :, :, :row_limit, :row_limit])
|
||||
result_rows.append(torch.cat(result_row, dim=-1))
|
||||
|
||||
dec = torch.cat(result_rows, dim=-2)
|
||||
return dec
|
||||
|
||||
def encode(self, x: Tensor) -> DiagonalGaussianDistribution:
|
||||
"""
|
||||
Encode input images to latent representation.
|
||||
Uses spatial tiling for large images if enabled.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : Tensor
|
||||
Input image tensor of shape (B, C, H, W) or (B, C, T, H, W).
|
||||
|
||||
Returns
|
||||
-------
|
||||
DiagonalGaussianDistribution
|
||||
Latent distribution with mean and logvar.
|
||||
"""
|
||||
# Handle 5D input (B, C, T, H, W) by removing time dimension
|
||||
original_ndim = x.ndim
|
||||
if original_ndim == 5:
|
||||
x = x.squeeze(2)
|
||||
|
||||
# Use tiling for large images to reduce memory usage
|
||||
if self.use_spatial_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
|
||||
h = self.spatial_tiled_encode(x)
|
||||
else:
|
||||
h = self.encoder(x)
|
||||
|
||||
# Restore time dimension if input was 5D
|
||||
if original_ndim == 5:
|
||||
h = h.unsqueeze(2)
|
||||
|
||||
posterior = DiagonalGaussianDistribution(h)
|
||||
return posterior
|
||||
|
||||
def decode(self, z: Tensor):
|
||||
"""
|
||||
Decode latent representation back to images.
|
||||
Uses spatial tiling for large latents if enabled.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
z : Tensor
|
||||
Latent tensor of shape (B, C, H, W) or (B, C, T, H, W).
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tensor
|
||||
Decoded image tensor.
|
||||
"""
|
||||
# Handle 5D input (B, C, T, H, W) by removing time dimension
|
||||
original_ndim = z.ndim
|
||||
if original_ndim == 5:
|
||||
z = z.squeeze(2)
|
||||
|
||||
# Use tiling for large latents to reduce memory usage
|
||||
if self.use_spatial_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
|
||||
decoded = self.spatial_tiled_decode(z)
|
||||
else:
|
||||
decoded = self.decoder(z)
|
||||
|
||||
# Restore time dimension if input was 5D
|
||||
if original_ndim == 5:
|
||||
decoded = decoded.unsqueeze(2)
|
||||
|
||||
return decoded
|
||||
|
||||
|
||||
def load_vae(vae_path: str, device: torch.device, disable_mmap: bool = False) -> HunyuanVAE2D:
|
||||
logger.info("Initializing VAE")
|
||||
vae = HunyuanVAE2D()
|
||||
|
||||
logger.info(f"Loading VAE from {vae_path}")
|
||||
state_dict = load_safetensors(vae_path, device=device, disable_mmap=disable_mmap)
|
||||
info = vae.load_state_dict(state_dict, strict=True, assign=True)
|
||||
logger.info(f"Loaded VAE: {info}")
|
||||
|
||||
vae.to(device)
|
||||
return vae
|
||||
249
library/lora_utils.py
Normal file
249
library/lora_utils.py
Normal file
@@ -0,0 +1,249 @@
|
||||
# copy from Musubi Tuner
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Optional, Union
|
||||
import torch
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from library.custom_offloading_utils import synchronize_device
|
||||
from library.fp8_optimization_utils import load_safetensors_with_fp8_optimization
|
||||
from library.utils import MemoryEfficientSafeOpen, setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def filter_lora_state_dict(
|
||||
weights_sd: Dict[str, torch.Tensor],
|
||||
include_pattern: Optional[str] = None,
|
||||
exclude_pattern: Optional[str] = None,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
# apply include/exclude patterns
|
||||
original_key_count = len(weights_sd.keys())
|
||||
if include_pattern is not None:
|
||||
regex_include = re.compile(include_pattern)
|
||||
weights_sd = {k: v for k, v in weights_sd.items() if regex_include.search(k)}
|
||||
logger.info(f"Filtered keys with include pattern {include_pattern}: {original_key_count} -> {len(weights_sd.keys())}")
|
||||
|
||||
if exclude_pattern is not None:
|
||||
original_key_count_ex = len(weights_sd.keys())
|
||||
regex_exclude = re.compile(exclude_pattern)
|
||||
weights_sd = {k: v for k, v in weights_sd.items() if not regex_exclude.search(k)}
|
||||
logger.info(f"Filtered keys with exclude pattern {exclude_pattern}: {original_key_count_ex} -> {len(weights_sd.keys())}")
|
||||
|
||||
if len(weights_sd) != original_key_count:
|
||||
remaining_keys = list(set([k.split(".", 1)[0] for k in weights_sd.keys()]))
|
||||
remaining_keys.sort()
|
||||
logger.info(f"Remaining LoRA modules after filtering: {remaining_keys}")
|
||||
if len(weights_sd) == 0:
|
||||
logger.warning("No keys left after filtering.")
|
||||
|
||||
return weights_sd
|
||||
|
||||
|
||||
def load_safetensors_with_lora_and_fp8(
|
||||
model_files: Union[str, List[str]],
|
||||
lora_weights_list: Optional[Dict[str, torch.Tensor]],
|
||||
lora_multipliers: Optional[List[float]],
|
||||
fp8_optimization: bool,
|
||||
calc_device: torch.device,
|
||||
move_to_device: bool = False,
|
||||
dit_weight_dtype: Optional[torch.dtype] = None,
|
||||
target_keys: Optional[List[str]] = None,
|
||||
exclude_keys: Optional[List[str]] = None,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Merge LoRA weights into the state dict of a model with fp8 optimization if needed.
|
||||
|
||||
Args:
|
||||
model_files (Union[str, List[str]]): Path to the model file or list of paths. If the path matches a pattern like `00001-of-00004`, it will load all files with the same prefix.
|
||||
lora_weights_list (Optional[Dict[str, torch.Tensor]]): Dictionary of LoRA weight tensors to load.
|
||||
lora_multipliers (Optional[List[float]]): List of multipliers for LoRA weights.
|
||||
fp8_optimization (bool): Whether to apply FP8 optimization.
|
||||
calc_device (torch.device): Device to calculate on.
|
||||
move_to_device (bool): Whether to move tensors to the calculation device after loading.
|
||||
target_keys (Optional[List[str]]): Keys to target for optimization.
|
||||
exclude_keys (Optional[List[str]]): Keys to exclude from optimization.
|
||||
"""
|
||||
|
||||
# if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix
|
||||
if isinstance(model_files, str):
|
||||
model_files = [model_files]
|
||||
|
||||
extended_model_files = []
|
||||
for model_file in model_files:
|
||||
basename = os.path.basename(model_file)
|
||||
match = re.match(r"^(.*?)(\d+)-of-(\d+)\.safetensors$", basename)
|
||||
if match:
|
||||
prefix = basename[: match.start(2)]
|
||||
count = int(match.group(3))
|
||||
state_dict = {}
|
||||
for i in range(count):
|
||||
filename = f"{prefix}{i+1:05d}-of-{count:05d}.safetensors"
|
||||
filepath = os.path.join(os.path.dirname(model_file), filename)
|
||||
if os.path.exists(filepath):
|
||||
extended_model_files.append(filepath)
|
||||
else:
|
||||
raise FileNotFoundError(f"File {filepath} not found")
|
||||
else:
|
||||
extended_model_files.append(model_file)
|
||||
model_files = extended_model_files
|
||||
logger.info(f"Loading model files: {model_files}")
|
||||
|
||||
# load LoRA weights
|
||||
weight_hook = None
|
||||
if lora_weights_list is None or len(lora_weights_list) == 0:
|
||||
lora_weights_list = []
|
||||
lora_multipliers = []
|
||||
list_of_lora_weight_keys = []
|
||||
else:
|
||||
list_of_lora_weight_keys = []
|
||||
for lora_sd in lora_weights_list:
|
||||
lora_weight_keys = set(lora_sd.keys())
|
||||
list_of_lora_weight_keys.append(lora_weight_keys)
|
||||
|
||||
if lora_multipliers is None:
|
||||
lora_multipliers = [1.0] * len(lora_weights_list)
|
||||
while len(lora_multipliers) < len(lora_weights_list):
|
||||
lora_multipliers.append(1.0)
|
||||
if len(lora_multipliers) > len(lora_weights_list):
|
||||
lora_multipliers = lora_multipliers[: len(lora_weights_list)]
|
||||
|
||||
# Merge LoRA weights into the state dict
|
||||
logger.info(f"Merging LoRA weights into state dict. multipliers: {lora_multipliers}")
|
||||
|
||||
# make hook for LoRA merging
|
||||
def weight_hook_func(model_weight_key, model_weight):
|
||||
nonlocal list_of_lora_weight_keys, lora_weights_list, lora_multipliers, calc_device
|
||||
|
||||
if not model_weight_key.endswith(".weight"):
|
||||
return model_weight
|
||||
|
||||
original_device = model_weight.device
|
||||
if original_device != calc_device:
|
||||
model_weight = model_weight.to(calc_device) # to make calculation faster
|
||||
|
||||
for lora_weight_keys, lora_sd, multiplier in zip(list_of_lora_weight_keys, lora_weights_list, lora_multipliers):
|
||||
# check if this weight has LoRA weights
|
||||
lora_name = model_weight_key.rsplit(".", 1)[0] # remove trailing ".weight"
|
||||
lora_name = "lora_unet_" + lora_name.replace(".", "_")
|
||||
down_key = lora_name + ".lora_down.weight"
|
||||
up_key = lora_name + ".lora_up.weight"
|
||||
alpha_key = lora_name + ".alpha"
|
||||
if down_key not in lora_weight_keys or up_key not in lora_weight_keys:
|
||||
continue
|
||||
|
||||
# get LoRA weights
|
||||
down_weight = lora_sd[down_key]
|
||||
up_weight = lora_sd[up_key]
|
||||
|
||||
dim = down_weight.size()[0]
|
||||
alpha = lora_sd.get(alpha_key, dim)
|
||||
scale = alpha / dim
|
||||
|
||||
down_weight = down_weight.to(calc_device)
|
||||
up_weight = up_weight.to(calc_device)
|
||||
|
||||
# W <- W + U * D
|
||||
if len(model_weight.size()) == 2:
|
||||
# linear
|
||||
if len(up_weight.size()) == 4: # use linear projection mismatch
|
||||
up_weight = up_weight.squeeze(3).squeeze(2)
|
||||
down_weight = down_weight.squeeze(3).squeeze(2)
|
||||
model_weight = model_weight + multiplier * (up_weight @ down_weight) * scale
|
||||
elif down_weight.size()[2:4] == (1, 1):
|
||||
# conv2d 1x1
|
||||
model_weight = (
|
||||
model_weight
|
||||
+ multiplier
|
||||
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||
* scale
|
||||
)
|
||||
else:
|
||||
# conv2d 3x3
|
||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
|
||||
model_weight = model_weight + multiplier * conved * scale
|
||||
|
||||
# remove LoRA keys from set
|
||||
lora_weight_keys.remove(down_key)
|
||||
lora_weight_keys.remove(up_key)
|
||||
if alpha_key in lora_weight_keys:
|
||||
lora_weight_keys.remove(alpha_key)
|
||||
|
||||
model_weight = model_weight.to(original_device) # move back to original device
|
||||
return model_weight
|
||||
|
||||
weight_hook = weight_hook_func
|
||||
|
||||
state_dict = load_safetensors_with_fp8_optimization_and_hook(
|
||||
model_files,
|
||||
fp8_optimization,
|
||||
calc_device,
|
||||
move_to_device,
|
||||
dit_weight_dtype,
|
||||
target_keys,
|
||||
exclude_keys,
|
||||
weight_hook=weight_hook,
|
||||
)
|
||||
|
||||
for lora_weight_keys in list_of_lora_weight_keys:
|
||||
# check if all LoRA keys are used
|
||||
if len(lora_weight_keys) > 0:
|
||||
# if there are still LoRA keys left, it means they are not used in the model
|
||||
# this is a warning, not an error
|
||||
logger.warning(f"Warning: not all LoRA keys are used: {', '.join(lora_weight_keys)}")
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_safetensors_with_fp8_optimization_and_hook(
|
||||
model_files: list[str],
|
||||
fp8_optimization: bool,
|
||||
calc_device: torch.device,
|
||||
move_to_device: bool = False,
|
||||
dit_weight_dtype: Optional[torch.dtype] = None,
|
||||
target_keys: Optional[List[str]] = None,
|
||||
exclude_keys: Optional[List[str]] = None,
|
||||
weight_hook: callable = None,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Load state dict from safetensors files and merge LoRA weights into the state dict with fp8 optimization if needed.
|
||||
"""
|
||||
if fp8_optimization:
|
||||
logger.info(
|
||||
f"Loading state dict with FP8 optimization. Dtype of weight: {dit_weight_dtype}, hook enabled: {weight_hook is not None}"
|
||||
)
|
||||
# dit_weight_dtype is not used because we use fp8 optimization
|
||||
state_dict = load_safetensors_with_fp8_optimization(
|
||||
model_files, calc_device, target_keys, exclude_keys, move_to_device=move_to_device, weight_hook=weight_hook
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Loading state dict without FP8 optimization. Dtype of weight: {dit_weight_dtype}, hook enabled: {weight_hook is not None}"
|
||||
)
|
||||
state_dict = {}
|
||||
for model_file in model_files:
|
||||
with MemoryEfficientSafeOpen(model_file) as f:
|
||||
for key in tqdm(f.keys(), desc=f"Loading {os.path.basename(model_file)}", leave=False):
|
||||
value = f.get_tensor(key)
|
||||
if weight_hook is not None:
|
||||
value = weight_hook(key, value)
|
||||
if move_to_device:
|
||||
if dit_weight_dtype is None:
|
||||
value = value.to(calc_device, non_blocking=True)
|
||||
else:
|
||||
value = value.to(calc_device, dtype=dit_weight_dtype, non_blocking=True)
|
||||
elif dit_weight_dtype is not None:
|
||||
value = value.to(dit_weight_dtype)
|
||||
|
||||
state_dict[key] = value
|
||||
|
||||
if move_to_device:
|
||||
synchronize_device(calc_device)
|
||||
|
||||
return state_dict
|
||||
1444
networks/lora_hunyuan_image.py
Normal file
1444
networks/lora_hunyuan_image.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user