mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
Remove unused attention, fix typo
This commit is contained in:
@@ -467,13 +467,6 @@ class JointAttention(nn.Module):
|
||||
return self.out(output)
|
||||
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor:
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
x = rearrange(x, "B H L D -> B L (H D)")
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def apply_rope(
|
||||
x_in: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
@@ -965,8 +958,6 @@ class NextDiT(nn.Module):
|
||||
Tuple[Tensor, Tensor, Tensor, List[int], List[int]]:
|
||||
|
||||
return x, attention_mask, freqs_cis, l_effective_cap_len, seq_lengths
|
||||
|
||||
|
||||
"""
|
||||
bsz, channels, height, width = x.shape
|
||||
pH = pW = self.patch_size
|
||||
@@ -993,7 +984,7 @@ class NextDiT(nn.Module):
|
||||
position_ids[i, cap_len:seq_len, 1] = row_ids
|
||||
position_ids[i, cap_len:seq_len, 2] = col_ids
|
||||
|
||||
# Get combinded rotary embeddings
|
||||
# Get combined rotary embeddings
|
||||
freqs_cis = self.rope_embedder(position_ids)
|
||||
|
||||
# Create separate rotary embeddings for captions and images
|
||||
|
||||
Reference in New Issue
Block a user