Files
Kohya-ss-sd-scripts/library/attention.py
Kohya S. 34e7138b6a Add/modify some implementation for anima (#2261)
* fix: update extend-exclude list in _typos.toml to include configs

* fix: exclude anima tests from pytest

* feat: add entry for 'temperal' in extend-words section of _typos.toml for Qwen-Image VAE

* fix: update default value for --discrete_flow_shift in anima training guide

* feat: add Qwen-Image VAE

* feat: simplify encode_tokens

* feat: use unified attention module, add wrapper for state dict compatibility

* feat: loading with dynamic fp8 optimization and LoRA support

* feat: add anima minimal inference script (WIP)

* format: format

* feat: simplify target module selection by regular expression patterns

* feat: kept caption dropout rate in cache and handle in training script

* feat: update train_llm_adapter and verbose default values to string type

* fix: use strategy instead of using tokenizers directly

* feat: add dtype property and all-zero mask handling in cross-attention in LLMAdapterTransformerBlock

* feat: support 5d tensor in get_noisy_model_input_and_timesteps

* feat: update loss calculation to support 5d tensor

* fix: update argument names in anima_train_utils to align with other archtectures

* feat: simplify Anima training script and update empty caption handling

* feat: support LoRA format without `net.` prefix

* fix: update to work fp8_scaled option

* feat: add regex-based learning rates and dimensions handling in create_network

* fix: improve regex matching for module selection and learning rates in LoRANetwork

* fix: update logging message for regex match in LoRANetwork

* fix: keep latents 4D except DiT call

* feat: enhance block swap functionality for inference and training in Anima model

* feat: refactor Anima training script

* feat: optimize VAE processing by adjusting tensor dimensions and data types

* fix: wait all block trasfer before siwtching offloader mode

* feat: update Anima training guide with new argument specifications and regex-based module selection. Thank you Claude!

* feat: support LORA for Qwen3

* feat: update Anima SAI model spec metadata handling

* fix: remove unused code

* feat: split CFG processing in do_sample function to reduce memory usage

* feat: add VAE chunking and caching options to reduce memory usage

* feat: optimize RMSNorm forward method and remove unused torch_attention_op

* Update library/strategy_anima.py

Use torch.all instead of all.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update library/safetensors_utils.py

Fix duplicated new_key for concat_hook.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update anima_minimal_inference.py

Remove unused code.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update anima_train.py

Remove unused import.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update library/anima_train_utils.py

Remove unused import.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* fix: review with Copilot

* feat: add script to convert LoRA format to ComfyUI compatible format (WIP, not tested yet)

* feat: add process_escape function to handle escape sequences in prompts

* feat: enhance LoRA weight handling in model loading and add text encoder loading function

* feat: improve ComfyUI conversion script with prefix constants and module name adjustments

* feat: update caption dropout documentation to clarify cache regeneration requirement

* feat: add clarification on learning rate adjustments

* feat: add note on PyTorch version requirement to prevent NaN loss

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-02-13 08:15:06 +09:00

269 lines
11 KiB
Python

# Unified attention function supporting various implementations
from dataclasses import dataclass
import torch
from typing import Optional, Union
try:
import flash_attn
from flash_attn.flash_attn_interface import _flash_attn_forward
from flash_attn.flash_attn_interface import flash_attn_varlen_func
from flash_attn.flash_attn_interface import flash_attn_func
except ImportError:
flash_attn = None
flash_attn_varlen_func = None
_flash_attn_forward = None
flash_attn_func = None
try:
from sageattention import sageattn_varlen, sageattn
except ImportError:
sageattn_varlen = None
sageattn = None
try:
import xformers.ops as xops
except ImportError:
xops = None
@dataclass
class AttentionParams:
attn_mode: Optional[str] = None
split_attn: bool = False
img_len: Optional[int] = None
attention_mask: Optional[torch.Tensor] = None
seqlens: Optional[torch.Tensor] = None
cu_seqlens: Optional[torch.Tensor] = None
max_seqlen: Optional[int] = None
@property
def supports_fp32(self) -> bool:
return self.attn_mode not in ["flash"]
@property
def requires_same_dtype(self) -> bool:
return self.attn_mode in ["xformers"]
@staticmethod
def create_attention_params(attn_mode: Optional[str], split_attn: bool) -> "AttentionParams":
return AttentionParams(attn_mode, split_attn)
@staticmethod
def create_attention_params_from_mask(
attn_mode: Optional[str], split_attn: bool, img_len: Optional[int], attention_mask: Optional[torch.Tensor]
) -> "AttentionParams":
if attention_mask is None:
# No attention mask provided: assume all tokens are valid
return AttentionParams(attn_mode, split_attn, None, None, None, None, None)
else:
# Note: attention_mask is only for text tokens, not including image tokens
seqlens = attention_mask.sum(dim=1).to(torch.int32) + img_len # [B]
max_seqlen = attention_mask.shape[1] + img_len
if split_attn:
# cu_seqlens is not needed for split attention
return AttentionParams(attn_mode, split_attn, img_len, attention_mask, seqlens, None, max_seqlen)
# Convert attention mask to cumulative sequence lengths for flash attention
batch_size = attention_mask.shape[0]
cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device=attention_mask.device)
for i in range(batch_size):
cu_seqlens[2 * i + 1] = i * max_seqlen + seqlens[i] # end of valid tokens for query
cu_seqlens[2 * i + 2] = (i + 1) * max_seqlen # end of all tokens for query
# Expand attention mask to include image tokens
attention_mask = torch.nn.functional.pad(attention_mask, (img_len, 0), value=1) # [B, img_len + L]
if attn_mode == "xformers":
seqlens_list = seqlens.cpu().tolist()
attention_mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(
seqlens_list, seqlens_list, device=attention_mask.device
)
elif attn_mode == "torch":
attention_mask = attention_mask[:, None, None, :].to(torch.bool) # [B, 1, 1, img_len + L]
return AttentionParams(attn_mode, split_attn, img_len, attention_mask, seqlens, cu_seqlens, max_seqlen)
def attention(
qkv_or_q: Union[torch.Tensor, list],
k: Optional[torch.Tensor] = None,
v: Optional[torch.Tensor] = None,
attn_params: Optional[AttentionParams] = None,
drop_rate: float = 0.0,
) -> torch.Tensor:
"""
Compute scaled dot-product attention with variable sequence lengths.
Handles batches with different sequence lengths by splitting and
processing each sequence individually.
Args:
qkv_or_q: Query tensor [B, L, H, D]. or list of such tensors.
k: Key tensor [B, L, H, D].
v: Value tensor [B, L, H, D].
attn_params: Attention parameters including mask and sequence lengths.
drop_rate: Attention dropout rate.
Returns:
Attention output tensor [B, L, H*D].
"""
if isinstance(qkv_or_q, list):
q, k, v = qkv_or_q
q: torch.Tensor = q
qkv_or_q.clear()
del qkv_or_q
else:
q: torch.Tensor = qkv_or_q
del qkv_or_q
assert k is not None and v is not None, "k and v must be provided if qkv_or_q is a tensor"
if attn_params is None:
attn_params = AttentionParams.create_attention_params("torch", False)
# If split attn is False, attention mask is provided and all sequence lengths are same, we can trim the sequence
seqlen_trimmed = False
if not attn_params.split_attn and attn_params.attention_mask is not None and attn_params.seqlens is not None:
if torch.all(attn_params.seqlens == attn_params.seqlens[0]):
seqlen = attn_params.seqlens[0].item()
q = q[:, :seqlen]
k = k[:, :seqlen]
v = v[:, :seqlen]
max_seqlen = attn_params.max_seqlen
attn_params = AttentionParams.create_attention_params(attn_params.attn_mode, False) # do not in-place modify
attn_params.max_seqlen = max_seqlen # keep max_seqlen for padding
seqlen_trimmed = True
# Determine tensor layout based on attention implementation
if attn_params.attn_mode == "torch" or (
attn_params.attn_mode == "sageattn" and (attn_params.split_attn or attn_params.cu_seqlens is None)
):
transpose_fn = lambda x: x.transpose(1, 2) # [B, H, L, D] for SDPA and sageattn with fixed length
# pad on sequence length dimension
pad_fn = lambda x, pad_to: torch.nn.functional.pad(x, (0, 0, 0, pad_to - x.shape[-2]), value=0)
else:
transpose_fn = lambda x: x # [B, L, H, D] for other implementations
# pad on sequence length dimension
pad_fn = lambda x, pad_to: torch.nn.functional.pad(x, (0, 0, 0, 0, 0, pad_to - x.shape[-3]), value=0)
# Process each batch element with its valid sequence lengths
if attn_params.split_attn:
if attn_params.seqlens is None:
# If no seqlens provided, assume all tokens are valid
attn_params = AttentionParams.create_attention_params(attn_params.attn_mode, True) # do not in-place modify
attn_params.seqlens = torch.tensor([q.shape[1]] * q.shape[0], device=q.device)
attn_params.max_seqlen = q.shape[1]
q = [transpose_fn(q[i : i + 1, : attn_params.seqlens[i]]) for i in range(len(q))]
k = [transpose_fn(k[i : i + 1, : attn_params.seqlens[i]]) for i in range(len(k))]
v = [transpose_fn(v[i : i + 1, : attn_params.seqlens[i]]) for i in range(len(v))]
else:
q = transpose_fn(q)
k = transpose_fn(k)
v = transpose_fn(v)
if attn_params.attn_mode == "torch":
if attn_params.split_attn:
x = []
for i in range(len(q)):
x_i = torch.nn.functional.scaled_dot_product_attention(q[i], k[i], v[i], dropout_p=drop_rate)
q[i] = None
k[i] = None
v[i] = None
x.append(pad_fn(x_i, attn_params.max_seqlen)) # B, H, L, D
x = torch.cat(x, dim=0)
del q, k, v
else:
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_params.attention_mask, dropout_p=drop_rate)
del q, k, v
elif attn_params.attn_mode == "xformers":
if attn_params.split_attn:
x = []
for i in range(len(q)):
x_i = xops.memory_efficient_attention(q[i], k[i], v[i], p=drop_rate)
q[i] = None
k[i] = None
v[i] = None
x.append(pad_fn(x_i, attn_params.max_seqlen)) # B, L, H, D
x = torch.cat(x, dim=0)
del q, k, v
else:
x = xops.memory_efficient_attention(q, k, v, attn_bias=attn_params.attention_mask, p=drop_rate)
del q, k, v
elif attn_params.attn_mode == "sageattn":
if attn_params.split_attn:
x = []
for i in range(len(q)):
# HND seems to cause an error
x_i = sageattn(q[i], k[i], v[i]) # B, H, L, D. No dropout support
q[i] = None
k[i] = None
v[i] = None
x.append(pad_fn(x_i, attn_params.max_seqlen)) # B, H, L, D
x = torch.cat(x, dim=0)
del q, k, v
elif attn_params.cu_seqlens is None: # all tokens are valid
x = sageattn(q, k, v) # B, L, H, D. No dropout support
del q, k, v
else:
# Reshape to [(bxs), a, d]
batch_size, seqlen = q.shape[0], q.shape[1]
q = q.view(q.shape[0] * q.shape[1], *q.shape[2:]) # [B*L, H, D]
k = k.view(k.shape[0] * k.shape[1], *k.shape[2:]) # [B*L, H, D]
v = v.view(v.shape[0] * v.shape[1], *v.shape[2:]) # [B*L, H, D]
# Assume cu_seqlens_q == cu_seqlens_kv and max_seqlen_q == max_seqlen_kv. No dropout support
x = sageattn_varlen(
q, k, v, attn_params.cu_seqlens, attn_params.cu_seqlens, attn_params.max_seqlen, attn_params.max_seqlen
)
del q, k, v
# Reshape x with shape [(bxs), a, d] to [b, s, a, d]
x = x.view(batch_size, seqlen, x.shape[-2], x.shape[-1]) # B, L, H, D
elif attn_params.attn_mode == "flash":
if attn_params.split_attn:
x = []
for i in range(len(q)):
# HND seems to cause an error
x_i = flash_attn_func(q[i], k[i], v[i], drop_rate) # B, L, H, D
q[i] = None
k[i] = None
v[i] = None
x.append(pad_fn(x_i, attn_params.max_seqlen)) # B, L, H, D
x = torch.cat(x, dim=0)
del q, k, v
elif attn_params.cu_seqlens is None: # all tokens are valid
x = flash_attn_func(q, k, v, drop_rate) # B, L, H, D
del q, k, v
else:
# Reshape to [(bxs), a, d]
batch_size, seqlen = q.shape[0], q.shape[1]
q = q.view(q.shape[0] * q.shape[1], *q.shape[2:]) # [B*L, H, D]
k = k.view(k.shape[0] * k.shape[1], *k.shape[2:]) # [B*L, H, D]
v = v.view(v.shape[0] * v.shape[1], *v.shape[2:]) # [B*L, H, D]
# Assume cu_seqlens_q == cu_seqlens_kv and max_seqlen_q == max_seqlen_kv
x = flash_attn_varlen_func(
q, k, v, attn_params.cu_seqlens, attn_params.cu_seqlens, attn_params.max_seqlen, attn_params.max_seqlen, drop_rate
)
del q, k, v
# Reshape x with shape [(bxs), a, d] to [b, s, a, d]
x = x.view(batch_size, seqlen, x.shape[-2], x.shape[-1]) # B, L, H, D
else:
# Currently only PyTorch SDPA and xformers are implemented
raise ValueError(f"Unsupported attention mode: {attn_params.attn_mode}")
x = transpose_fn(x) # [B, L, H, D]
x = x.reshape(x.shape[0], x.shape[1], -1) # [B, L, H*D]
if seqlen_trimmed:
x = torch.nn.functional.pad(x, (0, 0, 0, attn_params.max_seqlen - x.shape[1]), value=0) # pad back to max_seqlen
return x