modify to attn1/attn2 only

This commit is contained in:
ykume
2023-08-18 09:00:22 +09:00
parent 1e52fe6e09
commit b5db90c8a8
2 changed files with 10 additions and 2 deletions

View File

@@ -18,7 +18,11 @@ SKIP_CONV2D = False
# if True, only transformer_blocks are applied, and ResBlocks are not applied
TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored because conv2d is not used in transformer_blocks
# Trueならattn1とattn2にのみ適用し、ffなどには適用しない / if True, apply only to attn1 and attn2, not to ff etc.
ATTN1_2_ONLY = True
# Trueならattn1やffなどにのみ適用し、attn2などには適用しない / if True, apply only to attn1 and ff, not to attn2
# ATTN1_2_ONLYと同時にTrueにできない / cannot be True at the same time as ATTN1_2_ONLY
ATTN1_ETC_ONLY = False # True
# transformer_blocksの最大インデックス。Noneなら全てのtransformer_blocksに適用
@@ -203,6 +207,10 @@ class LoRAControlNet(torch.nn.Module):
if "emb_layers" in lora_name or ("attn2" in lora_name and ("to_k" in lora_name or "to_v" in lora_name)):
continue
if ATTN1_2_ONLY:
if not ("attn1" in lora_name or "attn2" in lora_name):
continue
if ATTN1_ETC_ONLY:
if "proj_out" in lora_name:
pass
@@ -368,7 +376,7 @@ if __name__ == "__main__":
unet.to("cuda").to(torch.float16)
print("create LoRA controlnet")
control_net = LoRAControlNet(unet, 64, 16, 1)
control_net = LoRAControlNet(unet, 64, 32, 1)
control_net.apply_to()
control_net.to("cuda")

View File

@@ -325,7 +325,7 @@ def train(args):
accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
sai_metadata = train_util.get_sai_model_spec(None, args, True, True, False)
sai_metadata["modelspec.architecture"] = sai_model_spec.ARCH_SD_XL_V1_BASE + "/lora-control-net"
sai_metadata["modelspec.architecture"] = sai_model_spec.ARCH_SD_XL_V1_BASE + "/control-net-llite"
unwrapped_nw.save_weights(ckpt_file, save_dtype, sai_metadata)
if args.huggingface_repo_id is not None: