Merge pull request #1505 from liesened/patch-2

Add v-pred support for SDXL train
This commit is contained in:
Kohya S.
2024-08-24 21:16:53 +09:00
committed by GitHub

View File

@@ -702,7 +702,11 @@ def train(args):
with accelerator.autocast():
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
target = noise
if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
target = noise
if (
args.min_snr_gamma
@@ -720,7 +724,7 @@ def train(args):
loss = loss.mean([1, 2, 3])
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss: