fix face_crop_aug not working on finetune method, prepare upscaler

This commit is contained in:
Kohya S
2023-04-22 10:41:36 +09:00
parent 220436244c
commit 884e6bff5d
3 changed files with 403 additions and 10 deletions

View File

@@ -945,7 +945,7 @@ class PipelineLike:
# encode the init image into latents and scale the latents
init_image = init_image.to(device=self.device, dtype=latents_dtype)
if init_image.size()[2:] == (height // 8, width // 8):
if init_image.size()[1:] == (height // 8, width // 8):
init_latents = init_image
else:
if vae_batch_size >= batch_size:
@@ -1015,7 +1015,7 @@ class PipelineLike:
if self.control_nets:
if reginonal_network:
num_sub_and_neg_prompts = len(text_embeddings) // batch_size
text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2::num_sub_and_neg_prompts] # last subprompt
text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt
else:
text_emb_last = text_embeddings
noise_pred = original_control_net.call_unet_and_control_net(
@@ -2318,6 +2318,22 @@ def main(args):
else:
networks = []
# upscalerの指定があれば取得する
upscaler = None
if args.highres_fix_upscaler:
print("import upscaler module:", args.highres_fix_upscaler)
imported_module = importlib.import_module(args.highres_fix_upscaler)
us_kwargs = {}
if args.highres_fix_upscaler_args:
for net_arg in args.highres_fix_upscaler_args.split(";"):
key, value = net_arg.split("=")
us_kwargs[key] = value
print("create upscaler")
upscaler = imported_module.create_upscaler(**us_kwargs)
upscaler.to(dtype).to(device)
# ControlNetの処理
control_nets: List[ControlNetInfo] = []
if args.control_net_models:
@@ -2590,7 +2606,7 @@ def main(args):
np_mask = np_mask[:, :, i]
size = np_mask.shape
else:
np_mask = np.full(size, 255, dtype=np.uint8)
np_mask = np.full(size, 255, dtype=np.uint8)
mask = torch.from_numpy(np_mask.astype(np.float32) / 255.0)
network.set_region(i, i == len(networks) - 1, mask)
mask_images = None
@@ -2639,6 +2655,8 @@ def main(args):
# highres_fixの処理
if highres_fix and not highres_1st:
# 1st stageのバッチを作成して呼び出すサイズを小さくして呼び出す
is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling
print("process 1st stage")
batch_1st = []
for _, base, ext in batch:
@@ -2657,12 +2675,32 @@ def main(args):
ext.network_muls,
ext.num_sub_prompts,
)
batch_1st.append(BatchData(args.highres_fix_latents_upscaling, base, ext_1st))
batch_1st.append(BatchData(is_1st_latent, base, ext_1st))
images_1st = process_batch(batch_1st, True, True)
# 2nd stageのバッチを作成して以下処理する
print("process 2nd stage")
if args.highres_fix_latents_upscaling:
width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height
if upscaler:
# upscalerを使って画像を拡大する
lowreso_imgs = None if is_1st_latent else images_1st
lowreso_latents = None if not is_1st_latent else images_1st
# 戻り値はPIL.Image.Imageかtorch.Tensorのlatents
batch_size = len(images_1st)
vae_batch_size = (
batch_size
if args.vae_batch_size is None
else (max(1, int(batch_size * args.vae_batch_size)) if args.vae_batch_size < 1 else args.vae_batch_size)
)
vae_batch_size = int(vae_batch_size)
images_1st = upscaler.upscale(
vae, lowreso_imgs, lowreso_latents, dtype, width_2nd, height_2nd, batch_size, vae_batch_size
)
elif args.highres_fix_latents_upscaling:
# latentを拡大する
org_dtype = images_1st.dtype
if images_1st.dtype == torch.bfloat16:
images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない
@@ -2671,10 +2709,12 @@ def main(args):
) # , antialias=True)
images_1st = images_1st.to(org_dtype)
else:
# 画像をLANCZOSで拡大する
images_1st = [image.resize((width_2nd, height_2nd), resample=PIL.Image.LANCZOS) for image in images_1st]
batch_2nd = []
for i, (bd, image) in enumerate(zip(batch, images_1st)):
if not args.highres_fix_latents_upscaling:
image = image.resize((bd.ext.width, bd.ext.height), resample=PIL.Image.LANCZOS) # img2imgとして設定
bd_2nd = BatchData(False, BatchDataBase(*bd.base[0:3], bd.base.seed + 1, image, None, *bd.base[6:]), bd.ext)
batch_2nd.append(bd_2nd)
batch = batch_2nd
@@ -3229,6 +3269,16 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="use latents upscaling for highres fix / highres fixでlatentで拡大する",
)
parser.add_argument(
"--highres_fix_upscaler", type=str, default=None, help="upscaler module for highres fix / highres fixで使うupscalerのモジュール名"
)
parser.add_argument(
"--highres_fix_upscaler_args",
type=str,
default=None,
help="additional argmuments for upscaler (key=value) / upscalerへの追加の引数",
)
parser.add_argument(
"--negative_scale", type=float, default=None, help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する"
)

