Update train_network.py

This commit is contained in:
gesen2egee
2024-08-04 17:36:34 +08:00
committed by GitHub
parent aa850aa531
commit cdb2d9c516

View File

@@ -192,7 +192,7 @@ class NetworkTrainer:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
total_loss += loss
average_loss = total_loss / len(timesteps_list)
return average_loss