Merge pull request #23 from rockerBOO/lumina-lora

Lumina lora updates
This commit is contained in:
青龍聖者@bdsqlsz
2025-03-09 21:04:45 +08:00
committed by GitHub
3 changed files with 152 additions and 125 deletions

View File

@@ -887,6 +887,9 @@ class NextDiT(nn.Module):
),
)
nn.init.trunc_normal_(self.cap_embedder[1].weight, std=0.02)
nn.init.zeros_(self.cap_embedder[1].bias)
self.context_refiner = nn.ModuleList(
[
JointTransformerBlock(
@@ -929,9 +932,6 @@ class NextDiT(nn.Module):
]
)
nn.init.trunc_normal_(self.cap_embedder[1].weight, std=0.02)
# nn.init.zeros_(self.cap_embedder[1].weight)
nn.init.zeros_(self.cap_embedder[1].bias)
self.layers = nn.ModuleList(
[

View File

@@ -529,8 +529,8 @@ class DreamBoothSubset(BaseSubset):
self.is_reg = is_reg
self.class_tokens = class_tokens
self.caption_extension = caption_extension
if self.caption_extension and not self.caption_extension.startswith("."):
self.caption_extension = "." + self.caption_extension
# if self.caption_extension and not self.caption_extension.startswith("."):
# self.caption_extension = "." + self.caption_extension
self.cache_info = cache_info
def __eq__(self, other) -> bool:
@@ -1895,30 +1895,33 @@ class DreamBoothDataset(BaseDataset):
self.bucket_reso_steps = None # この情報は使われない
self.bucket_no_upscale = False
def read_caption(img_path, caption_extension, enable_wildcard):
def read_caption(img_path: str, caption_extension: str, enable_wildcard: bool):
# captionの候補ファイル名を作る
base_name = os.path.splitext(img_path)[0]
base_name_face_det = base_name
tokens = base_name.split("_")
if len(tokens) >= 5:
base_name_face_det = "_".join(tokens[:-4])
cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension]
cap_paths = [(base_name, caption_extension), (base_name_face_det, caption_extension)]
caption = None
for cap_path in cap_paths:
if os.path.isfile(cap_path):
with open(cap_path, "rt", encoding="utf-8") as f:
try:
lines = f.readlines()
except UnicodeDecodeError as e:
logger.error(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}")
raise e
assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
if enable_wildcard:
caption = "\n".join([line.strip() for line in lines if line.strip() != ""]) # 空行を除く、改行で連結
else:
caption = lines[0].strip()
break
for base, cap_extension in cap_paths:
# check with and without . to allow for extension flexibility (img_var.txt, img.txt, img + txt)
for cap_path in [base + cap_extension, base + "." + cap_extension]:
if os.path.isfile(cap_path):
with open(cap_path, "rt", encoding="utf-8") as f:
try:
lines = f.readlines()
except UnicodeDecodeError as e:
logger.error(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}")
raise e
assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
if enable_wildcard:
caption = "\n".join([line.strip() for line in lines if line.strip() != ""]) # 空行を除く、改行で連結
else:
caption = lines[0].strip()
break
break
return caption
def load_dreambooth_dir(subset: DreamBoothSubset):

View File

@@ -1,5 +1,5 @@
# temporary minimum implementation of LoRA
# FLUX doesn't have Conv2d, so we ignore it
# Lumina 2 does not have Conv2d, so ignore
# TODO commonize with the original implementation
# LoRA network module
@@ -10,13 +10,11 @@
import math
import os
from typing import Dict, List, Optional, Tuple, Type, Union
from diffusers import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from transformers import CLIPTextModel
import numpy as np
import torch
import re
from torch import Tensor, nn
from library.utils import setup_logging
from library.sdxl_original_unet import SdxlUNet2DConditionModel
setup_logging()
import logging
@@ -24,10 +22,6 @@ import logging
logger = logging.getLogger(__name__)
NUM_DOUBLE_BLOCKS = 19
NUM_SINGLE_BLOCKS = 38
class LoRAModule(torch.nn.Module):
"""
replaces forward method of the original Linear, instead of replacing the original Linear module.
@@ -35,14 +29,14 @@ class LoRAModule(torch.nn.Module):
def __init__(
self,
lora_name,
org_module: torch.nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
dropout=None,
rank_dropout=None,
module_dropout=None,
lora_name: str,
org_module: nn.Module,
multiplier: float =1.0,
lora_dim: int = 4,
alpha: Optional[float | int | Tensor] = 1,
dropout: Optional[float] = None,
rank_dropout: Optional[float] = None,
module_dropout: Optional[float] = None,
split_dims: Optional[List[int]] = None,
):
"""
@@ -60,6 +54,9 @@ class LoRAModule(torch.nn.Module):
in_dim = org_module.in_features
out_dim = org_module.out_features
assert isinstance(in_dim, int)
assert isinstance(out_dim, int)
self.lora_dim = lora_dim
self.split_dims = split_dims
@@ -68,30 +65,31 @@ class LoRAModule(torch.nn.Module):
kernel_size = org_module.kernel_size
stride = org_module.stride
padding = org_module.padding
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
self.lora_down = nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
self.lora_up = nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
else:
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
self.lora_down = nn.Linear(in_dim, self.lora_dim, bias=False)
self.lora_up = nn.Linear(self.lora_dim, out_dim, bias=False)
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
torch.nn.init.zeros_(self.lora_up.weight)
nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_up.weight)
else:
# conv2d not supported
assert sum(split_dims) == out_dim, "sum of split_dims must be equal to out_dim"
assert org_module.__class__.__name__ == "Linear", "split_dims is only supported for Linear"
# print(f"split_dims: {split_dims}")
self.lora_down = torch.nn.ModuleList(
[torch.nn.Linear(in_dim, self.lora_dim, bias=False) for _ in range(len(split_dims))]
self.lora_down = nn.ModuleList(
[nn.Linear(in_dim, self.lora_dim, bias=False) for _ in range(len(split_dims))]
)
self.lora_up = torch.nn.ModuleList([torch.nn.Linear(self.lora_dim, split_dim, bias=False) for split_dim in split_dims])
for lora_down in self.lora_down:
torch.nn.init.kaiming_uniform_(lora_down.weight, a=math.sqrt(5))
for lora_up in self.lora_up:
torch.nn.init.zeros_(lora_up.weight)
self.lora_up = nn.ModuleList([torch.nn.Linear(self.lora_dim, split_dim, bias=False) for split_dim in split_dims])
if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
for lora_down in self.lora_down:
nn.init.kaiming_uniform_(lora_down.weight, a=math.sqrt(5))
for lora_up in self.lora_up:
nn.init.zeros_(lora_up.weight)
if isinstance(alpha, Tensor):
alpha = alpha.detach().cpu().float().item() # without casting, bf16 causes error
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
self.scale = alpha / self.lora_dim
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
@@ -152,9 +150,9 @@ class LoRAModule(torch.nn.Module):
if self.rank_dropout is not None and self.training:
masks = [torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout for lx in lxs]
for i in range(len(lxs)):
if len(lx.size()) == 3:
if len(lxs[i].size()) == 3:
masks[i] = masks[i].unsqueeze(1)
elif len(lx.size()) == 4:
elif len(lxs[i].size()) == 4:
masks[i] = masks[i].unsqueeze(-1).unsqueeze(-1)
lxs[i] = lxs[i] * masks[i]
@@ -339,14 +337,14 @@ def create_network(
if all([d is None for d in type_dims]):
type_dims = None
# in_dims for embedders
in_dims = kwargs.get("in_dims", None)
if in_dims is not None:
in_dims = in_dims.strip()
if in_dims.startswith("[") and in_dims.endswith("]"):
in_dims = in_dims[1:-1]
in_dims = [int(d) for d in in_dims.split(",")]
assert len(in_dims) == 4, f"invalid in_dims: {in_dims}, must be 4 dimensions (x_embedder, t_embedder, cap_embedder, final_layer)"
# embedder_dims for embedders
embedder_dims = kwargs.get("embedder_dims", None)
if embedder_dims is not None:
embedder_dims = embedder_dims.strip()
if embedder_dims.startswith("[") and embedder_dims.endswith("]"):
embedder_dims = embedder_dims[1:-1]
embedder_dims = [int(d) for d in embedder_dims.split(",")]
assert len(embedder_dims) == 3, f"invalid embedder_dims: {embedder_dims}, must be 4 dimensions (x_embedder, t_embedder, cap_embedder)"
# rank/module dropout
rank_dropout = kwargs.get("rank_dropout", None)
@@ -357,9 +355,9 @@ def create_network(
module_dropout = float(module_dropout)
# single or double blocks
train_blocks = kwargs.get("train_blocks", None) # None (default), "all" (same as None), "single", "double"
train_blocks = kwargs.get("train_blocks", None) # None (default), "all" (same as None), "transformer", "refiners", "noise_refiner", "context_refiner"
if train_blocks is not None:
assert train_blocks in ["all", "single", "double"], f"invalid train_blocks: {train_blocks}"
assert train_blocks in ["all", "transformer", "refiners", "noise_refiner", "context_refiner"], f"invalid train_blocks: {train_blocks}"
# split qkv
split_qkv = kwargs.get("split_qkv", False)
@@ -386,7 +384,7 @@ def create_network(
train_blocks=train_blocks,
split_qkv=split_qkv,
type_dims=type_dims,
in_dims=in_dims,
embedder_dims=embedder_dims,
verbose=verbose,
)
@@ -461,7 +459,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, lumina, wei
class LoRANetwork(torch.nn.Module):
LUMINA_TARGET_REPLACE_MODULE = ["JointTransformerBlock"]
LUMINA_TARGET_REPLACE_MODULE = ["JointTransformerBlock", "FinalLayer"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["Gemma2Attention", "Gemma2FlashAttention2", "Gemma2SdpaAttention", "Gemma2MLP"]
LORA_PREFIX_LUMINA = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te" # Simplified prefix since we only have one text encoder
@@ -478,13 +476,14 @@ class LoRANetwork(torch.nn.Module):
module_dropout: Optional[float] = None,
conv_lora_dim: Optional[int] = None,
conv_alpha: Optional[float] = None,
module_class: Type[object] = LoRAModule,
module_class: Type[LoRAModule] = LoRAModule,
modules_dim: Optional[Dict[str, int]] = None,
modules_alpha: Optional[Dict[str, int]] = None,
train_blocks: Optional[str] = None,
split_qkv: bool = False,
type_dims: Optional[List[int]] = None,
in_dims: Optional[List[int]] = None,
embedder_dims: Optional[List[int]] = None,
train_block_indices: Optional[List[bool]] = None,
verbose: Optional[bool] = False,
) -> None:
super().__init__()
@@ -501,7 +500,9 @@ class LoRANetwork(torch.nn.Module):
self.split_qkv = split_qkv
self.type_dims = type_dims
self.in_dims = in_dims
self.embedder_dims = embedder_dims
self.train_block_indices = train_block_indices
self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None
@@ -509,7 +510,7 @@ class LoRANetwork(torch.nn.Module):
if modules_dim is not None:
logger.info(f"create LoRA network from weights")
self.in_dims = [0] * 5 # create in_dims
self.embedder_dims = [0] * 5 # create embedder_dims
# verbose = True
else:
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
@@ -529,7 +530,7 @@ class LoRANetwork(torch.nn.Module):
def create_modules(
is_lumina: bool,
root_module: torch.nn.Module,
target_replace_modules: List[str],
target_replace_modules: Optional[List[str]],
filter: Optional[str] = None,
default_dim: Optional[int] = None,
) -> List[LoRAModule]:
@@ -544,63 +545,76 @@ class LoRANetwork(torch.nn.Module):
for child_name, child_module in module.named_modules():
is_linear = child_module.__class__.__name__ == "Linear"
is_conv2d = child_module.__class__.__name__ == "Conv2d"
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
if is_linear or is_conv2d:
lora_name = prefix + "." + (name + "." if name else "") + child_name
lora_name = lora_name.replace(".", "_")
lora_name = prefix + "." + (name + "." if name else "") + child_name
lora_name = lora_name.replace(".", "_")
if filter is not None and not filter in lora_name:
continue
# Only Linear is supported
if not is_linear:
skipped.append(lora_name)
continue
dim = None
alpha = None
if filter is not None and filter not in lora_name:
continue
if modules_dim is not None:
# モジュール指定あり
if lora_name in modules_dim:
dim = modules_dim[lora_name]
alpha = modules_alpha[lora_name]
else:
# 通常、すべて対象とする
if is_linear or is_conv2d_1x1:
dim = default_dim if default_dim is not None else self.lora_dim
alpha = self.alpha
dim = default_dim if default_dim is not None else self.lora_dim
alpha = self.alpha
if is_lumina and type_dims is not None:
identifier = [
("attention",), # attention layers
("mlp",), # MLP layers
("modulation",), # modulation layers
("refiner",), # refiner blocks
]
for i, d in enumerate(type_dims):
if d is not None and all([id in lora_name for id in identifier[i]]):
dim = d # may be 0 for skip
break
# Set dim/alpha to modules dim/alpha
if modules_dim is not None and modules_alpha is not None:
# モジュール指定あり
if lora_name in modules_dim:
dim = modules_dim[lora_name]
alpha = modules_alpha[lora_name]
elif self.conv_lora_dim is not None:
dim = self.conv_lora_dim
alpha = self.conv_alpha
# Set dims to type_dims
if is_lumina and type_dims is not None:
identifier = [
("attention",), # attention layers
("mlp",), # MLP layers
("modulation",), # modulation layers
("refiner",), # refiner blocks
]
for i, d in enumerate(type_dims):
if d is not None and all([id in lora_name for id in identifier[i]]):
dim = d # may be 0 for skip
break
if dim is None or dim == 0:
# skipした情報を出力
if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None):
skipped.append(lora_name)
continue
lora = module_class(
lora_name,
child_module,
self.multiplier,
dim,
alpha,
dropout=dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
# Drop blocks if we are only training some blocks
if (
is_lumina
and dim
and (
self.train_block_indices is not None
)
loras.append(lora)
and ("layer" in lora_name)
):
# "lora_unet_layers_0_..." or "lora_unet_cap_refiner_0_..." or or "lora_unet_noise_refiner_0_..."
block_index = int(lora_name.split("_")[3]) # bit dirty
if (
"layer" in lora_name
and self.train_block_indices is not None
and not self.train_block_indices[block_index]
):
dim = 0
if dim is None or dim == 0:
# skipした情報を出力
skipped.append(lora_name)
continue
lora = module_class(
lora_name,
child_module,
self.multiplier,
dim,
alpha,
dropout=dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
)
loras.append(lora)
if target_replace_modules is None:
break # all modules are searched
@@ -617,15 +631,25 @@ class LoRANetwork(torch.nn.Module):
skipped_te += skipped
# create LoRA for U-Net
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
if self.train_blocks == "all":
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
# TODO: limit different blocks
elif self.train_blocks == "transformer":
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
elif self.train_blocks == "refiners":
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
elif self.train_blocks == "noise_refiner":
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
elif self.train_blocks == "cap_refiner":
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
self.unet_loras, skipped_un = create_modules(True, unet, target_replace_modules)
# Handle embedders
if self.in_dims:
for filter, in_dim in zip(["x_embedder", "t_embedder", "cap_embedder", "final_layer"], self.in_dims):
loras, _ = create_modules(True, unet, None, filter=filter, default_dim=in_dim)
if self.embedder_dims:
for filter, embedder_dim in zip(["x_embedder", "t_embedder", "cap_embedder"], self.embedder_dims):
loras, _ = create_modules(True, unet, None, filter=filter, default_dim=embedder_dim)
self.unet_loras.extend(loras)
logger.info(f"create LoRA for Lumina blocks: {len(self.unet_loras)} modules.")