From c28e7a47c3bd3c4efc81404bf4dadba2b41d4fe4 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 26 Jul 2025 19:35:42 +0900 Subject: [PATCH] feat: add regex-based rank and learning rate configuration for FLUX.1 LoRA --- docs/flux_train_network.md | 49 ++++++++- networks/lora_flux.py | 197 +++++++++++++++++++++++++++---------- train_network.py | 2 +- 3 files changed, 194 insertions(+), 54 deletions(-) diff --git a/docs/flux_train_network.md b/docs/flux_train_network.md index f324b959..1b584180 100644 --- a/docs/flux_train_network.md +++ b/docs/flux_train_network.md @@ -398,7 +398,50 @@ FLUX.1 LoRA学習では、network_argsの`train_double_block_indices`と`train_s -### 6.5. Text Encoder LoRA Support / Text Encoder LoRAのサポート + + + +### 6.4. Regular Expression-based Rank/LR Configuration / 正規表現によるランク・学習率の指定 + +You can specify ranks (dims) and learning rates for LoRA modules using regular expressions. This allows for more flexible and fine-grained control than specifying by layer. + +These settings are specified via the `network_args` argument. + +* `network_reg_dims`: Specify ranks for modules matching a regular expression. The format is a comma-separated string of `pattern=rank`. + * Example: `--network_args "network_reg_dims=single.*_modulation.*=4,img_attn=8"` + * This sets the rank to 4 for modules whose names contain `single` and contain `_modulation`, and to 8 for modules containing `img_attn`. +* `network_reg_lrs`: Specify learning rates for modules matching a regular expression. The format is a comma-separated string of `pattern=lr`. + * Example: `--network_args "network_reg_lrs=single_blocks_(\d|10)_=1e-3,double_blocks=2e-3"` + * This sets the learning rate to `1e-3` for modules whose names contain `single_blocks` followed by a digit (`0` to `9`) or `10`, and to `2e-3` for modules whose names contain `double_blocks`. + +**Notes:** + +* Settings via `network_reg_dims` and `network_reg_lrs` take precedence over the global `--network_dim` and `--learning_rate` settings. +* If a module name matches multiple patterns, the setting from the last matching pattern in the string will be applied. +* These settings are applied after the block-specific training settings (`train_double_block_indices`, `train_single_block_indices`). + +
+日本語 + +正規表現を用いて、LoRAのモジュールごとにランク(dim)や学習率を指定することができます。これにより、層ごとの指定よりも柔軟できめ細やかな制御が可能になります。 + +これらの設定は `network_args` 引数で指定します。 + +* `network_reg_dims`: 正規表現にマッチするモジュールに対してランクを指定します。`pattern=rank` という形式の文字列をカンマで区切って指定します。 + * 例: `--network_args "network_reg_dims=single.*_modulation.*=4,img_attn=8"` + * この例では、名前に `single` で始まり `_modulation` を含むモジュールのランクを4に、`img_attn` を含むモジュールのランクを8に設定します。 +* `network_reg_lrs`: 正規表現にマッチするモジュールに対して学習率を指定します。`pattern=lr` という形式の文字列をカンマで区切って指定します。 + * 例: `--network_args "network_reg_lrs=single_blocks_(\d|10)_=1e-3,double_blocks=2e-3"` + * この例では、名前が `single_blocks` で始まり、後に数字(`0`から`9`)または`10`が続くモジュールの学習率を `1e-3` に、`double_blocks` を含むモジュールの学習率を `2e-3` に設定します。 +**注意点:** + +* `network_reg_dims` および `network_reg_lrs` での設定は、全体設定である `--network_dim` や `--learning_rate` よりも優先されます。 +* あるモジュール名が複数のパターンにマッチした場合、文字列の中で後方にあるパターンの設定が適用されます。 +* これらの設定は、ブロック指定(`train_double_block_indices`, `train_single_block_indices`)が適用された後に行われます。 + +
+ +### 6.6. Text Encoder LoRA Support / Text Encoder LoRAのサポート FLUX.1 LoRA training supports training CLIP-L and T5XXL LoRA: @@ -417,7 +460,7 @@ FLUX.1 LoRA学習は、CLIP-LとT5XXL LoRAのトレーニングもサポート -### 6.6. Multi-Resolution Training / マルチ解像度トレーニング +### 6.7. Multi-Resolution Training / マルチ解像度トレーニング You can define multiple resolutions in the dataset configuration file, with different batch sizes for each resolution. @@ -462,7 +505,7 @@ resolution = [768, 768] -### 6.7. Validation / 検証 +### 6.8. Validation / 検証 You can calculate validation loss during training using a validation dataset to evaluate model generalization performance. diff --git a/networks/lora_flux.py b/networks/lora_flux.py index ddc91608..320bc463 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -156,11 +156,19 @@ class LoRAModule(torch.nn.Module): lx = self.lora_up(lx) # LoRA Gradient-Guided Perturbation Optimization - if self.training and self.ggpo_sigma is not None and self.ggpo_beta is not None and self.combined_weight_norms is not None and self.grad_norms is not None: + if ( + self.training + and self.ggpo_sigma is not None + and self.ggpo_beta is not None + and self.combined_weight_norms is not None + and self.grad_norms is not None + ): with torch.no_grad(): - perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms ** 2)) + (self.ggpo_beta * (self.grad_norms ** 2)) + perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms**2)) + ( + self.ggpo_beta * (self.grad_norms**2) + ) perturbation_scale_factor = (perturbation_scale * self.perturbation_norm_factor).to(self.device) - perturbation = torch.randn(self.org_module_shape, dtype=self.dtype, device=self.device) + perturbation = torch.randn(self.org_module_shape, dtype=self.dtype, device=self.device) perturbation.mul_(perturbation_scale_factor) perturbation_output = x @ perturbation.T # Result: (batch × n) return org_forwarded + (self.multiplier * scale * lx) + perturbation_output @@ -197,24 +205,24 @@ class LoRAModule(torch.nn.Module): # Choose a reasonable sample size n_rows = org_module_weight.shape[0] sample_size = min(1000, n_rows) # Cap at 1000 samples or use all if smaller - + # Sample random indices across all rows indices = torch.randperm(n_rows)[:sample_size] - + # Convert to a supported data type first, then index # Use float32 for indexing operations weights_float32 = org_module_weight.to(dtype=torch.float32) sampled_weights = weights_float32[indices].to(device=self.device) - + # Calculate sampled norms sampled_norms = torch.norm(sampled_weights, dim=1, keepdim=True) - + # Store the mean norm as our estimate self.org_weight_norm_estimate = sampled_norms.mean() - + # Optional: store standard deviation for confidence intervals self.org_weight_norm_std = sampled_norms.std() - + # Free memory del sampled_weights, weights_float32 @@ -223,37 +231,36 @@ class LoRAModule(torch.nn.Module): # Calculate the true norm (this will be slow but it's just for validation) true_norms = [] chunk_size = 1024 # Process in chunks to avoid OOM - + for i in range(0, org_module_weight.shape[0], chunk_size): end_idx = min(i + chunk_size, org_module_weight.shape[0]) chunk = org_module_weight[i:end_idx].to(device=self.device, dtype=self.dtype) chunk_norms = torch.norm(chunk, dim=1, keepdim=True) true_norms.append(chunk_norms.cpu()) del chunk - + true_norms = torch.cat(true_norms, dim=0) true_mean_norm = true_norms.mean().item() - + # Compare with our estimate estimated_norm = self.org_weight_norm_estimate.item() - + # Calculate error metrics absolute_error = abs(true_mean_norm - estimated_norm) relative_error = absolute_error / true_mean_norm * 100 # as percentage - + if verbose: logger.info(f"True mean norm: {true_mean_norm:.6f}") logger.info(f"Estimated norm: {estimated_norm:.6f}") logger.info(f"Absolute error: {absolute_error:.6f}") logger.info(f"Relative error: {relative_error:.2f}%") - - return { - 'true_mean_norm': true_mean_norm, - 'estimated_norm': estimated_norm, - 'absolute_error': absolute_error, - 'relative_error': relative_error - } + return { + "true_mean_norm": true_mean_norm, + "estimated_norm": estimated_norm, + "absolute_error": absolute_error, + "relative_error": relative_error, + } @torch.no_grad() def update_norms(self): @@ -261,7 +268,7 @@ class LoRAModule(torch.nn.Module): if self.ggpo_beta is None or self.ggpo_sigma is None: return - # only update norms when we are training + # only update norms when we are training if self.training is False: return @@ -269,8 +276,9 @@ class LoRAModule(torch.nn.Module): module_weights.mul(self.scale) self.weight_norms = torch.norm(module_weights, dim=1, keepdim=True) - self.combined_weight_norms = torch.sqrt((self.org_weight_norm_estimate**2) + - torch.sum(module_weights**2, dim=1, keepdim=True)) + self.combined_weight_norms = torch.sqrt( + (self.org_weight_norm_estimate**2) + torch.sum(module_weights**2, dim=1, keepdim=True) + ) @torch.no_grad() def update_grad_norms(self): @@ -293,7 +301,6 @@ class LoRAModule(torch.nn.Module): approx_grad = self.scale * ((self.lora_up.weight @ lora_down_grad) + (lora_up_grad @ self.lora_down.weight)) self.grad_norms = torch.norm(approx_grad, dim=1, keepdim=True) - @property def device(self): return next(self.parameters()).device @@ -564,7 +571,6 @@ def create_network( if ggpo_sigma is not None: ggpo_sigma = float(ggpo_sigma) - # train T5XXL train_t5xxl = kwargs.get("train_t5xxl", False) if train_t5xxl is not None: @@ -575,6 +581,42 @@ def create_network( if verbose is not None: verbose = True if verbose == "True" else False + # regex-specific learning rates + def parse_kv_pairs(kv_pair_str: str, is_int: bool) -> Dict[str, float]: + """ + Parse a string of key-value pairs separated by commas. + """ + pairs = {} + for pair in kv_pair_str.split(","): + pair = pair.strip() + if not pair: + continue + if "=" not in pair: + logger.warning(f"Invalid format: {pair}, expected 'key=value'") + continue + key, value = pair.split("=", 1) + key = key.strip() + value = value.strip() + try: + pairs[key] = int(value) if is_int else float(value) + except ValueError: + logger.warning(f"Invalid value for {key}: {value}") + return pairs + + # parse regular expression based learning rates + network_reg_lrs = kwargs.get("network_reg_lrs", None) + if network_reg_lrs is not None: + reg_lrs = parse_kv_pairs(network_reg_lrs, is_int=False) + else: + reg_lrs = None + + # regex-specific dimensions (ranks) + network_reg_dims = kwargs.get("network_reg_dims", None) + if network_reg_dims is not None: + reg_dims = parse_kv_pairs(network_reg_dims, is_int=True) + else: + reg_dims = None + # すごく引数が多いな ( ^ω^)・・・ network = LoRANetwork( text_encoders, @@ -594,8 +636,10 @@ def create_network( in_dims=in_dims, train_double_block_indices=train_double_block_indices, train_single_block_indices=train_single_block_indices, + reg_dims=reg_dims, ggpo_beta=ggpo_beta, ggpo_sigma=ggpo_sigma, + reg_lrs=reg_lrs, verbose=verbose, ) @@ -613,7 +657,6 @@ def create_network( # Create network from weights for inference, weights are not loaded here (because can be merged) def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weights_sd=None, for_inference=False, **kwargs): - # if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True if weights_sd is None: if os.path.splitext(file)[1] == ".safetensors": from safetensors.torch import load_file, safe_open @@ -644,22 +687,6 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh if train_t5xxl is None: train_t5xxl = False - # # split qkv - # double_qkv_rank = None - # single_qkv_rank = None - # rank = None - # for lora_name, dim in modules_dim.items(): - # if "double" in lora_name and "qkv" in lora_name: - # double_qkv_rank = dim - # elif "single" in lora_name and "linear1" in lora_name: - # single_qkv_rank = dim - # elif rank is None: - # rank = dim - # if double_qkv_rank is not None and single_qkv_rank is not None and rank is not None: - # break - # split_qkv = (double_qkv_rank is not None and double_qkv_rank != rank) or ( - # single_qkv_rank is not None and single_qkv_rank != rank - # ) split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined module_class = LoRAInfModule if for_inference else LoRAModule @@ -708,8 +735,10 @@ class LoRANetwork(torch.nn.Module): in_dims: Optional[List[int]] = None, train_double_block_indices: Optional[List[bool]] = None, train_single_block_indices: Optional[List[bool]] = None, + reg_dims: Optional[Dict[str, int]] = None, ggpo_beta: Optional[float] = None, ggpo_sigma: Optional[float] = None, + reg_lrs: Optional[Dict[str, float]] = None, verbose: Optional[bool] = False, ) -> None: super().__init__() @@ -730,6 +759,8 @@ class LoRANetwork(torch.nn.Module): self.in_dims = in_dims self.train_double_block_indices = train_double_block_indices self.train_single_block_indices = train_single_block_indices + self.reg_dims = reg_dims + self.reg_lrs = reg_lrs self.loraplus_lr_ratio = None self.loraplus_unet_lr_ratio = None @@ -757,7 +788,6 @@ class LoRANetwork(torch.nn.Module): if self.train_blocks is not None: logger.info(f"train {self.train_blocks} blocks only") - if train_t5xxl: logger.info(f"train T5XXL as well") @@ -803,8 +833,16 @@ class LoRANetwork(torch.nn.Module): if lora_name in modules_dim: dim = modules_dim[lora_name] alpha = modules_alpha[lora_name] - else: - # 通常、すべて対象とする + elif self.reg_dims is not None: + for reg, d in self.reg_dims.items(): + if re.search(reg, lora_name): + dim = d + alpha = self.alpha + logger.info(f"LoRA {lora_name} matched with regex {reg}, using dim: {dim}") + break + + # 通常、すべて対象とする + if dim is None: if is_linear or is_conv2d_1x1: dim = default_dim if default_dim is not None else self.lora_dim alpha = self.alpha @@ -979,7 +1017,6 @@ class LoRANetwork(torch.nn.Module): combined_weight_norms.append(lora.combined_weight_norms.mean(dim=0)) return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else None - def load_weights(self, file): if os.path.splitext(file)[1] == ".safetensors": from safetensors.torch import load_file @@ -1166,17 +1203,77 @@ class LoRANetwork(torch.nn.Module): all_params = [] lr_descriptions = [] + reg_lrs_list = list(self.reg_lrs.items()) if self.reg_lrs is not None else [] + def assemble_params(loras, lr, loraplus_ratio): param_groups = {"lora": {}, "plus": {}} + # regular expression param groups: {"reg_lr_0": {"lora": {}, "plus": {}}, ...} + reg_groups = {} + for lora in loras: + # check if this lora matches any regex learning rate + matched_reg_lr = None + for i, (regex_str, reg_lr) in enumerate(reg_lrs_list): + try: + if re.search(regex_str, lora.lora_name): + matched_reg_lr = (i, reg_lr) + logger.info(f"Module {lora.lora_name} matched regex '{regex_str}' -> LR {reg_lr}") + break + except re.error: + # regex error should have been caught during parsing, but just in case + continue + for name, param in lora.named_parameters(): - if loraplus_ratio is not None and "lora_up" in name: - param_groups["plus"][f"{lora.lora_name}.{name}"] = param + param_key = f"{lora.lora_name}.{name}" + is_plus = loraplus_ratio is not None and "lora_up" in name + + if matched_reg_lr is not None: + # use regex-specific learning rate + reg_idx, reg_lr = matched_reg_lr + group_key = f"reg_lr_{reg_idx}" + if group_key not in reg_groups: + reg_groups[group_key] = {"lora": {}, "plus": {}, "lr": reg_lr} + + if is_plus: + reg_groups[group_key]["plus"][param_key] = param + else: + reg_groups[group_key]["lora"][param_key] = param else: - param_groups["lora"][f"{lora.lora_name}.{name}"] = param + # use default learning rate + if is_plus: + param_groups["plus"][param_key] = param + else: + param_groups["lora"][param_key] = param params = [] descriptions = [] + + # process regex-specific groups first (higher priority) + for group_key in sorted(reg_groups.keys()): + group = reg_groups[group_key] + reg_lr = group["lr"] + + for param_type in ["lora", "plus"]: + if len(group[param_type]) == 0: + continue + + param_data = {"params": group[param_type].values()} + + if param_type == "plus" and loraplus_ratio is not None: + param_data["lr"] = reg_lr * loraplus_ratio + else: + param_data["lr"] = reg_lr + + if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: + continue + + params.append(param_data) + desc = f"reg_lr_{group_key.split('_')[-1]}" + if param_type == "plus": + desc += " plus" + descriptions.append(desc) + + # process default groups for key in param_groups.keys(): param_data = {"params": param_groups[key].values()} diff --git a/train_network.py b/train_network.py index 6073c4c3..7861e740 100644 --- a/train_network.py +++ b/train_network.py @@ -645,7 +645,7 @@ class NetworkTrainer: net_kwargs = {} if args.network_args is not None: for net_arg in args.network_args: - key, value = net_arg.split("=") + key, value = net_arg.split("=", 1) net_kwargs[key] = value # if a new network is added in future, add if ~ then blocks for each network (;'∀')