diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 889b4c4c..2c36329e 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -99,12 +99,6 @@ from library.original_unet import FlashAttentionFunction from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI -# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う -TOKENIZER_PATH = "openai/clip-vit-large-patch14" -V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う - -DEFAULT_TOKEN_LENGTH = 75 - # scheduler: SCHEDULER_LINEAR_START = 0.00085 SCHEDULER_LINEAR_END = 0.0120 @@ -2066,6 +2060,17 @@ def main(args): tokenizer = loading_pipe.tokenizer del loading_pipe + # Diffusers U-Net to original U-Net + original_unet = UNet2DConditionModel( + unet.config.sample_size, + unet.config.attention_head_dim, + unet.config.cross_attention_dim, + unet.config.use_linear_projection, + unet.config.upcast_attention, + ) + original_unet.load_state_dict(unet.state_dict()) + unet = original_unet + # VAEを読み込む if args.vae is not None: vae = model_util.load_vae(args.vae, dtype) diff --git a/library/model_util.py b/library/model_util.py index d59f5ef4..63a395f8 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -933,10 +933,31 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt else: converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict) - logging.set_verbosity_error() # don't show annoying warning - text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device) - logging.set_verbosity_warning() - + # logging.set_verbosity_error() # don't show annoying warning + # text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device) + # logging.set_verbosity_warning() + # print(f"config: {text_model.config}") + cfg = CLIPTextConfig( + vocab_size=49408, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + max_position_embeddings=77, + hidden_act="quick_gelu", + layer_norm_eps=1e-05, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + model_type="clip_text_model", + projection_dim=768, + torch_dtype="float32", + ) + text_model = CLIPTextModel._from_config(cfg) info = text_model.load_state_dict(converted_text_encoder_checkpoint) print("loading text encoder:", info) diff --git a/library/train_util.py b/library/train_util.py index 30380262..f13d7252 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -36,7 +36,6 @@ from torch.optim import Optimizer from torchvision import transforms from transformers import CLIPTokenizer import transformers -import diffusers from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION from diffusers import ( StableDiffusionPipeline, @@ -52,6 +51,7 @@ from diffusers import ( KDPM2DiscreteScheduler, KDPM2AncestralDiscreteScheduler, ) +from library.original_unet import UNet2DConditionModel from huggingface_hub import hf_hub_download import albumentations as albu import numpy as np @@ -2947,11 +2947,26 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"): print( f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}" ) + raise ex text_encoder = pipe.text_encoder vae = pipe.vae unet = pipe.unet del pipe + # Diffusers U-Net to original U-Net + # TODO *.ckpt/*.safetensorsのv2と同じ形式にここで変換すると良さそう + # print(f"unet config: {unet.config}") + original_unet = UNet2DConditionModel( + unet.config.sample_size, + unet.config.attention_head_dim, + unet.config.cross_attention_dim, + unet.config.use_linear_projection, + unet.config.upcast_attention, + ) + original_unet.load_state_dict(unet.state_dict()) + unet = original_unet + print("U-Net converted to original U-Net") + # VAEを読み込む if args.vae is not None: vae = model_util.load_vae(args.vae, weight_dtype)