enable i2i with highres fix, add slicing VAE

This commit is contained in:
Kohya S
2023-05-10 23:09:25 +09:00
parent 437501cde3
commit 1b4bdff331
2 changed files with 727 additions and 17 deletions

View File

@@ -955,7 +955,7 @@ class PipelineLike:
if torch.cuda.is_available():
torch.cuda.empty_cache()
init_latents = []
for i in tqdm(range(0, batch_size, vae_batch_size)):
for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)):
init_latent_dist = self.vae.encode(
init_image[i : i + vae_batch_size] if vae_batch_size > 1 else init_image[i].unsqueeze(0)
).latent_dist
@@ -2091,7 +2091,7 @@ def main(args):
dtype = torch.float32
highres_fix = args.highres_fix_scale is not None
assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません"
# assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません"
if args.v_parameterization and not args.v2:
print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません")
@@ -2250,7 +2250,27 @@ def main(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない
# custom pipelineをコピったやつを生成する
if args.vae_slices:
from library.slicing_vae import SlicingAutoencoderKL
sli_vae = SlicingAutoencoderKL(
act_fn="silu",
block_out_channels=(128, 256, 512, 512),
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"],
in_channels=3,
latent_channels=4,
layers_per_block=2,
norm_num_groups=32,
out_channels=3,
sample_size=512,
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"],
num_slices=args.vae_slices,
)
sli_vae.load_state_dict(vae.state_dict()) # vaeのパラメータをコピーする
vae = sli_vae
del sli_vae
vae.to(dtype).to(device)
text_encoder.to(dtype).to(device)
unet.to(dtype).to(device)
if clip_model is not None:
@@ -2262,7 +2282,7 @@ def main(args):
if args.network_module:
networks = []
network_default_muls = []
network_pre_calc=args.network_pre_calc
network_pre_calc = args.network_pre_calc
for i, network_module in enumerate(args.network_module):
print("import network module:", network_module)
@@ -2592,12 +2612,18 @@ def main(args):
# 画像サイズにオプション指定があるときはリサイズする
if args.W is not None and args.H is not None:
# highres fix を考慮に入れる
w, h = args.W, args.H
if highres_fix:
w = int(w * args.highres_fix_scale + 0.5)
h = int(h * args.highres_fix_scale + 0.5)
if init_images is not None:
print(f"resize img2img source images to {args.W}*{args.H}")
init_images = resize_images(init_images, (args.W, args.H))
print(f"resize img2img source images to {w}*{h}")
init_images = resize_images(init_images, (w, h))
if mask_images is not None:
print(f"resize img2img mask images to {args.W}*{args.H}")
mask_images = resize_images(mask_images, (args.W, args.H))
print(f"resize img2img mask images to {w}*{h}")
mask_images = resize_images(mask_images, (w, h))
regional_network = False
if networks and mask_images:
@@ -2671,13 +2697,15 @@ def main(args):
width_1st = width_1st - width_1st % 32
height_1st = height_1st - height_1st % 32
strength_1st = ext.strength if args.highres_fix_strength is None else args.highres_fix_strength
ext_1st = BatchDataExt(
width_1st,
height_1st,
args.highres_fix_steps,
ext.scale,
ext.negative_scale,
ext.strength,
strength_1st,
ext.network_muls,
ext.num_sub_prompts,
)
@@ -2827,7 +2855,7 @@ def main(args):
n.set_multiplier(m)
if regional_network:
n.set_current_generation(batch_size, num_sub_prompts, width, height, shared)
if not regional_network and network_pre_calc:
for n in networks:
n.restore_weights()
@@ -3032,14 +3060,16 @@ def main(args):
if init_images is not None:
init_image = init_images[global_step % len(init_images)]
# img2imgの場合は、基本的に元画像のサイズで生成する。highres fixの場合はargs.W, args.Hとscaleに従いリサイズ済みなので無視する
# 32単位に丸めたやつにresizeされるので踏襲する
width, height = init_image.size
width = width - width % 32
height = height - height % 32
if width != init_image.size[0] or height != init_image.size[1]:
print(
f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます"
)
if not highres_fix:
width, height = init_image.size
width = width - width % 32
height = height - height % 32
if width != init_image.size[0] or height != init_image.size[1]:
print(
f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます"
)
if mask_images is not None:
mask_image = mask_images[global_step % len(mask_images)]
@@ -3141,6 +3171,13 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="batch size for VAE, < 1.0 for ratio / VAE処理時のバッチサイズ、1未満の値の場合は通常バッチサイズの比率",
)
parser.add_argument(
"--vae_slices",
type=int,
default=None,
help=
"number of slices to split image into for VAE to reduce VRAM usage, None for no splitting (default), slower if specified. 16 or 32 recommended / VAE処理時にVRAM使用量削減のため画像を分割するスライス数、Noneの場合は分割しないデフォルト、指定すると遅くなる。16か32程度を推奨"
)
parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数")
parser.add_argument(
"--sampler",
@@ -3218,7 +3255,9 @@ def setup_parser() -> argparse.ArgumentParser:
)
parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する")
parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする")
parser.add_argument("--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する")
parser.add_argument(
"--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する"
)
parser.add_argument(
"--textual_inversion_embeddings",
type=str,
@@ -3276,6 +3315,12 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--highres_fix_steps", type=int, default=28, help="1st stage steps for highres fix / highres fixの最初のステージのステップ数"
)
parser.add_argument(
"--highres_fix_strength",
type=float,
default=None,
help="1st stage img2img strength for highres fix / highres fixの最初のステージのimg2img時のstrength、省略時はstrengthと同じ",
)
parser.add_argument(
"--highres_fix_save_1st", action="store_true", help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する"
)