mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 05:44:56 +00:00
355 lines
12 KiB
Python
355 lines
12 KiB
Python
# 外部から簡単にupscalerを呼ぶためのスクリプト
|
||
# 単体で動くようにモデル定義も含めている
|
||
|
||
import argparse
|
||
import glob
|
||
import os
|
||
import cv2
|
||
from diffusers import AutoencoderKL
|
||
|
||
from typing import Dict, List
|
||
import numpy as np
|
||
|
||
import torch
|
||
from library.device_utils import init_ipex, get_preferred_device
|
||
init_ipex()
|
||
|
||
from torch import nn
|
||
from tqdm import tqdm
|
||
from PIL import Image
|
||
from library.utils import setup_logging
|
||
setup_logging()
|
||
import logging
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class ResidualBlock(nn.Module):
|
||
def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=1):
|
||
super(ResidualBlock, self).__init__()
|
||
|
||
if out_channels is None:
|
||
out_channels = in_channels
|
||
|
||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
|
||
self.bn1 = nn.BatchNorm2d(out_channels)
|
||
self.relu1 = nn.ReLU(inplace=True)
|
||
|
||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding, bias=False)
|
||
self.bn2 = nn.BatchNorm2d(out_channels)
|
||
|
||
self.relu2 = nn.ReLU(inplace=True) # このReLUはresidualに足す前にかけるほうがいいかも
|
||
|
||
# initialize weights
|
||
self._initialize_weights()
|
||
|
||
def _initialize_weights(self):
|
||
for m in self.modules():
|
||
if isinstance(m, nn.Conv2d):
|
||
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||
if m.bias is not None:
|
||
nn.init.constant_(m.bias, 0)
|
||
elif isinstance(m, nn.BatchNorm2d):
|
||
nn.init.constant_(m.weight, 1)
|
||
nn.init.constant_(m.bias, 0)
|
||
elif isinstance(m, nn.Linear):
|
||
nn.init.normal_(m.weight, 0, 0.01)
|
||
nn.init.constant_(m.bias, 0)
|
||
|
||
def forward(self, x):
|
||
residual = x
|
||
|
||
out = self.conv1(x)
|
||
out = self.bn1(out)
|
||
out = self.relu1(out)
|
||
|
||
out = self.conv2(out)
|
||
out = self.bn2(out)
|
||
|
||
out += residual
|
||
|
||
out = self.relu2(out)
|
||
|
||
return out
|
||
|
||
|
||
class Upscaler(nn.Module):
|
||
def __init__(self):
|
||
super(Upscaler, self).__init__()
|
||
|
||
# define layers
|
||
# latent has 4 channels
|
||
|
||
self.conv1 = nn.Conv2d(4, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
|
||
self.bn1 = nn.BatchNorm2d(128)
|
||
self.relu1 = nn.ReLU(inplace=True)
|
||
|
||
# resblocks
|
||
# 数の暴力で20個:次元数を増やすよりもブロックを増やしたほうがreceptive fieldが広がるはずだぞ
|
||
self.resblock1 = ResidualBlock(128)
|
||
self.resblock2 = ResidualBlock(128)
|
||
self.resblock3 = ResidualBlock(128)
|
||
self.resblock4 = ResidualBlock(128)
|
||
self.resblock5 = ResidualBlock(128)
|
||
self.resblock6 = ResidualBlock(128)
|
||
self.resblock7 = ResidualBlock(128)
|
||
self.resblock8 = ResidualBlock(128)
|
||
self.resblock9 = ResidualBlock(128)
|
||
self.resblock10 = ResidualBlock(128)
|
||
self.resblock11 = ResidualBlock(128)
|
||
self.resblock12 = ResidualBlock(128)
|
||
self.resblock13 = ResidualBlock(128)
|
||
self.resblock14 = ResidualBlock(128)
|
||
self.resblock15 = ResidualBlock(128)
|
||
self.resblock16 = ResidualBlock(128)
|
||
self.resblock17 = ResidualBlock(128)
|
||
self.resblock18 = ResidualBlock(128)
|
||
self.resblock19 = ResidualBlock(128)
|
||
self.resblock20 = ResidualBlock(128)
|
||
|
||
# last convs
|
||
self.conv2 = nn.Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
|
||
self.bn2 = nn.BatchNorm2d(64)
|
||
self.relu2 = nn.ReLU(inplace=True)
|
||
|
||
self.conv3 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
|
||
self.bn3 = nn.BatchNorm2d(64)
|
||
self.relu3 = nn.ReLU(inplace=True)
|
||
|
||
# final conv: output 4 channels
|
||
self.conv_final = nn.Conv2d(64, 4, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
|
||
|
||
# initialize weights
|
||
self._initialize_weights()
|
||
|
||
def _initialize_weights(self):
|
||
for m in self.modules():
|
||
if isinstance(m, nn.Conv2d):
|
||
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||
if m.bias is not None:
|
||
nn.init.constant_(m.bias, 0)
|
||
elif isinstance(m, nn.BatchNorm2d):
|
||
nn.init.constant_(m.weight, 1)
|
||
nn.init.constant_(m.bias, 0)
|
||
elif isinstance(m, nn.Linear):
|
||
nn.init.normal_(m.weight, 0, 0.01)
|
||
nn.init.constant_(m.bias, 0)
|
||
|
||
# initialize final conv weights to 0: 流行りのzero conv
|
||
nn.init.constant_(self.conv_final.weight, 0)
|
||
|
||
def forward(self, x):
|
||
inp = x
|
||
|
||
x = self.conv1(x)
|
||
x = self.bn1(x)
|
||
x = self.relu1(x)
|
||
|
||
# いくつかのresblockを通した後に、residualを足すことで精度向上と学習速度向上が見込めるはず
|
||
residual = x
|
||
x = self.resblock1(x)
|
||
x = self.resblock2(x)
|
||
x = self.resblock3(x)
|
||
x = self.resblock4(x)
|
||
x = x + residual
|
||
residual = x
|
||
x = self.resblock5(x)
|
||
x = self.resblock6(x)
|
||
x = self.resblock7(x)
|
||
x = self.resblock8(x)
|
||
x = x + residual
|
||
residual = x
|
||
x = self.resblock9(x)
|
||
x = self.resblock10(x)
|
||
x = self.resblock11(x)
|
||
x = self.resblock12(x)
|
||
x = x + residual
|
||
residual = x
|
||
x = self.resblock13(x)
|
||
x = self.resblock14(x)
|
||
x = self.resblock15(x)
|
||
x = self.resblock16(x)
|
||
x = x + residual
|
||
residual = x
|
||
x = self.resblock17(x)
|
||
x = self.resblock18(x)
|
||
x = self.resblock19(x)
|
||
x = self.resblock20(x)
|
||
x = x + residual
|
||
|
||
x = self.conv2(x)
|
||
x = self.bn2(x)
|
||
x = self.relu2(x)
|
||
x = self.conv3(x)
|
||
x = self.bn3(x)
|
||
|
||
# ここにreluを入れないほうがいい気がする
|
||
|
||
x = self.conv_final(x)
|
||
|
||
# network estimates the difference between the input and the output
|
||
x = x + inp
|
||
|
||
return x
|
||
|
||
def support_latents(self) -> bool:
|
||
return False
|
||
|
||
def upscale(
|
||
self,
|
||
vae: AutoencoderKL,
|
||
lowreso_images: List[Image.Image],
|
||
lowreso_latents: torch.Tensor,
|
||
dtype: torch.dtype,
|
||
width: int,
|
||
height: int,
|
||
batch_size: int = 1,
|
||
vae_batch_size: int = 1,
|
||
):
|
||
# assertion
|
||
assert lowreso_images is not None, "Upscaler requires lowreso image"
|
||
|
||
# make upsampled image with lanczos4
|
||
upsampled_images = []
|
||
for lowreso_image in lowreso_images:
|
||
upsampled_image = np.array(lowreso_image.resize((width, height), Image.LANCZOS))
|
||
upsampled_images.append(upsampled_image)
|
||
|
||
# convert to tensor: this tensor is too large to be converted to cuda
|
||
upsampled_images = [torch.from_numpy(upsampled_image).permute(2, 0, 1).float() for upsampled_image in upsampled_images]
|
||
upsampled_images = torch.stack(upsampled_images, dim=0)
|
||
upsampled_images = upsampled_images.to(dtype)
|
||
|
||
# normalize to [-1, 1]
|
||
upsampled_images = upsampled_images / 127.5 - 1.0
|
||
|
||
# convert upsample images to latents with batch size
|
||
# logger.info("Encoding upsampled (LANCZOS4) images...")
|
||
upsampled_latents = []
|
||
for i in tqdm(range(0, upsampled_images.shape[0], vae_batch_size)):
|
||
batch = upsampled_images[i : i + vae_batch_size].to(vae.device)
|
||
with torch.no_grad():
|
||
batch = vae.encode(batch).latent_dist.sample()
|
||
upsampled_latents.append(batch)
|
||
|
||
upsampled_latents = torch.cat(upsampled_latents, dim=0)
|
||
|
||
# upscale (refine) latents with this model with batch size
|
||
logger.info("Upscaling latents...")
|
||
upscaled_latents = []
|
||
for i in range(0, upsampled_latents.shape[0], batch_size):
|
||
with torch.no_grad():
|
||
upscaled_latents.append(self.forward(upsampled_latents[i : i + batch_size]))
|
||
upscaled_latents = torch.cat(upscaled_latents, dim=0)
|
||
|
||
return upscaled_latents * 0.18215
|
||
|
||
|
||
# external interface: returns a model
|
||
def create_upscaler(**kwargs):
|
||
weights = kwargs["weights"]
|
||
model = Upscaler()
|
||
|
||
logger.info(f"Loading weights from {weights}...")
|
||
if os.path.splitext(weights)[1] == ".safetensors":
|
||
from safetensors.torch import load_file
|
||
|
||
sd = load_file(weights)
|
||
else:
|
||
sd = torch.load(weights, map_location=torch.device("cpu"))
|
||
model.load_state_dict(sd)
|
||
return model
|
||
|
||
|
||
# another interface: upscale images with a model for given images from command line
|
||
def upscale_images(args: argparse.Namespace):
|
||
DEVICE = get_preferred_device()
|
||
us_dtype = torch.float16 # TODO: support fp32/bf16
|
||
os.makedirs(args.output_dir, exist_ok=True)
|
||
|
||
# load VAE with Diffusers
|
||
assert args.vae_path is not None, "VAE path is required"
|
||
logger.info(f"Loading VAE from {args.vae_path}...")
|
||
vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae")
|
||
vae.to(DEVICE, dtype=us_dtype)
|
||
|
||
# prepare model
|
||
logger.info("Preparing model...")
|
||
upscaler: Upscaler = create_upscaler(weights=args.weights)
|
||
# logger.info("Loading weights from", args.weights)
|
||
# upscaler.load_state_dict(torch.load(args.weights))
|
||
upscaler.eval()
|
||
upscaler.to(DEVICE, dtype=us_dtype)
|
||
|
||
# load images
|
||
image_paths = glob.glob(args.image_pattern)
|
||
images = []
|
||
for image_path in image_paths:
|
||
image = Image.open(image_path)
|
||
image = image.convert("RGB")
|
||
|
||
# make divisible by 8
|
||
width = image.width
|
||
height = image.height
|
||
if width % 8 != 0:
|
||
width = width - (width % 8)
|
||
if height % 8 != 0:
|
||
height = height - (height % 8)
|
||
if width != image.width or height != image.height:
|
||
image = image.crop((0, 0, width, height))
|
||
|
||
images.append(image)
|
||
|
||
# debug output
|
||
if args.debug:
|
||
for image, image_path in zip(images, image_paths):
|
||
image_debug = image.resize((image.width * 2, image.height * 2), Image.LANCZOS)
|
||
|
||
basename = os.path.basename(image_path)
|
||
basename_wo_ext, ext = os.path.splitext(basename)
|
||
dest_file_name = os.path.join(args.output_dir, f"{basename_wo_ext}_lanczos4{ext}")
|
||
image_debug.save(dest_file_name)
|
||
|
||
# upscale
|
||
logger.info("Upscaling...")
|
||
upscaled_latents = upscaler.upscale(
|
||
vae, images, None, us_dtype, width * 2, height * 2, batch_size=args.batch_size, vae_batch_size=args.vae_batch_size
|
||
)
|
||
upscaled_latents /= 0.18215
|
||
|
||
# decode with batch
|
||
logger.info("Decoding...")
|
||
upscaled_images = []
|
||
for i in tqdm(range(0, upscaled_latents.shape[0], args.vae_batch_size)):
|
||
with torch.no_grad():
|
||
batch = vae.decode(upscaled_latents[i : i + args.vae_batch_size]).sample
|
||
batch = batch.to("cpu")
|
||
upscaled_images.append(batch)
|
||
upscaled_images = torch.cat(upscaled_images, dim=0)
|
||
|
||
# tensor to numpy
|
||
upscaled_images = upscaled_images.permute(0, 2, 3, 1).numpy()
|
||
upscaled_images = (upscaled_images + 1.0) * 127.5
|
||
upscaled_images = upscaled_images.clip(0, 255).astype(np.uint8)
|
||
|
||
upscaled_images = upscaled_images[..., ::-1]
|
||
|
||
# save images
|
||
for i, image in enumerate(upscaled_images):
|
||
basename = os.path.basename(image_paths[i])
|
||
basename_wo_ext, ext = os.path.splitext(basename)
|
||
dest_file_name = os.path.join(args.output_dir, f"{basename_wo_ext}_upscaled{ext}")
|
||
cv2.imwrite(dest_file_name, image)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument("--vae_path", type=str, default=None, help="VAE path")
|
||
parser.add_argument("--weights", type=str, default=None, help="Weights path")
|
||
parser.add_argument("--image_pattern", type=str, default=None, help="Image pattern")
|
||
parser.add_argument("--output_dir", type=str, default=".", help="Output directory")
|
||
parser.add_argument("--batch_size", type=int, default=4, help="Batch size")
|
||
parser.add_argument("--vae_batch_size", type=int, default=1, help="VAE batch size")
|
||
parser.add_argument("--debug", action="store_true", help="Debug mode")
|
||
|
||
args = parser.parse_args()
|
||
upscale_images(args)
|