Fix crashing if image is too tall or wide.

This commit is contained in:
Kohya S
2024-10-31 21:39:07 +09:00
parent 9e23368e3d
commit 830df4abcc

View File

@@ -868,7 +868,7 @@ class MMDiT(nn.Module):
self.use_scaled_pos_embed = use_scaled_pos_embed
if self.use_scaled_pos_embed:
# # remove pos_embed to free up memory up to 0.4 GB
# remove pos_embed to free up memory up to 0.4 GB
self.pos_embed = None
# sort latent sizes in ascending order
@@ -977,7 +977,7 @@ class MMDiT(nn.Module):
# patched size
h = (h + 1) // p
w = (w + 1) // p
if self.pos_embed is None:
if self.pos_embed is None: # should not happen
return get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=device)
assert self.pos_embed_max_size is not None
assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size)
@@ -1016,13 +1016,20 @@ class MMDiT(nn.Module):
if patched_size is None:
raise ValueError(f"Area {area} is too large for the given latent sizes {self.resolution_area_to_latent_size}.")
pos_embed_size = int(patched_size * MMDiT.POS_EMBED_MAX_RATIO)
pos_embed = self.resolution_pos_embeds[patched_size]
pos_embed_size = round(math.sqrt(pos_embed.shape[1]))
if h > pos_embed_size or w > pos_embed_size:
# fallback to normal pos_embed
# # fallback to normal pos_embed
# return self.cropped_pos_embed(h * p, w * p, device=device, random_crop=random_crop)
# extend pos_embed size
logger.warning(
f"Using normal pos_embed for size {h}x{w} as it exceeds the scaled pos_embed size {pos_embed_size}. Image is too tall or wide."
)
return self.cropped_pos_embed(h, w, device=device, random_crop=random_crop)
pos_embed_size = max(h, w)
pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, pos_embed_size, sample_size=patched_size)
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0)
self.resolution_pos_embeds[patched_size] = pos_embed
logger.info(f"Updated pos_embed for size {pos_embed_size}x{pos_embed_size}")
if not random_crop:
top = (pos_embed_size - h) // 2
@@ -1031,7 +1038,6 @@ class MMDiT(nn.Module):
top = torch.randint(0, pos_embed_size - h + 1, (1,)).item()
left = torch.randint(0, pos_embed_size - w + 1, (1,)).item()
pos_embed = self.resolution_pos_embeds[patched_size]
if pos_embed.device != device:
pos_embed = pos_embed.to(device)
# which is better to update device, or transfer every time to device? -> 64x64 emb is 96*96*1536*4=56MB. It's okay to update device.