From abff4b0ec7bb37b338924e38392593f2bea2b8d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Sat, 7 Dec 2024 16:12:46 +0800 Subject: [PATCH] Unify controlnet parameters name and change scripts name. (#1821) * Update sd3_train.py * add freeze block lr * Update train_util.py * update * Revert "add freeze block lr" This reverts commit 8b1653548f8f219e5be2cde96f65a8813cf9ea1f. # Conflicts: # library/train_util.py # sd3_train.py * use same control net model path * use controlnet_model_name_or_path --- flux_train_control_net.py | 2 +- library/flux_train_utils.py | 2 +- sdxl_train_control_net.py | 8 ++++---- train_controlnet.py => train_control_net.py | 0 4 files changed, 6 insertions(+), 6 deletions(-) rename train_controlnet.py => train_control_net.py (100%) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index 5548fd99..9d36a41d 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -265,7 +265,7 @@ def train(args): # load controlnet controlnet_dtype = torch.float32 if args.deepspeed else weight_dtype controlnet = flux_utils.load_controlnet( - args.controlnet, is_schnell, controlnet_dtype, accelerator.device, args.disable_mmap_load_safetensors + args.controlnet_model_name_or_path, is_schnell, controlnet_dtype, accelerator.device, args.disable_mmap_load_safetensors ) controlnet.train() diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index de2e2b48..f7f06c5c 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -564,7 +564,7 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): ) parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)") parser.add_argument( - "--controlnet", + "--controlnet_model_name_or_path", type=str, default=None, help="path to controlnet (*.sft or *.safetensors) / controlnetのパス(*.sftまたは*.safetensors)" diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py index 01387409..ffbf03ca 100644 --- a/sdxl_train_control_net.py +++ b/sdxl_train_control_net.py @@ -184,12 +184,12 @@ def train(args): # make control net logger.info("make ControlNet") - if args.controlnet_model_path: + if args.controlnet_model_name_or_path: with init_empty_weights(): control_net = SdxlControlNet() - logger.info(f"load ControlNet from {args.controlnet_model_path}") - filename = args.controlnet_model_path + logger.info(f"load ControlNet from {args.controlnet_model_name_or_path}") + filename = args.controlnet_model_name_or_path if os.path.splitext(filename)[1] == ".safetensors": state_dict = load_file(filename) else: @@ -675,7 +675,7 @@ def setup_parser() -> argparse.ArgumentParser: sdxl_train_util.add_sdxl_training_arguments(parser) parser.add_argument( - "--controlnet_model_path", + "--controlnet_model_name_or_path", type=str, default=None, help="controlnet model name or path / controlnetのモデル名またはパス", diff --git a/train_controlnet.py b/train_control_net.py similarity index 100% rename from train_controlnet.py rename to train_control_net.py