feat: add multi backend attention and related update for HI2.1 models and scripts

This commit is contained in:
Kohya S
2025-09-20 19:45:33 +09:00
parent f834b2e0d4
commit b090d15f7d
6 changed files with 286 additions and 102 deletions

View File

@@ -126,7 +126,8 @@ accelerate launch --num_cpu_threads_per_process 1 hunyuan_image_train_network.py
--learning_rate=1e-4 \
--optimizer_type="AdamW8bit" \
--lr_scheduler="constant" \
--sdpa \
--attn_mode="torch" \
--split_attn \
--max_train_epochs=10 \
--save_every_n_epochs=1 \
--mixed_precision="bf16" \
@@ -175,6 +176,10 @@ The script adds HunyuanImage-2.1 specific arguments. For common arguments (like
#### Memory/Speed Related
* `--attn_mode=<choice>`
- Specifies the attention implementation to use. Options are `torch`, `xformers`, `flash`, `sageattn`. Default is `torch` (use scaled dot product attention). Each library must be installed separately other than `torch`. If using `xformers`, also specify `--split_attn` if the batch size is more than 1.
* `--split_attn`
- Splits the batch during attention computation to process one item at a time, reducing VRAM usage by avoiding attention mask computation. Can improve speed when using `torch`. Required when using `xformers` with batch size greater than 1.
* `--fp8_scaled`
- Enables training the DiT model in scaled FP8 format. This can significantly reduce VRAM usage (can run with as little as 8GB VRAM when combined with `--blocks_to_swap`), but the training results may vary. This is a newer alternative to the unsupported `--fp8_base` option.
* `--fp8_vl`
@@ -429,6 +434,7 @@ python hunyuan_image_minimal_inference.py \
--vae "<path to hunyuan_image_2.1_vae_fp16.safetensors>" \
--lora_weight "<path to your trained LoRA>" \
--lora_multiplier 1.0 \
--attn_mode "torch" \
--prompt "A cute cartoon penguin in a snowy landscape" \
--image_size 2048 2048 \
--infer_steps 50 \
@@ -445,6 +451,8 @@ python hunyuan_image_minimal_inference.py \
- `--guidance_scale`: CFG scale (default: 3.5)
- `--flow_shift`: Flow matching shift parameter (default: 5.0)
`--split_attn` is not supported (since inference is done one at a time).
<details>
<summary>日本語</summary>
@@ -457,6 +465,8 @@ python hunyuan_image_minimal_inference.py \
- `--guidance_scale`: CFGスケール推奨: 3.5
- `--flow_shift`: Flow Matchingシフトパラメータデフォルト: 5.0
`--split_attn`はサポートされていません1件ずつ推論するため
</details>
## 9. Related Tools / 関連ツール

View File

@@ -96,7 +96,7 @@ def parse_args() -> argparse.Namespace:
"--attn_mode",
type=str,
default="torch",
choices=["flash", "torch", "sageattn", "xformers", "sdpa"], # "flash2", "flash3",
choices=["flash", "torch", "sageattn", "xformers", "sdpa"], # "sdpa" for backward compatibility
help="attention mode",
)
parser.add_argument("--blocks_to_swap", type=int, default=0, help="number of blocks to swap in the model")
@@ -130,6 +130,9 @@ def parse_args() -> argparse.Namespace:
if args.lycoris and not lycoris_available:
raise ValueError("install lycoris: https://github.com/KohakuBlueleaf/LyCORIS")
if args.attn_mode == "sdpa":
args.attn_mode = "torch" # backward compatibility
return args
@@ -265,7 +268,7 @@ def load_dit_model(
device,
args.dit,
args.attn_mode,
False,
True, # enable split_attn to trim masked tokens
loading_device,
loading_weight_dtype,
args.fp8_scaled and not args.lycoris,

View File

@@ -379,18 +379,19 @@ class HunyuanImageNetworkTrainer(train_network.NetworkTrainer):
loading_dtype = None if args.fp8_scaled else weight_dtype
loading_device = "cpu" if self.is_swapping_blocks else accelerator.device
split_attn = True
attn_mode = "torch"
if args.xformers:
attn_mode = "xformers"
logger.info("xformers is enabled for attention")
if args.attn_mode is not None:
attn_mode = args.attn_mode
logger.info(f"Loading DiT model with attn_mode: {attn_mode}, split_attn: {args.split_attn}, fp8_scaled: {args.fp8_scaled}")
model = hunyuan_image_models.load_hunyuan_image_model(
accelerator.device,
args.pretrained_model_name_or_path,
attn_mode,
split_attn,
args.split_attn,
loading_device,
loading_dtype,
args.fp8_scaled,
@@ -674,6 +675,19 @@ def setup_parser() -> argparse.ArgumentParser:
help="Enable tiling for VAE decoding and encoding / VAEデコーディングとエンコーディングのタイルを有効にする",
)
parser.add_argument(
"--attn_mode",
choices=["torch", "xformers", "flash", "sageattn", "sdpa"], # "sdpa" is for backward compatibility
default=None,
help="Attention implementation to use. Default is None (torch). xformers requires --split_attn. sageattn does not support training (inference only). This option overrides --xformers or --sdpa."
" / 使用するAttentionの実装。デフォルトはNonetorchです。xformersは--split_attnの指定が必要です。sageattnはトレーニングをサポートしていません推論のみ。このオプションは--xformersまたは--sdpaを上書きします。",
)
parser.add_argument(
"--split_attn",
action="store_true",
help="split attention computation to reduce memory usage / メモリ使用量を減らすためにattention時にバッチを分割する",
)
return parser
@@ -684,5 +698,8 @@ if __name__ == "__main__":
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
if args.attn_mode == "sdpa":
args.attn_mode = "torch" # backward compatibility
trainer = HunyuanImageNetworkTrainer()
trainer.train(args)

View File

@@ -1,18 +1,88 @@
# Unified attention function supporting various implementations
from dataclasses import dataclass
import torch
from typing import Optional, Union
try:
import flash_attn
from flash_attn.flash_attn_interface import _flash_attn_forward
from flash_attn.flash_attn_interface import flash_attn_varlen_func
from flash_attn.flash_attn_interface import flash_attn_func
except ImportError:
flash_attn = None
flash_attn_varlen_func = None
_flash_attn_forward = None
flash_attn_func = None
try:
from sageattention import sageattn_varlen, sageattn
except ImportError:
sageattn_varlen = None
sageattn = None
try:
import xformers.ops as xops
except ImportError:
xops = None
@dataclass
class AttentionParams:
attn_mode: Optional[str] = None
split_attn: bool = False
img_len: Optional[int] = None
attention_mask: Optional[torch.Tensor] = None
seqlens: Optional[torch.Tensor] = None
cu_seqlens: Optional[torch.Tensor] = None
max_seqlen: Optional[int] = None
@staticmethod
def create_attention_params(attn_mode: Optional[str], split_attn: bool) -> "AttentionParams":
return AttentionParams(attn_mode, split_attn)
@staticmethod
def create_attention_params_from_mask(
attn_mode: Optional[str], split_attn: bool, img_len: Optional[int], attention_mask: Optional[torch.Tensor]
) -> "AttentionParams":
if attention_mask is None:
# No attention mask provided: assume all tokens are valid
return AttentionParams(attn_mode, split_attn, None, None, None, None, None)
else:
# Note: attention_mask is only for text tokens, not including image tokens
seqlens = attention_mask.sum(dim=1).to(torch.int32) + img_len # [B]
max_seqlen = attention_mask.shape[1] + img_len
if split_attn:
# cu_seqlens is not needed for split attention
return AttentionParams(attn_mode, split_attn, img_len, attention_mask, seqlens, None, max_seqlen)
# Convert attention mask to cumulative sequence lengths for flash attention
batch_size = attention_mask.shape[0]
cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device=attention_mask.device)
for i in range(batch_size):
cu_seqlens[2 * i + 1] = i * max_seqlen + seqlens[i] # end of valid tokens for query
cu_seqlens[2 * i + 2] = (i + 1) * max_seqlen # end of all tokens for query
# Expand attention mask to include image tokens
attention_mask = torch.nn.functional.pad(attention_mask, (img_len, 0), value=1) # [B, img_len + L]
if attn_mode == "xformers":
seqlens_list = seqlens.cpu().tolist()
attention_mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(
seqlens_list, seqlens_list, device=attention_mask.device
)
elif attn_mode == "torch":
attention_mask = attention_mask[:, None, None, :].to(torch.bool) # [B, 1, 1, img_len + L]
return AttentionParams(attn_mode, split_attn, img_len, attention_mask, seqlens, cu_seqlens, max_seqlen)
def attention(
qkv_or_q: Union[torch.Tensor, list],
k: Optional[torch.Tensor] = None,
v: Optional[torch.Tensor] = None,
seq_lens: Optional[list[int]] = None,
attn_mode: str = "torch",
attn_params: Optional[AttentionParams] = None,
drop_rate: float = 0.0,
) -> torch.Tensor:
"""
@@ -25,8 +95,7 @@ def attention(
qkv_or_q: Query tensor [B, L, H, D]. or list of such tensors.
k: Key tensor [B, L, H, D].
v: Value tensor [B, L, H, D].
seq_lens: Valid sequence length for each batch element.
attn_mode: Attention implementation ("torch" or "sageattn").
attn_param: Attention parameters including mask and sequence lengths.
drop_rate: Attention dropout rate.
Returns:
@@ -34,53 +103,158 @@ def attention(
"""
if isinstance(qkv_or_q, list):
q, k, v = qkv_or_q
q: torch.Tensor = q
qkv_or_q.clear()
del qkv_or_q
else:
q = qkv_or_q
q: torch.Tensor = qkv_or_q
del qkv_or_q
assert k is not None and v is not None, "k and v must be provided if qkv_or_q is a tensor"
if seq_lens is None:
seq_lens = [q.shape[1]] * q.shape[0]
if attn_params is None:
attn_params = AttentionParams.create_attention_params("torch", False)
# If split attn is False, attention mask is provided and all sequence lengths are same, we can trim the sequence
seqlen_trimmed = False
if not attn_params.split_attn and attn_params.attention_mask is not None and attn_params.seqlens is not None:
if torch.all(attn_params.seqlens == attn_params.seqlens[0]):
seqlen = attn_params.seqlens[0].item()
q = q[:, :seqlen]
k = k[:, :seqlen]
v = v[:, :seqlen]
max_seqlen = attn_params.max_seqlen
attn_params = AttentionParams.create_attention_params(attn_params.attn_mode, False) # do not in-place modify
attn_params.max_seqlen = max_seqlen # keep max_seqlen for padding
seqlen_trimmed = True
# Determine tensor layout based on attention implementation
if attn_mode == "torch" or attn_mode == "sageattn":
transpose_fn = lambda x: x.transpose(1, 2) # [B, H, L, D] for SDPA
if attn_params.attn_mode == "torch" or (
attn_params.attn_mode == "sageattn" and (attn_params.split_attn or attn_params.cu_seqlens is None)
):
transpose_fn = lambda x: x.transpose(1, 2) # [B, H, L, D] for SDPA and sageattn with fixed length
# pad on sequence length dimension
pad_fn = lambda x, pad_to: torch.nn.functional.pad(x, (0, 0, 0, pad_to - x.shape[-2]), value=0)
else:
transpose_fn = lambda x: x # [B, L, H, D] for other implementations
# pad on sequence length dimension
pad_fn = lambda x, pad_to: torch.nn.functional.pad(x, (0, 0, 0, 0, 0, pad_to - x.shape[-3]), value=0)
# Process each batch element with its valid sequence length
q_seq_len = q.shape[1]
q = [transpose_fn(q[i : i + 1, : seq_lens[i]]) for i in range(len(q))]
k = [transpose_fn(k[i : i + 1, : seq_lens[i]]) for i in range(len(k))]
v = [transpose_fn(v[i : i + 1, : seq_lens[i]]) for i in range(len(v))]
# Process each batch element with its valid sequence lengths
if attn_params.split_attn:
if attn_params.seqlens is None:
# If no seqlens provided, assume all tokens are valid
attn_params = AttentionParams.create_attention_params(attn_params.attn_mode, True) # do not in-place modify
attn_params.seqlens = torch.tensor([q.shape[1]] * q.shape[0], device=q.device)
attn_params.max_seqlen = q.shape[1]
q = [transpose_fn(q[i : i + 1, : attn_params.seqlens[i]]) for i in range(len(q))]
k = [transpose_fn(k[i : i + 1, : attn_params.seqlens[i]]) for i in range(len(k))]
v = [transpose_fn(v[i : i + 1, : attn_params.seqlens[i]]) for i in range(len(v))]
else:
q = transpose_fn(q)
k = transpose_fn(k)
v = transpose_fn(v)
if attn_mode == "torch":
x = []
for i in range(len(q)):
x_i = torch.nn.functional.scaled_dot_product_attention(q[i], k[i], v[i], dropout_p=drop_rate)
q[i] = None
k[i] = None
v[i] = None
x.append(torch.nn.functional.pad(x_i, (0, 0, 0, q_seq_len - x_i.shape[2]), value=0)) # Pad to max seq len, B, H, L, D
x = torch.cat(x, dim=0)
del q, k, v
if attn_params.attn_mode == "torch":
if attn_params.split_attn:
x = []
for i in range(len(q)):
x_i = torch.nn.functional.scaled_dot_product_attention(q[i], k[i], v[i], dropout_p=drop_rate)
q[i] = None
k[i] = None
v[i] = None
x.append(pad_fn(x_i, attn_params.max_seqlen)) # B, H, L, D
x = torch.cat(x, dim=0)
del q, k, v
elif attn_mode == "xformers":
x = []
for i in range(len(q)):
x_i = xops.memory_efficient_attention(q[i], k[i], v[i], p=drop_rate)
q[i] = None
k[i] = None
v[i] = None
x.append(torch.nn.functional.pad(x_i, (0, 0, 0, 0, 0, q_seq_len - x_i.shape[1]), value=0)) # B, L, H, D
x = torch.cat(x, dim=0)
del q, k, v
else:
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_params.attention_mask, dropout_p=drop_rate)
del q, k, v
elif attn_params.attn_mode == "xformers":
if attn_params.split_attn:
x = []
for i in range(len(q)):
x_i = xops.memory_efficient_attention(q[i], k[i], v[i], p=drop_rate)
q[i] = None
k[i] = None
v[i] = None
x.append(pad_fn(x_i, attn_params.max_seqlen)) # B, L, H, D
x = torch.cat(x, dim=0)
del q, k, v
else:
x = xops.memory_efficient_attention(q, k, v, attn_bias=attn_params.attention_mask, p=drop_rate)
del q, k, v
elif attn_params.attn_mode == "sageattn":
if attn_params.split_attn:
x = []
for i in range(len(q)):
# HND seems to cause an error
x_i = sageattn(q[i], k[i], v[i]) # B, H, L, D. No dropout support
q[i] = None
k[i] = None
v[i] = None
x.append(pad_fn(x_i, attn_params.max_seqlen)) # B, H, L, D
x = torch.cat(x, dim=0)
del q, k, v
elif attn_params.cu_seqlens is None: # all tokens are valid
x = sageattn(q, k, v) # B, L, H, D. No dropout support
del q, k, v
else:
# Reshape to [(bxs), a, d]
batch_size, seqlen = q.shape[0], q.shape[1]
q = q.view(q.shape[0] * q.shape[1], *q.shape[2:]) # [B*L, H, D]
k = k.view(k.shape[0] * k.shape[1], *k.shape[2:]) # [B*L, H, D]
v = v.view(v.shape[0] * v.shape[1], *v.shape[2:]) # [B*L, H, D]
# Assume cu_seqlens_q == cu_seqlens_kv and max_seqlen_q == max_seqlen_kv. No dropout support
x = sageattn_varlen(
q, k, v, attn_params.cu_seqlens, attn_params.cu_seqlens, attn_params.max_seqlen, attn_params.max_seqlen
)
del q, k, v
# Reshape x with shape [(bxs), a, d] to [b, s, a, d]
x = x.view(batch_size, seqlen, x.shape[-2], x.shape[-1]) # B, L, H, D
elif attn_params.attn_mode == "flash":
if attn_params.split_attn:
x = []
for i in range(len(q)):
# HND seems to cause an error
x_i = flash_attn_func(q[i], k[i], v[i], drop_rate) # B, L, H, D
q[i] = None
k[i] = None
v[i] = None
x.append(pad_fn(x_i, attn_params.max_seqlen)) # B, L, H, D
x = torch.cat(x, dim=0)
del q, k, v
elif attn_params.cu_seqlens is None: # all tokens are valid
x = flash_attn_func(q, k, v, drop_rate) # B, L, H, D
del q, k, v
else:
# Reshape to [(bxs), a, d]
batch_size, seqlen = q.shape[0], q.shape[1]
q = q.view(q.shape[0] * q.shape[1], *q.shape[2:]) # [B*L, H, D]
k = k.view(k.shape[0] * k.shape[1], *k.shape[2:]) # [B*L, H, D]
v = v.view(v.shape[0] * v.shape[1], *v.shape[2:]) # [B*L, H, D]
# Assume cu_seqlens_q == cu_seqlens_kv and max_seqlen_q == max_seqlen_kv
x = flash_attn_varlen_func(
q, k, v, attn_params.cu_seqlens, attn_params.cu_seqlens, attn_params.max_seqlen, attn_params.max_seqlen, drop_rate
)
del q, k, v
# Reshape x with shape [(bxs), a, d] to [b, s, a, d]
x = x.view(batch_size, seqlen, x.shape[-2], x.shape[-1]) # B, L, H, D
else:
# Currently only PyTorch SDPA and xformers are implemented
raise ValueError(f"Unsupported attention mode: {attn_mode}")
raise ValueError(f"Unsupported attention mode: {attn_params.attn_mode}")
x = transpose_fn(x) # [B, L, H, D]
x = x.reshape(x.shape[0], x.shape[1], -1) # [B, L, H*D]
if seqlen_trimmed:
x = torch.nn.functional.pad(x, (0, 0, 0, attn_params.max_seqlen - x.shape[1]), value=0) # pad back to max_seqlen
return x

