mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
Unify controlnet parameters name and change scripts name. (#1821)
* Update sd3_train.py
* add freeze block lr
* Update train_util.py
* update
* Revert "add freeze block lr"
This reverts commit 8b1653548f.
# Conflicts:
# library/train_util.py
# sd3_train.py
* use same control net model path
* use controlnet_model_name_or_path
This commit is contained in:
@@ -265,7 +265,7 @@ def train(args):
|
||||
# load controlnet
|
||||
controlnet_dtype = torch.float32 if args.deepspeed else weight_dtype
|
||||
controlnet = flux_utils.load_controlnet(
|
||||
args.controlnet, is_schnell, controlnet_dtype, accelerator.device, args.disable_mmap_load_safetensors
|
||||
args.controlnet_model_name_or_path, is_schnell, controlnet_dtype, accelerator.device, args.disable_mmap_load_safetensors
|
||||
)
|
||||
controlnet.train()
|
||||
|
||||
|
||||
@@ -564,7 +564,7 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
|
||||
)
|
||||
parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)")
|
||||
parser.add_argument(
|
||||
"--controlnet",
|
||||
"--controlnet_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="path to controlnet (*.sft or *.safetensors) / controlnetのパス(*.sftまたは*.safetensors)"
|
||||
|
||||
@@ -184,12 +184,12 @@ def train(args):
|
||||
|
||||
# make control net
|
||||
logger.info("make ControlNet")
|
||||
if args.controlnet_model_path:
|
||||
if args.controlnet_model_name_or_path:
|
||||
with init_empty_weights():
|
||||
control_net = SdxlControlNet()
|
||||
|
||||
logger.info(f"load ControlNet from {args.controlnet_model_path}")
|
||||
filename = args.controlnet_model_path
|
||||
logger.info(f"load ControlNet from {args.controlnet_model_name_or_path}")
|
||||
filename = args.controlnet_model_name_or_path
|
||||
if os.path.splitext(filename)[1] == ".safetensors":
|
||||
state_dict = load_file(filename)
|
||||
else:
|
||||
@@ -675,7 +675,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
sdxl_train_util.add_sdxl_training_arguments(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--controlnet_model_path",
|
||||
"--controlnet_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="controlnet model name or path / controlnetのモデル名またはパス",
|
||||
|
||||
Reference in New Issue
Block a user