mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
update to work interactive mode
This commit is contained in:
@@ -39,6 +39,8 @@ The trained LoRA model can be used with ComfyUI.
|
||||
|
||||
The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options.
|
||||
|
||||
Aug 12: `--interactive` option is now working.
|
||||
|
||||
```
|
||||
python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0
|
||||
```
|
||||
|
||||
@@ -5,7 +5,7 @@ import datetime
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from typing import Callable, Optional, Tuple
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
import einops
|
||||
import numpy as np
|
||||
|
||||
@@ -121,6 +121,9 @@ def generate_image(
|
||||
steps: Optional[int],
|
||||
guidance: float,
|
||||
):
|
||||
seed = seed if seed is not None else random.randint(0, 2**32 - 1)
|
||||
logger.info(f"Seed: {seed}")
|
||||
|
||||
# make first noise with packed shape
|
||||
# original: b,16,2*h//16,2*w//16, packed: b,h//16*w//16,16*2*2
|
||||
packed_latent_height, packed_latent_width = math.ceil(image_height / 16), math.ceil(image_width / 16)
|
||||
@@ -183,9 +186,7 @@ def generate_image(
|
||||
steps = 4 if is_schnell else 50
|
||||
|
||||
img_ids = img_ids.to(device)
|
||||
x = do_sample(
|
||||
accelerator, model, noise, img_ids, l_pooled, t5_out, txt_ids, steps, guidance_scale, is_schnell, device, flux_dtype
|
||||
)
|
||||
x = do_sample(accelerator, model, noise, img_ids, l_pooled, t5_out, txt_ids, steps, guidance, is_schnell, device, flux_dtype)
|
||||
if args.offload:
|
||||
model = model.cpu()
|
||||
# del model
|
||||
@@ -255,6 +256,7 @@ if __name__ == "__main__":
|
||||
default=[],
|
||||
help="LoRA weights, only supports networks.lora_flux, each argument is a `path;multiplier` (semi-colon separated)",
|
||||
)
|
||||
parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model")
|
||||
parser.add_argument("--width", type=int, default=target_width)
|
||||
parser.add_argument("--height", type=int, default=target_height)
|
||||
parser.add_argument("--interactive", action="store_true")
|
||||
@@ -341,6 +343,7 @@ if __name__ == "__main__":
|
||||
ae = accelerator.prepare(ae)
|
||||
|
||||
# LoRA
|
||||
lora_models: List[lora_flux.LoRANetwork] = []
|
||||
for weights_file in args.lora_weights:
|
||||
if ";" in weights_file:
|
||||
weights_file, multiplier = weights_file.split(";")
|
||||
@@ -351,7 +354,16 @@ if __name__ == "__main__":
|
||||
lora_model, weights_sd = lora_flux.create_network_from_weights(
|
||||
multiplier, weights_file, ae, [clip_l, t5xxl], model, None, True
|
||||
)
|
||||
lora_model.merge_to([clip_l, t5xxl], model, weights_sd)
|
||||
if args.merge_lora_weights:
|
||||
lora_model.merge_to([clip_l, t5xxl], model, weights_sd)
|
||||
else:
|
||||
lora_model.apply_to([clip_l, t5xxl], model)
|
||||
info = lora_model.load_state_dict(weights_sd, strict=True)
|
||||
logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
|
||||
lora_model.eval()
|
||||
lora_model.to(device)
|
||||
|
||||
lora_models.append(lora_model)
|
||||
|
||||
if not args.interactive:
|
||||
generate_image(model, clip_l, t5xxl, ae, args.prompt, args.seed, args.width, args.height, args.steps, args.guidance)
|
||||
@@ -363,7 +375,9 @@ if __name__ == "__main__":
|
||||
guidance = args.guidance
|
||||
|
||||
while True:
|
||||
print("Enter prompt (empty to exit). Options: --w <width> --h <height> --s <steps> --d <seed> --g <guidance>")
|
||||
print(
|
||||
"Enter prompt (empty to exit). Options: --w <width> --h <height> --s <steps> --d <seed> --g <guidance> --m <multipliers for LoRA>"
|
||||
)
|
||||
prompt = input()
|
||||
if prompt == "":
|
||||
break
|
||||
@@ -384,6 +398,13 @@ if __name__ == "__main__":
|
||||
seed = int(opt[1:].strip())
|
||||
elif opt.startswith("g"):
|
||||
guidance = float(opt[1:].strip())
|
||||
elif opt.startswith("m"):
|
||||
mutipliers = opt[1:].strip().split(",")
|
||||
if len(mutipliers) != len(lora_models):
|
||||
logger.error(f"Invalid number of multipliers, expected {len(lora_models)}")
|
||||
continue
|
||||
for i, lora_model in enumerate(lora_models):
|
||||
lora_model.set_multiplier(float(mutipliers[i]))
|
||||
|
||||
generate_image(model, clip_l, t5xxl, ae, prompt, seed, width, height, steps, guidance)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user