Compare commits

..

188 Commits

Author SHA1 Message Date
Kohya S
08ae46b163 Merge pull request #208 from space-nuko/add-optimizer-to-metadata
Add optimizer to metadata
2023-02-19 17:21:57 +09:00
space-nuko
4e5db58a71 Add optimizer to metadata 2023-02-18 23:28:36 -08:00
Kohya S
a9d29ac78c Merge pull request #207 from kohya-ss/dev
Dev
2023-02-19 15:29:40 +09:00
Kohya S
5c065eee79 update readme 2023-02-19 15:26:21 +09:00
Kohya S
048e7cd428 add lion optimizer support 2023-02-19 15:26:14 +09:00
Kohya S
a76ad2d1d5 add comment for future requirement update 2023-02-19 15:25:01 +09:00
Kohya S
9d0f9736bf Merge pull request #202 from vladmandic/main
fix git path
2023-02-19 15:01:21 +09:00
Kohya S
00bb8a65a6 Merge pull request #200 from Isotr0py/lowram
Add '--lowram' argument
2023-02-19 14:32:32 +09:00
Vladimir Mandic
dac2bd163a fix git path 2023-02-17 14:19:08 -05:00
Isotr0py
78d1fb5ce6 Add '--lowram' argument 2023-02-17 12:08:54 +08:00
Kohya S
14d7b24619 Merge pull request #198 from kohya-ss/dev
Dev
2023-02-16 22:35:47 +09:00
Kohya S
3bc0d83769 update readme 2023-02-16 22:21:51 +09:00
Kohya S
ffdfd5f615 fix name of loss for epoch 2023-02-16 22:21:36 +09:00
Kohya S
d01d953262 Merge pull request #196 from space-nuko/add-noise-offset-metadata
Add noise offset to metadata
2023-02-16 22:01:02 +09:00
Kohya S
914d1505df Merge pull request #189 from shirayu/improve_loss_track
Show the moving average loss
2023-02-16 22:00:26 +09:00
space-nuko
496c8cdc09 Add noise-offset to metadata 2023-02-16 02:56:39 -08:00
Kohya S
82713e9aa6 Update README.md 2023-02-14 21:41:04 +09:00
Kohya S
e067d64b53 Merge pull request #190 from kohya-ss/dev
Dev
2023-02-14 21:32:03 +09:00
Kohya S
3d400667d2 fix typos 2023-02-14 21:29:40 +09:00
Kohya S
2aef2872fb update readme 2023-02-14 21:28:34 +09:00
Kohya S
43c0a69843 Add noise_offset 2023-02-14 21:15:48 +09:00
Yuta Hayashibe
8aed5125de Removed call of sum() 2023-02-14 21:11:30 +09:00
Kohya S
e0f007f2a9 Fix import 2023-02-14 20:55:38 +09:00
Kohya S
3c29784825 Add ja comment 2023-02-14 20:55:20 +09:00
Kohya S
8f1e930bf4 Merge pull request #187 from space-nuko/add-commit-hash
Add commit hash to metadata
2023-02-14 19:52:30 +09:00
Kohya S
f771396e90 Merge pull request #179 from mgz-dev/resize_lora-verbose-print
add verbosity option for resize_lora.py
2023-02-14 19:50:49 +09:00
Kohya S
f67b3f4452 Merge pull request #165 from Isotr0py/support-multi-gpu
Add support with multi-gpu train for train_newtork.py
2023-02-14 19:47:53 +09:00
Yuta Hayashibe
21f5b618c3 Show the moving average loss 2023-02-14 19:46:27 +09:00
space-nuko
5471b0deb0 Add commit hash to metadata 2023-02-13 02:58:06 -08:00
Isotr0py
2b1a3080e7 Add type checking 2023-02-12 15:32:38 +08:00
Isotr0py
92a1af8024 Merge branch 'kohya-ss:main' into support-multi-gpu 2023-02-12 15:06:46 +08:00
michaelgzhang
b35b053b8d clean up print formatting 2023-02-11 03:14:43 -06:00
michaelgzhang
55521eece0 add verbosity option for resize_lora.py
add --verbose flag to print additional statistics during resize_lora function
correct some parameter references in resize_lora_model function
2023-02-11 02:38:13 -06:00
Kohya S
b32abdd327 Merge pull request #178 from kohya-ss/dev
Dev
2023-02-11 16:16:15 +09:00
Kohya S
d1ecfde487 fix typo 2023-02-11 16:12:27 +09:00
Kohya S
04ad46a9a7 update readme 2023-02-11 16:11:42 +09:00
Kohya S
4c561411aa revert batch size limiting for bucket 2023-02-11 16:02:56 +09:00
Kohya S
43a41c6c43 Merge pull request #177 from kohya-ss/dev
Dev
2023-02-11 15:11:07 +09:00
Kohya S
5367daa210 update readme 2023-02-11 15:09:45 +09:00
Kohya S
b825e4602c update readme 2023-02-11 15:05:45 +09:00
Kohya S
188e54b760 support multiple init words 2023-02-11 15:00:11 +09:00
Kohya S
2c5f5c324a Fix crash TI train close #172, tag drop wo shuffle 2023-02-11 14:41:44 +09:00
Kohya S
5777be5208 Update README.md 2023-02-11 13:36:33 +09:00
Kohya S
e727a0d222 Update README.md 2023-02-11 13:30:12 +09:00
Kohya S
cdd8882a01 Merge pull request #176 from kohya-ss/dev
Dev
2023-02-11 13:22:40 +09:00
Kohya S
3f3502fb57 add message 2023-02-11 13:20:58 +09:00
Kohya S
20c00603a8 Merge branch 'main' into dev 2023-02-11 13:16:13 +09:00
Kohya S
9239fefa52 add lora interrogator with text encoder 2023-02-11 13:15:57 +09:00
Kohya S
53d60543e5 Merge pull request #174 from kohya-ss/dev
Dev
2023-02-10 23:11:12 +09:00
Kohya S
22e3aca89c Update README.md 2023-02-10 23:07:53 +09:00
Kohya S
8d86f58174 add merge script with svd 2023-02-10 22:55:33 +09:00
Kohya S
e5cc64a563 support multibyte characters for filename 2023-02-10 22:55:21 +09:00
Kohya S
c7406d6b27 keep metadata when resizing 2023-02-10 22:55:00 +09:00
Kohya S
d2da3c4236 support for models with different alphas 2023-02-10 22:54:35 +09:00
Kohya S
2bad87f2f6 Update README-ja.md 2023-02-10 18:12:03 +09:00
Kohya S
ed62e566bb Update README.md 2023-02-10 18:11:39 +09:00
Kohya S
51b3dc2c11 Merge pull request #171 from kohya-ss/dev
Dev
2023-02-10 17:40:08 +09:00
Kohya S
74f4a8fab9 Merge branch 'main' into dev 2023-02-10 17:37:39 +09:00
Kohya S
a75baf9143 Add strict version no 2023-02-10 17:37:19 +09:00
Kohya S
b03721b4d9 Add todo comment 2023-02-10 17:36:38 +09:00
Kohya S
6b790bace6 Update README.md 2023-02-09 23:14:41 +09:00
Kohya S
dcaecfd20b Merge pull request #168 from kohya-ss/dev
Dev
2023-02-09 22:15:35 +09:00
Kohya S
553ac4aa1b add about resizeing script 2023-02-09 22:13:01 +09:00
Kohya S
f0c8c95871 add assocatied files copying 2023-02-09 22:12:41 +09:00
Kohya S
c2e1d4b71b fix typo 2023-02-09 21:38:01 +09:00
Kohya S
3a72e6f003 add tag dropout 2023-02-09 21:35:27 +09:00
Kohya S
f7b5abb595 add resizing script 2023-02-09 21:30:27 +09:00
Isotr0py
b8ad17902f fix get_hidden_states expected scalar Error again 2023-02-08 23:09:59 +08:00
Isotr0py
9a9ac79edf correct wrong inserted code for noise_pred fix 2023-02-08 22:30:20 +08:00
Isotr0py
6473aa1dd7 fix Input type error in noise_pred when using DDP 2023-02-08 21:32:21 +08:00
Isotr0py
b599adc938 fix Input type error when using DDP 2023-02-08 20:14:20 +08:00
Isotr0py
5e96e1369d fix get_hidden_states expected scalar Error 2023-02-08 20:14:13 +08:00
Isotr0py
c0be52a773 ignore get_hidden_states expected scalar Error 2023-02-08 20:13:09 +08:00
Isotr0py
fb312acb7f support DistributedDataParallel 2023-02-08 20:12:43 +08:00
Isotr0py
938bd71844 lower ram usage 2023-02-08 18:31:27 +08:00
Kohya S
b3020db63f support python 3.8 2023-02-07 22:29:12 +09:00
Kohya S
e42b2f7aa9 conditional caption dropout (in progress) 2023-02-07 22:28:56 +09:00
Kohya S
f9478f0d47 Merge pull request #159 from forestsource/main
Add Conditional Dropout options
2023-02-07 21:50:26 +09:00
Kohya S
4fc9f1f8c5 Merge pull request #157 from shirayu/improve_tag_shuffle
Always join with ", "
2023-02-07 21:47:05 +09:00
Kohya S
5a3d1a57b6 Merge pull request #154 from shirayu/typos_checker
Add typo check GitHub Action
2023-02-07 21:35:35 +09:00
forestsource
7db98baa86 Add dropout options 2023-02-07 00:01:30 +09:00
Kohya S
d591891048 Update README.md 2023-02-06 21:30:38 +09:00
Kohya S
3a93d18bb5 Merge pull request #158 from kohya-ss/dev
Dev
2023-02-06 21:26:14 +09:00
Kohya S
7511674333 update readme 2023-02-06 21:14:16 +09:00
Kohya S
883bd1269c Merge branch 'dev' of https://github.com/kohya-ss/sd-scripts into dev 2023-02-06 20:52:30 +09:00
Kohya S
2aa27b7a4b Update downsampling for larger image in no_upscale 2023-02-06 20:52:24 +09:00
Yuta Hayashibe
5ea5fefcd2 Always join with ", " 2023-02-06 12:29:41 +09:00
Kohya S
6a79ac6a03 Update README.md 2023-02-05 21:59:55 +09:00
Kohya S
ea2dfd09ef update bucketing features 2023-02-05 21:37:46 +09:00
Yuta Hayashibe
7380801dfc Add typo check GitHub Action 2023-02-05 19:22:18 +09:00
Kohya S
ae33d72479 Merge pull request #153 from shirayu/fix_a_typo
Fix a typo
2023-02-04 21:21:24 +09:00
Yuta Hayashibe
19c2752e87 Fix a typo 2023-02-04 21:18:34 +09:00
Kohya S
d80af9c17b Merge pull request #152 from kohya-ss/dev
Dev
2023-02-04 20:53:58 +09:00
Kohya S
fb230aff1b Update README.md 2023-02-04 20:52:24 +09:00
Kohya S
8cbd3f4fca Add device option to calculate on GPU 2023-02-04 20:36:10 +09:00
Kohya S
b18db9fbbd Merge pull request #147 from mgz-dev/resize_lora_rank
resize lora rank
2023-02-04 18:23:07 +09:00
Kohya S
b1635f4bf6 Merge pull request #144 from tsukimiya/debug_dataset_linux_support
Fixed --debug_dataset option to work in non-Windows environments
2023-02-04 18:19:04 +09:00
Kohya S
44013fe0ef Merge pull request #140 from hitomi/main
Add persistent_workers options in DataLoader
2023-02-04 18:16:31 +09:00
Kohya S
9fd7fb813d Merge branch 'dev' into main 2023-02-04 18:16:03 +09:00
mgz
89a9d3a92c Merge branch 'kohya-ss:main' into resize_lora_rank 2023-02-03 23:12:11 +00:00
Kohya S
9682772b09 Update README-ja.md 2023-02-03 22:10:17 +09:00
Kohya S
b18a09edb5 Update README.md 2023-02-03 22:09:55 +09:00
Kohya S
c086e85d17 Merge pull request #148 from kohya-ss/dev
Dev
2023-02-03 22:05:49 +09:00
Kohya S
26efa88908 Update README.md 2023-02-03 22:02:49 +09:00
Kohya S
1bec2bfe07 Add cleaning duplicated tags 2023-02-03 21:05:55 +09:00
Kohya S
76f53429be Fix existing npz skip feature 2023-02-03 21:05:14 +09:00
Kohya S
73d612ff9c Add cleaning patterns 2023-02-03 21:04:37 +09:00
Kohya S
58a809eaff Add comment 2023-02-03 21:04:03 +09:00
Kohya S
93134cdd15 Add tag freq for FinetuneDataset 2023-02-03 21:03:42 +09:00
michaelgzhang
b7e7ee387a resize lora rank
add script which can be used to convert higher rank lora to approximate lower rank lora using svd
2023-02-03 01:00:02 -06:00
Kohya S
57d8483eaf add GIT captioning, refactoring, DataLoader 2023-02-03 08:45:33 +09:00
tsukimiya
949ee6fcc9 Fixed --debug_dataset option to work in non-Windows environments 2023-02-03 00:37:27 +09:00
hitomi
26a81d075c add --persistent_data_loader_workers option 2023-02-01 16:02:15 +08:00
Kohya S
8c3a52ecc9 Merge pull request #129 from p1atdev/main
Add support for .jpeg images in glob
2023-01-31 21:03:46 +09:00
Kohya S
86f4e20337 Merge branch 'dev' into main 2023-01-31 21:02:18 +09:00
Kohya S
9abbee0632 Merge pull request #110 from breakcore2/main
add recursive tag search when merging tags to metadata
2023-01-31 21:00:15 +09:00
Kohya S
74eba06d13 Merge pull request #104 from space-nuko/caption-frequency-metadata
Add tag frequency metadata
2023-01-31 20:56:15 +09:00
unknown
4e1acc62f9 Merge branch 'main' of https://github.com/kohya-ss/sd-scripts 2023-01-29 22:32:06 +09:00
unknown
c20745b6e8 fix: #53 2023-01-29 22:30:45 +09:00
Kohya S
4cabb37977 Update README.md 2023-01-29 21:50:17 +09:00
Kohya S
86eba1d2cf Update README.md 2023-01-29 21:23:05 +09:00
Kohya S
05940940c0 Merge pull request #128 from kohya-ss/dev
Dev
2023-01-29 21:16:09 +09:00
Kohya S
6bbb4d426e Fix unet config in Diffusers (sample_size=64) 2023-01-29 20:43:58 +09:00
Kohya S
7817e95a86 change name of arg 2023-01-29 20:28:24 +09:00
Kohya S
443ce7a30b Merge pull request #121 from mgz-dev/monkeypatch-lr_schedulers
monkeypatch updated get_scheduler for diffusers
2023-01-29 18:14:47 +09:00
Kohya S
ed2e431950 Merge branch 'main' into caption-frequency-metadata 2023-01-29 17:50:23 +09:00
michaelgzhang
0fef7b4684 monkeypatch updated get_scheduler for diffusers
enables use of "num_cycles" and "power" for cosine_with_restarts and polynomial learning rate schedulers
2023-01-27 16:42:11 -06:00
Kohya S
67e698af67 Merge pull request #114 from shirayu/fix_typos
Fix typos
2023-01-27 19:14:35 +09:00
Kohya S
7c35aee042 Update train_ti_README-ja.md 2023-01-26 22:22:37 +09:00
Yuta Hayashibe
481823796e Fix typos 2023-01-26 22:12:29 +09:00
Kohya S
835b0d54cd Update train_ti_README-ja.md 2023-01-26 22:11:37 +09:00
Kohya S
505768ea86 Update documents for TI 2023-01-26 22:06:29 +09:00
Kohya S
1614d30d1b Merge pull request #113 from kohya-ss/textual_inversion
Add supporting for Textual inversion
2023-01-26 21:41:48 +09:00
Kohya S
25566182a8 Support newer traiing args 2023-01-26 21:37:14 +09:00
Kohya S
6dffc88b44 Support Textual Inversion 2023-01-26 21:36:43 +09:00
breakcore2
64d5ceda71 simplify arg to --recursive 2023-01-26 01:06:33 -08:00
breakcore2
e8806f29dc Merge branch 'kohya-ss:main' into main 2023-01-26 01:02:17 -08:00
breakcore2
2ce9ad235c add recursive structure merge dd tags and convert to pathlib 2023-01-26 01:01:38 -08:00
Kohya S
3fb12e41b7 Merge branch 'main' into textual_inversion 2023-01-26 17:50:20 +09:00
Kohya S
591e3c1813 Update train_network_README-ja.md 2023-01-26 08:37:14 +09:00
Kohya S
b5ba463512 Update fine_tune_README_ja.md 2023-01-26 08:32:51 +09:00
Kohya S
e0d7f1d99d Update train_db_README-ja.md 2023-01-26 08:32:05 +09:00
Kohya S
a68501bede Update README-ja.md 2023-01-25 14:02:27 +09:00
Kohya S
c425afb08b Update README.md 2023-01-25 14:00:42 +09:00
Kohya S
46029b2707 Update README.md 2023-01-24 20:57:33 +09:00
Kohya S
02acae8e1d Merge pull request #107 from kohya-ss/dev
merge dev to main
2023-01-24 20:21:57 +09:00
Kohya S
91a50ea637 Change img_ar_errors to mean because too many imgs 2023-01-24 20:17:15 +09:00
Kohya S
9f644d8dc3 Change default save format to safetensors 2023-01-24 20:16:21 +09:00
Kohya S
36dc97c841 Merge pull request #103 from space-nuko/bucketing-metadata
Add bucketing metadata
2023-01-24 19:06:21 +09:00
Kohya S
e6bad080cb Merge pull request #102 from space-nuko/precalculate-hashes
Precalculate .safetensors model hashes after training
2023-01-24 19:03:45 +09:00
Kohya S
7f17237ada Merge pull request #92 from forestsource/add_save_n_epoch_ratio
Add save_n_epoch_ratio
2023-01-24 18:59:47 +09:00
Kohya S
ebd3ea380c Merge branch 'main' into dev 2023-01-24 18:57:49 +09:00
Kohya S
bf3a13bb4e Fix error for loading bf16 weights 2023-01-24 18:57:21 +09:00
Kohya S
1a170c4762 Merge pull request #106 from shirayu/patch-1
Fix markdown
2023-01-24 18:51:46 +09:00
Yuta Hayashibe
552cdbd6d8 Fix markdown 2023-01-24 18:39:05 +09:00
Kohya S
a86514f1ad Merge pull request #97 from shirayu/patch-1
Fix a link
2023-01-24 18:08:46 +09:00
space-nuko
2e8a3d20dd Add tag frequency metadata 2023-01-23 17:43:03 -08:00
space-nuko
66051883fb Add bucketing metadata 2023-01-23 17:26:58 -08:00
space-nuko
f7fbdc4b2a Precalculate .safetensors model hashes after training 2023-01-23 17:21:04 -08:00
breakcore2
00f1296537 Merge branch 'kohya-ss:main' into main 2023-01-22 22:57:44 -08:00
Yuta Hayashibe
ebdb624d29 Fix a link 2023-01-23 00:25:32 +09:00
Kohya S
93df55d597 Merge pull request #96 from shirayu/patch-1
``--network_dim`` is removed from ``gen_img_diffusers.py``
2023-01-22 23:29:52 +09:00
Yuta Hayashibe
56bc806d52 `--network_dim is removed from gen_img_diffusers.py` 2023-01-22 23:10:10 +09:00
Kohya S
25f8ac731f Update README-ja.md 2023-01-22 22:22:53 +09:00
Kohya S
4ba1667978 Update README.md 2023-01-22 22:19:07 +09:00
Kohya S
0ca064287e Update README.md 2023-01-22 22:03:15 +09:00
Kohya S
a3171714ce Update README.md 2023-01-22 21:57:59 +09:00
Kohya S
4a1668fe37 Merge pull request #95 from kohya-ss/dev
support alpha etc.
2023-01-22 21:47:45 +09:00
Kohya S
4eb356f165 Upate readme 2023-01-22 21:33:58 +09:00
Kohya S
a7218574f2 Update help message 2023-01-22 21:33:48 +09:00
Kohya S
ddfe94b33b Update for alpha value 2023-01-22 21:33:35 +09:00
Kohya S
8746188ed7 Add traning_comment metadata. 2023-01-22 18:33:19 +09:00
Kohya S
1bfcf164f1 Merge branch 'main' into dev 2023-01-22 11:26:18 +09:00
forestsource
5e817e4343 Add save_n_epoch_ratio 2023-01-22 03:00:28 +09:00
Kohya S
b4636d4185 Add scaling alpha for LoRA 2023-01-21 20:37:34 +09:00
Kohya S
22ee0ac467 Move TE/UN loss calc to train script 2023-01-21 12:51:17 +09:00
Kohya S
17089b1287 Merge branch 'dev' of https://github.com/kohya-ss/sd-scripts into dev 2023-01-21 12:46:20 +09:00
Kohya S
7ee808d5d7 Merge pull request #79 from mgz-dev/tensorboard-improvements
expand details in tensorboard logs
2023-01-21 12:46:13 +09:00
Kohya S
9ff26af68b Update to add grad_ckpting etc to metadata 2023-01-21 12:36:31 +09:00
Kohya S
7dbcef745a Merge pull request #77 from space-nuko/ss-extra-metadata
More helpful metadata
2023-01-21 12:18:23 +09:00
space-nuko
da48f74e7b Add new version model/VAE hash to training metadata 2023-01-18 23:00:16 -08:00
mgz
e5d9f483f0 Merge branch 'kohya-ss:main' into tensorboard-improvements 2023-01-18 21:30:15 +00:00
michaelgzhang
303c3410e2 expand details in tensorboard logs
- Update tensorboard logging to track both unet and textencoder learning rates
- Update tensorboard logging to track both current and moving average epoch loss
- Clean up tensorboard log variable names for dashboard formatting
2023-01-18 13:10:13 -06:00
space-nuko
de1dde1a06 More helpful metadata
- dataset/reg image dirs
- random session ID
- keep_tokens
- training date
- output name
2023-01-17 16:28:35 -08:00
Kohya S
186a2665ad Merge branch 'main' into textual_inversion 2023-01-15 16:08:53 +09:00
breakcore2
29c9008e07 Merge branch 'kohya-ss:main' into main 2023-01-13 23:04:37 -08:00
Kohya S
c1b14fcdd6 initial version of TI 2023-01-12 20:47:08 +09:00
breakcore2
4735b21318 add .bmp support for wd14 tagger 2023-01-06 22:21:06 -08:00
33 changed files with 3123 additions and 411 deletions

21
.github/workflows/typos.yml vendored Normal file
View File

@@ -0,0 +1,21 @@
---
# yamllint disable rule:line-length
name: Typos
on: # yamllint disable-line rule:truthy
push:
pull_request:
types:
- opened
- synchronize
- reopened
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: typos-action
uses: crate-ci/typos@v1.13.10

3
.gitignore vendored
View File

@@ -3,4 +3,5 @@ __pycache__
wd14_tagger_model
venv
*.egg-info
build
build
.vscode

View File

@@ -1,7 +1,7 @@
## リポジトリについて
Stable Diffusionの学習、画像生成、その他のスクリプトを入れたリポジトリです。
[README in English](./README.md)
[README in English](./README.md) ←更新情報はこちらにあります
GUIやPowerShellスクリプトなど、より使いやすくする機能が[bmaltais氏のリポジトリ](https://github.com/bmaltais/kohya_ss)で提供されています英語ですのであわせてご覧ください。bmaltais氏に感謝します。
@@ -16,9 +16,11 @@ GUIやPowerShellスクリプトなど、より使いやすくする機能が[bma
当リポジトリ内およびnote.comに記事がありますのでそちらをご覧ください将来的にはすべてこちらへ移すかもしれません
* note.com [環境整備とDreamBooth学習スクリプトについて](https://note.com/kohya_ss/n/nba4eceaa4594)
* [DreamBooth学習について](./train_db_README-ja.md)
* [fine-tuningのガイド](./fine_tune_README_ja.md):
BLIPによるキャプショニングと、DeepDanbooruまたはWD14 taggerによるタグ付けを含みます
* [LoRAの学習について](./train_network_README-ja.md)
* [Textual Inversionの学習について](./train_ti_README-ja.md)
* note.com [画像生成スクリプト](https://note.com/kohya_ss/n/n2693183a798e)
* note.com [モデル変換スクリプト](https://note.com/kohya_ss/n/n374f316fe4ad)
@@ -44,12 +46,11 @@ PowerShellを使う場合、venvを使えるようにするためには以下の
通常の管理者ではないPowerShellを開き以下を順に実行します。
```powershell
git clone https://github.com/kohya-ss/sd-scripts.git
cd sd-scripts
python -m venv --system-site-packages venv
python -m venv venv
.\venv\Scripts\activate
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
@@ -63,6 +64,12 @@ cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_set
accelerate config
```
<!--
pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117
pip install --use-pep517 --upgrade -r requirements.txt
pip install -U -I --no-deps xformers==0.0.16
-->
コマンドプロンプトでは以下になります。
@@ -70,7 +77,7 @@ accelerate config
git clone https://github.com/kohya-ss/sd-scripts.git
cd sd-scripts
python -m venv --system-site-packages venv
python -m venv venv
.\venv\Scripts\activate
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
@@ -84,6 +91,8 @@ copy /y .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cud
accelerate config
```
(注:``python -m venv venv`` のほうが ``python -m venv --system-site-packages venv`` より安全そうなため書き換えました。globalなpythonにパッケージがインストールしてあると、後者だといろいろと問題が起きます。
accelerate configの質問には以下のように答えてください。bf16で学習する場合、最後の質問にはbf16と答えてください。
※0.15.0から日本語環境では選択のためにカーソルキーを押すと落ちます……。数字キーの0、1、2……で選択できますので、そちらを使ってください。
@@ -101,6 +110,10 @@ accelerate configの質問には以下のように答えてください。bf1
※場合によって ``ValueError: fp16 mixed precision requires a GPU`` というエラーが出ることがあるようです。この場合、6番目の質問
``What GPU(s) (by id) should be used for training on this machine as a comma-separated list? [all]:``に「0」と答えてください。id `0`のGPUが使われます。
### PyTorchとxformersのバージョンについて
他のバージョンでは学習がうまくいかない場合があるようです。特に他の理由がなければ指定のバージョンをお使いください。
## アップグレード
新しいリリースがあった場合、以下のコマンドで更新できます。
@@ -109,7 +122,7 @@ accelerate configの質問には以下のように答えてください。bf1
cd sd-scripts
git pull
.\venv\Scripts\activate
pip install --upgrade -r <requirement file name>
pip install --use-pep517 --upgrade -r requirements.txt
```
コマンドが成功すれば新しいバージョンが使用できます。

148
README.md
View File

@@ -1,17 +1,7 @@
This repository contains training, generation and utility scripts for Stable Diffusion.
## Updates
- 22 Jan. 2023, 2023/1/22
- Fix script to check LoRA weights ``check_lora_weights.py``. Some layer weights were shown as ``0.0`` even if the layer is trained, because of the overflow of ``torch.mean``. Sorry for the confusion.
- Noe the script shows the mean of the absolute values of the weights, and the minimum of the absolute values of the weights.
- LoRAの重みをチェックするスクリプト ``check_lora_weights.py`` を修正しました。一部のレイヤーで学習されているにもかかわらず重みが ``0.0`` と表示されていました。混乱を招き申し訳ありません。
- スクリプトを「重みの絶対の平均」と「重みの絶対値の最小値」を表示するよう修正しました。
Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates.
最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。
##
[__Change History__](#change-history) is moved to the bottom of the page.
更新履歴は[ページ末尾](#change-history)に移しました。
[日本語版README](./README-ja.md)
@@ -20,10 +10,13 @@ For easier use (GUI and PowerShell scripts etc...), please visit [the repository
This repository contains the scripts for:
* DreamBooth training, including U-Net and Text Encoder
* fine-tuning (native training), including U-Net and Text Encoder
* Fine-tuning (native training), including U-Net and Text Encoder
* LoRA training
* image generation
* model conversion (supports 1.x and 2.x, Stable Diffision ckpt/safetensors and Diffusers)
* Texutl Inversion training
* Image generation
* Model conversion (supports 1.x and 2.x, Stable Diffision ckpt/safetensors and Diffusers)
__Stable Diffusion web UI now seems to support LoRA trained by ``sd-scripts``.__ (SD 1.x based only) Thank you for great work!!!
## About requirements.txt
@@ -39,6 +32,7 @@ All documents are in Japanese currently, and CUI based.
* [Step by Step fine-tuning guide](./fine_tune_README_ja.md):
Including BLIP captioning and tagging by DeepDanbooru or WD14 tagger
* [training LoRA](./train_network_README-ja.md)
* [training Textual Inversion](./train_ti_README-ja.md)
* note.com [Image generation](https://note.com/kohya_ss/n/n2693183a798e)
* note.com [Model conversion](https://note.com/kohya_ss/n/n374f316fe4ad)
@@ -63,7 +57,7 @@ Open a regular Powershell terminal and type the following inside:
git clone https://github.com/kohya-ss/sd-scripts.git
cd sd-scripts
python -m venv --system-site-packages venv
python -m venv venv
.\venv\Scripts\activate
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
@@ -75,9 +69,10 @@ cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\ce
cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
accelerate config
```
update: ``python -m venv venv`` is seemed to be safer than ``python -m venv --system-site-packages venv`` (some user have packages in global python).
Answers to accelerate config:
```txt
@@ -95,6 +90,11 @@ note: Some user reports ``ValueError: fp16 mixed precision requires a GPU`` is o
(Single GPU with id `0` will be used.)
### about PyTorch and xformers
Other versions of PyTorch and xformers seem to have problems with training.
If there is no other reason, please install the specified version.
## Upgrade
When a new release comes out you can upgrade your repo with the following command:
@@ -103,7 +103,7 @@ When a new release comes out you can upgrade your repo with the following comman
cd sd-scripts
git pull
.\venv\Scripts\activate
pip install --upgrade -r requirements.txt
pip install --use-pep517 --upgrade -r requirements.txt
```
Once the commands have completed successfully you should be ready to use the new version.
@@ -121,3 +121,115 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
[bitsandbytes](https://github.com/TimDettmers/bitsandbytes): MIT
[BLIP](https://github.com/salesforce/BLIP): BSD-3-Clause
## Change History
- 19 Feb. 2023, 2023/2/19:
- Add ``--use_lion_optimizer`` to each training script to use [Lion optimizer](https://github.com/lucidrains/lion-pytorch).
- Please install Lion optimizer with ``pip install lion-pytorch`` (it is not in ``requirements.txt`` currently.)
- Add ``--lowram`` option to ``train_network.py``. Load models to VRAM instead of VRAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle). Thanks to Isotr0py!
- Default behavior (without lowram) has reverted to the same as before 14 Feb.
- Fixed git commit hash to be set correctly regardless of the working directory. Thanks to vladmandic!
- ``--use_lion_optimizer`` オプションを各学習スクリプトに追加しました。 [Lion optimizer](https://github.com/lucidrains/lion-pytorch) を使用できます。
- あらかじめ ``pip install lion-pytorch`` でインストールしてください(現在は ``requirements.txt`` に含まれていません)。
- ``--lowram`` オプションを ``train_network.py`` に追加しました。モデルをRAMではなくVRAMに読み込みますColabやKaggleなど、VRAMがRAMに比べて多い環境で有効です。 Isotr0py 氏に感謝します。
- lowram オプションなしのデフォルト動作は2/14より前と同じに戻しました。
- git commit hash を現在のフォルダ位置に関わらず正しく取得するように修正しました。vladmandic 氏に感謝します。
- 16 Feb. 2023, 2023/2/16:
- Noise offset is recorded to the metadata. Thanks to space-nuko!
- Show the moving average loss to prevent loss jumping in ``train_network.py`` and ``train_db.py``. Thanks to shirayu!
- Noise offsetがメタデータに記録されるようになりました。space-nuko氏に感謝します。
- ``train_network.py``と``train_db.py``で学習中に表示されるlossの値が移動平均になりました。epochの先頭で表示されるlossが大きく変動する事象を解決します。shirayu氏に感謝します。
- 14 Feb. 2023, 2023/2/14:
- Add support with multi-gpu trainining for ``train_network.py``. Thanks to Isotr0py!
- Add ``--verbose`` option for ``resize_lora.py``. For details, see [this PR](https://github.com/kohya-ss/sd-scripts/pull/179). Thanks to mgz-dev!
- Git commit hash is added to the metadata for LoRA. Thanks to space-nuko!
- Add ``--noise_offset`` option for each training scripts.
- Implementation of https://www.crosslabs.org//blog/diffusion-with-offset-noise
- This option may improve ability to generate darker/lighter images. May work with LoRA.
- ``train_network.py``でマルチGPU学習をサポートしました。Isotr0py氏に感謝します。
- ``--verbose``オプションを ``resize_lora.py`` に追加しました。表示される情報の詳細は [こちらのPR](https://github.com/kohya-ss/sd-scripts/pull/179) をご参照ください。mgz-dev氏に感謝します。
- LoRAのメタデータにgitのcommit hashを追加しました。space-nuko氏に感謝します。
- ``--noise_offset`` オプションを各学習スクリプトに追加しました。
- こちらの記事の実装になります: https://www.crosslabs.org//blog/diffusion-with-offset-noise
- 全体的に暗い、明るい画像の生成結果が良くなる可能性があるようです。LoRA学習でも有効なようです。
- 11 Feb. 2023, 2023/2/11:
- ``lora_interrogator.py`` is added in ``networks`` folder. See ``python networks\lora_interrogator.py -h`` for usage.
- For LoRAs where the activation word is unknown, this script compares the output of Text Encoder after applying LoRA to that of unapplied to find out which token is affected by LoRA. Hopefully you can figure out the activation word. LoRA trained with captions does not seem to be able to interrogate.
- Batch size can be large (like 64 or 128).
- ``train_textual_inversion.py`` now supports multiple init words.
- Following feature is reverted to be the same as before. Sorry for confusion:
> Now the number of data in each batch is limited to the number of actual images (not duplicated). Because a certain bucket may contain smaller number of actual images, so the batch may contain same (duplicated) images.
- ``lora_interrogator.py`` を ``network``フォルダに追加しました。使用法は ``python networks\lora_interrogator.py -h`` でご確認ください。
- このスクリプトは、起動promptがわからないLoRAについて、LoRA適用前後のText Encoderの出力を比較することで、どのtokenの出力が変化しているかを調べます。運が良ければ起動用の単語が分かります。キャプション付きで学習されたLoRAは影響が広範囲に及ぶため、調査は難しいようです。
- バッチサイズはわりと大きくできます64や128など
- ``train_textual_inversion.py`` で複数のinit_word指定が可能になりました。
- 次の機能を削除し元に戻しました。混乱を招き申し訳ありません。
> これらのオプションによりbucketが細分化され、ひとつのバッチ内に同一画像が重複して存在することが増えたため、バッチサイズを``そのbucketの画像種類数``までに制限する機能を追加しました。
- 10 Feb. 2023, 2023/2/10:
- Updated ``requirements.txt`` to prevent upgrading with pip taking a long time or failure to upgrade.
- ``resize_lora.py`` keeps the metadata of the model. ``dimension is resized from ...`` is added to the top of ``ss_training_comment``.
- ``merge_lora.py`` supports models with different ``alpha``s. If there is a problem, old version is ``merge_lora_old.py``.
- ``svd_merge_lora.py`` is added. This script merges LoRA models with any rank (dim) and alpha, and approximate a new LoRA with svd for a specified rank (dim).
- Note: merging scripts erase the metadata currently.
- ``resize_images_to_resolution.py`` supports multibyte characters in filenames.
- pipでの更新が長時間掛かったり、更新に失敗したりするのを防ぐため、``requirements.txt``を更新しました。
- ``resize_lora.py``がメタデータを保持するようになりました。 ``dimension is resized from ...`` という文字列が ``ss_training_comment`` の先頭に追加されます。
- ``merge_lora.py``がalphaが異なるモデルをサポートしました。 何か問題がありましたら旧バージョン ``merge_lora_old.py`` をお使いください。
- ``svd_merge_lora.py`` を追加しました。 複数の任意のdim (rank)、alphaのLoRAモデルをマージし、svdで任意dim(rank)のLoRAで近似します。
- 注:マージ系のスクリプトは現時点ではメタデータを消去しますのでご注意ください。
- ``resize_images_to_resolution.py``が日本語ファイル名をサポートしました。
- 9 Feb. 2023, 2023/2/9:
- Caption dropout is supported in ``train_db.py``, ``fine_tune.py`` and ``train_network.py``. Thanks to forestsource!
- ``--caption_dropout_rate`` option specifies the dropout rate for captions (0~1.0, 0.1 means 10% chance for dropout). If dropout occurs, the image is trained with the empty caption. Default is 0 (no dropout).
- ``--caption_dropout_every_n_epochs`` option specifies how many epochs to drop captions. If ``3`` is specified, in epoch 3, 6, 9 ..., images are trained with all captions empty. Default is None (no dropout).
- ``--caption_tag_dropout_rate`` option specified the dropout rate for tags (comma separated tokens) (0~1.0, 0.1 means 10% chance for dropout). If dropout occurs, the tag is removed from the caption. If ``--keep_tokens`` option is set, these tokens (tags) are not dropped. Default is 0 (no droupout).
- The bulk image downsampling script is added. Documentation is [here](https://github.com/kohya-ss/sd-scripts/blob/main/train_network_README-ja.md#%E7%94%BB%E5%83%8F%E3%83%AA%E3%82%B5%E3%82%A4%E3%82%BA%E3%82%B9%E3%82%AF%E3%83%AA%E3%83%97%E3%83%88) (in Jpanaese). Thanks to bmaltais!
- Typo check is added. Thanks to shirayu!
- キャプションのドロップアウトを``train_db.py``、``fine_tune.py``、``train_network.py``の各スクリプトに追加しました。forestsource氏に感謝します。
- ``--caption_dropout_rate``オプションでキャプションのドロップアウト率を指定します0~1.0、 0.1を指定すると10%の確率でドロップアウト)。ドロップアウトされた場合、画像は空のキャプションで学習されます。デフォルトは 0 (ドロップアウトなし)です。
- ``--caption_dropout_every_n_epochs`` オプションで何エポックごとにキャプションを完全にドロップアウトするか指定します。たとえば``3``を指定すると、エポック3、6、9……で、すべての画像がキャプションなしで学習されます。デフォルトは None (ドロップアウトなし)です。
- ``--caption_tag_dropout_rate`` オプションで各タグカンマ区切りの各部分のドロップアウト率を指定します0~1.0、 0.1を指定すると10%の確率でドロップアウト)。ドロップアウトが起きるとそのタグはそのときだけキャプションから取り除かれて学習されます。``--keep_tokens`` オプションを指定していると、シャッフルされない部分のタグはドロップアウトされません。デフォルトは 0 (ドロップアウトなし)です。
- 画像の一括縮小スクリプトを追加しました。ドキュメントは [こちら](https://github.com/kohya-ss/sd-scripts/blob/main/train_network_README-ja.md#%E7%94%BB%E5%83%8F%E3%83%AA%E3%82%B5%E3%82%A4%E3%82%BA%E3%82%B9%E3%82%AF%E3%83%AA%E3%83%97%E3%83%88) です。bmaltais氏に感謝します。
- 誤字チェッカが追加されました。shirayu氏に感謝します。
- 6 Feb. 2023, 2023/2/6
- ``--bucket_reso_steps`` and ``--bucket_no_upscale`` options are added to training scripts (fine tuning, DreamBooth, LoRA and Textual Inversion) and ``prepare_buckets_latents.py``.
- ``--bucket_reso_steps`` takes the steps for buckets in aspect ratio bucketing. Default is 64, same as before.
- Any value greater than or equal to 1 can be specified; 64 is highly recommended and a value divisible by 8 is recommended.
- If less than 64 is specified, padding will occur within U-Net. The result is unknown.
- If you specify a value that is not divisible by 8, it will be truncated to divisible by 8 inside VAE, because the size of the latent is 1/8 of the image size.
- If ``--bucket_no_upscale`` option is specified, images smaller than the bucket size will be processed without upscaling.
- Internally, a bucket smaller than the image size is created (for example, if the image is 300x300 and ``bucket_reso_steps=64``, the bucket is 256x256). The image will be trimmed.
- Implementation of [#130](https://github.com/kohya-ss/sd-scripts/issues/130).
- Images with an area larger than the maximum size specified by ``--resolution`` are downsampled to the max bucket size.
- Now the number of data in each batch is limited to the number of actual images (not duplicated). Because a certain bucket may contain smaller number of actual images, so the batch may contain same (duplicated) images.
- ``--random_crop`` now also works with buckets enabled.
- Instead of always cropping the center of the image, the image is shifted left, right, up, and down to be used as the training data. This is expected to train to the edges of the image.
- Implementation of discussion [#34](https://github.com/kohya-ss/sd-scripts/discussions/34).
- ``--bucket_reso_steps``および``--bucket_no_upscale``オプションを、学習スクリプトおよび``prepare_buckets_latents.py``に追加しました。
- ``--bucket_reso_steps``オプションでは、bucketの解像度の単位を指定できます。デフォルトは64で、今までと同じ動作です。
- 1以上の任意の値を指定できます。基本的には64を推奨します。64以外の値では、8で割り切れる値を推奨します。
- 64未満を指定するとU-Netの内部でpaddingが発生します。どのような結果になるかは未知数です。
- 8で割り切れない値を指定すると余りはVAE内部で切り捨てられます。
- ``--bucket_no_upscale``オプションを指定すると、bucketサイズよりも小さい画像は拡大せずそのまま処理します。
- 内部的には画像サイズ以下のサイズのbucketを作成しますたとえば画像が300x300で``bucket_reso_steps=64``の場合、256x256のbucket。余りは都度trimmingされます。
- [#130](https://github.com/kohya-ss/sd-scripts/issues/130) を実装したものです。
- ``--resolution``で指定した最大サイズよりも面積が大きい画像は、最大サイズと同じ面積になるようアスペクト比を維持したまま縮小され、そのサイズを元にbucketが作られます。
- これらのオプションによりbucketが細分化され、ひとつのバッチ内に同一画像が重複して存在することが増えたため、バッチサイズを``そのbucketの画像種類数``までに制限する機能を追加しました。
- たとえば繰り返し回数10で、あるbucketに1枚しか画像がなく、バッチサイズが10以上のとき、今まではepoch内で、同一画像を10枚含むバッチが1回だけ使用されていました。
- 機能追加後はepoch内にサイズ1のバッチが10回、使用されます。
- ``--random_crop``がbucketを有効にした場合にも機能するようになりました。
- 常に画像の中央を切り取るのではなく、左右、上下にずらして教師データにします。これにより画像端まで学習されることが期待されます。
- discussionの[#34](https://github.com/kohya-ss/sd-scripts/discussions/34)を実装したものです。
Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates.
最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。

15
_typos.toml Normal file
View File

@@ -0,0 +1,15 @@
# Files for typos
# Instruction: https://github.com/marketplace/actions/typos-action#getting-started
[default.extend-identifiers]
[default.extend-words]
NIN="NIN"
parms="parms"
nin="nin"
extention="extention" # Intentionally left
nd="nd"
[files]
extend-exclude = ["_typos.toml"]

View File

@@ -33,8 +33,13 @@ def train(args):
train_dataset = train_util.FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
args.bucket_reso_steps, args.bucket_no_upscale,
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
args.dataset_repeats, args.debug_dataset)
# 学習データのdropout率を設定する
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
train_dataset.make_buckets()
if args.debug_dataset:
@@ -153,6 +158,13 @@ def train(args):
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
print("use 8-bit Adam optimizer")
optimizer_class = bnb.optim.AdamW8bit
elif args.use_lion_optimizer:
try:
import lion_pytorch
except ImportError:
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
print("use Lion optimizer")
optimizer_class = lion_pytorch.Lion
else:
optimizer_class = torch.optim.AdamW
@@ -163,7 +175,7 @@ def train(args):
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers)
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
@@ -200,6 +212,8 @@ def train(args):
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
# 学習する
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
@@ -223,6 +237,8 @@ def train(args):
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset.set_current_epoch(epoch + 1)
for m in training_models:
m.train()
@@ -246,6 +262,9 @@ def train(args):
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
@@ -329,7 +348,7 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, False, True)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
train_util.add_sd_saving_arguments(parser)

View File

@@ -324,7 +324,7 @@ __※引数を都度書き換えて、別のメタデータファイルに書き
## 学習の実行
たとえば以下のように実行します。以下は省メモリ化のための設定です。
```
accelerate launch --num_cpu_threads_per_process 8 fine_tune.py
accelerate launch --num_cpu_threads_per_process 1 fine_tune.py
--pretrained_model_name_or_path=model.ckpt
--in_json meta_lat.json
--train_data_dir=train_data
@@ -336,7 +336,7 @@ accelerate launch --num_cpu_threads_per_process 8 fine_tune.py
--save_every_n_epochs=4
```
accelerateのnum_cpu_threads_per_processにはCPUのコア数を指定するとよいようです。
accelerateのnum_cpu_threads_per_processには通常は1を指定するとよいようです。
pretrained_model_name_or_pathに学習対象のモデルを指定しますStable DiffusionのcheckpointかDiffusersのモデル。Stable Diffusionのcheckpointは.ckptと.safetensorsに対応しています拡張子で自動判定

View File

@@ -5,13 +5,32 @@ import argparse
import glob
import os
import json
import re
from tqdm import tqdm
PATTERN_HAIR_LENGTH = re.compile(r', (long|short|medium) hair, ')
PATTERN_HAIR_CUT = re.compile(r', (bob|hime) cut, ')
PATTERN_HAIR = re.compile(r', ([\w\-]+) hair, ')
PATTERN_WORD = re.compile(r', ([\w\-]+|hair ornament), ')
# 複数人がいるとき、複数の髪色や目の色が定義されていれば削除する
PATTERNS_REMOVE_IN_MULTI = [
PATTERN_HAIR_LENGTH,
PATTERN_HAIR_CUT,
re.compile(r', [\w\-]+ eyes, '),
re.compile(r', ([\w\-]+ sleeves|sleeveless), '),
# 複数の髪型定義がある場合は削除する
re.compile(
r', (ponytail|braid|ahoge|twintails|[\w\-]+ bun|single hair bun|single side bun|two side up|two tails|[\w\-]+ braid|sidelocks), '),
]
def clean_tags(image_key, tags):
# replace '_' to ' '
tags = tags.replace('^_^', '^@@@^')
tags = tags.replace('_', ' ')
tags = tags.replace('^@@@^', '^_^')
# remove rating: deepdanbooruのみ
tokens = tags.split(", rating")
@@ -26,6 +45,37 @@ def clean_tags(image_key, tags):
print(f"{image_key} {tags}")
tags = tokens[0]
tags = ", " + tags.replace(", ", ", , ") + ", " # カンマ付きで検索をするための身も蓋もない対策
# 複数の人物がいる場合は髪色等のタグを削除する
if 'girls' in tags or 'boys' in tags:
for pat in PATTERNS_REMOVE_IN_MULTI:
found = pat.findall(tags)
if len(found) > 1: # 二つ以上、タグがある
tags = pat.sub("", tags)
# 髪の特殊対応
srch_hair_len = PATTERN_HAIR_LENGTH.search(tags) # 髪の長さタグは例外なので避けておく(全員が同じ髪の長さの場合)
if srch_hair_len:
org = srch_hair_len.group()
tags = PATTERN_HAIR_LENGTH.sub(", @@@, ", tags)
found = PATTERN_HAIR.findall(tags)
if len(found) > 1:
tags = PATTERN_HAIR.sub("", tags)
if srch_hair_len:
tags = tags.replace(", @@@, ", org) # 戻す
# white shirtとshirtみたいな重複タグの削除
found = PATTERN_WORD.findall(tags)
for word in found:
if re.search(f", ((\w+) )+{word}, ", tags):
tags = tags.replace(f", {word}, ", "")
tags = tags.replace(", , ", ", ")
assert tags.startswith(", ") and tags.endswith(", ")
tags = tags[2:-2]
return tags
@@ -88,13 +138,23 @@ def main(args):
if tags is None:
print(f"image does not have tags / メタデータにタグがありません: {image_key}")
else:
metadata[image_key]['tags'] = clean_tags(image_key, tags)
org = tags
tags = clean_tags(image_key, tags)
metadata[image_key]['tags'] = tags
if args.debug and org != tags:
print("FROM: " + org)
print("TO: " + tags)
caption = metadata[image_key].get('caption')
if caption is None:
print(f"image does not have caption / メタデータにキャプションがありません: {image_key}")
else:
metadata[image_key]['caption'] = clean_caption(caption)
org = caption
caption = clean_caption(caption)
metadata[image_key]['caption'] = caption
if args.debug and org != caption:
print("FROM: " + org)
print("TO: " + caption)
# metadataを書き出して終わり
print(f"writing metadata: {args.out_json}")
@@ -108,6 +168,7 @@ if __name__ == '__main__':
# parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("--debug", action="store_true", help="debug mode")
args, unknown = parser.parse_known_args()
if len(unknown) == 1:

View File

@@ -11,18 +11,59 @@ import torch
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from blip.blip import blip_decoder
# from Salesforce_BLIP.models.blip import blip_decoder
import library.train_util as train_util
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
IMAGE_SIZE = 384
# 正方形でいいのか? という気がするがソースがそうなので
IMAGE_TRANSFORM = transforms.Compose([
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])
# 共通化したいが微妙に処理が異なる……
class ImageLoadingTransformDataset(torch.utils.data.Dataset):
def __init__(self, image_paths):
self.images = image_paths
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = self.images[idx]
try:
image = Image.open(img_path).convert("RGB")
# convert to tensor temporarily so dataloader will accept it
tensor = IMAGE_TRANSFORM(image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
return None
return (tensor, img_path)
def collate_fn_remove_corrupted(batch):
"""Collate function that allows to remove corrupted examples in the
dataloader. It expects that the dataloader returns 'None' when that occurs.
The 'None's in the batch are removed.
"""
# Filter out all the Nones (corrupted examples)
batch = list(filter(lambda x: x is not None, batch))
return batch
def main(args):
# fix the seed for reproducibility
seed = args.seed # + utils.get_rank()
seed = args.seed # + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if not os.path.exists("blip"):
args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path
@@ -31,24 +72,15 @@ def main(args):
os.chdir('finetune')
print(f"load images from {args.train_data_dir}")
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
image_paths = train_util.glob_images(args.train_data_dir)
print(f"found {len(image_paths)} images.")
print(f"loading BLIP caption: {args.caption_weights}")
image_size = 384
model = blip_decoder(pretrained=args.caption_weights, image_size=image_size, vit='large', med_config="./blip/med_config.json")
model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit='large', med_config="./blip/med_config.json")
model.eval()
model = model.to(DEVICE)
print("BLIP loaded")
# 正方形でいいのか? という気がするがソースがそうなので
transform = transforms.Compose([
transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])
# captioningする
def run_batch(path_imgs):
imgs = torch.stack([im for _, im in path_imgs]).to(DEVICE)
@@ -66,18 +98,35 @@ def main(args):
if args.debug:
print(image_path, caption)
b_imgs = []
for image_path in tqdm(image_paths, smoothing=0.0):
raw_image = Image.open(image_path)
if raw_image.mode != "RGB":
print(f"convert image mode {raw_image.mode} to RGB: {image_path}")
raw_image = raw_image.convert("RGB")
# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
dataset = ImageLoadingTransformDataset(image_paths)
data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
else:
data = [[(None, ip)] for ip in image_paths]
image = transform(raw_image)
b_imgs.append((image_path, image))
if len(b_imgs) >= args.batch_size:
run_batch(b_imgs)
b_imgs.clear()
b_imgs = []
for data_entry in tqdm(data, smoothing=0.0):
for data in data_entry:
if data is None:
continue
img_tensor, image_path = data
if img_tensor is None:
try:
raw_image = Image.open(image_path)
if raw_image.mode != 'RGB':
raw_image = raw_image.convert("RGB")
img_tensor = IMAGE_TRANSFORM(raw_image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue
b_imgs.append((image_path, img_tensor))
if len(b_imgs) >= args.batch_size:
run_batch(b_imgs)
b_imgs.clear()
if len(b_imgs) > 0:
run_batch(b_imgs)
@@ -95,6 +144,8 @@ if __name__ == '__main__':
parser.add_argument("--beam_search", action="store_true",
help="use beam search (default Nucleus sampling) / beam searchを使うこのオプション未指定時はNucleus sampling")
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する読み込みを高速化")
parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数多いと精度が上がるが時間がかかる")
parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p")
parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長")

View File

@@ -0,0 +1,145 @@
import argparse
import os
import re
from PIL import Image
from tqdm import tqdm
import torch
from transformers import AutoProcessor, AutoModelForCausalLM
from transformers.generation.utils import GenerationMixin
import library.train_util as train_util
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
PATTERN_REPLACE = [
re.compile(r'(has|with|and) the (words?|letters?|name) (" ?[^"]*"|\w+)( ?(is )?(on|in) (the |her |their |him )?\w+)?'),
re.compile(r'(with a sign )?that says ?(" ?[^"]*"|\w+)( ?on it)?'),
re.compile(r"(with a sign )?that says ?(' ?(i'm)?[^']*'|\w+)( ?on it)?"),
re.compile(r'with the number \d+ on (it|\w+ \w+)'),
re.compile(r'with the words "'),
re.compile(r'word \w+ on it'),
re.compile(r'that says the word \w+ on it'),
re.compile('that says\'the word "( on it)?'),
]
# 誤検知しまくりの with the word xxxx を消す
def remove_words(captions, debug):
removed_caps = []
for caption in captions:
cap = caption
for pat in PATTERN_REPLACE:
cap = pat.sub("", cap)
if debug and cap != caption:
print(caption)
print(cap)
removed_caps.append(cap)
return removed_caps
def collate_fn_remove_corrupted(batch):
"""Collate function that allows to remove corrupted examples in the
dataloader. It expects that the dataloader returns 'None' when that occurs.
The 'None's in the batch are removed.
"""
# Filter out all the Nones (corrupted examples)
batch = list(filter(lambda x: x is not None, batch))
return batch
def main(args):
# GITにバッチサイズが1より大きくても動くようにパッチを当てる: transformers 4.26.0用
org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation
curr_batch_size = [args.batch_size] # ループの最後で件数がbatch_size未満になるので入れ替えられるように
# input_idsがバッチサイズと同じ件数である必要があるバッチサイズはこの関数から参照できないので外から渡す
# ここより上で置き換えようとするとすごく大変
def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs):
input_ids = org_prepare_input_ids_for_generation(self, bos_token_id, encoder_outputs)
if input_ids.size()[0] != curr_batch_size[0]:
input_ids = input_ids.repeat(curr_batch_size[0], 1)
return input_ids
GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
print(f"load images from {args.train_data_dir}")
image_paths = train_util.glob_images(args.train_data_dir)
print(f"found {len(image_paths)} images.")
# できればcacheに依存せず明示的にダウンロードしたい
print(f"loading GIT: {args.model_id}")
git_processor = AutoProcessor.from_pretrained(args.model_id)
git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE)
print("GIT loaded")
# captioningする
def run_batch(path_imgs):
imgs = [im for _, im in path_imgs]
curr_batch_size[0] = len(path_imgs)
inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE) # 画像はpil形式
generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length)
captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True)
if args.remove_words:
captions = remove_words(captions, args.debug)
for (image_path, _), caption in zip(path_imgs, captions):
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
f.write(caption + "\n")
if args.debug:
print(image_path, caption)
# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
dataset = train_util.ImageLoadingDataset(image_paths)
data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
else:
data = [[(None, ip)] for ip in image_paths]
b_imgs = []
for data_entry in tqdm(data, smoothing=0.0):
for data in data_entry:
if data is None:
continue
image, image_path = data
if image is None:
try:
image = Image.open(image_path)
if image.mode != 'RGB':
image = image.convert("RGB")
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue
b_imgs.append((image_path, image))
if len(b_imgs) >= args.batch_size:
run_batch(b_imgs)
b_imgs.clear()
if len(b_imgs) > 0:
run_batch(b_imgs)
print("done!")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
parser.add_argument("--model_id", type=str, default="microsoft/git-large-textcaps",
help="model id for GIT in Hugging Face / 使用するGITのHugging FaceのモデルID")
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する読み込みを高速化")
parser.add_argument("--max_length", type=int, default=50, help="max length of caption / captionの最大長")
parser.add_argument("--remove_words", action="store_true",
help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する")
parser.add_argument("--debug", action="store_true", help="debug mode")
args = parser.parse_args()
main(args)

View File

@@ -1,26 +1,24 @@
# このスクリプトのライセンスは、Apache License 2.0とします
# (c) 2022 Kohya S. @kohya_ss
import argparse
import glob
import os
import json
from pathlib import Path
from typing import List
from tqdm import tqdm
import library.train_util as train_util
def main(args):
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
train_data_dir_path = Path(args.train_data_dir)
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.")
if args.in_json is None and os.path.isfile(args.out_json):
if args.in_json is None and Path(args.out_json).is_file():
args.in_json = args.out_json
if args.in_json is not None:
print(f"loading existing metadata: {args.in_json}")
with open(args.in_json, "rt", encoding='utf-8') as f:
metadata = json.load(f)
metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
print("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます")
else:
print("new metadata will be created / 新しいメタデータファイルが作成されます")
@@ -28,11 +26,10 @@ def main(args):
print("merge caption texts to metadata json.")
for image_path in tqdm(image_paths):
caption_path = os.path.splitext(image_path)[0] + args.caption_extension
with open(caption_path, "rt", encoding='utf-8') as f:
caption = f.readlines()[0].strip()
caption_path = image_path.with_suffix(args.caption_extension)
caption = caption_path.read_text(encoding='utf-8').strip()
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
image_key = str(image_path) if args.full_path else image_path.stem
if image_key not in metadata:
metadata[image_key] = {}
@@ -42,8 +39,7 @@ def main(args):
# metadataを書き出して終わり
print(f"writing metadata: {args.out_json}")
with open(args.out_json, "wt", encoding='utf-8') as f:
json.dump(metadata, f, indent=2)
Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
print("done!")
@@ -51,12 +47,15 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("--in_json", type=str, help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル省略時、out_jsonが存在すればそれを読み込む")
parser.add_argument("--in_json", type=str,
help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル省略時、out_jsonが存在すればそれを読み込む")
parser.add_argument("--caption_extention", type=str, default=None,
help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子")
parser.add_argument("--full_path", action="store_true",
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
parser.add_argument("--recursive", action="store_true",
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")
parser.add_argument("--debug", action="store_true", help="debug mode")
args = parser.parse_args()

View File

@@ -1,26 +1,24 @@
# このスクリプトのライセンスは、Apache License 2.0とします
# (c) 2022 Kohya S. @kohya_ss
import argparse
import glob
import os
import json
from pathlib import Path
from typing import List
from tqdm import tqdm
import library.train_util as train_util
def main(args):
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
train_data_dir_path = Path(args.train_data_dir)
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.")
if args.in_json is None and os.path.isfile(args.out_json):
if args.in_json is None and Path(args.out_json).is_file():
args.in_json = args.out_json
if args.in_json is not None:
print(f"loading existing metadata: {args.in_json}")
with open(args.in_json, "rt", encoding='utf-8') as f:
metadata = json.load(f)
metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
print("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます")
else:
print("new metadata will be created / 新しいメタデータファイルが作成されます")
@@ -28,11 +26,10 @@ def main(args):
print("merge tags to metadata json.")
for image_path in tqdm(image_paths):
tags_path = os.path.splitext(image_path)[0] + '.txt'
with open(tags_path, "rt", encoding='utf-8') as f:
tags = f.readlines()[0].strip()
tags_path = image_path.with_suffix(args.caption_extension)
tags = tags_path.read_text(encoding='utf-8').strip()
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
image_key = str(image_path) if args.full_path else image_path.stem
if image_key not in metadata:
metadata[image_key] = {}
@@ -42,8 +39,8 @@ def main(args):
# metadataを書き出して終わり
print(f"writing metadata: {args.out_json}")
with open(args.out_json, "wt", encoding='utf-8') as f:
json.dump(metadata, f, indent=2)
Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
print("done!")
@@ -51,9 +48,14 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("--in_json", type=str, help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル省略時、out_jsonが存在すればそれを読み込む")
parser.add_argument("--in_json", type=str,
help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル省略時、out_jsonが存在すればそれを読み込む")
parser.add_argument("--full_path", action="store_true",
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
parser.add_argument("--recursive", action="store_true",
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")
parser.add_argument("--caption_extension", type=str, default=".txt",
help="extension of caption (tag) file / 読み込むキャプション(タグ)ファイルの拡張子")
parser.add_argument("--debug", action="store_true", help="debug mode, print tags")
args = parser.parse_args()

View File

@@ -1,20 +1,16 @@
# このスクリプトのライセンスは、Apache License 2.0とします
# (c) 2022 Kohya S. @kohya_ss
import argparse
import glob
import os
import json
from tqdm import tqdm
import numpy as np
from diffusers import AutoencoderKL
from PIL import Image
import cv2
import torch
from torchvision import transforms
import library.model_util as model_util
import library.train_util as train_util
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -26,6 +22,16 @@ IMAGE_TRANSFORMS = transforms.Compose(
)
def collate_fn_remove_corrupted(batch):
"""Collate function that allows to remove corrupted examples in the
dataloader. It expects that the dataloader returns 'None' when that occurs.
The 'None's in the batch are removed.
"""
# Filter out all the Nones (corrupted examples)
batch = list(filter(lambda x: x is not None, batch))
return batch
def get_latents(vae, images, weight_dtype):
img_tensors = [IMAGE_TRANSFORMS(image) for image in images]
img_tensors = torch.stack(img_tensors)
@@ -35,9 +41,22 @@ def get_latents(vae, images, weight_dtype):
return latents
def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip):
if is_full_path:
base_name = os.path.splitext(os.path.basename(image_key))[0]
else:
base_name = image_key
if flip:
base_name += '_flip'
return os.path.join(data_dir, base_name)
def main(args):
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
# assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります"
if args.bucket_reso_steps % 8 > 0:
print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります")
image_paths = train_util.glob_images(args.train_data_dir)
print(f"found {len(image_paths)} images.")
if os.path.exists(args.in_json):
@@ -62,89 +81,144 @@ def main(args):
max_reso = tuple([int(t) for t in args.max_resolution.split(',')])
assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
bucket_resos, bucket_aspect_ratios = model_util.make_bucket_resolutions(
max_reso, args.min_bucket_reso, args.max_bucket_reso)
bucket_manager = train_util.BucketManager(args.bucket_no_upscale, max_reso,
args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps)
if not args.bucket_no_upscale:
bucket_manager.make_buckets()
else:
print("min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます")
# 画像をひとつずつ適切なbucketに割り当てながらlatentを計算する
bucket_aspect_ratios = np.array(bucket_aspect_ratios)
buckets_imgs = [[] for _ in range(len(bucket_resos))]
bucket_counts = [0 for _ in range(len(bucket_resos))]
img_ar_errors = []
for i, image_path in enumerate(tqdm(image_paths, smoothing=0.0)):
def process_batch(is_last):
for bucket in bucket_manager.buckets:
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
latents = get_latents(vae, [img for _, img in bucket], weight_dtype)
assert latents.shape[2] == bucket[0][1].shape[0] // 8 and latents.shape[3] == bucket[0][1].shape[1] // 8, \
f"latent shape {latents.shape}, {bucket[0][1].shape}"
for (image_key, _), latent in zip(bucket, latents):
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False)
np.savez(npz_file_name, latent)
# flip
if args.flip_aug:
latents = get_latents(vae, [img[:, ::-1].copy() for _, img in bucket], weight_dtype) # copyがないとTensor変換できない
for (image_key, _), latent in zip(bucket, latents):
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True)
np.savez(npz_file_name, latent)
else:
# remove existing flipped npz
for image_key, _ in bucket:
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz"
if os.path.isfile(npz_file_name):
print(f"remove existing flipped npz / 既存のflipされたnpzファイルを削除します: {npz_file_name}")
os.remove(npz_file_name)
bucket.clear()
# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
dataset = train_util.ImageLoadingDataset(image_paths)
data = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False,
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
else:
data = [[(None, ip)] for ip in image_paths]
bucket_counts = {}
for data_entry in tqdm(data, smoothing=0.0):
if data_entry[0] is None:
continue
img_tensor, image_path = data_entry[0]
if img_tensor is not None:
image = transforms.functional.to_pil_image(img_tensor)
else:
try:
image = Image.open(image_path)
if image.mode != 'RGB':
image = image.convert("RGB")
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
if image_key not in metadata:
metadata[image_key] = {}
image = Image.open(image_path)
if image.mode != 'RGB':
image = image.convert("RGB")
# 本当はこのあとの部分もDataSetに持っていけば高速化できるがいろいろ大変
aspect_ratio = image.width / image.height
ar_errors = bucket_aspect_ratios - aspect_ratio
bucket_id = np.abs(ar_errors).argmin()
reso = bucket_resos[bucket_id]
ar_error = ar_errors[bucket_id]
reso, resized_size, ar_error = bucket_manager.select_bucket(image.width, image.height)
img_ar_errors.append(abs(ar_error))
bucket_counts[reso] = bucket_counts.get(reso, 0) + 1
# どのサイズにリサイズするか→トリミングする方向で
if ar_error <= 0: # 横が長い→縦を合わせる
scale = reso[1] / image.height
else:
scale = reso[0] / image.width
# メタデータに記録する解像度はlatent単位とするので、8単位で切り捨て
metadata[image_key]['train_resolution'] = (reso[0] - reso[0] % 8, reso[1] - reso[1] % 8)
resized_size = (int(image.width * scale + .5), int(image.height * scale + .5))
if not args.bucket_no_upscale:
# upscaleを行わないときには、resize後のサイズは、bucketのサイズと、縦横どちらかが同じであることを確認する
assert resized_size[0] == reso[0] or resized_size[1] == reso[
1], f"internal error, resized size not match: {reso}, {resized_size}, {image.width}, {image.height}"
assert resized_size[0] >= reso[0] and resized_size[1] >= reso[
1], f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}"
# print(image.width, image.height, bucket_id, bucket_resos[bucket_id], ar_errors[bucket_id], resized_size,
# bucket_resos[bucket_id][0] - resized_size[0], bucket_resos[bucket_id][1] - resized_size[1])
assert resized_size[0] == reso[0] or resized_size[1] == reso[
1], f"internal error, resized size not match: {reso}, {resized_size}, {image.width}, {image.height}"
assert resized_size[0] >= reso[0] and resized_size[1] >= reso[
1], f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}"
1], f"internal error resized size is small: {resized_size}, {reso}"
# 既に存在するファイルがあればshapeを確認して同じならskipする
if args.skip_existing:
npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False) + ".npz"]
if args.flip_aug:
npz_files.append(get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz")
found = True
for npz_file in npz_files:
if not os.path.exists(npz_file):
found = False
break
dat = np.load(npz_file)['arr_0']
if dat.shape[1] != reso[1] // 8 or dat.shape[2] != reso[0] // 8: # latentsのshapeを確認
found = False
break
if found:
continue
# 画像をリサイズしてトリミングする
# PILにinter_areaがないのでcv2で……
image = np.array(image)
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA)
if resized_size[0] != image.shape[1] or resized_size[1] != image.shape[0]: # リサイズ処理が必要?
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA)
if resized_size[0] > reso[0]:
trim_size = resized_size[0] - reso[0]
image = image[:, trim_size//2:trim_size//2 + reso[0]]
elif resized_size[1] > reso[1]:
if resized_size[1] > reso[1]:
trim_size = resized_size[1] - reso[1]
image = image[trim_size//2:trim_size//2 + reso[1]]
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
# # debug
# cv2.imwrite(f"r:\\test\\img_{i:05d}.jpg", image[:, :, ::-1])
# cv2.imwrite(f"r:\\test\\img_{len(img_ar_errors)}.jpg", image[:, :, ::-1])
# バッチへ追加
buckets_imgs[bucket_id].append((image_key, reso, image))
bucket_counts[bucket_id] += 1
metadata[image_key]['train_resolution'] = reso
bucket_manager.add_image(reso, (image_key, image))
# バッチを推論するか判定して推論する
is_last = i == len(image_paths) - 1
for j in range(len(buckets_imgs)):
bucket = buckets_imgs[j]
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
latents = get_latents(vae, [img for _, _, img in bucket], weight_dtype)
process_batch(False)
for (image_key, reso, _), latent in zip(bucket, latents):
npz_file_name = os.path.splitext(os.path.basename(image_key))[0] if args.full_path else image_key
np.savez(os.path.join(args.train_data_dir, npz_file_name), latent)
# 残りを処理する
process_batch(True)
# flip
if args.flip_aug:
latents = get_latents(vae, [img[:, ::-1].copy() for _, _, img in bucket], weight_dtype) # copyがないとTensor変換できない
for (image_key, reso, _), latent in zip(bucket, latents):
npz_file_name = os.path.splitext(os.path.basename(image_key))[0] if args.full_path else image_key
np.savez(os.path.join(args.train_data_dir, npz_file_name + '_flip'), latent)
bucket.clear()
for i, (reso, count) in enumerate(zip(bucket_resos, bucket_counts)):
print(f"bucket {i} {reso}: {count}")
bucket_manager.sort()
for i, reso in enumerate(bucket_manager.resos):
count = bucket_counts.get(reso, 0)
if count > 0:
print(f"bucket {i} {reso}: {count}")
img_ar_errors = np.array(img_ar_errors)
print(f"mean ar error: {np.mean(img_ar_errors)}")
@@ -162,18 +236,26 @@ if __name__ == '__main__':
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
parser.add_argument("--v2", action='store_true',
help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む')
help='not used (for backward compatibility) / 使用されません(互換性のため残してあります)')
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する読み込みを高速化")
parser.add_argument("--max_resolution", type=str, default="512,512",
help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)")
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度")
parser.add_argument("--bucket_reso_steps", type=int, default=64,
help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します")
parser.add_argument("--bucket_no_upscale", action="store_true",
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します")
parser.add_argument("--mixed_precision", type=str, default="no",
choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
parser.add_argument("--full_path", action="store_true",
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
parser.add_argument("--flip_aug", action="store_true",
help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する")
parser.add_argument("--skip_existing", action="store_true",
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップするflip_aug有効時は通常、反転の両方が存在する画像をスキップ")
args = parser.parse_args()
main(args)

View File

@@ -1,6 +1,3 @@
# このスクリプトのライセンスは、Apache License 2.0とします
# (c) 2022 Kohya S. @kohya_ss
import argparse
import csv
import glob
@@ -12,32 +9,87 @@ from tqdm import tqdm
import numpy as np
from tensorflow.keras.models import load_model
from huggingface_hub import hf_hub_download
import torch
import library.train_util as train_util
# from wd14 tagger
IMAGE_SIZE = 448
WD14_TAGGER_REPO = 'SmilingWolf/wd-v1-4-vit-tagger'
# wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2
DEFAULT_WD14_TAGGER_REPO = 'SmilingWolf/wd-v1-4-convnext-tagger-v2'
FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"]
SUB_DIR = "variables"
SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"]
CSV_FILE = FILES[-1]
def preprocess_image(image):
image = np.array(image)
image = image[:, :, ::-1] # RGB->BGR
# pad to square
size = max(image.shape[0:2])
pad_x = size - image.shape[1]
pad_y = size - image.shape[0]
pad_l = pad_x // 2
pad_t = pad_y // 2
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode='constant', constant_values=255)
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
image = image.astype(np.float32)
return image
class ImageLoadingPrepDataset(torch.utils.data.Dataset):
def __init__(self, image_paths):
self.images = image_paths
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = self.images[idx]
try:
image = Image.open(img_path).convert("RGB")
image = preprocess_image(image)
tensor = torch.tensor(image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
return None
return (tensor, img_path)
def collate_fn_remove_corrupted(batch):
"""Collate function that allows to remove corrupted examples in the
dataloader. It expects that the dataloader returns 'None' when that occurs.
The 'None's in the batch are removed.
"""
# Filter out all the Nones (corrupted examples)
batch = list(filter(lambda x: x is not None, batch))
return batch
def main(args):
# hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする
# depreacatedの警告が出るけどなくなったらその時
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
if not os.path.exists(args.model_dir) or args.force_download:
print("downloading wd14 tagger model from hf_hub")
print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
for file in FILES:
hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
for file in SUB_DIR_FILES:
hf_hub_download(args.repo_id, file, subfolder=SUB_DIR, cache_dir=os.path.join(
args.model_dir, SUB_DIR), force_download=True, force_filename=file)
else:
print("using existing wd14 tagger model")
# 画像を読み込む
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
image_paths = train_util.glob_images(args.train_data_dir)
print(f"found {len(image_paths)} images.")
print("loading model and labels")
@@ -72,7 +124,7 @@ def main(args):
# Everything else is tags: pick any where prediction confidence > threshold
tag_text = ""
for i, p in enumerate(prob[4:]): # numpyとか使うのが良いけど、まあそれほど数も多くないのでループで
if p >= args.thresh:
if p >= args.thresh and i < len(tags):
tag_text += ", " + tags[i]
if len(tag_text) > 0:
@@ -83,34 +135,37 @@ def main(args):
if args.debug:
print(image_path, tag_text)
# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
dataset = ImageLoadingPrepDataset(image_paths)
data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
else:
data = [[(None, ip)] for ip in image_paths]
b_imgs = []
for image_path in tqdm(image_paths, smoothing=0.0):
img = Image.open(image_path) # cv2は日本語ファイル名で死ぬのとモード変換したいのでpillowで開く
if img.mode != 'RGB':
img = img.convert("RGB")
img = np.array(img)
img = img[:, :, ::-1] # RGB->BGR
for data_entry in tqdm(data, smoothing=0.0):
for data in data_entry:
if data is None:
continue
# pad to square
size = max(img.shape[0:2])
pad_x = size - img.shape[1]
pad_y = size - img.shape[0]
pad_l = pad_x // 2
pad_t = pad_y // 2
img = np.pad(img, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode='constant', constant_values=255)
image, image_path = data
if image is not None:
image = image.detach().numpy()
else:
try:
image = Image.open(image_path)
if image.mode != 'RGB':
image = image.convert("RGB")
image = preprocess_image(image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue
b_imgs.append((image_path, image))
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
img = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
# cv2.imshow("img", img)
# cv2.waitKey()
# cv2.destroyAllWindows()
img = img.astype(np.float32)
b_imgs.append((image_path, img))
if len(b_imgs) >= args.batch_size:
run_batch(b_imgs)
b_imgs.clear()
if len(b_imgs) >= args.batch_size:
run_batch(b_imgs)
b_imgs.clear()
if len(b_imgs) > 0:
run_batch(b_imgs)
@@ -121,7 +176,7 @@ def main(args):
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("--repo_id", type=str, default=WD14_TAGGER_REPO,
parser.add_argument("--repo_id", type=str, default=DEFAULT_WD14_TAGGER_REPO,
help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID")
parser.add_argument("--model_dir", type=str, default="wd14_tagger_model",
help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ")
@@ -129,6 +184,8 @@ if __name__ == '__main__':
help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします")
parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値")
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する読み込みを高速化")
parser.add_argument("--caption_extention", type=str, default=None,
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子")

View File

@@ -470,6 +470,9 @@ class PipelineLike():
self.scheduler = scheduler
self.safety_checker = None
# Textual Inversion
self.token_replacements = {}
# CLIP guidance
self.clip_guidance_scale = clip_guidance_scale
self.clip_image_guidance_scale = clip_image_guidance_scale
@@ -484,6 +487,19 @@ class PipelineLike():
self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers)
self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD)
# Textual Inversion
def add_token_replacement(self, target_token_id, rep_token_ids):
self.token_replacements[target_token_id] = rep_token_ids
def replace_token(self, tokens):
new_tokens = []
for token in tokens:
if token in self.token_replacements:
new_tokens.extend(self.token_replacements[token])
else:
new_tokens.append(token)
return new_tokens
# region xformersとか使う部分独自に書き換えるので関係なし
def enable_xformers_memory_efficient_attention(self):
r"""
@@ -1507,6 +1523,9 @@ def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length:
for word, weight in texts_and_weights:
# tokenize and discard the starting and the ending token
token = pipe.tokenizer(word).input_ids[1:-1]
token = pipe.replace_token(token)
text_token += token
# copy the weight by length of token
text_weight += [weight] * len(token)
@@ -1826,12 +1845,12 @@ def main(args):
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt)
else:
print("load Diffusers pretrained models")
pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype)
text_encoder = pipe.text_encoder
vae = pipe.vae
unet = pipe.unet
tokenizer = pipe.tokenizer
del pipe
loading_pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype)
text_encoder = loading_pipe.text_encoder
vae = loading_pipe.vae
unet = loading_pipe.unet
tokenizer = loading_pipe.tokenizer
del loading_pipe
# VAEを読み込む
if args.vae is not None:
@@ -1981,7 +2000,6 @@ def main(args):
imported_module = importlib.import_module(network_module)
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
network_dim = None if args.network_dim is None or len(args.network_dim) <= i else args.network_dim[i]
net_kwargs = {}
if args.network_args and i < len(args.network_args):
@@ -1992,22 +2010,22 @@ def main(args):
key, value = net_arg.split("=")
net_kwargs[key] = value
network = imported_module.create_network(network_mul, network_dim, vae, text_encoder, unet, **net_kwargs)
if network is None:
return
if args.network_weights and i < len(args.network_weights):
network_weight = args.network_weights[i]
print("load network weights from:", network_weight)
if os.path.splitext(network_weight)[1] == '.safetensors':
if model_util.is_safetensors(network_weight):
from safetensors.torch import safe_open
with safe_open(network_weight, framework="pt") as f:
metadata = f.metadata()
if metadata is not None:
print(f"metadata for: {network_weight}: {metadata}")
network.load_weights(network_weight)
network = imported_module.create_network_from_weights(network_mul, network_weight, vae, text_encoder, unet, **net_kwargs)
else:
raise ValueError("No weight. Weight is required.")
if network is None:
return
network.apply_to(text_encoder, unet)
@@ -2040,6 +2058,44 @@ def main(args):
if args.diffusers_xformers:
pipe.enable_xformers_memory_efficient_attention()
# Textual Inversionを処理する
if args.textual_inversion_embeddings:
token_ids_embeds = []
for embeds_file in args.textual_inversion_embeddings:
if model_util.is_safetensors(embeds_file):
from safetensors.torch import load_file
data = load_file(embeds_file)
else:
data = torch.load(embeds_file, map_location="cpu")
embeds = next(iter(data.values()))
if type(embeds) != torch.Tensor:
raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {embeds_file}")
num_vectors_per_token = embeds.size()[0]
token_string = os.path.splitext(os.path.basename(embeds_file))[0]
token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)]
# add new word to tokenizer, count is num_vectors_per_token
num_added_tokens = tokenizer.add_tokens(token_strings)
assert num_added_tokens == num_vectors_per_token, f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}"
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids}")
assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered"
assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}"
if num_vectors_per_token > 1:
pipe.add_token_replacement(token_ids[0], token_ids)
token_ids_embeds.append((token_ids, embeds))
text_encoder.resize_token_embeddings(len(tokenizer))
token_embeds = text_encoder.get_input_embeddings().weight.data
for token_ids, embeds in token_ids_embeds:
for token_id, embed in zip(token_ids, embeds):
token_embeds[token_id] = embed
# promptを取得する
if args.from_file is not None:
print(f"reading prompts from {args.from_file}")
@@ -2158,8 +2214,8 @@ def main(args):
os.makedirs(args.outdir, exist_ok=True)
max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples
for iter in range(args.n_iter):
print(f"iteration {iter+1}/{args.n_iter}")
for gen_iter in range(args.n_iter):
print(f"iteration {gen_iter+1}/{args.n_iter}")
iter_seed = random.randint(0, 0x7fffffff)
# バッチ処理の関数
@@ -2526,10 +2582,10 @@ if __name__ == '__main__':
parser.add_argument("--network_weights", type=str, default=None, nargs='*',
help='Hypernetwork weights to load / Hypernetworkの重み')
parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率')
parser.add_argument("--network_dim", type=int, default=None, nargs='*',
help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)')
parser.add_argument("--network_args", type=str, default=None, nargs='*',
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
parser.add_argument("--textual_inversion_embeddings", type=str, default=None, nargs='*',
help='Embeddings files of Textual Inversion / Textual Inversionのembeddings')
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
parser.add_argument("--max_embeddings_multiples", type=int, default=None,
help='max embeding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる')

View File

@@ -16,7 +16,7 @@ BETA_END = 0.0120
UNET_PARAMS_MODEL_CHANNELS = 320
UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
UNET_PARAMS_IMAGE_SIZE = 32 # unused
UNET_PARAMS_IMAGE_SIZE = 64 # fixed from old invalid value `32`
UNET_PARAMS_IN_CHANNELS = 4
UNET_PARAMS_OUT_CHANNELS = 4
UNET_PARAMS_NUM_RES_BLOCKS = 2
@@ -1163,15 +1163,14 @@ def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64)
resos = list(resos)
resos.sort()
aspect_ratios = [w / h for w, h in resos]
return resos, aspect_ratios
return resos
if __name__ == '__main__':
resos, aspect_ratios = make_bucket_resolutions((512, 768))
resos = make_bucket_resolutions((512, 768))
print(len(resos))
print(resos)
aspect_ratios = [w / h for w, h in resos]
print(aspect_ratios)
ars = set()

View File

@@ -4,13 +4,16 @@ import argparse
import json
import shutil
import time
from typing import NamedTuple
from typing import Dict, List, NamedTuple, Tuple
from accelerate import Accelerator
from torch.autograd.function import Function
import glob
import math
import os
import random
import hashlib
import subprocess
from io import BytesIO
from tqdm import tqdm
import torch
@@ -24,6 +27,7 @@ from PIL import Image
import cv2
from einops import rearrange
from torch import einsum
import safetensors.torch
import library.model_util as model_util
@@ -42,6 +46,7 @@ DEFAULT_LAST_OUTPUT_NAME = "last"
# region dataset
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp"]
# , ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] # Linux?
class ImageInfo():
@@ -51,16 +56,142 @@ class ImageInfo():
self.caption: str = caption
self.is_reg: bool = is_reg
self.absolute_path: str = absolute_path
self.image_size: tuple[int, int] = None
self.bucket_reso: tuple[int, int] = None
self.image_size: Tuple[int, int] = None
self.resized_size: Tuple[int, int] = None
self.bucket_reso: Tuple[int, int] = None
self.latents: torch.Tensor = None
self.latents_flipped: torch.Tensor = None
self.latents_npz: str = None
self.latents_npz_flipped: str = None
class BucketManager():
def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None:
self.no_upscale = no_upscale
if max_reso is None:
self.max_reso = None
self.max_area = None
else:
self.max_reso = max_reso
self.max_area = max_reso[0] * max_reso[1]
self.min_size = min_size
self.max_size = max_size
self.reso_steps = reso_steps
self.resos = []
self.reso_to_id = {}
self.buckets = [] # 前処理時は (image_key, image)、学習時は image_key
def add_image(self, reso, image):
bucket_id = self.reso_to_id[reso]
self.buckets[bucket_id].append(image)
def shuffle(self):
for bucket in self.buckets:
random.shuffle(bucket)
def sort(self):
# 解像度順にソートする表示時、メタデータ格納時の見栄えをよくするためだけ。bucketsも入れ替えてreso_to_idも振り直す
sorted_resos = self.resos.copy()
sorted_resos.sort()
sorted_buckets = []
sorted_reso_to_id = {}
for i, reso in enumerate(sorted_resos):
bucket_id = self.reso_to_id[reso]
sorted_buckets.append(self.buckets[bucket_id])
sorted_reso_to_id[reso] = i
self.resos = sorted_resos
self.buckets = sorted_buckets
self.reso_to_id = sorted_reso_to_id
def make_buckets(self):
resos = model_util.make_bucket_resolutions(self.max_reso, self.min_size, self.max_size, self.reso_steps)
self.set_predefined_resos(resos)
def set_predefined_resos(self, resos):
# 規定サイズから選ぶ場合の解像度、aspect ratioの情報を格納しておく
self.predefined_resos = resos.copy()
self.predefined_resos_set = set(resos)
self.predefined_aspect_ratios = np.array([w / h for w, h in resos])
def add_if_new_reso(self, reso):
if reso not in self.reso_to_id:
bucket_id = len(self.resos)
self.reso_to_id[reso] = bucket_id
self.resos.append(reso)
self.buckets.append([])
# print(reso, bucket_id, len(self.buckets))
def round_to_steps(self, x):
x = int(x + .5)
return x - x % self.reso_steps
def select_bucket(self, image_width, image_height):
aspect_ratio = image_width / image_height
if not self.no_upscale:
# 同じaspect ratioがあるかもしれないのでfine tuningで、no_upscale=Trueで前処理した場合、解像度が同じものを優先する
reso = (image_width, image_height)
if reso in self.predefined_resos_set:
pass
else:
ar_errors = self.predefined_aspect_ratios - aspect_ratio
predefined_bucket_id = np.abs(ar_errors).argmin() # 当該解像度以外でaspect ratio errorが最も少ないもの
reso = self.predefined_resos[predefined_bucket_id]
ar_reso = reso[0] / reso[1]
if aspect_ratio > ar_reso: # 横が長い→縦を合わせる
scale = reso[1] / image_height
else:
scale = reso[0] / image_width
resized_size = (int(image_width * scale + .5), int(image_height * scale + .5))
# print("use predef", image_width, image_height, reso, resized_size)
else:
if image_width * image_height > self.max_area:
# 画像が大きすぎるのでアスペクト比を保ったまま縮小することを前提にbucketを決める
resized_width = math.sqrt(self.max_area * aspect_ratio)
resized_height = self.max_area / resized_width
assert abs(resized_width / resized_height - aspect_ratio) < 1e-2, "aspect is illegal"
# リサイズ後の短辺または長辺をreso_steps単位にするaspect ratioの差が少ないほうを選ぶ
# 元のbucketingと同じロジック
b_width_rounded = self.round_to_steps(resized_width)
b_height_in_wr = self.round_to_steps(b_width_rounded / aspect_ratio)
ar_width_rounded = b_width_rounded / b_height_in_wr
b_height_rounded = self.round_to_steps(resized_height)
b_width_in_hr = self.round_to_steps(b_height_rounded * aspect_ratio)
ar_height_rounded = b_width_in_hr / b_height_rounded
# print(b_width_rounded, b_height_in_wr, ar_width_rounded)
# print(b_width_in_hr, b_height_rounded, ar_height_rounded)
if abs(ar_width_rounded - aspect_ratio) < abs(ar_height_rounded - aspect_ratio):
resized_size = (b_width_rounded, int(b_width_rounded / aspect_ratio + .5))
else:
resized_size = (int(b_height_rounded * aspect_ratio + .5), b_height_rounded)
# print(resized_size)
else:
resized_size = (image_width, image_height) # リサイズは不要
# 画像のサイズ未満をbucketのサイズとするpaddingせずにcroppingする
bucket_width = resized_size[0] - resized_size[0] % self.reso_steps
bucket_height = resized_size[1] - resized_size[1] % self.reso_steps
# print("use arbitrary", image_width, image_height, resized_size, bucket_width, bucket_height)
reso = (bucket_width, bucket_height)
self.add_if_new_reso(reso)
ar_error = (reso[0] / reso[1]) - aspect_ratio
return reso, resized_size, ar_error
class BucketBatchIndex(NamedTuple):
bucket_index: int
bucket_batch_size: int
batch_index: int
@@ -79,9 +210,25 @@ class BaseDataset(torch.utils.data.Dataset):
self.debug_dataset = debug_dataset
self.random_crop = random_crop
self.token_padding_disabled = False
self.dataset_dirs_info = {}
self.reg_dataset_dirs_info = {}
self.tag_frequency = {}
self.enable_bucket = False
self.bucket_manager: BucketManager = None # not initialized
self.min_bucket_reso = None
self.max_bucket_reso = None
self.bucket_reso_steps = None
self.bucket_no_upscale = None
self.bucket_info = None # for metadata
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
self.dropout_rate: float = 0
self.dropout_every_n_epochs: int = None
self.tag_dropout_rate: float = 0
# augmentation
flip_p = 0.5 if flip_aug else 0.0
if color_aug:
@@ -102,23 +249,83 @@ class BaseDataset(torch.utils.data.Dataset):
self.image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ])
self.image_data: dict[str, ImageInfo] = {}
self.image_data: Dict[str, ImageInfo] = {}
self.replacements = {}
def set_current_epoch(self, epoch):
self.current_epoch = epoch
def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs, tag_dropout_rate):
# コンストラクタで渡さないのはTextual Inversionで意識したくないからということにしておく
self.dropout_rate = dropout_rate
self.dropout_every_n_epochs = dropout_every_n_epochs
self.tag_dropout_rate = tag_dropout_rate
def set_tag_frequency(self, dir_name, captions):
frequency_for_dir = self.tag_frequency.get(dir_name, {})
self.tag_frequency[dir_name] = frequency_for_dir
for caption in captions:
for tag in caption.split(","):
if tag and not tag.isspace():
tag = tag.lower()
frequency = frequency_for_dir.get(tag, 0)
frequency_for_dir[tag] = frequency + 1
def disable_token_padding(self):
self.token_padding_disabled = True
def add_replacement(self, str_from, str_to):
self.replacements[str_from] = str_to
def process_caption(self, caption):
if self.shuffle_caption:
tokens = caption.strip().split(",")
if self.shuffle_keep_tokens is None:
random.shuffle(tokens)
else:
if len(tokens) > self.shuffle_keep_tokens:
keep_tokens = tokens[:self.shuffle_keep_tokens]
tokens = tokens[self.shuffle_keep_tokens:]
random.shuffle(tokens)
tokens = keep_tokens + tokens
caption = ",".join(tokens).strip()
# dropoutの決定tag dropがこのメソッド内にあるのでここで行うのが良い
is_drop_out = self.dropout_rate > 0 and random.random() < self.dropout_rate
is_drop_out = is_drop_out or self.dropout_every_n_epochs and self.current_epoch % self.dropout_every_n_epochs == 0
if is_drop_out:
caption = ""
else:
if self.shuffle_caption or self.tag_dropout_rate > 0:
def dropout_tags(tokens):
if self.tag_dropout_rate <= 0:
return tokens
l = []
for token in tokens:
if random.random() >= self.tag_dropout_rate:
l.append(token)
return l
tokens = [t.strip() for t in caption.strip().split(",")]
if self.shuffle_keep_tokens is None:
if self.shuffle_caption:
random.shuffle(tokens)
tokens = dropout_tags(tokens)
else:
if len(tokens) > self.shuffle_keep_tokens:
keep_tokens = tokens[:self.shuffle_keep_tokens]
tokens = tokens[self.shuffle_keep_tokens:]
if self.shuffle_caption:
random.shuffle(tokens)
tokens = dropout_tags(tokens)
tokens = keep_tokens + tokens
caption = ", ".join(tokens)
# textual inversion対応
for str_from, str_to in self.replacements.items():
if str_from == "":
# replace all
if type(str_to) == list:
caption = random.choice(str_to)
else:
caption = str_to
else:
caption = caption.replace(str_from, str_to)
return caption
def get_input_ids(self, caption):
@@ -178,59 +385,80 @@ class BaseDataset(torch.utils.data.Dataset):
else:
print("prepare dataset")
bucket_resos = self.bucket_resos
bucket_aspect_ratios = np.array(self.bucket_aspect_ratios)
# bucketを作成する
# bucketを作成し、画像をbucketに振り分ける
if self.enable_bucket:
if self.bucket_manager is None: # fine tuningの場合でmetadataに定義がある場合は、すでに初期化済み
self.bucket_manager = BucketManager(self.bucket_no_upscale, (self.width, self.height),
self.min_bucket_reso, self.max_bucket_reso, self.bucket_reso_steps)
if not self.bucket_no_upscale:
self.bucket_manager.make_buckets()
else:
print("min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます")
img_ar_errors = []
for image_info in self.image_data.values():
# bucketを決める
image_width, image_height = image_info.image_size
aspect_ratio = image_width / image_height
ar_errors = bucket_aspect_ratios - aspect_ratio
image_info.bucket_reso, image_info.resized_size, ar_error = self.bucket_manager.select_bucket(image_width, image_height)
bucket_id = np.abs(ar_errors).argmin()
image_info.bucket_reso = bucket_resos[bucket_id]
# print(image_info.image_key, image_info.bucket_reso)
img_ar_errors.append(abs(ar_error))
ar_error = ar_errors[bucket_id]
img_ar_errors.append(ar_error)
self.bucket_manager.sort()
else:
self.bucket_manager = BucketManager(False, (self.width, self.height), None, None, None)
self.bucket_manager.set_predefined_resos([(self.width, self.height)]) # ひとつの固定サイズbucketのみ
for image_info in self.image_data.values():
image_info.bucket_reso = bucket_resos[0] # bucket_resos contains (width, height) only
# 画像をbucketに分割する
self.buckets: list[str] = [[] for _ in range(len(bucket_resos))]
reso_to_index = {}
for i, reso in enumerate(bucket_resos):
reso_to_index[reso] = i
image_width, image_height = image_info.image_size
image_info.bucket_reso, image_info.resized_size, _ = self.bucket_manager.select_bucket(image_width, image_height)
for image_info in self.image_data.values():
bucket_index = reso_to_index[image_info.bucket_reso]
for _ in range(image_info.num_repeats):
self.buckets[bucket_index].append(image_info.image_key)
self.bucket_manager.add_image(image_info.bucket_reso, image_info.image_key)
# bucket情報を表示、格納する
if self.enable_bucket:
self.bucket_info = {"buckets": {}}
print("number of images (including repeats) / 各bucketの画像枚数繰り返し回数を含む")
for i, (reso, img_keys) in enumerate(zip(bucket_resos, self.buckets)):
print(f"bucket {i}: resolution {reso}, count: {len(img_keys)}")
img_ar_errors = np.array(img_ar_errors)
print(f"mean ar error (without repeats): {np.mean(np.abs(img_ar_errors))}")
for i, (reso, bucket) in enumerate(zip(self.bucket_manager.resos, self.bucket_manager.buckets)):
count = len(bucket)
if count > 0:
self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)}
print(f"bucket {i}: resolution {reso}, count: {len(bucket)}")
# 参照用indexを作る
self.buckets_indices: list(BucketBatchIndex) = []
for bucket_index, bucket in enumerate(self.buckets):
img_ar_errors = np.array(img_ar_errors)
mean_img_ar_error = np.mean(np.abs(img_ar_errors))
self.bucket_info["mean_img_ar_error"] = mean_img_ar_error
print(f"mean ar error (without repeats): {mean_img_ar_error}")
# データ参照用indexを作る。このindexはdatasetのshuffleに用いられる
self.buckets_indices: List(BucketBatchIndex) = []
for bucket_index, bucket in enumerate(self.bucket_manager.buckets):
batch_count = int(math.ceil(len(bucket) / self.batch_size))
for batch_index in range(batch_count):
self.buckets_indices.append(BucketBatchIndex(bucket_index, batch_index))
self.buckets_indices.append(BucketBatchIndex(bucket_index, self.batch_size, batch_index))
# ↓以下はbucketごとのbatch件数があまりにも増えて混乱を招くので元に戻す
#  学習時はステップ数がランダムなので、同一画像が同一batch内にあってもそれほど悪影響はないであろう、と考えられる
#
# # bucketが細分化されることにより、ひとつのbucketに一種類の画像のみというケースが増え、つまりそれは
# # ひとつのbatchが同じ画像で占められることになるので、さすがに良くないであろう
# # そのためバッチサイズを画像種類までに制限する
# # ただそれでも同一画像が同一バッチに含まれる可能性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなることは間違いない
# # TO DO 正則化画像をepochまたがりで利用する仕組み
# num_of_image_types = len(set(bucket))
# bucket_batch_size = min(self.batch_size, num_of_image_types)
# batch_count = int(math.ceil(len(bucket) / bucket_batch_size))
# # print(bucket_index, num_of_image_types, bucket_batch_size, batch_count)
# for batch_index in range(batch_count):
# self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index))
# ↑ここまで
self.shuffle_buckets()
self._length = len(self.buckets_indices)
def shuffle_buckets(self):
random.shuffle(self.buckets_indices)
for bucket in self.buckets:
random.shuffle(bucket)
self.bucket_manager.shuffle()
def load_image(self, image_path):
image = Image.open(image_path)
@@ -239,28 +467,30 @@ class BaseDataset(torch.utils.data.Dataset):
img = np.array(image, np.uint8)
return img
def resize_and_trim(self, image, reso):
def trim_and_resize_if_required(self, image, reso, resized_size):
image_height, image_width = image.shape[0:2]
ar_img = image_width / image_height
ar_reso = reso[0] / reso[1]
if ar_img > ar_reso: # 横が長い→縦を合わせる
scale = reso[1] / image_height
else:
scale = reso[0] / image_width
resized_size = (int(image_width * scale + .5), int(image_height * scale + .5))
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
if resized_size[0] > reso[0]:
trim_size = resized_size[0] - reso[0]
image = image[:, trim_size//2:trim_size//2 + reso[0]]
elif resized_size[1] > reso[1]:
trim_size = resized_size[1] - reso[1]
image = image[trim_size//2:trim_size//2 + reso[1]]
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], \
f"internal error, illegal trimmed size: {image.shape}, {reso}"
if image_width != resized_size[0] or image_height != resized_size[1]:
# リサイズする
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
image_height, image_width = image.shape[0:2]
if image_width > reso[0]:
trim_size = image_width - reso[0]
p = trim_size // 2 if not self.random_crop else random.randint(0, trim_size)
# print("w", trim_size, p)
image = image[:, p:p + reso[0]]
if image_height > reso[1]:
trim_size = image_height - reso[1]
p = trim_size // 2 if not self.random_crop else random.randint(0, trim_size)
# print("h", trim_size, p)
image = image[p:p + reso[1]]
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
return image
def cache_latents(self, vae):
# TODO ここを高速化したい
print("caching latents.")
for info in tqdm(self.image_data.values()):
if info.latents_npz is not None:
@@ -272,7 +502,7 @@ class BaseDataset(torch.utils.data.Dataset):
continue
image = self.load_image(info.absolute_path)
image = self.resize_and_trim(image, info.bucket_reso)
image = self.trim_and_resize_if_required(image, info.bucket_reso, info.resized_size)
img_tensor = self.image_transforms(image)
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
@@ -362,8 +592,9 @@ class BaseDataset(torch.utils.data.Dataset):
if index == 0:
self.shuffle_buckets()
bucket = self.buckets[self.buckets_indices[index].bucket_index]
image_index = self.buckets_indices[index].batch_index * self.batch_size
bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
bucket_batch_size = self.buckets_indices[index].bucket_batch_size
image_index = self.buckets_indices[index].batch_index * bucket_batch_size
loss_weights = []
captions = []
@@ -371,7 +602,7 @@ class BaseDataset(torch.utils.data.Dataset):
latents_list = []
images = []
for image_key in bucket[image_index:image_index + self.batch_size]:
for image_key in bucket[image_index:image_index + bucket_batch_size]:
image_info = self.image_data[image_key]
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
@@ -389,7 +620,7 @@ class BaseDataset(torch.utils.data.Dataset):
im_h, im_w = img.shape[0:2]
if self.enable_bucket:
img = self.resize_and_trim(img, image_info.bucket_reso)
img = self.trim_and_resize_if_required(img, image_info.bucket_reso, image_info.resized_size)
else:
if face_cx > 0: # 顔位置情報あり
img = self.crop_target(img, face_cx, face_cy, face_w, face_h)
@@ -446,7 +677,7 @@ class BaseDataset(torch.utils.data.Dataset):
class DreamBoothDataset(BaseDataset):
def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) -> None:
def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) -> None:
super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
@@ -461,11 +692,15 @@ class DreamBoothDataset(BaseDataset):
if self.enable_bucket:
assert min(resolution) >= min_bucket_reso, f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください"
assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
self.bucket_resos, self.bucket_aspect_ratios = model_util.make_bucket_resolutions(
(self.width, self.height), min_bucket_reso, max_bucket_reso)
self.min_bucket_reso = min_bucket_reso
self.max_bucket_reso = max_bucket_reso
self.bucket_reso_steps = bucket_reso_steps
self.bucket_no_upscale = bucket_no_upscale
else:
self.bucket_resos = [(self.width, self.height)]
self.bucket_aspect_ratios = [self.width / self.height]
self.min_bucket_reso = None
self.max_bucket_reso = None
self.bucket_reso_steps = None # この情報は使われない
self.bucket_no_upscale = False
def read_caption(img_path):
# captionの候補ファイル名を作る
@@ -512,6 +747,8 @@ class DreamBoothDataset(BaseDataset):
cap_for_img = read_caption(img_path)
captions.append(caption_by_folder if cap_for_img is None else cap_for_img)
self.set_tag_frequency(os.path.basename(dir), captions) # タグ頻度を記録
return n_repeats, img_paths, captions
print("prepare train images.")
@@ -520,9 +757,13 @@ class DreamBoothDataset(BaseDataset):
for dir in train_dirs:
n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(train_data_dir, dir))
num_train_images += n_repeats * len(img_paths)
for img_path, caption in zip(img_paths, captions):
info = ImageInfo(img_path, n_repeats, caption, False, img_path)
self.register_image(info)
self.dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
print(f"{num_train_images} train images with repeating.")
self.num_train_images = num_train_images
@@ -530,16 +771,19 @@ class DreamBoothDataset(BaseDataset):
num_reg_images = 0
if reg_data_dir:
print("prepare reg images.")
reg_infos: list[ImageInfo] = []
reg_infos: List[ImageInfo] = []
reg_dirs = os.listdir(reg_data_dir)
for dir in reg_dirs:
n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(reg_data_dir, dir))
num_reg_images += n_repeats * len(img_paths)
for img_path, caption in zip(img_paths, captions):
info = ImageInfo(img_path, n_repeats, caption, True, img_path)
reg_infos.append(info)
self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
print(f"{num_reg_images} reg images.")
if num_train_images < num_reg_images:
print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
@@ -566,7 +810,7 @@ class DreamBoothDataset(BaseDataset):
class FineTuningDataset(BaseDataset):
def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, flip_aug, color_aug, face_crop_aug_range, random_crop, dataset_repeats, debug_dataset) -> None:
def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, flip_aug, color_aug, face_crop_aug_range, random_crop, dataset_repeats, debug_dataset) -> None:
super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset)
@@ -582,6 +826,7 @@ class FineTuningDataset(BaseDataset):
self.train_data_dir = train_data_dir
self.batch_size = batch_size
tags_list = []
for image_key, img_md in metadata.items():
# path情報を作る
if os.path.exists(image_key):
@@ -589,7 +834,7 @@ class FineTuningDataset(BaseDataset):
else:
# わりといい加減だがいい方法が思いつかん
abs_path = glob_images(train_data_dir, image_key)
assert len(abs_path) >= 1, f"no image / 画像がありません: {abs_path}"
assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}"
abs_path = abs_path[0]
caption = img_md.get('caption')
@@ -598,12 +843,13 @@ class FineTuningDataset(BaseDataset):
caption = tags
elif tags is not None and len(tags) > 0:
caption = caption + ', ' + tags
tags_list.append(tags)
assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}"
image_info = ImageInfo(image_key, dataset_repeats, caption, False, abs_path)
image_info.image_size = img_md.get('train_resolution')
if not self.color_aug:
if not self.color_aug and not self.random_crop:
# if npz exists, use them
image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(image_key)
@@ -611,8 +857,13 @@ class FineTuningDataset(BaseDataset):
self.num_train_images = len(metadata) * dataset_repeats
self.num_reg_images = 0
# TODO do not record tag freq when no tag
self.set_tag_frequency(os.path.basename(json_file_name), tags_list)
self.dataset_dirs_info[os.path.basename(json_file_name)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)}
# check existence of all npz files
if not self.color_aug:
use_npz_latents = not (self.color_aug or self.random_crop)
if use_npz_latents:
npz_any = False
npz_all = True
for image_info in self.image_data.values():
@@ -627,11 +878,15 @@ class FineTuningDataset(BaseDataset):
break
if not npz_any:
print(f"npz file does not exist. make latents with VAE / npzファイルが見つからないためVAEを使ってlatentsを取得します")
use_npz_latents = False
print(f"npz file does not exist. ignore npz files / npzファイルが見つからないためnpzファイルを無視します")
elif not npz_all:
use_npz_latents = False
print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します")
for image_info in self.image_data.values():
image_info.latents_npz = image_info.latents_npz_flipped = None
if self.flip_aug:
print("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
# else:
# print("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
# check min/max bucket size
sizes = set()
@@ -645,25 +900,34 @@ class FineTuningDataset(BaseDataset):
resos.add(tuple(image_info.image_size))
if sizes is None:
if use_npz_latents:
use_npz_latents = False
print(f"npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します")
assert resolution is not None, "if metadata doesn't have bucket info, resolution is required / メタデータにbucket情報がない場合はresolutionを指定してください"
self.enable_bucket = enable_bucket
if self.enable_bucket:
assert min(resolution) >= min_bucket_reso, f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください"
assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
self.bucket_resos, self.bucket_aspect_ratios = model_util.make_bucket_resolutions(
(self.width, self.height), min_bucket_reso, max_bucket_reso)
else:
self.bucket_resos = [(self.width, self.height)]
self.bucket_aspect_ratios = [self.width / self.height]
self.min_bucket_reso = min_bucket_reso
self.max_bucket_reso = max_bucket_reso
self.bucket_reso_steps = bucket_reso_steps
self.bucket_no_upscale = bucket_no_upscale
else:
if not enable_bucket:
print("metadata has bucket info, enable bucketing / メタデータにbucket情報があるためbucketを有効にします")
print("using bucket info in metadata / メタデータ内のbucket情報を使います")
self.enable_bucket = True
self.bucket_resos = list(resos)
self.bucket_resos.sort()
self.bucket_aspect_ratios = [w / h for w, h in self.bucket_resos]
assert not bucket_no_upscale, "if metadata has bucket info, bucket reso is precalculated, so bucket_no_upscale cannot be used / メタデータ内にbucket情報がある場合はbucketの解像度は計算済みのため、bucket_no_upscaleは使えません"
# bucket情報を初期化しておく、make_bucketsで再作成しない
self.bucket_manager = BucketManager(False, None, None, None, None)
self.bucket_manager.set_predefined_resos(resos)
# npz情報をきれいにしておく
if not use_npz_latents:
for image_info in self.image_data.values():
image_info.latents_npz = image_info.latents_npz_flipped = None
def image_key_to_npz_file(self, image_key):
base_name = os.path.splitext(image_key)[0]
@@ -689,38 +953,59 @@ class FineTuningDataset(BaseDataset):
return npz_file_norm, npz_file_flip
def debug_dataset(train_dataset):
def debug_dataset(train_dataset, show_input_ids=False):
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
print("Escape for exit. / Escキーで中断、終了します")
train_dataset.set_current_epoch(1)
k = 0
for example in train_dataset:
for i, example in enumerate(train_dataset):
if example['latents'] is not None:
print("sample has latents from npz file")
for j, (ik, cap, lw) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'])):
print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, caption: "{cap}", loss weight: {lw}')
print(f"sample has latents from npz file: {example['latents'].size()}")
for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])):
print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}"')
if show_input_ids:
print(f"input ids: {iid}")
if example['images'] is not None:
im = example['images'][j]
print(f"image size: {im.size()}")
im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c
im = im[:, :, ::-1] # RGB -> BGR (OpenCV)
cv2.imshow("img", im)
if os.name == 'nt': # only windows
cv2.imshow("img", im)
k = cv2.waitKey()
cv2.destroyAllWindows()
if k == 27:
break
if k == 27 or example['images'] is None:
if k == 27 or (example['images'] is None and i >= 8):
break
def glob_images(dir, base):
def glob_images(directory, base="*"):
img_paths = []
for ext in IMAGE_EXTENSIONS:
if base == '*':
img_paths.extend(glob.glob(os.path.join(glob.escape(dir), base + ext)))
img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
else:
img_paths.extend(glob.glob(glob.escape(os.path.join(dir, base + ext))))
img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
# img_paths = list(set(img_paths)) # 重複を排除
# img_paths.sort()
return img_paths
def glob_images_pathlib(dir_path, recursive):
image_paths = []
if recursive:
for ext in IMAGE_EXTENSIONS:
image_paths += list(dir_path.rglob('*' + ext))
else:
for ext in IMAGE_EXTENSIONS:
image_paths += list(dir_path.glob('*' + ext))
# image_paths = list(set(image_paths)) # 重複を排除
# image_paths.sort()
return image_paths
# endregion
@@ -749,9 +1034,9 @@ def default(val, d):
def model_hash(filename):
"""Old model hash used by stable-diffusion-webui"""
try:
with open(filename, "rb") as file:
import hashlib
m = hashlib.sha256()
file.seek(0x100000)
@@ -761,6 +1046,68 @@ def model_hash(filename):
return 'NOFILE'
def calculate_sha256(filename):
"""New model hash used by stable-diffusion-webui"""
hash_sha256 = hashlib.sha256()
blksize = 1024 * 1024
with open(filename, "rb") as f:
for chunk in iter(lambda: f.read(blksize), b""):
hash_sha256.update(chunk)
return hash_sha256.hexdigest()
def precalculate_safetensors_hashes(tensors, metadata):
"""Precalculate the model hashes needed by sd-webui-additional-networks to
save time on indexing the model later."""
# Because writing user metadata to the file can change the result of
# sd_models.model_hash(), only retain the training metadata for purposes of
# calculating the hash, as they are meant to be immutable
metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
bytes = safetensors.torch.save(tensors, metadata)
b = BytesIO(bytes)
model_hash = addnet_hash_safetensors(b)
legacy_hash = addnet_hash_legacy(b)
return model_hash, legacy_hash
def addnet_hash_legacy(b):
"""Old model hash used by sd-webui-additional-networks for .safetensors format files"""
m = hashlib.sha256()
b.seek(0x100000)
m.update(b.read(0x10000))
return m.hexdigest()[0:8]
def addnet_hash_safetensors(b):
"""New model hash used by sd-webui-additional-networks for .safetensors format files"""
hash_sha256 = hashlib.sha256()
blksize = 1024 * 1024
b.seek(0)
header = b.read(8)
n = int.from_bytes(header, "little")
offset = n + 8
b.seek(offset)
for chunk in iter(lambda: b.read(blksize), b""):
hash_sha256.update(chunk)
return hash_sha256.hexdigest()
def get_git_revision_hash() -> str:
try:
return subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=os.path.dirname(__file__)).decode('ascii').strip()
except:
return "(unknown)"
# flash attention forwards and backwards
# https://arxiv.org/abs/2205.14135
@@ -1028,8 +1375,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する")
parser.add_argument("--save_every_n_epochs", type=int, default=None,
help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する")
parser.add_argument("--save_n_epoch_ratio", type=int, default=None,
help="save checkpoint N epoch ratio (for example 5 means save at least 5 files total) / 学習中のモデルを指定のエポック割合で保存するたとえば5を指定すると最低5個のファイルが保存される")
parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する")
parser.add_argument("--save_last_n_epochs_state", type=int, default=None, help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)")
parser.add_argument("--save_last_n_epochs_state", type=int, default=None,
help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)")
parser.add_argument("--save_state", action="store_true",
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
@@ -1039,6 +1389,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長未指定で75、150または225が指定可")
parser.add_argument("--use_8bit_adam", action="store_true",
help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使うbitsandbytesのインストールが必要")
parser.add_argument("--use_lion_optimizer", action="store_true",
help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う lion-pytorch のインストールが必要)")
parser.add_argument("--mem_eff_attn", action="store_true",
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
parser.add_argument("--xformers", action="store_true",
@@ -1048,8 +1400,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
parser.add_argument("--max_train_epochs", type=int, default=None, help="training epochs (overrides max_train_steps) / 学習エポック数max_train_stepsを上書きします")
parser.add_argument("--max_data_loader_n_workers", type=int, default=8, help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)")
parser.add_argument("--max_train_epochs", type=int, default=None,
help="training epochs (overrides max_train_steps) / 学習エポック数max_train_stepsを上書きします)")
parser.add_argument("--max_data_loader_n_workers", type=int, default=8,
help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります")
parser.add_argument("--persistent_data_loader_workers", action="store_true",
help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)")
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
parser.add_argument("--gradient_checkpointing", action="store_true",
help="enable gradient checkpointing / grandient checkpointingを有効にする")
@@ -1067,6 +1423,10 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
parser.add_argument("--lr_warmup_steps", type=int, default=0,
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数デフォルト0")
parser.add_argument("--noise_offset", type=float, default=None,
help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する有効にする場合は0.1程度を推奨)")
parser.add_argument("--lowram", action="store_true",
help="enable low RAM optimization. e.g. load models to VRAM instead of RAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle) / メインメモリが少ない環境向け最適化を有効にする。たとえばVRAMにモデルを読み込むなどColabやKaggleなどRAMに比べてVRAMが多い環境向け")
if support_dreambooth:
# DreamBooth training
@@ -1081,7 +1441,7 @@ def verify_training_args(args: argparse.Namespace):
print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool):
def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool):
# dataset common
parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("--shuffle_caption", action="store_true",
@@ -1107,6 +1467,20 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする")
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度")
parser.add_argument("--bucket_reso_steps", type=int, default=64,
help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します")
parser.add_argument("--bucket_no_upscale", action="store_true",
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します")
if support_caption_dropout:
# Textual Inversion はcaptionのdropoutをsupportしない
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
parser.add_argument("--caption_dropout_rate", type=float, default=0,
help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=None,
help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
parser.add_argument("--caption_tag_dropout_rate", type=float, default=0,
help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合")
if support_dreambooth:
# DreamBooth dataset
@@ -1138,6 +1512,7 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
if args.cache_latents:
assert not args.color_aug, "when caching latents, color_aug cannot be used / latentをキャッシュするときはcolor_augは使えません"
assert not args.random_crop, "when caching latents, random_crop cannot be used / latentをキャッシュするときはrandom_cropは使えません"
# assert args.resolution is not None, f"resolution is required / resolution解像度を指定してください"
if args.resolution is not None:
@@ -1149,14 +1524,14 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
if args.face_crop_aug_range is not None:
args.face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(',')])
assert len(args.face_crop_aug_range) == 2, \
assert len(args.face_crop_aug_range) == 2 and args.face_crop_aug_range[0] <= args.face_crop_aug_range[1], \
f"face_crop_aug_range must be two floats / face_crop_aug_rangeは'下限,上限'で指定してください: {args.face_crop_aug_range}"
else:
args.face_crop_aug_range = None
if support_metadata:
if args.in_json is not None and args.color_aug:
print(f"latents in npz is ignored when color_aug is True / color_augを有効にした場合、npzファイルのlatentsは無視されます")
if args.in_json is not None and (args.color_aug or args.random_crop):
print(f"latents in npz is ignored when color_aug or random_crop is True / color_augまたはrandom_cropを有効にした場合、npzファイルのlatentsは無視されます")
def load_tokenizer(args: argparse.Namespace):
@@ -1259,9 +1634,6 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod
else:
enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True)
encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
if weight_dtype is not None:
# this is required for additional network training
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
# bs*3, 77, 768 or 1024
@@ -1288,6 +1660,10 @@ def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encod
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS>
encoder_hidden_states = torch.cat(states_list, dim=1)
if weight_dtype is not None:
# this is required for additional network training
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
return encoder_hidden_states
@@ -1391,5 +1767,30 @@ def save_state_on_train_end(args: argparse.Namespace, accelerator):
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
# endregion
# region 前処理用
class ImageLoadingDataset(torch.utils.data.Dataset):
def __init__(self, image_paths):
self.images = image_paths
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = self.images[idx]
try:
image = Image.open(img_path).convert("RGB")
# convert to tensor temporarily so dataloader will accept it
tensor_pil = transforms.functional.pil_to_tensor(image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
return None
return (tensor_pil, img_path)
# endregion

View File

@@ -44,9 +44,9 @@ def svd(args):
print(f"loading SD model : {args.model_tuned}")
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned)
# create LoRA network to extract weights
lora_network_o = lora.create_network(1.0, args.dim, None, text_encoder_o, unet_o)
lora_network_t = lora.create_network(1.0, args.dim, None, text_encoder_t, unet_t)
# create LoRA network to extract weights: Use dim (rank) as alpha
lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_o, unet_o)
lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_t, unet_t)
assert len(lora_network_o.text_encoder_loras) == len(
lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違いますSD1.xベースとSD2.xベース "
@@ -77,10 +77,10 @@ def svd(args):
module_t = lora_t.org_module
diff = module_t.weight - module_o.weight
diff = diff.float()
if args.device:
diff = diff.to(args.device)
diffs[lora_name] = diff
# make LoRA with svd
@@ -116,6 +116,9 @@ def svd(args):
print(f"LoRA has {len(lora_sd)} weights.")
for key in list(lora_sd.keys()):
if "alpha" in key:
continue
lora_name = key.split('.')[0]
i = 0 if "lora_up" in key else 1
@@ -124,7 +127,7 @@ def svd(args):
if len(lora_sd[key].size()) == 4:
weights = weights.unsqueeze(2).unsqueeze(3)
assert weights.size() == lora_sd[key].size()
assert weights.size() == lora_sd[key].size(), f"size unmatch: {key}"
lora_sd[key] = weights
# load state dict to LoRA and save it
@@ -135,7 +138,10 @@ def svd(args):
if dir_name and not os.path.exists(dir_name):
os.makedirs(dir_name, exist_ok=True)
lora_network_o.save_weights(args.save_to, save_dtype, {})
# minimum metadata
metadata = {"ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)}
lora_network_o.save_weights(args.save_to, save_dtype, metadata)
print(f"LoRA weights are saved to: {args.save_to}")
@@ -151,8 +157,8 @@ if __name__ == '__main__':
help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル生成されるLoRAは元→派生の差分になります、ckptまたはsafetensors")
parser.add_argument("--save_to", type=str, default=None,
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
parser.add_argument("--dim", type=int, default=4, help="dimension of LoRA (default 4) / LoRAの次元数デフォルト4")
parser.add_argument("--device", type=str, default=None, help="device to use, 'cuda' for GPU / 計算を行うデバイス、'cuda'でGPUを使う")
parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数rankデフォルト4")
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
args = parser.parse_args()
svd(args)

View File

@@ -5,17 +5,22 @@
import math
import os
from typing import List
import torch
from library import train_util
class LoRAModule(torch.nn.Module):
"""
replaces forward method of the original Linear, instead of replacing the original Linear module.
"""
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4):
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
""" if alpha == 0 or None, alpha is rank (no scaling). """
super().__init__()
self.lora_name = lora_name
self.lora_dim = lora_dim
if org_module.__class__.__name__ == 'Conv2d':
in_dim = org_module.in_channels
@@ -28,6 +33,12 @@ class LoRAModule(torch.nn.Module):
self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False)
self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False)
if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
alpha = lora_dim if alpha is None or alpha == 0 else alpha
self.scale = alpha / self.lora_dim
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
# same as microsoft's
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
torch.nn.init.zeros_(self.lora_up.weight)
@@ -41,13 +52,37 @@ class LoRAModule(torch.nn.Module):
del self.org_module
def forward(self, x):
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
def create_network(multiplier, network_dim, vae, text_encoder, unet, **kwargs):
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
if network_dim is None:
network_dim = 4 # default
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim)
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
return network
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwargs):
if os.path.splitext(file)[1] == '.safetensors':
from safetensors.torch import load_file, safe_open
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location='cpu')
# get dim (rank)
network_alpha = None
network_dim = None
for key, value in weights_sd.items():
if network_alpha is None and 'alpha' in key:
network_alpha = value
if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
network_dim = value.size()[0]
if network_alpha is None:
network_alpha = network_dim
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
network.weights_sd = weights_sd
return network
@@ -57,13 +92,14 @@ class LoRANetwork(torch.nn.Module):
LORA_PREFIX_UNET = 'lora_unet'
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4) -> None:
def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1) -> None:
super().__init__()
self.multiplier = multiplier
self.lora_dim = lora_dim
self.alpha = alpha
# create module instances
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> list[LoRAModule]:
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
loras = []
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
@@ -71,7 +107,7 @@ class LoRANetwork(torch.nn.Module):
if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
lora_name = prefix + '.' + name + '.' + child_name
lora_name = lora_name.replace('.', '_')
lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim)
lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim, self.alpha)
loras.append(lora)
return loras
@@ -149,21 +185,21 @@ class LoRANetwork(torch.nn.Module):
return params
self.requires_grad_(True)
params = []
all_params = []
if self.text_encoder_loras:
param_data = {'params': enumerate_params(self.text_encoder_loras)}
if text_encoder_lr is not None:
param_data['lr'] = text_encoder_lr
params.append(param_data)
all_params.append(param_data)
if self.unet_loras:
param_data = {'params': enumerate_params(self.unet_loras)}
if unet_lr is not None:
param_data['lr'] = unet_lr
params.append(param_data)
all_params.append(param_data)
return params
return all_params
def prepare_grad_etc(self, text_encoder, unet):
self.requires_grad_(True)
@@ -188,6 +224,14 @@ class LoRANetwork(torch.nn.Module):
if os.path.splitext(file)[1] == '.safetensors':
from safetensors.torch import save_file
# Precalculate model hashes to save time on indexing
if metadata is None:
metadata = {}
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
save_file(state_dict, file, metadata)
else:
torch.save(state_dict, file)

View File

@@ -0,0 +1,122 @@
from tqdm import tqdm
from library import model_util
import argparse
from transformers import CLIPTokenizer
import torch
import library.model_util as model_util
import lora
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def interrogate(args):
# いろいろ準備する
print(f"loading SD model: {args.sd_model}")
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
print(f"loading LoRA: {args.model}")
network = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)
# text encoder向けの重みがあるかチェックする本当はlora側でやるのがいい
has_te_weight = False
for key in network.weights_sd.keys():
if 'lora_te' in key:
has_te_weight = True
break
if not has_te_weight:
print("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません")
return
del vae
print("loading tokenizer")
if args.v2:
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
else:
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) # , model_max_length=max_token_length + 2)
text_encoder.to(DEVICE)
text_encoder.eval()
unet.to(DEVICE)
unet.eval() # U-Netは呼び出さないので不要だけど
# トークンをひとつひとつ当たっていく
token_id_start = 0
token_id_end = max(tokenizer.all_special_ids)
print(f"interrogate tokens are: {token_id_start} to {token_id_end}")
def get_all_embeddings(text_encoder):
embs = []
with torch.no_grad():
for token_id in tqdm(range(token_id_start, token_id_end + 1, args.batch_size)):
batch = []
for tid in range(token_id, min(token_id_end + 1, token_id + args.batch_size)):
tokens = [tokenizer.bos_token_id, tid, tokenizer.eos_token_id]
# tokens = [tid] # こちらは結果がいまひとつ
batch.append(tokens)
# batch_embs = text_encoder(torch.tensor(batch).to(DEVICE))[0].to("cpu") # bos/eosも含めたほうが差が出るようだ [:, 1]
# clip skip対応
batch = torch.tensor(batch).to(DEVICE)
if args.clip_skip is None:
encoder_hidden_states = text_encoder(batch)[0]
else:
enc_out = text_encoder(batch, output_hidden_states=True, return_dict=True)
encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.to("cpu")
embs.extend(encoder_hidden_states)
return torch.stack(embs)
print("get original text encoder embeddings.")
orig_embs = get_all_embeddings(text_encoder)
network.apply_to(text_encoder, unet, True, len(network.unet_loras) > 0)
network.to(DEVICE)
network.eval()
print("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません")
print("get text encoder embeddings with lora.")
lora_embs = get_all_embeddings(text_encoder)
# 比べる:とりあえず単純に差分の絶対値で
print("comparing...")
diffs = {}
for i, (orig_emb, lora_emb) in enumerate(zip(orig_embs, tqdm(lora_embs))):
diff = torch.mean(torch.abs(orig_emb - lora_emb))
# diff = torch.mean(torch.cosine_similarity(orig_emb, lora_emb, dim=1)) # うまく検出できない
diff = float(diff.detach().to('cpu').numpy())
diffs[token_id_start + i] = diff
diffs_sorted = sorted(diffs.items(), key=lambda x: -x[1])
# 結果を表示する
print("top 100:")
for i, (token, diff) in enumerate(diffs_sorted[:100]):
# if diff < 1e-6:
# break
string = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens([token]))
print(f"[{i:3d}]: {token:5d} {string:<20s}: {diff:.5f}")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--v2", action='store_true',
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
parser.add_argument("--sd_model", type=str, default=None,
help="Stable Diffusion model to load: ckpt or safetensors file / 読み込むSDのモデル、ckptまたはsafetensors")
parser.add_argument("--model", type=str, default=None,
help="LoRA model to interrogate: ckpt or safetensors file / 調査するLoRAモデル、ckptまたはsafetensors")
parser.add_argument("--batch_size", type=int, default=16,
help="batch size for processing with Text Encoder / Text Encoderで処理するときのバッチサイズ")
parser.add_argument("--clip_skip", type=int, default=None,
help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いるnは1以上")
args = parser.parse_args()
interrogate(args)

View File

@@ -1,5 +1,5 @@
import math
import argparse
import os
import torch
@@ -61,6 +61,7 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
for key in lora_sd.keys():
if "lora_down" in key:
up_key = key.replace("lora_down", "lora_up")
alpha_key = key[:key.index("lora_down")] + 'alpha'
# find original module for this lora
module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight"
@@ -73,33 +74,85 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
down_weight = lora_sd[key]
up_weight = lora_sd[up_key]
dim = down_weight.size()[0]
alpha = lora_sd.get(alpha_key, dim)
scale = alpha / dim
# W <- W + U * D
weight = module.weight
if len(weight.size()) == 2:
# linear
weight = weight + ratio * (up_weight @ down_weight)
weight = weight + ratio * (up_weight @ down_weight) * scale
else:
# conv2d
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
).unsqueeze(2).unsqueeze(3) * scale
module.weight = torch.nn.Parameter(weight)
def merge_lora_models(models, ratios, merge_dtype):
merged_sd = {}
base_alphas = {} # alpha for merged model
base_dims = {}
merged_sd = {}
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
lora_sd = load_state_dict(model, merge_dtype)
# get alpha and dim
alphas = {} # alpha for current model
dims = {} # dims for current model
for key in lora_sd.keys():
if 'alpha' in key:
lora_module_name = key[:key.rfind(".alpha")]
alpha = float(lora_sd[key].detach().numpy())
alphas[lora_module_name] = alpha
if lora_module_name not in base_alphas:
base_alphas[lora_module_name] = alpha
elif "lora_down" in key:
lora_module_name = key[:key.rfind(".lora_down")]
dim = lora_sd[key].size()[0]
dims[lora_module_name] = dim
if lora_module_name not in base_dims:
base_dims[lora_module_name] = dim
for lora_module_name in dims.keys():
if lora_module_name not in alphas:
alpha = dims[lora_module_name]
alphas[lora_module_name] = alpha
if lora_module_name not in base_alphas:
base_alphas[lora_module_name] = alpha
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
# merge
print(f"merging...")
for key in lora_sd.keys():
if 'alpha' in key:
continue
lora_module_name = key[:key.rfind(".lora_")]
base_alpha = base_alphas[lora_module_name]
alpha = alphas[lora_module_name]
scale = math.sqrt(alpha / base_alpha) * ratio
if key in merged_sd:
assert merged_sd[key].size() == lora_sd[key].size(
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
merged_sd[key] = merged_sd[key] + lora_sd[key] * ratio
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
else:
merged_sd[key] = lora_sd[key] * ratio
merged_sd[key] = lora_sd[key] * scale
# set alpha to sd
for lora_module_name, alpha in base_alphas.items():
key = lora_module_name + ".alpha"
merged_sd[key] = torch.tensor(alpha)
print("merged model")
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
return merged_sd
@@ -145,7 +198,7 @@ if __name__ == '__main__':
parser.add_argument("--save_precision", type=str, default=None,
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ")
parser.add_argument("--precision", type=str, default="float",
choices=["float", "fp16", "bf16"], help="precision in merging / マージの計算時の精度")
choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度floatを推奨")
parser.add_argument("--sd_model", type=str, default=None,
help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする")
parser.add_argument("--save_to", type=str, default=None,

179
networks/merge_lora_old.py Normal file
View File

@@ -0,0 +1,179 @@
import argparse
import os
import torch
from safetensors.torch import load_file, save_file
import library.model_util as model_util
import lora
def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == '.safetensors':
sd = load_file(file_name)
else:
sd = torch.load(file_name, map_location='cpu')
for key in list(sd.keys()):
if type(sd[key]) == torch.Tensor:
sd[key] = sd[key].to(dtype)
return sd
def save_to_file(file_name, model, state_dict, dtype):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
if os.path.splitext(file_name)[1] == '.safetensors':
save_file(model, file_name)
else:
torch.save(model, file_name)
def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
text_encoder.to(merge_dtype)
unet.to(merge_dtype)
# create module map
name_to_module = {}
for i, root_module in enumerate([text_encoder, unet]):
if i == 0:
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER
target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
else:
prefix = lora.LoRANetwork.LORA_PREFIX_UNET
target_replace_modules = lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
lora_name = prefix + '.' + name + '.' + child_name
lora_name = lora_name.replace('.', '_')
name_to_module[lora_name] = child_module
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
lora_sd = load_state_dict(model, merge_dtype)
print(f"merging...")
for key in lora_sd.keys():
if "lora_down" in key:
up_key = key.replace("lora_down", "lora_up")
alpha_key = key[:key.index("lora_down")] + 'alpha'
# find original module for this lora
module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight"
if module_name not in name_to_module:
print(f"no module found for LoRA weight: {key}")
continue
module = name_to_module[module_name]
# print(f"apply {key} to {module}")
down_weight = lora_sd[key]
up_weight = lora_sd[up_key]
dim = down_weight.size()[0]
alpha = lora_sd.get(alpha_key, dim)
scale = alpha / dim
# W <- W + U * D
weight = module.weight
if len(weight.size()) == 2:
# linear
weight = weight + ratio * (up_weight @ down_weight) * scale
else:
# conv2d
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) * scale
module.weight = torch.nn.Parameter(weight)
def merge_lora_models(models, ratios, merge_dtype):
merged_sd = {}
alpha = None
dim = None
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
lora_sd = load_state_dict(model, merge_dtype)
print(f"merging...")
for key in lora_sd.keys():
if 'alpha' in key:
if key in merged_sd:
assert merged_sd[key] == lora_sd[key], f"alpha mismatch / alphaが異なる場合、現時点ではマージできません"
else:
alpha = lora_sd[key].detach().numpy()
merged_sd[key] = lora_sd[key]
else:
if key in merged_sd:
assert merged_sd[key].size() == lora_sd[key].size(
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
merged_sd[key] = merged_sd[key] + lora_sd[key] * ratio
else:
if "lora_down" in key:
dim = lora_sd[key].size()[0]
merged_sd[key] = lora_sd[key] * ratio
print(f"dim (rank): {dim}, alpha: {alpha}")
if alpha is None:
alpha = dim
return merged_sd, dim, alpha
def merge(args):
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
def str_to_dtype(p):
if p == 'float':
return torch.float
if p == 'fp16':
return torch.float16
if p == 'bf16':
return torch.bfloat16
return None
merge_dtype = str_to_dtype(args.precision)
save_dtype = str_to_dtype(args.save_precision)
if save_dtype is None:
save_dtype = merge_dtype
if args.sd_model is not None:
print(f"loading SD model: {args.sd_model}")
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype)
print(f"saving SD model to: {args.save_to}")
model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet,
args.sd_model, 0, 0, save_dtype, vae)
else:
state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype)
print(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--v2", action='store_true',
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
parser.add_argument("--save_precision", type=str, default=None,
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ")
parser.add_argument("--precision", type=str, default="float",
choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度floatを推奨")
parser.add_argument("--sd_model", type=str, default=None,
help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする")
parser.add_argument("--save_to", type=str, default=None,
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
parser.add_argument("--models", type=str, nargs='*',
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors")
parser.add_argument("--ratios", type=float, nargs='*',
help="ratios for each model / それぞれのLoRAモデルの比率")
args = parser.parse_args()
merge(args)

198
networks/resize_lora.py Normal file
View File

@@ -0,0 +1,198 @@
# Convert LoRA to different rank approximation (should only be used to go to lower rank)
# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
# Thanks to cloneofsimo and kohya
import argparse
import os
import torch
from safetensors.torch import load_file, save_file, safe_open
from tqdm import tqdm
from library import train_util, model_util
def load_state_dict(file_name, dtype):
if model_util.is_safetensors(file_name):
sd = load_file(file_name)
with safe_open(file_name, framework="pt") as f:
metadata = f.metadata()
else:
sd = torch.load(file_name, map_location='cpu')
metadata = None
for key in list(sd.keys()):
if type(sd[key]) == torch.Tensor:
sd[key] = sd[key].to(dtype)
return sd, metadata
def save_to_file(file_name, model, state_dict, dtype, metadata):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
if model_util.is_safetensors(file_name):
save_file(model, file_name, metadata)
else:
torch.save(model, file_name)
def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
network_alpha = None
network_dim = None
verbose_str = "\n"
CLAMP_QUANTILE = 0.99
# Extract loaded lora dim and alpha
for key, value in lora_sd.items():
if network_alpha is None and 'alpha' in key:
network_alpha = value
if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
network_dim = value.size()[0]
if network_alpha is not None and network_dim is not None:
break
if network_alpha is None:
network_alpha = network_dim
scale = network_alpha/network_dim
new_alpha = float(scale*new_rank) # calculate new alpha from scale
print(f"old dimension: {network_dim}, old alpha: {network_alpha}, new alpha: {new_alpha}")
lora_down_weight = None
lora_up_weight = None
o_lora_sd = lora_sd.copy()
block_down_name = None
block_up_name = None
print("resizing lora...")
with torch.no_grad():
for key, value in tqdm(lora_sd.items()):
if 'lora_down' in key:
block_down_name = key.split(".")[0]
lora_down_weight = value
if 'lora_up' in key:
block_up_name = key.split(".")[0]
lora_up_weight = value
weights_loaded = (lora_down_weight is not None and lora_up_weight is not None)
if (block_down_name == block_up_name) and weights_loaded:
conv2d = (len(lora_down_weight.size()) == 4)
if conv2d:
lora_down_weight = lora_down_weight.squeeze()
lora_up_weight = lora_up_weight.squeeze()
if device:
org_device = lora_up_weight.device
lora_up_weight = lora_up_weight.to(args.device)
lora_down_weight = lora_down_weight.to(args.device)
full_weight_matrix = torch.matmul(lora_up_weight, lora_down_weight)
U, S, Vh = torch.linalg.svd(full_weight_matrix)
if verbose:
s_sum = torch.sum(torch.abs(S))
s_rank = torch.sum(torch.abs(S[:new_rank]))
verbose_str+=f"{block_down_name:76} | "
verbose_str+=f"sum(S) retained: {(s_rank)/s_sum:.1%}, max(S) ratio: {S[0]/S[new_rank]:0.1f}\n"
U = U[:, :new_rank]
S = S[:new_rank]
U = U @ torch.diag(S)
Vh = Vh[:new_rank, :]
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
low_val = -hi_val
U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)
if conv2d:
U = U.unsqueeze(2).unsqueeze(3)
Vh = Vh.unsqueeze(2).unsqueeze(3)
if device:
U = U.to(org_device)
Vh = Vh.to(org_device)
o_lora_sd[block_down_name + "." + "lora_down.weight"] = Vh.to(save_dtype).contiguous()
o_lora_sd[block_up_name + "." + "lora_up.weight"] = U.to(save_dtype).contiguous()
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(new_alpha).to(save_dtype)
block_down_name = None
block_up_name = None
lora_down_weight = None
lora_up_weight = None
weights_loaded = False
if verbose:
print(verbose_str)
print("resizing complete")
return o_lora_sd, network_dim, new_alpha
def resize(args):
def str_to_dtype(p):
if p == 'float':
return torch.float
if p == 'fp16':
return torch.float16
if p == 'bf16':
return torch.bfloat16
return None
merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32
save_dtype = str_to_dtype(args.save_precision)
if save_dtype is None:
save_dtype = merge_dtype
print("loading Model...")
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
print("resizing rank...")
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.verbose)
# update metadata
if metadata is None:
metadata = {}
comment = metadata.get("ss_training_comment", "")
metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}"
metadata["ss_network_dim"] = str(args.new_rank)
metadata["ss_network_alpha"] = str(new_alpha)
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
print(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--save_precision", type=str, default=None,
choices=[None, "float", "fp16", "bf16"], help="precision in saving, float if omitted / 保存時の精度、未指定時はfloat")
parser.add_argument("--new_rank", type=int, default=4,
help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
parser.add_argument("--save_to", type=str, default=None,
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
parser.add_argument("--model", type=str, default=None,
help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors")
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
parser.add_argument("--verbose", action="store_true",
help="Display verbose resizing information / rank変更時の詳細情報を出力する")
args = parser.parse_args()
resize(args)

164
networks/svd_merge_lora.py Normal file
View File

@@ -0,0 +1,164 @@
import math
import argparse
import os
import torch
from safetensors.torch import load_file, save_file
from tqdm import tqdm
import library.model_util as model_util
import lora
CLAMP_QUANTILE = 0.99
def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == '.safetensors':
sd = load_file(file_name)
else:
sd = torch.load(file_name, map_location='cpu')
for key in list(sd.keys()):
if type(sd[key]) == torch.Tensor:
sd[key] = sd[key].to(dtype)
return sd
def save_to_file(file_name, model, state_dict, dtype):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
if os.path.splitext(file_name)[1] == '.safetensors':
save_file(model, file_name)
else:
torch.save(model, file_name)
def merge_lora_models(models, ratios, new_rank, device, merge_dtype):
merged_sd = {}
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
lora_sd = load_state_dict(model, merge_dtype)
# merge
print(f"merging...")
for key in tqdm(list(lora_sd.keys())):
if 'lora_down' not in key:
continue
lora_module_name = key[:key.rfind(".lora_down")]
down_weight = lora_sd[key]
network_dim = down_weight.size()[0]
up_weight = lora_sd[lora_module_name + '.lora_up.weight']
alpha = lora_sd.get(lora_module_name + '.alpha', network_dim)
in_dim = down_weight.size()[1]
out_dim = up_weight.size()[0]
conv2d = len(down_weight.size()) == 4
print(lora_module_name, network_dim, alpha, in_dim, out_dim)
# make original weight if not exist
if lora_module_name not in merged_sd:
weight = torch.zeros((out_dim, in_dim, 1, 1) if conv2d else (out_dim, in_dim), dtype=merge_dtype)
if device:
weight = weight.to(device)
else:
weight = merged_sd[lora_module_name]
# merge to weight
if device:
up_weight = up_weight.to(device)
down_weight = down_weight.to(device)
# W <- W + U * D
scale = (alpha / network_dim)
if not conv2d: # linear
weight = weight + ratio * (up_weight @ down_weight) * scale
else:
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
).unsqueeze(2).unsqueeze(3) * scale
merged_sd[lora_module_name] = weight
# extract from merged weights
print("extract new lora...")
merged_lora_sd = {}
with torch.no_grad():
for lora_module_name, mat in tqdm(list(merged_sd.items())):
conv2d = (len(mat.size()) == 4)
if conv2d:
mat = mat.squeeze()
U, S, Vh = torch.linalg.svd(mat)
U = U[:, :new_rank]
S = S[:new_rank]
U = U @ torch.diag(S)
Vh = Vh[:new_rank, :]
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
low_val = -hi_val
U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)
up_weight = U
down_weight = Vh
if conv2d:
up_weight = up_weight.unsqueeze(2).unsqueeze(3)
down_weight = down_weight.unsqueeze(2).unsqueeze(3)
merged_lora_sd[lora_module_name + '.lora_up.weight'] = up_weight.to("cpu").contiguous()
merged_lora_sd[lora_module_name + '.lora_down.weight'] = down_weight.to("cpu").contiguous()
merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(new_rank)
return merged_lora_sd
def merge(args):
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
def str_to_dtype(p):
if p == 'float':
return torch.float
if p == 'fp16':
return torch.float16
if p == 'bf16':
return torch.bfloat16
return None
merge_dtype = str_to_dtype(args.precision)
save_dtype = str_to_dtype(args.save_precision)
if save_dtype is None:
save_dtype = merge_dtype
state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, args.device, merge_dtype)
print(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--save_precision", type=str, default=None,
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ")
parser.add_argument("--precision", type=str, default="float",
choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度floatを推奨")
parser.add_argument("--save_to", type=str, default=None,
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
parser.add_argument("--models", type=str, nargs='*',
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors")
parser.add_argument("--ratios", type=float, nargs='*',
help="ratios for each model / それぞれのLoRAモデルの比率")
parser.add_argument("--new_rank", type=int, default=4,
help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
args = parser.parse_args()
merge(args)

View File

@@ -1,23 +1,24 @@
accelerate==0.15.0
transformers==4.25.1
ftfy
albumentations
opencv-python
einops
transformers==4.26.0
ftfy==6.1.1
albumentations==1.3.0
opencv-python==4.7.0.68
einops==0.6.0
diffusers[torch]==0.10.2
pytorch_lightning
pytorch-lightning==1.9.0
bitsandbytes==0.35.0
tensorboard
tensorboard==2.10.1
safetensors==0.2.6
gradio
altair
easygui
gradio==3.16.2
altair==4.2.2
easygui==0.98.3
# for BLIP captioning
requests
timm==0.4.12
fairscale==0.4.4
requests==2.28.2
timm==0.6.12
fairscale==0.4.13
# for WD14 captioning
tensorflow<2.11
huggingface-hub
# tensorflow<2.11
tensorflow==2.10.1
huggingface-hub==0.12.0
# for kohya_ss library
.
.

View File

@@ -1,8 +1,4 @@
# convert Diffusers v1.x/v2.0 model to original Stable Diffusion
# v1: initial version
# v2: support safetensors
# v3: fix to support another format
# v4: support safetensors in Diffusers
import argparse
import os

View File

@@ -0,0 +1,122 @@
import glob
import os
import cv2
import argparse
import shutil
import math
from PIL import Image
import numpy as np
def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2, interpolation=None, save_as_png=False, copy_associated_files=False):
# Split the max_resolution string by "," and strip any whitespaces
max_resolutions = [res.strip() for res in max_resolution.split(',')]
# # Calculate max_pixels from max_resolution string
# max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1])
# Create destination folder if it does not exist
if not os.path.exists(dst_img_folder):
os.makedirs(dst_img_folder)
# Select interpolation method
if interpolation == 'lanczos4':
cv2_interpolation = cv2.INTER_LANCZOS4
elif interpolation == 'cubic':
cv2_interpolation = cv2.INTER_CUBIC
else:
cv2_interpolation = cv2.INTER_AREA
# Iterate through all files in src_img_folder
img_exts = (".png", ".jpg", ".jpeg", ".webp", ".bmp") # copy from train_util.py
for filename in os.listdir(src_img_folder):
# Check if the image is png, jpg or webp etc...
if not filename.endswith(img_exts):
# Copy the file to the destination folder if not png, jpg or webp etc (.txt or .caption or etc.)
shutil.copy(os.path.join(src_img_folder, filename), os.path.join(dst_img_folder, filename))
continue
# Load image
# img = cv2.imread(os.path.join(src_img_folder, filename))
image = Image.open(os.path.join(src_img_folder, filename))
if not image.mode == "RGB":
image = image.convert("RGB")
img = np.array(image, np.uint8)
base, _ = os.path.splitext(filename)
for max_resolution in max_resolutions:
# Calculate max_pixels from max_resolution string
max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1])
# Calculate current number of pixels
current_pixels = img.shape[0] * img.shape[1]
# Check if the image needs resizing
if current_pixels > max_pixels:
# Calculate scaling factor
scale_factor = max_pixels / current_pixels
# Calculate new dimensions
new_height = int(img.shape[0] * math.sqrt(scale_factor))
new_width = int(img.shape[1] * math.sqrt(scale_factor))
# Resize image
img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation)
else:
new_height, new_width = img.shape[0:2]
# Calculate the new height and width that are divisible by divisible_by (with/without resizing)
new_height = new_height if new_height % divisible_by == 0 else new_height - new_height % divisible_by
new_width = new_width if new_width % divisible_by == 0 else new_width - new_width % divisible_by
# Center crop the image to the calculated dimensions
y = int((img.shape[0] - new_height) / 2)
x = int((img.shape[1] - new_width) / 2)
img = img[y:y + new_height, x:x + new_width]
# Split filename into base and extension
new_filename = base + '+' + max_resolution + ('.png' if save_as_png else '.jpg')
# Save resized image in dst_img_folder
# cv2.imwrite(os.path.join(dst_img_folder, new_filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100])
image = Image.fromarray(img)
image.save(os.path.join(dst_img_folder, new_filename), quality=100)
proc = "Resized" if current_pixels > max_pixels else "Saved"
print(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}")
# If other files with same basename, copy them with resolution suffix
if copy_associated_files:
asoc_files = glob.glob(os.path.join(src_img_folder, base + ".*"))
for asoc_file in asoc_files:
ext = os.path.splitext(asoc_file)[1]
if ext in img_exts:
continue
for max_resolution in max_resolutions:
new_asoc_file = base + '+' + max_resolution + ext
print(f"Copy {asoc_file} as {new_asoc_file}")
shutil.copy(os.path.join(src_img_folder, asoc_file), os.path.join(dst_img_folder, new_asoc_file))
def main():
parser = argparse.ArgumentParser(
description='Resize images in a folder to a specified max resolution(s) / 指定されたフォルダ内の画像を指定した最大画像サイズ(面積)以下にアスペクト比を維持したままリサイズします')
parser.add_argument('src_img_folder', type=str, help='Source folder containing the images / 元画像のフォルダ')
parser.add_argument('dst_img_folder', type=str, help='Destination folder to save the resized images / リサイズ後の画像を保存するフォルダ')
parser.add_argument('--max_resolution', type=str,
help='Maximum resolution(s) in the format "512x512,384x384, etc, etc" / 最大画像サイズをカンマ区切りで指定 ("512x512,384x384, etc, etc" など)', default="512x512,384x384,256x256,128x128")
parser.add_argument('--divisible_by', type=int,
help='Ensure new dimensions are divisible by this value / リサイズ後の画像のサイズをこの値で割り切れるようにします', default=1)
parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4'],
default='area', help='Interpolation method for resizing / リサイズ時の補完方法')
parser.add_argument('--save_as_png', action='store_true', help='Save as png format / png形式で保存')
parser.add_argument('--copy_associated_files', action='store_true',
help='Copy files with same base name to images (captions etc) / 画像と同じファイル名(拡張子を除く)のファイルもコピーする')
args = parser.parse_args()
resize_images(args.src_img_folder, args.dst_img_folder, args.max_resolution,
args.divisible_by, args.interpolation, args.save_as_png, args.copy_associated_files)
if __name__ == '__main__':
main()

View File

@@ -35,10 +35,16 @@ def train(args):
train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, args.prior_loss_weight,
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset)
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
args.bucket_reso_steps, args.bucket_no_upscale,
args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset)
if args.no_token_padding:
train_dataset.disable_token_padding()
# 学習データのdropout率を設定する
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
train_dataset.make_buckets()
if args.debug_dataset:
@@ -118,6 +124,13 @@ def train(args):
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
print("use 8-bit Adam optimizer")
optimizer_class = bnb.optim.AdamW8bit
elif args.use_lion_optimizer:
try:
import lion_pytorch
except ImportError:
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
print("use Lion optimizer")
optimizer_class = lion_pytorch.Lion
else:
optimizer_class = torch.optim.AdamW
@@ -133,7 +146,7 @@ def train(args):
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers)
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
@@ -176,6 +189,8 @@ def train(args):
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
# 学習する
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
@@ -198,8 +213,11 @@ def train(args):
if accelerator.is_main_process:
accelerator.init_trackers("dreambooth")
loss_list = []
loss_total = 0.0
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset.set_current_epoch(epoch + 1)
# 指定したステップ数までText Encoderを学習するepoch最初の状態
unet.train()
@@ -207,7 +225,6 @@ def train(args):
if args.gradient_checkpointing or global_step < args.stop_text_encoder_training:
text_encoder.train()
loss_total = 0
for step, batch in enumerate(train_dataloader):
# 指定したステップ数でText Encoderの学習を止める
if global_step == args.stop_text_encoder_training:
@@ -224,10 +241,13 @@ def train(args):
else:
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
b_size = latents.shape[0]
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
b_size = latents.shape[0]
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
# Get the text embedding for conditioning
with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):
@@ -282,8 +302,13 @@ def train(args):
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
accelerator.log(logs, step=global_step)
if epoch == 0:
loss_list.append(current_loss)
else:
loss_total -= loss_list[step]
loss_list[step] = current_loss
loss_total += current_loss
avr_loss = loss_total / (step+1)
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
@@ -291,7 +316,7 @@ def train(args):
break
if args.logging_dir is not None:
logs = {"epoch_loss": loss_total / len(train_dataloader)}
logs = {"loss/epoch": loss_total / len(loss_list)}
accelerator.log(logs, step=epoch+1)
accelerator.wait_for_everyone()
@@ -324,7 +349,7 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, False)
train_util.add_dataset_arguments(parser, True, False, True)
train_util.add_training_arguments(parser, True)
train_util.add_sd_saving_arguments(parser)

View File

@@ -72,7 +72,7 @@ identifierとclassを使い、たとえば「shs dog」などでモデルを学
※LoRA等の追加ネットワークを学習する場合のコマンドは ``train_db.py`` ではなく ``train_network.py`` となります。また追加でnetwork_\*オプションが必要となりますので、LoRAのガイドを参照してください。
```
accelerate launch --num_cpu_threads_per_process 8 train_db.py
accelerate launch --num_cpu_threads_per_process 1 train_db.py
--pretrained_model_name_or_path=<.ckptまたは.safetensordまたはDiffusers版モデルのディレクトリ>
--train_data_dir=<学習用データのディレクトリ>
--reg_data_dir=<正則化画像のディレクトリ>
@@ -89,7 +89,7 @@ accelerate launch --num_cpu_threads_per_process 8 train_db.py
--gradient_checkpointing
```
num_cpu_threads_per_processにはCPUコア数を指定するとよいようです。
num_cpu_threads_per_processには通常は1を指定するとよいようです。
pretrained_model_name_or_pathに追加学習を行う元となるモデルを指定します。Stable Diffusionのcheckpointファイル.ckptまたは.safetensors、Diffusersのローカルディスクにあるモデルディレクトリ、DiffusersのモデルID"stabilityai/stable-diffusion-2"などが指定できます。学習後のモデルの保存形式はデフォルトでは元のモデルと同じになりますsave_model_asオプションで変更できます
@@ -159,7 +159,7 @@ v2.xモデルでWebUIで画像生成する場合、モデルの仕様が記述
![image](https://user-images.githubusercontent.com/52813779/210776915-061d79c3-6582-42c2-8884-8b91d2f07313.png)
各yamlファイルは[https://github.com/Stability-AI/stablediffusion/tree/main/configs/stable-diffusion](Stability AIのSD2.0のリポジトリ)にあります。
各yamlファイルは[Stability AIのSD2.0のリポジトリ](https://github.com/Stability-AI/stablediffusion/tree/main/configs/stable-diffusion)にあります。
# その他の学習オプション

View File

@@ -1,8 +1,16 @@
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
from torch.optim import Optimizer
from torch.cuda.amp import autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from typing import Optional, Union
import importlib
import argparse
import gc
import math
import os
import random
import time
import json
from tqdm import tqdm
import torch
@@ -18,7 +26,86 @@ def collate_fn(examples):
return examples[0]
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
logs = {"loss/current": current_loss, "loss/average": avr_loss}
if args.network_train_unet_only:
logs["lr/unet"] = lr_scheduler.get_last_lr()[0]
elif args.network_train_text_encoder_only:
logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
else:
logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
logs["lr/unet"] = lr_scheduler.get_last_lr()[-1] # may be same to textencoder
return logs
# Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
# code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
# Which is a newer release of diffusers than currently packaged with sd-scripts
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
def get_scheduler_fix(
name: Union[str, SchedulerType],
optimizer: Optimizer,
num_warmup_steps: Optional[int] = None,
num_training_steps: Optional[int] = None,
num_cycles: int = 1,
power: float = 1.0,
):
"""
Unified API to get any scheduler from its name.
Args:
name (`str` or `SchedulerType`):
The name of the scheduler to use.
optimizer (`torch.optim.Optimizer`):
The optimizer that will be used during training.
num_warmup_steps (`int`, *optional*):
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
num_training_steps (`int``, *optional*):
The number of training steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
num_cycles (`int`, *optional*):
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
power (`float`, *optional*, defaults to 1.0):
Power factor. See `POLYNOMIAL` scheduler
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
"""
name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
if name == SchedulerType.CONSTANT:
return schedule_func(optimizer)
# All other schedulers require `num_warmup_steps`
if num_warmup_steps is None:
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
if name == SchedulerType.CONSTANT_WITH_WARMUP:
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
# All other schedulers require `num_training_steps`
if num_training_steps is None:
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
if name == SchedulerType.COSINE_WITH_RESTARTS:
return schedule_func(
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
)
if name == SchedulerType.POLYNOMIAL:
return schedule_func(
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
)
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
def train(args):
session_id = random.randint(0, 2**32)
training_started_at = time.time()
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
@@ -35,22 +122,29 @@ def train(args):
print("Use DreamBooth method.")
train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, args.prior_loss_weight,
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset)
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
args.bucket_reso_steps, args.bucket_no_upscale,
args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range,
args.random_crop, args.debug_dataset)
else:
print("Train with captions.")
train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
args.bucket_reso_steps, args.bucket_no_upscale,
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
args.dataset_repeats, args.debug_dataset)
# 学習データのdropout率を設定する
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
train_dataset.make_buckets()
if args.debug_dataset:
train_util.debug_dataset(train_dataset)
return
if len(train_dataset) == 0:
print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してくださいtrain_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります")
return
# acceleratorを準備する
@@ -63,6 +157,11 @@ def train(args):
# モデルを読み込む
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
# work on low-ram device
if args.lowram:
text_encoder.to("cuda")
unet.to("cuda")
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
@@ -88,7 +187,8 @@ def train(args):
key, value = net_arg.split('=')
net_kwargs[key] = value
network = network_module.create_network(1.0, args.network_dim, vae, text_encoder, unet, **net_kwargs)
# if a new network is added in future, add if ~ then blocks for each network (;'∀')
network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs)
if network is None:
return
@@ -116,9 +216,18 @@ def train(args):
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
print("use 8-bit Adam optimizer")
optimizer_class = bnb.optim.AdamW8bit
elif args.use_lion_optimizer:
try:
import lion_pytorch
except ImportError:
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
print("use Lion optimizer")
optimizer_class = lion_pytorch.Lion
else:
optimizer_class = torch.optim.AdamW
optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
@@ -128,7 +237,7 @@ def train(args):
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers)
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
@@ -136,8 +245,11 @@ def train(args):
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# lr schedulerを用意する
lr_scheduler = diffusers.optimization.get_scheduler(
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
# lr_scheduler = diffusers.optimization.get_scheduler(
lr_scheduler = get_scheduler_fix(
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
# 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする
if args.full_fp16:
@@ -162,17 +274,26 @@ def train(args):
unet.requires_grad_(False)
unet.to(accelerator.device, dtype=weight_dtype)
text_encoder.requires_grad_(False)
text_encoder.to(accelerator.device, dtype=weight_dtype)
text_encoder.to(accelerator.device)
if args.gradient_checkpointing: # according to TI example in Diffusers, train is required
unet.train()
text_encoder.train()
# set top parameter requires_grad = True for gradient checkpointing works
text_encoder.text_model.embeddings.requires_grad_(True)
if type(text_encoder) == DDP:
text_encoder.module.text_model.embeddings.requires_grad_(True)
else:
text_encoder.text_model.embeddings.requires_grad_(True)
else:
unet.eval()
text_encoder.eval()
# support DistributedDataParallel
if type(text_encoder) == DDP:
text_encoder = text_encoder.module
unet = unet.module
network = network.module
network.prepare_grad_etc(text_encoder, unet)
if not cache_latents:
@@ -192,6 +313,8 @@ def train(args):
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
# 学習する
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
@@ -206,21 +329,26 @@ def train(args):
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
metadata = {
"ss_session_id": session_id, # random integer indicating which group of epochs the model came from
"ss_training_started_at": training_started_at, # unix timestamp
"ss_output_name": args.output_name,
"ss_learning_rate": args.learning_rate,
"ss_text_encoder_lr": args.text_encoder_lr,
"ss_unet_lr": args.unet_lr,
"ss_num_train_images": train_dataset.num_train_images, # includes repeating TODO more detailed data
"ss_num_train_images": train_dataset.num_train_images, # includes repeating
"ss_num_reg_images": train_dataset.num_reg_images,
"ss_num_batches_per_epoch": len(train_dataloader),
"ss_num_epochs": num_train_epochs,
"ss_batch_size_per_device": args.train_batch_size,
"ss_total_batch_size": total_batch_size,
"ss_gradient_checkpointing": args.gradient_checkpointing,
"ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
"ss_max_train_steps": args.max_train_steps,
"ss_lr_warmup_steps": args.lr_warmup_steps,
"ss_lr_scheduler": args.lr_scheduler,
"ss_network_module": args.network_module,
"ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim
"ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim
"ss_network_alpha": args.network_alpha, # some networks may not use this value
"ss_mixed_precision": args.mixed_precision,
"ss_full_fp16": bool(args.full_fp16),
"ss_v2": bool(args.v2),
@@ -232,10 +360,19 @@ def train(args):
"ss_random_crop": bool(args.random_crop),
"ss_shuffle_caption": bool(args.shuffle_caption),
"ss_cache_latents": bool(args.cache_latents),
"ss_enable_bucket": bool(train_dataset.enable_bucket), # TODO move to BaseDataset from DB/FT
"ss_min_bucket_reso": args.min_bucket_reso, # TODO get from dataset
"ss_max_bucket_reso": args.max_bucket_reso,
"ss_seed": args.seed
"ss_enable_bucket": bool(train_dataset.enable_bucket),
"ss_min_bucket_reso": train_dataset.min_bucket_reso,
"ss_max_bucket_reso": train_dataset.max_bucket_reso,
"ss_seed": args.seed,
"ss_keep_tokens": args.keep_tokens,
"ss_noise_offset": args.noise_offset,
"ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info),
"ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info),
"ss_tag_frequency": json.dumps(train_dataset.tag_frequency),
"ss_bucket_info": json.dumps(train_dataset.bucket_info),
"ss_training_comment": args.training_comment, # will not be updated after training
"ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
"ss_optimizer": optimizer_name
}
# uncomment if another network is added
@@ -246,6 +383,7 @@ def train(args):
sd_model_name = args.pretrained_model_name_or_path
if os.path.exists(sd_model_name):
metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name)
metadata["ss_new_sd_model_hash"] = train_util.calculate_sha256(sd_model_name)
sd_model_name = os.path.basename(sd_model_name)
metadata["ss_sd_model_name"] = sd_model_name
@@ -253,6 +391,7 @@ def train(args):
vae_name = args.vae
if os.path.exists(vae_name):
metadata["ss_vae_hash"] = train_util.model_hash(vae_name)
metadata["ss_new_vae_hash"] = train_util.calculate_sha256(vae_name)
vae_name = os.path.basename(vae_name)
metadata["ss_vae_name"] = vae_name
@@ -267,13 +406,16 @@ def train(args):
if accelerator.is_main_process:
accelerator.init_trackers("network_train")
loss_list = []
loss_total = 0.0
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset.set_current_epoch(epoch + 1)
metadata["ss_epoch"] = str(epoch+1)
network.on_epoch_start(text_encoder, unet)
loss_total = 0
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(network):
with torch.no_grad():
@@ -292,6 +434,9 @@ def train(args):
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
@@ -302,7 +447,8 @@ def train(args):
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
with autocast():
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
if args.v_parameterization:
# v-parameterization training
@@ -333,20 +479,25 @@ def train(args):
global_step += 1
current_loss = loss.detach().item()
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
accelerator.log(logs, step=global_step)
if epoch == 0:
loss_list.append(current_loss)
else:
loss_total -= loss_list[step]
loss_list[step] = current_loss
loss_total += current_loss
avr_loss = loss_total / (step+1)
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if args.logging_dir is not None:
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
logs = {"epoch_loss": loss_total / len(train_dataloader)}
logs = {"loss/epoch": loss_total / len(loss_list)}
accelerator.log(logs, step=epoch+1)
accelerator.wait_for_everyone()
@@ -402,26 +553,34 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, True)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, True)
parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt")
parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors")
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
parser.add_argument("--lr_scheduler_power", type=float, default=1,
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
parser.add_argument("--network_weights", type=str, default=None,
help="pretrained weights for network / 学習するネットワークの初期重み")
parser.add_argument("--network_module", type=str, default=None, help='network module to train / 学習対象のネットワークのモジュール')
parser.add_argument("--network_dim", type=int, default=None,
help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)')
parser.add_argument("--network_alpha", type=float, default=1,
help='alpha for LoRA weight scaling, default 1 (same as network_dim for same behavior as old version) / LoRaの重み調整のalpha値、デフォルト1旧バージョンと同じ動作をするにはnetwork_dimと同じ値を指定')
parser.add_argument("--network_args", type=str, default=None, nargs='*',
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
parser.add_argument("--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する")
parser.add_argument("--network_train_text_encoder_only", action="store_true",
help="only training Text Encoder part / Text Encoder関連部分のみ学習する")
parser.add_argument("--training_comment", type=str, default=None,
help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列")
args = parser.parse_args()
train(args)

View File

@@ -10,7 +10,7 @@
cloneofsimo氏のリポジトリ、およびd8ahazard氏の[Dreambooth Extension for Stable-Diffusion-WebUI](https://github.com/d8ahazard/sd_dreambooth_extension)とは、現時点では互換性がありません。いくつかの機能拡張を行っているためです(後述)。
WebUI等で画像生成する場合には、学習したLoRAのモデルを学習元のStable Diffusionのモデルにこのリポジトリ内のスクリプトであらかじめマージしておくか、こちらの[WebUI用extention](https://github.com/kohya-ss/sd-webui-additional-networks)を使ってください。
WebUI等で画像生成する場合には、学習したLoRAのモデルを学習元のStable Diffusionのモデルにこのリポジトリ内のスクリプトであらかじめマージしておくか、こちらの[WebUI用extension](https://github.com/kohya-ss/sd-webui-additional-networks)を使ってください。
## 学習方法
@@ -24,7 +24,7 @@ DreamBoothの手法identifiersksなどとclass、オプションで正
[DreamBoothのガイド](./train_db_README-ja.md) を参照してデータを用意してください。
学習するとき、train_db.pyの代わりにtrain_network.pyを指定してください。
学習するとき、train_db.pyの代わりにtrain_network.pyを指定してください。そして「LoRAの学習のためのオプション」にあるようにLoRA関連のオプション``network_dim``や``network_alpha``など)を追加してください。
ほぼすべてのオプションStable Diffusionのモデル保存関係を除くが使えますが、stop_text_encoder_trainingはサポートしていません。
@@ -32,7 +32,7 @@ DreamBoothの手法identifiersksなどとclass、オプションで正
[fine-tuningのガイド](./fine_tune_README_ja.md) を参照し、各手順を実行してください。
学習するとき、fine_tune.pyの代わりにtrain_network.pyを指定してください。ほぼすべてのオプションモデル保存関係を除くがそのまま使えます。
学習するとき、fine_tune.pyの代わりにtrain_network.pyを指定してください。ほぼすべてのオプションモデル保存関係を除くがそのまま使えます。そして「LoRAの学習のためのオプション」にあるようにLoRA関連のオプション``network_dim``や``network_alpha``など)を追加してください。
なお「latentsの事前取得」は行わなくても動作します。VAEから学習時またはキャッシュ時にlatentを取得するため学習速度は遅くなりますが、代わりにcolor_augが使えるようになります。
@@ -45,7 +45,7 @@ train_network.pyでは--network_moduleオプションに、学習対象のモジ
以下はコマンドラインの例ですDreamBooth手法
```
accelerate launch --num_cpu_threads_per_process 12 train_network.py
accelerate launch --num_cpu_threads_per_process 1 train_network.py
--pretrained_model_name_or_path=..\models\model.ckpt
--train_data_dir=..\data\db\char1 --output_dir=..\lora_train1
--reg_data_dir=..\data\db\reg1 --prior_loss_weight=1.0
@@ -55,12 +55,14 @@ accelerate launch --num_cpu_threads_per_process 12 train_network.py
--network_module=networks.lora
```
--output_dirオプションで指定したディレクトリに、LoRAのモデルが保存されます。
--output_dirオプションで指定したフォルダに、LoRAのモデルが保存されます。
その他、以下のオプションが指定できます。
* --network_dim
* LoRAの次元数を指定します(``--networkdim=4``など。省略時は4になります。数が多いほど表現力は増しますが、学習に必要なメモリ、時間は増えます。また闇雲に増やしても良くないようです。
* LoRAのRANKを指定します(``--networkdim=4``など。省略時は4になります。数が多いほど表現力は増しますが、学習に必要なメモリ、時間は増えます。また闇雲に増やしても良くないようです。
* --network_alpha
* アンダーフローを防ぎ安定して学習するための ``alpha`` 値を指定します。デフォルトは1です。``network_dim``と同じ値を指定すると以前のバージョンと同じ動作になります。
* --network_weights
* 学習前に学習済みのLoRAの重みを読み込み、そこから追加で学習します。
* --network_train_unet_only
@@ -126,7 +128,7 @@ python networks\merge_lora.py
--ratiosにそれぞれのモデルの比率どのくらい重みを元モデルに反映するかを0~1.0の数値で指定します。二つのモデルを一対一でマージす場合は、「0.5 0.5」になります。「1.0 1.0」では合計の重みが大きくなりすぎて、恐らく結果はあまり望ましくないものになると思われます。
v1で学習したLoRAとv2で学習したLoRA、次元数の異なるLoRAはマージできません。U-NetだけのLoRAとU-Net+Text EncoderのLoRAはマージできるはずですが、結果は未知数です。
v1で学習したLoRAとv2で学習したLoRA、rank次元数や``alpha``の異なるLoRAはマージできません。U-NetだけのLoRAとU-Net+Text EncoderのLoRAはマージできるはずですが、結果は未知数です。
### その他のオプション
@@ -138,7 +140,7 @@ v1で学習したLoRAとv2で学習したLoRA、次元数の異なるLoRAはマ
## 当リポジトリ内の画像生成スクリプトで生成する
gen_img_diffusers.pyに、--network_module、--network_weights、--network_dim省略可の各オプションを追加してください。意味は学習時と同様です。
gen_img_diffusers.pyに、--network_module、--network_weightsの各オプションを追加してください。意味は学習時と同様です。
--network_mulオプションで0~1.0の数値を指定すると、LoRAの適用率を変えられます。
@@ -176,6 +178,38 @@ Text Encoderが二つのモデルで同じ場合にはLoRAはU-NetのみのLoRA
- --save_precision
- LoRAの保存形式を"float", "fp16", "bf16"から指定します。省略時はfloatになります。
## 画像リサイズスクリプト
(のちほどドキュメントを整理しますがとりあえずここに説明を書いておきます。)
Aspect Ratio Bucketingの機能拡張で、小さな画像については拡大しないでそのまま教師データとすることが可能になりました。元の教師画像を縮小した画像を、教師データに加えると精度が向上したという報告とともに前処理用のスクリプトをいただきましたので整備して追加しました。bmaltais氏に感謝します。
### スクリプトの実行方法
以下のように指定してください。元の画像そのまま、およびリサイズ後の画像が変換先フォルダに保存されます。リサイズ後の画像には、ファイル名に ``+512x512`` のようにリサイズ先の解像度が付け加えられます(画像サイズとは異なります)。リサイズ先の解像度より小さい画像は拡大されることはありません。
```
python tools\resize_images_to_resolution.py --max_resolution 512x512,384x384,256x256 --save_as_png
--copy_associated_files 元画像フォルダ 変換先フォルダ
```
元画像フォルダ内の画像ファイルが、指定した解像度(複数指定可)と同じ面積になるようにリサイズされ、変換先フォルダに保存されます。画像以外のファイルはそのままコピーされます。
``--max_resolution`` オプションにリサイズ先のサイズを例のように指定してください。面積がそのサイズになるようにリサイズします。複数指定すると、それぞれの解像度でリサイズされます。``512x512,384x384,256x256``なら、変換先フォルダの画像は、元サイズとリサイズ後サイズ×3の計4枚になります。
``--save_as_png`` オプションを指定するとpng形式で保存します。省略するとjpeg形式quality=100で保存されます。
``--copy_associated_files`` オプションを指定すると、拡張子を除き画像と同じファイル名(たとえばキャプションなど)のファイルが、リサイズ後の画像のファイル名と同じ名前でコピーされます。
### その他のオプション
- divisible_by
- リサイズ後の画像のサイズ(縦、横のそれぞれ)がこの値で割り切れるように、画像中心を切り出します。
- interpolation
- 縮小時の補完方法を指定します。``area, cubic, lanczos4``から選択可能で、デフォルトは``area``です。
## 追加情報
### cloneofsimo氏のリポジトリとの違い

512
train_textual_inversion.py Normal file
View File

@@ -0,0 +1,512 @@
import importlib
import argparse
import gc
import math
import os
from tqdm import tqdm
import torch
from accelerate.utils import set_seed
import diffusers
from diffusers import DDPMScheduler
import library.train_util as train_util
from library.train_util import DreamBoothDataset, FineTuningDataset
imagenet_templates_small = [
"a photo of a {}",
"a rendering of a {}",
"a cropped photo of the {}",
"the photo of a {}",
"a photo of a clean {}",
"a photo of a dirty {}",
"a dark photo of the {}",
"a photo of my {}",
"a photo of the cool {}",
"a close-up photo of a {}",
"a bright photo of the {}",
"a cropped photo of a {}",
"a photo of the {}",
"a good photo of the {}",
"a photo of one {}",
"a close-up photo of the {}",
"a rendition of the {}",
"a photo of the clean {}",
"a rendition of a {}",
"a photo of a nice {}",
"a good photo of a {}",
"a photo of the nice {}",
"a photo of the small {}",
"a photo of the weird {}",
"a photo of the large {}",
"a photo of a cool {}",
"a photo of a small {}",
]
imagenet_style_templates_small = [
"a painting in the style of {}",
"a rendering in the style of {}",
"a cropped painting in the style of {}",
"the painting in the style of {}",
"a clean painting in the style of {}",
"a dirty painting in the style of {}",
"a dark painting in the style of {}",
"a picture in the style of {}",
"a cool painting in the style of {}",
"a close-up painting in the style of {}",
"a bright painting in the style of {}",
"a cropped painting in the style of {}",
"a good painting in the style of {}",
"a close-up painting in the style of {}",
"a rendition in the style of {}",
"a nice painting in the style of {}",
"a small painting in the style of {}",
"a weird painting in the style of {}",
"a large painting in the style of {}",
]
def collate_fn(examples):
return examples[0]
def train(args):
if args.output_name is None:
args.output_name = args.token_string
use_template = args.use_object_template or args.use_style_template
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
cache_latents = args.cache_latents
use_dreambooth_method = args.in_json is None
if args.seed is not None:
set_seed(args.seed)
tokenizer = train_util.load_tokenizer(args)
# acceleratorを準備する
print("prepare accelerator")
accelerator, unwrap_model = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
# モデルを読み込む
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
# Convert the init_word to token_id
if args.init_word is not None:
init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False)
if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token:
print(
f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}")
else:
init_token_ids = None
# add new word to tokenizer, count is num_vectors_per_token
token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
num_added_tokens = tokenizer.add_tokens(token_strings)
assert num_added_tokens == args.num_vectors_per_token, f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}"
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
print(f"tokens are added: {token_ids}")
assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered"
assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}"
# Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder.resize_token_embeddings(len(tokenizer))
# Initialise the newly added placeholder token with the embeddings of the initializer token
token_embeds = text_encoder.get_input_embeddings().weight.data
if init_token_ids is not None:
for i, token_id in enumerate(token_ids):
token_embeds[token_id] = token_embeds[init_token_ids[i % len(init_token_ids)]]
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
# load weights
if args.weights is not None:
embeddings = load_weights(args.weights)
assert len(token_ids) == len(
embeddings), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}"
# print(token_ids, embeddings.size())
for token_id, embedding in zip(token_ids, embeddings):
token_embeds[token_id] = embedding
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
print(f"weighs loaded")
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
# データセットを準備する
if use_dreambooth_method:
print("Use DreamBooth method.")
train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
args.bucket_reso_steps, args.bucket_no_upscale,
args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset)
else:
print("Train with captions.")
train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
args.bucket_reso_steps, args.bucket_no_upscale,
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
args.dataset_repeats, args.debug_dataset)
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
if use_template:
print("use template for training captions. is object: {args.use_object_template}")
templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small
replace_to = " ".join(token_strings)
captions = []
for tmpl in templates:
captions.append(tmpl.format(replace_to))
train_dataset.add_replacement("", captions)
elif args.num_vectors_per_token > 1:
replace_to = " ".join(token_strings)
train_dataset.add_replacement(args.token_string, replace_to)
train_dataset.make_buckets()
if args.debug_dataset:
train_util.debug_dataset(train_dataset, show_input_ids=True)
return
if len(train_dataset) == 0:
print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
return
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset.cache_latents(vae)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
text_encoder.gradient_checkpointing_enable()
# 学習に必要なクラスを準備する
print("prepare optimizer, data loader etc.")
# 8-bit Adamを使う
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
print("use 8-bit Adam optimizer")
optimizer_class = bnb.optim.AdamW8bit
elif args.use_lion_optimizer:
try:
import lion_pytorch
except ImportError:
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
print("use Lion optimizer")
optimizer_class = lion_pytorch.Lion
else:
optimizer_class = torch.optim.AdamW
trainable_params = text_encoder.get_input_embeddings().parameters()
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# lr schedulerを用意する
lr_scheduler = diffusers.optimization.get_scheduler(
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
# acceleratorがなんかよろしくやってくれるらしい
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder, optimizer, train_dataloader, lr_scheduler)
index_no_updates = torch.arange(len(tokenizer)) < token_ids[0]
# print(len(index_no_updates), torch.sum(index_no_updates))
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
# Freeze all parameters except for the token embeddings in text encoder
text_encoder.requires_grad_(True)
text_encoder.text_model.encoder.requires_grad_(False)
text_encoder.text_model.final_layer_norm.requires_grad_(False)
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
# text_encoder.text_model.embeddings.token_embedding.requires_grad_(True)
unet.requires_grad_(False)
unet.to(accelerator.device, dtype=weight_dtype)
if args.gradient_checkpointing: # according to TI example in Diffusers, train is required
unet.train()
else:
unet.eval()
if not cache_latents:
vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=weight_dtype)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
text_encoder.to(weight_dtype)
# resumeする
if args.resume is not None:
print(f"resume training from state: {args.resume}")
accelerator.load_state(args.resume)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
# 学習する
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
print("running training / 学習開始")
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
num_train_timesteps=1000, clip_sample=False)
if accelerator.is_main_process:
accelerator.init_trackers("textual_inversion")
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset.set_current_epoch(epoch + 1)
text_encoder.train()
loss_total = 0
bef_epo_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(text_encoder):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
b_size = latents.shape[0]
# Get the text embedding for conditioning
input_ids = batch["input_ids"].to(accelerator.device)
# weight_dtype) use float instead of fp16/bf16 because text encoder is float
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, torch.float)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = text_encoder.get_input_embeddings().parameters()
accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# Let's make sure we don't update any embedding weights besides the newly added token
with torch.no_grad():
unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[index_no_updates]
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
current_loss = loss.detach().item()
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
accelerator.log(logs, step=global_step)
loss_total += current_loss
avr_loss = loss_total / (step+1)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(train_dataloader)}
accelerator.log(logs, step=epoch+1)
accelerator.wait_for_everyone()
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
# d = updated_embs - bef_epo_embs
# print(bef_epo_embs.size(), updated_embs.size(), d.mean(), d.min())
if args.save_every_n_epochs is not None:
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
def save_func():
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"saving checkpoint: {ckpt_file}")
save_weights(ckpt_file, updated_embs, save_dtype)
def remove_old_func(old_epoch_no):
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
if os.path.exists(old_ckpt_file):
print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
if saving and args.save_state:
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
# end of epoch
is_main_process = accelerator.is_main_process
if is_main_process:
text_encoder = unwrap_model(text_encoder)
accelerator.end_training()
if args.save_state:
train_util.save_state_on_train_end(args, accelerator)
updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone()
del accelerator # この後メモリを使うのでこれは消す
if is_main_process:
os.makedirs(args.output_dir, exist_ok=True)
model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
ckpt_name = model_name + '.' + args.save_model_as
ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"save trained model to {ckpt_file}")
save_weights(ckpt_file, updated_embs, save_dtype)
print("model saved.")
def save_weights(file, updated_embs, save_dtype):
state_dict = {"emb_params": updated_embs}
if save_dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(save_dtype)
state_dict[key] = v
if os.path.splitext(file)[1] == '.safetensors':
from safetensors.torch import save_file
save_file(state_dict, file)
else:
torch.save(state_dict, file) # can be loaded in Web UI
def load_weights(file):
if os.path.splitext(file)[1] == '.safetensors':
from safetensors.torch import load_file
data = load_file(file)
else:
# compatible to Web UI's file format
data = torch.load(file, map_location='cpu')
if type(data) != dict:
raise ValueError(f"weight file is not dict / 重みファイルがdict形式ではありません: {file}")
if 'string_to_param' in data: # textual inversion embeddings
data = data['string_to_param']
if hasattr(data, '_parameters'): # support old PyTorch?
data = getattr(data, '_parameters')
emb = next(iter(data.values()))
if type(emb) != torch.Tensor:
raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {file}")
if len(emb.size()) == 1:
emb = emb.unsqueeze(0)
return emb
if __name__ == '__main__':
parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, False)
train_util.add_training_arguments(parser, True)
parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
help="format to save the model (default is .pt) / モデル保存時の形式デフォルトはpt")
parser.add_argument("--weights", type=str, default=None,
help="embedding weights to initialize / 学習するネットワークの初期重み")
parser.add_argument("--num_vectors_per_token", type=int, default=1,
help='number of vectors per token / トークンに割り当てるembeddingsの要素数')
parser.add_argument("--token_string", type=str, default=None,
help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること")
parser.add_argument("--init_word", type=str, default=None,
help="words to initialize vector / ベクトルを初期化に使用する単語、複数可")
parser.add_argument("--use_object_template", action='store_true',
help="ignore caption and use default templates for object / キャプションは使わずデフォルトの物体用テンプレートで学習する")
parser.add_argument("--use_style_template", action='store_true',
help="ignore caption and use default templates for stype / キャプションは使わずデフォルトのスタイル用テンプレートで学習する")
args = parser.parse_args()
train(args)

63
train_ti_README-ja.md Normal file
View File

@@ -0,0 +1,63 @@
## Textual Inversionの学習について
[Textual Inversion](https://textual-inversion.github.io/)です。実装に当たっては https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion を大いに参考にしました。
学習したモデルはWeb UIでもそのまま使えます。
なお恐らくSD2.xにも対応していますが現時点では未テストです。
## 学習方法
``train_textual_inversion.py`` を用います。
データの準備については ``train_network.py`` と全く同じですので、[そちらのドキュメント](./train_network_README-ja.md)を参照してください。
## オプション
以下はコマンドラインの例ですDreamBooth手法
```
accelerate launch --num_cpu_threads_per_process 1 train_textual_inversion.py
--pretrained_model_name_or_path=..\models\model.ckpt
--train_data_dir=..\data\db\char1 --output_dir=..\ti_train1
--resolution=448,640 --train_batch_size=1 --learning_rate=1e-4
--max_train_steps=400 --use_8bit_adam --xformers --mixed_precision=fp16
--save_every_n_epochs=1 --save_model_as=safetensors --clip_skip=2 --seed=42 --color_aug
--token_string=mychar4 --init_word=cute --num_vectors_per_token=4
```
``--token_string`` に学習時のトークン文字列を指定します。__学習時のプロンプトは、この文字列を含むようにしてくださいtoken_stringがmychar4なら、``mychar4 1girl`` など__。プロンプトのこの文字列の部分が、Textual Inversionの新しいtokenに置換されて学習されます。
プロンプトにトークン文字列が含まれているかどうかは、``--debug_dataset`` で置換後のtoken idが表示されますので、以下のように ``49408`` 以降のtokenが存在するかどうかで確認できます。
```
input ids: tensor([[49406, 49408, 49409, 49410, 49411, 49412, 49413, 49414, 49415, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407]])
```
tokenizerがすでに持っている単語一般的な単語は使用できません。
``--init_word`` にembeddingsを初期化するときのコピー元トークンの文字列を指定します。学ばせたい概念が近いものを選ぶとよいようです。二つ以上のトークンになる文字列は指定できません。
``--num_vectors_per_token`` にいくつのトークンをこの学習で使うかを指定します。多いほうが表現力が増しますが、その分多くのトークンを消費します。たとえばnum_vectors_per_token=8の場合、指定したトークン文字列は一般的なプロンプトの77トークン制限のうち8トークンを消費します。
その他、以下のオプションが指定できます。
* --weights
* 学習前に学習済みのembeddingsを読み込み、そこから追加で学習します。
* --use_object_template
* キャプションではなく既定の物体用テンプレート文字列(``a photo of a {}``など)で学習します。公式実装と同じになります。キャプションは無視されます。
* --use_style_template
* キャプションではなく既定のスタイル用テンプレート文字列で学習します(``a painting in the style of {}``など)。公式実装と同じになります。キャプションは無視されます。
## 当リポジトリ内の画像生成スクリプトで生成する
gen_img_diffusers.pyに、``--textual_inversion_embeddings`` オプションで学習したembeddingsファイルを指定してください複数可。プロンプトでembeddingsファイルのファイル名拡張子を除くを使うと、そのembeddingsが適用されます。