From b5db90c8a848203f028e2b0d2c50a4d3f4dfd882 Mon Sep 17 00:00:00 2001 From: ykume Date: Fri, 18 Aug 2023 09:00:22 +0900 Subject: [PATCH] modify to attn1/attn2 only --- networks/lora_control_net.py | 10 +++++++++- sdxl_train_lora_control_net.py | 2 +- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/networks/lora_control_net.py b/networks/lora_control_net.py index 0dd2a0a1..120ab0ac 100644 --- a/networks/lora_control_net.py +++ b/networks/lora_control_net.py @@ -18,7 +18,11 @@ SKIP_CONV2D = False # if True, only transformer_blocks are applied, and ResBlocks are not applied TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored because conv2d is not used in transformer_blocks +# Trueならattn1とattn2にのみ適用し、ffなどには適用しない / if True, apply only to attn1 and attn2, not to ff etc. +ATTN1_2_ONLY = True + # Trueならattn1やffなどにのみ適用し、attn2などには適用しない / if True, apply only to attn1 and ff, not to attn2 +# ATTN1_2_ONLYと同時にTrueにできない / cannot be True at the same time as ATTN1_2_ONLY ATTN1_ETC_ONLY = False # True # transformer_blocksの最大インデックス。Noneなら全てのtransformer_blocksに適用 @@ -203,6 +207,10 @@ class LoRAControlNet(torch.nn.Module): if "emb_layers" in lora_name or ("attn2" in lora_name and ("to_k" in lora_name or "to_v" in lora_name)): continue + if ATTN1_2_ONLY: + if not ("attn1" in lora_name or "attn2" in lora_name): + continue + if ATTN1_ETC_ONLY: if "proj_out" in lora_name: pass @@ -368,7 +376,7 @@ if __name__ == "__main__": unet.to("cuda").to(torch.float16) print("create LoRA controlnet") - control_net = LoRAControlNet(unet, 64, 16, 1) + control_net = LoRAControlNet(unet, 64, 32, 1) control_net.apply_to() control_net.to("cuda") diff --git a/sdxl_train_lora_control_net.py b/sdxl_train_lora_control_net.py index e0ec3a6a..b6fb1dec 100644 --- a/sdxl_train_lora_control_net.py +++ b/sdxl_train_lora_control_net.py @@ -325,7 +325,7 @@ def train(args): accelerator.print(f"\nsaving checkpoint: {ckpt_file}") sai_metadata = train_util.get_sai_model_spec(None, args, True, True, False) - sai_metadata["modelspec.architecture"] = sai_model_spec.ARCH_SD_XL_V1_BASE + "/lora-control-net" + sai_metadata["modelspec.architecture"] = sai_model_spec.ARCH_SD_XL_V1_BASE + "/control-net-llite" unwrapped_nw.save_weights(ckpt_file, save_dtype, sai_metadata) if args.huggingface_repo_id is not None: