Only add warning for deprecated scaling vpred loss function

This commit is contained in:
catboxanon
2024-10-21 08:12:53 -04:00
parent 8fc30f8205
commit e1b63c2249
9 changed files with 14 additions and 13 deletions

View File

@@ -383,7 +383,7 @@ def train(args):
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr:
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)