From 202f2c32927edc23788a3cab6edffde6c371c420 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= <865105819@qq.com> Date: Mon, 23 Oct 2023 21:59:14 +0800 Subject: [PATCH] Debias Estimation loss (#889) * update for bnb 0.41.1 * fixed generate_controlnet_subsets_config for training * Revert "update for bnb 0.41.1" This reverts commit 70bd3612d84778d491fc8006b8b9f9e21c4d2eb8. * add debiased_estimation_loss * add train_network * Revert "add train_network" This reverts commit 6539363c5c13a3e63fc0e52adf7fc26fb566d491. * Update train_network.py --- fine_tune.py | 5 ++++- library/custom_train_functions.py | 11 +++++++++++ sdxl_train.py | 5 ++++- sdxl_train_control_net_lllite.py | 3 +++ sdxl_train_control_net_lllite_old.py | 3 +++ train_db.py | 3 +++ train_network.py | 4 ++++ train_textual_inversion.py | 3 +++ train_textual_inversion_XTI.py | 3 +++ 9 files changed, 38 insertions(+), 2 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 2ecb4ff3..4a3f49c7 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -32,6 +32,7 @@ from library.custom_train_functions import ( get_weighted_text_embeddings, prepare_scheduler_for_custom_training, scale_v_prediction_loss_like_noise_prediction, + apply_debiased_estimation, ) @@ -339,7 +340,7 @@ def train(args): else: target = noise - if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred: + if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss,: # do not mean over batch dimension for snr weight or scale v-pred loss loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = loss.mean([1, 2, 3]) @@ -348,6 +349,8 @@ def train(args): loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) loss = loss.mean() # mean over batch dimension else: diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 677d1bf4..28b625d3 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -86,6 +86,12 @@ def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_los loss = loss + loss / scale * v_pred_like_loss return loss +def apply_debiased_estimation(loss, timesteps, noise_scheduler): + snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size + snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 + weight = 1/torch.sqrt(snr_t) + loss = weight * loss + return loss # TODO train_utilと分散しているのでどちらかに寄せる @@ -108,6 +114,11 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted default=None, help="add v-prediction like loss multiplied by this value / v-prediction lossをこの値をかけたものをlossに加算する", ) + parser.add_argument( + "--debiased_estimation_loss", + action="store_true", + help="debiased estimation loss / debiased estimation loss", + ) if support_weighted_captions: parser.add_argument( "--weighted_captions", diff --git a/sdxl_train.py b/sdxl_train.py index 7bde3cab..55c11f9c 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -34,6 +34,7 @@ from library.custom_train_functions import ( prepare_scheduler_for_custom_training, scale_v_prediction_loss_like_noise_prediction, add_v_prediction_like_loss, + apply_debiased_estimation, ) from library.sdxl_original_unet import SdxlUNet2DConditionModel @@ -548,7 +549,7 @@ def train(args): target = noise - if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.v_pred_like_loss: + if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.v_pred_like_loss or args.debiased_estimation_loss: # do not mean over batch dimension for snr weight or scale v-pred loss loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = loss.mean([1, 2, 3]) @@ -559,6 +560,8 @@ def train(args): loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) loss = loss.mean() # mean over batch dimension else: diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 0df61e84..7a141bb4 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -44,6 +44,7 @@ from library.custom_train_functions import ( pyramid_noise_like, apply_noise_offset, scale_v_prediction_loss_like_noise_prediction, + apply_debiased_estimation, ) import networks.control_net_lllite_for_train as control_net_lllite_for_train @@ -465,6 +466,8 @@ def train(args): loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 79920a97..e256badc 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -40,6 +40,7 @@ from library.custom_train_functions import ( pyramid_noise_like, apply_noise_offset, scale_v_prediction_loss_like_noise_prediction, + apply_debiased_estimation, ) import networks.control_net_lllite as control_net_lllite @@ -435,6 +436,8 @@ def train(args): loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_db.py b/train_db.py index a1b9cac8..7316c27e 100644 --- a/train_db.py +++ b/train_db.py @@ -35,6 +35,7 @@ from library.custom_train_functions import ( pyramid_noise_like, apply_noise_offset, scale_v_prediction_loss_like_noise_prediction, + apply_debiased_estimation, ) # perlin_noise, @@ -336,6 +337,8 @@ def train(args): loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_network.py b/train_network.py index 2232a384..9deb5331 100644 --- a/train_network.py +++ b/train_network.py @@ -43,6 +43,7 @@ from library.custom_train_functions import ( prepare_scheduler_for_custom_training, scale_v_prediction_loss_like_noise_prediction, add_v_prediction_like_loss, + apply_debiased_estimation, ) @@ -528,6 +529,7 @@ class NetworkTrainer: "ss_min_snr_gamma": args.min_snr_gamma, "ss_scale_weight_norms": args.scale_weight_norms, "ss_ip_noise_gamma": args.ip_noise_gamma, + "ss_debiased_estimation": bool(args.debiased_estimation_loss), } if use_user_config: @@ -811,6 +813,8 @@ class NetworkTrainer: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 252add53..6b6e7f5a 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -32,6 +32,7 @@ from library.custom_train_functions import ( prepare_scheduler_for_custom_training, scale_v_prediction_loss_like_noise_prediction, add_v_prediction_like_loss, + apply_debiased_estimation, ) imagenet_templates_small = [ @@ -582,6 +583,8 @@ class TextualInversionTrainer: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 525e612f..8dd5c672 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -34,6 +34,7 @@ from library.custom_train_functions import ( pyramid_noise_like, apply_noise_offset, scale_v_prediction_loss_like_noise_prediction, + apply_debiased_estimation, ) import library.original_unet as original_unet from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI @@ -471,6 +472,8 @@ def train(args): loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし