From 7ea38f90d7bbcee5523d1b945a66a577b42fa696 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 7 Aug 2023 23:40:49 +0900 Subject: [PATCH 1/3] add merge script --- tools/merge_models.py | 168 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 168 insertions(+) create mode 100644 tools/merge_models.py diff --git a/tools/merge_models.py b/tools/merge_models.py new file mode 100644 index 00000000..dd04ea46 --- /dev/null +++ b/tools/merge_models.py @@ -0,0 +1,168 @@ +import argparse +import os + +import torch +from safetensors import safe_open +from safetensors.torch import load_file, save_file +from tqdm import tqdm + + +def is_unet_key(key): + # VAE or TextEncoder, the last one is for SDXL + return not ("first_stage_model" in key or "cond_stage_model" in key or "conditioner." in key) + + +TEXT_ENCODER_KEY_REPLACEMENTS = [ + ("cond_stage_model.transformer.embeddings.", "cond_stage_model.transformer.text_model.embeddings."), + ("cond_stage_model.transformer.encoder.", "cond_stage_model.transformer.text_model.encoder."), + ("cond_stage_model.transformer.final_layer_norm.", "cond_stage_model.transformer.text_model.final_layer_norm."), +] + + +# support for models with different text encoder keys +def replace_text_encoder_key(key): + for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS: + if key.startswith(rep_from): + return True, rep_to + key[len(rep_from) :] + return False, key + + +def merge(args): + if args.precision == "fp16": + dtype = torch.float16 + elif args.precision == "bf16": + dtype = torch.bfloat16 + else: + dtype = torch.float + + if args.saving_precision == "fp16": + save_dtype = torch.float16 + elif args.saving_precision == "bf16": + save_dtype = torch.bfloat16 + else: + save_dtype = torch.float + + # check if all models are safetensors + for model in args.models: + if not model.endswith("safetensors"): + print(f"Model {model} is not a safetensors model") + exit() + if not os.path.isfile(model): + print(f"Model {model} does not exist") + exit() + + assert len(args.models) == len(args.ratios) or args.ratios is None, "ratios must be the same length as models" + + # load and merge + ratio = 1.0 / len(args.models) # default + supplementary_key_ratios = {} # [key] = ratio, for keys not in all models, add later + + merged_sd = None + first_model_keys = set() # check missing keys in other models + for i, model in enumerate(args.models): + if args.ratios is not None: + ratio = args.ratios[i] + + if merged_sd is None: + # load first model + print(f"Loading model {model}, ratio = {ratio}...") + merged_sd = {} + with safe_open(model, framework="pt", device=args.device) as f: + for key in tqdm(f.keys()): + value = f.get_tensor(key) + _, key = replace_text_encoder_key(key) + + first_model_keys.add(key) + + if not is_unet_key(key) and args.unet_only: + supplementary_key_ratios[key] = 1.0 # use first model's value for VAE or TextEncoder + continue + + value = ratio * value.to(dtype) # first model's value * ratio + merged_sd[key] = value + + print(f"Model has {len(merged_sd)} keys " + ("(UNet only)" if args.unet_only else "")) + continue + + # load other models + print(f"Loading model {model}, ratio = {ratio}...") + + with safe_open(model, framework="pt", device=args.device) as f: + model_keys = f.keys() + for key in tqdm(model_keys): + _, new_key = replace_text_encoder_key(key) + if new_key not in merged_sd: + if args.show_skipped and new_key not in first_model_keys: + print(f"Skip: {new_key}") + continue + + value = f.get_tensor(key) + merged_sd[new_key] = merged_sd[new_key] + ratio * value.to(dtype) + + # enumerate keys not in this model + model_keys = set(model_keys) + for key in merged_sd.keys(): + if key in model_keys: + continue + print(f"Key {key} not in model {model}, use first model's value") + if key in supplementary_key_ratios: + supplementary_key_ratios[key] += ratio + else: + supplementary_key_ratios[key] = ratio + + # add supplementary keys' value (including VAE and TextEncoder) + if len(supplementary_key_ratios) > 0: + print("add first model's value") + with safe_open(model, framework="pt", device=args.device) as f: + for key in tqdm(f.keys()): + _, new_key = replace_text_encoder_key(key) + if new_key not in supplementary_key_ratios: + continue + + if is_unet_key(new_key): # not VAE or TextEncoder + print(f"Key {new_key} not in all models, ratio = {supplementary_key_ratios[new_key]}") + + value = f.get_tensor(key) # original key + + if new_key not in merged_sd: + merged_sd[new_key] = supplementary_key_ratios[new_key] * value.to(dtype) + else: + merged_sd[new_key] = merged_sd[new_key] + supplementary_key_ratios[new_key] * value.to(dtype) + + # save + output_file = args.output + if not output_file.endswith(".safetensors"): + output_file = output_file + ".safetensors" + + print(f"Saving to {output_file}...") + + # convert to save_dtype + for k in merged_sd.keys(): + merged_sd[k] = merged_sd[k].to(save_dtype) + + save_file(merged_sd, output_file) + + print("Done!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Merge models") + parser.add_argument("--models", nargs="+", type=str, help="Models to merge") + parser.add_argument("--output", type=str, help="Output model") + parser.add_argument("--ratios", nargs="+", type=float, help="Ratios of models, default is equal, total = 1.0") + parser.add_argument("--unet_only", action="store_true", help="Only merge unet") + parser.add_argument("--device", type=str, default="cpu", help="Device to use, default is cpu") + parser.add_argument( + "--precision", type=str, default="float", choices=["float", "fp16", "bf16"], help="Calculation precision, default is float" + ) + parser.add_argument( + "--saving_precision", + type=str, + default="float", + choices=["float", "fp16", "bf16"], + help="Saving precision, default is float", + ) + parser.add_argument("--show_skipped", action="store_true", help="Show skipped keys (keys not in first model)") + + args = parser.parse_args() + merge(args) From 6f80fe17fcac026ab85004d5a701835c14f5da84 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 8 Aug 2023 21:03:16 +0900 Subject: [PATCH 2/3] fix crashing in saving lora with clipskip --- library/sai_model_spec.py | 12 ++++++++---- library/train_util.py | 20 ++++++++++---------- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/library/sai_model_spec.py b/library/sai_model_spec.py index 88c2cb77..472686ba 100644 --- a/library/sai_model_spec.py +++ b/library/sai_model_spec.py @@ -194,8 +194,8 @@ def build_metadata( # comma separated to tuple if isinstance(reso, str): reso = tuple(map(int, reso.split(","))) - if len(reso) == 1: - reso = (reso[0], reso[0]) + if len(reso) == 1: + reso = (reso[0], reso[0]) else: # resolution is defined in dataset, so use default if sdxl: @@ -215,7 +215,11 @@ def build_metadata( metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON if timesteps is not None: - metadata["modelspec.timestep_range"] = timesteps + if isinstance(timesteps, str) or isinstance(timesteps, int): + timesteps = (timesteps, timesteps) + if len(timesteps) == 1: + timesteps = (timesteps[0], timesteps[0]) + metadata["modelspec.timestep_range"] = f"{timesteps[0]},{timesteps[1]}" else: del metadata["modelspec.timestep_range"] @@ -228,7 +232,7 @@ def build_metadata( # assert all([v is not None for v in metadata.values()]), metadata if not all([v is not None for v in metadata.values()]): print(f"Internal error: some metadata values are None: {metadata}") - + return metadata diff --git a/library/train_util.py b/library/train_util.py index dbfe41e8..34e477ed 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2521,7 +2521,7 @@ def get_sai_model_spec( sdxl: bool, lora: bool, textual_inversion: bool, - is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA + is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA ): timestamp = time.time() @@ -2546,15 +2546,15 @@ def get_sai_model_spec( lora, textual_inversion, timestamp, - title, - reso, - is_stable_diffusion_ckpt, - args.metadata_author, - args.metadata_description, - args.metadata_license, - args.metadata_tags, - timesteps, - args.clip_skip, # None or int + title=title, + reso=reso, + is_stable_diffusion_ckpt=is_stable_diffusion_ckpt, + author=args.metadata_author, + description=args.metadata_description, + license=args.metadata_license, + tags=args.metadata_tags, + timesteps=timesteps, + clip_skip=args.clip_skip, # None or int ) return metadata From b83ce0c3529f564b52d684bd60f0d5a7734dc658 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 8 Aug 2023 21:09:08 +0900 Subject: [PATCH 3/3] modify import #368 --- networks/extract_lora_from_models.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/networks/extract_lora_from_models.py b/networks/extract_lora_from_models.py index b4eb0cf7..eed30350 100644 --- a/networks/extract_lora_from_models.py +++ b/networks/extract_lora_from_models.py @@ -9,9 +9,7 @@ import time import torch from safetensors.torch import load_file, save_file from tqdm import tqdm -from library import sai_model_spec -import library.model_util as model_util -import library.sdxl_model_util as sdxl_model_util +from library import sai_model_spec, model_util, sdxl_model_util import lora