mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
change option name for merging network weights
This commit is contained in:
@@ -155,22 +155,22 @@ def train(args):
|
||||
print("import network module:", args.network_module)
|
||||
network_module = importlib.import_module(args.network_module)
|
||||
|
||||
if args.base_modules is not None:
|
||||
# base_modules が指定されている場合は、指定されたモジュールを読み込みマージする
|
||||
for i, module_path in enumerate(args.base_modules):
|
||||
print(f"merging module: {module_path}")
|
||||
if args.base_weights is not None:
|
||||
# base_weights が指定されている場合は、指定された重みを読み込みマージする
|
||||
for i, weight_path in enumerate(args.base_weights):
|
||||
print(f"merging module: {weight_path}")
|
||||
|
||||
if args.base_modules_weights is None or len(args.base_modules_weights) <= i:
|
||||
weight = 1.0
|
||||
if args.base_weights_multiplier is None or len(args.base_weights_multiplier) <= i:
|
||||
multiplier = 1.0
|
||||
else:
|
||||
weight = args.base_modules_weights[i]
|
||||
multiplier = args.base_weights_multiplier[i]
|
||||
|
||||
module, weights_sd = network_module.create_network_from_weights(
|
||||
weight, module_path, vae, text_encoder, unet, for_inference=True
|
||||
multiplier, weight_path, vae, text_encoder, unet, for_inference=True
|
||||
)
|
||||
module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu")
|
||||
|
||||
print(f"all modules merged: {', '.join(args.base_modules)}")
|
||||
print(f"all weights merged: {', '.join(args.base_weights)}")
|
||||
|
||||
# 学習を準備する
|
||||
if cache_latents:
|
||||
@@ -789,18 +789,18 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="automatically determine dim (rank) from network_weights / dim (rank)をnetwork_weightsで指定した重みから自動で決定する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base_modules",
|
||||
"--base_weights",
|
||||
type=str,
|
||||
default=None,
|
||||
nargs="*",
|
||||
help="base modules for differential learning / 差分学習用のベースモデル",
|
||||
help="network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みファイル",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base_modules_weight",
|
||||
"--base_weights_multiplier",
|
||||
type=float,
|
||||
default=None,
|
||||
nargs="*",
|
||||
help="weights of base modules for differential learning / 差分学習用のベースモデルの比重",
|
||||
help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
Reference in New Issue
Block a user