From 303c3410e26916d84a39e6ad04a295b610a6629b Mon Sep 17 00:00:00 2001 From: michaelgzhang <49577754+mgz-dev@users.noreply.github.com> Date: Wed, 18 Jan 2023 13:10:13 -0600 Subject: [PATCH] expand details in tensorboard logs - Update tensorboard logging to track both unet and textencoder learning rates - Update tensorboard logging to track both current and moving average epoch loss - Clean up tensorboard log variable names for dashboard formatting --- library/train_util.py | 12 ++++++++++++ train_network.py | 11 ++++++----- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 57ebf1b0..eb11d6fb 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1388,5 +1388,17 @@ def save_state_on_train_end(args: argparse.Namespace, accelerator): model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name))) +def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): + logs = {"loss/current": current_loss, "loss/average": avr_loss} + + if args.network_train_unet_only: + logs["lr/unet"] = lr_scheduler.get_last_lr()[0] + elif args.network_train_text_encoder_only: + logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0] + else: + logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0] + logs["lr/unet"] = lr_scheduler.get_last_lr()[-1] + + return logs # endregion diff --git a/train_network.py b/train_network.py index c0a881ad..cac63295 100644 --- a/train_network.py +++ b/train_network.py @@ -330,20 +330,21 @@ def train(args): global_step += 1 current_loss = loss.detach().item() - if args.logging_dir is not None: - logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]} - accelerator.log(logs, step=global_step) - loss_total += current_loss avr_loss = loss_total / (step+1) logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) + if args.logging_dir is not None: + logs = train_util.generate_step_logs(args, current_loss, avr_loss, lr_scheduler) + + accelerator.log(logs, step=global_step) + if global_step >= args.max_train_steps: break if args.logging_dir is not None: - logs = {"epoch_loss": loss_total / len(train_dataloader)} + logs = {"loss/epoch": loss_total / len(train_dataloader)} accelerator.log(logs, step=epoch+1) accelerator.wait_for_everyone()