mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
enable i2i with highres fix, add slicing VAE
This commit is contained in:
@@ -955,7 +955,7 @@ class PipelineLike:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
init_latents = []
|
||||
for i in tqdm(range(0, batch_size, vae_batch_size)):
|
||||
for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)):
|
||||
init_latent_dist = self.vae.encode(
|
||||
init_image[i : i + vae_batch_size] if vae_batch_size > 1 else init_image[i].unsqueeze(0)
|
||||
).latent_dist
|
||||
@@ -2091,7 +2091,7 @@ def main(args):
|
||||
dtype = torch.float32
|
||||
|
||||
highres_fix = args.highres_fix_scale is not None
|
||||
assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません"
|
||||
# assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません"
|
||||
|
||||
if args.v_parameterization and not args.v2:
|
||||
print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません")
|
||||
@@ -2250,7 +2250,27 @@ def main(args):
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない
|
||||
|
||||
# custom pipelineをコピったやつを生成する
|
||||
if args.vae_slices:
|
||||
from library.slicing_vae import SlicingAutoencoderKL
|
||||
|
||||
sli_vae = SlicingAutoencoderKL(
|
||||
act_fn="silu",
|
||||
block_out_channels=(128, 256, 512, 512),
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
in_channels=3,
|
||||
latent_channels=4,
|
||||
layers_per_block=2,
|
||||
norm_num_groups=32,
|
||||
out_channels=3,
|
||||
sample_size=512,
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
num_slices=args.vae_slices,
|
||||
)
|
||||
sli_vae.load_state_dict(vae.state_dict()) # vaeのパラメータをコピーする
|
||||
vae = sli_vae
|
||||
del sli_vae
|
||||
vae.to(dtype).to(device)
|
||||
|
||||
text_encoder.to(dtype).to(device)
|
||||
unet.to(dtype).to(device)
|
||||
if clip_model is not None:
|
||||
@@ -2262,7 +2282,7 @@ def main(args):
|
||||
if args.network_module:
|
||||
networks = []
|
||||
network_default_muls = []
|
||||
network_pre_calc=args.network_pre_calc
|
||||
network_pre_calc = args.network_pre_calc
|
||||
|
||||
for i, network_module in enumerate(args.network_module):
|
||||
print("import network module:", network_module)
|
||||
@@ -2592,12 +2612,18 @@ def main(args):
|
||||
|
||||
# 画像サイズにオプション指定があるときはリサイズする
|
||||
if args.W is not None and args.H is not None:
|
||||
# highres fix を考慮に入れる
|
||||
w, h = args.W, args.H
|
||||
if highres_fix:
|
||||
w = int(w * args.highres_fix_scale + 0.5)
|
||||
h = int(h * args.highres_fix_scale + 0.5)
|
||||
|
||||
if init_images is not None:
|
||||
print(f"resize img2img source images to {args.W}*{args.H}")
|
||||
init_images = resize_images(init_images, (args.W, args.H))
|
||||
print(f"resize img2img source images to {w}*{h}")
|
||||
init_images = resize_images(init_images, (w, h))
|
||||
if mask_images is not None:
|
||||
print(f"resize img2img mask images to {args.W}*{args.H}")
|
||||
mask_images = resize_images(mask_images, (args.W, args.H))
|
||||
print(f"resize img2img mask images to {w}*{h}")
|
||||
mask_images = resize_images(mask_images, (w, h))
|
||||
|
||||
regional_network = False
|
||||
if networks and mask_images:
|
||||
@@ -2671,13 +2697,15 @@ def main(args):
|
||||
width_1st = width_1st - width_1st % 32
|
||||
height_1st = height_1st - height_1st % 32
|
||||
|
||||
strength_1st = ext.strength if args.highres_fix_strength is None else args.highres_fix_strength
|
||||
|
||||
ext_1st = BatchDataExt(
|
||||
width_1st,
|
||||
height_1st,
|
||||
args.highres_fix_steps,
|
||||
ext.scale,
|
||||
ext.negative_scale,
|
||||
ext.strength,
|
||||
strength_1st,
|
||||
ext.network_muls,
|
||||
ext.num_sub_prompts,
|
||||
)
|
||||
@@ -2827,7 +2855,7 @@ def main(args):
|
||||
n.set_multiplier(m)
|
||||
if regional_network:
|
||||
n.set_current_generation(batch_size, num_sub_prompts, width, height, shared)
|
||||
|
||||
|
||||
if not regional_network and network_pre_calc:
|
||||
for n in networks:
|
||||
n.restore_weights()
|
||||
@@ -3032,14 +3060,16 @@ def main(args):
|
||||
if init_images is not None:
|
||||
init_image = init_images[global_step % len(init_images)]
|
||||
|
||||
# img2imgの場合は、基本的に元画像のサイズで生成する。highres fixの場合はargs.W, args.Hとscaleに従いリサイズ済みなので無視する
|
||||
# 32単位に丸めたやつにresizeされるので踏襲する
|
||||
width, height = init_image.size
|
||||
width = width - width % 32
|
||||
height = height - height % 32
|
||||
if width != init_image.size[0] or height != init_image.size[1]:
|
||||
print(
|
||||
f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます"
|
||||
)
|
||||
if not highres_fix:
|
||||
width, height = init_image.size
|
||||
width = width - width % 32
|
||||
height = height - height % 32
|
||||
if width != init_image.size[0] or height != init_image.size[1]:
|
||||
print(
|
||||
f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます"
|
||||
)
|
||||
|
||||
if mask_images is not None:
|
||||
mask_image = mask_images[global_step % len(mask_images)]
|
||||
@@ -3141,6 +3171,13 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
default=None,
|
||||
help="batch size for VAE, < 1.0 for ratio / VAE処理時のバッチサイズ、1未満の値の場合は通常バッチサイズの比率",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_slices",
|
||||
type=int,
|
||||
default=None,
|
||||
help=
|
||||
"number of slices to split image into for VAE to reduce VRAM usage, None for no splitting (default), slower if specified. 16 or 32 recommended / VAE処理時にVRAM使用量削減のため画像を分割するスライス数、Noneの場合は分割しない(デフォルト)、指定すると遅くなる。16か32程度を推奨"
|
||||
)
|
||||
parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数")
|
||||
parser.add_argument(
|
||||
"--sampler",
|
||||
@@ -3218,7 +3255,9 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
)
|
||||
parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する")
|
||||
parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする")
|
||||
parser.add_argument("--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する")
|
||||
parser.add_argument(
|
||||
"--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--textual_inversion_embeddings",
|
||||
type=str,
|
||||
@@ -3276,6 +3315,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument(
|
||||
"--highres_fix_steps", type=int, default=28, help="1st stage steps for highres fix / highres fixの最初のステージのステップ数"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--highres_fix_strength",
|
||||
type=float,
|
||||
default=None,
|
||||
help="1st stage img2img strength for highres fix / highres fixの最初のステージのimg2img時のstrength、省略時はstrengthと同じ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--highres_fix_save_1st", action="store_true", help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user