mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
530 lines
21 KiB
Python
530 lines
21 KiB
Python
# some codes are copied from:
|
||
# https://github.com/huawei-noah/KD-NLP/blob/main/DyLoRA/
|
||
|
||
# Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved.
|
||
# Changes made to the original code:
|
||
# 2022.08.20 - Integrate the DyLoRA layer for the LoRA Linear layer
|
||
# ------------------------------------------------------------------------------------------
|
||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||
# ------------------------------------------------------------------------------------------
|
||
|
||
import math
|
||
import os
|
||
import random
|
||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||
from diffusers import AutoencoderKL
|
||
from transformers import CLIPTextModel
|
||
import torch
|
||
from torch import nn
|
||
from library.utils import setup_logging
|
||
|
||
setup_logging()
|
||
import logging
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class DyLoRAModule(torch.nn.Module):
|
||
"""
|
||
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
||
"""
|
||
|
||
# NOTE: support dropout in future
|
||
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, unit=1):
|
||
super().__init__()
|
||
self.lora_name = lora_name
|
||
self.lora_dim = lora_dim
|
||
self.unit = unit
|
||
assert self.lora_dim % self.unit == 0, "rank must be a multiple of unit"
|
||
|
||
if org_module.__class__.__name__ == "Conv2d":
|
||
in_dim = org_module.in_channels
|
||
out_dim = org_module.out_channels
|
||
else:
|
||
in_dim = org_module.in_features
|
||
out_dim = org_module.out_features
|
||
|
||
if type(alpha) == torch.Tensor:
|
||
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
||
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
|
||
self.scale = alpha / self.lora_dim
|
||
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
|
||
|
||
self.is_conv2d = org_module.__class__.__name__ == "Conv2d"
|
||
self.is_conv2d_3x3 = self.is_conv2d and org_module.kernel_size == (3, 3)
|
||
|
||
if self.is_conv2d and self.is_conv2d_3x3:
|
||
kernel_size = org_module.kernel_size
|
||
self.stride = org_module.stride
|
||
self.padding = org_module.padding
|
||
self.lora_A = nn.ParameterList([org_module.weight.new_zeros((1, in_dim, *kernel_size)) for _ in range(self.lora_dim)])
|
||
self.lora_B = nn.ParameterList([org_module.weight.new_zeros((out_dim, 1, 1, 1)) for _ in range(self.lora_dim)])
|
||
else:
|
||
self.lora_A = nn.ParameterList([org_module.weight.new_zeros((1, in_dim)) for _ in range(self.lora_dim)])
|
||
self.lora_B = nn.ParameterList([org_module.weight.new_zeros((out_dim, 1)) for _ in range(self.lora_dim)])
|
||
|
||
# same as microsoft's
|
||
for lora in self.lora_A:
|
||
torch.nn.init.kaiming_uniform_(lora, a=math.sqrt(5))
|
||
for lora in self.lora_B:
|
||
torch.nn.init.zeros_(lora)
|
||
|
||
self.multiplier = multiplier
|
||
self.org_module = org_module # remove in applying
|
||
|
||
def apply_to(self):
|
||
self.org_forward = self.org_module.forward
|
||
self.org_module.forward = self.forward
|
||
del self.org_module
|
||
|
||
def forward(self, x):
|
||
result = self.org_forward(x)
|
||
|
||
# specify the dynamic rank
|
||
trainable_rank = random.randint(0, self.lora_dim - 1)
|
||
trainable_rank = trainable_rank - trainable_rank % self.unit # make sure the rank is a multiple of unit
|
||
|
||
# 一部のパラメータを固定して、残りのパラメータを学習する
|
||
for i in range(0, trainable_rank):
|
||
self.lora_A[i].requires_grad = False
|
||
self.lora_B[i].requires_grad = False
|
||
for i in range(trainable_rank, trainable_rank + self.unit):
|
||
self.lora_A[i].requires_grad = True
|
||
self.lora_B[i].requires_grad = True
|
||
for i in range(trainable_rank + self.unit, self.lora_dim):
|
||
self.lora_A[i].requires_grad = False
|
||
self.lora_B[i].requires_grad = False
|
||
|
||
lora_A = torch.cat(tuple(self.lora_A), dim=0)
|
||
lora_B = torch.cat(tuple(self.lora_B), dim=1)
|
||
|
||
# calculate with lora_A and lora_B
|
||
if self.is_conv2d_3x3:
|
||
ab = torch.nn.functional.conv2d(x, lora_A, stride=self.stride, padding=self.padding)
|
||
ab = torch.nn.functional.conv2d(ab, lora_B)
|
||
else:
|
||
ab = x
|
||
if self.is_conv2d:
|
||
ab = ab.reshape(ab.size(0), ab.size(1), -1).transpose(1, 2) # (N, C, H, W) -> (N, H*W, C)
|
||
|
||
ab = torch.nn.functional.linear(ab, lora_A)
|
||
ab = torch.nn.functional.linear(ab, lora_B)
|
||
|
||
if self.is_conv2d:
|
||
ab = ab.transpose(1, 2).reshape(ab.size(0), -1, *x.size()[2:]) # (N, H*W, C) -> (N, C, H, W)
|
||
|
||
# 最後の項は、低rankをより大きくするためのスケーリング(じゃないかな)
|
||
result = result + ab * self.scale * math.sqrt(self.lora_dim / (trainable_rank + self.unit))
|
||
|
||
# NOTE weightに加算してからlinear/conv2dを呼んだほうが速いかも
|
||
return result
|
||
|
||
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
||
# state dictを通常のLoRAと同じにする:
|
||
# nn.ParameterListは `.lora_A.0` みたいな名前になるので、forwardと同様にcatして入れ替える
|
||
sd = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
|
||
|
||
lora_A_weight = torch.cat(tuple(self.lora_A), dim=0)
|
||
if self.is_conv2d and not self.is_conv2d_3x3:
|
||
lora_A_weight = lora_A_weight.unsqueeze(-1).unsqueeze(-1)
|
||
|
||
lora_B_weight = torch.cat(tuple(self.lora_B), dim=1)
|
||
if self.is_conv2d and not self.is_conv2d_3x3:
|
||
lora_B_weight = lora_B_weight.unsqueeze(-1).unsqueeze(-1)
|
||
|
||
sd[self.lora_name + ".lora_down.weight"] = lora_A_weight if keep_vars else lora_A_weight.detach()
|
||
sd[self.lora_name + ".lora_up.weight"] = lora_B_weight if keep_vars else lora_B_weight.detach()
|
||
|
||
i = 0
|
||
while True:
|
||
key_a = f"{self.lora_name}.lora_A.{i}"
|
||
key_b = f"{self.lora_name}.lora_B.{i}"
|
||
if key_a in sd:
|
||
sd.pop(key_a)
|
||
sd.pop(key_b)
|
||
else:
|
||
break
|
||
i += 1
|
||
return sd
|
||
|
||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
||
# 通常のLoRAと同じstate dictを読み込めるようにする:この方法はchatGPTに聞いた
|
||
lora_A_weight = state_dict.pop(self.lora_name + ".lora_down.weight", None)
|
||
lora_B_weight = state_dict.pop(self.lora_name + ".lora_up.weight", None)
|
||
|
||
if lora_A_weight is None or lora_B_weight is None:
|
||
if strict:
|
||
raise KeyError(f"{self.lora_name}.lora_down/up.weight is not found")
|
||
else:
|
||
return
|
||
|
||
if self.is_conv2d and not self.is_conv2d_3x3:
|
||
lora_A_weight = lora_A_weight.squeeze(-1).squeeze(-1)
|
||
lora_B_weight = lora_B_weight.squeeze(-1).squeeze(-1)
|
||
|
||
state_dict.update(
|
||
{f"{self.lora_name}.lora_A.{i}": nn.Parameter(lora_A_weight[i].unsqueeze(0)) for i in range(lora_A_weight.size(0))}
|
||
)
|
||
state_dict.update(
|
||
{f"{self.lora_name}.lora_B.{i}": nn.Parameter(lora_B_weight[:, i].unsqueeze(1)) for i in range(lora_B_weight.size(1))}
|
||
)
|
||
|
||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||
|
||
|
||
def create_network(
|
||
multiplier: float,
|
||
network_dim: Optional[int],
|
||
network_alpha: Optional[float],
|
||
vae: AutoencoderKL,
|
||
text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
|
||
unet,
|
||
**kwargs,
|
||
):
|
||
if network_dim is None:
|
||
network_dim = 4 # default
|
||
if network_alpha is None:
|
||
network_alpha = 1.0
|
||
|
||
# extract dim/alpha for conv2d, and block dim
|
||
conv_dim = kwargs.get("conv_dim", None)
|
||
conv_alpha = kwargs.get("conv_alpha", None)
|
||
unit = kwargs.get("unit", None)
|
||
if conv_dim is not None:
|
||
conv_dim = int(conv_dim)
|
||
assert conv_dim == network_dim, "conv_dim must be same as network_dim"
|
||
if conv_alpha is None:
|
||
conv_alpha = 1.0
|
||
else:
|
||
conv_alpha = float(conv_alpha)
|
||
|
||
if unit is not None:
|
||
unit = int(unit)
|
||
else:
|
||
unit = 1
|
||
|
||
network = DyLoRANetwork(
|
||
text_encoder,
|
||
unet,
|
||
multiplier=multiplier,
|
||
lora_dim=network_dim,
|
||
alpha=network_alpha,
|
||
apply_to_conv=conv_dim is not None,
|
||
unit=unit,
|
||
varbose=True,
|
||
)
|
||
|
||
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
|
||
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
|
||
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
|
||
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
|
||
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
|
||
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
|
||
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
|
||
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
|
||
|
||
return network
|
||
|
||
|
||
# Create network from weights for inference, weights are not loaded here (because can be merged)
|
||
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
|
||
if weights_sd is None:
|
||
if os.path.splitext(file)[1] == ".safetensors":
|
||
from safetensors.torch import load_file, safe_open
|
||
|
||
weights_sd = load_file(file)
|
||
else:
|
||
weights_sd = torch.load(file, map_location="cpu")
|
||
|
||
# get dim/alpha mapping
|
||
modules_dim = {}
|
||
modules_alpha = {}
|
||
for key, value in weights_sd.items():
|
||
if "." not in key:
|
||
continue
|
||
|
||
lora_name = key.split(".")[0]
|
||
if "alpha" in key:
|
||
modules_alpha[lora_name] = value
|
||
elif "lora_down" in key:
|
||
dim = value.size()[0]
|
||
modules_dim[lora_name] = dim
|
||
# logger.info(f"{lora_name} {value.size()} {dim}")
|
||
|
||
# support old LoRA without alpha
|
||
for key in modules_dim.keys():
|
||
if key not in modules_alpha:
|
||
modules_alpha = modules_dim[key]
|
||
|
||
module_class = DyLoRAModule
|
||
|
||
network = DyLoRANetwork(
|
||
text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class
|
||
)
|
||
return network, weights_sd
|
||
|
||
|
||
class DyLoRANetwork(torch.nn.Module):
|
||
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
||
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||
LORA_PREFIX_UNET = "lora_unet"
|
||
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
||
|
||
def __init__(
|
||
self,
|
||
text_encoder,
|
||
unet,
|
||
multiplier=1.0,
|
||
lora_dim=4,
|
||
alpha=1,
|
||
apply_to_conv=False,
|
||
modules_dim=None,
|
||
modules_alpha=None,
|
||
unit=1,
|
||
module_class=DyLoRAModule,
|
||
varbose=False,
|
||
) -> None:
|
||
super().__init__()
|
||
self.multiplier = multiplier
|
||
|
||
self.lora_dim = lora_dim
|
||
self.alpha = alpha
|
||
self.apply_to_conv = apply_to_conv
|
||
|
||
self.loraplus_lr_ratio = None
|
||
self.loraplus_unet_lr_ratio = None
|
||
self.loraplus_text_encoder_lr_ratio = None
|
||
|
||
if modules_dim is not None:
|
||
logger.info("create LoRA network from weights")
|
||
else:
|
||
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, unit: {unit}")
|
||
if self.apply_to_conv:
|
||
logger.info("apply LoRA to Conv2d with kernel size (3,3).")
|
||
|
||
# create module instances
|
||
def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[DyLoRAModule]:
|
||
prefix = DyLoRANetwork.LORA_PREFIX_UNET if is_unet else DyLoRANetwork.LORA_PREFIX_TEXT_ENCODER
|
||
loras = []
|
||
for name, module in root_module.named_modules():
|
||
if module.__class__.__name__ in target_replace_modules:
|
||
for child_name, child_module in module.named_modules():
|
||
is_linear = child_module.__class__.__name__ == "Linear"
|
||
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
||
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
||
|
||
if is_linear or is_conv2d:
|
||
lora_name = prefix + "." + name + "." + child_name
|
||
lora_name = lora_name.replace(".", "_")
|
||
|
||
dim = None
|
||
alpha = None
|
||
if modules_dim is not None:
|
||
if lora_name in modules_dim:
|
||
dim = modules_dim[lora_name]
|
||
alpha = modules_alpha[lora_name]
|
||
else:
|
||
if is_linear or is_conv2d_1x1 or apply_to_conv:
|
||
dim = self.lora_dim
|
||
alpha = self.alpha
|
||
|
||
if dim is None or dim == 0:
|
||
continue
|
||
|
||
# dropout and fan_in_fan_out is default
|
||
lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, unit)
|
||
loras.append(lora)
|
||
return loras
|
||
|
||
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
|
||
|
||
self.text_encoder_loras = []
|
||
for i, text_encoder in enumerate(text_encoders):
|
||
if len(text_encoders) > 1:
|
||
index = i + 1
|
||
logger.info(f"create LoRA for Text Encoder {index}")
|
||
else:
|
||
index = None
|
||
logger.info("create LoRA for Text Encoder")
|
||
|
||
text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||
self.text_encoder_loras.extend(text_encoder_loras)
|
||
|
||
# self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||
logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||
|
||
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
||
target_modules = DyLoRANetwork.UNET_TARGET_REPLACE_MODULE
|
||
if modules_dim is not None or self.apply_to_conv:
|
||
target_modules += DyLoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||
|
||
self.unet_loras = create_modules(True, unet, target_modules)
|
||
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||
|
||
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
|
||
self.loraplus_lr_ratio = loraplus_lr_ratio
|
||
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
|
||
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
|
||
|
||
logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}")
|
||
logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}")
|
||
|
||
def set_multiplier(self, multiplier):
|
||
self.multiplier = multiplier
|
||
for lora in self.text_encoder_loras + self.unet_loras:
|
||
lora.multiplier = self.multiplier
|
||
|
||
def load_weights(self, file):
|
||
if os.path.splitext(file)[1] == ".safetensors":
|
||
from safetensors.torch import load_file
|
||
|
||
weights_sd = load_file(file)
|
||
else:
|
||
weights_sd = torch.load(file, map_location="cpu")
|
||
|
||
info = self.load_state_dict(weights_sd, False)
|
||
return info
|
||
|
||
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
|
||
if apply_text_encoder:
|
||
logger.info("enable LoRA for text encoder")
|
||
else:
|
||
self.text_encoder_loras = []
|
||
|
||
if apply_unet:
|
||
logger.info("enable LoRA for U-Net")
|
||
else:
|
||
self.unet_loras = []
|
||
|
||
for lora in self.text_encoder_loras + self.unet_loras:
|
||
lora.apply_to()
|
||
self.add_module(lora.lora_name, lora)
|
||
|
||
"""
|
||
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
|
||
apply_text_encoder = apply_unet = False
|
||
for key in weights_sd.keys():
|
||
if key.startswith(DyLoRANetwork.LORA_PREFIX_TEXT_ENCODER):
|
||
apply_text_encoder = True
|
||
elif key.startswith(DyLoRANetwork.LORA_PREFIX_UNET):
|
||
apply_unet = True
|
||
|
||
if apply_text_encoder:
|
||
logger.info("enable LoRA for text encoder")
|
||
else:
|
||
self.text_encoder_loras = []
|
||
|
||
if apply_unet:
|
||
logger.info("enable LoRA for U-Net")
|
||
else:
|
||
self.unet_loras = []
|
||
|
||
for lora in self.text_encoder_loras + self.unet_loras:
|
||
sd_for_lora = {}
|
||
for key in weights_sd.keys():
|
||
if key.startswith(lora.lora_name):
|
||
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
|
||
lora.merge_to(sd_for_lora, dtype, device)
|
||
|
||
logger.info(f"weights are merged")
|
||
"""
|
||
|
||
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
||
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
||
self.requires_grad_(True)
|
||
all_params = []
|
||
|
||
def assemble_params(loras, lr, ratio):
|
||
param_groups = {"lora": {}, "plus": {}}
|
||
for lora in loras:
|
||
for name, param in lora.named_parameters():
|
||
if ratio is not None and "lora_B" in name:
|
||
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
|
||
else:
|
||
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
|
||
|
||
params = []
|
||
for key in param_groups.keys():
|
||
param_data = {"params": param_groups[key].values()}
|
||
|
||
if len(param_data["params"]) == 0:
|
||
continue
|
||
|
||
if lr is not None:
|
||
if key == "plus":
|
||
param_data["lr"] = lr * ratio
|
||
else:
|
||
param_data["lr"] = lr
|
||
|
||
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
|
||
continue
|
||
|
||
params.append(param_data)
|
||
|
||
return params
|
||
|
||
if self.text_encoder_loras:
|
||
params = assemble_params(
|
||
self.text_encoder_loras,
|
||
text_encoder_lr if text_encoder_lr is not None else default_lr,
|
||
self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio,
|
||
)
|
||
all_params.extend(params)
|
||
|
||
if self.unet_loras:
|
||
params = assemble_params(
|
||
self.unet_loras, default_lr if unet_lr is None else unet_lr, self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio
|
||
)
|
||
all_params.extend(params)
|
||
|
||
return all_params
|
||
|
||
def enable_gradient_checkpointing(self):
|
||
# not supported
|
||
pass
|
||
|
||
def prepare_grad_etc(self, text_encoder, unet):
|
||
self.requires_grad_(True)
|
||
|
||
def on_epoch_start(self, text_encoder, unet):
|
||
self.train()
|
||
|
||
def get_trainable_params(self):
|
||
return self.parameters()
|
||
|
||
def save_weights(self, file, dtype, metadata):
|
||
if metadata is not None and len(metadata) == 0:
|
||
metadata = None
|
||
|
||
state_dict = self.state_dict()
|
||
|
||
if dtype is not None:
|
||
for key in list(state_dict.keys()):
|
||
v = state_dict[key]
|
||
v = v.detach().clone().to("cpu").to(dtype)
|
||
state_dict[key] = v
|
||
|
||
if os.path.splitext(file)[1] == ".safetensors":
|
||
from safetensors.torch import save_file
|
||
from library import train_util
|
||
|
||
# Precalculate model hashes to save time on indexing
|
||
if metadata is None:
|
||
metadata = {}
|
||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||
metadata["sshs_model_hash"] = model_hash
|
||
metadata["sshs_legacy_hash"] = legacy_hash
|
||
|
||
save_file(state_dict, file, metadata)
|
||
else:
|
||
torch.save(state_dict, file)
|
||
|
||
# mask is a tensor with values from 0 to 1
|
||
def set_region(self, sub_prompt_index, is_last_network, mask):
|
||
pass
|
||
|
||
def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared):
|
||
pass
|