Add workaround for clip's bug for pooled output

This commit is contained in:
Kohya S
2023-08-04 08:38:27 +09:00
parent cf6832896f
commit c6d52fdea4
4 changed files with 61 additions and 11 deletions

View File

@@ -18,7 +18,7 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput,
from diffusers.utils import logging
from PIL import Image
from library import sdxl_model_util, sdxl_train_util
from library import sdxl_model_util, sdxl_train_util, train_util
try:
@@ -210,7 +210,7 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos
return tokens, weights
def get_hidden_states(text_encoder, input_ids, is_sdxl_text_encoder2: bool, device):
def get_hidden_states(text_encoder, input_ids, is_sdxl_text_encoder2: bool, eos_token_id, device):
if not is_sdxl_text_encoder2:
# text_encoder1: same as SD1/2
enc_out = text_encoder(input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=True)
@@ -220,7 +220,8 @@ def get_hidden_states(text_encoder, input_ids, is_sdxl_text_encoder2: bool, devi
# text_encoder2
enc_out = text_encoder(input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=True)
hidden_states = enc_out["hidden_states"][-2] # penuultimate layer
pool = enc_out["text_embeds"]
# pool = enc_out["text_embeds"]
pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], input_ids, eos_token_id)
hidden_states = hidden_states.to(device)
if pool is not None:
pool = pool.to(device)
@@ -261,7 +262,7 @@ def get_unweighted_text_embeddings(
text_input_chunk[j, 1] = eos
text_embedding, current_text_pool = get_hidden_states(
pipe.text_encoder, text_input_chunk, is_sdxl_text_encoder2, pipe.device
pipe.text_encoder, text_input_chunk, is_sdxl_text_encoder2, eos, pipe.device
)
if text_pool is None:
text_pool = current_text_pool
@@ -280,7 +281,7 @@ def get_unweighted_text_embeddings(
text_embeddings.append(text_embedding)
text_embeddings = torch.concat(text_embeddings, axis=1)
else:
text_embeddings, text_pool = get_hidden_states(pipe.text_encoder, text_input, is_sdxl_text_encoder2, pipe.device)
text_embeddings, text_pool = get_hidden_states(pipe.text_encoder, text_input, is_sdxl_text_encoder2, eos, pipe.device)
return text_embeddings, text_pool

View File

@@ -34,7 +34,7 @@ import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torchvision import transforms
from transformers import CLIPTokenizer
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
import transformers
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
from diffusers import (
@@ -2868,7 +2868,7 @@ def verify_training_args(args: argparse.Namespace):
raise ValueError(
"scale_v_pred_loss_like_noise_pred can be enabled only with v_parameterization / scale_v_pred_loss_like_noise_predはv_parameterizationが有効なときのみ有効にできます"
)
if args.v_pred_like_loss and args.v_parameterization:
raise ValueError(
"v_pred_like_loss cannot be enabled with v_parameterization / v_pred_like_lossはv_parameterizationが有効なときには有効にできません"
@@ -3733,8 +3733,50 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod
return encoder_hidden_states
def pool_workaround(
text_encoder: CLIPTextModelWithProjection, last_hidden_state: torch.Tensor, input_ids: torch.Tensor, eos_token_id: int
):
r"""
workaround for CLIP's pooling bug: it returns the hidden states for the max token id as the pooled output
instead of the hidden states for the EOS token
If we use Textual Inversion, we need to use the hidden states for the EOS token as the pooled output
Original code from CLIP's pooling function:
\# text_embeds.shape = [batch_size, sequence_length, transformer.width]
\# take features from the eot embedding (eot_token is the highest number in each sequence)
\# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
]
"""
# input_ids: b*n,77
# find index for EOS token
eos_token_index = torch.where(input_ids == eos_token_id)[1]
eos_token_index = eos_token_index.to(device=last_hidden_state.device)
print(eos_token_index)
print(input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1))
# get hidden states for EOS token
pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index]
# apply projection
pooled_output = text_encoder.text_projection(pooled_output)
return pooled_output
def get_hidden_states_sdxl(
max_token_length, input_ids1, input_ids2, tokenizer1, tokenizer2, text_encoder1, text_encoder2, weight_dtype=None
max_token_length: int,
input_ids1: torch.Tensor,
input_ids2: torch.Tensor,
tokenizer1: CLIPTokenizer,
tokenizer2: CLIPTokenizer,
text_encoder1: CLIPTextModel,
text_encoder2: CLIPTextModelWithProjection,
weight_dtype: Optional[str] = None,
):
# input_ids: b,n,77 -> b*n, 77
b_size = input_ids1.size()[0]
@@ -3748,7 +3790,10 @@ def get_hidden_states_sdxl(
# text_encoder2
enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True)
hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer
pool2 = enc_out["text_embeds"]
# pool2 = enc_out["text_embeds"]
pool2 = pool_workaround(text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id)
print(f"original pool2: {enc_out['text_embeds']}, fixed: {pool2}")
# b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280
n_size = 1 if max_token_length is None else max_token_length // 75

View File

@@ -94,7 +94,7 @@ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xform
replace_vae_attn_to_memory_efficient()
elif xformers:
# replace_vae_attn_to_xformers() # 解像度によってxformersがエラーを出す
vae.set_use_memory_efficient_attention_xformers(True) # とりあえずこっちを使う
vae.set_use_memory_efficient_attention_xformers(True) # とりあえずこっちを使う
elif sdpa:
replace_vae_attn_to_sdpa()
@@ -960,6 +960,8 @@ def get_unweighted_text_embeddings(
text_embedding = enc_out["hidden_states"][-2]
if pool is None:
pool = enc_out.get("text_embeds", None) # use 1st chunk, if provided
if pool is not None:
pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], text_input_chunk, eos)
if no_boseos_middle:
if i == 0:
@@ -978,6 +980,8 @@ def get_unweighted_text_embeddings(
enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
text_embeddings = enc_out["hidden_states"][-2]
pool = enc_out.get("text_embeds", None) # text encoder 1 doesn't return this
if pool is not None:
pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], text_input_chunk, eos)
return text_embeddings, pool

View File

@@ -213,7 +213,7 @@ if __name__ == "__main__":
enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True)
text_embedding2_penu = enc_out["hidden_states"][-2]
# print("hidden_states2", text_embedding2_penu.shape)
text_embedding2_pool = enc_out["text_embeds"]
text_embedding2_pool = enc_out["text_embeds"] # do not suport Textual Inversion
# 連結して終了 concat and finish
text_embedding = torch.cat([text_embedding1, text_embedding2_penu], dim=2)