fix not working

This commit is contained in:
Kohya S
2023-03-30 22:28:55 +09:00
parent 5fc80b7a5b
commit e76ea7cd7d

View File

@@ -531,11 +531,14 @@ class PipelineLike:
replacer.append(self.token_replacements_XTI[r][layer])
else:
replacer = replacer_
new_tokens.extend(replacer)
else:
new_tokens.append(token)
new_tokens.extend(replacer)
else:
new_tokens.append(token)
return new_tokens
def add_token_replacement_XTI(self, target_token_id, rep_token_ids):
self.token_replacements_XTI[target_token_id] = rep_token_ids
def set_control_nets(self, ctrl_nets):
self.control_nets = ctrl_nets
@@ -756,7 +759,7 @@ class PipelineLike:
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
if not self.token_replacements_XTI:
text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings(
pipe=self,
@@ -770,8 +773,8 @@ class PipelineLike:
if negative_scale is not None:
_, real_uncond_embeddings, _ = get_weighted_text_embeddings(
pipe=self,
prompt=prompt, # こちらのトークン長に合わせてuncondを作るので75トークン超で必須
uncond_prompt=[""]*batch_size,
prompt=prompt, # こちらのトークン長に合わせてuncondを作るので75トークン超で必須
uncond_prompt=[""] * batch_size,
max_embeddings_multiples=max_embeddings_multiples,
clip_skip=self.clip_skip,
**kwargs,
@@ -779,7 +782,24 @@ class PipelineLike:
if self.token_replacements_XTI:
text_embeddings_concat = []
for layer in ['IN01', 'IN02', 'IN04', 'IN05', 'IN07', 'IN08', 'MID', 'OUT03', 'OUT04', 'OUT05', 'OUT06', 'OUT07', 'OUT08', 'OUT09', 'OUT10', 'OUT11']:
for layer in [
"IN01",
"IN02",
"IN04",
"IN05",
"IN07",
"IN08",
"MID",
"OUT03",
"OUT04",
"OUT05",
"OUT06",
"OUT07",
"OUT08",
"OUT09",
"OUT10",
"OUT11",
]:
text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings(
pipe=self,
prompt=prompt,
@@ -801,14 +821,6 @@ class PipelineLike:
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
else:
text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings])
text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings(
pipe=self,
prompt=prompt,
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
max_embeddings_multiples=max_embeddings_multiples,
clip_skip=self.clip_skip,
**kwargs,
)
# CLIP guidanceで使用するembeddingsを取得する
if self.clip_guidance_scale > 0:
@@ -1716,7 +1728,7 @@ def parse_prompt_attention(text):
return res
def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length: int):
def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length: int, layer=None):
r"""
Tokenize a list of prompts and return its tokens with weights of each token.
No padding, starting or ending token is included.
@@ -1732,7 +1744,7 @@ def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length:
# tokenize and discard the starting and the ending token
token = pipe.tokenizer(word).input_ids[1:-1]
token = pipe.replace_token(token)
token = pipe.replace_token(token, layer=layer)
text_token += token
# copy the weight by length of token
@@ -1879,11 +1891,11 @@ def get_weighted_text_embeddings(
prompt = [prompt]
if not skip_parsing:
prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2, layer=layer)
if uncond_prompt is not None:
if isinstance(uncond_prompt, str):
uncond_prompt = [uncond_prompt]
uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2, layer=layer)
else:
prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids]
prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
@@ -2262,7 +2274,7 @@ def main(args):
metadata = f.metadata()
if metadata is not None:
print(f"metadata for: {network_weight}: {metadata}")
network = imported_module.create_network_from_weights(
network_mul, network_weight, vae, text_encoder, unet, **net_kwargs
)
@@ -2335,9 +2347,10 @@ def main(args):
if args.diffusers_xformers:
pipe.enable_xformers_memory_efficient_attention()
diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI
diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D.forward = downblock_forward_XTI
diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D.forward = upblock_forward_XTI
if args.XTI_embeddings:
diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI
diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D.forward = downblock_forward_XTI
diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D.forward = upblock_forward_XTI
# Textual Inversionを処理する
if args.textual_inversion_embeddings:
@@ -2386,32 +2399,52 @@ def main(args):
token_embeds[token_id] = embed
if args.XTI_embeddings:
XTI_layers = ['IN01', 'IN02', 'IN04', 'IN05', 'IN07', 'IN08', 'MID', 'OUT03', 'OUT04', 'OUT05', 'OUT06', 'OUT07', 'OUT08', 'OUT09', 'OUT10', 'OUT11']
XTI_layers = [
"IN01",
"IN02",
"IN04",
"IN05",
"IN07",
"IN08",
"MID",
"OUT03",
"OUT04",
"OUT05",
"OUT06",
"OUT07",
"OUT08",
"OUT09",
"OUT10",
"OUT11",
]
token_ids_embeds_XTI = []
for embeds_file in args.XTI_embeddings:
if model_util.is_safetensors(embeds_file):
from safetensors.torch import load_file
data = load_file(embeds_file)
else:
data = torch.load(embeds_file, map_location="cpu")
if set(data.keys()) != set(XTI_layers):
raise ValueError("NOT XTI")
embeds = torch.concat(list(data.values()))
num_vectors_per_token = data['MID'].size()[0]
num_vectors_per_token = data["MID"].size()[0]
token_string = os.path.splitext(os.path.basename(embeds_file))[0]
token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)]
# add new word to tokenizer, count is num_vectors_per_token
num_added_tokens = tokenizer.add_tokens(token_strings)
assert num_added_tokens == num_vectors_per_token, f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}"
assert (
num_added_tokens == num_vectors_per_token
), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}"
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
print(f"XTI embeddings `{token_string}` loaded. Tokens are added: {token_ids}")
#if num_vectors_per_token > 1:
# if num_vectors_per_token > 1:
pipe.add_token_replacement(token_ids[0], token_ids)
token_strings_XTI = []
for layer_name in XTI_layers:
token_strings_XTI += [f"{t}_{layer_name}" for t in token_strings]
@@ -2423,7 +2456,7 @@ def main(args):
for i, layer_name in enumerate(XTI_layers):
t_XTI_dic[layer_name] = t + (i + 1) * num_added_tokens
pipe.add_token_replacement_XTI(t, t_XTI_dic)
text_encoder.resize_token_embeddings(len(tokenizer))
token_embeds = text_encoder.get_input_embeddings().weight.data
for token_ids, embeds in token_ids_embeds_XTI:
@@ -3090,8 +3123,8 @@ def setup_parser() -> argparse.ArgumentParser:
"--XTI_embeddings",
type=str,
default=None,
nargs='*',
help='Embeddings files of Extended Textual Inversion / Extended Textual Inversionのembeddings'
nargs="*",
help="Embeddings files of Extended Textual Inversion / Extended Textual Inversionのembeddings",
)
parser.add_argument("--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う")
parser.add_argument(