Merge pull request #29 from kohya-ss/dev

rename another position_ids key (supports wd v1.4)
This commit is contained in:
Kohya S
2023-01-01 09:23:42 +09:00
committed by GitHub

View File

@@ -624,8 +624,16 @@ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
# position_idsの追加
new_sd["text_model.embeddings.position_ids"] = torch.Tensor([list(range(max_length))]).to(torch.int64)
# rename or add position_ids
ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
if ANOTHER_POSITION_IDS_KEY in new_sd:
# waifu diffusion v1.4
position_ids = new_sd[ANOTHER_POSITION_IDS_KEY]
del new_sd[ANOTHER_POSITION_IDS_KEY]
else:
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
new_sd["text_model.embeddings.position_ids"] = position_ids
return new_sd
# endregion