diff --git a/README.md b/README.md index ccc83e6e..c0d50a5a 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,8 @@ The trained LoRA model can be used with ComfyUI. The inference script is also available. The script is `flux_minimal_inference.py`. See `--help` for options. +Aug 12: `--interactive` option is now working. + ``` python flux_minimal_inference.py --ckpt flux1-dev.sft --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.sft --dtype bf16 --prompt "a cat holding a sign that says hello world" --out path/to/output/dir --seed 1 --flux_dtype fp8 --offload --lora lora-flux-name.safetensors;1.0 ``` diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index f3affca8..b09f6380 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -5,7 +5,7 @@ import datetime import math import os import random -from typing import Callable, Optional, Tuple +from typing import Callable, List, Optional, Tuple import einops import numpy as np @@ -121,6 +121,9 @@ def generate_image( steps: Optional[int], guidance: float, ): + seed = seed if seed is not None else random.randint(0, 2**32 - 1) + logger.info(f"Seed: {seed}") + # make first noise with packed shape # original: b,16,2*h//16,2*w//16, packed: b,h//16*w//16,16*2*2 packed_latent_height, packed_latent_width = math.ceil(image_height / 16), math.ceil(image_width / 16) @@ -183,9 +186,7 @@ def generate_image( steps = 4 if is_schnell else 50 img_ids = img_ids.to(device) - x = do_sample( - accelerator, model, noise, img_ids, l_pooled, t5_out, txt_ids, steps, guidance_scale, is_schnell, device, flux_dtype - ) + x = do_sample(accelerator, model, noise, img_ids, l_pooled, t5_out, txt_ids, steps, guidance, is_schnell, device, flux_dtype) if args.offload: model = model.cpu() # del model @@ -255,6 +256,7 @@ if __name__ == "__main__": default=[], help="LoRA weights, only supports networks.lora_flux, each argument is a `path;multiplier` (semi-colon separated)", ) + parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model") parser.add_argument("--width", type=int, default=target_width) parser.add_argument("--height", type=int, default=target_height) parser.add_argument("--interactive", action="store_true") @@ -341,6 +343,7 @@ if __name__ == "__main__": ae = accelerator.prepare(ae) # LoRA + lora_models: List[lora_flux.LoRANetwork] = [] for weights_file in args.lora_weights: if ";" in weights_file: weights_file, multiplier = weights_file.split(";") @@ -351,7 +354,16 @@ if __name__ == "__main__": lora_model, weights_sd = lora_flux.create_network_from_weights( multiplier, weights_file, ae, [clip_l, t5xxl], model, None, True ) - lora_model.merge_to([clip_l, t5xxl], model, weights_sd) + if args.merge_lora_weights: + lora_model.merge_to([clip_l, t5xxl], model, weights_sd) + else: + lora_model.apply_to([clip_l, t5xxl], model) + info = lora_model.load_state_dict(weights_sd, strict=True) + logger.info(f"Loaded LoRA weights from {weights_file}: {info}") + lora_model.eval() + lora_model.to(device) + + lora_models.append(lora_model) if not args.interactive: generate_image(model, clip_l, t5xxl, ae, args.prompt, args.seed, args.width, args.height, args.steps, args.guidance) @@ -363,7 +375,9 @@ if __name__ == "__main__": guidance = args.guidance while True: - print("Enter prompt (empty to exit). Options: --w --h --s --d --g ") + print( + "Enter prompt (empty to exit). Options: --w --h --s --d --g --m " + ) prompt = input() if prompt == "": break @@ -384,6 +398,13 @@ if __name__ == "__main__": seed = int(opt[1:].strip()) elif opt.startswith("g"): guidance = float(opt[1:].strip()) + elif opt.startswith("m"): + mutipliers = opt[1:].strip().split(",") + if len(mutipliers) != len(lora_models): + logger.error(f"Invalid number of multipliers, expected {len(lora_models)}") + continue + for i, lora_model in enumerate(lora_models): + lora_model.set_multiplier(float(mutipliers[i])) generate_image(model, clip_l, t5xxl, ae, prompt, seed, width, height, steps, guidance)