From f41e9e2b587e6700edbd98ddf03624612cfcf445 Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Sun, 21 Sep 2025 11:09:37 +0900 Subject: [PATCH] feat: add vae_chunk_size argument for memory-efficient VAE decoding and processing --- hunyuan_image_minimal_inference.py | 34 ++--- hunyuan_image_train_network.py | 15 +-- library/hunyuan_image_vae.py | 191 ++++++++++++++++++++++++----- library/strategy_base.py | 1 + 4 files changed, 185 insertions(+), 56 deletions(-) diff --git a/hunyuan_image_minimal_inference.py b/hunyuan_image_minimal_inference.py index 85023383..711e911f 100644 --- a/hunyuan_image_minimal_inference.py +++ b/hunyuan_image_minimal_inference.py @@ -88,7 +88,13 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--fp8_scaled", action="store_true", help="use scaled fp8 for DiT, only for fp8") parser.add_argument("--text_encoder_cpu", action="store_true", help="Inference on CPU for Text Encoders") - parser.add_argument("--vae_enable_tiling", action="store_true", help="Enable tiling for VAE decoding") + parser.add_argument( + "--vae_chunk_size", + type=int, + default=None, # default is None (no chunking) + help="Chunk size for VAE decoding to reduce memory usage. Default is None (no chunking). 16 is recommended if enabled" + " / メモリ使用量を減らすためのVAEデコードのチャンクサイズ。デフォルトはNone(チャンクなし)。有効にする場合は16程度を推奨。", + ) parser.add_argument( "--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU" ) @@ -431,14 +437,10 @@ def merge_lora_weights( # endregion -def decode_latent(vae: HunyuanVAE2D, latent: torch.Tensor, device: torch.device, enable_tiling: bool = False) -> torch.Tensor: +def decode_latent(vae: HunyuanVAE2D, latent: torch.Tensor, device: torch.device) -> torch.Tensor: logger.info(f"Decoding image. Latent shape {latent.shape}, device {device}") vae.to(device) - if enable_tiling: - vae.enable_tiling() - else: - vae.disable_tiling() with torch.no_grad(): latent = latent / vae.scaling_factor # scale latent back to original range pixels = vae.decode(latent.to(device, dtype=vae.dtype)) @@ -807,7 +809,7 @@ def save_output( vae: HunyuanVAE2D, latent: torch.Tensor, device: torch.device, - original_base_names: Optional[List[str]] = None, + original_base_name: Optional[str] = None, ) -> None: """save output @@ -816,7 +818,7 @@ def save_output( vae: VAE model latent: latent tensor device: device to use - original_base_names: original base names (if latents are loaded from files) + original_base_name: original base name (if latents are loaded from files) """ height, width = latent.shape[-2], latent.shape[-1] # BCTHW height *= hunyuan_image_vae.VAE_SCALE_FACTOR @@ -839,14 +841,14 @@ def save_output( 1, vae.latent_channels, height // hunyuan_image_vae.VAE_SCALE_FACTOR, width // hunyuan_image_vae.VAE_SCALE_FACTOR ) - image = decode_latent(vae, latent, device, args.vae_enable_tiling) + image = decode_latent(vae, latent, device) if args.output_type == "images" or args.output_type == "latent_images": # save images - if original_base_names is None or len(original_base_names) == 0: + if original_base_name is None: original_name = "" else: - original_name = f"_{original_base_names[0]}" + original_name = f"_{original_base_name}" save_images(image, args, original_name) @@ -919,7 +921,7 @@ def process_batch_prompts(prompts_data: List[Dict], args: argparse.Namespace) -> # 1. Prepare VAE logger.info("Loading VAE for batch generation...") - vae_for_batch = hunyuan_image_vae.load_vae(args.vae, device="cpu", disable_mmap=True) + vae_for_batch = hunyuan_image_vae.load_vae(args.vae, device="cpu", disable_mmap=True, chunk_size=args.vae_chunk_size) vae_for_batch.eval() all_prompt_args_list = [apply_overrides(args, pd) for pd in prompts_data] # Create all arg instances first @@ -1057,7 +1059,7 @@ def process_interactive(args: argparse.Namespace) -> None: shared_models = load_shared_models(args) shared_models["conds_cache"] = {} # Initialize empty cache for interactive mode - vae = hunyuan_image_vae.load_vae(args.vae, device="cpu", disable_mmap=True) + vae = hunyuan_image_vae.load_vae(args.vae, device="cpu", disable_mmap=True, chunk_size=args.vae_chunk_size) vae.eval() print("Interactive mode. Enter prompts (Ctrl+D or Ctrl+Z (Windows) to exit):") @@ -1185,9 +1187,9 @@ def main(): for i, latent in enumerate(latents_list): args.seed = seeds[i] - vae = hunyuan_image_vae.load_vae(args.vae, device=device, disable_mmap=True) + vae = hunyuan_image_vae.load_vae(args.vae, device=device, disable_mmap=True, chunk_size=args.vae_chunk_size) vae.eval() - save_output(args, vae, latent, device, original_base_names) + save_output(args, vae, latent, device, original_base_names[i]) elif args.from_file: # Batch mode from file @@ -1220,7 +1222,7 @@ def main(): clean_memory_on_device(device) # Save latent and video - vae = hunyuan_image_vae.load_vae(args.vae, device="cpu", disable_mmap=True) + vae = hunyuan_image_vae.load_vae(args.vae, device="cpu", disable_mmap=True, chunk_size=args.vae_chunk_size) vae.eval() save_output(args, vae, latent, device) diff --git a/hunyuan_image_train_network.py b/hunyuan_image_train_network.py index 07e072e7..228c9dbc 100644 --- a/hunyuan_image_train_network.py +++ b/hunyuan_image_train_network.py @@ -358,12 +358,11 @@ class HunyuanImageNetworkTrainer(train_network.NetworkTrainer): args.byt5, dtype=torch.float16, device=vl_device, disable_mmap=args.disable_mmap_load_safetensors ) - vae = hunyuan_image_vae.load_vae(args.vae, "cpu", disable_mmap=args.disable_mmap_load_safetensors) + vae = hunyuan_image_vae.load_vae( + args.vae, "cpu", disable_mmap=args.disable_mmap_load_safetensors, chunk_size=args.vae_chunk_size + ) vae.to(dtype=torch.float16) # VAE is always fp16 vae.eval() - if args.vae_enable_tiling: - vae.enable_tiling() - logger.info("VAE tiling is enabled") model_version = hunyuan_image_utils.MODEL_VERSION_2_1 return model_version, [text_encoder_vlm, text_encoder_byt5], vae, None # unet will be loaded later @@ -674,9 +673,11 @@ def setup_parser() -> argparse.ArgumentParser: "--text_encoder_cpu", action="store_true", help="Inference on CPU for Text Encoders / テキストエンコーダをCPUで推論する" ) parser.add_argument( - "--vae_enable_tiling", - action="store_true", - help="Enable tiling for VAE decoding and encoding / VAEデコーディングとエンコーディングのタイルを有効にする", + "--vae_chunk_size", + type=int, + default=None, # default is None (no chunking) + help="Chunk size for VAE decoding to reduce memory usage. Default is None (no chunking). 16 is recommended if enabled" + " / メモリ使用量を減らすためのVAEデコードのチャンクサイズ。デフォルトはNone(チャンクなし)。有効にする場合は16程度を推奨。", ) parser.add_argument( diff --git a/library/hunyuan_image_vae.py b/library/hunyuan_image_vae.py index b66854e5..a6ed1e81 100644 --- a/library/hunyuan_image_vae.py +++ b/library/hunyuan_image_vae.py @@ -29,14 +29,20 @@ def swish(x: Tensor) -> Tensor: class AttnBlock(nn.Module): """Self-attention block using scaled dot-product attention.""" - def __init__(self, in_channels: int): + def __init__(self, in_channels: int, chunk_size: Optional[int] = None): super().__init__() self.in_channels = in_channels self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) - self.q = Conv2d(in_channels, in_channels, kernel_size=1) - self.k = Conv2d(in_channels, in_channels, kernel_size=1) - self.v = Conv2d(in_channels, in_channels, kernel_size=1) - self.proj_out = Conv2d(in_channels, in_channels, kernel_size=1) + if chunk_size is None or chunk_size <= 0: + self.q = Conv2d(in_channels, in_channels, kernel_size=1) + self.k = Conv2d(in_channels, in_channels, kernel_size=1) + self.v = Conv2d(in_channels, in_channels, kernel_size=1) + self.proj_out = Conv2d(in_channels, in_channels, kernel_size=1) + else: + self.q = ChunkedConv2d(in_channels, in_channels, kernel_size=1, chunk_size=chunk_size) + self.k = ChunkedConv2d(in_channels, in_channels, kernel_size=1, chunk_size=chunk_size) + self.v = ChunkedConv2d(in_channels, in_channels, kernel_size=1, chunk_size=chunk_size) + self.proj_out = ChunkedConv2d(in_channels, in_channels, kernel_size=1, chunk_size=chunk_size) def attention(self, x: Tensor) -> Tensor: x = self.norm(x) @@ -56,6 +62,87 @@ class AttnBlock(nn.Module): return x + self.proj_out(self.attention(x)) +class ChunkedConv2d(nn.Conv2d): + """ + Convolutional layer that processes input in chunks to reduce memory usage. + + Parameters + ---------- + chunk_size : int, optional + Size of chunks to process at a time. Default is 64. + """ + + def __init__(self, *args, **kwargs): + if "chunk_size" in kwargs: + self.chunk_size = kwargs.pop("chunk_size", 64) + super().__init__(*args, **kwargs) + assert self.padding_mode == "zeros", "Only 'zeros' padding mode is supported." + assert self.dilation == (1, 1) and self.stride == (1, 1), "Only dilation=1 and stride=1 are supported." + assert self.groups == 1, "Only groups=1 is supported." + assert self.kernel_size[0] == self.kernel_size[1], "Only square kernels are supported." + assert ( + self.padding[0] == self.padding[1] and self.padding[0] == self.kernel_size[0] // 2 + ), "Only kernel_size//2 padding is supported." + self.original_padding = self.padding + self.padding = (0, 0) # We handle padding manually in forward + + def forward(self, x: Tensor) -> Tensor: + # If chunking is not needed, process normally. We chunk only along height dimension. + if self.chunk_size is None or x.shape[1] <= self.chunk_size: + self.padding = self.original_padding + x = super().forward(x) + self.padding = (0, 0) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + return x + + # Process input in chunks to reduce memory usage + org_shape = x.shape + + # If kernel size is not 1, we need to use overlapping chunks + overlap = self.kernel_size[0] // 2 # 1 for kernel size 3 + step = self.chunk_size - overlap + y = torch.zeros((org_shape[0], self.out_channels, org_shape[2], org_shape[3]), dtype=x.dtype, device=x.device) + yi = 0 + i = 0 + while i < org_shape[2]: + si = i if i == 0 else i - overlap + ei = i + self.chunk_size + + # Check last chunk. If remaining part is small, include it in last chunk + if ei > org_shape[2] or ei + step // 4 > org_shape[2]: + ei = org_shape[2] + + chunk = x[:, :, : ei - si, :] + x = x[:, :, ei - si - overlap * 2 :, :] + + # Pad chunk if needed: This is as the original Conv2d with padding + if i == 0: # First chunk + # Pad except bottom + chunk = torch.nn.functional.pad(chunk, (overlap, overlap, overlap, 0), mode="constant", value=0) + elif ei == org_shape[2]: # Last chunk + # Pad except top + chunk = torch.nn.functional.pad(chunk, (overlap, overlap, 0, overlap), mode="constant", value=0) + else: + # Pad left and right only + chunk = torch.nn.functional.pad(chunk, (overlap, overlap), mode="constant", value=0) + + chunk = super().forward(chunk) + y[:, :, yi : yi + chunk.shape[2], :] = chunk + yi += chunk.shape[2] + del chunk + + if ei == org_shape[2]: + break + i += step + + assert yi == org_shape[2], f"yi={yi}, org_shape[2]={org_shape[2]}" + + if torch.cuda.is_available(): + torch.cuda.empty_cache() # This helps reduce peak memory usage, but slows down a bit + return y + + class ResnetBlock(nn.Module): """ Residual block with two convolutions, group normalization, and swish activation. @@ -69,19 +156,29 @@ class ResnetBlock(nn.Module): Number of output channels. """ - def __init__(self, in_channels: int, out_channels: int): + def __init__(self, in_channels: int, out_channels: int, chunk_size: Optional[int] = None): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) - self.conv1 = Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) - self.conv2 = Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if chunk_size is None or chunk_size <= 0: + self.conv1 = Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.conv2 = Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) - # Skip connection projection for channel dimension mismatch - if self.in_channels != self.out_channels: - self.nin_shortcut = Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + # Skip connection projection for channel dimension mismatch + if self.in_channels != self.out_channels: + self.nin_shortcut = Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + else: + self.conv1 = ChunkedConv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size) + self.conv2 = ChunkedConv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size) + + # Skip connection projection for channel dimension mismatch + if self.in_channels != self.out_channels: + self.nin_shortcut = ChunkedConv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0, chunk_size=chunk_size + ) def forward(self, x: Tensor) -> Tensor: h = x @@ -113,12 +210,17 @@ class Downsample(nn.Module): Number of output channels (must be divisible by 4). """ - def __init__(self, in_channels: int, out_channels: int): + def __init__(self, in_channels: int, out_channels: int, chunk_size: Optional[int] = None): super().__init__() factor = 4 # 2x2 spatial reduction factor assert out_channels % factor == 0 - self.conv = Conv2d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1) + if chunk_size is None or chunk_size <= 0: + self.conv = Conv2d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1) + else: + self.conv = ChunkedConv2d( + in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size + ) self.group_size = factor * in_channels // out_channels def forward(self, x: Tensor) -> Tensor: @@ -147,10 +249,15 @@ class Upsample(nn.Module): Number of output channels. """ - def __init__(self, in_channels: int, out_channels: int): + def __init__(self, in_channels: int, out_channels: int, chunk_size: Optional[int] = None): super().__init__() factor = 4 # 2x2 spatial expansion factor - self.conv = Conv2d(in_channels, out_channels * factor, kernel_size=3, stride=1, padding=1) + + if chunk_size is None or chunk_size <= 0: + self.conv = Conv2d(in_channels, out_channels * factor, kernel_size=3, stride=1, padding=1) + else: + self.conv = ChunkedConv2d(in_channels, out_channels * factor, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size) + self.repeats = factor * out_channels // in_channels def forward(self, x: Tensor) -> Tensor: @@ -191,6 +298,7 @@ class Encoder(nn.Module): block_out_channels: Tuple[int, ...], num_res_blocks: int, ffactor_spatial: int, + chunk_size: Optional[int] = None, ): super().__init__() assert block_out_channels[-1] % (2 * z_channels) == 0 @@ -199,7 +307,12 @@ class Encoder(nn.Module): self.block_out_channels = block_out_channels self.num_res_blocks = num_res_blocks - self.conv_in = Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1) + if chunk_size is None or chunk_size <= 0: + self.conv_in = Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1) + else: + self.conv_in = ChunkedConv2d( + in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1, chunk_size=chunk_size + ) self.down = nn.ModuleList() block_in = block_out_channels[0] @@ -211,7 +324,7 @@ class Encoder(nn.Module): # Add residual blocks for this level for _ in range(self.num_res_blocks): - block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, chunk_size=chunk_size)) block_in = block_out down = nn.Module() @@ -222,20 +335,23 @@ class Encoder(nn.Module): if add_spatial_downsample: assert i_level < len(block_out_channels) - 1 block_out = block_out_channels[i_level + 1] - down.downsample = Downsample(block_in, block_out) + down.downsample = Downsample(block_in, block_out, chunk_size=chunk_size) block_in = block_out self.down.append(down) # Middle blocks with attention self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) - self.mid.attn_1 = AttnBlock(block_in) - self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, chunk_size=chunk_size) + self.mid.attn_1 = AttnBlock(block_in, chunk_size=chunk_size) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, chunk_size=chunk_size) # Output layers self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) - self.conv_out = Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1) + if chunk_size is None or chunk_size <= 0: + self.conv_out = Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1) + else: + self.conv_out = ChunkedConv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size) def forward(self, x: Tensor) -> Tensor: # Initial convolution @@ -291,6 +407,7 @@ class Decoder(nn.Module): block_out_channels: Tuple[int, ...], num_res_blocks: int, ffactor_spatial: int, + chunk_size: Optional[int] = None, ): super().__init__() assert block_out_channels[0] % z_channels == 0 @@ -300,13 +417,16 @@ class Decoder(nn.Module): self.num_res_blocks = num_res_blocks block_in = block_out_channels[0] - self.conv_in = Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + if chunk_size is None or chunk_size <= 0: + self.conv_in = Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + else: + self.conv_in = ChunkedConv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size) # Middle blocks with attention self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) - self.mid.attn_1 = AttnBlock(block_in) - self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, chunk_size=chunk_size) + self.mid.attn_1 = AttnBlock(block_in, chunk_size=chunk_size) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, chunk_size=chunk_size) # Build upsampling blocks self.up = nn.ModuleList() @@ -316,7 +436,7 @@ class Decoder(nn.Module): # Add residual blocks for this level (extra block for decoder) for _ in range(self.num_res_blocks + 1): - block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, chunk_size=chunk_size)) block_in = block_out up = nn.Module() @@ -327,14 +447,17 @@ class Decoder(nn.Module): if add_spatial_upsample: assert i_level < len(block_out_channels) - 1 block_out = block_out_channels[i_level + 1] - up.upsample = Upsample(block_in, block_out) + up.upsample = Upsample(block_in, block_out, chunk_size=chunk_size) block_in = block_out self.up.append(up) # Output layers self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) - self.conv_out = Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1) + if chunk_size is None or chunk_size <= 0: + self.conv_out = Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1) + else: + self.conv_out = ChunkedConv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size) def forward(self, z: Tensor) -> Tensor: # Initial processing with skip connection @@ -370,7 +493,7 @@ class HunyuanVAE2D(nn.Module): with 32x spatial compression and optional memory-efficient tiling for large images. """ - def __init__(self): + def __init__(self, chunk_size: Optional[int] = None): super().__init__() # Fixed configuration for Hunyuan Image-2.1 @@ -392,6 +515,7 @@ class HunyuanVAE2D(nn.Module): block_out_channels=block_out_channels, num_res_blocks=layers_per_block, ffactor_spatial=ffactor_spatial, + chunk_size=chunk_size, ) self.decoder = Decoder( @@ -400,6 +524,7 @@ class HunyuanVAE2D(nn.Module): block_out_channels=list(reversed(block_out_channels)), num_res_blocks=layers_per_block, ffactor_spatial=ffactor_spatial, + chunk_size=chunk_size, ) # Spatial tiling configuration for memory efficiency @@ -617,9 +742,9 @@ class HunyuanVAE2D(nn.Module): return decoded -def load_vae(vae_path: str, device: torch.device, disable_mmap: bool = False) -> HunyuanVAE2D: - logger.info("Initializing VAE") - vae = HunyuanVAE2D() +def load_vae(vae_path: str, device: torch.device, disable_mmap: bool = False, chunk_size: Optional[int] = None) -> HunyuanVAE2D: + logger.info(f"Initializing VAE with chunk_size={chunk_size}") + vae = HunyuanVAE2D(chunk_size=chunk_size) logger.info(f"Loading VAE from {vae_path}") state_dict = load_safetensors(vae_path, device=device, disable_mmap=disable_mmap) diff --git a/library/strategy_base.py b/library/strategy_base.py index fad79682..e88d273f 100644 --- a/library/strategy_base.py +++ b/library/strategy_base.py @@ -626,6 +626,7 @@ class LatentsCachingStrategy: for key in npz.files: kwargs[key] = npz[key] + # TODO float() is needed if vae is in bfloat16. Remove it if vae is float16. kwargs["latents" + key_reso_suffix] = latents_tensor.float().cpu().numpy() kwargs["original_size" + key_reso_suffix] = np.array(original_size) kwargs["crop_ltrb" + key_reso_suffix] = np.array(crop_ltrb)