mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
303 lines
10 KiB
Python
303 lines
10 KiB
Python
from dataclasses import dataclass
|
|
import math
|
|
import re
|
|
from typing import Dict, List, Optional, Union
|
|
import torch
|
|
import safetensors
|
|
from safetensors.torch import load_file
|
|
from accelerate import init_empty_weights
|
|
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPConfig, CLIPTextConfig
|
|
|
|
from .utils import setup_logging
|
|
|
|
setup_logging()
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
from library import sd3_models
|
|
|
|
# TODO move some of functions to model_util.py
|
|
from library import sdxl_model_util
|
|
|
|
# region models
|
|
|
|
# TODO remove dependency on flux_utils
|
|
from library.utils import load_safetensors
|
|
from library.flux_utils import load_t5xxl as flux_utils_load_t5xxl
|
|
|
|
|
|
def analyze_state_dict_state(state_dict: Dict, prefix: str = ""):
|
|
logger.info(f"Analyzing state dict state...")
|
|
|
|
# analyze configs
|
|
patch_size = state_dict[f"{prefix}x_embedder.proj.weight"].shape[2]
|
|
depth = state_dict[f"{prefix}x_embedder.proj.weight"].shape[0] // 64
|
|
num_patches = state_dict[f"{prefix}pos_embed"].shape[1]
|
|
pos_embed_max_size = round(math.sqrt(num_patches))
|
|
adm_in_channels = state_dict[f"{prefix}y_embedder.mlp.0.weight"].shape[1]
|
|
context_shape = state_dict[f"{prefix}context_embedder.weight"].shape
|
|
qk_norm = "rms" if f"{prefix}joint_blocks.0.context_block.attn.ln_k.weight" in state_dict.keys() else None
|
|
|
|
# x_block_self_attn_layers.append(int(key.split(".x_block.attn2.ln_k.weight")[0].split(".")[-1]))
|
|
x_block_self_attn_layers = []
|
|
re_attn = re.compile(r"\.(\d+)\.x_block\.attn2\.ln_k\.weight")
|
|
for key in list(state_dict.keys()):
|
|
m = re_attn.search(key)
|
|
if m:
|
|
x_block_self_attn_layers.append(int(m.group(1)))
|
|
|
|
context_embedder_in_features = context_shape[1]
|
|
context_embedder_out_features = context_shape[0]
|
|
|
|
# only supports 3-5-large, medium or 3-medium
|
|
if qk_norm is not None:
|
|
if len(x_block_self_attn_layers) == 0:
|
|
model_type = "3-5-large"
|
|
else:
|
|
model_type = "3-5-medium"
|
|
else:
|
|
model_type = "3-medium"
|
|
|
|
params = sd3_models.SD3Params(
|
|
patch_size=patch_size,
|
|
depth=depth,
|
|
num_patches=num_patches,
|
|
pos_embed_max_size=pos_embed_max_size,
|
|
adm_in_channels=adm_in_channels,
|
|
qk_norm=qk_norm,
|
|
x_block_self_attn_layers=x_block_self_attn_layers,
|
|
context_embedder_in_features=context_embedder_in_features,
|
|
context_embedder_out_features=context_embedder_out_features,
|
|
model_type=model_type,
|
|
)
|
|
logger.info(f"Analyzed state dict state: {params}")
|
|
return params
|
|
|
|
|
|
def load_mmdit(
|
|
state_dict: Dict, dtype: Optional[Union[str, torch.dtype]], device: Union[str, torch.device], attn_mode: str = "torch"
|
|
) -> sd3_models.MMDiT:
|
|
mmdit_sd = {}
|
|
|
|
mmdit_prefix = "model.diffusion_model."
|
|
for k in list(state_dict.keys()):
|
|
if k.startswith(mmdit_prefix):
|
|
mmdit_sd[k[len(mmdit_prefix) :]] = state_dict.pop(k)
|
|
|
|
# load MMDiT
|
|
logger.info("Building MMDit")
|
|
params = analyze_state_dict_state(mmdit_sd)
|
|
with init_empty_weights():
|
|
mmdit = sd3_models.create_sd3_mmdit(params, attn_mode)
|
|
|
|
logger.info("Loading state dict...")
|
|
info = mmdit.load_state_dict(mmdit_sd, strict=False, assign=True)
|
|
logger.info(f"Loaded MMDiT: {info}")
|
|
return mmdit
|
|
|
|
|
|
def load_clip_l(
|
|
clip_l_path: Optional[str],
|
|
dtype: Optional[Union[str, torch.dtype]],
|
|
device: Union[str, torch.device],
|
|
disable_mmap: bool = False,
|
|
state_dict: Optional[Dict] = None,
|
|
):
|
|
clip_l_sd = None
|
|
if clip_l_path is None:
|
|
if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
|
|
# found clip_l: remove prefix "text_encoders.clip_l."
|
|
logger.info("clip_l is included in the checkpoint")
|
|
clip_l_sd = {}
|
|
prefix = "text_encoders.clip_l."
|
|
for k in list(state_dict.keys()):
|
|
if k.startswith(prefix):
|
|
clip_l_sd[k[len(prefix) :]] = state_dict.pop(k)
|
|
elif clip_l_path is None:
|
|
logger.info("clip_l is not included in the checkpoint and clip_l_path is not provided")
|
|
return None
|
|
|
|
# load clip_l
|
|
logger.info("Building CLIP-L")
|
|
config = CLIPTextConfig(
|
|
vocab_size=49408,
|
|
hidden_size=768,
|
|
intermediate_size=3072,
|
|
num_hidden_layers=12,
|
|
num_attention_heads=12,
|
|
max_position_embeddings=77,
|
|
hidden_act="quick_gelu",
|
|
layer_norm_eps=1e-05,
|
|
dropout=0.0,
|
|
attention_dropout=0.0,
|
|
initializer_range=0.02,
|
|
initializer_factor=1.0,
|
|
pad_token_id=1,
|
|
bos_token_id=0,
|
|
eos_token_id=2,
|
|
model_type="clip_text_model",
|
|
projection_dim=768,
|
|
# torch_dtype="float32",
|
|
# transformers_version="4.25.0.dev0",
|
|
)
|
|
with init_empty_weights():
|
|
clip = CLIPTextModelWithProjection(config)
|
|
|
|
if clip_l_sd is None:
|
|
logger.info(f"Loading state dict from {clip_l_path}")
|
|
clip_l_sd = load_safetensors(clip_l_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
|
|
|
|
if "text_projection.weight" not in clip_l_sd:
|
|
logger.info("Adding text_projection.weight to clip_l_sd")
|
|
clip_l_sd["text_projection.weight"] = torch.eye(768, dtype=dtype, device=device)
|
|
|
|
info = clip.load_state_dict(clip_l_sd, strict=False, assign=True)
|
|
logger.info(f"Loaded CLIP-L: {info}")
|
|
return clip
|
|
|
|
|
|
def load_clip_g(
|
|
clip_g_path: Optional[str],
|
|
dtype: Optional[Union[str, torch.dtype]],
|
|
device: Union[str, torch.device],
|
|
disable_mmap: bool = False,
|
|
state_dict: Optional[Dict] = None,
|
|
):
|
|
clip_g_sd = None
|
|
if state_dict is not None:
|
|
if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
|
|
# found clip_g: remove prefix "text_encoders.clip_g."
|
|
logger.info("clip_g is included in the checkpoint")
|
|
clip_g_sd = {}
|
|
prefix = "text_encoders.clip_g."
|
|
for k in list(state_dict.keys()):
|
|
if k.startswith(prefix):
|
|
clip_g_sd[k[len(prefix) :]] = state_dict.pop(k)
|
|
elif clip_g_path is None:
|
|
logger.info("clip_g is not included in the checkpoint and clip_g_path is not provided")
|
|
return None
|
|
|
|
# load clip_g
|
|
logger.info("Building CLIP-G")
|
|
config = CLIPTextConfig(
|
|
vocab_size=49408,
|
|
hidden_size=1280,
|
|
intermediate_size=5120,
|
|
num_hidden_layers=32,
|
|
num_attention_heads=20,
|
|
max_position_embeddings=77,
|
|
hidden_act="gelu",
|
|
layer_norm_eps=1e-05,
|
|
dropout=0.0,
|
|
attention_dropout=0.0,
|
|
initializer_range=0.02,
|
|
initializer_factor=1.0,
|
|
pad_token_id=1,
|
|
bos_token_id=0,
|
|
eos_token_id=2,
|
|
model_type="clip_text_model",
|
|
projection_dim=1280,
|
|
# torch_dtype="float32",
|
|
# transformers_version="4.25.0.dev0",
|
|
)
|
|
with init_empty_weights():
|
|
clip = CLIPTextModelWithProjection(config)
|
|
|
|
if clip_g_sd is None:
|
|
logger.info(f"Loading state dict from {clip_g_path}")
|
|
clip_g_sd = load_safetensors(clip_g_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
|
|
info = clip.load_state_dict(clip_g_sd, strict=False, assign=True)
|
|
logger.info(f"Loaded CLIP-G: {info}")
|
|
return clip
|
|
|
|
|
|
def load_t5xxl(
|
|
t5xxl_path: Optional[str],
|
|
dtype: Optional[Union[str, torch.dtype]],
|
|
device: Union[str, torch.device],
|
|
disable_mmap: bool = False,
|
|
state_dict: Optional[Dict] = None,
|
|
):
|
|
t5xxl_sd = None
|
|
if state_dict is not None:
|
|
if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict:
|
|
# found t5xxl: remove prefix "text_encoders.t5xxl."
|
|
logger.info("t5xxl is included in the checkpoint")
|
|
t5xxl_sd = {}
|
|
prefix = "text_encoders.t5xxl."
|
|
for k in list(state_dict.keys()):
|
|
if k.startswith(prefix):
|
|
t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k)
|
|
elif t5xxl_path is None:
|
|
logger.info("t5xxl is not included in the checkpoint and t5xxl_path is not provided")
|
|
return None
|
|
|
|
return flux_utils_load_t5xxl(t5xxl_path, dtype, device, disable_mmap, state_dict=t5xxl_sd)
|
|
|
|
|
|
def load_vae(
|
|
vae_path: Optional[str],
|
|
vae_dtype: Optional[Union[str, torch.dtype]],
|
|
device: Optional[Union[str, torch.device]],
|
|
disable_mmap: bool = False,
|
|
state_dict: Optional[Dict] = None,
|
|
):
|
|
vae_sd = {}
|
|
if vae_path:
|
|
logger.info(f"Loading VAE from {vae_path}...")
|
|
vae_sd = load_safetensors(vae_path, device, disable_mmap)
|
|
else:
|
|
# remove prefix "first_stage_model."
|
|
vae_sd = {}
|
|
vae_prefix = "first_stage_model."
|
|
for k in list(state_dict.keys()):
|
|
if k.startswith(vae_prefix):
|
|
vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k)
|
|
|
|
logger.info("Building VAE")
|
|
vae = sd3_models.SDVAE(vae_dtype, device)
|
|
logger.info("Loading state dict...")
|
|
info = vae.load_state_dict(vae_sd)
|
|
logger.info(f"Loaded VAE: {info}")
|
|
vae.to(device=device, dtype=vae_dtype) # make sure it's in the right device and dtype
|
|
return vae
|
|
|
|
|
|
# endregion
|
|
|
|
|
|
class ModelSamplingDiscreteFlow:
|
|
"""Helper for sampler scheduling (ie timestep/sigma calculations) for Discrete Flow models"""
|
|
|
|
def __init__(self, shift=1.0):
|
|
self.shift = shift
|
|
timesteps = 1000
|
|
self.sigmas = self.sigma(torch.arange(1, timesteps + 1, 1))
|
|
|
|
@property
|
|
def sigma_min(self):
|
|
return self.sigmas[0]
|
|
|
|
@property
|
|
def sigma_max(self):
|
|
return self.sigmas[-1]
|
|
|
|
def timestep(self, sigma):
|
|
return sigma * 1000
|
|
|
|
def sigma(self, timestep: torch.Tensor):
|
|
timestep = timestep / 1000.0
|
|
if self.shift == 1.0:
|
|
return timestep
|
|
return self.shift * timestep / (1 + (self.shift - 1) * timestep)
|
|
|
|
def calculate_denoised(self, sigma, model_output, model_input):
|
|
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
|
return model_input - model_output * sigma
|
|
|
|
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
|
|
# assert max_denoise is False, "max_denoise not implemented"
|
|
# max_denoise is always True, I'm not sure why it's there
|
|
return sigma * noise + (1.0 - sigma) * latent_image
|