mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
Instantiate size_from_weights
This commit is contained in:
@@ -176,7 +176,32 @@ def train(args):
|
||||
net_kwargs[key] = value
|
||||
|
||||
# if a new network is added in future, add if ~ then blocks for each network (;'∀')
|
||||
network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs)
|
||||
if args.size_from_weights:
|
||||
network, weights = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet)
|
||||
if net_kwargs is not None:
|
||||
down_lr_weight = net_kwargs.get("down_lr_weight", None)
|
||||
mid_lr_weight = net_kwargs.get("mid_lr_weight", None)
|
||||
up_lr_weight = net_kwargs.get("up_lr_weight", None)
|
||||
if down_lr_weight is not None:
|
||||
# if some parameters are not set, use zero
|
||||
if "," in down_lr_weight:
|
||||
down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")]
|
||||
|
||||
if mid_lr_weight is not None:
|
||||
mid_lr_weight = float(mid_lr_weight)
|
||||
|
||||
if up_lr_weight is not None:
|
||||
if "," in up_lr_weight:
|
||||
up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")]
|
||||
|
||||
down_lr_weight, mid_lr_weight, up_lr_weight = network_module.get_block_lr_weight(
|
||||
down_lr_weight, mid_lr_weight, up_lr_weight, float(net_kwargs.get("block_lr_zero_threshold", 0.0))
|
||||
)
|
||||
|
||||
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
|
||||
network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
|
||||
else:
|
||||
network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs)
|
||||
if network is None:
|
||||
return
|
||||
|
||||
@@ -760,6 +785,10 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument(
|
||||
"--training_comment", type=str, default=None, help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--size_from_weights", action="store_true", help="only training Text Encoder part / Text Encoder関連部分のみ学習する"
|
||||
)
|
||||
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
Reference in New Issue
Block a user