Merge branch 'kohya-ss:dev' into dev

This commit is contained in:
青龍聖者@bdsqlsz
2023-09-15 15:37:16 +08:00
committed by GitHub
3 changed files with 48 additions and 4 deletions

View File

@@ -258,6 +258,10 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
te1_sd[k.replace("conditioner.embedders.0.transformer.", "")] = state_dict.pop(k)
elif k.startswith("conditioner.embedders.1.model."):
te2_sd[k] = state_dict.pop(k)
# 一部のposition_idsがないモデルへの対応 / add position_ids for some models
if "text_model.embeddings.position_ids" not in te1_sd:
te1_sd["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0)
info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32
print("text encoder 1:", info1)

View File

@@ -37,7 +37,7 @@ from diffusers import (
from einops import rearrange
from tqdm import tqdm
from torchvision import transforms
from transformers import CLIPTextModel, CLIPTokenizer, CLIPModel, CLIPTextConfig
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPImageProcessor
import PIL
from PIL import Image
from PIL.PngImagePlugin import PngInfo
@@ -61,6 +61,8 @@ SCHEDLER_SCHEDULE = "scaled_linear"
LATENT_CHANNELS = 4
DOWNSAMPLING_FACTOR = 8
CLIP_VISION_MODEL = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
# region モジュール入れ替え部
"""
高速化のためのモジュール入れ替え
@@ -320,6 +322,10 @@ class PipelineLike:
self.scheduler = scheduler
self.safety_checker = None
self.clip_vision_model: CLIPVisionModelWithProjection = None
self.clip_vision_processor: CLIPImageProcessor = None
self.clip_vision_strength = 0.0
# Textual Inversion
self.token_replacements_list = []
for _ in range(len(self.text_encoders)):
@@ -535,6 +541,21 @@ class PipelineLike:
num_sub_prompts = len(text_pool) // batch_size
text_pool = text_pool[num_sub_prompts - 1 :: num_sub_prompts] # last subprompt
if init_image is not None and self.clip_vision_model is not None:
print(f"encode by clip_vision_model and apply clip_vision_strength={self.clip_vision_strength}")
vision_input = self.clip_vision_processor(init_image, return_tensors="pt", device=self.device)
pixel_values = vision_input["pixel_values"].to(self.device, dtype=text_embeddings.dtype)
clip_vision_embeddings = self.clip_vision_model(pixel_values=pixel_values, output_hidden_states=True, return_dict=True)
clip_vision_embeddings = clip_vision_embeddings.image_embeds
if len(clip_vision_embeddings) == 1 and batch_size > 1:
clip_vision_embeddings = clip_vision_embeddings.repeat((batch_size, 1))
clip_vision_embeddings = clip_vision_embeddings * self.clip_vision_strength
assert clip_vision_embeddings.shape == text_pool.shape, f"{clip_vision_embeddings.shape} != {text_pool.shape}"
text_pool = clip_vision_embeddings # replace: same as ComfyUI (?)
c_vector = torch.cat([text_pool, c_vector], dim=1)
uc_vector = torch.cat([uncond_pool, uc_vector], dim=1)
@@ -1767,6 +1788,19 @@ def main(args):
init_images = load_images(args.image_path)
assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}"
print(f"loaded {len(init_images)} images for img2img")
# CLIP Vision
if args.clip_vision_strength is not None:
print(f"load CLIP Vision model: {CLIP_VISION_MODEL}")
vision_model = CLIPVisionModelWithProjection.from_pretrained(CLIP_VISION_MODEL, projection_dim=1280)
vision_model.to(device, dtype)
processor = CLIPImageProcessor.from_pretrained(CLIP_VISION_MODEL)
pipe.clip_vision_model = vision_model
pipe.clip_vision_processor = processor
pipe.clip_vision_strength = args.clip_vision_strength
print(f"CLIP Vision model loaded.")
else:
init_images = None
@@ -2656,6 +2690,12 @@ def setup_parser() -> argparse.ArgumentParser:
nargs="*",
help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率",
)
parser.add_argument(
"--clip_vision_strength",
type=float,
default=None,
help="enable CLIP Vision Conditioning for img2img with this strength / img2imgでCLIP Vision Conditioningを有効にしてこのstrengthで処理する",
)
# # parser.add_argument(
# "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像"
# )

View File

@@ -51,7 +51,7 @@ def merge(args):
print(f"Model {model} does not exist")
exit()
assert len(args.models) == len(args.ratios) or args.ratios is None, "ratios must be the same length as models"
assert args.ratios is None or len(args.models) == len(args.ratios), "ratios must be the same length as models"
# load and merge
ratio = 1.0 / len(args.models) # default
@@ -113,13 +113,13 @@ def merge(args):
# add supplementary keys' value (including VAE and TextEncoder)
if len(supplementary_key_ratios) > 0:
print("add first model's value")
with safe_open(model, framework="pt", device=args.device) as f:
with safe_open(args.models[0], framework="pt", device=args.device) as f:
for key in tqdm(f.keys()):
_, new_key = replace_text_encoder_key(key)
if new_key not in supplementary_key_ratios:
continue
if is_unet_key(new_key): # not VAE or TextEncoder
if is_unet_key(new_key): # not VAE or TextEncoder
print(f"Key {new_key} not in all models, ratio = {supplementary_key_ratios[new_key]}")
value = f.get_tensor(key) # original key