From 7fb0d30feba5f1112ad28099ac79b6109e98ec57 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 9 Jul 2025 23:28:55 +0900 Subject: [PATCH] feat: add LoRA support for lumina minimal inference --- lumina_minimal_inference.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/lumina_minimal_inference.py b/lumina_minimal_inference.py index ff7c21df..ba305f6f 100644 --- a/lumina_minimal_inference.py +++ b/lumina_minimal_inference.py @@ -27,6 +27,7 @@ from library import ( sd3_train_utils, strategy_lumina, ) +import networks.lora_lumina as lora_lumina from library.device_utils import get_preferred_device, init_ipex from library.utils import setup_logging, str_to_dtype @@ -248,6 +249,14 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="Use sage attention for Lumina model", ) + parser.add_argument( + "--lora_weights", + type=str, + nargs="*", + default=[], + help="LoRA weights, each argument is a `path;multiplier` (semi-colon separated)", + ) + parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model") return parser @@ -275,6 +284,30 @@ if __name__ == "__main__": # Load Autoencoder ae = lumina_util.load_ae(args.ae_path, dtype=None, device="cpu") + # LoRA + lora_models = [] + for weights_file in args.lora_weights: + if ";" in weights_file: + weights_file, multiplier = weights_file.split(";") + multiplier = float(multiplier) + else: + multiplier = 1.0 + + weights_sd = load_file(weights_file) + lora_model, _ = lora_lumina.create_network_from_weights( + multiplier, None, ae, [gemma2], model, weights_sd, True + ) + + if args.merge_lora_weights: + lora_model.merge_to([gemma2], model, weights_sd) + else: + lora_model.apply_to([gemma2], model) + info = lora_model.load_state_dict(weights_sd, strict=True) + logger.info(f"Loaded LoRA weights from {weights_file}: {info}") + lora_model.eval() + + lora_models.append(lora_model) + generate_image( model, gemma2,