mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
Fix crashing if image is too tall or wide.
This commit is contained in:
@@ -868,7 +868,7 @@ class MMDiT(nn.Module):
|
|||||||
self.use_scaled_pos_embed = use_scaled_pos_embed
|
self.use_scaled_pos_embed = use_scaled_pos_embed
|
||||||
|
|
||||||
if self.use_scaled_pos_embed:
|
if self.use_scaled_pos_embed:
|
||||||
# # remove pos_embed to free up memory up to 0.4 GB
|
# remove pos_embed to free up memory up to 0.4 GB
|
||||||
self.pos_embed = None
|
self.pos_embed = None
|
||||||
|
|
||||||
# sort latent sizes in ascending order
|
# sort latent sizes in ascending order
|
||||||
@@ -977,7 +977,7 @@ class MMDiT(nn.Module):
|
|||||||
# patched size
|
# patched size
|
||||||
h = (h + 1) // p
|
h = (h + 1) // p
|
||||||
w = (w + 1) // p
|
w = (w + 1) // p
|
||||||
if self.pos_embed is None:
|
if self.pos_embed is None: # should not happen
|
||||||
return get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=device)
|
return get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=device)
|
||||||
assert self.pos_embed_max_size is not None
|
assert self.pos_embed_max_size is not None
|
||||||
assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size)
|
assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size)
|
||||||
@@ -1016,13 +1016,20 @@ class MMDiT(nn.Module):
|
|||||||
if patched_size is None:
|
if patched_size is None:
|
||||||
raise ValueError(f"Area {area} is too large for the given latent sizes {self.resolution_area_to_latent_size}.")
|
raise ValueError(f"Area {area} is too large for the given latent sizes {self.resolution_area_to_latent_size}.")
|
||||||
|
|
||||||
pos_embed_size = int(patched_size * MMDiT.POS_EMBED_MAX_RATIO)
|
pos_embed = self.resolution_pos_embeds[patched_size]
|
||||||
|
pos_embed_size = round(math.sqrt(pos_embed.shape[1]))
|
||||||
if h > pos_embed_size or w > pos_embed_size:
|
if h > pos_embed_size or w > pos_embed_size:
|
||||||
# fallback to normal pos_embed
|
# # fallback to normal pos_embed
|
||||||
|
# return self.cropped_pos_embed(h * p, w * p, device=device, random_crop=random_crop)
|
||||||
|
# extend pos_embed size
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Using normal pos_embed for size {h}x{w} as it exceeds the scaled pos_embed size {pos_embed_size}. Image is too tall or wide."
|
f"Using normal pos_embed for size {h}x{w} as it exceeds the scaled pos_embed size {pos_embed_size}. Image is too tall or wide."
|
||||||
)
|
)
|
||||||
return self.cropped_pos_embed(h, w, device=device, random_crop=random_crop)
|
pos_embed_size = max(h, w)
|
||||||
|
pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, pos_embed_size, sample_size=patched_size)
|
||||||
|
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0)
|
||||||
|
self.resolution_pos_embeds[patched_size] = pos_embed
|
||||||
|
logger.info(f"Updated pos_embed for size {pos_embed_size}x{pos_embed_size}")
|
||||||
|
|
||||||
if not random_crop:
|
if not random_crop:
|
||||||
top = (pos_embed_size - h) // 2
|
top = (pos_embed_size - h) // 2
|
||||||
@@ -1031,7 +1038,6 @@ class MMDiT(nn.Module):
|
|||||||
top = torch.randint(0, pos_embed_size - h + 1, (1,)).item()
|
top = torch.randint(0, pos_embed_size - h + 1, (1,)).item()
|
||||||
left = torch.randint(0, pos_embed_size - w + 1, (1,)).item()
|
left = torch.randint(0, pos_embed_size - w + 1, (1,)).item()
|
||||||
|
|
||||||
pos_embed = self.resolution_pos_embeds[patched_size]
|
|
||||||
if pos_embed.device != device:
|
if pos_embed.device != device:
|
||||||
pos_embed = pos_embed.to(device)
|
pos_embed = pos_embed.to(device)
|
||||||
# which is better to update device, or transfer every time to device? -> 64x64 emb is 96*96*1536*4=56MB. It's okay to update device.
|
# which is better to update device, or transfer every time to device? -> 64x64 emb is 96*96*1536*4=56MB. It's okay to update device.
|
||||||
|
|||||||
Reference in New Issue
Block a user