From efef5c8ead18d98770350540914bb14545509482 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Fri, 27 Oct 2023 17:59:58 +0900 Subject: [PATCH] Show "avr_loss" instead of "loss" because it is moving average --- fine_tune.py | 2 +- sdxl_train.py | 2 +- train_db.py | 2 +- train_network.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 0de4aff1..c5e99ad4 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -407,7 +407,7 @@ def train(args): loss_recorder.add(epoch=epoch, step=step, loss=current_loss) avr_loss: float = loss_recorder.get_moving_average() - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if global_step >= args.max_train_steps: diff --git a/sdxl_train.py b/sdxl_train.py index 5e5d528d..096c89e9 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -634,7 +634,7 @@ def train(args): loss_recorder.add(epoch=epoch, step=step, loss=current_loss) avr_loss: float = loss_recorder.get_moving_average() - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if global_step >= args.max_train_steps: diff --git a/train_db.py b/train_db.py index 221a1e45..11230349 100644 --- a/train_db.py +++ b/train_db.py @@ -393,7 +393,7 @@ def train(args): loss_recorder.add(epoch=epoch, step=step, loss=current_loss) avr_loss: float = loss_recorder.get_moving_average() - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if global_step >= args.max_train_steps: diff --git a/train_network.py b/train_network.py index aeefe2a5..58f7e445 100644 --- a/train_network.py +++ b/train_network.py @@ -855,7 +855,7 @@ class NetworkTrainer: current_loss = loss.detach().item() loss_recorder.add(epoch=epoch, step=step, loss=current_loss) avr_loss: float = loss_recorder.get_moving_average() - logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if args.scale_weight_norms: