From f19233887477746d3386c305794a559f6e4b503b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 1 Jan 2023 09:17:16 +0900 Subject: [PATCH] rename another position_ids key (supports wd v1.4) --- library/model_util.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/library/model_util.py b/library/model_util.py index f3453025..398b6404 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -624,8 +624,16 @@ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length): new_sd[key_pfx + "k_proj" + key_suffix] = values[1] new_sd[key_pfx + "v_proj" + key_suffix] = values[2] - # position_idsの追加 - new_sd["text_model.embeddings.position_ids"] = torch.Tensor([list(range(max_length))]).to(torch.int64) + # rename or add position_ids + ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids" + if ANOTHER_POSITION_IDS_KEY in new_sd: + # waifu diffusion v1.4 + position_ids = new_sd[ANOTHER_POSITION_IDS_KEY] + del new_sd[ANOTHER_POSITION_IDS_KEY] + else: + position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64) + + new_sd["text_model.embeddings.position_ids"] = position_ids return new_sd # endregion