remove duplicate resolution for scaled pos embed

This commit is contained in:
Kohya S
2024-11-01 21:43:47 +09:00
parent 9aa6f52ac3
commit 82daa98fe8
3 changed files with 4 additions and 1 deletions

View File

@@ -871,7 +871,8 @@ class MMDiT(nn.Module):
# remove pos_embed to free up memory up to 0.4 GB
self.pos_embed = None
# sort latent sizes in ascending order
# remove duplcates and sort latent sizes in ascending order
latent_sizes = list(set(latent_sizes))
latent_sizes = sorted(latent_sizes)
patched_sizes = [latent_size // self.patch_size for latent_size in latent_sizes]

View File

@@ -366,6 +366,7 @@ def train(args):
if args.enable_scaled_pos_embed:
resolutions = train_dataset_group.get_resolutions()
latent_sizes = [round(math.sqrt(res[0] * res[1])) // 8 for res in resolutions] # 8 is stride for latent
latent_sizes = list(set(latent_sizes)) # remove duplicates
logger.info(f"Prepare scaled positional embeddings for resolutions: {resolutions}, sizes: {latent_sizes}")
mmdit.enable_scaled_pos_embed(True, latent_sizes)

View File

@@ -73,6 +73,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
# set resolutions for positional embeddings
if args.enable_scaled_pos_embed:
latent_sizes = [round(math.sqrt(res[0] * res[1])) // 8 for res in self.resolutions] # 8 is stride for latent
latent_sizes = list(set(latent_sizes)) # remove duplicates
logger.info(f"Prepare scaled positional embeddings for resolutions: {self.resolutions}, sizes: {latent_sizes}")
mmdit.enable_scaled_pos_embed(True, latent_sizes)