From e76ea7cd7deecb09d8f5022f28aa6146ef81e2fd Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 30 Mar 2023 22:28:55 +0900 Subject: [PATCH] fix not working --- gen_img_diffusers.py | 97 +++++++++++++++++++++++++++++--------------- 1 file changed, 65 insertions(+), 32 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 4dbe5f90..225de33c 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -531,11 +531,14 @@ class PipelineLike: replacer.append(self.token_replacements_XTI[r][layer]) else: replacer = replacer_ - new_tokens.extend(replacer) - else: - new_tokens.append(token) + new_tokens.extend(replacer) + else: + new_tokens.append(token) return new_tokens + def add_token_replacement_XTI(self, target_token_id, rep_token_ids): + self.token_replacements_XTI[target_token_id] = rep_token_ids + def set_control_nets(self, ctrl_nets): self.control_nets = ctrl_nets @@ -756,7 +759,7 @@ class PipelineLike: f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) - + if not self.token_replacements_XTI: text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings( pipe=self, @@ -770,8 +773,8 @@ class PipelineLike: if negative_scale is not None: _, real_uncond_embeddings, _ = get_weighted_text_embeddings( pipe=self, - prompt=prompt, # こちらのトークン長に合わせてuncondを作るので75トークン超で必須 - uncond_prompt=[""]*batch_size, + prompt=prompt, # こちらのトークン長に合わせてuncondを作るので75トークン超で必須 + uncond_prompt=[""] * batch_size, max_embeddings_multiples=max_embeddings_multiples, clip_skip=self.clip_skip, **kwargs, @@ -779,7 +782,24 @@ class PipelineLike: if self.token_replacements_XTI: text_embeddings_concat = [] - for layer in ['IN01', 'IN02', 'IN04', 'IN05', 'IN07', 'IN08', 'MID', 'OUT03', 'OUT04', 'OUT05', 'OUT06', 'OUT07', 'OUT08', 'OUT09', 'OUT10', 'OUT11']: + for layer in [ + "IN01", + "IN02", + "IN04", + "IN05", + "IN07", + "IN08", + "MID", + "OUT03", + "OUT04", + "OUT05", + "OUT06", + "OUT07", + "OUT08", + "OUT09", + "OUT10", + "OUT11", + ]: text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings( pipe=self, prompt=prompt, @@ -801,14 +821,6 @@ class PipelineLike: text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) else: text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]) - text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings( - pipe=self, - prompt=prompt, - uncond_prompt=negative_prompt if do_classifier_free_guidance else None, - max_embeddings_multiples=max_embeddings_multiples, - clip_skip=self.clip_skip, - **kwargs, - ) # CLIP guidanceで使用するembeddingsを取得する if self.clip_guidance_scale > 0: @@ -1716,7 +1728,7 @@ def parse_prompt_attention(text): return res -def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length: int): +def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length: int, layer=None): r""" Tokenize a list of prompts and return its tokens with weights of each token. No padding, starting or ending token is included. @@ -1732,7 +1744,7 @@ def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length: # tokenize and discard the starting and the ending token token = pipe.tokenizer(word).input_ids[1:-1] - token = pipe.replace_token(token) + token = pipe.replace_token(token, layer=layer) text_token += token # copy the weight by length of token @@ -1879,11 +1891,11 @@ def get_weighted_text_embeddings( prompt = [prompt] if not skip_parsing: - prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2) + prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2, layer=layer) if uncond_prompt is not None: if isinstance(uncond_prompt, str): uncond_prompt = [uncond_prompt] - uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2) + uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2, layer=layer) else: prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids] prompt_weights = [[1.0] * len(token) for token in prompt_tokens] @@ -2262,7 +2274,7 @@ def main(args): metadata = f.metadata() if metadata is not None: print(f"metadata for: {network_weight}: {metadata}") - + network = imported_module.create_network_from_weights( network_mul, network_weight, vae, text_encoder, unet, **net_kwargs ) @@ -2335,9 +2347,10 @@ def main(args): if args.diffusers_xformers: pipe.enable_xformers_memory_efficient_attention() - diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI - diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D.forward = downblock_forward_XTI - diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D.forward = upblock_forward_XTI + if args.XTI_embeddings: + diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI + diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D.forward = downblock_forward_XTI + diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D.forward = upblock_forward_XTI # Textual Inversionを処理する if args.textual_inversion_embeddings: @@ -2386,32 +2399,52 @@ def main(args): token_embeds[token_id] = embed if args.XTI_embeddings: - XTI_layers = ['IN01', 'IN02', 'IN04', 'IN05', 'IN07', 'IN08', 'MID', 'OUT03', 'OUT04', 'OUT05', 'OUT06', 'OUT07', 'OUT08', 'OUT09', 'OUT10', 'OUT11'] + XTI_layers = [ + "IN01", + "IN02", + "IN04", + "IN05", + "IN07", + "IN08", + "MID", + "OUT03", + "OUT04", + "OUT05", + "OUT06", + "OUT07", + "OUT08", + "OUT09", + "OUT10", + "OUT11", + ] token_ids_embeds_XTI = [] for embeds_file in args.XTI_embeddings: if model_util.is_safetensors(embeds_file): from safetensors.torch import load_file + data = load_file(embeds_file) else: data = torch.load(embeds_file, map_location="cpu") if set(data.keys()) != set(XTI_layers): raise ValueError("NOT XTI") embeds = torch.concat(list(data.values())) - num_vectors_per_token = data['MID'].size()[0] + num_vectors_per_token = data["MID"].size()[0] token_string = os.path.splitext(os.path.basename(embeds_file))[0] token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)] - + # add new word to tokenizer, count is num_vectors_per_token num_added_tokens = tokenizer.add_tokens(token_strings) - assert num_added_tokens == num_vectors_per_token, f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}" + assert ( + num_added_tokens == num_vectors_per_token + ), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}" token_ids = tokenizer.convert_tokens_to_ids(token_strings) print(f"XTI embeddings `{token_string}` loaded. Tokens are added: {token_ids}") - #if num_vectors_per_token > 1: + # if num_vectors_per_token > 1: pipe.add_token_replacement(token_ids[0], token_ids) - + token_strings_XTI = [] for layer_name in XTI_layers: token_strings_XTI += [f"{t}_{layer_name}" for t in token_strings] @@ -2423,7 +2456,7 @@ def main(args): for i, layer_name in enumerate(XTI_layers): t_XTI_dic[layer_name] = t + (i + 1) * num_added_tokens pipe.add_token_replacement_XTI(t, t_XTI_dic) - + text_encoder.resize_token_embeddings(len(tokenizer)) token_embeds = text_encoder.get_input_embeddings().weight.data for token_ids, embeds in token_ids_embeds_XTI: @@ -3090,8 +3123,8 @@ def setup_parser() -> argparse.ArgumentParser: "--XTI_embeddings", type=str, default=None, - nargs='*', - help='Embeddings files of Extended Textual Inversion / Extended Textual Inversionのembeddings' + nargs="*", + help="Embeddings files of Extended Textual Inversion / Extended Textual Inversionのembeddings", ) parser.add_argument("--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う") parser.add_argument(