diff --git a/fine_tune.py b/fine_tune.py index b556672d..19a35229 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -383,10 +383,10 @@ def train(args): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: + if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # mean over batch dimension else: diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 2a513dc5..faf44304 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -96,10 +96,13 @@ def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_los return loss -def apply_debiased_estimation(loss, timesteps, noise_scheduler): +def apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=False): snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 - weight = 1 / torch.sqrt(snr_t) + if v_prediction: + weight = 1 / (snr_t + 1) + else: + weight = 1 / torch.sqrt(snr_t) loss = weight * loss return loss diff --git a/library/train_util.py b/library/train_util.py index 27910dc9..adb983d2 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3731,6 +3731,11 @@ def verify_training_args(args: argparse.Namespace): raise ValueError( "scale_v_pred_loss_like_noise_pred can be enabled only with v_parameterization / scale_v_pred_loss_like_noise_predはv_parameterizationが有効なときのみ有効にできます" ) + + if args.scale_v_pred_loss_like_noise_pred and args.zero_terminal_snr: + raise ValueError( + "zero_terminal_snr enabled. scale_v_pred_loss_like_noise_pred will not be used / zero_terminal_snrが有効です。scale_v_pred_loss_like_noise_predは使用されません" + ) if args.v_pred_like_loss and args.v_parameterization: raise ValueError( diff --git a/sdxl_train.py b/sdxl_train.py index e0a8f2b2..44ee9233 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -725,12 +725,12 @@ def train(args): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: + if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # mean over batch dimension else: diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 5ff060a9..436f0e19 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -474,12 +474,12 @@ def train(args): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: + if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 292a0463..8fba9eba 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -434,12 +434,12 @@ def train(args): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: + if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_db.py b/train_db.py index 2c7f0258..d5a94a56 100644 --- a/train_db.py +++ b/train_db.py @@ -370,10 +370,10 @@ def train(args): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: + if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_network.py b/train_network.py index 044ec3aa..790fbfc9 100644 --- a/train_network.py +++ b/train_network.py @@ -993,12 +993,12 @@ class NetworkTrainer: if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: + if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 96e7bd50..10b34db5 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -598,12 +598,12 @@ class TextualInversionTrainer: if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: + if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index efb59137..084b90c6 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -483,10 +483,10 @@ def train(args): loss = loss * loss_weights if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: + if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし