Support Lora

This commit is contained in:
kohya-ss
2024-10-29 21:52:04 +09:00
parent 0af4edd8a6
commit d4e19fbd5e

View File

@@ -10,11 +10,13 @@ import numpy as np
import torch
from safetensors.torch import safe_open, load_file
import torch.amp
from tqdm import tqdm
from PIL import Image
from transformers import CLIPTextModelWithProjection, T5EncoderModel
from library.device_utils import init_ipex, get_preferred_device
from networks import lora_sd3
init_ipex()
@@ -104,7 +106,8 @@ def do_sample(
x_c_nc = torch.cat([x, x], dim=0)
# print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape)
model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y)
with torch.autocast(device_type=device.type, dtype=dtype):
model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y)
model_output = model_output.float()
batched = model_sampling.calculate_denoised(sigma_hat, model_output, x)
@@ -153,7 +156,7 @@ def generate_image(
clip_g.to(device)
t5xxl.to(device)
with torch.no_grad():
with torch.autocast(device_type=device.type, dtype=mmdit.dtype), torch.no_grad():
tokens_and_masks = tokenize_strategy.tokenize(prompt)
lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encoding_strategy.encode_tokens(
tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask
@@ -233,13 +236,14 @@ if __name__ == "__main__":
parser.add_argument("--bf16", action="store_true")
parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--steps", type=int, default=50)
# parser.add_argument(
# "--lora_weights",
# type=str,
# nargs="*",
# default=[],
# help="LoRA weights, only supports networks.lora, each argument is a `path;multiplier` (semi-colon separated)",
# )
parser.add_argument(
"--lora_weights",
type=str,
nargs="*",
default=[],
help="LoRA weights, only supports networks.lora_sd3, each argument is a `path;multiplier` (semi-colon separated)",
)
parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model")
parser.add_argument("--width", type=int, default=target_width)
parser.add_argument("--height", type=int, default=target_height)
parser.add_argument("--interactive", action="store_true")
@@ -294,6 +298,30 @@ if __name__ == "__main__":
tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_token_length)
encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy()
# LoRA
lora_models: list[lora_sd3.LoRANetwork] = []
for weights_file in args.lora_weights:
if ";" in weights_file:
weights_file, multiplier = weights_file.split(";")
multiplier = float(multiplier)
else:
multiplier = 1.0
weights_sd = load_file(weights_file)
module = lora_sd3
lora_model, _ = module.create_network_from_weights(multiplier, None, vae, [clip_l, clip_g, t5xxl], mmdit, weights_sd, True)
if args.merge_lora_weights:
lora_model.merge_to([clip_l, clip_g, t5xxl], mmdit, weights_sd)
else:
lora_model.apply_to([clip_l, clip_g, t5xxl], mmdit)
info = lora_model.load_state_dict(weights_sd, strict=True)
logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
lora_model.eval()
lora_model.to(device)
lora_models.append(lora_model)
if not args.interactive:
generate_image(
mmdit,
@@ -344,13 +372,13 @@ if __name__ == "__main__":
steps = int(opt[1:].strip())
elif opt.startswith("d"):
seed = int(opt[1:].strip())
# elif opt.startswith("m"):
# mutipliers = opt[1:].strip().split(",")
# if len(mutipliers) != len(lora_models):
# logger.error(f"Invalid number of multipliers, expected {len(lora_models)}")
# continue
# for i, lora_model in enumerate(lora_models):
# lora_model.set_multiplier(float(mutipliers[i]))
elif opt.startswith("m"):
mutipliers = opt[1:].strip().split(",")
if len(mutipliers) != len(lora_models):
logger.error(f"Invalid number of multipliers, expected {len(lora_models)}")
continue
for i, lora_model in enumerate(lora_models):
lora_model.set_multiplier(float(mutipliers[i]))
elif opt.startswith("n"):
negative_prompt = opt[1:].strip()
if negative_prompt == "-":