Merge branch 'sd3_5_support' of https://github.com/kohya-ss/sd-scripts into sd3_5_support

This commit is contained in:
Kohya S
2024-10-30 12:51:55 +09:00
3 changed files with 74 additions and 22 deletions

View File

@@ -54,6 +54,10 @@ def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int
with safe_open(ckpt_path, framework="pt") as f: with safe_open(ckpt_path, framework="pt") as f:
keys.extend(f.keys()) keys.extend(f.keys())
# if the key has annoying prefix, remove it
if keys[0].startswith("model.diffusion_model."):
keys = [key.replace("model.diffusion_model.", "") for key in keys]
is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys
is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys) is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys)
@@ -122,6 +126,13 @@ def load_flow_model(
sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks) sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks)
logger.info("Converted Diffusers to BFL") logger.info("Converted Diffusers to BFL")
# if the key has annoying prefix, remove it
for key in list(sd.keys()):
new_key = key.replace("model.diffusion_model.", "")
if new_key == key:
break # the model doesn't have annoying prefix
sd[new_key] = sd.pop(key)
info = model.load_state_dict(sd, strict=False, assign=True) info = model.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded Flux: {info}") logger.info(f"Loaded Flux: {info}")
return is_schnell, model return is_schnell, model

View File

