mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
Merge branch 'dev' into dev_device_support
This commit is contained in:
@@ -16,7 +16,10 @@ from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
)
|
||||
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
@@ -41,18 +44,18 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
if args.dataset_class is None:
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
|
||||
if args.dataset_config is not None:
|
||||
print(f"Load dataset config from {args.dataset_config}")
|
||||
logger.info(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
print(
|
||||
logger.warning(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
)
|
||||
else:
|
||||
if use_dreambooth_method:
|
||||
print("Using DreamBooth method.")
|
||||
logger.info("Using DreamBooth method.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
@@ -63,7 +66,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
]
|
||||
}
|
||||
else:
|
||||
print("Training with captions.")
|
||||
logger.info("Training with captions.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
@@ -90,7 +93,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
# acceleratorを準備する
|
||||
print("prepare accelerator")
|
||||
logger.info("prepare accelerator")
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
@@ -98,7 +101,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
|
||||
|
||||
# モデルを読み込む
|
||||
print("load model")
|
||||
logger.info("load model")
|
||||
if args.sdxl:
|
||||
(_, _, _, vae, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
|
||||
else:
|
||||
@@ -152,7 +155,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
|
||||
if args.skip_existing:
|
||||
if train_util.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug):
|
||||
print(f"Skipping {image_info.latents_npz} because it already exists.")
|
||||
logger.warning(f"Skipping {image_info.latents_npz} because it already exists.")
|
||||
continue
|
||||
|
||||
image_infos.append(image_info)
|
||||
|
||||
@@ -16,7 +16,10 @@ from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
)
|
||||
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
@@ -48,18 +51,18 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
if args.dataset_class is None:
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
|
||||
if args.dataset_config is not None:
|
||||
print(f"Load dataset config from {args.dataset_config}")
|
||||
logger.info(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
print(
|
||||
logger.warning(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
)
|
||||
else:
|
||||
if use_dreambooth_method:
|
||||
print("Using DreamBooth method.")
|
||||
logger.info("Using DreamBooth method.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
@@ -70,7 +73,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
]
|
||||
}
|
||||
else:
|
||||
print("Training with captions.")
|
||||
logger.info("Training with captions.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
@@ -95,14 +98,14 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
# acceleratorを準備する
|
||||
print("prepare accelerator")
|
||||
logger.info("prepare accelerator")
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
weight_dtype, _ = train_util.prepare_dtype(args)
|
||||
|
||||
# モデルを読み込む
|
||||
print("load model")
|
||||
logger.info("load model")
|
||||
if args.sdxl:
|
||||
(_, text_encoder1, text_encoder2, _, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
|
||||
text_encoders = [text_encoder1, text_encoder2]
|
||||
@@ -147,7 +150,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
|
||||
if args.skip_existing:
|
||||
if os.path.exists(image_info.text_encoder_outputs_npz):
|
||||
print(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.")
|
||||
logger.warning(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.")
|
||||
continue
|
||||
|
||||
image_info.input_ids1 = input_ids1
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
import argparse
|
||||
import cv2
|
||||
|
||||
import logging
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def canny(args):
|
||||
img = cv2.imread(args.input)
|
||||
@@ -10,7 +14,7 @@ def canny(args):
|
||||
# canny_img = 255 - canny_img
|
||||
|
||||
cv2.imwrite(args.output, canny_img)
|
||||
print("done!")
|
||||
logger.info("done!")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
|
||||
@@ -6,7 +6,10 @@ import torch
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
import library.model_util as model_util
|
||||
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def convert(args):
|
||||
# 引数を確認する
|
||||
@@ -30,7 +33,7 @@ def convert(args):
|
||||
|
||||
# モデルを読み込む
|
||||
msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else ""))
|
||||
print(f"loading {msg}: {args.model_to_load}")
|
||||
logger.info(f"loading {msg}: {args.model_to_load}")
|
||||
|
||||
if is_load_ckpt:
|
||||
v2_model = args.v2
|
||||
@@ -48,13 +51,13 @@ def convert(args):
|
||||
if args.v1 == args.v2:
|
||||
# 自動判定する
|
||||
v2_model = unet.config.cross_attention_dim == 1024
|
||||
print("checking model version: model is " + ("v2" if v2_model else "v1"))
|
||||
logger.info("checking model version: model is " + ("v2" if v2_model else "v1"))
|
||||
else:
|
||||
v2_model = not args.v1
|
||||
|
||||
# 変換して保存する
|
||||
msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers"
|
||||
print(f"converting and saving as {msg}: {args.model_to_save}")
|
||||
logger.info(f"converting and saving as {msg}: {args.model_to_save}")
|
||||
|
||||
if is_save_ckpt:
|
||||
original_model = args.model_to_load if is_load_ckpt else None
|
||||
@@ -70,15 +73,15 @@ def convert(args):
|
||||
save_dtype=save_dtype,
|
||||
vae=vae,
|
||||
)
|
||||
print(f"model saved. total converted state_dict keys: {key_count}")
|
||||
logger.info(f"model saved. total converted state_dict keys: {key_count}")
|
||||
else:
|
||||
print(
|
||||
logger.info(
|
||||
f"copy scheduler/tokenizer config from: {args.reference_model if args.reference_model is not None else 'default model'}"
|
||||
)
|
||||
model_util.save_diffusers_checkpoint(
|
||||
v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors
|
||||
)
|
||||
print("model saved.")
|
||||
logger.info("model saved.")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
|
||||
@@ -15,6 +15,10 @@ import os
|
||||
from anime_face_detector import create_detector
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
KP_REYE = 11
|
||||
KP_LEYE = 19
|
||||
@@ -24,7 +28,7 @@ SCORE_THRES = 0.90
|
||||
|
||||
def detect_faces(detector, image, min_size):
|
||||
preds = detector(image) # bgr
|
||||
# print(len(preds))
|
||||
# logger.info(len(preds))
|
||||
|
||||
faces = []
|
||||
for pred in preds:
|
||||
@@ -78,7 +82,7 @@ def process(args):
|
||||
assert args.crop_ratio is None or args.resize_face_size is None, f"crop_ratio指定時はresize_face_sizeは指定できません"
|
||||
|
||||
# アニメ顔検出モデルを読み込む
|
||||
print("loading face detector.")
|
||||
logger.info("loading face detector.")
|
||||
detector = create_detector('yolov3')
|
||||
|
||||
# cropの引数を解析する
|
||||
@@ -97,7 +101,7 @@ def process(args):
|
||||
crop_h_ratio, crop_v_ratio = [float(t) for t in tokens]
|
||||
|
||||
# 画像を処理する
|
||||
print("processing.")
|
||||
logger.info("processing.")
|
||||
output_extension = ".png"
|
||||
|
||||
os.makedirs(args.dst_dir, exist_ok=True)
|
||||
@@ -111,7 +115,7 @@ def process(args):
|
||||
if len(image.shape) == 2:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
|
||||
if image.shape[2] == 4:
|
||||
print(f"image has alpha. ignore / 画像の透明度が設定されているため無視します: {path}")
|
||||
logger.warning(f"image has alpha. ignore / 画像の透明度が設定されているため無視します: {path}")
|
||||
image = image[:, :, :3].copy() # copyをしないと内部的に透明度情報が付いたままになるらしい
|
||||
|
||||
h, w = image.shape[:2]
|
||||
@@ -144,11 +148,11 @@ def process(args):
|
||||
# 顔サイズを基準にリサイズする
|
||||
scale = args.resize_face_size / face_size
|
||||
if scale < cur_crop_width / w:
|
||||
print(
|
||||
logger.warning(
|
||||
f"image width too small in face size based resizing / 顔を基準にリサイズすると画像の幅がcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}")
|
||||
scale = cur_crop_width / w
|
||||
if scale < cur_crop_height / h:
|
||||
print(
|
||||
logger.warning(
|
||||
f"image height too small in face size based resizing / 顔を基準にリサイズすると画像の高さがcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}")
|
||||
scale = cur_crop_height / h
|
||||
elif crop_h_ratio is not None:
|
||||
@@ -157,10 +161,10 @@ def process(args):
|
||||
else:
|
||||
# 切り出しサイズ指定あり
|
||||
if w < cur_crop_width:
|
||||
print(f"image width too small/ 画像の幅がcrop sizeより小さいので画質が劣化します: {path}")
|
||||
logger.warning(f"image width too small/ 画像の幅がcrop sizeより小さいので画質が劣化します: {path}")
|
||||
scale = cur_crop_width / w
|
||||
if h < cur_crop_height:
|
||||
print(f"image height too small/ 画像の高さがcrop sizeより小さいので画質が劣化します: {path}")
|
||||
logger.warning(f"image height too small/ 画像の高さがcrop sizeより小さいので画質が劣化します: {path}")
|
||||
scale = cur_crop_height / h
|
||||
if args.resize_fit:
|
||||
scale = max(cur_crop_width / w, cur_crop_height / h)
|
||||
@@ -198,7 +202,7 @@ def process(args):
|
||||
face_img = face_img[y:y + cur_crop_height]
|
||||
|
||||
# # debug
|
||||
# print(path, cx, cy, angle)
|
||||
# logger.info(path, cx, cy, angle)
|
||||
# crp = cv2.resize(image, (image.shape[1]//8, image.shape[0]//8))
|
||||
# cv2.imshow("image", crp)
|
||||
# if cv2.waitKey() == 27:
|
||||
|
||||
@@ -17,7 +17,10 @@ init_ipex()
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=1):
|
||||
@@ -219,7 +222,7 @@ class Upscaler(nn.Module):
|
||||
upsampled_images = upsampled_images / 127.5 - 1.0
|
||||
|
||||
# convert upsample images to latents with batch size
|
||||
# print("Encoding upsampled (LANCZOS4) images...")
|
||||
# logger.info("Encoding upsampled (LANCZOS4) images...")
|
||||
upsampled_latents = []
|
||||
for i in tqdm(range(0, upsampled_images.shape[0], vae_batch_size)):
|
||||
batch = upsampled_images[i : i + vae_batch_size].to(vae.device)
|
||||
@@ -230,7 +233,7 @@ class Upscaler(nn.Module):
|
||||
upsampled_latents = torch.cat(upsampled_latents, dim=0)
|
||||
|
||||
# upscale (refine) latents with this model with batch size
|
||||
print("Upscaling latents...")
|
||||
logger.info("Upscaling latents...")
|
||||
upscaled_latents = []
|
||||
for i in range(0, upsampled_latents.shape[0], batch_size):
|
||||
with torch.no_grad():
|
||||
@@ -245,7 +248,7 @@ def create_upscaler(**kwargs):
|
||||
weights = kwargs["weights"]
|
||||
model = Upscaler()
|
||||
|
||||
print(f"Loading weights from {weights}...")
|
||||
logger.info(f"Loading weights from {weights}...")
|
||||
if os.path.splitext(weights)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
@@ -264,14 +267,14 @@ def upscale_images(args: argparse.Namespace):
|
||||
|
||||
# load VAE with Diffusers
|
||||
assert args.vae_path is not None, "VAE path is required"
|
||||
print(f"Loading VAE from {args.vae_path}...")
|
||||
logger.info(f"Loading VAE from {args.vae_path}...")
|
||||
vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae")
|
||||
vae.to(DEVICE, dtype=us_dtype)
|
||||
|
||||
# prepare model
|
||||
print("Preparing model...")
|
||||
logger.info("Preparing model...")
|
||||
upscaler: Upscaler = create_upscaler(weights=args.weights)
|
||||
# print("Loading weights from", args.weights)
|
||||
# logger.info("Loading weights from", args.weights)
|
||||
# upscaler.load_state_dict(torch.load(args.weights))
|
||||
upscaler.eval()
|
||||
upscaler.to(DEVICE, dtype=us_dtype)
|
||||
@@ -306,14 +309,14 @@ def upscale_images(args: argparse.Namespace):
|
||||
image_debug.save(dest_file_name)
|
||||
|
||||
# upscale
|
||||
print("Upscaling...")
|
||||
logger.info("Upscaling...")
|
||||
upscaled_latents = upscaler.upscale(
|
||||
vae, images, None, us_dtype, width * 2, height * 2, batch_size=args.batch_size, vae_batch_size=args.vae_batch_size
|
||||
)
|
||||
upscaled_latents /= 0.18215
|
||||
|
||||
# decode with batch
|
||||
print("Decoding...")
|
||||
logger.info("Decoding...")
|
||||
upscaled_images = []
|
||||
for i in tqdm(range(0, upscaled_latents.shape[0], args.vae_batch_size)):
|
||||
with torch.no_grad():
|
||||
|
||||
@@ -5,7 +5,10 @@ import torch
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def is_unet_key(key):
|
||||
# VAE or TextEncoder, the last one is for SDXL
|
||||
@@ -45,10 +48,10 @@ def merge(args):
|
||||
# check if all models are safetensors
|
||||
for model in args.models:
|
||||
if not model.endswith("safetensors"):
|
||||
print(f"Model {model} is not a safetensors model")
|
||||
logger.info(f"Model {model} is not a safetensors model")
|
||||
exit()
|
||||
if not os.path.isfile(model):
|
||||
print(f"Model {model} does not exist")
|
||||
logger.info(f"Model {model} does not exist")
|
||||
exit()
|
||||
|
||||
assert args.ratios is None or len(args.models) == len(args.ratios), "ratios must be the same length as models"
|
||||
@@ -65,7 +68,7 @@ def merge(args):
|
||||
|
||||
if merged_sd is None:
|
||||
# load first model
|
||||
print(f"Loading model {model}, ratio = {ratio}...")
|
||||
logger.info(f"Loading model {model}, ratio = {ratio}...")
|
||||
merged_sd = {}
|
||||
with safe_open(model, framework="pt", device=args.device) as f:
|
||||
for key in tqdm(f.keys()):
|
||||
@@ -81,11 +84,11 @@ def merge(args):
|
||||
value = ratio * value.to(dtype) # first model's value * ratio
|
||||
merged_sd[key] = value
|
||||
|
||||
print(f"Model has {len(merged_sd)} keys " + ("(UNet only)" if args.unet_only else ""))
|
||||
logger.info(f"Model has {len(merged_sd)} keys " + ("(UNet only)" if args.unet_only else ""))
|
||||
continue
|
||||
|
||||
# load other models
|
||||
print(f"Loading model {model}, ratio = {ratio}...")
|
||||
logger.info(f"Loading model {model}, ratio = {ratio}...")
|
||||
|
||||
with safe_open(model, framework="pt", device=args.device) as f:
|
||||
model_keys = f.keys()
|
||||
@@ -93,7 +96,7 @@ def merge(args):
|
||||
_, new_key = replace_text_encoder_key(key)
|
||||
if new_key not in merged_sd:
|
||||
if args.show_skipped and new_key not in first_model_keys:
|
||||
print(f"Skip: {new_key}")
|
||||
logger.info(f"Skip: {new_key}")
|
||||
continue
|
||||
|
||||
value = f.get_tensor(key)
|
||||
@@ -104,7 +107,7 @@ def merge(args):
|
||||
for key in merged_sd.keys():
|
||||
if key in model_keys:
|
||||
continue
|
||||
print(f"Key {key} not in model {model}, use first model's value")
|
||||
logger.warning(f"Key {key} not in model {model}, use first model's value")
|
||||
if key in supplementary_key_ratios:
|
||||
supplementary_key_ratios[key] += ratio
|
||||
else:
|
||||
@@ -112,7 +115,7 @@ def merge(args):
|
||||
|
||||
# add supplementary keys' value (including VAE and TextEncoder)
|
||||
if len(supplementary_key_ratios) > 0:
|
||||
print("add first model's value")
|
||||
logger.info("add first model's value")
|
||||
with safe_open(args.models[0], framework="pt", device=args.device) as f:
|
||||
for key in tqdm(f.keys()):
|
||||
_, new_key = replace_text_encoder_key(key)
|
||||
@@ -120,7 +123,7 @@ def merge(args):
|
||||
continue
|
||||
|
||||
if is_unet_key(new_key): # not VAE or TextEncoder
|
||||
print(f"Key {new_key} not in all models, ratio = {supplementary_key_ratios[new_key]}")
|
||||
logger.warning(f"Key {new_key} not in all models, ratio = {supplementary_key_ratios[new_key]}")
|
||||
|
||||
value = f.get_tensor(key) # original key
|
||||
|
||||
@@ -134,7 +137,7 @@ def merge(args):
|
||||
if not output_file.endswith(".safetensors"):
|
||||
output_file = output_file + ".safetensors"
|
||||
|
||||
print(f"Saving to {output_file}...")
|
||||
logger.info(f"Saving to {output_file}...")
|
||||
|
||||
# convert to save_dtype
|
||||
for k in merged_sd.keys():
|
||||
@@ -142,7 +145,7 @@ def merge(args):
|
||||
|
||||
save_file(merged_sd, output_file)
|
||||
|
||||
print("Done!")
|
||||
logger.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -7,7 +7,10 @@ from safetensors.torch import load_file
|
||||
from library.original_unet import UNet2DConditionModel, SampleOutput
|
||||
|
||||
import library.model_util as model_util
|
||||
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ControlNetInfo(NamedTuple):
|
||||
unet: Any
|
||||
@@ -51,7 +54,7 @@ def load_control_net(v2, unet, model):
|
||||
|
||||
# control sdからキー変換しつつU-Netに対応する部分のみ取り出し、DiffusersのU-Netに読み込む
|
||||
# state dictを読み込む
|
||||
print(f"ControlNet: loading control SD model : {model}")
|
||||
logger.info(f"ControlNet: loading control SD model : {model}")
|
||||
|
||||
if model_util.is_safetensors(model):
|
||||
ctrl_sd_sd = load_file(model)
|
||||
@@ -61,7 +64,7 @@ def load_control_net(v2, unet, model):
|
||||
|
||||
# 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む
|
||||
is_difference = "difference" in ctrl_sd_sd
|
||||
print("ControlNet: loading difference:", is_difference)
|
||||
logger.info(f"ControlNet: loading difference: {is_difference}")
|
||||
|
||||
# ControlNetには存在しないキーがあるので、まず現在のU-NetでSD版の全keyを作っておく
|
||||
# またTransfer Controlの元weightとなる
|
||||
@@ -89,13 +92,13 @@ def load_control_net(v2, unet, model):
|
||||
# ControlNetのU-Netを作成する
|
||||
ctrl_unet = UNet2DConditionModel(**unet_config)
|
||||
info = ctrl_unet.load_state_dict(ctrl_unet_du_sd)
|
||||
print("ControlNet: loading Control U-Net:", info)
|
||||
logger.info(f"ControlNet: loading Control U-Net: {info}")
|
||||
|
||||
# U-Net以外のControlNetを作成する
|
||||
# TODO support middle only
|
||||
ctrl_net = ControlNet()
|
||||
info = ctrl_net.load_state_dict(zero_conv_sd)
|
||||
print("ControlNet: loading ControlNet:", info)
|
||||
logger.info("ControlNet: loading ControlNet: {info}")
|
||||
|
||||
ctrl_unet.to(unet.device, dtype=unet.dtype)
|
||||
ctrl_net.to(unet.device, dtype=unet.dtype)
|
||||
@@ -117,7 +120,7 @@ def load_preprocess(prep_type: str):
|
||||
|
||||
return canny
|
||||
|
||||
print("Unsupported prep type:", prep_type)
|
||||
logger.info(f"Unsupported prep type: {prep_type}")
|
||||
return None
|
||||
|
||||
|
||||
@@ -174,7 +177,7 @@ def call_unet_and_control_net(
|
||||
cnet_idx = step % cnet_cnt
|
||||
cnet_info = control_nets[cnet_idx]
|
||||
|
||||
# print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
|
||||
# logger.info(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
|
||||
if cnet_info.ratio < current_ratio:
|
||||
return original_unet(sample, timestep, encoder_hidden_states)
|
||||
|
||||
@@ -192,7 +195,7 @@ def call_unet_and_control_net(
|
||||
# ControlNet
|
||||
cnet_outs_list = []
|
||||
for i, cnet_info in enumerate(control_nets):
|
||||
# print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
|
||||
# logger.info(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
|
||||
if cnet_info.ratio < current_ratio:
|
||||
continue
|
||||
guided_hint = guided_hints[i]
|
||||
@@ -232,7 +235,7 @@ def unet_forward(
|
||||
upsample_size = None
|
||||
|
||||
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
||||
print("Forward upsample size to force interpolation output size.")
|
||||
logger.info("Forward upsample size to force interpolation output size.")
|
||||
forward_upsample_size = True
|
||||
|
||||
# 1. time
|
||||
|
||||
@@ -6,7 +6,10 @@ import shutil
|
||||
import math
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2, interpolation=None, save_as_png=False, copy_associated_files=False):
|
||||
# Split the max_resolution string by "," and strip any whitespaces
|
||||
@@ -83,7 +86,7 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
|
||||
image.save(os.path.join(dst_img_folder, new_filename), quality=100)
|
||||
|
||||
proc = "Resized" if current_pixels > max_pixels else "Saved"
|
||||
print(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}")
|
||||
logger.info(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}")
|
||||
|
||||
# If other files with same basename, copy them with resolution suffix
|
||||
if copy_associated_files:
|
||||
@@ -94,7 +97,7 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
|
||||
continue
|
||||
for max_resolution in max_resolutions:
|
||||
new_asoc_file = base + '+' + max_resolution + ext
|
||||
print(f"Copy {asoc_file} as {new_asoc_file}")
|
||||
logger.info(f"Copy {asoc_file} as {new_asoc_file}")
|
||||
shutil.copy(os.path.join(src_img_folder, asoc_file), os.path.join(dst_img_folder, new_asoc_file))
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
import json
|
||||
import argparse
|
||||
from safetensors import safe_open
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model", type=str, required=True)
|
||||
@@ -10,10 +14,10 @@ with safe_open(args.model, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
|
||||
if metadata is None:
|
||||
print("No metadata found")
|
||||
logger.error("No metadata found")
|
||||
else:
|
||||
# metadata is json dict, but not pretty printed
|
||||
# sort by key and pretty print
|
||||
print(json.dumps(metadata, indent=4, sort_keys=True))
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user