View File

@@ -8,6 +8,7 @@ import torch.nn as nn
from accelerate import init_empty_weights
from library import custom_offloading_utils
from library.attention import AttentionParams
from library.fp8_optimization_utils import apply_fp8_monkey_patch
from library.lora_utils import load_safetensors_with_lora_and_fp8
from library.utils import setup_logging
@@ -50,7 +51,7 @@ class HYImageDiffusionTransformer(nn.Module):
attn_mode: Attention implementation mode ("torch" or "sageattn").
"""
def __init__(self, attn_mode: str = "torch"):
def __init__(self, attn_mode: str = "torch", split_attn: bool = False):
super().__init__()
# Fixed architecture parameters for HunyuanImage-2.1
@@ -80,6 +81,7 @@ class HYImageDiffusionTransformer(nn.Module):
qk_norm_type: str = "rms" # RMS normalization type
self.attn_mode = attn_mode
self.split_attn = split_attn
# ByT5 character-level text encoder mapping
self.byt5_in = ByT5Mapper(in_dim=1472, out_dim=2048, hidden_dim=2048, out_dim1=self.hidden_size, use_residual=False)
@@ -88,7 +90,7 @@ class HYImageDiffusionTransformer(nn.Module):
self.img_in = PatchEmbed2D(self.patch_size, self.in_channels, self.hidden_size)
# Text token refinement with cross-attention
self.txt_in = SingleTokenRefiner(text_states_dim, self.hidden_size, self.heads_num, depth=2, attn_mode=self.attn_mode)
self.txt_in = SingleTokenRefiner(text_states_dim, self.hidden_size, self.heads_num, depth=2)
# Timestep embedding for diffusion process
self.time_in = TimestepEmbedder(self.hidden_size, nn.SiLU)
@@ -110,7 +112,6 @@ class HYImageDiffusionTransformer(nn.Module):
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
qkv_bias=qkv_bias,
attn_mode=self.attn_mode,
)
for _ in range(mm_double_blocks_depth)
]
@@ -126,7 +127,6 @@ class HYImageDiffusionTransformer(nn.Module):
mlp_act_type=mlp_act_type,
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
attn_mode=self.attn_mode,
)
for _ in range(mm_single_blocks_depth)
]
@@ -339,22 +339,21 @@ class HYImageDiffusionTransformer(nn.Module):
# MeanFlow and guidance embedding not used in this configuration
# Process text tokens through refinement layers
txt_lens = text_mask.to(torch.bool).sum(dim=1).tolist()
txt = self.txt_in(txt, t, txt_lens)
txt_attn_params = AttentionParams.create_attention_params_from_mask(self.attn_mode, self.split_attn, 0, text_mask)
txt = self.txt_in(txt, t, txt_attn_params)
# Integrate character-level ByT5 features with word-level tokens
# Use variable length sequences with sequence lengths
byt5_txt = self.byt5_in(byt5_text_states)
txt, _, txt_lens = self.reorder_txt_token(byt5_txt, txt, byt5_text_mask, text_mask)
txt, text_mask, txt_lens = self.reorder_txt_token(byt5_txt, txt, byt5_text_mask, text_mask)
# Trim sequences to maximum length in the batch
img_seq_len = img.shape[1]
# print(f"img_seq_len: {img_seq_len}, txt_lens: {txt_lens}")
seq_lens = [img_seq_len + l for l in txt_lens]
max_txt_len = max(txt_lens)
# print(f"max_txt_len: {max_txt_len}, seq_lens: {seq_lens}, txt.shape: {txt.shape}")
txt = txt[:, :max_txt_len, :]
txt_seq_len = txt.shape[1]
text_mask = text_mask[:, :max_txt_len]
attn_params = AttentionParams.create_attention_params_from_mask(self.attn_mode, self.split_attn, img_seq_len, text_mask)
input_device = img.device
@@ -362,7 +361,7 @@ class HYImageDiffusionTransformer(nn.Module):
for index, block in enumerate(self.double_blocks):
if self.blocks_to_swap:
self.offloader_double.wait_for_block(index)
img, txt = block(img, txt, vec, freqs_cis, seq_lens)
img, txt = block(img, txt, vec, freqs_cis, attn_params)
if self.blocks_to_swap:
self.offloader_double.submit_move_blocks(self.double_blocks, index)
@@ -373,7 +372,7 @@ class HYImageDiffusionTransformer(nn.Module):
for index, block in enumerate(self.single_blocks):
if self.blocks_to_swap:
self.offloader_single.wait_for_block(index)
x = block(x, vec, txt_seq_len, freqs_cis, seq_lens)
x = block(x, vec, freqs_cis, attn_params)
if self.blocks_to_swap:
self.offloader_single.submit_move_blocks(self.single_blocks, index)
@@ -417,7 +416,7 @@ class HYImageDiffusionTransformer(nn.Module):
def create_model(attn_mode: str, split_attn: bool, dtype: Optional[torch.dtype]) -> HYImageDiffusionTransformer:
with init_empty_weights():
model = HYImageDiffusionTransformer(attn_mode=attn_mode)
model = HYImageDiffusionTransformer(attn_mode=attn_mode, split_attn=split_attn)
if dtype is not None:
model.to(dtype)
return model

View File

@@ -7,7 +7,7 @@ import torch.nn as nn
from einops import rearrange
from library import custom_offloading_utils
from library.attention import attention
from library.attention import AttentionParams, attention
from library.hunyuan_image_utils import timestep_embedding, apply_rotary_emb, _to_tuple, apply_gate, modulate
from library.attention import attention
@@ -213,7 +213,6 @@ class IndividualTokenRefinerBlock(nn.Module):
qk_norm: QK normalization flag (must be False).
qk_norm_type: QK normalization type (only "layer" supported).
qkv_bias: Use bias in QKV projections.
attn_mode: Attention implementation mode.
"""
def __init__(
@@ -226,15 +225,12 @@ class IndividualTokenRefinerBlock(nn.Module):
qk_norm: bool = False,
qk_norm_type: str = "layer",
qkv_bias: bool = True,
attn_mode: str = "torch",
):
super().__init__()
assert qk_norm_type == "layer", "Only layer normalization supported for QK norm."
assert act_type == "silu", "Only SiLU activation supported."
assert not qk_norm, "QK normalization must be disabled."
self.attn_mode = attn_mode
self.heads_num = heads_num
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
@@ -253,19 +249,14 @@ class IndividualTokenRefinerBlock(nn.Module):
nn.Linear(hidden_size, 2 * hidden_size, bias=True),
)
def forward(
self,
x: torch.Tensor,
c: torch.Tensor, # Combined timestep and context conditioning
txt_lens: list[int],
) -> torch.Tensor:
def forward(self, x: torch.Tensor, c: torch.Tensor, attn_params: AttentionParams) -> torch.Tensor:
"""
Apply self-attention and MLP with adaptive conditioning.
Args:
x: Input token embeddings [B, L, C].
c: Combined conditioning vector [B, C].
txt_lens: Valid sequence lengths for each batch element.
attn_params: Attention parameters including sequence lengths.
Returns:
Refined token embeddings [B, L, C].
@@ -273,10 +264,14 @@ class IndividualTokenRefinerBlock(nn.Module):
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
norm_x = self.norm1(x)
qkv = self.self_attn_qkv(norm_x)
del norm_x
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
del qkv
q = self.self_attn_q_norm(q).to(v)
k = self.self_attn_k_norm(k).to(v)
attn = attention(q, k, v, seq_lens=txt_lens, attn_mode=self.attn_mode)
qkv = [q, k, v]
del q, k, v
attn = attention(qkv, attn_params=attn_params)
x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
@@ -299,7 +294,6 @@ class IndividualTokenRefiner(nn.Module):
qk_norm: QK normalization flag.
qk_norm_type: QK normalization type.
qkv_bias: Use bias in QKV projections.
attn_mode: Attention implementation mode.
"""
def __init__(
@@ -313,7 +307,6 @@ class IndividualTokenRefiner(nn.Module):
qk_norm: bool = False,
qk_norm_type: str = "layer",
qkv_bias: bool = True,
attn_mode: str = "torch",
):
super().__init__()
self.blocks = nn.ModuleList(
@@ -327,26 +320,25 @@ class IndividualTokenRefiner(nn.Module):
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
qkv_bias=qkv_bias,
attn_mode=attn_mode,
)
for _ in range(depth)
]
)
def forward(self, x: torch.Tensor, c: torch.LongTensor, txt_lens: list[int]) -> torch.Tensor:
def forward(self, x: torch.Tensor, c: torch.LongTensor, attn_params: AttentionParams) -> torch.Tensor:
"""
Apply sequential token refinement.
Args:
x: Input token embeddings [B, L, C].
c: Combined conditioning vector [B, C].
txt_lens: Valid sequence lengths for each batch element.
attn_params: Attention parameters including sequence lengths.
Returns:
Refined token embeddings [B, L, C].
"""
for block in self.blocks:
x = block(x, c, txt_lens)
x = block(x, c, attn_params)
return x
@@ -362,10 +354,9 @@ class SingleTokenRefiner(nn.Module):
hidden_size: Transformer hidden dimension.
heads_num: Number of attention heads.
depth: Number of refinement blocks.
attn_mode: Attention implementation mode.
"""
def __init__(self, in_channels: int, hidden_size: int, heads_num: int, depth: int, attn_mode: str = "torch"):
def __init__(self, in_channels: int, hidden_size: int, heads_num: int, depth: int):
# Fixed architecture parameters for HunyuanImage-2.1
mlp_drop_rate: float = 0.0 # No MLP dropout
act_type: str = "silu" # SiLU activation
@@ -389,17 +380,16 @@ class SingleTokenRefiner(nn.Module):
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
qkv_bias=qkv_bias,
attn_mode=attn_mode,
)
def forward(self, x: torch.Tensor, t: torch.LongTensor, txt_lens: list[int]) -> torch.Tensor:
def forward(self, x: torch.Tensor, t: torch.LongTensor, attn_params: AttentionParams) -> torch.Tensor:
"""
Refine text embeddings with timestep conditioning.
Args:
x: Input text embeddings [B, L, in_channels].
t: Diffusion timestep [B].
txt_lens: Valid sequence lengths for each batch element.
attn_params: Attention parameters including sequence lengths.
Returns:
Refined embeddings [B, L, hidden_size].
@@ -407,13 +397,14 @@ class SingleTokenRefiner(nn.Module):
timestep_aware_representations = self.t_embedder(t)
# Compute context-aware representations by averaging valid tokens
txt_lens = attn_params.seqlens # img_len is not used for SingleTokenRefiner
context_aware_representations = torch.stack([x[i, : txt_lens[i]].mean(dim=0) for i in range(x.shape[0])], dim=0) # [B, C]
context_aware_representations = self.c_embedder(context_aware_representations)
c = timestep_aware_representations + context_aware_representations
del timestep_aware_representations, context_aware_representations
x = self.input_embedder(x)
x = self.individual_token_refiner(x, c, txt_lens)
x = self.individual_token_refiner(x, c, attn_params)
return x
@@ -564,7 +555,6 @@ class MMDoubleStreamBlock(nn.Module):
qk_norm: QK normalization flag (must be True).
qk_norm_type: QK normalization type (only "rms" supported).
qkv_bias: Use bias in QKV projections.
attn_mode: Attention implementation mode.
"""
def __init__(
@@ -576,7 +566,6 @@ class MMDoubleStreamBlock(nn.Module):
qk_norm: bool = True,
qk_norm_type: str = "rms",
qkv_bias: bool = False,
attn_mode: str = "torch",
):
super().__init__()
@@ -584,7 +573,6 @@ class MMDoubleStreamBlock(nn.Module):
assert qk_norm_type == "rms", "Only RMS normalization supported."
assert qk_norm, "QK normalization must be enabled."
self.attn_mode = attn_mode
self.heads_num = heads_num
head_dim = hidden_size // heads_num
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
@@ -626,7 +614,7 @@ class MMDoubleStreamBlock(nn.Module):
self.cpu_offload_checkpointing = False
def _forward(
self, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, freqs_cis: tuple = None, seq_lens: list[int] = None
self, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, freqs_cis: tuple = None, attn_params: AttentionParams = None
) -> Tuple[torch.Tensor, torch.Tensor]:
# Extract modulation parameters for image and text streams
(img_mod1_shift, img_mod1_scale, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate) = self.img_mod(vec).chunk(
@@ -687,7 +675,7 @@ class MMDoubleStreamBlock(nn.Module):
qkv = [q, k, v]
del q, k, v
attn = attention(qkv, seq_lens=seq_lens, attn_mode=self.attn_mode)
attn = attention(qkv, attn_params=attn_params)
del qkv
# Split attention outputs back to separate streams
@@ -719,16 +707,16 @@ class MMDoubleStreamBlock(nn.Module):
return img, txt
def forward(
self, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, freqs_cis: tuple = None, seq_lens: list[int] = None
self, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, freqs_cis: tuple = None, attn_params: AttentionParams = None
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.gradient_checkpointing and self.training:
forward_fn = self._forward
if self.cpu_offload_checkpointing:
forward_fn = custom_offloading_utils.cpu_offload_wrapper(forward_fn, self.img_attn_qkv.weight.device)
return torch.utils.checkpoint.checkpoint(forward_fn, img, txt, vec, freqs_cis, seq_lens, use_reentrant=False)
return torch.utils.checkpoint.checkpoint(forward_fn, img, txt, vec, freqs_cis, attn_params, use_reentrant=False)
else:
return self._forward(img, txt, vec, freqs_cis, seq_lens)
return self._forward(img, txt, vec, freqs_cis, attn_params)
class MMSingleStreamBlock(nn.Module):
@@ -746,7 +734,6 @@ class MMSingleStreamBlock(nn.Module):
qk_norm: QK normalization flag (must be True).
qk_norm_type: QK normalization type (only "rms" supported).
qk_scale: Attention scaling factor (computed automatically if None).
attn_mode: Attention implementation mode.
"""
def __init__(
@@ -758,7 +745,6 @@ class MMSingleStreamBlock(nn.Module):
qk_norm: bool = True,
qk_norm_type: str = "rms",
qk_scale: float = None,
attn_mode: str = "torch",
):
super().__init__()
@@ -766,7 +752,6 @@ class MMSingleStreamBlock(nn.Module):
assert qk_norm_type == "rms", "Only RMS normalization supported."
assert qk_norm, "QK normalization must be enabled."
self.attn_mode = attn_mode
self.hidden_size = hidden_size
self.heads_num = heads_num
head_dim = hidden_size // heads_num
@@ -805,9 +790,8 @@ class MMSingleStreamBlock(nn.Module):
self,
x: torch.Tensor,
vec: torch.Tensor,
txt_len: int,
freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
seq_lens: list[int] = None,
attn_params: AttentionParams = None,
) -> torch.Tensor:
# Extract modulation parameters
mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
@@ -828,12 +812,10 @@ class MMSingleStreamBlock(nn.Module):
k = self.k_norm(k).to(v)
# Separate image and text tokens
img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
img_q, txt_q = q[:, : attn_params.img_len, :, :], q[:, attn_params.img_len :, :, :]
del q
img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
img_k, txt_k = k[:, : attn_params.img_len, :, :], k[:, attn_params.img_len :, :, :]
del k
# img_v, txt_v = v[:, :-txt_len, :, :], v[:, -txt_len:, :, :]
# del v
# Apply rotary position embeddings only to image tokens
img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
@@ -848,7 +830,7 @@ class MMSingleStreamBlock(nn.Module):
# del img_v, txt_v
qkv = [q, k, v]
del q, k, v
attn = attention(qkv, seq_lens=seq_lens, attn_mode=self.attn_mode)
attn = attention(qkv, attn_params=attn_params)
del qkv
# Combine attention and MLP outputs, apply gating
@@ -865,18 +847,17 @@ class MMSingleStreamBlock(nn.Module):
self,
x: torch.Tensor,
vec: torch.Tensor,
txt_len: int,
freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
seq_lens: list[int] = None,
attn_params: AttentionParams = None,
) -> torch.Tensor:
if self.gradient_checkpointing and self.training:
forward_fn = self._forward
if self.cpu_offload_checkpointing:
forward_fn = custom_offloading_utils.create_cpu_offloading_wrapper(forward_fn, self.linear1.weight.device)
return torch.utils.checkpoint.checkpoint(forward_fn, x, vec, txt_len, freqs_cis, seq_lens, use_reentrant=False)
return torch.utils.checkpoint.checkpoint(forward_fn, x, vec, freqs_cis, attn_params, use_reentrant=False)
else:
return self._forward(x, vec, txt_len, freqs_cis, seq_lens)
return self._forward(x, vec, freqs_cis, attn_params)
# endregion