Merge branch 'dev' into dev_device_support

This commit is contained in:
Kohya S
2024-02-12 13:01:54 +09:00
62 changed files with 1387 additions and 993 deletions

View File

@@ -16,7 +16,10 @@ from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def cache_to_disk(args: argparse.Namespace) -> None:
train_util.prepare_dataset_args(args, True)
@@ -41,18 +44,18 @@ def cache_to_disk(args: argparse.Namespace) -> None:
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
if use_dreambooth_method:
print("Using DreamBooth method.")
logger.info("Using DreamBooth method.")
user_config = {
"datasets": [
{
@@ -63,7 +66,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
]
}
else:
print("Training with captions.")
logger.info("Training with captions.")
user_config = {
"datasets": [
{
@@ -90,7 +93,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
# acceleratorを準備する
print("prepare accelerator")
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
@@ -98,7 +101,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
# モデルを読み込む
print("load model")
logger.info("load model")
if args.sdxl:
(_, _, _, vae, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
else:
@@ -152,7 +155,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
if args.skip_existing:
if train_util.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug):
print(f"Skipping {image_info.latents_npz} because it already exists.")
logger.warning(f"Skipping {image_info.latents_npz} because it already exists.")
continue
image_infos.append(image_info)

View File

@@ -16,7 +16,10 @@ from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def cache_to_disk(args: argparse.Namespace) -> None:
train_util.prepare_dataset_args(args, True)
@@ -48,18 +51,18 @@ def cache_to_disk(args: argparse.Namespace) -> None:
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
if use_dreambooth_method:
print("Using DreamBooth method.")
logger.info("Using DreamBooth method.")
user_config = {
"datasets": [
{
@@ -70,7 +73,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
]
}
else:
print("Training with captions.")
logger.info("Training with captions.")
user_config = {
"datasets": [
{
@@ -95,14 +98,14 @@ def cache_to_disk(args: argparse.Namespace) -> None:
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
# acceleratorを準備する
print("prepare accelerator")
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, _ = train_util.prepare_dtype(args)
# モデルを読み込む
print("load model")
logger.info("load model")
if args.sdxl:
(_, text_encoder1, text_encoder2, _, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
text_encoders = [text_encoder1, text_encoder2]
@@ -147,7 +150,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
if args.skip_existing:
if os.path.exists(image_info.text_encoder_outputs_npz):
print(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.")
logger.warning(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.")
continue
image_info.input_ids1 = input_ids1

View File

@@ -1,6 +1,10 @@
import argparse
import cv2
import logging
from library.utils import setup_logging
setup_logging()
logger = logging.getLogger(__name__)
def canny(args):
img = cv2.imread(args.input)
@@ -10,7 +14,7 @@ def canny(args):
# canny_img = 255 - canny_img
cv2.imwrite(args.output, canny_img)
print("done!")
logger.info("done!")
def setup_parser() -> argparse.ArgumentParser:

View File

@@ -6,7 +6,10 @@ import torch
from diffusers import StableDiffusionPipeline
import library.model_util as model_util
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def convert(args):
# 引数を確認する
@@ -30,7 +33,7 @@ def convert(args):
# モデルを読み込む
msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else ""))
print(f"loading {msg}: {args.model_to_load}")
logger.info(f"loading {msg}: {args.model_to_load}")
if is_load_ckpt:
v2_model = args.v2
@@ -48,13 +51,13 @@ def convert(args):
if args.v1 == args.v2:
# 自動判定する
v2_model = unet.config.cross_attention_dim == 1024
print("checking model version: model is " + ("v2" if v2_model else "v1"))
logger.info("checking model version: model is " + ("v2" if v2_model else "v1"))
else:
v2_model = not args.v1
# 変換して保存する
msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers"
print(f"converting and saving as {msg}: {args.model_to_save}")
logger.info(f"converting and saving as {msg}: {args.model_to_save}")
if is_save_ckpt:
original_model = args.model_to_load if is_load_ckpt else None
@@ -70,15 +73,15 @@ def convert(args):
save_dtype=save_dtype,
vae=vae,
)
print(f"model saved. total converted state_dict keys: {key_count}")
logger.info(f"model saved. total converted state_dict keys: {key_count}")
else:
print(
logger.info(
f"copy scheduler/tokenizer config from: {args.reference_model if args.reference_model is not None else 'default model'}"
)
model_util.save_diffusers_checkpoint(
v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors
)
print("model saved.")
logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser:

View File

@@ -15,6 +15,10 @@ import os
from anime_face_detector import create_detector
from tqdm import tqdm
import numpy as np
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
KP_REYE = 11
KP_LEYE = 19
@@ -24,7 +28,7 @@ SCORE_THRES = 0.90
def detect_faces(detector, image, min_size):
preds = detector(image) # bgr
# print(len(preds))
# logger.info(len(preds))
faces = []
for pred in preds:
@@ -78,7 +82,7 @@ def process(args):
assert args.crop_ratio is None or args.resize_face_size is None, f"crop_ratio指定時はresize_face_sizeは指定できません"
# アニメ顔検出モデルを読み込む
print("loading face detector.")
logger.info("loading face detector.")
detector = create_detector('yolov3')
# cropの引数を解析する
@@ -97,7 +101,7 @@ def process(args):
crop_h_ratio, crop_v_ratio = [float(t) for t in tokens]
# 画像を処理する
print("processing.")
logger.info("processing.")
output_extension = ".png"
os.makedirs(args.dst_dir, exist_ok=True)
@@ -111,7 +115,7 @@ def process(args):
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
if image.shape[2] == 4:
print(f"image has alpha. ignore / 画像の透明度が設定されているため無視します: {path}")
logger.warning(f"image has alpha. ignore / 画像の透明度が設定されているため無視します: {path}")
image = image[:, :, :3].copy() # copyをしないと内部的に透明度情報が付いたままになるらしい
h, w = image.shape[:2]
@@ -144,11 +148,11 @@ def process(args):
# 顔サイズを基準にリサイズする
scale = args.resize_face_size / face_size
if scale < cur_crop_width / w:
print(
logger.warning(
f"image width too small in face size based resizing / 顔を基準にリサイズすると画像の幅がcrop sizeより小さい顔が相対的に大きすぎるので顔サイズが変わります: {path}")
scale = cur_crop_width / w
if scale < cur_crop_height / h:
print(
logger.warning(
f"image height too small in face size based resizing / 顔を基準にリサイズすると画像の高さがcrop sizeより小さい顔が相対的に大きすぎるので顔サイズが変わります: {path}")
scale = cur_crop_height / h
elif crop_h_ratio is not None:
@@ -157,10 +161,10 @@ def process(args):
else:
# 切り出しサイズ指定あり
if w < cur_crop_width:
print(f"image width too small/ 画像の幅がcrop sizeより小さいので画質が劣化します: {path}")
logger.warning(f"image width too small/ 画像の幅がcrop sizeより小さいので画質が劣化します: {path}")
scale = cur_crop_width / w
if h < cur_crop_height:
print(f"image height too small/ 画像の高さがcrop sizeより小さいので画質が劣化します: {path}")
logger.warning(f"image height too small/ 画像の高さがcrop sizeより小さいので画質が劣化します: {path}")
scale = cur_crop_height / h
if args.resize_fit:
scale = max(cur_crop_width / w, cur_crop_height / h)
@@ -198,7 +202,7 @@ def process(args):
face_img = face_img[y:y + cur_crop_height]
# # debug
# print(path, cx, cy, angle)
# logger.info(path, cx, cy, angle)
# crp = cv2.resize(image, (image.shape[1]//8, image.shape[0]//8))
# cv2.imshow("image", crp)
# if cv2.waitKey() == 27:

View File

@@ -17,7 +17,10 @@ init_ipex()
from torch import nn
from tqdm import tqdm
from PIL import Image
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=1):
@@ -219,7 +222,7 @@ class Upscaler(nn.Module):
upsampled_images = upsampled_images / 127.5 - 1.0
# convert upsample images to latents with batch size
# print("Encoding upsampled (LANCZOS4) images...")
# logger.info("Encoding upsampled (LANCZOS4) images...")
upsampled_latents = []
for i in tqdm(range(0, upsampled_images.shape[0], vae_batch_size)):
batch = upsampled_images[i : i + vae_batch_size].to(vae.device)
@@ -230,7 +233,7 @@ class Upscaler(nn.Module):
upsampled_latents = torch.cat(upsampled_latents, dim=0)
# upscale (refine) latents with this model with batch size
print("Upscaling latents...")
logger.info("Upscaling latents...")
upscaled_latents = []
for i in range(0, upsampled_latents.shape[0], batch_size):
with torch.no_grad():
@@ -245,7 +248,7 @@ def create_upscaler(**kwargs):
weights = kwargs["weights"]
model = Upscaler()
print(f"Loading weights from {weights}...")
logger.info(f"Loading weights from {weights}...")
if os.path.splitext(weights)[1] == ".safetensors":
from safetensors.torch import load_file
@@ -264,14 +267,14 @@ def upscale_images(args: argparse.Namespace):
# load VAE with Diffusers
assert args.vae_path is not None, "VAE path is required"
print(f"Loading VAE from {args.vae_path}...")
logger.info(f"Loading VAE from {args.vae_path}...")
vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae")
vae.to(DEVICE, dtype=us_dtype)
# prepare model
print("Preparing model...")
logger.info("Preparing model...")
upscaler: Upscaler = create_upscaler(weights=args.weights)
# print("Loading weights from", args.weights)
# logger.info("Loading weights from", args.weights)
# upscaler.load_state_dict(torch.load(args.weights))
upscaler.eval()
upscaler.to(DEVICE, dtype=us_dtype)
@@ -306,14 +309,14 @@ def upscale_images(args: argparse.Namespace):
image_debug.save(dest_file_name)
# upscale
print("Upscaling...")
logger.info("Upscaling...")
upscaled_latents = upscaler.upscale(
vae, images, None, us_dtype, width * 2, height * 2, batch_size=args.batch_size, vae_batch_size=args.vae_batch_size
)
upscaled_latents /= 0.18215
# decode with batch
print("Decoding...")
logger.info("Decoding...")
upscaled_images = []
for i in tqdm(range(0, upscaled_latents.shape[0], args.vae_batch_size)):
with torch.no_grad():

View File

@@ -5,7 +5,10 @@ import torch
from safetensors import safe_open
from safetensors.torch import load_file, save_file
from tqdm import tqdm
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def is_unet_key(key):
# VAE or TextEncoder, the last one is for SDXL
@@ -45,10 +48,10 @@ def merge(args):
# check if all models are safetensors
for model in args.models:
if not model.endswith("safetensors"):
print(f"Model {model} is not a safetensors model")
logger.info(f"Model {model} is not a safetensors model")
exit()
if not os.path.isfile(model):
print(f"Model {model} does not exist")
logger.info(f"Model {model} does not exist")
exit()
assert args.ratios is None or len(args.models) == len(args.ratios), "ratios must be the same length as models"
@@ -65,7 +68,7 @@ def merge(args):
if merged_sd is None:
# load first model
print(f"Loading model {model}, ratio = {ratio}...")
logger.info(f"Loading model {model}, ratio = {ratio}...")
merged_sd = {}
with safe_open(model, framework="pt", device=args.device) as f:
for key in tqdm(f.keys()):
@@ -81,11 +84,11 @@ def merge(args):
value = ratio * value.to(dtype) # first model's value * ratio
merged_sd[key] = value
print(f"Model has {len(merged_sd)} keys " + ("(UNet only)" if args.unet_only else ""))
logger.info(f"Model has {len(merged_sd)} keys " + ("(UNet only)" if args.unet_only else ""))
continue
# load other models
print(f"Loading model {model}, ratio = {ratio}...")
logger.info(f"Loading model {model}, ratio = {ratio}...")
with safe_open(model, framework="pt", device=args.device) as f:
model_keys = f.keys()
@@ -93,7 +96,7 @@ def merge(args):
_, new_key = replace_text_encoder_key(key)
if new_key not in merged_sd:
if args.show_skipped and new_key not in first_model_keys:
print(f"Skip: {new_key}")
logger.info(f"Skip: {new_key}")
continue
value = f.get_tensor(key)
@@ -104,7 +107,7 @@ def merge(args):
for key in merged_sd.keys():
if key in model_keys:
continue
print(f"Key {key} not in model {model}, use first model's value")
logger.warning(f"Key {key} not in model {model}, use first model's value")
if key in supplementary_key_ratios:
supplementary_key_ratios[key] += ratio
else:
@@ -112,7 +115,7 @@ def merge(args):
# add supplementary keys' value (including VAE and TextEncoder)
if len(supplementary_key_ratios) > 0:
print("add first model's value")
logger.info("add first model's value")
with safe_open(args.models[0], framework="pt", device=args.device) as f:
for key in tqdm(f.keys()):
_, new_key = replace_text_encoder_key(key)
@@ -120,7 +123,7 @@ def merge(args):
continue
if is_unet_key(new_key): # not VAE or TextEncoder
print(f"Key {new_key} not in all models, ratio = {supplementary_key_ratios[new_key]}")
logger.warning(f"Key {new_key} not in all models, ratio = {supplementary_key_ratios[new_key]}")
value = f.get_tensor(key) # original key
@@ -134,7 +137,7 @@ def merge(args):
if not output_file.endswith(".safetensors"):
output_file = output_file + ".safetensors"
print(f"Saving to {output_file}...")
logger.info(f"Saving to {output_file}...")
# convert to save_dtype
for k in merged_sd.keys():
@@ -142,7 +145,7 @@ def merge(args):
save_file(merged_sd, output_file)
print("Done!")
logger.info("Done!")
if __name__ == "__main__":

View File

@@ -7,7 +7,10 @@ from safetensors.torch import load_file
from library.original_unet import UNet2DConditionModel, SampleOutput
import library.model_util as model_util
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class ControlNetInfo(NamedTuple):
unet: Any
@@ -51,7 +54,7 @@ def load_control_net(v2, unet, model):
# control sdからキー変換しつつU-Netに対応する部分のみ取り出し、DiffusersのU-Netに読み込む
# state dictを読み込む
print(f"ControlNet: loading control SD model : {model}")
logger.info(f"ControlNet: loading control SD model : {model}")
if model_util.is_safetensors(model):
ctrl_sd_sd = load_file(model)
@@ -61,7 +64,7 @@ def load_control_net(v2, unet, model):
# 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む
is_difference = "difference" in ctrl_sd_sd
print("ControlNet: loading difference:", is_difference)
logger.info(f"ControlNet: loading difference: {is_difference}")
# ControlNetには存在しないキーがあるので、まず現在のU-NetでSD版の全keyを作っておく
# またTransfer Controlの元weightとなる
@@ -89,13 +92,13 @@ def load_control_net(v2, unet, model):
# ControlNetのU-Netを作成する
ctrl_unet = UNet2DConditionModel(**unet_config)
info = ctrl_unet.load_state_dict(ctrl_unet_du_sd)
print("ControlNet: loading Control U-Net:", info)
logger.info(f"ControlNet: loading Control U-Net: {info}")
# U-Net以外のControlNetを作成する
# TODO support middle only
ctrl_net = ControlNet()
info = ctrl_net.load_state_dict(zero_conv_sd)
print("ControlNet: loading ControlNet:", info)
logger.info("ControlNet: loading ControlNet: {info}")
ctrl_unet.to(unet.device, dtype=unet.dtype)
ctrl_net.to(unet.device, dtype=unet.dtype)
@@ -117,7 +120,7 @@ def load_preprocess(prep_type: str):
return canny
print("Unsupported prep type:", prep_type)
logger.info(f"Unsupported prep type: {prep_type}")
return None
@@ -174,7 +177,7 @@ def call_unet_and_control_net(
cnet_idx = step % cnet_cnt
cnet_info = control_nets[cnet_idx]
# print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
# logger.info(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
if cnet_info.ratio < current_ratio:
return original_unet(sample, timestep, encoder_hidden_states)
@@ -192,7 +195,7 @@ def call_unet_and_control_net(
# ControlNet
cnet_outs_list = []
for i, cnet_info in enumerate(control_nets):
# print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
# logger.info(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
if cnet_info.ratio < current_ratio:
continue
guided_hint = guided_hints[i]
@@ -232,7 +235,7 @@ def unet_forward(
upsample_size = None
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
print("Forward upsample size to force interpolation output size.")
logger.info("Forward upsample size to force interpolation output size.")
forward_upsample_size = True
# 1. time

View File

@@ -6,7 +6,10 @@ import shutil
import math
from PIL import Image
import numpy as np
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2, interpolation=None, save_as_png=False, copy_associated_files=False):
# Split the max_resolution string by "," and strip any whitespaces
@@ -83,7 +86,7 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
image.save(os.path.join(dst_img_folder, new_filename), quality=100)
proc = "Resized" if current_pixels > max_pixels else "Saved"
print(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}")
logger.info(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}")
# If other files with same basename, copy them with resolution suffix
if copy_associated_files:
@@ -94,7 +97,7 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
continue
for max_resolution in max_resolutions:
new_asoc_file = base + '+' + max_resolution + ext
print(f"Copy {asoc_file} as {new_asoc_file}")
logger.info(f"Copy {asoc_file} as {new_asoc_file}")
shutil.copy(os.path.join(src_img_folder, asoc_file), os.path.join(dst_img_folder, new_asoc_file))

View File

@@ -1,6 +1,10 @@
import json
import argparse
from safetensors import safe_open
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, required=True)
@@ -10,10 +14,10 @@ with safe_open(args.model, framework="pt") as f:
metadata = f.metadata()
if metadata is None:
print("No metadata found")
logger.error("No metadata found")
else:
# metadata is json dict, but not pretty printed
# sort by key and pretty print
print(json.dumps(metadata, indent=4, sort_keys=True))