From 8fc30f820595f80ec3f09738cc4cf01f441c41b7 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Mon, 21 Oct 2024 07:34:33 -0400 Subject: [PATCH 1/5] Fix training for V-pred and ztSNR 1) Updates debiased estimation loss function for V-pred. 2) Prevents now-deprecated scaling of loss if ztSNR is enabled. --- fine_tune.py | 4 ++-- library/custom_train_functions.py | 7 +++++-- library/train_util.py | 5 +++++ sdxl_train.py | 4 ++-- sdxl_train_control_net_lllite.py | 4 ++-- sdxl_train_control_net_lllite_old.py | 4 ++-- train_db.py | 4 ++-- train_network.py | 4 ++-- train_textual_inversion.py | 4 ++-- train_textual_inversion_XTI.py | 4 ++-- 10 files changed, 26 insertions(+), 18 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index b556672d..19a35229 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -383,10 +383,10 @@ def train(args): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: + if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # mean over batch dimension else: diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 2a513dc5..faf44304 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -96,10 +96,13 @@ def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_los return loss -def apply_debiased_estimation(loss, timesteps, noise_scheduler): +def apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=False): snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 - weight = 1 / torch.sqrt(snr_t) + if v_prediction: + weight = 1 / (snr_t + 1) + else: + weight = 1 / torch.sqrt(snr_t) loss = weight * loss return loss diff --git a/library/train_util.py b/library/train_util.py index 27910dc9..adb983d2 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3731,6 +3731,11 @@ def verify_training_args(args: argparse.Namespace): raise ValueError( "scale_v_pred_loss_like_noise_pred can be enabled only with v_parameterization / scale_v_pred_loss_like_noise_predはv_parameterizationが有効なときのみ有効にできます" ) + + if args.scale_v_pred_loss_like_noise_pred and args.zero_terminal_snr: + raise ValueError( + "zero_terminal_snr enabled. scale_v_pred_loss_like_noise_pred will not be used / zero_terminal_snrが有効です。scale_v_pred_loss_like_noise_predは使用されません" + ) if args.v_pred_like_loss and args.v_parameterization: raise ValueError( diff --git a/sdxl_train.py b/sdxl_train.py index e0a8f2b2..44ee9233 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -725,12 +725,12 @@ def train(args): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: + if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # mean over batch dimension else: diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 5ff060a9..436f0e19 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -474,12 +474,12 @@ def train(args): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: + if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 292a0463..8fba9eba 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -434,12 +434,12 @@ def train(args): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: + if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_db.py b/train_db.py index 2c7f0258..d5a94a56 100644 --- a/train_db.py +++ b/train_db.py @@ -370,10 +370,10 @@ def train(args): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: + if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_network.py b/train_network.py index 044ec3aa..790fbfc9 100644 --- a/train_network.py +++ b/train_network.py @@ -993,12 +993,12 @@ class NetworkTrainer: if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: + if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 96e7bd50..10b34db5 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -598,12 +598,12 @@ class TextualInversionTrainer: if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: + if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index efb59137..084b90c6 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -483,10 +483,10 @@ def train(args): loss = loss * loss_weights if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: + if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし From e1b63c2249345e4f14c10cbb252da68157ac13b7 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Mon, 21 Oct 2024 08:12:53 -0400 Subject: [PATCH 2/5] Only add warning for deprecated scaling vpred loss function --- fine_tune.py | 2 +- library/train_util.py | 11 ++++++----- sdxl_train.py | 2 +- sdxl_train_control_net_lllite.py | 2 +- sdxl_train_control_net_lllite_old.py | 2 +- train_db.py | 2 +- train_network.py | 2 +- train_textual_inversion.py | 2 +- train_textual_inversion_XTI.py | 2 +- 9 files changed, 14 insertions(+), 13 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 19a35229..c79f97d2 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -383,7 +383,7 @@ def train(args): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: + if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) diff --git a/library/train_util.py b/library/train_util.py index adb983d2..f479dcc6 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3727,15 +3727,16 @@ def verify_training_args(args: argparse.Namespace): if args.adaptive_noise_scale is not None and args.noise_offset is None: raise ValueError("adaptive_noise_scale requires noise_offset / adaptive_noise_scaleを使用するにはnoise_offsetが必要です") + if args.scale_v_pred_loss_like_noise_pred: + logger.warning( + f"scale_v_pred_loss_like_noise_pred is deprecated. it is suggested to use min_snr_gamma or debiased_estimation_loss" + + " / scale_v_pred_loss_like_noise_pred は非推奨です。min_snr_gammaまたはdebiased_estimation_lossを使用することをお勧めします" + ) + if args.scale_v_pred_loss_like_noise_pred and not args.v_parameterization: raise ValueError( "scale_v_pred_loss_like_noise_pred can be enabled only with v_parameterization / scale_v_pred_loss_like_noise_predはv_parameterizationが有効なときのみ有効にできます" ) - - if args.scale_v_pred_loss_like_noise_pred and args.zero_terminal_snr: - raise ValueError( - "zero_terminal_snr enabled. scale_v_pred_loss_like_noise_pred will not be used / zero_terminal_snrが有効です。scale_v_pred_loss_like_noise_predは使用されません" - ) if args.v_pred_like_loss and args.v_parameterization: raise ValueError( diff --git a/sdxl_train.py b/sdxl_train.py index 44ee9233..b533b274 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -725,7 +725,7 @@ def train(args): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: + if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 436f0e19..0e67cde5 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -474,7 +474,7 @@ def train(args): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: + if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 8fba9eba..4a01f9e2 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -434,7 +434,7 @@ def train(args): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: + if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) diff --git a/train_db.py b/train_db.py index d5a94a56..e7cf3cde 100644 --- a/train_db.py +++ b/train_db.py @@ -370,7 +370,7 @@ def train(args): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: + if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) diff --git a/train_network.py b/train_network.py index 790fbfc9..7bf125dc 100644 --- a/train_network.py +++ b/train_network.py @@ -993,7 +993,7 @@ class NetworkTrainer: if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: + if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 10b34db5..37349da7 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -598,7 +598,7 @@ class TextualInversionTrainer: if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: + if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 084b90c6..fac0787b 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -483,7 +483,7 @@ def train(args): loss = loss * loss_weights if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: + if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) From 0e7c5929336173e30d7932c0706eaf61a7d396f4 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Tue, 22 Oct 2024 11:19:34 -0400 Subject: [PATCH 3/5] Remove scale_v_pred_loss_like_noise_pred deprecation https://github.com/kohya-ss/sd-scripts/pull/1715#issuecomment-2427876376 --- library/train_util.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index f479dcc6..27910dc9 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3727,12 +3727,6 @@ def verify_training_args(args: argparse.Namespace): if args.adaptive_noise_scale is not None and args.noise_offset is None: raise ValueError("adaptive_noise_scale requires noise_offset / adaptive_noise_scaleを使用するにはnoise_offsetが必要です") - if args.scale_v_pred_loss_like_noise_pred: - logger.warning( - f"scale_v_pred_loss_like_noise_pred is deprecated. it is suggested to use min_snr_gamma or debiased_estimation_loss" - + " / scale_v_pred_loss_like_noise_pred は非推奨です。min_snr_gammaまたはdebiased_estimation_lossを使用することをお勧めします" - ) - if args.scale_v_pred_loss_like_noise_pred and not args.v_parameterization: raise ValueError( "scale_v_pred_loss_like_noise_pred can be enabled only with v_parameterization / scale_v_pred_loss_like_noise_predはv_parameterizationが有効なときのみ有効にできます" From be14c062674973d0e4fee1eb4527e04707bb72b8 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Tue, 22 Oct 2024 12:13:51 -0400 Subject: [PATCH 4/5] Remove v-pred warnings Different model architectures, such as SDXL, can take advantage of v-pred. It doesn't make sense to include these warnings anymore. --- gen_img.py | 2 -- gen_img_diffusers.py | 2 -- library/train_util.py | 4 ---- 3 files changed, 8 deletions(-) diff --git a/gen_img.py b/gen_img.py index 59bcd5b0..9427a894 100644 --- a/gen_img.py +++ b/gen_img.py @@ -1495,8 +1495,6 @@ def main(args): highres_fix = args.highres_fix_scale is not None # assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません" - if args.v_parameterization and not args.v2: - logger.warning("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") if args.v2 and args.clip_skip is not None: logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 2c40f1a0..04db4e9b 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -2216,8 +2216,6 @@ def main(args): highres_fix = args.highres_fix_scale is not None # assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません" - if args.v_parameterization and not args.v2: - logger.warning("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") if args.v2 and args.clip_skip is not None: logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") diff --git a/library/train_util.py b/library/train_util.py index 27910dc9..100ef475 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3698,10 +3698,6 @@ def verify_training_args(args: argparse.Namespace): global HIGH_VRAM HIGH_VRAM = True - if args.v_parameterization and not args.v2: - logger.warning( - "v_parameterization should be with v2 not v1 or sdxl / v1やsdxlでv_parameterizationを使用することは想定されていません" - ) if args.v2 and args.clip_skip is not None: logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") From b1e6504007aca20d15155d5c9fe880fb5e0002b8 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 25 Oct 2024 18:56:25 +0900 Subject: [PATCH 5/5] update README --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index de5cddb9..ce28d004 100644 --- a/README.md +++ b/README.md @@ -143,6 +143,9 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - bitsandbytes, transformers, accelerate and huggingface_hub are updated. - If you encounter any issues, please report them. +- Fixed a bug where the loss weight was incorrect when `--debiased_estimation_loss` was specified with `--v_parameterization`. PR [#1715](https://github.com/kohya-ss/sd-scripts/pull/1715) Thanks to catboxanon! See [the PR](https://github.com/kohya-ss/sd-scripts/pull/1715) for details. + - Removed the warning when `--v_parameterization` is specified in SDXL and SD1.5. PR [#1717](https://github.com/kohya-ss/sd-scripts/pull/1717) + - There was a bug where the min_bucket_reso/max_bucket_reso in the dataset configuration did not create the correct resolution bucket if it was not divisible by bucket_reso_steps. These values are now warned and automatically rounded to a divisible value. Thanks to Maru-mee for raising the issue. Related PR [#1632](https://github.com/kohya-ss/sd-scripts/pull/1632) - `bitsandbytes` is updated to 0.44.0. Now you can use `AdEMAMix8bit` and `PagedAdEMAMix8bit` in the training script. PR [#1640](https://github.com/kohya-ss/sd-scripts/pull/1640) Thanks to sdbds!