From 734333d0c9eec3f20582c9c16f6d148cb1ec2596 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 28 Feb 2025 23:52:29 +0900 Subject: [PATCH] feat: enhance merging logic for safetensors models to handle key prefixes correctly --- tools/merge_sd3_safetensors.py | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/tools/merge_sd3_safetensors.py b/tools/merge_sd3_safetensors.py index bef7c9b9..960cf6e7 100644 --- a/tools/merge_sd3_safetensors.py +++ b/tools/merge_sd3_safetensors.py @@ -53,22 +53,30 @@ def merge_safetensors( # 3. Load and merge each model with memory management # DiT/MMDiT - prefix: model.diffusion_model. + # This state dict may have VAE keys. logger.info(f"Loading DiT model from {dit_path}") dit_state_dict = load_safetensors(dit_path, device=device, disable_mmap=True) logger.info(f"Adding DiT model with {len(dit_state_dict)} keys") for key, value in dit_state_dict.items(): - merged_state_dict[f"model.diffusion_model.{key}"] = value + if key.startswith("model.diffusion_model.") or key.startswith("first_stage_model."): + merged_state_dict[key] = value + else: + merged_state_dict[f"model.diffusion_model.{key}"] = value # Free memory del dit_state_dict gc.collect() # VAE - prefix: first_stage_model. + # May be omitted if VAE is already included in DiT model. if vae_path: logger.info(f"Loading VAE model from {vae_path}") vae_state_dict = load_safetensors(vae_path, device=device, disable_mmap=True) logger.info(f"Adding VAE model with {len(vae_state_dict)} keys") for key, value in vae_state_dict.items(): - merged_state_dict[f"first_stage_model.{key}"] = value + if key.startswith("first_stage_model."): + merged_state_dict[key] = value + else: + merged_state_dict[f"first_stage_model.{key}"] = value # Free memory del vae_state_dict gc.collect() @@ -79,7 +87,10 @@ def merge_safetensors( clip_l_state_dict = load_safetensors(clip_l_path, device=device, disable_mmap=True) logger.info(f"Adding CLIP-L model with {len(clip_l_state_dict)} keys") for key, value in clip_l_state_dict.items(): - merged_state_dict[f"text_encoders.clip_l.{key}"] = value + if key.startswith("text_encoders.clip_l.transformer."): + merged_state_dict[key] = value + else: + merged_state_dict[f"text_encoders.clip_l.transformer.{key}"] = value # Free memory del clip_l_state_dict gc.collect() @@ -90,7 +101,10 @@ def merge_safetensors( clip_g_state_dict = load_safetensors(clip_g_path, device=device, disable_mmap=True) logger.info(f"Adding CLIP-G model with {len(clip_g_state_dict)} keys") for key, value in clip_g_state_dict.items(): - merged_state_dict[f"text_encoders.clip_g.{key}"] = value + if key.startswith("text_encoders.clip_g.transformer."): + merged_state_dict[key] = value + else: + merged_state_dict[f"text_encoders.clip_g.transformer.{key}"] = value # Free memory del clip_g_state_dict gc.collect() @@ -101,7 +115,10 @@ def merge_safetensors( t5xxl_state_dict = load_safetensors(t5xxl_path, device=device, disable_mmap=True) logger.info(f"Adding T5-XXL model with {len(t5xxl_state_dict)} keys") for key, value in t5xxl_state_dict.items(): - merged_state_dict[f"text_encoders.t5xxl.{key}"] = value + if key.startswith("text_encoders.t5xxl.transformer."): + merged_state_dict[key] = value + else: + merged_state_dict[f"text_encoders.t5xxl.transformer.{key}"] = value # Free memory del t5xxl_state_dict gc.collect() @@ -115,7 +132,7 @@ def merge_safetensors( def main(): parser = argparse.ArgumentParser(description="Merge Stable Diffusion 3.5 model components into a single safetensors file") parser.add_argument("--dit", required=True, help="Path to the DiT/MMDiT model") - parser.add_argument("--vae", help="Path to the VAE model") + parser.add_argument("--vae", help="Path to the VAE model. May be omitted if VAE is included in DiT model") parser.add_argument("--clip_l", help="Path to the CLIP-L model") parser.add_argument("--clip_g", help="Path to the CLIP-G model") parser.add_argument("--t5xxl", help="Path to the T5-XXL model")