Compare commits

..

563 Commits

Author SHA1 Message Date
Kohya S
8f4ee8fc34 doc: update README for latest 2025-03-21 22:05:48 +09:00
Kohya S.
367f348430 Merge pull request #1964 from Nekotekina/main
Fix missing text encoder attn modules
2025-03-21 21:59:03 +09:00
Ivan Chikish
acdca2abb7 Fix [occasionally] missing text encoder attn modules
Should fix #1952
I added alternative name for CLIPAttention.
I have no idea why this name changed.
Now it should accept both names.
2025-03-01 20:35:45 +03:00
Kohya S.
6e3c1d0b58 Merge pull request #1879 from kohya-ss/dev
merge dev to main
2025-01-17 23:25:56 +09:00
Kohya S
345daaa986 update README for merging 2025-01-17 23:22:38 +09:00
Kohya S
6adb69be63 Merge branch 'main' into dev 2024-11-07 21:42:44 +09:00
Kohya S
e5ac095749 add about dev and sd3 branch to README 2024-11-07 21:39:47 +09:00
Kohya S
e070bd9973 Merge branch 'main' into dev 2024-10-27 10:19:55 +09:00
Kohya S
ca44e3e447 reduce VRAM usage, instead of increasing main RAM usage 2024-10-27 10:19:05 +09:00
Kohya S
900d551a6a Merge branch 'main' into dev 2024-10-26 22:02:36 +09:00
Kohya S
56b4ea963e Fix LoRA metadata hash calculation bug in svd_merge_lora.py, sdxl_merge_lora.py, and resize_lora.py closes #1722 2024-10-26 22:01:10 +09:00
Kohya S
b1e6504007 update README 2024-10-25 18:56:25 +09:00
Kohya S.
b8ae745d0c Merge pull request #1717 from catboxanon/fix/remove-vpred-warnings
Remove v-pred warnings
2024-10-25 18:49:40 +09:00
Kohya S.
c632af860e Merge pull request #1715 from catboxanon/vpred-ztsnr-fixes
Update debiased estimation loss function to accommodate V-pred
2024-10-25 18:48:14 +09:00
catboxanon
be14c06267 Remove v-pred warnings
Different model architectures, such as SDXL, can take advantage of
v-pred. It doesn't make sense to include these warnings anymore.
2024-10-22 12:13:51 -04:00
catboxanon
0e7c592933 Remove scale_v_pred_loss_like_noise_pred deprecation
https://github.com/kohya-ss/sd-scripts/pull/1715#issuecomment-2427876376
2024-10-22 11:19:34 -04:00
catboxanon
e1b63c2249 Only add warning for deprecated scaling vpred loss function 2024-10-21 08:12:53 -04:00
catboxanon
8fc30f8205 Fix training for V-pred and ztSNR
1) Updates debiased estimation loss function for V-pred.
2) Prevents now-deprecated scaling of loss if ztSNR is enabled.
2024-10-21 07:34:33 -04:00
Kohya S
012e7e63a5 fix to work linear/cosine scheduler closes #1651 ref #1393 2024-09-29 23:18:16 +09:00
Kohya S
1567549220 update help text #1632 2024-09-29 09:51:36 +09:00
Kohya S
fe2aa32484 adjust min/max bucket reso divisible by reso steps #1632 2024-09-29 09:49:25 +09:00
Kohya S
ce49ced699 update readme 2024-09-26 21:37:40 +09:00
Kohya S
a94bc84dec fix to work bitsandbytes optimizers with full path #1640 2024-09-26 21:37:31 +09:00
Kohya S.
4296e286b8 Merge pull request #1640 from sdbds/ademamix8bit
New optimizer:AdEMAMix8bit and PagedAdEMAMix8bit
2024-09-26 21:20:19 +09:00
Kohya S
bf91bea2e4 fix flip_aug, alpha_mask, random_crop issue in caching 2024-09-26 20:51:40 +09:00
sdbds
1beddd84e5 delete code for cleaning 2024-09-25 22:58:26 +08:00
Kohya S
e74f58148c update README 2024-09-25 20:55:50 +09:00
Kohya S.
c1d16a76d6 Merge pull request #1628 from recris/huber-timesteps
Make timesteps work in the standard way when Huber loss is used
2024-09-25 20:52:55 +09:00
sdbds
ab7b231870 init 2024-09-25 19:38:52 +08:00
Kohya S
29177d2f03 retain alpha in pil_resize backport #1619 2024-09-23 21:14:03 +09:00
recris
e1f23af1bc make timestep sampling behave in the standard way when huber loss is used 2024-09-21 13:21:56 +01:00
Kohya S.
0b7927e50b Merge pull request #1615 from Maru-mee/patch-1
Bug fix: alpha_mask load
2024-09-19 21:49:49 +09:00
Kohya S
d7e14721e2 Merge branch 'main' into dev 2024-09-19 21:15:19 +09:00
Kohya S
9c757c2fba fix SDXL block index to match LBW 2024-09-19 21:14:57 +09:00
Maru-mee
e7040669bc Bug fix: alpha_mask load 2024-09-19 15:47:06 +09:00
Kohya S
93d9fbf607 improve OFT implementation closes #944 2024-09-13 22:37:11 +09:00
Kohya S
43ad73860d Merge branch 'main' into dev 2024-09-13 21:29:51 +09:00
Kohya S
b755ebd0a4 add LBW support for SDXL merge LoRA 2024-09-13 21:29:31 +09:00
Kohya S
f4a0bea6dc format by black 2024-09-13 21:26:06 +09:00
terracottahaniwa
734d2e5b2b Support Lora Block Weight (LBW) to svd_merge_lora.py (#1575)
* support lora block weight

* solve license incompatibility

* Fix issue: lbw index calculation
2024-09-13 20:45:35 +09:00
Kohya S
9d2860760d Merge branch 'main' into dev 2024-09-13 19:48:53 +09:00
Kohya S
3387dc7306 formatting, update README 2024-09-13 19:45:42 +09:00
Kohya S
57ae44eb61 refactor to make safer 2024-09-13 19:45:00 +09:00
Maru-mee
1d7118a622 Support : OFT merge to base model (#1580)
* Support : OFT merge to base model

* Fix typo

* Fix typo_2

* Delete unused parameter 'eye'
2024-09-13 19:01:36 +09:00
Kohya S
c7c666b182 fix typo 2024-09-11 22:12:31 +09:00
Kohya S
6dbfd47a59 Fix to work PIECEWISE_CONSTANT, update requirement.txt and README #1393 2024-09-11 21:44:36 +09:00
青龍聖者@bdsqlsz
fd68703f37 Add New lr scheduler (#1393)
* add new lr scheduler

* fix bugs and use num_cycles / 2

* Update requirements.txt

* add num_cycles for min lr

* keep PIECEWISE_CONSTANT

* allow use float with warmup or decay ratio.

* Update train_util.py
2024-09-11 21:25:45 +09:00
Kohya S
62ec3e6424 Merge branch 'main' into dev 2024-09-07 10:52:49 +09:00
Kohya S.
de25945a93 Merge pull request #1550 from kohya-ss/dependabot/github_actions/crate-ci/typos-1.24.3
Bump crate-ci/typos from 1.19.0 to 1.24.3
2024-09-07 10:50:46 +09:00
Kohya S
0005867ba5 update README, format code 2024-09-07 10:45:18 +09:00
Kohya S.
16bb5699ac Merge pull request #1426 from sdbds/resize
Replacing CV2 resize to Pil resize
2024-09-07 10:22:52 +09:00
Kohya S.
319e4d9831 Merge pull request #1433 from millie-v/sample-image-without-cuda
Generate sample images without having CUDA (such as on Macs)
2024-09-07 10:19:55 +09:00
dependabot[bot]
1bcf8d600b Bump crate-ci/typos from 1.19.0 to 1.24.3
Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.19.0 to 1.24.3.
- [Release notes](https://github.com/crate-ci/typos/releases)
- [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md)
- [Commits](https://github.com/crate-ci/typos/compare/v1.19.0...v1.24.3)

---
updated-dependencies:
- dependency-name: crate-ci/typos
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-09-01 01:33:04 +00:00
Kohya S.
f8f5b16958 Merge pull request #1540 from kohya-ss/dependabot/pip/opencv-python-4.8.1.78
Bump opencv-python from 4.7.0.68 to 4.8.1.78
2024-08-31 21:37:07 +09:00
Kohya S.
826ab5ce2e Merge pull request #1532 from nandometzger/main
Update train_util.py, bug fix
2024-08-31 21:36:33 +09:00
dependabot[bot]
3a6154b7b0 Bump opencv-python from 4.7.0.68 to 4.8.1.78
Bumps [opencv-python](https://github.com/opencv/opencv-python) from 4.7.0.68 to 4.8.1.78.
- [Release notes](https://github.com/opencv/opencv-python/releases)
- [Commits](https://github.com/opencv/opencv-python/commits)

---
updated-dependencies:
- dependency-name: opencv-python
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-08-31 06:21:16 +00:00
Nando Metzger
2a3aefb4e4 Update train_util.py, bug fix 2024-08-30 08:15:05 +02:00
Kohya S
d5c076cf90 update readme 2024-08-24 21:21:39 +09:00
Kohya S.
4ca29edbff Merge pull request #1505 from liesened/patch-2
Add v-pred support for SDXL train
2024-08-24 21:16:53 +09:00
liesen
1e8108fec9 Handle args.v_parameterization properly for MinSNR and changed prediction target 2024-08-24 01:38:17 +03:00
Kohya S
afb971f9c3 fix SD1.5 LoRA extraction #1490 2024-08-22 21:33:15 +09:00
Kohya S
74f91c2ff7 correct option name closes #1446 2024-08-11 21:54:10 +09:00
sdbds
9ca7a5b6cc instead cv2 LANCZOS4 resize to pil resize 2024-07-20 21:59:11 +08:00
sdbds
1f16b80e88 Revert "judge image size for using diff interpolation"
This reverts commit 87526942a6.
2024-07-20 21:35:24 +08:00
Millie
2e67978ee2 Generate sample images without having CUDA (such as on Macs) 2024-07-18 11:52:58 -07:00
sdbds
87526942a6 judge image size for using diff interpolation 2024-07-12 22:56:38 +08:00
Kohya S
0b3e4f7ab6 show file name if error in load_image ref #1385 2024-06-25 20:03:09 +09:00
Kohya S
9dd1ee458c Merge branch 'main' into dev 2024-06-23 14:11:51 +09:00
Kohya S
25f961bc77 fix to work cache_latents/text_encoder_outputs 2024-06-23 13:24:30 +09:00
Kohya S
56bb81c9e6 add grad_hook after restore state closes #1344 2024-06-12 21:39:35 +09:00
Kohya S
22413a5247 Merge pull request #1359 from kohya-ss/train_resume_step
Train resume step
2024-06-11 19:52:03 +09:00
Kohya S
18d7597b0b update README 2024-06-11 19:51:30 +09:00
Kohya S
4a441889d4 Merge branch 'dev' into train_resume_step 2024-06-11 19:27:37 +09:00
Kohya S
3259928ce4 Merge branch 'dev' of https://github.com/kohya-ss/sd-scripts into dev 2024-06-09 19:26:42 +09:00
Kohya S
1a104dc75e make forward/backward pathes same ref #1363 2024-06-09 19:26:36 +09:00
Kohya S
58fb64819a set static graph flag when DDP ref #1363 2024-06-09 19:26:09 +09:00
Kohya S
5bfe5e411b Merge pull request #1361 from shirayu/update/github_actions/crate-ci/typos-1.21.0
Bump crate-ci/typos from 1.19.0 to 1.21.0, fix typos, and updated _typos.toml (Close #1307)
2024-06-06 21:23:24 +09:00
Yuta Hayashibe
4ecbac131a Bump crate-ci/typos from 1.19.0 to 1.21.0, fix typos, and updated _typos.toml (Close #1307) 2024-06-05 16:31:55 +09:00
Kohya S
4dbcef429b update for corner cases 2024-06-04 21:26:55 +09:00
Kohya S
321e24d83b Merge pull request #1353 from KohakuBlueleaf/train_resume_step
Resume correct step for "resume from state" feature.
2024-06-04 19:30:11 +09:00
Kohya S
e5bab69e3a fix alpha mask without disk cache closes #1351, ref #1339 2024-06-02 21:11:40 +09:00
Kohaku-Blueleaf
3eb27ced52 Skip the final 1 step 2024-05-31 12:24:15 +08:00
Kohaku-Blueleaf
b2363f1021 Final implementation 2024-05-31 12:20:20 +08:00
Kohya S
0d96e10b3e Merge pull request #1339 from kohya-ss/alpha-masked-loss
Alpha masked loss
2024-05-27 21:41:16 +09:00
Kohya S
fc85496f7e update docs for masked loss 2024-05-27 21:25:06 +09:00
Kohya S
2870be9b52 Merge branch 'dev' into alpha-masked-loss 2024-05-27 21:08:43 +09:00
Kohya S
71ad3c0f45 Update masked_loss_README-ja.md
add sample images
2024-05-27 21:07:57 +09:00
Kohya S
ffce3b5098 Merge pull request #1349 from rockerBOO/patch-4
Update issue link
2024-05-27 21:00:46 +09:00
Kohya S
a4c3155148 add doc for mask loss 2024-05-27 20:59:40 +09:00
Kohya S
58cadf476b Merge branch 'dev' into alpha-masked-loss 2024-05-27 20:02:32 +09:00
Dave Lage
d50c1b3c5c Update issue link 2024-05-27 01:11:01 -04:00
Kohya S
e8cfd4ba1d fix to work cond mask and alpha mask 2024-05-26 22:01:37 +09:00
Kohya S
fb12b6d8e5 Merge pull request #1347 from rockerBOO/lora-plus-log-info
Add LoRA+ LR Ratio info message to logger
2024-05-26 19:45:03 +09:00
rockerBOO
00513b9b70 Add LoRA+ LR Ratio info message to logger 2024-05-23 22:27:12 -04:00
Kohya S
da6fea3d97 simplify and update alpha mask to work with various cases 2024-05-19 21:26:18 +09:00
Kohya S
f2dd43e198 revert kwargs to explicit declaration 2024-05-19 19:23:59 +09:00
u-haru
db6752901f 画像のアルファチャンネルをlossのマスクとして使用するオプションを追加 (#1223)
* Add alpha_mask parameter and apply masked loss

* Fix type hint in trim_and_resize_if_required function

* Refactor code to use keyword arguments in train_util.py

* Fix alpha mask flipping logic

* Fix alpha mask initialization

* Fix alpha_mask transformation

* Cache alpha_mask

* Update alpha_masks to be on CPU

* Set flipped_alpha_masks to Null if option disabled

* Check if alpha_mask is None

* Set alpha_mask to None if option disabled

* Add description of alpha_mask option to docs
2024-05-19 19:07:25 +09:00
Kohya S
febc5c59fa update README 2024-05-19 19:03:43 +09:00
Kohya S
4c798129b0 update README 2024-05-19 19:00:32 +09:00
Kohya S
38e4c602b1 Merge pull request #1277 from Cauldrath/negative_learning
Allow negative learning rate
2024-05-19 18:55:50 +09:00
Kohya S
e4d9e3c843 remove dependency for omegaconf #ref 1284 2024-05-19 17:46:07 +09:00
Kohya S
de0e0b9468 Merge pull request #1284 from sdbds/fix_traincontrolnet
Fix train controlnet
2024-05-19 17:39:15 +09:00
Kohya S
c68baae480 add --log_config option to enable/disable output training config 2024-05-19 17:21:04 +09:00
Kohya S
47187f7079 Merge pull request #1285 from ccharest93/main
Hyperparameter tracking
2024-05-19 16:31:33 +09:00
Kohya S
e3ddd1fbbe update README and format code 2024-05-19 16:26:10 +09:00
Kohya S
0640f017ab Merge pull request #1322 from aria1th/patch-1
Accelerate: fix get_trainable_params in controlnet-llite training
2024-05-19 16:23:01 +09:00
Kohya S
2f19175dfe update README 2024-05-19 15:38:37 +09:00
Kohya S
146edce693 support Diffusers' based SDXL LoRA key for inference 2024-05-18 11:05:04 +09:00
Kohya S
153764a687 add prompt option '--f' for filename 2024-05-15 20:21:49 +09:00
Kohya S
589c2aa025 update README 2024-05-13 21:20:37 +09:00
Kohya S
16677da0d9 fix create_network_from_weights doesn't work 2024-05-12 22:15:07 +09:00
Kohya S
a384bf2187 Merge pull request #1313 from rockerBOO/patch-3
Add caption_separator to output for subset
2024-05-12 21:36:56 +09:00
Kohya S
1c296f7229 Merge pull request #1312 from rockerBOO/patch-2
Fix caption_separator missing in subset schema
2024-05-12 21:33:12 +09:00
Kohya S
e96a5217c3 Merge pull request #1291 from frodo821/patch-1
removed unnecessary `torch` import on line 115
2024-05-12 21:14:50 +09:00
Kohya S
39b82f26e5 update readme 2024-05-12 20:58:45 +09:00
Kohya S
3701507874 raise original error if error is occured in checking latents 2024-05-12 20:56:56 +09:00
Kohya S
78020936d2 Merge pull request #1278 from Cauldrath/catch_latent_error_file
Display name of error latent file
2024-05-12 20:46:25 +09:00
Kohya S
9ddb4d7a01 update readme and help message etc. 2024-05-12 17:55:08 +09:00
Kohya S
8d1b1acd33 Merge pull request #1266 from Zovjsra/feature/disable-mmap
Add "--disable_mmap_load_safetensors" parameter
2024-05-12 17:43:44 +09:00
Kohya S
02298e3c4a Merge pull request #1331 from kohya-ss/lora-plus
Lora plus
2024-05-12 17:04:58 +09:00
Kohya S
44190416c6 update docs etc. 2024-05-12 17:01:20 +09:00
Kohya S
3c8193f642 revert lora+ for lora_fa 2024-05-12 17:00:51 +09:00
Kohya S
c6a437054a Merge branch 'dev' into lora-plus 2024-05-12 16:18:57 +09:00
Kohya S
1ffc0b330a fix typo 2024-05-12 16:18:43 +09:00
Kohya S
e01e148705 Merge branch 'dev' into lora-plus 2024-05-12 16:17:52 +09:00
Kohya S
e9f3a622f4 Merge branch 'dev' into lora-plus 2024-05-12 16:17:27 +09:00
Kohya S
7983d3db5f Merge pull request #1319 from kohya-ss/fused-backward-pass
Fused backward pass
2024-05-12 15:09:39 +09:00
Kohya S
bee8cee7e8 update README for fused optimizer 2024-05-12 15:08:52 +09:00
Kohya S
f3d2cf22ff update README for fused optimizer 2024-05-12 15:03:02 +09:00
Kohya S
6dbc23cf63 Merge branch 'dev' into fused-backward-pass 2024-05-12 14:21:56 +09:00
Kohya S
c1ba0b4356 update readme 2024-05-12 14:21:10 +09:00
Kohya S
607e041f3d chore: Refactor optimizer group 2024-05-12 14:16:41 +09:00
AngelBottomless
793aeb94da fix get_trainable_params in controlnet-llite training 2024-05-07 18:21:31 +09:00
Kohya S
b56d5f7801 add experimental option to fuse params to optimizer groups 2024-05-06 21:35:39 +09:00
Kohya S
017b82ebe3 update help message for fused_backward_pass 2024-05-06 15:05:42 +09:00
Kohya S
2a359e0a41 Merge pull request #1259 from 2kpr/fused_backward_pass
Adafactor fused backward pass and optimizer step, lowers SDXL (@ 1024 resolution) VRAM usage to BF16(10GB)/FP32(16.4GB)
2024-05-06 15:01:56 +09:00
Kohya S
3fd8cdc55d fix dylora loraplus 2024-05-06 14:03:19 +09:00
Kohya S
7fe81502d0 update loraplus on dylora/lofa_fa 2024-05-06 11:09:32 +09:00
Kohya S
52e64c69cf add debug log 2024-05-04 18:43:52 +09:00
Kohya S
58c2d856ae support block dim/lr for sdxl 2024-05-03 22:18:20 +09:00
Dave Lage
8db0cadcee Add caption_separator to output for subset 2024-05-02 18:08:28 -04:00
Dave Lage
dbb7bb288e Fix caption_separator missing in subset schema 2024-05-02 17:39:35 -04:00
Kohya S
969f82ab47 move loraplus args from args to network_args, simplify log lr desc 2024-04-29 20:04:25 +09:00
Kohya S
834445a1d6 Merge pull request #1233 from rockerBOO/lora-plus
Add LoRA+ support
2024-04-29 18:05:12 +09:00
frodo821
fdbb03c360 removed unnecessary torch import on line 115
as per #1290
2024-04-23 14:29:05 +09:00
Cauldrath
040e26ff1d Regenerate failed file
If a latent file fails to load, print out the path and the error, then return false to regenerate it
2024-04-21 13:46:31 -04:00
Kohya S
0540c33aca pop weights if available #1247 2024-04-21 17:45:29 +09:00
Kohya S
52652cba1a disable main process check for deepspeed #1247 2024-04-21 17:41:32 +09:00
青龍聖者@bdsqlsz
5cb145d13b Update train_util.py 2024-04-20 21:56:24 +08:00
Maatra
b886d0a359 Cleaned typing to be in line with accelerate hyperparameters type resctrictions 2024-04-20 14:36:47 +01:00
青龍聖者@bdsqlsz
4477116a64 fix train controlnet 2024-04-20 21:26:09 +08:00
Maatra
2c9db5d9f2 passing filtered hyperparameters to accelerate 2024-04-20 14:11:43 +01:00
Cauldrath
fc374375de Allow negative learning rate
This can be used to train away from a group of images you don't want
As this moves the model away from a point instead of towards it, the change in the model is unbounded
So, don't set it too low. -4e-7 seemed to work well.
2024-04-18 23:29:01 -04:00
Cauldrath
feefcf256e Display name of error latent file
When trying to load stored latents, if an error occurs, this change will tell you what file failed to load
Currently it will just tell you that something failed without telling you which file
2024-04-18 23:15:36 -04:00
Zovjsra
64916a35b2 add disable_mmap to args 2024-04-16 16:40:08 +08:00
2kpr
4f203ce40d Fused backward pass 2024-04-14 09:56:58 -05:00
rockerBOO
68467bdf4d Fix unset or invalid LR from making a param_group 2024-04-11 17:33:19 -04:00
rockerBOO
75833e84a1 Fix default LR, Add overall LoRA+ ratio, Add log
`--loraplus_ratio` added for both TE and UNet
Add log for lora+
2024-04-08 19:23:02 -04:00
Kohya S
71e2c91330 Merge pull request #1230 from kohya-ss/dependabot/github_actions/crate-ci/typos-1.19.0
Bump crate-ci/typos from 1.17.2 to 1.19.0
2024-04-07 21:14:18 +09:00
Kohya S
bfb352bc43 change huber_schedule from exponential to snr 2024-04-07 21:07:52 +09:00
Kohya S
c973b29da4 update readme 2024-04-07 20:51:52 +09:00
Kohya S
683f3d6ab3 Merge pull request #1212 from kohya-ss/dev
Version 0.8.6
2024-04-07 20:42:41 +09:00
Kohya S
dfa30790a9 update readme 2024-04-07 20:34:26 +09:00
Kohya S
d30ebb205c update readme, add metadata for network module 2024-04-07 14:58:17 +09:00
kabachuha
90b18795fc Add option to use Scheduled Huber Loss in all training pipelines to improve resilience to data corruption (#1228)
* add huber loss and huber_c compute to train_util

* add reduction modes

* add huber_c retrieval from timestep getter

* move get timesteps and huber to own function

* add conditional loss to all training scripts

* add cond loss to train network

* add (scheduled) huber_loss to args

* fixup twice timesteps getting

* PHL-schedule should depend on noise scheduler's num timesteps

* *2 multiplier to huber loss cause of 1/2 a^2 conv.

The Taylor expansion of sqrt near zero gives 1/2 a^2, which differs from a^2 of the standard MSE loss. This change scales them better against one another

* add option for smooth l1 (huber / delta)

* unify huber scheduling

* add snr huber scheduler

---------

Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
2024-04-07 13:54:21 +09:00
Kohya S
089727b5ee update readme 2024-04-07 12:42:49 +09:00
Kohya S
921036dd91 Merge pull request #1240 from kohya-ss/verify-command-line-args
verify command line args if wandb is enabled
2024-04-07 12:27:03 +09:00
ykume
cd587ce62c verify command line args if wandb is enabled 2024-04-05 08:23:03 +09:00
rockerBOO
1933ab4b48 Fix default_lr being applied 2024-04-03 12:46:34 -04:00
Kohya S
b748b48dbb fix attention couple+deep shink cause error in some reso 2024-04-03 12:43:08 +09:00
rockerBOO
c7691607ea Add LoRA-FA for LoRA+ 2024-04-01 15:43:04 -04:00
rockerBOO
f99fe281cb Add LoRA+ support 2024-04-01 15:38:26 -04:00
dependabot[bot]
80e9f72234 Bump crate-ci/typos from 1.17.2 to 1.19.0
Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.17.2 to 1.19.0.
- [Release notes](https://github.com/crate-ci/typos/releases)
- [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md)
- [Commits](https://github.com/crate-ci/typos/compare/v1.17.2...v1.19.0)

---
updated-dependencies:
- dependency-name: crate-ci/typos
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-04-01 01:50:22 +00:00
Kohya S
2258a1b753 add save/load hook to remove U-Net/TEs from state 2024-03-31 15:50:35 +09:00
Kohya S
059ee047f3 fix typo 2024-03-30 23:02:24 +09:00
Kohya S
2c2ca9d726 update tagger doc 2024-03-30 22:55:56 +09:00
Kohya S
f5323e3c4b update tagger doc 2024-03-30 22:10:37 +09:00
Kohya S
cae5aa0a56 update wd14 tagger and doc 2024-03-30 21:48:22 +09:00
Kohya S
6ba84288d9 Merge pull request #1216 from Disty0/dev
Rating support for WD Tagger
2024-03-30 18:50:49 +09:00
Kohya S
434dc408f9 update readme 2024-03-30 17:12:36 +09:00
Kohya S
ae3f625739 Merge branch 'dev' of https://github.com/kohya-ss/sd-scripts into dev 2024-03-30 14:57:43 +09:00
Kohya S
f1f30ab418 fix to work with num_beams>1 closes #1149 2024-03-30 14:57:39 +09:00
Disty0
bc586ce190 Add --use_rating_tags and --character_tags_first for WD Tagger 2024-03-29 13:56:42 +03:00
Disty0
4012fd24f6 IPEX fix pin_memory 2024-03-28 21:08:16 +03:00
Disty0
954731d564 fix typo 2024-03-27 22:00:59 +03:00
Disty0
dd9763be31 Rating support for WD Tagger 2024-03-27 21:53:40 +03:00
Kohya S
b86af6798d Merge pull request #1213 from Disty0/dev
Add OpenVINO and ROCm ONNX Runtime for WD14
2024-03-27 23:15:33 +09:00
Disty0
6f7e93d5cc Add OpenVINO and ROCm ONNX Runtime for WD14 2024-03-27 03:21:13 +03:00
Kohya S
6c08e97e1f update readme 2024-03-26 20:48:08 +09:00
Kohya S
78e0a7630c Merge pull request #1206 from kohya-ss/dataset-cache
Add metadata caching for DreamBooth dataset
2024-03-26 19:49:23 +09:00
Kohya S
c86e356013 Merge branch 'dev' into dataset-cache 2024-03-26 19:43:40 +09:00
Kohya S
5a2afb3588 Merge pull request #1207 from kohya-ss/masked-loss
Add masked loss
2024-03-26 19:41:31 +09:00
Kohya S
ab1e389347 Merge branch 'dev' into masked-loss 2024-03-26 19:39:30 +09:00
Kohya S
ea05e3fd5b Merge pull request #1139 from kohya-ss/deep-speed
Deep speed
2024-03-26 19:33:57 +09:00
Kohya S
a2b8531627 make each script consistent, fix to work w/o DeepSpeed 2024-03-25 22:28:46 +09:00
Kohya S
c24422fb9d Merge branch 'dev' into deep-speed 2024-03-25 22:11:05 +09:00
Kohya S
9c4492b58a fix pytorch version 2.1.1 to 2.1.2 2024-03-24 23:17:25 +09:00
Kohya S
9bbb28c361 update PyTorch version and reorganize dependencies 2024-03-24 22:06:37 +09:00
Kohya S
1648ade6da format by black 2024-03-24 20:55:48 +09:00
Kohya S
993b2ab4c1 Merge branch 'dev' into deep-speed 2024-03-24 18:45:59 +09:00
Kohya S
8d5858826f Merge branch 'dev' into masked-loss 2024-03-24 18:19:53 +09:00
Kohya S
025347214d refactor metadata caching for DreamBooth dataset 2024-03-24 18:09:32 +09:00
Kohaku-Blueleaf
ae97c8bfd1 [Experimental] Add cache mechanism for dataset groups to avoid long waiting time for initilization (#1178)
* support meta cached dataset

* add cache meta scripts

* random ip_noise_gamma strength

* random noise_offset strength

* use correct settings for parser

* cache path/caption/size only

* revert mess up commit

* revert mess up commit

* Update requirements.txt

* Add arguments for meta cache.

* remove pickle implementation

* Return sizes when enable cache

---------

Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
2024-03-24 15:40:18 +09:00
Kohya S
381c44955e update readme and typing hint 2024-03-24 11:27:18 +09:00
Kohya S
ad97410ba5 Merge pull request #1205 from feffy380/patch-1
register reg images with correct subset
2024-03-24 11:14:07 +09:00
Kohya S
691f04322a update readme 2024-03-24 11:10:26 +09:00
Kohya S
79d1c12ab0 disable sample_every_n_xxx if value less than 1 ref #1202 2024-03-24 11:06:37 +09:00
feffy380
0c7baea88c register reg images with correct subset 2024-03-23 17:28:02 +01:00
Kohya S
f4a4c11cd3 support multiline captions ref #1155 2024-03-23 18:51:37 +09:00
Kohya S
594c7f7050 format by black 2024-03-23 16:11:31 +09:00
Kohya S
d17c0f5084 update dataset config doc 2024-03-21 08:31:29 +09:00
Kohya S
a35e7bd595 Merge pull request #1200 from BootsofLagrangian/deep-speed
Fix sdxl_train.py in deepspeed branch
2024-03-20 21:32:35 +09:00
BootsofLagrangian
d9456020d7 Fix most of ZeRO stage uses optimizer partitioning
- we have to prepare optimizer and ds_model at the same time.
 - pull/1139#issuecomment-1986790007

Signed-off-by: BootsofLagrangian <hard2251@yonsei.ac.kr>
2024-03-20 20:52:59 +09:00
Kohya S
fbb98f144e Merge branch 'dev' into deep-speed 2024-03-20 18:15:26 +09:00
Kohya S
9b6b39f204 Merge branch 'dev' into masked-loss 2024-03-20 18:14:36 +09:00
Kohya S
855add067b update option help and readme 2024-03-20 18:14:05 +09:00
Kohya S
bf6cd4b9da Merge pull request #1168 from gesen2egee/save_state_on_train_end
Save state on train end
2024-03-20 18:02:13 +09:00
Kohya S
3b0db0f17f update readme 2024-03-20 17:45:35 +09:00
Kohya S
119cc99fb0 Merge pull request #1167 from Horizon1704/patch-1
Add "encoding='utf-8'" for --config_file
2024-03-20 17:39:08 +09:00
Kohya S
5f6196e4c7 update readme 2024-03-20 16:35:23 +09:00
Victor Espinoza-Guerra
46331a9e8e English Translation of config_README-ja.md (#1175)
* Add files via upload

Creating template to work on.

* Update config_README-en.md

Total Conversion from Japanese to English.

* Update config_README-en.md

* Update config_README-en.md

* Update config_README-en.md
2024-03-20 16:31:01 +09:00
Kohya S
cf09c6aa9f Merge pull request #1177 from KohakuBlueleaf/random-strength-noise
Random strength for Noise Offset and input perturbation noise
2024-03-20 16:17:16 +09:00
Kohya S
80dbbf5e48 tagger now stores model under repo_id subdir 2024-03-20 16:14:57 +09:00
Kohya S
7da41be281 Merge pull request #1192 from sdbds/main
Add WDV3 support
2024-03-20 15:49:55 +09:00
Kohya S
e281e867e6 Merge branch 'main' into dev 2024-03-20 15:49:08 +09:00
青龍聖者@bdsqlsz
6c51c971d1 fix typo 2024-03-20 09:35:21 +08:00
青龍聖者@bdsqlsz
a71c35ccd9 Update requirements.txt 2024-03-18 22:31:59 +08:00
青龍聖者@bdsqlsz
5410a8c79b Update requirements.txt 2024-03-18 22:31:00 +08:00
青龍聖者@bdsqlsz
a7dff592d3 Update tag_images_by_wd14_tagger.py
add WDV3
2024-03-18 22:29:05 +08:00
Kohya S
f9317052ed update readme for timestep embs bug 2024-03-18 08:53:23 +09:00
Kohya S
86e40fabbc Merge branch 'dev' into deep-speed 2024-03-17 19:30:42 +09:00
Kohya S
3419c3de0d common masked loss func, apply to all training script 2024-03-17 19:30:20 +09:00
Kohya S
7081a0cf0f extension of src image could be different than target image 2024-03-17 18:09:15 +09:00
Kohya S
0ef4fe70f0 Merge branch 'dev' into masked-loss 2024-03-17 11:18:18 +09:00
Kohya S
443f02942c fix doc 2024-03-15 21:35:14 +09:00
Kohya S
0a8ec5224e Merge branch 'main' into dev 2024-03-15 21:33:07 +09:00
Kohya S
6b1520a46b Merge pull request #1187 from kohya-ss/fix-timeemb
fix sdxl timestep embedding
2024-03-15 21:17:13 +09:00
Kohya S
f811b115ba fix sdxl timestep embedding 2024-03-15 21:05:00 +09:00
kblueleaf
53954a1e2e use correct settings for parser 2024-03-13 18:21:49 +08:00
kblueleaf
86399407b2 random noise_offset strength 2024-03-13 18:21:49 +08:00
kblueleaf
948029fe61 random ip_noise_gamma strength 2024-03-13 18:21:49 +08:00
Kohya S
97524f1bda Merge branch 'dev' into deep-speed 2024-03-12 20:41:41 +09:00
Kohya S
74c266a597 Merge branch 'dev' into masked-loss 2024-03-12 20:40:57 +09:00
gesen2egee
d282c45002 Update train_network.py 2024-03-11 23:56:09 +08:00
gesen2egee
095b8035e6 save state on train end 2024-03-10 23:33:38 +08:00
Horizon1704
124ec45876 Add "encoding='utf-8'" 2024-03-10 22:53:05 +08:00
Kohya S
14c9372a38 add doc about Colab/rich issue 2024-03-03 21:47:37 +09:00
Kohya S
a9b64ffba8 support masked loss in sdxl_train ref #589 2024-02-27 21:43:55 +09:00
Kohya S
e3ccf8fbf7 make deepspeed_utils 2024-02-27 21:30:46 +09:00
Kohya S
0e4a5738df Merge pull request #1101 from BootsofLagrangian/deepspeed
support deepspeed
2024-02-27 18:59:00 +09:00
Kohya S
eefb3cc1e7 Merge branch 'deep-speed' into deepspeed 2024-02-27 18:57:42 +09:00
Kohya S
074d32af20 Merge branch 'main' into dev 2024-02-27 18:53:43 +09:00
Kohya S
2d7389185c Merge pull request #1094 from kohya-ss/dependabot/github_actions/crate-ci/typos-1.17.2
Bump crate-ci/typos from 1.16.26 to 1.17.2
2024-02-27 18:23:41 +09:00
Kohya S
4a5546d40e fix typo 2024-02-26 23:39:56 +09:00
Kohya S
175193623b update readme 2024-02-26 23:29:41 +09:00
Kohya S
f2c727fc8c add minimal impl for masked loss 2024-02-26 23:19:58 +09:00
Kohya S
577e9913ca add some new dataset settings 2024-02-26 20:01:25 +09:00
Kohya S
fccbee2727 revert logging #1137 2024-02-25 10:43:14 +09:00
Kohya S
e0acb10f31 Merge pull request #1137 from shirayu/replace_print_with_logger
Replaced print with logger
2024-02-25 10:34:19 +09:00
Yuta Hayashibe
5d5f39b6e6 Replaced print with logger 2024-02-25 01:24:11 +09:00
Kohya S
e69d34103b Merge pull request #1136 from kohya-ss/dev
v0.8.4
2024-02-24 21:15:46 +09:00
Kohya S
a21218bdd5 update readme 2024-02-24 21:09:59 +09:00
Kohya S
81e8af6519 fix ipex init 2024-02-24 20:51:26 +09:00
Kohya S
8b7c14246a some log output to print 2024-02-24 20:50:00 +09:00
Kohya S
52b3799989 fix format, add new conv rank to metadata comment 2024-02-24 20:49:41 +09:00
Kohya S
738c397e1a Merge pull request #1102 from mgz-dev/resize_lora-add-rank-for-conv
Resize lora add new rank for conv
2024-02-24 20:10:20 +09:00
Kohya S
0e703608f9 Merge branch 'dev' into resize_lora-add-rank-for-conv 2024-02-24 20:09:38 +09:00
Kohya S
fb9110bac1 format by black 2024-02-24 20:00:57 +09:00
Kohya S
24092e6f21 update einops to 0.7.0 #1122 2024-02-24 19:51:51 +09:00
Kohya S
f4132018c5 fix to work with cpu_count() == 1 closes #1134 2024-02-24 19:25:31 +09:00
Kohya S
488d1870ab Merge pull request #1126 from tamlog06/DyLoRA-xl
Fix dylora create_modules error when training sdxl
2024-02-24 19:19:33 +09:00
Kohya S
86279c8855 Merge branch 'dev' into DyLoRA-xl 2024-02-24 19:18:36 +09:00
BootsofLagrangian
4d5186d1cf refactored codes, some function moved into train_utils.py 2024-02-22 16:20:53 +09:00
tamlog06
a6f1ed2e14 fix dylora create_modules error 2024-02-18 13:20:47 +00:00
Kohya S
d1fb480887 format by black 2024-02-18 09:13:24 +09:00
Kohya S
75e4a951d0 update readme 2024-02-17 12:04:12 +09:00
Kohya S
42f3318e17 Merge pull request #1116 from kohya-ss/dev_device_support
Dev device support
2024-02-17 11:58:02 +09:00
Kohya S
baa0e97ced Merge branch 'dev' into dev_device_support 2024-02-17 11:54:07 +09:00
Kohya S
71ebcc5e25 update readme and gradual latent doc 2024-02-12 14:52:19 +09:00
Kohya S
93bed60762 fix to work --console_log_xxx options 2024-02-12 14:49:29 +09:00
Kohya S
41d32c0be4 Merge pull request #1117 from kohya-ss/gradual_latent_hires_fix
Gradual latent hires fix
2024-02-12 14:21:27 +09:00
Kohya S
cbe9c5dc06 supprt deep shink with regional lora, add prompter module 2024-02-12 14:17:27 +09:00
Kohya S
d3745db764 add args for logging 2024-02-12 13:15:21 +09:00
Kohya S
358ca205a3 Merge branch 'dev' into dev_device_support 2024-02-12 13:01:54 +09:00
Kohya S
c748719115 fix indent 2024-02-12 12:59:45 +09:00
Kohya S
98f42d3a0b Merge branch 'dev' into gradual_latent_hires_fix 2024-02-12 12:59:25 +09:00
Kohya S
35c6053de3 Merge pull request #1104 from kohya-ss/dev_improve_log
replace print with logger
2024-02-12 11:33:32 +09:00
Kohya S
20ae603221 Merge branch 'dev' into gradual_latent_hires_fix 2024-02-12 11:26:36 +09:00
Kohya S
672851e805 Merge branch 'dev' into dev_improve_log 2024-02-12 11:24:33 +09:00
Kohya S
e579648ce9 fix help for highvram arg 2024-02-12 11:12:41 +09:00
Kohya S
e24d9606a2 add clean_memory_on_device and use it from training 2024-02-12 11:10:52 +09:00
Kohya S
75ecb047e2 Merge branch 'dev' into dev_device_support 2024-02-11 19:51:28 +09:00
Kohya S
f897d55781 Merge pull request #1113 from kohya-ss/dev_multi_gpu_sample_gen
Dev multi gpu sample gen
2024-02-11 19:49:08 +09:00
Kohya S
7202596393 log to print tag frequencies 2024-02-10 09:59:12 +09:00
BootsofLagrangian
03f0816f86 the reason not working grad accum steps found. it was becasue of my accelerate settings 2024-02-09 17:47:49 +09:00
Kohya S
5d9e2873f6 make rich to output to stderr instead of stdout 2024-02-08 21:38:02 +09:00
Kohya S
055f02e1e1 add logging args for training scripts 2024-02-08 21:16:42 +09:00
Kohya S
9b8ea12d34 update log initialization without rich 2024-02-08 21:06:39 +09:00
Kohya S
74fe0453b2 add comment for get_preferred_device 2024-02-08 20:58:54 +09:00
BootsofLagrangian
a98fecaeb1 forgot setting mixed_precision for deepspeed. sorry 2024-02-07 17:19:46 +09:00
BootsofLagrangian
2445a5b74e remove test requirements 2024-02-07 16:48:18 +09:00
BootsofLagrangian
62556619bd fix full_fp16 compatible and train_step 2024-02-07 16:42:05 +09:00
BootsofLagrangian
7d2a9268b9 apply offloading method runable for all trainer 2024-02-05 22:42:06 +09:00
BootsofLagrangian
3970bf4080 maybe fix branch to run offloading 2024-02-05 22:40:43 +09:00
BootsofLagrangian
4295f91dcd fix all trainer about vae 2024-02-05 20:19:56 +09:00
BootsofLagrangian
2824312d5e fix vae type error during training sdxl 2024-02-05 20:13:28 +09:00
BootsofLagrangian
64873c1b43 fix offload_optimizer_device typo 2024-02-05 17:11:50 +09:00
Kohya S
efd3b58973 Add logging arguments and update logging setup 2024-02-04 20:44:10 +09:00
Kohya S
6279b33736 fallback to basic logging if rich is not installed 2024-02-04 18:28:54 +09:00
Yuta Hayashibe
5f6bf29e52 Replace print with logger if they are logs (#905)
* Add get_my_logger()

* Use logger instead of print

* Fix log level

* Removed line-breaks for readability

* Use setup_logging()

* Add rich to requirements.txt

* Make simple

* Use logger instead of print

---------

Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
2024-02-04 18:14:34 +09:00
Kohya S
e793d7780d reduce peak VRAM in sample gen 2024-02-04 17:31:01 +09:00
mgz
1492bcbfa2 add --new_conv_rank option
update script to also take a separate conv rank value
2024-02-03 23:18:55 -06:00
mgz
bf2de5620c fix formatting in resize_lora.py 2024-02-03 20:09:37 -06:00
BootsofLagrangian
dfe08f395f support deepspeed 2024-02-04 03:12:42 +09:00
Kohya S
6269682c56 unificaition of gen scripts for SD and SDXL, work in progress 2024-02-03 23:33:48 +09:00
Kohya S
2f9a344297 fix typo 2024-02-03 23:26:57 +09:00
Kohya S
11aced3500 simplify multi-GPU sample generation 2024-02-03 22:25:29 +09:00
DKnight54
1567ce1e17 Enable distributed sample image generation on multi-GPU enviroment (#1061)
* Update train_util.py

Modifying to attempt enable multi GPU inference

* Update train_util.py

additional VRAM checking, refactor check_vram_usage to return string for use with accelerator.print

* Update train_network.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

remove sample image debug outputs

* Update train_util.py

* Update train_util.py

* Update train_network.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_network.py

* Update train_util.py

* Update train_network.py

* Update train_network.py

* Update train_network.py

* Cleanup of debugging outputs

* adopt more elegant coding

Co-authored-by: Aarni Koskela <akx@iki.fi>

* Update train_util.py

Fix leftover debugging code
attempt to refactor inference into separate function

* refactor in function generate_per_device_prompt_list() generation of distributed prompt list

* Clean up missing variables

* fix syntax error

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* true random sample image generation

update code to reinitialize random seed to true random if seed was set

* true random sample image generation

* simplify per process prompt

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_network.py

* Update train_network.py

* Update train_network.py

---------

Co-authored-by: Aarni Koskela <akx@iki.fi>
2024-02-03 21:46:31 +09:00
Kohya S
5cca1fdc40 add highvram option and do not clear cache in caching latents 2024-02-01 21:55:55 +09:00
Kohya S
9f0f0d573d Merge pull request #1092 from Disty0/dev_device_support
Fix IPEX support and add XPU device to device_utils
2024-02-01 20:41:21 +09:00
dependabot[bot]
716a92cbed Bump crate-ci/typos from 1.16.26 to 1.17.2
Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.16.26 to 1.17.2.
- [Release notes](https://github.com/crate-ci/typos/releases)
- [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md)
- [Commits](https://github.com/crate-ci/typos/compare/v1.16.26...v1.17.2)

---
updated-dependencies:
- dependency-name: crate-ci/typos
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-02-01 01:57:52 +00:00
Disty0
a6a2b5a867 Fix IPEX support and add XPU device to device_utils 2024-01-31 17:32:37 +03:00
Kohya S
2ca4d0c831 Merge pull request #1054 from akx/mps
Device support improvements (MPS)
2024-01-31 21:30:12 +09:00
Kohya S
7f948db158 Merge pull request #1087 from mgz-dev/fix-imports-on-svd_merge_lora
fix broken import in svd_merge_lora script
2024-01-31 21:08:40 +09:00
Kohya S
9d7729c00d Merge pull request #1086 from Disty0/dev
Update IPEX Libs
2024-01-31 21:06:34 +09:00
Disty0
988dee02b9 IPEX torch.tensor FP64 workaround 2024-01-30 01:52:32 +03:00
mgz
d4b9568269 fix broken import in svd_merge_lora script
remove missing import, and remove unused imports
2024-01-28 11:59:07 -06:00
Disty0
ccc3a481e7 Update IPEX Libs 2024-01-28 14:14:31 +03:00
Kohya S
8f6f734a6f Merge branch 'dev' into gradual_latent_hires_fix 2024-01-28 08:21:15 +09:00
Kohya S
cd19df49cd Merge pull request #1085 from kohya-ss/dev
Dev
2024-01-27 18:32:06 +09:00
Kohya S
736365bdd5 update README.md 2024-01-27 18:31:01 +09:00
Kohya S
6ceedb9448 Merge branch 'main' into dev 2024-01-27 18:23:52 +09:00
Kohya S
930a3912a7 Merge pull request #1084 from fireicewolf/devel
Fix network multiplier cause crashed while use multi-GPUs
2024-01-27 18:22:00 +09:00
Kohya S
cf790d87c4 Merge pull request #1079 from feffy380/fix/fp8savestate
Update safetensors to fix a crash with `--fp8_base --save_state`
2024-01-26 22:34:35 +09:00
DukeG
4e67fb8444 test 2024-01-26 20:22:49 +08:00
DukeG
50f631c768 test 2024-01-26 20:02:48 +08:00
DukeG
85bc371ebc test 2024-01-26 18:58:47 +08:00
feffy380
322ee52c77 Update requirements.txt
Update safetensors to fix a crash when using `--fp8_base --save_state`
2024-01-25 19:15:53 +01:00
Kohya S
c576f80639 Fix ControlNetLLLite training issue #1069 2024-01-25 18:43:07 +09:00
Aarni Koskela
478156b4f7 Refactor device determination to function; add MPS fallback 2024-01-23 14:29:03 +02:00
Aarni Koskela
afc38707d5 Refactor memory cleaning into a single function 2024-01-23 14:28:50 +02:00
Aarni Koskela
2e4bee6f24 Log accelerator device 2024-01-23 14:20:40 +02:00
Kohya S
d5ab97b69b Merge pull request #1067 from kohya-ss/dev
Dev
2024-01-23 21:04:16 +09:00
Kohya S
7cb44e4502 update readme 2024-01-23 21:02:40 +09:00
Kohya S
7a20df5ad5 Merge pull request #1064 from KohakuBlueleaf/fix-grad-sync
Avoid grad sync on each step even when doing accumulation
2024-01-23 20:33:55 +09:00
Kohya S
bea4362e21 Merge pull request #1060 from akx/refactor-xpu-init
Deduplicate ipex initialization code
2024-01-23 20:25:37 +09:00
Kohya S
6805cafa9b fix TI training crashes in multigpu #1019 2024-01-23 20:17:19 +09:00
Kohaku-Blueleaf
711b40ccda Avoid always sync 2024-01-23 11:49:03 +08:00
Kohya S
696dd7f668 Fix dtype issue in PyTorch 2.0 for generating samples in training sdxl network 2024-01-22 12:43:37 +09:00
Kohya S
e0a3c69223 update readme 2024-01-20 18:47:10 +09:00
Kohya S
c59249a664 Add options to reduce memory usage in extract_lora_from_models.py closes #1059 2024-01-20 18:45:54 +09:00
Kohya S
fef172966f Add network_multiplier for dataset and train LoRA 2024-01-20 16:24:43 +09:00
Kohya S
5a1ebc4c7c format by black 2024-01-20 13:10:45 +09:00
Kohya S
2a0f45aea9 update readme 2024-01-20 11:08:20 +09:00
Kohya S
1f77bb6e73 fix to work sample generation in fp8 ref #1057 2024-01-20 10:57:42 +09:00
Kohya S
a7ef6422b6 fix to work with torch 2.0 2024-01-20 10:00:30 +09:00
Kohaku-Blueleaf
9cfa68c92f [Experimental Feature] FP8 weight dtype for base model when running train_network (or sdxl_train_network) (#1057)
* Add fp8 support

* remove some debug prints

* Better implementation for te

* Fix some misunderstanding

* as same as unet, add explicit convert

* better impl for convert TE to fp8

* fp8 for not only unet

* Better cache TE and TE lr

* match arg name

* Fix with list

* Add timeout settings

* Fix arg style

* Add custom seperator

* Fix typo

* Fix typo again

* Fix dtype error

* Fix gradient problem

* Fix req grad

* fix merge

* Fix merge

* Resolve merge

* arrangement and document

* Resolve merge error

* Add assert for mixed precision
2024-01-20 09:46:53 +09:00
Aarni Koskela
6f3f701d3d Deduplicate ipex initialization code 2024-01-19 18:07:36 +02:00
Kohya S
d2a99a19d4 Merge pull request #1056 from kohya-ss/dev
fix vram usage in LoRA training
2024-01-17 21:41:36 +09:00
Kohya S
0395a35543 Merge branch 'main' into dev 2024-01-17 21:39:13 +09:00
Kohya S
987d4a969d update readme 2024-01-17 21:38:49 +09:00
Kohya S
976d092c68 fix text encodes are on gpu even when not trained 2024-01-17 21:31:50 +09:00
Kohya S
e6b15c7e4a Merge pull request #1053 from akx/sdpa
Fix typo `--spda` (it's `--sdpa`)
2024-01-16 21:50:45 +09:00
Aarni Koskela
ef50436464 Fix typo --spda (it's --sdpa) 2024-01-16 14:32:48 +02:00
Kohya S
26d35794e3 Merge pull request #1052 from kohya-ss/dev
merge dev
2024-01-15 21:39:02 +09:00
Kohya S
dcf0eeb5b6 update readme 2024-01-15 21:35:26 +09:00
Kohya S
32b759a328 Add wandb_run_name parameter to init_kwargs #1032 2024-01-14 22:02:03 +09:00
Kohya S
09ef3ffa8b Merge branch 'main' into dev 2024-01-14 21:49:25 +09:00
Kohya S
aab265e431 Fix an issue with saving as diffusers sd1/2 model close #1033 2024-01-04 21:43:50 +09:00
Kohya S
da9b34fa26 Merge branch 'dev' into gradual_latent_hires_fix 2024-01-04 19:53:46 +09:00
Kohya S
716bad188b Update dependencies ref #1024 2024-01-04 19:53:25 +09:00
Kohya S
4f93bf10f0 Merge pull request #1032 from hopl1t/wandb_session_name_support
Added cli argument for wandb session name
2024-01-04 11:10:31 +09:00
Kohya S
07bf2a21ac Merge pull request #1024 from p1atdev/main
Add support for `torch.compile`
2024-01-04 10:49:52 +09:00
Kohya S
8ac2d2a92f Merge pull request #1030 from Disty0/dev
Update IPEX Libs
2024-01-04 10:46:07 +09:00
Kohya S
76aee71257 Merge branch 'main' into dev 2024-01-04 10:42:16 +09:00
Kohya S
1db5d790ed Merge pull request #1029 from kohya-ss/dependabot/github_actions/crate-ci/typos-1.16.26
Bump crate-ci/typos from 1.16.15 to 1.16.26
2024-01-04 10:41:07 +09:00
Kohya S
663b481029 fix TI training with full_fp16/bf16 ref #1019 2024-01-03 23:22:00 +09:00
Kohya S
1ab6493268 Merge branch 'main' into dev 2024-01-03 21:36:31 +09:00
Nir Weingarten
ab716302e4 Added cli argument for wandb session name 2024-01-03 11:52:38 +02:00
Disty0
b9d2181192 Cleanup 2024-01-02 11:51:29 +03:00
Disty0
49148eb36e Disable Diffusers slicing if device is not XPU 2024-01-02 11:50:08 +03:00
Disty0
479bac447e Fix typo 2024-01-01 12:51:23 +03:00
Disty0
15d5e78ac2 Update IPEX Libs 2024-01-01 12:44:26 +03:00
dependabot[bot]
fd7f27f044 Bump crate-ci/typos from 1.16.15 to 1.16.26
Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.16.15 to 1.16.26.
- [Release notes](https://github.com/crate-ci/typos/releases)
- [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md)
- [Commits](https://github.com/crate-ci/typos/compare/v1.16.15...v1.16.26)

---
updated-dependencies:
- dependency-name: crate-ci/typos
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-01-01 01:28:55 +00:00
Plat
62e7516537 feat: support torch.compile 2023-12-27 02:17:24 +09:00
Plat
20296b4f0e chore: bump eniops version due to support torch.compile 2023-12-27 02:17:24 +09:00
Kohya S
5cae6db804 Fix to work with DDP TextualInversionTrainer ref #1019 2023-12-24 22:05:56 +09:00
Kohya S
1a36f9dc65 Merge pull request #1020 from kohya-ss/dev
Fix convert_diffusers20_original_sd.py and add metadata & variant options
2023-12-24 21:48:25 +09:00
Kohya S
c2497877ca Merge branch 'main' into dev 2023-12-24 21:46:05 +09:00
Kohya S
3b5c1a1d4b Fix issue with tools/convert_diffusers20_original_sd.py 2023-12-24 21:45:51 +09:00
Kohya S
9a2e385f12 Merge pull request #1016 from Disty0/dev
Fix convert_diffusers20_original_sd.py and add --variant option for loading
2023-12-24 21:41:12 +09:00
Disty0
7080e1a11c Fix convert_diffusers20_original_sd.py and add metadata & variant options 2023-12-22 22:40:03 +03:00
Kohya S
0a52b83c6a Merge pull request #1012 from kohya-ss/dev
merge dev to main
2023-12-21 22:18:39 +09:00
Kohya S
11ed8e2a6d update readme 2023-12-21 22:16:55 +09:00
Kohya S
bb20c09a9a update readme 2023-12-21 22:10:47 +09:00
Kohya S
04ef8d395f speed up nan replace in sdxl training ref #1009 2023-12-21 21:44:03 +09:00
Kohya S
0676f1a86f Merge pull request #1009 from liubo0902/main
speed up latents nan replace
2023-12-21 21:37:16 +09:00
Kohya S
6b7823df07 Merge branch 'main' into dev 2023-12-21 21:33:43 +09:00
Kohya S
2186e417ba fix size of bucket < min_size ref #1008 2023-12-20 22:12:21 +09:00
Kohya S
1519e3067c Merge pull request #1008 from Cauldrath/zero_height_error
Fix zero height buckets
2023-12-20 22:09:04 +09:00
Kohya S
35e5424255 Merge pull request #1007 from Disty0/dev
IPEX fix SDPA
2023-12-20 21:53:11 +09:00
liubo0902
8c7d05afd2 speed up latents nan replace 2023-12-20 09:35:17 +08:00
Cauldrath
f8360a4831 Fix zero height buckets
If max_size is too large relative to max_reso, it will calculate a height of zero for some buckets.
This causes a crash later when it divides the width by the height.

This change also simplifies some math and consolidates the redundant "size" variable into "width".
2023-12-19 18:35:09 -05:00
Disty0
8556b9d7f5 IPEX fix SDPA 2023-12-19 22:59:06 +03:00
Kohya S
3efd90b2ad fix sampling in training with mutiple gpus ref #989 2023-12-15 22:35:54 +09:00
Kohya S
7adcd9cd1a Merge pull request #1003 from Disty0/dev
IPEX support for Torch 2.1 and fix dtype erros
2023-12-15 08:18:48 +09:00
Disty0
aff05e043f IPEX support for Torch 2.1 and fix dtype erros 2023-12-13 19:40:38 +03:00
Kohya S
ff2c0c192e update readme 2023-12-13 23:13:22 +09:00
Kohya S
d309a27a51 change option names, add ddp kwargs if needed ref #1000 2023-12-13 21:02:26 +09:00
Kohya S
471d274803 Merge pull request #1000 from Isotr0py/dev
Fix multi-gpu SDXL training
2023-12-13 20:52:11 +09:00
Kohya S
35f4c9b5c7 fix an error when keep_tokens_separator is not set ref #975 2023-12-12 21:43:21 +09:00
Kohya S
034a49c69d Merge pull request #975 from Linaqruf/dev
Add keep_tokens_separator as alternative for keep_tokens
2023-12-12 21:28:32 +09:00
Kohya S
3b6825d7e2 Merge pull request #986 from CjangCjengh/dev
Fixed the path error in finetune/make_captions.py
2023-12-12 21:17:24 +09:00
Isotr0py
bb5ae389f7 fix DDP SDXL training 2023-12-12 19:58:44 +08:00
Kohya S
d61ecb26fd enable comment in prompt file, record raw prompt to metadata 2023-12-12 08:20:36 +09:00
Kohya S
07ef03d340 fix controlnet to work with gradual latent 2023-12-12 08:03:27 +09:00
Kohya S
9278031e60 Merge branch 'dev' into gradual_latent_hires_fix 2023-12-12 07:49:36 +09:00
Kohya S
4a2cef887c fix lllite training not working ref #913 2023-12-10 09:23:37 +09:00
Kohya S
42750f7846 fix error on pool_workaround in sdxl TE training ref #994 2023-12-10 09:18:33 +09:00
Kohya S
e8c3a02830 Merge branch 'dev' into gradual_latent_hires_fix 2023-12-08 08:23:53 +09:00
CjangCjengh
d31aa143f4 fix path error 2023-12-08 00:27:32 +08:00
CjangCjengh
710e777a92 fix path error 2023-12-08 00:22:13 +08:00
Kohya S
912dca8f65 fix duplicated sample gen for every epoch ref #907 2023-12-07 22:13:38 +09:00
Isotr0py
db84530074 Fix gradients synchronization for multi-GPUs training (#989)
* delete DDP wrapper

* fix train_db vae and train_network

* fix train_db vae and train_network unwrap

* network grad sync

---------

Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
2023-12-07 22:01:42 +09:00
Kohya S
72bbaac96d Merge pull request #985 from Disty0/dev
Update IPEX hijacks
2023-12-07 21:39:24 +09:00
Kohya S
5713d63dc5 add temporary workaround for playground-v2 2023-12-06 23:08:02 +09:00
CjangCjengh
d653e594c2 fix path error 2023-12-06 09:48:42 +08:00
Disty0
dd7bb33ab6 IPEX fix torch.UntypedStorage.is_cuda 2023-12-05 22:18:47 +03:00
Disty0
a9c6182b3f Cleanup IPEX libs 2023-12-05 19:52:31 +03:00
Disty0
3d70137d31 Disable IPEX attention if the GPU supports 64 bit 2023-12-05 19:40:16 +03:00
Disty0
bce9a081db Update IPEX hijacks 2023-12-05 14:17:31 +03:00
Kohya S
46cf41cc93 Merge pull request #961 from rockerBOO/attention-processor
Add attention processor
2023-12-03 21:24:12 +09:00
Kohya S
81a440c8e8 Merge pull request #955 from xzuyn/paged_adamw
Add PagedAdamW
2023-12-03 21:22:38 +09:00
Kohya S
f24a3b5282 show seed in generating samples 2023-12-03 21:15:30 +09:00
Kohya S
383b4a2c3e Merge pull request #907 from shirayu/add_option_sample_at_first
Add option --sample_at_first
2023-12-03 21:00:32 +09:00
Kohya S
df59822a27 Merge pull request #906 from shirayu/accept_scheduler_designation_in_training
Accept sampler designation in sampling of training
2023-12-03 20:46:16 +09:00
Kohya S
0908c5414d Merge pull request #978 from kohya-ss/dev
Dev
2023-12-03 18:27:29 +09:00
Kohya S
ee46134fa7 update readme 2023-12-03 18:24:50 +09:00
Kohya S
7a4e50705c add target_x flag (not sure this impl is correct) 2023-12-03 17:59:41 +09:00
Kohya S
2952bca520 fix strength error 2023-12-01 21:56:08 +09:00
Kohya S
39bb319d4c fix to work with cfg scale=1 2023-11-29 12:42:12 +09:00
Kohya S
29b6fa6212 add unsharp mask 2023-11-28 22:33:22 +09:00
Furqanil Taqwa
1bdd83a85f remove unnecessary debug print 2023-11-28 17:26:27 +07:00
Furqanil Taqwa
1624c239c2 added keep_tokens_separator to dynamically keep token for being shuffled 2023-11-28 17:23:55 +07:00
Furqanil Taqwa
4a913ce61e initialize keep_tokens_separator to dataset config 2023-11-28 17:22:35 +07:00
Kohya S
2c50ea0403 apply unsharp mask 2023-11-27 23:50:21 +09:00
Kohya S
298c6c2343 fix gradual latent cannot be disabled 2023-11-26 21:48:36 +09:00
Kohya S
2897a89dfd Merge branch 'dev' into gradual_latent_hires_fix 2023-11-26 18:12:24 +09:00
Kohya S
764e333fa2 make slicing vae compatible with latest diffusers 2023-11-26 18:12:04 +09:00
Kohya S
c61e3bf4c9 make separate U-Net for inference 2023-11-26 18:11:30 +09:00
Kohya S
fc8649d80f Merge pull request #934 from feffy380/fix-minsnr-vpred-zsnr
Fix min-snr-gamma for v-prediction and ZSNR.
2023-11-25 21:19:39 +09:00
Kohya S
0fb9ecf1f3 format by black, add ja comment 2023-11-25 21:05:55 +09:00
Kohya S
97958400fb Merge pull request #936 from wkpark/model_util-update
use **kwargs and change svd() calling convention to make svd() reusable
2023-11-25 21:01:14 +09:00
Kohya S
610566fbb9 Update README.md 2023-11-23 22:22:36 +09:00
Kohya S
684954695d add gradual latent 2023-11-23 22:17:49 +09:00
Kohya S
6d6d86260b add Deep Shrink 2023-11-23 19:40:48 +09:00
rockerBOO
c856ea4249 Add attention processor 2023-11-19 12:11:36 -05:00
Kohya S
d0923d6710 add caption_separator option 2023-11-19 21:44:52 +09:00
Kohya S
f312522cef Merge pull request #913 from KohakuBlueleaf/custom-seperator
Add custom seperator for shuffle caption
2023-11-19 21:32:01 +09:00
xzuyn
da5a144589 Add PagedAdamW 2023-11-18 07:47:27 -05:00
Won-Kyu Park
2c1e669bd8 add min_diff, clamp_quantile args
based on https://github.com/bmaltais/kohya_ss/pull/1332 a9ec90c40a
2023-11-10 02:35:55 +09:00
Won-Kyu Park
e20e9f61ac use **kwargs and change svd() calling convention to make svd() reusable
* add required attributes to model_org, model_tuned, save_to
 * set "*_alpha" using str(float(foo))
2023-11-10 02:35:10 +09:00
feffy380
6b3148fd3f Fix min-snr-gamma for v-prediction and ZSNR.
This fixes min-snr for vpred+zsnr by dividing directly by SNR+1.
The old implementation did it in two steps: (min-snr/snr) * (snr/(snr+1)), which causes division by zero when combined with --zero_terminal_snr
2023-11-07 23:02:25 +01:00
Kohya S
95ae56bd22 Update README.md 2023-11-05 21:10:26 +09:00
Kohya S
990192d077 Merge pull request #927 from kohya-ss/dev
Dev
2023-11-05 19:31:41 +09:00
Kohya S
f3e69531c3 update readme 2023-11-05 19:30:52 +09:00
Kohya S
0cb3272bda update readme 2023-11-05 19:26:35 +09:00
Kohya S
6231aa91e2 common lr logging, set default None to ddp_timeout 2023-11-05 19:09:17 +09:00
Kohaku-Blueleaf
489b728dbc Fix typo again 2023-10-30 20:19:51 +08:00
Kohaku-Blueleaf
583e2b2d01 Fix typo 2023-10-30 20:02:04 +08:00
Kohaku-Blueleaf
5dc2a0d3fd Add custom seperator 2023-10-30 19:55:30 +08:00
Yuta Hayashibe
2c731418ad Added sample_images() for --sample_at_first 2023-10-29 22:08:42 +09:00
Yuta Hayashibe
5c150675bf Added --sample_at_first description 2023-10-29 21:46:47 +09:00
Yuta Hayashibe
fea810b437 Added --sample_at_first to generate sample images before training 2023-10-29 21:44:57 +09:00
Kohya S
96d877be90 support separate LR for Text Encoder for SD1/2 2023-10-29 21:30:32 +09:00
Yuta Hayashibe
40d917b0fe Removed incorrect comments 2023-10-29 21:02:44 +09:00
Kohya S
e72020ae01 update readme 2023-10-29 20:52:43 +09:00
Kohya S
01d929ee2a support separate learning rates for TE1/2 2023-10-29 20:38:01 +09:00
Yuta Hayashibe
cf876fcdb4 Accept --ss to set sample_sampler dynamically 2023-10-29 20:15:04 +09:00
Yuta Hayashibe
291c29caaf Added a function line_to_prompt_dict() and removed duplicated initializations 2023-10-29 19:57:25 +09:00
Yuta Hayashibe
01e00ac1b0 Make a function get_my_scheduler() 2023-10-29 19:46:02 +09:00
Kohya S
a9ed4ed8a8 Merge pull request #900 from xzuyn/paged_adamw_32bit
Add PagedAdamW32bit
2023-10-29 15:01:55 +09:00
Kohya S
9d6a5a0c79 Merge pull request #899 from shirayu/use_moving_average
Show moving average loss in the progress bar
2023-10-29 14:37:58 +09:00
Kohya S
fb97a7aab1 Merge pull request #898 from shirayu/update_repare_buckets_latents
Fix a typo and add assertions in making buckets
2023-10-29 14:29:53 +09:00
Kohaku-Blueleaf
1cefb2a753 Better implementation for te autocast (#895)
* Better implementation for te

* Fix some misunderstanding

* as same as unet, add explicit convert

* Better cache TE and TE lr

* Fix with list

* Add timeout settings

* Fix arg style
2023-10-28 15:49:59 +09:00
Yuta Hayashibe
63992b81c8 Fix initialize place of loss_recorder 2023-10-27 21:13:29 +09:00
xzuyn
d8f68674fb Update train_util.py 2023-10-27 07:05:53 -04:00
Yuta Hayashibe
9d00c8eea2 Use LossRecorder 2023-10-27 18:31:36 +09:00
Yuta Hayashibe
0d21925bdf Use @property 2023-10-27 18:14:27 +09:00
Yuta Hayashibe
efef5c8ead Show "avr_loss" instead of "loss" because it is moving average 2023-10-27 17:59:58 +09:00
Yuta Hayashibe
3d2bb1a8f1 Add LossRecorder and use moving average in all places 2023-10-27 17:49:49 +09:00
Yuta Hayashibe
837a4dddb8 Added assertions 2023-10-26 13:34:36 +09:00
Yuta Hayashibe
b2626bc7a9 Fix a typo 2023-10-26 00:51:17 +09:00
青龍聖者@bdsqlsz
202f2c3292 Debias Estimation loss (#889)
* update for bnb 0.41.1

* fixed generate_controlnet_subsets_config for training

* Revert "update for bnb 0.41.1"

This reverts commit 70bd3612d8.

* add debiased_estimation_loss

* add train_network

* Revert "add train_network"

This reverts commit 6539363c5c.

* Update train_network.py
2023-10-23 22:59:14 +09:00
Kohya S
2a23713f71 Merge pull request #872 from kohya-ss/dev
fix make_captions_by_git, improve image generation scripts
2023-10-11 07:56:39 +09:00
Kohya S
681034d001 update readme 2023-10-11 07:54:30 +09:00
Kohya S
17813ff5b4 remove workaround for transfomers bs>1 close #869 2023-10-11 07:40:12 +09:00
Kohya S
3e81bd6b67 fix network_merge, add regional mask as color code 2023-10-09 23:07:14 +09:00
Kohya S
23ae358e0f Merge branch 'main' into dev 2023-10-09 21:42:13 +09:00
Kohya S
f611726364 add network_merge_n_models option 2023-10-09 21:41:50 +09:00
Kohya S
33ee0acd35 Merge pull request #867 from kohya-ss/dev
onnx support in wd14 tagger, OFT
2023-10-09 18:04:17 +09:00
Kohya S
8b79e3b06c fix typos 2023-10-09 18:00:45 +09:00
Kohya S
cf49e912fc update readme 2023-10-09 17:59:31 +09:00
Kohya S
66741c035c add OFT 2023-10-09 17:59:24 +09:00
Kohya S
406511c333 add error message if model.onnx doesn't exist 2023-10-09 17:08:58 +09:00
Kohya S
8a2d68d63e Merge pull request #864 from Isotr0py/onnx
Add `--onnx` to wd14 tagger
2023-10-09 15:14:11 +09:00
Kohya S
07d297fdbe Merge branch 'dev' into onnx 2023-10-09 15:13:40 +09:00
Kohya S
0d4e8b50d0 change option to append_tags, minor update 2023-10-09 15:09:54 +09:00
Kohya S
1d7c5c2a98 Merge pull request #858 from a-l-e-x-d-s-9/main
Add append_captions feature to wd14 tagger
2023-10-09 14:31:54 +09:00
Kohya S
0faa350175 Merge pull request #865 from kohya-ss/dev
Support JPEG-XL on windows, dropout for LyCORIS
2023-10-09 14:11:49 +09:00
Kohya S
8a7509db75 Merge branch 'dev' of https://github.com/kohya-ss/sd-scripts into dev 2023-10-09 14:07:02 +09:00
Kohya S
025368f51c may work dropout in LyCORIS #859 2023-10-09 14:06:58 +09:00
Kohya S
5fe52ed322 Merge pull request #856 from Isotr0py/jxl
Fix JPEG-XL support
2023-10-09 13:55:03 +09:00
Kohya S
8b247a330b Merge pull request #851 from kohya-ss/dependabot/github_actions/actions/checkout-4
Bump actions/checkout from 3 to 4
2023-10-09 11:45:47 +09:00
Isotr0py
d6f458fcb3 fix dependency 2023-10-08 23:51:18 +08:00
Isotr0py
b8b84021e5 fix a typo 2023-10-08 20:49:03 +08:00
Isotr0py
70fe7e18be add onnx to wd14 tagger 2023-10-08 20:31:10 +08:00
alexds9
9378da3c82 Fix comment 2023-10-05 21:29:46 +03:00
alexds9
a4857fa764 Add append_captions feature to wd14 tagger
This feature allows for appending new tags to the existing content of caption files.
If the caption file for an image already exists, the tags generated from the current
run are appended to the existing ones. Duplicate tags are checked and avoided.
2023-10-05 21:26:09 +03:00
Isotr0py
592014923f Support JPEG-XL on windows 2023-10-04 21:48:25 +08:00
dependabot[bot]
6d06b215bf Bump actions/checkout from 3 to 4
Bumps [actions/checkout](https://github.com/actions/checkout) from 3 to 4.
- [Release notes](https://github.com/actions/checkout/releases)
- [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/checkout/compare/v3...v4)

---
updated-dependencies:
- dependency-name: actions/checkout
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-10-01 22:51:32 +00:00
Kohya S
2d87bb648f Merge pull request #850 from kohya-ss/dev
fix typos
2023-10-02 07:51:05 +09:00
Kohya S
56ebef35b0 Merge pull request #848 from shirayu/update_typos
Update typos to the latest version and add dependabot.yml
2023-10-02 07:45:29 +09:00
Yuta Hayashibe
13d8b22d25 Add dependabot 2023-10-01 21:52:16 +09:00
Yuta Hayashibe
27f9b6ffeb updated typos to v1.16.15 and fix typos 2023-10-01 21:51:24 +09:00
Yuta Hayashibe
c8fcfd4581 Add "venv" to extend-exclude 2023-10-01 21:48:50 +09:00
Kohya S
49c24285c7 Merge pull request #847 from kohya-ss/sdxl
merge sdxl into main
2023-10-01 20:40:33 +09:00
Kohya S
c918489259 update readme 2023-10-01 20:34:12 +09:00
Kohya S
93155242fa Merge pull request #846 from kohya-ss/dev
Fix to work training U-Net only LoRA for SD1/2
2023-10-01 16:44:13 +09:00
Kohya S
4cc919607a fix placing of requires_grad_ of U-Net 2023-10-01 16:41:48 +09:00
Kohya S
81419f7f32 Fix to work training U-Net only LoRA for SD1/2 2023-10-01 16:37:23 +09:00
Kohya S
6bd6cd9c51 update doc 2023-10-01 12:17:54 +09:00
Kohya S
35a1d68eb6 Merge pull request #844 from kohya-ss/dev
IPEX update, concat LoRA
2023-10-01 12:06:36 +09:00
Kohya S
365a06bdb6 Merge pull request #839 from laksjdjf/sdxl
Support concat LoRA
2023-10-01 11:16:46 +09:00
Kohya S
8e117f9f92 Merge pull request #841 from Disty0/dev
IPEX Attention optimizations
2023-10-01 10:58:19 +09:00
Disty0
209eafb631 IPEX attention optimizations 2023-09-28 14:02:25 +03:00
laksjdjf
14aa2923cf Support concat LoRA 2023-09-28 14:39:32 +09:00
Kohya S
1e395ed285 Merge branch 'main' into sdxl 2023-09-24 17:51:08 +09:00
Kohya S
98615166b0 Merge pull request #831 from kohya-ss/dev
update versions of accelerate and diffusers
2023-09-24 17:50:40 +09:00
Kohya S
28272de97a update readme 2023-09-24 17:48:51 +09:00
Kohya S
7e736da30c update versions of accelerate and diffusers 2023-09-24 17:46:57 +09:00
Kohya S
20e929e27e fix to work iter_same_seed 2023-09-24 16:04:50 +09:00
Kohya S
477b5260aa fix sai metadata for sdxl closes #824 2023-09-24 14:47:13 +09:00
Kohya S
d39f1a3427 Merge pull request #808 from rockerBOO/metadata
Add ip_noise_gamma metadata
2023-09-24 14:35:18 +09:00
Kohya S
3757855231 rename train_lllite_alt to train_lllite 2023-09-24 14:34:31 +09:00
Kohya S
d846431015 Merge branch 'dev' into sdxl 2023-09-24 14:30:16 +09:00
Kohya S
624edf428f Merge pull request #825 from Disty0/dev
Intel ARC support with IPEX
2023-09-24 14:29:03 +09:00
Kohya S
54500b861d Merge pull request #830 from kohya-ss/dev2
add extension checking for resize_lora.py
2023-09-24 12:12:32 +09:00
Kohya S
f2491ee0ac change block name doesn't contain '.' at end 2023-09-24 12:10:56 +09:00
Kohya S
1f169ee7fb Merge pull request #760 from Symbiomatrix/bugfix1
Update resize_lora.py
2023-09-24 11:59:18 +09:00
Kohya S
66817992c1 revert formatting 2023-09-24 11:50:44 +09:00
Kohya S
8052bcd5cd format by black 2023-09-24 11:26:28 +09:00
Kohya S
55886a0116 add .pt and .pth for available extension 2023-09-24 11:25:54 +09:00
Kohya S
33e90cc6a0 Merge pull request #815 from jvkap/patch-1
Update resize_lora.py
2023-09-24 11:02:12 +09:00
青龍聖者@bdsqlsz
d5be8125b0 update bitsandbytes for 0.41.1 and fixed bugs with generate_controlnet_subsets_config for training (#823)
* update for bnb 0.41.1

* fixed generate_controlnet_subsets_config for training

* Revert "update for bnb 0.41.1"

This reverts commit 70bd3612d8.
2023-09-24 10:51:47 +09:00
Disty0
b99cd2a920 Update getDeviceIdListForCard 2023-09-20 17:16:06 +03:00
Disty0
b64389c8a9 Intel ARC support with IPEX 2023-09-19 18:05:05 +03:00
jvkap
a0e05fa291 Update resize_lora.py 2023-09-11 11:41:33 -03:00
jvkap
e33c007cd0 Update resize_lora.py 2023-09-11 11:29:06 -03:00
rockerBOO
80aca1ccc7 Add ip_noise_gamma metadata 2023-09-05 15:20:15 -04:00
Kohya S
6b3a580ee5 Merge pull request #804 from kohya-ss/dev
fix to work regional LoRA
2023-09-03 17:52:23 +09:00
Kohya S
74561dbdac update readme (#803)
* update readme

* update readme

* fix typo
2023-09-03 12:51:09 +09:00
Symbiomatrix
9d678a6f41 Update resize_lora.py 2023-08-16 00:08:09 +03:00
92 changed files with 14926 additions and 3850 deletions

7
.github/dependabot.yml vendored Normal file
View File

@@ -0,0 +1,7 @@
---
version: 2
updates:
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "monthly"

View File

@@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: typos-action
uses: crate-ci/typos@v1.13.10
uses: crate-ci/typos@v1.24.3

View File

@@ -3,21 +3,25 @@ Stable Diffusionの学習、画像生成、その他のスクリプトを入れ
[README in English](./README.md) ←更新情報はこちらにあります
開発中のバージョンはdevブランチにあります。最新の変更点はdevブランチをご確認ください。
FLUX.1およびSD3/SD3.5対応はsd3ブランチで行っています。それらの学習を行う場合はsd3ブランチをご利用ください。
GUIやPowerShellスクリプトなど、より使いやすくする機能が[bmaltais氏のリポジトリ](https://github.com/bmaltais/kohya_ss)で提供されています英語ですのであわせてご覧ください。bmaltais氏に感謝します。
以下のスクリプトがあります。
* DreamBooth、U-NetおよびText Encoderの学習をサポート
* fine-tuning、同上
* LoRAの学習をサポート
* 画像生成
* モデル変換Stable Diffision ckpt/safetensorsとDiffusersの相互変換
## 使用法について
当リポジトリ内およびnote.comに記事がありますのでそちらをご覧ください将来的にはすべてこちらへ移すかもしれません
* [学習について、共通編](./docs/train_README-ja.md) : データ整備やオプションなど
* [データセット設定](./docs/config_README-ja.md)
* [SDXL学習](./docs/train_SDXL-en.md) (英語版)
* [DreamBoothの学習について](./docs/train_db_README-ja.md)
* [fine-tuningのガイド](./docs/fine_tune_README_ja.md):
* [LoRAの学習について](./docs/train_network_README-ja.md)
@@ -32,6 +36,8 @@ Python 3.10.6およびGitが必要です。
- Python 3.10.6: https://www.python.org/ftp/python/3.10.6/python-3.10.6-amd64.exe
- git: https://git-scm.com/download/win
Python 3.10.x、3.11.x、3.12.xでも恐らく動作しますが、3.10.6でテストしています。
PowerShellを使う場合、venvを使えるようにするためには以下の手順でセキュリティ設定を変更してください。
venvに限らずスクリプトの実行が可能になりますので注意してください。
@@ -41,11 +47,11 @@ PowerShellを使う場合、venvを使えるようにするためには以下の
## Windows環境でのインストール
以下の例ではPyTorchは1.12.1CUDA 11.6版をインストールします。CUDA 11.3版やPyTorch 1.13を使う場合は適宜書き換えください
スクリプトはPyTorch 2.1.2でテストしています。PyTorch 2.2以降でも恐らく動作します
なお、python -m venvの行で「python」とだけ表示された場合、py -m venvのようにpythonをpyに変更してください。
通常の管理者ではないPowerShellを開き以下を順に実行します。
PowerShellを使う場合、通常の管理者ではないPowerShellを開き以下を順に実行します。
```powershell
git clone https://github.com/kohya-ss/sd-scripts.git
@@ -54,50 +60,23 @@ cd sd-scripts
python -m venv venv
.\venv\Scripts\activate
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu118
pip install --upgrade -r requirements.txt
pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu118
accelerate config
```
<!--
pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117
pip install --use-pep517 --upgrade -r requirements.txt
pip install -U -I --no-deps xformers==0.0.16
-->
コマンドプロンプトでも同一です。
コマンドプロンプトでは以下になります
注:`bitsandbytes==0.44.0``prodigyopt==1.0``lion-pytorch==0.0.6``requirements.txt` に含まれるようになりました。他のバージョンを使う場合は適宜インストールしてください
この例では PyTorch および xfomers は2.1.2CUDA 11.8版をインストールします。CUDA 12.1版やPyTorch 1.12.1を使う場合は適宜書き換えください。たとえば CUDA 12.1版の場合は `pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu121` および `pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu121` としてください。
```bat
git clone https://github.com/kohya-ss/sd-scripts.git
cd sd-scripts
python -m venv venv
.\venv\Scripts\activate
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
pip install --upgrade -r requirements.txt
pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
copy /y .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
copy /y .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
copy /y .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
accelerate config
```
(注:``python -m venv venv`` のほうが ``python -m venv --system-site-packages venv`` より安全そうなため書き換えました。globalなpythonにパッケージがインストールしてあると、後者だといろいろと問題が起きます。
PyTorch 2.2以降を用いる場合は、`torch==2.1.2``torchvision==0.16.2` 、および `xformers==0.0.23.post1` を適宜変更してください。
accelerate configの質問には以下のように答えてください。bf16で学習する場合、最後の質問にはbf16と答えてください。
※0.15.0から日本語環境では選択のためにカーソルキーを押すと落ちます……。数字キーの0、1、2……で選択できますので、そちらを使ってください。
```txt
- This machine
- No distributed training
@@ -111,30 +90,6 @@ accelerate configの質問には以下のように答えてください。bf1
※場合によって ``ValueError: fp16 mixed precision requires a GPU`` というエラーが出ることがあるようです。この場合、6番目の質問
``What GPU(s) (by id) should be used for training on this machine as a comma-separated list? [all]:``に「0」と答えてください。id `0`のGPUが使われます。
### PyTorchとxformersのバージョンについて
他のバージョンでは学習がうまくいかない場合があるようです。特に他の理由がなければ指定のバージョンをお使いください。
### オプションLion8bitを使う
Lion8bitを使う場合には`bitsandbytes`を0.38.0以降にアップグレードする必要があります。`bitsandbytes`をアンインストールし、Windows環境では例えば[こちら](https://github.com/jllllll/bitsandbytes-windows-webui)などからWindows版のwhlファイルをインストールしてください。たとえば以下のような手順になります。
```powershell
pip install https://github.com/jllllll/bitsandbytes-windows-webui/raw/main/bitsandbytes-0.38.1-py3-none-any.whl
```
アップグレード時には`pip install .`でこのリポジトリを更新し、必要に応じて他のパッケージもアップグレードしてください。
### オプションPagedAdamW8bitとPagedLion8bitを使う
PagedAdamW8bitとPagedLion8bitを使う場合には`bitsandbytes`を0.39.0以降にアップグレードする必要があります。`bitsandbytes`をアンインストールし、Windows環境では例えば[こちら](https://github.com/jllllll/bitsandbytes-windows-webui)などからWindows版のwhlファイルをインストールしてください。たとえば以下のような手順になります。
```powershell
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
```
アップグレード時には`pip install .`でこのリポジトリを更新し、必要に応じて他のパッケージもアップグレードしてください。
## アップグレード
新しいリリースがあった場合、以下のコマンドで更新できます。
@@ -164,4 +119,47 @@ Conv2d 3x3への拡大は [cloneofsimo氏](https://github.com/cloneofsimo/lora)
[BLIP](https://github.com/salesforce/BLIP): BSD-3-Clause
## その他の情報
### LoRAの名称について
`train_network.py` がサポートするLoRAについて、混乱を避けるため名前を付けました。ドキュメントは更新済みです。以下は当リポジトリ内の独自の名称です。
1. __LoRA-LierLa__ : (LoRA for __Li__ n __e__ a __r__ __La__ yers、リエラと読みます)
Linear 層およびカーネルサイズ 1x1 の Conv2d 層に適用されるLoRA
2. __LoRA-C3Lier__ : (LoRA for __C__ olutional layers with __3__ x3 Kernel and __Li__ n __e__ a __r__ layers、セリアと読みます)
1.に加え、カーネルサイズ 3x3 の Conv2d 層に適用されるLoRA
デフォルトではLoRA-LierLaが使われます。LoRA-C3Lierを使う場合は `--network_args` に `conv_dim` を指定してください。
<!--
LoRA-LierLa は[Web UI向け拡張](https://github.com/kohya-ss/sd-webui-additional-networks)、またはAUTOMATIC1111氏のWeb UIのLoRA機能で使用することができます。
LoRA-C3Lierを使いWeb UIで生成するには拡張を使用してください。
-->
### 学習中のサンプル画像生成
プロンプトファイルは例えば以下のようになります。
```
# prompt 1
masterpiece, best quality, (1girl), in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
# prompt 2
masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n (low quality, worst quality), bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
```
`#` で始まる行はコメントになります。`--n` のように「ハイフン二個+英小文字」の形でオプションを指定できます。以下が使用可能できます。
* `--n` Negative prompt up to the next option.
* `--w` Specifies the width of the generated image.
* `--h` Specifies the height of the generated image.
* `--d` Specifies the seed of the generated image.
* `--l` Specifies the CFG scale of the generated image.
* `--s` Specifies the number of steps in the generation.
`( )` や `[ ]` などの重みづけも動作します。

824
README.md
View File

@@ -1,9 +1,16 @@
This repository contains training, generation and utility scripts for Stable Diffusion.
[__Change History__](#change-history) is moved to the bottom of the page.
[__Change History__](#change-history) is moved to the bottom of the page.
更新履歴は[ページ末尾](#change-history)に移しました。
[日本語版README](./README-ja.md)
Latest update: 2025-03-21 (Version 0.9.1)
[日本語版READMEはこちら](./README-ja.md)
The development version is in the `dev` branch. Please check the dev branch for the latest changes.
FLUX.1 and SD3/SD3.5 support is done in the `sd3` branch. If you want to train them, please use the sd3 branch.
For easier use (GUI and PowerShell scripts etc...), please visit [the repository maintained by bmaltais](https://github.com/bmaltais/kohya_ss). Thanks to @bmaltais!
@@ -16,135 +23,13 @@ This repository contains the scripts for:
* Image generation
* Model conversion (supports 1.x and 2.x, Stable Diffision ckpt/safetensors and Diffusers)
__Stable Diffusion web UI now seems to support LoRA trained by ``sd-scripts``.__ Thank you for great work!!!
## About SDXL training
The feature of SDXL training is now available in sdxl branch as an experimental feature.
Sep 3, 2023: The feature will be merged into the main branch soon. Following are the changes from the previous version.
- ControlNet-LLLite is added. See [documentation](./docs/train_lllite_README.md) for details.
- JPEG XL is supported. [#786](https://github.com/kohya-ss/sd-scripts/pull/786)
- Peak memory usage is reduced. [#791](https://github.com/kohya-ss/sd-scripts/pull/791)
- Input perturbation noise is added. See [#798](https://github.com/kohya-ss/sd-scripts/pull/798) for details.
- Dataset subset now has `caption_prefix` and `caption_suffix` options. The strings are added to the beginning and the end of the captions before shuffling. You can specify the options in `.toml`.
- Other minor changes.
- Thanks for contributions from Isotr0py, vvern999, lansing and others!
Aug 13, 2023:
- LoRA-FA is added experimentally. Specify `--network_module networks.lora_fa` option instead of `--network_module networks.lora`. The trained model can be used as a normal LoRA model.
Aug 12, 2023:
- The default value of noise offset when omitted has been changed to 0 from 0.0357.
- The different learning rates for each U-Net block are now supported. Specify with `--block_lr` option. Specify 23 values separated by commas like `--block_lr 1e-3,1e-3 ... 1e-3`.
- 23 values correspond to `0: time/label embed, 1-9: input blocks 0-8, 10-12: mid blocks 0-2, 13-21: output blocks 0-8, 22: out`.
Aug 6, 2023:
- [SAI Model Spec](https://github.com/Stability-AI/ModelSpec) metadata is now supported partially. `hash_sha256` is not supported yet.
- The main items are set automatically.
- You can set title, author, description, license and tags with `--metadata_xxx` options in each training script.
- Merging scripts also support minimum SAI Model Spec metadata. See the help message for the usage.
- Metadata editor will be available soon.
- SDXL LoRA has `sdxl_base_v1-0` now for `ss_base_model_version` metadata item, instead of `v0-9`.
Aug 4, 2023:
- `bitsandbytes` is now optional. Please install it if you want to use it. The insructions are in the later section.
- `albumentations` is not required anymore.
- An issue for pooled output for Textual Inversion training is fixed.
- `--v_pred_like_loss ratio` option is added. This option adds the loss like v-prediction loss in SDXL training. `0.1` means that the loss is added 10% of the v-prediction loss. The default value is None (disabled).
- In v-prediction, the loss is higher in the early timesteps (near the noise). This option can be used to increase the loss in the early timesteps.
- Arbitrary options can be used for Diffusers' schedulers. For example `--lr_scheduler_args "lr_end=1e-8"`.
- `sdxl_gen_imgs.py` supports batch size > 1.
- Fix ControlNet to work with attention couple and reginal LoRA in `gen_img_diffusers.py`.
Summary of the feature:
- `tools/cache_latents.py` is added. This script can be used to cache the latents to disk in advance.
- The options are almost the same as `sdxl_train.py'. See the help message for the usage.
- Please launch the script as follows:
`accelerate launch --num_cpu_threads_per_process 1 tools/cache_latents.py ...`
- This script should work with multi-GPU, but it is not tested in my environment.
- `tools/cache_text_encoder_outputs.py` is added. This script can be used to cache the text encoder outputs to disk in advance.
- The options are almost the same as `cache_latents.py' and `sdxl_train.py'. See the help message for the usage.
- `sdxl_train.py` is a script for SDXL fine-tuning. The usage is almost the same as `fine_tune.py`, but it also supports DreamBooth dataset.
- `--full_bf16` option is added. Thanks to KohakuBlueleaf!
- This option enables the full bfloat16 training (includes gradients). This option is useful to reduce the GPU memory usage.
- However, bitsandbytes==0.35 doesn't seem to support this. Please use a newer version of bitsandbytes or another optimizer.
- I cannot find bitsandbytes>0.35.0 that works correctly on Windows.
- In addition, the full bfloat16 training might be unstable. Please use it at your own risk.
- `prepare_buckets_latents.py` now supports SDXL fine-tuning.
- `sdxl_train_network.py` is a script for LoRA training for SDXL. The usage is almost the same as `train_network.py`.
- Both scripts has following additional options:
- `--cache_text_encoder_outputs` and `--cache_text_encoder_outputs_to_disk`: Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions.
- `--no_half_vae`: Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs.
- The image generation during training is now available. `--no_half_vae` option also works to avoid black images.
- `--weighted_captions` option is not supported yet for both scripts.
- `--min_timestep` and `--max_timestep` options are added to each training script. These options can be used to train U-Net with different timesteps. The default values are 0 and 1000.
- `sdxl_train_textual_inversion.py` is a script for Textual Inversion training for SDXL. The usage is almost the same as `train_textual_inversion.py`.
- `--cache_text_encoder_outputs` is not supported.
- `token_string` must be alphabet only currently, due to the limitation of the open-clip tokenizer.
- There are two options for captions:
1. Training with captions. All captions must include the token string. The token string is replaced with multiple tokens.
2. Use `--use_object_template` or `--use_style_template` option. The captions are generated from the template. The existing captions are ignored.
- See below for the format of the embeddings.
- `sdxl_gen_img.py` is added. This script can be used to generate images with SDXL, including LoRA. See the help message for the usage.
- Textual Inversion is supported, but the name for the embeds in the caption becomes alphabet only. For example, `neg_hand_v1.safetensors` can be activated with `neghandv`.
`requirements.txt` is updated to support SDXL training.
### Tips for SDXL training
- The default resolution of SDXL is 1024x1024.
- The fine-tuning can be done with 24GB GPU memory with the batch size of 1. For 24GB GPU, the following options are recommended __for the fine-tuning with 24GB GPU memory__:
- Train U-Net only.
- Use gradient checkpointing.
- Use `--cache_text_encoder_outputs` option and caching latents.
- Use Adafactor optimizer. RMSprop 8bit or Adagrad 8bit may work. AdamW 8bit doesn't seem to work.
- The LoRA training can be done with 8GB GPU memory (10GB recommended). For reducing the GPU memory usage, the following options are recommended:
- Train U-Net only.
- Use gradient checkpointing.
- Use `--cache_text_encoder_outputs` option and caching latents.
- Use one of 8bit optimizers or Adafactor optimizer.
- Use lower dim (-8 for 8GB GPU).
- `--network_train_unet_only` option is highly recommended for SDXL LoRA. Because SDXL has two text encoders, the result of the training will be unexpected.
- PyTorch 2 seems to use slightly less GPU memory than PyTorch 1.
- `--bucket_reso_steps` can be set to 32 instead of the default value 64. Smaller values than 32 will not work for SDXL training.
Example of the optimizer settings for Adafactor with the fixed learning rate:
```toml
optimizer_type = "adafactor"
optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ]
lr_scheduler = "constant_with_warmup"
lr_warmup_steps = 100
learning_rate = 4e-7 # SDXL original learning rate
```
### Format of Textual Inversion embeddings
```python
from safetensors.torch import save_file
state_dict = {"clip_g": embs_for_text_encoder_1280, "clip_l": embs_for_text_encoder_768}
save_file(state_dict, file)
```
## About requirements.txt
These files do not contain requirements for PyTorch. Because the versions of them depend on your environment. Please install PyTorch at first (see installation guide below.)
The file does not contain requirements for PyTorch. Because the version of PyTorch depends on the environment, it is not included in the file. Please install PyTorch first according to the environment. See installation instructions below.
The scripts are tested with PyTorch 1.12.1 and 2.0.1, Diffusers 0.18.2.
The scripts are tested with Pytorch 2.1.2. PyTorch 2.2 or later will work. Please install the appropriate version of PyTorch and xformers.
## Links to how-to-use documents
## Links to usage documentation
Most of the documents are written in Japanese.
@@ -152,11 +37,13 @@ Most of the documents are written in Japanese.
* [Training guide - common](./docs/train_README-ja.md) : data preparation, options etc...
* [Chinese version](./docs/train_README-zh.md)
* [SDXL training](./docs/train_SDXL-en.md) (English version)
* [Dataset config](./docs/config_README-ja.md)
* [English version](./docs/config_README-en.md)
* [DreamBooth training guide](./docs/train_db_README-ja.md)
* [Step by Step fine-tuning guide](./docs/fine_tune_README_ja.md):
* [training LoRA](./docs/train_network_README-ja.md)
* [training Textual Inversion](./docs/train_ti_README-ja.md)
* [Training LoRA](./docs/train_network_README-ja.md)
* [Training Textual Inversion](./docs/train_ti_README-ja.md)
* [Image generation](./docs/gen_img_README-ja.md)
* note.com [Model conversion](https://note.com/kohya_ss/n/n374f316fe4ad)
@@ -167,6 +54,8 @@ Python 3.10.6 and Git:
- Python 3.10.6: https://www.python.org/ftp/python/3.10.6/python-3.10.6-amd64.exe
- git: https://git-scm.com/download/win
Python 3.10.x, 3.11.x, and 3.12.x will work but not tested.
Give unrestricted script access to powershell so venv can work:
- Open an administrator powershell window
@@ -184,14 +73,20 @@ cd sd-scripts
python -m venv venv
.\venv\Scripts\activate
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu118
pip install --upgrade -r requirements.txt
pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu118
accelerate config
```
__Note:__ Now bitsandbytes is optional. Please install any version of bitsandbytes as needed. Installation instructions are in the following section.
If `python -m venv` shows only `python`, change `python` to `py`.
Note: Now `bitsandbytes==0.44.0`, `prodigyopt==1.0` and `lion-pytorch==0.0.6` are included in the requirements.txt. If you'd like to use the another version, please install it manually.
This installation is for CUDA 11.8. If you use a different version of CUDA, please install the appropriate version of PyTorch and xformers. For example, if you use CUDA 12, please install `pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu121` and `pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu121`.
If you use PyTorch 2.2 or later, please change `torch==2.1.2` and `torchvision==0.16.2` and `xformers==0.0.23.post1` to the appropriate version.
<!--
cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
@@ -210,73 +105,13 @@ Answers to accelerate config:
- fp16
```
note: Some user reports ``ValueError: fp16 mixed precision requires a GPU`` is occurred in training. In this case, answer `0` for the 6th question:
If you'd like to use bf16, please answer `bf16` to the last question.
Note: Some user reports ``ValueError: fp16 mixed precision requires a GPU`` is occurred in training. In this case, answer `0` for the 6th question:
``What GPU(s) (by id) should be used for training on this machine as a comma-separated list? [all]:``
(Single GPU with id `0` will be used.)
### Experimental: Use PyTorch 2.0
In this case, you need to install PyTorch 2.0 and xformers 0.0.20. Instead of the above, please type the following:
```powershell
git clone https://github.com/kohya-ss/sd-scripts.git
cd sd-scripts
python -m venv venv
.\venv\Scripts\activate
pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118
pip install --upgrade -r requirements.txt
pip install xformers==0.0.20
accelerate config
```
Answers to accelerate config should be the same as above.
### about PyTorch and xformers
Other versions of PyTorch and xformers seem to have problems with training.
If there is no other reason, please install the specified version.
### Optional: Use `bitsandbytes` (8bit optimizer)
For 8bit optimizer, you need to install `bitsandbytes`. For Linux, please install `bitsandbytes` as usual (0.41.1 or later is recommended.)
For Windows, there are several versions of `bitsandbytes`:
- `bitsandbytes` 0.35.0: Stable version. AdamW8bit is available. `full_bf16` is not available.
- `bitsandbytes` 0.41.1: Lion8bit, PagedAdamW8bit and PagedLion8bit are available. `full_bf16` is available.
Note: `bitsandbytes`above 0.35.0 till 0.41.0 seems to have an issue: https://github.com/TimDettmers/bitsandbytes/issues/659
Follow the instructions below to install `bitsandbytes` for Windows.
### bitsandbytes 0.35.0 for Windows
Open a regular Powershell terminal and type the following inside:
```powershell
cd sd-scripts
.\venv\Scripts\activate
pip install bitsandbytes==0.35.0
cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
```
This will install `bitsandbytes` 0.35.0 and copy the necessary files to the `bitsandbytes` directory.
### bitsandbytes 0.41.1 for Windows
Install the Windows version whl file from [here](https://github.com/jllllll/bitsandbytes-windows-webui) or other sources, like:
```powershell
python -m pip install bitsandbytes==0.41.1 --prefer-binary --extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui
```
## Upgrade
When a new release comes out you can upgrade your repo with the following command:
@@ -290,6 +125,10 @@ pip install --use-pep517 --upgrade -r requirements.txt
Once the commands have completed successfully you should be ready to use the new version.
### Upgrade PyTorch
If you want to upgrade PyTorch, you can upgrade it with `pip install` command in [Windows Installation](#windows-installation) section. `xformers` is also required to be upgraded when PyTorch is upgraded.
## Credits
The implementation for LoRA is based on [cloneofsimo's repo](https://github.com/cloneofsimo/lora). Thank you for great work!
@@ -306,218 +145,415 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
[BLIP](https://github.com/salesforce/BLIP): BSD-3-Clause
## Change History
### 15 Jun. 2023, 2023/06/15
### Mar 21, 2025 / 2025-03-21 Version 0.9.1
- Prodigy optimizer is supported in each training script. It is a member of D-Adaptation and is effective for DyLoRA training. [PR #585](https://github.com/kohya-ss/sd-scripts/pull/585) Please see the PR for details. Thanks to sdbds!
- Install the package with `pip install prodigyopt`. Then specify the option like `--optimizer_type="prodigy"`.
- Arbitrary Dataset is supported in each training script (except XTI). You can use it by defining a Dataset class that returns images and captions.
- Prepare a Python script and define a class that inherits `train_util.MinimalDataset`. Then specify the option like `--dataset_class package.module.DatasetClass` in each training script.
- Please refer to `MinimalDataset` for implementation. I will prepare a sample later.
- The following features have been added to the generation script.
- Added an option `--highres_fix_disable_control_net` to disable ControlNet in the 2nd stage of Highres. Fix. Please try it if the image is disturbed by some ControlNet such as Canny.
- Added Variants similar to sd-dynamic-propmpts in the prompt.
- If you specify `{spring|summer|autumn|winter}`, one of them will be randomly selected.
- If you specify `{2$$chocolate|vanilla|strawberry}`, two of them will be randomly selected.
- If you specify `{1-2$$ and $$chocolate|vanilla|strawberry}`, one or two of them will be randomly selected and connected by ` and `.
- You can specify the number of candidates in the range `0-2`. You cannot omit one side like `-2` or `1-`.
- It can also be specified for the prompt option.
- If you specify `e` or `E`, all candidates will be selected and the prompt will be repeated multiple times (`--images_per_prompt` is ignored). It may be useful for creating X/Y plots.
- You can also specify `--am {e$$0.2|0.4|0.6|0.8|1.0},{e$$0.4|0.7|1.0} --d 1234`. In this case, 15 prompts will be generated with 5*3.
- There is no weighting function.
- Fixed a bug where some of LoRA modules for CLIP Text Encoder were not trained. Thank you Nekotekina for PR [#1964](https://github.com/kohya-ss/sd-scripts/pull/1964)
- The LoRA modules for CLIP Text Encoder are now 264 modules, which is the same as before. Only 88 modules were trained in the previous version.
- 各学習スクリプトでProdigyオプティマイザがサポートされました。D-Adaptationの仲間でDyLoRAの学習に有効とのことです。 [PR #585](https://github.com/kohya-ss/sd-scripts/pull/585) 詳細はPRをご覧ください。sdbds氏に感謝します。
- `pip install prodigyopt` としてパッケージをインストールしてください。また `--optimizer_type="prodigy"` のようにオプションを指定します。
- 各学習スクリプトで任意のDatasetをサポートしましたXTIを除く。画像とキャプションを返すDatasetクラスを定義することで、学習スクリプトから利用できます。
- Pythonスクリプトを用意し、`train_util.MinimalDataset`を継承するクラスを定義してください。そして各学習スクリプトのオプションで `--dataset_class package.module.DatasetClass` のように指定してください。
- 実装方法は `MinimalDataset` を参考にしてください。のちほどサンプルを用意します。
- 生成スクリプトに以下の機能追加を行いました。
- Highres. Fixの2nd stageでControlNetを無効化するオプション `--highres_fix_disable_control_net` を追加しました。Canny等一部のControlNetで画像が乱れる場合にお試しください。
- プロンプトでsd-dynamic-propmptsに似たVariantをサポートしました。
- `{spring|summer|autumn|winter}` のように指定すると、いずれかがランダムに選択されます。
- `{2$$chocolate|vanilla|strawberry}` のように指定すると、いずれか2個がランダムに選択されます。
- `{1-2$$ and $$chocolate|vanilla|strawberry}` のように指定すると、1個か2個がランダムに選択され ` and ` で接続されます。
- 個数のレンジ指定では`0-2`のように0個も指定可能です。`-2`や`1-`のような片側の省略はできません。
- プロンプトオプションに対しても指定可能です。
- `{e$$chocolate|vanilla|strawberry}` のように`e`または`E`を指定すると、すべての候補が選択されプロンプトが複数回繰り返されます(`--images_per_prompt`は無視されます。X/Y plotの作成に便利かもしれません。
- `--am {e$$0.2|0.4|0.6|0.8|1.0},{e$$0.4|0.7|1.0} --d 1234`のような指定も可能です。この場合、5*3で15回のプロンプトが生成されます。
- Weightingの機能はありません。
### Jan 17, 2025 / 2025-01-17 Version 0.9.0
### 8 Jun. 2023, 2023/06/08
- __important__ The dependent libraries are updated. Please see [Upgrade](#upgrade) and update the libraries.
- bitsandbytes, transformers, accelerate and huggingface_hub are updated.
- If you encounter any issues, please report them.
- Fixed a bug where clip skip did not work when training with weighted captions (`--weighted_captions` specified) and when generating sample images during training.
- 重みづけキャプションでの学習時(`--weighted_captions`指定時および学習中のサンプル画像生成時にclip skipが機能しない不具合を修正しました。
- The dev branch is merged into main. The documentation is delayed, and I apologize for that. I will gradually improve it.
- The state just before the merge is released as Version 0.8.8, so please use it if you encounter any issues.
- The following changes are included.
### 6 Jun. 2023, 2023/06/06
#### Changes
- Fix `train_network.py` to probably work with older versions of LyCORIS.
- `gen_img_diffusers.py` now supports `BREAK` syntax.
- `train_network.py`がLyCORISの以前のバージョンでも恐らく動作するよう修正しました。
- `gen_img_diffusers.py` で `BREAK` 構文をサポートしました。
- Fixed a bug where the loss weight was incorrect when `--debiased_estimation_loss` was specified with `--v_parameterization`. PR [#1715](https://github.com/kohya-ss/sd-scripts/pull/1715) Thanks to catboxanon! See [the PR](https://github.com/kohya-ss/sd-scripts/pull/1715) for details.
- Removed the warning when `--v_parameterization` is specified in SDXL and SD1.5. PR [#1717](https://github.com/kohya-ss/sd-scripts/pull/1717)
### 3 Jun. 2023, 2023/06/03
- There was a bug where the min_bucket_reso/max_bucket_reso in the dataset configuration did not create the correct resolution bucket if it was not divisible by bucket_reso_steps. These values are now warned and automatically rounded to a divisible value. Thanks to Maru-mee for raising the issue. Related PR [#1632](https://github.com/kohya-ss/sd-scripts/pull/1632)
- Max Norm Regularization is now available in `train_network.py`. [PR #545](https://github.com/kohya-ss/sd-scripts/pull/545) Thanks to AI-Casanova!
- Max Norm Regularization is a technique to stabilize network training by limiting the norm of network weights. It may be effective in suppressing overfitting of LoRA and improving stability when used with other LoRAs. See PR for details.
- Specify as `--scale_weight_norms=1.0`. It seems good to try from `1.0`.
- The networks other than LoRA in this repository (such as LyCORIS) do not support this option.
- `bitsandbytes` is updated to 0.44.0. Now you can use `AdEMAMix8bit` and `PagedAdEMAMix8bit` in the training script. PR [#1640](https://github.com/kohya-ss/sd-scripts/pull/1640) Thanks to sdbds!
- There is no abbreviation, so please specify the full path like `--optimizer_type bitsandbytes.optim.AdEMAMix8bit` (not bnb but bitsandbytes).
- Three types of dropout have been added to `train_network.py` and LoRA network.
- Dropout is a technique to suppress overfitting and improve network performance by randomly setting some of the network outputs to 0.
- `--network_dropout` is a normal dropout at the neuron level. In the case of LoRA, it is applied to the output of down. Proposed in [PR #545](https://github.com/kohya-ss/sd-scripts/pull/545) Thanks to AI-Casanova!
- `--network_dropout=0.1` specifies the dropout probability to `0.1`.
- Note that the specification method is different from LyCORIS.
- For LoRA network, `--network_args` can specify `rank_dropout` to dropout each rank with specified probability. Also `module_dropout` can be specified to dropout each module with specified probability.
- Specify as `--network_args "rank_dropout=0.2" "module_dropout=0.1"`.
- `--network_dropout`, `rank_dropout`, and `module_dropout` can be specified at the same time.
- Values of 0.1 to 0.3 may be good to try. Values greater than 0.5 should not be specified.
- `rank_dropout` and `module_dropout` are original techniques of this repository. Their effectiveness has not been verified yet.
- The networks other than LoRA in this repository (such as LyCORIS) do not support these options.
- Fixed a bug in the cache of latents. When `flip_aug`, `alpha_mask`, and `random_crop` are different in multiple subsets in the dataset configuration file (.toml), the last subset is used instead of reflecting them correctly.
- Added an option `--scale_v_pred_loss_like_noise_pred` to scale v-prediction loss like noise prediction in each training script.
- By scaling the loss according to the time step, the weights of global noise prediction and local noise prediction become the same, and the improvement of details may be expected.
- See [this article](https://xrg.hatenablog.com/entry/2023/06/02/202418) by xrg for details (written in Japanese). Thanks to xrg for the great suggestion!
- Fixed an issue where the timesteps in the batch were the same when using Huber loss. PR [#1628](https://github.com/kohya-ss/sd-scripts/pull/1628) Thanks to recris!
- Max Norm Regularizationが`train_network.py`で使えるようになりました。[PR #545](https://github.com/kohya-ss/sd-scripts/pull/545) AI-Casanova氏に感謝します。
- Max Norm Regularizationは、ネットワークの重みのルムを制限することで、ネットワークの学習を安定させる手法です。LoRAの過学習の抑制、他のLoRAと併用した時の安定性の向上が期待できるかもしれません。詳細はPRを参照してください。
- `--scale_weight_norms=1.0`のように `--scale_weight_norms` で指定してください。`1.0`から試すと良いようです。
- LyCORIS等、当リポジトリ以外のネットワークは現時点では未対応です。
- Improvements in OFT (Orthogonal Finetuning) Implementation
1. Optimization of Calculation Order:
- Changed the calculation order in the forward method from (Wx)R to W(xR).
- This has improved computational efficiency and processing speed.
2. Correction of Bias Application:
- In the previous implementation, R was incorrectly applied to the bias.
- The new implementation now correctly handles bias by using F.conv2d and F.linear.
3. Efficiency Enhancement in Matrix Operations:
- Introduced einsum in both the forward and merge_to methods.
- This has optimized matrix operations, resulting in further speed improvements.
4. Proper Handling of Data Types:
- Improved to use torch.float32 during calculations and convert results back to the original data type.
- This maintains precision while ensuring compatibility with the original model.
5. Unified Processing for Conv2d and Linear Layers:
- Implemented a consistent method for applying OFT to both layer types.
- These changes have made the OFT implementation more efficient and accurate, potentially leading to improved model performance and training stability.
- `train_network.py` およびLoRAに計三種類のdropoutを追加しました。
- dropoutはネットワークの一部の出力をランダムに0にすることで、過学習の抑制、ネットワークの性能向上等を図る手法です。
- `--network_dropout` はニューロン単位の通常のdropoutです。LoRAの場合、downの出力に対して適用されます。[PR #545](https://github.com/kohya-ss/sd-scripts/pull/545) で提案されました。AI-Casanova氏に感謝します。
- `--network_dropout=0.1` などとすることで、dropoutの確率を指定できます。
- LyCORISとは指定方法が異なりますのでご注意ください。
- LoRAの場合、`--network_args`に`rank_dropout`を指定することで各rankを指定確率でdropoutします。また同じくLoRAの場合、`--network_args`に`module_dropout`を指定することで各モジュールを指定確率でdropoutします。
- `--network_args "rank_dropout=0.2" "module_dropout=0.1"` のように指定します。
- `--network_dropout`、`rank_dropout` 、 `module_dropout` は同時に指定できます。
- それぞれの値は0.1~0.3程度から試してみると良いかもしれません。0.5を超える値は指定しない方が良いでしょう。
- `rank_dropout`および`module_dropout`は当リポジトリ独自の手法です。有効性の検証はまだ行っていません。
- これらのdropoutはLyCORIS等、当リポジトリ以外のネットワークは現時点では未対応です。
- Additional Information
* Recommended α value for OFT constraint: We recommend using α values between 1e-4 and 1e-2. This differs slightly from the original implementation of "(α\*out_dim\*out_dim)". Our implementation uses "(α\*out_dim)", hence we recommend higher values than the 1e-5 suggested in the original implementation.
- 各学習スクリプトにv-prediction lossをnoise predictionと同様の値にスケールするオプション`--scale_v_pred_loss_like_noise_pred`を追加しました。
- タイムステップに応じてlossをスケールすることで、 大域的なノイズの予測と局所的なノイズの予測の重みが同じになり、ディテールの改善が期待できるかもしれません。
- 詳細はxrg氏のこちらの記事をご参照ください[noise_predictionモデルとv_predictionモデルの損失 - 勾配降下党青年局](https://xrg.hatenablog.com/entry/2023/06/02/202418) xrg氏の素晴らしい記事に感謝します。
* Performance Improvement: Training speed has been improved by approximately 30%.
### 31 May 2023, 2023/05/31
* Inference Environment: This implementation is compatible with and operates within Stable Diffusion web UI (SD1/2 and SDXL).
- Show warning when image caption file does not exist during training. [PR #533](https://github.com/kohya-ss/sd-scripts/pull/533) Thanks to TingTingin!
- Warning is also displayed when using class+identifier dataset. Please ignore if it is intended.
- `train_network.py` now supports merging network weights before training. [PR #542](https://github.com/kohya-ss/sd-scripts/pull/542) Thanks to u-haru!
- `--base_weights` option specifies LoRA or other model files (multiple files are allowed) to merge.
- `--base_weights_multiplier` option specifies multiplier of the weights to merge (multiple values are allowed). If omitted or less than `base_weights`, 1.0 is used.
- This is useful for incremental learning. See PR for details.
- Show warning and continue training when uploading to HuggingFace fails.
- The INVERSE_SQRT, COSINE_WITH_MIN_LR, and WARMUP_STABLE_DECAY learning rate schedules are now available in the transformers library. See PR [#1393](https://github.com/kohya-ss/sd-scripts/pull/1393) for details. Thanks to sdbds!
- See the [transformers documentation](https://huggingface.co/docs/transformers/v4.44.2/en/main_classes/optimizer_schedules#schedules) for details on each scheduler.
- `--lr_warmup_steps` and `--lr_decay_steps` can now be specified as a ratio of the number of training steps, not just the step value. Example: `--lr_warmup_steps=0.1` or `--lr_warmup_steps=10%`, etc.
- 学習時に画像のキャプションファイルが存在しない場合、警告が表示されるようになりました。 [PR #533](https://github.com/kohya-ss/sd-scripts/pull/533) TingTingin氏に感謝します。
- class+identifier方式のデータセットを利用している場合も警告が表示されます。意図している通りの場合は無視してください。
- `train_network.py` に学習前にモデルにnetworkの重みをマージする機能が追加されました。 [PR #542](https://github.com/kohya-ss/sd-scripts/pull/542) u-haru氏に感謝します。
- `--base_weights` オプションでLoRA等のモデルファイル複数可を指定すると、それらの重みをマージします。
- `--base_weights_multiplier` オプションでマージする重みの倍率(複数可)を指定できます。省略時または`base_weights`よりも数が少ない場合は1.0になります。
- 差分追加学習などにご利用ください。詳細はPRをご覧ください。
- HuggingFaceへのアップロードに失敗した場合、警告を表示しそのまま学習を続行するよう変更しました。
- When enlarging images in the script (when the size of the training image is small and bucket_no_upscale is not specified), it has been changed to use Pillow's resize and LANCZOS interpolation instead of OpenCV2's resize and Lanczos4 interpolation. The quality of the image enlargement may be slightly improved. PR [#1426](https://github.com/kohya-ss/sd-scripts/pull/1426) Thanks to sdbds!
### 25 May 2023, 2023/05/25
- Sample image generation during training now works on non-CUDA devices. PR [#1433](https://github.com/kohya-ss/sd-scripts/pull/1433) Thanks to millie-v!
- [D-Adaptation v3.0](https://github.com/facebookresearch/dadaptation) is now supported. [PR #530](https://github.com/kohya-ss/sd-scripts/pull/530) Thanks to sdbds!
- `--optimizer_type` now accepts `DAdaptAdamPreprint`, `DAdaptAdanIP`, and `DAdaptLion`.
- `DAdaptAdam` is now new. The old `DAdaptAdam` is available with `DAdaptAdamPreprint`.
- Simply specifying `DAdaptation` will use `DAdaptAdamPreprint` (same behavior as before).
- You need to install D-Adaptation v3.0. After activating venv, please do `pip install -U dadaptation`.
- See PR and D-Adaptation documentation for details.
- [D-Adaptation v3.0](https://github.com/facebookresearch/dadaptation)がサポートされました。 [PR #530](https://github.com/kohya-ss/sd-scripts/pull/530) sdbds氏に感謝します。
- `--optimizer_type`に`DAdaptAdamPreprint`、`DAdaptAdanIP`、`DAdaptLion` が追加されました。
- `DAdaptAdam`が新しくなりました。今までの`DAdaptAdam`は`DAdaptAdamPreprint`で使用できます。
- 単に `DAdaptation` を指定すると`DAdaptAdamPreprint`が使用されます(今までと同じ動き)。
- D-Adaptation v3.0のインストールが必要です。venvを有効にした後 `pip install -U dadaptation` としてください。
- 詳細はPRおよびD-Adaptationのドキュメントを参照してください。
- `--v_parameterization` is available in `sdxl_train.py`. The results are unpredictable, so use with caution. PR [#1505](https://github.com/kohya-ss/sd-scripts/pull/1505) Thanks to liesened!
### 22 May 2023, 2023/05/22
- Fused optimizer is available for SDXL training. PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) Thanks to 2kpr!
- The memory usage during training is significantly reduced by integrating the optimizer's backward pass with step. The training results are the same as before, but if you have plenty of memory, the speed will be slower.
- Specify the `--fused_backward_pass` option in `sdxl_train.py`. At this time, only AdaFactor is supported. Gradient accumulation is not available.
- Setting mixed precision to `no` seems to use less memory than `fp16` or `bf16`.
- Training is possible with a memory usage of about 17GB with a batch size of 1 and fp32. If you specify the `--full_bf16` option, you can further reduce the memory usage (but the accuracy will be lower). With the same memory usage as before, you can increase the batch size.
- PyTorch 2.1 or later is required because it uses the new API `Tensor.register_post_accumulate_grad_hook(hook)`.
- Mechanism: Normally, backward -> step is performed for each parameter, so all gradients need to be temporarily stored in memory. "Fuse backward and step" reduces memory usage by performing backward/step for each parameter and reflecting the gradient immediately. The more parameters there are, the greater the effect, so it is not effective in other training scripts (LoRA, etc.) where the memory usage peak is elsewhere, and there are no plans to implement it in those training scripts.
- Fixed several bugs.
- The state is saved even when the `--save_state` option is not specified in `fine_tune.py` and `train_db.py`. [PR #521](https://github.com/kohya-ss/sd-scripts/pull/521) Thanks to akshaal!
- Cannot load LoRA without `alpha`. [PR #527](https://github.com/kohya-ss/sd-scripts/pull/527) Thanks to Manjiz!
- Minor changes to console output during sample generation. [PR #515](https://github.com/kohya-ss/sd-scripts/pull/515) Thanks to yanhuifair!
- The generation script now uses xformers for VAE as well.
- いくつかのバグ修正を行いました。
- `fine_tune.py`と`train_db.py`で`--save_state`オプション未指定時にもstateが保存される。 [PR #521](https://github.com/kohya-ss/sd-scripts/pull/521) akshaal氏に感謝します。
- `alpha`を持たないLoRAを読み込めない。[PR #527](https://github.com/kohya-ss/sd-scripts/pull/527) Manjiz氏に感謝します。
- サンプル生成時のコンソール出力の軽微な変更。[PR #515](https://github.com/kohya-ss/sd-scripts/pull/515) yanhuifair氏に感謝します。
- 生成スクリプトでVAEについてもxformersを使うようにしました。
- Optimizer groups feature is added to SDXL training. PR [#1319](https://github.com/kohya-ss/sd-scripts/pull/1319)
- Memory usage is reduced by the same principle as Fused optimizer. The training results and speed are the same as Fused optimizer.
- Specify the number of groups like `--fused_optimizer_groups 10` in `sdxl_train.py`. Increasing the number of groups reduces memory usage but slows down training. Since the effect is limited to a certain number, it is recommended to specify 4-10.
- Any optimizer can be used, but optimizers that automatically calculate the learning rate (such as D-Adaptation and Prodigy) cannot be used. Gradient accumulation is not available.
- `--fused_optimizer_groups` cannot be used with `--fused_backward_pass`. When using AdaFactor, the memory usage is slightly larger than with Fused optimizer. PyTorch 2.1 or later is required.
- Mechanism: While Fused optimizer performs backward/step for individual parameters within the optimizer, optimizer groups reduce memory usage by grouping parameters and creating multiple optimizers to perform backward/step for each group. Fused optimizer requires implementation on the optimizer side, while optimizer groups are implemented only on the training script side.
### 16 May 2023, 2023/05/16
- LoRA+ is supported. PR [#1233](https://github.com/kohya-ss/sd-scripts/pull/1233) Thanks to rockerBOO!
- LoRA+ is a method to improve training speed by increasing the learning rate of the UP side (LoRA-B) of LoRA. Specify the multiple. The original paper recommends 16, but adjust as needed. Please see the PR for details.
- Specify `loraplus_lr_ratio` with `--network_args`. Example: `--network_args "loraplus_lr_ratio=16"`
- `loraplus_unet_lr_ratio` and `loraplus_lr_ratio` can be specified separately for U-Net and Text Encoder.
- Example: `--network_args "loraplus_unet_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` or `--network_args "loraplus_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` etc.
- `network_module` `networks.lora` and `networks.dylora` are available.
- Fixed an issue where an error would occur if the encoding of the prompt file was different from the default. [PR #510](https://github.com/kohya-ss/sd-scripts/pull/510) Thanks to sdbds!
- Please save the prompt file in UTF-8.
- プロンプトファイルのエンコーディングがデフォルトと異なる場合にエラーが発生する問題を修正しました。 [PR #510](https://github.com/kohya-ss/sd-scripts/pull/510) sdbds氏に感謝します。
- プロンプトファイルはUTF-8で保存してください。
- The feature to use the transparency (alpha channel) of the image as a mask in the loss calculation has been added. PR [#1223](https://github.com/kohya-ss/sd-scripts/pull/1223) Thanks to u-haru!
- The transparent part is ignored during training. Specify the `--alpha_mask` option in the training script or specify `alpha_mask = true` in the dataset configuration file.
- See [About masked loss](./docs/masked_loss_README.md) for details.
### 15 May 2023, 2023/05/15
- LoRA training in SDXL now supports block-wise learning rates and block-wise dim (rank). PR [#1331](https://github.com/kohya-ss/sd-scripts/pull/1331)
- Specify the learning rate and dim (rank) for each block.
- See [Block-wise learning rates in LoRA](./docs/train_network_README-ja.md#階層別学習率) for details (Japanese only).
- Added [English translation of documents](https://github.com/darkstorm2150/sd-scripts#links-to-usage-documentation) by darkstorm2150. Thank you very much!
- The prompt for sample generation during training can now be specified in `.toml` or `.json`. [PR #504](https://github.com/kohya-ss/sd-scripts/pull/504) Thanks to Linaqruf!
- For details on prompt description, please see the PR.
- Negative learning rates can now be specified during SDXL model training. PR [#1277](https://github.com/kohya-ss/sd-scripts/pull/1277) Thanks to Cauldrath!
- The model is trained to move away from the training images, so the model is easily collapsed. Use with caution. A value close to 0 is recommended.
- When specifying from the command line, use `=` like `--learning_rate=-1e-7`.
- darkstorm2150氏に[ドキュメント類を英訳](https://github.com/darkstorm2150/sd-scripts#links-to-usage-documentation)していただきました。ありがとうございます!
- 学習中のサンプル生成のプロンプトを`.toml`または`.json`で指定可能になりました。 [PR #504](https://github.com/kohya-ss/sd-scripts/pull/504) Linaqruf氏に感謝します。
- プロンプト記述の詳細は当該PRをご覧ください。
- Training scripts can now output training settings to wandb or Tensor Board logs. Specify the `--log_config` option. PR [#1285](https://github.com/kohya-ss/sd-scripts/pull/1285) Thanks to ccharest93, plucked, rockerBOO, and VelocityRa!
- Some settings, such as API keys and directory specifications, are not output due to security issues.
### 11 May 2023, 2023/05/11
- The ControlNet training script `train_controlnet.py` for SD1.5/2.x was not working, but it has been fixed. PR [#1284](https://github.com/kohya-ss/sd-scripts/pull/1284) Thanks to sdbds!
- Added an option `--dim_from_weights` to `train_network.py` to automatically determine the dim(rank) from the weight file. [PR #491](https://github.com/kohya-ss/sd-scripts/pull/491) Thanks to AI-Casanova!
- It is useful in combination with `resize_lora.py`. Please see the PR for details.
- Fixed a bug where the noise resolution was incorrect with Multires noise. [PR #489](https://github.com/kohya-ss/sd-scripts/pull/489) Thanks to sdbds!
- Please see the PR for details.
- The image generation scripts can now use img2img and highres fix at the same time.
- Fixed a bug where the hint image of ControlNet was incorrectly BGR instead of RGB in the image generation scripts.
- Added a feature to the image generation scripts to use the memory-efficient VAE.
- If you specify a number with the `--vae_slices` option, the memory-efficient VAE will be used. The maximum output size will be larger, but it will be slower. Please specify a value of about `16` or `32`.
- The implementation of the VAE is in `library/slicing_vae.py`.
- `train_network.py` and `sdxl_train_network.py` now restore the order/position of data loading from DataSet when resuming training. PR [#1353](https://github.com/kohya-ss/sd-scripts/pull/1353) [#1359](https://github.com/kohya-ss/sd-scripts/pull/1359) Thanks to KohakuBlueleaf!
- This resolves the issue where the order of data loading from DataSet changes when resuming training.
- Specify the `--skip_until_initial_step` option to skip data loading until the specified step. If not specified, data loading starts from the beginning of the DataSet (same as before).
- If `--resume` is specified, the step saved in the state is used.
- Specify the `--initial_step` or `--initial_epoch` option to skip data loading until the specified step or epoch. Use these options in conjunction with `--skip_until_initial_step`. These options can be used without `--resume` (use them when resuming training with `--network_weights`).
- `train_network.py`にdim(rank)を重みファイルから自動決定するオプション`--dim_from_weights`が追加されました。 [PR #491](https://github.com/kohya-ss/sd-scripts/pull/491) AI-Casanova氏に感謝します。
- `resize_lora.py`と組み合わせると有用です。詳細はPRもご参照ください。
- Multires noiseでイズ解像度が正しくない不具合が修正されました。 [PR #489](https://github.com/kohya-ss/sd-scripts/pull/489) sdbds氏に感謝します。
- 詳細は当該PRをご参照ください。
- 生成スクリプトでimg2imgとhighres fixを同時に使用できるようにしました。
- 生成スクリプトでControlNetのhint画像が誤ってBGRだったのをRGBに修正しました。
- 生成スクリプトで省メモリ化VAEを使えるよう機能追加しました。
- `--vae_slices`オプションに数値を指定すると、省メモリ化VAEを用います。出力可能な最大サイズが大きくなりますが、遅くなります。`16`または`32`程度の値を指定してください。
- VAEの実装は`library/slicing_vae.py`にあります。
- An option `--disable_mmap_load_safetensors` is added to disable memory mapping when loading the model's .safetensors in SDXL. PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Thanks to Zovjsra!
- It seems that the model file loading is faster in the WSL environment etc.
- Available in `sdxl_train.py`, `sdxl_train_network.py`, `sdxl_train_textual_inversion.py`, and `sdxl_train_control_net_lllite.py`.
### 7 May 2023, 2023/05/07
- When there is an error in the cached latents file on disk, the file name is now displayed. PR [#1278](https://github.com/kohya-ss/sd-scripts/pull/1278) Thanks to Cauldrath!
- The documentation has been moved to the `docs` folder. If you have links, please change them.
- Removed `gradio` from `requirements.txt`.
- DAdaptAdaGrad, DAdaptAdan, and DAdaptSGD are now supported by DAdaptation. [PR#455](https://github.com/kohya-ss/sd-scripts/pull/455) Thanks to sdbds!
- DAdaptation needs to be installed. Also, depending on the optimizer, DAdaptation may need to be updated. Please update with `pip install --upgrade dadaptation`.
- Added support for pre-calculation of LoRA weights in image generation scripts. Specify `--network_pre_calc`.
- The prompt option `--am` is available. Also, it is disabled when Regional LoRA is used.
- Added Adaptive noise scale to each training script. Specify a number with `--adaptive_noise_scale` to enable it.
- __Experimental option. It may be removed or changed in the future.__
- This is an original implementation that automatically adjusts the value of the noise offset according to the absolute value of the mean of each channel of the latents. It is expected that appropriate noise offsets will be set for bright and dark images, respectively.
- Specify it together with `--noise_offset`.
- The actual value of the noise offset is calculated as `noise_offset + abs(mean(latents, dim=(2,3))) * adaptive_noise_scale`. Since the latent is close to a normal distribution, it may be a good idea to specify a value of about 1/10 to the same as the noise offset.
- Negative values can also be specified, in which case the noise offset will be clipped to 0 or more.
- Other minor fixes.
- Fixed an error that occurs when specifying `--max_dataloader_n_workers` in `tag_images_by_wd14_tagger.py` when Onnx is not used. PR [#1291](
https://github.com/kohya-ss/sd-scripts/pull/1291) issue [#1290](
https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821!
- ドキュメントを`docs`フォルダに移動しました。リンク等を張られている場合は変更をお願いいたします。
- `requirements.txt`から`gradio`を削除しました。
- DAdaptationで新しくDAdaptAdaGrad、DAdaptAdan、DAdaptSGDがサポートされました。[PR#455](https://github.com/kohya-ss/sd-scripts/pull/455) sdbds氏に感謝します。
- dadaptationのインストールが必要です。またオプティマイザによってはdadaptationの更新が必要です。`pip install --upgrade dadaptation`で更新してください。
- 画像生成スクリプトでLoRAの重みの事前計算をサポートしました。`--network_pre_calc`を指定してください。
- プロンプトオプションの`--am`が利用できます。またRegional LoRA使用時には無効になります。
- 各学習スクリプトにAdaptive noise scaleを追加しました。`--adaptive_noise_scale`で数値を指定すると有効になります。
- __実験的オプションです。将来的に削除、仕様変更される可能性があります。__
- Noise offsetの値を、latentsの各チャネルの平均値の絶対値に応じて自動調整するオプションです。独自の実装で、明るい画像、暗い画像に対してそれぞれ適切なnoise offsetが設定されることが期待されます。
- `--noise_offset` と同時に指定してください。
- 実際のNoise offsetの値は `noise_offset + abs(mean(latents, dim=(2,3))) * adaptive_noise_scale` で計算されます。 latentは正規分布に近いためnoise_offsetの1/10同程度の値を指定するとよいかもしれません
- 負の値も指定でき、その場合はnoise offsetは0以上にclipされます
- その他の細かい修正を行いました
- Fixed a bug that `caption_separator` cannot be specified in the subset in the dataset settings .toml file. [#1312](https://github.com/kohya-ss/sd-scripts/pull/1312) and [#1313](https://github.com/kohya-ss/sd-scripts/pull/1312) Thanks to rockerBOO!
- Fixed a potential bug in ControlNet-LLLite training. PR [#1322](https://github.com/kohya-ss/sd-scripts/pull/1322) Thanks to aria1th!
- Fixed some bugs when using DeepSpeed. Related [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247)
- Added a prompt option `--f` to `gen_imgs.py` to specify the file name when saving. Also, Diffusers-based keys for LoRA weights are now supported.
#### 変更点
- devブランチがmainにマージされました。ドキュメントの整備が遅れており申し訳ありません。少しずつ整備していきます
- マージ直前の状態が Version 0.8.8 としてリリースされていますので、問題があればそちらをご利用ください
- 以下の変更が含まれます
- SDXL の学習時に Fused optimizer が使えるようになりました。PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) 2kpr 氏に感謝します。
- optimizer の backward pass に step を統合することで学習時のメモリ使用量を大きく削減します。学習結果は未適用時と同一ですが、メモリが潤沢にある場合は速度は遅くなります。
- `sdxl_train.py` に `--fused_backward_pass` オプションを指定してください。現時点では optimizer は AdaFactor のみ対応しています。また gradient accumulation は使えません。
- mixed precision は `no` のほうが `fp16` や `bf16` よりも使用メモリ量が少ないようです。
- バッチサイズ 1、fp32 で 17GB 程度で学習可能なようです。`--full_bf16` オプションを指定するとさらに削減できます(精度は劣ります)。以前と同じメモリ使用量ではバッチサイズを増やせます。
- PyTorch 2.1 以降の新 API `Tensor.register_post_accumulate_grad_hook(hook)` を使用しているため、PyTorch 2.1 以降が必要です。
- 仕組み:通常は backward -> step の順で行うためすべての勾配を一時的にメモリに保持する必要があります。「backward と step の統合」はパラメータごとに backward/step を行って、勾配をすぐ反映することでメモリ使用量を削減します。パラメータ数が多いほど効果が大きいため、SDXL の学習以外LoRA 等)ではほぼ効果がなく(メモリ使用量のピークが他の場所にあるため)、それらの学習スクリプトへの実装予定もありません。
- SDXL の学習時に optimizer group 機能を追加しました。PR [#1319](https://github.com/kohya-ss/sd-scripts/pull/1319)
- Fused optimizer と同様の原理でメモリ使用量を削減します。学習結果や速度についても同様です。
- `sdxl_train.py` に `--fused_optimizer_groups 10` のようにグループ数を指定してください。グループ数を増やすとメモリ使用量が削減されますが、速度は遅くなります。ある程度の数までしか効果がないため、4~10 程度を指定すると良いでしょう。
- 任意の optimizer が使えますが、学習率を自動計算する optimizer D-Adaptation や Prodigy などは使えません。gradient accumulation は使えません。
- `--fused_optimizer_groups` は `--fused_backward_pass` と併用できません。AdaFactor 使用時は Fused optimizer よりも若干メモリ使用量は大きくなります。PyTorch 2.1 以降が必要です。
- 仕組みFused optimizer が optimizer 内で個別のパラメータについて backward/step を行っているのに対して、optimizer groups はパラメータをグループ化して複数の optimizer を作成し、それぞれ backward/step を行うことでメモリ使用量を削減します。Fused optimizer は optimizer 側の実装が必要ですが、optimizer groups は学習スクリプト側のみで実装されています。やはり SDXL の学習でのみ効果があります。
- LoRA+ がサポートされました。PR [#1233](https://github.com/kohya-ss/sd-scripts/pull/1233) rockerBOO 氏に感謝します。
- LoRA の UP 側LoRA-Bの学習率を上げることで学習速度の向上を図る手法です。倍数で指定します。元の論文では 16 が推奨されていますが、データセット等にもよりますので、適宜調整してください。PR もあわせてご覧ください。
- `--network_args` で `loraplus_lr_ratio` を指定します。例:`--network_args "loraplus_lr_ratio=16"`
- `loraplus_unet_lr_ratio` と `loraplus_lr_ratio` で、U-Net および Text Encoder に個別の値を指定することも可能です。
- 例:`--network_args "loraplus_unet_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` または `--network_args "loraplus_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` など
- `network_module` の `networks.lora` および `networks.dylora` で使用可能です。
- 画像の透明度アルファチャネルをロス計算時のマスクとして使用する機能が追加されました。PR [#1223](https://github.com/kohya-ss/sd-scripts/pull/1223) u-haru 氏に感謝します。
- 透明部分が学習時に無視されるようになります。学習スクリプトに `--alpha_mask` オプションを指定するか、データセット設定ファイルに `alpha_mask = true` を指定してください。
- 詳細は [マスクロスについて](./docs/masked_loss_README-ja.md) をご覧ください。
- SDXL の LoRA で階層別学習率、階層別 dim (rank) をサポートしました。PR [#1331](https://github.com/kohya-ss/sd-scripts/pull/1331)
- ブロックごとに学習率および dim (rank) を指定することができます。
- 詳細は [LoRA の階層別学習率](./docs/train_network_README-ja.md#階層別学習率) をご覧ください。
- `sdxl_train.py` での SDXL モデル学習時に負の学習率が指定できるようになりました。PR [#1277](https://github.com/kohya-ss/sd-scripts/pull/1277) Cauldrath 氏に感謝します。
- 学習画像から離れるように学習するため、モデルは容易に崩壊します。注意して使用してください。0 に近い値を推奨します。
- コマンドラインから指定する場合、`--learning_rate=-1e-7` のように`=` を使ってください。
- 各学習スクリプトで学習設定を wandb や Tensor Board などのログに出力できるようになりました。`--log_config` オプションを指定してください。PR [#1285](https://github.com/kohya-ss/sd-scripts/pull/1285) ccharest93 氏、plucked 氏、rockerBOO 氏および VelocityRa 氏に感謝します。
- API キーや各種ディレクトリ指定など、一部の設定はセキュリティ上の問題があるため出力されません。
- SD1.5/2.x 用の ControlNet 学習スクリプト `train_controlnet.py` が動作しなくなっていたのが修正されました。PR [#1284](https://github.com/kohya-ss/sd-scripts/pull/1284) sdbds 氏に感謝します。
- `train_network.py` および `sdxl_train_network.py` で、学習再開時に DataSet の読み込み順についても復元できるようになりました。PR [#1353](https://github.com/kohya-ss/sd-scripts/pull/1353) [#1359](https://github.com/kohya-ss/sd-scripts/pull/1359) KohakuBlueleaf 氏に感謝します。
- これにより、学習再開時に DataSet の読み込み順が変わってしまう問題が解消されます。
- `--skip_until_initial_step` オプションを指定すると、指定したステップまで DataSet 読み込みをスキップします。指定しない場合の動作は変わりませんDataSet の最初から読み込みます)
- `--resume` オプションを指定すると、state に保存されたステップ数が使用されます。
- `--initial_step` または `--initial_epoch` オプションを指定すると、指定したステップまたはエポックまで DataSet 読み込みをスキップします。これらのオプションは `--skip_until_initial_step` と併用してください。またこれらのオプションは `--resume` と併用しなくても使えます(`--network_weights` を用いた学習再開時などにお使いください )。
- SDXL でモデルの .safetensors を読み込む際にメモリマッピングを無効化するオプション `--disable_mmap_load_safetensors` が追加されました。PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Zovjsra 氏に感謝します。
- WSL 環境等でモデルファイルの読み込みが高速化されるようです。
- `sdxl_train.py`、`sdxl_train_network.py`、`sdxl_train_textual_inversion.py`、`sdxl_train_control_net_lllite.py` で使用可能です。
- ディスクにキャッシュされた latents ファイルに何らかのエラーがあったとき、そのファイル名が表示されるようになりました。 PR [#1278](https://github.com/kohya-ss/sd-scripts/pull/1278) Cauldrath 氏に感謝します。
- `tag_images_by_wd14_tagger.py` で Onnx 未使用時に `--max_dataloader_n_workers` を指定するとエラーになる不具合が修正されました。 PR [#1291](
https://github.com/kohya-ss/sd-scripts/pull/1291) issue [#1290](
https://github.com/kohya-ss/sd-scripts/pull/1290) frodo821 氏に感謝します。
- データセット設定の .toml ファイルで、`caption_separator` が subset に指定できない不具合が修正されました。 PR [#1312](https://github.com/kohya-ss/sd-scripts/pull/1312) および [#1313](https://github.com/kohya-ss/sd-scripts/pull/1313) rockerBOO 氏に感謝します。
- ControlNet-LLLite 学習時の潜在バグが修正されました。 PR [#1322](https://github.com/kohya-ss/sd-scripts/pull/1322) aria1th 氏に感謝します。
- DeepSpeed 使用時のいくつかのバグを修正しました。関連 [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247)
- `gen_imgs.py` のプロンプトオプションに、保存時のファイル名を指定する `--f` オプションを追加しました。また同スクリプトで Diffusers ベースのキーを持つ LoRA の重みに対応しました。
### Oct 27, 2024 / 2024-10-27:
- `svd_merge_lora.py` VRAM usage has been reduced. However, main memory usage will increase (32GB is sufficient).
- This will be included in the next release.
- `svd_merge_lora.py` のVRAM使用量を削減しました。ただし、メインメモリの使用量は増加します32GBあれば十分です
- これは次回リリースに含まれます。
### Oct 26, 2024 / 2024-10-26:
- Fixed a bug in `svd_merge_lora.py`, `sdxl_merge_lora.py`, and `resize_lora.py` where the hash value of LoRA metadata was not correctly calculated when the `save_precision` was different from the `precision` used in the calculation. See issue [#1722](https://github.com/kohya-ss/sd-scripts/pull/1722) for details. Thanks to JujoHotaru for raising the issue.
- It will be included in the next release.
- `svd_merge_lora.py`、`sdxl_merge_lora.py`、`resize_lora.py`で、保存時の精度が計算時の精度と異なる場合、LoRAメタデータのハッシュ値が正しく計算されない不具合を修正しました。詳細は issue [#1722](https://github.com/kohya-ss/sd-scripts/pull/1722) をご覧ください。問題提起していただいた JujoHotaru 氏に感謝します。
- 以上は次回リリースに含まれます。
### Sep 13, 2024 / 2024-09-13:
- `sdxl_merge_lora.py` now supports OFT. Thanks to Maru-mee for the PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580).
- `svd_merge_lora.py` now supports LBW. Thanks to terracottahaniwa. See PR [#1575](https://github.com/kohya-ss/sd-scripts/pull/1575) for details.
- `sdxl_merge_lora.py` also supports LBW.
- See [LoRA Block Weight](https://github.com/hako-mikan/sd-webui-lora-block-weight) by hako-mikan for details on LBW.
- These will be included in the next release.
- `sdxl_merge_lora.py` が OFT をサポートされました。PR [#1580](https://github.com/kohya-ss/sd-scripts/pull/1580) Maru-mee 氏に感謝します。
- `svd_merge_lora.py` で LBW がサポートされました。PR [#1575](https://github.com/kohya-ss/sd-scripts/pull/1575) terracottahaniwa 氏に感謝します。
- `sdxl_merge_lora.py` でも LBW がサポートされました。
- LBW の詳細は hako-mikan 氏の [LoRA Block Weight](https://github.com/hako-mikan/sd-webui-lora-block-weight) をご覧ください。
- 以上は次回リリースに含まれます。
### Jun 23, 2024 / 2024-06-23:
- Fixed `cache_latents.py` and `cache_text_encoder_outputs.py` not working. (Will be included in the next release.)
- `cache_latents.py` および `cache_text_encoder_outputs.py` が動作しなくなっていたのを修正しました。(次回リリースに含まれます。)
### Apr 7, 2024 / 2024-04-07: v0.8.7
- The default value of `huber_schedule` in Scheduled Huber Loss is changed from `exponential` to `snr`, which is expected to give better results.
- Scheduled Huber Loss の `huber_schedule` のデフォルト値を `exponential` から、より良い結果が期待できる `snr` に変更しました。
### Apr 7, 2024 / 2024-04-07: v0.8.6
#### Highlights
- The dependent libraries are updated. Please see [Upgrade](#upgrade) and update the libraries.
- Especially `imagesize` is newly added, so if you cannot update the libraries immediately, please install with `pip install imagesize==1.4.1` separately.
- `bitsandbytes==0.43.0`, `prodigyopt==1.0`, `lion-pytorch==0.0.6` are included in the requirements.txt.
- `bitsandbytes` no longer requires complex procedures as it now officially supports Windows.
- Also, the PyTorch version is updated to 2.1.2 (PyTorch does not need to be updated immediately). In the upgrade procedure, PyTorch is not updated, so please manually install or update torch, torchvision, xformers if necessary (see [Upgrade PyTorch](#upgrade-pytorch)).
- When logging to wandb is enabled, the entire command line is exposed. Therefore, it is recommended to write wandb API key and HuggingFace token in the configuration file (`.toml`). Thanks to bghira for raising the issue.
- A warning is displayed at the start of training if such information is included in the command line.
- Also, if there is an absolute path, the path may be exposed, so it is recommended to specify a relative path or write it in the configuration file. In such cases, an INFO log is displayed.
- See [#1123](https://github.com/kohya-ss/sd-scripts/pull/1123) and PR [#1240](https://github.com/kohya-ss/sd-scripts/pull/1240) for details.
- Colab seems to stop with log output. Try specifying `--console_log_simple` option in the training script to disable rich logging.
- Other improvements include the addition of masked loss, scheduled Huber Loss, DeepSpeed support, dataset settings improvements, and image tagging improvements. See below for details.
#### Training scripts
- `train_network.py` and `sdxl_train_network.py` are modified to record some dataset settings in the metadata of the trained model (`caption_prefix`, `caption_suffix`, `keep_tokens_separator`, `secondary_separator`, `enable_wildcard`).
- Fixed a bug that U-Net and Text Encoders are included in the state in `train_network.py` and `sdxl_train_network.py`. The saving and loading of the state are faster, the file size is smaller, and the memory usage when loading is reduced.
- DeepSpeed is supported. PR [#1101](https://github.com/kohya-ss/sd-scripts/pull/1101) and [#1139](https://github.com/kohya-ss/sd-scripts/pull/1139) Thanks to BootsofLagrangian! See PR [#1101](https://github.com/kohya-ss/sd-scripts/pull/1101) for details.
- The masked loss is supported in each training script. PR [#1207](https://github.com/kohya-ss/sd-scripts/pull/1207) See [Masked loss](#about-masked-loss) for details.
- Scheduled Huber Loss has been introduced to each training scripts. PR [#1228](https://github.com/kohya-ss/sd-scripts/pull/1228/) Thanks to kabachuha for the PR and cheald, drhead, and others for the discussion! See the PR and [Scheduled Huber Loss](#about-scheduled-huber-loss) for details.
- The options `--noise_offset_random_strength` and `--ip_noise_gamma_random_strength` are added to each training script. These options can be used to vary the noise offset and ip noise gamma in the range of 0 to the specified value. PR [#1177](https://github.com/kohya-ss/sd-scripts/pull/1177) Thanks to KohakuBlueleaf!
- The options `--save_state_on_train_end` are added to each training script. PR [#1168](https://github.com/kohya-ss/sd-scripts/pull/1168) Thanks to gesen2egee!
- The options `--sample_every_n_epochs` and `--sample_every_n_steps` in each training script now display a warning and ignore them when a number less than or equal to `0` is specified. Thanks to S-Del for raising the issue.
#### Dataset settings
- The [English version of the dataset settings documentation](./docs/config_README-en.md) is added. PR [#1175](https://github.com/kohya-ss/sd-scripts/pull/1175) Thanks to darkstorm2150!
- The `.toml` file for the dataset config is now read in UTF-8 encoding. PR [#1167](https://github.com/kohya-ss/sd-scripts/pull/1167) Thanks to Horizon1704!
- Fixed a bug that the last subset settings are applied to all images when multiple subsets of regularization images are specified in the dataset settings. The settings for each subset are correctly applied to each image. PR [#1205](https://github.com/kohya-ss/sd-scripts/pull/1205) Thanks to feffy380!
- Some features are added to the dataset subset settings.
- `secondary_separator` is added to specify the tag separator that is not the target of shuffling or dropping.
- Specify `secondary_separator=";;;"`. When you specify `secondary_separator`, the part is not shuffled or dropped.
- `enable_wildcard` is added. When set to `true`, the wildcard notation `{aaa|bbb|ccc}` can be used. The multi-line caption is also enabled.
- `keep_tokens_separator` is updated to be used twice in the caption. When you specify `keep_tokens_separator="|||"`, the part divided by the second `|||` is not shuffled or dropped and remains at the end.
- The existing features `caption_prefix` and `caption_suffix` can be used together. `caption_prefix` and `caption_suffix` are processed first, and then `enable_wildcard`, `keep_tokens_separator`, shuffling and dropping, and `secondary_separator` are processed in order.
- See [Dataset config](./docs/config_README-en.md) for details.
- The dataset with DreamBooth method supports caching image information (size, caption). PR [#1178](https://github.com/kohya-ss/sd-scripts/pull/1178) and [#1206](https://github.com/kohya-ss/sd-scripts/pull/1206) Thanks to KohakuBlueleaf! See [DreamBooth method specific options](./docs/config_README-en.md#dreambooth-specific-options) for details.
#### Image tagging
- The support for v3 repositories is added to `tag_image_by_wd14_tagger.py` (`--onnx` option only). PR [#1192](https://github.com/kohya-ss/sd-scripts/pull/1192) Thanks to sdbds!
- Onnx may need to be updated. Onnx is not installed by default, so please install or update it with `pip install onnx==1.15.0 onnxruntime-gpu==1.17.1` etc. Please also check the comments in `requirements.txt`.
- The model is now saved in the subdirectory as `--repo_id` in `tag_image_by_wd14_tagger.py` . This caches multiple repo_id models. Please delete unnecessary files under `--model_dir`.
- Some options are added to `tag_image_by_wd14_tagger.py`.
- Some are added in PR [#1216](https://github.com/kohya-ss/sd-scripts/pull/1216) Thanks to Disty0!
- Output rating tags `--use_rating_tags` and `--use_rating_tags_as_last_tag`
- Output character tags first `--character_tags_first`
- Expand character tags and series `--character_tag_expand`
- Specify tags to output first `--always_first_tags`
- Replace tags `--tag_replacement`
- See [Tagging documentation](./docs/wd14_tagger_README-en.md) for details.
- Fixed an error when specifying `--beam_search` and a value of 2 or more for `--num_beams` in `make_captions.py`.
#### About Masked loss
The masked loss is supported in each training script. To enable the masked loss, specify the `--masked_loss` option.
The feature is not fully tested, so there may be bugs. If you find any issues, please open an Issue.
ControlNet dataset is used to specify the mask. The mask images should be the RGB images. The pixel value 255 in R channel is treated as the mask (the loss is calculated only for the pixels with the mask), and 0 is treated as the non-mask. The pixel values 0-255 are converted to 0-1 (i.e., the pixel value 128 is treated as the half weight of the loss). See details for the dataset specification in the [LLLite documentation](./docs/train_lllite_README.md#preparing-the-dataset).
#### About Scheduled Huber Loss
Scheduled Huber Loss has been introduced to each training scripts. This is a method to improve robustness against outliers or anomalies (data corruption) in the training data.
With the traditional MSE (L2) loss function, the impact of outliers could be significant, potentially leading to a degradation in the quality of generated images. On the other hand, while the Huber loss function can suppress the influence of outliers, it tends to compromise the reproduction of fine details in images.
To address this, the proposed method employs a clever application of the Huber loss function. By scheduling the use of Huber loss in the early stages of training (when noise is high) and MSE in the later stages, it strikes a balance between outlier robustness and fine detail reproduction.
Experimental results have confirmed that this method achieves higher accuracy on data containing outliers compared to pure Huber loss or MSE. The increase in computational cost is minimal.
The newly added arguments loss_type, huber_schedule, and huber_c allow for the selection of the loss function type (Huber, smooth L1, MSE), scheduling method (exponential, constant, SNR), and Huber's parameter. This enables optimization based on the characteristics of the dataset.
See PR [#1228](https://github.com/kohya-ss/sd-scripts/pull/1228/) for details.
- `loss_type`: Specify the loss function type. Choose `huber` for Huber loss, `smooth_l1` for smooth L1 loss, and `l2` for MSE loss. The default is `l2`, which is the same as before.
- `huber_schedule`: Specify the scheduling method. Choose `exponential`, `constant`, or `snr`. The default is `snr`.
- `huber_c`: Specify the Huber's parameter. The default is `0.1`.
Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates.
#### 主要な変更点
- 依存ライブラリが更新されました。[アップグレード](./README-ja.md#アップグレード) を参照しライブラリを更新してください。
- 特に `imagesize` が新しく追加されていますので、すぐにライブラリの更新ができない場合は `pip install imagesize==1.4.1` で個別にインストールしてください。
- `bitsandbytes==0.43.0`、`prodigyopt==1.0`、`lion-pytorch==0.0.6` が requirements.txt に含まれるようになりました。
- `bitsandbytes` が公式に Windows をサポートしたため複雑な手順が不要になりました。
- また PyTorch のバージョンを 2.1.2 に更新しました。PyTorch はすぐに更新する必要はありません。更新時は、アップグレードの手順では PyTorch が更新されませんので、torch、torchvision、xformers を手動でインストールしてください。
- wandb へのログ出力が有効の場合、コマンドライン全体が公開されます。そのため、コマンドラインに wandb の API キーや HuggingFace のトークンなどが含まれる場合、設定ファイル(`.toml`)への記載をお勧めします。問題提起していただいた bghira 氏に感謝します。
- このような場合には学習開始時に警告が表示されます。
- また絶対パスの指定がある場合、そのパスが公開される可能性がありますので、相対パスを指定するか設定ファイルに記載することをお勧めします。このような場合は INFO ログが表示されます。
- 詳細は [#1123](https://github.com/kohya-ss/sd-scripts/pull/1123) および PR [#1240](https://github.com/kohya-ss/sd-scripts/pull/1240) をご覧ください。
- Colab での動作時、ログ出力で停止してしまうようです。学習スクリプトに `--console_log_simple` オプションを指定し、rich のロギングを無効してお試しください。
- その他、マスクロス追加、Scheduled Huber Loss 追加、DeepSpeed 対応、データセット設定の改善、画像タグ付けの改善などがあります。詳細は以下をご覧ください。
#### 学習スクリプト
- `train_network.py` および `sdxl_train_network.py` で、学習したモデルのメタデータに一部のデータセット設定が記録されるよう修正しました(`caption_prefix`、`caption_suffix`、`keep_tokens_separator`、`secondary_separator`、`enable_wildcard`)。
- `train_network.py` および `sdxl_train_network.py` で、state に U-Net および Text Encoder が含まれる不具合を修正しました。state の保存、読み込みが高速化され、ファイルサイズも小さくなり、また読み込み時のメモリ使用量も削減されます。
- DeepSpeed がサポートされました。PR [#1101](https://github.com/kohya-ss/sd-scripts/pull/1101) 、[#1139](https://github.com/kohya-ss/sd-scripts/pull/1139) BootsofLagrangian 氏に感謝します。詳細は PR [#1101](https://github.com/kohya-ss/sd-scripts/pull/1101) をご覧ください。
- 各学習スクリプトでマスクロスをサポートしました。PR [#1207](https://github.com/kohya-ss/sd-scripts/pull/1207) 詳細は [マスクロスについて](#マスクロスについて) をご覧ください。
- 各学習スクリプトに Scheduled Huber Loss を追加しました。PR [#1228](https://github.com/kohya-ss/sd-scripts/pull/1228/) ご提案いただいた kabachuha 氏、および議論を深めてくださった cheald 氏、drhead 氏を始めとする諸氏に感謝します。詳細は当該 PR および [Scheduled Huber Loss について](#scheduled-huber-loss-について) をご覧ください。
- 各学習スクリプトに、noise offset、ip noise gammaを、それぞれ 0~指定した値の範囲で変動させるオプション `--noise_offset_random_strength` および `--ip_noise_gamma_random_strength` が追加されました。 PR [#1177](https://github.com/kohya-ss/sd-scripts/pull/1177) KohakuBlueleaf 氏に感謝します。
- 各学習スクリプトに、学習終了時に state を保存する `--save_state_on_train_end` オプションが追加されました。 PR [#1168](https://github.com/kohya-ss/sd-scripts/pull/1168) gesen2egee 氏に感謝します。
- 各学習スクリプトで `--sample_every_n_epochs` および `--sample_every_n_steps` オプションに `0` 以下の数値を指定した時、警告を表示するとともにそれらを無視するよう変更しました。問題提起していただいた S-Del 氏に感謝します。
#### データセット設定
- データセット設定の `.toml` ファイルが UTF-8 encoding で読み込まれるようになりました。PR [#1167](https://github.com/kohya-ss/sd-scripts/pull/1167) Horizon1704 氏に感謝します。
- データセット設定で、正則化画像のサブセットを複数指定した時、最後のサブセットの各種設定がすべてのサブセットの画像に適用される不具合が修正されました。それぞれのサブセットの設定が、それぞれの画像に正しく適用されます。PR [#1205](https://github.com/kohya-ss/sd-scripts/pull/1205) feffy380 氏に感謝します。
- データセットのサブセット設定にいくつかの機能を追加しました。
- シャッフルの対象とならないタグ分割識別子の指定 `secondary_separator` を追加しました。`secondary_separator=";;;"` のように指定します。`secondary_separator` で区切ることで、その部分はシャッフル、drop 時にまとめて扱われます。
- `enable_wildcard` を追加しました。`true` にするとワイルドカード記法 `{aaa|bbb|ccc}` が使えます。また複数行キャプションも有効になります。
- `keep_tokens_separator` をキャプション内に 2 つ使えるようにしました。たとえば `keep_tokens_separator="|||"` と指定したとき、`1girl, hatsune miku, vocaloid ||| stage, mic ||| best quality, rating: general` とキャプションを指定すると、二番目の `|||` で分割された部分はシャッフル、drop されず末尾に残ります。
- 既存の機能 `caption_prefix` と `caption_suffix` とあわせて使えます。`caption_prefix` と `caption_suffix` は一番最初に処理され、その後、ワイルドカード、`keep_tokens_separator`、シャッフルおよび drop、`secondary_separator` の順に処理されます。
- 詳細は [データセット設定](./docs/config_README-ja.md) をご覧ください。
- DreamBooth 方式の DataSet で画像情報サイズ、キャプションをキャッシュする機能が追加されました。PR [#1178](https://github.com/kohya-ss/sd-scripts/pull/1178)、[#1206](https://github.com/kohya-ss/sd-scripts/pull/1206) KohakuBlueleaf 氏に感謝します。詳細は [データセット設定](./docs/config_README-ja.md#dreambooth-方式専用のオプション) をご覧ください。
- データセット設定の[英語版ドキュメント](./docs/config_README-en.md) が追加されました。PR [#1175](https://github.com/kohya-ss/sd-scripts/pull/1175) darkstorm2150 氏に感謝します。
#### 画像のタグ付け
- `tag_image_by_wd14_tagger.py` で v3 のリポジトリがサポートされました(`--onnx` 指定時のみ有効)。 PR [#1192](https://github.com/kohya-ss/sd-scripts/pull/1192) sdbds 氏に感謝します。
- Onnx のバージョンアップが必要になるかもしれません。デフォルトでは Onnx はインストールされていませんので、`pip install onnx==1.15.0 onnxruntime-gpu==1.17.1` 等でインストール、アップデートしてください。`requirements.txt` のコメントもあわせてご確認ください。
- `tag_image_by_wd14_tagger.py` で、モデルを`--repo_id` のサブディレクトリに保存するようにしました。これにより複数のモデルファイルがキャッシュされます。`--model_dir` 直下の不要なファイルは削除願います。
- `tag_image_by_wd14_tagger.py` にいくつかのオプションを追加しました。
- 一部は PR [#1216](https://github.com/kohya-ss/sd-scripts/pull/1216) で追加されました。Disty0 氏に感謝します。
- レーティングタグを出力する `--use_rating_tags` および `--use_rating_tags_as_last_tag`
- キャラクタタグを最初に出力する `--character_tags_first`
- キャラクタタグとシリーズを展開する `--character_tag_expand`
- 常に最初に出力するタグを指定する `--always_first_tags`
- タグを置換する `--tag_replacement`
- 詳細は [タグ付けに関するドキュメント](./docs/wd14_tagger_README-ja.md) をご覧ください。
- `make_captions.py` で `--beam_search` を指定し `--num_beams` に2以上の値を指定した時のエラーを修正しました。
#### マスクロスについて
各学習スクリプトでマスクロスをサポートしました。マスクロスを有効にするには `--masked_loss` オプションを指定してください。
機能は完全にテストされていないため、不具合があるかもしれません。その場合は Issue を立てていただけると助かります。
マスクの指定には ControlNet データセットを使用します。マスク画像は RGB 画像である必要があります。R チャンネルのピクセル値 255 がロス計算対象、0 がロス計算対象外になります。0-255 の値は、0-1 の範囲に変換されます(つまりピクセル値 128 の部分はロスの重みが半分になります)。データセットの詳細は [LLLite ドキュメント](./docs/train_lllite_README-ja.md#データセットの準備) をご覧ください。
#### Scheduled Huber Loss について
各学習スクリプトに、学習データ中の異常値や外れ値data corruptionへの耐性を高めるための手法、Scheduled Huber Lossが導入されました。
従来のMSEL2損失関数では、異常値の影響を大きく受けてしまい、生成画像の品質低下を招く恐れがありました。一方、Huber損失関数は異常値の影響を抑えられますが、画像の細部再現性が損なわれがちでした。
この手法ではHuber損失関数の適用を工夫し、学習の初期段階イズが大きい場合ではHuber損失を、後期段階ではMSEを用いるようスケジューリングすることで、異常値耐性と細部再現性のバランスを取ります。
実験の結果では、この手法が純粋なHuber損失やMSEと比べ、異常値を含むデータでより高い精度を達成することが確認されています。また計算コストの増加はわずかです。
具体的には、新たに追加された引数loss_type、huber_schedule、huber_cで、損失関数の種類Huber, smooth L1, MSEとスケジューリング方法exponential, constant, SNRを選択できます。これによりデータセットに応じた最適化が可能になります。
詳細は PR [#1228](https://github.com/kohya-ss/sd-scripts/pull/1228/) をご覧ください。
- `loss_type` : 損失関数の種類を指定します。`huber` で Huber損失、`smooth_l1` で smooth L1 損失、`l2` で MSE 損失を選択します。デフォルトは `l2` で、従来と同様です。
- `huber_schedule` : スケジューリング方法を指定します。`exponential` で指数関数的、`constant` で一定、`snr` で信号対雑音比に基づくスケジューリングを選択します。デフォルトは `snr` です。
- `huber_c` : Huber損失のパラメータを指定します。デフォルトは `0.1` です。
PR 内でいくつかの比較が共有されています。この機能を試す場合、最初は `--loss_type smooth_l1 --huber_schedule snr --huber_c 0.1` などで試してみるとよいかもしれません。
最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。
## Additional Information
### Naming of LoRA
The LoRA supported by `train_network.py` has been named to avoid confusion. The documentation has been updated. The following are the names of LoRA types in this repository.
@@ -530,27 +566,14 @@ The LoRA supported by `train_network.py` has been named to avoid confusion. The
In addition to 1., LoRA for Conv2d layers with 3x3 kernel
LoRA-LierLa is the default LoRA type for `train_network.py` (without `conv_dim` network arg). LoRA-LierLa can be used with [our extension](https://github.com/kohya-ss/sd-webui-additional-networks) for AUTOMATIC1111's Web UI, or with the built-in LoRA feature of the Web UI.
LoRA-LierLa is the default LoRA type for `train_network.py` (without `conv_dim` network arg).
<!--
LoRA-LierLa can be used with [our extension](https://github.com/kohya-ss/sd-webui-additional-networks) for AUTOMATIC1111's Web UI, or with the built-in LoRA feature of the Web UI.
To use LoRA-C3Lier with Web UI, please use our extension.
To use LoRA-C3Lier with Web UI, please use our extension.
-->
### LoRAの名称について
`train_network.py` がサポートするLoRAについて、混乱を避けるため名前を付けました。ドキュメントは更新済みです。以下は当リポジトリ内の独自の名称です。
1. __LoRA-LierLa__ : (LoRA for __Li__ n __e__ a __r__ __La__ yers、リエラと読みます)
Linear 層およびカーネルサイズ 1x1 の Conv2d 層に適用されるLoRA
2. __LoRA-C3Lier__ : (LoRA for __C__ olutional layers with __3__ x3 Kernel and __Li__ n __e__ a __r__ layers、セリアと読みます)
1.に加え、カーネルサイズ 3x3 の Conv2d 層に適用されるLoRA
LoRA-LierLa は[Web UI向け拡張](https://github.com/kohya-ss/sd-webui-additional-networks)、またはAUTOMATIC1111氏のWeb UIのLoRA機能で使用することができます。
LoRA-C3Lierを使いWeb UIで生成するには拡張を使用してください。
## Sample image generation during training
### Sample image generation during training
A prompt file might look like this, for example
```
@@ -571,26 +594,3 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
* `--s` Specifies the number of steps in the generation.
The prompt weighting such as `( )` and `[ ]` are working.
## サンプル画像生成
プロンプトファイルは例えば以下のようになります。
```
# prompt 1
masterpiece, best quality, (1girl), in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
# prompt 2
masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n (low quality, worst quality), bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
```
`#` で始まる行はコメントになります。`--n` のように「ハイフン二個+英小文字」の形でオプションを指定できます。以下が使用可能できます。
* `--n` Negative prompt up to the next option.
* `--w` Specifies the width of the generated image.
* `--h` Specifies the height of the generated image.
* `--d` Specifies the seed of the generated image.
* `--l` Specifies the CFG scale of the generated image.
* `--s` Specifies the number of steps in the generation.
`( )` や `[ ]` などの重みづけも動作します。

View File

@@ -1,4 +1,7 @@
import torch
from library.device_utils import init_ipex
init_ipex()
from typing import Union, List, Optional, Dict, Any, Tuple
from diffusers.models.unet_2d_condition import UNet2DConditionOutput

View File

@@ -2,6 +2,7 @@
# Instruction: https://github.com/marketplace/actions/typos-action#getting-started
[default.extend-identifiers]
ddPn08="ddPn08"
[default.extend-words]
NIN="NIN"
@@ -9,7 +10,26 @@ parms="parms"
nin="nin"
extention="extention" # Intentionally left
nd="nd"
shs="shs"
sts="sts"
scs="scs"
cpc="cpc"
coc="coc"
cic="cic"
msm="msm"
usu="usu"
ici="ici"
lvl="lvl"
dii="dii"
muk="muk"
ori="ori"
hru="hru"
rik="rik"
koo="koo"
yos="yos"
wn="wn"
hime="hime"
[files]
extend-exclude = ["_typos.toml"]
extend-exclude = ["_typos.toml", "venv"]

386
docs/config_README-en.md Normal file
View File

@@ -0,0 +1,386 @@
Original Source by kohya-ss
First version:
A.I Translation by Model: NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO, editing by Darkstorm2150
Some parts are manually added.
# Config Readme
This README is about the configuration files that can be passed with the `--dataset_config` option.
## Overview
By passing a configuration file, users can make detailed settings.
* Multiple datasets can be configured
* For example, by setting `resolution` for each dataset, they can be mixed and trained.
* In training methods that support both the DreamBooth approach and the fine-tuning approach, datasets of the DreamBooth method and the fine-tuning method can be mixed.
* Settings can be changed for each subset
* A subset is a partition of the dataset by image directory or metadata. Several subsets make up a dataset.
* Options such as `keep_tokens` and `flip_aug` can be set for each subset. On the other hand, options such as `resolution` and `batch_size` can be set for each dataset, and their values are common among subsets belonging to the same dataset. More details will be provided later.
The configuration file format can be JSON or TOML. Considering the ease of writing, it is recommended to use [TOML](https://toml.io/ja/v1.0.0-rc.2). The following explanation assumes the use of TOML.
Here is an example of a configuration file written in TOML.
```toml
[general]
shuffle_caption = true
caption_extension = '.txt'
keep_tokens = 1
# This is a DreamBooth-style dataset
[[datasets]]
resolution = 512
batch_size = 4
keep_tokens = 2
[[datasets.subsets]]
image_dir = 'C:\hoge'
class_tokens = 'hoge girl'
# This subset uses keep_tokens = 2 (the value of the parent datasets)
[[datasets.subsets]]
image_dir = 'C:\fuga'
class_tokens = 'fuga boy'
keep_tokens = 3
[[datasets.subsets]]
is_reg = true
image_dir = 'C:\reg'
class_tokens = 'human'
keep_tokens = 1
# This is a fine-tuning dataset
[[datasets]]
resolution = [768, 768]
batch_size = 2
[[datasets.subsets]]
image_dir = 'C:\piyo'
metadata_file = 'C:\piyo\piyo_md.json'
# This subset uses keep_tokens = 1 (the value of [general])
```
In this example, three directories are trained as a DreamBooth-style dataset at 512x512 (batch size 4), and one directory is trained as a fine-tuning dataset at 768x768 (batch size 2).
## Settings for datasets and subsets
Settings for datasets and subsets are divided into several registration locations.
* `[general]`
* This is where options that apply to all datasets or all subsets are specified.
* If there are options with the same name in the dataset-specific or subset-specific settings, the dataset-specific or subset-specific settings take precedence.
* `[[datasets]]`
* `datasets` is where settings for datasets are registered. This is where options that apply individually to each dataset are specified.
* If there are subset-specific settings, the subset-specific settings take precedence.
* `[[datasets.subsets]]`
* `datasets.subsets` is where settings for subsets are registered. This is where options that apply individually to each subset are specified.
Here is an image showing the correspondence between image directories and registration locations in the previous example.
```
C:\
├─ hoge -> [[datasets.subsets]] No.1 ┐ ┐
├─ fuga -> [[datasets.subsets]] No.2 |-> [[datasets]] No.1 |-> [general]
├─ reg -> [[datasets.subsets]] No.3 ┘ |
└─ piyo -> [[datasets.subsets]] No.4 --> [[datasets]] No.2 ┘
```
The image directory corresponds to each `[[datasets.subsets]]`. Then, multiple `[[datasets.subsets]]` are combined to form one `[[datasets]]`. All `[[datasets]]` and `[[datasets.subsets]]` belong to `[general]`.
The available options for each registration location may differ, but if the same option is specified, the value in the lower registration location will take precedence. You can check how the `keep_tokens` option is handled in the previous example for better understanding.
Additionally, the available options may vary depending on the method that the learning approach supports.
* Options specific to the DreamBooth method
* Options specific to the fine-tuning method
* Options available when using the caption dropout technique
When using both the DreamBooth method and the fine-tuning method, they can be used together with a learning approach that supports both.
When using them together, a point to note is that the method is determined based on the dataset, so it is not possible to mix DreamBooth method subsets and fine-tuning method subsets within the same dataset.
In other words, if you want to use both methods together, you need to set up subsets of different methods belonging to different datasets.
In terms of program behavior, if the `metadata_file` option exists, it is determined to be a subset of fine-tuning. Therefore, for subsets belonging to the same dataset, as long as they are either "all have the `metadata_file` option" or "all have no `metadata_file` option," there is no problem.
Below, the available options will be explained. For options with the same name as the command-line argument, the explanation will be omitted in principle. Please refer to other READMEs.
### Common options for all learning methods
These are options that can be specified regardless of the learning method.
#### Data set specific options
These are options related to the configuration of the data set. They cannot be described in `datasets.subsets`.
| Option Name | Example Setting | `[general]` | `[[datasets]]` |
| ---- | ---- | ---- | ---- |
| `batch_size` | `1` | o | o |
| `bucket_no_upscale` | `true` | o | o |
| `bucket_reso_steps` | `64` | o | o |
| `enable_bucket` | `true` | o | o |
| `max_bucket_reso` | `1024` | o | o |
| `min_bucket_reso` | `128` | o | o |
| `resolution` | `256`, `[512, 512]` | o | o |
* `batch_size`
* This corresponds to the command-line argument `--train_batch_size`.
* `max_bucket_reso`, `min_bucket_reso`
* Specify the maximum and minimum resolutions of the bucket. It must be divisible by `bucket_reso_steps`.
These settings are fixed per dataset. That means that subsets belonging to the same dataset will share these settings. For example, if you want to prepare datasets with different resolutions, you can define them as separate datasets as shown in the example above, and set different resolutions for each.
#### Options for Subsets
These options are related to subset configuration.
| Option Name | Example | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
| ---- | ---- | ---- | ---- | ---- |
| `color_aug` | `false` | o | o | o |
| `face_crop_aug_range` | `[1.0, 3.0]` | o | o | o |
| `flip_aug` | `true` | o | o | o |
| `keep_tokens` | `2` | o | o | o |
| `num_repeats` | `10` | o | o | o |
| `random_crop` | `false` | o | o | o |
| `shuffle_caption` | `true` | o | o | o |
| `caption_prefix` | `"masterpiece, best quality, "` | o | o | o |
| `caption_suffix` | `", from side"` | o | o | o |
| `caption_separator` | (not specified) | o | o | o |
| `keep_tokens_separator` | `“|||”` | o | o | o |
| `secondary_separator` | `“;;;”` | o | o | o |
| `enable_wildcard` | `true` | o | o | o |
* `num_repeats`
* Specifies the number of repeats for images in a subset. This is equivalent to `--dataset_repeats` in fine-tuning but can be specified for any training method.
* `caption_prefix`, `caption_suffix`
* Specifies the prefix and suffix strings to be appended to the captions. Shuffling is performed with these strings included. Be cautious when using `keep_tokens`.
* `caption_separator`
* Specifies the string to separate the tags. The default is `,`. This option is usually not necessary to set.
* `keep_tokens_separator`
* Specifies the string to separate the parts to be fixed in the caption. For example, if you specify `aaa, bbb ||| ccc, ddd, eee, fff ||| ggg, hhh`, the parts `aaa, bbb` and `ggg, hhh` will remain, and the rest will be shuffled and dropped. The comma in between is not necessary. As a result, the prompt will be `aaa, bbb, eee, ccc, fff, ggg, hhh` or `aaa, bbb, fff, ccc, eee, ggg, hhh`, etc.
* `secondary_separator`
* Specifies an additional separator. The part separated by this separator is treated as one tag and is shuffled and dropped. It is then replaced by `caption_separator`. For example, if you specify `aaa;;;bbb;;;ccc`, it will be replaced by `aaa,bbb,ccc` or dropped together.
* `enable_wildcard`
* Enables wildcard notation. This will be explained later.
### DreamBooth-specific options
DreamBooth-specific options only exist as subsets-specific options.
#### Subset-specific options
Options related to the configuration of DreamBooth subsets.
| Option Name | Example Setting | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
| ---- | ---- | ---- | ---- | ---- |
| `image_dir` | `'C:\hoge'` | - | - | o (required) |
| `caption_extension` | `".txt"` | o | o | o |
| `class_tokens` | `"sks girl"` | - | - | o |
| `cache_info` | `false` | o | o | o |
| `is_reg` | `false` | - | - | o |
Firstly, note that for `image_dir`, the path to the image files must be specified as being directly in the directory. Unlike the previous DreamBooth method, where images had to be placed in subdirectories, this is not compatible with that specification. Also, even if you name the folder something like "5_cat", the number of repeats of the image and the class name will not be reflected. If you want to set these individually, you will need to explicitly specify them using `num_repeats` and `class_tokens`.
* `image_dir`
* Specifies the path to the image directory. This is a required option.
* Images must be placed directly under the directory.
* `class_tokens`
* Sets the class tokens.
* Only used during training when a corresponding caption file does not exist. The determination of whether or not to use it is made on a per-image basis. If `class_tokens` is not specified and a caption file is not found, an error will occur.
* `cache_info`
* Specifies whether to cache the image size and caption. If not specified, it is set to `false`. The cache is saved in `metadata_cache.json` in `image_dir`.
* Caching speeds up the loading of the dataset after the first time. It is effective when dealing with thousands of images or more.
* `is_reg`
* Specifies whether the subset images are for normalization. If not specified, it is set to `false`, meaning that the images are not for normalization.
### Fine-tuning method specific options
The options for the fine-tuning method only exist for subset-specific options.
#### Subset-specific options
These options are related to the configuration of the fine-tuning method's subsets.
| Option name | Example setting | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
| ---- | ---- | ---- | ---- | ---- |
| `image_dir` | `'C:\hoge'` | - | - | o |
| `metadata_file` | `'C:\piyo\piyo_md.json'` | - | - | o (required) |
* `image_dir`
* Specify the path to the image directory. Unlike the DreamBooth method, specifying it is not mandatory, but it is recommended to do so.
* The case where it is not necessary to specify is when the `--full_path` is added to the command line when generating the metadata file.
* The images must be placed directly under the directory.
* `metadata_file`
* Specify the path to the metadata file used for the subset. This is a required option.
* It is equivalent to the command-line argument `--in_json`.
* Due to the specification that a metadata file must be specified for each subset, it is recommended to avoid creating a metadata file with images from different directories as a single metadata file. It is strongly recommended to prepare a separate metadata file for each image directory and register them as separate subsets.
### Options available when caption dropout method can be used
The options available when the caption dropout method can be used exist only for subsets. Regardless of whether it's the DreamBooth method or fine-tuning method, if it supports caption dropout, it can be specified.
#### Subset-specific options
Options related to the setting of subsets that caption dropout can be used for.
| Option Name | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
| ---- | ---- | ---- | ---- |
| `caption_dropout_every_n_epochs` | o | o | o |
| `caption_dropout_rate` | o | o | o |
| `caption_tag_dropout_rate` | o | o | o |
## Behavior when there are duplicate subsets
In the case of the DreamBooth dataset, if there are multiple `image_dir` directories with the same content, they are considered to be duplicate subsets. For the fine-tuning dataset, if there are multiple `metadata_file` files with the same content, they are considered to be duplicate subsets. If duplicate subsets exist in the dataset, subsequent subsets will be ignored.
However, if they belong to different datasets, they are not considered duplicates. For example, if you have subsets with the same `image_dir` in different datasets, they will not be considered duplicates. This is useful when you want to train with the same image but with different resolutions.
```toml
# If data sets exist separately, they are not considered duplicates and are both used for training.
[[datasets]]
resolution = 512
[[datasets.subsets]]
image_dir = 'C:\hoge'
[[datasets]]
resolution = 768
[[datasets.subsets]]
image_dir = 'C:\hoge'
```
## Command Line Argument and Configuration File
There are options in the configuration file that have overlapping roles with command line argument options.
The following command line argument options are ignored if a configuration file is passed:
* `--train_data_dir`
* `--reg_data_dir`
* `--in_json`
The following command line argument options are given priority over the configuration file options if both are specified simultaneously. In most cases, they have the same names as the corresponding options in the configuration file.
| Command Line Argument Option | Prioritized Configuration File Option |
| ------------------------------- | ------------------------------------- |
| `--bucket_no_upscale` | |
| `--bucket_reso_steps` | |
| `--caption_dropout_every_n_epochs` | |
| `--caption_dropout_rate` | |
| `--caption_extension` | |
| `--caption_tag_dropout_rate` | |
| `--color_aug` | |
| `--dataset_repeats` | `num_repeats` |
| `--enable_bucket` | |
| `--face_crop_aug_range` | |
| `--flip_aug` | |
| `--keep_tokens` | |
| `--min_bucket_reso` | |
| `--random_crop` | |
| `--resolution` | |
| `--shuffle_caption` | |
| `--train_batch_size` | `batch_size` |
## Error Guide
Currently, we are using an external library to check if the configuration file is written correctly, but the development has not been completed, and there is a problem that the error message is not clear. In the future, we plan to improve this problem.
As a temporary measure, we will list common errors and their solutions. If you encounter an error even though it should be correct or if the error content is not understandable, please contact us as it may be a bug.
* `voluptuous.error.MultipleInvalid: required key not provided @ ...`: This error occurs when a required option is not provided. It is highly likely that you forgot to specify the option or misspelled the option name.
* The error location is indicated by `...` in the error message. For example, if you encounter an error like `voluptuous.error.MultipleInvalid: required key not provided @ data['datasets'][0]['subsets'][0]['image_dir']`, it means that the `image_dir` option does not exist in the 0th `subsets` of the 0th `datasets` setting.
* `voluptuous.error.MultipleInvalid: expected int for dictionary value @ ...`: This error occurs when the specified value format is incorrect. It is highly likely that the value format is incorrect. The `int` part changes depending on the target option. The example configurations in this README may be helpful.
* `voluptuous.error.MultipleInvalid: extra keys not allowed @ ...`: This error occurs when there is an option name that is not supported. It is highly likely that you misspelled the option name or mistakenly included it.
## Miscellaneous
### Multi-line captions
By setting `enable_wildcard = true`, multiple-line captions are also enabled. If the caption file consists of multiple lines, one line is randomly selected as the caption.
```txt
1girl, hatsune miku, vocaloid, upper body, looking at viewer, microphone, stage
a girl with a microphone standing on a stage
detailed digital art of a girl with a microphone on a stage
```
It can be combined with wildcard notation.
In metadata files, you can also specify multiple-line captions. In the `.json` metadata file, use `\n` to represent a line break. If the caption file consists of multiple lines, `merge_captions_to_metadata.py` will create a metadata file in this format.
The tags in the metadata (`tags`) are added to each line of the caption.
```json
{
"/path/to/image.png": {
"caption": "a cartoon of a frog with the word frog on it\ntest multiline caption1\ntest multiline caption2",
"tags": "open mouth, simple background, standing, no humans, animal, black background, frog, animal costume, animal focus"
},
...
}
```
In this case, the actual caption will be `a cartoon of a frog with the word frog on it, open mouth, simple background ...`, `test multiline caption1, open mouth, simple background ...`, `test multiline caption2, open mouth, simple background ...`, etc.
### Example of configuration file : `secondary_separator`, wildcard notation, `keep_tokens_separator`, etc.
```toml
[general]
flip_aug = true
color_aug = false
resolution = [1024, 1024]
[[datasets]]
batch_size = 6
enable_bucket = true
bucket_no_upscale = true
caption_extension = ".txt"
keep_tokens_separator= "|||"
shuffle_caption = true
caption_tag_dropout_rate = 0.1
secondary_separator = ";;;" # subset 側に書くこともできます / can be written in the subset side
enable_wildcard = true # 同上 / same as above
[[datasets.subsets]]
image_dir = "/path/to/image_dir"
num_repeats = 1
# ||| の前後はカンマは不要です(自動的に追加されます) / No comma is required before and after ||| (it is added automatically)
caption_prefix = "1girl, hatsune miku, vocaloid |||"
# ||| の後はシャッフル、drop されず残ります / After |||, it is not shuffled or dropped and remains
# 単純に文字列として連結されるので、カンマなどは自分で入れる必要があります / It is simply concatenated as a string, so you need to put commas yourself
caption_suffix = ", anime screencap ||| masterpiece, rating: general"
```
### Example of caption, secondary_separator notation: `secondary_separator = ";;;"`
```txt
1girl, hatsune miku, vocaloid, upper body, looking at viewer, sky;;;cloud;;;day, outdoors
```
The part `sky;;;cloud;;;day` is replaced with `sky,cloud,day` without shuffling or dropping. When shuffling and dropping are enabled, it is processed as a whole (as one tag). For example, it becomes `vocaloid, 1girl, upper body, sky,cloud,day, outdoors, hatsune miku` (shuffled) or `vocaloid, 1girl, outdoors, looking at viewer, upper body, hatsune miku` (dropped).
### Example of caption, enable_wildcard notation: `enable_wildcard = true`
```txt
1girl, hatsune miku, vocaloid, upper body, looking at viewer, {simple|white} background
```
`simple` or `white` is randomly selected, and it becomes `simple background` or `white background`.
```txt
1girl, hatsune miku, vocaloid, {{retro style}}
```
If you want to include `{` or `}` in the tag string, double them like `{{` or `}}` (in this example, the actual caption used for training is `{retro style}`).
### Example of caption, `keep_tokens_separator` notation: `keep_tokens_separator = "|||"`
```txt
1girl, hatsune miku, vocaloid ||| stage, microphone, white shirt, smile ||| best quality, rating: general
```
It becomes `1girl, hatsune miku, vocaloid, microphone, stage, white shirt, best quality, rating: general` or `1girl, hatsune miku, vocaloid, white shirt, smile, stage, microphone, best quality, rating: general` etc.

View File

@@ -1,5 +1,3 @@
For non-Japanese speakers: this README is provided only in Japanese in the current state. Sorry for inconvenience. We will provide English version in the near future.
`--dataset_config` で渡すことができる設定ファイルに関する説明です。
## 概要
@@ -120,6 +118,8 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学
* `batch_size`
* コマンドライン引数の `--train_batch_size` と同等です。
* `max_bucket_reso`, `min_bucket_reso`
* bucketの最大、最小解像度を指定します。`bucket_reso_steps` で割り切れる必要があります。
これらの設定はデータセットごとに固定です。
つまり、データセットに所属するサブセットはこれらの設定を共有することになります。
@@ -140,12 +140,28 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学
| `shuffle_caption` | `true` | o | o | o |
| `caption_prefix` | `“masterpiece, best quality, ”` | o | o | o |
| `caption_suffix` | `“, from side”` | o | o | o |
| `caption_separator` | (通常は設定しません) | o | o | o |
| `keep_tokens_separator` | `“|||”` | o | o | o |
| `secondary_separator` | `“;;;”` | o | o | o |
| `enable_wildcard` | `true` | o | o | o |
* `num_repeats`
* サブセットの画像の繰り返し回数を指定します。fine tuning における `--dataset_repeats` に相当しますが、`num_repeats` はどの学習方法でも指定可能です。
* `caption_prefix`, `caption_suffix`
* キャプションの前、後に付与する文字列を指定します。シャッフルはこれらの文字列を含めた状態で行われます。`keep_tokens` を指定する場合には注意してください。
* `caption_separator`
* タグを区切る文字列を指定します。デフォルトは `,` です。このオプションは通常は設定する必要はありません。
* `keep_tokens_separator`
* キャプションで固定したい部分を区切る文字列を指定します。たとえば `aaa, bbb ||| ccc, ddd, eee, fff ||| ggg, hhh` のように指定すると、`aaa, bbb``ggg, hhh` の部分はシャッフル、drop されず残ります。間のカンマは不要です。結果としてプロンプトは `aaa, bbb, eee, ccc, fff, ggg, hhh``aaa, bbb, fff, ccc, eee, ggg, hhh` などになります。
* `secondary_separator`
* 追加の区切り文字を指定します。この区切り文字で区切られた部分は一つのタグとして扱われ、シャッフル、drop されます。その後、`caption_separator` に置き換えられます。たとえば `aaa;;;bbb;;;ccc` のように指定すると、`aaa,bbb,ccc` に置き換えられるか、まとめて drop されます。
* `enable_wildcard`
* ワイルドカード記法および複数行キャプションを有効にします。ワイルドカード記法、複数行キャプションについては後述します。
### DreamBooth 方式専用のオプション
DreamBooth 方式のオプションは、サブセット向けオプションのみ存在します。
@@ -159,6 +175,7 @@ DreamBooth 方式のサブセットの設定に関わるオプションです。
| `image_dir` | `C:\hoge` | - | - | o必須 |
| `caption_extension` | `".txt"` | o | o | o |
| `class_tokens` | `“sks girl”` | - | - | o |
| `cache_info` | `false` | o | o | o |
| `is_reg` | `false` | - | - | o |
まず注意点として、 `image_dir` には画像ファイルが直下に置かれているパスを指定する必要があります。従来の DreamBooth の手法ではサブディレクトリに画像を置く必要がありましたが、そちらとは仕様に互換性がありません。また、`5_cat` のようなフォルダ名にしても、画像の繰り返し回数とクラス名は反映されません。これらを個別に設定したい場合、`num_repeats``class_tokens` で明示的に指定する必要があることに注意してください。
@@ -169,6 +186,9 @@ DreamBooth 方式のサブセットの設定に関わるオプションです。
* `class_tokens`
* クラストークンを設定します。
* 画像に対応する caption ファイルが存在しない場合にのみ学習時に利用されます。利用するかどうかの判定は画像ごとに行います。`class_tokens` を指定しなかった場合に caption ファイルも見つからなかった場合にはエラーになります。
* `cache_info`
* 画像サイズ、キャプションをキャッシュするかどうかを指定します。指定しなかった場合は `false` になります。キャッシュは `image_dir``metadata_cache.json` というファイル名で保存されます。
* キャッシュを行うと、二回目以降のデータセット読み込みが高速化されます。数千枚以上の画像を扱う場合には有効です。
* `is_reg`
* サブセットの画像が正規化用かどうかを指定します。指定しなかった場合は `false` として、つまり正規化画像ではないとして扱います。
@@ -280,4 +300,89 @@ resolution = 768
* `voluptuous.error.MultipleInvalid: expected int for dictionary value @ ...`: 指定する値の形式が不正というエラーです。値の形式が間違っている可能性が高いです。`int` の部分は対象となるオプションによって変わります。この README に載っているオプションの「設定例」が役立つかもしれません。
* `voluptuous.error.MultipleInvalid: extra keys not allowed @ ...`: 対応していないオプション名が存在している場合に発生するエラーです。オプション名を間違って記述しているか、誤って紛れ込んでいる可能性が高いです。
## その他
### 複数行キャプション
`enable_wildcard = true` を設定することで、複数行キャプションも同時に有効になります。キャプションファイルが複数の行からなる場合、ランダムに一つの行が選ばれてキャプションとして利用されます。
```txt
1girl, hatsune miku, vocaloid, upper body, looking at viewer, microphone, stage
a girl with a microphone standing on a stage
detailed digital art of a girl with a microphone on a stage
```
ワイルドカード記法と組み合わせることも可能です。
メタデータファイルでも同様に複数行キャプションを指定することができます。メタデータの .json 内には、`\n` を使って改行を表現してください。キャプションファイルが複数行からなる場合、`merge_captions_to_metadata.py` を使うと、この形式でメタデータファイルが作成されます。
メタデータのタグ (`tags`) は、キャプションの各行に追加されます。
```json
{
"/path/to/image.png": {
"caption": "a cartoon of a frog with the word frog on it\ntest multiline caption1\ntest multiline caption2",
"tags": "open mouth, simple background, standing, no humans, animal, black background, frog, animal costume, animal focus"
},
...
}
```
この場合、実際のキャプションは `a cartoon of a frog with the word frog on it, open mouth, simple background ...` または `test multiline caption1, open mouth, simple background ...``test multiline caption2, open mouth, simple background ...` 等になります。
### 設定ファイルの記述例:追加の区切り文字、ワイルドカード記法、`keep_tokens_separator` 等
```toml
[general]
flip_aug = true
color_aug = false
resolution = [1024, 1024]
[[datasets]]
batch_size = 6
enable_bucket = true
bucket_no_upscale = true
caption_extension = ".txt"
keep_tokens_separator= "|||"
shuffle_caption = true
caption_tag_dropout_rate = 0.1
secondary_separator = ";;;" # subset 側に書くこともできます / can be written in the subset side
enable_wildcard = true # 同上 / same as above
[[datasets.subsets]]
image_dir = "/path/to/image_dir"
num_repeats = 1
# ||| の前後はカンマは不要です(自動的に追加されます) / No comma is required before and after ||| (it is added automatically)
caption_prefix = "1girl, hatsune miku, vocaloid |||"
# ||| の後はシャッフル、drop されず残ります / After |||, it is not shuffled or dropped and remains
# 単純に文字列として連結されるので、カンマなどは自分で入れる必要があります / It is simply concatenated as a string, so you need to put commas yourself
caption_suffix = ", anime screencap ||| masterpiece, rating: general"
```
### キャプション記述例、secondary_separator 記法:`secondary_separator = ";;;"` の場合
```txt
1girl, hatsune miku, vocaloid, upper body, looking at viewer, sky;;;cloud;;;day, outdoors
```
`sky;;;cloud;;;day` の部分はシャッフル、drop されず `sky,cloud,day` に置換されます。シャッフル、drop が有効な場合、まとめて(一つのタグとして)処理されます。つまり `vocaloid, 1girl, upper body, sky,cloud,day, outdoors, hatsune miku` (シャッフル)や `vocaloid, 1girl, outdoors, looking at viewer, upper body, hatsune miku` drop されたケース)などになります。
### キャプション記述例、ワイルドカード記法: `enable_wildcard = true` の場合
```txt
1girl, hatsune miku, vocaloid, upper body, looking at viewer, {simple|white} background
```
ランダムに `simple` または `white` が選ばれ、`simple background` または `white background` になります。
```txt
1girl, hatsune miku, vocaloid, {{retro style}}
```
タグ文字列に `{``}` そのものを含めたい場合は `{{``}}` のように二つ重ねてください(この例では実際に学習に用いられるキャプションは `{retro style}` になります)。
### キャプション記述例、`keep_tokens_separator` 記法: `keep_tokens_separator = "|||"` の場合
```txt
1girl, hatsune miku, vocaloid ||| stage, microphone, white shirt, smile ||| best quality, rating: general
```
`1girl, hatsune miku, vocaloid, microphone, stage, white shirt, best quality, rating: general``1girl, hatsune miku, vocaloid, white shirt, smile, stage, microphone, best quality, rating: general` などになります。

View File

@@ -452,3 +452,36 @@ python gen_img_diffusers.py --ckpt wd-v1-3-full-pruned-half.ckpt
- `--network_show_meta` : 追加ネットワークのメタデータを表示します。
---
# About Gradual Latent
Gradual Latent is a Hires fix that gradually increases the size of the latent. `gen_img.py`, `sdxl_gen_img.py`, and `gen_img_diffusers.py` have the following options.
- `--gradual_latent_timesteps`: Specifies the timestep to start increasing the size of the latent. The default is None, which means Gradual Latent is not used. Please try around 750 at first.
- `--gradual_latent_ratio`: Specifies the initial size of the latent. The default is 0.5, which means it starts with half the default latent size.
- `--gradual_latent_ratio_step`: Specifies the ratio to increase the size of the latent. The default is 0.125, which means the latent size is gradually increased to 0.625, 0.75, 0.875, 1.0.
- `--gradual_latent_ratio_every_n_steps`: Specifies the interval to increase the size of the latent. The default is 3, which means the latent size is increased every 3 steps.
Each option can also be specified with prompt options, `--glt`, `--glr`, `--gls`, `--gle`.
__Please specify `euler_a` for the sampler.__ Because the source code of the sampler is modified. It will not work with other samplers.
It is more effective with SD 1.5. It is quite subtle with SDXL.
# Gradual Latent について
latentのサイズを徐々に大きくしていくHires fixです。`gen_img.py` 、``sdxl_gen_img.py``gen_img_diffusers.py` に以下のオプションが追加されています。
- `--gradual_latent_timesteps` : latentのサイズを大きくし始めるタイムステップを指定します。デフォルトは None で、Gradual Latentを使用しません。750 くらいから始めてみてください。
- `--gradual_latent_ratio` : latentの初期サイズを指定します。デフォルトは 0.5 で、デフォルトの latent サイズの半分のサイズから始めます。
- `--gradual_latent_ratio_step`: latentのサイズを大きくする割合を指定します。デフォルトは 0.125 で、latentのサイズを 0.625, 0.75, 0.875, 1.0 と徐々に大きくします。
- `--gradual_latent_ratio_every_n_steps`: latentのサイズを大きくする間隔を指定します。デフォルトは 3 で、3ステップごとに latent のサイズを大きくします。
それぞれのオプションは、プロンプトオプション、`--glt``--glr``--gls``--gle` でも指定できます。
サンプラーに手を加えているため、__サンプラーに `euler_a` を指定してください。__ 他のサンプラーでは動作しません。
SD 1.5 のほうが効果があります。SDXL ではかなり微妙です。

View File

@@ -0,0 +1,57 @@
## マスクロスについて
マスクロスは、入力画像のマスクで指定された部分だけ損失計算することで、画像の一部分だけを学習することができる機能です。
たとえばキャラクタを学習したい場合、キャラクタ部分だけをマスクして学習することで、背景を無視して学習することができます。
マスクロスのマスクには、二種類の指定方法があります。
- マスク画像を用いる方法
- 透明度(アルファチャネル)を使用する方法
なお、サンプルは [ずんずんPJイラスト/3Dデータ](https://zunko.jp/con_illust.html) の「AI画像モデル用学習データ」を使用しています。
### マスク画像を用いる方法
学習画像それぞれに対応するマスク画像を用意する方法です。学習画像と同じファイル名のマスク画像を用意し、それを学習画像と別のディレクトリに保存します。
- 学習画像
![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/607c5116-5f62-47de-8b66-9c4a597f0441)
- マスク画像
![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/53e9b0f8-a4bf-49ed-882d-4026f84e8450)
```.toml
[[datasets.subsets]]
image_dir = "/path/to/a_zundamon"
caption_extension = ".txt"
conditioning_data_dir = "/path/to/a_zundamon_mask"
num_repeats = 8
```
マスク画像は、学習画像と同じサイズで、学習する部分を白、無視する部分を黒で描画します。グレースケールにも対応しています127 ならロス重みが 0.5 になります)。なお、正確にはマスク画像の R チャネルが用いられます。
DreamBooth 方式の dataset で、`conditioning_data_dir` で指定したディレクトリにマスク画像を保存してください。ControlNet のデータセットと同じですので、詳細は [ControlNet-LLLite](train_lllite_README-ja.md#データセットの準備) を参照してください。
### 透明度(アルファチャネル)を使用する方法
学習画像の透明度(アルファチャネル)がマスクとして使用されます。透明度が 0 の部分は無視され、255 の部分は学習されます。半透明の場合は、その透明度に応じてロス重みが変化します127 ならおおむね 0.5)。
![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/0baa129b-446a-4aac-b98c-7208efb0e75e)
※それぞれの画像は透過PNG
学習時のスクリプトのオプションに `--alpha_mask` を指定するか、dataset の設定ファイルの subset で、`alpha_mask` を指定してください。たとえば、以下のようになります。
```toml
[[datasets.subsets]]
image_dir = "/path/to/image/dir"
caption_extension = ".txt"
num_repeats = 8
alpha_mask = true
```
## 学習時の注意事項
- 現時点では DreamBooth 方式の dataset のみ対応しています。
- マスクは latents のサイズ、つまり 1/8 に縮小されてから適用されます。そのため、細かい部分(たとえばアホ毛やイヤリングなど)はうまく学習できない可能性があります。マスクをわずかに拡張するなどの工夫が必要かもしれません。
- マスクロスを用いる場合、学習対象外の部分をキャプションに含める必要はないかもしれません。(要検証)
- `alpha_mask` の場合、マスクの有無を切り替えると latents キャッシュが自動的に再生成されます。

View File

@@ -0,0 +1,56 @@
## Masked Loss
Masked loss is a feature that allows you to train only part of an image by calculating the loss only for the part specified by the mask of the input image. For example, if you want to train a character, you can train only the character part by masking it, ignoring the background.
There are two ways to specify the mask for masked loss.
- Using a mask image
- Using transparency (alpha channel) of the image
The sample uses the "AI image model training data" from [ZunZunPJ Illustration/3D Data](https://zunko.jp/con_illust.html).
### Using a mask image
This is a method of preparing a mask image corresponding to each training image. Prepare a mask image with the same file name as the training image and save it in a different directory from the training image.
- Training image
![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/607c5116-5f62-47de-8b66-9c4a597f0441)
- Mask image
![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/53e9b0f8-a4bf-49ed-882d-4026f84e8450)
```.toml
[[datasets.subsets]]
image_dir = "/path/to/a_zundamon"
caption_extension = ".txt"
conditioning_data_dir = "/path/to/a_zundamon_mask"
num_repeats = 8
```
The mask image is the same size as the training image, with the part to be trained drawn in white and the part to be ignored in black. It also supports grayscale (127 gives a loss weight of 0.5). The R channel of the mask image is used currently.
Use the dataset in the DreamBooth method, and save the mask image in the directory specified by `conditioning_data_dir`. It is the same as the ControlNet dataset, so please refer to [ControlNet-LLLite](train_lllite_README.md#Preparing-the-dataset) for details.
### Using transparency (alpha channel) of the image
The transparency (alpha channel) of the training image is used as a mask. The part with transparency 0 is ignored, the part with transparency 255 is trained. For semi-transparent parts, the loss weight changes according to the transparency (127 gives a weight of about 0.5).
![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/0baa129b-446a-4aac-b98c-7208efb0e75e)
※Each image is a transparent PNG
Specify `--alpha_mask` in the training script options or specify `alpha_mask` in the subset of the dataset configuration file. For example, it will look like this.
```toml
[[datasets.subsets]]
image_dir = "/path/to/image/dir"
caption_extension = ".txt"
num_repeats = 8
alpha_mask = true
```
## Notes on training
- At the moment, only the dataset in the DreamBooth method is supported.
- The mask is applied after the size is reduced to 1/8, which is the size of the latents. Therefore, fine details (such as ahoge or earrings) may not be learned well. Some dilations of the mask may be necessary.
- If using masked loss, it may not be necessary to include parts that are not to be trained in the caption. (To be verified)
- In the case of `alpha_mask`, the latents cache is automatically regenerated when the enable/disable state of the mask is switched.

View File

@@ -374,6 +374,10 @@ classがひとつで対象が複数の場合、正則化画像フォルダはひ
サンプル出力するステップ数またはエポック数を指定します。この数ごとにサンプル出力します。両方指定するとエポック数が優先されます。
- `--sample_at_first`
学習開始前にサンプル出力します。学習前との比較ができます。
- `--sample_prompts`
サンプル出力用プロンプトのファイルを指定します。
@@ -644,7 +648,7 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
詳細については各自お調べください。
任意のスケジューラを使う場合、任意のオプティマイザと同様に、`--scheduler_args`でオプション引数を指定してください。
任意のスケジューラを使う場合、任意のオプティマイザと同様に、`--lr_scheduler_args`でオプション引数を指定してください。
### オプティマイザの指定について

View File

@@ -582,7 +582,7 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
有关详细信息,请自行研究。
要使用任何调度程序,请像使用任何优化器一样使用“--scheduler_args”指定可选参数。
要使用任何调度程序,请像使用任何优化器一样使用“--lr_scheduler_args”指定可选参数。
### 关于指定优化器
使用 --optimizer_args 选项指定优化器选项参数。可以以key=value的格式指定多个值。此外您可以指定多个值以逗号分隔。例如要指定 AdamW 优化器的参数,``--optimizer_args weight_decay=0.01 betas=.9,.999``。

84
docs/train_SDXL-en.md Normal file
View File

@@ -0,0 +1,84 @@
## SDXL training
The documentation will be moved to the training documentation in the future. The following is a brief explanation of the training scripts for SDXL.
### Training scripts for SDXL
- `sdxl_train.py` is a script for SDXL fine-tuning. The usage is almost the same as `fine_tune.py`, but it also supports DreamBooth dataset.
- `--full_bf16` option is added. Thanks to KohakuBlueleaf!
- This option enables the full bfloat16 training (includes gradients). This option is useful to reduce the GPU memory usage.
- The full bfloat16 training might be unstable. Please use it at your own risk.
- The different learning rates for each U-Net block are now supported in sdxl_train.py. Specify with `--block_lr` option. Specify 23 values separated by commas like `--block_lr 1e-3,1e-3 ... 1e-3`.
- 23 values correspond to `0: time/label embed, 1-9: input blocks 0-8, 10-12: mid blocks 0-2, 13-21: output blocks 0-8, 22: out`.
- `prepare_buckets_latents.py` now supports SDXL fine-tuning.
- `sdxl_train_network.py` is a script for LoRA training for SDXL. The usage is almost the same as `train_network.py`.
- Both scripts has following additional options:
- `--cache_text_encoder_outputs` and `--cache_text_encoder_outputs_to_disk`: Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions.
- `--no_half_vae`: Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs.
- `--weighted_captions` option is not supported yet for both scripts.
- `sdxl_train_textual_inversion.py` is a script for Textual Inversion training for SDXL. The usage is almost the same as `train_textual_inversion.py`.
- `--cache_text_encoder_outputs` is not supported.
- There are two options for captions:
1. Training with captions. All captions must include the token string. The token string is replaced with multiple tokens.
2. Use `--use_object_template` or `--use_style_template` option. The captions are generated from the template. The existing captions are ignored.
- See below for the format of the embeddings.
- `--min_timestep` and `--max_timestep` options are added to each training script. These options can be used to train U-Net with different timesteps. The default values are 0 and 1000.
### Utility scripts for SDXL
- `tools/cache_latents.py` is added. This script can be used to cache the latents to disk in advance.
- The options are almost the same as `sdxl_train.py'. See the help message for the usage.
- Please launch the script as follows:
`accelerate launch --num_cpu_threads_per_process 1 tools/cache_latents.py ...`
- This script should work with multi-GPU, but it is not tested in my environment.
- `tools/cache_text_encoder_outputs.py` is added. This script can be used to cache the text encoder outputs to disk in advance.
- The options are almost the same as `cache_latents.py` and `sdxl_train.py`. See the help message for the usage.
- `sdxl_gen_img.py` is added. This script can be used to generate images with SDXL, including LoRA, Textual Inversion and ControlNet-LLLite. See the help message for the usage.
### Tips for SDXL training
- The default resolution of SDXL is 1024x1024.
- The fine-tuning can be done with 24GB GPU memory with the batch size of 1. For 24GB GPU, the following options are recommended __for the fine-tuning with 24GB GPU memory__:
- Train U-Net only.
- Use gradient checkpointing.
- Use `--cache_text_encoder_outputs` option and caching latents.
- Use Adafactor optimizer. RMSprop 8bit or Adagrad 8bit may work. AdamW 8bit doesn't seem to work.
- The LoRA training can be done with 8GB GPU memory (10GB recommended). For reducing the GPU memory usage, the following options are recommended:
- Train U-Net only.
- Use gradient checkpointing.
- Use `--cache_text_encoder_outputs` option and caching latents.
- Use one of 8bit optimizers or Adafactor optimizer.
- Use lower dim (4 to 8 for 8GB GPU).
- `--network_train_unet_only` option is highly recommended for SDXL LoRA. Because SDXL has two text encoders, the result of the training will be unexpected.
- PyTorch 2 seems to use slightly less GPU memory than PyTorch 1.
- `--bucket_reso_steps` can be set to 32 instead of the default value 64. Smaller values than 32 will not work for SDXL training.
Example of the optimizer settings for Adafactor with the fixed learning rate:
```toml
optimizer_type = "adafactor"
optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ]
lr_scheduler = "constant_with_warmup"
lr_warmup_steps = 100
learning_rate = 4e-7 # SDXL original learning rate
```
### Format of Textual Inversion embeddings for SDXL
```python
from safetensors.torch import save_file
state_dict = {"clip_g": embs_for_text_encoder_1280, "clip_l": embs_for_text_encoder_768}
save_file(state_dict, file)
```
### ControlNet-LLLite
ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [documentation](./docs/train_lllite_README.md) for details.

View File

@@ -21,9 +21,13 @@ ComfyUIのカスタムードを用意しています。: https://github.com/k
## モデルの学習
### データセットの準備
通常のdatasetに加え`conditioning_data_dir` で指定したディレクトリにconditioning imageを格納してください。conditioning imageは学習用画像と同じbasenameを持つ必要があります。また、conditioning imageは学習用画像と同じサイズに自動的にリサイズされます。conditioning imageにはキャプションファイルは不要です。
DreamBooth 方式の dataset`conditioning_data_dir` で指定したディレクトリにconditioning imageを格納してください。
たとえば DreamBooth 方式でキャプションファイルを用いる場合の設定ファイルは以下のようになります。
finetuning 方式の dataset はサポートしていません。)
conditioning imageは学習用画像と同じbasenameを持つ必要があります。また、conditioning imageは学習用画像と同じサイズに自動的にリサイズされます。conditioning imageにはキャプションファイルは不要です。
たとえば、キャプションにフォルダ名ではなくキャプションファイルを用いる場合の設定ファイルは以下のようになります。
```toml
[[datasets.subsets]]

View File

@@ -26,7 +26,9 @@ Due to the limitations of the inference environment, only CrossAttention (attn1
### Preparing the dataset
In addition to the normal dataset, please store the conditioning image in the directory specified by `conditioning_data_dir`. The conditioning image must have the same basename as the training image. The conditioning image will be automatically resized to the same size as the training image. The conditioning image does not require a caption file.
In addition to the normal DreamBooth method dataset, please store the conditioning image in the directory specified by `conditioning_data_dir`. The conditioning image must have the same basename as the training image. The conditioning image will be automatically resized to the same size as the training image. The conditioning image does not require a caption file.
(We do not support the finetuning method dataset.)
```toml
[[datasets.subsets]]

View File

@@ -102,6 +102,8 @@ accelerate launch --num_cpu_threads_per_process 1 train_network.py
* Text Encoderに関連するLoRAモジュールに、通常の学習率--learning_rateオプションで指定とは異なる学習率を使う時に指定します。Text Encoderのほうを若干低めの学習率5e-5などにしたほうが良い、という話もあるようです。
* `--network_args`
* 複数の引数を指定できます。後述します。
* `--alpha_mask`
* 画像のアルファ値をマスクとして使用します。透過画像を学習する際に使用します。[PR #1223](https://github.com/kohya-ss/sd-scripts/pull/1223)
`--network_train_unet_only``--network_train_text_encoder_only` の両方とも未指定時デフォルトはText EncoderとU-Netの両方のLoRAモジュールを有効にします。
@@ -183,12 +185,14 @@ python networks\extract_lora_from_dylora.py --model "foldername/dylora-model.saf
フルモデルの25個のブロックの重みを指定できます。最初のブロックに該当するLoRAは存在しませんが、階層別LoRA適用等との互換性のために25個としています。またconv2d3x3に拡張しない場合も一部のブロックにはLoRAが存在しませんが、記述を統一するため常に25個の値を指定してください。
SDXL では down/up 9 個、middle 3 個の値を指定してください。
`--network_args` で以下の引数を指定してください。
- `down_lr_weight` : U-Netのdown blocksの学習率の重みを指定します。以下が指定可能です。
- ブロックごとの重み : `"down_lr_weight=0,0,0,0,0,0,1,1,1,1,1,1"` のように12個の数値を指定します。
- ブロックごとの重み : `"down_lr_weight=0,0,0,0,0,0,1,1,1,1,1,1"` のように12個SDXL では 9 個)の数値を指定します。
- プリセットからの指定 : `"down_lr_weight=sine"` のように指定しますサインカーブで重みを指定します。sine, cosine, linear, reverse_linear, zeros が指定可能です。また `"down_lr_weight=cosine+.25"` のように `+数値` を追加すると、指定した数値を加算します0.25~1.25になります)。
- `mid_lr_weight` : U-Netのmid blockの学習率の重みを指定します。`"down_lr_weight=0.5"` のように数値を一つだけ指定します。
- `mid_lr_weight` : U-Netのmid blockの学習率の重みを指定します。`"down_lr_weight=0.5"` のように数値を一つだけ指定しますSDXL の場合は 3 個)
- `up_lr_weight` : U-Netのup blocksの学習率の重みを指定します。down_lr_weightと同様です。
- 指定を省略した部分は1.0として扱われます。また重みを0にするとそのブロックのLoRAモジュールは作成されません。
- `block_lr_zero_threshold` : 重みがこの値以下の場合、LoRAモジュールを作成しません。デフォルトは0です。
@@ -213,6 +217,9 @@ network_args = [ "block_lr_zero_threshold=0.1", "down_lr_weight=sine+.5", "mid_l
フルモデルの25個のブロックのdim (rank)を指定できます。階層別学習率と同様に一部のブロックにはLoRAが存在しない場合がありますが、常に25個の値を指定してください。
SDXL では 23 個の値を指定してください。一部のブロックにはLoRA が存在しませんが、`sdxl_train.py` の[階層別学習率](./train_SDXL-en.md) との互換性のためです。
対応は、`0: time/label embed, 1-9: input blocks 0-8, 10-12: mid blocks 0-2, 13-21: output blocks 0-8, 22: out` です。
`--network_args` で以下の引数を指定してください。
- `block_dims` : 各ブロックのdim (rank)を指定します。`"block_dims=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"` のように25個の数値を指定します。
@@ -246,6 +253,8 @@ network_args = [ "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,
merge_lora.pyでStable DiffusionのモデルにLoRAの学習結果をマージしたり、複数のLoRAモデルをマージしたりできます。
SDXL向けにはsdxl_merge_lora.pyを用意しています。オプション等は同一ですので、以下のmerge_lora.pyを読み替えてください。
### Stable DiffusionのモデルにLoRAのモデルをマージする
マージ後のモデルは通常のStable Diffusionのckptと同様に扱えます。たとえば以下のようなコマンドラインになります。
@@ -276,29 +285,29 @@ python networks\merge_lora.py --sd_model ..\model\model.ckpt
### 複数のLoRAのモデルをマージする
__複数のLoRAをマージする場合は原則として `svd_merge_lora.py` を使用してください__ 単純なup同士やdown同士のマージでは、計算結果が正しくなくなるためです
`merge_lora.py` によるマージは差分抽出法でLoRAを生成する場合等、ごく限られた場合でのみ有効です。
--concatオプションを指定すると、複数のLoRAを単純に結合して新しいLoRAモデルを作成できます。ファイルサイズおよびdim/rankは指定したLoRAの合計サイズになりますマージ時にdim (rank)を変更する場合は `svd_merge_lora.py` を使用してください
たとえば以下のようなコマンドラインになります。
```
python networks\merge_lora.py
python networks\merge_lora.py --save_precision bf16
--save_to ..\lora_train1\model-char1-style1-merged.safetensors
--models ..\lora_train1\last.safetensors ..\lora_train2\last.safetensors --ratios 0.6 0.4
--models ..\lora_train1\last.safetensors ..\lora_train2\last.safetensors
--ratios 1.0 -1.0 --concat --shuffle
```
--sd_modelオプション指定不要です。
--concatオプション指定します。
また--shuffleオプションを追加し、重みをシャッフルします。シャッフルしないとマージ後のLoRAから元のLoRAを取り出せるため、コピー機学習などの場合には学習元データが明らかになります。ご注意ください。
--save_toオプションにマージ後のLoRAモデルの保存先を指定します.ckptまたは.safetensors、拡張子で自動判定
--modelsに学習したLoRAのモデルファイルを指定します。三つ以上も指定可能です。
--ratiosにそれぞれのモデルの比率どのくらい重みを元モデルに反映するかを0~1.0の数値で指定します。二つのモデルを一対一でマージす場合は、「0.5 0.5」になります。「1.0 1.0」では合計の重みが大きくなりすぎて、恐らく結果はあまり望ましくないものになると思われます。
--ratiosにそれぞれのモデルの比率どのくらい重みを元モデルに反映するかを0~1.0の数値で指定します。二つのモデルを一対一でマージす場合は、「0.5 0.5」になります。「1.0 1.0」では合計の重みが大きくなりすぎて、恐らく結果はあまり望ましくないものになると思われます。
v1で学習したLoRAとv2で学習したLoRA、rank次元数の異なるLoRAはマージできません。U-NetだけのLoRAとU-Net+Text EncoderのLoRAはマージできるはずですが、結果は未知数です。
### その他のオプション
* precision
@@ -306,6 +315,7 @@ v1で学習したLoRAとv2で学習したLoRA、rank次元数の異なるL
* save_precision
* モデル保存時の精度をfloat、fp16、bf16から指定できます。省略時はprecisionと同じ精度になります。
他にもいくつかのオプションがありますので、--helpで確認してください。
## 複数のrankが異なるLoRAのモデルをマージする

View File

@@ -101,6 +101,8 @@ LoRA的模型将会被保存在通过`--output_dir`选项指定的文件夹中
* 当在Text Encoder相关的LoRA模块中使用与常规学习率`--learning_rate`选项指定不同的学习率时应指定此选项。可能最好将Text Encoder的学习率稍微降低例如5e-5
* `--network_args`
* 可以指定多个参数。将在下面详细说明。
* `--alpha_mask`
* 使用图像的 Alpha 值作为遮罩。这在学习透明图像时使用。[PR #1223](https://github.com/kohya-ss/sd-scripts/pull/1223)
当未指定`--network_train_unet_only``--network_train_text_encoder_only`默认情况将启用Text Encoder和U-Net的两个LoRA模块。

View File

@@ -0,0 +1,88 @@
# Image Tagging using WD14Tagger
This document is based on the information from this github page (https://github.com/toriato/stable-diffusion-webui-wd14-tagger#mrsmilingwolfs-model-aka-waifu-diffusion-14-tagger).
Using onnx for inference is recommended. Please install onnx with the following command:
```powershell
pip install onnx==1.15.0 onnxruntime-gpu==1.17.1
```
The model weights will be automatically downloaded from Hugging Face.
# Usage
Run the script to perform tagging.
```powershell
python finetune/tag_images_by_wd14_tagger.py --onnx --repo_id <model repo id> --batch_size <batch size> <training data folder>
```
For example, if using the repository `SmilingWolf/wd-swinv2-tagger-v3` with a batch size of 4, and the training data is located in the parent folder `train_data`, it would be:
```powershell
python tag_images_by_wd14_tagger.py --onnx --repo_id SmilingWolf/wd-swinv2-tagger-v3 --batch_size 4 ..\train_data
```
On the first run, the model files will be automatically downloaded to the `wd14_tagger_model` folder (the folder can be changed with an option).
Tag files will be created in the same directory as the training data images, with the same filename and a `.txt` extension.
![Generated tag files](https://user-images.githubusercontent.com/52813779/208910534-ea514373-1185-4b7d-9ae3-61eb50bc294e.png)
![Tags and image](https://user-images.githubusercontent.com/52813779/208910599-29070c15-7639-474f-b3e4-06bd5a3df29e.png)
## Example
To output in the Animagine XL 3.1 format, it would be as follows (enter on a single line in practice):
```
python tag_images_by_wd14_tagger.py --onnx --repo_id SmilingWolf/wd-swinv2-tagger-v3
--batch_size 4 --remove_underscore --undesired_tags "PUT,YOUR,UNDESIRED,TAGS" --recursive
--use_rating_tags_as_last_tag --character_tags_first --character_tag_expand
--always_first_tags "1girl,1boy" ..\train_data
```
## Available Repository IDs
[SmilingWolf's V2 and V3 models](https://huggingface.co/SmilingWolf) are available for use. Specify them in the format like `SmilingWolf/wd-vit-tagger-v3`. The default when omitted is `SmilingWolf/wd-v1-4-convnext-tagger-v2`.
# Options
## General Options
- `--onnx`: Use ONNX for inference. If not specified, TensorFlow will be used. If using TensorFlow, please install TensorFlow separately.
- `--batch_size`: Number of images to process at once. Default is 1. Adjust according to VRAM capacity.
- `--caption_extension`: File extension for caption files. Default is `.txt`.
- `--max_data_loader_n_workers`: Maximum number of workers for DataLoader. Specifying a value of 1 or more will use DataLoader to speed up image loading. If unspecified, DataLoader will not be used.
- `--thresh`: Confidence threshold for outputting tags. Default is 0.35. Lowering the value will assign more tags but accuracy will decrease.
- `--general_threshold`: Confidence threshold for general tags. If omitted, same as `--thresh`.
- `--character_threshold`: Confidence threshold for character tags. If omitted, same as `--thresh`.
- `--recursive`: If specified, subfolders within the specified folder will also be processed recursively.
- `--append_tags`: Append tags to existing tag files.
- `--frequency_tags`: Output tag frequencies.
- `--debug`: Debug mode. Outputs debug information if specified.
## Model Download
- `--model_dir`: Folder to save model files. Default is `wd14_tagger_model`.
- `--force_download`: Re-download model files if specified.
## Tag Editing
- `--remove_underscore`: Remove underscores from output tags.
- `--undesired_tags`: Specify tags not to output. Multiple tags can be specified, separated by commas. For example, `black eyes,black hair`.
- `--use_rating_tags`: Output rating tags at the beginning of the tags.
- `--use_rating_tags_as_last_tag`: Add rating tags at the end of the tags.
- `--character_tags_first`: Output character tags first.
- `--character_tag_expand`: Expand character tag series names. For example, split the tag `chara_name_(series)` into `chara_name, series`.
- `--always_first_tags`: Specify tags to always output first when a certain tag appears in an image. Multiple tags can be specified, separated by commas. For example, `1girl,1boy`.
- `--caption_separator`: Separate tags with this string in the output file. Default is `, `.
- `--tag_replacement`: Perform tag replacement. Specify in the format `tag1,tag2;tag3,tag4`. If using `,` and `;`, escape them with `\`. \
For example, specify `aira tsubase,aira tsubase (uniform)` (when you want to train a specific costume), `aira tsubase,aira tsubase\, heir of shadows` (when the series name is not included in the tag).
When using `tag_replacement`, it is applied after `character_tag_expand`.
When specifying `remove_underscore`, specify `undesired_tags`, `always_first_tags`, and `tag_replacement` without including underscores.
When specifying `caption_separator`, separate `undesired_tags` and `always_first_tags` with `caption_separator`. Always separate `tag_replacement` with `,`.

View File

@@ -0,0 +1,88 @@
# WD14Taggerによるタグ付け
こちらのgithubページhttps://github.com/toriato/stable-diffusion-webui-wd14-tagger#mrsmilingwolfs-model-aka-waifu-diffusion-14-tagger )の情報を参考にさせていただきました。
onnx を用いた推論を推奨します。以下のコマンドで onnx をインストールしてください。
```powershell
pip install onnx==1.15.0 onnxruntime-gpu==1.17.1
```
モデルの重みはHugging Faceから自動的にダウンロードしてきます。
# 使い方
スクリプトを実行してタグ付けを行います。
```
python fintune/tag_images_by_wd14_tagger.py --onnx --repo_id <モデルのrepo id> --batch_size <バッチサイズ> <教師データフォルダ>
```
レポジトリに `SmilingWolf/wd-swinv2-tagger-v3` を使用し、バッチサイズを4にして、教師データを親フォルダの `train_data`に置いた場合、以下のようになります。
```
python tag_images_by_wd14_tagger.py --onnx --repo_id SmilingWolf/wd-swinv2-tagger-v3 --batch_size 4 ..\train_data
```
初回起動時にはモデルファイルが `wd14_tagger_model` フォルダに自動的にダウンロードされます(フォルダはオプションで変えられます)。
タグファイルが教師データ画像と同じディレクトリに、同じファイル名、拡張子.txtで作成されます。
![生成されたタグファイル](https://user-images.githubusercontent.com/52813779/208910534-ea514373-1185-4b7d-9ae3-61eb50bc294e.png)
![タグと画像](https://user-images.githubusercontent.com/52813779/208910599-29070c15-7639-474f-b3e4-06bd5a3df29e.png)
## 記述例
Animagine XL 3.1 方式で出力する場合、以下のようになります(実際には 1 行で入力してください)。
```
python tag_images_by_wd14_tagger.py --onnx --repo_id SmilingWolf/wd-swinv2-tagger-v3
--batch_size 4 --remove_underscore --undesired_tags "PUT,YOUR,UNDESIRED,TAGS" --recursive
--use_rating_tags_as_last_tag --character_tags_first --character_tag_expand
--always_first_tags "1girl,1boy" ..\train_data
```
## 使用可能なリポジトリID
[SmilingWolf 氏の V2、V3 のモデル](https://huggingface.co/SmilingWolf)が使用可能です。`SmilingWolf/wd-vit-tagger-v3` のように指定してください。省略時のデフォルトは `SmilingWolf/wd-v1-4-convnext-tagger-v2` です。
# オプション
## 一般オプション
- `--onnx` : ONNX を使用して推論します。指定しない場合は TensorFlow を使用します。TensorFlow 使用時は別途 TensorFlow をインストールしてください。
- `--batch_size` : 一度に処理する画像の数。デフォルトは1です。VRAMの容量に応じて増減してください。
- `--caption_extension` : キャプションファイルの拡張子。デフォルトは `.txt` です。
- `--max_data_loader_n_workers` : DataLoader の最大ワーカー数です。このオプションに 1 以上の数値を指定すると、DataLoader を用いて画像読み込みを高速化します。未指定時は DataLoader を用いません。
- `--thresh` : 出力するタグの信頼度の閾値。デフォルトは0.35です。値を下げるとより多くのタグが付与されますが、精度は下がります。
- `--general_threshold` : 一般タグの信頼度の閾値。省略時は `--thresh` と同じです。
- `--character_threshold` : キャラクタータグの信頼度の閾値。省略時は `--thresh` と同じです。
- `--recursive` : 指定すると、指定したフォルダ内のサブフォルダも再帰的に処理します。
- `--append_tags` : 既存のタグファイルにタグを追加します。
- `--frequency_tags` : タグの頻度を出力します。
- `--debug` : デバッグモード。指定するとデバッグ情報を出力します。
## モデルのダウンロード
- `--model_dir` : モデルファイルの保存先フォルダ。デフォルトは `wd14_tagger_model` です。
- `--force_download` : 指定するとモデルファイルを再ダウンロードします。
## タグ編集関連
- `--remove_underscore` : 出力するタグからアンダースコアを削除します。
- `--undesired_tags` : 出力しないタグを指定します。カンマ区切りで複数指定できます。たとえば `black eyes,black hair` のように指定します。
- `--use_rating_tags` : タグの最初にレーティングタグを出力します。
- `--use_rating_tags_as_last_tag` : タグの最後にレーティングタグを追加します。
- `--character_tags_first` : キャラクタータグを最初に出力します。
- `--character_tag_expand` : キャラクタータグのシリーズ名を展開します。たとえば `chara_name_(series)` のタグを `chara_name, series` に分割します。
- `--always_first_tags` : あるタグが画像に出力されたとき、そのタグを最初に出力するタグを指定します。カンマ区切りで複数指定できます。たとえば `1girl,1boy` のように指定します。
- `--caption_separator` : 出力するファイルでタグをこの文字列で区切ります。デフォルトは `, ` です。
- `--tag_replacement` : タグの置換を行います。`tag1,tag2;tag3,tag4` のように指定します。`,` および `;` を使う場合は `\` でエスケープしてください。\
たとえば `aira tsubase,aira tsubase (uniform)` (特定の衣装を学習させたいとき)、`aira tsubase,aira tsubase\, heir of shadows` (シリーズ名がタグに含まれないとき)のように指定します。
`tag_replacement``character_tag_expand` の後に適用されます。
`remove_underscore` 指定時は、`undesired_tags``always_first_tags``tag_replacement` はアンダースコアを含めずに指定してください。
`caption_separator` 指定時は、`undesired_tags``always_first_tags``caption_separator` で区切ってください。`tag_replacement` は必ず `,` で区切ってください。

View File

@@ -2,17 +2,29 @@
# XXX dropped option: hypernetwork training
import argparse
import gc
import math
import os
from multiprocessing import Value
import toml
from tqdm import tqdm
import torch
from library import deepspeed_utils
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
import library.train_util as train_util
import library.config_util as config_util
from library.config_util import (
@@ -25,12 +37,15 @@ from library.custom_train_functions import (
get_weighted_text_embeddings,
prepare_scheduler_for_custom_training,
scale_v_prediction_loss_like_noise_prediction,
apply_debiased_estimation,
)
def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
deepspeed_utils.prepare_deepspeed_args(args)
setup_logging(args, reset=True)
cache_latents = args.cache_latents
@@ -43,11 +58,11 @@ def train(args):
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, False, True))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
@@ -73,14 +88,16 @@ def train(args):
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
train_dataset_group.verify_bucket_reso_steps(64)
if args.debug_dataset:
train_util.debug_dataset(train_dataset_group)
return
if len(train_dataset_group) == 0:
print(
logger.error(
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
)
return
@@ -91,11 +108,12 @@ def train(args):
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
# acceleratorを準備する
print("prepare accelerator")
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
# モデルを読み込む
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)
@@ -146,15 +164,13 @@ def train(args):
# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
@@ -181,27 +197,33 @@ def train(args):
if not cache_latents:
vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=vae_dtype)
for m in training_models:
m.requires_grad_(True)
params = []
for m in training_models:
params.extend(m.parameters())
params_to_optimize = params
trainable_params = []
if args.learning_rate_te is None or not args.train_text_encoder:
for m in training_models:
trainable_params.extend(m.parameters())
else:
trainable_params = [
{"params": list(unet.parameters()), "lr": args.learning_rate},
{"params": list(text_encoder.parameters()), "lr": args.learning_rate_te},
]
# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
_, _, optimizer = train_util.get_optimizer(args, trainable_params=trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collater,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
@@ -211,7 +233,9 @@ def train(args):
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
@@ -228,16 +252,23 @@ def train(args):
unet.to(weight_dtype)
text_encoder.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
if args.deepspeed:
if args.train_text_encoder:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder)
else:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
training_models = [ds_model]
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
# transform DDP after prepare
text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet)
# acceleratorがなんかよろしくやってくれるらしい
if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
@@ -277,10 +308,20 @@ def train(args):
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
accelerator.init_trackers(
"finetuning" if args.log_tracker_name is None else args.log_tracker_name,
config=train_util.get_sanitized_config_or_none(args),
init_kwargs=init_kwargs,
)
# For --sample_at_first
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
@@ -288,16 +329,15 @@ def train(args):
for m in training_models:
m.train()
loss_total = 0
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
with accelerator.accumulate(*training_models):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device) # .to(dtype=weight_dtype)
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(weight_dtype)
latents = latents * 0.18215
b_size = latents.shape[0]
@@ -320,7 +360,9 @@ def train(args):
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)
# Predict the noise residual
with accelerator.autocast():
@@ -332,19 +374,25 @@ def train(args):
else:
target = noise
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred:
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss:
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
loss = loss.mean([1, 2, 3])
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
loss = loss.mean() # mean over batch dimension
else:
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
)
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
@@ -389,26 +437,20 @@ def train(args):
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if (
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
): # tracking d*lr value
logs["lr/d*lr"] = (
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
)
logs = {"loss": current_loss}
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
accelerator.log(logs, step=global_step)
# TODO moving averageにする
loss_total += current_loss
avr_loss = loss_total / (step + 1)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(train_dataloader)}
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
@@ -441,7 +483,7 @@ def train(args):
accelerator.end_training()
if args.save_state and is_main_process:
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)
del accelerator # この後メモリを使うのでこれは消す
@@ -451,22 +493,37 @@ def train(args):
train_util.save_sd_model_on_train_end(
args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae
)
print("model saved.")
logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
parser.add_argument(
"--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する"
)
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
parser.add_argument(
"--learning_rate_te",
type=float,
default=None,
help="learning rate for text encoder, default is same as unet / Text Encoderの学習率、デフォルトはunetと同じ",
)
parser.add_argument(
"--no_half_vae",
action="store_true",
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
)
return parser
@@ -475,6 +532,7 @@ if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
train(args)

View File

@@ -21,6 +21,10 @@ import torch.nn.functional as F
import os
from urllib.parse import urlparse
from timm.models.hub import download_cached_file
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class BLIP_Base(nn.Module):
def __init__(self,
@@ -130,8 +134,9 @@ class BLIP_Decoder(nn.Module):
def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0):
image_embeds = self.visual_encoder(image)
if not sample:
image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
# recent version of transformers seems to do repeat_interleave automatically
# if not sample:
# image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}
@@ -235,6 +240,6 @@ def load_checkpoint(model,url_or_filename):
del state_dict[key]
msg = model.load_state_dict(state_dict,strict=False)
print('load checkpoint from %s'%url_or_filename)
logger.info('load checkpoint from %s'%url_or_filename)
return model,msg

View File

@@ -8,6 +8,10 @@ import json
import re
from tqdm import tqdm
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
PATTERN_HAIR_LENGTH = re.compile(r', (long|short|medium) hair, ')
PATTERN_HAIR_CUT = re.compile(r', (bob|hime) cut, ')
@@ -36,13 +40,13 @@ def clean_tags(image_key, tags):
tokens = tags.split(", rating")
if len(tokens) == 1:
# WD14 taggerのときはこちらになるのでメッセージは出さない
# print("no rating:")
# print(f"{image_key} {tags}")
# logger.info("no rating:")
# logger.info(f"{image_key} {tags}")
pass
else:
if len(tokens) > 2:
print("multiple ratings:")
print(f"{image_key} {tags}")
logger.info("multiple ratings:")
logger.info(f"{image_key} {tags}")
tags = tokens[0]
tags = ", " + tags.replace(", ", ", , ") + ", " # カンマ付きで検索をするための身も蓋もない対策
@@ -124,43 +128,43 @@ def clean_caption(caption):
def main(args):
if os.path.exists(args.in_json):
print(f"loading existing metadata: {args.in_json}")
logger.info(f"loading existing metadata: {args.in_json}")
with open(args.in_json, "rt", encoding='utf-8') as f:
metadata = json.load(f)
else:
print("no metadata / メタデータファイルがありません")
logger.error("no metadata / メタデータファイルがありません")
return
print("cleaning captions and tags.")
logger.info("cleaning captions and tags.")
image_keys = list(metadata.keys())
for image_key in tqdm(image_keys):
tags = metadata[image_key].get('tags')
if tags is None:
print(f"image does not have tags / メタデータにタグがありません: {image_key}")
logger.error(f"image does not have tags / メタデータにタグがありません: {image_key}")
else:
org = tags
tags = clean_tags(image_key, tags)
metadata[image_key]['tags'] = tags
if args.debug and org != tags:
print("FROM: " + org)
print("TO: " + tags)
logger.info("FROM: " + org)
logger.info("TO: " + tags)
caption = metadata[image_key].get('caption')
if caption is None:
print(f"image does not have caption / メタデータにキャプションがありません: {image_key}")
logger.error(f"image does not have caption / メタデータにキャプションがありません: {image_key}")
else:
org = caption
caption = clean_caption(caption)
metadata[image_key]['caption'] = caption
if args.debug and org != caption:
print("FROM: " + org)
print("TO: " + caption)
logger.info("FROM: " + org)
logger.info("TO: " + caption)
# metadataを書き出して終わり
print(f"writing metadata: {args.out_json}")
logger.info(f"writing metadata: {args.out_json}")
with open(args.out_json, "wt", encoding='utf-8') as f:
json.dump(metadata, f, indent=2)
print("done!")
logger.info("done!")
def setup_parser() -> argparse.ArgumentParser:
@@ -178,10 +182,10 @@ if __name__ == '__main__':
args, unknown = parser.parse_known_args()
if len(unknown) == 1:
print("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.")
print("All captions and tags in the metadata are processed.")
print("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。")
print("メタデータ内のすべてのキャプションとタグが処理されます。")
logger.warning("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.")
logger.warning("All captions and tags in the metadata are processed.")
logger.warning("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。")
logger.warning("メタデータ内のすべてのキャプションとタグが処理されます。")
args.in_json = args.out_json
args.out_json = unknown[0]
elif len(unknown) > 0:

View File

@@ -9,14 +9,22 @@ from pathlib import Path
from PIL import Image
from tqdm import tqdm
import numpy as np
import torch
from library.device_utils import init_ipex, get_preferred_device
init_ipex()
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
sys.path.append(os.path.dirname(__file__))
from blip.blip import blip_decoder
from blip.blip import blip_decoder, is_url
import library.train_util as train_util
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE = get_preferred_device()
IMAGE_SIZE = 384
@@ -47,7 +55,7 @@ class ImageLoadingTransformDataset(torch.utils.data.Dataset):
# convert to tensor temporarily so dataloader will accept it
tensor = IMAGE_TRANSFORM(image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
return None
return (tensor, img_path)
@@ -74,19 +82,21 @@ def main(args):
args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path
cwd = os.getcwd()
print("Current Working Directory is: ", cwd)
logger.info(f"Current Working Directory is: {cwd}")
os.chdir("finetune")
if not is_url(args.caption_weights) and not os.path.isfile(args.caption_weights):
args.caption_weights = os.path.join("..", args.caption_weights)
print(f"load images from {args.train_data_dir}")
logger.info(f"load images from {args.train_data_dir}")
train_data_dir_path = Path(args.train_data_dir)
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.")
logger.info(f"found {len(image_paths)} images.")
print(f"loading BLIP caption: {args.caption_weights}")
logger.info(f"loading BLIP caption: {args.caption_weights}")
model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit="large", med_config="./blip/med_config.json")
model.eval()
model = model.to(DEVICE)
print("BLIP loaded")
logger.info("BLIP loaded")
# captioningする
def run_batch(path_imgs):
@@ -106,7 +116,7 @@ def main(args):
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
f.write(caption + "\n")
if args.debug:
print(image_path, caption)
logger.info(f'{image_path} {caption}')
# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
@@ -136,7 +146,7 @@ def main(args):
raw_image = raw_image.convert("RGB")
img_tensor = IMAGE_TRANSFORM(raw_image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue
b_imgs.append((image_path, img_tensor))
@@ -146,7 +156,7 @@ def main(args):
if len(b_imgs) > 0:
run_batch(b_imgs)
print("done!")
logger.info("done!")
def setup_parser() -> argparse.ArgumentParser:

View File

@@ -5,12 +5,19 @@ import re
from pathlib import Path
from PIL import Image
from tqdm import tqdm
import torch
from library.device_utils import init_ipex, get_preferred_device
init_ipex()
from transformers import AutoProcessor, AutoModelForCausalLM
from transformers.generation.utils import GenerationMixin
import library.train_util as train_util
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -35,8 +42,8 @@ def remove_words(captions, debug):
for pat in PATTERN_REPLACE:
cap = pat.sub("", cap)
if debug and cap != caption:
print(caption)
print(cap)
logger.info(caption)
logger.info(cap)
removed_caps.append(cap)
return removed_caps
@@ -52,6 +59,9 @@ def collate_fn_remove_corrupted(batch):
def main(args):
r"""
transformers 4.30.2で、バッチサイズ>1でも動くようになったので、以下コメントアウト
# GITにバッチサイズが1より大きくても動くようにパッチを当てる: transformers 4.26.0用
org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation
curr_batch_size = [args.batch_size] # ループの最後で件数がbatch_size未満になるので入れ替えられるように
@@ -65,23 +75,24 @@ def main(args):
return input_ids
GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
"""
print(f"load images from {args.train_data_dir}")
logger.info(f"load images from {args.train_data_dir}")
train_data_dir_path = Path(args.train_data_dir)
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.")
logger.info(f"found {len(image_paths)} images.")
# できればcacheに依存せず明示的にダウンロードしたい
print(f"loading GIT: {args.model_id}")
logger.info(f"loading GIT: {args.model_id}")
git_processor = AutoProcessor.from_pretrained(args.model_id)
git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE)
print("GIT loaded")
logger.info("GIT loaded")
# captioningする
def run_batch(path_imgs):
imgs = [im for _, im in path_imgs]
curr_batch_size[0] = len(path_imgs)
# curr_batch_size[0] = len(path_imgs)
inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE) # 画像はpil形式
generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length)
captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True)
@@ -93,7 +104,7 @@ def main(args):
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
f.write(caption + "\n")
if args.debug:
print(image_path, caption)
logger.info(f"{image_path} {caption}")
# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
@@ -122,7 +133,7 @@ def main(args):
if image.mode != "RGB":
image = image.convert("RGB")
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue
b_imgs.append((image_path, image))
@@ -133,7 +144,7 @@ def main(args):
if len(b_imgs) > 0:
run_batch(b_imgs)
print("done!")
logger.info("done!")
def setup_parser() -> argparse.ArgumentParser:

View File

@@ -5,72 +5,96 @@ from typing import List
from tqdm import tqdm
import library.train_util as train_util
import os
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def main(args):
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
assert not args.recursive or (
args.recursive and args.full_path
), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
train_data_dir_path = Path(args.train_data_dir)
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.")
train_data_dir_path = Path(args.train_data_dir)
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
logger.info(f"found {len(image_paths)} images.")
if args.in_json is None and Path(args.out_json).is_file():
args.in_json = args.out_json
if args.in_json is None and Path(args.out_json).is_file():
args.in_json = args.out_json
if args.in_json is not None:
print(f"loading existing metadata: {args.in_json}")
metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
print("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます")
else:
print("new metadata will be created / 新しいメタデータファイルが作成されます")
metadata = {}
if args.in_json is not None:
logger.info(f"loading existing metadata: {args.in_json}")
metadata = json.loads(Path(args.in_json).read_text(encoding="utf-8"))
logger.warning("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます")
else:
logger.info("new metadata will be created / 新しいメタデータファイルが作成されます")
metadata = {}
print("merge caption texts to metadata json.")
for image_path in tqdm(image_paths):
caption_path = image_path.with_suffix(args.caption_extension)
caption = caption_path.read_text(encoding='utf-8').strip()
logger.info("merge caption texts to metadata json.")
for image_path in tqdm(image_paths):
caption_path = image_path.with_suffix(args.caption_extension)
caption = caption_path.read_text(encoding="utf-8").strip()
if not os.path.exists(caption_path):
caption_path = os.path.join(image_path, args.caption_extension)
if not os.path.exists(caption_path):
caption_path = os.path.join(image_path, args.caption_extension)
image_key = str(image_path) if args.full_path else image_path.stem
if image_key not in metadata:
metadata[image_key] = {}
image_key = str(image_path) if args.full_path else image_path.stem
if image_key not in metadata:
metadata[image_key] = {}
metadata[image_key]['caption'] = caption
if args.debug:
print(image_key, caption)
metadata[image_key]["caption"] = caption
if args.debug:
logger.info(f"{image_key} {caption}")
# metadataを書き出して終わり
print(f"writing metadata: {args.out_json}")
Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
print("done!")
# metadataを書き出して終わり
logger.info(f"writing metadata: {args.out_json}")
Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding="utf-8")
logger.info("done!")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("--in_json", type=str,
help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル省略時、out_jsonが存在すればそれを読み込む")
parser.add_argument("--caption_extention", type=str, default=None,
help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります")
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子")
parser.add_argument("--full_path", action="store_true",
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
parser.add_argument("--recursive", action="store_true",
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")
parser.add_argument("--debug", action="store_true", help="debug mode")
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument(
"--in_json",
type=str,
help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル省略時、out_jsonが存在すればそれを読み込む",
)
parser.add_argument(
"--caption_extention",
type=str,
default=None,
help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)",
)
parser.add_argument(
"--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子"
)
parser.add_argument(
"--full_path",
action="store_true",
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)",
)
parser.add_argument(
"--recursive",
action="store_true",
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す",
)
parser.add_argument("--debug", action="store_true", help="debug mode")
return parser
return parser
if __name__ == '__main__':
parser = setup_parser()
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
args = parser.parse_args()
# スペルミスしていたオプションを復元する
if args.caption_extention is not None:
args.caption_extension = args.caption_extention
# スペルミスしていたオプションを復元する
if args.caption_extention is not None:
args.caption_extension = args.caption_extention
main(args)
main(args)

View File

@@ -5,67 +5,89 @@ from typing import List
from tqdm import tqdm
import library.train_util as train_util
import os
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def main(args):
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
assert not args.recursive or (
args.recursive and args.full_path
), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
train_data_dir_path = Path(args.train_data_dir)
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.")
train_data_dir_path = Path(args.train_data_dir)
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
logger.info(f"found {len(image_paths)} images.")
if args.in_json is None and Path(args.out_json).is_file():
args.in_json = args.out_json
if args.in_json is None and Path(args.out_json).is_file():
args.in_json = args.out_json
if args.in_json is not None:
print(f"loading existing metadata: {args.in_json}")
metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
print("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます")
else:
print("new metadata will be created / 新しいメタデータファイルが作成されます")
metadata = {}
if args.in_json is not None:
logger.info(f"loading existing metadata: {args.in_json}")
metadata = json.loads(Path(args.in_json).read_text(encoding="utf-8"))
logger.warning("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます")
else:
logger.info("new metadata will be created / 新しいメタデータファイルが作成されます")
metadata = {}
print("merge tags to metadata json.")
for image_path in tqdm(image_paths):
tags_path = image_path.with_suffix(args.caption_extension)
tags = tags_path.read_text(encoding='utf-8').strip()
logger.info("merge tags to metadata json.")
for image_path in tqdm(image_paths):
tags_path = image_path.with_suffix(args.caption_extension)
tags = tags_path.read_text(encoding="utf-8").strip()
if not os.path.exists(tags_path):
tags_path = os.path.join(image_path, args.caption_extension)
if not os.path.exists(tags_path):
tags_path = os.path.join(image_path, args.caption_extension)
image_key = str(image_path) if args.full_path else image_path.stem
if image_key not in metadata:
metadata[image_key] = {}
image_key = str(image_path) if args.full_path else image_path.stem
if image_key not in metadata:
metadata[image_key] = {}
metadata[image_key]['tags'] = tags
if args.debug:
print(image_key, tags)
metadata[image_key]["tags"] = tags
if args.debug:
logger.info(f"{image_key} {tags}")
# metadataを書き出して終わり
print(f"writing metadata: {args.out_json}")
Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
# metadataを書き出して終わり
logger.info(f"writing metadata: {args.out_json}")
Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding="utf-8")
print("done!")
logger.info("done!")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("--in_json", type=str,
help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル省略時、out_jsonが存在すればそれを読み込む")
parser.add_argument("--full_path", action="store_true",
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応")
parser.add_argument("--recursive", action="store_true",
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")
parser.add_argument("--caption_extension", type=str, default=".txt",
help="extension of caption (tag) file / 読み込むキャプション(タグ)ファイルの拡張子")
parser.add_argument("--debug", action="store_true", help="debug mode, print tags")
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument(
"--in_json",
type=str,
help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル省略時、out_jsonが存在すればそれを読み込む",
)
parser.add_argument(
"--full_path",
action="store_true",
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)",
)
parser.add_argument(
"--recursive",
action="store_true",
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す",
)
parser.add_argument(
"--caption_extension",
type=str,
default=".txt",
help="extension of caption (tag) file / 読み込むキャプション(タグ)ファイルの拡張子",
)
parser.add_argument("--debug", action="store_true", help="debug mode, print tags")
return parser
return parser
if __name__ == '__main__':
parser = setup_parser()
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
main(args)
args = parser.parse_args()
main(args)

View File

@@ -8,13 +8,24 @@ from tqdm import tqdm
import numpy as np
from PIL import Image
import cv2
import torch
from library.device_utils import init_ipex, get_preferred_device
init_ipex()
from torchvision import transforms
import library.model_util as model_util
import library.train_util as train_util
from library.utils import setup_logging
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
setup_logging()
import logging
logger = logging.getLogger(__name__)
DEVICE = get_preferred_device()
IMAGE_TRANSFORMS = transforms.Compose(
[
@@ -51,22 +62,22 @@ def get_npz_filename(data_dir, image_key, is_full_path, recursive):
def main(args):
# assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります"
if args.bucket_reso_steps % 8 > 0:
print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります")
logger.warning(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります")
if args.bucket_reso_steps % 32 > 0:
print(
logger.warning(
f"WARNING: bucket_reso_steps is not divisible by 32. It is not working with SDXL / bucket_reso_stepsが32で割り切れません。SDXLでは動作しません"
)
train_data_dir_path = Path(args.train_data_dir)
image_paths: List[str] = [str(p) for p in train_util.glob_images_pathlib(train_data_dir_path, args.recursive)]
print(f"found {len(image_paths)} images.")
logger.info(f"found {len(image_paths)} images.")
if os.path.exists(args.in_json):
print(f"loading existing metadata: {args.in_json}")
logger.info(f"loading existing metadata: {args.in_json}")
with open(args.in_json, "rt", encoding="utf-8") as f:
metadata = json.load(f)
else:
print(f"no metadata / メタデータファイルがありません: {args.in_json}")
logger.error(f"no metadata / メタデータファイルがありません: {args.in_json}")
return
weight_dtype = torch.float32
@@ -81,7 +92,9 @@ def main(args):
# bucketのサイズを計算する
max_reso = tuple([int(t) for t in args.max_resolution.split(",")])
assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
assert (
len(max_reso) == 2
), f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
bucket_manager = train_util.BucketManager(
args.bucket_no_upscale, max_reso, args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps
@@ -89,7 +102,7 @@ def main(args):
if not args.bucket_no_upscale:
bucket_manager.make_buckets()
else:
print(
logger.warning(
"min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます"
)
@@ -99,7 +112,7 @@ def main(args):
def process_batch(is_last):
for bucket in bucket_manager.buckets:
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
train_util.cache_batch_latents(vae, True, bucket, args.flip_aug, False)
train_util.cache_batch_latents(vae, True, bucket, args.flip_aug, args.alpha_mask, False)
bucket.clear()
# 読み込みの高速化のためにDataLoaderを使うオプション
@@ -130,7 +143,7 @@ def main(args):
if image.mode != "RGB":
image = image.convert("RGB")
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
@@ -183,15 +196,15 @@ def main(args):
for i, reso in enumerate(bucket_manager.resos):
count = bucket_counts.get(reso, 0)
if count > 0:
print(f"bucket {i} {reso}: {count}")
logger.info(f"bucket {i} {reso}: {count}")
img_ar_errors = np.array(img_ar_errors)
print(f"mean ar error: {np.mean(img_ar_errors)}")
logger.info(f"mean ar error: {np.mean(img_ar_errors)}")
# metadataを書き出して終わり
print(f"writing metadata: {args.out_json}")
logger.info(f"writing metadata: {args.out_json}")
with open(args.out_json, "wt", encoding="utf-8") as f:
json.dump(metadata, f, indent=2)
print("done!")
logger.info("done!")
def setup_parser() -> argparse.ArgumentParser:
@@ -200,7 +213,9 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
parser.add_argument("--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)")
parser.add_argument(
"--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)"
)
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument(
"--max_data_loader_n_workers",
@@ -215,7 +230,7 @@ def setup_parser() -> argparse.ArgumentParser:
help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)",
)
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最解像度")
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最解像度")
parser.add_argument(
"--bucket_reso_steps",
type=int,
@@ -223,10 +238,16 @@ def setup_parser() -> argparse.ArgumentParser:
help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します",
)
parser.add_argument(
"--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します"
"--bucket_no_upscale",
action="store_true",
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します",
)
parser.add_argument(
"--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度"
"--mixed_precision",
type=str,
default="no",
choices=["no", "fp16", "bf16"],
help="use mixed precision / 混合精度を使う場合、その精度",
)
parser.add_argument(
"--full_path",
@@ -234,7 +255,15 @@ def setup_parser() -> argparse.ArgumentParser:
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)",
)
parser.add_argument(
"--flip_aug", action="store_true", help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する"
"--flip_aug",
action="store_true",
help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する",
)
parser.add_argument(
"--alpha_mask",
type=str,
default="",
help="save alpha mask for images for loss calculation / 損失計算用に画像のアルファマスクを保存する",
)
parser.add_argument(
"--skip_existing",

View File

@@ -1,18 +1,22 @@
import argparse
import csv
import glob
import os
from PIL import Image
import cv2
from tqdm import tqdm
import numpy as np
from tensorflow.keras.models import load_model
from huggingface_hub import hf_hub_download
import torch
from pathlib import Path
import cv2
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from tqdm import tqdm
import library.train_util as train_util
from library.utils import setup_logging, pil_resize
setup_logging()
import logging
logger = logging.getLogger(__name__)
# from wd14 tagger
IMAGE_SIZE = 448
@@ -20,6 +24,7 @@ IMAGE_SIZE = 448
# wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2
DEFAULT_WD14_TAGGER_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"]
FILES_ONNX = ["model.onnx"]
SUB_DIR = "variables"
SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"]
CSV_FILE = FILES[-1]
@@ -37,8 +42,10 @@ def preprocess_image(image):
pad_t = pad_y // 2
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
if size > IMAGE_SIZE:
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), cv2.INTER_AREA)
else:
image = pil_resize(image, (IMAGE_SIZE, IMAGE_SIZE))
image = image.astype(np.float32)
return image
@@ -57,12 +64,12 @@ class ImageLoadingPrepDataset(torch.utils.data.Dataset):
try:
image = Image.open(img_path).convert("RGB")
image = preprocess_image(image)
tensor = torch.tensor(image)
# tensor = torch.tensor(image) # これ Tensor に変換する必要ないな……(;・∀・)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
return None
return (tensor, img_path)
return (image, img_path)
def collate_fn_remove_corrupted(batch):
@@ -76,101 +83,250 @@ def collate_fn_remove_corrupted(batch):
def main(args):
# model location is model_dir + repo_id
# repo id may be like "user/repo" or "user/repo/branch", so we need to remove slash
model_location = os.path.join(args.model_dir, args.repo_id.replace("/", "_"))
# hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする
# depreacatedの警告が出るけどなくなったらその時
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
if not os.path.exists(args.model_dir) or args.force_download:
print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
for file in FILES:
hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
for file in SUB_DIR_FILES:
hf_hub_download(
args.repo_id,
file,
subfolder=SUB_DIR,
cache_dir=os.path.join(args.model_dir, SUB_DIR),
force_download=True,
force_filename=file,
if not os.path.exists(model_location) or args.force_download:
os.makedirs(args.model_dir, exist_ok=True)
logger.info(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
files = FILES
if args.onnx:
files = ["selected_tags.csv"]
files += FILES_ONNX
else:
for file in SUB_DIR_FILES:
hf_hub_download(
args.repo_id,
file,
subfolder=SUB_DIR,
cache_dir=os.path.join(model_location, SUB_DIR),
force_download=True,
force_filename=file,
)
for file in files:
hf_hub_download(args.repo_id, file, cache_dir=model_location, force_download=True, force_filename=file)
else:
logger.info("using existing wd14 tagger model")
# モデルを読み込む
if args.onnx:
import onnx
import onnxruntime as ort
onnx_path = f"{model_location}/model.onnx"
logger.info("Running wd14 tagger with onnx")
logger.info(f"loading onnx model: {onnx_path}")
if not os.path.exists(onnx_path):
raise Exception(
f"onnx model not found: {onnx_path}, please redownload the model with --force_download"
+ " / onnxモデルが見つかりませんでした。--force_downloadで再ダウンロードしてください"
)
model = onnx.load(onnx_path)
input_name = model.graph.input[0].name
try:
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_value
except Exception:
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_param
if args.batch_size != batch_size and not isinstance(batch_size, str) and batch_size > 0:
# some rebatch model may use 'N' as dynamic axes
logger.warning(
f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}"
)
args.batch_size = batch_size
del model
if "OpenVINOExecutionProvider" in ort.get_available_providers():
# requires provider options for gpu support
# fp16 causes nonsense outputs
ort_sess = ort.InferenceSession(
onnx_path,
providers=(["OpenVINOExecutionProvider"]),
provider_options=[{'device_type' : "GPU_FP32"}],
)
else:
ort_sess = ort.InferenceSession(
onnx_path,
providers=(
["CUDAExecutionProvider"] if "CUDAExecutionProvider" in ort.get_available_providers() else
["ROCMExecutionProvider"] if "ROCMExecutionProvider" in ort.get_available_providers() else
["CPUExecutionProvider"]
),
)
else:
print("using existing wd14 tagger model")
from tensorflow.keras.models import load_model
# 画像を読み込む
model = load_model(args.model_dir)
model = load_model(f"{model_location}")
# label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv")
# 依存ライブラリを増やしたくないので自力で読むよ
with open(os.path.join(args.model_dir, CSV_FILE), "r", encoding="utf-8") as f:
with open(os.path.join(model_location, CSV_FILE), "r", encoding="utf-8") as f:
reader = csv.reader(f)
l = [row for row in reader]
header = l[0] # tag_id,name,category,count
rows = l[1:]
line = [row for row in reader]
header = line[0] # tag_id,name,category,count
rows = line[1:]
assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}"
general_tags = [row[1] for row in rows[1:] if row[2] == "0"]
character_tags = [row[1] for row in rows[1:] if row[2] == "4"]
rating_tags = [row[1] for row in rows[0:] if row[2] == "9"]
general_tags = [row[1] for row in rows[0:] if row[2] == "0"]
character_tags = [row[1] for row in rows[0:] if row[2] == "4"]
# preprocess tags in advance
if args.character_tag_expand:
for i, tag in enumerate(character_tags):
if tag.endswith(")"):
# chara_name_(series) -> chara_name, series
# chara_name_(costume)_(series) -> chara_name_(costume), series
tags = tag.split("(")
character_tag = "(".join(tags[:-1])
if character_tag.endswith("_"):
character_tag = character_tag[:-1]
series_tag = tags[-1].replace(")", "")
character_tags[i] = character_tag + args.caption_separator + series_tag
if args.remove_underscore:
rating_tags = [tag.replace("_", " ") if len(tag) > 3 else tag for tag in rating_tags]
general_tags = [tag.replace("_", " ") if len(tag) > 3 else tag for tag in general_tags]
character_tags = [tag.replace("_", " ") if len(tag) > 3 else tag for tag in character_tags]
if args.tag_replacement is not None:
# escape , and ; in tag_replacement: wd14 tag names may contain , and ;
escaped_tag_replacements = args.tag_replacement.replace("\\,", "@@@@").replace("\\;", "####")
tag_replacements = escaped_tag_replacements.split(";")
for tag_replacement in tag_replacements:
tags = tag_replacement.split(",") # source, target
assert len(tags) == 2, f"tag replacement must be in the format of `source,target` / タグの置換は `置換元,置換先` の形式で指定してください: {args.tag_replacement}"
source, target = [tag.replace("@@@@", ",").replace("####", ";") for tag in tags]
logger.info(f"replacing tag: {source} -> {target}")
if source in general_tags:
general_tags[general_tags.index(source)] = target
elif source in character_tags:
character_tags[character_tags.index(source)] = target
elif source in rating_tags:
rating_tags[rating_tags.index(source)] = target
# 画像を読み込む
train_data_dir_path = Path(args.train_data_dir)
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.")
logger.info(f"found {len(image_paths)} images.")
tag_freq = {}
undesired_tags = set(args.undesired_tags.split(","))
caption_separator = args.caption_separator
stripped_caption_separator = caption_separator.strip()
undesired_tags = args.undesired_tags.split(stripped_caption_separator)
undesired_tags = set([tag.strip() for tag in undesired_tags if tag.strip() != ""])
always_first_tags = None
if args.always_first_tags is not None:
always_first_tags = [tag for tag in args.always_first_tags.split(stripped_caption_separator) if tag.strip() != ""]
def run_batch(path_imgs):
imgs = np.array([im for _, im in path_imgs])
probs = model(imgs, training=False)
probs = probs.numpy()
if args.onnx:
# if len(imgs) < args.batch_size:
# imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0)
probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy
probs = probs[: len(path_imgs)]
else:
probs = model(imgs, training=False)
probs = probs.numpy()
for (image_path, _), prob in zip(path_imgs, probs):
# 最初の4つはratingなので無視する
# # First 4 labels are actually ratings: pick one with argmax
# ratings_names = label_names[:4]
# rating_index = ratings_names["probs"].argmax()
# found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]]
# それ以降はタグなのでconfidenceがthresholdより高いものを追加する
# Everything else is tags: pick any where prediction confidence > threshold
combined_tags = []
general_tag_text = ""
rating_tag_text = ""
character_tag_text = ""
general_tag_text = ""
# 最初の4つ以降はタグなのでconfidenceがthreshold以上のものを追加する
# First 4 labels are ratings, the rest are tags: pick any where prediction confidence >= threshold
for i, p in enumerate(prob[4:]):
if i < len(general_tags) and p >= args.general_threshold:
tag_name = general_tags[i]
if args.remove_underscore and len(tag_name) > 3: # ignore emoji tags like >_< and ^_^
tag_name = tag_name.replace("_", " ")
if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
general_tag_text += ", " + tag_name
general_tag_text += caption_separator + tag_name
combined_tags.append(tag_name)
elif i >= len(general_tags) and p >= args.character_threshold:
tag_name = character_tags[i - len(general_tags)]
if args.remove_underscore and len(tag_name) > 3:
tag_name = tag_name.replace("_", " ")
if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
character_tag_text += ", " + tag_name
combined_tags.append(tag_name)
character_tag_text += caption_separator + tag_name
if args.character_tags_first: # insert to the beginning
combined_tags.insert(0, tag_name)
else:
combined_tags.append(tag_name)
# 最初の4つはratingなのでargmaxで選ぶ
# First 4 labels are actually ratings: pick one with argmax
if args.use_rating_tags or args.use_rating_tags_as_last_tag:
ratings_probs = prob[:4]
rating_index = ratings_probs.argmax()
found_rating = rating_tags[rating_index]
if found_rating not in undesired_tags:
tag_freq[found_rating] = tag_freq.get(found_rating, 0) + 1
rating_tag_text = found_rating
if args.use_rating_tags:
combined_tags.insert(0, found_rating) # insert to the beginning
else:
combined_tags.append(found_rating)
# 一番最初に置くタグを指定する
# Always put some tags at the beginning
if always_first_tags is not None:
for tag in always_first_tags:
if tag in combined_tags:
combined_tags.remove(tag)
combined_tags.insert(0, tag)
# 先頭のカンマを取る
if len(general_tag_text) > 0:
general_tag_text = general_tag_text[2:]
general_tag_text = general_tag_text[len(caption_separator) :]
if len(character_tag_text) > 0:
character_tag_text = character_tag_text[2:]
character_tag_text = character_tag_text[len(caption_separator) :]
tag_text = ", ".join(combined_tags)
caption_file = os.path.splitext(image_path)[0] + args.caption_extension
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
tag_text = caption_separator.join(combined_tags)
if args.append_tags:
# Check if file exists
if os.path.exists(caption_file):
with open(caption_file, "rt", encoding="utf-8") as f:
# Read file and remove new lines
existing_content = f.read().strip("\n") # Remove newlines
# Split the content into tags and store them in a list
existing_tags = [tag.strip() for tag in existing_content.split(stripped_caption_separator) if tag.strip()]
# Check and remove repeating tags in tag_text
new_tags = [tag for tag in combined_tags if tag not in existing_tags]
# Create new tag_text
tag_text = caption_separator.join(existing_tags + new_tags)
with open(caption_file, "wt", encoding="utf-8") as f:
f.write(tag_text + "\n")
if args.debug:
print(f"\n{image_path}:\n Character tags: {character_tag_text}\n General tags: {general_tag_text}")
logger.info("")
logger.info(f"{image_path}:")
logger.info(f"\tRating tags: {rating_tag_text}")
logger.info(f"\tCharacter tags: {character_tag_text}")
logger.info(f"\tGeneral tags: {general_tag_text}")
# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
@@ -193,16 +349,14 @@ def main(args):
continue
image, image_path = data
if image is not None:
image = image.detach().numpy()
else:
if image is None:
try:
image = Image.open(image_path)
if image.mode != "RGB":
image = image.convert("RGB")
image = preprocess_image(image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue
b_imgs.append((image_path, image))
@@ -217,16 +371,18 @@ def main(args):
if args.frequency_tags:
sorted_tags = sorted(tag_freq.items(), key=lambda x: x[1], reverse=True)
print("\nTag frequencies:")
print("Tag frequencies:")
for tag, freq in sorted_tags:
print(f"{tag}: {freq}")
print("done!")
logger.info("done!")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument(
"train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ"
)
parser.add_argument(
"--repo_id",
type=str,
@@ -240,9 +396,13 @@ def setup_parser() -> argparse.ArgumentParser:
help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ",
)
parser.add_argument(
"--force_download", action="store_true", help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします"
"--force_download",
action="store_true",
help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします",
)
parser.add_argument(
"--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ"
)
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument(
"--max_data_loader_n_workers",
type=int,
@@ -255,8 +415,12 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)",
)
parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子")
parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値")
parser.add_argument(
"--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子"
)
parser.add_argument(
"--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値"
)
parser.add_argument(
"--general_threshold",
type=float,
@@ -269,26 +433,74 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="threshold of confidence to add a tag for character category, same as --thres if omitted / characterカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ",
)
parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する")
parser.add_argument(
"--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する"
)
parser.add_argument(
"--remove_underscore",
action="store_true",
help="replace underscores with spaces in the output tags / 出力されるタグのアンダースコアをスペースに置き換える",
)
parser.add_argument("--debug", action="store_true", help="debug mode")
parser.add_argument(
"--debug", action="store_true", help="debug mode"
)
parser.add_argument(
"--undesired_tags",
type=str,
default="",
help="comma-separated list of undesired tags to remove from the output / 出力から除外したいタグのカンマ区切りのリスト",
)
parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する")
parser.add_argument(
"--frequency_tags", action="store_true", help="Show frequency of tags for images / タグの出現頻度を表示する"
)
parser.add_argument(
"--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する"
)
parser.add_argument(
"--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する"
)
parser.add_argument(
"--use_rating_tags", action="store_true", help="Adds rating tags as the first tag / レーティングタグを最初のタグとして追加する",
)
parser.add_argument(
"--use_rating_tags_as_last_tag", action="store_true", help="Adds rating tags as the last tag / レーティングタグを最後のタグとして追加する",
)
parser.add_argument(
"--character_tags_first", action="store_true", help="Always inserts character tags before the general tags / characterタグを常にgeneralタグの前に出力する",
)
parser.add_argument(
"--always_first_tags",
type=str,
default=None,
help="comma-separated list of tags to always put at the beginning, e.g. `1girl,1boy`"
+ " / 必ず先頭に置くタグのカンマ区切りリスト、例 : `1girl,1boy`",
)
parser.add_argument(
"--caption_separator",
type=str,
default=", ",
help="Separator for captions, include space if needed / キャプションの区切り文字、必要ならスペースを含めてください",
)
parser.add_argument(
"--tag_replacement",
type=str,
default=None,
help="tag replacement in the format of `source1,target1;source2,target2; ...`. Escape `,` and `;` with `\`. e.g. `tag1,tag2;tag3,tag4`"
+ " / タグの置換を `置換元1,置換先1;置換元2,置換先2; ...`で指定する。`\` で `,` と `;` をエスケープできる。例: `tag1,tag2;tag3,tag4`",
)
parser.add_argument(
"--character_tag_expand",
action="store_true",
help="expand tag tail parenthesis to another tag for character tags. `chara_name_(series)` becomes `chara_name, series`"
+ " / キャラクタタグの末尾の括弧を別のタグに展開する。`chara_name_(series)` は `chara_name, series` になる",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
# スペルミスしていたオプションを復元する

3358
gen_img.py Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

106
library/adafactor_fused.py Normal file
View File

@@ -0,0 +1,106 @@
import math
import torch
from transformers import Adafactor
@torch.no_grad()
def adafactor_step_param(self, p, group):
if p.grad is None:
return
grad = p.grad
if grad.dtype in {torch.float16, torch.bfloat16}:
grad = grad.float()
if grad.is_sparse:
raise RuntimeError("Adafactor does not support sparse gradients.")
state = self.state[p]
grad_shape = grad.shape
factored, use_first_moment = Adafactor._get_options(group, grad_shape)
# State Initialization
if len(state) == 0:
state["step"] = 0
if use_first_moment:
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(grad)
if factored:
state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
else:
state["exp_avg_sq"] = torch.zeros_like(grad)
state["RMS"] = 0
else:
if use_first_moment:
state["exp_avg"] = state["exp_avg"].to(grad)
if factored:
state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
else:
state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
p_data_fp32 = p
if p.dtype in {torch.float16, torch.bfloat16}:
p_data_fp32 = p_data_fp32.float()
state["step"] += 1
state["RMS"] = Adafactor._rms(p_data_fp32)
lr = Adafactor._get_lr(group, state)
beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
update = (grad ** 2) + group["eps"][0]
if factored:
exp_avg_sq_row = state["exp_avg_sq_row"]
exp_avg_sq_col = state["exp_avg_sq_col"]
exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
# Approximation of exponential moving average of square of gradient
update = Adafactor._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update.mul_(grad)
else:
exp_avg_sq = state["exp_avg_sq"]
exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
update = exp_avg_sq.rsqrt().mul_(grad)
update.div_((Adafactor._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
update.mul_(lr)
if use_first_moment:
exp_avg = state["exp_avg"]
exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
update = exp_avg
if group["weight_decay"] != 0:
p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
p_data_fp32.add_(-update)
if p.dtype in {torch.float16, torch.bfloat16}:
p.copy_(p_data_fp32)
@torch.no_grad()
def adafactor_step(self, closure=None):
"""
Performs a single optimization step
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
adafactor_step_param(self, p, group)
return loss
def patch_adafactor_fused(optimizer: Adafactor):
optimizer.step_param = adafactor_step_param.__get__(optimizer)
optimizer.step = adafactor_step.__get__(optimizer)

File diff suppressed because it is too large Load Diff

View File

@@ -3,6 +3,12 @@ import argparse
import random
import re
from typing import List, Optional, Union
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def prepare_scheduler_for_custom_training(noise_scheduler, device):
@@ -21,7 +27,7 @@ def prepare_scheduler_for_custom_training(noise_scheduler, device):
def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
# fix beta: zero terminal SNR
print(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891")
logger.info(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891")
def enforce_zero_terminal_snr(betas):
# Convert betas to alphas_bar_sqrt
@@ -49,18 +55,21 @@ def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
# print("original:", noise_scheduler.betas)
# print("fixed:", betas)
# logger.info(f"original: {noise_scheduler.betas}")
# logger.info(f"fixed: {betas}")
noise_scheduler.betas = betas
noise_scheduler.alphas = alphas
noise_scheduler.alphas_cumprod = alphas_cumprod
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False):
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device) # from paper
min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
if v_prediction:
snr_weight = torch.div(min_snr_gamma, snr + 1).float().to(loss.device)
else:
snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device)
loss = loss * snr_weight
return loss
@@ -76,17 +85,28 @@ def get_snr_scale(timesteps, noise_scheduler):
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
scale = snr_t / (snr_t + 1)
# # show debug info
# print(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}")
# logger.info(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}")
return scale
def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss):
scale = get_snr_scale(timesteps, noise_scheduler)
# print(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
# logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
loss = loss + loss / scale * v_pred_like_loss
return loss
def apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=False):
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
if v_prediction:
weight = 1 / (snr_t + 1)
else:
weight = 1 / torch.sqrt(snr_t)
loss = weight * loss
return loss
# TODO train_utilと分散しているのでどちらかに寄せる
@@ -108,6 +128,11 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted
default=None,
help="add v-prediction like loss multiplied by this value / v-prediction lossをこの値をかけたものをlossに加算する",
)
parser.add_argument(
"--debiased_estimation_loss",
action="store_true",
help="debiased estimation loss / debiased estimation loss",
)
if support_weighted_captions:
parser.add_argument(
"--weighted_captions",
@@ -254,7 +279,7 @@ def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int):
tokens.append(text_token)
weights.append(text_weight)
if truncated:
print("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
return tokens, weights
@@ -457,6 +482,25 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
return noise
def apply_masked_loss(loss, batch):
if "conditioning_images" in batch:
# conditioning image is -1 to 1. we need to convert it to 0 to 1
mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
mask_image = mask_image / 2 + 0.5
# print(f"conditioning_image: {mask_image.shape}")
elif "alpha_masks" in batch and batch["alpha_masks"] is not None:
# alpha mask is 0 to 1
mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension
# print(f"mask_image: {mask_image.shape}, {mask_image.mean()}")
else:
return loss
# resize to the same size as the loss
mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")
loss = loss * mask_image
return loss
"""
##########################################
# Perlin Noise

139
library/deepspeed_utils.py Normal file
View File

@@ -0,0 +1,139 @@
import os
import argparse
import torch
from accelerate import DeepSpeedPlugin, Accelerator
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def add_deepspeed_arguments(parser: argparse.ArgumentParser):
# DeepSpeed Arguments. https://huggingface.co/docs/accelerate/usage_guides/deepspeed
parser.add_argument("--deepspeed", action="store_true", help="enable deepspeed training")
parser.add_argument("--zero_stage", type=int, default=2, choices=[0, 1, 2, 3], help="Possible options are 0,1,2,3.")
parser.add_argument(
"--offload_optimizer_device",
type=str,
default=None,
choices=[None, "cpu", "nvme"],
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3.",
)
parser.add_argument(
"--offload_optimizer_nvme_path",
type=str,
default=None,
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
)
parser.add_argument(
"--offload_param_device",
type=str,
default=None,
choices=[None, "cpu", "nvme"],
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3.",
)
parser.add_argument(
"--offload_param_nvme_path",
type=str,
default=None,
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
)
parser.add_argument(
"--zero3_init_flag",
action="store_true",
help="Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models."
"Only applicable with ZeRO Stage-3.",
)
parser.add_argument(
"--zero3_save_16bit_model",
action="store_true",
help="Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3.",
)
parser.add_argument(
"--fp16_master_weights_and_gradients",
action="store_true",
help="fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32.",
)
def prepare_deepspeed_args(args: argparse.Namespace):
if not args.deepspeed:
return
# To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
args.max_data_loader_n_workers = 1
def prepare_deepspeed_plugin(args: argparse.Namespace):
if not args.deepspeed:
return None
try:
import deepspeed
except ImportError as e:
logger.error(
"deepspeed is not installed. please install deepspeed in your environment with following command. DS_BUILD_OPS=0 pip install deepspeed"
)
exit(1)
deepspeed_plugin = DeepSpeedPlugin(
zero_stage=args.zero_stage,
gradient_accumulation_steps=args.gradient_accumulation_steps,
gradient_clipping=args.max_grad_norm,
offload_optimizer_device=args.offload_optimizer_device,
offload_optimizer_nvme_path=args.offload_optimizer_nvme_path,
offload_param_device=args.offload_param_device,
offload_param_nvme_path=args.offload_param_nvme_path,
zero3_init_flag=args.zero3_init_flag,
zero3_save_16bit_model=args.zero3_save_16bit_model,
)
deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_batch_size
deepspeed_plugin.deepspeed_config["train_batch_size"] = (
args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"])
)
deepspeed_plugin.set_mixed_precision(args.mixed_precision)
if args.mixed_precision.lower() == "fp16":
deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow.
if args.full_fp16 or args.fp16_master_weights_and_gradients:
if args.offload_optimizer_device == "cpu" and args.zero_stage == 2:
deepspeed_plugin.deepspeed_config["fp16"]["fp16_master_weights_and_grads"] = True
logger.info("[DeepSpeed] full fp16 enable.")
else:
logger.info(
"[DeepSpeed]full fp16, fp16_master_weights_and_grads currently only supported using ZeRO-Offload with DeepSpeedCPUAdam on ZeRO-2 stage."
)
if args.offload_optimizer_device is not None:
logger.info("[DeepSpeed] start to manually build cpu_adam.")
deepspeed.ops.op_builder.CPUAdamBuilder().load()
logger.info("[DeepSpeed] building cpu_adam done.")
return deepspeed_plugin
# Accelerate library does not support multiple models for deepspeed. So, we need to wrap multiple models into a single model.
def prepare_deepspeed_model(args: argparse.Namespace, **models):
# remove None from models
models = {k: v for k, v in models.items() if v is not None}
class DeepSpeedWrapper(torch.nn.Module):
def __init__(self, **kw_models) -> None:
super().__init__()
self.models = torch.nn.ModuleDict()
for key, model in kw_models.items():
if isinstance(model, list):
model = torch.nn.ModuleList(model)
assert isinstance(
model, torch.nn.Module
), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
self.models.update(torch.nn.ModuleDict({key: model}))
def get_models(self):
return self.models
ds_model = DeepSpeedWrapper(**models)
return ds_model

84
library/device_utils.py Normal file
View File

@@ -0,0 +1,84 @@
import functools
import gc
import torch
try:
HAS_CUDA = torch.cuda.is_available()
except Exception:
HAS_CUDA = False
try:
HAS_MPS = torch.backends.mps.is_available()
except Exception:
HAS_MPS = False
try:
import intel_extension_for_pytorch as ipex # noqa
HAS_XPU = torch.xpu.is_available()
except Exception:
HAS_XPU = False
def clean_memory():
gc.collect()
if HAS_CUDA:
torch.cuda.empty_cache()
if HAS_XPU:
torch.xpu.empty_cache()
if HAS_MPS:
torch.mps.empty_cache()
def clean_memory_on_device(device: torch.device):
r"""
Clean memory on the specified device, will be called from training scripts.
"""
gc.collect()
# device may "cuda" or "cuda:0", so we need to check the type of device
if device.type == "cuda":
torch.cuda.empty_cache()
if device.type == "xpu":
torch.xpu.empty_cache()
if device.type == "mps":
torch.mps.empty_cache()
@functools.lru_cache(maxsize=None)
def get_preferred_device() -> torch.device:
r"""
Do not call this function from training scripts. Use accelerator.device instead.
"""
if HAS_CUDA:
device = torch.device("cuda")
elif HAS_XPU:
device = torch.device("xpu")
elif HAS_MPS:
device = torch.device("mps")
else:
device = torch.device("cpu")
print(f"get_preferred_device() -> {device}")
return device
def init_ipex():
"""
Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`.
This function should run right after importing torch and before doing anything else.
If IPEX is not available, this function does nothing.
"""
try:
if HAS_XPU:
from library.ipex import ipex_init
is_initialized, error_message = ipex_init()
if not is_initialized:
print("failed to initialize ipex:", error_message)
else:
return
except Exception as e:
print("failed to initialize ipex:", e)

View File

@@ -4,7 +4,10 @@ from pathlib import Path
import argparse
import os
from library.utils import fire_in_thread
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None):
api = HfApi(
@@ -33,9 +36,9 @@ def upload(
try:
api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
except Exception as e: # とりあえずRepositoryNotFoundErrorは確認したが他にあると困るので
print("===========================================")
print(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}")
print("===========================================")
logger.error("===========================================")
logger.error(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}")
logger.error("===========================================")
is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir())
@@ -56,9 +59,9 @@ def upload(
path_in_repo=path_in_repo,
)
except Exception as e: # RuntimeErrorを確認済みだが他にあると困るので
print("===========================================")
print(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}")
print("===========================================")
logger.error("===========================================")
logger.error(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}")
logger.error("===========================================")
if args.async_upload and not force_sync_upload:
fire_in_thread(uploader)

180
library/ipex/__init__.py Normal file
View File

@@ -0,0 +1,180 @@
import os
import sys
import contextlib
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
from .hijacks import ipex_hijacks
# pylint: disable=protected-access, missing-function-docstring, line-too-long
def ipex_init(): # pylint: disable=too-many-statements
try:
if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_xpu_hijacked") and torch.cuda.is_xpu_hijacked:
return True, "Skipping IPEX hijack"
else:
# Replace cuda with xpu:
torch.cuda.current_device = torch.xpu.current_device
torch.cuda.current_stream = torch.xpu.current_stream
torch.cuda.device = torch.xpu.device
torch.cuda.device_count = torch.xpu.device_count
torch.cuda.device_of = torch.xpu.device_of
torch.cuda.get_device_name = torch.xpu.get_device_name
torch.cuda.get_device_properties = torch.xpu.get_device_properties
torch.cuda.init = torch.xpu.init
torch.cuda.is_available = torch.xpu.is_available
torch.cuda.is_initialized = torch.xpu.is_initialized
torch.cuda.is_current_stream_capturing = lambda: False
torch.cuda.set_device = torch.xpu.set_device
torch.cuda.stream = torch.xpu.stream
torch.cuda.synchronize = torch.xpu.synchronize
torch.cuda.Event = torch.xpu.Event
torch.cuda.Stream = torch.xpu.Stream
torch.cuda.FloatTensor = torch.xpu.FloatTensor
torch.Tensor.cuda = torch.Tensor.xpu
torch.Tensor.is_cuda = torch.Tensor.is_xpu
torch.nn.Module.cuda = torch.nn.Module.xpu
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
torch.cuda._initialized = torch.xpu.lazy_init._initialized
torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
torch.cuda._tls = torch.xpu.lazy_init._tls
torch.cuda.threading = torch.xpu.lazy_init.threading
torch.cuda.traceback = torch.xpu.lazy_init.traceback
torch.cuda.Optional = torch.xpu.Optional
torch.cuda.__cached__ = torch.xpu.__cached__
torch.cuda.__loader__ = torch.xpu.__loader__
torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
torch.cuda.Tuple = torch.xpu.Tuple
torch.cuda.streams = torch.xpu.streams
torch.cuda._lazy_new = torch.xpu._lazy_new
torch.cuda.FloatStorage = torch.xpu.FloatStorage
torch.cuda.Any = torch.xpu.Any
torch.cuda.__doc__ = torch.xpu.__doc__
torch.cuda.default_generators = torch.xpu.default_generators
torch.cuda.HalfTensor = torch.xpu.HalfTensor
torch.cuda._get_device_index = torch.xpu._get_device_index
torch.cuda.__path__ = torch.xpu.__path__
torch.cuda.Device = torch.xpu.Device
torch.cuda.IntTensor = torch.xpu.IntTensor
torch.cuda.ByteStorage = torch.xpu.ByteStorage
torch.cuda.set_stream = torch.xpu.set_stream
torch.cuda.BoolStorage = torch.xpu.BoolStorage
torch.cuda.os = torch.xpu.os
torch.cuda.torch = torch.xpu.torch
torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
torch.cuda.Union = torch.xpu.Union
torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
torch.cuda.ShortTensor = torch.xpu.ShortTensor
torch.cuda.LongTensor = torch.xpu.LongTensor
torch.cuda.IntStorage = torch.xpu.IntStorage
torch.cuda.LongStorage = torch.xpu.LongStorage
torch.cuda.__annotations__ = torch.xpu.__annotations__
torch.cuda.__package__ = torch.xpu.__package__
torch.cuda.__builtins__ = torch.xpu.__builtins__
torch.cuda.CharTensor = torch.xpu.CharTensor
torch.cuda.List = torch.xpu.List
torch.cuda._lazy_init = torch.xpu._lazy_init
torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
torch.cuda.ByteTensor = torch.xpu.ByteTensor
torch.cuda.StreamContext = torch.xpu.StreamContext
torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
torch.cuda.ShortStorage = torch.xpu.ShortStorage
torch.cuda._lazy_call = torch.xpu._lazy_call
torch.cuda.HalfStorage = torch.xpu.HalfStorage
torch.cuda.random = torch.xpu.random
torch.cuda._device = torch.xpu._device
torch.cuda.classproperty = torch.xpu.classproperty
torch.cuda.__name__ = torch.xpu.__name__
torch.cuda._device_t = torch.xpu._device_t
torch.cuda.warnings = torch.xpu.warnings
torch.cuda.__spec__ = torch.xpu.__spec__
torch.cuda.BoolTensor = torch.xpu.BoolTensor
torch.cuda.CharStorage = torch.xpu.CharStorage
torch.cuda.__file__ = torch.xpu.__file__
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
# torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
# Memory:
torch.cuda.memory = torch.xpu.memory
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
torch.xpu.empty_cache = lambda: None
torch.cuda.empty_cache = torch.xpu.empty_cache
torch.cuda.memory_stats = torch.xpu.memory_stats
torch.cuda.memory_summary = torch.xpu.memory_summary
torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
torch.cuda.memory_allocated = torch.xpu.memory_allocated
torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated
torch.cuda.memory_reserved = torch.xpu.memory_reserved
torch.cuda.memory_cached = torch.xpu.memory_reserved
torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved
torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved
torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats
torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats
torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats
torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict
torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats
# RNG:
torch.cuda.get_rng_state = torch.xpu.get_rng_state
torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all
torch.cuda.set_rng_state = torch.xpu.set_rng_state
torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all
torch.cuda.manual_seed = torch.xpu.manual_seed
torch.cuda.manual_seed_all = torch.xpu.manual_seed_all
torch.cuda.seed = torch.xpu.seed
torch.cuda.seed_all = torch.xpu.seed_all
torch.cuda.initial_seed = torch.xpu.initial_seed
# AMP:
torch.cuda.amp = torch.xpu.amp
torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled
torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype
if not hasattr(torch.cuda.amp, "common"):
torch.cuda.amp.common = contextlib.nullcontext()
torch.cuda.amp.common.amp_definitely_not_available = lambda: False
try:
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
except Exception: # pylint: disable=broad-exception-caught
try:
from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error
gradscaler_init()
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
except Exception: # pylint: disable=broad-exception-caught
torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
# C
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
ipex._C._DeviceProperties.major = 2024
ipex._C._DeviceProperties.minor = 0
# Fix functions with ipex:
torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
torch._utils._get_available_device_type = lambda: "xpu"
torch.has_cuda = True
torch.cuda.has_half = True
torch.cuda.is_bf16_supported = lambda *args, **kwargs: True
torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
torch.backends.cuda.is_built = lambda *args, **kwargs: True
torch.version.cuda = "12.1"
torch.cuda.get_device_capability = lambda *args, **kwargs: [12,1]
torch.cuda.get_device_properties.major = 12
torch.cuda.get_device_properties.minor = 1
torch.cuda.ipc_collect = lambda *args, **kwargs: None
torch.cuda.utilization = lambda *args, **kwargs: 0
ipex_hijacks()
if not torch.xpu.has_fp64_dtype() or os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is not None:
try:
from .diffusers import ipex_diffusers
ipex_diffusers()
except Exception: # pylint: disable=broad-exception-caught
pass
torch.cuda.is_xpu_hijacked = True
except Exception as e:
return False, e
return True, None

177
library/ipex/attention.py Normal file
View File

@@ -0,0 +1,177 @@
import os
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
from functools import cache
# pylint: disable=protected-access, missing-function-docstring, line-too-long
# ARC GPUs can't allocate more than 4GB to a single block so we slice the attention layers
sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4))
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
# Find something divisible with the input_tokens
@cache
def find_slice_size(slice_size, slice_block_size):
while (slice_size * slice_block_size) > attention_slice_rate:
slice_size = slice_size // 2
if slice_size <= 1:
slice_size = 1
break
return slice_size
# Find slice sizes for SDPA
@cache
def find_sdpa_slice_sizes(query_shape, query_element_size):
if len(query_shape) == 3:
batch_size_attention, query_tokens, shape_three = query_shape
shape_four = 1
else:
batch_size_attention, query_tokens, shape_three, shape_four = query_shape
slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size
block_size = batch_size_attention * slice_block_size
split_slice_size = batch_size_attention
split_2_slice_size = query_tokens
split_3_slice_size = shape_three
do_split = False
do_split_2 = False
do_split_3 = False
if block_size > sdpa_slice_trigger_rate:
do_split = True
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
if split_slice_size * slice_block_size > attention_slice_rate:
slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size
do_split_2 = True
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size
do_split_3 = True
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
# Find slice sizes for BMM
@cache
def find_bmm_slice_sizes(input_shape, input_element_size, mat2_shape):
batch_size_attention, input_tokens, mat2_atten_shape = input_shape[0], input_shape[1], mat2_shape[2]
slice_block_size = input_tokens * mat2_atten_shape / 1024 / 1024 * input_element_size
block_size = batch_size_attention * slice_block_size
split_slice_size = batch_size_attention
split_2_slice_size = input_tokens
split_3_slice_size = mat2_atten_shape
do_split = False
do_split_2 = False
do_split_3 = False
if block_size > attention_slice_rate:
do_split = True
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
if split_slice_size * slice_block_size > attention_slice_rate:
slice_2_block_size = split_slice_size * mat2_atten_shape / 1024 / 1024 * input_element_size
do_split_2 = True
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
slice_3_block_size = split_slice_size * split_2_slice_size / 1024 / 1024 * input_element_size
do_split_3 = True
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
original_torch_bmm = torch.bmm
def torch_bmm_32_bit(input, mat2, *, out=None):
if input.device.type != "xpu":
return original_torch_bmm(input, mat2, out=out)
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_bmm_slice_sizes(input.shape, input.element_size(), mat2.shape)
# Slice BMM
if do_split:
batch_size_attention, input_tokens, mat2_atten_shape = input.shape[0], input.shape[1], mat2.shape[2]
hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype)
for i in range(batch_size_attention // split_slice_size):
start_idx = i * split_slice_size
end_idx = (i + 1) * split_slice_size
if do_split_2:
for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name
start_idx_2 = i2 * split_2_slice_size
end_idx_2 = (i2 + 1) * split_2_slice_size
if do_split_3:
for i3 in range(mat2_atten_shape // split_3_slice_size): # pylint: disable=invalid-name
start_idx_3 = i3 * split_3_slice_size
end_idx_3 = (i3 + 1) * split_3_slice_size
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_torch_bmm(
input[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
mat2[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
out=out
)
else:
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm(
input[start_idx:end_idx, start_idx_2:end_idx_2],
mat2[start_idx:end_idx, start_idx_2:end_idx_2],
out=out
)
else:
hidden_states[start_idx:end_idx] = original_torch_bmm(
input[start_idx:end_idx],
mat2[start_idx:end_idx],
out=out
)
torch.xpu.synchronize(input.device)
else:
return original_torch_bmm(input, mat2, out=out)
return hidden_states
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs):
if query.device.type != "xpu":
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_sdpa_slice_sizes(query.shape, query.element_size())
# Slice SDPA
if do_split:
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
for i in range(batch_size_attention // split_slice_size):
start_idx = i * split_slice_size
end_idx = (i + 1) * split_slice_size
if do_split_2:
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
start_idx_2 = i2 * split_2_slice_size
end_idx_2 = (i2 + 1) * split_2_slice_size
if do_split_3:
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
start_idx_3 = i3 * split_3_slice_size
end_idx_3 = (i3 + 1) * split_3_slice_size
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_scaled_dot_product_attention(
query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal, **kwargs
)
else:
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
query[start_idx:end_idx, start_idx_2:end_idx_2],
key[start_idx:end_idx, start_idx_2:end_idx_2],
value[start_idx:end_idx, start_idx_2:end_idx_2],
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal, **kwargs
)
else:
hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention(
query[start_idx:end_idx],
key[start_idx:end_idx],
value[start_idx:end_idx],
attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal, **kwargs
)
torch.xpu.synchronize(query.device)
else:
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
return hidden_states

312
library/ipex/diffusers.py Normal file
View File

@@ -0,0 +1,312 @@
import os
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
import diffusers #0.24.0 # pylint: disable=import-error
from diffusers.models.attention_processor import Attention
from diffusers.utils import USE_PEFT_BACKEND
from functools import cache
# pylint: disable=protected-access, missing-function-docstring, line-too-long
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
@cache
def find_slice_size(slice_size, slice_block_size):
while (slice_size * slice_block_size) > attention_slice_rate:
slice_size = slice_size // 2
if slice_size <= 1:
slice_size = 1
break
return slice_size
@cache
def find_attention_slice_sizes(query_shape, query_element_size, query_device_type, slice_size=None):
if len(query_shape) == 3:
batch_size_attention, query_tokens, shape_three = query_shape
shape_four = 1
else:
batch_size_attention, query_tokens, shape_three, shape_four = query_shape
if slice_size is not None:
batch_size_attention = slice_size
slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size
block_size = batch_size_attention * slice_block_size
split_slice_size = batch_size_attention
split_2_slice_size = query_tokens
split_3_slice_size = shape_three
do_split = False
do_split_2 = False
do_split_3 = False
if query_device_type != "xpu":
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
if block_size > attention_slice_rate:
do_split = True
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
if split_slice_size * slice_block_size > attention_slice_rate:
slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size
do_split_2 = True
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size
do_split_3 = True
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
r"""
Processor for implementing sliced attention.
Args:
slice_size (`int`, *optional*):
The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
`attention_head_dim` must be a multiple of the `slice_size`.
"""
def __init__(self, slice_size):
self.slice_size = slice_size
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
encoder_hidden_states=None, attention_mask=None) -> torch.FloatTensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
dim = query.shape[-1]
query = attn.head_to_batch_dim(query)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
batch_size_attention, query_tokens, shape_three = query.shape
hidden_states = torch.zeros(
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
)
####################################################################
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
_, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type, slice_size=self.slice_size)
for i in range(batch_size_attention // split_slice_size):
start_idx = i * split_slice_size
end_idx = (i + 1) * split_slice_size
if do_split_2:
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
start_idx_2 = i2 * split_2_slice_size
end_idx_2 = (i2 + 1) * split_2_slice_size
if do_split_3:
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
start_idx_3 = i3 * split_3_slice_size
end_idx_3 = (i3 + 1) * split_3_slice_size
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
del query_slice
del key_slice
del attn_mask_slice
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3])
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice
del attn_slice
else:
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
del query_slice
del key_slice
del attn_mask_slice
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
del attn_slice
torch.xpu.synchronize(query.device)
else:
query_slice = query[start_idx:end_idx]
key_slice = key[start_idx:end_idx]
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
del query_slice
del key_slice
del attn_mask_slice
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
del attn_slice
####################################################################
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class AttnProcessor:
r"""
Default processor for performing attention-related computations.
"""
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
encoder_hidden_states=None, attention_mask=None,
temb=None, scale: float = 1.0) -> torch.Tensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches
residual = hidden_states
args = () if USE_PEFT_BACKEND else (scale,)
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, *args)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
####################################################################
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type)
if do_split:
for i in range(batch_size_attention // split_slice_size):
start_idx = i * split_slice_size
end_idx = (i + 1) * split_slice_size
if do_split_2:
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
start_idx_2 = i2 * split_2_slice_size
end_idx_2 = (i2 + 1) * split_2_slice_size
if do_split_3:
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
start_idx_3 = i3 * split_3_slice_size
end_idx_3 = (i3 + 1) * split_3_slice_size
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
del query_slice
del key_slice
del attn_mask_slice
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3])
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice
del attn_slice
else:
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
del query_slice
del key_slice
del attn_mask_slice
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
del attn_slice
else:
query_slice = query[start_idx:end_idx]
key_slice = key[start_idx:end_idx]
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
del query_slice
del key_slice
del attn_mask_slice
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
del attn_slice
torch.xpu.synchronize(query.device)
else:
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
####################################################################
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
def ipex_diffusers():
#ARC GPUs can't allocate more than 4GB to a single block:
diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor
diffusers.models.attention_processor.AttnProcessor = AttnProcessor

183
library/ipex/gradscaler.py Normal file
View File

@@ -0,0 +1,183 @@
from collections import defaultdict
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
import intel_extension_for_pytorch._C as core # pylint: disable=import-error, unused-import
# pylint: disable=protected-access, missing-function-docstring, line-too-long
device_supports_fp64 = torch.xpu.has_fp64_dtype()
OptState = ipex.cpu.autocast._grad_scaler.OptState
_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator
_refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state
def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): # pylint: disable=unused-argument
per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
per_device_found_inf = _MultiDeviceReplicator(found_inf)
# To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
# There could be hundreds of grads, so we'd like to iterate through them just once.
# However, we don't know their devices or dtypes in advance.
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
# Google says mypy struggles with defaultdicts type annotations.
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
# sync grad to master weight
if hasattr(optimizer, "sync_grad"):
optimizer.sync_grad()
with torch.no_grad():
for group in optimizer.param_groups:
for param in group["params"]:
if param.grad is None:
continue
if (not allow_fp16) and param.grad.dtype == torch.float16:
raise ValueError("Attempting to unscale FP16 gradients.")
if param.grad.is_sparse:
# is_coalesced() == False means the sparse grad has values with duplicate indices.
# coalesce() deduplicates indices and adds all values that have the same index.
# For scaled fp16 values, there's a good chance coalescing will cause overflow,
# so we should check the coalesced _values().
if param.grad.dtype is torch.float16:
param.grad = param.grad.coalesce()
to_unscale = param.grad._values()
else:
to_unscale = param.grad
# -: is there a way to split by device and dtype without appending in the inner loop?
to_unscale = to_unscale.to("cpu")
per_device_and_dtype_grads[to_unscale.device][
to_unscale.dtype
].append(to_unscale)
for _, per_dtype_grads in per_device_and_dtype_grads.items():
for grads in per_dtype_grads.values():
core._amp_foreach_non_finite_check_and_unscale_(
grads,
per_device_found_inf.get("cpu"),
per_device_inv_scale.get("cpu"),
)
return per_device_found_inf._per_device_tensors
def unscale_(self, optimizer):
"""
Divides ("unscales") the optimizer's gradient tensors by the scale factor.
:meth:`unscale_` is optional, serving cases where you need to
:ref:`modify or inspect gradients<working-with-unscaled-gradients>`
between the backward pass(es) and :meth:`step`.
If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`.
Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
...
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
scaler.step(optimizer)
scaler.update()
Args:
optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled.
.. warning::
:meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
and only after all gradients for that optimizer's assigned parameters have been accumulated.
Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
.. warning::
:meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
"""
if not self._enabled:
return
self._check_scale_growth_tracker("unscale_")
optimizer_state = self._per_optimizer_states[id(optimizer)]
if optimizer_state["stage"] is OptState.UNSCALED: # pylint: disable=no-else-raise
raise RuntimeError(
"unscale_() has already been called on this optimizer since the last update()."
)
elif optimizer_state["stage"] is OptState.STEPPED:
raise RuntimeError("unscale_() is being called after step().")
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
assert self._scale is not None
if device_supports_fp64:
inv_scale = self._scale.double().reciprocal().float()
else:
inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device)
found_inf = torch.full(
(1,), 0.0, dtype=torch.float32, device=self._scale.device
)
optimizer_state["found_inf_per_device"] = self._unscale_grads_(
optimizer, inv_scale, found_inf, False
)
optimizer_state["stage"] = OptState.UNSCALED
def update(self, new_scale=None):
"""
Updates the scale factor.
If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
the scale is multiplied by ``growth_factor`` to increase it.
Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
used directly, it's used to fill GradScaler's internal scale tensor. So if
``new_scale`` was a tensor, later in-place changes to that tensor will not further
affect the scale GradScaler uses internally.)
Args:
new_scale (float or :class:`torch.FloatTensor`, optional, default=None): New scale factor.
.. warning::
:meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
been invoked for all optimizers used this iteration.
"""
if not self._enabled:
return
_scale, _growth_tracker = self._check_scale_growth_tracker("update")
if new_scale is not None:
# Accept a new user-defined scale.
if isinstance(new_scale, float):
self._scale.fill_(new_scale) # type: ignore[union-attr]
else:
reason = "new_scale should be a float or a 1-element torch.FloatTensor with requires_grad=False."
assert isinstance(new_scale, torch.FloatTensor), reason # type: ignore[attr-defined]
assert new_scale.numel() == 1, reason
assert new_scale.requires_grad is False, reason
self._scale.copy_(new_scale) # type: ignore[union-attr]
else:
# Consume shared inf/nan data collected from optimizers to update the scale.
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
found_infs = [
found_inf.to(device="cpu", non_blocking=True)
for state in self._per_optimizer_states.values()
for found_inf in state["found_inf_per_device"].values()
]
assert len(found_infs) > 0, "No inf checks were recorded prior to update."
found_inf_combined = found_infs[0]
if len(found_infs) > 1:
for i in range(1, len(found_infs)):
found_inf_combined += found_infs[i]
to_device = _scale.device
_scale = _scale.to("cpu")
_growth_tracker = _growth_tracker.to("cpu")
core._amp_update_scale_(
_scale,
_growth_tracker,
found_inf_combined,
self._growth_factor,
self._backoff_factor,
self._growth_interval,
)
_scale = _scale.to(to_device)
_growth_tracker = _growth_tracker.to(to_device)
# To prepare for next iteration, clear the data collected from optimizers this iteration.
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
def gradscaler_init():
torch.xpu.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
torch.xpu.amp.GradScaler._unscale_grads_ = _unscale_grads_
torch.xpu.amp.GradScaler.unscale_ = unscale_
torch.xpu.amp.GradScaler.update = update
return torch.xpu.amp.GradScaler

313
library/ipex/hijacks.py Normal file
View File

@@ -0,0 +1,313 @@
import os
from functools import wraps
from contextlib import nullcontext
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
import numpy as np
device_supports_fp64 = torch.xpu.has_fp64_dtype()
# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument
if isinstance(device_ids, list) and len(device_ids) > 1:
print("IPEX backend doesn't support DataParallel on multiple XPU devices")
return module.to("xpu")
def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
return nullcontext()
@property
def is_cuda(self):
return self.device.type == 'xpu' or self.device.type == 'cuda'
def check_device(device):
return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int))
def return_xpu(device):
return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu"
# Autocast
original_autocast_init = torch.amp.autocast_mode.autocast.__init__
@wraps(torch.amp.autocast_mode.autocast.__init__)
def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=None):
if device_type == "cuda":
return original_autocast_init(self, device_type="xpu", dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
else:
return original_autocast_init(self, device_type=device_type, dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
# Latent Antialias CPU Offload:
original_interpolate = torch.nn.functional.interpolate
@wraps(torch.nn.functional.interpolate)
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
if antialias or align_corners is not None or mode == 'bicubic':
return_device = tensor.device
return_dtype = tensor.dtype
return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode,
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias).to(return_device, dtype=return_dtype)
else:
return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode,
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias)
# Diffusers Float64 (Alchemist GPUs doesn't support 64 bit):
original_from_numpy = torch.from_numpy
@wraps(torch.from_numpy)
def from_numpy(ndarray):
if ndarray.dtype == float:
return original_from_numpy(ndarray.astype('float32'))
else:
return original_from_numpy(ndarray)
original_as_tensor = torch.as_tensor
@wraps(torch.as_tensor)
def as_tensor(data, dtype=None, device=None):
if check_device(device):
device = return_xpu(device)
if isinstance(data, np.ndarray) and data.dtype == float and not (
(isinstance(device, torch.device) and device.type == "cpu") or (isinstance(device, str) and "cpu" in device)):
return original_as_tensor(data, dtype=torch.float32, device=device)
else:
return original_as_tensor(data, dtype=dtype, device=device)
if device_supports_fp64 and os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is None:
original_torch_bmm = torch.bmm
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
else:
# 32 bit attention workarounds for Alchemist:
try:
from .attention import torch_bmm_32_bit as original_torch_bmm
from .attention import scaled_dot_product_attention_32_bit as original_scaled_dot_product_attention
except Exception: # pylint: disable=broad-exception-caught
original_torch_bmm = torch.bmm
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
# Data Type Errors:
@wraps(torch.bmm)
def torch_bmm(input, mat2, *, out=None):
if input.dtype != mat2.dtype:
mat2 = mat2.to(input.dtype)
return original_torch_bmm(input, mat2, out=out)
@wraps(torch.nn.functional.scaled_dot_product_attention)
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
if query.dtype != key.dtype:
key = key.to(dtype=query.dtype)
if query.dtype != value.dtype:
value = value.to(dtype=query.dtype)
if attn_mask is not None and query.dtype != attn_mask.dtype:
attn_mask = attn_mask.to(dtype=query.dtype)
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
# A1111 FP16
original_functional_group_norm = torch.nn.functional.group_norm
@wraps(torch.nn.functional.group_norm)
def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05):
if weight is not None and input.dtype != weight.data.dtype:
input = input.to(dtype=weight.data.dtype)
if bias is not None and weight is not None and bias.data.dtype != weight.data.dtype:
bias.data = bias.data.to(dtype=weight.data.dtype)
return original_functional_group_norm(input, num_groups, weight=weight, bias=bias, eps=eps)
# A1111 BF16
original_functional_layer_norm = torch.nn.functional.layer_norm
@wraps(torch.nn.functional.layer_norm)
def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
if weight is not None and input.dtype != weight.data.dtype:
input = input.to(dtype=weight.data.dtype)
if bias is not None and weight is not None and bias.data.dtype != weight.data.dtype:
bias.data = bias.data.to(dtype=weight.data.dtype)
return original_functional_layer_norm(input, normalized_shape, weight=weight, bias=bias, eps=eps)
# Training
original_functional_linear = torch.nn.functional.linear
@wraps(torch.nn.functional.linear)
def functional_linear(input, weight, bias=None):
if input.dtype != weight.data.dtype:
input = input.to(dtype=weight.data.dtype)
if bias is not None and bias.data.dtype != weight.data.dtype:
bias.data = bias.data.to(dtype=weight.data.dtype)
return original_functional_linear(input, weight, bias=bias)
original_functional_conv2d = torch.nn.functional.conv2d
@wraps(torch.nn.functional.conv2d)
def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
if input.dtype != weight.data.dtype:
input = input.to(dtype=weight.data.dtype)
if bias is not None and bias.data.dtype != weight.data.dtype:
bias.data = bias.data.to(dtype=weight.data.dtype)
return original_functional_conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
# A1111 Embedding BF16
original_torch_cat = torch.cat
@wraps(torch.cat)
def torch_cat(tensor, *args, **kwargs):
if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs)
else:
return original_torch_cat(tensor, *args, **kwargs)
# SwinIR BF16:
original_functional_pad = torch.nn.functional.pad
@wraps(torch.nn.functional.pad)
def functional_pad(input, pad, mode='constant', value=None):
if mode == 'reflect' and input.dtype == torch.bfloat16:
return original_functional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16)
else:
return original_functional_pad(input, pad, mode=mode, value=value)
original_torch_tensor = torch.tensor
@wraps(torch.tensor)
def torch_tensor(data, *args, dtype=None, device=None, **kwargs):
if check_device(device):
device = return_xpu(device)
if not device_supports_fp64:
if (isinstance(device, torch.device) and device.type == "xpu") or (isinstance(device, str) and "xpu" in device):
if dtype == torch.float64:
dtype = torch.float32
elif dtype is None and (hasattr(data, "dtype") and (data.dtype == torch.float64 or data.dtype == float)):
dtype = torch.float32
return original_torch_tensor(data, *args, dtype=dtype, device=device, **kwargs)
original_Tensor_to = torch.Tensor.to
@wraps(torch.Tensor.to)
def Tensor_to(self, device=None, *args, **kwargs):
if check_device(device):
return original_Tensor_to(self, return_xpu(device), *args, **kwargs)
else:
return original_Tensor_to(self, device, *args, **kwargs)
original_Tensor_cuda = torch.Tensor.cuda
@wraps(torch.Tensor.cuda)
def Tensor_cuda(self, device=None, *args, **kwargs):
if check_device(device):
return original_Tensor_cuda(self, return_xpu(device), *args, **kwargs)
else:
return original_Tensor_cuda(self, device, *args, **kwargs)
original_Tensor_pin_memory = torch.Tensor.pin_memory
@wraps(torch.Tensor.pin_memory)
def Tensor_pin_memory(self, device=None, *args, **kwargs):
if device is None:
device = "xpu"
if check_device(device):
return original_Tensor_pin_memory(self, return_xpu(device), *args, **kwargs)
else:
return original_Tensor_pin_memory(self, device, *args, **kwargs)
original_UntypedStorage_init = torch.UntypedStorage.__init__
@wraps(torch.UntypedStorage.__init__)
def UntypedStorage_init(*args, device=None, **kwargs):
if check_device(device):
return original_UntypedStorage_init(*args, device=return_xpu(device), **kwargs)
else:
return original_UntypedStorage_init(*args, device=device, **kwargs)
original_UntypedStorage_cuda = torch.UntypedStorage.cuda
@wraps(torch.UntypedStorage.cuda)
def UntypedStorage_cuda(self, device=None, *args, **kwargs):
if check_device(device):
return original_UntypedStorage_cuda(self, return_xpu(device), *args, **kwargs)
else:
return original_UntypedStorage_cuda(self, device, *args, **kwargs)
original_torch_empty = torch.empty
@wraps(torch.empty)
def torch_empty(*args, device=None, **kwargs):
if check_device(device):
return original_torch_empty(*args, device=return_xpu(device), **kwargs)
else:
return original_torch_empty(*args, device=device, **kwargs)
original_torch_randn = torch.randn
@wraps(torch.randn)
def torch_randn(*args, device=None, dtype=None, **kwargs):
if dtype == bytes:
dtype = None
if check_device(device):
return original_torch_randn(*args, device=return_xpu(device), **kwargs)
else:
return original_torch_randn(*args, device=device, **kwargs)
original_torch_ones = torch.ones
@wraps(torch.ones)
def torch_ones(*args, device=None, **kwargs):
if check_device(device):
return original_torch_ones(*args, device=return_xpu(device), **kwargs)
else:
return original_torch_ones(*args, device=device, **kwargs)
original_torch_zeros = torch.zeros
@wraps(torch.zeros)
def torch_zeros(*args, device=None, **kwargs):
if check_device(device):
return original_torch_zeros(*args, device=return_xpu(device), **kwargs)
else:
return original_torch_zeros(*args, device=device, **kwargs)
original_torch_linspace = torch.linspace
@wraps(torch.linspace)
def torch_linspace(*args, device=None, **kwargs):
if check_device(device):
return original_torch_linspace(*args, device=return_xpu(device), **kwargs)
else:
return original_torch_linspace(*args, device=device, **kwargs)
original_torch_Generator = torch.Generator
@wraps(torch.Generator)
def torch_Generator(device=None):
if check_device(device):
return original_torch_Generator(return_xpu(device))
else:
return original_torch_Generator(device)
original_torch_load = torch.load
@wraps(torch.load)
def torch_load(f, map_location=None, *args, **kwargs):
if map_location is None:
map_location = "xpu"
if check_device(map_location):
return original_torch_load(f, *args, map_location=return_xpu(map_location), **kwargs)
else:
return original_torch_load(f, *args, map_location=map_location, **kwargs)
# Hijack Functions:
def ipex_hijacks():
torch.tensor = torch_tensor
torch.Tensor.to = Tensor_to
torch.Tensor.cuda = Tensor_cuda
torch.Tensor.pin_memory = Tensor_pin_memory
torch.UntypedStorage.__init__ = UntypedStorage_init
torch.UntypedStorage.cuda = UntypedStorage_cuda
torch.empty = torch_empty
torch.randn = torch_randn
torch.ones = torch_ones
torch.zeros = torch_zeros
torch.linspace = torch_linspace
torch.Generator = torch_Generator
torch.load = torch_load
torch.backends.cuda.sdp_kernel = return_null_context
torch.nn.DataParallel = DummyDataParallel
torch.UntypedStorage.is_cuda = is_cuda
torch.amp.autocast_mode.autocast.__init__ = autocast_init
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
torch.nn.functional.group_norm = functional_group_norm
torch.nn.functional.layer_norm = functional_layer_norm
torch.nn.functional.linear = functional_linear
torch.nn.functional.conv2d = functional_conv2d
torch.nn.functional.interpolate = interpolate
torch.nn.functional.pad = functional_pad
torch.bmm = torch_bmm
torch.cat = torch_cat
if not device_supports_fp64:
torch.from_numpy = from_numpy
torch.as_tensor = as_tensor

View File

@@ -9,7 +9,7 @@ import numpy as np
import PIL.Image
import torch
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
import diffusers
from diffusers import SchedulerMixin, StableDiffusionPipeline
@@ -17,7 +17,6 @@ from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
from diffusers.utils import logging
try:
from diffusers.utils import PIL_INTERPOLATION
except ImportError:
@@ -520,6 +519,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
image_encoder: CLIPVisionModelWithProjection = None,
clip_skip: int = 1,
):
super().__init__(
@@ -531,32 +531,11 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
requires_safety_checker=requires_safety_checker,
image_encoder=image_encoder,
)
self.clip_skip = clip_skip
self.custom_clip_skip = clip_skip
self.__init__additional__()
# else:
# def __init__(
# self,
# vae: AutoencoderKL,
# text_encoder: CLIPTextModel,
# tokenizer: CLIPTokenizer,
# unet: UNet2DConditionModel,
# scheduler: SchedulerMixin,
# safety_checker: StableDiffusionSafetyChecker,
# feature_extractor: CLIPFeatureExtractor,
# ):
# super().__init__(
# vae=vae,
# text_encoder=text_encoder,
# tokenizer=tokenizer,
# unet=unet,
# scheduler=scheduler,
# safety_checker=safety_checker,
# feature_extractor=feature_extractor,
# )
# self.__init__additional__()
def __init__additional__(self):
if not hasattr(self, "vae_scale_factor"):
setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
@@ -624,7 +603,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
prompt=prompt,
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
max_embeddings_multiples=max_embeddings_multiples,
clip_skip=self.clip_skip,
clip_skip=self.custom_clip_skip,
)
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
@@ -646,7 +625,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if height % 8 != 0 or width % 8 != 0:
print(height, width)
logger.info(f'{height} {width}')
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (

View File

@@ -3,12 +3,20 @@
import math
import os
import torch
from library.device_utils import init_ipex
init_ipex()
import diffusers
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel
from safetensors.torch import load_file, save_file
from library.original_unet import UNet2DConditionModel
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# DiffUsers版StableDiffusionのモデルパラメータ
NUM_TRAIN_TIMESTEPS = 1000
@@ -564,9 +572,9 @@ def convert_ldm_clip_checkpoint_v1(checkpoint):
if key.startswith("cond_stage_model.transformer"):
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
# support checkpoint without position_ids (invalid checkpoint)
if "text_model.embeddings.position_ids" not in text_model_dict:
text_model_dict["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) # 77 is the max length of the text
# remove position_ids for newer transformer, which causes error :(
if "text_model.embeddings.position_ids" in text_model_dict:
text_model_dict.pop("text_model.embeddings.position_ids")
return text_model_dict
@@ -940,7 +948,7 @@ def convert_vae_state_dict(vae_state_dict):
for k, v in new_state_dict.items():
for weight_name in weights_to_convert:
if f"mid.attn_1.{weight_name}.weight" in k:
# print(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1")
# logger.info(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1")
new_state_dict[k] = reshape_weight_for_sd(v)
return new_state_dict
@@ -998,7 +1006,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt
unet = UNet2DConditionModel(**unet_config).to(device)
info = unet.load_state_dict(converted_unet_checkpoint)
print("loading u-net:", info)
logger.info(f"loading u-net: {info}")
# Convert the VAE model.
vae_config = create_vae_diffusers_config()
@@ -1006,7 +1014,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt
vae = AutoencoderKL(**vae_config).to(device)
info = vae.load_state_dict(converted_vae_checkpoint)
print("loading vae:", info)
logger.info(f"loading vae: {info}")
# convert text_model
if v2:
@@ -1040,7 +1048,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt
# logging.set_verbosity_error() # don't show annoying warning
# text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
# logging.set_verbosity_warning()
# print(f"config: {text_model.config}")
# logger.info(f"config: {text_model.config}")
cfg = CLIPTextConfig(
vocab_size=49408,
hidden_size=768,
@@ -1063,7 +1071,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt
)
text_model = CLIPTextModel._from_config(cfg)
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
print("loading text encoder:", info)
logger.info(f"loading text encoder: {info}")
return text_model, vae, unet
@@ -1138,7 +1146,7 @@ def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=Fals
# 最後の層などを捏造するか
if make_dummy_weights:
print("make dummy weights for resblock.23, text_projection and logit scale.")
logger.info("make dummy weights for resblock.23, text_projection and logit scale.")
keys = list(new_sd.keys())
for key in keys:
if key.startswith("transformer.resblocks.22."):
@@ -1235,8 +1243,13 @@ def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_mod
if vae is None:
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
# original U-Net cannot be saved, so we need to convert it to the Diffusers version
# TODO this consumes a lot of memory
diffusers_unet = diffusers.UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet")
diffusers_unet.load_state_dict(unet.state_dict())
pipeline = StableDiffusionPipeline(
unet=unet,
unet=diffusers_unet,
text_encoder=text_encoder,
vae=vae,
scheduler=scheduler,
@@ -1252,14 +1265,14 @@ VAE_PREFIX = "first_stage_model."
def load_vae(vae_id, dtype):
print(f"load VAE: {vae_id}")
logger.info(f"load VAE: {vae_id}")
if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
# Diffusers local/remote
try:
vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
except EnvironmentError as e:
print(f"exception occurs in loading vae: {e}")
print("retry with subfolder='vae'")
logger.error(f"exception occurs in loading vae: {e}")
logger.error("retry with subfolder='vae'")
vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
return vae
@@ -1300,19 +1313,19 @@ def load_vae(vae_id, dtype):
def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
max_width, max_height = max_reso
max_area = (max_width // divisible) * (max_height // divisible)
max_area = max_width * max_height
resos = set()
size = int(math.sqrt(max_area)) * divisible
resos.add((size, size))
width = int(math.sqrt(max_area) // divisible) * divisible
resos.add((width, width))
size = min_size
while size <= max_size:
width = size
height = min(max_size, (max_area // (width // divisible)) * divisible)
resos.add((width, height))
resos.add((height, width))
width = min_size
while width <= max_size:
height = min(max_size, int((max_area // width) // divisible) * divisible)
if height >= min_size:
resos.add((width, height))
resos.add((height, width))
# # make additional resos
# if width >= height and width - divisible >= min_size:
@@ -1322,7 +1335,7 @@ def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64)
# resos.add((width, height - divisible))
# resos.add((height - divisible, width))
size += divisible
width += divisible
resos = list(resos)
resos.sort()
@@ -1331,13 +1344,13 @@ def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64)
if __name__ == "__main__":
resos = make_bucket_resolutions((512, 768))
print(len(resos))
print(resos)
logger.info(f"{len(resos)}")
logger.info(f"{resos}")
aspect_ratios = [w / h for w, h in resos]
print(aspect_ratios)
logger.info(f"{aspect_ratios}")
ars = set()
for ar in aspect_ratios:
if ar in ars:
print("error! duplicate ar:", ar)
logger.error(f"error! duplicate ar: {ar}")
ars.add(ar)

View File

@@ -113,6 +113,10 @@ import torch
from torch import nn
from torch.nn import functional as F
from einops import rearrange
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
BLOCK_OUT_CHANNELS: Tuple[int] = (320, 640, 1280, 1280)
TIMESTEP_INPUT_DIM = BLOCK_OUT_CHANNELS[0]
@@ -131,7 +135,7 @@ DOWN_BLOCK_TYPES = ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDo
UP_BLOCK_TYPES = ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"]
# region memory effcient attention
# region memory efficient attention
# FlashAttentionを使うCrossAttention
# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
@@ -361,6 +365,23 @@ def get_timestep_embedding(
return emb
# Deep Shrink: We do not common this function, because minimize dependencies.
def resize_like(x, target, mode="bicubic", align_corners=False):
org_dtype = x.dtype
if org_dtype == torch.bfloat16:
x = x.to(torch.float32)
if x.shape[-2:] != target.shape[-2:]:
if mode == "nearest":
x = F.interpolate(x, size=target.shape[-2:], mode=mode)
else:
x = F.interpolate(x, size=target.shape[-2:], mode=mode, align_corners=align_corners)
if org_dtype == torch.bfloat16:
x = x.to(org_dtype)
return x
class SampleOutput:
def __init__(self, sample):
self.sample = sample
@@ -569,6 +590,9 @@ class CrossAttention(nn.Module):
self.use_memory_efficient_attention_mem_eff = False
self.use_sdpa = False
# Attention processor
self.processor = None
def set_use_memory_efficient_attention(self, xformers, mem_eff):
self.use_memory_efficient_attention_xformers = xformers
self.use_memory_efficient_attention_mem_eff = mem_eff
@@ -590,7 +614,28 @@ class CrossAttention(nn.Module):
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
def forward(self, hidden_states, context=None, mask=None):
def set_processor(self):
return self.processor
def get_processor(self):
return self.processor
def forward(self, hidden_states, context=None, mask=None, **kwargs):
if self.processor is not None:
(
hidden_states,
encoder_hidden_states,
attention_mask,
) = translate_attention_names_from_diffusers(
hidden_states=hidden_states, context=context, mask=mask, **kwargs
)
return self.processor(
attn=self,
hidden_states=hidden_states,
encoder_hidden_states=context,
attention_mask=mask,
**kwargs
)
if self.use_memory_efficient_attention_xformers:
return self.forward_memory_efficient_xformers(hidden_states, context, mask)
if self.use_memory_efficient_attention_mem_eff:
@@ -703,6 +748,21 @@ class CrossAttention(nn.Module):
out = self.to_out[0](out)
return out
def translate_attention_names_from_diffusers(
hidden_states: torch.FloatTensor,
context: Optional[torch.FloatTensor] = None,
mask: Optional[torch.FloatTensor] = None,
# HF naming
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None
):
# translate from hugging face diffusers
context = context if context is not None else encoder_hidden_states
# translate from hugging face diffusers
mask = mask if mask is not None else attention_mask
return hidden_states, context, mask
# feedforward
class GEGLU(nn.Module):
@@ -1130,6 +1190,7 @@ class UpBlock2D(nn.Module):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
@@ -1205,9 +1266,9 @@ class CrossAttnUpBlock2D(nn.Module):
for attn in self.attentions:
attn.set_use_memory_efficient_attention(xformers, mem_eff)
def set_use_sdpa(self, spda):
def set_use_sdpa(self, sdpa):
for attn in self.attentions:
attn.set_use_sdpa(spda)
attn.set_use_sdpa(sdpa)
def forward(
self,
@@ -1221,6 +1282,7 @@ class CrossAttnUpBlock2D(nn.Module):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
@@ -1322,7 +1384,7 @@ class UNet2DConditionModel(nn.Module):
):
super().__init__()
assert sample_size is not None, "sample_size must be specified"
print(
logger.info(
f"UNet2DConditionModel: {sample_size}, {attention_head_dim}, {cross_attention_dim}, {use_linear_projection}, {upcast_attention}"
)
@@ -1331,7 +1393,7 @@ class UNet2DConditionModel(nn.Module):
self.out_channels = OUT_CHANNELS
self.sample_size = sample_size
self.prepare_config()
self.prepare_config(sample_size=sample_size)
# state_dictの書式が変わるのでmoduleの持ち方は変えられない
@@ -1418,8 +1480,8 @@ class UNet2DConditionModel(nn.Module):
self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1)
# region diffusers compatibility
def prepare_config(self):
self.config = SimpleNamespace()
def prepare_config(self, *args, **kwargs):
self.config = SimpleNamespace(**kwargs)
@property
def dtype(self) -> torch.dtype:
@@ -1456,7 +1518,7 @@ class UNet2DConditionModel(nn.Module):
def set_gradient_checkpointing(self, value=False):
modules = self.down_blocks + [self.mid_block] + self.up_blocks
for module in modules:
print(module.__class__.__name__, module.gradient_checkpointing, "->", value)
logger.info(f"{module.__class__.__name__} {module.gradient_checkpointing} -> {value}")
module.gradient_checkpointing = value
# endregion
@@ -1519,7 +1581,6 @@ class UNet2DConditionModel(nn.Module):
# 2. pre-process
sample = self.conv_in(sample)
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
# downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
@@ -1604,3 +1665,255 @@ class UNet2DConditionModel(nn.Module):
timesteps = timesteps.expand(sample.shape[0])
return timesteps
class InferUNet2DConditionModel:
def __init__(self, original_unet: UNet2DConditionModel):
self.delegate = original_unet
# override original model's forward method: because forward is not called by `__call__`
# overriding `__call__` is not enough, because nn.Module.forward has a special handling
self.delegate.forward = self.forward
# override original model's up blocks' forward method
for up_block in self.delegate.up_blocks:
if up_block.__class__.__name__ == "UpBlock2D":
def resnet_wrapper(func, block):
def forward(*args, **kwargs):
return func(block, *args, **kwargs)
return forward
up_block.forward = resnet_wrapper(self.up_block_forward, up_block)
elif up_block.__class__.__name__ == "CrossAttnUpBlock2D":
def cross_attn_up_wrapper(func, block):
def forward(*args, **kwargs):
return func(block, *args, **kwargs)
return forward
up_block.forward = cross_attn_up_wrapper(self.cross_attn_up_block_forward, up_block)
# Deep Shrink
self.ds_depth_1 = None
self.ds_depth_2 = None
self.ds_timesteps_1 = None
self.ds_timesteps_2 = None
self.ds_ratio = None
# call original model's methods
def __getattr__(self, name):
return getattr(self.delegate, name)
def __call__(self, *args, **kwargs):
return self.delegate(*args, **kwargs)
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
if ds_depth_1 is None:
logger.info("Deep Shrink is disabled.")
self.ds_depth_1 = None
self.ds_timesteps_1 = None
self.ds_depth_2 = None
self.ds_timesteps_2 = None
self.ds_ratio = None
else:
logger.info(
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
)
self.ds_depth_1 = ds_depth_1
self.ds_timesteps_1 = ds_timesteps_1
self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
self.ds_ratio = ds_ratio
def up_block_forward(self, _self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
for resnet in _self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# Deep Shrink
if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
hidden_states = resize_like(hidden_states, res_hidden_states)
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
if _self.upsamplers is not None:
for upsampler in _self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
def cross_attn_up_block_forward(
self,
_self,
hidden_states,
res_hidden_states_tuple,
temb=None,
encoder_hidden_states=None,
upsample_size=None,
):
for resnet, attn in zip(_self.resnets, _self.attentions):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# Deep Shrink
if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
hidden_states = resize_like(hidden_states, res_hidden_states)
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
if _self.upsamplers is not None:
for upsampler in _self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None,
return_dict: bool = True,
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
mid_block_additional_residual: Optional[torch.Tensor] = None,
) -> Union[Dict, Tuple]:
r"""
current implementation is a copy of `UNet2DConditionModel.forward()` with Deep Shrink.
"""
r"""
Args:
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a dict instead of a plain tuple.
Returns:
`SampleOutput` or `tuple`:
`SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
"""
_self = self.delegate
# By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
# However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary.
# デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある
# ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する
# 多分画質が悪くなるので、64で割り切れるようにしておくのが良い
default_overall_up_factor = 2**_self.num_upsamplers
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
# 64で割り切れないときはupsamplerにサイズを伝える
forward_upsample_size = False
upsample_size = None
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
# logger.info("Forward upsample size to force interpolation output size.")
forward_upsample_size = True
# 1. time
timesteps = timestep
timesteps = _self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理
t_emb = _self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
# timestepsは重みを含まないので常にfloat32のテンソルを返す
# しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある
# time_projでキャストしておけばいいんじゃね
t_emb = t_emb.to(dtype=_self.dtype)
emb = _self.time_embedding(t_emb)
# 2. pre-process
sample = _self.conv_in(sample)
down_block_res_samples = (sample,)
for depth, downsample_block in enumerate(_self.down_blocks):
# Deep Shrink
if self.ds_depth_1 is not None:
if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
self.ds_depth_2 is not None
and depth == self.ds_depth_2
and timesteps[0] < self.ds_timesteps_1
and timesteps[0] >= self.ds_timesteps_2
):
org_dtype = sample.dtype
if org_dtype == torch.bfloat16:
sample = sample.to(torch.float32)
sample = F.interpolate(sample, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype)
# downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
# まあこちらのほうがわかりやすいかもしれない
if downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
# skip connectionにControlNetの出力を追加する
if down_block_additional_residuals is not None:
down_block_res_samples = list(down_block_res_samples)
for i in range(len(down_block_res_samples)):
down_block_res_samples[i] += down_block_additional_residuals[i]
down_block_res_samples = tuple(down_block_res_samples)
# 4. mid
sample = _self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
# ControlNetの出力を追加する
if mid_block_additional_residual is not None:
sample += mid_block_additional_residual
# 5. up
for i, upsample_block in enumerate(_self.up_blocks):
is_final_block = i == len(_self.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection
# if we have not reached the final block and need to forward the upsample size, we do it here
# 前述のように最後のブロック以外ではupsample_sizeを伝える
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if upsample_block.has_cross_attention:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
upsample_size=upsample_size,
)
else:
sample = upsample_block(
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
)
# 6. post-process
sample = _self.conv_norm_out(sample)
sample = _self.conv_act(sample)
sample = _self.conv_out(sample)
if not return_dict:
return (sample,)
return SampleOutput(sample=sample)

View File

@@ -5,6 +5,10 @@ from io import BytesIO
import os
from typing import List, Optional, Tuple, Union
import safetensors
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
r"""
# Metadata Example
@@ -231,7 +235,7 @@ def build_metadata(
# # assert all values are filled
# assert all([v is not None for v in metadata.values()]), metadata
if not all([v is not None for v in metadata.values()]):
print(f"Internal error: some metadata values are None: {metadata}")
logger.error(f"Internal error: some metadata values are None: {metadata}")
return metadata

View File

@@ -923,7 +923,11 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
if up1 is not None:
uncond_pool = up1
dtype = self.unet.dtype
unet_dtype = self.unet.dtype
dtype = unet_dtype
if hasattr(dtype, "itemsize") and dtype.itemsize == 1: # fp8
dtype = torch.float16
self.unet.to(dtype)
# 4. Preprocess image and mask
if isinstance(image, PIL.Image.Image):
@@ -1028,6 +1032,7 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
if is_cancelled_callback is not None and is_cancelled_callback():
return None
self.unet.to(unet_dtype)
return latents
def latents_to_image(self, latents):

View File

@@ -1,4 +1,5 @@
import torch
import safetensors
from accelerate import init_empty_weights
from accelerate.utils.modeling import set_module_tensor_to_device
from safetensors.torch import load_file, save_file
@@ -7,7 +8,12 @@ from typing import List
from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
from library import model_util
from library import sdxl_original_unet
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
VAE_SCALE_FACTOR = 0.13025
MODEL_VERSION_SDXL_BASE_V1_0 = "sdxl_base_v1-0"
@@ -100,7 +106,7 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
key = key.replace(".ln_final", ".final_layer_norm")
# ckpt from comfy has this key: text_model.encoder.text_model.embeddings.position_ids
elif ".embeddings.position_ids" in key:
key = None # remove this key: make position_ids by ourselves
key = None # remove this key: position_ids is not used in newer transformers
return key
keys = list(checkpoint.keys())
@@ -126,13 +132,15 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
# original SD にはないので、position_idsを追加
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
new_sd["text_model.embeddings.position_ids"] = position_ids
# logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す
logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None)
# temporary workaround for text_projection.weight.weight for Playground-v2
if "text_projection.weight.weight" in new_sd:
logger.info("convert_sdxl_text_encoder_2_checkpoint: convert text_projection.weight.weight to text_projection.weight")
new_sd["text_projection.weight"] = new_sd["text_projection.weight.weight"]
del new_sd["text_projection.weight.weight"]
return new_sd, logit_scale
@@ -158,17 +166,20 @@ def _load_state_dict_on_device(model, state_dict, device, dtype=None):
raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs)))
def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None):
def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None, disable_mmap=False):
# model_version is reserved for future use
# dtype is used for full_fp16/bf16 integration. Text Encoder will remain fp32, because it runs on CPU when caching
# Load the state dict
if model_util.is_safetensors(ckpt_path):
checkpoint = None
try:
state_dict = load_file(ckpt_path, device=map_location)
except:
state_dict = load_file(ckpt_path) # prevent device invalid Error
if disable_mmap:
state_dict = safetensors.torch.load(open(ckpt_path, "rb").read())
else:
try:
state_dict = load_file(ckpt_path, device=map_location)
except:
state_dict = load_file(ckpt_path) # prevent device invalid Error
epoch = None
global_step = None
else:
@@ -184,20 +195,20 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
checkpoint = None
# U-Net
print("building U-Net")
logger.info("building U-Net")
with init_empty_weights():
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
print("loading U-Net from checkpoint")
logger.info("loading U-Net from checkpoint")
unet_sd = {}
for k in list(state_dict.keys()):
if k.startswith("model.diffusion_model."):
unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k)
info = _load_state_dict_on_device(unet, unet_sd, device=map_location, dtype=dtype)
print("U-Net: ", info)
logger.info(f"U-Net: {info}")
# Text Encoders
print("building text encoders")
logger.info("building text encoders")
# Text Encoder 1 is same to Stability AI's SDXL
text_model1_cfg = CLIPTextConfig(
@@ -250,7 +261,7 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
with init_empty_weights():
text_model2 = CLIPTextModelWithProjection(text_model2_cfg)
print("loading text encoders from checkpoint")
logger.info("loading text encoders from checkpoint")
te1_sd = {}
te2_sd = {}
for k in list(state_dict.keys()):
@@ -258,28 +269,28 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
te1_sd[k.replace("conditioner.embedders.0.transformer.", "")] = state_dict.pop(k)
elif k.startswith("conditioner.embedders.1.model."):
te2_sd[k] = state_dict.pop(k)
# 一部のposition_idsがないモデルへの対応 / add position_ids for some models
if "text_model.embeddings.position_ids" not in te1_sd:
te1_sd["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0)
# 最新の transformers では position_ids を含むとエラーになるので削除 / remove position_ids for latest transformers
if "text_model.embeddings.position_ids" in te1_sd:
te1_sd.pop("text_model.embeddings.position_ids")
info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32
print("text encoder 1:", info1)
logger.info(f"text encoder 1: {info1}")
converted_sd, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77)
info2 = _load_state_dict_on_device(text_model2, converted_sd, device=map_location) # remain fp32
print("text encoder 2:", info2)
logger.info(f"text encoder 2: {info2}")
# prepare vae
print("building VAE")
logger.info("building VAE")
vae_config = model_util.create_vae_diffusers_config()
with init_empty_weights():
vae = AutoencoderKL(**vae_config)
print("loading VAE from checkpoint")
logger.info("loading VAE from checkpoint")
converted_vae_checkpoint = model_util.convert_ldm_vae_checkpoint(state_dict, vae_config)
info = _load_state_dict_on_device(vae, converted_vae_checkpoint, device=map_location, dtype=dtype)
print("VAE:", info)
logger.info(f"VAE: {info}")
ckpt_info = (epoch, global_step) if epoch is not None else None
return text_model1, text_model2, vae, unet, logit_scale, ckpt_info

View File

@@ -24,13 +24,18 @@
import math
from types import SimpleNamespace
from typing import Optional
from typing import Any, Optional
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import functional as F
from einops import rearrange
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
IN_CHANNELS: int = 4
OUT_CHANNELS: int = 4
@@ -41,7 +46,7 @@ TIME_EMBED_DIM = 320 * 4
USE_REENTRANT = True
# region memory effcient attention
# region memory efficient attention
# FlashAttentionを使うCrossAttention
# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
@@ -266,6 +271,23 @@ def get_timestep_embedding(
return emb
# Deep Shrink: We do not common this function, because minimize dependencies.
def resize_like(x, target, mode="bicubic", align_corners=False):
org_dtype = x.dtype
if org_dtype == torch.bfloat16:
x = x.to(torch.float32)
if x.shape[-2:] != target.shape[-2:]:
if mode == "nearest":
x = F.interpolate(x, size=target.shape[-2:], mode=mode)
else:
x = F.interpolate(x, size=target.shape[-2:], mode=mode, align_corners=align_corners)
if org_dtype == torch.bfloat16:
x = x.to(org_dtype)
return x
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
if self.weight.dtype != torch.float32:
@@ -315,7 +337,7 @@ class ResnetBlock2D(nn.Module):
def forward(self, x, emb):
if self.training and self.gradient_checkpointing:
# print("ResnetBlock2D: gradient_checkpointing")
# logger.info("ResnetBlock2D: gradient_checkpointing")
def create_custom_forward(func):
def custom_forward(*inputs):
@@ -349,7 +371,7 @@ class Downsample2D(nn.Module):
def forward(self, hidden_states):
if self.training and self.gradient_checkpointing:
# print("Downsample2D: gradient_checkpointing")
# logger.info("Downsample2D: gradient_checkpointing")
def create_custom_forward(func):
def custom_forward(*inputs):
@@ -636,7 +658,7 @@ class BasicTransformerBlock(nn.Module):
def forward(self, hidden_states, context=None, timestep=None):
if self.training and self.gradient_checkpointing:
# print("BasicTransformerBlock: checkpointing")
# logger.info("BasicTransformerBlock: checkpointing")
def create_custom_forward(func):
def custom_forward(*inputs):
@@ -779,7 +801,7 @@ class Upsample2D(nn.Module):
def forward(self, hidden_states, output_size=None):
if self.training and self.gradient_checkpointing:
# print("Upsample2D: gradient_checkpointing")
# logger.info("Upsample2D: gradient_checkpointing")
def create_custom_forward(func):
def custom_forward(*inputs):
@@ -996,109 +1018,6 @@ class SdxlUNet2DConditionModel(nn.Module):
[GroupNorm32(32, self.model_channels), nn.SiLU(), nn.Conv2d(self.model_channels, self.out_channels, 3, padding=1)]
)
# FreeU
self.freeU = False
self.freeUB1 = 1.0
self.freeUB2 = 1.0
self.freeUS1 = 1.0
self.freeUS2 = 1.0
self.freeURThres = 1
# implementation of FreeU
# FreeU: Free Lunch in Diffusion U-Net https://arxiv.org/abs/2309.11497
def set_free_u_enabled(self, enabled: bool, b1=1.0, b2=1.0, s1=1.0, s2=1.0, rthresh=1):
print(f"FreeU: {enabled}, b1={b1}, b2={b2}, s1={s1}, s2={s2}, rthresh={rthresh}")
self.freeU = enabled
self.freeUB1 = b1
self.freeUB2 = b2
self.freeUS1 = s1
self.freeUS2 = s2
self.freeURThres = rthresh
def spectral_modulation(self, skip_feature, sl=1.0, rthresh=1):
"""
スキップ特徴を周波数領域で修正する関数
:param skip_feature: スキップ特徴のテンソル [b, c, H, W]
:param sl: スケーリング係数
:param rthresh: 周波数の閾値
:return: 修正されたスキップ特徴
"""
import torch.fft
r"""
# 論文に従った実装
org_dtype = skip_feature.dtype
if org_dtype == torch.bfloat16:
skip_feature = skip_feature.to(torch.float32)
# FFTを計算
F = torch.fft.fftn(skip_feature, dim=(2, 3))
# 周波数領域での座標を計算
freq_x = torch.fft.fftfreq(skip_feature.size(2), d=1 / skip_feature.size(2)).to(skip_feature.device)
freq_y = torch.fft.fftfreq(skip_feature.size(3), d=1 / skip_feature.size(3)).to(skip_feature.device)
# 2Dグリッドを作成
freq_x = freq_x[:, None] # [H, 1]
freq_y = freq_y[None, :] # [1, W]
# ラジアス(距離)を計算
r = torch.sqrt(freq_x**2 + freq_y**2)
# 32,32: tensor(0., device='cuda:0') tensor(22.6274, device='cuda:0') tensor(12.2521, device='cuda:0')
# 64,64: tensor(0., device='cuda:0') tensor(45.2548, device='cuda:0') tensor(24.4908, device='cuda:0')
# 128,128: tensor(0., device='cuda:0') tensor(90.5097, device='cuda:0') tensor(48.9748, device='cuda:0')
# マスクを作成
mask = torch.ones_like(r)
mask[r < rthresh] = sl
# b,c,H,Wの形状にブロードキャスト
# TODO shapeごとに同じなのでキャッシュすると良さそう
mask = mask[None, None, :, :]
# 周波数領域での要素ごとの乗算
F_prime = F * mask
# 逆FFTを計算
modified_skip_feature = torch.fft.ifftn(F_prime, dim=(2, 3))
modified_skip_feature = modified_skip_feature.real # 実部のみを取得
"""
# 公式リポジトリの実装
org_dtype = skip_feature.dtype
x = skip_feature
threshold = rthresh
scale = sl
# FFT
x_freq = torch.fft.fftn(x.float(), dim=(-2, -1))
x_freq = torch.fft.fftshift(x_freq, dim=(-2, -1))
B, C, H, W = x_freq.shape
mask = torch.ones((B, C, H, W), device=x.device)
crow, ccol = H // 2, W // 2
mask[..., crow - threshold : crow + threshold, ccol - threshold : ccol + threshold] = scale
x_freq = x_freq * mask
# IFFT
x_freq = torch.fft.ifftshift(x_freq, dim=(-2, -1))
x_filtered = torch.fft.ifftn(x_freq, dim=(-2, -1)).real
modified_skip_feature = x_filtered
# if org_dtype == torch.bfloat16:
modified_skip_feature = modified_skip_feature.to(org_dtype)
return modified_skip_feature
# region diffusers compatibility
def prepare_config(self):
self.config = SimpleNamespace()
@@ -1132,7 +1051,7 @@ class SdxlUNet2DConditionModel(nn.Module):
for block in blocks:
for module in block:
if hasattr(module, "set_use_memory_efficient_attention"):
# print(module.__class__.__name__)
# logger.info(module.__class__.__name__)
module.set_use_memory_efficient_attention(xformers, mem_eff)
def set_use_sdpa(self, sdpa: bool) -> None:
@@ -1147,7 +1066,7 @@ class SdxlUNet2DConditionModel(nn.Module):
for block in blocks:
for module in block.modules():
if hasattr(module, "gradient_checkpointing"):
# print(module.__class__.__name__, module.gradient_checkpointing, "->", value)
# logger.info(f{module.__class__.__name__} {module.gradient_checkpointing} -> {value}")
module.gradient_checkpointing = value
# endregion
@@ -1157,7 +1076,7 @@ class SdxlUNet2DConditionModel(nn.Module):
timesteps = timesteps.expand(x.shape[0])
hs = []
t_emb = get_timestep_embedding(timesteps, self.model_channels) # , repeat_only=False)
t_emb = get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0) # , repeat_only=False)
t_emb = t_emb.to(x.dtype)
emb = self.time_embed(t_emb)
@@ -1166,6 +1085,96 @@ class SdxlUNet2DConditionModel(nn.Module):
# assert x.dtype == self.dtype
emb = emb + self.label_emb(y)
def call_module(module, h, emb, context):
x = h
for layer in module:
# logger.info(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None)
if isinstance(layer, ResnetBlock2D):
x = layer(x, emb)
elif isinstance(layer, Transformer2DModel):
x = layer(x, context)
else:
x = layer(x)
return x
# h = x.type(self.dtype)
h = x
for module in self.input_blocks:
h = call_module(module, h, emb, context)
hs.append(h)
h = call_module(self.middle_block, h, emb, context)
for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1)
h = call_module(module, h, emb, context)
h = h.type(x.dtype)
h = call_module(self.out, h, emb, context)
return h
class InferSdxlUNet2DConditionModel:
def __init__(self, original_unet: SdxlUNet2DConditionModel, **kwargs):
self.delegate = original_unet
# override original model's forward method: because forward is not called by `__call__`
# overriding `__call__` is not enough, because nn.Module.forward has a special handling
self.delegate.forward = self.forward
# Deep Shrink
self.ds_depth_1 = None
self.ds_depth_2 = None
self.ds_timesteps_1 = None
self.ds_timesteps_2 = None
self.ds_ratio = None
# call original model's methods
def __getattr__(self, name):
return getattr(self.delegate, name)
def __call__(self, *args, **kwargs):
return self.delegate(*args, **kwargs)
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
if ds_depth_1 is None:
logger.info("Deep Shrink is disabled.")
self.ds_depth_1 = None
self.ds_timesteps_1 = None
self.ds_depth_2 = None
self.ds_timesteps_2 = None
self.ds_ratio = None
else:
logger.info(
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
)
self.ds_depth_1 = ds_depth_1
self.ds_timesteps_1 = ds_timesteps_1
self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
self.ds_ratio = ds_ratio
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
r"""
current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink.
"""
_self = self.delegate
# broadcast timesteps to batch dimension
timesteps = timesteps.expand(x.shape[0])
hs = []
t_emb = get_timestep_embedding(timesteps, _self.model_channels, downscale_freq_shift=0) # , repeat_only=False)
t_emb = t_emb.to(x.dtype)
emb = _self.time_embed(t_emb)
assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
# assert x.dtype == _self.dtype
emb = emb + _self.label_emb(y)
def call_module(module, h, emb, context):
x = h
for layer in module:
@@ -1180,37 +1189,44 @@ class SdxlUNet2DConditionModel(nn.Module):
# h = x.type(self.dtype)
h = x
for module in self.input_blocks:
for depth, module in enumerate(_self.input_blocks):
# Deep Shrink
if self.ds_depth_1 is not None:
if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
self.ds_depth_2 is not None
and depth == self.ds_depth_2
and timesteps[0] < self.ds_timesteps_1
and timesteps[0] >= self.ds_timesteps_2
):
# print("downsample", h.shape, self.ds_ratio)
org_dtype = h.dtype
if org_dtype == torch.bfloat16:
h = h.to(torch.float32)
h = F.interpolate(h, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype)
h = call_module(module, h, emb, context)
hs.append(h)
if self.freeU:
ch = h.shape[1]
s = self.freeUS1 if ch == 1280 else (self.freeUS2 if ch == 640 else 1.0)
if s == 1.0:
h_mod = h
else:
h_mod = self.spectral_modulation(h, s, self.freeURThres)
hs.append(h_mod)
else:
hs.append(h)
h = call_module(_self.middle_block, h, emb, context)
h = call_module(self.middle_block, h, emb, context)
for module in self.output_blocks:
if self.freeU:
ch = h.shape[1]
if ch == 1280:
h[:, : ch // 2] = h[:, : ch // 2] * self.freeUB1
elif ch == 640:
h[:, : ch // 2] = h[:, : ch // 2] * self.freeUB2
# else:
# print(f"disable freeU: {ch}")
for module in _self.output_blocks:
# Deep Shrink
if self.ds_depth_1 is not None:
if hs[-1].shape[-2:] != h.shape[-2:]:
# print("upsample", h.shape, hs[-1].shape)
h = resize_like(h, hs[-1])
h = torch.cat([h, hs.pop()], dim=1)
h = call_module(module, h, emb, context)
# Deep Shrink: in case of depth 0
if self.ds_depth_1 == 0 and h.shape[-2:] != x.shape[-2:]:
# print("upsample", h.shape, x.shape)
h = resize_like(h, x)
h = h.type(x.dtype)
h = call_module(self.out, h, emb, context)
h = call_module(_self.out, h, emb, context)
return h
@@ -1218,7 +1234,7 @@ class SdxlUNet2DConditionModel(nn.Module):
if __name__ == "__main__":
import time
print("create unet")
logger.info("create unet")
unet = SdxlUNet2DConditionModel()
unet.to("cuda")
@@ -1227,7 +1243,7 @@ if __name__ == "__main__":
unet.train()
# 使用メモリ量確認用の疑似学習ループ
print("preparing optimizer")
logger.info("preparing optimizer")
# optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working
@@ -1242,12 +1258,12 @@ if __name__ == "__main__":
scaler = torch.cuda.amp.GradScaler(enabled=True)
print("start training")
logger.info("start training")
steps = 10
batch_size = 1
for step in range(steps):
print(f"step {step}")
logger.info(f"step {step}")
if step == 1:
time_start = time.perf_counter()
@@ -1267,4 +1283,4 @@ if __name__ == "__main__":
optimizer.zero_grad(set_to_none=True)
time_end = time.perf_counter()
print(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps")
logger.info(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps")

View File

@@ -1,14 +1,24 @@
import argparse
import gc
import math
import os
from typing import Optional
import torch
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate import init_empty_weights
from tqdm import tqdm
from transformers import CLIPTokenizer
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
@@ -17,11 +27,10 @@ TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
def load_target_model(args, accelerator, model_version: str, weight_dtype):
# load models for each process
model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16
for pi in range(accelerator.state.num_processes):
if pi == accelerator.state.local_process_index:
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
(
load_stable_diffusion_format,
@@ -38,6 +47,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
weight_dtype,
accelerator.device if args.lowram else "cpu",
model_dtype,
args.disable_mmap_load_safetensors,
)
# work on low-ram device
@@ -47,24 +57,21 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
unet.to(accelerator.device)
vae.to(accelerator.device)
gc.collect()
torch.cuda.empty_cache()
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
text_encoder1, text_encoder2, unet = train_util.transform_models_if_DDP([text_encoder1, text_encoder2, unet])
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
def _load_target_model(
name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None
name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None, disable_mmap=False
):
# model_dtype only work with full fp16/bf16
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
if load_stable_diffusion_format:
print(f"load StableDiffusion checkpoint: {name_or_path}")
logger.info(f"load StableDiffusion checkpoint: {name_or_path}")
(
text_encoder1,
text_encoder2,
@@ -72,13 +79,13 @@ def _load_target_model(
unet,
logit_scale,
ckpt_info,
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, model_dtype)
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, model_dtype, disable_mmap)
else:
# Diffusers model is loaded to CPU
from diffusers import StableDiffusionXLPipeline
variant = "fp16" if weight_dtype == torch.float16 else None
print(f"load Diffusers pretrained models: {name_or_path}, variant={variant}")
logger.info(f"load Diffusers pretrained models: {name_or_path}, variant={variant}")
try:
try:
pipe = StableDiffusionXLPipeline.from_pretrained(
@@ -86,12 +93,12 @@ def _load_target_model(
)
except EnvironmentError as ex:
if variant is not None:
print("try to load fp32 model")
logger.info("try to load fp32 model")
pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, variant=None, tokenizer=None)
else:
raise ex
except EnvironmentError as ex:
print(
logger.error(
f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}"
)
raise ex
@@ -114,7 +121,7 @@ def _load_target_model(
with init_empty_weights():
unet = sdxl_original_unet.SdxlUNet2DConditionModel() # overwrite unet
sdxl_model_util._load_state_dict_on_device(unet, state_dict, device=device, dtype=model_dtype)
print("U-Net converted to original U-Net")
logger.info("U-Net converted to original U-Net")
logit_scale = None
ckpt_info = None
@@ -122,13 +129,13 @@ def _load_target_model(
# VAEを読み込む
if vae_path is not None:
vae = model_util.load_vae(vae_path, weight_dtype)
print("additional VAE loaded")
logger.info("additional VAE loaded")
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
def load_tokenizers(args: argparse.Namespace):
print("prepare tokenizers")
logger.info("prepare tokenizers")
original_paths = [TOKENIZER1_PATH, TOKENIZER2_PATH]
tokeniers = []
@@ -137,14 +144,14 @@ def load_tokenizers(args: argparse.Namespace):
if args.tokenizer_cache_dir:
local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_"))
if os.path.exists(local_tokenizer_path):
print(f"load tokenizer from cache: {local_tokenizer_path}")
logger.info(f"load tokenizer from cache: {local_tokenizer_path}")
tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path)
if tokenizer is None:
tokenizer = CLIPTokenizer.from_pretrained(original_path)
if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
print(f"save Tokenizer to cache: {local_tokenizer_path}")
logger.info(f"save Tokenizer to cache: {local_tokenizer_path}")
tokenizer.save_pretrained(local_tokenizer_path)
if i == 1:
@@ -153,7 +160,7 @@ def load_tokenizers(args: argparse.Namespace):
tokeniers.append(tokenizer)
if hasattr(args, "max_token_length") and args.max_token_length is not None:
print(f"update token length: {args.max_token_length}")
logger.info(f"update token length: {args.max_token_length}")
return tokeniers
@@ -329,28 +336,33 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
action="store_true",
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
)
parser.add_argument(
"--disable_mmap_load_safetensors",
action="store_true",
help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる",
)
def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません"
if args.v_parameterization:
print("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります")
logger.warning("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります")
if args.clip_skip is not None:
print("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")
logger.warning("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")
# if args.multires_noise_iterations:
# print(
# logger.info(
# f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります"
# )
# else:
# if args.noise_offset is None:
# args.noise_offset = DEFAULT_NOISE_OFFSET
# elif args.noise_offset != DEFAULT_NOISE_OFFSET:
# print(
# logger.info(
# f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています"
# )
# print(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
# logger.info(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
assert (
not hasattr(args, "weighted_captions") or not args.weighted_captions
@@ -359,7 +371,7 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin
if supportTextEncoderCaching:
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
args.cache_text_encoder_outputs = True
print(
logger.warning(
"cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / "
+ "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました"
)

View File

@@ -26,7 +26,10 @@ from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
from diffusers.models.vae import DecoderOutput, DiagonalGaussianDistribution
from diffusers.models.autoencoder_kl import AutoencoderKLOutput
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def slice_h(x, num_slices):
# slice with pad 1 both sides: to eliminate side effect of padding of conv2d
@@ -62,7 +65,7 @@ def cat_h(sliced):
return x
def resblock_forward(_self, num_slices, input_tensor, temb):
def resblock_forward(_self, num_slices, input_tensor, temb, **kwargs):
assert _self.upsample is None and _self.downsample is None
assert _self.norm1.num_groups == _self.norm2.num_groups
assert temb is None
@@ -89,7 +92,7 @@ def resblock_forward(_self, num_slices, input_tensor, temb):
# sliced_tensor = torch.chunk(x, num_div, dim=1)
# sliced_weight = torch.chunk(norm.weight, num_div, dim=0)
# sliced_bias = torch.chunk(norm.bias, num_div, dim=0)
# print(sliced_tensor[0].shape, num_div, sliced_weight[0].shape, sliced_bias[0].shape)
# logger.info(sliced_tensor[0].shape, num_div, sliced_weight[0].shape, sliced_bias[0].shape)
# normed_tensor = []
# for i in range(num_div):
# n = torch.group_norm(sliced_tensor[i], norm.num_groups, sliced_weight[i], sliced_bias[i], norm.eps)
@@ -243,7 +246,7 @@ class SlicingEncoder(nn.Module):
self.num_slices = num_slices
div = num_slices / (2 ** (len(self.down_blocks) - 1)) # 深い層はそこまで分割しなくていいので適宜減らす
# print(f"initial divisor: {div}")
# logger.info(f"initial divisor: {div}")
if div >= 2:
div = int(div)
for resnet in self.mid_block.resnets:
@@ -253,11 +256,11 @@ class SlicingEncoder(nn.Module):
for i, down_block in enumerate(self.down_blocks[::-1]):
if div >= 2:
div = int(div)
# print(f"down block: {i} divisor: {div}")
# logger.info(f"down block: {i} divisor: {div}")
for resnet in down_block.resnets:
resnet.forward = wrapper(resblock_forward, resnet, div)
if down_block.downsamplers is not None:
# print("has downsample")
# logger.info("has downsample")
for downsample in down_block.downsamplers:
downsample.forward = wrapper(self.downsample_forward, downsample, div * 2)
div *= 2
@@ -307,7 +310,7 @@ class SlicingEncoder(nn.Module):
def downsample_forward(self, _self, num_slices, hidden_states):
assert hidden_states.shape[1] == _self.channels
assert _self.use_conv and _self.padding == 0
print("downsample forward", num_slices, hidden_states.shape)
logger.info(f"downsample forward {num_slices} {hidden_states.shape}")
org_device = hidden_states.device
cpu_device = torch.device("cpu")
@@ -350,7 +353,7 @@ class SlicingEncoder(nn.Module):
hidden_states = torch.cat([hidden_states, x], dim=2)
hidden_states = hidden_states.to(org_device)
# print("downsample forward done", hidden_states.shape)
# logger.info(f"downsample forward done {hidden_states.shape}")
return hidden_states
@@ -426,7 +429,7 @@ class SlicingDecoder(nn.Module):
self.num_slices = num_slices
div = num_slices / (2 ** (len(self.up_blocks) - 1))
print(f"initial divisor: {div}")
logger.info(f"initial divisor: {div}")
if div >= 2:
div = int(div)
for resnet in self.mid_block.resnets:
@@ -436,11 +439,11 @@ class SlicingDecoder(nn.Module):
for i, up_block in enumerate(self.up_blocks):
if div >= 2:
div = int(div)
# print(f"up block: {i} divisor: {div}")
# logger.info(f"up block: {i} divisor: {div}")
for resnet in up_block.resnets:
resnet.forward = wrapper(resblock_forward, resnet, div)
if up_block.upsamplers is not None:
# print("has upsample")
# logger.info("has upsample")
for upsample in up_block.upsamplers:
upsample.forward = wrapper(self.upsample_forward, upsample, div * 2)
div *= 2
@@ -528,7 +531,7 @@ class SlicingDecoder(nn.Module):
del x
hidden_states = torch.cat(sliced, dim=2)
# print("us hidden_states", hidden_states.shape)
# logger.info(f"us hidden_states {hidden_states.shape}")
del sliced
hidden_states = hidden_states.to(org_device)

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,287 @@
import logging
import sys
import threading
import torch
from torchvision import transforms
from typing import *
from diffusers import EulerAncestralDiscreteScheduler
import diffusers.schedulers.scheduling_euler_ancestral_discrete
from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput
import cv2
from PIL import Image
import numpy as np
def fire_in_thread(f, *args, **kwargs):
threading.Thread(target=f, args=args, kwargs=kwargs).start()
threading.Thread(target=f, args=args, kwargs=kwargs).start()
def add_logging_arguments(parser):
parser.add_argument(
"--console_log_level",
type=str,
default=None,
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
help="Set the logging level, default is INFO / ログレベルを設定する。デフォルトはINFO",
)
parser.add_argument(
"--console_log_file",
type=str,
default=None,
help="Log to a file instead of stderr / 標準エラー出力ではなくファイルにログを出力する",
)
parser.add_argument("--console_log_simple", action="store_true", help="Simple log output / シンプルなログ出力")
def setup_logging(args=None, log_level=None, reset=False):
if logging.root.handlers:
if reset:
# remove all handlers
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
else:
return
# log_level can be set by the caller or by the args, the caller has priority. If not set, use INFO
if log_level is None and args is not None:
log_level = args.console_log_level
if log_level is None:
log_level = "INFO"
log_level = getattr(logging, log_level)
msg_init = None
if args is not None and args.console_log_file:
handler = logging.FileHandler(args.console_log_file, mode="w")
else:
handler = None
if not args or not args.console_log_simple:
try:
from rich.logging import RichHandler
from rich.console import Console
from rich.logging import RichHandler
handler = RichHandler(console=Console(stderr=True))
except ImportError:
# print("rich is not installed, using basic logging")
msg_init = "rich is not installed, using basic logging"
if handler is None:
handler = logging.StreamHandler(sys.stdout) # same as print
handler.propagate = False
formatter = logging.Formatter(
fmt="%(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
handler.setFormatter(formatter)
logging.root.setLevel(log_level)
logging.root.addHandler(handler)
if msg_init is not None:
logger = logging.getLogger(__name__)
logger.info(msg_init)
def pil_resize(image, size, interpolation=Image.LANCZOS):
has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False
if has_alpha:
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA))
else:
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
resized_pil = pil_image.resize(size, interpolation)
# Convert back to cv2 format
if has_alpha:
resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGBA2BGRA)
else:
resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR)
return resized_cv2
# TODO make inf_utils.py
# region Gradual Latent hires fix
class GradualLatent:
def __init__(
self,
ratio,
start_timesteps,
every_n_steps,
ratio_step,
s_noise=1.0,
gaussian_blur_ksize=None,
gaussian_blur_sigma=0.5,
gaussian_blur_strength=0.5,
unsharp_target_x=True,
):
self.ratio = ratio
self.start_timesteps = start_timesteps
self.every_n_steps = every_n_steps
self.ratio_step = ratio_step
self.s_noise = s_noise
self.gaussian_blur_ksize = gaussian_blur_ksize
self.gaussian_blur_sigma = gaussian_blur_sigma
self.gaussian_blur_strength = gaussian_blur_strength
self.unsharp_target_x = unsharp_target_x
def __str__(self) -> str:
return (
f"GradualLatent(ratio={self.ratio}, start_timesteps={self.start_timesteps}, "
+ f"every_n_steps={self.every_n_steps}, ratio_step={self.ratio_step}, s_noise={self.s_noise}, "
+ f"gaussian_blur_ksize={self.gaussian_blur_ksize}, gaussian_blur_sigma={self.gaussian_blur_sigma}, gaussian_blur_strength={self.gaussian_blur_strength}, "
+ f"unsharp_target_x={self.unsharp_target_x})"
)
def apply_unshark_mask(self, x: torch.Tensor):
if self.gaussian_blur_ksize is None:
return x
blurred = transforms.functional.gaussian_blur(x, self.gaussian_blur_ksize, self.gaussian_blur_sigma)
# mask = torch.sigmoid((x - blurred) * self.gaussian_blur_strength)
mask = (x - blurred) * self.gaussian_blur_strength
sharpened = x + mask
return sharpened
def interpolate(self, x: torch.Tensor, resized_size, unsharp=True):
org_dtype = x.dtype
if org_dtype == torch.bfloat16:
x = x.float()
x = torch.nn.functional.interpolate(x, size=resized_size, mode="bicubic", align_corners=False).to(dtype=org_dtype)
# apply unsharp mask / アンシャープマスクを適用する
if unsharp and self.gaussian_blur_ksize:
x = self.apply_unshark_mask(x)
return x
class EulerAncestralDiscreteSchedulerGL(EulerAncestralDiscreteScheduler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.resized_size = None
self.gradual_latent = None
def set_gradual_latent_params(self, size, gradual_latent: GradualLatent):
self.resized_size = size
self.gradual_latent = gradual_latent
def step(
self,
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`):
Whether or not to return a
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
Returns:
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
If return_dict is `True`,
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
otherwise a tuple is returned where the first element is the sample tensor.
"""
if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor):
raise ValueError(
(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep."
),
)
if not self.is_scale_input_called:
# logger.warning(
print(
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
"See `StableDiffusionPipeline` for a usage example."
)
if self.step_index is None:
self._init_step_index(timestep)
sigma = self.sigmas[self.step_index]
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
if self.config.prediction_type == "epsilon":
pred_original_sample = sample - sigma * model_output
elif self.config.prediction_type == "v_prediction":
# * c_out + input * c_skip
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
elif self.config.prediction_type == "sample":
raise NotImplementedError("prediction_type not implemented yet: sample")
else:
raise ValueError(f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`")
sigma_from = self.sigmas[self.step_index]
sigma_to = self.sigmas[self.step_index + 1]
sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
# 2. Convert to an ODE derivative
derivative = (sample - pred_original_sample) / sigma
dt = sigma_down - sigma
device = model_output.device
if self.resized_size is None:
prev_sample = sample + derivative * dt
noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor(
model_output.shape, dtype=model_output.dtype, device=device, generator=generator
)
s_noise = 1.0
else:
print("resized_size", self.resized_size, "model_output.shape", model_output.shape, "sample.shape", sample.shape)
s_noise = self.gradual_latent.s_noise
if self.gradual_latent.unsharp_target_x:
prev_sample = sample + derivative * dt
prev_sample = self.gradual_latent.interpolate(prev_sample, self.resized_size)
else:
sample = self.gradual_latent.interpolate(sample, self.resized_size)
derivative = self.gradual_latent.interpolate(derivative, self.resized_size, unsharp=False)
prev_sample = sample + derivative * dt
noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor(
(model_output.shape[0], model_output.shape[1], self.resized_size[0], self.resized_size[1]),
dtype=model_output.dtype,
device=device,
generator=generator,
)
prev_sample = prev_sample + noise * sigma_up * s_noise
# upon completion increase step index by one
self._step_index += 1
if not return_dict:
return (prev_sample,)
return EulerAncestralDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
# endregion

View File

@@ -2,10 +2,13 @@ import argparse
import os
import torch
from safetensors.torch import load_file
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def main(file):
print(f"loading: {file}")
logger.info(f"loading: {file}")
if os.path.splitext(file)[1] == ".safetensors":
sd = load_file(file)
else:
@@ -15,7 +18,7 @@ def main(file):
keys = list(sd.keys())
for key in keys:
if "lora_up" in key or "lora_down" in key:
if "lora_up" in key or "lora_down" in key or "lora_A" in key or "lora_B" in key or "oft_" in key:
values.append((key, sd[key]))
print(f"number of LoRA modules: {len(values)}")

View File

@@ -2,7 +2,10 @@ import os
from typing import Optional, List, Type
import torch
from library import sdxl_original_unet
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# input_blocksに適用するかどうか / if True, input_blocks are not applied
SKIP_INPUT_BLOCKS = False
@@ -125,7 +128,7 @@ class LLLiteModule(torch.nn.Module):
return
# timestepごとに呼ばれないので、あらかじめ計算しておく / it is not called for each timestep, so calculate it in advance
# print(f"C {self.lllite_name}, cond_image.shape={cond_image.shape}")
# logger.info(f"C {self.lllite_name}, cond_image.shape={cond_image.shape}")
cx = self.conditioning1(cond_image)
if not self.is_conv2d:
# reshape / b,c,h,w -> b,h*w,c
@@ -155,7 +158,7 @@ class LLLiteModule(torch.nn.Module):
cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1)
if self.use_zeros_for_batch_uncond:
cx[0::2] = 0.0 # uncond is zero
# print(f"C {self.lllite_name}, x.shape={x.shape}, cx.shape={cx.shape}")
# logger.info(f"C {self.lllite_name}, x.shape={x.shape}, cx.shape={cx.shape}")
# downで入力の次元数を削減し、conditioning image embeddingと結合する
# 加算ではなくchannel方向に結合することで、うまいこと混ぜてくれることを期待している
@@ -286,7 +289,7 @@ class ControlNetLLLite(torch.nn.Module):
# create module instances
self.unet_modules: List[LLLiteModule] = create_modules(unet, target_modules, LLLiteModule)
print(f"create ControlNet LLLite for U-Net: {len(self.unet_modules)} modules.")
logger.info(f"create ControlNet LLLite for U-Net: {len(self.unet_modules)} modules.")
def forward(self, x):
return x # dummy
@@ -319,7 +322,7 @@ class ControlNetLLLite(torch.nn.Module):
return info
def apply_to(self):
print("applying LLLite for U-Net...")
logger.info("applying LLLite for U-Net...")
for module in self.unet_modules:
module.apply_to()
self.add_module(module.lllite_name, module)
@@ -374,19 +377,19 @@ if __name__ == "__main__":
# sdxl_original_unet.USE_REENTRANT = False
# test shape etc
print("create unet")
logger.info("create unet")
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
unet.to("cuda").to(torch.float16)
print("create ControlNet-LLLite")
logger.info("create ControlNet-LLLite")
control_net = ControlNetLLLite(unet, 32, 64)
control_net.apply_to()
control_net.to("cuda")
print(control_net)
logger.info(control_net)
# print number of parameters
print("number of parameters", sum(p.numel() for p in control_net.parameters() if p.requires_grad))
# logger.info number of parameters
logger.info(f"number of parameters {sum(p.numel() for p in control_net.parameters() if p.requires_grad)}")
input()
@@ -398,12 +401,12 @@ if __name__ == "__main__":
# # visualize
# import torchviz
# print("run visualize")
# logger.info("run visualize")
# controlnet.set_control(conditioning_image)
# output = unet(x, t, ctx, y)
# print("make_dot")
# logger.info("make_dot")
# image = torchviz.make_dot(output, params=dict(controlnet.named_parameters()))
# print("render")
# logger.info("render")
# image.format = "svg" # "png"
# image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time
# input()
@@ -414,12 +417,12 @@ if __name__ == "__main__":
scaler = torch.cuda.amp.GradScaler(enabled=True)
print("start training")
logger.info("start training")
steps = 10
sample_param = [p for p in control_net.named_parameters() if "up" in p[0]][0]
for step in range(steps):
print(f"step {step}")
logger.info(f"step {step}")
batch_size = 1
conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0
@@ -439,7 +442,7 @@ if __name__ == "__main__":
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
print(sample_param)
logger.info(f"{sample_param}")
# from safetensors.torch import save_file

View File

@@ -6,7 +6,12 @@ import re
from typing import Optional, List, Type
import torch
from library import sdxl_original_unet
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# input_blocksに適用するかどうか / if True, input_blocks are not applied
SKIP_INPUT_BLOCKS = False
@@ -100,19 +105,15 @@ class LLLiteLinear(ORIGINAL_LINEAR):
add_lllite_modules(self, in_dim, depth, cond_emb_dim, mlp_dim)
self.cond_image = None
self.cond_emb = None
def set_cond_image(self, cond_image):
self.cond_image = cond_image
self.cond_emb = None
def forward(self, x):
if not self.enabled:
return super().forward(x)
if self.cond_emb is None:
self.cond_emb = self.lllite_conditioning1(self.cond_image)
cx = self.cond_emb
cx = self.lllite_conditioning1(self.cond_image) # make forward and backward compatible
# reshape / b,c,h,w -> b,h*w,c
n, c, h, w = cx.shape
@@ -156,9 +157,7 @@ class LLLiteConv2d(ORIGINAL_CONV2D):
if not self.enabled:
return super().forward(x)
if self.cond_emb is None:
self.cond_emb = self.lllite_conditioning1(self.cond_image)
cx = self.cond_emb
cx = self.lllite_conditioning1(self.cond_image)
cx = torch.cat([cx, self.down(x)], dim=1)
cx = self.mid(cx)
@@ -270,7 +269,7 @@ class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DCond
# create module instances
self.lllite_modules = apply_to_modules(self, target_modules)
print(f"enable ControlNet LLLite for U-Net: {len(self.lllite_modules)} modules.")
logger.info(f"enable ControlNet LLLite for U-Net: {len(self.lllite_modules)} modules.")
# def prepare_optimizer_params(self):
def prepare_params(self):
@@ -281,8 +280,8 @@ class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DCond
train_params.append(p)
else:
non_train_params.append(p)
print(f"count of trainable parameters: {len(train_params)}")
print(f"count of non-trainable parameters: {len(non_train_params)}")
logger.info(f"count of trainable parameters: {len(train_params)}")
logger.info(f"count of non-trainable parameters: {len(non_train_params)}")
for p in non_train_params:
p.requires_grad_(False)
@@ -388,7 +387,7 @@ class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DCond
matches = pattern.findall(module_name)
if matches is not None:
for m in matches:
print(module_name, m)
logger.info(f"{module_name} {m}")
module_name = module_name.replace(m, m.replace("_", "@"))
module_name = module_name.replace("_", ".")
module_name = module_name.replace("@", "_")
@@ -407,7 +406,7 @@ class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DCond
def replace_unet_linear_and_conv2d():
print("replace torch.nn.Linear and torch.nn.Conv2d to LLLiteLinear and LLLiteConv2d in U-Net")
logger.info("replace torch.nn.Linear and torch.nn.Conv2d to LLLiteLinear and LLLiteConv2d in U-Net")
sdxl_original_unet.torch.nn.Linear = LLLiteLinear
sdxl_original_unet.torch.nn.Conv2d = LLLiteConv2d
@@ -419,10 +418,10 @@ if __name__ == "__main__":
replace_unet_linear_and_conv2d()
# test shape etc
print("create unet")
logger.info("create unet")
unet = SdxlUNet2DConditionModelControlNetLLLite()
print("enable ControlNet-LLLite")
logger.info("enable ControlNet-LLLite")
unet.apply_lllite(32, 64, None, False, 1.0)
unet.to("cuda") # .to(torch.float16)
@@ -439,14 +438,14 @@ if __name__ == "__main__":
# unet_sd[converted_key] = model_sd[key]
# info = unet.load_lllite_weights("r:/lllite_from_unet.safetensors", unet_sd)
# print(info)
# logger.info(info)
# print(unet)
# logger.info(unet)
# print number of parameters
# logger.info number of parameters
params = unet.prepare_params()
print("number of parameters", sum(p.numel() for p in params))
# print("type any key to continue")
logger.info(f"number of parameters {sum(p.numel() for p in params)}")
# logger.info("type any key to continue")
# input()
unet.set_use_memory_efficient_attention(True, False)
@@ -455,12 +454,12 @@ if __name__ == "__main__":
# # visualize
# import torchviz
# print("run visualize")
# logger.info("run visualize")
# controlnet.set_control(conditioning_image)
# output = unet(x, t, ctx, y)
# print("make_dot")
# logger.info("make_dot")
# image = torchviz.make_dot(output, params=dict(controlnet.named_parameters()))
# print("render")
# logger.info("render")
# image.format = "svg" # "png"
# image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time
# input()
@@ -471,13 +470,13 @@ if __name__ == "__main__":
scaler = torch.cuda.amp.GradScaler(enabled=True)
print("start training")
logger.info("start training")
steps = 10
batch_size = 1
sample_param = [p for p in unet.named_parameters() if ".lllite_up." in p[0]][0]
for step in range(steps):
print(f"step {step}")
logger.info(f"step {step}")
conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0
x = torch.randn(batch_size, 4, 128, 128).cuda()
@@ -494,9 +493,9 @@ if __name__ == "__main__":
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
print(sample_param)
logger.info(sample_param)
# from safetensors.torch import save_file
# print("save weights")
# logger.info("save weights")
# unet.save_lllite_weights("r:/lllite_from_unet.safetensors", torch.float16, None)

View File

@@ -12,9 +12,17 @@
import math
import os
import random
from typing import List, Tuple, Union
from typing import Dict, List, Optional, Tuple, Type, Union
from diffusers import AutoencoderKL
from transformers import CLIPTextModel
import torch
from torch import nn
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class DyLoRAModule(torch.nn.Module):
@@ -165,7 +173,15 @@ class DyLoRAModule(torch.nn.Module):
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
def create_network(
multiplier: float,
network_dim: Optional[int],
network_alpha: Optional[float],
vae: AutoencoderKL,
text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
unet,
**kwargs,
):
if network_dim is None:
network_dim = 4 # default
if network_alpha is None:
@@ -182,6 +198,7 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
conv_alpha = 1.0
else:
conv_alpha = float(conv_alpha)
if unit is not None:
unit = int(unit)
else:
@@ -197,6 +214,16 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
unit=unit,
varbose=True,
)
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
return network
@@ -223,7 +250,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
elif "lora_down" in key:
dim = value.size()[0]
modules_dim[lora_name] = dim
# print(lora_name, value.size(), dim)
# logger.info(f"{lora_name} {value.size()} {dim}")
# support old LoRA without alpha
for key in modules_dim.keys():
@@ -241,7 +268,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
class DyLoRANetwork(torch.nn.Module):
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
LORA_PREFIX_UNET = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te"
@@ -266,12 +293,16 @@ class DyLoRANetwork(torch.nn.Module):
self.alpha = alpha
self.apply_to_conv = apply_to_conv
self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None
self.loraplus_text_encoder_lr_ratio = None
if modules_dim is not None:
print(f"create LoRA network from weights")
logger.info("create LoRA network from weights")
else:
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, unit: {unit}")
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, unit: {unit}")
if self.apply_to_conv:
print(f"apply LoRA to Conv2d with kernel size (3,3).")
logger.info("apply LoRA to Conv2d with kernel size (3,3).")
# create module instances
def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[DyLoRAModule]:
@@ -307,8 +338,22 @@ class DyLoRANetwork(torch.nn.Module):
loras.append(lora)
return loras
self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
self.text_encoder_loras = []
for i, text_encoder in enumerate(text_encoders):
if len(text_encoders) > 1:
index = i + 1
logger.info(f"create LoRA for Text Encoder {index}")
else:
index = None
logger.info("create LoRA for Text Encoder")
text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
self.text_encoder_loras.extend(text_encoder_loras)
# self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
target_modules = DyLoRANetwork.UNET_TARGET_REPLACE_MODULE
@@ -316,7 +361,15 @@ class DyLoRANetwork(torch.nn.Module):
target_modules += DyLoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
self.unet_loras = create_modules(True, unet, target_modules)
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
self.loraplus_lr_ratio = loraplus_lr_ratio
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}")
logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}")
def set_multiplier(self, multiplier):
self.multiplier = multiplier
@@ -336,12 +389,12 @@ class DyLoRANetwork(torch.nn.Module):
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
if apply_text_encoder:
print("enable LoRA for text encoder")
logger.info("enable LoRA for text encoder")
else:
self.text_encoder_loras = []
if apply_unet:
print("enable LoRA for U-Net")
logger.info("enable LoRA for U-Net")
else:
self.unet_loras = []
@@ -359,12 +412,12 @@ class DyLoRANetwork(torch.nn.Module):
apply_unet = True
if apply_text_encoder:
print("enable LoRA for text encoder")
logger.info("enable LoRA for text encoder")
else:
self.text_encoder_loras = []
if apply_unet:
print("enable LoRA for U-Net")
logger.info("enable LoRA for U-Net")
else:
self.unet_loras = []
@@ -375,30 +428,56 @@ class DyLoRANetwork(torch.nn.Module):
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
lora.merge_to(sd_for_lora, dtype, device)
print(f"weights are merged")
logger.info(f"weights are merged")
"""
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
self.requires_grad_(True)
all_params = []
def enumerate_params(loras):
params = []
def assemble_params(loras, lr, ratio):
param_groups = {"lora": {}, "plus": {}}
for lora in loras:
params.extend(lora.parameters())
for name, param in lora.named_parameters():
if ratio is not None and "lora_B" in name:
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
else:
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
params = []
for key in param_groups.keys():
param_data = {"params": param_groups[key].values()}
if len(param_data["params"]) == 0:
continue
if lr is not None:
if key == "plus":
param_data["lr"] = lr * ratio
else:
param_data["lr"] = lr
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
continue
params.append(param_data)
return params
if self.text_encoder_loras:
param_data = {"params": enumerate_params(self.text_encoder_loras)}
if text_encoder_lr is not None:
param_data["lr"] = text_encoder_lr
all_params.append(param_data)
params = assemble_params(
self.text_encoder_loras,
text_encoder_lr if text_encoder_lr is not None else default_lr,
self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio,
)
all_params.extend(params)
if self.unet_loras:
param_data = {"params": enumerate_params(self.unet_loras)}
if unet_lr is not None:
param_data["lr"] = unet_lr
all_params.append(param_data)
params = assemble_params(
self.unet_loras, default_lr if unet_lr is None else unet_lr, self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio
)
all_params.extend(params)
return all_params

View File

@@ -10,7 +10,10 @@ from safetensors.torch import load_file, save_file, safe_open
from tqdm import tqdm
from library import train_util, model_util
import numpy as np
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def load_state_dict(file_name):
if model_util.is_safetensors(file_name):
@@ -40,13 +43,13 @@ def split_lora_model(lora_sd, unit):
rank = value.size()[0]
if rank > max_rank:
max_rank = rank
print(f"Max rank: {max_rank}")
logger.info(f"Max rank: {max_rank}")
rank = unit
split_models = []
new_alpha = None
while rank < max_rank:
print(f"Splitting rank {rank}")
logger.info(f"Splitting rank {rank}")
new_sd = {}
for key, value in lora_sd.items():
if "lora_down" in key:
@@ -57,7 +60,7 @@ def split_lora_model(lora_sd, unit):
# なぜかscaleするとおかしくなる……
# this_rank = lora_sd[key.replace("alpha", "lora_down.weight")].size()[0]
# scale = math.sqrt(this_rank / rank) # rank is > unit
# print(key, value.size(), this_rank, rank, value, scale)
# logger.info(key, value.size(), this_rank, rank, value, scale)
# new_alpha = value * scale # always same
# new_sd[key] = new_alpha
new_sd[key] = value
@@ -69,10 +72,10 @@ def split_lora_model(lora_sd, unit):
def split(args):
print("loading Model...")
logger.info("loading Model...")
lora_sd, metadata = load_state_dict(args.model)
print("Splitting Model...")
logger.info("Splitting Model...")
original_rank, split_models = split_lora_model(lora_sd, args.unit)
comment = metadata.get("ss_training_comment", "")
@@ -94,7 +97,7 @@ def split(args):
filename, ext = os.path.splitext(args.save_to)
model_file_name = filename + f"-{new_rank:04d}{ext}"
print(f"saving model to: {model_file_name}")
logger.info(f"saving model to: {model_file_name}")
save_to_file(model_file_name, state_dict, new_metadata)

View File

@@ -11,10 +11,13 @@ from safetensors.torch import load_file, save_file
from tqdm import tqdm
from library import sai_model_spec, model_util, sdxl_model_util
import lora
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
CLAMP_QUANTILE = 0.99
MIN_DIFF = 1e-4
# CLAMP_QUANTILE = 0.99
# MIN_DIFF = 1e-1
def save_to_file(file_name, model, state_dict, dtype):
@@ -29,7 +32,24 @@ def save_to_file(file_name, model, state_dict, dtype):
torch.save(model, file_name)
def svd(args):
def svd(
model_org=None,
model_tuned=None,
save_to=None,
dim=4,
v2=None,
sdxl=None,
conv_dim=None,
v_parameterization=None,
device=None,
save_precision=None,
clamp_quantile=0.99,
min_diff=0.01,
no_metadata=False,
load_precision=None,
load_original_model_to=None,
load_tuned_model_to=None,
):
def str_to_dtype(p):
if p == "float":
return torch.float
@@ -39,44 +59,65 @@ def svd(args):
return torch.bfloat16
return None
assert args.v2 != args.sdxl or (
not args.v2 and not args.sdxl
), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません"
if args.v_parameterization is None:
args.v_parameterization = args.v2
assert v2 != sdxl or (not v2 and not sdxl), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません"
if v_parameterization is None:
v_parameterization = v2
save_dtype = str_to_dtype(args.save_precision)
load_dtype = str_to_dtype(load_precision) if load_precision else None
save_dtype = str_to_dtype(save_precision)
work_device = "cpu"
# load models
if not args.sdxl:
print(f"loading original SD model : {args.model_org}")
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org)
if not sdxl:
logger.info(f"loading original SD model : {model_org}")
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org)
text_encoders_o = [text_encoder_o]
print(f"loading tuned SD model : {args.model_tuned}")
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned)
if load_dtype is not None:
text_encoder_o = text_encoder_o.to(load_dtype)
unet_o = unet_o.to(load_dtype)
logger.info(f"loading tuned SD model : {model_tuned}")
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned)
text_encoders_t = [text_encoder_t]
model_version = model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization)
if load_dtype is not None:
text_encoder_t = text_encoder_t.to(load_dtype)
unet_t = unet_t.to(load_dtype)
model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization)
else:
print(f"loading original SDXL model : {args.model_org}")
device_org = load_original_model_to if load_original_model_to else "cpu"
device_tuned = load_tuned_model_to if load_tuned_model_to else "cpu"
logger.info(f"loading original SDXL model : {model_org}")
text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.model_org, "cpu"
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, device_org
)
text_encoders_o = [text_encoder_o1, text_encoder_o2]
print(f"loading original SDXL model : {args.model_tuned}")
if load_dtype is not None:
text_encoder_o1 = text_encoder_o1.to(load_dtype)
text_encoder_o2 = text_encoder_o2.to(load_dtype)
unet_o = unet_o.to(load_dtype)
logger.info(f"loading original SDXL model : {model_tuned}")
text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.model_tuned, "cpu"
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, device_tuned
)
text_encoders_t = [text_encoder_t1, text_encoder_t2]
if load_dtype is not None:
text_encoder_t1 = text_encoder_t1.to(load_dtype)
text_encoder_t2 = text_encoder_t2.to(load_dtype)
unet_t = unet_t.to(load_dtype)
model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0
# create LoRA network to extract weights: Use dim (rank) as alpha
if args.conv_dim is None:
if conv_dim is None:
kwargs = {}
else:
kwargs = {"conv_dim": args.conv_dim, "conv_alpha": args.conv_dim}
kwargs = {"conv_dim": conv_dim, "conv_alpha": conv_dim}
lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoders_o, unet_o, **kwargs)
lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoders_t, unet_t, **kwargs)
lora_network_o = lora.create_network(1.0, dim, dim, None, text_encoders_o, unet_o, **kwargs)
lora_network_t = lora.create_network(1.0, dim, dim, None, text_encoders_t, unet_t, **kwargs)
assert len(lora_network_o.text_encoder_loras) == len(
lora_network_t.text_encoder_loras
), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違いますSD1.xベースとSD2.xベース "
@@ -88,50 +129,66 @@ def svd(args):
lora_name = lora_o.lora_name
module_o = lora_o.org_module
module_t = lora_t.org_module
diff = module_t.weight - module_o.weight
diff = module_t.weight.to(work_device) - module_o.weight.to(work_device)
# clear weight to save memory
module_o.weight = None
module_t.weight = None
# Text Encoder might be same
if not text_encoder_different and torch.max(torch.abs(diff)) > MIN_DIFF:
if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff:
text_encoder_different = True
print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {MIN_DIFF}")
logger.info(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}")
diff = diff.float()
diffs[lora_name] = diff
# clear target Text Encoder to save memory
for text_encoder in text_encoders_t:
del text_encoder
if not text_encoder_different:
print("Text encoder is same. Extract U-Net only.")
logger.warning("Text encoder is same. Extract U-Net only.")
lora_network_o.text_encoder_loras = []
diffs = {}
diffs = {} # clear diffs
for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)):
lora_name = lora_o.lora_name
module_o = lora_o.org_module
module_t = lora_t.org_module
diff = module_t.weight - module_o.weight
diff = diff.float()
diff = module_t.weight.to(work_device) - module_o.weight.to(work_device)
if args.device:
diff = diff.to(args.device)
# clear weight to save memory
module_o.weight = None
module_t.weight = None
diffs[lora_name] = diff
# clear LoRA network, target U-Net to save memory
del lora_network_o
del lora_network_t
del unet_t
# make LoRA with svd
print("calculating by svd")
logger.info("calculating by svd")
lora_weights = {}
with torch.no_grad():
for lora_name, mat in tqdm(list(diffs.items())):
# if args.conv_dim is None, diffs do not include LoRAs for conv2d-3x3
if args.device:
mat = mat.to(args.device)
mat = mat.to(torch.float) # calc by float
# if conv_dim is None, diffs do not include LoRAs for conv2d-3x3
conv2d = len(mat.size()) == 4
kernel_size = None if not conv2d else mat.size()[2:4]
conv2d_3x3 = conv2d and kernel_size != (1, 1)
rank = args.dim if not conv2d_3x3 or args.conv_dim is None else args.conv_dim
rank = dim if not conv2d_3x3 or conv_dim is None else conv_dim
out_dim, in_dim = mat.size()[0:2]
if args.device:
mat = mat.to(args.device)
if device:
mat = mat.to(device)
# print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim)
# logger.info(lora_name, mat.size(), mat.device, rank, in_dim, out_dim)
rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
if conv2d:
@@ -149,7 +206,7 @@ def svd(args):
Vh = Vh[:rank, :]
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
hi_val = torch.quantile(dist, clamp_quantile)
low_val = -hi_val
U = U.clamp(low_val, hi_val)
@@ -159,8 +216,8 @@ def svd(args):
U = U.reshape(out_dim, rank, 1, 1)
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
U = U.to("cpu").contiguous()
Vh = Vh.to("cpu").contiguous()
U = U.to(work_device, dtype=save_dtype).contiguous()
Vh = Vh.to(work_device, dtype=save_dtype).contiguous()
lora_weights[lora_name] = (U, Vh)
@@ -176,36 +233,34 @@ def svd(args):
lora_network_save.apply_to(text_encoders_o, unet_o) # create internal module references for state_dict
info = lora_network_save.load_state_dict(lora_sd)
print(f"Loading extracted LoRA weights: {info}")
logger.info(f"Loading extracted LoRA weights: {info}")
dir_name = os.path.dirname(args.save_to)
dir_name = os.path.dirname(save_to)
if dir_name and not os.path.exists(dir_name):
os.makedirs(dir_name, exist_ok=True)
# minimum metadata
net_kwargs = {}
if args.conv_dim is not None:
net_kwargs["conv_dim"] = args.conv_dim
net_kwargs["conv_alpha"] = args.conv_dim
if conv_dim is not None:
net_kwargs["conv_dim"] = str(conv_dim)
net_kwargs["conv_alpha"] = str(float(conv_dim))
metadata = {
"ss_v2": str(args.v2),
"ss_v2": str(v2),
"ss_base_model_version": model_version,
"ss_network_module": "networks.lora",
"ss_network_dim": str(args.dim),
"ss_network_alpha": str(args.dim),
"ss_network_dim": str(dim),
"ss_network_alpha": str(float(dim)),
"ss_network_args": json.dumps(net_kwargs),
}
if not args.no_metadata:
title = os.path.splitext(os.path.basename(args.save_to))[0]
sai_metadata = sai_model_spec.build_metadata(
None, args.v2, args.v_parameterization, False, True, False, time.time(), title=title
)
if not no_metadata:
title = os.path.splitext(os.path.basename(save_to))[0]
sai_metadata = sai_model_spec.build_metadata(None, v2, v_parameterization, sdxl, True, False, time.time(), title=title)
metadata.update(sai_metadata)
lora_network_save.save_weights(args.save_to, save_dtype, metadata)
print(f"LoRA weights are saved to: {args.save_to}")
lora_network_save.save_weights(save_to, save_dtype, metadata)
logger.info(f"LoRA weights are saved to: {save_to}")
def setup_parser() -> argparse.ArgumentParser:
@@ -213,13 +268,20 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む")
parser.add_argument(
"--v_parameterization",
type=bool,
action="store_true",
default=None,
help="make LoRA metadata for v-parameterization (default is same to v2) / 作成するLoRAのメタデータにv-parameterization用と設定する省略時はv2と同じ",
)
parser.add_argument(
"--sdxl", action="store_true", help="load Stable Diffusion SDXL base model / Stable Diffusion SDXL baseのモデルを読み込む"
)
parser.add_argument(
"--load_precision",
type=str,
default=None,
choices=[None, "float", "fp16", "bf16"],
help="precision in loading, model default if omitted / 読み込み時に精度を変更して読み込む、省略時はモデルファイルによる"
)
parser.add_argument(
"--save_precision",
type=str,
@@ -231,16 +293,22 @@ def setup_parser() -> argparse.ArgumentParser:
"--model_org",
type=str,
default=None,
required=True,
help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors",
)
parser.add_argument(
"--model_tuned",
type=str,
default=None,
required=True,
help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル生成されるLoRAは元→派生の差分になります、ckptまたはsafetensors",
)
parser.add_argument(
"--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
"--save_to",
type=str,
default=None,
required=True,
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors",
)
parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数rankデフォルト4")
parser.add_argument(
@@ -250,12 +318,37 @@ def setup_parser() -> argparse.ArgumentParser:
help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数rankデフォルトNone、適用なし",
)
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
parser.add_argument(
"--clamp_quantile",
type=float,
default=0.99,
help="Quantile clamping value, float, (0-1). Default = 0.99 / 値をクランプするための分位点、float、(0-1)。デフォルトは0.99",
)
parser.add_argument(
"--min_diff",
type=float,
default=0.01,
help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01 /"
+ "LoRAを抽出するために元モデルと派生モデルの差分の最小値、float、(0-1)。デフォルトは0.01",
)
parser.add_argument(
"--no_metadata",
action="store_true",
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
+ "sai modelspecのメタデータを保存しないLoRAの最低限のss_metadataは保存される",
)
parser.add_argument(
"--load_original_model_to",
type=str,
default=None,
help="location to load original model, cpu or cuda, cuda:0, etc, default is cpu, only for SDXL / 元モデル読み込み先、cpuまたはcuda、cuda:0など、省略時はcpu、SDXLのみ有効",
)
parser.add_argument(
"--load_tuned_model_to",
type=str,
default=None,
help="location to load tuned model, cpu or cuda, cuda:0, etc, default is cpu, only for SDXL / 派生モデル読み込み先、cpuまたはcuda、cuda:0など、省略時はcpu、SDXLのみ有効",
)
return parser
@@ -264,4 +357,4 @@ if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
svd(args)
svd(**vars(args))

View File

@@ -11,7 +11,13 @@ from transformers import CLIPTextModel
import numpy as np
import torch
import re
from library.utils import setup_logging
from library.sdxl_original_unet import SdxlUNet2DConditionModel
setup_logging()
import logging
logger = logging.getLogger(__name__)
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
@@ -46,7 +52,7 @@ class LoRAModule(torch.nn.Module):
# if limit_rank:
# self.lora_dim = min(lora_dim, in_dim, out_dim)
# if self.lora_dim != lora_dim:
# print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
# logger.info(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
# else:
self.lora_dim = lora_dim
@@ -177,7 +183,7 @@ class LoRAInfModule(LoRAModule):
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# print(conved.size(), weight.size(), module.stride, module.padding)
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + self.multiplier * conved * self.scale
# set weight to org_module
@@ -216,7 +222,7 @@ class LoRAInfModule(LoRAModule):
self.region_mask = None
def default_forward(self, x):
# print("default_forward", self.lora_name, x.size())
# logger.info(f"default_forward {self.lora_name} {x.size()}")
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
def forward(self, x):
@@ -242,13 +248,13 @@ class LoRAInfModule(LoRAModule):
area = x.size()[1]
mask = self.network.mask_dic.get(area, None)
if mask is None:
# raise ValueError(f"mask is None for resolution {area}")
if mask is None or len(x.size()) == 2:
# emb_layers in SDXL doesn't have mask
# print(f"mask is None for resolution {area}, {x.size()}")
# if "emb" not in self.lora_name:
# print(f"mask is None for resolution {self.lora_name}, {area}, {x.size()}")
mask_size = (1, x.size()[1]) if len(x.size()) == 2 else (1, *x.size()[1:-1], 1)
return torch.ones(mask_size, dtype=x.dtype, device=x.device) / self.network.num_sub_prompts
if len(x.size()) != 4:
if len(x.size()) == 3:
mask = torch.reshape(mask, (1, -1, 1))
return mask
@@ -263,6 +269,8 @@ class LoRAInfModule(LoRAModule):
lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
mask = self.get_mask_for_x(lx)
# print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size())
# if mask.ndim > lx.ndim: # in some resolution, lx is 2d and mask is 3d (the reason is not checked)
# mask = mask.squeeze(-1)
lx = lx * mask
x = self.org_forward(x)
@@ -291,7 +299,7 @@ class LoRAInfModule(LoRAModule):
if has_real_uncond:
query[-self.network.batch_size :] = x[-self.network.batch_size :]
# print("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts)
# logger.info(f"postp_to_q {self.lora_name} {x.size()} {query.size()} {self.network.num_sub_prompts}")
return query
def sub_prompt_forward(self, x):
@@ -306,7 +314,7 @@ class LoRAInfModule(LoRAModule):
lx = x[emb_idx :: self.network.num_sub_prompts]
lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale
# print("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx)
# logger.info(f"sub_prompt_forward {self.lora_name} {x.size()} {lx.size()} {emb_idx}")
x = self.org_forward(x)
x[emb_idx :: self.network.num_sub_prompts] += lx
@@ -314,7 +322,7 @@ class LoRAInfModule(LoRAModule):
return x
def to_out_forward(self, x):
# print("to_out_forward", self.lora_name, x.size(), self.network.is_last_network)
# logger.info(f"to_out_forward {self.lora_name} {x.size()} {self.network.is_last_network}")
if self.network.is_last_network:
masks = [None] * self.network.num_sub_prompts
@@ -332,7 +340,7 @@ class LoRAInfModule(LoRAModule):
)
self.network.shared[self.lora_name] = (lx, masks)
# print("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts)
# logger.info(f"to_out_forward {lx.size()} {lx1.size()} {self.network.sub_prompt_index} {self.network.num_sub_prompts}")
lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1
masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1)
@@ -351,7 +359,7 @@ class LoRAInfModule(LoRAModule):
if has_real_uncond:
out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond
# print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts)
# logger.info(f"to_out_forward {self.lora_name} {self.network.sub_prompt_index} {self.network.num_sub_prompts}")
# if num_sub_prompts > num of LoRAs, fill with zero
for i in range(len(masks)):
if masks[i] is None:
@@ -374,18 +382,18 @@ class LoRAInfModule(LoRAModule):
x1 = x1 + lx1
out[self.network.batch_size + i] = x1
# print("to_out_forward", x.size(), out.size(), has_real_uncond)
# logger.info(f"to_out_forward {x.size()} {out.size()} {has_real_uncond}")
return out
def parse_block_lr_kwargs(nw_kwargs):
def parse_block_lr_kwargs(is_sdxl: bool, nw_kwargs: Dict) -> Optional[List[float]]:
down_lr_weight = nw_kwargs.get("down_lr_weight", None)
mid_lr_weight = nw_kwargs.get("mid_lr_weight", None)
up_lr_weight = nw_kwargs.get("up_lr_weight", None)
# 以上のいずれにも設定がない場合は無効としてNoneを返す
if down_lr_weight is None and mid_lr_weight is None and up_lr_weight is None:
return None, None, None
return None
# extract learning rate weight for each block
if down_lr_weight is not None:
@@ -394,18 +402,16 @@ def parse_block_lr_kwargs(nw_kwargs):
down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")]
if mid_lr_weight is not None:
mid_lr_weight = float(mid_lr_weight)
mid_lr_weight = [(float(s) if s else 0.0) for s in mid_lr_weight.split(",")]
if up_lr_weight is not None:
if "," in up_lr_weight:
up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")]
down_lr_weight, mid_lr_weight, up_lr_weight = get_block_lr_weight(
down_lr_weight, mid_lr_weight, up_lr_weight, float(nw_kwargs.get("block_lr_zero_threshold", 0.0))
return get_block_lr_weight(
is_sdxl, down_lr_weight, mid_lr_weight, up_lr_weight, float(nw_kwargs.get("block_lr_zero_threshold", 0.0))
)
return down_lr_weight, mid_lr_weight, up_lr_weight
def create_network(
multiplier: float,
@@ -417,6 +423,9 @@ def create_network(
neuron_dropout: Optional[float] = None,
**kwargs,
):
# if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True
is_sdxl = unet is not None and issubclass(unet.__class__, SdxlUNet2DConditionModel)
if network_dim is None:
network_dim = 4 # default
if network_alpha is None:
@@ -434,21 +443,21 @@ def create_network(
# block dim/alpha/lr
block_dims = kwargs.get("block_dims", None)
down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs)
block_lr_weight = parse_block_lr_kwargs(is_sdxl, kwargs)
# 以上のいずれかに指定があればblockごとのdim(rank)を有効にする
if block_dims is not None or down_lr_weight is not None or mid_lr_weight is not None or up_lr_weight is not None:
if block_dims is not None or block_lr_weight is not None:
block_alphas = kwargs.get("block_alphas", None)
conv_block_dims = kwargs.get("conv_block_dims", None)
conv_block_alphas = kwargs.get("conv_block_alphas", None)
block_dims, block_alphas, conv_block_dims, conv_block_alphas = get_block_dims_and_alphas(
block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
is_sdxl, block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
)
# remove block dim/alpha without learning rate
block_dims, block_alphas, conv_block_dims, conv_block_alphas = remove_block_dims_and_alphas(
block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
is_sdxl, block_dims, block_alphas, conv_block_dims, conv_block_alphas, block_lr_weight
)
else:
@@ -481,10 +490,20 @@ def create_network(
conv_block_dims=conv_block_dims,
conv_block_alphas=conv_block_alphas,
varbose=True,
is_sdxl=is_sdxl,
)
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
if block_lr_weight is not None:
network.set_block_lr_weight(block_lr_weight)
return network
@@ -494,9 +513,13 @@ def create_network(
# block_dims, block_alphas は両方ともNoneまたは両方とも値が入っている
# conv_dim, conv_alpha は両方ともNoneまたは両方とも値が入っている
def get_block_dims_and_alphas(
block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
is_sdxl, block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
):
num_total_blocks = LoRANetwork.NUM_OF_BLOCKS * 2 + 1
if not is_sdxl:
num_total_blocks = LoRANetwork.NUM_OF_BLOCKS * 2 + LoRANetwork.NUM_OF_MID_BLOCKS
else:
# 1+9+3+9+1=23, no LoRA for emb_layers (0)
num_total_blocks = 1 + LoRANetwork.SDXL_NUM_OF_BLOCKS * 2 + LoRANetwork.SDXL_NUM_OF_MID_BLOCKS + 1
def parse_ints(s):
return [int(i) for i in s.split(",")]
@@ -507,11 +530,14 @@ def get_block_dims_and_alphas(
# block_dimsとblock_alphasをパースする。必ず値が入る
if block_dims is not None:
block_dims = parse_ints(block_dims)
assert (
len(block_dims) == num_total_blocks
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
assert len(block_dims) == num_total_blocks, (
f"block_dims must have {num_total_blocks} elements but {len(block_dims)} elements are given"
+ f" / block_dims{num_total_blocks}個指定してください(指定された個数: {len(block_dims)}"
)
else:
print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
logger.warning(
f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります"
)
block_dims = [network_dim] * num_total_blocks
if block_alphas is not None:
@@ -520,7 +546,7 @@ def get_block_dims_and_alphas(
len(block_alphas) == num_total_blocks
), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください"
else:
print(
logger.warning(
f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります"
)
block_alphas = [network_alpha] * num_total_blocks
@@ -540,13 +566,13 @@ def get_block_dims_and_alphas(
else:
if conv_alpha is None:
conv_alpha = 1.0
print(
logger.warning(
f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります"
)
conv_block_alphas = [conv_alpha] * num_total_blocks
else:
if conv_dim is not None:
print(
logger.warning(
f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります"
)
conv_block_dims = [conv_dim] * num_total_blocks
@@ -558,15 +584,25 @@ def get_block_dims_and_alphas(
return block_dims, block_alphas, conv_block_dims, conv_block_alphas
# 層別学習率用に層ごとの学習率に対する倍率を定義する、外部から呼び出される可能性を考慮しておく
# 層別学習率用に層ごとの学習率に対する倍率を定義する、外部から呼び出せるようにclass外に出しておく
# 戻り値は block ごとの倍率のリスト
def get_block_lr_weight(
down_lr_weight, mid_lr_weight, up_lr_weight, zero_threshold
) -> Tuple[List[float], List[float], List[float]]:
is_sdxl,
down_lr_weight: Union[str, List[float]],
mid_lr_weight: List[float],
up_lr_weight: Union[str, List[float]],
zero_threshold: float,
) -> Optional[List[float]]:
# パラメータ未指定時は何もせず、今までと同じ動作とする
if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None:
return None, None, None
return None
max_len = LoRANetwork.NUM_OF_BLOCKS # フルモデル相当でのup,downの層の数
if not is_sdxl:
max_len_for_down_or_up = LoRANetwork.NUM_OF_BLOCKS
max_len_for_mid = LoRANetwork.NUM_OF_MID_BLOCKS
else:
max_len_for_down_or_up = LoRANetwork.SDXL_NUM_OF_BLOCKS
max_len_for_mid = LoRANetwork.SDXL_NUM_OF_MID_BLOCKS
def get_list(name_with_suffix) -> List[float]:
import math
@@ -576,17 +612,20 @@ def get_block_lr_weight(
base_lr = float(tokens[1]) if len(tokens) > 1 else 0.0
if name == "cosine":
return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in reversed(range(max_len))]
return [
math.sin(math.pi * (i / (max_len_for_down_or_up - 1)) / 2) + base_lr
for i in reversed(range(max_len_for_down_or_up))
]
elif name == "sine":
return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in range(max_len)]
return [math.sin(math.pi * (i / (max_len_for_down_or_up - 1)) / 2) + base_lr for i in range(max_len_for_down_or_up)]
elif name == "linear":
return [i / (max_len - 1) + base_lr for i in range(max_len)]
return [i / (max_len_for_down_or_up - 1) + base_lr for i in range(max_len_for_down_or_up)]
elif name == "reverse_linear":
return [i / (max_len - 1) + base_lr for i in reversed(range(max_len))]
return [i / (max_len_for_down_or_up - 1) + base_lr for i in reversed(range(max_len_for_down_or_up))]
elif name == "zeros":
return [0.0 + base_lr] * max_len
return [0.0 + base_lr] * max_len_for_down_or_up
else:
print(
logger.error(
"Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros"
% (name)
)
@@ -597,99 +636,176 @@ def get_block_lr_weight(
if type(up_lr_weight) == str:
up_lr_weight = get_list(up_lr_weight)
if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len):
print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
up_lr_weight = up_lr_weight[:max_len]
down_lr_weight = down_lr_weight[:max_len]
if (up_lr_weight != None and len(up_lr_weight) > max_len_for_down_or_up) or (
down_lr_weight != None and len(down_lr_weight) > max_len_for_down_or_up
):
logger.warning("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len_for_down_or_up)
logger.warning("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len_for_down_or_up)
up_lr_weight = up_lr_weight[:max_len_for_down_or_up]
down_lr_weight = down_lr_weight[:max_len_for_down_or_up]
if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len):
print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
print("down_weightもしくはup_weightがすぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
if mid_lr_weight != None and len(mid_lr_weight) > max_len_for_mid:
logger.warning("mid_weight is too long. Parameters after %d-th are ignored." % max_len_for_mid)
logger.warning("mid_weightがすぎます。%d個目以降のパラメータは無視されます。" % max_len_for_mid)
mid_lr_weight = mid_lr_weight[:max_len_for_mid]
if down_lr_weight != None and len(down_lr_weight) < max_len:
down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight))
if up_lr_weight != None and len(up_lr_weight) < max_len:
up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight))
if (up_lr_weight != None and len(up_lr_weight) < max_len_for_down_or_up) or (
down_lr_weight != None and len(down_lr_weight) < max_len_for_down_or_up
):
logger.warning("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len_for_down_or_up)
logger.warning(
"down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len_for_down_or_up
)
if down_lr_weight != None and len(down_lr_weight) < max_len_for_down_or_up:
down_lr_weight = down_lr_weight + [1.0] * (max_len_for_down_or_up - len(down_lr_weight))
if up_lr_weight != None and len(up_lr_weight) < max_len_for_down_or_up:
up_lr_weight = up_lr_weight + [1.0] * (max_len_for_down_or_up - len(up_lr_weight))
if mid_lr_weight != None and len(mid_lr_weight) < max_len_for_mid:
logger.warning("mid_weight is too short. Parameters after %d-th are filled with 1." % max_len_for_mid)
logger.warning("mid_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len_for_mid)
mid_lr_weight = mid_lr_weight + [1.0] * (max_len_for_mid - len(mid_lr_weight))
if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
print("apply block learning rate / 階層別学習率を適用します。")
logger.info("apply block learning rate / 階層別学習率を適用します。")
if down_lr_weight != None:
down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight]
print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", down_lr_weight)
logger.info(f"down_lr_weight (shallower -> deeper, 浅い層->深い層): {down_lr_weight}")
else:
print("down_lr_weight: all 1.0, すべて1.0")
down_lr_weight = [1.0] * max_len_for_down_or_up
logger.info("down_lr_weight: all 1.0, すべて1.0")
if mid_lr_weight != None:
mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0
print("mid_lr_weight:", mid_lr_weight)
mid_lr_weight = [w if w > zero_threshold else 0 for w in mid_lr_weight]
logger.info(f"mid_lr_weight: {mid_lr_weight}")
else:
print("mid_lr_weight: 1.0")
mid_lr_weight = [1.0] * max_len_for_mid
logger.info("mid_lr_weight: all 1.0, すべて1.0")
if up_lr_weight != None:
up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight]
print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", up_lr_weight)
logger.info(f"up_lr_weight (deeper -> shallower, 深い層->浅い層): {up_lr_weight}")
else:
print("up_lr_weight: all 1.0, すべて1.0")
up_lr_weight = [1.0] * max_len_for_down_or_up
logger.info("up_lr_weight: all 1.0, すべて1.0")
return down_lr_weight, mid_lr_weight, up_lr_weight
lr_weight = down_lr_weight + mid_lr_weight + up_lr_weight
if is_sdxl:
lr_weight = [1.0] + lr_weight + [1.0] # add 1.0 for emb_layers and out
assert (not is_sdxl and len(lr_weight) == LoRANetwork.NUM_OF_BLOCKS * 2 + LoRANetwork.NUM_OF_MID_BLOCKS) or (
is_sdxl and len(lr_weight) == 1 + LoRANetwork.SDXL_NUM_OF_BLOCKS * 2 + LoRANetwork.SDXL_NUM_OF_MID_BLOCKS + 1
), f"lr_weight length is invalid: {len(lr_weight)}"
return lr_weight
# lr_weightが0のblockをblock_dimsから除外する、外部から呼び出す可能性を考慮しておく
def remove_block_dims_and_alphas(
block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
is_sdxl, block_dims, block_alphas, conv_block_dims, conv_block_alphas, block_lr_weight: Optional[List[float]]
):
# set 0 to block dim without learning rate to remove the block
if down_lr_weight != None:
for i, lr in enumerate(down_lr_weight):
if block_lr_weight is not None:
for i, lr in enumerate(block_lr_weight):
if lr == 0:
block_dims[i] = 0
if conv_block_dims is not None:
conv_block_dims[i] = 0
if mid_lr_weight != None:
if mid_lr_weight == 0:
block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0
if conv_block_dims is not None:
conv_block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0
if up_lr_weight != None:
for i, lr in enumerate(up_lr_weight):
if lr == 0:
block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0
if conv_block_dims is not None:
conv_block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0
return block_dims, block_alphas, conv_block_dims, conv_block_alphas
# 外部から呼び出す可能性を考慮しておく
def get_block_index(lora_name: str) -> int:
def get_block_index(lora_name: str, is_sdxl: bool = False) -> int:
block_idx = -1 # invalid lora name
if not is_sdxl:
m = RE_UPDOWN.search(lora_name)
if m:
g = m.groups()
i = int(g[1])
j = int(g[3])
if g[2] == "resnets":
idx = 3 * i + j
elif g[2] == "attentions":
idx = 3 * i + j
elif g[2] == "upsamplers" or g[2] == "downsamplers":
idx = 3 * i + 2
m = RE_UPDOWN.search(lora_name)
if m:
g = m.groups()
i = int(g[1])
j = int(g[3])
if g[2] == "resnets":
idx = 3 * i + j
elif g[2] == "attentions":
idx = 3 * i + j
elif g[2] == "upsamplers" or g[2] == "downsamplers":
idx = 3 * i + 2
if g[0] == "down":
block_idx = 1 + idx # 0に該当するLoRAは存在しない
elif g[0] == "up":
block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx
elif "mid_block_" in lora_name:
block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12
if g[0] == "down":
block_idx = 1 + idx # 0に該当するLoRAは存在しない
elif g[0] == "up":
block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx
elif "mid_block_" in lora_name:
block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12
else:
# copy from sdxl_train
if lora_name.startswith("lora_unet_"):
name = lora_name[len("lora_unet_") :]
if name.startswith("time_embed_") or name.startswith("label_emb_"): # No LoRA
block_idx = 0 # 0
elif name.startswith("input_blocks_"): # 1-9
block_idx = 1 + int(name.split("_")[2])
elif name.startswith("middle_block_"): # 10-12
block_idx = 10 + int(name.split("_")[2])
elif name.startswith("output_blocks_"): # 13-21
block_idx = 13 + int(name.split("_")[2])
elif name.startswith("out_"): # 22, out, no LoRA
block_idx = 22
return block_idx
def convert_diffusers_to_sai_if_needed(weights_sd):
# only supports U-Net LoRA modules
found_up_down_blocks = False
for k in list(weights_sd.keys()):
if "down_blocks" in k:
found_up_down_blocks = True
break
if "up_blocks" in k:
found_up_down_blocks = True
break
if not found_up_down_blocks:
return
from library.sdxl_model_util import make_unet_conversion_map
unet_conversion_map = make_unet_conversion_map()
unet_conversion_map = {hf.replace(".", "_")[:-1]: sd.replace(".", "_")[:-1] for sd, hf in unet_conversion_map}
# # add extra conversion
# unet_conversion_map["up_blocks_1_upsamplers_0"] = "lora_unet_output_blocks_2_2_conv"
logger.info(f"Converting LoRA keys from Diffusers to SAI")
lora_unet_prefix = "lora_unet_"
for k in list(weights_sd.keys()):
if not k.startswith(lora_unet_prefix):
continue
unet_module_name = k[len(lora_unet_prefix) :].split(".")[0]
# search for conversion: this is slow because the algorithm is O(n^2), but the number of keys is small
for hf_module_name, sd_module_name in unet_conversion_map.items():
if hf_module_name in unet_module_name:
new_key = (
lora_unet_prefix
+ unet_module_name.replace(hf_module_name, sd_module_name)
+ k[len(lora_unet_prefix) + len(unet_module_name) :]
)
weights_sd[new_key] = weights_sd.pop(k)
found = True
break
if not found:
logger.warning(f"Key {k} is not found in unet_conversion_map")
# Create network from weights for inference, weights are not loaded here (because can be merged)
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
# if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True
is_sdxl = unet is not None and issubclass(unet.__class__, SdxlUNet2DConditionModel)
if weights_sd is None:
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file, safe_open
@@ -698,6 +814,10 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
else:
weights_sd = torch.load(file, map_location="cpu")
# if keys are Diffusers based, convert to SAI based
if is_sdxl:
convert_diffusers_to_sai_if_needed(weights_sd)
# get dim/alpha mapping
modules_dim = {}
modules_alpha = {}
@@ -711,7 +831,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
elif "lora_down" in key:
dim = value.size()[0]
modules_dim[lora_name] = dim
# print(lora_name, value.size(), dim)
# logger.info(lora_name, value.size(), dim)
# support old LoRA without alpha
for key in modules_dim.keys():
@@ -721,23 +841,32 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
module_class = LoRAInfModule if for_inference else LoRAModule
network = LoRANetwork(
text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class
text_encoder,
unet,
multiplier=multiplier,
modules_dim=modules_dim,
modules_alpha=modules_alpha,
module_class=module_class,
is_sdxl=is_sdxl,
)
# block lr
down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs)
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
block_lr_weight = parse_block_lr_kwargs(is_sdxl, kwargs)
if block_lr_weight is not None:
network.set_block_lr_weight(block_lr_weight)
return network, weights_sd
class LoRANetwork(torch.nn.Module):
NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
NUM_OF_MID_BLOCKS = 1
SDXL_NUM_OF_BLOCKS = 9 # SDXLのモデルでのinput/outputの層の数 total=1(base) 9(input) + 3(mid) + 9(output) + 1(out) = 23
SDXL_NUM_OF_MID_BLOCKS = 3
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
LORA_PREFIX_UNET = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te"
@@ -765,6 +894,7 @@ class LoRANetwork(torch.nn.Module):
modules_alpha: Optional[Dict[str, int]] = None,
module_class: Type[object] = LoRAModule,
varbose: Optional[bool] = False,
is_sdxl: Optional[bool] = False,
) -> None:
"""
LoRA network: すごく引数が多いが、パターンは以下の通り
@@ -785,21 +915,31 @@ class LoRANetwork(torch.nn.Module):
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None
self.loraplus_text_encoder_lr_ratio = None
if modules_dim is not None:
print(f"create LoRA network from weights")
logger.info(f"create LoRA network from weights")
elif block_dims is not None:
print(f"create LoRA network from block_dims")
print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
print(f"block_dims: {block_dims}")
print(f"block_alphas: {block_alphas}")
logger.info(f"create LoRA network from block_dims")
logger.info(
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
)
logger.info(f"block_dims: {block_dims}")
logger.info(f"block_alphas: {block_alphas}")
if conv_block_dims is not None:
print(f"conv_block_dims: {conv_block_dims}")
print(f"conv_block_alphas: {conv_block_alphas}")
logger.info(f"conv_block_dims: {conv_block_dims}")
logger.info(f"conv_block_alphas: {conv_block_alphas}")
else:
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
logger.info(
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
)
if self.conv_lora_dim is not None:
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
logger.info(
f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
)
# create module instances
def create_modules(
@@ -840,7 +980,7 @@ class LoRANetwork(torch.nn.Module):
alpha = modules_alpha[lora_name]
elif is_unet and block_dims is not None:
# U-Netでblock_dims指定あり
block_idx = get_block_index(lora_name)
block_idx = get_block_index(lora_name, is_sdxl)
if is_linear or is_conv2d_1x1:
dim = block_dims[block_idx]
alpha = block_alphas[block_idx]
@@ -884,15 +1024,15 @@ class LoRANetwork(torch.nn.Module):
for i, text_encoder in enumerate(text_encoders):
if len(text_encoders) > 1:
index = i + 1
print(f"create LoRA for Text Encoder {index}:")
logger.info(f"create LoRA for Text Encoder {index}:")
else:
index = None
print(f"create LoRA for Text Encoder:")
logger.info(f"create LoRA for Text Encoder:")
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
self.text_encoder_loras.extend(text_encoder_loras)
skipped_te += skipped
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
@@ -900,19 +1040,17 @@ class LoRANetwork(torch.nn.Module):
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
skipped = skipped_te + skipped_un
if varbose and len(skipped) > 0:
print(
logger.warning(
f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
)
for name in skipped:
print(f"\t{name}")
logger.info(f"\t{name}")
self.up_lr_weight: List[float] = None
self.down_lr_weight: List[float] = None
self.mid_lr_weight: float = None
self.block_lr_weight = None
self.block_lr = False
# assertion
@@ -926,6 +1064,10 @@ class LoRANetwork(torch.nn.Module):
for lora in self.text_encoder_loras + self.unet_loras:
lora.multiplier = self.multiplier
def set_enabled(self, is_enabled):
for lora in self.text_encoder_loras + self.unet_loras:
lora.enabled = is_enabled
def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
@@ -939,12 +1081,12 @@ class LoRANetwork(torch.nn.Module):
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
if apply_text_encoder:
print("enable LoRA for text encoder")
logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules")
else:
self.text_encoder_loras = []
if apply_unet:
print("enable LoRA for U-Net")
logger.info(f"enable LoRA for U-Net: {len(self.unet_loras)} modules")
else:
self.unet_loras = []
@@ -966,12 +1108,12 @@ class LoRANetwork(torch.nn.Module):
apply_unet = True
if apply_text_encoder:
print("enable LoRA for text encoder")
logger.info("enable LoRA for text encoder")
else:
self.text_encoder_loras = []
if apply_unet:
print("enable LoRA for U-Net")
logger.info("enable LoRA for U-Net")
else:
self.unet_loras = []
@@ -982,84 +1124,120 @@ class LoRANetwork(torch.nn.Module):
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
lora.merge_to(sd_for_lora, dtype, device)
print(f"weights are merged")
logger.info(f"weights are merged")
# 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない
def set_block_lr_weight(
self,
up_lr_weight: List[float] = None,
mid_lr_weight: float = None,
down_lr_weight: List[float] = None,
):
def set_block_lr_weight(self, block_lr_weight: Optional[List[float]]):
self.block_lr = True
self.down_lr_weight = down_lr_weight
self.mid_lr_weight = mid_lr_weight
self.up_lr_weight = up_lr_weight
self.block_lr_weight = block_lr_weight
def get_lr_weight(self, lora: LoRAModule) -> float:
lr_weight = 1.0
block_idx = get_block_index(lora.lora_name)
if block_idx < 0:
return lr_weight
def get_lr_weight(self, block_idx: int) -> float:
if not self.block_lr or self.block_lr_weight is None:
return 1.0
return self.block_lr_weight[block_idx]
if block_idx < LoRANetwork.NUM_OF_BLOCKS:
if self.down_lr_weight != None:
lr_weight = self.down_lr_weight[block_idx]
elif block_idx == LoRANetwork.NUM_OF_BLOCKS:
if self.mid_lr_weight != None:
lr_weight = self.mid_lr_weight
elif block_idx > LoRANetwork.NUM_OF_BLOCKS:
if self.up_lr_weight != None:
lr_weight = self.up_lr_weight[block_idx - LoRANetwork.NUM_OF_BLOCKS - 1]
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
self.loraplus_lr_ratio = loraplus_lr_ratio
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
return lr_weight
logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}")
logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}")
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
self.requires_grad_(True)
all_params = []
# TODO warn if optimizer is not compatible with LoRA+ (but it will cause error so we don't need to check it here?)
# if (
# self.loraplus_lr_ratio is not None
# or self.loraplus_text_encoder_lr_ratio is not None
# or self.loraplus_unet_lr_ratio is not None
# ):
# assert (
# optimizer_type.lower() != "prodigy" and "dadapt" not in optimizer_type.lower()
# ), "LoRA+ and Prodigy/DAdaptation is not supported / LoRA+とProdigy/DAdaptationの組み合わせはサポートされていません"
def enumerate_params(loras):
params = []
self.requires_grad_(True)
all_params = []
lr_descriptions = []
def assemble_params(loras, lr, ratio):
param_groups = {"lora": {}, "plus": {}}
for lora in loras:
params.extend(lora.parameters())
return params
for name, param in lora.named_parameters():
if ratio is not None and "lora_up" in name:
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
else:
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
params = []
descriptions = []
for key in param_groups.keys():
param_data = {"params": param_groups[key].values()}
if len(param_data["params"]) == 0:
continue
if lr is not None:
if key == "plus":
param_data["lr"] = lr * ratio
else:
param_data["lr"] = lr
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
logger.info("NO LR skipping!")
continue
params.append(param_data)
descriptions.append("plus" if key == "plus" else "")
return params, descriptions
if self.text_encoder_loras:
param_data = {"params": enumerate_params(self.text_encoder_loras)}
if text_encoder_lr is not None:
param_data["lr"] = text_encoder_lr
all_params.append(param_data)
params, descriptions = assemble_params(
self.text_encoder_loras,
text_encoder_lr if text_encoder_lr is not None else default_lr,
self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio,
)
all_params.extend(params)
lr_descriptions.extend(["textencoder" + (" " + d if d else "") for d in descriptions])
if self.unet_loras:
if self.block_lr:
is_sdxl = False
for lora in self.unet_loras:
if "input_blocks" in lora.lora_name or "output_blocks" in lora.lora_name:
is_sdxl = True
break
# 学習率のグラフをblockごとにしたいので、blockごとにloraを分類
block_idx_to_lora = {}
for lora in self.unet_loras:
idx = get_block_index(lora.lora_name)
idx = get_block_index(lora.lora_name, is_sdxl)
if idx not in block_idx_to_lora:
block_idx_to_lora[idx] = []
block_idx_to_lora[idx].append(lora)
# blockごとにパラメータを設定する
for idx, block_loras in block_idx_to_lora.items():
param_data = {"params": enumerate_params(block_loras)}
if unet_lr is not None:
param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0])
elif default_lr is not None:
param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0])
if ("lr" in param_data) and (param_data["lr"] == 0):
continue
all_params.append(param_data)
params, descriptions = assemble_params(
block_loras,
(unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(idx),
self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio,
)
all_params.extend(params)
lr_descriptions.extend([f"unet_block{idx}" + (" " + d if d else "") for d in descriptions])
else:
param_data = {"params": enumerate_params(self.unet_loras)}
if unet_lr is not None:
param_data["lr"] = unet_lr
all_params.append(param_data)
params, descriptions = assemble_params(
self.unet_loras,
unet_lr if unet_lr is not None else default_lr,
self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio,
)
all_params.extend(params)
lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions])
return all_params
return all_params, lr_descriptions
def enable_gradient_checkpointing(self):
# not supported
@@ -1113,7 +1291,7 @@ class LoRANetwork(torch.nn.Module):
for lora in self.text_encoder_loras + self.unet_loras:
lora.set_network(self)
def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared):
def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared, ds_ratio=None):
self.batch_size = batch_size
self.num_sub_prompts = num_sub_prompts
self.current_size = (height, width)
@@ -1128,7 +1306,7 @@ class LoRANetwork(torch.nn.Module):
device = ref_weight.device
def resize_add(mh, mw):
# print(mh, mw, mh * mw)
# logger.info(mh, mw, mh * mw)
m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16
m = m.to(device, dtype=dtype)
mask_dic[mh * mw] = m
@@ -1139,6 +1317,13 @@ class LoRANetwork(torch.nn.Module):
resize_add(h, w)
if h % 2 == 1 or w % 2 == 1: # add extra shape if h/w is not divisible by 2
resize_add(h + h % 2, w + w % 2)
# deep shrink
if ds_ratio is not None:
hd = int(h * ds_ratio)
wd = int(w * ds_ratio)
resize_add(hd, wd)
h = (h + 1) // 2
w = (w + 1) // 2

View File

@@ -9,8 +9,15 @@ from diffusers import UNet2DConditionModel
import numpy as np
from tqdm import tqdm
from transformers import CLIPTextModel
import torch
import torch
from library.device_utils import init_ipex, get_preferred_device
init_ipex()
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def make_unet_conversion_map() -> Dict[str, str]:
unet_conversion_map_layer = []
@@ -117,7 +124,7 @@ class LoRAModule(torch.nn.Module):
super().__init__()
self.lora_name = lora_name
if org_module.__class__.__name__ == "Conv2d":
if org_module.__class__.__name__ == "Conv2d" or org_module.__class__.__name__ == "LoRACompatibleConv":
in_dim = org_module.in_channels
out_dim = org_module.out_channels
else:
@@ -126,7 +133,7 @@ class LoRAModule(torch.nn.Module):
self.lora_dim = lora_dim
if org_module.__class__.__name__ == "Conv2d":
if org_module.__class__.__name__ == "Conv2d" or org_module.__class__.__name__ == "LoRACompatibleConv":
kernel_size = org_module.kernel_size
stride = org_module.stride
padding = org_module.padding
@@ -166,7 +173,8 @@ class LoRAModule(torch.nn.Module):
self.org_module[0].forward = self.org_forward
# forward with lora
def forward(self, x):
# scale is used LoRACompatibleConv, but we ignore it because we have multiplier
def forward(self, x, scale=1.0):
if not self.enabled:
return self.org_forward(x)
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
@@ -247,7 +255,7 @@ def create_network_from_weights(
elif "lora_down" in key:
dim = value.size()[0]
modules_dim[lora_name] = dim
# print(lora_name, value.size(), dim)
# logger.info(f"{lora_name} {value.size()} {dim}")
# support old LoRA without alpha
for key in modules_dim.keys():
@@ -270,7 +278,7 @@ def merge_lora_weights(pipe, weights_sd: Dict, multiplier: float = 1.0):
class LoRANetwork(torch.nn.Module):
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
LORA_PREFIX_UNET = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te"
@@ -290,12 +298,12 @@ class LoRANetwork(torch.nn.Module):
super().__init__()
self.multiplier = multiplier
print(f"create LoRA network from weights")
logger.info("create LoRA network from weights")
# convert SDXL Stability AI's U-Net modules to Diffusers
converted = self.convert_unet_modules(modules_dim, modules_alpha)
if converted:
print(f"converted {converted} Stability AI's U-Net LoRA modules to Diffusers (SDXL)")
logger.info(f"converted {converted} Stability AI's U-Net LoRA modules to Diffusers (SDXL)")
# create module instances
def create_modules(
@@ -318,15 +326,19 @@ class LoRANetwork(torch.nn.Module):
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
is_linear = child_module.__class__.__name__ == "Linear"
is_conv2d = child_module.__class__.__name__ == "Conv2d"
is_linear = (
child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "LoRACompatibleLinear"
)
is_conv2d = (
child_module.__class__.__name__ == "Conv2d" or child_module.__class__.__name__ == "LoRACompatibleConv"
)
if is_linear or is_conv2d:
lora_name = prefix + "." + name + "." + child_name
lora_name = lora_name.replace(".", "_")
if lora_name not in modules_dim:
# print(f"skipped {lora_name} (not found in modules_dim)")
# logger.info(f"skipped {lora_name} (not found in modules_dim)")
skipped.append(lora_name)
continue
@@ -357,18 +369,18 @@ class LoRANetwork(torch.nn.Module):
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
self.text_encoder_loras.extend(text_encoder_loras)
skipped_te += skipped
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
if len(skipped_te) > 0:
print(f"skipped {len(skipped_te)} modules because of missing weight.")
logger.warning(f"skipped {len(skipped_te)} modules because of missing weight for text encoder.")
# extend U-Net target modules to include Conv2d 3x3
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
self.unet_loras: List[LoRAModule]
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
if len(skipped_un) > 0:
print(f"skipped {len(skipped_un)} modules because of missing weight.")
logger.warning(f"skipped {len(skipped_un)} modules because of missing weight for U-Net.")
# assertion
names = set()
@@ -415,11 +427,11 @@ class LoRANetwork(torch.nn.Module):
def apply_to(self, multiplier=1.0, apply_text_encoder=True, apply_unet=True):
if apply_text_encoder:
print("enable LoRA for text encoder")
logger.info("enable LoRA for text encoder")
for lora in self.text_encoder_loras:
lora.apply_to(multiplier)
if apply_unet:
print("enable LoRA for U-Net")
logger.info("enable LoRA for U-Net")
for lora in self.unet_loras:
lora.apply_to(multiplier)
@@ -428,16 +440,16 @@ class LoRANetwork(torch.nn.Module):
lora.unapply_to()
def merge_to(self, multiplier=1.0):
print("merge LoRA weights to original weights")
logger.info("merge LoRA weights to original weights")
for lora in tqdm(self.text_encoder_loras + self.unet_loras):
lora.merge_to(multiplier)
print(f"weights are merged")
logger.info(f"weights are merged")
def restore_from(self, multiplier=1.0):
print("restore LoRA weights from original weights")
logger.info("restore LoRA weights from original weights")
for lora in tqdm(self.text_encoder_loras + self.unet_loras):
lora.restore_from(multiplier)
print(f"weights are restored")
logger.info(f"weights are restored")
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
# convert SDXL Stability AI's state dict to Diffusers' based state dict
@@ -458,7 +470,7 @@ class LoRANetwork(torch.nn.Module):
my_state_dict = self.state_dict()
for key in state_dict.keys():
if state_dict[key].size() != my_state_dict[key].size():
# print(f"convert {key} from {state_dict[key].size()} to {my_state_dict[key].size()}")
# logger.info(f"convert {key} from {state_dict[key].size()} to {my_state_dict[key].size()}")
state_dict[key] = state_dict[key].view(my_state_dict[key].size())
return super().load_state_dict(state_dict, strict)
@@ -471,7 +483,7 @@ if __name__ == "__main__":
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = get_preferred_device()
parser = argparse.ArgumentParser()
parser.add_argument("--model_id", type=str, default=None, help="model id for huggingface")
@@ -485,7 +497,7 @@ if __name__ == "__main__":
image_prefix = args.model_id.replace("/", "_") + "_"
# load Diffusers model
print(f"load model from {args.model_id}")
logger.info(f"load model from {args.model_id}")
pipe: Union[StableDiffusionPipeline, StableDiffusionXLPipeline]
if args.sdxl:
# use_safetensors=True does not work with 0.18.2
@@ -498,7 +510,7 @@ if __name__ == "__main__":
text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if args.sdxl else [pipe.text_encoder]
# load LoRA weights
print(f"load LoRA weights from {args.lora_weights}")
logger.info(f"load LoRA weights from {args.lora_weights}")
if os.path.splitext(args.lora_weights)[1] == ".safetensors":
from safetensors.torch import load_file
@@ -507,10 +519,10 @@ if __name__ == "__main__":
lora_sd = torch.load(args.lora_weights)
# create by LoRA weights and load weights
print(f"create LoRA network")
logger.info(f"create LoRA network")
lora_network: LoRANetwork = create_network_from_weights(text_encoders, pipe.unet, lora_sd, multiplier=1.0)
print(f"load LoRA network weights")
logger.info(f"load LoRA network weights")
lora_network.load_state_dict(lora_sd)
lora_network.to(device, dtype=pipe.unet.dtype) # required to apply_to. merge_to works without this
@@ -539,34 +551,34 @@ if __name__ == "__main__":
random.seed(seed)
# create image with original weights
print(f"create image with original weights")
logger.info(f"create image with original weights")
seed_everything(args.seed)
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
image.save(image_prefix + "original.png")
# apply LoRA network to the model: slower than merge_to, but can be reverted easily
print(f"apply LoRA network to the model")
logger.info(f"apply LoRA network to the model")
lora_network.apply_to(multiplier=1.0)
print(f"create image with applied LoRA")
logger.info(f"create image with applied LoRA")
seed_everything(args.seed)
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
image.save(image_prefix + "applied_lora.png")
# unapply LoRA network to the model
print(f"unapply LoRA network to the model")
logger.info(f"unapply LoRA network to the model")
lora_network.unapply_to()
print(f"create image with unapplied LoRA")
logger.info(f"create image with unapplied LoRA")
seed_everything(args.seed)
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
image.save(image_prefix + "unapplied_lora.png")
# merge LoRA network to the model: faster than apply_to, but requires back-up of original weights (or unmerge_to)
print(f"merge LoRA network to the model")
logger.info(f"merge LoRA network to the model")
lora_network.merge_to(multiplier=1.0)
print(f"create image with LoRA")
logger.info(f"create image with LoRA")
seed_everything(args.seed)
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
image.save(image_prefix + "merged_lora.png")
@@ -574,31 +586,31 @@ if __name__ == "__main__":
# restore (unmerge) LoRA weights: numerically unstable
# マージされた重みを元に戻す。計算誤差のため、元の重みと完全に一致しないことがあるかもしれない
# 保存したstate_dictから元の重みを復元するのが確実
print(f"restore (unmerge) LoRA weights")
logger.info(f"restore (unmerge) LoRA weights")
lora_network.restore_from(multiplier=1.0)
print(f"create image without LoRA")
logger.info(f"create image without LoRA")
seed_everything(args.seed)
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
image.save(image_prefix + "unmerged_lora.png")
# restore original weights
print(f"restore original weights")
logger.info(f"restore original weights")
pipe.unet.load_state_dict(org_unet_sd)
pipe.text_encoder.load_state_dict(org_text_encoder_sd)
if args.sdxl:
pipe.text_encoder_2.load_state_dict(org_text_encoder_2_sd)
print(f"create image with restored original weights")
logger.info(f"create image with restored original weights")
seed_everything(args.seed)
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
image.save(image_prefix + "restore_original.png")
# use convenience function to merge LoRA weights
print(f"merge LoRA weights with convenience function")
logger.info(f"merge LoRA weights with convenience function")
merge_lora_weights(pipe, lora_sd, multiplier=1.0)
print(f"create image with merged LoRA weights")
logger.info(f"create image with merged LoRA weights")
seed_everything(args.seed)
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
image.save(image_prefix + "convenience_merged_lora.png")

View File

@@ -14,7 +14,10 @@ from transformers import CLIPTextModel
import numpy as np
import torch
import re
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
@@ -49,7 +52,7 @@ class LoRAModule(torch.nn.Module):
# if limit_rank:
# self.lora_dim = min(lora_dim, in_dim, out_dim)
# if self.lora_dim != lora_dim:
# print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
# logger.info(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
# else:
self.lora_dim = lora_dim
@@ -197,7 +200,7 @@ class LoRAInfModule(LoRAModule):
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# print(conved.size(), weight.size(), module.stride, module.padding)
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + self.multiplier * conved * self.scale
# set weight to org_module
@@ -236,7 +239,7 @@ class LoRAInfModule(LoRAModule):
self.region_mask = None
def default_forward(self, x):
# print("default_forward", self.lora_name, x.size())
# logger.info("default_forward", self.lora_name, x.size())
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
def forward(self, x):
@@ -278,7 +281,7 @@ class LoRAInfModule(LoRAModule):
# apply mask for LoRA result
lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
mask = self.get_mask_for_x(lx)
# print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size())
# logger.info("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size())
lx = lx * mask
x = self.org_forward(x)
@@ -307,7 +310,7 @@ class LoRAInfModule(LoRAModule):
if has_real_uncond:
query[-self.network.batch_size :] = x[-self.network.batch_size :]
# print("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts)
# logger.info("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts)
return query
def sub_prompt_forward(self, x):
@@ -322,7 +325,7 @@ class LoRAInfModule(LoRAModule):
lx = x[emb_idx :: self.network.num_sub_prompts]
lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale
# print("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx)
# logger.info("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx)
x = self.org_forward(x)
x[emb_idx :: self.network.num_sub_prompts] += lx
@@ -330,7 +333,7 @@ class LoRAInfModule(LoRAModule):
return x
def to_out_forward(self, x):
# print("to_out_forward", self.lora_name, x.size(), self.network.is_last_network)
# logger.info("to_out_forward", self.lora_name, x.size(), self.network.is_last_network)
if self.network.is_last_network:
masks = [None] * self.network.num_sub_prompts
@@ -348,7 +351,7 @@ class LoRAInfModule(LoRAModule):
)
self.network.shared[self.lora_name] = (lx, masks)
# print("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts)
# logger.info("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts)
lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1
masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1)
@@ -367,7 +370,7 @@ class LoRAInfModule(LoRAModule):
if has_real_uncond:
out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond
# print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts)
# logger.info("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts)
# for i in range(len(masks)):
# if masks[i] is None:
# masks[i] = torch.zeros_like(masks[-1])
@@ -389,7 +392,7 @@ class LoRAInfModule(LoRAModule):
x1 = x1 + lx1
out[self.network.batch_size + i] = x1
# print("to_out_forward", x.size(), out.size(), has_real_uncond)
# logger.info("to_out_forward", x.size(), out.size(), has_real_uncond)
return out
@@ -526,7 +529,7 @@ def get_block_dims_and_alphas(
len(block_dims) == num_total_blocks
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
else:
print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
logger.warning(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
block_dims = [network_dim] * num_total_blocks
if block_alphas is not None:
@@ -535,7 +538,7 @@ def get_block_dims_and_alphas(
len(block_alphas) == num_total_blocks
), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください"
else:
print(
logger.warning(
f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります"
)
block_alphas = [network_alpha] * num_total_blocks
@@ -555,13 +558,13 @@ def get_block_dims_and_alphas(
else:
if conv_alpha is None:
conv_alpha = 1.0
print(
logger.warning(
f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります"
)
conv_block_alphas = [conv_alpha] * num_total_blocks
else:
if conv_dim is not None:
print(
logger.warning(
f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります"
)
conv_block_dims = [conv_dim] * num_total_blocks
@@ -601,7 +604,7 @@ def get_block_lr_weight(
elif name == "zeros":
return [0.0 + base_lr] * max_len
else:
print(
logger.error(
"Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros"
% (name)
)
@@ -613,14 +616,14 @@ def get_block_lr_weight(
up_lr_weight = get_list(up_lr_weight)
if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len):
print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
logger.warning("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
logger.warning("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
up_lr_weight = up_lr_weight[:max_len]
down_lr_weight = down_lr_weight[:max_len]
if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len):
print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
logger.warning("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
logger.warning("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
if down_lr_weight != None and len(down_lr_weight) < max_len:
down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight))
@@ -628,24 +631,24 @@ def get_block_lr_weight(
up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight))
if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
print("apply block learning rate / 階層別学習率を適用します。")
logger.info("apply block learning rate / 階層別学習率を適用します。")
if down_lr_weight != None:
down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight]
print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", down_lr_weight)
logger.info(f"down_lr_weight (shallower -> deeper, 浅い層->深い層): {down_lr_weight}")
else:
print("down_lr_weight: all 1.0, すべて1.0")
logger.info("down_lr_weight: all 1.0, すべて1.0")
if mid_lr_weight != None:
mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0
print("mid_lr_weight:", mid_lr_weight)
logger.info(f"mid_lr_weight: {mid_lr_weight}")
else:
print("mid_lr_weight: 1.0")
logger.info("mid_lr_weight: 1.0")
if up_lr_weight != None:
up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight]
print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", up_lr_weight)
logger.info(f"up_lr_weight (deeper -> shallower, 深い層->浅い層): {up_lr_weight}")
else:
print("up_lr_weight: all 1.0, すべて1.0")
logger.info("up_lr_weight: all 1.0, すべて1.0")
return down_lr_weight, mid_lr_weight, up_lr_weight
@@ -726,7 +729,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
elif "lora_down" in key:
dim = value.size()[0]
modules_dim[lora_name] = dim
# print(lora_name, value.size(), dim)
# logger.info(lora_name, value.size(), dim)
# support old LoRA without alpha
for key in modules_dim.keys():
@@ -752,7 +755,7 @@ class LoRANetwork(torch.nn.Module):
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
LORA_PREFIX_UNET = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te"
@@ -801,20 +804,20 @@ class LoRANetwork(torch.nn.Module):
self.module_dropout = module_dropout
if modules_dim is not None:
print(f"create LoRA network from weights")
logger.info(f"create LoRA network from weights")
elif block_dims is not None:
print(f"create LoRA network from block_dims")
print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
print(f"block_dims: {block_dims}")
print(f"block_alphas: {block_alphas}")
logger.info(f"create LoRA network from block_dims")
logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
logger.info(f"block_dims: {block_dims}")
logger.info(f"block_alphas: {block_alphas}")
if conv_block_dims is not None:
print(f"conv_block_dims: {conv_block_dims}")
print(f"conv_block_alphas: {conv_block_alphas}")
logger.info(f"conv_block_dims: {conv_block_dims}")
logger.info(f"conv_block_alphas: {conv_block_alphas}")
else:
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
if self.conv_lora_dim is not None:
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
# create module instances
def create_modules(
@@ -899,15 +902,15 @@ class LoRANetwork(torch.nn.Module):
for i, text_encoder in enumerate(text_encoders):
if len(text_encoders) > 1:
index = i + 1
print(f"create LoRA for Text Encoder {index}:")
logger.info(f"create LoRA for Text Encoder {index}:")
else:
index = None
print(f"create LoRA for Text Encoder:")
logger.info(f"create LoRA for Text Encoder:")
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
self.text_encoder_loras.extend(text_encoder_loras)
skipped_te += skipped
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
@@ -915,15 +918,15 @@ class LoRANetwork(torch.nn.Module):
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
skipped = skipped_te + skipped_un
if varbose and len(skipped) > 0:
print(
logger.warning(
f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
)
for name in skipped:
print(f"\t{name}")
logger.info(f"\t{name}")
self.up_lr_weight: List[float] = None
self.down_lr_weight: List[float] = None
@@ -954,12 +957,12 @@ class LoRANetwork(torch.nn.Module):
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
if apply_text_encoder:
print("enable LoRA for text encoder")
logger.info("enable LoRA for text encoder")
else:
self.text_encoder_loras = []
if apply_unet:
print("enable LoRA for U-Net")
logger.info("enable LoRA for U-Net")
else:
self.unet_loras = []
@@ -981,12 +984,12 @@ class LoRANetwork(torch.nn.Module):
apply_unet = True
if apply_text_encoder:
print("enable LoRA for text encoder")
logger.info("enable LoRA for text encoder")
else:
self.text_encoder_loras = []
if apply_unet:
print("enable LoRA for U-Net")
logger.info("enable LoRA for U-Net")
else:
self.unet_loras = []
@@ -997,7 +1000,7 @@ class LoRANetwork(torch.nn.Module):
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
lora.merge_to(sd_for_lora, dtype, device)
print(f"weights are merged")
logger.info(f"weights are merged")
# 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない
def set_block_lr_weight(
@@ -1144,7 +1147,7 @@ class LoRANetwork(torch.nn.Module):
device = ref_weight.device
def resize_add(mh, mw):
# print(mh, mw, mh * mw)
# logger.info(mh, mw, mh * mw)
m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16
m = m.to(device, dtype=dtype)
mask_dic[mh * mw] = m

View File

@@ -5,27 +5,34 @@ from library import model_util
import library.train_util as train_util
import argparse
from transformers import CLIPTokenizer
import torch
from library.device_utils import init_ipex, get_preferred_device
init_ipex()
import library.model_util as model_util
import lora
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE = get_preferred_device()
def interrogate(args):
weights_dtype = torch.float16
# いろいろ準備する
print(f"loading SD model: {args.sd_model}")
logger.info(f"loading SD model: {args.sd_model}")
args.pretrained_model_name_or_path = args.sd_model
args.vae = None
text_encoder, vae, unet, _ = train_util._load_target_model(args,weights_dtype, DEVICE)
print(f"loading LoRA: {args.model}")
logger.info(f"loading LoRA: {args.model}")
network, weights_sd = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)
# text encoder向けの重みがあるかチェックする本当はlora側でやるのがいい
@@ -35,11 +42,11 @@ def interrogate(args):
has_te_weight = True
break
if not has_te_weight:
print("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません")
logger.error("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません")
return
del vae
print("loading tokenizer")
logger.info("loading tokenizer")
if args.v2:
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
else:
@@ -53,7 +60,7 @@ def interrogate(args):
# トークンをひとつひとつ当たっていく
token_id_start = 0
token_id_end = max(tokenizer.all_special_ids)
print(f"interrogate tokens are: {token_id_start} to {token_id_end}")
logger.info(f"interrogate tokens are: {token_id_start} to {token_id_end}")
def get_all_embeddings(text_encoder):
embs = []
@@ -79,24 +86,24 @@ def interrogate(args):
embs.extend(encoder_hidden_states)
return torch.stack(embs)
print("get original text encoder embeddings.")
logger.info("get original text encoder embeddings.")
orig_embs = get_all_embeddings(text_encoder)
network.apply_to(text_encoder, unet, True, len(network.unet_loras) > 0)
info = network.load_state_dict(weights_sd, strict=False)
print(f"Loading LoRA weights: {info}")
logger.info(f"Loading LoRA weights: {info}")
network.to(DEVICE, dtype=weights_dtype)
network.eval()
del unet
print("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません")
print("get text encoder embeddings with lora.")
logger.info("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません")
logger.info("get text encoder embeddings with lora.")
lora_embs = get_all_embeddings(text_encoder)
# 比べる:とりあえず単純に差分の絶対値で
print("comparing...")
logger.info("comparing...")
diffs = {}
for i, (orig_emb, lora_emb) in enumerate(zip(orig_embs, tqdm(lora_embs))):
diff = torch.mean(torch.abs(orig_emb - lora_emb))

View File

@@ -7,7 +7,10 @@ from safetensors.torch import load_file, save_file
from library import sai_model_spec, train_util
import library.model_util as model_util
import lora
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == ".safetensors":
@@ -61,10 +64,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
name_to_module[lora_name] = child_module
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
logger.info(f"loading: {model}")
lora_sd, _ = load_state_dict(model, merge_dtype)
print(f"merging...")
logger.info(f"merging...")
for key in lora_sd.keys():
if "lora_down" in key:
up_key = key.replace("lora_down", "lora_up")
@@ -73,10 +76,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
# find original module for this lora
module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
if module_name not in name_to_module:
print(f"no module found for LoRA weight: {key}")
logger.info(f"no module found for LoRA weight: {key}")
continue
module = name_to_module[module_name]
# print(f"apply {key} to {module}")
# logger.info(f"apply {key} to {module}")
down_weight = lora_sd[key]
up_weight = lora_sd[up_key]
@@ -104,13 +107,13 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# print(conved.size(), weight.size(), module.stride, module.padding)
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + ratio * conved * scale
module.weight = torch.nn.Parameter(weight)
def merge_lora_models(models, ratios, merge_dtype):
def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
base_alphas = {} # alpha for merged model
base_dims = {}
@@ -118,7 +121,7 @@ def merge_lora_models(models, ratios, merge_dtype):
v2 = None
base_model = None
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
logger.info(f"loading: {model}")
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
if lora_metadata is not None:
@@ -151,13 +154,19 @@ def merge_lora_models(models, ratios, merge_dtype):
if lora_module_name not in base_alphas:
base_alphas[lora_module_name] = alpha
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
logger.info(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
# merge
print(f"merging...")
logger.info(f"merging...")
for key in lora_sd.keys():
if "alpha" in key:
continue
if "lora_up" in key and concat:
concat_dim = 1
elif "lora_down" in key and concat:
concat_dim = 0
else:
concat_dim = None
lora_module_name = key[: key.rfind(".lora_")]
@@ -165,12 +174,16 @@ def merge_lora_models(models, ratios, merge_dtype):
alpha = alphas[lora_module_name]
scale = math.sqrt(alpha / base_alpha) * ratio
scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。
if key in merged_sd:
assert (
merged_sd[key].size() == lora_sd[key].size()
merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
if concat_dim is not None:
merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim)
else:
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
else:
merged_sd[key] = lora_sd[key] * scale
@@ -178,9 +191,16 @@ def merge_lora_models(models, ratios, merge_dtype):
for lora_module_name, alpha in base_alphas.items():
key = lora_module_name + ".alpha"
merged_sd[key] = torch.tensor(alpha)
if shuffle:
key_down = lora_module_name + ".lora_down.weight"
key_up = lora_module_name + ".lora_up.weight"
dim = merged_sd[key_down].shape[0]
perm = torch.randperm(dim)
merged_sd[key_down] = merged_sd[key_down][perm]
merged_sd[key_up] = merged_sd[key_up][:,perm]
print("merged model")
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
logger.info("merged model")
logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
# check all dims are same
dims_list = list(set(base_dims.values()))
@@ -222,7 +242,7 @@ def merge(args):
save_dtype = merge_dtype
if args.sd_model is not None:
print(f"loading SD model: {args.sd_model}")
logger.info(f"loading SD model: {args.sd_model}")
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
@@ -247,18 +267,18 @@ def merge(args):
)
if args.v2:
# TODO read sai modelspec
print(
logger.warning(
"Cannot determine if model is for v-prediction, so save metadata as v-prediction / modelがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します"
)
print(f"saving SD model to: {args.save_to}")
logger.info(f"saving SD model to: {args.save_to}")
model_util.save_stable_diffusion_checkpoint(
args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, sai_metadata, save_dtype, vae
)
else:
state_dict, metadata, v2 = merge_lora_models(args.models, args.ratios, merge_dtype)
state_dict, metadata, v2 = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
print(f"calculating hashes and creating metadata...")
logger.info(f"calculating hashes and creating metadata...")
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
@@ -272,12 +292,12 @@ def merge(args):
)
if v2:
# TODO read sai modelspec
print(
logger.warning(
"Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します"
)
metadata.update(sai_metadata)
print(f"saving model to: {args.save_to}")
logger.info(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
@@ -317,7 +337,19 @@ def setup_parser() -> argparse.ArgumentParser:
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
+ "sai modelspecのメタデータを保存しないLoRAの最低限のss_metadataは保存される",
)
parser.add_argument(
"--concat",
action="store_true",
help="concat lora instead of merge (The dim(rank) of the output LoRA is the sum of the input dims) / "
+ "マージの代わりに結合するLoRAのdim(rank)は入力dimの合計になる",
)
parser.add_argument(
"--shuffle",
action="store_true",
help="shuffle lora weight./ "
+ "LoRAの重みをシャッフルする",
)
return parser

View File

@@ -6,7 +6,10 @@ import torch
from safetensors.torch import load_file, save_file
import library.model_util as model_util
import lora
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == '.safetensors':
@@ -54,10 +57,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
name_to_module[lora_name] = child_module
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
logger.info(f"loading: {model}")
lora_sd = load_state_dict(model, merge_dtype)
print(f"merging...")
logger.info(f"merging...")
for key in lora_sd.keys():
if "lora_down" in key:
up_key = key.replace("lora_down", "lora_up")
@@ -66,10 +69,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
# find original module for this lora
module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight"
if module_name not in name_to_module:
print(f"no module found for LoRA weight: {key}")
logger.info(f"no module found for LoRA weight: {key}")
continue
module = name_to_module[module_name]
# print(f"apply {key} to {module}")
# logger.info(f"apply {key} to {module}")
down_weight = lora_sd[key]
up_weight = lora_sd[up_key]
@@ -96,10 +99,10 @@ def merge_lora_models(models, ratios, merge_dtype):
alpha = None
dim = None
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
logger.info(f"loading: {model}")
lora_sd = load_state_dict(model, merge_dtype)
print(f"merging...")
logger.info(f"merging...")
for key in lora_sd.keys():
if 'alpha' in key:
if key in merged_sd:
@@ -117,7 +120,7 @@ def merge_lora_models(models, ratios, merge_dtype):
dim = lora_sd[key].size()[0]
merged_sd[key] = lora_sd[key] * ratio
print(f"dim (rank): {dim}, alpha: {alpha}")
logger.info(f"dim (rank): {dim}, alpha: {alpha}")
if alpha is None:
alpha = dim
@@ -142,19 +145,21 @@ def merge(args):
save_dtype = merge_dtype
if args.sd_model is not None:
print(f"loading SD model: {args.sd_model}")
logger.info(f"loading SD model: {args.sd_model}")
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype)
print(f"\nsaving SD model to: {args.save_to}")
logger.info("")
logger.info(f"saving SD model to: {args.save_to}")
model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet,
args.sd_model, 0, 0, save_dtype, vae)
else:
state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype)
print(f"\nsaving model to: {args.save_to}")
logger.info(f"")
logger.info(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype)

459
networks/oft.py Normal file
View File

@@ -0,0 +1,459 @@
# OFT network module
import math
import os
from typing import Dict, List, Optional, Tuple, Type, Union
from diffusers import AutoencoderKL
import einops
from transformers import CLIPTextModel
import numpy as np
import torch
import torch.nn.functional as F
import re
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
class OFTModule(torch.nn.Module):
"""
replaces forward method of the original Linear, instead of replacing the original Linear module.
"""
def __init__(
self,
oft_name,
org_module: torch.nn.Module,
multiplier=1.0,
dim=4,
alpha=1,
):
"""
dim -> num blocks
alpha -> constraint
"""
super().__init__()
self.oft_name = oft_name
self.num_blocks = dim
if "Linear" in org_module.__class__.__name__:
out_dim = org_module.out_features
elif "Conv" in org_module.__class__.__name__:
out_dim = org_module.out_channels
if type(alpha) == torch.Tensor:
alpha = alpha.detach().numpy()
# constraint in original paper is alpha * out_dim * out_dim, but we use alpha * out_dim for backward compatibility
# original alpha is 1e-6, so we use 1e-3 or 1e-4 for alpha
self.constraint = alpha * out_dim
self.register_buffer("alpha", torch.tensor(alpha))
self.block_size = out_dim // self.num_blocks
self.oft_blocks = torch.nn.Parameter(torch.zeros(self.num_blocks, self.block_size, self.block_size))
self.I = torch.eye(self.block_size).unsqueeze(0).repeat(self.num_blocks, 1, 1) # cpu
self.out_dim = out_dim
self.shape = org_module.weight.shape
self.multiplier = multiplier
self.org_module = [org_module] # moduleにならないようにlistに入れる
def apply_to(self):
self.org_forward = self.org_module[0].forward
self.org_module[0].forward = self.forward
def get_weight(self, multiplier=None):
if multiplier is None:
multiplier = self.multiplier
block_Q = self.oft_blocks - self.oft_blocks.transpose(1, 2)
norm_Q = torch.norm(block_Q.flatten())
new_norm_Q = torch.clamp(norm_Q, max=self.constraint)
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
if self.I.device != block_Q.device:
self.I = self.I.to(block_Q.device)
I = self.I
block_R = torch.matmul(I + block_Q, (I - block_Q).float().inverse())
block_R_weighted = self.multiplier * (block_R - I) + I
return block_R_weighted
def forward(self, x, scale=None):
if self.multiplier == 0.0:
return self.org_forward(x)
org_module = self.org_module[0]
org_dtype = x.dtype
R = self.get_weight().to(torch.float32)
W = org_module.weight.to(torch.float32)
if len(W.shape) == 4: # Conv2d
W_reshaped = einops.rearrange(W, "(k n) ... -> k n ...", k=self.num_blocks, n=self.block_size)
RW = torch.einsum("k n m, k n ... -> k m ...", R, W_reshaped)
RW = einops.rearrange(RW, "k m ... -> (k m) ...")
result = F.conv2d(
x, RW.to(org_dtype), org_module.bias, org_module.stride, org_module.padding, org_module.dilation, org_module.groups
)
else: # Linear
W_reshaped = einops.rearrange(W, "(k n) m -> k n m", k=self.num_blocks, n=self.block_size)
RW = torch.einsum("k n m, k n p -> k m p", R, W_reshaped)
RW = einops.rearrange(RW, "k m p -> (k m) p")
result = F.linear(x, RW.to(org_dtype), org_module.bias)
return result
class OFTInfModule(OFTModule):
def __init__(
self,
oft_name,
org_module: torch.nn.Module,
multiplier=1.0,
dim=4,
alpha=1,
**kwargs,
):
# no dropout for inference
super().__init__(oft_name, org_module, multiplier, dim, alpha)
self.enabled = True
self.network: OFTNetwork = None
def set_network(self, network):
self.network = network
def forward(self, x, scale=None):
if not self.enabled:
return self.org_forward(x)
return super().forward(x, scale)
def merge_to(self, multiplier=None):
# get org weight
org_sd = self.org_module[0].state_dict()
org_weight = org_sd["weight"].to(torch.float32)
R = self.get_weight(multiplier).to(torch.float32)
weight = org_weight.reshape(self.num_blocks, self.block_size, -1)
weight = torch.einsum("k n m, k n ... -> k m ...", R, weight)
weight = weight.reshape(org_weight.shape)
# convert back to original dtype
weight = weight.to(org_sd["weight"].dtype)
# set weight to org_module
org_sd["weight"] = weight
self.org_module[0].load_state_dict(org_sd)
def create_network(
multiplier: float,
network_dim: Optional[int],
network_alpha: Optional[float],
vae: AutoencoderKL,
text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
unet,
neuron_dropout: Optional[float] = None,
**kwargs,
):
if network_dim is None:
network_dim = 4 # default
if network_alpha is None: # should be set
logger.info(
"network_alpha is not set, use default value 1e-3 / network_alphaが設定されていないのでデフォルト値 1e-3 を使用します"
)
network_alpha = 1e-3
elif network_alpha >= 1:
logger.warning(
"network_alpha is too large (>=1, maybe default value is too large), please consider to set smaller value like 1e-3"
" / network_alphaが大きすぎるようです(>=1, デフォルト値が大きすぎる可能性があります)。1e-3のような小さな値を推奨"
)
enable_all_linear = kwargs.get("enable_all_linear", None)
enable_conv = kwargs.get("enable_conv", None)
if enable_all_linear is not None:
enable_all_linear = bool(enable_all_linear)
if enable_conv is not None:
enable_conv = bool(enable_conv)
network = OFTNetwork(
text_encoder,
unet,
multiplier=multiplier,
dim=network_dim,
alpha=network_alpha,
enable_all_linear=enable_all_linear,
enable_conv=enable_conv,
varbose=True,
)
return network
# Create network from weights for inference, weights are not loaded here (because can be merged)
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
if weights_sd is None:
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file, safe_open
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
# check dim, alpha and if weights have for conv2d
dim = None
alpha = None
has_conv2d = None
all_linear = None
for name, param in weights_sd.items():
if name.endswith(".alpha"):
if alpha is None:
alpha = param.item()
else:
if dim is None:
dim = param.size()[0]
if has_conv2d is None and "in_layers_2" in name:
has_conv2d = True
if all_linear is None and "_ff_" in name:
all_linear = True
if dim is not None and alpha is not None and has_conv2d is not None and all_linear is not None:
break
if has_conv2d is None:
has_conv2d = False
if all_linear is None:
all_linear = False
module_class = OFTInfModule if for_inference else OFTModule
network = OFTNetwork(
text_encoder,
unet,
multiplier=multiplier,
dim=dim,
alpha=alpha,
enable_all_linear=all_linear,
enable_conv=has_conv2d,
module_class=module_class,
)
return network, weights_sd
class OFTNetwork(torch.nn.Module):
UNET_TARGET_REPLACE_MODULE_ATTN_ONLY = ["CrossAttention"]
UNET_TARGET_REPLACE_MODULE_ALL_LINEAR = ["Transformer2DModel"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
OFT_PREFIX_UNET = "oft_unet" # これ変えないほうがいいかな
def __init__(
self,
text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
unet,
multiplier: float = 1.0,
dim: int = 4,
alpha: float = 1,
enable_all_linear: Optional[bool] = False,
enable_conv: Optional[bool] = False,
module_class: Type[object] = OFTModule,
varbose: Optional[bool] = False,
) -> None:
super().__init__()
self.multiplier = multiplier
self.dim = dim
self.alpha = alpha
logger.info(
f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}, enable_all_linear: {enable_all_linear}"
)
# create module instances
def create_modules(
root_module: torch.nn.Module,
target_replace_modules: List[torch.nn.Module],
) -> List[OFTModule]:
prefix = self.OFT_PREFIX_UNET
ofts = []
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
is_linear = "Linear" in child_module.__class__.__name__
is_conv2d = "Conv2d" in child_module.__class__.__name__
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
if is_linear or is_conv2d_1x1 or (is_conv2d and enable_conv):
oft_name = prefix + "." + name + "." + child_name
oft_name = oft_name.replace(".", "_")
# logger.info(oft_name)
oft = module_class(
oft_name,
child_module,
self.multiplier,
dim,
alpha,
)
ofts.append(oft)
return ofts
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
if enable_all_linear:
target_modules = OFTNetwork.UNET_TARGET_REPLACE_MODULE_ALL_LINEAR
else:
target_modules = OFTNetwork.UNET_TARGET_REPLACE_MODULE_ATTN_ONLY
if enable_conv:
target_modules += OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
self.unet_ofts: List[OFTModule] = create_modules(unet, target_modules)
logger.info(f"create OFT for U-Net: {len(self.unet_ofts)} modules.")
# assertion
names = set()
for oft in self.unet_ofts:
assert oft.oft_name not in names, f"duplicated oft name: {oft.oft_name}"
names.add(oft.oft_name)
def set_multiplier(self, multiplier):
self.multiplier = multiplier
for oft in self.unet_ofts:
oft.multiplier = self.multiplier
def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
info = self.load_state_dict(weights_sd, False)
return info
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
assert apply_unet, "apply_unet must be True"
for oft in self.unet_ofts:
oft.apply_to()
self.add_module(oft.oft_name, oft)
# マージできるかどうかを返す
def is_mergeable(self):
return True
# TODO refactor to common function with apply_to
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
logger.info("enable OFT for U-Net")
for oft in self.unet_ofts:
sd_for_lora = {}
for key in weights_sd.keys():
if key.startswith(oft.oft_name):
sd_for_lora[key[len(oft.oft_name) + 1 :]] = weights_sd[key]
oft.load_state_dict(sd_for_lora, False)
oft.merge_to()
logger.info(f"weights are merged")
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
self.requires_grad_(True)
all_params = []
def enumerate_params(ofts):
params = []
for oft in ofts:
params.extend(oft.parameters())
# logger.info num of params
num_params = 0
for p in params:
num_params += p.numel()
logger.info(f"OFT params: {num_params}")
return params
param_data = {"params": enumerate_params(self.unet_ofts)}
if unet_lr is not None:
param_data["lr"] = unet_lr
all_params.append(param_data)
return all_params
def enable_gradient_checkpointing(self):
# not supported
pass
def prepare_grad_etc(self, text_encoder, unet):
self.requires_grad_(True)
def on_epoch_start(self, text_encoder, unet):
self.train()
def get_trainable_params(self):
return self.parameters()
def save_weights(self, file, dtype, metadata):
if metadata is not None and len(metadata) == 0:
metadata = None
state_dict = self.state_dict()
if dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(dtype)
state_dict[key] = v
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file
from library import train_util
# Precalculate model hashes to save time on indexing
if metadata is None:
metadata = {}
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
save_file(state_dict, file, metadata)
else:
torch.save(state_dict, file)
def backup_weights(self):
# 重みのバックアップを行う
ofts: List[OFTInfModule] = self.unet_ofts
for oft in ofts:
org_module = oft.org_module[0]
if not hasattr(org_module, "_lora_org_weight"):
sd = org_module.state_dict()
org_module._lora_org_weight = sd["weight"].detach().clone()
org_module._lora_restored = True
def restore_weights(self):
# 重みのリストアを行う
ofts: List[OFTInfModule] = self.unet_ofts
for oft in ofts:
org_module = oft.org_module[0]
if not org_module._lora_restored:
sd = org_module.state_dict()
sd["weight"] = org_module._lora_org_weight
org_module.load_state_dict(sd)
org_module._lora_restored = True
def pre_calculation(self):
# 事前計算を行う
ofts: List[OFTInfModule] = self.unet_ofts
for oft in ofts:
org_module = oft.org_module[0]
oft.merge_to()
# sd = org_module.state_dict()
# org_weight = sd["weight"]
# lora_weight = oft.get_weight().to(org_weight.device, dtype=org_weight.dtype)
# sd["weight"] = org_weight + lora_weight
# assert sd["weight"].shape == org_weight.shape
# org_module.load_state_dict(sd)
org_module._lora_restored = False
oft.enabled = False

View File

@@ -2,80 +2,86 @@
# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
# Thanks to cloneofsimo
import os
import argparse
import torch
from safetensors.torch import load_file, save_file, safe_open
from tqdm import tqdm
from library import train_util, model_util
import numpy as np
from library import train_util
from library import model_util
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
MIN_SV = 1e-6
# Model save and load functions
def load_state_dict(file_name, dtype):
if model_util.is_safetensors(file_name):
sd = load_file(file_name)
with safe_open(file_name, framework="pt") as f:
metadata = f.metadata()
else:
sd = torch.load(file_name, map_location='cpu')
metadata = None
if model_util.is_safetensors(file_name):
sd = load_file(file_name)
with safe_open(file_name, framework="pt") as f:
metadata = f.metadata()
else:
sd = torch.load(file_name, map_location="cpu")
metadata = None
for key in list(sd.keys()):
if type(sd[key]) == torch.Tensor:
sd[key] = sd[key].to(dtype)
for key in list(sd.keys()):
if type(sd[key]) == torch.Tensor:
sd[key] = sd[key].to(dtype)
return sd, metadata
return sd, metadata
def save_to_file(file_name, model, state_dict, dtype, metadata):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
if model_util.is_safetensors(file_name):
save_file(model, file_name, metadata)
else:
torch.save(model, file_name)
def save_to_file(file_name, state_dict, metadata):
if model_util.is_safetensors(file_name):
save_file(state_dict, file_name, metadata)
else:
torch.save(state_dict, file_name)
# Indexing functions
def index_sv_cumulative(S, target):
original_sum = float(torch.sum(S))
cumulative_sums = torch.cumsum(S, dim=0)/original_sum
index = int(torch.searchsorted(cumulative_sums, target)) + 1
index = max(1, min(index, len(S)-1))
return index
def index_sv_cumulative(S, target):
original_sum = float(torch.sum(S))
cumulative_sums = torch.cumsum(S, dim=0) / original_sum
index = int(torch.searchsorted(cumulative_sums, target)) + 1
index = max(1, min(index, len(S) - 1))
return index
def index_sv_fro(S, target):
S_squared = S.pow(2)
s_fro_sq = float(torch.sum(S_squared))
sum_S_squared = torch.cumsum(S_squared, dim=0)/s_fro_sq
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
index = max(1, min(index, len(S)-1))
S_squared = S.pow(2)
S_fro_sq = float(torch.sum(S_squared))
sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
index = max(1, min(index, len(S) - 1))
return index
return index
def index_sv_ratio(S, target):
max_sv = S[0]
min_sv = max_sv/target
index = int(torch.sum(S > min_sv).item())
index = max(1, min(index, len(S)-1))
max_sv = S[0]
min_sv = max_sv / target
index = int(torch.sum(S > min_sv).item())
index = max(1, min(index, len(S) - 1))
return index
return index
# Modified from Kohaku-blueleaf's extract/merge functions
def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
out_size, in_size, kernel_size, _ = weight.size()
U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device))
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
lora_rank = param_dict["new_rank"]
@@ -92,17 +98,17 @@ def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale
def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
out_size, in_size = weight.size()
U, S, Vh = torch.linalg.svd(weight.to(device))
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
lora_rank = param_dict["new_rank"]
U = U[:, :lora_rank]
S = S[:lora_rank]
U = U @ torch.diag(S)
Vh = Vh[:lora_rank, :]
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu()
param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu()
del U, S, Vh, weight
@@ -113,7 +119,7 @@ def merge_conv(lora_down, lora_up, device):
in_rank, in_size, kernel_size, k_ = lora_down.shape
out_size, out_rank, _, _ = lora_up.shape
assert in_rank == out_rank and kernel_size == k_, f"rank {in_rank} {out_rank} or kernel {kernel_size} {k_} mismatch"
lora_down = lora_down.to(device)
lora_up = lora_up.to(device)
@@ -127,233 +133,280 @@ def merge_linear(lora_down, lora_up, device):
in_rank, in_size = lora_down.shape
out_size, out_rank = lora_up.shape
assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch"
lora_down = lora_down.to(device)
lora_up = lora_up.to(device)
weight = lora_up @ lora_down
del lora_up, lora_down
return weight
# Calculate new rank
def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
param_dict = {}
if dynamic_method=="sv_ratio":
if dynamic_method == "sv_ratio":
# Calculate new dim and alpha based off ratio
new_rank = index_sv_ratio(S, dynamic_param) + 1
new_alpha = float(scale*new_rank)
new_alpha = float(scale * new_rank)
elif dynamic_method=="sv_cumulative":
elif dynamic_method == "sv_cumulative":
# Calculate new dim and alpha based off cumulative sum
new_rank = index_sv_cumulative(S, dynamic_param) + 1
new_alpha = float(scale*new_rank)
new_alpha = float(scale * new_rank)
elif dynamic_method=="sv_fro":
elif dynamic_method == "sv_fro":
# Calculate new dim and alpha based off sqrt sum of squares
new_rank = index_sv_fro(S, dynamic_param) + 1
new_alpha = float(scale*new_rank)
new_alpha = float(scale * new_rank)
else:
new_rank = rank
new_alpha = float(scale*new_rank)
new_alpha = float(scale * new_rank)
if S[0] <= MIN_SV: # Zero matrix, set dim to 1
if S[0] <= MIN_SV: # Zero matrix, set dim to 1
new_rank = 1
new_alpha = float(scale*new_rank)
elif new_rank > rank: # cap max rank at rank
new_alpha = float(scale * new_rank)
elif new_rank > rank: # cap max rank at rank
new_rank = rank
new_alpha = float(scale*new_rank)
new_alpha = float(scale * new_rank)
# Calculate resize info
s_sum = torch.sum(torch.abs(S))
s_rank = torch.sum(torch.abs(S[:new_rank]))
S_squared = S.pow(2)
s_fro = torch.sqrt(torch.sum(S_squared))
s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank]))
fro_percent = float(s_red_fro/s_fro)
fro_percent = float(s_red_fro / s_fro)
param_dict["new_rank"] = new_rank
param_dict["new_alpha"] = new_alpha
param_dict["sum_retained"] = (s_rank)/s_sum
param_dict["sum_retained"] = (s_rank) / s_sum
param_dict["fro_retained"] = fro_percent
param_dict["max_ratio"] = S[0]/S[new_rank - 1]
param_dict["max_ratio"] = S[0] / S[new_rank - 1]
return param_dict
def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
network_alpha = None
network_dim = None
verbose_str = "\n"
fro_list = []
def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
network_alpha = None
network_dim = None
verbose_str = "\n"
fro_list = []
# Extract loaded lora dim and alpha
for key, value in lora_sd.items():
if network_alpha is None and 'alpha' in key:
network_alpha = value
if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
network_dim = value.size()[0]
if network_alpha is not None and network_dim is not None:
break
if network_alpha is None:
network_alpha = network_dim
# Extract loaded lora dim and alpha
for key, value in lora_sd.items():
if network_alpha is None and "alpha" in key:
network_alpha = value
if network_dim is None and "lora_down" in key and len(value.size()) == 2:
network_dim = value.size()[0]
if network_alpha is not None and network_dim is not None:
break
if network_alpha is None:
network_alpha = network_dim
scale = network_alpha/network_dim
scale = network_alpha / network_dim
if dynamic_method:
print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}")
if dynamic_method:
logger.info(
f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}"
)
lora_down_weight = None
lora_up_weight = None
lora_down_weight = None
lora_up_weight = None
o_lora_sd = lora_sd.copy()
block_down_name = None
block_up_name = None
o_lora_sd = lora_sd.copy()
block_down_name = None
block_up_name = None
with torch.no_grad():
for key, value in tqdm(lora_sd.items()):
weight_name = None
if 'lora_down' in key:
block_down_name = key.split(".")[0]
weight_name = key.split(".")[-1]
lora_down_weight = value
else:
continue
with torch.no_grad():
for key, value in tqdm(lora_sd.items()):
weight_name = None
if "lora_down" in key:
block_down_name = key.rsplit(".lora_down", 1)[0]
weight_name = key.rsplit(".", 1)[-1]
lora_down_weight = value
else:
continue
# find corresponding lora_up and alpha
block_up_name = block_down_name
lora_up_weight = lora_sd.get(block_up_name + '.lora_up.' + weight_name, None)
lora_alpha = lora_sd.get(block_down_name + '.alpha', None)
# find corresponding lora_up and alpha
block_up_name = block_down_name
lora_up_weight = lora_sd.get(block_up_name + ".lora_up." + weight_name, None)
lora_alpha = lora_sd.get(block_down_name + ".alpha", None)
weights_loaded = (lora_down_weight is not None and lora_up_weight is not None)
weights_loaded = lora_down_weight is not None and lora_up_weight is not None
if weights_loaded:
if weights_loaded:
conv2d = (len(lora_down_weight.size()) == 4)
if lora_alpha is None:
scale = 1.0
else:
scale = lora_alpha/lora_down_weight.size()[0]
conv2d = len(lora_down_weight.size()) == 4
if lora_alpha is None:
scale = 1.0
else:
scale = lora_alpha / lora_down_weight.size()[0]
if conv2d:
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
param_dict = extract_conv(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
else:
full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device)
param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
if conv2d:
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
param_dict = extract_conv(full_weight_matrix, new_conv_rank, dynamic_method, dynamic_param, device, scale)
else:
full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device)
param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
if verbose:
max_ratio = param_dict['max_ratio']
sum_retained = param_dict['sum_retained']
fro_retained = param_dict['fro_retained']
if not np.isnan(fro_retained):
fro_list.append(float(fro_retained))
if verbose:
max_ratio = param_dict["max_ratio"]
sum_retained = param_dict["sum_retained"]
fro_retained = param_dict["fro_retained"]
if not np.isnan(fro_retained):
fro_list.append(float(fro_retained))
verbose_str+=f"{block_down_name:75} | "
verbose_str+=f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}"
verbose_str += f"{block_down_name:75} | "
verbose_str += (
f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}"
)
if verbose and dynamic_method:
verbose_str+=f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n"
else:
verbose_str+=f"\n"
if verbose and dynamic_method:
verbose_str += f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n"
else:
verbose_str += "\n"
new_alpha = param_dict['new_alpha']
o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous()
o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous()
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict['new_alpha']).to(save_dtype)
new_alpha = param_dict["new_alpha"]
o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous()
o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous()
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict["new_alpha"]).to(save_dtype)
block_down_name = None
block_up_name = None
lora_down_weight = None
lora_up_weight = None
weights_loaded = False
del param_dict
block_down_name = None
block_up_name = None
lora_down_weight = None
lora_up_weight = None
weights_loaded = False
del param_dict
if verbose:
print(verbose_str)
print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
print("resizing complete")
return o_lora_sd, network_dim, new_alpha
if verbose:
print(verbose_str)
print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
logger.info("resizing complete")
return o_lora_sd, network_dim, new_alpha
def resize(args):
if args.save_to is None or not (
args.save_to.endswith(".ckpt")
or args.save_to.endswith(".pt")
or args.save_to.endswith(".pth")
or args.save_to.endswith(".safetensors")
):
raise Exception("The --save_to argument must be specified and must be a .ckpt , .pt, .pth or .safetensors file.")
def str_to_dtype(p):
if p == 'float':
return torch.float
if p == 'fp16':
return torch.float16
if p == 'bf16':
return torch.bfloat16
return None
args.new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank
if args.dynamic_method and not args.dynamic_param:
raise Exception("If using dynamic_method, then dynamic_param is required")
def str_to_dtype(p):
if p == "float":
return torch.float
if p == "fp16":
return torch.float16
if p == "bf16":
return torch.bfloat16
return None
merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32
save_dtype = str_to_dtype(args.save_precision)
if save_dtype is None:
save_dtype = merge_dtype
if args.dynamic_method and not args.dynamic_param:
raise Exception("If using dynamic_method, then dynamic_param is required")
print("loading Model...")
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
merge_dtype = str_to_dtype("float") # matmul method above only seems to work in float32
save_dtype = str_to_dtype(args.save_precision)
if save_dtype is None:
save_dtype = merge_dtype
print("Resizing Lora...")
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose)
logger.info("loading Model...")
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
# update metadata
if metadata is None:
metadata = {}
logger.info("Resizing Lora...")
state_dict, old_dim, new_alpha = resize_lora_model(
lora_sd, args.new_rank, args.new_conv_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose
)
comment = metadata.get("ss_training_comment", "")
# update metadata
if metadata is None:
metadata = {}
if not args.dynamic_method:
metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}"
metadata["ss_network_dim"] = str(args.new_rank)
metadata["ss_network_alpha"] = str(new_alpha)
else:
metadata["ss_training_comment"] = f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}"
metadata["ss_network_dim"] = 'Dynamic'
metadata["ss_network_alpha"] = 'Dynamic'
comment = metadata.get("ss_training_comment", "")
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
if not args.dynamic_method:
conv_desc = "" if args.new_rank == args.new_conv_rank else f" (conv: {args.new_conv_rank})"
metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}{conv_desc}; {comment}"
metadata["ss_network_dim"] = str(args.new_rank)
metadata["ss_network_alpha"] = str(new_alpha)
else:
metadata["ss_training_comment"] = (
f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}"
)
metadata["ss_network_dim"] = "Dynamic"
metadata["ss_network_alpha"] = "Dynamic"
print(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
# cast to save_dtype before calculating hashes
for key in list(state_dict.keys()):
value = state_dict[key]
if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype:
state_dict[key] = value.to(save_dtype)
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
logger.info(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, metadata)
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser()
parser.add_argument("--save_precision", type=str, default=None,
choices=[None, "float", "fp16", "bf16"], help="precision in saving, float if omitted / 保存時の精度、未指定時はfloat")
parser.add_argument("--new_rank", type=int, default=4,
help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
parser.add_argument("--save_to", type=str, default=None,
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
parser.add_argument("--model", type=str, default=None,
help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors")
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
parser.add_argument("--verbose", action="store_true",
help="Display verbose resizing information / rank変更時の詳細情報を出力する")
parser.add_argument("--dynamic_method", type=str, default=None, choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"],
help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank")
parser.add_argument("--dynamic_param", type=float, default=None,
help="Specify target for dynamic reduction")
return parser
parser.add_argument(
"--save_precision",
type=str,
default=None,
choices=[None, "float", "fp16", "bf16"],
help="precision in saving, float if omitted / 保存時の精度、未指定時はfloat",
)
parser.add_argument("--new_rank", type=int, default=4, help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
parser.add_argument(
"--new_conv_rank",
type=int,
default=None,
help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ",
)
parser.add_argument(
"--save_to",
type=str,
default=None,
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors",
)
parser.add_argument(
"--model",
type=str,
default=None,
help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors",
)
parser.add_argument(
"--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う"
)
parser.add_argument(
"--verbose", action="store_true", help="Display verbose resizing information / rank変更時の詳細情報を出力する"
)
parser.add_argument(
"--dynamic_method",
type=str,
default=None,
choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"],
help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank",
)
parser.add_argument("--dynamic_param", type=float, default=None, help="Specify target for dynamic reduction")
return parser
if __name__ == '__main__':
parser = setup_parser()
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
resize(args)
args = parser.parse_args()
resize(args)

View File

@@ -1,13 +1,23 @@
import itertools
import math
import argparse
import os
import time
import concurrent.futures
import torch
from safetensors.torch import load_file, save_file
from tqdm import tqdm
from library import sai_model_spec, sdxl_model_util, train_util
import library.model_util as model_util
import lora
import oft
from svd_merge_lora import format_lbws, get_lbw_block_index, LAYER26
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def load_state_dict(file_name, dtype):
@@ -25,36 +35,58 @@ def load_state_dict(file_name, dtype):
return sd, metadata
def save_to_file(file_name, model, state_dict, dtype, metadata):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
def save_to_file(file_name, model, metadata):
if os.path.splitext(file_name)[1] == ".safetensors":
save_file(model, file_name, metadata=metadata)
else:
torch.save(model, file_name)
def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype):
text_encoder1.to(merge_dtype)
def detect_method_from_training_model(models, dtype):
for model in models:
# TODO It is better to use key names to detect the method
lora_sd, _ = load_state_dict(model, dtype)
for key in tqdm(lora_sd.keys()):
if "lora_up" in key or "lora_down" in key:
return "LoRA"
elif "oft_blocks" in key:
return "OFT"
def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, lbws, merge_dtype):
text_encoder1.to(merge_dtype)
text_encoder2.to(merge_dtype)
unet.to(merge_dtype)
# detect the method: OFT or LoRA_module
method = detect_method_from_training_model(models, merge_dtype)
logger.info(f"method:{method}")
if lbws:
lbws, _, LBW_TARGET_IDX = format_lbws(lbws)
else:
LBW_TARGET_IDX = []
# create module map
name_to_module = {}
for i, root_module in enumerate([text_encoder1, text_encoder2, unet]):
if i <= 1:
if i == 0:
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1
if method == "LoRA":
if i <= 1:
if i == 0:
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1
else:
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2
target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
else:
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2
target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
else:
prefix = lora.LoRANetwork.LORA_PREFIX_UNET
prefix = lora.LoRANetwork.LORA_PREFIX_UNET
target_replace_modules = (
lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
)
elif method == "OFT":
prefix = oft.OFTNetwork.OFT_PREFIX_UNET
# ALL_LINEAR includes ATTN_ONLY, so we don't need to specify ATTN_ONLY
target_replace_modules = (
lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_ALL_LINEAR + oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
)
for name, module in root_module.named_modules():
@@ -65,65 +97,172 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
lora_name = lora_name.replace(".", "_")
name_to_module[lora_name] = child_module
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws):
logger.info(f"loading: {model}")
lora_sd, _ = load_state_dict(model, merge_dtype)
print(f"merging...")
for key in tqdm(lora_sd.keys()):
if "lora_down" in key:
up_key = key.replace("lora_down", "lora_up")
alpha_key = key[: key.index("lora_down")] + "alpha"
logger.info(f"merging...")
# find original module for this lora
module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
if lbw:
lbw_weights = [1] * 26
for index, value in zip(LBW_TARGET_IDX, lbw):
lbw_weights[index] = value
logger.info(f"lbw: {dict(zip(LAYER26.keys(), lbw_weights))}")
if method == "LoRA":
for key in tqdm(lora_sd.keys()):
if "lora_down" in key:
up_key = key.replace("lora_down", "lora_up")
alpha_key = key[: key.index("lora_down")] + "alpha"
# find original module for this lora
module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
if module_name not in name_to_module:
logger.info(f"no module found for LoRA weight: {key}")
continue
module = name_to_module[module_name]
# logger.info(f"apply {key} to {module}")
down_weight = lora_sd[key]
up_weight = lora_sd[up_key]
dim = down_weight.size()[0]
alpha = lora_sd.get(alpha_key, dim)
scale = alpha / dim
if lbw:
index = get_lbw_block_index(key, True)
is_lbw_target = index in LBW_TARGET_IDX
if is_lbw_target:
scale *= lbw_weights[index] # keyがlbwの対象であれば、lbwの重みを掛ける
# W <- W + U * D
weight = module.weight
# logger.info(module_name, down_weight.size(), up_weight.size())
if len(weight.size()) == 2:
# linear
weight = weight + ratio * (up_weight @ down_weight) * scale
elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1
weight = (
weight
+ ratio
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* scale
)
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + ratio * conved * scale
module.weight = torch.nn.Parameter(weight)
elif method == "OFT":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for key in tqdm(lora_sd.keys()):
if "oft_blocks" in key:
oft_blocks = lora_sd[key]
dim = oft_blocks.shape[0]
break
for key in tqdm(lora_sd.keys()):
if "alpha" in key:
oft_blocks = lora_sd[key]
alpha = oft_blocks.item()
break
def merge_to(key):
if "alpha" in key:
return
# find original module for this OFT
module_name = ".".join(key.split(".")[:-1])
if module_name not in name_to_module:
print(f"no module found for LoRA weight: {key}")
continue
logger.info(f"no module found for OFT weight: {key}")
return
module = name_to_module[module_name]
# print(f"apply {key} to {module}")
down_weight = lora_sd[key]
up_weight = lora_sd[up_key]
# logger.info(f"apply {key} to {module}")
dim = down_weight.size()[0]
alpha = lora_sd.get(alpha_key, dim)
scale = alpha / dim
oft_blocks = lora_sd[key]
# W <- W + U * D
weight = module.weight
# print(module_name, down_weight.size(), up_weight.size())
if len(weight.size()) == 2:
# linear
weight = weight + ratio * (up_weight @ down_weight) * scale
elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1
weight = (
weight
+ ratio
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* scale
)
if isinstance(module, torch.nn.Linear):
out_dim = module.out_features
elif isinstance(module, torch.nn.Conv2d):
out_dim = module.out_channels
num_blocks = dim
block_size = out_dim // dim
constraint = (0 if alpha is None else alpha) * out_dim
multiplier = 1
if lbw:
index = get_lbw_block_index(key, False)
is_lbw_target = index in LBW_TARGET_IDX
if is_lbw_target:
multiplier *= lbw_weights[index]
block_Q = oft_blocks - oft_blocks.transpose(1, 2)
norm_Q = torch.norm(block_Q.flatten())
new_norm_Q = torch.clamp(norm_Q, max=constraint)
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
I = torch.eye(block_size, device=oft_blocks.device).unsqueeze(0).repeat(num_blocks, 1, 1)
block_R = torch.matmul(I + block_Q, (I - block_Q).inverse())
block_R_weighted = multiplier * block_R + (1 - multiplier) * I
R = torch.block_diag(*block_R_weighted)
# get org weight
org_sd = module.state_dict()
org_weight = org_sd["weight"].to(device)
R = R.to(org_weight.device, dtype=org_weight.dtype)
if org_weight.dim() == 4:
weight = torch.einsum("oihw, op -> pihw", org_weight, R)
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# print(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + ratio * conved * scale
weight = torch.einsum("oi, op -> pi", org_weight, R)
weight = weight.contiguous() # Make Tensor contiguous; required due to ThreadPoolExecutor
module.weight = torch.nn.Parameter(weight)
# TODO multi-threading may cause OOM on CPU if cpu_count is too high and RAM is not enough
max_workers = 1 if device.type != "cpu" else None # avoid OOM on GPU
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
list(tqdm(executor.map(merge_to, lora_sd.keys()), total=len(lora_sd.keys())))
def merge_lora_models(models, ratios, merge_dtype):
def merge_lora_models(models, ratios, lbws, merge_dtype, concat=False, shuffle=False):
base_alphas = {} # alpha for merged model
base_dims = {}
# detect the method: OFT or LoRA_module
method = detect_method_from_training_model(models, merge_dtype)
if method == "OFT":
raise ValueError(
"OFT model is not supported for merging OFT models. / OFTモデルはOFTモデル同士のマージには対応していません"
)
if lbws:
lbws, _, LBW_TARGET_IDX = format_lbws(lbws)
else:
LBW_TARGET_IDX = []
merged_sd = {}
v2 = None
base_model = None
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws):
logger.info(f"loading: {model}")
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
if lbw:
lbw_weights = [1] * 26
for index, value in zip(LBW_TARGET_IDX, lbw):
lbw_weights[index] = value
logger.info(f"lbw: {dict(zip(LAYER26.keys(), lbw_weights))}")
if lora_metadata is not None:
if v2 is None:
v2 = lora_metadata.get(train_util.SS_METADATA_KEY_V2, None) # returns string, SDXLはv2がないのでFalseのはず
@@ -154,26 +293,43 @@ def merge_lora_models(models, ratios, merge_dtype):
if lora_module_name not in base_alphas:
base_alphas[lora_module_name] = alpha
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
logger.info(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
# merge
print(f"merging...")
logger.info(f"merging...")
for key in tqdm(lora_sd.keys()):
if "alpha" in key:
continue
if "lora_up" in key and concat:
concat_dim = 1
elif "lora_down" in key and concat:
concat_dim = 0
else:
concat_dim = None
lora_module_name = key[: key.rfind(".lora_")]
base_alpha = base_alphas[lora_module_name]
alpha = alphas[lora_module_name]
scale = math.sqrt(alpha / base_alpha) * ratio
scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。
if lbw:
index = get_lbw_block_index(key, True)
is_lbw_target = index in LBW_TARGET_IDX
if is_lbw_target:
scale *= lbw_weights[index] # keyがlbwの対象であれば、lbwの重みを掛ける
if key in merged_sd:
assert (
merged_sd[key].size() == lora_sd[key].size()
merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
if concat_dim is not None:
merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim)
else:
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
else:
merged_sd[key] = lora_sd[key] * scale
@@ -181,9 +337,16 @@ def merge_lora_models(models, ratios, merge_dtype):
for lora_module_name, alpha in base_alphas.items():
key = lora_module_name + ".alpha"
merged_sd[key] = torch.tensor(alpha)
if shuffle:
key_down = lora_module_name + ".lora_down.weight"
key_up = lora_module_name + ".lora_up.weight"
dim = merged_sd[key_down].shape[0]
perm = torch.randperm(dim)
merged_sd[key_down] = merged_sd[key_down][perm]
merged_sd[key_up] = merged_sd[key_up][:, perm]
print("merged model")
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
logger.info("merged model")
logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
# check all dims are same
dims_list = list(set(base_dims.values()))
@@ -208,7 +371,15 @@ def merge_lora_models(models, ratios, merge_dtype):
def merge(args):
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
assert len(args.models) == len(
args.ratios
), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
if args.lbws:
assert len(args.models) == len(
args.lbws
), f"number of models must be equal to number of ratios / モデルの数と層別適用率の数は合わせてください"
else:
args.lbws = [] # zip_longestで扱えるようにlbws未使用時には空のリストにしておく
def str_to_dtype(p):
if p == "float":
@@ -225,7 +396,7 @@ def merge(args):
save_dtype = merge_dtype
if args.sd_model is not None:
print(f"loading SD model: {args.sd_model}")
logger.info(f"loading SD model: {args.sd_model}")
(
text_model1,
@@ -236,7 +407,7 @@ def merge(args):
ckpt_info,
) = sdxl_model_util.load_models_from_sdxl_checkpoint(sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.sd_model, "cpu")
merge_to_sd_model(text_model1, text_model2, unet, args.models, args.ratios, merge_dtype)
merge_to_sd_model(text_model1, text_model2, unet, args.models, args.ratios, args.lbws, merge_dtype)
if args.no_metadata:
sai_metadata = None
@@ -247,14 +418,20 @@ def merge(args):
None, False, False, True, False, False, time.time(), title=title, merged_from=merged_from
)
print(f"saving SD model to: {args.save_to}")
logger.info(f"saving SD model to: {args.save_to}")
sdxl_model_util.save_stable_diffusion_checkpoint(
args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, logit_scale, sai_metadata, save_dtype
)
else:
state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype)
state_dict, metadata = merge_lora_models(args.models, args.ratios, args.lbws, merge_dtype, args.concat, args.shuffle)
print(f"calculating hashes and creating metadata...")
# cast to save_dtype before calculating hashes
for key in list(state_dict.keys()):
value = state_dict[key]
if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype:
state_dict[key] = value.to(save_dtype)
logger.info(f"calculating hashes and creating metadata...")
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
@@ -268,8 +445,8 @@ def merge(args):
)
metadata.update(sai_metadata)
print(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
logger.info(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, metadata)
def setup_parser() -> argparse.ArgumentParser:
@@ -295,18 +472,36 @@ def setup_parser() -> argparse.ArgumentParser:
help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする",
)
parser.add_argument(
"--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
"--save_to",
type=str,
default=None,
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors",
)
parser.add_argument(
"--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors"
"--models",
type=str,
nargs="*",
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors",
)
parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
parser.add_argument("--lbws", type=str, nargs="*", help="lbw for each model / それぞれのLoRAモデルの層別適用率")
parser.add_argument(
"--no_metadata",
action="store_true",
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
+ "sai modelspecのメタデータを保存しないLoRAの最低限のss_metadataは保存される",
)
parser.add_argument(
"--concat",
action="store_true",
help="concat lora instead of merge (The dim(rank) of the output LoRA is the sum of the input dims) / "
+ "マージの代わりに結合するLoRAのdim(rank)は入力dimの合計になる",
)
parser.add_argument(
"--shuffle",
action="store_true",
help="shuffle lora weight./ " + "LoRAの重みをシャッフルする",
)
return parser

View File

@@ -1,6 +1,8 @@
import math
import argparse
import itertools
import json
import os
import re
import time
import torch
from safetensors.torch import load_file, save_file
@@ -8,10 +10,196 @@ from tqdm import tqdm
from library import sai_model_spec, train_util
import library.model_util as model_util
import lora
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
CLAMP_QUANTILE = 0.99
ACCEPTABLE = [12, 17, 20, 26]
SDXL_LAYER_NUM = [12, 20]
LAYER12 = {
"BASE": True,
"IN00": False,
"IN01": False,
"IN02": False,
"IN03": False,
"IN04": True,
"IN05": True,
"IN06": False,
"IN07": True,
"IN08": True,
"IN09": False,
"IN10": False,
"IN11": False,
"MID": True,
"OUT00": True,
"OUT01": True,
"OUT02": True,
"OUT03": True,
"OUT04": True,
"OUT05": True,
"OUT06": False,
"OUT07": False,
"OUT08": False,
"OUT09": False,
"OUT10": False,
"OUT11": False,
}
LAYER17 = {
"BASE": True,
"IN00": False,
"IN01": True,
"IN02": True,
"IN03": False,
"IN04": True,
"IN05": True,
"IN06": False,
"IN07": True,
"IN08": True,
"IN09": False,
"IN10": False,
"IN11": False,
"MID": True,
"OUT00": False,
"OUT01": False,
"OUT02": False,
"OUT03": True,
"OUT04": True,
"OUT05": True,
"OUT06": True,
"OUT07": True,
"OUT08": True,
"OUT09": True,
"OUT10": True,
"OUT11": True,
}
LAYER20 = {
"BASE": True,
"IN00": True,
"IN01": True,
"IN02": True,
"IN03": True,
"IN04": True,
"IN05": True,
"IN06": True,
"IN07": True,
"IN08": True,
"IN09": False,
"IN10": False,
"IN11": False,
"MID": True,
"OUT00": True,
"OUT01": True,
"OUT02": True,
"OUT03": True,
"OUT04": True,
"OUT05": True,
"OUT06": True,
"OUT07": True,
"OUT08": True,
"OUT09": False,
"OUT10": False,
"OUT11": False,
}
LAYER26 = {
"BASE": True,
"IN00": True,
"IN01": True,
"IN02": True,
"IN03": True,
"IN04": True,
"IN05": True,
"IN06": True,
"IN07": True,
"IN08": True,
"IN09": True,
"IN10": True,
"IN11": True,
"MID": True,
"OUT00": True,
"OUT01": True,
"OUT02": True,
"OUT03": True,
"OUT04": True,
"OUT05": True,
"OUT06": True,
"OUT07": True,
"OUT08": True,
"OUT09": True,
"OUT10": True,
"OUT11": True,
}
assert len([v for v in LAYER12.values() if v]) == 12
assert len([v for v in LAYER17.values() if v]) == 17
assert len([v for v in LAYER20.values() if v]) == 20
assert len([v for v in LAYER26.values() if v]) == 26
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
def get_lbw_block_index(lora_name: str, is_sdxl: bool = False) -> int:
# lbw block index is 0-based, but 0 for text encoder, so we return 0 for text encoder
if "text_model_encoder_" in lora_name: # LoRA for text encoder
return 0
# lbw block index is 1-based for U-Net, and no "input_blocks.0" in CompVis SD, so "input_blocks.1" have index 2
block_idx = -1 # invalid lora name
if not is_sdxl:
NUM_OF_BLOCKS = 12 # up/down blocks
m = RE_UPDOWN.search(lora_name)
if m:
g = m.groups()
up_down = g[0]
i = int(g[1])
j = int(g[3])
if up_down == "down":
if g[2] == "resnets" or g[2] == "attentions":
idx = 3 * i + j + 1
elif g[2] == "downsamplers":
idx = 3 * (i + 1)
else:
return block_idx # invalid lora name
elif up_down == "up":
if g[2] == "resnets" or g[2] == "attentions":
idx = 3 * i + j
elif g[2] == "upsamplers":
idx = 3 * i + 2
else:
return block_idx # invalid lora name
if g[0] == "down":
block_idx = 1 + idx # 1-based index, down block index
elif g[0] == "up":
block_idx = 1 + NUM_OF_BLOCKS + 1 + idx # 1-based index, num blocks, mid block, up block index
elif "mid_block_" in lora_name:
block_idx = 1 + NUM_OF_BLOCKS # 1-based index, num blocks, mid block
else:
# SDXL: some numbers are skipped
if lora_name.startswith("lora_unet_"):
name = lora_name[len("lora_unet_") :]
if name.startswith("time_embed_") or name.startswith("label_emb_"): # 1, No LoRA in sd-scripts
block_idx = 1
elif name.startswith("input_blocks_"): # 1-8 to 2-9
block_idx = 1 + int(name.split("_")[2])
elif name.startswith("middle_block_"): # 13
block_idx = 13
elif name.startswith("output_blocks_"): # 0-8 to 14-22
block_idx = 14 + int(name.split("_")[2])
elif name.startswith("out_"): # 23, No LoRA in sd-scripts
block_idx = 23
return block_idx
def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == ".safetensors":
@@ -28,25 +216,54 @@ def load_state_dict(file_name, dtype):
return sd, metadata
def save_to_file(file_name, state_dict, dtype, metadata):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
def save_to_file(file_name, state_dict, metadata):
if os.path.splitext(file_name)[1] == ".safetensors":
save_file(state_dict, file_name, metadata=metadata)
else:
torch.save(state_dict, file_name)
def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype):
print(f"new rank: {new_rank}, new conv rank: {new_conv_rank}")
def format_lbws(lbws):
try:
# lbwは"[1,1,1,1,1,1,1,1,1,1,1,1]"のような文字列で与えられることを期待している
lbws = [json.loads(lbw) for lbw in lbws]
except Exception:
raise ValueError(f"format of lbws are must be json / 層別適用率はJSON形式で書いてください")
assert all(isinstance(lbw, list) for lbw in lbws), f"lbws are must be list / 層別適用率はリストにしてください"
assert len(set(len(lbw) for lbw in lbws)) == 1, "all lbws should have the same length / 層別適用率は同じ長さにしてください"
assert all(
len(lbw) in ACCEPTABLE for lbw in lbws
), f"length of lbw are must be in {ACCEPTABLE} / 層別適用率の長さは{ACCEPTABLE}のいずれかにしてください"
assert all(
all(isinstance(weight, (int, float)) for weight in lbw) for lbw in lbws
), f"values of lbs are must be numbers / 層別適用率の値はすべて数値にしてください"
layer_num = len(lbws[0])
is_sdxl = True if layer_num in SDXL_LAYER_NUM else False
FLAGS = {
"12": LAYER12.values(),
"17": LAYER17.values(),
"20": LAYER20.values(),
"26": LAYER26.values(),
}[str(layer_num)]
LBW_TARGET_IDX = [i for i, flag in enumerate(FLAGS) if flag]
return lbws, is_sdxl, LBW_TARGET_IDX
def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, merge_dtype):
logger.info(f"new rank: {new_rank}, new conv rank: {new_conv_rank}")
merged_sd = {}
v2 = None
v2 = None # This is meaning LoRA Metadata v2, Not meaning SD2
base_model = None
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
if lbws:
lbws, is_sdxl, LBW_TARGET_IDX = format_lbws(lbws)
else:
is_sdxl = False
LBW_TARGET_IDX = []
for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws):
logger.info(f"loading: {model}")
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
if lora_metadata is not None:
@@ -55,8 +272,14 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
if base_model is None:
base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None)
if lbw:
lbw_weights = [1] * 26
for index, value in zip(LBW_TARGET_IDX, lbw):
lbw_weights[index] = value
logger.info(f"lbw: {dict(zip(LAYER26.keys(), lbw_weights))}")
# merge
print(f"merging...")
logger.info(f"merging...")
for key in tqdm(list(lora_sd.keys())):
if "lora_down" not in key:
continue
@@ -73,15 +296,15 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
out_dim = up_weight.size()[0]
conv2d = len(down_weight.size()) == 4
kernel_size = None if not conv2d else down_weight.size()[2:4]
# print(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size)
# logger.info(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size)
# make original weight if not exist
if lora_module_name not in merged_sd:
weight = torch.zeros((out_dim, in_dim, *kernel_size) if conv2d else (out_dim, in_dim), dtype=merge_dtype)
if device:
weight = weight.to(device)
else:
weight = merged_sd[lora_module_name]
if device:
weight = weight.to(device)
# merge to weight
if device:
@@ -91,6 +314,12 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
# W <- W + U * D
scale = alpha / network_dim
if lbw:
index = get_lbw_block_index(key, is_sdxl)
is_lbw_target = index in LBW_TARGET_IDX
if is_lbw_target:
scale *= lbw_weights[index] # keyがlbwの対象であれば、lbwの重みを掛ける
if device: # and isinstance(scale, torch.Tensor):
scale = scale.to(device)
@@ -107,13 +336,16 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
weight = weight + ratio * conved * scale
merged_sd[lora_module_name] = weight
merged_sd[lora_module_name] = weight.to("cpu")
# extract from merged weights
print("extract new lora...")
logger.info("extract new lora...")
merged_lora_sd = {}
with torch.no_grad():
for lora_module_name, mat in tqdm(list(merged_sd.items())):
if device:
mat = mat.to(device)
conv2d = len(mat.size()) == 4
kernel_size = None if not conv2d else mat.size()[2:4]
conv2d_3x3 = conv2d and kernel_size != (1, 1)
@@ -152,7 +384,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
merged_lora_sd[lora_module_name + ".lora_up.weight"] = up_weight.to("cpu").contiguous()
merged_lora_sd[lora_module_name + ".lora_down.weight"] = down_weight.to("cpu").contiguous()
merged_lora_sd[lora_module_name + ".alpha"] = torch.tensor(module_new_rank)
merged_lora_sd[lora_module_name + ".alpha"] = torch.tensor(module_new_rank, device="cpu")
# build minimum metadata
dims = f"{new_rank}"
@@ -167,7 +399,15 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
def merge(args):
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
assert len(args.models) == len(
args.ratios
), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
if args.lbws:
assert len(args.models) == len(
args.lbws
), f"number of models must be equal to number of ratios / モデルの数と層別適用率の数は合わせてください"
else:
args.lbws = [] # zip_longestで扱えるようにlbws未使用時には空のリストにしておく
def str_to_dtype(p):
if p == "float":
@@ -185,10 +425,16 @@ def merge(args):
new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank
state_dict, metadata, v2, base_model = merge_lora_models(
args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype
args.models, args.ratios, args.lbws, args.new_rank, new_conv_rank, args.device, merge_dtype
)
print(f"calculating hashes and creating metadata...")
# cast to save_dtype before calculating hashes
for key in list(state_dict.keys()):
value = state_dict[key]
if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype:
state_dict[key] = value.to(save_dtype)
logger.info(f"calculating hashes and creating metadata...")
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
@@ -203,13 +449,13 @@ def merge(args):
)
if v2:
# TODO read sai modelspec
print(
logger.warning(
"Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します"
)
metadata.update(sai_metadata)
print(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, save_dtype, metadata)
logger.info(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, metadata)
def setup_parser() -> argparse.ArgumentParser:
@@ -229,12 +475,19 @@ def setup_parser() -> argparse.ArgumentParser:
help="precision in merging (float is recommended) / マージの計算時の精度floatを推奨",
)
parser.add_argument(
"--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
"--save_to",
type=str,
default=None,
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors",
)
parser.add_argument(
"--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors"
"--models",
type=str,
nargs="*",
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors",
)
parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
parser.add_argument("--lbws", type=str, nargs="*", help="lbw for each model / それぞれのLoRAモデルの層別適用率")
parser.add_argument("--new_rank", type=int, default=4, help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
parser.add_argument(
"--new_conv_rank",
@@ -242,7 +495,9 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ",
)
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
parser.add_argument(
"--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う"
)
parser.add_argument(
"--no_metadata",
action="store_true",

View File

@@ -1,29 +1,42 @@
accelerate==0.19.0
transformers==4.30.2
diffusers[torch]==0.18.2
accelerate==0.30.0
transformers==4.44.0
diffusers[torch]==0.25.0
ftfy==6.1.1
# albumentations==1.3.0
opencv-python==4.7.0.68
einops==0.6.0
opencv-python==4.8.1.78
einops==0.7.0
pytorch-lightning==1.9.0
# bitsandbytes==0.39.1
tensorboard==2.10.1
safetensors==0.3.1
bitsandbytes==0.44.0
prodigyopt==1.0
lion-pytorch==0.0.6
tensorboard
safetensors==0.4.2
# gradio==3.16.2
altair==4.2.2
easygui==0.98.3
toml==0.10.2
voluptuous==0.13.1
huggingface-hub==0.15.1
# for loading Diffusers' SDXL
invisible-watermark==0.2.0
huggingface-hub==0.24.5
# for Image utils
imagesize==1.4.1
# for BLIP captioning
# requests==2.28.2
# timm==0.6.12
# fairscale==0.4.13
# for WD14 captioning
# for WD14 captioning (tensorflow)
# tensorflow==2.10.1
# for WD14 captioning (onnx)
# onnx==1.15.0
# onnxruntime-gpu==1.17.1
# onnxruntime==1.17.1
# for cuda 12.1(default 11.8)
# onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/
# this is for onnx:
# protobuf==3.20.3
# open clip for SDXL
open-clip-torch==2.20.0
# open-clip-torch==2.20.0
# For logging
rich==13.7.0
# for kohya_ss library
-e .

File diff suppressed because it is too large Load Diff

View File

@@ -8,16 +8,28 @@ import os
import random
from einops import repeat
import numpy as np
import torch
from library.device_utils import init_ipex, get_preferred_device
init_ipex()
from tqdm import tqdm
from transformers import CLIPTokenizer
from diffusers import EulerDiscreteScheduler
from PIL import Image
import open_clip
# import open_clip
from safetensors.torch import load_file
from library import model_util, sdxl_model_util
import networks.lora as lora
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# scheduler: このあたりの設定はSD1/2と同じでいいらしい
# scheduler: The settings around here seem to be the same as SD1/2
@@ -80,7 +92,7 @@ if __name__ == "__main__":
guidance_scale = 7
seed = None # 1
DEVICE = "cuda"
DEVICE = get_preferred_device()
DTYPE = torch.float16 # bfloat16 may work
parser = argparse.ArgumentParser()
@@ -94,7 +106,7 @@ if __name__ == "__main__":
type=str,
nargs="*",
default=[],
help="LoRA weights, only supports networks.lora, each arguement is a `path;multiplier` (semi-colon separated)",
help="LoRA weights, only supports networks.lora, each argument is a `path;multiplier` (semi-colon separated)",
)
parser.add_argument("--interactive", action="store_true")
args = parser.parse_args()
@@ -135,7 +147,7 @@ if __name__ == "__main__":
vae_dtype = DTYPE
if DTYPE == torch.float16:
print("use float32 for vae")
logger.info("use float32 for vae")
vae_dtype = torch.float32
vae.to(DEVICE, dtype=vae_dtype)
vae.eval()
@@ -146,12 +158,13 @@ if __name__ == "__main__":
text_model2.eval()
unet.set_use_memory_efficient_attention(True, False)
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
vae.set_use_memory_efficient_attention_xformers(True)
# Tokenizers
tokenizer1 = CLIPTokenizer.from_pretrained(text_encoder_1_name)
tokenizer2 = lambda x: open_clip.tokenize(x, context_length=77)
# tokenizer2 = lambda x: open_clip.tokenize(x, context_length=77)
tokenizer2 = CLIPTokenizer.from_pretrained(text_encoder_2_name)
# LoRA
for weights_file in args.lora_weights:
@@ -182,9 +195,11 @@ if __name__ == "__main__":
emb1 = get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256)
emb2 = get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256)
emb3 = get_timestep_embedding(torch.FloatTensor([target_height, target_width]).unsqueeze(0), 256)
# print("emb1", emb1.shape)
# logger.info("emb1", emb1.shape)
c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(DEVICE, dtype=DTYPE)
uc_vector = c_vector.clone().to(DEVICE, dtype=DTYPE) # ちょっとここ正しいかどうかわからない I'm not sure if this is right
uc_vector = c_vector.clone().to(
DEVICE, dtype=DTYPE
) # ちょっとここ正しいかどうかわからない I'm not sure if this is right
# crossattn
@@ -207,13 +222,22 @@ if __name__ == "__main__":
# text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding) # layer normは通さないらしい
# text encoder 2
with torch.no_grad():
tokens = tokenizer2(text2).to(DEVICE)
# tokens = tokenizer2(text2).to(DEVICE)
tokens = tokenizer2(
text,
truncation=True,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(DEVICE)
with torch.no_grad():
enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True)
text_embedding2_penu = enc_out["hidden_states"][-2]
# print("hidden_states2", text_embedding2_penu.shape)
text_embedding2_pool = enc_out["text_embeds"] # do not support Textual Inversion
# logger.info("hidden_states2", text_embedding2_penu.shape)
text_embedding2_pool = enc_out["text_embeds"] # do not support Textual Inversion
# 連結して終了 concat and finish
text_embedding = torch.cat([text_embedding1, text_embedding2_penu], dim=2)
@@ -221,7 +245,7 @@ if __name__ == "__main__":
# cond
c_ctx, c_ctx_pool = call_text_encoder(prompt, prompt2)
# print(c_ctx.shape, c_ctx_p.shape, c_vector.shape)
# logger.info(c_ctx.shape, c_ctx_p.shape, c_vector.shape)
c_vector = torch.cat([c_ctx_pool, c_vector], dim=1)
# uncond
@@ -318,4 +342,4 @@ if __name__ == "__main__":
seed = int(seed)
generate_image(prompt, prompt2, negative_prompt, seed)
print("Done!")
logger.info("Done!")

View File

@@ -1,7 +1,6 @@
# training with captions
import argparse
import gc
import math
import os
from multiprocessing import Value
@@ -9,12 +8,26 @@ from typing import List
import toml
from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from library import sdxl_model_util
from library import deepspeed_utils, sdxl_model_util
import library.train_util as train_util
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
import library.config_util as config_util
import library.sdxl_train_util as sdxl_train_util
from library.config_util import (
@@ -27,6 +40,8 @@ from library.custom_train_functions import (
prepare_scheduler_for_custom_training,
scale_v_prediction_loss_like_noise_prediction,
add_v_prediction_like_loss,
apply_debiased_estimation,
apply_masked_loss,
)
from library.sdxl_original_unet import SdxlUNet2DConditionModel
@@ -63,41 +78,34 @@ def get_block_params_to_optimize(unet: SdxlUNet2DConditionModel, block_lrs: List
def append_block_lr_to_logs(block_lrs, logs, lr_scheduler, optimizer_type):
lrs = lr_scheduler.get_last_lr()
lr_index = 0
names = []
block_index = 0
while lr_index < len(lrs):
while block_index < UNET_NUM_BLOCKS_FOR_BLOCK_LR + 2:
if block_index < UNET_NUM_BLOCKS_FOR_BLOCK_LR:
name = f"block{block_index}"
if block_lrs[block_index] == 0:
block_index += 1
continue
names.append(f"block{block_index}")
elif block_index == UNET_NUM_BLOCKS_FOR_BLOCK_LR:
name = "text_encoder1"
names.append("text_encoder1")
elif block_index == UNET_NUM_BLOCKS_FOR_BLOCK_LR + 1:
name = "text_encoder2"
else:
raise ValueError(f"unexpected block_index: {block_index}")
names.append("text_encoder2")
block_index += 1
logs["lr/" + name] = float(lrs[lr_index])
if optimizer_type.lower().startswith("DAdapt".lower()) or optimizer_type.lower() == "Prodigy".lower():
logs["lr/d*lr/" + name] = (
lr_scheduler.optimizers[-1].param_groups[lr_index]["d"] * lr_scheduler.optimizers[-1].param_groups[lr_index]["lr"]
)
lr_index += 1
train_util.append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names)
def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
sdxl_train_util.verify_sdxl_training_args(args)
deepspeed_utils.prepare_deepspeed_args(args)
setup_logging(args, reset=True)
assert not args.weighted_captions, "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
assert (
not args.weighted_captions
), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
assert (
not args.train_text_encoder or not args.cache_text_encoder_outputs
), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません"
@@ -120,20 +128,20 @@ def train(args):
# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
if use_dreambooth_method:
print("Using DreamBooth method.")
logger.info("Using DreamBooth method.")
user_config = {
"datasets": [
{
@@ -144,7 +152,7 @@ def train(args):
]
}
else:
print("Training with captions.")
logger.info("Training with captions.")
user_config = {
"datasets": [
{
@@ -165,8 +173,8 @@ def train(args):
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
train_dataset_group.verify_bucket_reso_steps(32)
@@ -174,7 +182,7 @@ def train(args):
train_util.debug_dataset(train_dataset_group, True)
return
if len(train_dataset_group) == 0:
print(
logger.error(
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
)
return
@@ -190,7 +198,7 @@ def train(args):
), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
# acceleratorを準備する
print("prepare accelerator")
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
@@ -257,17 +265,16 @@ def train(args):
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
# 学習を準備する:モデルを適切な状態にする
training_models = []
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
training_models.append(unet)
train_unet = args.learning_rate != 0
train_text_encoder1 = False
train_text_encoder2 = False
if args.train_text_encoder:
# TODO each option for two text encoders?
@@ -275,10 +282,23 @@ def train(args):
if args.gradient_checkpointing:
text_encoder1.gradient_checkpointing_enable()
text_encoder2.gradient_checkpointing_enable()
training_models.append(text_encoder1)
training_models.append(text_encoder2)
# set require_grad=True later
lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train
lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train
train_text_encoder1 = lr_te1 != 0
train_text_encoder2 = lr_te2 != 0
# caching one text encoder output is not supported
if not train_text_encoder1:
text_encoder1.to(weight_dtype)
if not train_text_encoder2:
text_encoder2.to(weight_dtype)
text_encoder1.requires_grad_(train_text_encoder1)
text_encoder2.requires_grad_(train_text_encoder2)
text_encoder1.train(train_text_encoder1)
text_encoder2.train(train_text_encoder2)
else:
text_encoder1.to(weight_dtype)
text_encoder2.to(weight_dtype)
text_encoder1.requires_grad_(False)
text_encoder2.requires_grad_(False)
text_encoder1.eval()
@@ -287,7 +307,7 @@ def train(args):
# TextEncoderの出力をキャッシュする
if args.cache_text_encoder_outputs:
# Text Encodes are eval and no grad
with torch.no_grad():
with torch.no_grad(), accelerator.autocast():
train_dataset_group.cache_text_encoder_outputs(
(tokenizer1, tokenizer2),
(text_encoder1, text_encoder2),
@@ -303,45 +323,94 @@ def train(args):
vae.eval()
vae.to(accelerator.device, dtype=vae_dtype)
for m in training_models:
m.requires_grad_(True)
unet.requires_grad_(train_unet)
if not train_unet:
unet.to(accelerator.device, dtype=weight_dtype) # because of unet is not prepared
if block_lrs is None:
params = []
for m in training_models:
params.extend(m.parameters())
params_to_optimize = params
training_models = []
params_to_optimize = []
if train_unet:
training_models.append(unet)
if block_lrs is None:
params_to_optimize.append({"params": list(unet.parameters()), "lr": args.learning_rate})
else:
params_to_optimize.extend(get_block_params_to_optimize(unet, block_lrs))
# calculate number of trainable parameters
n_params = 0
for p in params:
if train_text_encoder1:
training_models.append(text_encoder1)
params_to_optimize.append({"params": list(text_encoder1.parameters()), "lr": args.learning_rate_te1 or args.learning_rate})
if train_text_encoder2:
training_models.append(text_encoder2)
params_to_optimize.append({"params": list(text_encoder2.parameters()), "lr": args.learning_rate_te2 or args.learning_rate})
# calculate number of trainable parameters
n_params = 0
for group in params_to_optimize:
for p in group["params"]:
n_params += p.numel()
else:
params_to_optimize = get_block_params_to_optimize(training_models[0], block_lrs) # U-Net
for m in training_models[1:]: # Text Encoders if exists
params_to_optimize.append({"params": m.parameters(), "lr": args.learning_rate})
# calculate number of trainable parameters
n_params = 0
for params in params_to_optimize:
for p in params["params"]:
n_params += p.numel()
accelerator.print(f"train unet: {train_unet}, text_encoder1: {train_text_encoder1}, text_encoder2: {train_text_encoder2}")
accelerator.print(f"number of models: {len(training_models)}")
accelerator.print(f"number of trainable parameters: {n_params}")
# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
if args.fused_optimizer_groups:
# fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
# Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each group of parameters.
# This balances memory usage and management complexity.
# calculate total number of parameters
n_total_params = sum(len(params["params"]) for params in params_to_optimize)
params_per_group = math.ceil(n_total_params / args.fused_optimizer_groups)
# split params into groups, keeping the learning rate the same for all params in a group
# this will increase the number of groups if the learning rate is different for different params (e.g. U-Net and text encoders)
grouped_params = []
param_group = []
param_group_lr = -1
for group in params_to_optimize:
lr = group["lr"]
for p in group["params"]:
# if the learning rate is different for different params, start a new group
if lr != param_group_lr:
if param_group:
grouped_params.append({"params": param_group, "lr": param_group_lr})
param_group = []
param_group_lr = lr
param_group.append(p)
# if the group has enough parameters, start a new group
if len(param_group) == params_per_group:
grouped_params.append({"params": param_group, "lr": param_group_lr})
param_group = []
param_group_lr = -1
if param_group:
grouped_params.append({"params": param_group, "lr": param_group_lr})
# prepare optimizers for each group
optimizers = []
for group in grouped_params:
_, _, optimizer = train_util.get_optimizer(args, trainable_params=[group])
optimizers.append(optimizer)
optimizer = optimizers[0] # avoid error in the following code
logger.info(f"using {len(optimizers)} optimizers for fused optimizer groups")
else:
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collater,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
@@ -351,13 +420,20 @@ def train(args):
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
if args.fused_optimizer_groups:
# prepare lr schedulers for each optimizer
lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers]
lr_scheduler = lr_schedulers[0] # avoid error in the following code
else:
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# 実験的機能勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
if args.full_fp16:
@@ -377,27 +453,40 @@ def train(args):
text_encoder1.to(weight_dtype)
text_encoder2.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
if args.train_text_encoder:
unet, text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler
)
# freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer
if train_text_encoder1:
text_encoder1.text_model.encoder.layers[-1].requires_grad_(False)
text_encoder1.text_model.final_layer_norm.requires_grad_(False)
if args.deepspeed:
ds_model = deepspeed_utils.prepare_deepspeed_model(
args,
unet=unet if train_unet else None,
text_encoder1=text_encoder1 if train_text_encoder1 else None,
text_encoder2=text_encoder2 if train_text_encoder2 else None,
)
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
training_models = [ds_model]
# transform DDP after prepare
text_encoder1, text_encoder2, unet = train_util.transform_models_if_DDP([text_encoder1, text_encoder2, unet])
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
(unet,) = train_util.transform_models_if_DDP([unet])
text_encoder1.to(weight_dtype)
text_encoder2.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
if train_unet:
unet = accelerator.prepare(unet)
if train_text_encoder1:
text_encoder1 = accelerator.prepare(text_encoder1)
if train_text_encoder2:
text_encoder2 = accelerator.prepare(text_encoder2)
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
if args.cache_text_encoder_outputs:
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
text_encoder1.to("cpu", dtype=torch.float32)
text_encoder2.to("cpu", dtype=torch.float32)
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory_on_device(accelerator.device)
else:
# make sure Text Encoders are on GPU
text_encoder1.to(accelerator.device)
@@ -405,11 +494,64 @@ def train(args):
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
# During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do.
# -> But we think it's ok to patch accelerator even if deepspeed is enabled.
train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
if args.fused_backward_pass:
# use fused optimizer for backward pass: other optimizers will be supported in the future
import library.adafactor_fused
library.adafactor_fused.patch_adafactor_fused(optimizer)
for param_group in optimizer.param_groups:
for parameter in param_group["params"]:
if parameter.requires_grad:
def __grad_hook(tensor: torch.Tensor, param_group=param_group):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, param_group)
tensor.grad = None
parameter.register_post_accumulate_grad_hook(__grad_hook)
elif args.fused_optimizer_groups:
# prepare for additional optimizers and lr schedulers
for i in range(1, len(optimizers)):
optimizers[i] = accelerator.prepare(optimizers[i])
lr_schedulers[i] = accelerator.prepare(lr_schedulers[i])
# counters are used to determine when to step the optimizer
global optimizer_hooked_count
global num_parameters_per_group
global parameter_optimizer_map
optimizer_hooked_count = {}
num_parameters_per_group = [0] * len(optimizers)
parameter_optimizer_map = {}
for opt_idx, optimizer in enumerate(optimizers):
for param_group in optimizer.param_groups:
for parameter in param_group["params"]:
if parameter.requires_grad:
def optimizer_hook(parameter: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(parameter, args.max_grad_norm)
i = parameter_optimizer_map[parameter]
optimizer_hooked_count[i] += 1
if optimizer_hooked_count[i] == num_parameters_per_group[i]:
optimizers[i].step()
optimizers[i].zero_grad(set_to_none=True)
parameter.register_post_accumulate_grad_hook(optimizer_hook)
parameter_optimizer_map[parameter] = opt_idx
num_parameters_per_group[opt_idx] += 1
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
@@ -422,7 +564,9 @@ def train(args):
accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
accelerator.print(
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
)
# accelerator.print(
# f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
# )
@@ -441,10 +585,22 @@ def train(args):
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
accelerator.init_trackers(
"finetuning" if args.log_tracker_name is None else args.log_tracker_name,
config=train_util.get_sanitized_config_or_none(args),
init_kwargs=init_kwargs,
)
# For --sample_at_first
sdxl_train_util.sample_images(
accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [text_encoder1, text_encoder2], unet
)
loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
@@ -452,10 +608,13 @@ def train(args):
for m in training_models:
m.train()
loss_total = 0
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
if args.fused_optimizer_groups:
optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step
with accelerator.accumulate(*training_models):
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
@@ -466,7 +625,7 @@ def train(args):
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
latents = torch.nan_to_num(latents, 0, out=latents)
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
@@ -487,6 +646,7 @@ def train(args):
# else:
input_ids1 = input_ids1.to(accelerator.device)
input_ids2 = input_ids2.to(accelerator.device)
# unwrap_model is fine for models not wrapped by accelerator
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
args.max_token_length,
input_ids1,
@@ -496,6 +656,7 @@ def train(args):
text_encoder1,
text_encoder2,
None if not args.full_fp16 else weight_dtype,
accelerator=accelerator,
)
else:
encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
@@ -517,7 +678,7 @@ def train(args):
# assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# print("text encoder outputs verified")
# logger.info("text encoder outputs verified")
# get size embeddings
orig_size = batch["original_sizes_hw"]
@@ -531,7 +692,9 @@ def train(args):
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
@@ -539,34 +702,60 @@ def train(args):
with accelerator.autocast():
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
target = noise
if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
target = noise
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.v_pred_like_loss:
if (
args.min_snr_gamma
or args.scale_v_pred_loss_like_noise_pred
or args.v_pred_like_loss
or args.debiased_estimation_loss
or args.masked_loss
):
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss:
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
loss = loss.mean() # mean over batch dimension
else:
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
)
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = []
for m in training_models:
params_to_clip.extend(m.parameters())
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
if not (args.fused_backward_pass or args.fused_optimizer_groups):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = []
for m in training_models:
params_to_clip.extend(m.parameters())
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
else:
# optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
lr_scheduler.step()
if args.fused_optimizer_groups:
for i in range(1, len(optimizers)):
lr_schedulers[i].step()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
@@ -613,29 +802,22 @@ def train(args):
if args.logging_dir is not None:
logs = {"loss": current_loss}
if block_lrs is None:
logs["lr"] = float(lr_scheduler.get_last_lr()[0])
if (
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
): # tracking d*lr value
logs["lr/d*lr"] = (
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
)
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=train_unet)
else:
append_block_lr_to_logs(block_lrs, logs, lr_scheduler, args.optimizer_type)
append_block_lr_to_logs(block_lrs, logs, lr_scheduler, args.optimizer_type) # U-Net is included in block_lrs
accelerator.log(logs, step=global_step)
# TODO moving averageにする
loss_total += current_loss
avr_loss = loss_total / (step + 1)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(train_dataloader)}
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
@@ -682,7 +864,7 @@ def train(args):
accelerator.end_training()
if args.save_state: # and is_main_process:
if args.save_state or args.save_state_on_train_end:
train_util.save_state_on_train_end(args, accelerator)
del accelerator # この後メモリを使うのでこれは消す
@@ -704,22 +886,40 @@ def train(args):
logit_scale,
ckpt_info,
)
print("model saved.")
logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, False)
train_util.add_masked_loss_arguments(parser)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
sdxl_train_util.add_sdxl_training_arguments(parser)
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
parser.add_argument(
"--learning_rate_te1",
type=float,
default=None,
help="learning rate for text encoder 1 (ViT-L) / text encoder 1 (ViT-L)の学習率",
)
parser.add_argument(
"--learning_rate_te2",
type=float,
default=None,
help="learning rate for text encoder 2 (BiG-G) / text encoder 2 (BiG-G)の学習率",
)
parser.add_argument(
"--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する"
)
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
parser.add_argument(
"--no_half_vae",
@@ -733,7 +933,12 @@ def setup_parser() -> argparse.ArgumentParser:
help=f"learning rates for each block of U-Net, comma-separated, {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / "
+ f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値",
)
parser.add_argument(
"--fused_optimizer_groups",
type=int,
default=None,
help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数",
)
return parser
@@ -741,6 +946,7 @@ if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
train(args)

View File

@@ -1,5 +1,7 @@
# cond_imageをU-Netのforwardで渡すバージョンのControlNet-LLLite検証用学習コード
# training code for ControlNet-LLLite with passing cond_image to U-Net's forward
import argparse
import gc
import json
import math
import os
@@ -10,12 +12,18 @@ from types import SimpleNamespace
import toml
from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from torch.nn.parallel import DistributedDataParallel as DDP
from accelerate.utils import set_seed
import accelerate
from diffusers import DDPMScheduler, ControlNetModel
from safetensors.torch import load_file
from library import sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util
from library import deepspeed_utils, sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util
import library.model_util as model_util
import library.train_util as train_util
@@ -33,8 +41,15 @@ from library.custom_train_functions import (
pyramid_noise_like,
apply_noise_offset,
scale_v_prediction_loss_like_noise_prediction,
apply_debiased_estimation,
)
import networks.control_net_lllite as control_net_lllite
import networks.control_net_lllite_for_train as control_net_lllite_for_train
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
# TODO 他のスクリプトと共通化する
@@ -55,6 +70,7 @@ def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
sdxl_train_util.verify_sdxl_training_args(args)
setup_logging(args, reset=True)
cache_latents = args.cache_latents
use_user_config = args.dataset_config is not None
@@ -68,11 +84,11 @@ def train(args):
# データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
if use_user_config:
print(f"Load dataset config from {args.dataset_config}")
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "conditioning_data_dir"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
@@ -95,8 +111,8 @@ def train(args):
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
train_dataset_group.verify_bucket_reso_steps(32)
@@ -104,7 +120,7 @@ def train(args):
train_util.debug_dataset(train_dataset_group)
return
if len(train_dataset_group) == 0:
print(
logger.error(
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してくださいtrain_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります"
)
return
@@ -114,7 +130,9 @@ def train(args):
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
else:
print("WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません")
logger.warning(
"WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません"
)
if args.cache_text_encoder_outputs:
assert (
@@ -122,7 +140,7 @@ def train(args):
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
# acceleratorを準備する
print("prepare accelerator")
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
is_main_process = accelerator.is_main_process
@@ -141,9 +159,6 @@ def train(args):
ckpt_info,
) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype)
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=vae_dtype)
@@ -157,9 +172,7 @@ def train(args):
accelerator.is_main_process,
)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
@@ -177,36 +190,67 @@ def train(args):
)
accelerator.wait_for_everyone()
# prepare ControlNet
network = control_net_lllite.ControlNetLLLite(unet, args.cond_emb_dim, args.network_dim, args.network_dropout)
network.apply_to()
# prepare ControlNet-LLLite
control_net_lllite_for_train.replace_unet_linear_and_conv2d()
if args.network_weights is not None:
info = network.load_weights(args.network_weights)
accelerator.print(f"load ControlNet weights from {args.network_weights}: {info}")
accelerator.print(f"initialize U-Net with ControlNet-LLLite")
with accelerate.init_empty_weights():
unet_lllite = control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite()
unet_lllite.to(accelerator.device, dtype=weight_dtype)
unet_sd = unet.state_dict()
info = unet_lllite.load_lllite_weights(args.network_weights, unet_sd)
accelerator.print(f"load ControlNet-LLLite weights from {args.network_weights}: {info}")
else:
# cosumes large memory, so send to GPU before creating the LLLite model
accelerator.print("sending U-Net to GPU")
unet.to(accelerator.device, dtype=weight_dtype)
unet_sd = unet.state_dict()
# init LLLite weights
accelerator.print(f"initialize U-Net with ControlNet-LLLite")
if args.lowram:
with accelerate.init_on_device(accelerator.device):
unet_lllite = control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite()
else:
unet_lllite = control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite()
unet_lllite.to(weight_dtype)
info = unet_lllite.load_lllite_weights(None, unet_sd)
accelerator.print(f"init U-Net with ControlNet-LLLite weights: {info}")
del unet_sd, unet
unet: control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite = unet_lllite
del unet_lllite
unet.apply_lllite(args.cond_emb_dim, args.network_dim, args.network_dropout)
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
network.enable_gradient_checkpointing() # may have no effect
# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")
trainable_params = list(network.prepare_optimizer_params())
print(f"trainable params count: {len(trainable_params)}")
print(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}")
trainable_params = list(unet.prepare_params())
logger.info(f"trainable params count: {len(trainable_params)}")
logger.info(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}")
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collater,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
@@ -216,7 +260,9 @@ def train(args):
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
@@ -225,44 +271,38 @@ def train(args):
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# 実験的機能勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
if args.full_fp16:
assert (
args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
accelerator.print("enable full fp16 training.")
unet.to(weight_dtype)
network.to(weight_dtype)
elif args.full_bf16:
assert (
args.mixed_precision == "bf16"
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
accelerator.print("enable full bf16 training.")
unet.to(weight_dtype)
network.to(weight_dtype)
# if args.full_fp16:
# assert (
# args.mixed_precision == "fp16"
# ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
# accelerator.print("enable full fp16 training.")
# unet.to(weight_dtype)
# elif args.full_bf16:
# assert (
# args.mixed_precision == "bf16"
# ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
# accelerator.print("enable full bf16 training.")
# unet.to(weight_dtype)
unet.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, network, optimizer, train_dataloader, lr_scheduler
)
network: control_net_lllite.ControlNetLLLite
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
# transform DDP after prepare (train_network here only)
unet, network = train_util.transform_models_if_DDP([unet, network])
if isinstance(unet, DDP):
unet._set_static_graph() # avoid error for multiple use of the parameter
if args.gradient_checkpointing:
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
else:
unet.eval()
network.prepare_grad_etc()
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
if args.cache_text_encoder_outputs:
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
text_encoder1.to("cpu", dtype=torch.float32)
text_encoder2.to("cpu", dtype=torch.float32)
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory_on_device(accelerator.device)
else:
# make sure Text Encoders are on GPU
text_encoder1.to(accelerator.device)
@@ -293,8 +333,10 @@ def train(args):
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
# print(f" total train batch size (with parallel & distributed & accumulation) / バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
accelerator.print(
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
)
# logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
@@ -310,18 +352,25 @@ def train(args):
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs
)
loss_list = []
loss_total = 0.0
loss_recorder = train_util.LossRecorder()
del train_dataset_group
# function for saving/removing
def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False):
def save_model(
ckpt_name,
unwrapped_nw: control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite,
steps,
epoch_no,
force_sync_upload=False,
):
os.makedirs(args.output_dir, exist_ok=True)
ckpt_file = os.path.join(args.output_dir, ckpt_name)
@@ -329,7 +378,7 @@ def train(args):
sai_metadata = train_util.get_sai_model_spec(None, args, True, True, False)
sai_metadata["modelspec.architecture"] = sai_model_spec.ARCH_SD_XL_V1_BASE + "/control-net-lllite"
unwrapped_nw.save_weights(ckpt_file, save_dtype, sai_metadata)
unwrapped_nw.save_lllite_weights(ckpt_file, save_dtype, sai_metadata)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
@@ -344,22 +393,20 @@ def train(args):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
network.on_epoch_start() # train()
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(network):
with accelerator.accumulate(unet):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample()
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
latents = torch.nan_to_num(latents, 0, out=latents)
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
@@ -396,7 +443,9 @@ def train(args):
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
@@ -405,10 +454,9 @@ def train(args):
with accelerator.autocast():
# conditioning imageをControlNetに渡す / pass conditioning image to ControlNet
# 内部でcond_embに変換される / it will be converted to cond_emb inside
network.set_cond_image(controlnet_image)
# それらの値を使いつつ、U-Netでイズを予測する / predict noise with U-Net using those values
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding, controlnet_image)
if args.v_parameterization:
# v-parameterization training
@@ -416,24 +464,28 @@ def train(args):
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss:
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = network.get_trainable_params()
params_to_clip = accelerator.unwrap_model(unet).get_trainable_params()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
@@ -452,7 +504,7 @@ def train(args):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch)
save_model(ckpt_name, accelerator.unwrap_model(unet), global_step, epoch)
if args.save_state:
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
@@ -463,14 +515,9 @@ def train(args):
remove_model(remove_ckpt_name)
current_loss = loss.detach().item()
if epoch == 0:
loss_list.append(current_loss)
else:
loss_total -= loss_list[step]
loss_list[step] = current_loss
loss_total += current_loss
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if args.logging_dir is not None:
@@ -481,7 +528,7 @@ def train(args):
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(loss_list)}
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
@@ -491,7 +538,7 @@ def train(args):
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
if is_main_process and saving:
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch + 1)
save_model(ckpt_name, accelerator.unwrap_model(unet), global_step, epoch + 1)
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
if remove_epoch_no is not None:
@@ -506,26 +553,28 @@ def train(args):
# end of epoch
if is_main_process:
network = accelerator.unwrap_model(network)
unet = accelerator.unwrap_model(unet)
accelerator.end_training()
if is_main_process and args.save_state:
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)
if is_main_process:
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True)
save_model(ckpt_name, unet, global_step, num_train_epochs, force_sync_upload=True)
print("model saved.")
logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
@@ -538,8 +587,12 @@ def setup_parser() -> argparse.ArgumentParser:
choices=[None, "ckpt", "pt", "safetensors"],
help="format to save the model (default is .safetensors) / モデル保存時の形式デフォルトはsafetensors",
)
parser.add_argument("--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数")
parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み")
parser.add_argument(
"--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数"
)
parser.add_argument(
"--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み"
)
parser.add_argument("--network_dim", type=int, default=None, help="network dimensions (rank) / モジュールの次元数")
parser.add_argument(
"--network_dropout",
@@ -567,6 +620,7 @@ if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
train(args)

View File

@@ -1,8 +1,4 @@
# cond_imageをU-Netのforwardで渡すバージョンのControlNet-LLLite検証用学習コード
# training code for ControlNet-LLLite with passing cond_image to U-Net's forward
import argparse
import gc
import json
import math
import os
@@ -13,13 +9,16 @@ from types import SimpleNamespace
import toml
from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from torch.nn.parallel import DistributedDataParallel as DDP
from accelerate.utils import set_seed
import accelerate
from diffusers import DDPMScheduler, ControlNetModel
from safetensors.torch import load_file
from library import sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util
from library import deepspeed_utils, sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util
import library.model_util as model_util
import library.train_util as train_util
@@ -37,8 +36,15 @@ from library.custom_train_functions import (
pyramid_noise_like,
apply_noise_offset,
scale_v_prediction_loss_like_noise_prediction,
apply_debiased_estimation,
)
import networks.control_net_lllite_for_train as control_net_lllite_for_train
import networks.control_net_lllite as control_net_lllite
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
# TODO 他のスクリプトと共通化する
@@ -59,6 +65,7 @@ def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
sdxl_train_util.verify_sdxl_training_args(args)
setup_logging(args, reset=True)
cache_latents = args.cache_latents
use_user_config = args.dataset_config is not None
@@ -72,11 +79,11 @@ def train(args):
# データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
if use_user_config:
print(f"Load dataset config from {args.dataset_config}")
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "conditioning_data_dir"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
@@ -99,8 +106,8 @@ def train(args):
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
train_dataset_group.verify_bucket_reso_steps(32)
@@ -108,7 +115,7 @@ def train(args):
train_util.debug_dataset(train_dataset_group)
return
if len(train_dataset_group) == 0:
print(
logger.error(
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してくださいtrain_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります"
)
return
@@ -118,7 +125,9 @@ def train(args):
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
else:
print("WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません")
logger.warning(
"WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません"
)
if args.cache_text_encoder_outputs:
assert (
@@ -126,7 +135,7 @@ def train(args):
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
# acceleratorを準備する
print("prepare accelerator")
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
is_main_process = accelerator.is_main_process
@@ -145,6 +154,9 @@ def train(args):
ckpt_info,
) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype)
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=vae_dtype)
@@ -158,9 +170,7 @@ def train(args):
accelerator.is_main_process,
)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
@@ -178,67 +188,36 @@ def train(args):
)
accelerator.wait_for_everyone()
# prepare ControlNet-LLLite
control_net_lllite_for_train.replace_unet_linear_and_conv2d()
# prepare ControlNet
network = control_net_lllite.ControlNetLLLite(unet, args.cond_emb_dim, args.network_dim, args.network_dropout)
network.apply_to()
if args.network_weights is not None:
accelerator.print(f"initialize U-Net with ControlNet-LLLite")
with accelerate.init_empty_weights():
unet_lllite = control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite()
unet_lllite.to(accelerator.device, dtype=weight_dtype)
unet_sd = unet.state_dict()
info = unet_lllite.load_lllite_weights(args.network_weights, unet_sd)
accelerator.print(f"load ControlNet-LLLite weights from {args.network_weights}: {info}")
else:
# cosumes large memory, so send to GPU before creating the LLLite model
accelerator.print("sending U-Net to GPU")
unet.to(accelerator.device, dtype=weight_dtype)
unet_sd = unet.state_dict()
# init LLLite weights
accelerator.print(f"initialize U-Net with ControlNet-LLLite")
if args.lowram:
with accelerate.init_on_device(accelerator.device):
unet_lllite = control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite()
else:
unet_lllite = control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite()
unet_lllite.to(weight_dtype)
info = unet_lllite.load_lllite_weights(None, unet_sd)
accelerator.print(f"init U-Net with ControlNet-LLLite weights: {info}")
del unet_sd, unet
unet: control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite = unet_lllite
del unet_lllite
unet.apply_lllite(args.cond_emb_dim, args.network_dim, args.network_dropout)
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
info = network.load_weights(args.network_weights)
accelerator.print(f"load ControlNet weights from {args.network_weights}: {info}")
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
network.enable_gradient_checkpointing() # may have no effect
# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")
trainable_params = list(unet.prepare_params())
print(f"trainable params count: {len(trainable_params)}")
print(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}")
trainable_params = list(network.prepare_optimizer_params())
logger.info(f"trainable params count: {len(trainable_params)}")
logger.info(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}")
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collater,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
@@ -248,7 +227,9 @@ def train(args):
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
@@ -257,39 +238,40 @@ def train(args):
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# 実験的機能勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
# if args.full_fp16:
# assert (
# args.mixed_precision == "fp16"
# ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
# accelerator.print("enable full fp16 training.")
# unet.to(weight_dtype)
# elif args.full_bf16:
# assert (
# args.mixed_precision == "bf16"
# ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
# accelerator.print("enable full bf16 training.")
# unet.to(weight_dtype)
unet.to(weight_dtype)
if args.full_fp16:
assert (
args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
accelerator.print("enable full fp16 training.")
unet.to(weight_dtype)
network.to(weight_dtype)
elif args.full_bf16:
assert (
args.mixed_precision == "bf16"
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
accelerator.print("enable full bf16 training.")
unet.to(weight_dtype)
network.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
# transform DDP after prepare (train_network here only)
unet = train_util.transform_models_if_DDP([unet])[0]
unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, network, optimizer, train_dataloader, lr_scheduler
)
network: control_net_lllite.ControlNetLLLite
if args.gradient_checkpointing:
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
else:
unet.eval()
network.prepare_grad_etc()
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
if args.cache_text_encoder_outputs:
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
text_encoder1.to("cpu", dtype=torch.float32)
text_encoder2.to("cpu", dtype=torch.float32)
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory_on_device(accelerator.device)
else:
# make sure Text Encoders are on GPU
text_encoder1.to(accelerator.device)
@@ -320,8 +302,10 @@ def train(args):
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
# print(f" total train batch size (with parallel & distributed & accumulation) / バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
accelerator.print(
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
)
# logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
@@ -340,21 +324,14 @@ def train(args):
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs
)
loss_list = []
loss_total = 0.0
loss_recorder = train_util.LossRecorder()
del train_dataset_group
# function for saving/removing
def save_model(
ckpt_name,
unwrapped_nw: control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite,
steps,
epoch_no,
force_sync_upload=False,
):
def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False):
os.makedirs(args.output_dir, exist_ok=True)
ckpt_file = os.path.join(args.output_dir, ckpt_name)
@@ -362,7 +339,7 @@ def train(args):
sai_metadata = train_util.get_sai_model_spec(None, args, True, True, False)
sai_metadata["modelspec.architecture"] = sai_model_spec.ARCH_SD_XL_V1_BASE + "/control-net-lllite"
unwrapped_nw.save_lllite_weights(ckpt_file, save_dtype, sai_metadata)
unwrapped_nw.save_weights(ckpt_file, save_dtype, sai_metadata)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
@@ -377,20 +354,22 @@ def train(args):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
network.on_epoch_start() # train()
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(unet):
with accelerator.accumulate(network):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample()
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
latents = torch.nan_to_num(latents, 0, out=latents)
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
@@ -427,7 +406,7 @@ def train(args):
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
@@ -436,9 +415,10 @@ def train(args):
with accelerator.autocast():
# conditioning imageをControlNetに渡す / pass conditioning image to ControlNet
# 内部でcond_embに変換される / it will be converted to cond_emb inside
network.set_cond_image(controlnet_image)
# それらの値を使いつつ、U-Netでイズを予測する / predict noise with U-Net using those values
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding, controlnet_image)
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
if args.v_parameterization:
# v-parameterization training
@@ -446,24 +426,26 @@ def train(args):
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss:
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = unet.get_trainable_params()
params_to_clip = network.get_trainable_params()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
@@ -482,7 +464,7 @@ def train(args):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
save_model(ckpt_name, accelerator.unwrap_model(unet), global_step, epoch)
save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch)
if args.save_state:
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
@@ -493,14 +475,9 @@ def train(args):
remove_model(remove_ckpt_name)
current_loss = loss.detach().item()
if epoch == 0:
loss_list.append(current_loss)
else:
loss_total -= loss_list[step]
loss_list[step] = current_loss
loss_total += current_loss
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if args.logging_dir is not None:
@@ -511,7 +488,7 @@ def train(args):
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(loss_list)}
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
@@ -521,7 +498,7 @@ def train(args):
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
if is_main_process and saving:
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
save_model(ckpt_name, accelerator.unwrap_model(unet), global_step, epoch + 1)
save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch + 1)
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
if remove_epoch_no is not None:
@@ -536,7 +513,7 @@ def train(args):
# end of epoch
if is_main_process:
unet = accelerator.unwrap_model(unet)
network = accelerator.unwrap_model(network)
accelerator.end_training()
@@ -545,17 +522,19 @@ def train(args):
if is_main_process:
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(ckpt_name, unet, global_step, num_train_epochs, force_sync_upload=True)
save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True)
print("model saved.")
logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
@@ -568,8 +547,12 @@ def setup_parser() -> argparse.ArgumentParser:
choices=[None, "ckpt", "pt", "safetensors"],
help="format to save the model (default is .safetensors) / モデル保存時の形式デフォルトはsafetensors",
)
parser.add_argument("--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数")
parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み")
parser.add_argument(
"--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数"
)
parser.add_argument(
"--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み"
)
parser.add_argument("--network_dim", type=int, default=None, help="network dimensions (rank) / モジュールの次元数")
parser.add_argument(
"--network_dropout",
@@ -597,6 +580,7 @@ if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
train(args)

View File

@@ -1,8 +1,15 @@
import argparse
import torch
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from library import sdxl_model_util, sdxl_train_util, train_util
import train_network
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class SdxlNetworkTrainer(train_network.NetworkTrainer):
def __init__(self):
@@ -55,36 +62,36 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
if args.cache_text_encoder_outputs:
if not args.lowram:
# メモリ消費を減らす
print("move vae and unet to cpu to save memory")
logger.info("move vae and unet to cpu to save memory")
org_vae_device = vae.device
org_unet_device = unet.device
vae.to("cpu")
unet.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory_on_device(accelerator.device)
dataset.cache_text_encoder_outputs(
tokenizers,
text_encoders,
accelerator.device,
weight_dtype,
args.cache_text_encoder_outputs_to_disk,
accelerator.is_main_process,
)
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
with accelerator.autocast():
dataset.cache_text_encoder_outputs(
tokenizers,
text_encoders,
accelerator.device,
weight_dtype,
args.cache_text_encoder_outputs_to_disk,
accelerator.is_main_process,
)
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
text_encoders[1].to("cpu", dtype=torch.float32)
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory_on_device(accelerator.device)
if not args.lowram:
print("move vae and unet back to original device")
logger.info("move vae and unet back to original device")
vae.to(org_vae_device)
unet.to(org_unet_device)
else:
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
text_encoders[0].to(accelerator.device)
text_encoders[1].to(accelerator.device)
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
text_encoders[1].to(accelerator.device, dtype=weight_dtype)
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
@@ -114,6 +121,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
text_encoders[0],
text_encoders[1],
None if not args.full_fp16 else weight_dtype,
accelerator=accelerator,
)
else:
encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
@@ -135,7 +143,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
# assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# print("text encoder outputs verified")
# logger.info("text encoder outputs verified")
return encoder_hidden_states1, encoder_hidden_states2, pool2
@@ -170,6 +178,7 @@ if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
trainer = SdxlNetworkTrainer()

View File

@@ -2,8 +2,11 @@ import argparse
import os
import regex
import torch
import open_clip
from library.device_utils import init_ipex
init_ipex()
from library import sdxl_model_util, sdxl_train_util, train_util
import train_textual_inversion
@@ -57,6 +60,7 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
text_encoders[0],
text_encoders[1],
None if not args.full_fp16 else weight_dtype,
accelerator=accelerator,
)
return encoder_hidden_states1, encoder_hidden_states2, pool2
@@ -127,6 +131,7 @@ if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
trainer = SdxlTextualInversionTrainer()

View File

@@ -16,9 +16,15 @@ from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
def cache_to_disk(args: argparse.Namespace) -> None:
setup_logging(args, reset=True)
train_util.prepare_dataset_args(args, True)
# check cache latents arg
@@ -41,18 +47,18 @@ def cache_to_disk(args: argparse.Namespace) -> None:
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
if use_dreambooth_method:
print("Using DreamBooth method.")
logger.info("Using DreamBooth method.")
user_config = {
"datasets": [
{
@@ -63,7 +69,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
]
}
else:
print("Training with captions.")
logger.info("Training with captions.")
user_config = {
"datasets": [
{
@@ -86,11 +92,12 @@ def cache_to_disk(args: argparse.Namespace) -> None:
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
# acceleratorを準備する
print("prepare accelerator")
logger.info("prepare accelerator")
args.deepspeed = False
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
@@ -98,13 +105,13 @@ def cache_to_disk(args: argparse.Namespace) -> None:
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
# モデルを読み込む
print("load model")
logger.info("load model")
if args.sdxl:
(_, _, _, vae, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
else:
_, vae, _, _ = train_util.load_target_model(args, weight_dtype, accelerator)
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
vae.set_use_memory_efficient_attention_xformers(args.xformers)
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
@@ -113,14 +120,14 @@ def cache_to_disk(args: argparse.Namespace) -> None:
# dataloaderを準備する
train_dataset_group.set_caching_mode("latents")
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collater,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
@@ -133,6 +140,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
b_size = len(batch["images"])
vae_batch_size = b_size if args.vae_batch_size is None else args.vae_batch_size
flip_aug = batch["flip_aug"]
alpha_mask = batch["alpha_mask"]
random_crop = batch["random_crop"]
bucket_reso = batch["bucket_reso"]
@@ -151,14 +159,16 @@ def cache_to_disk(args: argparse.Namespace) -> None:
image_info.latents_npz = os.path.splitext(absolute_path)[0] + ".npz"
if args.skip_existing:
if train_util.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug):
print(f"Skipping {image_info.latents_npz} because it already exists.")
if train_util.is_disk_cached_latents_is_expected(
image_info.bucket_reso, image_info.latents_npz, flip_aug, alpha_mask
):
logger.warning(f"Skipping {image_info.latents_npz} because it already exists.")
continue
image_infos.append(image_info)
if len(image_infos) > 0:
train_util.cache_batch_latents(vae, True, image_infos, flip_aug, random_crop)
train_util.cache_batch_latents(vae, True, image_infos, flip_aug, alpha_mask, random_crop)
accelerator.wait_for_everyone()
accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.")
@@ -167,6 +177,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
train_util.add_training_arguments(parser, True)
train_util.add_dataset_arguments(parser, True, True, True)

View File

@@ -16,9 +16,13 @@ from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
def cache_to_disk(args: argparse.Namespace) -> None:
setup_logging(args, reset=True)
train_util.prepare_dataset_args(args, True)
# check cache arg
@@ -48,18 +52,18 @@ def cache_to_disk(args: argparse.Namespace) -> None:
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
if use_dreambooth_method:
print("Using DreamBooth method.")
logger.info("Using DreamBooth method.")
user_config = {
"datasets": [
{
@@ -70,7 +74,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
]
}
else:
print("Training with captions.")
logger.info("Training with captions.")
user_config = {
"datasets": [
{
@@ -91,18 +95,19 @@ def cache_to_disk(args: argparse.Namespace) -> None:
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
# acceleratorを準備する
print("prepare accelerator")
logger.info("prepare accelerator")
args.deepspeed = False
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, _ = train_util.prepare_dtype(args)
# モデルを読み込む
print("load model")
logger.info("load model")
if args.sdxl:
(_, text_encoder1, text_encoder2, _, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
text_encoders = [text_encoder1, text_encoder2]
@@ -118,14 +123,14 @@ def cache_to_disk(args: argparse.Namespace) -> None:
# dataloaderを準備する
train_dataset_group.set_caching_mode("text")
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collater,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
@@ -147,7 +152,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
if args.skip_existing:
if os.path.exists(image_info.text_encoder_outputs_npz):
print(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.")
logger.warning(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.")
continue
image_info.input_ids1 = input_ids1
@@ -168,6 +173,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
train_util.add_training_arguments(parser, True)
train_util.add_dataset_arguments(parser, True, True, True)

View File

@@ -1,6 +1,10 @@
import argparse
import cv2
import logging
from library.utils import setup_logging
setup_logging()
logger = logging.getLogger(__name__)
def canny(args):
img = cv2.imread(args.input)
@@ -10,7 +14,7 @@ def canny(args):
# canny_img = 255 - canny_img
cv2.imwrite(args.output, canny_img)
print("done!")
logger.info("done!")
def setup_parser() -> argparse.ArgumentParser:

View File

@@ -6,7 +6,10 @@ import torch
from diffusers import StableDiffusionPipeline
import library.model_util as model_util
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def convert(args):
# 引数を確認する
@@ -23,21 +26,23 @@ def convert(args):
is_load_ckpt = os.path.isfile(args.model_to_load)
is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0
assert not is_load_ckpt or args.v1 != args.v2, f"v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です"
assert not is_load_ckpt or args.v1 != args.v2, "v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です"
# assert (
# is_save_ckpt or args.reference_model is not None
# ), f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です"
# モデルを読み込む
msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else ""))
print(f"loading {msg}: {args.model_to_load}")
logger.info(f"loading {msg}: {args.model_to_load}")
if is_load_ckpt:
v2_model = args.v2
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load, unet_use_linear_projection_in_v2=args.unet_use_linear_projection)
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(
v2_model, args.model_to_load, unet_use_linear_projection_in_v2=args.unet_use_linear_projection
)
else:
pipe = StableDiffusionPipeline.from_pretrained(
args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None
args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None, variant=args.variant
)
text_encoder = pipe.text_encoder
vae = pipe.vae
@@ -46,26 +51,37 @@ def convert(args):
if args.v1 == args.v2:
# 自動判定する
v2_model = unet.config.cross_attention_dim == 1024
print("checking model version: model is " + ("v2" if v2_model else "v1"))
logger.info("checking model version: model is " + ("v2" if v2_model else "v1"))
else:
v2_model = not args.v1
# 変換して保存する
msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers"
print(f"converting and saving as {msg}: {args.model_to_save}")
logger.info(f"converting and saving as {msg}: {args.model_to_save}")
if is_save_ckpt:
original_model = args.model_to_load if is_load_ckpt else None
key_count = model_util.save_stable_diffusion_checkpoint(
v2_model, args.model_to_save, text_encoder, unet, original_model, args.epoch, args.global_step, save_dtype, vae
v2_model,
args.model_to_save,
text_encoder,
unet,
original_model,
args.epoch,
args.global_step,
None if args.metadata is None else eval(args.metadata),
save_dtype=save_dtype,
vae=vae,
)
print(f"model saved. total converted state_dict keys: {key_count}")
logger.info(f"model saved. total converted state_dict keys: {key_count}")
else:
print(f"copy scheduler/tokenizer config from: {args.reference_model if args.reference_model is not None else 'default model'}")
logger.info(
f"copy scheduler/tokenizer config from: {args.reference_model if args.reference_model is not None else 'default model'}"
)
model_util.save_diffusers_checkpoint(
v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors
)
print(f"model saved.")
logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser:
@@ -77,7 +93,9 @@ def setup_parser() -> argparse.ArgumentParser:
"--v2", action="store_true", help="load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む"
)
parser.add_argument(
"--unet_use_linear_projection", action="store_true", help="When saving v2 model as Diffusers, set U-Net config to `use_linear_projection=true` (to match stabilityai's model) / Diffusers形式でv2モデルを保存するときにU-Netの設定を`use_linear_projection=true`にするstabilityaiのモデルと合わせる"
"--unet_use_linear_projection",
action="store_true",
help="When saving v2 model as Diffusers, set U-Net config to `use_linear_projection=true` (to match stabilityai's model) / Diffusers形式でv2モデルを保存するときにU-Netの設定を`use_linear_projection=true`にするstabilityaiのモデルと合わせる",
)
parser.add_argument(
"--fp16",
@@ -99,6 +117,18 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--global_step", type=int, default=0, help="global_step to write to checkpoint / checkpointに記録するglobal_stepの値"
)
parser.add_argument(
"--metadata",
type=str,
default=None,
help='モデルに保存されるメタデータ、Pythonの辞書形式で指定 / metadata: metadata written in to the model in Python Dictionary. Example metadata: \'{"name": "model_name", "resolution": "512x512"}\'',
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="読む込むDiffusersのvariantを指定する、例: fp16 / variant: Diffusers variant to load. Example: fp16",
)
parser.add_argument(
"--reference_model",
type=str,

View File

@@ -15,6 +15,10 @@ import os
from anime_face_detector import create_detector
from tqdm import tqdm
import numpy as np
from library.utils import setup_logging, pil_resize
setup_logging()
import logging
logger = logging.getLogger(__name__)
KP_REYE = 11
KP_LEYE = 19
@@ -24,7 +28,7 @@ SCORE_THRES = 0.90
def detect_faces(detector, image, min_size):
preds = detector(image) # bgr
# print(len(preds))
# logger.info(len(preds))
faces = []
for pred in preds:
@@ -78,7 +82,7 @@ def process(args):
assert args.crop_ratio is None or args.resize_face_size is None, f"crop_ratio指定時はresize_face_sizeは指定できません"
# アニメ顔検出モデルを読み込む
print("loading face detector.")
logger.info("loading face detector.")
detector = create_detector('yolov3')
# cropの引数を解析する
@@ -97,7 +101,7 @@ def process(args):
crop_h_ratio, crop_v_ratio = [float(t) for t in tokens]
# 画像を処理する
print("processing.")
logger.info("processing.")
output_extension = ".png"
os.makedirs(args.dst_dir, exist_ok=True)
@@ -111,7 +115,7 @@ def process(args):
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
if image.shape[2] == 4:
print(f"image has alpha. ignore / 画像の透明度が設定されているため無視します: {path}")
logger.warning(f"image has alpha. ignore / 画像の透明度が設定されているため無視します: {path}")
image = image[:, :, :3].copy() # copyをしないと内部的に透明度情報が付いたままになるらしい
h, w = image.shape[:2]
@@ -144,11 +148,11 @@ def process(args):
# 顔サイズを基準にリサイズする
scale = args.resize_face_size / face_size
if scale < cur_crop_width / w:
print(
logger.warning(
f"image width too small in face size based resizing / 顔を基準にリサイズすると画像の幅がcrop sizeより小さい顔が相対的に大きすぎるので顔サイズが変わります: {path}")
scale = cur_crop_width / w
if scale < cur_crop_height / h:
print(
logger.warning(
f"image height too small in face size based resizing / 顔を基準にリサイズすると画像の高さがcrop sizeより小さい顔が相対的に大きすぎるので顔サイズが変わります: {path}")
scale = cur_crop_height / h
elif crop_h_ratio is not None:
@@ -157,10 +161,10 @@ def process(args):
else:
# 切り出しサイズ指定あり
if w < cur_crop_width:
print(f"image width too small/ 画像の幅がcrop sizeより小さいので画質が劣化します: {path}")
logger.warning(f"image width too small/ 画像の幅がcrop sizeより小さいので画質が劣化します: {path}")
scale = cur_crop_width / w
if h < cur_crop_height:
print(f"image height too small/ 画像の高さがcrop sizeより小さいので画質が劣化します: {path}")
logger.warning(f"image height too small/ 画像の高さがcrop sizeより小さいので画質が劣化します: {path}")
scale = cur_crop_height / h
if args.resize_fit:
scale = max(cur_crop_width / w, cur_crop_height / h)
@@ -168,7 +172,10 @@ def process(args):
if scale != 1.0:
w = int(w * scale + .5)
h = int(h * scale + .5)
face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LANCZOS4)
if scale < 1.0:
face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA)
else:
face_img = pil_resize(face_img, (w, h))
cx = int(cx * scale + .5)
cy = int(cy * scale + .5)
fw = int(fw * scale + .5)
@@ -198,7 +205,7 @@ def process(args):
face_img = face_img[y:y + cur_crop_height]
# # debug
# print(path, cx, cy, angle)
# logger.info(path, cx, cy, angle)
# crp = cv2.resize(image, (image.shape[1]//8, image.shape[0]//8))
# cv2.imshow("image", crp)
# if cv2.waitKey() == 27:

View File

@@ -11,10 +11,16 @@ from typing import Dict, List
import numpy as np
import torch
from library.device_utils import init_ipex, get_preferred_device
init_ipex()
from torch import nn
from tqdm import tqdm
from PIL import Image
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=1):
@@ -216,7 +222,7 @@ class Upscaler(nn.Module):
upsampled_images = upsampled_images / 127.5 - 1.0
# convert upsample images to latents with batch size
# print("Encoding upsampled (LANCZOS4) images...")
# logger.info("Encoding upsampled (LANCZOS4) images...")
upsampled_latents = []
for i in tqdm(range(0, upsampled_images.shape[0], vae_batch_size)):
batch = upsampled_images[i : i + vae_batch_size].to(vae.device)
@@ -227,7 +233,7 @@ class Upscaler(nn.Module):
upsampled_latents = torch.cat(upsampled_latents, dim=0)
# upscale (refine) latents with this model with batch size
print("Upscaling latents...")
logger.info("Upscaling latents...")
upscaled_latents = []
for i in range(0, upsampled_latents.shape[0], batch_size):
with torch.no_grad():
@@ -242,7 +248,7 @@ def create_upscaler(**kwargs):
weights = kwargs["weights"]
model = Upscaler()
print(f"Loading weights from {weights}...")
logger.info(f"Loading weights from {weights}...")
if os.path.splitext(weights)[1] == ".safetensors":
from safetensors.torch import load_file
@@ -255,20 +261,20 @@ def create_upscaler(**kwargs):
# another interface: upscale images with a model for given images from command line
def upscale_images(args: argparse.Namespace):
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE = get_preferred_device()
us_dtype = torch.float16 # TODO: support fp32/bf16
os.makedirs(args.output_dir, exist_ok=True)
# load VAE with Diffusers
assert args.vae_path is not None, "VAE path is required"
print(f"Loading VAE from {args.vae_path}...")
logger.info(f"Loading VAE from {args.vae_path}...")
vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae")
vae.to(DEVICE, dtype=us_dtype)
# prepare model
print("Preparing model...")
logger.info("Preparing model...")
upscaler: Upscaler = create_upscaler(weights=args.weights)
# print("Loading weights from", args.weights)
# logger.info("Loading weights from", args.weights)
# upscaler.load_state_dict(torch.load(args.weights))
upscaler.eval()
upscaler.to(DEVICE, dtype=us_dtype)
@@ -303,14 +309,14 @@ def upscale_images(args: argparse.Namespace):
image_debug.save(dest_file_name)
# upscale
print("Upscaling...")
logger.info("Upscaling...")
upscaled_latents = upscaler.upscale(
vae, images, None, us_dtype, width * 2, height * 2, batch_size=args.batch_size, vae_batch_size=args.vae_batch_size
)
upscaled_latents /= 0.18215
# decode with batch
print("Decoding...")
logger.info("Decoding...")
upscaled_images = []
for i in tqdm(range(0, upscaled_latents.shape[0], args.vae_batch_size)):
with torch.no_grad():

View File

@@ -5,7 +5,10 @@ import torch
from safetensors import safe_open
from safetensors.torch import load_file, save_file
from tqdm import tqdm
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def is_unet_key(key):
# VAE or TextEncoder, the last one is for SDXL
@@ -45,10 +48,10 @@ def merge(args):
# check if all models are safetensors
for model in args.models:
if not model.endswith("safetensors"):
print(f"Model {model} is not a safetensors model")
logger.info(f"Model {model} is not a safetensors model")
exit()
if not os.path.isfile(model):
print(f"Model {model} does not exist")
logger.info(f"Model {model} does not exist")
exit()
assert args.ratios is None or len(args.models) == len(args.ratios), "ratios must be the same length as models"
@@ -65,7 +68,7 @@ def merge(args):
if merged_sd is None:
# load first model
print(f"Loading model {model}, ratio = {ratio}...")
logger.info(f"Loading model {model}, ratio = {ratio}...")
merged_sd = {}
with safe_open(model, framework="pt", device=args.device) as f:
for key in tqdm(f.keys()):
@@ -81,11 +84,11 @@ def merge(args):
value = ratio * value.to(dtype) # first model's value * ratio
merged_sd[key] = value
print(f"Model has {len(merged_sd)} keys " + ("(UNet only)" if args.unet_only else ""))
logger.info(f"Model has {len(merged_sd)} keys " + ("(UNet only)" if args.unet_only else ""))
continue
# load other models
print(f"Loading model {model}, ratio = {ratio}...")
logger.info(f"Loading model {model}, ratio = {ratio}...")
with safe_open(model, framework="pt", device=args.device) as f:
model_keys = f.keys()
@@ -93,7 +96,7 @@ def merge(args):
_, new_key = replace_text_encoder_key(key)
if new_key not in merged_sd:
if args.show_skipped and new_key not in first_model_keys:
print(f"Skip: {new_key}")
logger.info(f"Skip: {new_key}")
continue
value = f.get_tensor(key)
@@ -104,7 +107,7 @@ def merge(args):
for key in merged_sd.keys():
if key in model_keys:
continue
print(f"Key {key} not in model {model}, use first model's value")
logger.warning(f"Key {key} not in model {model}, use first model's value")
if key in supplementary_key_ratios:
supplementary_key_ratios[key] += ratio
else:
@@ -112,7 +115,7 @@ def merge(args):
# add supplementary keys' value (including VAE and TextEncoder)
if len(supplementary_key_ratios) > 0:
print("add first model's value")
logger.info("add first model's value")
with safe_open(args.models[0], framework="pt", device=args.device) as f:
for key in tqdm(f.keys()):
_, new_key = replace_text_encoder_key(key)
@@ -120,7 +123,7 @@ def merge(args):
continue
if is_unet_key(new_key): # not VAE or TextEncoder
print(f"Key {new_key} not in all models, ratio = {supplementary_key_ratios[new_key]}")
logger.warning(f"Key {new_key} not in all models, ratio = {supplementary_key_ratios[new_key]}")
value = f.get_tensor(key) # original key
@@ -134,7 +137,7 @@ def merge(args):
if not output_file.endswith(".safetensors"):
output_file = output_file + ".safetensors"
print(f"Saving to {output_file}...")
logger.info(f"Saving to {output_file}...")
# convert to save_dtype
for k in merged_sd.keys():
@@ -142,7 +145,7 @@ def merge(args):
save_file(merged_sd, output_file)
print("Done!")
logger.info("Done!")
if __name__ == "__main__":

View File

@@ -7,7 +7,10 @@ from safetensors.torch import load_file
from library.original_unet import UNet2DConditionModel, SampleOutput
import library.model_util as model_util
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class ControlNetInfo(NamedTuple):
unet: Any
@@ -51,7 +54,7 @@ def load_control_net(v2, unet, model):
# control sdからキー変換しつつU-Netに対応する部分のみ取り出し、DiffusersのU-Netに読み込む
# state dictを読み込む
print(f"ControlNet: loading control SD model : {model}")
logger.info(f"ControlNet: loading control SD model : {model}")
if model_util.is_safetensors(model):
ctrl_sd_sd = load_file(model)
@@ -61,7 +64,7 @@ def load_control_net(v2, unet, model):
# 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む
is_difference = "difference" in ctrl_sd_sd
print("ControlNet: loading difference:", is_difference)
logger.info(f"ControlNet: loading difference: {is_difference}")
# ControlNetには存在しないキーがあるので、まず現在のU-NetでSD版の全keyを作っておく
# またTransfer Controlの元weightとなる
@@ -89,13 +92,13 @@ def load_control_net(v2, unet, model):
# ControlNetのU-Netを作成する
ctrl_unet = UNet2DConditionModel(**unet_config)
info = ctrl_unet.load_state_dict(ctrl_unet_du_sd)
print("ControlNet: loading Control U-Net:", info)
logger.info(f"ControlNet: loading Control U-Net: {info}")
# U-Net以外のControlNetを作成する
# TODO support middle only
ctrl_net = ControlNet()
info = ctrl_net.load_state_dict(zero_conv_sd)
print("ControlNet: loading ControlNet:", info)
logger.info("ControlNet: loading ControlNet: {info}")
ctrl_unet.to(unet.device, dtype=unet.dtype)
ctrl_net.to(unet.device, dtype=unet.dtype)
@@ -117,7 +120,7 @@ def load_preprocess(prep_type: str):
return canny
print("Unsupported prep type:", prep_type)
logger.info(f"Unsupported prep type: {prep_type}")
return None
@@ -174,13 +177,26 @@ def call_unet_and_control_net(
cnet_idx = step % cnet_cnt
cnet_info = control_nets[cnet_idx]
# print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
# logger.info(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
if cnet_info.ratio < current_ratio:
return original_unet(sample, timestep, encoder_hidden_states)
guided_hint = guided_hints[cnet_idx]
# gradual latent support: match the size of guided_hint to the size of sample
if guided_hint.shape[-2:] != sample.shape[-2:]:
# print(f"guided_hint.shape={guided_hint.shape}, sample.shape={sample.shape}")
org_dtype = guided_hint.dtype
if org_dtype == torch.bfloat16:
guided_hint = guided_hint.to(torch.float32)
guided_hint = torch.nn.functional.interpolate(guided_hint, size=sample.shape[-2:], mode="bicubic")
if org_dtype == torch.bfloat16:
guided_hint = guided_hint.to(org_dtype)
guided_hint = guided_hint.repeat((num_latent_input, 1, 1, 1))
outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states_for_control_net)
outs = unet_forward(
True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states_for_control_net
)
outs = [o * cnet_info.weight for o in outs]
# U-Net
@@ -192,7 +208,7 @@ def call_unet_and_control_net(
# ControlNet
cnet_outs_list = []
for i, cnet_info in enumerate(control_nets):
# print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
# logger.info(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
if cnet_info.ratio < current_ratio:
continue
guided_hint = guided_hints[i]
@@ -232,7 +248,7 @@ def unet_forward(
upsample_size = None
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
print("Forward upsample size to force interpolation output size.")
logger.info("Forward upsample size to force interpolation output size.")
forward_upsample_size = True
# 1. time

View File

@@ -6,7 +6,10 @@ import shutil
import math
from PIL import Image
import numpy as np
from library.utils import setup_logging, pil_resize
setup_logging()
import logging
logger = logging.getLogger(__name__)
def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2, interpolation=None, save_as_png=False, copy_associated_files=False):
# Split the max_resolution string by "," and strip any whitespaces
@@ -21,9 +24,9 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
# Select interpolation method
if interpolation == 'lanczos4':
cv2_interpolation = cv2.INTER_LANCZOS4
pil_interpolation = Image.LANCZOS
elif interpolation == 'cubic':
cv2_interpolation = cv2.INTER_CUBIC
pil_interpolation = Image.BICUBIC
else:
cv2_interpolation = cv2.INTER_AREA
@@ -61,7 +64,10 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
new_width = int(img.shape[1] * math.sqrt(scale_factor))
# Resize image
img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation)
if cv2_interpolation:
img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation)
else:
img = pil_resize(img, (new_width, new_height), interpolation=pil_interpolation)
else:
new_height, new_width = img.shape[0:2]
@@ -83,7 +89,7 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
image.save(os.path.join(dst_img_folder, new_filename), quality=100)
proc = "Resized" if current_pixels > max_pixels else "Saved"
print(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}")
logger.info(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}")
# If other files with same basename, copy them with resolution suffix
if copy_associated_files:
@@ -94,7 +100,7 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
continue
for max_resolution in max_resolutions:
new_asoc_file = base + '+' + max_resolution + ext
print(f"Copy {asoc_file} as {new_asoc_file}")
logger.info(f"Copy {asoc_file} as {new_asoc_file}")
shutil.copy(os.path.join(src_img_folder, asoc_file), os.path.join(dst_img_folder, new_asoc_file))

View File

@@ -1,6 +1,10 @@
import json
import argparse
from safetensors import safe_open
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, required=True)
@@ -10,10 +14,10 @@ with safe_open(args.model, framework="pt") as f:
metadata = f.metadata()
if metadata is None:
print("No metadata found")
logger.error("No metadata found")
else:
# metadata is json dict, but not pretty printed
# sort by key and pretty print
print(json.dumps(metadata, indent=4, sort_keys=True))

View File

@@ -1,16 +1,22 @@
import argparse
import gc
import json
import math
import os
import random
import time
from multiprocessing import Value
from types import SimpleNamespace
# from omegaconf import OmegaConf
import toml
from tqdm import tqdm
import torch
from library import deepspeed_utils
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from torch.nn.parallel import DistributedDataParallel as DDP
from accelerate.utils import set_seed
from diffusers import DDPMScheduler, ControlNetModel
@@ -30,6 +36,12 @@ from library.custom_train_functions import (
pyramid_noise_like,
apply_noise_offset,
)
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
# TODO 他のスクリプトと共通化する
@@ -51,6 +63,7 @@ def train(args):
# training_started_at = time.time()
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
setup_logging(args, reset=True)
cache_latents = args.cache_latents
use_user_config = args.dataset_config is not None
@@ -64,11 +77,11 @@ def train(args):
# データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
if use_user_config:
print(f"Load dataset config from {args.dataset_config}")
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "conditioning_data_dir"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
@@ -91,14 +104,16 @@ def train(args):
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
train_dataset_group.verify_bucket_reso_steps(64)
if args.debug_dataset:
train_util.debug_dataset(train_dataset_group)
return
if len(train_dataset_group) == 0:
print(
logger.error(
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してくださいtrain_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります"
)
return
@@ -109,7 +124,7 @@ def train(args):
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
# acceleratorを準備する
print("prepare accelerator")
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
is_main_process = accelerator.is_main_process
@@ -137,8 +152,10 @@ def train(args):
"in_channels": 4,
"layers_per_block": 2,
"mid_block_scale_factor": 1,
"mid_block_type": "UNetMidBlock2DCrossAttn",
"norm_eps": 1e-05,
"norm_num_groups": 32,
"num_attention_heads": [5, 10, 20, 20],
"num_class_embeds": None,
"only_cross_attention": False,
"out_channels": 4,
@@ -168,8 +185,10 @@ def train(args):
"in_channels": 4,
"layers_per_block": 2,
"mid_block_scale_factor": 1,
"mid_block_type": "UNetMidBlock2DCrossAttn",
"norm_eps": 1e-05,
"norm_num_groups": 32,
"num_attention_heads": 8,
"out_channels": 4,
"sample_size": 64,
"up_block_types": ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"],
@@ -182,7 +201,23 @@ def train(args):
"resnet_time_scale_shift": "default",
"projection_class_embeddings_input_dim": None,
}
unet.config = SimpleNamespace(**unet.config)
# unet.config = OmegaConf.create(unet.config)
# make unet.config iterable and accessible by attribute
class CustomConfig:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
def __getattr__(self, name):
if name in self.__dict__:
return self.__dict__[name]
else:
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
def __contains__(self, name):
return name in self.__dict__
unet.config = CustomConfig(**unet.config)
controlnet = ControlNetModel.from_unet(unet)
@@ -214,9 +249,7 @@ def train(args):
accelerator.is_main_process,
)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
@@ -226,19 +259,19 @@ def train(args):
# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")
trainable_params = controlnet.parameters()
trainable_params = list(controlnet.parameters())
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collater,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
@@ -248,7 +281,9 @@ def train(args):
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
@@ -304,8 +339,10 @@ def train(args):
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
# print(f" total train batch size (with parallel & distributed & accumulation) / バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
accelerator.print(
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
)
# logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
@@ -326,12 +363,17 @@ def train(args):
)
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
accelerator.init_trackers(
"controlnet_train" if args.log_tracker_name is None else args.log_tracker_name,
config=train_util.get_sanitized_config_or_none(args),
init_kwargs=init_kwargs,
)
loss_list = []
loss_total = 0.0
loss_recorder = train_util.LossRecorder()
del train_dataset_group
# function for saving/removing
@@ -365,6 +407,11 @@ def train(args):
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
# For --sample_at_first
train_util.sample_images(
accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, controlnet=controlnet
)
# training loop
for epoch in range(num_train_epochs):
if is_main_process:
@@ -376,7 +423,7 @@ def train(args):
with accelerator.accumulate(controlnet):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
@@ -399,13 +446,10 @@ def train(args):
)
# Sample a random timestep for each image
timesteps = torch.randint(
0,
noise_scheduler.config.num_train_timesteps,
(b_size,),
device=latents.device,
timesteps, huber_c = train_util.get_timesteps_and_huber_c(
args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler, b_size, latents.device
)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
@@ -436,14 +480,16 @@ def train(args):
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
@@ -493,14 +539,9 @@ def train(args):
remove_model(remove_ckpt_name)
current_loss = loss.detach().item()
if epoch == 0:
loss_list.append(current_loss)
else:
loss_total -= loss_list[step]
loss_list[step] = current_loss
loss_total += current_loss
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if args.logging_dir is not None:
@@ -511,7 +552,7 @@ def train(args):
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(loss_list)}
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
@@ -550,7 +591,7 @@ def train(args):
accelerator.end_training()
if is_main_process and args.save_state:
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)
# del accelerator # この後メモリを使うのでこれは消す→printで使うので消さずにおく
@@ -559,15 +600,17 @@ def train(args):
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(ckpt_name, controlnet, force_sync_upload=True)
print("model saved.")
logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
@@ -599,6 +642,7 @@ if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
train(args)

View File

@@ -1,7 +1,6 @@
# DreamBooth training
# XXX dropped option: fine_tune
import gc
import argparse
import itertools
import math
@@ -10,7 +9,14 @@ from multiprocessing import Value
import toml
from tqdm import tqdm
import torch
from library import deepspeed_utils
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
@@ -28,7 +34,15 @@ from library.custom_train_functions import (
pyramid_noise_like,
apply_noise_offset,
scale_v_prediction_loss_like_noise_prediction,
apply_debiased_estimation,
apply_masked_loss,
)
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
# perlin_noise,
@@ -36,6 +50,8 @@ from library.custom_train_functions import (
def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, False)
deepspeed_utils.prepare_deepspeed_args(args)
setup_logging(args, reset=True)
cache_latents = args.cache_latents
@@ -46,13 +62,13 @@ def train(args):
# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, False, True))
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, args.masked_loss, True))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "reg_data_dir"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
@@ -71,12 +87,14 @@ def train(args):
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
if args.no_token_padding:
train_dataset_group.disable_token_padding()
train_dataset_group.verify_bucket_reso_steps(64)
if args.debug_dataset:
train_util.debug_dataset(train_dataset_group)
return
@@ -87,13 +105,13 @@ def train(args):
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
# acceleratorを準備する
print("prepare accelerator")
logger.info("prepare accelerator")
if args.gradient_accumulation_steps > 1:
print(
logger.warning(
f"gradient_accumulation_steps is {args.gradient_accumulation_steps}. accelerate does not support gradient_accumulation_steps when training multiple models (U-Net and Text Encoder), so something might be wrong"
)
print(
logger.warning(
f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデルU-NetおよびText Encoderの学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です"
)
@@ -101,6 +119,7 @@ def train(args):
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
# モデルを読み込む
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)
@@ -125,15 +144,13 @@ def train(args):
# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
@@ -156,21 +173,27 @@ def train(args):
# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")
if train_text_encoder:
# wightout list, adamw8bit is crashed
trainable_params = list(itertools.chain(unet.parameters(), text_encoder.parameters()))
if args.learning_rate_te is None:
# wightout list, adamw8bit is crashed
trainable_params = list(itertools.chain(unet.parameters(), text_encoder.parameters()))
else:
trainable_params = [
{"params": list(unet.parameters()), "lr": args.learning_rate},
{"params": list(text_encoder.parameters()), "lr": args.learning_rate_te},
]
else:
trainable_params = unet.parameters()
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collater,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
@@ -180,7 +203,9 @@ def train(args):
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
@@ -201,15 +226,25 @@ def train(args):
text_encoder.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
if train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
if args.deepspeed:
if args.train_text_encoder:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder)
else:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
training_models = [ds_model]
# transform DDP after prepare
text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet)
else:
if train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
training_models = [unet, text_encoder]
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
training_models = [unet]
if not train_text_encoder:
text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error
@@ -253,12 +288,16 @@ def train(args):
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs)
loss_list = []
loss_total = 0.0
# For --sample_at_first
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
@@ -277,12 +316,14 @@ def train(args):
if not args.gradient_checkpointing:
text_encoder.train(False)
text_encoder.requires_grad_(False)
if len(training_models) == 2:
training_models = training_models[0] # remove text_encoder from training_models
with accelerator.accumulate(unet):
with accelerator.accumulate(*training_models):
with torch.no_grad():
# latentに変換
if cache_latents:
latents = batch["latents"].to(accelerator.device)
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
@@ -307,7 +348,7 @@ def train(args):
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
# Predict the noise residual
with accelerator.autocast():
@@ -319,16 +360,20 @@ def train(args):
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
@@ -376,30 +421,20 @@ def train(args):
current_loss = loss.detach().item()
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if (
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
): # tracking d*lr value
logs["lr/d*lr"] = (
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
)
logs = {"loss": current_loss}
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
accelerator.log(logs, step=global_step)
if epoch == 0:
loss_list.append(current_loss)
else:
loss_total -= loss_list[step]
loss_list[step] = current_loss
loss_total += current_loss
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(loss_list)}
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
@@ -433,7 +468,7 @@ def train(args):
accelerator.end_training()
if args.save_state and is_main_process:
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)
del accelerator # この後メモリを使うのでこれは消す
@@ -443,20 +478,29 @@ def train(args):
train_util.save_sd_model_on_train_end(
args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae
)
print("model saved.")
logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, False, True)
train_util.add_training_arguments(parser, True)
train_util.add_masked_loss_arguments(parser)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
parser.add_argument(
"--learning_rate_te",
type=float,
default=None,
help="learning rate for text encoder, default is same as unet / Text Encoderの学習率、デフォルトはunetと同じ",
)
parser.add_argument(
"--no_token_padding",
action="store_true",
@@ -468,6 +512,11 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない",
)
parser.add_argument(
"--no_half_vae",
action="store_true",
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
)
return parser
@@ -476,6 +525,7 @@ if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
train(args)

View File

@@ -1,6 +1,5 @@
import importlib
import argparse
import gc
import math
import os
import sys
@@ -11,15 +10,18 @@ from multiprocessing import Value
import toml
from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from library import model_util
from library import deepspeed_utils, model_util
import library.train_util as train_util
from library.train_util import (
DreamBoothDataset,
)
from library.train_util import DreamBoothDataset
import library.config_util as config_util
from library.config_util import (
ConfigSanitizer,
@@ -33,7 +35,15 @@ from library.custom_train_functions import (
prepare_scheduler_for_custom_training,
scale_v_prediction_loss_like_noise_prediction,
add_v_prediction_like_loss,
apply_debiased_estimation,
apply_masked_loss,
)
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
class NetworkTrainer:
@@ -43,7 +53,15 @@ class NetworkTrainer:
# TODO 他のスクリプトと共通化する
def generate_step_logs(
self, args: argparse.Namespace, current_loss, avr_loss, lr_scheduler, keys_scaled=None, mean_norm=None, maximum_norm=None
self,
args: argparse.Namespace,
current_loss,
avr_loss,
lr_scheduler,
lr_descriptions,
keys_scaled=None,
mean_norm=None,
maximum_norm=None,
):
logs = {"loss/current": current_loss, "loss/average": avr_loss}
@@ -53,39 +71,31 @@ class NetworkTrainer:
logs["max_norm/max_key_norm"] = maximum_norm
lrs = lr_scheduler.get_last_lr()
if args.network_train_text_encoder_only or len(lrs) <= 2: # not block lr (or single block)
if args.network_train_unet_only:
logs["lr/unet"] = float(lrs[0])
elif args.network_train_text_encoder_only:
logs["lr/textencoder"] = float(lrs[0])
for i, lr in enumerate(lrs):
if lr_descriptions is not None:
lr_desc = lr_descriptions[i]
else:
logs["lr/textencoder"] = float(lrs[0])
logs["lr/unet"] = float(lrs[-1]) # may be same to textencoder
idx = i - (0 if args.network_train_unet_only else -1)
if idx == -1:
lr_desc = "textencoder"
else:
if len(lrs) > 2:
lr_desc = f"group{idx}"
else:
lr_desc = "unet"
if (
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
): # tracking d*lr value of unet.
logs["lr/d*lr"] = (
lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
logs[f"lr/{lr_desc}"] = lr
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower():
# tracking d*lr value
logs[f"lr/d*lr/{lr_desc}"] = (
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
)
else:
idx = 0
if not args.network_train_unet_only:
logs["lr/textencoder"] = float(lrs[0])
idx = 1
for i in range(idx, len(lrs)):
logs[f"lr/group{i}"] = float(lrs[i])
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower():
logs[f"lr/d*lr/group{i}"] = (
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
)
return logs
def assert_extra_args(self, args, train_dataset_group):
pass
train_dataset_group.verify_bucket_reso_steps(64)
def load_target_model(self, args, weight_dtype, accelerator):
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
@@ -98,11 +108,14 @@ class NetworkTrainer:
def is_text_encoder_outputs_cached(self, args):
return False
def is_train_text_encoder(self, args):
return not args.network_train_unet_only and not self.is_text_encoder_outputs_cached(args)
def cache_text_encoder_outputs_if_needed(
self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype
):
for t_enc in text_encoders:
t_enc.to(accelerator.device)
t_enc.to(accelerator.device, dtype=weight_dtype)
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
input_ids = batch["input_ids"].to(accelerator.device)
@@ -113,6 +126,11 @@ class NetworkTrainer:
noise_pred = unet(noisy_latents, timesteps, text_conds).sample
return noise_pred
def all_reduce_network(self, accelerator, network):
for param in network.parameters():
if param.grad is not None:
param.grad = accelerator.reduce(param.grad, reduction="mean")
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)
@@ -121,6 +139,8 @@ class NetworkTrainer:
training_started_at = time.time()
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
deepspeed_utils.prepare_deepspeed_args(args)
setup_logging(args, reset=True)
cache_latents = args.cache_latents
use_dreambooth_method = args.in_json is None
@@ -136,20 +156,20 @@ class NetworkTrainer:
# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
if use_user_config:
print(f"Loading dataset config from {args.dataset_config}")
logger.info(f"Loading dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
logger.warning(
"ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
if use_dreambooth_method:
print("Using DreamBooth method.")
logger.info("Using DreamBooth method.")
user_config = {
"datasets": [
{
@@ -160,7 +180,7 @@ class NetworkTrainer:
]
}
else:
print("Training with captions.")
logger.info("Training with captions.")
user_config = {
"datasets": [
{
@@ -182,14 +202,14 @@ class NetworkTrainer:
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
if args.debug_dataset:
train_util.debug_dataset(train_dataset_group)
return
if len(train_dataset_group) == 0:
print(
logger.error(
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してくださいtrain_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります"
)
return
@@ -202,7 +222,7 @@ class NetworkTrainer:
self.assert_extra_args(args, train_dataset_group)
# acceleratorを準備する
print("preparing accelerator")
logger.info("preparing accelerator")
accelerator = train_util.prepare_accelerator(args)
is_main_process = accelerator.is_main_process
@@ -251,13 +271,12 @@ class NetworkTrainer:
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
# 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される
# cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu
self.cache_text_encoder_outputs_if_needed(
args, accelerator, unet, vae, tokenizers, text_encoders, train_dataset_group, weight_dtype
)
@@ -273,7 +292,10 @@ class NetworkTrainer:
if args.dim_from_weights:
network, _ = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet, **net_kwargs)
else:
# LyCORIS will work with this...
if "dropout" not in net_kwargs:
# workaround for LyCORIS (;^ω^)
net_kwargs["dropout"] = args.network_dropout
network = network_module.create_network(
1.0,
args.network_dim,
@@ -286,20 +308,22 @@ class NetworkTrainer:
)
if network is None:
return
network_has_multiplier = hasattr(network, "set_multiplier")
if hasattr(network, "prepare_network"):
network.prepare_network(args)
if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"):
print(
logger.warning(
"warning: scale_weight_norms is specified but the network does not support it / scale_weight_normsが指定されていますが、ネットワークが対応していません"
)
args.scale_weight_norms = False
train_unet = not args.network_train_text_encoder_only
train_text_encoder = not args.network_train_unet_only and not self.is_text_encoder_outputs_cached(args)
train_text_encoder = self.is_train_text_encoder(args)
network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
if args.network_weights is not None:
# FIXME consider alpha of weights
info = network.load_weights(args.network_weights)
accelerator.print(f"load network weights from {args.network_weights}: {info}")
@@ -315,24 +339,42 @@ class NetworkTrainer:
# 後方互換性を確保するよ
try:
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate)
except TypeError:
accelerator.print(
"Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)"
)
results = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate)
if type(results) is tuple:
trainable_params = results[0]
lr_descriptions = results[1]
else:
trainable_params = results
lr_descriptions = None
except TypeError as e:
# logger.warning(f"{e}")
# accelerator.print(
# "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)"
# )
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
lr_descriptions = None
# if len(trainable_params) == 0:
# accelerator.print("no trainable parameters found / 学習可能なパラメータが見つかりませんでした")
# for params in trainable_params:
# for k, v in params.items():
# if type(v) == float:
# pass
# else:
# v = len(v)
# accelerator.print(f"trainable_params: {k} = {v}")
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collater,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
@@ -366,51 +408,59 @@ class NetworkTrainer:
accelerator.print("enable full bf16 training.")
network.to(weight_dtype)
unet_weight_dtype = te_weight_dtype = weight_dtype
# Experimental Feature: Put base model into fp8 to save vram
if args.fp8_base:
assert torch.__version__ >= "2.1.0", "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。"
assert (
args.mixed_precision != "no"
), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。"
accelerator.print("enable fp8 training.")
unet_weight_dtype = torch.float8_e4m3fn
te_weight_dtype = torch.float8_e4m3fn
unet.requires_grad_(False)
unet.to(dtype=weight_dtype)
unet.to(dtype=unet_weight_dtype)
for t_enc in text_encoders:
t_enc.requires_grad_(False)
# acceleratorがなんかよろしくやってくれるらしい
# TODO めちゃくちゃ冗長なのでコードを整理する
if train_unet and train_text_encoder:
if len(text_encoders) > 1:
unet, t_enc1, t_enc2, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoders[0], text_encoders[1], network, optimizer, train_dataloader, lr_scheduler
)
text_encoder = text_encoders = [t_enc1, t_enc2]
del t_enc1, t_enc2
else:
unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler
)
text_encoders = [text_encoder]
elif train_unet:
unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, network, optimizer, train_dataloader, lr_scheduler
)
elif train_text_encoder:
if len(text_encoders) > 1:
t_enc1, t_enc2, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoders[0], text_encoders[1], network, optimizer, train_dataloader, lr_scheduler
)
text_encoder = text_encoders = [t_enc1, t_enc2]
del t_enc1, t_enc2
else:
text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder, network, optimizer, train_dataloader, lr_scheduler
)
text_encoders = [text_encoder]
# in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16
if t_enc.device.type != "cpu":
t_enc.to(dtype=te_weight_dtype)
# nn.Embedding not support FP8
t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
unet.to(accelerator.device, dtype=weight_dtype) # move to device because unet is not prepared by accelerator
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
if args.deepspeed:
ds_model = deepspeed_utils.prepare_deepspeed_model(
args,
unet=unet if train_unet else None,
text_encoder1=text_encoders[0] if train_text_encoder else None,
text_encoder2=text_encoders[1] if train_text_encoder and len(text_encoders) > 1 else None,
network=network,
)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
training_model = ds_model
else:
if train_unet:
unet = accelerator.prepare(unet)
else:
unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator
if train_text_encoder:
if len(text_encoders) > 1:
text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders]
else:
text_encoder = accelerator.prepare(text_encoder)
text_encoders = [text_encoder]
else:
pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
network, optimizer, train_dataloader, lr_scheduler
)
# transform DDP after prepare (train_network here only)
text_encoders = train_util.transform_models_if_DDP(text_encoders)
unet, network = train_util.transform_models_if_DDP([unet, network])
training_model = network
if args.gradient_checkpointing:
# according to TI example in Diffusers, train is required
@@ -419,7 +469,9 @@ class NetworkTrainer:
t_enc.train()
# set top parameter requires_grad = True for gradient checkpointing works
t_enc.text_model.embeddings.requires_grad_(True)
if train_text_encoder:
t_enc.text_model.embeddings.requires_grad_(True)
else:
unet.eval()
for t_enc in text_encoders:
@@ -427,7 +479,7 @@ class NetworkTrainer:
del t_enc
network.prepare_grad_etc(text_encoder, unet)
accelerator.unwrap_model(network).prepare_grad_etc(text_encoder, unet)
if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する
vae.requires_grad_(False)
@@ -438,6 +490,51 @@ class NetworkTrainer:
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
# before resuming make hook for saving/loading to save/load the network weights only
def save_model_hook(models, weights, output_dir):
# pop weights of other models than network to save only network weights
# only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606
if accelerator.is_main_process or args.deepspeed:
remove_indices = []
for i, model in enumerate(models):
if not isinstance(model, type(accelerator.unwrap_model(network))):
remove_indices.append(i)
for i in reversed(remove_indices):
if len(weights) > i:
weights.pop(i)
# print(f"save model hook: {len(weights)} weights will be saved")
# save current ecpoch and step
train_state_file = os.path.join(output_dir, "train_state.json")
# +1 is needed because the state is saved before current_step is set from global_step
logger.info(f"save train state to {train_state_file} at epoch {current_epoch.value} step {current_step.value+1}")
with open(train_state_file, "w", encoding="utf-8") as f:
json.dump({"current_epoch": current_epoch.value, "current_step": current_step.value + 1}, f)
steps_from_state = None
def load_model_hook(models, input_dir):
# remove models except network
remove_indices = []
for i, model in enumerate(models):
if not isinstance(model, type(accelerator.unwrap_model(network))):
remove_indices.append(i)
for i in reversed(remove_indices):
models.pop(i)
# print(f"load model hook: {len(models)} models will be loaded")
# load current epoch and step to
nonlocal steps_from_state
train_state_file = os.path.join(input_dir, "train_state.json")
if os.path.exists(train_state_file):
with open(train_state_file, "r", encoding="utf-8") as f:
data = json.load(f)
steps_from_state = data["current_step"]
logger.info(f"load train state from {train_state_file}: {data}")
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
# resumeする
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
@@ -509,6 +606,13 @@ class NetworkTrainer:
"ss_prior_loss_weight": args.prior_loss_weight,
"ss_min_snr_gamma": args.min_snr_gamma,
"ss_scale_weight_norms": args.scale_weight_norms,
"ss_ip_noise_gamma": args.ip_noise_gamma,
"ss_debiased_estimation": bool(args.debiased_estimation_loss),
"ss_noise_offset_random_strength": args.noise_offset_random_strength,
"ss_ip_noise_gamma_random_strength": args.ip_noise_gamma_random_strength,
"ss_loss_type": args.loss_type,
"ss_huber_schedule": args.huber_schedule,
"ss_huber_c": args.huber_c,
}
if use_user_config:
@@ -544,6 +648,11 @@ class NetworkTrainer:
"random_crop": bool(subset.random_crop),
"shuffle_caption": bool(subset.shuffle_caption),
"keep_tokens": subset.keep_tokens,
"keep_tokens_separator": subset.keep_tokens_separator,
"secondary_separator": subset.secondary_separator,
"enable_wildcard": bool(subset.enable_wildcard),
"caption_prefix": subset.caption_prefix,
"caption_suffix": subset.caption_suffix,
}
image_dir_or_metadata_file = None
@@ -666,7 +775,54 @@ class NetworkTrainer:
if key in metadata:
minimum_metadata[key] = metadata[key]
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
# calculate steps to skip when resuming or starting from a specific step
initial_step = 0
if args.initial_epoch is not None or args.initial_step is not None:
# if initial_epoch or initial_step is specified, steps_from_state is ignored even when resuming
if steps_from_state is not None:
logger.warning(
"steps from the state is ignored because initial_step is specified / initial_stepが指定されているため、stateからのステップ数は無視されます"
)
if args.initial_step is not None:
initial_step = args.initial_step
else:
# num steps per epoch is calculated by num_processes and gradient_accumulation_steps
initial_step = (args.initial_epoch - 1) * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
else:
# if initial_epoch and initial_step are not specified, steps_from_state is used when resuming
if steps_from_state is not None:
initial_step = steps_from_state
steps_from_state = None
if initial_step > 0:
assert (
args.max_train_steps > initial_step
), f"max_train_steps should be greater than initial step / max_train_stepsは初期ステップより大きい必要があります: {args.max_train_steps} vs {initial_step}"
progress_bar = tqdm(
range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps"
)
epoch_to_start = 0
if initial_step > 0:
if args.skip_until_initial_step:
# if skip_until_initial_step is specified, load data and discard it to ensure the same data is used
if not args.resume:
logger.info(
f"initial_step is specified but not resuming. lr scheduler will be started from the beginning / initial_stepが指定されていますがresumeしていないため、lr schedulerは最初から始まります"
)
logger.info(f"skipping {initial_step} steps / {initial_step}ステップをスキップします")
initial_step *= args.gradient_accumulation_steps
# set epoch to start to make initial_step less than len(train_dataloader)
epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
else:
# if not, only epoch no is skipped for informative purpose
epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
initial_step = 0 # do not skip
global_step = 0
noise_scheduler = DDPMScheduler(
@@ -678,19 +834,22 @@ class NetworkTrainer:
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"network_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
"network_train" if args.log_tracker_name is None else args.log_tracker_name,
config=train_util.get_sanitized_config_or_none(args),
init_kwargs=init_kwargs,
)
loss_list = []
loss_total = 0.0
loss_recorder = train_util.LossRecorder()
del train_dataset_group
# callback for step start
if hasattr(network, "on_step_start"):
on_step_start = network.on_step_start
if hasattr(accelerator.unwrap_model(network), "on_step_start"):
on_step_start = accelerator.unwrap_model(network).on_step_start
else:
on_step_start = lambda *args, **kwargs: None
@@ -718,35 +877,63 @@ class NetworkTrainer:
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
# For --sample_at_first
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
# training loop
for epoch in range(num_train_epochs):
if initial_step > 0: # only if skip_until_initial_step is specified
for skip_epoch in range(epoch_to_start): # skip epochs
logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}")
initial_step -= len(train_dataloader)
global_step = initial_step
for epoch in range(epoch_to_start, num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
metadata["ss_epoch"] = str(epoch + 1)
network.on_epoch_start(text_encoder, unet)
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet)
for step, batch in enumerate(train_dataloader):
skipped_dataloader = None
if initial_step > 0:
skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step - 1)
initial_step = 1
for step, batch in enumerate(skipped_dataloader or train_dataloader):
current_step.value = global_step
with accelerator.accumulate(network):
if initial_step > 0:
initial_step -= 1
continue
with accelerator.accumulate(training_model):
on_step_start(text_encoder, unet)
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
else:
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
with torch.no_grad():
# latentに変換
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample()
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
latents = latents * self.vae_scale_factor
b_size = latents.shape[0]
latents = torch.nan_to_num(latents, 0, out=latents)
latents = latents * self.vae_scale_factor
with torch.set_grad_enabled(train_text_encoder):
# get multiplier for each sample
if network_has_multiplier:
multipliers = batch["network_multipliers"]
# if all multipliers are same, use single multiplier
if torch.all(multipliers == multipliers[0]):
multipliers = multipliers[0].item()
else:
raise NotImplementedError("multipliers for each sample is not supported yet")
# print(f"set multiplier: {multipliers}")
accelerator.unwrap_model(network).set_multiplier(multipliers)
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
# Get the text embedding for conditioning
if args.weighted_captions:
text_encoder_conds = get_weighted_text_embeddings(
@@ -764,14 +951,28 @@ class NetworkTrainer:
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)
# ensure the hidden state will require grad
if args.gradient_checkpointing:
for x in noisy_latents:
x.requires_grad_(True)
for t in text_encoder_conds:
t.requires_grad_(True)
# Predict the noise residual
with accelerator.autocast():
noise_pred = self.call_unet(
args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype
args,
accelerator,
unet,
noisy_latents.requires_grad_(train_unet),
timesteps,
text_encoder_conds,
batch,
weight_dtype,
)
if args.v_parameterization:
@@ -780,32 +981,40 @@ class NetworkTrainer:
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss:
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = network.get_trainable_params()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
if accelerator.sync_gradients:
self.all_reduce_network(accelerator, network) # sync DDP grad manually
if args.max_grad_norm != 0.0:
params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
if args.scale_weight_norms:
keys_scaled, mean_norm, maximum_norm = network.apply_max_norm_regularization(
keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(
args.scale_weight_norms, accelerator.device
)
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
@@ -835,28 +1044,25 @@ class NetworkTrainer:
remove_model(remove_ckpt_name)
current_loss = loss.detach().item()
if epoch == 0:
loss_list.append(current_loss)
else:
loss_total -= loss_list[step]
loss_list[step] = current_loss
loss_total += current_loss
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if args.scale_weight_norms:
progress_bar.set_postfix(**{**max_mean_logs, **logs})
if args.logging_dir is not None:
logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm)
logs = self.generate_step_logs(
args, current_loss, avr_loss, lr_scheduler, lr_descriptions, keys_scaled, mean_norm, maximum_norm
)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(loss_list)}
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
@@ -888,27 +1094,32 @@ class NetworkTrainer:
accelerator.end_training()
if is_main_process and args.save_state:
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)
if is_main_process:
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True)
print("model saved.")
logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, True)
train_util.add_masked_loss_arguments(parser)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
parser.add_argument(
"--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない"
)
parser.add_argument(
"--save_model_as",
type=str,
@@ -920,10 +1131,17 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み")
parser.add_argument("--network_module", type=str, default=None, help="network module to train / 学習対象のネットワークのモジュール")
parser.add_argument(
"--network_dim", type=int, default=None, help="network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)"
"--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み"
)
parser.add_argument(
"--network_module", type=str, default=None, help="network module to train / 学習対象のネットワークのモジュール"
)
parser.add_argument(
"--network_dim",
type=int,
default=None,
help="network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)",
)
parser.add_argument(
"--network_alpha",
@@ -938,14 +1156,25 @@ def setup_parser() -> argparse.ArgumentParser:
help="Drops neurons out of training every step (0 or None is default behavior (no dropout), 1 would drop all neurons) / 訓練時に毎ステップでニューロンをdropする0またはNoneはdropoutなし、1は全ニューロンをdropout",
)
parser.add_argument(
"--network_args", type=str, default=None, nargs="*", help="additional argmuments for network (key=value) / ネットワークへの追加の引数"
)
parser.add_argument("--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する")
parser.add_argument(
"--network_train_text_encoder_only", action="store_true", help="only training Text Encoder part / Text Encoder関連部分のみ学習する"
"--network_args",
type=str,
default=None,
nargs="*",
help="additional arguments for network (key=value) / ネットワークへの追加の引数",
)
parser.add_argument(
"--training_comment", type=str, default=None, help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列"
"--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する"
)
parser.add_argument(
"--network_train_text_encoder_only",
action="store_true",
help="only training Text Encoder part / Text Encoder関連部分のみ学習する",
)
parser.add_argument(
"--training_comment",
type=str,
default=None,
help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列",
)
parser.add_argument(
"--dim_from_weights",
@@ -977,6 +1206,28 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
)
parser.add_argument(
"--skip_until_initial_step",
action="store_true",
help="skip training until initial_step is reached / initial_stepに到達するまで学習をスキップする",
)
parser.add_argument(
"--initial_epoch",
type=int,
default=None,
help="initial epoch number, 1 means first epoch (same as not specifying). NOTE: initial_epoch/step doesn't affect to lr scheduler. Which means lr scheduler will start from 0 without `--resume`."
+ " / 初期エポック数、1で最初のエポック未指定時と同じ。注意initial_epoch/stepはlr schedulerに影響しないため、`--resume`しない場合はlr schedulerは0から始まる",
)
parser.add_argument(
"--initial_step",
type=int,
default=None,
help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch."
+ " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ未指定時と同じ。initial_epochを上書きする",
)
# parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio")
# parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio")
# parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio")
return parser
@@ -984,6 +1235,7 @@ if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
trainer = NetworkTrainer()

View File

@@ -1,16 +1,21 @@
import argparse
import gc
import math
import os
from multiprocessing import Value
import toml
from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from transformers import CLIPTokenizer
from library import model_util
from library import deepspeed_utils, model_util
import library.train_util as train_util
import library.huggingface_util as huggingface_util
@@ -25,7 +30,15 @@ from library.custom_train_functions import (
prepare_scheduler_for_custom_training,
scale_v_prediction_loss_like_noise_prediction,
add_v_prediction_like_loss,
apply_debiased_estimation,
apply_masked_loss,
)
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
imagenet_templates_small = [
"a photo of a {}",
@@ -86,7 +99,7 @@ class TextualInversionTrainer:
self.is_sdxl = False
def assert_extra_args(self, args, train_dataset_group):
pass
train_dataset_group.verify_bucket_reso_steps(64)
def load_target_model(self, args, weight_dtype, accelerator):
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
@@ -162,6 +175,7 @@ class TextualInversionTrainer:
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
setup_logging(args, reset=True)
cache_latents = args.cache_latents
@@ -172,7 +186,7 @@ class TextualInversionTrainer:
tokenizers = tokenizer_or_list if isinstance(tokenizer_or_list, list) else [tokenizer_or_list]
# acceleratorを準備する
print("prepare accelerator")
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
@@ -257,7 +271,7 @@ class TextualInversionTrainer:
# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, False))
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, False))
if args.dataset_config is not None:
accelerator.print(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
@@ -282,7 +296,7 @@ class TextualInversionTrainer:
]
}
else:
print("Train with captions.")
logger.info("Train with captions.")
user_config = {
"datasets": [
{
@@ -305,8 +319,8 @@ class TextualInversionTrainer:
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
if use_template:
@@ -357,9 +371,7 @@ class TextualInversionTrainer:
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
@@ -376,13 +388,13 @@ class TextualInversionTrainer:
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collater,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
@@ -407,15 +419,11 @@ class TextualInversionTrainer:
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler
)
# transform DDP after prepare
text_encoder_or_list, unet = train_util.transform_if_model_is_DDP(text_encoder_or_list, unet)
elif len(text_encoders) == 2:
text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoders[0], text_encoders[1], optimizer, train_dataloader, lr_scheduler
)
# transform DDP after prepare
text_encoder1, text_encoder2, unet = train_util.transform_if_model_is_DDP(text_encoder1, text_encoder2, unet)
text_encoder_or_list = text_encoders = [text_encoder1, text_encoder2]
@@ -434,9 +442,10 @@ class TextualInversionTrainer:
# Freeze all parameters except for the token embeddings in text encoder
text_encoder.requires_grad_(True)
text_encoder.text_model.encoder.requires_grad_(False)
text_encoder.text_model.final_layer_norm.requires_grad_(False)
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
unwrapped_text_encoder = accelerator.unwrap_model(text_encoder)
unwrapped_text_encoder.text_model.encoder.requires_grad_(False)
unwrapped_text_encoder.text_model.final_layer_norm.requires_grad_(False)
unwrapped_text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
# text_encoder.text_model.embeddings.token_embedding.requires_grad_(True)
unet.requires_grad_(False)
@@ -496,10 +505,12 @@ class TextualInversionTrainer:
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs
)
# function for saving/removing
@@ -521,6 +532,20 @@ class TextualInversionTrainer:
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
# For --sample_at_first
self.sample_images(
accelerator,
args,
0,
global_step,
accelerator.device,
vae,
tokenizer_or_list,
text_encoder_or_list,
unet,
prompt_replacement,
)
# training loop
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
@@ -536,10 +561,10 @@ class TextualInversionTrainer:
with accelerator.accumulate(text_encoders[0]):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample()
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)
latents = latents * self.vae_scale_factor
# Get the text embedding for conditioning
@@ -547,7 +572,7 @@ class TextualInversionTrainer:
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)
@@ -563,24 +588,28 @@ class TextualInversionTrainer:
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss:
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = text_encoder.get_input_embeddings().parameters()
params_to_clip = accelerator.unwrap_model(text_encoder).get_input_embeddings().parameters()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
@@ -592,9 +621,11 @@ class TextualInversionTrainer:
for text_encoder, orig_embeds_params, index_no_updates in zip(
text_encoders, orig_embeds_params_list, index_no_updates_list
):
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
# if full_fp16/bf16, input_embeddings_weight is fp16/bf16, orig_embeds_params is fp32
input_embeddings_weight = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight
input_embeddings_weight[index_no_updates] = orig_embeds_params.to(input_embeddings_weight.dtype)[
index_no_updates
] = orig_embeds_params[index_no_updates]
]
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
@@ -702,27 +733,29 @@ class TextualInversionTrainer:
is_main_process = accelerator.is_main_process
if is_main_process:
text_encoder = accelerator.unwrap_model(text_encoder)
updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone()
accelerator.end_training()
if args.save_state and is_main_process:
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)
updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone()
if is_main_process:
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(ckpt_name, updated_embs_list, global_step, num_train_epochs, force_sync_upload=True)
print("model saved.")
logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, False)
train_util.add_training_arguments(parser, True)
train_util.add_masked_loss_arguments(parser)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser, False)
@@ -735,7 +768,9 @@ def setup_parser() -> argparse.ArgumentParser:
help="format to save the model (default is .pt) / モデル保存時の形式デフォルトはpt",
)
parser.add_argument("--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み")
parser.add_argument(
"--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み"
)
parser.add_argument(
"--num_vectors_per_token", type=int, default=1, help="number of vectors per token / トークンに割り当てるembeddingsの要素数"
)
@@ -745,7 +780,9 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること",
)
parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可")
parser.add_argument(
"--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可"
)
parser.add_argument(
"--use_object_template",
action="store_true",
@@ -769,6 +806,7 @@ if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
trainer = TextualInversionTrainer()

View File

@@ -1,13 +1,18 @@
import importlib
import argparse
import gc
import math
import os
import toml
from multiprocessing import Value
from tqdm import tqdm
import torch
from library import deepspeed_utils
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate.utils import set_seed
import diffusers
from diffusers import DDPMScheduler
@@ -27,9 +32,17 @@ from library.custom_train_functions import (
pyramid_noise_like,
apply_noise_offset,
scale_v_prediction_loss_like_noise_prediction,
apply_debiased_estimation,
apply_masked_loss,
)
import library.original_unet as original_unet
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
imagenet_templates_small = [
"a photo of a {}",
@@ -88,12 +101,13 @@ def train(args):
if args.output_name is None:
args.output_name = args.token_string
use_template = args.use_object_template or args.use_style_template
setup_logging(args, reset=True)
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
if args.sample_every_n_steps is not None or args.sample_every_n_epochs is not None:
print(
logger.warning(
"sample_every_n_steps and sample_every_n_epochs are not supported in this script currently / sample_every_n_stepsとsample_every_n_epochsは現在このスクリプトではサポートされていません"
)
assert (
@@ -108,7 +122,7 @@ def train(args):
tokenizer = train_util.load_tokenizer(args)
# acceleratorを準備する
print("prepare accelerator")
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
@@ -121,7 +135,7 @@ def train(args):
if args.init_word is not None:
init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False)
if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token:
print(
logger.warning(
f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}"
)
else:
@@ -135,7 +149,7 @@ def train(args):
), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}"
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
print(f"tokens are added: {token_ids}")
logger.info(f"tokens are added: {token_ids}")
assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered"
assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}"
@@ -163,7 +177,7 @@ def train(args):
tokenizer.add_tokens(token_strings_XTI)
token_ids_XTI = tokenizer.convert_tokens_to_ids(token_strings_XTI)
print(f"tokens are added (XTI): {token_ids_XTI}")
logger.info(f"tokens are added (XTI): {token_ids_XTI}")
# Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder.resize_token_embeddings(len(tokenizer))
@@ -172,7 +186,7 @@ def train(args):
if init_token_ids is not None:
for i, token_id in enumerate(token_ids_XTI):
token_embeds[token_id] = token_embeds[init_token_ids[(i // 16) % len(init_token_ids)]]
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
# logger.info(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
# load weights
if args.weights is not None:
@@ -180,22 +194,22 @@ def train(args):
assert len(token_ids) == len(
embeddings
), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}"
# print(token_ids, embeddings.size())
# logger.info(token_ids, embeddings.size())
for token_id, embedding in zip(token_ids_XTI, embeddings):
token_embeds[token_id] = embedding
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
print(f"weighs loaded")
# logger.info(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
logger.info(f"weighs loaded")
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
logger.info(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
# データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, False))
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, False))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
logger.info(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
@@ -203,14 +217,14 @@ def train(args):
else:
use_dreambooth_method = args.in_json is None
if use_dreambooth_method:
print("Use DreamBooth method.")
logger.info("Use DreamBooth method.")
user_config = {
"datasets": [
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
]
}
else:
print("Train with captions.")
logger.info("Train with captions.")
user_config = {
"datasets": [
{
@@ -229,12 +243,12 @@ def train(args):
train_dataset_group.enable_XTI(XTI_layers, token_strings=token_strings)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
if use_template:
print(f"use template for training captions. is object: {args.use_object_template}")
logger.info(f"use template for training captions. is object: {args.use_object_template}")
templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small
replace_to = " ".join(token_strings)
captions = []
@@ -258,7 +272,7 @@ def train(args):
train_util.debug_dataset(train_dataset_group, show_input_ids=True)
return
if len(train_dataset_group) == 0:
print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
logger.error("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
return
if cache_latents:
@@ -280,9 +294,7 @@ def train(args):
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
@@ -291,18 +303,18 @@ def train(args):
text_encoder.gradient_checkpointing_enable()
# 学習に必要なクラスを準備する
print("prepare optimizer, data loader etc.")
logger.info("prepare optimizer, data loader etc.")
trainable_params = text_encoder.get_input_embeddings().parameters()
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collater,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
@@ -312,7 +324,9 @@ def train(args):
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
logger.info(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
@@ -325,11 +339,8 @@ def train(args):
text_encoder, optimizer, train_dataloader, lr_scheduler
)
# transform DDP after prepare
text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet)
index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0]
# print(len(index_no_updates), torch.sum(index_no_updates))
# logger.info(len(index_no_updates), torch.sum(index_no_updates))
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
# Freeze all parameters except for the token embeddings in text encoder
@@ -367,15 +378,17 @@ def train(args):
# 学習する
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
print("running training / 学習開始")
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
print(f" gradient ccumulation steps / 勾配合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
logger.info("running training / 学習開始")
logger.info(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
logger.info(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
logger.info(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
logger.info(f" num epochs / epoch数: {num_train_epochs}")
logger.info(f" batch size per device / バッチサイズ: {args.train_batch_size}")
logger.info(
f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
)
logger.info(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
logger.info(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
@@ -389,16 +402,21 @@ def train(args):
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
accelerator.init_trackers(
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs
)
# function for saving/removing
def save_model(ckpt_name, embs, steps, epoch_no, force_sync_upload=False):
os.makedirs(args.output_dir, exist_ok=True)
ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"\nsaving checkpoint: {ckpt_file}")
logger.info("")
logger.info(f"saving checkpoint: {ckpt_file}")
save_weights(ckpt_file, embs, save_dtype)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
@@ -406,12 +424,13 @@ def train(args):
def remove_model(old_ckpt_name):
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
if os.path.exists(old_ckpt_file):
print(f"removing old checkpoint: {old_ckpt_file}")
logger.info(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
# training loop
for epoch in range(num_train_epochs):
print(f"\nepoch {epoch+1}/{num_train_epochs}")
logger.info("")
logger.info(f"epoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
text_encoder.train()
@@ -423,7 +442,7 @@ def train(args):
with accelerator.accumulate(text_encoder):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
@@ -442,7 +461,7 @@ def train(args):
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
# Predict the noise residual
with accelerator.autocast():
@@ -454,16 +473,20 @@ def train(args):
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
@@ -568,7 +591,7 @@ def train(args):
accelerator.end_training()
if args.save_state and is_main_process:
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)
updated_embs = text_encoder.get_input_embeddings().weight[token_ids_XTI].data.detach().clone()
@@ -579,7 +602,7 @@ def train(args):
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(ckpt_name, updated_embs, global_step, num_train_epochs, force_sync_upload=True)
print("model saved.")
logger.info("model saved.")
def save_weights(file, updated_embs, save_dtype):
@@ -640,9 +663,12 @@ def load_weights(file):
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, False)
train_util.add_training_arguments(parser, True)
train_util.add_masked_loss_arguments(parser)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser, False)
@@ -655,7 +681,9 @@ def setup_parser() -> argparse.ArgumentParser:
help="format to save the model (default is .pt) / モデル保存時の形式デフォルトはpt",
)
parser.add_argument("--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み")
parser.add_argument(
"--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み"
)
parser.add_argument(
"--num_vectors_per_token", type=int, default=1, help="number of vectors per token / トークンに割り当てるembeddingsの要素数"
)
@@ -665,7 +693,9 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること",
)
parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可")
parser.add_argument(
"--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可"
)
parser.add_argument(
"--use_object_template",
action="store_true",
@@ -684,6 +714,7 @@ if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
train(args)