feat: add vae_chunk_size argument for memory-efficient VAE decoding and processing

This commit is contained in:
Kohya S
2025-09-21 11:09:37 +09:00
parent 8f20c37949
commit f41e9e2b58
4 changed files with 185 additions and 56 deletions

View File

@@ -88,7 +88,13 @@ def parse_args() -> argparse.Namespace:
parser.add_argument("--fp8_scaled", action="store_true", help="use scaled fp8 for DiT, only for fp8")
parser.add_argument("--text_encoder_cpu", action="store_true", help="Inference on CPU for Text Encoders")
parser.add_argument("--vae_enable_tiling", action="store_true", help="Enable tiling for VAE decoding")
parser.add_argument(
"--vae_chunk_size",
type=int,
default=None, # default is None (no chunking)
help="Chunk size for VAE decoding to reduce memory usage. Default is None (no chunking). 16 is recommended if enabled"
" / メモリ使用量を減らすためのVAEデコードのチャンクサイズ。デフォルトはNoneチャンクなし。有効にする場合は16程度を推奨。",
)
parser.add_argument(
"--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU"
)
@@ -431,14 +437,10 @@ def merge_lora_weights(
# endregion
def decode_latent(vae: HunyuanVAE2D, latent: torch.Tensor, device: torch.device, enable_tiling: bool = False) -> torch.Tensor:
def decode_latent(vae: HunyuanVAE2D, latent: torch.Tensor, device: torch.device) -> torch.Tensor:
logger.info(f"Decoding image. Latent shape {latent.shape}, device {device}")
vae.to(device)
if enable_tiling:
vae.enable_tiling()
else:
vae.disable_tiling()
with torch.no_grad():
latent = latent / vae.scaling_factor # scale latent back to original range
pixels = vae.decode(latent.to(device, dtype=vae.dtype))
@@ -807,7 +809,7 @@ def save_output(
vae: HunyuanVAE2D,
latent: torch.Tensor,
device: torch.device,
original_base_names: Optional[List[str]] = None,
original_base_name: Optional[str] = None,
) -> None:
"""save output
@@ -816,7 +818,7 @@ def save_output(
vae: VAE model
latent: latent tensor
device: device to use
original_base_names: original base names (if latents are loaded from files)
original_base_name: original base name (if latents are loaded from files)
"""
height, width = latent.shape[-2], latent.shape[-1] # BCTHW
height *= hunyuan_image_vae.VAE_SCALE_FACTOR
@@ -839,14 +841,14 @@ def save_output(
1, vae.latent_channels, height // hunyuan_image_vae.VAE_SCALE_FACTOR, width // hunyuan_image_vae.VAE_SCALE_FACTOR
)
image = decode_latent(vae, latent, device, args.vae_enable_tiling)
image = decode_latent(vae, latent, device)
if args.output_type == "images" or args.output_type == "latent_images":
# save images
if original_base_names is None or len(original_base_names) == 0:
if original_base_name is None:
original_name = ""
else:
original_name = f"_{original_base_names[0]}"
original_name = f"_{original_base_name}"
save_images(image, args, original_name)
@@ -919,7 +921,7 @@ def process_batch_prompts(prompts_data: List[Dict], args: argparse.Namespace) ->
# 1. Prepare VAE
logger.info("Loading VAE for batch generation...")
vae_for_batch = hunyuan_image_vae.load_vae(args.vae, device="cpu", disable_mmap=True)
vae_for_batch = hunyuan_image_vae.load_vae(args.vae, device="cpu", disable_mmap=True, chunk_size=args.vae_chunk_size)
vae_for_batch.eval()
all_prompt_args_list = [apply_overrides(args, pd) for pd in prompts_data] # Create all arg instances first
@@ -1057,7 +1059,7 @@ def process_interactive(args: argparse.Namespace) -> None:
shared_models = load_shared_models(args)
shared_models["conds_cache"] = {} # Initialize empty cache for interactive mode
vae = hunyuan_image_vae.load_vae(args.vae, device="cpu", disable_mmap=True)
vae = hunyuan_image_vae.load_vae(args.vae, device="cpu", disable_mmap=True, chunk_size=args.vae_chunk_size)
vae.eval()
print("Interactive mode. Enter prompts (Ctrl+D or Ctrl+Z (Windows) to exit):")
@@ -1185,9 +1187,9 @@ def main():
for i, latent in enumerate(latents_list):
args.seed = seeds[i]
vae = hunyuan_image_vae.load_vae(args.vae, device=device, disable_mmap=True)
vae = hunyuan_image_vae.load_vae(args.vae, device=device, disable_mmap=True, chunk_size=args.vae_chunk_size)
vae.eval()
save_output(args, vae, latent, device, original_base_names)
save_output(args, vae, latent, device, original_base_names[i])
elif args.from_file:
# Batch mode from file
@@ -1220,7 +1222,7 @@ def main():
clean_memory_on_device(device)
# Save latent and video
vae = hunyuan_image_vae.load_vae(args.vae, device="cpu", disable_mmap=True)
vae = hunyuan_image_vae.load_vae(args.vae, device="cpu", disable_mmap=True, chunk_size=args.vae_chunk_size)
vae.eval()
save_output(args, vae, latent, device)

View File

@@ -358,12 +358,11 @@ class HunyuanImageNetworkTrainer(train_network.NetworkTrainer):
args.byt5, dtype=torch.float16, device=vl_device, disable_mmap=args.disable_mmap_load_safetensors
)
vae = hunyuan_image_vae.load_vae(args.vae, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
vae = hunyuan_image_vae.load_vae(
args.vae, "cpu", disable_mmap=args.disable_mmap_load_safetensors, chunk_size=args.vae_chunk_size
)
vae.to(dtype=torch.float16) # VAE is always fp16
vae.eval()
if args.vae_enable_tiling:
vae.enable_tiling()
logger.info("VAE tiling is enabled")
model_version = hunyuan_image_utils.MODEL_VERSION_2_1
return model_version, [text_encoder_vlm, text_encoder_byt5], vae, None # unet will be loaded later
@@ -674,9 +673,11 @@ def setup_parser() -> argparse.ArgumentParser:
"--text_encoder_cpu", action="store_true", help="Inference on CPU for Text Encoders / テキストエンコーダをCPUで推論する"
)
parser.add_argument(
"--vae_enable_tiling",
action="store_true",
help="Enable tiling for VAE decoding and encoding / VAEデコーディングとエンコーディングのタイルを有効にする",
"--vae_chunk_size",
type=int,
default=None, # default is None (no chunking)
help="Chunk size for VAE decoding to reduce memory usage. Default is None (no chunking). 16 is recommended if enabled"
" / メモリ使用量を減らすためのVAEデコードのチャンクサイズ。デフォルトはNoneチャンクなし。有効にする場合は16程度を推奨。",
)
parser.add_argument(

View File

@@ -29,14 +29,20 @@ def swish(x: Tensor) -> Tensor:
class AttnBlock(nn.Module):
"""Self-attention block using scaled dot-product attention."""
def __init__(self, in_channels: int):
def __init__(self, in_channels: int, chunk_size: Optional[int] = None):
super().__init__()
self.in_channels = in_channels
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.q = Conv2d(in_channels, in_channels, kernel_size=1)
self.k = Conv2d(in_channels, in_channels, kernel_size=1)
self.v = Conv2d(in_channels, in_channels, kernel_size=1)
self.proj_out = Conv2d(in_channels, in_channels, kernel_size=1)
if chunk_size is None or chunk_size <= 0:
self.q = Conv2d(in_channels, in_channels, kernel_size=1)
self.k = Conv2d(in_channels, in_channels, kernel_size=1)
self.v = Conv2d(in_channels, in_channels, kernel_size=1)
self.proj_out = Conv2d(in_channels, in_channels, kernel_size=1)
else:
self.q = ChunkedConv2d(in_channels, in_channels, kernel_size=1, chunk_size=chunk_size)
self.k = ChunkedConv2d(in_channels, in_channels, kernel_size=1, chunk_size=chunk_size)
self.v = ChunkedConv2d(in_channels, in_channels, kernel_size=1, chunk_size=chunk_size)
self.proj_out = ChunkedConv2d(in_channels, in_channels, kernel_size=1, chunk_size=chunk_size)
def attention(self, x: Tensor) -> Tensor:
x = self.norm(x)
@@ -56,6 +62,87 @@ class AttnBlock(nn.Module):
return x + self.proj_out(self.attention(x))
class ChunkedConv2d(nn.Conv2d):
"""
Convolutional layer that processes input in chunks to reduce memory usage.
Parameters
----------
chunk_size : int, optional
Size of chunks to process at a time. Default is 64.
"""
def __init__(self, *args, **kwargs):
if "chunk_size" in kwargs:
self.chunk_size = kwargs.pop("chunk_size", 64)
super().__init__(*args, **kwargs)
assert self.padding_mode == "zeros", "Only 'zeros' padding mode is supported."
assert self.dilation == (1, 1) and self.stride == (1, 1), "Only dilation=1 and stride=1 are supported."
assert self.groups == 1, "Only groups=1 is supported."
assert self.kernel_size[0] == self.kernel_size[1], "Only square kernels are supported."
assert (
self.padding[0] == self.padding[1] and self.padding[0] == self.kernel_size[0] // 2
), "Only kernel_size//2 padding is supported."
self.original_padding = self.padding
self.padding = (0, 0) # We handle padding manually in forward
def forward(self, x: Tensor) -> Tensor:
# If chunking is not needed, process normally. We chunk only along height dimension.
if self.chunk_size is None or x.shape[1] <= self.chunk_size:
self.padding = self.original_padding
x = super().forward(x)
self.padding = (0, 0)
if torch.cuda.is_available():
torch.cuda.empty_cache()
return x
# Process input in chunks to reduce memory usage
org_shape = x.shape
# If kernel size is not 1, we need to use overlapping chunks
overlap = self.kernel_size[0] // 2 # 1 for kernel size 3
step = self.chunk_size - overlap
y = torch.zeros((org_shape[0], self.out_channels, org_shape[2], org_shape[3]), dtype=x.dtype, device=x.device)
yi = 0
i = 0
while i < org_shape[2]:
si = i if i == 0 else i - overlap
ei = i + self.chunk_size
# Check last chunk. If remaining part is small, include it in last chunk
if ei > org_shape[2] or ei + step // 4 > org_shape[2]:
ei = org_shape[2]
chunk = x[:, :, : ei - si, :]
x = x[:, :, ei - si - overlap * 2 :, :]
# Pad chunk if needed: This is as the original Conv2d with padding
if i == 0: # First chunk
# Pad except bottom
chunk = torch.nn.functional.pad(chunk, (overlap, overlap, overlap, 0), mode="constant", value=0)
elif ei == org_shape[2]: # Last chunk
# Pad except top
chunk = torch.nn.functional.pad(chunk, (overlap, overlap, 0, overlap), mode="constant", value=0)
else:
# Pad left and right only
chunk = torch.nn.functional.pad(chunk, (overlap, overlap), mode="constant", value=0)
chunk = super().forward(chunk)
y[:, :, yi : yi + chunk.shape[2], :] = chunk
yi += chunk.shape[2]
del chunk
if ei == org_shape[2]:
break
i += step
assert yi == org_shape[2], f"yi={yi}, org_shape[2]={org_shape[2]}"
if torch.cuda.is_available():
torch.cuda.empty_cache() # This helps reduce peak memory usage, but slows down a bit
return y
class ResnetBlock(nn.Module):
"""
Residual block with two convolutions, group normalization, and swish activation.
@@ -69,19 +156,29 @@ class ResnetBlock(nn.Module):
Number of output channels.
"""
def __init__(self, in_channels: int, out_channels: int):
def __init__(self, in_channels: int, out_channels: int, chunk_size: Optional[int] = None):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.conv1 = Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
self.conv2 = Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if chunk_size is None or chunk_size <= 0:
self.conv1 = Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.conv2 = Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
# Skip connection projection for channel dimension mismatch
if self.in_channels != self.out_channels:
self.nin_shortcut = Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
# Skip connection projection for channel dimension mismatch
if self.in_channels != self.out_channels:
self.nin_shortcut = Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
else:
self.conv1 = ChunkedConv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size)
self.conv2 = ChunkedConv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size)
# Skip connection projection for channel dimension mismatch
if self.in_channels != self.out_channels:
self.nin_shortcut = ChunkedConv2d(
in_channels, out_channels, kernel_size=1, stride=1, padding=0, chunk_size=chunk_size
)
def forward(self, x: Tensor) -> Tensor:
h = x
@@ -113,12 +210,17 @@ class Downsample(nn.Module):
Number of output channels (must be divisible by 4).
"""
def __init__(self, in_channels: int, out_channels: int):
def __init__(self, in_channels: int, out_channels: int, chunk_size: Optional[int] = None):
super().__init__()
factor = 4 # 2x2 spatial reduction factor
assert out_channels % factor == 0
self.conv = Conv2d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1)
if chunk_size is None or chunk_size <= 0:
self.conv = Conv2d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1)
else:
self.conv = ChunkedConv2d(
in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size
)
self.group_size = factor * in_channels // out_channels
def forward(self, x: Tensor) -> Tensor:
@@ -147,10 +249,15 @@ class Upsample(nn.Module):
Number of output channels.
"""
def __init__(self, in_channels: int, out_channels: int):
def __init__(self, in_channels: int, out_channels: int, chunk_size: Optional[int] = None):
super().__init__()
factor = 4 # 2x2 spatial expansion factor
self.conv = Conv2d(in_channels, out_channels * factor, kernel_size=3, stride=1, padding=1)
if chunk_size is None or chunk_size <= 0:
self.conv = Conv2d(in_channels, out_channels * factor, kernel_size=3, stride=1, padding=1)
else:
self.conv = ChunkedConv2d(in_channels, out_channels * factor, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size)
self.repeats = factor * out_channels // in_channels
def forward(self, x: Tensor) -> Tensor:
@@ -191,6 +298,7 @@ class Encoder(nn.Module):
block_out_channels: Tuple[int, ...],
num_res_blocks: int,
ffactor_spatial: int,
chunk_size: Optional[int] = None,
):
super().__init__()
assert block_out_channels[-1] % (2 * z_channels) == 0
@@ -199,7 +307,12 @@ class Encoder(nn.Module):
self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks
self.conv_in = Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
if chunk_size is None or chunk_size <= 0:
self.conv_in = Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
else:
self.conv_in = ChunkedConv2d(
in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1, chunk_size=chunk_size
)
self.down = nn.ModuleList()
block_in = block_out_channels[0]
@@ -211,7 +324,7 @@ class Encoder(nn.Module):
# Add residual blocks for this level
for _ in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, chunk_size=chunk_size))
block_in = block_out
down = nn.Module()
@@ -222,20 +335,23 @@ class Encoder(nn.Module):
if add_spatial_downsample:
assert i_level < len(block_out_channels) - 1
block_out = block_out_channels[i_level + 1]
down.downsample = Downsample(block_in, block_out)
down.downsample = Downsample(block_in, block_out, chunk_size=chunk_size)
block_in = block_out
self.down.append(down)
# Middle blocks with attention
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, chunk_size=chunk_size)
self.mid.attn_1 = AttnBlock(block_in, chunk_size=chunk_size)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, chunk_size=chunk_size)
# Output layers
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
self.conv_out = Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
if chunk_size is None or chunk_size <= 0:
self.conv_out = Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
else:
self.conv_out = ChunkedConv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size)
def forward(self, x: Tensor) -> Tensor:
# Initial convolution
@@ -291,6 +407,7 @@ class Decoder(nn.Module):
block_out_channels: Tuple[int, ...],
num_res_blocks: int,
ffactor_spatial: int,
chunk_size: Optional[int] = None,
):
super().__init__()
assert block_out_channels[0] % z_channels == 0
@@ -300,13 +417,16 @@ class Decoder(nn.Module):
self.num_res_blocks = num_res_blocks
block_in = block_out_channels[0]
self.conv_in = Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
if chunk_size is None or chunk_size <= 0:
self.conv_in = Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
else:
self.conv_in = ChunkedConv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size)
# Middle blocks with attention
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, chunk_size=chunk_size)
self.mid.attn_1 = AttnBlock(block_in, chunk_size=chunk_size)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, chunk_size=chunk_size)
# Build upsampling blocks
self.up = nn.ModuleList()
@@ -316,7 +436,7 @@ class Decoder(nn.Module):
# Add residual blocks for this level (extra block for decoder)
for _ in range(self.num_res_blocks + 1):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, chunk_size=chunk_size))
block_in = block_out
up = nn.Module()
@@ -327,14 +447,17 @@ class Decoder(nn.Module):
if add_spatial_upsample:
assert i_level < len(block_out_channels) - 1
block_out = block_out_channels[i_level + 1]
up.upsample = Upsample(block_in, block_out)
up.upsample = Upsample(block_in, block_out, chunk_size=chunk_size)
block_in = block_out
self.up.append(up)
# Output layers
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
self.conv_out = Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)
if chunk_size is None or chunk_size <= 0:
self.conv_out = Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)
else:
self.conv_out = ChunkedConv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size)
def forward(self, z: Tensor) -> Tensor:
# Initial processing with skip connection
@@ -370,7 +493,7 @@ class HunyuanVAE2D(nn.Module):
with 32x spatial compression and optional memory-efficient tiling for large images.
"""
def __init__(self):
def __init__(self, chunk_size: Optional[int] = None):
super().__init__()
# Fixed configuration for Hunyuan Image-2.1
@@ -392,6 +515,7 @@ class HunyuanVAE2D(nn.Module):
block_out_channels=block_out_channels,
num_res_blocks=layers_per_block,
ffactor_spatial=ffactor_spatial,
chunk_size=chunk_size,
)
self.decoder = Decoder(
@@ -400,6 +524,7 @@ class HunyuanVAE2D(nn.Module):
block_out_channels=list(reversed(block_out_channels)),
num_res_blocks=layers_per_block,
ffactor_spatial=ffactor_spatial,
chunk_size=chunk_size,
)
# Spatial tiling configuration for memory efficiency
@@ -617,9 +742,9 @@ class HunyuanVAE2D(nn.Module):
return decoded
def load_vae(vae_path: str, device: torch.device, disable_mmap: bool = False) -> HunyuanVAE2D:
logger.info("Initializing VAE")
vae = HunyuanVAE2D()
def load_vae(vae_path: str, device: torch.device, disable_mmap: bool = False, chunk_size: Optional[int] = None) -> HunyuanVAE2D:
logger.info(f"Initializing VAE with chunk_size={chunk_size}")
vae = HunyuanVAE2D(chunk_size=chunk_size)
logger.info(f"Loading VAE from {vae_path}")
state_dict = load_safetensors(vae_path, device=device, disable_mmap=disable_mmap)

View File

@@ -626,6 +626,7 @@ class LatentsCachingStrategy:
for key in npz.files:
kwargs[key] = npz[key]
# TODO float() is needed if vae is in bfloat16. Remove it if vae is float16.
kwargs["latents" + key_reso_suffix] = latents_tensor.float().cpu().numpy()
kwargs["original_size" + key_reso_suffix] = np.array(original_size)
kwargs["crop_ltrb" + key_reso_suffix] = np.array(crop_ltrb)