This commit is contained in:
Dave Lage
2026-03-31 04:54:17 +00:00
committed by GitHub

View File

@@ -18,6 +18,7 @@ import torch
from einops import rearrange
from torch import Tensor, nn
from torch.utils.checkpoint import checkpoint
from torch.nn.attention import SDPBackend, sdpa_kernel
from library import custom_offloading_utils
@@ -445,11 +446,13 @@ configs = {
# region math
kernels = [SDPBackend.FLASH_ATTENTION, SDPBackend.CUDNN_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor:
q, k = apply_rope(q, k, pe)
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
with sdpa_kernel(kernels):
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
x = rearrange(x, "B H L D -> B L (H D)")
return x