feat: add regex-based rank and learning rate configuration for FLUX.1 LoRA

This commit is contained in:
Kohya S
2025-07-26 19:35:42 +09:00
parent 32f06012a7
commit c28e7a47c3
3 changed files with 194 additions and 54 deletions

View File

@@ -398,7 +398,50 @@ FLUX.1 LoRA学習では、network_argsの`train_double_block_indices`と`train_s
</details>
### 6.5. Text Encoder LoRA Support / Text Encoder LoRAのサポート
</details>
### 6.4. Regular Expression-based Rank/LR Configuration / 正規表現によるランク・学習率の指定
You can specify ranks (dims) and learning rates for LoRA modules using regular expressions. This allows for more flexible and fine-grained control than specifying by layer.
These settings are specified via the `network_args` argument.
* `network_reg_dims`: Specify ranks for modules matching a regular expression. The format is a comma-separated string of `pattern=rank`.
* Example: `--network_args "network_reg_dims=single.*_modulation.*=4,img_attn=8"`
* This sets the rank to 4 for modules whose names contain `single` and contain `_modulation`, and to 8 for modules containing `img_attn`.
* `network_reg_lrs`: Specify learning rates for modules matching a regular expression. The format is a comma-separated string of `pattern=lr`.
* Example: `--network_args "network_reg_lrs=single_blocks_(\d|10)_=1e-3,double_blocks=2e-3"`
* This sets the learning rate to `1e-3` for modules whose names contain `single_blocks` followed by a digit (`0` to `9`) or `10`, and to `2e-3` for modules whose names contain `double_blocks`.
**Notes:**
* Settings via `network_reg_dims` and `network_reg_lrs` take precedence over the global `--network_dim` and `--learning_rate` settings.
* If a module name matches multiple patterns, the setting from the last matching pattern in the string will be applied.
* These settings are applied after the block-specific training settings (`train_double_block_indices`, `train_single_block_indices`).
<details>
<summary>日本語</summary>
正規表現を用いて、LoRAのモジュールごとにランクdimや学習率を指定することができます。これにより、層ごとの指定よりも柔軟できめ細やかな制御が可能になります。
これらの設定は `network_args` 引数で指定します。
* `network_reg_dims`: 正規表現にマッチするモジュールに対してランクを指定します。`pattern=rank` という形式の文字列をカンマで区切って指定します。
* 例: `--network_args "network_reg_dims=single.*_modulation.*=4,img_attn=8"`
* この例では、名前に `single` で始まり `_modulation` を含むモジュールのランクを4に、`img_attn` を含むモジュールのランクを8に設定します。
* `network_reg_lrs`: 正規表現にマッチするモジュールに対して学習率を指定します。`pattern=lr` という形式の文字列をカンマで区切って指定します。
* 例: `--network_args "network_reg_lrs=single_blocks_(\d|10)_=1e-3,double_blocks=2e-3"`
* この例では、名前が `single_blocks` で始まり、後に数字(`0`から`9`)または`10`が続くモジュールの学習率を `1e-3` に、`double_blocks` を含むモジュールの学習率を `2e-3` に設定します。
**注意点:**
* `network_reg_dims` および `network_reg_lrs` での設定は、全体設定である `--network_dim` や `--learning_rate` よりも優先されます。
* あるモジュール名が複数のパターンにマッチした場合、文字列の中で後方にあるパターンの設定が適用されます。
* これらの設定は、ブロック指定(`train_double_block_indices`, `train_single_block_indices`)が適用された後に行われます。
</details>
### 6.6. Text Encoder LoRA Support / Text Encoder LoRAのサポート
FLUX.1 LoRA training supports training CLIP-L and T5XXL LoRA:
@@ -417,7 +460,7 @@ FLUX.1 LoRA学習は、CLIP-LとT5XXL LoRAのトレーニングもサポート
</details>
### 6.6. Multi-Resolution Training / マルチ解像度トレーニング
### 6.7. Multi-Resolution Training / マルチ解像度トレーニング
You can define multiple resolutions in the dataset configuration file, with different batch sizes for each resolution.
@@ -462,7 +505,7 @@ resolution = [768, 768]
</details>
### 6.7. Validation / 検証
### 6.8. Validation / 検証
You can calculate validation loss during training using a validation dataset to evaluate model generalization performance.

View File

@@ -156,11 +156,19 @@ class LoRAModule(torch.nn.Module):
lx = self.lora_up(lx)
# LoRA Gradient-Guided Perturbation Optimization
if self.training and self.ggpo_sigma is not None and self.ggpo_beta is not None and self.combined_weight_norms is not None and self.grad_norms is not None:
if (
self.training
and self.ggpo_sigma is not None
and self.ggpo_beta is not None
and self.combined_weight_norms is not None
and self.grad_norms is not None
):
with torch.no_grad():
perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms ** 2)) + (self.ggpo_beta * (self.grad_norms ** 2))
perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms**2)) + (
self.ggpo_beta * (self.grad_norms**2)
)
perturbation_scale_factor = (perturbation_scale * self.perturbation_norm_factor).to(self.device)
perturbation = torch.randn(self.org_module_shape, dtype=self.dtype, device=self.device)
perturbation = torch.randn(self.org_module_shape, dtype=self.dtype, device=self.device)
perturbation.mul_(perturbation_scale_factor)
perturbation_output = x @ perturbation.T # Result: (batch × n)
return org_forwarded + (self.multiplier * scale * lx) + perturbation_output
@@ -197,24 +205,24 @@ class LoRAModule(torch.nn.Module):
# Choose a reasonable sample size
n_rows = org_module_weight.shape[0]
sample_size = min(1000, n_rows) # Cap at 1000 samples or use all if smaller
# Sample random indices across all rows
indices = torch.randperm(n_rows)[:sample_size]
# Convert to a supported data type first, then index
# Use float32 for indexing operations
weights_float32 = org_module_weight.to(dtype=torch.float32)
sampled_weights = weights_float32[indices].to(device=self.device)
# Calculate sampled norms
sampled_norms = torch.norm(sampled_weights, dim=1, keepdim=True)
# Store the mean norm as our estimate
self.org_weight_norm_estimate = sampled_norms.mean()
# Optional: store standard deviation for confidence intervals
self.org_weight_norm_std = sampled_norms.std()
# Free memory
del sampled_weights, weights_float32
@@ -223,37 +231,36 @@ class LoRAModule(torch.nn.Module):
# Calculate the true norm (this will be slow but it's just for validation)
true_norms = []
chunk_size = 1024 # Process in chunks to avoid OOM
for i in range(0, org_module_weight.shape[0], chunk_size):
end_idx = min(i + chunk_size, org_module_weight.shape[0])
chunk = org_module_weight[i:end_idx].to(device=self.device, dtype=self.dtype)
chunk_norms = torch.norm(chunk, dim=1, keepdim=True)
true_norms.append(chunk_norms.cpu())
del chunk
true_norms = torch.cat(true_norms, dim=0)
true_mean_norm = true_norms.mean().item()
# Compare with our estimate
estimated_norm = self.org_weight_norm_estimate.item()
# Calculate error metrics
absolute_error = abs(true_mean_norm - estimated_norm)
relative_error = absolute_error / true_mean_norm * 100 # as percentage
if verbose:
logger.info(f"True mean norm: {true_mean_norm:.6f}")
logger.info(f"Estimated norm: {estimated_norm:.6f}")
logger.info(f"Absolute error: {absolute_error:.6f}")
logger.info(f"Relative error: {relative_error:.2f}%")
return {
'true_mean_norm': true_mean_norm,
'estimated_norm': estimated_norm,
'absolute_error': absolute_error,
'relative_error': relative_error
}
return {
"true_mean_norm": true_mean_norm,
"estimated_norm": estimated_norm,
"absolute_error": absolute_error,
"relative_error": relative_error,
}
@torch.no_grad()
def update_norms(self):
@@ -261,7 +268,7 @@ class LoRAModule(torch.nn.Module):
if self.ggpo_beta is None or self.ggpo_sigma is None:
return
# only update norms when we are training
# only update norms when we are training
if self.training is False:
return
@@ -269,8 +276,9 @@ class LoRAModule(torch.nn.Module):
module_weights.mul(self.scale)
self.weight_norms = torch.norm(module_weights, dim=1, keepdim=True)
self.combined_weight_norms = torch.sqrt((self.org_weight_norm_estimate**2) +
torch.sum(module_weights**2, dim=1, keepdim=True))
self.combined_weight_norms = torch.sqrt(
(self.org_weight_norm_estimate**2) + torch.sum(module_weights**2, dim=1, keepdim=True)
)
@torch.no_grad()
def update_grad_norms(self):
@@ -293,7 +301,6 @@ class LoRAModule(torch.nn.Module):
approx_grad = self.scale * ((self.lora_up.weight @ lora_down_grad) + (lora_up_grad @ self.lora_down.weight))
self.grad_norms = torch.norm(approx_grad, dim=1, keepdim=True)
@property
def device(self):
return next(self.parameters()).device
@@ -564,7 +571,6 @@ def create_network(
if ggpo_sigma is not None:
ggpo_sigma = float(ggpo_sigma)
# train T5XXL
train_t5xxl = kwargs.get("train_t5xxl", False)
if train_t5xxl is not None:
@@ -575,6 +581,42 @@ def create_network(
if verbose is not None:
verbose = True if verbose == "True" else False
# regex-specific learning rates
def parse_kv_pairs(kv_pair_str: str, is_int: bool) -> Dict[str, float]:
"""
Parse a string of key-value pairs separated by commas.
"""
pairs = {}
for pair in kv_pair_str.split(","):
pair = pair.strip()
if not pair:
continue
if "=" not in pair:
logger.warning(f"Invalid format: {pair}, expected 'key=value'")
continue
key, value = pair.split("=", 1)
key = key.strip()
value = value.strip()
try:
pairs[key] = int(value) if is_int else float(value)
except ValueError:
logger.warning(f"Invalid value for {key}: {value}")
return pairs
# parse regular expression based learning rates
network_reg_lrs = kwargs.get("network_reg_lrs", None)
if network_reg_lrs is not None:
reg_lrs = parse_kv_pairs(network_reg_lrs, is_int=False)
else:
reg_lrs = None
# regex-specific dimensions (ranks)
network_reg_dims = kwargs.get("network_reg_dims", None)
if network_reg_dims is not None:
reg_dims = parse_kv_pairs(network_reg_dims, is_int=True)
else:
reg_dims = None
# すごく引数が多いな ( ^ω^)・・・
network = LoRANetwork(
text_encoders,
@@ -594,8 +636,10 @@ def create_network(
in_dims=in_dims,
train_double_block_indices=train_double_block_indices,
train_single_block_indices=train_single_block_indices,
reg_dims=reg_dims,
ggpo_beta=ggpo_beta,
ggpo_sigma=ggpo_sigma,
reg_lrs=reg_lrs,
verbose=verbose,
)
@@ -613,7 +657,6 @@ def create_network(
# Create network from weights for inference, weights are not loaded here (because can be merged)
def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weights_sd=None, for_inference=False, **kwargs):
# if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True
if weights_sd is None:
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file, safe_open
@@ -644,22 +687,6 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh
if train_t5xxl is None:
train_t5xxl = False
# # split qkv
# double_qkv_rank = None
# single_qkv_rank = None
# rank = None
# for lora_name, dim in modules_dim.items():
# if "double" in lora_name and "qkv" in lora_name:
# double_qkv_rank = dim
# elif "single" in lora_name and "linear1" in lora_name:
# single_qkv_rank = dim
# elif rank is None:
# rank = dim
# if double_qkv_rank is not None and single_qkv_rank is not None and rank is not None:
# break
# split_qkv = (double_qkv_rank is not None and double_qkv_rank != rank) or (
# single_qkv_rank is not None and single_qkv_rank != rank
# )
split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined
module_class = LoRAInfModule if for_inference else LoRAModule
@@ -708,8 +735,10 @@ class LoRANetwork(torch.nn.Module):
in_dims: Optional[List[int]] = None,
train_double_block_indices: Optional[List[bool]] = None,
train_single_block_indices: Optional[List[bool]] = None,
reg_dims: Optional[Dict[str, int]] = None,
ggpo_beta: Optional[float] = None,
ggpo_sigma: Optional[float] = None,
reg_lrs: Optional[Dict[str, float]] = None,
verbose: Optional[bool] = False,
) -> None:
super().__init__()
@@ -730,6 +759,8 @@ class LoRANetwork(torch.nn.Module):
self.in_dims = in_dims
self.train_double_block_indices = train_double_block_indices
self.train_single_block_indices = train_single_block_indices
self.reg_dims = reg_dims
self.reg_lrs = reg_lrs
self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None
@@ -757,7 +788,6 @@ class LoRANetwork(torch.nn.Module):
if self.train_blocks is not None:
logger.info(f"train {self.train_blocks} blocks only")
if train_t5xxl:
logger.info(f"train T5XXL as well")
@@ -803,8 +833,16 @@ class LoRANetwork(torch.nn.Module):
if lora_name in modules_dim:
dim = modules_dim[lora_name]
alpha = modules_alpha[lora_name]
else:
# 通常、すべて対象とする
elif self.reg_dims is not None:
for reg, d in self.reg_dims.items():
if re.search(reg, lora_name):
dim = d
alpha = self.alpha
logger.info(f"LoRA {lora_name} matched with regex {reg}, using dim: {dim}")
break
# 通常、すべて対象とする
if dim is None:
if is_linear or is_conv2d_1x1:
dim = default_dim if default_dim is not None else self.lora_dim
alpha = self.alpha
@@ -979,7 +1017,6 @@ class LoRANetwork(torch.nn.Module):
combined_weight_norms.append(lora.combined_weight_norms.mean(dim=0))
return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else None
def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
@@ -1166,17 +1203,77 @@ class LoRANetwork(torch.nn.Module):
all_params = []
lr_descriptions = []
reg_lrs_list = list(self.reg_lrs.items()) if self.reg_lrs is not None else []
def assemble_params(loras, lr, loraplus_ratio):
param_groups = {"lora": {}, "plus": {}}
# regular expression param groups: {"reg_lr_0": {"lora": {}, "plus": {}}, ...}
reg_groups = {}
for lora in loras:
# check if this lora matches any regex learning rate
matched_reg_lr = None
for i, (regex_str, reg_lr) in enumerate(reg_lrs_list):
try:
if re.search(regex_str, lora.lora_name):
matched_reg_lr = (i, reg_lr)
logger.info(f"Module {lora.lora_name} matched regex '{regex_str}' -> LR {reg_lr}")
break
except re.error:
# regex error should have been caught during parsing, but just in case
continue
for name, param in lora.named_parameters():
if loraplus_ratio is not None and "lora_up" in name:
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
param_key = f"{lora.lora_name}.{name}"
is_plus = loraplus_ratio is not None and "lora_up" in name
if matched_reg_lr is not None:
# use regex-specific learning rate
reg_idx, reg_lr = matched_reg_lr
group_key = f"reg_lr_{reg_idx}"
if group_key not in reg_groups:
reg_groups[group_key] = {"lora": {}, "plus": {}, "lr": reg_lr}
if is_plus:
reg_groups[group_key]["plus"][param_key] = param
else:
reg_groups[group_key]["lora"][param_key] = param
else:
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
# use default learning rate
if is_plus:
param_groups["plus"][param_key] = param
else:
param_groups["lora"][param_key] = param
params = []
descriptions = []
# process regex-specific groups first (higher priority)
for group_key in sorted(reg_groups.keys()):
group = reg_groups[group_key]
reg_lr = group["lr"]
for param_type in ["lora", "plus"]:
if len(group[param_type]) == 0:
continue
param_data = {"params": group[param_type].values()}
if param_type == "plus" and loraplus_ratio is not None:
param_data["lr"] = reg_lr * loraplus_ratio
else:
param_data["lr"] = reg_lr
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
continue
params.append(param_data)
desc = f"reg_lr_{group_key.split('_')[-1]}"
if param_type == "plus":
desc += " plus"
descriptions.append(desc)
# process default groups
for key in param_groups.keys():
param_data = {"params": param_groups[key].values()}

View File

@@ -645,7 +645,7 @@ class NetworkTrainer:
net_kwargs = {}
if args.network_args is not None:
for net_arg in args.network_args:
key, value = net_arg.split("=")
key, value = net_arg.split("=", 1)
net_kwargs[key] = value
# if a new network is added in future, add if ~ then blocks for each network (;'∀')