Compare commits

...

4 Commits

Author SHA1 Message Date
Kohya S
9de721198a gelore impl 2024-03-07 23:31:57 +09:00
Kohya S
14c9372a38 add doc about Colab/rich issue 2024-03-03 21:47:37 +09:00
Kohya S
074d32af20 Merge branch 'main' into dev 2024-02-27 18:53:43 +09:00
Kohya S
577e9913ca add some new dataset settings 2024-02-26 20:01:25 +09:00
6 changed files with 994 additions and 61 deletions

164
README.md
View File

@@ -249,6 +249,112 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum
## Change History
### Working in progress
- Colab seems to stop with log output. Try specifying `--console_log_simple` option in the training script to disable rich logging.
- `train_network.py` and `sdxl_train_network.py` are modified to record some dataset settings in the metadata of the trained model (`caption_prefix`, `caption_suffix`, `keep_tokens_separator`, `secondary_separator`, `enable_wildcard`).
- Some features are added to the dataset subset settings.
- `secondary_separator` is added to specify the tag separator that is not the target of shuffling or dropping.
- Specify `secondary_separator=";;;"`. When you specify `secondary_separator`, the part is not shuffled or dropped. See the example below.
- `enable_wildcard` is added. When set to `true`, the wildcard notation `{aaa|bbb|ccc}` can be used. See the example below.
- `keep_tokens_separator` is updated to be used twice in the caption. When you specify `keep_tokens_separator="|||"`, the part divided by the second `|||` is not shuffled or dropped and remains at the end.
- The existing features `caption_prefix` and `caption_suffix` can be used together. `caption_prefix` and `caption_suffix` are processed first, and then `enable_wildcard`, `keep_tokens_separator`, shuffling and dropping, and `secondary_separator` are processed in order.
- The examples are [shown below](#example-of-dataset-settings--データセット設定の記述例).
- Colab での動作時、ログ出力で停止してしまうようです。学習スクリプトに `--console_log_simple` オプションを指定し、rich のロギングを無効してお試しください。
- `train_network.py` および `sdxl_train_network.py` で、学習したモデルのメタデータに一部のデータセット設定が記録されるよう修正しました(`caption_prefix``caption_suffix``keep_tokens_separator``secondary_separator``enable_wildcard`)。
- データセットのサブセット設定にいくつかの機能を追加しました。
- シャッフルの対象とならないタグ分割識別子の指定 `secondary_separator` を追加しました。`secondary_separator=";;;"` のように指定します。`secondary_separator` で区切ることで、その部分はシャッフル、drop 時にまとめて扱われます。詳しくは記述例をご覧ください。
- `enable_wildcard` を追加しました。`true` にするとワイルドカード記法 `{aaa|bbb|ccc}` が使えます。詳しくは記述例をご覧ください。
- `keep_tokens_separator` をキャプション内に 2 つ使えるようにしました。たとえば `keep_tokens_separator="|||"` と指定したとき、`1girl, hatsune miku, vocaloid ||| stage, mic ||| best quality, rating: general` とキャプションを指定すると、二番目の `|||` で分割された部分はシャッフル、drop されず末尾に残ります。
- 既存の機能 `caption_prefix``caption_suffix` とあわせて使えます。`caption_prefix``caption_suffix` は一番最初に処理され、その後、ワイルドカード、`keep_tokens_separator`、シャッフルおよび drop、`secondary_separator` の順に処理されます。
#### Example of dataset settings / データセット設定の記述例:
```toml
[general]
flip_aug = true
color_aug = false
resolution = [1024, 1024]
[[datasets]]
batch_size = 6
enable_bucket = true
bucket_no_upscale = true
caption_extension = ".txt"
keep_tokens_separator= "|||"
shuffle_caption = true
caption_tag_dropout_rate = 0.1
secondary_separator = ";;;" # subset 側に書くこともできます / can be written in the subset side
enable_wildcard = true # 同上 / same as above
[[datasets.subsets]]
image_dir = "/path/to/image_dir"
num_repeats = 1
# ||| の前後はカンマは不要です(自動的に追加されます) / No comma is required before and after ||| (it is added automatically)
caption_prefix = "1girl, hatsune miku, vocaloid |||"
# ||| の後はシャッフル、drop されず残ります / After |||, it is not shuffled or dropped and remains
# 単純に文字列として連結されるので、カンマなどは自分で入れる必要があります / It is simply concatenated as a string, so you need to put commas yourself
caption_suffix = ", anime screencap ||| masterpiece, rating: general"
```
#### Example of caption, secondary_separator notation: `secondary_separator = ";;;"`
```txt
1girl, hatsune miku, vocaloid, upper body, looking at viewer, sky;;;cloud;;;day, outdoors
```
The part `sky;;;cloud;;;day` is replaced with `sky,cloud,day` without shuffling or dropping. When shuffling and dropping are enabled, it is processed as a whole (as one tag). For example, it becomes `vocaloid, 1girl, upper body, sky,cloud,day, outdoors, hatsune miku` (shuffled) or `vocaloid, 1girl, outdoors, looking at viewer, upper body, hatsune miku` (dropped).
#### Example of caption, enable_wildcard notation: `enable_wildcard = true`
```txt
1girl, hatsune miku, vocaloid, upper body, looking at viewer, {simple|white} background
```
`simple` or `white` is randomly selected, and it becomes `simple background` or `white background`.
```txt
1girl, hatsune miku, vocaloid, {{retro style}}
```
If you want to include `{` or `}` in the tag string, double them like `{{` or `}}` (in this example, the actual caption used for training is `{retro style}`).
#### Example of caption, `keep_tokens_separator` notation: `keep_tokens_separator = "|||"`
```txt
1girl, hatsune miku, vocaloid ||| stage, microphone, white shirt, smile ||| best quality, rating: general
```
It becomes `1girl, hatsune miku, vocaloid, microphone, stage, white shirt, best quality, rating: general` or `1girl, hatsune miku, vocaloid, white shirt, smile, stage, microphone, best quality, rating: general` etc.
#### キャプション記述例、secondary_separator 記法:`secondary_separator = ";;;"` の場合
```txt
1girl, hatsune miku, vocaloid, upper body, looking at viewer, sky;;;cloud;;;day, outdoors
```
`sky;;;cloud;;;day` の部分はシャッフル、drop されず `sky,cloud,day` に置換されます。シャッフル、drop が有効な場合、まとめて(一つのタグとして)処理されます。つまり `vocaloid, 1girl, upper body, sky,cloud,day, outdoors, hatsune miku` (シャッフル)や `vocaloid, 1girl, outdoors, looking at viewer, upper body, hatsune miku` drop されたケース)などになります。
#### キャプション記述例、ワイルドカード記法: `enable_wildcard = true` の場合
```txt
1girl, hatsune miku, vocaloid, upper body, looking at viewer, {simple|white} background
```
ランダムに `simple` または `white` が選ばれ、`simple background` または `white background` になります。
```txt
1girl, hatsune miku, vocaloid, {{retro style}}
```
タグ文字列に `{``}` そのものを含めたい場合は `{{``}}` のように二つ重ねてください(この例では実際に学習に用いられるキャプションは `{retro style}` になります)。
#### キャプション記述例、`keep_tokens_separator` 記法: `keep_tokens_separator = "|||"` の場合
```txt
1girl, hatsune miku, vocaloid ||| stage, microphone, white shirt, smile ||| best quality, rating: general
```
`1girl, hatsune miku, vocaloid, microphone, stage, white shirt, best quality, rating: general``1girl, hatsune miku, vocaloid, white shirt, smile, stage, microphone, best quality, rating: general` などになります。
### Feb 24, 2024 / 2024/2/24: v0.8.4
- The log output has been improved. PR [#905](https://github.com/kohya-ss/sd-scripts/pull/905) Thanks to shirayu!
@@ -304,64 +410,6 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum
- 複数 GPU での学習時に `network_multiplier` を指定するとクラッシュする不具合が修正されました。 PR [#1084](https://github.com/kohya-ss/sd-scripts/pull/1084) fireicewolf 氏に感謝します。
- ControlNet-LLLite の学習がエラーになる不具合を修正しました。
### Jan 23, 2024 / 2024/1/23: v0.8.2
- [Experimental] The `--fp8_base` option is added to the training scripts for LoRA etc. The base model (U-Net, and Text Encoder when training modules for Text Encoder) can be trained with fp8. PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) Thanks to KohakuBlueleaf!
- Please specify `--fp8_base` in `train_network.py` or `sdxl_train_network.py`.
- PyTorch 2.1 or later is required.
- If you use xformers with PyTorch 2.1, please see [xformers repository](https://github.com/facebookresearch/xformers) and install the appropriate version according to your CUDA version.
- The sample image generation during training consumes a lot of memory. It is recommended to turn it off.
- [Experimental] The network multiplier can be specified for each dataset in the training scripts for LoRA etc.
- This is an experimental option and may be removed or changed in the future.
- For example, if you train with state A as `1.0` and state B as `-1.0`, you may be able to generate by switching between state A and B depending on the LoRA application rate.
- Also, if you prepare five states and train them as `0.2`, `0.4`, `0.6`, `0.8`, and `1.0`, you may be able to generate by switching the states smoothly depending on the application rate.
- Please specify `network_multiplier` in `[[datasets]]` in `.toml` file.
- Some options are added to `networks/extract_lora_from_models.py` to reduce the memory usage.
- `--load_precision` option can be used to specify the precision when loading the model. If the model is saved in fp16, you can reduce the memory usage by specifying `--load_precision fp16` without losing precision.
- `--load_original_model_to` option can be used to specify the device to load the original model. `--load_tuned_model_to` option can be used to specify the device to load the derived model. The default is `cpu` for both options, but you can specify `cuda` etc. You can reduce the memory usage by loading one of them to GPU. This option is available only for SDXL.
- The gradient synchronization in LoRA training with multi-GPU is improved. PR [#1064](https://github.com/kohya-ss/sd-scripts/pull/1064) Thanks to KohakuBlueleaf!
- The code for Intel IPEX support is improved. PR [#1060](https://github.com/kohya-ss/sd-scripts/pull/1060) Thanks to akx!
- Fixed a bug in multi-GPU Textual Inversion training.
- 実験的 LoRA等の学習スクリプトで、ベースモデルU-Net、および Text Encoder のモジュール学習時は Text Encoder も)の重みを fp8 にして学習するオプションが追加されました。 PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) KohakuBlueleaf 氏に感謝します。
- `train_network.py` または `sdxl_train_network.py``--fp8_base` を指定してください。
- PyTorch 2.1 以降が必要です。
- PyTorch 2.1 で xformers を使用する場合は、[xformers のリポジトリ](https://github.com/facebookresearch/xformers) を参照し、CUDA バージョンに応じて適切なバージョンをインストールしてください。
- 学習中のサンプル画像生成はメモリを大量に消費するため、オフにすることをお勧めします。
- (実験的) LoRA 等の学習で、データセットごとに異なるネットワーク適用率を指定できるようになりました。
- 実験的オプションのため、将来的に削除または仕様変更される可能性があります。
- たとえば状態 A を `1.0`、状態 B を `-1.0` として学習すると、LoRA の適用率に応じて状態 A と B を切り替えつつ生成できるかもしれません。
- また、五段階の状態を用意し、それぞれ `0.2``0.4``0.6``0.8``1.0` として学習すると、適用率でなめらかに状態を切り替えて生成できるかもしれません。
- `.toml` ファイルで `[[datasets]]``network_multiplier` を指定してください。
- `networks/extract_lora_from_models.py` に使用メモリ量を削減するいくつかのオプションを追加しました。
- `--load_precision` で読み込み時の精度を指定できます。モデルが fp16 で保存されている場合は `--load_precision fp16` を指定して精度を変えずにメモリ量を削減できます。
- `--load_original_model_to` で元モデルを読み込むデバイスを、`--load_tuned_model_to` で派生モデルを読み込むデバイスを指定できます。デフォルトは両方とも `cpu` ですがそれぞれ `cuda` 等を指定できます。片方を GPU に読み込むことでメモリ量を削減できます。SDXL の場合のみ有効です。
- マルチ GPU での LoRA 等の学習時に勾配の同期が改善されました。 PR [#1064](https://github.com/kohya-ss/sd-scripts/pull/1064) KohakuBlueleaf 氏に感謝します。
- Intel IPEX サポートのコードが改善されました。PR [#1060](https://github.com/kohya-ss/sd-scripts/pull/1060) akx 氏に感謝します。
- マルチ GPU での Textual Inversion 学習の不具合を修正しました。
- `.toml` example for network multiplier / ネットワーク適用率の `.toml` の記述例
```toml
[general]
[[datasets]]
resolution = 512
batch_size = 8
network_multiplier = 1.0
... subset settings ...
[[datasets]]
resolution = 512
batch_size = 8
network_multiplier = -1.0
... subset settings ...
```
Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates.
最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。

View File

@@ -60,6 +60,8 @@ class BaseSubsetParams:
caption_separator: str = (",",)
keep_tokens: int = 0
keep_tokens_separator: str = (None,)
secondary_separator: Optional[str] = None
enable_wildcard: bool = False
color_aug: bool = False
flip_aug: bool = False
face_crop_aug_range: Optional[Tuple[float, float]] = None
@@ -181,6 +183,8 @@ class ConfigSanitizer:
"shuffle_caption": bool,
"keep_tokens": int,
"keep_tokens_separator": str,
"secondary_separator": str,
"enable_wildcard": bool,
"token_warmup_min": int,
"token_warmup_step": Any(float, int),
"caption_prefix": str,
@@ -504,6 +508,8 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
shuffle_caption: {subset.shuffle_caption}
keep_tokens: {subset.keep_tokens}
keep_tokens_separator: {subset.keep_tokens_separator}
secondary_separator: {subset.secondary_separator}
enable_wildcard: {subset.enable_wildcard}
caption_dropout_rate: {subset.caption_dropout_rate}
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}

798
library/galore_optimizer.py Normal file
View File

@@ -0,0 +1,798 @@
# copy from https://github.com/jiaweizzhao/GaLore/blob/master/galore_torch/
# original license is Apache License 2.0
import ast
import math
import warnings
from typing import Callable, Dict, Iterable, Tuple
import torch
from torch import nn
from torch.optim import Optimizer
from transformers.utils.versions import require_version
from library import train_util
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class GaLoreProjector:
def __init__(self, rank, verbose=False, update_proj_gap=200, scale=1.0, proj_type="std"):
self.rank = rank
self.verbose = verbose
self.update_proj_gap = update_proj_gap
self.scale = scale
self.ortho_matrix = None
self.proj_type = proj_type
def project(self, full_rank_grad, iter):
if self.proj_type == "std":
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type="right")
low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
else:
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type="left")
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
elif self.proj_type == "reverse_std":
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type="left")
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
else:
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type="right")
low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
elif self.proj_type == "right":
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type="right")
low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
elif self.proj_type == "left":
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type="left")
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
elif self.proj_type == "full":
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type="full")
low_rank_grad = torch.matmul(self.ortho_matrix[0].t(), full_rank_grad) @ self.ortho_matrix[1].t()
return low_rank_grad
def project_back(self, low_rank_grad):
if self.proj_type == "std":
if low_rank_grad.shape[0] >= low_rank_grad.shape[1]:
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
else:
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
elif self.proj_type == "reverse_std":
if low_rank_grad.shape[0] <= low_rank_grad.shape[1]: # note this is different from std
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
else:
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
elif self.proj_type == "right":
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
elif self.proj_type == "left":
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
elif self.proj_type == "full":
full_rank_grad = torch.matmul(self.ortho_matrix[0], low_rank_grad) @ self.ortho_matrix[1]
return full_rank_grad * self.scale
# svd decomposition
def get_orthogonal_matrix(self, weights, rank, type):
module_params = weights
if module_params.data.dtype != torch.float:
float_data = False
original_type = module_params.data.dtype
original_device = module_params.data.device
matrix = module_params.data.float()
else:
float_data = True
matrix = module_params.data
U, s, Vh = torch.linalg.svd(matrix)
# make the smaller matrix always to be orthogonal matrix
if type == "right":
A = U[:, :rank] @ torch.diag(s[:rank])
B = Vh[:rank, :]
if not float_data:
B = B.to(original_device).type(original_type)
return B
elif type == "left":
A = U[:, :rank]
B = torch.diag(s[:rank]) @ Vh[:rank, :]
if not float_data:
A = A.to(original_device).type(original_type)
return A
elif type == "full":
A = U[:, :rank]
B = Vh[:rank, :]
if not float_data:
A = A.to(original_device).type(original_type)
B = B.to(original_device).type(original_type)
return [A, B]
else:
raise ValueError("type should be left, right or full")
class GaLoreAdamW(Optimizer):
"""
Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay
Regularization](https://arxiv.org/abs/1711.05101).
Parameters:
params (`Iterable[nn.parameter.Parameter]`):
Iterable of parameters to optimize or dictionaries defining parameter groups.
lr (`float`, *optional*, defaults to 0.001):
The learning rate to use.
betas (`Tuple[float,float]`, *optional*, defaults to `(0.9, 0.999)`):
Adam's betas parameters (b1, b2).
eps (`float`, *optional*, defaults to 1e-06):
Adam's epsilon for numerical stability.
weight_decay (`float`, *optional*, defaults to 0.0):
Decoupled weight decay to apply.
correct_bias (`bool`, *optional*, defaults to `True`):
Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`).
no_deprecation_warning (`bool`, *optional*, defaults to `False`):
A flag used to disable the deprecation warning (set to `True` to disable the warning).
"""
def __init__(
self,
params: Iterable[nn.parameter.Parameter],
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-6,
weight_decay: float = 0.0,
correct_bias: bool = True,
no_deprecation_warning: bool = False,
):
if not no_deprecation_warning:
warnings.warn(
"This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch"
" implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this"
" warning",
FutureWarning,
)
require_version("torch>=1.5.0") # add_ with alpha
if lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)")
if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0")
defaults = {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "correct_bias": correct_bias}
super().__init__(params, defaults)
@torch.no_grad()
def step(self, closure: Callable = None):
"""
Performs a single optimization step.
Arguments:
closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad
if grad.is_sparse:
raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")
state = self.state[p]
if "step" not in state:
state["step"] = 0
# GaLore Projection
if "rank" in group:
if "projector" not in state:
state["projector"] = GaLoreProjector(
group["rank"],
update_proj_gap=group["update_proj_gap"],
scale=group["scale"],
proj_type=group["proj_type"],
)
grad = state["projector"].project(grad, state["step"])
# State initialization
if "exp_avg" not in state:
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(grad)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(grad)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]
state["step"] += 1
# Decay the first and second moment running average coefficient
# In-place operations to update the averages at the same time
exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1))
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
denom = exp_avg_sq.sqrt().add_(group["eps"])
step_size = group["lr"]
if group["correct_bias"]: # No bias correction for Bert
bias_correction1 = 1.0 - beta1 ** state["step"]
bias_correction2 = 1.0 - beta2 ** state["step"]
step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
# compute norm gradient
norm_grad = exp_avg / denom
# GaLore Projection Back
if "rank" in group:
norm_grad = state["projector"].project_back(norm_grad)
p.add_(norm_grad, alpha=-step_size)
# Just adding the square of the weights to the loss function is *not*
# the correct way of using L2 regularization/weight decay with Adam,
# since that will interact with the m and v parameters in strange ways.
#
# Instead we want to decay the weights in a manner that doesn't interact
# with the m/v parameters. This is equivalent to adding the square
# of the weights to the loss with plain (non-momentum) SGD.
# Add weight decay at the end (fixed version)
if group["weight_decay"] > 0.0:
p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
return loss
class GaLoreAdafactor(Optimizer):
"""
AdaFactor pytorch implementation can be used as a drop in replacement for Adam original fairseq code:
https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py
Paper: *Adafactor: Adaptive Learning Rates with Sublinear Memory Cost* https://arxiv.org/abs/1804.04235 Note that
this optimizer internally adjusts the learning rate depending on the `scale_parameter`, `relative_step` and
`warmup_init` options. To use a manual (external) learning rate schedule you should set `scale_parameter=False` and
`relative_step=False`.
Arguments:
params (`Iterable[nn.parameter.Parameter]`):
Iterable of parameters to optimize or dictionaries defining parameter groups.
lr (`float`, *optional*):
The external learning rate.
eps (`Tuple[float, float]`, *optional*, defaults to `(1e-30, 0.001)`):
Regularization constants for square gradient and parameter scale respectively
clip_threshold (`float`, *optional*, defaults to 1.0):
Threshold of root mean square of final gradient update
decay_rate (`float`, *optional*, defaults to -0.8):
Coefficient used to compute running averages of square
beta1 (`float`, *optional*):
Coefficient used for computing running averages of gradient
weight_decay (`float`, *optional*, defaults to 0.0):
Weight decay (L2 penalty)
scale_parameter (`bool`, *optional*, defaults to `True`):
If True, learning rate is scaled by root mean square
relative_step (`bool`, *optional*, defaults to `True`):
If True, time-dependent learning rate is computed instead of external learning rate
warmup_init (`bool`, *optional*, defaults to `False`):
Time-dependent learning rate computation depends on whether warm-up initialization is being used
This implementation handles low-precision (FP16, bfloat) values, but we have not thoroughly tested.
Recommended T5 finetuning settings (https://discuss.huggingface.co/t/t5-finetuning-tips/684/3):
- Training without LR warmup or clip_threshold is not recommended.
- use scheduled LR warm-up to fixed LR
- use clip_threshold=1.0 (https://arxiv.org/abs/1804.04235)
- Disable relative updates
- Use scale_parameter=False
- Additional optimizer operations like gradient clipping should not be used alongside Adafactor
Example:
```python
Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=1e-3)
```
Others reported the following combination to work well:
```python
Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
```
When using `lr=None` with [`Trainer`] you will most likely need to use [`~optimization.AdafactorSchedule`]
scheduler as following:
```python
from transformers.optimization import Adafactor, AdafactorSchedule
optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
lr_scheduler = AdafactorSchedule(optimizer)
trainer = Trainer(..., optimizers=(optimizer, lr_scheduler))
```
Usage:
```python
# replace AdamW with Adafactor
optimizer = Adafactor(
model.parameters(),
lr=1e-3,
eps=(1e-30, 1e-3),
clip_threshold=1.0,
decay_rate=-0.8,
beta1=None,
weight_decay=0.0,
relative_step=False,
scale_parameter=False,
warmup_init=False,
)
```"""
# make default to be the same as trainer
def __init__(
self,
params,
lr=None,
eps=(1e-30, 1e-3),
clip_threshold=1.0,
decay_rate=-0.8,
beta1=None,
weight_decay=0.0,
scale_parameter=False,
relative_step=False,
warmup_init=False,
):
# scale_parameter=True,
# relative_step=True,
require_version("torch>=1.5.0") # add_ with alpha
if lr is not None and relative_step:
raise ValueError("Cannot combine manual `lr` and `relative_step=True` options")
if warmup_init and not relative_step:
raise ValueError("`warmup_init=True` requires `relative_step=True`")
defaults = {
"lr": lr,
"eps": eps,
"clip_threshold": clip_threshold,
"decay_rate": decay_rate,
"beta1": beta1,
"weight_decay": weight_decay,
"scale_parameter": scale_parameter,
"relative_step": relative_step,
"warmup_init": warmup_init,
}
super().__init__(params, defaults)
@staticmethod
def _get_lr(param_group, param_state):
rel_step_sz = param_group["lr"]
if param_group["relative_step"]:
min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2
rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"]))
param_scale = 1.0
if param_group["scale_parameter"]:
param_scale = max(param_group["eps"][1], param_state["RMS"])
return param_scale * rel_step_sz
@staticmethod
def _get_options(param_group, param_shape):
factored = len(param_shape) >= 2
use_first_moment = param_group["beta1"] is not None
return factored, use_first_moment
@staticmethod
def _rms(tensor):
return tensor.norm(2) / (tensor.numel() ** 0.5)
@staticmethod
def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):
# copy from fairseq's adafactor implementation:
# https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505
r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
return torch.mul(r_factor, c_factor)
@torch.no_grad()
def step(self, closure=None):
"""
Performs a single optimization step
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad
if grad.dtype in {torch.float16, torch.bfloat16}:
grad = grad.float()
if grad.is_sparse:
raise RuntimeError("Adafactor does not support sparse gradients.")
state = self.state[p]
if "step" not in state:
state["step"] = 0
# GaLore Projection
if "rank" in group:
if "projector" not in state:
state["projector"] = GaLoreProjector(
group["rank"],
update_proj_gap=group["update_proj_gap"],
scale=group["scale"],
proj_type=group["proj_type"],
)
grad = state["projector"].project(grad, state["step"])
grad_shape = grad.shape
factored, use_first_moment = self._get_options(group, grad_shape)
# State Initialization
if "RMS" not in state:
state["step"] = 0
if use_first_moment:
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(grad)
if factored:
state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
else:
state["exp_avg_sq"] = torch.zeros_like(grad)
state["RMS"] = 0
else:
if use_first_moment:
state["exp_avg"] = state["exp_avg"].to(grad)
if factored:
state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
else:
state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
p_data_fp32 = p
if p.dtype in {torch.float16, torch.bfloat16}:
p_data_fp32 = p_data_fp32.float()
state["step"] += 1
state["RMS"] = self._rms(p_data_fp32)
lr = self._get_lr(group, state)
beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
update = (grad**2) + group["eps"][0]
if factored:
exp_avg_sq_row = state["exp_avg_sq_row"]
exp_avg_sq_col = state["exp_avg_sq_col"]
exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
# Approximation of exponential moving average of square of gradient
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update.mul_(grad)
else:
exp_avg_sq = state["exp_avg_sq"]
exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
update = exp_avg_sq.rsqrt().mul_(grad)
update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
update.mul_(lr)
if use_first_moment:
exp_avg = state["exp_avg"]
exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
update = exp_avg
# GaLore Projection Back
if "rank" in group:
update = state["projector"].project_back(update)
if group["weight_decay"] != 0:
p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
p_data_fp32.add_(-update)
if p.dtype in {torch.float16, torch.bfloat16}:
p.copy_(p_data_fp32)
return loss
try:
from bitsandbytes.optim.optimizer import Optimizer2State
except ImportError:
# define a dummy Optimizer2State class
class Optimizer2State(Optimizer):
def __init__(self, *args, **kwargs):
raise ImportError("Please install bitsandbytes to use this optimizer")
def step(self, *args, **kwargs):
pass
def prefetch_state(self, *args, **kwargs):
pass
def init_state(self, *args, **kwargs):
pass
def update_step(self, *args, **kwargs):
pass
def check_overrides(self, *args, **kwargs):
pass
def to_gpu(self, *args, **kwargs):
pass
def to_cpu(self, *args, **kwargs):
pass
class GaLoreAdamW8bit(Optimizer2State):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
is_paged=False,
):
super().__init__(
"adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged
)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
overflows = []
if not self.initialized:
self.check_overrides()
self.to_gpu() # needed for fairseq pure fp16 training
self.initialized = True
# if self.is_paged: self.page_mng.prefetch_all()
for gindex, group in enumerate(self.param_groups):
for pindex, p in enumerate(group["params"]):
if p.grad is None:
continue
state = self.state[p]
if "step" not in state:
state["step"] = 0
# GaLore Projection
if "rank" in group:
if "projector" not in state:
state["projector"] = GaLoreProjector(
group["rank"],
update_proj_gap=group["update_proj_gap"],
scale=group["scale"],
proj_type=group["proj_type"],
)
if "weight_decay" in group and group["weight_decay"] > 0:
# ensure that the weight decay is not applied to the norm grad
group["weight_decay_saved"] = group["weight_decay"]
group["weight_decay"] = 0
grad = state["projector"].project(p.grad, state["step"])
# suboptimal implementation
p.saved_data = p.data.clone()
p.data = grad.clone().to(p.data.dtype).to(p.data.device)
p.data.zero_()
p.grad = grad
if "state1" not in state:
self.init_state(group, p, gindex, pindex)
self.prefetch_state(p)
self.update_step(group, p, gindex, pindex)
torch.cuda.synchronize()
# GaLore Projection Back
if "rank" in group:
p.data = p.saved_data.add_(state["projector"].project_back(p.data))
# apply weight decay
if "weight_decay_saved" in group:
p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay_saved"])
group["weight_decay"] = group["weight_decay_saved"]
del group["weight_decay_saved"]
if self.is_paged:
# all paged operation are asynchronous, we need
# to sync to make sure all tensors are in the right state
torch.cuda.synchronize()
return loss
def get_optimizer(args, optimizer_type, trainable_params, training_models, num_processes):
# trainable_params is list of dict, each dict contains "params" and "lr"
# list may contain multiple dicts: [unet] or [unet, te1] or [unet, te1, te2]
# block lr is not supported
assert len(trainable_params) == len(training_models), "block lr is not supported"
lr = args.learning_rate
optimizer_kwargs = {}
if args.optimizer_args is not None and len(args.optimizer_args) > 0:
for arg in args.optimizer_args:
key, value = arg.split("=")
value = ast.literal_eval(value)
optimizer_kwargs[key] = value
rank = optimizer_kwargs.pop("rank", 128)
update_proj_gap = optimizer_kwargs.pop("update_proj_gap", 50)
galore_scale = optimizer_kwargs.pop("galore_scale", 1.0)
proj_type = optimizer_kwargs.pop("proj_type", "std")
weight_decay = optimizer_kwargs.get("weight_decay", 0.0) # do not pop, as it is used in the optimizer
# make parameters with "rank" to a single group, if param_name has "mlp" or "attn"
# target_modules_list = ["attn", "mlp"]
target_modules_list = ["attn", "mlp", "ff"] # for SDXL U-Net
param_groups = []
param_lr = {}
for model, params in zip(training_models, trainable_params):
logger.info(f"model: {model.__class__.__name__}")
galore_params = []
group_lr = params.get("lr", lr)
for module_name, module in model.named_modules():
if not isinstance(module, nn.Linear):
continue
if not any(target_key in module_name for target_key in target_modules_list):
continue
logger.info("enable GaLore for weights in module: " + module_name)
galore_params.append(module.weight)
id_galore_params = [id(p) for p in galore_params]
# make parameters without "rank" to another group
regular_params = [p for p in params["params"] if id(p) not in id_galore_params]
# then call galore_adamw
param_groups.append({"params": regular_params, "lr": group_lr})
param_groups.append(
{
"params": galore_params,
"rank": rank,
"update_proj_gap": update_proj_gap,
"scale": galore_scale,
"proj_type": proj_type,
"lr": group_lr,
}
)
# record lr
for p in regular_params + galore_params:
param_lr[id(p)] = group_lr
# select optimizer
scheduler = None
if optimizer_type == "galore_adamw":
optimizer = GaLoreAdamW(param_groups, lr=lr, **optimizer_kwargs)
elif optimizer_type == "galore_adafactor":
beta1 = None if optimizer_kwargs.get("beta1", 0.0) == 0.0 else optimizer_kwargs.pop("beta1")
optimizer = GaLoreAdafactor(param_groups, lr=lr, beta1=beta1, **optimizer_kwargs)
elif optimizer_type == "galore_adamw8bit":
optimizer = GaLoreAdamW8bit(param_groups, lr=lr, **optimizer_kwargs)
elif optimizer_type == "galore_adamw8bit_per_layer":
# TODO: seems scheduler call twice in one update step, need to check, for now double the num_training_steps, warmup_steps and update_proj_gap
optimizer_dict = {}
all_params = []
for params in trainable_params:
all_params.extend(params["params"])
for p in all_params:
if p.requires_grad:
if id(p) in id_galore_params:
optimizer_dict[p] = GaLoreAdamW8bit(
[
{
"params": [p],
"rank": rank,
"update_proj_gap": update_proj_gap * 2,
"scale": galore_scale,
"proj_type": proj_type,
}
],
lr=param_lr[id(p)],
weight_decay=weight_decay,
)
else:
import bitsandbytes as bnb
optimizer_dict[p] = bnb.optim.Adam8bit([p], lr=param_lr[id(p)], weight_decay=weight_decay)
# get scheduler dict
# scheduler needs accelerate.prepare?
scheduler_dict = {}
for p in all_params:
if p.requires_grad:
scheduler_dict[p] = train_util.get_scheduler_fix(args, optimizer_dict[p], num_processes)
def optimizer_hook(p):
if p.grad is None:
return
optimizer_dict[p].step()
optimizer_dict[p].zero_grad()
scheduler_dict[p].step()
# Register the hook onto every parameter
for p in all_params:
if p.requires_grad:
p.register_post_accumulate_grad_hook(optimizer_hook)
# make dummy scheduler and optimizer
class DummyScheduler:
def step(self):
pass
class DummyOptimizer:
def __init__(self, optimizer_dict):
self.optimizer_dict = optimizer_dict
def step(self):
pass
def zero_grad(self, set_to_none=False):
pass
scheduler = DummyScheduler(optimizer_dict[all_params[0]])
optimizer = DummyOptimizer(optimizer_dict)
else:
raise ValueError(f"Unsupported optimizer type: {optimizer_type}")
return optimizer, scheduler

View File

@@ -364,6 +364,8 @@ class BaseSubset:
caption_separator: str,
keep_tokens: int,
keep_tokens_separator: str,
secondary_separator: Optional[str],
enable_wildcard: bool,
color_aug: bool,
flip_aug: bool,
face_crop_aug_range: Optional[Tuple[float, float]],
@@ -382,6 +384,8 @@ class BaseSubset:
self.caption_separator = caption_separator
self.keep_tokens = keep_tokens
self.keep_tokens_separator = keep_tokens_separator
self.secondary_separator = secondary_separator
self.enable_wildcard = enable_wildcard
self.color_aug = color_aug
self.flip_aug = flip_aug
self.face_crop_aug_range = face_crop_aug_range
@@ -410,6 +414,8 @@ class DreamBoothSubset(BaseSubset):
caption_separator: str,
keep_tokens,
keep_tokens_separator,
secondary_separator,
enable_wildcard,
color_aug,
flip_aug,
face_crop_aug_range,
@@ -431,6 +437,8 @@ class DreamBoothSubset(BaseSubset):
caption_separator,
keep_tokens,
keep_tokens_separator,
secondary_separator,
enable_wildcard,
color_aug,
flip_aug,
face_crop_aug_range,
@@ -466,6 +474,8 @@ class FineTuningSubset(BaseSubset):
caption_separator,
keep_tokens,
keep_tokens_separator,
secondary_separator,
enable_wildcard,
color_aug,
flip_aug,
face_crop_aug_range,
@@ -487,6 +497,8 @@ class FineTuningSubset(BaseSubset):
caption_separator,
keep_tokens,
keep_tokens_separator,
secondary_separator,
enable_wildcard,
color_aug,
flip_aug,
face_crop_aug_range,
@@ -519,6 +531,8 @@ class ControlNetSubset(BaseSubset):
caption_separator,
keep_tokens,
keep_tokens_separator,
secondary_separator,
enable_wildcard,
color_aug,
flip_aug,
face_crop_aug_range,
@@ -540,6 +554,8 @@ class ControlNetSubset(BaseSubset):
caption_separator,
keep_tokens,
keep_tokens_separator,
secondary_separator,
enable_wildcard,
color_aug,
flip_aug,
face_crop_aug_range,
@@ -675,15 +691,41 @@ class BaseDataset(torch.utils.data.Dataset):
if is_drop_out:
caption = ""
else:
# process wildcards
if subset.enable_wildcard:
# wildcard is like '{aaa|bbb|ccc...}'
# escape the curly braces like {{ or }}
replacer1 = ""
replacer2 = ""
while replacer1 in caption or replacer2 in caption:
replacer1 += ""
replacer2 += ""
caption = caption.replace("{{", replacer1).replace("}}", replacer2)
# replace the wildcard
def replace_wildcard(match):
return random.choice(match.group(1).split("|"))
caption = re.sub(r"\{([^}]+)\}", replace_wildcard, caption)
# unescape the curly braces
caption = caption.replace(replacer1, "{").replace(replacer2, "}")
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
fixed_tokens = []
flex_tokens = []
fixed_suffix_tokens = []
if (
hasattr(subset, "keep_tokens_separator")
and subset.keep_tokens_separator
and subset.keep_tokens_separator in caption
):
fixed_part, flex_part = caption.split(subset.keep_tokens_separator, 1)
if subset.keep_tokens_separator in flex_part:
flex_part, fixed_suffix_part = flex_part.split(subset.keep_tokens_separator, 1)
fixed_suffix_tokens = [t.strip() for t in fixed_suffix_part.split(subset.caption_separator) if t.strip()]
fixed_tokens = [t.strip() for t in fixed_part.split(subset.caption_separator) if t.strip()]
flex_tokens = [t.strip() for t in flex_part.split(subset.caption_separator) if t.strip()]
else:
@@ -718,7 +760,11 @@ class BaseDataset(torch.utils.data.Dataset):
flex_tokens = dropout_tags(flex_tokens)
caption = ", ".join(fixed_tokens + flex_tokens)
caption = ", ".join(fixed_tokens + flex_tokens + fixed_suffix_tokens)
# process secondary separator
if subset.secondary_separator:
caption = caption.replace(subset.secondary_separator, subset.caption_separator)
# textual inversion対応
for str_from, str_to in self.replacements.items():
@@ -1774,6 +1820,8 @@ class ControlNetDataset(BaseDataset):
subset.caption_separator,
subset.keep_tokens,
subset.keep_tokens_separator,
subset.secondary_separator,
subset.enable_wildcard,
subset.color_aug,
subset.flip_aug,
subset.face_crop_aug_range,
@@ -3284,6 +3332,18 @@ def add_dataset_arguments(
help="A custom separator to divide the caption into fixed and flexible parts. Tokens before this separator will not be shuffled. If not specified, '--keep_tokens' will be used to determine the fixed number of tokens."
+ " / captionを固定部分と可変部分に分けるためのカスタム区切り文字。この区切り文字より前のトークンはシャッフルされない。指定しない場合、'--keep_tokens'が固定部分のトークン数として使用される。",
)
parser.add_argument(
"--secondary_separator",
type=str,
default=None,
help="a secondary separator for caption. This separator is replaced to caption_separator after dropping/shuffling caption"
+ " / captionのセカンダリ区切り文字。この区切り文字はcaptionのドロップやシャッフル後にcaption_separatorに置き換えられる",
)
parser.add_argument(
"--enable_wildcard",
action="store_true",
help="enable wildcard for caption (e.g. '{image|picture|rendition}') / captionのワイルドカードを有効にする'{image|picture|rendition}'",
)
parser.add_argument(
"--caption_prefix",
type=str,
@@ -3611,7 +3671,7 @@ def get_optimizer(args, trainable_params):
optimizer_class = lion_pytorch.Lion
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type.endswith("8bit".lower()):
elif optimizer_type.endswith("8bit".lower()) and not optimizer_type.startswith("GaLore".lower()):
try:
import bitsandbytes as bnb
except ImportError:
@@ -3820,6 +3880,11 @@ def get_optimizer(args, trainable_params):
optimizer_class = torch.optim.AdamW
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type.startswith("GaLore".lower()):
logger.info(f"use GaLore optimizer | {optimizer_kwargs}")
optimizer = "galore"
return None, None, optimizer
if optimizer is None:
# 任意のoptimizerを使う
optimizer_type = args.optimizer_type # lowerでないやつ微妙

View File

@@ -11,6 +11,7 @@ from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate.utils import set_seed
@@ -378,7 +379,17 @@ def train(args):
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
lr_scheduler = None
if optimizer == "galore":
from library import galore_optimizer
# if lr_scheduler is not layerwise, it is None. if layerwise, it is a dummy scheduler
optimizer, lr_scheduler = galore_optimizer.get_optimizer(
args, args.optimizer_type, params_to_optimize, training_models, accelerator.num_processes
)
if lr_scheduler is None:
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# 実験的機能勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
if args.full_fp16:

View File

@@ -564,6 +564,11 @@ class NetworkTrainer:
"random_crop": bool(subset.random_crop),
"shuffle_caption": bool(subset.shuffle_caption),
"keep_tokens": subset.keep_tokens,
"keep_tokens_separator": subset.keep_tokens_separator,
"secondary_separator": subset.secondary_separator,
"enable_wildcard": bool(subset.enable_wildcard),
"caption_prefix": subset.caption_prefix,
"caption_suffix": subset.caption_suffix,
}
image_dir_or_metadata_file = None