From 32b759a328fab9094e3397f282fa586d1f1d6a44 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 14 Jan 2024 22:02:03 +0900 Subject: [PATCH] Add wandb_run_name parameter to init_kwargs #1032 --- fine_tune.py | 2 ++ sdxl_train.py | 2 ++ sdxl_train_control_net_lllite.py | 2 ++ train_controlnet.py | 2 ++ train_db.py | 2 ++ train_textual_inversion.py | 2 ++ train_textual_inversion_XTI.py | 2 ++ 7 files changed, 14 insertions(+) diff --git a/fine_tune.py b/fine_tune.py index f72e618b..be61b3d1 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -291,6 +291,8 @@ def train(args): if accelerator.is_main_process: init_kwargs = {} + if args.wandb_run_name: + init_kwargs['wandb'] = {'name': args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) diff --git a/sdxl_train.py b/sdxl_train.py index 8983673d..b4ce2770 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -457,6 +457,8 @@ def train(args): if accelerator.is_main_process: init_kwargs = {} + if args.wandb_run_name: + init_kwargs['wandb'] = {'name': args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 18c6bd05..4436dd3c 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -342,6 +342,8 @@ def train(args): if accelerator.is_main_process: init_kwargs = {} + if args.wandb_run_name: + init_kwargs['wandb'] = {'name': args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( diff --git a/train_controlnet.py b/train_controlnet.py index 1f3dbae3..cc0eaab7 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -336,6 +336,8 @@ def train(args): ) if accelerator.is_main_process: init_kwargs = {} + if args.wandb_run_name: + init_kwargs['wandb'] = {'name': args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( diff --git a/train_db.py b/train_db.py index 5518740f..14d9dff1 100644 --- a/train_db.py +++ b/train_db.py @@ -268,6 +268,8 @@ def train(args): if accelerator.is_main_process: init_kwargs = {} + if args.wandb_run_name: + init_kwargs['wandb'] = {'name': args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 545b6ba8..0e3912b1 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -504,6 +504,8 @@ class TextualInversionTrainer: if accelerator.is_main_process: init_kwargs = {} + if args.wandb_run_name: + init_kwargs['wandb'] = {'name': args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 42d69d2d..71b43549 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -394,6 +394,8 @@ def train(args): if accelerator.is_main_process: init_kwargs = {} + if args.wandb_run_name: + init_kwargs['wandb'] = {'name': args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)