@@ -307,6 +307,7 @@ class LoRANetwork(torch.nn.Module):
target_replace_modules: List[str], target_replace_modules: List[str],
filter: Optional[str] = None, filter: Optional[str] = None,
default_dim: Optional[int] = None, default_dim: Optional[int] = None,
include_conv2d_if_filter: bool = False,
) -> List[LoRAModule]: ) -> List[LoRAModule]:
prefix = ( prefix = (
self.LORA_PREFIX_SD3 self.LORA_PREFIX_SD3
@@ -332,8 +333,11 @@ class LoRANetwork(torch.nn.Module):
lora_name = prefix + "." + (name + "." if name else "") + child_name lora_name = prefix + "." + (name + "." if name else "") + child_name
lora_name = lora_name.replace(".", "_") lora_name = lora_name.replace(".", "_")
if filter is not None and not filter in lora_name: force_incl_conv2d = False
continue if filter is not None:
if not filter in lora_name:
continue
force_incl_conv2d = include_conv2d_if_filter
dim = None dim = None
alpha = None alpha = None
@@ -373,6 +377,10 @@ class LoRANetwork(torch.nn.Module):
elif self.conv_lora_dim is not None: elif self.conv_lora_dim is not None:
dim = self.conv_lora_dim dim = self.conv_lora_dim
alpha = self.conv_alpha alpha = self.conv_alpha
elif force_incl_conv2d:
# x_embedder
dim = default_dim if default_dim is not None else self.lora_dim
alpha = self.alpha
if dim is None or dim == 0: if dim is None or dim == 0:
# skipした情報を出力 # skipした情報を出力
@@ -428,7 +436,7 @@ class LoRANetwork(torch.nn.Module):
for filter, in_dim in zip( for filter, in_dim in zip(
[ [
"context_embedder", "context_embedder",
"t_embedder", "_t_embedder", # don't use "t_embedder" because it's used in "context_embedder"
"x_embedder", "x_embedder",
"y_embedder", "y_embedder",
"final_layer_adaLN_modulation", "final_layer_adaLN_modulation",
@@ -436,7 +444,12 @@ class LoRANetwork(torch.nn.Module):
], ],
self.emb_dims, self.emb_dims,
): ):
loras, _ = create_modules(True, None, unet, None, filter=filter, default_dim=in_dim) # x_embedder is conv2d, so we need to include it
loras, _ = create_modules(
True, None, unet, None, filter=filter, default_dim=in_dim, include_conv2d_if_filter=filter == "x_embedder"
)
# if len(loras) > 0:
# logger.info(f"create LoRA for {filter}: {len(loras)} modules.")
self.unet_loras.extend(loras) self.unet_loras.extend(loras)
logger.info(f"create LoRA for SD3 MMDiT: {len(self.unet_loras)} modules.") logger.info(f"create LoRA for SD3 MMDiT: {len(self.unet_loras)} modules.")
@@ -540,8 +553,8 @@ class LoRANetwork(torch.nn.Module):
down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim) down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim)
# merge up weight (sum of split_dim, rank*3) # merge up weight (sum of split_dim, rank*3)
qkv_dim, rank = up_weights[0].size() split_dim, rank = up_weights[0].size()
split_dim = qkv_dim // 3 qkv_dim = split_dim * 3
up_weight = torch.zeros((qkv_dim, down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype) up_weight = torch.zeros((qkv_dim, down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype)
i = 0 i = 0
for j in range(3): for j in range(3):

View File

@@ -10,11 +10,13 @@ import numpy as np
import torch import torch
from safetensors.torch import safe_open, load_file from safetensors.torch import safe_open, load_file
import torch.amp
from tqdm import tqdm from tqdm import tqdm
from PIL import Image from PIL import Image
from transformers import CLIPTextModelWithProjection, T5EncoderModel from transformers import CLIPTextModelWithProjection, T5EncoderModel
from library.device_utils import init_ipex, get_preferred_device from library.device_utils import init_ipex, get_preferred_device
from networks import lora_sd3
init_ipex() init_ipex()
@@ -104,7 +106,8 @@ def do_sample(
x_c_nc = torch.cat([x, x], dim=0) x_c_nc = torch.cat([x, x], dim=0)
# print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape) # print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape)
model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y) with torch.autocast(device_type=device.type, dtype=dtype):
model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y)
model_output = model_output.float() model_output = model_output.float()
batched = model_sampling.calculate_denoised(sigma_hat, model_output, x) batched = model_sampling.calculate_denoised(sigma_hat, model_output, x)
@@ -153,7 +156,7 @@ def generate_image(
clip_g.to(device) clip_g.to(device)
t5xxl.to(device) t5xxl.to(device)
with torch.no_grad(): with torch.autocast(device_type=device.type, dtype=mmdit.dtype), torch.no_grad():
tokens_and_masks = tokenize_strategy.tokenize(prompt) tokens_and_masks = tokenize_strategy.tokenize(prompt)
lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encoding_strategy.encode_tokens( lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encoding_strategy.encode_tokens(
tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask
@@ -233,13 +236,14 @@ if __name__ == "__main__":
parser.add_argument("--bf16", action="store_true") parser.add_argument("--bf16", action="store_true")
parser.add_argument("--seed", type=int, default=1) parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--steps", type=int, default=50) parser.add_argument("--steps", type=int, default=50)
# parser.add_argument( parser.add_argument(
# "--lora_weights", "--lora_weights",
# type=str, type=str,
# nargs="*", nargs="*",
# default=[], default=[],
# help="LoRA weights, only supports networks.lora, each argument is a `path;multiplier` (semi-colon separated)", help="LoRA weights, only supports networks.lora_sd3, each argument is a `path;multiplier` (semi-colon separated)",
# ) )
parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model")
parser.add_argument("--width", type=int, default=target_width) parser.add_argument("--width", type=int, default=target_width)
parser.add_argument("--height", type=int, default=target_height) parser.add_argument("--height", type=int, default=target_height)
parser.add_argument("--interactive", action="store_true") parser.add_argument("--interactive", action="store_true")
@@ -294,6 +298,30 @@ if __name__ == "__main__":
tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_token_length) tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_token_length)
encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy() encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy()
# LoRA
lora_models: list[lora_sd3.LoRANetwork] = []
for weights_file in args.lora_weights:
if ";" in weights_file:
weights_file, multiplier = weights_file.split(";")
multiplier = float(multiplier)
else:
multiplier = 1.0
weights_sd = load_file(weights_file)
module = lora_sd3
lora_model, _ = module.create_network_from_weights(multiplier, None, vae, [clip_l, clip_g, t5xxl], mmdit, weights_sd, True)
if args.merge_lora_weights:
lora_model.merge_to([clip_l, clip_g, t5xxl], mmdit, weights_sd)
else:
lora_model.apply_to([clip_l, clip_g, t5xxl], mmdit)
info = lora_model.load_state_dict(weights_sd, strict=True)
logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
lora_model.eval()
lora_model.to(device)
lora_models.append(lora_model)
if not args.interactive: if not args.interactive:
generate_image( generate_image(
mmdit, mmdit,
@@ -344,13 +372,13 @@ if __name__ == "__main__":
steps = int(opt[1:].strip()) steps = int(opt[1:].strip())
elif opt.startswith("d"): elif opt.startswith("d"):
seed = int(opt[1:].strip()) seed = int(opt[1:].strip())
# elif opt.startswith("m"): elif opt.startswith("m"):
# mutipliers = opt[1:].strip().split(",") mutipliers = opt[1:].strip().split(",")
# if len(mutipliers) != len(lora_models): if len(mutipliers) != len(lora_models):
# logger.error(f"Invalid number of multipliers, expected {len(lora_models)}") logger.error(f"Invalid number of multipliers, expected {len(lora_models)}")
# continue continue
# for i, lora_model in enumerate(lora_models): for i, lora_model in enumerate(lora_models):
# lora_model.set_multiplier(float(mutipliers[i])) lora_model.set_multiplier(float(mutipliers[i]))
elif opt.startswith("n"): elif opt.startswith("n"):
negative_prompt = opt[1:].strip() negative_prompt = opt[1:].strip()
if negative_prompt == "-": if negative_prompt == "-":