From 6d2d8dfd2f6d739620de686ae4b2c9e76bc75708 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 18 Jul 2023 23:17:23 +0900 Subject: [PATCH] add zero_terminal_snr option --- fine_tune.py | 4 ++-- library/custom_train_functions.py | 36 +++++++++++++++++++++++++++++++ library/sdxl_train_util.py | 7 +++--- library/train_util.py | 13 ++++++++++- sdxl_train.py | 2 ++ train_db.py | 2 ++ train_network.py | 3 +++ train_textual_inversion.py | 2 ++ train_textual_inversion_XTI.py | 2 ++ 9 files changed, 65 insertions(+), 6 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 58a6cda0..a906b238 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -23,8 +23,6 @@ from library.custom_train_functions import ( apply_snr_weight, get_weighted_text_embeddings, prepare_scheduler_for_custom_training, - pyramid_noise_like, - apply_noise_offset, scale_v_prediction_loss_like_noise_prediction, ) @@ -273,6 +271,8 @@ def train(args): beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False ) prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) + if args.zero_terminal_snr: + custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) if accelerator.is_main_process: accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name) diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index eacc23d8..5b6106fb 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -18,6 +18,42 @@ def prepare_scheduler_for_custom_training(noise_scheduler, device): noise_scheduler.all_snr = all_snr.to(device) +def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler): + # fix beta: zero terminal SNR + print(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891") + + def enforce_zero_terminal_snr(betas): + # Convert betas to alphas_bar_sqrt + alphas = 1 - betas + alphas_bar = alphas.cumprod(0) + alphas_bar_sqrt = alphas_bar.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + # Shift so last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + # Scale so first timestep is back to old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 + alphas = alphas_bar[1:] / alphas_bar[:-1] + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + return betas + + betas = noise_scheduler.betas + betas = enforce_zero_terminal_snr(betas) + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + + # print("original:", noise_scheduler.betas) + # print("fixed:", betas) + + noise_scheduler.betas = betas + noise_scheduler.alphas = alphas + noise_scheduler.alphas_cumprod = alphas_cumprod def apply_snr_weight(loss, timesteps, noise_scheduler, gamma): snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index 48785ca6..100b1cf8 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -343,9 +343,10 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser): def verify_sdxl_training_args(args: argparse.Namespace): - assert ( - not args.v2 and not args.v_parameterization - ), "v2 or v_parameterization cannot be enabled in SDXL training / SDXL学習ではv2とv_parameterizationを有効にすることはできません" + assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません" + if args.v_parameterization: + print("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります") + if args.clip_skip is not None: print("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません") diff --git a/library/train_util.py b/library/train_util.py index 785dc0f9..e6e5c3c4 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2750,6 +2750,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: default=None, help="add `latent mean absolute value * this value` to noise_offset (disabled if None, default) / latentの平均値の絶対値 * この値をnoise_offsetに加算する(Noneの場合は無効、デフォルト)", ) + parser.add_argument( + "--zero_terminal_snr", + action="store_true", + help="fix noise scheduler betas to enforce zero terminal SNR / noise schedulerのbetasを修正して、zero terminal SNRを強制する", + ) parser.add_argument( "--min_timestep", type=int, @@ -2825,7 +2830,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: def verify_training_args(args: argparse.Namespace): if args.v_parameterization and not args.v2: - print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") + print("v_parameterization should be with v2 not v1 or sdxl / v1やsdxlでv_parameterizationを使用することは想定されていません") if args.v2 and args.clip_skip is not None: print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") @@ -2856,6 +2861,12 @@ def verify_training_args(args: argparse.Namespace): "scale_v_pred_loss_like_noise_pred can be enabled only with v_parameterization / scale_v_pred_loss_like_noise_predはv_parameterizationが有効なときのみ有効にできます" ) + if args.zero_terminal_snr and not args.v_parameterization: + print( + f"zero_terminal_snr is enabled, but v_parameterization is not enabled. training will be unexpected" + + " / zero_terminal_snrが有効ですが、v_parameterizationが有効ではありません。学習結果は想定外になる可能性があります" + ) + def add_dataset_arguments( parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool diff --git a/sdxl_train.py b/sdxl_train.py index 8459671c..630d8832 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -350,6 +350,8 @@ def train(args): beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False ) prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) + if args.zero_terminal_snr: + custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) if accelerator.is_main_process: accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name) diff --git a/train_db.py b/train_db.py index 439f4b9d..7571efc3 100644 --- a/train_db.py +++ b/train_db.py @@ -246,6 +246,8 @@ def train(args): beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False ) prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) + if args.zero_terminal_snr: + custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) if accelerator.is_main_process: accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name) diff --git a/train_network.py b/train_network.py index f7ee451b..a873537c 100644 --- a/train_network.py +++ b/train_network.py @@ -487,6 +487,7 @@ class NetworkTrainer: "ss_multires_noise_iterations": args.multires_noise_iterations, "ss_multires_noise_discount": args.multires_noise_discount, "ss_adaptive_noise_scale": args.adaptive_noise_scale, + "ss_zero_terminal_snr": args.zero_terminal_snr, "ss_training_comment": args.training_comment, # will not be updated after training "ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(), "ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""), @@ -670,6 +671,8 @@ class NetworkTrainer: beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False ) prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) + if args.zero_terminal_snr: + custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) if accelerator.is_main_process: accelerator.init_trackers("network_train" if args.log_tracker_name is None else args.log_tracker_name) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 7be8ba80..e227a13b 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -487,6 +487,8 @@ class TextualInversionTrainer: beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False ) prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) + if args.zero_terminal_snr: + custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) if accelerator.is_main_process: accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 0e91c71c..ba5c7d03 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -384,6 +384,8 @@ def train(args): beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False ) prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) + if args.zero_terminal_snr: + custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) if accelerator.is_main_process: accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name)