pre calc LoRA in generating

This commit is contained in:
Kohya S
2023-05-07 09:57:54 +09:00
parent 165fc43655
commit fdbdb4748a
2 changed files with 122 additions and 32 deletions

View File

@@ -2262,6 +2262,8 @@ def main(args):
if args.network_module:
networks = []
network_default_muls = []
network_pre_calc=args.network_pre_calc
for i, network_module in enumerate(args.network_module):
print("import network module:", network_module)
imported_module = importlib.import_module(network_module)
@@ -2298,11 +2300,11 @@ def main(args):
if network is None:
return
mergiable = hasattr(network, "merge_to")
if args.network_merge and not mergiable:
mergeable = network.is_mergeable()
if args.network_merge and not mergeable:
print("network is not mergiable. ignore merge option.")
if not args.network_merge or not mergiable:
if not args.network_merge or not mergeable:
network.apply_to(text_encoder, unet)
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
print(f"weights are loaded: {info}")
@@ -2311,6 +2313,10 @@ def main(args):
network.to(memory_format=torch.channels_last)
network.to(dtype).to(device)
if network_pre_calc:
print("backup original weights")
network.backup_weights()
networks.append(network)
else:
network.merge_to(text_encoder, unet, weights_sd, dtype, device)
@@ -2815,11 +2821,19 @@ def main(args):
# generate
if networks:
# 追加ネットワークの処理
shared = {}
for n, m in zip(networks, network_muls if network_muls else network_default_muls):
n.set_multiplier(m)
if regional_network:
n.set_current_generation(batch_size, num_sub_prompts, width, height, shared)
if not regional_network and network_pre_calc:
for n in networks:
n.restore_weights()
for n in networks:
n.pre_calculation()
print("pre-calculation... done")
images = pipe(
prompts,
@@ -3204,6 +3218,7 @@ def setup_parser() -> argparse.ArgumentParser:
)
parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する")
parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする")
parser.add_argument("--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する")
parser.add_argument(
"--textual_inversion_embeddings",
type=str,