From b18d0992914d21e794d28000e3e5be53066e9164 Mon Sep 17 00:00:00 2001 From: Pam Date: Tue, 2 May 2023 09:42:17 +0500 Subject: [PATCH] Multi-Resolution Noise --- fine_tune.py | 4 +++- library/custom_train_functions.py | 13 +++++++++++++ library/train_util.py | 12 ++++++++++++ train_db.py | 4 +++- train_network.py | 6 +++++- train_textual_inversion.py | 4 +++- train_textual_inversion_XTI.py | 4 +++- 7 files changed, 42 insertions(+), 5 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index b6a8d1d7..f0641e85 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -21,7 +21,7 @@ from library.config_util import ( BlueprintGenerator, ) import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings +from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like def train(args): @@ -304,6 +304,8 @@ def train(args): if args.noise_offset: # https://www.crosslabs.org//blog/diffusion-with-offset-noise noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + elif args.multires_noise_iterations: + noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount) # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 7eb829fa..70a33e11 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -1,5 +1,6 @@ import torch import argparse +import random import re from typing import List, Optional, Union @@ -342,3 +343,15 @@ def get_weighted_text_embeddings( text_embeddings = text_embeddings * (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) return text_embeddings + + +# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2 +def pyramid_noise_like(noise, device, iterations=6, discount=0.3): + b, c, w, h = noise.shape + u = torch.nn.Upsample(size=(w, h), mode='bilinear').to(device) + for i in range(iterations): + r = random.random()*2+2 # Rather than always going 2x, + w, h = max(1, int(w/(r**i))), max(1, int(h/(r**i))) + noise += u(torch.randn(b, c, w, h).to(device)) * discount**i + if w==1 or h==1: break # Lowest resolution is 1x1 + return noise/noise.std() # Scaled back to roughly unit variance diff --git a/library/train_util.py b/library/train_util.py index 8c6e3437..2c107237 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2119,6 +2119,18 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: default=None, help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)", ) + parser.add_argument( + "--multires_noise_iterations", + type=int, + default=None, + help="enable multires noise with this number of iterations (if enabled, around 6-10 is recommended)" + ) + parser.add_argument( + "--multires_noise_discount", + type=float, + default=0.3, + help="set discount value for multires noise (has no effect without --multires_noise_iterations)" + ) parser.add_argument( "--lowram", action="store_true", diff --git a/train_db.py b/train_db.py index 178d5cb4..4d054e9a 100644 --- a/train_db.py +++ b/train_db.py @@ -23,7 +23,7 @@ from library.config_util import ( BlueprintGenerator, ) import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings +from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like def train(args): @@ -270,6 +270,8 @@ def train(args): if args.noise_offset: # https://www.crosslabs.org//blog/diffusion-with-offset-noise noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + elif args.multires_noise_iterations: + noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount) # Get the text embedding for conditioning with torch.set_grad_enabled(global_step < args.stop_text_encoder_training): diff --git a/train_network.py b/train_network.py index 5c4d5ad1..60007433 100644 --- a/train_network.py +++ b/train_network.py @@ -26,7 +26,7 @@ from library.config_util import ( ) import library.huggingface_util as huggingface_util import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings +from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like # TODO 他のスクリプトと共通化する @@ -366,6 +366,8 @@ def train(args): "ss_seed": args.seed, "ss_lowram": args.lowram, "ss_noise_offset": args.noise_offset, + "ss_multires_noise_iterations": args.multires_noise_iterations, + "ss_multires_noise_discount": args.multires_noise_discount, "ss_training_comment": args.training_comment, # will not be updated after training "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(), "ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""), @@ -612,6 +614,8 @@ def train(args): if args.noise_offset: # https://www.crosslabs.org//blog/diffusion-with-offset-noise noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + elif args.multires_noise_iterations: + noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount) # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index fb6b6053..d77a8878 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -20,7 +20,7 @@ from library.config_util import ( BlueprintGenerator, ) import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import apply_snr_weight +from library.custom_train_functions import apply_snr_weight, pyramid_noise_like imagenet_templates_small = [ "a photo of a {}", @@ -386,6 +386,8 @@ def train(args): if args.noise_offset: # https://www.crosslabs.org//blog/diffusion-with-offset-noise noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + elif args.multires_noise_iterations: + noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount) # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 69ec3eb1..27c5c2df 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -20,7 +20,7 @@ from library.config_util import ( BlueprintGenerator, ) import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import apply_snr_weight +from library.custom_train_functions import apply_snr_weight, pyramid_noise_like from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI imagenet_templates_small = [ @@ -425,6 +425,8 @@ def train(args): if args.noise_offset: # https://www.crosslabs.org//blog/diffusion-with-offset-noise noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) + elif args.multires_noise_iterations: + noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount) # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)