change option name for merging network weights

This commit is contained in:
Kohya S
2023-05-30 23:19:29 +09:00
parent fc00691898
commit c437dce056

View File

@@ -155,22 +155,22 @@ def train(args):
print("import network module:", args.network_module)
network_module = importlib.import_module(args.network_module)
if args.base_modules is not None:
# base_modules が指定されている場合は、指定されたモジュールを読み込みマージする
for i, module_path in enumerate(args.base_modules):
print(f"merging module: {module_path}")
if args.base_weights is not None:
# base_weights が指定されている場合は、指定された重みを読み込みマージする
for i, weight_path in enumerate(args.base_weights):
print(f"merging module: {weight_path}")
if args.base_modules_weights is None or len(args.base_modules_weights) <= i:
weight = 1.0
if args.base_weights_multiplier is None or len(args.base_weights_multiplier) <= i:
multiplier = 1.0
else:
weight = args.base_modules_weights[i]
multiplier = args.base_weights_multiplier[i]
module, weights_sd = network_module.create_network_from_weights(
weight, module_path, vae, text_encoder, unet, for_inference=True
multiplier, weight_path, vae, text_encoder, unet, for_inference=True
)
module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu")
print(f"all modules merged: {', '.join(args.base_modules)}")
print(f"all weights merged: {', '.join(args.base_weights)}")
# 学習を準備する
if cache_latents:
@@ -789,18 +789,18 @@ def setup_parser() -> argparse.ArgumentParser:
help="automatically determine dim (rank) from network_weights / dim (rank)をnetwork_weightsで指定した重みから自動で決定する",
)
parser.add_argument(
"--base_modules",
"--base_weights",
type=str,
default=None,
nargs="*",
help="base modules for differential learning / 差分学習用のベースモデ",
help="network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みファイ",
)
parser.add_argument(
"--base_modules_weight",
"--base_weights_multiplier",
type=float,
default=None,
nargs="*",
help="weights of base modules for differential learning / 差分学習用のベースモデルの比重",
help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率",
)
return parser