View File

@@ -845,9 +845,10 @@ class BaseDataset(torch.utils.data.Dataset):
# 画像サイズはsizeより大きいのでリサイズする
face_size = max(face_w, face_h)
size = min(self.height, self.width) # 短いほう
min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
min_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[1]))) # 指定した顔最小サイズ
max_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[0]))) # 指定した顔最大サイズ
min_scale = min(1.0, max(min_scale, size / (face_size * subset.face_crop_aug_range[1]))) # 指定した顔最小サイズ
max_scale = min(1.0, max(min_scale, size / (face_size * subset.face_crop_aug_range[0]))) # 指定した顔最大サイズ
if min_scale >= max_scale: # range指定がmin==max
scale = min_scale
else:
@@ -872,7 +873,7 @@ class BaseDataset(torch.utils.data.Dataset):
else:
# range指定があるときのみ、すこしだけランダムにわりと適当
if subset.face_crop_aug_range[0] != subset.face_crop_aug_range[1]:
if face_size > self.size // 10 and face_size >= 40:
if face_size > size // 10 and face_size >= 40:
p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
p1 = max(0, min(p1, length - target_size))

342
tools/latent_upscaler.py Normal file
View File

@@ -0,0 +1,342 @@
# 外部から簡単にupscalerを呼ぶためのスクリプト
# 単体で動くようにモデル定義も含めている
import argparse
import glob
import os
import cv2
from diffusers import AutoencoderKL
from typing import Dict, List
import numpy as np
import torch
from torch import nn
from tqdm import tqdm
from PIL import Image
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=1):
super(ResidualBlock, self).__init__()
if out_channels is None:
out_channels = in_channels
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.relu2 = nn.ReLU(inplace=True) # このReLUはresidualに足す前にかけるほうがいいかも
# initialize weights
self._initialize_weights()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu1(out)
out = self.conv2(out)
out = self.bn2(out)
out += residual
out = self.relu2(out)
return out
class Upscaler(nn.Module):
def __init__(self):
super(Upscaler, self).__init__()
# define layers
# latent has 4 channels
self.conv1 = nn.Conv2d(4, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
self.bn1 = nn.BatchNorm2d(128)
self.relu1 = nn.ReLU(inplace=True)
# resblocks
# 数の暴力で20個次元数を増やすよりもブロックを増やしたほうがreceptive fieldが広がるはずだぞ
self.resblock1 = ResidualBlock(128)
self.resblock2 = ResidualBlock(128)
self.resblock3 = ResidualBlock(128)
self.resblock4 = ResidualBlock(128)
self.resblock5 = ResidualBlock(128)
self.resblock6 = ResidualBlock(128)
self.resblock7 = ResidualBlock(128)
self.resblock8 = ResidualBlock(128)
self.resblock9 = ResidualBlock(128)
self.resblock10 = ResidualBlock(128)
self.resblock11 = ResidualBlock(128)
self.resblock12 = ResidualBlock(128)
self.resblock13 = ResidualBlock(128)
self.resblock14 = ResidualBlock(128)
self.resblock15 = ResidualBlock(128)
self.resblock16 = ResidualBlock(128)
self.resblock17 = ResidualBlock(128)
self.resblock18 = ResidualBlock(128)
self.resblock19 = ResidualBlock(128)
self.resblock20 = ResidualBlock(128)
# last convs
self.conv2 = nn.Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
self.bn2 = nn.BatchNorm2d(64)
self.relu2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
self.bn3 = nn.BatchNorm2d(64)
self.relu3 = nn.ReLU(inplace=True)
# final conv: output 4 channels
self.conv_final = nn.Conv2d(64, 4, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
# initialize weights
self._initialize_weights()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
# initialize final conv weights to 0: 流行りのzero conv
nn.init.constant_(self.conv_final.weight, 0)
def forward(self, x):
inp = x
x = self.conv1(x)
x = self.bn1(x)
x = self.relu1(x)
# いくつかのresblockを通した後に、residualを足すことで精度向上と学習速度向上が見込めるはず
residual = x
x = self.resblock1(x)
x = self.resblock2(x)
x = self.resblock3(x)
x = self.resblock4(x)
x = x + residual
residual = x
x = self.resblock5(x)
x = self.resblock6(x)
x = self.resblock7(x)
x = self.resblock8(x)
x = x + residual
residual = x
x = self.resblock9(x)
x = self.resblock10(x)
x = self.resblock11(x)
x = self.resblock12(x)
x = x + residual
residual = x
x = self.resblock13(x)
x = self.resblock14(x)
x = self.resblock15(x)
x = self.resblock16(x)
x = x + residual
residual = x
x = self.resblock17(x)
x = self.resblock18(x)
x = self.resblock19(x)
x = self.resblock20(x)
x = x + residual
x = self.conv2(x)
x = self.bn2(x)
x = self.relu2(x)
x = self.conv3(x)
x = self.bn3(x)
# ここにreluを入れないほうがいい気がする
x = self.conv_final(x)
# network estimates the difference between the input and the output
x = x + inp
return x
def support_latents(self) -> bool:
return False
def upscale(
self,
vae: AutoencoderKL,
lowreso_images: List[Image.Image],
lowreso_latents: torch.Tensor,
dtype: torch.dtype,
width: int,
height: int,
batch_size: int = 1,
vae_batch_size: int = 1,
):
# assertion
assert lowreso_images is not None, "Upscaler requires lowreso image"
# make upsampled image with lanczos4
upsampled_images = []
for lowreso_image in lowreso_images:
upsampled_image = np.array(lowreso_image.resize((width, height), Image.LANCZOS))
upsampled_images.append(upsampled_image)
# convert to tensor: this tensor is too large to be converted to cuda
upsampled_images = [torch.from_numpy(upsampled_image).permute(2, 0, 1).float() for upsampled_image in upsampled_images]
upsampled_images = torch.stack(upsampled_images, dim=0)
upsampled_images = upsampled_images.to(dtype)
# normalize to [-1, 1]
upsampled_images = upsampled_images / 127.5 - 1.0
# convert upsample images to latents with batch size
# print("Encoding upsampled (LANCZOS4) images...")
upsampled_latents = []
for i in tqdm(range(0, upsampled_images.shape[0], vae_batch_size)):
batch = upsampled_images[i : i + vae_batch_size].to(vae.device)
with torch.no_grad():
batch = vae.encode(batch).latent_dist.sample()
upsampled_latents.append(batch)
upsampled_latents = torch.cat(upsampled_latents, dim=0)
# upscale (refine) latents with this model with batch size
print("Upscaling latents...")
upscaled_latents = []
for i in range(0, upsampled_latents.shape[0], batch_size):
with torch.no_grad():
upscaled_latents.append(self.forward(upsampled_latents[i : i + batch_size]))
upscaled_latents = torch.cat(upscaled_latents, dim=0)
return upscaled_latents * 0.18215
# external interface: returns a model
def create_upscaler(**kwargs):
weights = kwargs["weights"]
model = Upscaler()
print(f"Loading weights from {weights}...")
model.load_state_dict(torch.load(weights, map_location=torch.device("cpu")))
return model
# another interface: upscale images with a model for given images from command line
def upscale_images(args: argparse.Namespace):
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
us_dtype = torch.float16 # TODO: support fp32/bf16
os.makedirs(args.output_dir, exist_ok=True)
# load VAE with Diffusers
assert args.vae_path is not None, "VAE path is required"
print(f"Loading VAE from {args.vae_path}...")
vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae")
vae.to(DEVICE, dtype=us_dtype)
# prepare model
print("Preparing model...")
upscaler: Upscaler = create_upscaler(weights=args.weights)
# print("Loading weights from", args.weights)
# upscaler.load_state_dict(torch.load(args.weights))
upscaler.eval()
upscaler.to(DEVICE, dtype=us_dtype)
# load images
image_paths = glob.glob(args.image_pattern)
images = []
for image_path in image_paths:
image = Image.open(image_path)
image = image.convert("RGB")
# make divisible by 8
width = image.width
height = image.height
if width % 8 != 0:
width = width - (width % 8)
if height % 8 != 0:
height = height - (height % 8)
if width != image.width or height != image.height:
image = image.crop((0, 0, width, height))
images.append(image)
# debug output
if args.debug:
for image, image_path in zip(images, image_paths):
image_debug = image.resize((image.width * 2, image.height * 2), Image.LANCZOS)
basename = os.path.basename(image_path)
basename_wo_ext, ext = os.path.splitext(basename)
dest_file_name = os.path.join(args.output_dir, f"{basename_wo_ext}_lanczos4{ext}")
image_debug.save(dest_file_name)
# upscale
print("Upscaling...")
upscaled_latents = upscaler.upscale(
vae, images, None, us_dtype, width * 2, height * 2, batch_size=args.batch_size, vae_batch_size=args.vae_batch_size
)
upscaled_latents /= 0.18215
# decode with batch
print("Decoding...")
upscaled_images = []
for i in tqdm(range(0, upscaled_latents.shape[0], args.vae_batch_size)):
with torch.no_grad():
batch = vae.decode(upscaled_latents[i : i + args.vae_batch_size]).sample
batch = batch.to("cpu")
upscaled_images.append(batch)
upscaled_images = torch.cat(upscaled_images, dim=0)
# tensor to numpy
upscaled_images = upscaled_images.permute(0, 2, 3, 1).numpy()
upscaled_images = (upscaled_images + 1.0) * 127.5
upscaled_images = upscaled_images.clip(0, 255).astype(np.uint8)
upscaled_images = upscaled_images[..., ::-1]
# save images
for i, image in enumerate(upscaled_images):
basename = os.path.basename(image_paths[i])
basename_wo_ext, ext = os.path.splitext(basename)
dest_file_name = os.path.join(args.output_dir, f"{basename_wo_ext}_upscaled{ext}")
cv2.imwrite(dest_file_name, image)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--vae_path", type=str, default=None, help="VAE path")
parser.add_argument("--weights", type=str, default=None, help="Weights path")
parser.add_argument("--image_pattern", type=str, default=None, help="Image pattern")
parser.add_argument("--output_dir", type=str, default=".", help="Output directory")
parser.add_argument("--batch_size", type=int, default=4, help="Batch size")
parser.add_argument("--vae_batch_size", type=int, default=1, help="VAE batch size")
parser.add_argument("--debug", action="store_true", help="Debug mode")
args = parser.parse_args()
upscale_images(args)