mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
feat: add multi backend attention and related update for HI2.1 models and scripts
This commit is contained in:
@@ -126,7 +126,8 @@ accelerate launch --num_cpu_threads_per_process 1 hunyuan_image_train_network.py
|
||||
--learning_rate=1e-4 \
|
||||
--optimizer_type="AdamW8bit" \
|
||||
--lr_scheduler="constant" \
|
||||
--sdpa \
|
||||
--attn_mode="torch" \
|
||||
--split_attn \
|
||||
--max_train_epochs=10 \
|
||||
--save_every_n_epochs=1 \
|
||||
--mixed_precision="bf16" \
|
||||
@@ -175,6 +176,10 @@ The script adds HunyuanImage-2.1 specific arguments. For common arguments (like
|
||||
|
||||
#### Memory/Speed Related
|
||||
|
||||
* `--attn_mode=<choice>`
|
||||
- Specifies the attention implementation to use. Options are `torch`, `xformers`, `flash`, `sageattn`. Default is `torch` (use scaled dot product attention). Each library must be installed separately other than `torch`. If using `xformers`, also specify `--split_attn` if the batch size is more than 1.
|
||||
* `--split_attn`
|
||||
- Splits the batch during attention computation to process one item at a time, reducing VRAM usage by avoiding attention mask computation. Can improve speed when using `torch`. Required when using `xformers` with batch size greater than 1.
|
||||
* `--fp8_scaled`
|
||||
- Enables training the DiT model in scaled FP8 format. This can significantly reduce VRAM usage (can run with as little as 8GB VRAM when combined with `--blocks_to_swap`), but the training results may vary. This is a newer alternative to the unsupported `--fp8_base` option.
|
||||
* `--fp8_vl`
|
||||
@@ -429,6 +434,7 @@ python hunyuan_image_minimal_inference.py \
|
||||
--vae "<path to hunyuan_image_2.1_vae_fp16.safetensors>" \
|
||||
--lora_weight "<path to your trained LoRA>" \
|
||||
--lora_multiplier 1.0 \
|
||||
--attn_mode "torch" \
|
||||
--prompt "A cute cartoon penguin in a snowy landscape" \
|
||||
--image_size 2048 2048 \
|
||||
--infer_steps 50 \
|
||||
@@ -445,6 +451,8 @@ python hunyuan_image_minimal_inference.py \
|
||||
- `--guidance_scale`: CFG scale (default: 3.5)
|
||||
- `--flow_shift`: Flow matching shift parameter (default: 5.0)
|
||||
|
||||
`--split_attn` is not supported (since inference is done one at a time).
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
@@ -457,6 +465,8 @@ python hunyuan_image_minimal_inference.py \
|
||||
- `--guidance_scale`: CFGスケール(推奨: 3.5)
|
||||
- `--flow_shift`: Flow Matchingシフトパラメータ(デフォルト: 5.0)
|
||||
|
||||
`--split_attn`はサポートされていません(1件ずつ推論するため)。
|
||||
|
||||
</details>
|
||||
|
||||
## 9. Related Tools / 関連ツール
|
||||
|
||||
@@ -96,7 +96,7 @@ def parse_args() -> argparse.Namespace:
|
||||
"--attn_mode",
|
||||
type=str,
|
||||
default="torch",
|
||||
choices=["flash", "torch", "sageattn", "xformers", "sdpa"], # "flash2", "flash3",
|
||||
choices=["flash", "torch", "sageattn", "xformers", "sdpa"], # "sdpa" for backward compatibility
|
||||
help="attention mode",
|
||||
)
|
||||
parser.add_argument("--blocks_to_swap", type=int, default=0, help="number of blocks to swap in the model")
|
||||
@@ -130,6 +130,9 @@ def parse_args() -> argparse.Namespace:
|
||||
if args.lycoris and not lycoris_available:
|
||||
raise ValueError("install lycoris: https://github.com/KohakuBlueleaf/LyCORIS")
|
||||
|
||||
if args.attn_mode == "sdpa":
|
||||
args.attn_mode = "torch" # backward compatibility
|
||||
|
||||
return args
|
||||
|
||||
|
||||
@@ -265,7 +268,7 @@ def load_dit_model(
|
||||
device,
|
||||
args.dit,
|
||||
args.attn_mode,
|
||||
False,
|
||||
True, # enable split_attn to trim masked tokens
|
||||
loading_device,
|
||||
loading_weight_dtype,
|
||||
args.fp8_scaled and not args.lycoris,
|
||||
|
||||
@@ -379,18 +379,19 @@ class HunyuanImageNetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
loading_dtype = None if args.fp8_scaled else weight_dtype
|
||||
loading_device = "cpu" if self.is_swapping_blocks else accelerator.device
|
||||
split_attn = True
|
||||
|
||||
attn_mode = "torch"
|
||||
if args.xformers:
|
||||
attn_mode = "xformers"
|
||||
logger.info("xformers is enabled for attention")
|
||||
if args.attn_mode is not None:
|
||||
attn_mode = args.attn_mode
|
||||
|
||||
logger.info(f"Loading DiT model with attn_mode: {attn_mode}, split_attn: {args.split_attn}, fp8_scaled: {args.fp8_scaled}")
|
||||
model = hunyuan_image_models.load_hunyuan_image_model(
|
||||
accelerator.device,
|
||||
args.pretrained_model_name_or_path,
|
||||
attn_mode,
|
||||
split_attn,
|
||||
args.split_attn,
|
||||
loading_device,
|
||||
loading_dtype,
|
||||
args.fp8_scaled,
|
||||
@@ -674,6 +675,19 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="Enable tiling for VAE decoding and encoding / VAEデコーディングとエンコーディングのタイルを有効にする",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--attn_mode",
|
||||
choices=["torch", "xformers", "flash", "sageattn", "sdpa"], # "sdpa" is for backward compatibility
|
||||
default=None,
|
||||
help="Attention implementation to use. Default is None (torch). xformers requires --split_attn. sageattn does not support training (inference only). This option overrides --xformers or --sdpa."
|
||||
" / 使用するAttentionの実装。デフォルトはNone(torch)です。xformersは--split_attnの指定が必要です。sageattnはトレーニングをサポートしていません(推論のみ)。このオプションは--xformersまたは--sdpaを上書きします。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--split_attn",
|
||||
action="store_true",
|
||||
help="split attention computation to reduce memory usage / メモリ使用量を減らすためにattention時にバッチを分割する",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@@ -684,5 +698,8 @@ if __name__ == "__main__":
|
||||
train_util.verify_command_line_training_args(args)
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
if args.attn_mode == "sdpa":
|
||||
args.attn_mode = "torch" # backward compatibility
|
||||
|
||||
trainer = HunyuanImageNetworkTrainer()
|
||||
trainer.train(args)
|
||||
|
||||
@@ -1,18 +1,88 @@
|
||||
# Unified attention function supporting various implementations
|
||||
|
||||
from dataclasses import dataclass
|
||||
import torch
|
||||
from typing import Optional, Union
|
||||
|
||||
try:
|
||||
import flash_attn
|
||||
from flash_attn.flash_attn_interface import _flash_attn_forward
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||
from flash_attn.flash_attn_interface import flash_attn_func
|
||||
except ImportError:
|
||||
flash_attn = None
|
||||
flash_attn_varlen_func = None
|
||||
_flash_attn_forward = None
|
||||
flash_attn_func = None
|
||||
|
||||
try:
|
||||
from sageattention import sageattn_varlen, sageattn
|
||||
except ImportError:
|
||||
sageattn_varlen = None
|
||||
sageattn = None
|
||||
|
||||
try:
|
||||
import xformers.ops as xops
|
||||
except ImportError:
|
||||
xops = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttentionParams:
|
||||
attn_mode: Optional[str] = None
|
||||
split_attn: bool = False
|
||||
img_len: Optional[int] = None
|
||||
attention_mask: Optional[torch.Tensor] = None
|
||||
seqlens: Optional[torch.Tensor] = None
|
||||
cu_seqlens: Optional[torch.Tensor] = None
|
||||
max_seqlen: Optional[int] = None
|
||||
|
||||
@staticmethod
|
||||
def create_attention_params(attn_mode: Optional[str], split_attn: bool) -> "AttentionParams":
|
||||
return AttentionParams(attn_mode, split_attn)
|
||||
|
||||
@staticmethod
|
||||
def create_attention_params_from_mask(
|
||||
attn_mode: Optional[str], split_attn: bool, img_len: Optional[int], attention_mask: Optional[torch.Tensor]
|
||||
) -> "AttentionParams":
|
||||
if attention_mask is None:
|
||||
# No attention mask provided: assume all tokens are valid
|
||||
return AttentionParams(attn_mode, split_attn, None, None, None, None, None)
|
||||
else:
|
||||
# Note: attention_mask is only for text tokens, not including image tokens
|
||||
seqlens = attention_mask.sum(dim=1).to(torch.int32) + img_len # [B]
|
||||
max_seqlen = attention_mask.shape[1] + img_len
|
||||
|
||||
if split_attn:
|
||||
# cu_seqlens is not needed for split attention
|
||||
return AttentionParams(attn_mode, split_attn, img_len, attention_mask, seqlens, None, max_seqlen)
|
||||
|
||||
# Convert attention mask to cumulative sequence lengths for flash attention
|
||||
batch_size = attention_mask.shape[0]
|
||||
cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device=attention_mask.device)
|
||||
for i in range(batch_size):
|
||||
cu_seqlens[2 * i + 1] = i * max_seqlen + seqlens[i] # end of valid tokens for query
|
||||
cu_seqlens[2 * i + 2] = (i + 1) * max_seqlen # end of all tokens for query
|
||||
|
||||
# Expand attention mask to include image tokens
|
||||
attention_mask = torch.nn.functional.pad(attention_mask, (img_len, 0), value=1) # [B, img_len + L]
|
||||
|
||||
if attn_mode == "xformers":
|
||||
seqlens_list = seqlens.cpu().tolist()
|
||||
attention_mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(
|
||||
seqlens_list, seqlens_list, device=attention_mask.device
|
||||
)
|
||||
elif attn_mode == "torch":
|
||||
attention_mask = attention_mask[:, None, None, :].to(torch.bool) # [B, 1, 1, img_len + L]
|
||||
|
||||
return AttentionParams(attn_mode, split_attn, img_len, attention_mask, seqlens, cu_seqlens, max_seqlen)
|
||||
|
||||
|
||||
def attention(
|
||||
qkv_or_q: Union[torch.Tensor, list],
|
||||
k: Optional[torch.Tensor] = None,
|
||||
v: Optional[torch.Tensor] = None,
|
||||
seq_lens: Optional[list[int]] = None,
|
||||
attn_mode: str = "torch",
|
||||
attn_params: Optional[AttentionParams] = None,
|
||||
drop_rate: float = 0.0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -25,8 +95,7 @@ def attention(
|
||||
qkv_or_q: Query tensor [B, L, H, D]. or list of such tensors.
|
||||
k: Key tensor [B, L, H, D].
|
||||
v: Value tensor [B, L, H, D].
|
||||
seq_lens: Valid sequence length for each batch element.
|
||||
attn_mode: Attention implementation ("torch" or "sageattn").
|
||||
attn_param: Attention parameters including mask and sequence lengths.
|
||||
drop_rate: Attention dropout rate.
|
||||
|
||||
Returns:
|
||||
@@ -34,53 +103,158 @@ def attention(
|
||||
"""
|
||||
if isinstance(qkv_or_q, list):
|
||||
q, k, v = qkv_or_q
|
||||
q: torch.Tensor = q
|
||||
qkv_or_q.clear()
|
||||
del qkv_or_q
|
||||
else:
|
||||
q = qkv_or_q
|
||||
q: torch.Tensor = qkv_or_q
|
||||
del qkv_or_q
|
||||
assert k is not None and v is not None, "k and v must be provided if qkv_or_q is a tensor"
|
||||
if seq_lens is None:
|
||||
seq_lens = [q.shape[1]] * q.shape[0]
|
||||
if attn_params is None:
|
||||
attn_params = AttentionParams.create_attention_params("torch", False)
|
||||
|
||||
# If split attn is False, attention mask is provided and all sequence lengths are same, we can trim the sequence
|
||||
seqlen_trimmed = False
|
||||
if not attn_params.split_attn and attn_params.attention_mask is not None and attn_params.seqlens is not None:
|
||||
if torch.all(attn_params.seqlens == attn_params.seqlens[0]):
|
||||
seqlen = attn_params.seqlens[0].item()
|
||||
q = q[:, :seqlen]
|
||||
k = k[:, :seqlen]
|
||||
v = v[:, :seqlen]
|
||||
max_seqlen = attn_params.max_seqlen
|
||||
attn_params = AttentionParams.create_attention_params(attn_params.attn_mode, False) # do not in-place modify
|
||||
attn_params.max_seqlen = max_seqlen # keep max_seqlen for padding
|
||||
seqlen_trimmed = True
|
||||
|
||||
# Determine tensor layout based on attention implementation
|
||||
if attn_mode == "torch" or attn_mode == "sageattn":
|
||||
transpose_fn = lambda x: x.transpose(1, 2) # [B, H, L, D] for SDPA
|
||||
if attn_params.attn_mode == "torch" or (
|
||||
attn_params.attn_mode == "sageattn" and (attn_params.split_attn or attn_params.cu_seqlens is None)
|
||||
):
|
||||
transpose_fn = lambda x: x.transpose(1, 2) # [B, H, L, D] for SDPA and sageattn with fixed length
|
||||
# pad on sequence length dimension
|
||||
pad_fn = lambda x, pad_to: torch.nn.functional.pad(x, (0, 0, 0, pad_to - x.shape[-2]), value=0)
|
||||
else:
|
||||
transpose_fn = lambda x: x # [B, L, H, D] for other implementations
|
||||
# pad on sequence length dimension
|
||||
pad_fn = lambda x, pad_to: torch.nn.functional.pad(x, (0, 0, 0, 0, 0, pad_to - x.shape[-3]), value=0)
|
||||
|
||||
# Process each batch element with its valid sequence length
|
||||
q_seq_len = q.shape[1]
|
||||
q = [transpose_fn(q[i : i + 1, : seq_lens[i]]) for i in range(len(q))]
|
||||
k = [transpose_fn(k[i : i + 1, : seq_lens[i]]) for i in range(len(k))]
|
||||
v = [transpose_fn(v[i : i + 1, : seq_lens[i]]) for i in range(len(v))]
|
||||
# Process each batch element with its valid sequence lengths
|
||||
if attn_params.split_attn:
|
||||
if attn_params.seqlens is None:
|
||||
# If no seqlens provided, assume all tokens are valid
|
||||
attn_params = AttentionParams.create_attention_params(attn_params.attn_mode, True) # do not in-place modify
|
||||
attn_params.seqlens = torch.tensor([q.shape[1]] * q.shape[0], device=q.device)
|
||||
attn_params.max_seqlen = q.shape[1]
|
||||
q = [transpose_fn(q[i : i + 1, : attn_params.seqlens[i]]) for i in range(len(q))]
|
||||
k = [transpose_fn(k[i : i + 1, : attn_params.seqlens[i]]) for i in range(len(k))]
|
||||
v = [transpose_fn(v[i : i + 1, : attn_params.seqlens[i]]) for i in range(len(v))]
|
||||
else:
|
||||
q = transpose_fn(q)
|
||||
k = transpose_fn(k)
|
||||
v = transpose_fn(v)
|
||||
|
||||
if attn_mode == "torch":
|
||||
x = []
|
||||
for i in range(len(q)):
|
||||
x_i = torch.nn.functional.scaled_dot_product_attention(q[i], k[i], v[i], dropout_p=drop_rate)
|
||||
q[i] = None
|
||||
k[i] = None
|
||||
v[i] = None
|
||||
x.append(torch.nn.functional.pad(x_i, (0, 0, 0, q_seq_len - x_i.shape[2]), value=0)) # Pad to max seq len, B, H, L, D
|
||||
x = torch.cat(x, dim=0)
|
||||
del q, k, v
|
||||
if attn_params.attn_mode == "torch":
|
||||
if attn_params.split_attn:
|
||||
x = []
|
||||
for i in range(len(q)):
|
||||
x_i = torch.nn.functional.scaled_dot_product_attention(q[i], k[i], v[i], dropout_p=drop_rate)
|
||||
q[i] = None
|
||||
k[i] = None
|
||||
v[i] = None
|
||||
x.append(pad_fn(x_i, attn_params.max_seqlen)) # B, H, L, D
|
||||
x = torch.cat(x, dim=0)
|
||||
del q, k, v
|
||||
|
||||
elif attn_mode == "xformers":
|
||||
x = []
|
||||
for i in range(len(q)):
|
||||
x_i = xops.memory_efficient_attention(q[i], k[i], v[i], p=drop_rate)
|
||||
q[i] = None
|
||||
k[i] = None
|
||||
v[i] = None
|
||||
x.append(torch.nn.functional.pad(x_i, (0, 0, 0, 0, 0, q_seq_len - x_i.shape[1]), value=0)) # B, L, H, D
|
||||
x = torch.cat(x, dim=0)
|
||||
del q, k, v
|
||||
else:
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_params.attention_mask, dropout_p=drop_rate)
|
||||
del q, k, v
|
||||
|
||||
elif attn_params.attn_mode == "xformers":
|
||||
if attn_params.split_attn:
|
||||
x = []
|
||||
for i in range(len(q)):
|
||||
x_i = xops.memory_efficient_attention(q[i], k[i], v[i], p=drop_rate)
|
||||
q[i] = None
|
||||
k[i] = None
|
||||
v[i] = None
|
||||
x.append(pad_fn(x_i, attn_params.max_seqlen)) # B, L, H, D
|
||||
x = torch.cat(x, dim=0)
|
||||
del q, k, v
|
||||
|
||||
else:
|
||||
x = xops.memory_efficient_attention(q, k, v, attn_bias=attn_params.attention_mask, p=drop_rate)
|
||||
del q, k, v
|
||||
|
||||
elif attn_params.attn_mode == "sageattn":
|
||||
if attn_params.split_attn:
|
||||
x = []
|
||||
for i in range(len(q)):
|
||||
# HND seems to cause an error
|
||||
x_i = sageattn(q[i], k[i], v[i]) # B, H, L, D. No dropout support
|
||||
q[i] = None
|
||||
k[i] = None
|
||||
v[i] = None
|
||||
x.append(pad_fn(x_i, attn_params.max_seqlen)) # B, H, L, D
|
||||
x = torch.cat(x, dim=0)
|
||||
del q, k, v
|
||||
elif attn_params.cu_seqlens is None: # all tokens are valid
|
||||
x = sageattn(q, k, v) # B, L, H, D. No dropout support
|
||||
del q, k, v
|
||||
else:
|
||||
# Reshape to [(bxs), a, d]
|
||||
batch_size, seqlen = q.shape[0], q.shape[1]
|
||||
q = q.view(q.shape[0] * q.shape[1], *q.shape[2:]) # [B*L, H, D]
|
||||
k = k.view(k.shape[0] * k.shape[1], *k.shape[2:]) # [B*L, H, D]
|
||||
v = v.view(v.shape[0] * v.shape[1], *v.shape[2:]) # [B*L, H, D]
|
||||
|
||||
# Assume cu_seqlens_q == cu_seqlens_kv and max_seqlen_q == max_seqlen_kv. No dropout support
|
||||
x = sageattn_varlen(
|
||||
q, k, v, attn_params.cu_seqlens, attn_params.cu_seqlens, attn_params.max_seqlen, attn_params.max_seqlen
|
||||
)
|
||||
del q, k, v
|
||||
|
||||
# Reshape x with shape [(bxs), a, d] to [b, s, a, d]
|
||||
x = x.view(batch_size, seqlen, x.shape[-2], x.shape[-1]) # B, L, H, D
|
||||
|
||||
elif attn_params.attn_mode == "flash":
|
||||
if attn_params.split_attn:
|
||||
x = []
|
||||
for i in range(len(q)):
|
||||
# HND seems to cause an error
|
||||
x_i = flash_attn_func(q[i], k[i], v[i], drop_rate) # B, L, H, D
|
||||
q[i] = None
|
||||
k[i] = None
|
||||
v[i] = None
|
||||
x.append(pad_fn(x_i, attn_params.max_seqlen)) # B, L, H, D
|
||||
x = torch.cat(x, dim=0)
|
||||
del q, k, v
|
||||
elif attn_params.cu_seqlens is None: # all tokens are valid
|
||||
x = flash_attn_func(q, k, v, drop_rate) # B, L, H, D
|
||||
del q, k, v
|
||||
else:
|
||||
# Reshape to [(bxs), a, d]
|
||||
batch_size, seqlen = q.shape[0], q.shape[1]
|
||||
q = q.view(q.shape[0] * q.shape[1], *q.shape[2:]) # [B*L, H, D]
|
||||
k = k.view(k.shape[0] * k.shape[1], *k.shape[2:]) # [B*L, H, D]
|
||||
v = v.view(v.shape[0] * v.shape[1], *v.shape[2:]) # [B*L, H, D]
|
||||
|
||||
# Assume cu_seqlens_q == cu_seqlens_kv and max_seqlen_q == max_seqlen_kv
|
||||
x = flash_attn_varlen_func(
|
||||
q, k, v, attn_params.cu_seqlens, attn_params.cu_seqlens, attn_params.max_seqlen, attn_params.max_seqlen, drop_rate
|
||||
)
|
||||
del q, k, v
|
||||
|
||||
# Reshape x with shape [(bxs), a, d] to [b, s, a, d]
|
||||
x = x.view(batch_size, seqlen, x.shape[-2], x.shape[-1]) # B, L, H, D
|
||||
|
||||
else:
|
||||
# Currently only PyTorch SDPA and xformers are implemented
|
||||
raise ValueError(f"Unsupported attention mode: {attn_mode}")
|
||||
raise ValueError(f"Unsupported attention mode: {attn_params.attn_mode}")
|
||||
|
||||
x = transpose_fn(x) # [B, L, H, D]
|
||||
x = x.reshape(x.shape[0], x.shape[1], -1) # [B, L, H*D]
|
||||
|
||||
if seqlen_trimmed:
|
||||
x = torch.nn.functional.pad(x, (0, 0, 0, attn_params.max_seqlen - x.shape[1]), value=0) # pad back to max_seqlen
|
||||
|
||||
return x
|
||||
|
||||
@@ -8,6 +8,7 @@ import torch.nn as nn
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
from library import custom_offloading_utils
|
||||
from library.attention import AttentionParams
|
||||
from library.fp8_optimization_utils import apply_fp8_monkey_patch
|
||||
from library.lora_utils import load_safetensors_with_lora_and_fp8
|
||||
from library.utils import setup_logging
|
||||
@@ -50,7 +51,7 @@ class HYImageDiffusionTransformer(nn.Module):
|
||||
attn_mode: Attention implementation mode ("torch" or "sageattn").
|
||||
"""
|
||||
|
||||
def __init__(self, attn_mode: str = "torch"):
|
||||
def __init__(self, attn_mode: str = "torch", split_attn: bool = False):
|
||||
super().__init__()
|
||||
|
||||
# Fixed architecture parameters for HunyuanImage-2.1
|
||||
@@ -80,6 +81,7 @@ class HYImageDiffusionTransformer(nn.Module):
|
||||
qk_norm_type: str = "rms" # RMS normalization type
|
||||
|
||||
self.attn_mode = attn_mode
|
||||
self.split_attn = split_attn
|
||||
|
||||
# ByT5 character-level text encoder mapping
|
||||
self.byt5_in = ByT5Mapper(in_dim=1472, out_dim=2048, hidden_dim=2048, out_dim1=self.hidden_size, use_residual=False)
|
||||
@@ -88,7 +90,7 @@ class HYImageDiffusionTransformer(nn.Module):
|
||||
self.img_in = PatchEmbed2D(self.patch_size, self.in_channels, self.hidden_size)
|
||||
|
||||
# Text token refinement with cross-attention
|
||||
self.txt_in = SingleTokenRefiner(text_states_dim, self.hidden_size, self.heads_num, depth=2, attn_mode=self.attn_mode)
|
||||
self.txt_in = SingleTokenRefiner(text_states_dim, self.hidden_size, self.heads_num, depth=2)
|
||||
|
||||
# Timestep embedding for diffusion process
|
||||
self.time_in = TimestepEmbedder(self.hidden_size, nn.SiLU)
|
||||
@@ -110,7 +112,6 @@ class HYImageDiffusionTransformer(nn.Module):
|
||||
qk_norm=qk_norm,
|
||||
qk_norm_type=qk_norm_type,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_mode=self.attn_mode,
|
||||
)
|
||||
for _ in range(mm_double_blocks_depth)
|
||||
]
|
||||
@@ -126,7 +127,6 @@ class HYImageDiffusionTransformer(nn.Module):
|
||||
mlp_act_type=mlp_act_type,
|
||||
qk_norm=qk_norm,
|
||||
qk_norm_type=qk_norm_type,
|
||||
attn_mode=self.attn_mode,
|
||||
)
|
||||
for _ in range(mm_single_blocks_depth)
|
||||
]
|
||||
@@ -339,22 +339,21 @@ class HYImageDiffusionTransformer(nn.Module):
|
||||
# MeanFlow and guidance embedding not used in this configuration
|
||||
|
||||
# Process text tokens through refinement layers
|
||||
txt_lens = text_mask.to(torch.bool).sum(dim=1).tolist()
|
||||
txt = self.txt_in(txt, t, txt_lens)
|
||||
txt_attn_params = AttentionParams.create_attention_params_from_mask(self.attn_mode, self.split_attn, 0, text_mask)
|
||||
txt = self.txt_in(txt, t, txt_attn_params)
|
||||
|
||||
# Integrate character-level ByT5 features with word-level tokens
|
||||
# Use variable length sequences with sequence lengths
|
||||
byt5_txt = self.byt5_in(byt5_text_states)
|
||||
txt, _, txt_lens = self.reorder_txt_token(byt5_txt, txt, byt5_text_mask, text_mask)
|
||||
txt, text_mask, txt_lens = self.reorder_txt_token(byt5_txt, txt, byt5_text_mask, text_mask)
|
||||
|
||||
# Trim sequences to maximum length in the batch
|
||||
img_seq_len = img.shape[1]
|
||||
# print(f"img_seq_len: {img_seq_len}, txt_lens: {txt_lens}")
|
||||
seq_lens = [img_seq_len + l for l in txt_lens]
|
||||
max_txt_len = max(txt_lens)
|
||||
# print(f"max_txt_len: {max_txt_len}, seq_lens: {seq_lens}, txt.shape: {txt.shape}")
|
||||
txt = txt[:, :max_txt_len, :]
|
||||
txt_seq_len = txt.shape[1]
|
||||
text_mask = text_mask[:, :max_txt_len]
|
||||
|
||||
attn_params = AttentionParams.create_attention_params_from_mask(self.attn_mode, self.split_attn, img_seq_len, text_mask)
|
||||
|
||||
input_device = img.device
|
||||
|
||||
@@ -362,7 +361,7 @@ class HYImageDiffusionTransformer(nn.Module):
|
||||
for index, block in enumerate(self.double_blocks):
|
||||
if self.blocks_to_swap:
|
||||
self.offloader_double.wait_for_block(index)
|
||||
img, txt = block(img, txt, vec, freqs_cis, seq_lens)
|
||||
img, txt = block(img, txt, vec, freqs_cis, attn_params)
|
||||
if self.blocks_to_swap:
|
||||
self.offloader_double.submit_move_blocks(self.double_blocks, index)
|
||||
|
||||
@@ -373,7 +372,7 @@ class HYImageDiffusionTransformer(nn.Module):
|
||||
for index, block in enumerate(self.single_blocks):
|
||||
if self.blocks_to_swap:
|
||||
self.offloader_single.wait_for_block(index)
|
||||
x = block(x, vec, txt_seq_len, freqs_cis, seq_lens)
|
||||
x = block(x, vec, freqs_cis, attn_params)
|
||||
if self.blocks_to_swap:
|
||||
self.offloader_single.submit_move_blocks(self.single_blocks, index)
|
||||
|
||||
@@ -417,7 +416,7 @@ class HYImageDiffusionTransformer(nn.Module):
|
||||
|
||||
def create_model(attn_mode: str, split_attn: bool, dtype: Optional[torch.dtype]) -> HYImageDiffusionTransformer:
|
||||
with init_empty_weights():
|
||||
model = HYImageDiffusionTransformer(attn_mode=attn_mode)
|
||||
model = HYImageDiffusionTransformer(attn_mode=attn_mode, split_attn=split_attn)
|
||||
if dtype is not None:
|
||||
model.to(dtype)
|
||||
return model
|
||||
|
||||
@@ -7,7 +7,7 @@ import torch.nn as nn
|
||||
from einops import rearrange
|
||||
|
||||
from library import custom_offloading_utils
|
||||
from library.attention import attention
|
||||
from library.attention import AttentionParams, attention
|
||||
from library.hunyuan_image_utils import timestep_embedding, apply_rotary_emb, _to_tuple, apply_gate, modulate
|
||||
from library.attention import attention
|
||||
|
||||
@@ -213,7 +213,6 @@ class IndividualTokenRefinerBlock(nn.Module):
|
||||
qk_norm: QK normalization flag (must be False).
|
||||
qk_norm_type: QK normalization type (only "layer" supported).
|
||||
qkv_bias: Use bias in QKV projections.
|
||||
attn_mode: Attention implementation mode.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -226,15 +225,12 @@ class IndividualTokenRefinerBlock(nn.Module):
|
||||
qk_norm: bool = False,
|
||||
qk_norm_type: str = "layer",
|
||||
qkv_bias: bool = True,
|
||||
attn_mode: str = "torch",
|
||||
):
|
||||
super().__init__()
|
||||
assert qk_norm_type == "layer", "Only layer normalization supported for QK norm."
|
||||
assert act_type == "silu", "Only SiLU activation supported."
|
||||
assert not qk_norm, "QK normalization must be disabled."
|
||||
|
||||
self.attn_mode = attn_mode
|
||||
|
||||
self.heads_num = heads_num
|
||||
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
|
||||
|
||||
@@ -253,19 +249,14 @@ class IndividualTokenRefinerBlock(nn.Module):
|
||||
nn.Linear(hidden_size, 2 * hidden_size, bias=True),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
c: torch.Tensor, # Combined timestep and context conditioning
|
||||
txt_lens: list[int],
|
||||
) -> torch.Tensor:
|
||||
def forward(self, x: torch.Tensor, c: torch.Tensor, attn_params: AttentionParams) -> torch.Tensor:
|
||||
"""
|
||||
Apply self-attention and MLP with adaptive conditioning.
|
||||
|
||||
Args:
|
||||
x: Input token embeddings [B, L, C].
|
||||
c: Combined conditioning vector [B, C].
|
||||
txt_lens: Valid sequence lengths for each batch element.
|
||||
attn_params: Attention parameters including sequence lengths.
|
||||
|
||||
Returns:
|
||||
Refined token embeddings [B, L, C].
|
||||
@@ -273,10 +264,14 @@ class IndividualTokenRefinerBlock(nn.Module):
|
||||
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
norm_x = self.norm1(x)
|
||||
qkv = self.self_attn_qkv(norm_x)
|
||||
del norm_x
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
||||
del qkv
|
||||
q = self.self_attn_q_norm(q).to(v)
|
||||
k = self.self_attn_k_norm(k).to(v)
|
||||
attn = attention(q, k, v, seq_lens=txt_lens, attn_mode=self.attn_mode)
|
||||
qkv = [q, k, v]
|
||||
del q, k, v
|
||||
attn = attention(qkv, attn_params=attn_params)
|
||||
|
||||
x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
|
||||
x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
|
||||
@@ -299,7 +294,6 @@ class IndividualTokenRefiner(nn.Module):
|
||||
qk_norm: QK normalization flag.
|
||||
qk_norm_type: QK normalization type.
|
||||
qkv_bias: Use bias in QKV projections.
|
||||
attn_mode: Attention implementation mode.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -313,7 +307,6 @@ class IndividualTokenRefiner(nn.Module):
|
||||
qk_norm: bool = False,
|
||||
qk_norm_type: str = "layer",
|
||||
qkv_bias: bool = True,
|
||||
attn_mode: str = "torch",
|
||||
):
|
||||
super().__init__()
|
||||
self.blocks = nn.ModuleList(
|
||||
@@ -327,26 +320,25 @@ class IndividualTokenRefiner(nn.Module):
|
||||
qk_norm=qk_norm,
|
||||
qk_norm_type=qk_norm_type,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_mode=attn_mode,
|
||||
)
|
||||
for _ in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, c: torch.LongTensor, txt_lens: list[int]) -> torch.Tensor:
|
||||
def forward(self, x: torch.Tensor, c: torch.LongTensor, attn_params: AttentionParams) -> torch.Tensor:
|
||||
"""
|
||||
Apply sequential token refinement.
|
||||
|
||||
Args:
|
||||
x: Input token embeddings [B, L, C].
|
||||
c: Combined conditioning vector [B, C].
|
||||
txt_lens: Valid sequence lengths for each batch element.
|
||||
attn_params: Attention parameters including sequence lengths.
|
||||
|
||||
Returns:
|
||||
Refined token embeddings [B, L, C].
|
||||
"""
|
||||
for block in self.blocks:
|
||||
x = block(x, c, txt_lens)
|
||||
x = block(x, c, attn_params)
|
||||
return x
|
||||
|
||||
|
||||
@@ -362,10 +354,9 @@ class SingleTokenRefiner(nn.Module):
|
||||
hidden_size: Transformer hidden dimension.
|
||||
heads_num: Number of attention heads.
|
||||
depth: Number of refinement blocks.
|
||||
attn_mode: Attention implementation mode.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, hidden_size: int, heads_num: int, depth: int, attn_mode: str = "torch"):
|
||||
def __init__(self, in_channels: int, hidden_size: int, heads_num: int, depth: int):
|
||||
# Fixed architecture parameters for HunyuanImage-2.1
|
||||
mlp_drop_rate: float = 0.0 # No MLP dropout
|
||||
act_type: str = "silu" # SiLU activation
|
||||
@@ -389,17 +380,16 @@ class SingleTokenRefiner(nn.Module):
|
||||
qk_norm=qk_norm,
|
||||
qk_norm_type=qk_norm_type,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_mode=attn_mode,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, t: torch.LongTensor, txt_lens: list[int]) -> torch.Tensor:
|
||||
def forward(self, x: torch.Tensor, t: torch.LongTensor, attn_params: AttentionParams) -> torch.Tensor:
|
||||
"""
|
||||
Refine text embeddings with timestep conditioning.
|
||||
|
||||
Args:
|
||||
x: Input text embeddings [B, L, in_channels].
|
||||
t: Diffusion timestep [B].
|
||||
txt_lens: Valid sequence lengths for each batch element.
|
||||
attn_params: Attention parameters including sequence lengths.
|
||||
|
||||
Returns:
|
||||
Refined embeddings [B, L, hidden_size].
|
||||
@@ -407,13 +397,14 @@ class SingleTokenRefiner(nn.Module):
|
||||
timestep_aware_representations = self.t_embedder(t)
|
||||
|
||||
# Compute context-aware representations by averaging valid tokens
|
||||
txt_lens = attn_params.seqlens # img_len is not used for SingleTokenRefiner
|
||||
context_aware_representations = torch.stack([x[i, : txt_lens[i]].mean(dim=0) for i in range(x.shape[0])], dim=0) # [B, C]
|
||||
|
||||
context_aware_representations = self.c_embedder(context_aware_representations)
|
||||
c = timestep_aware_representations + context_aware_representations
|
||||
del timestep_aware_representations, context_aware_representations
|
||||
x = self.input_embedder(x)
|
||||
x = self.individual_token_refiner(x, c, txt_lens)
|
||||
x = self.individual_token_refiner(x, c, attn_params)
|
||||
return x
|
||||
|
||||
|
||||
@@ -564,7 +555,6 @@ class MMDoubleStreamBlock(nn.Module):
|
||||
qk_norm: QK normalization flag (must be True).
|
||||
qk_norm_type: QK normalization type (only "rms" supported).
|
||||
qkv_bias: Use bias in QKV projections.
|
||||
attn_mode: Attention implementation mode.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -576,7 +566,6 @@ class MMDoubleStreamBlock(nn.Module):
|
||||
qk_norm: bool = True,
|
||||
qk_norm_type: str = "rms",
|
||||
qkv_bias: bool = False,
|
||||
attn_mode: str = "torch",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -584,7 +573,6 @@ class MMDoubleStreamBlock(nn.Module):
|
||||
assert qk_norm_type == "rms", "Only RMS normalization supported."
|
||||
assert qk_norm, "QK normalization must be enabled."
|
||||
|
||||
self.attn_mode = attn_mode
|
||||
self.heads_num = heads_num
|
||||
head_dim = hidden_size // heads_num
|
||||
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
|
||||
@@ -626,7 +614,7 @@ class MMDoubleStreamBlock(nn.Module):
|
||||
self.cpu_offload_checkpointing = False
|
||||
|
||||
def _forward(
|
||||
self, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, freqs_cis: tuple = None, seq_lens: list[int] = None
|
||||
self, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, freqs_cis: tuple = None, attn_params: AttentionParams = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Extract modulation parameters for image and text streams
|
||||
(img_mod1_shift, img_mod1_scale, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate) = self.img_mod(vec).chunk(
|
||||
@@ -687,7 +675,7 @@ class MMDoubleStreamBlock(nn.Module):
|
||||
|
||||
qkv = [q, k, v]
|
||||
del q, k, v
|
||||
attn = attention(qkv, seq_lens=seq_lens, attn_mode=self.attn_mode)
|
||||
attn = attention(qkv, attn_params=attn_params)
|
||||
del qkv
|
||||
|
||||
# Split attention outputs back to separate streams
|
||||
@@ -719,16 +707,16 @@ class MMDoubleStreamBlock(nn.Module):
|
||||
return img, txt
|
||||
|
||||
def forward(
|
||||
self, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, freqs_cis: tuple = None, seq_lens: list[int] = None
|
||||
self, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, freqs_cis: tuple = None, attn_params: AttentionParams = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.gradient_checkpointing and self.training:
|
||||
forward_fn = self._forward
|
||||
if self.cpu_offload_checkpointing:
|
||||
forward_fn = custom_offloading_utils.cpu_offload_wrapper(forward_fn, self.img_attn_qkv.weight.device)
|
||||
|
||||
return torch.utils.checkpoint.checkpoint(forward_fn, img, txt, vec, freqs_cis, seq_lens, use_reentrant=False)
|
||||
return torch.utils.checkpoint.checkpoint(forward_fn, img, txt, vec, freqs_cis, attn_params, use_reentrant=False)
|
||||
else:
|
||||
return self._forward(img, txt, vec, freqs_cis, seq_lens)
|
||||
return self._forward(img, txt, vec, freqs_cis, attn_params)
|
||||
|
||||
|
||||
class MMSingleStreamBlock(nn.Module):
|
||||
@@ -746,7 +734,6 @@ class MMSingleStreamBlock(nn.Module):
|
||||
qk_norm: QK normalization flag (must be True).
|
||||
qk_norm_type: QK normalization type (only "rms" supported).
|
||||
qk_scale: Attention scaling factor (computed automatically if None).
|
||||
attn_mode: Attention implementation mode.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -758,7 +745,6 @@ class MMSingleStreamBlock(nn.Module):
|
||||
qk_norm: bool = True,
|
||||
qk_norm_type: str = "rms",
|
||||
qk_scale: float = None,
|
||||
attn_mode: str = "torch",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -766,7 +752,6 @@ class MMSingleStreamBlock(nn.Module):
|
||||
assert qk_norm_type == "rms", "Only RMS normalization supported."
|
||||
assert qk_norm, "QK normalization must be enabled."
|
||||
|
||||
self.attn_mode = attn_mode
|
||||
self.hidden_size = hidden_size
|
||||
self.heads_num = heads_num
|
||||
head_dim = hidden_size // heads_num
|
||||
@@ -805,9 +790,8 @@ class MMSingleStreamBlock(nn.Module):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
vec: torch.Tensor,
|
||||
txt_len: int,
|
||||
freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
|
||||
seq_lens: list[int] = None,
|
||||
attn_params: AttentionParams = None,
|
||||
) -> torch.Tensor:
|
||||
# Extract modulation parameters
|
||||
mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
|
||||
@@ -828,12 +812,10 @@ class MMSingleStreamBlock(nn.Module):
|
||||
k = self.k_norm(k).to(v)
|
||||
|
||||
# Separate image and text tokens
|
||||
img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
|
||||
img_q, txt_q = q[:, : attn_params.img_len, :, :], q[:, attn_params.img_len :, :, :]
|
||||
del q
|
||||
img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
|
||||
img_k, txt_k = k[:, : attn_params.img_len, :, :], k[:, attn_params.img_len :, :, :]
|
||||
del k
|
||||
# img_v, txt_v = v[:, :-txt_len, :, :], v[:, -txt_len:, :, :]
|
||||
# del v
|
||||
|
||||
# Apply rotary position embeddings only to image tokens
|
||||
img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
|
||||
@@ -848,7 +830,7 @@ class MMSingleStreamBlock(nn.Module):
|
||||
# del img_v, txt_v
|
||||
qkv = [q, k, v]
|
||||
del q, k, v
|
||||
attn = attention(qkv, seq_lens=seq_lens, attn_mode=self.attn_mode)
|
||||
attn = attention(qkv, attn_params=attn_params)
|
||||
del qkv
|
||||
|
||||
# Combine attention and MLP outputs, apply gating
|
||||
@@ -865,18 +847,17 @@ class MMSingleStreamBlock(nn.Module):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
vec: torch.Tensor,
|
||||
txt_len: int,
|
||||
freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
|
||||
seq_lens: list[int] = None,
|
||||
attn_params: AttentionParams = None,
|
||||
) -> torch.Tensor:
|
||||
if self.gradient_checkpointing and self.training:
|
||||
forward_fn = self._forward
|
||||
if self.cpu_offload_checkpointing:
|
||||
forward_fn = custom_offloading_utils.create_cpu_offloading_wrapper(forward_fn, self.linear1.weight.device)
|
||||
|
||||
return torch.utils.checkpoint.checkpoint(forward_fn, x, vec, txt_len, freqs_cis, seq_lens, use_reentrant=False)
|
||||
return torch.utils.checkpoint.checkpoint(forward_fn, x, vec, freqs_cis, attn_params, use_reentrant=False)
|
||||
else:
|
||||
return self._forward(x, vec, txt_len, freqs_cis, seq_lens)
|
||||
return self._forward(x, vec, freqs_cis, attn_params)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
Reference in New Issue
Block a user