feat: add LoRA support for lumina minimal inference

This commit is contained in:
Kohya S
2025-07-09 23:28:55 +09:00
parent b4d1152293
commit 7fb0d30feb

View File

@@ -27,6 +27,7 @@ from library import (
sd3_train_utils,
strategy_lumina,
)
import networks.lora_lumina as lora_lumina
from library.device_utils import get_preferred_device, init_ipex
from library.utils import setup_logging, str_to_dtype
@@ -248,6 +249,14 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="Use sage attention for Lumina model",
)
parser.add_argument(
"--lora_weights",
type=str,
nargs="*",
default=[],
help="LoRA weights, each argument is a `path;multiplier` (semi-colon separated)",
)
parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model")
return parser
@@ -275,6 +284,30 @@ if __name__ == "__main__":
# Load Autoencoder
ae = lumina_util.load_ae(args.ae_path, dtype=None, device="cpu")
# LoRA
lora_models = []
for weights_file in args.lora_weights:
if ";" in weights_file:
weights_file, multiplier = weights_file.split(";")
multiplier = float(multiplier)
else:
multiplier = 1.0
weights_sd = load_file(weights_file)
lora_model, _ = lora_lumina.create_network_from_weights(
multiplier, None, ae, [gemma2], model, weights_sd, True
)
if args.merge_lora_weights:
lora_model.merge_to([gemma2], model, weights_sd)
else:
lora_model.apply_to([gemma2], model)
info = lora_model.load_state_dict(weights_sd, strict=True)
logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
lora_model.eval()
lora_models.append(lora_model)
generate_image(
model,
gemma2,