mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
864 lines
31 KiB
Python
864 lines
31 KiB
Python
# Original work: https://github.com/Tencent-Hunyuan/HunyuanImage-2.1
|
|
# Re-implemented for license compliance for sd-scripts.
|
|
|
|
from typing import Tuple, Callable
|
|
import torch
|
|
import torch.nn as nn
|
|
from einops import rearrange
|
|
|
|
from library import custom_offloading_utils
|
|
from library.attention import AttentionParams, attention
|
|
from library.hunyuan_image_utils import timestep_embedding, apply_rotary_emb, _to_tuple, apply_gate, modulate
|
|
from library.attention import attention
|
|
|
|
# region Modules
|
|
|
|
|
|
class ByT5Mapper(nn.Module):
|
|
"""
|
|
Maps ByT5 character-level encoder outputs to transformer hidden space.
|
|
|
|
Applies layer normalization, two MLP layers with GELU activation,
|
|
and optional residual connection.
|
|
|
|
Args:
|
|
in_dim: Input dimension from ByT5 encoder (1472 for ByT5-large).
|
|
out_dim: Intermediate dimension after first projection.
|
|
hidden_dim: Hidden dimension for MLP layer.
|
|
out_dim1: Final output dimension matching transformer hidden size.
|
|
use_residual: Whether to add residual connection (requires in_dim == out_dim).
|
|
"""
|
|
|
|
def __init__(self, in_dim, out_dim, hidden_dim, out_dim1, use_residual=True):
|
|
super().__init__()
|
|
if use_residual:
|
|
assert in_dim == out_dim
|
|
self.layernorm = nn.LayerNorm(in_dim)
|
|
self.fc1 = nn.Linear(in_dim, hidden_dim)
|
|
self.fc2 = nn.Linear(hidden_dim, out_dim)
|
|
self.fc3 = nn.Linear(out_dim, out_dim1)
|
|
self.use_residual = use_residual
|
|
self.act_fn = nn.GELU()
|
|
|
|
def forward(self, x):
|
|
"""
|
|
Transform ByT5 embeddings to transformer space.
|
|
|
|
Args:
|
|
x: Input ByT5 embeddings [..., in_dim].
|
|
|
|
Returns:
|
|
Transformed embeddings [..., out_dim1].
|
|
"""
|
|
residual = x if self.use_residual else None
|
|
x = self.layernorm(x)
|
|
x = self.fc1(x)
|
|
x = self.act_fn(x)
|
|
x = self.fc2(x)
|
|
x = self.act_fn(x)
|
|
x = self.fc3(x)
|
|
if self.use_residual:
|
|
x = x + residual
|
|
return x
|
|
|
|
|
|
class PatchEmbed2D(nn.Module):
|
|
"""
|
|
2D patch embedding layer for converting image latents to transformer tokens.
|
|
|
|
Uses 2D convolution to project image patches to embedding space.
|
|
For HunyuanImage-2.1, patch_size=[1,1] means no spatial downsampling.
|
|
|
|
Args:
|
|
patch_size: Spatial size of patches (int or tuple).
|
|
in_chans: Number of input channels.
|
|
embed_dim: Output embedding dimension.
|
|
"""
|
|
|
|
def __init__(self, patch_size=16, in_chans=3, embed_dim=768):
|
|
super().__init__()
|
|
self.patch_size = tuple(patch_size)
|
|
|
|
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=True)
|
|
self.norm = nn.Identity() # No normalization layer used
|
|
|
|
def forward(self, x):
|
|
x = self.proj(x)
|
|
x = x.flatten(2).transpose(1, 2)
|
|
x = self.norm(x)
|
|
return x
|
|
|
|
|
|
class TimestepEmbedder(nn.Module):
|
|
"""
|
|
Embeds scalar diffusion timesteps into vector representations.
|
|
|
|
Uses sinusoidal encoding followed by a two-layer MLP.
|
|
|
|
Args:
|
|
hidden_size: Output embedding dimension.
|
|
act_layer: Activation function class (e.g., nn.SiLU).
|
|
frequency_embedding_size: Dimension of sinusoidal encoding.
|
|
max_period: Maximum period for sinusoidal frequencies.
|
|
out_size: Output dimension (defaults to hidden_size).
|
|
"""
|
|
|
|
def __init__(self, hidden_size, act_layer, frequency_embedding_size=256, max_period=10000, out_size=None):
|
|
super().__init__()
|
|
self.frequency_embedding_size = frequency_embedding_size
|
|
self.max_period = max_period
|
|
if out_size is None:
|
|
out_size = hidden_size
|
|
|
|
self.mlp = nn.Sequential(
|
|
nn.Linear(frequency_embedding_size, hidden_size, bias=True), act_layer(), nn.Linear(hidden_size, out_size, bias=True)
|
|
)
|
|
|
|
def forward(self, t):
|
|
t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype)
|
|
return self.mlp(t_freq)
|
|
|
|
|
|
class TextProjection(nn.Module):
|
|
"""
|
|
Projects text embeddings through a two-layer MLP.
|
|
|
|
Used for context-aware representation computation in token refinement.
|
|
|
|
Args:
|
|
in_channels: Input feature dimension.
|
|
hidden_size: Hidden and output dimension.
|
|
act_layer: Activation function class.
|
|
"""
|
|
|
|
def __init__(self, in_channels, hidden_size, act_layer):
|
|
super().__init__()
|
|
self.linear_1 = nn.Linear(in_features=in_channels, out_features=hidden_size, bias=True)
|
|
self.act_1 = act_layer()
|
|
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True)
|
|
|
|
def forward(self, caption):
|
|
hidden_states = self.linear_1(caption)
|
|
hidden_states = self.act_1(hidden_states)
|
|
hidden_states = self.linear_2(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class MLP(nn.Module):
|
|
"""
|
|
Multi-layer perceptron with configurable activation and normalization.
|
|
|
|
Standard two-layer MLP with optional dropout and intermediate normalization.
|
|
|
|
Args:
|
|
in_channels: Input feature dimension.
|
|
hidden_channels: Hidden layer dimension (defaults to in_channels).
|
|
out_features: Output dimension (defaults to in_channels).
|
|
act_layer: Activation function class.
|
|
norm_layer: Optional normalization layer class.
|
|
bias: Whether to use bias (can be bool or tuple for each layer).
|
|
drop: Dropout rate (can be float or tuple for each layer).
|
|
use_conv: Whether to use convolution instead of linear (not supported).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
hidden_channels=None,
|
|
out_features=None,
|
|
act_layer=nn.GELU,
|
|
norm_layer=None,
|
|
bias=True,
|
|
drop=0.0,
|
|
use_conv=False,
|
|
):
|
|
super().__init__()
|
|
assert not use_conv, "Convolutional MLP not supported in this implementation."
|
|
|
|
out_features = out_features or in_channels
|
|
hidden_channels = hidden_channels or in_channels
|
|
bias = _to_tuple(bias, 2)
|
|
drop_probs = _to_tuple(drop, 2)
|
|
|
|
self.fc1 = nn.Linear(in_channels, hidden_channels, bias=bias[0])
|
|
self.act = act_layer()
|
|
self.drop1 = nn.Dropout(drop_probs[0])
|
|
self.norm = norm_layer(hidden_channels) if norm_layer is not None else nn.Identity()
|
|
self.fc2 = nn.Linear(hidden_channels, out_features, bias=bias[1])
|
|
self.drop2 = nn.Dropout(drop_probs[1])
|
|
|
|
def forward(self, x):
|
|
x = self.fc1(x)
|
|
x = self.act(x)
|
|
x = self.drop1(x)
|
|
x = self.norm(x)
|
|
x = self.fc2(x)
|
|
x = self.drop2(x)
|
|
return x
|
|
|
|
|
|
class IndividualTokenRefinerBlock(nn.Module):
|
|
"""
|
|
Single transformer block for individual token refinement.
|
|
|
|
Applies self-attention and MLP with adaptive layer normalization (AdaLN)
|
|
conditioned on timestep and context information.
|
|
|
|
Args:
|
|
hidden_size: Model dimension.
|
|
heads_num: Number of attention heads.
|
|
mlp_width_ratio: MLP expansion ratio.
|
|
mlp_drop_rate: MLP dropout rate.
|
|
act_type: Activation function (only "silu" supported).
|
|
qk_norm: QK normalization flag (must be False).
|
|
qk_norm_type: QK normalization type (only "layer" supported).
|
|
qkv_bias: Use bias in QKV projections.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
heads_num: int,
|
|
mlp_width_ratio: float = 4.0,
|
|
mlp_drop_rate: float = 0.0,
|
|
act_type: str = "silu",
|
|
qk_norm: bool = False,
|
|
qk_norm_type: str = "layer",
|
|
qkv_bias: bool = True,
|
|
):
|
|
super().__init__()
|
|
assert qk_norm_type == "layer", "Only layer normalization supported for QK norm."
|
|
assert act_type == "silu", "Only SiLU activation supported."
|
|
assert not qk_norm, "QK normalization must be disabled."
|
|
|
|
self.heads_num = heads_num
|
|
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
|
|
|
|
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
|
|
self.self_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias)
|
|
|
|
self.self_attn_q_norm = nn.Identity()
|
|
self.self_attn_k_norm = nn.Identity()
|
|
self.self_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias)
|
|
|
|
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
|
|
self.mlp = MLP(in_channels=hidden_size, hidden_channels=mlp_hidden_dim, act_layer=nn.SiLU, drop=mlp_drop_rate)
|
|
|
|
self.adaLN_modulation = nn.Sequential(
|
|
nn.SiLU(),
|
|
nn.Linear(hidden_size, 2 * hidden_size, bias=True),
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor, c: torch.Tensor, attn_params: AttentionParams) -> torch.Tensor:
|
|
"""
|
|
Apply self-attention and MLP with adaptive conditioning.
|
|
|
|
Args:
|
|
x: Input token embeddings [B, L, C].
|
|
c: Combined conditioning vector [B, C].
|
|
attn_params: Attention parameters including sequence lengths.
|
|
|
|
Returns:
|
|
Refined token embeddings [B, L, C].
|
|
"""
|
|
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
|
|
norm_x = self.norm1(x)
|
|
qkv = self.self_attn_qkv(norm_x)
|
|
del norm_x
|
|
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
|
del qkv
|
|
q = self.self_attn_q_norm(q).to(v)
|
|
k = self.self_attn_k_norm(k).to(v)
|
|
qkv = [q, k, v]
|
|
del q, k, v
|
|
attn = attention(qkv, attn_params=attn_params)
|
|
|
|
x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
|
|
x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
|
|
return x
|
|
|
|
|
|
class IndividualTokenRefiner(nn.Module):
|
|
"""
|
|
Stack of token refinement blocks with self-attention.
|
|
|
|
Processes tokens individually with adaptive layer normalization.
|
|
|
|
Args:
|
|
hidden_size: Model dimension.
|
|
heads_num: Number of attention heads.
|
|
depth: Number of refinement blocks.
|
|
mlp_width_ratio: MLP expansion ratio.
|
|
mlp_drop_rate: MLP dropout rate.
|
|
act_type: Activation function type.
|
|
qk_norm: QK normalization flag.
|
|
qk_norm_type: QK normalization type.
|
|
qkv_bias: Use bias in QKV projections.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
heads_num: int,
|
|
depth: int,
|
|
mlp_width_ratio: float = 4.0,
|
|
mlp_drop_rate: float = 0.0,
|
|
act_type: str = "silu",
|
|
qk_norm: bool = False,
|
|
qk_norm_type: str = "layer",
|
|
qkv_bias: bool = True,
|
|
):
|
|
super().__init__()
|
|
self.blocks = nn.ModuleList(
|
|
[
|
|
IndividualTokenRefinerBlock(
|
|
hidden_size=hidden_size,
|
|
heads_num=heads_num,
|
|
mlp_width_ratio=mlp_width_ratio,
|
|
mlp_drop_rate=mlp_drop_rate,
|
|
act_type=act_type,
|
|
qk_norm=qk_norm,
|
|
qk_norm_type=qk_norm_type,
|
|
qkv_bias=qkv_bias,
|
|
)
|
|
for _ in range(depth)
|
|
]
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor, c: torch.LongTensor, attn_params: AttentionParams) -> torch.Tensor:
|
|
"""
|
|
Apply sequential token refinement.
|
|
|
|
Args:
|
|
x: Input token embeddings [B, L, C].
|
|
c: Combined conditioning vector [B, C].
|
|
attn_params: Attention parameters including sequence lengths.
|
|
|
|
Returns:
|
|
Refined token embeddings [B, L, C].
|
|
"""
|
|
for block in self.blocks:
|
|
x = block(x, c, attn_params)
|
|
return x
|
|
|
|
|
|
class SingleTokenRefiner(nn.Module):
|
|
"""
|
|
Text embedding refinement with timestep and context conditioning.
|
|
|
|
Projects input text embeddings and applies self-attention refinement
|
|
conditioned on diffusion timestep and aggregate text context.
|
|
|
|
Args:
|
|
in_channels: Input text embedding dimension.
|
|
hidden_size: Transformer hidden dimension.
|
|
heads_num: Number of attention heads.
|
|
depth: Number of refinement blocks.
|
|
"""
|
|
|
|
def __init__(self, in_channels: int, hidden_size: int, heads_num: int, depth: int):
|
|
# Fixed architecture parameters for HunyuanImage-2.1
|
|
mlp_drop_rate: float = 0.0 # No MLP dropout
|
|
act_type: str = "silu" # SiLU activation
|
|
mlp_width_ratio: float = 4.0 # 4x MLP expansion
|
|
qk_norm: bool = False # No QK normalization
|
|
qk_norm_type: str = "layer" # Layer norm type (unused)
|
|
qkv_bias: bool = True # Use QKV bias
|
|
|
|
super().__init__()
|
|
self.input_embedder = nn.Linear(in_channels, hidden_size, bias=True)
|
|
act_layer = nn.SiLU
|
|
self.t_embedder = TimestepEmbedder(hidden_size, act_layer)
|
|
self.c_embedder = TextProjection(in_channels, hidden_size, act_layer)
|
|
self.individual_token_refiner = IndividualTokenRefiner(
|
|
hidden_size=hidden_size,
|
|
heads_num=heads_num,
|
|
depth=depth,
|
|
mlp_width_ratio=mlp_width_ratio,
|
|
mlp_drop_rate=mlp_drop_rate,
|
|
act_type=act_type,
|
|
qk_norm=qk_norm,
|
|
qk_norm_type=qk_norm_type,
|
|
qkv_bias=qkv_bias,
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor, t: torch.LongTensor, attn_params: AttentionParams) -> torch.Tensor:
|
|
"""
|
|
Refine text embeddings with timestep conditioning.
|
|
|
|
Args:
|
|
x: Input text embeddings [B, L, in_channels].
|
|
t: Diffusion timestep [B].
|
|
attn_params: Attention parameters including sequence lengths.
|
|
|
|
Returns:
|
|
Refined embeddings [B, L, hidden_size].
|
|
"""
|
|
timestep_aware_representations = self.t_embedder(t)
|
|
|
|
# Compute context-aware representations by averaging valid tokens
|
|
txt_lens = attn_params.seqlens # img_len is not used for SingleTokenRefiner
|
|
context_aware_representations = torch.stack([x[i, : txt_lens[i]].mean(dim=0) for i in range(x.shape[0])], dim=0) # [B, C]
|
|
|
|
context_aware_representations = self.c_embedder(context_aware_representations)
|
|
c = timestep_aware_representations + context_aware_representations
|
|
del timestep_aware_representations, context_aware_representations
|
|
x = self.input_embedder(x)
|
|
x = self.individual_token_refiner(x, c, attn_params)
|
|
return x
|
|
|
|
|
|
class FinalLayer(nn.Module):
|
|
"""
|
|
Final output projection layer with adaptive layer normalization.
|
|
|
|
Projects transformer hidden states to output patch space with
|
|
timestep-conditioned modulation.
|
|
|
|
Args:
|
|
hidden_size: Input hidden dimension.
|
|
patch_size: Spatial patch size for output reshaping.
|
|
out_channels: Number of output channels.
|
|
act_layer: Activation function class.
|
|
"""
|
|
|
|
def __init__(self, hidden_size, patch_size, out_channels, act_layer):
|
|
super().__init__()
|
|
|
|
# Layer normalization without learnable parameters
|
|
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
|
out_size = (patch_size[0] * patch_size[1]) * out_channels
|
|
self.linear = nn.Linear(hidden_size, out_size, bias=True)
|
|
|
|
# Adaptive layer normalization modulation
|
|
self.adaLN_modulation = nn.Sequential(
|
|
act_layer(),
|
|
nn.Linear(hidden_size, 2 * hidden_size, bias=True),
|
|
)
|
|
|
|
def forward(self, x, c):
|
|
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
|
x = modulate(self.norm_final(x), shift=shift, scale=scale)
|
|
del shift, scale, c
|
|
x = self.linear(x)
|
|
return x
|
|
|
|
|
|
class RMSNorm(nn.Module):
|
|
"""
|
|
Root Mean Square Layer Normalization.
|
|
|
|
Normalizes input using RMS and applies learnable scaling.
|
|
More efficient than LayerNorm as it doesn't compute mean.
|
|
|
|
Args:
|
|
dim: Input feature dimension.
|
|
eps: Small value for numerical stability.
|
|
"""
|
|
|
|
def __init__(self, dim: int, eps: float = 1e-6):
|
|
super().__init__()
|
|
self.eps = eps
|
|
self.weight = nn.Parameter(torch.ones(dim))
|
|
|
|
def _norm(self, x):
|
|
"""
|
|
Apply RMS normalization.
|
|
|
|
Args:
|
|
x: Input tensor.
|
|
|
|
Returns:
|
|
RMS normalized tensor.
|
|
"""
|
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
|
|
|
def reset_parameters(self):
|
|
self.weight.fill_(1)
|
|
|
|
def forward(self, x):
|
|
"""
|
|
Apply RMSNorm with learnable scaling.
|
|
|
|
Args:
|
|
x: Input tensor.
|
|
|
|
Returns:
|
|
Normalized and scaled tensor.
|
|
"""
|
|
output = self._norm(x.float()).type_as(x)
|
|
del x
|
|
# output = output * self.weight
|
|
# fp8 support
|
|
output = output * self.weight.to(output.dtype)
|
|
return output
|
|
|
|
|
|
# kept for reference, not used in current implementation
|
|
# class LinearWarpforSingle(nn.Module):
|
|
# """
|
|
# Linear layer wrapper for concatenating and projecting two inputs.
|
|
|
|
# Used in single-stream blocks to combine attention output with MLP features.
|
|
|
|
# Args:
|
|
# in_dim: Input dimension (sum of both input feature dimensions).
|
|
# out_dim: Output dimension.
|
|
# bias: Whether to use bias in linear projection.
|
|
# """
|
|
|
|
# def __init__(self, in_dim: int, out_dim: int, bias=False):
|
|
# super().__init__()
|
|
# self.fc = nn.Linear(in_dim, out_dim, bias=bias)
|
|
|
|
# def forward(self, x, y):
|
|
# """Concatenate inputs along feature dimension and project."""
|
|
# x = torch.cat([x.contiguous(), y.contiguous()], dim=2).contiguous()
|
|
# return self.fc(x)
|
|
|
|
|
|
class ModulateDiT(nn.Module):
|
|
"""
|
|
Timestep conditioning modulation layer.
|
|
|
|
Projects timestep embeddings to multiple modulation parameters
|
|
for adaptive layer normalization.
|
|
|
|
Args:
|
|
hidden_size: Input conditioning dimension.
|
|
factor: Number of modulation parameters to generate.
|
|
act_layer: Activation function class.
|
|
"""
|
|
|
|
def __init__(self, hidden_size: int, factor: int, act_layer: Callable):
|
|
super().__init__()
|
|
self.act = act_layer()
|
|
self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return self.linear(self.act(x))
|
|
|
|
|
|
class MMDoubleStreamBlock(nn.Module):
|
|
"""
|
|
Multimodal double-stream transformer block.
|
|
|
|
Processes image and text tokens separately with cross-modal attention.
|
|
Each stream has its own normalization and MLP layers but shares
|
|
attention computation for cross-modal interaction.
|
|
|
|
Args:
|
|
hidden_size: Model dimension.
|
|
heads_num: Number of attention heads.
|
|
mlp_width_ratio: MLP expansion ratio.
|
|
mlp_act_type: MLP activation function (only "gelu_tanh" supported).
|
|
qk_norm: QK normalization flag (must be True).
|
|
qk_norm_type: QK normalization type (only "rms" supported).
|
|
qkv_bias: Use bias in QKV projections.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
heads_num: int,
|
|
mlp_width_ratio: float,
|
|
mlp_act_type: str = "gelu_tanh",
|
|
qk_norm: bool = True,
|
|
qk_norm_type: str = "rms",
|
|
qkv_bias: bool = False,
|
|
):
|
|
super().__init__()
|
|
|
|
assert mlp_act_type == "gelu_tanh", "Only GELU-tanh activation supported."
|
|
assert qk_norm_type == "rms", "Only RMS normalization supported."
|
|
assert qk_norm, "QK normalization must be enabled."
|
|
|
|
self.heads_num = heads_num
|
|
head_dim = hidden_size // heads_num
|
|
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
|
|
|
|
# Image stream processing components
|
|
self.img_mod = ModulateDiT(hidden_size, factor=6, act_layer=nn.SiLU)
|
|
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
|
|
|
self.img_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias)
|
|
|
|
self.img_attn_q_norm = RMSNorm(head_dim, eps=1e-6)
|
|
self.img_attn_k_norm = RMSNorm(head_dim, eps=1e-6)
|
|
self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias)
|
|
|
|
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
|
self.img_mlp = MLP(hidden_size, mlp_hidden_dim, act_layer=lambda: nn.GELU(approximate="tanh"), bias=True)
|
|
|
|
# Text stream processing components
|
|
self.txt_mod = ModulateDiT(hidden_size, factor=6, act_layer=nn.SiLU)
|
|
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
|
|
|
self.txt_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias)
|
|
self.txt_attn_q_norm = RMSNorm(head_dim, eps=1e-6)
|
|
self.txt_attn_k_norm = RMSNorm(head_dim, eps=1e-6)
|
|
self.txt_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias)
|
|
|
|
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
|
self.txt_mlp = MLP(hidden_size, mlp_hidden_dim, act_layer=lambda: nn.GELU(approximate="tanh"), bias=True)
|
|
|
|
self.gradient_checkpointing = False
|
|
self.cpu_offload_checkpointing = False
|
|
|
|
def enable_gradient_checkpointing(self, cpu_offload: bool = False):
|
|
self.gradient_checkpointing = True
|
|
self.cpu_offload_checkpointing = cpu_offload
|
|
|
|
def disable_gradient_checkpointing(self):
|
|
self.gradient_checkpointing = False
|
|
self.cpu_offload_checkpointing = False
|
|
|
|
def _forward(
|
|
self, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, freqs_cis: tuple = None, attn_params: AttentionParams = None
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
# Extract modulation parameters for image and text streams
|
|
(img_mod1_shift, img_mod1_scale, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate) = self.img_mod(vec).chunk(
|
|
6, dim=-1
|
|
)
|
|
(txt_mod1_shift, txt_mod1_scale, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate) = self.txt_mod(vec).chunk(
|
|
6, dim=-1
|
|
)
|
|
|
|
# Process image stream for attention
|
|
img_modulated = self.img_norm1(img)
|
|
img_modulated = modulate(img_modulated, shift=img_mod1_shift, scale=img_mod1_scale)
|
|
del img_mod1_shift, img_mod1_scale
|
|
|
|
img_qkv = self.img_attn_qkv(img_modulated)
|
|
del img_modulated
|
|
img_q, img_k, img_v = img_qkv.chunk(3, dim=-1)
|
|
del img_qkv
|
|
|
|
img_q = rearrange(img_q, "B L (H D) -> B L H D", H=self.heads_num)
|
|
img_k = rearrange(img_k, "B L (H D) -> B L H D", H=self.heads_num)
|
|
img_v = rearrange(img_v, "B L (H D) -> B L H D", H=self.heads_num)
|
|
|
|
# Apply QK-Norm if enabled
|
|
img_q = self.img_attn_q_norm(img_q).to(img_v)
|
|
img_k = self.img_attn_k_norm(img_k).to(img_v)
|
|
|
|
# Apply rotary position embeddings to image tokens
|
|
if freqs_cis is not None:
|
|
img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
|
|
del freqs_cis
|
|
|
|
# Process text stream for attention
|
|
txt_modulated = self.txt_norm1(txt)
|
|
txt_modulated = modulate(txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale)
|
|
|
|
txt_qkv = self.txt_attn_qkv(txt_modulated)
|
|
del txt_modulated
|
|
txt_q, txt_k, txt_v = txt_qkv.chunk(3, dim=-1)
|
|
del txt_qkv
|
|
|
|
txt_q = rearrange(txt_q, "B L (H D) -> B L H D", H=self.heads_num)
|
|
txt_k = rearrange(txt_k, "B L (H D) -> B L H D", H=self.heads_num)
|
|
txt_v = rearrange(txt_v, "B L (H D) -> B L H D", H=self.heads_num)
|
|
|
|
# Apply QK-Norm if enabled
|
|
txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
|
|
txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
|
|
|
|
# Concatenate image and text tokens for joint attention
|
|
img_seq_len = img.shape[1]
|
|
q = torch.cat([img_q, txt_q], dim=1)
|
|
del img_q, txt_q
|
|
k = torch.cat([img_k, txt_k], dim=1)
|
|
del img_k, txt_k
|
|
v = torch.cat([img_v, txt_v], dim=1)
|
|
del img_v, txt_v
|
|
|
|
qkv = [q, k, v]
|
|
del q, k, v
|
|
attn = attention(qkv, attn_params=attn_params)
|
|
del qkv
|
|
|
|
# Split attention outputs back to separate streams
|
|
img_attn, txt_attn = (attn[:, :img_seq_len].contiguous(), attn[:, img_seq_len:].contiguous())
|
|
del attn
|
|
|
|
# Apply attention projection and residual connection for image stream
|
|
img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate)
|
|
del img_attn, img_mod1_gate
|
|
|
|
# Apply MLP and residual connection for image stream
|
|
img = img + apply_gate(
|
|
self.img_mlp(modulate(self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale)),
|
|
gate=img_mod2_gate,
|
|
)
|
|
del img_mod2_shift, img_mod2_scale, img_mod2_gate
|
|
|
|
# Apply attention projection and residual connection for text stream
|
|
txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate)
|
|
del txt_attn, txt_mod1_gate
|
|
|
|
# Apply MLP and residual connection for text stream
|
|
txt = txt + apply_gate(
|
|
self.txt_mlp(modulate(self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale)),
|
|
gate=txt_mod2_gate,
|
|
)
|
|
del txt_mod2_shift, txt_mod2_scale, txt_mod2_gate
|
|
|
|
return img, txt
|
|
|
|
def forward(
|
|
self, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, freqs_cis: tuple = None, attn_params: AttentionParams = None
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
if self.gradient_checkpointing and self.training:
|
|
forward_fn = self._forward
|
|
if self.cpu_offload_checkpointing:
|
|
forward_fn = custom_offloading_utils.cpu_offload_wrapper(forward_fn, self.img_attn_qkv.weight.device)
|
|
|
|
return torch.utils.checkpoint.checkpoint(forward_fn, img, txt, vec, freqs_cis, attn_params, use_reentrant=False)
|
|
else:
|
|
return self._forward(img, txt, vec, freqs_cis, attn_params)
|
|
|
|
|
|
class MMSingleStreamBlock(nn.Module):
|
|
"""
|
|
Multimodal single-stream transformer block.
|
|
|
|
Processes concatenated image and text tokens jointly with shared attention.
|
|
Uses parallel linear layers for efficiency and applies RoPE only to image tokens.
|
|
|
|
Args:
|
|
hidden_size: Model dimension.
|
|
heads_num: Number of attention heads.
|
|
mlp_width_ratio: MLP expansion ratio.
|
|
mlp_act_type: MLP activation function (only "gelu_tanh" supported).
|
|
qk_norm: QK normalization flag (must be True).
|
|
qk_norm_type: QK normalization type (only "rms" supported).
|
|
qk_scale: Attention scaling factor (computed automatically if None).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
heads_num: int,
|
|
mlp_width_ratio: float = 4.0,
|
|
mlp_act_type: str = "gelu_tanh",
|
|
qk_norm: bool = True,
|
|
qk_norm_type: str = "rms",
|
|
qk_scale: float = None,
|
|
):
|
|
super().__init__()
|
|
|
|
assert mlp_act_type == "gelu_tanh", "Only GELU-tanh activation supported."
|
|
assert qk_norm_type == "rms", "Only RMS normalization supported."
|
|
assert qk_norm, "QK normalization must be enabled."
|
|
|
|
self.hidden_size = hidden_size
|
|
self.heads_num = heads_num
|
|
head_dim = hidden_size // heads_num
|
|
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
|
|
self.mlp_hidden_dim = mlp_hidden_dim
|
|
self.scale = qk_scale or head_dim**-0.5
|
|
|
|
# Parallel linear projections for efficiency
|
|
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim)
|
|
|
|
# Combined output projection
|
|
# self.linear2 = LinearWarpforSingle(hidden_size + mlp_hidden_dim, hidden_size, bias=True) # for reference
|
|
self.linear2 = nn.Linear(hidden_size + mlp_hidden_dim, hidden_size, bias=True)
|
|
|
|
# QK normalization layers
|
|
self.q_norm = RMSNorm(head_dim, eps=1e-6)
|
|
self.k_norm = RMSNorm(head_dim, eps=1e-6)
|
|
|
|
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
|
|
|
self.mlp_act = nn.GELU(approximate="tanh")
|
|
self.modulation = ModulateDiT(hidden_size, factor=3, act_layer=nn.SiLU)
|
|
|
|
self.gradient_checkpointing = False
|
|
self.cpu_offload_checkpointing = False
|
|
|
|
def enable_gradient_checkpointing(self, cpu_offload: bool = False):
|
|
self.gradient_checkpointing = True
|
|
self.cpu_offload_checkpointing = cpu_offload
|
|
|
|
def disable_gradient_checkpointing(self):
|
|
self.gradient_checkpointing = False
|
|
self.cpu_offload_checkpointing = False
|
|
|
|
def _forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
vec: torch.Tensor,
|
|
freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
|
|
attn_params: AttentionParams = None,
|
|
) -> torch.Tensor:
|
|
# Extract modulation parameters
|
|
mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
|
|
x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale)
|
|
|
|
# Compute Q, K, V, and MLP input
|
|
qkv_mlp = self.linear1(x_mod)
|
|
del x_mod
|
|
q, k, v, mlp = qkv_mlp.split([self.hidden_size, self.hidden_size, self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
|
del qkv_mlp
|
|
|
|
q = rearrange(q, "B L (H D) -> B L H D", H=self.heads_num)
|
|
k = rearrange(k, "B L (H D) -> B L H D", H=self.heads_num)
|
|
v = rearrange(v, "B L (H D) -> B L H D", H=self.heads_num)
|
|
|
|
# Apply QK-Norm if enabled
|
|
q = self.q_norm(q).to(v)
|
|
k = self.k_norm(k).to(v)
|
|
|
|
# Separate image and text tokens
|
|
img_q, txt_q = q[:, : attn_params.img_len, :, :], q[:, attn_params.img_len :, :, :]
|
|
del q
|
|
img_k, txt_k = k[:, : attn_params.img_len, :, :], k[:, attn_params.img_len :, :, :]
|
|
del k
|
|
|
|
# Apply rotary position embeddings only to image tokens
|
|
img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
|
|
del freqs_cis
|
|
|
|
# Recombine and compute joint attention
|
|
q = torch.cat([img_q, txt_q], dim=1)
|
|
del img_q, txt_q
|
|
k = torch.cat([img_k, txt_k], dim=1)
|
|
del img_k, txt_k
|
|
# v = torch.cat([img_v, txt_v], dim=1)
|
|
# del img_v, txt_v
|
|
qkv = [q, k, v]
|
|
del q, k, v
|
|
attn = attention(qkv, attn_params=attn_params)
|
|
del qkv
|
|
|
|
# Combine attention and MLP outputs, apply gating
|
|
# output = self.linear2(attn, self.mlp_act(mlp))
|
|
|
|
mlp = self.mlp_act(mlp)
|
|
output = torch.cat([attn, mlp], dim=2).contiguous()
|
|
del attn, mlp
|
|
output = self.linear2(output)
|
|
|
|
return x + apply_gate(output, gate=mod_gate)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
vec: torch.Tensor,
|
|
freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
|
|
attn_params: AttentionParams = None,
|
|
) -> torch.Tensor:
|
|
if self.gradient_checkpointing and self.training:
|
|
forward_fn = self._forward
|
|
if self.cpu_offload_checkpointing:
|
|
forward_fn = custom_offloading_utils.create_cpu_offloading_wrapper(forward_fn, self.linear1.weight.device)
|
|
|
|
return torch.utils.checkpoint.checkpoint(forward_fn, x, vec, freqs_cis, attn_params, use_reentrant=False)
|
|
else:
|
|
return self._forward(x, vec, freqs_cis, attn_params)
|
|
|
|
|
|
# endregion
|