diff --git a/fine_tune.py b/fine_tune.py index 9d42c873..442bd132 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -21,7 +21,7 @@ from library.config_util import ( BlueprintGenerator, ) import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings +from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like def train(args): @@ -307,6 +307,8 @@ def train(args): if args.noise_offset: # https://www.crosslabs.org//blog/diffusion-with-offset-noise noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + elif args.multires_noise_iterations: + noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount) # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 7eb829fa..70a33e11 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -1,5 +1,6 @@ import torch import argparse +import random import re from typing import List, Optional, Union @@ -342,3 +343,15 @@ def get_weighted_text_embeddings( text_embeddings = text_embeddings * (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) return text_embeddings + + +# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2 +def pyramid_noise_like(noise, device, iterations=6, discount=0.3): + b, c, w, h = noise.shape + u = torch.nn.Upsample(size=(w, h), mode='bilinear').to(device) + for i in range(iterations): + r = random.random()*2+2 # Rather than always going 2x, + w, h = max(1, int(w/(r**i))), max(1, int(h/(r**i))) + noise += u(torch.randn(b, c, w, h).to(device)) * discount**i + if w==1 or h==1: break # Lowest resolution is 1x1 + return noise/noise.std() # Scaled back to roughly unit variance diff --git a/library/train_util.py b/library/train_util.py index cac4cdc5..f25e6065 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2121,6 +2121,18 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: default=None, help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)", ) + parser.add_argument( + "--multires_noise_iterations", + type=int, + default=None, + help="enable multires noise with this number of iterations (if enabled, around 6-10 is recommended)" + ) + parser.add_argument( + "--multires_noise_discount", + type=float, + default=0.3, + help="set discount value for multires noise (has no effect without --multires_noise_iterations)" + ) parser.add_argument( "--lowram", action="store_true", diff --git a/train_db.py b/train_db.py index ad7a317e..90ee1bb1 100644 --- a/train_db.py +++ b/train_db.py @@ -23,7 +23,7 @@ from library.config_util import ( BlueprintGenerator, ) import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings +from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like def train(args): @@ -273,6 +273,8 @@ def train(args): if args.noise_offset: # https://www.crosslabs.org//blog/diffusion-with-offset-noise noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + elif args.multires_noise_iterations: + noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount) # Get the text embedding for conditioning with torch.set_grad_enabled(global_step < args.stop_text_encoder_training): diff --git a/train_network.py b/train_network.py index 3f95c5f7..4c4cc281 100644 --- a/train_network.py +++ b/train_network.py @@ -25,7 +25,7 @@ from library.config_util import ( ) import library.huggingface_util as huggingface_util import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings +from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like # TODO 他のスクリプトと共通化する @@ -342,6 +342,8 @@ def train(args): "ss_seed": args.seed, "ss_lowram": args.lowram, "ss_noise_offset": args.noise_offset, + "ss_multires_noise_iterations": args.multires_noise_iterations, + "ss_multires_noise_discount": args.multires_noise_discount, "ss_training_comment": args.training_comment, # will not be updated after training "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(), "ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""), @@ -588,6 +590,8 @@ def train(args): if args.noise_offset: # https://www.crosslabs.org//blog/diffusion-with-offset-noise noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + elif args.multires_noise_iterations: + noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount) # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index c11a199f..301aae7a 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -20,7 +20,7 @@ from library.config_util import ( BlueprintGenerator, ) import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import apply_snr_weight +from library.custom_train_functions import apply_snr_weight, pyramid_noise_like imagenet_templates_small = [ "a photo of a {}", @@ -389,6 +389,8 @@ def train(args): if args.noise_offset: # https://www.crosslabs.org//blog/diffusion-with-offset-noise noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + elif args.multires_noise_iterations: + noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount) # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 5342a695..2aa6cd7f 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -20,7 +20,7 @@ from library.config_util import ( BlueprintGenerator, ) import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import apply_snr_weight +from library.custom_train_functions import apply_snr_weight, pyramid_noise_like from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI imagenet_templates_small = [ @@ -428,6 +428,8 @@ def train(args): if args.noise_offset: # https://www.crosslabs.org//blog/diffusion-with-offset-noise noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + elif args.multires_noise_iterations: + noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount) # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)