mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
Update embedder_dims, add more flexible caption extension
This commit is contained in:
@@ -887,6 +887,9 @@ class NextDiT(nn.Module):
|
||||
),
|
||||
)
|
||||
|
||||
nn.init.trunc_normal_(self.cap_embedder[1].weight, std=0.02)
|
||||
nn.init.zeros_(self.cap_embedder[1].bias)
|
||||
|
||||
self.context_refiner = nn.ModuleList(
|
||||
[
|
||||
JointTransformerBlock(
|
||||
@@ -929,9 +932,6 @@ class NextDiT(nn.Module):
|
||||
]
|
||||
)
|
||||
|
||||
nn.init.trunc_normal_(self.cap_embedder[1].weight, std=0.02)
|
||||
# nn.init.zeros_(self.cap_embedder[1].weight)
|
||||
nn.init.zeros_(self.cap_embedder[1].bias)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
|
||||
@@ -529,8 +529,8 @@ class DreamBoothSubset(BaseSubset):
|
||||
self.is_reg = is_reg
|
||||
self.class_tokens = class_tokens
|
||||
self.caption_extension = caption_extension
|
||||
if self.caption_extension and not self.caption_extension.startswith("."):
|
||||
self.caption_extension = "." + self.caption_extension
|
||||
# if self.caption_extension and not self.caption_extension.startswith("."):
|
||||
# self.caption_extension = "." + self.caption_extension
|
||||
self.cache_info = cache_info
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
@@ -1895,30 +1895,33 @@ class DreamBoothDataset(BaseDataset):
|
||||
self.bucket_reso_steps = None # この情報は使われない
|
||||
self.bucket_no_upscale = False
|
||||
|
||||
def read_caption(img_path, caption_extension, enable_wildcard):
|
||||
def read_caption(img_path: str, caption_extension: str, enable_wildcard: bool):
|
||||
# captionの候補ファイル名を作る
|
||||
base_name = os.path.splitext(img_path)[0]
|
||||
base_name_face_det = base_name
|
||||
tokens = base_name.split("_")
|
||||
if len(tokens) >= 5:
|
||||
base_name_face_det = "_".join(tokens[:-4])
|
||||
cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension]
|
||||
cap_paths = [(base_name, caption_extension), (base_name_face_det, caption_extension)]
|
||||
|
||||
caption = None
|
||||
for cap_path in cap_paths:
|
||||
if os.path.isfile(cap_path):
|
||||
with open(cap_path, "rt", encoding="utf-8") as f:
|
||||
try:
|
||||
lines = f.readlines()
|
||||
except UnicodeDecodeError as e:
|
||||
logger.error(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}")
|
||||
raise e
|
||||
assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
|
||||
if enable_wildcard:
|
||||
caption = "\n".join([line.strip() for line in lines if line.strip() != ""]) # 空行を除く、改行で連結
|
||||
else:
|
||||
caption = lines[0].strip()
|
||||
break
|
||||
for base, cap_extension in cap_paths:
|
||||
# check with and without . to allow for extension flexibility (img_var.txt, img.txt, img + txt)
|
||||
for cap_path in [base + cap_extension, base + "." + cap_extension]:
|
||||
if os.path.isfile(cap_path):
|
||||
with open(cap_path, "rt", encoding="utf-8") as f:
|
||||
try:
|
||||
lines = f.readlines()
|
||||
except UnicodeDecodeError as e:
|
||||
logger.error(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}")
|
||||
raise e
|
||||
assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
|
||||
if enable_wildcard:
|
||||
caption = "\n".join([line.strip() for line in lines if line.strip() != ""]) # 空行を除く、改行で連結
|
||||
else:
|
||||
caption = lines[0].strip()
|
||||
break
|
||||
break
|
||||
return caption
|
||||
|
||||
def load_dreambooth_dir(subset: DreamBoothSubset):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# temporary minimum implementation of LoRA
|
||||
# FLUX doesn't have Conv2d, so we ignore it
|
||||
# Lumina 2 does not have Conv2d, so ignore
|
||||
# TODO commonize with the original implementation
|
||||
|
||||
# LoRA network module
|
||||
@@ -10,13 +10,11 @@
|
||||
import math
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
from diffusers import AutoencoderKL
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from transformers import CLIPTextModel
|
||||
import numpy as np
|
||||
import torch
|
||||
import re
|
||||
from torch import Tensor, nn
|
||||
from library.utils import setup_logging
|
||||
from library.sdxl_original_unet import SdxlUNet2DConditionModel
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
@@ -35,14 +33,14 @@ class LoRAModule(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lora_name,
|
||||
org_module: torch.nn.Module,
|
||||
multiplier=1.0,
|
||||
lora_dim=4,
|
||||
alpha=1,
|
||||
dropout=None,
|
||||
rank_dropout=None,
|
||||
module_dropout=None,
|
||||
lora_name: str,
|
||||
org_module: nn.Module,
|
||||
multiplier: float =1.0,
|
||||
lora_dim: int = 4,
|
||||
alpha: Optional[float | int | Tensor] = 1,
|
||||
dropout: Optional[float] = None,
|
||||
rank_dropout: Optional[float] = None,
|
||||
module_dropout: Optional[float] = None,
|
||||
split_dims: Optional[List[int]] = None,
|
||||
):
|
||||
"""
|
||||
@@ -60,6 +58,9 @@ class LoRAModule(torch.nn.Module):
|
||||
in_dim = org_module.in_features
|
||||
out_dim = org_module.out_features
|
||||
|
||||
assert isinstance(in_dim, int)
|
||||
assert isinstance(out_dim, int)
|
||||
|
||||
self.lora_dim = lora_dim
|
||||
self.split_dims = split_dims
|
||||
|
||||
@@ -68,30 +69,31 @@ class LoRAModule(torch.nn.Module):
|
||||
kernel_size = org_module.kernel_size
|
||||
stride = org_module.stride
|
||||
padding = org_module.padding
|
||||
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
|
||||
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
|
||||
self.lora_down = nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
|
||||
self.lora_up = nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
|
||||
else:
|
||||
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
|
||||
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
|
||||
self.lora_down = nn.Linear(in_dim, self.lora_dim, bias=False)
|
||||
self.lora_up = nn.Linear(self.lora_dim, out_dim, bias=False)
|
||||
|
||||
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
||||
torch.nn.init.zeros_(self.lora_up.weight)
|
||||
nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
||||
nn.init.zeros_(self.lora_up.weight)
|
||||
else:
|
||||
# conv2d not supported
|
||||
assert sum(split_dims) == out_dim, "sum of split_dims must be equal to out_dim"
|
||||
assert org_module.__class__.__name__ == "Linear", "split_dims is only supported for Linear"
|
||||
# print(f"split_dims: {split_dims}")
|
||||
self.lora_down = torch.nn.ModuleList(
|
||||
[torch.nn.Linear(in_dim, self.lora_dim, bias=False) for _ in range(len(split_dims))]
|
||||
self.lora_down = nn.ModuleList(
|
||||
[nn.Linear(in_dim, self.lora_dim, bias=False) for _ in range(len(split_dims))]
|
||||
)
|
||||
self.lora_up = torch.nn.ModuleList([torch.nn.Linear(self.lora_dim, split_dim, bias=False) for split_dim in split_dims])
|
||||
for lora_down in self.lora_down:
|
||||
torch.nn.init.kaiming_uniform_(lora_down.weight, a=math.sqrt(5))
|
||||
for lora_up in self.lora_up:
|
||||
torch.nn.init.zeros_(lora_up.weight)
|
||||
self.lora_up = nn.ModuleList([torch.nn.Linear(self.lora_dim, split_dim, bias=False) for split_dim in split_dims])
|
||||
|
||||
if type(alpha) == torch.Tensor:
|
||||
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
||||
for lora_down in self.lora_down:
|
||||
nn.init.kaiming_uniform_(lora_down.weight, a=math.sqrt(5))
|
||||
for lora_up in self.lora_up:
|
||||
nn.init.zeros_(lora_up.weight)
|
||||
|
||||
if isinstance(alpha, Tensor):
|
||||
alpha = alpha.detach().cpu().float().item() # without casting, bf16 causes error
|
||||
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
|
||||
self.scale = alpha / self.lora_dim
|
||||
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
|
||||
@@ -140,6 +142,9 @@ class LoRAModule(torch.nn.Module):
|
||||
|
||||
lx = self.lora_up(lx)
|
||||
|
||||
if self.dropout is not None and self.training:
|
||||
lx = torch.nn.functional.dropout(lx, p=self.dropout)
|
||||
|
||||
return org_forwarded + lx * self.multiplier * scale
|
||||
else:
|
||||
lxs = [lora_down(x) for lora_down in self.lora_down]
|
||||
@@ -152,9 +157,9 @@ class LoRAModule(torch.nn.Module):
|
||||
if self.rank_dropout is not None and self.training:
|
||||
masks = [torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout for lx in lxs]
|
||||
for i in range(len(lxs)):
|
||||
if len(lx.size()) == 3:
|
||||
if len(lxs[i].size()) == 3:
|
||||
masks[i] = masks[i].unsqueeze(1)
|
||||
elif len(lx.size()) == 4:
|
||||
elif len(lxs[i].size()) == 4:
|
||||
masks[i] = masks[i].unsqueeze(-1).unsqueeze(-1)
|
||||
lxs[i] = lxs[i] * masks[i]
|
||||
|
||||
@@ -165,6 +170,9 @@ class LoRAModule(torch.nn.Module):
|
||||
|
||||
lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)]
|
||||
|
||||
if self.dropout is not None and self.training:
|
||||
lxs = [torch.nn.functional.dropout(lx, p=self.dropout) for lx in lxs]
|
||||
|
||||
return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale
|
||||
|
||||
|
||||
@@ -339,14 +347,14 @@ def create_network(
|
||||
if all([d is None for d in type_dims]):
|
||||
type_dims = None
|
||||
|
||||
# in_dims for embedders
|
||||
in_dims = kwargs.get("in_dims", None)
|
||||
if in_dims is not None:
|
||||
in_dims = in_dims.strip()
|
||||
if in_dims.startswith("[") and in_dims.endswith("]"):
|
||||
in_dims = in_dims[1:-1]
|
||||
in_dims = [int(d) for d in in_dims.split(",")]
|
||||
assert len(in_dims) == 4, f"invalid in_dims: {in_dims}, must be 4 dimensions (x_embedder, t_embedder, cap_embedder, final_layer)"
|
||||
# embedder_dims for embedders
|
||||
embedder_dims = kwargs.get("embedder_dims", None)
|
||||
if embedder_dims is not None:
|
||||
embedder_dims = embedder_dims.strip()
|
||||
if embedder_dims.startswith("[") and embedder_dims.endswith("]"):
|
||||
embedder_dims = embedder_dims[1:-1]
|
||||
embedder_dims = [int(d) for d in embedder_dims.split(",")]
|
||||
assert len(embedder_dims) == 3, f"invalid embedder_dims: {embedder_dims}, must be 4 dimensions (x_embedder, t_embedder, cap_embedder)"
|
||||
|
||||
# rank/module dropout
|
||||
rank_dropout = kwargs.get("rank_dropout", None)
|
||||
@@ -357,9 +365,9 @@ def create_network(
|
||||
module_dropout = float(module_dropout)
|
||||
|
||||
# single or double blocks
|
||||
train_blocks = kwargs.get("train_blocks", None) # None (default), "all" (same as None), "single", "double"
|
||||
train_blocks = kwargs.get("train_blocks", None) # None (default), "all" (same as None), "transformer", "refiners", "noise_refiner", "context_refiner"
|
||||
if train_blocks is not None:
|
||||
assert train_blocks in ["all", "single", "double"], f"invalid train_blocks: {train_blocks}"
|
||||
assert train_blocks in ["all", "transformer", "refiners", "noise_refiner", "context_refiner"], f"invalid train_blocks: {train_blocks}"
|
||||
|
||||
# split qkv
|
||||
split_qkv = kwargs.get("split_qkv", False)
|
||||
@@ -386,7 +394,7 @@ def create_network(
|
||||
train_blocks=train_blocks,
|
||||
split_qkv=split_qkv,
|
||||
type_dims=type_dims,
|
||||
in_dims=in_dims,
|
||||
embedder_dims=embedder_dims,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
@@ -461,7 +469,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, lumina, wei
|
||||
|
||||
|
||||
class LoRANetwork(torch.nn.Module):
|
||||
LUMINA_TARGET_REPLACE_MODULE = ["JointTransformerBlock"]
|
||||
LUMINA_TARGET_REPLACE_MODULE = ["JointTransformerBlock", "FinalLayer"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["Gemma2Attention", "Gemma2FlashAttention2", "Gemma2SdpaAttention", "Gemma2MLP"]
|
||||
LORA_PREFIX_LUMINA = "lora_unet"
|
||||
LORA_PREFIX_TEXT_ENCODER = "lora_te" # Simplified prefix since we only have one text encoder
|
||||
@@ -478,13 +486,14 @@ class LoRANetwork(torch.nn.Module):
|
||||
module_dropout: Optional[float] = None,
|
||||
conv_lora_dim: Optional[int] = None,
|
||||
conv_alpha: Optional[float] = None,
|
||||
module_class: Type[object] = LoRAModule,
|
||||
module_class: Type[LoRAModule] = LoRAModule,
|
||||
modules_dim: Optional[Dict[str, int]] = None,
|
||||
modules_alpha: Optional[Dict[str, int]] = None,
|
||||
train_blocks: Optional[str] = None,
|
||||
split_qkv: bool = False,
|
||||
type_dims: Optional[List[int]] = None,
|
||||
in_dims: Optional[List[int]] = None,
|
||||
embedder_dims: Optional[List[int]] = None,
|
||||
train_block_indices: Optional[List[bool]] = None,
|
||||
verbose: Optional[bool] = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -501,7 +510,9 @@ class LoRANetwork(torch.nn.Module):
|
||||
self.split_qkv = split_qkv
|
||||
|
||||
self.type_dims = type_dims
|
||||
self.in_dims = in_dims
|
||||
self.embedder_dims = embedder_dims
|
||||
|
||||
self.train_block_indices = train_block_indices
|
||||
|
||||
self.loraplus_lr_ratio = None
|
||||
self.loraplus_unet_lr_ratio = None
|
||||
@@ -509,7 +520,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
if modules_dim is not None:
|
||||
logger.info(f"create LoRA network from weights")
|
||||
self.in_dims = [0] * 5 # create in_dims
|
||||
self.embedder_dims = [0] * 5 # create embedder_dims
|
||||
# verbose = True
|
||||
else:
|
||||
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
||||
@@ -529,7 +540,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
def create_modules(
|
||||
is_lumina: bool,
|
||||
root_module: torch.nn.Module,
|
||||
target_replace_modules: List[str],
|
||||
target_replace_modules: Optional[List[str]],
|
||||
filter: Optional[str] = None,
|
||||
default_dim: Optional[int] = None,
|
||||
) -> List[LoRAModule]:
|
||||
@@ -544,63 +555,77 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
for child_name, child_module in module.named_modules():
|
||||
is_linear = child_module.__class__.__name__ == "Linear"
|
||||
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
||||
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
||||
|
||||
if is_linear or is_conv2d:
|
||||
lora_name = prefix + "." + (name + "." if name else "") + child_name
|
||||
lora_name = lora_name.replace(".", "_")
|
||||
lora_name = prefix + "." + (name + "." if name else "") + child_name
|
||||
lora_name = lora_name.replace(".", "_")
|
||||
|
||||
if filter is not None and not filter in lora_name:
|
||||
continue
|
||||
# Only Linear is supported
|
||||
if not is_linear:
|
||||
skipped.append(lora_name)
|
||||
continue
|
||||
|
||||
dim = None
|
||||
alpha = None
|
||||
if filter is not None and filter not in lora_name:
|
||||
continue
|
||||
|
||||
if modules_dim is not None:
|
||||
# モジュール指定あり
|
||||
if lora_name in modules_dim:
|
||||
dim = modules_dim[lora_name]
|
||||
alpha = modules_alpha[lora_name]
|
||||
else:
|
||||
# 通常、すべて対象とする
|
||||
if is_linear or is_conv2d_1x1:
|
||||
dim = default_dim if default_dim is not None else self.lora_dim
|
||||
alpha = self.alpha
|
||||
dim = default_dim if default_dim is not None else self.lora_dim
|
||||
alpha = self.alpha
|
||||
|
||||
if is_lumina and type_dims is not None:
|
||||
identifier = [
|
||||
("attention",), # attention layers
|
||||
("mlp",), # MLP layers
|
||||
("modulation",), # modulation layers
|
||||
("refiner",), # refiner blocks
|
||||
]
|
||||
for i, d in enumerate(type_dims):
|
||||
if d is not None and all([id in lora_name for id in identifier[i]]):
|
||||
dim = d # may be 0 for skip
|
||||
break
|
||||
# Set dim/alpha to modules dim/alpha
|
||||
if modules_dim is not None and modules_alpha is not None:
|
||||
# モジュール指定あり
|
||||
if lora_name in modules_dim:
|
||||
dim = modules_dim[lora_name]
|
||||
alpha = modules_alpha[lora_name]
|
||||
|
||||
elif self.conv_lora_dim is not None:
|
||||
dim = self.conv_lora_dim
|
||||
alpha = self.conv_alpha
|
||||
# Set dims to type_dims
|
||||
if is_lumina and type_dims is not None:
|
||||
identifier = [
|
||||
("attention",), # attention layers
|
||||
("mlp",), # MLP layers
|
||||
("modulation",), # modulation layers
|
||||
("refiner",), # refiner blocks
|
||||
]
|
||||
for i, d in enumerate(type_dims):
|
||||
if d is not None and all([id in lora_name for id in identifier[i]]):
|
||||
dim = d # may be 0 for skip
|
||||
break
|
||||
|
||||
if dim is None or dim == 0:
|
||||
# skipした情報を出力
|
||||
if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None):
|
||||
skipped.append(lora_name)
|
||||
continue
|
||||
|
||||
lora = module_class(
|
||||
lora_name,
|
||||
child_module,
|
||||
self.multiplier,
|
||||
dim,
|
||||
alpha,
|
||||
dropout=dropout,
|
||||
rank_dropout=rank_dropout,
|
||||
module_dropout=module_dropout,
|
||||
# Drop blocks if we are only training some blocks
|
||||
if (
|
||||
is_lumina
|
||||
and dim
|
||||
and (
|
||||
self.train_block_indices is not None
|
||||
)
|
||||
loras.append(lora)
|
||||
and ("layer" in lora_name)
|
||||
):
|
||||
# "lora_unet_layers_0_..." or "lora_unet_cap_refiner_0_..." or or "lora_unet_noise_refiner_0_..."
|
||||
block_index = int(lora_name.split("_")[3]) # bit dirty
|
||||
if (
|
||||
"layer" in lora_name
|
||||
and self.train_block_indices is not None
|
||||
and not self.train_block_indices[block_index]
|
||||
):
|
||||
dim = 0
|
||||
|
||||
|
||||
if dim is None or dim == 0:
|
||||
# skipした情報を出力
|
||||
skipped.append(lora_name)
|
||||
continue
|
||||
|
||||
lora = module_class(
|
||||
lora_name,
|
||||
child_module,
|
||||
self.multiplier,
|
||||
dim,
|
||||
alpha,
|
||||
dropout=dropout,
|
||||
rank_dropout=rank_dropout,
|
||||
module_dropout=module_dropout,
|
||||
)
|
||||
logger.info(f"Add LoRA module: {lora_name}")
|
||||
loras.append(lora)
|
||||
|
||||
if target_replace_modules is None:
|
||||
break # all modules are searched
|
||||
@@ -617,15 +642,25 @@ class LoRANetwork(torch.nn.Module):
|
||||
skipped_te += skipped
|
||||
|
||||
# create LoRA for U-Net
|
||||
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
|
||||
if self.train_blocks == "all":
|
||||
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
|
||||
# TODO: limit different blocks
|
||||
elif self.train_blocks == "transformer":
|
||||
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
|
||||
elif self.train_blocks == "refiners":
|
||||
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
|
||||
elif self.train_blocks == "noise_refiner":
|
||||
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
|
||||
elif self.train_blocks == "cap_refiner":
|
||||
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
|
||||
|
||||
self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
|
||||
self.unet_loras, skipped_un = create_modules(True, unet, target_replace_modules)
|
||||
|
||||
# Handle embedders
|
||||
if self.in_dims:
|
||||
for filter, in_dim in zip(["x_embedder", "t_embedder", "cap_embedder", "final_layer"], self.in_dims):
|
||||
loras, _ = create_modules(True, unet, None, filter=filter, default_dim=in_dim)
|
||||
if self.embedder_dims:
|
||||
for filter, embedder_dim in zip(["x_embedder", "t_embedder", "cap_embedder"], self.embedder_dims):
|
||||
loras, _ = create_modules(True, unet, None, filter=filter, default_dim=embedder_dim)
|
||||
self.unet_loras.extend(loras)
|
||||
|
||||
logger.info(f"create LoRA for Lumina blocks: {len(self.unet_loras)} modules.")
|
||||
|
||||
Reference in New Issue
Block a user