move convert_vae to inline, restore comments

This commit is contained in:
Kohya S
2023-01-14 21:24:09 +09:00
parent 199a3cbae4
commit 61ec60a893

View File

@@ -632,7 +632,7 @@ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
del new_sd[ANOTHER_POSITION_IDS_KEY]
else:
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
new_sd["text_model.embeddings.position_ids"] = position_ids
return new_sd
@@ -864,7 +864,6 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path):
return checkpoint, state_dict
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
@@ -887,7 +886,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
vae = AutoencoderKL(**vae_config)
info = vae.load_state_dict(converted_vae_checkpoint)
print("loadint vae:", info)
print("loading vae:", info)
# convert text_model
if v2:
@@ -1089,11 +1088,6 @@ def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_mod
VAE_PREFIX = "first_stage_model."
def convert_vae(vae_sd, vae_config):
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
def load_vae(vae_id, dtype):
print(f"load VAE: {vae_id}")
if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
@@ -1109,27 +1103,34 @@ def load_vae(vae_id, dtype):
# local
vae_config = create_vae_diffusers_config()
if vae_id.endswith(".bin"): return convert_vae(torch.load(vae_id, map_location="cpu"), vae_config)
if vae_id.endswith(".bin"):
# SD 1.5 VAE on Huggingface
converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
else:
# StableDiffusion
vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id)
else torch.load(vae_id, map_location="cpu"))
vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model
vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id)
else torch.load(vae_id, map_location="cpu"))
vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model
full_model = False
for vae_key in vae_sd:
if vae_key.startswith(VAE_PREFIX):
full_model = True
break
if not full_model:
sd = {}
for key, value in vae_sd.items():
sd[VAE_PREFIX + key] = value
vae_sd = sd
del sd
# vae only or full model
full_model = False
for vae_key in vae_sd:
if vae_key.startswith(VAE_PREFIX):
full_model = True
break
if not full_model:
sd = {}
for key, value in vae_sd.items():
sd[VAE_PREFIX + key] = value
vae_sd = sd
del sd
return convert_vae(vae_sd, vae_config)
# Convert the VAE model.
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
return vae
# endregion