mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
526 lines
17 KiB
Python
526 lines
17 KiB
Python
# Original work: https://github.com/Tencent-Hunyuan/HunyuanImage-2.1
|
|
# Re-implemented for license compliance for sd-scripts.
|
|
|
|
import math
|
|
from typing import Tuple, Union, Optional
|
|
import torch
|
|
|
|
from library.utils import setup_logging
|
|
|
|
setup_logging()
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
MODEL_VERSION_2_1 = "hunyuan-image-2.1"
|
|
|
|
# region model
|
|
|
|
|
|
def _to_tuple(x, dim=2):
|
|
"""
|
|
Convert int or sequence to tuple of specified dimension.
|
|
|
|
Args:
|
|
x: Int or sequence to convert.
|
|
dim: Target dimension for tuple.
|
|
|
|
Returns:
|
|
Tuple of length dim.
|
|
"""
|
|
if isinstance(x, int) or isinstance(x, float):
|
|
return (x,) * dim
|
|
elif len(x) == dim:
|
|
return x
|
|
else:
|
|
raise ValueError(f"Expected length {dim} or int, but got {x}")
|
|
|
|
|
|
def get_meshgrid_nd(start, dim=2):
|
|
"""
|
|
Generate n-dimensional coordinate meshgrid from 0 to grid_size.
|
|
|
|
Creates coordinate grids for each spatial dimension, useful for
|
|
generating position embeddings.
|
|
|
|
Args:
|
|
start: Grid size for each dimension (int or tuple).
|
|
dim: Number of spatial dimensions.
|
|
|
|
Returns:
|
|
Coordinate grid tensor [dim, *grid_size].
|
|
"""
|
|
# Convert start to grid sizes
|
|
num = _to_tuple(start, dim=dim)
|
|
start = (0,) * dim
|
|
stop = num
|
|
|
|
# Generate coordinate arrays for each dimension
|
|
axis_grid = []
|
|
for i in range(dim):
|
|
a, b, n = start[i], stop[i], num[i]
|
|
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
|
|
axis_grid.append(g)
|
|
grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
|
|
grid = torch.stack(grid, dim=0) # [dim, W, H, D]
|
|
|
|
return grid
|
|
|
|
|
|
def get_nd_rotary_pos_embed(rope_dim_list, start, theta=10000.0):
|
|
"""
|
|
Generate n-dimensional rotary position embeddings for spatial tokens.
|
|
|
|
Creates RoPE embeddings for multi-dimensional positional encoding,
|
|
distributing head dimensions across spatial dimensions.
|
|
|
|
Args:
|
|
rope_dim_list: Dimensions allocated to each spatial axis (should sum to head_dim).
|
|
start: Spatial grid size for each dimension.
|
|
theta: Base frequency for RoPE computation.
|
|
|
|
Returns:
|
|
Tuple of (cos_freqs, sin_freqs) for rotary embedding [H*W, D/2].
|
|
"""
|
|
|
|
grid = get_meshgrid_nd(start, dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H]
|
|
|
|
# Generate RoPE embeddings for each spatial dimension
|
|
embs = []
|
|
for i in range(len(rope_dim_list)):
|
|
emb = get_1d_rotary_pos_embed(rope_dim_list[i], grid[i].reshape(-1), theta) # 2 x [WHD, rope_dim_list[i]]
|
|
embs.append(emb)
|
|
|
|
cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
|
|
sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
|
|
return cos, sin
|
|
|
|
|
|
def get_1d_rotary_pos_embed(
|
|
dim: int, pos: Union[torch.FloatTensor, int], theta: float = 10000.0
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Generate 1D rotary position embeddings.
|
|
|
|
Args:
|
|
dim: Embedding dimension (must be even).
|
|
pos: Position indices [S] or scalar for sequence length.
|
|
theta: Base frequency for sinusoidal encoding.
|
|
|
|
Returns:
|
|
Tuple of (cos_freqs, sin_freqs) tensors [S, D].
|
|
"""
|
|
if isinstance(pos, int):
|
|
pos = torch.arange(pos).float()
|
|
|
|
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
|
|
freqs = torch.outer(pos, freqs) # [S, D/2]
|
|
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
|
|
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
|
|
return freqs_cos, freqs_sin
|
|
|
|
|
|
def timestep_embedding(t, dim, max_period=10000):
|
|
"""
|
|
Create sinusoidal timestep embeddings for diffusion models.
|
|
|
|
Converts scalar timesteps to high-dimensional embeddings using
|
|
sinusoidal encoding at different frequencies.
|
|
|
|
Args:
|
|
t: Timestep tensor [N].
|
|
dim: Output embedding dimension.
|
|
max_period: Maximum period for frequency computation.
|
|
|
|
Returns:
|
|
Timestep embeddings [N, dim].
|
|
"""
|
|
half = dim // 2
|
|
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
|
|
args = t[:, None].float() * freqs[None]
|
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
|
if dim % 2:
|
|
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
|
return embedding
|
|
|
|
|
|
def modulate(x, shift=None, scale=None):
|
|
"""
|
|
Apply adaptive layer normalization modulation.
|
|
|
|
Applies scale and shift transformations for conditioning
|
|
in adaptive layer normalization.
|
|
|
|
Args:
|
|
x: Input tensor to modulate.
|
|
shift: Additive shift parameter (optional).
|
|
scale: Multiplicative scale parameter (optional).
|
|
|
|
Returns:
|
|
Modulated tensor x * (1 + scale) + shift.
|
|
"""
|
|
if scale is None and shift is None:
|
|
return x
|
|
elif shift is None:
|
|
return x * (1 + scale.unsqueeze(1))
|
|
elif scale is None:
|
|
return x + shift.unsqueeze(1)
|
|
else:
|
|
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
|
|
|
|
|
def apply_gate(x, gate=None, tanh=False):
|
|
"""
|
|
Apply gating mechanism to tensor.
|
|
|
|
Multiplies input by gate values, optionally applying tanh activation.
|
|
Used in residual connections for adaptive control.
|
|
|
|
Args:
|
|
x: Input tensor to gate.
|
|
gate: Gating values (optional).
|
|
tanh: Whether to apply tanh to gate values.
|
|
|
|
Returns:
|
|
Gated tensor x * gate (with optional tanh).
|
|
"""
|
|
if gate is None:
|
|
return x
|
|
if tanh:
|
|
return x * gate.unsqueeze(1).tanh()
|
|
else:
|
|
return x * gate.unsqueeze(1)
|
|
|
|
|
|
def reshape_for_broadcast(
|
|
freqs_cis: Tuple[torch.Tensor, torch.Tensor],
|
|
x: torch.Tensor,
|
|
head_first=False,
|
|
):
|
|
"""
|
|
Reshape RoPE frequency tensors for broadcasting with attention tensors.
|
|
|
|
Args:
|
|
freqs_cis: Tuple of (cos_freqs, sin_freqs) tensors.
|
|
x: Target tensor for broadcasting compatibility.
|
|
head_first: Must be False (only supported layout).
|
|
|
|
Returns:
|
|
Reshaped (cos_freqs, sin_freqs) tensors ready for broadcasting.
|
|
"""
|
|
assert not head_first, "Only head_first=False layout supported."
|
|
assert isinstance(freqs_cis, tuple), "Expected tuple of (cos, sin) frequency tensors."
|
|
assert x.ndim > 1, f"x should have at least 2 dimensions, but got {x.ndim}"
|
|
|
|
# Validate frequency tensor dimensions match target tensor
|
|
assert freqs_cis[0].shape == (
|
|
x.shape[1],
|
|
x.shape[-1],
|
|
), f"Frequency tensor shape {freqs_cis[0].shape} incompatible with target shape {x.shape}"
|
|
|
|
shape = [d if i == 1 or i == x.ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
|
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
|
|
|
|
|
|
def rotate_half(x):
|
|
"""
|
|
Rotate half the dimensions for RoPE computation.
|
|
|
|
Splits the last dimension in half and applies a 90-degree rotation
|
|
by swapping and negating components.
|
|
|
|
Args:
|
|
x: Input tensor [..., D] where D is even.
|
|
|
|
Returns:
|
|
Rotated tensor with same shape as input.
|
|
"""
|
|
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
|
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
|
|
|
|
|
def apply_rotary_emb(
|
|
xq: torch.Tensor, xk: torch.Tensor, freqs_cis: Tuple[torch.Tensor, torch.Tensor], head_first: bool = False
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Apply rotary position embeddings to query and key tensors.
|
|
|
|
Args:
|
|
xq: Query tensor [B, S, H, D].
|
|
xk: Key tensor [B, S, H, D].
|
|
freqs_cis: Tuple of (cos_freqs, sin_freqs) for rotation.
|
|
head_first: Whether head dimension precedes sequence dimension.
|
|
|
|
Returns:
|
|
Tuple of rotated (query, key) tensors.
|
|
"""
|
|
device = xq.device
|
|
dtype = xq.dtype
|
|
|
|
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first)
|
|
cos, sin = cos.to(device), sin.to(device)
|
|
|
|
# Apply rotation: x' = x * cos + rotate_half(x) * sin
|
|
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).to(dtype)
|
|
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).to(dtype)
|
|
|
|
return xq_out, xk_out
|
|
|
|
|
|
# endregion
|
|
|
|
# region inference
|
|
|
|
|
|
def get_timesteps_sigmas(sampling_steps: int, shift: float, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Generate timesteps and sigmas for diffusion sampling.
|
|
|
|
Args:
|
|
sampling_steps: Number of sampling steps.
|
|
shift: Sigma shift parameter for schedule modification.
|
|
device: Target device for tensors.
|
|
|
|
Returns:
|
|
Tuple of (timesteps, sigmas) tensors.
|
|
"""
|
|
sigmas = torch.linspace(1, 0, sampling_steps + 1)
|
|
sigmas = (shift * sigmas) / (1 + (shift - 1) * sigmas)
|
|
sigmas = sigmas.to(torch.float32)
|
|
timesteps = (sigmas[:-1] * 1000).to(dtype=torch.float32, device=device)
|
|
return timesteps, sigmas
|
|
|
|
|
|
def step(latents, noise_pred, sigmas, step_i):
|
|
"""
|
|
Perform a single diffusion sampling step.
|
|
|
|
Args:
|
|
latents: Current latent state.
|
|
noise_pred: Predicted noise.
|
|
sigmas: Noise schedule sigmas.
|
|
step_i: Current step index.
|
|
|
|
Returns:
|
|
Updated latents after the step.
|
|
"""
|
|
return latents.float() - (sigmas[step_i] - sigmas[step_i + 1]) * noise_pred.float()
|
|
|
|
|
|
# endregion
|
|
|
|
|
|
# region AdaptiveProjectedGuidance
|
|
|
|
|
|
class MomentumBuffer:
|
|
"""
|
|
Exponential moving average buffer for APG momentum.
|
|
"""
|
|
|
|
def __init__(self, momentum: float):
|
|
self.momentum = momentum
|
|
self.running_average = 0
|
|
|
|
def update(self, update_value: torch.Tensor):
|
|
new_average = self.momentum * self.running_average
|
|
self.running_average = update_value + new_average
|
|
|
|
|
|
def normalized_guidance_apg(
|
|
pred_cond: torch.Tensor,
|
|
pred_uncond: torch.Tensor,
|
|
guidance_scale: float,
|
|
momentum_buffer: Optional[MomentumBuffer] = None,
|
|
eta: float = 1.0,
|
|
norm_threshold: float = 0.0,
|
|
use_original_formulation: bool = False,
|
|
):
|
|
"""
|
|
Apply normalized adaptive projected guidance.
|
|
|
|
Projects the guidance vector to reduce over-saturation while maintaining
|
|
directional control by decomposing into parallel and orthogonal components.
|
|
|
|
Args:
|
|
pred_cond: Conditional prediction.
|
|
pred_uncond: Unconditional prediction.
|
|
guidance_scale: Guidance scale factor.
|
|
momentum_buffer: Optional momentum buffer for temporal smoothing.
|
|
eta: Scaling factor for parallel component.
|
|
norm_threshold: Maximum norm for guidance vector clipping.
|
|
use_original_formulation: Whether to use original APG formulation.
|
|
|
|
Returns:
|
|
Guided prediction tensor.
|
|
"""
|
|
diff = pred_cond - pred_uncond
|
|
dim = [-i for i in range(1, len(diff.shape))] # All dimensions except batch
|
|
|
|
# Apply momentum smoothing if available
|
|
if momentum_buffer is not None:
|
|
momentum_buffer.update(diff)
|
|
diff = momentum_buffer.running_average
|
|
|
|
# Apply norm clipping if threshold is set
|
|
if norm_threshold > 0:
|
|
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
|
|
scale_factor = torch.minimum(torch.ones_like(diff_norm), norm_threshold / diff_norm)
|
|
diff = diff * scale_factor
|
|
|
|
# Project guidance vector into parallel and orthogonal components
|
|
v0, v1 = diff.double(), pred_cond.double()
|
|
v1 = torch.nn.functional.normalize(v1, dim=dim)
|
|
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
|
|
v0_orthogonal = v0 - v0_parallel
|
|
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
|
|
|
|
# Combine components with different scaling
|
|
normalized_update = diff_orthogonal + eta * diff_parallel
|
|
pred = pred_cond if use_original_formulation else pred_uncond
|
|
pred = pred + guidance_scale * normalized_update
|
|
|
|
return pred
|
|
|
|
|
|
class AdaptiveProjectedGuidance:
|
|
"""
|
|
Adaptive Projected Guidance for classifier-free guidance.
|
|
|
|
Implements APG which projects the guidance vector to reduce over-saturation
|
|
while maintaining directional control.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
guidance_scale: float = 7.5,
|
|
adaptive_projected_guidance_momentum: Optional[float] = None,
|
|
adaptive_projected_guidance_rescale: float = 15.0,
|
|
eta: float = 0.0,
|
|
guidance_rescale: float = 0.0,
|
|
use_original_formulation: bool = False,
|
|
):
|
|
self.guidance_scale = guidance_scale
|
|
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
|
|
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
|
|
self.eta = eta
|
|
self.guidance_rescale = guidance_rescale
|
|
self.use_original_formulation = use_original_formulation
|
|
self.momentum_buffer = None
|
|
|
|
def __call__(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None, step=None) -> torch.Tensor:
|
|
if step == 0 and self.adaptive_projected_guidance_momentum is not None:
|
|
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
|
|
|
|
pred = normalized_guidance_apg(
|
|
pred_cond,
|
|
pred_uncond,
|
|
self.guidance_scale,
|
|
self.momentum_buffer,
|
|
self.eta,
|
|
self.adaptive_projected_guidance_rescale,
|
|
self.use_original_formulation,
|
|
)
|
|
|
|
if self.guidance_rescale > 0.0:
|
|
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
|
|
|
return pred
|
|
|
|
|
|
def rescale_noise_cfg(guided_noise, conditional_noise, rescale_factor=0.0):
|
|
"""
|
|
Rescale guided noise prediction to prevent overexposure and improve image quality.
|
|
|
|
This implementation addresses the overexposure issue described in "Common Diffusion Noise
|
|
Schedules and Sample Steps are Flawed" (https://arxiv.org/pdf/2305.08891.pdf) (Section 3.4).
|
|
The rescaling preserves the statistical properties of the conditional prediction while reducing artifacts.
|
|
|
|
Args:
|
|
guided_noise (torch.Tensor): Noise prediction from classifier-free guidance.
|
|
conditional_noise (torch.Tensor): Noise prediction from conditional model.
|
|
rescale_factor (float): Interpolation factor between original and rescaled predictions.
|
|
0.0 = no rescaling, 1.0 = full rescaling.
|
|
|
|
Returns:
|
|
torch.Tensor: Rescaled noise prediction with reduced overexposure.
|
|
"""
|
|
if rescale_factor == 0.0:
|
|
return guided_noise
|
|
|
|
# Calculate standard deviation across spatial dimensions for both predictions
|
|
spatial_dims = list(range(1, conditional_noise.ndim))
|
|
conditional_std = conditional_noise.std(dim=spatial_dims, keepdim=True)
|
|
guided_std = guided_noise.std(dim=spatial_dims, keepdim=True)
|
|
|
|
# Rescale guided noise to match conditional noise statistics
|
|
std_ratio = conditional_std / guided_std
|
|
rescaled_prediction = guided_noise * std_ratio
|
|
|
|
# Interpolate between original and rescaled predictions
|
|
final_prediction = rescale_factor * rescaled_prediction + (1.0 - rescale_factor) * guided_noise
|
|
|
|
return final_prediction
|
|
|
|
|
|
def apply_classifier_free_guidance(
|
|
noise_pred_text: torch.Tensor,
|
|
noise_pred_uncond: torch.Tensor,
|
|
is_ocr: bool,
|
|
guidance_scale: float,
|
|
step: int,
|
|
apg_start_step_ocr: int = 38,
|
|
apg_start_step_general: int = 5,
|
|
cfg_guider_ocr: AdaptiveProjectedGuidance = None,
|
|
cfg_guider_general: AdaptiveProjectedGuidance = None,
|
|
guidance_rescale: float = 0.0,
|
|
):
|
|
"""
|
|
Apply classifier-free guidance with OCR-aware APG for batch_size=1.
|
|
|
|
Args:
|
|
noise_pred_text: Conditional noise prediction tensor [1, ...].
|
|
noise_pred_uncond: Unconditional noise prediction tensor [1, ...].
|
|
is_ocr: Whether this sample requires OCR-specific guidance.
|
|
guidance_scale: Guidance scale for CFG.
|
|
step: Current diffusion step index.
|
|
apg_start_step_ocr: Step to start APG for OCR regions.
|
|
apg_start_step_general: Step to start APG for general regions.
|
|
cfg_guider_ocr: APG guider for OCR regions.
|
|
cfg_guider_general: APG guider for general regions.
|
|
|
|
Returns:
|
|
Guided noise prediction tensor [1, ...].
|
|
"""
|
|
if guidance_scale == 1.0:
|
|
return noise_pred_text
|
|
|
|
# Select appropriate guider and start step based on OCR requirement
|
|
if is_ocr:
|
|
cfg_guider = cfg_guider_ocr
|
|
apg_start_step = apg_start_step_ocr
|
|
else:
|
|
cfg_guider = cfg_guider_general
|
|
apg_start_step = apg_start_step_general
|
|
|
|
# Apply standard CFG or APG based on current step
|
|
if step <= apg_start_step:
|
|
# Standard classifier-free guidance
|
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
|
|
|
if guidance_rescale > 0.0:
|
|
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
|
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale)
|
|
|
|
# Initialize APG guider state
|
|
_ = cfg_guider(noise_pred_text, noise_pred_uncond, step=step)
|
|
else:
|
|
# Use APG for guidance
|
|
noise_pred = cfg_guider(noise_pred_text, noise_pred_uncond, step=step)
|
|
|
|
return noise_pred
|
|
|
|
|
|
# endregion
|