mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
feat: add vae_chunk_size argument for memory-efficient VAE decoding and processing
This commit is contained in:
@@ -88,7 +88,13 @@ def parse_args() -> argparse.Namespace:
|
||||
parser.add_argument("--fp8_scaled", action="store_true", help="use scaled fp8 for DiT, only for fp8")
|
||||
|
||||
parser.add_argument("--text_encoder_cpu", action="store_true", help="Inference on CPU for Text Encoders")
|
||||
parser.add_argument("--vae_enable_tiling", action="store_true", help="Enable tiling for VAE decoding")
|
||||
parser.add_argument(
|
||||
"--vae_chunk_size",
|
||||
type=int,
|
||||
default=None, # default is None (no chunking)
|
||||
help="Chunk size for VAE decoding to reduce memory usage. Default is None (no chunking). 16 is recommended if enabled"
|
||||
" / メモリ使用量を減らすためのVAEデコードのチャンクサイズ。デフォルトはNone(チャンクなし)。有効にする場合は16程度を推奨。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU"
|
||||
)
|
||||
@@ -431,14 +437,10 @@ def merge_lora_weights(
|
||||
# endregion
|
||||
|
||||
|
||||
def decode_latent(vae: HunyuanVAE2D, latent: torch.Tensor, device: torch.device, enable_tiling: bool = False) -> torch.Tensor:
|
||||
def decode_latent(vae: HunyuanVAE2D, latent: torch.Tensor, device: torch.device) -> torch.Tensor:
|
||||
logger.info(f"Decoding image. Latent shape {latent.shape}, device {device}")
|
||||
|
||||
vae.to(device)
|
||||
if enable_tiling:
|
||||
vae.enable_tiling()
|
||||
else:
|
||||
vae.disable_tiling()
|
||||
with torch.no_grad():
|
||||
latent = latent / vae.scaling_factor # scale latent back to original range
|
||||
pixels = vae.decode(latent.to(device, dtype=vae.dtype))
|
||||
@@ -807,7 +809,7 @@ def save_output(
|
||||
vae: HunyuanVAE2D,
|
||||
latent: torch.Tensor,
|
||||
device: torch.device,
|
||||
original_base_names: Optional[List[str]] = None,
|
||||
original_base_name: Optional[str] = None,
|
||||
) -> None:
|
||||
"""save output
|
||||
|
||||
@@ -816,7 +818,7 @@ def save_output(
|
||||
vae: VAE model
|
||||
latent: latent tensor
|
||||
device: device to use
|
||||
original_base_names: original base names (if latents are loaded from files)
|
||||
original_base_name: original base name (if latents are loaded from files)
|
||||
"""
|
||||
height, width = latent.shape[-2], latent.shape[-1] # BCTHW
|
||||
height *= hunyuan_image_vae.VAE_SCALE_FACTOR
|
||||
@@ -839,14 +841,14 @@ def save_output(
|
||||
1, vae.latent_channels, height // hunyuan_image_vae.VAE_SCALE_FACTOR, width // hunyuan_image_vae.VAE_SCALE_FACTOR
|
||||
)
|
||||
|
||||
image = decode_latent(vae, latent, device, args.vae_enable_tiling)
|
||||
image = decode_latent(vae, latent, device)
|
||||
|
||||
if args.output_type == "images" or args.output_type == "latent_images":
|
||||
# save images
|
||||
if original_base_names is None or len(original_base_names) == 0:
|
||||
if original_base_name is None:
|
||||
original_name = ""
|
||||
else:
|
||||
original_name = f"_{original_base_names[0]}"
|
||||
original_name = f"_{original_base_name}"
|
||||
save_images(image, args, original_name)
|
||||
|
||||
|
||||
@@ -919,7 +921,7 @@ def process_batch_prompts(prompts_data: List[Dict], args: argparse.Namespace) ->
|
||||
|
||||
# 1. Prepare VAE
|
||||
logger.info("Loading VAE for batch generation...")
|
||||
vae_for_batch = hunyuan_image_vae.load_vae(args.vae, device="cpu", disable_mmap=True)
|
||||
vae_for_batch = hunyuan_image_vae.load_vae(args.vae, device="cpu", disable_mmap=True, chunk_size=args.vae_chunk_size)
|
||||
vae_for_batch.eval()
|
||||
|
||||
all_prompt_args_list = [apply_overrides(args, pd) for pd in prompts_data] # Create all arg instances first
|
||||
@@ -1057,7 +1059,7 @@ def process_interactive(args: argparse.Namespace) -> None:
|
||||
shared_models = load_shared_models(args)
|
||||
shared_models["conds_cache"] = {} # Initialize empty cache for interactive mode
|
||||
|
||||
vae = hunyuan_image_vae.load_vae(args.vae, device="cpu", disable_mmap=True)
|
||||
vae = hunyuan_image_vae.load_vae(args.vae, device="cpu", disable_mmap=True, chunk_size=args.vae_chunk_size)
|
||||
vae.eval()
|
||||
|
||||
print("Interactive mode. Enter prompts (Ctrl+D or Ctrl+Z (Windows) to exit):")
|
||||
@@ -1185,9 +1187,9 @@ def main():
|
||||
for i, latent in enumerate(latents_list):
|
||||
args.seed = seeds[i]
|
||||
|
||||
vae = hunyuan_image_vae.load_vae(args.vae, device=device, disable_mmap=True)
|
||||
vae = hunyuan_image_vae.load_vae(args.vae, device=device, disable_mmap=True, chunk_size=args.vae_chunk_size)
|
||||
vae.eval()
|
||||
save_output(args, vae, latent, device, original_base_names)
|
||||
save_output(args, vae, latent, device, original_base_names[i])
|
||||
|
||||
elif args.from_file:
|
||||
# Batch mode from file
|
||||
@@ -1220,7 +1222,7 @@ def main():
|
||||
clean_memory_on_device(device)
|
||||
|
||||
# Save latent and video
|
||||
vae = hunyuan_image_vae.load_vae(args.vae, device="cpu", disable_mmap=True)
|
||||
vae = hunyuan_image_vae.load_vae(args.vae, device="cpu", disable_mmap=True, chunk_size=args.vae_chunk_size)
|
||||
vae.eval()
|
||||
save_output(args, vae, latent, device)
|
||||
|
||||
|
||||
@@ -358,12 +358,11 @@ class HunyuanImageNetworkTrainer(train_network.NetworkTrainer):
|
||||
args.byt5, dtype=torch.float16, device=vl_device, disable_mmap=args.disable_mmap_load_safetensors
|
||||
)
|
||||
|
||||
vae = hunyuan_image_vae.load_vae(args.vae, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
||||
vae = hunyuan_image_vae.load_vae(
|
||||
args.vae, "cpu", disable_mmap=args.disable_mmap_load_safetensors, chunk_size=args.vae_chunk_size
|
||||
)
|
||||
vae.to(dtype=torch.float16) # VAE is always fp16
|
||||
vae.eval()
|
||||
if args.vae_enable_tiling:
|
||||
vae.enable_tiling()
|
||||
logger.info("VAE tiling is enabled")
|
||||
|
||||
model_version = hunyuan_image_utils.MODEL_VERSION_2_1
|
||||
return model_version, [text_encoder_vlm, text_encoder_byt5], vae, None # unet will be loaded later
|
||||
@@ -674,9 +673,11 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
"--text_encoder_cpu", action="store_true", help="Inference on CPU for Text Encoders / テキストエンコーダをCPUで推論する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_enable_tiling",
|
||||
action="store_true",
|
||||
help="Enable tiling for VAE decoding and encoding / VAEデコーディングとエンコーディングのタイルを有効にする",
|
||||
"--vae_chunk_size",
|
||||
type=int,
|
||||
default=None, # default is None (no chunking)
|
||||
help="Chunk size for VAE decoding to reduce memory usage. Default is None (no chunking). 16 is recommended if enabled"
|
||||
" / メモリ使用量を減らすためのVAEデコードのチャンクサイズ。デフォルトはNone(チャンクなし)。有効にする場合は16程度を推奨。",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
|
||||
@@ -29,14 +29,20 @@ def swish(x: Tensor) -> Tensor:
|
||||
class AttnBlock(nn.Module):
|
||||
"""Self-attention block using scaled dot-product attention."""
|
||||
|
||||
def __init__(self, in_channels: int):
|
||||
def __init__(self, in_channels: int, chunk_size: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
self.q = Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
self.k = Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
self.v = Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
self.proj_out = Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
if chunk_size is None or chunk_size <= 0:
|
||||
self.q = Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
self.k = Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
self.v = Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
self.proj_out = Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
else:
|
||||
self.q = ChunkedConv2d(in_channels, in_channels, kernel_size=1, chunk_size=chunk_size)
|
||||
self.k = ChunkedConv2d(in_channels, in_channels, kernel_size=1, chunk_size=chunk_size)
|
||||
self.v = ChunkedConv2d(in_channels, in_channels, kernel_size=1, chunk_size=chunk_size)
|
||||
self.proj_out = ChunkedConv2d(in_channels, in_channels, kernel_size=1, chunk_size=chunk_size)
|
||||
|
||||
def attention(self, x: Tensor) -> Tensor:
|
||||
x = self.norm(x)
|
||||
@@ -56,6 +62,87 @@ class AttnBlock(nn.Module):
|
||||
return x + self.proj_out(self.attention(x))
|
||||
|
||||
|
||||
class ChunkedConv2d(nn.Conv2d):
|
||||
"""
|
||||
Convolutional layer that processes input in chunks to reduce memory usage.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
chunk_size : int, optional
|
||||
Size of chunks to process at a time. Default is 64.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if "chunk_size" in kwargs:
|
||||
self.chunk_size = kwargs.pop("chunk_size", 64)
|
||||
super().__init__(*args, **kwargs)
|
||||
assert self.padding_mode == "zeros", "Only 'zeros' padding mode is supported."
|
||||
assert self.dilation == (1, 1) and self.stride == (1, 1), "Only dilation=1 and stride=1 are supported."
|
||||
assert self.groups == 1, "Only groups=1 is supported."
|
||||
assert self.kernel_size[0] == self.kernel_size[1], "Only square kernels are supported."
|
||||
assert (
|
||||
self.padding[0] == self.padding[1] and self.padding[0] == self.kernel_size[0] // 2
|
||||
), "Only kernel_size//2 padding is supported."
|
||||
self.original_padding = self.padding
|
||||
self.padding = (0, 0) # We handle padding manually in forward
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
# If chunking is not needed, process normally. We chunk only along height dimension.
|
||||
if self.chunk_size is None or x.shape[1] <= self.chunk_size:
|
||||
self.padding = self.original_padding
|
||||
x = super().forward(x)
|
||||
self.padding = (0, 0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
return x
|
||||
|
||||
# Process input in chunks to reduce memory usage
|
||||
org_shape = x.shape
|
||||
|
||||
# If kernel size is not 1, we need to use overlapping chunks
|
||||
overlap = self.kernel_size[0] // 2 # 1 for kernel size 3
|
||||
step = self.chunk_size - overlap
|
||||
y = torch.zeros((org_shape[0], self.out_channels, org_shape[2], org_shape[3]), dtype=x.dtype, device=x.device)
|
||||
yi = 0
|
||||
i = 0
|
||||
while i < org_shape[2]:
|
||||
si = i if i == 0 else i - overlap
|
||||
ei = i + self.chunk_size
|
||||
|
||||
# Check last chunk. If remaining part is small, include it in last chunk
|
||||
if ei > org_shape[2] or ei + step // 4 > org_shape[2]:
|
||||
ei = org_shape[2]
|
||||
|
||||
chunk = x[:, :, : ei - si, :]
|
||||
x = x[:, :, ei - si - overlap * 2 :, :]
|
||||
|
||||
# Pad chunk if needed: This is as the original Conv2d with padding
|
||||
if i == 0: # First chunk
|
||||
# Pad except bottom
|
||||
chunk = torch.nn.functional.pad(chunk, (overlap, overlap, overlap, 0), mode="constant", value=0)
|
||||
elif ei == org_shape[2]: # Last chunk
|
||||
# Pad except top
|
||||
chunk = torch.nn.functional.pad(chunk, (overlap, overlap, 0, overlap), mode="constant", value=0)
|
||||
else:
|
||||
# Pad left and right only
|
||||
chunk = torch.nn.functional.pad(chunk, (overlap, overlap), mode="constant", value=0)
|
||||
|
||||
chunk = super().forward(chunk)
|
||||
y[:, :, yi : yi + chunk.shape[2], :] = chunk
|
||||
yi += chunk.shape[2]
|
||||
del chunk
|
||||
|
||||
if ei == org_shape[2]:
|
||||
break
|
||||
i += step
|
||||
|
||||
assert yi == org_shape[2], f"yi={yi}, org_shape[2]={org_shape[2]}"
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache() # This helps reduce peak memory usage, but slows down a bit
|
||||
return y
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
"""
|
||||
Residual block with two convolutions, group normalization, and swish activation.
|
||||
@@ -69,19 +156,29 @@ class ResnetBlock(nn.Module):
|
||||
Number of output channels.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int):
|
||||
def __init__(self, in_channels: int, out_channels: int, chunk_size: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
self.conv1 = Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
|
||||
self.conv2 = Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if chunk_size is None or chunk_size <= 0:
|
||||
self.conv1 = Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.conv2 = Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# Skip connection projection for channel dimension mismatch
|
||||
if self.in_channels != self.out_channels:
|
||||
self.nin_shortcut = Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
# Skip connection projection for channel dimension mismatch
|
||||
if self.in_channels != self.out_channels:
|
||||
self.nin_shortcut = Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
else:
|
||||
self.conv1 = ChunkedConv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size)
|
||||
self.conv2 = ChunkedConv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size)
|
||||
|
||||
# Skip connection projection for channel dimension mismatch
|
||||
if self.in_channels != self.out_channels:
|
||||
self.nin_shortcut = ChunkedConv2d(
|
||||
in_channels, out_channels, kernel_size=1, stride=1, padding=0, chunk_size=chunk_size
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
h = x
|
||||
@@ -113,12 +210,17 @@ class Downsample(nn.Module):
|
||||
Number of output channels (must be divisible by 4).
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int):
|
||||
def __init__(self, in_channels: int, out_channels: int, chunk_size: Optional[int] = None):
|
||||
super().__init__()
|
||||
factor = 4 # 2x2 spatial reduction factor
|
||||
assert out_channels % factor == 0
|
||||
|
||||
self.conv = Conv2d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1)
|
||||
if chunk_size is None or chunk_size <= 0:
|
||||
self.conv = Conv2d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1)
|
||||
else:
|
||||
self.conv = ChunkedConv2d(
|
||||
in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size
|
||||
)
|
||||
self.group_size = factor * in_channels // out_channels
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
@@ -147,10 +249,15 @@ class Upsample(nn.Module):
|
||||
Number of output channels.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int):
|
||||
def __init__(self, in_channels: int, out_channels: int, chunk_size: Optional[int] = None):
|
||||
super().__init__()
|
||||
factor = 4 # 2x2 spatial expansion factor
|
||||
self.conv = Conv2d(in_channels, out_channels * factor, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
if chunk_size is None or chunk_size <= 0:
|
||||
self.conv = Conv2d(in_channels, out_channels * factor, kernel_size=3, stride=1, padding=1)
|
||||
else:
|
||||
self.conv = ChunkedConv2d(in_channels, out_channels * factor, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size)
|
||||
|
||||
self.repeats = factor * out_channels // in_channels
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
@@ -191,6 +298,7 @@ class Encoder(nn.Module):
|
||||
block_out_channels: Tuple[int, ...],
|
||||
num_res_blocks: int,
|
||||
ffactor_spatial: int,
|
||||
chunk_size: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
assert block_out_channels[-1] % (2 * z_channels) == 0
|
||||
@@ -199,7 +307,12 @@ class Encoder(nn.Module):
|
||||
self.block_out_channels = block_out_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
|
||||
self.conv_in = Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
|
||||
if chunk_size is None or chunk_size <= 0:
|
||||
self.conv_in = Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
|
||||
else:
|
||||
self.conv_in = ChunkedConv2d(
|
||||
in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1, chunk_size=chunk_size
|
||||
)
|
||||
|
||||
self.down = nn.ModuleList()
|
||||
block_in = block_out_channels[0]
|
||||
@@ -211,7 +324,7 @@ class Encoder(nn.Module):
|
||||
|
||||
# Add residual blocks for this level
|
||||
for _ in range(self.num_res_blocks):
|
||||
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
||||
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, chunk_size=chunk_size))
|
||||
block_in = block_out
|
||||
|
||||
down = nn.Module()
|
||||
@@ -222,20 +335,23 @@ class Encoder(nn.Module):
|
||||
if add_spatial_downsample:
|
||||
assert i_level < len(block_out_channels) - 1
|
||||
block_out = block_out_channels[i_level + 1]
|
||||
down.downsample = Downsample(block_in, block_out)
|
||||
down.downsample = Downsample(block_in, block_out, chunk_size=chunk_size)
|
||||
block_in = block_out
|
||||
|
||||
self.down.append(down)
|
||||
|
||||
# Middle blocks with attention
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, chunk_size=chunk_size)
|
||||
self.mid.attn_1 = AttnBlock(block_in, chunk_size=chunk_size)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, chunk_size=chunk_size)
|
||||
|
||||
# Output layers
|
||||
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
||||
self.conv_out = Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
|
||||
if chunk_size is None or chunk_size <= 0:
|
||||
self.conv_out = Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
|
||||
else:
|
||||
self.conv_out = ChunkedConv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
# Initial convolution
|
||||
@@ -291,6 +407,7 @@ class Decoder(nn.Module):
|
||||
block_out_channels: Tuple[int, ...],
|
||||
num_res_blocks: int,
|
||||
ffactor_spatial: int,
|
||||
chunk_size: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
assert block_out_channels[0] % z_channels == 0
|
||||
@@ -300,13 +417,16 @@ class Decoder(nn.Module):
|
||||
self.num_res_blocks = num_res_blocks
|
||||
|
||||
block_in = block_out_channels[0]
|
||||
self.conv_in = Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
||||
if chunk_size is None or chunk_size <= 0:
|
||||
self.conv_in = Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
||||
else:
|
||||
self.conv_in = ChunkedConv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size)
|
||||
|
||||
# Middle blocks with attention
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, chunk_size=chunk_size)
|
||||
self.mid.attn_1 = AttnBlock(block_in, chunk_size=chunk_size)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, chunk_size=chunk_size)
|
||||
|
||||
# Build upsampling blocks
|
||||
self.up = nn.ModuleList()
|
||||
@@ -316,7 +436,7 @@ class Decoder(nn.Module):
|
||||
|
||||
# Add residual blocks for this level (extra block for decoder)
|
||||
for _ in range(self.num_res_blocks + 1):
|
||||
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
||||
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, chunk_size=chunk_size))
|
||||
block_in = block_out
|
||||
|
||||
up = nn.Module()
|
||||
@@ -327,14 +447,17 @@ class Decoder(nn.Module):
|
||||
if add_spatial_upsample:
|
||||
assert i_level < len(block_out_channels) - 1
|
||||
block_out = block_out_channels[i_level + 1]
|
||||
up.upsample = Upsample(block_in, block_out)
|
||||
up.upsample = Upsample(block_in, block_out, chunk_size=chunk_size)
|
||||
block_in = block_out
|
||||
|
||||
self.up.append(up)
|
||||
|
||||
# Output layers
|
||||
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
||||
self.conv_out = Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if chunk_size is None or chunk_size <= 0:
|
||||
self.conv_out = Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
else:
|
||||
self.conv_out = ChunkedConv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size)
|
||||
|
||||
def forward(self, z: Tensor) -> Tensor:
|
||||
# Initial processing with skip connection
|
||||
@@ -370,7 +493,7 @@ class HunyuanVAE2D(nn.Module):
|
||||
with 32x spatial compression and optional memory-efficient tiling for large images.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, chunk_size: Optional[int] = None):
|
||||
super().__init__()
|
||||
|
||||
# Fixed configuration for Hunyuan Image-2.1
|
||||
@@ -392,6 +515,7 @@ class HunyuanVAE2D(nn.Module):
|
||||
block_out_channels=block_out_channels,
|
||||
num_res_blocks=layers_per_block,
|
||||
ffactor_spatial=ffactor_spatial,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
|
||||
self.decoder = Decoder(
|
||||
@@ -400,6 +524,7 @@ class HunyuanVAE2D(nn.Module):
|
||||
block_out_channels=list(reversed(block_out_channels)),
|
||||
num_res_blocks=layers_per_block,
|
||||
ffactor_spatial=ffactor_spatial,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
|
||||
# Spatial tiling configuration for memory efficiency
|
||||
@@ -617,9 +742,9 @@ class HunyuanVAE2D(nn.Module):
|
||||
return decoded
|
||||
|
||||
|
||||
def load_vae(vae_path: str, device: torch.device, disable_mmap: bool = False) -> HunyuanVAE2D:
|
||||
logger.info("Initializing VAE")
|
||||
vae = HunyuanVAE2D()
|
||||
def load_vae(vae_path: str, device: torch.device, disable_mmap: bool = False, chunk_size: Optional[int] = None) -> HunyuanVAE2D:
|
||||
logger.info(f"Initializing VAE with chunk_size={chunk_size}")
|
||||
vae = HunyuanVAE2D(chunk_size=chunk_size)
|
||||
|
||||
logger.info(f"Loading VAE from {vae_path}")
|
||||
state_dict = load_safetensors(vae_path, device=device, disable_mmap=disable_mmap)
|
||||
|
||||
@@ -626,6 +626,7 @@ class LatentsCachingStrategy:
|
||||
for key in npz.files:
|
||||
kwargs[key] = npz[key]
|
||||
|
||||
# TODO float() is needed if vae is in bfloat16. Remove it if vae is float16.
|
||||
kwargs["latents" + key_reso_suffix] = latents_tensor.float().cpu().numpy()
|
||||
kwargs["original_size" + key_reso_suffix] = np.array(original_size)
|
||||
kwargs["crop_ltrb" + key_reso_suffix] = np.array(crop_ltrb)
|
||||
|
||||
Reference in New Issue
Block a user