From 09c719c926c51f009bcb197ebaac5ed5e3df307b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 7 May 2023 18:09:08 +0900 Subject: [PATCH] add adaptive noise scale --- fine_tune.py | 5 ++--- library/custom_train_functions.py | 28 +++++++++++++++++++++++----- library/train_util.py | 11 +++++++++++ train_db.py | 5 ++--- train_network.py | 6 +++--- train_textual_inversion.py | 5 ++--- train_textual_inversion_XTI.py | 5 ++--- 7 files changed, 45 insertions(+), 20 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 45b27a41..05761501 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -21,7 +21,7 @@ from library.config_util import ( BlueprintGenerator, ) import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like +from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like, apply_noise_offset def train(args): @@ -305,8 +305,7 @@ def train(args): # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) if args.noise_offset: - # https://www.crosslabs.org//blog/diffusion-with-offset-noise - noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale) elif args.multires_noise_iterations: noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount) diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 70a33e11..2d387d15 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -348,10 +348,28 @@ def get_weighted_text_embeddings( # https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2 def pyramid_noise_like(noise, device, iterations=6, discount=0.3): b, c, w, h = noise.shape - u = torch.nn.Upsample(size=(w, h), mode='bilinear').to(device) + u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device) for i in range(iterations): - r = random.random()*2+2 # Rather than always going 2x, - w, h = max(1, int(w/(r**i))), max(1, int(h/(r**i))) + r = random.random() * 2 + 2 # Rather than always going 2x, + w, h = max(1, int(w / (r**i))), max(1, int(h / (r**i))) noise += u(torch.randn(b, c, w, h).to(device)) * discount**i - if w==1 or h==1: break # Lowest resolution is 1x1 - return noise/noise.std() # Scaled back to roughly unit variance + if w == 1 or h == 1: + break # Lowest resolution is 1x1 + return noise / noise.std() # Scaled back to roughly unit variance + + +# https://www.crosslabs.org//blog/diffusion-with-offset-noise +def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale): + if noise_offset is None: + return noise + if adaptive_noise_scale is not None: + # latent shape: (batch_size, channels, height, width) + # abs mean value for each channel + latent_mean = torch.abs(latents.mean(dim=(2, 3), keepdim=True)) + + # multiply adaptive noise scale to the mean value and add it to the noise offset + noise_offset = noise_offset + adaptive_noise_scale * latent_mean + noise_offset = torch.clamp(noise_offset, 0.0, None) # in case of adaptive noise scale is negative + + noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + return noise diff --git a/library/train_util.py b/library/train_util.py index ea433979..9a421808 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2133,6 +2133,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: default=0.3, help="set discount value for multires noise (has no effect without --multires_noise_iterations) / Multires noiseのdiscount値を設定する(--multires_noise_iterations指定時のみ有効)", ) + parser.add_argument( + "--adaptive_noise_scale", + type=float, + default=None, + help="add `latent mean absolute value * this value` to noise_offset (disabled if None, default) / latentの平均値の絶対値 * この値をnoise_offsetに加算する(Noneの場合は無効、デフォルト)", + ) parser.add_argument( "--lowram", action="store_true", @@ -2210,6 +2216,11 @@ def verify_training_args(args: argparse.Namespace): "noise_offset and multires_noise_iterations cannot be enabled at the same time / noise_offsetとmultires_noise_iterationsを同時に有効にすることはできません" ) + if args.adaptive_noise_scale is not None and args.noise_offset is None: + raise ValueError( + "adaptive_noise_scale requires noise_offset / adaptive_noise_scaleを使用するにはnoise_offsetが必要です" + ) + def add_dataset_arguments( parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool diff --git a/train_db.py b/train_db.py index c4b9cb19..55e2abab 100644 --- a/train_db.py +++ b/train_db.py @@ -23,7 +23,7 @@ from library.config_util import ( BlueprintGenerator, ) import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like +from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like, apply_noise_offset def train(args): @@ -271,8 +271,7 @@ def train(args): # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) if args.noise_offset: - # https://www.crosslabs.org//blog/diffusion-with-offset-noise - noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale) elif args.multires_noise_iterations: noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount) diff --git a/train_network.py b/train_network.py index 96a23de0..43f70225 100644 --- a/train_network.py +++ b/train_network.py @@ -25,7 +25,7 @@ from library.config_util import ( ) import library.huggingface_util as huggingface_util import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like +from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like, apply_noise_offset # TODO 他のスクリプトと共通化する @@ -585,11 +585,11 @@ def train(args): else: input_ids = batch["input_ids"].to(accelerator.device) encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype) + # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) if args.noise_offset: - # https://www.crosslabs.org//blog/diffusion-with-offset-noise - noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale) elif args.multires_noise_iterations: noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index c5d23acb..8da20477 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -20,7 +20,7 @@ from library.config_util import ( BlueprintGenerator, ) import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import apply_snr_weight, pyramid_noise_like +from library.custom_train_functions import apply_snr_weight, pyramid_noise_like, apply_noise_offset imagenet_templates_small = [ "a photo of a {}", @@ -387,8 +387,7 @@ def train(args): # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) if args.noise_offset: - # https://www.crosslabs.org//blog/diffusion-with-offset-noise - noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale) elif args.multires_noise_iterations: noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 7debee82..35874802 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -20,7 +20,7 @@ from library.config_util import ( BlueprintGenerator, ) import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import apply_snr_weight, pyramid_noise_like +from library.custom_train_functions import apply_snr_weight, pyramid_noise_like, apply_noise_offset from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI imagenet_templates_small = [ @@ -426,8 +426,7 @@ def train(args): # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) if args.noise_offset: - # https://www.crosslabs.org//blog/diffusion-with-offset-noise - noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale) elif args.multires_noise_iterations: noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount)