Fix training for V-pred and ztSNR

1) Updates debiased estimation loss function for V-pred.
2) Prevents now-deprecated scaling of loss if ztSNR is enabled.
This commit is contained in:
catboxanon
2024-10-21 07:34:33 -04:00
parent 012e7e63a5
commit 8fc30f8205
10 changed files with 26 additions and 18 deletions

View File

@@ -383,10 +383,10 @@ def train(args):
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
loss = loss.mean() # mean over batch dimension
else: