diff --git a/train_db.py b/train_db.py index d36bd8d0..e4f1e54c 100644 --- a/train_db.py +++ b/train_db.py @@ -309,7 +309,7 @@ def train(args): break if args.logging_dir is not None: - logs = {"epoch_loss": loss_total / len(loss_list)} + logs = {"loss/epoch": loss_total / len(loss_list)} accelerator.log(logs, step=epoch+1) accelerator.wait_for_everyone()