Kohya-ss-sd-scripts/tools/convert_diffusers20_original_sd.py
Yuta Hayashibe 5f6bf29e52
Replace print with logger if they are logs (#905)
* Add get_my_logger()

* Use logger instead of print

* Fix log level

* Removed line-breaks for readability

* Use setup_logging()

* Add rich to requirements.txt

* Make simple

* Use logger instead of print

---------

Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
2024-02-04 18:14:34 +09:00

164 lines
7.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# convert Diffusers v1.x/v2.0 model to original Stable Diffusion
import argparse
import os
import torch
from diffusers import StableDiffusionPipeline
import library.model_util as model_util
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def convert(args):
# 引数を確認する
load_dtype = torch.float16 if args.fp16 else None
save_dtype = None
if args.fp16 or args.save_precision_as == "fp16":
save_dtype = torch.float16
elif args.bf16 or args.save_precision_as == "bf16":
save_dtype = torch.bfloat16
elif args.float or args.save_precision_as == "float":
save_dtype = torch.float
is_load_ckpt = os.path.isfile(args.model_to_load)
is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0
assert not is_load_ckpt or args.v1 != args.v2, "v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です"
# assert (
# is_save_ckpt or args.reference_model is not None
# ), f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です"
# モデルを読み込む
msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else ""))
logger.info(f"loading {msg}: {args.model_to_load}")
if is_load_ckpt:
v2_model = args.v2
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(
v2_model, args.model_to_load, unet_use_linear_projection_in_v2=args.unet_use_linear_projection
)
else:
pipe = StableDiffusionPipeline.from_pretrained(
args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None, variant=args.variant
)
text_encoder = pipe.text_encoder
vae = pipe.vae
unet = pipe.unet
if args.v1 == args.v2:
# 自動判定する
v2_model = unet.config.cross_attention_dim == 1024
logger.info("checking model version: model is " + ("v2" if v2_model else "v1"))
else:
v2_model = not args.v1
# 変換して保存する
msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers"
logger.info(f"converting and saving as {msg}: {args.model_to_save}")
if is_save_ckpt:
original_model = args.model_to_load if is_load_ckpt else None
key_count = model_util.save_stable_diffusion_checkpoint(
v2_model,
args.model_to_save,
text_encoder,
unet,
original_model,
args.epoch,
args.global_step,
None if args.metadata is None else eval(args.metadata),
save_dtype=save_dtype,
vae=vae,
)
logger.info(f"model saved. total converted state_dict keys: {key_count}")
else:
logger.info(
f"copy scheduler/tokenizer config from: {args.reference_model if args.reference_model is not None else 'default model'}"
)
model_util.save_diffusers_checkpoint(
v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors
)
logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument(
"--v1", action="store_true", help="load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む"
)
parser.add_argument(
"--v2", action="store_true", help="load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む"
)
parser.add_argument(
"--unet_use_linear_projection",
action="store_true",
help="When saving v2 model as Diffusers, set U-Net config to `use_linear_projection=true` (to match stabilityai's model) / Diffusers形式でv2モデルを保存するときにU-Netの設定を`use_linear_projection=true`にするstabilityaiのモデルと合わせる",
)
parser.add_argument(
"--fp16",
action="store_true",
help="load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16形式で読み込みDiffusers形式のみ対応、保存するcheckpointのみ対応",
)
parser.add_argument("--bf16", action="store_true", help="save as bf16 (checkpoint only) / bf16形式で保存するcheckpointのみ対応")
parser.add_argument(
"--float", action="store_true", help="save as float (checkpoint only) / float(float32)形式で保存するcheckpointのみ対応"
)
parser.add_argument(
"--save_precision_as",
type=str,
default="no",
choices=["fp16", "bf16", "float"],
help="save precision, do not specify with --fp16/--bf16/--float / 保存する精度、--fp16/--bf16/--floatと併用しないでください",
)
parser.add_argument("--epoch", type=int, default=0, help="epoch to write to checkpoint / checkpointに記録するepoch数の値")
parser.add_argument(
"--global_step", type=int, default=0, help="global_step to write to checkpoint / checkpointに記録するglobal_stepの値"
)
parser.add_argument(
"--metadata",
type=str,
default=None,
help='モデルに保存されるメタデータ、Pythonの辞書形式で指定 / metadata: metadata written in to the model in Python Dictionary. Example metadata: \'{"name": "model_name", "resolution": "512x512"}\'',
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="読む込むDiffusersのvariantを指定する、例: fp16 / variant: Diffusers variant to load. Example: fp16",
)
parser.add_argument(
"--reference_model",
type=str,
default=None,
help="scheduler/tokenizerのコピー元Diffusersモデル、Diffusers形式で保存するときに使用される、省略時は`runwayml/stable-diffusion-v1-5` または `stabilityai/stable-diffusion-2-1` / reference Diffusers model to copy scheduler/tokenizer config from, used when saving as Diffusers format, default is `runwayml/stable-diffusion-v1-5` or `stabilityai/stable-diffusion-2-1`",
)
parser.add_argument(
"--use_safetensors",
action="store_true",
help="use safetensors format to save Diffusers model (checkpoint depends on the file extension) / Duffusersモデルをsafetensors形式で保存するcheckpointは拡張子で自動判定",
)
parser.add_argument(
"model_to_load",
type=str,
default=None,
help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ",
)
parser.add_argument(
"model_to_save",
type=str,
default=None,
help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
convert(args)