Fix to use SDPA instead of xformers

This commit is contained in:
Kohya S
2024-10-30 14:34:19 +09:00
parent 8c3c825b5f
commit 70a179e446

View File

@@ -645,7 +645,7 @@ class MMDiTBlock(nn.Module):
if self.x_block.x_block_self_attn:
x_q2, x_k2, x_v2 = x_qkv2
attn2 = attention(x_q2, x_k2, x_v2, self.x_block.attn2.num_heads)
attn2 = attention(x_q2, x_k2, x_v2, self.x_block.attn2.num_heads, mode=self.mode)
x = self.x_block.post_attention_x(x_attn_out, attn2, *x_intermediates)
else:
x = self.x_block.post_attention(x_attn_out, *x_intermediates)