mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
move convert_vae to inline, restore comments
This commit is contained in:
@@ -632,7 +632,7 @@ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
|
||||
del new_sd[ANOTHER_POSITION_IDS_KEY]
|
||||
else:
|
||||
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
|
||||
|
||||
|
||||
new_sd["text_model.embeddings.position_ids"] = position_ids
|
||||
return new_sd
|
||||
|
||||
@@ -864,7 +864,6 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path):
|
||||
return checkpoint, state_dict
|
||||
|
||||
|
||||
|
||||
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
|
||||
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
|
||||
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
|
||||
@@ -887,7 +886,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
|
||||
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
info = vae.load_state_dict(converted_vae_checkpoint)
|
||||
print("loadint vae:", info)
|
||||
print("loading vae:", info)
|
||||
|
||||
# convert text_model
|
||||
if v2:
|
||||
@@ -1089,11 +1088,6 @@ def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_mod
|
||||
VAE_PREFIX = "first_stage_model."
|
||||
|
||||
|
||||
def convert_vae(vae_sd, vae_config):
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
|
||||
def load_vae(vae_id, dtype):
|
||||
print(f"load VAE: {vae_id}")
|
||||
if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
|
||||
@@ -1109,27 +1103,34 @@ def load_vae(vae_id, dtype):
|
||||
# local
|
||||
vae_config = create_vae_diffusers_config()
|
||||
|
||||
if vae_id.endswith(".bin"): return convert_vae(torch.load(vae_id, map_location="cpu"), vae_config)
|
||||
if vae_id.endswith(".bin"):
|
||||
# SD 1.5 VAE on Huggingface
|
||||
converted_vae_checkpoint = torch.load(vae_id, map_location="cpu")
|
||||
else:
|
||||
# StableDiffusion
|
||||
vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id)
|
||||
else torch.load(vae_id, map_location="cpu"))
|
||||
vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model
|
||||
|
||||
vae_model = (load_file(vae_id, "cpu") if is_safetensors(vae_id)
|
||||
else torch.load(vae_id, map_location="cpu"))
|
||||
vae_sd = vae_model['state_dict'] if 'state_dict' in vae_model else vae_model
|
||||
|
||||
full_model = False
|
||||
for vae_key in vae_sd:
|
||||
if vae_key.startswith(VAE_PREFIX):
|
||||
full_model = True
|
||||
break
|
||||
if not full_model:
|
||||
sd = {}
|
||||
for key, value in vae_sd.items():
|
||||
sd[VAE_PREFIX + key] = value
|
||||
vae_sd = sd
|
||||
del sd
|
||||
# vae only or full model
|
||||
full_model = False
|
||||
for vae_key in vae_sd:
|
||||
if vae_key.startswith(VAE_PREFIX):
|
||||
full_model = True
|
||||
break
|
||||
if not full_model:
|
||||
sd = {}
|
||||
for key, value in vae_sd.items():
|
||||
sd[VAE_PREFIX + key] = value
|
||||
vae_sd = sd
|
||||
del sd
|
||||
|
||||
|
||||
return convert_vae(vae_sd, vae_config)
|
||||
# Convert the VAE model.
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
|
||||
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
return vae
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
Reference in New Issue
Block a user