mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 14:45:19 +00:00
Compare commits
167 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
54500b861d | ||
|
|
f2491ee0ac | ||
|
|
1f169ee7fb | ||
|
|
66817992c1 | ||
|
|
8052bcd5cd | ||
|
|
55886a0116 | ||
|
|
33e90cc6a0 | ||
|
|
a0e05fa291 | ||
|
|
e33c007cd0 | ||
|
|
4c6f3125fc | ||
|
|
59c9a8e7ae | ||
|
|
c2419ddabf | ||
|
|
9d678a6f41 | ||
|
|
a02056c566 | ||
|
|
cf8021020f | ||
|
|
fb1054b5e3 | ||
|
|
1e4512b2c8 | ||
|
|
3a7326ae46 | ||
|
|
38b59a93de | ||
|
|
4a1b92d309 | ||
|
|
7b31c0830f | ||
|
|
8f645d354e | ||
|
|
d131bde183 | ||
|
|
9a67e0df39 | ||
|
|
e6d1f509a0 | ||
|
|
afce13d101 | ||
|
|
7981ee186f | ||
|
|
0cfcb5a49c | ||
|
|
c7fd336c5d | ||
|
|
ed30af8343 | ||
|
|
1e0b059982 | ||
|
|
038c09f552 | ||
|
|
5d1b54de45 | ||
|
|
18156bf2a1 | ||
|
|
5845de7d7c | ||
|
|
e97d67a681 | ||
|
|
f0bb3ae825 | ||
|
|
9806b00f74 | ||
|
|
f2989b36c2 | ||
|
|
624fbadea2 | ||
|
|
d4ba37f543 | ||
|
|
1da6d43109 | ||
|
|
9aee793078 | ||
|
|
89c3033401 | ||
|
|
334d07bf96 | ||
|
|
6417f5d7c1 | ||
|
|
8088c04a71 | ||
|
|
f7b1911f1b | ||
|
|
045cd38b6e | ||
|
|
363f1dfab9 | ||
|
|
4e24733f1c | ||
|
|
bb91a10b5f | ||
|
|
98635ebde2 | ||
|
|
24823b061d | ||
|
|
0fe1afd4ef | ||
|
|
5907bbd9de | ||
|
|
7c38c33ed6 | ||
|
|
5bec05e045 | ||
|
|
6084611508 | ||
|
|
71a7a27319 | ||
|
|
ec2efe52e4 | ||
|
|
0f0158ddaa | ||
|
|
dde7807b00 | ||
|
|
f8e8df5a04 | ||
|
|
f4c9276336 | ||
|
|
a5c38e5d5b | ||
|
|
9c7237157d | ||
|
|
5931948adb | ||
|
|
8a5e3904a0 | ||
|
|
d679dc4de1 | ||
|
|
a002d10a4d | ||
|
|
3a06968332 | ||
|
|
6fbd526931 | ||
|
|
c437dce056 | ||
|
|
fc00691898 | ||
|
|
990ceddd14 | ||
|
|
226db64736 | ||
|
|
2429ac73b2 | ||
|
|
dd8e17cb37 | ||
|
|
db756e9a34 | ||
|
|
16e5981d31 | ||
|
|
575c51fd3b | ||
|
|
5b2447f71d | ||
|
|
0ccb4d4a3a | ||
|
|
b5bb8bec67 | ||
|
|
5cdf4e34a1 | ||
|
|
061e157191 | ||
|
|
d859a3a925 | ||
|
|
5a1a14f9fc | ||
|
|
b6ba4cac83 | ||
|
|
99b607c60c | ||
|
|
289298b17d | ||
|
|
f7a1868fc2 | ||
|
|
02bb8e0ac3 | ||
|
|
bc909e8359 | ||
|
|
c971d9319c | ||
|
|
0c942106bf | ||
|
|
c0c4d4ddc6 | ||
|
|
c924c47f37 | ||
|
|
5b54086663 | ||
|
|
9e797cc151 | ||
|
|
cc10a62e16 | ||
|
|
7e5b6154d0 | ||
|
|
6d6df18387 | ||
|
|
ca36f47dfc | ||
|
|
45f9cc9e0e | ||
|
|
3699a90645 | ||
|
|
714846e1e1 | ||
|
|
08d85d4013 | ||
|
|
0ec7743436 | ||
|
|
a72d80aa85 | ||
|
|
b556fc43bc | ||
|
|
dbb9c19669 | ||
|
|
bca6a44974 | ||
|
|
8ab5c8cb28 | ||
|
|
774c4059fb | ||
|
|
5f1d07d62f | ||
|
|
cd984992cf | ||
|
|
99f4940eb7 | ||
|
|
41dd835a89 | ||
|
|
ee42c5cd42 | ||
|
|
47b6101465 | ||
|
|
7889a52f95 | ||
|
|
8d562ecf48 | ||
|
|
2767a0f9f2 | ||
|
|
af08c56ce0 | ||
|
|
dfc56e9227 | ||
|
|
84d157995e | ||
|
|
ed5bfda372 | ||
|
|
a59822540f | ||
|
|
968bbd2f47 | ||
|
|
1b4bdff331 | ||
|
|
678fe003e3 | ||
|
|
3b1af3f1a6 | ||
|
|
437501cde3 | ||
|
|
8bd2072e19 | ||
|
|
85df289190 | ||
|
|
8856496aac | ||
|
|
a7df7db464 | ||
|
|
59507c7c02 | ||
|
|
09c719c926 | ||
|
|
e54b6311ef | ||
|
|
fdbdb4748a | ||
|
|
76a2b14cdb | ||
|
|
b08154dc36 | ||
|
|
165fc43655 | ||
|
|
42cbf75cfa | ||
|
|
e6ad3cbc66 | ||
|
|
2127907dd3 | ||
|
|
164a1978de | ||
|
|
cb1076ed23 | ||
|
|
ad5f318d06 | ||
|
|
60bbe64489 | ||
|
|
b9085fc80a | ||
|
|
2fad5b88bc | ||
|
|
b271a6bd89 | ||
|
|
758a1e7f66 | ||
|
|
1cba447102 | ||
|
|
e25164cfed | ||
|
|
f6556f7972 | ||
|
|
69579668bb | ||
|
|
2e688b7cd3 | ||
|
|
2fcbfec178 | ||
|
|
e1143caf38 | ||
|
|
a7485e4d9e | ||
|
|
335b2f960e | ||
|
|
b18d099291 |
34
README-ja.md
34
README-ja.md
@@ -16,13 +16,13 @@ GUIやPowerShellスクリプトなど、より使いやすくする機能が[bma
|
||||
|
||||
当リポジトリ内およびnote.comに記事がありますのでそちらをご覧ください(将来的にはすべてこちらへ移すかもしれません)。
|
||||
|
||||
* [学習について、共通編](./train_README-ja.md) : データ整備やオプションなど
|
||||
* [データセット設定](./config_README-ja.md)
|
||||
* [DreamBoothの学習について](./train_db_README-ja.md)
|
||||
* [fine-tuningのガイド](./fine_tune_README_ja.md):
|
||||
* [LoRAの学習について](./train_network_README-ja.md)
|
||||
* [Textual Inversionの学習について](./train_ti_README-ja.md)
|
||||
* note.com [画像生成スクリプト](https://note.com/kohya_ss/n/n2693183a798e)
|
||||
* [学習について、共通編](./docs/train_README-ja.md) : データ整備やオプションなど
|
||||
* [データセット設定](./docs/config_README-ja.md)
|
||||
* [DreamBoothの学習について](./docs/train_db_README-ja.md)
|
||||
* [fine-tuningのガイド](./docs/fine_tune_README_ja.md):
|
||||
* [LoRAの学習について](./docs/train_network_README-ja.md)
|
||||
* [Textual Inversionの学習について](./docs/train_ti_README-ja.md)
|
||||
* [画像生成スクリプト](./docs/gen_img_README-ja.md)
|
||||
* note.com [モデル変換スクリプト](https://note.com/kohya_ss/n/n374f316fe4ad)
|
||||
|
||||
## Windowsでの動作に必要なプログラム
|
||||
@@ -115,6 +115,26 @@ accelerate configの質問には以下のように答えてください。(bf1
|
||||
|
||||
他のバージョンでは学習がうまくいかない場合があるようです。特に他の理由がなければ指定のバージョンをお使いください。
|
||||
|
||||
### オプション:Lion8bitを使う
|
||||
|
||||
Lion8bitを使う場合には`bitsandbytes`を0.38.0以降にアップグレードする必要があります。`bitsandbytes`をアンインストールし、Windows環境では例えば[こちら](https://github.com/jllllll/bitsandbytes-windows-webui)などからWindows版のwhlファイルをインストールしてください。たとえば以下のような手順になります。
|
||||
|
||||
```powershell
|
||||
pip install https://github.com/jllllll/bitsandbytes-windows-webui/raw/main/bitsandbytes-0.38.1-py3-none-any.whl
|
||||
```
|
||||
|
||||
アップグレード時には`pip install .`でこのリポジトリを更新し、必要に応じて他のパッケージもアップグレードしてください。
|
||||
|
||||
### オプション:PagedAdamW8bitとPagedLion8bitを使う
|
||||
|
||||
PagedAdamW8bitとPagedLion8bitを使う場合には`bitsandbytes`を0.39.0以降にアップグレードする必要があります。`bitsandbytes`をアンインストールし、Windows環境では例えば[こちら](https://github.com/jllllll/bitsandbytes-windows-webui)などからWindows版のwhlファイルをインストールしてください。たとえば以下のような手順になります。
|
||||
|
||||
```powershell
|
||||
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
|
||||
```
|
||||
|
||||
アップグレード時には`pip install .`でこのリポジトリを更新し、必要に応じて他のパッケージもアップグレードしてください。
|
||||
|
||||
## アップグレード
|
||||
|
||||
新しいリリースがあった場合、以下のコマンドで更新できます。
|
||||
|
||||
274
README.md
274
README.md
@@ -12,11 +12,11 @@ This repository contains the scripts for:
|
||||
* DreamBooth training, including U-Net and Text Encoder
|
||||
* Fine-tuning (native training), including U-Net and Text Encoder
|
||||
* LoRA training
|
||||
* Texutl Inversion training
|
||||
* Textual Inversion training
|
||||
* Image generation
|
||||
* Model conversion (supports 1.x and 2.x, Stable Diffision ckpt/safetensors and Diffusers)
|
||||
|
||||
__Stable Diffusion web UI now seems to support LoRA trained by ``sd-scripts``.__ (SD 1.x based only) Thank you for great work!!!
|
||||
__Stable Diffusion web UI now seems to support LoRA trained by ``sd-scripts``.__ Thank you for great work!!!
|
||||
|
||||
## About requirements.txt
|
||||
|
||||
@@ -28,14 +28,16 @@ The scripts are tested with PyTorch 1.12.1 and 1.13.0, Diffusers 0.10.2.
|
||||
|
||||
Most of the documents are written in Japanese.
|
||||
|
||||
* [Training guide - common](./train_README-ja.md) : data preparation, options etc...
|
||||
* [Chinese version](./train_README-zh.md)
|
||||
* [Dataset config](./config_README-ja.md)
|
||||
* [DreamBooth training guide](./train_db_README-ja.md)
|
||||
* [Step by Step fine-tuning guide](./fine_tune_README_ja.md):
|
||||
* [training LoRA](./train_network_README-ja.md)
|
||||
* [training Textual Inversion](./train_ti_README-ja.md)
|
||||
* note.com [Image generation](https://note.com/kohya_ss/n/n2693183a798e)
|
||||
[English translation by darkstorm2150 is here](https://github.com/darkstorm2150/sd-scripts#links-to-usage-documentation). Thanks to darkstorm2150!
|
||||
|
||||
* [Training guide - common](./docs/train_README-ja.md) : data preparation, options etc...
|
||||
* [Chinese version](./docs/train_README-zh.md)
|
||||
* [Dataset config](./docs/config_README-ja.md)
|
||||
* [DreamBooth training guide](./docs/train_db_README-ja.md)
|
||||
* [Step by Step fine-tuning guide](./docs/fine_tune_README_ja.md):
|
||||
* [training LoRA](./docs/train_network_README-ja.md)
|
||||
* [training Textual Inversion](./docs/train_ti_README-ja.md)
|
||||
* [Image generation](./docs/gen_img_README-ja.md)
|
||||
* note.com [Model conversion](https://note.com/kohya_ss/n/n374f316fe4ad)
|
||||
|
||||
## Windows Required Dependencies
|
||||
@@ -97,6 +99,26 @@ note: Some user reports ``ValueError: fp16 mixed precision requires a GPU`` is o
|
||||
Other versions of PyTorch and xformers seem to have problems with training.
|
||||
If there is no other reason, please install the specified version.
|
||||
|
||||
### Optional: Use Lion8bit
|
||||
|
||||
For Lion8bit, you need to upgrade `bitsandbytes` to 0.38.0 or later. Uninstall `bitsandbytes`, and for Windows, install the Windows version whl file from [here](https://github.com/jllllll/bitsandbytes-windows-webui) or other sources, like:
|
||||
|
||||
```powershell
|
||||
pip install https://github.com/jllllll/bitsandbytes-windows-webui/raw/main/bitsandbytes-0.38.1-py3-none-any.whl
|
||||
```
|
||||
|
||||
For upgrading, upgrade this repo with `pip install .`, and upgrade necessary packages manually.
|
||||
|
||||
### Optional: Use PagedAdamW8bit and PagedLion8bit
|
||||
|
||||
For PagedAdamW8bit and PagedLion8bit, you need to upgrade `bitsandbytes` to 0.39.0 or later. Uninstall `bitsandbytes`, and for Windows, install the Windows version whl file from [here](https://github.com/jllllll/bitsandbytes-windows-webui) or other sources, like:
|
||||
|
||||
```powershell
|
||||
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
|
||||
```
|
||||
|
||||
For upgrading, upgrade this repo with `pip install .`, and upgrade necessary packages manually.
|
||||
|
||||
## Upgrade
|
||||
|
||||
When a new release comes out you can upgrade your repo with the following command:
|
||||
@@ -128,40 +150,212 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
|
||||
|
||||
## Change History
|
||||
|
||||
### 30 Apr. 2023, 2023/04/30
|
||||
### 15 Jun. 2023, 2023/06/15
|
||||
|
||||
- Added Chinese translation of [DreamBooth guide](./train_db_README-zh.md) and [LoRA guide](./train_network_README-zh.md). [PR #459](https://github.com/kohya-ss/sd-scripts/pull/459) Thanks to tomj2ee!
|
||||
- Added [documentation](./gen_img_README-ja.md) for image generation script `gen_img_diffusers.py` (Japanese version only).
|
||||
- 中国語版の[DreamBoothガイド](./train_db_README-zh.md)と[LoRAガイド](./train_network_README-zh.md)が追加されました。 [PR #459](https://github.com/kohya-ss/sd-scripts/pull/459) tomj2ee氏に感謝します。
|
||||
- 画像生成スクリプト `gen_img_diffusers.py`の簡単な[ドキュメント](./gen_img_README-ja.md)を追加しました(日本語版のみ)。
|
||||
- Prodigy optimizer is supported in each training script. It is a member of D-Adaptation and is effective for DyLoRA training. [PR #585](https://github.com/kohya-ss/sd-scripts/pull/585) Please see the PR for details. Thanks to sdbds!
|
||||
- Install the package with `pip install prodigyopt`. Then specify the option like `--optimizer_type="prodigy"`.
|
||||
- Arbitrary Dataset is supported in each training script (except XTI). You can use it by defining a Dataset class that returns images and captions.
|
||||
- Prepare a Python script and define a class that inherits `train_util.MinimalDataset`. Then specify the option like `--dataset_class package.module.DatasetClass` in each training script.
|
||||
- Please refer to `MinimalDataset` for implementation. I will prepare a sample later.
|
||||
- The following features have been added to the generation script.
|
||||
- Added an option `--highres_fix_disable_control_net` to disable ControlNet in the 2nd stage of Highres. Fix. Please try it if the image is disturbed by some ControlNet such as Canny.
|
||||
- Added Variants similar to sd-dynamic-propmpts in the prompt.
|
||||
- If you specify `{spring|summer|autumn|winter}`, one of them will be randomly selected.
|
||||
- If you specify `{2$$chocolate|vanilla|strawberry}`, two of them will be randomly selected.
|
||||
- If you specify `{1-2$$ and $$chocolate|vanilla|strawberry}`, one or two of them will be randomly selected and connected by ` and `.
|
||||
- You can specify the number of candidates in the range `0-2`. You cannot omit one side like `-2` or `1-`.
|
||||
- It can also be specified for the prompt option.
|
||||
- If you specify `e` or `E`, all candidates will be selected and the prompt will be repeated multiple times (`--images_per_prompt` is ignored). It may be useful for creating X/Y plots.
|
||||
- You can also specify `--am {e$$0.2|0.4|0.6|0.8|1.0},{e$$0.4|0.7|1.0} --d 1234`. In this case, 15 prompts will be generated with 5*3.
|
||||
- There is no weighting function.
|
||||
|
||||
### 26 Apr. 2023, 2023/04/26
|
||||
- 各学習スクリプトでProdigyオプティマイザがサポートされました。D-Adaptationの仲間でDyLoRAの学習に有効とのことです。 [PR #585](https://github.com/kohya-ss/sd-scripts/pull/585) 詳細はPRをご覧ください。sdbds氏に感謝します。
|
||||
- `pip install prodigyopt` としてパッケージをインストールしてください。また `--optimizer_type="prodigy"` のようにオプションを指定します。
|
||||
- 各学習スクリプトで任意のDatasetをサポートしました(XTIを除く)。画像とキャプションを返すDatasetクラスを定義することで、学習スクリプトから利用できます。
|
||||
- Pythonスクリプトを用意し、`train_util.MinimalDataset`を継承するクラスを定義してください。そして各学習スクリプトのオプションで `--dataset_class package.module.DatasetClass` のように指定してください。
|
||||
- 実装方法は `MinimalDataset` を参考にしてください。のちほどサンプルを用意します。
|
||||
- 生成スクリプトに以下の機能追加を行いました。
|
||||
- Highres. Fixの2nd stageでControlNetを無効化するオプション `--highres_fix_disable_control_net` を追加しました。Canny等一部のControlNetで画像が乱れる場合にお試しください。
|
||||
- プロンプトでsd-dynamic-propmptsに似たVariantをサポートしました。
|
||||
- `{spring|summer|autumn|winter}` のように指定すると、いずれかがランダムに選択されます。
|
||||
- `{2$$chocolate|vanilla|strawberry}` のように指定すると、いずれか2個がランダムに選択されます。
|
||||
- `{1-2$$ and $$chocolate|vanilla|strawberry}` のように指定すると、1個か2個がランダムに選択され ` and ` で接続されます。
|
||||
- 個数のレンジ指定では`0-2`のように0個も指定可能です。`-2`や`1-`のような片側の省略はできません。
|
||||
- プロンプトオプションに対しても指定可能です。
|
||||
- `{e$$chocolate|vanilla|strawberry}` のように`e`または`E`を指定すると、すべての候補が選択されプロンプトが複数回繰り返されます(`--images_per_prompt`は無視されます)。X/Y plotの作成に便利かもしれません。
|
||||
- `--am {e$$0.2|0.4|0.6|0.8|1.0},{e$$0.4|0.7|1.0} --d 1234`のような指定も可能です。この場合、5*3で15回のプロンプトが生成されます。
|
||||
- Weightingの機能はありません。
|
||||
|
||||
- Added [Chinese translation](./train_README-zh.md) of training guide. [PR #445](https://github.com/kohya-ss/sd-scripts/pull/445) Thanks to tomj2ee!
|
||||
- `tag_images_by_wd14_tagger.py` can now get arguments from outside. [PR #453](https://github.com/kohya-ss/sd-scripts/pull/453) Thanks to mio2333!
|
||||
- 学習に関するドキュメントの[中国語版](./train_README-zh.md)が追加されました。 [PR #445](https://github.com/kohya-ss/sd-scripts/pull/445) tomj2ee氏に感謝します。
|
||||
- `tag_images_by_wd14_tagger.py`の引数を外部から取得できるようになりました。 [PR #453](https://github.com/kohya-ss/sd-scripts/pull/453) mio2333氏に感謝します。
|
||||
### 8 Jun. 2023, 2023/06/08
|
||||
|
||||
### 25 Apr. 2023, 2023/04/25
|
||||
- Fixed a bug where clip skip did not work when training with weighted captions (`--weighted_captions` specified) and when generating sample images during training.
|
||||
- 重みづけキャプションでの学習時(`--weighted_captions`指定時)および学習中のサンプル画像生成時にclip skipが機能しない不具合を修正しました。
|
||||
|
||||
- Please do not update for a while if you cannot revert the repository to the previous version when something goes wrong, because the model saving part has been changed.
|
||||
- Added `--save_every_n_steps` option to each training script. The model is saved every specified steps.
|
||||
- `--save_last_n_steps` option can be used to save only the specified number of models (old models will be deleted).
|
||||
- If you specify the `--save_state` option, the state will also be saved at the same time. You can specify the number of steps to keep the state with the `--save_last_n_steps_state` option (the same value as `--save_last_n_steps` is used if omitted).
|
||||
- You can use the epoch-based model saving and state saving options together.
|
||||
- Not tested in multi-GPU environment. Please report any bugs.
|
||||
- `--cache_latents_to_disk` option automatically enables `--cache_latents` option when specified. [#438](https://github.com/kohya-ss/sd-scripts/issues/438)
|
||||
- Fixed a bug in `gen_img_diffusers.py` where latents upscaler would fail with a batch size of 2 or more.
|
||||
|
||||
- モデル保存部分を変更していますので、何か不具合が起きた時にリポジトリを前のバージョンに戻せない場合には、しばらく更新を控えてください。
|
||||
- 各学習スクリプトに`--save_every_n_steps`オプションを追加しました。指定ステップごとにモデルを保存します。
|
||||
- `--save_last_n_steps`オプションに数値を指定すると、そのステップ数のモデルのみを保存します(古いモデルは削除されます)。
|
||||
- `--save_state`オプションを指定するとstateも同時に保存します。`--save_last_n_steps_state`オプションでstateを残すステップ数を指定できます(省略時は`--save_last_n_steps`と同じ値が使われます)。
|
||||
- エポックごとのモデル保存、state保存のオプションと共存できます。
|
||||
- マルチGPU環境でのテストを行っていないため、不具合等あればご報告ください。
|
||||
- `--cache_latents_to_disk`オプションが指定されたとき、`--cache_latents`オプションが自動的に有効になるようにしました。 [#438](https://github.com/kohya-ss/sd-scripts/issues/438)
|
||||
- `gen_img_diffusers.py`でlatents upscalerがバッチサイズ2以上でエラーとなる不具合を修正しました。
|
||||
### 6 Jun. 2023, 2023/06/06
|
||||
|
||||
- Fix `train_network.py` to probably work with older versions of LyCORIS.
|
||||
- `gen_img_diffusers.py` now supports `BREAK` syntax.
|
||||
- `train_network.py`がLyCORISの以前のバージョンでも恐らく動作するよう修正しました。
|
||||
- `gen_img_diffusers.py` で `BREAK` 構文をサポートしました。
|
||||
|
||||
### 3 Jun. 2023, 2023/06/03
|
||||
|
||||
- Max Norm Regularization is now available in `train_network.py`. [PR #545](https://github.com/kohya-ss/sd-scripts/pull/545) Thanks to AI-Casanova!
|
||||
- Max Norm Regularization is a technique to stabilize network training by limiting the norm of network weights. It may be effective in suppressing overfitting of LoRA and improving stability when used with other LoRAs. See PR for details.
|
||||
- Specify as `--scale_weight_norms=1.0`. It seems good to try from `1.0`.
|
||||
- The networks other than LoRA in this repository (such as LyCORIS) do not support this option.
|
||||
|
||||
- Three types of dropout have been added to `train_network.py` and LoRA network.
|
||||
- Dropout is a technique to suppress overfitting and improve network performance by randomly setting some of the network outputs to 0.
|
||||
- `--network_dropout` is a normal dropout at the neuron level. In the case of LoRA, it is applied to the output of down. Proposed in [PR #545](https://github.com/kohya-ss/sd-scripts/pull/545) Thanks to AI-Casanova!
|
||||
- `--network_dropout=0.1` specifies the dropout probability to `0.1`.
|
||||
- Note that the specification method is different from LyCORIS.
|
||||
- For LoRA network, `--network_args` can specify `rank_dropout` to dropout each rank with specified probability. Also `module_dropout` can be specified to dropout each module with specified probability.
|
||||
- Specify as `--network_args "rank_dropout=0.2" "module_dropout=0.1"`.
|
||||
- `--network_dropout`, `rank_dropout`, and `module_dropout` can be specified at the same time.
|
||||
- Values of 0.1 to 0.3 may be good to try. Values greater than 0.5 should not be specified.
|
||||
- `rank_dropout` and `module_dropout` are original techniques of this repository. Their effectiveness has not been verified yet.
|
||||
- The networks other than LoRA in this repository (such as LyCORIS) do not support these options.
|
||||
|
||||
- Added an option `--scale_v_pred_loss_like_noise_pred` to scale v-prediction loss like noise prediction in each training script.
|
||||
- By scaling the loss according to the time step, the weights of global noise prediction and local noise prediction become the same, and the improvement of details may be expected.
|
||||
- See [this article](https://xrg.hatenablog.com/entry/2023/06/02/202418) by xrg for details (written in Japanese). Thanks to xrg for the great suggestion!
|
||||
|
||||
- Max Norm Regularizationが`train_network.py`で使えるようになりました。[PR #545](https://github.com/kohya-ss/sd-scripts/pull/545) AI-Casanova氏に感謝します。
|
||||
- Max Norm Regularizationは、ネットワークの重みのノルムを制限することで、ネットワークの学習を安定させる手法です。LoRAの過学習の抑制、他のLoRAと併用した時の安定性の向上が期待できるかもしれません。詳細はPRを参照してください。
|
||||
- `--scale_weight_norms=1.0`のように `--scale_weight_norms` で指定してください。`1.0`から試すと良いようです。
|
||||
- LyCORIS等、当リポジトリ以外のネットワークは現時点では未対応です。
|
||||
|
||||
- `train_network.py` およびLoRAに計三種類のdropoutを追加しました。
|
||||
- dropoutはネットワークの一部の出力をランダムに0にすることで、過学習の抑制、ネットワークの性能向上等を図る手法です。
|
||||
- `--network_dropout` はニューロン単位の通常のdropoutです。LoRAの場合、downの出力に対して適用されます。[PR #545](https://github.com/kohya-ss/sd-scripts/pull/545) で提案されました。AI-Casanova氏に感謝します。
|
||||
- `--network_dropout=0.1` などとすることで、dropoutの確率を指定できます。
|
||||
- LyCORISとは指定方法が異なりますのでご注意ください。
|
||||
- LoRAの場合、`--network_args`に`rank_dropout`を指定することで各rankを指定確率でdropoutします。また同じくLoRAの場合、`--network_args`に`module_dropout`を指定することで各モジュールを指定確率でdropoutします。
|
||||
- `--network_args "rank_dropout=0.2" "module_dropout=0.1"` のように指定します。
|
||||
- `--network_dropout`、`rank_dropout` 、 `module_dropout` は同時に指定できます。
|
||||
- それぞれの値は0.1~0.3程度から試してみると良いかもしれません。0.5を超える値は指定しない方が良いでしょう。
|
||||
- `rank_dropout`および`module_dropout`は当リポジトリ独自の手法です。有効性の検証はまだ行っていません。
|
||||
- これらのdropoutはLyCORIS等、当リポジトリ以外のネットワークは現時点では未対応です。
|
||||
|
||||
- 各学習スクリプトにv-prediction lossをnoise predictionと同様の値にスケールするオプション`--scale_v_pred_loss_like_noise_pred`を追加しました。
|
||||
- タイムステップに応じてlossをスケールすることで、 大域的なノイズの予測と局所的なノイズの予測の重みが同じになり、ディテールの改善が期待できるかもしれません。
|
||||
- 詳細はxrg氏のこちらの記事をご参照ください:[noise_predictionモデルとv_predictionモデルの損失 - 勾配降下党青年局](https://xrg.hatenablog.com/entry/2023/06/02/202418) xrg氏の素晴らしい記事に感謝します。
|
||||
|
||||
### 31 May 2023, 2023/05/31
|
||||
|
||||
- Show warning when image caption file does not exist during training. [PR #533](https://github.com/kohya-ss/sd-scripts/pull/533) Thanks to TingTingin!
|
||||
- Warning is also displayed when using class+identifier dataset. Please ignore if it is intended.
|
||||
- `train_network.py` now supports merging network weights before training. [PR #542](https://github.com/kohya-ss/sd-scripts/pull/542) Thanks to u-haru!
|
||||
- `--base_weights` option specifies LoRA or other model files (multiple files are allowed) to merge.
|
||||
- `--base_weights_multiplier` option specifies multiplier of the weights to merge (multiple values are allowed). If omitted or less than `base_weights`, 1.0 is used.
|
||||
- This is useful for incremental learning. See PR for details.
|
||||
- Show warning and continue training when uploading to HuggingFace fails.
|
||||
|
||||
- 学習時に画像のキャプションファイルが存在しない場合、警告が表示されるようになりました。 [PR #533](https://github.com/kohya-ss/sd-scripts/pull/533) TingTingin氏に感謝します。
|
||||
- class+identifier方式のデータセットを利用している場合も警告が表示されます。意図している通りの場合は無視してください。
|
||||
- `train_network.py` に学習前にモデルにnetworkの重みをマージする機能が追加されました。 [PR #542](https://github.com/kohya-ss/sd-scripts/pull/542) u-haru氏に感謝します。
|
||||
- `--base_weights` オプションでLoRA等のモデルファイル(複数可)を指定すると、それらの重みをマージします。
|
||||
- `--base_weights_multiplier` オプションでマージする重みの倍率(複数可)を指定できます。省略時または`base_weights`よりも数が少ない場合は1.0になります。
|
||||
- 差分追加学習などにご利用ください。詳細はPRをご覧ください。
|
||||
- HuggingFaceへのアップロードに失敗した場合、警告を表示しそのまま学習を続行するよう変更しました。
|
||||
|
||||
### 25 May 2023, 2023/05/25
|
||||
|
||||
- [D-Adaptation v3.0](https://github.com/facebookresearch/dadaptation) is now supported. [PR #530](https://github.com/kohya-ss/sd-scripts/pull/530) Thanks to sdbds!
|
||||
- `--optimizer_type` now accepts `DAdaptAdamPreprint`, `DAdaptAdanIP`, and `DAdaptLion`.
|
||||
- `DAdaptAdam` is now new. The old `DAdaptAdam` is available with `DAdaptAdamPreprint`.
|
||||
- Simply specifying `DAdaptation` will use `DAdaptAdamPreprint` (same behavior as before).
|
||||
- You need to install D-Adaptation v3.0. After activating venv, please do `pip install -U dadaptation`.
|
||||
- See PR and D-Adaptation documentation for details.
|
||||
- [D-Adaptation v3.0](https://github.com/facebookresearch/dadaptation)がサポートされました。 [PR #530](https://github.com/kohya-ss/sd-scripts/pull/530) sdbds氏に感謝します。
|
||||
- `--optimizer_type`に`DAdaptAdamPreprint`、`DAdaptAdanIP`、`DAdaptLion` が追加されました。
|
||||
- `DAdaptAdam`が新しくなりました。今までの`DAdaptAdam`は`DAdaptAdamPreprint`で使用できます。
|
||||
- 単に `DAdaptation` を指定すると`DAdaptAdamPreprint`が使用されます(今までと同じ動き)。
|
||||
- D-Adaptation v3.0のインストールが必要です。venvを有効にした後 `pip install -U dadaptation` としてください。
|
||||
- 詳細はPRおよびD-Adaptationのドキュメントを参照してください。
|
||||
|
||||
### 22 May 2023, 2023/05/22
|
||||
|
||||
- Fixed several bugs.
|
||||
- The state is saved even when the `--save_state` option is not specified in `fine_tune.py` and `train_db.py`. [PR #521](https://github.com/kohya-ss/sd-scripts/pull/521) Thanks to akshaal!
|
||||
- Cannot load LoRA without `alpha`. [PR #527](https://github.com/kohya-ss/sd-scripts/pull/527) Thanks to Manjiz!
|
||||
- Minor changes to console output during sample generation. [PR #515](https://github.com/kohya-ss/sd-scripts/pull/515) Thanks to yanhuifair!
|
||||
- The generation script now uses xformers for VAE as well.
|
||||
- いくつかのバグ修正を行いました。
|
||||
- `fine_tune.py`と`train_db.py`で`--save_state`オプション未指定時にもstateが保存される。 [PR #521](https://github.com/kohya-ss/sd-scripts/pull/521) akshaal氏に感謝します。
|
||||
- `alpha`を持たないLoRAを読み込めない。[PR #527](https://github.com/kohya-ss/sd-scripts/pull/527) Manjiz氏に感謝します。
|
||||
- サンプル生成時のコンソール出力の軽微な変更。[PR #515](https://github.com/kohya-ss/sd-scripts/pull/515) yanhuifair氏に感謝します。
|
||||
- 生成スクリプトでVAEについてもxformersを使うようにしました。
|
||||
|
||||
### 16 May 2023, 2023/05/16
|
||||
|
||||
- Fixed an issue where an error would occur if the encoding of the prompt file was different from the default. [PR #510](https://github.com/kohya-ss/sd-scripts/pull/510) Thanks to sdbds!
|
||||
- Please save the prompt file in UTF-8.
|
||||
- プロンプトファイルのエンコーディングがデフォルトと異なる場合にエラーが発生する問題を修正しました。 [PR #510](https://github.com/kohya-ss/sd-scripts/pull/510) sdbds氏に感謝します。
|
||||
- プロンプトファイルはUTF-8で保存してください。
|
||||
|
||||
### 15 May 2023, 2023/05/15
|
||||
|
||||
- Added [English translation of documents](https://github.com/darkstorm2150/sd-scripts#links-to-usage-documentation) by darkstorm2150. Thank you very much!
|
||||
- The prompt for sample generation during training can now be specified in `.toml` or `.json`. [PR #504](https://github.com/kohya-ss/sd-scripts/pull/504) Thanks to Linaqruf!
|
||||
- For details on prompt description, please see the PR.
|
||||
|
||||
- darkstorm2150氏に[ドキュメント類を英訳](https://github.com/darkstorm2150/sd-scripts#links-to-usage-documentation)していただきました。ありがとうございます!
|
||||
- 学習中のサンプル生成のプロンプトを`.toml`または`.json`で指定可能になりました。 [PR #504](https://github.com/kohya-ss/sd-scripts/pull/504) Linaqruf氏に感謝します。
|
||||
- プロンプト記述の詳細は当該PRをご覧ください。
|
||||
|
||||
### 11 May 2023, 2023/05/11
|
||||
|
||||
- Added an option `--dim_from_weights` to `train_network.py` to automatically determine the dim(rank) from the weight file. [PR #491](https://github.com/kohya-ss/sd-scripts/pull/491) Thanks to AI-Casanova!
|
||||
- It is useful in combination with `resize_lora.py`. Please see the PR for details.
|
||||
- Fixed a bug where the noise resolution was incorrect with Multires noise. [PR #489](https://github.com/kohya-ss/sd-scripts/pull/489) Thanks to sdbds!
|
||||
- Please see the PR for details.
|
||||
- The image generation scripts can now use img2img and highres fix at the same time.
|
||||
- Fixed a bug where the hint image of ControlNet was incorrectly BGR instead of RGB in the image generation scripts.
|
||||
- Added a feature to the image generation scripts to use the memory-efficient VAE.
|
||||
- If you specify a number with the `--vae_slices` option, the memory-efficient VAE will be used. The maximum output size will be larger, but it will be slower. Please specify a value of about `16` or `32`.
|
||||
- The implementation of the VAE is in `library/slicing_vae.py`.
|
||||
|
||||
- `train_network.py`にdim(rank)を重みファイルから自動決定するオプション`--dim_from_weights`が追加されました。 [PR #491](https://github.com/kohya-ss/sd-scripts/pull/491) AI-Casanova氏に感謝します。
|
||||
- `resize_lora.py`と組み合わせると有用です。詳細はPRもご参照ください。
|
||||
- Multires noiseでノイズ解像度が正しくない不具合が修正されました。 [PR #489](https://github.com/kohya-ss/sd-scripts/pull/489) sdbds氏に感謝します。
|
||||
- 詳細は当該PRをご参照ください。
|
||||
- 生成スクリプトでimg2imgとhighres fixを同時に使用できるようにしました。
|
||||
- 生成スクリプトでControlNetのhint画像が誤ってBGRだったのをRGBに修正しました。
|
||||
- 生成スクリプトで省メモリ化VAEを使えるよう機能追加しました。
|
||||
- `--vae_slices`オプションに数値を指定すると、省メモリ化VAEを用います。出力可能な最大サイズが大きくなりますが、遅くなります。`16`または`32`程度の値を指定してください。
|
||||
- VAEの実装は`library/slicing_vae.py`にあります。
|
||||
|
||||
### 7 May 2023, 2023/05/07
|
||||
|
||||
- The documentation has been moved to the `docs` folder. If you have links, please change them.
|
||||
- Removed `gradio` from `requirements.txt`.
|
||||
- DAdaptAdaGrad, DAdaptAdan, and DAdaptSGD are now supported by DAdaptation. [PR#455](https://github.com/kohya-ss/sd-scripts/pull/455) Thanks to sdbds!
|
||||
- DAdaptation needs to be installed. Also, depending on the optimizer, DAdaptation may need to be updated. Please update with `pip install --upgrade dadaptation`.
|
||||
- Added support for pre-calculation of LoRA weights in image generation scripts. Specify `--network_pre_calc`.
|
||||
- The prompt option `--am` is available. Also, it is disabled when Regional LoRA is used.
|
||||
- Added Adaptive noise scale to each training script. Specify a number with `--adaptive_noise_scale` to enable it.
|
||||
- __Experimental option. It may be removed or changed in the future.__
|
||||
- This is an original implementation that automatically adjusts the value of the noise offset according to the absolute value of the mean of each channel of the latents. It is expected that appropriate noise offsets will be set for bright and dark images, respectively.
|
||||
- Specify it together with `--noise_offset`.
|
||||
- The actual value of the noise offset is calculated as `noise_offset + abs(mean(latents, dim=(2,3))) * adaptive_noise_scale`. Since the latent is close to a normal distribution, it may be a good idea to specify a value of about 1/10 to the same as the noise offset.
|
||||
- Negative values can also be specified, in which case the noise offset will be clipped to 0 or more.
|
||||
- Other minor fixes.
|
||||
|
||||
- ドキュメントを`docs`フォルダに移動しました。リンク等を張られている場合は変更をお願いいたします。
|
||||
- `requirements.txt`から`gradio`を削除しました。
|
||||
- DAdaptationで新しくDAdaptAdaGrad、DAdaptAdan、DAdaptSGDがサポートされました。[PR#455](https://github.com/kohya-ss/sd-scripts/pull/455) sdbds氏に感謝します。
|
||||
- dadaptationのインストールが必要です。またオプティマイザによってはdadaptationの更新が必要です。`pip install --upgrade dadaptation`で更新してください。
|
||||
- 画像生成スクリプトでLoRAの重みの事前計算をサポートしました。`--network_pre_calc`を指定してください。
|
||||
- プロンプトオプションの`--am`が利用できます。またRegional LoRA使用時には無効になります。
|
||||
- 各学習スクリプトにAdaptive noise scaleを追加しました。`--adaptive_noise_scale`で数値を指定すると有効になります。
|
||||
- __実験的オプションです。将来的に削除、仕様変更される可能性があります。__
|
||||
- Noise offsetの値を、latentsの各チャネルの平均値の絶対値に応じて自動調整するオプションです。独自の実装で、明るい画像、暗い画像に対してそれぞれ適切なnoise offsetが設定されることが期待されます。
|
||||
- `--noise_offset` と同時に指定してください。
|
||||
- 実際のNoise offsetの値は `noise_offset + abs(mean(latents, dim=(2,3))) * adaptive_noise_scale` で計算されます。 latentは正規分布に近いためnoise_offsetの1/10~同程度の値を指定するとよいかもしれません。
|
||||
- 負の値も指定でき、その場合はnoise offsetは0以上にclipされます。
|
||||
- その他の細かい修正を行いました。
|
||||
|
||||
Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates.
|
||||
最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。
|
||||
@@ -180,7 +374,7 @@ The LoRA supported by `train_network.py` has been named to avoid confusion. The
|
||||
|
||||
LoRA-LierLa is the default LoRA type for `train_network.py` (without `conv_dim` network arg). LoRA-LierLa can be used with [our extension](https://github.com/kohya-ss/sd-webui-additional-networks) for AUTOMATIC1111's Web UI, or with the built-in LoRA feature of the Web UI.
|
||||
|
||||
To use LoRA-C3Liar with Web UI, please use our extension.
|
||||
To use LoRA-C3Lier with Web UI, please use our extension.
|
||||
|
||||
### LoRAの名称について
|
||||
|
||||
@@ -196,7 +390,7 @@ To use LoRA-C3Liar with Web UI, please use our extension.
|
||||
|
||||
LoRA-LierLa は[Web UI向け拡張](https://github.com/kohya-ss/sd-webui-additional-networks)、またはAUTOMATIC1111氏のWeb UIのLoRA機能で使用することができます。
|
||||
|
||||
LoRA-C3Liarを使いWeb UIで生成するには拡張を使用してください。
|
||||
LoRA-C3Lierを使いWeb UIで生成するには拡張を使用してください。
|
||||
|
||||
## Sample image generation during training
|
||||
A prompt file might look like this, for example
|
||||
|
||||
@@ -153,7 +153,9 @@ python gen_img_diffusers.py --ckpt <モデル名> --outdir <画像出力先>
|
||||
|
||||
- `--network_mul`:使用する追加ネットワークの重みを何倍にするかを指定します。デフォルトは`1`です。`--network_mul 0.8`のように指定します。複数のLoRAを使用する場合は`--network_mul 0.4 0.5 0.7`のように指定します。引数の数は`--network_module`で指定した数と同じにしてください。
|
||||
|
||||
- `--network_merge`:使用する追加ネットワークの重みを`--network_mul`に指定した重みであらかじめマージします。プロンプトオプションの`--am`は使用できなくなりますが、LoRA未使用時と同じ程度まで生成が高速化されます。
|
||||
- `--network_merge`:使用する追加ネットワークの重みを`--network_mul`に指定した重みであらかじめマージします。`--network_pre_calc` と同時に使用できません。プロンプトオプションの`--am`、およびRegional LoRAは使用できなくなりますが、LoRA未使用時と同じ程度まで生成が高速化されます。
|
||||
|
||||
- `--network_pre_calc`:使用する追加ネットワークの重みを生成ごとにあらかじめ計算します。プロンプトオプションの`--am`が使用できます。LoRA未使用時と同じ程度まで生成は高速化されますが、生成前に重みを計算する時間が必要で、またメモリ使用量も若干増加します。Regional LoRA使用時は無効になります 。
|
||||
|
||||
# 主なオプションの指定例
|
||||
|
||||
@@ -295,7 +295,7 @@ Stable Diffusion のv1は512\*512で学習されていますが、それに加
|
||||
|
||||
また任意の解像度で学習するため、事前に画像データの縦横比を統一しておく必要がなくなります。
|
||||
|
||||
設定で有効、向こうが切り替えられますが、ここまでの設定ファイルの記述例では有効になっています(`true` が設定されています)。
|
||||
設定で有効、無効が切り替えられますが、ここまでの設定ファイルの記述例では有効になっています(`true` が設定されています)。
|
||||
|
||||
学習解像度はパラメータとして与えられた解像度の面積(=メモリ使用量)を超えない範囲で、64ピクセル単位(デフォルト、変更可)で縦横に調整、作成されます。
|
||||
|
||||
@@ -463,27 +463,6 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
|
||||
|
||||
xformersオプションを指定するとxformersのCrossAttentionを用います。xformersをインストールしていない場合やエラーとなる場合(環境にもよりますが `mixed_precision="no"` の場合など)、代わりに `mem_eff_attn` オプションを指定すると省メモリ版CrossAttentionを使用します(xformersよりも速度は遅くなります)。
|
||||
|
||||
- `--save_precision`
|
||||
|
||||
保存時のデータ精度を指定します。save_precisionオプションにfloat、fp16、bf16のいずれかを指定すると、その形式でモデルを保存します(DreamBooth、fine tuningでDiffusers形式でモデルを保存する場合は無効です)。モデルのサイズを削減したい場合などにお使いください。
|
||||
|
||||
- `--save_every_n_epochs` / `--save_state` / `--resume`
|
||||
save_every_n_epochsオプションに数値を指定すると、そのエポックごとに学習途中のモデルを保存します。
|
||||
|
||||
save_stateオプションを同時に指定すると、optimizer等の状態も含めた学習状態を合わせて保存します(保存したモデルからも学習再開できますが、それに比べると精度の向上、学習時間の短縮が期待できます)。保存先はフォルダになります。
|
||||
|
||||
学習状態は保存先フォルダに `<output_name>-??????-state`(??????はエポック数)という名前のフォルダで出力されます。長時間にわたる学習時にご利用ください。
|
||||
|
||||
保存された学習状態から学習を再開するにはresumeオプションを使います。学習状態のフォルダ(`output_dir` ではなくその中のstateのフォルダ)を指定してください。
|
||||
|
||||
なおAcceleratorの仕様により、エポック数、global stepは保存されておらず、resumeしたときにも1からになりますがご容赦ください。
|
||||
|
||||
- `--save_model_as` (DreamBooth, fine tuning のみ)
|
||||
|
||||
モデルの保存形式を`ckpt, safetensors, diffusers, diffusers_safetensors` から選べます。
|
||||
|
||||
`--save_model_as=safetensors` のように指定します。Stable Diffusion形式(ckptまたはsafetensors)を読み込み、Diffusers形式で保存する場合、不足する情報はHugging Faceからv1.5またはv2.1の情報を落としてきて補完します。
|
||||
|
||||
- `--clip_skip`
|
||||
|
||||
`2` を指定すると、Text Encoder (CLIP) の後ろから二番目の層の出力を用います。1またはオプション省略時は最後の層を用います。
|
||||
@@ -502,6 +481,12 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
|
||||
|
||||
clip_skipと同様に、モデルの学習状態と異なる長さで学習するには、ある程度の教師データ枚数、長めの学習時間が必要になると思われます。
|
||||
|
||||
- `--weighted_captions`
|
||||
|
||||
指定するとAutomatic1111氏のWeb UIと同様の重み付きキャプションが有効になります。「Textual Inversion と XTI」以外の学習に使用できます。キャプションだけでなく DreamBooth 手法の token string でも有効です。
|
||||
|
||||
重みづけキャプションの記法はWeb UIとほぼ同じで、(abc)や[abc]、(abc:1.23)などが使用できます。入れ子も可能です。括弧内にカンマを含めるとプロンプトのshuffle/dropoutで括弧の対応付けがおかしくなるため、括弧内にはカンマを含めないでください。
|
||||
|
||||
- `--persistent_data_loader_workers`
|
||||
|
||||
Windows環境で指定するとエポック間の待ち時間が大幅に短縮されます。
|
||||
@@ -527,12 +512,28 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
|
||||
|
||||
その後ブラウザを開き、http://localhost:6006/ へアクセスすると表示されます。
|
||||
|
||||
- `--log_with` / `--log_tracker_name`
|
||||
|
||||
学習ログの保存に関するオプションです。`tensorboard` だけでなく `wandb`への保存が可能です。詳細は [PR#428](https://github.com/kohya-ss/sd-scripts/pull/428)をご覧ください。
|
||||
|
||||
- `--noise_offset`
|
||||
|
||||
こちらの記事の実装になります: https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||
|
||||
全体的に暗い、明るい画像の生成結果が良くなる可能性があるようです。LoRA学習でも有効なようです。`0.1` 程度の値を指定するとよいようです。
|
||||
|
||||
- `--adaptive_noise_scale` (実験的オプション)
|
||||
|
||||
Noise offsetの値を、latentsの各チャネルの平均値の絶対値に応じて自動調整するオプションです。`--noise_offset` と同時に指定することで有効になります。Noise offsetの値は `noise_offset + abs(mean(latents, dim=(2,3))) * adaptive_noise_scale` で計算されます。latentは正規分布に近いためnoise_offsetの1/10~同程度の値を指定するとよいかもしれません。
|
||||
|
||||
負の値も指定でき、その場合はnoise offsetは0以上にclipされます。
|
||||
|
||||
- `--multires_noise_iterations` / `--multires_noise_discount`
|
||||
|
||||
Multi resolution noise (pyramid noise)の設定です。詳細は [PR#471](https://github.com/kohya-ss/sd-scripts/pull/471) およびこちらのページ [Multi-Resolution Noise for Diffusion Model Training](https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2) を参照してください。
|
||||
|
||||
`--multires_noise_iterations` に数値を指定すると有効になります。6~10程度の値が良いようです。`--multires_noise_discount` に0.1~0.3 程度の値(LoRA学習等比較的データセットが小さい場合のPR作者の推奨)、ないしは0.8程度の値(元記事の推奨)を指定してください(デフォルトは 0.3)。
|
||||
|
||||
- `--debug_dataset`
|
||||
|
||||
このオプションを付けることで学習を行う前に事前にどのような画像データ、キャプションで学習されるかを確認できます。Escキーを押すと終了してコマンドラインに戻ります。`S`キーで次のステップ(バッチ)、`E`キーで次のエポックに進みます。
|
||||
@@ -545,14 +546,62 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
|
||||
|
||||
DreamBoothおよびfine tuningでは、保存されるモデルはこのVAEを組み込んだものになります。
|
||||
|
||||
- `--cache_latents`
|
||||
- `--cache_latents` / `--cache_latents_to_disk`
|
||||
|
||||
使用VRAMを減らすためVAEの出力をメインメモリにキャッシュします。`flip_aug` 以外のaugmentationは使えなくなります。また全体の学習速度が若干速くなります。
|
||||
|
||||
cache_latents_to_diskを指定するとキャッシュをディスクに保存します。スクリプトを終了し、再度起動した場合もキャッシュが有効になります。
|
||||
|
||||
- `--min_snr_gamma`
|
||||
|
||||
Min-SNR Weighting strategyを指定します。詳細は[こちら](https://github.com/kohya-ss/sd-scripts/pull/308)を参照してください。論文では`5`が推奨されています。
|
||||
|
||||
## モデルの保存に関する設定
|
||||
|
||||
- `--save_precision`
|
||||
|
||||
保存時のデータ精度を指定します。save_precisionオプションにfloat、fp16、bf16のいずれかを指定すると、その形式でモデルを保存します(DreamBooth、fine tuningでDiffusers形式でモデルを保存する場合は無効です)。モデルのサイズを削減したい場合などにお使いください。
|
||||
|
||||
- `--save_every_n_epochs` / `--save_state` / `--resume`
|
||||
|
||||
save_every_n_epochsオプションに数値を指定すると、そのエポックごとに学習途中のモデルを保存します。
|
||||
|
||||
save_stateオプションを同時に指定すると、optimizer等の状態も含めた学習状態を合わせて保存します(保存したモデルからも学習再開できますが、それに比べると精度の向上、学習時間の短縮が期待できます)。保存先はフォルダになります。
|
||||
|
||||
学習状態は保存先フォルダに `<output_name>-??????-state`(??????はエポック数)という名前のフォルダで出力されます。長時間にわたる学習時にご利用ください。
|
||||
|
||||
保存された学習状態から学習を再開するにはresumeオプションを使います。学習状態のフォルダ(`output_dir` ではなくその中のstateのフォルダ)を指定してください。
|
||||
|
||||
なおAcceleratorの仕様により、エポック数、global stepは保存されておらず、resumeしたときにも1からになりますがご容赦ください。
|
||||
|
||||
- `--save_every_n_steps`
|
||||
|
||||
save_every_n_stepsオプションに数値を指定すると、そのステップごとに学習途中のモデルを保存します。save_every_n_epochsと同時に指定できます。
|
||||
|
||||
- `--save_model_as` (DreamBooth, fine tuning のみ)
|
||||
|
||||
モデルの保存形式を`ckpt, safetensors, diffusers, diffusers_safetensors` から選べます。
|
||||
|
||||
`--save_model_as=safetensors` のように指定します。Stable Diffusion形式(ckptまたはsafetensors)を読み込み、Diffusers形式で保存する場合、不足する情報はHugging Faceからv1.5またはv2.1の情報を落としてきて補完します。
|
||||
|
||||
- `--huggingface_repo_id` 等
|
||||
|
||||
huggingface_repo_idが指定されているとモデル保存時に同時にHuggingFaceにアップロードします。アクセストークンの取り扱いに注意してください(HuggingFaceのドキュメントを参照してください)。
|
||||
|
||||
他の引数をたとえば以下のように指定してください。
|
||||
|
||||
- `--huggingface_repo_id "your-hf-name/your-model" --huggingface_path_in_repo "path" --huggingface_repo_type model --huggingface_repo_visibility private --huggingface_token hf_YourAccessTokenHere`
|
||||
|
||||
huggingface_repo_visibilityに`public`を指定するとリポジトリが公開されます。省略時または`private`(などpublic以外)を指定すると非公開になります。
|
||||
|
||||
`--save_state`オプション指定時に`--save_state_to_huggingface`を指定するとstateもアップロードします。
|
||||
|
||||
`--resume`オプション指定時に`--resume_from_huggingface`を指定するとHuggingFaceからstateをダウンロードして再開します。その時の --resumeオプションは `--resume {repo_id}/{path_in_repo}:{revision}:{repo_type}`になります。
|
||||
|
||||
例: `--resume_from_huggingface --resume your-hf-name/your-model/path/test-000002-state:main:model`
|
||||
|
||||
`--async_upload`オプションを指定するとアップロードを非同期で行います。
|
||||
|
||||
## オプティマイザ関係
|
||||
|
||||
- `--optimizer_type`
|
||||
@@ -560,12 +609,22 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
|
||||
- AdamW : [torch.optim.AdamW](https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html)
|
||||
- 過去のバージョンのオプション未指定時と同じ
|
||||
- AdamW8bit : 引数は同上
|
||||
- PagedAdamW8bit : 引数は同上
|
||||
- 過去のバージョンの--use_8bit_adam指定時と同じ
|
||||
- Lion : https://github.com/lucidrains/lion-pytorch
|
||||
- 過去のバージョンの--use_lion_optimizer指定時と同じ
|
||||
- Lion8bit : 引数は同上
|
||||
- PagedLion8bit : 引数は同上
|
||||
- SGDNesterov : [torch.optim.SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html), nesterov=True
|
||||
- SGDNesterov8bit : 引数は同上
|
||||
- DAdaptation : https://github.com/facebookresearch/dadaptation
|
||||
- DAdaptation(DAdaptAdamPreprint) : https://github.com/facebookresearch/dadaptation
|
||||
- DAdaptAdam : 引数は同上
|
||||
- DAdaptAdaGrad : 引数は同上
|
||||
- DAdaptAdan : 引数は同上
|
||||
- DAdaptAdanIP : 引数は同上
|
||||
- DAdaptLion : 引数は同上
|
||||
- DAdaptSGD : 引数は同上
|
||||
- Prodigy : https://github.com/konstmish/prodigy
|
||||
- AdaFactor : [Transformers AdaFactor](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules)
|
||||
- 任意のオプティマイザ
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
__由于文档正在更新中,描述可能有错误。__
|
||||
|
||||
# 关于本学习文档,通用描述
|
||||
# 关于训练,通用描述
|
||||
本库支持模型微调(fine tuning)、DreamBooth、训练LoRA和文本反转(Textual Inversion)(包括[XTI:P+](https://github.com/kohya-ss/sd-scripts/pull/327)
|
||||
)
|
||||
本文档将说明它们通用的学习数据准备方法和选项等。
|
||||
本文档将说明它们通用的训练数据准备方法和选项等。
|
||||
|
||||
# 概要
|
||||
|
||||
@@ -12,15 +12,15 @@ __由于文档正在更新中,描述可能有错误。__
|
||||
|
||||
以下本节说明。
|
||||
|
||||
1. 关于准备学习数据的新形式(使用设置文件)
|
||||
1. 对于在学习中使用的术语的简要解释
|
||||
1. 准备训练数据(使用设置文件的新格式)
|
||||
1. 训练中使用的术语的简要解释
|
||||
1. 先前的指定格式(不使用设置文件,而是从命令行指定)
|
||||
1. 生成学习过程中的示例图像
|
||||
1. 生成训练过程中的示例图像
|
||||
1. 各脚本中常用的共同选项
|
||||
1. 准备 fine tuning 方法的元数据:如说明文字(打标签)等
|
||||
|
||||
|
||||
1. 如果只执行一次,学习就可以进行(相关内容,请参阅各个脚本的文档)。如果需要,以后可以随时参考。
|
||||
1. 如果只执行一次,训练就可以进行(相关内容,请参阅各个脚本的文档)。如果需要,以后可以随时参考。
|
||||
|
||||
|
||||
|
||||
@@ -28,24 +28,25 @@ __由于文档正在更新中,描述可能有错误。__
|
||||
|
||||
在任意文件夹(也可以是多个文件夹)中准备好训练数据的图像文件。支持 `.png`, `.jpg`, `.jpeg`, `.webp`, `.bmp` 格式的文件。通常不需要进行任何预处理,如调整大小等。
|
||||
|
||||
但是请勿使用极小的图像,其尺寸比训练分辨率(稍后将提到)还小,建议事先使用超分辨率AI等进行放大。另外,请注意不要使用过大的图像(约为3000 x 3000像素以上),因为这可能会导致错误,建议事先缩小。
|
||||
但是请勿使用极小的图像,若其尺寸比训练分辨率(稍后将提到)还小,建议事先使用超分辨率AI等进行放大。另外,请注意不要使用过大的图像(约为3000 x 3000像素以上),因为这可能会导致错误,建议事先缩小。
|
||||
|
||||
在训练时,需要整理要用于训练模型的图像数据,并将其指定给脚本。根据训练数据的数量、训练目标和说明(图像描述)是否可用等因素,可以使用几种方法指定训练数据。以下是其中的一些方法(每个名称都不是通用的,而是该存储库自定义的定义)。有关正则化图像的信息将在稍后提供。
|
||||
|
||||
1. DreamBooth、class + identifier方式(可使用正则化图像)
|
||||
|
||||
将训练目标与特定单词(identifier)相关联进行训练。无需准备说明。例如,当要学习特定角色时,由于无需准备说明,因此比较方便,但由于学习数据的所有元素都与identifier相关联,例如发型、服装、背景等,因此在生成时可能会出现无法更换服装的情况。
|
||||
将训练目标与特定单词(identifier)相关联进行训练。无需准备说明。例如,当要学习特定角色时,由于无需准备说明,因此比较方便,但由于训练数据的所有元素都与identifier相关联,例如发型、服装、背景等,因此在生成时可能会出现无法更换服装的情况。
|
||||
|
||||
2. DreamBooth、说明方式(可使用正则化图像)
|
||||
|
||||
准备记录每个图像说明的文本文件进行训练。例如,通过将图像详细信息(如穿着白色衣服的角色A、穿着红色衣服的角色A等)记录在说明中,可以将角色和其他元素分离,并期望模型更准确地学习角色。
|
||||
事先给每个图片写说明(caption),存放到文本文件中,然后进行训练。例如,通过将图像详细信息(如穿着白色衣服的角色A、穿着红色衣服的角色A等)记录在caption中,可以将角色和其他元素分离,并期望模型更准确地学习角色。
|
||||
|
||||
3. 微调方式(不可使用正则化图像)
|
||||
|
||||
先将说明收集到元数据文件中。支持分离标签和说明以及预先缓存latents等功能,以加速训练(这些将在另一篇文档中介绍)。(虽然名为fine tuning方式,但不仅限于fine tuning。)
|
||||
你要学的东西和你可以使用的规范方法的组合如下。
|
||||
|
||||
训练对象和你可以使用的规范方法的组合如下。
|
||||
|
||||
| 学习对象或方法 | 脚本 | DB/class+identifier | DB/caption | fine tuning |
|
||||
| 训练对象或方法 | 脚本 | DB/class+identifier | DB/caption | fine tuning |
|
||||
|----------------| ----- | ----- | ----- | ----- |
|
||||
| fine tuning微调模型 | `fine_tune.py`| x | x | o |
|
||||
| DreamBooth训练模型 | `train_db.py`| o | o | x |
|
||||
@@ -54,15 +55,15 @@ __由于文档正在更新中,描述可能有错误。__
|
||||
|
||||
## 选择哪一个
|
||||
|
||||
如果您想要学习LoRA、Textual Inversion而不需要准备简介文件,则建议使用DreamBooth class+identifier。如果您能够准备好,则DreamBooth Captions方法更好。如果您有大量的训练数据并且不使用规则化图像,则请考虑使用fine-tuning方法。
|
||||
如果您想要训练LoRA、Textual Inversion而不需要准备说明(caption)文件,则建议使用DreamBooth class+identifier。如果您能够准备caption文件,则DreamBooth Captions方法更好。如果您有大量的训练数据并且不使用正则化图像,则请考虑使用fine-tuning方法。
|
||||
|
||||
对于DreamBooth也是一样的,但不能使用fine-tuning方法。对于fine-tuning方法,只能使用fine-tuning方式。
|
||||
对于DreamBooth也是一样的,但不能使用fine-tuning方法。若要进行微调,只能使用fine-tuning方式。
|
||||
|
||||
# 每种方法的指定方式
|
||||
|
||||
在这里,我们只介绍每种指定方法的典型模式。有关更详细的指定方法,请参见[数据集设置](./config_README-ja.md)。
|
||||
|
||||
# DreamBooth,class+identifier方法(可使用规则化图像)
|
||||
# DreamBooth,class+identifier方法(可使用正则化图像)
|
||||
|
||||
在该方法中,每个图像将被视为使用与 `class identifier` 相同的标题进行训练(例如 `shs dog`)。
|
||||
|
||||
@@ -70,15 +71,15 @@ __由于文档正在更新中,描述可能有错误。__
|
||||
|
||||
## step 1.确定identifier和class
|
||||
|
||||
要将学习的目标与identifier和属于该目标的class相关联。
|
||||
要将训练的目标与identifier和属于该目标的class相关联。
|
||||
|
||||
(虽然有很多称呼,但暂时按照原始论文的说法。)
|
||||
|
||||
以下是简要说明(请查阅详细信息)。
|
||||
|
||||
class是学习目标的一般类别。例如,如果要学习特定品种的狗,则class将是“dog”。对于动漫角色,根据模型不同,可能是“boy”或“girl”,也可能是“1boy”或“1girl”。
|
||||
class是训练目标的一般类别。例如,如果要学习特定品种的狗,则class将是“dog”。对于动漫角色,根据模型不同,可能是“boy”或“girl”,也可能是“1boy”或“1girl”。
|
||||
|
||||
identifier是用于识别学习目标并进行学习的单词。可以使用任何单词,但是根据原始论文,“Tokenizer生成的3个或更少字符的罕见单词”是最好的选择。
|
||||
identifier是用于识别训练目标并进行学习的单词。可以使用任何单词,但是根据原始论文,“Tokenizer生成的3个或更少字符的罕见单词”是最好的选择。
|
||||
|
||||
使用identifier和class,例如,“shs dog”可以将模型训练为从class中识别并学习所需的目标。
|
||||
|
||||
@@ -86,9 +87,9 @@ identifier是用于识别学习目标并进行学习的单词。可以使用任
|
||||
|
||||
(作为identifier,我最近使用的一些参考是“shs sts scs cpc coc cic msm usu ici lvl cic dii muk ori hru rik koo yos wny”等。最好是不包含在Danbooru标签中的单词。)
|
||||
|
||||
## step 2. 决定是否使用正则化图像,并生成正则化图像
|
||||
## step 2. 决定是否使用正则化图像,并在使用时生成正则化图像
|
||||
|
||||
正则化图像是为防止前面提到的语言漂移,即整个类别被拉扯成为学习目标而生成的图像。如果不使用正则化图像,例如在 `shs 1girl` 中学习特定角色时,即使在简单的 `1girl` 提示下生成,也会越来越像该角色。这是因为 `1girl` 在训练时的标题中包含了该角色的信息。
|
||||
正则化图像是为防止前面提到的语言漂移,即整个类别被拉扯成为训练目标而生成的图像。如果不使用正则化图像,例如在 `shs 1girl` 中学习特定角色时,即使在简单的 `1girl` 提示下生成,也会越来越像该角色。这是因为 `1girl` 在训练时的标题中包含了该角色的信息。
|
||||
|
||||
通过同时学习目标图像和正则化图像,类别仍然保持不变,仅在将标识符附加到提示中时才生成目标图像。
|
||||
|
||||
@@ -100,46 +101,48 @@ identifier是用于识别学习目标并进行学习的单词。可以使用任
|
||||
|
||||
(由于正则化图像也被训练,因此其质量会影响模型。)
|
||||
|
||||
通常,准备数百张图像是理想的(图像数量太少会导致类别图像无法推广并学习它们的特征)。
|
||||
通常,准备数百张图像是理想的(图像数量太少会导致类别图像无法被归纳,特征也不会被学习)。
|
||||
|
||||
如果要使用生成的图像,生成图像的大小通常应与训练分辨率(更准确地说,是bucket的分辨率,见下文)相匹配。
|
||||
|
||||
|
||||
如果要使用生成的图像,请将其大小通常与训练分辨率(更准确地说是bucket的分辨率)相适应。
|
||||
|
||||
## step 2. 设置文件的描述
|
||||
|
||||
创建一个文本文件,并将其扩展名更改为`.toml`。例如,您可以按以下方式进行描述:
|
||||
|
||||
(以`#`开头的部分是注释,因此您可以直接复制粘贴,或者将其删除,都没有问题。)
|
||||
(以`#`开头的部分是注释,因此您可以直接复制粘贴,或者将其删除。)
|
||||
|
||||
```toml
|
||||
[general]
|
||||
enable_bucket = true # 是否使用Aspect Ratio Bucketing
|
||||
|
||||
[[datasets]]
|
||||
resolution = 512 # 学习分辨率
|
||||
batch_size = 4 # 批量大小
|
||||
resolution = 512 # 训练分辨率
|
||||
batch_size = 4 # 批次大小
|
||||
|
||||
[[datasets.subsets]]
|
||||
image_dir = 'C:\hoge' # 指定包含训练图像的文件夹
|
||||
class_tokens = 'hoge girl' # 指定标识符类
|
||||
num_repeats = 10 # 训练图像的迭代次数
|
||||
num_repeats = 10 # 训练图像的重复次数
|
||||
|
||||
# 以下仅在使用正则化图像时进行描述。不使用则删除
|
||||
[[datasets.subsets]]
|
||||
is_reg = true
|
||||
image_dir = 'C:\reg' # 指定包含正则化图像的文件夹
|
||||
class_tokens = 'girl' # 指定类别
|
||||
num_repeats = 1 # 正则化图像的迭代次数,基本上1就可以了
|
||||
class_tokens = 'girl' # 指定class
|
||||
num_repeats = 1 # 正则化图像的重复次数,基本上1就可以了
|
||||
```
|
||||
|
||||
基本上只需更改以下位置即可进行学习。
|
||||
基本上只需更改以下几个地方即可进行训练。
|
||||
|
||||
1. 学习分辨率
|
||||
1. 训练分辨率
|
||||
|
||||
指定一个数字表示正方形(如果是 `512`,则为 512x512),如果使用方括号和逗号分隔的两个数字,则表示横向×纵向(如果是`[512,768]`,则为 512x768)。在SD1.x系列中,原始学习分辨率为512。指定较大的分辨率,如 `[512,768]` 可能会减少纵向和横向图像生成时的错误。在SD2.x 768系列中,分辨率为 `768`。
|
||||
指定一个数字表示正方形(如果是 `512`,则为 512x512),如果使用方括号和逗号分隔的两个数字,则表示横向×纵向(如果是`[512,768]`,则为 512x768)。在SD1.x系列中,原始训练分辨率为512。指定较大的分辨率,如 `[512,768]` 可能会减少纵向和横向图像生成时的错误。在SD2.x 768系列中,分辨率为 `768`。
|
||||
|
||||
1. 批量大小
|
||||
1. 批次大小
|
||||
|
||||
指定同时学习多少个数据。这取决于GPU的VRAM大小和学习分辨率。详细信息将在后面说明。此外,fine tuning/DreamBooth/LoRA等也会影响批量大小,请查看各个脚本的说明。
|
||||
指定同时训练多少个数据。这取决于GPU的VRAM大小和训练分辨率。详细信息将在后面说明。此外,fine tuning/DreamBooth/LoRA等也会影响批次大小,请查看各个脚本的说明。
|
||||
|
||||
1. 文件夹指定
|
||||
|
||||
@@ -149,7 +152,7 @@ batch_size = 4 # 批量大小
|
||||
|
||||
如前所述,与示例相同。
|
||||
|
||||
1. 迭代次数
|
||||
1. 重复次数
|
||||
|
||||
将在后面说明。
|
||||
|
||||
@@ -159,69 +162,68 @@ batch_size = 4 # 批量大小
|
||||
|
||||
请将重复次数指定为“ __训练用图像的重复次数×训练用图像的数量≥正则化图像的重复次数×正则化图像的数量__ ”。
|
||||
|
||||
(1个epoch(数据一周一次)的数据量为“训练用图像的重复次数×训练用图像的数量”。如果正则化图像的数量多于这个值,则剩余的正则化图像将不会被使用。)
|
||||
(1个epoch(指训练数据过完一遍)的数据量为“训练用图像的重复次数×训练用图像的数量”。如果正则化图像的数量多于这个值,则剩余的正则化图像将不会被使用。)
|
||||
|
||||
## 步骤 3. 学习
|
||||
## 步骤 3. 训练
|
||||
|
||||
请根据每个文档的参考进行学习。
|
||||
详情请参考相关文档进行训练。
|
||||
|
||||
# DreamBooth,标题方式(可使用规范化图像)
|
||||
# DreamBooth,文本说明(caption)方式(可使用正则化图像)
|
||||
|
||||
在此方式中,每个图像都将通过标题进行学习。
|
||||
在此方式中,每个图像都将通过caption进行训练。
|
||||
|
||||
## 步骤 1. 准备标题文件
|
||||
## 步骤 1. 准备文本说明文件
|
||||
|
||||
请将与图像具有相同文件名且扩展名为 `.caption`(可以在设置中更改)的文件放置在用于训练图像的文件夹中。每个文件应该只有一行。编码为 `UTF-8`。
|
||||
|
||||
## 步骤 2. 决定是否使用规范化图像,并在使用时生成规范化图像
|
||||
## 步骤 2. 决定是否使用正则化图像,并在使用时生成正则化图像
|
||||
|
||||
与class+identifier格式相同。可以在规范化图像上附加标题,但通常不需要。
|
||||
与class+identifier格式相同。可以在规范化图像上附加caption,但通常不需要。
|
||||
|
||||
## 步骤 2. 编写设置文件
|
||||
|
||||
创建一个文本文件并将扩展名更改为 `.toml`。例如,可以按以下方式进行记录。
|
||||
创建一个文本文件并将扩展名更改为 `.toml`。例如,您可以按以下方式进行描述:
|
||||
|
||||
```toml
|
||||
[general]
|
||||
enable_bucket = true # Aspect Ratio Bucketingを使うか否か
|
||||
enable_bucket = true # 是否使用Aspect Ratio Bucketing
|
||||
|
||||
[[datasets]]
|
||||
resolution = 512 # 学習解像度
|
||||
batch_size = 4 # 批量大小
|
||||
resolution = 512 # 训练分辨率
|
||||
batch_size = 4 # 批次大小
|
||||
|
||||
[[datasets.subsets]]
|
||||
image_dir = 'C:\hoge' # 指定包含训练图像的文件夹
|
||||
caption_extension = '.caption' # 使用字幕文件扩展名 .txt 时重写
|
||||
num_repeats = 10 # 训练图像的迭代次数
|
||||
caption_extension = '.caption' # 若使用txt文件,更改此项
|
||||
num_repeats = 10 # 训练图像的重复次数
|
||||
|
||||
# 以下仅在使用正则化图像时进行描述。不使用则删除
|
||||
[[datasets.subsets]]
|
||||
is_reg = true
|
||||
image_dir = 'C:\reg' #指定包含正则化图像的文件夹
|
||||
class_tokens = 'girl' # class を指定
|
||||
num_repeats = 1 #
|
||||
正则化图像的迭代次数,基本上1就可以了
|
||||
image_dir = 'C:\reg' # 指定包含正则化图像的文件夹
|
||||
class_tokens = 'girl' # 指定class
|
||||
num_repeats = 1 # 正则化图像的重复次数,基本上1就可以了
|
||||
```
|
||||
|
||||
基本上,您可以通过仅重写以下位置来学习。除非另有说明,否则与类+标识符方法相同。
|
||||
基本上只需更改以下几个地方来训练。除非另有说明,否则与class+identifier方法相同。
|
||||
|
||||
1. 学习分辨率
|
||||
2. 批量大小
|
||||
1. 训练分辨率
|
||||
2. 批次大小
|
||||
3. 文件夹指定
|
||||
4. 标题文件的扩展名
|
||||
4. caption文件的扩展名
|
||||
|
||||
可以指定任意的扩展名。
|
||||
5. 重复次数
|
||||
|
||||
## 步骤 3. 学习
|
||||
## 步骤 3. 训练
|
||||
|
||||
请参考每个文档进行学习。
|
||||
详情请参考相关文档进行训练。
|
||||
|
||||
# 微调方法
|
||||
# 微调方法(fine tuning)
|
||||
|
||||
## 步骤 1. 准备元数据
|
||||
|
||||
将标题和标签整合到管理文件中称为元数据。它的扩展名为 `.json`,格式为json。由于创建方法较长,因此在本文档的末尾进行了描述。
|
||||
将caption和标签整合到管理文件中称为元数据。它的扩展名为 `.json`,格式为json。由于创建方法较长,因此在本文档的末尾进行描述。
|
||||
|
||||
## 步骤 2. 编写设置文件
|
||||
|
||||
@@ -233,16 +235,16 @@ keep_tokens = 1
|
||||
|
||||
[[datasets]]
|
||||
resolution = 512 # 图像分辨率
|
||||
batch_size = 4 # 批量大小
|
||||
batch_size = 4 # 批次大小
|
||||
|
||||
[[datasets.subsets]]
|
||||
image_dir = 'C:\piyo' # 指定包含训练图像的文件夹
|
||||
metadata_file = 'C:\piyo\piyo_md.json' # 元数据文件名
|
||||
```
|
||||
|
||||
基本上,您可以通过仅重写以下位置来学习。如无特别说明,与DreamBooth相同,类+标识符方式。
|
||||
基本上只需更改以下几个地方来训练。除非另有说明,否则与DreamBooth, class+identifier方法相同。
|
||||
|
||||
1. 学习解像度
|
||||
1. 训练分辨率
|
||||
2. 批次大小
|
||||
3. 指定文件夹
|
||||
4. 元数据文件名
|
||||
@@ -250,25 +252,25 @@ batch_size = 4 # 批量大小
|
||||
指定使用后面所述方法创建的元数据文件。
|
||||
|
||||
|
||||
## 第三步:学习
|
||||
## 第三步:训练
|
||||
|
||||
请参考各个文档进行学习。
|
||||
详情请参考相关文档进行训练。
|
||||
|
||||
# 学习中使用的术语简单解释
|
||||
# 训练中使用的术语简单解释
|
||||
|
||||
由于省略了细节并且我自己也没有完全理解,因此请自行查阅详细信息。
|
||||
|
||||
## 微调(fine tuning)
|
||||
|
||||
指训练模型并微调其性能。具体含义因用法而异,但在 Stable Diffusion 中,狭义的微调是指使用图像和标题进行训练模型。DreamBooth 可视为狭义微调的一种特殊方法。广义的微调包括 LoRA、Textual Inversion、Hypernetworks 等,包括训练模型的所有内容。
|
||||
指训练模型并微调其性能。具体含义因用法而异,但在 Stable Diffusion 中,狭义的微调是指使用图像和caption进行训练模型。DreamBooth 可视为狭义微调的一种特殊方法。广义的微调包括 LoRA、Textual Inversion、Hypernetworks 等,包括训练模型的所有内容。
|
||||
|
||||
## 步骤(step)
|
||||
|
||||
粗略地说,每次在训练数据上进行一次计算即为一步。具体来说,“将训练数据的标题传递给当前模型,将生成的图像与训练数据的图像进行比较,稍微更改模型,以使其更接近训练数据”即为一步。
|
||||
粗略地说,每次在训练数据上进行一次计算即为一步。具体来说,“将训练数据的caption传递给当前模型,将生成的图像与训练数据的图像进行比较,稍微更改模型,以使其更接近训练数据”即为一步。
|
||||
|
||||
## 批次大小(batch size)
|
||||
|
||||
批次大小指定每个步骤要计算多少数据。批量计算可以提高速度。一般来说,批次大小越大,精度也越高。
|
||||
批次大小指定每个步骤要计算多少数据。批次计算可以提高速度。一般来说,批次大小越大,精度也越高。
|
||||
|
||||
“批次大小×步数”是用于训练的数据数量。因此,建议减少步数以增加批次大小。
|
||||
|
||||
@@ -276,37 +278,37 @@ batch_size = 4 # 批量大小
|
||||
|
||||
批次大小越大,GPU 内存消耗就越大。如果内存不足,将导致错误,或者在边缘时将导致训练速度降低。建议在任务管理器或 `nvidia-smi` 命令中检查使用的内存量进行调整。
|
||||
|
||||
另外,批次是指“一块数据”的意思。
|
||||
注意,一个批次是指“一个数据单位”。
|
||||
|
||||
## 学习率
|
||||
|
||||
学习率指的是每个步骤中改变的程度。如果指定一个大的值,学习速度就会加快,但是可能会出现变化太大导致模型崩溃或无法达到最佳状态的情况。如果指定一个小的值,学习速度会变慢,也可能无法达到最佳状态。
|
||||
学习率指的是每个步骤中改变的程度。如果指定一个大的值,学习速度就会加快,但是可能会出现变化太大导致模型崩溃或无法达到最佳状态的情况。如果指定一个小的值,学习速度会变慢,同时可能无法达到最佳状态。
|
||||
|
||||
在fine tuning、DreamBooth、LoRA等过程中,学习率会有很大的差异,并且也会受到训练数据、所需训练的模型、批量大小和步骤数等因素的影响。建议从一般的值开始,观察训练状态并逐渐调整。
|
||||
在fine tuning、DreamBooth、LoRA等过程中,学习率会有很大的差异,并且也会受到训练数据、所需训练的模型、批次大小和步骤数等因素的影响。建议从通常值开始,观察训练状态并逐渐调整。
|
||||
|
||||
默认情况下,整个训练过程中学习率是固定的。但是可以通过调度程序指定学习率如何变化,因此结果也会有所不同。
|
||||
|
||||
## 时代(epoch)
|
||||
## Epoch
|
||||
|
||||
Epoch指的是训练数据被完整训练一遍(即数据一周)的情况。如果指定了重复次数,则在重复后的数据一周后,就是1个epoch。
|
||||
Epoch指的是训练数据被完整训练一遍(即数据已经迭代一轮)。如果指定了重复次数,则在重复后的数据迭代一轮后,为1个epoch。
|
||||
|
||||
1个epoch的步骤数通常为“数据量÷批量大小”,但如果使用Aspect Ratio Bucketing,则略微增加(由于不同bucket的数据不能在同一个批次中,因此步骤数会增加)。
|
||||
1个epoch的步骤数通常为“数据量÷批次大小”,但如果使用Aspect Ratio Bucketing,则略微增加(由于不同bucket的数据不能在同一个批次中,因此步骤数会增加)。
|
||||
|
||||
## 纵横比分桶(Aspect Ratio Bucketing)
|
||||
## 长宽比分桶(Aspect Ratio Bucketing)
|
||||
|
||||
Stable Diffusion 的 v1 是以 512\*512 的分辨率进行训练的,但同时也可以在其他分辨率下进行训练,例如 256\*1024 和 384\*640。这样可以减少裁剪的部分,期望更准确地学习图像和标题之间的关系。
|
||||
Stable Diffusion 的 v1 是以 512\*512 的分辨率进行训练的,但同时也可以在其他分辨率下进行训练,例如 256\*1024 和 384\*640。这样可以减少裁剪的部分,希望更准确地学习图像和标题之间的关系。
|
||||
|
||||
此外,由于可以在任意分辨率下进行训练,因此不再需要事先统一图像数据的纵横比。
|
||||
此外,由于可以在任意分辨率下进行训练,因此不再需要事先统一图像数据的长宽比。
|
||||
|
||||
该设置在配置中有效,可以切换,但在此之前的配置文件示例中已启用(设置为 `true`)。
|
||||
此值可以被设定,其在此之前的配置文件示例中已被启用(设置为 `true`)。
|
||||
|
||||
学习分辨率将根据参数所提供的分辨率面积(即内存使用量)进行调整,以64像素为单位(默认值,可更改)在纵横方向上进行调整和创建。
|
||||
只要不超过作为参数给出的分辨率区域(= 内存使用量),就可以按 64 像素的增量(默认值,可更改)在垂直和水平方向上调整和创建训练分辨率。
|
||||
|
||||
在机器学习中,通常需要将所有输入大小统一,但实际上只要在同一批次中统一即可。 NovelAI 所说的分桶(bucketing) 指的是,预先将训练数据按照纵横比分类到每个学习分辨率下,并通过使用每个 bucket 内的图像创建批次来统一批次图像大小。
|
||||
在机器学习中,通常需要将所有输入大小统一,但实际上只要在同一批次中统一即可。 NovelAI 所说的分桶(bucketing) 指的是,预先将训练数据按照长宽比分类到每个学习分辨率下,并通过使用每个 bucket 内的图像创建批次来统一批次图像大小。
|
||||
|
||||
# 以前的指定格式(不使用 .toml 文件,而是使用命令行选项指定)
|
||||
|
||||
这是一种通过命令行选项而不是指定 .toml 文件的方法。有 DreamBooth 类+标识符方法、DreamBooth 标题方法、微调方法三种方式。
|
||||
这是一种通过命令行选项而不是指定 .toml 文件的方法。有 DreamBooth 类+标识符方法、DreamBooth caption方法、微调方法三种方式。
|
||||
|
||||
## DreamBooth、类+标识符方式
|
||||
|
||||
@@ -326,7 +328,7 @@ Stable Diffusion 的 v1 是以 512\*512 的分辨率进行训练的,但同时
|
||||
|
||||

|
||||
|
||||
### 多个类别、多个标识符的学习
|
||||
### 多个类别、多个标识符的训练
|
||||
|
||||
该方法很简单,在用于训练的图像文件夹中,需要准备多个文件夹,每个文件夹都是以“重复次数_<标识符> <类别>”命名的,同样,在正则化图像文件夹中,也需要准备多个文件夹,每个文件夹都是以“重复次数_<类别>”命名的。
|
||||
|
||||
@@ -344,37 +346,37 @@ Stable Diffusion 的 v1 是以 512\*512 的分辨率进行训练的,但同时
|
||||
|
||||
### step 2. 准备正规化图像
|
||||
|
||||
这是使用规则化图像时的过程。
|
||||
这是使用正则化图像时的过程。
|
||||
|
||||
创建一个文件夹来存储规则化的图像。 __此外,__ 创建一个名为``<repeat count>_<class>`` 的目录。
|
||||
创建一个文件夹来存储正则化的图像。 __此外,__ 创建一个名为``<repeat count>_<class>`` 的目录。
|
||||
|
||||
例如,使用提示“frog”并且不重复数据(仅一次):
|
||||

|
||||
|
||||
|
||||
步骤3. 执行学习
|
||||
步骤3. 执行训练
|
||||
|
||||
执行每个学习脚本。使用 `--train_data_dir` 选项指定包含训练数据文件夹的父文件夹(不是包含图像的文件夹),使用 `--reg_data_dir` 选项指定包含正则化图像的父文件夹(不是包含图像的文件夹)。
|
||||
执行每个训练脚本。使用 `--train_data_dir` 选项指定包含训练数据文件夹的父文件夹(不是包含图像的文件夹),使用 `--reg_data_dir` 选项指定包含正则化图像的父文件夹(不是包含图像的文件夹)。
|
||||
|
||||
## DreamBooth,带标题方式
|
||||
## DreamBooth,带文本说明(caption)的方式
|
||||
|
||||
在包含训练图像和正则化图像的文件夹中,将与图像具有相同文件名的文件.caption(可以使用选项进行更改)放置在该文件夹中,然后从该文件中加载标题作为提示进行学习。
|
||||
在包含训练图像和正则化图像的文件夹中,将与图像具有相同文件名的文件.caption(可以使用选项进行更改)放置在该文件夹中,然后从该文件中加载caption所作为提示进行训练。
|
||||
|
||||
※文件夹名称(标识符类)不再用于这些图像的训练。
|
||||
|
||||
默认的标题文件扩展名为.caption。可以使用学习脚本的 `--caption_extension` 选项进行更改。 使用 `--shuffle_caption` 选项,同时对每个逗号分隔的部分进行学习时会对学习时的标题进行混洗。
|
||||
默认的caption文件扩展名为.caption。可以使用训练脚本的 `--caption_extension` 选项进行更改。 使用 `--shuffle_caption` 选项,同时对每个逗号分隔的部分进行训练时会对训练时的caption进行混洗。
|
||||
|
||||
## 微调方式
|
||||
|
||||
创建元数据的方式与使用配置文件相同。 使用 `in_json` 选项指定元数据文件。
|
||||
|
||||
# 学习过程中的样本输出
|
||||
# 训练过程中的样本输出
|
||||
|
||||
通过在训练中使用模型生成图像,可以检查学习进度。将以下选项指定为学习脚本。
|
||||
通过在训练中使用模型生成图像,可以检查训练进度。将以下选项指定为训练脚本。
|
||||
|
||||
- `--sample_every_n_steps` / `--sample_every_n_epochs`
|
||||
|
||||
指定要采样的步数或纪元数。为这些数字中的每一个输出样本。如果两者都指定,则 epoch 数优先。
|
||||
指定要采样的步数或epoch数。为这些数字中的每一个输出样本。如果两者都指定,则 epoch 数优先。
|
||||
- `--sample_prompts`
|
||||
|
||||
指定示例输出的提示文件。
|
||||
@@ -421,11 +423,11 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
|
||||
4. U-Net的结构(CrossAttention的头数等)
|
||||
5. v-parameterization(采样方式好像变了)
|
||||
|
||||
其中碱基使用1-4个,非碱基使用1-5个(768-v)。使用 1-4 进行 v2 选择,使用 5 进行 v_parameterization 选择。
|
||||
-`--pretrained_model_name_or_path`
|
||||
其中base使用1-4,非base使用1-5(768-v)。使用 1-4 进行 v2 选择,使用 5 进行 v_parameterization 选择。
|
||||
- `--pretrained_model_name_or_path`
|
||||
|
||||
指定要从中执行额外训练的模型。您可以指定稳定扩散检查点文件(.ckpt 或 .safetensors)、扩散器本地磁盘上的模型目录或扩散器模型 ID(例如“stabilityai/stable-diffusion-2”)。
|
||||
## 学习设置
|
||||
指定要从中执行额外训练的模型。您可以指定Stable Diffusion检查点文件(.ckpt 或 .safetensors)、diffusers本地磁盘上的模型目录或diffusers模型 ID(例如“stabilityai/stable-diffusion-2”)。
|
||||
## 训练设置
|
||||
|
||||
- `--output_dir`
|
||||
|
||||
@@ -441,7 +443,7 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
|
||||
|
||||
- `--max_train_steps` / `--max_train_epochs`
|
||||
|
||||
指定要学习的步数或纪元数。如果两者都指定,则 epoch 数优先。
|
||||
指定要训练的步数或epoch数。如果两者都指定,则 epoch 数优先。
|
||||
-
|
||||
- `--mixed_precision`
|
||||
|
||||
@@ -450,9 +452,9 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
|
||||
(在RTX30系列以后也可以指定`bf16`,请配合您在搭建环境时做的加速设置)。
|
||||
- `--gradient_checkpointing`
|
||||
|
||||
通过逐步计算权重而不是在训练期间一次计算所有权重来减少训练所需的 GPU 内存量。关闭它不会影响准确性,但打开它允许更大的批量大小,所以那里有影响。
|
||||
通过逐步计算权重而不是在训练期间一次计算所有权重来减少训练所需的 GPU 内存量。关闭它不会影响准确性,但打开它允许更大的批次大小,所以那里有影响。
|
||||
|
||||
另外,打开它通常会减慢速度,但可以增加批量大小,因此总的学习时间实际上可能会更快。
|
||||
另外,打开它通常会减慢速度,但可以增加批次大小,因此总的训练时间实际上可能会更快。
|
||||
|
||||
- `--xformers` / `--mem_eff_attn`
|
||||
|
||||
@@ -463,35 +465,35 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
|
||||
- `--save_every_n_epochs` / `--save_state` / `--resume`
|
||||
为 save_every_n_epochs 选项指定一个数字可以在每个时期的训练期间保存模型。
|
||||
|
||||
如果同时指定save_state选项,学习状态包括优化器的状态等都会一起保存。。保存目的地将是一个文件夹。
|
||||
如果同时指定save_state选项,训练状态包括优化器的状态等都会一起保存。。保存目的地将是一个文件夹。
|
||||
|
||||
学习状态输出到目标文件夹中名为“<output_name>-??????-state”(??????是纪元数)的文件夹中。长时间学习时请使用。
|
||||
训练状态输出到目标文件夹中名为“<output_name>-??????-state”(??????是epoch数)的文件夹中。长时间训练时请使用。
|
||||
|
||||
使用 resume 选项从保存的训练状态恢复训练。指定学习状态文件夹(其中的状态文件夹,而不是 `output_dir`)。
|
||||
使用 resume 选项从保存的训练状态恢复训练。指定训练状态文件夹(其中的状态文件夹,而不是 `output_dir`)。
|
||||
|
||||
请注意,由于 Accelerator 规范,epoch 数和全局步数不会保存,即使恢复时它们也从 1 开始。
|
||||
- `--save_model_as` (DreamBooth, fine tuning 仅有的)
|
||||
|
||||
您可以从 `ckpt, safetensors, diffusers, diffusers_safetensors` 中选择模型保存格式。
|
||||
|
||||
- `--save_model_as=safetensors` 指定喜欢当读取稳定扩散格式(ckpt 或安全张量)并以扩散器格式保存时,缺少的信息通过从 Hugging Face 中删除 v1.5 或 v2.1 信息来补充。
|
||||
- `--save_model_as=safetensors` 指定喜欢当读取Stable Diffusion格式(ckpt 或safetensors)并以diffusers格式保存时,缺少的信息通过从 Hugging Face 中删除 v1.5 或 v2.1 信息来补充。
|
||||
|
||||
- `--clip_skip`
|
||||
|
||||
`2` 如果指定,则使用文本编码器 (CLIP) 的倒数第二层的输出。如果省略 1 或选项,则使用最后一层。
|
||||
|
||||
*SD2.0默认使用倒数第二层,学习SD2.0时请不要指定。
|
||||
*SD2.0默认使用倒数第二层,训练SD2.0时请不要指定。
|
||||
|
||||
如果被训练的模型最初被训练为使用第二层,则 2 是一个很好的值。
|
||||
|
||||
如果您使用的是最后一层,那么整个模型都会根据该假设进行训练。因此,如果再次使用第二层进行训练,可能需要一定数量的teacher数据和更长时间的学习才能得到想要的学习结果。
|
||||
如果您使用的是最后一层,那么整个模型都会根据该假设进行训练。因此,如果再次使用第二层进行训练,可能需要一定数量的teacher数据和更长时间的训练才能得到想要的训练结果。
|
||||
- `--max_token_length`
|
||||
|
||||
默认值为 75。您可以通过指定“150”或“225”来扩展令牌长度来学习。使用长字幕学习时指定。
|
||||
默认值为 75。您可以通过指定“150”或“225”来扩展令牌长度来训练。使用长字幕训练时指定。
|
||||
|
||||
但由于学习时token展开的规范与Automatic1111的web UI(除法等规范)略有不同,如非必要建议用75学习。
|
||||
但由于训练时token展开的规范与Automatic1111的web UI(除法等规范)略有不同,如非必要建议用75训练。
|
||||
|
||||
与clip_skip一样,学习与模型学习状态不同的长度可能需要一定量的teacher数据和更长的学习时间。
|
||||
与clip_skip一样,训练与模型训练状态不同的长度可能需要一定量的teacher数据和更长的学习时间。
|
||||
|
||||
- `--persistent_data_loader_workers`
|
||||
|
||||
@@ -502,7 +504,7 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
|
||||
指定数据加载的进程数。大量的进程会更快地加载数据并更有效地使用 GPU,但会消耗更多的主内存。默认是"`8`或者`CPU并发执行线程数 - 1`,取小者",所以如果主存没有空间或者GPU使用率大概在90%以上,就看那些数字和 `2` 或将其降低到大约 `1`。
|
||||
- `--logging_dir` / `--log_prefix`
|
||||
|
||||
保存学习日志的选项。在 logging_dir 选项中指定日志保存目标文件夹。以 TensorBoard 格式保存日志。
|
||||
保存训练日志的选项。在 logging_dir 选项中指定日志保存目标文件夹。以 TensorBoard 格式保存日志。
|
||||
|
||||
例如,如果您指定 --logging_dir=logs,将在您的工作文件夹中创建一个日志文件夹,并将日志保存在日期/时间文件夹中。
|
||||
此外,如果您指定 --log_prefix 选项,则指定的字符串将添加到日期和时间之前。使用“--logging_dir=logs --log_prefix=db_style1_”进行识别。
|
||||
@@ -518,23 +520,23 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
|
||||
- `--noise_offset`
|
||||
本文的实现:https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||
|
||||
看起来它可能会为整体更暗和更亮的图像产生更好的结果。它似乎对 LoRA 学习也有效。指定一个大约 0.1 的值似乎很好。
|
||||
看起来它可能会为整体更暗和更亮的图像产生更好的结果。它似乎对 LoRA 训练也有效。指定一个大约 0.1 的值似乎很好。
|
||||
|
||||
- `--debug_dataset`
|
||||
|
||||
通过添加此选项,您可以在学习之前检查将学习什么样的图像数据和标题。按 Esc 退出并返回命令行。按 `S` 进入下一步(批次),按 `E` 进入下一个纪元。
|
||||
通过添加此选项,您可以在训练之前检查将训练什么样的图像数据和标题。按 Esc 退出并返回命令行。按 `S` 进入下一步(批次),按 `E` 进入下一个epoch。
|
||||
|
||||
*图片在 Linux 环境(包括 Colab)下不显示。
|
||||
|
||||
- `--vae`
|
||||
|
||||
如果您在 vae 选项中指定稳定扩散检查点、VAE 检查点文件、扩散模型或 VAE(两者都可以指定本地或拥抱面模型 ID),则该 VAE 用于学习(缓存时的潜伏)或在学习过程中获得潜伏)。
|
||||
如果您在 vae 选项中指定Stable Diffusion检查点、VAE 检查点文件、扩散模型或 VAE(两者都可以指定本地或拥抱面模型 ID),则该 VAE 用于训练(缓存时的潜伏)或在训练过程中获得潜伏)。
|
||||
|
||||
对于 DreamBooth 和微调,保存的模型将包含此 VAE
|
||||
|
||||
- `--cache_latents`
|
||||
|
||||
在主内存中缓存 VAE 输出以减少 VRAM 使用。除 flip_aug 之外的任何增强都将不可用。此外,整体学习速度略快。
|
||||
在主内存中缓存 VAE 输出以减少 VRAM 使用。除 flip_aug 之外的任何增强都将不可用。此外,整体训练速度略快。
|
||||
- `--min_snr_gamma`
|
||||
|
||||
指定最小 SNR 加权策略。细节是[这里](https://github.com/kohya-ss/sd-scripts/pull/308)请参阅。论文中推荐`5`。
|
||||
@@ -545,19 +547,29 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
|
||||
-- 指定优化器类型。您可以指定
|
||||
- AdamW : [torch.optim.AdamW](https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html)
|
||||
- 与过去版本中未指定选项时相同
|
||||
- AdamW8bit : 同上
|
||||
- AdamW8bit : 参数同上
|
||||
- PagedAdamW8bit : 参数同上
|
||||
- 与过去版本中指定的 --use_8bit_adam 相同
|
||||
- Lion : https://github.com/lucidrains/lion-pytorch
|
||||
- Lion8bit : 参数同上
|
||||
- PagedLion8bit : 参数同上
|
||||
- 与过去版本中指定的 --use_lion_optimizer 相同
|
||||
- SGDNesterov : [torch.optim.SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html), nesterov=True
|
||||
- SGDNesterov8bit : 引数同上
|
||||
- DAdaptation : https://github.com/facebookresearch/dadaptation
|
||||
- SGDNesterov8bit : 参数同上
|
||||
- DAdaptation(DAdaptAdamPreprint) : https://github.com/facebookresearch/dadaptation
|
||||
- DAdaptAdam : 参数同上
|
||||
- DAdaptAdaGrad : 参数同上
|
||||
- DAdaptAdan : 参数同上
|
||||
- DAdaptAdanIP : 参数同上
|
||||
- DAdaptLion : 参数同上
|
||||
- DAdaptSGD : 参数同上
|
||||
- Prodigy : https://github.com/konstmish/prodigy
|
||||
- AdaFactor : [Transformers AdaFactor](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules)
|
||||
- 任何优化器
|
||||
|
||||
- `--learning_rate`
|
||||
|
||||
指定学习率。合适的学习率取决于学习脚本,所以请参考每个解释。
|
||||
指定学习率。合适的学习率取决于训练脚本,所以请参考每个解释。
|
||||
- `--lr_scheduler` / `--lr_warmup_steps` / `--lr_scheduler_num_cycles` / `--lr_scheduler_power`
|
||||
|
||||
学习率的调度程序相关规范。
|
||||
@@ -592,14 +604,14 @@ D-Adaptation 优化器自动调整学习率。学习率选项指定的值不是
|
||||
(内部仅通过 importlib 未确认操作。如果需要,请安装包。)
|
||||
<!--
|
||||
## 使用任意大小的图像进行训练 --resolution
|
||||
你可以在广场外学习。请在分辨率中指定“宽度、高度”,如“448,640”。宽度和高度必须能被 64 整除。匹配训练图像和正则化图像的大小。
|
||||
你可以在广场外训练。请在分辨率中指定“宽度、高度”,如“448,640”。宽度和高度必须能被 64 整除。匹配训练图像和正则化图像的大小。
|
||||
|
||||
就我个人而言,我经常生成垂直长的图像,所以我有时会用“448、640”来学习。
|
||||
就我个人而言,我经常生成垂直长的图像,所以我有时会用“448、640”来训练。
|
||||
|
||||
## 纵横比分桶 --enable_bucket / --min_bucket_reso / --max_bucket_reso
|
||||
它通过指定 enable_bucket 选项来启用。 Stable Diffusion 在 512x512 分辨率下训练,但也在 256x768 和 384x640 等分辨率下训练。
|
||||
|
||||
如果指定此选项,则不需要将训练图像和正则化图像统一为特定分辨率。从多种分辨率(纵横比)中进行选择,并在该分辨率下学习。
|
||||
如果指定此选项,则不需要将训练图像和正则化图像统一为特定分辨率。从多种分辨率(纵横比)中进行选择,并在该分辨率下训练。
|
||||
由于分辨率为 64 像素,纵横比可能与原始图像不完全相同。
|
||||
|
||||
您可以使用 min_bucket_reso 选项指定分辨率的最小大小,使用 max_bucket_reso 指定最大大小。默认值分别为 256 和 1024。
|
||||
@@ -611,13 +623,13 @@ D-Adaptation 优化器自动调整学习率。学习率选项指定的值不是
|
||||
(因为一批中的图像不偏向于训练图像和正则化图像。
|
||||
|
||||
## 扩充 --color_aug / --flip_aug
|
||||
增强是一种通过在学习过程中动态改变数据来提高模型性能的方法。在使用 color_aug 巧妙地改变色调并使用 flip_aug 左右翻转的同时学习。
|
||||
增强是一种通过在训练过程中动态改变数据来提高模型性能的方法。在使用 color_aug 巧妙地改变色调并使用 flip_aug 左右翻转的同时训练。
|
||||
|
||||
由于数据是动态变化的,因此不能与 cache_latents 选项一起指定。
|
||||
|
||||
## 使用 fp16 梯度训练(实验特征)--full_fp16
|
||||
如果指定 full_fp16 选项,梯度从普通 float32 变为 float16 (fp16) 并学习(它似乎是 full fp16 学习而不是混合精度)。
|
||||
结果,似乎 SD1.x 512x512 大小可以在 VRAM 使用量小于 8GB 的情况下学习,而 SD2.x 512x512 大小可以在 VRAM 使用量小于 12GB 的情况下学习。
|
||||
如果指定 full_fp16 选项,梯度从普通 float32 变为 float16 (fp16) 并训练(它似乎是 full fp16 训练而不是混合精度)。
|
||||
结果,似乎 SD1.x 512x512 大小可以在 VRAM 使用量小于 8GB 的情况下训练,而 SD2.x 512x512 大小可以在 VRAM 使用量小于 12GB 的情况下训练。
|
||||
|
||||
预先在加速配置中指定 fp16,并可选择设置 ``mixed_precision="fp16"``(bf16 不起作用)。
|
||||
|
||||
@@ -631,20 +643,20 @@ D-Adaptation 优化器自动调整学习率。学习率选项指定的值不是
|
||||
|
||||
# 创建元数据文件
|
||||
|
||||
## 准备教师资料
|
||||
## 准备训练数据
|
||||
|
||||
如上所述准备好你要学习的图像数据,放在任意文件夹中。
|
||||
如上所述准备好你要训练的图像数据,放在任意文件夹中。
|
||||
|
||||
例如,存储这样的图像:
|
||||
|
||||

|
||||
|
||||
## 自动字幕
|
||||
## 自动captioning
|
||||
|
||||
如果您只想学习没有标题的标签,请跳过。
|
||||
如果您只想训练没有标题的标签,请跳过。
|
||||
|
||||
另外,手动准备字幕时,请准备在与教师数据图像相同的目录下,文件名相同,扩展名.caption等。每个文件应该是只有一行的文本文件。
|
||||
### 使用 BLIP 添加字幕
|
||||
另外,手动准备caption时,请准备在与教师数据图像相同的目录下,文件名相同,扩展名.caption等。每个文件应该是只有一行的文本文件。
|
||||
### 使用 BLIP 添加caption
|
||||
|
||||
最新版本不再需要 BLIP 下载、权重下载和额外的虚拟环境。按原样工作。
|
||||
|
||||
@@ -659,24 +671,24 @@ python finetune\make_captions.py --batch_size <バッチサイズ> <教師デー
|
||||
python finetune\make_captions.py --batch_size 8 ..\train_data
|
||||
```
|
||||
|
||||
字幕文件创建在与教师数据图像相同的目录中,具有相同的文件名和扩展名.caption。
|
||||
caption文件创建在与教师数据图像相同的目录中,具有相同的文件名和扩展名.caption。
|
||||
|
||||
根据 GPU 的 VRAM 容量增加或减少 batch_size。越大越快(我认为 12GB 的 VRAM 可以多一点)。
|
||||
您可以使用 max_length 选项指定标题的最大长度。默认值为 75。如果使用 225 的令牌长度训练模型,它可能会更长。
|
||||
您可以使用 caption_extension 选项更改标题扩展名。默认为 .caption(.txt 与稍后描述的 DeepDanbooru 冲突)。
|
||||
您可以使用 max_length 选项指定caption的最大长度。默认值为 75。如果使用 225 的令牌长度训练模型,它可能会更长。
|
||||
您可以使用 caption_extension 选项更改caption扩展名。默认为 .caption(.txt 与稍后描述的 DeepDanbooru 冲突)。
|
||||
如果有多个教师数据文件夹,则对每个文件夹执行。
|
||||
|
||||
请注意,推理是随机的,因此每次运行时结果都会发生变化。如果要修复它,请使用 --seed 选项指定一个随机数种子,例如 `--seed 42`。
|
||||
|
||||
其他的选项,请参考help with `--help`(好像没有文档说明参数的含义,得看源码)。
|
||||
|
||||
默认情况下,会生成扩展名为 .caption 的字幕文件。
|
||||
默认情况下,会生成扩展名为 .caption 的caption文件。
|
||||
|
||||

|
||||
|
||||
例如,标题如下:
|
||||
|
||||

|
||||

|
||||
|
||||
## 由 DeepDanbooru 标记
|
||||
|
||||
@@ -695,7 +707,7 @@ python finetune\make_captions.py --batch_size 8 ..\train_data
|
||||
做一个这样的目录结构
|
||||
|
||||

|
||||
为扩散器环境安装必要的库。进入 DeepDanbooru 文件夹并安装它(我认为它实际上只是添加了 tensorflow-io)。
|
||||
为diffusers环境安装必要的库。进入 DeepDanbooru 文件夹并安装它(我认为它实际上只是添加了 tensorflow-io)。
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
@@ -768,12 +780,12 @@ python tag_images_by_wd14_tagger.py --batch_size 4 ..\train_data
|
||||
|
||||
如果有多个教师数据文件夹,则对每个文件夹执行。
|
||||
|
||||
## 预处理字幕和标签信息
|
||||
## 预处理caption和标签信息
|
||||
|
||||
将字幕和标签作为元数据合并到一个文件中,以便从脚本中轻松处理。
|
||||
### 字幕预处理
|
||||
将caption和标签作为元数据合并到一个文件中,以便从脚本中轻松处理。
|
||||
### caption预处理
|
||||
|
||||
要将字幕放入元数据,请在您的工作文件夹中运行以下命令(如果您不使用字幕进行学习,则不需要运行它)(它实际上是一行,依此类推)。指定 `--full_path` 选项以将图像文件的完整路径存储在元数据中。如果省略此选项,则会记录相对路径,但 .toml 文件中需要单独的文件夹规范。
|
||||
要将caption放入元数据,请在您的工作文件夹中运行以下命令(如果您不使用caption进行训练,则不需要运行它)(它实际上是一行,依此类推)。指定 `--full_path` 选项以将图像文件的完整路径存储在元数据中。如果省略此选项,则会记录相对路径,但 .toml 文件中需要单独的文件夹规范。
|
||||
```
|
||||
python merge_captions_to_metadata.py --full_path <教师资料夹>
|
||||
--in_json <要读取的元数据文件名> <元数据文件名>
|
||||
@@ -799,7 +811,7 @@ python merge_captions_to_metadata.py --full_path --in_json meta_cap1.json
|
||||
__* 每次重写 in_json 选项和写入目标并写入单独的元数据文件是安全的。 __
|
||||
### 标签预处理
|
||||
|
||||
同样,标签也收集在元数据中(如果标签不用于学习,则无需这样做)。
|
||||
同样,标签也收集在元数据中(如果标签不用于训练,则无需这样做)。
|
||||
```
|
||||
python merge_dd_tags_to_metadata.py --full_path <教师资料夹>
|
||||
--in_json <要读取的元数据文件名> <要写入的元数据文件名>
|
||||
@@ -855,7 +867,7 @@ python clean_captions_and_tags.py meta_cap_dd.json meta_clean.json
|
||||
python prepare_buckets_latents.py --full_path <教师资料夹>
|
||||
<要读取的元数据文件名> <要写入的元数据文件名>
|
||||
<要微调的模型名称或检查点>
|
||||
--batch_size <批量大小>
|
||||
--batch_size <批次大小>
|
||||
--max_resolution <分辨率宽、高>
|
||||
--mixed_precision <准确性>
|
||||
```
|
||||
@@ -875,7 +887,7 @@ python prepare_buckets_latents.py --full_path
|
||||
|
||||
对于翻转的图像,也会获取latents,并保存名为\ *_flip.npz的文件,这是一个简单的实现。在fline_tune.py中不需要特定的选项。如果有带有\_flip的文件,则会随机加载带有和不带有flip的文件。
|
||||
|
||||
即使VRAM为12GB,批量大小也可以稍微增加。分辨率以“宽度,高度”的形式指定,必须是64的倍数。分辨率直接影响fine tuning时的内存大小。在12GB VRAM中,512,512似乎是极限(*)。如果有16GB,则可以将其提高到512,704或512,768。即使分辨率为256,256等,VRAM 8GB也很难承受(因为参数、优化器等与分辨率无关,需要一定的内存)。
|
||||
即使VRAM为12GB,批次大小也可以稍微增加。分辨率以“宽度,高度”的形式指定,必须是64的倍数。分辨率直接影响fine tuning时的内存大小。在12GB VRAM中,512,512似乎是极限(*)。如果有16GB,则可以将其提高到512,704或512,768。即使分辨率为256,256等,VRAM 8GB也很难承受(因为参数、优化器等与分辨率无关,需要一定的内存)。
|
||||
|
||||
*有报道称,在batch size为1的训练中,使用12GB VRAM和640,640的分辨率。
|
||||
|
||||
@@ -276,7 +276,9 @@ python networks\merge_lora.py --sd_model ..\model\model.ckpt
|
||||
|
||||
### 複数のLoRAのモデルをマージする
|
||||
|
||||
複数のLoRAモデルをひとつずつSDモデルに適用する場合と、複数のLoRAモデルをマージしてからSDモデルにマージする場合とは、計算順序の関連で微妙に異なる結果になります。
|
||||
__複数のLoRAをマージする場合は原則として `svd_merge_lora.py` を使用してください。__ 単純なup同士やdown同士のマージでは、計算結果が正しくなくなるためです。
|
||||
|
||||
`merge_lora.py` によるマージは差分抽出法でLoRAを生成する場合等、ごく限られた場合でのみ有効です。
|
||||
|
||||
たとえば以下のようなコマンドラインになります。
|
||||
|
||||
@@ -294,7 +296,7 @@ python networks\merge_lora.py
|
||||
|
||||
--ratiosにそれぞれのモデルの比率(どのくらい重みを元モデルに反映するか)を0~1.0の数値で指定します。二つのモデルを一対一でマージす場合は、「0.5 0.5」になります。「1.0 1.0」では合計の重みが大きくなりすぎて、恐らく結果はあまり望ましくないものになると思われます。
|
||||
|
||||
v1で学習したLoRAとv2で学習したLoRA、rank(次元数)や``alpha``の異なるLoRAはマージできません。U-NetだけのLoRAとU-Net+Text EncoderのLoRAはマージできるはずですが、結果は未知数です。
|
||||
v1で学習したLoRAとv2で学習したLoRA、rank(次元数)の異なるLoRAはマージできません。U-NetだけのLoRAとU-Net+Text EncoderのLoRAはマージできるはずですが、結果は未知数です。
|
||||
|
||||
|
||||
### その他のオプション
|
||||
89
fine_tune.py
89
fine_tune.py
@@ -21,7 +21,14 @@ from library.config_util import (
|
||||
BlueprintGenerator,
|
||||
)
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings
|
||||
from library.custom_train_functions import (
|
||||
apply_snr_weight,
|
||||
get_weighted_text_embeddings,
|
||||
prepare_scheduler_for_custom_training,
|
||||
pyramid_noise_like,
|
||||
apply_noise_offset,
|
||||
scale_v_prediction_loss_like_noise_prediction,
|
||||
)
|
||||
|
||||
|
||||
def train(args):
|
||||
@@ -35,33 +42,37 @@ def train(args):
|
||||
|
||||
tokenizer = train_util.load_tokenizer(args)
|
||||
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True))
|
||||
if args.dataset_config is not None:
|
||||
print(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
print(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
# データセットを準備する
|
||||
if args.dataset_class is None:
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True))
|
||||
if args.dataset_config is not None:
|
||||
print(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
print(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": [
|
||||
{
|
||||
"image_dir": args.train_data_dir,
|
||||
"metadata_file": args.in_json,
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
else:
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": [
|
||||
{
|
||||
"image_dir": args.train_data_dir,
|
||||
"metadata_file": args.in_json,
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
@@ -90,7 +101,7 @@ def train(args):
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
|
||||
# モデルを読み込む
|
||||
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype)
|
||||
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||
|
||||
# verify load/save model formats
|
||||
if load_stable_diffusion_format:
|
||||
@@ -228,6 +239,9 @@ def train(args):
|
||||
else:
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
# transform DDP after prepare
|
||||
text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||
if args.full_fp16:
|
||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||
@@ -258,12 +272,13 @@ def train(args):
|
||||
noise_scheduler = DDPMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
||||
)
|
||||
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name)
|
||||
|
||||
for epoch in range(num_train_epochs):
|
||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||
print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
for m in training_models:
|
||||
@@ -302,8 +317,9 @@ def train(args):
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents, device=latents.device)
|
||||
if args.noise_offset:
|
||||
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
||||
noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale)
|
||||
elif args.multires_noise_iterations:
|
||||
noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount)
|
||||
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
|
||||
@@ -323,11 +339,16 @@ def train(args):
|
||||
else:
|
||||
target = noise
|
||||
|
||||
if args.min_snr_gamma:
|
||||
# do not mean over batch dimension for snr weight
|
||||
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred:
|
||||
# do not mean over batch dimension for snr weight or scale v-pred loss
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||
|
||||
loss = loss.mean() # mean over batch dimension
|
||||
else:
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
||||
@@ -376,7 +397,7 @@ def train(args):
|
||||
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
||||
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): # tracking d*lr value
|
||||
logs["lr/d*lr"] = (
|
||||
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
|
||||
)
|
||||
|
||||
@@ -3,6 +3,7 @@ import glob
|
||||
import os
|
||||
import json
|
||||
import random
|
||||
import sys
|
||||
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
@@ -11,6 +12,7 @@ import numpy as np
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
from blip.blip import blip_decoder
|
||||
import library.train_util as train_util
|
||||
|
||||
|
||||
@@ -46,6 +46,7 @@ VGG(
|
||||
)
|
||||
"""
|
||||
|
||||
import itertools
|
||||
import json
|
||||
from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable
|
||||
import glob
|
||||
@@ -311,6 +312,7 @@ class FlashAttentionFunction(torch.autograd.Function):
|
||||
return dq, dk, dv, None, None, None, None
|
||||
|
||||
|
||||
# TODO common train_util.py
|
||||
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
|
||||
if mem_eff_attn:
|
||||
replace_unet_cross_attn_to_memory_efficient()
|
||||
@@ -319,7 +321,7 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio
|
||||
|
||||
|
||||
def replace_unet_cross_attn_to_memory_efficient():
|
||||
print("Replace CrossAttention.forward to use NAI style Hypernetwork and FlashAttention")
|
||||
print("CrossAttention.forward has been replaced to FlashAttention (not xformers) and NAI style Hypernetwork")
|
||||
flash_func = FlashAttentionFunction
|
||||
|
||||
def forward_flash_attn(self, x, context=None, mask=None):
|
||||
@@ -359,7 +361,7 @@ def replace_unet_cross_attn_to_memory_efficient():
|
||||
|
||||
|
||||
def replace_unet_cross_attn_to_xformers():
|
||||
print("Replace CrossAttention.forward to use NAI style Hypernetwork and xformers")
|
||||
print("CrossAttention.forward has been replaced to enable xformers and NAI style Hypernetwork")
|
||||
try:
|
||||
import xformers.ops
|
||||
except ImportError:
|
||||
@@ -401,6 +403,104 @@ def replace_unet_cross_attn_to_xformers():
|
||||
diffusers.models.attention.CrossAttention.forward = forward_xformers
|
||||
|
||||
|
||||
def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers):
|
||||
if mem_eff_attn:
|
||||
replace_vae_attn_to_memory_efficient()
|
||||
elif xformers:
|
||||
# とりあえずDiffusersのxformersを使う。AttentionがあるのはMidBlockのみ
|
||||
print("Use Diffusers xformers for VAE")
|
||||
vae.set_use_memory_efficient_attention_xformers(True)
|
||||
|
||||
"""
|
||||
# VAEがbfloat16でメモリ消費が大きい問題を解決する
|
||||
upsamplers = []
|
||||
for block in vae.decoder.up_blocks:
|
||||
if block.upsamplers is not None:
|
||||
upsamplers.extend(block.upsamplers)
|
||||
|
||||
def forward_upsample(_self, hidden_states, output_size=None):
|
||||
assert hidden_states.shape[1] == _self.channels
|
||||
if _self.use_conv_transpose:
|
||||
return _self.conv(hidden_states)
|
||||
|
||||
dtype = hidden_states.dtype
|
||||
if dtype == torch.bfloat16:
|
||||
assert output_size is None
|
||||
# repeat_interleaveはすごく遅いが、回数はあまり呼ばれないので許容する
|
||||
hidden_states = hidden_states.repeat_interleave(2, dim=-1)
|
||||
hidden_states = hidden_states.repeat_interleave(2, dim=-2)
|
||||
else:
|
||||
if hidden_states.shape[0] >= 64:
|
||||
hidden_states = hidden_states.contiguous()
|
||||
|
||||
# if `output_size` is passed we force the interpolation output
|
||||
# size and do not make use of `scale_factor=2`
|
||||
if output_size is None:
|
||||
hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
||||
else:
|
||||
hidden_states = torch.nn.functional.interpolate(hidden_states, size=output_size, mode="nearest")
|
||||
|
||||
if _self.use_conv:
|
||||
if _self.name == "conv":
|
||||
hidden_states = _self.conv(hidden_states)
|
||||
else:
|
||||
hidden_states = _self.Conv2d_0(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
# replace upsamplers
|
||||
for upsampler in upsamplers:
|
||||
# make new scope
|
||||
def make_replacer(upsampler):
|
||||
def forward(hidden_states, output_size=None):
|
||||
return forward_upsample(upsampler, hidden_states, output_size)
|
||||
|
||||
return forward
|
||||
|
||||
upsampler.forward = make_replacer(upsampler)
|
||||
"""
|
||||
|
||||
|
||||
def replace_vae_attn_to_memory_efficient():
|
||||
print("AttentionBlock.forward has been replaced to FlashAttention (not xformers)")
|
||||
flash_func = FlashAttentionFunction
|
||||
|
||||
def forward_flash_attn(self, hidden_states):
|
||||
print("forward_flash_attn")
|
||||
q_bucket_size = 512
|
||||
k_bucket_size = 1024
|
||||
|
||||
residual = hidden_states
|
||||
batch, channel, height, width = hidden_states.shape
|
||||
|
||||
# norm
|
||||
hidden_states = self.group_norm(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
|
||||
|
||||
# proj to q, k, v
|
||||
query_proj = self.query(hidden_states)
|
||||
key_proj = self.key(hidden_states)
|
||||
value_proj = self.value(hidden_states)
|
||||
|
||||
query_proj, key_proj, value_proj = map(
|
||||
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (query_proj, key_proj, value_proj)
|
||||
)
|
||||
|
||||
out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size)
|
||||
|
||||
out = rearrange(out, "b h n d -> b n (h d)")
|
||||
|
||||
# compute next hidden_states
|
||||
hidden_states = self.proj_attn(hidden_states)
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
|
||||
|
||||
# res connect and rescale
|
||||
hidden_states = (hidden_states + residual) / self.rescale_output_factor
|
||||
return hidden_states
|
||||
|
||||
diffusers.models.attention.AttentionBlock.forward = forward_flash_attn
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region 画像生成の本体:lpw_stable_diffusion.py (ASL)からコピーして修正
|
||||
@@ -515,11 +615,15 @@ class PipelineLike:
|
||||
|
||||
# ControlNet
|
||||
self.control_nets: List[ControlNetInfo] = []
|
||||
self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない
|
||||
|
||||
# Textual Inversion
|
||||
def add_token_replacement(self, target_token_id, rep_token_ids):
|
||||
self.token_replacements[target_token_id] = rep_token_ids
|
||||
|
||||
def set_enable_control_net(self, en: bool):
|
||||
self.control_net_enabled = en
|
||||
|
||||
def replace_token(self, tokens, layer=None):
|
||||
new_tokens = []
|
||||
for token in tokens:
|
||||
@@ -955,7 +1059,7 @@ class PipelineLike:
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
init_latents = []
|
||||
for i in tqdm(range(0, batch_size, vae_batch_size)):
|
||||
for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)):
|
||||
init_latent_dist = self.vae.encode(
|
||||
init_image[i : i + vae_batch_size] if vae_batch_size > 1 else init_image[i].unsqueeze(0)
|
||||
).latent_dist
|
||||
@@ -1012,7 +1116,7 @@ class PipelineLike:
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
if self.control_nets:
|
||||
if self.control_nets and self.control_net_enabled:
|
||||
if reginonal_network:
|
||||
num_sub_and_neg_prompts = len(text_embeddings) // batch_size
|
||||
text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt
|
||||
@@ -1696,6 +1800,9 @@ def parse_prompt_attention(text):
|
||||
for p in range(start_position, len(res)):
|
||||
res[p][1] *= multiplier
|
||||
|
||||
# keep break as separate token
|
||||
text = text.replace("BREAK", "\\BREAK\\")
|
||||
|
||||
for m in re_attention.finditer(text):
|
||||
text = m.group(0)
|
||||
weight = m.group(1)
|
||||
@@ -1727,7 +1834,7 @@ def parse_prompt_attention(text):
|
||||
# merge runs of identical weights
|
||||
i = 0
|
||||
while i + 1 < len(res):
|
||||
if res[i][1] == res[i + 1][1]:
|
||||
if res[i][1] == res[i + 1][1] and res[i][0].strip() != "BREAK" and res[i + 1][0].strip() != "BREAK":
|
||||
res[i][0] += res[i + 1][0]
|
||||
res.pop(i + 1)
|
||||
else:
|
||||
@@ -1744,11 +1851,25 @@ def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length:
|
||||
tokens = []
|
||||
weights = []
|
||||
truncated = False
|
||||
|
||||
for text in prompt:
|
||||
texts_and_weights = parse_prompt_attention(text)
|
||||
text_token = []
|
||||
text_weight = []
|
||||
for word, weight in texts_and_weights:
|
||||
if word.strip() == "BREAK":
|
||||
# pad until next multiple of tokenizer's max token length
|
||||
pad_len = pipe.tokenizer.model_max_length - (len(text_token) % pipe.tokenizer.model_max_length)
|
||||
print(f"BREAK pad_len: {pad_len}")
|
||||
for i in range(pad_len):
|
||||
# v2のときEOSをつけるべきかどうかわからないぜ
|
||||
# if i == 0:
|
||||
# text_token.append(pipe.tokenizer.eos_token_id)
|
||||
# else:
|
||||
text_token.append(pipe.tokenizer.pad_token_id)
|
||||
text_weight.append(1.0)
|
||||
continue
|
||||
|
||||
# tokenize and discard the starting and the ending token
|
||||
token = pipe.tokenizer(word).input_ids[1:-1]
|
||||
|
||||
@@ -2043,6 +2164,110 @@ def preprocess_mask(mask):
|
||||
return mask
|
||||
|
||||
|
||||
# regular expression for dynamic prompt:
|
||||
# starts and ends with "{" and "}"
|
||||
# contains at least one variant divided by "|"
|
||||
# optional framgments divided by "$$" at start
|
||||
# if the first fragment is "E" or "e", enumerate all variants
|
||||
# if the second fragment is a number or two numbers, repeat the variants in the range
|
||||
# if the third fragment is a string, use it as a separator
|
||||
|
||||
RE_DYNAMIC_PROMPT = re.compile(r"\{((e|E)\$\$)?(([\d\-]+)\$\$)?(([^\|\}]+?)\$\$)?(.+?((\|).+?)*?)\}")
|
||||
|
||||
|
||||
def handle_dynamic_prompt_variants(prompt, repeat_count):
|
||||
founds = list(RE_DYNAMIC_PROMPT.finditer(prompt))
|
||||
if not founds:
|
||||
return [prompt]
|
||||
|
||||
# make each replacement for each variant
|
||||
enumerating = False
|
||||
replacers = []
|
||||
for found in founds:
|
||||
# if "e$$" is found, enumerate all variants
|
||||
found_enumerating = found.group(2) is not None
|
||||
enumerating = enumerating or found_enumerating
|
||||
|
||||
separator = ", " if found.group(6) is None else found.group(6)
|
||||
variants = found.group(7).split("|")
|
||||
|
||||
# parse count range
|
||||
count_range = found.group(4)
|
||||
if count_range is None:
|
||||
count_range = [1, 1]
|
||||
else:
|
||||
count_range = count_range.split("-")
|
||||
if len(count_range) == 1:
|
||||
count_range = [int(count_range[0]), int(count_range[0])]
|
||||
elif len(count_range) == 2:
|
||||
count_range = [int(count_range[0]), int(count_range[1])]
|
||||
else:
|
||||
print(f"invalid count range: {count_range}")
|
||||
count_range = [1, 1]
|
||||
if count_range[0] > count_range[1]:
|
||||
count_range = [count_range[1], count_range[0]]
|
||||
if count_range[0] < 0:
|
||||
count_range[0] = 0
|
||||
if count_range[1] > len(variants):
|
||||
count_range[1] = len(variants)
|
||||
|
||||
if found_enumerating:
|
||||
# make function to enumerate all combinations
|
||||
def make_replacer_enum(vari, cr, sep):
|
||||
def replacer():
|
||||
values = []
|
||||
for count in range(cr[0], cr[1] + 1):
|
||||
for comb in itertools.combinations(vari, count):
|
||||
values.append(sep.join(comb))
|
||||
return values
|
||||
|
||||
return replacer
|
||||
|
||||
replacers.append(make_replacer_enum(variants, count_range, separator))
|
||||
else:
|
||||
# make function to choose random combinations
|
||||
def make_replacer_single(vari, cr, sep):
|
||||
def replacer():
|
||||
count = random.randint(cr[0], cr[1])
|
||||
comb = random.sample(vari, count)
|
||||
return [sep.join(comb)]
|
||||
|
||||
return replacer
|
||||
|
||||
replacers.append(make_replacer_single(variants, count_range, separator))
|
||||
|
||||
# make each prompt
|
||||
if not enumerating:
|
||||
# if not enumerating, repeat the prompt, replace each variant randomly
|
||||
prompts = []
|
||||
for _ in range(repeat_count):
|
||||
current = prompt
|
||||
for found, replacer in zip(founds, replacers):
|
||||
current = current.replace(found.group(0), replacer()[0], 1)
|
||||
prompts.append(current)
|
||||
else:
|
||||
# if enumerating, iterate all combinations for previous prompts
|
||||
prompts = [prompt]
|
||||
|
||||
for found, replacer in zip(founds, replacers):
|
||||
if found.group(2) is not None:
|
||||
# make all combinations for existing prompts
|
||||
new_prompts = []
|
||||
for current in prompts:
|
||||
replecements = replacer()
|
||||
for replecement in replecements:
|
||||
new_prompts.append(current.replace(found.group(0), replecement, 1))
|
||||
prompts = new_prompts
|
||||
|
||||
for found, replacer in zip(founds, replacers):
|
||||
# make random selection for existing prompts
|
||||
if found.group(2) is None:
|
||||
for i in range(len(prompts)):
|
||||
prompts[i] = prompts[i].replace(found.group(0), replacer()[0], 1)
|
||||
|
||||
return prompts
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
@@ -2091,7 +2316,7 @@ def main(args):
|
||||
dtype = torch.float32
|
||||
|
||||
highres_fix = args.highres_fix_scale is not None
|
||||
assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません"
|
||||
# assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません"
|
||||
|
||||
if args.v_parameterization and not args.v2:
|
||||
print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません")
|
||||
@@ -2142,6 +2367,7 @@ def main(args):
|
||||
# xformers、Hypernetwork対応
|
||||
if not args.diffusers_xformers:
|
||||
replace_unet_modules(unet, not args.xformers, args.xformers)
|
||||
replace_vae_modules(vae, not args.xformers, args.xformers)
|
||||
|
||||
# tokenizerを読み込む
|
||||
print("loading tokenizer")
|
||||
@@ -2250,7 +2476,27 @@ def main(args):
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない
|
||||
|
||||
# custom pipelineをコピったやつを生成する
|
||||
if args.vae_slices:
|
||||
from library.slicing_vae import SlicingAutoencoderKL
|
||||
|
||||
sli_vae = SlicingAutoencoderKL(
|
||||
act_fn="silu",
|
||||
block_out_channels=(128, 256, 512, 512),
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
in_channels=3,
|
||||
latent_channels=4,
|
||||
layers_per_block=2,
|
||||
norm_num_groups=32,
|
||||
out_channels=3,
|
||||
sample_size=512,
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
num_slices=args.vae_slices,
|
||||
)
|
||||
sli_vae.load_state_dict(vae.state_dict()) # vaeのパラメータをコピーする
|
||||
vae = sli_vae
|
||||
del sli_vae
|
||||
vae.to(dtype).to(device)
|
||||
|
||||
text_encoder.to(dtype).to(device)
|
||||
unet.to(dtype).to(device)
|
||||
if clip_model is not None:
|
||||
@@ -2262,6 +2508,8 @@ def main(args):
|
||||
if args.network_module:
|
||||
networks = []
|
||||
network_default_muls = []
|
||||
network_pre_calc = args.network_pre_calc
|
||||
|
||||
for i, network_module in enumerate(args.network_module):
|
||||
print("import network module:", network_module)
|
||||
imported_module = importlib.import_module(network_module)
|
||||
@@ -2298,11 +2546,11 @@ def main(args):
|
||||
if network is None:
|
||||
return
|
||||
|
||||
mergiable = hasattr(network, "merge_to")
|
||||
if args.network_merge and not mergiable:
|
||||
mergeable = network.is_mergeable()
|
||||
if args.network_merge and not mergeable:
|
||||
print("network is not mergiable. ignore merge option.")
|
||||
|
||||
if not args.network_merge or not mergiable:
|
||||
if not args.network_merge or not mergeable:
|
||||
network.apply_to(text_encoder, unet)
|
||||
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
|
||||
print(f"weights are loaded: {info}")
|
||||
@@ -2311,6 +2559,10 @@ def main(args):
|
||||
network.to(memory_format=torch.channels_last)
|
||||
network.to(dtype).to(device)
|
||||
|
||||
if network_pre_calc:
|
||||
print("backup original weights")
|
||||
network.backup_weights()
|
||||
|
||||
networks.append(network)
|
||||
else:
|
||||
network.merge_to(text_encoder, unet, weights_sd, dtype, device)
|
||||
@@ -2586,12 +2838,18 @@ def main(args):
|
||||
|
||||
# 画像サイズにオプション指定があるときはリサイズする
|
||||
if args.W is not None and args.H is not None:
|
||||
# highres fix を考慮に入れる
|
||||
w, h = args.W, args.H
|
||||
if highres_fix:
|
||||
w = int(w * args.highres_fix_scale + 0.5)
|
||||
h = int(h * args.highres_fix_scale + 0.5)
|
||||
|
||||
if init_images is not None:
|
||||
print(f"resize img2img source images to {args.W}*{args.H}")
|
||||
init_images = resize_images(init_images, (args.W, args.H))
|
||||
print(f"resize img2img source images to {w}*{h}")
|
||||
init_images = resize_images(init_images, (w, h))
|
||||
if mask_images is not None:
|
||||
print(f"resize img2img mask images to {args.W}*{args.H}")
|
||||
mask_images = resize_images(mask_images, (args.W, args.H))
|
||||
print(f"resize img2img mask images to {w}*{h}")
|
||||
mask_images = resize_images(mask_images, (w, h))
|
||||
|
||||
regional_network = False
|
||||
if networks and mask_images:
|
||||
@@ -2627,6 +2885,7 @@ def main(args):
|
||||
|
||||
# seed指定時はseedを決めておく
|
||||
if args.seed is not None:
|
||||
# dynamic promptを使うと足りなくなる→images_per_promptを適当に大きくしておいてもらう
|
||||
random.seed(args.seed)
|
||||
predefined_seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.n_iter * len(prompt_list) * args.images_per_prompt)]
|
||||
if len(predefined_seeds) == 1:
|
||||
@@ -2665,17 +2924,21 @@ def main(args):
|
||||
width_1st = width_1st - width_1st % 32
|
||||
height_1st = height_1st - height_1st % 32
|
||||
|
||||
strength_1st = ext.strength if args.highres_fix_strength is None else args.highres_fix_strength
|
||||
|
||||
ext_1st = BatchDataExt(
|
||||
width_1st,
|
||||
height_1st,
|
||||
args.highres_fix_steps,
|
||||
ext.scale,
|
||||
ext.negative_scale,
|
||||
ext.strength,
|
||||
strength_1st,
|
||||
ext.network_muls,
|
||||
ext.num_sub_prompts,
|
||||
)
|
||||
batch_1st.append(BatchData(is_1st_latent, base, ext_1st))
|
||||
|
||||
pipe.set_enable_control_net(True) # 1st stageではControlNetを有効にする
|
||||
images_1st = process_batch(batch_1st, True, True)
|
||||
|
||||
# 2nd stageのバッチを作成して以下処理する
|
||||
@@ -2719,6 +2982,9 @@ def main(args):
|
||||
batch_2nd.append(bd_2nd)
|
||||
batch = batch_2nd
|
||||
|
||||
if args.highres_fix_disable_control_net:
|
||||
pipe.set_enable_control_net(False) # オプション指定時、2nd stageではControlNetを無効にする
|
||||
|
||||
# このバッチの情報を取り出す
|
||||
(
|
||||
return_latents,
|
||||
@@ -2815,12 +3081,20 @@ def main(args):
|
||||
|
||||
# generate
|
||||
if networks:
|
||||
# 追加ネットワークの処理
|
||||
shared = {}
|
||||
for n, m in zip(networks, network_muls if network_muls else network_default_muls):
|
||||
n.set_multiplier(m)
|
||||
if regional_network:
|
||||
n.set_current_generation(batch_size, num_sub_prompts, width, height, shared)
|
||||
|
||||
if not regional_network and network_pre_calc:
|
||||
for n in networks:
|
||||
n.restore_weights()
|
||||
for n in networks:
|
||||
n.pre_calculation()
|
||||
print("pre-calculation... done")
|
||||
|
||||
images = pipe(
|
||||
prompts,
|
||||
negative_prompts,
|
||||
@@ -2899,133 +3173,152 @@ def main(args):
|
||||
while not valid:
|
||||
print("\nType prompt:")
|
||||
try:
|
||||
prompt = input()
|
||||
raw_prompt = input()
|
||||
except EOFError:
|
||||
break
|
||||
|
||||
valid = len(prompt.strip().split(" --")[0].strip()) > 0
|
||||
valid = len(raw_prompt.strip().split(" --")[0].strip()) > 0
|
||||
if not valid: # EOF, end app
|
||||
break
|
||||
else:
|
||||
prompt = prompt_list[prompt_index]
|
||||
raw_prompt = prompt_list[prompt_index]
|
||||
|
||||
# parse prompt
|
||||
width = args.W
|
||||
height = args.H
|
||||
scale = args.scale
|
||||
negative_scale = args.negative_scale
|
||||
steps = args.steps
|
||||
seeds = None
|
||||
strength = 0.8 if args.strength is None else args.strength
|
||||
negative_prompt = ""
|
||||
clip_prompt = None
|
||||
network_muls = None
|
||||
# sd-dynamic-prompts like variants:
|
||||
# count is 1 (not dynamic) or images_per_prompt (no enumeration) or arbitrary (enumeration)
|
||||
raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt)
|
||||
|
||||
prompt_args = prompt.strip().split(" --")
|
||||
prompt = prompt_args[0]
|
||||
print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}")
|
||||
# repeat prompt
|
||||
for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)):
|
||||
raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0]
|
||||
|
||||
for parg in prompt_args[1:]:
|
||||
try:
|
||||
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
width = int(m.group(1))
|
||||
print(f"width: {width}")
|
||||
continue
|
||||
if pi == 0 or len(raw_prompts) > 1:
|
||||
# parse prompt: if prompt is not changed, skip parsing
|
||||
width = args.W
|
||||
height = args.H
|
||||
scale = args.scale
|
||||
negative_scale = args.negative_scale
|
||||
steps = args.steps
|
||||
seed = None
|
||||
seeds = None
|
||||
strength = 0.8 if args.strength is None else args.strength
|
||||
negative_prompt = ""
|
||||
clip_prompt = None
|
||||
network_muls = None
|
||||
|
||||
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
height = int(m.group(1))
|
||||
print(f"height: {height}")
|
||||
continue
|
||||
prompt_args = raw_prompt.strip().split(" --")
|
||||
prompt = prompt_args[0]
|
||||
print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}")
|
||||
|
||||
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
|
||||
if m: # steps
|
||||
steps = max(1, min(1000, int(m.group(1))))
|
||||
print(f"steps: {steps}")
|
||||
continue
|
||||
for parg in prompt_args[1:]:
|
||||
try:
|
||||
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
width = int(m.group(1))
|
||||
print(f"width: {width}")
|
||||
continue
|
||||
|
||||
m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE)
|
||||
if m: # seed
|
||||
seeds = [int(d) for d in m.group(1).split(",")]
|
||||
print(f"seeds: {seeds}")
|
||||
continue
|
||||
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
height = int(m.group(1))
|
||||
print(f"height: {height}")
|
||||
continue
|
||||
|
||||
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # scale
|
||||
scale = float(m.group(1))
|
||||
print(f"scale: {scale}")
|
||||
continue
|
||||
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
|
||||
if m: # steps
|
||||
steps = max(1, min(1000, int(m.group(1))))
|
||||
print(f"steps: {steps}")
|
||||
continue
|
||||
|
||||
m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE)
|
||||
if m: # negative scale
|
||||
if m.group(1).lower() == "none":
|
||||
negative_scale = None
|
||||
else:
|
||||
negative_scale = float(m.group(1))
|
||||
print(f"negative scale: {negative_scale}")
|
||||
continue
|
||||
m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE)
|
||||
if m: # seed
|
||||
seeds = [int(d) for d in m.group(1).split(",")]
|
||||
print(f"seeds: {seeds}")
|
||||
continue
|
||||
|
||||
m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # strength
|
||||
strength = float(m.group(1))
|
||||
print(f"strength: {strength}")
|
||||
continue
|
||||
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # scale
|
||||
scale = float(m.group(1))
|
||||
print(f"scale: {scale}")
|
||||
continue
|
||||
|
||||
m = re.match(r"n (.+)", parg, re.IGNORECASE)
|
||||
if m: # negative prompt
|
||||
negative_prompt = m.group(1)
|
||||
print(f"negative prompt: {negative_prompt}")
|
||||
continue
|
||||
m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE)
|
||||
if m: # negative scale
|
||||
if m.group(1).lower() == "none":
|
||||
negative_scale = None
|
||||
else:
|
||||
negative_scale = float(m.group(1))
|
||||
print(f"negative scale: {negative_scale}")
|
||||
continue
|
||||
|
||||
m = re.match(r"c (.+)", parg, re.IGNORECASE)
|
||||
if m: # clip prompt
|
||||
clip_prompt = m.group(1)
|
||||
print(f"clip prompt: {clip_prompt}")
|
||||
continue
|
||||
m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # strength
|
||||
strength = float(m.group(1))
|
||||
print(f"strength: {strength}")
|
||||
continue
|
||||
|
||||
m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE)
|
||||
if m: # network multiplies
|
||||
network_muls = [float(v) for v in m.group(1).split(",")]
|
||||
while len(network_muls) < len(networks):
|
||||
network_muls.append(network_muls[-1])
|
||||
print(f"network mul: {network_muls}")
|
||||
continue
|
||||
m = re.match(r"n (.+)", parg, re.IGNORECASE)
|
||||
if m: # negative prompt
|
||||
negative_prompt = m.group(1)
|
||||
print(f"negative prompt: {negative_prompt}")
|
||||
continue
|
||||
|
||||
except ValueError as ex:
|
||||
print(f"Exception in parsing / 解析エラー: {parg}")
|
||||
print(ex)
|
||||
m = re.match(r"c (.+)", parg, re.IGNORECASE)
|
||||
if m: # clip prompt
|
||||
clip_prompt = m.group(1)
|
||||
print(f"clip prompt: {clip_prompt}")
|
||||
continue
|
||||
|
||||
if seeds is not None:
|
||||
# 数が足りないなら繰り返す
|
||||
if len(seeds) < args.images_per_prompt:
|
||||
seeds = seeds * int(math.ceil(args.images_per_prompt / len(seeds)))
|
||||
seeds = seeds[: args.images_per_prompt]
|
||||
else:
|
||||
if predefined_seeds is not None:
|
||||
seeds = predefined_seeds[-args.images_per_prompt :]
|
||||
predefined_seeds = predefined_seeds[: -args.images_per_prompt]
|
||||
elif args.iter_same_seed:
|
||||
seeds = [iter_seed] * args.images_per_prompt
|
||||
m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE)
|
||||
if m: # network multiplies
|
||||
network_muls = [float(v) for v in m.group(1).split(",")]
|
||||
while len(network_muls) < len(networks):
|
||||
network_muls.append(network_muls[-1])
|
||||
print(f"network mul: {network_muls}")
|
||||
continue
|
||||
|
||||
except ValueError as ex:
|
||||
print(f"Exception in parsing / 解析エラー: {parg}")
|
||||
print(ex)
|
||||
|
||||
# prepare seed
|
||||
if seeds is not None: # given in prompt
|
||||
# 数が足りないなら前のをそのまま使う
|
||||
if len(seeds) > 0:
|
||||
seed = seeds.pop(0)
|
||||
else:
|
||||
seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.images_per_prompt)]
|
||||
if args.interactive:
|
||||
print(f"seed: {seeds}")
|
||||
if predefined_seeds is not None:
|
||||
if len(predefined_seeds) > 0:
|
||||
seed = predefined_seeds.pop(0)
|
||||
else:
|
||||
print("predefined seeds are exhausted")
|
||||
seed = None
|
||||
elif args.iter_same_seed:
|
||||
seeds = iter_seed
|
||||
else:
|
||||
seed = None # 前のを消す
|
||||
|
||||
if seed is None:
|
||||
seed = random.randint(0, 0x7FFFFFFF)
|
||||
if args.interactive:
|
||||
print(f"seed: {seed}")
|
||||
|
||||
# prepare init image, guide image and mask
|
||||
init_image = mask_image = guide_image = None
|
||||
|
||||
init_image = mask_image = guide_image = None
|
||||
for seed in seeds: # images_per_promptの数だけ
|
||||
# 同一イメージを使うとき、本当はlatentに変換しておくと無駄がないが面倒なのでとりあえず毎回処理する
|
||||
if init_images is not None:
|
||||
init_image = init_images[global_step % len(init_images)]
|
||||
|
||||
# img2imgの場合は、基本的に元画像のサイズで生成する。highres fixの場合はargs.W, args.Hとscaleに従いリサイズ済みなので無視する
|
||||
# 32単位に丸めたやつにresizeされるので踏襲する
|
||||
width, height = init_image.size
|
||||
width = width - width % 32
|
||||
height = height - height % 32
|
||||
if width != init_image.size[0] or height != init_image.size[1]:
|
||||
print(
|
||||
f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます"
|
||||
)
|
||||
if not highres_fix:
|
||||
width, height = init_image.size
|
||||
width = width - width % 32
|
||||
height = height - height % 32
|
||||
if width != init_image.size[0] or height != init_image.size[1]:
|
||||
print(
|
||||
f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます"
|
||||
)
|
||||
|
||||
if mask_images is not None:
|
||||
mask_image = mask_images[global_step % len(mask_images)]
|
||||
@@ -3127,6 +3420,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
default=None,
|
||||
help="batch size for VAE, < 1.0 for ratio / VAE処理時のバッチサイズ、1未満の値の場合は通常バッチサイズの比率",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_slices",
|
||||
type=int,
|
||||
default=None,
|
||||
help="number of slices to split image into for VAE to reduce VRAM usage, None for no splitting (default), slower if specified. 16 or 32 recommended / VAE処理時にVRAM使用量削減のため画像を分割するスライス数、Noneの場合は分割しない(デフォルト)、指定すると遅くなる。16か32程度を推奨",
|
||||
)
|
||||
parser.add_argument("--steps", type=int, default=50, help="number of ddim sampling steps / サンプリングステップ数")
|
||||
parser.add_argument(
|
||||
"--sampler",
|
||||
@@ -3204,6 +3503,9 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
)
|
||||
parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する")
|
||||
parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする")
|
||||
parser.add_argument(
|
||||
"--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--textual_inversion_embeddings",
|
||||
type=str,
|
||||
@@ -3261,6 +3563,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument(
|
||||
"--highres_fix_steps", type=int, default=28, help="1st stage steps for highres fix / highres fixの最初のステージのステップ数"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--highres_fix_strength",
|
||||
type=float,
|
||||
default=None,
|
||||
help="1st stage img2img strength for highres fix / highres fixの最初のステージのimg2img時のstrength、省略時はstrengthと同じ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--highres_fix_save_1st", action="store_true", help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する"
|
||||
)
|
||||
@@ -3278,6 +3586,11 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
default=None,
|
||||
help="additional argmuments for upscaler (key=value) / upscalerへの追加の引数",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--highres_fix_disable_control_net",
|
||||
action="store_true",
|
||||
help="disable ControlNet for highres fix / highres fixでControlNetを使わない",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--negative_scale", type=float, default=None, help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する"
|
||||
|
||||
@@ -1,23 +1,44 @@
|
||||
import torch
|
||||
import argparse
|
||||
import random
|
||||
import re
|
||||
from typing import List, Optional, Union
|
||||
|
||||
|
||||
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
|
||||
def prepare_scheduler_for_custom_training(noise_scheduler, device):
|
||||
if hasattr(noise_scheduler, "all_snr"):
|
||||
return
|
||||
|
||||
alphas_cumprod = noise_scheduler.alphas_cumprod
|
||||
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
|
||||
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
|
||||
alpha = sqrt_alphas_cumprod
|
||||
sigma = sqrt_one_minus_alphas_cumprod
|
||||
all_snr = (alpha / sigma) ** 2
|
||||
snr = torch.stack([all_snr[t] for t in timesteps])
|
||||
|
||||
noise_scheduler.all_snr = all_snr.to(device)
|
||||
|
||||
|
||||
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
|
||||
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
|
||||
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
|
||||
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float() # from paper
|
||||
loss = loss * snr_weight
|
||||
return loss
|
||||
|
||||
|
||||
def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler):
|
||||
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
|
||||
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
|
||||
scale = snr_t / (snr_t + 1)
|
||||
|
||||
loss = loss * scale
|
||||
return loss
|
||||
|
||||
|
||||
# TODO train_utilと分散しているのでどちらかに寄せる
|
||||
|
||||
|
||||
def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted_captions: bool = True):
|
||||
parser.add_argument(
|
||||
"--min_snr_gamma",
|
||||
@@ -25,6 +46,11 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted
|
||||
default=None,
|
||||
help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale_v_pred_loss_like_noise_pred",
|
||||
action="store_true",
|
||||
help="scale v-prediction loss like noise prediction loss / v-prediction lossをnoise prediction lossと同じようにスケーリングする",
|
||||
)
|
||||
if support_weighted_captions:
|
||||
parser.add_argument(
|
||||
"--weighted_captions",
|
||||
@@ -239,11 +265,6 @@ def get_unweighted_text_embeddings(
|
||||
text_embedding = enc_out["hidden_states"][-clip_skip]
|
||||
text_embedding = text_encoder.text_model.final_layer_norm(text_embedding)
|
||||
|
||||
# cover the head and the tail by the starting and the ending tokens
|
||||
text_input_chunk[:, 0] = text_input[0, 0]
|
||||
text_input_chunk[:, -1] = text_input[0, -1]
|
||||
text_embedding = text_encoder(text_input_chunk, attention_mask=None)[0]
|
||||
|
||||
if no_boseos_middle:
|
||||
if i == 0:
|
||||
# discard the ending token
|
||||
@@ -258,7 +279,12 @@ def get_unweighted_text_embeddings(
|
||||
text_embeddings.append(text_embedding)
|
||||
text_embeddings = torch.concat(text_embeddings, axis=1)
|
||||
else:
|
||||
text_embeddings = text_encoder(text_input)[0]
|
||||
if clip_skip is None or clip_skip == 1:
|
||||
text_embeddings = text_encoder(text_input)[0]
|
||||
else:
|
||||
enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
|
||||
text_embeddings = enc_out["hidden_states"][-clip_skip]
|
||||
text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings)
|
||||
return text_embeddings
|
||||
|
||||
|
||||
@@ -342,3 +368,91 @@ def get_weighted_text_embeddings(
|
||||
text_embeddings = text_embeddings * (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
return text_embeddings
|
||||
|
||||
|
||||
# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2
|
||||
def pyramid_noise_like(noise, device, iterations=6, discount=0.4):
|
||||
b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant!
|
||||
u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
|
||||
for i in range(iterations):
|
||||
r = random.random() * 2 + 2 # Rather than always going 2x,
|
||||
wn, hn = max(1, int(w / (r**i))), max(1, int(h / (r**i)))
|
||||
noise += u(torch.randn(b, c, wn, hn).to(device)) * discount**i
|
||||
if wn == 1 or hn == 1:
|
||||
break # Lowest resolution is 1x1
|
||||
return noise / noise.std() # Scaled back to roughly unit variance
|
||||
|
||||
|
||||
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||
def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
|
||||
if noise_offset is None:
|
||||
return noise
|
||||
if adaptive_noise_scale is not None:
|
||||
# latent shape: (batch_size, channels, height, width)
|
||||
# abs mean value for each channel
|
||||
latent_mean = torch.abs(latents.mean(dim=(2, 3), keepdim=True))
|
||||
|
||||
# multiply adaptive noise scale to the mean value and add it to the noise offset
|
||||
noise_offset = noise_offset + adaptive_noise_scale * latent_mean
|
||||
noise_offset = torch.clamp(noise_offset, 0.0, None) # in case of adaptive noise scale is negative
|
||||
|
||||
noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
||||
return noise
|
||||
|
||||
|
||||
"""
|
||||
##########################################
|
||||
# Perlin Noise
|
||||
def rand_perlin_2d(device, shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
|
||||
delta = (res[0] / shape[0], res[1] / shape[1])
|
||||
d = (shape[0] // res[0], shape[1] // res[1])
|
||||
|
||||
grid = (
|
||||
torch.stack(
|
||||
torch.meshgrid(torch.arange(0, res[0], delta[0], device=device), torch.arange(0, res[1], delta[1], device=device)),
|
||||
dim=-1,
|
||||
)
|
||||
% 1
|
||||
)
|
||||
angles = 2 * torch.pi * torch.rand(res[0] + 1, res[1] + 1, device=device)
|
||||
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
|
||||
|
||||
tile_grads = (
|
||||
lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
|
||||
.repeat_interleave(d[0], 0)
|
||||
.repeat_interleave(d[1], 1)
|
||||
)
|
||||
dot = lambda grad, shift: (
|
||||
torch.stack((grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), dim=-1)
|
||||
* grad[: shape[0], : shape[1]]
|
||||
).sum(dim=-1)
|
||||
|
||||
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
|
||||
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
|
||||
n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
|
||||
n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
|
||||
t = fade(grid[: shape[0], : shape[1]])
|
||||
return 1.414 * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
|
||||
|
||||
|
||||
def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5):
|
||||
noise = torch.zeros(shape, device=device)
|
||||
frequency = 1
|
||||
amplitude = 1
|
||||
for _ in range(octaves):
|
||||
noise += amplitude * rand_perlin_2d(device, shape, (frequency * res[0], frequency * res[1]))
|
||||
frequency *= 2
|
||||
amplitude *= persistence
|
||||
return noise
|
||||
|
||||
|
||||
def perlin_noise(noise, device, octaves):
|
||||
_, c, w, h = noise.shape
|
||||
perlin = lambda: rand_perlin_2d_octaves(device, (w, h), (4, 4), octaves)
|
||||
noise_perlin = []
|
||||
for _ in range(c):
|
||||
noise_perlin.append(perlin())
|
||||
noise_perlin = torch.stack(noise_perlin).unsqueeze(0) # (1, c, w, h)
|
||||
noise += noise_perlin # broadcast for each batch
|
||||
return noise / noise.std() # Scaled back to roughly unit variance
|
||||
"""
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
from typing import *
|
||||
from typing import Union, BinaryIO
|
||||
from huggingface_hub import HfApi
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from library.utils import fire_in_thread
|
||||
|
||||
|
||||
def exists_repo(
|
||||
repo_id: str, repo_type: str, revision: str = "main", token: str = None
|
||||
):
|
||||
def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None):
|
||||
api = HfApi(
|
||||
token=token,
|
||||
)
|
||||
@@ -33,27 +30,35 @@ def upload(
|
||||
private = args.huggingface_repo_visibility is None or args.huggingface_repo_visibility != "public"
|
||||
api = HfApi(token=token)
|
||||
if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token):
|
||||
api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
|
||||
try:
|
||||
api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
|
||||
except Exception as e: # とりあえずRepositoryNotFoundErrorは確認したが他にあると困るので
|
||||
print("===========================================")
|
||||
print(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}")
|
||||
print("===========================================")
|
||||
|
||||
is_folder = (type(src) == str and os.path.isdir(src)) or (
|
||||
isinstance(src, Path) and src.is_dir()
|
||||
)
|
||||
is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir())
|
||||
|
||||
def uploader():
|
||||
if is_folder:
|
||||
api.upload_folder(
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
folder_path=src,
|
||||
path_in_repo=path_in_repo,
|
||||
)
|
||||
else:
|
||||
api.upload_file(
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
path_or_fileobj=src,
|
||||
path_in_repo=path_in_repo,
|
||||
)
|
||||
try:
|
||||
if is_folder:
|
||||
api.upload_folder(
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
folder_path=src,
|
||||
path_in_repo=path_in_repo,
|
||||
)
|
||||
else:
|
||||
api.upload_file(
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
path_or_fileobj=src,
|
||||
path_in_repo=path_in_repo,
|
||||
)
|
||||
except Exception as e: # RuntimeErrorを確認済みだが他にあると困るので
|
||||
print("===========================================")
|
||||
print(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}")
|
||||
print("===========================================")
|
||||
|
||||
if args.async_upload and not force_sync_upload:
|
||||
fire_in_thread(uploader)
|
||||
@@ -72,7 +77,5 @@ def list_dir(
|
||||
token=token,
|
||||
)
|
||||
repo_info = api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
|
||||
file_list = [
|
||||
file for file in repo_info.siblings if file.rfilename.startswith(subfolder)
|
||||
]
|
||||
file_list = [file for file in repo_info.siblings if file.rfilename.startswith(subfolder)]
|
||||
return file_list
|
||||
|
||||
@@ -245,11 +245,6 @@ def get_unweighted_text_embeddings(
|
||||
text_embedding = enc_out["hidden_states"][-clip_skip]
|
||||
text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding)
|
||||
|
||||
# cover the head and the tail by the starting and the ending tokens
|
||||
text_input_chunk[:, 0] = text_input[0, 0]
|
||||
text_input_chunk[:, -1] = text_input[0, -1]
|
||||
text_embedding = pipe.text_encoder(text_input_chunk, attention_mask=None)[0]
|
||||
|
||||
if no_boseos_middle:
|
||||
if i == 0:
|
||||
# discard the ending token
|
||||
@@ -264,7 +259,12 @@ def get_unweighted_text_embeddings(
|
||||
text_embeddings.append(text_embedding)
|
||||
text_embeddings = torch.concat(text_embeddings, axis=1)
|
||||
else:
|
||||
text_embeddings = pipe.text_encoder(text_input)[0]
|
||||
if clip_skip is None or clip_skip == 1:
|
||||
text_embeddings = pipe.text_encoder(text_input)[0]
|
||||
else:
|
||||
enc_out = pipe.text_encoder(text_input, output_hidden_states=True, return_dict=True)
|
||||
text_embeddings = enc_out["hidden_states"][-clip_skip]
|
||||
text_embeddings = pipe.text_encoder.text_model.final_layer_norm(text_embeddings)
|
||||
return text_embeddings
|
||||
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ UNET_PARAMS_OUT_CHANNELS = 4
|
||||
UNET_PARAMS_NUM_RES_BLOCKS = 2
|
||||
UNET_PARAMS_CONTEXT_DIM = 768
|
||||
UNET_PARAMS_NUM_HEADS = 8
|
||||
# UNET_PARAMS_USE_LINEAR_PROJECTION = False
|
||||
|
||||
VAE_PARAMS_Z_CHANNELS = 4
|
||||
VAE_PARAMS_RESOLUTION = 256
|
||||
@@ -34,6 +35,7 @@ VAE_PARAMS_NUM_RES_BLOCKS = 2
|
||||
# V2
|
||||
V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
|
||||
V2_UNET_PARAMS_CONTEXT_DIM = 1024
|
||||
# V2_UNET_PARAMS_USE_LINEAR_PROJECTION = True
|
||||
|
||||
# Diffusersの設定を読み込むための参照モデル
|
||||
DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5"
|
||||
@@ -357,8 +359,9 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
|
||||
|
||||
new_checkpoint[new_path] = unet_state_dict[old_path]
|
||||
|
||||
# SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する
|
||||
if v2:
|
||||
# SDのv2では1*1のconv2dがlinearに変わっている
|
||||
# 誤って Diffusers 側を conv2d のままにしてしまったので、変換必要
|
||||
if v2 and not config.get('use_linear_projection', False):
|
||||
linear_transformer_to_conv(new_checkpoint)
|
||||
|
||||
return new_checkpoint
|
||||
@@ -468,7 +471,7 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
def create_unet_diffusers_config(v2):
|
||||
def create_unet_diffusers_config(v2, use_linear_projection_in_v2=False):
|
||||
"""
|
||||
Creates a config for the diffusers based on the config of the LDM model.
|
||||
"""
|
||||
@@ -500,7 +503,10 @@ def create_unet_diffusers_config(v2):
|
||||
layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
|
||||
cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
|
||||
attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
|
||||
# use_linear_projection=UNET_PARAMS_USE_LINEAR_PROJECTION if not v2 else V2_UNET_PARAMS_USE_LINEAR_PROJECTION,
|
||||
)
|
||||
if v2 and use_linear_projection_in_v2:
|
||||
config["use_linear_projection"] = True
|
||||
|
||||
return config
|
||||
|
||||
@@ -534,6 +540,11 @@ def convert_ldm_clip_checkpoint_v1(checkpoint):
|
||||
for key in keys:
|
||||
if key.startswith("cond_stage_model.transformer"):
|
||||
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
|
||||
|
||||
# support checkpoint without position_ids (invalid checkpoint)
|
||||
if "text_model.embeddings.position_ids" not in text_model_dict:
|
||||
text_model_dict["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) # 77 is the max length of the text
|
||||
|
||||
return text_model_dict
|
||||
|
||||
|
||||
@@ -846,11 +857,11 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"):
|
||||
|
||||
|
||||
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
|
||||
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None):
|
||||
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None, unet_use_linear_projection_in_v2=False):
|
||||
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device)
|
||||
|
||||
# Convert the UNet2DConditionModel model.
|
||||
unet_config = create_unet_diffusers_config(v2)
|
||||
unet_config = create_unet_diffusers_config(v2, unet_use_linear_projection_in_v2)
|
||||
converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
|
||||
|
||||
unet = UNet2DConditionModel(**unet_config).to(device)
|
||||
|
||||
679
library/slicing_vae.py
Normal file
679
library/slicing_vae.py
Normal file
@@ -0,0 +1,679 @@
|
||||
# Modified from Diffusers to reduce VRAM usage
|
||||
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.modeling_utils import ModelMixin
|
||||
from diffusers.utils import BaseOutput
|
||||
from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block, ResnetBlock2D
|
||||
from diffusers.models.vae import DecoderOutput, Encoder, AutoencoderKLOutput, DiagonalGaussianDistribution
|
||||
|
||||
|
||||
def slice_h(x, num_slices):
|
||||
# slice with pad 1 both sides: to eliminate side effect of padding of conv2d
|
||||
# Conv2dのpaddingの副作用を排除するために、両側にpad 1しながらHをスライスする
|
||||
# NCHWでもNHWCでもどちらでも動く
|
||||
size = (x.shape[2] + num_slices - 1) // num_slices
|
||||
sliced = []
|
||||
for i in range(num_slices):
|
||||
if i == 0:
|
||||
sliced.append(x[:, :, : size + 1, :])
|
||||
else:
|
||||
end = size * (i + 1) + 1
|
||||
if x.shape[2] - end < 3: # if the last slice is too small, use the rest of the tensor 最後が細すぎるとconv2dできないので全部使う
|
||||
end = x.shape[2]
|
||||
sliced.append(x[:, :, size * i - 1 : end, :])
|
||||
if end >= x.shape[2]:
|
||||
break
|
||||
return sliced
|
||||
|
||||
|
||||
def cat_h(sliced):
|
||||
# padding分を除いて結合する
|
||||
cat = []
|
||||
for i, x in enumerate(sliced):
|
||||
if i == 0:
|
||||
cat.append(x[:, :, :-1, :])
|
||||
elif i == len(sliced) - 1:
|
||||
cat.append(x[:, :, 1:, :])
|
||||
else:
|
||||
cat.append(x[:, :, 1:-1, :])
|
||||
del x
|
||||
x = torch.cat(cat, dim=2)
|
||||
return x
|
||||
|
||||
|
||||
def resblock_forward(_self, num_slices, input_tensor, temb):
|
||||
assert _self.upsample is None and _self.downsample is None
|
||||
assert _self.norm1.num_groups == _self.norm2.num_groups
|
||||
assert temb is None
|
||||
|
||||
# make sure norms are on cpu
|
||||
org_device = input_tensor.device
|
||||
cpu_device = torch.device("cpu")
|
||||
_self.norm1.to(cpu_device)
|
||||
_self.norm2.to(cpu_device)
|
||||
|
||||
# GroupNormがCPUでfp16で動かない対策
|
||||
org_dtype = input_tensor.dtype
|
||||
if org_dtype == torch.float16:
|
||||
_self.norm1.to(torch.float32)
|
||||
_self.norm2.to(torch.float32)
|
||||
|
||||
# すべてのテンソルをCPUに移動する
|
||||
input_tensor = input_tensor.to(cpu_device)
|
||||
hidden_states = input_tensor
|
||||
|
||||
# どうもこれは結果が異なるようだ……
|
||||
# def sliced_norm1(norm, x):
|
||||
# num_div = 4 if up_block_idx <= 2 else x.shape[1] // norm.num_groups
|
||||
# sliced_tensor = torch.chunk(x, num_div, dim=1)
|
||||
# sliced_weight = torch.chunk(norm.weight, num_div, dim=0)
|
||||
# sliced_bias = torch.chunk(norm.bias, num_div, dim=0)
|
||||
# print(sliced_tensor[0].shape, num_div, sliced_weight[0].shape, sliced_bias[0].shape)
|
||||
# normed_tensor = []
|
||||
# for i in range(num_div):
|
||||
# n = torch.group_norm(sliced_tensor[i], norm.num_groups, sliced_weight[i], sliced_bias[i], norm.eps)
|
||||
# normed_tensor.append(n)
|
||||
# del n
|
||||
# x = torch.cat(normed_tensor, dim=1)
|
||||
# return num_div, x
|
||||
|
||||
# normを分割すると結果が変わるので、ここだけは分割しない。GPUで計算するとVRAMが足りなくなるので、CPUで計算する。幸いCPUでもそこまで遅くない
|
||||
if org_dtype == torch.float16:
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
hidden_states = _self.norm1(hidden_states) # run on cpu
|
||||
if org_dtype == torch.float16:
|
||||
hidden_states = hidden_states.to(torch.float16)
|
||||
|
||||
sliced = slice_h(hidden_states, num_slices)
|
||||
del hidden_states
|
||||
|
||||
for i in range(len(sliced)):
|
||||
x = sliced[i]
|
||||
sliced[i] = None
|
||||
|
||||
# 計算する部分だけGPUに移動する、以下同様
|
||||
x = x.to(org_device)
|
||||
x = _self.nonlinearity(x)
|
||||
x = _self.conv1(x)
|
||||
x = x.to(cpu_device)
|
||||
sliced[i] = x
|
||||
del x
|
||||
|
||||
hidden_states = cat_h(sliced)
|
||||
del sliced
|
||||
|
||||
if org_dtype == torch.float16:
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
hidden_states = _self.norm2(hidden_states) # run on cpu
|
||||
if org_dtype == torch.float16:
|
||||
hidden_states = hidden_states.to(torch.float16)
|
||||
|
||||
sliced = slice_h(hidden_states, num_slices)
|
||||
del hidden_states
|
||||
|
||||
for i in range(len(sliced)):
|
||||
x = sliced[i]
|
||||
sliced[i] = None
|
||||
|
||||
x = x.to(org_device)
|
||||
x = _self.nonlinearity(x)
|
||||
x = _self.dropout(x)
|
||||
x = _self.conv2(x)
|
||||
x = x.to(cpu_device)
|
||||
sliced[i] = x
|
||||
del x
|
||||
|
||||
hidden_states = cat_h(sliced)
|
||||
del sliced
|
||||
|
||||
# make shortcut
|
||||
if _self.conv_shortcut is not None:
|
||||
sliced = list(torch.chunk(input_tensor, num_slices, dim=2)) # no padding in conv_shortcut パディングがないので普通にスライスする
|
||||
del input_tensor
|
||||
|
||||
for i in range(len(sliced)):
|
||||
x = sliced[i]
|
||||
sliced[i] = None
|
||||
|
||||
x = x.to(org_device)
|
||||
x = _self.conv_shortcut(x)
|
||||
x = x.to(cpu_device)
|
||||
sliced[i] = x
|
||||
del x
|
||||
|
||||
input_tensor = torch.cat(sliced, dim=2)
|
||||
del sliced
|
||||
|
||||
output_tensor = (input_tensor + hidden_states) / _self.output_scale_factor
|
||||
|
||||
output_tensor = output_tensor.to(org_device) # 次のレイヤーがGPUで計算する
|
||||
return output_tensor
|
||||
|
||||
|
||||
class SlicingEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=("DownEncoderBlock2D",),
|
||||
block_out_channels=(64,),
|
||||
layers_per_block=2,
|
||||
norm_num_groups=32,
|
||||
act_fn="silu",
|
||||
double_z=True,
|
||||
num_slices=2,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers_per_block = layers_per_block
|
||||
|
||||
self.conv_in = torch.nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.mid_block = None
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
|
||||
# down
|
||||
output_channel = block_out_channels[0]
|
||||
for i, down_block_type in enumerate(down_block_types):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
num_layers=self.layers_per_block,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
add_downsample=not is_final_block,
|
||||
resnet_eps=1e-6,
|
||||
downsample_padding=0,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
attn_num_head_channels=None,
|
||||
temb_channels=None,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
# mid
|
||||
self.mid_block = UNetMidBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
resnet_eps=1e-6,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=1,
|
||||
resnet_time_scale_shift="default",
|
||||
attn_num_head_channels=None,
|
||||
resnet_groups=norm_num_groups,
|
||||
temb_channels=None,
|
||||
)
|
||||
self.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True) # とりあえずDiffusersのxformersを使う
|
||||
|
||||
# out
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
|
||||
self.conv_act = nn.SiLU()
|
||||
|
||||
conv_out_channels = 2 * out_channels if double_z else out_channels
|
||||
self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
|
||||
|
||||
# replace forward of ResBlocks
|
||||
def wrapper(func, module, num_slices):
|
||||
def forward(*args, **kwargs):
|
||||
return func(module, num_slices, *args, **kwargs)
|
||||
|
||||
return forward
|
||||
|
||||
self.num_slices = num_slices
|
||||
div = num_slices / (2 ** (len(self.down_blocks) - 1)) # 深い層はそこまで分割しなくていいので適宜減らす
|
||||
# print(f"initial divisor: {div}")
|
||||
if div >= 2:
|
||||
div = int(div)
|
||||
for resnet in self.mid_block.resnets:
|
||||
resnet.forward = wrapper(resblock_forward, resnet, div)
|
||||
# midblock doesn't have downsample
|
||||
|
||||
for i, down_block in enumerate(self.down_blocks[::-1]):
|
||||
if div >= 2:
|
||||
div = int(div)
|
||||
# print(f"down block: {i} divisor: {div}")
|
||||
for resnet in down_block.resnets:
|
||||
resnet.forward = wrapper(resblock_forward, resnet, div)
|
||||
if down_block.downsamplers is not None:
|
||||
# print("has downsample")
|
||||
for downsample in down_block.downsamplers:
|
||||
downsample.forward = wrapper(self.downsample_forward, downsample, div * 2)
|
||||
div *= 2
|
||||
|
||||
def forward(self, x):
|
||||
sample = x
|
||||
del x
|
||||
|
||||
org_device = sample.device
|
||||
cpu_device = torch.device("cpu")
|
||||
|
||||
# sample = self.conv_in(sample)
|
||||
sample = sample.to(cpu_device)
|
||||
sliced = slice_h(sample, self.num_slices)
|
||||
del sample
|
||||
|
||||
for i in range(len(sliced)):
|
||||
x = sliced[i]
|
||||
sliced[i] = None
|
||||
|
||||
x = x.to(org_device)
|
||||
x = self.conv_in(x)
|
||||
x = x.to(cpu_device)
|
||||
sliced[i] = x
|
||||
del x
|
||||
|
||||
sample = cat_h(sliced)
|
||||
del sliced
|
||||
|
||||
sample = sample.to(org_device)
|
||||
|
||||
# down
|
||||
for down_block in self.down_blocks:
|
||||
sample = down_block(sample)
|
||||
|
||||
# middle
|
||||
sample = self.mid_block(sample)
|
||||
|
||||
# post-process
|
||||
# ここも省メモリ化したいが、恐らくそこまでメモリを食わないので省略
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
return sample
|
||||
|
||||
def downsample_forward(self, _self, num_slices, hidden_states):
|
||||
assert hidden_states.shape[1] == _self.channels
|
||||
assert _self.use_conv and _self.padding == 0
|
||||
print("downsample forward", num_slices, hidden_states.shape)
|
||||
|
||||
org_device = hidden_states.device
|
||||
cpu_device = torch.device("cpu")
|
||||
|
||||
hidden_states = hidden_states.to(cpu_device)
|
||||
pad = (0, 1, 0, 1)
|
||||
hidden_states = torch.nn.functional.pad(hidden_states, pad, mode="constant", value=0)
|
||||
|
||||
# slice with even number because of stride 2
|
||||
# strideが2なので偶数でスライスする
|
||||
# slice with pad 1 both sides: to eliminate side effect of padding of conv2d
|
||||
size = (hidden_states.shape[2] + num_slices - 1) // num_slices
|
||||
size = size + 1 if size % 2 == 1 else size
|
||||
|
||||
sliced = []
|
||||
for i in range(num_slices):
|
||||
if i == 0:
|
||||
sliced.append(hidden_states[:, :, : size + 1, :])
|
||||
else:
|
||||
end = size * (i + 1) + 1
|
||||
if hidden_states.shape[2] - end < 4: # if the last slice is too small, use the rest of the tensor
|
||||
end = hidden_states.shape[2]
|
||||
sliced.append(hidden_states[:, :, size * i - 1 : end, :])
|
||||
if end >= hidden_states.shape[2]:
|
||||
break
|
||||
del hidden_states
|
||||
|
||||
for i in range(len(sliced)):
|
||||
x = sliced[i]
|
||||
sliced[i] = None
|
||||
|
||||
x = x.to(org_device)
|
||||
x = _self.conv(x)
|
||||
x = x.to(cpu_device)
|
||||
|
||||
# ここだけ雰囲気が違うのはCopilotのせい
|
||||
if i == 0:
|
||||
hidden_states = x
|
||||
else:
|
||||
hidden_states = torch.cat([hidden_states, x], dim=2)
|
||||
|
||||
hidden_states = hidden_states.to(org_device)
|
||||
# print("downsample forward done", hidden_states.shape)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SlicingDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
up_block_types=("UpDecoderBlock2D",),
|
||||
block_out_channels=(64,),
|
||||
layers_per_block=2,
|
||||
norm_num_groups=32,
|
||||
act_fn="silu",
|
||||
num_slices=2,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers_per_block = layers_per_block
|
||||
|
||||
self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.mid_block = None
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
# mid
|
||||
self.mid_block = UNetMidBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
resnet_eps=1e-6,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=1,
|
||||
resnet_time_scale_shift="default",
|
||||
attn_num_head_channels=None,
|
||||
resnet_groups=norm_num_groups,
|
||||
temb_channels=None,
|
||||
)
|
||||
self.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True) # とりあえずDiffusersのxformersを使う
|
||||
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
for i, up_block_type in enumerate(up_block_types):
|
||||
prev_output_channel = output_channel
|
||||
output_channel = reversed_block_out_channels[i]
|
||||
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
up_block = get_up_block(
|
||||
up_block_type,
|
||||
num_layers=self.layers_per_block + 1,
|
||||
in_channels=prev_output_channel,
|
||||
out_channels=output_channel,
|
||||
prev_output_channel=None,
|
||||
add_upsample=not is_final_block,
|
||||
resnet_eps=1e-6,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
attn_num_head_channels=None,
|
||||
temb_channels=None,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
# out
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
|
||||
|
||||
# replace forward of ResBlocks
|
||||
def wrapper(func, module, num_slices):
|
||||
def forward(*args, **kwargs):
|
||||
return func(module, num_slices, *args, **kwargs)
|
||||
|
||||
return forward
|
||||
|
||||
self.num_slices = num_slices
|
||||
div = num_slices / (2 ** (len(self.up_blocks) - 1))
|
||||
print(f"initial divisor: {div}")
|
||||
if div >= 2:
|
||||
div = int(div)
|
||||
for resnet in self.mid_block.resnets:
|
||||
resnet.forward = wrapper(resblock_forward, resnet, div)
|
||||
# midblock doesn't have upsample
|
||||
|
||||
for i, up_block in enumerate(self.up_blocks):
|
||||
if div >= 2:
|
||||
div = int(div)
|
||||
# print(f"up block: {i} divisor: {div}")
|
||||
for resnet in up_block.resnets:
|
||||
resnet.forward = wrapper(resblock_forward, resnet, div)
|
||||
if up_block.upsamplers is not None:
|
||||
# print("has upsample")
|
||||
for upsample in up_block.upsamplers:
|
||||
upsample.forward = wrapper(self.upsample_forward, upsample, div * 2)
|
||||
div *= 2
|
||||
|
||||
def forward(self, z):
|
||||
sample = z
|
||||
del z
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
# middle
|
||||
sample = self.mid_block(sample)
|
||||
|
||||
# up
|
||||
for i, up_block in enumerate(self.up_blocks):
|
||||
sample = up_block(sample)
|
||||
|
||||
# post-process
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
|
||||
# conv_out with slicing because of VRAM usage
|
||||
# conv_outはとてもVRAM使うのでスライスして対応
|
||||
org_device = sample.device
|
||||
cpu_device = torch.device("cpu")
|
||||
sample = sample.to(cpu_device)
|
||||
|
||||
sliced = slice_h(sample, self.num_slices)
|
||||
del sample
|
||||
for i in range(len(sliced)):
|
||||
x = sliced[i]
|
||||
sliced[i] = None
|
||||
|
||||
x = x.to(org_device)
|
||||
x = self.conv_out(x)
|
||||
x = x.to(cpu_device)
|
||||
sliced[i] = x
|
||||
sample = cat_h(sliced)
|
||||
del sliced
|
||||
|
||||
sample = sample.to(org_device)
|
||||
return sample
|
||||
|
||||
def upsample_forward(self, _self, num_slices, hidden_states, output_size=None):
|
||||
assert hidden_states.shape[1] == _self.channels
|
||||
assert _self.use_conv_transpose == False and _self.use_conv
|
||||
|
||||
org_dtype = hidden_states.dtype
|
||||
org_device = hidden_states.device
|
||||
cpu_device = torch.device("cpu")
|
||||
|
||||
hidden_states = hidden_states.to(cpu_device)
|
||||
sliced = slice_h(hidden_states, num_slices)
|
||||
del hidden_states
|
||||
|
||||
for i in range(len(sliced)):
|
||||
x = sliced[i]
|
||||
sliced[i] = None
|
||||
|
||||
x = x.to(org_device)
|
||||
|
||||
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
||||
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
|
||||
# https://github.com/pytorch/pytorch/issues/86679
|
||||
# PyTorch 2で直らないかね……
|
||||
if org_dtype == torch.bfloat16:
|
||||
x = x.to(torch.float32)
|
||||
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
|
||||
if org_dtype == torch.bfloat16:
|
||||
x = x.to(org_dtype)
|
||||
|
||||
x = _self.conv(x)
|
||||
|
||||
# upsampleされてるのでpadは2になる
|
||||
if i == 0:
|
||||
x = x[:, :, :-2, :]
|
||||
elif i == num_slices - 1:
|
||||
x = x[:, :, 2:, :]
|
||||
else:
|
||||
x = x[:, :, 2:-2, :]
|
||||
|
||||
x = x.to(cpu_device)
|
||||
sliced[i] = x
|
||||
del x
|
||||
|
||||
hidden_states = torch.cat(sliced, dim=2)
|
||||
# print("us hidden_states", hidden_states.shape)
|
||||
del sliced
|
||||
|
||||
hidden_states = hidden_states.to(org_device)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SlicingAutoencoderKL(ModelMixin, ConfigMixin):
|
||||
r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma
|
||||
and Max Welling.
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
|
||||
implements for all the model (such as downloading or saving, etc.)
|
||||
|
||||
Parameters:
|
||||
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
||||
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to :
|
||||
obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to :
|
||||
obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to :
|
||||
obj:`(64,)`): Tuple of block output channels.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
||||
latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space.
|
||||
sample_size (`int`, *optional*, defaults to `32`): TODO
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
|
||||
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
|
||||
block_out_channels: Tuple[int] = (64,),
|
||||
layers_per_block: int = 1,
|
||||
act_fn: str = "silu",
|
||||
latent_channels: int = 4,
|
||||
norm_num_groups: int = 32,
|
||||
sample_size: int = 32,
|
||||
num_slices: int = 16,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# pass init params to Encoder
|
||||
self.encoder = SlicingEncoder(
|
||||
in_channels=in_channels,
|
||||
out_channels=latent_channels,
|
||||
down_block_types=down_block_types,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
act_fn=act_fn,
|
||||
norm_num_groups=norm_num_groups,
|
||||
double_z=True,
|
||||
num_slices=num_slices,
|
||||
)
|
||||
|
||||
# pass init params to Decoder
|
||||
self.decoder = SlicingDecoder(
|
||||
in_channels=latent_channels,
|
||||
out_channels=out_channels,
|
||||
up_block_types=up_block_types,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
norm_num_groups=norm_num_groups,
|
||||
act_fn=act_fn,
|
||||
num_slices=num_slices,
|
||||
)
|
||||
|
||||
self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
|
||||
self.use_slicing = False
|
||||
|
||||
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
||||
h = self.encoder(x)
|
||||
moments = self.quant_conv(h)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
|
||||
if not return_dict:
|
||||
return (posterior,)
|
||||
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
z = self.post_quant_conv(z)
|
||||
dec = self.decoder(z)
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
# これはバッチ方向のスライシング 紛らわしい
|
||||
def enable_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding.
|
||||
|
||||
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
|
||||
steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.use_slicing = True
|
||||
|
||||
def disable_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_slicing = False
|
||||
|
||||
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
if self.use_slicing and z.shape[0] > 1:
|
||||
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
||||
decoded = torch.cat(decoded_slices)
|
||||
else:
|
||||
decoded = self._decode(z).sample
|
||||
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
|
||||
return DecoderOutput(sample=decoded)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.FloatTensor`): Input sample.
|
||||
sample_posterior (`bool`, *optional*, defaults to `False`):
|
||||
Whether to sample from the posterior.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
||||
"""
|
||||
x = sample
|
||||
posterior = self.encode(x).latent_dist
|
||||
if sample_posterior:
|
||||
z = posterior.sample(generator=generator)
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z).sample
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
@@ -19,6 +19,7 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
from accelerate import Accelerator
|
||||
import gc
|
||||
import glob
|
||||
import math
|
||||
import os
|
||||
@@ -30,6 +31,7 @@ import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
from torchvision import transforms
|
||||
from transformers import CLIPTokenizer
|
||||
@@ -346,6 +348,8 @@ class DreamBoothSubset(BaseSubset):
|
||||
self.is_reg = is_reg
|
||||
self.class_tokens = class_tokens
|
||||
self.caption_extension = caption_extension
|
||||
if self.caption_extension and not self.caption_extension.startswith("."):
|
||||
self.caption_extension = "." + self.caption_extension
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
if not isinstance(other, DreamBoothSubset):
|
||||
@@ -1079,16 +1083,37 @@ class DreamBoothDataset(BaseDataset):
|
||||
|
||||
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
|
||||
captions = []
|
||||
missing_captions = []
|
||||
for img_path in img_paths:
|
||||
cap_for_img = read_caption(img_path, subset.caption_extension)
|
||||
if cap_for_img is None and subset.class_tokens is None:
|
||||
print(f"neither caption file nor class tokens are found. use empty caption for {img_path}")
|
||||
print(
|
||||
f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}"
|
||||
)
|
||||
captions.append("")
|
||||
missing_captions.append(img_path)
|
||||
else:
|
||||
captions.append(subset.class_tokens if cap_for_img is None else cap_for_img)
|
||||
if cap_for_img is None:
|
||||
captions.append(subset.class_tokens)
|
||||
missing_captions.append(img_path)
|
||||
else:
|
||||
captions.append(cap_for_img)
|
||||
|
||||
self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
|
||||
|
||||
if missing_captions:
|
||||
number_of_missing_captions = len(missing_captions)
|
||||
number_of_missing_captions_to_show = 5
|
||||
remaining_missing_captions = number_of_missing_captions - number_of_missing_captions_to_show
|
||||
|
||||
print(
|
||||
f"No caption file found for {number_of_missing_captions} images. Training will continue without captions for these images. If class token exists, it will be used. / {number_of_missing_captions}枚の画像にキャプションファイルが見つかりませんでした。これらの画像についてはキャプションなしで学習を続行します。class tokenが存在する場合はそれを使います。"
|
||||
)
|
||||
for i, missing_caption in enumerate(missing_captions):
|
||||
if i >= number_of_missing_captions_to_show:
|
||||
print(missing_caption + f"... and {remaining_missing_captions} more")
|
||||
break
|
||||
print(missing_caption)
|
||||
return img_paths, captions
|
||||
|
||||
print("prepare images.")
|
||||
@@ -1422,7 +1447,7 @@ def debug_dataset(train_dataset, show_input_ids=False):
|
||||
|
||||
epoch = 1
|
||||
while True:
|
||||
print(f"epoch: {epoch}")
|
||||
print(f"\nepoch: {epoch}")
|
||||
|
||||
steps = (epoch - 1) * len(train_dataset) + 1
|
||||
indices = list(range(len(train_dataset)))
|
||||
@@ -1493,6 +1518,76 @@ def glob_images_pathlib(dir_path, recursive):
|
||||
return image_paths
|
||||
|
||||
|
||||
class MinimalDataset(BaseDataset):
|
||||
def __init__(self, tokenizer, max_token_length, resolution, debug_dataset=False):
|
||||
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
|
||||
|
||||
self.num_train_images = 0 # update in subclass
|
||||
self.num_reg_images = 0 # update in subclass
|
||||
self.datasets = [self]
|
||||
self.batch_size = 1 # update in subclass
|
||||
|
||||
self.subsets = [self]
|
||||
self.num_repeats = 1 # update in subclass if needed
|
||||
self.img_count = 1 # update in subclass if needed
|
||||
self.bucket_info = {}
|
||||
self.is_reg = False
|
||||
self.image_dir = "dummy" # for metadata
|
||||
|
||||
def is_latent_cacheable(self) -> bool:
|
||||
return False
|
||||
|
||||
def __len__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
# override to avoid shuffling buckets
|
||||
def set_current_epoch(self, epoch):
|
||||
self.current_epoch = epoch
|
||||
|
||||
def __getitem__(self, idx):
|
||||
r"""
|
||||
The subclass may have image_data for debug_dataset, which is a dict of ImageInfo objects.
|
||||
|
||||
Returns: example like this:
|
||||
|
||||
for i in range(batch_size):
|
||||
image_key = ... # whatever hashable
|
||||
image_keys.append(image_key)
|
||||
|
||||
image = ... # PIL Image
|
||||
img_tensor = self.image_transforms(img)
|
||||
images.append(img_tensor)
|
||||
|
||||
caption = ... # str
|
||||
input_ids = self.get_input_ids(caption)
|
||||
input_ids_list.append(input_ids)
|
||||
|
||||
captions.append(caption)
|
||||
|
||||
images = torch.stack(images, dim=0)
|
||||
input_ids_list = torch.stack(input_ids_list, dim=0)
|
||||
example = {
|
||||
"images": images,
|
||||
"input_ids": input_ids_list,
|
||||
"captions": captions, # for debug_dataset
|
||||
"latents": None,
|
||||
"image_keys": image_keys, # for debug_dataset
|
||||
"loss_weights": torch.ones(batch_size, dtype=torch.float32),
|
||||
}
|
||||
return example
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset:
|
||||
module = ".".join(args.dataset_class.split(".")[:-1])
|
||||
dataset_class = args.dataset_class.split(".")[-1]
|
||||
module = importlib.import_module(module)
|
||||
dataset_class = getattr(module, dataset_class)
|
||||
train_dataset_group: MinimalDataset = dataset_class(tokenizer, args.max_token_length, args.resolution, args.debug_dataset)
|
||||
return train_dataset_group
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region モジュール入れ替え部
|
||||
@@ -1763,6 +1858,7 @@ class FlashAttentionFunction(torch.autograd.function.Function):
|
||||
|
||||
|
||||
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
|
||||
# unet is not used currently, but it is here for future use
|
||||
if mem_eff_attn:
|
||||
replace_unet_cross_attn_to_memory_efficient()
|
||||
elif xformers:
|
||||
@@ -1770,7 +1866,7 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio
|
||||
|
||||
|
||||
def replace_unet_cross_attn_to_memory_efficient():
|
||||
print("Replace CrossAttention.forward to use FlashAttention (not xformers)")
|
||||
print("CrossAttention.forward has been replaced to FlashAttention (not xformers)")
|
||||
flash_func = FlashAttentionFunction
|
||||
|
||||
def forward_flash_attn(self, x, context=None, mask=None):
|
||||
@@ -1810,7 +1906,7 @@ def replace_unet_cross_attn_to_memory_efficient():
|
||||
|
||||
|
||||
def replace_unet_cross_attn_to_xformers():
|
||||
print("Replace CrossAttention.forward to use xformers")
|
||||
print("CrossAttention.forward has been replaced to enable xformers.")
|
||||
try:
|
||||
import xformers.ops
|
||||
except ImportError:
|
||||
@@ -1852,6 +1948,60 @@ def replace_unet_cross_attn_to_xformers():
|
||||
diffusers.models.attention.CrossAttention.forward = forward_xformers
|
||||
|
||||
|
||||
"""
|
||||
def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xformers):
|
||||
# vae is not used currently, but it is here for future use
|
||||
if mem_eff_attn:
|
||||
replace_vae_attn_to_memory_efficient()
|
||||
elif xformers:
|
||||
# とりあえずDiffusersのxformersを使う。AttentionがあるのはMidBlockのみ
|
||||
print("Use Diffusers xformers for VAE")
|
||||
vae.encoder.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True)
|
||||
vae.decoder.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True)
|
||||
|
||||
|
||||
def replace_vae_attn_to_memory_efficient():
|
||||
print("AttentionBlock.forward has been replaced to FlashAttention (not xformers)")
|
||||
flash_func = FlashAttentionFunction
|
||||
|
||||
def forward_flash_attn(self, hidden_states):
|
||||
print("forward_flash_attn")
|
||||
q_bucket_size = 512
|
||||
k_bucket_size = 1024
|
||||
|
||||
residual = hidden_states
|
||||
batch, channel, height, width = hidden_states.shape
|
||||
|
||||
# norm
|
||||
hidden_states = self.group_norm(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
|
||||
|
||||
# proj to q, k, v
|
||||
query_proj = self.query(hidden_states)
|
||||
key_proj = self.key(hidden_states)
|
||||
value_proj = self.value(hidden_states)
|
||||
|
||||
query_proj, key_proj, value_proj = map(
|
||||
lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (query_proj, key_proj, value_proj)
|
||||
)
|
||||
|
||||
out = flash_func.apply(query_proj, key_proj, value_proj, None, False, q_bucket_size, k_bucket_size)
|
||||
|
||||
out = rearrange(out, "b h n d -> b n (h d)")
|
||||
|
||||
# compute next hidden_states
|
||||
hidden_states = self.proj_attn(hidden_states)
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
|
||||
|
||||
# res connect and rescale
|
||||
hidden_states = (hidden_states + residual) / self.rescale_output_factor
|
||||
return hidden_states
|
||||
|
||||
diffusers.models.attention.AttentionBlock.forward = forward_flash_attn
|
||||
"""
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
@@ -1883,7 +2033,7 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
||||
"--optimizer_type",
|
||||
type=str,
|
||||
default="",
|
||||
help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor",
|
||||
help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW8bit, Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor",
|
||||
)
|
||||
|
||||
# backward compatibility
|
||||
@@ -2119,6 +2269,30 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
default=None,
|
||||
help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--multires_noise_iterations",
|
||||
type=int,
|
||||
default=None,
|
||||
help="enable multires noise with this number of iterations (if enabled, around 6-10 is recommended) / Multires noiseを有効にしてこのイテレーション数を設定する(有効にする場合は6-10程度を推奨)",
|
||||
)
|
||||
# parser.add_argument(
|
||||
# "--perlin_noise",
|
||||
# type=int,
|
||||
# default=None,
|
||||
# help="enable perlin noise and set the octaves / perlin noiseを有効にしてoctavesをこの値に設定する",
|
||||
# )
|
||||
parser.add_argument(
|
||||
"--multires_noise_discount",
|
||||
type=float,
|
||||
default=0.3,
|
||||
help="set discount value for multires noise (has no effect without --multires_noise_iterations) / Multires noiseのdiscount値を設定する(--multires_noise_iterations指定時のみ有効)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adaptive_noise_scale",
|
||||
type=float,
|
||||
default=None,
|
||||
help="add `latent mean absolute value * this value` to noise_offset (disabled if None, default) / latentの平均値の絶対値 * この値をnoise_offsetに加算する(Noneの場合は無効、デフォルト)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lowram",
|
||||
action="store_true",
|
||||
@@ -2191,6 +2365,27 @@ def verify_training_args(args: argparse.Namespace):
|
||||
"cache_latents_to_disk is enabled, so cache_latents is also enabled / cache_latents_to_diskが有効なため、cache_latentsを有効にします"
|
||||
)
|
||||
|
||||
# noise_offset, perlin_noise, multires_noise_iterations cannot be enabled at the same time
|
||||
# Listを使って数えてもいいけど並べてしまえ
|
||||
if args.noise_offset is not None and args.multires_noise_iterations is not None:
|
||||
raise ValueError(
|
||||
"noise_offset and multires_noise_iterations cannot be enabled at the same time / noise_offsetとmultires_noise_iterationsを同時に有効にできません"
|
||||
)
|
||||
# if args.noise_offset is not None and args.perlin_noise is not None:
|
||||
# raise ValueError("noise_offset and perlin_noise cannot be enabled at the same time / noise_offsetとperlin_noiseは同時に有効にできません")
|
||||
# if args.perlin_noise is not None and args.multires_noise_iterations is not None:
|
||||
# raise ValueError(
|
||||
# "perlin_noise and multires_noise_iterations cannot be enabled at the same time / perlin_noiseとmultires_noise_iterationsを同時に有効にできません"
|
||||
# )
|
||||
|
||||
if args.adaptive_noise_scale is not None and args.noise_offset is None:
|
||||
raise ValueError("adaptive_noise_scale requires noise_offset / adaptive_noise_scaleを使用するにはnoise_offsetが必要です")
|
||||
|
||||
if args.scale_v_pred_loss_like_noise_pred and not args.v_parameterization:
|
||||
raise ValueError(
|
||||
"scale_v_pred_loss_like_noise_pred can be enabled only with v_parameterization / scale_v_pred_loss_like_noise_predはv_parameterizationが有効なときのみ有効にできます"
|
||||
)
|
||||
|
||||
|
||||
def add_dataset_arguments(
|
||||
parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool
|
||||
@@ -2269,7 +2464,6 @@ def add_dataset_arguments(
|
||||
default=1,
|
||||
help="start learning at N tags (token means comma separated strinfloatgs) / タグ数をN個から増やしながら学習する",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--token_warmup_step",
|
||||
type=float,
|
||||
@@ -2277,6 +2471,13 @@ def add_dataset_arguments(
|
||||
help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / N(N<1ならN*max_train_steps)ステップでタグ長が最大になる。デフォルトは0(最初から最大)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dataset_class",
|
||||
type=str,
|
||||
default=None,
|
||||
help="dataset class for arbitrary dataset (package.module.Class) / 任意のデータセットを用いるときのクラス名 (package.module.Class)",
|
||||
)
|
||||
|
||||
if support_caption_dropout:
|
||||
# Textual Inversion はcaptionのdropoutをsupportしない
|
||||
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
|
||||
@@ -2448,7 +2649,7 @@ def resume_from_local_or_hf_if_specified(accelerator, args):
|
||||
|
||||
|
||||
def get_optimizer(args, trainable_params):
|
||||
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, Adafactor"
|
||||
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW8bit, Lion8bit, PagedLion8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor"
|
||||
|
||||
optimizer_type = args.optimizer_type
|
||||
if args.use_8bit_adam:
|
||||
@@ -2525,6 +2726,39 @@ def get_optimizer(args, trainable_params):
|
||||
print(f"use Lion optimizer | {optimizer_kwargs}")
|
||||
optimizer_class = lion_pytorch.Lion
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type.endswith("8bit".lower()):
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです")
|
||||
|
||||
if optimizer_type == "Lion8bit".lower():
|
||||
print(f"use 8-bit Lion optimizer | {optimizer_kwargs}")
|
||||
try:
|
||||
optimizer_class = bnb.optim.Lion8bit
|
||||
except AttributeError:
|
||||
raise AttributeError(
|
||||
"No Lion8bit. The version of bitsandbytes installed seems to be old. Please install 0.38.0 or later. / Lion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.38.0以上をインストールしてください"
|
||||
)
|
||||
elif optimizer_type == "PagedAdamW8bit".lower():
|
||||
print(f"use 8-bit PagedAdamW optimizer | {optimizer_kwargs}")
|
||||
try:
|
||||
optimizer_class = bnb.optim.PagedAdamW8bit
|
||||
except AttributeError:
|
||||
raise AttributeError(
|
||||
"No PagedAdamW8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedAdamW8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください"
|
||||
)
|
||||
elif optimizer_type == "PagedLion8bit".lower():
|
||||
print(f"use 8-bit Paged Lion optimizer | {optimizer_kwargs}")
|
||||
try:
|
||||
optimizer_class = bnb.optim.PagedLion8bit
|
||||
except AttributeError:
|
||||
raise AttributeError(
|
||||
"No PagedLion8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedLion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください"
|
||||
)
|
||||
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type == "SGDNesterov".lower():
|
||||
print(f"use SGD with Nesterov optimizer | {optimizer_kwargs}")
|
||||
@@ -2535,13 +2769,8 @@ def get_optimizer(args, trainable_params):
|
||||
optimizer_class = torch.optim.SGD
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type == "DAdaptation".lower():
|
||||
try:
|
||||
import dadaptation
|
||||
except ImportError:
|
||||
raise ImportError("No dadaptation / dadaptation がインストールされていないようです")
|
||||
print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}")
|
||||
|
||||
elif optimizer_type.startswith("DAdapt".lower()) or optimizer_type == "Prodigy".lower():
|
||||
# check lr and lr_count, and print warning
|
||||
actual_lr = lr
|
||||
lr_count = 1
|
||||
if type(trainable_params) == list and type(trainable_params[0]) == dict:
|
||||
@@ -2553,16 +2782,60 @@ def get_optimizer(args, trainable_params):
|
||||
|
||||
if actual_lr <= 0.1:
|
||||
print(
|
||||
f"learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: lr={actual_lr}"
|
||||
f"learning rate is too low. If using D-Adaptation or Prodigy, set learning rate around 1.0 / 学習率が低すぎるようです。D-AdaptationまたはProdigyの使用時は1.0前後の値を指定してください: lr={actual_lr}"
|
||||
)
|
||||
print("recommend option: lr=1.0 / 推奨は1.0です")
|
||||
if lr_count > 1:
|
||||
print(
|
||||
f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-Adaptationで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}"
|
||||
f"when multiple learning rates are specified with dadaptation (e.g. for Text Encoder and U-Net), only the first one will take effect / D-AdaptationまたはProdigyで複数の学習率を指定した場合(Text EncoderとU-Netなど)、最初の学習率のみが有効になります: lr={actual_lr}"
|
||||
)
|
||||
|
||||
optimizer_class = dadaptation.DAdaptAdam
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
if optimizer_type.startswith("DAdapt".lower()):
|
||||
# DAdaptation family
|
||||
# check dadaptation is installed
|
||||
try:
|
||||
import dadaptation
|
||||
import dadaptation.experimental as experimental
|
||||
except ImportError:
|
||||
raise ImportError("No dadaptation / dadaptation がインストールされていないようです")
|
||||
|
||||
# set optimizer
|
||||
if optimizer_type == "DAdaptation".lower() or optimizer_type == "DAdaptAdamPreprint".lower():
|
||||
optimizer_class = experimental.DAdaptAdamPreprint
|
||||
print(f"use D-Adaptation AdamPreprint optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "DAdaptAdaGrad".lower():
|
||||
optimizer_class = dadaptation.DAdaptAdaGrad
|
||||
print(f"use D-Adaptation AdaGrad optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "DAdaptAdam".lower():
|
||||
optimizer_class = dadaptation.DAdaptAdam
|
||||
print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "DAdaptAdan".lower():
|
||||
optimizer_class = dadaptation.DAdaptAdan
|
||||
print(f"use D-Adaptation Adan optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "DAdaptAdanIP".lower():
|
||||
optimizer_class = experimental.DAdaptAdanIP
|
||||
print(f"use D-Adaptation AdanIP optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "DAdaptLion".lower():
|
||||
optimizer_class = dadaptation.DAdaptLion
|
||||
print(f"use D-Adaptation Lion optimizer | {optimizer_kwargs}")
|
||||
elif optimizer_type == "DAdaptSGD".lower():
|
||||
optimizer_class = dadaptation.DAdaptSGD
|
||||
print(f"use D-Adaptation SGD optimizer | {optimizer_kwargs}")
|
||||
else:
|
||||
raise ValueError(f"Unknown optimizer type: {optimizer_type}")
|
||||
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
else:
|
||||
# Prodigy
|
||||
# check Prodigy is installed
|
||||
try:
|
||||
import prodigyopt
|
||||
except ImportError:
|
||||
raise ImportError("No Prodigy / Prodigy がインストールされていないようです")
|
||||
|
||||
print(f"use Prodigy optimizer | {optimizer_kwargs}")
|
||||
optimizer_class = prodigyopt.Prodigy
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type == "Adafactor".lower():
|
||||
# 引数を確認して適宜補正する
|
||||
@@ -2850,16 +3123,16 @@ def prepare_dtype(args: argparse.Namespace):
|
||||
return weight_dtype, save_dtype
|
||||
|
||||
|
||||
def load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"):
|
||||
def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"):
|
||||
name_or_path = args.pretrained_model_name_or_path
|
||||
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
|
||||
name_or_path = os.path.realpath(name_or_path) if os.path.islink(name_or_path) else name_or_path
|
||||
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
|
||||
if load_stable_diffusion_format:
|
||||
print("load StableDiffusion checkpoint")
|
||||
print(f"load StableDiffusion checkpoint: {name_or_path}")
|
||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path, device)
|
||||
else:
|
||||
# Diffusers model is loaded to CPU
|
||||
print("load Diffusers pretrained models")
|
||||
print(f"load Diffusers pretrained models: {name_or_path}")
|
||||
try:
|
||||
pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None)
|
||||
except EnvironmentError as ex:
|
||||
@@ -2879,6 +3152,36 @@ def load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"):
|
||||
return text_encoder, vae, unet, load_stable_diffusion_format
|
||||
|
||||
|
||||
def transform_if_model_is_DDP(text_encoder, unet, network=None):
|
||||
# Transform text_encoder, unet and network from DistributedDataParallel
|
||||
return (model.module if type(model) == DDP else model for model in [text_encoder, unet, network] if model is not None)
|
||||
|
||||
|
||||
def load_target_model(args, weight_dtype, accelerator):
|
||||
# load models for each process
|
||||
for pi in range(accelerator.state.num_processes):
|
||||
if pi == accelerator.state.local_process_index:
|
||||
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
|
||||
|
||||
text_encoder, vae, unet, load_stable_diffusion_format = _load_target_model(
|
||||
args, weight_dtype, accelerator.device if args.lowram else "cpu"
|
||||
)
|
||||
|
||||
# work on low-ram device
|
||||
if args.lowram:
|
||||
text_encoder.to(accelerator.device)
|
||||
unet.to(accelerator.device)
|
||||
vae.to(accelerator.device)
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
text_encoder, unet = transform_if_model_is_DDP(text_encoder, unet)
|
||||
|
||||
return text_encoder, vae, unet, load_stable_diffusion_format
|
||||
|
||||
|
||||
def patch_accelerator_for_fp16_training(accelerator):
|
||||
org_unscale_grads = accelerator.scaler._unscale_grads_
|
||||
|
||||
@@ -3018,7 +3321,7 @@ def save_sd_model_on_epoch_end_or_stepwise(
|
||||
ckpt_name = get_step_ckpt_name(args, ext, global_step)
|
||||
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
print(f"saving checkpoint: {ckpt_file}")
|
||||
print(f"\nsaving checkpoint: {ckpt_file}")
|
||||
model_util.save_stable_diffusion_checkpoint(
|
||||
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae
|
||||
)
|
||||
@@ -3044,7 +3347,7 @@ def save_sd_model_on_epoch_end_or_stepwise(
|
||||
else:
|
||||
out_dir = os.path.join(args.output_dir, STEP_DIFFUSERS_DIR_NAME.format(model_name, global_step))
|
||||
|
||||
print(f"saving model: {out_dir}")
|
||||
print(f"\nsaving model: {out_dir}")
|
||||
model_util.save_diffusers_checkpoint(
|
||||
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
|
||||
)
|
||||
@@ -3062,16 +3365,17 @@ def save_sd_model_on_epoch_end_or_stepwise(
|
||||
print(f"removing old model: {remove_out_dir}")
|
||||
shutil.rmtree(remove_out_dir)
|
||||
|
||||
if on_epoch_end:
|
||||
save_and_remove_state_on_epoch_end(args, accelerator, epoch_no)
|
||||
else:
|
||||
save_and_remove_state_stepwise(args, accelerator, global_step)
|
||||
if args.save_state:
|
||||
if on_epoch_end:
|
||||
save_and_remove_state_on_epoch_end(args, accelerator, epoch_no)
|
||||
else:
|
||||
save_and_remove_state_stepwise(args, accelerator, global_step)
|
||||
|
||||
|
||||
def save_and_remove_state_on_epoch_end(args: argparse.Namespace, accelerator, epoch_no):
|
||||
model_name = default_if_none(args.output_name, DEFAULT_EPOCH_NAME)
|
||||
|
||||
print(f"saving state at epoch {epoch_no}")
|
||||
print(f"\nsaving state at epoch {epoch_no}")
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
state_dir = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))
|
||||
@@ -3092,7 +3396,7 @@ def save_and_remove_state_on_epoch_end(args: argparse.Namespace, accelerator, ep
|
||||
def save_and_remove_state_stepwise(args: argparse.Namespace, accelerator, step_no):
|
||||
model_name = default_if_none(args.output_name, DEFAULT_STEP_NAME)
|
||||
|
||||
print(f"saving state at step {step_no}")
|
||||
print(f"\nsaving state at step {step_no}")
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
state_dir = os.path.join(args.output_dir, STEP_STATE_NAME.format(model_name, step_no))
|
||||
@@ -3117,7 +3421,7 @@ def save_and_remove_state_stepwise(args: argparse.Namespace, accelerator, step_n
|
||||
def save_state_on_train_end(args: argparse.Namespace, accelerator):
|
||||
model_name = default_if_none(args.output_name, DEFAULT_LAST_OUTPUT_NAME)
|
||||
|
||||
print("saving last state.")
|
||||
print("\nsaving last state.")
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
state_dir = os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name))
|
||||
@@ -3189,7 +3493,7 @@ def sample_images(
|
||||
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
|
||||
return
|
||||
|
||||
print(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
|
||||
print(f"\ngenerating sample images at step / サンプル画像生成 ステップ: {steps}")
|
||||
if not os.path.isfile(args.sample_prompts):
|
||||
print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
|
||||
return
|
||||
@@ -3198,8 +3502,21 @@ def sample_images(
|
||||
vae.to(device)
|
||||
|
||||
# read prompts
|
||||
with open(args.sample_prompts, "rt", encoding="utf-8") as f:
|
||||
prompts = f.readlines()
|
||||
|
||||
# with open(args.sample_prompts, "rt", encoding="utf-8") as f:
|
||||
# prompts = f.readlines()
|
||||
|
||||
if args.sample_prompts.endswith(".txt"):
|
||||
with open(args.sample_prompts, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
|
||||
elif args.sample_prompts.endswith(".toml"):
|
||||
with open(args.sample_prompts, "r", encoding="utf-8") as f:
|
||||
data = toml.load(f)
|
||||
prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]]
|
||||
elif args.sample_prompts.endswith(".json"):
|
||||
with open(args.sample_prompts, "r", encoding="utf-8") as f:
|
||||
prompts = json.load(f)
|
||||
|
||||
# schedulerを用意する
|
||||
sched_init_args = {}
|
||||
@@ -3262,60 +3579,70 @@ def sample_images(
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
rng_state = torch.get_rng_state()
|
||||
cuda_rng_state = torch.cuda.get_rng_state()
|
||||
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
|
||||
|
||||
with torch.no_grad():
|
||||
with accelerator.autocast():
|
||||
for i, prompt in enumerate(prompts):
|
||||
if not accelerator.is_main_process:
|
||||
continue
|
||||
prompt = prompt.strip()
|
||||
if len(prompt) == 0 or prompt[0] == "#":
|
||||
continue
|
||||
|
||||
# subset of gen_img_diffusers
|
||||
prompt_args = prompt.split(" --")
|
||||
prompt = prompt_args[0]
|
||||
negative_prompt = None
|
||||
sample_steps = 30
|
||||
width = height = 512
|
||||
scale = 7.5
|
||||
seed = None
|
||||
for parg in prompt_args:
|
||||
try:
|
||||
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
width = int(m.group(1))
|
||||
continue
|
||||
if isinstance(prompt, dict):
|
||||
negative_prompt = prompt.get("negative_prompt")
|
||||
sample_steps = prompt.get("sample_steps", 30)
|
||||
width = prompt.get("width", 512)
|
||||
height = prompt.get("height", 512)
|
||||
scale = prompt.get("scale", 7.5)
|
||||
seed = prompt.get("seed")
|
||||
prompt = prompt.get("prompt")
|
||||
else:
|
||||
# prompt = prompt.strip()
|
||||
# if len(prompt) == 0 or prompt[0] == "#":
|
||||
# continue
|
||||
|
||||
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
height = int(m.group(1))
|
||||
continue
|
||||
# subset of gen_img_diffusers
|
||||
prompt_args = prompt.split(" --")
|
||||
prompt = prompt_args[0]
|
||||
negative_prompt = None
|
||||
sample_steps = 30
|
||||
width = height = 512
|
||||
scale = 7.5
|
||||
seed = None
|
||||
for parg in prompt_args:
|
||||
try:
|
||||
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
width = int(m.group(1))
|
||||
continue
|
||||
|
||||
m = re.match(r"d (\d+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
seed = int(m.group(1))
|
||||
continue
|
||||
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
height = int(m.group(1))
|
||||
continue
|
||||
|
||||
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
|
||||
if m: # steps
|
||||
sample_steps = max(1, min(1000, int(m.group(1))))
|
||||
continue
|
||||
m = re.match(r"d (\d+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
seed = int(m.group(1))
|
||||
continue
|
||||
|
||||
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # scale
|
||||
scale = float(m.group(1))
|
||||
continue
|
||||
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
|
||||
if m: # steps
|
||||
sample_steps = max(1, min(1000, int(m.group(1))))
|
||||
continue
|
||||
|
||||
m = re.match(r"n (.+)", parg, re.IGNORECASE)
|
||||
if m: # negative prompt
|
||||
negative_prompt = m.group(1)
|
||||
continue
|
||||
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # scale
|
||||
scale = float(m.group(1))
|
||||
continue
|
||||
|
||||
except ValueError as ex:
|
||||
print(f"Exception in parsing / 解析エラー: {parg}")
|
||||
print(ex)
|
||||
m = re.match(r"n (.+)", parg, re.IGNORECASE)
|
||||
if m: # negative prompt
|
||||
negative_prompt = m.group(1)
|
||||
continue
|
||||
|
||||
except ValueError as ex:
|
||||
print(f"Exception in parsing / 解析エラー: {parg}")
|
||||
print(ex)
|
||||
|
||||
if seed is not None:
|
||||
torch.manual_seed(seed)
|
||||
@@ -3369,7 +3696,8 @@ def sample_images(
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
torch.set_rng_state(rng_state)
|
||||
torch.cuda.set_rng_state(cuda_rng_state)
|
||||
if cuda_rng_state is not None:
|
||||
torch.cuda.set_rng_state(cuda_rng_state)
|
||||
vae.to(org_vae_device)
|
||||
|
||||
|
||||
|
||||
322
networks/lora.py
322
networks/lora.py
@@ -19,7 +19,17 @@ class LoRAModule(torch.nn.Module):
|
||||
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
||||
"""
|
||||
|
||||
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
|
||||
def __init__(
|
||||
self,
|
||||
lora_name,
|
||||
org_module: torch.nn.Module,
|
||||
multiplier=1.0,
|
||||
lora_dim=4,
|
||||
alpha=1,
|
||||
dropout=None,
|
||||
rank_dropout=None,
|
||||
module_dropout=None,
|
||||
):
|
||||
"""if alpha == 0 or None, alpha is rank (no scaling)."""
|
||||
super().__init__()
|
||||
self.lora_name = lora_name
|
||||
@@ -60,12 +70,87 @@ class LoRAModule(torch.nn.Module):
|
||||
|
||||
self.multiplier = multiplier
|
||||
self.org_module = org_module # remove in applying
|
||||
self.dropout = dropout
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
|
||||
def apply_to(self):
|
||||
self.org_forward = self.org_module.forward
|
||||
self.org_module.forward = self.forward
|
||||
del self.org_module
|
||||
|
||||
def forward(self, x):
|
||||
org_forwarded = self.org_forward(x)
|
||||
|
||||
# module dropout
|
||||
if self.module_dropout is not None and self.training:
|
||||
if torch.rand(1) < self.module_dropout:
|
||||
return org_forwarded
|
||||
|
||||
lx = self.lora_down(x)
|
||||
|
||||
# normal dropout
|
||||
if self.dropout is not None and self.training:
|
||||
lx = torch.nn.functional.dropout(lx, p=self.dropout)
|
||||
|
||||
# rank dropout
|
||||
if self.rank_dropout is not None and self.training:
|
||||
mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
|
||||
if len(lx.size()) == 3:
|
||||
mask = mask.unsqueeze(1) # for Text Encoder
|
||||
elif len(lx.size()) == 4:
|
||||
mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
|
||||
lx = lx * mask
|
||||
|
||||
# scaling for rank dropout: treat as if the rank is changed
|
||||
# maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる
|
||||
scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
|
||||
else:
|
||||
scale = self.scale
|
||||
|
||||
lx = self.lora_up(lx)
|
||||
|
||||
return org_forwarded + lx * self.multiplier * scale
|
||||
|
||||
|
||||
class LoRAInfModule(LoRAModule):
|
||||
def __init__(
|
||||
self,
|
||||
lora_name,
|
||||
org_module: torch.nn.Module,
|
||||
multiplier=1.0,
|
||||
lora_dim=4,
|
||||
alpha=1,
|
||||
**kwargs,
|
||||
):
|
||||
# no dropout for inference
|
||||
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
|
||||
|
||||
self.org_module_ref = [org_module] # 後から参照できるように
|
||||
self.enabled = True
|
||||
|
||||
# check regional or not by lora_name
|
||||
self.text_encoder = False
|
||||
if lora_name.startswith("lora_te_"):
|
||||
self.regional = False
|
||||
self.use_sub_prompt = True
|
||||
self.text_encoder = True
|
||||
elif "attn2_to_k" in lora_name or "attn2_to_v" in lora_name:
|
||||
self.regional = False
|
||||
self.use_sub_prompt = True
|
||||
elif "time_emb" in lora_name:
|
||||
self.regional = False
|
||||
self.use_sub_prompt = False
|
||||
else:
|
||||
self.regional = True
|
||||
self.use_sub_prompt = False
|
||||
|
||||
self.network: LoRANetwork = None
|
||||
|
||||
def set_network(self, network):
|
||||
self.network = network
|
||||
|
||||
# freezeしてマージする
|
||||
def merge_to(self, sd, dtype, device):
|
||||
# get up/down weight
|
||||
up_weight = sd["lora_up.weight"].to(torch.float).to(device)
|
||||
@@ -97,44 +182,45 @@ class LoRAModule(torch.nn.Module):
|
||||
org_sd["weight"] = weight.to(dtype)
|
||||
self.org_module.load_state_dict(org_sd)
|
||||
|
||||
# 復元できるマージのため、このモジュールのweightを返す
|
||||
def get_weight(self, multiplier=None):
|
||||
if multiplier is None:
|
||||
multiplier = self.multiplier
|
||||
|
||||
# get up/down weight from module
|
||||
up_weight = self.lora_up.weight.to(torch.float)
|
||||
down_weight = self.lora_down.weight.to(torch.float)
|
||||
|
||||
# pre-calculated weight
|
||||
if len(down_weight.size()) == 2:
|
||||
# linear
|
||||
weight = self.multiplier * (up_weight @ down_weight) * self.scale
|
||||
elif down_weight.size()[2:4] == (1, 1):
|
||||
# conv2d 1x1
|
||||
weight = (
|
||||
self.multiplier
|
||||
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||
* self.scale
|
||||
)
|
||||
else:
|
||||
# conv2d 3x3
|
||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
weight = self.multiplier * conved * self.scale
|
||||
|
||||
return weight
|
||||
|
||||
def set_region(self, region):
|
||||
self.region = region
|
||||
self.region_mask = None
|
||||
|
||||
def forward(self, x):
|
||||
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||
|
||||
|
||||
class LoRAInfModule(LoRAModule):
|
||||
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
|
||||
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
|
||||
|
||||
# check regional or not by lora_name
|
||||
self.text_encoder = False
|
||||
if lora_name.startswith("lora_te_"):
|
||||
self.regional = False
|
||||
self.use_sub_prompt = True
|
||||
self.text_encoder = True
|
||||
elif "attn2_to_k" in lora_name or "attn2_to_v" in lora_name:
|
||||
self.regional = False
|
||||
self.use_sub_prompt = True
|
||||
elif "time_emb" in lora_name:
|
||||
self.regional = False
|
||||
self.use_sub_prompt = False
|
||||
else:
|
||||
self.regional = True
|
||||
self.use_sub_prompt = False
|
||||
|
||||
self.network: LoRANetwork = None
|
||||
|
||||
def set_network(self, network):
|
||||
self.network = network
|
||||
|
||||
def default_forward(self, x):
|
||||
# print("default_forward", self.lora_name, x.size())
|
||||
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||
|
||||
def forward(self, x):
|
||||
if not self.enabled:
|
||||
return self.org_forward(x)
|
||||
|
||||
if self.network is None or self.network.sub_prompt_index is None:
|
||||
return self.default_forward(x)
|
||||
if not self.regional and not self.use_sub_prompt:
|
||||
@@ -285,7 +371,36 @@ class LoRAInfModule(LoRAModule):
|
||||
return out
|
||||
|
||||
|
||||
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
||||
def parse_block_lr_kwargs(nw_kwargs):
|
||||
down_lr_weight = nw_kwargs.get("down_lr_weight", None)
|
||||
mid_lr_weight = nw_kwargs.get("mid_lr_weight", None)
|
||||
up_lr_weight = nw_kwargs.get("up_lr_weight", None)
|
||||
|
||||
# 以上のいずれにも設定がない場合は無効としてNoneを返す
|
||||
if down_lr_weight is None and mid_lr_weight is None and up_lr_weight is None:
|
||||
return None, None, None
|
||||
|
||||
# extract learning rate weight for each block
|
||||
if down_lr_weight is not None:
|
||||
# if some parameters are not set, use zero
|
||||
if "," in down_lr_weight:
|
||||
down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")]
|
||||
|
||||
if mid_lr_weight is not None:
|
||||
mid_lr_weight = float(mid_lr_weight)
|
||||
|
||||
if up_lr_weight is not None:
|
||||
if "," in up_lr_weight:
|
||||
up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")]
|
||||
|
||||
down_lr_weight, mid_lr_weight, up_lr_weight = get_block_lr_weight(
|
||||
down_lr_weight, mid_lr_weight, up_lr_weight, float(nw_kwargs.get("block_lr_zero_threshold", 0.0))
|
||||
)
|
||||
|
||||
return down_lr_weight, mid_lr_weight, up_lr_weight
|
||||
|
||||
|
||||
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, neuron_dropout=None, **kwargs):
|
||||
if network_dim is None:
|
||||
network_dim = 4 # default
|
||||
if network_alpha is None:
|
||||
@@ -303,9 +418,7 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
|
||||
|
||||
# block dim/alpha/lr
|
||||
block_dims = kwargs.get("block_dims", None)
|
||||
down_lr_weight = kwargs.get("down_lr_weight", None)
|
||||
mid_lr_weight = kwargs.get("mid_lr_weight", None)
|
||||
up_lr_weight = kwargs.get("up_lr_weight", None)
|
||||
down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs)
|
||||
|
||||
# 以上のいずれかに指定があればblockごとのdim(rank)を有効にする
|
||||
if block_dims is not None or down_lr_weight is not None or mid_lr_weight is not None or up_lr_weight is not None:
|
||||
@@ -317,23 +430,6 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
|
||||
block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
|
||||
)
|
||||
|
||||
# extract learning rate weight for each block
|
||||
if down_lr_weight is not None:
|
||||
# if some parameters are not set, use zero
|
||||
if "," in down_lr_weight:
|
||||
down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")]
|
||||
|
||||
if mid_lr_weight is not None:
|
||||
mid_lr_weight = float(mid_lr_weight)
|
||||
|
||||
if up_lr_weight is not None:
|
||||
if "," in up_lr_weight:
|
||||
up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")]
|
||||
|
||||
down_lr_weight, mid_lr_weight, up_lr_weight = get_block_lr_weight(
|
||||
down_lr_weight, mid_lr_weight, up_lr_weight, float(kwargs.get("block_lr_zero_threshold", 0.0))
|
||||
)
|
||||
|
||||
# remove block dim/alpha without learning rate
|
||||
block_dims, block_alphas, conv_block_dims, conv_block_alphas = remove_block_dims_and_alphas(
|
||||
block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
|
||||
@@ -344,6 +440,14 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
|
||||
conv_block_dims = None
|
||||
conv_block_alphas = None
|
||||
|
||||
# rank/module dropout
|
||||
rank_dropout = kwargs.get("rank_dropout", None)
|
||||
if rank_dropout is not None:
|
||||
rank_dropout = float(rank_dropout)
|
||||
module_dropout = kwargs.get("module_dropout", None)
|
||||
if module_dropout is not None:
|
||||
module_dropout = float(module_dropout)
|
||||
|
||||
# すごく引数が多いな ( ^ω^)・・・
|
||||
network = LoRANetwork(
|
||||
text_encoder,
|
||||
@@ -351,6 +455,9 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
|
||||
multiplier=multiplier,
|
||||
lora_dim=network_dim,
|
||||
alpha=network_alpha,
|
||||
dropout=neuron_dropout,
|
||||
rank_dropout=rank_dropout,
|
||||
module_dropout=module_dropout,
|
||||
conv_lora_dim=conv_dim,
|
||||
conv_alpha=conv_alpha,
|
||||
block_dims=block_dims,
|
||||
@@ -593,13 +700,19 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
|
||||
# support old LoRA without alpha
|
||||
for key in modules_dim.keys():
|
||||
if key not in modules_alpha:
|
||||
modules_alpha = modules_dim[key]
|
||||
modules_alpha[key] = modules_dim[key]
|
||||
|
||||
module_class = LoRAInfModule if for_inference else LoRAModule
|
||||
|
||||
network = LoRANetwork(
|
||||
text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class
|
||||
)
|
||||
|
||||
# block lr
|
||||
down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs)
|
||||
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
|
||||
network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
|
||||
|
||||
return network, weights_sd
|
||||
|
||||
|
||||
@@ -620,6 +733,9 @@ class LoRANetwork(torch.nn.Module):
|
||||
multiplier=1.0,
|
||||
lora_dim=4,
|
||||
alpha=1,
|
||||
dropout=None,
|
||||
rank_dropout=None,
|
||||
module_dropout=None,
|
||||
conv_lora_dim=None,
|
||||
conv_alpha=None,
|
||||
block_dims=None,
|
||||
@@ -646,11 +762,15 @@ class LoRANetwork(torch.nn.Module):
|
||||
self.alpha = alpha
|
||||
self.conv_lora_dim = conv_lora_dim
|
||||
self.conv_alpha = conv_alpha
|
||||
self.dropout = dropout
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
|
||||
if modules_dim is not None:
|
||||
print(f"create LoRA network from weights")
|
||||
elif block_dims is not None:
|
||||
print(f"create LoRA network from block_dims")
|
||||
print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
|
||||
print(f"block_dims: {block_dims}")
|
||||
print(f"block_alphas: {block_alphas}")
|
||||
if conv_block_dims is not None:
|
||||
@@ -658,6 +778,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
print(f"conv_block_alphas: {conv_block_alphas}")
|
||||
else:
|
||||
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
||||
print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
|
||||
if self.conv_lora_dim is not None:
|
||||
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
||||
|
||||
@@ -704,7 +825,16 @@ class LoRANetwork(torch.nn.Module):
|
||||
skipped.append(lora_name)
|
||||
continue
|
||||
|
||||
lora = module_class(lora_name, child_module, self.multiplier, dim, alpha)
|
||||
lora = module_class(
|
||||
lora_name,
|
||||
child_module,
|
||||
self.multiplier,
|
||||
dim,
|
||||
alpha,
|
||||
dropout=dropout,
|
||||
rank_dropout=rank_dropout,
|
||||
module_dropout=module_dropout,
|
||||
)
|
||||
loras.append(lora)
|
||||
return loras, skipped
|
||||
|
||||
@@ -769,6 +899,10 @@ class LoRANetwork(torch.nn.Module):
|
||||
lora.apply_to()
|
||||
self.add_module(lora.lora_name, lora)
|
||||
|
||||
# マージできるかどうかを返す
|
||||
def is_mergeable(self):
|
||||
return True
|
||||
|
||||
# TODO refactor to common function with apply_to
|
||||
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
|
||||
apply_text_encoder = apply_unet = False
|
||||
@@ -797,7 +931,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
print(f"weights are merged")
|
||||
|
||||
# 層別学習率用に層ごとの学習率に対する倍率を定義する
|
||||
# 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない
|
||||
def set_block_lr_weight(
|
||||
self,
|
||||
up_lr_weight: List[float] = None,
|
||||
@@ -955,3 +1089,83 @@ class LoRANetwork(torch.nn.Module):
|
||||
w = (w + 1) // 2
|
||||
|
||||
self.mask_dic = mask_dic
|
||||
|
||||
def backup_weights(self):
|
||||
# 重みのバックアップを行う
|
||||
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
|
||||
for lora in loras:
|
||||
org_module = lora.org_module_ref[0]
|
||||
if not hasattr(org_module, "_lora_org_weight"):
|
||||
sd = org_module.state_dict()
|
||||
org_module._lora_org_weight = sd["weight"].detach().clone()
|
||||
org_module._lora_restored = True
|
||||
|
||||
def restore_weights(self):
|
||||
# 重みのリストアを行う
|
||||
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
|
||||
for lora in loras:
|
||||
org_module = lora.org_module_ref[0]
|
||||
if not org_module._lora_restored:
|
||||
sd = org_module.state_dict()
|
||||
sd["weight"] = org_module._lora_org_weight
|
||||
org_module.load_state_dict(sd)
|
||||
org_module._lora_restored = True
|
||||
|
||||
def pre_calculation(self):
|
||||
# 事前計算を行う
|
||||
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
|
||||
for lora in loras:
|
||||
org_module = lora.org_module_ref[0]
|
||||
sd = org_module.state_dict()
|
||||
|
||||
org_weight = sd["weight"]
|
||||
lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype)
|
||||
sd["weight"] = org_weight + lora_weight
|
||||
assert sd["weight"].shape == org_weight.shape
|
||||
org_module.load_state_dict(sd)
|
||||
|
||||
org_module._lora_restored = False
|
||||
lora.enabled = False
|
||||
|
||||
def apply_max_norm_regularization(self, max_norm_value, device):
|
||||
downkeys = []
|
||||
upkeys = []
|
||||
alphakeys = []
|
||||
norms = []
|
||||
keys_scaled = 0
|
||||
|
||||
state_dict = self.state_dict()
|
||||
for key in state_dict.keys():
|
||||
if "lora_down" in key and "weight" in key:
|
||||
downkeys.append(key)
|
||||
upkeys.append(key.replace("lora_down", "lora_up"))
|
||||
alphakeys.append(key.replace("lora_down.weight", "alpha"))
|
||||
|
||||
for i in range(len(downkeys)):
|
||||
down = state_dict[downkeys[i]].to(device)
|
||||
up = state_dict[upkeys[i]].to(device)
|
||||
alpha = state_dict[alphakeys[i]].to(device)
|
||||
dim = down.shape[0]
|
||||
scale = alpha / dim
|
||||
|
||||
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
|
||||
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
|
||||
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
|
||||
else:
|
||||
updown = up @ down
|
||||
|
||||
updown *= scale
|
||||
|
||||
norm = updown.norm().clamp(min=max_norm_value / 2)
|
||||
desired = torch.clamp(norm, max=max_norm_value)
|
||||
ratio = desired.cpu() / norm.cpu()
|
||||
sqrt_ratio = ratio**0.5
|
||||
if ratio != 1:
|
||||
keys_scaled += 1
|
||||
state_dict[upkeys[i]] *= sqrt_ratio
|
||||
state_dict[downkeys[i]] *= sqrt_ratio
|
||||
scalednorm = updown.norm() * ratio
|
||||
norms.append(scalednorm.item())
|
||||
|
||||
return keys_scaled, sum(norms) / len(norms), max(norms)
|
||||
|
||||
@@ -23,7 +23,7 @@ def interrogate(args):
|
||||
print(f"loading SD model: {args.sd_model}")
|
||||
args.pretrained_model_name_or_path = args.sd_model
|
||||
args.vae = None
|
||||
text_encoder, vae, unet, _ = train_util.load_target_model(args,weights_dtype, DEVICE)
|
||||
text_encoder, vae, unet, _ = train_util._load_target_model(args,weights_dtype, DEVICE)
|
||||
|
||||
print(f"loading LoRA: {args.model}")
|
||||
network, weights_sd = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)
|
||||
|
||||
@@ -148,13 +148,13 @@ def merge(args):
|
||||
|
||||
merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype)
|
||||
|
||||
print(f"saving SD model to: {args.save_to}")
|
||||
print(f"\nsaving SD model to: {args.save_to}")
|
||||
model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet,
|
||||
args.sd_model, 0, 0, save_dtype, vae)
|
||||
else:
|
||||
state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype)
|
||||
|
||||
print(f"saving model to: {args.save_to}")
|
||||
print(f"\nsaving model to: {args.save_to}")
|
||||
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
|
||||
|
||||
|
||||
|
||||
@@ -219,8 +219,8 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
|
||||
for key, value in tqdm(lora_sd.items()):
|
||||
weight_name = None
|
||||
if 'lora_down' in key:
|
||||
block_down_name = key.split(".")[0]
|
||||
weight_name = key.split(".")[-1]
|
||||
block_down_name = key.rsplit('.lora_down', 1)[0]
|
||||
weight_name = key.rsplit(".", 1)[-1]
|
||||
lora_down_weight = value
|
||||
else:
|
||||
continue
|
||||
@@ -283,7 +283,10 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
|
||||
|
||||
|
||||
def resize(args):
|
||||
if args.save_to is None or not (args.save_to.endswith('.ckpt') or args.save_to.endswith('.pt') or args.save_to.endswith('.pth') or args.save_to.endswith('.safetensors')):
|
||||
raise Exception("The --save_to argument must be specified and must be a .ckpt , .pt, .pth or .safetensors file.")
|
||||
|
||||
|
||||
def str_to_dtype(p):
|
||||
if p == 'float':
|
||||
return torch.float
|
||||
|
||||
@@ -9,7 +9,7 @@ pytorch-lightning==1.9.0
|
||||
bitsandbytes==0.35.0
|
||||
tensorboard==2.10.1
|
||||
safetensors==0.2.6
|
||||
gradio==3.16.2
|
||||
# gradio==3.16.2
|
||||
altair==4.2.2
|
||||
easygui==0.98.3
|
||||
toml==0.10.2
|
||||
@@ -21,6 +21,6 @@ fairscale==0.4.13
|
||||
# for WD14 captioning
|
||||
# tensorflow<2.11
|
||||
tensorflow==2.10.1
|
||||
huggingface-hub==0.13.3
|
||||
huggingface-hub==0.15.1
|
||||
# for kohya_ss library
|
||||
.
|
||||
|
||||
@@ -24,9 +24,9 @@ def convert(args):
|
||||
is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0
|
||||
|
||||
assert not is_load_ckpt or args.v1 != args.v2, f"v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です"
|
||||
assert (
|
||||
is_save_ckpt or args.reference_model is not None
|
||||
), f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です"
|
||||
# assert (
|
||||
# is_save_ckpt or args.reference_model is not None
|
||||
# ), f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です"
|
||||
|
||||
# モデルを読み込む
|
||||
msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else ""))
|
||||
@@ -34,7 +34,7 @@ def convert(args):
|
||||
|
||||
if is_load_ckpt:
|
||||
v2_model = args.v2
|
||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load)
|
||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load, unet_use_linear_projection_in_v2=args.unet_use_linear_projection)
|
||||
else:
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None
|
||||
@@ -61,7 +61,7 @@ def convert(args):
|
||||
)
|
||||
print(f"model saved. total converted state_dict keys: {key_count}")
|
||||
else:
|
||||
print(f"copy scheduler/tokenizer config from: {args.reference_model}")
|
||||
print(f"copy scheduler/tokenizer config from: {args.reference_model if args.reference_model is not None else 'default model'}")
|
||||
model_util.save_diffusers_checkpoint(
|
||||
v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors
|
||||
)
|
||||
@@ -76,6 +76,9 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument(
|
||||
"--v2", action="store_true", help="load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--unet_use_linear_projection", action="store_true", help="When saving v2 model as Diffusers, set U-Net config to `use_linear_projection=true` (to match stabilityai's model) / Diffusers形式でv2モデルを保存するときにU-Netの設定を`use_linear_projection=true`にする(stabilityaiのモデルと合わせる)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fp16",
|
||||
action="store_true",
|
||||
@@ -100,7 +103,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
"--reference_model",
|
||||
type=str,
|
||||
default=None,
|
||||
help="reference model for schduler/tokenizer, required in saving Diffusers, copy schduler/tokenizer from this / scheduler/tokenizerのコピー元のDiffusersモデル、Diffusers形式で保存するときに必要",
|
||||
help="scheduler/tokenizerのコピー元Diffusersモデル、Diffusers形式で保存するときに使用される、省略時は`runwayml/stable-diffusion-v1-5` または `stabilityai/stable-diffusion-2-1` / reference Diffusers model to copy scheduler/tokenizer config from, used when saving as Diffusers format, default is `runwayml/stable-diffusion-v1-5` or `stabilityai/stable-diffusion-2-1`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_safetensors",
|
||||
|
||||
@@ -62,7 +62,7 @@ def load_control_net(v2, unet, model):
|
||||
|
||||
# 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む
|
||||
is_difference = "difference" in ctrl_sd_sd
|
||||
print("ControlNet: loading difference")
|
||||
print("ControlNet: loading difference:", is_difference)
|
||||
|
||||
# ControlNetには存在しないキーがあるので、まず現在のU-NetでSD版の全keyを作っておく
|
||||
# またTransfer Controlの元weightとなる
|
||||
@@ -123,7 +123,8 @@ def load_preprocess(prep_type: str):
|
||||
|
||||
def preprocess_ctrl_net_hint_image(image):
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[:, :, ::-1].copy() # rgb to bgr
|
||||
# ControlNetのサンプルはcv2を使っているが、読み込みはGradioなので実はRGBになっている
|
||||
# image = image[:, :, ::-1].copy() # rgb to bgr
|
||||
image = image[None].transpose(0, 3, 1, 2) # nchw
|
||||
image = torch.from_numpy(image)
|
||||
return image # 0 to 1
|
||||
|
||||
73
train_db.py
73
train_db.py
@@ -23,7 +23,16 @@ from library.config_util import (
|
||||
BlueprintGenerator,
|
||||
)
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings
|
||||
from library.custom_train_functions import (
|
||||
apply_snr_weight,
|
||||
get_weighted_text_embeddings,
|
||||
prepare_scheduler_for_custom_training,
|
||||
pyramid_noise_like,
|
||||
apply_noise_offset,
|
||||
scale_v_prediction_loss_like_noise_prediction,
|
||||
)
|
||||
|
||||
# perlin_noise,
|
||||
|
||||
|
||||
def train(args):
|
||||
@@ -37,26 +46,30 @@ def train(args):
|
||||
|
||||
tokenizer = train_util.load_tokenizer(args)
|
||||
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True))
|
||||
if args.dataset_config is not None:
|
||||
print(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "reg_data_dir"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
print(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
# データセットを準備する
|
||||
if args.dataset_class is None:
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True))
|
||||
if args.dataset_config is not None:
|
||||
print(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "reg_data_dir"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
print(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
|
||||
]
|
||||
}
|
||||
else:
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
|
||||
]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
@@ -92,7 +105,7 @@ def train(args):
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
|
||||
# モデルを読み込む
|
||||
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype)
|
||||
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||
|
||||
# verify load/save model formats
|
||||
if load_stable_diffusion_format:
|
||||
@@ -196,6 +209,9 @@ def train(args):
|
||||
else:
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
# transform DDP after prepare
|
||||
text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet)
|
||||
|
||||
if not train_text_encoder:
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error
|
||||
|
||||
@@ -230,6 +246,7 @@ def train(args):
|
||||
noise_scheduler = DDPMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
||||
)
|
||||
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name)
|
||||
@@ -237,7 +254,7 @@ def train(args):
|
||||
loss_list = []
|
||||
loss_total = 0.0
|
||||
for epoch in range(num_train_epochs):
|
||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||
print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
# 指定したステップ数までText Encoderを学習する:epoch最初の状態
|
||||
@@ -268,8 +285,11 @@ def train(args):
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents, device=latents.device)
|
||||
if args.noise_offset:
|
||||
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
||||
noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale)
|
||||
elif args.multires_noise_iterations:
|
||||
noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount)
|
||||
# elif args.perlin_noise:
|
||||
# noise = perlin_noise(noise, latents.device, args.perlin_noise) # only shape of noise is used currently
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):
|
||||
@@ -297,7 +317,8 @@ def train(args):
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Predict the noise residual
|
||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
with accelerator.autocast():
|
||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
if args.v_parameterization:
|
||||
# v-parameterization training
|
||||
@@ -313,6 +334,8 @@ def train(args):
|
||||
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
@@ -361,7 +384,7 @@ def train(args):
|
||||
current_loss = loss.detach().item()
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
||||
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): # tracking d*lr value
|
||||
logs["lr/d*lr"] = (
|
||||
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
|
||||
)
|
||||
|
||||
262
train_network.py
262
train_network.py
@@ -1,4 +1,3 @@
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
import importlib
|
||||
import argparse
|
||||
import gc
|
||||
@@ -26,13 +25,27 @@ from library.config_util import (
|
||||
)
|
||||
import library.huggingface_util as huggingface_util
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings
|
||||
from library.custom_train_functions import (
|
||||
apply_snr_weight,
|
||||
get_weighted_text_embeddings,
|
||||
prepare_scheduler_for_custom_training,
|
||||
pyramid_noise_like,
|
||||
apply_noise_offset,
|
||||
scale_v_prediction_loss_like_noise_prediction,
|
||||
)
|
||||
|
||||
|
||||
# TODO 他のスクリプトと共通化する
|
||||
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
|
||||
def generate_step_logs(
|
||||
args: argparse.Namespace, current_loss, avr_loss, lr_scheduler, keys_scaled=None, mean_norm=None, maximum_norm=None
|
||||
):
|
||||
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
||||
|
||||
if keys_scaled is not None:
|
||||
logs["max_norm/keys_scaled"] = keys_scaled
|
||||
logs["max_norm/average_key_norm"] = mean_norm
|
||||
logs["max_norm/max_key_norm"] = maximum_norm
|
||||
|
||||
lrs = lr_scheduler.get_last_lr()
|
||||
|
||||
if args.network_train_text_encoder_only or len(lrs) <= 2: # not block lr (or single block)
|
||||
@@ -44,7 +57,7 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche
|
||||
logs["lr/textencoder"] = float(lrs[0])
|
||||
logs["lr/unet"] = float(lrs[-1]) # may be same to textencoder
|
||||
|
||||
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet.
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): # tracking d*lr value of unet.
|
||||
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
|
||||
else:
|
||||
idx = 0
|
||||
@@ -54,7 +67,7 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche
|
||||
|
||||
for i in range(idx, len(lrs)):
|
||||
logs[f"lr/group{i}"] = float(lrs[i])
|
||||
if args.optimizer_type.lower() == "DAdaptation".lower():
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower():
|
||||
logs[f"lr/d*lr/group{i}"] = (
|
||||
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
|
||||
)
|
||||
@@ -79,42 +92,50 @@ def train(args):
|
||||
tokenizer = train_util.load_tokenizer(args)
|
||||
|
||||
# データセットを準備する
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True))
|
||||
if use_user_config:
|
||||
print(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
print(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
if args.dataset_class is None:
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True))
|
||||
if use_user_config:
|
||||
print(f"Loading dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
print(
|
||||
"ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
if use_dreambooth_method:
|
||||
print("Use DreamBooth method.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
|
||||
]
|
||||
}
|
||||
else:
|
||||
print("Train with captions.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": [
|
||||
{
|
||||
"image_dir": args.train_data_dir,
|
||||
"metadata_file": args.in_json,
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
if use_dreambooth_method:
|
||||
print("Using DreamBooth method.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
|
||||
args.train_data_dir, args.reg_data_dir
|
||||
)
|
||||
}
|
||||
]
|
||||
}
|
||||
else:
|
||||
print("Training with captions.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": [
|
||||
{
|
||||
"image_dir": args.train_data_dir,
|
||||
"metadata_file": args.in_json,
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
# use arbitrary dataset class
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
@@ -136,7 +157,7 @@ def train(args):
|
||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
|
||||
# acceleratorを準備する
|
||||
print("prepare accelerator")
|
||||
print("preparing accelerator")
|
||||
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
||||
is_main_process = accelerator.is_main_process
|
||||
|
||||
@@ -144,28 +165,35 @@ def train(args):
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
|
||||
# モデルを読み込む
|
||||
for pi in range(accelerator.state.num_processes):
|
||||
# TODO: modify other training scripts as well
|
||||
if pi == accelerator.state.local_process_index:
|
||||
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
|
||||
|
||||
text_encoder, vae, unet, _ = train_util.load_target_model(
|
||||
args, weight_dtype, accelerator.device if args.lowram else "cpu"
|
||||
)
|
||||
|
||||
# work on low-ram device
|
||||
if args.lowram:
|
||||
text_encoder.to(accelerator.device)
|
||||
unet.to(accelerator.device)
|
||||
vae.to(accelerator.device)
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
accelerator.wait_for_everyone()
|
||||
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||
|
||||
# モデルに xformers とか memory efficient attention を組み込む
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
||||
|
||||
# 差分追加学習のためにモデルを読み込む
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
print("import network module:", args.network_module)
|
||||
network_module = importlib.import_module(args.network_module)
|
||||
|
||||
if args.base_weights is not None:
|
||||
# base_weights が指定されている場合は、指定された重みを読み込みマージする
|
||||
for i, weight_path in enumerate(args.base_weights):
|
||||
if args.base_weights_multiplier is None or len(args.base_weights_multiplier) <= i:
|
||||
multiplier = 1.0
|
||||
else:
|
||||
multiplier = args.base_weights_multiplier[i]
|
||||
|
||||
print(f"merging module: {weight_path} with multiplier {multiplier}")
|
||||
|
||||
module, weights_sd = network_module.create_network_from_weights(
|
||||
multiplier, weight_path, vae, text_encoder, unet, for_inference=True
|
||||
)
|
||||
module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu")
|
||||
|
||||
print(f"all weights merged: {', '.join(args.base_weights)}")
|
||||
|
||||
# 学習を準備する
|
||||
if cache_latents:
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
@@ -181,12 +209,6 @@ def train(args):
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# prepare network
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
print("import network module:", args.network_module)
|
||||
network_module = importlib.import_module(args.network_module)
|
||||
|
||||
net_kwargs = {}
|
||||
if args.network_args is not None:
|
||||
for net_arg in args.network_args:
|
||||
@@ -194,12 +216,23 @@ def train(args):
|
||||
net_kwargs[key] = value
|
||||
|
||||
# if a new network is added in future, add if ~ then blocks for each network (;'∀')
|
||||
network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs)
|
||||
if args.dim_from_weights:
|
||||
network, _ = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet, **net_kwargs)
|
||||
else:
|
||||
# LyCORIS will work with this...
|
||||
network = network_module.create_network(
|
||||
1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, neuron_dropout=args.network_dropout, **net_kwargs
|
||||
)
|
||||
if network is None:
|
||||
return
|
||||
|
||||
if hasattr(network, "prepare_network"):
|
||||
network.prepare_network(args)
|
||||
if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"):
|
||||
print(
|
||||
"warning: scale_weight_norms is specified but the network does not support it / scale_weight_normsが指定されていますが、ネットワークが対応していません"
|
||||
)
|
||||
args.scale_weight_norms = False
|
||||
|
||||
train_unet = not args.network_train_text_encoder_only
|
||||
train_text_encoder = not args.network_train_unet_only
|
||||
@@ -207,7 +240,7 @@ def train(args):
|
||||
|
||||
if args.network_weights is not None:
|
||||
info = network.load_weights(args.network_weights)
|
||||
print(f"load network weights from {args.network_weights}: {info}")
|
||||
print(f"loaded network weights from {args.network_weights}: {info}")
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
@@ -215,7 +248,7 @@ def train(args):
|
||||
network.enable_gradient_checkpointing() # may have no effect
|
||||
|
||||
# 学習に必要なクラスを準備する
|
||||
print("prepare optimizer, data loader etc.")
|
||||
print("preparing optimizer, data loader etc.")
|
||||
|
||||
# 後方互換性を確保するよ
|
||||
try:
|
||||
@@ -260,7 +293,7 @@ def train(args):
|
||||
assert (
|
||||
args.mixed_precision == "fp16"
|
||||
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
||||
print("enable full fp16 training.")
|
||||
print("enabling full fp16 training.")
|
||||
network.to(weight_dtype)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
@@ -279,6 +312,9 @@ def train(args):
|
||||
else:
|
||||
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
# transform DDP after prepare (train_network here only)
|
||||
text_encoder, unet, network = train_util.transform_if_model_is_DDP(text_encoder, unet, network)
|
||||
|
||||
unet.requires_grad_(False)
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder.requires_grad_(False)
|
||||
@@ -288,23 +324,14 @@ def train(args):
|
||||
text_encoder.train()
|
||||
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
if type(text_encoder) == DDP:
|
||||
text_encoder.module.text_model.embeddings.requires_grad_(True)
|
||||
else:
|
||||
text_encoder.text_model.embeddings.requires_grad_(True)
|
||||
text_encoder.text_model.embeddings.requires_grad_(True)
|
||||
else:
|
||||
unet.eval()
|
||||
text_encoder.eval()
|
||||
|
||||
# support DistributedDataParallel
|
||||
if type(text_encoder) == DDP:
|
||||
text_encoder = text_encoder.module
|
||||
unet = unet.module
|
||||
network = network.module
|
||||
|
||||
network.prepare_grad_etc(text_encoder, unet)
|
||||
|
||||
if not cache_latents:
|
||||
if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
@@ -356,7 +383,8 @@ def train(args):
|
||||
"ss_lr_scheduler": args.lr_scheduler,
|
||||
"ss_network_module": args.network_module,
|
||||
"ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim
|
||||
"ss_network_alpha": args.network_alpha, # some networks may not use this value
|
||||
"ss_network_alpha": args.network_alpha, # some networks may not have alpha
|
||||
"ss_network_dropout": args.network_dropout, # some networks may not have dropout
|
||||
"ss_mixed_precision": args.mixed_precision,
|
||||
"ss_full_fp16": bool(args.full_fp16),
|
||||
"ss_v2": bool(args.v2),
|
||||
@@ -366,6 +394,9 @@ def train(args):
|
||||
"ss_seed": args.seed,
|
||||
"ss_lowram": args.lowram,
|
||||
"ss_noise_offset": args.noise_offset,
|
||||
"ss_multires_noise_iterations": args.multires_noise_iterations,
|
||||
"ss_multires_noise_discount": args.multires_noise_discount,
|
||||
"ss_adaptive_noise_scale": args.adaptive_noise_scale,
|
||||
"ss_training_comment": args.training_comment, # will not be updated after training
|
||||
"ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
|
||||
"ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""),
|
||||
@@ -376,6 +407,7 @@ def train(args):
|
||||
"ss_face_crop_aug_range": args.face_crop_aug_range,
|
||||
"ss_prior_loss_weight": args.prior_loss_weight,
|
||||
"ss_min_snr_gamma": args.min_snr_gamma,
|
||||
"ss_scale_weight_norms": args.scale_weight_norms,
|
||||
}
|
||||
|
||||
if use_user_config:
|
||||
@@ -537,6 +569,8 @@ def train(args):
|
||||
noise_scheduler = DDPMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
||||
)
|
||||
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("network_train" if args.log_tracker_name is None else args.log_tracker_name)
|
||||
|
||||
@@ -544,17 +578,18 @@ def train(args):
|
||||
loss_total = 0.0
|
||||
del train_dataset_group
|
||||
|
||||
# if hasattr(network, "on_step_start"):
|
||||
# on_step_start = network.on_step_start
|
||||
# else:
|
||||
# on_step_start = lambda *args, **kwargs: None
|
||||
# callback for step start
|
||||
if hasattr(network, "on_step_start"):
|
||||
on_step_start = network.on_step_start
|
||||
else:
|
||||
on_step_start = lambda *args, **kwargs: None
|
||||
|
||||
# function for saving/removing
|
||||
def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False):
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
|
||||
print(f"saving checkpoint: {ckpt_file}")
|
||||
print(f"\nsaving checkpoint: {ckpt_file}")
|
||||
metadata["ss_training_finished_at"] = str(time.time())
|
||||
metadata["ss_steps"] = str(steps)
|
||||
metadata["ss_epoch"] = str(epoch_no)
|
||||
@@ -572,7 +607,7 @@ def train(args):
|
||||
# training loop
|
||||
for epoch in range(num_train_epochs):
|
||||
if is_main_process:
|
||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||
print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
metadata["ss_epoch"] = str(epoch + 1)
|
||||
@@ -582,7 +617,7 @@ def train(args):
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
current_step.value = global_step
|
||||
with accelerator.accumulate(network):
|
||||
# on_step_start(text_encoder, unet)
|
||||
on_step_start(text_encoder, unet)
|
||||
|
||||
with torch.no_grad():
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
@@ -607,11 +642,13 @@ def train(args):
|
||||
else:
|
||||
input_ids = batch["input_ids"].to(accelerator.device)
|
||||
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents, device=latents.device)
|
||||
if args.noise_offset:
|
||||
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
||||
noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale)
|
||||
elif args.multires_noise_iterations:
|
||||
noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount)
|
||||
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
|
||||
@@ -638,6 +675,8 @@ def train(args):
|
||||
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
@@ -650,6 +689,14 @@ def train(args):
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
if args.scale_weight_norms:
|
||||
keys_scaled, mean_norm, maximum_norm = network.apply_max_norm_regularization(
|
||||
args.scale_weight_norms, accelerator.device
|
||||
)
|
||||
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
|
||||
else:
|
||||
keys_scaled, mean_norm, maximum_norm = None, None, None
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
@@ -685,8 +732,11 @@ def train(args):
|
||||
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if args.scale_weight_norms:
|
||||
progress_bar.set_postfix(**{**max_mean_logs, **logs})
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
|
||||
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
@@ -733,7 +783,7 @@ def train(args):
|
||||
if is_main_process:
|
||||
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
|
||||
save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True)
|
||||
|
||||
|
||||
print("model saved.")
|
||||
|
||||
|
||||
@@ -770,6 +820,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
default=1,
|
||||
help="alpha for LoRA weight scaling, default 1 (same as network_dim for same behavior as old version) / LoRaの重み調整のalpha値、デフォルト1(旧バージョンと同じ動作をするにはnetwork_dimと同じ値を指定)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--network_dropout",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Drops neurons out of training every step (0 or None is default behavior (no dropout), 1 would drop all neurons) / 訓練時に毎ステップでニューロンをdropする(0またはNoneはdropoutなし、1は全ニューロンをdropout)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--network_args", type=str, default=None, nargs="*", help="additional argmuments for network (key=value) / ネットワークへの追加の引数"
|
||||
)
|
||||
@@ -780,7 +836,31 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument(
|
||||
"--training_comment", type=str, default=None, help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dim_from_weights",
|
||||
action="store_true",
|
||||
help="automatically determine dim (rank) from network_weights / dim (rank)をnetwork_weightsで指定した重みから自動で決定する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale_weight_norms",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Scale the weight of each key pair to help prevent overtraing via exploding gradients. (1 is a good starting point) / 重みの値をスケーリングして勾配爆発を防ぐ(1が初期値としては適当)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base_weights",
|
||||
type=str,
|
||||
default=None,
|
||||
nargs="*",
|
||||
help="network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みファイル",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base_weights_multiplier",
|
||||
type=float,
|
||||
default=None,
|
||||
nargs="*",
|
||||
help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
@@ -20,7 +20,13 @@ from library.config_util import (
|
||||
BlueprintGenerator,
|
||||
)
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import apply_snr_weight
|
||||
from library.custom_train_functions import (
|
||||
apply_snr_weight,
|
||||
prepare_scheduler_for_custom_training,
|
||||
pyramid_noise_like,
|
||||
apply_noise_offset,
|
||||
scale_v_prediction_loss_like_noise_prediction,
|
||||
)
|
||||
|
||||
imagenet_templates_small = [
|
||||
"a photo of a {}",
|
||||
@@ -98,7 +104,7 @@ def train(args):
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
|
||||
# モデルを読み込む
|
||||
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
|
||||
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||
|
||||
# Convert the init_word to token_id
|
||||
if args.init_word is not None:
|
||||
@@ -147,43 +153,46 @@ def train(args):
|
||||
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
|
||||
|
||||
# データセットを準備する
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False))
|
||||
if args.dataset_config is not None:
|
||||
print(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
print(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
if args.dataset_class is None:
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False))
|
||||
if args.dataset_config is not None:
|
||||
print(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
print(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
use_dreambooth_method = args.in_json is None
|
||||
if use_dreambooth_method:
|
||||
print("Use DreamBooth method.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
|
||||
]
|
||||
}
|
||||
else:
|
||||
print("Train with captions.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": [
|
||||
{
|
||||
"image_dir": args.train_data_dir,
|
||||
"metadata_file": args.in_json,
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
use_dreambooth_method = args.in_json is None
|
||||
if use_dreambooth_method:
|
||||
print("Use DreamBooth method.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
|
||||
]
|
||||
}
|
||||
else:
|
||||
print("Train with captions.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": [
|
||||
{
|
||||
"image_dir": args.train_data_dir,
|
||||
"metadata_file": args.in_json,
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
@@ -280,6 +289,9 @@ def train(args):
|
||||
text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
# transform DDP after prepare
|
||||
text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet)
|
||||
|
||||
index_no_updates = torch.arange(len(tokenizer)) < token_ids[0]
|
||||
# print(len(index_no_updates), torch.sum(index_no_updates))
|
||||
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
|
||||
@@ -335,6 +347,7 @@ def train(args):
|
||||
noise_scheduler = DDPMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
||||
)
|
||||
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name)
|
||||
@@ -344,7 +357,7 @@ def train(args):
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
|
||||
print(f"saving checkpoint: {ckpt_file}")
|
||||
print(f"\nsaving checkpoint: {ckpt_file}")
|
||||
save_weights(ckpt_file, embs, save_dtype)
|
||||
if args.huggingface_repo_id is not None:
|
||||
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
|
||||
@@ -357,7 +370,7 @@ def train(args):
|
||||
|
||||
# training loop
|
||||
for epoch in range(num_train_epochs):
|
||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||
print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
text_encoder.train()
|
||||
@@ -384,8 +397,9 @@ def train(args):
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents, device=latents.device)
|
||||
if args.noise_offset:
|
||||
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
||||
noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale)
|
||||
elif args.multires_noise_iterations:
|
||||
noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount)
|
||||
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
|
||||
@@ -408,12 +422,14 @@ def train(args):
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
accelerator.backward(loss)
|
||||
@@ -460,7 +476,7 @@ def train(args):
|
||||
current_loss = loss.detach().item()
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
||||
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): # tracking d*lr value
|
||||
logs["lr/d*lr"] = (
|
||||
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
|
||||
)
|
||||
|
||||
@@ -20,7 +20,13 @@ from library.config_util import (
|
||||
BlueprintGenerator,
|
||||
)
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import apply_snr_weight
|
||||
from library.custom_train_functions import (
|
||||
apply_snr_weight,
|
||||
prepare_scheduler_for_custom_training,
|
||||
pyramid_noise_like,
|
||||
apply_noise_offset,
|
||||
scale_v_prediction_loss_like_noise_prediction,
|
||||
)
|
||||
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
|
||||
|
||||
imagenet_templates_small = [
|
||||
@@ -88,6 +94,9 @@ def train(args):
|
||||
print(
|
||||
"sample_every_n_steps and sample_every_n_epochs are not supported in this script currently / sample_every_n_stepsとsample_every_n_epochsは現在このスクリプトではサポートされていません"
|
||||
)
|
||||
assert (
|
||||
args.dataset_class is None
|
||||
), "dataset_class is not supported in this script currently / dataset_classは現在このスクリプトではサポートされていません"
|
||||
|
||||
cache_latents = args.cache_latents
|
||||
|
||||
@@ -104,7 +113,7 @@ def train(args):
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
|
||||
# モデルを読み込む
|
||||
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
|
||||
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||
|
||||
# Convert the init_word to token_id
|
||||
if args.init_word is not None:
|
||||
@@ -314,6 +323,9 @@ def train(args):
|
||||
text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
# transform DDP after prepare
|
||||
text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet)
|
||||
|
||||
index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0]
|
||||
# print(len(index_no_updates), torch.sum(index_no_updates))
|
||||
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
|
||||
@@ -369,6 +381,7 @@ def train(args):
|
||||
noise_scheduler = DDPMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
|
||||
)
|
||||
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name)
|
||||
@@ -378,7 +391,7 @@ def train(args):
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
|
||||
print(f"saving checkpoint: {ckpt_file}")
|
||||
print(f"\nsaving checkpoint: {ckpt_file}")
|
||||
save_weights(ckpt_file, embs, save_dtype)
|
||||
if args.huggingface_repo_id is not None:
|
||||
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
|
||||
@@ -391,7 +404,7 @@ def train(args):
|
||||
|
||||
# training loop
|
||||
for epoch in range(num_train_epochs):
|
||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||
print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
text_encoder.train()
|
||||
@@ -423,8 +436,9 @@ def train(args):
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents, device=latents.device)
|
||||
if args.noise_offset:
|
||||
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
||||
noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale)
|
||||
elif args.multires_noise_iterations:
|
||||
noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount)
|
||||
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
|
||||
@@ -447,11 +461,13 @@ def train(args):
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
|
||||
loss = loss * loss_weights
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
@@ -499,7 +515,7 @@ def train(args):
|
||||
current_loss = loss.detach().item()
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
||||
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
|
||||
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): # tracking d*lr value
|
||||
logs["lr/d*lr"] = (
|
||||
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user