mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
feat: add LoRA support for lumina minimal inference
This commit is contained in:
@@ -27,6 +27,7 @@ from library import (
|
||||
sd3_train_utils,
|
||||
strategy_lumina,
|
||||
)
|
||||
import networks.lora_lumina as lora_lumina
|
||||
from library.device_utils import get_preferred_device, init_ipex
|
||||
from library.utils import setup_logging, str_to_dtype
|
||||
|
||||
@@ -248,6 +249,14 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
action="store_true",
|
||||
help="Use sage attention for Lumina model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_weights",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=[],
|
||||
help="LoRA weights, each argument is a `path;multiplier` (semi-colon separated)",
|
||||
)
|
||||
parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model")
|
||||
return parser
|
||||
|
||||
|
||||
@@ -275,6 +284,30 @@ if __name__ == "__main__":
|
||||
# Load Autoencoder
|
||||
ae = lumina_util.load_ae(args.ae_path, dtype=None, device="cpu")
|
||||
|
||||
# LoRA
|
||||
lora_models = []
|
||||
for weights_file in args.lora_weights:
|
||||
if ";" in weights_file:
|
||||
weights_file, multiplier = weights_file.split(";")
|
||||
multiplier = float(multiplier)
|
||||
else:
|
||||
multiplier = 1.0
|
||||
|
||||
weights_sd = load_file(weights_file)
|
||||
lora_model, _ = lora_lumina.create_network_from_weights(
|
||||
multiplier, None, ae, [gemma2], model, weights_sd, True
|
||||
)
|
||||
|
||||
if args.merge_lora_weights:
|
||||
lora_model.merge_to([gemma2], model, weights_sd)
|
||||
else:
|
||||
lora_model.apply_to([gemma2], model)
|
||||
info = lora_model.load_state_dict(weights_sd, strict=True)
|
||||
logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
|
||||
lora_model.eval()
|
||||
|
||||
lora_models.append(lora_model)
|
||||
|
||||
generate_image(
|
||||
model,
|
||||
gemma2,
|
||||
|
||||
Reference in New Issue
Block a user