mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
feat: add regex-based rank and learning rate configuration for FLUX.1 LoRA
This commit is contained in:
@@ -398,7 +398,50 @@ FLUX.1 LoRA学習では、network_argsの`train_double_block_indices`と`train_s
|
||||
|
||||
</details>
|
||||
|
||||
### 6.5. Text Encoder LoRA Support / Text Encoder LoRAのサポート
|
||||
|
||||
</details>
|
||||
|
||||
### 6.4. Regular Expression-based Rank/LR Configuration / 正規表現によるランク・学習率の指定
|
||||
|
||||
You can specify ranks (dims) and learning rates for LoRA modules using regular expressions. This allows for more flexible and fine-grained control than specifying by layer.
|
||||
|
||||
These settings are specified via the `network_args` argument.
|
||||
|
||||
* `network_reg_dims`: Specify ranks for modules matching a regular expression. The format is a comma-separated string of `pattern=rank`.
|
||||
* Example: `--network_args "network_reg_dims=single.*_modulation.*=4,img_attn=8"`
|
||||
* This sets the rank to 4 for modules whose names contain `single` and contain `_modulation`, and to 8 for modules containing `img_attn`.
|
||||
* `network_reg_lrs`: Specify learning rates for modules matching a regular expression. The format is a comma-separated string of `pattern=lr`.
|
||||
* Example: `--network_args "network_reg_lrs=single_blocks_(\d|10)_=1e-3,double_blocks=2e-3"`
|
||||
* This sets the learning rate to `1e-3` for modules whose names contain `single_blocks` followed by a digit (`0` to `9`) or `10`, and to `2e-3` for modules whose names contain `double_blocks`.
|
||||
|
||||
**Notes:**
|
||||
|
||||
* Settings via `network_reg_dims` and `network_reg_lrs` take precedence over the global `--network_dim` and `--learning_rate` settings.
|
||||
* If a module name matches multiple patterns, the setting from the last matching pattern in the string will be applied.
|
||||
* These settings are applied after the block-specific training settings (`train_double_block_indices`, `train_single_block_indices`).
|
||||
|
||||
<details>
|
||||
<summary>日本語</summary>
|
||||
|
||||
正規表現を用いて、LoRAのモジュールごとにランク(dim)や学習率を指定することができます。これにより、層ごとの指定よりも柔軟できめ細やかな制御が可能になります。
|
||||
|
||||
これらの設定は `network_args` 引数で指定します。
|
||||
|
||||
* `network_reg_dims`: 正規表現にマッチするモジュールに対してランクを指定します。`pattern=rank` という形式の文字列をカンマで区切って指定します。
|
||||
* 例: `--network_args "network_reg_dims=single.*_modulation.*=4,img_attn=8"`
|
||||
* この例では、名前に `single` で始まり `_modulation` を含むモジュールのランクを4に、`img_attn` を含むモジュールのランクを8に設定します。
|
||||
* `network_reg_lrs`: 正規表現にマッチするモジュールに対して学習率を指定します。`pattern=lr` という形式の文字列をカンマで区切って指定します。
|
||||
* 例: `--network_args "network_reg_lrs=single_blocks_(\d|10)_=1e-3,double_blocks=2e-3"`
|
||||
* この例では、名前が `single_blocks` で始まり、後に数字(`0`から`9`)または`10`が続くモジュールの学習率を `1e-3` に、`double_blocks` を含むモジュールの学習率を `2e-3` に設定します。
|
||||
**注意点:**
|
||||
|
||||
* `network_reg_dims` および `network_reg_lrs` での設定は、全体設定である `--network_dim` や `--learning_rate` よりも優先されます。
|
||||
* あるモジュール名が複数のパターンにマッチした場合、文字列の中で後方にあるパターンの設定が適用されます。
|
||||
* これらの設定は、ブロック指定(`train_double_block_indices`, `train_single_block_indices`)が適用された後に行われます。
|
||||
|
||||
</details>
|
||||
|
||||
### 6.6. Text Encoder LoRA Support / Text Encoder LoRAのサポート
|
||||
|
||||
FLUX.1 LoRA training supports training CLIP-L and T5XXL LoRA:
|
||||
|
||||
@@ -417,7 +460,7 @@ FLUX.1 LoRA学習は、CLIP-LとT5XXL LoRAのトレーニングもサポート
|
||||
|
||||
</details>
|
||||
|
||||
### 6.6. Multi-Resolution Training / マルチ解像度トレーニング
|
||||
### 6.7. Multi-Resolution Training / マルチ解像度トレーニング
|
||||
|
||||
You can define multiple resolutions in the dataset configuration file, with different batch sizes for each resolution.
|
||||
|
||||
@@ -462,7 +505,7 @@ resolution = [768, 768]
|
||||
|
||||
</details>
|
||||
|
||||
### 6.7. Validation / 検証
|
||||
### 6.8. Validation / 検証
|
||||
|
||||
You can calculate validation loss during training using a validation dataset to evaluate model generalization performance.
|
||||
|
||||
|
||||
@@ -156,11 +156,19 @@ class LoRAModule(torch.nn.Module):
|
||||
lx = self.lora_up(lx)
|
||||
|
||||
# LoRA Gradient-Guided Perturbation Optimization
|
||||
if self.training and self.ggpo_sigma is not None and self.ggpo_beta is not None and self.combined_weight_norms is not None and self.grad_norms is not None:
|
||||
if (
|
||||
self.training
|
||||
and self.ggpo_sigma is not None
|
||||
and self.ggpo_beta is not None
|
||||
and self.combined_weight_norms is not None
|
||||
and self.grad_norms is not None
|
||||
):
|
||||
with torch.no_grad():
|
||||
perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms ** 2)) + (self.ggpo_beta * (self.grad_norms ** 2))
|
||||
perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms**2)) + (
|
||||
self.ggpo_beta * (self.grad_norms**2)
|
||||
)
|
||||
perturbation_scale_factor = (perturbation_scale * self.perturbation_norm_factor).to(self.device)
|
||||
perturbation = torch.randn(self.org_module_shape, dtype=self.dtype, device=self.device)
|
||||
perturbation = torch.randn(self.org_module_shape, dtype=self.dtype, device=self.device)
|
||||
perturbation.mul_(perturbation_scale_factor)
|
||||
perturbation_output = x @ perturbation.T # Result: (batch × n)
|
||||
return org_forwarded + (self.multiplier * scale * lx) + perturbation_output
|
||||
@@ -197,24 +205,24 @@ class LoRAModule(torch.nn.Module):
|
||||
# Choose a reasonable sample size
|
||||
n_rows = org_module_weight.shape[0]
|
||||
sample_size = min(1000, n_rows) # Cap at 1000 samples or use all if smaller
|
||||
|
||||
|
||||
# Sample random indices across all rows
|
||||
indices = torch.randperm(n_rows)[:sample_size]
|
||||
|
||||
|
||||
# Convert to a supported data type first, then index
|
||||
# Use float32 for indexing operations
|
||||
weights_float32 = org_module_weight.to(dtype=torch.float32)
|
||||
sampled_weights = weights_float32[indices].to(device=self.device)
|
||||
|
||||
|
||||
# Calculate sampled norms
|
||||
sampled_norms = torch.norm(sampled_weights, dim=1, keepdim=True)
|
||||
|
||||
|
||||
# Store the mean norm as our estimate
|
||||
self.org_weight_norm_estimate = sampled_norms.mean()
|
||||
|
||||
|
||||
# Optional: store standard deviation for confidence intervals
|
||||
self.org_weight_norm_std = sampled_norms.std()
|
||||
|
||||
|
||||
# Free memory
|
||||
del sampled_weights, weights_float32
|
||||
|
||||
@@ -223,37 +231,36 @@ class LoRAModule(torch.nn.Module):
|
||||
# Calculate the true norm (this will be slow but it's just for validation)
|
||||
true_norms = []
|
||||
chunk_size = 1024 # Process in chunks to avoid OOM
|
||||
|
||||
|
||||
for i in range(0, org_module_weight.shape[0], chunk_size):
|
||||
end_idx = min(i + chunk_size, org_module_weight.shape[0])
|
||||
chunk = org_module_weight[i:end_idx].to(device=self.device, dtype=self.dtype)
|
||||
chunk_norms = torch.norm(chunk, dim=1, keepdim=True)
|
||||
true_norms.append(chunk_norms.cpu())
|
||||
del chunk
|
||||
|
||||
|
||||
true_norms = torch.cat(true_norms, dim=0)
|
||||
true_mean_norm = true_norms.mean().item()
|
||||
|
||||
|
||||
# Compare with our estimate
|
||||
estimated_norm = self.org_weight_norm_estimate.item()
|
||||
|
||||
|
||||
# Calculate error metrics
|
||||
absolute_error = abs(true_mean_norm - estimated_norm)
|
||||
relative_error = absolute_error / true_mean_norm * 100 # as percentage
|
||||
|
||||
|
||||
if verbose:
|
||||
logger.info(f"True mean norm: {true_mean_norm:.6f}")
|
||||
logger.info(f"Estimated norm: {estimated_norm:.6f}")
|
||||
logger.info(f"Absolute error: {absolute_error:.6f}")
|
||||
logger.info(f"Relative error: {relative_error:.2f}%")
|
||||
|
||||
return {
|
||||
'true_mean_norm': true_mean_norm,
|
||||
'estimated_norm': estimated_norm,
|
||||
'absolute_error': absolute_error,
|
||||
'relative_error': relative_error
|
||||
}
|
||||
|
||||
return {
|
||||
"true_mean_norm": true_mean_norm,
|
||||
"estimated_norm": estimated_norm,
|
||||
"absolute_error": absolute_error,
|
||||
"relative_error": relative_error,
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def update_norms(self):
|
||||
@@ -261,7 +268,7 @@ class LoRAModule(torch.nn.Module):
|
||||
if self.ggpo_beta is None or self.ggpo_sigma is None:
|
||||
return
|
||||
|
||||
# only update norms when we are training
|
||||
# only update norms when we are training
|
||||
if self.training is False:
|
||||
return
|
||||
|
||||
@@ -269,8 +276,9 @@ class LoRAModule(torch.nn.Module):
|
||||
module_weights.mul(self.scale)
|
||||
|
||||
self.weight_norms = torch.norm(module_weights, dim=1, keepdim=True)
|
||||
self.combined_weight_norms = torch.sqrt((self.org_weight_norm_estimate**2) +
|
||||
torch.sum(module_weights**2, dim=1, keepdim=True))
|
||||
self.combined_weight_norms = torch.sqrt(
|
||||
(self.org_weight_norm_estimate**2) + torch.sum(module_weights**2, dim=1, keepdim=True)
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def update_grad_norms(self):
|
||||
@@ -293,7 +301,6 @@ class LoRAModule(torch.nn.Module):
|
||||
approx_grad = self.scale * ((self.lora_up.weight @ lora_down_grad) + (lora_up_grad @ self.lora_down.weight))
|
||||
self.grad_norms = torch.norm(approx_grad, dim=1, keepdim=True)
|
||||
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
@@ -564,7 +571,6 @@ def create_network(
|
||||
if ggpo_sigma is not None:
|
||||
ggpo_sigma = float(ggpo_sigma)
|
||||
|
||||
|
||||
# train T5XXL
|
||||
train_t5xxl = kwargs.get("train_t5xxl", False)
|
||||
if train_t5xxl is not None:
|
||||
@@ -575,6 +581,42 @@ def create_network(
|
||||
if verbose is not None:
|
||||
verbose = True if verbose == "True" else False
|
||||
|
||||
# regex-specific learning rates
|
||||
def parse_kv_pairs(kv_pair_str: str, is_int: bool) -> Dict[str, float]:
|
||||
"""
|
||||
Parse a string of key-value pairs separated by commas.
|
||||
"""
|
||||
pairs = {}
|
||||
for pair in kv_pair_str.split(","):
|
||||
pair = pair.strip()
|
||||
if not pair:
|
||||
continue
|
||||
if "=" not in pair:
|
||||
logger.warning(f"Invalid format: {pair}, expected 'key=value'")
|
||||
continue
|
||||
key, value = pair.split("=", 1)
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
try:
|
||||
pairs[key] = int(value) if is_int else float(value)
|
||||
except ValueError:
|
||||
logger.warning(f"Invalid value for {key}: {value}")
|
||||
return pairs
|
||||
|
||||
# parse regular expression based learning rates
|
||||
network_reg_lrs = kwargs.get("network_reg_lrs", None)
|
||||
if network_reg_lrs is not None:
|
||||
reg_lrs = parse_kv_pairs(network_reg_lrs, is_int=False)
|
||||
else:
|
||||
reg_lrs = None
|
||||
|
||||
# regex-specific dimensions (ranks)
|
||||
network_reg_dims = kwargs.get("network_reg_dims", None)
|
||||
if network_reg_dims is not None:
|
||||
reg_dims = parse_kv_pairs(network_reg_dims, is_int=True)
|
||||
else:
|
||||
reg_dims = None
|
||||
|
||||
# すごく引数が多いな ( ^ω^)・・・
|
||||
network = LoRANetwork(
|
||||
text_encoders,
|
||||
@@ -594,8 +636,10 @@ def create_network(
|
||||
in_dims=in_dims,
|
||||
train_double_block_indices=train_double_block_indices,
|
||||
train_single_block_indices=train_single_block_indices,
|
||||
reg_dims=reg_dims,
|
||||
ggpo_beta=ggpo_beta,
|
||||
ggpo_sigma=ggpo_sigma,
|
||||
reg_lrs=reg_lrs,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
@@ -613,7 +657,6 @@ def create_network(
|
||||
|
||||
# Create network from weights for inference, weights are not loaded here (because can be merged)
|
||||
def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weights_sd=None, for_inference=False, **kwargs):
|
||||
# if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True
|
||||
if weights_sd is None:
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file, safe_open
|
||||
@@ -644,22 +687,6 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh
|
||||
if train_t5xxl is None:
|
||||
train_t5xxl = False
|
||||
|
||||
# # split qkv
|
||||
# double_qkv_rank = None
|
||||
# single_qkv_rank = None
|
||||
# rank = None
|
||||
# for lora_name, dim in modules_dim.items():
|
||||
# if "double" in lora_name and "qkv" in lora_name:
|
||||
# double_qkv_rank = dim
|
||||
# elif "single" in lora_name and "linear1" in lora_name:
|
||||
# single_qkv_rank = dim
|
||||
# elif rank is None:
|
||||
# rank = dim
|
||||
# if double_qkv_rank is not None and single_qkv_rank is not None and rank is not None:
|
||||
# break
|
||||
# split_qkv = (double_qkv_rank is not None and double_qkv_rank != rank) or (
|
||||
# single_qkv_rank is not None and single_qkv_rank != rank
|
||||
# )
|
||||
split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined
|
||||
|
||||
module_class = LoRAInfModule if for_inference else LoRAModule
|
||||
@@ -708,8 +735,10 @@ class LoRANetwork(torch.nn.Module):
|
||||
in_dims: Optional[List[int]] = None,
|
||||
train_double_block_indices: Optional[List[bool]] = None,
|
||||
train_single_block_indices: Optional[List[bool]] = None,
|
||||
reg_dims: Optional[Dict[str, int]] = None,
|
||||
ggpo_beta: Optional[float] = None,
|
||||
ggpo_sigma: Optional[float] = None,
|
||||
reg_lrs: Optional[Dict[str, float]] = None,
|
||||
verbose: Optional[bool] = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -730,6 +759,8 @@ class LoRANetwork(torch.nn.Module):
|
||||
self.in_dims = in_dims
|
||||
self.train_double_block_indices = train_double_block_indices
|
||||
self.train_single_block_indices = train_single_block_indices
|
||||
self.reg_dims = reg_dims
|
||||
self.reg_lrs = reg_lrs
|
||||
|
||||
self.loraplus_lr_ratio = None
|
||||
self.loraplus_unet_lr_ratio = None
|
||||
@@ -757,7 +788,6 @@ class LoRANetwork(torch.nn.Module):
|
||||
if self.train_blocks is not None:
|
||||
logger.info(f"train {self.train_blocks} blocks only")
|
||||
|
||||
|
||||
if train_t5xxl:
|
||||
logger.info(f"train T5XXL as well")
|
||||
|
||||
@@ -803,8 +833,16 @@ class LoRANetwork(torch.nn.Module):
|
||||
if lora_name in modules_dim:
|
||||
dim = modules_dim[lora_name]
|
||||
alpha = modules_alpha[lora_name]
|
||||
else:
|
||||
# 通常、すべて対象とする
|
||||
elif self.reg_dims is not None:
|
||||
for reg, d in self.reg_dims.items():
|
||||
if re.search(reg, lora_name):
|
||||
dim = d
|
||||
alpha = self.alpha
|
||||
logger.info(f"LoRA {lora_name} matched with regex {reg}, using dim: {dim}")
|
||||
break
|
||||
|
||||
# 通常、すべて対象とする
|
||||
if dim is None:
|
||||
if is_linear or is_conv2d_1x1:
|
||||
dim = default_dim if default_dim is not None else self.lora_dim
|
||||
alpha = self.alpha
|
||||
@@ -979,7 +1017,6 @@ class LoRANetwork(torch.nn.Module):
|
||||
combined_weight_norms.append(lora.combined_weight_norms.mean(dim=0))
|
||||
return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else None
|
||||
|
||||
|
||||
def load_weights(self, file):
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
@@ -1166,17 +1203,77 @@ class LoRANetwork(torch.nn.Module):
|
||||
all_params = []
|
||||
lr_descriptions = []
|
||||
|
||||
reg_lrs_list = list(self.reg_lrs.items()) if self.reg_lrs is not None else []
|
||||
|
||||
def assemble_params(loras, lr, loraplus_ratio):
|
||||
param_groups = {"lora": {}, "plus": {}}
|
||||
# regular expression param groups: {"reg_lr_0": {"lora": {}, "plus": {}}, ...}
|
||||
reg_groups = {}
|
||||
|
||||
for lora in loras:
|
||||
# check if this lora matches any regex learning rate
|
||||
matched_reg_lr = None
|
||||
for i, (regex_str, reg_lr) in enumerate(reg_lrs_list):
|
||||
try:
|
||||
if re.search(regex_str, lora.lora_name):
|
||||
matched_reg_lr = (i, reg_lr)
|
||||
logger.info(f"Module {lora.lora_name} matched regex '{regex_str}' -> LR {reg_lr}")
|
||||
break
|
||||
except re.error:
|
||||
# regex error should have been caught during parsing, but just in case
|
||||
continue
|
||||
|
||||
for name, param in lora.named_parameters():
|
||||
if loraplus_ratio is not None and "lora_up" in name:
|
||||
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
|
||||
param_key = f"{lora.lora_name}.{name}"
|
||||
is_plus = loraplus_ratio is not None and "lora_up" in name
|
||||
|
||||
if matched_reg_lr is not None:
|
||||
# use regex-specific learning rate
|
||||
reg_idx, reg_lr = matched_reg_lr
|
||||
group_key = f"reg_lr_{reg_idx}"
|
||||
if group_key not in reg_groups:
|
||||
reg_groups[group_key] = {"lora": {}, "plus": {}, "lr": reg_lr}
|
||||
|
||||
if is_plus:
|
||||
reg_groups[group_key]["plus"][param_key] = param
|
||||
else:
|
||||
reg_groups[group_key]["lora"][param_key] = param
|
||||
else:
|
||||
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
|
||||
# use default learning rate
|
||||
if is_plus:
|
||||
param_groups["plus"][param_key] = param
|
||||
else:
|
||||
param_groups["lora"][param_key] = param
|
||||
|
||||
params = []
|
||||
descriptions = []
|
||||
|
||||
# process regex-specific groups first (higher priority)
|
||||
for group_key in sorted(reg_groups.keys()):
|
||||
group = reg_groups[group_key]
|
||||
reg_lr = group["lr"]
|
||||
|
||||
for param_type in ["lora", "plus"]:
|
||||
if len(group[param_type]) == 0:
|
||||
continue
|
||||
|
||||
param_data = {"params": group[param_type].values()}
|
||||
|
||||
if param_type == "plus" and loraplus_ratio is not None:
|
||||
param_data["lr"] = reg_lr * loraplus_ratio
|
||||
else:
|
||||
param_data["lr"] = reg_lr
|
||||
|
||||
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
|
||||
continue
|
||||
|
||||
params.append(param_data)
|
||||
desc = f"reg_lr_{group_key.split('_')[-1]}"
|
||||
if param_type == "plus":
|
||||
desc += " plus"
|
||||
descriptions.append(desc)
|
||||
|
||||
# process default groups
|
||||
for key in param_groups.keys():
|
||||
param_data = {"params": param_groups[key].values()}
|
||||
|
||||
|
||||
@@ -645,7 +645,7 @@ class NetworkTrainer:
|
||||
net_kwargs = {}
|
||||
if args.network_args is not None:
|
||||
for net_arg in args.network_args:
|
||||
key, value = net_arg.split("=")
|
||||
key, value = net_arg.split("=", 1)
|
||||
net_kwargs[key] = value
|
||||
|
||||
# if a new network is added in future, add if ~ then blocks for each network (;'∀')
|
||||
|
||||
Reference in New Issue
Block a user