mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
update sdxl ver in lora metadata from v0-9 to v1-0
This commit is contained in:
@@ -10,10 +10,10 @@ from library import sdxl_original_unet
|
||||
|
||||
|
||||
VAE_SCALE_FACTOR = 0.13025
|
||||
MODEL_VERSION_SDXL_BASE_V0_9 = "sdxl_base_v0-9"
|
||||
MODEL_VERSION_SDXL_BASE_V1_0 = "sdxl_base_v1-0"
|
||||
|
||||
# Diffusersの設定を読み込むための参照モデル
|
||||
DIFFUSERS_REF_MODEL_ID_SDXL = "stabilityai/stable-diffusion-xl-base-0.9" # アクセス権が必要
|
||||
DIFFUSERS_REF_MODEL_ID_SDXL = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
|
||||
DIFFUSERS_SDXL_UNET_CONFIG = {
|
||||
"act_fn": "silu",
|
||||
|
||||
@@ -61,15 +61,15 @@ def svd(args):
|
||||
else:
|
||||
print(f"loading original SDXL model : {args.model_org}")
|
||||
text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, args.model_org, "cpu"
|
||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.model_org, "cpu"
|
||||
)
|
||||
text_encoders_o = [text_encoder_o1, text_encoder_o2]
|
||||
print(f"loading original SDXL model : {args.model_tuned}")
|
||||
text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, args.model_tuned, "cpu"
|
||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.model_tuned, "cpu"
|
||||
)
|
||||
text_encoders_t = [text_encoder_t1, text_encoder_t2]
|
||||
model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9
|
||||
model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0
|
||||
|
||||
# create LoRA network to extract weights: Use dim (rank) as alpha
|
||||
if args.conv_dim is None:
|
||||
|
||||
@@ -234,7 +234,7 @@ def merge(args):
|
||||
unet,
|
||||
logit_scale,
|
||||
ckpt_info,
|
||||
) = sdxl_model_util.load_models_from_sdxl_checkpoint(sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, args.sd_model, "cpu")
|
||||
) = sdxl_model_util.load_models_from_sdxl_checkpoint(sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.sd_model, "cpu")
|
||||
|
||||
merge_to_sd_model(text_model1, text_model2, unet, args.models, args.ratios, merge_dtype)
|
||||
|
||||
|
||||
@@ -1294,7 +1294,7 @@ def main(args):
|
||||
args.ckpt = files[0]
|
||||
|
||||
(_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model(
|
||||
args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, dtype
|
||||
args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype
|
||||
)
|
||||
|
||||
# xformers、Hypernetwork対応
|
||||
|
||||
@@ -112,7 +112,7 @@ if __name__ == "__main__":
|
||||
# 本体RAMが少ない場合はGPUにロードするといいかも
|
||||
# If the main RAM is small, it may be better to load it on the GPU
|
||||
text_model1, text_model2, vae, unet, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, args.ckpt_path, "cpu"
|
||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.ckpt_path, "cpu"
|
||||
)
|
||||
|
||||
# Text Encoder 1はSDXL本体でもHuggingFaceのものを使っている
|
||||
|
||||
@@ -151,7 +151,7 @@ def train(args):
|
||||
else:
|
||||
save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors"
|
||||
use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
|
||||
assert save_stable_diffusion_format, "save_model_as must be ckpt or safetensors / save_model_asはckptかsafetensorsである必要があります"
|
||||
# assert save_stable_diffusion_format, "save_model_as must be ckpt or safetensors / save_model_asはckptかsafetensorsである必要があります"
|
||||
|
||||
# Diffusers版のxformers使用フラグを設定する関数
|
||||
def set_diffusers_xformers_flag(model, valid):
|
||||
|
||||
@@ -32,13 +32,13 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
unet,
|
||||
logit_scale,
|
||||
ckpt_info,
|
||||
) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, weight_dtype)
|
||||
) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype)
|
||||
|
||||
self.load_stable_diffusion_format = load_stable_diffusion_format
|
||||
self.logit_scale = logit_scale
|
||||
self.ckpt_info = ckpt_info
|
||||
|
||||
return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, [text_encoder1, text_encoder2], vae, unet
|
||||
return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet
|
||||
|
||||
def load_tokenizer(self, args):
|
||||
tokenizer = sdxl_train_util.load_tokenizers(args)
|
||||
|
||||
@@ -28,13 +28,13 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
|
||||
unet,
|
||||
logit_scale,
|
||||
ckpt_info,
|
||||
) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, weight_dtype)
|
||||
) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype)
|
||||
|
||||
self.load_stable_diffusion_format = load_stable_diffusion_format
|
||||
self.logit_scale = logit_scale
|
||||
self.ckpt_info = ckpt_info
|
||||
|
||||
return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, [text_encoder1, text_encoder2], vae, unet
|
||||
return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet
|
||||
|
||||
def load_tokenizer(self, args):
|
||||
tokenizer = sdxl_train_util.load_tokenizers(args)
|
||||
|
||||
Reference in New Issue
Block a user