diff --git a/library/lumina_models.py b/library/lumina_models.py index e82f3b2c..36c3b979 100644 --- a/library/lumina_models.py +++ b/library/lumina_models.py @@ -467,13 +467,6 @@ class JointAttention(nn.Module): return self.out(output) -def attention(q: Tensor, k: Tensor, v: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor: - x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) - x = rearrange(x, "B H L D -> B L (H D)") - - return x - - def apply_rope( x_in: torch.Tensor, freqs_cis: torch.Tensor, @@ -965,8 +958,6 @@ class NextDiT(nn.Module): Tuple[Tensor, Tensor, Tensor, List[int], List[int]]: return x, attention_mask, freqs_cis, l_effective_cap_len, seq_lengths - - """ bsz, channels, height, width = x.shape pH = pW = self.patch_size @@ -993,7 +984,7 @@ class NextDiT(nn.Module): position_ids[i, cap_len:seq_len, 1] = row_ids position_ids[i, cap_len:seq_len, 2] = col_ids - # Get combinded rotary embeddings + # Get combined rotary embeddings freqs_cis = self.rope_embedder(position_ids) # Create separate rotary embeddings for captions and images