mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
Support for checkpoint files with a mysterious prefix "model.diffusion_model."
This commit is contained in:
@@ -73,6 +73,10 @@ def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int
|
||||
with safe_open(ckpt_path, framework="pt") as f:
|
||||
keys.extend(f.keys())
|
||||
|
||||
# if the key has annoying prefix, remove it
|
||||
if keys[0].startswith("model.diffusion_model."):
|
||||
keys = [key.replace("model.diffusion_model.", "") for key in keys]
|
||||
|
||||
is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys
|
||||
is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys)
|
||||
|
||||
@@ -141,6 +145,13 @@ def load_flow_model(
|
||||
sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks)
|
||||
logger.info("Converted Diffusers to BFL")
|
||||
|
||||
# if the key has annoying prefix, remove it
|
||||
for key in list(sd.keys()):
|
||||
new_key = key.replace("model.diffusion_model.", "")
|
||||
if new_key == key:
|
||||
break # the model doesn't have annoying prefix
|
||||
sd[new_key] = sd.pop(key)
|
||||
|
||||
info = model.load_state_dict(sd, strict=False, assign=True)
|
||||
logger.info(f"Loaded Flux: {info}")
|
||||
return is_schnell, model
|
||||
|
||||
Reference in New Issue
Block a user