mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
Add workaround for clip's bug for pooled output
This commit is contained in:
@@ -18,7 +18,7 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput,
|
||||
from diffusers.utils import logging
|
||||
from PIL import Image
|
||||
|
||||
from library import sdxl_model_util, sdxl_train_util
|
||||
from library import sdxl_model_util, sdxl_train_util, train_util
|
||||
|
||||
|
||||
try:
|
||||
@@ -210,7 +210,7 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos
|
||||
return tokens, weights
|
||||
|
||||
|
||||
def get_hidden_states(text_encoder, input_ids, is_sdxl_text_encoder2: bool, device):
|
||||
def get_hidden_states(text_encoder, input_ids, is_sdxl_text_encoder2: bool, eos_token_id, device):
|
||||
if not is_sdxl_text_encoder2:
|
||||
# text_encoder1: same as SD1/2
|
||||
enc_out = text_encoder(input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=True)
|
||||
@@ -220,7 +220,8 @@ def get_hidden_states(text_encoder, input_ids, is_sdxl_text_encoder2: bool, devi
|
||||
# text_encoder2
|
||||
enc_out = text_encoder(input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=True)
|
||||
hidden_states = enc_out["hidden_states"][-2] # penuultimate layer
|
||||
pool = enc_out["text_embeds"]
|
||||
# pool = enc_out["text_embeds"]
|
||||
pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], input_ids, eos_token_id)
|
||||
hidden_states = hidden_states.to(device)
|
||||
if pool is not None:
|
||||
pool = pool.to(device)
|
||||
@@ -261,7 +262,7 @@ def get_unweighted_text_embeddings(
|
||||
text_input_chunk[j, 1] = eos
|
||||
|
||||
text_embedding, current_text_pool = get_hidden_states(
|
||||
pipe.text_encoder, text_input_chunk, is_sdxl_text_encoder2, pipe.device
|
||||
pipe.text_encoder, text_input_chunk, is_sdxl_text_encoder2, eos, pipe.device
|
||||
)
|
||||
if text_pool is None:
|
||||
text_pool = current_text_pool
|
||||
@@ -280,7 +281,7 @@ def get_unweighted_text_embeddings(
|
||||
text_embeddings.append(text_embedding)
|
||||
text_embeddings = torch.concat(text_embeddings, axis=1)
|
||||
else:
|
||||
text_embeddings, text_pool = get_hidden_states(pipe.text_encoder, text_input, is_sdxl_text_encoder2, pipe.device)
|
||||
text_embeddings, text_pool = get_hidden_states(pipe.text_encoder, text_input, is_sdxl_text_encoder2, eos, pipe.device)
|
||||
return text_embeddings, text_pool
|
||||
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ import torch
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
from torchvision import transforms
|
||||
from transformers import CLIPTokenizer
|
||||
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
|
||||
import transformers
|
||||
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
||||
from diffusers import (
|
||||
@@ -2868,7 +2868,7 @@ def verify_training_args(args: argparse.Namespace):
|
||||
raise ValueError(
|
||||
"scale_v_pred_loss_like_noise_pred can be enabled only with v_parameterization / scale_v_pred_loss_like_noise_predはv_parameterizationが有効なときのみ有効にできます"
|
||||
)
|
||||
|
||||
|
||||
if args.v_pred_like_loss and args.v_parameterization:
|
||||
raise ValueError(
|
||||
"v_pred_like_loss cannot be enabled with v_parameterization / v_pred_like_lossはv_parameterizationが有効なときには有効にできません"
|
||||
@@ -3733,8 +3733,50 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod
|
||||
return encoder_hidden_states
|
||||
|
||||
|
||||
def pool_workaround(
|
||||
text_encoder: CLIPTextModelWithProjection, last_hidden_state: torch.Tensor, input_ids: torch.Tensor, eos_token_id: int
|
||||
):
|
||||
r"""
|
||||
workaround for CLIP's pooling bug: it returns the hidden states for the max token id as the pooled output
|
||||
instead of the hidden states for the EOS token
|
||||
If we use Textual Inversion, we need to use the hidden states for the EOS token as the pooled output
|
||||
|
||||
Original code from CLIP's pooling function:
|
||||
|
||||
\# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
||||
\# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||
\# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
||||
pooled_output = last_hidden_state[
|
||||
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
|
||||
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
|
||||
]
|
||||
"""
|
||||
|
||||
# input_ids: b*n,77
|
||||
# find index for EOS token
|
||||
eos_token_index = torch.where(input_ids == eos_token_id)[1]
|
||||
eos_token_index = eos_token_index.to(device=last_hidden_state.device)
|
||||
print(eos_token_index)
|
||||
print(input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1))
|
||||
|
||||
# get hidden states for EOS token
|
||||
pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index]
|
||||
|
||||
# apply projection
|
||||
pooled_output = text_encoder.text_projection(pooled_output)
|
||||
|
||||
return pooled_output
|
||||
|
||||
|
||||
def get_hidden_states_sdxl(
|
||||
max_token_length, input_ids1, input_ids2, tokenizer1, tokenizer2, text_encoder1, text_encoder2, weight_dtype=None
|
||||
max_token_length: int,
|
||||
input_ids1: torch.Tensor,
|
||||
input_ids2: torch.Tensor,
|
||||
tokenizer1: CLIPTokenizer,
|
||||
tokenizer2: CLIPTokenizer,
|
||||
text_encoder1: CLIPTextModel,
|
||||
text_encoder2: CLIPTextModelWithProjection,
|
||||
weight_dtype: Optional[str] = None,
|
||||
):
|
||||
# input_ids: b,n,77 -> b*n, 77
|
||||
b_size = input_ids1.size()[0]
|
||||
@@ -3748,7 +3790,10 @@ def get_hidden_states_sdxl(
|
||||
# text_encoder2
|
||||
enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True)
|
||||
hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer
|
||||
pool2 = enc_out["text_embeds"]
|
||||
|
||||
# pool2 = enc_out["text_embeds"]
|
||||
pool2 = pool_workaround(text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id)
|
||||
print(f"original pool2: {enc_out['text_embeds']}, fixed: {pool2}")
|
||||
|
||||
# b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280
|
||||
n_size = 1 if max_token_length is None else max_token_length // 75
|
||||
|
||||
@@ -94,7 +94,7 @@ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xform
|
||||
replace_vae_attn_to_memory_efficient()
|
||||
elif xformers:
|
||||
# replace_vae_attn_to_xformers() # 解像度によってxformersがエラーを出す?
|
||||
vae.set_use_memory_efficient_attention_xformers(True) # とりあえずこっちを使う
|
||||
vae.set_use_memory_efficient_attention_xformers(True) # とりあえずこっちを使う
|
||||
elif sdpa:
|
||||
replace_vae_attn_to_sdpa()
|
||||
|
||||
@@ -960,6 +960,8 @@ def get_unweighted_text_embeddings(
|
||||
text_embedding = enc_out["hidden_states"][-2]
|
||||
if pool is None:
|
||||
pool = enc_out.get("text_embeds", None) # use 1st chunk, if provided
|
||||
if pool is not None:
|
||||
pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], text_input_chunk, eos)
|
||||
|
||||
if no_boseos_middle:
|
||||
if i == 0:
|
||||
@@ -978,6 +980,8 @@ def get_unweighted_text_embeddings(
|
||||
enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
|
||||
text_embeddings = enc_out["hidden_states"][-2]
|
||||
pool = enc_out.get("text_embeds", None) # text encoder 1 doesn't return this
|
||||
if pool is not None:
|
||||
pool = train_util.pool_workaround(text_encoder, enc_out["last_hidden_state"], text_input_chunk, eos)
|
||||
return text_embeddings, pool
|
||||
|
||||
|
||||
|
||||
@@ -213,7 +213,7 @@ if __name__ == "__main__":
|
||||
enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True)
|
||||
text_embedding2_penu = enc_out["hidden_states"][-2]
|
||||
# print("hidden_states2", text_embedding2_penu.shape)
|
||||
text_embedding2_pool = enc_out["text_embeds"]
|
||||
text_embedding2_pool = enc_out["text_embeds"] # do not suport Textual Inversion
|
||||
|
||||
# 連結して終了 concat and finish
|
||||
text_embedding = torch.cat([text_embedding1, text_embedding2_penu], dim=2)
|
||||
|
||||
Reference in New Issue
Block a user