mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
pre calc LoRA in generating
This commit is contained in:
@@ -2262,6 +2262,8 @@ def main(args):
|
||||
if args.network_module:
|
||||
networks = []
|
||||
network_default_muls = []
|
||||
network_pre_calc=args.network_pre_calc
|
||||
|
||||
for i, network_module in enumerate(args.network_module):
|
||||
print("import network module:", network_module)
|
||||
imported_module = importlib.import_module(network_module)
|
||||
@@ -2298,11 +2300,11 @@ def main(args):
|
||||
if network is None:
|
||||
return
|
||||
|
||||
mergiable = hasattr(network, "merge_to")
|
||||
if args.network_merge and not mergiable:
|
||||
mergeable = network.is_mergeable()
|
||||
if args.network_merge and not mergeable:
|
||||
print("network is not mergiable. ignore merge option.")
|
||||
|
||||
if not args.network_merge or not mergiable:
|
||||
if not args.network_merge or not mergeable:
|
||||
network.apply_to(text_encoder, unet)
|
||||
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
|
||||
print(f"weights are loaded: {info}")
|
||||
@@ -2311,6 +2313,10 @@ def main(args):
|
||||
network.to(memory_format=torch.channels_last)
|
||||
network.to(dtype).to(device)
|
||||
|
||||
if network_pre_calc:
|
||||
print("backup original weights")
|
||||
network.backup_weights()
|
||||
|
||||
networks.append(network)
|
||||
else:
|
||||
network.merge_to(text_encoder, unet, weights_sd, dtype, device)
|
||||
@@ -2815,11 +2821,19 @@ def main(args):
|
||||
|
||||
# generate
|
||||
if networks:
|
||||
# 追加ネットワークの処理
|
||||
shared = {}
|
||||
for n, m in zip(networks, network_muls if network_muls else network_default_muls):
|
||||
n.set_multiplier(m)
|
||||
if regional_network:
|
||||
n.set_current_generation(batch_size, num_sub_prompts, width, height, shared)
|
||||
|
||||
if not regional_network and network_pre_calc:
|
||||
for n in networks:
|
||||
n.restore_weights()
|
||||
for n in networks:
|
||||
n.pre_calculation()
|
||||
print("pre-calculation... done")
|
||||
|
||||
images = pipe(
|
||||
prompts,
|
||||
@@ -3204,6 +3218,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
)
|
||||
parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する")
|
||||
parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする")
|
||||
parser.add_argument("--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する")
|
||||
parser.add_argument(
|
||||
"--textual_inversion_embeddings",
|
||||
type=str,
|
||||
|
||||
Reference in New Issue
Block a user