fix: correct tensor indexing in HunyuanVAE2D class for blending and encoding functions

This commit is contained in:
Kohya S
2025-09-17 21:54:25 +09:00
parent cbe2a9da45
commit f5b004009e

View File

@@ -449,7 +449,7 @@ class HunyuanVAE2D(nn.Module):
"""
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent)
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
return b
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
@@ -467,7 +467,7 @@ class HunyuanVAE2D(nn.Module):
"""
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent)
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
return b
def spatial_tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
@@ -478,9 +478,14 @@ class HunyuanVAE2D(nn.Module):
Parameters
----------
x : torch.Tensor
Input tensor of shape (B, C, T, H, W).
Input tensor of shape (B, C, T, H, W) or (B, C, H, W).
"""
B, C, T, H, W = x.shape
# Handle 5D input (B, C, T, H, W) by removing time dimension
original_ndim = x.ndim
if original_ndim == 5:
x = x.squeeze(2)
B, C, H, W = x.shape
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
row_limit = self.tile_latent_min_size - blend_extent
@@ -489,7 +494,7 @@ class HunyuanVAE2D(nn.Module):
for i in range(0, H, overlap_size):
row = []
for j in range(0, W, overlap_size):
tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
tile = self.encoder(tile)
row.append(tile)
rows.append(row)
@@ -502,7 +507,7 @@ class HunyuanVAE2D(nn.Module):
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :, :row_limit, :row_limit])
result_row.append(tile[:, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=-1))
moments = torch.cat(result_rows, dim=-2)