Fix min-snr-gamma for v-prediction and ZSNR.

This fixes min-snr for vpred+zsnr by dividing directly by SNR+1.
The old implementation did it in two steps: (min-snr/snr) * (snr/(snr+1)), which causes division by zero when combined with --zero_terminal_snr
This commit is contained in:
feffy380
2023-11-07 23:02:25 +01:00
parent f3e69531c3
commit 6b3148fd3f
9 changed files with 14 additions and 11 deletions

View File

@@ -460,7 +460,7 @@ def train(args):
loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss: