Compare commits

...

128 Commits

Author SHA1 Message Date
Kohya S
e5ac095749 add about dev and sd3 branch to README 2024-11-07 21:39:47 +09:00
Kohya S
ca44e3e447 reduce VRAM usage, instead of increasing main RAM usage 2024-10-27 10:19:05 +09:00
Kohya S
56b4ea963e Fix LoRA metadata hash calculation bug in svd_merge_lora.py, sdxl_merge_lora.py, and resize_lora.py closes #1722 2024-10-26 22:01:10 +09:00
Kohya S
9c757c2fba fix SDXL block index to match LBW 2024-09-19 21:14:57 +09:00
Kohya S
b755ebd0a4 add LBW support for SDXL merge LoRA 2024-09-13 21:29:31 +09:00
Kohya S
f4a0bea6dc format by black 2024-09-13 21:26:06 +09:00
terracottahaniwa
734d2e5b2b Support Lora Block Weight (LBW) to svd_merge_lora.py (#1575)
* support lora block weight

* solve license incompatibility

* Fix issue: lbw index calculation
2024-09-13 20:45:35 +09:00
Kohya S
3387dc7306 formatting, update README 2024-09-13 19:45:42 +09:00
Kohya S
57ae44eb61 refactor to make safer 2024-09-13 19:45:00 +09:00
Maru-mee
1d7118a622 Support : OFT merge to base model (#1580)
* Support : OFT merge to base model

* Fix typo

* Fix typo_2

* Delete unused parameter 'eye'
2024-09-13 19:01:36 +09:00
Kohya S.
de25945a93 Merge pull request #1550 from kohya-ss/dependabot/github_actions/crate-ci/typos-1.24.3
Bump crate-ci/typos from 1.19.0 to 1.24.3
2024-09-07 10:50:46 +09:00
dependabot[bot]
1bcf8d600b Bump crate-ci/typos from 1.19.0 to 1.24.3
Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.19.0 to 1.24.3.
- [Release notes](https://github.com/crate-ci/typos/releases)
- [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md)
- [Commits](https://github.com/crate-ci/typos/compare/v1.19.0...v1.24.3)

---
updated-dependencies:
- dependency-name: crate-ci/typos
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-09-01 01:33:04 +00:00
Kohya S.
f8f5b16958 Merge pull request #1540 from kohya-ss/dependabot/pip/opencv-python-4.8.1.78
Bump opencv-python from 4.7.0.68 to 4.8.1.78
2024-08-31 21:37:07 +09:00
Kohya S.
826ab5ce2e Merge pull request #1532 from nandometzger/main
Update train_util.py, bug fix
2024-08-31 21:36:33 +09:00
dependabot[bot]
3a6154b7b0 Bump opencv-python from 4.7.0.68 to 4.8.1.78
Bumps [opencv-python](https://github.com/opencv/opencv-python) from 4.7.0.68 to 4.8.1.78.
- [Release notes](https://github.com/opencv/opencv-python/releases)
- [Commits](https://github.com/opencv/opencv-python/commits)

---
updated-dependencies:
- dependency-name: opencv-python
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-08-31 06:21:16 +00:00
Nando Metzger
2a3aefb4e4 Update train_util.py, bug fix 2024-08-30 08:15:05 +02:00
Kohya S
25f961bc77 fix to work cache_latents/text_encoder_outputs 2024-06-23 13:24:30 +09:00
Kohya S
71e2c91330 Merge pull request #1230 from kohya-ss/dependabot/github_actions/crate-ci/typos-1.19.0
Bump crate-ci/typos from 1.17.2 to 1.19.0
2024-04-07 21:14:18 +09:00
Kohya S
bfb352bc43 change huber_schedule from exponential to snr 2024-04-07 21:07:52 +09:00
Kohya S
c973b29da4 update readme 2024-04-07 20:51:52 +09:00
Kohya S
683f3d6ab3 Merge pull request #1212 from kohya-ss/dev
Version 0.8.6
2024-04-07 20:42:41 +09:00
Kohya S
dfa30790a9 update readme 2024-04-07 20:34:26 +09:00
Kohya S
d30ebb205c update readme, add metadata for network module 2024-04-07 14:58:17 +09:00
kabachuha
90b18795fc Add option to use Scheduled Huber Loss in all training pipelines to improve resilience to data corruption (#1228)
* add huber loss and huber_c compute to train_util

* add reduction modes

* add huber_c retrieval from timestep getter

* move get timesteps and huber to own function

* add conditional loss to all training scripts

* add cond loss to train network

* add (scheduled) huber_loss to args

* fixup twice timesteps getting

* PHL-schedule should depend on noise scheduler's num timesteps

* *2 multiplier to huber loss cause of 1/2 a^2 conv.

The Taylor expansion of sqrt near zero gives 1/2 a^2, which differs from a^2 of the standard MSE loss. This change scales them better against one another

* add option for smooth l1 (huber / delta)

* unify huber scheduling

* add snr huber scheduler

---------

Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
2024-04-07 13:54:21 +09:00
Kohya S
089727b5ee update readme 2024-04-07 12:42:49 +09:00
Kohya S
921036dd91 Merge pull request #1240 from kohya-ss/verify-command-line-args
verify command line args if wandb is enabled
2024-04-07 12:27:03 +09:00
ykume
cd587ce62c verify command line args if wandb is enabled 2024-04-05 08:23:03 +09:00
Kohya S
b748b48dbb fix attention couple+deep shink cause error in some reso 2024-04-03 12:43:08 +09:00
dependabot[bot]
80e9f72234 Bump crate-ci/typos from 1.17.2 to 1.19.0
Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.17.2 to 1.19.0.
- [Release notes](https://github.com/crate-ci/typos/releases)
- [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md)
- [Commits](https://github.com/crate-ci/typos/compare/v1.17.2...v1.19.0)

---
updated-dependencies:
- dependency-name: crate-ci/typos
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-04-01 01:50:22 +00:00
Kohya S
2258a1b753 add save/load hook to remove U-Net/TEs from state 2024-03-31 15:50:35 +09:00
Kohya S
059ee047f3 fix typo 2024-03-30 23:02:24 +09:00
Kohya S
2c2ca9d726 update tagger doc 2024-03-30 22:55:56 +09:00
Kohya S
f5323e3c4b update tagger doc 2024-03-30 22:10:37 +09:00
Kohya S
cae5aa0a56 update wd14 tagger and doc 2024-03-30 21:48:22 +09:00
Kohya S
6ba84288d9 Merge pull request #1216 from Disty0/dev
Rating support for WD Tagger
2024-03-30 18:50:49 +09:00
Kohya S
434dc408f9 update readme 2024-03-30 17:12:36 +09:00
Kohya S
ae3f625739 Merge branch 'dev' of https://github.com/kohya-ss/sd-scripts into dev 2024-03-30 14:57:43 +09:00
Kohya S
f1f30ab418 fix to work with num_beams>1 closes #1149 2024-03-30 14:57:39 +09:00
Disty0
bc586ce190 Add --use_rating_tags and --character_tags_first for WD Tagger 2024-03-29 13:56:42 +03:00
Disty0
4012fd24f6 IPEX fix pin_memory 2024-03-28 21:08:16 +03:00
Disty0
954731d564 fix typo 2024-03-27 22:00:59 +03:00
Disty0
dd9763be31 Rating support for WD Tagger 2024-03-27 21:53:40 +03:00
Kohya S
b86af6798d Merge pull request #1213 from Disty0/dev
Add OpenVINO and ROCm ONNX Runtime for WD14
2024-03-27 23:15:33 +09:00
Disty0
6f7e93d5cc Add OpenVINO and ROCm ONNX Runtime for WD14 2024-03-27 03:21:13 +03:00
Kohya S
6c08e97e1f update readme 2024-03-26 20:48:08 +09:00
Kohya S
78e0a7630c Merge pull request #1206 from kohya-ss/dataset-cache
Add metadata caching for DreamBooth dataset
2024-03-26 19:49:23 +09:00
Kohya S
c86e356013 Merge branch 'dev' into dataset-cache 2024-03-26 19:43:40 +09:00
Kohya S
5a2afb3588 Merge pull request #1207 from kohya-ss/masked-loss
Add masked loss
2024-03-26 19:41:31 +09:00
Kohya S
ab1e389347 Merge branch 'dev' into masked-loss 2024-03-26 19:39:30 +09:00
Kohya S
ea05e3fd5b Merge pull request #1139 from kohya-ss/deep-speed
Deep speed
2024-03-26 19:33:57 +09:00
Kohya S
a2b8531627 make each script consistent, fix to work w/o DeepSpeed 2024-03-25 22:28:46 +09:00
Kohya S
c24422fb9d Merge branch 'dev' into deep-speed 2024-03-25 22:11:05 +09:00
Kohya S
9c4492b58a fix pytorch version 2.1.1 to 2.1.2 2024-03-24 23:17:25 +09:00
Kohya S
9bbb28c361 update PyTorch version and reorganize dependencies 2024-03-24 22:06:37 +09:00
Kohya S
1648ade6da format by black 2024-03-24 20:55:48 +09:00
Kohya S
993b2ab4c1 Merge branch 'dev' into deep-speed 2024-03-24 18:45:59 +09:00
Kohya S
8d5858826f Merge branch 'dev' into masked-loss 2024-03-24 18:19:53 +09:00
Kohya S
025347214d refactor metadata caching for DreamBooth dataset 2024-03-24 18:09:32 +09:00
Kohaku-Blueleaf
ae97c8bfd1 [Experimental] Add cache mechanism for dataset groups to avoid long waiting time for initilization (#1178)
* support meta cached dataset

* add cache meta scripts

* random ip_noise_gamma strength

* random noise_offset strength

* use correct settings for parser

* cache path/caption/size only

* revert mess up commit

* revert mess up commit

* Update requirements.txt

* Add arguments for meta cache.

* remove pickle implementation

* Return sizes when enable cache

---------

Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
2024-03-24 15:40:18 +09:00
Kohya S
381c44955e update readme and typing hint 2024-03-24 11:27:18 +09:00
Kohya S
ad97410ba5 Merge pull request #1205 from feffy380/patch-1
register reg images with correct subset
2024-03-24 11:14:07 +09:00
Kohya S
691f04322a update readme 2024-03-24 11:10:26 +09:00
Kohya S
79d1c12ab0 disable sample_every_n_xxx if value less than 1 ref #1202 2024-03-24 11:06:37 +09:00
feffy380
0c7baea88c register reg images with correct subset 2024-03-23 17:28:02 +01:00
Kohya S
f4a4c11cd3 support multiline captions ref #1155 2024-03-23 18:51:37 +09:00
Kohya S
594c7f7050 format by black 2024-03-23 16:11:31 +09:00
Kohya S
d17c0f5084 update dataset config doc 2024-03-21 08:31:29 +09:00
Kohya S
a35e7bd595 Merge pull request #1200 from BootsofLagrangian/deep-speed
Fix sdxl_train.py in deepspeed branch
2024-03-20 21:32:35 +09:00
BootsofLagrangian
d9456020d7 Fix most of ZeRO stage uses optimizer partitioning
- we have to prepare optimizer and ds_model at the same time.
 - pull/1139#issuecomment-1986790007

Signed-off-by: BootsofLagrangian <hard2251@yonsei.ac.kr>
2024-03-20 20:52:59 +09:00
Kohya S
fbb98f144e Merge branch 'dev' into deep-speed 2024-03-20 18:15:26 +09:00
Kohya S
9b6b39f204 Merge branch 'dev' into masked-loss 2024-03-20 18:14:36 +09:00
Kohya S
855add067b update option help and readme 2024-03-20 18:14:05 +09:00
Kohya S
bf6cd4b9da Merge pull request #1168 from gesen2egee/save_state_on_train_end
Save state on train end
2024-03-20 18:02:13 +09:00
Kohya S
3b0db0f17f update readme 2024-03-20 17:45:35 +09:00
Kohya S
119cc99fb0 Merge pull request #1167 from Horizon1704/patch-1
Add "encoding='utf-8'" for --config_file
2024-03-20 17:39:08 +09:00
Kohya S
5f6196e4c7 update readme 2024-03-20 16:35:23 +09:00
Victor Espinoza-Guerra
46331a9e8e English Translation of config_README-ja.md (#1175)
* Add files via upload

Creating template to work on.

* Update config_README-en.md

Total Conversion from Japanese to English.

* Update config_README-en.md

* Update config_README-en.md

* Update config_README-en.md
2024-03-20 16:31:01 +09:00
Kohya S
cf09c6aa9f Merge pull request #1177 from KohakuBlueleaf/random-strength-noise
Random strength for Noise Offset and input perturbation noise
2024-03-20 16:17:16 +09:00
Kohya S
80dbbf5e48 tagger now stores model under repo_id subdir 2024-03-20 16:14:57 +09:00
Kohya S
7da41be281 Merge pull request #1192 from sdbds/main
Add WDV3 support
2024-03-20 15:49:55 +09:00
Kohya S
e281e867e6 Merge branch 'main' into dev 2024-03-20 15:49:08 +09:00
青龍聖者@bdsqlsz
6c51c971d1 fix typo 2024-03-20 09:35:21 +08:00
青龍聖者@bdsqlsz
a71c35ccd9 Update requirements.txt 2024-03-18 22:31:59 +08:00
青龍聖者@bdsqlsz
5410a8c79b Update requirements.txt 2024-03-18 22:31:00 +08:00
青龍聖者@bdsqlsz
a7dff592d3 Update tag_images_by_wd14_tagger.py
add WDV3
2024-03-18 22:29:05 +08:00
Kohya S
f9317052ed update readme for timestep embs bug 2024-03-18 08:53:23 +09:00
Kohya S
86e40fabbc Merge branch 'dev' into deep-speed 2024-03-17 19:30:42 +09:00
Kohya S
3419c3de0d common masked loss func, apply to all training script 2024-03-17 19:30:20 +09:00
Kohya S
7081a0cf0f extension of src image could be different than target image 2024-03-17 18:09:15 +09:00
Kohya S
0ef4fe70f0 Merge branch 'dev' into masked-loss 2024-03-17 11:18:18 +09:00
Kohya S
443f02942c fix doc 2024-03-15 21:35:14 +09:00
Kohya S
0a8ec5224e Merge branch 'main' into dev 2024-03-15 21:33:07 +09:00
Kohya S
6b1520a46b Merge pull request #1187 from kohya-ss/fix-timeemb
fix sdxl timestep embedding
2024-03-15 21:17:13 +09:00
Kohya S
f811b115ba fix sdxl timestep embedding 2024-03-15 21:05:00 +09:00
kblueleaf
53954a1e2e use correct settings for parser 2024-03-13 18:21:49 +08:00
kblueleaf
86399407b2 random noise_offset strength 2024-03-13 18:21:49 +08:00
kblueleaf
948029fe61 random ip_noise_gamma strength 2024-03-13 18:21:49 +08:00
Kohya S
97524f1bda Merge branch 'dev' into deep-speed 2024-03-12 20:41:41 +09:00
Kohya S
74c266a597 Merge branch 'dev' into masked-loss 2024-03-12 20:40:57 +09:00
gesen2egee
d282c45002 Update train_network.py 2024-03-11 23:56:09 +08:00
gesen2egee
095b8035e6 save state on train end 2024-03-10 23:33:38 +08:00
Horizon1704
124ec45876 Add "encoding='utf-8'" 2024-03-10 22:53:05 +08:00
Kohya S
14c9372a38 add doc about Colab/rich issue 2024-03-03 21:47:37 +09:00
Kohya S
a9b64ffba8 support masked loss in sdxl_train ref #589 2024-02-27 21:43:55 +09:00
Kohya S
e3ccf8fbf7 make deepspeed_utils 2024-02-27 21:30:46 +09:00
Kohya S
0e4a5738df Merge pull request #1101 from BootsofLagrangian/deepspeed
support deepspeed
2024-02-27 18:59:00 +09:00
Kohya S
eefb3cc1e7 Merge branch 'deep-speed' into deepspeed 2024-02-27 18:57:42 +09:00
Kohya S
074d32af20 Merge branch 'main' into dev 2024-02-27 18:53:43 +09:00
Kohya S
2d7389185c Merge pull request #1094 from kohya-ss/dependabot/github_actions/crate-ci/typos-1.17.2
Bump crate-ci/typos from 1.16.26 to 1.17.2
2024-02-27 18:23:41 +09:00
Kohya S
4a5546d40e fix typo 2024-02-26 23:39:56 +09:00
Kohya S
175193623b update readme 2024-02-26 23:29:41 +09:00
Kohya S
f2c727fc8c add minimal impl for masked loss 2024-02-26 23:19:58 +09:00
Kohya S
577e9913ca add some new dataset settings 2024-02-26 20:01:25 +09:00
Kohya S
fccbee2727 revert logging #1137 2024-02-25 10:43:14 +09:00
Kohya S
e0acb10f31 Merge pull request #1137 from shirayu/replace_print_with_logger
Replaced print with logger
2024-02-25 10:34:19 +09:00
Yuta Hayashibe
5d5f39b6e6 Replaced print with logger 2024-02-25 01:24:11 +09:00
BootsofLagrangian
4d5186d1cf refactored codes, some function moved into train_utils.py 2024-02-22 16:20:53 +09:00
BootsofLagrangian
03f0816f86 the reason not working grad accum steps found. it was becasue of my accelerate settings 2024-02-09 17:47:49 +09:00
BootsofLagrangian
a98fecaeb1 forgot setting mixed_precision for deepspeed. sorry 2024-02-07 17:19:46 +09:00
BootsofLagrangian
2445a5b74e remove test requirements 2024-02-07 16:48:18 +09:00
BootsofLagrangian
62556619bd fix full_fp16 compatible and train_step 2024-02-07 16:42:05 +09:00
BootsofLagrangian
7d2a9268b9 apply offloading method runable for all trainer 2024-02-05 22:42:06 +09:00
BootsofLagrangian
3970bf4080 maybe fix branch to run offloading 2024-02-05 22:40:43 +09:00
BootsofLagrangian
4295f91dcd fix all trainer about vae 2024-02-05 20:19:56 +09:00
BootsofLagrangian
2824312d5e fix vae type error during training sdxl 2024-02-05 20:13:28 +09:00
BootsofLagrangian
64873c1b43 fix offload_optimizer_device typo 2024-02-05 17:11:50 +09:00
BootsofLagrangian
dfe08f395f support deepspeed 2024-02-04 03:12:42 +09:00
dependabot[bot]
716a92cbed Bump crate-ci/typos from 1.16.26 to 1.17.2
Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.16.26 to 1.17.2.
- [Release notes](https://github.com/crate-ci/typos/releases)
- [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md)
- [Commits](https://github.com/crate-ci/typos/compare/v1.16.26...v1.17.2)

---
updated-dependencies:
- dependency-name: crate-ci/typos
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-02-01 01:57:52 +00:00
46 changed files with 2912 additions and 902 deletions

View File

@@ -18,4 +18,4 @@ jobs:
- uses: actions/checkout@v4
- name: typos-action
uses: crate-ci/typos@v1.16.26
uses: crate-ci/typos@v1.24.3

View File

@@ -1,12 +1,12 @@
SDXLがサポートされました。sdxlブランチはmainブランチにマージされました。リポジトリを更新したときにはUpgradeの手順を実行してください。また accelerate のバージョンが上がっていますので、accelerate config を再度実行してください。
SDXL学習については[こちら](./README.md#sdxl-training)をご覧ください(英語です)。
## リポジトリについて
Stable Diffusionの学習、画像生成、その他のスクリプトを入れたリポジトリです。
[README in English](./README.md) ←更新情報はこちらにあります
開発中のバージョンはdevブランチにあります。最新の変更点はdevブランチをご確認ください。
FLUX.1およびSD3/SD3.5対応はsd3ブランチで行っています。それらの学習を行う場合はsd3ブランチをご利用ください。
GUIやPowerShellスクリプトなど、より使いやすくする機能が[bmaltais氏のリポジトリ](https://github.com/bmaltais/kohya_ss)で提供されています英語ですのであわせてご覧ください。bmaltais氏に感謝します。
以下のスクリプトがあります。
@@ -21,6 +21,7 @@ GUIやPowerShellスクリプトなど、より使いやすくする機能が[bma
* [学習について、共通編](./docs/train_README-ja.md) : データ整備やオプションなど
* [データセット設定](./docs/config_README-ja.md)
* [SDXL学習](./docs/train_SDXL-en.md) (英語版)
* [DreamBoothの学習について](./docs/train_db_README-ja.md)
* [fine-tuningのガイド](./docs/fine_tune_README_ja.md):
* [LoRAの学習について](./docs/train_network_README-ja.md)
@@ -44,9 +45,7 @@ PowerShellを使う場合、venvを使えるようにするためには以下の
## Windows環境でのインストール
スクリプトはPyTorch 2.0.1でテストしています。PyTorch 1.12.1でも動作すると思われます。
以下の例ではPyTorchは2.0.1CUDA 11.8版をインストールします。CUDA 11.6版やPyTorch 1.12.1を使う場合は適宜書き換えください。
スクリプトはPyTorch 2.1.2でテストしています。PyTorch 2.0.1、1.12.1でも動作すると思われます。
なお、python -m venvの行で「python」とだけ表示された場合、py -m venvのようにpythonをpyに変更してください。
@@ -59,21 +58,21 @@ cd sd-scripts
python -m venv venv
.\venv\Scripts\activate
pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118
pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu118
pip install --upgrade -r requirements.txt
pip install xformers==0.0.20
pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu118
accelerate config
```
コマンドプロンプトでも同一です。
(注:``python -m venv venv`` のほうが ``python -m venv --system-site-packages venv`` より安全そうなため書き換えました。globalなpythonにパッケージがインストールしてあると、後者だといろいろと問題が起きます。
注:`bitsandbytes==0.43.0``prodigyopt==1.0``lion-pytorch==0.0.6``requirements.txt` に含まれるようになりました。他のバージョンを使う場合は適宜インストールしてください。
この例では PyTorch および xfomers は2.1.2CUDA 11.8版をインストールします。CUDA 12.1版やPyTorch 1.12.1を使う場合は適宜書き換えください。たとえば CUDA 12.1版の場合は `pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu121` および `pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu121` としてください。
accelerate configの質問には以下のように答えてください。bf16で学習する場合、最後の質問にはbf16と答えてください。
※0.15.0から日本語環境では選択のためにカーソルキーを押すと落ちます……。数字キーの0、1、2……で選択できますので、そちらを使ってください。
```txt
- This machine
- No distributed training
@@ -87,41 +86,6 @@ accelerate configの質問には以下のように答えてください。bf1
※場合によって ``ValueError: fp16 mixed precision requires a GPU`` というエラーが出ることがあるようです。この場合、6番目の質問
``What GPU(s) (by id) should be used for training on this machine as a comma-separated list? [all]:``に「0」と答えてください。id `0`のGPUが使われます。
### オプション:`bitsandbytes`8bit optimizerを使う
`bitsandbytes`はオプションになりました。Linuxでは通常通りpipでインストールできます0.41.1または以降のバージョンを推奨)。
Windowsでは0.35.0または0.41.1を推奨します。
- `bitsandbytes` 0.35.0: 安定しているとみられるバージョンです。AdamW8bitは使用できますが、他のいくつかの8bit optimizer、学習時の`full_bf16`オプションは使用できません。
- `bitsandbytes` 0.41.1: Lion8bit、PagedAdamW8bit、PagedLion8bitをサポートします。`full_bf16`が使用できます。
注:`bitsandbytes` 0.35.0から0.41.0までのバージョンには問題があるようです。 https://github.com/TimDettmers/bitsandbytes/issues/659
以下の手順に従い、`bitsandbytes`をインストールしてください。
### 0.35.0を使う場合
PowerShellの例です。コマンドプロンプトではcpの代わりにcopyを使ってください。
```powershell
cd sd-scripts
.\venv\Scripts\activate
pip install bitsandbytes==0.35.0
cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
```
### 0.41.1を使う場合
jllllll氏の配布されている[こちら](https://github.com/jllllll/bitsandbytes-windows-webui) または他の場所から、Windows用のwhlファイルをインストールしてください。
```powershell
python -m pip install bitsandbytes==0.41.1 --prefer-binary --extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui
```
## アップグレード
新しいリリースがあった場合、以下のコマンドで更新できます。
@@ -151,4 +115,47 @@ Conv2d 3x3への拡大は [cloneofsimo氏](https://github.com/cloneofsimo/lora)
[BLIP](https://github.com/salesforce/BLIP): BSD-3-Clause
## その他の情報
### LoRAの名称について
`train_network.py` がサポートするLoRAについて、混乱を避けるため名前を付けました。ドキュメントは更新済みです。以下は当リポジトリ内の独自の名称です。
1. __LoRA-LierLa__ : (LoRA for __Li__ n __e__ a __r__ __La__ yers、リエラと読みます)
Linear 層およびカーネルサイズ 1x1 の Conv2d 層に適用されるLoRA
2. __LoRA-C3Lier__ : (LoRA for __C__ olutional layers with __3__ x3 Kernel and __Li__ n __e__ a __r__ layers、セリアと読みます)
1.に加え、カーネルサイズ 3x3 の Conv2d 層に適用されるLoRA
デフォルトではLoRA-LierLaが使われます。LoRA-C3Lierを使う場合は `--network_args` に `conv_dim` を指定してください。
<!--
LoRA-LierLa は[Web UI向け拡張](https://github.com/kohya-ss/sd-webui-additional-networks)、またはAUTOMATIC1111氏のWeb UIのLoRA機能で使用することができます。
LoRA-C3Lierを使いWeb UIで生成するには拡張を使用してください。
-->
### 学習中のサンプル画像生成
プロンプトファイルは例えば以下のようになります。
```
# prompt 1
masterpiece, best quality, (1girl), in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
# prompt 2
masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n (low quality, worst quality), bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
```
`#` で始まる行はコメントになります。`--n` のように「ハイフン二個+英小文字」の形でオプションを指定できます。以下が使用可能できます。
* `--n` Negative prompt up to the next option.
* `--w` Specifies the width of the generated image.
* `--h` Specifies the height of the generated image.
* `--d` Specifies the seed of the generated image.
* `--l` Specifies the CFG scale of the generated image.
* `--s` Specifies the number of steps in the generation.
`( )` や `[ ]` などの重みづけも動作します。

489
README.md
View File

@@ -1,5 +1,3 @@
__SDXL is now supported. The sdxl branch has been merged into the main branch. If you update the repository, please follow the upgrade instructions. Also, the version of accelerate has been updated, so please run accelerate config again.__ The documentation for SDXL training is [here](./README.md#sdxl-training).
This repository contains training, generation and utility scripts for Stable Diffusion.
[__Change History__](#change-history) is moved to the bottom of the page.
@@ -7,6 +5,11 @@ This repository contains training, generation and utility scripts for Stable Dif
[日本語版READMEはこちら](./README-ja.md)
The development version is in the `dev` branch. Please check the dev branch for the latest changes.
FLUX.1 and SD3/SD3.5 support is done in the `sd3` branch. If you want to train them, please use the sd3 branch.
For easier use (GUI and PowerShell scripts etc...), please visit [the repository maintained by bmaltais](https://github.com/bmaltais/kohya_ss). Thanks to @bmaltais!
This repository contains the scripts for:
@@ -20,9 +23,9 @@ This repository contains the scripts for:
## About requirements.txt
These files do not contain requirements for PyTorch. Because the versions of them depend on your environment. Please install PyTorch at first (see installation guide below.)
The file does not contain requirements for PyTorch. Because the version of PyTorch depends on the environment, it is not included in the file. Please install PyTorch first according to the environment. See installation instructions below.
The scripts are tested with Pytorch 2.0.1. 1.12.1 is not tested but should work.
The scripts are tested with Pytorch 2.1.2. 2.0.1 and 1.12.1 is not tested but should work.
## Links to usage documentation
@@ -32,11 +35,13 @@ Most of the documents are written in Japanese.
* [Training guide - common](./docs/train_README-ja.md) : data preparation, options etc...
* [Chinese version](./docs/train_README-zh.md)
* [SDXL training](./docs/train_SDXL-en.md) (English version)
* [Dataset config](./docs/config_README-ja.md)
* [English version](./docs/config_README-en.md)
* [DreamBooth training guide](./docs/train_db_README-ja.md)
* [Step by Step fine-tuning guide](./docs/fine_tune_README_ja.md):
* [training LoRA](./docs/train_network_README-ja.md)
* [training Textual Inversion](./docs/train_ti_README-ja.md)
* [Training LoRA](./docs/train_network_README-ja.md)
* [Training Textual Inversion](./docs/train_ti_README-ja.md)
* [Image generation](./docs/gen_img_README-ja.md)
* note.com [Model conversion](https://note.com/kohya_ss/n/n374f316fe4ad)
@@ -64,14 +69,18 @@ cd sd-scripts
python -m venv venv
.\venv\Scripts\activate
pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118
pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu118
pip install --upgrade -r requirements.txt
pip install xformers==0.0.20
pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu118
accelerate config
```
__Note:__ Now bitsandbytes is optional. Please install any version of bitsandbytes as needed. Installation instructions are in the following section.
If `python -m venv` shows only `python`, change `python` to `py`.
__Note:__ Now `bitsandbytes==0.43.0`, `prodigyopt==1.0` and `lion-pytorch==0.0.6` are included in the requirements.txt. If you'd like to use the another version, please install it manually.
This installation is for CUDA 11.8. If you use a different version of CUDA, please install the appropriate version of PyTorch and xformers. For example, if you use CUDA 12, please install `pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu121` and `pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu121`.
<!--
cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
@@ -90,48 +99,13 @@ Answers to accelerate config:
- fp16
```
note: Some user reports ``ValueError: fp16 mixed precision requires a GPU`` is occurred in training. In this case, answer `0` for the 6th question:
If you'd like to use bf16, please answer `bf16` to the last question.
Note: Some user reports ``ValueError: fp16 mixed precision requires a GPU`` is occurred in training. In this case, answer `0` for the 6th question:
``What GPU(s) (by id) should be used for training on this machine as a comma-separated list? [all]:``
(Single GPU with id `0` will be used.)
### Optional: Use `bitsandbytes` (8bit optimizer)
For 8bit optimizer, you need to install `bitsandbytes`. For Linux, please install `bitsandbytes` as usual (0.41.1 or later is recommended.)
For Windows, there are several versions of `bitsandbytes`:
- `bitsandbytes` 0.35.0: Stable version. AdamW8bit is available. `full_bf16` is not available.
- `bitsandbytes` 0.41.1: Lion8bit, PagedAdamW8bit and PagedLion8bit are available. `full_bf16` is available.
Note: `bitsandbytes`above 0.35.0 till 0.41.0 seems to have an issue: https://github.com/TimDettmers/bitsandbytes/issues/659
Follow the instructions below to install `bitsandbytes` for Windows.
### bitsandbytes 0.35.0 for Windows
Open a regular Powershell terminal and type the following inside:
```powershell
cd sd-scripts
.\venv\Scripts\activate
pip install bitsandbytes==0.35.0
cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
```
This will install `bitsandbytes` 0.35.0 and copy the necessary files to the `bitsandbytes` directory.
### bitsandbytes 0.41.1 for Windows
Install the Windows version whl file from [here](https://github.com/jllllll/bitsandbytes-windows-webui) or other sources, like:
```powershell
python -m pip install bitsandbytes==0.41.1 --prefer-binary --extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui
```
## Upgrade
When a new release comes out you can upgrade your repo with the following command:
@@ -145,6 +119,10 @@ pip install --use-pep517 --upgrade -r requirements.txt
Once the commands have completed successfully you should be ready to use the new version.
### Upgrade PyTorch
If you want to upgrade PyTorch, you can upgrade it with `pip install` command in [Windows Installation](#windows-installation) section. `xformers` is also required to be upgraded when PyTorch is upgraded.
## Credits
The implementation for LoRA is based on [cloneofsimo's repo](https://github.com/cloneofsimo/lora). Thank you for great work!
@@ -162,209 +140,218 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
[BLIP](https://github.com/salesforce/BLIP): BSD-3-Clause
## SDXL training
The documentation in this section will be moved to a separate document later.
### Training scripts for SDXL
- `sdxl_train.py` is a script for SDXL fine-tuning. The usage is almost the same as `fine_tune.py`, but it also supports DreamBooth dataset.
- `--full_bf16` option is added. Thanks to KohakuBlueleaf!
- This option enables the full bfloat16 training (includes gradients). This option is useful to reduce the GPU memory usage.
- The full bfloat16 training might be unstable. Please use it at your own risk.
- The different learning rates for each U-Net block are now supported in sdxl_train.py. Specify with `--block_lr` option. Specify 23 values separated by commas like `--block_lr 1e-3,1e-3 ... 1e-3`.
- 23 values correspond to `0: time/label embed, 1-9: input blocks 0-8, 10-12: mid blocks 0-2, 13-21: output blocks 0-8, 22: out`.
- `prepare_buckets_latents.py` now supports SDXL fine-tuning.
- `sdxl_train_network.py` is a script for LoRA training for SDXL. The usage is almost the same as `train_network.py`.
- Both scripts has following additional options:
- `--cache_text_encoder_outputs` and `--cache_text_encoder_outputs_to_disk`: Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions.
- `--no_half_vae`: Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs.
- `--weighted_captions` option is not supported yet for both scripts.
- `sdxl_train_textual_inversion.py` is a script for Textual Inversion training for SDXL. The usage is almost the same as `train_textual_inversion.py`.
- `--cache_text_encoder_outputs` is not supported.
- There are two options for captions:
1. Training with captions. All captions must include the token string. The token string is replaced with multiple tokens.
2. Use `--use_object_template` or `--use_style_template` option. The captions are generated from the template. The existing captions are ignored.
- See below for the format of the embeddings.
- `--min_timestep` and `--max_timestep` options are added to each training script. These options can be used to train U-Net with different timesteps. The default values are 0 and 1000.
### Utility scripts for SDXL
- `tools/cache_latents.py` is added. This script can be used to cache the latents to disk in advance.
- The options are almost the same as `sdxl_train.py'. See the help message for the usage.
- Please launch the script as follows:
`accelerate launch --num_cpu_threads_per_process 1 tools/cache_latents.py ...`
- This script should work with multi-GPU, but it is not tested in my environment.
- `tools/cache_text_encoder_outputs.py` is added. This script can be used to cache the text encoder outputs to disk in advance.
- The options are almost the same as `cache_latents.py` and `sdxl_train.py`. See the help message for the usage.
- `sdxl_gen_img.py` is added. This script can be used to generate images with SDXL, including LoRA, Textual Inversion and ControlNet-LLLite. See the help message for the usage.
### Tips for SDXL training
- The default resolution of SDXL is 1024x1024.
- The fine-tuning can be done with 24GB GPU memory with the batch size of 1. For 24GB GPU, the following options are recommended __for the fine-tuning with 24GB GPU memory__:
- Train U-Net only.
- Use gradient checkpointing.
- Use `--cache_text_encoder_outputs` option and caching latents.
- Use Adafactor optimizer. RMSprop 8bit or Adagrad 8bit may work. AdamW 8bit doesn't seem to work.
- The LoRA training can be done with 8GB GPU memory (10GB recommended). For reducing the GPU memory usage, the following options are recommended:
- Train U-Net only.
- Use gradient checkpointing.
- Use `--cache_text_encoder_outputs` option and caching latents.
- Use one of 8bit optimizers or Adafactor optimizer.
- Use lower dim (4 to 8 for 8GB GPU).
- `--network_train_unet_only` option is highly recommended for SDXL LoRA. Because SDXL has two text encoders, the result of the training will be unexpected.
- PyTorch 2 seems to use slightly less GPU memory than PyTorch 1.
- `--bucket_reso_steps` can be set to 32 instead of the default value 64. Smaller values than 32 will not work for SDXL training.
Example of the optimizer settings for Adafactor with the fixed learning rate:
```toml
optimizer_type = "adafactor"
optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ]
lr_scheduler = "constant_with_warmup"
lr_warmup_steps = 100
learning_rate = 4e-7 # SDXL original learning rate
```
### Format of Textual Inversion embeddings for SDXL
```python
from safetensors.torch import save_file
state_dict = {"clip_g": embs_for_text_encoder_1280, "clip_l": embs_for_text_encoder_768}
save_file(state_dict, file)
```
### ControlNet-LLLite
ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [documentation](./docs/train_lllite_README.md) for details.
## Change History
### Feb 24, 2024 / 2024/2/24: v0.8.4
### Oct 27, 2024 / 2024-10-27:
- The log output has been improved. PR [#905](https://github.com/kohya-ss/sd-scripts/pull/905) Thanks to shirayu!
- The log is formatted by default. The `rich` library is required. Please see [Upgrade](#upgrade) and update the library.
- If `rich` is not installed, the log output will be the same as before.
- The following options are available in each training script:
- `--console_log_simple` option can be used to switch to the previous log output.
- `--console_log_level` option can be used to specify the log level. The default is `INFO`.
- `--console_log_file` option can be used to output the log to a file. The default is `None` (output to the console).
- The sample image generation during multi-GPU training is now done with multiple GPUs. PR [#1061](https://github.com/kohya-ss/sd-scripts/pull/1061) Thanks to DKnight54!
- The support for mps devices is improved. PR [#1054](https://github.com/kohya-ss/sd-scripts/pull/1054) Thanks to akx! If mps device exists instead of CUDA, the mps device is used automatically.
- The `--new_conv_rank` option to specify the new rank of Conv2d is added to `networks/resize_lora.py`. PR [#1102](https://github.com/kohya-ss/sd-scripts/pull/1102) Thanks to mgz-dev!
- An option `--highvram` to disable the optimization for environments with little VRAM is added to the training scripts. If you specify it when there is enough VRAM, the operation will be faster.
- Currently, only the cache part of latents is optimized.
- The IPEX support is improved. PR [#1086](https://github.com/kohya-ss/sd-scripts/pull/1086) Thanks to Disty0!
- Fixed a bug that `svd_merge_lora.py` crashes in some cases. PR [#1087](https://github.com/kohya-ss/sd-scripts/pull/1087) Thanks to mgz-dev!
- DyLoRA is fixed to work with SDXL. PR [#1126](https://github.com/kohya-ss/sd-scripts/pull/1126) Thanks to tamlog06!
- The common image generation script `gen_img.py` for SD 1/2 and SDXL is added. The basic functions are the same as the scripts for SD 1/2 and SDXL, but some new features are added.
- External scripts to generate prompts can be supported. It can be called with `--from_module` option. (The documentation will be added later)
- The normalization method after prompt weighting can be specified with `--emb_normalize_mode` option. `original` is the original method, `abs` is the normalization with the average of the absolute values, `none` is no normalization.
- Gradual Latent Hires fix is added to each generation script. See [here](./docs/gen_img_README-ja.md#about-gradual-latent) for details.
- `svd_merge_lora.py` VRAM usage has been reduced. However, main memory usage will increase (32GB is sufficient).
- This will be included in the next release.
- `svd_merge_lora.py` のVRAM使用量を削減しました。ただし、メインメモリの使用量は増加します32GBあれば十分です
- これは次回リリースに含まれます。
- ログ出力が改善されました。 PR [#905](https://github.com/kohya-ss/sd-scripts/pull/905) shirayu 氏に感謝します。
- デフォルトでログが成形されます。`rich` ライブラリが必要なため、[Upgrade](#upgrade) を参照し更新をお願いします。
- `rich` がインストールされていない場合は、従来のログ出力になります。
- 各学習スクリプトでは以下のオプションが有効です。
- `--console_log_simple` オプションで従来のログ出力に切り替えられます。
- `--console_log_level` でログレベルを指定できます。デフォルトは `INFO` です。
- `--console_log_file` でログファイルを出力できます。デフォルトは `None`(コンソールに出力) です。
- 複数 GPU 学習時に学習中のサンプル画像生成を複数 GPU で行うようになりました。 PR [#1061](https://github.com/kohya-ss/sd-scripts/pull/1061) DKnight54 氏に感謝します。
- mps デバイスのサポートが改善されました。 PR [#1054](https://github.com/kohya-ss/sd-scripts/pull/1054) akx 氏に感謝します。CUDA ではなく mps が存在する場合には自動的に mps デバイスを使用します。
- `networks/resize_lora.py` に Conv2d の新しいランクを指定するオプション `--new_conv_rank` が追加されました。 PR [#1102](https://github.com/kohya-ss/sd-scripts/pull/1102) mgz-dev 氏に感謝します。
- 学習スクリプトに VRAMが少ない環境向け最適化を無効にするオプション `--highvram` を追加しました。VRAM に余裕がある場合に指定すると動作が高速化されます。
- 現在は latents のキャッシュ部分のみ高速化されます。
- IPEX サポートが改善されました。 PR [#1086](https://github.com/kohya-ss/sd-scripts/pull/1086) Disty0 氏に感謝します。
- `svd_merge_lora.py` が場合によってエラーになる不具合が修正されました。 PR [#1087](https://github.com/kohya-ss/sd-scripts/pull/1087) mgz-dev 氏に感謝します。
- DyLoRA が SDXL で動くよう修正されました。PR [#1126](https://github.com/kohya-ss/sd-scripts/pull/1126) tamlog06 氏に感謝します。
- SD 1/2 および SDXL 共通の生成スクリプト `gen_img.py` を追加しました。基本的な機能は SD 1/2、SDXL 向けスクリプトと同じですが、いくつかの新機能が追加されています。
- プロンプトを動的に生成する外部スクリプトをサポートしました。 `--from_module` で呼び出せます。(ドキュメントはのちほど追加します)
- プロンプト重みづけ後の正規化方法を `--emb_normalize_mode` で指定できます。`original` は元の方法、`abs` は絶対値の平均値で正規化、`none` は正規化を行いません。
- Gradual Latent Hires fix を各生成スクリプトに追加しました。詳細は [こちら](./docs/gen_img_README-ja.md#about-gradual-latent)。
### Oct 26, 2024 / 2024-10-26:
- Fixed a bug in `svd_merge_lora.py`, `sdxl_merge_lora.py`, and `resize_lora.py` where the hash value of LoRA metadata was not correctly calculated when the `save_precision` was different from the `precision` used in the calculation. See issue [#1722](https://github.com/kohya-ss/sd-scripts/pull/1722) for details. Thanks to JujoHotaru for raising the issue.
- It will be included in the next release.
### Jan 27, 2024 / 2024/1/27: v0.8.3
- `svd_merge_lora.py`、`sdxl_merge_lora.py`、`resize_lora.py`で、保存時の精度が計算時の精度と異なる場合、LoRAメタデータのハッシュ値が正しく計算されない不具合を修正しました。詳細は issue [#1722](https://github.com/kohya-ss/sd-scripts/pull/1722) をご覧ください。問題提起していただいた JujoHotaru 氏に感謝します。
- 以上は次回リリースに含まれます。
- Fixed a bug that the training crashes when `--fp8_base` is specified with `--save_state`. PR [#1079](https://github.com/kohya-ss/sd-scripts/pull/1079) Thanks to feffy380!
- `safetensors` is updated. Please see [Upgrade](#upgrade) and update the library.
- Fixed a bug that the training crashes when `network_multiplier` is specified with multi-GPU training. PR [#1084](https://github.com/kohya-ss/sd-scripts/pull/1084) Thanks to fireicewolf!
- Fixed a bug that the training crashes when training ControlNet-LLLite.
### Sep 13, 2024 / 2024-09-13:
- `--fp8_base` 指定時に `--save_state` での保存がエラーになる不具合が修正されました。 PR [#1079](https://github.com/kohya-ss/sd-scripts/pull/1079) feffy380 氏に感謝します。
- `safetensors` がバージョンアップされていますので、[Upgrade](#upgrade) を参照し更新をお願いします。
- 複数 GPU での学習時に `network_multiplier` を指定するとクラッシュする不具合が修正されました。 PR [#1084](https://github.com/kohya-ss/sd-scripts/pull/1084) fireicewolf 氏に感謝します。
- ControlNet-LLLite の学習がエラーになる不具合を修正しました。
- `sdxl_merge_lora.py` now supports OFT. Thanks to Maru-mee for the PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580).
- `svd_merge_lora.py` now supports LBW. Thanks to terracottahaniwa. See PR [#1575](https://github.com/kohya-ss/sd-scripts/pull/1575) for details.
- `sdxl_merge_lora.py` also supports LBW.
- See [LoRA Block Weight](https://github.com/hako-mikan/sd-webui-lora-block-weight) by hako-mikan for details on LBW.
- These will be included in the next release.
### Jan 23, 2024 / 2024/1/23: v0.8.2
- `sdxl_merge_lora.py` が OFT をサポートされました。PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580) Maru-mee 氏に感謝します。
- `svd_merge_lora.py` で LBW がサポートされました。PR [#1575](https://github.com/kohya-ss/sd-scripts/pull/1575) terracottahaniwa 氏に感謝します。
- `sdxl_merge_lora.py` でも LBW がサポートされました。
- LBW の詳細は hako-mikan 氏の [LoRA Block Weight](https://github.com/hako-mikan/sd-webui-lora-block-weight) をご覧ください。
- 以上は次回リリースに含まれます。
- [Experimental] The `--fp8_base` option is added to the training scripts for LoRA etc. The base model (U-Net, and Text Encoder when training modules for Text Encoder) can be trained with fp8. PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) Thanks to KohakuBlueleaf!
- Please specify `--fp8_base` in `train_network.py` or `sdxl_train_network.py`.
- PyTorch 2.1 or later is required.
- If you use xformers with PyTorch 2.1, please see [xformers repository](https://github.com/facebookresearch/xformers) and install the appropriate version according to your CUDA version.
- The sample image generation during training consumes a lot of memory. It is recommended to turn it off.
### Jun 23, 2024 / 2024-06-23:
- [Experimental] The network multiplier can be specified for each dataset in the training scripts for LoRA etc.
- This is an experimental option and may be removed or changed in the future.
- For example, if you train with state A as `1.0` and state B as `-1.0`, you may be able to generate by switching between state A and B depending on the LoRA application rate.
- Also, if you prepare five states and train them as `0.2`, `0.4`, `0.6`, `0.8`, and `1.0`, you may be able to generate by switching the states smoothly depending on the application rate.
- Please specify `network_multiplier` in `[[datasets]]` in `.toml` file.
- Some options are added to `networks/extract_lora_from_models.py` to reduce the memory usage.
- `--load_precision` option can be used to specify the precision when loading the model. If the model is saved in fp16, you can reduce the memory usage by specifying `--load_precision fp16` without losing precision.
- `--load_original_model_to` option can be used to specify the device to load the original model. `--load_tuned_model_to` option can be used to specify the device to load the derived model. The default is `cpu` for both options, but you can specify `cuda` etc. You can reduce the memory usage by loading one of them to GPU. This option is available only for SDXL.
- Fixed `cache_latents.py` and `cache_text_encoder_outputs.py` not working. (Will be included in the next release.)
- The gradient synchronization in LoRA training with multi-GPU is improved. PR [#1064](https://github.com/kohya-ss/sd-scripts/pull/1064) Thanks to KohakuBlueleaf!
- The code for Intel IPEX support is improved. PR [#1060](https://github.com/kohya-ss/sd-scripts/pull/1060) Thanks to akx!
- Fixed a bug in multi-GPU Textual Inversion training.
- `cache_latents.py` および `cache_text_encoder_outputs.py` が動作しなくなっていたのを修正しました。(次回リリースに含まれます。)
- 実験的 LoRA等の学習スクリプトで、ベースモデルU-Net、および Text Encoder のモジュール学習時は Text Encoder も)の重みを fp8 にして学習するオプションが追加されました。 PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) KohakuBlueleaf 氏に感謝します。
- `train_network.py` または `sdxl_train_network.py``--fp8_base` を指定してください。
- PyTorch 2.1 以降が必要です。
- PyTorch 2.1 で xformers を使用する場合は、[xformers のリポジトリ](https://github.com/facebookresearch/xformers) を参照し、CUDA バージョンに応じて適切なバージョンをインストールしてください。
- 学習中のサンプル画像生成はメモリを大量に消費するため、オフにすることをお勧めします。
- (実験的) LoRA 等の学習で、データセットごとに異なるネットワーク適用率を指定できるようになりました。
- 実験的オプションのため、将来的に削除または仕様変更される可能性があります。
- たとえば状態 A を `1.0`、状態 B を `-1.0` として学習すると、LoRA の適用率に応じて状態 A と B を切り替えつつ生成できるかもしれません。
- また、五段階の状態を用意し、それぞれ `0.2``0.4``0.6``0.8``1.0` として学習すると、適用率でなめらかに状態を切り替えて生成できるかもしれません。
- `.toml` ファイルで `[[datasets]]``network_multiplier` を指定してください。
- `networks/extract_lora_from_models.py` に使用メモリ量を削減するいくつかのオプションを追加しました。
- `--load_precision` で読み込み時の精度を指定できます。モデルが fp16 で保存されている場合は `--load_precision fp16` を指定して精度を変えずにメモリ量を削減できます。
- `--load_original_model_to` で元モデルを読み込むデバイスを、`--load_tuned_model_to` で派生モデルを読み込むデバイスを指定できます。デフォルトは両方とも `cpu` ですがそれぞれ `cuda` 等を指定できます。片方を GPU に読み込むことでメモリ量を削減できます。SDXL の場合のみ有効です。
- マルチ GPU での LoRA 等の学習時に勾配の同期が改善されました。 PR [#1064](https://github.com/kohya-ss/sd-scripts/pull/1064) KohakuBlueleaf 氏に感謝します。
- Intel IPEX サポートのコードが改善されました。PR [#1060](https://github.com/kohya-ss/sd-scripts/pull/1060) akx 氏に感謝します。
- マルチ GPU での Textual Inversion 学習の不具合を修正しました。
### Apr 7, 2024 / 2024-04-07: v0.8.7
- `.toml` example for network multiplier / ネットワーク適用率の `.toml` の記述例
- The default value of `huber_schedule` in Scheduled Huber Loss is changed from `exponential` to `snr`, which is expected to give better results.
```toml
[general]
[[datasets]]
resolution = 512
batch_size = 8
network_multiplier = 1.0
- Scheduled Huber Loss の `huber_schedule` のデフォルト値を `exponential` から、より良い結果が期待できる `snr` に変更しました。
... subset settings ...
### Apr 7, 2024 / 2024-04-07: v0.8.6
[[datasets]]
resolution = 512
batch_size = 8
network_multiplier = -1.0
#### Highlights
... subset settings ...
```
- The dependent libraries are updated. Please see [Upgrade](#upgrade) and update the libraries.
- Especially `imagesize` is newly added, so if you cannot update the libraries immediately, please install with `pip install imagesize==1.4.1` separately.
- `bitsandbytes==0.43.0`, `prodigyopt==1.0`, `lion-pytorch==0.0.6` are included in the requirements.txt.
- `bitsandbytes` no longer requires complex procedures as it now officially supports Windows.
- Also, the PyTorch version is updated to 2.1.2 (PyTorch does not need to be updated immediately). In the upgrade procedure, PyTorch is not updated, so please manually install or update torch, torchvision, xformers if necessary (see [Upgrade PyTorch](#upgrade-pytorch)).
- When logging to wandb is enabled, the entire command line is exposed. Therefore, it is recommended to write wandb API key and HuggingFace token in the configuration file (`.toml`). Thanks to bghira for raising the issue.
- A warning is displayed at the start of training if such information is included in the command line.
- Also, if there is an absolute path, the path may be exposed, so it is recommended to specify a relative path or write it in the configuration file. In such cases, an INFO log is displayed.
- See [#1123](https://github.com/kohya-ss/sd-scripts/pull/1123) and PR [#1240](https://github.com/kohya-ss/sd-scripts/pull/1240) for details.
- Colab seems to stop with log output. Try specifying `--console_log_simple` option in the training script to disable rich logging.
- Other improvements include the addition of masked loss, scheduled Huber Loss, DeepSpeed support, dataset settings improvements, and image tagging improvements. See below for details.
#### Training scripts
- `train_network.py` and `sdxl_train_network.py` are modified to record some dataset settings in the metadata of the trained model (`caption_prefix`, `caption_suffix`, `keep_tokens_separator`, `secondary_separator`, `enable_wildcard`).
- Fixed a bug that U-Net and Text Encoders are included in the state in `train_network.py` and `sdxl_train_network.py`. The saving and loading of the state are faster, the file size is smaller, and the memory usage when loading is reduced.
- DeepSpeed is supported. PR [#1101](https://github.com/kohya-ss/sd-scripts/pull/1101) and [#1139](https://github.com/kohya-ss/sd-scripts/pull/1139) Thanks to BootsofLagrangian! See PR [#1101](https://github.com/kohya-ss/sd-scripts/pull/1101) for details.
- The masked loss is supported in each training script. PR [#1207](https://github.com/kohya-ss/sd-scripts/pull/1207) See [Masked loss](#about-masked-loss) for details.
- Scheduled Huber Loss has been introduced to each training scripts. PR [#1228](https://github.com/kohya-ss/sd-scripts/pull/1228/) Thanks to kabachuha for the PR and cheald, drhead, and others for the discussion! See the PR and [Scheduled Huber Loss](#about-scheduled-huber-loss) for details.
- The options `--noise_offset_random_strength` and `--ip_noise_gamma_random_strength` are added to each training script. These options can be used to vary the noise offset and ip noise gamma in the range of 0 to the specified value. PR [#1177](https://github.com/kohya-ss/sd-scripts/pull/1177) Thanks to KohakuBlueleaf!
- The options `--save_state_on_train_end` are added to each training script. PR [#1168](https://github.com/kohya-ss/sd-scripts/pull/1168) Thanks to gesen2egee!
- The options `--sample_every_n_epochs` and `--sample_every_n_steps` in each training script now display a warning and ignore them when a number less than or equal to `0` is specified. Thanks to S-Del for raising the issue.
#### Dataset settings
- The [English version of the dataset settings documentation](./docs/config_README-en.md) is added. PR [#1175](https://github.com/kohya-ss/sd-scripts/pull/1175) Thanks to darkstorm2150!
- The `.toml` file for the dataset config is now read in UTF-8 encoding. PR [#1167](https://github.com/kohya-ss/sd-scripts/pull/1167) Thanks to Horizon1704!
- Fixed a bug that the last subset settings are applied to all images when multiple subsets of regularization images are specified in the dataset settings. The settings for each subset are correctly applied to each image. PR [#1205](https://github.com/kohya-ss/sd-scripts/pull/1205) Thanks to feffy380!
- Some features are added to the dataset subset settings.
- `secondary_separator` is added to specify the tag separator that is not the target of shuffling or dropping.
- Specify `secondary_separator=";;;"`. When you specify `secondary_separator`, the part is not shuffled or dropped.
- `enable_wildcard` is added. When set to `true`, the wildcard notation `{aaa|bbb|ccc}` can be used. The multi-line caption is also enabled.
- `keep_tokens_separator` is updated to be used twice in the caption. When you specify `keep_tokens_separator="|||"`, the part divided by the second `|||` is not shuffled or dropped and remains at the end.
- The existing features `caption_prefix` and `caption_suffix` can be used together. `caption_prefix` and `caption_suffix` are processed first, and then `enable_wildcard`, `keep_tokens_separator`, shuffling and dropping, and `secondary_separator` are processed in order.
- See [Dataset config](./docs/config_README-en.md) for details.
- The dataset with DreamBooth method supports caching image information (size, caption). PR [#1178](https://github.com/kohya-ss/sd-scripts/pull/1178) and [#1206](https://github.com/kohya-ss/sd-scripts/pull/1206) Thanks to KohakuBlueleaf! See [DreamBooth method specific options](./docs/config_README-en.md#dreambooth-specific-options) for details.
#### Image tagging
- The support for v3 repositories is added to `tag_image_by_wd14_tagger.py` (`--onnx` option only). PR [#1192](https://github.com/kohya-ss/sd-scripts/pull/1192) Thanks to sdbds!
- Onnx may need to be updated. Onnx is not installed by default, so please install or update it with `pip install onnx==1.15.0 onnxruntime-gpu==1.17.1` etc. Please also check the comments in `requirements.txt`.
- The model is now saved in the subdirectory as `--repo_id` in `tag_image_by_wd14_tagger.py` . This caches multiple repo_id models. Please delete unnecessary files under `--model_dir`.
- Some options are added to `tag_image_by_wd14_tagger.py`.
- Some are added in PR [#1216](https://github.com/kohya-ss/sd-scripts/pull/1216) Thanks to Disty0!
- Output rating tags `--use_rating_tags` and `--use_rating_tags_as_last_tag`
- Output character tags first `--character_tags_first`
- Expand character tags and series `--character_tag_expand`
- Specify tags to output first `--always_first_tags`
- Replace tags `--tag_replacement`
- See [Tagging documentation](./docs/wd14_tagger_README-en.md) for details.
- Fixed an error when specifying `--beam_search` and a value of 2 or more for `--num_beams` in `make_captions.py`.
#### About Masked loss
The masked loss is supported in each training script. To enable the masked loss, specify the `--masked_loss` option.
The feature is not fully tested, so there may be bugs. If you find any issues, please open an Issue.
ControlNet dataset is used to specify the mask. The mask images should be the RGB images. The pixel value 255 in R channel is treated as the mask (the loss is calculated only for the pixels with the mask), and 0 is treated as the non-mask. The pixel values 0-255 are converted to 0-1 (i.e., the pixel value 128 is treated as the half weight of the loss). See details for the dataset specification in the [LLLite documentation](./docs/train_lllite_README.md#preparing-the-dataset).
#### About Scheduled Huber Loss
Scheduled Huber Loss has been introduced to each training scripts. This is a method to improve robustness against outliers or anomalies (data corruption) in the training data.
With the traditional MSE (L2) loss function, the impact of outliers could be significant, potentially leading to a degradation in the quality of generated images. On the other hand, while the Huber loss function can suppress the influence of outliers, it tends to compromise the reproduction of fine details in images.
To address this, the proposed method employs a clever application of the Huber loss function. By scheduling the use of Huber loss in the early stages of training (when noise is high) and MSE in the later stages, it strikes a balance between outlier robustness and fine detail reproduction.
Experimental results have confirmed that this method achieves higher accuracy on data containing outliers compared to pure Huber loss or MSE. The increase in computational cost is minimal.
The newly added arguments loss_type, huber_schedule, and huber_c allow for the selection of the loss function type (Huber, smooth L1, MSE), scheduling method (exponential, constant, SNR), and Huber's parameter. This enables optimization based on the characteristics of the dataset.
See PR [#1228](https://github.com/kohya-ss/sd-scripts/pull/1228/) for details.
- `loss_type`: Specify the loss function type. Choose `huber` for Huber loss, `smooth_l1` for smooth L1 loss, and `l2` for MSE loss. The default is `l2`, which is the same as before.
- `huber_schedule`: Specify the scheduling method. Choose `exponential`, `constant`, or `snr`. The default is `snr`.
- `huber_c`: Specify the Huber's parameter. The default is `0.1`.
Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates.
#### 主要な変更点
- 依存ライブラリが更新されました。[アップグレード](./README-ja.md#アップグレード) を参照しライブラリを更新してください。
- 特に `imagesize` が新しく追加されていますので、すぐにライブラリの更新ができない場合は `pip install imagesize==1.4.1` で個別にインストールしてください。
- `bitsandbytes==0.43.0`、`prodigyopt==1.0`、`lion-pytorch==0.0.6` が requirements.txt に含まれるようになりました。
- `bitsandbytes` が公式に Windows をサポートしたため複雑な手順が不要になりました。
- また PyTorch のバージョンを 2.1.2 に更新しました。PyTorch はすぐに更新する必要はありません。更新時は、アップグレードの手順では PyTorch が更新されませんので、torch、torchvision、xformers を手動でインストールしてください。
- wandb へのログ出力が有効の場合、コマンドライン全体が公開されます。そのため、コマンドラインに wandb の API キーや HuggingFace のトークンなどが含まれる場合、設定ファイル(`.toml`)への記載をお勧めします。問題提起していただいた bghira 氏に感謝します。
- このような場合には学習開始時に警告が表示されます。
- また絶対パスの指定がある場合、そのパスが公開される可能性がありますので、相対パスを指定するか設定ファイルに記載することをお勧めします。このような場合は INFO ログが表示されます。
- 詳細は [#1123](https://github.com/kohya-ss/sd-scripts/pull/1123) および PR [#1240](https://github.com/kohya-ss/sd-scripts/pull/1240) をご覧ください。
- Colab での動作時、ログ出力で停止してしまうようです。学習スクリプトに `--console_log_simple` オプションを指定し、rich のロギングを無効してお試しください。
- その他、マスクロス追加、Scheduled Huber Loss 追加、DeepSpeed 対応、データセット設定の改善、画像タグ付けの改善などがあります。詳細は以下をご覧ください。
#### 学習スクリプト
- `train_network.py` および `sdxl_train_network.py` で、学習したモデルのメタデータに一部のデータセット設定が記録されるよう修正しました(`caption_prefix`、`caption_suffix`、`keep_tokens_separator`、`secondary_separator`、`enable_wildcard`)。
- `train_network.py` および `sdxl_train_network.py` で、state に U-Net および Text Encoder が含まれる不具合を修正しました。state の保存、読み込みが高速化され、ファイルサイズも小さくなり、また読み込み時のメモリ使用量も削減されます。
- DeepSpeed がサポートされました。PR [#1101](https://github.com/kohya-ss/sd-scripts/pull/1101) 、[#1139](https://github.com/kohya-ss/sd-scripts/pull/1139) BootsofLagrangian 氏に感謝します。詳細は PR [#1101](https://github.com/kohya-ss/sd-scripts/pull/1101) をご覧ください。
- 各学習スクリプトでマスクロスをサポートしました。PR [#1207](https://github.com/kohya-ss/sd-scripts/pull/1207) 詳細は [マスクロスについて](#マスクロスについて) をご覧ください。
- 各学習スクリプトに Scheduled Huber Loss を追加しました。PR [#1228](https://github.com/kohya-ss/sd-scripts/pull/1228/) ご提案いただいた kabachuha 氏、および議論を深めてくださった cheald 氏、drhead 氏を始めとする諸氏に感謝します。詳細は当該 PR および [Scheduled Huber Loss について](#scheduled-huber-loss-について) をご覧ください。
- 各学習スクリプトに、noise offset、ip noise gammaを、それぞれ 0~指定した値の範囲で変動させるオプション `--noise_offset_random_strength` および `--ip_noise_gamma_random_strength` が追加されました。 PR [#1177](https://github.com/kohya-ss/sd-scripts/pull/1177) KohakuBlueleaf 氏に感謝します。
- 各学習スクリプトに、学習終了時に state を保存する `--save_state_on_train_end` オプションが追加されました。 PR [#1168](https://github.com/kohya-ss/sd-scripts/pull/1168) gesen2egee 氏に感謝します。
- 各学習スクリプトで `--sample_every_n_epochs` および `--sample_every_n_steps` オプションに `0` 以下の数値を指定した時、警告を表示するとともにそれらを無視するよう変更しました。問題提起していただいた S-Del 氏に感謝します。
#### データセット設定
- データセット設定の `.toml` ファイルが UTF-8 encoding で読み込まれるようになりました。PR [#1167](https://github.com/kohya-ss/sd-scripts/pull/1167) Horizon1704 氏に感謝します。
- データセット設定で、正則化画像のサブセットを複数指定した時、最後のサブセットの各種設定がすべてのサブセットの画像に適用される不具合が修正されました。それぞれのサブセットの設定が、それぞれの画像に正しく適用されます。PR [#1205](https://github.com/kohya-ss/sd-scripts/pull/1205) feffy380 氏に感謝します。
- データセットのサブセット設定にいくつかの機能を追加しました。
- シャッフルの対象とならないタグ分割識別子の指定 `secondary_separator` を追加しました。`secondary_separator=";;;"` のように指定します。`secondary_separator` で区切ることで、その部分はシャッフル、drop 時にまとめて扱われます。
- `enable_wildcard` を追加しました。`true` にするとワイルドカード記法 `{aaa|bbb|ccc}` が使えます。また複数行キャプションも有効になります。
- `keep_tokens_separator` をキャプション内に 2 つ使えるようにしました。たとえば `keep_tokens_separator="|||"` と指定したとき、`1girl, hatsune miku, vocaloid ||| stage, mic ||| best quality, rating: general` とキャプションを指定すると、二番目の `|||` で分割された部分はシャッフル、drop されず末尾に残ります。
- 既存の機能 `caption_prefix` と `caption_suffix` とあわせて使えます。`caption_prefix` と `caption_suffix` は一番最初に処理され、その後、ワイルドカード、`keep_tokens_separator`、シャッフルおよび drop、`secondary_separator` の順に処理されます。
- 詳細は [データセット設定](./docs/config_README-ja.md) をご覧ください。
- DreamBooth 方式の DataSet で画像情報サイズ、キャプションをキャッシュする機能が追加されました。PR [#1178](https://github.com/kohya-ss/sd-scripts/pull/1178)、[#1206](https://github.com/kohya-ss/sd-scripts/pull/1206) KohakuBlueleaf 氏に感謝します。詳細は [データセット設定](./docs/config_README-ja.md#dreambooth-方式専用のオプション) をご覧ください。
- データセット設定の[英語版ドキュメント](./docs/config_README-en.md) が追加されました。PR [#1175](https://github.com/kohya-ss/sd-scripts/pull/1175) darkstorm2150 氏に感謝します。
#### 画像のタグ付け
- `tag_image_by_wd14_tagger.py` で v3 のリポジトリがサポートされました(`--onnx` 指定時のみ有効)。 PR [#1192](https://github.com/kohya-ss/sd-scripts/pull/1192) sdbds 氏に感謝します。
- Onnx のバージョンアップが必要になるかもしれません。デフォルトでは Onnx はインストールされていませんので、`pip install onnx==1.15.0 onnxruntime-gpu==1.17.1` 等でインストール、アップデートしてください。`requirements.txt` のコメントもあわせてご確認ください。
- `tag_image_by_wd14_tagger.py` で、モデルを`--repo_id` のサブディレクトリに保存するようにしました。これにより複数のモデルファイルがキャッシュされます。`--model_dir` 直下の不要なファイルは削除願います。
- `tag_image_by_wd14_tagger.py` にいくつかのオプションを追加しました。
- 一部は PR [#1216](https://github.com/kohya-ss/sd-scripts/pull/1216) で追加されました。Disty0 氏に感謝します。
- レーティングタグを出力する `--use_rating_tags` および `--use_rating_tags_as_last_tag`
- キャラクタタグを最初に出力する `--character_tags_first`
- キャラクタタグとシリーズを展開する `--character_tag_expand`
- 常に最初に出力するタグを指定する `--always_first_tags`
- タグを置換する `--tag_replacement`
- 詳細は [タグ付けに関するドキュメント](./docs/wd14_tagger_README-ja.md) をご覧ください。
- `make_captions.py` で `--beam_search` を指定し `--num_beams` に2以上の値を指定した時のエラーを修正しました。
#### マスクロスについて
各学習スクリプトでマスクロスをサポートしました。マスクロスを有効にするには `--masked_loss` オプションを指定してください。
機能は完全にテストされていないため、不具合があるかもしれません。その場合は Issue を立てていただけると助かります。
マスクの指定には ControlNet データセットを使用します。マスク画像は RGB 画像である必要があります。R チャンネルのピクセル値 255 がロス計算対象、0 がロス計算対象外になります。0-255 の値は、0-1 の範囲に変換されます(つまりピクセル値 128 の部分はロスの重みが半分になります)。データセットの詳細は [LLLite ドキュメント](./docs/train_lllite_README-ja.md#データセットの準備) をご覧ください。
#### Scheduled Huber Loss について
各学習スクリプトに、学習データ中の異常値や外れ値data corruptionへの耐性を高めるための手法、Scheduled Huber Lossが導入されました。
従来のMSEL2損失関数では、異常値の影響を大きく受けてしまい、生成画像の品質低下を招く恐れがありました。一方、Huber損失関数は異常値の影響を抑えられますが、画像の細部再現性が損なわれがちでした。
この手法ではHuber損失関数の適用を工夫し、学習の初期段階イズが大きい場合ではHuber損失を、後期段階ではMSEを用いるようスケジューリングすることで、異常値耐性と細部再現性のバランスを取ります。
実験の結果では、この手法が純粋なHuber損失やMSEと比べ、異常値を含むデータでより高い精度を達成することが確認されています。また計算コストの増加はわずかです。
具体的には、新たに追加された引数loss_type、huber_schedule、huber_cで、損失関数の種類Huber, smooth L1, MSEとスケジューリング方法exponential, constant, SNRを選択できます。これによりデータセットに応じた最適化が可能になります。
詳細は PR [#1228](https://github.com/kohya-ss/sd-scripts/pull/1228/) をご覧ください。
- `loss_type` : 損失関数の種類を指定します。`huber` で Huber損失、`smooth_l1` で smooth L1 損失、`l2` で MSE 損失を選択します。デフォルトは `l2` で、従来と同様です。
- `huber_schedule` : スケジューリング方法を指定します。`exponential` で指数関数的、`constant` で一定、`snr` で信号対雑音比に基づくスケジューリングを選択します。デフォルトは `snr` です。
- `huber_c` : Huber損失のパラメータを指定します。デフォルトは `0.1` です。
PR 内でいくつかの比較が共有されています。この機能を試す場合、最初は `--loss_type smooth_l1 --huber_schedule snr --huber_c 0.1` などで試してみるとよいかもしれません。
最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。
## Additional Information
### Naming of LoRA
The LoRA supported by `train_network.py` has been named to avoid confusion. The documentation has been updated. The following are the names of LoRA types in this repository.
@@ -377,27 +364,14 @@ The LoRA supported by `train_network.py` has been named to avoid confusion. The
In addition to 1., LoRA for Conv2d layers with 3x3 kernel
LoRA-LierLa is the default LoRA type for `train_network.py` (without `conv_dim` network arg). LoRA-LierLa can be used with [our extension](https://github.com/kohya-ss/sd-webui-additional-networks) for AUTOMATIC1111's Web UI, or with the built-in LoRA feature of the Web UI.
LoRA-LierLa is the default LoRA type for `train_network.py` (without `conv_dim` network arg).
<!--
LoRA-LierLa can be used with [our extension](https://github.com/kohya-ss/sd-webui-additional-networks) for AUTOMATIC1111's Web UI, or with the built-in LoRA feature of the Web UI.
To use LoRA-C3Lier with Web UI, please use our extension.
To use LoRA-C3Lier with Web UI, please use our extension.
-->
### LoRAの名称について
`train_network.py` がサポートするLoRAについて、混乱を避けるため名前を付けました。ドキュメントは更新済みです。以下は当リポジトリ内の独自の名称です。
1. __LoRA-LierLa__ : (LoRA for __Li__ n __e__ a __r__ __La__ yers、リエラと読みます)
Linear 層およびカーネルサイズ 1x1 の Conv2d 層に適用されるLoRA
2. __LoRA-C3Lier__ : (LoRA for __C__ olutional layers with __3__ x3 Kernel and __Li__ n __e__ a __r__ layers、セリアと読みます)
1.に加え、カーネルサイズ 3x3 の Conv2d 層に適用されるLoRA
LoRA-LierLa は[Web UI向け拡張](https://github.com/kohya-ss/sd-webui-additional-networks)、またはAUTOMATIC1111氏のWeb UIのLoRA機能で使用することができます。
LoRA-C3Lierを使いWeb UIで生成するには拡張を使用してください。
## Sample image generation during training
### Sample image generation during training
A prompt file might look like this, for example
```
@@ -418,26 +392,3 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
* `--s` Specifies the number of steps in the generation.
The prompt weighting such as `( )` and `[ ]` are working.
## サンプル画像生成
プロンプトファイルは例えば以下のようになります。
```
# prompt 1
masterpiece, best quality, (1girl), in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
# prompt 2
masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n (low quality, worst quality), bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
```
`#` で始まる行はコメントになります。`--n` のように「ハイフン二個+英小文字」の形でオプションを指定できます。以下が使用可能できます。
* `--n` Negative prompt up to the next option.
* `--w` Specifies the width of the generated image.
* `--h` Specifies the height of the generated image.
* `--d` Specifies the seed of the generated image.
* `--l` Specifies the CFG scale of the generated image.
* `--s` Specifies the number of steps in the generation.
`( )``[ ]` などの重みづけも動作します。

384
docs/config_README-en.md Normal file
View File

@@ -0,0 +1,384 @@
Original Source by kohya-ss
First version:
A.I Translation by Model: NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO, editing by Darkstorm2150
Some parts are manually added.
# Config Readme
This README is about the configuration files that can be passed with the `--dataset_config` option.
## Overview
By passing a configuration file, users can make detailed settings.
* Multiple datasets can be configured
* For example, by setting `resolution` for each dataset, they can be mixed and trained.
* In training methods that support both the DreamBooth approach and the fine-tuning approach, datasets of the DreamBooth method and the fine-tuning method can be mixed.
* Settings can be changed for each subset
* A subset is a partition of the dataset by image directory or metadata. Several subsets make up a dataset.
* Options such as `keep_tokens` and `flip_aug` can be set for each subset. On the other hand, options such as `resolution` and `batch_size` can be set for each dataset, and their values are common among subsets belonging to the same dataset. More details will be provided later.
The configuration file format can be JSON or TOML. Considering the ease of writing, it is recommended to use [TOML](https://toml.io/ja/v1.0.0-rc.2). The following explanation assumes the use of TOML.
Here is an example of a configuration file written in TOML.
```toml
[general]
shuffle_caption = true
caption_extension = '.txt'
keep_tokens = 1
# This is a DreamBooth-style dataset
[[datasets]]
resolution = 512
batch_size = 4
keep_tokens = 2
[[datasets.subsets]]
image_dir = 'C:\hoge'
class_tokens = 'hoge girl'
# This subset uses keep_tokens = 2 (the value of the parent datasets)
[[datasets.subsets]]
image_dir = 'C:\fuga'
class_tokens = 'fuga boy'
keep_tokens = 3
[[datasets.subsets]]
is_reg = true
image_dir = 'C:\reg'
class_tokens = 'human'
keep_tokens = 1
# This is a fine-tuning dataset
[[datasets]]
resolution = [768, 768]
batch_size = 2
[[datasets.subsets]]
image_dir = 'C:\piyo'
metadata_file = 'C:\piyo\piyo_md.json'
# This subset uses keep_tokens = 1 (the value of [general])
```
In this example, three directories are trained as a DreamBooth-style dataset at 512x512 (batch size 4), and one directory is trained as a fine-tuning dataset at 768x768 (batch size 2).
## Settings for datasets and subsets
Settings for datasets and subsets are divided into several registration locations.
* `[general]`
* This is where options that apply to all datasets or all subsets are specified.
* If there are options with the same name in the dataset-specific or subset-specific settings, the dataset-specific or subset-specific settings take precedence.
* `[[datasets]]`
* `datasets` is where settings for datasets are registered. This is where options that apply individually to each dataset are specified.
* If there are subset-specific settings, the subset-specific settings take precedence.
* `[[datasets.subsets]]`
* `datasets.subsets` is where settings for subsets are registered. This is where options that apply individually to each subset are specified.
Here is an image showing the correspondence between image directories and registration locations in the previous example.
```
C:\
├─ hoge -> [[datasets.subsets]] No.1 ┐ ┐
├─ fuga -> [[datasets.subsets]] No.2 |-> [[datasets]] No.1 |-> [general]
├─ reg -> [[datasets.subsets]] No.3 ┘ |
└─ piyo -> [[datasets.subsets]] No.4 --> [[datasets]] No.2 ┘
```
The image directory corresponds to each `[[datasets.subsets]]`. Then, multiple `[[datasets.subsets]]` are combined to form one `[[datasets]]`. All `[[datasets]]` and `[[datasets.subsets]]` belong to `[general]`.
The available options for each registration location may differ, but if the same option is specified, the value in the lower registration location will take precedence. You can check how the `keep_tokens` option is handled in the previous example for better understanding.
Additionally, the available options may vary depending on the method that the learning approach supports.
* Options specific to the DreamBooth method
* Options specific to the fine-tuning method
* Options available when using the caption dropout technique
When using both the DreamBooth method and the fine-tuning method, they can be used together with a learning approach that supports both.
When using them together, a point to note is that the method is determined based on the dataset, so it is not possible to mix DreamBooth method subsets and fine-tuning method subsets within the same dataset.
In other words, if you want to use both methods together, you need to set up subsets of different methods belonging to different datasets.
In terms of program behavior, if the `metadata_file` option exists, it is determined to be a subset of fine-tuning. Therefore, for subsets belonging to the same dataset, as long as they are either "all have the `metadata_file` option" or "all have no `metadata_file` option," there is no problem.
Below, the available options will be explained. For options with the same name as the command-line argument, the explanation will be omitted in principle. Please refer to other READMEs.
### Common options for all learning methods
These are options that can be specified regardless of the learning method.
#### Data set specific options
These are options related to the configuration of the data set. They cannot be described in `datasets.subsets`.
| Option Name | Example Setting | `[general]` | `[[datasets]]` |
| ---- | ---- | ---- | ---- |
| `batch_size` | `1` | o | o |
| `bucket_no_upscale` | `true` | o | o |
| `bucket_reso_steps` | `64` | o | o |
| `enable_bucket` | `true` | o | o |
| `max_bucket_reso` | `1024` | o | o |
| `min_bucket_reso` | `128` | o | o |
| `resolution` | `256`, `[512, 512]` | o | o |
* `batch_size`
* This corresponds to the command-line argument `--train_batch_size`.
These settings are fixed per dataset. That means that subsets belonging to the same dataset will share these settings. For example, if you want to prepare datasets with different resolutions, you can define them as separate datasets as shown in the example above, and set different resolutions for each.
#### Options for Subsets
These options are related to subset configuration.
| Option Name | Example | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
| ---- | ---- | ---- | ---- | ---- |
| `color_aug` | `false` | o | o | o |
| `face_crop_aug_range` | `[1.0, 3.0]` | o | o | o |
| `flip_aug` | `true` | o | o | o |
| `keep_tokens` | `2` | o | o | o |
| `num_repeats` | `10` | o | o | o |
| `random_crop` | `false` | o | o | o |
| `shuffle_caption` | `true` | o | o | o |
| `caption_prefix` | `"masterpiece, best quality, "` | o | o | o |
| `caption_suffix` | `", from side"` | o | o | o |
| `caption_separator` | (not specified) | o | o | o |
| `keep_tokens_separator` | `“|||”` | o | o | o |
| `secondary_separator` | `“;;;”` | o | o | o |
| `enable_wildcard` | `true` | o | o | o |
* `num_repeats`
* Specifies the number of repeats for images in a subset. This is equivalent to `--dataset_repeats` in fine-tuning but can be specified for any training method.
* `caption_prefix`, `caption_suffix`
* Specifies the prefix and suffix strings to be appended to the captions. Shuffling is performed with these strings included. Be cautious when using `keep_tokens`.
* `caption_separator`
* Specifies the string to separate the tags. The default is `,`. This option is usually not necessary to set.
* `keep_tokens_separator`
* Specifies the string to separate the parts to be fixed in the caption. For example, if you specify `aaa, bbb ||| ccc, ddd, eee, fff ||| ggg, hhh`, the parts `aaa, bbb` and `ggg, hhh` will remain, and the rest will be shuffled and dropped. The comma in between is not necessary. As a result, the prompt will be `aaa, bbb, eee, ccc, fff, ggg, hhh` or `aaa, bbb, fff, ccc, eee, ggg, hhh`, etc.
* `secondary_separator`
* Specifies an additional separator. The part separated by this separator is treated as one tag and is shuffled and dropped. It is then replaced by `caption_separator`. For example, if you specify `aaa;;;bbb;;;ccc`, it will be replaced by `aaa,bbb,ccc` or dropped together.
* `enable_wildcard`
* Enables wildcard notation. This will be explained later.
### DreamBooth-specific options
DreamBooth-specific options only exist as subsets-specific options.
#### Subset-specific options
Options related to the configuration of DreamBooth subsets.
| Option Name | Example Setting | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
| ---- | ---- | ---- | ---- | ---- |
| `image_dir` | `'C:\hoge'` | - | - | o (required) |
| `caption_extension` | `".txt"` | o | o | o |
| `class_tokens` | `"sks girl"` | - | - | o |
| `cache_info` | `false` | o | o | o |
| `is_reg` | `false` | - | - | o |
Firstly, note that for `image_dir`, the path to the image files must be specified as being directly in the directory. Unlike the previous DreamBooth method, where images had to be placed in subdirectories, this is not compatible with that specification. Also, even if you name the folder something like "5_cat", the number of repeats of the image and the class name will not be reflected. If you want to set these individually, you will need to explicitly specify them using `num_repeats` and `class_tokens`.
* `image_dir`
* Specifies the path to the image directory. This is a required option.
* Images must be placed directly under the directory.
* `class_tokens`
* Sets the class tokens.
* Only used during training when a corresponding caption file does not exist. The determination of whether or not to use it is made on a per-image basis. If `class_tokens` is not specified and a caption file is not found, an error will occur.
* `cache_info`
* Specifies whether to cache the image size and caption. If not specified, it is set to `false`. The cache is saved in `metadata_cache.json` in `image_dir`.
* Caching speeds up the loading of the dataset after the first time. It is effective when dealing with thousands of images or more.
* `is_reg`
* Specifies whether the subset images are for normalization. If not specified, it is set to `false`, meaning that the images are not for normalization.
### Fine-tuning method specific options
The options for the fine-tuning method only exist for subset-specific options.
#### Subset-specific options
These options are related to the configuration of the fine-tuning method's subsets.
| Option name | Example setting | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
| ---- | ---- | ---- | ---- | ---- |
| `image_dir` | `'C:\hoge'` | - | - | o |
| `metadata_file` | `'C:\piyo\piyo_md.json'` | - | - | o (required) |
* `image_dir`
* Specify the path to the image directory. Unlike the DreamBooth method, specifying it is not mandatory, but it is recommended to do so.
* The case where it is not necessary to specify is when the `--full_path` is added to the command line when generating the metadata file.
* The images must be placed directly under the directory.
* `metadata_file`
* Specify the path to the metadata file used for the subset. This is a required option.
* It is equivalent to the command-line argument `--in_json`.
* Due to the specification that a metadata file must be specified for each subset, it is recommended to avoid creating a metadata file with images from different directories as a single metadata file. It is strongly recommended to prepare a separate metadata file for each image directory and register them as separate subsets.
### Options available when caption dropout method can be used
The options available when the caption dropout method can be used exist only for subsets. Regardless of whether it's the DreamBooth method or fine-tuning method, if it supports caption dropout, it can be specified.
#### Subset-specific options
Options related to the setting of subsets that caption dropout can be used for.
| Option Name | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
| ---- | ---- | ---- | ---- |
| `caption_dropout_every_n_epochs` | o | o | o |
| `caption_dropout_rate` | o | o | o |
| `caption_tag_dropout_rate` | o | o | o |
## Behavior when there are duplicate subsets
In the case of the DreamBooth dataset, if there are multiple `image_dir` directories with the same content, they are considered to be duplicate subsets. For the fine-tuning dataset, if there are multiple `metadata_file` files with the same content, they are considered to be duplicate subsets. If duplicate subsets exist in the dataset, subsequent subsets will be ignored.
However, if they belong to different datasets, they are not considered duplicates. For example, if you have subsets with the same `image_dir` in different datasets, they will not be considered duplicates. This is useful when you want to train with the same image but with different resolutions.
```toml
# If data sets exist separately, they are not considered duplicates and are both used for training.
[[datasets]]
resolution = 512
[[datasets.subsets]]
image_dir = 'C:\hoge'
[[datasets]]
resolution = 768
[[datasets.subsets]]
image_dir = 'C:\hoge'
```
## Command Line Argument and Configuration File
There are options in the configuration file that have overlapping roles with command line argument options.
The following command line argument options are ignored if a configuration file is passed:
* `--train_data_dir`
* `--reg_data_dir`
* `--in_json`
The following command line argument options are given priority over the configuration file options if both are specified simultaneously. In most cases, they have the same names as the corresponding options in the configuration file.
| Command Line Argument Option | Prioritized Configuration File Option |
| ------------------------------- | ------------------------------------- |
| `--bucket_no_upscale` | |
| `--bucket_reso_steps` | |
| `--caption_dropout_every_n_epochs` | |
| `--caption_dropout_rate` | |
| `--caption_extension` | |
| `--caption_tag_dropout_rate` | |
| `--color_aug` | |
| `--dataset_repeats` | `num_repeats` |
| `--enable_bucket` | |
| `--face_crop_aug_range` | |
| `--flip_aug` | |
| `--keep_tokens` | |
| `--min_bucket_reso` | |
| `--random_crop` | |
| `--resolution` | |
| `--shuffle_caption` | |
| `--train_batch_size` | `batch_size` |
## Error Guide
Currently, we are using an external library to check if the configuration file is written correctly, but the development has not been completed, and there is a problem that the error message is not clear. In the future, we plan to improve this problem.
As a temporary measure, we will list common errors and their solutions. If you encounter an error even though it should be correct or if the error content is not understandable, please contact us as it may be a bug.
* `voluptuous.error.MultipleInvalid: required key not provided @ ...`: This error occurs when a required option is not provided. It is highly likely that you forgot to specify the option or misspelled the option name.
* The error location is indicated by `...` in the error message. For example, if you encounter an error like `voluptuous.error.MultipleInvalid: required key not provided @ data['datasets'][0]['subsets'][0]['image_dir']`, it means that the `image_dir` option does not exist in the 0th `subsets` of the 0th `datasets` setting.
* `voluptuous.error.MultipleInvalid: expected int for dictionary value @ ...`: This error occurs when the specified value format is incorrect. It is highly likely that the value format is incorrect. The `int` part changes depending on the target option. The example configurations in this README may be helpful.
* `voluptuous.error.MultipleInvalid: extra keys not allowed @ ...`: This error occurs when there is an option name that is not supported. It is highly likely that you misspelled the option name or mistakenly included it.
## Miscellaneous
### Multi-line captions
By setting `enable_wildcard = true`, multiple-line captions are also enabled. If the caption file consists of multiple lines, one line is randomly selected as the caption.
```txt
1girl, hatsune miku, vocaloid, upper body, looking at viewer, microphone, stage
a girl with a microphone standing on a stage
detailed digital art of a girl with a microphone on a stage
```
It can be combined with wildcard notation.
In metadata files, you can also specify multiple-line captions. In the `.json` metadata file, use `\n` to represent a line break. If the caption file consists of multiple lines, `merge_captions_to_metadata.py` will create a metadata file in this format.
The tags in the metadata (`tags`) are added to each line of the caption.
```json
{
"/path/to/image.png": {
"caption": "a cartoon of a frog with the word frog on it\ntest multiline caption1\ntest multiline caption2",
"tags": "open mouth, simple background, standing, no humans, animal, black background, frog, animal costume, animal focus"
},
...
}
```
In this case, the actual caption will be `a cartoon of a frog with the word frog on it, open mouth, simple background ...`, `test multiline caption1, open mouth, simple background ...`, `test multiline caption2, open mouth, simple background ...`, etc.
### Example of configuration file : `secondary_separator`, wildcard notation, `keep_tokens_separator`, etc.
```toml
[general]
flip_aug = true
color_aug = false
resolution = [1024, 1024]
[[datasets]]
batch_size = 6
enable_bucket = true
bucket_no_upscale = true
caption_extension = ".txt"
keep_tokens_separator= "|||"
shuffle_caption = true
caption_tag_dropout_rate = 0.1
secondary_separator = ";;;" # subset 側に書くこともできます / can be written in the subset side
enable_wildcard = true # 同上 / same as above
[[datasets.subsets]]
image_dir = "/path/to/image_dir"
num_repeats = 1
# ||| の前後はカンマは不要です(自動的に追加されます) / No comma is required before and after ||| (it is added automatically)
caption_prefix = "1girl, hatsune miku, vocaloid |||"
# ||| の後はシャッフル、drop されず残ります / After |||, it is not shuffled or dropped and remains
# 単純に文字列として連結されるので、カンマなどは自分で入れる必要があります / It is simply concatenated as a string, so you need to put commas yourself
caption_suffix = ", anime screencap ||| masterpiece, rating: general"
```
### Example of caption, secondary_separator notation: `secondary_separator = ";;;"`
```txt
1girl, hatsune miku, vocaloid, upper body, looking at viewer, sky;;;cloud;;;day, outdoors
```
The part `sky;;;cloud;;;day` is replaced with `sky,cloud,day` without shuffling or dropping. When shuffling and dropping are enabled, it is processed as a whole (as one tag). For example, it becomes `vocaloid, 1girl, upper body, sky,cloud,day, outdoors, hatsune miku` (shuffled) or `vocaloid, 1girl, outdoors, looking at viewer, upper body, hatsune miku` (dropped).
### Example of caption, enable_wildcard notation: `enable_wildcard = true`
```txt
1girl, hatsune miku, vocaloid, upper body, looking at viewer, {simple|white} background
```
`simple` or `white` is randomly selected, and it becomes `simple background` or `white background`.
```txt
1girl, hatsune miku, vocaloid, {{retro style}}
```
If you want to include `{` or `}` in the tag string, double them like `{{` or `}}` (in this example, the actual caption used for training is `{retro style}`).
### Example of caption, `keep_tokens_separator` notation: `keep_tokens_separator = "|||"`
```txt
1girl, hatsune miku, vocaloid ||| stage, microphone, white shirt, smile ||| best quality, rating: general
```
It becomes `1girl, hatsune miku, vocaloid, microphone, stage, white shirt, best quality, rating: general` or `1girl, hatsune miku, vocaloid, white shirt, smile, stage, microphone, best quality, rating: general` etc.

View File

@@ -1,5 +1,3 @@
For non-Japanese speakers: this README is provided only in Japanese in the current state. Sorry for inconvenience. We will provide English version in the near future.
`--dataset_config` で渡すことができる設定ファイルに関する説明です。
## 概要
@@ -140,12 +138,28 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学
| `shuffle_caption` | `true` | o | o | o |
| `caption_prefix` | `“masterpiece, best quality, ”` | o | o | o |
| `caption_suffix` | `“, from side”` | o | o | o |
| `caption_separator` | (通常は設定しません) | o | o | o |
| `keep_tokens_separator` | `“|||”` | o | o | o |
| `secondary_separator` | `“;;;”` | o | o | o |
| `enable_wildcard` | `true` | o | o | o |
* `num_repeats`
* サブセットの画像の繰り返し回数を指定します。fine tuning における `--dataset_repeats` に相当しますが、`num_repeats` はどの学習方法でも指定可能です。
* `caption_prefix`, `caption_suffix`
* キャプションの前、後に付与する文字列を指定します。シャッフルはこれらの文字列を含めた状態で行われます。`keep_tokens` を指定する場合には注意してください。
* `caption_separator`
* タグを区切る文字列を指定します。デフォルトは `,` です。このオプションは通常は設定する必要はありません。
* `keep_tokens_separator`
* キャプションで固定したい部分を区切る文字列を指定します。たとえば `aaa, bbb ||| ccc, ddd, eee, fff ||| ggg, hhh` のように指定すると、`aaa, bbb``ggg, hhh` の部分はシャッフル、drop されず残ります。間のカンマは不要です。結果としてプロンプトは `aaa, bbb, eee, ccc, fff, ggg, hhh``aaa, bbb, fff, ccc, eee, ggg, hhh` などになります。
* `secondary_separator`
* 追加の区切り文字を指定します。この区切り文字で区切られた部分は一つのタグとして扱われ、シャッフル、drop されます。その後、`caption_separator` に置き換えられます。たとえば `aaa;;;bbb;;;ccc` のように指定すると、`aaa,bbb,ccc` に置き換えられるか、まとめて drop されます。
* `enable_wildcard`
* ワイルドカード記法および複数行キャプションを有効にします。ワイルドカード記法、複数行キャプションについては後述します。
### DreamBooth 方式専用のオプション
DreamBooth 方式のオプションは、サブセット向けオプションのみ存在します。
@@ -159,6 +173,7 @@ DreamBooth 方式のサブセットの設定に関わるオプションです。
| `image_dir` | `C:\hoge` | - | - | o必須 |
| `caption_extension` | `".txt"` | o | o | o |
| `class_tokens` | `“sks girl”` | - | - | o |
| `cache_info` | `false` | o | o | o |
| `is_reg` | `false` | - | - | o |
まず注意点として、 `image_dir` には画像ファイルが直下に置かれているパスを指定する必要があります。従来の DreamBooth の手法ではサブディレクトリに画像を置く必要がありましたが、そちらとは仕様に互換性がありません。また、`5_cat` のようなフォルダ名にしても、画像の繰り返し回数とクラス名は反映されません。これらを個別に設定したい場合、`num_repeats``class_tokens` で明示的に指定する必要があることに注意してください。
@@ -169,6 +184,9 @@ DreamBooth 方式のサブセットの設定に関わるオプションです。
* `class_tokens`
* クラストークンを設定します。
* 画像に対応する caption ファイルが存在しない場合にのみ学習時に利用されます。利用するかどうかの判定は画像ごとに行います。`class_tokens` を指定しなかった場合に caption ファイルも見つからなかった場合にはエラーになります。
* `cache_info`
* 画像サイズ、キャプションをキャッシュするかどうかを指定します。指定しなかった場合は `false` になります。キャッシュは `image_dir``metadata_cache.json` というファイル名で保存されます。
* キャッシュを行うと、二回目以降のデータセット読み込みが高速化されます。数千枚以上の画像を扱う場合には有効です。
* `is_reg`
* サブセットの画像が正規化用かどうかを指定します。指定しなかった場合は `false` として、つまり正規化画像ではないとして扱います。
@@ -280,4 +298,89 @@ resolution = 768
* `voluptuous.error.MultipleInvalid: expected int for dictionary value @ ...`: 指定する値の形式が不正というエラーです。値の形式が間違っている可能性が高いです。`int` の部分は対象となるオプションによって変わります。この README に載っているオプションの「設定例」が役立つかもしれません。
* `voluptuous.error.MultipleInvalid: extra keys not allowed @ ...`: 対応していないオプション名が存在している場合に発生するエラーです。オプション名を間違って記述しているか、誤って紛れ込んでいる可能性が高いです。
## その他
### 複数行キャプション
`enable_wildcard = true` を設定することで、複数行キャプションも同時に有効になります。キャプションファイルが複数の行からなる場合、ランダムに一つの行が選ばれてキャプションとして利用されます。
```txt
1girl, hatsune miku, vocaloid, upper body, looking at viewer, microphone, stage
a girl with a microphone standing on a stage
detailed digital art of a girl with a microphone on a stage
```
ワイルドカード記法と組み合わせることも可能です。
メタデータファイルでも同様に複数行キャプションを指定することができます。メタデータの .json 内には、`\n` を使って改行を表現してください。キャプションファイルが複数行からなる場合、`merge_captions_to_metadata.py` を使うと、この形式でメタデータファイルが作成されます。
メタデータのタグ (`tags`) は、キャプションの各行に追加されます。
```json
{
"/path/to/image.png": {
"caption": "a cartoon of a frog with the word frog on it\ntest multiline caption1\ntest multiline caption2",
"tags": "open mouth, simple background, standing, no humans, animal, black background, frog, animal costume, animal focus"
},
...
}
```
この場合、実際のキャプションは `a cartoon of a frog with the word frog on it, open mouth, simple background ...` または `test multiline caption1, open mouth, simple background ...``test multiline caption2, open mouth, simple background ...` 等になります。
### 設定ファイルの記述例:追加の区切り文字、ワイルドカード記法、`keep_tokens_separator` 等
```toml
[general]
flip_aug = true
color_aug = false
resolution = [1024, 1024]
[[datasets]]
batch_size = 6
enable_bucket = true
bucket_no_upscale = true
caption_extension = ".txt"
keep_tokens_separator= "|||"
shuffle_caption = true
caption_tag_dropout_rate = 0.1
secondary_separator = ";;;" # subset 側に書くこともできます / can be written in the subset side
enable_wildcard = true # 同上 / same as above
[[datasets.subsets]]
image_dir = "/path/to/image_dir"
num_repeats = 1
# ||| の前後はカンマは不要です(自動的に追加されます) / No comma is required before and after ||| (it is added automatically)
caption_prefix = "1girl, hatsune miku, vocaloid |||"
# ||| の後はシャッフル、drop されず残ります / After |||, it is not shuffled or dropped and remains
# 単純に文字列として連結されるので、カンマなどは自分で入れる必要があります / It is simply concatenated as a string, so you need to put commas yourself
caption_suffix = ", anime screencap ||| masterpiece, rating: general"
```
### キャプション記述例、secondary_separator 記法:`secondary_separator = ";;;"` の場合
```txt
1girl, hatsune miku, vocaloid, upper body, looking at viewer, sky;;;cloud;;;day, outdoors
```
`sky;;;cloud;;;day` の部分はシャッフル、drop されず `sky,cloud,day` に置換されます。シャッフル、drop が有効な場合、まとめて(一つのタグとして)処理されます。つまり `vocaloid, 1girl, upper body, sky,cloud,day, outdoors, hatsune miku` (シャッフル)や `vocaloid, 1girl, outdoors, looking at viewer, upper body, hatsune miku` drop されたケース)などになります。
### キャプション記述例、ワイルドカード記法: `enable_wildcard = true` の場合
```txt
1girl, hatsune miku, vocaloid, upper body, looking at viewer, {simple|white} background
```
ランダムに `simple` または `white` が選ばれ、`simple background` または `white background` になります。
```txt
1girl, hatsune miku, vocaloid, {{retro style}}
```
タグ文字列に `{``}` そのものを含めたい場合は `{{``}}` のように二つ重ねてください(この例では実際に学習に用いられるキャプションは `{retro style}` になります)。
### キャプション記述例、`keep_tokens_separator` 記法: `keep_tokens_separator = "|||"` の場合
```txt
1girl, hatsune miku, vocaloid ||| stage, microphone, white shirt, smile ||| best quality, rating: general
```
`1girl, hatsune miku, vocaloid, microphone, stage, white shirt, best quality, rating: general``1girl, hatsune miku, vocaloid, white shirt, smile, stage, microphone, best quality, rating: general` などになります。

84
docs/train_SDXL-en.md Normal file
View File

@@ -0,0 +1,84 @@
## SDXL training
The documentation will be moved to the training documentation in the future. The following is a brief explanation of the training scripts for SDXL.
### Training scripts for SDXL
- `sdxl_train.py` is a script for SDXL fine-tuning. The usage is almost the same as `fine_tune.py`, but it also supports DreamBooth dataset.
- `--full_bf16` option is added. Thanks to KohakuBlueleaf!
- This option enables the full bfloat16 training (includes gradients). This option is useful to reduce the GPU memory usage.
- The full bfloat16 training might be unstable. Please use it at your own risk.
- The different learning rates for each U-Net block are now supported in sdxl_train.py. Specify with `--block_lr` option. Specify 23 values separated by commas like `--block_lr 1e-3,1e-3 ... 1e-3`.
- 23 values correspond to `0: time/label embed, 1-9: input blocks 0-8, 10-12: mid blocks 0-2, 13-21: output blocks 0-8, 22: out`.
- `prepare_buckets_latents.py` now supports SDXL fine-tuning.
- `sdxl_train_network.py` is a script for LoRA training for SDXL. The usage is almost the same as `train_network.py`.
- Both scripts has following additional options:
- `--cache_text_encoder_outputs` and `--cache_text_encoder_outputs_to_disk`: Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions.
- `--no_half_vae`: Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs.
- `--weighted_captions` option is not supported yet for both scripts.
- `sdxl_train_textual_inversion.py` is a script for Textual Inversion training for SDXL. The usage is almost the same as `train_textual_inversion.py`.
- `--cache_text_encoder_outputs` is not supported.
- There are two options for captions:
1. Training with captions. All captions must include the token string. The token string is replaced with multiple tokens.
2. Use `--use_object_template` or `--use_style_template` option. The captions are generated from the template. The existing captions are ignored.
- See below for the format of the embeddings.
- `--min_timestep` and `--max_timestep` options are added to each training script. These options can be used to train U-Net with different timesteps. The default values are 0 and 1000.
### Utility scripts for SDXL
- `tools/cache_latents.py` is added. This script can be used to cache the latents to disk in advance.
- The options are almost the same as `sdxl_train.py'. See the help message for the usage.
- Please launch the script as follows:
`accelerate launch --num_cpu_threads_per_process 1 tools/cache_latents.py ...`
- This script should work with multi-GPU, but it is not tested in my environment.
- `tools/cache_text_encoder_outputs.py` is added. This script can be used to cache the text encoder outputs to disk in advance.
- The options are almost the same as `cache_latents.py` and `sdxl_train.py`. See the help message for the usage.
- `sdxl_gen_img.py` is added. This script can be used to generate images with SDXL, including LoRA, Textual Inversion and ControlNet-LLLite. See the help message for the usage.
### Tips for SDXL training
- The default resolution of SDXL is 1024x1024.
- The fine-tuning can be done with 24GB GPU memory with the batch size of 1. For 24GB GPU, the following options are recommended __for the fine-tuning with 24GB GPU memory__:
- Train U-Net only.
- Use gradient checkpointing.
- Use `--cache_text_encoder_outputs` option and caching latents.
- Use Adafactor optimizer. RMSprop 8bit or Adagrad 8bit may work. AdamW 8bit doesn't seem to work.
- The LoRA training can be done with 8GB GPU memory (10GB recommended). For reducing the GPU memory usage, the following options are recommended:
- Train U-Net only.
- Use gradient checkpointing.
- Use `--cache_text_encoder_outputs` option and caching latents.
- Use one of 8bit optimizers or Adafactor optimizer.
- Use lower dim (4 to 8 for 8GB GPU).
- `--network_train_unet_only` option is highly recommended for SDXL LoRA. Because SDXL has two text encoders, the result of the training will be unexpected.
- PyTorch 2 seems to use slightly less GPU memory than PyTorch 1.
- `--bucket_reso_steps` can be set to 32 instead of the default value 64. Smaller values than 32 will not work for SDXL training.
Example of the optimizer settings for Adafactor with the fixed learning rate:
```toml
optimizer_type = "adafactor"
optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ]
lr_scheduler = "constant_with_warmup"
lr_warmup_steps = 100
learning_rate = 4e-7 # SDXL original learning rate
```
### Format of Textual Inversion embeddings for SDXL
```python
from safetensors.torch import save_file
state_dict = {"clip_g": embs_for_text_encoder_1280, "clip_l": embs_for_text_encoder_768}
save_file(state_dict, file)
```
### ControlNet-LLLite
ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [documentation](./docs/train_lllite_README.md) for details.

View File

@@ -21,9 +21,13 @@ ComfyUIのカスタムードを用意しています。: https://github.com/k
## モデルの学習
### データセットの準備
通常のdatasetに加え`conditioning_data_dir` で指定したディレクトリにconditioning imageを格納してください。conditioning imageは学習用画像と同じbasenameを持つ必要があります。また、conditioning imageは学習用画像と同じサイズに自動的にリサイズされます。conditioning imageにはキャプションファイルは不要です。
DreamBooth 方式の dataset`conditioning_data_dir` で指定したディレクトリにconditioning imageを格納してください。
たとえば DreamBooth 方式でキャプションファイルを用いる場合の設定ファイルは以下のようになります。
finetuning 方式の dataset はサポートしていません。)
conditioning imageは学習用画像と同じbasenameを持つ必要があります。また、conditioning imageは学習用画像と同じサイズに自動的にリサイズされます。conditioning imageにはキャプションファイルは不要です。
たとえば、キャプションにフォルダ名ではなくキャプションファイルを用いる場合の設定ファイルは以下のようになります。
```toml
[[datasets.subsets]]

View File

@@ -26,7 +26,9 @@ Due to the limitations of the inference environment, only CrossAttention (attn1
### Preparing the dataset
In addition to the normal dataset, please store the conditioning image in the directory specified by `conditioning_data_dir`. The conditioning image must have the same basename as the training image. The conditioning image will be automatically resized to the same size as the training image. The conditioning image does not require a caption file.
In addition to the normal DreamBooth method dataset, please store the conditioning image in the directory specified by `conditioning_data_dir`. The conditioning image must have the same basename as the training image. The conditioning image will be automatically resized to the same size as the training image. The conditioning image does not require a caption file.
(We do not support the finetuning method dataset.)
```toml
[[datasets.subsets]]

View File

@@ -0,0 +1,88 @@
# Image Tagging using WD14Tagger
This document is based on the information from this github page (https://github.com/toriato/stable-diffusion-webui-wd14-tagger#mrsmilingwolfs-model-aka-waifu-diffusion-14-tagger).
Using onnx for inference is recommended. Please install onnx with the following command:
```powershell
pip install onnx==1.15.0 onnxruntime-gpu==1.17.1
```
The model weights will be automatically downloaded from Hugging Face.
# Usage
Run the script to perform tagging.
```powershell
python finetune/tag_images_by_wd14_tagger.py --onnx --repo_id <model repo id> --batch_size <batch size> <training data folder>
```
For example, if using the repository `SmilingWolf/wd-swinv2-tagger-v3` with a batch size of 4, and the training data is located in the parent folder `train_data`, it would be:
```powershell
python tag_images_by_wd14_tagger.py --onnx --repo_id SmilingWolf/wd-swinv2-tagger-v3 --batch_size 4 ..\train_data
```
On the first run, the model files will be automatically downloaded to the `wd14_tagger_model` folder (the folder can be changed with an option).
Tag files will be created in the same directory as the training data images, with the same filename and a `.txt` extension.
![Generated tag files](https://user-images.githubusercontent.com/52813779/208910534-ea514373-1185-4b7d-9ae3-61eb50bc294e.png)
![Tags and image](https://user-images.githubusercontent.com/52813779/208910599-29070c15-7639-474f-b3e4-06bd5a3df29e.png)
## Example
To output in the Animagine XL 3.1 format, it would be as follows (enter on a single line in practice):
```
python tag_images_by_wd14_tagger.py --onnx --repo_id SmilingWolf/wd-swinv2-tagger-v3
--batch_size 4 --remove_underscore --undesired_tags "PUT,YOUR,UNDESIRED,TAGS" --recursive
--use_rating_tags_as_last_tag --character_tags_first --character_tag_expand
--always_first_tags "1girl,1boy" ..\train_data
```
## Available Repository IDs
[SmilingWolf's V2 and V3 models](https://huggingface.co/SmilingWolf) are available for use. Specify them in the format like `SmilingWolf/wd-vit-tagger-v3`. The default when omitted is `SmilingWolf/wd-v1-4-convnext-tagger-v2`.
# Options
## General Options
- `--onnx`: Use ONNX for inference. If not specified, TensorFlow will be used. If using TensorFlow, please install TensorFlow separately.
- `--batch_size`: Number of images to process at once. Default is 1. Adjust according to VRAM capacity.
- `--caption_extension`: File extension for caption files. Default is `.txt`.
- `--max_data_loader_n_workers`: Maximum number of workers for DataLoader. Specifying a value of 1 or more will use DataLoader to speed up image loading. If unspecified, DataLoader will not be used.
- `--thresh`: Confidence threshold for outputting tags. Default is 0.35. Lowering the value will assign more tags but accuracy will decrease.
- `--general_threshold`: Confidence threshold for general tags. If omitted, same as `--thresh`.
- `--character_threshold`: Confidence threshold for character tags. If omitted, same as `--thresh`.
- `--recursive`: If specified, subfolders within the specified folder will also be processed recursively.
- `--append_tags`: Append tags to existing tag files.
- `--frequency_tags`: Output tag frequencies.
- `--debug`: Debug mode. Outputs debug information if specified.
## Model Download
- `--model_dir`: Folder to save model files. Default is `wd14_tagger_model`.
- `--force_download`: Re-download model files if specified.
## Tag Editing
- `--remove_underscore`: Remove underscores from output tags.
- `--undesired_tags`: Specify tags not to output. Multiple tags can be specified, separated by commas. For example, `black eyes,black hair`.
- `--use_rating_tags`: Output rating tags at the beginning of the tags.
- `--use_rating_tags_as_last_tag`: Add rating tags at the end of the tags.
- `--character_tags_first`: Output character tags first.
- `--character_tag_expand`: Expand character tag series names. For example, split the tag `chara_name_(series)` into `chara_name, series`.
- `--always_first_tags`: Specify tags to always output first when a certain tag appears in an image. Multiple tags can be specified, separated by commas. For example, `1girl,1boy`.
- `--caption_separator`: Separate tags with this string in the output file. Default is `, `.
- `--tag_replacement`: Perform tag replacement. Specify in the format `tag1,tag2;tag3,tag4`. If using `,` and `;`, escape them with `\`. \
For example, specify `aira tsubase,aira tsubase (uniform)` (when you want to train a specific costume), `aira tsubase,aira tsubase\, heir of shadows` (when the series name is not included in the tag).
When using `tag_replacement`, it is applied after `character_tag_expand`.
When specifying `remove_underscore`, specify `undesired_tags`, `always_first_tags`, and `tag_replacement` without including underscores.
When specifying `caption_separator`, separate `undesired_tags` and `always_first_tags` with `caption_separator`. Always separate `tag_replacement` with `,`.

View File

@@ -0,0 +1,88 @@
# WD14Taggerによるタグ付け
こちらのgithubページhttps://github.com/toriato/stable-diffusion-webui-wd14-tagger#mrsmilingwolfs-model-aka-waifu-diffusion-14-tagger )の情報を参考にさせていただきました。
onnx を用いた推論を推奨します。以下のコマンドで onnx をインストールしてください。
```powershell
pip install onnx==1.15.0 onnxruntime-gpu==1.17.1
```
モデルの重みはHugging Faceから自動的にダウンロードしてきます。
# 使い方
スクリプトを実行してタグ付けを行います。
```
python fintune/tag_images_by_wd14_tagger.py --onnx --repo_id <モデルのrepo id> --batch_size <バッチサイズ> <教師データフォルダ>
```
レポジトリに `SmilingWolf/wd-swinv2-tagger-v3` を使用し、バッチサイズを4にして、教師データを親フォルダの `train_data`に置いた場合、以下のようになります。
```
python tag_images_by_wd14_tagger.py --onnx --repo_id SmilingWolf/wd-swinv2-tagger-v3 --batch_size 4 ..\train_data
```
初回起動時にはモデルファイルが `wd14_tagger_model` フォルダに自動的にダウンロードされます(フォルダはオプションで変えられます)。
タグファイルが教師データ画像と同じディレクトリに、同じファイル名、拡張子.txtで作成されます。
![生成されたタグファイル](https://user-images.githubusercontent.com/52813779/208910534-ea514373-1185-4b7d-9ae3-61eb50bc294e.png)
![タグと画像](https://user-images.githubusercontent.com/52813779/208910599-29070c15-7639-474f-b3e4-06bd5a3df29e.png)
## 記述例
Animagine XL 3.1 方式で出力する場合、以下のようになります(実際には 1 行で入力してください)。
```
python tag_images_by_wd14_tagger.py --onnx --repo_id SmilingWolf/wd-swinv2-tagger-v3
--batch_size 4 --remove_underscore --undesired_tags "PUT,YOUR,UNDESIRED,TAGS" --recursive
--use_rating_tags_as_last_tag --character_tags_first --character_tag_expand
--always_first_tags "1girl,1boy" ..\train_data
```
## 使用可能なリポジトリID
[SmilingWolf 氏の V2、V3 のモデル](https://huggingface.co/SmilingWolf)が使用可能です。`SmilingWolf/wd-vit-tagger-v3` のように指定してください。省略時のデフォルトは `SmilingWolf/wd-v1-4-convnext-tagger-v2` です。
# オプション
## 一般オプション
- `--onnx` : ONNX を使用して推論します。指定しない場合は TensorFlow を使用します。TensorFlow 使用時は別途 TensorFlow をインストールしてください。
- `--batch_size` : 一度に処理する画像の数。デフォルトは1です。VRAMの容量に応じて増減してください。
- `--caption_extension` : キャプションファイルの拡張子。デフォルトは `.txt` です。
- `--max_data_loader_n_workers` : DataLoader の最大ワーカー数です。このオプションに 1 以上の数値を指定すると、DataLoader を用いて画像読み込みを高速化します。未指定時は DataLoader を用いません。
- `--thresh` : 出力するタグの信頼度の閾値。デフォルトは0.35です。値を下げるとより多くのタグが付与されますが、精度は下がります。
- `--general_threshold` : 一般タグの信頼度の閾値。省略時は `--thresh` と同じです。
- `--character_threshold` : キャラクタータグの信頼度の閾値。省略時は `--thresh` と同じです。
- `--recursive` : 指定すると、指定したフォルダ内のサブフォルダも再帰的に処理します。
- `--append_tags` : 既存のタグファイルにタグを追加します。
- `--frequency_tags` : タグの頻度を出力します。
- `--debug` : デバッグモード。指定するとデバッグ情報を出力します。
## モデルのダウンロード
- `--model_dir` : モデルファイルの保存先フォルダ。デフォルトは `wd14_tagger_model` です。
- `--force_download` : 指定するとモデルファイルを再ダウンロードします。
## タグ編集関連
- `--remove_underscore` : 出力するタグからアンダースコアを削除します。
- `--undesired_tags` : 出力しないタグを指定します。カンマ区切りで複数指定できます。たとえば `black eyes,black hair` のように指定します。
- `--use_rating_tags` : タグの最初にレーティングタグを出力します。
- `--use_rating_tags_as_last_tag` : タグの最後にレーティングタグを追加します。
- `--character_tags_first` : キャラクタータグを最初に出力します。
- `--character_tag_expand` : キャラクタータグのシリーズ名を展開します。たとえば `chara_name_(series)` のタグを `chara_name, series` に分割します。
- `--always_first_tags` : あるタグが画像に出力されたとき、そのタグを最初に出力するタグを指定します。カンマ区切りで複数指定できます。たとえば `1girl,1boy` のように指定します。
- `--caption_separator` : 出力するファイルでタグをこの文字列で区切ります。デフォルトは `, ` です。
- `--tag_replacement` : タグの置換を行います。`tag1,tag2;tag3,tag4` のように指定します。`,` および `;` を使う場合は `\` でエスケープしてください。\
たとえば `aira tsubase,aira tsubase (uniform)` (特定の衣装を学習させたいとき)、`aira tsubase,aira tsubase\, heir of shadows` (シリーズ名がタグに含まれないとき)のように指定します。
`tag_replacement``character_tag_expand` の後に適用されます。
`remove_underscore` 指定時は、`undesired_tags``always_first_tags``tag_replacement` はアンダースコアを含めずに指定してください。
`caption_separator` 指定時は、`undesired_tags``always_first_tags``caption_separator` で区切ってください。`tag_replacement` は必ず `,` で区切ってください。

View File

@@ -10,7 +10,9 @@ import toml
from tqdm import tqdm
import torch
from library import deepspeed_utils
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate.utils import set_seed
@@ -42,6 +44,7 @@ from library.custom_train_functions import (
def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
deepspeed_utils.prepare_deepspeed_args(args)
setup_logging(args, reset=True)
cache_latents = args.cache_latents
@@ -108,6 +111,7 @@ def train(args):
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
# モデルを読み込む
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)
@@ -158,7 +162,7 @@ def train(args):
# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
@@ -191,7 +195,7 @@ def train(args):
if not cache_latents:
vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=vae_dtype)
for m in training_models:
m.requires_grad_(True)
@@ -246,13 +250,23 @@ def train(args):
unet.to(weight_dtype)
text_encoder.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
if args.deepspeed:
if args.train_text_encoder:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder)
else:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
training_models = [ds_model]
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
# acceleratorがなんかよろしくやってくれるらしい
if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
@@ -311,13 +325,13 @@ def train(args):
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
with accelerator.accumulate(*training_models):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device) # .to(dtype=weight_dtype)
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(weight_dtype)
latents = latents * 0.18215
b_size = latents.shape[0]
@@ -340,7 +354,7 @@ def train(args):
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
# Predict the noise residual
with accelerator.autocast():
@@ -354,7 +368,7 @@ def train(args):
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss:
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = loss.mean([1, 2, 3])
if args.min_snr_gamma:
@@ -366,7 +380,7 @@ def train(args):
loss = loss.mean() # mean over batch dimension
else:
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c)
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
@@ -457,7 +471,7 @@ def train(args):
accelerator.end_training()
if args.save_state and is_main_process:
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)
del accelerator # この後メモリを使うのでこれは消す
@@ -477,6 +491,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
@@ -492,6 +507,11 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="learning rate for text encoder, default is same as unet / Text Encoderの学習率、デフォルトはunetと同じ",
)
parser.add_argument(
"--no_half_vae",
action="store_true",
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
)
return parser
@@ -500,6 +520,7 @@ if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
train(args)

View File

@@ -134,8 +134,9 @@ class BLIP_Decoder(nn.Module):
def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0):
image_embeds = self.visual_encoder(image)
if not sample:
image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
# recent version of transformers seems to do repeat_interleave automatically
# if not sample:
# image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}

View File

@@ -6,75 +6,95 @@ from tqdm import tqdm
import library.train_util as train_util
import os
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def main(args):
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
assert not args.recursive or (
args.recursive and args.full_path
), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
train_data_dir_path = Path(args.train_data_dir)
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
logger.info(f"found {len(image_paths)} images.")
train_data_dir_path = Path(args.train_data_dir)
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
logger.info(f"found {len(image_paths)} images.")
if args.in_json is None and Path(args.out_json).is_file():
args.in_json = args.out_json
if args.in_json is None and Path(args.out_json).is_file():
args.in_json = args.out_json
if args.in_json is not None:
logger.info(f"loading existing metadata: {args.in_json}")
metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
logger.warning("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます")
else:
logger.info("new metadata will be created / 新しいメタデータファイルが作成されます")
metadata = {}
if args.in_json is not None:
logger.info(f"loading existing metadata: {args.in_json}")
metadata = json.loads(Path(args.in_json).read_text(encoding="utf-8"))
logger.warning("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます")
else:
logger.info("new metadata will be created / 新しいメタデータファイルが作成されます")
metadata = {}
logger.info("merge caption texts to metadata json.")
for image_path in tqdm(image_paths):
caption_path = image_path.with_suffix(args.caption_extension)
caption = caption_path.read_text(encoding='utf-8').strip()
logger.info("merge caption texts to metadata json.")
for image_path in tqdm(image_paths):
caption_path = image_path.with_suffix(args.caption_extension)
caption = caption_path.read_text(encoding="utf-8").strip()
if not os.path.exists(caption_path):
caption_path = os.path.join(image_path, args.caption_extension)
if not os.path.exists(caption_path):
caption_path = os.path.join(image_path, args.caption_extension)
image_key = str(image_path) if args.full_path else image_path.stem
if image_key not in metadata:
metadata[image_key] = {}
image_key = str(image_path) if args.full_path else image_path.stem
if image_key not in metadata:
metadata[image_key] = {}
metadata[image_key]['caption'] = caption
if args.debug:
logger.info(f"{image_key} {caption}")
metadata[image_key]["caption"] = caption
if args.debug:
logger.info(f"{image_key} {caption}")
# metadataを書き出して終わり
logger.info(f"writing metadata: {args.out_json}")
Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
logger.info("done!")
# metadataを書き出して終わり
logger.info(f"writing metadata: {args.out_json}")
Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding="utf-8")
logger.info("done!")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("--in_json", type=str,
help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル省略時、out_jsonが存在すればそれを読み込む")
parser.add_argument("--caption_extention", type=str, default=None,
help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります")
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子")
parser.add_argument("--full_path", action="store_true",
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
parser.add_argument("--recursive", action="store_true",
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")
parser.add_argument("--debug", action="store_true", help="debug mode")
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument(
"--in_json",
type=str,
help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル省略時、out_jsonが存在すればそれを読み込む",
)
parser.add_argument(
"--caption_extention",
type=str,
default=None,
help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)",
)
parser.add_argument(
"--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子"
)
parser.add_argument(
"--full_path",
action="store_true",
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)",
)
parser.add_argument(
"--recursive",
action="store_true",
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す",
)
parser.add_argument("--debug", action="store_true", help="debug mode")
return parser
return parser
if __name__ == '__main__':
parser = setup_parser()
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
args = parser.parse_args()
# スペルミスしていたオプションを復元する
if args.caption_extention is not None:
args.caption_extension = args.caption_extention
# スペルミスしていたオプションを復元する
if args.caption_extention is not None:
args.caption_extension = args.caption_extention
main(args)
main(args)

View File

@@ -6,70 +6,88 @@ from tqdm import tqdm
import library.train_util as train_util
import os
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def main(args):
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
assert not args.recursive or (
args.recursive and args.full_path
), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
train_data_dir_path = Path(args.train_data_dir)
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
logger.info(f"found {len(image_paths)} images.")
train_data_dir_path = Path(args.train_data_dir)
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
logger.info(f"found {len(image_paths)} images.")
if args.in_json is None and Path(args.out_json).is_file():
args.in_json = args.out_json
if args.in_json is None and Path(args.out_json).is_file():
args.in_json = args.out_json
if args.in_json is not None:
logger.info(f"loading existing metadata: {args.in_json}")
metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
logger.warning("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます")
else:
logger.info("new metadata will be created / 新しいメタデータファイルが作成されます")
metadata = {}
if args.in_json is not None:
logger.info(f"loading existing metadata: {args.in_json}")
metadata = json.loads(Path(args.in_json).read_text(encoding="utf-8"))
logger.warning("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます")
else:
logger.info("new metadata will be created / 新しいメタデータファイルが作成されます")
metadata = {}
logger.info("merge tags to metadata json.")
for image_path in tqdm(image_paths):
tags_path = image_path.with_suffix(args.caption_extension)
tags = tags_path.read_text(encoding='utf-8').strip()
logger.info("merge tags to metadata json.")
for image_path in tqdm(image_paths):
tags_path = image_path.with_suffix(args.caption_extension)
tags = tags_path.read_text(encoding="utf-8").strip()
if not os.path.exists(tags_path):
tags_path = os.path.join(image_path, args.caption_extension)
if not os.path.exists(tags_path):
tags_path = os.path.join(image_path, args.caption_extension)
image_key = str(image_path) if args.full_path else image_path.stem
if image_key not in metadata:
metadata[image_key] = {}
image_key = str(image_path) if args.full_path else image_path.stem
if image_key not in metadata:
metadata[image_key] = {}
metadata[image_key]['tags'] = tags
if args.debug:
logger.info(f"{image_key} {tags}")
metadata[image_key]["tags"] = tags
if args.debug:
logger.info(f"{image_key} {tags}")
# metadataを書き出して終わり
logger.info(f"writing metadata: {args.out_json}")
Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
# metadataを書き出して終わり
logger.info(f"writing metadata: {args.out_json}")
Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding="utf-8")
logger.info("done!")
logger.info("done!")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("--in_json", type=str,
help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル省略時、out_jsonが存在すればそれを読み込む")
parser.add_argument("--full_path", action="store_true",
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応")
parser.add_argument("--recursive", action="store_true",
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")
parser.add_argument("--caption_extension", type=str, default=".txt",
help="extension of caption (tag) file / 読み込むキャプション(タグ)ファイルの拡張子")
parser.add_argument("--debug", action="store_true", help="debug mode, print tags")
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument(
"--in_json",
type=str,
help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル省略時、out_jsonが存在すればそれを読み込む",
)
parser.add_argument(
"--full_path",
action="store_true",
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)",
)
parser.add_argument(
"--recursive",
action="store_true",
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す",
)
parser.add_argument(
"--caption_extension",
type=str,
default=".txt",
help="extension of caption (tag) file / 読み込むキャプション(タグ)ファイルの拡張子",
)
parser.add_argument("--debug", action="store_true", help="debug mode, print tags")
return parser
return parser
if __name__ == '__main__':
parser = setup_parser()
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
main(args)
args = parser.parse_args()
main(args)

View File

@@ -12,8 +12,10 @@ from tqdm import tqdm
import library.train_util as train_util
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# from wd14 tagger
@@ -60,12 +62,12 @@ class ImageLoadingPrepDataset(torch.utils.data.Dataset):
try:
image = Image.open(img_path).convert("RGB")
image = preprocess_image(image)
tensor = torch.tensor(image)
# tensor = torch.tensor(image) # これ Tensor に変換する必要ないな……(;・∀・)
except Exception as e:
logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
return None
return (tensor, img_path)
return (image, img_path)
def collate_fn_remove_corrupted(batch):
@@ -79,34 +81,42 @@ def collate_fn_remove_corrupted(batch):
def main(args):
# model location is model_dir + repo_id
# repo id may be like "user/repo" or "user/repo/branch", so we need to remove slash
model_location = os.path.join(args.model_dir, args.repo_id.replace("/", "_"))
# hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする
# depreacatedの警告が出るけどなくなったらその時
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
if not os.path.exists(args.model_dir) or args.force_download:
if not os.path.exists(model_location) or args.force_download:
os.makedirs(args.model_dir, exist_ok=True)
logger.info(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
files = FILES
if args.onnx:
files = ["selected_tags.csv"]
files += FILES_ONNX
else:
for file in SUB_DIR_FILES:
hf_hub_download(
args.repo_id,
file,
subfolder=SUB_DIR,
cache_dir=os.path.join(model_location, SUB_DIR),
force_download=True,
force_filename=file,
)
for file in files:
hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
for file in SUB_DIR_FILES:
hf_hub_download(
args.repo_id,
file,
subfolder=SUB_DIR,
cache_dir=os.path.join(args.model_dir, SUB_DIR),
force_download=True,
force_filename=file,
)
hf_hub_download(args.repo_id, file, cache_dir=model_location, force_download=True, force_filename=file)
else:
logger.info("using existing wd14 tagger model")
# 画像を読み込む
# モデルを読み込む
if args.onnx:
import torch
import onnx
import onnxruntime as ort
onnx_path = f"{args.model_dir}/model.onnx"
onnx_path = f"{model_location}/model.onnx"
logger.info("Running wd14 tagger with onnx")
logger.info(f"loading onnx model: {onnx_path}")
@@ -120,10 +130,10 @@ def main(args):
input_name = model.graph.input[0].name
try:
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_value
except:
except Exception:
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_param
if args.batch_size != batch_size and type(batch_size) != str:
if args.batch_size != batch_size and not isinstance(batch_size, str) and batch_size > 0:
# some rebatch model may use 'N' as dynamic axes
logger.warning(
f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}"
@@ -132,32 +142,79 @@ def main(args):
del model
ort_sess = ort.InferenceSession(
onnx_path,
providers=["CUDAExecutionProvider"]
if "CUDAExecutionProvider" in ort.get_available_providers()
else ["CPUExecutionProvider"],
)
if "OpenVINOExecutionProvider" in ort.get_available_providers():
# requires provider options for gpu support
# fp16 causes nonsense outputs
ort_sess = ort.InferenceSession(
onnx_path,
providers=(["OpenVINOExecutionProvider"]),
provider_options=[{'device_type' : "GPU_FP32"}],
)
else:
ort_sess = ort.InferenceSession(
onnx_path,
providers=(
["CUDAExecutionProvider"] if "CUDAExecutionProvider" in ort.get_available_providers() else
["ROCMExecutionProvider"] if "ROCMExecutionProvider" in ort.get_available_providers() else
["CPUExecutionProvider"]
),
)
else:
from tensorflow.keras.models import load_model
model = load_model(f"{args.model_dir}")
model = load_model(f"{model_location}")
# label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv")
# 依存ライブラリを増やしたくないので自力で読むよ
with open(os.path.join(args.model_dir, CSV_FILE), "r", encoding="utf-8") as f:
with open(os.path.join(model_location, CSV_FILE), "r", encoding="utf-8") as f:
reader = csv.reader(f)
l = [row for row in reader]
header = l[0] # tag_id,name,category,count
rows = l[1:]
line = [row for row in reader]
header = line[0] # tag_id,name,category,count
rows = line[1:]
assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}"
general_tags = [row[1] for row in rows[1:] if row[2] == "0"]
character_tags = [row[1] for row in rows[1:] if row[2] == "4"]
rating_tags = [row[1] for row in rows[0:] if row[2] == "9"]
general_tags = [row[1] for row in rows[0:] if row[2] == "0"]
character_tags = [row[1] for row in rows[0:] if row[2] == "4"]
# preprocess tags in advance
if args.character_tag_expand:
for i, tag in enumerate(character_tags):
if tag.endswith(")"):
# chara_name_(series) -> chara_name, series
# chara_name_(costume)_(series) -> chara_name_(costume), series
tags = tag.split("(")
character_tag = "(".join(tags[:-1])
if character_tag.endswith("_"):
character_tag = character_tag[:-1]
series_tag = tags[-1].replace(")", "")
character_tags[i] = character_tag + args.caption_separator + series_tag
if args.remove_underscore:
rating_tags = [tag.replace("_", " ") if len(tag) > 3 else tag for tag in rating_tags]
general_tags = [tag.replace("_", " ") if len(tag) > 3 else tag for tag in general_tags]
character_tags = [tag.replace("_", " ") if len(tag) > 3 else tag for tag in character_tags]
if args.tag_replacement is not None:
# escape , and ; in tag_replacement: wd14 tag names may contain , and ;
escaped_tag_replacements = args.tag_replacement.replace("\\,", "@@@@").replace("\\;", "####")
tag_replacements = escaped_tag_replacements.split(";")
for tag_replacement in tag_replacements:
tags = tag_replacement.split(",") # source, target
assert len(tags) == 2, f"tag replacement must be in the format of `source,target` / タグの置換は `置換元,置換先` の形式で指定してください: {args.tag_replacement}"
source, target = [tag.replace("@@@@", ",").replace("####", ";") for tag in tags]
logger.info(f"replacing tag: {source} -> {target}")
if source in general_tags:
general_tags[general_tags.index(source)] = target
elif source in character_tags:
character_tags[character_tags.index(source)] = target
elif source in rating_tags:
rating_tags[rating_tags.index(source)] = target
# 画像を読み込む
train_data_dir_path = Path(args.train_data_dir)
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
logger.info(f"found {len(image_paths)} images.")
@@ -166,14 +223,19 @@ def main(args):
caption_separator = args.caption_separator
stripped_caption_separator = caption_separator.strip()
undesired_tags = set(args.undesired_tags.split(stripped_caption_separator))
undesired_tags = args.undesired_tags.split(stripped_caption_separator)
undesired_tags = set([tag.strip() for tag in undesired_tags if tag.strip() != ""])
always_first_tags = None
if args.always_first_tags is not None:
always_first_tags = [tag for tag in args.always_first_tags.split(stripped_caption_separator) if tag.strip() != ""]
def run_batch(path_imgs):
imgs = np.array([im for _, im in path_imgs])
if args.onnx:
if len(imgs) < args.batch_size:
imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0)
# if len(imgs) < args.batch_size:
# imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0)
probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy
probs = probs[: len(path_imgs)]
else:
@@ -181,22 +243,16 @@ def main(args):
probs = probs.numpy()
for (image_path, _), prob in zip(path_imgs, probs):
# 最初の4つはratingなので無視する
# # First 4 labels are actually ratings: pick one with argmax
# ratings_names = label_names[:4]
# rating_index = ratings_names["probs"].argmax()
# found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]]
# それ以降はタグなのでconfidenceがthresholdより高いものを追加する
# Everything else is tags: pick any where prediction confidence > threshold
combined_tags = []
general_tag_text = ""
rating_tag_text = ""
character_tag_text = ""
general_tag_text = ""
# 最初の4つ以降はタグなのでconfidenceがthreshold以上のものを追加する
# First 4 labels are ratings, the rest are tags: pick any where prediction confidence >= threshold
for i, p in enumerate(prob[4:]):
if i < len(general_tags) and p >= args.general_threshold:
tag_name = general_tags[i]
if args.remove_underscore and len(tag_name) > 3: # ignore emoji tags like >_< and ^_^
tag_name = tag_name.replace("_", " ")
if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
@@ -204,13 +260,37 @@ def main(args):
combined_tags.append(tag_name)
elif i >= len(general_tags) and p >= args.character_threshold:
tag_name = character_tags[i - len(general_tags)]
if args.remove_underscore and len(tag_name) > 3:
tag_name = tag_name.replace("_", " ")
if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
character_tag_text += caption_separator + tag_name
combined_tags.append(tag_name)
if args.character_tags_first: # insert to the beginning
combined_tags.insert(0, tag_name)
else:
combined_tags.append(tag_name)
# 最初の4つはratingなのでargmaxで選ぶ
# First 4 labels are actually ratings: pick one with argmax
if args.use_rating_tags or args.use_rating_tags_as_last_tag:
ratings_probs = prob[:4]
rating_index = ratings_probs.argmax()
found_rating = rating_tags[rating_index]
if found_rating not in undesired_tags:
tag_freq[found_rating] = tag_freq.get(found_rating, 0) + 1
rating_tag_text = found_rating
if args.use_rating_tags:
combined_tags.insert(0, found_rating) # insert to the beginning
else:
combined_tags.append(found_rating)
# 一番最初に置くタグを指定する
# Always put some tags at the beginning
if always_first_tags is not None:
for tag in always_first_tags:
if tag in combined_tags:
combined_tags.remove(tag)
combined_tags.insert(0, tag)
# 先頭のカンマを取る
if len(general_tag_text) > 0:
@@ -243,6 +323,7 @@ def main(args):
if args.debug:
logger.info("")
logger.info(f"{image_path}:")
logger.info(f"\tRating tags: {rating_tag_text}")
logger.info(f"\tCharacter tags: {character_tag_text}")
logger.info(f"\tGeneral tags: {general_tag_text}")
@@ -267,9 +348,7 @@ def main(args):
continue
image, image_path = data
if image is not None:
image = image.detach().numpy()
else:
if image is None:
try:
image = Image.open(image_path)
if image.mode != "RGB":
@@ -300,7 +379,9 @@ def main(args):
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument(
"train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ"
)
parser.add_argument(
"--repo_id",
type=str,
@@ -314,9 +395,13 @@ def setup_parser() -> argparse.ArgumentParser:
help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ",
)
parser.add_argument(
"--force_download", action="store_true", help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします"
"--force_download",
action="store_true",
help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします",
)
parser.add_argument(
"--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ"
)
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument(
"--max_data_loader_n_workers",
type=int,
@@ -329,8 +414,12 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)",
)
parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子")
parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値")
parser.add_argument(
"--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子"
)
parser.add_argument(
"--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値"
)
parser.add_argument(
"--general_threshold",
type=float,
@@ -343,28 +432,67 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="threshold of confidence to add a tag for character category, same as --thres if omitted / characterカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ",
)
parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する")
parser.add_argument(
"--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する"
)
parser.add_argument(
"--remove_underscore",
action="store_true",
help="replace underscores with spaces in the output tags / 出力されるタグのアンダースコアをスペースに置き換える",
)
parser.add_argument("--debug", action="store_true", help="debug mode")
parser.add_argument(
"--debug", action="store_true", help="debug mode"
)
parser.add_argument(
"--undesired_tags",
type=str,
default="",
help="comma-separated list of undesired tags to remove from the output / 出力から除外したいタグのカンマ区切りのリスト",
)
parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する")
parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する")
parser.add_argument("--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する")
parser.add_argument(
"--frequency_tags", action="store_true", help="Show frequency of tags for images / タグの出現頻度を表示する"
)
parser.add_argument(
"--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する"
)
parser.add_argument(
"--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する"
)
parser.add_argument(
"--use_rating_tags", action="store_true", help="Adds rating tags as the first tag / レーティングタグを最初のタグとして追加する",
)
parser.add_argument(
"--use_rating_tags_as_last_tag", action="store_true", help="Adds rating tags as the last tag / レーティングタグを最後のタグとして追加する",
)
parser.add_argument(
"--character_tags_first", action="store_true", help="Always inserts character tags before the general tags / characterタグを常にgeneralタグの前に出力する",
)
parser.add_argument(
"--always_first_tags",
type=str,
default=None,
help="comma-separated list of tags to always put at the beginning, e.g. `1girl,1boy`"
+ " / 必ず先頭に置くタグのカンマ区切りリスト、例 : `1girl,1boy`",
)
parser.add_argument(
"--caption_separator",
type=str,
default=", ",
help="Separator for captions, include space if needed / キャプションの区切り文字、必要ならスペースを含めてください",
)
parser.add_argument(
"--tag_replacement",
type=str,
default=None,
help="tag replacement in the format of `source1,target1;source2,target2; ...`. Escape `,` and `;` with `\`. e.g. `tag1,tag2;tag3,tag4`"
+ " / タグの置換を `置換元1,置換先1;置換元2,置換先2; ...`で指定する。`\` で `,` と `;` をエスケープできる。例: `tag1,tag2;tag3,tag4`",
)
parser.add_argument(
"--character_tag_expand",
action="store_true",
help="expand tag tail parenthesis to another tag for character tags. `chara_name_(series)` becomes `chara_name, series`"
+ " / キャラクタタグの末尾の括弧を別のタグに展開する。`chara_name_(series)` は `chara_name, series` になる",
)
return parser

View File

@@ -61,6 +61,12 @@ from library.sdxl_original_unet import InferSdxlUNet2DConditionModel
from library.original_unet import FlashAttentionFunction
from networks.control_net_lllite import ControlNetLLLite
from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
# scheduler:
SCHEDULER_LINEAR_START = 0.00085
@@ -82,12 +88,12 @@ CLIP_VISION_MODEL = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
if mem_eff_attn:
print("Enable memory efficient attention for U-Net")
logger.info("Enable memory efficient attention for U-Net")
# これはDiffusersのU-Netではなく自前のU-Netなので置き換えなくても良い
unet.set_use_memory_efficient_attention(False, True)
elif xformers:
print("Enable xformers for U-Net")
logger.info("Enable xformers for U-Net")
try:
import xformers.ops
except ImportError:
@@ -95,7 +101,7 @@ def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditio
unet.set_use_memory_efficient_attention(True, False)
elif sdpa:
print("Enable SDPA for U-Net")
logger.info("Enable SDPA for U-Net")
unet.set_use_memory_efficient_attention(False, False)
unet.set_use_sdpa(True)
@@ -112,7 +118,7 @@ def replace_vae_modules(vae: diffusers.models.AutoencoderKL, mem_eff_attn, xform
def replace_vae_attn_to_memory_efficient():
print("VAE Attention.forward has been replaced to FlashAttention (not xformers)")
logger.info("VAE Attention.forward has been replaced to FlashAttention (not xformers)")
flash_func = FlashAttentionFunction
def forward_flash_attn(self, hidden_states, **kwargs):
@@ -168,7 +174,7 @@ def replace_vae_attn_to_memory_efficient():
def replace_vae_attn_to_xformers():
print("VAE: Attention.forward has been replaced to xformers")
logger.info("VAE: Attention.forward has been replaced to xformers")
import xformers.ops
def forward_xformers(self, hidden_states, **kwargs):
@@ -224,7 +230,7 @@ def replace_vae_attn_to_xformers():
def replace_vae_attn_to_sdpa():
print("VAE: Attention.forward has been replaced to sdpa")
logger.info("VAE: Attention.forward has been replaced to sdpa")
def forward_sdpa(self, hidden_states, **kwargs):
residual = hidden_states
@@ -386,10 +392,10 @@ class PipelineLike:
def set_gradual_latent(self, gradual_latent):
if gradual_latent is None:
print("gradual_latent is disabled")
logger.info("gradual_latent is disabled")
self.gradual_latent = None
else:
print(f"gradual_latent is enabled: {gradual_latent}")
logger.info(f"gradual_latent is enabled: {gradual_latent}")
self.gradual_latent = gradual_latent # (ds_ratio, start_timesteps, every_n_steps, ratio_step)
@torch.no_grad()
@@ -467,7 +473,7 @@ class PipelineLike:
do_classifier_free_guidance = guidance_scale > 1.0
if not do_classifier_free_guidance and negative_scale is not None:
print(f"negative_scale is ignored if guidance scalle <= 1.0")
logger.warning(f"negative_scale is ignored if guidance scalle <= 1.0")
negative_scale = None
# get unconditional embeddings for classifier free guidance
@@ -576,7 +582,7 @@ class PipelineLike:
text_pool = text_pool[num_sub_prompts - 1 :: num_sub_prompts] # last subprompt
if init_image is not None and self.clip_vision_model is not None:
print(f"encode by clip_vision_model and apply clip_vision_strength={self.clip_vision_strength}")
logger.info(f"encode by clip_vision_model and apply clip_vision_strength={self.clip_vision_strength}")
vision_input = self.clip_vision_processor(init_image, return_tensors="pt", device=self.device)
pixel_values = vision_input["pixel_values"].to(self.device, dtype=text_embeddings.dtype)
@@ -742,8 +748,8 @@ class PipelineLike:
enable_gradual_latent = False
if self.gradual_latent:
if not hasattr(self.scheduler, "set_gradual_latent_params"):
print("gradual_latent is not supported for this scheduler. Ignoring.")
print(self.scheduler.__class__.__name__)
logger.warning("gradual_latent is not supported for this scheduler. Ignoring.")
logger.warning(f"{self.scheduler.__class__.__name__}")
else:
enable_gradual_latent = True
step_elapsed = 1000
@@ -792,7 +798,7 @@ class PipelineLike:
if not enabled or ratio >= 1.0:
continue
if ratio < i / len(timesteps):
print(f"ControlNetLLLite {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})")
logger.info(f"ControlNetLLLite {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})")
control_net.set_cond_image(None)
each_control_net_enabled[j] = False
@@ -1013,7 +1019,7 @@ def get_prompts_with_weights(tokenizer: CLIPTokenizer, token_replacer, prompt: L
if word.strip() == "BREAK":
# pad until next multiple of tokenizer's max token length
pad_len = tokenizer.model_max_length - (len(text_token) % tokenizer.model_max_length)
print(f"BREAK pad_len: {pad_len}")
logger.info(f"BREAK pad_len: {pad_len}")
for i in range(pad_len):
# v2のときEOSをつけるべきかどうかわからないぜ
# if i == 0:
@@ -1043,7 +1049,7 @@ def get_prompts_with_weights(tokenizer: CLIPTokenizer, token_replacer, prompt: L
tokens.append(text_token)
weights.append(text_weight)
if truncated:
print("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
logger.warning("warning: Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
return tokens, weights
@@ -1344,7 +1350,7 @@ def handle_dynamic_prompt_variants(prompt, repeat_count):
elif len(count_range) == 2:
count_range = [int(count_range[0]), int(count_range[1])]
else:
print(f"invalid count range: {count_range}")
logger.warning(f"invalid count range: {count_range}")
count_range = [1, 1]
if count_range[0] > count_range[1]:
count_range = [count_range[1], count_range[0]]
@@ -1488,9 +1494,9 @@ def main(args):
# assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません"
if args.v_parameterization and not args.v2:
print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません")
logger.warning("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません")
if args.v2 and args.clip_skip is not None:
print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
# モデルを読み込む
if not os.path.exists(args.ckpt): # ファイルがないならパターンで探し、一つだけ該当すればそれを使う
@@ -1510,7 +1516,7 @@ def main(args):
else:
# if `text_encoder_2` subdirectory exists, sdxl
is_sdxl = os.path.isdir(os.path.join(name_or_path, "text_encoder_2"))
print(f"SDXL: {is_sdxl}")
logger.info(f"SDXL: {is_sdxl}")
if is_sdxl:
if args.clip_skip is None:
@@ -1526,10 +1532,10 @@ def main(args):
args.clip_skip = 2 if args.v2 else 1
if use_stable_diffusion_format:
print("load StableDiffusion checkpoint")
logger.info("load StableDiffusion checkpoint")
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt)
else:
print("load Diffusers pretrained models")
logger.info("load Diffusers pretrained models")
loading_pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype)
text_encoder = loading_pipe.text_encoder
vae = loading_pipe.vae
@@ -1553,7 +1559,7 @@ def main(args):
# VAEを読み込む
if args.vae is not None:
vae = model_util.load_vae(args.vae, dtype)
print("additional VAE loaded")
logger.info("additional VAE loaded")
# xformers、Hypernetwork対応
if not args.diffusers_xformers:
@@ -1562,7 +1568,7 @@ def main(args):
replace_vae_modules(vae, mem_eff, args.xformers, args.sdpa)
# tokenizerを読み込む
print("loading tokenizer")
logger.info("loading tokenizer")
if is_sdxl:
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
tokenizers = [tokenizer1, tokenizer2]
@@ -1654,7 +1660,7 @@ def main(args):
noise = None
if noise == None:
print(f"unexpected noise request: {self.sampler_noise_index}, {shape}")
logger.warning(f"unexpected noise request: {self.sampler_noise_index}, {shape}")
noise = torch.randn(shape, dtype=dtype, device=device, generator=generator)
self.sampler_noise_index += 1
@@ -1715,7 +1721,7 @@ def main(args):
vae_dtype = dtype
if args.no_half_vae:
print("set vae_dtype to float32")
logger.info("set vae_dtype to float32")
vae_dtype = torch.float32
vae.to(vae_dtype).to(device)
vae.eval()
@@ -1739,10 +1745,10 @@ def main(args):
network_merge = args.network_merge_n_models
else:
network_merge = 0
print(f"network_merge: {network_merge}")
logger.info(f"network_merge: {network_merge}")
for i, network_module in enumerate(args.network_module):
print("import network module:", network_module)
logger.info("import network module: {network_module}")
imported_module = importlib.import_module(network_module)
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
@@ -1760,7 +1766,7 @@ def main(args):
raise ValueError("No weight. Weight is required.")
network_weight = args.network_weights[i]
print("load network weights from:", network_weight)
logger.info(f"load network weights from: {network_weight}")
if model_util.is_safetensors(network_weight) and args.network_show_meta:
from safetensors.torch import safe_open
@@ -1768,7 +1774,7 @@ def main(args):
with safe_open(network_weight, framework="pt") as f:
metadata = f.metadata()
if metadata is not None:
print(f"metadata for: {network_weight}: {metadata}")
logger.info(f"metadata for: {network_weight}: {metadata}")
network, weights_sd = imported_module.create_network_from_weights(
network_mul, network_weight, vae, text_encoders, unet, for_inference=True, **net_kwargs
@@ -1778,20 +1784,20 @@ def main(args):
mergeable = network.is_mergeable()
if network_merge and not mergeable:
print("network is not mergiable. ignore merge option.")
logger.warning("network is not mergiable. ignore merge option.")
if not mergeable or i >= network_merge:
# not merging
network.apply_to(text_encoders, unet)
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
print(f"weights are loaded: {info}")
logger.info(f"weights are loaded: {info}")
if args.opt_channels_last:
network.to(memory_format=torch.channels_last)
network.to(dtype).to(device)
if network_pre_calc:
print("backup original weights")
logger.info("backup original weights")
network.backup_weights()
networks.append(network)
@@ -1805,7 +1811,7 @@ def main(args):
# upscalerの指定があれば取得する
upscaler = None
if args.highres_fix_upscaler:
print("import upscaler module:", args.highres_fix_upscaler)
logger.info("import upscaler module: {args.highres_fix_upscaler}")
imported_module = importlib.import_module(args.highres_fix_upscaler)
us_kwargs = {}
@@ -1814,7 +1820,7 @@ def main(args):
key, value = net_arg.split("=")
us_kwargs[key] = value
print("create upscaler")
logger.info("create upscaler")
upscaler = imported_module.create_upscaler(**us_kwargs)
upscaler.to(dtype).to(device)
@@ -1833,7 +1839,7 @@ def main(args):
control_net_lllites: List[Tuple[ControlNetLLLite, float]] = []
if args.control_net_lllite_models:
for i, model_file in enumerate(args.control_net_lllite_models):
print(f"loading ControlNet-LLLite: {model_file}")
logger.info(f"loading ControlNet-LLLite: {model_file}")
from safetensors.torch import load_file
@@ -1867,7 +1873,7 @@ def main(args):
), "ControlNet and ControlNet-LLLite cannot be used at the same time"
if args.opt_channels_last:
print(f"set optimizing: channels last")
logger.info(f"set optimizing: channels last")
for text_encoder in text_encoders:
text_encoder.to(memory_format=torch.channels_last)
vae.to(memory_format=torch.channels_last)
@@ -1894,7 +1900,7 @@ def main(args):
)
pipe.set_control_nets(control_nets)
pipe.set_control_net_lllites(control_net_lllites)
print("pipeline is ready.")
logger.info("pipeline is ready.")
if args.diffusers_xformers:
pipe.enable_xformers_memory_efficient_attention()
@@ -1965,7 +1971,7 @@ def main(args):
token_ids1 = tokenizers[0].convert_tokens_to_ids(token_strings)
token_ids2 = tokenizers[1].convert_tokens_to_ids(token_strings) if is_sdxl else None
print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids1} and {token_ids2}")
logger.info(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids1} and {token_ids2}")
assert (
min(token_ids1) == token_ids1[0] and token_ids1[-1] == token_ids1[0] + len(token_ids1) - 1
), f"token ids1 is not ordered"
@@ -2002,7 +2008,7 @@ def main(args):
# promptを取得する
prompt_list = None
if args.from_file is not None:
print(f"reading prompts from {args.from_file}")
logger.info(f"reading prompts from {args.from_file}")
with open(args.from_file, "r", encoding="utf-8") as f:
prompt_list = f.read().splitlines()
prompt_list = [d for d in prompt_list if len(d.strip()) > 0 and d[0] != "#"]
@@ -2019,7 +2025,7 @@ def main(args):
spec.loader.exec_module(module)
return module
print(f"reading prompts from module: {args.from_module}")
logger.info(f"reading prompts from module: {args.from_module}")
prompt_module = load_module_from_path("prompt_module", args.from_module)
prompter = prompt_module.get_prompter(args, pipe, networks)
@@ -2050,7 +2056,7 @@ def main(args):
for p in paths:
image = Image.open(p)
if image.mode != "RGB":
print(f"convert image to RGB from {image.mode}: {p}")
logger.info(f"convert image to RGB from {image.mode}: {p}")
image = image.convert("RGB")
images.append(image)
@@ -2066,14 +2072,14 @@ def main(args):
return resized
if args.image_path is not None:
print(f"load image for img2img: {args.image_path}")
logger.info(f"load image for img2img: {args.image_path}")
init_images = load_images(args.image_path)
assert len(init_images) > 0, f"No image / 画像がありません: {args.image_path}"
print(f"loaded {len(init_images)} images for img2img")
logger.info(f"loaded {len(init_images)} images for img2img")
# CLIP Vision
if args.clip_vision_strength is not None:
print(f"load CLIP Vision model: {CLIP_VISION_MODEL}")
logger.info(f"load CLIP Vision model: {CLIP_VISION_MODEL}")
vision_model = CLIPVisionModelWithProjection.from_pretrained(CLIP_VISION_MODEL, projection_dim=1280)
vision_model.to(device, dtype)
processor = CLIPImageProcessor.from_pretrained(CLIP_VISION_MODEL)
@@ -2081,22 +2087,22 @@ def main(args):
pipe.clip_vision_model = vision_model
pipe.clip_vision_processor = processor
pipe.clip_vision_strength = args.clip_vision_strength
print(f"CLIP Vision model loaded.")
logger.info(f"CLIP Vision model loaded.")
else:
init_images = None
if args.mask_path is not None:
print(f"load mask for inpainting: {args.mask_path}")
logger.info(f"load mask for inpainting: {args.mask_path}")
mask_images = load_images(args.mask_path)
assert len(mask_images) > 0, f"No mask image / マスク画像がありません: {args.image_path}"
print(f"loaded {len(mask_images)} mask images for inpainting")
logger.info(f"loaded {len(mask_images)} mask images for inpainting")
else:
mask_images = None
# promptがないとき、画像のPngInfoから取得する
if init_images is not None and prompter is None and not args.interactive:
print("get prompts from images' metadata")
logger.info("get prompts from images' metadata")
prompt_list = []
for img in init_images:
if "prompt" in img.text:
@@ -2127,17 +2133,17 @@ def main(args):
h = int(h * args.highres_fix_scale + 0.5)
if init_images is not None:
print(f"resize img2img source images to {w}*{h}")
logger.info(f"resize img2img source images to {w}*{h}")
init_images = resize_images(init_images, (w, h))
if mask_images is not None:
print(f"resize img2img mask images to {w}*{h}")
logger.info(f"resize img2img mask images to {w}*{h}")
mask_images = resize_images(mask_images, (w, h))
regional_network = False
if networks and mask_images:
# mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応
regional_network = True
print("use mask as region")
logger.info("use mask as region")
size = None
for i, network in enumerate(networks):
@@ -2162,14 +2168,14 @@ def main(args):
prev_image = None # for VGG16 guided
if args.guide_image_path is not None:
print(f"load image for ControlNet guidance: {args.guide_image_path}")
logger.info(f"load image for ControlNet guidance: {args.guide_image_path}")
guide_images = []
for p in args.guide_image_path:
guide_images.extend(load_images(p))
print(f"loaded {len(guide_images)} guide images for guidance")
logger.info(f"loaded {len(guide_images)} guide images for guidance")
if len(guide_images) == 0:
print(
logger.warning(
f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}"
)
guide_images = None
@@ -2200,7 +2206,7 @@ def main(args):
max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples
for gen_iter in range(args.n_iter):
print(f"iteration {gen_iter+1}/{args.n_iter}")
logger.info(f"iteration {gen_iter+1}/{args.n_iter}")
if args.iter_same_seed:
iter_seed = seed_random.randint(0, 2**32 - 1)
else:
@@ -2219,7 +2225,7 @@ def main(args):
# 1st stageのバッチを作成して呼び出すサイズを小さくして呼び出す
is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling
print("process 1st stage")
logger.info("process 1st stage")
batch_1st = []
for _, base, ext in batch:
@@ -2264,7 +2270,7 @@ def main(args):
images_1st = process_batch(batch_1st, True, True)
# 2nd stageのバッチを作成して以下処理する
print("process 2nd stage")
logger.info("process 2nd stage")
width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height
if upscaler:
@@ -2437,7 +2443,7 @@ def main(args):
n.restore_weights()
for n in networks:
n.pre_calculation()
print("pre-calculation... done")
logger.info("pre-calculation... done")
images = pipe(
prompts,
@@ -2520,7 +2526,7 @@ def main(args):
cv2.waitKey()
cv2.destroyAllWindows()
except ImportError:
print(
logger.warning(
"opencv-python is not installed, cannot preview / opencv-pythonがインストールされていないためプレビューできません"
)
@@ -2535,7 +2541,7 @@ def main(args):
# interactive
valid = False
while not valid:
print("\nType prompt:")
logger.info("\nType prompt:")
try:
raw_prompt = input()
except EOFError:
@@ -2595,74 +2601,74 @@ def main(args):
prompt_args = raw_prompt.strip().split(" --")
prompt = prompt_args[0]
length = len(prompter) if hasattr(prompter, "__len__") else 0
print(f"prompt {prompt_index+1}/{length}: {prompt}")
logger.info(f"prompt {prompt_index+1}/{length}: {prompt}")
for parg in prompt_args[1:]:
try:
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
if m:
width = int(m.group(1))
print(f"width: {width}")
logger.info(f"width: {width}")
continue
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
if m:
height = int(m.group(1))
print(f"height: {height}")
logger.info(f"height: {height}")
continue
m = re.match(r"ow (\d+)", parg, re.IGNORECASE)
if m:
original_width = int(m.group(1))
print(f"original width: {original_width}")
logger.info(f"original width: {original_width}")
continue
m = re.match(r"oh (\d+)", parg, re.IGNORECASE)
if m:
original_height = int(m.group(1))
print(f"original height: {original_height}")
logger.info(f"original height: {original_height}")
continue
m = re.match(r"nw (\d+)", parg, re.IGNORECASE)
if m:
original_width_negative = int(m.group(1))
print(f"original width negative: {original_width_negative}")
logger.info(f"original width negative: {original_width_negative}")
continue
m = re.match(r"nh (\d+)", parg, re.IGNORECASE)
if m:
original_height_negative = int(m.group(1))
print(f"original height negative: {original_height_negative}")
logger.info(f"original height negative: {original_height_negative}")
continue
m = re.match(r"ct (\d+)", parg, re.IGNORECASE)
if m:
crop_top = int(m.group(1))
print(f"crop top: {crop_top}")
logger.info(f"crop top: {crop_top}")
continue
m = re.match(r"cl (\d+)", parg, re.IGNORECASE)
if m:
crop_left = int(m.group(1))
print(f"crop left: {crop_left}")
logger.info(f"crop left: {crop_left}")
continue
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
if m: # steps
steps = max(1, min(1000, int(m.group(1))))
print(f"steps: {steps}")
logger.info(f"steps: {steps}")
continue
m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE)
if m: # seed
seeds = [int(d) for d in m.group(1).split(",")]
print(f"seeds: {seeds}")
logger.info(f"seeds: {seeds}")
continue
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
if m: # scale
scale = float(m.group(1))
print(f"scale: {scale}")
logger.info(f"scale: {scale}")
continue
m = re.match(r"nl ([\d\.]+|none|None)", parg, re.IGNORECASE)
@@ -2671,25 +2677,25 @@ def main(args):
negative_scale = None
else:
negative_scale = float(m.group(1))
print(f"negative scale: {negative_scale}")
logger.info(f"negative scale: {negative_scale}")
continue
m = re.match(r"t ([\d\.]+)", parg, re.IGNORECASE)
if m: # strength
strength = float(m.group(1))
print(f"strength: {strength}")
logger.info(f"strength: {strength}")
continue
m = re.match(r"n (.+)", parg, re.IGNORECASE)
if m: # negative prompt
negative_prompt = m.group(1)
print(f"negative prompt: {negative_prompt}")
logger.info(f"negative prompt: {negative_prompt}")
continue
m = re.match(r"c (.+)", parg, re.IGNORECASE)
if m: # clip prompt
clip_prompt = m.group(1)
print(f"clip prompt: {clip_prompt}")
logger.info(f"clip prompt: {clip_prompt}")
continue
m = re.match(r"am ([\d\.\-,]+)", parg, re.IGNORECASE)
@@ -2697,89 +2703,89 @@ def main(args):
network_muls = [float(v) for v in m.group(1).split(",")]
while len(network_muls) < len(networks):
network_muls.append(network_muls[-1])
print(f"network mul: {network_muls}")
logger.info(f"network mul: {network_muls}")
continue
# Deep Shrink
m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink depth 1
ds_depth_1 = int(m.group(1))
print(f"deep shrink depth 1: {ds_depth_1}")
logger.info(f"deep shrink depth 1: {ds_depth_1}")
continue
m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink timesteps 1
ds_timesteps_1 = int(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
print(f"deep shrink timesteps 1: {ds_timesteps_1}")
logger.info(f"deep shrink timesteps 1: {ds_timesteps_1}")
continue
m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink depth 2
ds_depth_2 = int(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
print(f"deep shrink depth 2: {ds_depth_2}")
logger.info(f"deep shrink depth 2: {ds_depth_2}")
continue
m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink timesteps 2
ds_timesteps_2 = int(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
print(f"deep shrink timesteps 2: {ds_timesteps_2}")
logger.info(f"deep shrink timesteps 2: {ds_timesteps_2}")
continue
m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink ratio
ds_ratio = float(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
print(f"deep shrink ratio: {ds_ratio}")
logger.info(f"deep shrink ratio: {ds_ratio}")
continue
# Gradual Latent
m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent timesteps
gl_timesteps = int(m.group(1))
print(f"gradual latent timesteps: {gl_timesteps}")
logger.info(f"gradual latent timesteps: {gl_timesteps}")
continue
m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent ratio
gl_ratio = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
print(f"gradual latent ratio: {ds_ratio}")
logger.info(f"gradual latent ratio: {ds_ratio}")
continue
m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent every n steps
gl_every_n_steps = int(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
print(f"gradual latent every n steps: {gl_every_n_steps}")
logger.info(f"gradual latent every n steps: {gl_every_n_steps}")
continue
m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent ratio step
gl_ratio_step = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
print(f"gradual latent ratio step: {gl_ratio_step}")
logger.info(f"gradual latent ratio step: {gl_ratio_step}")
continue
m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent s noise
gl_s_noise = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
print(f"gradual latent s noise: {gl_s_noise}")
logger.info(f"gradual latent s noise: {gl_s_noise}")
continue
m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE)
if m: # gradual latent unsharp params
gl_unsharp_params = m.group(1)
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
print(f"gradual latent unsharp params: {gl_unsharp_params}")
logger.info(f"gradual latent unsharp params: {gl_unsharp_params}")
continue
except ValueError as ex:
print(f"Exception in parsing / 解析エラー: {parg}")
print(ex)
logger.error(f"Exception in parsing / 解析エラー: {parg}")
logger.error(f"{ex}")
# override Deep Shrink
if ds_depth_1 is not None:
@@ -2825,7 +2831,7 @@ def main(args):
if seed is None:
seed = seed_random.randint(0, 2**32 - 1)
if args.interactive:
print(f"seed: {seed}")
logger.info(f"seed: {seed}")
# prepare init image, guide image and mask
init_image = mask_image = guide_image = None
@@ -2841,7 +2847,7 @@ def main(args):
width = width - width % 32
height = height - height % 32
if width != init_image.size[0] or height != init_image.size[1]:
print(
logger.warning(
f"img2img image size is not divisible by 32 so aspect ratio is changed / img2imgの画像サイズが32で割り切れないためリサイズされます。画像が歪みます"
)
@@ -2903,12 +2909,14 @@ def main(args):
process_batch(batch_data, highres_fix)
batch_data.clear()
print("done!")
logger.info("done!")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
parser.add_argument(
"--sdxl", action="store_true", help="load Stable Diffusion XL model / Stable Diffusion XLのモデルを読み込む"
)

View File

@@ -489,10 +489,10 @@ class PipelineLike:
def set_gradual_latent(self, gradual_latent):
if gradual_latent is None:
print("gradual_latent is disabled")
logger.info("gradual_latent is disabled")
self.gradual_latent = None
else:
print(f"gradual_latent is enabled: {gradual_latent}")
logger.info(f"gradual_latent is enabled: {gradual_latent}")
self.gradual_latent = gradual_latent # (ds_ratio, start_timesteps, every_n_steps, ratio_step)
# region xformersとか使う部分独自に書き換えるので関係なし
@@ -971,8 +971,8 @@ class PipelineLike:
enable_gradual_latent = False
if self.gradual_latent:
if not hasattr(self.scheduler, "set_gradual_latent_params"):
print("gradual_latent is not supported for this scheduler. Ignoring.")
print(self.scheduler.__class__.__name__)
logger.info("gradual_latent is not supported for this scheduler. Ignoring.")
logger.info(f'{self.scheduler.__class__.__name__}')
else:
enable_gradual_latent = True
step_elapsed = 1000
@@ -3314,42 +3314,42 @@ def main(args):
m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent timesteps
gl_timesteps = int(m.group(1))
print(f"gradual latent timesteps: {gl_timesteps}")
logger.info(f"gradual latent timesteps: {gl_timesteps}")
continue
m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent ratio
gl_ratio = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
print(f"gradual latent ratio: {ds_ratio}")
logger.info(f"gradual latent ratio: {ds_ratio}")
continue
m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent every n steps
gl_every_n_steps = int(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
print(f"gradual latent every n steps: {gl_every_n_steps}")
logger.info(f"gradual latent every n steps: {gl_every_n_steps}")
continue
m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent ratio step
gl_ratio_step = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
print(f"gradual latent ratio step: {gl_ratio_step}")
logger.info(f"gradual latent ratio step: {gl_ratio_step}")
continue
m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent s noise
gl_s_noise = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
print(f"gradual latent s noise: {gl_s_noise}")
logger.info(f"gradual latent s noise: {gl_s_noise}")
continue
m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE)
if m: # gradual latent unsharp params
gl_unsharp_params = m.group(1)
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
print(f"gradual latent unsharp params: {gl_unsharp_params}")
logger.info(f"gradual latent unsharp params: {gl_unsharp_params}")
continue
except ValueError as ex:
@@ -3369,7 +3369,7 @@ def main(args):
if gl_unsharp_params is not None:
unsharp_params = gl_unsharp_params.split(",")
us_ksize, us_sigma, us_strength = [float(v) for v in unsharp_params[:3]]
print(unsharp_params)
logger.info(f'{unsharp_params}')
us_target_x = True if len(unsharp_params) < 4 else bool(int(unsharp_params[3]))
us_ksize = int(us_ksize)
else:

View File

@@ -41,12 +41,17 @@ from .train_util import (
DatasetGroup,
)
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def add_config_arguments(parser: argparse.ArgumentParser):
parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル")
parser.add_argument(
"--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル"
)
# TODO: inherit Params class in Subset, Dataset
@@ -60,6 +65,8 @@ class BaseSubsetParams:
caption_separator: str = (",",)
keep_tokens: int = 0
keep_tokens_separator: str = (None,)
secondary_separator: Optional[str] = None
enable_wildcard: bool = False
color_aug: bool = False
flip_aug: bool = False
face_crop_aug_range: Optional[Tuple[float, float]] = None
@@ -78,6 +85,7 @@ class DreamBoothSubsetParams(BaseSubsetParams):
is_reg: bool = False
class_tokens: Optional[str] = None
caption_extension: str = ".caption"
cache_info: bool = False
@dataclass
@@ -89,6 +97,7 @@ class FineTuningSubsetParams(BaseSubsetParams):
class ControlNetSubsetParams(BaseSubsetParams):
conditioning_data_dir: str = None
caption_extension: str = ".caption"
cache_info: bool = False
@dataclass
@@ -181,6 +190,8 @@ class ConfigSanitizer:
"shuffle_caption": bool,
"keep_tokens": int,
"keep_tokens_separator": str,
"secondary_separator": str,
"enable_wildcard": bool,
"token_warmup_min": int,
"token_warmup_step": Any(float, int),
"caption_prefix": str,
@@ -196,6 +207,7 @@ class ConfigSanitizer:
DB_SUBSET_ASCENDABLE_SCHEMA = {
"caption_extension": str,
"class_tokens": str,
"cache_info": bool,
}
DB_SUBSET_DISTINCT_SCHEMA = {
Required("image_dir"): str,
@@ -208,6 +220,7 @@ class ConfigSanitizer:
}
CN_SUBSET_ASCENDABLE_SCHEMA = {
"caption_extension": str,
"cache_info": bool,
}
CN_SUBSET_DISTINCT_SCHEMA = {
Required("image_dir"): str,
@@ -244,9 +257,10 @@ class ConfigSanitizer:
}
def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None:
assert (
support_dreambooth or support_finetuning or support_controlnet
), "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。"
assert support_dreambooth or support_finetuning or support_controlnet, (
"Neither DreamBooth mode nor fine tuning mode nor controlnet mode specified. Please specify one mode or more."
+ " / DreamBooth モードか fine tuning モードか controlnet モードのども指定されていません。1つ以上指定してください。"
)
self.db_subset_schema = self.__merge_dict(
self.SUBSET_ASCENDABLE_SCHEMA,
@@ -313,7 +327,10 @@ class ConfigSanitizer:
self.dataset_schema = validate_flex_dataset
elif support_dreambooth:
self.dataset_schema = self.db_dataset_schema
if support_controlnet:
self.dataset_schema = self.cn_dataset_schema
else:
self.dataset_schema = self.db_dataset_schema
elif support_finetuning:
self.dataset_schema = self.ft_dataset_schema
elif support_controlnet:
@@ -358,7 +375,9 @@ class ConfigSanitizer:
return self.argparse_config_validator(argparse_namespace)
except MultipleInvalid:
# XXX: this should be a bug
logger.error("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。")
logger.error(
"Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。"
)
raise
# NOTE: value would be overwritten by latter dict if there is already the same key
@@ -504,6 +523,8 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
shuffle_caption: {subset.shuffle_caption}
keep_tokens: {subset.keep_tokens}
keep_tokens_separator: {subset.keep_tokens_separator}
secondary_separator: {subset.secondary_separator}
enable_wildcard: {subset.enable_wildcard}
caption_dropout_rate: {subset.caption_dropout_rate}
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
@@ -541,11 +562,11 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
" ",
)
logger.info(f'{info}')
logger.info(f"{info}")
# make buckets first because it determines the length of dataset
# and set the same seed for all datasets
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
for i, dataset in enumerate(datasets):
logger.info(f"[Dataset {i}]")
dataset.make_buckets()
@@ -632,13 +653,17 @@ def load_user_config(file: str) -> dict:
with open(file, "r") as f:
config = json.load(f)
except Exception:
logger.error(f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}")
logger.error(
f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
)
raise
elif file.name.lower().endswith(".toml"):
try:
config = toml.load(file)
except Exception:
logger.error(f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}")
logger.error(
f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
)
raise
else:
raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")
@@ -665,13 +690,13 @@ if __name__ == "__main__":
train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning)
logger.info("[argparse_namespace]")
logger.info(f'{vars(argparse_namespace)}')
logger.info(f"{vars(argparse_namespace)}")
user_config = load_user_config(config_args.dataset_config)
logger.info("")
logger.info("[user_config]")
logger.info(f'{user_config}')
logger.info(f"{user_config}")
sanitizer = ConfigSanitizer(
config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout
@@ -680,10 +705,10 @@ if __name__ == "__main__":
logger.info("")
logger.info("[sanitized_user_config]")
logger.info(f'{sanitized_user_config}')
logger.info(f"{sanitized_user_config}")
blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
logger.info("")
logger.info("[blueprint]")
logger.info(f'{blueprint}')
logger.info(f"{blueprint}")

View File

@@ -3,11 +3,14 @@ import argparse
import random
import re
from typing import List, Optional, Union
from .utils import setup_logging
from .utils import setup_logging
setup_logging()
import logging
import logging
logger = logging.getLogger(__name__)
def prepare_scheduler_for_custom_training(noise_scheduler, device):
if hasattr(noise_scheduler, "all_snr"):
return
@@ -64,7 +67,7 @@ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
if v_prediction:
snr_weight = torch.div(min_snr_gamma, snr+1).float().to(loss.device)
snr_weight = torch.div(min_snr_gamma, snr + 1).float().to(loss.device)
else:
snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device)
loss = loss * snr_weight
@@ -92,13 +95,15 @@ def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_los
loss = loss + loss / scale * v_pred_like_loss
return loss
def apply_debiased_estimation(loss, timesteps, noise_scheduler):
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
weight = 1/torch.sqrt(snr_t)
weight = 1 / torch.sqrt(snr_t)
loss = weight * loss
return loss
# TODO train_utilと分散しているのでどちらかに寄せる
@@ -474,6 +479,17 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
return noise
def apply_masked_loss(loss, batch):
# mask image is -1 to 1. we need to convert it to 0 to 1
mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
# resize to the same size as the loss
mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")
mask_image = mask_image / 2 + 0.5
loss = loss * mask_image
return loss
"""
##########################################
# Perlin Noise

139
library/deepspeed_utils.py Normal file
View File

@@ -0,0 +1,139 @@
import os
import argparse
import torch
from accelerate import DeepSpeedPlugin, Accelerator
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def add_deepspeed_arguments(parser: argparse.ArgumentParser):
# DeepSpeed Arguments. https://huggingface.co/docs/accelerate/usage_guides/deepspeed
parser.add_argument("--deepspeed", action="store_true", help="enable deepspeed training")
parser.add_argument("--zero_stage", type=int, default=2, choices=[0, 1, 2, 3], help="Possible options are 0,1,2,3.")
parser.add_argument(
"--offload_optimizer_device",
type=str,
default=None,
choices=[None, "cpu", "nvme"],
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3.",
)
parser.add_argument(
"--offload_optimizer_nvme_path",
type=str,
default=None,
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
)
parser.add_argument(
"--offload_param_device",
type=str,
default=None,
choices=[None, "cpu", "nvme"],
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3.",
)
parser.add_argument(
"--offload_param_nvme_path",
type=str,
default=None,
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
)
parser.add_argument(
"--zero3_init_flag",
action="store_true",
help="Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models."
"Only applicable with ZeRO Stage-3.",
)
parser.add_argument(
"--zero3_save_16bit_model",
action="store_true",
help="Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3.",
)
parser.add_argument(
"--fp16_master_weights_and_gradients",
action="store_true",
help="fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32.",
)
def prepare_deepspeed_args(args: argparse.Namespace):
if not args.deepspeed:
return
# To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
args.max_data_loader_n_workers = 1
def prepare_deepspeed_plugin(args: argparse.Namespace):
if not args.deepspeed:
return None
try:
import deepspeed
except ImportError as e:
logger.error(
"deepspeed is not installed. please install deepspeed in your environment with following command. DS_BUILD_OPS=0 pip install deepspeed"
)
exit(1)
deepspeed_plugin = DeepSpeedPlugin(
zero_stage=args.zero_stage,
gradient_accumulation_steps=args.gradient_accumulation_steps,
gradient_clipping=args.max_grad_norm,
offload_optimizer_device=args.offload_optimizer_device,
offload_optimizer_nvme_path=args.offload_optimizer_nvme_path,
offload_param_device=args.offload_param_device,
offload_param_nvme_path=args.offload_param_nvme_path,
zero3_init_flag=args.zero3_init_flag,
zero3_save_16bit_model=args.zero3_save_16bit_model,
)
deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_batch_size
deepspeed_plugin.deepspeed_config["train_batch_size"] = (
args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"])
)
deepspeed_plugin.set_mixed_precision(args.mixed_precision)
if args.mixed_precision.lower() == "fp16":
deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow.
if args.full_fp16 or args.fp16_master_weights_and_gradients:
if args.offload_optimizer_device == "cpu" and args.zero_stage == 2:
deepspeed_plugin.deepspeed_config["fp16"]["fp16_master_weights_and_grads"] = True
logger.info("[DeepSpeed] full fp16 enable.")
else:
logger.info(
"[DeepSpeed]full fp16, fp16_master_weights_and_grads currently only supported using ZeRO-Offload with DeepSpeedCPUAdam on ZeRO-2 stage."
)
if args.offload_optimizer_device is not None:
logger.info("[DeepSpeed] start to manually build cpu_adam.")
deepspeed.ops.op_builder.CPUAdamBuilder().load()
logger.info("[DeepSpeed] building cpu_adam done.")
return deepspeed_plugin
# Accelerate library does not support multiple models for deepspeed. So, we need to wrap multiple models into a single model.
def prepare_deepspeed_model(args: argparse.Namespace, **models):
# remove None from models
models = {k: v for k, v in models.items() if v is not None}
class DeepSpeedWrapper(torch.nn.Module):
def __init__(self, **kw_models) -> None:
super().__init__()
self.models = torch.nn.ModuleDict()
for key, model in kw_models.items():
if isinstance(model, list):
model = torch.nn.ModuleList(model)
assert isinstance(
model, torch.nn.Module
), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
self.models.update(torch.nn.ModuleDict({key: model}))
def get_models(self):
return self.models
ds_model = DeepSpeedWrapper(**models)
return ds_model

View File

@@ -32,6 +32,7 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.FloatTensor = torch.xpu.FloatTensor
torch.Tensor.cuda = torch.Tensor.xpu
torch.Tensor.is_cuda = torch.Tensor.is_xpu
torch.nn.Module.cuda = torch.nn.Module.xpu
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
torch.cuda._initialized = torch.xpu.lazy_init._initialized
@@ -147,9 +148,9 @@ def ipex_init(): # pylint: disable=too-many-statements
# C
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_eu_count
ipex._C._DeviceProperties.major = 2023
ipex._C._DeviceProperties.minor = 2
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
ipex._C._DeviceProperties.major = 2024
ipex._C._DeviceProperties.minor = 0
# Fix functions with ipex:
torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]

View File

@@ -122,15 +122,15 @@ def torch_bmm_32_bit(input, mat2, *, out=None):
mat2[start_idx:end_idx],
out=out
)
torch.xpu.synchronize(input.device)
else:
return original_torch_bmm(input, mat2, out=out)
torch.xpu.synchronize(input.device)
return hidden_states
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs):
if query.device.type != "xpu":
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_sdpa_slice_sizes(query.shape, query.element_size())
# Slice SDPA
@@ -153,7 +153,7 @@ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropo
key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal
dropout_p=dropout_p, is_causal=is_causal, **kwargs
)
else:
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
@@ -161,7 +161,7 @@ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropo
key[start_idx:end_idx, start_idx_2:end_idx_2],
value[start_idx:end_idx, start_idx_2:end_idx_2],
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal
dropout_p=dropout_p, is_causal=is_causal, **kwargs
)
else:
hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention(
@@ -169,9 +169,9 @@ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropo
key[start_idx:end_idx],
value[start_idx:end_idx],
attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal
dropout_p=dropout_p, is_causal=is_causal, **kwargs
)
torch.xpu.synchronize(query.device)
else:
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
torch.xpu.synchronize(query.device)
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
return hidden_states

View File

@@ -12,7 +12,7 @@ device_supports_fp64 = torch.xpu.has_fp64_dtype()
class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument
if isinstance(device_ids, list) and len(device_ids) > 1:
logger.error("IPEX backend doesn't support DataParallel on multiple XPU devices")
print("IPEX backend doesn't support DataParallel on multiple XPU devices")
return module.to("xpu")
def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
@@ -42,7 +42,7 @@ def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=Non
original_interpolate = torch.nn.functional.interpolate
@wraps(torch.nn.functional.interpolate)
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
if antialias or align_corners is not None:
if antialias or align_corners is not None or mode == 'bicubic':
return_device = tensor.device
return_dtype = tensor.dtype
return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode,
@@ -190,6 +190,16 @@ def Tensor_cuda(self, device=None, *args, **kwargs):
else:
return original_Tensor_cuda(self, device, *args, **kwargs)
original_Tensor_pin_memory = torch.Tensor.pin_memory
@wraps(torch.Tensor.pin_memory)
def Tensor_pin_memory(self, device=None, *args, **kwargs):
if device is None:
device = "xpu"
if check_device(device):
return original_Tensor_pin_memory(self, return_xpu(device), *args, **kwargs)
else:
return original_Tensor_pin_memory(self, device, *args, **kwargs)
original_UntypedStorage_init = torch.UntypedStorage.__init__
@wraps(torch.UntypedStorage.__init__)
def UntypedStorage_init(*args, device=None, **kwargs):
@@ -216,7 +226,9 @@ def torch_empty(*args, device=None, **kwargs):
original_torch_randn = torch.randn
@wraps(torch.randn)
def torch_randn(*args, device=None, **kwargs):
def torch_randn(*args, device=None, dtype=None, **kwargs):
if dtype == bytes:
dtype = None
if check_device(device):
return original_torch_randn(*args, device=return_xpu(device), **kwargs)
else:
@@ -256,11 +268,13 @@ def torch_Generator(device=None):
original_torch_load = torch.load
@wraps(torch.load)
def torch_load(f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs):
def torch_load(f, map_location=None, *args, **kwargs):
if map_location is None:
map_location = "xpu"
if check_device(map_location):
return original_torch_load(f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs)
return original_torch_load(f, *args, map_location=return_xpu(map_location), **kwargs)
else:
return original_torch_load(f, map_location=map_location, pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs)
return original_torch_load(f, *args, map_location=map_location, **kwargs)
# Hijack Functions:
@@ -268,6 +282,7 @@ def ipex_hijacks():
torch.tensor = torch_tensor
torch.Tensor.to = Tensor_to
torch.Tensor.cuda = Tensor_cuda
torch.Tensor.pin_memory = Tensor_pin_memory
torch.UntypedStorage.__init__ = UntypedStorage_init
torch.UntypedStorage.cuda = UntypedStorage_cuda
torch.empty = torch_empty

View File

@@ -31,8 +31,10 @@ from torch import nn
from torch.nn import functional as F
from einops import rearrange
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
IN_CHANNELS: int = 4
@@ -1074,7 +1076,7 @@ class SdxlUNet2DConditionModel(nn.Module):
timesteps = timesteps.expand(x.shape[0])
hs = []
t_emb = get_timestep_embedding(timesteps, self.model_channels) # , repeat_only=False)
t_emb = get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0) # , repeat_only=False)
t_emb = t_emb.to(x.dtype)
emb = self.time_embed(t_emb)
@@ -1132,7 +1134,7 @@ class InferSdxlUNet2DConditionModel:
# call original model's methods
def __getattr__(self, name):
return getattr(self.delegate, name)
def __call__(self, *args, **kwargs):
return self.delegate(*args, **kwargs)
@@ -1164,7 +1166,7 @@ class InferSdxlUNet2DConditionModel:
timesteps = timesteps.expand(x.shape[0])
hs = []
t_emb = get_timestep_embedding(timesteps, _self.model_channels) # , repeat_only=False)
t_emb = get_timestep_embedding(timesteps, _self.model_channels, downscale_freq_shift=0) # , repeat_only=False)
t_emb = t_emb.to(x.dtype)
emb = _self.time_embed(t_emb)

View File

@@ -24,7 +24,6 @@ TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
def load_target_model(args, accelerator, model_version: str, weight_dtype):
# load models for each process
model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16
for pi in range(accelerator.state.num_processes):
if pi == accelerator.state.local_process_index:

View File

@@ -63,12 +63,14 @@ from library.original_unet import UNet2DConditionModel
from huggingface_hub import hf_hub_download
import numpy as np
from PIL import Image
import imagesize
import cv2
import safetensors.torch
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
import library.model_util as model_util
import library.huggingface_util as huggingface_util
import library.sai_model_spec as sai_model_spec
import library.deepspeed_utils as deepspeed_utils
from library.utils import setup_logging
setup_logging()
@@ -364,6 +366,8 @@ class BaseSubset:
caption_separator: str,
keep_tokens: int,
keep_tokens_separator: str,
secondary_separator: Optional[str],
enable_wildcard: bool,
color_aug: bool,
flip_aug: bool,
face_crop_aug_range: Optional[Tuple[float, float]],
@@ -382,6 +386,8 @@ class BaseSubset:
self.caption_separator = caption_separator
self.keep_tokens = keep_tokens
self.keep_tokens_separator = keep_tokens_separator
self.secondary_separator = secondary_separator
self.enable_wildcard = enable_wildcard
self.color_aug = color_aug
self.flip_aug = flip_aug
self.face_crop_aug_range = face_crop_aug_range
@@ -405,11 +411,14 @@ class DreamBoothSubset(BaseSubset):
is_reg: bool,
class_tokens: Optional[str],
caption_extension: str,
cache_info: bool,
num_repeats,
shuffle_caption,
caption_separator: str,
keep_tokens,
keep_tokens_separator,
secondary_separator,
enable_wildcard,
color_aug,
flip_aug,
face_crop_aug_range,
@@ -431,6 +440,8 @@ class DreamBoothSubset(BaseSubset):
caption_separator,
keep_tokens,
keep_tokens_separator,
secondary_separator,
enable_wildcard,
color_aug,
flip_aug,
face_crop_aug_range,
@@ -449,6 +460,7 @@ class DreamBoothSubset(BaseSubset):
self.caption_extension = caption_extension
if self.caption_extension and not self.caption_extension.startswith("."):
self.caption_extension = "." + self.caption_extension
self.cache_info = cache_info
def __eq__(self, other) -> bool:
if not isinstance(other, DreamBoothSubset):
@@ -466,6 +478,8 @@ class FineTuningSubset(BaseSubset):
caption_separator,
keep_tokens,
keep_tokens_separator,
secondary_separator,
enable_wildcard,
color_aug,
flip_aug,
face_crop_aug_range,
@@ -487,6 +501,8 @@ class FineTuningSubset(BaseSubset):
caption_separator,
keep_tokens,
keep_tokens_separator,
secondary_separator,
enable_wildcard,
color_aug,
flip_aug,
face_crop_aug_range,
@@ -514,11 +530,14 @@ class ControlNetSubset(BaseSubset):
image_dir: str,
conditioning_data_dir: str,
caption_extension: str,
cache_info: bool,
num_repeats,
shuffle_caption,
caption_separator,
keep_tokens,
keep_tokens_separator,
secondary_separator,
enable_wildcard,
color_aug,
flip_aug,
face_crop_aug_range,
@@ -540,6 +559,8 @@ class ControlNetSubset(BaseSubset):
caption_separator,
keep_tokens,
keep_tokens_separator,
secondary_separator,
enable_wildcard,
color_aug,
flip_aug,
face_crop_aug_range,
@@ -557,6 +578,7 @@ class ControlNetSubset(BaseSubset):
self.caption_extension = caption_extension
if self.caption_extension and not self.caption_extension.startswith("."):
self.caption_extension = "." + self.caption_extension
self.cache_info = cache_info
def __eq__(self, other) -> bool:
if not isinstance(other, ControlNetSubset):
@@ -675,15 +697,48 @@ class BaseDataset(torch.utils.data.Dataset):
if is_drop_out:
caption = ""
else:
# process wildcards
if subset.enable_wildcard:
# if caption is multiline, random choice one line
if "\n" in caption:
caption = random.choice(caption.split("\n"))
# wildcard is like '{aaa|bbb|ccc...}'
# escape the curly braces like {{ or }}
replacer1 = ""
replacer2 = ""
while replacer1 in caption or replacer2 in caption:
replacer1 += ""
replacer2 += ""
caption = caption.replace("{{", replacer1).replace("}}", replacer2)
# replace the wildcard
def replace_wildcard(match):
return random.choice(match.group(1).split("|"))
caption = re.sub(r"\{([^}]+)\}", replace_wildcard, caption)
# unescape the curly braces
caption = caption.replace(replacer1, "{").replace(replacer2, "}")
else:
# if caption is multiline, use the first line
caption = caption.split("\n")[0]
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
fixed_tokens = []
flex_tokens = []
fixed_suffix_tokens = []
if (
hasattr(subset, "keep_tokens_separator")
and subset.keep_tokens_separator
and subset.keep_tokens_separator in caption
):
fixed_part, flex_part = caption.split(subset.keep_tokens_separator, 1)
if subset.keep_tokens_separator in flex_part:
flex_part, fixed_suffix_part = flex_part.split(subset.keep_tokens_separator, 1)
fixed_suffix_tokens = [t.strip() for t in fixed_suffix_part.split(subset.caption_separator) if t.strip()]
fixed_tokens = [t.strip() for t in fixed_part.split(subset.caption_separator) if t.strip()]
flex_tokens = [t.strip() for t in flex_part.split(subset.caption_separator) if t.strip()]
else:
@@ -718,7 +773,11 @@ class BaseDataset(torch.utils.data.Dataset):
flex_tokens = dropout_tags(flex_tokens)
caption = ", ".join(fixed_tokens + flex_tokens)
caption = ", ".join(fixed_tokens + flex_tokens + fixed_suffix_tokens)
# process secondary separator
if subset.secondary_separator:
caption = caption.replace(subset.secondary_separator, subset.caption_separator)
# textual inversion対応
for str_from, str_to in self.replacements.items():
@@ -1027,8 +1086,7 @@ class BaseDataset(torch.utils.data.Dataset):
)
def get_image_size(self, image_path):
image = Image.open(image_path)
return image.size
return imagesize.get(image_path)
def load_image_with_face_info(self, subset: BaseSubset, image_path: str):
img = load_image(image_path)
@@ -1357,6 +1415,8 @@ class BaseDataset(torch.utils.data.Dataset):
class DreamBoothDataset(BaseDataset):
IMAGE_INFO_CACHE_FILE = "metadata_cache.json"
def __init__(
self,
subsets: Sequence[DreamBoothSubset],
@@ -1400,7 +1460,7 @@ class DreamBoothDataset(BaseDataset):
self.bucket_reso_steps = None # この情報は使われない
self.bucket_no_upscale = False
def read_caption(img_path, caption_extension):
def read_caption(img_path, caption_extension, enable_wildcard):
# captionの候補ファイル名を作る
base_name = os.path.splitext(img_path)[0]
base_name_face_det = base_name
@@ -1419,35 +1479,66 @@ class DreamBoothDataset(BaseDataset):
logger.error(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}")
raise e
assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
caption = lines[0].strip()
if enable_wildcard:
caption = "\n".join([line.strip() for line in lines if line.strip() != ""]) # 空行を除く、改行で連結
else:
caption = lines[0].strip()
break
return caption
def load_dreambooth_dir(subset: DreamBoothSubset):
if not os.path.isdir(subset.image_dir):
logger.warning(f"not directory: {subset.image_dir}")
return [], []
return [], [], []
info_cache_file = os.path.join(subset.image_dir, self.IMAGE_INFO_CACHE_FILE)
use_cached_info_for_subset = subset.cache_info
if use_cached_info_for_subset:
logger.info(
f"using cached image info for this subset / このサブセットで、キャッシュされた画像情報を使います: {info_cache_file}"
)
if not os.path.isfile(info_cache_file):
logger.warning(
f"image info file not found. You can ignore this warning if this is the first time to use this subset"
+ " / キャッシュファイルが見つかりませんでした。初回実行時はこの警告を無視してください: {metadata_file}"
)
use_cached_info_for_subset = False
if use_cached_info_for_subset:
# json: {`img_path`:{"caption": "caption...", "resolution": [width, height]}, ...}
with open(info_cache_file, "r", encoding="utf-8") as f:
metas = json.load(f)
img_paths = list(metas.keys())
sizes = [meta["resolution"] for meta in metas.values()]
# we may need to check image size and existence of image files, but it takes time, so user should check it before training
else:
img_paths = glob_images(subset.image_dir, "*")
sizes = [None] * len(img_paths)
img_paths = glob_images(subset.image_dir, "*")
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
captions = []
missing_captions = []
for img_path in img_paths:
cap_for_img = read_caption(img_path, subset.caption_extension)
if cap_for_img is None and subset.class_tokens is None:
logger.warning(
f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}"
)
captions.append("")
missing_captions.append(img_path)
else:
if cap_for_img is None:
captions.append(subset.class_tokens)
if use_cached_info_for_subset:
captions = [meta["caption"] for meta in metas.values()]
missing_captions = [img_path for img_path, caption in zip(img_paths, captions) if caption is None or caption == ""]
else:
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
captions = []
missing_captions = []
for img_path in img_paths:
cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard)
if cap_for_img is None and subset.class_tokens is None:
logger.warning(
f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}"
)
captions.append("")
missing_captions.append(img_path)
else:
captions.append(cap_for_img)
if cap_for_img is None:
captions.append(subset.class_tokens)
missing_captions.append(img_path)
else:
captions.append(cap_for_img)
self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
@@ -1464,12 +1555,24 @@ class DreamBoothDataset(BaseDataset):
logger.warning(missing_caption + f"... and {remaining_missing_captions} more")
break
logger.warning(missing_caption)
return img_paths, captions
if not use_cached_info_for_subset and subset.cache_info:
logger.info(f"cache image info for / 画像情報をキャッシュします : {info_cache_file}")
sizes = [self.get_image_size(img_path) for img_path in tqdm(img_paths, desc="get image size")]
matas = {}
for img_path, caption, size in zip(img_paths, captions, sizes):
matas[img_path] = {"caption": caption, "resolution": list(size)}
with open(info_cache_file, "w", encoding="utf-8") as f:
json.dump(matas, f, ensure_ascii=False, indent=2)
logger.info(f"cache image info done for / 画像情報を出力しました : {info_cache_file}")
# if sizes are not set, image size will be read in make_buckets
return img_paths, captions, sizes
logger.info("prepare images.")
num_train_images = 0
num_reg_images = 0
reg_infos: List[ImageInfo] = []
reg_infos: List[Tuple[ImageInfo, DreamBoothSubset]] = []
for subset in subsets:
if subset.num_repeats < 1:
logger.warning(
@@ -1483,7 +1586,7 @@ class DreamBoothDataset(BaseDataset):
)
continue
img_paths, captions = load_dreambooth_dir(subset)
img_paths, captions, sizes = load_dreambooth_dir(subset)
if len(img_paths) < 1:
logger.warning(
f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します"
@@ -1495,10 +1598,12 @@ class DreamBoothDataset(BaseDataset):
else:
num_train_images += subset.num_repeats * len(img_paths)
for img_path, caption in zip(img_paths, captions):
for img_path, caption, size in zip(img_paths, captions, sizes):
info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path)
if size is not None:
info.image_size = size
if subset.is_reg:
reg_infos.append(info)
reg_infos.append((info, subset))
else:
self.register_image(info, subset)
@@ -1519,7 +1624,7 @@ class DreamBoothDataset(BaseDataset):
n = 0
first_loop = True
while n < num_train_images:
for info in reg_infos:
for info, subset in reg_infos:
if first_loop:
self.register_image(info, subset)
n += info.num_repeats
@@ -1611,10 +1716,24 @@ class FineTuningDataset(BaseDataset):
caption = img_md.get("caption")
tags = img_md.get("tags")
if caption is None:
caption = tags
elif tags is not None and len(tags) > 0:
caption = caption + ", " + tags
tags_list.append(tags)
caption = tags # could be multiline
tags = None
if subset.enable_wildcard:
# tags must be single line
if tags is not None:
tags = tags.replace("\n", subset.caption_separator)
# add tags to each line of caption
if caption is not None and tags is not None:
caption = "\n".join(
[f"{line}{subset.caption_separator}{tags}" for line in caption.split("\n") if line.strip() != ""]
)
else:
# use as is
if tags is not None and len(tags) > 0:
caption = caption + subset.caption_separator + tags
tags_list.append(tags)
if caption is None:
caption = ""
@@ -1764,16 +1883,22 @@ class ControlNetDataset(BaseDataset):
db_subsets = []
for subset in subsets:
assert (
not subset.random_crop
), "random_crop is not supported in ControlNetDataset / random_cropはControlNetDatasetではサポートされていません"
db_subset = DreamBoothSubset(
subset.image_dir,
False,
None,
subset.caption_extension,
subset.cache_info,
subset.num_repeats,
subset.shuffle_caption,
subset.caption_separator,
subset.keep_tokens,
subset.keep_tokens_separator,
subset.secondary_separator,
subset.enable_wildcard,
subset.color_aug,
subset.flip_aug,
subset.face_crop_aug_range,
@@ -1812,7 +1937,7 @@ class ControlNetDataset(BaseDataset):
# assert all conditioning data exists
missing_imgs = []
cond_imgs_with_img = set()
cond_imgs_with_pair = set()
for image_key, info in self.dreambooth_dataset_delegate.image_data.items():
db_subset = self.dreambooth_dataset_delegate.image_to_subset[image_key]
subset = None
@@ -1826,23 +1951,29 @@ class ControlNetDataset(BaseDataset):
logger.warning(f"not directory: {subset.conditioning_data_dir}")
continue
img_basename = os.path.basename(info.absolute_path)
ctrl_img_path = os.path.join(subset.conditioning_data_dir, img_basename)
if not os.path.exists(ctrl_img_path):
img_basename = os.path.splitext(os.path.basename(info.absolute_path))[0]
ctrl_img_path = glob_images(subset.conditioning_data_dir, img_basename)
if len(ctrl_img_path) < 1:
missing_imgs.append(img_basename)
continue
ctrl_img_path = ctrl_img_path[0]
ctrl_img_path = os.path.abspath(ctrl_img_path) # normalize path
info.cond_img_path = ctrl_img_path
cond_imgs_with_img.add(ctrl_img_path)
cond_imgs_with_pair.add(os.path.splitext(ctrl_img_path)[0]) # remove extension because Windows is case insensitive
extra_imgs = []
for subset in subsets:
conditioning_img_paths = glob_images(subset.conditioning_data_dir, "*")
extra_imgs.extend(
[cond_img_path for cond_img_path in conditioning_img_paths if cond_img_path not in cond_imgs_with_img]
)
conditioning_img_paths = [os.path.abspath(p) for p in conditioning_img_paths] # normalize path
extra_imgs.extend([p for p in conditioning_img_paths if os.path.splitext(p)[0] not in cond_imgs_with_pair])
assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}"
assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}"
assert (
len(missing_imgs) == 0
), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}"
assert (
len(extra_imgs) == 0
), f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}"
self.conditioning_image_transforms = IMAGE_TRANSFORMS
@@ -2888,7 +3019,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument(
"--save_state",
action="store_true",
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する",
help="save training state additionally (including optimizer states etc.) when saving model / optimizerなど学習状態も含めたstateをモデル保存時に追加で保存する",
)
parser.add_argument(
"--save_state_on_train_end",
action="store_true",
help="save training state (including optimizer states etc.) on train end / optimizerなど学習状態も含めたstateを学習完了時に保存する",
)
parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
@@ -2971,6 +3107,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
"--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する"
) # TODO move to SDXL training, because it is not supported by SD1/2
parser.add_argument("--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う")
parser.add_argument(
"--ddp_timeout",
type=int,
@@ -3033,12 +3170,18 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=None,
help="specify WandB API key to log in before starting training (optional). / WandB APIキーを指定して学習開始前にログインするオプション",
)
parser.add_argument(
"--noise_offset",
type=float,
default=None,
help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する有効にする場合は0.1程度を推奨)",
)
parser.add_argument(
"--noise_offset_random_strength",
action="store_true",
help="use random strength between 0~noise_offset for noise offset. / noise offsetにおいて、0からnoise_offsetの間でランダムな強度を使用します。",
)
parser.add_argument(
"--multires_noise_iterations",
type=int,
@@ -3052,6 +3195,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
help="enable input perturbation noise. used for regularization. recommended value: around 0.1 (from arxiv.org/abs/2301.11706) "
+ "/ input perturbation noiseを有効にする。正則化に使用される。推奨値: 0.1程度 (arxiv.org/abs/2301.11706 より)",
)
parser.add_argument(
"--ip_noise_gamma_random_strength",
action="store_true",
help="Use random strength between 0~ip_noise_gamma for input perturbation noise."
+ "/ input perturbation noiseにおいて、0からip_noise_gammaの間でランダムな強度を使用します。",
)
# parser.add_argument(
# "--perlin_noise",
# type=int,
@@ -3087,6 +3236,27 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=None,
help="set maximum time step for U-Net training (1~1000, default is 1000) / U-Net学習時のtime stepの最大値を設定する1~1000で指定、省略時はデフォルト値(1000)",
)
parser.add_argument(
"--loss_type",
type=str,
default="l2",
choices=["l2", "huber", "smooth_l1"],
help="The type of loss function to use (L2, Huber, or smooth L1), default is L2 / 使用する損失関数の種類L2、Huber、またはsmooth L1、デフォルトはL2",
)
parser.add_argument(
"--huber_schedule",
type=str,
default="snr",
choices=["constant", "exponential", "snr"],
help="The scheduling method for Huber loss (constant, exponential, or SNR-based). Only used when loss_type is 'huber' or 'smooth_l1'. default is snr"
+ " / Huber損失のスケジューリング方法constant、exponential、またはSNRベース。loss_typeが'huber'または'smooth_l1'の場合に有効、デフォルトは snr",
)
parser.add_argument(
"--huber_c",
type=float,
default=0.1,
help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type. default is 0.1 / Huber損失のパラメータ。loss_typeがhuberまたはsmooth l1の場合に有効。デフォルトは0.1",
)
parser.add_argument(
"--lowram",
@@ -3195,6 +3365,74 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
)
def add_masked_loss_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--conditioning_data_dir",
type=str,
default=None,
help="conditioning data directory / 条件付けデータのディレクトリ",
)
parser.add_argument(
"--masked_loss",
action="store_true",
help="apply mask for calculating loss. conditioning_data_dir is required for dataset. / 損失計算時にマスクを適用する。datasetにはconditioning_data_dirが必要",
)
# verify command line args for training
def verify_command_line_training_args(args: argparse.Namespace):
# if wandb is enabled, the command line is exposed to the public
# check whether sensitive options are included in the command line arguments
# if so, warn or inform the user to move them to the configuration file
# wandbが有効な場合、コマンドラインが公開される
# 学習用のコマンドライン引数に敏感なオプションが含まれているかどうかを確認し、
# 含まれている場合は設定ファイルに移動するようにユーザーに警告または通知する
wandb_enabled = args.log_with is not None and args.log_with != "tensorboard" # "all" or "wandb"
if not wandb_enabled:
return
sensitive_args = ["wandb_api_key", "huggingface_token"]
sensitive_path_args = [
"pretrained_model_name_or_path",
"vae",
"tokenizer_cache_dir",
"train_data_dir",
"conditioning_data_dir",
"reg_data_dir",
"output_dir",
"logging_dir",
]
for arg in sensitive_args:
if getattr(args, arg, None) is not None:
logger.warning(
f"wandb is enabled, but option `{arg}` is included in the command line. Because the command line is exposed to the public, it is recommended to move it to the `.toml` file."
+ f" / wandbが有効で、かつオプション `{arg}` がコマンドラインに含まれています。コマンドラインは公開されるため、`.toml`ファイルに移動することをお勧めします。"
)
# if path is absolute, it may include sensitive information
for arg in sensitive_path_args:
if getattr(args, arg, None) is not None and os.path.isabs(getattr(args, arg)):
logger.info(
f"wandb is enabled, but option `{arg}` is included in the command line and it is an absolute path. Because the command line is exposed to the public, it is recommended to move it to the `.toml` file or use relative path."
+ f" / wandbが有効で、かつオプション `{arg}` がコマンドラインに含まれており、絶対パスです。コマンドラインは公開されるため、`.toml`ファイルに移動するか、相対パスを使用することをお勧めします。"
)
if getattr(args, "config_file", None) is not None:
logger.info(
f"wandb is enabled, but option `config_file` is included in the command line. Because the command line is exposed to the public, please be careful about the information included in the path."
+ f" / wandbが有効で、かつオプション `config_file` がコマンドラインに含まれています。コマンドラインは公開されるため、パスに含まれる情報にご注意ください。"
)
# other sensitive options
if args.huggingface_repo_id is not None and args.huggingface_repo_visibility != "public":
logger.info(
f"wandb is enabled, but option huggingface_repo_id is included in the command line and huggingface_repo_visibility is not 'public'. Because the command line is exposed to the public, it is recommended to move it to the `.toml` file."
+ f" / wandbが有効で、かつオプション huggingface_repo_id がコマンドラインに含まれており、huggingface_repo_visibility が 'public' ではありません。コマンドラインは公開されるため、`.toml`ファイルに移動することをお勧めします。"
)
def verify_training_args(args: argparse.Namespace):
r"""
Verify training arguments. Also reflect highvram option to global variable
@@ -3250,6 +3488,18 @@ def verify_training_args(args: argparse.Namespace):
+ " / zero_terminal_snrが有効ですが、v_parameterizationが有効ではありません。学習結果は想定外になる可能性があります"
)
if args.sample_every_n_epochs is not None and args.sample_every_n_epochs <= 0:
logger.warning(
"sample_every_n_epochs is less than or equal to 0, so it will be disabled / sample_every_n_epochsに0以下の値が指定されたため無効になります"
)
args.sample_every_n_epochs = None
if args.sample_every_n_steps is not None and args.sample_every_n_steps <= 0:
logger.warning(
"sample_every_n_steps is less than or equal to 0, so it will be disabled / sample_every_n_stepsに0以下の値が指定されたため無効になります"
)
args.sample_every_n_steps = None
def add_dataset_arguments(
parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool
@@ -3258,6 +3508,12 @@ def add_dataset_arguments(
parser.add_argument(
"--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ"
)
parser.add_argument(
"--cache_info",
action="store_true",
help="cache meta information (caption and image size) for faster dataset loading. only available for DreamBooth"
+ " / メタ情報キャプションとサイズをキャッシュしてデータセット読み込みを高速化する。DreamBooth方式のみ有効",
)
parser.add_argument(
"--shuffle_caption", action="store_true", help="shuffle separated caption / 区切られたcaptionの各要素をshuffleする"
)
@@ -3284,6 +3540,18 @@ def add_dataset_arguments(
help="A custom separator to divide the caption into fixed and flexible parts. Tokens before this separator will not be shuffled. If not specified, '--keep_tokens' will be used to determine the fixed number of tokens."
+ " / captionを固定部分と可変部分に分けるためのカスタム区切り文字。この区切り文字より前のトークンはシャッフルされない。指定しない場合、'--keep_tokens'が固定部分のトークン数として使用される。",
)
parser.add_argument(
"--secondary_separator",
type=str,
default=None,
help="a secondary separator for caption. This separator is replaced to caption_separator after dropping/shuffling caption"
+ " / captionのセカンダリ区切り文字。この区切り文字はcaptionのドロップやシャッフル後にcaption_separatorに置き換えられる",
)
parser.add_argument(
"--enable_wildcard",
action="store_true",
help="enable wildcard for caption (e.g. '{image|picture|rendition}') / captionのワイルドカードを有効にする'{image|picture|rendition}'",
)
parser.add_argument(
"--caption_prefix",
type=str,
@@ -3474,7 +3742,7 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar
exit(1)
logger.info(f"Loading settings from {config_path}...")
with open(config_path, "r") as f:
with open(config_path, "r", encoding="utf-8") as f:
config_dict = toml.load(f)
# combine all sections into one
@@ -3983,6 +4251,10 @@ def load_tokenizer(args: argparse.Namespace):
def prepare_accelerator(args: argparse.Namespace):
"""
this function also prepares deepspeed plugin
"""
if args.logging_dir is None:
logging_dir = None
else:
@@ -4028,6 +4300,8 @@ def prepare_accelerator(args: argparse.Namespace):
),
)
kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers))
deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
@@ -4035,6 +4309,7 @@ def prepare_accelerator(args: argparse.Namespace):
project_dir=logging_dir,
kwargs_handlers=kwargs_handlers,
dynamo_backend=dynamo_backend,
deepspeed_plugin=deepspeed_plugin,
)
print("accelerator device:", accelerator.device)
return accelerator
@@ -4105,7 +4380,6 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une
def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False):
# load models for each process
for pi in range(accelerator.state.num_processes):
if pi == accelerator.state.local_process_index:
logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
@@ -4116,7 +4390,6 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
accelerator.device if args.lowram else "cpu",
unet_use_linear_projection_in_v2=unet_use_linear_projection_in_v2,
)
# work on low-ram device
if args.lowram:
text_encoder.to(accelerator.device)
@@ -4125,7 +4398,6 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
return text_encoder, vae, unet, load_stable_diffusion_format
@@ -4592,11 +4864,47 @@ def save_sd_model_on_train_end_common(
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)
def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device):
# TODO: if a huber loss is selected, it will use constant timesteps for each batch
# as. In the future there may be a smarter way
if args.loss_type == "huber" or args.loss_type == "smooth_l1":
timesteps = torch.randint(min_timestep, max_timestep, (1,), device="cpu")
timestep = timesteps.item()
if args.huber_schedule == "exponential":
alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
huber_c = math.exp(-alpha * timestep)
elif args.huber_schedule == "snr":
alphas_cumprod = noise_scheduler.alphas_cumprod[timestep]
sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
elif args.huber_schedule == "constant":
huber_c = args.huber_c
else:
raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!")
timesteps = timesteps.repeat(b_size).to(device)
elif args.loss_type == "l2":
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device)
huber_c = 1 # may be anything, as it's not used
else:
raise NotImplementedError(f"Unknown loss type {args.loss_type}")
timesteps = timesteps.long()
return timesteps, huber_c
def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
noise = custom_train_functions.apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale)
if args.noise_offset_random_strength:
noise_offset = torch.rand(1, device=latents.device) * args.noise_offset
else:
noise_offset = args.noise_offset
noise = custom_train_functions.apply_noise_offset(latents, noise, noise_offset, args.adaptive_noise_scale)
if args.multires_noise_iterations:
noise = custom_train_functions.pyramid_noise_like(
noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount
@@ -4607,17 +4915,44 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
min_timestep = 0 if args.min_timestep is None else args.min_timestep
max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=latents.device)
timesteps = timesteps.long()
timesteps, huber_c = get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, latents.device)
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
if args.ip_noise_gamma:
noisy_latents = noise_scheduler.add_noise(latents, noise + args.ip_noise_gamma * torch.randn_like(latents), timesteps)
if args.ip_noise_gamma_random_strength:
strength = torch.rand(1, device=latents.device) * args.ip_noise_gamma
else:
strength = args.ip_noise_gamma
noisy_latents = noise_scheduler.add_noise(latents, noise + strength * torch.randn_like(latents), timesteps)
else:
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
return noise, noisy_latents, timesteps
return noise, noisy_latents, timesteps, huber_c
# NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already
def conditional_loss(
model_pred: torch.Tensor, target: torch.Tensor, reduction: str = "mean", loss_type: str = "l2", huber_c: float = 0.1
):
if loss_type == "l2":
loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction)
elif loss_type == "huber":
loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
if reduction == "mean":
loss = torch.mean(loss)
elif reduction == "sum":
loss = torch.sum(loss)
elif loss_type == "smooth_l1":
loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
if reduction == "mean":
loss = torch.mean(loss)
elif reduction == "sum":
loss = torch.sum(loss)
else:
raise NotImplementedError(f"Unsupported Loss Type {loss_type}")
return loss
def append_lr_to_logs(logs, lr_scheduler, optimizer_type, including_unet=True):

View File

@@ -327,10 +327,10 @@ class DyLoRANetwork(torch.nn.Module):
for i, text_encoder in enumerate(text_encoders):
if len(text_encoders) > 1:
index = i + 1
print(f"create LoRA for Text Encoder {index}")
logger.info(f"create LoRA for Text Encoder {index}")
else:
index = None
print(f"create LoRA for Text Encoder")
logger.info("create LoRA for Text Encoder")
text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
self.text_encoder_loras.extend(text_encoder_loras)

View File

@@ -247,14 +247,13 @@ class LoRAInfModule(LoRAModule):
area = x.size()[1]
mask = self.network.mask_dic.get(area, None)
if mask is None:
# raise ValueError(f"mask is None for resolution {area}")
if mask is None or len(x.size()) == 2:
# emb_layers in SDXL doesn't have mask
# if "emb" not in self.lora_name:
# print(f"mask is None for resolution {self.lora_name}, {area}, {x.size()}")
mask_size = (1, x.size()[1]) if len(x.size()) == 2 else (1, *x.size()[1:-1], 1)
return torch.ones(mask_size, dtype=x.dtype, device=x.device) / self.network.num_sub_prompts
if len(x.size()) != 4:
if len(x.size()) == 3:
mask = torch.reshape(mask, (1, -1, 1))
return mask

View File

@@ -39,12 +39,7 @@ def load_state_dict(file_name, dtype):
return sd, metadata
def save_to_file(file_name, state_dict, dtype, metadata):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
def save_to_file(file_name, state_dict, metadata):
if model_util.is_safetensors(file_name):
save_file(state_dict, file_name, metadata)
else:
@@ -349,12 +344,18 @@ def resize(args):
metadata["ss_network_dim"] = "Dynamic"
metadata["ss_network_alpha"] = "Dynamic"
# cast to save_dtype before calculating hashes
for key in list(state_dict.keys()):
value = state_dict[key]
if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype:
state_dict[key] = value.to(save_dtype)
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
logger.info(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, save_dtype, metadata)
save_to_file(args.save_to, state_dict, metadata)
def setup_parser() -> argparse.ArgumentParser:

View File

@@ -1,18 +1,25 @@
import itertools
import math
import argparse
import os
import time
import concurrent.futures
import torch
from safetensors.torch import load_file, save_file
from tqdm import tqdm
from library import sai_model_spec, sdxl_model_util, train_util
import library.model_util as model_util
import lora
import oft
from svd_merge_lora import format_lbws, get_lbw_block_index, LAYER26
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == ".safetensors":
sd = load_file(file_name)
@@ -28,36 +35,58 @@ def load_state_dict(file_name, dtype):
return sd, metadata
def save_to_file(file_name, model, state_dict, dtype, metadata):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
def save_to_file(file_name, model, metadata):
if os.path.splitext(file_name)[1] == ".safetensors":
save_file(model, file_name, metadata=metadata)
else:
torch.save(model, file_name)
def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype):
text_encoder1.to(merge_dtype)
def detect_method_from_training_model(models, dtype):
for model in models:
# TODO It is better to use key names to detect the method
lora_sd, _ = load_state_dict(model, dtype)
for key in tqdm(lora_sd.keys()):
if "lora_up" in key or "lora_down" in key:
return "LoRA"
elif "oft_blocks" in key:
return "OFT"
def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, lbws, merge_dtype):
text_encoder1.to(merge_dtype)
text_encoder2.to(merge_dtype)
unet.to(merge_dtype)
# detect the method: OFT or LoRA_module
method = detect_method_from_training_model(models, merge_dtype)
logger.info(f"method:{method}")
if lbws:
lbws, _, LBW_TARGET_IDX = format_lbws(lbws)
else:
LBW_TARGET_IDX = []
# create module map
name_to_module = {}
for i, root_module in enumerate([text_encoder1, text_encoder2, unet]):
if i <= 1:
if i == 0:
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1
if method == "LoRA":
if i <= 1:
if i == 0:
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1
else:
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2
target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
else:
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2
target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
else:
prefix = lora.LoRANetwork.LORA_PREFIX_UNET
prefix = lora.LoRANetwork.LORA_PREFIX_UNET
target_replace_modules = (
lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
)
elif method == "OFT":
prefix = oft.OFTNetwork.OFT_PREFIX_UNET
# ALL_LINEAR includes ATTN_ONLY, so we don't need to specify ATTN_ONLY
target_replace_modules = (
lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_ALL_LINEAR + oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
)
for name, module in root_module.named_modules():
@@ -68,65 +97,172 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
lora_name = lora_name.replace(".", "_")
name_to_module[lora_name] = child_module
for model, ratio in zip(models, ratios):
for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws):
logger.info(f"loading: {model}")
lora_sd, _ = load_state_dict(model, merge_dtype)
logger.info(f"merging...")
for key in tqdm(lora_sd.keys()):
if "lora_down" in key:
up_key = key.replace("lora_down", "lora_up")
alpha_key = key[: key.index("lora_down")] + "alpha"
# find original module for this lora
module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
if lbw:
lbw_weights = [1] * 26
for index, value in zip(LBW_TARGET_IDX, lbw):
lbw_weights[index] = value
logger.info(f"lbw: {dict(zip(LAYER26.keys(), lbw_weights))}")
if method == "LoRA":
for key in tqdm(lora_sd.keys()):
if "lora_down" in key:
up_key = key.replace("lora_down", "lora_up")
alpha_key = key[: key.index("lora_down")] + "alpha"
# find original module for this lora
module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
if module_name not in name_to_module:
logger.info(f"no module found for LoRA weight: {key}")
continue
module = name_to_module[module_name]
# logger.info(f"apply {key} to {module}")
down_weight = lora_sd[key]
up_weight = lora_sd[up_key]
dim = down_weight.size()[0]
alpha = lora_sd.get(alpha_key, dim)
scale = alpha / dim
if lbw:
index = get_lbw_block_index(key, True)
is_lbw_target = index in LBW_TARGET_IDX
if is_lbw_target:
scale *= lbw_weights[index] # keyがlbwの対象であれば、lbwの重みを掛ける
# W <- W + U * D
weight = module.weight
# logger.info(module_name, down_weight.size(), up_weight.size())
if len(weight.size()) == 2:
# linear
weight = weight + ratio * (up_weight @ down_weight) * scale
elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1
weight = (
weight
+ ratio
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* scale
)
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + ratio * conved * scale
module.weight = torch.nn.Parameter(weight)
elif method == "OFT":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for key in tqdm(lora_sd.keys()):
if "oft_blocks" in key:
oft_blocks = lora_sd[key]
dim = oft_blocks.shape[0]
break
for key in tqdm(lora_sd.keys()):
if "alpha" in key:
oft_blocks = lora_sd[key]
alpha = oft_blocks.item()
break
def merge_to(key):
if "alpha" in key:
return
# find original module for this OFT
module_name = ".".join(key.split(".")[:-1])
if module_name not in name_to_module:
logger.info(f"no module found for LoRA weight: {key}")
continue
logger.info(f"no module found for OFT weight: {key}")
return
module = name_to_module[module_name]
# logger.info(f"apply {key} to {module}")
down_weight = lora_sd[key]
up_weight = lora_sd[up_key]
oft_blocks = lora_sd[key]
dim = down_weight.size()[0]
alpha = lora_sd.get(alpha_key, dim)
scale = alpha / dim
if isinstance(module, torch.nn.Linear):
out_dim = module.out_features
elif isinstance(module, torch.nn.Conv2d):
out_dim = module.out_channels
# W <- W + U * D
weight = module.weight
# logger.info(module_name, down_weight.size(), up_weight.size())
if len(weight.size()) == 2:
# linear
weight = weight + ratio * (up_weight @ down_weight) * scale
elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1
weight = (
weight
+ ratio
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* scale
)
num_blocks = dim
block_size = out_dim // dim
constraint = (0 if alpha is None else alpha) * out_dim
multiplier = 1
if lbw:
index = get_lbw_block_index(key, False)
is_lbw_target = index in LBW_TARGET_IDX
if is_lbw_target:
multiplier *= lbw_weights[index]
block_Q = oft_blocks - oft_blocks.transpose(1, 2)
norm_Q = torch.norm(block_Q.flatten())
new_norm_Q = torch.clamp(norm_Q, max=constraint)
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
I = torch.eye(block_size, device=oft_blocks.device).unsqueeze(0).repeat(num_blocks, 1, 1)
block_R = torch.matmul(I + block_Q, (I - block_Q).inverse())
block_R_weighted = multiplier * block_R + (1 - multiplier) * I
R = torch.block_diag(*block_R_weighted)
# get org weight
org_sd = module.state_dict()
org_weight = org_sd["weight"].to(device)
R = R.to(org_weight.device, dtype=org_weight.dtype)
if org_weight.dim() == 4:
weight = torch.einsum("oihw, op -> pihw", org_weight, R)
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + ratio * conved * scale
weight = torch.einsum("oi, op -> pi", org_weight, R)
weight = weight.contiguous() # Make Tensor contiguous; required due to ThreadPoolExecutor
module.weight = torch.nn.Parameter(weight)
# TODO multi-threading may cause OOM on CPU if cpu_count is too high and RAM is not enough
max_workers = 1 if device.type != "cpu" else None # avoid OOM on GPU
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
list(tqdm(executor.map(merge_to, lora_sd.keys()), total=len(lora_sd.keys())))
def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
def merge_lora_models(models, ratios, lbws, merge_dtype, concat=False, shuffle=False):
base_alphas = {} # alpha for merged model
base_dims = {}
# detect the method: OFT or LoRA_module
method = detect_method_from_training_model(models, merge_dtype)
if method == "OFT":
raise ValueError(
"OFT model is not supported for merging OFT models. / OFTモデルはOFTモデル同士のマージには対応していません"
)
if lbws:
lbws, _, LBW_TARGET_IDX = format_lbws(lbws)
else:
LBW_TARGET_IDX = []
merged_sd = {}
v2 = None
base_model = None
for model, ratio in zip(models, ratios):
for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws):
logger.info(f"loading: {model}")
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
if lbw:
lbw_weights = [1] * 26
for index, value in zip(LBW_TARGET_IDX, lbw):
lbw_weights[index] = value
logger.info(f"lbw: {dict(zip(LAYER26.keys(), lbw_weights))}")
if lora_metadata is not None:
if v2 is None:
v2 = lora_metadata.get(train_util.SS_METADATA_KEY_V2, None) # returns string, SDXLはv2がないのでFalseのはず
@@ -164,7 +300,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
for key in tqdm(lora_sd.keys()):
if "alpha" in key:
continue
if "lora_up" in key and concat:
concat_dim = 1
elif "lora_down" in key and concat:
@@ -178,8 +314,14 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
alpha = alphas[lora_module_name]
scale = math.sqrt(alpha / base_alpha) * ratio
scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。
scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。
if lbw:
index = get_lbw_block_index(key, True)
is_lbw_target = index in LBW_TARGET_IDX
if is_lbw_target:
scale *= lbw_weights[index] # keyがlbwの対象であれば、lbwの重みを掛ける
if key in merged_sd:
assert (
merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None
@@ -201,7 +343,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
dim = merged_sd[key_down].shape[0]
perm = torch.randperm(dim)
merged_sd[key_down] = merged_sd[key_down][perm]
merged_sd[key_up] = merged_sd[key_up][:,perm]
merged_sd[key_up] = merged_sd[key_up][:, perm]
logger.info("merged model")
logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
@@ -229,7 +371,15 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
def merge(args):
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
assert len(args.models) == len(
args.ratios
), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
if args.lbws:
assert len(args.models) == len(
args.lbws
), f"number of models must be equal to number of ratios / モデルの数と層別適用率の数は合わせてください"
else:
args.lbws = [] # zip_longestで扱えるようにlbws未使用時には空のリストにしておく
def str_to_dtype(p):
if p == "float":
@@ -257,7 +407,7 @@ def merge(args):
ckpt_info,
) = sdxl_model_util.load_models_from_sdxl_checkpoint(sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.sd_model, "cpu")
merge_to_sd_model(text_model1, text_model2, unet, args.models, args.ratios, merge_dtype)
merge_to_sd_model(text_model1, text_model2, unet, args.models, args.ratios, args.lbws, merge_dtype)
if args.no_metadata:
sai_metadata = None
@@ -273,7 +423,13 @@ def merge(args):
args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, logit_scale, sai_metadata, save_dtype
)
else:
state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
state_dict, metadata = merge_lora_models(args.models, args.ratios, args.lbws, merge_dtype, args.concat, args.shuffle)
# cast to save_dtype before calculating hashes
for key in list(state_dict.keys()):
value = state_dict[key]
if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype:
state_dict[key] = value.to(save_dtype)
logger.info(f"calculating hashes and creating metadata...")
@@ -290,7 +446,7 @@ def merge(args):
metadata.update(sai_metadata)
logger.info(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
save_to_file(args.save_to, state_dict, metadata)
def setup_parser() -> argparse.ArgumentParser:
@@ -316,12 +472,19 @@ def setup_parser() -> argparse.ArgumentParser:
help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする",
)
parser.add_argument(
"--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
"--save_to",
type=str,
default=None,
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors",
)
parser.add_argument(
"--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors"
"--models",
type=str,
nargs="*",
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors",
)
parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
parser.add_argument("--lbws", type=str, nargs="*", help="lbw for each model / それぞれのLoRAモデルの層別適用率")
parser.add_argument(
"--no_metadata",
action="store_true",
@@ -337,8 +500,7 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--shuffle",
action="store_true",
help="shuffle lora weight./ "
+ "LoRAの重みをシャッフルする",
help="shuffle lora weight./ " + "LoRAの重みをシャッフルする",
)
return parser

View File

@@ -1,5 +1,8 @@
import argparse
import itertools
import json
import os
import re
import time
import torch
from safetensors.torch import load_file, save_file
@@ -8,12 +11,195 @@ from library import sai_model_spec, train_util
import library.model_util as model_util
import lora
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
CLAMP_QUANTILE = 0.99
ACCEPTABLE = [12, 17, 20, 26]
SDXL_LAYER_NUM = [12, 20]
LAYER12 = {
"BASE": True,
"IN00": False,
"IN01": False,
"IN02": False,
"IN03": False,
"IN04": True,
"IN05": True,
"IN06": False,
"IN07": True,
"IN08": True,
"IN09": False,
"IN10": False,
"IN11": False,
"MID": True,
"OUT00": True,
"OUT01": True,
"OUT02": True,
"OUT03": True,
"OUT04": True,
"OUT05": True,
"OUT06": False,
"OUT07": False,
"OUT08": False,
"OUT09": False,
"OUT10": False,
"OUT11": False,
}
LAYER17 = {
"BASE": True,
"IN00": False,
"IN01": True,
"IN02": True,
"IN03": False,
"IN04": True,
"IN05": True,
"IN06": False,
"IN07": True,
"IN08": True,
"IN09": False,
"IN10": False,
"IN11": False,
"MID": True,
"OUT00": False,
"OUT01": False,
"OUT02": False,
"OUT03": True,
"OUT04": True,
"OUT05": True,
"OUT06": True,
"OUT07": True,
"OUT08": True,
"OUT09": True,
"OUT10": True,
"OUT11": True,
}
LAYER20 = {
"BASE": True,
"IN00": True,
"IN01": True,
"IN02": True,
"IN03": True,
"IN04": True,
"IN05": True,
"IN06": True,
"IN07": True,
"IN08": True,
"IN09": False,
"IN10": False,
"IN11": False,
"MID": True,
"OUT00": True,
"OUT01": True,
"OUT02": True,
"OUT03": True,
"OUT04": True,
"OUT05": True,
"OUT06": True,
"OUT07": True,
"OUT08": True,
"OUT09": False,
"OUT10": False,
"OUT11": False,
}
LAYER26 = {
"BASE": True,
"IN00": True,
"IN01": True,
"IN02": True,
"IN03": True,
"IN04": True,
"IN05": True,
"IN06": True,
"IN07": True,
"IN08": True,
"IN09": True,
"IN10": True,
"IN11": True,
"MID": True,
"OUT00": True,
"OUT01": True,
"OUT02": True,
"OUT03": True,
"OUT04": True,
"OUT05": True,
"OUT06": True,
"OUT07": True,
"OUT08": True,
"OUT09": True,
"OUT10": True,
"OUT11": True,
}
assert len([v for v in LAYER12.values() if v]) == 12
assert len([v for v in LAYER17.values() if v]) == 17
assert len([v for v in LAYER20.values() if v]) == 20
assert len([v for v in LAYER26.values() if v]) == 26
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
def get_lbw_block_index(lora_name: str, is_sdxl: bool = False) -> int:
# lbw block index is 0-based, but 0 for text encoder, so we return 0 for text encoder
if "text_model_encoder_" in lora_name: # LoRA for text encoder
return 0
# lbw block index is 1-based for U-Net, and no "input_blocks.0" in CompVis SD, so "input_blocks.1" have index 2
block_idx = -1 # invalid lora name
if not is_sdxl:
NUM_OF_BLOCKS = 12 # up/down blocks
m = RE_UPDOWN.search(lora_name)
if m:
g = m.groups()
up_down = g[0]
i = int(g[1])
j = int(g[3])
if up_down == "down":
if g[2] == "resnets" or g[2] == "attentions":
idx = 3 * i + j + 1
elif g[2] == "downsamplers":
idx = 3 * (i + 1)
else:
return block_idx # invalid lora name
elif up_down == "up":
if g[2] == "resnets" or g[2] == "attentions":
idx = 3 * i + j
elif g[2] == "upsamplers":
idx = 3 * i + 2
else:
return block_idx # invalid lora name
if g[0] == "down":
block_idx = 1 + idx # 1-based index, down block index
elif g[0] == "up":
block_idx = 1 + NUM_OF_BLOCKS + 1 + idx # 1-based index, num blocks, mid block, up block index
elif "mid_block_" in lora_name:
block_idx = 1 + NUM_OF_BLOCKS # 1-based index, num blocks, mid block
else:
# SDXL: some numbers are skipped
if lora_name.startswith("lora_unet_"):
name = lora_name[len("lora_unet_") :]
if name.startswith("time_embed_") or name.startswith("label_emb_"): # 1, No LoRA in sd-scripts
block_idx = 1
elif name.startswith("input_blocks_"): # 1-8 to 2-9
block_idx = 1 + int(name.split("_")[2])
elif name.startswith("middle_block_"): # 13
block_idx = 13
elif name.startswith("output_blocks_"): # 0-8 to 14-22
block_idx = 14 + int(name.split("_")[2])
elif name.startswith("out_"): # 23, No LoRA in sd-scripts
block_idx = 23
return block_idx
def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == ".safetensors":
@@ -30,24 +216,53 @@ def load_state_dict(file_name, dtype):
return sd, metadata
def save_to_file(file_name, state_dict, dtype, metadata):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
def save_to_file(file_name, state_dict, metadata):
if os.path.splitext(file_name)[1] == ".safetensors":
save_file(state_dict, file_name, metadata=metadata)
else:
torch.save(state_dict, file_name)
def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype):
def format_lbws(lbws):
try:
# lbwは"[1,1,1,1,1,1,1,1,1,1,1,1]"のような文字列で与えられることを期待している
lbws = [json.loads(lbw) for lbw in lbws]
except Exception:
raise ValueError(f"format of lbws are must be json / 層別適用率はJSON形式で書いてください")
assert all(isinstance(lbw, list) for lbw in lbws), f"lbws are must be list / 層別適用率はリストにしてください"
assert len(set(len(lbw) for lbw in lbws)) == 1, "all lbws should have the same length / 層別適用率は同じ長さにしてください"
assert all(
len(lbw) in ACCEPTABLE for lbw in lbws
), f"length of lbw are must be in {ACCEPTABLE} / 層別適用率の長さは{ACCEPTABLE}のいずれかにしてください"
assert all(
all(isinstance(weight, (int, float)) for weight in lbw) for lbw in lbws
), f"values of lbs are must be numbers / 層別適用率の値はすべて数値にしてください"
layer_num = len(lbws[0])
is_sdxl = True if layer_num in SDXL_LAYER_NUM else False
FLAGS = {
"12": LAYER12.values(),
"17": LAYER17.values(),
"20": LAYER20.values(),
"26": LAYER26.values(),
}[str(layer_num)]
LBW_TARGET_IDX = [i for i, flag in enumerate(FLAGS) if flag]
return lbws, is_sdxl, LBW_TARGET_IDX
def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, merge_dtype):
logger.info(f"new rank: {new_rank}, new conv rank: {new_conv_rank}")
merged_sd = {}
v2 = None
v2 = None # This is meaning LoRA Metadata v2, Not meaning SD2
base_model = None
for model, ratio in zip(models, ratios):
if lbws:
lbws, is_sdxl, LBW_TARGET_IDX = format_lbws(lbws)
else:
is_sdxl = False
LBW_TARGET_IDX = []
for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws):
logger.info(f"loading: {model}")
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
@@ -57,6 +272,12 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
if base_model is None:
base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None)
if lbw:
lbw_weights = [1] * 26
for index, value in zip(LBW_TARGET_IDX, lbw):
lbw_weights[index] = value
logger.info(f"lbw: {dict(zip(LAYER26.keys(), lbw_weights))}")
# merge
logger.info(f"merging...")
for key in tqdm(list(lora_sd.keys())):
@@ -80,10 +301,10 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
# make original weight if not exist
if lora_module_name not in merged_sd:
weight = torch.zeros((out_dim, in_dim, *kernel_size) if conv2d else (out_dim, in_dim), dtype=merge_dtype)
if device:
weight = weight.to(device)
else:
weight = merged_sd[lora_module_name]
if device:
weight = weight.to(device)
# merge to weight
if device:
@@ -93,6 +314,12 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
# W <- W + U * D
scale = alpha / network_dim
if lbw:
index = get_lbw_block_index(key, is_sdxl)
is_lbw_target = index in LBW_TARGET_IDX
if is_lbw_target:
scale *= lbw_weights[index] # keyがlbwの対象であれば、lbwの重みを掛ける
if device: # and isinstance(scale, torch.Tensor):
scale = scale.to(device)
@@ -109,13 +336,16 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
weight = weight + ratio * conved * scale
merged_sd[lora_module_name] = weight
merged_sd[lora_module_name] = weight.to("cpu")
# extract from merged weights
logger.info("extract new lora...")
merged_lora_sd = {}
with torch.no_grad():
for lora_module_name, mat in tqdm(list(merged_sd.items())):
if device:
mat = mat.to(device)
conv2d = len(mat.size()) == 4
kernel_size = None if not conv2d else mat.size()[2:4]
conv2d_3x3 = conv2d and kernel_size != (1, 1)
@@ -154,7 +384,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
merged_lora_sd[lora_module_name + ".lora_up.weight"] = up_weight.to("cpu").contiguous()
merged_lora_sd[lora_module_name + ".lora_down.weight"] = down_weight.to("cpu").contiguous()
merged_lora_sd[lora_module_name + ".alpha"] = torch.tensor(module_new_rank)
merged_lora_sd[lora_module_name + ".alpha"] = torch.tensor(module_new_rank, device="cpu")
# build minimum metadata
dims = f"{new_rank}"
@@ -169,7 +399,15 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
def merge(args):
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
assert len(args.models) == len(
args.ratios
), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
if args.lbws:
assert len(args.models) == len(
args.lbws
), f"number of models must be equal to number of ratios / モデルの数と層別適用率の数は合わせてください"
else:
args.lbws = [] # zip_longestで扱えるようにlbws未使用時には空のリストにしておく
def str_to_dtype(p):
if p == "float":
@@ -187,9 +425,15 @@ def merge(args):
new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank
state_dict, metadata, v2, base_model = merge_lora_models(
args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype
args.models, args.ratios, args.lbws, args.new_rank, new_conv_rank, args.device, merge_dtype
)
# cast to save_dtype before calculating hashes
for key in list(state_dict.keys()):
value = state_dict[key]
if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype:
state_dict[key] = value.to(save_dtype)
logger.info(f"calculating hashes and creating metadata...")
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
@@ -211,7 +455,7 @@ def merge(args):
metadata.update(sai_metadata)
logger.info(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, save_dtype, metadata)
save_to_file(args.save_to, state_dict, metadata)
def setup_parser() -> argparse.ArgumentParser:
@@ -231,12 +475,19 @@ def setup_parser() -> argparse.ArgumentParser:
help="precision in merging (float is recommended) / マージの計算時の精度floatを推奨",
)
parser.add_argument(
"--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
"--save_to",
type=str,
default=None,
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors",
)
parser.add_argument(
"--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors"
"--models",
type=str,
nargs="*",
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors",
)
parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
parser.add_argument("--lbws", type=str, nargs="*", help="lbw for each model / それぞれのLoRAモデルの層別適用率")
parser.add_argument("--new_rank", type=int, default=4, help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
parser.add_argument(
"--new_conv_rank",
@@ -244,7 +495,9 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ",
)
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
parser.add_argument(
"--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う"
)
parser.add_argument(
"--no_metadata",
action="store_true",

View File

@@ -3,11 +3,13 @@ transformers==4.36.2
diffusers[torch]==0.25.0
ftfy==6.1.1
# albumentations==1.3.0
opencv-python==4.7.0.68
opencv-python==4.8.1.78
einops==0.7.0
pytorch-lightning==1.9.0
# bitsandbytes==0.39.1
tensorboard==2.10.1
bitsandbytes==0.43.0
prodigyopt==1.0
lion-pytorch==0.0.6
tensorboard
safetensors==0.4.2
# gradio==3.16.2
altair==4.2.2
@@ -15,6 +17,8 @@ easygui==0.98.3
toml==0.10.2
voluptuous==0.13.1
huggingface-hub==0.20.1
# for Image utils
imagesize==1.4.1
# for BLIP captioning
# requests==2.28.2
# timm==0.6.12
@@ -22,13 +26,16 @@ huggingface-hub==0.20.1
# for WD14 captioning (tensorflow)
# tensorflow==2.10.1
# for WD14 captioning (onnx)
# onnx==1.14.1
# onnxruntime-gpu==1.16.0
# onnxruntime==1.16.0
# onnx==1.15.0
# onnxruntime-gpu==1.17.1
# onnxruntime==1.17.1
# for cuda 12.1(default 11.8)
# onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/
# this is for onnx:
# protobuf==3.20.3
# open clip for SDXL
open-clip-torch==2.20.0
# open-clip-torch==2.20.0
# For logging
rich==13.7.0
# for kohya_ss library

View File

@@ -380,10 +380,10 @@ class PipelineLike:
def set_gradual_latent(self, gradual_latent):
if gradual_latent is None:
print("gradual_latent is disabled")
logger.info("gradual_latent is disabled")
self.gradual_latent = None
else:
print(f"gradual_latent is enabled: {gradual_latent}")
logger.info(f"gradual_latent is enabled: {gradual_latent}")
self.gradual_latent = gradual_latent # (ds_ratio, start_timesteps, every_n_steps, ratio_step)
@torch.no_grad()
@@ -789,8 +789,8 @@ class PipelineLike:
enable_gradual_latent = False
if self.gradual_latent:
if not hasattr(self.scheduler, "set_gradual_latent_params"):
print("gradual_latent is not supported for this scheduler. Ignoring.")
print(self.scheduler.__class__.__name__)
logger.info("gradual_latent is not supported for this scheduler. Ignoring.")
logger.info(f'{self.scheduler.__class__.__name__}')
else:
enable_gradual_latent = True
step_elapsed = 1000
@@ -2614,84 +2614,84 @@ def main(args):
m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent timesteps
gl_timesteps = int(m.group(1))
print(f"gradual latent timesteps: {gl_timesteps}")
logger.info(f"gradual latent timesteps: {gl_timesteps}")
continue
m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent ratio
gl_ratio = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
print(f"gradual latent ratio: {ds_ratio}")
logger.info(f"gradual latent ratio: {ds_ratio}")
continue
m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent every n steps
gl_every_n_steps = int(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
print(f"gradual latent every n steps: {gl_every_n_steps}")
logger.info(f"gradual latent every n steps: {gl_every_n_steps}")
continue
m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent ratio step
gl_ratio_step = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
print(f"gradual latent ratio step: {gl_ratio_step}")
logger.info(f"gradual latent ratio step: {gl_ratio_step}")
continue
m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent s noise
gl_s_noise = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
print(f"gradual latent s noise: {gl_s_noise}")
logger.info(f"gradual latent s noise: {gl_s_noise}")
continue
m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE)
if m: # gradual latent unsharp params
gl_unsharp_params = m.group(1)
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
print(f"gradual latent unsharp params: {gl_unsharp_params}")
logger.info(f"gradual latent unsharp params: {gl_unsharp_params}")
continue
# Gradual Latent
m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent timesteps
gl_timesteps = int(m.group(1))
print(f"gradual latent timesteps: {gl_timesteps}")
logger.info(f"gradual latent timesteps: {gl_timesteps}")
continue
m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent ratio
gl_ratio = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
print(f"gradual latent ratio: {ds_ratio}")
logger.info(f"gradual latent ratio: {ds_ratio}")
continue
m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent every n steps
gl_every_n_steps = int(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
print(f"gradual latent every n steps: {gl_every_n_steps}")
logger.info(f"gradual latent every n steps: {gl_every_n_steps}")
continue
m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent ratio step
gl_ratio_step = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
print(f"gradual latent ratio step: {gl_ratio_step}")
logger.info(f"gradual latent ratio step: {gl_ratio_step}")
continue
m = re.match(r"glsn ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent s noise
gl_s_noise = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
print(f"gradual latent s noise: {gl_s_noise}")
logger.info(f"gradual latent s noise: {gl_s_noise}")
continue
m = re.match(r"glus ([\d\.\-,]+)", parg, re.IGNORECASE)
if m: # gradual latent unsharp params
gl_unsharp_params = m.group(1)
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
print(f"gradual latent unsharp params: {gl_unsharp_params}")
logger.info(f"gradual latent unsharp params: {gl_unsharp_params}")
continue
except ValueError as ex:

View File

@@ -11,20 +11,24 @@ import numpy as np
import torch
from library.device_utils import init_ipex, get_preferred_device
init_ipex()
from tqdm import tqdm
from transformers import CLIPTokenizer
from diffusers import EulerDiscreteScheduler
from PIL import Image
import open_clip
# import open_clip
from safetensors.torch import load_file
from library import model_util, sdxl_model_util
import networks.lora as lora
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# scheduler: このあたりの設定はSD1/2と同じでいいらしい
@@ -154,12 +158,13 @@ if __name__ == "__main__":
text_model2.eval()
unet.set_use_memory_efficient_attention(True, False)
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
vae.set_use_memory_efficient_attention_xformers(True)
# Tokenizers
tokenizer1 = CLIPTokenizer.from_pretrained(text_encoder_1_name)
tokenizer2 = lambda x: open_clip.tokenize(x, context_length=77)
# tokenizer2 = lambda x: open_clip.tokenize(x, context_length=77)
tokenizer2 = CLIPTokenizer.from_pretrained(text_encoder_2_name)
# LoRA
for weights_file in args.lora_weights:
@@ -192,7 +197,9 @@ if __name__ == "__main__":
emb3 = get_timestep_embedding(torch.FloatTensor([target_height, target_width]).unsqueeze(0), 256)
# logger.info("emb1", emb1.shape)
c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(DEVICE, dtype=DTYPE)
uc_vector = c_vector.clone().to(DEVICE, dtype=DTYPE) # ちょっとここ正しいかどうかわからない I'm not sure if this is right
uc_vector = c_vector.clone().to(
DEVICE, dtype=DTYPE
) # ちょっとここ正しいかどうかわからない I'm not sure if this is right
# crossattn
@@ -215,13 +222,22 @@ if __name__ == "__main__":
# text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding) # layer normは通さないらしい
# text encoder 2
with torch.no_grad():
tokens = tokenizer2(text2).to(DEVICE)
# tokens = tokenizer2(text2).to(DEVICE)
tokens = tokenizer2(
text,
truncation=True,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(DEVICE)
with torch.no_grad():
enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True)
text_embedding2_penu = enc_out["hidden_states"][-2]
# logger.info("hidden_states2", text_embedding2_penu.shape)
text_embedding2_pool = enc_out["text_embeds"] # do not support Textual Inversion
text_embedding2_pool = enc_out["text_embeds"] # do not support Textual Inversion
# 連結して終了 concat and finish
text_embedding = torch.cat([text_embedding1, text_embedding2_penu], dim=2)

View File

@@ -11,11 +11,13 @@ from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from library import sdxl_model_util
from library import deepspeed_utils, sdxl_model_util
import library.train_util as train_util
@@ -39,6 +41,7 @@ from library.custom_train_functions import (
scale_v_prediction_loss_like_noise_prediction,
add_v_prediction_like_loss,
apply_debiased_estimation,
apply_masked_loss,
)
from library.sdxl_original_unet import SdxlUNet2DConditionModel
@@ -97,6 +100,7 @@ def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
sdxl_train_util.verify_sdxl_training_args(args)
deepspeed_utils.prepare_deepspeed_args(args)
setup_logging(args, reset=True)
assert (
@@ -124,7 +128,7 @@ def train(args):
# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
if args.dataset_config is not None:
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
@@ -398,18 +402,33 @@ def train(args):
text_encoder1.to(weight_dtype)
text_encoder2.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
if train_unet:
unet = accelerator.prepare(unet)
# freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer
if train_text_encoder1:
# freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer
text_encoder1.text_model.encoder.layers[-1].requires_grad_(False)
text_encoder1.text_model.final_layer_norm.requires_grad_(False)
text_encoder1 = accelerator.prepare(text_encoder1)
if train_text_encoder2:
text_encoder2 = accelerator.prepare(text_encoder2)
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
if args.deepspeed:
ds_model = deepspeed_utils.prepare_deepspeed_model(
args,
unet=unet if train_unet else None,
text_encoder1=text_encoder1 if train_text_encoder1 else None,
text_encoder2=text_encoder2 if train_text_encoder2 else None,
)
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
training_models = [ds_model]
else:
# acceleratorがなんかよろしくやってくれるらしい
if train_unet:
unet = accelerator.prepare(unet)
if train_text_encoder1:
text_encoder1 = accelerator.prepare(text_encoder1)
if train_text_encoder2:
text_encoder2 = accelerator.prepare(text_encoder2)
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
if args.cache_text_encoder_outputs:
@@ -424,6 +443,8 @@ def train(args):
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
# During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do.
# -> But we think it's ok to patch accelerator even if deepspeed is enabled.
train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする
@@ -561,7 +582,7 @@ def train(args):
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
@@ -576,9 +597,12 @@ def train(args):
or args.scale_v_pred_loss_like_noise_pred
or args.v_pred_like_loss
or args.debiased_estimation_loss
or args.masked_loss
):
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
if args.masked_loss:
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
if args.min_snr_gamma:
@@ -592,7 +616,7 @@ def train(args):
loss = loss.mean() # mean over batch dimension
else:
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c)
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
@@ -712,7 +736,7 @@ def train(args):
accelerator.end_training()
if args.save_state: # and is_main_process:
if args.save_state or args.save_state_on_train_end:
train_util.save_state_on_train_end(args, accelerator)
del accelerator # この後メモリを使うのでこれは消す
@@ -744,6 +768,8 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, False)
train_util.add_masked_loss_arguments(parser)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
@@ -779,7 +805,6 @@ def setup_parser() -> argparse.ArgumentParser:
help=f"learning rates for each block of U-Net, comma-separated, {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / "
+ f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値",
)
return parser
@@ -787,6 +812,7 @@ if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
train(args)

View File

@@ -22,7 +22,7 @@ from accelerate.utils import set_seed
import accelerate
from diffusers import DDPMScheduler, ControlNetModel
from safetensors.torch import load_file
from library import sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util
from library import deepspeed_utils, sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util
import library.model_util as model_util
import library.train_util as train_util
@@ -394,10 +394,10 @@ def train(args):
with accelerator.accumulate(unet):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample()
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
@@ -439,7 +439,7 @@ def train(args):
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
@@ -458,7 +458,7 @@ def train(args):
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
@@ -549,7 +549,7 @@ def train(args):
accelerator.end_training()
if is_main_process and args.save_state:
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)
if is_main_process:
@@ -566,6 +566,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
@@ -611,6 +612,7 @@ if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
train(args)

View File

@@ -18,7 +18,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from accelerate.utils import set_seed
from diffusers import DDPMScheduler, ControlNetModel
from safetensors.torch import load_file
from library import sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util
from library import deepspeed_utils, sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util
import library.model_util as model_util
import library.train_util as train_util
@@ -361,10 +361,10 @@ def train(args):
with accelerator.accumulate(network):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample()
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
@@ -406,7 +406,7 @@ def train(args):
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
@@ -426,7 +426,7 @@ def train(args):
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
@@ -534,6 +534,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
@@ -579,6 +580,7 @@ if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
train(args)

View File

@@ -178,6 +178,7 @@ if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
trainer = SdxlNetworkTrainer()

View File

@@ -7,7 +7,6 @@ import torch
from library.device_utils import init_ipex
init_ipex()
import open_clip
from library import sdxl_model_util, sdxl_train_util, train_util
import train_textual_inversion
@@ -132,6 +131,7 @@ if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
trainer = SdxlTextualInversionTrainer()

View File

@@ -16,12 +16,13 @@ from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
from library.utils import setup_logging
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
def cache_to_disk(args: argparse.Namespace) -> None:
setup_logging(args, reset=True)
train_util.prepare_dataset_args(args, True)
# check cache latents arg
@@ -94,6 +95,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
# acceleratorを準備する
logger.info("prepare accelerator")
args.deepspeed = False
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
@@ -170,6 +172,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
train_util.add_training_arguments(parser, True)
train_util.add_dataset_arguments(parser, True, True, True)

View File

@@ -16,12 +16,13 @@ from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
from library.utils import setup_logging
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
def cache_to_disk(args: argparse.Namespace) -> None:
setup_logging(args, reset=True)
train_util.prepare_dataset_args(args, True)
# check cache arg
@@ -99,6 +100,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
# acceleratorを準備する
logger.info("prepare accelerator")
args.deepspeed = False
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
@@ -171,6 +173,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
train_util.add_training_arguments(parser, True)
train_util.add_dataset_arguments(parser, True, True, True)

View File

@@ -11,6 +11,7 @@ import toml
from tqdm import tqdm
import torch
from library import deepspeed_utils
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
@@ -396,7 +397,7 @@ def train(args):
with accelerator.accumulate(controlnet):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
@@ -419,13 +420,8 @@ def train(args):
)
# Sample a random timestep for each image
timesteps = torch.randint(
0,
noise_scheduler.config.num_train_timesteps,
(b_size,),
device=latents.device,
)
timesteps = timesteps.long()
timesteps, huber_c = train_util.get_timesteps_and_huber_c(args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler, b_size, latents.device)
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
@@ -456,7 +452,7 @@ def train(args):
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
@@ -565,7 +561,7 @@ def train(args):
accelerator.end_training()
if is_main_process and args.save_state:
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)
# del accelerator # この後メモリを使うのでこれは消す→printで使うので消さずにおく
@@ -584,6 +580,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
@@ -615,6 +612,7 @@ if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
train(args)

View File

@@ -11,7 +11,10 @@ import toml
from tqdm import tqdm
import torch
from library import deepspeed_utils
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate.utils import set_seed
@@ -32,6 +35,7 @@ from library.custom_train_functions import (
apply_noise_offset,
scale_v_prediction_loss_like_noise_prediction,
apply_debiased_estimation,
apply_masked_loss,
)
from library.utils import setup_logging, add_logging_arguments
@@ -46,6 +50,7 @@ logger = logging.getLogger(__name__)
def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, False)
deepspeed_utils.prepare_deepspeed_args(args)
setup_logging(args, reset=True)
cache_latents = args.cache_latents
@@ -57,7 +62,7 @@ def train(args):
# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, False, True))
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, args.masked_loss, True))
if args.dataset_config is not None:
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
@@ -219,12 +224,25 @@ def train(args):
text_encoder.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
if train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
if args.deepspeed:
if args.train_text_encoder:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder)
else:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
training_models = [ds_model]
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
if train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
training_models = [unet, text_encoder]
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
training_models = [unet]
if not train_text_encoder:
text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error
@@ -296,12 +314,14 @@ def train(args):
if not args.gradient_checkpointing:
text_encoder.train(False)
text_encoder.requires_grad_(False)
if len(training_models) == 2:
training_models = training_models[0] # remove text_encoder from training_models
with accelerator.accumulate(unet):
with accelerator.accumulate(*training_models):
with torch.no_grad():
# latentに変換
if cache_latents:
latents = batch["latents"].to(accelerator.device)
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
@@ -326,7 +346,7 @@ def train(args):
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
# Predict the noise residual
with accelerator.autocast():
@@ -338,7 +358,9 @@ def train(args):
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
if args.masked_loss:
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
@@ -444,7 +466,7 @@ def train(args):
accelerator.end_training()
if args.save_state and is_main_process:
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)
del accelerator # この後メモリを使うのでこれは消す
@@ -464,6 +486,8 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, False, True)
train_util.add_training_arguments(parser, True)
train_util.add_masked_loss_arguments(parser)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
@@ -499,6 +523,7 @@ if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
train(args)

View File

@@ -13,18 +13,15 @@ from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from torch.nn.parallel import DistributedDataParallel as DDP
init_ipex()
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from library import model_util
from library import deepspeed_utils, model_util
import library.train_util as train_util
from library.train_util import (
DreamBoothDataset,
)
from library.train_util import DreamBoothDataset
import library.config_util as config_util
from library.config_util import (
ConfigSanitizer,
@@ -39,6 +36,7 @@ from library.custom_train_functions import (
scale_v_prediction_loss_like_noise_prediction,
add_v_prediction_like_loss,
apply_debiased_estimation,
apply_masked_loss,
)
from library.utils import setup_logging, add_logging_arguments
@@ -141,6 +139,7 @@ class NetworkTrainer:
training_started_at = time.time()
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
deepspeed_utils.prepare_deepspeed_args(args)
setup_logging(args, reset=True)
cache_latents = args.cache_latents
@@ -157,7 +156,7 @@ class NetworkTrainer:
# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
if use_user_config:
logger.info(f"Loading dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
@@ -413,20 +412,36 @@ class NetworkTrainer:
t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
if train_unet:
unet = accelerator.prepare(unet)
if args.deepspeed:
ds_model = deepspeed_utils.prepare_deepspeed_model(
args,
unet=unet if train_unet else None,
text_encoder1=text_encoders[0] if train_text_encoder else None,
text_encoder2=text_encoders[1] if train_text_encoder and len(text_encoders) > 1 else None,
network=network,
)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
training_model = ds_model
else:
unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator
if train_text_encoder:
if len(text_encoders) > 1:
text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders]
if train_unet:
unet = accelerator.prepare(unet)
else:
text_encoder = accelerator.prepare(text_encoder)
text_encoders = [text_encoder]
else:
pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set
unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator
if train_text_encoder:
if len(text_encoders) > 1:
text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders]
else:
text_encoder = accelerator.prepare(text_encoder)
text_encoders = [text_encoder]
else:
pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler)
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
network, optimizer, train_dataloader, lr_scheduler
)
training_model = network
if args.gradient_checkpointing:
# according to TI example in Diffusers, train is required
@@ -456,6 +471,31 @@ class NetworkTrainer:
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
# before resuming make hook for saving/loading to save/load the network weights only
def save_model_hook(models, weights, output_dir):
# pop weights of other models than network to save only network weights
if accelerator.is_main_process:
remove_indices = []
for i, model in enumerate(models):
if not isinstance(model, type(accelerator.unwrap_model(network))):
remove_indices.append(i)
for i in reversed(remove_indices):
weights.pop(i)
# print(f"save model hook: {len(weights)} weights will be saved")
def load_model_hook(models, input_dir):
# remove models except network
remove_indices = []
for i, model in enumerate(models):
if not isinstance(model, type(accelerator.unwrap_model(network))):
remove_indices.append(i)
for i in reversed(remove_indices):
models.pop(i)
# print(f"load model hook: {len(models)} models will be loaded")
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
# resumeする
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
@@ -529,6 +569,11 @@ class NetworkTrainer:
"ss_scale_weight_norms": args.scale_weight_norms,
"ss_ip_noise_gamma": args.ip_noise_gamma,
"ss_debiased_estimation": bool(args.debiased_estimation_loss),
"ss_noise_offset_random_strength": args.noise_offset_random_strength,
"ss_ip_noise_gamma_random_strength": args.ip_noise_gamma_random_strength,
"ss_loss_type": args.loss_type,
"ss_huber_schedule": args.huber_schedule,
"ss_huber_c": args.huber_c,
}
if use_user_config:
@@ -564,6 +609,11 @@ class NetworkTrainer:
"random_crop": bool(subset.random_crop),
"shuffle_caption": bool(subset.shuffle_caption),
"keep_tokens": subset.keep_tokens,
"keep_tokens_separator": subset.keep_tokens_separator,
"secondary_separator": subset.secondary_separator,
"enable_wildcard": bool(subset.enable_wildcard),
"caption_prefix": subset.caption_prefix,
"caption_suffix": subset.caption_suffix,
}
image_dir_or_metadata_file = None
@@ -753,21 +803,21 @@ class NetworkTrainer:
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(network):
with accelerator.accumulate(training_model):
on_step_start(text_encoder, unet)
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
else:
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
with torch.no_grad():
# latentに変換
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample()
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.nan_to_num(latents, 0, out=latents)
latents = latents * self.vae_scale_factor
latents = latents * self.vae_scale_factor
# get multiplier for each sample
if network_has_multiplier:
@@ -798,7 +848,7 @@ class NetworkTrainer:
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)
@@ -828,7 +878,11 @@ class NetworkTrainer:
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
if args.masked_loss:
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
@@ -935,7 +989,7 @@ class NetworkTrainer:
accelerator.end_training()
if is_main_process and args.save_state:
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)
if is_main_process:
@@ -952,6 +1006,8 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, True)
train_util.add_masked_loss_arguments(parser)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
@@ -1052,6 +1108,7 @@ if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
trainer = NetworkTrainer()

View File

@@ -8,12 +8,14 @@ from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from transformers import CLIPTokenizer
from library import model_util
from library import deepspeed_utils, model_util
import library.train_util as train_util
import library.huggingface_util as huggingface_util
@@ -29,6 +31,7 @@ from library.custom_train_functions import (
scale_v_prediction_loss_like_noise_prediction,
add_v_prediction_like_loss,
apply_debiased_estimation,
apply_masked_loss,
)
from library.utils import setup_logging, add_logging_arguments
@@ -268,7 +271,7 @@ class TextualInversionTrainer:
# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, False))
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, False))
if args.dataset_config is not None:
accelerator.print(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
@@ -558,10 +561,10 @@ class TextualInversionTrainer:
with accelerator.accumulate(text_encoders[0]):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample()
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)
latents = latents * self.vae_scale_factor
# Get the text embedding for conditioning
@@ -569,7 +572,7 @@ class TextualInversionTrainer:
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)
@@ -585,7 +588,9 @@ class TextualInversionTrainer:
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
if args.masked_loss:
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
@@ -732,7 +737,7 @@ class TextualInversionTrainer:
accelerator.end_training()
if args.save_state and is_main_process:
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)
if is_main_process:
@@ -749,6 +754,8 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, False)
train_util.add_training_arguments(parser, True)
train_util.add_masked_loss_arguments(parser)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser, False)
@@ -799,6 +806,7 @@ if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
trainer = TextualInversionTrainer()

View File

@@ -8,7 +8,9 @@ from multiprocessing import Value
from tqdm import tqdm
import torch
from library import deepspeed_utils
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate.utils import set_seed
@@ -31,6 +33,7 @@ from library.custom_train_functions import (
apply_noise_offset,
scale_v_prediction_loss_like_noise_prediction,
apply_debiased_estimation,
apply_masked_loss,
)
import library.original_unet as original_unet
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
@@ -200,7 +203,7 @@ def train(args):
logger.info(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
# データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, False))
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, False))
if args.dataset_config is not None:
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
@@ -439,7 +442,7 @@ def train(args):
with accelerator.accumulate(text_encoder):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
@@ -458,7 +461,7 @@ def train(args):
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
# Predict the noise residual
with accelerator.autocast():
@@ -470,7 +473,9 @@ def train(args):
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
if args.masked_loss:
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
@@ -586,7 +591,7 @@ def train(args):
accelerator.end_training()
if args.save_state and is_main_process:
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)
updated_embs = text_encoder.get_input_embeddings().weight[token_ids_XTI].data.detach().clone()
@@ -662,6 +667,8 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, False)
train_util.add_training_arguments(parser, True)
train_util.add_masked_loss_arguments(parser)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser, False)
@@ -707,6 +714,7 @@ if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
train(args)