From 0ecfd91a208f845f45fdc39cae83d6be71e23720 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 13 Sep 2023 17:59:14 +0900 Subject: [PATCH 1/3] fix VAE becomes last one --- tools/merge_models.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tools/merge_models.py b/tools/merge_models.py index dd04ea46..391bfe67 100644 --- a/tools/merge_models.py +++ b/tools/merge_models.py @@ -51,7 +51,7 @@ def merge(args): print(f"Model {model} does not exist") exit() - assert len(args.models) == len(args.ratios) or args.ratios is None, "ratios must be the same length as models" + assert args.ratios is None or len(args.models) == len(args.ratios), "ratios must be the same length as models" # load and merge ratio = 1.0 / len(args.models) # default @@ -113,13 +113,13 @@ def merge(args): # add supplementary keys' value (including VAE and TextEncoder) if len(supplementary_key_ratios) > 0: print("add first model's value") - with safe_open(model, framework="pt", device=args.device) as f: + with safe_open(args.models[0], framework="pt", device=args.device) as f: for key in tqdm(f.keys()): _, new_key = replace_text_encoder_key(key) if new_key not in supplementary_key_ratios: continue - if is_unet_key(new_key): # not VAE or TextEncoder + if is_unet_key(new_key): # not VAE or TextEncoder print(f"Key {new_key} not in all models, ratio = {supplementary_key_ratios[new_key]}") value = f.get_tensor(key) # original key From 90c47140b8d969c7ba55b1b85e0d518826a9b464 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 13 Sep 2023 17:59:34 +0900 Subject: [PATCH 2/3] add support model without position_ids --- library/sdxl_model_util.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index 6647b439..2f0154ca 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -258,6 +258,10 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty te1_sd[k.replace("conditioner.embedders.0.transformer.", "")] = state_dict.pop(k) elif k.startswith("conditioner.embedders.1.model."): te2_sd[k] = state_dict.pop(k) + + # 一部のposition_idsがないモデルへの対応 / add position_ids for some models + if "text_model.embeddings.position_ids" not in te1_sd: + te1_sd["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32 print("text encoder 1:", info1) From d337bbf8a08bce18e302b7e8403c70d58d632610 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 13 Sep 2023 20:58:37 +0900 Subject: [PATCH 3/3] get pool from CLIPVisionModel in img2img --- sdxl_gen_img.py | 42 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index c506ad3f..7d9c68bf 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -37,7 +37,7 @@ from diffusers import ( from einops import rearrange from tqdm import tqdm from torchvision import transforms -from transformers import CLIPTextModel, CLIPTokenizer, CLIPModel, CLIPTextConfig +from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPImageProcessor import PIL from PIL import Image from PIL.PngImagePlugin import PngInfo @@ -61,6 +61,8 @@ SCHEDLER_SCHEDULE = "scaled_linear" LATENT_CHANNELS = 4 DOWNSAMPLING_FACTOR = 8 +CLIP_VISION_MODEL = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + # region モジュール入れ替え部 """ 高速化のためのモジュール入れ替え @@ -320,6 +322,10 @@ class PipelineLike: self.scheduler = scheduler self.safety_checker = None + self.clip_vision_model: CLIPVisionModelWithProjection = None + self.clip_vision_processor: CLIPImageProcessor = None + self.clip_vision_strength = 0.0 + # Textual Inversion self.token_replacements_list = [] for _ in range(len(self.text_encoders)): @@ -535,6 +541,21 @@ class PipelineLike: num_sub_prompts = len(text_pool) // batch_size text_pool = text_pool[num_sub_prompts - 1 :: num_sub_prompts] # last subprompt + if init_image is not None and self.clip_vision_model is not None: + print(f"encode by clip_vision_model and apply clip_vision_strength={self.clip_vision_strength}") + vision_input = self.clip_vision_processor(init_image, return_tensors="pt", device=self.device) + pixel_values = vision_input["pixel_values"].to(self.device, dtype=text_embeddings.dtype) + + clip_vision_embeddings = self.clip_vision_model(pixel_values=pixel_values, output_hidden_states=True, return_dict=True) + clip_vision_embeddings = clip_vision_embeddings.image_embeds + + if len(clip_vision_embeddings) == 1 and batch_size > 1: + clip_vision_embeddings = clip_vision_embeddings.repeat((batch_size, 1)) + + clip_vision_embeddings = clip_vision_embeddings * self.clip_vision_strength + assert clip_vision_embeddings.shape == text_pool.shape, f"{clip_vision_embeddings.shape} != {text_pool.shape}" + text_pool = clip_vision_embeddings # replace: same as ComfyUI (?) + c_vector = torch.cat([text_pool, c_vector], dim=1) uc_vector = torch.cat([uncond_pool, uc_vector], dim=1) @@ -1767,6 +1788,19 @@ def main(args): init_images = load_images(args.image_path) assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}" print(f"loaded {len(init_images)} images for img2img") + + # CLIP Vision + if args.clip_vision_strength is not None: + print(f"load CLIP Vision model: {CLIP_VISION_MODEL}") + vision_model = CLIPVisionModelWithProjection.from_pretrained(CLIP_VISION_MODEL, projection_dim=1280) + vision_model.to(device, dtype) + processor = CLIPImageProcessor.from_pretrained(CLIP_VISION_MODEL) + + pipe.clip_vision_model = vision_model + pipe.clip_vision_processor = processor + pipe.clip_vision_strength = args.clip_vision_strength + print(f"CLIP Vision model loaded.") + else: init_images = None @@ -2656,6 +2690,12 @@ def setup_parser() -> argparse.ArgumentParser: nargs="*", help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率", ) + parser.add_argument( + "--clip_vision_strength", + type=float, + default=None, + help="enable CLIP Vision Conditioning for img2img with this strength / img2imgでCLIP Vision Conditioningを有効にしてこのstrengthで処理する", + ) # # parser.add_argument( # "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像" # )