diff --git a/train_network.py b/train_network.py index 4c4cc281..1fe9d083 100644 --- a/train_network.py +++ b/train_network.py @@ -176,7 +176,32 @@ def train(args): net_kwargs[key] = value # if a new network is added in future, add if ~ then blocks for each network (;'∀') - network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs) + if args.size_from_weights: + network, weights = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet) + if net_kwargs is not None: + down_lr_weight = net_kwargs.get("down_lr_weight", None) + mid_lr_weight = net_kwargs.get("mid_lr_weight", None) + up_lr_weight = net_kwargs.get("up_lr_weight", None) + if down_lr_weight is not None: + # if some parameters are not set, use zero + if "," in down_lr_weight: + down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")] + + if mid_lr_weight is not None: + mid_lr_weight = float(mid_lr_weight) + + if up_lr_weight is not None: + if "," in up_lr_weight: + up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")] + + down_lr_weight, mid_lr_weight, up_lr_weight = network_module.get_block_lr_weight( + down_lr_weight, mid_lr_weight, up_lr_weight, float(net_kwargs.get("block_lr_zero_threshold", 0.0)) + ) + + if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None: + network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight) + else: + network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs) if network is None: return @@ -760,6 +785,10 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument( "--training_comment", type=str, default=None, help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列" ) + parser.add_argument( + "--size_from_weights", action="store_true", help="only training Text Encoder part / Text Encoder関連部分のみ学習する" + ) + return parser