diff --git a/library/model_util.py b/library/model_util.py index 26f72235..70a8c752 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -540,6 +540,11 @@ def convert_ldm_clip_checkpoint_v1(checkpoint): for key in keys: if key.startswith("cond_stage_model.transformer"): text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] + + # support checkpoint without position_ids (invalid checkpoint) + if "text_model.embeddings.position_ids" not in text_model_dict: + text_model_dict["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) # 77 is the max length of the text + return text_model_dict