From 61ec60a8932410be5793a3b9e7abab35245d6b53 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 14 Jan 2023 21:24:09 +0900 Subject: [PATCH] move convert_vae to inline, restore comments --- library/model_util.py | 53 ++++++++++++++++++++++--------------------- 1 file changed, 27 insertions(+), 26 deletions(-) diff --git a/library/model_util.py b/library/model_util.py index 96c7bbf2..6a1e656a 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -632,7 +632,7 @@ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length): del new_sd[ANOTHER_POSITION_IDS_KEY] else: position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64) - + new_sd["text_model.embeddings.position_ids"] = position_ids return new_sd @@ -864,7 +864,6 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path): return checkpoint, state_dict - # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認 def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None): _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path) @@ -887,7 +886,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None): vae = AutoencoderKL(**vae_config) info = vae.load_state_dict(converted_vae_checkpoint) - print("loadint vae:", info) + print("loading vae:", info) # convert text_model if v2: @@ -1089,11 +1088,6 @@ def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_mod VAE_PREFIX = "first_stage_model." -def convert_vae(vae_sd, vae_config): - converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config) - vae = AutoencoderKL(**vae_config) - vae.load_state_dict(converted_vae_checkpoint) - def load_vae(vae_id, dtype): print(f"load VAE: {vae_id}") if os.path.isdir(vae_id) or not os.path.isfile(vae_id): @@ -1109,27 +1103,34 @@ def load_vae(vae_id, dtype): # local vae_config = create_vae_diffusers_config() - if vae_id.endswith(".bin"): return convert_vae(torch.load(vae_id, map_location="cpu"), vae_config) + if vae_id.endswith(".bin"): + # SD 1.5 VAE on Huggingface + converted_vae_checkpoint = torch.load(vae_id, map_location="cpu") + else: + # StableDiffusion + vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id) + else torch.load(vae_id, map_location="cpu")) + vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model - vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id) - else torch.load(vae_id, map_location="cpu")) - vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model - - full_model = False - for vae_key in vae_sd: - if vae_key.startswith(VAE_PREFIX): - full_model = True - break - if not full_model: - sd = {} - for key, value in vae_sd.items(): - sd[VAE_PREFIX + key] = value - vae_sd = sd - del sd + # vae only or full model + full_model = False + for vae_key in vae_sd: + if vae_key.startswith(VAE_PREFIX): + full_model = True + break + if not full_model: + sd = {} + for key, value in vae_sd.items(): + sd[VAE_PREFIX + key] = value + vae_sd = sd + del sd - - return convert_vae(vae_sd, vae_config) + # Convert the VAE model. + converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config) + vae = AutoencoderKL(**vae_config) + vae.load_state_dict(converted_vae_checkpoint) + return vae # endregion