From 872124c5e147db30c47d63055ebccc00b7f49f0c Mon Sep 17 00:00:00 2001 From: woctordho Date: Mon, 17 Nov 2025 09:20:08 +0800 Subject: [PATCH 01/17] Use svd_lowrank for large matrices in resize_lora.py --- networks/resize_lora.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/networks/resize_lora.py b/networks/resize_lora.py index 18326437..5dd1132f 100644 --- a/networks/resize_lora.py +++ b/networks/resize_lora.py @@ -87,7 +87,14 @@ def index_sv_ratio(S, target): # Modified from Kohaku-blueleaf's extract/merge functions def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1): out_size, in_size, kernel_size, _ = weight.size() - U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device)) + weight = weight.reshape(out_size, -1) + _in_size = in_size * kernel_size * kernel_size + + if out_size > 2048 and _in_size > 2048: + U, S, V = torch.svd_lowrank(weight.to(device), q=min(2 * lora_rank, out_size, _in_size)) + Vh = V.T + else: + U, S, Vh = torch.linalg.svd(weight.to(device)) param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale) lora_rank = param_dict["new_rank"] @@ -106,7 +113,11 @@ def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1): out_size, in_size = weight.size() - U, S, Vh = torch.linalg.svd(weight.to(device)) + if out_size > 2048 and in_size > 2048: + U, S, V = torch.svd_lowrank(weight.to(device), q=min(2 * lora_rank, out_size, in_size)) + Vh = V.T + else: + U, S, Vh = torch.linalg.svd(weight.to(device)) param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale) lora_rank = param_dict["new_rank"] From 609d1292f6e262b27a8c5b2849e7bf0df2ecd7a8 Mon Sep 17 00:00:00 2001 From: duongve13112002 <71595470+duongve13112002@users.noreply.github.com> Date: Mon, 23 Feb 2026 13:13:40 +0700 Subject: [PATCH 02/17] Fix the LoRA dropout issue in the Anima model during LoRA training. (#2272) * Support network_reg_alphas and fix bug when setting rank_dropout in training lora for anima model * Update anima_train_network.md * Update anima_train_network.md * Remove network_reg_alphas * Update document --- docs/anima_train_network.md | 2 +- networks/lora_anima.py | 2 +- networks/lora_flux.py | 13 ++++++++----- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/docs/anima_train_network.md b/docs/anima_train_network.md index f97aa975..5d67ae36 100644 --- a/docs/anima_train_network.md +++ b/docs/anima_train_network.md @@ -652,4 +652,4 @@ The following metadata is saved in the LoRA model file: * `ss_sigmoid_scale` * `ss_discrete_flow_shift` - + \ No newline at end of file diff --git a/networks/lora_anima.py b/networks/lora_anima.py index 224ef20c..9413e8c8 100644 --- a/networks/lora_anima.py +++ b/networks/lora_anima.py @@ -636,4 +636,4 @@ class LoRANetwork(torch.nn.Module): scalednorm = updown.norm() * ratio norms.append(scalednorm.item()) - return keys_scaled, sum(norms) / len(norms), max(norms) + return keys_scaled, sum(norms) / len(norms), max(norms) \ No newline at end of file diff --git a/networks/lora_flux.py b/networks/lora_flux.py index d74d0172..947733fe 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -141,10 +141,13 @@ class LoRAModule(torch.nn.Module): # rank dropout if self.rank_dropout is not None and self.training: mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout - if len(lx.size()) == 3: - mask = mask.unsqueeze(1) # for Text Encoder - elif len(lx.size()) == 4: - mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d + if isinstance(self.lora_down, torch.nn.Conv2d): + # Conv2d: lora_dim is at dim 1 → [B, dim, 1, 1] + mask = mask.unsqueeze(-1).unsqueeze(-1) + else: + # Linear: lora_dim is at last dim → [B, 1, ..., 1, dim] + for _ in range(len(lx.size()) - 2): + mask = mask.unsqueeze(1) lx = lx * mask # scaling for rank dropout: treat as if the rank is changed @@ -1445,4 +1448,4 @@ class LoRANetwork(torch.nn.Module): scalednorm = updown.norm() * ratio norms.append(scalednorm.item()) - return keys_scaled, sum(norms) / len(norms), max(norms) + return keys_scaled, sum(norms) / len(norms), max(norms) \ No newline at end of file From 50694df3cf0c02bbdac9d94464c9d93d908c654c Mon Sep 17 00:00:00 2001 From: woctordho Date: Mon, 23 Feb 2026 14:30:36 +0800 Subject: [PATCH 03/17] Multi-resolution dataset for SD1/SDXL (#2269) * Multi-resolution dataset for SD1/SDXL * Add fallback to legacy key without resolution suffix * Support numpy 2.2 --- library/strategy_base.py | 102 +++++++++++++++++++++++++++------------ library/strategy_sd.py | 23 ++++++++- 2 files changed, 92 insertions(+), 33 deletions(-) diff --git a/library/strategy_base.py b/library/strategy_base.py index 6e6487ea..9a2acdba 100644 --- a/library/strategy_base.py +++ b/library/strategy_base.py @@ -6,6 +6,11 @@ from typing import Any, List, Optional, Tuple, Union, Callable import numpy as np import torch + +try: + from numpy.lib import _format_impl as np_format_impl +except ImportError: + from numpy.lib import format as np_format_impl from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection @@ -424,6 +429,16 @@ class LatentsCachingStrategy: def cache_batch_latents(self, model: Any, batch: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): raise NotImplementedError + def _get_npz_array_shape(self, npz: Any, key: str) -> Optional[Tuple[int, ...]]: + """Get array shape in npz file by only reading the header.""" + if key not in npz: + return None + + with npz.zip.open(key + ".npy") as npy_file: + version = np.lib.format.read_magic(npy_file) + shape, _, _ = np_format_impl._read_array_header(npy_file, version) + return shape + def _default_is_disk_cached_latents_expected( self, latents_stride: int, @@ -432,6 +447,7 @@ class LatentsCachingStrategy: flip_aug: bool, apply_alpha_mask: bool, multi_resolution: bool = False, + fallback_no_reso: bool = False, ) -> bool: """ Args: @@ -441,6 +457,7 @@ class LatentsCachingStrategy: flip_aug: whether to flip images apply_alpha_mask: whether to apply alpha mask multi_resolution: whether to use multi-resolution latents + fallback_no_reso: fallback to legacy key without resolution suffix Returns: bool @@ -458,13 +475,21 @@ class LatentsCachingStrategy: key_reso_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}" if multi_resolution else "" try: - npz = np.load(npz_path) - if "latents" + key_reso_suffix not in npz: - return False - if flip_aug and "latents_flipped" + key_reso_suffix not in npz: - return False - if apply_alpha_mask and "alpha_mask" + key_reso_suffix not in npz: - return False + with np.load(npz_path) as npz: + if "latents" + key_reso_suffix not in npz: + if not (multi_resolution and fallback_no_reso): + return False + + latents_shape = self._get_npz_array_shape(npz, "latents") + if latents_shape is None or tuple(latents_shape[-2:]) != expected_latents_size: + return False + + key_reso_suffix = "" + + if flip_aug and "latents_flipped" + key_reso_suffix not in npz: + return False + if apply_alpha_mask and "alpha_mask" + key_reso_suffix not in npz: + return False except Exception as e: logger.error(f"Error loading file: {npz_path}") raise e @@ -495,8 +520,8 @@ class LatentsCachingStrategy: apply_alpha_mask: whether to apply alpha mask random_crop: whether to random crop images multi_resolution: whether to use multi-resolution latents - - Returns: + + Returns: None """ from library import train_util # import here to avoid circular import @@ -548,52 +573,67 @@ class LatentsCachingStrategy: Args: npz_path (str): Path to the npz file. bucket_reso (Tuple[int, int]): The resolution of the bucket. - + Returns: Tuple[ - Optional[np.ndarray], - Optional[List[int]], - Optional[List[int]], - Optional[np.ndarray], + Optional[np.ndarray], + Optional[List[int]], + Optional[List[int]], + Optional[np.ndarray], Optional[np.ndarray] ]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask """ return self._default_load_latents_from_disk(None, npz_path, bucket_reso) def _default_load_latents_from_disk( - self, latents_stride: Optional[int], npz_path: str, bucket_reso: Tuple[int, int] + self, + latents_stride: Optional[int], + npz_path: str, + bucket_reso: Tuple[int, int], + fallback_no_reso: bool = False, ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: """ Args: latents_stride (Optional[int]): Stride for latents. If None, load all latents. npz_path (str): Path to the npz file. bucket_reso (Tuple[int, int]): The resolution of the bucket. - + fallback_no_reso (bool): fallback to legacy key without resolution suffix + Returns: Tuple[ - Optional[np.ndarray], - Optional[List[int]], - Optional[List[int]], - Optional[np.ndarray], + Optional[np.ndarray], + Optional[List[int]], + Optional[List[int]], + Optional[np.ndarray], Optional[np.ndarray] ]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask """ if latents_stride is None: + expected_latents_size = None key_reso_suffix = "" else: - latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H) - key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" # e.g. "_32x64", HxW + expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H) + key_reso_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}" # e.g. "_32x64", HxW - npz = np.load(npz_path) - if "latents" + key_reso_suffix not in npz: - raise ValueError(f"latents{key_reso_suffix} not found in {npz_path}") + with np.load(npz_path) as npz: + key_reso_suffix = key_reso_suffix - latents = npz["latents" + key_reso_suffix] - original_size = npz["original_size" + key_reso_suffix].tolist() - crop_ltrb = npz["crop_ltrb" + key_reso_suffix].tolist() - flipped_latents = npz["latents_flipped" + key_reso_suffix] if "latents_flipped" + key_reso_suffix in npz else None - alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None - return latents, original_size, crop_ltrb, flipped_latents, alpha_mask + if "latents" + key_reso_suffix not in npz: + if not fallback_no_reso or expected_latents_size is None: + raise ValueError(f"latents{key_reso_suffix} not found in {npz_path}") + + latents_shape = self._get_npz_array_shape(npz, "latents") + if latents_shape is None or tuple(latents_shape[-2:]) != expected_latents_size: + raise ValueError(f"latents with legacy key has unexpected shape {latents_shape} in {npz_path}") + + key_reso_suffix = "" + + latents = npz["latents" + key_reso_suffix] + original_size = npz["original_size" + key_reso_suffix].tolist() + crop_ltrb = npz["crop_ltrb" + key_reso_suffix].tolist() + flipped_latents = npz["latents_flipped" + key_reso_suffix] if "latents_flipped" + key_reso_suffix in npz else None + alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None + return latents, original_size, crop_ltrb, flipped_latents, alpha_mask def save_latents_to_disk( self, diff --git a/library/strategy_sd.py b/library/strategy_sd.py index a44fc409..837b8f5a 100644 --- a/library/strategy_sd.py +++ b/library/strategy_sd.py @@ -2,6 +2,7 @@ import glob import os from typing import Any, List, Optional, Tuple, Union +import numpy as np import torch from transformers import CLIPTokenizer from library import train_util @@ -157,7 +158,25 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy): return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): - return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask) + return self._default_is_disk_cached_latents_expected( + 8, + bucket_reso, + npz_path, + flip_aug, + alpha_mask, + multi_resolution=True, + fallback_no_reso=True, + ) + + def load_latents_from_disk( + self, npz_path: str, bucket_reso: Tuple[int, int] + ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: + return self._default_load_latents_from_disk( + 8, + npz_path, + bucket_reso, + fallback_no_reso=True, + ) # TODO remove circular dependency for ImageInfo def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): @@ -165,7 +184,7 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy): vae_device = vae.device vae_dtype = vae.dtype - self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop) + self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True) if not train_util.HIGH_VRAM: train_util.clean_memory_on_device(vae.device) From 892f8be78fc01989ab27c01bfd02173676d43bd3 Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Mon, 23 Feb 2026 21:12:57 +0900 Subject: [PATCH 04/17] fix: cast input tensor to float32 for improved numerical stability in residual connections --- library/anima_models.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/library/anima_models.py b/library/anima_models.py index 6828e598..037ffd77 100644 --- a/library/anima_models.py +++ b/library/anima_models.py @@ -864,6 +864,10 @@ class Block(nn.Module): adaln_lora_B_T_3D: Optional[torch.Tensor] = None, extra_per_block_pos_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if x_B_T_H_W_D.dtype == torch.float16: + # Cast to float32 for better numerical stability in residual connections. Each module will cast back to float16 by enclosing autocast context. + x_B_T_H_W_D = x_B_T_H_W_D.float() + if extra_per_block_pos_emb is not None: x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb From f90fa1a89a717093286dd784c268811883f5c345 Mon Sep 17 00:00:00 2001 From: "Kohya S." <52813779+kohya-ss@users.noreply.github.com> Date: Mon, 23 Feb 2026 21:44:51 +0900 Subject: [PATCH 05/17] feat: backward compatibility for SD/SDXL latent cache (#2276) * fix: improve handling of legacy npz files and add logging for fallback scenarios * fix: simplify fallback handling in SdSdxlLatentsCachingStrategy --- library/strategy_base.py | 88 ++++++++++++++-------------------------- library/strategy_sd.py | 23 +++-------- 2 files changed, 37 insertions(+), 74 deletions(-) diff --git a/library/strategy_base.py b/library/strategy_base.py index 9a2acdba..5a043342 100644 --- a/library/strategy_base.py +++ b/library/strategy_base.py @@ -6,11 +6,6 @@ from typing import Any, List, Optional, Tuple, Union, Callable import numpy as np import torch - -try: - from numpy.lib import _format_impl as np_format_impl -except ImportError: - from numpy.lib import format as np_format_impl from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection @@ -387,6 +382,8 @@ class LatentsCachingStrategy: _strategy = None # strategy instance: actual strategy class + _warned_fallback_to_old_npz = False # to avoid spamming logs about fallback + def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: self._cache_to_disk = cache_to_disk self._batch_size = batch_size @@ -429,16 +426,6 @@ class LatentsCachingStrategy: def cache_batch_latents(self, model: Any, batch: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): raise NotImplementedError - def _get_npz_array_shape(self, npz: Any, key: str) -> Optional[Tuple[int, ...]]: - """Get array shape in npz file by only reading the header.""" - if key not in npz: - return None - - with npz.zip.open(key + ".npy") as npy_file: - version = np.lib.format.read_magic(npy_file) - shape, _, _ = np_format_impl._read_array_header(npy_file, version) - return shape - def _default_is_disk_cached_latents_expected( self, latents_stride: int, @@ -447,7 +434,6 @@ class LatentsCachingStrategy: flip_aug: bool, apply_alpha_mask: bool, multi_resolution: bool = False, - fallback_no_reso: bool = False, ) -> bool: """ Args: @@ -457,7 +443,6 @@ class LatentsCachingStrategy: flip_aug: whether to flip images apply_alpha_mask: whether to apply alpha mask multi_resolution: whether to use multi-resolution latents - fallback_no_reso: fallback to legacy key without resolution suffix Returns: bool @@ -475,21 +460,16 @@ class LatentsCachingStrategy: key_reso_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}" if multi_resolution else "" try: - with np.load(npz_path) as npz: - if "latents" + key_reso_suffix not in npz: - if not (multi_resolution and fallback_no_reso): - return False + npz = np.load(npz_path) - latents_shape = self._get_npz_array_shape(npz, "latents") - if latents_shape is None or tuple(latents_shape[-2:]) != expected_latents_size: - return False - - key_reso_suffix = "" - - if flip_aug and "latents_flipped" + key_reso_suffix not in npz: - return False - if apply_alpha_mask and "alpha_mask" + key_reso_suffix not in npz: - return False + # In old SD/SDXL npz files, if the actual latents shape does not match the expected shape, it doesn't raise an error as long as "latents" key exists (backward compatibility) + # In non-SD/SDXL npz files (multi-resolution support), the latents key always has the resolution suffix, and no latents key without suffix exists, so it raises an error if the expected resolution suffix key is not found (this doesn't change the behavior for non-SD/SDXL npz files). + if "latents" + key_reso_suffix not in npz and "latents" not in npz: + return False + if flip_aug and ("latents_flipped" + key_reso_suffix not in npz and "latents_flipped" not in npz): + return False + if apply_alpha_mask and ("alpha_mask" + key_reso_suffix not in npz and "alpha_mask" not in npz): + return False except Exception as e: logger.error(f"Error loading file: {npz_path}") raise e @@ -568,7 +548,7 @@ class LatentsCachingStrategy: self, npz_path: str, bucket_reso: Tuple[int, int] ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: """ - for SD/SDXL + For single resolution architectures (currently no architecture is single resolution specific). Kept for reference. Args: npz_path (str): Path to the npz file. @@ -586,18 +566,13 @@ class LatentsCachingStrategy: return self._default_load_latents_from_disk(None, npz_path, bucket_reso) def _default_load_latents_from_disk( - self, - latents_stride: Optional[int], - npz_path: str, - bucket_reso: Tuple[int, int], - fallback_no_reso: bool = False, + self, latents_stride: Optional[int], npz_path: str, bucket_reso: Tuple[int, int] ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: """ Args: latents_stride (Optional[int]): Stride for latents. If None, load all latents. npz_path (str): Path to the npz file. bucket_reso (Tuple[int, int]): The resolution of the bucket. - fallback_no_reso (bool): fallback to legacy key without resolution suffix Returns: Tuple[ @@ -609,31 +584,30 @@ class LatentsCachingStrategy: ]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask """ if latents_stride is None: - expected_latents_size = None key_reso_suffix = "" else: expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H) key_reso_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}" # e.g. "_32x64", HxW - with np.load(npz_path) as npz: - key_reso_suffix = key_reso_suffix + npz = np.load(npz_path) + if "latents" + key_reso_suffix not in npz: + # raise ValueError(f"latents{key_reso_suffix} not found in {npz_path}") + # Fallback to old npz without resolution suffix + if "latents" not in npz: + raise ValueError(f"latents not found in {npz_path} (either with or without resolution suffix: {key_reso_suffix})") + if not self._warned_fallback_to_old_npz: + logger.warning( + f"latents{key_reso_suffix} not found in {npz_path}. Falling back to latents without resolution suffix (old npz). This warning will only be shown once. To avoid this warning, please re-cache the latents with the latest version." + ) + self._warned_fallback_to_old_npz = True + key_reso_suffix = "" - if "latents" + key_reso_suffix not in npz: - if not fallback_no_reso or expected_latents_size is None: - raise ValueError(f"latents{key_reso_suffix} not found in {npz_path}") - - latents_shape = self._get_npz_array_shape(npz, "latents") - if latents_shape is None or tuple(latents_shape[-2:]) != expected_latents_size: - raise ValueError(f"latents with legacy key has unexpected shape {latents_shape} in {npz_path}") - - key_reso_suffix = "" - - latents = npz["latents" + key_reso_suffix] - original_size = npz["original_size" + key_reso_suffix].tolist() - crop_ltrb = npz["crop_ltrb" + key_reso_suffix].tolist() - flipped_latents = npz["latents_flipped" + key_reso_suffix] if "latents_flipped" + key_reso_suffix in npz else None - alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None - return latents, original_size, crop_ltrb, flipped_latents, alpha_mask + latents = npz["latents" + key_reso_suffix] + original_size = npz["original_size" + key_reso_suffix].tolist() + crop_ltrb = npz["crop_ltrb" + key_reso_suffix].tolist() + flipped_latents = npz["latents_flipped" + key_reso_suffix] if "latents_flipped" + key_reso_suffix in npz else None + alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None + return latents, original_size, crop_ltrb, flipped_latents, alpha_mask def save_latents_to_disk( self, diff --git a/library/strategy_sd.py b/library/strategy_sd.py index 837b8f5a..4521ae8d 100644 --- a/library/strategy_sd.py +++ b/library/strategy_sd.py @@ -145,7 +145,7 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy): self.suffix = ( SdSdxlLatentsCachingStrategy.SD_LATENTS_NPZ_SUFFIX if sd else SdSdxlLatentsCachingStrategy.SDXL_LATENTS_NPZ_SUFFIX ) - + @property def cache_suffix(self) -> str: return self.suffix @@ -158,25 +158,12 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy): return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): - return self._default_is_disk_cached_latents_expected( - 8, - bucket_reso, - npz_path, - flip_aug, - alpha_mask, - multi_resolution=True, - fallback_no_reso=True, - ) + return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True) def load_latents_from_disk( self, npz_path: str, bucket_reso: Tuple[int, int] ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: - return self._default_load_latents_from_disk( - 8, - npz_path, - bucket_reso, - fallback_no_reso=True, - ) + return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # TODO remove circular dependency for ImageInfo def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): @@ -184,7 +171,9 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy): vae_device = vae.device vae_dtype = vae.dtype - self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True) + self._default_cache_batch_latents( + encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True + ) if not train_util.HIGH_VRAM: train_util.clean_memory_on_device(vae.device) From 2217704ce17c8650838627d46e9e8864762070b9 Mon Sep 17 00:00:00 2001 From: "Kohya S." <52813779+kohya-ss@users.noreply.github.com> Date: Mon, 23 Feb 2026 22:09:00 +0900 Subject: [PATCH 06/17] feat: Support LoKr/LoHa for SDXL and Anima (#2275) * feat: Add LoHa/LoKr network support for SDXL and Anima - networks/network_base.py: shared AdditionalNetwork base class with architecture auto-detection (SDXL/Anima) and generic module injection - networks/loha.py: LoHa (Low-rank Hadamard Product) module with HadaWeight custom autograd, training/inference classes, and factory functions - networks/lokr.py: LoKr (Low-rank Kronecker Product) module with factorization, training/inference classes, and factory functions - library/lora_utils.py: extend weight merge hook to detect and merge LoHa/LoKr weights alongside standard LoRA Linear and Conv2d 1x1 layers only; Conv2d 3x3 (Tucker decomposition) support will be added separately. Co-Authored-By: Claude Opus 4.6 * feat: Enhance LoHa and LoKr modules with Tucker decomposition support - Added Tucker decomposition functionality to LoHa and LoKr modules. - Implemented new methods for weight rebuilding using Tucker decomposition. - Updated initialization and weight handling for Conv2d 3x3+ layers. - Modified get_diff_weight methods to accommodate Tucker and non-Tucker modes. - Enhanced network base to include unet_conv_target_modules for architecture detection. * fix: rank dropout handling in LoRAModule for Conv2d and Linear layers, see #2272 for details * doc: add dtype comment for load_safetensors_with_lora_and_fp8 function * fix: enhance architecture detection to support InferSdxlUNet2DConditionModel for gen_img.py * doc: update model support structure to include Lumina Image 2.0, HunyuanImage-2.1, and Anima-Preview * doc: add documentation for LoHa and LoKr fine-tuning methods * Update networks/network_base.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update docs/loha_lokr.md Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix: refactor LoHa and LoKr imports for weight merging in load_safetensors_with_lora_and_fp8 function --------- Co-authored-by: Claude Opus 4.6 Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .ai/context/01-overview.md | 3 + .gitignore | 1 + docs/loha_lokr.md | 359 +++++++++++++++++++ library/lora_utils.py | 554 +++++++++++++++--------------- networks/loha.py | 643 ++++++++++++++++++++++++++++++++++ networks/lokr.py | 683 +++++++++++++++++++++++++++++++++++++ networks/lora_anima.py | 209 +++++++++++- networks/network_base.py | 545 +++++++++++++++++++++++++++++ 8 files changed, 2729 insertions(+), 268 deletions(-) create mode 100644 docs/loha_lokr.md create mode 100644 networks/loha.py create mode 100644 networks/lokr.py create mode 100644 networks/network_base.py diff --git a/.ai/context/01-overview.md b/.ai/context/01-overview.md index 41133e98..c37aba19 100644 --- a/.ai/context/01-overview.md +++ b/.ai/context/01-overview.md @@ -21,6 +21,9 @@ Each supported model family has a consistent structure: - **SDXL**: `sdxl_train*.py`, `library/sdxl_*` - **SD3**: `sd3_train*.py`, `library/sd3_*` - **FLUX.1**: `flux_train*.py`, `library/flux_*` +- **Lumina Image 2.0**: `lumina_train*.py`, `library/lumina_*` +- **HunyuanImage-2.1**: `hunyuan_image_train*.py`, `library/hunyuan_image_*` +- **Anima-Preview**: `anima_train*.py`, `library/anima_*` ### Key Components diff --git a/.gitignore b/.gitignore index cfdc0268..f5772a7f 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,4 @@ GEMINI.md .claude .gemini MagicMock +references \ No newline at end of file diff --git a/docs/loha_lokr.md b/docs/loha_lokr.md new file mode 100644 index 00000000..6f16ba66 --- /dev/null +++ b/docs/loha_lokr.md @@ -0,0 +1,359 @@ +> 📝 Click on the language section to expand / 言語をクリックして展開 + +# LoHa / LoKr (LyCORIS) + +## Overview / 概要 + +In addition to standard LoRA, sd-scripts supports **LoHa** (Low-rank Hadamard Product) and **LoKr** (Low-rank Kronecker Product) as alternative parameter-efficient fine-tuning methods. These are based on techniques from the [LyCORIS](https://github.com/KohakuBlueleaf/LyCORIS) project. + +- **LoHa**: Represents weight updates as a Hadamard (element-wise) product of two low-rank matrices. Reference: [FedPara (arXiv:2108.06098)](https://arxiv.org/abs/2108.06098) +- **LoKr**: Represents weight updates as a Kronecker product with optional low-rank decomposition. Reference: [LoKr (arXiv:2309.14859)](https://arxiv.org/abs/2309.14859) + +The algorithms and recommended settings are described in the [LyCORIS documentation](https://github.com/KohakuBlueleaf/LyCORIS/blob/main/docs/Algo-List.md) and [guidelines](https://github.com/KohakuBlueleaf/LyCORIS/blob/main/docs/Guidelines.md). + +Both methods target Linear and Conv2d layers. Conv2d 1x1 layers are treated similarly to Linear layers. For Conv2d 3x3+ layers, optional Tucker decomposition or flat (kernel-flattened) mode is available. + +This feature is experimental. + +
+日本語 + +sd-scriptsでは、標準的なLoRAに加え、代替のパラメータ効率の良いファインチューニング手法として **LoHa**(Low-rank Hadamard Product)と **LoKr**(Low-rank Kronecker Product)をサポートしています。これらは [LyCORIS](https://github.com/KohakuBlueleaf/LyCORIS) プロジェクトの手法に基づいています。 + +- **LoHa**: 重みの更新を2つの低ランク行列のHadamard積(要素ごとの積)で表現します。参考文献: [FedPara (arXiv:2108.06098)](https://arxiv.org/abs/2108.06098) +- **LoKr**: 重みの更新をKronecker積と、オプションの低ランク分解で表現します。参考文献: [LoKr (arXiv:2309.14859)](https://arxiv.org/abs/2309.14859) + +アルゴリズムと推奨設定は[LyCORISのアルゴリズム解説](https://github.com/KohakuBlueleaf/LyCORIS/blob/main/docs/Algo-List.md)と[ガイドライン](https://github.com/KohakuBlueleaf/LyCORIS/blob/main/docs/Guidelines.md)を参照してください。 + +LinearおよびConv2d層の両方を対象としています。Conv2d 1x1層はLinear層と同様に扱われます。Conv2d 3x3+層については、オプションのTucker分解またはflat(カーネル平坦化)モードが利用可能です。 + +この機能は実験的なものです。 + +
+ +## Acknowledgments / 謝辞 + +The LoHa and LoKr implementations in sd-scripts are based on the [LyCORIS](https://github.com/KohakuBlueleaf/LyCORIS) project by [KohakuBlueleaf](https://github.com/KohakuBlueleaf). We would like to express our sincere gratitude for the excellent research and open-source contributions that made this implementation possible. + +
+日本語 + +sd-scriptsのLoHaおよびLoKrの実装は、[KohakuBlueleaf](https://github.com/KohakuBlueleaf)氏による[LyCORIS](https://github.com/KohakuBlueleaf/LyCORIS)プロジェクトに基づいています。この実装を可能にしてくださった素晴らしい研究とオープンソースへの貢献に心から感謝いたします。 + +
+ +## Supported architectures / 対応アーキテクチャ + +LoHa and LoKr automatically detect the model architecture and apply appropriate default settings. The following architectures are currently supported: + +- **SDXL**: Targets `Transformer2DModel` for UNet and `CLIPAttention`/`CLIPMLP` for text encoders. Conv2d layers in `ResnetBlock2D`, `Downsample2D`, and `Upsample2D` are also supported when `conv_dim` is specified. No default `exclude_patterns`. +- **Anima**: Targets `Block`, `PatchEmbed`, `TimestepEmbedding`, and `FinalLayer` for DiT, and `Qwen3Attention`/`Qwen3MLP` for the text encoder. Default `exclude_patterns` automatically skips modulation, normalization, embedder, and final_layer modules. + +
+日本語 + +LoHaとLoKrは、モデルのアーキテクチャを自動で検出し、適切なデフォルト設定を適用します。現在、以下のアーキテクチャに対応しています: + +- **SDXL**: UNetの`Transformer2DModel`、テキストエンコーダの`CLIPAttention`/`CLIPMLP`を対象とします。`conv_dim`を指定した場合、`ResnetBlock2D`、`Downsample2D`、`Upsample2D`のConv2d層も対象になります。デフォルトの`exclude_patterns`はありません。 +- **Anima**: DiTの`Block`、`PatchEmbed`、`TimestepEmbedding`、`FinalLayer`、テキストエンコーダの`Qwen3Attention`/`Qwen3MLP`を対象とします。デフォルトの`exclude_patterns`により、modulation、normalization、embedder、final_layerモジュールは自動的にスキップされます。 + +
+ +## Training / 学習 + +To use LoHa or LoKr, change the `--network_module` argument in your training command. All other training options (dataset config, optimizer, etc.) remain the same as LoRA. + +
+日本語 + +LoHaまたはLoKrを使用するには、学習コマンドの `--network_module` 引数を変更します。その他の学習オプション(データセット設定、オプティマイザなど)はLoRAと同じです。 + +
+ +### LoHa (SDXL) + +```bash +accelerate launch --num_cpu_threads_per_process 1 --mixed_precision bf16 sdxl_train_network.py \ + --pretrained_model_name_or_path path/to/sdxl.safetensors \ + --dataset_config path/to/toml \ + --mixed_precision bf16 --fp8_base \ + --optimizer_type adamw8bit --learning_rate 2e-4 --gradient_checkpointing \ + --network_module networks.loha --network_dim 32 --network_alpha 16 \ + --max_train_epochs 16 --save_every_n_epochs 1 \ + --output_dir path/to/output --output_name my-loha +``` + +### LoKr (SDXL) + +```bash +accelerate launch --num_cpu_threads_per_process 1 --mixed_precision bf16 sdxl_train_network.py \ + --pretrained_model_name_or_path path/to/sdxl.safetensors \ + --dataset_config path/to/toml \ + --mixed_precision bf16 --fp8_base \ + --optimizer_type adamw8bit --learning_rate 2e-4 --gradient_checkpointing \ + --network_module networks.lokr --network_dim 32 --network_alpha 16 \ + --max_train_epochs 16 --save_every_n_epochs 1 \ + --output_dir path/to/output --output_name my-lokr +``` + +For Anima, replace `sdxl_train_network.py` with `anima_train_network.py` and use the appropriate model path and options. + +
+日本語 + +Animaの場合は、`sdxl_train_network.py` を `anima_train_network.py` に置き換え、適切なモデルパスとオプションを使用してください。 + +
+ +### Common training options / 共通の学習オプション + +The following `--network_args` options are available for both LoHa and LoKr, same as LoRA: + +| Option | Description | +|---|---| +| `verbose=True` | Display detailed information about the network modules | +| `rank_dropout=0.1` | Apply dropout to the rank dimension during training | +| `module_dropout=0.1` | Randomly skip entire modules during training | +| `exclude_patterns=[r'...']` | Exclude modules matching the regex patterns (in addition to architecture defaults) | +| `include_patterns=[r'...']` | Override excludes: modules matching these regex patterns will be included even if they match `exclude_patterns` | +| `network_reg_lrs=regex1=lr1,regex2=lr2` | Set per-module learning rates using regex patterns | +| `network_reg_dims=regex1=dim1,regex2=dim2` | Set per-module dimensions (rank) using regex patterns | + +
+日本語 + +以下の `--network_args` オプションは、LoRAと同様にLoHaとLoKrの両方で使用できます: + +| オプション | 説明 | +|---|---| +| `verbose=True` | ネットワークモジュールの詳細情報を表示 | +| `rank_dropout=0.1` | 学習時にランク次元にドロップアウトを適用 | +| `module_dropout=0.1` | 学習時にモジュール全体をランダムにスキップ | +| `exclude_patterns=[r'...']` | 正規表現パターンに一致するモジュールを除外(アーキテクチャのデフォルトに追加) | +| `include_patterns=[r'...']` | 正規表現パターンに一致するモジュールのみを対象とする | +| `network_reg_lrs=regex1=lr1,regex2=lr2` | 正規表現パターンでモジュールごとの学習率を設定 | +| `network_reg_dims=regex1=dim1,regex2=dim2` | 正規表現パターンでモジュールごとの次元(ランク)を設定 | + +
+ +### Conv2d support / Conv2dサポート + +By default, LoHa and LoKr target Linear and Conv2d 1x1 layers. To also train Conv2d 3x3+ layers (e.g., in SDXL's ResNet blocks), use the `conv_dim` and `conv_alpha` options: + +```bash +--network_args "conv_dim=16" "conv_alpha=8" +``` + +For Conv2d 3x3+ layers, you can enable Tucker decomposition for more efficient parameter representation: + +```bash +--network_args "conv_dim=16" "conv_alpha=8" "use_tucker=True" +``` + +- Without `use_tucker`: The kernel dimensions are flattened into the input dimension (flat mode). +- With `use_tucker=True`: A separate Tucker tensor is used to handle the kernel dimensions, which can be more parameter-efficient. + +
+日本語 + +デフォルトでは、LoHaとLoKrはLinearおよびConv2d 1x1層を対象とします。Conv2d 3x3+層(SDXLのResNetブロックなど)も学習するには、`conv_dim`と`conv_alpha`オプションを使用します: + +```bash +--network_args "conv_dim=16" "conv_alpha=8" +``` + +Conv2d 3x3+層に対して、Tucker分解を有効にすることで、より効率的なパラメータ表現が可能です: + +```bash +--network_args "conv_dim=16" "conv_alpha=8" "use_tucker=True" +``` + +- `use_tucker`なし: カーネル次元が入力次元に平坦化されます(flatモード)。 +- `use_tucker=True`: カーネル次元を扱う別のTuckerテンソルが使用され、よりパラメータ効率が良くなる場合があります。 + +
+ +### LoKr-specific option: `factor` / LoKr固有のオプション: `factor` + +LoKr decomposes weight dimensions using factorization. The `factor` option controls how dimensions are split: + +- `factor=-1` (default): Automatically find balanced factors. For example, dimension 512 is split into (16, 32). +- `factor=N` (positive integer): Force factorization using the specified value. For example, `factor=4` splits dimension 512 into (4, 128). + +```bash +--network_args "factor=4" +``` + +When `network_dim` (rank) is large enough relative to the factorized dimensions, LoKr uses a full matrix instead of a low-rank decomposition for the second factor. A warning will be logged in this case. + +
+日本語 + +LoKrは重みの次元を因数分解して分割します。`factor` オプションでその分割方法を制御します: + +- `factor=-1`(デフォルト): バランスの良い因数を自動的に見つけます。例えば、次元512は(16, 32)に分割されます。 +- `factor=N`(正の整数): 指定した値で因数分解します。例えば、`factor=4` は次元512を(4, 128)に分割します。 + +```bash +--network_args "factor=4" +``` + +`network_dim`(ランク)が因数分解された次元に対して十分に大きい場合、LoKrは第2因子に低ランク分解ではなくフル行列を使用します。その場合、警告がログに出力されます。 + +
+ +### Anima-specific option: `train_llm_adapter` / Anima固有のオプション: `train_llm_adapter` + +For Anima, you can additionally train the LLM adapter modules by specifying: + +```bash +--network_args "train_llm_adapter=True" +``` + +This includes `LLMAdapterTransformerBlock` modules as training targets. + +
+日本語 + +Animaでは、以下を指定することでLLMアダプターモジュールも追加で学習できます: + +```bash +--network_args "train_llm_adapter=True" +``` + +これにより、`LLMAdapterTransformerBlock` モジュールが学習対象に含まれます。 + +
+ +### LoRA+ / LoRA+ + +LoRA+ (`loraplus_lr_ratio` etc. in `--network_args`) is supported with LoHa/LoKr. For LoHa, the second pair of matrices (`hada_w2_a`) is treated as the "plus" (higher learning rate) parameter group. For LoKr, the scale factor (`lokr_w1`) is treated as the "plus" parameter group. + +```bash +--network_args "loraplus_lr_ratio=4" +``` + +This feature has been confirmed to work in basic testing, but feedback is welcome. If you encounter any issues, please report them. + +
+日本語 + +LoRA+(`--network_args` の `loraplus_lr_ratio` 等)はLoHa/LoKrでもサポートされています。LoHaでは第2ペアの行列(`hada_w2_a`)が「plus」(より高い学習率)パラメータグループとして扱われます。LoKrではスケール係数(`lokr_w1`)が「plus」パラメータグループとして扱われます。 + +```bash +--network_args "loraplus_lr_ratio=4" +``` + +この機能は基本的なテストでは動作確認されていますが、フィードバックをお待ちしています。問題が発生した場合はご報告ください。 + +
+ +## How LoHa and LoKr work / LoHaとLoKrの仕組み + +### LoHa + +LoHa represents the weight update as a Hadamard (element-wise) product of two low-rank matrices: + +``` +ΔW = (W1a × W1b) ⊙ (W2a × W2b) +``` + +where `W1a`, `W1b`, `W2a`, `W2b` are low-rank matrices with rank `network_dim`. This means LoHa has roughly **twice the number of trainable parameters** compared to LoRA at the same rank, but can capture more complex weight structures due to the element-wise product. + +For Conv2d 3x3+ layers with Tucker decomposition, each pair additionally has a Tucker tensor `T` and the reconstruction becomes: `einsum("i j ..., j r, i p -> p r ...", T, Wb, Wa)`. + +### LoKr + +LoKr represents the weight update using a Kronecker product: + +``` +ΔW = W1 ⊗ W2 (where W2 = W2a × W2b in low-rank mode) +``` + +The original weight dimensions are factorized (e.g., a 512×512 weight might be split so that W1 is 16×16 and W2 is 32×32). W1 is always a full matrix (small), while W2 can be either low-rank decomposed or a full matrix depending on the rank setting. LoKr tends to produce **smaller models** compared to LoRA at the same rank. + +
+日本語 + +### LoHa + +LoHaは重みの更新を2つの低ランク行列のHadamard積(要素ごとの積)で表現します: + +``` +ΔW = (W1a × W1b) ⊙ (W2a × W2b) +``` + +ここで `W1a`, `W1b`, `W2a`, `W2b` はランク `network_dim` の低ランク行列です。LoHaは同じランクのLoRAと比較して学習可能なパラメータ数が **約2倍** になりますが、要素ごとの積により、より複雑な重み構造を捉えることができます。 + +Conv2d 3x3+層でTucker分解を使用する場合、各ペアにはさらにTuckerテンソル `T` があり、再構成は `einsum("i j ..., j r, i p -> p r ...", T, Wb, Wa)` となります。 + +### LoKr + +LoKrはKronecker積を使って重みの更新を表現します: + +``` +ΔW = W1 ⊗ W2 (低ランクモードでは W2 = W2a × W2b) +``` + +元の重みの次元が因数分解されます(例: 512×512の重みが、W1が16×16、W2が32×32に分割されます)。W1は常にフル行列(小さい)で、W2はランク設定に応じて低ランク分解またはフル行列になります。LoKrは同じランクのLoRAと比較して **より小さいモデル** を生成する傾向があります。 + +
+ +## Inference / 推論 + +Trained LoHa/LoKr weights are saved in safetensors format, just like LoRA. + +
+日本語 + +学習済みのLoHa/LoKrの重みは、LoRAと同様にsafetensors形式で保存されます。 + +
+ +### SDXL + +For SDXL, use `gen_img.py` with `--network_module` and `--network_weights`, the same way as LoRA: + +```bash +python gen_img.py --ckpt path/to/sdxl.safetensors \ + --network_module networks.loha --network_weights path/to/loha.safetensors \ + --prompt "your prompt" ... +``` + +Replace `networks.loha` with `networks.lokr` for LoKr weights. + +
+日本語 + +SDXLでは、LoRAと同様に `gen_img.py` で `--network_module` と `--network_weights` を指定します: + +```bash +python gen_img.py --ckpt path/to/sdxl.safetensors \ + --network_module networks.loha --network_weights path/to/loha.safetensors \ + --prompt "your prompt" ... +``` + +LoKrの重みを使用する場合は `networks.loha` を `networks.lokr` に置き換えてください。 + +
+ +### Anima + +For Anima, use `anima_minimal_inference.py` with the `--lora_weight` argument. LoRA, LoHa, and LoKr weights are automatically detected and merged: + +```bash +python anima_minimal_inference.py --dit path/to/dit --prompt "your prompt" \ + --lora_weight path/to/loha_or_lokr.safetensors ... +``` + +
+日本語 + +Animaでは、`anima_minimal_inference.py` に `--lora_weight` 引数を指定します。LoRA、LoHa、LoKrの重みは自動的に判定されてマージされます: + +```bash +python anima_minimal_inference.py --dit path/to/dit --prompt "your prompt" \ + --lora_weight path/to/loha_or_lokr.safetensors ... +``` + +
diff --git a/library/lora_utils.py b/library/lora_utils.py index 90e3c389..dadad898 100644 --- a/library/lora_utils.py +++ b/library/lora_utils.py @@ -1,267 +1,287 @@ -import os -import re -from typing import Dict, List, Optional, Union -import torch -from tqdm import tqdm -from library.device_utils import synchronize_device -from library.fp8_optimization_utils import load_safetensors_with_fp8_optimization -from library.safetensors_utils import MemoryEfficientSafeOpen, TensorWeightAdapter, WeightTransformHooks, get_split_weight_filenames -from library.utils import setup_logging - -setup_logging() -import logging - -logger = logging.getLogger(__name__) - - -def filter_lora_state_dict( - weights_sd: Dict[str, torch.Tensor], - include_pattern: Optional[str] = None, - exclude_pattern: Optional[str] = None, -) -> Dict[str, torch.Tensor]: - # apply include/exclude patterns - original_key_count = len(weights_sd.keys()) - if include_pattern is not None: - regex_include = re.compile(include_pattern) - weights_sd = {k: v for k, v in weights_sd.items() if regex_include.search(k)} - logger.info(f"Filtered keys with include pattern {include_pattern}: {original_key_count} -> {len(weights_sd.keys())}") - - if exclude_pattern is not None: - original_key_count_ex = len(weights_sd.keys()) - regex_exclude = re.compile(exclude_pattern) - weights_sd = {k: v for k, v in weights_sd.items() if not regex_exclude.search(k)} - logger.info(f"Filtered keys with exclude pattern {exclude_pattern}: {original_key_count_ex} -> {len(weights_sd.keys())}") - - if len(weights_sd) != original_key_count: - remaining_keys = list(set([k.split(".", 1)[0] for k in weights_sd.keys()])) - remaining_keys.sort() - logger.info(f"Remaining LoRA modules after filtering: {remaining_keys}") - if len(weights_sd) == 0: - logger.warning("No keys left after filtering.") - - return weights_sd - - -def load_safetensors_with_lora_and_fp8( - model_files: Union[str, List[str]], - lora_weights_list: Optional[List[Dict[str, torch.Tensor]]], - lora_multipliers: Optional[List[float]], - fp8_optimization: bool, - calc_device: torch.device, - move_to_device: bool = False, - dit_weight_dtype: Optional[torch.dtype] = None, - target_keys: Optional[List[str]] = None, - exclude_keys: Optional[List[str]] = None, - disable_numpy_memmap: bool = False, - weight_transform_hooks: Optional[WeightTransformHooks] = None, -) -> dict[str, torch.Tensor]: - """ - Merge LoRA weights into the state dict of a model with fp8 optimization if needed. - - Args: - model_files (Union[str, List[str]]): Path to the model file or list of paths. If the path matches a pattern like `00001-of-00004`, it will load all files with the same prefix. - lora_weights_list (Optional[List[Dict[str, torch.Tensor]]]): List of dictionaries of LoRA weight tensors to load. - lora_multipliers (Optional[List[float]]): List of multipliers for LoRA weights. - fp8_optimization (bool): Whether to apply FP8 optimization. - calc_device (torch.device): Device to calculate on. - move_to_device (bool): Whether to move tensors to the calculation device after loading. - target_keys (Optional[List[str]]): Keys to target for optimization. - exclude_keys (Optional[List[str]]): Keys to exclude from optimization. - disable_numpy_memmap (bool): Whether to disable numpy memmap when loading safetensors. - weight_transform_hooks (Optional[WeightTransformHooks]): Hooks for transforming weights during loading. - """ - - # if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix - if isinstance(model_files, str): - model_files = [model_files] - - extended_model_files = [] - for model_file in model_files: - split_filenames = get_split_weight_filenames(model_file) - if split_filenames is not None: - extended_model_files.extend(split_filenames) - else: - extended_model_files.append(model_file) - model_files = extended_model_files - logger.info(f"Loading model files: {model_files}") - - # load LoRA weights - weight_hook = None - if lora_weights_list is None or len(lora_weights_list) == 0: - lora_weights_list = [] - lora_multipliers = [] - list_of_lora_weight_keys = [] - else: - list_of_lora_weight_keys = [] - for lora_sd in lora_weights_list: - lora_weight_keys = set(lora_sd.keys()) - list_of_lora_weight_keys.append(lora_weight_keys) - - if lora_multipliers is None: - lora_multipliers = [1.0] * len(lora_weights_list) - while len(lora_multipliers) < len(lora_weights_list): - lora_multipliers.append(1.0) - if len(lora_multipliers) > len(lora_weights_list): - lora_multipliers = lora_multipliers[: len(lora_weights_list)] - - # Merge LoRA weights into the state dict - logger.info(f"Merging LoRA weights into state dict. multipliers: {lora_multipliers}") - - # make hook for LoRA merging - def weight_hook_func(model_weight_key, model_weight: torch.Tensor, keep_on_calc_device=False): - nonlocal list_of_lora_weight_keys, lora_weights_list, lora_multipliers, calc_device - - if not model_weight_key.endswith(".weight"): - return model_weight - - original_device = model_weight.device - if original_device != calc_device: - model_weight = model_weight.to(calc_device) # to make calculation faster - - for lora_weight_keys, lora_sd, multiplier in zip(list_of_lora_weight_keys, lora_weights_list, lora_multipliers): - # check if this weight has LoRA weights - lora_name_without_prefix = model_weight_key.rsplit(".", 1)[0] # remove trailing ".weight" - found = False - for prefix in ["lora_unet_", ""]: - lora_name = prefix + lora_name_without_prefix.replace(".", "_") - down_key = lora_name + ".lora_down.weight" - up_key = lora_name + ".lora_up.weight" - alpha_key = lora_name + ".alpha" - if down_key in lora_weight_keys and up_key in lora_weight_keys: - found = True - break - if not found: - continue # no LoRA weights for this model weight - - # get LoRA weights - down_weight = lora_sd[down_key] - up_weight = lora_sd[up_key] - - dim = down_weight.size()[0] - alpha = lora_sd.get(alpha_key, dim) - scale = alpha / dim - - down_weight = down_weight.to(calc_device) - up_weight = up_weight.to(calc_device) - - original_dtype = model_weight.dtype - if original_dtype.itemsize == 1: # fp8 - # temporarily convert to float16 for calculation - model_weight = model_weight.to(torch.float16) - down_weight = down_weight.to(torch.float16) - up_weight = up_weight.to(torch.float16) - - # W <- W + U * D - if len(model_weight.size()) == 2: - # linear - if len(up_weight.size()) == 4: # use linear projection mismatch - up_weight = up_weight.squeeze(3).squeeze(2) - down_weight = down_weight.squeeze(3).squeeze(2) - model_weight = model_weight + multiplier * (up_weight @ down_weight) * scale - elif down_weight.size()[2:4] == (1, 1): - # conv2d 1x1 - model_weight = ( - model_weight - + multiplier - * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) - * scale - ) - else: - # conv2d 3x3 - conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - # logger.info(conved.size(), weight.size(), module.stride, module.padding) - model_weight = model_weight + multiplier * conved * scale - - if original_dtype.itemsize == 1: # fp8 - model_weight = model_weight.to(original_dtype) # convert back to original dtype - - # remove LoRA keys from set - lora_weight_keys.remove(down_key) - lora_weight_keys.remove(up_key) - if alpha_key in lora_weight_keys: - lora_weight_keys.remove(alpha_key) - - if not keep_on_calc_device and original_device != calc_device: - model_weight = model_weight.to(original_device) # move back to original device - return model_weight - - weight_hook = weight_hook_func - - state_dict = load_safetensors_with_fp8_optimization_and_hook( - model_files, - fp8_optimization, - calc_device, - move_to_device, - dit_weight_dtype, - target_keys, - exclude_keys, - weight_hook=weight_hook, - disable_numpy_memmap=disable_numpy_memmap, - weight_transform_hooks=weight_transform_hooks, - ) - - for lora_weight_keys in list_of_lora_weight_keys: - # check if all LoRA keys are used - if len(lora_weight_keys) > 0: - # if there are still LoRA keys left, it means they are not used in the model - # this is a warning, not an error - logger.warning(f"Warning: not all LoRA keys are used: {', '.join(lora_weight_keys)}") - - return state_dict - - -def load_safetensors_with_fp8_optimization_and_hook( - model_files: list[str], - fp8_optimization: bool, - calc_device: torch.device, - move_to_device: bool = False, - dit_weight_dtype: Optional[torch.dtype] = None, - target_keys: Optional[List[str]] = None, - exclude_keys: Optional[List[str]] = None, - weight_hook: callable = None, - disable_numpy_memmap: bool = False, - weight_transform_hooks: Optional[WeightTransformHooks] = None, -) -> dict[str, torch.Tensor]: - """ - Load state dict from safetensors files and merge LoRA weights into the state dict with fp8 optimization if needed. - """ - if fp8_optimization: - logger.info( - f"Loading state dict with FP8 optimization. Dtype of weight: {dit_weight_dtype}, hook enabled: {weight_hook is not None}" - ) - # dit_weight_dtype is not used because we use fp8 optimization - state_dict = load_safetensors_with_fp8_optimization( - model_files, - calc_device, - target_keys, - exclude_keys, - move_to_device=move_to_device, - weight_hook=weight_hook, - disable_numpy_memmap=disable_numpy_memmap, - weight_transform_hooks=weight_transform_hooks, - ) - else: - logger.info( - f"Loading state dict without FP8 optimization. Dtype of weight: {dit_weight_dtype}, hook enabled: {weight_hook is not None}" - ) - state_dict = {} - for model_file in model_files: - with MemoryEfficientSafeOpen(model_file, disable_numpy_memmap=disable_numpy_memmap) as original_f: - f = TensorWeightAdapter(weight_transform_hooks, original_f) if weight_transform_hooks is not None else original_f - for key in tqdm(f.keys(), desc=f"Loading {os.path.basename(model_file)}", leave=False): - if weight_hook is None and move_to_device: - value = f.get_tensor(key, device=calc_device, dtype=dit_weight_dtype) - else: - value = f.get_tensor(key) # we cannot directly load to device because get_tensor does non-blocking transfer - if weight_hook is not None: - value = weight_hook(key, value, keep_on_calc_device=move_to_device) - if move_to_device: - value = value.to(calc_device, dtype=dit_weight_dtype, non_blocking=True) - elif dit_weight_dtype is not None: - value = value.to(dit_weight_dtype) - - state_dict[key] = value - if move_to_device: - synchronize_device(calc_device) - - return state_dict +import os +import re +from typing import Dict, List, Optional, Union +import torch +from tqdm import tqdm +from library.device_utils import synchronize_device +from library.fp8_optimization_utils import load_safetensors_with_fp8_optimization +from library.safetensors_utils import MemoryEfficientSafeOpen, TensorWeightAdapter, WeightTransformHooks, get_split_weight_filenames +from networks.loha import merge_weights_to_tensor as loha_merge +from networks.lokr import merge_weights_to_tensor as lokr_merge + +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +def filter_lora_state_dict( + weights_sd: Dict[str, torch.Tensor], + include_pattern: Optional[str] = None, + exclude_pattern: Optional[str] = None, +) -> Dict[str, torch.Tensor]: + # apply include/exclude patterns + original_key_count = len(weights_sd.keys()) + if include_pattern is not None: + regex_include = re.compile(include_pattern) + weights_sd = {k: v for k, v in weights_sd.items() if regex_include.search(k)} + logger.info(f"Filtered keys with include pattern {include_pattern}: {original_key_count} -> {len(weights_sd.keys())}") + + if exclude_pattern is not None: + original_key_count_ex = len(weights_sd.keys()) + regex_exclude = re.compile(exclude_pattern) + weights_sd = {k: v for k, v in weights_sd.items() if not regex_exclude.search(k)} + logger.info(f"Filtered keys with exclude pattern {exclude_pattern}: {original_key_count_ex} -> {len(weights_sd.keys())}") + + if len(weights_sd) != original_key_count: + remaining_keys = list(set([k.split(".", 1)[0] for k in weights_sd.keys()])) + remaining_keys.sort() + logger.info(f"Remaining LoRA modules after filtering: {remaining_keys}") + if len(weights_sd) == 0: + logger.warning("No keys left after filtering.") + + return weights_sd + + +def load_safetensors_with_lora_and_fp8( + model_files: Union[str, List[str]], + lora_weights_list: Optional[List[Dict[str, torch.Tensor]]], + lora_multipliers: Optional[List[float]], + fp8_optimization: bool, + calc_device: torch.device, + move_to_device: bool = False, + dit_weight_dtype: Optional[torch.dtype] = None, + target_keys: Optional[List[str]] = None, + exclude_keys: Optional[List[str]] = None, + disable_numpy_memmap: bool = False, + weight_transform_hooks: Optional[WeightTransformHooks] = None, +) -> dict[str, torch.Tensor]: + """ + Merge LoRA weights into the state dict of a model with fp8 optimization if needed. + + Args: + model_files (Union[str, List[str]]): Path to the model file or list of paths. If the path matches a pattern like `00001-of-00004`, it will load all files with the same prefix. + lora_weights_list (Optional[List[Dict[str, torch.Tensor]]]): List of dictionaries of LoRA weight tensors to load. + lora_multipliers (Optional[List[float]]): List of multipliers for LoRA weights. + fp8_optimization (bool): Whether to apply FP8 optimization. + calc_device (torch.device): Device to calculate on. + move_to_device (bool): Whether to move tensors to the calculation device after loading. + dit_weight_dtype (Optional[torch.dtype]): Dtype to load weights in when not using FP8 optimization. + target_keys (Optional[List[str]]): Keys to target for optimization. + exclude_keys (Optional[List[str]]): Keys to exclude from optimization. + disable_numpy_memmap (bool): Whether to disable numpy memmap when loading safetensors. + weight_transform_hooks (Optional[WeightTransformHooks]): Hooks for transforming weights during loading. + """ + + # if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix + if isinstance(model_files, str): + model_files = [model_files] + + extended_model_files = [] + for model_file in model_files: + split_filenames = get_split_weight_filenames(model_file) + if split_filenames is not None: + extended_model_files.extend(split_filenames) + else: + extended_model_files.append(model_file) + model_files = extended_model_files + logger.info(f"Loading model files: {model_files}") + + # load LoRA weights + weight_hook = None + if lora_weights_list is None or len(lora_weights_list) == 0: + lora_weights_list = [] + lora_multipliers = [] + list_of_lora_weight_keys = [] + else: + list_of_lora_weight_keys = [] + for lora_sd in lora_weights_list: + lora_weight_keys = set(lora_sd.keys()) + list_of_lora_weight_keys.append(lora_weight_keys) + + if lora_multipliers is None: + lora_multipliers = [1.0] * len(lora_weights_list) + while len(lora_multipliers) < len(lora_weights_list): + lora_multipliers.append(1.0) + if len(lora_multipliers) > len(lora_weights_list): + lora_multipliers = lora_multipliers[: len(lora_weights_list)] + + # Merge LoRA weights into the state dict + logger.info(f"Merging LoRA weights into state dict. multipliers: {lora_multipliers}") + + # make hook for LoRA merging + def weight_hook_func(model_weight_key, model_weight: torch.Tensor, keep_on_calc_device=False): + nonlocal list_of_lora_weight_keys, lora_weights_list, lora_multipliers, calc_device + + if not model_weight_key.endswith(".weight"): + return model_weight + + original_device = model_weight.device + if original_device != calc_device: + model_weight = model_weight.to(calc_device) # to make calculation faster + + for lora_weight_keys, lora_sd, multiplier in zip(list_of_lora_weight_keys, lora_weights_list, lora_multipliers): + # check if this weight has LoRA weights + lora_name_without_prefix = model_weight_key.rsplit(".", 1)[0] # remove trailing ".weight" + found = False + for prefix in ["lora_unet_", ""]: + lora_name = prefix + lora_name_without_prefix.replace(".", "_") + down_key = lora_name + ".lora_down.weight" + up_key = lora_name + ".lora_up.weight" + alpha_key = lora_name + ".alpha" + if down_key in lora_weight_keys and up_key in lora_weight_keys: + found = True + break + + if found: + # Standard LoRA merge + # get LoRA weights + down_weight = lora_sd[down_key] + up_weight = lora_sd[up_key] + + dim = down_weight.size()[0] + alpha = lora_sd.get(alpha_key, dim) + scale = alpha / dim + + down_weight = down_weight.to(calc_device) + up_weight = up_weight.to(calc_device) + + original_dtype = model_weight.dtype + if original_dtype.itemsize == 1: # fp8 + # temporarily convert to float16 for calculation + model_weight = model_weight.to(torch.float16) + down_weight = down_weight.to(torch.float16) + up_weight = up_weight.to(torch.float16) + + # W <- W + U * D + if len(model_weight.size()) == 2: + # linear + if len(up_weight.size()) == 4: # use linear projection mismatch + up_weight = up_weight.squeeze(3).squeeze(2) + down_weight = down_weight.squeeze(3).squeeze(2) + model_weight = model_weight + multiplier * (up_weight @ down_weight) * scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + model_weight = ( + model_weight + + multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) + model_weight = model_weight + multiplier * conved * scale + + if original_dtype.itemsize == 1: # fp8 + model_weight = model_weight.to(original_dtype) # convert back to original dtype + + # remove LoRA keys from set + lora_weight_keys.remove(down_key) + lora_weight_keys.remove(up_key) + if alpha_key in lora_weight_keys: + lora_weight_keys.remove(alpha_key) + continue + + # Check for LoHa/LoKr weights with same prefix search + for prefix in ["lora_unet_", ""]: + lora_name = prefix + lora_name_without_prefix.replace(".", "_") + hada_key = lora_name + ".hada_w1_a" + lokr_key = lora_name + ".lokr_w1" + + if hada_key in lora_weight_keys: + # LoHa merge + model_weight = loha_merge(model_weight, lora_name, lora_sd, lora_weight_keys, multiplier, calc_device) + break + elif lokr_key in lora_weight_keys: + # LoKr merge + model_weight = lokr_merge(model_weight, lora_name, lora_sd, lora_weight_keys, multiplier, calc_device) + break + + if not keep_on_calc_device and original_device != calc_device: + model_weight = model_weight.to(original_device) # move back to original device + return model_weight + + weight_hook = weight_hook_func + + state_dict = load_safetensors_with_fp8_optimization_and_hook( + model_files, + fp8_optimization, + calc_device, + move_to_device, + dit_weight_dtype, + target_keys, + exclude_keys, + weight_hook=weight_hook, + disable_numpy_memmap=disable_numpy_memmap, + weight_transform_hooks=weight_transform_hooks, + ) + + for lora_weight_keys in list_of_lora_weight_keys: + # check if all LoRA keys are used + if len(lora_weight_keys) > 0: + # if there are still LoRA keys left, it means they are not used in the model + # this is a warning, not an error + logger.warning(f"Warning: not all LoRA keys are used: {', '.join(lora_weight_keys)}") + + return state_dict + + +def load_safetensors_with_fp8_optimization_and_hook( + model_files: list[str], + fp8_optimization: bool, + calc_device: torch.device, + move_to_device: bool = False, + dit_weight_dtype: Optional[torch.dtype] = None, + target_keys: Optional[List[str]] = None, + exclude_keys: Optional[List[str]] = None, + weight_hook: callable = None, + disable_numpy_memmap: bool = False, + weight_transform_hooks: Optional[WeightTransformHooks] = None, +) -> dict[str, torch.Tensor]: + """ + Load state dict from safetensors files and merge LoRA weights into the state dict with fp8 optimization if needed. + """ + if fp8_optimization: + logger.info( + f"Loading state dict with FP8 optimization. Dtype of weight: {dit_weight_dtype}, hook enabled: {weight_hook is not None}" + ) + # dit_weight_dtype is not used because we use fp8 optimization + state_dict = load_safetensors_with_fp8_optimization( + model_files, + calc_device, + target_keys, + exclude_keys, + move_to_device=move_to_device, + weight_hook=weight_hook, + disable_numpy_memmap=disable_numpy_memmap, + weight_transform_hooks=weight_transform_hooks, + ) + else: + logger.info( + f"Loading state dict without FP8 optimization. Dtype of weight: {dit_weight_dtype}, hook enabled: {weight_hook is not None}" + ) + state_dict = {} + for model_file in model_files: + with MemoryEfficientSafeOpen(model_file, disable_numpy_memmap=disable_numpy_memmap) as original_f: + f = TensorWeightAdapter(weight_transform_hooks, original_f) if weight_transform_hooks is not None else original_f + for key in tqdm(f.keys(), desc=f"Loading {os.path.basename(model_file)}", leave=False): + if weight_hook is None and move_to_device: + value = f.get_tensor(key, device=calc_device, dtype=dit_weight_dtype) + else: + value = f.get_tensor(key) # we cannot directly load to device because get_tensor does non-blocking transfer + if weight_hook is not None: + value = weight_hook(key, value, keep_on_calc_device=move_to_device) + if move_to_device: + value = value.to(calc_device, dtype=dit_weight_dtype, non_blocking=True) + elif dit_weight_dtype is not None: + value = value.to(dit_weight_dtype) + + state_dict[key] = value + if move_to_device: + synchronize_device(calc_device) + + return state_dict diff --git a/networks/loha.py b/networks/loha.py new file mode 100644 index 00000000..8734f9c5 --- /dev/null +++ b/networks/loha.py @@ -0,0 +1,643 @@ +# LoHa (Low-rank Hadamard Product) network module +# Reference: https://arxiv.org/abs/2108.06098 +# +# Based on the LyCORIS project by KohakuBlueleaf +# https://github.com/KohakuBlueleaf/LyCORIS + +import ast +import os +import logging +from typing import Dict, List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .network_base import ArchConfig, AdditionalNetwork, detect_arch_config, _parse_kv_pairs +from library.utils import setup_logging + +setup_logging() +logger = logging.getLogger(__name__) + + +class HadaWeight(torch.autograd.Function): + """Efficient Hadamard product forward/backward for LoHa. + + Computes ((w1a @ w1b) * (w2a @ w2b)) * scale with custom backward + that recomputes intermediates instead of storing them. + """ + + @staticmethod + def forward(ctx, w1a, w1b, w2a, w2b, scale=None): + if scale is None: + scale = torch.tensor(1, device=w1a.device, dtype=w1a.dtype) + ctx.save_for_backward(w1a, w1b, w2a, w2b, scale) + diff_weight = ((w1a @ w1b) * (w2a @ w2b)) * scale + return diff_weight + + @staticmethod + def backward(ctx, grad_out): + (w1a, w1b, w2a, w2b, scale) = ctx.saved_tensors + grad_out = grad_out * scale + temp = grad_out * (w2a @ w2b) + grad_w1a = temp @ w1b.T + grad_w1b = w1a.T @ temp + + temp = grad_out * (w1a @ w1b) + grad_w2a = temp @ w2b.T + grad_w2b = w2a.T @ temp + + del temp + return grad_w1a, grad_w1b, grad_w2a, grad_w2b, None + + +class HadaWeightTucker(torch.autograd.Function): + """Tucker-decomposed Hadamard product forward/backward for LoHa Conv2d 3x3+. + + Computes (rebuild(t1, w1b, w1a) * rebuild(t2, w2b, w2a)) * scale + where rebuild = einsum("i j ..., j r, i p -> p r ...", t, wb, wa). + Compatible with LyCORIS parameter naming convention. + """ + + @staticmethod + def forward(ctx, t1, w1b, w1a, t2, w2b, w2a, scale=None): + if scale is None: + scale = torch.tensor(1, device=t1.device, dtype=t1.dtype) + ctx.save_for_backward(t1, w1b, w1a, t2, w2b, w2a, scale) + + rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1b, w1a) + rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2b, w2a) + + return rebuild1 * rebuild2 * scale + + @staticmethod + def backward(ctx, grad_out): + (t1, w1b, w1a, t2, w2b, w2a, scale) = ctx.saved_tensors + grad_out = grad_out * scale + + # Gradients for w1a, w1b, t1 (using rebuild2) + temp = torch.einsum("i j ..., j r -> i r ...", t2, w2b) + rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w2a) + + grad_w = rebuild * grad_out + del rebuild + + grad_w1a = torch.einsum("r j ..., i j ... -> r i", temp, grad_w) + grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w1a.T) + del grad_w, temp + + grad_w1b = torch.einsum("i r ..., i j ... -> r j", t1, grad_temp) + grad_t1 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w1b.T) + del grad_temp + + # Gradients for w2a, w2b, t2 (using rebuild1) + temp = torch.einsum("i j ..., j r -> i r ...", t1, w1b) + rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w1a) + + grad_w = rebuild * grad_out + del rebuild + + grad_w2a = torch.einsum("r j ..., i j ... -> r i", temp, grad_w) + grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w2a.T) + del grad_w, temp + + grad_w2b = torch.einsum("i r ..., i j ... -> r j", t2, grad_temp) + grad_t2 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w2b.T) + del grad_temp + + return grad_t1, grad_w1b, grad_w1a, grad_t2, grad_w2b, grad_w2a, None + + +class LoHaModule(torch.nn.Module): + """LoHa module for training. Replaces forward method of the original Linear/Conv2d.""" + + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=None, + rank_dropout=None, + module_dropout=None, + use_tucker=False, + **kwargs, + ): + super().__init__() + self.lora_name = lora_name + self.lora_dim = lora_dim + + is_conv2d = org_module.__class__.__name__ == "Conv2d" + if is_conv2d: + in_dim = org_module.in_channels + out_dim = org_module.out_channels + kernel_size = org_module.kernel_size + self.is_conv = True + self.stride = org_module.stride + self.padding = org_module.padding + self.dilation = org_module.dilation + self.groups = org_module.groups + self.kernel_size = kernel_size + + self.tucker = use_tucker and any(k != 1 for k in kernel_size) + + if kernel_size == (1, 1): + self.conv_mode = "1x1" + elif self.tucker: + self.conv_mode = "tucker" + else: + self.conv_mode = "flat" + else: + in_dim = org_module.in_features + out_dim = org_module.out_features + self.is_conv = False + self.tucker = False + self.conv_mode = None + self.kernel_size = None + + self.in_dim = in_dim + self.out_dim = out_dim + + # Create parameters based on mode + if self.conv_mode == "tucker": + # Tucker decomposition for Conv2d 3x3+ + # Shapes follow LyCORIS convention: w_a = (rank, out_dim), w_b = (rank, in_dim) + self.hada_t1 = nn.Parameter(torch.empty(lora_dim, lora_dim, *kernel_size)) + self.hada_w1_a = nn.Parameter(torch.empty(lora_dim, out_dim)) + self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, in_dim)) + self.hada_t2 = nn.Parameter(torch.empty(lora_dim, lora_dim, *kernel_size)) + self.hada_w2_a = nn.Parameter(torch.empty(lora_dim, out_dim)) + self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, in_dim)) + + # LyCORIS init: w1_a = 0 (ensures ΔW=0), t1/t2 normal(0.1) + torch.nn.init.normal_(self.hada_t1, std=0.1) + torch.nn.init.normal_(self.hada_t2, std=0.1) + torch.nn.init.normal_(self.hada_w1_b, std=1.0) + torch.nn.init.constant_(self.hada_w1_a, 0) + torch.nn.init.normal_(self.hada_w2_b, std=1.0) + torch.nn.init.normal_(self.hada_w2_a, std=0.1) + elif self.conv_mode == "flat": + # Non-Tucker Conv2d 3x3+: flatten kernel into in_dim + k_prod = 1 + for k in kernel_size: + k_prod *= k + flat_in = in_dim * k_prod + + self.hada_w1_a = nn.Parameter(torch.empty(out_dim, lora_dim)) + self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, flat_in)) + self.hada_w2_a = nn.Parameter(torch.empty(out_dim, lora_dim)) + self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, flat_in)) + + torch.nn.init.normal_(self.hada_w1_a, std=0.1) + torch.nn.init.normal_(self.hada_w1_b, std=1.0) + torch.nn.init.constant_(self.hada_w2_a, 0) + torch.nn.init.normal_(self.hada_w2_b, std=1.0) + else: + # Linear or Conv2d 1x1 + self.hada_w1_a = nn.Parameter(torch.empty(out_dim, lora_dim)) + self.hada_w1_b = nn.Parameter(torch.empty(lora_dim, in_dim)) + self.hada_w2_a = nn.Parameter(torch.empty(out_dim, lora_dim)) + self.hada_w2_b = nn.Parameter(torch.empty(lora_dim, in_dim)) + + torch.nn.init.normal_(self.hada_w1_a, std=0.1) + torch.nn.init.normal_(self.hada_w1_b, std=1.0) + torch.nn.init.constant_(self.hada_w2_a, 0) + torch.nn.init.normal_(self.hada_w2_b, std=1.0) + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() + alpha = lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + self.register_buffer("alpha", torch.tensor(alpha)) + + self.multiplier = multiplier + self.org_module = org_module # remove in applying + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + + def apply_to(self): + self.org_forward = self.org_module.forward + self.org_module.forward = self.forward + del self.org_module + + def get_diff_weight(self): + """Return materialized weight delta. + + Returns: + - Linear: 2D tensor (out_dim, in_dim) + - Conv2d 1x1: 2D tensor (out_dim, in_dim) — caller should unsqueeze for F.conv2d + - Conv2d 3x3+ Tucker: 4D tensor (out_dim, in_dim, k1, k2) + - Conv2d 3x3+ flat: 4D tensor (out_dim, in_dim, k1, k2) + """ + if self.tucker: + scale = torch.tensor(self.scale, dtype=self.hada_t1.dtype, device=self.hada_t1.device) + return HadaWeightTucker.apply( + self.hada_t1, self.hada_w1_b, self.hada_w1_a, + self.hada_t2, self.hada_w2_b, self.hada_w2_a, scale + ) + elif self.conv_mode == "flat": + scale = torch.tensor(self.scale, dtype=self.hada_w1_a.dtype, device=self.hada_w1_a.device) + diff = HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale) + return diff.reshape(self.out_dim, self.in_dim, *self.kernel_size) + else: + scale = torch.tensor(self.scale, dtype=self.hada_w1_a.dtype, device=self.hada_w1_a.device) + return HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale) + + def forward(self, x): + org_forwarded = self.org_forward(x) + + # module dropout + if self.module_dropout is not None and self.training: + if torch.rand(1) < self.module_dropout: + return org_forwarded + + diff_weight = self.get_diff_weight() + + # rank dropout (applied on output dimension) + if self.rank_dropout is not None and self.training: + drop = (torch.rand(diff_weight.size(0), device=diff_weight.device) > self.rank_dropout).to(diff_weight.dtype) + drop = drop.view(-1, *([1] * (diff_weight.dim() - 1))) + diff_weight = diff_weight * drop + scale = 1.0 / (1.0 - self.rank_dropout) + else: + scale = 1.0 + + if self.is_conv: + if self.conv_mode == "1x1": + diff_weight = diff_weight.unsqueeze(2).unsqueeze(3) + return org_forwarded + F.conv2d( + x, diff_weight, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups + ) * self.multiplier * scale + else: + # Conv2d 3x3+: diff_weight is already 4D from get_diff_weight + return org_forwarded + F.conv2d( + x, diff_weight, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups + ) * self.multiplier * scale + else: + return org_forwarded + F.linear(x, diff_weight) * self.multiplier * scale + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + +class LoHaInfModule(LoHaModule): + """LoHa module for inference. Supports merge_to and get_weight.""" + + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + **kwargs, + ): + # no dropout for inference; pass use_tucker from kwargs + use_tucker = kwargs.pop("use_tucker", False) + super().__init__(lora_name, org_module, multiplier, lora_dim, alpha, use_tucker=use_tucker) + + self.org_module_ref = [org_module] + self.enabled = True + self.network: AdditionalNetwork = None + + def set_network(self, network): + self.network = network + + def merge_to(self, sd, dtype, device): + # extract weight from org_module + org_sd = self.org_module.state_dict() + weight = org_sd["weight"] + org_dtype = weight.dtype + org_device = weight.device + weight = weight.to(torch.float) + + if dtype is None: + dtype = org_dtype + if device is None: + device = org_device + + # get LoHa weights + w1a = sd["hada_w1_a"].to(torch.float).to(device) + w1b = sd["hada_w1_b"].to(torch.float).to(device) + w2a = sd["hada_w2_a"].to(torch.float).to(device) + w2b = sd["hada_w2_b"].to(torch.float).to(device) + + if self.tucker: + # Tucker mode + t1 = sd["hada_t1"].to(torch.float).to(device) + t2 = sd["hada_t2"].to(torch.float).to(device) + rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1b, w1a) + rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2b, w2a) + diff_weight = rebuild1 * rebuild2 * self.scale + else: + diff_weight = ((w1a @ w1b) * (w2a @ w2b)) * self.scale + # reshape diff_weight to match original weight shape if needed + if diff_weight.shape != weight.shape: + diff_weight = diff_weight.reshape(weight.shape) + + weight = weight.to(device) + self.multiplier * diff_weight + + org_sd["weight"] = weight.to(dtype) + self.org_module.load_state_dict(org_sd) + + def get_weight(self, multiplier=None): + if multiplier is None: + multiplier = self.multiplier + + if self.tucker: + t1 = self.hada_t1.to(torch.float) + w1a = self.hada_w1_a.to(torch.float) + w1b = self.hada_w1_b.to(torch.float) + t2 = self.hada_t2.to(torch.float) + w2a = self.hada_w2_a.to(torch.float) + w2b = self.hada_w2_b.to(torch.float) + rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1b, w1a) + rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2b, w2a) + weight = rebuild1 * rebuild2 * self.scale * multiplier + else: + w1a = self.hada_w1_a.to(torch.float) + w1b = self.hada_w1_b.to(torch.float) + w2a = self.hada_w2_a.to(torch.float) + w2b = self.hada_w2_b.to(torch.float) + weight = ((w1a @ w1b) * (w2a @ w2b)) * self.scale * multiplier + + if self.is_conv: + if self.conv_mode == "1x1": + weight = weight.unsqueeze(2).unsqueeze(3) + elif self.conv_mode == "flat": + weight = weight.reshape(self.out_dim, self.in_dim, *self.kernel_size) + + return weight + + def default_forward(self, x): + diff_weight = self.get_diff_weight() + if self.is_conv: + if self.conv_mode == "1x1": + diff_weight = diff_weight.unsqueeze(2).unsqueeze(3) + return self.org_forward(x) + F.conv2d( + x, diff_weight, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups + ) * self.multiplier + else: + return self.org_forward(x) + F.linear(x, diff_weight) * self.multiplier + + def forward(self, x): + if not self.enabled: + return self.org_forward(x) + return self.default_forward(x) + + +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + vae, + text_encoder, + unet, + neuron_dropout: Optional[float] = None, + **kwargs, +): + """Create a LoHa network. Called by train_network.py via network_module.create_network().""" + if network_dim is None: + network_dim = 4 + if network_alpha is None: + network_alpha = 1.0 + + # handle text_encoder as list + text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder] + + # detect architecture + arch_config = detect_arch_config(unet, text_encoders) + + # train LLM adapter + train_llm_adapter = kwargs.get("train_llm_adapter", "false") + if train_llm_adapter is not None: + train_llm_adapter = True if str(train_llm_adapter).lower() == "true" else False + + # exclude patterns + exclude_patterns = kwargs.get("exclude_patterns", None) + if exclude_patterns is None: + exclude_patterns = [] + else: + exclude_patterns = ast.literal_eval(exclude_patterns) + if not isinstance(exclude_patterns, list): + exclude_patterns = [exclude_patterns] + + # add default exclude patterns from arch config + exclude_patterns.extend(arch_config.default_excludes) + + # include patterns + include_patterns = kwargs.get("include_patterns", None) + if include_patterns is not None: + include_patterns = ast.literal_eval(include_patterns) + if not isinstance(include_patterns, list): + include_patterns = [include_patterns] + + # rank/module dropout + rank_dropout = kwargs.get("rank_dropout", None) + if rank_dropout is not None: + rank_dropout = float(rank_dropout) + module_dropout = kwargs.get("module_dropout", None) + if module_dropout is not None: + module_dropout = float(module_dropout) + + # conv dim/alpha for Conv2d 3x3 + conv_lora_dim = kwargs.get("conv_dim", None) + conv_alpha = kwargs.get("conv_alpha", None) + if conv_lora_dim is not None: + conv_lora_dim = int(conv_lora_dim) + if conv_alpha is None: + conv_alpha = 1.0 + else: + conv_alpha = float(conv_alpha) + + # Tucker decomposition for Conv2d 3x3 + use_tucker = kwargs.get("use_tucker", "false") + if use_tucker is not None: + use_tucker = True if str(use_tucker).lower() == "true" else False + + # verbose + verbose = kwargs.get("verbose", "false") + if verbose is not None: + verbose = True if str(verbose).lower() == "true" else False + + # regex-specific learning rates / dimensions + network_reg_lrs = kwargs.get("network_reg_lrs", None) + reg_lrs = _parse_kv_pairs(network_reg_lrs, is_int=False) if network_reg_lrs is not None else None + + network_reg_dims = kwargs.get("network_reg_dims", None) + reg_dims = _parse_kv_pairs(network_reg_dims, is_int=True) if network_reg_dims is not None else None + + network = AdditionalNetwork( + text_encoders, + unet, + arch_config=arch_config, + multiplier=multiplier, + lora_dim=network_dim, + alpha=network_alpha, + dropout=neuron_dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + module_class=LoHaModule, + module_kwargs={"use_tucker": use_tucker}, + conv_lora_dim=conv_lora_dim, + conv_alpha=conv_alpha, + train_llm_adapter=train_llm_adapter, + exclude_patterns=exclude_patterns, + include_patterns=include_patterns, + reg_dims=reg_dims, + reg_lrs=reg_lrs, + verbose=verbose, + ) + + # LoRA+ support + loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) + loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None) + loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None) + loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None + loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None + loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None + if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None: + network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) + + return network + + +def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs): + """Create a LoHa network from saved weights. Called by train_network.py.""" + if weights_sd is None: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + # detect dim/alpha from weights + modules_dim = {} + modules_alpha = {} + train_llm_adapter = False + for key, value in weights_sd.items(): + if "." not in key: + continue + + lora_name = key.split(".")[0] + if "alpha" in key: + modules_alpha[lora_name] = value + elif "hada_w1_b" in key: + dim = value.shape[0] + modules_dim[lora_name] = dim + + if "llm_adapter" in lora_name: + train_llm_adapter = True + + # detect Tucker mode from weights + use_tucker = any("hada_t1" in key for key in weights_sd.keys()) + + # handle text_encoder as list + text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder] + + # detect architecture + arch_config = detect_arch_config(unet, text_encoders) + + module_class = LoHaInfModule if for_inference else LoHaModule + module_kwargs = {"use_tucker": use_tucker} + + network = AdditionalNetwork( + text_encoders, + unet, + arch_config=arch_config, + multiplier=multiplier, + modules_dim=modules_dim, + modules_alpha=modules_alpha, + module_class=module_class, + module_kwargs=module_kwargs, + train_llm_adapter=train_llm_adapter, + ) + return network, weights_sd + + +def merge_weights_to_tensor( + model_weight: torch.Tensor, + lora_name: str, + lora_sd: Dict[str, torch.Tensor], + lora_weight_keys: set, + multiplier: float, + calc_device: torch.device, +) -> torch.Tensor: + """Merge LoHa weights directly into a model weight tensor. + + Supports standard LoHa, non-Tucker Conv2d 3x3, and Tucker Conv2d 3x3. + No Module/Network creation needed. Consumed keys are removed from lora_weight_keys. + Returns model_weight unchanged if no matching LoHa keys found. + """ + w1a_key = lora_name + ".hada_w1_a" + w1b_key = lora_name + ".hada_w1_b" + w2a_key = lora_name + ".hada_w2_a" + w2b_key = lora_name + ".hada_w2_b" + t1_key = lora_name + ".hada_t1" + t2_key = lora_name + ".hada_t2" + alpha_key = lora_name + ".alpha" + + if w1a_key not in lora_weight_keys: + return model_weight + + w1a = lora_sd[w1a_key].to(calc_device) + w1b = lora_sd[w1b_key].to(calc_device) + w2a = lora_sd[w2a_key].to(calc_device) + w2b = lora_sd[w2b_key].to(calc_device) + + has_tucker = t1_key in lora_weight_keys + + dim = w1b.shape[0] + alpha = lora_sd.get(alpha_key, torch.tensor(dim)) + if isinstance(alpha, torch.Tensor): + alpha = alpha.item() + scale = alpha / dim + + original_dtype = model_weight.dtype + if original_dtype.itemsize == 1: # fp8 + model_weight = model_weight.to(torch.float16) + w1a, w1b = w1a.to(torch.float16), w1b.to(torch.float16) + w2a, w2b = w2a.to(torch.float16), w2b.to(torch.float16) + + if has_tucker: + # Tucker decomposition: rebuild via einsum + t1 = lora_sd[t1_key].to(calc_device) + t2 = lora_sd[t2_key].to(calc_device) + if original_dtype.itemsize == 1: + t1, t2 = t1.to(torch.float16), t2.to(torch.float16) + rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1b, w1a) + rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2b, w2a) + diff_weight = rebuild1 * rebuild2 * scale + else: + # Standard LoHa: ΔW = ((w1a @ w1b) * (w2a @ w2b)) * scale + diff_weight = ((w1a @ w1b) * (w2a @ w2b)) * scale + + # Reshape diff_weight to match model_weight shape if needed + # (handles Conv2d 1x1 unsqueeze, Conv2d 3x3 non-Tucker reshape, etc.) + if diff_weight.shape != model_weight.shape: + diff_weight = diff_weight.reshape(model_weight.shape) + + model_weight = model_weight + multiplier * diff_weight + + if original_dtype.itemsize == 1: + model_weight = model_weight.to(original_dtype) + + # remove consumed keys + consumed = [w1a_key, w1b_key, w2a_key, w2b_key, alpha_key] + if has_tucker: + consumed.extend([t1_key, t2_key]) + for key in consumed: + lora_weight_keys.discard(key) + + return model_weight diff --git a/networks/lokr.py b/networks/lokr.py new file mode 100644 index 00000000..03b50ca0 --- /dev/null +++ b/networks/lokr.py @@ -0,0 +1,683 @@ +# LoKr (Low-rank Kronecker Product) network module +# Reference: https://arxiv.org/abs/2309.14859 +# +# Based on the LyCORIS project by KohakuBlueleaf +# https://github.com/KohakuBlueleaf/LyCORIS + +import ast +import math +import os +import logging +from typing import Dict, List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .network_base import ArchConfig, AdditionalNetwork, detect_arch_config, _parse_kv_pairs +from library.utils import setup_logging + +setup_logging() +logger = logging.getLogger(__name__) + + +def factorization(dimension: int, factor: int = -1) -> tuple: + """Return a tuple of two values whose product equals dimension, + optimized for balanced factors. + + In LoKr, the first value is for the weight scale (smaller), + and the second value is for the weight (larger). + + Examples: + factor=-1: 128 -> (8, 16), 512 -> (16, 32), 1024 -> (32, 32) + factor=4: 128 -> (4, 32), 512 -> (4, 128) + """ + if factor > 0 and (dimension % factor) == 0: + m = factor + n = dimension // factor + if m > n: + n, m = m, n + return m, n + if factor < 0: + factor = dimension + m, n = 1, dimension + length = m + n + while m < n: + new_m = m + 1 + while dimension % new_m != 0: + new_m += 1 + new_n = dimension // new_m + if new_m + new_n > length or new_m > factor: + break + else: + m, n = new_m, new_n + if m > n: + n, m = m, n + return m, n + + +def make_kron(w1, w2, scale): + """Compute Kronecker product of w1 and w2, scaled by scale.""" + if w1.dim() != w2.dim(): + for _ in range(w2.dim() - w1.dim()): + w1 = w1.unsqueeze(-1) + w2 = w2.contiguous() + rebuild = torch.kron(w1, w2) + if scale != 1: + rebuild = rebuild * scale + return rebuild + + +def rebuild_tucker(t, wa, wb): + """Rebuild weight from Tucker decomposition: einsum("i j ..., i p, j r -> p r ...", t, wa, wb). + + Compatible with LyCORIS convention. + """ + return torch.einsum("i j ..., i p, j r -> p r ...", t, wa, wb) + + +class LoKrModule(torch.nn.Module): + """LoKr module for training. Replaces forward method of the original Linear/Conv2d.""" + + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=None, + rank_dropout=None, + module_dropout=None, + factor=-1, + use_tucker=False, + **kwargs, + ): + super().__init__() + self.lora_name = lora_name + self.lora_dim = lora_dim + + is_conv2d = org_module.__class__.__name__ == "Conv2d" + if is_conv2d: + in_dim = org_module.in_channels + out_dim = org_module.out_channels + kernel_size = org_module.kernel_size + self.is_conv = True + self.stride = org_module.stride + self.padding = org_module.padding + self.dilation = org_module.dilation + self.groups = org_module.groups + self.kernel_size = kernel_size + + self.tucker = use_tucker and any(k != 1 for k in kernel_size) + + if kernel_size == (1, 1): + self.conv_mode = "1x1" + elif self.tucker: + self.conv_mode = "tucker" + else: + self.conv_mode = "flat" + else: + in_dim = org_module.in_features + out_dim = org_module.out_features + self.is_conv = False + self.tucker = False + self.conv_mode = None + self.kernel_size = None + + self.in_dim = in_dim + self.out_dim = out_dim + + factor = int(factor) + self.use_w2 = False + + # Factorize dimensions + in_m, in_n = factorization(in_dim, factor) + out_l, out_k = factorization(out_dim, factor) + + # w1 is always a full matrix (the "scale" factor, small) + self.lokr_w1 = nn.Parameter(torch.empty(out_l, in_m)) + + # w2: depends on mode + if self.conv_mode in ("tucker", "flat"): + # Conv2d 3x3+ modes + k_size = kernel_size + + if lora_dim >= max(out_k, in_n) / 2: + # Full matrix mode (includes kernel dimensions) + self.use_w2 = True + self.lokr_w2 = nn.Parameter(torch.empty(out_k, in_n, *k_size)) + logger.warning( + f"LoKr: lora_dim {lora_dim} is large for dim={max(in_dim, out_dim)} " + f"and factor={factor}, using full matrix mode for Conv2d." + ) + elif self.tucker: + # Tucker mode: separate kernel into t2 tensor + self.lokr_t2 = nn.Parameter(torch.empty(lora_dim, lora_dim, *k_size)) + self.lokr_w2_a = nn.Parameter(torch.empty(lora_dim, out_k)) + self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, in_n)) + else: + # Non-Tucker: flatten kernel into w2_b + k_prod = 1 + for k in k_size: + k_prod *= k + self.lokr_w2_a = nn.Parameter(torch.empty(out_k, lora_dim)) + self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, in_n * k_prod)) + else: + # Linear or Conv2d 1x1 + if lora_dim < max(out_k, in_n) / 2: + self.lokr_w2_a = nn.Parameter(torch.empty(out_k, lora_dim)) + self.lokr_w2_b = nn.Parameter(torch.empty(lora_dim, in_n)) + else: + self.use_w2 = True + self.lokr_w2 = nn.Parameter(torch.empty(out_k, in_n)) + if lora_dim >= max(out_k, in_n) / 2: + logger.warning( + f"LoKr: lora_dim {lora_dim} is large for dim={max(in_dim, out_dim)} " + f"and factor={factor}, using full matrix mode." + ) + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() + alpha = lora_dim if alpha is None or alpha == 0 else alpha + # if both w1 and w2 are full matrices, use scale = 1 + if self.use_w2: + alpha = lora_dim + self.scale = alpha / self.lora_dim + self.register_buffer("alpha", torch.tensor(alpha)) + + # Initialization + torch.nn.init.kaiming_uniform_(self.lokr_w1, a=math.sqrt(5)) + if self.use_w2: + torch.nn.init.constant_(self.lokr_w2, 0) + else: + if self.tucker: + torch.nn.init.kaiming_uniform_(self.lokr_t2, a=math.sqrt(5)) + torch.nn.init.kaiming_uniform_(self.lokr_w2_a, a=math.sqrt(5)) + torch.nn.init.constant_(self.lokr_w2_b, 0) + # Ensures ΔW = kron(w1, 0) = 0 at init + + self.multiplier = multiplier + self.org_module = org_module # remove in applying + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + + def apply_to(self): + self.org_forward = self.org_module.forward + self.org_module.forward = self.forward + del self.org_module + + def get_diff_weight(self): + """Return materialized weight delta. + + Returns: + - Linear: 2D tensor (out_dim, in_dim) + - Conv2d 1x1: 2D tensor (out_dim, in_dim) — caller should unsqueeze for F.conv2d + - Conv2d 3x3+ Tucker/full: 4D tensor (out_dim, in_dim, k1, k2) + - Conv2d 3x3+ flat: 4D tensor (out_dim, in_dim, k1, k2) — reshaped from 2D + """ + w1 = self.lokr_w1 + + if self.use_w2: + w2 = self.lokr_w2 + elif self.tucker: + w2 = rebuild_tucker(self.lokr_t2, self.lokr_w2_a, self.lokr_w2_b) + else: + w2 = self.lokr_w2_a @ self.lokr_w2_b + + result = make_kron(w1, w2, self.scale) + + # For non-Tucker Conv2d 3x3+, result is 2D; reshape to 4D + if self.conv_mode == "flat" and result.dim() == 2: + result = result.reshape(self.out_dim, self.in_dim, *self.kernel_size) + + return result + + def forward(self, x): + org_forwarded = self.org_forward(x) + + # module dropout + if self.module_dropout is not None and self.training: + if torch.rand(1) < self.module_dropout: + return org_forwarded + + diff_weight = self.get_diff_weight() + + # rank dropout + if self.rank_dropout is not None and self.training: + drop = (torch.rand(diff_weight.size(0), device=diff_weight.device) > self.rank_dropout).to(diff_weight.dtype) + drop = drop.view(-1, *([1] * (diff_weight.dim() - 1))) + diff_weight = diff_weight * drop + scale = 1.0 / (1.0 - self.rank_dropout) + else: + scale = 1.0 + + if self.is_conv: + if self.conv_mode == "1x1": + diff_weight = diff_weight.unsqueeze(2).unsqueeze(3) + return org_forwarded + F.conv2d( + x, diff_weight, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups + ) * self.multiplier * scale + else: + # Conv2d 3x3+: diff_weight is already 4D from get_diff_weight + return org_forwarded + F.conv2d( + x, diff_weight, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups + ) * self.multiplier * scale + else: + return org_forwarded + F.linear(x, diff_weight) * self.multiplier * scale + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + +class LoKrInfModule(LoKrModule): + """LoKr module for inference. Supports merge_to and get_weight.""" + + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + **kwargs, + ): + # no dropout for inference; pass factor and use_tucker from kwargs + factor = kwargs.pop("factor", -1) + use_tucker = kwargs.pop("use_tucker", False) + super().__init__(lora_name, org_module, multiplier, lora_dim, alpha, factor=factor, use_tucker=use_tucker) + + self.org_module_ref = [org_module] + self.enabled = True + self.network: AdditionalNetwork = None + + def set_network(self, network): + self.network = network + + def merge_to(self, sd, dtype, device): + # extract weight from org_module + org_sd = self.org_module.state_dict() + weight = org_sd["weight"] + org_dtype = weight.dtype + org_device = weight.device + weight = weight.to(torch.float) + + if dtype is None: + dtype = org_dtype + if device is None: + device = org_device + + # get LoKr weights + w1 = sd["lokr_w1"].to(torch.float).to(device) + + if "lokr_w2" in sd: + w2 = sd["lokr_w2"].to(torch.float).to(device) + elif "lokr_t2" in sd: + # Tucker mode + t2 = sd["lokr_t2"].to(torch.float).to(device) + w2a = sd["lokr_w2_a"].to(torch.float).to(device) + w2b = sd["lokr_w2_b"].to(torch.float).to(device) + w2 = rebuild_tucker(t2, w2a, w2b) + else: + w2a = sd["lokr_w2_a"].to(torch.float).to(device) + w2b = sd["lokr_w2_b"].to(torch.float).to(device) + w2 = w2a @ w2b + + # compute ΔW via Kronecker product + diff_weight = make_kron(w1, w2, self.scale) + + # reshape diff_weight to match original weight shape if needed + if diff_weight.shape != weight.shape: + diff_weight = diff_weight.reshape(weight.shape) + + weight = weight.to(device) + self.multiplier * diff_weight + + org_sd["weight"] = weight.to(dtype) + self.org_module.load_state_dict(org_sd) + + def get_weight(self, multiplier=None): + if multiplier is None: + multiplier = self.multiplier + + w1 = self.lokr_w1.to(torch.float) + + if self.use_w2: + w2 = self.lokr_w2.to(torch.float) + elif self.tucker: + w2 = rebuild_tucker( + self.lokr_t2.to(torch.float), + self.lokr_w2_a.to(torch.float), + self.lokr_w2_b.to(torch.float), + ) + else: + w2 = (self.lokr_w2_a @ self.lokr_w2_b).to(torch.float) + + weight = make_kron(w1, w2, self.scale) * multiplier + + # reshape to match original weight shape if needed + if self.is_conv: + if self.conv_mode == "1x1": + weight = weight.unsqueeze(2).unsqueeze(3) + elif self.conv_mode == "flat" and weight.dim() == 2: + weight = weight.reshape(self.out_dim, self.in_dim, *self.kernel_size) + # Tucker and full matrix modes: already 4D from kron + + return weight + + def default_forward(self, x): + diff_weight = self.get_diff_weight() + if self.is_conv: + if self.conv_mode == "1x1": + diff_weight = diff_weight.unsqueeze(2).unsqueeze(3) + return self.org_forward(x) + F.conv2d( + x, diff_weight, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups + ) * self.multiplier + else: + return self.org_forward(x) + F.linear(x, diff_weight) * self.multiplier + + def forward(self, x): + if not self.enabled: + return self.org_forward(x) + return self.default_forward(x) + + +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + vae, + text_encoder, + unet, + neuron_dropout: Optional[float] = None, + **kwargs, +): + """Create a LoKr network. Called by train_network.py via network_module.create_network().""" + if network_dim is None: + network_dim = 4 + if network_alpha is None: + network_alpha = 1.0 + + # handle text_encoder as list + text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder] + + # detect architecture + arch_config = detect_arch_config(unet, text_encoders) + + # train LLM adapter + train_llm_adapter = kwargs.get("train_llm_adapter", "false") + if train_llm_adapter is not None: + train_llm_adapter = True if str(train_llm_adapter).lower() == "true" else False + + # exclude patterns + exclude_patterns = kwargs.get("exclude_patterns", None) + if exclude_patterns is None: + exclude_patterns = [] + else: + exclude_patterns = ast.literal_eval(exclude_patterns) + if not isinstance(exclude_patterns, list): + exclude_patterns = [exclude_patterns] + + # add default exclude patterns from arch config + exclude_patterns.extend(arch_config.default_excludes) + + # include patterns + include_patterns = kwargs.get("include_patterns", None) + if include_patterns is not None: + include_patterns = ast.literal_eval(include_patterns) + if not isinstance(include_patterns, list): + include_patterns = [include_patterns] + + # rank/module dropout + rank_dropout = kwargs.get("rank_dropout", None) + if rank_dropout is not None: + rank_dropout = float(rank_dropout) + module_dropout = kwargs.get("module_dropout", None) + if module_dropout is not None: + module_dropout = float(module_dropout) + + # conv dim/alpha for Conv2d 3x3 + conv_lora_dim = kwargs.get("conv_dim", None) + conv_alpha = kwargs.get("conv_alpha", None) + if conv_lora_dim is not None: + conv_lora_dim = int(conv_lora_dim) + if conv_alpha is None: + conv_alpha = 1.0 + else: + conv_alpha = float(conv_alpha) + + # Tucker decomposition for Conv2d 3x3 + use_tucker = kwargs.get("use_tucker", "false") + if use_tucker is not None: + use_tucker = True if str(use_tucker).lower() == "true" else False + + # factor for LoKr + factor = int(kwargs.get("factor", -1)) + + # verbose + verbose = kwargs.get("verbose", "false") + if verbose is not None: + verbose = True if str(verbose).lower() == "true" else False + + # regex-specific learning rates / dimensions + network_reg_lrs = kwargs.get("network_reg_lrs", None) + reg_lrs = _parse_kv_pairs(network_reg_lrs, is_int=False) if network_reg_lrs is not None else None + + network_reg_dims = kwargs.get("network_reg_dims", None) + reg_dims = _parse_kv_pairs(network_reg_dims, is_int=True) if network_reg_dims is not None else None + + network = AdditionalNetwork( + text_encoders, + unet, + arch_config=arch_config, + multiplier=multiplier, + lora_dim=network_dim, + alpha=network_alpha, + dropout=neuron_dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + module_class=LoKrModule, + module_kwargs={"factor": factor, "use_tucker": use_tucker}, + conv_lora_dim=conv_lora_dim, + conv_alpha=conv_alpha, + train_llm_adapter=train_llm_adapter, + exclude_patterns=exclude_patterns, + include_patterns=include_patterns, + reg_dims=reg_dims, + reg_lrs=reg_lrs, + verbose=verbose, + ) + + # LoRA+ support + loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) + loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None) + loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None) + loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None + loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None + loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None + if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None: + network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) + + return network + + +def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs): + """Create a LoKr network from saved weights. Called by train_network.py.""" + if weights_sd is None: + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + # detect dim/alpha from weights + modules_dim = {} + modules_alpha = {} + train_llm_adapter = False + use_tucker = False + for key, value in weights_sd.items(): + if "." not in key: + continue + + lora_name = key.split(".")[0] + if "alpha" in key: + modules_alpha[lora_name] = value + elif "lokr_w2_a" in key: + # low-rank mode: dim detection depends on Tucker vs non-Tucker + if "lokr_t2" in key.replace("lokr_w2_a", "lokr_t2") and lora_name + ".lokr_t2" in weights_sd: + # Tucker: w2_a = (rank, out_k) → dim = w2_a.shape[0] + dim = value.shape[0] + else: + # Non-Tucker: w2_a = (out_k, rank) → dim = w2_a.shape[1] + dim = value.shape[1] + modules_dim[lora_name] = dim + elif "lokr_w2" in key and "lokr_w2_a" not in key and "lokr_w2_b" not in key: + # full matrix mode: set dim large enough to trigger full-matrix path + if lora_name not in modules_dim: + modules_dim[lora_name] = max(value.shape[0], value.shape[1]) + + if "lokr_t2" in key: + use_tucker = True + + if "llm_adapter" in lora_name: + train_llm_adapter = True + + # handle text_encoder as list + text_encoders = text_encoder if isinstance(text_encoder, list) else [text_encoder] + + # detect architecture + arch_config = detect_arch_config(unet, text_encoders) + + # extract factor for LoKr + factor = int(kwargs.get("factor", -1)) + + module_class = LoKrInfModule if for_inference else LoKrModule + module_kwargs = {"factor": factor, "use_tucker": use_tucker} + + network = AdditionalNetwork( + text_encoders, + unet, + arch_config=arch_config, + multiplier=multiplier, + modules_dim=modules_dim, + modules_alpha=modules_alpha, + module_class=module_class, + module_kwargs=module_kwargs, + train_llm_adapter=train_llm_adapter, + ) + return network, weights_sd + + +def merge_weights_to_tensor( + model_weight: torch.Tensor, + lora_name: str, + lora_sd: Dict[str, torch.Tensor], + lora_weight_keys: set, + multiplier: float, + calc_device: torch.device, +) -> torch.Tensor: + """Merge LoKr weights directly into a model weight tensor. + + Supports standard LoKr, non-Tucker Conv2d 3x3, and Tucker Conv2d 3x3. + No Module/Network creation needed. Consumed keys are removed from lora_weight_keys. + Returns model_weight unchanged if no matching LoKr keys found. + """ + w1_key = lora_name + ".lokr_w1" + w2_key = lora_name + ".lokr_w2" + w2a_key = lora_name + ".lokr_w2_a" + w2b_key = lora_name + ".lokr_w2_b" + t2_key = lora_name + ".lokr_t2" + alpha_key = lora_name + ".alpha" + + if w1_key not in lora_weight_keys: + return model_weight + + w1 = lora_sd[w1_key].to(calc_device) + + # determine mode: full matrix vs Tucker vs low-rank + has_tucker = t2_key in lora_weight_keys + + if w2a_key in lora_weight_keys: + w2a = lora_sd[w2a_key].to(calc_device) + w2b = lora_sd[w2b_key].to(calc_device) + + if has_tucker: + # Tucker: w2a = (rank, out_k), dim = rank + dim = w2a.shape[0] + else: + # Non-Tucker low-rank: w2a = (out_k, rank), dim = rank + dim = w2a.shape[1] + + consumed_keys = [w1_key, w2a_key, w2b_key, alpha_key] + if has_tucker: + consumed_keys.append(t2_key) + elif w2_key in lora_weight_keys: + # full matrix mode + w2a = None + w2b = None + dim = None + consumed_keys = [w1_key, w2_key, alpha_key] + else: + return model_weight + + alpha = lora_sd.get(alpha_key, None) + if alpha is not None and isinstance(alpha, torch.Tensor): + alpha = alpha.item() + + # compute scale + if w2a is not None: + if alpha is None: + alpha = dim + scale = alpha / dim + else: + # full matrix mode: scale = 1.0 + scale = 1.0 + + original_dtype = model_weight.dtype + if original_dtype.itemsize == 1: # fp8 + model_weight = model_weight.to(torch.float16) + w1 = w1.to(torch.float16) + if w2a is not None: + w2a, w2b = w2a.to(torch.float16), w2b.to(torch.float16) + + # compute w2 + if w2a is not None: + if has_tucker: + t2 = lora_sd[t2_key].to(calc_device) + if original_dtype.itemsize == 1: + t2 = t2.to(torch.float16) + w2 = rebuild_tucker(t2, w2a, w2b) + else: + w2 = w2a @ w2b + else: + w2 = lora_sd[w2_key].to(calc_device) + if original_dtype.itemsize == 1: + w2 = w2.to(torch.float16) + + # ΔW = kron(w1, w2) * scale + diff_weight = make_kron(w1, w2, scale) + + # Reshape diff_weight to match model_weight shape if needed + # (handles Conv2d 1x1 unsqueeze, Conv2d 3x3 non-Tucker reshape, etc.) + if diff_weight.shape != model_weight.shape: + diff_weight = diff_weight.reshape(model_weight.shape) + + model_weight = model_weight + multiplier * diff_weight + + if original_dtype.itemsize == 1: + model_weight = model_weight.to(original_dtype) + + # remove consumed keys + for key in consumed_keys: + lora_weight_keys.discard(key) + + return model_weight diff --git a/networks/lora_anima.py b/networks/lora_anima.py index 9413e8c8..4cff2819 100644 --- a/networks/lora_anima.py +++ b/networks/lora_anima.py @@ -1,11 +1,11 @@ # LoRA network module for Anima import ast +import math import os import re from typing import Dict, List, Optional, Tuple, Type, Union import torch from library.utils import setup_logging -from networks.lora_flux import LoRAModule, LoRAInfModule import logging @@ -13,6 +13,213 @@ setup_logging() logger = logging.getLogger(__name__) +class LoRAModule(torch.nn.Module): + """ + replaces forward method of the original Linear, instead of replacing the original Linear module. + """ + + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=None, + rank_dropout=None, + module_dropout=None, + ): + """ + if alpha == 0 or None, alpha is rank (no scaling). + """ + super().__init__() + self.lora_name = lora_name + + if org_module.__class__.__name__ == "Conv2d": + in_dim = org_module.in_channels + out_dim = org_module.out_channels + else: + in_dim = org_module.in_features + out_dim = org_module.out_features + + self.lora_dim = lora_dim + + if org_module.__class__.__name__ == "Conv2d": + kernel_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) + self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) + else: + self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) + self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) + + torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + torch.nn.init.zeros_(self.lora_up.weight) + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = self.lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える + + # same as microsoft's + self.multiplier = multiplier + self.org_module = org_module # remove in applying + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + + def apply_to(self): + self.org_forward = self.org_module.forward + self.org_module.forward = self.forward + + del self.org_module + + def forward(self, x): + org_forwarded = self.org_forward(x) + + # module dropout + if self.module_dropout is not None and self.training: + if torch.rand(1) < self.module_dropout: + return org_forwarded + + lx = self.lora_down(x) + + # normal dropout + if self.dropout is not None and self.training: + lx = torch.nn.functional.dropout(lx, p=self.dropout) + + # rank dropout + if self.rank_dropout is not None and self.training: + mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout + if isinstance(self.lora_down, torch.nn.Conv2d): + # Conv2d: lora_dim is at dim 1 → [B, dim, 1, 1] + mask = mask.unsqueeze(-1).unsqueeze(-1) + else: + # Linear: lora_dim is at last dim → [B, 1, ..., 1, dim] + for _ in range(len(lx.size()) - 2): + mask = mask.unsqueeze(1) + lx = lx * mask + + # scaling for rank dropout: treat as if the rank is changed + # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる + scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + else: + scale = self.scale + + lx = self.lora_up(lx) + + return org_forwarded + lx * self.multiplier * scale + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype + + +class LoRAInfModule(LoRAModule): + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + **kwargs, + ): + # no dropout for inference + super().__init__(lora_name, org_module, multiplier, lora_dim, alpha) + + self.org_module_ref = [org_module] # 後から参照できるように + self.enabled = True + self.network: LoRANetwork = None + + def set_network(self, network): + self.network = network + + # freezeしてマージする + def merge_to(self, sd, dtype, device): + # extract weight from org_module + org_sd = self.org_module.state_dict() + weight = org_sd["weight"] + org_dtype = weight.dtype + org_device = weight.device + weight = weight.to(torch.float) # calc in float + + if dtype is None: + dtype = org_dtype + if device is None: + device = org_device + + # get up/down weight + down_weight = sd["lora_down.weight"].to(torch.float).to(device) + up_weight = sd["lora_up.weight"].to(torch.float).to(device) + + # merge weight + if len(weight.size()) == 2: + # linear + weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + self.multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * self.scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + self.multiplier * conved * self.scale + + # set weight to org_module + org_sd["weight"] = weight.to(dtype) + self.org_module.load_state_dict(org_sd) + + # 復元できるマージのため、このモジュールのweightを返す + def get_weight(self, multiplier=None): + if multiplier is None: + multiplier = self.multiplier + + # get up/down weight from module + up_weight = self.lora_up.weight.to(torch.float) + down_weight = self.lora_down.weight.to(torch.float) + + # pre-calculated weight + if len(down_weight.size()) == 2: + # linear + weight = self.multiplier * (up_weight @ down_weight) * self.scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + self.multiplier + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * self.scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + weight = self.multiplier * conved * self.scale + + return weight + + def default_forward(self, x): + # logger.info(f"default_forward {self.lora_name} {x.size()}") + lx = self.lora_down(x) + lx = self.lora_up(lx) + return self.org_forward(x) + lx * self.multiplier * self.scale + + def forward(self, x): + if not self.enabled: + return self.org_forward(x) + return self.default_forward(x) + + def create_network( multiplier: float, network_dim: Optional[int], diff --git a/networks/network_base.py b/networks/network_base.py new file mode 100644 index 00000000..d9697562 --- /dev/null +++ b/networks/network_base.py @@ -0,0 +1,545 @@ +# Shared network base for additional network modules (like LyCORIS-family modules: LoHa, LoKr, etc). +# Provides architecture detection and a generic AdditionalNetwork class. + +import os +import re +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Tuple, Type, Union + +import torch +from library.sdxl_original_unet import InferSdxlUNet2DConditionModel +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +@dataclass +class ArchConfig: + unet_target_modules: List[str] + te_target_modules: List[str] + unet_prefix: str + te_prefixes: List[str] + default_excludes: List[str] = field(default_factory=list) + adapter_target_modules: List[str] = field(default_factory=list) + unet_conv_target_modules: List[str] = field(default_factory=list) + + +def detect_arch_config(unet, text_encoders) -> ArchConfig: + """Detect architecture from model structure and return ArchConfig.""" + from library.sdxl_original_unet import SdxlUNet2DConditionModel + + # Check SDXL first + if unet is not None and ( + issubclass(unet.__class__, SdxlUNet2DConditionModel) or issubclass(unet.__class__, InferSdxlUNet2DConditionModel) + ): + return ArchConfig( + unet_target_modules=["Transformer2DModel"], + te_target_modules=["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"], + unet_prefix="lora_unet", + te_prefixes=["lora_te1", "lora_te2"], + default_excludes=[], + unet_conv_target_modules=["ResnetBlock2D", "Downsample2D", "Upsample2D"], + ) + + # Check Anima: look for Block class in named_modules + module_class_names = set() + if unet is not None: + for module in unet.modules(): + module_class_names.add(type(module).__name__) + + if "Block" in module_class_names: + return ArchConfig( + unet_target_modules=["Block", "PatchEmbed", "TimestepEmbedding", "FinalLayer"], + te_target_modules=["Qwen3Attention", "Qwen3MLP", "Qwen3SdpaAttention", "Qwen3FlashAttention2"], + unet_prefix="lora_unet", + te_prefixes=["lora_te"], + default_excludes=[r".*(_modulation|_norm|_embedder|final_layer).*"], + adapter_target_modules=["LLMAdapterTransformerBlock"], + ) + + raise ValueError(f"Cannot auto-detect architecture for LyCORIS. Module classes found: {sorted(module_class_names)}") + + +def _parse_kv_pairs(kv_pair_str: str, is_int: bool) -> Dict[str, Union[int, float]]: + """Parse a string of key-value pairs separated by commas.""" + pairs = {} + for pair in kv_pair_str.split(","): + pair = pair.strip() + if not pair: + continue + if "=" not in pair: + logger.warning(f"Invalid format: {pair}, expected 'key=value'") + continue + key, value = pair.split("=", 1) + key = key.strip() + value = value.strip() + try: + pairs[key] = int(value) if is_int else float(value) + except ValueError: + logger.warning(f"Invalid value for {key}: {value}") + return pairs + + +class AdditionalNetwork(torch.nn.Module): + """Generic Additional network that supports LoHa, LoKr, and similar module types. + + Constructed with a module_class parameter to inject the specific module type. + Based on the lora_anima.py LoRANetwork, generalized for multiple architectures. + """ + + def __init__( + self, + text_encoders: list, + unet, + arch_config: ArchConfig, + multiplier: float = 1.0, + lora_dim: int = 4, + alpha: float = 1, + dropout: Optional[float] = None, + rank_dropout: Optional[float] = None, + module_dropout: Optional[float] = None, + module_class: Type[torch.nn.Module] = None, + module_kwargs: Optional[Dict] = None, + modules_dim: Optional[Dict[str, int]] = None, + modules_alpha: Optional[Dict[str, int]] = None, + conv_lora_dim: Optional[int] = None, + conv_alpha: Optional[float] = None, + exclude_patterns: Optional[List[str]] = None, + include_patterns: Optional[List[str]] = None, + reg_dims: Optional[Dict[str, int]] = None, + reg_lrs: Optional[Dict[str, float]] = None, + train_llm_adapter: bool = False, + verbose: bool = False, + ) -> None: + super().__init__() + assert module_class is not None, "module_class must be specified" + + self.multiplier = multiplier + self.lora_dim = lora_dim + self.alpha = alpha + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + self.conv_lora_dim = conv_lora_dim + self.conv_alpha = conv_alpha + self.train_llm_adapter = train_llm_adapter + self.reg_dims = reg_dims + self.reg_lrs = reg_lrs + self.arch_config = arch_config + + self.loraplus_lr_ratio = None + self.loraplus_unet_lr_ratio = None + self.loraplus_text_encoder_lr_ratio = None + + if module_kwargs is None: + module_kwargs = {} + + if modules_dim is not None: + logger.info(f"create {module_class.__name__} network from weights") + else: + logger.info(f"create {module_class.__name__} network. base dim (rank): {lora_dim}, alpha: {alpha}") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) + + # compile regular expressions + def str_to_re_patterns(patterns: Optional[List[str]]) -> List[re.Pattern]: + re_patterns = [] + if patterns is not None: + for pattern in patterns: + try: + re_pattern = re.compile(pattern) + except re.error as e: + logger.error(f"Invalid pattern '{pattern}': {e}") + continue + re_patterns.append(re_pattern) + return re_patterns + + exclude_re_patterns = str_to_re_patterns(exclude_patterns) + include_re_patterns = str_to_re_patterns(include_patterns) + + # create module instances + def create_modules( + prefix: str, + root_module: torch.nn.Module, + target_replace_modules: List[str], + default_dim: Optional[int] = None, + ) -> Tuple[List[torch.nn.Module], List[str]]: + loras = [] + skipped = [] + for name, module in root_module.named_modules(): + if target_replace_modules is None or module.__class__.__name__ in target_replace_modules: + if target_replace_modules is None: + module = root_module + + for child_name, child_module in module.named_modules(): + is_linear = child_module.__class__.__name__ == "Linear" + is_conv2d = child_module.__class__.__name__ == "Conv2d" + is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) + + if is_linear or is_conv2d: + original_name = (name + "." if name else "") + child_name + lora_name = f"{prefix}.{original_name}".replace(".", "_") + + # exclude/include filter + excluded = any(pattern.fullmatch(original_name) for pattern in exclude_re_patterns) + included = any(pattern.fullmatch(original_name) for pattern in include_re_patterns) + if excluded and not included: + if verbose: + logger.info(f"exclude: {original_name}") + continue + + dim = None + alpha_val = None + + if modules_dim is not None: + if lora_name in modules_dim: + dim = modules_dim[lora_name] + alpha_val = modules_alpha[lora_name] + else: + if self.reg_dims is not None: + for reg, d in self.reg_dims.items(): + if re.fullmatch(reg, original_name): + dim = d + alpha_val = self.alpha + logger.info(f"Module {original_name} matched with regex '{reg}' -> dim: {dim}") + break + # fallback to default dim + if dim is None: + if is_linear or is_conv2d_1x1: + dim = default_dim if default_dim is not None else self.lora_dim + alpha_val = self.alpha + elif is_conv2d and self.conv_lora_dim is not None: + dim = self.conv_lora_dim + alpha_val = self.conv_alpha + + if dim is None or dim == 0: + if is_linear or is_conv2d_1x1: + skipped.append(lora_name) + continue + + lora = module_class( + lora_name, + child_module, + self.multiplier, + dim, + alpha_val, + dropout=dropout, + rank_dropout=rank_dropout, + module_dropout=module_dropout, + **module_kwargs, + ) + lora.original_name = original_name + loras.append(lora) + + if target_replace_modules is None: + break + return loras, skipped + + # Create modules for text encoders + self.text_encoder_loras: List[torch.nn.Module] = [] + skipped_te = [] + if text_encoders is not None: + for i, text_encoder in enumerate(text_encoders): + if text_encoder is None: + continue + + # Determine prefix for this text encoder + if i < len(arch_config.te_prefixes): + te_prefix = arch_config.te_prefixes[i] + else: + te_prefix = arch_config.te_prefixes[0] + + logger.info(f"create {module_class.__name__} for Text Encoder {i+1} (prefix={te_prefix}):") + te_loras, te_skipped = create_modules(te_prefix, text_encoder, arch_config.te_target_modules) + logger.info(f"create {module_class.__name__} for Text Encoder {i+1}: {len(te_loras)} modules.") + self.text_encoder_loras.extend(te_loras) + skipped_te += te_skipped + + # Create modules for UNet/DiT + target_modules = list(arch_config.unet_target_modules) + if modules_dim is not None or conv_lora_dim is not None: + target_modules.extend(arch_config.unet_conv_target_modules) + if train_llm_adapter and arch_config.adapter_target_modules: + target_modules.extend(arch_config.adapter_target_modules) + + self.unet_loras: List[torch.nn.Module] + self.unet_loras, skipped_un = create_modules(arch_config.unet_prefix, unet, target_modules) + logger.info(f"create {module_class.__name__} for UNet/DiT: {len(self.unet_loras)} modules.") + + if verbose: + for lora in self.unet_loras: + logger.info(f"\t{lora.lora_name:60} {lora.lora_dim}, {lora.alpha}") + + skipped = skipped_te + skipped_un + if verbose and len(skipped) > 0: + logger.warning(f"dim (rank) is 0, {len(skipped)} modules are skipped:") + for name in skipped: + logger.info(f"\t{name}") + + # assertion: no duplicate names + names = set() + for lora in self.text_encoder_loras + self.unet_loras: + assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" + names.add(lora.lora_name) + + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for lora in self.text_encoder_loras + self.unet_loras: + lora.multiplier = self.multiplier + + def set_enabled(self, is_enabled): + for lora in self.text_encoder_loras + self.unet_loras: + lora.enabled = is_enabled + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + + info = self.load_state_dict(weights_sd, False) + return info + + def apply_to(self, text_encoders, unet, apply_text_encoder=True, apply_unet=True): + if apply_text_encoder: + logger.info(f"enable modules for text encoder: {len(self.text_encoder_loras)} modules") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info(f"enable modules for UNet/DiT: {len(self.unet_loras)} modules") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + lora.apply_to() + self.add_module(lora.lora_name, lora) + + def is_mergeable(self): + return True + + def merge_to(self, text_encoders, unet, weights_sd, dtype=None, device=None): + apply_text_encoder = apply_unet = False + te_prefixes = self.arch_config.te_prefixes + unet_prefix = self.arch_config.unet_prefix + + for key in weights_sd.keys(): + if any(key.startswith(p) for p in te_prefixes): + apply_text_encoder = True + elif key.startswith(unet_prefix): + apply_unet = True + + if apply_text_encoder: + logger.info("enable modules for text encoder") + else: + self.text_encoder_loras = [] + + if apply_unet: + logger.info("enable modules for UNet/DiT") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + sd_for_lora = {} + for key in weights_sd.keys(): + if key.startswith(lora.lora_name): + sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key] + lora.merge_to(sd_for_lora, dtype, device) + + logger.info("weights are merged") + + def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio): + self.loraplus_lr_ratio = loraplus_lr_ratio + self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio + self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio + + logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}") + logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}") + + def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr): + if text_encoder_lr is None or (isinstance(text_encoder_lr, list) and len(text_encoder_lr) == 0): + text_encoder_lr = [default_lr] + elif isinstance(text_encoder_lr, float) or isinstance(text_encoder_lr, int): + text_encoder_lr = [float(text_encoder_lr)] + elif len(text_encoder_lr) == 1: + pass # already a list with one element + + self.requires_grad_(True) + + all_params = [] + lr_descriptions = [] + + def assemble_params(loras, lr, loraplus_ratio): + param_groups = {"lora": {}, "plus": {}} + reg_groups = {} + reg_lrs_list = list(self.reg_lrs.items()) if self.reg_lrs is not None else [] + + for lora in loras: + matched_reg_lr = None + for i, (regex_str, reg_lr) in enumerate(reg_lrs_list): + if re.fullmatch(regex_str, lora.original_name): + matched_reg_lr = (i, reg_lr) + logger.info(f"Module {lora.original_name} matched regex '{regex_str}' -> LR {reg_lr}") + break + + for name, param in lora.named_parameters(): + if matched_reg_lr is not None: + reg_idx, reg_lr = matched_reg_lr + group_key = f"reg_lr_{reg_idx}" + if group_key not in reg_groups: + reg_groups[group_key] = {"lora": {}, "plus": {}, "lr": reg_lr} + # LoRA+ detection: check for "up" weight parameters + if loraplus_ratio is not None and self._is_plus_param(name): + reg_groups[group_key]["plus"][f"{lora.lora_name}.{name}"] = param + else: + reg_groups[group_key]["lora"][f"{lora.lora_name}.{name}"] = param + continue + + if loraplus_ratio is not None and self._is_plus_param(name): + param_groups["plus"][f"{lora.lora_name}.{name}"] = param + else: + param_groups["lora"][f"{lora.lora_name}.{name}"] = param + + params = [] + descriptions = [] + for group_key, group in reg_groups.items(): + reg_lr = group["lr"] + for key in ("lora", "plus"): + param_data = {"params": group[key].values()} + if len(param_data["params"]) == 0: + continue + if key == "plus": + param_data["lr"] = reg_lr * loraplus_ratio if loraplus_ratio is not None else reg_lr + else: + param_data["lr"] = reg_lr + if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: + logger.info("NO LR skipping!") + continue + params.append(param_data) + desc = f"reg_lr_{group_key.split('_')[-1]}" + descriptions.append(desc + (" plus" if key == "plus" else "")) + + for key in param_groups.keys(): + param_data = {"params": param_groups[key].values()} + if len(param_data["params"]) == 0: + continue + if lr is not None: + if key == "plus": + param_data["lr"] = lr * loraplus_ratio + else: + param_data["lr"] = lr + if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: + logger.info("NO LR skipping!") + continue + params.append(param_data) + descriptions.append("plus" if key == "plus" else "") + return params, descriptions + + if self.text_encoder_loras: + loraplus_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio + # Group TE loras by prefix + for te_idx, te_prefix in enumerate(self.arch_config.te_prefixes): + te_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(te_prefix)] + if len(te_loras) > 0: + te_lr = text_encoder_lr[te_idx] if te_idx < len(text_encoder_lr) else text_encoder_lr[0] + logger.info(f"Text Encoder {te_idx+1} ({te_prefix}): {len(te_loras)} modules, LR {te_lr}") + params, descriptions = assemble_params(te_loras, te_lr, loraplus_ratio) + all_params.extend(params) + lr_descriptions.extend([f"textencoder {te_idx+1}" + (" " + d if d else "") for d in descriptions]) + + if self.unet_loras: + params, descriptions = assemble_params( + self.unet_loras, + unet_lr if unet_lr is not None else default_lr, + self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio, + ) + all_params.extend(params) + lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions]) + + return all_params, lr_descriptions + + def _is_plus_param(self, name: str) -> bool: + """Check if a parameter name corresponds to a 'plus' (higher LR) param for LoRA+. + + For LoRA: lora_up. For LoHa: hada_w2_a (the second pair). For LoKr: lokr_w1 (the scale factor). + Override in subclass if needed. Default: check for common 'up' patterns. + """ + return "lora_up" in name or "hada_w2_a" in name or "lokr_w1" in name + + def enable_gradient_checkpointing(self): + pass # not supported + + def prepare_grad_etc(self, text_encoder, unet): + self.requires_grad_(True) + + def on_epoch_start(self, text_encoder, unet): + self.train() + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + from library import train_util + + if metadata is None: + metadata = {} + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + + def backup_weights(self): + loras = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not hasattr(org_module, "_lora_org_weight"): + sd = org_module.state_dict() + org_module._lora_org_weight = sd["weight"].detach().clone() + org_module._lora_restored = True + + def restore_weights(self): + loras = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + if not org_module._lora_restored: + sd = org_module.state_dict() + sd["weight"] = org_module._lora_org_weight + org_module.load_state_dict(sd) + org_module._lora_restored = True + + def pre_calculation(self): + loras = self.text_encoder_loras + self.unet_loras + for lora in loras: + org_module = lora.org_module_ref[0] + sd = org_module.state_dict() + + org_weight = sd["weight"] + lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype) + sd["weight"] = org_weight + lora_weight + assert sd["weight"].shape == org_weight.shape + org_module.load_state_dict(sd) + + org_module._lora_restored = False + lora.enabled = False From 1cd95b2d8b3b994214681ab30cbdc74f9abc44ef Mon Sep 17 00:00:00 2001 From: woctordho Date: Thu, 19 Mar 2026 07:43:39 +0800 Subject: [PATCH 07/17] Add `skip_image_resolution` to deduplicate multi-resolution dataset (#2273) * Add min_orig_resolution and max_orig_resolution * Rename min_orig_resolution to skip_image_resolution; remove max_orig_resolution * Change skip_image_resolution to tuple * Move filtering to __init__ * Minor fix --- library/config_util.py | 6 ++- library/train_util.py | 84 ++++++++++++++++++++++++++++++++++++++---- train_network.py | 2 + 3 files changed, 84 insertions(+), 8 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index 53727f25..b31f9665 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -108,6 +108,7 @@ class BaseDatasetParams: validation_seed: Optional[int] = None validation_split: float = 0.0 resize_interpolation: Optional[str] = None + skip_image_resolution: Optional[Tuple[int, int]] = None @dataclass class DreamBoothDatasetParams(BaseDatasetParams): @@ -118,7 +119,7 @@ class DreamBoothDatasetParams(BaseDatasetParams): bucket_reso_steps: int = 64 bucket_no_upscale: bool = False prior_loss_weight: float = 1.0 - + @dataclass class FineTuningDatasetParams(BaseDatasetParams): batch_size: int = 1 @@ -244,6 +245,7 @@ class ConfigSanitizer: "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), "network_multiplier": float, "resize_interpolation": str, + "skip_image_resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), } # options handled by argparse but not handled by user config @@ -256,6 +258,7 @@ class ConfigSanitizer: ARGPARSE_NULLABLE_OPTNAMES = [ "face_crop_aug_range", "resolution", + "skip_image_resolution", ] # prepare map because option name may differ among argparse and user config ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = { @@ -528,6 +531,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu [{dataset_type} {i}] batch_size: {dataset.batch_size} resolution: {(dataset.width, dataset.height)} + skip_image_resolution: {dataset.skip_image_resolution} resize_interpolation: {dataset.resize_interpolation} enable_bucket: {dataset.enable_bucket} """) diff --git a/library/train_util.py b/library/train_util.py index d8577b9d..b65f06b9 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -687,6 +687,7 @@ class BaseDataset(torch.utils.data.Dataset): network_multiplier: float, debug_dataset: bool, resize_interpolation: Optional[str] = None, + skip_image_resolution: Optional[Tuple[int, int]] = None, ) -> None: super().__init__() @@ -727,6 +728,8 @@ class BaseDataset(torch.utils.data.Dataset): ), f'Resize interpolation "{resize_interpolation}" is not a valid interpolation' self.resize_interpolation = resize_interpolation + self.skip_image_resolution = skip_image_resolution + self.image_data: Dict[str, ImageInfo] = {} self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {} @@ -1915,8 +1918,15 @@ class DreamBoothDataset(BaseDataset): validation_split: float, validation_seed: Optional[int], resize_interpolation: Optional[str], + skip_image_resolution: Optional[Tuple[int, int]] = None, ) -> None: - super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation) + super().__init__( + resolution, + network_multiplier, + debug_dataset, + resize_interpolation, + skip_image_resolution, + ) assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" @@ -2034,6 +2044,22 @@ class DreamBoothDataset(BaseDataset): size_set_count += 1 logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}") + if self.skip_image_resolution is not None: + filtered_img_paths = [] + filtered_sizes = [] + skip_image_area = self.skip_image_resolution[0] * self.skip_image_resolution[1] + for img_path, size in zip(img_paths, sizes): + if size is None: # no latents cache file, get image size by reading image file (slow) + size = self.get_image_size(img_path) + if size[0] * size[1] <= skip_image_area: + continue + filtered_img_paths.append(img_path) + filtered_sizes.append(size) + if len(filtered_img_paths) < len(img_paths): + logger.info(f"filtered {len(img_paths) - len(filtered_img_paths)} images by original resolution from {subset.image_dir}") + img_paths = filtered_img_paths + sizes = filtered_sizes + # We want to create a training and validation split. This should be improved in the future # to allow a clearer distinction between training and validation. This can be seen as a # short-term solution to limit what is necessary to implement validation datasets @@ -2059,7 +2085,7 @@ class DreamBoothDataset(BaseDataset): logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") if use_cached_info_for_subset: - captions = [meta["caption"] for meta in metas.values()] + captions = [metas[img_path]["caption"] for img_path in img_paths] missing_captions = [img_path for img_path, caption in zip(img_paths, captions) if caption is None or caption == ""] else: # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う @@ -2200,8 +2226,15 @@ class FineTuningDataset(BaseDataset): validation_seed: int, validation_split: float, resize_interpolation: Optional[str], + skip_image_resolution: Optional[Tuple[int, int]] = None, ) -> None: - super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation) + super().__init__( + resolution, + network_multiplier, + debug_dataset, + resize_interpolation, + skip_image_resolution, + ) self.batch_size = batch_size self.size = min(self.width, self.height) # 短いほう @@ -2297,6 +2330,7 @@ class FineTuningDataset(BaseDataset): tags_list = [] size_set_from_metadata = 0 size_set_from_cache_filename = 0 + num_filtered = 0 for image_key in image_keys_sorted_by_length_desc: img_md = metadata[image_key] caption = img_md.get("caption") @@ -2355,6 +2389,16 @@ class FineTuningDataset(BaseDataset): image_info.image_size = (w, h) size_set_from_cache_filename += 1 + if self.skip_image_resolution is not None: + size = image_info.image_size + if size is None: # no image size in metadata or latents cache file, get image size by reading image file (slow) + size = self.get_image_size(abs_path) + image_info.image_size = size + skip_image_area = self.skip_image_resolution[0] * self.skip_image_resolution[1] + if size[0] * size[1] <= skip_image_area: + num_filtered += 1 + continue + self.register_image(image_info, subset) if size_set_from_cache_filename > 0: @@ -2363,6 +2407,8 @@ class FineTuningDataset(BaseDataset): ) if size_set_from_metadata > 0: logger.info(f"set image size from metadata: {size_set_from_metadata}/{len(image_keys_sorted_by_length_desc)}") + if num_filtered > 0: + logger.info(f"filtered {num_filtered} images by original resolution from {subset.metadata_file}") self.num_train_images += len(metadata) * subset.num_repeats # TODO do not record tag freq when no tag @@ -2387,8 +2433,15 @@ class ControlNetDataset(BaseDataset): validation_split: float, validation_seed: Optional[int], resize_interpolation: Optional[str] = None, + skip_image_resolution: Optional[Tuple[int, int]] = None, ) -> None: - super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation) + super().__init__( + resolution, + network_multiplier, + debug_dataset, + resize_interpolation, + skip_image_resolution, + ) db_subsets = [] for subset in subsets: @@ -2440,6 +2493,7 @@ class ControlNetDataset(BaseDataset): validation_split, validation_seed, resize_interpolation, + skip_image_resolution, ) # config_util等から参照される値をいれておく(若干微妙なのでなんとかしたい) @@ -2487,9 +2541,10 @@ class ControlNetDataset(BaseDataset): assert ( len(missing_imgs) == 0 ), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}" - assert ( - len(extra_imgs) == 0 - ), f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}" + if len(extra_imgs) > 0: + logger.warning( + f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}" + ) self.conditioning_image_transforms = IMAGE_TRANSFORMS @@ -4601,6 +4656,13 @@ def add_dataset_arguments( help="maximum resolution for buckets, must be divisible by bucket_reso_steps " " / bucketの最大解像度、bucket_reso_stepsで割り切れる必要があります", ) + parser.add_argument( + "--skip_image_resolution", + type=str, + default=None, + help="images not larger than this resolution will be skipped ('size' or 'width,height')" + " / この解像度以下の画像はスキップされます('サイズ'指定、または'幅,高さ'指定)", + ) parser.add_argument( "--bucket_reso_steps", type=int, @@ -5414,6 +5476,14 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool): len(args.resolution) == 2 ), f"resolution must be 'size' or 'width,height' / resolution(解像度)は'サイズ'または'幅','高さ'で指定してください: {args.resolution}" + if args.skip_image_resolution is not None: + args.skip_image_resolution = tuple([int(r) for r in args.skip_image_resolution.split(",")]) + if len(args.skip_image_resolution) == 1: + args.skip_image_resolution = (args.skip_image_resolution[0], args.skip_image_resolution[0]) + assert ( + len(args.skip_image_resolution) == 2 + ), f"skip_image_resolution must be 'size' or 'width,height' / skip_image_resolutionは'サイズ'または'幅','高さ'で指定してください: {args.skip_image_resolution}" + if args.face_crop_aug_range is not None: args.face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(",")]) assert ( diff --git a/train_network.py b/train_network.py index 2f8797d2..2ee671e9 100644 --- a/train_network.py +++ b/train_network.py @@ -1085,6 +1085,7 @@ class NetworkTrainer: "enable_bucket": bool(dataset.enable_bucket), "min_bucket_reso": dataset.min_bucket_reso, "max_bucket_reso": dataset.max_bucket_reso, + "skip_image_resolution": dataset.skip_image_resolution, "tag_frequency": dataset.tag_frequency, "bucket_info": dataset.bucket_info, "resize_interpolation": dataset.resize_interpolation, @@ -1191,6 +1192,7 @@ class NetworkTrainer: "ss_bucket_no_upscale": bool(dataset.bucket_no_upscale), "ss_min_bucket_reso": dataset.min_bucket_reso, "ss_max_bucket_reso": dataset.max_bucket_reso, + "ss_skip_image_resolution": dataset.skip_image_resolution, "ss_keep_tokens": args.keep_tokens, "ss_dataset_dirs": json.dumps(dataset_dirs_info), "ss_reg_dataset_dirs": json.dumps(reg_dataset_dirs_info), From 7c159291e9dc5afb074d8f95e0028c4e87f0dc5b Mon Sep 17 00:00:00 2001 From: "Kohya S." <52813779+kohya-ss@users.noreply.github.com> Date: Thu, 19 Mar 2026 09:17:29 +0900 Subject: [PATCH 08/17] docs: add skip_image_resolution to config README (#2288) * docs: add skip_image_resolution option to config README Document the skip_image_resolution dataset option added in PR #2273. Add option description, multi-resolution dataset TOML example, and command-line argument entry to both Japanese and English config READMEs. Co-Authored-By: Claude Opus 4.6 * docs: clarify `skip_image_resolution` functionality in dataset config --------- Co-authored-by: Claude Opus 4.6 --- docs/config_README-en.md | 33 +++++++++++++++++++++++++++++++++ docs/config_README-ja.md | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+) diff --git a/docs/config_README-en.md b/docs/config_README-en.md index 78687ee6..6b55a985 100644 --- a/docs/config_README-en.md +++ b/docs/config_README-en.md @@ -122,11 +122,15 @@ These are options related to the configuration of the data set. They cannot be d | `max_bucket_reso` | `1024` | o | o | | `min_bucket_reso` | `128` | o | o | | `resolution` | `256`, `[512, 512]` | o | o | +| `skip_image_resolution` | `768`, `[512, 768]` | o | o | * `batch_size` * This corresponds to the command-line argument `--train_batch_size`. * `max_bucket_reso`, `min_bucket_reso` * Specify the maximum and minimum resolutions of the bucket. It must be divisible by `bucket_reso_steps`. +* `skip_image_resolution` + * Images whose original resolution (area) is equal to or smaller than the specified resolution will be skipped. Specify as `'size'` or `[width, height]`. This corresponds to the command-line argument `--skip_image_resolution`. + * Useful when sharing the same image directory across multiple datasets with different resolutions, to exclude low-resolution source images from higher-resolution datasets. These settings are fixed per dataset. That means that subsets belonging to the same dataset will share these settings. For example, if you want to prepare datasets with different resolutions, you can define them as separate datasets as shown in the example above, and set different resolutions for each. @@ -254,6 +258,34 @@ resolution = 768 image_dir = 'C:\hoge' ``` +When using multi-resolution datasets, you can use `skip_image_resolution` to exclude images whose original size is too small for higher-resolution datasets. This prevents overlapping of low-resolution images across datasets and improves training quality. This option can also be used to simply exclude low-resolution source images from datasets. + +```toml +[general] +enable_bucket = true +bucket_no_upscale = true +max_bucket_reso = 1536 + +[[datasets]] +resolution = 768 + [[datasets.subsets]] + image_dir = 'C:\hoge' + +[[datasets]] +resolution = 1024 +skip_image_resolution = 768 + [[datasets.subsets]] + image_dir = 'C:\hoge' + +[[datasets]] +resolution = 1280 +skip_image_resolution = 1024 + [[datasets.subsets]] + image_dir = 'C:\hoge' +``` + +In this example, the 1024-resolution dataset skips images whose original size is 768x768 or smaller, and the 1280-resolution dataset skips images whose original size is 1024x1024 or smaller. + ## Command Line Argument and Configuration File There are options in the configuration file that have overlapping roles with command line argument options. @@ -284,6 +316,7 @@ For the command line options listed below, if an option is specified in both the | `--random_crop` | | | `--resolution` | | | `--shuffle_caption` | | +| `--skip_image_resolution` | | | `--train_batch_size` | `batch_size` | ## Error Guide diff --git a/docs/config_README-ja.md b/docs/config_README-ja.md index aec0eca5..61d3e251 100644 --- a/docs/config_README-ja.md +++ b/docs/config_README-ja.md @@ -115,11 +115,15 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学 | `max_bucket_reso` | `1024` | o | o | | `min_bucket_reso` | `128` | o | o | | `resolution` | `256`, `[512, 512]` | o | o | +| `skip_image_resolution` | `768`, `[512, 768]` | o | o | * `batch_size` * コマンドライン引数の `--train_batch_size` と同等です。 * `max_bucket_reso`, `min_bucket_reso` * bucketの最大、最小解像度を指定します。`bucket_reso_steps` で割り切れる必要があります。 +* `skip_image_resolution` + * 指定した解像度(面積)以下の画像をスキップします。`'サイズ'` または `[幅, 高さ]` で指定します。コマンドライン引数の `--skip_image_resolution` と同等です。 + * 同じ画像ディレクトリを異なる解像度の複数のデータセットで使い回す場合に、低解像度の元画像を高解像度のデータセットから除外するために使用します。 これらの設定はデータセットごとに固定です。 つまり、データセットに所属するサブセットはこれらの設定を共有することになります。 @@ -259,6 +263,34 @@ resolution = 768 image_dir = 'C:\hoge' ``` +なお、マルチ解像度データセットでは `skip_image_resolution` を使用して、元の画像サイズが小さい画像を高解像度データセットから除外できます。これにより、低解像度画像のデータセット間での重複を防ぎ、学習品質を向上させることができます。また、小さい画像を除外するフィルターとしても機能します。 + +```toml +[general] +enable_bucket = true +bucket_no_upscale = true +max_bucket_reso = 1536 + +[[datasets]] +resolution = 768 + [[datasets.subsets]] + image_dir = 'C:\hoge' + +[[datasets]] +resolution = 1024 +skip_image_resolution = 768 + [[datasets.subsets]] + image_dir = 'C:\hoge' + +[[datasets]] +resolution = 1280 +skip_image_resolution = 1024 + [[datasets.subsets]] + image_dir = 'C:\hoge' +``` + +この例では、1024 解像度のデータセットでは元の画像サイズが 768x768 以下の画像がスキップされ、1280 解像度のデータセットでは 1024x1024 以下の画像がスキップされます。 + ## コマンドライン引数との併用 設定ファイルのオプションの中には、コマンドライン引数のオプションと役割が重複しているものがあります。 @@ -289,6 +321,7 @@ resolution = 768 | `--random_crop` | | | `--resolution` | | | `--shuffle_caption` | | +| `--skip_image_resolution` | | | `--train_batch_size` | `batch_size` | ## エラーの手引き From 343c929e39801c8f9b0131dbe224943da692b4a2 Mon Sep 17 00:00:00 2001 From: woctordho Date: Wed, 24 Sep 2025 00:32:00 +0800 Subject: [PATCH 09/17] Log d*lr for ProdigyPlusScheduleFree --- library/train_util.py | 6 +++++- train_network.py | 33 ++++++++------------------------- 2 files changed, 13 insertions(+), 26 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index d8577b9d..e95a4612 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -6183,10 +6183,14 @@ def append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names): name = names[lr_index] logs["lr/" + name] = float(lrs[lr_index]) - if optimizer_type.lower().startswith("DAdapt".lower()) or optimizer_type.lower() == "Prodigy".lower(): + if optimizer_type.lower().startswith("DAdapt".lower()) or optimizer_type.lower().startswith("Prodigy".lower()): logs["lr/d*lr/" + name] = ( lr_scheduler.optimizers[-1].param_groups[lr_index]["d"] * lr_scheduler.optimizers[-1].param_groups[lr_index]["lr"] ) + if "effective_lr" in lr_scheduler.optimizers[-1].param_groups[lr_index]: + logs["lr/d*eff_lr/" + name] = ( + lr_scheduler.optimizers[-1].param_groups[lr_index]["d"] * lr_scheduler.optimizers[-1].param_groups[lr_index]["effective_lr"] + ) # scheduler: diff --git a/train_network.py b/train_network.py index 2f8797d2..ae8d6d0f 100644 --- a/train_network.py +++ b/train_network.py @@ -90,40 +90,23 @@ class NetworkTrainer: if lr_descriptions is not None: lr_desc = lr_descriptions[i] else: - idx = i - (0 if args.network_train_unet_only else -1) + idx = i - (0 if args.network_train_unet_only else 1) if idx == -1: lr_desc = "textencoder" else: if len(lrs) > 2: - lr_desc = f"group{idx}" + lr_desc = f"group{i}" else: lr_desc = "unet" logs[f"lr/{lr_desc}"] = lr - if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): - # tracking d*lr value - logs[f"lr/d*lr/{lr_desc}"] = ( - lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] - ) - if ( - args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None - ): # tracking d*lr value of unet. - logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] - else: - idx = 0 - if not args.network_train_unet_only: - logs["lr/textencoder"] = float(lrs[0]) - idx = 1 - - for i in range(idx, len(lrs)): - logs[f"lr/group{i}"] = float(lrs[i]) - if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): - logs[f"lr/d*lr/group{i}"] = ( - lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] - ) - if args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None: - logs[f"lr/d*lr/group{i}"] = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower().startswith("Prodigy".lower()): + opt = lr_scheduler.optimizers[-1] if hasattr(lr_scheduler, "optimizers") else optimizer + if opt is not None: + logs[f"lr/d*lr/{lr_desc}"] = opt.param_groups[i]["d"] * opt.param_groups[i]["lr"] + if "effective_lr" in opt.param_groups[i]: + logs[f"lr/d*eff_lr/{lr_desc}"] = opt.param_groups[i]["d"] * opt.param_groups[i]["effective_lr"] return logs From cdb49f9fe7730164e068fea06159cd9bd76a1cb3 Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Sun, 22 Mar 2026 22:19:47 +0900 Subject: [PATCH 10/17] fix: Anima validation dataset not working with Text Encoder output caching due to caption dropout --- anima_train_network.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/anima_train_network.py b/anima_train_network.py index eaad7197..ff770a9f 100644 --- a/anima_train_network.py +++ b/anima_train_network.py @@ -286,7 +286,9 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): t.requires_grad_(True) # Unpack text encoder conditions - prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoder_conds + prompt_embeds, attn_mask, t5_input_ids, t5_attn_mask = text_encoder_conds[ + :4 + ] # ignore caption_dropout_rate which is not needed for training step # Move to device prompt_embeds = prompt_embeds.to(accelerator.device, dtype=weight_dtype) @@ -353,7 +355,8 @@ class AnimaNetworkTrainer(train_network.NetworkTrainer): text_encoder_outputs_list = anima_text_encoding_strategy.drop_cached_text_encoder_outputs( *text_encoder_outputs_list, caption_dropout_rates=caption_dropout_rates ) - batch["text_encoder_outputs_list"] = text_encoder_outputs_list + # Add the caption dropout rates back to the list for validation dataset (which is re-used batch items) + batch["text_encoder_outputs_list"] = text_encoder_outputs_list + [caption_dropout_rates] return super().process_batch( batch, From 0e168dd1eb9eb683faff315f754cc7d1da36096a Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Sun, 29 Mar 2026 20:33:33 +0900 Subject: [PATCH 11/17] add --svd_lowrank_niter option to resize_lora.py Allow users to control the number of iterations for torch.svd_lowrank on large matrices. Default is 2 (matching PR #2240 behavior). Set to 0 to disable svd_lowrank and use full SVD instead. Co-Authored-By: Claude Opus 4.6 --- networks/resize_lora.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/networks/resize_lora.py b/networks/resize_lora.py index 5dd1132f..a616b6ac 100644 --- a/networks/resize_lora.py +++ b/networks/resize_lora.py @@ -85,13 +85,13 @@ def index_sv_ratio(S, target): # Modified from Kohaku-blueleaf's extract/merge functions -def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1): +def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1, svd_lowrank_niter=2): out_size, in_size, kernel_size, _ = weight.size() weight = weight.reshape(out_size, -1) _in_size = in_size * kernel_size * kernel_size - if out_size > 2048 and _in_size > 2048: - U, S, V = torch.svd_lowrank(weight.to(device), q=min(2 * lora_rank, out_size, _in_size)) + if svd_lowrank_niter > 0 and out_size > 2048 and _in_size > 2048: + U, S, V = torch.svd_lowrank(weight.to(device), q=min(2 * lora_rank, out_size, _in_size), niter=svd_lowrank_niter) Vh = V.T else: U, S, Vh = torch.linalg.svd(weight.to(device)) @@ -110,11 +110,11 @@ def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale return param_dict -def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1): +def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1, svd_lowrank_niter=2): out_size, in_size = weight.size() - if out_size > 2048 and in_size > 2048: - U, S, V = torch.svd_lowrank(weight.to(device), q=min(2 * lora_rank, out_size, in_size)) + if svd_lowrank_niter > 0 and out_size > 2048 and in_size > 2048: + U, S, V = torch.svd_lowrank(weight.to(device), q=min(2 * lora_rank, out_size, in_size), niter=svd_lowrank_niter) Vh = V.T else: U, S, Vh = torch.linalg.svd(weight.to(device)) @@ -209,7 +209,7 @@ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1): return param_dict -def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dynamic_method, dynamic_param, verbose): +def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dynamic_method, dynamic_param, verbose, svd_lowrank_niter=2): max_old_rank = None new_alpha = None verbose_str = "\n" @@ -273,10 +273,10 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna if conv2d: full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device) - param_dict = extract_conv(full_weight_matrix, new_conv_rank, dynamic_method, dynamic_param, device, scale) + param_dict = extract_conv(full_weight_matrix, new_conv_rank, dynamic_method, dynamic_param, device, scale, svd_lowrank_niter) else: full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device) - param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale) + param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale, svd_lowrank_niter) if verbose: max_ratio = param_dict["max_ratio"] @@ -347,7 +347,7 @@ def resize(args): logger.info("Resizing Lora...") state_dict, old_dim, new_alpha = resize_lora_model( - lora_sd, args.new_rank, args.new_conv_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose + lora_sd, args.new_rank, args.new_conv_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose, args.svd_lowrank_niter ) # update metadata @@ -425,6 +425,13 @@ def setup_parser() -> argparse.ArgumentParser: help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank", ) parser.add_argument("--dynamic_param", type=float, default=None, help="Specify target for dynamic reduction") + parser.add_argument( + "--svd_lowrank_niter", + type=int, + default=2, + help="Number of iterations for svd_lowrank on large matrices (>2048 dims). 0 to disable and use full SVD" + " / 大行列(2048次元超)に対するsvd_lowrankの反復回数。0で無効化し完全SVDを使用", + ) return parser From 4be0e94fad67a58a1c7d941f68aa50639110e7c5 Mon Sep 17 00:00:00 2001 From: woctordho Date: Sun, 29 Mar 2026 19:35:00 +0800 Subject: [PATCH 12/17] Merge pull request #2194 from woct0rdho/rank1 Fix the 'off by 1' problem in dynamically resized LoRA rank --- networks/resize_lora.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/networks/resize_lora.py b/networks/resize_lora.py index 5dd1132f..2f586a8a 100644 --- a/networks/resize_lora.py +++ b/networks/resize_lora.py @@ -59,8 +59,8 @@ def save_to_file(file_name, state_dict, metadata): def index_sv_cumulative(S, target): original_sum = float(torch.sum(S)) cumulative_sums = torch.cumsum(S, dim=0) / original_sum - index = int(torch.searchsorted(cumulative_sums, target)) + 1 - index = max(1, min(index, len(S) - 1)) + index = int(torch.searchsorted(cumulative_sums, target)) + index = max(0, min(index, len(S) - 1)) return index @@ -69,8 +69,8 @@ def index_sv_fro(S, target): S_squared = S.pow(2) S_fro_sq = float(torch.sum(S_squared)) sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq - index = int(torch.searchsorted(sum_S_squared, target**2)) + 1 - index = max(1, min(index, len(S) - 1)) + index = int(torch.searchsorted(sum_S_squared, target**2)) + index = max(0, min(index, len(S) - 1)) return index @@ -78,8 +78,8 @@ def index_sv_fro(S, target): def index_sv_ratio(S, target): max_sv = S[0] min_sv = max_sv / target - index = int(torch.sum(S > min_sv).item()) - index = max(1, min(index, len(S) - 1)) + index = int(torch.sum(S > min_sv).item()) - 1 + index = max(0, min(index, len(S) - 1)) return index From 5cdad10de52ec87640afb729adfba94ecef4a3bf Mon Sep 17 00:00:00 2001 From: "Kohya S." <52813779+kohya-ss@users.noreply.github.com> Date: Sun, 29 Mar 2026 20:41:43 +0900 Subject: [PATCH 13/17] Fix/leco cleanup (#2294) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: SD1.x/2.x と SDXL 向けの LECO 学習スクリプトを追加 (#2285) * Add LECO training script and associated tests - Implemented `sdxl_train_leco.py` for training with LECO prompts, including argument parsing, model setup, training loop, and weight saving functionality. - Created unit tests for `load_prompt_settings` in `test_leco_train_util.py` to validate loading of prompt configurations in both original and slider formats. - Added basic syntax tests for `train_leco.py` and `sdxl_train_leco.py` to ensure modules are importable. * fix: use getattr for safe attribute access in argument verification * feat: add CUDA device compatibility validation and corresponding tests * Revert "feat: add CUDA device compatibility validation and corresponding tests" This reverts commit 6d3e51431be4f207b2ebddc975c6b0a2196576ad. * feat: update predict_noise_xl to use vector embedding from add_time_ids * feat: implement checkpointing in predict_noise and predict_noise_xl functions * feat: remove unused submodules and update .gitignore to exclude .codex-tmp --------- Co-authored-by: Kohya S. <52813779+kohya-ss@users.noreply.github.com> * fix: format * fix: LECO PR #2285 のレビュー指摘事項を修正 - train_util.py/deepspeed_utils.py の getattr 化を元に戻し、LECO パーサーにダミー引数を追加 - sdxl_train_util のモジュールレベルインポートをローカルインポートに変更 - PromptEmbedsCache.__getitem__ でキャッシュミス時に KeyError を送出するよう修正 - 設定ファイル形式を YAML から TOML に変更(リポジトリの規約に統一) - 重複コード (build_network_kwargs, get_save_extension, save_weights) を leco_train_util.py に統合 - _expand_slider_target の冗長な PromptSettings 構築を簡素化 - add_time_ids 用に専用の batch_add_time_ids 関数を追加 Co-Authored-By: Claude Opus 4.6 * docs: LECO 学習ガイドを大幅に拡充 コマンドライン引数の全カテゴリ別解説、プロンプト TOML の全フィールド説明、 2つの guidance_scale の違い、推奨設定表、YAML からの変換ガイド等を追加。 英語本文と日本語折り畳みの二言語構成。 Co-Authored-By: Claude Opus 4.6 * fix: apply_noise_offset の dtype 不一致を修正 torch.randn のデフォルト float32 により latents が暗黙的にアップキャストされる問題を修正。 float32/CPU で生成後に latents の dtype/device へ変換する安全なパターンを採用。 Co-Authored-By: Claude Opus 4.6 --------- Co-authored-by: Umisetokikaze <52318966+umisetokikaze@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 --- .gitignore | 3 +- docs/train_leco.md | 736 ++++++++++++++++++++++++++ library/leco_train_util.py | 522 ++++++++++++++++++ library/train_util.py | 11 +- sdxl_train_leco.py | 342 ++++++++++++ tests/library/test_leco_train_util.py | 116 ++++ tests/test_sdxl_train_leco.py | 16 + tests/test_train_leco.py | 15 + train_leco.py | 319 +++++++++++ 9 files changed, 2074 insertions(+), 6 deletions(-) create mode 100644 docs/train_leco.md create mode 100644 library/leco_train_util.py create mode 100644 sdxl_train_leco.py create mode 100644 tests/library/test_leco_train_util.py create mode 100644 tests/test_sdxl_train_leco.py create mode 100644 tests/test_train_leco.py create mode 100644 train_leco.py diff --git a/.gitignore b/.gitignore index f5772a7f..79b9dc3d 100644 --- a/.gitignore +++ b/.gitignore @@ -11,4 +11,5 @@ GEMINI.md .claude .gemini MagicMock -references \ No newline at end of file +.codex-tmp +references diff --git a/docs/train_leco.md b/docs/train_leco.md new file mode 100644 index 00000000..0896c58c --- /dev/null +++ b/docs/train_leco.md @@ -0,0 +1,736 @@ +# LECO Training Guide / LECO 学習ガイド + +LECO (Low-rank adaptation for Erasing COncepts from diffusion models) is a technique for training LoRA models that modify or erase concepts from a diffusion model **without requiring any image dataset**. It works by training a LoRA against the model's own noise predictions using text prompts only. + +This repository provides two LECO training scripts: + +- `train_leco.py` for Stable Diffusion 1.x / 2.x +- `sdxl_train_leco.py` for SDXL + +
+日本語 + +LECO (Low-rank adaptation for Erasing COncepts from diffusion models) は、**画像データセットを一切必要とせず**、テキストプロンプトのみを使用してモデル自身のノイズ予測に対して LoRA を学習させる手法です。拡散モデルから概念を変更・消去する LoRA モデルを作成できます。 + +このリポジトリでは以下の2つの LECO 学習スクリプトを提供しています: + +- `train_leco.py` : Stable Diffusion 1.x / 2.x 用 +- `sdxl_train_leco.py` : SDXL 用 +
+ +## 1. Overview / 概要 + +### What LECO Can Do / LECO でできること + +LECO can be used for: + +- **Concept erasing**: Remove a specific style or concept (e.g., erase "van gogh" style from generated images) +- **Concept enhancing**: Strengthen a specific attribute (e.g., make "detailed" more pronounced) +- **Slider LoRA**: Create a LoRA that controls an attribute bidirectionally (e.g., a slider between "short hair" and "long hair") + +Unlike standard LoRA training, LECO does not use any training images. All training signals come from the difference between the model's own noise predictions on different text prompts. + +
+日本語 + +LECO は以下の用途に使用できます: + +- **概念の消去**: 特定のスタイルや概念を除去する(例:生成画像から「van gogh」スタイルを消去) +- **概念の強化**: 特定の属性を強化する(例:「detailed」をより顕著にする) +- **スライダー LoRA**: 属性を双方向に制御する LoRA を作成する(例:「short hair」と「long hair」の間のスライダー) + +通常の LoRA 学習とは異なり、LECO は学習画像を一切使用しません。学習のシグナルは全て、異なるテキストプロンプトに対するモデル自身のノイズ予測の差分から得られます。 +
+ +### Key Differences from Standard LoRA Training / 通常の LoRA 学習との違い + +| | Standard LoRA | LECO | +|---|---|---| +| Training data | Image dataset required | **No images needed** | +| Configuration | Dataset TOML | Prompt TOML | +| Training target | U-Net and/or Text Encoder | **U-Net only** | +| Training unit | Epochs and steps | **Steps only** | +| Saving | Per-epoch or per-step | **Per-step only** (`--save_every_n_steps`) | + +
+日本語 + +| | 通常の LoRA | LECO | +|---|---|---| +| 学習データ | 画像データセットが必要 | **画像不要** | +| 設定ファイル | データセット TOML | プロンプト TOML | +| 学習対象 | U-Net と Text Encoder | **U-Net のみ** | +| 学習単位 | エポックとステップ | **ステップのみ** | +| 保存 | エポック毎またはステップ毎 | **ステップ毎のみ** (`--save_every_n_steps`) | +
+ +## 2. Prompt Configuration File / プロンプト設定ファイル + +LECO uses a TOML file to define training prompts. Two formats are supported: the **original LECO format** and the **slider target format** (ai-toolkit style). + +
+日本語 +LECO は学習プロンプトの定義に TOML ファイルを使用します。**オリジナル LECO 形式**と**スライダーターゲット形式**(ai-toolkit スタイル)の2つの形式に対応しています。 +
+ +### 2.1. Original LECO Format / オリジナル LECO 形式 + +Use `[[prompts]]` sections to define prompt pairs directly. This gives you full control over each training pair. + +```toml +[[prompts]] +target = "van gogh" +positive = "van gogh" +unconditional = "" +neutral = "" +action = "erase" +guidance_scale = 1.0 +resolution = 512 +batch_size = 1 +multiplier = 1.0 +weight = 1.0 +``` + +Each `[[prompts]]` entry defines one training pair with the following fields: + +| Field | Required | Default | Description | +|-------|----------|---------|-------------| +| `target` | Yes | - | The concept to be modified by the LoRA | +| `positive` | No | same as `target` | The "positive direction" prompt for building the training target | +| `unconditional` | No | `""` | The unconditional/negative prompt | +| `neutral` | No | `""` | The neutral baseline prompt | +| `action` | No | `"erase"` | `"erase"` to remove the concept, `"enhance"` to strengthen it | +| `guidance_scale` | No | `1.0` | Scale factor for target construction (higher = stronger effect) | +| `resolution` | No | `512` | Training resolution (int or `[height, width]`) | +| `batch_size` | No | `1` | Number of latent samples per training step for this prompt | +| `multiplier` | No | `1.0` | LoRA strength multiplier during training | +| `weight` | No | `1.0` | Loss weight for this prompt pair | + +
+日本語 + +`[[prompts]]` セクションを使用して、プロンプトペアを直接定義します。各学習ペアを細かく制御できます。 + +各 `[[prompts]]` エントリのフィールド: + +| フィールド | 必須 | デフォルト | 説明 | +|-----------|------|-----------|------| +| `target` | はい | - | LoRA によって変更される概念 | +| `positive` | いいえ | `target` と同じ | 学習ターゲット構築時の「正方向」プロンプト | +| `unconditional` | いいえ | `""` | 無条件/ネガティブプロンプト | +| `neutral` | いいえ | `""` | ニュートラルベースラインプロンプト | +| `action` | いいえ | `"erase"` | `"erase"` で概念を除去、`"enhance"` で強化 | +| `guidance_scale` | いいえ | `1.0` | ターゲット構築時のスケール係数(大きいほど効果が強い) | +| `resolution` | いいえ | `512` | 学習解像度(整数または `[height, width]`) | +| `batch_size` | いいえ | `1` | このプロンプトの学習ステップごとの latent サンプル数 | +| `multiplier` | いいえ | `1.0` | 学習時の LoRA 強度乗数 | +| `weight` | いいえ | `1.0` | このプロンプトペアの loss 重み | +
+ +### 2.2. Slider Target Format / スライダーターゲット形式 + +Use `[[targets]]` sections to define slider-style LoRAs. Each target is automatically expanded into bidirectional training pairs (4 pairs when both `positive` and `negative` are provided, 2 pairs when only one is provided). + +```toml +guidance_scale = 1.0 +resolution = 1024 +neutral = "" + +[[targets]] +target_class = "1girl" +positive = "1girl, long hair" +negative = "1girl, short hair" +multiplier = 1.0 +weight = 1.0 +``` + +Top-level fields (`guidance_scale`, `resolution`, `neutral`, `batch_size`, etc.) serve as defaults for all targets. + +Each `[[targets]]` entry supports the following fields: + +| Field | Required | Default | Description | +|-------|----------|---------|-------------| +| `target_class` | Yes | - | The base class/subject prompt | +| `positive` | No* | `""` | Prompt for the positive direction of the slider | +| `negative` | No* | `""` | Prompt for the negative direction of the slider | +| `multiplier` | No | `1.0` | LoRA strength multiplier | +| `weight` | No | `1.0` | Loss weight | + +\* At least one of `positive` or `negative` must be provided. + +
+日本語 + +`[[targets]]` セクションを使用してスライダースタイルの LoRA を定義します。各ターゲットは自動的に双方向の学習ペアに展開されます(`positive` と `negative` の両方がある場合は4ペア、片方のみの場合は2ペア)。 + +トップレベルのフィールド(`guidance_scale`、`resolution`、`neutral`、`batch_size` など)は全ターゲットのデフォルト値として機能します。 + +各 `[[targets]]` エントリのフィールド: + +| フィールド | 必須 | デフォルト | 説明 | +|-----------|------|-----------|------| +| `target_class` | はい | - | ベースとなるクラス/被写体プロンプト | +| `positive` | いいえ* | `""` | スライダーの正方向プロンプト | +| `negative` | いいえ* | `""` | スライダーの負方向プロンプト | +| `multiplier` | いいえ | `1.0` | LoRA 強度乗数 | +| `weight` | いいえ | `1.0` | loss 重み | + +\* `positive` と `negative` のうち少なくとも一方を指定する必要があります。 +
+ +### 2.3. Multiple Neutral Prompts / 複数のニュートラルプロンプト + +You can provide multiple neutral prompts for slider targets. Each neutral prompt generates a separate set of training pairs, which can improve generalization. + +```toml +guidance_scale = 1.5 +resolution = 1024 +neutrals = ["", "photo of a person", "cinematic portrait"] + +[[targets]] +target_class = "person" +positive = "smiling person" +negative = "expressionless person" +``` + +You can also load neutral prompts from a text file (one prompt per line): + +```toml +neutral_prompt_file = "neutrals.txt" + +[[targets]] +target_class = "" +positive = "high detail" +negative = "low detail" +``` + +
+日本語 + +スライダーターゲットに対して複数のニュートラルプロンプトを指定できます。各ニュートラルプロンプトごとに個別の学習ペアが生成され、汎化性能の向上が期待できます。 + +ニュートラルプロンプトをテキストファイル(1行1プロンプト)から読み込むこともできます。 +
+ +### 2.4. Converting from ai-toolkit YAML / ai-toolkit の YAML からの変換 + +If you have an existing ai-toolkit style YAML config, convert it to TOML as follows: + +
+日本語 +既存の ai-toolkit スタイルの YAML 設定がある場合、以下のように TOML に変換してください。 +
+ +**YAML:** +```yaml +targets: + - target_class: "" + positive: "high detail" + negative: "low detail" + multiplier: 1.0 +guidance_scale: 1.0 +resolution: 512 +``` + +**TOML:** +```toml +guidance_scale = 1.0 +resolution = 512 + +[[targets]] +target_class = "" +positive = "high detail" +negative = "low detail" +multiplier = 1.0 +``` + +Key syntax differences: + +- Use `=` instead of `:` for key-value pairs +- Use `[[targets]]` header instead of `targets:` with `- ` list items +- Arrays use `[brackets]` (e.g., `neutrals = ["a", "b"]`) + +
+日本語 + +主な構文の違い: + +- キーと値の区切りに `:` ではなく `=` を使用 +- `targets:` と `- ` のリスト記法ではなく `[[targets]]` ヘッダを使用 +- 配列は `[brackets]` で記述(例:`neutrals = ["a", "b"]`) +
+ +## 3. Running the Training / 学習の実行 + +Training is started by executing the script from the terminal. Below are basic command-line examples. + +In reality, you need to write the command in a single line, but it is shown with line breaks for readability. On Linux/Mac, add `\` at the end of each line; on Windows, add `^`. + +
+日本語 +学習はターミナルからスクリプトを実行して開始します。以下に基本的なコマンドライン例を示します。 + +実際には1行で書く必要がありますが、見やすさのために改行しています。Linux/Mac では各行末に `\` を、Windows では `^` を追加してください。 +
+ +### SD 1.x / 2.x + +```bash +accelerate launch --mixed_precision bf16 train_leco.py + --pretrained_model_name_or_path="model.safetensors" + --prompts_file="prompts.toml" + --output_dir="output" + --output_name="my_leco" + --network_dim=8 + --network_alpha=4 + --learning_rate=1e-4 + --optimizer_type="AdamW8bit" + --max_train_steps=500 + --max_denoising_steps=40 + --mixed_precision=bf16 + --sdpa + --gradient_checkpointing + --save_every_n_steps=100 +``` + +### SDXL + +```bash +accelerate launch --mixed_precision bf16 sdxl_train_leco.py + --pretrained_model_name_or_path="sdxl_model.safetensors" + --prompts_file="slider.toml" + --output_dir="output" + --output_name="my_sdxl_slider" + --network_dim=8 + --network_alpha=4 + --learning_rate=1e-4 + --optimizer_type="AdamW8bit" + --max_train_steps=1000 + --max_denoising_steps=40 + --mixed_precision=bf16 + --sdpa + --gradient_checkpointing + --save_every_n_steps=200 +``` + +## 4. Command-Line Arguments / コマンドライン引数 + +### 4.1. LECO-Specific Arguments / LECO 固有の引数 + +These arguments are unique to LECO and not found in standard LoRA training scripts. + +
+日本語 +以下の引数は LECO 固有のもので、通常の LoRA 学習スクリプトにはありません。 +
+ +* `--prompts_file="prompts.toml"` **[Required]** + * Path to the LECO prompt configuration TOML file. See [Section 2](#2-prompt-configuration-file--プロンプト設定ファイル) for the file format. + +* `--max_denoising_steps=40` + * Number of partial denoising steps per training iteration. At each step, a random number of denoising steps (from 1 to this value) is performed. Default: `40`. + +* `--leco_denoise_guidance_scale=3.0` + * Guidance scale used during the partial denoising pass. This is separate from `guidance_scale` in the TOML file. Default: `3.0`. + +
+日本語 + +* `--prompts_file="prompts.toml"` **[必須]** + * LECO プロンプト設定 TOML ファイルのパス。ファイル形式については[セクション2](#2-prompt-configuration-file--プロンプト設定ファイル)を参照してください。 + +* `--max_denoising_steps=40` + * 各学習イテレーションでの部分デノイズステップ数。各ステップで1からこの値の間のランダムなステップ数でデノイズが行われます。デフォルト: `40`。 + +* `--leco_denoise_guidance_scale=3.0` + * 部分デノイズ時の guidance scale。TOML ファイル内の `guidance_scale` とは別のパラメータです。デフォルト: `3.0`。 +
+ +#### Understanding the Two `guidance_scale` Parameters / 2つの `guidance_scale` の違い + +There are two separate guidance scale parameters that control different aspects of LECO training: + +1. **`--leco_denoise_guidance_scale` (command-line)**: Controls CFG strength during the partial denoising pass that generates intermediate latents. Higher values produce more prompt-adherent latents for the training signal. + +2. **`guidance_scale` (in TOML file)**: Controls the magnitude of the concept offset when constructing the training target. Higher values produce a stronger erase/enhance effect. This can be set per-prompt or per-target. + +If training results are too subtle, try increasing the TOML `guidance_scale` (e.g., `1.5` to `3.0`). + +
+日本語 + +LECO の学習では、異なる役割を持つ2つの guidance scale パラメータがあります: + +1. **`--leco_denoise_guidance_scale`(コマンドライン)**: 中間 latent を生成する部分デノイズパスの CFG 強度を制御します。大きな値にすると、プロンプトにより忠実な latent が学習シグナルとして生成されます。 + +2. **`guidance_scale`(TOML ファイル内)**: 学習ターゲット構築時の概念オフセットの大きさを制御します。大きな値にすると、消去/強化の効果が強くなります。プロンプトごと・ターゲットごとに設定可能です。 + +学習結果の効果が弱い場合は、TOML の `guidance_scale` を大きくしてみてください(例:`1.5` から `3.0`)。 +
+ +### 4.2. Model Arguments / モデル引数 + +* `--pretrained_model_name_or_path="model.safetensors"` **[Required]** + * Path to the base Stable Diffusion model (`.ckpt`, `.safetensors`, Diffusers directory, or Hugging Face model ID). + +* `--v2` (SD 1.x/2.x only) + * Specify when using a Stable Diffusion v2.x model. + +* `--v_parameterization` (SD 1.x/2.x only) + * Specify when using a v-prediction model (e.g., SD 2.x 768px models). + +
+日本語 + +* `--pretrained_model_name_or_path="model.safetensors"` **[必須]** + * ベースとなる Stable Diffusion モデルのパス(`.ckpt`、`.safetensors`、Diffusers ディレクトリ、Hugging Face モデル ID)。 + +* `--v2`(SD 1.x/2.x のみ) + * Stable Diffusion v2.x モデルを使用する場合に指定します。 + +* `--v_parameterization`(SD 1.x/2.x のみ) + * v-prediction モデル(SD 2.x 768px モデルなど)を使用する場合に指定します。 +
+ +### 4.3. LoRA Network Arguments / LoRA ネットワーク引数 + +* `--network_module=networks.lora` + * Network module to train. Default: `networks.lora`. + +* `--network_dim=8` + * LoRA rank (dimension). Higher values increase expressiveness but also file size. Typical values: `4` to `16`. Default: `4`. + +* `--network_alpha=4` + * LoRA alpha for learning rate scaling. A common choice is to set this to half of `network_dim`. Default: `1.0`. + +* `--network_dropout=0.1` + * Dropout rate for LoRA layers. Optional. + +* `--network_args "key=value" ...` + * Additional network-specific arguments. For example, `--network_args "conv_dim=4"` to enable Conv2d LoRA. + +* `--network_weights="path/to/weights.safetensors"` + * Load pretrained LoRA weights to continue training. + +* `--dim_from_weights` + * Infer `network_dim` from the weights specified by `--network_weights`. Requires `--network_weights`. + +
+日本語 + +* `--network_module=networks.lora` + * 学習するネットワークモジュール。デフォルト: `networks.lora`。 + +* `--network_dim=8` + * LoRA のランク(次元数)。大きいほど表現力が上がりますがファイルサイズも増加します。一般的な値: `4` から `16`。デフォルト: `4`。 + +* `--network_alpha=4` + * 学習率スケーリング用の LoRA alpha。`network_dim` の半分程度に設定するのが一般的です。デフォルト: `1.0`。 + +* `--network_dropout=0.1` + * LoRA レイヤーのドロップアウト率。省略可。 + +* `--network_args "key=value" ...` + * ネットワーク固有の追加引数。例:`--network_args "conv_dim=4"` で Conv2d LoRA を有効にします。 + +* `--network_weights="path/to/weights.safetensors"` + * 事前学習済み LoRA ウェイトを読み込んで学習を続行します。 + +* `--dim_from_weights` + * `--network_weights` で指定したウェイトから `network_dim` を推定します。`--network_weights` の指定が必要です。 +
+ +### 4.4. Training Parameters / 学習パラメータ + +* `--max_train_steps=500` + * Total number of training steps. Default: `1600`. Typical range for LECO: `300` to `2000`. + * Note: `--max_train_epochs` is **not supported** for LECO (the training loop is step-based only). + +* `--learning_rate=1e-4` + * Learning rate. Typical range for LECO: `1e-4` to `1e-3`. + +* `--unet_lr=1e-4` + * Separate learning rate for U-Net LoRA modules. If not specified, `--learning_rate` is used. + +* `--optimizer_type="AdamW8bit"` + * Optimizer type. Options include `AdamW8bit` (requires `bitsandbytes`), `AdamW`, `Lion`, `Adafactor`, etc. + +* `--lr_scheduler="constant"` + * Learning rate scheduler. Options: `constant`, `cosine`, `linear`, `constant_with_warmup`, etc. + +* `--lr_warmup_steps=0` + * Number of warmup steps for the learning rate scheduler. + +* `--gradient_accumulation_steps=1` + * Number of steps to accumulate gradients before updating. Effectively multiplies the batch size. + +* `--max_grad_norm=1.0` + * Maximum gradient norm for gradient clipping. Set to `0` to disable. + +* `--min_snr_gamma=5.0` + * Min-SNR weighting gamma. Applies SNR-based loss weighting. Optional. + +
+日本語 + +* `--max_train_steps=500` + * 学習の総ステップ数。デフォルト: `1600`。LECO の一般的な範囲: `300` から `2000`。 + * 注意: `--max_train_epochs` は LECO では**サポートされていません**(学習ループはステップベースのみです)。 + +* `--learning_rate=1e-4` + * 学習率。LECO の一般的な範囲: `1e-4` から `1e-3`。 + +* `--unet_lr=1e-4` + * U-Net LoRA モジュール用の個別の学習率。指定しない場合は `--learning_rate` が使用されます。 + +* `--optimizer_type="AdamW8bit"` + * オプティマイザの種類。`AdamW8bit`(要 `bitsandbytes`)、`AdamW`、`Lion`、`Adafactor` 等が選択可能です。 + +* `--lr_scheduler="constant"` + * 学習率スケジューラ。`constant`、`cosine`、`linear`、`constant_with_warmup` 等が選択可能です。 + +* `--lr_warmup_steps=0` + * 学習率スケジューラのウォームアップステップ数。 + +* `--gradient_accumulation_steps=1` + * 勾配を累積するステップ数。実質的にバッチサイズを増加させます。 + +* `--max_grad_norm=1.0` + * 勾配クリッピングの最大勾配ノルム。`0` で無効化。 + +* `--min_snr_gamma=5.0` + * Min-SNR 重み付けのガンマ値。SNR ベースの loss 重み付けを適用します。省略可。 +
+ +### 4.5. Output and Save Arguments / 出力・保存引数 + +* `--output_dir="output"` **[Required]** + * Directory for saving trained LoRA models and logs. + +* `--output_name="my_leco"` **[Required]** + * Base filename for the trained LoRA (without extension). + +* `--save_model_as="safetensors"` + * Model save format. Options: `safetensors` (default, recommended), `ckpt`, `pt`. + +* `--save_every_n_steps=100` + * Save an intermediate checkpoint every N steps. If not specified, only the final model is saved. + * Note: `--save_every_n_epochs` is **not supported** for LECO. + +* `--save_precision="fp16"` + * Precision for saving the model. Options: `float`, `fp16`, `bf16`. If not specified, the training precision is used. + +* `--no_metadata` + * Do not write metadata into the saved model file. + +* `--training_comment="my comment"` + * A comment string stored in the model metadata. + +
+日本語 + +* `--output_dir="output"` **[必須]** + * 学習済み LoRA モデルとログの保存先ディレクトリ。 + +* `--output_name="my_leco"` **[必須]** + * 学習済み LoRA のベースファイル名(拡張子なし)。 + +* `--save_model_as="safetensors"` + * モデルの保存形式。`safetensors`(デフォルト、推奨)、`ckpt`、`pt` から選択。 + +* `--save_every_n_steps=100` + * N ステップごとに中間チェックポイントを保存。指定しない場合は最終モデルのみ保存されます。 + * 注意: `--save_every_n_epochs` は LECO では**サポートされていません**。 + +* `--save_precision="fp16"` + * モデル保存時の精度。`float`、`fp16`、`bf16` から選択。省略時は学習時の精度が使用されます。 + +* `--no_metadata` + * 保存するモデルファイルにメタデータを書き込みません。 + +* `--training_comment="my comment"` + * モデルのメタデータに保存されるコメント文字列。 +
+ +### 4.6. Memory and Performance Arguments / メモリ・パフォーマンス引数 + +* `--mixed_precision="bf16"` + * Mixed precision training. Options: `no`, `fp16`, `bf16`. Using `bf16` or `fp16` is recommended. + +* `--full_fp16` + * Train entirely in fp16 precision including gradients. + +* `--full_bf16` + * Train entirely in bf16 precision including gradients. + +* `--gradient_checkpointing` + * Enable gradient checkpointing to reduce VRAM usage at the cost of slightly slower training. **Recommended for LECO**, especially with larger models or higher resolutions. + +* `--sdpa` + * Use Scaled Dot-Product Attention. Reduces memory usage and can improve speed. Recommended. + +* `--xformers` + * Use xformers for memory-efficient attention (requires `xformers` package). Alternative to `--sdpa`. + +* `--mem_eff_attn` + * Use memory-efficient attention implementation. Another alternative to `--sdpa`. + +
+日本語 + +* `--mixed_precision="bf16"` + * 混合精度学習。`no`、`fp16`、`bf16` から選択。`bf16` または `fp16` の使用を推奨します。 + +* `--full_fp16` + * 勾配を含め全体を fp16 精度で学習します。 + +* `--full_bf16` + * 勾配を含め全体を bf16 精度で学習します。 + +* `--gradient_checkpointing` + * gradient checkpointing を有効にしてVRAM使用量を削減します(学習速度は若干低下)。特に大きなモデルや高解像度での LECO 学習時に**推奨**です。 + +* `--sdpa` + * Scaled Dot-Product Attention を使用します。メモリ使用量を削減し速度向上が期待できます。推奨。 + +* `--xformers` + * xformers を使用したメモリ効率の良い attention(`xformers` パッケージが必要)。`--sdpa` の代替。 + +* `--mem_eff_attn` + * メモリ効率の良い attention 実装を使用。`--sdpa` の別の代替。 +
+ +### 4.7. Other Useful Arguments / その他の便利な引数 + +* `--seed=42` + * Random seed for reproducibility. If not specified, a random seed is automatically generated. + +* `--noise_offset=0.05` + * Enable noise offset. Small values like `0.02` to `0.1` can help with training stability. + +* `--zero_terminal_snr` + * Fix noise scheduler betas to enforce zero terminal SNR. + +* `--clip_skip=2` (SD 1.x/2.x only) + * Use the output from the Nth-to-last layer of the text encoder. Common values: `1` (no skip) or `2`. + +* `--logging_dir="logs"` + * Directory for TensorBoard logs. Enables logging when specified. + +* `--log_with="tensorboard"` + * Logging tool. Options: `tensorboard`, `wandb`, `all`. + +
+日本語 + +* `--seed=42` + * 再現性のための乱数シード。指定しない場合は自動生成されます。 + +* `--noise_offset=0.05` + * ノイズオフセットを有効にします。`0.02` から `0.1` 程度の小さい値で学習の安定性が向上する場合があります。 + +* `--zero_terminal_snr` + * noise scheduler の betas を修正してゼロ終端 SNR を強制します。 + +* `--clip_skip=2`(SD 1.x/2.x のみ) + * text encoder の後ろから N 番目の層の出力を使用します。一般的な値: `1`(スキップなし)または `2`。 + +* `--logging_dir="logs"` + * TensorBoard ログの出力ディレクトリ。指定時にログ出力が有効になります。 + +* `--log_with="tensorboard"` + * ログツール。`tensorboard`、`wandb`、`all` から選択。 +
+ +## 5. Tips / ヒント + +### Tuning the Effect Strength / 効果の強さの調整 + +If the trained LoRA has a weak or unnoticeable effect: + +1. **Increase `guidance_scale` in TOML** (e.g., `1.5` to `3.0`). This is the most direct way to strengthen the effect. +2. **Increase `multiplier` in TOML** (e.g., `1.5` to `2.0`). +3. **Increase `--max_denoising_steps`** for more refined intermediate latents. +4. **Increase `--max_train_steps`** to train longer. +5. **Apply the LoRA with a higher weight** at inference time. + +
+日本語 + +学習した LoRA の効果が弱い、または認識できない場合: + +1. **TOML の `guidance_scale` を上げる**(例:`1.5` から `3.0`)。効果を強める最も直接的な方法です。 +2. **TOML の `multiplier` を上げる**(例:`1.5` から `2.0`)。 +3. **`--max_denoising_steps` を増やす**。より精緻な中間 latent が生成されます。 +4. **`--max_train_steps` を増やして**、より長く学習する。 +5. **推論時に LoRA のウェイトを大きくして**適用する。 +
+ +### Recommended Starting Settings / 推奨の開始設定 + +| Parameter | SD 1.x/2.x | SDXL | +|-----------|-------------|------| +| `--network_dim` | `4`-`8` | `8`-`16` | +| `--learning_rate` | `1e-4` | `1e-4` | +| `--max_train_steps` | `300`-`1000` | `500`-`2000` | +| `resolution` (in TOML) | `512` | `1024` | +| `guidance_scale` (in TOML) | `1.0`-`2.0` | `1.0`-`3.0` | +| `batch_size` (in TOML) | `1`-`4` | `1`-`4` | + +
+日本語 + +| パラメータ | SD 1.x/2.x | SDXL | +|-----------|-------------|------| +| `--network_dim` | `4`-`8` | `8`-`16` | +| `--learning_rate` | `1e-4` | `1e-4` | +| `--max_train_steps` | `300`-`1000` | `500`-`2000` | +| `resolution`(TOML内) | `512` | `1024` | +| `guidance_scale`(TOML内) | `1.0`-`2.0` | `1.0`-`3.0` | +| `batch_size`(TOML内) | `1`-`4` | `1`-`4` | +
+ +### Dynamic Resolution and Crops (SDXL) / 動的解像度とクロップ(SDXL) + +For SDXL slider targets, you can enable dynamic resolution and crops in the TOML file: + +```toml +resolution = 1024 +dynamic_resolution = true +dynamic_crops = true + +[[targets]] +target_class = "" +positive = "high detail" +negative = "low detail" +``` + +- `dynamic_resolution`: Randomly varies the training resolution around the base value using aspect ratio buckets. +- `dynamic_crops`: Randomizes crop positions in the SDXL size conditioning embeddings. + +These options can improve the LoRA's generalization across different aspect ratios. + +
+日本語 + +SDXL のスライダーターゲットでは、TOML ファイルで動的解像度とクロップを有効にできます。 + +- `dynamic_resolution`: アスペクト比バケツを使用して、ベース値の周囲で学習解像度をランダムに変化させます。 +- `dynamic_crops`: SDXL のサイズ条件付け埋め込みでクロップ位置をランダム化します。 + +これらのオプションにより、異なるアスペクト比に対する LoRA の汎化性能が向上する場合があります。 +
+ +## 6. Using the Trained Model / 学習済みモデルの利用 + +The trained LoRA file (`.safetensors`) is saved in the `--output_dir` directory. It can be used with GUI tools such as AUTOMATIC1111/stable-diffusion-webui, ComfyUI, etc. + +For slider LoRAs, apply positive weights (e.g., `0.5` to `1.5`) to move in the positive direction, and negative weights (e.g., `-0.5` to `-1.5`) to move in the negative direction. + +
+日本語 + +学習済みの LoRA ファイル(`.safetensors`)は `--output_dir` ディレクトリに保存されます。AUTOMATIC1111/stable-diffusion-webui、ComfyUI 等の GUI ツールで使用できます。 + +スライダー LoRA の場合、正のウェイト(例:`0.5` から `1.5`)で正方向に、負のウェイト(例:`-0.5` から `-1.5`)で負方向に効果を適用できます。 +
diff --git a/library/leco_train_util.py b/library/leco_train_util.py new file mode 100644 index 00000000..5e95c163 --- /dev/null +++ b/library/leco_train_util.py @@ -0,0 +1,522 @@ +import argparse +import json +import math +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union + +import torch +import toml +from torch.utils.checkpoint import checkpoint + +from library import train_util + +import logging + +logger = logging.getLogger(__name__) + + +def build_network_kwargs(args: argparse.Namespace) -> Dict[str, str]: + kwargs = {} + if args.network_args: + for net_arg in args.network_args: + key, value = net_arg.split("=", 1) + kwargs[key] = value + if "dropout" not in kwargs: + kwargs["dropout"] = args.network_dropout + return kwargs + + +def get_save_extension(args: argparse.Namespace) -> str: + if args.save_model_as == "ckpt": + return ".ckpt" + if args.save_model_as == "pt": + return ".pt" + return ".safetensors" + + +def save_weights( + accelerator, + network, + args: argparse.Namespace, + save_dtype, + prompt_settings, + global_step: int, + last: bool = False, + extra_metadata: Optional[Dict[str, str]] = None, +) -> None: + os.makedirs(args.output_dir, exist_ok=True) + ext = get_save_extension(args) + ckpt_name = train_util.get_last_ckpt_name(args, ext) if last else train_util.get_step_ckpt_name(args, ext, global_step) + ckpt_file = os.path.join(args.output_dir, ckpt_name) + + metadata = None + if not args.no_metadata: + metadata = { + "ss_network_module": args.network_module, + "ss_network_dim": str(args.network_dim), + "ss_network_alpha": str(args.network_alpha), + "ss_leco_prompt_count": str(len(prompt_settings)), + "ss_leco_prompts_file": os.path.basename(args.prompts_file), + } + if extra_metadata: + metadata.update(extra_metadata) + if args.training_comment: + metadata["ss_training_comment"] = args.training_comment + metadata["ss_leco_preview"] = json.dumps( + [ + { + "target": p.target, + "positive": p.positive, + "unconditional": p.unconditional, + "neutral": p.neutral, + "action": p.action, + "multiplier": p.multiplier, + "weight": p.weight, + } + for p in prompt_settings[:16] + ], + ensure_ascii=False, + ) + + unwrapped = accelerator.unwrap_model(network) + unwrapped.save_weights(ckpt_file, save_dtype, metadata) + logger.info(f"saved model to: {ckpt_file}") + + + +ResolutionValue = Union[int, Tuple[int, int]] + + +@dataclass +class PromptEmbedsXL: + text_embeds: torch.Tensor + pooled_embeds: torch.Tensor + + +class PromptEmbedsCache: + def __init__(self): + self.prompts: dict[str, Any] = {} + + def __setitem__(self, name: str, value: Any) -> None: + self.prompts[name] = value + + def __getitem__(self, name: str) -> Any: + return self.prompts[name] + + +@dataclass +class PromptSettings: + target: str + positive: Optional[str] = None + unconditional: str = "" + neutral: Optional[str] = None + action: str = "erase" + guidance_scale: float = 1.0 + resolution: ResolutionValue = 512 + dynamic_resolution: bool = False + batch_size: int = 1 + dynamic_crops: bool = False + multiplier: float = 1.0 + weight: float = 1.0 + + def __post_init__(self): + if self.positive is None: + self.positive = self.target + if self.neutral is None: + self.neutral = self.unconditional + if self.action not in ("erase", "enhance"): + raise ValueError(f"Invalid action: {self.action}") + + self.guidance_scale = float(self.guidance_scale) + self.batch_size = int(self.batch_size) + self.multiplier = float(self.multiplier) + self.weight = float(self.weight) + self.dynamic_resolution = bool(self.dynamic_resolution) + self.dynamic_crops = bool(self.dynamic_crops) + self.resolution = normalize_resolution(self.resolution) + + def get_resolution(self) -> Tuple[int, int]: + if isinstance(self.resolution, tuple): + return self.resolution + return (self.resolution, self.resolution) + + def build_target(self, positive_latents, neutral_latents, unconditional_latents): + offset = self.guidance_scale * (positive_latents - unconditional_latents) + if self.action == "erase": + return neutral_latents - offset + return neutral_latents + offset + + +def normalize_resolution(value: Any) -> ResolutionValue: + if isinstance(value, tuple): + if len(value) != 2: + raise ValueError(f"resolution tuple must have 2 items: {value}") + return (int(value[0]), int(value[1])) + if isinstance(value, list): + if len(value) == 2 and all(isinstance(v, (int, float)) for v in value): + return (int(value[0]), int(value[1])) + raise ValueError(f"resolution list must have 2 numeric items: {value}") + return int(value) + + +def _read_non_empty_lines(path: Union[str, Path]) -> List[str]: + with open(path, "r", encoding="utf-8") as f: + return [line.strip() for line in f.readlines() if line.strip()] + + +def _recognized_prompt_keys() -> set[str]: + return { + "target", + "positive", + "unconditional", + "neutral", + "action", + "guidance_scale", + "resolution", + "dynamic_resolution", + "batch_size", + "dynamic_crops", + "multiplier", + "weight", + } + + +def _recognized_slider_keys() -> set[str]: + return { + "target_class", + "positive", + "negative", + "neutral", + "guidance_scale", + "resolution", + "resolutions", + "dynamic_resolution", + "batch_size", + "dynamic_crops", + "multiplier", + "weight", + } + + +def _merge_known_defaults(defaults: dict[str, Any], item: dict[str, Any], known_keys: Iterable[str]) -> dict[str, Any]: + merged = {k: v for k, v in defaults.items() if k in known_keys} + merged.update(item) + return merged + + +def _normalize_resolution_values(value: Any) -> List[ResolutionValue]: + if value is None: + return [512] + if isinstance(value, list) and value and isinstance(value[0], (list, tuple)): + return [normalize_resolution(v) for v in value] + return [normalize_resolution(value)] + + +def _expand_slider_target(target: dict[str, Any], neutral: str) -> List[PromptSettings]: + target_class = str(target.get("target_class", "")) + positive = str(target.get("positive", "") or "") + negative = str(target.get("negative", "") or "") + multiplier = target.get("multiplier", 1.0) + resolutions = _normalize_resolution_values(target.get("resolutions", target.get("resolution", 512))) + + if not positive.strip() and not negative.strip(): + raise ValueError("slider target requires either positive or negative prompt") + + base = dict( + target=target_class, + neutral=neutral, + guidance_scale=target.get("guidance_scale", 1.0), + dynamic_resolution=target.get("dynamic_resolution", False), + batch_size=target.get("batch_size", 1), + dynamic_crops=target.get("dynamic_crops", False), + weight=target.get("weight", 1.0), + ) + + # Build bidirectional (positive_prompt, unconditional_prompt, action, multiplier_sign) pairs. + # With both positive and negative: 4 pairs; with only one: 2 pairs. + pairs: list[tuple[str, str, str, float]] = [] + if positive.strip() and negative.strip(): + pairs = [ + (negative, positive, "erase", multiplier), + (positive, negative, "enhance", multiplier), + (positive, negative, "erase", -multiplier), + (negative, positive, "enhance", -multiplier), + ] + elif negative.strip(): + pairs = [ + (negative, "", "erase", multiplier), + (negative, "", "enhance", -multiplier), + ] + else: + pairs = [ + (positive, "", "enhance", multiplier), + (positive, "", "erase", -multiplier), + ] + + prompt_settings: List[PromptSettings] = [] + for resolution in resolutions: + for pos, uncond, action, mult in pairs: + prompt_settings.append( + PromptSettings(**base, positive=pos, unconditional=uncond, action=action, resolution=resolution, multiplier=mult) + ) + + return prompt_settings + + +def load_prompt_settings(path: Union[str, Path]) -> List[PromptSettings]: + path = Path(path) + with open(path, "r", encoding="utf-8") as f: + data = toml.load(f) + + if not data: + raise ValueError("prompt file is empty") + + default_prompt_values = { + "guidance_scale": 1.0, + "resolution": 512, + "dynamic_resolution": False, + "batch_size": 1, + "dynamic_crops": False, + "multiplier": 1.0, + "weight": 1.0, + } + + prompt_settings: List[PromptSettings] = [] + + def append_prompt_item(item: dict[str, Any], defaults: dict[str, Any]) -> None: + merged = _merge_known_defaults(defaults, item, _recognized_prompt_keys()) + prompt_settings.append(PromptSettings(**merged)) + + def append_slider_item(item: dict[str, Any], defaults: dict[str, Any], neutral_values: Sequence[str]) -> None: + merged = _merge_known_defaults(defaults, item, _recognized_slider_keys()) + if not neutral_values: + neutral_values = [str(merged.get("neutral", "") or "")] + for neutral in neutral_values: + prompt_settings.extend(_expand_slider_target(merged, neutral)) + + if "prompts" in data: + defaults = {**default_prompt_values, **{k: v for k, v in data.items() if k in _recognized_prompt_keys()}} + for item in data["prompts"]: + if "target_class" in item: + append_slider_item(item, defaults, [str(item.get("neutral", "") or "")]) + else: + append_prompt_item(item, defaults) + else: + slider_config = data.get("slider", data) + targets = slider_config.get("targets") + if targets is None: + if "target_class" in slider_config: + targets = [slider_config] + elif "target" in slider_config: + targets = [slider_config] + else: + raise ValueError("prompt file does not contain prompts or slider targets") + if len(targets) == 0: + raise ValueError("prompt file contains an empty targets list") + + if "target" in targets[0]: + defaults = {**default_prompt_values, **{k: v for k, v in slider_config.items() if k in _recognized_prompt_keys()}} + for item in targets: + append_prompt_item(item, defaults) + else: + defaults = {**default_prompt_values, **{k: v for k, v in slider_config.items() if k in _recognized_slider_keys()}} + neutral_values: List[str] = [] + if "neutrals" in slider_config: + neutral_values.extend(str(v) for v in slider_config["neutrals"]) + if "neutral_prompt_file" in slider_config: + neutral_values.extend(_read_non_empty_lines(path.parent / slider_config["neutral_prompt_file"])) + if "prompt_file" in slider_config: + neutral_values.extend(_read_non_empty_lines(path.parent / slider_config["prompt_file"])) + if not neutral_values: + neutral_values = [str(slider_config.get("neutral", "") or "")] + + for item in targets: + item_neutrals = neutral_values + if "neutrals" in item: + item_neutrals = [str(v) for v in item["neutrals"]] + elif "neutral_prompt_file" in item: + item_neutrals = _read_non_empty_lines(path.parent / item["neutral_prompt_file"]) + elif "prompt_file" in item: + item_neutrals = _read_non_empty_lines(path.parent / item["prompt_file"]) + elif "neutral" in item: + item_neutrals = [str(item["neutral"] or "")] + + append_slider_item(item, defaults, item_neutrals) + + if not prompt_settings: + raise ValueError("no prompt settings found") + + return prompt_settings + + +def encode_prompt_sd(tokenize_strategy, text_encoding_strategy, text_encoder, prompt: str) -> torch.Tensor: + tokens = tokenize_strategy.tokenize(prompt) + return text_encoding_strategy.encode_tokens(tokenize_strategy, [text_encoder], tokens)[0] + + +def encode_prompt_sdxl(tokenize_strategy, text_encoding_strategy, text_encoders, prompt: str) -> PromptEmbedsXL: + tokens = tokenize_strategy.tokenize(prompt) + hidden1, hidden2, pool2 = text_encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens) + return PromptEmbedsXL(torch.cat([hidden1, hidden2], dim=2), pool2) + + +def apply_noise_offset(latents: torch.Tensor, noise_offset: Optional[float]) -> torch.Tensor: + if noise_offset is None: + return latents + noise = torch.randn((latents.shape[0], latents.shape[1], 1, 1), dtype=torch.float32, device="cpu") + noise = noise.to(dtype=latents.dtype, device=latents.device) + return latents + noise_offset * noise + + +def get_initial_latents(scheduler, batch_size: int, height: int, width: int, n_prompts: int = 1) -> torch.Tensor: + noise = torch.randn( + (batch_size, 4, height // 8, width // 8), + device="cpu", + ).repeat(n_prompts, 1, 1, 1) + return noise * scheduler.init_noise_sigma + + +def concat_embeddings(unconditional: torch.Tensor, conditional: torch.Tensor, batch_size: int) -> torch.Tensor: + return torch.cat([unconditional, conditional], dim=0).repeat_interleave(batch_size, dim=0) + + +def concat_embeddings_xl(unconditional: PromptEmbedsXL, conditional: PromptEmbedsXL, batch_size: int) -> PromptEmbedsXL: + text_embeds = torch.cat([unconditional.text_embeds, conditional.text_embeds], dim=0).repeat_interleave(batch_size, dim=0) + pooled_embeds = torch.cat([unconditional.pooled_embeds, conditional.pooled_embeds], dim=0).repeat_interleave(batch_size, dim=0) + return PromptEmbedsXL(text_embeds=text_embeds, pooled_embeds=pooled_embeds) + + +def batch_add_time_ids(add_time_ids: torch.Tensor, batch_size: int) -> torch.Tensor: + """Duplicate add_time_ids for CFG (unconditional + conditional) and repeat for the batch.""" + return torch.cat([add_time_ids, add_time_ids], dim=0).repeat_interleave(batch_size, dim=0) + + +def _run_with_checkpoint(function, *args): + if torch.is_grad_enabled(): + return checkpoint(function, *args, use_reentrant=False) + return function(*args) + + +def predict_noise(unet, scheduler, timestep, latents: torch.Tensor, text_embeddings: torch.Tensor, guidance_scale: float = 1.0): + latent_model_input = torch.cat([latents] * 2) + latent_model_input = scheduler.scale_model_input(latent_model_input, timestep) + + def run_unet(model_input, encoder_hidden_states): + return unet(model_input, timestep, encoder_hidden_states=encoder_hidden_states).sample + + noise_pred = _run_with_checkpoint(run_unet, latent_model_input, text_embeddings) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + return noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + +def diffusion( + unet, + scheduler, + latents: torch.Tensor, + text_embeddings: torch.Tensor, + total_timesteps: int, + start_timesteps: int = 0, + guidance_scale: float = 3.0, +): + for timestep in scheduler.timesteps[start_timesteps:total_timesteps]: + noise_pred = predict_noise(unet, scheduler, timestep, latents, text_embeddings, guidance_scale=guidance_scale) + latents = scheduler.step(noise_pred, timestep, latents).prev_sample + return latents + + +def get_add_time_ids( + height: int, + width: int, + dynamic_crops: bool = False, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, +) -> torch.Tensor: + if dynamic_crops: + random_scale = torch.rand(1).item() * 2 + 1 + original_size = (int(height * random_scale), int(width * random_scale)) + crops_coords_top_left = ( + torch.randint(0, max(original_size[0] - height, 1), (1,)).item(), + torch.randint(0, max(original_size[1] - width, 1), (1,)).item(), + ) + target_size = (height, width) + else: + original_size = (height, width) + crops_coords_top_left = (0, 0) + target_size = (height, width) + + add_time_ids = torch.tensor([list(original_size + crops_coords_top_left + target_size)], dtype=dtype) + if device is not None: + add_time_ids = add_time_ids.to(device) + return add_time_ids + + +def predict_noise_xl( + unet, + scheduler, + timestep, + latents: torch.Tensor, + prompt_embeds: PromptEmbedsXL, + add_time_ids: torch.Tensor, + guidance_scale: float = 1.0, +): + latent_model_input = torch.cat([latents] * 2) + latent_model_input = scheduler.scale_model_input(latent_model_input, timestep) + + orig_size = add_time_ids[:, :2] + crop_size = add_time_ids[:, 2:4] + target_size = add_time_ids[:, 4:6] + from library import sdxl_train_util + + size_embeddings = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, latent_model_input.device) + vector_embedding = torch.cat([prompt_embeds.pooled_embeds, size_embeddings.to(prompt_embeds.pooled_embeds.dtype)], dim=1) + + def run_unet(model_input, text_embeds, vector_embeds): + return unet(model_input, timestep, text_embeds, vector_embeds) + + noise_pred = _run_with_checkpoint(run_unet, latent_model_input, prompt_embeds.text_embeds, vector_embedding) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + return noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + +def diffusion_xl( + unet, + scheduler, + latents: torch.Tensor, + prompt_embeds: PromptEmbedsXL, + add_time_ids: torch.Tensor, + total_timesteps: int, + start_timesteps: int = 0, + guidance_scale: float = 3.0, +): + for timestep in scheduler.timesteps[start_timesteps:total_timesteps]: + noise_pred = predict_noise_xl( + unet, + scheduler, + timestep, + latents, + prompt_embeds=prompt_embeds, + add_time_ids=add_time_ids, + guidance_scale=guidance_scale, + ) + latents = scheduler.step(noise_pred, timestep, latents).prev_sample + return latents + + +def get_random_resolution_in_bucket(bucket_resolution: int = 512) -> Tuple[int, int]: + max_resolution = bucket_resolution + min_resolution = bucket_resolution // 2 + step = 64 + min_step = min_resolution // step + max_step = max_resolution // step + height = torch.randint(min_step, max_step + 1, (1,)).item() * step + width = torch.randint(min_step, max_step + 1, (1,)).item() * step + return height, width + + +def get_random_resolution(prompt: PromptSettings) -> Tuple[int, int]: + height, width = prompt.get_resolution() + if prompt.dynamic_resolution and height == width: + return get_random_resolution_in_bucket(height) + return height, width diff --git a/library/train_util.py b/library/train_util.py index 672aa597..83d04f5e 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1106,7 +1106,8 @@ class BaseDataset(torch.utils.data.Dataset): return all( [ not ( - subset.caption_dropout_rate > 0 and not cache_supports_dropout + subset.caption_dropout_rate > 0 + and not cache_supports_dropout or subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0 @@ -2056,7 +2057,9 @@ class DreamBoothDataset(BaseDataset): filtered_img_paths.append(img_path) filtered_sizes.append(size) if len(filtered_img_paths) < len(img_paths): - logger.info(f"filtered {len(img_paths) - len(filtered_img_paths)} images by original resolution from {subset.image_dir}") + logger.info( + f"filtered {len(img_paths) - len(filtered_img_paths)} images by original resolution from {subset.image_dir}" + ) img_paths = filtered_img_paths sizes = filtered_sizes @@ -2542,9 +2545,7 @@ class ControlNetDataset(BaseDataset): len(missing_imgs) == 0 ), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}" if len(extra_imgs) > 0: - logger.warning( - f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}" - ) + logger.warning(f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}") self.conditioning_image_transforms = IMAGE_TRANSFORMS diff --git a/sdxl_train_leco.py b/sdxl_train_leco.py new file mode 100644 index 00000000..ff5550f9 --- /dev/null +++ b/sdxl_train_leco.py @@ -0,0 +1,342 @@ +import argparse +import importlib +import random + +import torch +from accelerate.utils import set_seed +from diffusers import DDPMScheduler +from tqdm import tqdm + +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from library import custom_train_functions, sdxl_model_util, sdxl_train_util, strategy_sdxl, train_util +from library.custom_train_functions import apply_snr_weight, prepare_scheduler_for_custom_training +from library.leco_train_util import ( + PromptEmbedsCache, + apply_noise_offset, + batch_add_time_ids, + build_network_kwargs, + concat_embeddings_xl, + diffusion_xl, + encode_prompt_sdxl, + get_add_time_ids, + get_initial_latents, + get_random_resolution, + load_prompt_settings, + predict_noise_xl, + save_weights, +) +from library.utils import add_logging_arguments, setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + train_util.add_sd_models_arguments(parser) + train_util.add_optimizer_arguments(parser) + train_util.add_training_arguments(parser, support_dreambooth=False) + custom_train_functions.add_custom_train_arguments(parser, support_weighted_captions=False) + sdxl_train_util.add_sdxl_training_arguments(parser, support_text_encoder_caching=False) + add_logging_arguments(parser) + + parser.add_argument( + "--save_model_as", + type=str, + default="safetensors", + choices=[None, "ckpt", "pt", "safetensors"], + help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)", + ) + parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを保存しない") + + parser.add_argument("--prompts_file", type=str, required=True, help="LECO prompt toml / LECO用のprompt toml") + parser.add_argument( + "--max_denoising_steps", + type=int, + default=40, + help="number of partial denoising steps per iteration / 各イテレーションで部分デノイズするステップ数", + ) + parser.add_argument( + "--leco_denoise_guidance_scale", + type=float, + default=3.0, + help="guidance scale for the partial denoising pass / 部分デノイズ時のguidance scale", + ) + + parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network") + parser.add_argument("--network_module", type=str, default="networks.lora", help="network module to train") + parser.add_argument("--network_dim", type=int, default=4, help="network rank / ネットワークのrank") + parser.add_argument("--network_alpha", type=float, default=1.0, help="network alpha / ネットワークのalpha") + parser.add_argument("--network_dropout", type=float, default=None, help="network dropout / ネットワークのdropout") + parser.add_argument("--network_args", type=str, default=None, nargs="*", help="additional network arguments") + parser.add_argument( + "--network_train_text_encoder_only", + action="store_true", + help="unsupported for LECO; kept for compatibility / LECOでは未対応", + ) + parser.add_argument( + "--network_train_unet_only", + action="store_true", + help="LECO always trains U-Net LoRA only / LECOは常にU-Net LoRAのみを学習", + ) + parser.add_argument("--training_comment", type=str, default=None, help="comment stored in metadata") + parser.add_argument("--dim_from_weights", action="store_true", help="infer network dim from network_weights") + parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") + + # dummy arguments required by train_util.verify_training_args / deepspeed_utils (LECO does not use datasets or deepspeed) + parser.add_argument("--cache_latents", action="store_true", default=False, help=argparse.SUPPRESS) + parser.add_argument("--cache_latents_to_disk", action="store_true", default=False, help=argparse.SUPPRESS) + parser.add_argument("--deepspeed", action="store_true", default=False, help=argparse.SUPPRESS) + + return parser + + +def main(): + parser = setup_parser() + args = parser.parse_args() + args = train_util.read_config_from_file(args, parser) + train_util.verify_training_args(args) + sdxl_train_util.verify_sdxl_training_args(args, support_text_encoder_caching=False) + + if args.output_dir is None: + raise ValueError("--output_dir is required") + if args.network_train_text_encoder_only: + raise ValueError("LECO does not support text encoder LoRA training") + + if args.seed is None: + args.seed = random.randint(0, 2**32 - 1) + set_seed(args.seed) + + accelerator = train_util.prepare_accelerator(args) + weight_dtype, save_dtype = train_util.prepare_dtype(args) + + prompt_settings = load_prompt_settings(args.prompts_file) + logger.info(f"loaded {len(prompt_settings)} LECO prompt settings from {args.prompts_file}") + + _, text_encoder1, text_encoder2, vae, unet, _, _ = sdxl_train_util.load_target_model( + args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype + ) + del vae + text_encoders = [text_encoder1, text_encoder2] + + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + unet.requires_grad_(False) + unet.to(accelerator.device, dtype=weight_dtype) + unet.train() + + tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) + text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy() + + for text_encoder in text_encoders: + text_encoder.to(accelerator.device, dtype=weight_dtype) + text_encoder.requires_grad_(False) + text_encoder.eval() + + prompt_cache = PromptEmbedsCache() + unique_prompts = sorted( + { + prompt + for setting in prompt_settings + for prompt in (setting.target, setting.positive, setting.unconditional, setting.neutral) + } + ) + with torch.no_grad(): + for prompt in unique_prompts: + prompt_cache[prompt] = encode_prompt_sdxl(tokenize_strategy, text_encoding_strategy, text_encoders, prompt) + + for text_encoder in text_encoders: + text_encoder.to("cpu", dtype=torch.float32) + clean_memory_on_device(accelerator.device) + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + num_train_timesteps=1000, + clip_sample=False, + ) + prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) + if args.zero_terminal_snr: + custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) + + network_module = importlib.import_module(args.network_module) + net_kwargs = build_network_kwargs(args) + if args.dim_from_weights: + if args.network_weights is None: + raise ValueError("--dim_from_weights requires --network_weights") + network, _ = network_module.create_network_from_weights(1.0, args.network_weights, None, text_encoders, unet, **net_kwargs) + else: + network = network_module.create_network( + 1.0, + args.network_dim, + args.network_alpha, + None, + text_encoders, + unet, + neuron_dropout=args.network_dropout, + **net_kwargs, + ) + + network.apply_to(text_encoders, unet, apply_text_encoder=False, apply_unet=True) + network.set_multiplier(0.0) + + if args.network_weights is not None: + info = network.load_weights(args.network_weights) + logger.info(f"loaded network weights from {args.network_weights}: {info}") + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + network.enable_gradient_checkpointing() + + unet_lr = args.unet_lr if args.unet_lr is not None else args.learning_rate + trainable_params, _ = network.prepare_optimizer_params(None, unet_lr, args.learning_rate) + _, _, optimizer = train_util.get_optimizer(args, trainable_params) + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + network, optimizer, lr_scheduler = accelerator.prepare(network, optimizer, lr_scheduler) + accelerator.unwrap_model(network).prepare_grad_etc(text_encoders, unet) + + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + + optimizer_train_fn, _ = train_util.get_optimizer_train_eval_fn(optimizer, args) + optimizer_train_fn() + train_util.init_trackers(accelerator, args, "sdxl_leco_train") + + progress_bar = tqdm(total=args.max_train_steps, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + while global_step < args.max_train_steps: + with accelerator.accumulate(network): + optimizer.zero_grad(set_to_none=True) + + setting = prompt_settings[torch.randint(0, len(prompt_settings), (1,)).item()] + noise_scheduler.set_timesteps(args.max_denoising_steps, device=accelerator.device) + + timesteps_to = torch.randint(1, args.max_denoising_steps, (1,), device=accelerator.device).item() + height, width = get_random_resolution(setting) + + latents = get_initial_latents(noise_scheduler, setting.batch_size, height, width, 1).to( + accelerator.device, dtype=weight_dtype + ) + latents = apply_noise_offset(latents, args.noise_offset) + add_time_ids = get_add_time_ids( + height, + width, + dynamic_crops=setting.dynamic_crops, + dtype=weight_dtype, + device=accelerator.device, + ) + batched_time_ids = batch_add_time_ids(add_time_ids, setting.batch_size) + + network_multiplier = accelerator.unwrap_model(network) + network_multiplier.set_multiplier(setting.multiplier) + with accelerator.autocast(): + denoised_latents = diffusion_xl( + unet, + noise_scheduler, + latents, + concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.target], setting.batch_size), + add_time_ids=batched_time_ids, + total_timesteps=timesteps_to, + guidance_scale=args.leco_denoise_guidance_scale, + ) + + noise_scheduler.set_timesteps(1000, device=accelerator.device) + current_timestep_index = int(timesteps_to * 1000 / args.max_denoising_steps) + current_timestep = noise_scheduler.timesteps[current_timestep_index] + + network_multiplier.set_multiplier(0.0) + with torch.no_grad(), accelerator.autocast(): + positive_latents = predict_noise_xl( + unet, + noise_scheduler, + current_timestep, + denoised_latents, + concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.positive], setting.batch_size), + add_time_ids=batched_time_ids, + guidance_scale=1.0, + ) + neutral_latents = predict_noise_xl( + unet, + noise_scheduler, + current_timestep, + denoised_latents, + concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.neutral], setting.batch_size), + add_time_ids=batched_time_ids, + guidance_scale=1.0, + ) + unconditional_latents = predict_noise_xl( + unet, + noise_scheduler, + current_timestep, + denoised_latents, + concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.unconditional], setting.batch_size), + add_time_ids=batched_time_ids, + guidance_scale=1.0, + ) + + network_multiplier.set_multiplier(setting.multiplier) + with accelerator.autocast(): + target_latents = predict_noise_xl( + unet, + noise_scheduler, + current_timestep, + denoised_latents, + concat_embeddings_xl(prompt_cache[setting.unconditional], prompt_cache[setting.target], setting.batch_size), + add_time_ids=batched_time_ids, + guidance_scale=1.0, + ) + + target = setting.build_target(positive_latents, neutral_latents, unconditional_latents) + loss = torch.nn.functional.mse_loss(target_latents.float(), target.float(), reduction="none") + loss = loss.mean(dim=(1, 2, 3)) + if args.min_snr_gamma is not None and args.min_snr_gamma > 0: + timesteps = torch.full((loss.shape[0],), current_timestep_index, device=loss.device, dtype=torch.long) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) + loss = loss.mean() * setting.weight + + accelerator.backward(loss) + + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(network.parameters(), args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + + if accelerator.sync_gradients: + global_step += 1 + progress_bar.update(1) + network_multiplier = accelerator.unwrap_model(network) + network_multiplier.set_multiplier(0.0) + + logs = { + "loss": loss.detach().item(), + "lr": lr_scheduler.get_last_lr()[0], + "guidance_scale": setting.guidance_scale, + "network_multiplier": setting.multiplier, + } + accelerator.log(logs, step=global_step) + progress_bar.set_postfix(loss=f"{logs['loss']:.4f}") + + if args.save_every_n_steps and global_step % args.save_every_n_steps == 0 and global_step < args.max_train_steps: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + sdxl_extra = {"ss_base_model_version": sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0} + save_weights(accelerator, network, args, save_dtype, prompt_settings, global_step, last=False, extra_metadata=sdxl_extra) + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + sdxl_extra = {"ss_base_model_version": sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0} + save_weights(accelerator, network, args, save_dtype, prompt_settings, global_step, last=True, extra_metadata=sdxl_extra) + + accelerator.end_training() + + +if __name__ == "__main__": + main() diff --git a/tests/library/test_leco_train_util.py b/tests/library/test_leco_train_util.py new file mode 100644 index 00000000..5e950f43 --- /dev/null +++ b/tests/library/test_leco_train_util.py @@ -0,0 +1,116 @@ +from pathlib import Path + +import torch + +from library.leco_train_util import load_prompt_settings + + +def test_load_prompt_settings_with_original_format(tmp_path: Path): + prompt_file = tmp_path / "prompts.toml" + prompt_file.write_text( + """ +[[prompts]] +target = "van gogh" +guidance_scale = 1.5 +resolution = 512 +""".strip(), + encoding="utf-8", + ) + + prompts = load_prompt_settings(prompt_file) + + assert len(prompts) == 1 + assert prompts[0].target == "van gogh" + assert prompts[0].positive == "van gogh" + assert prompts[0].unconditional == "" + assert prompts[0].neutral == "" + assert prompts[0].action == "erase" + assert prompts[0].guidance_scale == 1.5 + + +def test_load_prompt_settings_with_slider_targets(tmp_path: Path): + prompt_file = tmp_path / "slider.toml" + prompt_file.write_text( + """ +guidance_scale = 2.0 +resolution = 768 +neutral = "" + +[[targets]] +target_class = "" +positive = "high detail" +negative = "low detail" +multiplier = 1.25 +weight = 0.5 +""".strip(), + encoding="utf-8", + ) + + prompts = load_prompt_settings(prompt_file) + + assert len(prompts) == 4 + + first = prompts[0] + second = prompts[1] + third = prompts[2] + fourth = prompts[3] + + assert first.target == "" + assert first.positive == "low detail" + assert first.unconditional == "high detail" + assert first.action == "erase" + assert first.multiplier == 1.25 + assert first.weight == 0.5 + assert first.get_resolution() == (768, 768) + + assert second.positive == "high detail" + assert second.unconditional == "low detail" + assert second.action == "enhance" + assert second.multiplier == 1.25 + + assert third.action == "erase" + assert third.multiplier == -1.25 + + assert fourth.action == "enhance" + assert fourth.multiplier == -1.25 + + +def test_predict_noise_xl_uses_vector_embedding_from_add_time_ids(): + from library import sdxl_train_util + from library.leco_train_util import PromptEmbedsXL, predict_noise_xl + + class DummyScheduler: + def scale_model_input(self, latent_model_input, timestep): + return latent_model_input + + class DummyUNet: + def __call__(self, x, timesteps, context, y): + self.x = x + self.timesteps = timesteps + self.context = context + self.y = y + return torch.zeros_like(x) + + latents = torch.randn(1, 4, 8, 8) + prompt_embeds = PromptEmbedsXL( + text_embeds=torch.randn(2, 77, 2048), + pooled_embeds=torch.randn(2, 1280), + ) + add_time_ids = torch.tensor( + [ + [1024, 1024, 0, 0, 1024, 1024], + [1024, 1024, 0, 0, 1024, 1024], + ], + dtype=prompt_embeds.pooled_embeds.dtype, + ) + + unet = DummyUNet() + noise_pred = predict_noise_xl(unet, DummyScheduler(), torch.tensor(10), latents, prompt_embeds, add_time_ids) + + expected_size_embeddings = sdxl_train_util.get_size_embeddings( + add_time_ids[:, :2], add_time_ids[:, 2:4], add_time_ids[:, 4:6], latents.device + ).to(prompt_embeds.pooled_embeds.dtype) + + assert noise_pred.shape == latents.shape + assert unet.context is prompt_embeds.text_embeds + assert torch.equal(unet.y, torch.cat([prompt_embeds.pooled_embeds, expected_size_embeddings], dim=1)) diff --git a/tests/test_sdxl_train_leco.py b/tests/test_sdxl_train_leco.py new file mode 100644 index 00000000..637aa28f --- /dev/null +++ b/tests/test_sdxl_train_leco.py @@ -0,0 +1,16 @@ +import sdxl_train_leco +from library import deepspeed_utils, sdxl_train_util, train_util + + +def test_syntax(): + assert sdxl_train_leco is not None + + +def test_setup_parser_supports_shared_training_validation(): + args = sdxl_train_leco.setup_parser().parse_args(["--prompts_file", "slider.yaml"]) + + train_util.verify_training_args(args) + sdxl_train_util.verify_sdxl_training_args(args, support_text_encoder_caching=False) + + assert args.min_snr_gamma is None + assert deepspeed_utils.prepare_deepspeed_plugin(args) is None diff --git a/tests/test_train_leco.py b/tests/test_train_leco.py new file mode 100644 index 00000000..4a43d3d7 --- /dev/null +++ b/tests/test_train_leco.py @@ -0,0 +1,15 @@ +import train_leco +from library import deepspeed_utils, train_util + + +def test_syntax(): + assert train_leco is not None + + +def test_setup_parser_supports_shared_training_validation(): + args = train_leco.setup_parser().parse_args(["--prompts_file", "slider.yaml"]) + + train_util.verify_training_args(args) + + assert args.min_snr_gamma is None + assert deepspeed_utils.prepare_deepspeed_plugin(args) is None diff --git a/train_leco.py b/train_leco.py new file mode 100644 index 00000000..e5439e0f --- /dev/null +++ b/train_leco.py @@ -0,0 +1,319 @@ +import argparse +import importlib +import random + +import torch +from accelerate.utils import set_seed +from diffusers import DDPMScheduler +from tqdm import tqdm + +from library.device_utils import init_ipex, clean_memory_on_device + +init_ipex() + +from library import custom_train_functions, strategy_sd, train_util +from library.custom_train_functions import apply_snr_weight, prepare_scheduler_for_custom_training +from library.leco_train_util import ( + PromptEmbedsCache, + apply_noise_offset, + build_network_kwargs, + concat_embeddings, + diffusion, + encode_prompt_sd, + get_initial_latents, + get_random_resolution, + get_save_extension, + load_prompt_settings, + predict_noise, + save_weights, +) +from library.utils import add_logging_arguments, setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +def setup_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + train_util.add_sd_models_arguments(parser) + train_util.add_optimizer_arguments(parser) + train_util.add_training_arguments(parser, support_dreambooth=False) + custom_train_functions.add_custom_train_arguments(parser, support_weighted_captions=False) + add_logging_arguments(parser) + + parser.add_argument( + "--save_model_as", + type=str, + default="safetensors", + choices=[None, "ckpt", "pt", "safetensors"], + help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)", + ) + parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを保存しない") + + parser.add_argument("--prompts_file", type=str, required=True, help="LECO prompt toml / LECO用のprompt toml") + parser.add_argument( + "--max_denoising_steps", + type=int, + default=40, + help="number of partial denoising steps per iteration / 各イテレーションで部分デノイズするステップ数", + ) + parser.add_argument( + "--leco_denoise_guidance_scale", + type=float, + default=3.0, + help="guidance scale for the partial denoising pass / 部分デノイズ時のguidance scale", + ) + + parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network") + parser.add_argument("--network_module", type=str, default="networks.lora", help="network module to train") + parser.add_argument("--network_dim", type=int, default=4, help="network rank / ネットワークのrank") + parser.add_argument("--network_alpha", type=float, default=1.0, help="network alpha / ネットワークのalpha") + parser.add_argument("--network_dropout", type=float, default=None, help="network dropout / ネットワークのdropout") + parser.add_argument("--network_args", type=str, default=None, nargs="*", help="additional network arguments") + parser.add_argument( + "--network_train_text_encoder_only", + action="store_true", + help="unsupported for LECO; kept for compatibility / LECOでは未対応", + ) + parser.add_argument( + "--network_train_unet_only", + action="store_true", + help="LECO always trains U-Net LoRA only / LECOは常にU-Net LoRAのみを学習", + ) + parser.add_argument("--training_comment", type=str, default=None, help="comment stored in metadata") + parser.add_argument("--dim_from_weights", action="store_true", help="infer network dim from network_weights") + parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") + + # dummy arguments required by train_util.verify_training_args / deepspeed_utils (LECO does not use datasets or deepspeed) + parser.add_argument("--cache_latents", action="store_true", default=False, help=argparse.SUPPRESS) + parser.add_argument("--cache_latents_to_disk", action="store_true", default=False, help=argparse.SUPPRESS) + parser.add_argument("--deepspeed", action="store_true", default=False, help=argparse.SUPPRESS) + + return parser + + +def main(): + parser = setup_parser() + args = parser.parse_args() + args = train_util.read_config_from_file(args, parser) + train_util.verify_training_args(args) + + if args.output_dir is None: + raise ValueError("--output_dir is required") + if args.network_train_text_encoder_only: + raise ValueError("LECO does not support text encoder LoRA training") + + if args.seed is None: + args.seed = random.randint(0, 2**32 - 1) + set_seed(args.seed) + + accelerator = train_util.prepare_accelerator(args) + weight_dtype, save_dtype = train_util.prepare_dtype(args) + + prompt_settings = load_prompt_settings(args.prompts_file) + logger.info(f"loaded {len(prompt_settings)} LECO prompt settings from {args.prompts_file}") + + text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) + del vae + + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) + unet.requires_grad_(False) + unet.to(accelerator.device, dtype=weight_dtype) + unet.train() + + tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir) + text_encoding_strategy = strategy_sd.SdTextEncodingStrategy(args.clip_skip) + + text_encoder.to(accelerator.device, dtype=weight_dtype) + text_encoder.requires_grad_(False) + text_encoder.eval() + + prompt_cache = PromptEmbedsCache() + unique_prompts = sorted( + { + prompt + for setting in prompt_settings + for prompt in (setting.target, setting.positive, setting.unconditional, setting.neutral) + } + ) + with torch.no_grad(): + for prompt in unique_prompts: + prompt_cache[prompt] = encode_prompt_sd(tokenize_strategy, text_encoding_strategy, text_encoder, prompt) + + text_encoder.to("cpu") + clean_memory_on_device(accelerator.device) + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + num_train_timesteps=1000, + clip_sample=False, + ) + prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device) + if args.zero_terminal_snr: + custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler) + + network_module = importlib.import_module(args.network_module) + net_kwargs = build_network_kwargs(args) + if args.dim_from_weights: + if args.network_weights is None: + raise ValueError("--dim_from_weights requires --network_weights") + network, _ = network_module.create_network_from_weights(1.0, args.network_weights, None, text_encoder, unet, **net_kwargs) + else: + network = network_module.create_network( + 1.0, + args.network_dim, + args.network_alpha, + None, + text_encoder, + unet, + neuron_dropout=args.network_dropout, + **net_kwargs, + ) + + network.apply_to(text_encoder, unet, apply_text_encoder=False, apply_unet=True) + network.set_multiplier(0.0) + + if args.network_weights is not None: + info = network.load_weights(args.network_weights) + logger.info(f"loaded network weights from {args.network_weights}: {info}") + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + network.enable_gradient_checkpointing() + + unet_lr = args.unet_lr if args.unet_lr is not None else args.learning_rate + trainable_params, _ = network.prepare_optimizer_params(None, unet_lr, args.learning_rate) + _, _, optimizer = train_util.get_optimizer(args, trainable_params) + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + + network, optimizer, lr_scheduler = accelerator.prepare(network, optimizer, lr_scheduler) + accelerator.unwrap_model(network).prepare_grad_etc(text_encoder, unet) + + if args.full_fp16: + train_util.patch_accelerator_for_fp16_training(accelerator) + + optimizer_train_fn, _ = train_util.get_optimizer_train_eval_fn(optimizer, args) + optimizer_train_fn() + train_util.init_trackers(accelerator, args, "leco_train") + + progress_bar = tqdm(total=args.max_train_steps, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 + + while global_step < args.max_train_steps: + with accelerator.accumulate(network): + optimizer.zero_grad(set_to_none=True) + + setting = prompt_settings[torch.randint(0, len(prompt_settings), (1,)).item()] + noise_scheduler.set_timesteps(args.max_denoising_steps, device=accelerator.device) + + timesteps_to = torch.randint(1, args.max_denoising_steps, (1,), device=accelerator.device).item() + height, width = get_random_resolution(setting) + + latents = get_initial_latents(noise_scheduler, setting.batch_size, height, width, 1).to( + accelerator.device, dtype=weight_dtype + ) + latents = apply_noise_offset(latents, args.noise_offset) + + network_multiplier = accelerator.unwrap_model(network) + network_multiplier.set_multiplier(setting.multiplier) + with accelerator.autocast(): + denoised_latents = diffusion( + unet, + noise_scheduler, + latents, + concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.target], setting.batch_size), + total_timesteps=timesteps_to, + guidance_scale=args.leco_denoise_guidance_scale, + ) + + noise_scheduler.set_timesteps(1000, device=accelerator.device) + current_timestep_index = int(timesteps_to * 1000 / args.max_denoising_steps) + current_timestep = noise_scheduler.timesteps[current_timestep_index] + + network_multiplier.set_multiplier(0.0) + with torch.no_grad(), accelerator.autocast(): + positive_latents = predict_noise( + unet, + noise_scheduler, + current_timestep, + denoised_latents, + concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.positive], setting.batch_size), + guidance_scale=1.0, + ) + neutral_latents = predict_noise( + unet, + noise_scheduler, + current_timestep, + denoised_latents, + concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.neutral], setting.batch_size), + guidance_scale=1.0, + ) + unconditional_latents = predict_noise( + unet, + noise_scheduler, + current_timestep, + denoised_latents, + concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.unconditional], setting.batch_size), + guidance_scale=1.0, + ) + + network_multiplier.set_multiplier(setting.multiplier) + with accelerator.autocast(): + target_latents = predict_noise( + unet, + noise_scheduler, + current_timestep, + denoised_latents, + concat_embeddings(prompt_cache[setting.unconditional], prompt_cache[setting.target], setting.batch_size), + guidance_scale=1.0, + ) + + target = setting.build_target(positive_latents, neutral_latents, unconditional_latents) + loss = torch.nn.functional.mse_loss(target_latents.float(), target.float(), reduction="none") + loss = loss.mean(dim=(1, 2, 3)) + if args.min_snr_gamma is not None and args.min_snr_gamma > 0: + timesteps = torch.full((loss.shape[0],), current_timestep_index, device=loss.device, dtype=torch.long) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) + loss = loss.mean() * setting.weight + + accelerator.backward(loss) + + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(network.parameters(), args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + + if accelerator.sync_gradients: + global_step += 1 + progress_bar.update(1) + network_multiplier = accelerator.unwrap_model(network) + network_multiplier.set_multiplier(0.0) + + logs = { + "loss": loss.detach().item(), + "lr": lr_scheduler.get_last_lr()[0], + "guidance_scale": setting.guidance_scale, + "network_multiplier": setting.multiplier, + } + accelerator.log(logs, step=global_step) + progress_bar.set_postfix(loss=f"{logs['loss']:.4f}") + + if args.save_every_n_steps and global_step % args.save_every_n_steps == 0 and global_step < args.max_train_steps: + accelerator.wait_for_everyone() + if accelerator.is_main_process: + save_weights(accelerator, network, args, save_dtype, prompt_settings, global_step, last=False) + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + save_weights(accelerator, network, args, save_dtype, prompt_settings, global_step, last=True) + + accelerator.end_training() + + +if __name__ == "__main__": + main() From 5fb3172baf66248b4192ae19a95cbab1ad33a024 Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Sun, 29 Mar 2026 21:25:53 +0900 Subject: [PATCH 14/17] fix: AdaLN modulation to use float32 for numerical stability in fp16 --- library/anima_models.py | 59 ++++++++++++++++++++++------------------- 1 file changed, 32 insertions(+), 27 deletions(-) diff --git a/library/anima_models.py b/library/anima_models.py index 037ffd77..00e9c6c6 100644 --- a/library/anima_models.py +++ b/library/anima_models.py @@ -739,13 +739,16 @@ class FinalLayer(nn.Module): emb_B_T_D: torch.Tensor, adaln_lora_B_T_3D: Optional[torch.Tensor] = None, ): - if self.use_adaln_lora: - assert adaln_lora_B_T_3D is not None - shift_B_T_D, scale_B_T_D = (self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size]).chunk( - 2, dim=-1 - ) - else: - shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1) + # Compute AdaLN modulation parameters (in float32 when fp16 to avoid overflow in Linear layers) + use_fp32 = x_B_T_H_W_D.dtype == torch.float16 + with torch.autocast(device_type=x_B_T_H_W_D.device.type, dtype=torch.float32, enabled=use_fp32): + if self.use_adaln_lora: + assert adaln_lora_B_T_3D is not None + shift_B_T_D, scale_B_T_D = ( + self.adaln_modulation(emb_B_T_D) + adaln_lora_B_T_3D[:, :, : 2 * self.hidden_size] + ).chunk(2, dim=-1) + else: + shift_B_T_D, scale_B_T_D = self.adaln_modulation(emb_B_T_D).chunk(2, dim=-1) shift_B_T_1_1_D = rearrange(shift_B_T_D, "b t d -> b t 1 1 d") scale_B_T_1_1_D = rearrange(scale_B_T_D, "b t d -> b t 1 1 d") @@ -864,32 +867,34 @@ class Block(nn.Module): adaln_lora_B_T_3D: Optional[torch.Tensor] = None, extra_per_block_pos_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if x_B_T_H_W_D.dtype == torch.float16: + use_fp32 = x_B_T_H_W_D.dtype == torch.float16 + if use_fp32: # Cast to float32 for better numerical stability in residual connections. Each module will cast back to float16 by enclosing autocast context. x_B_T_H_W_D = x_B_T_H_W_D.float() if extra_per_block_pos_emb is not None: x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb - # Compute AdaLN modulation parameters - if self.use_adaln_lora: - shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = ( - self.adaln_modulation_self_attn(emb_B_T_D) + adaln_lora_B_T_3D - ).chunk(3, dim=-1) - shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = ( - self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D - ).chunk(3, dim=-1) - shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = (self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D).chunk( - 3, dim=-1 - ) - else: - shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn(emb_B_T_D).chunk( - 3, dim=-1 - ) - shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn( - emb_B_T_D - ).chunk(3, dim=-1) - shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = self.adaln_modulation_mlp(emb_B_T_D).chunk(3, dim=-1) + # Compute AdaLN modulation parameters (in float32 when fp16 to avoid overflow in Linear layers) + with torch.autocast(device_type=x_B_T_H_W_D.device.type, dtype=torch.float32, enabled=use_fp32): + if self.use_adaln_lora: + shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = ( + self.adaln_modulation_self_attn(emb_B_T_D) + adaln_lora_B_T_3D + ).chunk(3, dim=-1) + shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = ( + self.adaln_modulation_cross_attn(emb_B_T_D) + adaln_lora_B_T_3D + ).chunk(3, dim=-1) + shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = (self.adaln_modulation_mlp(emb_B_T_D) + adaln_lora_B_T_3D).chunk( + 3, dim=-1 + ) + else: + shift_self_attn_B_T_D, scale_self_attn_B_T_D, gate_self_attn_B_T_D = self.adaln_modulation_self_attn( + emb_B_T_D + ).chunk(3, dim=-1) + shift_cross_attn_B_T_D, scale_cross_attn_B_T_D, gate_cross_attn_B_T_D = self.adaln_modulation_cross_attn( + emb_B_T_D + ).chunk(3, dim=-1) + shift_mlp_B_T_D, scale_mlp_B_T_D, gate_mlp_B_T_D = self.adaln_modulation_mlp(emb_B_T_D).chunk(3, dim=-1) # Reshape for broadcasting: (B, T, D) -> (B, T, 1, 1, D) shift_self_attn_B_T_1_1_D = rearrange(shift_self_attn_B_T_D, "b t d -> b t 1 1 d") From b637c3136527675712ff0b830901f2e9609f76b7 Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Sun, 29 Mar 2026 21:58:38 +0900 Subject: [PATCH 15/17] fix: update table of contents and change history in README files for clarity --- README-ja.md | 51 +++++++++++++++++++++++++++++++-------------------- README.md | 35 +++++++++++++++++++++++------------ 2 files changed, 54 insertions(+), 32 deletions(-) diff --git a/README-ja.md b/README-ja.md index 51935728..e04934aa 100644 --- a/README-ja.md +++ b/README-ja.md @@ -8,25 +8,25 @@ クリックすると展開します - [はじめに](#はじめに) - - [スポンサー](#スポンサー) - - [スポンサー募集のお知らせ](#スポンサー募集のお知らせ) - - [更新履歴](#更新履歴) - - [サポートモデル](#サポートモデル) - - [機能](#機能) + - [スポンサー](#スポンサー) + - [スポンサー募集のお知らせ](#スポンサー募集のお知らせ) + - [更新履歴](#更新履歴) + - [サポートモデル](#サポートモデル) + - [機能](#機能) - [ドキュメント](#ドキュメント) - - [学習ドキュメント(英語および日本語)](#学習ドキュメント英語および日本語) - - [その他のドキュメント](#その他のドキュメント) - - [旧ドキュメント(日本語)](#旧ドキュメント日本語) + - [学習ドキュメント(英語および日本語)](#学習ドキュメント英語および日本語) + - [その他のドキュメント](#その他のドキュメント) + - [旧ドキュメント(日本語)](#旧ドキュメント日本語) - [AIコーディングエージェントを使う開発者の方へ](#aiコーディングエージェントを使う開発者の方へ) - [Windows環境でのインストール](#windows環境でのインストール) - - [Windowsでの動作に必要なプログラム](#windowsでの動作に必要なプログラム) - - [インストール手順](#インストール手順) - - [requirements.txtとPyTorchについて](#requirementstxtとpytorchについて) - - [xformersのインストール(オプション)](#xformersのインストールオプション) + - [Windowsでの動作に必要なプログラム](#windowsでの動作に必要なプログラム) + - [インストール手順](#インストール手順) + - [requirements.txtとPyTorchについて](#requirementstxtとpytorchについて) + - [xformersのインストール(オプション)](#xformersのインストールオプション) - [Linux/WSL2環境でのインストール](#linuxwsl2環境でのインストール) - - [DeepSpeedのインストール(実験的、LinuxまたはWSL2のみ)](#deepspeedのインストール実験的linuxまたはwsl2のみ) + - [DeepSpeedのインストール(実験的、LinuxまたはWSL2のみ)](#deepspeedのインストール実験的linuxまたはwsl2のみ) - [アップグレード](#アップグレード) - - [PyTorchのアップグレード](#pytorchのアップグレード) + - [PyTorchのアップグレード](#pytorchのアップグレード) - [謝意](#謝意) - [ライセンス](#ライセンス) @@ -50,15 +50,26 @@ Stable Diffusion等の画像生成モデルの学習、モデルによる画像 ### 更新履歴 +- **Version 0.10.2 (2026-03-30):** + - `networks/resize_lora.py`が`torch.svd_lowrank`に対応し、大幅に高速化されました。[PR #2240](https://github.com/kohya-ss/sd-scripts/pull/2240) および [PR #2296](https://github.com/kohya-ss/sd-scripts/pull/2296) woct0rdho氏に深く感謝します。 + - デフォルトは有効になっています。`--svd_lowrank_niter`オプションで反復回数を指定できます(デフォルトは2、多いほど精度が向上します)。0にすると従来の方法になります。詳細は `--help` でご確認ください。 + - LoKr/LoHaをSDXL/Animaでサポートしました。[PR #2275](https://github.com/kohya-ss/sd-scripts/pull/2275) + - 詳細は[ドキュメント](./docs/loha_lokr.md)をご覧ください。 + - マルチ解像度データセット(同じ画像を複数のbucketサイズにリサイズして使用)がSD/SDXLの学習でサポートされました。[PR #2269](https://github.com/kohya-ss/sd-scripts/pull/2269) また、マルチ解像度データセットで同じ解像度の画像が重複して使用される事象への対応を行いました。[PR #2273](https://github.com/kohya-ss/sd-scripts/pull/2273) + - woct0rdho氏に感謝します。 + - [ドキュメント英語版](./docs/config_README-en.md#behavior-when-there-are-duplicate-subsets) / [ドキュメント日本語版](./docs/config_README-ja.md#重複したサブセットが存在する時の挙動) をご覧ください。 + - Animaでfp16で学習する際の安定性が向上しました。[PR #2297](https://github.com/kohya-ss/sd-scripts/pull/2297) ただし、依然として不安定な場合があるようです。問題が発生する場合は、詳細をIssueでお知らせください。 + - その他、細かいバグ修正や改善を行いました。 + - **Version 0.10.1 (2026-02-13):** - - [Anima Preview](https://huggingface.co/circlestone-labs/Anima)モデルのLoRA学習およびfine-tuningをサポートしました。[PR #2260](https://github.com/kohya-ss/sd-scripts/pull/2260) および[PR #2261](https://github.com/kohya-ss/sd-scripts/pull/2261) - - 素晴らしいモデルを公開された CircleStone Labs、および PR #2260を提出していただいたduongve13112002氏に深く感謝します。 - - 詳細は[ドキュメント](./docs/anima_train_network.md)をご覧ください。 + - [Anima Preview](https://huggingface.co/circlestone-labs/Anima)モデルのLoRA学習およびfine-tuningをサポートしました。[PR #2260](https://github.com/kohya-ss/sd-scripts/pull/2260) および[PR #2261](https://github.com/kohya-ss/sd-scripts/pull/2261) + - 素晴らしいモデルを公開された CircleStone Labs、および PR #2260を提出していただいたduongve13112002氏に深く感謝します。 + - 詳細は[ドキュメント](./docs/anima_train_network.md)をご覧ください。 - **Version 0.10.0 (2026-01-19):** - - `sd3`ブランチを`main`ブランチにマージしました。このバージョンからFLUX.1およびSD3/SD3.5等のモデルが`main`ブランチでサポートされます。 - - ドキュメントにはまだ不備があるため、お気づきの点はIssue等でお知らせください。 - - `sd3`ブランチは当面、`dev`ブランチと同期して開発ブランチとして維持します。 + - `sd3`ブランチを`main`ブランチにマージしました。このバージョンからFLUX.1およびSD3/SD3.5等のモデルが`main`ブランチでサポートされます。 + - ドキュメントにはまだ不備があるため、お気づきの点はIssue等でお知らせください。 + - `sd3`ブランチは当面、`dev`ブランチと同期して開発ブランチとして維持します。 ### サポートモデル diff --git a/README.md b/README.md index 88b58359..6e889a82 100644 --- a/README.md +++ b/README.md @@ -7,23 +7,23 @@ Click to expand - [Introduction](#introduction) - - [Supported Models](#supported-models) - - [Features](#features) - - [Sponsors](#sponsors) - - [Support the Project](#support-the-project) + - [Supported Models](#supported-models) + - [Features](#features) + - [Sponsors](#sponsors) + - [Support the Project](#support-the-project) - [Documentation](#documentation) - - [Training Documentation (English and Japanese)](#training-documentation-english-and-japanese) - - [Other Documentation (English and Japanese)](#other-documentation-english-and-japanese) + - [Training Documentation (English and Japanese)](#training-documentation-english-and-japanese) + - [Other Documentation (English and Japanese)](#other-documentation-english-and-japanese) - [For Developers Using AI Coding Agents](#for-developers-using-ai-coding-agents) - [Windows Installation](#windows-installation) - - [Windows Required Dependencies](#windows-required-dependencies) - - [Installation Steps](#installation-steps) - - [About requirements.txt and PyTorch](#about-requirementstxt-and-pytorch) - - [xformers installation (optional)](#xformers-installation-optional) + - [Windows Required Dependencies](#windows-required-dependencies) + - [Installation Steps](#installation-steps) + - [About requirements.txt and PyTorch](#about-requirementstxt-and-pytorch) + - [xformers installation (optional)](#xformers-installation-optional) - [Linux/WSL2 Installation](#linuxwsl2-installation) - - [DeepSpeed installation (experimental, Linux or WSL2 only)](#deepspeed-installation-experimental-linux-or-wsl2-only) + - [DeepSpeed installation (experimental, Linux or WSL2 only)](#deepspeed-installation-experimental-linux-or-wsl2-only) - [Upgrade](#upgrade) - - [Upgrade PyTorch](#upgrade-pytorch) + - [Upgrade PyTorch](#upgrade-pytorch) - [Credits](#credits) - [License](#license) @@ -47,6 +47,17 @@ If you find this project helpful, please consider supporting its development via ### Change History +- **Version 0.10.2 (2026-03-30):** + - `networks/resize_lora.py` has been updated to use `torch.svd_lowrank`, resulting in a significant speedup. Many thanks to woct0rdho for [PR #2240](https://github.com/kohya-ss/sd-scripts/pull/2240) and [PR #2296](https://github.com/kohya-ss/sd-scripts/pull/2296). + - It is enabled by default. You can specify the number of iterations with the `--svd_lowrank_niter` option (default is 2, more iterations will improve accuracy). Setting it to 0 will revert to the previous method. Please check `--help` for details. + - LoKr/LoHa is now supported for SDXL/Anima. See [PR #2275](https://github.com/kohya-ss/sd-scripts/pull/2275) for details. + - Please refer to the [documentation](./docs/loha_lokr.md) for details. + - Multi-resolution datasets (using the same image resized to multiple bucket sizes) are now supported in SD/SDXL training. We also addressed the issue of duplicate images with the same resolution being used in multi-resolution datasets. See [PR #2269](https://github.com/kohya-ss/sd-scripts/pull/2269) and [PR #2273](https://github.com/kohya-ss/sd-scripts/pull/2273) for details. + - Thanks to woct0rdho for the contribution. + - Please refer to the [English documentation](./docs/config_README-en.md#behavior-when-there-are-duplicate-subsets) / [Japanese documentation](./docs/config_README-ja.md#重複したサブセットが存在する時の挙動) for details. + - Stability when training with fp16 on Anima has been improved. See [PR #2297](https://github.com/kohya-ss/sd-scripts/pull/2297) for details. However, it still seems to be unstable in some cases. If you encounter any issues, please let us know the details via Issues. + - Other minor bug fixes and improvements were made. + - **Version 0.10.1 (2026-02-13):** - [Anima Preview](https://huggingface.co/circlestone-labs/Anima) model LoRA training and fine-tuning are now supported. See [PR #2260](https://github.com/kohya-ss/sd-scripts/pull/2260) and [PR #2261](https://github.com/kohya-ss/sd-scripts/pull/2261). - Many thanks to CircleStone Labs for releasing this amazing model, and to duongve13112002 for submitting great PR #2260. From 3cb9025b4bed96cd1e42885cd685f6e49d176665 Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Sun, 29 Mar 2026 22:07:52 +0900 Subject: [PATCH 16/17] doc: update change history in README files to include LECO training support for SD/SDXL --- README-ja.md | 2 ++ README.md | 2 ++ 2 files changed, 4 insertions(+) diff --git a/README-ja.md b/README-ja.md index e04934aa..f4f912a2 100644 --- a/README-ja.md +++ b/README-ja.md @@ -51,6 +51,8 @@ Stable Diffusion等の画像生成モデルの学習、モデルによる画像 ### 更新履歴 - **Version 0.10.2 (2026-03-30):** + - SD/SDXLのLECO学習に対応しました。[PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) および [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294) umisetokikaze氏に深く感謝します。 + - 詳細は[ドキュメント](./docs/train_leco.md)をご覧ください。 - `networks/resize_lora.py`が`torch.svd_lowrank`に対応し、大幅に高速化されました。[PR #2240](https://github.com/kohya-ss/sd-scripts/pull/2240) および [PR #2296](https://github.com/kohya-ss/sd-scripts/pull/2296) woct0rdho氏に深く感謝します。 - デフォルトは有効になっています。`--svd_lowrank_niter`オプションで反復回数を指定できます(デフォルトは2、多いほど精度が向上します)。0にすると従来の方法になります。詳細は `--help` でご確認ください。 - LoKr/LoHaをSDXL/Animaでサポートしました。[PR #2275](https://github.com/kohya-ss/sd-scripts/pull/2275) diff --git a/README.md b/README.md index 6e889a82..fc041db3 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,8 @@ If you find this project helpful, please consider supporting its development via ### Change History - **Version 0.10.2 (2026-03-30):** + - LECO training for SD/SDXL is now supported. Many thanks to umisetokikaze for [PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) and [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294). + - Please refer to the [documentation](./docs/train_leco.md) for details. - `networks/resize_lora.py` has been updated to use `torch.svd_lowrank`, resulting in a significant speedup. Many thanks to woct0rdho for [PR #2240](https://github.com/kohya-ss/sd-scripts/pull/2240) and [PR #2296](https://github.com/kohya-ss/sd-scripts/pull/2296). - It is enabled by default. You can specify the number of iterations with the `--svd_lowrank_niter` option (default is 2, more iterations will improve accuracy). Setting it to 0 will revert to the previous method. Please check `--help` for details. - LoKr/LoHa is now supported for SDXL/Anima. See [PR #2275](https://github.com/kohya-ss/sd-scripts/pull/2275) for details. From b2c330407b8124fed8dae14e0ed0d329d663d5f4 Mon Sep 17 00:00:00 2001 From: woctordho Date: Thu, 4 Sep 2025 15:38:53 +0800 Subject: [PATCH 17/17] Print verbose info while extracting --- networks/resize_lora.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/networks/resize_lora.py b/networks/resize_lora.py index 2a44592b..f64edb1f 100644 --- a/networks/resize_lora.py +++ b/networks/resize_lora.py @@ -212,7 +212,6 @@ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1): def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dynamic_method, dynamic_param, verbose, svd_lowrank_niter=2): max_old_rank = None new_alpha = None - verbose_str = "\n" fro_list = [] if dynamic_method: @@ -285,15 +284,13 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna if not np.isnan(fro_retained): fro_list.append(float(fro_retained)) - verbose_str += f"{block_down_name:75} | " + verbose_str = f"{block_down_name:75} | " verbose_str += ( f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}" ) - - if verbose and dynamic_method: - verbose_str += f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n" - else: - verbose_str += "\n" + if dynamic_method: + verbose_str += f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}" + tqdm.write(verbose_str) new_alpha = param_dict["new_alpha"] o_lora_sd[block_down_name + lora_down_name + weight_name] = param_dict["lora_down"].to(save_dtype).contiguous() @@ -308,7 +305,6 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna del param_dict if verbose: - print(verbose_str) print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}") logger.info("resizing complete") return o_lora_sd, max_old_rank, new_alpha