Files
Kohya-ss-sd-scripts/library/hunyuan_image_utils.py

526 lines
17 KiB
Python

# Original work: https://github.com/Tencent-Hunyuan/HunyuanImage-2.1
# Re-implemented for license compliance for sd-scripts.
import math
from typing import Tuple, Union, Optional
import torch
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
MODEL_VERSION_2_1 = "hunyuan-image-2.1"
# region model
def _to_tuple(x, dim=2):
"""
Convert int or sequence to tuple of specified dimension.
Args:
x: Int or sequence to convert.
dim: Target dimension for tuple.
Returns:
Tuple of length dim.
"""
if isinstance(x, int) or isinstance(x, float):
return (x,) * dim
elif len(x) == dim:
return x
else:
raise ValueError(f"Expected length {dim} or int, but got {x}")
def get_meshgrid_nd(start, dim=2):
"""
Generate n-dimensional coordinate meshgrid from 0 to grid_size.
Creates coordinate grids for each spatial dimension, useful for
generating position embeddings.
Args:
start: Grid size for each dimension (int or tuple).
dim: Number of spatial dimensions.
Returns:
Coordinate grid tensor [dim, *grid_size].
"""
# Convert start to grid sizes
num = _to_tuple(start, dim=dim)
start = (0,) * dim
stop = num
# Generate coordinate arrays for each dimension
axis_grid = []
for i in range(dim):
a, b, n = start[i], stop[i], num[i]
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
axis_grid.append(g)
grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
grid = torch.stack(grid, dim=0) # [dim, W, H, D]
return grid
def get_nd_rotary_pos_embed(rope_dim_list, start, theta=10000.0):
"""
Generate n-dimensional rotary position embeddings for spatial tokens.
Creates RoPE embeddings for multi-dimensional positional encoding,
distributing head dimensions across spatial dimensions.
Args:
rope_dim_list: Dimensions allocated to each spatial axis (should sum to head_dim).
start: Spatial grid size for each dimension.
theta: Base frequency for RoPE computation.
Returns:
Tuple of (cos_freqs, sin_freqs) for rotary embedding [H*W, D/2].
"""
grid = get_meshgrid_nd(start, dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H]
# Generate RoPE embeddings for each spatial dimension
embs = []
for i in range(len(rope_dim_list)):
emb = get_1d_rotary_pos_embed(rope_dim_list[i], grid[i].reshape(-1), theta) # 2 x [WHD, rope_dim_list[i]]
embs.append(emb)
cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
return cos, sin
def get_1d_rotary_pos_embed(
dim: int, pos: Union[torch.FloatTensor, int], theta: float = 10000.0
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Generate 1D rotary position embeddings.
Args:
dim: Embedding dimension (must be even).
pos: Position indices [S] or scalar for sequence length.
theta: Base frequency for sinusoidal encoding.
Returns:
Tuple of (cos_freqs, sin_freqs) tensors [S, D].
"""
if isinstance(pos, int):
pos = torch.arange(pos).float()
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
freqs = torch.outer(pos, freqs) # [S, D/2]
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
return freqs_cos, freqs_sin
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings for diffusion models.
Converts scalar timesteps to high-dimensional embeddings using
sinusoidal encoding at different frequencies.
Args:
t: Timestep tensor [N].
dim: Output embedding dimension.
max_period: Maximum period for frequency computation.
Returns:
Timestep embeddings [N, dim].
"""
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def modulate(x, shift=None, scale=None):
"""
Apply adaptive layer normalization modulation.
Applies scale and shift transformations for conditioning
in adaptive layer normalization.
Args:
x: Input tensor to modulate.
shift: Additive shift parameter (optional).
scale: Multiplicative scale parameter (optional).
Returns:
Modulated tensor x * (1 + scale) + shift.
"""
if scale is None and shift is None:
return x
elif shift is None:
return x * (1 + scale.unsqueeze(1))
elif scale is None:
return x + shift.unsqueeze(1)
else:
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def apply_gate(x, gate=None, tanh=False):
"""
Apply gating mechanism to tensor.
Multiplies input by gate values, optionally applying tanh activation.
Used in residual connections for adaptive control.
Args:
x: Input tensor to gate.
gate: Gating values (optional).
tanh: Whether to apply tanh to gate values.
Returns:
Gated tensor x * gate (with optional tanh).
"""
if gate is None:
return x
if tanh:
return x * gate.unsqueeze(1).tanh()
else:
return x * gate.unsqueeze(1)
def reshape_for_broadcast(
freqs_cis: Tuple[torch.Tensor, torch.Tensor],
x: torch.Tensor,
head_first=False,
):
"""
Reshape RoPE frequency tensors for broadcasting with attention tensors.
Args:
freqs_cis: Tuple of (cos_freqs, sin_freqs) tensors.
x: Target tensor for broadcasting compatibility.
head_first: Must be False (only supported layout).
Returns:
Reshaped (cos_freqs, sin_freqs) tensors ready for broadcasting.
"""
assert not head_first, "Only head_first=False layout supported."
assert isinstance(freqs_cis, tuple), "Expected tuple of (cos, sin) frequency tensors."
assert x.ndim > 1, f"x should have at least 2 dimensions, but got {x.ndim}"
# Validate frequency tensor dimensions match target tensor
assert freqs_cis[0].shape == (
x.shape[1],
x.shape[-1],
), f"Frequency tensor shape {freqs_cis[0].shape} incompatible with target shape {x.shape}"
shape = [d if i == 1 or i == x.ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
def rotate_half(x):
"""
Rotate half the dimensions for RoPE computation.
Splits the last dimension in half and applies a 90-degree rotation
by swapping and negating components.
Args:
x: Input tensor [..., D] where D is even.
Returns:
Rotated tensor with same shape as input.
"""
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
def apply_rotary_emb(
xq: torch.Tensor, xk: torch.Tensor, freqs_cis: Tuple[torch.Tensor, torch.Tensor], head_first: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary position embeddings to query and key tensors.
Args:
xq: Query tensor [B, S, H, D].
xk: Key tensor [B, S, H, D].
freqs_cis: Tuple of (cos_freqs, sin_freqs) for rotation.
head_first: Whether head dimension precedes sequence dimension.
Returns:
Tuple of rotated (query, key) tensors.
"""
device = xq.device
dtype = xq.dtype
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first)
cos, sin = cos.to(device), sin.to(device)
# Apply rotation: x' = x * cos + rotate_half(x) * sin
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).to(dtype)
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).to(dtype)
return xq_out, xk_out
# endregion
# region inference
def get_timesteps_sigmas(sampling_steps: int, shift: float, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Generate timesteps and sigmas for diffusion sampling.
Args:
sampling_steps: Number of sampling steps.
shift: Sigma shift parameter for schedule modification.
device: Target device for tensors.
Returns:
Tuple of (timesteps, sigmas) tensors.
"""
sigmas = torch.linspace(1, 0, sampling_steps + 1)
sigmas = (shift * sigmas) / (1 + (shift - 1) * sigmas)
sigmas = sigmas.to(torch.float32)
timesteps = (sigmas[:-1] * 1000).to(dtype=torch.float32, device=device)
return timesteps, sigmas
def step(latents, noise_pred, sigmas, step_i):
"""
Perform a single diffusion sampling step.
Args:
latents: Current latent state.
noise_pred: Predicted noise.
sigmas: Noise schedule sigmas.
step_i: Current step index.
Returns:
Updated latents after the step.
"""
return latents.float() - (sigmas[step_i] - sigmas[step_i + 1]) * noise_pred.float()
# endregion
# region AdaptiveProjectedGuidance
class MomentumBuffer:
"""
Exponential moving average buffer for APG momentum.
"""
def __init__(self, momentum: float):
self.momentum = momentum
self.running_average = 0
def update(self, update_value: torch.Tensor):
new_average = self.momentum * self.running_average
self.running_average = update_value + new_average
def normalized_guidance_apg(
pred_cond: torch.Tensor,
pred_uncond: torch.Tensor,
guidance_scale: float,
momentum_buffer: Optional[MomentumBuffer] = None,
eta: float = 1.0,
norm_threshold: float = 0.0,
use_original_formulation: bool = False,
):
"""
Apply normalized adaptive projected guidance.
Projects the guidance vector to reduce over-saturation while maintaining
directional control by decomposing into parallel and orthogonal components.
Args:
pred_cond: Conditional prediction.
pred_uncond: Unconditional prediction.
guidance_scale: Guidance scale factor.
momentum_buffer: Optional momentum buffer for temporal smoothing.
eta: Scaling factor for parallel component.
norm_threshold: Maximum norm for guidance vector clipping.
use_original_formulation: Whether to use original APG formulation.
Returns:
Guided prediction tensor.
"""
diff = pred_cond - pred_uncond
dim = [-i for i in range(1, len(diff.shape))] # All dimensions except batch
# Apply momentum smoothing if available
if momentum_buffer is not None:
momentum_buffer.update(diff)
diff = momentum_buffer.running_average
# Apply norm clipping if threshold is set
if norm_threshold > 0:
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
scale_factor = torch.minimum(torch.ones_like(diff_norm), norm_threshold / diff_norm)
diff = diff * scale_factor
# Project guidance vector into parallel and orthogonal components
v0, v1 = diff.double(), pred_cond.double()
v1 = torch.nn.functional.normalize(v1, dim=dim)
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
# Combine components with different scaling
normalized_update = diff_orthogonal + eta * diff_parallel
pred = pred_cond if use_original_formulation else pred_uncond
pred = pred + guidance_scale * normalized_update
return pred
class AdaptiveProjectedGuidance:
"""
Adaptive Projected Guidance for classifier-free guidance.
Implements APG which projects the guidance vector to reduce over-saturation
while maintaining directional control.
"""
def __init__(
self,
guidance_scale: float = 7.5,
adaptive_projected_guidance_momentum: Optional[float] = None,
adaptive_projected_guidance_rescale: float = 15.0,
eta: float = 0.0,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
):
self.guidance_scale = guidance_scale
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
self.eta = eta
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
self.momentum_buffer = None
def __call__(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None, step=None) -> torch.Tensor:
if step == 0 and self.adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
pred = normalized_guidance_apg(
pred_cond,
pred_uncond,
self.guidance_scale,
self.momentum_buffer,
self.eta,
self.adaptive_projected_guidance_rescale,
self.use_original_formulation,
)
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred
def rescale_noise_cfg(guided_noise, conditional_noise, rescale_factor=0.0):
"""
Rescale guided noise prediction to prevent overexposure and improve image quality.
This implementation addresses the overexposure issue described in "Common Diffusion Noise
Schedules and Sample Steps are Flawed" (https://arxiv.org/pdf/2305.08891.pdf) (Section 3.4).
The rescaling preserves the statistical properties of the conditional prediction while reducing artifacts.
Args:
guided_noise (torch.Tensor): Noise prediction from classifier-free guidance.
conditional_noise (torch.Tensor): Noise prediction from conditional model.
rescale_factor (float): Interpolation factor between original and rescaled predictions.
0.0 = no rescaling, 1.0 = full rescaling.
Returns:
torch.Tensor: Rescaled noise prediction with reduced overexposure.
"""
if rescale_factor == 0.0:
return guided_noise
# Calculate standard deviation across spatial dimensions for both predictions
spatial_dims = list(range(1, conditional_noise.ndim))
conditional_std = conditional_noise.std(dim=spatial_dims, keepdim=True)
guided_std = guided_noise.std(dim=spatial_dims, keepdim=True)
# Rescale guided noise to match conditional noise statistics
std_ratio = conditional_std / guided_std
rescaled_prediction = guided_noise * std_ratio
# Interpolate between original and rescaled predictions
final_prediction = rescale_factor * rescaled_prediction + (1.0 - rescale_factor) * guided_noise
return final_prediction
def apply_classifier_free_guidance(
noise_pred_text: torch.Tensor,
noise_pred_uncond: torch.Tensor,
is_ocr: bool,
guidance_scale: float,
step: int,
apg_start_step_ocr: int = 38,
apg_start_step_general: int = 5,
cfg_guider_ocr: AdaptiveProjectedGuidance = None,
cfg_guider_general: AdaptiveProjectedGuidance = None,
guidance_rescale: float = 0.0,
):
"""
Apply classifier-free guidance with OCR-aware APG for batch_size=1.
Args:
noise_pred_text: Conditional noise prediction tensor [1, ...].
noise_pred_uncond: Unconditional noise prediction tensor [1, ...].
is_ocr: Whether this sample requires OCR-specific guidance.
guidance_scale: Guidance scale for CFG.
step: Current diffusion step index.
apg_start_step_ocr: Step to start APG for OCR regions.
apg_start_step_general: Step to start APG for general regions.
cfg_guider_ocr: APG guider for OCR regions.
cfg_guider_general: APG guider for general regions.
Returns:
Guided noise prediction tensor [1, ...].
"""
if guidance_scale == 1.0:
return noise_pred_text
# Select appropriate guider and start step based on OCR requirement
if is_ocr:
cfg_guider = cfg_guider_ocr
apg_start_step = apg_start_step_ocr
else:
cfg_guider = cfg_guider_general
apg_start_step = apg_start_step_general
# Apply standard CFG or APG based on current step
if step <= apg_start_step:
# Standard classifier-free guidance
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale)
# Initialize APG guider state
_ = cfg_guider(noise_pred_text, noise_pred_uncond, step=step)
else:
# Use APG for guidance
noise_pred = cfg_guider(noise_pred_text, noise_pred_uncond, step=step)
return noise_pred
# endregion