diff --git a/library/lumina_models.py b/library/lumina_models.py index e00dcf96..2508cc7d 100644 --- a/library/lumina_models.py +++ b/library/lumina_models.py @@ -887,6 +887,9 @@ class NextDiT(nn.Module): ), ) + nn.init.trunc_normal_(self.cap_embedder[1].weight, std=0.02) + nn.init.zeros_(self.cap_embedder[1].bias) + self.context_refiner = nn.ModuleList( [ JointTransformerBlock( @@ -929,9 +932,6 @@ class NextDiT(nn.Module): ] ) - nn.init.trunc_normal_(self.cap_embedder[1].weight, std=0.02) - # nn.init.zeros_(self.cap_embedder[1].weight) - nn.init.zeros_(self.cap_embedder[1].bias) self.layers = nn.ModuleList( [ diff --git a/library/train_util.py b/library/train_util.py index 34b98f89..c07a4a73 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -529,8 +529,8 @@ class DreamBoothSubset(BaseSubset): self.is_reg = is_reg self.class_tokens = class_tokens self.caption_extension = caption_extension - if self.caption_extension and not self.caption_extension.startswith("."): - self.caption_extension = "." + self.caption_extension + # if self.caption_extension and not self.caption_extension.startswith("."): + # self.caption_extension = "." + self.caption_extension self.cache_info = cache_info def __eq__(self, other) -> bool: @@ -1895,30 +1895,33 @@ class DreamBoothDataset(BaseDataset): self.bucket_reso_steps = None # この情報は使われない self.bucket_no_upscale = False - def read_caption(img_path, caption_extension, enable_wildcard): + def read_caption(img_path: str, caption_extension: str, enable_wildcard: bool): # captionの候補ファイル名を作る base_name = os.path.splitext(img_path)[0] base_name_face_det = base_name tokens = base_name.split("_") if len(tokens) >= 5: base_name_face_det = "_".join(tokens[:-4]) - cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension] + cap_paths = [(base_name, caption_extension), (base_name_face_det, caption_extension)] caption = None - for cap_path in cap_paths: - if os.path.isfile(cap_path): - with open(cap_path, "rt", encoding="utf-8") as f: - try: - lines = f.readlines() - except UnicodeDecodeError as e: - logger.error(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}") - raise e - assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}" - if enable_wildcard: - caption = "\n".join([line.strip() for line in lines if line.strip() != ""]) # 空行を除く、改行で連結 - else: - caption = lines[0].strip() - break + for base, cap_extension in cap_paths: + # check with and without . to allow for extension flexibility (img_var.txt, img.txt, img + txt) + for cap_path in [base + cap_extension, base + "." + cap_extension]: + if os.path.isfile(cap_path): + with open(cap_path, "rt", encoding="utf-8") as f: + try: + lines = f.readlines() + except UnicodeDecodeError as e: + logger.error(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}") + raise e + assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}" + if enable_wildcard: + caption = "\n".join([line.strip() for line in lines if line.strip() != ""]) # 空行を除く、改行で連結 + else: + caption = lines[0].strip() + break + break return caption def load_dreambooth_dir(subset: DreamBoothSubset): diff --git a/networks/lora_lumina.py b/networks/lora_lumina.py index 431c183d..f856d4e7 100644 --- a/networks/lora_lumina.py +++ b/networks/lora_lumina.py @@ -1,5 +1,5 @@ # temporary minimum implementation of LoRA -# FLUX doesn't have Conv2d, so we ignore it +# Lumina 2 does not have Conv2d, so ignore # TODO commonize with the original implementation # LoRA network module @@ -10,13 +10,11 @@ import math import os from typing import Dict, List, Optional, Tuple, Type, Union -from diffusers import AutoencoderKL +from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL from transformers import CLIPTextModel -import numpy as np import torch -import re +from torch import Tensor, nn from library.utils import setup_logging -from library.sdxl_original_unet import SdxlUNet2DConditionModel setup_logging() import logging @@ -35,14 +33,14 @@ class LoRAModule(torch.nn.Module): def __init__( self, - lora_name, - org_module: torch.nn.Module, - multiplier=1.0, - lora_dim=4, - alpha=1, - dropout=None, - rank_dropout=None, - module_dropout=None, + lora_name: str, + org_module: nn.Module, + multiplier: float =1.0, + lora_dim: int = 4, + alpha: Optional[float | int | Tensor] = 1, + dropout: Optional[float] = None, + rank_dropout: Optional[float] = None, + module_dropout: Optional[float] = None, split_dims: Optional[List[int]] = None, ): """ @@ -60,6 +58,9 @@ class LoRAModule(torch.nn.Module): in_dim = org_module.in_features out_dim = org_module.out_features + assert isinstance(in_dim, int) + assert isinstance(out_dim, int) + self.lora_dim = lora_dim self.split_dims = split_dims @@ -68,30 +69,31 @@ class LoRAModule(torch.nn.Module): kernel_size = org_module.kernel_size stride = org_module.stride padding = org_module.padding - self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) - self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) + self.lora_down = nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) + self.lora_up = nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) else: - self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) - self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) + self.lora_down = nn.Linear(in_dim, self.lora_dim, bias=False) + self.lora_up = nn.Linear(self.lora_dim, out_dim, bias=False) - torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) - torch.nn.init.zeros_(self.lora_up.weight) + nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + nn.init.zeros_(self.lora_up.weight) else: # conv2d not supported assert sum(split_dims) == out_dim, "sum of split_dims must be equal to out_dim" assert org_module.__class__.__name__ == "Linear", "split_dims is only supported for Linear" # print(f"split_dims: {split_dims}") - self.lora_down = torch.nn.ModuleList( - [torch.nn.Linear(in_dim, self.lora_dim, bias=False) for _ in range(len(split_dims))] + self.lora_down = nn.ModuleList( + [nn.Linear(in_dim, self.lora_dim, bias=False) for _ in range(len(split_dims))] ) - self.lora_up = torch.nn.ModuleList([torch.nn.Linear(self.lora_dim, split_dim, bias=False) for split_dim in split_dims]) - for lora_down in self.lora_down: - torch.nn.init.kaiming_uniform_(lora_down.weight, a=math.sqrt(5)) - for lora_up in self.lora_up: - torch.nn.init.zeros_(lora_up.weight) + self.lora_up = nn.ModuleList([torch.nn.Linear(self.lora_dim, split_dim, bias=False) for split_dim in split_dims]) - if type(alpha) == torch.Tensor: - alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + for lora_down in self.lora_down: + nn.init.kaiming_uniform_(lora_down.weight, a=math.sqrt(5)) + for lora_up in self.lora_up: + nn.init.zeros_(lora_up.weight) + + if isinstance(alpha, Tensor): + alpha = alpha.detach().cpu().float().item() # without casting, bf16 causes error alpha = self.lora_dim if alpha is None or alpha == 0 else alpha self.scale = alpha / self.lora_dim self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える @@ -140,6 +142,9 @@ class LoRAModule(torch.nn.Module): lx = self.lora_up(lx) + if self.dropout is not None and self.training: + lx = torch.nn.functional.dropout(lx, p=self.dropout) + return org_forwarded + lx * self.multiplier * scale else: lxs = [lora_down(x) for lora_down in self.lora_down] @@ -152,9 +157,9 @@ class LoRAModule(torch.nn.Module): if self.rank_dropout is not None and self.training: masks = [torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout for lx in lxs] for i in range(len(lxs)): - if len(lx.size()) == 3: + if len(lxs[i].size()) == 3: masks[i] = masks[i].unsqueeze(1) - elif len(lx.size()) == 4: + elif len(lxs[i].size()) == 4: masks[i] = masks[i].unsqueeze(-1).unsqueeze(-1) lxs[i] = lxs[i] * masks[i] @@ -165,6 +170,9 @@ class LoRAModule(torch.nn.Module): lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)] + if self.dropout is not None and self.training: + lxs = [torch.nn.functional.dropout(lx, p=self.dropout) for lx in lxs] + return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale @@ -339,14 +347,14 @@ def create_network( if all([d is None for d in type_dims]): type_dims = None - # in_dims for embedders - in_dims = kwargs.get("in_dims", None) - if in_dims is not None: - in_dims = in_dims.strip() - if in_dims.startswith("[") and in_dims.endswith("]"): - in_dims = in_dims[1:-1] - in_dims = [int(d) for d in in_dims.split(",")] - assert len(in_dims) == 4, f"invalid in_dims: {in_dims}, must be 4 dimensions (x_embedder, t_embedder, cap_embedder, final_layer)" + # embedder_dims for embedders + embedder_dims = kwargs.get("embedder_dims", None) + if embedder_dims is not None: + embedder_dims = embedder_dims.strip() + if embedder_dims.startswith("[") and embedder_dims.endswith("]"): + embedder_dims = embedder_dims[1:-1] + embedder_dims = [int(d) for d in embedder_dims.split(",")] + assert len(embedder_dims) == 3, f"invalid embedder_dims: {embedder_dims}, must be 4 dimensions (x_embedder, t_embedder, cap_embedder)" # rank/module dropout rank_dropout = kwargs.get("rank_dropout", None) @@ -357,9 +365,9 @@ def create_network( module_dropout = float(module_dropout) # single or double blocks - train_blocks = kwargs.get("train_blocks", None) # None (default), "all" (same as None), "single", "double" + train_blocks = kwargs.get("train_blocks", None) # None (default), "all" (same as None), "transformer", "refiners", "noise_refiner", "context_refiner" if train_blocks is not None: - assert train_blocks in ["all", "single", "double"], f"invalid train_blocks: {train_blocks}" + assert train_blocks in ["all", "transformer", "refiners", "noise_refiner", "context_refiner"], f"invalid train_blocks: {train_blocks}" # split qkv split_qkv = kwargs.get("split_qkv", False) @@ -386,7 +394,7 @@ def create_network( train_blocks=train_blocks, split_qkv=split_qkv, type_dims=type_dims, - in_dims=in_dims, + embedder_dims=embedder_dims, verbose=verbose, ) @@ -461,7 +469,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, lumina, wei class LoRANetwork(torch.nn.Module): - LUMINA_TARGET_REPLACE_MODULE = ["JointTransformerBlock"] + LUMINA_TARGET_REPLACE_MODULE = ["JointTransformerBlock", "FinalLayer"] TEXT_ENCODER_TARGET_REPLACE_MODULE = ["Gemma2Attention", "Gemma2FlashAttention2", "Gemma2SdpaAttention", "Gemma2MLP"] LORA_PREFIX_LUMINA = "lora_unet" LORA_PREFIX_TEXT_ENCODER = "lora_te" # Simplified prefix since we only have one text encoder @@ -478,13 +486,14 @@ class LoRANetwork(torch.nn.Module): module_dropout: Optional[float] = None, conv_lora_dim: Optional[int] = None, conv_alpha: Optional[float] = None, - module_class: Type[object] = LoRAModule, + module_class: Type[LoRAModule] = LoRAModule, modules_dim: Optional[Dict[str, int]] = None, modules_alpha: Optional[Dict[str, int]] = None, train_blocks: Optional[str] = None, split_qkv: bool = False, type_dims: Optional[List[int]] = None, - in_dims: Optional[List[int]] = None, + embedder_dims: Optional[List[int]] = None, + train_block_indices: Optional[List[bool]] = None, verbose: Optional[bool] = False, ) -> None: super().__init__() @@ -501,7 +510,9 @@ class LoRANetwork(torch.nn.Module): self.split_qkv = split_qkv self.type_dims = type_dims - self.in_dims = in_dims + self.embedder_dims = embedder_dims + + self.train_block_indices = train_block_indices self.loraplus_lr_ratio = None self.loraplus_unet_lr_ratio = None @@ -509,7 +520,7 @@ class LoRANetwork(torch.nn.Module): if modules_dim is not None: logger.info(f"create LoRA network from weights") - self.in_dims = [0] * 5 # create in_dims + self.embedder_dims = [0] * 5 # create embedder_dims # verbose = True else: logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") @@ -529,7 +540,7 @@ class LoRANetwork(torch.nn.Module): def create_modules( is_lumina: bool, root_module: torch.nn.Module, - target_replace_modules: List[str], + target_replace_modules: Optional[List[str]], filter: Optional[str] = None, default_dim: Optional[int] = None, ) -> List[LoRAModule]: @@ -544,63 +555,77 @@ class LoRANetwork(torch.nn.Module): for child_name, child_module in module.named_modules(): is_linear = child_module.__class__.__name__ == "Linear" - is_conv2d = child_module.__class__.__name__ == "Conv2d" - is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) - if is_linear or is_conv2d: - lora_name = prefix + "." + (name + "." if name else "") + child_name - lora_name = lora_name.replace(".", "_") + lora_name = prefix + "." + (name + "." if name else "") + child_name + lora_name = lora_name.replace(".", "_") - if filter is not None and not filter in lora_name: - continue + # Only Linear is supported + if not is_linear: + skipped.append(lora_name) + continue - dim = None - alpha = None + if filter is not None and filter not in lora_name: + continue - if modules_dim is not None: - # モジュール指定あり - if lora_name in modules_dim: - dim = modules_dim[lora_name] - alpha = modules_alpha[lora_name] - else: - # 通常、すべて対象とする - if is_linear or is_conv2d_1x1: - dim = default_dim if default_dim is not None else self.lora_dim - alpha = self.alpha + dim = default_dim if default_dim is not None else self.lora_dim + alpha = self.alpha - if is_lumina and type_dims is not None: - identifier = [ - ("attention",), # attention layers - ("mlp",), # MLP layers - ("modulation",), # modulation layers - ("refiner",), # refiner blocks - ] - for i, d in enumerate(type_dims): - if d is not None and all([id in lora_name for id in identifier[i]]): - dim = d # may be 0 for skip - break + # Set dim/alpha to modules dim/alpha + if modules_dim is not None and modules_alpha is not None: + # モジュール指定あり + if lora_name in modules_dim: + dim = modules_dim[lora_name] + alpha = modules_alpha[lora_name] - elif self.conv_lora_dim is not None: - dim = self.conv_lora_dim - alpha = self.conv_alpha + # Set dims to type_dims + if is_lumina and type_dims is not None: + identifier = [ + ("attention",), # attention layers + ("mlp",), # MLP layers + ("modulation",), # modulation layers + ("refiner",), # refiner blocks + ] + for i, d in enumerate(type_dims): + if d is not None and all([id in lora_name for id in identifier[i]]): + dim = d # may be 0 for skip + break - if dim is None or dim == 0: - # skipした情報を出力 - if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None): - skipped.append(lora_name) - continue - - lora = module_class( - lora_name, - child_module, - self.multiplier, - dim, - alpha, - dropout=dropout, - rank_dropout=rank_dropout, - module_dropout=module_dropout, + # Drop blocks if we are only training some blocks + if ( + is_lumina + and dim + and ( + self.train_block_indices is not None ) - loras.append(lora) + and ("layer" in lora_name) + ): + # "lora_unet_layers_0_..." or "lora_unet_cap_refiner_0_..." or or "lora_unet_noise_refiner_0_..." + block_index = int(lora_name.split("_")[3]) # bit dirty + if ( + "layer" in lora_name + and self.train_block_indices is not None + and not self.train_block_indices[block_index] + ): + dim = 0 + + + if dim is None or dim == 0: + # skipした情報を出力 + skipped.append(lora_name) + continue + + lora = module_class( + lora_name, + child_module, + self.multiplier, + dim, + alpha, + dropout=dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + ) + logger.info(f"Add LoRA module: {lora_name}") + loras.append(lora) if target_replace_modules is None: break # all modules are searched @@ -617,15 +642,25 @@ class LoRANetwork(torch.nn.Module): skipped_te += skipped # create LoRA for U-Net - target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE + if self.train_blocks == "all": + target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE + # TODO: limit different blocks + elif self.train_blocks == "transformer": + target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE + elif self.train_blocks == "refiners": + target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE + elif self.train_blocks == "noise_refiner": + target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE + elif self.train_blocks == "cap_refiner": + target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE self.unet_loras: List[Union[LoRAModule, LoRAInfModule]] self.unet_loras, skipped_un = create_modules(True, unet, target_replace_modules) # Handle embedders - if self.in_dims: - for filter, in_dim in zip(["x_embedder", "t_embedder", "cap_embedder", "final_layer"], self.in_dims): - loras, _ = create_modules(True, unet, None, filter=filter, default_dim=in_dim) + if self.embedder_dims: + for filter, embedder_dim in zip(["x_embedder", "t_embedder", "cap_embedder"], self.embedder_dims): + loras, _ = create_modules(True, unet, None, filter=filter, default_dim=embedder_dim) self.unet_loras.extend(loras) logger.info(f"create LoRA for Lumina blocks: {len(self.unet_loras)} modules.")