From bca6a44974414f6dc2f7e32423938dbe09bf50af Mon Sep 17 00:00:00 2001 From: hkinghuang <178854663@qq.com> Date: Mon, 15 May 2023 11:16:08 +0800 Subject: [PATCH] Perlin noise --- library/custom_train_functions.py | 6 +++--- library/train_util.py | 6 ++++++ train_db.py | 4 +++- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index a2303a87..d9d85d45 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -410,9 +410,9 @@ def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5): amplitude *= persistence return noise -def perlin_noise(noise, device): +def perlin_noise(noise, device,octaves): b, c, w, h = noise.shape() - perlin = lambda : rand_perlin_2d_octaves(device,(w,h),(4,4),1) + perlin = lambda : rand_perlin_2d_octaves(device,(w,h),(4,4),octaves) noise_perlin_r = torch.rand(noise.shape, device=device) + perlin() noise_perlin_g = torch.rand(noise.shape, device=device) + perlin() noise_perlin_b = torch.rand(noise.shape, device=device) + perlin() @@ -420,7 +420,7 @@ def perlin_noise(noise, device): (noise_perlin_r, noise_perlin_g, noise_perlin_b), - 2) + 1) return noise_perlin diff --git a/library/train_util.py b/library/train_util.py index 2a55a446..3539a5bd 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2127,6 +2127,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: default=None, help="enable multires noise with this number of iterations (if enabled, around 6-10 is recommended) / Multires noiseを有効にしてこのイテレーション数を設定する(有効にする場合は6-10程度を推奨)", ) + parser.add_argument( + "--perlin_noise", + type=int, + default=None, + help="enable perlin noise and set the octaves", + ) parser.add_argument( "--multires_noise_discount", type=float, diff --git a/train_db.py b/train_db.py index 11af9f6b..5425a488 100644 --- a/train_db.py +++ b/train_db.py @@ -23,7 +23,7 @@ from library.config_util import ( BlueprintGenerator, ) import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like, apply_noise_offset +from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like, apply_noise_offset,perlin_noise def train(args): @@ -274,6 +274,8 @@ def train(args): noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale) elif args.multires_noise_iterations: noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount) + elif args.perlin_noise: + noise = perlin_noise(noise,latents.device,args.perlin_noise) # Get the text embedding for conditioning with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):