mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
fix: correct tensor indexing in HunyuanVAE2D class for blending and encoding functions
This commit is contained in:
@@ -449,7 +449,7 @@ class HunyuanVAE2D(nn.Module):
|
||||
"""
|
||||
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
|
||||
for x in range(blend_extent):
|
||||
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent)
|
||||
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
|
||||
return b
|
||||
|
||||
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
@@ -467,7 +467,7 @@ class HunyuanVAE2D(nn.Module):
|
||||
"""
|
||||
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
|
||||
for y in range(blend_extent):
|
||||
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent)
|
||||
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
|
||||
return b
|
||||
|
||||
def spatial_tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -478,9 +478,14 @@ class HunyuanVAE2D(nn.Module):
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Input tensor of shape (B, C, T, H, W).
|
||||
Input tensor of shape (B, C, T, H, W) or (B, C, H, W).
|
||||
"""
|
||||
B, C, T, H, W = x.shape
|
||||
# Handle 5D input (B, C, T, H, W) by removing time dimension
|
||||
original_ndim = x.ndim
|
||||
if original_ndim == 5:
|
||||
x = x.squeeze(2)
|
||||
|
||||
B, C, H, W = x.shape
|
||||
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
||||
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
||||
row_limit = self.tile_latent_min_size - blend_extent
|
||||
@@ -489,7 +494,7 @@ class HunyuanVAE2D(nn.Module):
|
||||
for i in range(0, H, overlap_size):
|
||||
row = []
|
||||
for j in range(0, W, overlap_size):
|
||||
tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
|
||||
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
|
||||
tile = self.encoder(tile)
|
||||
row.append(tile)
|
||||
rows.append(row)
|
||||
@@ -502,7 +507,7 @@ class HunyuanVAE2D(nn.Module):
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
||||
result_row.append(tile[:, :, :, :row_limit, :row_limit])
|
||||
result_row.append(tile[:, :, :row_limit, :row_limit])
|
||||
result_rows.append(torch.cat(result_row, dim=-1))
|
||||
|
||||
moments = torch.cat(result_rows, dim=-2)
|
||||
|
||||
Reference in New Issue
Block a user