diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index b9cb6057..807e0aec 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -10,10 +10,10 @@ from library import sdxl_original_unet VAE_SCALE_FACTOR = 0.13025 -MODEL_VERSION_SDXL_BASE_V0_9 = "sdxl_base_v0-9" +MODEL_VERSION_SDXL_BASE_V1_0 = "sdxl_base_v1-0" # Diffusersの設定を読み込むための参照モデル -DIFFUSERS_REF_MODEL_ID_SDXL = "stabilityai/stable-diffusion-xl-base-0.9" # アクセス権が必要 +DIFFUSERS_REF_MODEL_ID_SDXL = "stabilityai/stable-diffusion-xl-base-1.0" DIFFUSERS_SDXL_UNET_CONFIG = { "act_fn": "silu", diff --git a/networks/extract_lora_from_models.py b/networks/extract_lora_from_models.py index e8099681..b4eb0cf7 100644 --- a/networks/extract_lora_from_models.py +++ b/networks/extract_lora_from_models.py @@ -61,15 +61,15 @@ def svd(args): else: print(f"loading original SDXL model : {args.model_org}") text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( - sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, args.model_org, "cpu" + sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.model_org, "cpu" ) text_encoders_o = [text_encoder_o1, text_encoder_o2] print(f"loading original SDXL model : {args.model_tuned}") text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( - sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, args.model_tuned, "cpu" + sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.model_tuned, "cpu" ) text_encoders_t = [text_encoder_t1, text_encoder_t2] - model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9 + model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0 # create LoRA network to extract weights: Use dim (rank) as alpha if args.conv_dim is None: diff --git a/networks/sdxl_merge_lora.py b/networks/sdxl_merge_lora.py index 74f455b2..0608c01f 100644 --- a/networks/sdxl_merge_lora.py +++ b/networks/sdxl_merge_lora.py @@ -234,7 +234,7 @@ def merge(args): unet, logit_scale, ckpt_info, - ) = sdxl_model_util.load_models_from_sdxl_checkpoint(sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, args.sd_model, "cpu") + ) = sdxl_model_util.load_models_from_sdxl_checkpoint(sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.sd_model, "cpu") merge_to_sd_model(text_model1, text_model2, unet, args.models, args.ratios, merge_dtype) diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index 68f0e5db..209e71a7 100644 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -1294,7 +1294,7 @@ def main(args): args.ckpt = files[0] (_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model( - args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, dtype + args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype ) # xformers、Hypernetwork対応 diff --git a/sdxl_minimal_inference.py b/sdxl_minimal_inference.py index 72ffe97f..5c8a0bd8 100644 --- a/sdxl_minimal_inference.py +++ b/sdxl_minimal_inference.py @@ -112,7 +112,7 @@ if __name__ == "__main__": # 本体RAMが少ない場合はGPUにロードするといいかも # If the main RAM is small, it may be better to load it on the GPU text_model1, text_model2, vae, unet, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint( - sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, args.ckpt_path, "cpu" + sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.ckpt_path, "cpu" ) # Text Encoder 1はSDXL本体でもHuggingFaceのものを使っている diff --git a/sdxl_train.py b/sdxl_train.py index b57e2f5c..2ca14931 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -151,7 +151,7 @@ def train(args): else: save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors" use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower()) - assert save_stable_diffusion_format, "save_model_as must be ckpt or safetensors / save_model_asはckptかsafetensorsである必要があります" + # assert save_stable_diffusion_format, "save_model_as must be ckpt or safetensors / save_model_asはckptかsafetensorsである必要があります" # Diffusers版のxformers使用フラグを設定する関数 def set_diffusers_xformers_flag(model, valid): diff --git a/sdxl_train_network.py b/sdxl_train_network.py index 7dbdd413..e3254be0 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -32,13 +32,13 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): unet, logit_scale, ckpt_info, - ) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, weight_dtype) + ) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype) self.load_stable_diffusion_format = load_stable_diffusion_format self.logit_scale = logit_scale self.ckpt_info = ckpt_info - return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, [text_encoder1, text_encoder2], vae, unet + return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet def load_tokenizer(self, args): tokenizer = sdxl_train_util.load_tokenizers(args) diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py index a5c91a27..1ddfd92b 100644 --- a/sdxl_train_textual_inversion.py +++ b/sdxl_train_textual_inversion.py @@ -28,13 +28,13 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine unet, logit_scale, ckpt_info, - ) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, weight_dtype) + ) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype) self.load_stable_diffusion_format = load_stable_diffusion_format self.logit_scale = logit_scale self.ckpt_info = ckpt_info - return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, [text_encoder1, text_encoder2], vae, unet + return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet def load_tokenizer(self, args): tokenizer = sdxl_train_util.load_tokenizers(args)