diff --git a/library/flux_utils.py b/library/flux_utils.py index 7a1ec37b..4403835f 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -73,6 +73,10 @@ def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int with safe_open(ckpt_path, framework="pt") as f: keys.extend(f.keys()) + # if the key has annoying prefix, remove it + if keys[0].startswith("model.diffusion_model."): + keys = [key.replace("model.diffusion_model.", "") for key in keys] + is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys) @@ -141,6 +145,13 @@ def load_flow_model( sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks) logger.info("Converted Diffusers to BFL") + # if the key has annoying prefix, remove it + for key in list(sd.keys()): + new_key = key.replace("model.diffusion_model.", "") + if new_key == key: + break # the model doesn't have annoying prefix + sd[new_key] = sd.pop(key) + info = model.load_state_dict(sd, strict=False, assign=True) logger.info(f"Loaded Flux: {info}") return is_schnell, model