add train_network

This commit is contained in:
青龍聖者@bdsqlsz
2023-10-20 09:31:43 +08:00
parent fe6f189a61
commit 6539363c5c
2 changed files with 3 additions and 0 deletions

1
.gitignore vendored
View File

@@ -6,3 +6,4 @@ venv
build
.vscode
wandb
.vs

View File

@@ -813,6 +813,8 @@ class NetworkTrainer:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss:
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし