mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6b1520a46b | ||
|
|
f811b115ba |
168
README.md
168
README.md
@@ -249,111 +249,15 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum
|
||||
|
||||
## Change History
|
||||
|
||||
### Working in progress
|
||||
### Mar 15, 2024 / 2024/3/15: v0.8.5
|
||||
|
||||
- Colab seems to stop with log output. Try specifying `--console_log_simple` option in the training script to disable rich logging.
|
||||
- `train_network.py` and `sdxl_train_network.py` are modified to record some dataset settings in the metadata of the trained model (`caption_prefix`, `caption_suffix`, `keep_tokens_separator`, `secondary_separator`, `enable_wildcard`).
|
||||
- Some features are added to the dataset subset settings.
|
||||
- `secondary_separator` is added to specify the tag separator that is not the target of shuffling or dropping.
|
||||
- Specify `secondary_separator=";;;"`. When you specify `secondary_separator`, the part is not shuffled or dropped. See the example below.
|
||||
- `enable_wildcard` is added. When set to `true`, the wildcard notation `{aaa|bbb|ccc}` can be used. See the example below.
|
||||
- `keep_tokens_separator` is updated to be used twice in the caption. When you specify `keep_tokens_separator="|||"`, the part divided by the second `|||` is not shuffled or dropped and remains at the end.
|
||||
- The existing features `caption_prefix` and `caption_suffix` can be used together. `caption_prefix` and `caption_suffix` are processed first, and then `enable_wildcard`, `keep_tokens_separator`, shuffling and dropping, and `secondary_separator` are processed in order.
|
||||
- The examples are [shown below](#example-of-dataset-settings--データセット設定の記述例).
|
||||
|
||||
|
||||
- Colab での動作時、ログ出力で停止してしまうようです。学習スクリプトに `--console_log_simple` オプションを指定し、rich のロギングを無効してお試しください。
|
||||
- `train_network.py` および `sdxl_train_network.py` で、学習したモデルのメタデータに一部のデータセット設定が記録されるよう修正しました(`caption_prefix`、`caption_suffix`、`keep_tokens_separator`、`secondary_separator`、`enable_wildcard`)。
|
||||
- データセットのサブセット設定にいくつかの機能を追加しました。
|
||||
- シャッフルの対象とならないタグ分割識別子の指定 `secondary_separator` を追加しました。`secondary_separator=";;;"` のように指定します。`secondary_separator` で区切ることで、その部分はシャッフル、drop 時にまとめて扱われます。詳しくは記述例をご覧ください。
|
||||
- `enable_wildcard` を追加しました。`true` にするとワイルドカード記法 `{aaa|bbb|ccc}` が使えます。詳しくは記述例をご覧ください。
|
||||
- `keep_tokens_separator` をキャプション内に 2 つ使えるようにしました。たとえば `keep_tokens_separator="|||"` と指定したとき、`1girl, hatsune miku, vocaloid ||| stage, mic ||| best quality, rating: general` とキャプションを指定すると、二番目の `|||` で分割された部分はシャッフル、drop されず末尾に残ります。
|
||||
- 既存の機能 `caption_prefix` と `caption_suffix` とあわせて使えます。`caption_prefix` と `caption_suffix` は一番最初に処理され、その後、ワイルドカード、`keep_tokens_separator`、シャッフルおよび drop、`secondary_separator` の順に処理されます。
|
||||
|
||||
#### Example of dataset settings / データセット設定の記述例:
|
||||
|
||||
```toml
|
||||
[general]
|
||||
flip_aug = true
|
||||
color_aug = false
|
||||
resolution = [1024, 1024]
|
||||
|
||||
[[datasets]]
|
||||
batch_size = 6
|
||||
enable_bucket = true
|
||||
bucket_no_upscale = true
|
||||
caption_extension = ".txt"
|
||||
keep_tokens_separator= "|||"
|
||||
shuffle_caption = true
|
||||
caption_tag_dropout_rate = 0.1
|
||||
secondary_separator = ";;;" # subset 側に書くこともできます / can be written in the subset side
|
||||
enable_wildcard = true # 同上 / same as above
|
||||
|
||||
[[datasets.subsets]]
|
||||
image_dir = "/path/to/image_dir"
|
||||
num_repeats = 1
|
||||
|
||||
# ||| の前後はカンマは不要です(自動的に追加されます) / No comma is required before and after ||| (it is added automatically)
|
||||
caption_prefix = "1girl, hatsune miku, vocaloid |||"
|
||||
|
||||
# ||| の後はシャッフル、drop されず残ります / After |||, it is not shuffled or dropped and remains
|
||||
# 単純に文字列として連結されるので、カンマなどは自分で入れる必要があります / It is simply concatenated as a string, so you need to put commas yourself
|
||||
caption_suffix = ", anime screencap ||| masterpiece, rating: general"
|
||||
```
|
||||
|
||||
#### Example of caption, secondary_separator notation: `secondary_separator = ";;;"`
|
||||
|
||||
```txt
|
||||
1girl, hatsune miku, vocaloid, upper body, looking at viewer, sky;;;cloud;;;day, outdoors
|
||||
```
|
||||
The part `sky;;;cloud;;;day` is replaced with `sky,cloud,day` without shuffling or dropping. When shuffling and dropping are enabled, it is processed as a whole (as one tag). For example, it becomes `vocaloid, 1girl, upper body, sky,cloud,day, outdoors, hatsune miku` (shuffled) or `vocaloid, 1girl, outdoors, looking at viewer, upper body, hatsune miku` (dropped).
|
||||
|
||||
#### Example of caption, enable_wildcard notation: `enable_wildcard = true`
|
||||
|
||||
```txt
|
||||
1girl, hatsune miku, vocaloid, upper body, looking at viewer, {simple|white} background
|
||||
```
|
||||
`simple` or `white` is randomly selected, and it becomes `simple background` or `white background`.
|
||||
|
||||
```txt
|
||||
1girl, hatsune miku, vocaloid, {{retro style}}
|
||||
```
|
||||
If you want to include `{` or `}` in the tag string, double them like `{{` or `}}` (in this example, the actual caption used for training is `{retro style}`).
|
||||
|
||||
#### Example of caption, `keep_tokens_separator` notation: `keep_tokens_separator = "|||"`
|
||||
|
||||
```txt
|
||||
1girl, hatsune miku, vocaloid ||| stage, microphone, white shirt, smile ||| best quality, rating: general
|
||||
```
|
||||
It becomes `1girl, hatsune miku, vocaloid, microphone, stage, white shirt, best quality, rating: general` or `1girl, hatsune miku, vocaloid, white shirt, smile, stage, microphone, best quality, rating: general` etc.
|
||||
|
||||
|
||||
#### キャプション記述例、secondary_separator 記法:`secondary_separator = ";;;"` の場合
|
||||
|
||||
```txt
|
||||
1girl, hatsune miku, vocaloid, upper body, looking at viewer, sky;;;cloud;;;day, outdoors
|
||||
```
|
||||
`sky;;;cloud;;;day` の部分はシャッフル、drop されず `sky,cloud,day` に置換されます。シャッフル、drop が有効な場合、まとめて(一つのタグとして)処理されます。つまり `vocaloid, 1girl, upper body, sky,cloud,day, outdoors, hatsune miku` (シャッフル)や `vocaloid, 1girl, outdoors, looking at viewer, upper body, hatsune miku` (drop されたケース)などになります。
|
||||
|
||||
#### キャプション記述例、ワイルドカード記法: `enable_wildcard = true` の場合
|
||||
|
||||
```txt
|
||||
1girl, hatsune miku, vocaloid, upper body, looking at viewer, {simple|white} background
|
||||
```
|
||||
ランダムに `simple` または `white` が選ばれ、`simple background` または `white background` になります。
|
||||
|
||||
```txt
|
||||
1girl, hatsune miku, vocaloid, {{retro style}}
|
||||
```
|
||||
タグ文字列に `{` や `}` そのものを含めたい場合は `{{` や `}}` のように二つ重ねてください(この例では実際に学習に用いられるキャプションは `{retro style}` になります)。
|
||||
|
||||
#### キャプション記述例、`keep_tokens_separator` 記法: `keep_tokens_separator = "|||"` の場合
|
||||
|
||||
```txt
|
||||
1girl, hatsune miku, vocaloid ||| stage, microphone, white shirt, smile ||| best quality, rating: general
|
||||
```
|
||||
`1girl, hatsune miku, vocaloid, microphone, stage, white shirt, best quality, rating: general` や `1girl, hatsune miku, vocaloid, white shirt, smile, stage, microphone, best quality, rating: general` などになります。
|
||||
- Fixed a bug that the value of timestep embedding during SDXL training was incorrect.
|
||||
- The inference with the generation script is also fixed.
|
||||
- The impact is unknown, but please update for SDXL training.
|
||||
|
||||
- SDXL 学習時の timestep embedding の値が誤っていたのを修正しました。
|
||||
- 生成スクリプトでの推論時についてもあわせて修正しました。
|
||||
- 影響の度合いは不明ですが、SDXL の学習時にはアップデートをお願いいたします。
|
||||
|
||||
### Feb 24, 2024 / 2024/2/24: v0.8.4
|
||||
|
||||
@@ -410,6 +314,64 @@ It becomes `1girl, hatsune miku, vocaloid, microphone, stage, white shirt, best
|
||||
- 複数 GPU での学習時に `network_multiplier` を指定するとクラッシュする不具合が修正されました。 PR [#1084](https://github.com/kohya-ss/sd-scripts/pull/1084) fireicewolf 氏に感謝します。
|
||||
- ControlNet-LLLite の学習がエラーになる不具合を修正しました。
|
||||
|
||||
### Jan 23, 2024 / 2024/1/23: v0.8.2
|
||||
|
||||
- [Experimental] The `--fp8_base` option is added to the training scripts for LoRA etc. The base model (U-Net, and Text Encoder when training modules for Text Encoder) can be trained with fp8. PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) Thanks to KohakuBlueleaf!
|
||||
- Please specify `--fp8_base` in `train_network.py` or `sdxl_train_network.py`.
|
||||
- PyTorch 2.1 or later is required.
|
||||
- If you use xformers with PyTorch 2.1, please see [xformers repository](https://github.com/facebookresearch/xformers) and install the appropriate version according to your CUDA version.
|
||||
- The sample image generation during training consumes a lot of memory. It is recommended to turn it off.
|
||||
|
||||
- [Experimental] The network multiplier can be specified for each dataset in the training scripts for LoRA etc.
|
||||
- This is an experimental option and may be removed or changed in the future.
|
||||
- For example, if you train with state A as `1.0` and state B as `-1.0`, you may be able to generate by switching between state A and B depending on the LoRA application rate.
|
||||
- Also, if you prepare five states and train them as `0.2`, `0.4`, `0.6`, `0.8`, and `1.0`, you may be able to generate by switching the states smoothly depending on the application rate.
|
||||
- Please specify `network_multiplier` in `[[datasets]]` in `.toml` file.
|
||||
- Some options are added to `networks/extract_lora_from_models.py` to reduce the memory usage.
|
||||
- `--load_precision` option can be used to specify the precision when loading the model. If the model is saved in fp16, you can reduce the memory usage by specifying `--load_precision fp16` without losing precision.
|
||||
- `--load_original_model_to` option can be used to specify the device to load the original model. `--load_tuned_model_to` option can be used to specify the device to load the derived model. The default is `cpu` for both options, but you can specify `cuda` etc. You can reduce the memory usage by loading one of them to GPU. This option is available only for SDXL.
|
||||
|
||||
- The gradient synchronization in LoRA training with multi-GPU is improved. PR [#1064](https://github.com/kohya-ss/sd-scripts/pull/1064) Thanks to KohakuBlueleaf!
|
||||
- The code for Intel IPEX support is improved. PR [#1060](https://github.com/kohya-ss/sd-scripts/pull/1060) Thanks to akx!
|
||||
- Fixed a bug in multi-GPU Textual Inversion training.
|
||||
|
||||
- (実験的) LoRA等の学習スクリプトで、ベースモデル(U-Net、および Text Encoder のモジュール学習時は Text Encoder も)の重みを fp8 にして学習するオプションが追加されました。 PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) KohakuBlueleaf 氏に感謝します。
|
||||
- `train_network.py` または `sdxl_train_network.py` で `--fp8_base` を指定してください。
|
||||
- PyTorch 2.1 以降が必要です。
|
||||
- PyTorch 2.1 で xformers を使用する場合は、[xformers のリポジトリ](https://github.com/facebookresearch/xformers) を参照し、CUDA バージョンに応じて適切なバージョンをインストールしてください。
|
||||
- 学習中のサンプル画像生成はメモリを大量に消費するため、オフにすることをお勧めします。
|
||||
- (実験的) LoRA 等の学習で、データセットごとに異なるネットワーク適用率を指定できるようになりました。
|
||||
- 実験的オプションのため、将来的に削除または仕様変更される可能性があります。
|
||||
- たとえば状態 A を `1.0`、状態 B を `-1.0` として学習すると、LoRA の適用率に応じて状態 A と B を切り替えつつ生成できるかもしれません。
|
||||
- また、五段階の状態を用意し、それぞれ `0.2`、`0.4`、`0.6`、`0.8`、`1.0` として学習すると、適用率でなめらかに状態を切り替えて生成できるかもしれません。
|
||||
- `.toml` ファイルで `[[datasets]]` に `network_multiplier` を指定してください。
|
||||
- `networks/extract_lora_from_models.py` に使用メモリ量を削減するいくつかのオプションを追加しました。
|
||||
- `--load_precision` で読み込み時の精度を指定できます。モデルが fp16 で保存されている場合は `--load_precision fp16` を指定して精度を変えずにメモリ量を削減できます。
|
||||
- `--load_original_model_to` で元モデルを読み込むデバイスを、`--load_tuned_model_to` で派生モデルを読み込むデバイスを指定できます。デフォルトは両方とも `cpu` ですがそれぞれ `cuda` 等を指定できます。片方を GPU に読み込むことでメモリ量を削減できます。SDXL の場合のみ有効です。
|
||||
- マルチ GPU での LoRA 等の学習時に勾配の同期が改善されました。 PR [#1064](https://github.com/kohya-ss/sd-scripts/pull/1064) KohakuBlueleaf 氏に感謝します。
|
||||
- Intel IPEX サポートのコードが改善されました。PR [#1060](https://github.com/kohya-ss/sd-scripts/pull/1060) akx 氏に感謝します。
|
||||
- マルチ GPU での Textual Inversion 学習の不具合を修正しました。
|
||||
|
||||
- `.toml` example for network multiplier / ネットワーク適用率の `.toml` の記述例
|
||||
|
||||
```toml
|
||||
[general]
|
||||
[[datasets]]
|
||||
resolution = 512
|
||||
batch_size = 8
|
||||
network_multiplier = 1.0
|
||||
|
||||
... subset settings ...
|
||||
|
||||
[[datasets]]
|
||||
resolution = 512
|
||||
batch_size = 8
|
||||
network_multiplier = -1.0
|
||||
|
||||
... subset settings ...
|
||||
```
|
||||
|
||||
|
||||
Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates.
|
||||
最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。
|
||||
|
||||
|
||||
@@ -60,8 +60,6 @@ class BaseSubsetParams:
|
||||
caption_separator: str = (",",)
|
||||
keep_tokens: int = 0
|
||||
keep_tokens_separator: str = (None,)
|
||||
secondary_separator: Optional[str] = None
|
||||
enable_wildcard: bool = False
|
||||
color_aug: bool = False
|
||||
flip_aug: bool = False
|
||||
face_crop_aug_range: Optional[Tuple[float, float]] = None
|
||||
@@ -183,8 +181,6 @@ class ConfigSanitizer:
|
||||
"shuffle_caption": bool,
|
||||
"keep_tokens": int,
|
||||
"keep_tokens_separator": str,
|
||||
"secondary_separator": str,
|
||||
"enable_wildcard": bool,
|
||||
"token_warmup_min": int,
|
||||
"token_warmup_step": Any(float, int),
|
||||
"caption_prefix": str,
|
||||
@@ -508,8 +504,6 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
shuffle_caption: {subset.shuffle_caption}
|
||||
keep_tokens: {subset.keep_tokens}
|
||||
keep_tokens_separator: {subset.keep_tokens_separator}
|
||||
secondary_separator: {subset.secondary_separator}
|
||||
enable_wildcard: {subset.enable_wildcard}
|
||||
caption_dropout_rate: {subset.caption_dropout_rate}
|
||||
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
|
||||
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
|
||||
|
||||
@@ -1,798 +0,0 @@
|
||||
# copy from https://github.com/jiaweizzhao/GaLore/blob/master/galore_torch/
|
||||
# original license is Apache License 2.0
|
||||
import ast
|
||||
import math
|
||||
import warnings
|
||||
from typing import Callable, Dict, Iterable, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from library import train_util
|
||||
|
||||
from .utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GaLoreProjector:
|
||||
def __init__(self, rank, verbose=False, update_proj_gap=200, scale=1.0, proj_type="std"):
|
||||
self.rank = rank
|
||||
self.verbose = verbose
|
||||
self.update_proj_gap = update_proj_gap
|
||||
self.scale = scale
|
||||
self.ortho_matrix = None
|
||||
self.proj_type = proj_type
|
||||
|
||||
def project(self, full_rank_grad, iter):
|
||||
|
||||
if self.proj_type == "std":
|
||||
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
|
||||
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type="right")
|
||||
low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
|
||||
else:
|
||||
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type="left")
|
||||
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
|
||||
elif self.proj_type == "reverse_std":
|
||||
if full_rank_grad.shape[0] >= full_rank_grad.shape[1]:
|
||||
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type="left")
|
||||
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
|
||||
else:
|
||||
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type="right")
|
||||
low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
|
||||
elif self.proj_type == "right":
|
||||
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type="right")
|
||||
low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t())
|
||||
elif self.proj_type == "left":
|
||||
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type="left")
|
||||
low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad)
|
||||
elif self.proj_type == "full":
|
||||
if self.ortho_matrix is None or iter % self.update_proj_gap == 0:
|
||||
self.ortho_matrix = self.get_orthogonal_matrix(full_rank_grad, self.rank, type="full")
|
||||
low_rank_grad = torch.matmul(self.ortho_matrix[0].t(), full_rank_grad) @ self.ortho_matrix[1].t()
|
||||
|
||||
return low_rank_grad
|
||||
|
||||
def project_back(self, low_rank_grad):
|
||||
|
||||
if self.proj_type == "std":
|
||||
if low_rank_grad.shape[0] >= low_rank_grad.shape[1]:
|
||||
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
|
||||
else:
|
||||
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
|
||||
elif self.proj_type == "reverse_std":
|
||||
if low_rank_grad.shape[0] <= low_rank_grad.shape[1]: # note this is different from std
|
||||
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
|
||||
else:
|
||||
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
|
||||
elif self.proj_type == "right":
|
||||
full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix)
|
||||
elif self.proj_type == "left":
|
||||
full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad)
|
||||
elif self.proj_type == "full":
|
||||
full_rank_grad = torch.matmul(self.ortho_matrix[0], low_rank_grad) @ self.ortho_matrix[1]
|
||||
|
||||
return full_rank_grad * self.scale
|
||||
|
||||
# svd decomposition
|
||||
def get_orthogonal_matrix(self, weights, rank, type):
|
||||
module_params = weights
|
||||
|
||||
if module_params.data.dtype != torch.float:
|
||||
float_data = False
|
||||
original_type = module_params.data.dtype
|
||||
original_device = module_params.data.device
|
||||
matrix = module_params.data.float()
|
||||
else:
|
||||
float_data = True
|
||||
matrix = module_params.data
|
||||
|
||||
U, s, Vh = torch.linalg.svd(matrix)
|
||||
|
||||
# make the smaller matrix always to be orthogonal matrix
|
||||
if type == "right":
|
||||
A = U[:, :rank] @ torch.diag(s[:rank])
|
||||
B = Vh[:rank, :]
|
||||
|
||||
if not float_data:
|
||||
B = B.to(original_device).type(original_type)
|
||||
return B
|
||||
elif type == "left":
|
||||
A = U[:, :rank]
|
||||
B = torch.diag(s[:rank]) @ Vh[:rank, :]
|
||||
if not float_data:
|
||||
A = A.to(original_device).type(original_type)
|
||||
return A
|
||||
elif type == "full":
|
||||
A = U[:, :rank]
|
||||
B = Vh[:rank, :]
|
||||
if not float_data:
|
||||
A = A.to(original_device).type(original_type)
|
||||
B = B.to(original_device).type(original_type)
|
||||
return [A, B]
|
||||
else:
|
||||
raise ValueError("type should be left, right or full")
|
||||
|
||||
|
||||
class GaLoreAdamW(Optimizer):
|
||||
"""
|
||||
Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay
|
||||
Regularization](https://arxiv.org/abs/1711.05101).
|
||||
|
||||
Parameters:
|
||||
params (`Iterable[nn.parameter.Parameter]`):
|
||||
Iterable of parameters to optimize or dictionaries defining parameter groups.
|
||||
lr (`float`, *optional*, defaults to 0.001):
|
||||
The learning rate to use.
|
||||
betas (`Tuple[float,float]`, *optional*, defaults to `(0.9, 0.999)`):
|
||||
Adam's betas parameters (b1, b2).
|
||||
eps (`float`, *optional*, defaults to 1e-06):
|
||||
Adam's epsilon for numerical stability.
|
||||
weight_decay (`float`, *optional*, defaults to 0.0):
|
||||
Decoupled weight decay to apply.
|
||||
correct_bias (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`).
|
||||
no_deprecation_warning (`bool`, *optional*, defaults to `False`):
|
||||
A flag used to disable the deprecation warning (set to `True` to disable the warning).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params: Iterable[nn.parameter.Parameter],
|
||||
lr: float = 1e-3,
|
||||
betas: Tuple[float, float] = (0.9, 0.999),
|
||||
eps: float = 1e-6,
|
||||
weight_decay: float = 0.0,
|
||||
correct_bias: bool = True,
|
||||
no_deprecation_warning: bool = False,
|
||||
):
|
||||
if not no_deprecation_warning:
|
||||
warnings.warn(
|
||||
"This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch"
|
||||
" implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this"
|
||||
" warning",
|
||||
FutureWarning,
|
||||
)
|
||||
require_version("torch>=1.5.0") # add_ with alpha
|
||||
if lr < 0.0:
|
||||
raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0")
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)")
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)")
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0")
|
||||
defaults = {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "correct_bias": correct_bias}
|
||||
super().__init__(params, defaults)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure: Callable = None):
|
||||
"""
|
||||
Performs a single optimization step.
|
||||
|
||||
Arguments:
|
||||
closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
if "step" not in state:
|
||||
state["step"] = 0
|
||||
|
||||
# GaLore Projection
|
||||
if "rank" in group:
|
||||
if "projector" not in state:
|
||||
state["projector"] = GaLoreProjector(
|
||||
group["rank"],
|
||||
update_proj_gap=group["update_proj_gap"],
|
||||
scale=group["scale"],
|
||||
proj_type=group["proj_type"],
|
||||
)
|
||||
|
||||
grad = state["projector"].project(grad, state["step"])
|
||||
|
||||
# State initialization
|
||||
if "exp_avg" not in state:
|
||||
# Exponential moving average of gradient values
|
||||
state["exp_avg"] = torch.zeros_like(grad)
|
||||
# Exponential moving average of squared gradient values
|
||||
state["exp_avg_sq"] = torch.zeros_like(grad)
|
||||
|
||||
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
||||
beta1, beta2 = group["betas"]
|
||||
|
||||
state["step"] += 1
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
# In-place operations to update the averages at the same time
|
||||
exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1))
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
|
||||
denom = exp_avg_sq.sqrt().add_(group["eps"])
|
||||
|
||||
step_size = group["lr"]
|
||||
if group["correct_bias"]: # No bias correction for Bert
|
||||
bias_correction1 = 1.0 - beta1 ** state["step"]
|
||||
bias_correction2 = 1.0 - beta2 ** state["step"]
|
||||
step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
|
||||
|
||||
# compute norm gradient
|
||||
norm_grad = exp_avg / denom
|
||||
|
||||
# GaLore Projection Back
|
||||
if "rank" in group:
|
||||
norm_grad = state["projector"].project_back(norm_grad)
|
||||
|
||||
p.add_(norm_grad, alpha=-step_size)
|
||||
|
||||
# Just adding the square of the weights to the loss function is *not*
|
||||
# the correct way of using L2 regularization/weight decay with Adam,
|
||||
# since that will interact with the m and v parameters in strange ways.
|
||||
#
|
||||
# Instead we want to decay the weights in a manner that doesn't interact
|
||||
# with the m/v parameters. This is equivalent to adding the square
|
||||
# of the weights to the loss with plain (non-momentum) SGD.
|
||||
# Add weight decay at the end (fixed version)
|
||||
if group["weight_decay"] > 0.0:
|
||||
p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class GaLoreAdafactor(Optimizer):
|
||||
"""
|
||||
AdaFactor pytorch implementation can be used as a drop in replacement for Adam original fairseq code:
|
||||
https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py
|
||||
|
||||
Paper: *Adafactor: Adaptive Learning Rates with Sublinear Memory Cost* https://arxiv.org/abs/1804.04235 Note that
|
||||
this optimizer internally adjusts the learning rate depending on the `scale_parameter`, `relative_step` and
|
||||
`warmup_init` options. To use a manual (external) learning rate schedule you should set `scale_parameter=False` and
|
||||
`relative_step=False`.
|
||||
|
||||
Arguments:
|
||||
params (`Iterable[nn.parameter.Parameter]`):
|
||||
Iterable of parameters to optimize or dictionaries defining parameter groups.
|
||||
lr (`float`, *optional*):
|
||||
The external learning rate.
|
||||
eps (`Tuple[float, float]`, *optional*, defaults to `(1e-30, 0.001)`):
|
||||
Regularization constants for square gradient and parameter scale respectively
|
||||
clip_threshold (`float`, *optional*, defaults to 1.0):
|
||||
Threshold of root mean square of final gradient update
|
||||
decay_rate (`float`, *optional*, defaults to -0.8):
|
||||
Coefficient used to compute running averages of square
|
||||
beta1 (`float`, *optional*):
|
||||
Coefficient used for computing running averages of gradient
|
||||
weight_decay (`float`, *optional*, defaults to 0.0):
|
||||
Weight decay (L2 penalty)
|
||||
scale_parameter (`bool`, *optional*, defaults to `True`):
|
||||
If True, learning rate is scaled by root mean square
|
||||
relative_step (`bool`, *optional*, defaults to `True`):
|
||||
If True, time-dependent learning rate is computed instead of external learning rate
|
||||
warmup_init (`bool`, *optional*, defaults to `False`):
|
||||
Time-dependent learning rate computation depends on whether warm-up initialization is being used
|
||||
|
||||
This implementation handles low-precision (FP16, bfloat) values, but we have not thoroughly tested.
|
||||
|
||||
Recommended T5 finetuning settings (https://discuss.huggingface.co/t/t5-finetuning-tips/684/3):
|
||||
|
||||
- Training without LR warmup or clip_threshold is not recommended.
|
||||
|
||||
- use scheduled LR warm-up to fixed LR
|
||||
- use clip_threshold=1.0 (https://arxiv.org/abs/1804.04235)
|
||||
- Disable relative updates
|
||||
- Use scale_parameter=False
|
||||
- Additional optimizer operations like gradient clipping should not be used alongside Adafactor
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=1e-3)
|
||||
```
|
||||
|
||||
Others reported the following combination to work well:
|
||||
|
||||
```python
|
||||
Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
|
||||
```
|
||||
|
||||
When using `lr=None` with [`Trainer`] you will most likely need to use [`~optimization.AdafactorSchedule`]
|
||||
scheduler as following:
|
||||
|
||||
```python
|
||||
from transformers.optimization import Adafactor, AdafactorSchedule
|
||||
|
||||
optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
|
||||
lr_scheduler = AdafactorSchedule(optimizer)
|
||||
trainer = Trainer(..., optimizers=(optimizer, lr_scheduler))
|
||||
```
|
||||
|
||||
Usage:
|
||||
|
||||
```python
|
||||
# replace AdamW with Adafactor
|
||||
optimizer = Adafactor(
|
||||
model.parameters(),
|
||||
lr=1e-3,
|
||||
eps=(1e-30, 1e-3),
|
||||
clip_threshold=1.0,
|
||||
decay_rate=-0.8,
|
||||
beta1=None,
|
||||
weight_decay=0.0,
|
||||
relative_step=False,
|
||||
scale_parameter=False,
|
||||
warmup_init=False,
|
||||
)
|
||||
```"""
|
||||
|
||||
# make default to be the same as trainer
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=None,
|
||||
eps=(1e-30, 1e-3),
|
||||
clip_threshold=1.0,
|
||||
decay_rate=-0.8,
|
||||
beta1=None,
|
||||
weight_decay=0.0,
|
||||
scale_parameter=False,
|
||||
relative_step=False,
|
||||
warmup_init=False,
|
||||
):
|
||||
# scale_parameter=True,
|
||||
# relative_step=True,
|
||||
|
||||
require_version("torch>=1.5.0") # add_ with alpha
|
||||
if lr is not None and relative_step:
|
||||
raise ValueError("Cannot combine manual `lr` and `relative_step=True` options")
|
||||
if warmup_init and not relative_step:
|
||||
raise ValueError("`warmup_init=True` requires `relative_step=True`")
|
||||
|
||||
defaults = {
|
||||
"lr": lr,
|
||||
"eps": eps,
|
||||
"clip_threshold": clip_threshold,
|
||||
"decay_rate": decay_rate,
|
||||
"beta1": beta1,
|
||||
"weight_decay": weight_decay,
|
||||
"scale_parameter": scale_parameter,
|
||||
"relative_step": relative_step,
|
||||
"warmup_init": warmup_init,
|
||||
}
|
||||
super().__init__(params, defaults)
|
||||
|
||||
@staticmethod
|
||||
def _get_lr(param_group, param_state):
|
||||
rel_step_sz = param_group["lr"]
|
||||
if param_group["relative_step"]:
|
||||
min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2
|
||||
rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"]))
|
||||
param_scale = 1.0
|
||||
if param_group["scale_parameter"]:
|
||||
param_scale = max(param_group["eps"][1], param_state["RMS"])
|
||||
return param_scale * rel_step_sz
|
||||
|
||||
@staticmethod
|
||||
def _get_options(param_group, param_shape):
|
||||
factored = len(param_shape) >= 2
|
||||
use_first_moment = param_group["beta1"] is not None
|
||||
return factored, use_first_moment
|
||||
|
||||
@staticmethod
|
||||
def _rms(tensor):
|
||||
return tensor.norm(2) / (tensor.numel() ** 0.5)
|
||||
|
||||
@staticmethod
|
||||
def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):
|
||||
# copy from fairseq's adafactor implementation:
|
||||
# https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505
|
||||
r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
|
||||
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
|
||||
return torch.mul(r_factor, c_factor)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
"""
|
||||
Performs a single optimization step
|
||||
|
||||
Arguments:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad
|
||||
if grad.dtype in {torch.float16, torch.bfloat16}:
|
||||
grad = grad.float()
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError("Adafactor does not support sparse gradients.")
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
if "step" not in state:
|
||||
state["step"] = 0
|
||||
|
||||
# GaLore Projection
|
||||
if "rank" in group:
|
||||
if "projector" not in state:
|
||||
state["projector"] = GaLoreProjector(
|
||||
group["rank"],
|
||||
update_proj_gap=group["update_proj_gap"],
|
||||
scale=group["scale"],
|
||||
proj_type=group["proj_type"],
|
||||
)
|
||||
|
||||
grad = state["projector"].project(grad, state["step"])
|
||||
|
||||
grad_shape = grad.shape
|
||||
|
||||
factored, use_first_moment = self._get_options(group, grad_shape)
|
||||
# State Initialization
|
||||
if "RMS" not in state:
|
||||
state["step"] = 0
|
||||
|
||||
if use_first_moment:
|
||||
# Exponential moving average of gradient values
|
||||
state["exp_avg"] = torch.zeros_like(grad)
|
||||
if factored:
|
||||
state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
|
||||
state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
|
||||
else:
|
||||
state["exp_avg_sq"] = torch.zeros_like(grad)
|
||||
|
||||
state["RMS"] = 0
|
||||
else:
|
||||
if use_first_moment:
|
||||
state["exp_avg"] = state["exp_avg"].to(grad)
|
||||
if factored:
|
||||
state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
|
||||
state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
|
||||
else:
|
||||
state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
|
||||
|
||||
p_data_fp32 = p
|
||||
if p.dtype in {torch.float16, torch.bfloat16}:
|
||||
p_data_fp32 = p_data_fp32.float()
|
||||
|
||||
state["step"] += 1
|
||||
state["RMS"] = self._rms(p_data_fp32)
|
||||
lr = self._get_lr(group, state)
|
||||
|
||||
beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
|
||||
update = (grad**2) + group["eps"][0]
|
||||
if factored:
|
||||
exp_avg_sq_row = state["exp_avg_sq_row"]
|
||||
exp_avg_sq_col = state["exp_avg_sq_col"]
|
||||
|
||||
exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
|
||||
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
|
||||
|
||||
# Approximation of exponential moving average of square of gradient
|
||||
update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
|
||||
update.mul_(grad)
|
||||
else:
|
||||
exp_avg_sq = state["exp_avg_sq"]
|
||||
|
||||
exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
|
||||
update = exp_avg_sq.rsqrt().mul_(grad)
|
||||
|
||||
update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
|
||||
update.mul_(lr)
|
||||
|
||||
if use_first_moment:
|
||||
exp_avg = state["exp_avg"]
|
||||
exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
|
||||
update = exp_avg
|
||||
|
||||
# GaLore Projection Back
|
||||
if "rank" in group:
|
||||
update = state["projector"].project_back(update)
|
||||
|
||||
if group["weight_decay"] != 0:
|
||||
p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
|
||||
|
||||
p_data_fp32.add_(-update)
|
||||
|
||||
if p.dtype in {torch.float16, torch.bfloat16}:
|
||||
p.copy_(p_data_fp32)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
try:
|
||||
from bitsandbytes.optim.optimizer import Optimizer2State
|
||||
except ImportError:
|
||||
# define a dummy Optimizer2State class
|
||||
class Optimizer2State(Optimizer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise ImportError("Please install bitsandbytes to use this optimizer")
|
||||
|
||||
def step(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def prefetch_state(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def init_state(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def update_step(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def check_overrides(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def to_gpu(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def to_cpu(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class GaLoreAdamW8bit(Optimizer2State):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-3,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8,
|
||||
weight_decay=1e-2,
|
||||
amsgrad=False,
|
||||
optim_bits=32,
|
||||
args=None,
|
||||
min_8bit_size=4096,
|
||||
percentile_clipping=100,
|
||||
block_wise=True,
|
||||
is_paged=False,
|
||||
):
|
||||
super().__init__(
|
||||
"adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
|
||||
Arguments:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
overflows = []
|
||||
|
||||
if not self.initialized:
|
||||
self.check_overrides()
|
||||
self.to_gpu() # needed for fairseq pure fp16 training
|
||||
self.initialized = True
|
||||
|
||||
# if self.is_paged: self.page_mng.prefetch_all()
|
||||
for gindex, group in enumerate(self.param_groups):
|
||||
for pindex, p in enumerate(group["params"]):
|
||||
if p.grad is None:
|
||||
continue
|
||||
state = self.state[p]
|
||||
|
||||
if "step" not in state:
|
||||
state["step"] = 0
|
||||
|
||||
# GaLore Projection
|
||||
if "rank" in group:
|
||||
if "projector" not in state:
|
||||
state["projector"] = GaLoreProjector(
|
||||
group["rank"],
|
||||
update_proj_gap=group["update_proj_gap"],
|
||||
scale=group["scale"],
|
||||
proj_type=group["proj_type"],
|
||||
)
|
||||
|
||||
if "weight_decay" in group and group["weight_decay"] > 0:
|
||||
# ensure that the weight decay is not applied to the norm grad
|
||||
group["weight_decay_saved"] = group["weight_decay"]
|
||||
group["weight_decay"] = 0
|
||||
|
||||
grad = state["projector"].project(p.grad, state["step"])
|
||||
|
||||
# suboptimal implementation
|
||||
p.saved_data = p.data.clone()
|
||||
p.data = grad.clone().to(p.data.dtype).to(p.data.device)
|
||||
p.data.zero_()
|
||||
p.grad = grad
|
||||
|
||||
if "state1" not in state:
|
||||
self.init_state(group, p, gindex, pindex)
|
||||
|
||||
self.prefetch_state(p)
|
||||
self.update_step(group, p, gindex, pindex)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# GaLore Projection Back
|
||||
if "rank" in group:
|
||||
p.data = p.saved_data.add_(state["projector"].project_back(p.data))
|
||||
|
||||
# apply weight decay
|
||||
if "weight_decay_saved" in group:
|
||||
p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay_saved"])
|
||||
group["weight_decay"] = group["weight_decay_saved"]
|
||||
del group["weight_decay_saved"]
|
||||
|
||||
if self.is_paged:
|
||||
# all paged operation are asynchronous, we need
|
||||
# to sync to make sure all tensors are in the right state
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def get_optimizer(args, optimizer_type, trainable_params, training_models, num_processes):
|
||||
# trainable_params is list of dict, each dict contains "params" and "lr"
|
||||
# list may contain multiple dicts: [unet] or [unet, te1] or [unet, te1, te2]
|
||||
# block lr is not supported
|
||||
assert len(trainable_params) == len(training_models), "block lr is not supported"
|
||||
|
||||
lr = args.learning_rate
|
||||
|
||||
optimizer_kwargs = {}
|
||||
if args.optimizer_args is not None and len(args.optimizer_args) > 0:
|
||||
for arg in args.optimizer_args:
|
||||
key, value = arg.split("=")
|
||||
value = ast.literal_eval(value)
|
||||
optimizer_kwargs[key] = value
|
||||
|
||||
rank = optimizer_kwargs.pop("rank", 128)
|
||||
update_proj_gap = optimizer_kwargs.pop("update_proj_gap", 50)
|
||||
galore_scale = optimizer_kwargs.pop("galore_scale", 1.0)
|
||||
proj_type = optimizer_kwargs.pop("proj_type", "std")
|
||||
weight_decay = optimizer_kwargs.get("weight_decay", 0.0) # do not pop, as it is used in the optimizer
|
||||
|
||||
# make parameters with "rank" to a single group, if param_name has "mlp" or "attn"
|
||||
# target_modules_list = ["attn", "mlp"]
|
||||
target_modules_list = ["attn", "mlp", "ff"] # for SDXL U-Net
|
||||
|
||||
param_groups = []
|
||||
param_lr = {}
|
||||
for model, params in zip(training_models, trainable_params):
|
||||
logger.info(f"model: {model.__class__.__name__}")
|
||||
galore_params = []
|
||||
group_lr = params.get("lr", lr)
|
||||
|
||||
for module_name, module in model.named_modules():
|
||||
if not isinstance(module, nn.Linear):
|
||||
continue
|
||||
|
||||
if not any(target_key in module_name for target_key in target_modules_list):
|
||||
continue
|
||||
|
||||
logger.info("enable GaLore for weights in module: " + module_name)
|
||||
galore_params.append(module.weight)
|
||||
|
||||
id_galore_params = [id(p) for p in galore_params]
|
||||
# make parameters without "rank" to another group
|
||||
regular_params = [p for p in params["params"] if id(p) not in id_galore_params]
|
||||
|
||||
# then call galore_adamw
|
||||
param_groups.append({"params": regular_params, "lr": group_lr})
|
||||
|
||||
param_groups.append(
|
||||
{
|
||||
"params": galore_params,
|
||||
"rank": rank,
|
||||
"update_proj_gap": update_proj_gap,
|
||||
"scale": galore_scale,
|
||||
"proj_type": proj_type,
|
||||
"lr": group_lr,
|
||||
}
|
||||
)
|
||||
|
||||
# record lr
|
||||
for p in regular_params + galore_params:
|
||||
param_lr[id(p)] = group_lr
|
||||
|
||||
# select optimizer
|
||||
scheduler = None
|
||||
if optimizer_type == "galore_adamw":
|
||||
optimizer = GaLoreAdamW(param_groups, lr=lr, **optimizer_kwargs)
|
||||
elif optimizer_type == "galore_adafactor":
|
||||
beta1 = None if optimizer_kwargs.get("beta1", 0.0) == 0.0 else optimizer_kwargs.pop("beta1")
|
||||
optimizer = GaLoreAdafactor(param_groups, lr=lr, beta1=beta1, **optimizer_kwargs)
|
||||
elif optimizer_type == "galore_adamw8bit":
|
||||
optimizer = GaLoreAdamW8bit(param_groups, lr=lr, **optimizer_kwargs)
|
||||
elif optimizer_type == "galore_adamw8bit_per_layer":
|
||||
# TODO: seems scheduler call twice in one update step, need to check, for now double the num_training_steps, warmup_steps and update_proj_gap
|
||||
optimizer_dict = {}
|
||||
all_params = []
|
||||
for params in trainable_params:
|
||||
all_params.extend(params["params"])
|
||||
|
||||
for p in all_params:
|
||||
if p.requires_grad:
|
||||
if id(p) in id_galore_params:
|
||||
optimizer_dict[p] = GaLoreAdamW8bit(
|
||||
[
|
||||
{
|
||||
"params": [p],
|
||||
"rank": rank,
|
||||
"update_proj_gap": update_proj_gap * 2,
|
||||
"scale": galore_scale,
|
||||
"proj_type": proj_type,
|
||||
}
|
||||
],
|
||||
lr=param_lr[id(p)],
|
||||
weight_decay=weight_decay,
|
||||
)
|
||||
else:
|
||||
import bitsandbytes as bnb
|
||||
|
||||
optimizer_dict[p] = bnb.optim.Adam8bit([p], lr=param_lr[id(p)], weight_decay=weight_decay)
|
||||
|
||||
# get scheduler dict
|
||||
# scheduler needs accelerate.prepare?
|
||||
scheduler_dict = {}
|
||||
for p in all_params:
|
||||
if p.requires_grad:
|
||||
scheduler_dict[p] = train_util.get_scheduler_fix(args, optimizer_dict[p], num_processes)
|
||||
|
||||
def optimizer_hook(p):
|
||||
if p.grad is None:
|
||||
return
|
||||
optimizer_dict[p].step()
|
||||
optimizer_dict[p].zero_grad()
|
||||
scheduler_dict[p].step()
|
||||
|
||||
# Register the hook onto every parameter
|
||||
for p in all_params:
|
||||
if p.requires_grad:
|
||||
p.register_post_accumulate_grad_hook(optimizer_hook)
|
||||
|
||||
# make dummy scheduler and optimizer
|
||||
class DummyScheduler:
|
||||
def step(self):
|
||||
pass
|
||||
|
||||
class DummyOptimizer:
|
||||
def __init__(self, optimizer_dict):
|
||||
self.optimizer_dict = optimizer_dict
|
||||
|
||||
def step(self):
|
||||
pass
|
||||
|
||||
def zero_grad(self, set_to_none=False):
|
||||
pass
|
||||
|
||||
scheduler = DummyScheduler(optimizer_dict[all_params[0]])
|
||||
optimizer = DummyOptimizer(optimizer_dict)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported optimizer type: {optimizer_type}")
|
||||
|
||||
return optimizer, scheduler
|
||||
@@ -31,8 +31,10 @@ from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from einops import rearrange
|
||||
from .utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
IN_CHANNELS: int = 4
|
||||
@@ -1074,7 +1076,7 @@ class SdxlUNet2DConditionModel(nn.Module):
|
||||
timesteps = timesteps.expand(x.shape[0])
|
||||
|
||||
hs = []
|
||||
t_emb = get_timestep_embedding(timesteps, self.model_channels) # , repeat_only=False)
|
||||
t_emb = get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0) # , repeat_only=False)
|
||||
t_emb = t_emb.to(x.dtype)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
@@ -1132,7 +1134,7 @@ class InferSdxlUNet2DConditionModel:
|
||||
# call original model's methods
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.delegate, name)
|
||||
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.delegate(*args, **kwargs)
|
||||
|
||||
@@ -1164,7 +1166,7 @@ class InferSdxlUNet2DConditionModel:
|
||||
timesteps = timesteps.expand(x.shape[0])
|
||||
|
||||
hs = []
|
||||
t_emb = get_timestep_embedding(timesteps, _self.model_channels) # , repeat_only=False)
|
||||
t_emb = get_timestep_embedding(timesteps, _self.model_channels, downscale_freq_shift=0) # , repeat_only=False)
|
||||
t_emb = t_emb.to(x.dtype)
|
||||
emb = _self.time_embed(t_emb)
|
||||
|
||||
|
||||
@@ -364,8 +364,6 @@ class BaseSubset:
|
||||
caption_separator: str,
|
||||
keep_tokens: int,
|
||||
keep_tokens_separator: str,
|
||||
secondary_separator: Optional[str],
|
||||
enable_wildcard: bool,
|
||||
color_aug: bool,
|
||||
flip_aug: bool,
|
||||
face_crop_aug_range: Optional[Tuple[float, float]],
|
||||
@@ -384,8 +382,6 @@ class BaseSubset:
|
||||
self.caption_separator = caption_separator
|
||||
self.keep_tokens = keep_tokens
|
||||
self.keep_tokens_separator = keep_tokens_separator
|
||||
self.secondary_separator = secondary_separator
|
||||
self.enable_wildcard = enable_wildcard
|
||||
self.color_aug = color_aug
|
||||
self.flip_aug = flip_aug
|
||||
self.face_crop_aug_range = face_crop_aug_range
|
||||
@@ -414,8 +410,6 @@ class DreamBoothSubset(BaseSubset):
|
||||
caption_separator: str,
|
||||
keep_tokens,
|
||||
keep_tokens_separator,
|
||||
secondary_separator,
|
||||
enable_wildcard,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
face_crop_aug_range,
|
||||
@@ -437,8 +431,6 @@ class DreamBoothSubset(BaseSubset):
|
||||
caption_separator,
|
||||
keep_tokens,
|
||||
keep_tokens_separator,
|
||||
secondary_separator,
|
||||
enable_wildcard,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
face_crop_aug_range,
|
||||
@@ -474,8 +466,6 @@ class FineTuningSubset(BaseSubset):
|
||||
caption_separator,
|
||||
keep_tokens,
|
||||
keep_tokens_separator,
|
||||
secondary_separator,
|
||||
enable_wildcard,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
face_crop_aug_range,
|
||||
@@ -497,8 +487,6 @@ class FineTuningSubset(BaseSubset):
|
||||
caption_separator,
|
||||
keep_tokens,
|
||||
keep_tokens_separator,
|
||||
secondary_separator,
|
||||
enable_wildcard,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
face_crop_aug_range,
|
||||
@@ -531,8 +519,6 @@ class ControlNetSubset(BaseSubset):
|
||||
caption_separator,
|
||||
keep_tokens,
|
||||
keep_tokens_separator,
|
||||
secondary_separator,
|
||||
enable_wildcard,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
face_crop_aug_range,
|
||||
@@ -554,8 +540,6 @@ class ControlNetSubset(BaseSubset):
|
||||
caption_separator,
|
||||
keep_tokens,
|
||||
keep_tokens_separator,
|
||||
secondary_separator,
|
||||
enable_wildcard,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
face_crop_aug_range,
|
||||
@@ -691,41 +675,15 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
if is_drop_out:
|
||||
caption = ""
|
||||
else:
|
||||
# process wildcards
|
||||
if subset.enable_wildcard:
|
||||
# wildcard is like '{aaa|bbb|ccc...}'
|
||||
# escape the curly braces like {{ or }}
|
||||
replacer1 = "⦅"
|
||||
replacer2 = "⦆"
|
||||
while replacer1 in caption or replacer2 in caption:
|
||||
replacer1 += "⦅"
|
||||
replacer2 += "⦆"
|
||||
|
||||
caption = caption.replace("{{", replacer1).replace("}}", replacer2)
|
||||
|
||||
# replace the wildcard
|
||||
def replace_wildcard(match):
|
||||
return random.choice(match.group(1).split("|"))
|
||||
|
||||
caption = re.sub(r"\{([^}]+)\}", replace_wildcard, caption)
|
||||
|
||||
# unescape the curly braces
|
||||
caption = caption.replace(replacer1, "{").replace(replacer2, "}")
|
||||
|
||||
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
|
||||
fixed_tokens = []
|
||||
flex_tokens = []
|
||||
fixed_suffix_tokens = []
|
||||
if (
|
||||
hasattr(subset, "keep_tokens_separator")
|
||||
and subset.keep_tokens_separator
|
||||
and subset.keep_tokens_separator in caption
|
||||
):
|
||||
fixed_part, flex_part = caption.split(subset.keep_tokens_separator, 1)
|
||||
if subset.keep_tokens_separator in flex_part:
|
||||
flex_part, fixed_suffix_part = flex_part.split(subset.keep_tokens_separator, 1)
|
||||
fixed_suffix_tokens = [t.strip() for t in fixed_suffix_part.split(subset.caption_separator) if t.strip()]
|
||||
|
||||
fixed_tokens = [t.strip() for t in fixed_part.split(subset.caption_separator) if t.strip()]
|
||||
flex_tokens = [t.strip() for t in flex_part.split(subset.caption_separator) if t.strip()]
|
||||
else:
|
||||
@@ -760,11 +718,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
flex_tokens = dropout_tags(flex_tokens)
|
||||
|
||||
caption = ", ".join(fixed_tokens + flex_tokens + fixed_suffix_tokens)
|
||||
|
||||
# process secondary separator
|
||||
if subset.secondary_separator:
|
||||
caption = caption.replace(subset.secondary_separator, subset.caption_separator)
|
||||
caption = ", ".join(fixed_tokens + flex_tokens)
|
||||
|
||||
# textual inversion対応
|
||||
for str_from, str_to in self.replacements.items():
|
||||
@@ -1820,8 +1774,6 @@ class ControlNetDataset(BaseDataset):
|
||||
subset.caption_separator,
|
||||
subset.keep_tokens,
|
||||
subset.keep_tokens_separator,
|
||||
subset.secondary_separator,
|
||||
subset.enable_wildcard,
|
||||
subset.color_aug,
|
||||
subset.flip_aug,
|
||||
subset.face_crop_aug_range,
|
||||
@@ -3332,18 +3284,6 @@ def add_dataset_arguments(
|
||||
help="A custom separator to divide the caption into fixed and flexible parts. Tokens before this separator will not be shuffled. If not specified, '--keep_tokens' will be used to determine the fixed number of tokens."
|
||||
+ " / captionを固定部分と可変部分に分けるためのカスタム区切り文字。この区切り文字より前のトークンはシャッフルされない。指定しない場合、'--keep_tokens'が固定部分のトークン数として使用される。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--secondary_separator",
|
||||
type=str,
|
||||
default=None,
|
||||
help="a secondary separator for caption. This separator is replaced to caption_separator after dropping/shuffling caption"
|
||||
+ " / captionのセカンダリ区切り文字。この区切り文字はcaptionのドロップやシャッフル後にcaption_separatorに置き換えられる",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_wildcard",
|
||||
action="store_true",
|
||||
help="enable wildcard for caption (e.g. '{image|picture|rendition}') / captionのワイルドカードを有効にする(例:'{image|picture|rendition}')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--caption_prefix",
|
||||
type=str,
|
||||
@@ -3671,7 +3611,7 @@ def get_optimizer(args, trainable_params):
|
||||
optimizer_class = lion_pytorch.Lion
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type.endswith("8bit".lower()) and not optimizer_type.startswith("GaLore".lower()):
|
||||
elif optimizer_type.endswith("8bit".lower()):
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
@@ -3880,11 +3820,6 @@ def get_optimizer(args, trainable_params):
|
||||
optimizer_class = torch.optim.AdamW
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type.startswith("GaLore".lower()):
|
||||
logger.info(f"use GaLore optimizer | {optimizer_kwargs}")
|
||||
optimizer = "galore"
|
||||
return None, None, optimizer
|
||||
|
||||
if optimizer is None:
|
||||
# 任意のoptimizerを使う
|
||||
optimizer_type = args.optimizer_type # lowerでないやつ(微妙)
|
||||
|
||||
@@ -11,7 +11,6 @@ from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from accelerate.utils import set_seed
|
||||
@@ -379,17 +378,7 @@ def train(args):
|
||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
|
||||
# lr schedulerを用意する
|
||||
lr_scheduler = None
|
||||
if optimizer == "galore":
|
||||
from library import galore_optimizer
|
||||
|
||||
# if lr_scheduler is not layerwise, it is None. if layerwise, it is a dummy scheduler
|
||||
optimizer, lr_scheduler = galore_optimizer.get_optimizer(
|
||||
args, args.optimizer_type, params_to_optimize, training_models, accelerator.num_processes
|
||||
)
|
||||
|
||||
if lr_scheduler is None:
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
|
||||
if args.full_fp16:
|
||||
|
||||
@@ -564,11 +564,6 @@ class NetworkTrainer:
|
||||
"random_crop": bool(subset.random_crop),
|
||||
"shuffle_caption": bool(subset.shuffle_caption),
|
||||
"keep_tokens": subset.keep_tokens,
|
||||
"keep_tokens_separator": subset.keep_tokens_separator,
|
||||
"secondary_separator": subset.secondary_separator,
|
||||
"enable_wildcard": bool(subset.enable_wildcard),
|
||||
"caption_prefix": subset.caption_prefix,
|
||||
"caption_suffix": subset.caption_suffix,
|
||||
}
|
||||
|
||||
image_dir_or_metadata_file = None
|
||||
|
||||
Reference in New Issue
Block a user