Compare commits

..

127 Commits

Author SHA1 Message Date
Kohya S
f0ae7eea95 Update README.md 2023-02-23 21:59:20 +09:00
Kohya S
b22b0a5c75 Merge pull request #223 from kohya-ss/control_net
support ControlNet
2023-02-23 21:53:05 +09:00
Kohya S
c7a13c89c7 Merge branch 'main' into control_net 2023-02-23 21:51:03 +09:00
Kohya S
39a70f10bd Merge pull request #222 from kohya-ss/dev
fix training instability issue, add metadata
2023-02-23 21:50:38 +09:00
Kohya S
a3c0e4cf44 update change history 2023-02-23 21:49:34 +09:00
Kohya S
9b13444b9c raise error if options conflict 2023-02-23 21:35:47 +09:00
Kohya S
0eb01dea55 add max_grad_norm to metadata 2023-02-23 21:34:38 +09:00
Kohya S
a3aa3b1712 fix typos 2023-02-23 21:14:44 +09:00
Kohya S
95b5aed41b Merge pull request #221 from space-nuko/add-more-metadata
Add more missing metadata
2023-02-23 21:14:26 +09:00
Kohya S
d9184ab21c remove LoRA-ControlNet 2023-02-23 21:01:13 +09:00
Kohya S
e7dd77836d Merge branch 'main' into control_net 2023-02-23 20:57:34 +09:00
Kohya S
4c5c486d28 Merge branch 'main' into dev 2023-02-23 20:57:17 +09:00
Kohya S
f403ac6132 fix float32 training doesn't work in some case 2023-02-23 20:56:41 +09:00
space-nuko
b39cf6e2c0 Add more missing metadata 2023-02-23 02:25:24 -08:00
Kohya S
71b728d5fc Update README.md 2023-02-22 22:25:53 +09:00
Kohya S
f0ef81f865 Merge pull request #219 from kohya-ss/dev
Dev
2023-02-22 22:21:04 +09:00
Kohya S
f68a48b354 update readme 2023-02-22 22:19:36 +09:00
Kohya S
7a0d2a2d45 update readme 2023-02-22 22:16:23 +09:00
Kohya S
e13e503cbc update readme 2023-02-22 22:10:32 +09:00
Kohya S
125039f491 update readme 2023-02-22 22:06:47 +09:00
Kohya S
f2b300a221 Add about optimizer 2023-02-22 22:04:53 +09:00
Kohya S
9ab964d0b8 Add Adafactor optimzier 2023-02-22 21:09:47 +09:00
Kohya S
663aad2b0d refactor get_scheduler etc. 2023-02-20 22:47:43 +09:00
Kohya S
12d30afb39 Merge pull request #212 from mgz-dev/optimizer-expand-and-refactor
expand optimizer options and refactor
2023-02-20 20:13:41 +09:00
Kohya S
107fa754e5 Merge branch 'dev' into optimizer-expand-and-refactor 2023-02-20 20:12:42 +09:00
Kohya S
a17d1180cb Merge pull request #209 from BootsofLagrangian/dadaptation
Dadaptation optimizer
2023-02-20 20:06:55 +09:00
Kohya S
014fd3d037 support original controlnet 2023-02-20 12:54:44 +09:00
mgz-dev
b29c5a750c expand optimizer options and refactor
Refactor code to make it easier to add new optimizers, and support alternate optimizer parameters

-move redundant code to train_util for initializing optimizers
- add SGD Nesterov optimizers as option (since they are already available)
- add new parameters which may be helpful for tuning existing and new optimizers
2023-02-19 17:45:09 -06:00
unknown
b612d0b091 apply dadaptation 2023-02-19 19:07:26 +09:00
Kohya S
d94c0d70fe support network mul from prompt 2023-02-19 18:43:35 +09:00
unknown
045a3dbe48 apply dadaptation 2023-02-19 18:37:07 +09:00
Kohya S
08ae46b163 Merge pull request #208 from space-nuko/add-optimizer-to-metadata
Add optimizer to metadata
2023-02-19 17:21:57 +09:00
space-nuko
4e5db58a71 Add optimizer to metadata 2023-02-18 23:28:36 -08:00
Kohya S
e45e272e9d Merge branch 'main' into control_net 2023-02-19 16:25:00 +09:00
Kohya S
a9d29ac78c Merge pull request #207 from kohya-ss/dev
Dev
2023-02-19 15:29:40 +09:00
Kohya S
5c065eee79 update readme 2023-02-19 15:26:21 +09:00
Kohya S
048e7cd428 add lion optimizer support 2023-02-19 15:26:14 +09:00
Kohya S
a76ad2d1d5 add comment for future requirement update 2023-02-19 15:25:01 +09:00
Kohya S
9d0f9736bf Merge pull request #202 from vladmandic/main
fix git path
2023-02-19 15:01:21 +09:00
Kohya S
00bb8a65a6 Merge pull request #200 from Isotr0py/lowram
Add '--lowram' argument
2023-02-19 14:32:32 +09:00
Vladimir Mandic
dac2bd163a fix git path 2023-02-17 14:19:08 -05:00
Isotr0py
78d1fb5ce6 Add '--lowram' argument 2023-02-17 12:08:54 +08:00
Kohya S
14d7b24619 Merge pull request #198 from kohya-ss/dev
Dev
2023-02-16 22:35:47 +09:00
Kohya S
3bc0d83769 update readme 2023-02-16 22:21:51 +09:00
Kohya S
ffdfd5f615 fix name of loss for epoch 2023-02-16 22:21:36 +09:00
Kohya S
d01d953262 Merge pull request #196 from space-nuko/add-noise-offset-metadata
Add noise offset to metadata
2023-02-16 22:01:02 +09:00
Kohya S
914d1505df Merge pull request #189 from shirayu/improve_loss_track
Show the moving average loss
2023-02-16 22:00:26 +09:00
Kohya S
8590d5dbca add dtype 2023-02-16 21:59:35 +09:00
space-nuko
496c8cdc09 Add noise-offset to metadata 2023-02-16 02:56:39 -08:00
Kohya S
39aa390d2b Merge branch 'main' into control_net 2023-02-15 12:36:34 +09:00
Kohya S
82713e9aa6 Update README.md 2023-02-14 21:41:04 +09:00
Kohya S
e067d64b53 Merge pull request #190 from kohya-ss/dev
Dev
2023-02-14 21:32:03 +09:00
Kohya S
3d400667d2 fix typos 2023-02-14 21:29:40 +09:00
Kohya S
2aef2872fb update readme 2023-02-14 21:28:34 +09:00
Kohya S
43c0a69843 Add noise_offset 2023-02-14 21:15:48 +09:00
Yuta Hayashibe
8aed5125de Removed call of sum() 2023-02-14 21:11:30 +09:00
Kohya S
e0f007f2a9 Fix import 2023-02-14 20:55:38 +09:00
Kohya S
3c29784825 Add ja comment 2023-02-14 20:55:20 +09:00
Kohya S
8f1e930bf4 Merge pull request #187 from space-nuko/add-commit-hash
Add commit hash to metadata
2023-02-14 19:52:30 +09:00
Kohya S
f771396e90 Merge pull request #179 from mgz-dev/resize_lora-verbose-print
add verbosity option for resize_lora.py
2023-02-14 19:50:49 +09:00
Kohya S
f67b3f4452 Merge pull request #165 from Isotr0py/support-multi-gpu
Add support with multi-gpu train for train_newtork.py
2023-02-14 19:47:53 +09:00
Yuta Hayashibe
21f5b618c3 Show the moving average loss 2023-02-14 19:46:27 +09:00
Kohya S
64bffe5238 remove print 2023-02-14 19:25:43 +09:00
Kohya S
cebee02698 Official weights to LoRA 2023-02-13 23:38:38 +09:00
space-nuko
5471b0deb0 Add commit hash to metadata 2023-02-13 02:58:06 -08:00
Kohya S
bc9fc4ccee ControlNet by LoRA 2023-02-12 22:15:23 +09:00
Isotr0py
2b1a3080e7 Add type checking 2023-02-12 15:32:38 +08:00
Isotr0py
92a1af8024 Merge branch 'kohya-ss:main' into support-multi-gpu 2023-02-12 15:06:46 +08:00
michaelgzhang
b35b053b8d clean up print formatting 2023-02-11 03:14:43 -06:00
michaelgzhang
55521eece0 add verbosity option for resize_lora.py
add --verbose flag to print additional statistics during resize_lora function
correct some parameter references in resize_lora_model function
2023-02-11 02:38:13 -06:00
Kohya S
b32abdd327 Merge pull request #178 from kohya-ss/dev
Dev
2023-02-11 16:16:15 +09:00
Kohya S
d1ecfde487 fix typo 2023-02-11 16:12:27 +09:00
Kohya S
04ad46a9a7 update readme 2023-02-11 16:11:42 +09:00
Kohya S
4c561411aa revert batch size limiting for bucket 2023-02-11 16:02:56 +09:00
Kohya S
43a41c6c43 Merge pull request #177 from kohya-ss/dev
Dev
2023-02-11 15:11:07 +09:00
Kohya S
5367daa210 update readme 2023-02-11 15:09:45 +09:00
Kohya S
b825e4602c update readme 2023-02-11 15:05:45 +09:00
Kohya S
188e54b760 support multiple init words 2023-02-11 15:00:11 +09:00
Kohya S
2c5f5c324a Fix crash TI train close #172, tag drop wo shuffle 2023-02-11 14:41:44 +09:00
Kohya S
5777be5208 Update README.md 2023-02-11 13:36:33 +09:00
Kohya S
e727a0d222 Update README.md 2023-02-11 13:30:12 +09:00
Kohya S
cdd8882a01 Merge pull request #176 from kohya-ss/dev
Dev
2023-02-11 13:22:40 +09:00
Kohya S
3f3502fb57 add message 2023-02-11 13:20:58 +09:00
Kohya S
20c00603a8 Merge branch 'main' into dev 2023-02-11 13:16:13 +09:00
Kohya S
9239fefa52 add lora interrogator with text encoder 2023-02-11 13:15:57 +09:00
Kohya S
53d60543e5 Merge pull request #174 from kohya-ss/dev
Dev
2023-02-10 23:11:12 +09:00
Kohya S
22e3aca89c Update README.md 2023-02-10 23:07:53 +09:00
Kohya S
8d86f58174 add merge script with svd 2023-02-10 22:55:33 +09:00
Kohya S
e5cc64a563 support multibyte characters for filename 2023-02-10 22:55:21 +09:00
Kohya S
c7406d6b27 keep metadata when resizing 2023-02-10 22:55:00 +09:00
Kohya S
d2da3c4236 support for models with different alphas 2023-02-10 22:54:35 +09:00
Kohya S
2bad87f2f6 Update README-ja.md 2023-02-10 18:12:03 +09:00
Kohya S
ed62e566bb Update README.md 2023-02-10 18:11:39 +09:00
Kohya S
51b3dc2c11 Merge pull request #171 from kohya-ss/dev
Dev
2023-02-10 17:40:08 +09:00
Kohya S
74f4a8fab9 Merge branch 'main' into dev 2023-02-10 17:37:39 +09:00
Kohya S
a75baf9143 Add strict version no 2023-02-10 17:37:19 +09:00
Kohya S
b03721b4d9 Add todo comment 2023-02-10 17:36:38 +09:00
Kohya S
6b790bace6 Update README.md 2023-02-09 23:14:41 +09:00
Kohya S
dcaecfd20b Merge pull request #168 from kohya-ss/dev
Dev
2023-02-09 22:15:35 +09:00
Kohya S
553ac4aa1b add about resizeing script 2023-02-09 22:13:01 +09:00
Kohya S
f0c8c95871 add assocatied files copying 2023-02-09 22:12:41 +09:00
Kohya S
c2e1d4b71b fix typo 2023-02-09 21:38:01 +09:00
Kohya S
3a72e6f003 add tag dropout 2023-02-09 21:35:27 +09:00
Kohya S
f7b5abb595 add resizing script 2023-02-09 21:30:27 +09:00
Isotr0py
b8ad17902f fix get_hidden_states expected scalar Error again 2023-02-08 23:09:59 +08:00
Isotr0py
9a9ac79edf correct wrong inserted code for noise_pred fix 2023-02-08 22:30:20 +08:00
Isotr0py
6473aa1dd7 fix Input type error in noise_pred when using DDP 2023-02-08 21:32:21 +08:00
Isotr0py
b599adc938 fix Input type error when using DDP 2023-02-08 20:14:20 +08:00
Isotr0py
5e96e1369d fix get_hidden_states expected scalar Error 2023-02-08 20:14:13 +08:00
Isotr0py
c0be52a773 ignore get_hidden_states expected scalar Error 2023-02-08 20:13:09 +08:00
Isotr0py
fb312acb7f support DistributedDataParallel 2023-02-08 20:12:43 +08:00
Isotr0py
938bd71844 lower ram usage 2023-02-08 18:31:27 +08:00
Kohya S
b3020db63f support python 3.8 2023-02-07 22:29:12 +09:00
Kohya S
e42b2f7aa9 conditional caption dropout (in progress) 2023-02-07 22:28:56 +09:00
Kohya S
f9478f0d47 Merge pull request #159 from forestsource/main
Add Conditional Dropout options
2023-02-07 21:50:26 +09:00
Kohya S
4fc9f1f8c5 Merge pull request #157 from shirayu/improve_tag_shuffle
Always join with ", "
2023-02-07 21:47:05 +09:00
Kohya S
5a3d1a57b6 Merge pull request #154 from shirayu/typos_checker
Add typo check GitHub Action
2023-02-07 21:35:35 +09:00
forestsource
7db98baa86 Add dropout options 2023-02-07 00:01:30 +09:00
Kohya S
d591891048 Update README.md 2023-02-06 21:30:38 +09:00
Kohya S
3a93d18bb5 Merge pull request #158 from kohya-ss/dev
Dev
2023-02-06 21:26:14 +09:00
Kohya S
7511674333 update readme 2023-02-06 21:14:16 +09:00
Kohya S
883bd1269c Merge branch 'dev' of https://github.com/kohya-ss/sd-scripts into dev 2023-02-06 20:52:30 +09:00
Kohya S
2aa27b7a4b Update downsampling for larger image in no_upscale 2023-02-06 20:52:24 +09:00
Yuta Hayashibe
5ea5fefcd2 Always join with ", " 2023-02-06 12:29:41 +09:00
Kohya S
6a79ac6a03 Update README.md 2023-02-05 21:59:55 +09:00
Kohya S
ea2dfd09ef update bucketing features 2023-02-05 21:37:46 +09:00
Yuta Hayashibe
7380801dfc Add typo check GitHub Action 2023-02-05 19:22:18 +09:00
23 changed files with 2294 additions and 583 deletions

21
.github/workflows/typos.yml vendored Normal file
View File

@@ -0,0 +1,21 @@
---
# yamllint disable rule:line-length
name: Typos
on: # yamllint disable-line rule:truthy
push:
pull_request:
types:
- opened
- synchronize
- reopened
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: typos-action
uses: crate-ci/typos@v1.13.10

View File

@@ -64,6 +64,12 @@ cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_set
accelerate config
```
<!--
pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117
pip install --use-pep517 --upgrade -r requirements.txt
pip install -U -I --no-deps xformers==0.0.16
-->
コマンドプロンプトでは以下になります。
@@ -116,7 +122,7 @@ accelerate configの質問には以下のように答えてください。bf1
cd sd-scripts
git pull
.\venv\Scripts\activate
pip install --upgrade -r requirements.txt
pip install --use-pep517 --upgrade -r requirements.txt
```
コマンドが成功すれば新しいバージョンが使用できます。

117
README.md
View File

@@ -1,64 +1,7 @@
This repository contains training, generation and utility scripts for Stable Diffusion.
## Updates
__Stable Diffusion web UI now seems to support LoRA trained by ``sd-scripts``.__ Thank you for great work!!!
Note: The LoRA models for SD 2.x is not supported too in Web UI.
- 4 Feb. 2023, 2023/2/4
- ``--persistent_data_loader_workers`` option is added to ``fine_tune.py``, ``train_db.py`` and ``train_network.py``. This option may significantly reduce the waiting time between epochs. Thanks to hitomi!
- ``--debug_dataset`` option is now working on non-Windows environment. Thanks to tsukimiya!
- ``networks/resize_lora.py`` script is added. This can approximate the higher-rank (dim) LoRA model by a lower-rank LoRA model, e.g. 128 by 4. Thanks to mgz-dev!
- ``--help`` option shows usage.
- Currently the metadata is not copied. This will be fixed in the near future.
- ``--persistent_data_loader_workers``オプションが ``fine_tune.py``、 ``train_db.py``、``train_network.py``の各スクリプトに追加されました。このオプションを指定するとエポック間の待ち時間が大幅に短縮される可能性があります。hitomi氏に感謝します。
- ``--debug_dataset``オプションがWindows環境以外でも動くようになりました。tsukimiya氏に感謝します。
- ``networks/resize_lora.py``スクリプトを追加しました。高rankのLoRAモデルを低rankのLoRAモデルで近似しますつまり128 rank (dim)のLoRAに似た、4 rank (dim)のLoRAを作ることができます。mgz-dev氏に感謝します。
- 使い方は``--help``オプションを指定して参照してください。
- 現時点ではメタデータはコピーされません。近日中に対応予定です。
- 3 Feb. 2023, 2023/2/3
- Update finetune preprocessing scripts.
- ``.bmp`` and ``.jpeg`` are supported. Thanks to breakcore2 and p1atdev!
- The default weights of ``tag_images_by_wd14_tagger.py`` is now ``SmilingWolf/wd-v1-4-convnext-tagger-v2``. You can specify another model id from ``SmilingWolf`` by ``--repo_id`` option. Thanks to SmilingWolf for the great work.
- To change the weight, remove ``wd14_tagger_model`` folder, and run the script again.
- ``--max_data_loader_n_workers`` option is added to each script. This option uses the DataLoader for data loading to speed up loading, 20%~30% faster.
- Please specify 2 or 4, depends on the number of CPU cores.
- ``--recursive`` option is added to ``merge_dd_tags_to_metadata.py`` and ``merge_captions_to_metadata.py``, only works with ``--full_path``.
- ``make_captions_by_git.py`` is added. It uses [GIT microsoft/git-large-textcaps](https://huggingface.co/microsoft/git-large-textcaps) for captioning.
- ``requirements.txt`` is updated. If you use this script, [please update the libraries](https://github.com/kohya-ss/sd-scripts#upgrade).
- Usage is almost the same as ``make_captions.py``, but batch size should be smaller.
- ``--remove_words`` option removes as much text as possible (such as ``the word "XXXX" on it``).
- ``--skip_existing`` option is added to ``prepare_buckets_latents.py``. Images with existing npz files are ignored by this option.
- ``clean_captions_and_tags.py`` is updated to remove duplicated or conflicting tags, e.g. ``shirt`` is removed when ``white shirt`` exists. if ``black hair`` is with ``red hair``, both are removed.
- Tag frequency is added to the metadata in ``train_network.py``. Thanks to space-nuko!
- __All tags and number of occurrences of the tag are recorded.__ If you do not want it, disable metadata storing with ``--no_metadata`` option.
- fine tuning用の前処理スクリプト群を更新しました。
- 拡張子 ``.bmp`` と ``.jpeg`` をサポートしました。breakcore2氏およびp1atdev氏に感謝します。
- ``tag_images_by_wd14_tagger.py`` のデフォルトの重みを ``SmilingWolf/wd-v1-4-convnext-tagger-v2`` に更新しました。他の ``SmilingWolf`` 氏の重みも ``--repo_id`` オプションで指定可能です。SmilingWolf氏に感謝します。
- 重みを変更するときには ``wd14_tagger_model`` フォルダを削除してからスクリプトを再実行してください。
- ``--max_data_loader_n_workers`` オプションが各スクリプトに追加されました。DataLoaderを用いることで読み込み処理を並列化し、処理を20~30%程度高速化します。
- CPUのコア数に応じて2~4程度の値を指定してください。
- ``--recursive`` オプションを ``merge_dd_tags_to_metadata.py`` と ``merge_captions_to_metadata.py`` に追加しました。``--full_path`` を指定したときのみ使用可能です。
- ``make_captions_by_git.py`` を追加しました。[GIT microsoft/git-large-textcaps](https://huggingface.co/microsoft/git-large-textcaps) を用いてキャプションニングを行います。
- ``requirements.txt`` が更新されていますので、[ライブラリをアップデート](https://github.com/kohya-ss/sd-scripts/blob/main/README-ja.md#%E3%82%A2%E3%83%83%E3%83%97%E3%82%B0%E3%83%AC%E3%83%BC%E3%83%89)してください。
- 使用法は ``make_captions.py``とほぼ同じですがバッチサイズは小さめにしてください。
- ``--remove_words`` オプションを指定するとテキスト読み取りを可能な限り削除します(``the word "XXXX" on it``のようなもの)。
- ``--skip_existing`` を ``prepare_buckets_latents.py`` に追加しました。すでにnpzファイルがある画像の処理をスキップします。
- ``clean_captions_and_tags.py``を重複タグや矛盾するタグを削除するよう機能追加しました。例:``white shirt`` タグがある場合、 ``shirt`` タグは削除されます。また``black hair``と``red hair``の両方がある場合、両方とも削除されます。
- ``train_network.py``で使用されているタグと回数をメタデータに記録するようになりました。space-nuko氏に感謝します。
- __すべてのタグと回数がメタデータに記録されます__ 望まない場合には``--no_metadata option``オプションでメタデータの記録を停止してください。
Stable Diffusion web UI本体で当リポジトリで学習したLoRAモデルによる画像生成がサポートされたようです。
SD2.x用のLoRAモデルはサポートされないようです。
Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates.
最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。
##
[__Change History__](#change-history) is moved to the bottom of the page.
更新履歴は[ページ末尾](#change-history)に移しました。
[日本語版README](./README-ja.md)
@@ -67,10 +10,13 @@ For easier use (GUI and PowerShell scripts etc...), please visit [the repository
This repository contains the scripts for:
* DreamBooth training, including U-Net and Text Encoder
* fine-tuning (native training), including U-Net and Text Encoder
* Fine-tuning (native training), including U-Net and Text Encoder
* LoRA training
* image generation
* model conversion (supports 1.x and 2.x, Stable Diffision ckpt/safetensors and Diffusers)
* Texutl Inversion training
* Image generation
* Model conversion (supports 1.x and 2.x, Stable Diffision ckpt/safetensors and Diffusers)
__Stable Diffusion web UI now seems to support LoRA trained by ``sd-scripts``.__ (SD 1.x based only) Thank you for great work!!!
## About requirements.txt
@@ -157,7 +103,7 @@ When a new release comes out you can upgrade your repo with the following comman
cd sd-scripts
git pull
.\venv\Scripts\activate
pip install --upgrade -r requirements.txt
pip install --use-pep517 --upgrade -r requirements.txt
```
Once the commands have completed successfully you should be ready to use the new version.
@@ -175,3 +121,48 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
[bitsandbytes](https://github.com/TimDettmers/bitsandbytes): MIT
[BLIP](https://github.com/salesforce/BLIP): BSD-3-Clause
## Change History
- 23 Feb. 2023, 2023/2/23:
- Fix instability training issue in ``train_network.py``.
- ``fp16`` training is probably not affected by this issue.
- Training with ``float`` for SD2.x models will work now. Also training with ``bf16`` might be improved.
- This issue seems to have occurred in [PR#190](https://github.com/kohya-ss/sd-scripts/pull/190).
- Add some metadata to LoRA model. Thanks to space-nuko!
- Raise an error if optimizer options conflict (e.g. ``--optimizer_type`` and ``--use_8bit_adam``.)
- Support ControlNet in ``gen_img_diffusers.py`` (no documentation yet.)
- ``train_network.py`` で学習が不安定になる不具合を修正しました。
- ``fp16`` 精度での学習には恐らくこの問題は影響しません。
- ``float`` 精度での SD2.x モデルの学習が正しく動作するようになりました。また ``bf16`` 精度の学習も改善する可能性があります。
- この問題は [PR#190](https://github.com/kohya-ss/sd-scripts/pull/190) から起きていたようです。
- いくつかのメタデータを LoRA モデルに追加しました。 space-nuko 氏に感謝します。
- オプティマイザ関係のオプションが矛盾していた場合、エラーとするように修正しました(例: ``--optimizer_type`` と ``--use_8bit_adam``)。
- ``gen_img_diffusers.py`` で ControlNet をサポートしました(ドキュメントはのちほど追加します)。
- 22 Feb. 2023, 2023/2/22:
- Refactor optmizer options. Thanks to mgz-dev!
- Add ``--optimizer_type`` option for each training script. Please see help. Japanese documentation is [here](https://github.com/kohya-ss/sd-scripts/blob/main/train_network_README-ja.md#%E3%82%AA%E3%83%97%E3%83%86%E3%82%A3%E3%83%9E%E3%82%A4%E3%82%B6%E3%81%AE%E6%8C%87%E5%AE%9A%E3%81%AB%E3%81%A4%E3%81%84%E3%81%A6).
- ``--use_8bit_adam`` and ``--use_lion_optimizer`` options also work, but override above option.
- Add SGDNesterov and its 8bit.
- Add [D-Adaptation](https://github.com/facebookresearch/dadaptation) optimizer. Thanks to BootsofLagrangian and all!
- Please install D-Adaptation optimizer with ``pip install dadaptation`` (it is not in requirements.txt currently.)
- Please see https://github.com/kohya-ss/sd-scripts/issues/181 for details.
- Add AdaFactor optimizer. Thanks to Toshiaki!
- Extra lr scheduler settings (num_cycles etc.) are working in training scripts other than ``train_network.py``.
- Add ``--max_grad_norm`` option for each training script for gradient clipping. ``0.0`` disables clipping.
- Symbolic link can be loaded in each training script. Thanks to TkskKurumi!
- オプティマイザ関連のオプションを見直しました。mgz-dev氏に感謝します。
- ``--optimizer_type`` を各学習スクリプトに追加しました。ドキュメントは[こちら](https://github.com/kohya-ss/sd-scripts/blob/main/train_network_README-ja.md#%E3%82%AA%E3%83%97%E3%83%86%E3%82%A3%E3%83%9E%E3%82%A4%E3%82%B6%E3%81%AE%E6%8C%87%E5%AE%9A%E3%81%AB%E3%81%A4%E3%81%84%E3%81%A6)。
- ``--use_8bit_adam`` と ``--use_lion_optimizer`` のオプションは依然として動作しますがoptimizer_typeを上書きしますのでご注意ください。
- SGDNesterov オプティマイザおよびその8bit版を追加しました。
- [D-Adaptation](https://github.com/facebookresearch/dadaptation) オプティマイザを追加しました。BootsofLagrangian 氏および諸氏に感謝します。
- ``pip install dadaptation`` コマンドで別途インストールが必要です現時点ではrequirements.txtに含まれておりません
- こちらのissueもあわせてご覧ください。 https://github.com/kohya-ss/sd-scripts/issues/181
- AdaFactor オプティマイザを追加しました。Toshiaki氏に感謝します。
- 追加のスケジューラ設定num_cycles等が ``train_network.py`` 以外の学習スクリプトでも使えるようになりました。
- 勾配クリップ時の最大normを指定する ``--max_grad_norm`` オプションを追加しました。``0.0``を指定するとクリップしなくなります。
- 各学習スクリプトでシンボリックリンクが読み込めるようになりました。TkskKurumi氏に感謝します。
Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates.
最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。

15
_typos.toml Normal file
View File

@@ -0,0 +1,15 @@
# Files for typos
# Instruction: https://github.com/marketplace/actions/typos-action#getting-started
[default.extend-identifiers]
[default.extend-words]
NIN="NIN"
parms="parms"
nin="nin"
extention="extention" # Intentionally left
nd="nd"
[files]
extend-exclude = ["_typos.toml"]

View File

@@ -33,8 +33,13 @@ def train(args):
train_dataset = train_util.FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
args.bucket_reso_steps, args.bucket_no_upscale,
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
args.dataset_repeats, args.debug_dataset)
# 学習データのdropout率を設定する
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
train_dataset.make_buckets()
if args.debug_dataset:
@@ -144,20 +149,7 @@ def train(args):
# 学習に必要なクラスを準備する
print("prepare optimizer, data loader etc.")
# 8-bit Adamを使う
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
print("use 8-bit Adam optimizer")
optimizer_class = bnb.optim.AdamW8bit
else:
optimizer_class = torch.optim.AdamW
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
optimizer = optimizer_class(params_to_optimize, lr=args.learning_rate)
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
@@ -171,8 +163,9 @@ def train(args):
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# lr schedulerを用意する
lr_scheduler = diffusers.optimization.get_scheduler(
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
# 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする
if args.full_fp16:
@@ -225,6 +218,8 @@ def train(args):
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset.set_current_epoch(epoch + 1)
for m in training_models:
m.train()
@@ -248,6 +243,9 @@ def train(args):
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
@@ -269,11 +267,11 @@ def train(args):
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
accelerator.backward(loss)
if accelerator.sync_gradients:
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = []
for m in training_models:
params_to_clip.extend(m.parameters())
accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
@@ -286,9 +284,12 @@ def train(args):
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
accelerator.log(logs, step=global_step)
# TODO moving averageにする
loss_total += current_loss
avr_loss = loss_total / (step+1)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
@@ -298,7 +299,7 @@ def train(args):
break
if args.logging_dir is not None:
logs = {"epoch_loss": loss_total / len(train_dataloader)}
logs = {"loss/epoch": loss_total / len(train_dataloader)}
accelerator.log(logs, step=epoch+1)
accelerator.wait_for_everyone()
@@ -331,9 +332,10 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, False, True)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
parser.add_argument("--diffusers_xformers", action='store_true',
help='use xformers by diffusers / Diffusersでxformersを使用する')

View File

@@ -52,6 +52,10 @@ def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip):
def main(args):
# assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります"
if args.bucket_reso_steps % 8 > 0:
print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります")
image_paths = train_util.glob_images(args.train_data_dir)
print(f"found {len(image_paths)} images.")
@@ -77,32 +81,41 @@ def main(args):
max_reso = tuple([int(t) for t in args.max_resolution.split(',')])
assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
bucket_resos, bucket_aspect_ratios = model_util.make_bucket_resolutions(
max_reso, args.min_bucket_reso, args.max_bucket_reso)
bucket_manager = train_util.BucketManager(args.bucket_no_upscale, max_reso,
args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps)
if not args.bucket_no_upscale:
bucket_manager.make_buckets()
else:
print("min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます")
# 画像をひとつずつ適切なbucketに割り当てながらlatentを計算する
bucket_aspect_ratios = np.array(bucket_aspect_ratios)
buckets_imgs = [[] for _ in range(len(bucket_resos))]
bucket_counts = [0 for _ in range(len(bucket_resos))]
img_ar_errors = []
def process_batch(is_last):
for j in range(len(buckets_imgs)):
bucket = buckets_imgs[j]
for bucket in bucket_manager.buckets:
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
latents = get_latents(vae, [img for _, _, img in bucket], weight_dtype)
latents = get_latents(vae, [img for _, img in bucket], weight_dtype)
assert latents.shape[2] == bucket[0][1].shape[0] // 8 and latents.shape[3] == bucket[0][1].shape[1] // 8, \
f"latent shape {latents.shape}, {bucket[0][1].shape}"
for (image_key, _, _), latent in zip(bucket, latents):
for (image_key, _), latent in zip(bucket, latents):
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False)
np.savez(npz_file_name, latent)
# flip
if args.flip_aug:
latents = get_latents(vae, [img[:, ::-1].copy() for _, _, img in bucket], weight_dtype) # copyがないとTensor変換できない
latents = get_latents(vae, [img[:, ::-1].copy() for _, img in bucket], weight_dtype) # copyがないとTensor変換できない
for (image_key, _, _), latent in zip(bucket, latents):
for (image_key, _), latent in zip(bucket, latents):
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True)
np.savez(npz_file_name, latent)
else:
# remove existing flipped npz
for image_key, _ in bucket:
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz"
if os.path.isfile(npz_file_name):
print(f"remove existing flipped npz / 既存のflipされたnpzファイルを削除します: {npz_file_name}")
os.remove(npz_file_name)
bucket.clear()
@@ -114,6 +127,7 @@ def main(args):
else:
data = [[(None, ip)] for ip in image_paths]
bucket_counts = {}
for data_entry in tqdm(data, smoothing=0.0):
if data_entry[0] is None:
continue
@@ -134,29 +148,24 @@ def main(args):
if image_key not in metadata:
metadata[image_key] = {}
# 本当はこの部分もDataSetに持っていけば高速化できるがいろいろ大変
aspect_ratio = image.width / image.height
ar_errors = bucket_aspect_ratios - aspect_ratio
bucket_id = np.abs(ar_errors).argmin()
reso = bucket_resos[bucket_id]
ar_error = ar_errors[bucket_id]
# 本当はこのあとの部分もDataSetに持っていけば高速化できるがいろいろ大変
reso, resized_size, ar_error = bucket_manager.select_bucket(image.width, image.height)
img_ar_errors.append(abs(ar_error))
bucket_counts[reso] = bucket_counts.get(reso, 0) + 1
# どのサイズにリサイズするか→トリミングする方向で
if ar_error <= 0: # 横が長い→縦を合わせる
scale = reso[1] / image.height
else:
scale = reso[0] / image.width
# メタデータに記録する解像度はlatent単位とするので、8単位で切り捨て
metadata[image_key]['train_resolution'] = (reso[0] - reso[0] % 8, reso[1] - reso[1] % 8)
resized_size = (int(image.width * scale + .5), int(image.height * scale + .5))
if not args.bucket_no_upscale:
# upscaleを行わないときには、resize後のサイズは、bucketのサイズと、縦横どちらかが同じであることを確認する
assert resized_size[0] == reso[0] or resized_size[1] == reso[
1], f"internal error, resized size not match: {reso}, {resized_size}, {image.width}, {image.height}"
assert resized_size[0] >= reso[0] and resized_size[1] >= reso[
1], f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}"
# print(image.width, image.height, bucket_id, bucket_resos[bucket_id], ar_errors[bucket_id], resized_size,
# bucket_resos[bucket_id][0] - resized_size[0], bucket_resos[bucket_id][1] - resized_size[1])
assert resized_size[0] == reso[0] or resized_size[1] == reso[
1], f"internal error, resized size not match: {reso}, {resized_size}, {image.width}, {image.height}"
assert resized_size[0] >= reso[0] and resized_size[1] >= reso[
1], f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}"
1], f"internal error resized size is small: {resized_size}, {reso}"
# 既に存在するファイルがあればshapeを確認して同じならskipする
if args.skip_existing:
@@ -180,22 +189,24 @@ def main(args):
# 画像をリサイズしてトリミングする
# PILにinter_areaがないのでcv2で……
image = np.array(image)
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA)
if resized_size[0] != image.shape[1] or resized_size[1] != image.shape[0]: # リサイズ処理が必要?
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA)
if resized_size[0] > reso[0]:
trim_size = resized_size[0] - reso[0]
image = image[:, trim_size//2:trim_size//2 + reso[0]]
elif resized_size[1] > reso[1]:
if resized_size[1] > reso[1]:
trim_size = resized_size[1] - reso[1]
image = image[trim_size//2:trim_size//2 + reso[1]]
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
# # debug
# cv2.imwrite(f"r:\\test\\img_{i:05d}.jpg", image[:, :, ::-1])
# cv2.imwrite(f"r:\\test\\img_{len(img_ar_errors)}.jpg", image[:, :, ::-1])
# バッチへ追加
buckets_imgs[bucket_id].append((image_key, reso, image))
bucket_counts[bucket_id] += 1
metadata[image_key]['train_resolution'] = reso
bucket_manager.add_image(reso, (image_key, image))
# バッチを推論するか判定して推論する
process_batch(False)
@@ -203,8 +214,11 @@ def main(args):
# 残りを処理する
process_batch(True)
for i, (reso, count) in enumerate(zip(bucket_resos, bucket_counts)):
print(f"bucket {i} {reso}: {count}")
bucket_manager.sort()
for i, reso in enumerate(bucket_manager.resos):
count = bucket_counts.get(reso, 0)
if count > 0:
print(f"bucket {i} {reso}: {count}")
img_ar_errors = np.array(img_ar_errors)
print(f"mean ar error: {np.mean(img_ar_errors)}")
@@ -230,6 +244,10 @@ if __name__ == '__main__':
help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)")
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度")
parser.add_argument("--bucket_reso_steps", type=int, default=64,
help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します")
parser.add_argument("--bucket_no_upscale", action="store_true",
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します")
parser.add_argument("--mixed_precision", type=str, default="no",
choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
parser.add_argument("--full_path", action="store_true",

View File

@@ -47,7 +47,7 @@ VGG(
"""
import json
from typing import List, Optional, Union
from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable
import glob
import importlib
import inspect
@@ -60,7 +60,6 @@ import math
import os
import random
import re
from typing import Any, Callable, List, Optional, Union
import diffusers
import numpy as np
@@ -81,6 +80,8 @@ from PIL import Image
from PIL.PngImagePlugin import PngInfo
import library.model_util as model_util
import tools.original_control_net as original_control_net
from tools.original_control_net import ControlNetInfo
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
@@ -487,6 +488,9 @@ class PipelineLike():
self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers)
self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD)
# ControlNet
self.control_nets: List[ControlNetInfo] = []
# Textual Inversion
def add_token_replacement(self, target_token_id, rep_token_ids):
self.token_replacements[target_token_id] = rep_token_ids
@@ -500,7 +504,11 @@ class PipelineLike():
new_tokens.append(token)
return new_tokens
def set_control_nets(self, ctrl_nets):
self.control_nets = ctrl_nets
# region xformersとか使う部分独自に書き換えるので関係なし
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
@@ -752,7 +760,7 @@ class PipelineLike():
text_embeddings_clip = self.clip_model.get_text_features(clip_text_input)
text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) # prompt複数件でもOK
if self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0 and clip_guide_images is not None:
if self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0 and clip_guide_images is not None or self.control_nets:
if isinstance(clip_guide_images, PIL.Image.Image):
clip_guide_images = [clip_guide_images]
@@ -765,7 +773,7 @@ class PipelineLike():
image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
if len(image_embeddings_clip) == 1:
image_embeddings_clip = image_embeddings_clip.repeat((batch_size, 1, 1, 1))
else:
elif self.vgg16_guidance_scale > 0:
size = (width // VGG16_INPUT_RESIZE_DIV, height // VGG16_INPUT_RESIZE_DIV) # とりあえず1/4に小さいか?
clip_guide_images = [preprocess_vgg16_guide_image(im, size) for im in clip_guide_images]
clip_guide_images = torch.cat(clip_guide_images, dim=0)
@@ -774,6 +782,10 @@ class PipelineLike():
image_embeddings_vgg16 = self.vgg16_feat_model(clip_guide_images)['feat']
if len(image_embeddings_vgg16) == 1:
image_embeddings_vgg16 = image_embeddings_vgg16.repeat((batch_size, 1, 1, 1))
else:
# ControlNetのhintにguide imageを流用する
# 前処理はControlNet側で行う
pass
# set timesteps
self.scheduler.set_timesteps(num_inference_steps, self.device)
@@ -864,12 +876,21 @@ class PipelineLike():
extra_step_kwargs["eta"] = eta
num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1
if self.control_nets:
guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images)
for i, t in enumerate(tqdm(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
if self.control_nets:
noise_pred = original_control_net.call_unet_and_control_net(
i, num_latent_input, self.unet, self.control_nets, guided_hints, i / len(timesteps), latent_model_input, t, text_embeddings).sample
else:
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
@@ -1817,6 +1838,34 @@ def preprocess_mask(mask):
# return text_encoder
class BatchDataBase(NamedTuple):
# バッチ分割が必要ないデータ
step: int
prompt: str
negative_prompt: str
seed: int
init_image: Any
mask_image: Any
clip_prompt: str
guide_image: Any
class BatchDataExt(NamedTuple):
# バッチ分割が必要なデータ
width: int
height: int
steps: int
scale: float
negative_scale: float
strength: float
network_muls: Tuple[float]
class BatchData(NamedTuple):
base: BatchDataBase
ext: BatchDataExt
def main(args):
if args.fp16:
dtype = torch.float16
@@ -1995,11 +2044,13 @@ def main(args):
# networkを組み込む
if args.network_module:
networks = []
network_default_muls = []
for i, network_module in enumerate(args.network_module):
print("import network module:", network_module)
imported_module = importlib.import_module(network_module)
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
network_default_muls.append(network_mul)
net_kwargs = {}
if args.network_args and i < len(args.network_args):
@@ -2014,7 +2065,7 @@ def main(args):
network_weight = args.network_weights[i]
print("load network weights from:", network_weight)
if model_util.is_safetensors(network_weight):
if model_util.is_safetensors(network_weight) and args.network_show_meta:
from safetensors.torch import safe_open
with safe_open(network_weight, framework="pt") as f:
metadata = f.metadata()
@@ -2037,6 +2088,18 @@ def main(args):
else:
networks = []
# ControlNetの処理
control_nets: List[ControlNetInfo] = []
if args.control_net_models:
for i, model in enumerate(args.control_net_models):
prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i]
weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i]
ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i]
ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model)
prep = original_control_net.load_preprocess(prep_type)
control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio))
if args.opt_channels_last:
print(f"set optimizing: channels last")
text_encoder.to(memory_format=torch.channels_last)
@@ -2050,9 +2113,14 @@ def main(args):
if vgg16_model is not None:
vgg16_model.to(memory_format=torch.channels_last)
for cn in control_nets:
cn.unet.to(memory_format=torch.channels_last)
cn.net.to(memory_format=torch.channels_last)
pipe = PipelineLike(device, vae, text_encoder, tokenizer, unet, scheduler, args.clip_skip,
clip_model, args.clip_guidance_scale, args.clip_image_guidance_scale,
vgg16_model, args.vgg16_guidance_scale, args.vgg16_guidance_layer)
pipe.set_control_nets(control_nets)
print("pipeline is ready.")
if args.diffusers_xformers:
@@ -2186,9 +2254,12 @@ def main(args):
prev_image = None # for VGG16 guided
if args.guide_image_path is not None:
print(f"load image for CLIP/VGG16 guidance: {args.guide_image_path}")
guide_images = load_images(args.guide_image_path)
print(f"loaded {len(guide_images)} guide images for CLIP/VGG16 guidance")
print(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}")
guide_images = []
for p in args.guide_image_path:
guide_images.extend(load_images(p))
print(f"loaded {len(guide_images)} guide images for guidance")
if len(guide_images) == 0:
print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}")
guide_images = None
@@ -2219,33 +2290,37 @@ def main(args):
iter_seed = random.randint(0, 0x7fffffff)
# バッチ処理の関数
def process_batch(batch, highres_fix, highres_1st=False):
def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
batch_size = len(batch)
# highres_fixの処理
if highres_fix and not highres_1st:
# 1st stageのバッチを作成して呼び出す
# 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す
print("process 1st stage1")
batch_1st = []
for params1, (width, height, steps, scale, negative_scale, strength) in batch:
width_1st = int(width * args.highres_fix_scale + .5)
height_1st = int(height * args.highres_fix_scale + .5)
for base, ext in batch:
width_1st = int(ext.width * args.highres_fix_scale + .5)
height_1st = int(ext.height * args.highres_fix_scale + .5)
width_1st = width_1st - width_1st % 32
height_1st = height_1st - height_1st % 32
batch_1st.append((params1, (width_1st, height_1st, args.highres_fix_steps, scale, negative_scale, strength)))
ext_1st = BatchDataExt(width_1st, height_1st, args.highres_fix_steps, ext.scale,
ext.negative_scale, ext.strength, ext.network_muls)
batch_1st.append(BatchData(base, ext_1st))
images_1st = process_batch(batch_1st, True, True)
# 2nd stageのバッチを作成して以下処理する
print("process 2nd stage1")
batch_2nd = []
for i, (b1, image) in enumerate(zip(batch, images_1st)):
image = image.resize((width, height), resample=PIL.Image.LANCZOS)
(step, prompt, negative_prompt, seed, _, _, clip_prompt, guide_image), params2 = b1
batch_2nd.append(((step, prompt, negative_prompt, seed+1, image, None, clip_prompt, guide_image), params2))
for i, (bd, image) in enumerate(zip(batch, images_1st)):
image = image.resize((bd.ext.width, bd.ext.height), resample=PIL.Image.LANCZOS) # img2imgとして設定
bd_2nd = BatchData(BatchDataBase(*bd.base[0:3], bd.base.seed+1, image, None, *bd.base[6:]), bd.ext)
batch_2nd.append(bd_2nd)
batch = batch_2nd
(step_first, _, _, _, init_image, mask_image, _, guide_image), (width,
height, steps, scale, negative_scale, strength) = batch[0]
# このバッチの情報を取り出す
(step_first, _, _, _, init_image, mask_image, _, guide_image), \
(width, height, steps, scale, negative_scale, strength, network_muls) = batch[0]
noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
prompts = []
@@ -2295,9 +2370,13 @@ def main(args):
all_masks_are_same = mask_images[-2] is mask_image
if guide_image is not None:
guide_images.append(guide_image)
if i > 0 and all_guide_images_are_same:
all_guide_images_are_same = guide_images[-2] is guide_image
if type(guide_image) is list:
guide_images.extend(guide_image)
all_guide_images_are_same = False
else:
guide_images.append(guide_image)
if i > 0 and all_guide_images_are_same:
all_guide_images_are_same = guide_images[-2] is guide_image
# make start code
torch.manual_seed(seed)
@@ -2320,7 +2399,19 @@ def main(args):
if guide_images is not None and all_guide_images_are_same:
guide_images = guide_images[0]
# ControlNet使用時はguide imageをリサイズする
if control_nets:
# TODO resampleのメソッド
guide_images = guide_images if type(guide_images) == list else [guide_images]
guide_images = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in guide_images]
if len(guide_images) == 1:
guide_images = guide_images[0]
# generate
if networks:
for n, m in zip(networks, network_muls if network_muls else network_default_muls):
n.set_multiplier(m)
images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code,
output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises, clip_prompts=clip_prompts, clip_guide_images=guide_images)[0]
if highres_1st and not args.highres_fix_save_1st:
@@ -2398,6 +2489,7 @@ def main(args):
strength = 0.8 if args.strength is None else args.strength
negative_prompt = ""
clip_prompt = None
network_muls = None
prompt_args = prompt.strip().split(' --')
prompt = prompt_args[0]
@@ -2461,6 +2553,15 @@ def main(args):
clip_prompt = m.group(1)
print(f"clip prompt: {clip_prompt}")
continue
m = re.match(r'am ([\d\.\-,]+)', parg, re.IGNORECASE)
if m: # network multiplies
network_muls = [float(v) for v in m.group(1).split(",")]
while len(network_muls) < len(networks):
network_muls.append(network_muls[-1])
print(f"network mul: {network_muls}")
continue
except ValueError as ex:
print(f"Exception in parsing / 解析エラー: {parg}")
print(ex)
@@ -2498,7 +2599,12 @@ def main(args):
mask_image = mask_images[global_step % len(mask_images)]
if guide_images is not None:
guide_image = guide_images[global_step % len(guide_images)]
if control_nets: # 複数件の場合あり
c = len(control_nets)
p = global_step % (len(guide_images) // c)
guide_image = guide_images[p * c:p * c + c]
else:
guide_image = guide_images[global_step % len(guide_images)]
elif args.clip_image_guidance_scale > 0 or args.vgg16_guidance_scale > 0:
if prev_image is None:
print("Generate 1st image without guide image.")
@@ -2506,9 +2612,8 @@ def main(args):
print("Use previous image as guide image.")
guide_image = prev_image
# TODO named tupleか何かにする
b1 = ((global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
(width, height, steps, scale, negative_scale, strength))
b1 = BatchData(BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
BatchDataExt(width, height, steps, scale, negative_scale, strength, tuple(network_muls) if network_muls else None))
if len(batch_data) > 0 and batch_data[-1][1] != b1[1]: # バッチ分割必要?
process_batch(batch_data, highres_fix)
batch_data.clear()
@@ -2578,12 +2683,15 @@ if __name__ == '__main__':
parser.add_argument("--opt_channels_last", action='store_true',
help='set channels last option to model / モデルにchannels lastを指定し最適化する')
parser.add_argument("--network_module", type=str, default=None, nargs='*',
help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名')
help='additional network module to use / 追加ネットワークを使う時そのモジュール名')
parser.add_argument("--network_weights", type=str, default=None, nargs='*',
help='Hypernetwork weights to load / Hypernetworkの重み')
parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率')
help='additional network weights to load / 追加ネットワークの重み')
parser.add_argument("--network_mul", type=float, default=None, nargs='*',
help='additional network multiplier / 追加ネットワークの効果の倍率')
parser.add_argument("--network_args", type=str, default=None, nargs='*',
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
parser.add_argument("--network_show_meta", action='store_true',
help='show metadata of network model / ネットワークモデルのメタデータを表示する')
parser.add_argument("--textual_inversion_embeddings", type=str, default=None, nargs='*',
help='Embeddings files of Textual Inversion / Textual Inversionのembeddings')
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
@@ -2597,7 +2705,8 @@ if __name__ == '__main__':
help='enable VGG16 guided SD by image, scale for guidance / 画像によるVGG16 guided SDを有効にしてこのscaleを適用する')
parser.add_argument("--vgg16_guidance_layer", type=int, default=20,
help='layer of VGG16 to calculate contents guide (1~30, 20 for conv4_2) / VGG16のcontents guideに使うレイヤー番号 (1~30、20はconv4_2)')
parser.add_argument("--guide_image_path", type=str, default=None, help="image to CLIP guidance / CLIP guided SDでガイドに使う画像")
parser.add_argument("--guide_image_path", type=str, default=None, nargs="*",
help="image to CLIP guidance / CLIP guided SDでガイドに使う画像")
parser.add_argument("--highres_fix_scale", type=float, default=None,
help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする")
parser.add_argument("--highres_fix_steps", type=int, default=28,
@@ -2607,5 +2716,13 @@ if __name__ == '__main__':
parser.add_argument("--negative_scale", type=float, default=None,
help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する")
parser.add_argument("--control_net_models", type=str, default=None, nargs='*',
help='ControlNet models to use / 使用するControlNetのモデル名')
parser.add_argument("--control_net_preps", type=str, default=None, nargs='*',
help='ControlNet preprocess to use / 使用するControlNetのプリプロセス名')
parser.add_argument("--control_net_weights", type=float, default=None, nargs='*', help='ControlNet weights / ControlNetの重み')
parser.add_argument("--control_net_ratios", type=float, default=None, nargs='*',
help='ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率')
args = parser.parse_args()
main(args)

View File

@@ -1163,15 +1163,14 @@ def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64)
resos = list(resos)
resos.sort()
aspect_ratios = [w / h for w, h in resos]
return resos, aspect_ratios
return resos
if __name__ == '__main__':
resos, aspect_ratios = make_bucket_resolutions((512, 768))
resos = make_bucket_resolutions((512, 768))
print(len(resos))
print(resos)
aspect_ratios = [w / h for w, h in resos]
print(aspect_ratios)
ars = set()

File diff suppressed because it is too large Load Diff

View File

@@ -5,6 +5,7 @@
import math
import os
from typing import List
import torch
from library import train_util
@@ -98,7 +99,7 @@ class LoRANetwork(torch.nn.Module):
self.alpha = alpha
# create module instances
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> list[LoRAModule]:
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
loras = []
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
@@ -125,6 +126,11 @@ class LoRANetwork(torch.nn.Module):
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
names.add(lora.lora_name)
def set_multiplier(self, multiplier):
self.multiplier = multiplier
for lora in self.text_encoder_loras + self.unet_loras:
lora.multiplier = self.multiplier
def load_weights(self, file):
if os.path.splitext(file)[1] == '.safetensors':
from safetensors.torch import load_file, safe_open

View File

@@ -0,0 +1,122 @@
from tqdm import tqdm
from library import model_util
import argparse
from transformers import CLIPTokenizer
import torch
import library.model_util as model_util
import lora
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def interrogate(args):
# いろいろ準備する
print(f"loading SD model: {args.sd_model}")
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
print(f"loading LoRA: {args.model}")
network = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)
# text encoder向けの重みがあるかチェックする本当はlora側でやるのがいい
has_te_weight = False
for key in network.weights_sd.keys():
if 'lora_te' in key:
has_te_weight = True
break
if not has_te_weight:
print("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません")
return
del vae
print("loading tokenizer")
if args.v2:
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
else:
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) # , model_max_length=max_token_length + 2)
text_encoder.to(DEVICE)
text_encoder.eval()
unet.to(DEVICE)
unet.eval() # U-Netは呼び出さないので不要だけど
# トークンをひとつひとつ当たっていく
token_id_start = 0
token_id_end = max(tokenizer.all_special_ids)
print(f"interrogate tokens are: {token_id_start} to {token_id_end}")
def get_all_embeddings(text_encoder):
embs = []
with torch.no_grad():
for token_id in tqdm(range(token_id_start, token_id_end + 1, args.batch_size)):
batch = []
for tid in range(token_id, min(token_id_end + 1, token_id + args.batch_size)):
tokens = [tokenizer.bos_token_id, tid, tokenizer.eos_token_id]
# tokens = [tid] # こちらは結果がいまひとつ
batch.append(tokens)
# batch_embs = text_encoder(torch.tensor(batch).to(DEVICE))[0].to("cpu") # bos/eosも含めたほうが差が出るようだ [:, 1]
# clip skip対応
batch = torch.tensor(batch).to(DEVICE)
if args.clip_skip is None:
encoder_hidden_states = text_encoder(batch)[0]
else:
enc_out = text_encoder(batch, output_hidden_states=True, return_dict=True)
encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.to("cpu")
embs.extend(encoder_hidden_states)
return torch.stack(embs)
print("get original text encoder embeddings.")
orig_embs = get_all_embeddings(text_encoder)
network.apply_to(text_encoder, unet, True, len(network.unet_loras) > 0)
network.to(DEVICE)
network.eval()
print("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません")
print("get text encoder embeddings with lora.")
lora_embs = get_all_embeddings(text_encoder)
# 比べる:とりあえず単純に差分の絶対値で
print("comparing...")
diffs = {}
for i, (orig_emb, lora_emb) in enumerate(zip(orig_embs, tqdm(lora_embs))):
diff = torch.mean(torch.abs(orig_emb - lora_emb))
# diff = torch.mean(torch.cosine_similarity(orig_emb, lora_emb, dim=1)) # うまく検出できない
diff = float(diff.detach().to('cpu').numpy())
diffs[token_id_start + i] = diff
diffs_sorted = sorted(diffs.items(), key=lambda x: -x[1])
# 結果を表示する
print("top 100:")
for i, (token, diff) in enumerate(diffs_sorted[:100]):
# if diff < 1e-6:
# break
string = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens([token]))
print(f"[{i:3d}]: {token:5d} {string:<20s}: {diff:.5f}")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--v2", action='store_true',
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
parser.add_argument("--sd_model", type=str, default=None,
help="Stable Diffusion model to load: ckpt or safetensors file / 読み込むSDのモデル、ckptまたはsafetensors")
parser.add_argument("--model", type=str, default=None,
help="LoRA model to interrogate: ckpt or safetensors file / 調査するLoRAモデル、ckptまたはsafetensors")
parser.add_argument("--batch_size", type=int, default=16,
help="batch size for processing with Text Encoder / Text Encoderで処理するときのバッチサイズ")
parser.add_argument("--clip_skip", type=int, default=None,
help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いるnは1以上")
args = parser.parse_args()
interrogate(args)

View File

@@ -1,5 +1,5 @@
import math
import argparse
import os
import torch
@@ -85,43 +85,76 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
weight = weight + ratio * (up_weight @ down_weight) * scale
else:
# conv2d
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) * scale
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
).unsqueeze(2).unsqueeze(3) * scale
module.weight = torch.nn.Parameter(weight)
def merge_lora_models(models, ratios, merge_dtype):
merged_sd = {}
base_alphas = {} # alpha for merged model
base_dims = {}
alpha = None
dim = None
merged_sd = {}
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
lora_sd = load_state_dict(model, merge_dtype)
# get alpha and dim
alphas = {} # alpha for current model
dims = {} # dims for current model
for key in lora_sd.keys():
if 'alpha' in key:
lora_module_name = key[:key.rfind(".alpha")]
alpha = float(lora_sd[key].detach().numpy())
alphas[lora_module_name] = alpha
if lora_module_name not in base_alphas:
base_alphas[lora_module_name] = alpha
elif "lora_down" in key:
lora_module_name = key[:key.rfind(".lora_down")]
dim = lora_sd[key].size()[0]
dims[lora_module_name] = dim
if lora_module_name not in base_dims:
base_dims[lora_module_name] = dim
for lora_module_name in dims.keys():
if lora_module_name not in alphas:
alpha = dims[lora_module_name]
alphas[lora_module_name] = alpha
if lora_module_name not in base_alphas:
base_alphas[lora_module_name] = alpha
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
# merge
print(f"merging...")
for key in lora_sd.keys():
if 'alpha' in key:
if key in merged_sd:
assert merged_sd[key] == lora_sd[key], f"alpha mismatch / alphaが異なる場合、現時点ではマージできません"
else:
alpha = lora_sd[key].detach().numpy()
merged_sd[key] = lora_sd[key]
continue
lora_module_name = key[:key.rfind(".lora_")]
base_alpha = base_alphas[lora_module_name]
alpha = alphas[lora_module_name]
scale = math.sqrt(alpha / base_alpha) * ratio
if key in merged_sd:
assert merged_sd[key].size() == lora_sd[key].size(
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
else:
if key in merged_sd:
assert merged_sd[key].size() == lora_sd[key].size(
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
merged_sd[key] = merged_sd[key] + lora_sd[key] * ratio
else:
if "lora_down" in key:
dim = lora_sd[key].size()[0]
merged_sd[key] = lora_sd[key] * ratio
merged_sd[key] = lora_sd[key] * scale
# set alpha to sd
for lora_module_name, alpha in base_alphas.items():
key = lora_module_name + ".alpha"
merged_sd[key] = torch.tensor(alpha)
print(f"dim (rank): {dim}, alpha: {alpha}")
if alpha is None:
alpha = dim
print("merged model")
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
return merged_sd, dim, alpha
return merged_sd
def merge(args):
@@ -152,7 +185,7 @@ def merge(args):
model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet,
args.sd_model, 0, 0, save_dtype, vae)
else:
state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype)
state_dict = merge_lora_models(args.models, args.ratios, merge_dtype)
print(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype)

179
networks/merge_lora_old.py Normal file
View File

@@ -0,0 +1,179 @@
import argparse
import os
import torch
from safetensors.torch import load_file, save_file
import library.model_util as model_util
import lora
def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == '.safetensors':
sd = load_file(file_name)
else:
sd = torch.load(file_name, map_location='cpu')
for key in list(sd.keys()):
if type(sd[key]) == torch.Tensor:
sd[key] = sd[key].to(dtype)
return sd
def save_to_file(file_name, model, state_dict, dtype):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
if os.path.splitext(file_name)[1] == '.safetensors':
save_file(model, file_name)
else:
torch.save(model, file_name)
def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
text_encoder.to(merge_dtype)
unet.to(merge_dtype)
# create module map
name_to_module = {}
for i, root_module in enumerate([text_encoder, unet]):
if i == 0:
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER
target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
else:
prefix = lora.LoRANetwork.LORA_PREFIX_UNET
target_replace_modules = lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
lora_name = prefix + '.' + name + '.' + child_name
lora_name = lora_name.replace('.', '_')
name_to_module[lora_name] = child_module
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
lora_sd = load_state_dict(model, merge_dtype)
print(f"merging...")
for key in lora_sd.keys():
if "lora_down" in key:
up_key = key.replace("lora_down", "lora_up")
alpha_key = key[:key.index("lora_down")] + 'alpha'
# find original module for this lora
module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight"
if module_name not in name_to_module:
print(f"no module found for LoRA weight: {key}")
continue
module = name_to_module[module_name]
# print(f"apply {key} to {module}")
down_weight = lora_sd[key]
up_weight = lora_sd[up_key]
dim = down_weight.size()[0]
alpha = lora_sd.get(alpha_key, dim)
scale = alpha / dim
# W <- W + U * D
weight = module.weight
if len(weight.size()) == 2:
# linear
weight = weight + ratio * (up_weight @ down_weight) * scale
else:
# conv2d
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) * scale
module.weight = torch.nn.Parameter(weight)
def merge_lora_models(models, ratios, merge_dtype):
merged_sd = {}
alpha = None
dim = None
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
lora_sd = load_state_dict(model, merge_dtype)
print(f"merging...")
for key in lora_sd.keys():
if 'alpha' in key:
if key in merged_sd:
assert merged_sd[key] == lora_sd[key], f"alpha mismatch / alphaが異なる場合、現時点ではマージできません"
else:
alpha = lora_sd[key].detach().numpy()
merged_sd[key] = lora_sd[key]
else:
if key in merged_sd:
assert merged_sd[key].size() == lora_sd[key].size(
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
merged_sd[key] = merged_sd[key] + lora_sd[key] * ratio
else:
if "lora_down" in key:
dim = lora_sd[key].size()[0]
merged_sd[key] = lora_sd[key] * ratio
print(f"dim (rank): {dim}, alpha: {alpha}")
if alpha is None:
alpha = dim
return merged_sd, dim, alpha
def merge(args):
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
def str_to_dtype(p):
if p == 'float':
return torch.float
if p == 'fp16':
return torch.float16
if p == 'bf16':
return torch.bfloat16
return None
merge_dtype = str_to_dtype(args.precision)
save_dtype = str_to_dtype(args.save_precision)
if save_dtype is None:
save_dtype = merge_dtype
if args.sd_model is not None:
print(f"loading SD model: {args.sd_model}")
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype)
print(f"saving SD model to: {args.save_to}")
model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet,
args.sd_model, 0, 0, save_dtype, vae)
else:
state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype)
print(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--v2", action='store_true',
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
parser.add_argument("--save_precision", type=str, default=None,
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ")
parser.add_argument("--precision", type=str, default="float",
choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度floatを推奨")
parser.add_argument("--sd_model", type=str, default=None,
help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする")
parser.add_argument("--save_to", type=str, default=None,
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
parser.add_argument("--models", type=str, nargs='*',
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors")
parser.add_argument("--ratios", type=float, nargs='*',
help="ratios for each model / それぞれのLoRAモデルの比率")
args = parser.parse_args()
merge(args)

View File

@@ -5,148 +5,178 @@
import argparse
import os
import torch
from safetensors.torch import load_file, save_file
from safetensors.torch import load_file, save_file, safe_open
from tqdm import tqdm
from library import train_util, model_util
def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == '.safetensors':
if model_util.is_safetensors(file_name):
sd = load_file(file_name)
with safe_open(file_name, framework="pt") as f:
metadata = f.metadata()
else:
sd = torch.load(file_name, map_location='cpu')
metadata = None
for key in list(sd.keys()):
if type(sd[key]) == torch.Tensor:
sd[key] = sd[key].to(dtype)
return sd
return sd, metadata
def save_to_file(file_name, model, state_dict, dtype):
def save_to_file(file_name, model, state_dict, dtype, metadata):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
if os.path.splitext(file_name)[1] == '.safetensors':
save_file(model, file_name)
if model_util.is_safetensors(file_name):
save_file(model, file_name, metadata)
else:
torch.save(model, file_name)
def resize_lora_model(model, new_rank, merge_dtype, save_dtype):
print("Loading Model...")
lora_sd = load_state_dict(model, merge_dtype)
def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
network_alpha = None
network_dim = None
verbose_str = "\n"
network_alpha = None
network_dim = None
CLAMP_QUANTILE = 0.99
CLAMP_QUANTILE = 0.99
# Extract loaded lora dim and alpha
for key, value in lora_sd.items():
if network_alpha is None and 'alpha' in key:
network_alpha = value
if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
network_dim = value.size()[0]
if network_alpha is not None and network_dim is not None:
break
if network_alpha is None:
network_alpha = network_dim
# Extract loaded lora dim and alpha
for key, value in lora_sd.items():
if network_alpha is None and 'alpha' in key:
network_alpha = value
if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
network_dim = value.size()[0]
if network_alpha is not None and network_dim is not None:
break
if network_alpha is None:
network_alpha = network_dim
scale = network_alpha/network_dim
new_alpha = float(scale*new_rank) # calculate new alpha from scale
scale = network_alpha/network_dim
new_alpha = float(scale*new_rank) # calculate new alpha from scale
print(f"old dimension: {network_dim}, old alpha: {network_alpha}, new alpha: {new_alpha}")
print(f"dimension: {network_dim}, alpha: {network_alpha}, new alpha: {new_alpha}")
lora_down_weight = None
lora_up_weight = None
lora_down_weight = None
lora_up_weight = None
o_lora_sd = lora_sd.copy()
block_down_name = None
block_up_name = None
o_lora_sd = lora_sd.copy()
block_down_name = None
block_up_name = None
print("resizing lora...")
with torch.no_grad():
for key, value in tqdm(lora_sd.items()):
if 'lora_down' in key:
block_down_name = key.split(".")[0]
lora_down_weight = value
if 'lora_up' in key:
block_up_name = key.split(".")[0]
lora_up_weight = value
print("resizing lora...")
with torch.no_grad():
for key, value in tqdm(lora_sd.items()):
if 'lora_down' in key:
block_down_name = key.split(".")[0]
lora_down_weight = value
if 'lora_up' in key:
block_up_name = key.split(".")[0]
lora_up_weight = value
weights_loaded = (lora_down_weight is not None and lora_up_weight is not None)
weights_loaded = (lora_down_weight is not None and lora_up_weight is not None)
if (block_down_name == block_up_name) and weights_loaded:
if (block_down_name == block_up_name) and weights_loaded:
conv2d = (len(lora_down_weight.size()) == 4)
conv2d = (len(lora_down_weight.size()) == 4)
if conv2d:
lora_down_weight = lora_down_weight.squeeze()
lora_up_weight = lora_up_weight.squeeze()
if conv2d:
lora_down_weight = lora_down_weight.squeeze()
lora_up_weight = lora_up_weight.squeeze()
if args.device:
org_device = lora_up_weight.device
lora_up_weight = lora_up_weight.to(args.device)
lora_down_weight = lora_down_weight.to(args.device)
if device:
org_device = lora_up_weight.device
lora_up_weight = lora_up_weight.to(args.device)
lora_down_weight = lora_down_weight.to(args.device)
full_weight_matrix = torch.matmul(lora_up_weight, lora_down_weight)
full_weight_matrix = torch.matmul(lora_up_weight, lora_down_weight)
U, S, Vh = torch.linalg.svd(full_weight_matrix)
U, S, Vh = torch.linalg.svd(full_weight_matrix)
U = U[:, :new_rank]
S = S[:new_rank]
U = U @ torch.diag(S)
if verbose:
s_sum = torch.sum(torch.abs(S))
s_rank = torch.sum(torch.abs(S[:new_rank]))
verbose_str+=f"{block_down_name:76} | "
verbose_str+=f"sum(S) retained: {(s_rank)/s_sum:.1%}, max(S) ratio: {S[0]/S[new_rank]:0.1f}\n"
Vh = Vh[:new_rank, :]
U = U[:, :new_rank]
S = S[:new_rank]
U = U @ torch.diag(S)
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
low_val = -hi_val
Vh = Vh[:new_rank, :]
U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)
if conv2d:
U = U.unsqueeze(2).unsqueeze(3)
Vh = Vh.unsqueeze(2).unsqueeze(3)
if args.device:
U = U.to(org_device)
Vh = Vh.to(org_device)
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
low_val = -hi_val
o_lora_sd[block_down_name + "." + "lora_down.weight"] = Vh.to(save_dtype).contiguous()
o_lora_sd[block_up_name + "." + "lora_up.weight"] = U.to(save_dtype).contiguous()
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(new_alpha).to(save_dtype)
U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)
block_down_name = None
block_up_name = None
lora_down_weight = None
lora_up_weight = None
weights_loaded = False
if conv2d:
U = U.unsqueeze(2).unsqueeze(3)
Vh = Vh.unsqueeze(2).unsqueeze(3)
if device:
U = U.to(org_device)
Vh = Vh.to(org_device)
o_lora_sd[block_down_name + "." + "lora_down.weight"] = Vh.to(save_dtype).contiguous()
o_lora_sd[block_up_name + "." + "lora_up.weight"] = U.to(save_dtype).contiguous()
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(new_alpha).to(save_dtype)
block_down_name = None
block_up_name = None
lora_down_weight = None
lora_up_weight = None
weights_loaded = False
if verbose:
print(verbose_str)
print("resizing complete")
return o_lora_sd, network_dim, new_alpha
print("resizing complete")
return o_lora_sd
def resize(args):
def str_to_dtype(p):
if p == 'float':
return torch.float
if p == 'fp16':
return torch.float16
if p == 'bf16':
return torch.bfloat16
return None
def str_to_dtype(p):
if p == 'float':
return torch.float
if p == 'fp16':
return torch.float16
if p == 'bf16':
return torch.bfloat16
return None
merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32
save_dtype = str_to_dtype(args.save_precision)
if save_dtype is None:
save_dtype = merge_dtype
merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32
save_dtype = str_to_dtype(args.save_precision)
if save_dtype is None:
save_dtype = merge_dtype
state_dict = resize_lora_model(args.model, args.new_rank, merge_dtype, save_dtype)
print("loading Model...")
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
print(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
print("resizing rank...")
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.verbose)
# update metadata
if metadata is None:
metadata = {}
comment = metadata.get("ss_training_comment", "")
metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}"
metadata["ss_network_dim"] = str(args.new_rank)
metadata["ss_network_alpha"] = str(new_alpha)
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
print(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
if __name__ == '__main__':
@@ -161,6 +191,8 @@ if __name__ == '__main__':
parser.add_argument("--model", type=str, default=None,
help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors")
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
parser.add_argument("--verbose", action="store_true",
help="Display verbose resizing information / rank変更時の詳細情報を出力する")
args = parser.parse_args()
resize(args)

164
networks/svd_merge_lora.py Normal file
View File

@@ -0,0 +1,164 @@
import math
import argparse
import os
import torch
from safetensors.torch import load_file, save_file
from tqdm import tqdm
import library.model_util as model_util
import lora
CLAMP_QUANTILE = 0.99
def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == '.safetensors':
sd = load_file(file_name)
else:
sd = torch.load(file_name, map_location='cpu')
for key in list(sd.keys()):
if type(sd[key]) == torch.Tensor:
sd[key] = sd[key].to(dtype)
return sd
def save_to_file(file_name, model, state_dict, dtype):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
if os.path.splitext(file_name)[1] == '.safetensors':
save_file(model, file_name)
else:
torch.save(model, file_name)
def merge_lora_models(models, ratios, new_rank, device, merge_dtype):
merged_sd = {}
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
lora_sd = load_state_dict(model, merge_dtype)
# merge
print(f"merging...")
for key in tqdm(list(lora_sd.keys())):
if 'lora_down' not in key:
continue
lora_module_name = key[:key.rfind(".lora_down")]
down_weight = lora_sd[key]
network_dim = down_weight.size()[0]
up_weight = lora_sd[lora_module_name + '.lora_up.weight']
alpha = lora_sd.get(lora_module_name + '.alpha', network_dim)
in_dim = down_weight.size()[1]
out_dim = up_weight.size()[0]
conv2d = len(down_weight.size()) == 4
print(lora_module_name, network_dim, alpha, in_dim, out_dim)
# make original weight if not exist
if lora_module_name not in merged_sd:
weight = torch.zeros((out_dim, in_dim, 1, 1) if conv2d else (out_dim, in_dim), dtype=merge_dtype)
if device:
weight = weight.to(device)
else:
weight = merged_sd[lora_module_name]
# merge to weight
if device:
up_weight = up_weight.to(device)
down_weight = down_weight.to(device)
# W <- W + U * D
scale = (alpha / network_dim)
if not conv2d: # linear
weight = weight + ratio * (up_weight @ down_weight) * scale
else:
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
).unsqueeze(2).unsqueeze(3) * scale
merged_sd[lora_module_name] = weight
# extract from merged weights
print("extract new lora...")
merged_lora_sd = {}
with torch.no_grad():
for lora_module_name, mat in tqdm(list(merged_sd.items())):
conv2d = (len(mat.size()) == 4)
if conv2d:
mat = mat.squeeze()
U, S, Vh = torch.linalg.svd(mat)
U = U[:, :new_rank]
S = S[:new_rank]
U = U @ torch.diag(S)
Vh = Vh[:new_rank, :]
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
low_val = -hi_val
U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)
up_weight = U
down_weight = Vh
if conv2d:
up_weight = up_weight.unsqueeze(2).unsqueeze(3)
down_weight = down_weight.unsqueeze(2).unsqueeze(3)
merged_lora_sd[lora_module_name + '.lora_up.weight'] = up_weight.to("cpu").contiguous()
merged_lora_sd[lora_module_name + '.lora_down.weight'] = down_weight.to("cpu").contiguous()
merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(new_rank)
return merged_lora_sd
def merge(args):
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
def str_to_dtype(p):
if p == 'float':
return torch.float
if p == 'fp16':
return torch.float16
if p == 'bf16':
return torch.bfloat16
return None
merge_dtype = str_to_dtype(args.precision)
save_dtype = str_to_dtype(args.save_precision)
if save_dtype is None:
save_dtype = merge_dtype
state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, args.device, merge_dtype)
print(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--save_precision", type=str, default=None,
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ")
parser.add_argument("--precision", type=str, default="float",
choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度floatを推奨")
parser.add_argument("--save_to", type=str, default=None,
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
parser.add_argument("--models", type=str, nargs='*',
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors")
parser.add_argument("--ratios", type=float, nargs='*',
help="ratios for each model / それぞれのLoRAモデルの比率")
parser.add_argument("--new_rank", type=int, default=4,
help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
args = parser.parse_args()
merge(args)

View File

@@ -1,23 +1,24 @@
accelerate==0.15.0
transformers==4.26.0
ftfy
albumentations
opencv-python
einops
ftfy==6.1.1
albumentations==1.3.0
opencv-python==4.7.0.68
einops==0.6.0
diffusers[torch]==0.10.2
pytorch_lightning
pytorch-lightning==1.9.0
bitsandbytes==0.35.0
tensorboard
tensorboard==2.10.1
safetensors==0.2.6
gradio
altair
easygui
gradio==3.16.2
altair==4.2.2
easygui==0.98.3
# for BLIP captioning
requests
timm==0.4.12
fairscale==0.4.4
requests==2.28.2
timm==0.6.12
fairscale==0.4.13
# for WD14 captioning
tensorflow<2.11
huggingface-hub
# tensorflow<2.11
tensorflow==2.10.1
huggingface-hub==0.12.0
# for kohya_ss library
.
.

24
tools/canny.py Normal file
View File

@@ -0,0 +1,24 @@
import argparse
import cv2
def canny(args):
img = cv2.imread(args.input)
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
canny_img = cv2.Canny(img, args.thres1, args.thres2)
# canny_img = 255 - canny_img
cv2.imwrite(args.output, canny_img)
print("done!")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str, default=None, help="input path")
parser.add_argument("--output", type=str, default=None, help="output path")
parser.add_argument("--thres1", type=int, default=32, help="thres1")
parser.add_argument("--thres2", type=int, default=224, help="thres2")
args = parser.parse_args()
canny(args)

View File

@@ -0,0 +1,320 @@
from typing import List, NamedTuple, Any
import numpy as np
import cv2
import torch
from safetensors.torch import load_file
from diffusers import UNet2DConditionModel
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
import library.model_util as model_util
class ControlNetInfo(NamedTuple):
unet: Any
net: Any
prep: Any
weight: float
ratio: float
class ControlNet(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
# make control model
self.control_model = torch.nn.Module()
dims = [320, 320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280, 1280]
zero_convs = torch.nn.ModuleList()
for i, dim in enumerate(dims):
sub_list = torch.nn.ModuleList([torch.nn.Conv2d(dim, dim, 1)])
zero_convs.append(sub_list)
self.control_model.add_module("zero_convs", zero_convs)
middle_block_out = torch.nn.Conv2d(1280, 1280, 1)
self.control_model.add_module("middle_block_out", torch.nn.ModuleList([middle_block_out]))
dims = [16, 16, 32, 32, 96, 96, 256, 320]
strides = [1, 1, 2, 1, 2, 1, 2, 1]
prev_dim = 3
input_hint_block = torch.nn.Sequential()
for i, (dim, stride) in enumerate(zip(dims, strides)):
input_hint_block.append(torch.nn.Conv2d(prev_dim, dim, 3, stride, 1))
if i < len(dims) - 1:
input_hint_block.append(torch.nn.SiLU())
prev_dim = dim
self.control_model.add_module("input_hint_block", input_hint_block)
def load_control_net(v2, unet, model):
device = unet.device
# control sdからキー変換しつつU-Netに対応する部分のみ取り出し、DiffusersのU-Netに読み込む
# state dictを読み込む
print(f"ControlNet: loading control SD model : {model}")
if model_util.is_safetensors(model):
ctrl_sd_sd = load_file(model)
else:
ctrl_sd_sd = torch.load(model, map_location='cpu')
ctrl_sd_sd = ctrl_sd_sd.pop("state_dict", ctrl_sd_sd)
# 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む
is_difference = "difference" in ctrl_sd_sd
print("ControlNet: loading difference")
# ControlNetには存在しないキーがあるので、まず現在のU-NetでSD版の全keyを作っておく
# またTransfer Controlの元weightとなる
ctrl_unet_sd_sd = model_util.convert_unet_state_dict_to_sd(v2, unet.state_dict())
# 元のU-Netに影響しないようにコピーする。またprefixが付いていないので付ける
for key in list(ctrl_unet_sd_sd.keys()):
ctrl_unet_sd_sd["model.diffusion_model." + key] = ctrl_unet_sd_sd.pop(key).clone()
zero_conv_sd = {}
for key in list(ctrl_sd_sd.keys()):
if key.startswith("control_"):
unet_key = "model.diffusion_" + key[len("control_"):]
if unet_key not in ctrl_unet_sd_sd: # zero conv
zero_conv_sd[key] = ctrl_sd_sd[key]
continue
if is_difference: # Transfer Control
ctrl_unet_sd_sd[unet_key] += ctrl_sd_sd[key].to(device, dtype=unet.dtype)
else:
ctrl_unet_sd_sd[unet_key] = ctrl_sd_sd[key].to(device, dtype=unet.dtype)
unet_config = model_util.create_unet_diffusers_config(v2)
ctrl_unet_du_sd = model_util.convert_ldm_unet_checkpoint(v2, ctrl_unet_sd_sd, unet_config) # DiffUsers版ControlNetのstate dict
# ControlNetのU-Netを作成する
ctrl_unet = UNet2DConditionModel(**unet_config)
info = ctrl_unet.load_state_dict(ctrl_unet_du_sd)
print("ControlNet: loading Control U-Net:", info)
# U-Net以外のControlNetを作成する
# TODO support middle only
ctrl_net = ControlNet()
info = ctrl_net.load_state_dict(zero_conv_sd)
print("ControlNet: loading ControlNet:", info)
ctrl_unet.to(unet.device, dtype=unet.dtype)
ctrl_net.to(unet.device, dtype=unet.dtype)
return ctrl_unet, ctrl_net
def load_preprocess(prep_type: str):
if prep_type is None or prep_type.lower() == "none":
return None
if prep_type.startswith("canny"):
args = prep_type.split("_")
th1 = int(args[1]) if len(args) >= 2 else 63
th2 = int(args[2]) if len(args) >= 3 else 191
def canny(img):
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
return cv2.Canny(img, th1, th2)
return canny
print("Unsupported prep type:", prep_type)
return None
def preprocess_ctrl_net_hint_image(image):
image = np.array(image).astype(np.float32) / 255.0
image = image[:, :, ::-1].copy() # rgb to bgr
image = image[None].transpose(0, 3, 1, 2) # nchw
image = torch.from_numpy(image)
return image # 0 to 1
def get_guided_hints(control_nets: List[ControlNetInfo], num_latent_input, b_size, hints):
guided_hints = []
for i, cnet_info in enumerate(control_nets):
# hintは 1枚目の画像のcnet1, 1枚目の画像のcnet2, 1枚目の画像のcnet3, 2枚目の画像のcnet1, 2枚目の画像のcnet2 ... と並んでいること
b_hints = []
if len(hints) == 1: # すべて同じ画像をhintとして使う
hint = hints[0]
if cnet_info.prep is not None:
hint = cnet_info.prep(hint)
hint = preprocess_ctrl_net_hint_image(hint)
b_hints = [hint for _ in range(b_size)]
else:
for bi in range(b_size):
hint = hints[(bi * len(control_nets) + i) % len(hints)]
if cnet_info.prep is not None:
hint = cnet_info.prep(hint)
hint = preprocess_ctrl_net_hint_image(hint)
b_hints.append(hint)
b_hints = torch.cat(b_hints, dim=0)
b_hints = b_hints.to(cnet_info.unet.device, dtype=cnet_info.unet.dtype)
guided_hint = cnet_info.net.control_model.input_hint_block(b_hints)
guided_hints.append(guided_hint)
return guided_hints
def call_unet_and_control_net(step, num_latent_input, original_unet, control_nets: List[ControlNetInfo], guided_hints, current_ratio, sample, timestep, encoder_hidden_states):
# ControlNet
# 複数のControlNetの場合は、出力をマージするのではなく交互に適用する
cnet_cnt = len(control_nets)
cnet_idx = step % cnet_cnt
cnet_info = control_nets[cnet_idx]
# print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
if cnet_info.ratio < current_ratio:
return original_unet(sample, timestep, encoder_hidden_states)
guided_hint = guided_hints[cnet_idx]
guided_hint = guided_hint.repeat((num_latent_input, 1, 1, 1))
outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states)
outs = [o * cnet_info.weight for o in outs]
# U-Net
return unet_forward(False, cnet_info.net, original_unet, None, outs, sample, timestep, encoder_hidden_states)
"""
# これはmergeのバージョン
# ControlNet
cnet_outs_list = []
for i, cnet_info in enumerate(control_nets):
# print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
if cnet_info.ratio < current_ratio:
continue
guided_hint = guided_hints[i]
outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states)
for i in range(len(outs)):
outs[i] *= cnet_info.weight
cnet_outs_list.append(outs)
count = len(cnet_outs_list)
if count == 0:
return original_unet(sample, timestep, encoder_hidden_states)
# sum of controlnets
for i in range(1, count):
cnet_outs_list[0] += cnet_outs_list[i]
# U-Net
return unet_forward(False, cnet_info.net, original_unet, None, cnet_outs_list[0], sample, timestep, encoder_hidden_states)
"""
def unet_forward(is_control_net, control_net: ControlNet, unet: UNet2DConditionModel, guided_hint, ctrl_outs, sample, timestep, encoder_hidden_states):
# copy from UNet2DConditionModel
default_overall_up_factor = 2**unet.num_upsamplers
forward_upsample_size = False
upsample_size = None
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
print("Forward upsample size to force interpolation output size.")
forward_upsample_size = True
# 0. center input if necessary
if unet.config.center_input_sample:
sample = 2 * sample - 1.0
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
t_emb = unet.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=unet.dtype)
emb = unet.time_embedding(t_emb)
outs = [] # output of ControlNet
zc_idx = 0
# 2. pre-process
sample = unet.conv_in(sample)
if is_control_net:
sample += guided_hint
outs.append(control_net.control_model.zero_convs[zc_idx][0](sample)) # , emb, encoder_hidden_states))
zc_idx += 1
# 3. down
down_block_res_samples = (sample,)
for downsample_block in unet.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
if is_control_net:
for rs in res_samples:
outs.append(control_net.control_model.zero_convs[zc_idx][0](rs)) # , emb, encoder_hidden_states))
zc_idx += 1
down_block_res_samples += res_samples
# 4. mid
sample = unet.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
if is_control_net:
outs.append(control_net.control_model.middle_block_out[0](sample))
return outs
if not is_control_net:
sample += ctrl_outs.pop()
# 5. up
for i, upsample_block in enumerate(unet.up_blocks):
is_final_block = i == len(unet.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
if not is_control_net and len(ctrl_outs) > 0:
res_samples = list(res_samples)
apply_ctrl_outs = ctrl_outs[-len(res_samples):]
ctrl_outs = ctrl_outs[:-len(res_samples)]
for j in range(len(res_samples)):
res_samples[j] = res_samples[j] + apply_ctrl_outs[j]
res_samples = tuple(res_samples)
# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
upsample_size=upsample_size,
)
else:
sample = upsample_block(
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
)
# 6. post-process
sample = unet.conv_norm_out(sample)
sample = unet.conv_act(sample)
sample = unet.conv_out(sample)
return UNet2DConditionOutput(sample=sample)

View File

@@ -0,0 +1,122 @@
import glob
import os
import cv2
import argparse
import shutil
import math
from PIL import Image
import numpy as np
def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2, interpolation=None, save_as_png=False, copy_associated_files=False):
# Split the max_resolution string by "," and strip any whitespaces
max_resolutions = [res.strip() for res in max_resolution.split(',')]
# # Calculate max_pixels from max_resolution string
# max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1])
# Create destination folder if it does not exist
if not os.path.exists(dst_img_folder):
os.makedirs(dst_img_folder)
# Select interpolation method
if interpolation == 'lanczos4':
cv2_interpolation = cv2.INTER_LANCZOS4
elif interpolation == 'cubic':
cv2_interpolation = cv2.INTER_CUBIC
else:
cv2_interpolation = cv2.INTER_AREA
# Iterate through all files in src_img_folder
img_exts = (".png", ".jpg", ".jpeg", ".webp", ".bmp") # copy from train_util.py
for filename in os.listdir(src_img_folder):
# Check if the image is png, jpg or webp etc...
if not filename.endswith(img_exts):
# Copy the file to the destination folder if not png, jpg or webp etc (.txt or .caption or etc.)
shutil.copy(os.path.join(src_img_folder, filename), os.path.join(dst_img_folder, filename))
continue
# Load image
# img = cv2.imread(os.path.join(src_img_folder, filename))
image = Image.open(os.path.join(src_img_folder, filename))
if not image.mode == "RGB":
image = image.convert("RGB")
img = np.array(image, np.uint8)
base, _ = os.path.splitext(filename)
for max_resolution in max_resolutions:
# Calculate max_pixels from max_resolution string
max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1])
# Calculate current number of pixels
current_pixels = img.shape[0] * img.shape[1]
# Check if the image needs resizing
if current_pixels > max_pixels:
# Calculate scaling factor
scale_factor = max_pixels / current_pixels
# Calculate new dimensions
new_height = int(img.shape[0] * math.sqrt(scale_factor))
new_width = int(img.shape[1] * math.sqrt(scale_factor))
# Resize image
img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation)
else:
new_height, new_width = img.shape[0:2]
# Calculate the new height and width that are divisible by divisible_by (with/without resizing)
new_height = new_height if new_height % divisible_by == 0 else new_height - new_height % divisible_by
new_width = new_width if new_width % divisible_by == 0 else new_width - new_width % divisible_by
# Center crop the image to the calculated dimensions
y = int((img.shape[0] - new_height) / 2)
x = int((img.shape[1] - new_width) / 2)
img = img[y:y + new_height, x:x + new_width]
# Split filename into base and extension
new_filename = base + '+' + max_resolution + ('.png' if save_as_png else '.jpg')
# Save resized image in dst_img_folder
# cv2.imwrite(os.path.join(dst_img_folder, new_filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100])
image = Image.fromarray(img)
image.save(os.path.join(dst_img_folder, new_filename), quality=100)
proc = "Resized" if current_pixels > max_pixels else "Saved"
print(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}")
# If other files with same basename, copy them with resolution suffix
if copy_associated_files:
asoc_files = glob.glob(os.path.join(src_img_folder, base + ".*"))
for asoc_file in asoc_files:
ext = os.path.splitext(asoc_file)[1]
if ext in img_exts:
continue
for max_resolution in max_resolutions:
new_asoc_file = base + '+' + max_resolution + ext
print(f"Copy {asoc_file} as {new_asoc_file}")
shutil.copy(os.path.join(src_img_folder, asoc_file), os.path.join(dst_img_folder, new_asoc_file))
def main():
parser = argparse.ArgumentParser(
description='Resize images in a folder to a specified max resolution(s) / 指定されたフォルダ内の画像を指定した最大画像サイズ(面積)以下にアスペクト比を維持したままリサイズします')
parser.add_argument('src_img_folder', type=str, help='Source folder containing the images / 元画像のフォルダ')
parser.add_argument('dst_img_folder', type=str, help='Destination folder to save the resized images / リサイズ後の画像を保存するフォルダ')
parser.add_argument('--max_resolution', type=str,
help='Maximum resolution(s) in the format "512x512,384x384, etc, etc" / 最大画像サイズをカンマ区切りで指定 ("512x512,384x384, etc, etc" など)', default="512x512,384x384,256x256,128x128")
parser.add_argument('--divisible_by', type=int,
help='Ensure new dimensions are divisible by this value / リサイズ後の画像のサイズをこの値で割り切れるようにします', default=1)
parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4'],
default='area', help='Interpolation method for resizing / リサイズ時の補完方法')
parser.add_argument('--save_as_png', action='store_true', help='Save as png format / png形式で保存')
parser.add_argument('--copy_associated_files', action='store_true',
help='Copy files with same base name to images (captions etc) / 画像と同じファイル名(拡張子を除く)のファイルもコピーする')
args = parser.parse_args()
resize_images(args.src_img_folder, args.dst_img_folder, args.max_resolution,
args.divisible_by, args.interpolation, args.save_as_png, args.copy_associated_files)
if __name__ == '__main__':
main()

View File

@@ -35,10 +35,16 @@ def train(args):
train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, args.prior_loss_weight,
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset)
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
args.bucket_reso_steps, args.bucket_no_upscale,
args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset)
if args.no_token_padding:
train_dataset.disable_token_padding()
# 学習データのdropout率を設定する
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
train_dataset.make_buckets()
if args.debug_dataset:
@@ -109,25 +115,12 @@ def train(args):
# 学習に必要なクラスを準備する
print("prepare optimizer, data loader etc.")
# 8-bit Adamを使う
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
print("use 8-bit Adam optimizer")
optimizer_class = bnb.optim.AdamW8bit
else:
optimizer_class = torch.optim.AdamW
if train_text_encoder:
trainable_params = (itertools.chain(unet.parameters(), text_encoder.parameters()))
else:
trainable_params = unet.parameters()
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
@@ -143,9 +136,10 @@ def train(args):
if args.stop_text_encoder_training is None:
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
# lr schedulerを用意する
lr_scheduler = diffusers.optimization.get_scheduler(
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps)
# lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps,
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
# 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする
if args.full_fp16:
@@ -200,8 +194,11 @@ def train(args):
if accelerator.is_main_process:
accelerator.init_trackers("dreambooth")
loss_list = []
loss_total = 0.0
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset.set_current_epoch(epoch + 1)
# 指定したステップ数までText Encoderを学習するepoch最初の状態
unet.train()
@@ -209,7 +206,6 @@ def train(args):
if args.gradient_checkpointing or global_step < args.stop_text_encoder_training:
text_encoder.train()
loss_total = 0
for step, batch in enumerate(train_dataloader):
# 指定したステップ数でText Encoderの学習を止める
if global_step == args.stop_text_encoder_training:
@@ -226,10 +222,13 @@ def train(args):
else:
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
b_size = latents.shape[0]
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
b_size = latents.shape[0]
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
# Get the text embedding for conditioning
with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):
@@ -263,12 +262,12 @@ def train(args):
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss)
if accelerator.sync_gradients:
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
if train_text_encoder:
params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters()))
else:
params_to_clip = unet.parameters()
accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
@@ -281,11 +280,18 @@ def train(args):
current_loss = loss.detach().item()
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
accelerator.log(logs, step=global_step)
if epoch == 0:
loss_list.append(current_loss)
else:
loss_total -= loss_list[step]
loss_list[step] = current_loss
loss_total += current_loss
avr_loss = loss_total / (step+1)
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
@@ -293,7 +299,7 @@ def train(args):
break
if args.logging_dir is not None:
logs = {"epoch_loss": loss_total / len(train_dataloader)}
logs = {"loss/epoch": loss_total / len(loss_list)}
accelerator.log(logs, step=epoch+1)
accelerator.wait_for_everyone()
@@ -326,9 +332,10 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, False)
train_util.add_dataset_arguments(parser, True, False, True)
train_util.add_training_arguments(parser, True)
train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
parser.add_argument("--no_token_padding", action="store_true",
help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にするDiffusers版DreamBoothと同じ動作")

View File

@@ -1,6 +1,5 @@
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
from torch.optim import Optimizer
from typing import Optional, Union
from torch.cuda.amp import autocast
from torch.nn.parallel import DistributedDataParallel as DDP
import importlib
import argparse
import gc
@@ -24,83 +23,24 @@ def collate_fn(examples):
return examples[0]
# TODO 他のスクリプトと共通化する
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
logs = {"loss/current": current_loss, "loss/average": avr_loss}
if args.network_train_unet_only:
logs["lr/unet"] = lr_scheduler.get_last_lr()[0]
logs["lr/unet"] = float(lr_scheduler.get_last_lr()[0])
elif args.network_train_text_encoder_only:
logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
else:
logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
logs["lr/unet"] = lr_scheduler.get_last_lr()[-1] # may be same to textencoder
logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
logs["lr/unet"] = float(lr_scheduler.get_last_lr()[-1]) # may be same to textencoder
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet.
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]['d']*lr_scheduler.optimizers[-1].param_groups[0]['lr']
return logs
# Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
# code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
# Which is a newer release of diffusers than currently packaged with sd-scripts
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
def get_scheduler_fix(
name: Union[str, SchedulerType],
optimizer: Optimizer,
num_warmup_steps: Optional[int] = None,
num_training_steps: Optional[int] = None,
num_cycles: int = 1,
power: float = 1.0,
):
"""
Unified API to get any scheduler from its name.
Args:
name (`str` or `SchedulerType`):
The name of the scheduler to use.
optimizer (`torch.optim.Optimizer`):
The optimizer that will be used during training.
num_warmup_steps (`int`, *optional*):
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
num_training_steps (`int``, *optional*):
The number of training steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
num_cycles (`int`, *optional*):
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
power (`float`, *optional*, defaults to 1.0):
Power factor. See `POLYNOMIAL` scheduler
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
"""
name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
if name == SchedulerType.CONSTANT:
return schedule_func(optimizer)
# All other schedulers require `num_warmup_steps`
if num_warmup_steps is None:
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
if name == SchedulerType.CONSTANT_WITH_WARMUP:
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
# All other schedulers require `num_training_steps`
if num_training_steps is None:
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
if name == SchedulerType.COSINE_WITH_RESTARTS:
return schedule_func(
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
)
if name == SchedulerType.POLYNOMIAL:
return schedule_func(
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
)
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
def train(args):
session_id = random.randint(0, 2**32)
training_started_at = time.time()
@@ -120,15 +60,22 @@ def train(args):
print("Use DreamBooth method.")
train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, args.prior_loss_weight,
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset)
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
args.bucket_reso_steps, args.bucket_no_upscale,
args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range,
args.random_crop, args.debug_dataset)
else:
print("Train with captions.")
train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
args.bucket_reso_steps, args.bucket_no_upscale,
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
args.dataset_repeats, args.debug_dataset)
# 学習データのdropout率を設定する
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
train_dataset.make_buckets()
if args.debug_dataset:
@@ -148,6 +95,11 @@ def train(args):
# モデルを読み込む
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
# work on low-ram device
if args.lowram:
text_encoder.to("cuda")
unet.to("cuda")
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
@@ -194,21 +146,8 @@ def train(args):
# 学習に必要なクラスを準備する
print("prepare optimizer, data loader etc.")
# 8-bit Adamを使う
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
print("use 8-bit Adam optimizer")
optimizer_class = bnb.optim.AdamW8bit
else:
optimizer_class = torch.optim.AdamW
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
@@ -222,11 +161,9 @@ def train(args):
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# lr schedulerを用意する
# lr_scheduler = diffusers.optimization.get_scheduler(
lr_scheduler = get_scheduler_fix(
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
# 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする
if args.full_fp16:
@@ -251,17 +188,26 @@ def train(args):
unet.requires_grad_(False)
unet.to(accelerator.device, dtype=weight_dtype)
text_encoder.requires_grad_(False)
text_encoder.to(accelerator.device, dtype=weight_dtype)
text_encoder.to(accelerator.device)
if args.gradient_checkpointing: # according to TI example in Diffusers, train is required
unet.train()
text_encoder.train()
# set top parameter requires_grad = True for gradient checkpointing works
text_encoder.text_model.embeddings.requires_grad_(True)
if type(text_encoder) == DDP:
text_encoder.module.text_model.embeddings.requires_grad_(True)
else:
text_encoder.text_model.embeddings.requires_grad_(True)
else:
unet.eval()
text_encoder.eval()
# support DistributedDataParallel
if type(text_encoder) == DDP:
text_encoder = text_encoder.module
unet = unet.module
network = network.module
network.prepare_grad_etc(text_encoder, unet)
if not cache_latents:
@@ -329,15 +275,26 @@ def train(args):
"ss_shuffle_caption": bool(args.shuffle_caption),
"ss_cache_latents": bool(args.cache_latents),
"ss_enable_bucket": bool(train_dataset.enable_bucket),
"ss_bucket_no_upscale": bool(train_dataset.bucket_no_upscale),
"ss_min_bucket_reso": train_dataset.min_bucket_reso,
"ss_max_bucket_reso": train_dataset.max_bucket_reso,
"ss_seed": args.seed,
"ss_lowram": args.lowram,
"ss_keep_tokens": args.keep_tokens,
"ss_noise_offset": args.noise_offset,
"ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info),
"ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info),
"ss_tag_frequency": json.dumps(train_dataset.tag_frequency),
"ss_bucket_info": json.dumps(train_dataset.bucket_info),
"ss_training_comment": args.training_comment # will not be updated after training
"ss_training_comment": args.training_comment, # will not be updated after training
"ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
"ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""),
"ss_max_grad_norm": args.max_grad_norm,
"ss_caption_dropout_rate": args.caption_dropout_rate,
"ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs,
"ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate,
"ss_face_crop_aug_range": args.face_crop_aug_range,
"ss_prior_loss_weight": args.prior_loss_weight,
}
# uncomment if another network is added
@@ -371,13 +328,16 @@ def train(args):
if accelerator.is_main_process:
accelerator.init_trackers("network_train")
loss_list = []
loss_total = 0.0
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset.set_current_epoch(epoch + 1)
metadata["ss_epoch"] = str(epoch+1)
network.on_epoch_start(text_encoder, unet)
loss_total = 0
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(network):
with torch.no_grad():
@@ -396,6 +356,9 @@ def train(args):
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
@@ -406,7 +369,8 @@ def train(args):
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
with accelerator.autocast():
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
if args.v_parameterization:
# v-parameterization training
@@ -423,9 +387,9 @@ def train(args):
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss)
if accelerator.sync_gradients:
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = network.get_trainable_params()
accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
@@ -437,8 +401,13 @@ def train(args):
global_step += 1
current_loss = loss.detach().item()
if epoch == 0:
loss_list.append(current_loss)
else:
loss_total -= loss_list[step]
loss_list[step] = current_loss
loss_total += current_loss
avr_loss = loss_total / (step+1)
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
@@ -450,7 +419,7 @@ def train(args):
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(train_dataloader)}
logs = {"loss/epoch": loss_total / len(loss_list)}
accelerator.log(logs, step=epoch+1)
accelerator.wait_for_everyone()
@@ -461,6 +430,7 @@ def train(args):
def save_func():
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
ckpt_file = os.path.join(args.output_dir, ckpt_name)
metadata["ss_training_finished_at"] = str(time.time())
print(f"saving checkpoint: {ckpt_file}")
unwrap_model(network).save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
@@ -478,6 +448,7 @@ def train(args):
# end of epoch
metadata["ss_epoch"] = str(num_train_epochs)
metadata["ss_training_finished_at"] = str(time.time())
is_main_process = accelerator.is_main_process
if is_main_process:
@@ -506,8 +477,9 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, True)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, True)
train_util.add_optimizer_arguments(parser)
parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
@@ -515,10 +487,6 @@ if __name__ == '__main__':
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
parser.add_argument("--lr_scheduler_power", type=float, default=1,
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
parser.add_argument("--network_weights", type=str, default=None,
help="pretrained weights for network / 学習するネットワークの初期重み")

View File

@@ -50,12 +50,14 @@ accelerate launch --num_cpu_threads_per_process 1 train_network.py
--train_data_dir=..\data\db\char1 --output_dir=..\lora_train1
--reg_data_dir=..\data\db\reg1 --prior_loss_weight=1.0
--resolution=448,640 --train_batch_size=1 --learning_rate=1e-4
--max_train_steps=400 --use_8bit_adam --xformers --mixed_precision=fp16
--max_train_steps=400 --optimizer_type=AdamW8bit --xformers --mixed_precision=fp16
--save_every_n_epochs=1 --save_model_as=safetensors --clip_skip=2 --seed=42 --color_aug
--network_module=networks.lora
```
--output_dirオプションで指定したディレクトリに、LoRAのモデルが保存されます。
2023/2/22:オプティマイザの指定方法が変わりました。[こちら](#オプティマイザの指定について)をご覧ください。)
--output_dirオプションで指定したフォルダに、LoRAのモデルが保存されます。
その他、以下のオプションが指定できます。
@@ -76,6 +78,42 @@ accelerate launch --num_cpu_threads_per_process 1 train_network.py
--network_train_unet_onlyと--network_train_text_encoder_onlyの両方とも未指定時デフォルトはText EncoderとU-Netの両方のLoRAモジュールを有効にします。
## オプティマイザの指定について
--optimizer_type オプションでオプティマイザの種類を指定します。以下が指定できます。
- AdamW : [torch.optim.AdamW](https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html)
- 過去のバージョンのオプション未指定時と同じ
- AdamW8bit : 引数は同上
- 過去のバージョンの--use_8bit_adam指定時と同じ
- Lion : https://github.com/lucidrains/lion-pytorch
- 過去のバージョンの--use_lion_optimizer指定時と同じ
- SGDNesterov : [torch.optim.SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html), nesterov=True
- SGDNesterov8bit : 引数は同上
- DAdaptation : https://github.com/facebookresearch/dadaptation
- AdaFactor : [Transformers AdaFactor](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules)
- 任意のオプティマイザ
オプティマイザのオプション引数は--optimizer_argsオプションで指定してください。key=valueの形式で、複数の値が指定できます。また、valueはカンマ区切りで複数の値が指定できます。たとえばAdamWオプティマイザに引数を指定する場合は、``--optimizer_args weight_decay=0.01 betas=.9,.999``のようになります。
オプション引数を指定する場合は、それぞれのオプティマイザの仕様をご確認ください。
一部のオプティマイザでは必須の引数があり、省略すると自動的に追加されますSGDNesterovのmomentumなど。コンソールの出力を確認してください。
D-Adaptationオプティマイザは学習率を自動調整します。学習率のオプションに指定した値は学習率そのものではなくD-Adaptationが決定した学習率の適用率になりますので、通常は1.0を指定してください。Text EncoderにU-Netの半分の学習率を指定したい場合は、``--text_encoder_lr=0.5 --unet_lr=1.0``と指定します。
AdaFactorオプティマイザはrelative_step=Trueを指定すると学習率を自動調整できます省略時はデフォルトで追加されます。自動調整する場合は学習率のスケジューラにはadafactor_schedulerが強制的に使用されます。またscale_parameterとwarmup_initを指定するとよいようです。
自動調整する場合のオプション指定はたとえば ``--optimizer_args "relative_step=True" "scale_parameter=True" "warmup_init=True"`` のようになります。
学習率を自動調整しない場合はオプション引数 ``relative_step=False`` を追加してください。その場合、学習率のスケジューラにはconstant_with_warmupが、また勾配のclip normをしないことが推奨されているようです。そのため引数は ``--optimizer_type=adafactor --optimizer_args "relative_step=False" --lr_scheduler="constant_with_warmup" --max_grad_norm=0.0`` のようになります。
### 任意のオプティマイザを使う
``torch.optim`` のオプティマイザを使う場合にはクラス名のみを(``--optimizer_type=RMSprop``など)、他のモジュールのオプティマイザを使う時は「モジュール名.クラス名」を指定してください(``--optimizer_type=bitsandbytes.optim.lamb.LAMB``など)。
内部でimportlibしているだけで動作は未確認です。必要ならパッケージをインストールしてください。
## マージスクリプトについて
merge_lora.pyでStable DiffusionのモデルにLoRAの学習結果をマージしたり、複数のLoRAモデルをマージしたりできます。
@@ -178,6 +216,38 @@ Text Encoderが二つのモデルで同じ場合にはLoRAはU-NetのみのLoRA
- --save_precision
- LoRAの保存形式を"float", "fp16", "bf16"から指定します。省略時はfloatになります。
## 画像リサイズスクリプト
(のちほどドキュメントを整理しますがとりあえずここに説明を書いておきます。)
Aspect Ratio Bucketingの機能拡張で、小さな画像については拡大しないでそのまま教師データとすることが可能になりました。元の教師画像を縮小した画像を、教師データに加えると精度が向上したという報告とともに前処理用のスクリプトをいただきましたので整備して追加しました。bmaltais氏に感謝します。
### スクリプトの実行方法
以下のように指定してください。元の画像そのまま、およびリサイズ後の画像が変換先フォルダに保存されます。リサイズ後の画像には、ファイル名に ``+512x512`` のようにリサイズ先の解像度が付け加えられます(画像サイズとは異なります)。リサイズ先の解像度より小さい画像は拡大されることはありません。
```
python tools\resize_images_to_resolution.py --max_resolution 512x512,384x384,256x256 --save_as_png
--copy_associated_files 元画像フォルダ 変換先フォルダ
```
元画像フォルダ内の画像ファイルが、指定した解像度(複数指定可)と同じ面積になるようにリサイズされ、変換先フォルダに保存されます。画像以外のファイルはそのままコピーされます。
``--max_resolution`` オプションにリサイズ先のサイズを例のように指定してください。面積がそのサイズになるようにリサイズします。複数指定すると、それぞれの解像度でリサイズされます。``512x512,384x384,256x256``なら、変換先フォルダの画像は、元サイズとリサイズ後サイズ×3の計4枚になります。
``--save_as_png`` オプションを指定するとpng形式で保存します。省略するとjpeg形式quality=100で保存されます。
``--copy_associated_files`` オプションを指定すると、拡張子を除き画像と同じファイル名(たとえばキャプションなど)のファイルが、リサイズ後の画像のファイル名と同じ名前でコピーされます。
### その他のオプション
- divisible_by
- リサイズ後の画像のサイズ(縦、横のそれぞれ)がこの値で割り切れるように、画像中心を切り出します。
- interpolation
- 縮小時の補完方法を指定します。``area, cubic, lanczos4``から選択可能で、デフォルトは``area``です。
## 追加情報
### cloneofsimo氏のリポジトリとの違い

View File

@@ -98,12 +98,12 @@ def train(args):
# Convert the init_word to token_id
if args.init_word is not None:
init_token_id = tokenizer.encode(args.init_word, add_special_tokens=False)
assert len(
init_token_id) == 1, f"init word {args.init_word} is not converted to single token / 初期化単語が二つ以上のトークンに変換されます。別の単語を使ってください"
init_token_id = init_token_id[0]
init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False)
if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token:
print(
f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}")
else:
init_token_id = None
init_token_ids = None
# add new word to tokenizer, count is num_vectors_per_token
token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
@@ -120,9 +120,9 @@ def train(args):
# Initialise the newly added placeholder token with the embeddings of the initializer token
token_embeds = text_encoder.get_input_embeddings().weight.data
if init_token_id is not None:
for token_id in token_ids:
token_embeds[token_id] = token_embeds[init_token_id]
if init_token_ids is not None:
for i, token_id in enumerate(token_ids):
token_embeds[token_id] = token_embeds[init_token_ids[i % len(init_token_ids)]]
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
# load weights
@@ -143,13 +143,15 @@ def train(args):
print("Use DreamBooth method.")
train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, args.prior_loss_weight,
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset)
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
args.bucket_reso_steps, args.bucket_no_upscale,
args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset)
else:
print("Train with captions.")
train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
args.bucket_reso_steps, args.bucket_no_upscale,
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
args.dataset_repeats, args.debug_dataset)
@@ -196,22 +198,8 @@ def train(args):
# 学習に必要なクラスを準備する
print("prepare optimizer, data loader etc.")
# 8-bit Adamを使う
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
print("use 8-bit Adam optimizer")
optimizer_class = bnb.optim.AdamW8bit
else:
optimizer_class = torch.optim.AdamW
trainable_params = text_encoder.get_input_embeddings().parameters()
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
@@ -225,15 +213,16 @@ def train(args):
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# lr schedulerを用意する
lr_scheduler = diffusers.optimization.get_scheduler(
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
# acceleratorがなんかよろしくやってくれるらしい
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder, optimizer, train_dataloader, lr_scheduler)
index_no_updates = torch.arange(len(tokenizer)) < token_ids[0]
print(len(index_no_updates), torch.sum(index_no_updates))
# print(len(index_no_updates), torch.sum(index_no_updates))
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
# Freeze all parameters except for the token embeddings in text encoder
@@ -294,6 +283,7 @@ def train(args):
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset.set_current_epoch(epoch + 1)
text_encoder.train()
@@ -312,10 +302,14 @@ def train(args):
# Get the text embedding for conditioning
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, torch.float) # weight_dtype) use float instead of fp16/bf16 because text encoder is float
# weight_dtype) use float instead of fp16/bf16 because text encoder is float
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, torch.float)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
@@ -343,9 +337,9 @@ def train(args):
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss)
if accelerator.sync_gradients:
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = text_encoder.get_input_embeddings().parameters()
accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
@@ -362,7 +356,9 @@ def train(args):
current_loss = loss.detach().item()
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
accelerator.log(logs, step=global_step)
loss_total += current_loss
@@ -380,8 +376,8 @@ def train(args):
accelerator.wait_for_everyone()
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
d = updated_embs - bef_epo_embs
print(bef_epo_embs.size(), updated_embs.size(), d.mean(), d.min())
# d = updated_embs - bef_epo_embs
# print(bef_epo_embs.size(), updated_embs.size(), d.mean(), d.min())
if args.save_every_n_epochs is not None:
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
@@ -475,8 +471,9 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, True)
train_util.add_dataset_arguments(parser, True, True, False)
train_util.add_training_arguments(parser, True)
train_util.add_optimizer_arguments(parser)
parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
help="format to save the model (default is .pt) / モデル保存時の形式デフォルトはpt")
@@ -488,7 +485,7 @@ if __name__ == '__main__':
parser.add_argument("--token_string", type=str, default=None,
help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること")
parser.add_argument("--init_word", type=str, default=None,
help="word to initialize vector / ベクトルを初期化に使用する単語、tokenizerで一語になること")
help="words to initialize vector / ベクトルを初期化に使用する単語、複数可")
parser.add_argument("--use_object_template", action='store_true',
help="ignore caption and use default templates for object / キャプションは使わずデフォルトの物体用テンプレートで学習する")
parser.add_argument("--use_style_template", action='store_true',