Remove unused attention, fix typo

This commit is contained in:
rockerBOO
2025-02-18 01:21:18 -05:00
parent 98efbc3bb7
commit bd16bd13ae

View File

@@ -467,13 +467,6 @@ class JointAttention(nn.Module):
return self.out(output)
def attention(q: Tensor, k: Tensor, v: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor:
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
x = rearrange(x, "B H L D -> B L (H D)")
return x
def apply_rope(
x_in: torch.Tensor,
freqs_cis: torch.Tensor,
@@ -965,8 +958,6 @@ class NextDiT(nn.Module):
Tuple[Tensor, Tensor, Tensor, List[int], List[int]]:
return x, attention_mask, freqs_cis, l_effective_cap_len, seq_lengths
"""
bsz, channels, height, width = x.shape
pH = pW = self.patch_size
@@ -993,7 +984,7 @@ class NextDiT(nn.Module):
position_ids[i, cap_len:seq_len, 1] = row_ids
position_ids[i, cap_len:seq_len, 2] = col_ids
# Get combinded rotary embeddings
# Get combined rotary embeddings
freqs_cis = self.rope_embedder(position_ids)
# Create separate rotary embeddings for captions and images