diff --git a/library/train_util.py b/library/train_util.py index e22afe1c..ff161fea 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2946,6 +2946,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: default=None, help="name of tracker to use for logging, default is script-specific default name / ログ出力に使用するtrackerの名前、省略時はスクリプトごとのデフォルト名", ) + parser.add_argument( + "--wandb_run_name", + type=str, + default=None, + help="The name of the specific wandb session / wandb ログに表示される特定の実行の名前", + ) parser.add_argument( "--log_tracker_config", type=str, diff --git a/train_network.py b/train_network.py index 9cba78da..a75299cd 100644 --- a/train_network.py +++ b/train_network.py @@ -684,6 +684,8 @@ class NetworkTrainer: if accelerator.is_main_process: init_kwargs = {} + if args.wandb_run_name: + init_kwargs['wandb'] = {'name': args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers(