Unify controlnet parameters name and change scripts name. (#1821)

* Update sd3_train.py

* add freeze block lr

* Update train_util.py

* update

* Revert "add freeze block lr"

This reverts commit 8b1653548f.

# Conflicts:
#	library/train_util.py
#	sd3_train.py

* use same control net model path

* use controlnet_model_name_or_path
This commit is contained in:
青龍聖者@bdsqlsz
2024-12-07 16:12:46 +08:00
committed by GitHub
parent 2be336688d
commit abff4b0ec7
4 changed files with 6 additions and 6 deletions

View File

@@ -265,7 +265,7 @@ def train(args):
# load controlnet
controlnet_dtype = torch.float32 if args.deepspeed else weight_dtype
controlnet = flux_utils.load_controlnet(
args.controlnet, is_schnell, controlnet_dtype, accelerator.device, args.disable_mmap_load_safetensors
args.controlnet_model_name_or_path, is_schnell, controlnet_dtype, accelerator.device, args.disable_mmap_load_safetensors
)
controlnet.train()

View File

@@ -564,7 +564,7 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
)
parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス*.sftまたは*.safetensors")
parser.add_argument(
"--controlnet",
"--controlnet_model_name_or_path",
type=str,
default=None,
help="path to controlnet (*.sft or *.safetensors) / controlnetのパス*.sftまたは*.safetensors"

View File

@@ -184,12 +184,12 @@ def train(args):
# make control net
logger.info("make ControlNet")
if args.controlnet_model_path:
if args.controlnet_model_name_or_path:
with init_empty_weights():
control_net = SdxlControlNet()
logger.info(f"load ControlNet from {args.controlnet_model_path}")
filename = args.controlnet_model_path
logger.info(f"load ControlNet from {args.controlnet_model_name_or_path}")
filename = args.controlnet_model_name_or_path
if os.path.splitext(filename)[1] == ".safetensors":
state_dict = load_file(filename)
else:
@@ -675,7 +675,7 @@ def setup_parser() -> argparse.ArgumentParser:
sdxl_train_util.add_sdxl_training_arguments(parser)
parser.add_argument(
"--controlnet_model_path",
"--controlnet_model_name_or_path",
type=str,
default=None,
help="controlnet model name or path / controlnetのモデル名またはパス",