add adaptive noise scale

This commit is contained in:
Kohya S
2023-05-07 18:09:08 +09:00
parent e54b6311ef
commit 09c719c926
7 changed files with 45 additions and 20 deletions

View File

@@ -21,7 +21,7 @@ from library.config_util import (
BlueprintGenerator,
)
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like, apply_noise_offset
def train(args):
@@ -305,8 +305,7 @@ def train(args):
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale)
elif args.multires_noise_iterations:
noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount)

View File

@@ -348,10 +348,28 @@ def get_weighted_text_embeddings(
# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2
def pyramid_noise_like(noise, device, iterations=6, discount=0.3):
b, c, w, h = noise.shape
u = torch.nn.Upsample(size=(w, h), mode='bilinear').to(device)
u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
for i in range(iterations):
r = random.random()*2+2 # Rather than always going 2x,
w, h = max(1, int(w/(r**i))), max(1, int(h/(r**i)))
r = random.random() * 2 + 2 # Rather than always going 2x,
w, h = max(1, int(w / (r**i))), max(1, int(h / (r**i)))
noise += u(torch.randn(b, c, w, h).to(device)) * discount**i
if w==1 or h==1: break # Lowest resolution is 1x1
return noise/noise.std() # Scaled back to roughly unit variance
if w == 1 or h == 1:
break # Lowest resolution is 1x1
return noise / noise.std() # Scaled back to roughly unit variance
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
if noise_offset is None:
return noise
if adaptive_noise_scale is not None:
# latent shape: (batch_size, channels, height, width)
# abs mean value for each channel
latent_mean = torch.abs(latents.mean(dim=(2, 3), keepdim=True))
# multiply adaptive noise scale to the mean value and add it to the noise offset
noise_offset = noise_offset + adaptive_noise_scale * latent_mean
noise_offset = torch.clamp(noise_offset, 0.0, None) # in case of adaptive noise scale is negative
noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
return noise

View File

@@ -2133,6 +2133,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=0.3,
help="set discount value for multires noise (has no effect without --multires_noise_iterations) / Multires noiseのdiscount値を設定する--multires_noise_iterations指定時のみ有効",
)
parser.add_argument(
"--adaptive_noise_scale",
type=float,
default=None,
help="add `latent mean absolute value * this value` to noise_offset (disabled if None, default) / latentの平均値の絶対値 * この値をnoise_offsetに加算するNoneの場合は無効、デフォルト",
)
parser.add_argument(
"--lowram",
action="store_true",
@@ -2210,6 +2216,11 @@ def verify_training_args(args: argparse.Namespace):
"noise_offset and multires_noise_iterations cannot be enabled at the same time / noise_offsetとmultires_noise_iterationsを同時に有効にすることはできません"
)
if args.adaptive_noise_scale is not None and args.noise_offset is None:
raise ValueError(
"adaptive_noise_scale requires noise_offset / adaptive_noise_scaleを使用するにはnoise_offsetが必要です"
)
def add_dataset_arguments(
parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool

View File

@@ -23,7 +23,7 @@ from library.config_util import (
BlueprintGenerator,
)
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like, apply_noise_offset
def train(args):
@@ -271,8 +271,7 @@ def train(args):
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale)
elif args.multires_noise_iterations:
noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount)

View File

@@ -25,7 +25,7 @@ from library.config_util import (
)
import library.huggingface_util as huggingface_util
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like, apply_noise_offset
# TODO 他のスクリプトと共通化する
@@ -585,11 +585,11 @@ def train(args):
else:
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale)
elif args.multires_noise_iterations:
noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount)

View File

@@ -20,7 +20,7 @@ from library.config_util import (
BlueprintGenerator,
)
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight, pyramid_noise_like
from library.custom_train_functions import apply_snr_weight, pyramid_noise_like, apply_noise_offset
imagenet_templates_small = [
"a photo of a {}",
@@ -387,8 +387,7 @@ def train(args):
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale)
elif args.multires_noise_iterations:
noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount)

View File

@@ -20,7 +20,7 @@ from library.config_util import (
BlueprintGenerator,
)
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight, pyramid_noise_like
from library.custom_train_functions import apply_snr_weight, pyramid_noise_like, apply_noise_offset
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
imagenet_templates_small = [
@@ -426,8 +426,7 @@ def train(args):
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale)
elif args.multires_noise_iterations:
noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount)