Merge branch 'dev' into sd3

This commit is contained in:
Kohya S
2024-10-25 19:03:27 +09:00
13 changed files with 16 additions and 18 deletions

View File

@@ -406,7 +406,7 @@ def train(args):
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
loss = loss.mean() # mean over batch dimension
else: