mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
* fix: update extend-exclude list in _typos.toml to include configs * fix: exclude anima tests from pytest * feat: add entry for 'temperal' in extend-words section of _typos.toml for Qwen-Image VAE * fix: update default value for --discrete_flow_shift in anima training guide * feat: add Qwen-Image VAE * feat: simplify encode_tokens * feat: use unified attention module, add wrapper for state dict compatibility * feat: loading with dynamic fp8 optimization and LoRA support * feat: add anima minimal inference script (WIP) * format: format * feat: simplify target module selection by regular expression patterns * feat: kept caption dropout rate in cache and handle in training script * feat: update train_llm_adapter and verbose default values to string type * fix: use strategy instead of using tokenizers directly * feat: add dtype property and all-zero mask handling in cross-attention in LLMAdapterTransformerBlock * feat: support 5d tensor in get_noisy_model_input_and_timesteps * feat: update loss calculation to support 5d tensor * fix: update argument names in anima_train_utils to align with other archtectures * feat: simplify Anima training script and update empty caption handling * feat: support LoRA format without `net.` prefix * fix: update to work fp8_scaled option * feat: add regex-based learning rates and dimensions handling in create_network * fix: improve regex matching for module selection and learning rates in LoRANetwork * fix: update logging message for regex match in LoRANetwork * fix: keep latents 4D except DiT call * feat: enhance block swap functionality for inference and training in Anima model * feat: refactor Anima training script * feat: optimize VAE processing by adjusting tensor dimensions and data types * fix: wait all block trasfer before siwtching offloader mode * feat: update Anima training guide with new argument specifications and regex-based module selection. Thank you Claude! * feat: support LORA for Qwen3 * feat: update Anima SAI model spec metadata handling * fix: remove unused code * feat: split CFG processing in do_sample function to reduce memory usage * feat: add VAE chunking and caching options to reduce memory usage * feat: optimize RMSNorm forward method and remove unused torch_attention_op * Update library/strategy_anima.py Use torch.all instead of all. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update library/safetensors_utils.py Fix duplicated new_key for concat_hook. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update anima_minimal_inference.py Remove unused code. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update anima_train.py Remove unused import. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update library/anima_train_utils.py Remove unused import. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix: review with Copilot * feat: add script to convert LoRA format to ComfyUI compatible format (WIP, not tested yet) * feat: add process_escape function to handle escape sequences in prompts * feat: enhance LoRA weight handling in model loading and add text encoder loading function * feat: improve ComfyUI conversion script with prefix constants and module name adjustments * feat: update caption dropout documentation to clarify cache regeneration requirement * feat: add clarification on learning rate adjustments * feat: add note on PyTorch version requirement to prevent NaN loss --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
269 lines
11 KiB
Python
269 lines
11 KiB
Python
# Unified attention function supporting various implementations
|
|
|
|
from dataclasses import dataclass
|
|
import torch
|
|
from typing import Optional, Union
|
|
|
|
try:
|
|
import flash_attn
|
|
from flash_attn.flash_attn_interface import _flash_attn_forward
|
|
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
|
from flash_attn.flash_attn_interface import flash_attn_func
|
|
except ImportError:
|
|
flash_attn = None
|
|
flash_attn_varlen_func = None
|
|
_flash_attn_forward = None
|
|
flash_attn_func = None
|
|
|
|
try:
|
|
from sageattention import sageattn_varlen, sageattn
|
|
except ImportError:
|
|
sageattn_varlen = None
|
|
sageattn = None
|
|
|
|
try:
|
|
import xformers.ops as xops
|
|
except ImportError:
|
|
xops = None
|
|
|
|
|
|
@dataclass
|
|
class AttentionParams:
|
|
attn_mode: Optional[str] = None
|
|
split_attn: bool = False
|
|
img_len: Optional[int] = None
|
|
attention_mask: Optional[torch.Tensor] = None
|
|
seqlens: Optional[torch.Tensor] = None
|
|
cu_seqlens: Optional[torch.Tensor] = None
|
|
max_seqlen: Optional[int] = None
|
|
|
|
@property
|
|
def supports_fp32(self) -> bool:
|
|
return self.attn_mode not in ["flash"]
|
|
|
|
@property
|
|
def requires_same_dtype(self) -> bool:
|
|
return self.attn_mode in ["xformers"]
|
|
|
|
@staticmethod
|
|
def create_attention_params(attn_mode: Optional[str], split_attn: bool) -> "AttentionParams":
|
|
return AttentionParams(attn_mode, split_attn)
|
|
|
|
@staticmethod
|
|
def create_attention_params_from_mask(
|
|
attn_mode: Optional[str], split_attn: bool, img_len: Optional[int], attention_mask: Optional[torch.Tensor]
|
|
) -> "AttentionParams":
|
|
if attention_mask is None:
|
|
# No attention mask provided: assume all tokens are valid
|
|
return AttentionParams(attn_mode, split_attn, None, None, None, None, None)
|
|
else:
|
|
# Note: attention_mask is only for text tokens, not including image tokens
|
|
seqlens = attention_mask.sum(dim=1).to(torch.int32) + img_len # [B]
|
|
max_seqlen = attention_mask.shape[1] + img_len
|
|
|
|
if split_attn:
|
|
# cu_seqlens is not needed for split attention
|
|
return AttentionParams(attn_mode, split_attn, img_len, attention_mask, seqlens, None, max_seqlen)
|
|
|
|
# Convert attention mask to cumulative sequence lengths for flash attention
|
|
batch_size = attention_mask.shape[0]
|
|
cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device=attention_mask.device)
|
|
for i in range(batch_size):
|
|
cu_seqlens[2 * i + 1] = i * max_seqlen + seqlens[i] # end of valid tokens for query
|
|
cu_seqlens[2 * i + 2] = (i + 1) * max_seqlen # end of all tokens for query
|
|
|
|
# Expand attention mask to include image tokens
|
|
attention_mask = torch.nn.functional.pad(attention_mask, (img_len, 0), value=1) # [B, img_len + L]
|
|
|
|
if attn_mode == "xformers":
|
|
seqlens_list = seqlens.cpu().tolist()
|
|
attention_mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(
|
|
seqlens_list, seqlens_list, device=attention_mask.device
|
|
)
|
|
elif attn_mode == "torch":
|
|
attention_mask = attention_mask[:, None, None, :].to(torch.bool) # [B, 1, 1, img_len + L]
|
|
|
|
return AttentionParams(attn_mode, split_attn, img_len, attention_mask, seqlens, cu_seqlens, max_seqlen)
|
|
|
|
|
|
def attention(
|
|
qkv_or_q: Union[torch.Tensor, list],
|
|
k: Optional[torch.Tensor] = None,
|
|
v: Optional[torch.Tensor] = None,
|
|
attn_params: Optional[AttentionParams] = None,
|
|
drop_rate: float = 0.0,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Compute scaled dot-product attention with variable sequence lengths.
|
|
|
|
Handles batches with different sequence lengths by splitting and
|
|
processing each sequence individually.
|
|
|
|
Args:
|
|
qkv_or_q: Query tensor [B, L, H, D]. or list of such tensors.
|
|
k: Key tensor [B, L, H, D].
|
|
v: Value tensor [B, L, H, D].
|
|
attn_params: Attention parameters including mask and sequence lengths.
|
|
drop_rate: Attention dropout rate.
|
|
|
|
Returns:
|
|
Attention output tensor [B, L, H*D].
|
|
"""
|
|
if isinstance(qkv_or_q, list):
|
|
q, k, v = qkv_or_q
|
|
q: torch.Tensor = q
|
|
qkv_or_q.clear()
|
|
del qkv_or_q
|
|
else:
|
|
q: torch.Tensor = qkv_or_q
|
|
del qkv_or_q
|
|
assert k is not None and v is not None, "k and v must be provided if qkv_or_q is a tensor"
|
|
if attn_params is None:
|
|
attn_params = AttentionParams.create_attention_params("torch", False)
|
|
|
|
# If split attn is False, attention mask is provided and all sequence lengths are same, we can trim the sequence
|
|
seqlen_trimmed = False
|
|
if not attn_params.split_attn and attn_params.attention_mask is not None and attn_params.seqlens is not None:
|
|
if torch.all(attn_params.seqlens == attn_params.seqlens[0]):
|
|
seqlen = attn_params.seqlens[0].item()
|
|
q = q[:, :seqlen]
|
|
k = k[:, :seqlen]
|
|
v = v[:, :seqlen]
|
|
max_seqlen = attn_params.max_seqlen
|
|
attn_params = AttentionParams.create_attention_params(attn_params.attn_mode, False) # do not in-place modify
|
|
attn_params.max_seqlen = max_seqlen # keep max_seqlen for padding
|
|
seqlen_trimmed = True
|
|
|
|
# Determine tensor layout based on attention implementation
|
|
if attn_params.attn_mode == "torch" or (
|
|
attn_params.attn_mode == "sageattn" and (attn_params.split_attn or attn_params.cu_seqlens is None)
|
|
):
|
|
transpose_fn = lambda x: x.transpose(1, 2) # [B, H, L, D] for SDPA and sageattn with fixed length
|
|
# pad on sequence length dimension
|
|
pad_fn = lambda x, pad_to: torch.nn.functional.pad(x, (0, 0, 0, pad_to - x.shape[-2]), value=0)
|
|
else:
|
|
transpose_fn = lambda x: x # [B, L, H, D] for other implementations
|
|
# pad on sequence length dimension
|
|
pad_fn = lambda x, pad_to: torch.nn.functional.pad(x, (0, 0, 0, 0, 0, pad_to - x.shape[-3]), value=0)
|
|
|
|
# Process each batch element with its valid sequence lengths
|
|
if attn_params.split_attn:
|
|
if attn_params.seqlens is None:
|
|
# If no seqlens provided, assume all tokens are valid
|
|
attn_params = AttentionParams.create_attention_params(attn_params.attn_mode, True) # do not in-place modify
|
|
attn_params.seqlens = torch.tensor([q.shape[1]] * q.shape[0], device=q.device)
|
|
attn_params.max_seqlen = q.shape[1]
|
|
q = [transpose_fn(q[i : i + 1, : attn_params.seqlens[i]]) for i in range(len(q))]
|
|
k = [transpose_fn(k[i : i + 1, : attn_params.seqlens[i]]) for i in range(len(k))]
|
|
v = [transpose_fn(v[i : i + 1, : attn_params.seqlens[i]]) for i in range(len(v))]
|
|
else:
|
|
q = transpose_fn(q)
|
|
k = transpose_fn(k)
|
|
v = transpose_fn(v)
|
|
|
|
if attn_params.attn_mode == "torch":
|
|
if attn_params.split_attn:
|
|
x = []
|
|
for i in range(len(q)):
|
|
x_i = torch.nn.functional.scaled_dot_product_attention(q[i], k[i], v[i], dropout_p=drop_rate)
|
|
q[i] = None
|
|
k[i] = None
|
|
v[i] = None
|
|
x.append(pad_fn(x_i, attn_params.max_seqlen)) # B, H, L, D
|
|
x = torch.cat(x, dim=0)
|
|
del q, k, v
|
|
|
|
else:
|
|
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_params.attention_mask, dropout_p=drop_rate)
|
|
del q, k, v
|
|
|
|
elif attn_params.attn_mode == "xformers":
|
|
if attn_params.split_attn:
|
|
x = []
|
|
for i in range(len(q)):
|
|
x_i = xops.memory_efficient_attention(q[i], k[i], v[i], p=drop_rate)
|
|
q[i] = None
|
|
k[i] = None
|
|
v[i] = None
|
|
x.append(pad_fn(x_i, attn_params.max_seqlen)) # B, L, H, D
|
|
x = torch.cat(x, dim=0)
|
|
del q, k, v
|
|
|
|
else:
|
|
x = xops.memory_efficient_attention(q, k, v, attn_bias=attn_params.attention_mask, p=drop_rate)
|
|
del q, k, v
|
|
|
|
elif attn_params.attn_mode == "sageattn":
|
|
if attn_params.split_attn:
|
|
x = []
|
|
for i in range(len(q)):
|
|
# HND seems to cause an error
|
|
x_i = sageattn(q[i], k[i], v[i]) # B, H, L, D. No dropout support
|
|
q[i] = None
|
|
k[i] = None
|
|
v[i] = None
|
|
x.append(pad_fn(x_i, attn_params.max_seqlen)) # B, H, L, D
|
|
x = torch.cat(x, dim=0)
|
|
del q, k, v
|
|
elif attn_params.cu_seqlens is None: # all tokens are valid
|
|
x = sageattn(q, k, v) # B, L, H, D. No dropout support
|
|
del q, k, v
|
|
else:
|
|
# Reshape to [(bxs), a, d]
|
|
batch_size, seqlen = q.shape[0], q.shape[1]
|
|
q = q.view(q.shape[0] * q.shape[1], *q.shape[2:]) # [B*L, H, D]
|
|
k = k.view(k.shape[0] * k.shape[1], *k.shape[2:]) # [B*L, H, D]
|
|
v = v.view(v.shape[0] * v.shape[1], *v.shape[2:]) # [B*L, H, D]
|
|
|
|
# Assume cu_seqlens_q == cu_seqlens_kv and max_seqlen_q == max_seqlen_kv. No dropout support
|
|
x = sageattn_varlen(
|
|
q, k, v, attn_params.cu_seqlens, attn_params.cu_seqlens, attn_params.max_seqlen, attn_params.max_seqlen
|
|
)
|
|
del q, k, v
|
|
|
|
# Reshape x with shape [(bxs), a, d] to [b, s, a, d]
|
|
x = x.view(batch_size, seqlen, x.shape[-2], x.shape[-1]) # B, L, H, D
|
|
|
|
elif attn_params.attn_mode == "flash":
|
|
if attn_params.split_attn:
|
|
x = []
|
|
for i in range(len(q)):
|
|
# HND seems to cause an error
|
|
x_i = flash_attn_func(q[i], k[i], v[i], drop_rate) # B, L, H, D
|
|
q[i] = None
|
|
k[i] = None
|
|
v[i] = None
|
|
x.append(pad_fn(x_i, attn_params.max_seqlen)) # B, L, H, D
|
|
x = torch.cat(x, dim=0)
|
|
del q, k, v
|
|
elif attn_params.cu_seqlens is None: # all tokens are valid
|
|
x = flash_attn_func(q, k, v, drop_rate) # B, L, H, D
|
|
del q, k, v
|
|
else:
|
|
# Reshape to [(bxs), a, d]
|
|
batch_size, seqlen = q.shape[0], q.shape[1]
|
|
q = q.view(q.shape[0] * q.shape[1], *q.shape[2:]) # [B*L, H, D]
|
|
k = k.view(k.shape[0] * k.shape[1], *k.shape[2:]) # [B*L, H, D]
|
|
v = v.view(v.shape[0] * v.shape[1], *v.shape[2:]) # [B*L, H, D]
|
|
|
|
# Assume cu_seqlens_q == cu_seqlens_kv and max_seqlen_q == max_seqlen_kv
|
|
x = flash_attn_varlen_func(
|
|
q, k, v, attn_params.cu_seqlens, attn_params.cu_seqlens, attn_params.max_seqlen, attn_params.max_seqlen, drop_rate
|
|
)
|
|
del q, k, v
|
|
|
|
# Reshape x with shape [(bxs), a, d] to [b, s, a, d]
|
|
x = x.view(batch_size, seqlen, x.shape[-2], x.shape[-1]) # B, L, H, D
|
|
|
|
else:
|
|
# Currently only PyTorch SDPA and xformers are implemented
|
|
raise ValueError(f"Unsupported attention mode: {attn_params.attn_mode}")
|
|
|
|
x = transpose_fn(x) # [B, L, H, D]
|
|
x = x.reshape(x.shape[0], x.shape[1], -1) # [B, L, H*D]
|
|
|
|
if seqlen_trimmed:
|
|
x = torch.nn.functional.pad(x, (0, 0, 0, attn_params.max_seqlen - x.shape[1]), value=0) # pad back to max_seqlen
|
|
|
|
return x
|