From c6d52fdea47e550c756c0eb87ae254e72cdbeb9d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 4 Aug 2023 08:38:27 +0900 Subject: [PATCH] Add workaround for clip's bug for pooled output --- library/sdxl_lpw_stable_diffusion.py | 11 +++--- library/train_util.py | 53 +++++++++++++++++++++++++--- sdxl_gen_img.py | 6 +++- sdxl_minimal_inference.py | 2 +- 4 files changed, 61 insertions(+), 11 deletions(-) diff --git a/library/sdxl_lpw_stable_diffusion.py b/library/sdxl_lpw_stable_diffusion.py index 7f88469f..e03ee405 100644 --- a/library/sdxl_lpw_stable_diffusion.py +++ b/library/sdxl_lpw_stable_diffusion.py @@ -18,7 +18,7 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, from diffusers.utils import logging from PIL import Image -from library import sdxl_model_util, sdxl_train_util +from library import sdxl_model_util, sdxl_train_util, train_util try: @@ -210,7 +210,7 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos return tokens, weights -def get_hidden_states(text_encoder, input_ids, is_sdxl_text_encoder2: bool, device): +def get_hidden_states(text_encoder, input_ids, is_sdxl_text_encoder2: bool, eos_token_id, device): if not is_sdxl_text_encoder2: # text_encoder1: same as SD1/2 enc_out = text_encoder(input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=True) @@ -220,7 +220,8 @@ def get_hidden_states(text_encoder, input_ids, is_sdxl_text_encoder2: bool, devi # text_encoder2 enc_out = text_encoder(input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=True) hidden_states = enc_out["hidden_states"][-2] # penuultimate layer - pool = enc_out["text_embeds"] + # pool = enc_out["text_embeds"] + pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], input_ids, eos_token_id) hidden_states = hidden_states.to(device) if pool is not None: pool = pool.to(device) @@ -261,7 +262,7 @@ def get_unweighted_text_embeddings( text_input_chunk[j, 1] = eos text_embedding, current_text_pool = get_hidden_states( - pipe.text_encoder, text_input_chunk, is_sdxl_text_encoder2, pipe.device + pipe.text_encoder, text_input_chunk, is_sdxl_text_encoder2, eos, pipe.device ) if text_pool is None: text_pool = current_text_pool @@ -280,7 +281,7 @@ def get_unweighted_text_embeddings( text_embeddings.append(text_embedding) text_embeddings = torch.concat(text_embeddings, axis=1) else: - text_embeddings, text_pool = get_hidden_states(pipe.text_encoder, text_input, is_sdxl_text_encoder2, pipe.device) + text_embeddings, text_pool = get_hidden_states(pipe.text_encoder, text_input, is_sdxl_text_encoder2, eos, pipe.device) return text_embeddings, text_pool diff --git a/library/train_util.py b/library/train_util.py index 1353173e..4d7f6727 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -34,7 +34,7 @@ import torch from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torchvision import transforms -from transformers import CLIPTokenizer +from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection import transformers from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION from diffusers import ( @@ -2868,7 +2868,7 @@ def verify_training_args(args: argparse.Namespace): raise ValueError( "scale_v_pred_loss_like_noise_pred can be enabled only with v_parameterization / scale_v_pred_loss_like_noise_predはv_parameterizationが有効なときのみ有効にできます" ) - + if args.v_pred_like_loss and args.v_parameterization: raise ValueError( "v_pred_like_loss cannot be enabled with v_parameterization / v_pred_like_lossはv_parameterizationが有効なときには有効にできません" @@ -3733,8 +3733,50 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod return encoder_hidden_states +def pool_workaround( + text_encoder: CLIPTextModelWithProjection, last_hidden_state: torch.Tensor, input_ids: torch.Tensor, eos_token_id: int +): + r""" + workaround for CLIP's pooling bug: it returns the hidden states for the max token id as the pooled output + instead of the hidden states for the EOS token + If we use Textual Inversion, we need to use the hidden states for the EOS token as the pooled output + + Original code from CLIP's pooling function: + + \# text_embeds.shape = [batch_size, sequence_length, transformer.width] + \# take features from the eot embedding (eot_token is the highest number in each sequence) + \# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), + input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1), + ] + """ + + # input_ids: b*n,77 + # find index for EOS token + eos_token_index = torch.where(input_ids == eos_token_id)[1] + eos_token_index = eos_token_index.to(device=last_hidden_state.device) + print(eos_token_index) + print(input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1)) + + # get hidden states for EOS token + pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index] + + # apply projection + pooled_output = text_encoder.text_projection(pooled_output) + + return pooled_output + + def get_hidden_states_sdxl( - max_token_length, input_ids1, input_ids2, tokenizer1, tokenizer2, text_encoder1, text_encoder2, weight_dtype=None + max_token_length: int, + input_ids1: torch.Tensor, + input_ids2: torch.Tensor, + tokenizer1: CLIPTokenizer, + tokenizer2: CLIPTokenizer, + text_encoder1: CLIPTextModel, + text_encoder2: CLIPTextModelWithProjection, + weight_dtype: Optional[str] = None, ): # input_ids: b,n,77 -> b*n, 77 b_size = input_ids1.size()[0] @@ -3748,7 +3790,10 @@ def get_hidden_states_sdxl( # text_encoder2 enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True) hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer - pool2 = enc_out["text_embeds"] + + # pool2 = enc_out["text_embeds"] + pool2 = pool_workaround(text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id) + print(f"original pool2: {enc_out['text_embeds']}, fixed: {pool2}") # b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280 n_size = 1 if max_token_length is None else max_token_length // 75 diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index d3d7f074..2544c689 100644 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -94,7 +94,7 @@ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xform replace_vae_attn_to_memory_efficient() elif xformers: # replace_vae_attn_to_xformers() # 解像度によってxformersがエラーを出す? - vae.set_use_memory_efficient_attention_xformers(True) # とりあえずこっちを使う + vae.set_use_memory_efficient_attention_xformers(True) # とりあえずこっちを使う elif sdpa: replace_vae_attn_to_sdpa() @@ -960,6 +960,8 @@ def get_unweighted_text_embeddings( text_embedding = enc_out["hidden_states"][-2] if pool is None: pool = enc_out.get("text_embeds", None) # use 1st chunk, if provided + if pool is not None: + pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], text_input_chunk, eos) if no_boseos_middle: if i == 0: @@ -978,6 +980,8 @@ def get_unweighted_text_embeddings( enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True) text_embeddings = enc_out["hidden_states"][-2] pool = enc_out.get("text_embeds", None) # text encoder 1 doesn't return this + if pool is not None: + pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], text_input_chunk, eos) return text_embeddings, pool diff --git a/sdxl_minimal_inference.py b/sdxl_minimal_inference.py index 1a950902..c3044493 100644 --- a/sdxl_minimal_inference.py +++ b/sdxl_minimal_inference.py @@ -213,7 +213,7 @@ if __name__ == "__main__": enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True) text_embedding2_penu = enc_out["hidden_states"][-2] # print("hidden_states2", text_embedding2_penu.shape) - text_embedding2_pool = enc_out["text_embeds"] + text_embedding2_pool = enc_out["text_embeds"] # do not suport Textual Inversion # 連結して終了 concat and finish text_embedding = torch.cat([text_embedding1, text_embedding2_penu], dim=2)