Files
Kohya-ss-sd-scripts/networks/control_net_lllite_for_train.py
2024-06-09 19:26:36 +09:00

502 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# cond_imageをU-Netのforwardで渡すバージョンのControlNet-LLLite検証用実装
# ControlNet-LLLite implementation for verification with cond_image passed in U-Net's forward
import os
import re
from typing import Optional, List, Type
import torch
from library import sdxl_original_unet
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# input_blocksに適用するかどうか / if True, input_blocks are not applied
SKIP_INPUT_BLOCKS = False
# output_blocksに適用するかどうか / if True, output_blocks are not applied
SKIP_OUTPUT_BLOCKS = True
# conv2dに適用するかどうか / if True, conv2d are not applied
SKIP_CONV2D = False
# transformer_blocksのみに適用するかどうか。Trueの場合、ResBlockには適用されない
# if True, only transformer_blocks are applied, and ResBlocks are not applied
TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored because conv2d is not used in transformer_blocks
# Trueならattn1とattn2にのみ適用し、ffなどには適用しない / if True, apply only to attn1 and attn2, not to ff etc.
ATTN1_2_ONLY = True
# Trueならattn1のQKV、attn2のQにのみ適用する、ATTN1_2_ONLY指定時のみ有効 / if True, apply only to attn1 QKV and attn2 Q, only valid when ATTN1_2_ONLY is specified
ATTN_QKV_ONLY = True
# Trueならattn1やffなどにのみ適用し、attn2などには適用しない / if True, apply only to attn1 and ff, not to attn2
# ATTN1_2_ONLYと同時にTrueにできない / cannot be True at the same time as ATTN1_2_ONLY
ATTN1_ETC_ONLY = False # True
# transformer_blocksの最大インデックス。Noneなら全てのtransformer_blocksに適用
# max index of transformer_blocks. if None, apply to all transformer_blocks
TRANSFORMER_MAX_BLOCK_INDEX = None
ORIGINAL_LINEAR = torch.nn.Linear
ORIGINAL_CONV2D = torch.nn.Conv2d
def add_lllite_modules(module: torch.nn.Module, in_dim: int, depth, cond_emb_dim, mlp_dim) -> None:
# conditioning1はconditioning imageを embedding する。timestepごとに呼ばれない
# conditioning1 embeds conditioning image. it is not called for each timestep
modules = []
modules.append(ORIGINAL_CONV2D(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) # to latent (from VAE) size
if depth == 1:
modules.append(torch.nn.ReLU(inplace=True))
modules.append(ORIGINAL_CONV2D(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
elif depth == 2:
modules.append(torch.nn.ReLU(inplace=True))
modules.append(ORIGINAL_CONV2D(cond_emb_dim // 2, cond_emb_dim, kernel_size=4, stride=4, padding=0))
elif depth == 3:
# kernel size 8は大きすぎるので、4にする / kernel size 8 is too large, so set it to 4
modules.append(torch.nn.ReLU(inplace=True))
modules.append(ORIGINAL_CONV2D(cond_emb_dim // 2, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0))
modules.append(torch.nn.ReLU(inplace=True))
modules.append(ORIGINAL_CONV2D(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
module.lllite_conditioning1 = torch.nn.Sequential(*modules)
# downで入力の次元数を削減する。LoRAにヒントを得ていることにする
# midでconditioning image embeddingと入力を結合する
# upで元の次元数に戻す
# これらはtimestepごとに呼ばれる
# reduce the number of input dimensions with down. inspired by LoRA
# combine conditioning image embedding and input with mid
# restore to the original dimension with up
# these are called for each timestep
module.lllite_down = torch.nn.Sequential(
ORIGINAL_LINEAR(in_dim, mlp_dim),
torch.nn.ReLU(inplace=True),
)
module.lllite_mid = torch.nn.Sequential(
ORIGINAL_LINEAR(mlp_dim + cond_emb_dim, mlp_dim),
torch.nn.ReLU(inplace=True),
)
module.lllite_up = torch.nn.Sequential(
ORIGINAL_LINEAR(mlp_dim, in_dim),
)
# Zero-Convにする / set to Zero-Conv
torch.nn.init.zeros_(module.lllite_up[0].weight) # zero conv
class LLLiteLinear(ORIGINAL_LINEAR):
def __init__(self, in_features: int, out_features: int, **kwargs):
super().__init__(in_features, out_features, **kwargs)
self.enabled = False
def set_lllite(self, depth, cond_emb_dim, name, mlp_dim, dropout=None, multiplier=1.0):
self.enabled = True
self.lllite_name = name
self.cond_emb_dim = cond_emb_dim
self.dropout = dropout
self.multiplier = multiplier # ignored
in_dim = self.in_features
add_lllite_modules(self, in_dim, depth, cond_emb_dim, mlp_dim)
self.cond_image = None
def set_cond_image(self, cond_image):
self.cond_image = cond_image
def forward(self, x):
if not self.enabled:
return super().forward(x)
cx = self.lllite_conditioning1(self.cond_image) # make forward and backward compatible
# reshape / b,c,h,w -> b,h*w,c
n, c, h, w = cx.shape
cx = cx.view(n, c, h * w).permute(0, 2, 1)
cx = torch.cat([cx, self.lllite_down(x)], dim=2)
cx = self.lllite_mid(cx)
if self.dropout is not None and self.training:
cx = torch.nn.functional.dropout(cx, p=self.dropout)
cx = self.lllite_up(cx) * self.multiplier
x = super().forward(x + cx) # ここで元のモジュールを呼び出す / call the original module here
return x
class LLLiteConv2d(ORIGINAL_CONV2D):
def __init__(self, in_channels: int, out_channels: int, kernel_size, **kwargs):
super().__init__(in_channels, out_channels, kernel_size, **kwargs)
self.enabled = False
def set_lllite(self, depth, cond_emb_dim, name, mlp_dim, dropout=None, multiplier=1.0):
self.enabled = True
self.lllite_name = name
self.cond_emb_dim = cond_emb_dim
self.dropout = dropout
self.multiplier = multiplier # ignored
in_dim = self.in_channels
add_lllite_modules(self, in_dim, depth, cond_emb_dim, mlp_dim)
self.cond_image = None
self.cond_emb = None
def set_cond_image(self, cond_image):
self.cond_image = cond_image
self.cond_emb = None
def forward(self, x): # , cond_image=None):
if not self.enabled:
return super().forward(x)
cx = self.lllite_conditioning1(self.cond_image)
cx = torch.cat([cx, self.down(x)], dim=1)
cx = self.mid(cx)
if self.dropout is not None and self.training:
cx = torch.nn.functional.dropout(cx, p=self.dropout)
cx = self.up(cx) * self.multiplier
x = super().forward(x + cx) # ここで元のモジュールを呼び出す / call the original module here
return x
class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DConditionModel):
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
LLLITE_PREFIX = "lllite_unet"
def __init__(self, **kwargs):
super().__init__(**kwargs)
def apply_lllite(
self,
cond_emb_dim: int = 16,
mlp_dim: int = 16,
dropout: Optional[float] = None,
varbose: Optional[bool] = False,
multiplier: Optional[float] = 1.0,
) -> None:
def apply_to_modules(
root_module: torch.nn.Module,
target_replace_modules: List[torch.nn.Module],
) -> List[torch.nn.Module]:
prefix = "lllite_unet"
modules = []
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
is_linear = child_module.__class__.__name__ == "LLLiteLinear"
is_conv2d = child_module.__class__.__name__ == "LLLiteConv2d"
if is_linear or (is_conv2d and not SKIP_CONV2D):
# block indexからdepthを計算: depthはconditioningのサイズやチャネルを計算するのに使う
# block index to depth: depth is using to calculate conditioning size and channels
block_name, index1, index2 = (name + "." + child_name).split(".")[:3]
index1 = int(index1)
if block_name == "input_blocks":
if SKIP_INPUT_BLOCKS:
continue
depth = 1 if index1 <= 2 else (2 if index1 <= 5 else 3)
elif block_name == "middle_block":
depth = 3
elif block_name == "output_blocks":
if SKIP_OUTPUT_BLOCKS:
continue
depth = 3 if index1 <= 2 else (2 if index1 <= 5 else 1)
if int(index2) >= 2:
depth -= 1
else:
raise NotImplementedError()
lllite_name = prefix + "." + name + "." + child_name
lllite_name = lllite_name.replace(".", "_")
if TRANSFORMER_MAX_BLOCK_INDEX is not None:
p = lllite_name.find("transformer_blocks")
if p >= 0:
tf_index = int(lllite_name[p:].split("_")[2])
if tf_index > TRANSFORMER_MAX_BLOCK_INDEX:
continue
# time embは適用外とする
# attn2のconditioning (CLIPからの入力) はshapeが違うので適用できない
# time emb is not applied
# attn2 conditioning (input from CLIP) cannot be applied because the shape is different
if "emb_layers" in lllite_name or (
"attn2" in lllite_name and ("to_k" in lllite_name or "to_v" in lllite_name)
):
continue
if ATTN1_2_ONLY:
if not ("attn1" in lllite_name or "attn2" in lllite_name):
continue
if ATTN_QKV_ONLY:
if "to_out" in lllite_name:
continue
if ATTN1_ETC_ONLY:
if "proj_out" in lllite_name:
pass
elif "attn1" in lllite_name and (
"to_k" in lllite_name or "to_v" in lllite_name or "to_out" in lllite_name
):
pass
elif "ff_net_2" in lllite_name:
pass
else:
continue
child_module.set_lllite(depth, cond_emb_dim, lllite_name, mlp_dim, dropout, multiplier)
modules.append(child_module)
return modules
target_modules = SdxlUNet2DConditionModelControlNetLLLite.UNET_TARGET_REPLACE_MODULE
if not TRANSFORMER_ONLY:
target_modules = target_modules + SdxlUNet2DConditionModelControlNetLLLite.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
# create module instances
self.lllite_modules = apply_to_modules(self, target_modules)
logger.info(f"enable ControlNet LLLite for U-Net: {len(self.lllite_modules)} modules.")
# def prepare_optimizer_params(self):
def prepare_params(self):
train_params = []
non_train_params = []
for name, p in self.named_parameters():
if "lllite" in name:
train_params.append(p)
else:
non_train_params.append(p)
logger.info(f"count of trainable parameters: {len(train_params)}")
logger.info(f"count of non-trainable parameters: {len(non_train_params)}")
for p in non_train_params:
p.requires_grad_(False)
# without this, an error occurs in the optimizer
# RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
non_train_params[0].requires_grad_(True)
for p in train_params:
p.requires_grad_(True)
return train_params
# def prepare_grad_etc(self):
# self.requires_grad_(True)
# def on_epoch_start(self):
# self.train()
def get_trainable_params(self):
return [p[1] for p in self.named_parameters() if "lllite" in p[0]]
def save_lllite_weights(self, file, dtype, metadata):
if metadata is not None and len(metadata) == 0:
metadata = None
org_state_dict = self.state_dict()
# copy LLLite keys from org_state_dict to state_dict with key conversion
state_dict = {}
for key in org_state_dict.keys():
# split with ".lllite"
pos = key.find(".lllite")
if pos < 0:
continue
lllite_key = SdxlUNet2DConditionModelControlNetLLLite.LLLITE_PREFIX + "." + key[:pos]
lllite_key = lllite_key.replace(".", "_") + key[pos:]
lllite_key = lllite_key.replace(".lllite_", ".")
state_dict[lllite_key] = org_state_dict[key]
if dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(dtype)
state_dict[key] = v
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file
save_file(state_dict, file, metadata)
else:
torch.save(state_dict, file)
def load_lllite_weights(self, file, non_lllite_unet_sd=None):
r"""
LLLiteの重みを読み込まないinitされた値を使う場合はfileにNoneを指定する。
この場合、non_lllite_unet_sdにはU-Netのstate_dictを指定する。
If you do not want to load LLLite weights (use initialized values), specify None for file.
In this case, specify the state_dict of U-Net for non_lllite_unet_sd.
"""
if not file:
state_dict = self.state_dict()
for key in non_lllite_unet_sd:
if key in state_dict:
state_dict[key] = non_lllite_unet_sd[key]
info = self.load_state_dict(state_dict, False)
return info
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
# module_name = module_name.replace("_block", "@blocks")
# module_name = module_name.replace("_layer", "@layer")
# module_name = module_name.replace("to_", "to@")
# module_name = module_name.replace("time_embed", "time@embed")
# module_name = module_name.replace("label_emb", "label@emb")
# module_name = module_name.replace("skip_connection", "skip@connection")
# module_name = module_name.replace("proj_in", "proj@in")
# module_name = module_name.replace("proj_out", "proj@out")
pattern = re.compile(r"(_block|_layer|to_|time_embed|label_emb|skip_connection|proj_in|proj_out)")
# convert to lllite with U-Net state dict
state_dict = non_lllite_unet_sd.copy() if non_lllite_unet_sd is not None else {}
for key in weights_sd.keys():
# split with "."
pos = key.find(".")
if pos < 0:
continue
module_name = key[:pos]
weight_name = key[pos + 1 :] # exclude "."
module_name = module_name.replace(SdxlUNet2DConditionModelControlNetLLLite.LLLITE_PREFIX + "_", "")
# これはうまくいかない。逆変換を考えなかった設計が悪い / this does not work well. bad design because I didn't think about inverse conversion
# module_name = module_name.replace("_", ".")
# ださいけどSDXLのU-Netの "_" を "@" に変換する / ugly but convert "_" of SDXL U-Net to "@"
matches = pattern.findall(module_name)
if matches is not None:
for m in matches:
logger.info(f"{module_name} {m}")
module_name = module_name.replace(m, m.replace("_", "@"))
module_name = module_name.replace("_", ".")
module_name = module_name.replace("@", "_")
lllite_key = module_name + ".lllite_" + weight_name
state_dict[lllite_key] = weights_sd[key]
info = self.load_state_dict(state_dict, False)
return info
def forward(self, x, timesteps=None, context=None, y=None, cond_image=None, **kwargs):
for m in self.lllite_modules:
m.set_cond_image(cond_image)
return super().forward(x, timesteps, context, y, **kwargs)
def replace_unet_linear_and_conv2d():
logger.info("replace torch.nn.Linear and torch.nn.Conv2d to LLLiteLinear and LLLiteConv2d in U-Net")
sdxl_original_unet.torch.nn.Linear = LLLiteLinear
sdxl_original_unet.torch.nn.Conv2d = LLLiteConv2d
if __name__ == "__main__":
# デバッグ用 / for debug
# sdxl_original_unet.USE_REENTRANT = False
replace_unet_linear_and_conv2d()
# test shape etc
logger.info("create unet")
unet = SdxlUNet2DConditionModelControlNetLLLite()
logger.info("enable ControlNet-LLLite")
unet.apply_lllite(32, 64, None, False, 1.0)
unet.to("cuda") # .to(torch.float16)
# from safetensors.torch import load_file
# model_sd = load_file(r"E:\Work\SD\Models\sdxl\sd_xl_base_1.0_0.9vae.safetensors")
# unet_sd = {}
# # copy U-Net keys from unet_state_dict to state_dict
# prefix = "model.diffusion_model."
# for key in model_sd.keys():
# if key.startswith(prefix):
# converted_key = key[len(prefix) :]
# unet_sd[converted_key] = model_sd[key]
# info = unet.load_lllite_weights("r:/lllite_from_unet.safetensors", unet_sd)
# logger.info(info)
# logger.info(unet)
# logger.info number of parameters
params = unet.prepare_params()
logger.info(f"number of parameters {sum(p.numel() for p in params)}")
# logger.info("type any key to continue")
# input()
unet.set_use_memory_efficient_attention(True, False)
unet.set_gradient_checkpointing(True)
unet.train() # for gradient checkpointing
# # visualize
# import torchviz
# logger.info("run visualize")
# controlnet.set_control(conditioning_image)
# output = unet(x, t, ctx, y)
# logger.info("make_dot")
# image = torchviz.make_dot(output, params=dict(controlnet.named_parameters()))
# logger.info("render")
# image.format = "svg" # "png"
# image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time
# input()
import bitsandbytes
optimizer = bitsandbytes.adam.Adam8bit(params, 1e-3)
scaler = torch.cuda.amp.GradScaler(enabled=True)
logger.info("start training")
steps = 10
batch_size = 1
sample_param = [p for p in unet.named_parameters() if ".lllite_up." in p[0]][0]
for step in range(steps):
logger.info(f"step {step}")
conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0
x = torch.randn(batch_size, 4, 128, 128).cuda()
t = torch.randint(low=0, high=10, size=(batch_size,)).cuda()
ctx = torch.randn(batch_size, 77, 2048).cuda()
y = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda()
with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
output = unet(x, t, ctx, y, conditioning_image)
target = torch.randn_like(output)
loss = torch.nn.functional.mse_loss(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
logger.info(sample_param)
# from safetensors.torch import save_file
# logger.info("save weights")
# unet.save_lllite_weights("r:/lllite_from_unet.safetensors", torch.float16, None)