update gemma2 train attention layer

This commit is contained in:
sdbds
2025-03-04 08:07:33 +08:00
parent 09c4710d1e
commit 5e45df722d

View File

@@ -462,7 +462,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, lumina, wei
class LoRANetwork(torch.nn.Module):
LUMINA_TARGET_REPLACE_MODULE = ["JointTransformerBlock"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["Gemma2Attention", "Gemma2MLP"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["Gemma2Attention", "Gemma2FlashAttention2", "Gemma2SdpaAttention", "Gemma2MLP"]
LORA_PREFIX_LUMINA = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te" # Simplified prefix since we only have one text encoder