diff --git a/networks/lora_lumina.py b/networks/lora_lumina.py index 3f6c9b41..431c183d 100644 --- a/networks/lora_lumina.py +++ b/networks/lora_lumina.py @@ -462,7 +462,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, lumina, wei class LoRANetwork(torch.nn.Module): LUMINA_TARGET_REPLACE_MODULE = ["JointTransformerBlock"] - TEXT_ENCODER_TARGET_REPLACE_MODULE = ["Gemma2Attention", "Gemma2MLP"] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["Gemma2Attention", "Gemma2FlashAttention2", "Gemma2SdpaAttention", "Gemma2MLP"] LORA_PREFIX_LUMINA = "lora_unet" LORA_PREFIX_TEXT_ENCODER = "lora_te" # Simplified prefix since we only have one text encoder