diff --git a/library/sd3_models.py b/library/sd3_models.py index 15a5b1db..b09a57db 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -871,7 +871,8 @@ class MMDiT(nn.Module): # remove pos_embed to free up memory up to 0.4 GB self.pos_embed = None - # sort latent sizes in ascending order + # remove duplcates and sort latent sizes in ascending order + latent_sizes = list(set(latent_sizes)) latent_sizes = sorted(latent_sizes) patched_sizes = [latent_size // self.patch_size for latent_size in latent_sizes] diff --git a/sd3_train.py b/sd3_train.py index 40f8c7e1..f64e2da2 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -366,6 +366,7 @@ def train(args): if args.enable_scaled_pos_embed: resolutions = train_dataset_group.get_resolutions() latent_sizes = [round(math.sqrt(res[0] * res[1])) // 8 for res in resolutions] # 8 is stride for latent + latent_sizes = list(set(latent_sizes)) # remove duplicates logger.info(f"Prepare scaled positional embeddings for resolutions: {resolutions}, sizes: {latent_sizes}") mmdit.enable_scaled_pos_embed(True, latent_sizes) diff --git a/sd3_train_network.py b/sd3_train_network.py index 9eeac05c..0739e094 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -73,6 +73,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): # set resolutions for positional embeddings if args.enable_scaled_pos_embed: latent_sizes = [round(math.sqrt(res[0] * res[1])) // 8 for res in self.resolutions] # 8 is stride for latent + latent_sizes = list(set(latent_sizes)) # remove duplicates logger.info(f"Prepare scaled positional embeddings for resolutions: {self.resolutions}, sizes: {latent_sizes}") mmdit.enable_scaled_pos_embed(True, latent_sizes)