From b090d15f7d72324ba81575cb453002a935f5bcce Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Sat, 20 Sep 2025 19:45:33 +0900 Subject: [PATCH] feat: add multi backend attention and related update for HI2.1 models and scripts --- docs/hunyuan_image_train_network.md | 12 +- hunyuan_image_minimal_inference.py | 7 +- hunyuan_image_train_network.py | 23 ++- library/attention.py | 244 ++++++++++++++++++++++++---- library/hunyuan_image_models.py | 27 ++- library/hunyuan_image_modules.py | 75 ++++----- 6 files changed, 286 insertions(+), 102 deletions(-) diff --git a/docs/hunyuan_image_train_network.md b/docs/hunyuan_image_train_network.md index 3d49fbdf..667b4fec 100644 --- a/docs/hunyuan_image_train_network.md +++ b/docs/hunyuan_image_train_network.md @@ -126,7 +126,8 @@ accelerate launch --num_cpu_threads_per_process 1 hunyuan_image_train_network.py --learning_rate=1e-4 \ --optimizer_type="AdamW8bit" \ --lr_scheduler="constant" \ - --sdpa \ + --attn_mode="torch" \ + --split_attn \ --max_train_epochs=10 \ --save_every_n_epochs=1 \ --mixed_precision="bf16" \ @@ -175,6 +176,10 @@ The script adds HunyuanImage-2.1 specific arguments. For common arguments (like #### Memory/Speed Related +* `--attn_mode=` + - Specifies the attention implementation to use. Options are `torch`, `xformers`, `flash`, `sageattn`. Default is `torch` (use scaled dot product attention). Each library must be installed separately other than `torch`. If using `xformers`, also specify `--split_attn` if the batch size is more than 1. +* `--split_attn` + - Splits the batch during attention computation to process one item at a time, reducing VRAM usage by avoiding attention mask computation. Can improve speed when using `torch`. Required when using `xformers` with batch size greater than 1. * `--fp8_scaled` - Enables training the DiT model in scaled FP8 format. This can significantly reduce VRAM usage (can run with as little as 8GB VRAM when combined with `--blocks_to_swap`), but the training results may vary. This is a newer alternative to the unsupported `--fp8_base` option. * `--fp8_vl` @@ -429,6 +434,7 @@ python hunyuan_image_minimal_inference.py \ --vae "" \ --lora_weight "" \ --lora_multiplier 1.0 \ + --attn_mode "torch" \ --prompt "A cute cartoon penguin in a snowy landscape" \ --image_size 2048 2048 \ --infer_steps 50 \ @@ -445,6 +451,8 @@ python hunyuan_image_minimal_inference.py \ - `--guidance_scale`: CFG scale (default: 3.5) - `--flow_shift`: Flow matching shift parameter (default: 5.0) +`--split_attn` is not supported (since inference is done one at a time). +
日本語 @@ -457,6 +465,8 @@ python hunyuan_image_minimal_inference.py \ - `--guidance_scale`: CFGスケール(推奨: 3.5) - `--flow_shift`: Flow Matchingシフトパラメータ(デフォルト: 5.0) +`--split_attn`はサポートされていません(1件ずつ推論するため)。 +
## 9. Related Tools / 関連ツール diff --git a/hunyuan_image_minimal_inference.py b/hunyuan_image_minimal_inference.py index 00356a37..85023383 100644 --- a/hunyuan_image_minimal_inference.py +++ b/hunyuan_image_minimal_inference.py @@ -96,7 +96,7 @@ def parse_args() -> argparse.Namespace: "--attn_mode", type=str, default="torch", - choices=["flash", "torch", "sageattn", "xformers", "sdpa"], # "flash2", "flash3", + choices=["flash", "torch", "sageattn", "xformers", "sdpa"], # "sdpa" for backward compatibility help="attention mode", ) parser.add_argument("--blocks_to_swap", type=int, default=0, help="number of blocks to swap in the model") @@ -130,6 +130,9 @@ def parse_args() -> argparse.Namespace: if args.lycoris and not lycoris_available: raise ValueError("install lycoris: https://github.com/KohakuBlueleaf/LyCORIS") + if args.attn_mode == "sdpa": + args.attn_mode = "torch" # backward compatibility + return args @@ -265,7 +268,7 @@ def load_dit_model( device, args.dit, args.attn_mode, - False, + True, # enable split_attn to trim masked tokens loading_device, loading_weight_dtype, args.fp8_scaled and not args.lycoris, diff --git a/hunyuan_image_train_network.py b/hunyuan_image_train_network.py index 60aa2178..6b102a9a 100644 --- a/hunyuan_image_train_network.py +++ b/hunyuan_image_train_network.py @@ -379,18 +379,19 @@ class HunyuanImageNetworkTrainer(train_network.NetworkTrainer): loading_dtype = None if args.fp8_scaled else weight_dtype loading_device = "cpu" if self.is_swapping_blocks else accelerator.device - split_attn = True attn_mode = "torch" if args.xformers: attn_mode = "xformers" - logger.info("xformers is enabled for attention") + if args.attn_mode is not None: + attn_mode = args.attn_mode + logger.info(f"Loading DiT model with attn_mode: {attn_mode}, split_attn: {args.split_attn}, fp8_scaled: {args.fp8_scaled}") model = hunyuan_image_models.load_hunyuan_image_model( accelerator.device, args.pretrained_model_name_or_path, attn_mode, - split_attn, + args.split_attn, loading_device, loading_dtype, args.fp8_scaled, @@ -674,6 +675,19 @@ def setup_parser() -> argparse.ArgumentParser: help="Enable tiling for VAE decoding and encoding / VAEデコーディングとエンコーディングのタイルを有効にする", ) + parser.add_argument( + "--attn_mode", + choices=["torch", "xformers", "flash", "sageattn", "sdpa"], # "sdpa" is for backward compatibility + default=None, + help="Attention implementation to use. Default is None (torch). xformers requires --split_attn. sageattn does not support training (inference only). This option overrides --xformers or --sdpa." + " / 使用するAttentionの実装。デフォルトはNone(torch)です。xformersは--split_attnの指定が必要です。sageattnはトレーニングをサポートしていません(推論のみ)。このオプションは--xformersまたは--sdpaを上書きします。", + ) + parser.add_argument( + "--split_attn", + action="store_true", + help="split attention computation to reduce memory usage / メモリ使用量を減らすためにattention時にバッチを分割する", + ) + return parser @@ -684,5 +698,8 @@ if __name__ == "__main__": train_util.verify_command_line_training_args(args) args = train_util.read_config_from_file(args, parser) + if args.attn_mode == "sdpa": + args.attn_mode = "torch" # backward compatibility + trainer = HunyuanImageNetworkTrainer() trainer.train(args) diff --git a/library/attention.py b/library/attention.py index f1e7c0b0..d3b8441e 100644 --- a/library/attention.py +++ b/library/attention.py @@ -1,18 +1,88 @@ +# Unified attention function supporting various implementations + +from dataclasses import dataclass import torch from typing import Optional, Union +try: + import flash_attn + from flash_attn.flash_attn_interface import _flash_attn_forward + from flash_attn.flash_attn_interface import flash_attn_varlen_func + from flash_attn.flash_attn_interface import flash_attn_func +except ImportError: + flash_attn = None + flash_attn_varlen_func = None + _flash_attn_forward = None + flash_attn_func = None + +try: + from sageattention import sageattn_varlen, sageattn +except ImportError: + sageattn_varlen = None + sageattn = None + try: import xformers.ops as xops except ImportError: xops = None +@dataclass +class AttentionParams: + attn_mode: Optional[str] = None + split_attn: bool = False + img_len: Optional[int] = None + attention_mask: Optional[torch.Tensor] = None + seqlens: Optional[torch.Tensor] = None + cu_seqlens: Optional[torch.Tensor] = None + max_seqlen: Optional[int] = None + + @staticmethod + def create_attention_params(attn_mode: Optional[str], split_attn: bool) -> "AttentionParams": + return AttentionParams(attn_mode, split_attn) + + @staticmethod + def create_attention_params_from_mask( + attn_mode: Optional[str], split_attn: bool, img_len: Optional[int], attention_mask: Optional[torch.Tensor] + ) -> "AttentionParams": + if attention_mask is None: + # No attention mask provided: assume all tokens are valid + return AttentionParams(attn_mode, split_attn, None, None, None, None, None) + else: + # Note: attention_mask is only for text tokens, not including image tokens + seqlens = attention_mask.sum(dim=1).to(torch.int32) + img_len # [B] + max_seqlen = attention_mask.shape[1] + img_len + + if split_attn: + # cu_seqlens is not needed for split attention + return AttentionParams(attn_mode, split_attn, img_len, attention_mask, seqlens, None, max_seqlen) + + # Convert attention mask to cumulative sequence lengths for flash attention + batch_size = attention_mask.shape[0] + cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device=attention_mask.device) + for i in range(batch_size): + cu_seqlens[2 * i + 1] = i * max_seqlen + seqlens[i] # end of valid tokens for query + cu_seqlens[2 * i + 2] = (i + 1) * max_seqlen # end of all tokens for query + + # Expand attention mask to include image tokens + attention_mask = torch.nn.functional.pad(attention_mask, (img_len, 0), value=1) # [B, img_len + L] + + if attn_mode == "xformers": + seqlens_list = seqlens.cpu().tolist() + attention_mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens( + seqlens_list, seqlens_list, device=attention_mask.device + ) + elif attn_mode == "torch": + attention_mask = attention_mask[:, None, None, :].to(torch.bool) # [B, 1, 1, img_len + L] + + return AttentionParams(attn_mode, split_attn, img_len, attention_mask, seqlens, cu_seqlens, max_seqlen) + + def attention( qkv_or_q: Union[torch.Tensor, list], k: Optional[torch.Tensor] = None, v: Optional[torch.Tensor] = None, - seq_lens: Optional[list[int]] = None, - attn_mode: str = "torch", + attn_params: Optional[AttentionParams] = None, drop_rate: float = 0.0, ) -> torch.Tensor: """ @@ -25,8 +95,7 @@ def attention( qkv_or_q: Query tensor [B, L, H, D]. or list of such tensors. k: Key tensor [B, L, H, D]. v: Value tensor [B, L, H, D]. - seq_lens: Valid sequence length for each batch element. - attn_mode: Attention implementation ("torch" or "sageattn"). + attn_param: Attention parameters including mask and sequence lengths. drop_rate: Attention dropout rate. Returns: @@ -34,53 +103,158 @@ def attention( """ if isinstance(qkv_or_q, list): q, k, v = qkv_or_q + q: torch.Tensor = q qkv_or_q.clear() del qkv_or_q else: - q = qkv_or_q + q: torch.Tensor = qkv_or_q del qkv_or_q assert k is not None and v is not None, "k and v must be provided if qkv_or_q is a tensor" - if seq_lens is None: - seq_lens = [q.shape[1]] * q.shape[0] + if attn_params is None: + attn_params = AttentionParams.create_attention_params("torch", False) + + # If split attn is False, attention mask is provided and all sequence lengths are same, we can trim the sequence + seqlen_trimmed = False + if not attn_params.split_attn and attn_params.attention_mask is not None and attn_params.seqlens is not None: + if torch.all(attn_params.seqlens == attn_params.seqlens[0]): + seqlen = attn_params.seqlens[0].item() + q = q[:, :seqlen] + k = k[:, :seqlen] + v = v[:, :seqlen] + max_seqlen = attn_params.max_seqlen + attn_params = AttentionParams.create_attention_params(attn_params.attn_mode, False) # do not in-place modify + attn_params.max_seqlen = max_seqlen # keep max_seqlen for padding + seqlen_trimmed = True # Determine tensor layout based on attention implementation - if attn_mode == "torch" or attn_mode == "sageattn": - transpose_fn = lambda x: x.transpose(1, 2) # [B, H, L, D] for SDPA + if attn_params.attn_mode == "torch" or ( + attn_params.attn_mode == "sageattn" and (attn_params.split_attn or attn_params.cu_seqlens is None) + ): + transpose_fn = lambda x: x.transpose(1, 2) # [B, H, L, D] for SDPA and sageattn with fixed length + # pad on sequence length dimension + pad_fn = lambda x, pad_to: torch.nn.functional.pad(x, (0, 0, 0, pad_to - x.shape[-2]), value=0) else: transpose_fn = lambda x: x # [B, L, H, D] for other implementations + # pad on sequence length dimension + pad_fn = lambda x, pad_to: torch.nn.functional.pad(x, (0, 0, 0, 0, 0, pad_to - x.shape[-3]), value=0) - # Process each batch element with its valid sequence length - q_seq_len = q.shape[1] - q = [transpose_fn(q[i : i + 1, : seq_lens[i]]) for i in range(len(q))] - k = [transpose_fn(k[i : i + 1, : seq_lens[i]]) for i in range(len(k))] - v = [transpose_fn(v[i : i + 1, : seq_lens[i]]) for i in range(len(v))] + # Process each batch element with its valid sequence lengths + if attn_params.split_attn: + if attn_params.seqlens is None: + # If no seqlens provided, assume all tokens are valid + attn_params = AttentionParams.create_attention_params(attn_params.attn_mode, True) # do not in-place modify + attn_params.seqlens = torch.tensor([q.shape[1]] * q.shape[0], device=q.device) + attn_params.max_seqlen = q.shape[1] + q = [transpose_fn(q[i : i + 1, : attn_params.seqlens[i]]) for i in range(len(q))] + k = [transpose_fn(k[i : i + 1, : attn_params.seqlens[i]]) for i in range(len(k))] + v = [transpose_fn(v[i : i + 1, : attn_params.seqlens[i]]) for i in range(len(v))] + else: + q = transpose_fn(q) + k = transpose_fn(k) + v = transpose_fn(v) - if attn_mode == "torch": - x = [] - for i in range(len(q)): - x_i = torch.nn.functional.scaled_dot_product_attention(q[i], k[i], v[i], dropout_p=drop_rate) - q[i] = None - k[i] = None - v[i] = None - x.append(torch.nn.functional.pad(x_i, (0, 0, 0, q_seq_len - x_i.shape[2]), value=0)) # Pad to max seq len, B, H, L, D - x = torch.cat(x, dim=0) - del q, k, v + if attn_params.attn_mode == "torch": + if attn_params.split_attn: + x = [] + for i in range(len(q)): + x_i = torch.nn.functional.scaled_dot_product_attention(q[i], k[i], v[i], dropout_p=drop_rate) + q[i] = None + k[i] = None + v[i] = None + x.append(pad_fn(x_i, attn_params.max_seqlen)) # B, H, L, D + x = torch.cat(x, dim=0) + del q, k, v - elif attn_mode == "xformers": - x = [] - for i in range(len(q)): - x_i = xops.memory_efficient_attention(q[i], k[i], v[i], p=drop_rate) - q[i] = None - k[i] = None - v[i] = None - x.append(torch.nn.functional.pad(x_i, (0, 0, 0, 0, 0, q_seq_len - x_i.shape[1]), value=0)) # B, L, H, D - x = torch.cat(x, dim=0) - del q, k, v + else: + x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_params.attention_mask, dropout_p=drop_rate) + del q, k, v + + elif attn_params.attn_mode == "xformers": + if attn_params.split_attn: + x = [] + for i in range(len(q)): + x_i = xops.memory_efficient_attention(q[i], k[i], v[i], p=drop_rate) + q[i] = None + k[i] = None + v[i] = None + x.append(pad_fn(x_i, attn_params.max_seqlen)) # B, L, H, D + x = torch.cat(x, dim=0) + del q, k, v + + else: + x = xops.memory_efficient_attention(q, k, v, attn_bias=attn_params.attention_mask, p=drop_rate) + del q, k, v + + elif attn_params.attn_mode == "sageattn": + if attn_params.split_attn: + x = [] + for i in range(len(q)): + # HND seems to cause an error + x_i = sageattn(q[i], k[i], v[i]) # B, H, L, D. No dropout support + q[i] = None + k[i] = None + v[i] = None + x.append(pad_fn(x_i, attn_params.max_seqlen)) # B, H, L, D + x = torch.cat(x, dim=0) + del q, k, v + elif attn_params.cu_seqlens is None: # all tokens are valid + x = sageattn(q, k, v) # B, L, H, D. No dropout support + del q, k, v + else: + # Reshape to [(bxs), a, d] + batch_size, seqlen = q.shape[0], q.shape[1] + q = q.view(q.shape[0] * q.shape[1], *q.shape[2:]) # [B*L, H, D] + k = k.view(k.shape[0] * k.shape[1], *k.shape[2:]) # [B*L, H, D] + v = v.view(v.shape[0] * v.shape[1], *v.shape[2:]) # [B*L, H, D] + + # Assume cu_seqlens_q == cu_seqlens_kv and max_seqlen_q == max_seqlen_kv. No dropout support + x = sageattn_varlen( + q, k, v, attn_params.cu_seqlens, attn_params.cu_seqlens, attn_params.max_seqlen, attn_params.max_seqlen + ) + del q, k, v + + # Reshape x with shape [(bxs), a, d] to [b, s, a, d] + x = x.view(batch_size, seqlen, x.shape[-2], x.shape[-1]) # B, L, H, D + + elif attn_params.attn_mode == "flash": + if attn_params.split_attn: + x = [] + for i in range(len(q)): + # HND seems to cause an error + x_i = flash_attn_func(q[i], k[i], v[i], drop_rate) # B, L, H, D + q[i] = None + k[i] = None + v[i] = None + x.append(pad_fn(x_i, attn_params.max_seqlen)) # B, L, H, D + x = torch.cat(x, dim=0) + del q, k, v + elif attn_params.cu_seqlens is None: # all tokens are valid + x = flash_attn_func(q, k, v, drop_rate) # B, L, H, D + del q, k, v + else: + # Reshape to [(bxs), a, d] + batch_size, seqlen = q.shape[0], q.shape[1] + q = q.view(q.shape[0] * q.shape[1], *q.shape[2:]) # [B*L, H, D] + k = k.view(k.shape[0] * k.shape[1], *k.shape[2:]) # [B*L, H, D] + v = v.view(v.shape[0] * v.shape[1], *v.shape[2:]) # [B*L, H, D] + + # Assume cu_seqlens_q == cu_seqlens_kv and max_seqlen_q == max_seqlen_kv + x = flash_attn_varlen_func( + q, k, v, attn_params.cu_seqlens, attn_params.cu_seqlens, attn_params.max_seqlen, attn_params.max_seqlen, drop_rate + ) + del q, k, v + + # Reshape x with shape [(bxs), a, d] to [b, s, a, d] + x = x.view(batch_size, seqlen, x.shape[-2], x.shape[-1]) # B, L, H, D else: # Currently only PyTorch SDPA and xformers are implemented - raise ValueError(f"Unsupported attention mode: {attn_mode}") + raise ValueError(f"Unsupported attention mode: {attn_params.attn_mode}") x = transpose_fn(x) # [B, L, H, D] x = x.reshape(x.shape[0], x.shape[1], -1) # [B, L, H*D] + + if seqlen_trimmed: + x = torch.nn.functional.pad(x, (0, 0, 0, attn_params.max_seqlen - x.shape[1]), value=0) # pad back to max_seqlen + return x diff --git a/library/hunyuan_image_models.py b/library/hunyuan_image_models.py index 356ce4b4..fc320dfc 100644 --- a/library/hunyuan_image_models.py +++ b/library/hunyuan_image_models.py @@ -8,6 +8,7 @@ import torch.nn as nn from accelerate import init_empty_weights from library import custom_offloading_utils +from library.attention import AttentionParams from library.fp8_optimization_utils import apply_fp8_monkey_patch from library.lora_utils import load_safetensors_with_lora_and_fp8 from library.utils import setup_logging @@ -50,7 +51,7 @@ class HYImageDiffusionTransformer(nn.Module): attn_mode: Attention implementation mode ("torch" or "sageattn"). """ - def __init__(self, attn_mode: str = "torch"): + def __init__(self, attn_mode: str = "torch", split_attn: bool = False): super().__init__() # Fixed architecture parameters for HunyuanImage-2.1 @@ -80,6 +81,7 @@ class HYImageDiffusionTransformer(nn.Module): qk_norm_type: str = "rms" # RMS normalization type self.attn_mode = attn_mode + self.split_attn = split_attn # ByT5 character-level text encoder mapping self.byt5_in = ByT5Mapper(in_dim=1472, out_dim=2048, hidden_dim=2048, out_dim1=self.hidden_size, use_residual=False) @@ -88,7 +90,7 @@ class HYImageDiffusionTransformer(nn.Module): self.img_in = PatchEmbed2D(self.patch_size, self.in_channels, self.hidden_size) # Text token refinement with cross-attention - self.txt_in = SingleTokenRefiner(text_states_dim, self.hidden_size, self.heads_num, depth=2, attn_mode=self.attn_mode) + self.txt_in = SingleTokenRefiner(text_states_dim, self.hidden_size, self.heads_num, depth=2) # Timestep embedding for diffusion process self.time_in = TimestepEmbedder(self.hidden_size, nn.SiLU) @@ -110,7 +112,6 @@ class HYImageDiffusionTransformer(nn.Module): qk_norm=qk_norm, qk_norm_type=qk_norm_type, qkv_bias=qkv_bias, - attn_mode=self.attn_mode, ) for _ in range(mm_double_blocks_depth) ] @@ -126,7 +127,6 @@ class HYImageDiffusionTransformer(nn.Module): mlp_act_type=mlp_act_type, qk_norm=qk_norm, qk_norm_type=qk_norm_type, - attn_mode=self.attn_mode, ) for _ in range(mm_single_blocks_depth) ] @@ -339,22 +339,21 @@ class HYImageDiffusionTransformer(nn.Module): # MeanFlow and guidance embedding not used in this configuration # Process text tokens through refinement layers - txt_lens = text_mask.to(torch.bool).sum(dim=1).tolist() - txt = self.txt_in(txt, t, txt_lens) + txt_attn_params = AttentionParams.create_attention_params_from_mask(self.attn_mode, self.split_attn, 0, text_mask) + txt = self.txt_in(txt, t, txt_attn_params) # Integrate character-level ByT5 features with word-level tokens # Use variable length sequences with sequence lengths byt5_txt = self.byt5_in(byt5_text_states) - txt, _, txt_lens = self.reorder_txt_token(byt5_txt, txt, byt5_text_mask, text_mask) + txt, text_mask, txt_lens = self.reorder_txt_token(byt5_txt, txt, byt5_text_mask, text_mask) # Trim sequences to maximum length in the batch img_seq_len = img.shape[1] - # print(f"img_seq_len: {img_seq_len}, txt_lens: {txt_lens}") - seq_lens = [img_seq_len + l for l in txt_lens] max_txt_len = max(txt_lens) - # print(f"max_txt_len: {max_txt_len}, seq_lens: {seq_lens}, txt.shape: {txt.shape}") txt = txt[:, :max_txt_len, :] - txt_seq_len = txt.shape[1] + text_mask = text_mask[:, :max_txt_len] + + attn_params = AttentionParams.create_attention_params_from_mask(self.attn_mode, self.split_attn, img_seq_len, text_mask) input_device = img.device @@ -362,7 +361,7 @@ class HYImageDiffusionTransformer(nn.Module): for index, block in enumerate(self.double_blocks): if self.blocks_to_swap: self.offloader_double.wait_for_block(index) - img, txt = block(img, txt, vec, freqs_cis, seq_lens) + img, txt = block(img, txt, vec, freqs_cis, attn_params) if self.blocks_to_swap: self.offloader_double.submit_move_blocks(self.double_blocks, index) @@ -373,7 +372,7 @@ class HYImageDiffusionTransformer(nn.Module): for index, block in enumerate(self.single_blocks): if self.blocks_to_swap: self.offloader_single.wait_for_block(index) - x = block(x, vec, txt_seq_len, freqs_cis, seq_lens) + x = block(x, vec, freqs_cis, attn_params) if self.blocks_to_swap: self.offloader_single.submit_move_blocks(self.single_blocks, index) @@ -417,7 +416,7 @@ class HYImageDiffusionTransformer(nn.Module): def create_model(attn_mode: str, split_attn: bool, dtype: Optional[torch.dtype]) -> HYImageDiffusionTransformer: with init_empty_weights(): - model = HYImageDiffusionTransformer(attn_mode=attn_mode) + model = HYImageDiffusionTransformer(attn_mode=attn_mode, split_attn=split_attn) if dtype is not None: model.to(dtype) return model diff --git a/library/hunyuan_image_modules.py b/library/hunyuan_image_modules.py index 555cb487..1953a783 100644 --- a/library/hunyuan_image_modules.py +++ b/library/hunyuan_image_modules.py @@ -7,7 +7,7 @@ import torch.nn as nn from einops import rearrange from library import custom_offloading_utils -from library.attention import attention +from library.attention import AttentionParams, attention from library.hunyuan_image_utils import timestep_embedding, apply_rotary_emb, _to_tuple, apply_gate, modulate from library.attention import attention @@ -213,7 +213,6 @@ class IndividualTokenRefinerBlock(nn.Module): qk_norm: QK normalization flag (must be False). qk_norm_type: QK normalization type (only "layer" supported). qkv_bias: Use bias in QKV projections. - attn_mode: Attention implementation mode. """ def __init__( @@ -226,15 +225,12 @@ class IndividualTokenRefinerBlock(nn.Module): qk_norm: bool = False, qk_norm_type: str = "layer", qkv_bias: bool = True, - attn_mode: str = "torch", ): super().__init__() assert qk_norm_type == "layer", "Only layer normalization supported for QK norm." assert act_type == "silu", "Only SiLU activation supported." assert not qk_norm, "QK normalization must be disabled." - self.attn_mode = attn_mode - self.heads_num = heads_num mlp_hidden_dim = int(hidden_size * mlp_width_ratio) @@ -253,19 +249,14 @@ class IndividualTokenRefinerBlock(nn.Module): nn.Linear(hidden_size, 2 * hidden_size, bias=True), ) - def forward( - self, - x: torch.Tensor, - c: torch.Tensor, # Combined timestep and context conditioning - txt_lens: list[int], - ) -> torch.Tensor: + def forward(self, x: torch.Tensor, c: torch.Tensor, attn_params: AttentionParams) -> torch.Tensor: """ Apply self-attention and MLP with adaptive conditioning. Args: x: Input token embeddings [B, L, C]. c: Combined conditioning vector [B, C]. - txt_lens: Valid sequence lengths for each batch element. + attn_params: Attention parameters including sequence lengths. Returns: Refined token embeddings [B, L, C]. @@ -273,10 +264,14 @@ class IndividualTokenRefinerBlock(nn.Module): gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1) norm_x = self.norm1(x) qkv = self.self_attn_qkv(norm_x) + del norm_x q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) + del qkv q = self.self_attn_q_norm(q).to(v) k = self.self_attn_k_norm(k).to(v) - attn = attention(q, k, v, seq_lens=txt_lens, attn_mode=self.attn_mode) + qkv = [q, k, v] + del q, k, v + attn = attention(qkv, attn_params=attn_params) x = x + apply_gate(self.self_attn_proj(attn), gate_msa) x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp) @@ -299,7 +294,6 @@ class IndividualTokenRefiner(nn.Module): qk_norm: QK normalization flag. qk_norm_type: QK normalization type. qkv_bias: Use bias in QKV projections. - attn_mode: Attention implementation mode. """ def __init__( @@ -313,7 +307,6 @@ class IndividualTokenRefiner(nn.Module): qk_norm: bool = False, qk_norm_type: str = "layer", qkv_bias: bool = True, - attn_mode: str = "torch", ): super().__init__() self.blocks = nn.ModuleList( @@ -327,26 +320,25 @@ class IndividualTokenRefiner(nn.Module): qk_norm=qk_norm, qk_norm_type=qk_norm_type, qkv_bias=qkv_bias, - attn_mode=attn_mode, ) for _ in range(depth) ] ) - def forward(self, x: torch.Tensor, c: torch.LongTensor, txt_lens: list[int]) -> torch.Tensor: + def forward(self, x: torch.Tensor, c: torch.LongTensor, attn_params: AttentionParams) -> torch.Tensor: """ Apply sequential token refinement. Args: x: Input token embeddings [B, L, C]. c: Combined conditioning vector [B, C]. - txt_lens: Valid sequence lengths for each batch element. + attn_params: Attention parameters including sequence lengths. Returns: Refined token embeddings [B, L, C]. """ for block in self.blocks: - x = block(x, c, txt_lens) + x = block(x, c, attn_params) return x @@ -362,10 +354,9 @@ class SingleTokenRefiner(nn.Module): hidden_size: Transformer hidden dimension. heads_num: Number of attention heads. depth: Number of refinement blocks. - attn_mode: Attention implementation mode. """ - def __init__(self, in_channels: int, hidden_size: int, heads_num: int, depth: int, attn_mode: str = "torch"): + def __init__(self, in_channels: int, hidden_size: int, heads_num: int, depth: int): # Fixed architecture parameters for HunyuanImage-2.1 mlp_drop_rate: float = 0.0 # No MLP dropout act_type: str = "silu" # SiLU activation @@ -389,17 +380,16 @@ class SingleTokenRefiner(nn.Module): qk_norm=qk_norm, qk_norm_type=qk_norm_type, qkv_bias=qkv_bias, - attn_mode=attn_mode, ) - def forward(self, x: torch.Tensor, t: torch.LongTensor, txt_lens: list[int]) -> torch.Tensor: + def forward(self, x: torch.Tensor, t: torch.LongTensor, attn_params: AttentionParams) -> torch.Tensor: """ Refine text embeddings with timestep conditioning. Args: x: Input text embeddings [B, L, in_channels]. t: Diffusion timestep [B]. - txt_lens: Valid sequence lengths for each batch element. + attn_params: Attention parameters including sequence lengths. Returns: Refined embeddings [B, L, hidden_size]. @@ -407,13 +397,14 @@ class SingleTokenRefiner(nn.Module): timestep_aware_representations = self.t_embedder(t) # Compute context-aware representations by averaging valid tokens + txt_lens = attn_params.seqlens # img_len is not used for SingleTokenRefiner context_aware_representations = torch.stack([x[i, : txt_lens[i]].mean(dim=0) for i in range(x.shape[0])], dim=0) # [B, C] context_aware_representations = self.c_embedder(context_aware_representations) c = timestep_aware_representations + context_aware_representations del timestep_aware_representations, context_aware_representations x = self.input_embedder(x) - x = self.individual_token_refiner(x, c, txt_lens) + x = self.individual_token_refiner(x, c, attn_params) return x @@ -564,7 +555,6 @@ class MMDoubleStreamBlock(nn.Module): qk_norm: QK normalization flag (must be True). qk_norm_type: QK normalization type (only "rms" supported). qkv_bias: Use bias in QKV projections. - attn_mode: Attention implementation mode. """ def __init__( @@ -576,7 +566,6 @@ class MMDoubleStreamBlock(nn.Module): qk_norm: bool = True, qk_norm_type: str = "rms", qkv_bias: bool = False, - attn_mode: str = "torch", ): super().__init__() @@ -584,7 +573,6 @@ class MMDoubleStreamBlock(nn.Module): assert qk_norm_type == "rms", "Only RMS normalization supported." assert qk_norm, "QK normalization must be enabled." - self.attn_mode = attn_mode self.heads_num = heads_num head_dim = hidden_size // heads_num mlp_hidden_dim = int(hidden_size * mlp_width_ratio) @@ -626,7 +614,7 @@ class MMDoubleStreamBlock(nn.Module): self.cpu_offload_checkpointing = False def _forward( - self, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, freqs_cis: tuple = None, seq_lens: list[int] = None + self, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, freqs_cis: tuple = None, attn_params: AttentionParams = None ) -> Tuple[torch.Tensor, torch.Tensor]: # Extract modulation parameters for image and text streams (img_mod1_shift, img_mod1_scale, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate) = self.img_mod(vec).chunk( @@ -687,7 +675,7 @@ class MMDoubleStreamBlock(nn.Module): qkv = [q, k, v] del q, k, v - attn = attention(qkv, seq_lens=seq_lens, attn_mode=self.attn_mode) + attn = attention(qkv, attn_params=attn_params) del qkv # Split attention outputs back to separate streams @@ -719,16 +707,16 @@ class MMDoubleStreamBlock(nn.Module): return img, txt def forward( - self, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, freqs_cis: tuple = None, seq_lens: list[int] = None + self, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, freqs_cis: tuple = None, attn_params: AttentionParams = None ) -> Tuple[torch.Tensor, torch.Tensor]: if self.gradient_checkpointing and self.training: forward_fn = self._forward if self.cpu_offload_checkpointing: forward_fn = custom_offloading_utils.cpu_offload_wrapper(forward_fn, self.img_attn_qkv.weight.device) - return torch.utils.checkpoint.checkpoint(forward_fn, img, txt, vec, freqs_cis, seq_lens, use_reentrant=False) + return torch.utils.checkpoint.checkpoint(forward_fn, img, txt, vec, freqs_cis, attn_params, use_reentrant=False) else: - return self._forward(img, txt, vec, freqs_cis, seq_lens) + return self._forward(img, txt, vec, freqs_cis, attn_params) class MMSingleStreamBlock(nn.Module): @@ -746,7 +734,6 @@ class MMSingleStreamBlock(nn.Module): qk_norm: QK normalization flag (must be True). qk_norm_type: QK normalization type (only "rms" supported). qk_scale: Attention scaling factor (computed automatically if None). - attn_mode: Attention implementation mode. """ def __init__( @@ -758,7 +745,6 @@ class MMSingleStreamBlock(nn.Module): qk_norm: bool = True, qk_norm_type: str = "rms", qk_scale: float = None, - attn_mode: str = "torch", ): super().__init__() @@ -766,7 +752,6 @@ class MMSingleStreamBlock(nn.Module): assert qk_norm_type == "rms", "Only RMS normalization supported." assert qk_norm, "QK normalization must be enabled." - self.attn_mode = attn_mode self.hidden_size = hidden_size self.heads_num = heads_num head_dim = hidden_size // heads_num @@ -805,9 +790,8 @@ class MMSingleStreamBlock(nn.Module): self, x: torch.Tensor, vec: torch.Tensor, - txt_len: int, freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None, - seq_lens: list[int] = None, + attn_params: AttentionParams = None, ) -> torch.Tensor: # Extract modulation parameters mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1) @@ -828,12 +812,10 @@ class MMSingleStreamBlock(nn.Module): k = self.k_norm(k).to(v) # Separate image and text tokens - img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :] + img_q, txt_q = q[:, : attn_params.img_len, :, :], q[:, attn_params.img_len :, :, :] del q - img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :] + img_k, txt_k = k[:, : attn_params.img_len, :, :], k[:, attn_params.img_len :, :, :] del k - # img_v, txt_v = v[:, :-txt_len, :, :], v[:, -txt_len:, :, :] - # del v # Apply rotary position embeddings only to image tokens img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False) @@ -848,7 +830,7 @@ class MMSingleStreamBlock(nn.Module): # del img_v, txt_v qkv = [q, k, v] del q, k, v - attn = attention(qkv, seq_lens=seq_lens, attn_mode=self.attn_mode) + attn = attention(qkv, attn_params=attn_params) del qkv # Combine attention and MLP outputs, apply gating @@ -865,18 +847,17 @@ class MMSingleStreamBlock(nn.Module): self, x: torch.Tensor, vec: torch.Tensor, - txt_len: int, freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None, - seq_lens: list[int] = None, + attn_params: AttentionParams = None, ) -> torch.Tensor: if self.gradient_checkpointing and self.training: forward_fn = self._forward if self.cpu_offload_checkpointing: forward_fn = custom_offloading_utils.create_cpu_offloading_wrapper(forward_fn, self.linear1.weight.device) - return torch.utils.checkpoint.checkpoint(forward_fn, x, vec, txt_len, freqs_cis, seq_lens, use_reentrant=False) + return torch.utils.checkpoint.checkpoint(forward_fn, x, vec, freqs_cis, attn_params, use_reentrant=False) else: - return self._forward(x, vec, txt_len, freqs_cis, seq_lens) + return self._forward(x, vec, freqs_cis, attn_params) # endregion