mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
Merge branch 'sd3_5_support' of https://github.com/kohya-ss/sd-scripts into sd3_5_support
This commit is contained in:
@@ -54,6 +54,10 @@ def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int
|
|||||||
with safe_open(ckpt_path, framework="pt") as f:
|
with safe_open(ckpt_path, framework="pt") as f:
|
||||||
keys.extend(f.keys())
|
keys.extend(f.keys())
|
||||||
|
|
||||||
|
# if the key has annoying prefix, remove it
|
||||||
|
if keys[0].startswith("model.diffusion_model."):
|
||||||
|
keys = [key.replace("model.diffusion_model.", "") for key in keys]
|
||||||
|
|
||||||
is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys
|
is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys
|
||||||
is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys)
|
is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys)
|
||||||
|
|
||||||
@@ -122,6 +126,13 @@ def load_flow_model(
|
|||||||
sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks)
|
sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks)
|
||||||
logger.info("Converted Diffusers to BFL")
|
logger.info("Converted Diffusers to BFL")
|
||||||
|
|
||||||
|
# if the key has annoying prefix, remove it
|
||||||
|
for key in list(sd.keys()):
|
||||||
|
new_key = key.replace("model.diffusion_model.", "")
|
||||||
|
if new_key == key:
|
||||||
|
break # the model doesn't have annoying prefix
|
||||||
|
sd[new_key] = sd.pop(key)
|
||||||
|
|
||||||
info = model.load_state_dict(sd, strict=False, assign=True)
|
info = model.load_state_dict(sd, strict=False, assign=True)
|
||||||
logger.info(f"Loaded Flux: {info}")
|
logger.info(f"Loaded Flux: {info}")
|
||||||
return is_schnell, model
|
return is_schnell, model
|
||||||
|
|||||||
@@ -307,6 +307,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
target_replace_modules: List[str],
|
target_replace_modules: List[str],
|
||||||
filter: Optional[str] = None,
|
filter: Optional[str] = None,
|
||||||
default_dim: Optional[int] = None,
|
default_dim: Optional[int] = None,
|
||||||
|
include_conv2d_if_filter: bool = False,
|
||||||
) -> List[LoRAModule]:
|
) -> List[LoRAModule]:
|
||||||
prefix = (
|
prefix = (
|
||||||
self.LORA_PREFIX_SD3
|
self.LORA_PREFIX_SD3
|
||||||
@@ -332,8 +333,11 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
lora_name = prefix + "." + (name + "." if name else "") + child_name
|
lora_name = prefix + "." + (name + "." if name else "") + child_name
|
||||||
lora_name = lora_name.replace(".", "_")
|
lora_name = lora_name.replace(".", "_")
|
||||||
|
|
||||||
if filter is not None and not filter in lora_name:
|
force_incl_conv2d = False
|
||||||
continue
|
if filter is not None:
|
||||||
|
if not filter in lora_name:
|
||||||
|
continue
|
||||||
|
force_incl_conv2d = include_conv2d_if_filter
|
||||||
|
|
||||||
dim = None
|
dim = None
|
||||||
alpha = None
|
alpha = None
|
||||||
@@ -373,6 +377,10 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
elif self.conv_lora_dim is not None:
|
elif self.conv_lora_dim is not None:
|
||||||
dim = self.conv_lora_dim
|
dim = self.conv_lora_dim
|
||||||
alpha = self.conv_alpha
|
alpha = self.conv_alpha
|
||||||
|
elif force_incl_conv2d:
|
||||||
|
# x_embedder
|
||||||
|
dim = default_dim if default_dim is not None else self.lora_dim
|
||||||
|
alpha = self.alpha
|
||||||
|
|
||||||
if dim is None or dim == 0:
|
if dim is None or dim == 0:
|
||||||
# skipした情報を出力
|
# skipした情報を出力
|
||||||
@@ -428,7 +436,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
for filter, in_dim in zip(
|
for filter, in_dim in zip(
|
||||||
[
|
[
|
||||||
"context_embedder",
|
"context_embedder",
|
||||||
"t_embedder",
|
"_t_embedder", # don't use "t_embedder" because it's used in "context_embedder"
|
||||||
"x_embedder",
|
"x_embedder",
|
||||||
"y_embedder",
|
"y_embedder",
|
||||||
"final_layer_adaLN_modulation",
|
"final_layer_adaLN_modulation",
|
||||||
@@ -436,7 +444,12 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
],
|
],
|
||||||
self.emb_dims,
|
self.emb_dims,
|
||||||
):
|
):
|
||||||
loras, _ = create_modules(True, None, unet, None, filter=filter, default_dim=in_dim)
|
# x_embedder is conv2d, so we need to include it
|
||||||
|
loras, _ = create_modules(
|
||||||
|
True, None, unet, None, filter=filter, default_dim=in_dim, include_conv2d_if_filter=filter == "x_embedder"
|
||||||
|
)
|
||||||
|
# if len(loras) > 0:
|
||||||
|
# logger.info(f"create LoRA for {filter}: {len(loras)} modules.")
|
||||||
self.unet_loras.extend(loras)
|
self.unet_loras.extend(loras)
|
||||||
|
|
||||||
logger.info(f"create LoRA for SD3 MMDiT: {len(self.unet_loras)} modules.")
|
logger.info(f"create LoRA for SD3 MMDiT: {len(self.unet_loras)} modules.")
|
||||||
@@ -540,8 +553,8 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim)
|
down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim)
|
||||||
|
|
||||||
# merge up weight (sum of split_dim, rank*3)
|
# merge up weight (sum of split_dim, rank*3)
|
||||||
qkv_dim, rank = up_weights[0].size()
|
split_dim, rank = up_weights[0].size()
|
||||||
split_dim = qkv_dim // 3
|
qkv_dim = split_dim * 3
|
||||||
up_weight = torch.zeros((qkv_dim, down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype)
|
up_weight = torch.zeros((qkv_dim, down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype)
|
||||||
i = 0
|
i = 0
|
||||||
for j in range(3):
|
for j in range(3):
|
||||||
|
|||||||
@@ -10,11 +10,13 @@ import numpy as np
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from safetensors.torch import safe_open, load_file
|
from safetensors.torch import safe_open, load_file
|
||||||
|
import torch.amp
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import CLIPTextModelWithProjection, T5EncoderModel
|
from transformers import CLIPTextModelWithProjection, T5EncoderModel
|
||||||
|
|
||||||
from library.device_utils import init_ipex, get_preferred_device
|
from library.device_utils import init_ipex, get_preferred_device
|
||||||
|
from networks import lora_sd3
|
||||||
|
|
||||||
init_ipex()
|
init_ipex()
|
||||||
|
|
||||||
@@ -104,7 +106,8 @@ def do_sample(
|
|||||||
x_c_nc = torch.cat([x, x], dim=0)
|
x_c_nc = torch.cat([x, x], dim=0)
|
||||||
# print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape)
|
# print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape)
|
||||||
|
|
||||||
model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y)
|
with torch.autocast(device_type=device.type, dtype=dtype):
|
||||||
|
model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y)
|
||||||
model_output = model_output.float()
|
model_output = model_output.float()
|
||||||
batched = model_sampling.calculate_denoised(sigma_hat, model_output, x)
|
batched = model_sampling.calculate_denoised(sigma_hat, model_output, x)
|
||||||
|
|
||||||
@@ -153,7 +156,7 @@ def generate_image(
|
|||||||
clip_g.to(device)
|
clip_g.to(device)
|
||||||
t5xxl.to(device)
|
t5xxl.to(device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.autocast(device_type=device.type, dtype=mmdit.dtype), torch.no_grad():
|
||||||
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
||||||
lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encoding_strategy.encode_tokens(
|
lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encoding_strategy.encode_tokens(
|
||||||
tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask
|
tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask
|
||||||
@@ -233,13 +236,14 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--bf16", action="store_true")
|
parser.add_argument("--bf16", action="store_true")
|
||||||
parser.add_argument("--seed", type=int, default=1)
|
parser.add_argument("--seed", type=int, default=1)
|
||||||
parser.add_argument("--steps", type=int, default=50)
|
parser.add_argument("--steps", type=int, default=50)
|
||||||
# parser.add_argument(
|
parser.add_argument(
|
||||||
# "--lora_weights",
|
"--lora_weights",
|
||||||
# type=str,
|
type=str,
|
||||||
# nargs="*",
|
nargs="*",
|
||||||
# default=[],
|
default=[],
|
||||||
# help="LoRA weights, only supports networks.lora, each argument is a `path;multiplier` (semi-colon separated)",
|
help="LoRA weights, only supports networks.lora_sd3, each argument is a `path;multiplier` (semi-colon separated)",
|
||||||
# )
|
)
|
||||||
|
parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model")
|
||||||
parser.add_argument("--width", type=int, default=target_width)
|
parser.add_argument("--width", type=int, default=target_width)
|
||||||
parser.add_argument("--height", type=int, default=target_height)
|
parser.add_argument("--height", type=int, default=target_height)
|
||||||
parser.add_argument("--interactive", action="store_true")
|
parser.add_argument("--interactive", action="store_true")
|
||||||
@@ -294,6 +298,30 @@ if __name__ == "__main__":
|
|||||||
tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_token_length)
|
tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_token_length)
|
||||||
encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy()
|
encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy()
|
||||||
|
|
||||||
|
# LoRA
|
||||||
|
lora_models: list[lora_sd3.LoRANetwork] = []
|
||||||
|
for weights_file in args.lora_weights:
|
||||||
|
if ";" in weights_file:
|
||||||
|
weights_file, multiplier = weights_file.split(";")
|
||||||
|
multiplier = float(multiplier)
|
||||||
|
else:
|
||||||
|
multiplier = 1.0
|
||||||
|
|
||||||
|
weights_sd = load_file(weights_file)
|
||||||
|
module = lora_sd3
|
||||||
|
lora_model, _ = module.create_network_from_weights(multiplier, None, vae, [clip_l, clip_g, t5xxl], mmdit, weights_sd, True)
|
||||||
|
|
||||||
|
if args.merge_lora_weights:
|
||||||
|
lora_model.merge_to([clip_l, clip_g, t5xxl], mmdit, weights_sd)
|
||||||
|
else:
|
||||||
|
lora_model.apply_to([clip_l, clip_g, t5xxl], mmdit)
|
||||||
|
info = lora_model.load_state_dict(weights_sd, strict=True)
|
||||||
|
logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
|
||||||
|
lora_model.eval()
|
||||||
|
lora_model.to(device)
|
||||||
|
|
||||||
|
lora_models.append(lora_model)
|
||||||
|
|
||||||
if not args.interactive:
|
if not args.interactive:
|
||||||
generate_image(
|
generate_image(
|
||||||
mmdit,
|
mmdit,
|
||||||
@@ -344,13 +372,13 @@ if __name__ == "__main__":
|
|||||||
steps = int(opt[1:].strip())
|
steps = int(opt[1:].strip())
|
||||||
elif opt.startswith("d"):
|
elif opt.startswith("d"):
|
||||||
seed = int(opt[1:].strip())
|
seed = int(opt[1:].strip())
|
||||||
# elif opt.startswith("m"):
|
elif opt.startswith("m"):
|
||||||
# mutipliers = opt[1:].strip().split(",")
|
mutipliers = opt[1:].strip().split(",")
|
||||||
# if len(mutipliers) != len(lora_models):
|
if len(mutipliers) != len(lora_models):
|
||||||
# logger.error(f"Invalid number of multipliers, expected {len(lora_models)}")
|
logger.error(f"Invalid number of multipliers, expected {len(lora_models)}")
|
||||||
# continue
|
continue
|
||||||
# for i, lora_model in enumerate(lora_models):
|
for i, lora_model in enumerate(lora_models):
|
||||||
# lora_model.set_multiplier(float(mutipliers[i]))
|
lora_model.set_multiplier(float(mutipliers[i]))
|
||||||
elif opt.startswith("n"):
|
elif opt.startswith("n"):
|
||||||
negative_prompt = opt[1:].strip()
|
negative_prompt = opt[1:].strip()
|
||||||
if negative_prompt == "-":
|
if negative_prompt == "-":
|
||||||
|
|||||||
Reference in New Issue
Block a user