Compare commits

...

467 Commits

Author SHA1 Message Date
Kohya S
2a23713f71 Merge pull request #872 from kohya-ss/dev
fix make_captions_by_git, improve image generation scripts
2023-10-11 07:56:39 +09:00
Kohya S
681034d001 update readme 2023-10-11 07:54:30 +09:00
Kohya S
17813ff5b4 remove workaround for transfomers bs>1 close #869 2023-10-11 07:40:12 +09:00
Kohya S
3e81bd6b67 fix network_merge, add regional mask as color code 2023-10-09 23:07:14 +09:00
Kohya S
23ae358e0f Merge branch 'main' into dev 2023-10-09 21:42:13 +09:00
Kohya S
f611726364 add network_merge_n_models option 2023-10-09 21:41:50 +09:00
Kohya S
33ee0acd35 Merge pull request #867 from kohya-ss/dev
onnx support in wd14 tagger, OFT
2023-10-09 18:04:17 +09:00
Kohya S
8b79e3b06c fix typos 2023-10-09 18:00:45 +09:00
Kohya S
cf49e912fc update readme 2023-10-09 17:59:31 +09:00
Kohya S
66741c035c add OFT 2023-10-09 17:59:24 +09:00
Kohya S
406511c333 add error message if model.onnx doesn't exist 2023-10-09 17:08:58 +09:00
Kohya S
8a2d68d63e Merge pull request #864 from Isotr0py/onnx
Add `--onnx` to wd14 tagger
2023-10-09 15:14:11 +09:00
Kohya S
07d297fdbe Merge branch 'dev' into onnx 2023-10-09 15:13:40 +09:00
Kohya S
0d4e8b50d0 change option to append_tags, minor update 2023-10-09 15:09:54 +09:00
Kohya S
1d7c5c2a98 Merge pull request #858 from a-l-e-x-d-s-9/main
Add append_captions feature to wd14 tagger
2023-10-09 14:31:54 +09:00
Kohya S
0faa350175 Merge pull request #865 from kohya-ss/dev
Support JPEG-XL on windows, dropout for LyCORIS
2023-10-09 14:11:49 +09:00
Kohya S
8a7509db75 Merge branch 'dev' of https://github.com/kohya-ss/sd-scripts into dev 2023-10-09 14:07:02 +09:00
Kohya S
025368f51c may work dropout in LyCORIS #859 2023-10-09 14:06:58 +09:00
Kohya S
5fe52ed322 Merge pull request #856 from Isotr0py/jxl
Fix JPEG-XL support
2023-10-09 13:55:03 +09:00
Kohya S
8b247a330b Merge pull request #851 from kohya-ss/dependabot/github_actions/actions/checkout-4
Bump actions/checkout from 3 to 4
2023-10-09 11:45:47 +09:00
Isotr0py
d6f458fcb3 fix dependency 2023-10-08 23:51:18 +08:00
Isotr0py
b8b84021e5 fix a typo 2023-10-08 20:49:03 +08:00
Isotr0py
70fe7e18be add onnx to wd14 tagger 2023-10-08 20:31:10 +08:00
alexds9
9378da3c82 Fix comment 2023-10-05 21:29:46 +03:00
alexds9
a4857fa764 Add append_captions feature to wd14 tagger
This feature allows for appending new tags to the existing content of caption files.
If the caption file for an image already exists, the tags generated from the current
run are appended to the existing ones. Duplicate tags are checked and avoided.
2023-10-05 21:26:09 +03:00
Isotr0py
592014923f Support JPEG-XL on windows 2023-10-04 21:48:25 +08:00
dependabot[bot]
6d06b215bf Bump actions/checkout from 3 to 4
Bumps [actions/checkout](https://github.com/actions/checkout) from 3 to 4.
- [Release notes](https://github.com/actions/checkout/releases)
- [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/checkout/compare/v3...v4)

---
updated-dependencies:
- dependency-name: actions/checkout
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-10-01 22:51:32 +00:00
Kohya S
2d87bb648f Merge pull request #850 from kohya-ss/dev
fix typos
2023-10-02 07:51:05 +09:00
Kohya S
56ebef35b0 Merge pull request #848 from shirayu/update_typos
Update typos to the latest version and add dependabot.yml
2023-10-02 07:45:29 +09:00
Yuta Hayashibe
13d8b22d25 Add dependabot 2023-10-01 21:52:16 +09:00
Yuta Hayashibe
27f9b6ffeb updated typos to v1.16.15 and fix typos 2023-10-01 21:51:24 +09:00
Yuta Hayashibe
c8fcfd4581 Add "venv" to extend-exclude 2023-10-01 21:48:50 +09:00
Kohya S
49c24285c7 Merge pull request #847 from kohya-ss/sdxl
merge sdxl into main
2023-10-01 20:40:33 +09:00
Kohya S
c918489259 update readme 2023-10-01 20:34:12 +09:00
Kohya S
93155242fa Merge pull request #846 from kohya-ss/dev
Fix to work training U-Net only LoRA for SD1/2
2023-10-01 16:44:13 +09:00
Kohya S
4cc919607a fix placing of requires_grad_ of U-Net 2023-10-01 16:41:48 +09:00
Kohya S
81419f7f32 Fix to work training U-Net only LoRA for SD1/2 2023-10-01 16:37:23 +09:00
Kohya S
6bd6cd9c51 update doc 2023-10-01 12:17:54 +09:00
Kohya S
35a1d68eb6 Merge pull request #844 from kohya-ss/dev
IPEX update, concat LoRA
2023-10-01 12:06:36 +09:00
Kohya S
365a06bdb6 Merge pull request #839 from laksjdjf/sdxl
Support concat LoRA
2023-10-01 11:16:46 +09:00
Kohya S
8e117f9f92 Merge pull request #841 from Disty0/dev
IPEX Attention optimizations
2023-10-01 10:58:19 +09:00
Disty0
209eafb631 IPEX attention optimizations 2023-09-28 14:02:25 +03:00
laksjdjf
14aa2923cf Support concat LoRA 2023-09-28 14:39:32 +09:00
Kohya S
1e395ed285 Merge branch 'main' into sdxl 2023-09-24 17:51:08 +09:00
Kohya S
98615166b0 Merge pull request #831 from kohya-ss/dev
update versions of accelerate and diffusers
2023-09-24 17:50:40 +09:00
Kohya S
28272de97a update readme 2023-09-24 17:48:51 +09:00
Kohya S
7e736da30c update versions of accelerate and diffusers 2023-09-24 17:46:57 +09:00
Kohya S
20e929e27e fix to work iter_same_seed 2023-09-24 16:04:50 +09:00
Kohya S
477b5260aa fix sai metadata for sdxl closes #824 2023-09-24 14:47:13 +09:00
Kohya S
d39f1a3427 Merge pull request #808 from rockerBOO/metadata
Add ip_noise_gamma metadata
2023-09-24 14:35:18 +09:00
Kohya S
3757855231 rename train_lllite_alt to train_lllite 2023-09-24 14:34:31 +09:00
Kohya S
d846431015 Merge branch 'dev' into sdxl 2023-09-24 14:30:16 +09:00
Kohya S
624edf428f Merge pull request #825 from Disty0/dev
Intel ARC support with IPEX
2023-09-24 14:29:03 +09:00
Kohya S
54500b861d Merge pull request #830 from kohya-ss/dev2
add extension checking for resize_lora.py
2023-09-24 12:12:32 +09:00
Kohya S
f2491ee0ac change block name doesn't contain '.' at end 2023-09-24 12:10:56 +09:00
Kohya S
1f169ee7fb Merge pull request #760 from Symbiomatrix/bugfix1
Update resize_lora.py
2023-09-24 11:59:18 +09:00
Kohya S
66817992c1 revert formatting 2023-09-24 11:50:44 +09:00
Kohya S
8052bcd5cd format by black 2023-09-24 11:26:28 +09:00
Kohya S
55886a0116 add .pt and .pth for available extension 2023-09-24 11:25:54 +09:00
Kohya S
33e90cc6a0 Merge pull request #815 from jvkap/patch-1
Update resize_lora.py
2023-09-24 11:02:12 +09:00
青龍聖者@bdsqlsz
d5be8125b0 update bitsandbytes for 0.41.1 and fixed bugs with generate_controlnet_subsets_config for training (#823)
* update for bnb 0.41.1

* fixed generate_controlnet_subsets_config for training

* Revert "update for bnb 0.41.1"

This reverts commit 70bd3612d8.
2023-09-24 10:51:47 +09:00
Disty0
b99cd2a920 Update getDeviceIdListForCard 2023-09-20 17:16:06 +03:00
Disty0
b64389c8a9 Intel ARC support with IPEX 2023-09-19 18:05:05 +03:00
Kohya S
db7a28ac25 fix to work highres_fix_latents_upscaling 2023-09-18 21:12:41 +09:00
Kohya S
d337bbf8a0 get pool from CLIPVisionModel in img2img 2023-09-13 20:58:37 +09:00
Kohya S
90c47140b8 add support model without position_ids 2023-09-13 17:59:34 +09:00
Kohya S
0ecfd91a20 fix VAE becomes last one 2023-09-13 17:59:14 +09:00
jvkap
a0e05fa291 Update resize_lora.py 2023-09-11 11:41:33 -03:00
jvkap
e33c007cd0 Update resize_lora.py 2023-09-11 11:29:06 -03:00
rockerBOO
80aca1ccc7 Add ip_noise_gamma metadata 2023-09-05 15:20:15 -04:00
Kohya S
6b3a580ee5 Merge pull request #804 from kohya-ss/dev
fix to work regional LoRA
2023-09-03 17:52:23 +09:00
Kohya S
207fc8b256 fix to work regional LoRA 2023-09-03 17:50:27 +09:00
Kohya S
74561dbdac update readme (#803)
* update readme

* update readme

* fix typo
2023-09-03 12:51:09 +09:00
Kohya S
867e7d3238 fix typo 2023-09-03 12:49:51 +09:00
Kohya S
5f08a21d12 update readme 2023-09-03 12:48:35 +09:00
Kohya S
95bc6e8749 update readme 2023-09-03 12:46:40 +09:00
Kohya S
4530b96c67 Merge pull request #802 from kohya-ss/dev
reduce fp16/bf16 memory usage, input pertubation noise, fix bug
2023-09-03 12:30:19 +09:00
Kohya S
360af27749 fix ControlNetDataset not working 2023-09-03 12:27:58 +09:00
Kohya S
0ee75fd75d fix typos, add comments etc. 2023-09-03 12:24:15 +09:00
Kohya S
2eae9b66d0 Merge pull request #798 from vvern999/vvern999-patch-1
add input perturbation noise
2023-09-03 10:51:23 +09:00
Kohya S
f6d417e26d Merge pull request #791 from Isotr0py/dev
Intergrate fp16/bf16 support to sdxl model loading
2023-09-03 10:35:09 +09:00
Kohya S
903825af6f Merge pull request #800 from kohya-ss/dev
support jpeg xl, add caption prefix/suffix
2023-09-02 16:20:05 +09:00
Kohya S
948cf17499 add caption_prefix/suffix to dataset 2023-09-02 16:17:12 +09:00
Kohya S
cd59003003 Merge branch 'sdxl' into dev 2023-09-02 15:54:56 +09:00
Kohya S
f19a48a28c Merge branch 'main' into sdxl 2023-09-02 15:53:48 +09:00
Kohya S
4c6f3125fc Merge pull request #793 from tgxs002/tgxs002-patch-1
Update train_README-zh.md, fix a few translation errors.
2023-09-02 15:53:24 +09:00
Kohya S
497051c14b Merge pull request #786 from Isotr0py/jxl
Support JPEG XL
2023-09-02 15:30:07 +09:00
Kohya S
6400116715 Merge pull request #774 from lansing/lansing/sdxl-fix-gen-memleak
fix: VRAM memory leak in sdxl_gen_img.py
2023-09-02 15:20:32 +09:00
Kohya S
f77bdf96d8 Merge pull request #799 from kohya-ss/dev
support diffusers' new VAE
2023-09-02 14:56:37 +09:00
Kohya S
c06a86706a support diffusers' new VAE 2023-09-02 14:54:42 +09:00
vvern999
e0beb6a999 add input perturbation noise
from https://arxiv.org/abs/2301.11706
2023-09-02 07:33:27 +03:00
Kohya S
633bb8d339 Merge branch 'sdxl' of https://github.com/kohya-ss/sd-scripts into sdxl 2023-09-01 07:59:33 +09:00
Kohya S
7e850f3b7e Merge branch 'main' into sdxl 2023-09-01 07:59:26 +09:00
Kohya S
59c9a8e7ae Merge pull request #717 from reid3333/main
load model may fail if symbolic link points to relative path
2023-09-01 07:57:38 +09:00
Blakey Wu
c2419ddabf Update train_README-zh.md, fix a few translation errors. 2023-08-29 08:08:40 +08:00
Isotr0py
2e0942d5c8 delet missed line 2023-08-27 20:45:40 +08:00
Isotr0py
6155f9c171 intergrate fp16/bf16 to model loading 2023-08-27 19:16:23 +08:00
Kohya S
f64c78b777 Merge pull request #787 from kohya-ss/dev
alternative impl of ControlNet-LLLite training
2023-08-25 21:22:31 +09:00
Kohya S
3d12cdc643 fix typo 2023-08-25 21:18:09 +09:00
Kohya S
526488feaa alternative impl of ControlNet-LLLite training 2023-08-25 21:16:11 +09:00
Isotr0py
5d88351bb5 support jpeg xl 2023-08-25 11:07:02 +08:00
Kohya S
a46a4781e8 fix "\\" to "/" for compatiblity 2023-08-24 19:19:53 +09:00
Kohya S
b44644bcec Merge pull request #783 from kohya-ss/dev
add .toml example for lllite doc
2023-08-24 07:53:04 +09:00
Kohya S
1f4a495e16 Merge branch 'sdxl' into dev 2023-08-24 07:51:12 +09:00
Kohya S
d97a1638d3 Merge branch 'dev' of https://github.com/kohya-ss/sd-scripts into dev 2023-08-24 07:50:47 +09:00
Kohya S
ef28a919d2 add .toml example for lllite doc 2023-08-24 07:50:40 +09:00
Kohya S
71369ac98b Merge pull request #776 from kohya-ss/dev
add multiplier, steps range, dataset synthesis
2023-08-22 20:55:19 +09:00
ykume
85f1114c4a add about dataset synthesis for LLLite doc 2023-08-22 20:52:33 +09:00
ykume
927c687628 Merge branch 'sdxl' into dev 2023-08-22 19:15:11 +09:00
Kohya S
6d5cffaee9 add multiplier, steps range 2023-08-22 08:17:21 +09:00
Max Lansing
fbc550d02e fix: VRAM memory leak in sdxl_gen_img.py 2023-08-20 19:04:16 -07:00
Kohya S
014c4b47c9 Merge pull request #770 from kohya-ss/dev
update doc and minor fix
2023-08-20 18:33:50 +09:00
Kohya S
9be19ad777 update doc 2023-08-20 18:30:49 +09:00
Kohya S
1161a5c6da fix debug_dataset for controlnet dataset 2023-08-20 17:39:48 +09:00
Kohya S
9947197a84 fix typos (;^ω^) 2023-08-20 13:53:00 +09:00
Kohya S
50c6aaae62 update lllite doc 2023-08-20 13:37:37 +09:00
Kohya S
edd314cc8a Update train_lllite_README.md 2023-08-20 13:09:01 +09:00
Kohya S
8b2a11fd5e Merge pull request #768 from kohya-ss/dev
ControlNet-LLLite
2023-08-20 13:07:21 +09:00
Kohya S
15b463d18d update lllite doc 2023-08-20 12:56:44 +09:00
Kohya S
0c1975501c Merge branch 'dev' of https://github.com/kohya-ss/sd-scripts into dev 2023-08-20 12:55:52 +09:00
Kohya S
98f8785a4f Update train_lllite_README.md 2023-08-20 12:55:24 +09:00
Kohya S
b74dfba215 update lllite doc 2023-08-20 12:50:37 +09:00
Kohya S
bee5c3f1b8 update lllite doc 2023-08-20 12:45:56 +09:00
Kohya S
e191892824 fix bucketing doesn't work in controlnet training 2023-08-20 12:24:40 +09:00
ykume
2841927dba Merge branch 'dev' of https://github.com/kohya-ss/sd-scripts into dev 2023-08-20 00:09:13 +09:00
ykume
0646112010 fix a bug x is updated inplace 2023-08-20 00:09:09 +09:00
Kohya S
782b11b844 Update train_lll_README-ja.md add sample images 2023-08-19 21:41:54 +09:00
ykume
5a86bbc0a0 fix typos, update readme 2023-08-19 18:54:31 +09:00
ykume
fef7eb73ad rename and update 2023-08-19 18:44:40 +09:00
ykume
62fa4734fe Merge branch 'dev' of https://github.com/kohya-ss/sd-scripts into dev 2023-08-18 12:22:03 +09:00
ykume
b5db90c8a8 modify to attn1/attn2 only 2023-08-18 09:00:22 +09:00
Kohya S
3e1591661e add readme about controlnet-lora 2023-08-17 22:02:07 +09:00
Kohya S
1e52fe6e09 add comments 2023-08-17 20:49:39 +09:00
ykume
809fca0be9 fix error in generation 2023-08-17 18:31:29 +09:00
Kohya S
5fa473d5f3 add cond/uncond, update config 2023-08-17 16:25:23 +09:00
ykume
784a90c3a6 Merge branch 'dev' of https://github.com/kohya-ss/sd-scripts into dev 2023-08-17 13:17:47 +09:00
ykume
6111151f50 add skip input blocks to lora control net 2023-08-17 13:17:43 +09:00
Kohya S
afc03af3ca read dim/rank from weights 2023-08-17 12:10:52 +09:00
ykume
306ee24c90 change to use_reentrant=False 2023-08-17 10:19:14 +09:00
Kohya S
3f7235c36f add lora controlnet train/gen temporarily 2023-08-17 10:08:02 +09:00
Symbiomatrix
9d678a6f41 Update resize_lora.py 2023-08-16 00:08:09 +03:00
Kohya S
983698dd1b add lora controlnet temporarily 2023-08-15 18:23:22 +09:00
Kohya S
9a60b8a0ba Merge pull request #755 from kohya-ss/dev
add lora_fa
2023-08-13 15:20:49 +09:00
Kohya S
adf99a332e update readme 2023-08-13 15:17:29 +09:00
Kohya S
d713e4c757 add lora_fa experimentally 2023-08-13 13:30:34 +09:00
Kohya S
a90c9c2776 add original size for negative cond 2023-08-13 11:17:41 +09:00
Kohya S
d43fcd638e update readme 2023-08-12 13:52:54 +09:00
Kohya S
e32e24adf5 Merge pull request #750 from kohya-ss/dev
block lr for U-Net with SDXL etc.
2023-08-12 13:17:06 +09:00
Kohya S
e2c2689f5c support block lr for U-Net 2023-08-12 13:13:59 +09:00
Kohya S
8415014de6 suppress waning for scheduler args #748 2023-08-11 21:31:55 +09:00
Kohya S
3307ccb2dc revert default noise offset to 0 (None) in sdxl 2023-08-11 20:35:46 +09:00
Kohya S
6889ee2b85 add warning for bucket_reso_steps with SDXL 2023-08-11 19:02:36 +09:00
Kohya S
bf31f18c46 Merge pull request #744 from kohya-ss/dev
fix sample gen failed in sdxl training
2023-08-11 17:00:52 +09:00
Kohya S
e73d103eca fix sample gen failed in sdxl training 2023-08-11 16:58:52 +09:00
Kohya S
12e58ab37f Merge pull request #741 from kohya-ss/dev
fix to work when input_ids has multiple EOS tokens
2023-08-10 20:17:56 +09:00
Kohya S
daad50e384 fix to work when input_ids has multiple EOS tokens 2023-08-10 20:13:59 +09:00
Kohya S
4e339bb101 Merge pull request #733 from kohya-ss/dev
fix sd1/2 lora saving error etc
2023-08-08 21:11:38 +09:00
Kohya S
b83ce0c352 modify import #368 2023-08-08 21:09:08 +09:00
Kohya S
6f80fe17fc fix crashing in saving lora with clipskip 2023-08-08 21:03:16 +09:00
Kohya S
7ea38f90d7 add merge script 2023-08-07 23:40:49 +09:00
Kohya S
f4a2bc6cf8 Merge pull request #722 from kohya-ss/dev
SAI model spec etc.
2023-08-07 08:08:51 +09:00
Kohya S
78226f8574 change assert to print 2023-08-06 22:35:01 +09:00
Kohya S
04b1defaf9 update readme 2023-08-06 22:19:00 +09:00
Kohya S
3cdbbb43be fix error in huggingface_path_in_repo=None 2023-08-06 22:08:30 +09:00
Kohya S
92f41f1051 update sdxl ver in lora metadata from v0-9 to v1-0 2023-08-06 22:06:48 +09:00
Kohya S
c142dadb46 support sai model spec 2023-08-06 21:50:05 +09:00
Kohya S
cd54af019a Merge pull request #720 from kohya-ss/dev
fix training textencoder in sdxl not working
2023-08-05 21:24:24 +09:00
Kohya S
e5f9772a35 fix training textencoder in sdxl not working 2023-08-05 21:22:50 +09:00
reid3333
a02056c566 fix: load may fail if symbolic link points to relative path 2023-08-05 17:47:43 +09:00
Kohya S
2dfa26cca0 Merge pull request #716 from kohya-ss/dev
fix sdxl_gen_img not working
2023-08-05 09:33:19 +09:00
Kohya S
25d8cd473e fix sdxl_gen_img not working 2023-08-05 09:32:01 +09:00
Kohya S
f4935dd6be Merge pull request #714 from kohya-ss/dev
pool output fix, v_pred loss like etc.
2023-08-04 22:36:25 +09:00
Kohya S
9d855091bf make bitsandbytes optional 2023-08-04 22:29:14 +09:00
Kohya S
f3be995c28 remove debug print 2023-08-04 08:44:17 +09:00
Kohya S
9d7619d1eb remove debug print 2023-08-04 08:42:54 +09:00
Kohya S
c6d52fdea4 Add workaround for clip's bug for pooled output 2023-08-04 08:38:27 +09:00
Kohya S
cf6832896f fix ControlNet with regional LoRA 2023-08-03 21:48:11 +09:00
Kohya S
6b1cf6c4fd fix ControlNet with regional LoRA, add shuffle cap 2023-08-03 21:41:46 +09:00
Kohya S
db80c5a2e7 format by black 2023-08-03 20:14:04 +09:00
Kohya S
89aae3e04f fix vae crashes in large reso 2023-07-31 21:48:19 +09:00
Kohya S
0636399c8c add adding v-pred like loss for noise pred 2023-07-31 08:23:28 +09:00
Kohya S
7e474d21ca fix recorded seed in highres fix 2023-07-30 16:48:52 +09:00
Kohya S
f61996b425 remove dependency for albumenations 2023-07-30 16:29:53 +09:00
Kohya S
496c3f2732 arbitrary args for diffusers lr scheduler 2023-07-30 14:36:03 +09:00
Kohya S
8856c19c76 fix batch generation not working 2023-07-30 14:19:25 +09:00
Kohya S
0eacadfa99 fix ControlNet not working 2023-07-30 14:09:43 +09:00
Kohya S
2a4ae88f18 format by black 2023-07-30 14:03:54 +09:00
Kohya S
a296654c1b refactor optimizer selection for bnb 2023-07-30 13:43:29 +09:00
Kohya S
b62185b821 change method name, add comments 2023-07-30 13:34:07 +09:00
Kohya S
e6034b7eb6 move releasing cache outside of the loop 2023-07-30 13:30:42 +09:00
Kohya S
54a4aa22ed Merge pull request #658 from pamparamm/cache_latents_leak_fix
Cache latents VRAM leak fix
2023-07-30 13:22:00 +09:00
青龍聖者@bdsqlsz
9ec70252d0 Add Paged/ adam8bit/lion8bit for Sdxl bitsandbytes 0.39.1 cuda118 on windows (#623)
* ADD libbitsandbytes.dll for 0.38.1

* Delete libbitsandbytes_cuda116.dll

* Delete cextension.py

* add main.py

* Update requirements.txt for bitsandbytes 0.38.1

* Update README.md for bitsandbytes-windows

* Update README-ja.md  for bitsandbytes 0.38.1

* Update main.py for return cuda118

* Update train_util.py for lion8bit

* Update train_README-ja.md for lion8bit

* Update train_util.py for add DAdaptAdan and DAdaptSGD

* Update train_util.py for DAdaptadam

* Update train_network.py for dadapt

* Update train_README-ja.md for DAdapt

* Update train_util.py for DAdapt

* Update train_network.py for DAdaptAdaGrad

* Update train_db.py for DAdapt

* Update fine_tune.py for DAdapt

* Update train_textual_inversion.py for DAdapt

* Update train_textual_inversion_XTI.py for DAdapt

* Revert "Merge branch 'qinglong' into main"

This reverts commit b65c023083, reversing
changes made to f6fda20caf.

* Revert "Update requirements.txt for bitsandbytes 0.38.1"

This reverts commit 83abc60dfa.

* Revert "Delete cextension.py"

This reverts commit 3ba4dfe046.

* Revert "Update README.md for bitsandbytes-windows"

This reverts commit 4642c52086.

* Revert "Update README-ja.md  for bitsandbytes 0.38.1"

This reverts commit fa6d7485ac.

* Update train_util.py for DAdaptLion

* Update train_README-zh.md for dadaptlion

* Update train_README-ja.md for DAdaptLion

* add DAdatpt V3

* Alignment

* Update train_util.py for experimental

* Update train_util.py V3

* Update train_util.py

* Update requirements.txt

* Update train_README-zh.md

* Update train_README-ja.md

* Update train_util.py fix

* Update train_util.py

* support Prodigy

* add lower

* Update main.py

* support PagedAdamW8bit/PagedLion8bit

* Update requirements.txt

* update for PageAdamW8bit and PagedLion8bit

* Revert

* revert main

* Update train_util.py

* update for bitsandbytes 0.39.1

* Update requirements.txt

* vram leak fix

---------

Co-authored-by: Pam <pamhome21@gmail.com>
2023-07-30 13:15:13 +09:00
Kohya S
e20b6acfe9 Merge pull request #676 from Isotr0py/sdxl
Fix RAM leak when loading SDXL model in lowram device
2023-07-30 12:46:23 +09:00
Isotr0py
d9180c03f6 fix typos for _load_state_dict 2023-07-29 22:25:00 +08:00
Kohya S
4072f723c1 Merge branch 'main' into sdxl 2023-07-29 14:55:03 +09:00
Kohya S
cf8021020f Merge pull request #688 from Noyii/main
fix typo
2023-07-29 14:53:04 +09:00
Kohya S
fb1054b5e3 Merge pull request #694 from kohya-ss/dev
support ckpt without position id in sd v1 #687
2023-07-29 14:52:42 +09:00
Kohya S
1e4512b2c8 support ckpt without position id in sd v1 #687 2023-07-29 14:19:25 +09:00
Kohya S
3a7326ae46 Merge pull request #693 from kohya-ss/dev
Support for bitsandbytes 0.39.1 with Paged Optimizer
2023-07-29 13:30:46 +09:00
Kohya S
38b59a93de Merge branch 'main' into dev 2023-07-29 13:20:58 +09:00
Isotr0py
1199eacb72 fix typo 2023-07-28 13:49:37 +08:00
Isotr0py
fdb58b0b62 fix mismatch dtype 2023-07-28 13:47:54 +08:00
Isotr0py
315fbc11e5 refactor model loading to catch error 2023-07-28 13:10:38 +08:00
Noyii
4a1b92d309 Update README.md 2023-07-28 12:31:14 +08:00
Isotr0py
272dd993e6 Merge branch 'sdxl' into sdxl 2023-07-28 10:19:37 +08:00
Isotr0py
96a52d9810 add dtype to u-net loading 2023-07-27 23:58:25 +08:00
Isotr0py
50544b7805 fix pipeline dtype 2023-07-27 23:16:58 +08:00
Kohya S
b78c0e2a69 remove unused func 2023-07-25 19:07:26 +09:00
Kohya S
2b969e9c42 support sdxl 2023-07-24 22:20:21 +09:00
Kohya S
e83ee217d3 format by black 2023-07-24 21:28:37 +09:00
Kohya S
b1e44e96bc fix to show batch size for each dataset refs #637 2023-07-23 15:39:56 +09:00
Kohya S
7ae0cde754 fix max mul embeds doesn't work. closes #656 2023-07-23 15:18:27 +09:00
Kohya S
c1d5c24bc7 fix LoRA with text encoder can't merge closes #660 2023-07-23 15:01:41 +09:00
Isotr0py
eec6aaddda fix safetensors error: device invalid 2023-07-23 13:29:29 +08:00
Isotr0py
bb167f94ca init unet with empty weights 2023-07-23 13:17:11 +08:00
Kohya S
2e4783bcdf Merge branch 'main' into sdxl 2023-07-23 13:53:13 +09:00
Kohya S
7b31c0830f Merge pull request #663 from DingSiuyo/main
fixed some Chinese translation errors
2023-07-23 13:52:32 +09:00
Kohya S
8f645d354e Merge pull request #615 from shirayu/patch-1
Fix a typo
2023-07-23 13:34:48 +09:00
Kohya S
7ec9a7af79 support Diffusers format 2023-07-23 13:33:14 +09:00
Kohya S
50b53e183e re-organize import 2023-07-23 13:33:02 +09:00
青龍聖者@bdsqlsz
d131bde183 Support for bitsandbytes 0.39.1 with Paged Optimizer(AdamW8bit and Lion8bit) (#631)
* ADD libbitsandbytes.dll for 0.38.1

* Delete libbitsandbytes_cuda116.dll

* Delete cextension.py

* add main.py

* Update requirements.txt for bitsandbytes 0.38.1

* Update README.md for bitsandbytes-windows

* Update README-ja.md  for bitsandbytes 0.38.1

* Update main.py for return cuda118

* Update train_util.py for lion8bit

* Update train_README-ja.md for lion8bit

* Update train_util.py for add DAdaptAdan and DAdaptSGD

* Update train_util.py for DAdaptadam

* Update train_network.py for dadapt

* Update train_README-ja.md for DAdapt

* Update train_util.py for DAdapt

* Update train_network.py for DAdaptAdaGrad

* Update train_db.py for DAdapt

* Update fine_tune.py for DAdapt

* Update train_textual_inversion.py for DAdapt

* Update train_textual_inversion_XTI.py for DAdapt

* Revert "Merge branch 'qinglong' into main"

This reverts commit b65c023083, reversing
changes made to f6fda20caf.

* Revert "Update requirements.txt for bitsandbytes 0.38.1"

This reverts commit 83abc60dfa.

* Revert "Delete cextension.py"

This reverts commit 3ba4dfe046.

* Revert "Update README.md for bitsandbytes-windows"

This reverts commit 4642c52086.

* Revert "Update README-ja.md  for bitsandbytes 0.38.1"

This reverts commit fa6d7485ac.

* Update train_util.py

* Update requirements.txt

* support PagedAdamW8bit/PagedLion8bit

* Update requirements.txt

* update for PageAdamW8bit and PagedLion8bit

* Revert

* revert main
2023-07-22 19:45:32 +09:00
Kohya S
d1864e2430 add invisible watermark to req.txt 2023-07-22 19:34:22 +09:00
Kohya S
8ba02ac829 fix to work text encoder only network with bf16 2023-07-22 09:56:36 +09:00
Kohya S
73a08c0be0 Merge pull request #630 from ddPn08/sdxl
make tracker init_kwargs configurable
2023-07-20 22:05:55 +09:00
Kohya S
c45d2f214b Merge branch 'main' into sdxl 2023-07-20 22:02:29 +09:00
Kohya S
9a67e0df39 Merge pull request #610 from lubobill1990/patch-1
Update huggingface hub to resolve error in windows
2023-07-20 21:45:38 +09:00
Kohya S
acf16c063a make to work with PyTorch 1.12 2023-07-20 21:41:16 +09:00
Kohya S
86a8cbd002 fix original w/h prompt opt shows wrong number 2023-07-20 14:52:04 +09:00
Kohya S
fc276a51fb fix invalid args checking in sdxl TI training 2023-07-20 14:50:57 +09:00
Kohya S
771f33d17d Merge pull request #641 from kaibioinfo/patch-1
fix typo in sdxl_train_textual_inversion
2023-07-20 08:28:11 +09:00
DingSiuyo
e6d1f509a0 fixed some translation errors 2023-07-19 04:30:37 +00:00
Kohya S
225e871819 enable full bf16 trainint in train_network 2023-07-19 08:41:42 +09:00
Kohya S
7875ca8fb5 Merge pull request #645 from Ttl/prepare_order
Cast weights to correct precision before transferring them to GPU
2023-07-19 08:33:32 +09:00
Kohya S
6d2d8dfd2f add zero_terminal_snr option 2023-07-18 23:17:23 +09:00
Kohya S
0ec7166098 make crop top/left same as stabilityai's prep 2023-07-18 21:39:36 +09:00
Kohya S
3d66a234b0 enable different prompt for text encoders 2023-07-18 21:39:01 +09:00
Pam
8a073ee49f vram leak fix 2023-07-17 17:51:26 +05:00
Kohya S
7e20c6d1a1 add convenience function to merge LoRA 2023-07-17 10:30:57 +09:00
Kohya S
1d4672d747 fix typos 2023-07-17 09:05:50 +09:00
Kohya S
39e62b948e add lora for Diffusers 2023-07-16 19:57:21 +09:00
Kohya S
41d195715d fix scheduler steps with gradient accumulation 2023-07-16 15:56:29 +09:00
Kohya S
3db97f8897 update readme 2023-07-16 15:14:49 +09:00
Kohya S
516f64f4d9 add caching to disk for text encoder outputs 2023-07-16 14:53:47 +09:00
Kohya S
62dd99bee5 update readme 2023-07-15 18:34:13 +09:00
Kohya S
94c151aea3 refactor caching latents (flip in same npz, etc) 2023-07-15 18:28:33 +09:00
Kohya S
81fa54837f fix sampling in multi GPU training 2023-07-15 11:21:14 +09:00
Kohya S
9de357e373 fix tokenizer 2 is not same as open clip tokenizer 2023-07-14 12:27:19 +09:00
Kohya S
b4a3824ce4 change tokenizer from open clip to transformers 2023-07-13 20:49:26 +09:00
Kohya S
3bb80ebf20 fix sampling gen fails in lora training 2023-07-13 19:02:34 +09:00
Henrik Forstén
cdffd19f61 Cast weights to correct precision before transferring them to GPU 2023-07-13 12:45:28 +03:00
kaibioinfo
a7ce2633f3 fix typo in sdxl_train_textual_inversion
bug appears when continue training on an existing TI
2023-07-12 15:06:20 +02:00
Kohya S
8fa5fb2816 support diffusers format for SDXL 2023-07-12 21:57:14 +09:00
Kohya S
8df948565a remove unnecessary code 2023-07-12 21:53:02 +09:00
Kohya S
3c67e595b8 fix gradient accumulation doesn't work 2023-07-12 21:35:57 +09:00
Kohya S
814996b14f fix NaN in sampling image 2023-07-11 23:18:35 +09:00
Kohya S
2e67d74df4 add no_half_vae option 2023-07-11 22:19:14 +09:00
ddPn08
b841dd78fe make tracker init_kwargs configurable 2023-07-11 10:21:45 +09:00
Kohya S
68ca0ea995 Fix to show template type 2023-07-10 22:28:26 +09:00
Kohya S
f54b784d88 support textual inversion training 2023-07-10 22:04:02 +09:00
Kohya S
b6e328ea8f don't hold latent on memory for finetuning dataset 2023-07-10 08:46:15 +09:00
Kohya S
5c80117fbd update readme 2023-07-09 21:37:46 +09:00
Kohya S
c2ceb6de5f fix uncond/cond order 2023-07-09 21:14:12 +09:00
Kohya S
77ec70d145 fix conditioning 2023-07-09 19:00:38 +09:00
Kohya S
a380502c01 fix pad token is not handled 2023-07-09 18:13:49 +09:00
Kohya S
0416f26a76 support multi gpu in caching text encoder outputs 2023-07-09 16:02:56 +09:00
Kohya S
3579b4570f Merge pull request #628 from KohakuBlueleaf/full_bf16
Full bf16 support
2023-07-09 14:22:44 +09:00
Kohya S
256ff5b56c Merge pull request #626 from ddPn08/sdxl
support avif
2023-07-09 14:14:28 +09:00
Kohya S
7502f662ab Merge branch 'sdxl' of https://github.com/kohya-ss/sd-scripts into sdxl 2023-07-09 14:12:05 +09:00
Kohaku-Blueleaf
d974959738 Update train_util.py for full_bf16 support 2023-07-09 12:47:26 +08:00
Kohaku-Blueleaf
5f348579d1 Update sdxl_train.py 2023-07-09 12:46:35 +08:00
ykume
8371a7a3aa update readme 2023-07-09 13:38:48 +09:00
ykume
1d25703ac3 add generation script 2023-07-09 13:33:26 +09:00
ykume
fe7ede5af3 fix wrapper tokenizer not work for weighted prompt 2023-07-09 13:33:16 +09:00
ddPn08
d599394f60 support avif 2023-07-08 15:47:56 +09:00
Kohya S
66c03be45f Fix TE key names for SD1/2 LoRA are invalid 2023-07-08 09:56:38 +09:00
Kohya S
c1d62383c6 update readme 2023-07-07 21:17:56 +09:00
Kohya S
73ab110260 Merge branch 'sdxl' of https://github.com/kohya-ss/sd-scripts into sdxl 2023-07-07 21:16:49 +09:00
Kohya S
cc3d40ca44 support sdxl in prepare scipt 2023-07-07 21:16:41 +09:00
Kohya S
288efddf2f Update README.md 2023-07-06 07:43:30 +09:00
Kohya S
4a34e5804e fix to work with .ckpt from comfyui 2023-07-05 21:55:43 +09:00
Kohya S
3d0375daa6 fix to work sdxl state dict without logit_scale 2023-07-05 21:45:30 +09:00
Kohya S
3060eb5baf remove debug print 2023-07-05 21:44:46 +09:00
Kohya S
ce46aa0c3b remove debug print 2023-07-04 21:34:18 +09:00
Kohya S
3b35547da0 fix dtype for vae 2023-07-04 21:30:37 +09:00
Kohya S
6aa62b9b66 update readme 2023-07-03 21:06:58 +09:00
Kohya S
2febbfe4b0 add error message for old npz 2023-07-03 20:58:35 +09:00
Kohya S
ea182461d3 add min/max_timestep 2023-07-03 20:44:42 +09:00
Kohya S
5863676ccb update readme 2023-07-02 16:49:18 +09:00
Kohya S
97611e89ca remove debug code 2023-07-02 16:49:11 +09:00
Kohya S
64cf922841 add feature to sample images during sdxl training 2023-07-02 16:42:19 +09:00
Kohya S
227a62e4c4 fix to work with dreambooth ds without toml 2023-06-30 07:40:22 +09:00
Kohya S
38e21f5c1a update transfomer to fix sdxl text model with bf16 2023-06-29 13:03:00 +09:00
Kohya S
d395bc0647 fix max_token_length not works for sdxl 2023-06-29 13:02:19 +09:00
Yuta Hayashibe
afce13d101 Fix a typo 2023-06-28 21:17:20 +09:00
Kohya S
8521ab7990 fix to work 2023-06-28 13:09:02 +09:00
Kohya S
71a6d49d06 fix to work train_network with fine-tuning dataset 2023-06-28 07:50:53 +09:00
Kohya S
07d5c71090 update readme 2023-06-27 23:24:56 +09:00
Kohya S
a751dc25d6 use CLIPTextModelWithProjection 2023-06-27 20:48:06 +09:00
Kohya S
753c63e11b update readme 2023-06-26 21:24:28 +09:00
Kohya S
b0dfbe7086 update readme 2023-06-26 21:20:49 +09:00
Kohya S
31018d57b6 update for sdxl 2023-06-26 21:18:22 +09:00
Kohya S
9ebebb22db fix typos 2023-06-26 20:43:34 +09:00
Kohya S
2c461e4ad3 Add no_half_vae for SDXL training, add nan check 2023-06-26 20:38:09 +09:00
Kohya S
56ca5dfa15 fix warning messages are shown every step 2023-06-26 20:37:14 +09:00
Kohya S
747af145ed add sdxl fine-tuning and LoRA 2023-06-26 08:07:24 +09:00
Bo Lu
7981ee186f Update huggingface hub to resolve error in windows
https://github.com/huggingface/huggingface_hub/issues/1423
2023-06-26 01:53:23 +08:00
Kohya S
9e9df2b501 update dataset to return size, refactor ctrlnet ds 2023-06-24 17:56:02 +09:00
Kohya S
f7f762c676 add minimal inference code for sdxl 2023-06-24 11:52:26 +09:00
Kohya S
0b730d904f Merge branch 'original-u-net' into sdxl 2023-06-24 09:37:00 +09:00
Kohya S
11e8c7d8ff fix to work controlnet training 2023-06-24 09:35:33 +09:00
Kohya S
663f953a78 Merge branch 'original-u-net' into sdxl 2023-06-24 08:49:38 +09:00
Kohya S
bfd909ab79 Merge branch 'main' into original-u-net 2023-06-24 08:49:07 +09:00
Kohya S
0cfcb5a49c fix lr/d*lr is not logged with prodigy in finetune 2023-06-24 08:36:09 +09:00
Kohya S
6a86de1927 add sdxl unet 2023-06-24 00:01:50 +09:00
Kohya S
5114e8daf1 fix training scripts except controlnet not working 2023-06-22 08:46:53 +09:00
Kohya S
1c09867b3e update Diffusers, remove BLIP deps 2023-06-22 08:38:44 +09:00
Kohya S
2b4229fa51 Merge pull request #551 from ddPn08/dev
add controlnet training
2023-06-17 22:02:34 +09:00
Kohya S
92e50133f8 Merge branch 'original-u-net' into dev 2023-06-17 21:57:08 +09:00
Kohya S
c4269b5efa Merge branch 'main' into original-u-net 2023-06-17 21:48:57 +09:00
Kohya S
19dfa24abb Merge branch 'main' into original-u-net 2023-06-16 20:59:34 +09:00
Kohya S
c7fd336c5d Merge pull request #594 from kohya-ss/dev
fix same random seed is used in multiple generation
2023-06-16 12:14:20 +09:00
Kohya S
ed30af8343 Merge branch 'main' into dev 2023-06-16 12:10:59 +09:00
Kohya S
1e0b059982 fix same seed is used for multiple generation 2023-06-16 12:10:18 +09:00
Kohya S
038c09f552 Merge pull request #590 from kohya-ss/dev
prodigyopt, arbitrary dataset etc.
2023-06-15 22:30:10 +09:00
Kohya S
5d1b54de45 update readme 2023-06-15 22:27:47 +09:00
Kohya S
18156bf2a1 fix same replacement multiple times in dyn prompt 2023-06-15 22:22:12 +09:00
Kohya S
5845de7d7c common lr checking for dadaptation and prodigy 2023-06-15 21:47:37 +09:00
青龍聖者@bdsqlsz
e97d67a681 Support for Prodigy(Dadapt variety for Dylora) (#585)
* Update train_util.py for DAdaptLion

* Update train_README-zh.md for dadaptlion

* Update train_README-ja.md for DAdaptLion

* add DAdatpt V3

* Alignment

* Update train_util.py for experimental

* Update train_util.py V3

* Update train_README-zh.md

* Update train_README-ja.md

* Update train_util.py fix

* Update train_util.py

* support Prodigy

* add lower
2023-06-15 21:12:53 +09:00
Kohya S
f0bb3ae825 add an option to disable controlnet in 2nd stage 2023-06-15 20:56:12 +09:00
Kohya S
9806b00f74 add arbitrary dataset feature to each script 2023-06-15 20:39:39 +09:00
Kohya S
f2989b36c2 fix typos, add comment 2023-06-15 20:37:01 +09:00
Kohya S
624fbadea2 fix dynamic prompt with from_file 2023-06-15 19:19:16 +09:00
Kohya S
d4ba37f543 supprot dynamic prompt variants 2023-06-15 13:22:06 +09:00
Kohya S
449ad7502c use original unet for HF models, don't download TE 2023-06-14 22:26:05 +09:00
Kohya S
44404fcd6d Merge branch 'main' into original-u-net 2023-06-14 12:49:51 +09:00
Kohya S
1da6d43109 Merge branch 'main' into dev 2023-06-14 12:49:37 +09:00
Kohya S
9aee793078 support arbitrary dataset for train_network.py 2023-06-14 12:49:12 +09:00
Kohya S
89c3033401 Merge pull request #581 from mio2333/patch-1
Update make_captions.py
2023-06-12 22:15:30 +09:00
Kohya S
67f09b7d7e change ver no for Diffusers VAE changing 2023-06-12 12:29:44 +09:00
ykume
0dfffcd88a remove unnecessary import 2023-06-11 21:46:05 +09:00
ykume
9e1683cf2b support sdpa 2023-06-11 21:26:15 +09:00
ykume
4d0c06e397 support both 0.10.2 and 0.17.0 for Diffusers 2023-06-11 18:54:50 +09:00
ykume
0315611b11 remove workaround for accelerator=0.15, fix XTI 2023-06-11 18:32:14 +09:00
ykume
33a6234b52 Merge branch 'main' into original-u-net 2023-06-11 17:35:20 +09:00
ykume
4b7b3bc04a fix saved SD dict is invalid for VAE 2023-06-11 17:35:00 +09:00
ykume
035dd3a900 fix mem_eff_attn does not work 2023-06-11 17:08:21 +09:00
ykume
4e25c8f78e fix to work with Diffusers 0.17.0 2023-06-11 16:57:17 +09:00
ykume
7f6b581ef8 support memory efficient attn (not xformers) 2023-06-11 16:54:41 +09:00
ykume
cc274fb7fb update diffusers ver, remove tensorflow 2023-06-11 16:54:10 +09:00
mio
334d07bf96 Update make_captions.py
Append sys path for make_captions.py to load blip module in the same folder to fix the error when you don't run this script under the folder
2023-06-08 23:39:06 +08:00
Kohya S
6417f5d7c1 Merge pull request #580 from kohya-ss/dev
fix clip skip not working in weighted caption training and sample gen
2023-06-08 22:10:30 +09:00
Kohya S
8088c04a71 update readme 2023-06-08 22:06:34 +09:00
Kohya S
f7b1911f1b Merge branch 'main' into dev 2023-06-08 22:03:06 +09:00
Kohya S
045cd38b6e fix clip_skip not work in weight capt, sample gen 2023-06-08 22:02:46 +09:00
Kohya S
dccdb8771c support sample generation in training 2023-06-07 08:12:52 +09:00
Kohya S
d4b5cab7f7 Merge branch 'main' into original-u-net 2023-06-07 07:42:27 +09:00
Kohya S
363f1dfab9 Merge pull request #569 from kohya-ss/dev
older lycoris support, BREAK support
2023-06-06 22:07:21 +09:00
Kohya S
4e24733f1c update readme 2023-06-06 22:03:21 +09:00
Kohya S
bb91a10b5f fix to work LyCORIS<0.1.6 2023-06-06 21:59:57 +09:00
Kohya S
98635ebde2 Merge branch 'main' into dev 2023-06-06 21:54:29 +09:00
Kohya S
24823b061d support BREAK in generation script 2023-06-06 21:53:58 +09:00
Kohya S
0fe1afd4ef Merge pull request #562 from u-haru/hotfix/max_mean_logs_with_loss
loss表示追加
2023-06-05 21:42:25 +09:00
Kohya S
c0a7df9ee1 fix eps value, enable xformers, etc. 2023-06-03 21:29:27 +09:00
u-haru
5907bbd9de loss表示追加 2023-06-03 21:20:26 +09:00
Kohya S
5db792b10b initial commit for original U-Net 2023-06-03 19:24:47 +09:00
Kohya S
7c38c33ed6 Merge pull request #560 from kohya-ss/dev
move max_norm to lora to avoid crashing in lycoris
2023-06-03 12:46:02 +09:00
Kohya S
5bec05e045 move max_norm to lora to avoid crashing in lycoris 2023-06-03 12:42:32 +09:00
Kohya S
6084611508 Merge pull request #559 from kohya-ss/dev
max norm, dropout, scale v-pred loss
2023-06-03 11:40:56 +09:00
Kohya S
71a7a27319 update readme 2023-06-03 11:33:18 +09:00
Kohya S
ec2efe52e4 scale v-pred loss like noise pred 2023-06-03 10:52:22 +09:00
Kohya S
0f0158ddaa scale in rank dropout, check training in dropout 2023-06-02 07:29:59 +09:00
Kohya S
dde7807b00 add rank dropout/module dropout 2023-06-01 22:21:36 +09:00
ddPn08
1e3daa247b fix bucketing 2023-06-01 21:58:45 +09:00
ddPn08
3bd00b88c2 support for controlnet in sample output 2023-06-01 20:48:30 +09:00
ddPn08
62d00b4520 add controlnet training 2023-06-01 20:48:25 +09:00
ddPn08
4f8ce00477 update diffusers to 1.16 | finetune 2023-06-01 20:47:54 +09:00
ddPn08
1214f35985 update diffusers to 1.16 | train_db 2023-06-01 20:39:31 +09:00
ddPn08
e743ee5d5c update diffusers to 1.16 | dylora 2023-06-01 20:39:30 +09:00
ddPn08
23c4e5cb01 update diffusers to 1.16 | train_textual_inversion 2023-06-01 20:39:29 +09:00
ddPn08
1f1cae6c5a make the device of snr_weight the same as loss 2023-06-01 20:39:28 +09:00
ddPn08
c8d209d36c update diffusers to 1.16 | train_network 2023-06-01 20:39:26 +09:00
Kohya S
f8e8df5a04 fix crash gen script, change to network_dropout 2023-06-01 20:07:04 +09:00
Kohya S
f4c9276336 add scaling to max norm 2023-06-01 19:46:17 +09:00
Kohya S
a5c38e5d5b fix crashing when max_norm is diabled 2023-06-01 19:32:22 +09:00
AI-Casanova
9c7237157d Dropout and Max Norm Regularization for LoRA training (#545)
* Instantiate max_norm

* minor

* Move to end of step

* argparse

* metadata

* phrasing

* Sqrt ratio and logging

* fix logging

* Dropout test

* Dropout Args

* Dropout changed to affect LoRA only

---------

Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
2023-06-01 14:58:38 +09:00
TingTingin
5931948adb Adjusted English grammar in logs to be more clear (#554)
* Update train_network.py

* Update train_network.py

* Update train_network.py

* Update train_network.py

* Update train_network.py

* Update train_network.py
2023-06-01 12:31:33 +09:00
Kohya S
8a5e3904a0 Merge pull request #553 from kohya-ss/dev
no caption warning, network merging before training
2023-05-31 21:04:50 +09:00
Kohya S
d679dc4de1 Merge branch 'main' into dev 2023-05-31 20:58:32 +09:00
Kohya S
a002d10a4d update readme 2023-05-31 20:57:01 +09:00
Kohya S
3a06968332 warn and continue if huggingface uploading failed 2023-05-31 20:48:33 +09:00
Kohya S
6fbd526931 show multiplier for base weights to console 2023-05-31 20:23:19 +09:00
Kohya S
c437dce056 change option name for merging network weights 2023-05-30 23:19:29 +09:00
Kohya S
fc00691898 enable multiple module weights 2023-05-30 23:10:41 +09:00
Kohya S
990ceddd14 show warning if no caption and no class token 2023-05-30 22:53:50 +09:00
Kohya S
226db64736 Merge pull request #542 from u-haru/feature/differential_learning
差分学習機能追加
2023-05-29 08:38:46 +09:00
Kohya S
2429ac73b2 Merge pull request #533 from TingTingin/main
Added warning on training without captions
2023-05-29 08:37:33 +09:00
u-haru
dd8e17cb37 差分学習機能追加 2023-05-27 05:15:02 +09:00
TingTingin
db756e9a34 Update train_util.py
I removed the sleep since it triggers per subset and if someone had a lot of subsets it would trigger multiple times
2023-05-26 08:08:34 -04:00
Kohya S
16e5981d31 Merge pull request #538 from kohya-ss/dev
update train_network doc. add warning to merge_lora.py
2023-05-25 22:24:16 +09:00
Kohya S
575c51fd3b Merge branch 'main' into dev 2023-05-25 22:14:40 +09:00
Kohya S
5b2447f71d add warning to merge_lora.py 2023-05-25 22:14:21 +09:00
Kohya S
0ccb4d4a3a Merge pull request #537 from kohya-ss/dev
support D-Adaptation v3.0
2023-05-25 22:05:24 +09:00
Kohya S
b5bb8bec67 update readme 2023-05-25 22:03:04 +09:00
青龍聖者@bdsqlsz
5cdf4e34a1 support for dadapaption V3 (#530)
* Update train_util.py for DAdaptLion

* Update train_README-zh.md for dadaptlion

* Update train_README-ja.md for DAdaptLion

* add DAdatpt V3

* Alignment

* Update train_util.py for experimental

* Update train_util.py V3

* Update train_README-zh.md

* Update train_README-ja.md

* Update train_util.py fix

* Update train_util.py

---------

Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
2023-05-25 21:52:36 +09:00
TingTingin
061e157191 Update train_util.py 2023-05-23 02:02:39 -04:00
TingTingin
d859a3a925 Update train_util.py
fix mistake
2023-05-23 02:00:33 -04:00
TingTingin
5a1a14f9fc Update train_util.py
Added feature to add "." if missing in caption_extension
Added warning on training without captions
2023-05-23 01:57:35 -04:00
Kohya S
b6ba4cac83 Merge pull request #528 from kohya-ss/dev
save_state handling, old LoRA support etc.
2023-05-22 18:51:18 +09:00
Kohya S
99b607c60c update readme 2023-05-22 18:46:57 +09:00
Kohya S
289298b17d Merge pull request #527 from Manjiz/main
fix: support old LoRA without alpha raise "TypeError: argument of typ…
2023-05-22 18:36:34 +09:00
琴动我心
f7a1868fc2 fix: support old LoRA without alpha raise "TypeError: argument of type 'int' is not iterable " 2023-05-22 17:15:51 +08:00
Kohya S
02bb8e0ac3 use xformers in VAE in gen script 2023-05-21 12:59:01 +09:00
Kohya S
bc909e8359 Merge pull request #521 from akshaal/fix/save_state
fix: don't save state if no --save-state arg given
2023-05-21 08:48:48 +09:00
Kohya S
c971d9319c Merge pull request #515 from yanhuifair/main
new line with print "generating sample images"
2023-05-21 08:39:22 +09:00
Evgeny Chukreev
0c942106bf fix: don't save state if no --save-state arg given 2023-05-18 20:09:06 +02:00
Fair
c0c4d4ddc6 new line with print "generating sample images" 2023-05-17 10:59:06 +08:00
Kohya S
c924c47f37 Merge pull request #514 from kohya-ss/dev
fix encoding error for prompt file
2023-05-16 07:11:07 +09:00
Kohya S
5b54086663 update readme 2023-05-16 07:09:21 +09:00
Kohya S
9e797cc151 Merge branch 'main' into dev 2023-05-16 07:05:11 +09:00
Kohya S
cc10a62e16 Merge pull request #510 from sdbds/bug_fix
BUG fix for different encoding
2023-05-16 07:03:43 +09:00
青龍聖者@bdsqlsz
7e5b6154d0 Update train_util.py 2023-05-16 00:09:53 +08:00
Kohya S
6d6df18387 Update README.md 2023-05-15 23:23:38 +09:00
Kohya S
ca36f47dfc Merge pull request #509 from kohya-ss/dev
.toml for sample generation etc.
2023-05-15 23:22:11 +09:00
Kohya S
45f9cc9e0e update readme 2023-05-15 23:18:38 +09:00
Kohya S
3699a90645 add adaptive noise scale to metadata 2023-05-15 23:18:16 +09:00
Kohya S
714846e1e1 revert perlin_noise 2023-05-15 23:12:11 +09:00
Kohya S
08d85d4013 Merge branch 'dev' of https://github.com/kohya-ss/sd-scripts into dev 2023-05-15 20:58:04 +09:00
Kohya S
0ec7743436 show loading model path 2023-05-15 20:57:53 +09:00
Kohya S
a72d80aa85 Merge pull request #507 from HkingAuditore/main
Added support for Perlin noise in Noise Offset
2023-05-15 20:56:46 +09:00
Kohya S
b556fc43bc Merge pull request #504 from Linaqruf/main
TOML support for sample prompt
2023-05-15 20:45:22 +09:00
HkingAuditore
dbb9c19669 Merge pull request #1 from kohya-ss/main
Update to newest
2023-05-15 11:22:02 +08:00
hkinghuang
bca6a44974 Perlin noise 2023-05-15 11:16:08 +08:00
Linaqruf
8ab5c8cb28 feat: added json support as well 2023-05-14 19:49:54 +07:00
Linaqruf
774c4059fb feat: added toml support for sample prompt 2023-05-14 19:38:44 +07:00
hkinghuang
5f1d07d62f init 2023-05-12 21:38:07 +08:00
Kohya S
cd984992cf Merge pull request #501 from kohya-ss/dev
fix to work with fp16, crash with some reso
2023-05-12 21:47:10 +09:00
Kohya S
99f4940eb7 Merge branch 'main' into dev 2023-05-12 21:44:42 +09:00
Kohya S
41dd835a89 fix to work with fp16, crash with some reso 2023-05-12 21:44:07 +09:00
Kohya S
ee42c5cd42 Merge pull request #495 from kohya-ss/dev
dim from weights, fix multires noise, update gen script etc.
2023-05-11 22:19:33 +09:00
Kohya S
47b6101465 update readme 2023-05-11 22:17:32 +09:00
Kohya S
7889a52f95 add callback for step start 2023-05-11 22:00:41 +09:00
青龍聖者@bdsqlsz
8d562ecf48 fix pynoise code bug (#489)
* fix pynoise

* Update custom_train_functions.py for default

* Update custom_train_functions.py for note

* Update custom_train_functions.py for default

* Revert "Update custom_train_functions.py for default"

This reverts commit ca79915d73.

* Update custom_train_functions.py for default

* Revert "Update custom_train_functions.py for default"

This reverts commit 483577e137.

* default value change
2023-05-11 21:48:51 +09:00
Kohya S
2767a0f9f2 common block lr args processing in create 2023-05-11 21:47:59 +09:00
Kohya S
af08c56ce0 remove unnecessary newline 2023-05-11 21:20:18 +09:00
Kohya S
dfc56e9227 Merge branch 'main' into dev 2023-05-11 21:12:33 +09:00
Kohya S
84d157995e Merge branch 'dev' of https://github.com/kohya-ss/sd-scripts into dev 2023-05-11 21:12:28 +09:00
Kohya S
ed5bfda372 Fix controlnet input to rgb from bgr 2023-05-11 21:12:06 +09:00
Kohya S
a59822540f Merge pull request #491 from AI-Casanova/size-from-weights
Size from network weights
2023-05-11 21:06:20 +09:00
Kohya S
968bbd2f47 Merge pull request #480 from yanhuifair/main
fix print "saving" and "epoch" in newline
2023-05-11 21:05:37 +09:00
Kohya S
1b4bdff331 enable i2i with highres fix, add slicing VAE 2023-05-10 23:09:25 +09:00
AI-Casanova
678fe003e3 Merge branch 'kohya-ss:main' into size-from-weights 2023-05-09 08:30:18 -05:00
Kohya S
3b1af3f1a6 Merge pull request #484 from kohya-ss/dev
more dadapataion optimizer, move docs, adaptive noise scale etc.
2023-05-07 21:20:55 +09:00
Kohya S
437501cde3 update readme 2023-05-07 21:18:13 +09:00
Kohya S
8bd2072e19 update readme 2023-05-07 21:15:20 +09:00
Kohya S
85df289190 remove gradio from requirements 2023-05-07 21:00:06 +09:00
Kohya S
8856496aac update link to documents 2023-05-07 20:59:02 +09:00
Kohya S
a7df7db464 move documents to docs folder 2023-05-07 20:56:42 +09:00
Kohya S
59507c7c02 update documents 2023-05-07 20:50:19 +09:00
Kohya S
09c719c926 add adaptive noise scale 2023-05-07 18:09:08 +09:00
Kohya S
e54b6311ef do not save cuda_rng_state if no cuda closes #390 2023-05-07 10:23:25 +09:00
Kohya S
fdbdb4748a pre calc LoRA in generating 2023-05-07 09:57:54 +09:00
AI-Casanova
76a2b14cdb Instantiate size_from_weights 2023-05-06 20:06:02 +00:00
Fair
b08154dc36 fix print "saving" and "epoch" in newline 2023-05-07 02:51:01 +08:00
Kohya S
165fc43655 fix comment 2023-05-06 18:25:26 +09:00
Kohya S
42cbf75cfa Merge branch 'main' into dev 2023-05-06 18:22:45 +09:00
Kohya S
2127907dd3 refactor selection and logging for DAdaptation 2023-05-06 18:14:16 +09:00
青龍聖者@bdsqlsz
164a1978de Support for more Dadaptation (#455)
* Update train_util.py for add DAdaptAdan and DAdaptSGD

* Update train_util.py for DAdaptadam

* Update train_network.py for dadapt

* Update train_README-ja.md for DAdapt

* Update train_util.py for DAdapt

* Update train_network.py for DAdaptAdaGrad

* Update train_db.py for DAdapt

* Update fine_tune.py for DAdapt

* Update train_textual_inversion.py for DAdapt

* Update train_textual_inversion_XTI.py for DAdapt
2023-05-06 17:30:09 +09:00
78 changed files with 24615 additions and 3593 deletions

7
.github/dependabot.yml vendored Normal file
View File

@@ -0,0 +1,7 @@
---
version: 2
updates:
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "monthly"

View File

@@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: typos-action
uses: crate-ci/typos@v1.13.10
uses: crate-ci/typos@v1.16.15

View File

@@ -1,3 +1,7 @@
SDXLがサポートされました。sdxlブランチはmainブランチにマージされました。リポジトリを更新したときにはUpgradeの手順を実行してください。また accelerate のバージョンが上がっていますので、accelerate config を再度実行してください。
SDXL学習については[こちら](./README.md#sdxl-training)をご覧ください(英語です)。
## リポジトリについて
Stable Diffusionの学習、画像生成、その他のスクリプトを入れたリポジトリです。
@@ -9,20 +13,19 @@ GUIやPowerShellスクリプトなど、より使いやすくする機能が[bma
* DreamBooth、U-NetおよびText Encoderの学習をサポート
* fine-tuning、同上
* LoRAの学習をサポート
* 画像生成
* モデル変換Stable Diffision ckpt/safetensorsとDiffusersの相互変換
## 使用法について
当リポジトリ内およびnote.comに記事がありますのでそちらをご覧ください将来的にはすべてこちらへ移すかもしれません
* [学習について、共通編](./train_README-ja.md) : データ整備やオプションなど
* [データセット設定](./config_README-ja.md)
* [DreamBoothの学習について](./train_db_README-ja.md)
* [fine-tuningのガイド](./fine_tune_README_ja.md):
* [LoRAの学習について](./train_network_README-ja.md)
* [Textual Inversionの学習について](./train_ti_README-ja.md)
* note.com [画像生成スクリプト](https://note.com/kohya_ss/n/n2693183a798e)
* [学習について、共通編](./docs/train_README-ja.md) : データ整備やオプションなど
* [データセット設定](./docs/config_README-ja.md)
* [DreamBoothの学習について](./docs/train_db_README-ja.md)
* [fine-tuningのガイド](./docs/fine_tune_README_ja.md):
* [LoRAの学習について](./docs/train_network_README-ja.md)
* [Textual Inversionの学習について](./docs/train_ti_README-ja.md)
* [画像生成スクリプト](./docs/gen_img_README-ja.md)
* note.com [モデル変換スクリプト](https://note.com/kohya_ss/n/n374f316fe4ad)
## Windowsでの動作に必要なプログラム
@@ -41,11 +44,13 @@ PowerShellを使う場合、venvを使えるようにするためには以下の
## Windows環境でのインストール
以下の例ではPyTorchは1.12.1CUDA 11.6版をインストールします。CUDA 11.3版やPyTorch 1.13を使う場合は適宜書き換えください
スクリプトはPyTorch 2.0.1でテストしています。PyTorch 1.12.1でも動作すると思われます
以下の例ではPyTorchは2.0.1CUDA 11.8版をインストールします。CUDA 11.6版やPyTorch 1.12.1を使う場合は適宜書き換えください。
なお、python -m venvの行で「python」とだけ表示された場合、py -m venvのようにpythonをpyに変更してください。
通常の管理者ではないPowerShellを開き以下を順に実行します。
PowerShellを使う場合、通常の管理者ではないPowerShellを開き以下を順に実行します。
```powershell
git clone https://github.com/kohya-ss/sd-scripts.git
@@ -54,43 +59,14 @@ cd sd-scripts
python -m venv venv
.\venv\Scripts\activate
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118
pip install --upgrade -r requirements.txt
pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
pip install xformers==0.0.20
accelerate config
```
<!--
pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117
pip install --use-pep517 --upgrade -r requirements.txt
pip install -U -I --no-deps xformers==0.0.16
-->
コマンドプロンプトでは以下になります。
```bat
git clone https://github.com/kohya-ss/sd-scripts.git
cd sd-scripts
python -m venv venv
.\venv\Scripts\activate
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
pip install --upgrade -r requirements.txt
pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
copy /y .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
copy /y .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
copy /y .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
accelerate config
```
コマンドプロンプトでも同一です。
(注:``python -m venv venv`` のほうが ``python -m venv --system-site-packages venv`` より安全そうなため書き換えました。globalなpythonにパッケージがインストールしてあると、後者だといろいろと問題が起きます。
@@ -111,19 +87,40 @@ accelerate configの質問には以下のように答えてください。bf1
※場合によって ``ValueError: fp16 mixed precision requires a GPU`` というエラーが出ることがあるようです。この場合、6番目の質問
``What GPU(s) (by id) should be used for training on this machine as a comma-separated list? [all]:``に「0」と答えてください。id `0`のGPUが使われます。
### PyTorchとxformersのバージョンについて
### オプション:`bitsandbytes`8bit optimizerを使う
他のバージョンでは学習がうまくいかない場合があるようです。特に他の理由がなければ指定のバージョンをお使いください
`bitsandbytes`はオプションになりました。Linuxでは通常通りpipでインストールできます0.41.1または以降のバージョンを推奨)
### オプションLion8bitを使う
Windowsでは0.35.0または0.41.1を推奨します。
Lion8bitを使う場合には`bitsandbytes`0.38.0以降にアップグレードする必要があります。`bitsandbytes`をアンインストールし、Windows環境では例えば[こちら](https://github.com/jllllll/bitsandbytes-windows-webui)などからWindows版のwhlファイルをインストールしてください。たとえば以下のような手順になります
- `bitsandbytes` 0.35.0: 安定しているとみられるバージョンです。AdamW8bitは使用できますが、他のいくつかの8bit optimizer、学習時の`full_bf16`オプションは使用できません
- `bitsandbytes` 0.41.1: Lion8bit、PagedAdamW8bit、PagedLion8bitをサポートします。`full_bf16`が使用できます。
注:`bitsandbytes` 0.35.0から0.41.0までのバージョンには問題があるようです。 https://github.com/TimDettmers/bitsandbytes/issues/659
以下の手順に従い、`bitsandbytes`をインストールしてください。
### 0.35.0を使う場合
PowerShellの例です。コマンドプロンプトではcpの代わりにcopyを使ってください。
```powershell
pip install https://github.com/jllllll/bitsandbytes-windows-webui/raw/main/bitsandbytes-0.38.1-py3-none-any.whl
cd sd-scripts
.\venv\Scripts\activate
pip install bitsandbytes==0.35.0
cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
```
アップグレード時には`pip install .`でこのリポジトリを更新し、必要に応じて他のパッケージもアップグレードしてください。
### 0.41.1を使う場合
jllllll氏の配布されている[こちら](https://github.com/jllllll/bitsandbytes-windows-webui) または他の場所から、Windows用のwhlファイルをインストールしてください。
```powershell
python -m pip install bitsandbytes==0.41.1 --prefer-binary --extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui
```
## アップグレード

259
README.md
View File

@@ -1,9 +1,11 @@
__SDXL is now supported. The sdxl branch has been merged into the main branch. If you update the repository, please follow the upgrade instructions. Also, the version of accelerate has been updated, so please run accelerate config again.__ The documentation for SDXL training is [here](./README.md#sdxl-training).
This repository contains training, generation and utility scripts for Stable Diffusion.
[__Change History__](#change-history) is moved to the bottom of the page.
[__Change History__](#change-history) is moved to the bottom of the page.
更新履歴は[ページ末尾](#change-history)に移しました。
[日本語版README](./README-ja.md)
[日本語版READMEはこちら](./README-ja.md)
For easier use (GUI and PowerShell scripts etc...), please visit [the repository maintained by bmaltais](https://github.com/bmaltais/kohya_ss). Thanks to @bmaltais!
@@ -12,30 +14,30 @@ This repository contains the scripts for:
* DreamBooth training, including U-Net and Text Encoder
* Fine-tuning (native training), including U-Net and Text Encoder
* LoRA training
* Texutl Inversion training
* Textual Inversion training
* Image generation
* Model conversion (supports 1.x and 2.x, Stable Diffision ckpt/safetensors and Diffusers)
__Stable Diffusion web UI now seems to support LoRA trained by ``sd-scripts``.__ (SD 1.x based only) Thank you for great work!!!
## About requirements.txt
These files do not contain requirements for PyTorch. Because the versions of them depend on your environment. Please install PyTorch at first (see installation guide below.)
The scripts are tested with PyTorch 1.12.1 and 1.13.0, Diffusers 0.10.2.
The scripts are tested with Pytorch 2.0.1. 1.12.1 is not tested but should work.
## Links to how-to-use documents
## Links to usage documentation
Most of the documents are written in Japanese.
* [Training guide - common](./train_README-ja.md) : data preparation, options etc...
* [Chinese version](./train_README-zh.md)
* [Dataset config](./config_README-ja.md)
* [DreamBooth training guide](./train_db_README-ja.md)
* [Step by Step fine-tuning guide](./fine_tune_README_ja.md):
* [training LoRA](./train_network_README-ja.md)
* [training Textual Inversion](./train_ti_README-ja.md)
* note.com [Image generation](https://note.com/kohya_ss/n/n2693183a798e)
[English translation by darkstorm2150 is here](https://github.com/darkstorm2150/sd-scripts#links-to-usage-documentation). Thanks to darkstorm2150!
* [Training guide - common](./docs/train_README-ja.md) : data preparation, options etc...
* [Chinese version](./docs/train_README-zh.md)
* [Dataset config](./docs/config_README-ja.md)
* [DreamBooth training guide](./docs/train_db_README-ja.md)
* [Step by Step fine-tuning guide](./docs/fine_tune_README_ja.md):
* [training LoRA](./docs/train_network_README-ja.md)
* [training Textual Inversion](./docs/train_ti_README-ja.md)
* [Image generation](./docs/gen_img_README-ja.md)
* note.com [Model conversion](https://note.com/kohya_ss/n/n374f316fe4ad)
## Windows Required Dependencies
@@ -62,19 +64,20 @@ cd sd-scripts
python -m venv venv
.\venv\Scripts\activate
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118
pip install --upgrade -r requirements.txt
pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
pip install xformers==0.0.20
accelerate config
```
update: ``python -m venv venv`` is seemed to be safer than ``python -m venv --system-site-packages venv`` (some user have packages in global python).
__Note:__ Now bitsandbytes is optional. Please install any version of bitsandbytes as needed. Installation instructions are in the following section.
<!--
cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
-->
Answers to accelerate config:
```txt
@@ -92,20 +95,42 @@ note: Some user reports ``ValueError: fp16 mixed precision requires a GPU`` is o
(Single GPU with id `0` will be used.)
### about PyTorch and xformers
### Optional: Use `bitsandbytes` (8bit optimizer)
Other versions of PyTorch and xformers seem to have problems with training.
If there is no other reason, please install the specified version.
For 8bit optimizer, you need to install `bitsandbytes`. For Linux, please install `bitsandbytes` as usual (0.41.1 or later is recommended.)
### Optional: Use Lion8bit
For Windows, there are several versions of `bitsandbytes`:
For Lion8bit, you need to upgrade `bitsandbytes` to 0.38.0 or later. Uninstall `bitsandbytes`, and for Windows, install the Windows version whl file from [here](https://github.com/jllllll/bitsandbytes-windows-webui) or other sources, like:
- `bitsandbytes` 0.35.0: Stable version. AdamW8bit is available. `full_bf16` is not available.
- `bitsandbytes` 0.41.1: Lion8bit, PagedAdamW8bit and PagedLion8bit are available. `full_bf16` is available.
Note: `bitsandbytes`above 0.35.0 till 0.41.0 seems to have an issue: https://github.com/TimDettmers/bitsandbytes/issues/659
Follow the instructions below to install `bitsandbytes` for Windows.
### bitsandbytes 0.35.0 for Windows
Open a regular Powershell terminal and type the following inside:
```powershell
pip install https://github.com/jllllll/bitsandbytes-windows-webui/raw/main/bitsandbytes-0.38.1-py3-none-any.whl
cd sd-scripts
.\venv\Scripts\activate
pip install bitsandbytes==0.35.0
cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
```
For upgrading, upgrade this repo with `pip install .`, and upgrade necessary packages manually.
This will install `bitsandbytes` 0.35.0 and copy the necessary files to the `bitsandbytes` directory.
### bitsandbytes 0.41.1 for Windows
Install the Windows version whl file from [here](https://github.com/jllllll/bitsandbytes-windows-webui) or other sources, like:
```powershell
python -m pip install bitsandbytes==0.41.1 --prefer-binary --extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui
```
## Upgrade
@@ -136,31 +161,169 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
[BLIP](https://github.com/salesforce/BLIP): BSD-3-Clause
## SDXL training
The documentation in this section will be moved to a separate document later.
### Training scripts for SDXL
- `sdxl_train.py` is a script for SDXL fine-tuning. The usage is almost the same as `fine_tune.py`, but it also supports DreamBooth dataset.
- `--full_bf16` option is added. Thanks to KohakuBlueleaf!
- This option enables the full bfloat16 training (includes gradients). This option is useful to reduce the GPU memory usage.
- The full bfloat16 training might be unstable. Please use it at your own risk.
- The different learning rates for each U-Net block are now supported in sdxl_train.py. Specify with `--block_lr` option. Specify 23 values separated by commas like `--block_lr 1e-3,1e-3 ... 1e-3`.
- 23 values correspond to `0: time/label embed, 1-9: input blocks 0-8, 10-12: mid blocks 0-2, 13-21: output blocks 0-8, 22: out`.
- `prepare_buckets_latents.py` now supports SDXL fine-tuning.
- `sdxl_train_network.py` is a script for LoRA training for SDXL. The usage is almost the same as `train_network.py`.
- Both scripts has following additional options:
- `--cache_text_encoder_outputs` and `--cache_text_encoder_outputs_to_disk`: Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions.
- `--no_half_vae`: Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs.
- `--weighted_captions` option is not supported yet for both scripts.
- `sdxl_train_textual_inversion.py` is a script for Textual Inversion training for SDXL. The usage is almost the same as `train_textual_inversion.py`.
- `--cache_text_encoder_outputs` is not supported.
- There are two options for captions:
1. Training with captions. All captions must include the token string. The token string is replaced with multiple tokens.
2. Use `--use_object_template` or `--use_style_template` option. The captions are generated from the template. The existing captions are ignored.
- See below for the format of the embeddings.
- `--min_timestep` and `--max_timestep` options are added to each training script. These options can be used to train U-Net with different timesteps. The default values are 0 and 1000.
### Utility scripts for SDXL
- `tools/cache_latents.py` is added. This script can be used to cache the latents to disk in advance.
- The options are almost the same as `sdxl_train.py'. See the help message for the usage.
- Please launch the script as follows:
`accelerate launch --num_cpu_threads_per_process 1 tools/cache_latents.py ...`
- This script should work with multi-GPU, but it is not tested in my environment.
- `tools/cache_text_encoder_outputs.py` is added. This script can be used to cache the text encoder outputs to disk in advance.
- The options are almost the same as `cache_latents.py` and `sdxl_train.py`. See the help message for the usage.
- `sdxl_gen_img.py` is added. This script can be used to generate images with SDXL, including LoRA, Textual Inversion and ControlNet-LLLite. See the help message for the usage.
### Tips for SDXL training
- The default resolution of SDXL is 1024x1024.
- The fine-tuning can be done with 24GB GPU memory with the batch size of 1. For 24GB GPU, the following options are recommended __for the fine-tuning with 24GB GPU memory__:
- Train U-Net only.
- Use gradient checkpointing.
- Use `--cache_text_encoder_outputs` option and caching latents.
- Use Adafactor optimizer. RMSprop 8bit or Adagrad 8bit may work. AdamW 8bit doesn't seem to work.
- The LoRA training can be done with 8GB GPU memory (10GB recommended). For reducing the GPU memory usage, the following options are recommended:
- Train U-Net only.
- Use gradient checkpointing.
- Use `--cache_text_encoder_outputs` option and caching latents.
- Use one of 8bit optimizers or Adafactor optimizer.
- Use lower dim (4 to 8 for 8GB GPU).
- `--network_train_unet_only` option is highly recommended for SDXL LoRA. Because SDXL has two text encoders, the result of the training will be unexpected.
- PyTorch 2 seems to use slightly less GPU memory than PyTorch 1.
- `--bucket_reso_steps` can be set to 32 instead of the default value 64. Smaller values than 32 will not work for SDXL training.
Example of the optimizer settings for Adafactor with the fixed learning rate:
```toml
optimizer_type = "adafactor"
optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ]
lr_scheduler = "constant_with_warmup"
lr_warmup_steps = 100
learning_rate = 4e-7 # SDXL original learning rate
```
### Format of Textual Inversion embeddings for SDXL
```python
from safetensors.torch import save_file
state_dict = {"clip_g": embs_for_text_encoder_1280, "clip_l": embs_for_text_encoder_768}
save_file(state_dict, file)
```
### ControlNet-LLLite
ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [documentation](./docs/train_lllite_README.md) for details.
## Change History
### 3 May 2023, 2023/05/03
### Oct 11, 2023 / 2023/10/11
- Fix to work `make_captions_by_git.py` with the latest version of transformers.
- Improve `gen_img_diffusers.py` and `sdxl_gen_img.py`. Both scripts now support the following options:
- `--network_merge_n_models` option can be used to merge some of the models. The remaining models aren't merged, so the multiplier can be changed, and the regional LoRA also works.
- `--network_regional_mask_max_color_codes` is added. Now you can use up to 7 regions.
- When this option is specified, the mask of the regional LoRA is the color code based instead of the channel based. The value is the maximum number of the color codes (up to 7).
- You can specify the mask for each LoRA by colors: 0x0000ff, 0x00ff00, 0x00ffff, 0xff0000, 0xff00ff, 0xffff00, 0xffffff.
- When saving v2 models in Diffusers format in training scripts and conversion scripts, it was found that the U-Net configuration is different from those of Hugging Face's stabilityai models (this repository is `"use_linear_projection": false`, stabilityai is `true`). Please note that the weight shapes are different, so please be careful when using the weight files directly. We apologize for the inconvenience.
- Since the U-Net model is created based on the configuration, it should not cause any problems in training or inference.
- Added `--unet_use_linear_projection` option to `convert_diffusers20_original_sd.py` script. If you specify this option, you can save a Diffusers format model with the same configuration as stabilityai's model from an SD format model (a single `*.safetensors` or `*.ckpt` file). Unfortunately, it is not possible to convert a Diffusers format model to the same format.
- `make_captions_by_git.py` が最新の transformers で動作するように修正しました。
- `gen_img_diffusers.py``sdxl_gen_img.py` を更新し、以下のオプションを追加しました。
- `--network_merge_n_models` オプションで一部のモデルのみマージできます。残りのモデルはマージされないため、重みを変更したり、領域別LoRAを使用したりできます。
- `--network_regional_mask_max_color_codes` を追加しました。最大7つの領域を使用できます。
- このオプションを指定すると、領域別LoRAのマスクはチャンネルベースではなくカラーコードベースになります。値はカラーコードの最大数最大7です。
- 各LoRAに対してマスクをカラーで指定できます0x0000ff、0x00ff00、0x00ffff、0xff0000、0xff00ff、0xffff00、0xffffff。
- Lion8bit optimizer is supported. [PR #447](https://github.com/kohya-ss/sd-scripts/pull/447) Thanks to sdbds!
- Currently it is optional because you need to update `bitsandbytes` version. See "Optional: Use Lion8bit" in installation instructions to use it.
- Multi-GPU training with DDP is supported in each training script. [PR #448](https://github.com/kohya-ss/sd-scripts/pull/448) Thanks to Isotr0py!
- Multi resolution noise (pyramid noise) is supported in each training script. [PR #471](https://github.com/kohya-ss/sd-scripts/pull/471) Thanks to pamparamm!
- See PR and this page [Multi-Resolution Noise for Diffusion Model Training](https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2) for details.
### Oct 9. 2023 / 2023/10/9
- 学習スクリプトや変換スクリプトでDiffusers形式でv2モデルを保存するとき、U-Netの設定がHugging Faceのstabilityaiのモデルと異なることがわかりました当リポジトリでは `"use_linear_projection": false`、stabilityaiは`true`)。重みの形状が異なるため、直接重みファイルを利用する場合にはご注意ください。ご不便をお掛けし申し訳ありません。
- U-Netのモデルは設定に基づいて作成されるため、通常、学習や推論で問題になることはないと思われます。
- `convert_diffusers20_original_sd.py`スクリプトに`--unet_use_linear_projection`オプションを追加しました。これを指定するとSD形式のモデル単一の`*.safetensors`または`*.ckpt`ファイルから、stabilityaiのモデルと同じ形状の重みファイルを持つDiffusers形式モデルが保存できます。なお、Diffusers形式のモデルを同形式に変換することはできません。
- `tag_images_by_wd_14_tagger.py` now supports Onnx. If you use Onnx, TensorFlow is not required anymore. [#864](https://github.com/kohya-ss/sd-scripts/pull/864) Thanks to Isotr0py!
- `--onnx` option is added. If you use Onnx, specify `--onnx` option.
- Please install Onnx and other required packages.
1. Uninstall TensorFlow.
1. `pip install tensorboard==2.14.1` This is required for the specified version of protobuf.
1. `pip install protobuf==3.20.3` This is required for Onnx.
1. `pip install onnx==1.14.1`
1. `pip install onnxruntime-gpu==1.16.0` or `pip install onnxruntime==1.16.0`
- `--append_tags` option is added to `tag_images_by_wd_14_tagger.py`. This option appends the tags to the existing tags, instead of replacing them. [#858](https://github.com/kohya-ss/sd-scripts/pull/858) Thanks to a-l-e-x-d-s-9!
- [OFT](https://oft.wyliu.com/) is now supported.
- You can use `networks.oft` for the network module in `sdxl_train_network.py`. The usage is the same as `networks.lora`. Some options are not supported.
- `sdxl_gen_img.py` also supports OFT as `--network_module`.
- OFT only supports SDXL currently. Because current OFT tweaks Q/K/V and O in the transformer, and SD1/2 have extremely fewer transformers than SDXL.
- The implementation is heavily based on laksjdjf's [OFT implementation](https://github.com/laksjdjf/sd-trainer/blob/dev/networks/lora_modules.py). Thanks to laksjdjf!
- Other bug fixes and improvements.
- `tag_images_by_wd_14_tagger.py` が Onnx をサポートしました。Onnx を使用する場合は TensorFlow は不要です。[#864](https://github.com/kohya-ss/sd-scripts/pull/864) Isotr0py氏に感謝します。
- Onnxを使用する場合は、`--onnx` オプションを指定してください。
- Onnx とその他の必要なパッケージをインストールしてください。
1. TensorFlow をアンインストールしてください。
1. `pip install tensorboard==2.14.1` protobufの指定バージョンにこれが必要。
1. `pip install protobuf==3.20.3` Onnxのために必要。
1. `pip install onnx==1.14.1`
1. `pip install onnxruntime-gpu==1.16.0` または `pip install onnxruntime==1.16.0`
- `tag_images_by_wd_14_tagger.py``--append_tags` オプションが追加されました。このオプションを指定すると、既存のタグに上書きするのではなく、新しいタグのみが既存のタグに追加されます。 [#858](https://github.com/kohya-ss/sd-scripts/pull/858) a-l-e-x-d-s-9氏に感謝します。
- [OFT](https://oft.wyliu.com/) をサポートしました。
- `sdxl_train_network.py``--network_module``networks.oft` を指定してください。使用方法は `networks.lora` と同様ですが一部のオプションは未サポートです。
- `sdxl_gen_img.py` でも同様に OFT を指定できます。
- OFT は現在 SDXL のみサポートしています。OFT は現在 transformer の Q/K/V と O を変更しますが、SD1/2 は transformer の数が SDXL よりも極端に少ないためです。
- 実装は laksjdjf 氏の [OFT実装](https://github.com/laksjdjf/sd-trainer/blob/dev/networks/lora_modules.py) を多くの部分で参考にしています。laksjdjf 氏に感謝します。
- その他のバグ修正と改善。
### Oct 1. 2023 / 2023/10/1
- SDXL training is now available in the main branch. The sdxl branch is merged into the main branch.
- [SAI Model Spec](https://github.com/Stability-AI/ModelSpec) metadata is now supported partially. `hash_sha256` is not supported yet.
- The main items are set automatically.
- You can set title, author, description, license and tags with `--metadata_xxx` options in each training script.
- Merging scripts also support minimum SAI Model Spec metadata. See the help message for the usage.
- Metadata editor will be available soon.
- `bitsandbytes` is now optional. Please install it if you want to use it. The insructions are in the later section.
- `albumentations` is not required anymore.
- `--v_pred_like_loss ratio` option is added. This option adds the loss like v-prediction loss in SDXL training. `0.1` means that the loss is added 10% of the v-prediction loss. The default value is None (disabled).
- In v-prediction, the loss is higher in the early timesteps (near the noise). This option can be used to increase the loss in the early timesteps.
- Arbitrary options can be used for Diffusers' schedulers. For example `--lr_scheduler_args "lr_end=1e-8"`.
- LoRA-FA is added experimentally. Specify `--network_module networks.lora_fa` option instead of `--network_module networks.lora`. The trained model can be used as a normal LoRA model.
- JPEG XL is supported. [#786](https://github.com/kohya-ss/sd-scripts/pull/786)
- Input perturbation noise is added. See [#798](https://github.com/kohya-ss/sd-scripts/pull/798) for details.
- Dataset subset now has `caption_prefix` and `caption_suffix` options. The strings are added to the beginning and the end of the captions before shuffling. You can specify the options in `.toml`.
- Intel ARC support with IPEX is added. [#825](https://github.com/kohya-ss/sd-scripts/pull/825)
- Other bug fixes and improvements.
- Lion8bitオプティマイザがサポートされました。[PR #447](https://github.com/kohya-ss/sd-scripts/pull/447) sdbds氏に感謝します。
- `bitsandbytes`のバージョンを更新する必要があるため、現在はオプションです。使用するにはインストール手順の「[オプションLion8bitを使う](./README-ja.md#オプションlion8bitを使う)」を参照してください。
- 各学習スクリプトでDDPによるマルチGPU学習がサポートされました。[PR #448](https://github.com/kohya-ss/sd-scripts/pull/448) Isotr0py氏に感謝します。
- Multi resolution noise (pyramid noise) が各学習スクリプトでサポートされました。[PR #471](https://github.com/kohya-ss/sd-scripts/pull/471) pamparamm氏に感謝します。
- 詳細はPRおよびこちらのページ [Multi-Resolution Noise for Diffusion Model Training](https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2) を参照してください。
- `--multires_noise_iterations` に数値を指定すると有効になります。`6`~`10`程度の値が良いようです。
- `--multires_noise_discount` に`0.1`~`0.3` 程度の値LoRA学習等比較的データセットが小さい場合のPR作者の推奨、ないしは`0.8`程度の値(元記事の推奨)を指定してください(デフォルトは `0.3`)。
Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates.
最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。

View File

@@ -1,133 +1,131 @@
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from typing import Union, List, Optional, Dict, Any, Tuple
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
def unet_forward_XTI(self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[UNet2DConditionOutput, Tuple]:
r"""
Args:
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
from library.original_unet import SampleOutput
Returns:
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
# By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
# However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary.
default_overall_up_factor = 2**self.num_upsamplers
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
forward_upsample_size = False
upsample_size = None
def unet_forward_XTI(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[Dict, Tuple]:
r"""
Args:
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a dict instead of a plain tuple.
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
logger.info("Forward upsample size to force interpolation output size.")
forward_upsample_size = True
Returns:
`SampleOutput` or `tuple`:
`SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
"""
# By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
# However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary.
# デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある
# ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する
# 多分画質が悪くなるので、64で割り切れるようにしておくのが良い
default_overall_up_factor = 2**self.num_upsamplers
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
# 64で割り切れないときはupsamplerにサイズを伝える
forward_upsample_size = False
upsample_size = None
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
# logger.info("Forward upsample size to force interpolation output size.")
forward_upsample_size = True
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
# 1. time
timesteps = timestep
timesteps = self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理
t_emb = self.time_proj(timesteps)
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
# timestepsは重みを含まないので常にfloat32のテンソルを返す
# しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある
# time_projでキャストしておけばいいんじゃね
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb)
if self.config.num_class_embeds is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
# 2. pre-process
sample = self.conv_in(sample)
# 2. pre-process
sample = self.conv_in(sample)
# 3. down
down_block_res_samples = (sample,)
down_i = 0
for downsample_block in self.down_blocks:
# downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
# まあこちらのほうがわかりやすいかもしれない
if downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states[down_i : down_i + 2],
)
down_i += 2
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
# 3. down
down_block_res_samples = (sample,)
down_i = 0
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states[down_i:down_i+2],
)
down_i += 2
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
down_block_res_samples += res_samples
# 4. mid
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states[6])
# 4. mid
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states[6])
# 5. up
up_i = 7
for i, upsample_block in enumerate(self.up_blocks):
is_final_block = i == len(self.up_blocks) - 1
# 5. up
up_i = 7
for i, upsample_block in enumerate(self.up_blocks):
is_final_block = i == len(self.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
# if we have not reached the final block and need to forward the upsample size, we do it here
# 前述のように最後のブロック以外ではupsample_sizeを伝える
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if upsample_block.has_cross_attention:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states[up_i : up_i + 3],
upsample_size=upsample_size,
)
up_i += 3
else:
sample = upsample_block(
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
)
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states[up_i:up_i+3],
upsample_size=upsample_size,
)
up_i += 3
else:
sample = upsample_block(
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
)
# 6. post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
# 6. post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if not return_dict:
return (sample,)
if not return_dict:
return (sample,)
return SampleOutput(sample=sample)
return UNet2DConditionOutput(sample=sample)
def downblock_forward_XTI(
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
@@ -166,6 +164,7 @@ def downblock_forward_XTI(
return hidden_states, output_states
def upblock_forward_XTI(
self,
hidden_states,
@@ -199,11 +198,11 @@ def upblock_forward_XTI(
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample
i += 1
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
return hidden_states

View File

@@ -9,7 +9,25 @@ parms="parms"
nin="nin"
extention="extention" # Intentionally left
nd="nd"
shs="shs"
sts="sts"
scs="scs"
cpc="cpc"
coc="coc"
cic="cic"
msm="msm"
usu="usu"
ici="ici"
lvl="lvl"
dii="dii"
muk="muk"
ori="ori"
hru="hru"
rik="rik"
koo="koo"
yos="yos"
wn="wn"
[files]
extend-exclude = ["_typos.toml"]
extend-exclude = ["_typos.toml", "venv"]

Binary file not shown.

View File

@@ -1,166 +1,166 @@
"""
extract factors the build is dependent on:
[X] compute capability
[ ] TODO: Q - What if we have multiple GPUs of different makes?
- CUDA version
- Software:
- CPU-only: only CPU quantization functions (no optimizer, no matrix multiple)
- CuBLAS-LT: full-build 8-bit optimizer
- no CuBLAS-LT: no 8-bit matrix multiplication (`nomatmul`)
evaluation:
- if paths faulty, return meaningful error
- else:
- determine CUDA version
- determine capabilities
- based on that set the default path
"""
import ctypes
from .paths import determine_cuda_runtime_lib_path
def check_cuda_result(cuda, result_val):
# 3. Check for CUDA errors
if result_val != 0:
error_str = ctypes.c_char_p()
cuda.cuGetErrorString(result_val, ctypes.byref(error_str))
print(f"CUDA exception! Error code: {error_str.value.decode()}")
def get_cuda_version(cuda, cudart_path):
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION
try:
cudart = ctypes.CDLL(cudart_path)
except OSError:
# TODO: shouldn't we error or at least warn here?
print(f'ERROR: libcudart.so could not be read from path: {cudart_path}!')
return None
version = ctypes.c_int()
check_cuda_result(cuda, cudart.cudaRuntimeGetVersion(ctypes.byref(version)))
version = int(version.value)
major = version//1000
minor = (version-(major*1000))//10
if major < 11:
print('CUDA SETUP: CUDA version lower than 11 are currently not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!')
return f'{major}{minor}'
def get_cuda_lib_handle():
# 1. find libcuda.so library (GPU driver) (/usr/lib)
try:
cuda = ctypes.CDLL("libcuda.so")
except OSError:
# TODO: shouldn't we error or at least warn here?
print('CUDA SETUP: WARNING! libcuda.so not found! Do you have a CUDA driver installed? If you are on a cluster, make sure you are on a CUDA machine!')
return None
check_cuda_result(cuda, cuda.cuInit(0))
return cuda
def get_compute_capabilities(cuda):
"""
1. find libcuda.so library (GPU driver) (/usr/lib)
init_device -> init variables -> call function by reference
2. call extern C function to determine CC
(https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html)
3. Check for CUDA errors
https://stackoverflow.com/questions/14038589/what-is-the-canonical-way-to-check-for-errors-using-the-cuda-runtime-api
# bits taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549
"""
nGpus = ctypes.c_int()
cc_major = ctypes.c_int()
cc_minor = ctypes.c_int()
device = ctypes.c_int()
check_cuda_result(cuda, cuda.cuDeviceGetCount(ctypes.byref(nGpus)))
ccs = []
for i in range(nGpus.value):
check_cuda_result(cuda, cuda.cuDeviceGet(ctypes.byref(device), i))
ref_major = ctypes.byref(cc_major)
ref_minor = ctypes.byref(cc_minor)
# 2. call extern C function to determine CC
check_cuda_result(
cuda, cuda.cuDeviceComputeCapability(ref_major, ref_minor, device)
)
ccs.append(f"{cc_major.value}.{cc_minor.value}")
return ccs
# def get_compute_capability()-> Union[List[str, ...], None]: # FIXME: error
def get_compute_capability(cuda):
"""
Extracts the highest compute capbility from all available GPUs, as compute
capabilities are downwards compatible. If no GPUs are detected, it returns
None.
"""
ccs = get_compute_capabilities(cuda)
if ccs is not None:
# TODO: handle different compute capabilities; for now, take the max
return ccs[-1]
return None
def evaluate_cuda_setup():
print('')
print('='*35 + 'BUG REPORT' + '='*35)
print('Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues')
print('For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link')
print('='*80)
return "libbitsandbytes_cuda116.dll" # $$$
binary_name = "libbitsandbytes_cpu.so"
#if not torch.cuda.is_available():
#print('No GPU detected. Loading CPU library...')
#return binary_name
cudart_path = determine_cuda_runtime_lib_path()
if cudart_path is None:
print(
"WARNING: No libcudart.so found! Install CUDA or the cudatoolkit package (anaconda)!"
)
return binary_name
print(f"CUDA SETUP: CUDA runtime path found: {cudart_path}")
cuda = get_cuda_lib_handle()
cc = get_compute_capability(cuda)
print(f"CUDA SETUP: Highest compute capability among GPUs detected: {cc}")
cuda_version_string = get_cuda_version(cuda, cudart_path)
if cc == '':
print(
"WARNING: No GPU detected! Check your CUDA paths. Processing to load CPU-only library..."
)
return binary_name
# 7.5 is the minimum CC vor cublaslt
has_cublaslt = cc in ["7.5", "8.0", "8.6"]
# TODO:
# (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible)
# (2) Multiple CUDA versions installed
# we use ls -l instead of nvcc to determine the cuda version
# since most installations will have the libcudart.so installed, but not the compiler
print(f'CUDA SETUP: Detected CUDA version {cuda_version_string}')
def get_binary_name():
"if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt.so"
bin_base_name = "libbitsandbytes_cuda"
if has_cublaslt:
return f"{bin_base_name}{cuda_version_string}.so"
else:
return f"{bin_base_name}{cuda_version_string}_nocublaslt.so"
binary_name = get_binary_name()
return binary_name
"""
extract factors the build is dependent on:
[X] compute capability
[ ] TODO: Q - What if we have multiple GPUs of different makes?
- CUDA version
- Software:
- CPU-only: only CPU quantization functions (no optimizer, no matrix multiple)
- CuBLAS-LT: full-build 8-bit optimizer
- no CuBLAS-LT: no 8-bit matrix multiplication (`nomatmul`)
evaluation:
- if paths faulty, return meaningful error
- else:
- determine CUDA version
- determine capabilities
- based on that set the default path
"""
import ctypes
from .paths import determine_cuda_runtime_lib_path
def check_cuda_result(cuda, result_val):
# 3. Check for CUDA errors
if result_val != 0:
error_str = ctypes.c_char_p()
cuda.cuGetErrorString(result_val, ctypes.byref(error_str))
print(f"CUDA exception! Error code: {error_str.value.decode()}")
def get_cuda_version(cuda, cudart_path):
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION
try:
cudart = ctypes.CDLL(cudart_path)
except OSError:
# TODO: shouldn't we error or at least warn here?
print(f'ERROR: libcudart.so could not be read from path: {cudart_path}!')
return None
version = ctypes.c_int()
check_cuda_result(cuda, cudart.cudaRuntimeGetVersion(ctypes.byref(version)))
version = int(version.value)
major = version//1000
minor = (version-(major*1000))//10
if major < 11:
print('CUDA SETUP: CUDA version lower than 11 are currently not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!')
return f'{major}{minor}'
def get_cuda_lib_handle():
# 1. find libcuda.so library (GPU driver) (/usr/lib)
try:
cuda = ctypes.CDLL("libcuda.so")
except OSError:
# TODO: shouldn't we error or at least warn here?
print('CUDA SETUP: WARNING! libcuda.so not found! Do you have a CUDA driver installed? If you are on a cluster, make sure you are on a CUDA machine!')
return None
check_cuda_result(cuda, cuda.cuInit(0))
return cuda
def get_compute_capabilities(cuda):
"""
1. find libcuda.so library (GPU driver) (/usr/lib)
init_device -> init variables -> call function by reference
2. call extern C function to determine CC
(https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html)
3. Check for CUDA errors
https://stackoverflow.com/questions/14038589/what-is-the-canonical-way-to-check-for-errors-using-the-cuda-runtime-api
# bits taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549
"""
nGpus = ctypes.c_int()
cc_major = ctypes.c_int()
cc_minor = ctypes.c_int()
device = ctypes.c_int()
check_cuda_result(cuda, cuda.cuDeviceGetCount(ctypes.byref(nGpus)))
ccs = []
for i in range(nGpus.value):
check_cuda_result(cuda, cuda.cuDeviceGet(ctypes.byref(device), i))
ref_major = ctypes.byref(cc_major)
ref_minor = ctypes.byref(cc_minor)
# 2. call extern C function to determine CC
check_cuda_result(
cuda, cuda.cuDeviceComputeCapability(ref_major, ref_minor, device)
)
ccs.append(f"{cc_major.value}.{cc_minor.value}")
return ccs
# def get_compute_capability()-> Union[List[str, ...], None]: # FIXME: error
def get_compute_capability(cuda):
"""
Extracts the highest compute capbility from all available GPUs, as compute
capabilities are downwards compatible. If no GPUs are detected, it returns
None.
"""
ccs = get_compute_capabilities(cuda)
if ccs is not None:
# TODO: handle different compute capabilities; for now, take the max
return ccs[-1]
return None
def evaluate_cuda_setup():
print('')
print('='*35 + 'BUG REPORT' + '='*35)
print('Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues')
print('For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link')
print('='*80)
return "libbitsandbytes_cuda116.dll" # $$$
binary_name = "libbitsandbytes_cpu.so"
#if not torch.cuda.is_available():
#print('No GPU detected. Loading CPU library...')
#return binary_name
cudart_path = determine_cuda_runtime_lib_path()
if cudart_path is None:
print(
"WARNING: No libcudart.so found! Install CUDA or the cudatoolkit package (anaconda)!"
)
return binary_name
print(f"CUDA SETUP: CUDA runtime path found: {cudart_path}")
cuda = get_cuda_lib_handle()
cc = get_compute_capability(cuda)
print(f"CUDA SETUP: Highest compute capability among GPUs detected: {cc}")
cuda_version_string = get_cuda_version(cuda, cudart_path)
if cc == '':
print(
"WARNING: No GPU detected! Check your CUDA paths. Processing to load CPU-only library..."
)
return binary_name
# 7.5 is the minimum CC vor cublaslt
has_cublaslt = cc in ["7.5", "8.0", "8.6"]
# TODO:
# (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible)
# (2) Multiple CUDA versions installed
# we use ls -l instead of nvcc to determine the cuda version
# since most installations will have the libcudart.so installed, but not the compiler
print(f'CUDA SETUP: Detected CUDA version {cuda_version_string}')
def get_binary_name():
"if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt.so"
bin_base_name = "libbitsandbytes_cuda"
if has_cublaslt:
return f"{bin_base_name}{cuda_version_string}.so"
else:
return f"{bin_base_name}{cuda_version_string}_nocublaslt.so"
binary_name = get_binary_name()
return binary_name

View File

@@ -138,9 +138,13 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学
| `num_repeats` | `10` | o | o | o |
| `random_crop` | `false` | o | o | o |
| `shuffle_caption` | `true` | o | o | o |
| `caption_prefix` | `“masterpiece, best quality, ”` | o | o | o |
| `caption_suffix` | `“, from side”` | o | o | o |
* `num_repeats`
* サブセットの画像の繰り返し回数を指定します。fine tuning における `--dataset_repeats` に相当しますが、`num_repeats` はどの学習方法でも指定可能です。
* `caption_prefix`, `caption_suffix`
* キャプションの前、後に付与する文字列を指定します。シャッフルはこれらの文字列を含めた状態で行われます。`keep_tokens` を指定する場合には注意してください。
### DreamBooth 方式専用のオプション

View File

@@ -153,7 +153,9 @@ python gen_img_diffusers.py --ckpt <モデル名> --outdir <画像出力先>
- `--network_mul`:使用する追加ネットワークの重みを何倍にするかを指定します。デフォルトは`1`です。`--network_mul 0.8`のように指定します。複数のLoRAを使用する場合は`--network_mul 0.4 0.5 0.7`のように指定します。引数の数は`--network_module`で指定した数と同じにしてください。
- `--network_merge`:使用する追加ネットワークの重みを`--network_mul`に指定した重みであらかじめマージします。プロンプトオプションの`--am`は使用できなくなりますが、LoRA未使用時と同じ程度まで生成が高速化されます。
- `--network_merge`:使用する追加ネットワークの重みを`--network_mul`に指定した重みであらかじめマージします。`--network_pre_calc` と同時に使用できません。プロンプトオプションの`--am`、およびRegional LoRAは使用できなくなりますが、LoRA未使用時と同じ程度まで生成が高速化されます。
- `--network_pre_calc`:使用する追加ネットワークの重みを生成ごとにあらかじめ計算します。プロンプトオプションの`--am`が使用できます。LoRA未使用時と同じ程度まで生成は高速化されますが、生成前に重みを計算する時間が必要で、またメモリ使用量も若干増加します。Regional LoRA使用時は無効になります 。
# 主なオプションの指定例

View File

@@ -295,7 +295,7 @@ Stable Diffusion のv1は512\*512で学習されていますが、それに加
また任意の解像度で学習するため、事前に画像データの縦横比を統一しておく必要がなくなります。
設定で有効、向こうが切り替えられますが、ここまでの設定ファイルの記述例では有効になっています(`true` が設定されています)。
設定で有効、無効が切り替えられますが、ここまでの設定ファイルの記述例では有効になっています(`true` が設定されています)。
学習解像度はパラメータとして与えられた解像度の面積メモリ使用量を超えない範囲で、64ピクセル単位デフォルト、変更可で縦横に調整、作成されます。
@@ -463,27 +463,6 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
xformersオプションを指定するとxformersのCrossAttentionを用います。xformersをインストールしていない場合やエラーとなる場合環境にもよりますが `mixed_precision="no"` の場合など)、代わりに `mem_eff_attn` オプションを指定すると省メモリ版CrossAttentionを使用しますxformersよりも速度は遅くなります
- `--save_precision`
保存時のデータ精度を指定します。save_precisionオプションにfloat、fp16、bf16のいずれかを指定すると、その形式でモデルを保存しますDreamBooth、fine tuningでDiffusers形式でモデルを保存する場合は無効です。モデルのサイズを削減したい場合などにお使いください。
- `--save_every_n_epochs` / `--save_state` / `--resume`
save_every_n_epochsオプションに数値を指定すると、そのエポックごとに学習途中のモデルを保存します。
save_stateオプションを同時に指定すると、optimizer等の状態も含めた学習状態を合わせて保存します保存したモデルからも学習再開できますが、それに比べると精度の向上、学習時間の短縮が期待できます。保存先はフォルダになります。
学習状態は保存先フォルダに `<output_name>-??????-state`??????はエポック数)という名前のフォルダで出力されます。長時間にわたる学習時にご利用ください。
保存された学習状態から学習を再開するにはresumeオプションを使います。学習状態のフォルダ`output_dir` ではなくその中のstateのフォルダを指定してください。
なおAcceleratorの仕様により、エポック数、global stepは保存されておらず、resumeしたときにも1からになりますがご容赦ください。
- `--save_model_as` DreamBooth, fine tuning のみ)
モデルの保存形式を`ckpt, safetensors, diffusers, diffusers_safetensors` から選べます。
`--save_model_as=safetensors` のように指定します。Stable Diffusion形式ckptまたはsafetensorsを読み込み、Diffusers形式で保存する場合、不足する情報はHugging Faceからv1.5またはv2.1の情報を落としてきて補完します。
- `--clip_skip`
`2` を指定すると、Text Encoder (CLIP) の後ろから二番目の層の出力を用います。1またはオプション省略時は最後の層を用います。
@@ -502,6 +481,12 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
clip_skipと同様に、モデルの学習状態と異なる長さで学習するには、ある程度の教師データ枚数、長めの学習時間が必要になると思われます。
- `--weighted_captions`
指定するとAutomatic1111氏のWeb UIと同様の重み付きキャプションが有効になります。「Textual Inversion と XTI」以外の学習に使用できます。キャプションだけでなく DreamBooth 手法の token string でも有効です。
重みづけキャプションの記法はWeb UIとほぼ同じで、(abc)や[abc]、(abc:1.23)などが使用できます。入れ子も可能です。括弧内にカンマを含めるとプロンプトのshuffle/dropoutで括弧の対応付けがおかしくなるため、括弧内にはカンマを含めないでください。
- `--persistent_data_loader_workers`
Windows環境で指定するとエポック間の待ち時間が大幅に短縮されます。
@@ -527,12 +512,28 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
その後ブラウザを開き、http://localhost:6006/ へアクセスすると表示されます。
- `--log_with` / `--log_tracker_name`
学習ログの保存に関するオプションです。`tensorboard` だけでなく `wandb`への保存が可能です。詳細は [PR#428](https://github.com/kohya-ss/sd-scripts/pull/428)をご覧ください。
- `--noise_offset`
こちらの記事の実装になります: https://www.crosslabs.org//blog/diffusion-with-offset-noise
全体的に暗い、明るい画像の生成結果が良くなる可能性があるようです。LoRA学習でも有効なようです。`0.1` 程度の値を指定するとよいようです。
- `--adaptive_noise_scale` (実験的オプション)
Noise offsetの値を、latentsの各チャネルの平均値の絶対値に応じて自動調整するオプションです。`--noise_offset` と同時に指定することで有効になります。Noise offsetの値は `noise_offset + abs(mean(latents, dim=(2,3))) * adaptive_noise_scale` で計算されます。latentは正規分布に近いためnoise_offsetの1/10同程度の値を指定するとよいかもしれません。
負の値も指定でき、その場合はnoise offsetは0以上にclipされます。
- `--multires_noise_iterations` / `--multires_noise_discount`
Multi resolution noise (pyramid noise)の設定です。詳細は [PR#471](https://github.com/kohya-ss/sd-scripts/pull/471) およびこちらのページ [Multi-Resolution Noise for Diffusion Model Training](https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2) を参照してください。
`--multires_noise_iterations` に数値を指定すると有効になります。6~10程度の値が良いようです。`--multires_noise_discount` に0.1~0.3 程度の値LoRA学習等比較的データセットが小さい場合のPR作者の推奨、ないしは0.8程度の値(元記事の推奨)を指定してください(デフォルトは 0.3)。
- `--debug_dataset`
このオプションを付けることで学習を行う前に事前にどのような画像データ、キャプションで学習されるかを確認できます。Escキーを押すと終了してコマンドラインに戻ります。`S`キーで次のステップ(バッチ)、`E`キーで次のエポックに進みます。
@@ -545,14 +546,62 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
DreamBoothおよびfine tuningでは、保存されるモデルはこのVAEを組み込んだものになります。
- `--cache_latents`
- `--cache_latents` / `--cache_latents_to_disk`
使用VRAMを減らすためVAEの出力をメインメモリにキャッシュします。`flip_aug` 以外のaugmentationは使えなくなります。また全体の学習速度が若干速くなります。
cache_latents_to_diskを指定するとキャッシュをディスクに保存します。スクリプトを終了し、再度起動した場合もキャッシュが有効になります。
- `--min_snr_gamma`
Min-SNR Weighting strategyを指定します。詳細は[こちら](https://github.com/kohya-ss/sd-scripts/pull/308)を参照してください。論文では`5`が推奨されています。
## モデルの保存に関する設定
- `--save_precision`
保存時のデータ精度を指定します。save_precisionオプションにfloat、fp16、bf16のいずれかを指定すると、その形式でモデルを保存しますDreamBooth、fine tuningでDiffusers形式でモデルを保存する場合は無効です。モデルのサイズを削減したい場合などにお使いください。
- `--save_every_n_epochs` / `--save_state` / `--resume`
save_every_n_epochsオプションに数値を指定すると、そのエポックごとに学習途中のモデルを保存します。
save_stateオプションを同時に指定すると、optimizer等の状態も含めた学習状態を合わせて保存します保存したモデルからも学習再開できますが、それに比べると精度の向上、学習時間の短縮が期待できます。保存先はフォルダになります。
学習状態は保存先フォルダに `<output_name>-??????-state`??????はエポック数)という名前のフォルダで出力されます。長時間にわたる学習時にご利用ください。
保存された学習状態から学習を再開するにはresumeオプションを使います。学習状態のフォルダ`output_dir` ではなくその中のstateのフォルダを指定してください。
なおAcceleratorの仕様により、エポック数、global stepは保存されておらず、resumeしたときにも1からになりますがご容赦ください。
- `--save_every_n_steps`
save_every_n_stepsオプションに数値を指定すると、そのステップごとに学習途中のモデルを保存します。save_every_n_epochsと同時に指定できます。
- `--save_model_as` DreamBooth, fine tuning のみ)
モデルの保存形式を`ckpt, safetensors, diffusers, diffusers_safetensors` から選べます。
`--save_model_as=safetensors` のように指定します。Stable Diffusion形式ckptまたはsafetensorsを読み込み、Diffusers形式で保存する場合、不足する情報はHugging Faceからv1.5またはv2.1の情報を落としてきて補完します。
- `--huggingface_repo_id`
huggingface_repo_idが指定されているとモデル保存時に同時にHuggingFaceにアップロードします。アクセストークンの取り扱いに注意してくださいHuggingFaceのドキュメントを参照してください
他の引数をたとえば以下のように指定してください。
- `--huggingface_repo_id "your-hf-name/your-model" --huggingface_path_in_repo "path" --huggingface_repo_type model --huggingface_repo_visibility private --huggingface_token hf_YourAccessTokenHere`
huggingface_repo_visibilityに`public`を指定するとリポジトリが公開されます。省略時または`private`などpublic以外を指定すると非公開になります。
`--save_state`オプション指定時に`--save_state_to_huggingface`を指定するとstateもアップロードします。
`--resume`オプション指定時に`--resume_from_huggingface`を指定するとHuggingFaceからstateをダウンロードして再開します。その時の --resumeオプションは `--resume {repo_id}/{path_in_repo}:{revision}:{repo_type}`になります。
例: `--resume_from_huggingface --resume your-hf-name/your-model/path/test-000002-state:main:model`
`--async_upload`オプションを指定するとアップロードを非同期で行います。
## オプティマイザ関係
- `--optimizer_type`
@@ -560,13 +609,22 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
- AdamW : [torch.optim.AdamW](https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html)
- 過去のバージョンのオプション未指定時と同じ
- AdamW8bit : 引数は同上
- PagedAdamW8bit : 引数は同上
- 過去のバージョンの--use_8bit_adam指定時と同じ
- Lion : https://github.com/lucidrains/lion-pytorch
- 過去のバージョンの--use_lion_optimizer指定時と同じ
- Lion8bit : 引数は同上
- PagedLion8bit : 引数は同上
- SGDNesterov : [torch.optim.SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html), nesterov=True
- SGDNesterov8bit : 引数は同上
- DAdaptation : https://github.com/facebookresearch/dadaptation
- DAdaptation(DAdaptAdamPreprint) : https://github.com/facebookresearch/dadaptation
- DAdaptAdam : 引数は同上
- DAdaptAdaGrad : 引数は同上
- DAdaptAdan : 引数は同上
- DAdaptAdanIP : 引数は同上
- DAdaptLion : 引数は同上
- DAdaptSGD : 引数は同上
- Prodigy : https://github.com/konstmish/prodigy
- AdaFactor : [Transformers AdaFactor](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules)
- 任意のオプティマイザ

View File

@@ -1,9 +1,9 @@
__由于文档正在更新中描述可能有错误。__
# 关于本学习文档,通用描述
# 关于训练,通用描述
本库支持模型微调(fine tuning)、DreamBooth、训练LoRA和文本反转(Textual Inversion)(包括[XTI:P+](https://github.com/kohya-ss/sd-scripts/pull/327)
本文档将说明它们通用的学习数据准备方法和选项等。
本文档将说明它们通用的训练数据准备方法和选项等。
# 概要
@@ -12,15 +12,15 @@ __由于文档正在更新中描述可能有错误。__
以下本节说明。
1. 关于准备学习数据的新形式(使用设置文件)
1. 对于在学习中使用的术语的简要解释
1. 准备训练数据(使用设置文件的新格式
1. 训练中使用的术语的简要解释
1. 先前的指定格式(不使用设置文件,而是从命令行指定)
1. 生成学习过程中的示例图像
1. 生成训练过程中的示例图像
1. 各脚本中常用的共同选项
1. 准备 fine tuning 方法的元数据:如说明文字(打标签)等
1. 如果只执行一次,学习就可以进行(相关内容,请参阅各个脚本的文档)。如果需要,以后可以随时参考。
1. 如果只执行一次,训练就可以进行(相关内容,请参阅各个脚本的文档)。如果需要,以后可以随时参考。
@@ -28,24 +28,25 @@ __由于文档正在更新中描述可能有错误。__
在任意文件夹(也可以是多个文件夹)中准备好训练数据的图像文件。支持 `.png`, `.jpg`, `.jpeg`, `.webp`, `.bmp` 格式的文件。通常不需要进行任何预处理,如调整大小等。
但是请勿使用极小的图像其尺寸比训练分辨率稍后将提到还小建议事先使用超分辨率AI等进行放大。另外请注意不要使用过大的图像约为3000 x 3000像素以上因为这可能会导致错误建议事先缩小。
但是请勿使用极小的图像,其尺寸比训练分辨率稍后将提到还小建议事先使用超分辨率AI等进行放大。另外请注意不要使用过大的图像约为3000 x 3000像素以上因为这可能会导致错误建议事先缩小。
在训练时,需要整理要用于训练模型的图像数据,并将其指定给脚本。根据训练数据的数量、训练目标和说明(图像描述)是否可用等因素,可以使用几种方法指定训练数据。以下是其中的一些方法(每个名称都不是通用的,而是该存储库自定义的定义)。有关正则化图像的信息将在稍后提供。
1. DreamBooth、class + identifier方式可使用正则化图像
将训练目标与特定单词identifier相关联进行训练。无需准备说明。例如当要学习特定角色时由于无需准备说明因此比较方便但由于学习数据的所有元素都与identifier相关联例如发型、服装、背景等因此在生成时可能会出现无法更换服装的情况。
将训练目标与特定单词identifier相关联进行训练。无需准备说明。例如当要学习特定角色时由于无需准备说明因此比较方便但由于训练数据的所有元素都与identifier相关联例如发型、服装、背景等因此在生成时可能会出现无法更换服装的情况。
2. DreamBooth、说明方式可使用正则化图像
准备记录每个图像说明的文本文件进行训练。例如通过将图像详细信息如穿着白色衣服的角色A、穿着红色衣服的角色A等记录在说明中,可以将角色和其他元素分离,并期望模型更准确地学习角色。
事先给每个图片写说明caption存放到文本文件中然后进行训练。例如通过将图像详细信息如穿着白色衣服的角色A、穿着红色衣服的角色A等记录在caption中,可以将角色和其他元素分离,并期望模型更准确地学习角色。
3. 微调方式(不可使用正则化图像)
先将说明收集到元数据文件中。支持分离标签和说明以及预先缓存latents等功能以加速训练这些将在另一篇文档中介绍虽然名为fine tuning方式但不仅限于fine tuning。
你要学的东西和你可以使用的规范方法的组合如下。
训练对象和你可以使用的规范方法的组合如下。
| 学习对象或方法 | 脚本 | DB/class+identifier | DB/caption | fine tuning |
| 训练对象或方法 | 脚本 | DB/class+identifier | DB/caption | fine tuning |
|----------------| ----- | ----- | ----- | ----- |
| fine tuning微调模型 | `fine_tune.py`| x | x | o |
| DreamBooth训练模型 | `train_db.py`| o | o | x |
@@ -54,15 +55,15 @@ __由于文档正在更新中描述可能有错误。__
## 选择哪一个
如果您想要学习LoRA、Textual Inversion而不需要准备简介文件则建议使用DreamBooth class+identifier。如果您能够准备则DreamBooth Captions方法更好。如果您有大量的训练数据并且不使用则化图像则请考虑使用fine-tuning方法。
如果您想要训练LoRA、Textual Inversion而不需要准备说明caption文件则建议使用DreamBooth class+identifier。如果您能够准备caption文件则DreamBooth Captions方法更好。如果您有大量的训练数据并且不使用则化图像则请考虑使用fine-tuning方法。
对于DreamBooth也是一样的但不能使用fine-tuning方法。对于fine-tuning方法只能使用fine-tuning方式。
对于DreamBooth也是一样的但不能使用fine-tuning方法。若要进行微调只能使用fine-tuning方式。
# 每种方法的指定方式
在这里,我们只介绍每种指定方法的典型模式。有关更详细的指定方法,请参见[数据集设置](./config_README-ja.md)。
# DreamBoothclass+identifier方法可使用则化图像)
# DreamBoothclass+identifier方法可使用则化图像)
在该方法中,每个图像将被视为使用与 `class identifier` 相同的标题进行训练(例如 `shs dog`)。
@@ -70,15 +71,15 @@ __由于文档正在更新中描述可能有错误。__
## step 1.确定identifier和class
要将学习的目标与identifier和属于该目标的class相关联。
要将训练的目标与identifier和属于该目标的class相关联。
(虽然有很多称呼,但暂时按照原始论文的说法。)
以下是简要说明(请查阅详细信息)。
class是学习目标的一般类别。例如如果要学习特定品种的狗则class将是“dog”。对于动漫角色根据模型不同可能是“boy”或“girl”也可能是“1boy”或“1girl”。
class是训练目标的一般类别。例如如果要学习特定品种的狗则class将是“dog”。对于动漫角色根据模型不同可能是“boy”或“girl”也可能是“1boy”或“1girl”。
identifier是用于识别学习目标并进行学习的单词。可以使用任何单词但是根据原始论文“Tokenizer生成的3个或更少字符的罕见单词”是最好的选择。
identifier是用于识别训练目标并进行学习的单词。可以使用任何单词但是根据原始论文“Tokenizer生成的3个或更少字符的罕见单词”是最好的选择。
使用identifier和class例如“shs dog”可以将模型训练为从class中识别并学习所需的目标。
@@ -86,9 +87,9 @@ identifier是用于识别学习目标并进行学习的单词。可以使用任
作为identifier我最近使用的一些参考是“shs sts scs cpc coc cic msm usu ici lvl cic dii muk ori hru rik koo yos wny”等。最好是不包含在Danbooru标签中的单词。
## step 2. 决定是否使用正则化图像,并生成正则化图像
## step 2. 决定是否使用正则化图像,并在使用时生成正则化图像
正则化图像是为防止前面提到的语言漂移,即整个类别被拉扯成为学习目标而生成的图像。如果不使用正则化图像,例如在 `shs 1girl` 中学习特定角色时,即使在简单的 `1girl` 提示下生成,也会越来越像该角色。这是因为 `1girl` 在训练时的标题中包含了该角色的信息。
正则化图像是为防止前面提到的语言漂移,即整个类别被拉扯成为训练目标而生成的图像。如果不使用正则化图像,例如在 `shs 1girl` 中学习特定角色时,即使在简单的 `1girl` 提示下生成,也会越来越像该角色。这是因为 `1girl` 在训练时的标题中包含了该角色的信息。
通过同时学习目标图像和正则化图像,类别仍然保持不变,仅在将标识符附加到提示中时才生成目标图像。
@@ -100,46 +101,48 @@ identifier是用于识别学习目标并进行学习的单词。可以使用任
(由于正则化图像也被训练,因此其质量会影响模型。)
通常,准备数百张图像是理想的(图像数量太少会导致类别图像无法推广并学习它们的特征)。
通常,准备数百张图像是理想的(图像数量太少会导致类别图像无法被归纳,特征也不会被学习)。
如果要使用生成的图像生成图像的大小通常应与训练分辨率更准确地说是bucket的分辨率见下文相匹配。
如果要使用生成的图像请将其大小通常与训练分辨率更准确地说是bucket的分辨率相适应。
## step 2. 设置文件的描述
创建一个文本文件,并将其扩展名更改为`.toml`。例如,您可以按以下方式进行描述:
(以``开头的部分是注释,因此您可以直接复制粘贴,或者将其删除,都没有问题。)
(以``开头的部分是注释,因此您可以直接复制粘贴,或者将其删除。)
```toml
[general]
enable_bucket = true # 是否使用Aspect Ratio Bucketing
[[datasets]]
resolution = 512 # 学习分辨率
batch_size = 4 # 批大小
resolution = 512 # 训练分辨率
batch_size = 4 # 批大小
[[datasets.subsets]]
image_dir = 'C:\hoge' # 指定包含训练图像的文件夹
class_tokens = 'hoge girl' # 指定标识符类
num_repeats = 10 # 训练图像的迭代次数
num_repeats = 10 # 训练图像的重复次数
# 以下仅在使用正则化图像时进行描述。不使用则删除
[[datasets.subsets]]
is_reg = true
image_dir = 'C:\reg' # 指定包含正则化图像的文件夹
class_tokens = 'girl' # 指定类别
num_repeats = 1 # 正则化图像的迭代次数基本上1就可以了
class_tokens = 'girl' # 指定class
num_repeats = 1 # 正则化图像的重复次数基本上1就可以了
```
基本上只需更改以下位置即可进行学习
基本上只需更改以下几个地方即可进行训练
1. 学习分辨率
1. 训练分辨率
指定一个数字表示正方形(如果是 `512`,则为 512x512如果使用方括号和逗号分隔的两个数字则表示横向×纵向如果是`[512,768]`,则为 512x768。在SD1.x系列中原始学习分辨率为512。指定较大的分辨率`[512,768]` 可能会减少纵向和横向图像生成时的错误。在SD2.x 768系列中分辨率为 `768`
指定一个数字表示正方形(如果是 `512`,则为 512x512如果使用方括号和逗号分隔的两个数字则表示横向×纵向如果是`[512,768]`,则为 512x768。在SD1.x系列中原始训练分辨率为512。指定较大的分辨率`[512,768]` 可能会减少纵向和横向图像生成时的错误。在SD2.x 768系列中分辨率为 `768`
1. 批大小
1. 批大小
指定同时学习多少个数据。这取决于GPU的VRAM大小和学习分辨率。详细信息将在后面说明。此外fine tuning/DreamBooth/LoRA等也会影响批大小,请查看各个脚本的说明。
指定同时训练多少个数据。这取决于GPU的VRAM大小和训练分辨率。详细信息将在后面说明。此外fine tuning/DreamBooth/LoRA等也会影响批大小,请查看各个脚本的说明。
1. 文件夹指定
@@ -149,7 +152,7 @@ batch_size = 4 # 批量大小
如前所述,与示例相同。
1. 迭代次数
1. 重复次数
将在后面说明。
@@ -159,69 +162,68 @@ batch_size = 4 # 批量大小
请将重复次数指定为“ __训练用图像的重复次数×训练用图像的数量≥正则化图像的重复次数×正则化图像的数量__ ”。
1个epoch数据一周一次)的数据量为“训练用图像的重复次数×训练用图像的数量”。如果正则化图像的数量多于这个值,则剩余的正则化图像将不会被使用。)
1个epoch指训练数据过完一遍)的数据量为“训练用图像的重复次数×训练用图像的数量”。如果正则化图像的数量多于这个值,则剩余的正则化图像将不会被使用。)
## 步骤 3. 学习
## 步骤 3. 训练
请根据每个文档的参考进行学习
详情请参考相关文档进行训练
# DreamBooth标题方式(可使用规范化图像)
# DreamBooth文本说明caption方式(可使用正则化图像)
在此方式中,每个图像都将通过标题进行学习
在此方式中,每个图像都将通过caption进行训练
## 步骤 1. 准备标题文件
## 步骤 1. 准备文本说明文件
请将与图像具有相同文件名且扩展名为 `.caption`(可以在设置中更改)的文件放置在用于训练图像的文件夹中。每个文件应该只有一行。编码为 `UTF-8`
## 步骤 2. 决定是否使用规范化图像,并在使用时生成规范化图像
## 步骤 2. 决定是否使用正则化图像,并在使用时生成正则化图像
与class+identifier格式相同。可以在规范化图像上附加标题,但通常不需要。
与class+identifier格式相同。可以在规范化图像上附加caption,但通常不需要。
## 步骤 2. 编写设置文件
创建一个文本文件并将扩展名更改为 `.toml`。例如,可以按以下方式进行记录。
创建一个文本文件并将扩展名更改为 `.toml`。例如,可以按以下方式进行描述:
```toml
[general]
enable_bucket = true # Aspect Ratio Bucketingを使うか否か
enable_bucket = true # 是否使用Aspect Ratio Bucketing
[[datasets]]
resolution = 512 # 学習解像度
batch_size = 4 # 批大小
resolution = 512 # 训练分辨率
batch_size = 4 # 批大小
[[datasets.subsets]]
image_dir = 'C:\hoge' # 指定包含训练图像的文件夹
caption_extension = '.caption' # 使用字幕文件扩展名 .txt 时重写
num_repeats = 10 # 训练图像的迭代次数
caption_extension = '.caption' # 使用txt文件,更改此项
num_repeats = 10 # 训练图像的重复次数
# 以下仅在使用正则化图像时进行描述。不使用则删除
[[datasets.subsets]]
is_reg = true
image_dir = 'C:\reg' #指定包含正则化图像的文件夹
class_tokens = 'girl' # class を指定
num_repeats = 1 #
正则化图像的迭代次数基本上1就可以了
image_dir = 'C:\reg' # 指定包含正则化图像的文件夹
class_tokens = 'girl' # 指定class
num_repeats = 1 # 正则化图像的重复次数基本上1就可以了
```
基本上,您可以通过仅重写以下位置来学习。除非另有说明,否则与类+标识符方法相同。
基本上只需更改以下几个地方来训练。除非另有说明否则与class+identifier方法相同。
1. 学习分辨率
2. 批大小
1. 训练分辨率
2. 批大小
3. 文件夹指定
4. 标题文件的扩展名
4. caption文件的扩展名
可以指定任意的扩展名。
5. 重复次数
## 步骤 3. 学习
## 步骤 3. 训练
请参考每个文档进行学习
详情请参考相关文档进行训练
# 微调方法
# 微调方法(fine tuning)
## 步骤 1. 准备元数据
标题和标签整合到管理文件中称为元数据。它的扩展名为 `.json`格式为json。由于创建方法较长因此在本文档的末尾进行描述。
caption和标签整合到管理文件中称为元数据。它的扩展名为 `.json`格式为json。由于创建方法较长因此在本文档的末尾进行描述。
## 步骤 2. 编写设置文件
@@ -233,16 +235,16 @@ keep_tokens = 1
[[datasets]]
resolution = 512 # 图像分辨率
batch_size = 4 # 批大小
batch_size = 4 # 批大小
[[datasets.subsets]]
image_dir = 'C:\piyo' # 指定包含训练图像的文件夹
metadata_file = 'C:\piyo\piyo_md.json' # 元数据文件名
```
基本上,您可以通过仅重写以下位置来学习。如无特别说明与DreamBooth相同,类+标识符方式
基本上只需更改以下几个地方来训练。除非另有说明,否则与DreamBooth, class+identifier方法相同。
1. 学习解像度
1. 训练分辨率
2. 批次大小
3. 指定文件夹
4. 元数据文件名
@@ -250,25 +252,25 @@ batch_size = 4 # 批量大小
指定使用后面所述方法创建的元数据文件。
## 第三步:学习
## 第三步:训练
请参考各个文档进行学习
详情请参考相关文档进行训练
# 学习中使用的术语简单解释
# 训练中使用的术语简单解释
由于省略了细节并且我自己也没有完全理解,因此请自行查阅详细信息。
## 微调fine tuning
指训练模型并微调其性能。具体含义因用法而异,但在 Stable Diffusion 中,狭义的微调是指使用图像和标题进行训练模型。DreamBooth 可视为狭义微调的一种特殊方法。广义的微调包括 LoRA、Textual Inversion、Hypernetworks 等,包括训练模型的所有内容。
指训练模型并微调其性能。具体含义因用法而异,但在 Stable Diffusion 中,狭义的微调是指使用图像和caption进行训练模型。DreamBooth 可视为狭义微调的一种特殊方法。广义的微调包括 LoRA、Textual Inversion、Hypernetworks 等,包括训练模型的所有内容。
## 步骤step
粗略地说,每次在训练数据上进行一次计算即为一步。具体来说,“将训练数据的标题传递给当前模型,将生成的图像与训练数据的图像进行比较,稍微更改模型,以使其更接近训练数据”即为一步。
粗略地说,每次在训练数据上进行一次计算即为一步。具体来说,“将训练数据的caption传递给当前模型,将生成的图像与训练数据的图像进行比较,稍微更改模型,以使其更接近训练数据”即为一步。
## 批次大小batch size
批次大小指定每个步骤要计算多少数据。批计算可以提高速度。一般来说,批次大小越大,精度也越高。
批次大小指定每个步骤要计算多少数据。批计算可以提高速度。一般来说,批次大小越大,精度也越高。
“批次大小×步数”是用于训练的数据数量。因此,建议减少步数以增加批次大小。
@@ -276,37 +278,37 @@ batch_size = 4 # 批量大小
批次大小越大GPU 内存消耗就越大。如果内存不足,将导致错误,或者在边缘时将导致训练速度降低。建议在任务管理器或 `nvidia-smi` 命令中检查使用的内存量进行调整。
另外,批次是指“一数据”的意思
注意,一个批次是指“一数据单位”。
## 学习率
学习率指的是每个步骤中改变的程度。如果指定一个大的值,学习速度就会加快,但是可能会出现变化太大导致模型崩溃或无法达到最佳状态的情况。如果指定一个小的值,学习速度会变慢,可能无法达到最佳状态。
学习率指的是每个步骤中改变的程度。如果指定一个大的值,学习速度就会加快,但是可能会出现变化太大导致模型崩溃或无法达到最佳状态的情况。如果指定一个小的值,学习速度会变慢,同时可能无法达到最佳状态。
在fine tuning、DreamBooth、LoRA等过程中学习率会有很大的差异并且也会受到训练数据、所需训练的模型、批大小和步骤数等因素的影响。建议从一般的值开始,观察训练状态并逐渐调整。
在fine tuning、DreamBooth、LoRA等过程中学习率会有很大的差异并且也会受到训练数据、所需训练的模型、批大小和步骤数等因素的影响。建议从通常值开始,观察训练状态并逐渐调整。
默认情况下,整个训练过程中学习率是固定的。但是可以通过调度程序指定学习率如何变化,因此结果也会有所不同。
## 时代epoch
## Epoch
Epoch指的是训练数据被完整训练一遍即数据一周)的情况。如果指定了重复次数,则在重复后的数据一周后,就是1个epoch。
Epoch指的是训练数据被完整训练一遍即数据已经迭代一轮)。如果指定了重复次数,则在重复后的数据迭代一轮后,为1个epoch。
1个epoch的步骤数通常为“数据量÷批大小”但如果使用Aspect Ratio Bucketing则略微增加由于不同bucket的数据不能在同一个批次中因此步骤数会增加
1个epoch的步骤数通常为“数据量÷批大小”但如果使用Aspect Ratio Bucketing则略微增加由于不同bucket的数据不能在同一个批次中因此步骤数会增加
## 纵横比分桶Aspect Ratio Bucketing)
## 长宽比分桶Aspect Ratio Bucketing
Stable Diffusion 的 v1 是以 512\*512 的分辨率进行训练的,但同时也可以在其他分辨率下进行训练,例如 256\*1024 和 384\*640。这样可以减少裁剪的部分望更准确地学习图像和标题之间的关系。
Stable Diffusion 的 v1 是以 512\*512 的分辨率进行训练的,但同时也可以在其他分辨率下进行训练,例如 256\*1024 和 384\*640。这样可以减少裁剪的部分望更准确地学习图像和标题之间的关系。
此外,由于可以在任意分辨率下进行训练,因此不再需要事先统一图像数据的纵横比。
此外,由于可以在任意分辨率下进行训练,因此不再需要事先统一图像数据的长宽比。
该设置在配置中有效,可以切换,但在此之前的配置文件示例中已启用(设置为 `true`)。
此值可以被设定,其在此之前的配置文件示例中已启用(设置为 `true`)。
学习分辨率将根据参数所提供的分辨率面积(即内存使用量)进行调整,以64像素为单位(默认值,可更改)在纵横方向上进行调整和创建。
只要不超过作为参数给出的分辨率区域(= 内存使用量),就可以按 64 像素的增量(默认值,可更改)在垂直和水平方向上调整和创建训练分辨率
在机器学习中,通常需要将所有输入大小统一,但实际上只要在同一批次中统一即可。 NovelAI 所说的分桶(bucketing) 指的是,预先将训练数据按照纵横比分类到每个学习分辨率下,并通过使用每个 bucket 内的图像创建批次来统一批次图像大小。
在机器学习中,通常需要将所有输入大小统一,但实际上只要在同一批次中统一即可。 NovelAI 所说的分桶(bucketing) 指的是,预先将训练数据按照长宽比分类到每个学习分辨率下,并通过使用每个 bucket 内的图像创建批次来统一批次图像大小。
# 以前的指定格式(不使用 .toml 文件,而是使用命令行选项指定)
这是一种通过命令行选项而不是指定 .toml 文件的方法。有 DreamBooth 类+标识符方法、DreamBooth 标题方法、微调方法三种方式。
这是一种通过命令行选项而不是指定 .toml 文件的方法。有 DreamBooth 类+标识符方法、DreamBooth caption方法、微调方法三种方式。
## DreamBooth、类+标识符方式
@@ -326,7 +328,7 @@ Stable Diffusion 的 v1 是以 512\*512 的分辨率进行训练的,但同时
![image](https://user-images.githubusercontent.com/52813779/210770636-1c851377-5936-4c15-90b7-8ac8ad6c2074.png)
### 多个类别、多个标识符的学习
### 多个类别、多个标识符的训练
该方法很简单在用于训练的图像文件夹中需要准备多个文件夹每个文件夹都是以“重复次数_<标识符> <类别>”命名的同样在正则化图像文件夹中也需要准备多个文件夹每个文件夹都是以“重复次数_<类别>”命名的。
@@ -344,37 +346,37 @@ Stable Diffusion 的 v1 是以 512\*512 的分辨率进行训练的,但同时
### step 2. 准备正规化图像
这是使用则化图像时的过程。
这是使用则化图像时的过程。
创建一个文件夹来存储则化的图像。 __此外__ 创建一个名为``<repeat count>_<class>`` 的目录。
创建一个文件夹来存储则化的图像。 __此外__ 创建一个名为``<repeat count>_<class>`` 的目录。
例如使用提示“frog”并且不重复数据仅一次
![image](https://user-images.githubusercontent.com/52813779/210770897-329758e5-3675-49f1-b345-c135f1725832.png)
步骤3. 执行学习
步骤3. 执行训练
执行每个学习脚本。使用 `--train_data_dir` 选项指定包含训练数据文件夹的父文件夹(不是包含图像的文件夹),使用 `--reg_data_dir` 选项指定包含正则化图像的父文件夹(不是包含图像的文件夹)。
执行每个训练脚本。使用 `--train_data_dir` 选项指定包含训练数据文件夹的父文件夹(不是包含图像的文件夹),使用 `--reg_data_dir` 选项指定包含正则化图像的父文件夹(不是包含图像的文件夹)。
## DreamBooth标题方式
## DreamBooth文本说明caption方式
在包含训练图像和正则化图像的文件夹中,将与图像具有相同文件名的文件.caption可以使用选项进行更改放置在该文件夹中然后从该文件中加载标题作为提示进行学习
在包含训练图像和正则化图像的文件夹中,将与图像具有相同文件名的文件.caption可以使用选项进行更改放置在该文件夹中然后从该文件中加载caption所作为提示进行训练
※文件夹名称(标识符类)不再用于这些图像的训练。
默认的标题文件扩展名为.caption。可以使用学习脚本的 `--caption_extension` 选项进行更改。 使用 `--shuffle_caption` 选项,同时对每个逗号分隔的部分进行学习时会对学习时的标题进行混洗。
默认的caption文件扩展名为.caption。可以使用训练脚本的 `--caption_extension` 选项进行更改。 使用 `--shuffle_caption` 选项,同时对每个逗号分隔的部分进行训练时会对训练时的caption进行混洗。
## 微调方式
创建元数据的方式与使用配置文件相同。 使用 `in_json` 选项指定元数据文件。
# 学习过程中的样本输出
# 训练过程中的样本输出
通过在训练中使用模型生成图像,可以检查学习进度。将以下选项指定为学习脚本。
通过在训练中使用模型生成图像,可以检查训练进度。将以下选项指定为训练脚本。
- `--sample_every_n_steps` / `--sample_every_n_epochs`
指定要采样的步数或纪元数。为这些数字中的每一个输出样本。如果两者都指定,则 epoch 数优先。
指定要采样的步数或epoch数。为这些数字中的每一个输出样本。如果两者都指定,则 epoch 数优先。
- `--sample_prompts`
指定示例输出的提示文件。
@@ -421,11 +423,11 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
4. U-Net的结构CrossAttention的头数等
5. v-parameterization采样方式好像变了
其中碱基使用1-4,非碱基使用1-5768-v。使用 1-4 进行 v2 选择,使用 5 进行 v_parameterization 选择。
-`--pretrained_model_name_or_path`
其中base使用1-4base使用1-5768-v。使用 1-4 进行 v2 选择,使用 5 进行 v_parameterization 选择。
- `--pretrained_model_name_or_path`
指定要从中执行额外训练的模型。您可以指定稳定扩散检查点文件(.ckpt 或 .safetensors扩散器本地磁盘上的模型目录或扩散器模型 ID例如“stabilityai/stable-diffusion-2”
## 学习设置
指定要从中执行额外训练的模型。您可以指定Stable Diffusion检查点文件(.ckpt 或 .safetensorsdiffusers本地磁盘上的模型目录或diffusers模型 ID例如“stabilityai/stable-diffusion-2”
## 训练设置
- `--output_dir`
@@ -441,7 +443,7 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
- `--max_train_steps` / `--max_train_epochs`
指定要学习的步数或纪元数。如果两者都指定,则 epoch 数优先。
指定要训练的步数或epoch数。如果两者都指定,则 epoch 数优先。
-
- `--mixed_precision`
@@ -450,9 +452,9 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
在RTX30系列以后也可以指定`bf16`,请配合您在搭建环境时做的加速设置)。
- `--gradient_checkpointing`
通过逐步计算权重而不是在训练期间一次计算所有权重来减少训练所需的 GPU 内存量。关闭它不会影响准确性,但打开它允许更大的批大小,所以那里有影响。
通过逐步计算权重而不是在训练期间一次计算所有权重来减少训练所需的 GPU 内存量。关闭它不会影响准确性,但打开它允许更大的批大小,所以那里有影响。
另外,打开它通常会减慢速度,但可以增加批大小,因此总的学习时间实际上可能会更快。
另外,打开它通常会减慢速度,但可以增加批大小,因此总的训练时间实际上可能会更快。
- `--xformers` / `--mem_eff_attn`
@@ -463,35 +465,35 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
- `--save_every_n_epochs` / `--save_state` / `--resume`
为 save_every_n_epochs 选项指定一个数字可以在每个时期的训练期间保存模型。
如果同时指定save_state选项学习状态包括优化器的状态等都会一起保存。。保存目的地将是一个文件夹。
如果同时指定save_state选项训练状态包括优化器的状态等都会一起保存。。保存目的地将是一个文件夹。
学习状态输出到目标文件夹中名为“<output_name>-??????-state”??????是纪元数)的文件夹中。长时间学习时请使用。
训练状态输出到目标文件夹中名为“<output_name>-??????-state”??????是epoch数)的文件夹中。长时间训练时请使用。
使用 resume 选项从保存的训练状态恢复训练。指定学习状态文件夹(其中的状态文件夹,而不是 `output_dir`)。
使用 resume 选项从保存的训练状态恢复训练。指定训练状态文件夹(其中的状态文件夹,而不是 `output_dir`)。
请注意,由于 Accelerator 规范epoch 数和全局步数不会保存,即使恢复时它们也从 1 开始。
- `--save_model_as` DreamBooth, fine tuning 仅有的)
您可以从 `ckpt, safetensors, diffusers, diffusers_safetensors` 中选择模型保存格式。
- `--save_model_as=safetensors` 指定喜欢当读取稳定扩散格式ckpt 或安全张量)并以扩散器格式保存时,缺少的信息通过从 Hugging Face 中删除 v1.5 或 v2.1 信息来补充。
- `--save_model_as=safetensors` 指定喜欢当读取Stable Diffusion格式ckpt 或safetensors并以diffusers格式保存时,缺少的信息通过从 Hugging Face 中删除 v1.5 或 v2.1 信息来补充。
- `--clip_skip`
`2` 如果指定,则使用文本编码器 (CLIP) 的倒数第二层的输出。如果省略 1 或选项,则使用最后一层。
*SD2.0默认使用倒数第二层,学习SD2.0时请不要指定。
*SD2.0默认使用倒数第二层,训练SD2.0时请不要指定。
如果被训练的模型最初被训练为使用第二层,则 2 是一个很好的值。
如果您使用的是最后一层那么整个模型都会根据该假设进行训练。因此如果再次使用第二层进行训练可能需要一定数量的teacher数据和更长时间的学习才能得到想要的学习结果。
如果您使用的是最后一层那么整个模型都会根据该假设进行训练。因此如果再次使用第二层进行训练可能需要一定数量的teacher数据和更长时间的训练才能得到想要的训练结果。
- `--max_token_length`
默认值为 75。您可以通过指定“150”或“225”来扩展令牌长度来学习。使用长字幕学习时指定。
默认值为 75。您可以通过指定“150”或“225”来扩展令牌长度来训练。使用长字幕训练时指定。
但由于学习时token展开的规范与Automatic1111的web UI除法等规范略有不同如非必要建议用75学习
但由于训练时token展开的规范与Automatic1111的web UI除法等规范略有不同如非必要建议用75训练
与clip_skip一样学习与模型学习状态不同的长度可能需要一定量的teacher数据和更长的学习时间。
与clip_skip一样训练与模型训练状态不同的长度可能需要一定量的teacher数据和更长的学习时间。
- `--persistent_data_loader_workers`
@@ -502,7 +504,7 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
指定数据加载的进程数。大量的进程会更快地加载数据并更有效地使用 GPU但会消耗更多的主内存。默认是"`8`或者`CPU并发执行线程数 - 1`,取小者"所以如果主存没有空间或者GPU使用率大概在90%以上,就看那些数字和 `2` 或将其降低到大约 `1`
- `--logging_dir` / `--log_prefix`
保存学习日志的选项。在 logging_dir 选项中指定日志保存目标文件夹。以 TensorBoard 格式保存日志。
保存训练日志的选项。在 logging_dir 选项中指定日志保存目标文件夹。以 TensorBoard 格式保存日志。
例如,如果您指定 --logging_dir=logs将在您的工作文件夹中创建一个日志文件夹并将日志保存在日期/时间文件夹中。
此外,如果您指定 --log_prefix 选项,则指定的字符串将添加到日期和时间之前。使用“--logging_dir=logs --log_prefix=db_style1_”进行识别。
@@ -518,23 +520,23 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
- `--noise_offset`
本文的实现https://www.crosslabs.org//blog/diffusion-with-offset-noise
看起来它可能会为整体更暗和更亮的图像产生更好的结果。它似乎对 LoRA 学习也有效。指定一个大约 0.1 的值似乎很好。
看起来它可能会为整体更暗和更亮的图像产生更好的结果。它似乎对 LoRA 训练也有效。指定一个大约 0.1 的值似乎很好。
- `--debug_dataset`
通过添加此选项,您可以在学习之前检查将学习什么样的图像数据和标题。按 Esc 退出并返回命令行。按 `S` 进入下一步(批次),按 `E` 进入下一个纪元
通过添加此选项,您可以在训练之前检查将训练什么样的图像数据和标题。按 Esc 退出并返回命令行。按 `S` 进入下一步(批次),按 `E` 进入下一个epoch
*图片在 Linux 环境(包括 Colab下不显示。
- `--vae`
如果您在 vae 选项中指定稳定扩散检查点、VAE 检查点文件、扩散模型或 VAE两者都可以指定本地或拥抱面模型 ID则该 VAE 用于学习(缓存时的潜伏)或在学习过程中获得潜伏)。
如果您在 vae 选项中指定Stable Diffusion检查点、VAE 检查点文件、扩散模型或 VAE两者都可以指定本地或拥抱面模型 ID则该 VAE 用于训练(缓存时的潜伏)或在训练过程中获得潜伏)。
对于 DreamBooth 和微调,保存的模型将包含此 VAE
- `--cache_latents`
在主内存中缓存 VAE 输出以减少 VRAM 使用。除 flip_aug 之外的任何增强都将不可用。此外,整体学习速度略快。
在主内存中缓存 VAE 输出以减少 VRAM 使用。除 flip_aug 之外的任何增强都将不可用。此外,整体训练速度略快。
- `--min_snr_gamma`
指定最小 SNR 加权策略。细节是[这里](https://github.com/kohya-ss/sd-scripts/pull/308)请参阅。论文中推荐`5`
@@ -545,19 +547,29 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
-- 指定优化器类型。您可以指定
- AdamW : [torch.optim.AdamW](https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html)
- 与过去版本中未指定选项时相同
- AdamW8bit : 同上
- AdamW8bit : 参数同上
- PagedAdamW8bit : 参数同上
- 与过去版本中指定的 --use_8bit_adam 相同
- Lion : https://github.com/lucidrains/lion-pytorch
- Lion8bit : 参数同上
- PagedLion8bit : 参数同上
- 与过去版本中指定的 --use_lion_optimizer 相同
- SGDNesterov : [torch.optim.SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html), nesterov=True
- SGDNesterov8bit : 数同上
- DAdaptation : https://github.com/facebookresearch/dadaptation
- SGDNesterov8bit : 数同上
- DAdaptation(DAdaptAdamPreprint) : https://github.com/facebookresearch/dadaptation
- DAdaptAdam : 参数同上
- DAdaptAdaGrad : 参数同上
- DAdaptAdan : 参数同上
- DAdaptAdanIP : 参数同上
- DAdaptLion : 参数同上
- DAdaptSGD : 参数同上
- Prodigy : https://github.com/konstmish/prodigy
- AdaFactor : [Transformers AdaFactor](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules)
- 任何优化器
- `--learning_rate`
指定学习率。合适的学习率取决于学习脚本,所以请参考每个解释。
指定学习率。合适的学习率取决于训练脚本,所以请参考每个解释。
- `--lr_scheduler` / `--lr_warmup_steps` / `--lr_scheduler_num_cycles` / `--lr_scheduler_power`
学习率的调度程序相关规范。
@@ -592,14 +604,14 @@ D-Adaptation 优化器自动调整学习率。学习率选项指定的值不是
(内部仅通过 importlib 未确认操作。如果需要,请安装包。)
<!--
## 使用任意大小的图像进行训练 --resolution
你可以在广场外学习。请在分辨率中指定“宽度、高度”如“448,640”。宽度和高度必须能被 64 整除。匹配训练图像和正则化图像的大小。
你可以在广场外训练。请在分辨率中指定“宽度、高度”如“448,640”。宽度和高度必须能被 64 整除。匹配训练图像和正则化图像的大小。
就我个人而言我经常生成垂直长的图像所以我有时会用“448、640”来学习
就我个人而言我经常生成垂直长的图像所以我有时会用“448、640”来训练
## 纵横比分桶 --enable_bucket / --min_bucket_reso / --max_bucket_reso
它通过指定 enable_bucket 选项来启用。 Stable Diffusion 在 512x512 分辨率下训练,但也在 256x768 和 384x640 等分辨率下训练。
如果指定此选项,则不需要将训练图像和正则化图像统一为特定分辨率。从多种分辨率(纵横比)中进行选择,并在该分辨率下学习
如果指定此选项,则不需要将训练图像和正则化图像统一为特定分辨率。从多种分辨率(纵横比)中进行选择,并在该分辨率下训练
由于分辨率为 64 像素,纵横比可能与原始图像不完全相同。
您可以使用 min_bucket_reso 选项指定分辨率的最小大小,使用 max_bucket_reso 指定最大大小。默认值分别为 256 和 1024。
@@ -611,13 +623,13 @@ D-Adaptation 优化器自动调整学习率。学习率选项指定的值不是
(因为一批中的图像不偏向于训练图像和正则化图像。
## 扩充 --color_aug / --flip_aug
增强是一种通过在学习过程中动态改变数据来提高模型性能的方法。在使用 color_aug 巧妙地改变色调并使用 flip_aug 左右翻转的同时学习
增强是一种通过在训练过程中动态改变数据来提高模型性能的方法。在使用 color_aug 巧妙地改变色调并使用 flip_aug 左右翻转的同时训练
由于数据是动态变化的,因此不能与 cache_latents 选项一起指定。
## 使用 fp16 梯度训练(实验特征)--full_fp16
如果指定 full_fp16 选项,梯度从普通 float32 变为 float16 (fp16) 并学习(它似乎是 full fp16 学习而不是混合精度)。
结果,似乎 SD1.x 512x512 大小可以在 VRAM 使用量小于 8GB 的​​情况下学习,而 SD2.x 512x512 大小可以在 VRAM 使用量小于 12GB 的情况下学习
如果指定 full_fp16 选项,梯度从普通 float32 变为 float16 (fp16) 并训练(它似乎是 full fp16 训练而不是混合精度)。
结果,似乎 SD1.x 512x512 大小可以在 VRAM 使用量小于 8GB 的​​情况下训练,而 SD2.x 512x512 大小可以在 VRAM 使用量小于 12GB 的情况下训练
预先在加速配置中指定 fp16并可选择设置 ``mixed_precision="fp16"``bf16 不起作用)。
@@ -631,20 +643,20 @@ D-Adaptation 优化器自动调整学习率。学习率选项指定的值不是
# 创建元数据文件
## 准备教师资料
## 准备训练数据
如上所述准备好你要学习的图像数据,放在任意文件夹中。
如上所述准备好你要训练的图像数据,放在任意文件夹中。
例如,存储这样的图像:
![教师数据文件夹的屏幕截图](https://user-images.githubusercontent.com/52813779/208907739-8e89d5fa-6ca8-4b60-8927-f484d2a9ae04.png)
## 自动字幕
## 自动captioning
如果您只想学习没有标题的标签,请跳过。
如果您只想训练没有标题的标签,请跳过。
另外,手动准备字幕时,请准备在与教师数据图像相同的目录下,文件名相同,扩展名.caption等。每个文件应该是只有一行的文本文件。
### 使用 BLIP 添加字幕
另外,手动准备caption时,请准备在与教师数据图像相同的目录下,文件名相同,扩展名.caption等。每个文件应该是只有一行的文本文件。
### 使用 BLIP 添加caption
最新版本不再需要 BLIP 下载、权重下载和额外的虚拟环境。按原样工作。
@@ -659,24 +671,24 @@ python finetune\make_captions.py --batch_size <バッチサイズ> <教師デー
python finetune\make_captions.py --batch_size 8 ..\train_data
```
字幕文件创建在与教师数据图像相同的目录中,具有相同的文件名和扩展名.caption。
caption文件创建在与教师数据图像相同的目录中,具有相同的文件名和扩展名.caption。
根据 GPU 的 VRAM 容量增加或减少 batch_size。越大越快我认为 12GB 的 VRAM 可以多一点)。
您可以使用 max_length 选项指定标题的最大长度。默认值为 75。如果使用 225 的令牌长度训练模型,它可能会更长。
您可以使用 caption_extension 选项更改标题扩展名。默认为 .caption.txt 与稍后描述的 DeepDanbooru 冲突)。
您可以使用 max_length 选项指定caption的最大长度。默认值为 75。如果使用 225 的令牌长度训练模型,它可能会更长。
您可以使用 caption_extension 选项更改caption扩展名。默认为 .caption.txt 与稍后描述的 DeepDanbooru 冲突)。
如果有多个教师数据文件夹,则对每个文件夹执行。
请注意,推理是随机的,因此每次运行时结果都会发生变化。如果要修复它,请使用 --seed 选项指定一个随机数种子,例如 `--seed 42`
其他的选项请参考help with `--help`(好像没有文档说明参数的含义,得看源码)。
默认情况下,会生成扩展名为 .caption 的字幕文件。
默认情况下,会生成扩展名为 .caption 的caption文件。
![caption生成的文件夹](https://user-images.githubusercontent.com/52813779/208908845-48a9d36c-f6ee-4dae-af71-9ab462d1459e.png)
例如,标题如下:
![字幕和图像](https://user-images.githubusercontent.com/52813779/208908947-af936957-5d73-4339-b6c8-945a52857373.png)
![caption和图像](https://user-images.githubusercontent.com/52813779/208908947-af936957-5d73-4339-b6c8-945a52857373.png)
## 由 DeepDanbooru 标记
@@ -695,7 +707,7 @@ python finetune\make_captions.py --batch_size 8 ..\train_data
做一个这样的目录结构
![DeepDanbooru的目录结构](https://user-images.githubusercontent.com/52813779/208909486-38935d8b-8dc6-43f1-84d3-fef99bc471aa.png)
扩散器环境安装必要的库。进入 DeepDanbooru 文件夹并安装它(我认为它实际上只是添加了 tensorflow-io
diffusers环境安装必要的库。进入 DeepDanbooru 文件夹并安装它(我认为它实际上只是添加了 tensorflow-io
```
pip install -r requirements.txt
```
@@ -768,12 +780,12 @@ python tag_images_by_wd14_tagger.py --batch_size 4 ..\train_data
如果有多个教师数据文件夹,则对每个文件夹执行。
## 预处理字幕和标签信息
## 预处理caption和标签信息
字幕和标签作为元数据合并到一个文件中,以便从脚本中轻松处理。
### 字幕预处理
caption和标签作为元数据合并到一个文件中,以便从脚本中轻松处理。
### caption预处理
要将字幕放入元数据,请在您的工作文件夹中运行以下命令(如果您不使用字幕进行学习,则不需要运行它)(它实际上是一行,依此类推)。指定 `--full_path` 选项以将图像文件的完整路径存储在元数据中。如果省略此选项,则会记录相对路径,但 .toml 文件中需要单独的文件夹规范。
要将caption放入元数据,请在您的工作文件夹中运行以下命令(如果您不使用caption进行训练,则不需要运行它)(它实际上是一行,依此类推)。指定 `--full_path` 选项以将图像文件的完整路径存储在元数据中。如果省略此选项,则会记录相对路径,但 .toml 文件中需要单独的文件夹规范。
```
python merge_captions_to_metadata.py --full_path <教师资料夹>
  --in_json <要读取的元数据文件名> <元数据文件名>
@@ -799,7 +811,7 @@ python merge_captions_to_metadata.py --full_path --in_json meta_cap1.json
__* 每次重写 in_json 选项和写入目标并写入单独的元数据文件是安全的。 __
### 标签预处理
同样,标签也收集在元数据中(如果标签不用于学习,则无需这样做)。
同样,标签也收集在元数据中(如果标签不用于训练,则无需这样做)。
```
python merge_dd_tags_to_metadata.py --full_path <教师资料夹>
--in_json <要读取的元数据文件名> <要写入的元数据文件名>
@@ -855,7 +867,7 @@ python clean_captions_and_tags.py meta_cap_dd.json meta_clean.json
python prepare_buckets_latents.py --full_path <教师资料夹>
<要读取的元数据文件名> <要写入的元数据文件名>
<要微调的模型名称或检查点>
--batch_size <批大小>
--batch_size <批大小>
--max_resolution <分辨率宽、高>
--mixed_precision <准确性>
```
@@ -875,7 +887,7 @@ python prepare_buckets_latents.py --full_path
对于翻转的图像也会获取latents并保存名为\ *_flip.npz的文件这是一个简单的实现。在fline_tune.py中不需要特定的选项。如果有带有\_flip的文件则会随机加载带有和不带有flip的文件。
即使VRAM为12GB大小也可以稍微增加。分辨率以“宽度高度”的形式指定必须是64的倍数。分辨率直接影响fine tuning时的内存大小。在12GB VRAM中512,512似乎是极限*。如果有16GB则可以将其提高到512,704或512,768。即使分辨率为256,256等VRAM 8GB也很难承受因为参数、优化器等与分辨率无关需要一定的内存
即使VRAM为12GB大小也可以稍微增加。分辨率以“宽度高度”的形式指定必须是64的倍数。分辨率直接影响fine tuning时的内存大小。在12GB VRAM中512,512似乎是极限*。如果有16GB则可以将其提高到512,704或512,768。即使分辨率为256,256等VRAM 8GB也很难承受因为参数、优化器等与分辨率无关需要一定的内存
*有报道称在batch size为1的训练中使用12GB VRAM和640,640的分辨率。

View File

@@ -0,0 +1,214 @@
# ControlNet-LLLite について
__きわめて実験的な実装のため、将来的に大きく変更される可能性があります。__
## 概要
ControlNet-LLLite は、[ControlNet](https://github.com/lllyasviel/ControlNet) の軽量版です。LoRA Like Lite という意味で、LoRAからインスピレーションを得た構造を持つ、軽量なControlNetです。現在はSDXLにのみ対応しています。
## サンプルの重みファイルと推論
こちらにあります: https://huggingface.co/kohya-ss/controlnet-lllite
ComfyUIのカスタムードを用意しています。: https://github.com/kohya-ss/ControlNet-LLLite-ComfyUI
生成サンプルはこのページの末尾にあります。
## モデル構造
ひとつのLLLiteモジュールは、制御用画像以下conditioning imageを潜在空間に写像するconditioning image embeddingと、LoRAにちょっと似た構造を持つ小型のネットワークからなります。LLLiteモジュールを、LoRAと同様にU-NetのLinearやConvに追加します。詳しくはソースコードを参照してください。
推論環境の制限で、現在はCrossAttentionのみattn1のq/k/v、attn2のqに追加されます。
## モデルの学習
### データセットの準備
通常のdatasetに加え、`conditioning_data_dir` で指定したディレクトリにconditioning imageを格納してください。conditioning imageは学習用画像と同じbasenameを持つ必要があります。また、conditioning imageは学習用画像と同じサイズに自動的にリサイズされます。conditioning imageにはキャプションファイルは不要です。
たとえば DreamBooth 方式でキャプションファイルを用いる場合の設定ファイルは以下のようになります。
```toml
[[datasets.subsets]]
image_dir = "path/to/image/dir"
caption_extension = ".txt"
conditioning_data_dir = "path/to/conditioning/image/dir"
```
現時点の制約として、random_cropは使用できません。
学習データとしては、元のモデルで生成した画像を学習用画像として、そこから加工した画像をconditioning imageとした、合成によるデータセットを用いるのがもっとも簡単ですデータセットの品質的には問題があるかもしれません。具体的なデータセットの合成方法については後述します。
なお、元モデルと異なる画風の画像を学習用画像とすると、制御に加えて、その画風についても学ぶ必要が生じます。ControlNet-LLLiteは容量が少ないため、画風学習には不向きです。このような場合には、後述の次元数を多めにしてください。
### 学習
スクリプトで生成する場合は、`sdxl_train_control_net_lllite.py` を実行してください。`--cond_emb_dim` でconditioning image embeddingの次元数を指定できます。`--network_dim` でLoRA的モジュールのrankを指定できます。その他のオプションは`sdxl_train_network.py`に準じますが、`--network_module`の指定は不要です。
学習時にはメモリを大量に使用しますので、キャッシュやgradient checkpointingなどの省メモリ化のオプションを有効にしてください。また`--full_bf16` オプションで、BFloat16を使用するのも有効ですRTX 30シリーズ以降のGPUが必要です。24GB VRAMで動作確認しています。
conditioning image embeddingの次元数は、サンプルのCannyでは32を指定しています。LoRA的モジュールのrankは同じく64です。対象とするconditioning imageの特徴に合わせて調整してください。
サンプルのCannyは恐らくかなり難しいと思われます。depthなどでは半分程度にしてもいいかもしれません。
以下は .toml の設定例です。
```toml
pretrained_model_name_or_path = "/path/to/model_trained_on.safetensors"
max_train_epochs = 12
max_data_loader_n_workers = 4
persistent_data_loader_workers = true
seed = 42
gradient_checkpointing = true
mixed_precision = "bf16"
save_precision = "bf16"
full_bf16 = true
optimizer_type = "adamw8bit"
learning_rate = 2e-4
xformers = true
output_dir = "/path/to/output/dir"
output_name = "output_name"
save_every_n_epochs = 1
save_model_as = "safetensors"
vae_batch_size = 4
cache_latents = true
cache_latents_to_disk = true
cache_text_encoder_outputs = true
cache_text_encoder_outputs_to_disk = true
network_dim = 64
cond_emb_dim = 32
dataset_config = "/path/to/dataset.toml"
```
### 推論
スクリプトで生成する場合は、`sdxl_gen_img.py` を実行してください。`--control_net_lllite_models` でLLLiteのモデルファイルを指定できます。次元数はモデルファイルから自動取得します。
`--guide_image_path`で推論に用いるconditioning imageを指定してください。なおpreprocessは行われないため、たとえばCannyならCanny処理を行った画像を指定してください背景黒に白線`--control_net_preps`, `--control_net_weights`, `--control_net_ratios` には未対応です。
## データセットの合成方法
### 学習用画像の生成
学習のベースとなるモデルで画像生成を行います。Web UIやComfyUIなどで生成してください。画像サイズはモデルのデフォルトサイズで良いと思われます1024x1024など。bucketingを用いることもできます。その場合は適宜適切な解像度で生成してください。
生成時のキャプション等は、ControlNet-LLLiteの利用時に生成したい画像にあわせるのが良いと思われます。
生成した画像を任意のディレクトリに保存してください。このディレクトリをデータセットの設定ファイルで指定します。
当リポジトリ内の `sdxl_gen_img.py` でも生成できます。例えば以下のように実行します。
```dos
python sdxl_gen_img.py --ckpt path/to/model.safetensors --n_iter 1 --scale 10 --steps 36 --outdir path/to/output/dir --xformers --W 1024 --H 1024 --original_width 2048 --original_height 2048 --bf16 --sampler ddim --batch_size 4 --vae_batch_size 2 --images_per_prompt 512 --max_embeddings_multiples 1 --prompt "{portrait|digital art|anime screen cap|detailed illustration} of 1girl, {standing|sitting|walking|running|dancing} on {classroom|street|town|beach|indoors|outdoors}, {looking at viewer|looking away|looking at another}, {in|wearing} {shirt and skirt|school uniform|casual wear} { |, dynamic pose}, (solo), teen age, {0-1$$smile,|blush,|kind smile,|expression less,|happy,|sadness,} {0-1$$upper body,|full body,|cowboy shot,|face focus,} trending on pixiv, {0-2$$depth of fields,|8k wallpaper,|highly detailed,|pov,} {0-1$$summer, |winter, |spring, |autumn, } beautiful face { |, from below|, from above|, from side|, from behind|, from back} --n nsfw, bad face, lowres, low quality, worst quality, low effort, watermark, signature, ugly, poorly drawn"
```
VRAM 24GBの設定です。VRAMサイズにより`--batch_size` `--vae_batch_size`を調整してください。
`--prompt`でワイルドカードを利用してランダムに生成しています。適宜調整してください。
### 画像の加工
外部のプログラムを用いて、生成した画像を加工します。加工した画像を任意のディレクトリに保存してください。これらがconditioning imageになります。
加工にはたとえばCannyなら以下のようなスクリプトが使えます。
```python
import glob
import os
import random
import cv2
import numpy as np
IMAGES_DIR = "path/to/generated/images"
CANNY_DIR = "path/to/canny/images"
os.makedirs(CANNY_DIR, exist_ok=True)
img_files = glob.glob(IMAGES_DIR + "/*.png")
for img_file in img_files:
can_file = CANNY_DIR + "/" + os.path.basename(img_file)
if os.path.exists(can_file):
print("Skip: " + img_file)
continue
print(img_file)
img = cv2.imread(img_file)
# random threshold
# while True:
# threshold1 = random.randint(0, 127)
# threshold2 = random.randint(128, 255)
# if threshold2 - threshold1 > 80:
# break
# fixed threshold
threshold1 = 100
threshold2 = 200
img = cv2.Canny(img, threshold1, threshold2)
cv2.imwrite(can_file, img)
```
### キャプションファイルの作成
学習用画像のbasenameと同じ名前で、それぞれの画像に対応したキャプションファイルを作成してください。生成時のプロンプトをそのまま利用すれば良いと思われます。
`sdxl_gen_img.py` で生成した場合は、画像内のメタデータに生成時のプロンプトが記録されていますので、以下のようなスクリプトで学習用画像と同じディレクトリにキャプションファイルを作成できます(拡張子 `.txt`)。
```python
import glob
import os
from PIL import Image
IMAGES_DIR = "path/to/generated/images"
img_files = glob.glob(IMAGES_DIR + "/*.png")
for img_file in img_files:
cap_file = img_file.replace(".png", ".txt")
if os.path.exists(cap_file):
print(f"Skip: {img_file}")
continue
print(img_file)
img = Image.open(img_file)
prompt = img.text["prompt"] if "prompt" in img.text else ""
if prompt == "":
print(f"Prompt not found in {img_file}")
with open(cap_file, "w") as f:
f.write(prompt + "\n")
```
### データセットの設定ファイルの作成
コマンドラインオプションからの指定も可能ですが、`.toml`ファイルを作成する場合は `conditioning_data_dir` に加工した画像を保存したディレクトリを指定します。
以下は設定ファイルの例です。
```toml
[general]
flip_aug = false
color_aug = false
resolution = [1024,1024]
[[datasets]]
batch_size = 8
enable_bucket = false
[[datasets.subsets]]
image_dir = "path/to/generated/image/dir"
caption_extension = ".txt"
conditioning_data_dir = "path/to/canny/image/dir"
```
## 謝辞
ControlNetの作者である lllyasviel 氏、実装上のアドバイスとトラブル解決へのご尽力をいただいた furusu 氏、ControlNetデータセットを実装していただいた ddPn08 氏に感謝いたします。
## サンプル
Canny
![kohya_ss_girl_standing_at_classroom_smiling_to_the_viewer_class_78976b3e-0d4d-4ea0-b8e3-053ae493abbc](https://github.com/kohya-ss/sd-scripts/assets/52813779/37e9a736-649b-4c0f-ab26-880a1bf319b5)
![im_20230820104253_000_1](https://github.com/kohya-ss/sd-scripts/assets/52813779/c8896900-ab86-4120-932f-6e2ae17b77c0)
![im_20230820104302_000_1](https://github.com/kohya-ss/sd-scripts/assets/52813779/b12457a0-ee3c-450e-ba9a-b712d0fe86bb)
![im_20230820104310_000_1](https://github.com/kohya-ss/sd-scripts/assets/52813779/8845b8d9-804a-44ac-9618-113a28eac8a1)

217
docs/train_lllite_README.md Normal file
View File

@@ -0,0 +1,217 @@
# About ControlNet-LLLite
__This is an extremely experimental implementation and may change significantly in the future.__
日本語版は[こちら](./train_lllite_README-ja.md)
## Overview
ControlNet-LLLite is a lightweight version of [ControlNet](https://github.com/lllyasviel/ControlNet). It is a "LoRA Like Lite" that is inspired by LoRA and has a lightweight structure. Currently, only SDXL is supported.
## Sample weight file and inference
Sample weight file is available here: https://huggingface.co/kohya-ss/controlnet-lllite
A custom node for ComfyUI is available: https://github.com/kohya-ss/ControlNet-LLLite-ComfyUI
Sample images are at the end of this page.
## Model structure
A single LLLite module consists of a conditioning image embedding that maps a conditioning image to a latent space and a small network with a structure similar to LoRA. The LLLite module is added to U-Net's Linear and Conv in the same way as LoRA. Please refer to the source code for details.
Due to the limitations of the inference environment, only CrossAttention (attn1 q/k/v, attn2 q) is currently added.
## Model training
### Preparing the dataset
In addition to the normal dataset, please store the conditioning image in the directory specified by `conditioning_data_dir`. The conditioning image must have the same basename as the training image. The conditioning image will be automatically resized to the same size as the training image. The conditioning image does not require a caption file.
```toml
[[datasets.subsets]]
image_dir = "path/to/image/dir"
caption_extension = ".txt"
conditioning_data_dir = "path/to/conditioning/image/dir"
```
At the moment, random_crop cannot be used.
For training data, it is easiest to use a synthetic dataset with the original model-generated images as training images and processed images as conditioning images (the quality of the dataset may be problematic). See below for specific methods of synthesizing datasets.
Note that if you use an image with a different art style than the original model as a training image, the model will have to learn not only the control but also the art style. ControlNet-LLLite has a small capacity, so it is not suitable for learning art styles. In such cases, increase the number of dimensions as described below.
### Training
Run `sdxl_train_control_net_lllite.py`. You can specify the dimension of the conditioning image embedding with `--cond_emb_dim`. You can specify the rank of the LoRA-like module with `--network_dim`. Other options are the same as `sdxl_train_network.py`, but `--network_module` is not required.
Since a large amount of memory is used during training, please enable memory-saving options such as cache and gradient checkpointing. It is also effective to use BFloat16 with the `--full_bf16` option (requires RTX 30 series or later GPU). It has been confirmed to work with 24GB VRAM.
For the sample Canny, the dimension of the conditioning image embedding is 32. The rank of the LoRA-like module is also 64. Adjust according to the features of the conditioning image you are targeting.
(The sample Canny is probably quite difficult. It may be better to reduce it to about half for depth, etc.)
The following is an example of a .toml configuration.
```toml
pretrained_model_name_or_path = "/path/to/model_trained_on.safetensors"
max_train_epochs = 12
max_data_loader_n_workers = 4
persistent_data_loader_workers = true
seed = 42
gradient_checkpointing = true
mixed_precision = "bf16"
save_precision = "bf16"
full_bf16 = true
optimizer_type = "adamw8bit"
learning_rate = 2e-4
xformers = true
output_dir = "/path/to/output/dir"
output_name = "output_name"
save_every_n_epochs = 1
save_model_as = "safetensors"
vae_batch_size = 4
cache_latents = true
cache_latents_to_disk = true
cache_text_encoder_outputs = true
cache_text_encoder_outputs_to_disk = true
network_dim = 64
cond_emb_dim = 32
dataset_config = "/path/to/dataset.toml"
```
### Inference
If you want to generate images with a script, run `sdxl_gen_img.py`. You can specify the LLLite model file with `--control_net_lllite_models`. The dimension is automatically obtained from the model file.
Specify the conditioning image to be used for inference with `--guide_image_path`. Since preprocess is not performed, if it is Canny, specify an image processed with Canny (white line on black background). `--control_net_preps`, `--control_net_weights`, and `--control_net_ratios` are not supported.
## How to synthesize a dataset
### Generating training images
Generate images with the base model for training. Please generate them with Web UI or ComfyUI etc. The image size should be the default size of the model (1024x1024, etc.). You can also use bucketing. In that case, please generate it at an arbitrary resolution.
The captions and other settings when generating the images should be the same as when generating the images with the trained ControlNet-LLLite model.
Save the generated images in an arbitrary directory. Specify this directory in the dataset configuration file.
You can also generate them with `sdxl_gen_img.py` in this repository. For example, run as follows:
```dos
python sdxl_gen_img.py --ckpt path/to/model.safetensors --n_iter 1 --scale 10 --steps 36 --outdir path/to/output/dir --xformers --W 1024 --H 1024 --original_width 2048 --original_height 2048 --bf16 --sampler ddim --batch_size 4 --vae_batch_size 2 --images_per_prompt 512 --max_embeddings_multiples 1 --prompt "{portrait|digital art|anime screen cap|detailed illustration} of 1girl, {standing|sitting|walking|running|dancing} on {classroom|street|town|beach|indoors|outdoors}, {looking at viewer|looking away|looking at another}, {in|wearing} {shirt and skirt|school uniform|casual wear} { |, dynamic pose}, (solo), teen age, {0-1$$smile,|blush,|kind smile,|expression less,|happy,|sadness,} {0-1$$upper body,|full body,|cowboy shot,|face focus,} trending on pixiv, {0-2$$depth of fields,|8k wallpaper,|highly detailed,|pov,} {0-1$$summer, |winter, |spring, |autumn, } beautiful face { |, from below|, from above|, from side|, from behind|, from back} --n nsfw, bad face, lowres, low quality, worst quality, low effort, watermark, signature, ugly, poorly drawn"
```
This is a setting for VRAM 24GB. Adjust `--batch_size` and `--vae_batch_size` according to the VRAM size.
The images are generated randomly using wildcards in `--prompt`. Adjust as necessary.
### Processing images
Use an external program to process the generated images. Save the processed images in an arbitrary directory. These will be the conditioning images.
For example, you can use the following script to process the images with Canny.
```python
import glob
import os
import random
import cv2
import numpy as np
IMAGES_DIR = "path/to/generated/images"
CANNY_DIR = "path/to/canny/images"
os.makedirs(CANNY_DIR, exist_ok=True)
img_files = glob.glob(IMAGES_DIR + "/*.png")
for img_file in img_files:
can_file = CANNY_DIR + "/" + os.path.basename(img_file)
if os.path.exists(can_file):
print("Skip: " + img_file)
continue
print(img_file)
img = cv2.imread(img_file)
# random threshold
# while True:
# threshold1 = random.randint(0, 127)
# threshold2 = random.randint(128, 255)
# if threshold2 - threshold1 > 80:
# break
# fixed threshold
threshold1 = 100
threshold2 = 200
img = cv2.Canny(img, threshold1, threshold2)
cv2.imwrite(can_file, img)
```
### Creating caption files
Create a caption file for each image with the same basename as the training image. It is fine to use the same caption as the one used when generating the image.
If you generated the images with `sdxl_gen_img.py`, you can use the following script to create the caption files (`*.txt`) from the metadata in the generated images.
```python
import glob
import os
from PIL import Image
IMAGES_DIR = "path/to/generated/images"
img_files = glob.glob(IMAGES_DIR + "/*.png")
for img_file in img_files:
cap_file = img_file.replace(".png", ".txt")
if os.path.exists(cap_file):
print(f"Skip: {img_file}")
continue
print(img_file)
img = Image.open(img_file)
prompt = img.text["prompt"] if "prompt" in img.text else ""
if prompt == "":
print(f"Prompt not found in {img_file}")
with open(cap_file, "w") as f:
f.write(prompt + "\n")
```
### Creating a dataset configuration file
You can use the command line arguments of `sdxl_train_control_net_lllite.py` to specify the conditioning image directory. However, if you want to use a `.toml` file, specify the conditioning image directory in `conditioning_data_dir`.
```toml
[general]
flip_aug = false
color_aug = false
resolution = [1024,1024]
[[datasets]]
batch_size = 8
enable_bucket = false
[[datasets.subsets]]
image_dir = "path/to/generated/image/dir"
caption_extension = ".txt"
conditioning_data_dir = "path/to/canny/image/dir"
```
## Credit
I would like to thank lllyasviel, the author of ControlNet, furusu, who provided me with advice on implementation and helped me solve problems, and ddPn08, who implemented the ControlNet dataset.
## Sample
Canny
![kohya_ss_girl_standing_at_classroom_smiling_to_the_viewer_class_78976b3e-0d4d-4ea0-b8e3-053ae493abbc](https://github.com/kohya-ss/sd-scripts/assets/52813779/37e9a736-649b-4c0f-ab26-880a1bf319b5)
![im_20230820104253_000_1](https://github.com/kohya-ss/sd-scripts/assets/52813779/c8896900-ab86-4120-932f-6e2ae17b77c0)
![im_20230820104302_000_1](https://github.com/kohya-ss/sd-scripts/assets/52813779/b12457a0-ee3c-450e-ba9a-b712d0fe86bb)
![im_20230820104310_000_1](https://github.com/kohya-ss/sd-scripts/assets/52813779/8845b8d9-804a-44ac-9618-113a28eac8a1)

View File

@@ -181,6 +181,8 @@ python networks\extract_lora_from_dylora.py --model "foldername/dylora-model.saf
詳細は[PR #355](https://github.com/kohya-ss/sd-scripts/pull/355) をご覧ください。
SDXLは現在サポートしていません。
フルモデルの25個のブロックの重みを指定できます。最初のブロックに該当するLoRAは存在しませんが、階層別LoRA適用等との互換性のために25個としています。またconv2d3x3に拡張しない場合も一部のブロックにはLoRAが存在しませんが、記述を統一するため常に25個の値を指定してください。
`--network_args` で以下の引数を指定してください。
@@ -246,6 +248,8 @@ network_args = [ "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,
merge_lora.pyでStable DiffusionのモデルにLoRAの学習結果をマージしたり、複数のLoRAモデルをマージしたりできます。
SDXL向けにはsdxl_merge_lora.pyを用意しています。オプション等は同一ですので、以下のmerge_lora.pyを読み替えてください。
### Stable DiffusionのモデルにLoRAのモデルをマージする
マージ後のモデルは通常のStable Diffusionのckptと同様に扱えます。たとえば以下のようなコマンドラインになります。
@@ -276,26 +280,28 @@ python networks\merge_lora.py --sd_model ..\model\model.ckpt
### 複数のLoRAのモデルをマージする
複数のLoRAモデルをひとつずつSDモデルに適用する場合と、複数のLoRAモデルをマージしてからSDモデルにマージする場合とは、計算順序の関連で微妙に異なる結果になります
--concatオプションを指定すると、複数のLoRAを単純に結合して新しいLoRAモデルを作成できます。ファイルサイズおよびdim/rankは指定したLoRAの合計サイズになりますマージ時にdim (rank)を変更する場合は `svd_merge_lora.py` を使用してください)
たとえば以下のようなコマンドラインになります。
```
python networks\merge_lora.py
python networks\merge_lora.py --save_precision bf16
--save_to ..\lora_train1\model-char1-style1-merged.safetensors
--models ..\lora_train1\last.safetensors ..\lora_train2\last.safetensors --ratios 0.6 0.4
--models ..\lora_train1\last.safetensors ..\lora_train2\last.safetensors
--ratios 1.0 -1.0 --concat --shuffle
```
--sd_modelオプション指定不要です。
--concatオプション指定します。
また--shuffleオプションを追加し、重みをシャッフルします。シャッフルしないとマージ後のLoRAから元のLoRAを取り出せるため、コピー機学習などの場合には学習元データが明らかになります。ご注意ください。
--save_toオプションにマージ後のLoRAモデルの保存先を指定します.ckptまたは.safetensors、拡張子で自動判定
--modelsに学習したLoRAのモデルファイルを指定します。三つ以上も指定可能です。
--ratiosにそれぞれのモデルの比率どのくらい重みを元モデルに反映するかを0~1.0の数値で指定します。二つのモデルを一対一でマージす場合は、「0.5 0.5」になります。「1.0 1.0」では合計の重みが大きくなりすぎて、恐らく結果はあまり望ましくないものになると思われます。
v1で学習したLoRAとv2で学習したLoRA、rank次元数や``alpha``の異なるLoRAはマージできません。U-NetだけのLoRAとU-Net+Text EncoderのLoRAはマージできるはずですが、結果は未知数です。
--ratiosにそれぞれのモデルの比率どのくらい重みを元モデルに反映するかを0~1.0の数値で指定します。二つのモデルを一対一でマージす場合は、「0.5 0.5」になります。「1.0 1.0」では合計の重みが大きくなりすぎて、恐らく結果はあまり望ましくないものになると思われます。
v1で学習したLoRAとv2で学習したLoRA、rank次元数の異なるLoRAはマージできません。U-NetだけのLoRAとU-Net+Text EncoderのLoRAはマージできるはずですが、結果は未知数です。
### その他のオプション
@@ -304,6 +310,7 @@ v1で学習したLoRAとv2で学習したLoRA、rank次元数や``alpha``
* save_precision
* モデル保存時の精度をfloat、fp16、bf16から指定できます。省略時はprecisionと同じ精度になります。
他にもいくつかのオプションがありますので、--helpで確認してください。
## 複数のrankが異なるLoRAのモデルをマージする

View File

@@ -5,13 +5,19 @@ import argparse
import gc
import math
import os
import toml
from multiprocessing import Value
import toml
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from accelerate.utils import set_seed
import diffusers
from diffusers import DDPMScheduler
import library.train_util as train_util
@@ -21,7 +27,12 @@ from library.config_util import (
BlueprintGenerator,
)
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like
from library.custom_train_functions import (
apply_snr_weight,
get_weighted_text_embeddings,
prepare_scheduler_for_custom_training,
scale_v_prediction_loss_like_noise_prediction,
)
def train(args):
@@ -35,38 +46,42 @@ def train(args):
tokenizer = train_util.load_tokenizer(args)
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, False, True))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
)
else:
user_config = {
"datasets": [
{
"subsets": [
{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}
]
}
]
}
else:
user_config = {
"datasets": [
{
"subsets": [
{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}
]
}
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
if args.debug_dataset:
train_util.debug_dataset(train_dataset_group)
@@ -84,7 +99,7 @@ def train(args):
# acceleratorを準備する
print("prepare accelerator")
accelerator, unwrap_model = train_util.prepare_accelerator(args)
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
@@ -128,13 +143,13 @@ def train(args):
# モデルに xformers とか memory efficient attention を組み込む
if args.diffusers_xformers:
print("Use xformers by Diffusers")
accelerator.print("Use xformers by Diffusers")
set_diffusers_xformers_flag(unet, True)
else:
# Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある
print("Disable Diffusers' xformers")
accelerator.print("Disable Diffusers' xformers")
set_diffusers_xformers_flag(unet, False)
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
# 学習を準備する
if cache_latents:
@@ -157,7 +172,7 @@ def train(args):
training_models.append(unet)
if args.train_text_encoder:
print("enable text encoder training")
accelerator.print("enable text encoder training")
if args.gradient_checkpointing:
text_encoder.gradient_checkpointing_enable()
training_models.append(text_encoder)
@@ -183,7 +198,7 @@ def train(args):
params_to_optimize = params
# 学習に必要なクラスを準備する
print("prepare optimizer, data loader etc.")
accelerator.print("prepare optimizer, data loader etc.")
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
# dataloaderを準備する
@@ -193,7 +208,7 @@ def train(args):
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collater,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
@@ -203,7 +218,7 @@ def train(args):
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
@@ -216,7 +231,7 @@ def train(args):
assert (
args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
print("enable full fp16 training.")
accelerator.print("enable full fp16 training.")
unet.to(weight_dtype)
text_encoder.to(weight_dtype)
@@ -246,14 +261,16 @@ def train(args):
# 学習する
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
print("running training / 学習開始")
print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
print(f" gradient accumulation steps / 勾配合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
accelerator.print("running training / 学習開始")
accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
accelerator.print(
f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
)
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
@@ -261,12 +278,18 @@ def train(args):
noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
)
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
if args.zero_terminal_snr:
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
if accelerator.is_main_process:
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name)
init_kwargs = {}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
for m in training_models:
@@ -302,21 +325,9 @@ def train(args):
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
elif args.multires_noise_iterations:
noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount)
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
# Predict the noise residual
with accelerator.autocast():
@@ -328,11 +339,16 @@ def train(args):
else:
target = noise
if args.min_snr_gamma:
# do not mean over batch dimension for snr weight
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred:
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
loss = loss.mean() # mean over batch dimension
else:
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
@@ -373,15 +389,17 @@ def train(args):
epoch,
num_train_epochs,
global_step,
unwrap_model(text_encoder),
unwrap_model(unet),
accelerator.unwrap_model(text_encoder),
accelerator.unwrap_model(unet),
vae,
)
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
if (
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
): # tracking d*lr value
logs["lr/d*lr"] = (
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
)
@@ -416,8 +434,8 @@ def train(args):
epoch,
num_train_epochs,
global_step,
unwrap_model(text_encoder),
unwrap_model(unet),
accelerator.unwrap_model(text_encoder),
accelerator.unwrap_model(unet),
vae,
)
@@ -425,8 +443,8 @@ def train(args):
is_main_process = accelerator.is_main_process
if is_main_process:
unet = unwrap_model(unet)
text_encoder = unwrap_model(text_encoder)
unet = accelerator.unwrap_model(unet)
text_encoder = accelerator.unwrap_model(text_encoder)
accelerator.end_training()

View File

@@ -3,6 +3,7 @@ import glob
import os
import json
import random
import sys
from pathlib import Path
from PIL import Image
@@ -11,6 +12,7 @@ import numpy as np
import torch
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
sys.path.append(os.path.dirname(__file__))
from blip.blip import blip_decoder
import library.train_util as train_util

View File

@@ -52,6 +52,9 @@ def collate_fn_remove_corrupted(batch):
def main(args):
r"""
transformers 4.30.2で、バッチサイズ>1でも動くようになったので、以下コメントアウト
# GITにバッチサイズが1より大きくても動くようにパッチを当てる: transformers 4.26.0用
org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation
curr_batch_size = [args.batch_size] # ループの最後で件数がbatch_size未満になるので入れ替えられるように
@@ -65,6 +68,7 @@ def main(args):
return input_ids
GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
"""
print(f"load images from {args.train_data_dir}")
train_data_dir_path = Path(args.train_data_dir)
@@ -81,7 +85,7 @@ def main(args):
def run_batch(path_imgs):
imgs = [im for _, im in path_imgs]
curr_batch_size[0] = len(path_imgs)
# curr_batch_size[0] = len(path_imgs)
inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE) # 画像はpil形式
generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length)
captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True)

View File

@@ -34,16 +34,7 @@ def collate_fn_remove_corrupted(batch):
return batch
def get_latents(vae, images, weight_dtype):
img_tensors = [IMAGE_TRANSFORMS(image) for image in images]
img_tensors = torch.stack(img_tensors)
img_tensors = img_tensors.to(DEVICE, weight_dtype)
with torch.no_grad():
latents = vae.encode(img_tensors).latent_dist.sample().float().to("cpu").numpy()
return latents
def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip, recursive):
def get_npz_filename(data_dir, image_key, is_full_path, recursive):
if is_full_path:
base_name = os.path.splitext(os.path.basename(image_key))[0]
relative_path = os.path.relpath(os.path.dirname(image_key), data_dir)
@@ -51,19 +42,20 @@ def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip, recursive):
base_name = image_key
relative_path = ""
if flip:
base_name += "_flip"
if recursive and relative_path:
return os.path.join(data_dir, relative_path, base_name)
return os.path.join(data_dir, relative_path, base_name) + ".npz"
else:
return os.path.join(data_dir, base_name)
return os.path.join(data_dir, base_name) + ".npz"
def main(args):
# assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります"
if args.bucket_reso_steps % 8 > 0:
print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります")
if args.bucket_reso_steps % 32 > 0:
print(
f"WARNING: bucket_reso_steps is not divisible by 32. It is not working with SDXL / bucket_reso_stepsが32で割り切れません。SDXLでは動作しません"
)
train_data_dir_path = Path(args.train_data_dir)
image_paths: List[str] = [str(p) for p in train_util.glob_images_pathlib(train_data_dir_path, args.recursive)]
@@ -107,34 +99,7 @@ def main(args):
def process_batch(is_last):
for bucket in bucket_manager.buckets:
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
latents = get_latents(vae, [img for _, img in bucket], weight_dtype)
assert (
latents.shape[2] == bucket[0][1].shape[0] // 8 and latents.shape[3] == bucket[0][1].shape[1] // 8
), f"latent shape {latents.shape}, {bucket[0][1].shape}"
for (image_key, _), latent in zip(bucket, latents):
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive)
np.savez(npz_file_name, latent)
# flip
if args.flip_aug:
latents = get_latents(vae, [img[:, ::-1].copy() for _, img in bucket], weight_dtype) # copyがないとTensor変換できない
for (image_key, _), latent in zip(bucket, latents):
npz_file_name = get_npz_filename_wo_ext(
args.train_data_dir, image_key, args.full_path, True, args.recursive
)
np.savez(npz_file_name, latent)
else:
# remove existing flipped npz
for image_key, _ in bucket:
npz_file_name = (
get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz"
)
if os.path.isfile(npz_file_name):
print(f"remove existing flipped npz / 既存のflipされたnpzファイルを削除します: {npz_file_name}")
os.remove(npz_file_name)
train_util.cache_batch_latents(vae, True, bucket, args.flip_aug, False)
bucket.clear()
# 読み込みの高速化のためにDataLoaderを使うオプション
@@ -194,50 +159,19 @@ def main(args):
resized_size[0] >= reso[0] and resized_size[1] >= reso[1]
), f"internal error resized size is small: {resized_size}, {reso}"
# 既に存在するファイルがあればshapeを確認して同じならskipする
# 既に存在するファイルがあればshapeを確認して同じならskipする
npz_file_name = get_npz_filename(args.train_data_dir, image_key, args.full_path, args.recursive)
if args.skip_existing:
npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive) + ".npz"]
if args.flip_aug:
npz_files.append(
get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz"
)
found = True
for npz_file in npz_files:
if not os.path.exists(npz_file):
found = False
break
dat = np.load(npz_file)["arr_0"]
if dat.shape[1] != reso[1] // 8 or dat.shape[2] != reso[0] // 8: # latentsのshapeを確認
found = False
break
if found:
if train_util.is_disk_cached_latents_is_expected(reso, npz_file_name, args.flip_aug):
continue
# 画像をリサイズしてトリミングする
# PILにinter_areaがないのでcv2で……
image = np.array(image)
if resized_size[0] != image.shape[1] or resized_size[1] != image.shape[0]: # リサイズ処理が必要?
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA)
if resized_size[0] > reso[0]:
trim_size = resized_size[0] - reso[0]
image = image[:, trim_size // 2 : trim_size // 2 + reso[0]]
if resized_size[1] > reso[1]:
trim_size = resized_size[1] - reso[1]
image = image[trim_size // 2 : trim_size // 2 + reso[1]]
assert (
image.shape[0] == reso[1] and image.shape[1] == reso[0]
), f"internal error, illegal trimmed size: {image.shape}, {reso}"
# # debug
# cv2.imwrite(f"r:\\test\\img_{len(img_ar_errors)}.jpg", image[:, :, ::-1])
# バッチへ追加
bucket_manager.add_image(reso, (image_key, image))
image_info = train_util.ImageInfo(image_key, 1, "", False, image_path)
image_info.latents_npz = npz_file_name
image_info.bucket_reso = reso
image_info.resized_size = resized_size
image_info.image = image
bucket_manager.add_image(reso, image_info)
# バッチを推論するか判定して推論する
process_batch(False)

View File

@@ -1,17 +1,15 @@
import argparse
import csv
import glob
import os
from PIL import Image
import cv2
from tqdm import tqdm
import numpy as np
from tensorflow.keras.models import load_model
from huggingface_hub import hf_hub_download
import torch
from pathlib import Path
import cv2
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from tqdm import tqdm
import library.train_util as train_util
# from wd14 tagger
@@ -20,6 +18,7 @@ IMAGE_SIZE = 448
# wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2
DEFAULT_WD14_TAGGER_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"]
FILES_ONNX = ["model.onnx"]
SUB_DIR = "variables"
SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"]
CSV_FILE = FILES[-1]
@@ -81,7 +80,10 @@ def main(args):
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
if not os.path.exists(args.model_dir) or args.force_download:
print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
for file in FILES:
files = FILES
if args.onnx:
files += FILES_ONNX
for file in files:
hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
for file in SUB_DIR_FILES:
hf_hub_download(
@@ -96,7 +98,46 @@ def main(args):
print("using existing wd14 tagger model")
# 画像を読み込む
model = load_model(args.model_dir)
if args.onnx:
import onnx
import onnxruntime as ort
onnx_path = f"{args.model_dir}/model.onnx"
print("Running wd14 tagger with onnx")
print(f"loading onnx model: {onnx_path}")
if not os.path.exists(onnx_path):
raise Exception(
f"onnx model not found: {onnx_path}, please redownload the model with --force_download"
+ " / onnxモデルが見つかりませんでした。--force_downloadで再ダウンロードしてください"
)
model = onnx.load(onnx_path)
input_name = model.graph.input[0].name
try:
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_value
except:
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_param
if args.batch_size != batch_size and type(batch_size) != str:
# some rebatch model may use 'N' as dynamic axes
print(
f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}"
)
args.batch_size = batch_size
del model
ort_sess = ort.InferenceSession(
onnx_path,
providers=["CUDAExecutionProvider"]
if "CUDAExecutionProvider" in ort.get_available_providers()
else ["CPUExecutionProvider"],
)
else:
from tensorflow.keras.models import load_model
model = load_model(f"{args.model_dir}")
# label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv")
# 依存ライブラリを増やしたくないので自力で読むよ
@@ -124,8 +165,14 @@ def main(args):
def run_batch(path_imgs):
imgs = np.array([im for _, im in path_imgs])
probs = model(imgs, training=False)
probs = probs.numpy()
if args.onnx:
if len(imgs) < args.batch_size:
imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0)
probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy
probs = probs[: len(path_imgs)]
else:
probs = model(imgs, training=False)
probs = probs.numpy()
for (image_path, _), prob in zip(path_imgs, probs):
# 最初の4つはratingなので無視する
@@ -165,9 +212,27 @@ def main(args):
if len(character_tag_text) > 0:
character_tag_text = character_tag_text[2:]
caption_file = os.path.splitext(image_path)[0] + args.caption_extension
tag_text = ", ".join(combined_tags)
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
if args.append_tags:
# Check if file exists
if os.path.exists(caption_file):
with open(caption_file, "rt", encoding="utf-8") as f:
# Read file and remove new lines
existing_content = f.read().strip("\n") # Remove newlines
# Split the content into tags and store them in a list
existing_tags = [tag.strip() for tag in existing_content.split(",") if tag.strip()]
# Check and remove repeating tags in tag_text
new_tags = [tag for tag in combined_tags if tag not in existing_tags]
# Create new tag_text
tag_text = ", ".join(existing_tags + new_tags)
with open(caption_file, "wt", encoding="utf-8") as f:
f.write(tag_text + "\n")
if args.debug:
print(f"\n{image_path}:\n Character tags: {character_tag_text}\n General tags: {general_tag_text}")
@@ -283,12 +348,15 @@ def setup_parser() -> argparse.ArgumentParser:
help="comma-separated list of undesired tags to remove from the output / 出力から除外したいタグのカンマ区切りのリスト",
)
parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する")
parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する")
parser.add_argument("--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する")
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
# スペルミスしていたオプションを復元する

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,227 @@
import math
from typing import Any
from einops import rearrange
import torch
from diffusers.models.attention_processor import Attention
# flash attention forwards and backwards
# https://arxiv.org/abs/2205.14135
EPSILON = 1e-6
class FlashAttentionFunction(torch.autograd.function.Function):
@staticmethod
@torch.no_grad()
def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
"""Algorithm 2 in the paper"""
device = q.device
dtype = q.dtype
max_neg_value = -torch.finfo(q.dtype).max
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
o = torch.zeros_like(q)
all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
all_row_maxes = torch.full(
(*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device
)
scale = q.shape[-1] ** -0.5
if mask is None:
mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
else:
mask = rearrange(mask, "b n -> b 1 1 n")
mask = mask.split(q_bucket_size, dim=-1)
row_splits = zip(
q.split(q_bucket_size, dim=-2),
o.split(q_bucket_size, dim=-2),
mask,
all_row_sums.split(q_bucket_size, dim=-2),
all_row_maxes.split(q_bucket_size, dim=-2),
)
for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
q_start_index = ind * q_bucket_size - qk_len_diff
col_splits = zip(
k.split(k_bucket_size, dim=-2),
v.split(k_bucket_size, dim=-2),
)
for k_ind, (kc, vc) in enumerate(col_splits):
k_start_index = k_ind * k_bucket_size
attn_weights = (
torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
)
if row_mask is not None:
attn_weights.masked_fill_(~row_mask, max_neg_value)
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
causal_mask = torch.ones(
(qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
).triu(q_start_index - k_start_index + 1)
attn_weights.masked_fill_(causal_mask, max_neg_value)
block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
attn_weights -= block_row_maxes
exp_weights = torch.exp(attn_weights)
if row_mask is not None:
exp_weights.masked_fill_(~row_mask, 0.0)
block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(
min=EPSILON
)
new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
exp_values = torch.einsum(
"... i j, ... j d -> ... i d", exp_weights, vc
)
exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
new_row_sums = (
exp_row_max_diff * row_sums
+ exp_block_row_max_diff * block_row_sums
)
oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_(
(exp_block_row_max_diff / new_row_sums) * exp_values
)
row_maxes.copy_(new_row_maxes)
row_sums.copy_(new_row_sums)
ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
return o
@staticmethod
@torch.no_grad()
def backward(ctx, do):
"""Algorithm 4 in the paper"""
causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
q, k, v, o, l, m = ctx.saved_tensors
device = q.device
max_neg_value = -torch.finfo(q.dtype).max
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
dq = torch.zeros_like(q)
dk = torch.zeros_like(k)
dv = torch.zeros_like(v)
row_splits = zip(
q.split(q_bucket_size, dim=-2),
o.split(q_bucket_size, dim=-2),
do.split(q_bucket_size, dim=-2),
mask,
l.split(q_bucket_size, dim=-2),
m.split(q_bucket_size, dim=-2),
dq.split(q_bucket_size, dim=-2),
)
for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
q_start_index = ind * q_bucket_size - qk_len_diff
col_splits = zip(
k.split(k_bucket_size, dim=-2),
v.split(k_bucket_size, dim=-2),
dk.split(k_bucket_size, dim=-2),
dv.split(k_bucket_size, dim=-2),
)
for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
k_start_index = k_ind * k_bucket_size
attn_weights = (
torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
)
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
causal_mask = torch.ones(
(qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
).triu(q_start_index - k_start_index + 1)
attn_weights.masked_fill_(causal_mask, max_neg_value)
exp_attn_weights = torch.exp(attn_weights - mc)
if row_mask is not None:
exp_attn_weights.masked_fill_(~row_mask, 0.0)
p = exp_attn_weights / lc
dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc)
dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc)
D = (doc * oc).sum(dim=-1, keepdims=True)
ds = p * scale * (dp - D)
dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc)
dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc)
dqc.add_(dq_chunk)
dkc.add_(dk_chunk)
dvc.add_(dv_chunk)
return dq, dk, dv, None, None, None, None
class FlashAttnProcessor:
def __call__(
self,
attn: Attention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
) -> Any:
q_bucket_size = 512
k_bucket_size = 1024
h = attn.heads
q = attn.to_q(hidden_states)
encoder_hidden_states = (
encoder_hidden_states
if encoder_hidden_states is not None
else hidden_states
)
encoder_hidden_states = encoder_hidden_states.to(hidden_states.dtype)
if hasattr(attn, "hypernetwork") and attn.hypernetwork is not None:
context_k, context_v = attn.hypernetwork.forward(
hidden_states, encoder_hidden_states
)
context_k = context_k.to(hidden_states.dtype)
context_v = context_v.to(hidden_states.dtype)
else:
context_k = encoder_hidden_states
context_v = encoder_hidden_states
k = attn.to_k(context_k)
v = attn.to_v(context_v)
del encoder_hidden_states, hidden_states
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
out = FlashAttentionFunction.apply(
q, k, v, attention_mask, False, q_bucket_size, k_bucket_size
)
out = rearrange(out, "b h n d -> b n (h d)")
out = attn.to_out[0](out)
out = attn.to_out[1](out)
return out

View File

@@ -33,8 +33,10 @@ from . import train_util
from .train_util import (
DreamBoothSubset,
FineTuningSubset,
ControlNetSubset,
DreamBoothDataset,
FineTuningDataset,
ControlNetDataset,
DatasetGroup,
)
@@ -54,6 +56,8 @@ class BaseSubsetParams:
flip_aug: bool = False
face_crop_aug_range: Optional[Tuple[float, float]] = None
random_crop: bool = False
caption_prefix: Optional[str] = None
caption_suffix: Optional[str] = None
caption_dropout_rate: float = 0.0
caption_dropout_every_n_epochs: int = 0
caption_tag_dropout_rate: float = 0.0
@@ -70,9 +74,14 @@ class DreamBoothSubsetParams(BaseSubsetParams):
class FineTuningSubsetParams(BaseSubsetParams):
metadata_file: Optional[str] = None
@dataclass
class ControlNetSubsetParams(BaseSubsetParams):
conditioning_data_dir: str = None
caption_extension: str = ".caption"
@dataclass
class BaseDatasetParams:
tokenizer: CLIPTokenizer = None
tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None
max_token_length: int = None
resolution: Optional[Tuple[int, int]] = None
debug_dataset: bool = False
@@ -96,6 +105,15 @@ class FineTuningDatasetParams(BaseDatasetParams):
bucket_reso_steps: int = 64
bucket_no_upscale: bool = False
@dataclass
class ControlNetDatasetParams(BaseDatasetParams):
batch_size: int = 1
enable_bucket: bool = False
min_bucket_reso: int = 256
max_bucket_reso: int = 1024
bucket_reso_steps: int = 64
bucket_no_upscale: bool = False
@dataclass
class SubsetBlueprint:
params: Union[DreamBoothSubsetParams, FineTuningSubsetParams]
@@ -103,6 +121,7 @@ class SubsetBlueprint:
@dataclass
class DatasetBlueprint:
is_dreambooth: bool
is_controlnet: bool
params: Union[DreamBoothDatasetParams, FineTuningDatasetParams]
subsets: Sequence[SubsetBlueprint]
@@ -142,6 +161,8 @@ class ConfigSanitizer:
"keep_tokens": int,
"token_warmup_min": int,
"token_warmup_step": Any(float,int),
"caption_prefix": str,
"caption_suffix": str,
}
# DO means DropOut
DO_SUBSET_ASCENDABLE_SCHEMA = {
@@ -163,6 +184,13 @@ class ConfigSanitizer:
Required("metadata_file"): str,
"image_dir": str,
}
CN_SUBSET_ASCENDABLE_SCHEMA = {
"caption_extension": str,
}
CN_SUBSET_DISTINCT_SCHEMA = {
Required("image_dir"): str,
Required("conditioning_data_dir"): str,
}
# datasets schema
DATASET_ASCENDABLE_SCHEMA = {
@@ -192,8 +220,8 @@ class ConfigSanitizer:
"dataset_repeats": "num_repeats",
}
def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_dropout: bool) -> None:
assert support_dreambooth or support_finetuning, "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。"
def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None:
assert support_dreambooth or support_finetuning or support_controlnet, "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。"
self.db_subset_schema = self.__merge_dict(
self.SUBSET_ASCENDABLE_SCHEMA,
@@ -208,6 +236,13 @@ class ConfigSanitizer:
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
)
self.cn_subset_schema = self.__merge_dict(
self.SUBSET_ASCENDABLE_SCHEMA,
self.CN_SUBSET_DISTINCT_SCHEMA,
self.CN_SUBSET_ASCENDABLE_SCHEMA,
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
)
self.db_dataset_schema = self.__merge_dict(
self.DATASET_ASCENDABLE_SCHEMA,
self.SUBSET_ASCENDABLE_SCHEMA,
@@ -223,13 +258,23 @@ class ConfigSanitizer:
{"subsets": [self.ft_subset_schema]},
)
self.cn_dataset_schema = self.__merge_dict(
self.DATASET_ASCENDABLE_SCHEMA,
self.SUBSET_ASCENDABLE_SCHEMA,
self.CN_SUBSET_ASCENDABLE_SCHEMA,
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
{"subsets": [self.cn_subset_schema]},
)
if support_dreambooth and support_finetuning:
def validate_flex_dataset(dataset_config: dict):
subsets_config = dataset_config.get("subsets", [])
if support_controlnet and all(["conditioning_data_dir" in subset for subset in subsets_config]):
return Schema(self.cn_dataset_schema)(dataset_config)
# check dataset meets FT style
# NOTE: all FT subsets should have "metadata_file"
if all(["metadata_file" in subset for subset in subsets_config]):
elif all(["metadata_file" in subset for subset in subsets_config]):
return Schema(self.ft_dataset_schema)(dataset_config)
# check dataset meets DB style
# NOTE: all DB subsets should have no "metadata_file"
@@ -241,13 +286,16 @@ class ConfigSanitizer:
self.dataset_schema = validate_flex_dataset
elif support_dreambooth:
self.dataset_schema = self.db_dataset_schema
else:
elif support_finetuning:
self.dataset_schema = self.ft_dataset_schema
elif support_controlnet:
self.dataset_schema = self.cn_dataset_schema
self.general_schema = self.__merge_dict(
self.DATASET_ASCENDABLE_SCHEMA,
self.SUBSET_ASCENDABLE_SCHEMA,
self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {},
self.CN_SUBSET_ASCENDABLE_SCHEMA if support_controlnet else {},
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
)
@@ -318,7 +366,11 @@ class BlueprintGenerator:
# NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets
subsets = dataset_config.get("subsets", [])
is_dreambooth = all(["metadata_file" not in subset for subset in subsets])
if is_dreambooth:
is_controlnet = all(["conditioning_data_dir" in subset for subset in subsets])
if is_controlnet:
subset_params_klass = ControlNetSubsetParams
dataset_params_klass = ControlNetDatasetParams
elif is_dreambooth:
subset_params_klass = DreamBoothSubsetParams
dataset_params_klass = DreamBoothDatasetParams
else:
@@ -333,7 +385,7 @@ class BlueprintGenerator:
params = self.generate_params_by_fallbacks(dataset_params_klass,
[dataset_config, general_config, argparse_config, runtime_params])
dataset_blueprints.append(DatasetBlueprint(is_dreambooth, params, subset_blueprints))
dataset_blueprints.append(DatasetBlueprint(is_dreambooth, is_controlnet, params, subset_blueprints))
dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints)
@@ -361,10 +413,13 @@ class BlueprintGenerator:
def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint):
datasets: List[Union[DreamBoothDataset, FineTuningDataset]] = []
datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
for dataset_blueprint in dataset_group_blueprint.datasets:
if dataset_blueprint.is_dreambooth:
if dataset_blueprint.is_controlnet:
subset_klass = ControlNetSubset
dataset_klass = ControlNetDataset
elif dataset_blueprint.is_dreambooth:
subset_klass = DreamBoothSubset
dataset_klass = DreamBoothDataset
else:
@@ -379,6 +434,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
info = ""
for i, dataset in enumerate(datasets):
is_dreambooth = isinstance(dataset, DreamBoothDataset)
is_controlnet = isinstance(dataset, ControlNetDataset)
info += dedent(f"""\
[Dataset {i}]
batch_size: {dataset.batch_size}
@@ -407,6 +463,8 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
caption_dropout_rate: {subset.caption_dropout_rate}
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
caption_prefix: {subset.caption_prefix}
caption_suffix: {subset.caption_suffix}
color_aug: {subset.color_aug}
flip_aug: {subset.flip_aug}
face_crop_aug_range: {subset.face_crop_aug_range}
@@ -421,7 +479,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
class_tokens: {subset.class_tokens}
caption_extension: {subset.caption_extension}
\n"""), " ")
else:
elif not is_controlnet:
info += indent(dedent(f"""\
metadata_file: {subset.metadata_file}
\n"""), " ")
@@ -479,6 +537,27 @@ def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str]
return subsets_config
def generate_controlnet_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, conditioning_data_dir: Optional[str] = None, caption_extension: str = ".txt"):
def generate(base_dir: Optional[str]):
if base_dir is None:
return []
base_dir: Path = Path(base_dir)
if not base_dir.is_dir():
return []
subsets_config = []
subset_config = {"image_dir": train_data_dir, "conditioning_data_dir": conditioning_data_dir, "caption_extension": caption_extension, "num_repeats": 1}
subsets_config.append(subset_config)
return subsets_config
subsets_config = []
subsets_config += generate(train_data_dir)
return subsets_config
def load_user_config(file: str) -> dict:
file: Path = Path(file)
if not file.is_file():
@@ -507,6 +586,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--support_dreambooth", action="store_true")
parser.add_argument("--support_finetuning", action="store_true")
parser.add_argument("--support_controlnet", action="store_true")
parser.add_argument("--support_dropout", action="store_true")
parser.add_argument("dataset_config")
config_args, remain = parser.parse_known_args()
@@ -525,7 +605,7 @@ if __name__ == "__main__":
print("\n[user_config]")
print(user_config)
sanitizer = ConfigSanitizer(config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout)
sanitizer = ConfigSanitizer(config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout)
sanitized_user_config = sanitizer.sanitize_user_config(user_config)
print("\n[sanitized_user_config]")

View File

@@ -5,20 +5,91 @@ import re
from typing import List, Optional, Union
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
def prepare_scheduler_for_custom_training(noise_scheduler, device):
if hasattr(noise_scheduler, "all_snr"):
return
alphas_cumprod = noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
alpha = sqrt_alphas_cumprod
sigma = sqrt_one_minus_alphas_cumprod
all_snr = (alpha / sigma) ** 2
snr = torch.stack([all_snr[t] for t in timesteps])
noise_scheduler.all_snr = all_snr.to(device)
def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
# fix beta: zero terminal SNR
print(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891")
def enforce_zero_terminal_snr(betas):
# Convert betas to alphas_bar_sqrt
alphas = 1 - betas
alphas_bar = alphas.cumprod(0)
alphas_bar_sqrt = alphas_bar.sqrt()
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
# Shift so last timestep is zero.
alphas_bar_sqrt -= alphas_bar_sqrt_T
# Scale so first timestep is back to old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2
alphas = alphas_bar[1:] / alphas_bar[:-1]
alphas = torch.cat([alphas_bar[0:1], alphas])
betas = 1 - alphas
return betas
betas = noise_scheduler.betas
betas = enforce_zero_terminal_snr(betas)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
# print("original:", noise_scheduler.betas)
# print("fixed:", betas)
noise_scheduler.betas = betas
noise_scheduler.alphas = alphas
noise_scheduler.alphas_cumprod = alphas_cumprod
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float() # from paper
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device) # from paper
loss = loss * snr_weight
return loss
def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler):
scale = get_snr_scale(timesteps, noise_scheduler)
loss = loss * scale
return loss
def get_snr_scale(timesteps, noise_scheduler):
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
scale = snr_t / (snr_t + 1)
# # show debug info
# print(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}")
return scale
def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss):
scale = get_snr_scale(timesteps, noise_scheduler)
# print(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
loss = loss + loss / scale * v_pred_like_loss
return loss
# TODO train_utilと分散しているのでどちらかに寄せる
def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted_captions: bool = True):
parser.add_argument(
"--min_snr_gamma",
@@ -26,6 +97,17 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted
default=None,
help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨",
)
parser.add_argument(
"--scale_v_pred_loss_like_noise_pred",
action="store_true",
help="scale v-prediction loss like noise prediction loss / v-prediction lossをnoise prediction lossと同じようにスケーリングする",
)
parser.add_argument(
"--v_pred_like_loss",
type=float,
default=None,
help="add v-prediction like loss multiplied by this value / v-prediction lossをこの値をかけたものをlossに加算する",
)
if support_weighted_captions:
parser.add_argument(
"--weighted_captions",
@@ -240,11 +322,6 @@ def get_unweighted_text_embeddings(
text_embedding = enc_out["hidden_states"][-clip_skip]
text_embedding = text_encoder.text_model.final_layer_norm(text_embedding)
# cover the head and the tail by the starting and the ending tokens
text_input_chunk[:, 0] = text_input[0, 0]
text_input_chunk[:, -1] = text_input[0, -1]
text_embedding = text_encoder(text_input_chunk, attention_mask=None)[0]
if no_boseos_middle:
if i == 0:
# discard the ending token
@@ -259,7 +336,12 @@ def get_unweighted_text_embeddings(
text_embeddings.append(text_embedding)
text_embeddings = torch.concat(text_embeddings, axis=1)
else:
text_embeddings = text_encoder(text_input)[0]
if clip_skip is None or clip_skip == 1:
text_embeddings = text_encoder(text_input)[0]
else:
enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
text_embeddings = enc_out["hidden_states"][-clip_skip]
text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings)
return text_embeddings
@@ -346,12 +428,88 @@ def get_weighted_text_embeddings(
# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2
def pyramid_noise_like(noise, device, iterations=6, discount=0.3):
b, c, w, h = noise.shape
u = torch.nn.Upsample(size=(w, h), mode='bilinear').to(device)
def pyramid_noise_like(noise, device, iterations=6, discount=0.4):
b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant!
u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
for i in range(iterations):
r = random.random()*2+2 # Rather than always going 2x,
w, h = max(1, int(w/(r**i))), max(1, int(h/(r**i)))
noise += u(torch.randn(b, c, w, h).to(device)) * discount**i
if w==1 or h==1: break # Lowest resolution is 1x1
return noise/noise.std() # Scaled back to roughly unit variance
r = random.random() * 2 + 2 # Rather than always going 2x,
wn, hn = max(1, int(w / (r**i))), max(1, int(h / (r**i)))
noise += u(torch.randn(b, c, wn, hn).to(device)) * discount**i
if wn == 1 or hn == 1:
break # Lowest resolution is 1x1
return noise / noise.std() # Scaled back to roughly unit variance
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
if noise_offset is None:
return noise
if adaptive_noise_scale is not None:
# latent shape: (batch_size, channels, height, width)
# abs mean value for each channel
latent_mean = torch.abs(latents.mean(dim=(2, 3), keepdim=True))
# multiply adaptive noise scale to the mean value and add it to the noise offset
noise_offset = noise_offset + adaptive_noise_scale * latent_mean
noise_offset = torch.clamp(noise_offset, 0.0, None) # in case of adaptive noise scale is negative
noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
return noise
"""
##########################################
# Perlin Noise
def rand_perlin_2d(device, shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
delta = (res[0] / shape[0], res[1] / shape[1])
d = (shape[0] // res[0], shape[1] // res[1])
grid = (
torch.stack(
torch.meshgrid(torch.arange(0, res[0], delta[0], device=device), torch.arange(0, res[1], delta[1], device=device)),
dim=-1,
)
% 1
)
angles = 2 * torch.pi * torch.rand(res[0] + 1, res[1] + 1, device=device)
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
tile_grads = (
lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
.repeat_interleave(d[0], 0)
.repeat_interleave(d[1], 1)
)
dot = lambda grad, shift: (
torch.stack((grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), dim=-1)
* grad[: shape[0], : shape[1]]
).sum(dim=-1)
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
t = fade(grid[: shape[0], : shape[1]])
return 1.414 * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5):
noise = torch.zeros(shape, device=device)
frequency = 1
amplitude = 1
for _ in range(octaves):
noise += amplitude * rand_perlin_2d(device, shape, (frequency * res[0], frequency * res[1]))
frequency *= 2
amplitude *= persistence
return noise
def perlin_noise(noise, device, octaves):
_, c, w, h = noise.shape
perlin = lambda: rand_perlin_2d_octaves(device, (w, h), (4, 4), octaves)
noise_perlin = []
for _ in range(c):
noise_perlin.append(perlin())
noise_perlin = torch.stack(noise_perlin).unsqueeze(0) # (1, c, w, h)
noise += noise_perlin # broadcast for each batch
return noise / noise.std() # Scaled back to roughly unit variance
"""

View File

@@ -6,9 +6,7 @@ import os
from library.utils import fire_in_thread
def exists_repo(
repo_id: str, repo_type: str, revision: str = "main", token: str = None
):
def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None):
api = HfApi(
token=token,
)
@@ -28,31 +26,39 @@ def upload(
repo_id = args.huggingface_repo_id
repo_type = args.huggingface_repo_type
token = args.huggingface_token
path_in_repo = args.huggingface_path_in_repo + dest_suffix
path_in_repo = args.huggingface_path_in_repo + dest_suffix if args.huggingface_path_in_repo is not None else None
private = args.huggingface_repo_visibility is None or args.huggingface_repo_visibility != "public"
api = HfApi(token=token)
if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token):
api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
try:
api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
except Exception as e: # とりあえずRepositoryNotFoundErrorは確認したが他にあると困るので
print("===========================================")
print(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}")
print("===========================================")
is_folder = (type(src) == str and os.path.isdir(src)) or (
isinstance(src, Path) and src.is_dir()
)
is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir())
def uploader():
if is_folder:
api.upload_folder(
repo_id=repo_id,
repo_type=repo_type,
folder_path=src,
path_in_repo=path_in_repo,
)
else:
api.upload_file(
repo_id=repo_id,
repo_type=repo_type,
path_or_fileobj=src,
path_in_repo=path_in_repo,
)
try:
if is_folder:
api.upload_folder(
repo_id=repo_id,
repo_type=repo_type,
folder_path=src,
path_in_repo=path_in_repo,
)
else:
api.upload_file(
repo_id=repo_id,
repo_type=repo_type,
path_or_fileobj=src,
path_in_repo=path_in_repo,
)
except Exception as e: # RuntimeErrorを確認済みだが他にあると困るので
print("===========================================")
print(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}")
print("===========================================")
if args.async_upload and not force_sync_upload:
fire_in_thread(uploader)
@@ -71,7 +77,5 @@ def list_dir(
token=token,
)
repo_info = api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
file_list = [
file for file in repo_info.siblings if file.rfilename.startswith(subfolder)
]
file_list = [file for file in repo_info.siblings if file.rfilename.startswith(subfolder)]
return file_list

223
library/hypernetwork.py Normal file
View File

@@ -0,0 +1,223 @@
import torch
import torch.nn.functional as F
from diffusers.models.attention_processor import (
Attention,
AttnProcessor2_0,
SlicedAttnProcessor,
XFormersAttnProcessor
)
try:
import xformers.ops
except:
xformers = None
loaded_networks = []
def apply_single_hypernetwork(
hypernetwork, hidden_states, encoder_hidden_states
):
context_k, context_v = hypernetwork.forward(hidden_states, encoder_hidden_states)
return context_k, context_v
def apply_hypernetworks(context_k, context_v, layer=None):
if len(loaded_networks) == 0:
return context_v, context_v
for hypernetwork in loaded_networks:
context_k, context_v = hypernetwork.forward(context_k, context_v)
context_k = context_k.to(dtype=context_k.dtype)
context_v = context_v.to(dtype=context_k.dtype)
return context_k, context_v
def xformers_forward(
self: XFormersAttnProcessor,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
attention_mask: torch.Tensor = None,
):
batch_size, sequence_length, _ = (
hidden_states.shape
if encoder_hidden_states is None
else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(
attention_mask, sequence_length, batch_size
)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
key = attn.to_k(context_k)
value = attn.to_v(context_v)
query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
hidden_states = xformers.ops.memory_efficient_attention(
query,
key,
value,
attn_bias=attention_mask,
op=self.attention_op,
scale=attn.scale,
)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
def sliced_attn_forward(
self: SlicedAttnProcessor,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
attention_mask: torch.Tensor = None,
):
batch_size, sequence_length, _ = (
hidden_states.shape
if encoder_hidden_states is None
else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(
attention_mask, sequence_length, batch_size
)
query = attn.to_q(hidden_states)
dim = query.shape[-1]
query = attn.head_to_batch_dim(query)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
key = attn.to_k(context_k)
value = attn.to_v(context_v)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
batch_size_attention, query_tokens, _ = query.shape
hidden_states = torch.zeros(
(batch_size_attention, query_tokens, dim // attn.heads),
device=query.device,
dtype=query.dtype,
)
for i in range(batch_size_attention // self.slice_size):
start_idx = i * self.slice_size
end_idx = (i + 1) * self.slice_size
query_slice = query[start_idx:end_idx]
key_slice = key[start_idx:end_idx]
attn_mask_slice = (
attention_mask[start_idx:end_idx] if attention_mask is not None else None
)
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
def v2_0_forward(
self: AttnProcessor2_0,
attn: Attention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
):
batch_size, sequence_length, _ = (
hidden_states.shape
if encoder_hidden_states is None
else encoder_hidden_states.shape
)
inner_dim = hidden_states.shape[-1]
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(
attention_mask, sequence_length, batch_size
)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(
batch_size, attn.heads, -1, attention_mask.shape[-1]
)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
key = attn.to_k(context_k)
value = attn.to_v(context_v)
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
def replace_attentions_for_hypernetwork():
import diffusers.models.attention_processor
diffusers.models.attention_processor.XFormersAttnProcessor.__call__ = (
xformers_forward
)
diffusers.models.attention_processor.SlicedAttnProcessor.__call__ = (
sliced_attn_forward
)
diffusers.models.attention_processor.AttnProcessor2_0.__call__ = v2_0_forward

175
library/ipex/__init__.py Normal file
View File

@@ -0,0 +1,175 @@
import os
import sys
import contextlib
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
from .hijacks import ipex_hijacks
from .attention import attention_init
# pylint: disable=protected-access, missing-function-docstring, line-too-long
def ipex_init(): # pylint: disable=too-many-statements
try:
#Replace cuda with xpu:
torch.cuda.current_device = torch.xpu.current_device
torch.cuda.current_stream = torch.xpu.current_stream
torch.cuda.device = torch.xpu.device
torch.cuda.device_count = torch.xpu.device_count
torch.cuda.device_of = torch.xpu.device_of
torch.cuda.get_device_name = torch.xpu.get_device_name
torch.cuda.get_device_properties = torch.xpu.get_device_properties
torch.cuda.init = torch.xpu.init
torch.cuda.is_available = torch.xpu.is_available
torch.cuda.is_initialized = torch.xpu.is_initialized
torch.cuda.is_current_stream_capturing = lambda: False
torch.cuda.set_device = torch.xpu.set_device
torch.cuda.stream = torch.xpu.stream
torch.cuda.synchronize = torch.xpu.synchronize
torch.cuda.Event = torch.xpu.Event
torch.cuda.Stream = torch.xpu.Stream
torch.cuda.FloatTensor = torch.xpu.FloatTensor
torch.Tensor.cuda = torch.Tensor.xpu
torch.Tensor.is_cuda = torch.Tensor.is_xpu
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
torch.cuda._initialized = torch.xpu.lazy_init._initialized
torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
torch.cuda._tls = torch.xpu.lazy_init._tls
torch.cuda.threading = torch.xpu.lazy_init.threading
torch.cuda.traceback = torch.xpu.lazy_init.traceback
torch.cuda.Optional = torch.xpu.Optional
torch.cuda.__cached__ = torch.xpu.__cached__
torch.cuda.__loader__ = torch.xpu.__loader__
torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
torch.cuda.Tuple = torch.xpu.Tuple
torch.cuda.streams = torch.xpu.streams
torch.cuda._lazy_new = torch.xpu._lazy_new
torch.cuda.FloatStorage = torch.xpu.FloatStorage
torch.cuda.Any = torch.xpu.Any
torch.cuda.__doc__ = torch.xpu.__doc__
torch.cuda.default_generators = torch.xpu.default_generators
torch.cuda.HalfTensor = torch.xpu.HalfTensor
torch.cuda._get_device_index = torch.xpu._get_device_index
torch.cuda.__path__ = torch.xpu.__path__
torch.cuda.Device = torch.xpu.Device
torch.cuda.IntTensor = torch.xpu.IntTensor
torch.cuda.ByteStorage = torch.xpu.ByteStorage
torch.cuda.set_stream = torch.xpu.set_stream
torch.cuda.BoolStorage = torch.xpu.BoolStorage
torch.cuda.os = torch.xpu.os
torch.cuda.torch = torch.xpu.torch
torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
torch.cuda.Union = torch.xpu.Union
torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
torch.cuda.ShortTensor = torch.xpu.ShortTensor
torch.cuda.LongTensor = torch.xpu.LongTensor
torch.cuda.IntStorage = torch.xpu.IntStorage
torch.cuda.LongStorage = torch.xpu.LongStorage
torch.cuda.__annotations__ = torch.xpu.__annotations__
torch.cuda.__package__ = torch.xpu.__package__
torch.cuda.__builtins__ = torch.xpu.__builtins__
torch.cuda.CharTensor = torch.xpu.CharTensor
torch.cuda.List = torch.xpu.List
torch.cuda._lazy_init = torch.xpu._lazy_init
torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
torch.cuda.ByteTensor = torch.xpu.ByteTensor
torch.cuda.StreamContext = torch.xpu.StreamContext
torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
torch.cuda.ShortStorage = torch.xpu.ShortStorage
torch.cuda._lazy_call = torch.xpu._lazy_call
torch.cuda.HalfStorage = torch.xpu.HalfStorage
torch.cuda.random = torch.xpu.random
torch.cuda._device = torch.xpu._device
torch.cuda.classproperty = torch.xpu.classproperty
torch.cuda.__name__ = torch.xpu.__name__
torch.cuda._device_t = torch.xpu._device_t
torch.cuda.warnings = torch.xpu.warnings
torch.cuda.__spec__ = torch.xpu.__spec__
torch.cuda.BoolTensor = torch.xpu.BoolTensor
torch.cuda.CharStorage = torch.xpu.CharStorage
torch.cuda.__file__ = torch.xpu.__file__
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
#torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
#Memory:
torch.cuda.memory = torch.xpu.memory
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
torch.xpu.empty_cache = lambda: None
torch.cuda.empty_cache = torch.xpu.empty_cache
torch.cuda.memory_stats = torch.xpu.memory_stats
torch.cuda.memory_summary = torch.xpu.memory_summary
torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
torch.cuda.memory_allocated = torch.xpu.memory_allocated
torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated
torch.cuda.memory_reserved = torch.xpu.memory_reserved
torch.cuda.memory_cached = torch.xpu.memory_reserved
torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved
torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved
torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats
torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats
torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats
torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict
torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats
#RNG:
torch.cuda.get_rng_state = torch.xpu.get_rng_state
torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all
torch.cuda.set_rng_state = torch.xpu.set_rng_state
torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all
torch.cuda.manual_seed = torch.xpu.manual_seed
torch.cuda.manual_seed_all = torch.xpu.manual_seed_all
torch.cuda.seed = torch.xpu.seed
torch.cuda.seed_all = torch.xpu.seed_all
torch.cuda.initial_seed = torch.xpu.initial_seed
#AMP:
torch.cuda.amp = torch.xpu.amp
if not hasattr(torch.cuda.amp, "common"):
torch.cuda.amp.common = contextlib.nullcontext()
torch.cuda.amp.common.amp_definitely_not_available = lambda: False
try:
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
except Exception: # pylint: disable=broad-exception-caught
try:
from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error
gradscaler_init()
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
except Exception: # pylint: disable=broad-exception-caught
torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
#C
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
ipex._C._DeviceProperties.major = 2023
ipex._C._DeviceProperties.minor = 2
#Fix functions with ipex:
torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
torch._utils._get_available_device_type = lambda: "xpu"
torch.has_cuda = True
torch.cuda.has_half = True
torch.cuda.is_bf16_supported = lambda *args, **kwargs: True
torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
torch.version.cuda = "11.7"
torch.cuda.get_device_capability = lambda *args, **kwargs: [11,7]
torch.cuda.get_device_properties.major = 11
torch.cuda.get_device_properties.minor = 7
torch.cuda.ipc_collect = lambda *args, **kwargs: None
torch.cuda.utilization = lambda *args, **kwargs: 0
if hasattr(torch.xpu, 'getDeviceIdListForCard'):
torch.cuda.getDeviceIdListForCard = torch.xpu.getDeviceIdListForCard
torch.cuda.get_device_id_list_per_card = torch.xpu.getDeviceIdListForCard
else:
torch.cuda.getDeviceIdListForCard = torch.xpu.get_device_id_list_per_card
torch.cuda.get_device_id_list_per_card = torch.xpu.get_device_id_list_per_card
ipex_hijacks()
attention_init()
try:
from .diffusers import ipex_diffusers
ipex_diffusers()
except Exception: # pylint: disable=broad-exception-caught
pass
except Exception as e:
return False, e
return True, None

157
library/ipex/attention.py Normal file
View File

@@ -0,0 +1,157 @@
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
# pylint: disable=protected-access, missing-function-docstring, line-too-long
original_torch_bmm = torch.bmm
def torch_bmm(input, mat2, *, out=None):
if input.dtype != mat2.dtype:
mat2 = mat2.to(input.dtype)
#ARC GPUs can't allocate more than 4GB to a single block, Slice it:
batch_size_attention, input_tokens, mat2_shape = input.shape[0], input.shape[1], mat2.shape[2]
block_multiply = input.element_size()
slice_block_size = input_tokens * mat2_shape / 1024 / 1024 * block_multiply
block_size = batch_size_attention * slice_block_size
split_slice_size = batch_size_attention
if block_size > 4:
do_split = True
#Find something divisible with the input_tokens
while (split_slice_size * slice_block_size) > 4:
split_slice_size = split_slice_size // 2
if split_slice_size <= 1:
split_slice_size = 1
break
else:
do_split = False
split_2_slice_size = input_tokens
if split_slice_size * slice_block_size > 4:
slice_block_size2 = split_slice_size * mat2_shape / 1024 / 1024 * block_multiply
do_split_2 = True
#Find something divisible with the input_tokens
while (split_2_slice_size * slice_block_size2) > 4:
split_2_slice_size = split_2_slice_size // 2
if split_2_slice_size <= 1:
split_2_slice_size = 1
break
else:
do_split_2 = False
if do_split:
hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype)
for i in range(batch_size_attention // split_slice_size):
start_idx = i * split_slice_size
end_idx = (i + 1) * split_slice_size
if do_split_2:
for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name
start_idx_2 = i2 * split_2_slice_size
end_idx_2 = (i2 + 1) * split_2_slice_size
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm(
input[start_idx:end_idx, start_idx_2:end_idx_2],
mat2[start_idx:end_idx, start_idx_2:end_idx_2],
out=out
)
else:
hidden_states[start_idx:end_idx] = original_torch_bmm(
input[start_idx:end_idx],
mat2[start_idx:end_idx],
out=out
)
else:
return original_torch_bmm(input, mat2, out=out)
return hidden_states
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
#ARC GPUs can't allocate more than 4GB to a single block, Slice it:
if len(query.shape) == 3:
batch_size_attention, query_tokens, shape_four = query.shape
shape_one = 1
no_shape_one = True
else:
shape_one, batch_size_attention, query_tokens, shape_four = query.shape
no_shape_one = False
block_multiply = query.element_size()
slice_block_size = shape_one * query_tokens * shape_four / 1024 / 1024 * block_multiply
block_size = batch_size_attention * slice_block_size
split_slice_size = batch_size_attention
if block_size > 4:
do_split = True
#Find something divisible with the shape_one
while (split_slice_size * slice_block_size) > 4:
split_slice_size = split_slice_size // 2
if split_slice_size <= 1:
split_slice_size = 1
break
else:
do_split = False
split_2_slice_size = query_tokens
if split_slice_size * slice_block_size > 4:
slice_block_size2 = shape_one * split_slice_size * shape_four / 1024 / 1024 * block_multiply
do_split_2 = True
#Find something divisible with the batch_size_attention
while (split_2_slice_size * slice_block_size2) > 4:
split_2_slice_size = split_2_slice_size // 2
if split_2_slice_size <= 1:
split_2_slice_size = 1
break
else:
do_split_2 = False
if do_split:
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
for i in range(batch_size_attention // split_slice_size):
start_idx = i * split_slice_size
end_idx = (i + 1) * split_slice_size
if do_split_2:
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
start_idx_2 = i2 * split_2_slice_size
end_idx_2 = (i2 + 1) * split_2_slice_size
if no_shape_one:
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
query[start_idx:end_idx, start_idx_2:end_idx_2],
key[start_idx:end_idx, start_idx_2:end_idx_2],
value[start_idx:end_idx, start_idx_2:end_idx_2],
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal
)
else:
hidden_states[:, start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
query[:, start_idx:end_idx, start_idx_2:end_idx_2],
key[:, start_idx:end_idx, start_idx_2:end_idx_2],
value[:, start_idx:end_idx, start_idx_2:end_idx_2],
attn_mask=attn_mask[:, start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal
)
else:
if no_shape_one:
hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention(
query[start_idx:end_idx],
key[start_idx:end_idx],
value[start_idx:end_idx],
attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal
)
else:
hidden_states[:, start_idx:end_idx] = original_scaled_dot_product_attention(
query[:, start_idx:end_idx],
key[:, start_idx:end_idx],
value[:, start_idx:end_idx],
attn_mask=attn_mask[:, start_idx:end_idx] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal
)
else:
return original_scaled_dot_product_attention(
query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal
)
return hidden_states
def attention_init():
#ARC GPUs can't allocate more than 4GB to a single block:
torch.bmm = torch_bmm
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention

120
library/ipex/diffusers.py Normal file
View File

@@ -0,0 +1,120 @@
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
import diffusers #0.21.1 # pylint: disable=import-error
from diffusers.models.attention_processor import Attention
# pylint: disable=protected-access, missing-function-docstring, line-too-long
class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
r"""
Processor for implementing sliced attention.
Args:
slice_size (`int`, *optional*):
The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
`attention_head_dim` must be a multiple of the `slice_size`.
"""
def __init__(self, slice_size):
self.slice_size = slice_size
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): # pylint: disable=too-many-statements, too-many-locals, too-many-branches
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
dim = query.shape[-1]
query = attn.head_to_batch_dim(query)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
batch_size_attention, query_tokens, shape_three = query.shape
hidden_states = torch.zeros(
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
)
#ARC GPUs can't allocate more than 4GB to a single block, Slice it:
block_multiply = query.element_size()
slice_block_size = self.slice_size * shape_three / 1024 / 1024 * block_multiply
block_size = query_tokens * slice_block_size
split_2_slice_size = query_tokens
if block_size > 4:
do_split_2 = True
#Find something divisible with the query_tokens
while (split_2_slice_size * slice_block_size) > 4:
split_2_slice_size = split_2_slice_size // 2
if split_2_slice_size <= 1:
split_2_slice_size = 1
break
else:
do_split_2 = False
for i in range(batch_size_attention // self.slice_size):
start_idx = i * self.slice_size
end_idx = (i + 1) * self.slice_size
if do_split_2:
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
start_idx_2 = i2 * split_2_slice_size
end_idx_2 = (i2 + 1) * split_2_slice_size
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
else:
query_slice = query[start_idx:end_idx]
key_slice = key[start_idx:end_idx]
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
def ipex_diffusers():
#ARC GPUs can't allocate more than 4GB to a single block:
diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor

179
library/ipex/gradscaler.py Normal file
View File

@@ -0,0 +1,179 @@
from collections import defaultdict
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
import intel_extension_for_pytorch._C as core # pylint: disable=import-error, unused-import
# pylint: disable=protected-access, missing-function-docstring, line-too-long
OptState = ipex.cpu.autocast._grad_scaler.OptState
_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator
_refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state
def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): # pylint: disable=unused-argument
per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
per_device_found_inf = _MultiDeviceReplicator(found_inf)
# To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
# There could be hundreds of grads, so we'd like to iterate through them just once.
# However, we don't know their devices or dtypes in advance.
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
# Google says mypy struggles with defaultdicts type annotations.
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
# sync grad to master weight
if hasattr(optimizer, "sync_grad"):
optimizer.sync_grad()
with torch.no_grad():
for group in optimizer.param_groups:
for param in group["params"]:
if param.grad is None:
continue
if (not allow_fp16) and param.grad.dtype == torch.float16:
raise ValueError("Attempting to unscale FP16 gradients.")
if param.grad.is_sparse:
# is_coalesced() == False means the sparse grad has values with duplicate indices.
# coalesce() deduplicates indices and adds all values that have the same index.
# For scaled fp16 values, there's a good chance coalescing will cause overflow,
# so we should check the coalesced _values().
if param.grad.dtype is torch.float16:
param.grad = param.grad.coalesce()
to_unscale = param.grad._values()
else:
to_unscale = param.grad
# -: is there a way to split by device and dtype without appending in the inner loop?
to_unscale = to_unscale.to("cpu")
per_device_and_dtype_grads[to_unscale.device][
to_unscale.dtype
].append(to_unscale)
for _, per_dtype_grads in per_device_and_dtype_grads.items():
for grads in per_dtype_grads.values():
core._amp_foreach_non_finite_check_and_unscale_(
grads,
per_device_found_inf.get("cpu"),
per_device_inv_scale.get("cpu"),
)
return per_device_found_inf._per_device_tensors
def unscale_(self, optimizer):
"""
Divides ("unscales") the optimizer's gradient tensors by the scale factor.
:meth:`unscale_` is optional, serving cases where you need to
:ref:`modify or inspect gradients<working-with-unscaled-gradients>`
between the backward pass(es) and :meth:`step`.
If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`.
Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
...
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
scaler.step(optimizer)
scaler.update()
Args:
optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled.
.. warning::
:meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
and only after all gradients for that optimizer's assigned parameters have been accumulated.
Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
.. warning::
:meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
"""
if not self._enabled:
return
self._check_scale_growth_tracker("unscale_")
optimizer_state = self._per_optimizer_states[id(optimizer)]
if optimizer_state["stage"] is OptState.UNSCALED: # pylint: disable=no-else-raise
raise RuntimeError(
"unscale_() has already been called on this optimizer since the last update()."
)
elif optimizer_state["stage"] is OptState.STEPPED:
raise RuntimeError("unscale_() is being called after step().")
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
assert self._scale is not None
inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device)
found_inf = torch.full(
(1,), 0.0, dtype=torch.float32, device=self._scale.device
)
optimizer_state["found_inf_per_device"] = self._unscale_grads_(
optimizer, inv_scale, found_inf, False
)
optimizer_state["stage"] = OptState.UNSCALED
def update(self, new_scale=None):
"""
Updates the scale factor.
If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
the scale is multiplied by ``growth_factor`` to increase it.
Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
used directly, it's used to fill GradScaler's internal scale tensor. So if
``new_scale`` was a tensor, later in-place changes to that tensor will not further
affect the scale GradScaler uses internally.)
Args:
new_scale (float or :class:`torch.FloatTensor`, optional, default=None): New scale factor.
.. warning::
:meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
been invoked for all optimizers used this iteration.
"""
if not self._enabled:
return
_scale, _growth_tracker = self._check_scale_growth_tracker("update")
if new_scale is not None:
# Accept a new user-defined scale.
if isinstance(new_scale, float):
self._scale.fill_(new_scale) # type: ignore[union-attr]
else:
reason = "new_scale should be a float or a 1-element torch.FloatTensor with requires_grad=False."
assert isinstance(new_scale, torch.FloatTensor), reason # type: ignore[attr-defined]
assert new_scale.numel() == 1, reason
assert new_scale.requires_grad is False, reason
self._scale.copy_(new_scale) # type: ignore[union-attr]
else:
# Consume shared inf/nan data collected from optimizers to update the scale.
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
found_infs = [
found_inf.to(device="cpu", non_blocking=True)
for state in self._per_optimizer_states.values()
for found_inf in state["found_inf_per_device"].values()
]
assert len(found_infs) > 0, "No inf checks were recorded prior to update."
found_inf_combined = found_infs[0]
if len(found_infs) > 1:
for i in range(1, len(found_infs)):
found_inf_combined += found_infs[i]
to_device = _scale.device
_scale = _scale.to("cpu")
_growth_tracker = _growth_tracker.to("cpu")
core._amp_update_scale_(
_scale,
_growth_tracker,
found_inf_combined,
self._growth_factor,
self._backoff_factor,
self._growth_interval,
)
_scale = _scale.to(to_device)
_growth_tracker = _growth_tracker.to(to_device)
# To prepare for next iteration, clear the data collected from optimizers this iteration.
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
def gradscaler_init():
torch.xpu.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
torch.xpu.amp.GradScaler._unscale_grads_ = _unscale_grads_
torch.xpu.amp.GradScaler.unscale_ = unscale_
torch.xpu.amp.GradScaler.update = update
return torch.xpu.amp.GradScaler

196
library/ipex/hijacks.py Normal file
View File

@@ -0,0 +1,196 @@
import contextlib
import importlib
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
class CondFunc: # pylint: disable=missing-class-docstring
def __new__(cls, orig_func, sub_func, cond_func):
self = super(CondFunc, cls).__new__(cls)
if isinstance(orig_func, str):
func_path = orig_func.split('.')
for i in range(len(func_path)-1, -1, -1):
try:
resolved_obj = importlib.import_module('.'.join(func_path[:i]))
break
except ImportError:
pass
for attr_name in func_path[i:-1]:
resolved_obj = getattr(resolved_obj, attr_name)
orig_func = getattr(resolved_obj, func_path[-1])
setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
self.__init__(orig_func, sub_func, cond_func)
return lambda *args, **kwargs: self(*args, **kwargs)
def __init__(self, orig_func, sub_func, cond_func):
self.__orig_func = orig_func
self.__sub_func = sub_func
self.__cond_func = cond_func
def __call__(self, *args, **kwargs):
if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
return self.__sub_func(self.__orig_func, *args, **kwargs)
else:
return self.__orig_func(*args, **kwargs)
_utils = torch.utils.data._utils
def _shutdown_workers(self):
if torch.utils.data._utils is None or torch.utils.data._utils.python_exit_status is True or torch.utils.data._utils.python_exit_status is None:
return
if hasattr(self, "_shutdown") and not self._shutdown:
self._shutdown = True
try:
if hasattr(self, '_pin_memory_thread'):
self._pin_memory_thread_done_event.set()
self._worker_result_queue.put((None, None))
self._pin_memory_thread.join()
self._worker_result_queue.cancel_join_thread()
self._worker_result_queue.close()
self._workers_done_event.set()
for worker_id in range(len(self._workers)):
if self._persistent_workers or self._workers_status[worker_id]:
self._mark_worker_as_unavailable(worker_id, shutdown=True)
for w in self._workers: # pylint: disable=invalid-name
w.join(timeout=torch.utils.data._utils.MP_STATUS_CHECK_INTERVAL)
for q in self._index_queues: # pylint: disable=invalid-name
q.cancel_join_thread()
q.close()
finally:
if self._worker_pids_set:
torch.utils.data._utils.signal_handling._remove_worker_pids(id(self))
self._worker_pids_set = False
for w in self._workers: # pylint: disable=invalid-name
if w.is_alive():
w.terminate()
class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument
if isinstance(device_ids, list) and len(device_ids) > 1:
print("IPEX backend doesn't support DataParallel on multiple XPU devices")
return module.to("xpu")
def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
return contextlib.nullcontext()
def check_device(device):
return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int))
def return_xpu(device):
return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu"
def ipex_no_cuda(orig_func, *args, **kwargs):
torch.cuda.is_available = lambda: False
orig_func(*args, **kwargs)
torch.cuda.is_available = torch.xpu.is_available
original_autocast = torch.autocast
def ipex_autocast(*args, **kwargs):
if len(args) > 0 and args[0] == "cuda":
return original_autocast("xpu", *args[1:], **kwargs)
else:
return original_autocast(*args, **kwargs)
original_torch_cat = torch.cat
def torch_cat(tensor, *args, **kwargs):
if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs)
else:
return original_torch_cat(tensor, *args, **kwargs)
original_interpolate = torch.nn.functional.interpolate
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
if antialias or align_corners is not None:
return_device = tensor.device
return_dtype = tensor.dtype
return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode,
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias).to(return_device, dtype=return_dtype)
else:
return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode,
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias)
original_linalg_solve = torch.linalg.solve
def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name
if A.device != torch.device("cpu") or B.device != torch.device("cpu"):
return_device = A.device
return original_linalg_solve(A.to("cpu"), B.to("cpu"), *args, **kwargs).to(return_device)
else:
return original_linalg_solve(A, B, *args, **kwargs)
def ipex_hijacks():
CondFunc('torch.Tensor.to',
lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
CondFunc('torch.Tensor.cuda',
lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
CondFunc('torch.empty',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.load',
lambda orig_func, *args, map_location=None, **kwargs: orig_func(*args, return_xpu(map_location), **kwargs),
lambda orig_func, *args, map_location=None, **kwargs: map_location is None or check_device(map_location))
CondFunc('torch.randn',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.ones',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.zeros',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.tensor',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.linspace',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.Generator',
lambda orig_func, device=None: torch.xpu.Generator(device),
lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu")
CondFunc('torch.batch_norm',
lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input,
weight if weight is not None else torch.ones(input.size()[1], device=input.device),
bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs),
lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu"))
CondFunc('torch.instance_norm',
lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input,
weight if weight is not None else torch.ones(input.size()[1], device=input.device),
bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs),
lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu"))
#Functions with dtype errors:
CondFunc('torch.nn.modules.GroupNorm.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
CondFunc('torch.nn.modules.linear.Linear.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
CondFunc('torch.nn.modules.conv.Conv2d.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
CondFunc('torch.nn.functional.layer_norm',
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs),
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
weight is not None and input.dtype != weight.data.dtype)
#Diffusers Float64 (ARC GPUs doesn't support double or Float64):
if not torch.xpu.has_fp64_dtype():
CondFunc('torch.from_numpy',
lambda orig_func, ndarray: orig_func(ndarray.astype('float32')),
lambda orig_func, ndarray: ndarray.dtype == float)
#Broken functions when torch.cuda.is_available is True:
CondFunc('torch.utils.data.dataloader._BaseDataLoaderIter.__init__',
lambda orig_func, *args, **kwargs: ipex_no_cuda(orig_func, *args, **kwargs),
lambda orig_func, *args, **kwargs: True)
#Functions that make compile mad with CondFunc:
torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = _shutdown_workers
torch.nn.DataParallel = DummyDataParallel
torch.autocast = ipex_autocast
torch.cat = torch_cat
torch.linalg.solve = linalg_solve
torch.nn.functional.interpolate = interpolate
torch.backends.cuda.sdp_kernel = return_null_context

View File

@@ -6,7 +6,7 @@ import re
from typing import Callable, List, Optional, Union
import numpy as np
import PIL
import PIL.Image
import torch
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
@@ -245,11 +245,6 @@ def get_unweighted_text_embeddings(
text_embedding = enc_out["hidden_states"][-clip_skip]
text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding)
# cover the head and the tail by the starting and the ending tokens
text_input_chunk[:, 0] = text_input[0, 0]
text_input_chunk[:, -1] = text_input[0, -1]
text_embedding = pipe.text_encoder(text_input_chunk, attention_mask=None)[0]
if no_boseos_middle:
if i == 0:
# discard the ending token
@@ -264,7 +259,12 @@ def get_unweighted_text_embeddings(
text_embeddings.append(text_embedding)
text_embeddings = torch.concat(text_embeddings, axis=1)
else:
text_embeddings = pipe.text_encoder(text_input)[0]
if clip_skip is None or clip_skip == 1:
text_embeddings = pipe.text_encoder(text_input)[0]
else:
enc_out = pipe.text_encoder(text_input, output_hidden_states=True, return_dict=True)
text_embeddings = enc_out["hidden_states"][-clip_skip]
text_embeddings = pipe.text_encoder.text_model.final_layer_norm(text_embeddings)
return text_embeddings
@@ -426,6 +426,58 @@ def preprocess_mask(mask, scale_factor=8):
return mask
def prepare_controlnet_image(
image: PIL.Image.Image,
width: int,
height: int,
batch_size: int,
num_images_per_prompt: int,
device: torch.device,
dtype: torch.dtype,
do_classifier_free_guidance: bool = False,
guess_mode: bool = False,
):
if not isinstance(image, torch.Tensor):
if isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
images = []
for image_ in image:
image_ = image_.convert("RGB")
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
image_ = np.array(image_)
image_ = image_[None, :]
images.append(image_)
image = images
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
image_batch_size = image.shape[0]
if image_batch_size == 1:
repeat_by = batch_size
else:
# image batch size is the same as prompt batch size
repeat_by = num_images_per_prompt
image = image.repeat_interleave(repeat_by, dim=0)
image = image.to(device=device, dtype=dtype)
if do_classifier_free_guidance and not guess_mode:
image = torch.cat([image] * 2)
return image
class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
r"""
Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
@@ -464,10 +516,11 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: SchedulerMixin,
clip_skip: int,
# clip_skip: int,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
clip_skip: int = 1,
):
super().__init__(
vae=vae,
@@ -707,6 +760,8 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
max_embeddings_multiples: Optional[int] = 3,
output_type: Optional[str] = "pil",
return_dict: bool = True,
controlnet=None,
controlnet_image=None,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
is_cancelled_callback: Optional[Callable[[], bool]] = None,
callback_steps: int = 1,
@@ -767,6 +822,11 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
controlnet (`diffusers.ControlNetModel`, *optional*):
A controlnet model to be used for the inference. If not provided, controlnet will be disabled.
controlnet_image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*):
`Image`, or tensor representing an image batch, to be used as the starting point for the controlnet
inference.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
@@ -785,6 +845,9 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
if controlnet is not None and controlnet_image is None:
raise ValueError("controlnet_image must be provided if controlnet is not None.")
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
@@ -824,6 +887,11 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
else:
mask = None
if controlnet_image is not None:
controlnet_image = prepare_controlnet_image(
controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False
)
# 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)
@@ -851,8 +919,22 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
unet_additional_args = {}
if controlnet is not None:
down_block_res_samples, mid_block_res_sample = controlnet(
latent_model_input,
t,
encoder_hidden_states=text_embeddings,
controlnet_cond=controlnet_image,
conditioning_scale=1.0,
guess_mode=False,
return_dict=False,
)
unet_additional_args["down_block_additional_residuals"] = down_block_res_samples
unet_additional_args["mid_block_additional_residual"] = mid_block_res_sample
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, **unet_additional_args).sample
# perform guidance
if do_classifier_free_guidance:
@@ -874,20 +956,13 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
if is_cancelled_callback is not None and is_cancelled_callback():
return None
return latents
def latents_to_image(self, latents):
# 9. Post-processing
image = self.decode_latents(latents)
# 10. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
# 11. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return image, has_nsfw_concept
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
image = self.decode_latents(latents.to(self.vae.dtype))
image = self.numpy_to_pil(image)
return image
def text2img(
self,

View File

@@ -4,9 +4,18 @@
import math
import os
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
import diffusers
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel
from safetensors.torch import load_file, save_file
from library.original_unet import UNet2DConditionModel
# DiffUsers版StableDiffusionのモデルパラメータ
NUM_TRAIN_TIMESTEPS = 1000
@@ -126,17 +135,30 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
new_item = new_item.replace("norm.weight", "group_norm.weight")
new_item = new_item.replace("norm.bias", "group_norm.bias")
new_item = new_item.replace("q.weight", "query.weight")
new_item = new_item.replace("q.bias", "query.bias")
if diffusers.__version__ < "0.17.0":
new_item = new_item.replace("q.weight", "query.weight")
new_item = new_item.replace("q.bias", "query.bias")
new_item = new_item.replace("k.weight", "key.weight")
new_item = new_item.replace("k.bias", "key.bias")
new_item = new_item.replace("k.weight", "key.weight")
new_item = new_item.replace("k.bias", "key.bias")
new_item = new_item.replace("v.weight", "value.weight")
new_item = new_item.replace("v.bias", "value.bias")
new_item = new_item.replace("v.weight", "value.weight")
new_item = new_item.replace("v.bias", "value.bias")
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
else:
new_item = new_item.replace("q.weight", "to_q.weight")
new_item = new_item.replace("q.bias", "to_q.bias")
new_item = new_item.replace("k.weight", "to_k.weight")
new_item = new_item.replace("k.bias", "to_k.bias")
new_item = new_item.replace("v.weight", "to_v.weight")
new_item = new_item.replace("v.bias", "to_v.bias")
new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
@@ -191,8 +213,16 @@ def assign_to_checkpoint(
new_path = new_path.replace(replacement["old"], replacement["new"])
# proj_attn.weight has to be converted from conv 1D to linear
if "proj_attn.weight" in new_path:
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
reshaping = False
if diffusers.__version__ < "0.17.0":
if "proj_attn.weight" in new_path:
reshaping = True
else:
if ".attentions." in new_path and ".0.to_" in new_path and old_checkpoint[path["old"]].ndim > 2:
reshaping = True
if reshaping:
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
else:
checkpoint[new_path] = old_checkpoint[path["old"]]
@@ -361,7 +391,7 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):
# SDのv2では1*1のconv2dがlinearに変わっている
# 誤って Diffusers 側を conv2d のままにしてしまったので、変換必要
if v2 and not config.get('use_linear_projection', False):
if v2 and not config.get("use_linear_projection", False):
linear_transformer_to_conv(new_checkpoint)
return new_checkpoint
@@ -540,6 +570,11 @@ def convert_ldm_clip_checkpoint_v1(checkpoint):
for key in keys:
if key.startswith("cond_stage_model.transformer"):
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
# support checkpoint without position_ids (invalid checkpoint)
if "text_model.embeddings.position_ids" not in text_model_dict:
text_model_dict["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) # 77 is the max length of the text
return text_model_dict
@@ -732,6 +767,105 @@ def convert_unet_state_dict_to_sd(v2, unet_state_dict):
return new_state_dict
def controlnet_conversion_map():
unet_conversion_map = [
("time_embed.0.weight", "time_embedding.linear_1.weight"),
("time_embed.0.bias", "time_embedding.linear_1.bias"),
("time_embed.2.weight", "time_embedding.linear_2.weight"),
("time_embed.2.bias", "time_embedding.linear_2.bias"),
("input_blocks.0.0.weight", "conv_in.weight"),
("input_blocks.0.0.bias", "conv_in.bias"),
("middle_block_out.0.weight", "controlnet_mid_block.weight"),
("middle_block_out.0.bias", "controlnet_mid_block.bias"),
]
unet_conversion_map_resnet = [
("in_layers.0", "norm1"),
("in_layers.2", "conv1"),
("out_layers.0", "norm2"),
("out_layers.3", "conv2"),
("emb_layers.1", "time_emb_proj"),
("skip_connection", "conv_shortcut"),
]
unet_conversion_map_layer = []
for i in range(4):
for j in range(2):
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3:
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
if i < 3:
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0."
sd_mid_atn_prefix = "middle_block.1."
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2*j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
controlnet_cond_embedding_names = ["conv_in"] + [f"blocks.{i}" for i in range(6)] + ["conv_out"]
for i, hf_prefix in enumerate(controlnet_cond_embedding_names):
hf_prefix = f"controlnet_cond_embedding.{hf_prefix}."
sd_prefix = f"input_hint_block.{i*2}."
unet_conversion_map_layer.append((sd_prefix, hf_prefix))
for i in range(12):
hf_prefix = f"controlnet_down_blocks.{i}."
sd_prefix = f"zero_convs.{i}.0."
unet_conversion_map_layer.append((sd_prefix, hf_prefix))
return unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer
def convert_controlnet_state_dict_to_sd(controlnet_state_dict):
unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map()
mapping = {k: k for k in controlnet_state_dict.keys()}
for sd_name, diffusers_name in unet_conversion_map:
mapping[diffusers_name] = sd_name
for k, v in mapping.items():
if "resnets" in k:
for sd_part, diffusers_part in unet_conversion_map_resnet:
v = v.replace(diffusers_part, sd_part)
mapping[k] = v
for k, v in mapping.items():
for sd_part, diffusers_part in unet_conversion_map_layer:
v = v.replace(diffusers_part, sd_part)
mapping[k] = v
new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()}
return new_state_dict
def convert_controlnet_state_dict_to_diffusers(controlnet_state_dict):
unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map()
mapping = {k: k for k in controlnet_state_dict.keys()}
for sd_name, diffusers_name in unet_conversion_map:
mapping[sd_name] = diffusers_name
for k, v in mapping.items():
for sd_part, diffusers_part in unet_conversion_map_layer:
v = v.replace(sd_part, diffusers_part)
mapping[k] = v
for k, v in mapping.items():
if "resnets" in v:
for sd_part, diffusers_part in unet_conversion_map_resnet:
v = v.replace(sd_part, diffusers_part)
mapping[k] = v
new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()}
return new_state_dict
# ================#
# VAE Conversion #
# ================#
@@ -779,14 +913,24 @@ def convert_vae_state_dict(vae_state_dict):
sd_mid_res_prefix = f"mid.block_{i+1}."
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
vae_conversion_map_attn = [
# (stable-diffusion, HF Diffusers)
("norm.", "group_norm."),
("q.", "query."),
("k.", "key."),
("v.", "value."),
("proj_out.", "proj_attn."),
]
if diffusers.__version__ < "0.17.0":
vae_conversion_map_attn = [
# (stable-diffusion, HF Diffusers)
("norm.", "group_norm."),
("q.", "query."),
("k.", "key."),
("v.", "value."),
("proj_out.", "proj_attn."),
]
else:
vae_conversion_map_attn = [
# (stable-diffusion, HF Diffusers)
("norm.", "group_norm."),
("q.", "to_q."),
("k.", "to_k."),
("v.", "to_v."),
("proj_out.", "to_out.0."),
]
mapping = {k: k for k in vae_state_dict.keys()}
for k, v in mapping.items():
@@ -803,7 +947,7 @@ def convert_vae_state_dict(vae_state_dict):
for k, v in new_state_dict.items():
for weight_name in weights_to_convert:
if f"mid.attn_1.{weight_name}.weight" in k:
# print(f"Reshaping {k} for SD format")
# print(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1")
new_state_dict[k] = reshape_weight_for_sd(v)
return new_state_dict
@@ -852,7 +996,7 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"):
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None, unet_use_linear_projection_in_v2=False):
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None, unet_use_linear_projection_in_v2=True):
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device)
# Convert the UNet2DConditionModel model.
@@ -900,16 +1044,49 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt
else:
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
logging.set_verbosity_error() # don't show annoying warning
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
logging.set_verbosity_warning()
# logging.set_verbosity_error() # don't show annoying warning
# text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
# logging.set_verbosity_warning()
# print(f"config: {text_model.config}")
cfg = CLIPTextConfig(
vocab_size=49408,
hidden_size=768,
intermediate_size=3072,
num_hidden_layers=12,
num_attention_heads=12,
max_position_embeddings=77,
hidden_act="quick_gelu",
layer_norm_eps=1e-05,
dropout=0.0,
attention_dropout=0.0,
initializer_range=0.02,
initializer_factor=1.0,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
model_type="clip_text_model",
projection_dim=768,
torch_dtype="float32",
)
text_model = CLIPTextModel._from_config(cfg)
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
print("loading text encoder:", info)
return text_model, vae, unet
def get_model_version_str_for_sd1_sd2(v2, v_parameterization):
# only for reference
version_str = "sd"
if v2:
version_str += "_v2"
else:
version_str += "_v1"
if v_parameterization:
version_str += "_v"
return version_str
def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
def convert_key(key):
# position_idsの除去
@@ -981,7 +1158,9 @@ def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=Fals
return new_sd
def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None):
def save_stable_diffusion_checkpoint(
v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, metadata, save_dtype=None, vae=None
):
if ckpt_path is not None:
# epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
@@ -1043,7 +1222,7 @@ def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_p
if is_safetensors(output_file):
# TODO Tensor以外のdictの値を削除したほうがいいか
save_file(state_dict, output_file)
save_file(state_dict, output_file, metadata)
else:
torch.save(new_ckpt, output_file)

1606
library/original_unet.py Normal file

File diff suppressed because it is too large Load Diff

305
library/sai_model_spec.py Normal file
View File

@@ -0,0 +1,305 @@
# based on https://github.com/Stability-AI/ModelSpec
import datetime
import hashlib
from io import BytesIO
import os
from typing import List, Optional, Tuple, Union
import safetensors
r"""
# Metadata Example
metadata = {
# === Must ===
"modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec
"modelspec.architecture": "stable-diffusion-xl-v1-base", # Architecture, reference the ID of the original model of the arch to match the ID
"modelspec.implementation": "sgm",
"modelspec.title": "Example Model Version 1.0", # Clean, human-readable title. May use your own phrasing/language/etc
# === Should ===
"modelspec.author": "Example Corp", # Your name or company name
"modelspec.description": "This is my example model to show you how to do it!", # Describe the model in your own words/language/etc. Focus on what users need to know
"modelspec.date": "2023-07-20", # ISO-8601 compliant date of when the model was created
# === Can ===
"modelspec.license": "ExampleLicense-1.0", # eg CreativeML Open RAIL, etc.
"modelspec.usage_hint": "Use keyword 'example'" # In your own language, very short hints about how the user should use the model
}
"""
BASE_METADATA = {
# === Must ===
"modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec
"modelspec.architecture": None,
"modelspec.implementation": None,
"modelspec.title": None,
"modelspec.resolution": None,
# === Should ===
"modelspec.description": None,
"modelspec.author": None,
"modelspec.date": None,
# === Can ===
"modelspec.license": None,
"modelspec.tags": None,
"modelspec.merged_from": None,
"modelspec.prediction_type": None,
"modelspec.timestep_range": None,
"modelspec.encoder_layer": None,
}
# 別に使うやつだけ定義
MODELSPEC_TITLE = "modelspec.title"
ARCH_SD_V1 = "stable-diffusion-v1"
ARCH_SD_V2_512 = "stable-diffusion-v2-512"
ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v"
ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base"
ADAPTER_LORA = "lora"
ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models"
IMPL_DIFFUSERS = "diffusers"
PRED_TYPE_EPSILON = "epsilon"
PRED_TYPE_V = "v"
def load_bytes_in_safetensors(tensors):
bytes = safetensors.torch.save(tensors)
b = BytesIO(bytes)
b.seek(0)
header = b.read(8)
n = int.from_bytes(header, "little")
offset = n + 8
b.seek(offset)
return b.read()
def precalculate_safetensors_hashes(state_dict):
# calculate each tensor one by one to reduce memory usage
hash_sha256 = hashlib.sha256()
for tensor in state_dict.values():
single_tensor_sd = {"tensor": tensor}
bytes_for_tensor = load_bytes_in_safetensors(single_tensor_sd)
hash_sha256.update(bytes_for_tensor)
return f"0x{hash_sha256.hexdigest()}"
def update_hash_sha256(metadata: dict, state_dict: dict):
raise NotImplementedError
def build_metadata(
state_dict: Optional[dict],
v2: bool,
v_parameterization: bool,
sdxl: bool,
lora: bool,
textual_inversion: bool,
timestamp: float,
title: Optional[str] = None,
reso: Optional[Union[int, Tuple[int, int]]] = None,
is_stable_diffusion_ckpt: Optional[bool] = None,
author: Optional[str] = None,
description: Optional[str] = None,
license: Optional[str] = None,
tags: Optional[str] = None,
merged_from: Optional[str] = None,
timesteps: Optional[Tuple[int, int]] = None,
clip_skip: Optional[int] = None,
):
# if state_dict is None, hash is not calculated
metadata = {}
metadata.update(BASE_METADATA)
# TODO メモリを消費せずかつ正しいハッシュ計算の方法がわかったら実装する
# if state_dict is not None:
# hash = precalculate_safetensors_hashes(state_dict)
# metadata["modelspec.hash_sha256"] = hash
if sdxl:
arch = ARCH_SD_XL_V1_BASE
elif v2:
if v_parameterization:
arch = ARCH_SD_V2_768_V
else:
arch = ARCH_SD_V2_512
else:
arch = ARCH_SD_V1
if lora:
arch += f"/{ADAPTER_LORA}"
elif textual_inversion:
arch += f"/{ADAPTER_TEXTUAL_INVERSION}"
metadata["modelspec.architecture"] = arch
if not lora and not textual_inversion and is_stable_diffusion_ckpt is None:
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
# Stable Diffusion ckpt, TI, SDXL LoRA
impl = IMPL_STABILITY_AI
else:
# v1/v2 LoRA or Diffusers
impl = IMPL_DIFFUSERS
metadata["modelspec.implementation"] = impl
if title is None:
if lora:
title = "LoRA"
elif textual_inversion:
title = "TextualInversion"
else:
title = "Checkpoint"
title += f"@{timestamp}"
metadata[MODELSPEC_TITLE] = title
if author is not None:
metadata["modelspec.author"] = author
else:
del metadata["modelspec.author"]
if description is not None:
metadata["modelspec.description"] = description
else:
del metadata["modelspec.description"]
if merged_from is not None:
metadata["modelspec.merged_from"] = merged_from
else:
del metadata["modelspec.merged_from"]
if license is not None:
metadata["modelspec.license"] = license
else:
del metadata["modelspec.license"]
if tags is not None:
metadata["modelspec.tags"] = tags
else:
del metadata["modelspec.tags"]
# remove microsecond from time
int_ts = int(timestamp)
# time to iso-8601 compliant date
date = datetime.datetime.fromtimestamp(int_ts).isoformat()
metadata["modelspec.date"] = date
if reso is not None:
# comma separated to tuple
if isinstance(reso, str):
reso = tuple(map(int, reso.split(",")))
if len(reso) == 1:
reso = (reso[0], reso[0])
else:
# resolution is defined in dataset, so use default
if sdxl:
reso = 1024
elif v2 and v_parameterization:
reso = 768
else:
reso = 512
if isinstance(reso, int):
reso = (reso, reso)
metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}"
if v_parameterization:
metadata["modelspec.prediction_type"] = PRED_TYPE_V
else:
metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON
if timesteps is not None:
if isinstance(timesteps, str) or isinstance(timesteps, int):
timesteps = (timesteps, timesteps)
if len(timesteps) == 1:
timesteps = (timesteps[0], timesteps[0])
metadata["modelspec.timestep_range"] = f"{timesteps[0]},{timesteps[1]}"
else:
del metadata["modelspec.timestep_range"]
if clip_skip is not None:
metadata["modelspec.encoder_layer"] = f"{clip_skip}"
else:
del metadata["modelspec.encoder_layer"]
# # assert all values are filled
# assert all([v is not None for v in metadata.values()]), metadata
if not all([v is not None for v in metadata.values()]):
print(f"Internal error: some metadata values are None: {metadata}")
return metadata
# region utils
def get_title(metadata: dict) -> Optional[str]:
return metadata.get(MODELSPEC_TITLE, None)
def load_metadata_from_safetensors(model: str) -> dict:
if not model.endswith(".safetensors"):
return {}
with safetensors.safe_open(model, framework="pt") as f:
metadata = f.metadata()
if metadata is None:
metadata = {}
return metadata
def build_merged_from(models: List[str]) -> str:
def get_title(model: str):
metadata = load_metadata_from_safetensors(model)
title = metadata.get(MODELSPEC_TITLE, None)
if title is None:
title = os.path.splitext(os.path.basename(model))[0] # use filename
return title
titles = [get_title(model) for model in models]
return ", ".join(titles)
# endregion
r"""
if __name__ == "__main__":
import argparse
import torch
from safetensors.torch import load_file
from library import train_util
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt", type=str, required=True)
args = parser.parse_args()
print(f"Loading {args.ckpt}")
state_dict = load_file(args.ckpt)
print(f"Calculating metadata")
metadata = get(state_dict, False, False, False, False, "sgm", False, False, "title", "date", 256, 1000, 0)
print(metadata)
del state_dict
# by reference implementation
with open(args.ckpt, mode="rb") as file_data:
file_hash = hashlib.sha256()
head_len = struct.unpack("Q", file_data.read(8)) # int64 header length prefix
header = json.loads(file_data.read(head_len[0])) # header itself, json string
content = (
file_data.read()
) # All other content is tightly packed tensors. Copy to RAM for simplicity, but you can avoid this read with a more careful FS-dependent impl.
file_hash.update(content)
# ===== Update the hash for modelspec =====
by_ref = f"0x{file_hash.hexdigest()}"
print(by_ref)
print("is same?", by_ref == metadata["modelspec.hash_sha256"])
"""

File diff suppressed because it is too large Load Diff

572
library/sdxl_model_util.py Normal file
View File

@@ -0,0 +1,572 @@
import torch
from accelerate import init_empty_weights
from accelerate.utils.modeling import set_module_tensor_to_device
from safetensors.torch import load_file, save_file
from transformers import CLIPTextModel, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
from typing import List
from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
from library import model_util
from library import sdxl_original_unet
VAE_SCALE_FACTOR = 0.13025
MODEL_VERSION_SDXL_BASE_V1_0 = "sdxl_base_v1-0"
# Diffusersの設定を読み込むための参照モデル
DIFFUSERS_REF_MODEL_ID_SDXL = "stabilityai/stable-diffusion-xl-base-1.0"
DIFFUSERS_SDXL_UNET_CONFIG = {
"act_fn": "silu",
"addition_embed_type": "text_time",
"addition_embed_type_num_heads": 64,
"addition_time_embed_dim": 256,
"attention_head_dim": [5, 10, 20],
"block_out_channels": [320, 640, 1280],
"center_input_sample": False,
"class_embed_type": None,
"class_embeddings_concat": False,
"conv_in_kernel": 3,
"conv_out_kernel": 3,
"cross_attention_dim": 2048,
"cross_attention_norm": None,
"down_block_types": ["DownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D"],
"downsample_padding": 1,
"dual_cross_attention": False,
"encoder_hid_dim": None,
"encoder_hid_dim_type": None,
"flip_sin_to_cos": True,
"freq_shift": 0,
"in_channels": 4,
"layers_per_block": 2,
"mid_block_only_cross_attention": None,
"mid_block_scale_factor": 1,
"mid_block_type": "UNetMidBlock2DCrossAttn",
"norm_eps": 1e-05,
"norm_num_groups": 32,
"num_attention_heads": None,
"num_class_embeds": None,
"only_cross_attention": False,
"out_channels": 4,
"projection_class_embeddings_input_dim": 2816,
"resnet_out_scale_factor": 1.0,
"resnet_skip_time_act": False,
"resnet_time_scale_shift": "default",
"sample_size": 128,
"time_cond_proj_dim": None,
"time_embedding_act_fn": None,
"time_embedding_dim": None,
"time_embedding_type": "positional",
"timestep_post_act": None,
"transformer_layers_per_block": [1, 2, 10],
"up_block_types": ["CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"],
"upcast_attention": False,
"use_linear_projection": True,
}
def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
SDXL_KEY_PREFIX = "conditioner.embedders.1.model."
# SD2のと、基本的には同じ。logit_scaleを後で使うので、それを追加で返す
# logit_scaleはcheckpointの保存時に使用する
def convert_key(key):
# common conversion
key = key.replace(SDXL_KEY_PREFIX + "transformer.", "text_model.encoder.")
key = key.replace(SDXL_KEY_PREFIX, "text_model.")
if "resblocks" in key:
# resblocks conversion
key = key.replace(".resblocks.", ".layers.")
if ".ln_" in key:
key = key.replace(".ln_", ".layer_norm")
elif ".mlp." in key:
key = key.replace(".c_fc.", ".fc1.")
key = key.replace(".c_proj.", ".fc2.")
elif ".attn.out_proj" in key:
key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
elif ".attn.in_proj" in key:
key = None # 特殊なので後で処理する
else:
raise ValueError(f"unexpected key in SD: {key}")
elif ".positional_embedding" in key:
key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
elif ".text_projection" in key:
key = key.replace("text_model.text_projection", "text_projection.weight")
elif ".logit_scale" in key:
key = None # 後で処理する
elif ".token_embedding" in key:
key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
elif ".ln_final" in key:
key = key.replace(".ln_final", ".final_layer_norm")
# ckpt from comfy has this key: text_model.encoder.text_model.embeddings.position_ids
elif ".embeddings.position_ids" in key:
key = None # remove this key: make position_ids by ourselves
return key
keys = list(checkpoint.keys())
new_sd = {}
for key in keys:
new_key = convert_key(key)
if new_key is None:
continue
new_sd[new_key] = checkpoint[key]
# attnの変換
for key in keys:
if ".resblocks" in key and ".attn.in_proj_" in key:
# 三つに分割
values = torch.chunk(checkpoint[key], 3)
key_suffix = ".weight" if "weight" in key else ".bias"
key_pfx = key.replace(SDXL_KEY_PREFIX + "transformer.resblocks.", "text_model.encoder.layers.")
key_pfx = key_pfx.replace("_weight", "")
key_pfx = key_pfx.replace("_bias", "")
key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
# original SD にはないので、position_idsを追加
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
new_sd["text_model.embeddings.position_ids"] = position_ids
# logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す
logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None)
return new_sd, logit_scale
# load state_dict without allocating new tensors
def _load_state_dict_on_device(model, state_dict, device, dtype=None):
# dtype will use fp32 as default
missing_keys = list(model.state_dict().keys() - state_dict.keys())
unexpected_keys = list(state_dict.keys() - model.state_dict().keys())
# similar to model.load_state_dict()
if not missing_keys and not unexpected_keys:
for k in list(state_dict.keys()):
set_module_tensor_to_device(model, k, device, value=state_dict.pop(k), dtype=dtype)
return "<All keys matched successfully>"
# error_msgs
error_msgs: List[str] = []
if missing_keys:
error_msgs.insert(0, "Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in missing_keys)))
if unexpected_keys:
error_msgs.insert(0, "Unexpected key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in unexpected_keys)))
raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs)))
def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None):
# model_version is reserved for future use
# dtype is used for full_fp16/bf16 integration. Text Encoder will remain fp32, because it runs on CPU when caching
# Load the state dict
if model_util.is_safetensors(ckpt_path):
checkpoint = None
try:
state_dict = load_file(ckpt_path, device=map_location)
except:
state_dict = load_file(ckpt_path) # prevent device invalid Error
epoch = None
global_step = None
else:
checkpoint = torch.load(ckpt_path, map_location=map_location)
if "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
epoch = checkpoint.get("epoch", 0)
global_step = checkpoint.get("global_step", 0)
else:
state_dict = checkpoint
epoch = 0
global_step = 0
checkpoint = None
# U-Net
print("building U-Net")
with init_empty_weights():
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
print("loading U-Net from checkpoint")
unet_sd = {}
for k in list(state_dict.keys()):
if k.startswith("model.diffusion_model."):
unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k)
info = _load_state_dict_on_device(unet, unet_sd, device=map_location, dtype=dtype)
print("U-Net: ", info)
# Text Encoders
print("building text encoders")
# Text Encoder 1 is same to Stability AI's SDXL
text_model1_cfg = CLIPTextConfig(
vocab_size=49408,
hidden_size=768,
intermediate_size=3072,
num_hidden_layers=12,
num_attention_heads=12,
max_position_embeddings=77,
hidden_act="quick_gelu",
layer_norm_eps=1e-05,
dropout=0.0,
attention_dropout=0.0,
initializer_range=0.02,
initializer_factor=1.0,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
model_type="clip_text_model",
projection_dim=768,
# torch_dtype="float32",
# transformers_version="4.25.0.dev0",
)
with init_empty_weights():
text_model1 = CLIPTextModel._from_config(text_model1_cfg)
# Text Encoder 2 is different from Stability AI's SDXL. SDXL uses open clip, but we use the model from HuggingFace.
# Note: Tokenizer from HuggingFace is different from SDXL. We must use open clip's tokenizer.
text_model2_cfg = CLIPTextConfig(
vocab_size=49408,
hidden_size=1280,
intermediate_size=5120,
num_hidden_layers=32,
num_attention_heads=20,
max_position_embeddings=77,
hidden_act="gelu",
layer_norm_eps=1e-05,
dropout=0.0,
attention_dropout=0.0,
initializer_range=0.02,
initializer_factor=1.0,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
model_type="clip_text_model",
projection_dim=1280,
# torch_dtype="float32",
# transformers_version="4.25.0.dev0",
)
with init_empty_weights():
text_model2 = CLIPTextModelWithProjection(text_model2_cfg)
print("loading text encoders from checkpoint")
te1_sd = {}
te2_sd = {}
for k in list(state_dict.keys()):
if k.startswith("conditioner.embedders.0.transformer."):
te1_sd[k.replace("conditioner.embedders.0.transformer.", "")] = state_dict.pop(k)
elif k.startswith("conditioner.embedders.1.model."):
te2_sd[k] = state_dict.pop(k)
# 一部のposition_idsがないモデルへの対応 / add position_ids for some models
if "text_model.embeddings.position_ids" not in te1_sd:
te1_sd["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0)
info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32
print("text encoder 1:", info1)
converted_sd, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77)
info2 = _load_state_dict_on_device(text_model2, converted_sd, device=map_location) # remain fp32
print("text encoder 2:", info2)
# prepare vae
print("building VAE")
vae_config = model_util.create_vae_diffusers_config()
with init_empty_weights():
vae = AutoencoderKL(**vae_config)
print("loading VAE from checkpoint")
converted_vae_checkpoint = model_util.convert_ldm_vae_checkpoint(state_dict, vae_config)
info = _load_state_dict_on_device(vae, converted_vae_checkpoint, device=map_location, dtype=dtype)
print("VAE:", info)
ckpt_info = (epoch, global_step) if epoch is not None else None
return text_model1, text_model2, vae, unet, logit_scale, ckpt_info
def make_unet_conversion_map():
unet_conversion_map_layer = []
for i in range(3): # num_blocks is 3 in sdxl
# loop over downblocks/upblocks
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3:
# no attention layers in down_blocks.3
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(3):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
# if i > 0: commentout for sdxl
# no attention layers in up_blocks.0
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0."
sd_mid_atn_prefix = "middle_block.1."
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2*j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
unet_conversion_map_resnet = [
# (stable-diffusion, HF Diffusers)
("in_layers.0.", "norm1."),
("in_layers.2.", "conv1."),
("out_layers.0.", "norm2."),
("out_layers.3.", "conv2."),
("emb_layers.1.", "time_emb_proj."),
("skip_connection.", "conv_shortcut."),
]
unet_conversion_map = []
for sd, hf in unet_conversion_map_layer:
if "resnets" in hf:
for sd_res, hf_res in unet_conversion_map_resnet:
unet_conversion_map.append((sd + sd_res, hf + hf_res))
else:
unet_conversion_map.append((sd, hf))
for j in range(2):
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
sd_time_embed_prefix = f"time_embed.{j*2}."
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
for j in range(2):
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
sd_label_embed_prefix = f"label_emb.0.{j*2}."
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
unet_conversion_map.append(("out.0.", "conv_norm_out."))
unet_conversion_map.append(("out.2.", "conv_out."))
return unet_conversion_map
def convert_diffusers_unet_state_dict_to_sdxl(du_sd):
unet_conversion_map = make_unet_conversion_map()
conversion_map = {hf: sd for sd, hf in unet_conversion_map}
return convert_unet_state_dict(du_sd, conversion_map)
def convert_unet_state_dict(src_sd, conversion_map):
converted_sd = {}
for src_key, value in src_sd.items():
# さすがに全部回すのは時間がかかるので右から要素を削りつつprefixを探す
src_key_fragments = src_key.split(".")[:-1] # remove weight/bias
while len(src_key_fragments) > 0:
src_key_prefix = ".".join(src_key_fragments) + "."
if src_key_prefix in conversion_map:
converted_prefix = conversion_map[src_key_prefix]
converted_key = converted_prefix + src_key[len(src_key_prefix) :]
converted_sd[converted_key] = value
break
src_key_fragments.pop(-1)
assert len(src_key_fragments) > 0, f"key {src_key} not found in conversion map"
return converted_sd
def convert_sdxl_unet_state_dict_to_diffusers(sd):
unet_conversion_map = make_unet_conversion_map()
conversion_dict = {sd: hf for sd, hf in unet_conversion_map}
return convert_unet_state_dict(sd, conversion_dict)
def convert_text_encoder_2_state_dict_to_sdxl(checkpoint, logit_scale):
def convert_key(key):
# position_idsの除去
if ".position_ids" in key:
return None
# common
key = key.replace("text_model.encoder.", "transformer.")
key = key.replace("text_model.", "")
if "layers" in key:
# resblocks conversion
key = key.replace(".layers.", ".resblocks.")
if ".layer_norm" in key:
key = key.replace(".layer_norm", ".ln_")
elif ".mlp." in key:
key = key.replace(".fc1.", ".c_fc.")
key = key.replace(".fc2.", ".c_proj.")
elif ".self_attn.out_proj" in key:
key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
elif ".self_attn." in key:
key = None # 特殊なので後で処理する
else:
raise ValueError(f"unexpected key in DiffUsers model: {key}")
elif ".position_embedding" in key:
key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
elif ".token_embedding" in key:
key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
elif "text_projection" in key: # no dot in key
key = key.replace("text_projection.weight", "text_projection")
elif "final_layer_norm" in key:
key = key.replace("final_layer_norm", "ln_final")
return key
keys = list(checkpoint.keys())
new_sd = {}
for key in keys:
new_key = convert_key(key)
if new_key is None:
continue
new_sd[new_key] = checkpoint[key]
# attnの変換
for key in keys:
if "layers" in key and "q_proj" in key:
# 三つを結合
key_q = key
key_k = key.replace("q_proj", "k_proj")
key_v = key.replace("q_proj", "v_proj")
value_q = checkpoint[key_q]
value_k = checkpoint[key_k]
value_v = checkpoint[key_v]
value = torch.cat([value_q, value_k, value_v])
new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
new_sd[new_key] = value
if logit_scale is not None:
new_sd["logit_scale"] = logit_scale
return new_sd
def save_stable_diffusion_checkpoint(
output_file,
text_encoder1,
text_encoder2,
unet,
epochs,
steps,
ckpt_info,
vae,
logit_scale,
metadata,
save_dtype=None,
):
state_dict = {}
def update_sd(prefix, sd):
for k, v in sd.items():
key = prefix + k
if save_dtype is not None:
v = v.detach().clone().to("cpu").to(save_dtype)
state_dict[key] = v
# Convert the UNet model
update_sd("model.diffusion_model.", unet.state_dict())
# Convert the text encoders
update_sd("conditioner.embedders.0.transformer.", text_encoder1.state_dict())
text_enc2_dict = convert_text_encoder_2_state_dict_to_sdxl(text_encoder2.state_dict(), logit_scale)
update_sd("conditioner.embedders.1.model.", text_enc2_dict)
# Convert the VAE
vae_dict = model_util.convert_vae_state_dict(vae.state_dict())
update_sd("first_stage_model.", vae_dict)
# Put together new checkpoint
key_count = len(state_dict.keys())
new_ckpt = {"state_dict": state_dict}
# epoch and global_step are sometimes not int
if ckpt_info is not None:
epochs += ckpt_info[0]
steps += ckpt_info[1]
new_ckpt["epoch"] = epochs
new_ckpt["global_step"] = steps
if model_util.is_safetensors(output_file):
save_file(state_dict, output_file, metadata)
else:
torch.save(new_ckpt, output_file)
return key_count
def save_diffusers_checkpoint(
output_dir, text_encoder1, text_encoder2, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False, save_dtype=None
):
from diffusers import StableDiffusionXLPipeline
# convert U-Net
unet_sd = unet.state_dict()
du_unet_sd = convert_sdxl_unet_state_dict_to_diffusers(unet_sd)
diffusers_unet = UNet2DConditionModel(**DIFFUSERS_SDXL_UNET_CONFIG)
if save_dtype is not None:
diffusers_unet.to(save_dtype)
diffusers_unet.load_state_dict(du_unet_sd)
# create pipeline to save
if pretrained_model_name_or_path is None:
pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_SDXL
scheduler = EulerDiscreteScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
tokenizer1 = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
tokenizer2 = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer_2")
if vae is None:
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
# prevent local path from being saved
def remove_name_or_path(model):
if hasattr(model, "config"):
model.config._name_or_path = None
model.config._name_or_path = None
remove_name_or_path(diffusers_unet)
remove_name_or_path(text_encoder1)
remove_name_or_path(text_encoder2)
remove_name_or_path(scheduler)
remove_name_or_path(tokenizer1)
remove_name_or_path(tokenizer2)
remove_name_or_path(vae)
pipeline = StableDiffusionXLPipeline(
unet=diffusers_unet,
text_encoder=text_encoder1,
text_encoder_2=text_encoder2,
vae=vae,
scheduler=scheduler,
tokenizer=tokenizer1,
tokenizer_2=tokenizer2,
)
if save_dtype is not None:
pipeline.to(None, save_dtype)
pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)

File diff suppressed because it is too large Load Diff

369
library/sdxl_train_util.py Normal file
View File

@@ -0,0 +1,369 @@
import argparse
import gc
import math
import os
from typing import Optional
import torch
from accelerate import init_empty_weights
from tqdm import tqdm
from transformers import CLIPTokenizer
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
# DEFAULT_NOISE_OFFSET = 0.0357
def load_target_model(args, accelerator, model_version: str, weight_dtype):
# load models for each process
model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16
for pi in range(accelerator.state.num_processes):
if pi == accelerator.state.local_process_index:
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
(
load_stable_diffusion_format,
text_encoder1,
text_encoder2,
vae,
unet,
logit_scale,
ckpt_info,
) = _load_target_model(
args.pretrained_model_name_or_path,
args.vae,
model_version,
weight_dtype,
accelerator.device if args.lowram else "cpu",
model_dtype,
)
# work on low-ram device
if args.lowram:
text_encoder1.to(accelerator.device)
text_encoder2.to(accelerator.device)
unet.to(accelerator.device)
vae.to(accelerator.device)
gc.collect()
torch.cuda.empty_cache()
accelerator.wait_for_everyone()
text_encoder1, text_encoder2, unet = train_util.transform_models_if_DDP([text_encoder1, text_encoder2, unet])
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
def _load_target_model(
name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None
):
# model_dtype only work with full fp16/bf16
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
if load_stable_diffusion_format:
print(f"load StableDiffusion checkpoint: {name_or_path}")
(
text_encoder1,
text_encoder2,
vae,
unet,
logit_scale,
ckpt_info,
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, model_dtype)
else:
# Diffusers model is loaded to CPU
from diffusers import StableDiffusionXLPipeline
variant = "fp16" if weight_dtype == torch.float16 else None
print(f"load Diffusers pretrained models: {name_or_path}, variant={variant}")
try:
try:
pipe = StableDiffusionXLPipeline.from_pretrained(
name_or_path, torch_dtype=model_dtype, variant=variant, tokenizer=None
)
except EnvironmentError as ex:
if variant is not None:
print("try to load fp32 model")
pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, variant=None, tokenizer=None)
else:
raise ex
except EnvironmentError as ex:
print(
f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}"
)
raise ex
text_encoder1 = pipe.text_encoder
text_encoder2 = pipe.text_encoder_2
# convert to fp32 for cache text_encoders outputs
if text_encoder1.dtype != torch.float32:
text_encoder1 = text_encoder1.to(dtype=torch.float32)
if text_encoder2.dtype != torch.float32:
text_encoder2 = text_encoder2.to(dtype=torch.float32)
vae = pipe.vae
unet = pipe.unet
del pipe
# Diffusers U-Net to original U-Net
state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict())
with init_empty_weights():
unet = sdxl_original_unet.SdxlUNet2DConditionModel() # overwrite unet
sdxl_model_util._load_state_dict_on_device(unet, state_dict, device=device, dtype=model_dtype)
print("U-Net converted to original U-Net")
logit_scale = None
ckpt_info = None
# VAEを読み込む
if vae_path is not None:
vae = model_util.load_vae(vae_path, weight_dtype)
print("additional VAE loaded")
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
def load_tokenizers(args: argparse.Namespace):
print("prepare tokenizers")
original_paths = [TOKENIZER1_PATH, TOKENIZER2_PATH]
tokeniers = []
for i, original_path in enumerate(original_paths):
tokenizer: CLIPTokenizer = None
if args.tokenizer_cache_dir:
local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_"))
if os.path.exists(local_tokenizer_path):
print(f"load tokenizer from cache: {local_tokenizer_path}")
tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path)
if tokenizer is None:
tokenizer = CLIPTokenizer.from_pretrained(original_path)
if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
print(f"save Tokenizer to cache: {local_tokenizer_path}")
tokenizer.save_pretrained(local_tokenizer_path)
if i == 1:
tokenizer.pad_token_id = 0 # fix pad token id to make same as open clip tokenizer
tokeniers.append(tokenizer)
if hasattr(args, "max_token_length") and args.max_token_length is not None:
print(f"update token length: {args.max_token_length}")
return tokeniers
def match_mixed_precision(args, weight_dtype):
if args.full_fp16:
assert (
weight_dtype == torch.float16
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
return weight_dtype
elif args.full_bf16:
assert (
weight_dtype == torch.bfloat16
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
return weight_dtype
else:
return None
def timestep_embedding(timesteps, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
device=timesteps.device
)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def get_timestep_embedding(x, outdim):
assert len(x.shape) == 2
b, dims = x.shape[0], x.shape[1]
x = torch.flatten(x)
emb = timestep_embedding(x, outdim)
emb = torch.reshape(emb, (b, dims * outdim))
return emb
def get_size_embeddings(orig_size, crop_size, target_size, device):
emb1 = get_timestep_embedding(orig_size, 256)
emb2 = get_timestep_embedding(crop_size, 256)
emb3 = get_timestep_embedding(target_size, 256)
vector = torch.cat([emb1, emb2, emb3], dim=1).to(device)
return vector
def save_sd_model_on_train_end(
args: argparse.Namespace,
src_path: str,
save_stable_diffusion_format: bool,
use_safetensors: bool,
save_dtype: torch.dtype,
epoch: int,
global_step: int,
text_encoder1,
text_encoder2,
unet,
vae,
logit_scale,
ckpt_info,
):
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec(None, args, True, False, False, is_stable_diffusion_ckpt=True)
sdxl_model_util.save_stable_diffusion_checkpoint(
ckpt_file,
text_encoder1,
text_encoder2,
unet,
epoch_no,
global_step,
ckpt_info,
vae,
logit_scale,
sai_metadata,
save_dtype,
)
def diffusers_saver(out_dir):
sdxl_model_util.save_diffusers_checkpoint(
out_dir,
text_encoder1,
text_encoder2,
unet,
src_path,
vae,
use_safetensors=use_safetensors,
save_dtype=save_dtype,
)
train_util.save_sd_model_on_train_end_common(
args, save_stable_diffusion_format, use_safetensors, epoch, global_step, sd_saver, diffusers_saver
)
# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している
# on_epoch_end: Trueならepoch終了時、Falseならstep経過時
def save_sd_model_on_epoch_end_or_stepwise(
args: argparse.Namespace,
on_epoch_end: bool,
accelerator,
src_path,
save_stable_diffusion_format: bool,
use_safetensors: bool,
save_dtype: torch.dtype,
epoch: int,
num_train_epochs: int,
global_step: int,
text_encoder1,
text_encoder2,
unet,
vae,
logit_scale,
ckpt_info,
):
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec(None, args, True, False, False, is_stable_diffusion_ckpt=True)
sdxl_model_util.save_stable_diffusion_checkpoint(
ckpt_file,
text_encoder1,
text_encoder2,
unet,
epoch_no,
global_step,
ckpt_info,
vae,
logit_scale,
sai_metadata,
save_dtype,
)
def diffusers_saver(out_dir):
sdxl_model_util.save_diffusers_checkpoint(
out_dir,
text_encoder1,
text_encoder2,
unet,
src_path,
vae,
use_safetensors=use_safetensors,
save_dtype=save_dtype,
)
train_util.save_sd_model_on_epoch_end_or_stepwise_common(
args,
on_epoch_end,
accelerator,
save_stable_diffusion_format,
use_safetensors,
epoch,
num_train_epochs,
global_step,
sd_saver,
diffusers_saver,
)
def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
)
parser.add_argument(
"--cache_text_encoder_outputs_to_disk",
action="store_true",
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
)
def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません"
if args.v_parameterization:
print("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります")
if args.clip_skip is not None:
print("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")
# if args.multires_noise_iterations:
# print(
# f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります"
# )
# else:
# if args.noise_offset is None:
# args.noise_offset = DEFAULT_NOISE_OFFSET
# elif args.noise_offset != DEFAULT_NOISE_OFFSET:
# print(
# f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています"
# )
# print(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
assert (
not hasattr(args, "weighted_captions") or not args.weighted_captions
), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
if supportTextEncoderCaching:
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
args.cache_text_encoder_outputs = True
print(
"cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / "
+ "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました"
)
def sample_images(*args, **kwargs):
return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs)

679
library/slicing_vae.py Normal file
View File

@@ -0,0 +1,679 @@
# Modified from Diffusers to reduce VRAM usage
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
from diffusers.models.vae import DecoderOutput, DiagonalGaussianDistribution
from diffusers.models.autoencoder_kl import AutoencoderKLOutput
def slice_h(x, num_slices):
# slice with pad 1 both sides: to eliminate side effect of padding of conv2d
# Conv2dのpaddingの副作用を排除するために、両側にpad 1しながらHをスライスする
# NCHWでもNHWCでもどちらでも動く
size = (x.shape[2] + num_slices - 1) // num_slices
sliced = []
for i in range(num_slices):
if i == 0:
sliced.append(x[:, :, : size + 1, :])
else:
end = size * (i + 1) + 1
if x.shape[2] - end < 3: # if the last slice is too small, use the rest of the tensor 最後が細すぎるとconv2dできないので全部使う
end = x.shape[2]
sliced.append(x[:, :, size * i - 1 : end, :])
if end >= x.shape[2]:
break
return sliced
def cat_h(sliced):
# padding分を除いて結合する
cat = []
for i, x in enumerate(sliced):
if i == 0:
cat.append(x[:, :, :-1, :])
elif i == len(sliced) - 1:
cat.append(x[:, :, 1:, :])
else:
cat.append(x[:, :, 1:-1, :])
del x
x = torch.cat(cat, dim=2)
return x
def resblock_forward(_self, num_slices, input_tensor, temb):
assert _self.upsample is None and _self.downsample is None
assert _self.norm1.num_groups == _self.norm2.num_groups
assert temb is None
# make sure norms are on cpu
org_device = input_tensor.device
cpu_device = torch.device("cpu")
_self.norm1.to(cpu_device)
_self.norm2.to(cpu_device)
# GroupNormがCPUでfp16で動かない対策
org_dtype = input_tensor.dtype
if org_dtype == torch.float16:
_self.norm1.to(torch.float32)
_self.norm2.to(torch.float32)
# すべてのテンソルをCPUに移動する
input_tensor = input_tensor.to(cpu_device)
hidden_states = input_tensor
# どうもこれは結果が異なるようだ……
# def sliced_norm1(norm, x):
# num_div = 4 if up_block_idx <= 2 else x.shape[1] // norm.num_groups
# sliced_tensor = torch.chunk(x, num_div, dim=1)
# sliced_weight = torch.chunk(norm.weight, num_div, dim=0)
# sliced_bias = torch.chunk(norm.bias, num_div, dim=0)
# print(sliced_tensor[0].shape, num_div, sliced_weight[0].shape, sliced_bias[0].shape)
# normed_tensor = []
# for i in range(num_div):
# n = torch.group_norm(sliced_tensor[i], norm.num_groups, sliced_weight[i], sliced_bias[i], norm.eps)
# normed_tensor.append(n)
# del n
# x = torch.cat(normed_tensor, dim=1)
# return num_div, x
# normを分割すると結果が変わるので、ここだけは分割しない。GPUで計算するとVRAMが足りなくなるので、CPUで計算する。幸いCPUでもそこまで遅くない
if org_dtype == torch.float16:
hidden_states = hidden_states.to(torch.float32)
hidden_states = _self.norm1(hidden_states) # run on cpu
if org_dtype == torch.float16:
hidden_states = hidden_states.to(torch.float16)
sliced = slice_h(hidden_states, num_slices)
del hidden_states
for i in range(len(sliced)):
x = sliced[i]
sliced[i] = None
# 計算する部分だけGPUに移動する、以下同様
x = x.to(org_device)
x = _self.nonlinearity(x)
x = _self.conv1(x)
x = x.to(cpu_device)
sliced[i] = x
del x
hidden_states = cat_h(sliced)
del sliced
if org_dtype == torch.float16:
hidden_states = hidden_states.to(torch.float32)
hidden_states = _self.norm2(hidden_states) # run on cpu
if org_dtype == torch.float16:
hidden_states = hidden_states.to(torch.float16)
sliced = slice_h(hidden_states, num_slices)
del hidden_states
for i in range(len(sliced)):
x = sliced[i]
sliced[i] = None
x = x.to(org_device)
x = _self.nonlinearity(x)
x = _self.dropout(x)
x = _self.conv2(x)
x = x.to(cpu_device)
sliced[i] = x
del x
hidden_states = cat_h(sliced)
del sliced
# make shortcut
if _self.conv_shortcut is not None:
sliced = list(torch.chunk(input_tensor, num_slices, dim=2)) # no padding in conv_shortcut パディングがないので普通にスライスする
del input_tensor
for i in range(len(sliced)):
x = sliced[i]
sliced[i] = None
x = x.to(org_device)
x = _self.conv_shortcut(x)
x = x.to(cpu_device)
sliced[i] = x
del x
input_tensor = torch.cat(sliced, dim=2)
del sliced
output_tensor = (input_tensor + hidden_states) / _self.output_scale_factor
output_tensor = output_tensor.to(org_device) # 次のレイヤーがGPUで計算する
return output_tensor
class SlicingEncoder(nn.Module):
def __init__(
self,
in_channels=3,
out_channels=3,
down_block_types=("DownEncoderBlock2D",),
block_out_channels=(64,),
layers_per_block=2,
norm_num_groups=32,
act_fn="silu",
double_z=True,
num_slices=2,
):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = torch.nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
self.mid_block = None
self.down_blocks = nn.ModuleList([])
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=self.layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
add_downsample=not is_final_block,
resnet_eps=1e-6,
downsample_padding=0,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attention_head_dim=output_channel,
temb_channels=None,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default",
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
temb_channels=None,
)
self.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True) # とりあえずDiffusersのxformersを使う
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
conv_out_channels = 2 * out_channels if double_z else out_channels
self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
# replace forward of ResBlocks
def wrapper(func, module, num_slices):
def forward(*args, **kwargs):
return func(module, num_slices, *args, **kwargs)
return forward
self.num_slices = num_slices
div = num_slices / (2 ** (len(self.down_blocks) - 1)) # 深い層はそこまで分割しなくていいので適宜減らす
# print(f"initial divisor: {div}")
if div >= 2:
div = int(div)
for resnet in self.mid_block.resnets:
resnet.forward = wrapper(resblock_forward, resnet, div)
# midblock doesn't have downsample
for i, down_block in enumerate(self.down_blocks[::-1]):
if div >= 2:
div = int(div)
# print(f"down block: {i} divisor: {div}")
for resnet in down_block.resnets:
resnet.forward = wrapper(resblock_forward, resnet, div)
if down_block.downsamplers is not None:
# print("has downsample")
for downsample in down_block.downsamplers:
downsample.forward = wrapper(self.downsample_forward, downsample, div * 2)
div *= 2
def forward(self, x):
sample = x
del x
org_device = sample.device
cpu_device = torch.device("cpu")
# sample = self.conv_in(sample)
sample = sample.to(cpu_device)
sliced = slice_h(sample, self.num_slices)
del sample
for i in range(len(sliced)):
x = sliced[i]
sliced[i] = None
x = x.to(org_device)
x = self.conv_in(x)
x = x.to(cpu_device)
sliced[i] = x
del x
sample = cat_h(sliced)
del sliced
sample = sample.to(org_device)
# down
for down_block in self.down_blocks:
sample = down_block(sample)
# middle
sample = self.mid_block(sample)
# post-process
# ここも省メモリ化したいが、恐らくそこまでメモリを食わないので省略
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
return sample
def downsample_forward(self, _self, num_slices, hidden_states):
assert hidden_states.shape[1] == _self.channels
assert _self.use_conv and _self.padding == 0
print("downsample forward", num_slices, hidden_states.shape)
org_device = hidden_states.device
cpu_device = torch.device("cpu")
hidden_states = hidden_states.to(cpu_device)
pad = (0, 1, 0, 1)
hidden_states = torch.nn.functional.pad(hidden_states, pad, mode="constant", value=0)
# slice with even number because of stride 2
# strideが2なので偶数でスライスする
# slice with pad 1 both sides: to eliminate side effect of padding of conv2d
size = (hidden_states.shape[2] + num_slices - 1) // num_slices
size = size + 1 if size % 2 == 1 else size
sliced = []
for i in range(num_slices):
if i == 0:
sliced.append(hidden_states[:, :, : size + 1, :])
else:
end = size * (i + 1) + 1
if hidden_states.shape[2] - end < 4: # if the last slice is too small, use the rest of the tensor
end = hidden_states.shape[2]
sliced.append(hidden_states[:, :, size * i - 1 : end, :])
if end >= hidden_states.shape[2]:
break
del hidden_states
for i in range(len(sliced)):
x = sliced[i]
sliced[i] = None
x = x.to(org_device)
x = _self.conv(x)
x = x.to(cpu_device)
# ここだけ雰囲気が違うのはCopilotのせい
if i == 0:
hidden_states = x
else:
hidden_states = torch.cat([hidden_states, x], dim=2)
hidden_states = hidden_states.to(org_device)
# print("downsample forward done", hidden_states.shape)
return hidden_states
class SlicingDecoder(nn.Module):
def __init__(
self,
in_channels=3,
out_channels=3,
up_block_types=("UpDecoderBlock2D",),
block_out_channels=(64,),
layers_per_block=2,
norm_num_groups=32,
act_fn="silu",
num_slices=2,
):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
self.mid_block = None
self.up_blocks = nn.ModuleList([])
# mid
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default",
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
temb_channels=None,
)
self.mid_block.attentions[0].set_use_memory_efficient_attention_xformers(True) # とりあえずDiffusersのxformersを使う
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
up_block = get_up_block(
up_block_type,
num_layers=self.layers_per_block + 1,
in_channels=prev_output_channel,
out_channels=output_channel,
prev_output_channel=None,
add_upsample=not is_final_block,
resnet_eps=1e-6,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attention_head_dim=output_channel,
temb_channels=None,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
# replace forward of ResBlocks
def wrapper(func, module, num_slices):
def forward(*args, **kwargs):
return func(module, num_slices, *args, **kwargs)
return forward
self.num_slices = num_slices
div = num_slices / (2 ** (len(self.up_blocks) - 1))
print(f"initial divisor: {div}")
if div >= 2:
div = int(div)
for resnet in self.mid_block.resnets:
resnet.forward = wrapper(resblock_forward, resnet, div)
# midblock doesn't have upsample
for i, up_block in enumerate(self.up_blocks):
if div >= 2:
div = int(div)
# print(f"up block: {i} divisor: {div}")
for resnet in up_block.resnets:
resnet.forward = wrapper(resblock_forward, resnet, div)
if up_block.upsamplers is not None:
# print("has upsample")
for upsample in up_block.upsamplers:
upsample.forward = wrapper(self.upsample_forward, upsample, div * 2)
div *= 2
def forward(self, z):
sample = z
del z
sample = self.conv_in(sample)
# middle
sample = self.mid_block(sample)
# up
for i, up_block in enumerate(self.up_blocks):
sample = up_block(sample)
# post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
# conv_out with slicing because of VRAM usage
# conv_outはとてもVRAM使うのでスライスして対応
org_device = sample.device
cpu_device = torch.device("cpu")
sample = sample.to(cpu_device)
sliced = slice_h(sample, self.num_slices)
del sample
for i in range(len(sliced)):
x = sliced[i]
sliced[i] = None
x = x.to(org_device)
x = self.conv_out(x)
x = x.to(cpu_device)
sliced[i] = x
sample = cat_h(sliced)
del sliced
sample = sample.to(org_device)
return sample
def upsample_forward(self, _self, num_slices, hidden_states, output_size=None):
assert hidden_states.shape[1] == _self.channels
assert _self.use_conv_transpose == False and _self.use_conv
org_dtype = hidden_states.dtype
org_device = hidden_states.device
cpu_device = torch.device("cpu")
hidden_states = hidden_states.to(cpu_device)
sliced = slice_h(hidden_states, num_slices)
del hidden_states
for i in range(len(sliced)):
x = sliced[i]
sliced[i] = None
x = x.to(org_device)
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
# https://github.com/pytorch/pytorch/issues/86679
# PyTorch 2で直らないかね……
if org_dtype == torch.bfloat16:
x = x.to(torch.float32)
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
if org_dtype == torch.bfloat16:
x = x.to(org_dtype)
x = _self.conv(x)
# upsampleされてるのでpadは2になる
if i == 0:
x = x[:, :, :-2, :]
elif i == num_slices - 1:
x = x[:, :, 2:, :]
else:
x = x[:, :, 2:-2, :]
x = x.to(cpu_device)
sliced[i] = x
del x
hidden_states = torch.cat(sliced, dim=2)
# print("us hidden_states", hidden_states.shape)
del sliced
hidden_states = hidden_states.to(org_device)
return hidden_states
class SlicingAutoencoderKL(ModelMixin, ConfigMixin):
r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma
and Max Welling.
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
implements for all the model (such as downloading or saving, etc.)
Parameters:
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
down_block_types (`Tuple[str]`, *optional*, defaults to :
obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
up_block_types (`Tuple[str]`, *optional*, defaults to :
obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to :
obj:`(64,)`): Tuple of block output channels.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space.
sample_size (`int`, *optional*, defaults to `32`): TODO
"""
@register_to_config
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
block_out_channels: Tuple[int] = (64,),
layers_per_block: int = 1,
act_fn: str = "silu",
latent_channels: int = 4,
norm_num_groups: int = 32,
sample_size: int = 32,
num_slices: int = 16,
):
super().__init__()
# pass init params to Encoder
self.encoder = SlicingEncoder(
in_channels=in_channels,
out_channels=latent_channels,
down_block_types=down_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
double_z=True,
num_slices=num_slices,
)
# pass init params to Decoder
self.decoder = SlicingDecoder(
in_channels=latent_channels,
out_channels=out_channels,
up_block_types=up_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups,
act_fn=act_fn,
num_slices=num_slices,
)
self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
self.use_slicing = False
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
z = self.post_quant_conv(z)
dec = self.decoder(z)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
# これはバッチ方向のスライシング 紛らわしい
def enable_slicing(self):
r"""
Enable sliced VAE decoding.
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def forward(
self,
sample: torch.FloatTensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.FloatTensor]:
r"""
Args:
sample (`torch.FloatTensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z).sample
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)

File diff suppressed because it is too large Load Diff

View File

@@ -5,35 +5,41 @@ from safetensors.torch import load_file
def main(file):
print(f"loading: {file}")
if os.path.splitext(file)[1] == '.safetensors':
sd = load_file(file)
else:
sd = torch.load(file, map_location='cpu')
print(f"loading: {file}")
if os.path.splitext(file)[1] == ".safetensors":
sd = load_file(file)
else:
sd = torch.load(file, map_location="cpu")
values = []
values = []
keys = list(sd.keys())
for key in keys:
if 'lora_up' in key or 'lora_down' in key:
values.append((key, sd[key]))
print(f"number of LoRA modules: {len(values)}")
keys = list(sd.keys())
for key in keys:
if "lora_up" in key or "lora_down" in key:
values.append((key, sd[key]))
print(f"number of LoRA modules: {len(values)}")
for key, value in values:
value = value.to(torch.float32)
print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")
if args.show_all_keys:
for key in [k for k in keys if k not in values]:
values.append((key, sd[key]))
print(f"number of all modules: {len(values)}")
for key, value in values:
value = value.to(torch.float32)
print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("file", type=str, help="model file to check / 重みを確認するモデルファイル")
parser = argparse.ArgumentParser()
parser.add_argument("file", type=str, help="model file to check / 重みを確認するモデルファイル")
parser.add_argument("-s", "--show_all_keys", action="store_true", help="show all keys / 全てのキーを表示する")
return parser
return parser
if __name__ == '__main__':
parser = setup_parser()
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
args = parser.parse_args()
main(args.file)
main(args.file)

View File

@@ -0,0 +1,446 @@
import os
from typing import Optional, List, Type
import torch
from library import sdxl_original_unet
# input_blocksに適用するかどうか / if True, input_blocks are not applied
SKIP_INPUT_BLOCKS = False
# output_blocksに適用するかどうか / if True, output_blocks are not applied
SKIP_OUTPUT_BLOCKS = True
# conv2dに適用するかどうか / if True, conv2d are not applied
SKIP_CONV2D = False
# transformer_blocksのみに適用するかどうか。Trueの場合、ResBlockには適用されない
# if True, only transformer_blocks are applied, and ResBlocks are not applied
TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored because conv2d is not used in transformer_blocks
# Trueならattn1とattn2にのみ適用し、ffなどには適用しない / if True, apply only to attn1 and attn2, not to ff etc.
ATTN1_2_ONLY = True
# Trueならattn1のQKV、attn2のQにのみ適用する、ATTN1_2_ONLY指定時のみ有効 / if True, apply only to attn1 QKV and attn2 Q, only valid when ATTN1_2_ONLY is specified
ATTN_QKV_ONLY = True
# Trueならattn1やffなどにのみ適用し、attn2などには適用しない / if True, apply only to attn1 and ff, not to attn2
# ATTN1_2_ONLYと同時にTrueにできない / cannot be True at the same time as ATTN1_2_ONLY
ATTN1_ETC_ONLY = False # True
# transformer_blocksの最大インデックス。Noneなら全てのtransformer_blocksに適用
# max index of transformer_blocks. if None, apply to all transformer_blocks
TRANSFORMER_MAX_BLOCK_INDEX = None
class LLLiteModule(torch.nn.Module):
def __init__(self, depth, cond_emb_dim, name, org_module, mlp_dim, dropout=None, multiplier=1.0):
super().__init__()
self.is_conv2d = org_module.__class__.__name__ == "Conv2d"
self.lllite_name = name
self.cond_emb_dim = cond_emb_dim
self.org_module = [org_module]
self.dropout = dropout
self.multiplier = multiplier
if self.is_conv2d:
in_dim = org_module.in_channels
else:
in_dim = org_module.in_features
# conditioning1はconditioning imageを embedding する。timestepごとに呼ばれない
# conditioning1 embeds conditioning image. it is not called for each timestep
modules = []
modules.append(torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) # to latent (from VAE) size
if depth == 1:
modules.append(torch.nn.ReLU(inplace=True))
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
elif depth == 2:
modules.append(torch.nn.ReLU(inplace=True))
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=4, stride=4, padding=0))
elif depth == 3:
# kernel size 8は大きすぎるので、4にする / kernel size 8 is too large, so set it to 4
modules.append(torch.nn.ReLU(inplace=True))
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0))
modules.append(torch.nn.ReLU(inplace=True))
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
self.conditioning1 = torch.nn.Sequential(*modules)
# downで入力の次元数を削減する。LoRAにヒントを得ていることにする
# midでconditioning image embeddingと入力を結合する
# upで元の次元数に戻す
# これらはtimestepごとに呼ばれる
# reduce the number of input dimensions with down. inspired by LoRA
# combine conditioning image embedding and input with mid
# restore to the original dimension with up
# these are called for each timestep
if self.is_conv2d:
self.down = torch.nn.Sequential(
torch.nn.Conv2d(in_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
torch.nn.ReLU(inplace=True),
)
self.mid = torch.nn.Sequential(
torch.nn.Conv2d(mlp_dim + cond_emb_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
torch.nn.ReLU(inplace=True),
)
self.up = torch.nn.Sequential(
torch.nn.Conv2d(mlp_dim, in_dim, kernel_size=1, stride=1, padding=0),
)
else:
# midの前にconditioningをreshapeすること / reshape conditioning before mid
self.down = torch.nn.Sequential(
torch.nn.Linear(in_dim, mlp_dim),
torch.nn.ReLU(inplace=True),
)
self.mid = torch.nn.Sequential(
torch.nn.Linear(mlp_dim + cond_emb_dim, mlp_dim),
torch.nn.ReLU(inplace=True),
)
self.up = torch.nn.Sequential(
torch.nn.Linear(mlp_dim, in_dim),
)
# Zero-Convにする / set to Zero-Conv
torch.nn.init.zeros_(self.up[0].weight) # zero conv
self.depth = depth # 1~3
self.cond_emb = None
self.batch_cond_only = False # Trueなら推論時のcondにのみ適用する / if True, apply only to cond at inference
self.use_zeros_for_batch_uncond = False # Trueならuncondのconditioningを0にする / if True, set uncond conditioning to 0
# batch_cond_onlyとuse_zeros_for_batch_uncondはどちらも適用すると生成画像の色味がおかしくなるので実際には使えそうにない
# Controlの種類によっては使えるかも
# both batch_cond_only and use_zeros_for_batch_uncond make the color of the generated image strange, so it doesn't seem to be usable in practice
# it may be available depending on the type of Control
def set_cond_image(self, cond_image):
r"""
中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む
/ call the model inside, so if necessary, surround it with torch.no_grad()
"""
if cond_image is None:
self.cond_emb = None
return
# timestepごとに呼ばれないので、あらかじめ計算しておく / it is not called for each timestep, so calculate it in advance
# print(f"C {self.lllite_name}, cond_image.shape={cond_image.shape}")
cx = self.conditioning1(cond_image)
if not self.is_conv2d:
# reshape / b,c,h,w -> b,h*w,c
n, c, h, w = cx.shape
cx = cx.view(n, c, h * w).permute(0, 2, 1)
self.cond_emb = cx
def set_batch_cond_only(self, cond_only, zeros):
self.batch_cond_only = cond_only
self.use_zeros_for_batch_uncond = zeros
def apply_to(self):
self.org_forward = self.org_module[0].forward
self.org_module[0].forward = self.forward
def forward(self, x):
r"""
学習用の便利forward。元のモジュールのforwardを呼び出す
/ convenient forward for training. call the forward of the original module
"""
if self.multiplier == 0.0 or self.cond_emb is None:
return self.org_forward(x)
cx = self.cond_emb
if not self.batch_cond_only and x.shape[0] // 2 == cx.shape[0]: # inference only
cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1)
if self.use_zeros_for_batch_uncond:
cx[0::2] = 0.0 # uncond is zero
# print(f"C {self.lllite_name}, x.shape={x.shape}, cx.shape={cx.shape}")
# downで入力の次元数を削減し、conditioning image embeddingと結合する
# 加算ではなくchannel方向に結合することで、うまいこと混ぜてくれることを期待している
# down reduces the number of input dimensions and combines it with conditioning image embedding
# we expect that it will mix well by combining in the channel direction instead of adding
cx = torch.cat([cx, self.down(x if not self.batch_cond_only else x[1::2])], dim=1 if self.is_conv2d else 2)
cx = self.mid(cx)
if self.dropout is not None and self.training:
cx = torch.nn.functional.dropout(cx, p=self.dropout)
cx = self.up(cx) * self.multiplier
# residual (x) を加算して元のforwardを呼び出す / add residual (x) and call the original forward
if self.batch_cond_only:
zx = torch.zeros_like(x)
zx[1::2] += cx
cx = zx
x = self.org_forward(x + cx) # ここで元のモジュールを呼び出す / call the original module here
return x
class ControlNetLLLite(torch.nn.Module):
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
def __init__(
self,
unet: sdxl_original_unet.SdxlUNet2DConditionModel,
cond_emb_dim: int = 16,
mlp_dim: int = 16,
dropout: Optional[float] = None,
varbose: Optional[bool] = False,
multiplier: Optional[float] = 1.0,
) -> None:
super().__init__()
# self.unets = [unet]
def create_modules(
root_module: torch.nn.Module,
target_replace_modules: List[torch.nn.Module],
module_class: Type[object],
) -> List[torch.nn.Module]:
prefix = "lllite_unet"
modules = []
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
is_linear = child_module.__class__.__name__ == "Linear"
is_conv2d = child_module.__class__.__name__ == "Conv2d"
if is_linear or (is_conv2d and not SKIP_CONV2D):
# block indexからdepthを計算: depthはconditioningのサイズやチャネルを計算するのに使う
# block index to depth: depth is using to calculate conditioning size and channels
block_name, index1, index2 = (name + "." + child_name).split(".")[:3]
index1 = int(index1)
if block_name == "input_blocks":
if SKIP_INPUT_BLOCKS:
continue
depth = 1 if index1 <= 2 else (2 if index1 <= 5 else 3)
elif block_name == "middle_block":
depth = 3
elif block_name == "output_blocks":
if SKIP_OUTPUT_BLOCKS:
continue
depth = 3 if index1 <= 2 else (2 if index1 <= 5 else 1)
if int(index2) >= 2:
depth -= 1
else:
raise NotImplementedError()
lllite_name = prefix + "." + name + "." + child_name
lllite_name = lllite_name.replace(".", "_")
if TRANSFORMER_MAX_BLOCK_INDEX is not None:
p = lllite_name.find("transformer_blocks")
if p >= 0:
tf_index = int(lllite_name[p:].split("_")[2])
if tf_index > TRANSFORMER_MAX_BLOCK_INDEX:
continue
# time embは適用外とする
# attn2のconditioning (CLIPからの入力) はshapeが違うので適用できない
# time emb is not applied
# attn2 conditioning (input from CLIP) cannot be applied because the shape is different
if "emb_layers" in lllite_name or (
"attn2" in lllite_name and ("to_k" in lllite_name or "to_v" in lllite_name)
):
continue
if ATTN1_2_ONLY:
if not ("attn1" in lllite_name or "attn2" in lllite_name):
continue
if ATTN_QKV_ONLY:
if "to_out" in lllite_name:
continue
if ATTN1_ETC_ONLY:
if "proj_out" in lllite_name:
pass
elif "attn1" in lllite_name and (
"to_k" in lllite_name or "to_v" in lllite_name or "to_out" in lllite_name
):
pass
elif "ff_net_2" in lllite_name:
pass
else:
continue
module = module_class(
depth,
cond_emb_dim,
lllite_name,
child_module,
mlp_dim,
dropout=dropout,
multiplier=multiplier,
)
modules.append(module)
return modules
target_modules = ControlNetLLLite.UNET_TARGET_REPLACE_MODULE
if not TRANSFORMER_ONLY:
target_modules = target_modules + ControlNetLLLite.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
# create module instances
self.unet_modules: List[LLLiteModule] = create_modules(unet, target_modules, LLLiteModule)
print(f"create ControlNet LLLite for U-Net: {len(self.unet_modules)} modules.")
def forward(self, x):
return x # dummy
def set_cond_image(self, cond_image):
r"""
中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む
/ call the model inside, so if necessary, surround it with torch.no_grad()
"""
for module in self.unet_modules:
module.set_cond_image(cond_image)
def set_batch_cond_only(self, cond_only, zeros):
for module in self.unet_modules:
module.set_batch_cond_only(cond_only, zeros)
def set_multiplier(self, multiplier):
for module in self.unet_modules:
module.multiplier = multiplier
def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
info = self.load_state_dict(weights_sd, False)
return info
def apply_to(self):
print("applying LLLite for U-Net...")
for module in self.unet_modules:
module.apply_to()
self.add_module(module.lllite_name, module)
# マージできるかどうかを返す
def is_mergeable(self):
return False
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
raise NotImplementedError()
def enable_gradient_checkpointing(self):
# not supported
pass
def prepare_optimizer_params(self):
self.requires_grad_(True)
return self.parameters()
def prepare_grad_etc(self):
self.requires_grad_(True)
def on_epoch_start(self):
self.train()
def get_trainable_params(self):
return self.parameters()
def save_weights(self, file, dtype, metadata):
if metadata is not None and len(metadata) == 0:
metadata = None
state_dict = self.state_dict()
if dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(dtype)
state_dict[key] = v
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file
save_file(state_dict, file, metadata)
else:
torch.save(state_dict, file)
if __name__ == "__main__":
# デバッグ用 / for debug
# sdxl_original_unet.USE_REENTRANT = False
# test shape etc
print("create unet")
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
unet.to("cuda").to(torch.float16)
print("create ControlNet-LLLite")
control_net = ControlNetLLLite(unet, 32, 64)
control_net.apply_to()
control_net.to("cuda")
print(control_net)
# print number of parameters
print("number of parameters", sum(p.numel() for p in control_net.parameters() if p.requires_grad))
input()
unet.set_use_memory_efficient_attention(True, False)
unet.set_gradient_checkpointing(True)
unet.train() # for gradient checkpointing
control_net.train()
# # visualize
# import torchviz
# print("run visualize")
# controlnet.set_control(conditioning_image)
# output = unet(x, t, ctx, y)
# print("make_dot")
# image = torchviz.make_dot(output, params=dict(controlnet.named_parameters()))
# print("render")
# image.format = "svg" # "png"
# image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time
# input()
import bitsandbytes
optimizer = bitsandbytes.adam.Adam8bit(control_net.prepare_optimizer_params(), 1e-3)
scaler = torch.cuda.amp.GradScaler(enabled=True)
print("start training")
steps = 10
sample_param = [p for p in control_net.named_parameters() if "up" in p[0]][0]
for step in range(steps):
print(f"step {step}")
batch_size = 1
conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0
x = torch.randn(batch_size, 4, 128, 128).cuda()
t = torch.randint(low=0, high=10, size=(batch_size,)).cuda()
ctx = torch.randn(batch_size, 77, 2048).cuda()
y = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda()
with torch.cuda.amp.autocast(enabled=True):
control_net.set_cond_image(conditioning_image)
output = unet(x, t, ctx, y)
target = torch.randn_like(output)
loss = torch.nn.functional.mse_loss(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
print(sample_param)
# from safetensors.torch import save_file
# save_file(control_net.state_dict(), "logs/control_net.safetensors")

View File

@@ -0,0 +1,502 @@
# cond_imageをU-Netのforwardで渡すバージョンのControlNet-LLLite検証用実装
# ControlNet-LLLite implementation for verification with cond_image passed in U-Net's forward
import os
import re
from typing import Optional, List, Type
import torch
from library import sdxl_original_unet
# input_blocksに適用するかどうか / if True, input_blocks are not applied
SKIP_INPUT_BLOCKS = False
# output_blocksに適用するかどうか / if True, output_blocks are not applied
SKIP_OUTPUT_BLOCKS = True
# conv2dに適用するかどうか / if True, conv2d are not applied
SKIP_CONV2D = False
# transformer_blocksのみに適用するかどうか。Trueの場合、ResBlockには適用されない
# if True, only transformer_blocks are applied, and ResBlocks are not applied
TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored because conv2d is not used in transformer_blocks
# Trueならattn1とattn2にのみ適用し、ffなどには適用しない / if True, apply only to attn1 and attn2, not to ff etc.
ATTN1_2_ONLY = True
# Trueならattn1のQKV、attn2のQにのみ適用する、ATTN1_2_ONLY指定時のみ有効 / if True, apply only to attn1 QKV and attn2 Q, only valid when ATTN1_2_ONLY is specified
ATTN_QKV_ONLY = True
# Trueならattn1やffなどにのみ適用し、attn2などには適用しない / if True, apply only to attn1 and ff, not to attn2
# ATTN1_2_ONLYと同時にTrueにできない / cannot be True at the same time as ATTN1_2_ONLY
ATTN1_ETC_ONLY = False # True
# transformer_blocksの最大インデックス。Noneなら全てのtransformer_blocksに適用
# max index of transformer_blocks. if None, apply to all transformer_blocks
TRANSFORMER_MAX_BLOCK_INDEX = None
ORIGINAL_LINEAR = torch.nn.Linear
ORIGINAL_CONV2D = torch.nn.Conv2d
def add_lllite_modules(module: torch.nn.Module, in_dim: int, depth, cond_emb_dim, mlp_dim) -> None:
# conditioning1はconditioning imageを embedding する。timestepごとに呼ばれない
# conditioning1 embeds conditioning image. it is not called for each timestep
modules = []
modules.append(ORIGINAL_CONV2D(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) # to latent (from VAE) size
if depth == 1:
modules.append(torch.nn.ReLU(inplace=True))
modules.append(ORIGINAL_CONV2D(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
elif depth == 2:
modules.append(torch.nn.ReLU(inplace=True))
modules.append(ORIGINAL_CONV2D(cond_emb_dim // 2, cond_emb_dim, kernel_size=4, stride=4, padding=0))
elif depth == 3:
# kernel size 8は大きすぎるので、4にする / kernel size 8 is too large, so set it to 4
modules.append(torch.nn.ReLU(inplace=True))
modules.append(ORIGINAL_CONV2D(cond_emb_dim // 2, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0))
modules.append(torch.nn.ReLU(inplace=True))
modules.append(ORIGINAL_CONV2D(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
module.lllite_conditioning1 = torch.nn.Sequential(*modules)
# downで入力の次元数を削減する。LoRAにヒントを得ていることにする
# midでconditioning image embeddingと入力を結合する
# upで元の次元数に戻す
# これらはtimestepごとに呼ばれる
# reduce the number of input dimensions with down. inspired by LoRA
# combine conditioning image embedding and input with mid
# restore to the original dimension with up
# these are called for each timestep
module.lllite_down = torch.nn.Sequential(
ORIGINAL_LINEAR(in_dim, mlp_dim),
torch.nn.ReLU(inplace=True),
)
module.lllite_mid = torch.nn.Sequential(
ORIGINAL_LINEAR(mlp_dim + cond_emb_dim, mlp_dim),
torch.nn.ReLU(inplace=True),
)
module.lllite_up = torch.nn.Sequential(
ORIGINAL_LINEAR(mlp_dim, in_dim),
)
# Zero-Convにする / set to Zero-Conv
torch.nn.init.zeros_(module.lllite_up[0].weight) # zero conv
class LLLiteLinear(ORIGINAL_LINEAR):
def __init__(self, in_features: int, out_features: int, **kwargs):
super().__init__(in_features, out_features, **kwargs)
self.enabled = False
def set_lllite(self, depth, cond_emb_dim, name, mlp_dim, dropout=None, multiplier=1.0):
self.enabled = True
self.lllite_name = name
self.cond_emb_dim = cond_emb_dim
self.dropout = dropout
self.multiplier = multiplier # ignored
in_dim = self.in_features
add_lllite_modules(self, in_dim, depth, cond_emb_dim, mlp_dim)
self.cond_image = None
self.cond_emb = None
def set_cond_image(self, cond_image):
self.cond_image = cond_image
self.cond_emb = None
def forward(self, x):
if not self.enabled:
return super().forward(x)
if self.cond_emb is None:
self.cond_emb = self.lllite_conditioning1(self.cond_image)
cx = self.cond_emb
# reshape / b,c,h,w -> b,h*w,c
n, c, h, w = cx.shape
cx = cx.view(n, c, h * w).permute(0, 2, 1)
cx = torch.cat([cx, self.lllite_down(x)], dim=2)
cx = self.lllite_mid(cx)
if self.dropout is not None and self.training:
cx = torch.nn.functional.dropout(cx, p=self.dropout)
cx = self.lllite_up(cx) * self.multiplier
x = super().forward(x + cx) # ここで元のモジュールを呼び出す / call the original module here
return x
class LLLiteConv2d(ORIGINAL_CONV2D):
def __init__(self, in_channels: int, out_channels: int, kernel_size, **kwargs):
super().__init__(in_channels, out_channels, kernel_size, **kwargs)
self.enabled = False
def set_lllite(self, depth, cond_emb_dim, name, mlp_dim, dropout=None, multiplier=1.0):
self.enabled = True
self.lllite_name = name
self.cond_emb_dim = cond_emb_dim
self.dropout = dropout
self.multiplier = multiplier # ignored
in_dim = self.in_channels
add_lllite_modules(self, in_dim, depth, cond_emb_dim, mlp_dim)
self.cond_image = None
self.cond_emb = None
def set_cond_image(self, cond_image):
self.cond_image = cond_image
self.cond_emb = None
def forward(self, x): # , cond_image=None):
if not self.enabled:
return super().forward(x)
if self.cond_emb is None:
self.cond_emb = self.lllite_conditioning1(self.cond_image)
cx = self.cond_emb
cx = torch.cat([cx, self.down(x)], dim=1)
cx = self.mid(cx)
if self.dropout is not None and self.training:
cx = torch.nn.functional.dropout(cx, p=self.dropout)
cx = self.up(cx) * self.multiplier
x = super().forward(x + cx) # ここで元のモジュールを呼び出す / call the original module here
return x
class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DConditionModel):
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
LLLITE_PREFIX = "lllite_unet"
def __init__(self, **kwargs):
super().__init__(**kwargs)
def apply_lllite(
self,
cond_emb_dim: int = 16,
mlp_dim: int = 16,
dropout: Optional[float] = None,
varbose: Optional[bool] = False,
multiplier: Optional[float] = 1.0,
) -> None:
def apply_to_modules(
root_module: torch.nn.Module,
target_replace_modules: List[torch.nn.Module],
) -> List[torch.nn.Module]:
prefix = "lllite_unet"
modules = []
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
is_linear = child_module.__class__.__name__ == "LLLiteLinear"
is_conv2d = child_module.__class__.__name__ == "LLLiteConv2d"
if is_linear or (is_conv2d and not SKIP_CONV2D):
# block indexからdepthを計算: depthはconditioningのサイズやチャネルを計算するのに使う
# block index to depth: depth is using to calculate conditioning size and channels
block_name, index1, index2 = (name + "." + child_name).split(".")[:3]
index1 = int(index1)
if block_name == "input_blocks":
if SKIP_INPUT_BLOCKS:
continue
depth = 1 if index1 <= 2 else (2 if index1 <= 5 else 3)
elif block_name == "middle_block":
depth = 3
elif block_name == "output_blocks":
if SKIP_OUTPUT_BLOCKS:
continue
depth = 3 if index1 <= 2 else (2 if index1 <= 5 else 1)
if int(index2) >= 2:
depth -= 1
else:
raise NotImplementedError()
lllite_name = prefix + "." + name + "." + child_name
lllite_name = lllite_name.replace(".", "_")
if TRANSFORMER_MAX_BLOCK_INDEX is not None:
p = lllite_name.find("transformer_blocks")
if p >= 0:
tf_index = int(lllite_name[p:].split("_")[2])
if tf_index > TRANSFORMER_MAX_BLOCK_INDEX:
continue
# time embは適用外とする
# attn2のconditioning (CLIPからの入力) はshapeが違うので適用できない
# time emb is not applied
# attn2 conditioning (input from CLIP) cannot be applied because the shape is different
if "emb_layers" in lllite_name or (
"attn2" in lllite_name and ("to_k" in lllite_name or "to_v" in lllite_name)
):
continue
if ATTN1_2_ONLY:
if not ("attn1" in lllite_name or "attn2" in lllite_name):
continue
if ATTN_QKV_ONLY:
if "to_out" in lllite_name:
continue
if ATTN1_ETC_ONLY:
if "proj_out" in lllite_name:
pass
elif "attn1" in lllite_name and (
"to_k" in lllite_name or "to_v" in lllite_name or "to_out" in lllite_name
):
pass
elif "ff_net_2" in lllite_name:
pass
else:
continue
child_module.set_lllite(depth, cond_emb_dim, lllite_name, mlp_dim, dropout, multiplier)
modules.append(child_module)
return modules
target_modules = SdxlUNet2DConditionModelControlNetLLLite.UNET_TARGET_REPLACE_MODULE
if not TRANSFORMER_ONLY:
target_modules = target_modules + SdxlUNet2DConditionModelControlNetLLLite.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
# create module instances
self.lllite_modules = apply_to_modules(self, target_modules)
print(f"enable ControlNet LLLite for U-Net: {len(self.lllite_modules)} modules.")
# def prepare_optimizer_params(self):
def prepare_params(self):
train_params = []
non_train_params = []
for name, p in self.named_parameters():
if "lllite" in name:
train_params.append(p)
else:
non_train_params.append(p)
print(f"count of trainable parameters: {len(train_params)}")
print(f"count of non-trainable parameters: {len(non_train_params)}")
for p in non_train_params:
p.requires_grad_(False)
# without this, an error occurs in the optimizer
# RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
non_train_params[0].requires_grad_(True)
for p in train_params:
p.requires_grad_(True)
return train_params
# def prepare_grad_etc(self):
# self.requires_grad_(True)
# def on_epoch_start(self):
# self.train()
def get_trainable_params(self):
return [p[1] for p in self.named_parameters() if "lllite" in p[0]]
def save_lllite_weights(self, file, dtype, metadata):
if metadata is not None and len(metadata) == 0:
metadata = None
org_state_dict = self.state_dict()
# copy LLLite keys from org_state_dict to state_dict with key conversion
state_dict = {}
for key in org_state_dict.keys():
# split with ".lllite"
pos = key.find(".lllite")
if pos < 0:
continue
lllite_key = SdxlUNet2DConditionModelControlNetLLLite.LLLITE_PREFIX + "." + key[:pos]
lllite_key = lllite_key.replace(".", "_") + key[pos:]
lllite_key = lllite_key.replace(".lllite_", ".")
state_dict[lllite_key] = org_state_dict[key]
if dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(dtype)
state_dict[key] = v
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file
save_file(state_dict, file, metadata)
else:
torch.save(state_dict, file)
def load_lllite_weights(self, file, non_lllite_unet_sd=None):
r"""
LLLiteの重みを読み込まないinitされた値を使う場合はfileにNoneを指定する。
この場合、non_lllite_unet_sdにはU-Netのstate_dictを指定する。
If you do not want to load LLLite weights (use initialized values), specify None for file.
In this case, specify the state_dict of U-Net for non_lllite_unet_sd.
"""
if not file:
state_dict = self.state_dict()
for key in non_lllite_unet_sd:
if key in state_dict:
state_dict[key] = non_lllite_unet_sd[key]
info = self.load_state_dict(state_dict, False)
return info
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
# module_name = module_name.replace("_block", "@blocks")
# module_name = module_name.replace("_layer", "@layer")
# module_name = module_name.replace("to_", "to@")
# module_name = module_name.replace("time_embed", "time@embed")
# module_name = module_name.replace("label_emb", "label@emb")
# module_name = module_name.replace("skip_connection", "skip@connection")
# module_name = module_name.replace("proj_in", "proj@in")
# module_name = module_name.replace("proj_out", "proj@out")
pattern = re.compile(r"(_block|_layer|to_|time_embed|label_emb|skip_connection|proj_in|proj_out)")
# convert to lllite with U-Net state dict
state_dict = non_lllite_unet_sd.copy() if non_lllite_unet_sd is not None else {}
for key in weights_sd.keys():
# split with "."
pos = key.find(".")
if pos < 0:
continue
module_name = key[:pos]
weight_name = key[pos + 1 :] # exclude "."
module_name = module_name.replace(SdxlUNet2DConditionModelControlNetLLLite.LLLITE_PREFIX + "_", "")
# これはうまくいかない。逆変換を考えなかった設計が悪い / this does not work well. bad design because I didn't think about inverse conversion
# module_name = module_name.replace("_", ".")
# ださいけどSDXLのU-Netの "_" を "@" に変換する / ugly but convert "_" of SDXL U-Net to "@"
matches = pattern.findall(module_name)
if matches is not None:
for m in matches:
print(module_name, m)
module_name = module_name.replace(m, m.replace("_", "@"))
module_name = module_name.replace("_", ".")
module_name = module_name.replace("@", "_")
lllite_key = module_name + ".lllite_" + weight_name
state_dict[lllite_key] = weights_sd[key]
info = self.load_state_dict(state_dict, False)
return info
def forward(self, x, timesteps=None, context=None, y=None, cond_image=None, **kwargs):
for m in self.lllite_modules:
m.set_cond_image(cond_image)
return super().forward(x, timesteps, context, y, **kwargs)
def replace_unet_linear_and_conv2d():
print("replace torch.nn.Linear and torch.nn.Conv2d to LLLiteLinear and LLLiteConv2d in U-Net")
sdxl_original_unet.torch.nn.Linear = LLLiteLinear
sdxl_original_unet.torch.nn.Conv2d = LLLiteConv2d
if __name__ == "__main__":
# デバッグ用 / for debug
# sdxl_original_unet.USE_REENTRANT = False
replace_unet_linear_and_conv2d()
# test shape etc
print("create unet")
unet = SdxlUNet2DConditionModelControlNetLLLite()
print("enable ControlNet-LLLite")
unet.apply_lllite(32, 64, None, False, 1.0)
unet.to("cuda") # .to(torch.float16)
# from safetensors.torch import load_file
# model_sd = load_file(r"E:\Work\SD\Models\sdxl\sd_xl_base_1.0_0.9vae.safetensors")
# unet_sd = {}
# # copy U-Net keys from unet_state_dict to state_dict
# prefix = "model.diffusion_model."
# for key in model_sd.keys():
# if key.startswith(prefix):
# converted_key = key[len(prefix) :]
# unet_sd[converted_key] = model_sd[key]
# info = unet.load_lllite_weights("r:/lllite_from_unet.safetensors", unet_sd)
# print(info)
# print(unet)
# print number of parameters
params = unet.prepare_params()
print("number of parameters", sum(p.numel() for p in params))
# print("type any key to continue")
# input()
unet.set_use_memory_efficient_attention(True, False)
unet.set_gradient_checkpointing(True)
unet.train() # for gradient checkpointing
# # visualize
# import torchviz
# print("run visualize")
# controlnet.set_control(conditioning_image)
# output = unet(x, t, ctx, y)
# print("make_dot")
# image = torchviz.make_dot(output, params=dict(controlnet.named_parameters()))
# print("render")
# image.format = "svg" # "png"
# image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time
# input()
import bitsandbytes
optimizer = bitsandbytes.adam.Adam8bit(params, 1e-3)
scaler = torch.cuda.amp.GradScaler(enabled=True)
print("start training")
steps = 10
batch_size = 1
sample_param = [p for p in unet.named_parameters() if ".lllite_up." in p[0]][0]
for step in range(steps):
print(f"step {step}")
conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0
x = torch.randn(batch_size, 4, 128, 128).cuda()
t = torch.randint(low=0, high=10, size=(batch_size,)).cuda()
ctx = torch.randn(batch_size, 77, 2048).cuda()
y = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda()
with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
output = unet(x, t, ctx, y, conditioning_image)
target = torch.randn_like(output)
loss = torch.nn.functional.mse_loss(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
print(sample_param)
# from safetensors.torch import save_file
# print("save weights")
# unet.save_lllite_weights("r:/lllite_from_unet.safetensors", torch.float16, None)

View File

@@ -239,7 +239,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
class DyLoRANetwork(torch.nn.Module):
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
LORA_PREFIX_UNET = "lora_unet"

View File

@@ -3,187 +3,265 @@
# Thanks to cloneofsimo!
import argparse
import json
import os
import time
import torch
from safetensors.torch import load_file, save_file
from tqdm import tqdm
import library.model_util as model_util
from library import sai_model_spec, model_util, sdxl_model_util
import lora
CLAMP_QUANTILE = 0.99
MIN_DIFF = 1e-6
MIN_DIFF = 1e-1
def save_to_file(file_name, model, state_dict, dtype):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
if os.path.splitext(file_name)[1] == '.safetensors':
save_file(model, file_name)
else:
torch.save(model, file_name)
if os.path.splitext(file_name)[1] == ".safetensors":
save_file(model, file_name)
else:
torch.save(model, file_name)
def svd(args):
def str_to_dtype(p):
if p == 'float':
return torch.float
if p == 'fp16':
return torch.float16
if p == 'bf16':
return torch.bfloat16
return None
def str_to_dtype(p):
if p == "float":
return torch.float
if p == "fp16":
return torch.float16
if p == "bf16":
return torch.bfloat16
return None
save_dtype = str_to_dtype(args.save_precision)
assert args.v2 != args.sdxl or (
not args.v2 and not args.sdxl
), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません"
if args.v_parameterization is None:
args.v_parameterization = args.v2
print(f"loading SD model : {args.model_org}")
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org)
print(f"loading SD model : {args.model_tuned}")
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned)
save_dtype = str_to_dtype(args.save_precision)
# create LoRA network to extract weights: Use dim (rank) as alpha
if args.conv_dim is None:
kwargs = {}
else:
kwargs = {"conv_dim": args.conv_dim, "conv_alpha": args.conv_dim}
# load models
if not args.sdxl:
print(f"loading original SD model : {args.model_org}")
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org)
text_encoders_o = [text_encoder_o]
print(f"loading tuned SD model : {args.model_tuned}")
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned)
text_encoders_t = [text_encoder_t]
model_version = model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization)
else:
print(f"loading original SDXL model : {args.model_org}")
text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.model_org, "cpu"
)
text_encoders_o = [text_encoder_o1, text_encoder_o2]
print(f"loading original SDXL model : {args.model_tuned}")
text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.model_tuned, "cpu"
)
text_encoders_t = [text_encoder_t1, text_encoder_t2]
model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0
lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_o, unet_o, **kwargs)
lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_t, unet_t, **kwargs)
assert len(lora_network_o.text_encoder_loras) == len(
lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違いますSD1.xベースとSD2.xベース "
# create LoRA network to extract weights: Use dim (rank) as alpha
if args.conv_dim is None:
kwargs = {}
else:
kwargs = {"conv_dim": args.conv_dim, "conv_alpha": args.conv_dim}
# get diffs
diffs = {}
text_encoder_different = False
for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras)):
lora_name = lora_o.lora_name
module_o = lora_o.org_module
module_t = lora_t.org_module
diff = module_t.weight - module_o.weight
lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoders_o, unet_o, **kwargs)
lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoders_t, unet_t, **kwargs)
assert len(lora_network_o.text_encoder_loras) == len(
lora_network_t.text_encoder_loras
), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違いますSD1.xベースとSD2.xベース "
# Text Encoder might be same
if torch.max(torch.abs(diff)) > MIN_DIFF:
text_encoder_different = True
diff = diff.float()
diffs[lora_name] = diff
if not text_encoder_different:
print("Text encoder is same. Extract U-Net only.")
lora_network_o.text_encoder_loras = []
# get diffs
diffs = {}
text_encoder_different = False
for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras)):
lora_name = lora_o.lora_name
module_o = lora_o.org_module
module_t = lora_t.org_module
diff = module_t.weight - module_o.weight
for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)):
lora_name = lora_o.lora_name
module_o = lora_o.org_module
module_t = lora_t.org_module
diff = module_t.weight - module_o.weight
diff = diff.float()
# Text Encoder might be same
if not text_encoder_different and torch.max(torch.abs(diff)) > MIN_DIFF:
text_encoder_different = True
print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {MIN_DIFF}")
if args.device:
diff = diff.to(args.device)
diff = diff.float()
diffs[lora_name] = diff
diffs[lora_name] = diff
if not text_encoder_different:
print("Text encoder is same. Extract U-Net only.")
lora_network_o.text_encoder_loras = []
diffs = {}
# make LoRA with svd
print("calculating by svd")
lora_weights = {}
with torch.no_grad():
for lora_name, mat in tqdm(list(diffs.items())):
# if args.conv_dim is None, diffs do not include LoRAs for conv2d-3x3
conv2d = (len(mat.size()) == 4)
kernel_size = None if not conv2d else mat.size()[2:4]
conv2d_3x3 = conv2d and kernel_size != (1, 1)
for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)):
lora_name = lora_o.lora_name
module_o = lora_o.org_module
module_t = lora_t.org_module
diff = module_t.weight - module_o.weight
diff = diff.float()
rank = args.dim if not conv2d_3x3 or args.conv_dim is None else args.conv_dim
out_dim, in_dim = mat.size()[0:2]
if args.device:
diff = diff.to(args.device)
if args.device:
mat = mat.to(args.device)
diffs[lora_name] = diff
# print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim)
rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
# make LoRA with svd
print("calculating by svd")
lora_weights = {}
with torch.no_grad():
for lora_name, mat in tqdm(list(diffs.items())):
# if args.conv_dim is None, diffs do not include LoRAs for conv2d-3x3
conv2d = len(mat.size()) == 4
kernel_size = None if not conv2d else mat.size()[2:4]
conv2d_3x3 = conv2d and kernel_size != (1, 1)
if conv2d:
if conv2d_3x3:
mat = mat.flatten(start_dim=1)
else:
mat = mat.squeeze()
rank = args.dim if not conv2d_3x3 or args.conv_dim is None else args.conv_dim
out_dim, in_dim = mat.size()[0:2]
U, S, Vh = torch.linalg.svd(mat)
if args.device:
mat = mat.to(args.device)
U = U[:, :rank]
S = S[:rank]
U = U @ torch.diag(S)
# print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim)
rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
Vh = Vh[:rank, :]
if conv2d:
if conv2d_3x3:
mat = mat.flatten(start_dim=1)
else:
mat = mat.squeeze()
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
low_val = -hi_val
U, S, Vh = torch.linalg.svd(mat)
U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)
U = U[:, :rank]
S = S[:rank]
U = U @ torch.diag(S)
if conv2d:
U = U.reshape(out_dim, rank, 1, 1)
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
Vh = Vh[:rank, :]
U = U.to("cpu").contiguous()
Vh = Vh.to("cpu").contiguous()
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
low_val = -hi_val
lora_weights[lora_name] = (U, Vh)
U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)
# make state dict for LoRA
lora_sd = {}
for lora_name, (up_weight, down_weight) in lora_weights.items():
lora_sd[lora_name + '.lora_up.weight'] = up_weight
lora_sd[lora_name + '.lora_down.weight'] = down_weight
lora_sd[lora_name + '.alpha'] = torch.tensor(down_weight.size()[0])
if conv2d:
U = U.reshape(out_dim, rank, 1, 1)
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
# load state dict to LoRA and save it
lora_network_save, lora_sd = lora.create_network_from_weights(1.0, None, None, text_encoder_o, unet_o, weights_sd=lora_sd)
lora_network_save.apply_to(text_encoder_o, unet_o) # create internal module references for state_dict
U = U.to("cpu").contiguous()
Vh = Vh.to("cpu").contiguous()
info = lora_network_save.load_state_dict(lora_sd)
print(f"Loading extracted LoRA weights: {info}")
lora_weights[lora_name] = (U, Vh)
dir_name = os.path.dirname(args.save_to)
if dir_name and not os.path.exists(dir_name):
os.makedirs(dir_name, exist_ok=True)
# make state dict for LoRA
lora_sd = {}
for lora_name, (up_weight, down_weight) in lora_weights.items():
lora_sd[lora_name + ".lora_up.weight"] = up_weight
lora_sd[lora_name + ".lora_down.weight"] = down_weight
lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0])
# minimum metadata
metadata = {"ss_network_module": "networks.lora", "ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)}
# load state dict to LoRA and save it
lora_network_save, lora_sd = lora.create_network_from_weights(1.0, None, None, text_encoders_o, unet_o, weights_sd=lora_sd)
lora_network_save.apply_to(text_encoders_o, unet_o) # create internal module references for state_dict
lora_network_save.save_weights(args.save_to, save_dtype, metadata)
print(f"LoRA weights are saved to: {args.save_to}")
info = lora_network_save.load_state_dict(lora_sd)
print(f"Loading extracted LoRA weights: {info}")
dir_name = os.path.dirname(args.save_to)
if dir_name and not os.path.exists(dir_name):
os.makedirs(dir_name, exist_ok=True)
# minimum metadata
net_kwargs = {}
if args.conv_dim is not None:
net_kwargs["conv_dim"] = args.conv_dim
net_kwargs["conv_alpha"] = args.conv_dim
metadata = {
"ss_v2": str(args.v2),
"ss_base_model_version": model_version,
"ss_network_module": "networks.lora",
"ss_network_dim": str(args.dim),
"ss_network_alpha": str(args.dim),
"ss_network_args": json.dumps(net_kwargs),
}
if not args.no_metadata:
title = os.path.splitext(os.path.basename(args.save_to))[0]
sai_metadata = sai_model_spec.build_metadata(
None, args.v2, args.v_parameterization, args.sdxl, True, False, time.time(), title=title
)
metadata.update(sai_metadata)
lora_network_save.save_weights(args.save_to, save_dtype, metadata)
print(f"LoRA weights are saved to: {args.save_to}")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("--v2", action='store_true',
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
parser.add_argument("--save_precision", type=str, default=None,
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はfloat")
parser.add_argument("--model_org", type=str, default=None,
help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors")
parser.add_argument("--model_tuned", type=str, default=None,
help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル生成されるLoRAは元→派生の差分になります、ckptまたはsafetensors")
parser.add_argument("--save_to", type=str, default=None,
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数rankデフォルト4")
parser.add_argument("--conv_dim", type=int, default=None,
help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数rankデフォルトNone、適用なし")
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
parser = argparse.ArgumentParser()
parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む")
parser.add_argument(
"--v_parameterization",
type=bool,
default=None,
help="make LoRA metadata for v-parameterization (default is same to v2) / 作成するLoRAのメタデータにv-parameterization用と設定する省略時はv2と同じ",
)
parser.add_argument(
"--sdxl", action="store_true", help="load Stable Diffusion SDXL base model / Stable Diffusion SDXL baseのモデルを読み込む"
)
parser.add_argument(
"--save_precision",
type=str,
default=None,
choices=[None, "float", "fp16", "bf16"],
help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はfloat",
)
parser.add_argument(
"--model_org",
type=str,
default=None,
help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors",
)
parser.add_argument(
"--model_tuned",
type=str,
default=None,
help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル生成されるLoRAは元→派生の差分になります、ckptまたはsafetensors",
)
parser.add_argument(
"--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
)
parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数rankデフォルト4")
parser.add_argument(
"--conv_dim",
type=int,
default=None,
help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数rankデフォルトNone、適用なし",
)
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
parser.add_argument(
"--no_metadata",
action="store_true",
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
+ "sai modelspecのメタデータを保存しないLoRAの最低限のss_metadataは保存される",
)
return parser
return parser
if __name__ == '__main__':
parser = setup_parser()
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
svd(args)
args = parser.parse_args()
svd(args)

View File

@@ -5,7 +5,9 @@
import math
import os
from typing import List, Tuple, Union
from typing import Dict, List, Optional, Tuple, Type, Union
from diffusers import AutoencoderKL
from transformers import CLIPTextModel
import numpy as np
import torch
import re
@@ -19,7 +21,17 @@ class LoRAModule(torch.nn.Module):
replaces forward method of the original Linear, instead of replacing the original Linear module.
"""
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
def __init__(
self,
lora_name,
org_module: torch.nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
dropout=None,
rank_dropout=None,
module_dropout=None,
):
"""if alpha == 0 or None, alpha is rank (no scaling)."""
super().__init__()
self.lora_name = lora_name
@@ -60,12 +72,87 @@ class LoRAModule(torch.nn.Module):
self.multiplier = multiplier
self.org_module = org_module # remove in applying
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
def apply_to(self):
self.org_forward = self.org_module.forward
self.org_module.forward = self.forward
del self.org_module
def forward(self, x):
org_forwarded = self.org_forward(x)
# module dropout
if self.module_dropout is not None and self.training:
if torch.rand(1) < self.module_dropout:
return org_forwarded
lx = self.lora_down(x)
# normal dropout
if self.dropout is not None and self.training:
lx = torch.nn.functional.dropout(lx, p=self.dropout)
# rank dropout
if self.rank_dropout is not None and self.training:
mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
if len(lx.size()) == 3:
mask = mask.unsqueeze(1) # for Text Encoder
elif len(lx.size()) == 4:
mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
lx = lx * mask
# scaling for rank dropout: treat as if the rank is changed
# maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる
scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
else:
scale = self.scale
lx = self.lora_up(lx)
return org_forwarded + lx * self.multiplier * scale
class LoRAInfModule(LoRAModule):
def __init__(
self,
lora_name,
org_module: torch.nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
**kwargs,
):
# no dropout for inference
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
self.org_module_ref = [org_module] # 後から参照できるように
self.enabled = True
# check regional or not by lora_name
self.text_encoder = False
if lora_name.startswith("lora_te_"):
self.regional = False
self.use_sub_prompt = True
self.text_encoder = True
elif "attn2_to_k" in lora_name or "attn2_to_v" in lora_name:
self.regional = False
self.use_sub_prompt = True
elif "time_emb" in lora_name:
self.regional = False
self.use_sub_prompt = False
else:
self.regional = True
self.use_sub_prompt = False
self.network: LoRANetwork = None
def set_network(self, network):
self.network = network
# freezeしてマージする
def merge_to(self, sd, dtype, device):
# get up/down weight
up_weight = sd["lora_up.weight"].to(torch.float).to(device)
@@ -97,44 +184,45 @@ class LoRAModule(torch.nn.Module):
org_sd["weight"] = weight.to(dtype)
self.org_module.load_state_dict(org_sd)
# 復元できるマージのため、このモジュールのweightを返す
def get_weight(self, multiplier=None):
if multiplier is None:
multiplier = self.multiplier
# get up/down weight from module
up_weight = self.lora_up.weight.to(torch.float)
down_weight = self.lora_down.weight.to(torch.float)
# pre-calculated weight
if len(down_weight.size()) == 2:
# linear
weight = self.multiplier * (up_weight @ down_weight) * self.scale
elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1
weight = (
self.multiplier
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* self.scale
)
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
weight = self.multiplier * conved * self.scale
return weight
def set_region(self, region):
self.region = region
self.region_mask = None
def forward(self, x):
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
class LoRAInfModule(LoRAModule):
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
# check regional or not by lora_name
self.text_encoder = False
if lora_name.startswith("lora_te_"):
self.regional = False
self.use_sub_prompt = True
self.text_encoder = True
elif "attn2_to_k" in lora_name or "attn2_to_v" in lora_name:
self.regional = False
self.use_sub_prompt = True
elif "time_emb" in lora_name:
self.regional = False
self.use_sub_prompt = False
else:
self.regional = True
self.use_sub_prompt = False
self.network: LoRANetwork = None
def set_network(self, network):
self.network = network
def default_forward(self, x):
# print("default_forward", self.lora_name, x.size())
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
def forward(self, x):
if not self.enabled:
return self.org_forward(x)
if self.network is None or self.network.sub_prompt_index is None:
return self.default_forward(x)
if not self.regional and not self.use_sub_prompt:
@@ -153,9 +241,13 @@ class LoRAInfModule(LoRAModule):
else:
area = x.size()[1]
mask = self.network.mask_dic[area]
mask = self.network.mask_dic.get(area, None)
if mask is None:
raise ValueError(f"mask is None for resolution {area}")
# raise ValueError(f"mask is None for resolution {area}")
# emb_layers in SDXL doesn't have mask
# print(f"mask is None for resolution {area}, {x.size()}")
mask_size = (1, x.size()[1]) if len(x.size()) == 2 else (1, *x.size()[1:-1], 1)
return torch.ones(mask_size, dtype=x.dtype, device=x.device) / self.network.num_sub_prompts
if len(x.size()) != 4:
mask = torch.reshape(mask, (1, -1, 1))
return mask
@@ -260,9 +352,10 @@ class LoRAInfModule(LoRAModule):
out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond
# print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts)
# for i in range(len(masks)):
# if masks[i] is None:
# masks[i] = torch.zeros_like(masks[-1])
# if num_sub_prompts > num of LoRAs, fill with zero
for i in range(len(masks)):
if masks[i] is None:
masks[i] = torch.zeros_like(masks[0])
mask = torch.cat(masks)
mask_sum = torch.sum(mask, dim=0) + 1e-4
@@ -285,7 +378,45 @@ class LoRAInfModule(LoRAModule):
return out
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
def parse_block_lr_kwargs(nw_kwargs):
down_lr_weight = nw_kwargs.get("down_lr_weight", None)
mid_lr_weight = nw_kwargs.get("mid_lr_weight", None)
up_lr_weight = nw_kwargs.get("up_lr_weight", None)
# 以上のいずれにも設定がない場合は無効としてNoneを返す
if down_lr_weight is None and mid_lr_weight is None and up_lr_weight is None:
return None, None, None
# extract learning rate weight for each block
if down_lr_weight is not None:
# if some parameters are not set, use zero
if "," in down_lr_weight:
down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")]
if mid_lr_weight is not None:
mid_lr_weight = float(mid_lr_weight)
if up_lr_weight is not None:
if "," in up_lr_weight:
up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")]
down_lr_weight, mid_lr_weight, up_lr_weight = get_block_lr_weight(
down_lr_weight, mid_lr_weight, up_lr_weight, float(nw_kwargs.get("block_lr_zero_threshold", 0.0))
)
return down_lr_weight, mid_lr_weight, up_lr_weight
def create_network(
multiplier: float,
network_dim: Optional[int],
network_alpha: Optional[float],
vae: AutoencoderKL,
text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
unet,
neuron_dropout: Optional[float] = None,
**kwargs,
):
if network_dim is None:
network_dim = 4 # default
if network_alpha is None:
@@ -303,9 +434,7 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
# block dim/alpha/lr
block_dims = kwargs.get("block_dims", None)
down_lr_weight = kwargs.get("down_lr_weight", None)
mid_lr_weight = kwargs.get("mid_lr_weight", None)
up_lr_weight = kwargs.get("up_lr_weight", None)
down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs)
# 以上のいずれかに指定があればblockごとのdim(rank)を有効にする
if block_dims is not None or down_lr_weight is not None or mid_lr_weight is not None or up_lr_weight is not None:
@@ -317,23 +446,6 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
)
# extract learning rate weight for each block
if down_lr_weight is not None:
# if some parameters are not set, use zero
if "," in down_lr_weight:
down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")]
if mid_lr_weight is not None:
mid_lr_weight = float(mid_lr_weight)
if up_lr_weight is not None:
if "," in up_lr_weight:
up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")]
down_lr_weight, mid_lr_weight, up_lr_weight = get_block_lr_weight(
down_lr_weight, mid_lr_weight, up_lr_weight, float(kwargs.get("block_lr_zero_threshold", 0.0))
)
# remove block dim/alpha without learning rate
block_dims, block_alphas, conv_block_dims, conv_block_alphas = remove_block_dims_and_alphas(
block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
@@ -344,6 +456,14 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
conv_block_dims = None
conv_block_alphas = None
# rank/module dropout
rank_dropout = kwargs.get("rank_dropout", None)
if rank_dropout is not None:
rank_dropout = float(rank_dropout)
module_dropout = kwargs.get("module_dropout", None)
if module_dropout is not None:
module_dropout = float(module_dropout)
# すごく引数が多いな ( ^ω^)・・・
network = LoRANetwork(
text_encoder,
@@ -351,6 +471,9 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
multiplier=multiplier,
lora_dim=network_dim,
alpha=network_alpha,
dropout=neuron_dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
conv_lora_dim=conv_dim,
conv_alpha=conv_alpha,
block_dims=block_dims,
@@ -593,43 +716,55 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
# support old LoRA without alpha
for key in modules_dim.keys():
if key not in modules_alpha:
modules_alpha = modules_dim[key]
modules_alpha[key] = modules_dim[key]
module_class = LoRAInfModule if for_inference else LoRAModule
network = LoRANetwork(
text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class
)
# block lr
down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs)
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
return network, weights_sd
class LoRANetwork(torch.nn.Module):
NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
# is it possible to apply conv_in and conv_out? -> yes, newer LoCon supports it (^^;)
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
LORA_PREFIX_UNET = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te"
# SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
LORA_PREFIX_TEXT_ENCODER1 = "lora_te1"
LORA_PREFIX_TEXT_ENCODER2 = "lora_te2"
def __init__(
self,
text_encoder,
text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
unet,
multiplier=1.0,
lora_dim=4,
alpha=1,
conv_lora_dim=None,
conv_alpha=None,
block_dims=None,
block_alphas=None,
conv_block_dims=None,
conv_block_alphas=None,
modules_dim=None,
modules_alpha=None,
module_class=LoRAModule,
varbose=False,
multiplier: float = 1.0,
lora_dim: int = 4,
alpha: float = 1,
dropout: Optional[float] = None,
rank_dropout: Optional[float] = None,
module_dropout: Optional[float] = None,
conv_lora_dim: Optional[int] = None,
conv_alpha: Optional[float] = None,
block_dims: Optional[List[int]] = None,
block_alphas: Optional[List[float]] = None,
conv_block_dims: Optional[List[int]] = None,
conv_block_alphas: Optional[List[float]] = None,
modules_dim: Optional[Dict[str, int]] = None,
modules_alpha: Optional[Dict[str, int]] = None,
module_class: Type[object] = LoRAModule,
varbose: Optional[bool] = False,
) -> None:
"""
LoRA network: すごく引数が多いが、パターンは以下の通り
@@ -646,11 +781,15 @@ class LoRANetwork(torch.nn.Module):
self.alpha = alpha
self.conv_lora_dim = conv_lora_dim
self.conv_alpha = conv_alpha
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
if modules_dim is not None:
print(f"create LoRA network from weights")
elif block_dims is not None:
print(f"create LoRA network from block_dims")
print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
print(f"block_dims: {block_dims}")
print(f"block_alphas: {block_alphas}")
if conv_block_dims is not None:
@@ -658,12 +797,26 @@ class LoRANetwork(torch.nn.Module):
print(f"conv_block_alphas: {conv_block_alphas}")
else:
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
if self.conv_lora_dim is not None:
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
# create module instances
def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
prefix = LoRANetwork.LORA_PREFIX_UNET if is_unet else LoRANetwork.LORA_PREFIX_TEXT_ENCODER
def create_modules(
is_unet: bool,
text_encoder_idx: Optional[int], # None, 1, 2
root_module: torch.nn.Module,
target_replace_modules: List[torch.nn.Module],
) -> List[LoRAModule]:
prefix = (
self.LORA_PREFIX_UNET
if is_unet
else (
self.LORA_PREFIX_TEXT_ENCODER
if text_encoder_idx is None
else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2)
)
)
loras = []
skipped = []
for name, module in root_module.named_modules():
@@ -679,11 +832,14 @@ class LoRANetwork(torch.nn.Module):
dim = None
alpha = None
if modules_dim is not None:
# モジュール指定あり
if lora_name in modules_dim:
dim = modules_dim[lora_name]
alpha = modules_alpha[lora_name]
elif is_unet and block_dims is not None:
# U-Netでblock_dims指定あり
block_idx = get_block_index(lora_name)
if is_linear or is_conv2d_1x1:
dim = block_dims[block_idx]
@@ -692,6 +848,7 @@ class LoRANetwork(torch.nn.Module):
dim = conv_block_dims[block_idx]
alpha = conv_block_alphas[block_idx]
else:
# 通常、すべて対象とする
if is_linear or is_conv2d_1x1:
dim = self.lora_dim
alpha = self.alpha
@@ -700,15 +857,41 @@ class LoRANetwork(torch.nn.Module):
alpha = self.conv_alpha
if dim is None or dim == 0:
# skipした情報を出力
if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None or conv_block_dims is not None):
skipped.append(lora_name)
continue
lora = module_class(lora_name, child_module, self.multiplier, dim, alpha)
lora = module_class(
lora_name,
child_module,
self.multiplier,
dim,
alpha,
dropout=dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
)
loras.append(lora)
return loras, skipped
self.text_encoder_loras, skipped_te = create_modules(False, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
# create LoRA for text encoder
# 毎回すべてのモジュールを作るのは無駄なので要検討
self.text_encoder_loras = []
skipped_te = []
for i, text_encoder in enumerate(text_encoders):
if len(text_encoders) > 1:
index = i + 1
print(f"create LoRA for Text Encoder {index}:")
else:
index = None
print(f"create LoRA for Text Encoder:")
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
self.text_encoder_loras.extend(text_encoder_loras)
skipped_te += skipped
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
@@ -716,7 +899,7 @@ class LoRANetwork(torch.nn.Module):
if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
self.unet_loras, skipped_un = create_modules(True, unet, target_modules)
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
skipped = skipped_te + skipped_un
@@ -769,6 +952,10 @@ class LoRANetwork(torch.nn.Module):
lora.apply_to()
self.add_module(lora.lora_name, lora)
# マージできるかどうかを返す
def is_mergeable(self):
return True
# TODO refactor to common function with apply_to
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
apply_text_encoder = apply_unet = False
@@ -797,7 +984,7 @@ class LoRANetwork(torch.nn.Module):
print(f"weights are merged")
# 層別学習率用に層ごとの学習率に対する倍率を定義する
# 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない
def set_block_lr_weight(
self,
up_lr_weight: List[float] = None,
@@ -827,6 +1014,7 @@ class LoRANetwork(torch.nn.Module):
return lr_weight
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
self.requires_grad_(True)
all_params = []
@@ -955,3 +1143,83 @@ class LoRANetwork(torch.nn.Module):
w = (w + 1) // 2
self.mask_dic = mask_dic
def backup_weights(self):
# 重みのバックアップを行う
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
for lora in loras:
org_module = lora.org_module_ref[0]
if not hasattr(org_module, "_lora_org_weight"):
sd = org_module.state_dict()
org_module._lora_org_weight = sd["weight"].detach().clone()
org_module._lora_restored = True
def restore_weights(self):
# 重みのリストアを行う
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
for lora in loras:
org_module = lora.org_module_ref[0]
if not org_module._lora_restored:
sd = org_module.state_dict()
sd["weight"] = org_module._lora_org_weight
org_module.load_state_dict(sd)
org_module._lora_restored = True
def pre_calculation(self):
# 事前計算を行う
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
for lora in loras:
org_module = lora.org_module_ref[0]
sd = org_module.state_dict()
org_weight = sd["weight"]
lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype)
sd["weight"] = org_weight + lora_weight
assert sd["weight"].shape == org_weight.shape
org_module.load_state_dict(sd)
org_module._lora_restored = False
lora.enabled = False
def apply_max_norm_regularization(self, max_norm_value, device):
downkeys = []
upkeys = []
alphakeys = []
norms = []
keys_scaled = 0
state_dict = self.state_dict()
for key in state_dict.keys():
if "lora_down" in key and "weight" in key:
downkeys.append(key)
upkeys.append(key.replace("lora_down", "lora_up"))
alphakeys.append(key.replace("lora_down.weight", "alpha"))
for i in range(len(downkeys)):
down = state_dict[downkeys[i]].to(device)
up = state_dict[upkeys[i]].to(device)
alpha = state_dict[alphakeys[i]].to(device)
dim = down.shape[0]
scale = alpha / dim
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
else:
updown = up @ down
updown *= scale
norm = updown.norm().clamp(min=max_norm_value / 2)
desired = torch.clamp(norm, max=max_norm_value)
ratio = desired.cpu() / norm.cpu()
sqrt_ratio = ratio**0.5
if ratio != 1:
keys_scaled += 1
state_dict[upkeys[i]] *= sqrt_ratio
state_dict[downkeys[i]] *= sqrt_ratio
scalednorm = updown.norm() * ratio
norms.append(scalednorm.item())
return keys_scaled, sum(norms) / len(norms), max(norms)

609
networks/lora_diffusers.py Normal file
View File

@@ -0,0 +1,609 @@
# Diffusersで動くLoRA。このファイル単独で完結する。
# LoRA module for Diffusers. This file works independently.
import bisect
import math
import random
from typing import Any, Dict, List, Mapping, Optional, Union
from diffusers import UNet2DConditionModel
import numpy as np
from tqdm import tqdm
from transformers import CLIPTextModel
import torch
def make_unet_conversion_map() -> Dict[str, str]:
unet_conversion_map_layer = []
for i in range(3): # num_blocks is 3 in sdxl
# loop over downblocks/upblocks
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3:
# no attention layers in down_blocks.3
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(3):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
# if i > 0: commentout for sdxl
# no attention layers in up_blocks.0
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0."
sd_mid_atn_prefix = "middle_block.1."
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2*j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
unet_conversion_map_resnet = [
# (stable-diffusion, HF Diffusers)
("in_layers.0.", "norm1."),
("in_layers.2.", "conv1."),
("out_layers.0.", "norm2."),
("out_layers.3.", "conv2."),
("emb_layers.1.", "time_emb_proj."),
("skip_connection.", "conv_shortcut."),
]
unet_conversion_map = []
for sd, hf in unet_conversion_map_layer:
if "resnets" in hf:
for sd_res, hf_res in unet_conversion_map_resnet:
unet_conversion_map.append((sd + sd_res, hf + hf_res))
else:
unet_conversion_map.append((sd, hf))
for j in range(2):
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
sd_time_embed_prefix = f"time_embed.{j*2}."
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
for j in range(2):
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
sd_label_embed_prefix = f"label_emb.0.{j*2}."
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
unet_conversion_map.append(("out.0.", "conv_norm_out."))
unet_conversion_map.append(("out.2.", "conv_out."))
sd_hf_conversion_map = {sd.replace(".", "_")[:-1]: hf.replace(".", "_")[:-1] for sd, hf in unet_conversion_map}
return sd_hf_conversion_map
UNET_CONVERSION_MAP = make_unet_conversion_map()
class LoRAModule(torch.nn.Module):
"""
replaces forward method of the original Linear, instead of replacing the original Linear module.
"""
def __init__(
self,
lora_name,
org_module: torch.nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
):
"""if alpha == 0 or None, alpha is rank (no scaling)."""
super().__init__()
self.lora_name = lora_name
if org_module.__class__.__name__ == "Conv2d" or org_module.__class__.__name__ == "LoRACompatibleConv":
in_dim = org_module.in_channels
out_dim = org_module.out_channels
else:
in_dim = org_module.in_features
out_dim = org_module.out_features
self.lora_dim = lora_dim
if org_module.__class__.__name__ == "Conv2d" or org_module.__class__.__name__ == "LoRACompatibleConv":
kernel_size = org_module.kernel_size
stride = org_module.stride
padding = org_module.padding
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
else:
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
self.scale = alpha / self.lora_dim
self.register_buffer("alpha", torch.tensor(alpha)) # 勾配計算に含めない / not included in gradient calculation
# same as microsoft's
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
torch.nn.init.zeros_(self.lora_up.weight)
self.multiplier = multiplier
self.org_module = [org_module]
self.enabled = True
self.network: LoRANetwork = None
self.org_forward = None
# override org_module's forward method
def apply_to(self, multiplier=None):
if multiplier is not None:
self.multiplier = multiplier
if self.org_forward is None:
self.org_forward = self.org_module[0].forward
self.org_module[0].forward = self.forward
# restore org_module's forward method
def unapply_to(self):
if self.org_forward is not None:
self.org_module[0].forward = self.org_forward
# forward with lora
# scale is used LoRACompatibleConv, but we ignore it because we have multiplier
def forward(self, x, scale=1.0):
if not self.enabled:
return self.org_forward(x)
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
def set_network(self, network):
self.network = network
# merge lora weight to org weight
def merge_to(self, multiplier=1.0):
# get lora weight
lora_weight = self.get_weight(multiplier)
# get org weight
org_sd = self.org_module[0].state_dict()
org_weight = org_sd["weight"]
weight = org_weight + lora_weight.to(org_weight.device, dtype=org_weight.dtype)
# set weight to org_module
org_sd["weight"] = weight
self.org_module[0].load_state_dict(org_sd)
# restore org weight from lora weight
def restore_from(self, multiplier=1.0):
# get lora weight
lora_weight = self.get_weight(multiplier)
# get org weight
org_sd = self.org_module[0].state_dict()
org_weight = org_sd["weight"]
weight = org_weight - lora_weight.to(org_weight.device, dtype=org_weight.dtype)
# set weight to org_module
org_sd["weight"] = weight
self.org_module[0].load_state_dict(org_sd)
# return lora weight
def get_weight(self, multiplier=None):
if multiplier is None:
multiplier = self.multiplier
# get up/down weight from module
up_weight = self.lora_up.weight.to(torch.float)
down_weight = self.lora_down.weight.to(torch.float)
# pre-calculated weight
if len(down_weight.size()) == 2:
# linear
weight = self.multiplier * (up_weight @ down_weight) * self.scale
elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1
weight = (
self.multiplier
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* self.scale
)
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
weight = self.multiplier * conved * self.scale
return weight
# Create network from weights for inference, weights are not loaded here
def create_network_from_weights(
text_encoder: Union[CLIPTextModel, List[CLIPTextModel]], unet: UNet2DConditionModel, weights_sd: Dict, multiplier: float = 1.0
):
# get dim/alpha mapping
modules_dim = {}
modules_alpha = {}
for key, value in weights_sd.items():
if "." not in key:
continue
lora_name = key.split(".")[0]
if "alpha" in key:
modules_alpha[lora_name] = value
elif "lora_down" in key:
dim = value.size()[0]
modules_dim[lora_name] = dim
# print(lora_name, value.size(), dim)
# support old LoRA without alpha
for key in modules_dim.keys():
if key not in modules_alpha:
modules_alpha[key] = modules_dim[key]
return LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha)
def merge_lora_weights(pipe, weights_sd: Dict, multiplier: float = 1.0):
text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if hasattr(pipe, "text_encoder_2") else [pipe.text_encoder]
unet = pipe.unet
lora_network = create_network_from_weights(text_encoders, unet, weights_sd, multiplier=multiplier)
lora_network.load_state_dict(weights_sd)
lora_network.merge_to(multiplier=multiplier)
# block weightや学習に対応しない簡易版 / simple version without block weight and training
class LoRANetwork(torch.nn.Module):
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
LORA_PREFIX_UNET = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te"
# SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
LORA_PREFIX_TEXT_ENCODER1 = "lora_te1"
LORA_PREFIX_TEXT_ENCODER2 = "lora_te2"
def __init__(
self,
text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
unet: UNet2DConditionModel,
multiplier: float = 1.0,
modules_dim: Optional[Dict[str, int]] = None,
modules_alpha: Optional[Dict[str, int]] = None,
varbose: Optional[bool] = False,
) -> None:
super().__init__()
self.multiplier = multiplier
print(f"create LoRA network from weights")
# convert SDXL Stability AI's U-Net modules to Diffusers
converted = self.convert_unet_modules(modules_dim, modules_alpha)
if converted:
print(f"converted {converted} Stability AI's U-Net LoRA modules to Diffusers (SDXL)")
# create module instances
def create_modules(
is_unet: bool,
text_encoder_idx: Optional[int], # None, 1, 2
root_module: torch.nn.Module,
target_replace_modules: List[torch.nn.Module],
) -> List[LoRAModule]:
prefix = (
self.LORA_PREFIX_UNET
if is_unet
else (
self.LORA_PREFIX_TEXT_ENCODER
if text_encoder_idx is None
else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2)
)
)
loras = []
skipped = []
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
is_linear = (
child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "LoRACompatibleLinear"
)
is_conv2d = (
child_module.__class__.__name__ == "Conv2d" or child_module.__class__.__name__ == "LoRACompatibleConv"
)
if is_linear or is_conv2d:
lora_name = prefix + "." + name + "." + child_name
lora_name = lora_name.replace(".", "_")
if lora_name not in modules_dim:
# print(f"skipped {lora_name} (not found in modules_dim)")
skipped.append(lora_name)
continue
dim = modules_dim[lora_name]
alpha = modules_alpha[lora_name]
lora = LoRAModule(
lora_name,
child_module,
self.multiplier,
dim,
alpha,
)
loras.append(lora)
return loras, skipped
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
# create LoRA for text encoder
# 毎回すべてのモジュールを作るのは無駄なので要検討 / it is wasteful to create all modules every time, need to consider
self.text_encoder_loras: List[LoRAModule] = []
skipped_te = []
for i, text_encoder in enumerate(text_encoders):
if len(text_encoders) > 1:
index = i + 1
else:
index = None
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
self.text_encoder_loras.extend(text_encoder_loras)
skipped_te += skipped
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
if len(skipped_te) > 0:
print(f"skipped {len(skipped_te)} modules because of missing weight for text encoder.")
# extend U-Net target modules to include Conv2d 3x3
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
self.unet_loras: List[LoRAModule]
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
if len(skipped_un) > 0:
print(f"skipped {len(skipped_un)} modules because of missing weight for U-Net.")
# assertion
names = set()
for lora in self.text_encoder_loras + self.unet_loras:
names.add(lora.lora_name)
for lora_name in modules_dim.keys():
assert lora_name in names, f"{lora_name} is not found in created LoRA modules."
# make to work load_state_dict
for lora in self.text_encoder_loras + self.unet_loras:
self.add_module(lora.lora_name, lora)
# SDXL: convert SDXL Stability AI's U-Net modules to Diffusers
def convert_unet_modules(self, modules_dim, modules_alpha):
converted_count = 0
not_converted_count = 0
map_keys = list(UNET_CONVERSION_MAP.keys())
map_keys.sort()
for key in list(modules_dim.keys()):
if key.startswith(LoRANetwork.LORA_PREFIX_UNET + "_"):
search_key = key.replace(LoRANetwork.LORA_PREFIX_UNET + "_", "")
position = bisect.bisect_right(map_keys, search_key)
map_key = map_keys[position - 1]
if search_key.startswith(map_key):
new_key = key.replace(map_key, UNET_CONVERSION_MAP[map_key])
modules_dim[new_key] = modules_dim[key]
modules_alpha[new_key] = modules_alpha[key]
del modules_dim[key]
del modules_alpha[key]
converted_count += 1
else:
not_converted_count += 1
assert (
converted_count == 0 or not_converted_count == 0
), f"some modules are not converted: {converted_count} converted, {not_converted_count} not converted"
return converted_count
def set_multiplier(self, multiplier):
self.multiplier = multiplier
for lora in self.text_encoder_loras + self.unet_loras:
lora.multiplier = self.multiplier
def apply_to(self, multiplier=1.0, apply_text_encoder=True, apply_unet=True):
if apply_text_encoder:
print("enable LoRA for text encoder")
for lora in self.text_encoder_loras:
lora.apply_to(multiplier)
if apply_unet:
print("enable LoRA for U-Net")
for lora in self.unet_loras:
lora.apply_to(multiplier)
def unapply_to(self):
for lora in self.text_encoder_loras + self.unet_loras:
lora.unapply_to()
def merge_to(self, multiplier=1.0):
print("merge LoRA weights to original weights")
for lora in tqdm(self.text_encoder_loras + self.unet_loras):
lora.merge_to(multiplier)
print(f"weights are merged")
def restore_from(self, multiplier=1.0):
print("restore LoRA weights from original weights")
for lora in tqdm(self.text_encoder_loras + self.unet_loras):
lora.restore_from(multiplier)
print(f"weights are restored")
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
# convert SDXL Stability AI's state dict to Diffusers' based state dict
map_keys = list(UNET_CONVERSION_MAP.keys()) # prefix of U-Net modules
map_keys.sort()
for key in list(state_dict.keys()):
if key.startswith(LoRANetwork.LORA_PREFIX_UNET + "_"):
search_key = key.replace(LoRANetwork.LORA_PREFIX_UNET + "_", "")
position = bisect.bisect_right(map_keys, search_key)
map_key = map_keys[position - 1]
if search_key.startswith(map_key):
new_key = key.replace(map_key, UNET_CONVERSION_MAP[map_key])
state_dict[new_key] = state_dict[key]
del state_dict[key]
# in case of V2, some weights have different shape, so we need to convert them
# because V2 LoRA is based on U-Net created by use_linear_projection=False
my_state_dict = self.state_dict()
for key in state_dict.keys():
if state_dict[key].size() != my_state_dict[key].size():
# print(f"convert {key} from {state_dict[key].size()} to {my_state_dict[key].size()}")
state_dict[key] = state_dict[key].view(my_state_dict[key].size())
return super().load_state_dict(state_dict, strict)
if __name__ == "__main__":
# sample code to use LoRANetwork
import os
import argparse
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
parser = argparse.ArgumentParser()
parser.add_argument("--model_id", type=str, default=None, help="model id for huggingface")
parser.add_argument("--lora_weights", type=str, default=None, help="path to LoRA weights")
parser.add_argument("--sdxl", action="store_true", help="use SDXL model")
parser.add_argument("--prompt", type=str, default="A photo of cat", help="prompt text")
parser.add_argument("--negative_prompt", type=str, default="", help="negative prompt text")
parser.add_argument("--seed", type=int, default=0, help="random seed")
args = parser.parse_args()
image_prefix = args.model_id.replace("/", "_") + "_"
# load Diffusers model
print(f"load model from {args.model_id}")
pipe: Union[StableDiffusionPipeline, StableDiffusionXLPipeline]
if args.sdxl:
# use_safetensors=True does not work with 0.18.2
pipe = StableDiffusionXLPipeline.from_pretrained(args.model_id, variant="fp16", torch_dtype=torch.float16)
else:
pipe = StableDiffusionPipeline.from_pretrained(args.model_id, variant="fp16", torch_dtype=torch.float16)
pipe.to(device)
pipe.set_use_memory_efficient_attention_xformers(True)
text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if args.sdxl else [pipe.text_encoder]
# load LoRA weights
print(f"load LoRA weights from {args.lora_weights}")
if os.path.splitext(args.lora_weights)[1] == ".safetensors":
from safetensors.torch import load_file
lora_sd = load_file(args.lora_weights)
else:
lora_sd = torch.load(args.lora_weights)
# create by LoRA weights and load weights
print(f"create LoRA network")
lora_network: LoRANetwork = create_network_from_weights(text_encoders, pipe.unet, lora_sd, multiplier=1.0)
print(f"load LoRA network weights")
lora_network.load_state_dict(lora_sd)
lora_network.to(device, dtype=pipe.unet.dtype) # required to apply_to. merge_to works without this
# 必要があれば、元のモデルの重みをバックアップしておく
# back-up unet/text encoder weights if necessary
def detach_and_move_to_cpu(state_dict):
for k, v in state_dict.items():
state_dict[k] = v.detach().cpu()
return state_dict
org_unet_sd = pipe.unet.state_dict()
detach_and_move_to_cpu(org_unet_sd)
org_text_encoder_sd = pipe.text_encoder.state_dict()
detach_and_move_to_cpu(org_text_encoder_sd)
if args.sdxl:
org_text_encoder_2_sd = pipe.text_encoder_2.state_dict()
detach_and_move_to_cpu(org_text_encoder_2_sd)
def seed_everything(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
# create image with original weights
print(f"create image with original weights")
seed_everything(args.seed)
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
image.save(image_prefix + "original.png")
# apply LoRA network to the model: slower than merge_to, but can be reverted easily
print(f"apply LoRA network to the model")
lora_network.apply_to(multiplier=1.0)
print(f"create image with applied LoRA")
seed_everything(args.seed)
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
image.save(image_prefix + "applied_lora.png")
# unapply LoRA network to the model
print(f"unapply LoRA network to the model")
lora_network.unapply_to()
print(f"create image with unapplied LoRA")
seed_everything(args.seed)
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
image.save(image_prefix + "unapplied_lora.png")
# merge LoRA network to the model: faster than apply_to, but requires back-up of original weights (or unmerge_to)
print(f"merge LoRA network to the model")
lora_network.merge_to(multiplier=1.0)
print(f"create image with LoRA")
seed_everything(args.seed)
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
image.save(image_prefix + "merged_lora.png")
# restore (unmerge) LoRA weights: numerically unstable
# マージされた重みを元に戻す。計算誤差のため、元の重みと完全に一致しないことがあるかもしれない
# 保存したstate_dictから元の重みを復元するのが確実
print(f"restore (unmerge) LoRA weights")
lora_network.restore_from(multiplier=1.0)
print(f"create image without LoRA")
seed_everything(args.seed)
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
image.save(image_prefix + "unmerged_lora.png")
# restore original weights
print(f"restore original weights")
pipe.unet.load_state_dict(org_unet_sd)
pipe.text_encoder.load_state_dict(org_text_encoder_sd)
if args.sdxl:
pipe.text_encoder_2.load_state_dict(org_text_encoder_2_sd)
print(f"create image with restored original weights")
seed_everything(args.seed)
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
image.save(image_prefix + "restore_original.png")
# use convenience function to merge LoRA weights
print(f"merge LoRA weights with convenience function")
merge_lora_weights(pipe, lora_sd, multiplier=1.0)
print(f"create image with merged LoRA weights")
seed_everything(args.seed)
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
image.save(image_prefix + "convenience_merged_lora.png")

1241
networks/lora_fa.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -1,8 +1,10 @@
import math
import argparse
import os
import time
import torch
from safetensors.torch import load_file, save_file
from library import sai_model_spec, train_util
import library.model_util as model_util
import lora
@@ -10,22 +12,26 @@ import lora
def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == ".safetensors":
sd = load_file(file_name)
metadata = train_util.load_metadata_from_safetensors(file_name)
else:
sd = torch.load(file_name, map_location="cpu")
metadata = {}
for key in list(sd.keys()):
if type(sd[key]) == torch.Tensor:
sd[key] = sd[key].to(dtype)
return sd
return sd, metadata
def save_to_file(file_name, model, state_dict, dtype):
def save_to_file(file_name, model, state_dict, dtype, metadata):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
if os.path.splitext(file_name)[1] == ".safetensors":
save_file(model, file_name)
save_file(model, file_name, metadata=metadata)
else:
torch.save(model, file_name)
@@ -56,7 +62,7 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
lora_sd = load_state_dict(model, merge_dtype)
lora_sd, _ = load_state_dict(model, merge_dtype)
print(f"merging...")
for key in lora_sd.keys():
@@ -81,9 +87,11 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
# W <- W + U * D
weight = module.weight
# print(module_name, down_weight.size(), up_weight.size())
if len(weight.size()) == 2:
# linear
if len(up_weight.size()) == 4: # use linear projection mismatch
up_weight = up_weight.squeeze(3).squeeze(2)
down_weight = down_weight.squeeze(3).squeeze(2)
weight = weight + ratio * (up_weight @ down_weight) * scale
elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1
@@ -102,14 +110,22 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
module.weight = torch.nn.Parameter(weight)
def merge_lora_models(models, ratios, merge_dtype):
def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
base_alphas = {} # alpha for merged model
base_dims = {}
merged_sd = {}
v2 = None
base_model = None
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
lora_sd = load_state_dict(model, merge_dtype)
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
if lora_metadata is not None:
if v2 is None:
v2 = lora_metadata.get(train_util.SS_METADATA_KEY_V2, None) # return string
if base_model is None:
base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None)
# get alpha and dim
alphas = {} # alpha for current model
@@ -142,6 +158,12 @@ def merge_lora_models(models, ratios, merge_dtype):
for key in lora_sd.keys():
if "alpha" in key:
continue
if "lora_up" in key and concat:
concat_dim = 1
elif "lora_down" in key and concat:
concat_dim = 0
else:
concat_dim = None
lora_module_name = key[: key.rfind(".lora_")]
@@ -149,12 +171,16 @@ def merge_lora_models(models, ratios, merge_dtype):
alpha = alphas[lora_module_name]
scale = math.sqrt(alpha / base_alpha) * ratio
scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。
if key in merged_sd:
assert (
merged_sd[key].size() == lora_sd[key].size()
merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
if concat_dim is not None:
merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim)
else:
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
else:
merged_sd[key] = lora_sd[key] * scale
@@ -162,11 +188,37 @@ def merge_lora_models(models, ratios, merge_dtype):
for lora_module_name, alpha in base_alphas.items():
key = lora_module_name + ".alpha"
merged_sd[key] = torch.tensor(alpha)
if shuffle:
key_down = lora_module_name + ".lora_down.weight"
key_up = lora_module_name + ".lora_up.weight"
dim = merged_sd[key_down].shape[0]
perm = torch.randperm(dim)
merged_sd[key_down] = merged_sd[key_down][perm]
merged_sd[key_up] = merged_sd[key_up][:,perm]
print("merged model")
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
return merged_sd
# check all dims are same
dims_list = list(set(base_dims.values()))
alphas_list = list(set(base_alphas.values()))
all_same_dims = True
all_same_alphas = True
for dims in dims_list:
if dims != dims_list[0]:
all_same_dims = False
break
for alphas in alphas_list:
if alphas != alphas_list[0]:
all_same_alphas = False
break
# build minimum metadata
dims = f"{dims_list[0]}" if all_same_dims else "Dynamic"
alphas = f"{alphas_list[0]}" if all_same_alphas else "Dynamic"
metadata = train_util.build_minimum_network_metadata(v2, base_model, "networks.lora", dims, alphas, None)
return merged_sd, metadata, v2 == "True"
def merge(args):
@@ -193,13 +245,57 @@ def merge(args):
merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype)
if args.no_metadata:
sai_metadata = None
else:
merged_from = sai_model_spec.build_merged_from([args.sd_model] + args.models)
title = os.path.splitext(os.path.basename(args.save_to))[0]
sai_metadata = sai_model_spec.build_metadata(
None,
args.v2,
args.v2,
False,
False,
False,
time.time(),
title=title,
merged_from=merged_from,
is_stable_diffusion_ckpt=True,
)
if args.v2:
# TODO read sai modelspec
print(
"Cannot determine if model is for v-prediction, so save metadata as v-prediction / modelがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します"
)
print(f"saving SD model to: {args.save_to}")
model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, save_dtype, vae)
model_util.save_stable_diffusion_checkpoint(
args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, sai_metadata, save_dtype, vae
)
else:
state_dict = merge_lora_models(args.models, args.ratios, merge_dtype)
state_dict, metadata, v2 = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
print(f"calculating hashes and creating metadata...")
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
if not args.no_metadata:
merged_from = sai_model_spec.build_merged_from(args.models)
title = os.path.splitext(os.path.basename(args.save_to))[0]
sai_metadata = sai_model_spec.build_metadata(
state_dict, v2, v2, False, True, False, time.time(), title=title, merged_from=merged_from
)
if v2:
# TODO read sai modelspec
print(
"Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します"
)
metadata.update(sai_metadata)
print(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
def setup_parser() -> argparse.ArgumentParser:
@@ -232,7 +328,25 @@ def setup_parser() -> argparse.ArgumentParser:
"--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors"
)
parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
parser.add_argument(
"--no_metadata",
action="store_true",
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
+ "sai modelspecのメタデータを保存しないLoRAの最低限のss_metadataは保存される",
)
parser.add_argument(
"--concat",
action="store_true",
help="concat lora instead of merge (The dim(rank) of the output LoRA is the sum of the input dims) / "
+ "マージの代わりに結合するLoRAのdim(rank)は入力dimの合計になる",
)
parser.add_argument(
"--shuffle",
action="store_true",
help="shuffle lora weight./ "
+ "LoRAの重みをシャッフルする",
)
return parser

View File

@@ -148,13 +148,13 @@ def merge(args):
merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype)
print(f"saving SD model to: {args.save_to}")
print(f"\nsaving SD model to: {args.save_to}")
model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet,
args.sd_model, 0, 0, save_dtype, vae)
else:
state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype)
print(f"saving model to: {args.save_to}")
print(f"\nsaving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype)

430
networks/oft.py Normal file
View File

@@ -0,0 +1,430 @@
# OFT network module
import math
import os
from typing import Dict, List, Optional, Tuple, Type, Union
from diffusers import AutoencoderKL
from transformers import CLIPTextModel
import numpy as np
import torch
import re
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
class OFTModule(torch.nn.Module):
"""
replaces forward method of the original Linear, instead of replacing the original Linear module.
"""
def __init__(
self,
oft_name,
org_module: torch.nn.Module,
multiplier=1.0,
dim=4,
alpha=1,
):
"""
dim -> num blocks
alpha -> constraint
"""
super().__init__()
self.oft_name = oft_name
self.num_blocks = dim
if "Linear" in org_module.__class__.__name__:
out_dim = org_module.out_features
elif "Conv" in org_module.__class__.__name__:
out_dim = org_module.out_channels
if type(alpha) == torch.Tensor:
alpha = alpha.detach().numpy()
self.constraint = alpha * out_dim
self.register_buffer("alpha", torch.tensor(alpha))
self.block_size = out_dim // self.num_blocks
self.oft_blocks = torch.nn.Parameter(torch.zeros(self.num_blocks, self.block_size, self.block_size))
self.out_dim = out_dim
self.shape = org_module.weight.shape
self.multiplier = multiplier
self.org_module = [org_module] # moduleにならないようにlistに入れる
def apply_to(self):
self.org_forward = self.org_module[0].forward
self.org_module[0].forward = self.forward
def get_weight(self, multiplier=None):
if multiplier is None:
multiplier = self.multiplier
block_Q = self.oft_blocks - self.oft_blocks.transpose(1, 2)
norm_Q = torch.norm(block_Q.flatten())
new_norm_Q = torch.clamp(norm_Q, max=self.constraint)
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1)
block_R = torch.matmul(I + block_Q, (I - block_Q).inverse())
block_R_weighted = self.multiplier * block_R + (1 - self.multiplier) * I
R = torch.block_diag(*block_R_weighted)
return R
def forward(self, x, scale=None):
x = self.org_forward(x)
if self.multiplier == 0.0:
return x
R = self.get_weight().to(x.device, dtype=x.dtype)
if x.dim() == 4:
x = x.permute(0, 2, 3, 1)
x = torch.matmul(x, R)
x = x.permute(0, 3, 1, 2)
else:
x = torch.matmul(x, R)
return x
class OFTInfModule(OFTModule):
def __init__(
self,
oft_name,
org_module: torch.nn.Module,
multiplier=1.0,
dim=4,
alpha=1,
**kwargs,
):
# no dropout for inference
super().__init__(oft_name, org_module, multiplier, dim, alpha)
self.enabled = True
self.network: OFTNetwork = None
def set_network(self, network):
self.network = network
def forward(self, x, scale=None):
if not self.enabled:
return self.org_forward(x)
return super().forward(x, scale)
def merge_to(self, multiplier=None, sign=1):
R = self.get_weight(multiplier) * sign
# get org weight
org_sd = self.org_module[0].state_dict()
org_weight = org_sd["weight"]
R = R.to(org_weight.device, dtype=org_weight.dtype)
if org_weight.dim() == 4:
weight = torch.einsum("oihw, op -> pihw", org_weight, R)
else:
weight = torch.einsum("oi, op -> pi", org_weight, R)
# set weight to org_module
org_sd["weight"] = weight
self.org_module[0].load_state_dict(org_sd)
def create_network(
multiplier: float,
network_dim: Optional[int],
network_alpha: Optional[float],
vae: AutoencoderKL,
text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
unet,
neuron_dropout: Optional[float] = None,
**kwargs,
):
if network_dim is None:
network_dim = 4 # default
if network_alpha is None:
network_alpha = 1.0
enable_all_linear = kwargs.get("enable_all_linear", None)
enable_conv = kwargs.get("enable_conv", None)
if enable_all_linear is not None:
enable_all_linear = bool(enable_all_linear)
if enable_conv is not None:
enable_conv = bool(enable_conv)
network = OFTNetwork(
text_encoder,
unet,
multiplier=multiplier,
dim=network_dim,
alpha=network_alpha,
enable_all_linear=enable_all_linear,
enable_conv=enable_conv,
varbose=True,
)
return network
# Create network from weights for inference, weights are not loaded here (because can be merged)
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
if weights_sd is None:
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file, safe_open
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
# check dim, alpha and if weights have for conv2d
dim = None
alpha = None
has_conv2d = None
all_linear = None
for name, param in weights_sd.items():
if name.endswith(".alpha"):
if alpha is None:
alpha = param.item()
else:
if dim is None:
dim = param.size()[0]
if has_conv2d is None and param.dim() == 4:
has_conv2d = True
if all_linear is None:
if param.dim() == 3 and "attn" not in name:
all_linear = True
if dim is not None and alpha is not None and has_conv2d is not None:
break
if has_conv2d is None:
has_conv2d = False
if all_linear is None:
all_linear = False
module_class = OFTInfModule if for_inference else OFTModule
network = OFTNetwork(
text_encoder,
unet,
multiplier=multiplier,
dim=dim,
alpha=alpha,
enable_all_linear=all_linear,
enable_conv=has_conv2d,
module_class=module_class,
)
return network, weights_sd
class OFTNetwork(torch.nn.Module):
UNET_TARGET_REPLACE_MODULE_ATTN_ONLY = ["CrossAttention"]
UNET_TARGET_REPLACE_MODULE_ALL_LINEAR = ["Transformer2DModel"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
OFT_PREFIX_UNET = "oft_unet" # これ変えないほうがいいかな
def __init__(
self,
text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
unet,
multiplier: float = 1.0,
dim: int = 4,
alpha: float = 1,
enable_all_linear: Optional[bool] = False,
enable_conv: Optional[bool] = False,
module_class: Type[object] = OFTModule,
varbose: Optional[bool] = False,
) -> None:
super().__init__()
self.multiplier = multiplier
self.dim = dim
self.alpha = alpha
print(
f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}"
)
# create module instances
def create_modules(
root_module: torch.nn.Module,
target_replace_modules: List[torch.nn.Module],
) -> List[OFTModule]:
prefix = self.OFT_PREFIX_UNET
ofts = []
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
is_linear = "Linear" in child_module.__class__.__name__
is_conv2d = "Conv2d" in child_module.__class__.__name__
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
if is_linear or is_conv2d_1x1 or (is_conv2d and enable_conv):
oft_name = prefix + "." + name + "." + child_name
oft_name = oft_name.replace(".", "_")
# print(oft_name)
oft = module_class(
oft_name,
child_module,
self.multiplier,
dim,
alpha,
)
ofts.append(oft)
return ofts
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
if enable_all_linear:
target_modules = OFTNetwork.UNET_TARGET_REPLACE_MODULE_ALL_LINEAR
else:
target_modules = OFTNetwork.UNET_TARGET_REPLACE_MODULE_ATTN_ONLY
if enable_conv:
target_modules += OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
self.unet_ofts: List[OFTModule] = create_modules(unet, target_modules)
print(f"create OFT for U-Net: {len(self.unet_ofts)} modules.")
# assertion
names = set()
for oft in self.unet_ofts:
assert oft.oft_name not in names, f"duplicated oft name: {oft.oft_name}"
names.add(oft.oft_name)
def set_multiplier(self, multiplier):
self.multiplier = multiplier
for oft in self.unet_ofts:
oft.multiplier = self.multiplier
def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
info = self.load_state_dict(weights_sd, False)
return info
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
assert apply_unet, "apply_unet must be True"
for oft in self.unet_ofts:
oft.apply_to()
self.add_module(oft.oft_name, oft)
# マージできるかどうかを返す
def is_mergeable(self):
return True
# TODO refactor to common function with apply_to
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
print("enable OFT for U-Net")
for oft in self.unet_ofts:
sd_for_lora = {}
for key in weights_sd.keys():
if key.startswith(oft.oft_name):
sd_for_lora[key[len(oft.oft_name) + 1 :]] = weights_sd[key]
oft.load_state_dict(sd_for_lora, False)
oft.merge_to()
print(f"weights are merged")
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
self.requires_grad_(True)
all_params = []
def enumerate_params(ofts):
params = []
for oft in ofts:
params.extend(oft.parameters())
# print num of params
num_params = 0
for p in params:
num_params += p.numel()
print(f"OFT params: {num_params}")
return params
param_data = {"params": enumerate_params(self.unet_ofts)}
if unet_lr is not None:
param_data["lr"] = unet_lr
all_params.append(param_data)
return all_params
def enable_gradient_checkpointing(self):
# not supported
pass
def prepare_grad_etc(self, text_encoder, unet):
self.requires_grad_(True)
def on_epoch_start(self, text_encoder, unet):
self.train()
def get_trainable_params(self):
return self.parameters()
def save_weights(self, file, dtype, metadata):
if metadata is not None and len(metadata) == 0:
metadata = None
state_dict = self.state_dict()
if dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(dtype)
state_dict[key] = v
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file
from library import train_util
# Precalculate model hashes to save time on indexing
if metadata is None:
metadata = {}
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
save_file(state_dict, file, metadata)
else:
torch.save(state_dict, file)
def backup_weights(self):
# 重みのバックアップを行う
ofts: List[OFTInfModule] = self.unet_ofts
for oft in ofts:
org_module = oft.org_module[0]
if not hasattr(org_module, "_lora_org_weight"):
sd = org_module.state_dict()
org_module._lora_org_weight = sd["weight"].detach().clone()
org_module._lora_restored = True
def restore_weights(self):
# 重みのリストアを行う
ofts: List[OFTInfModule] = self.unet_ofts
for oft in ofts:
org_module = oft.org_module[0]
if not org_module._lora_restored:
sd = org_module.state_dict()
sd["weight"] = org_module._lora_org_weight
org_module.load_state_dict(sd)
org_module._lora_restored = True
def pre_calculation(self):
# 事前計算を行う
ofts: List[OFTInfModule] = self.unet_ofts
for oft in ofts:
org_module = oft.org_module[0]
oft.merge_to()
# sd = org_module.state_dict()
# org_weight = sd["weight"]
# lora_weight = oft.get_weight().to(org_weight.device, dtype=org_weight.dtype)
# sd["weight"] = org_weight + lora_weight
# assert sd["weight"].shape == org_weight.shape
# org_module.load_state_dict(sd)
org_module._lora_restored = False
oft.enabled = False

View File

@@ -219,8 +219,8 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
for key, value in tqdm(lora_sd.items()):
weight_name = None
if 'lora_down' in key:
block_down_name = key.split(".")[0]
weight_name = key.split(".")[-1]
block_down_name = key.rsplit('.lora_down', 1)[0]
weight_name = key.rsplit(".", 1)[-1]
lora_down_weight = value
else:
continue
@@ -283,7 +283,10 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
def resize(args):
if args.save_to is None or not (args.save_to.endswith('.ckpt') or args.save_to.endswith('.pt') or args.save_to.endswith('.pth') or args.save_to.endswith('.safetensors')):
raise Exception("The --save_to argument must be specified and must be a .ckpt , .pt, .pth or .safetensors file.")
def str_to_dtype(p):
if p == 'float':
return torch.float

348
networks/sdxl_merge_lora.py Normal file
View File

@@ -0,0 +1,348 @@
import math
import argparse
import os
import time
import torch
from safetensors.torch import load_file, save_file
from tqdm import tqdm
from library import sai_model_spec, sdxl_model_util, train_util
import library.model_util as model_util
import lora
def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == ".safetensors":
sd = load_file(file_name)
metadata = train_util.load_metadata_from_safetensors(file_name)
else:
sd = torch.load(file_name, map_location="cpu")
metadata = {}
for key in list(sd.keys()):
if type(sd[key]) == torch.Tensor:
sd[key] = sd[key].to(dtype)
return sd, metadata
def save_to_file(file_name, model, state_dict, dtype, metadata):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
if os.path.splitext(file_name)[1] == ".safetensors":
save_file(model, file_name, metadata=metadata)
else:
torch.save(model, file_name)
def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype):
text_encoder1.to(merge_dtype)
text_encoder1.to(merge_dtype)
unet.to(merge_dtype)
# create module map
name_to_module = {}
for i, root_module in enumerate([text_encoder1, text_encoder2, unet]):
if i <= 1:
if i == 0:
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1
else:
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2
target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
else:
prefix = lora.LoRANetwork.LORA_PREFIX_UNET
target_replace_modules = (
lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
)
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
lora_name = prefix + "." + name + "." + child_name
lora_name = lora_name.replace(".", "_")
name_to_module[lora_name] = child_module
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
lora_sd, _ = load_state_dict(model, merge_dtype)
print(f"merging...")
for key in tqdm(lora_sd.keys()):
if "lora_down" in key:
up_key = key.replace("lora_down", "lora_up")
alpha_key = key[: key.index("lora_down")] + "alpha"
# find original module for this lora
module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
if module_name not in name_to_module:
print(f"no module found for LoRA weight: {key}")
continue
module = name_to_module[module_name]
# print(f"apply {key} to {module}")
down_weight = lora_sd[key]
up_weight = lora_sd[up_key]
dim = down_weight.size()[0]
alpha = lora_sd.get(alpha_key, dim)
scale = alpha / dim
# W <- W + U * D
weight = module.weight
# print(module_name, down_weight.size(), up_weight.size())
if len(weight.size()) == 2:
# linear
weight = weight + ratio * (up_weight @ down_weight) * scale
elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1
weight = (
weight
+ ratio
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* scale
)
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# print(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + ratio * conved * scale
module.weight = torch.nn.Parameter(weight)
def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
base_alphas = {} # alpha for merged model
base_dims = {}
merged_sd = {}
v2 = None
base_model = None
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
if lora_metadata is not None:
if v2 is None:
v2 = lora_metadata.get(train_util.SS_METADATA_KEY_V2, None) # returns string, SDXLはv2がないのでFalseのはず
if base_model is None:
base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None)
# get alpha and dim
alphas = {} # alpha for current model
dims = {} # dims for current model
for key in lora_sd.keys():
if "alpha" in key:
lora_module_name = key[: key.rfind(".alpha")]
alpha = float(lora_sd[key].detach().numpy())
alphas[lora_module_name] = alpha
if lora_module_name not in base_alphas:
base_alphas[lora_module_name] = alpha
elif "lora_down" in key:
lora_module_name = key[: key.rfind(".lora_down")]
dim = lora_sd[key].size()[0]
dims[lora_module_name] = dim
if lora_module_name not in base_dims:
base_dims[lora_module_name] = dim
for lora_module_name in dims.keys():
if lora_module_name not in alphas:
alpha = dims[lora_module_name]
alphas[lora_module_name] = alpha
if lora_module_name not in base_alphas:
base_alphas[lora_module_name] = alpha
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
# merge
print(f"merging...")
for key in tqdm(lora_sd.keys()):
if "alpha" in key:
continue
if "lora_up" in key and concat:
concat_dim = 1
elif "lora_down" in key and concat:
concat_dim = 0
else:
concat_dim = None
lora_module_name = key[: key.rfind(".lora_")]
base_alpha = base_alphas[lora_module_name]
alpha = alphas[lora_module_name]
scale = math.sqrt(alpha / base_alpha) * ratio
scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。
if key in merged_sd:
assert (
merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
if concat_dim is not None:
merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim)
else:
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
else:
merged_sd[key] = lora_sd[key] * scale
# set alpha to sd
for lora_module_name, alpha in base_alphas.items():
key = lora_module_name + ".alpha"
merged_sd[key] = torch.tensor(alpha)
if shuffle:
key_down = lora_module_name + ".lora_down.weight"
key_up = lora_module_name + ".lora_up.weight"
dim = merged_sd[key_down].shape[0]
perm = torch.randperm(dim)
merged_sd[key_down] = merged_sd[key_down][perm]
merged_sd[key_up] = merged_sd[key_up][:,perm]
print("merged model")
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
# check all dims are same
dims_list = list(set(base_dims.values()))
alphas_list = list(set(base_alphas.values()))
all_same_dims = True
all_same_alphas = True
for dims in dims_list:
if dims != dims_list[0]:
all_same_dims = False
break
for alphas in alphas_list:
if alphas != alphas_list[0]:
all_same_alphas = False
break
# build minimum metadata
dims = f"{dims_list[0]}" if all_same_dims else "Dynamic"
alphas = f"{alphas_list[0]}" if all_same_alphas else "Dynamic"
metadata = train_util.build_minimum_network_metadata(v2, base_model, "networks.lora", dims, alphas, None)
return merged_sd, metadata
def merge(args):
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
def str_to_dtype(p):
if p == "float":
return torch.float
if p == "fp16":
return torch.float16
if p == "bf16":
return torch.bfloat16
return None
merge_dtype = str_to_dtype(args.precision)
save_dtype = str_to_dtype(args.save_precision)
if save_dtype is None:
save_dtype = merge_dtype
if args.sd_model is not None:
print(f"loading SD model: {args.sd_model}")
(
text_model1,
text_model2,
vae,
unet,
logit_scale,
ckpt_info,
) = sdxl_model_util.load_models_from_sdxl_checkpoint(sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.sd_model, "cpu")
merge_to_sd_model(text_model1, text_model2, unet, args.models, args.ratios, merge_dtype)
if args.no_metadata:
sai_metadata = None
else:
merged_from = sai_model_spec.build_merged_from([args.sd_model] + args.models)
title = os.path.splitext(os.path.basename(args.save_to))[0]
sai_metadata = sai_model_spec.build_metadata(
None, False, False, True, False, False, time.time(), title=title, merged_from=merged_from
)
print(f"saving SD model to: {args.save_to}")
sdxl_model_util.save_stable_diffusion_checkpoint(
args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, logit_scale, sai_metadata, save_dtype
)
else:
state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
print(f"calculating hashes and creating metadata...")
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
if not args.no_metadata:
merged_from = sai_model_spec.build_merged_from(args.models)
title = os.path.splitext(os.path.basename(args.save_to))[0]
sai_metadata = sai_model_spec.build_metadata(
state_dict, False, False, True, True, False, time.time(), title=title, merged_from=merged_from
)
metadata.update(sai_metadata)
print(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument(
"--save_precision",
type=str,
default=None,
choices=[None, "float", "fp16", "bf16"],
help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ",
)
parser.add_argument(
"--precision",
type=str,
default="float",
choices=["float", "fp16", "bf16"],
help="precision in merging (float is recommended) / マージの計算時の精度floatを推奨",
)
parser.add_argument(
"--sd_model",
type=str,
default=None,
help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする",
)
parser.add_argument(
"--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
)
parser.add_argument(
"--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors"
)
parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
parser.add_argument(
"--no_metadata",
action="store_true",
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
+ "sai modelspecのメタデータを保存しないLoRAの最低限のss_metadataは保存される",
)
parser.add_argument(
"--concat",
action="store_true",
help="concat lora instead of merge (The dim(rank) of the output LoRA is the sum of the input dims) / "
+ "マージの代わりに結合するLoRAのdim(rank)は入力dimの合計になる",
)
parser.add_argument(
"--shuffle",
action="store_true",
help="shuffle lora weight./ "
+ "LoRAの重みをシャッフルする",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
merge(args)

View File

@@ -1,10 +1,11 @@
import math
import argparse
import os
import time
import torch
from safetensors.torch import load_file, save_file
from tqdm import tqdm
from library import sai_model_spec, train_util
import library.model_util as model_util
import lora
@@ -13,180 +14,247 @@ CLAMP_QUANTILE = 0.99
def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == '.safetensors':
sd = load_file(file_name)
else:
sd = torch.load(file_name, map_location='cpu')
for key in list(sd.keys()):
if type(sd[key]) == torch.Tensor:
sd[key] = sd[key].to(dtype)
return sd
if os.path.splitext(file_name)[1] == ".safetensors":
sd = load_file(file_name)
metadata = train_util.load_metadata_from_safetensors(file_name)
else:
sd = torch.load(file_name, map_location="cpu")
metadata = {}
for key in list(sd.keys()):
if type(sd[key]) == torch.Tensor:
sd[key] = sd[key].to(dtype)
return sd, metadata
def save_to_file(file_name, state_dict, dtype):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
def save_to_file(file_name, state_dict, dtype, metadata):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
if os.path.splitext(file_name)[1] == '.safetensors':
save_file(state_dict, file_name)
else:
torch.save(state_dict, file_name)
if os.path.splitext(file_name)[1] == ".safetensors":
save_file(state_dict, file_name, metadata=metadata)
else:
torch.save(state_dict, file_name)
def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype):
print(f"new rank: {new_rank}, new conv rank: {new_conv_rank}")
merged_sd = {}
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
lora_sd = load_state_dict(model, merge_dtype)
print(f"new rank: {new_rank}, new conv rank: {new_conv_rank}")
merged_sd = {}
v2 = None
base_model = None
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
# merge
print(f"merging...")
for key in tqdm(list(lora_sd.keys())):
if 'lora_down' not in key:
continue
if lora_metadata is not None:
if v2 is None:
v2 = lora_metadata.get(train_util.SS_METADATA_KEY_V2, None) # return string
if base_model is None:
base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None)
lora_module_name = key[:key.rfind(".lora_down")]
# merge
print(f"merging...")
for key in tqdm(list(lora_sd.keys())):
if "lora_down" not in key:
continue
down_weight = lora_sd[key]
network_dim = down_weight.size()[0]
lora_module_name = key[: key.rfind(".lora_down")]
up_weight = lora_sd[lora_module_name + '.lora_up.weight']
alpha = lora_sd.get(lora_module_name + '.alpha', network_dim)
down_weight = lora_sd[key]
network_dim = down_weight.size()[0]
in_dim = down_weight.size()[1]
out_dim = up_weight.size()[0]
conv2d = len(down_weight.size()) == 4
kernel_size = None if not conv2d else down_weight.size()[2:4]
# print(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size)
up_weight = lora_sd[lora_module_name + ".lora_up.weight"]
alpha = lora_sd.get(lora_module_name + ".alpha", network_dim)
# make original weight if not exist
if lora_module_name not in merged_sd:
weight = torch.zeros((out_dim, in_dim, *kernel_size) if conv2d else (out_dim, in_dim), dtype=merge_dtype)
if device:
weight = weight.to(device)
else:
weight = merged_sd[lora_module_name]
in_dim = down_weight.size()[1]
out_dim = up_weight.size()[0]
conv2d = len(down_weight.size()) == 4
kernel_size = None if not conv2d else down_weight.size()[2:4]
# print(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size)
# merge to weight
if device:
up_weight = up_weight.to(device)
down_weight = down_weight.to(device)
# make original weight if not exist
if lora_module_name not in merged_sd:
weight = torch.zeros((out_dim, in_dim, *kernel_size) if conv2d else (out_dim, in_dim), dtype=merge_dtype)
if device:
weight = weight.to(device)
else:
weight = merged_sd[lora_module_name]
# W <- W + U * D
scale = (alpha / network_dim)
# merge to weight
if device:
up_weight = up_weight.to(device)
down_weight = down_weight.to(device)
if device: # and isinstance(scale, torch.Tensor):
scale = scale.to(device)
# W <- W + U * D
scale = alpha / network_dim
if not conv2d: # linear
weight = weight + ratio * (up_weight @ down_weight) * scale
elif kernel_size == (1, 1):
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
).unsqueeze(2).unsqueeze(3) * scale
else:
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
weight = weight + ratio * conved * scale
if device: # and isinstance(scale, torch.Tensor):
scale = scale.to(device)
merged_sd[lora_module_name] = weight
if not conv2d: # linear
weight = weight + ratio * (up_weight @ down_weight) * scale
elif kernel_size == (1, 1):
weight = (
weight
+ ratio
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* scale
)
else:
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
weight = weight + ratio * conved * scale
# extract from merged weights
print("extract new lora...")
merged_lora_sd = {}
with torch.no_grad():
for lora_module_name, mat in tqdm(list(merged_sd.items())):
conv2d = (len(mat.size()) == 4)
kernel_size = None if not conv2d else mat.size()[2:4]
conv2d_3x3 = conv2d and kernel_size != (1, 1)
out_dim, in_dim = mat.size()[0:2]
merged_sd[lora_module_name] = weight
if conv2d:
if conv2d_3x3:
mat = mat.flatten(start_dim=1)
else:
mat = mat.squeeze()
# extract from merged weights
print("extract new lora...")
merged_lora_sd = {}
with torch.no_grad():
for lora_module_name, mat in tqdm(list(merged_sd.items())):
conv2d = len(mat.size()) == 4
kernel_size = None if not conv2d else mat.size()[2:4]
conv2d_3x3 = conv2d and kernel_size != (1, 1)
out_dim, in_dim = mat.size()[0:2]
module_new_rank = new_conv_rank if conv2d_3x3 else new_rank
module_new_rank = min(module_new_rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
if conv2d:
if conv2d_3x3:
mat = mat.flatten(start_dim=1)
else:
mat = mat.squeeze()
U, S, Vh = torch.linalg.svd(mat)
module_new_rank = new_conv_rank if conv2d_3x3 else new_rank
module_new_rank = min(module_new_rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
U = U[:, :module_new_rank]
S = S[:module_new_rank]
U = U @ torch.diag(S)
U, S, Vh = torch.linalg.svd(mat)
Vh = Vh[:module_new_rank, :]
U = U[:, :module_new_rank]
S = S[:module_new_rank]
U = U @ torch.diag(S)
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
low_val = -hi_val
Vh = Vh[:module_new_rank, :]
U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
low_val = -hi_val
if conv2d:
U = U.reshape(out_dim, module_new_rank, 1, 1)
Vh = Vh.reshape(module_new_rank, in_dim, kernel_size[0], kernel_size[1])
U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)
up_weight = U
down_weight = Vh
if conv2d:
U = U.reshape(out_dim, module_new_rank, 1, 1)
Vh = Vh.reshape(module_new_rank, in_dim, kernel_size[0], kernel_size[1])
merged_lora_sd[lora_module_name + '.lora_up.weight'] = up_weight.to("cpu").contiguous()
merged_lora_sd[lora_module_name + '.lora_down.weight'] = down_weight.to("cpu").contiguous()
merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(module_new_rank)
up_weight = U
down_weight = Vh
return merged_lora_sd
merged_lora_sd[lora_module_name + ".lora_up.weight"] = up_weight.to("cpu").contiguous()
merged_lora_sd[lora_module_name + ".lora_down.weight"] = down_weight.to("cpu").contiguous()
merged_lora_sd[lora_module_name + ".alpha"] = torch.tensor(module_new_rank)
# build minimum metadata
dims = f"{new_rank}"
alphas = f"{new_rank}"
if new_conv_rank is not None:
network_args = {"conv_dim": new_conv_rank, "conv_alpha": new_conv_rank}
else:
network_args = None
metadata = train_util.build_minimum_network_metadata(v2, base_model, "networks.lora", dims, alphas, network_args)
return merged_lora_sd, metadata, v2 == "True", base_model
def merge(args):
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
def str_to_dtype(p):
if p == 'float':
return torch.float
if p == 'fp16':
return torch.float16
if p == 'bf16':
return torch.bfloat16
return None
def str_to_dtype(p):
if p == "float":
return torch.float
if p == "fp16":
return torch.float16
if p == "bf16":
return torch.bfloat16
return None
merge_dtype = str_to_dtype(args.precision)
save_dtype = str_to_dtype(args.save_precision)
if save_dtype is None:
save_dtype = merge_dtype
merge_dtype = str_to_dtype(args.precision)
save_dtype = str_to_dtype(args.save_precision)
if save_dtype is None:
save_dtype = merge_dtype
new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank
state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype)
new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank
state_dict, metadata, v2, base_model = merge_lora_models(
args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype
)
print(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, save_dtype)
print(f"calculating hashes and creating metadata...")
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
if not args.no_metadata:
is_sdxl = base_model is not None and base_model.lower().startswith("sdxl")
merged_from = sai_model_spec.build_merged_from(args.models)
title = os.path.splitext(os.path.basename(args.save_to))[0]
sai_metadata = sai_model_spec.build_metadata(
state_dict, v2, v2, is_sdxl, True, False, time.time(), title=title, merged_from=merged_from
)
if v2:
# TODO read sai modelspec
print(
"Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します"
)
metadata.update(sai_metadata)
print(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, save_dtype, metadata)
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("--save_precision", type=str, default=None,
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ")
parser.add_argument("--precision", type=str, default="float",
choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度floatを推奨")
parser.add_argument("--save_to", type=str, default=None,
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
parser.add_argument("--models", type=str, nargs='*',
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors")
parser.add_argument("--ratios", type=float, nargs='*',
help="ratios for each model / それぞれのLoRAモデルの比率")
parser.add_argument("--new_rank", type=int, default=4,
help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
parser.add_argument("--new_conv_rank", type=int, default=None,
help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ")
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
parser = argparse.ArgumentParser()
parser.add_argument(
"--save_precision",
type=str,
default=None,
choices=[None, "float", "fp16", "bf16"],
help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ",
)
parser.add_argument(
"--precision",
type=str,
default="float",
choices=["float", "fp16", "bf16"],
help="precision in merging (float is recommended) / マージの計算時の精度floatを推奨",
)
parser.add_argument(
"--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
)
parser.add_argument(
"--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors"
)
parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
parser.add_argument("--new_rank", type=int, default=4, help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
parser.add_argument(
"--new_conv_rank",
type=int,
default=None,
help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ",
)
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
parser.add_argument(
"--no_metadata",
action="store_true",
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
+ "sai modelspecのメタデータを保存しないLoRAの最低限のss_metadataは保存される",
)
return parser
return parser
if __name__ == '__main__':
parser = setup_parser()
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
merge(args)
args = parser.parse_args()
merge(args)

View File

@@ -1,26 +1,33 @@
accelerate==0.15.0
transformers==4.26.0
accelerate==0.23.0
transformers==4.30.2
diffusers[torch]==0.21.2
ftfy==6.1.1
albumentations==1.3.0
# albumentations==1.3.0
opencv-python==4.7.0.68
einops==0.6.0
diffusers[torch]==0.10.2
pytorch-lightning==1.9.0
bitsandbytes==0.35.0
# bitsandbytes==0.39.1
tensorboard==2.10.1
safetensors==0.2.6
gradio==3.16.2
safetensors==0.3.1
# gradio==3.16.2
altair==4.2.2
easygui==0.98.3
toml==0.10.2
voluptuous==0.13.1
huggingface-hub==0.15.1
# for BLIP captioning
requests==2.28.2
timm==0.6.12
fairscale==0.4.13
# for WD14 captioning
# tensorflow<2.11
tensorflow==2.10.1
huggingface-hub==0.13.3
# requests==2.28.2
# timm==0.6.12
# fairscale==0.4.13
# for WD14 captioning (tensorflow)
# tensorflow==2.10.1
# for WD14 captioning (onnx)
# onnx==1.14.1
# onnxruntime-gpu==1.16.0
# onnxruntime==1.16.0
# this is for onnx:
# protobuf==3.20.3
# open clip for SDXL
open-clip-torch==2.20.0
# for kohya_ss library
.
-e .

2748
sdxl_gen_img.py Executable file

File diff suppressed because it is too large Load Diff

328
sdxl_minimal_inference.py Normal file
View File

@@ -0,0 +1,328 @@
# 手元で推論を行うための最低限のコード。HuggingFaceDiffusersのCLIP、schedulerとVAEを使う
# Minimal code for performing inference at local. Use HuggingFace/Diffusers CLIP, scheduler and VAE
import argparse
import datetime
import math
import os
import random
from einops import repeat
import numpy as np
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from tqdm import tqdm
from transformers import CLIPTokenizer
from diffusers import EulerDiscreteScheduler
from PIL import Image
import open_clip
from safetensors.torch import load_file
from library import model_util, sdxl_model_util
import networks.lora as lora
# scheduler: このあたりの設定はSD1/2と同じでいいらしい
# scheduler: The settings around here seem to be the same as SD1/2
SCHEDULER_LINEAR_START = 0.00085
SCHEDULER_LINEAR_END = 0.0120
SCHEDULER_TIMESTEPS = 1000
SCHEDLER_SCHEDULE = "scaled_linear"
# Time EmbeddingはDiffusersからのコピー
# Time Embedding is copied from Diffusers
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
if not repeat_only:
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
device=timesteps.device
)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
else:
embedding = repeat(timesteps, "b -> b d", d=dim)
return embedding
def get_timestep_embedding(x, outdim):
assert len(x.shape) == 2
b, dims = x.shape[0], x.shape[1]
# x = rearrange(x, "b d -> (b d)")
x = torch.flatten(x)
emb = timestep_embedding(x, outdim)
# emb = rearrange(emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=outdim)
emb = torch.reshape(emb, (b, dims * outdim))
return emb
if __name__ == "__main__":
# 画像生成条件を変更する場合はここを変更 / change here to change image generation conditions
# SDXLの追加のvector embeddingへ渡す値 / Values to pass to additional vector embedding of SDXL
target_height = 1024
target_width = 1024
original_height = target_height
original_width = target_width
crop_top = 0
crop_left = 0
steps = 50
guidance_scale = 7
seed = None # 1
DEVICE = "cuda"
DTYPE = torch.float16 # bfloat16 may work
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt_path", type=str, required=True)
parser.add_argument("--prompt", type=str, default="A photo of a cat")
parser.add_argument("--prompt2", type=str, default=None)
parser.add_argument("--negative_prompt", type=str, default="")
parser.add_argument("--output_dir", type=str, default=".")
parser.add_argument(
"--lora_weights",
type=str,
nargs="*",
default=[],
help="LoRA weights, only supports networks.lora, each argument is a `path;multiplier` (semi-colon separated)",
)
parser.add_argument("--interactive", action="store_true")
args = parser.parse_args()
if args.prompt2 is None:
args.prompt2 = args.prompt
# HuggingFaceのmodel id
text_encoder_1_name = "openai/clip-vit-large-patch14"
text_encoder_2_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
# checkpointを読み込む。モデル変換についてはそちらの関数を参照
# Load checkpoint. For model conversion, see this function
# 本体RAMが少ない場合はGPUにロードするといいかも
# If the main RAM is small, it may be better to load it on the GPU
text_model1, text_model2, vae, unet, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.ckpt_path, "cpu"
)
# Text Encoder 1はSDXL本体でもHuggingFaceのものを使っている
# In SDXL, Text Encoder 1 is also using HuggingFace's
# Text Encoder 2はSDXL本体ではopen_clipを使っている
# それを使ってもいいが、SD2のDiffusers版に合わせる形で、HuggingFaceのものを使う
# 重みの変換コードはSD2とほぼ同じ
# In SDXL, Text Encoder 2 is using open_clip
# It's okay to use it, but to match the Diffusers version of SD2, use HuggingFace's
# The weight conversion code is almost the same as SD2
# VAEの構造はSDXLもSD1/2と同じだが、重みは異なるようだ。何より謎のscale値が違う
# fp16でNaNが出やすいようだ
# The structure of VAE is the same as SD1/2, but the weights seem to be different. Above all, the mysterious scale value is different.
# NaN seems to be more likely to occur in fp16
unet.to(DEVICE, dtype=DTYPE)
unet.eval()
vae_dtype = DTYPE
if DTYPE == torch.float16:
print("use float32 for vae")
vae_dtype = torch.float32
vae.to(DEVICE, dtype=vae_dtype)
vae.eval()
text_model1.to(DEVICE, dtype=DTYPE)
text_model1.eval()
text_model2.to(DEVICE, dtype=DTYPE)
text_model2.eval()
unet.set_use_memory_efficient_attention(True, False)
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
vae.set_use_memory_efficient_attention_xformers(True)
# Tokenizers
tokenizer1 = CLIPTokenizer.from_pretrained(text_encoder_1_name)
tokenizer2 = lambda x: open_clip.tokenize(x, context_length=77)
# LoRA
for weights_file in args.lora_weights:
if ";" in weights_file:
weights_file, multiplier = weights_file.split(";")
multiplier = float(multiplier)
else:
multiplier = 1.0
lora_model, weights_sd = lora.create_network_from_weights(
multiplier, weights_file, vae, [text_model1, text_model2], unet, None, True
)
lora_model.merge_to([text_model1, text_model2], unet, weights_sd, DTYPE, DEVICE)
# scheduler
scheduler = EulerDiscreteScheduler(
num_train_timesteps=SCHEDULER_TIMESTEPS,
beta_start=SCHEDULER_LINEAR_START,
beta_end=SCHEDULER_LINEAR_END,
beta_schedule=SCHEDLER_SCHEDULE,
)
def generate_image(prompt, prompt2, negative_prompt, seed=None):
# 将来的にサイズ情報も変えられるようにする / Make it possible to change the size information in the future
# prepare embedding
with torch.no_grad():
# vector
emb1 = get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256)
emb2 = get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256)
emb3 = get_timestep_embedding(torch.FloatTensor([target_height, target_width]).unsqueeze(0), 256)
# print("emb1", emb1.shape)
c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(DEVICE, dtype=DTYPE)
uc_vector = c_vector.clone().to(DEVICE, dtype=DTYPE) # ちょっとここ正しいかどうかわからない I'm not sure if this is right
# crossattn
# Text Encoderを二つ呼ぶ関数 Function to call two Text Encoders
def call_text_encoder(text, text2):
# text encoder 1
batch_encoding = tokenizer1(
text,
truncation=True,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(DEVICE)
with torch.no_grad():
enc_out = text_model1(tokens, output_hidden_states=True, return_dict=True)
text_embedding1 = enc_out["hidden_states"][11]
# text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding) # layer normは通さないらしい
# text encoder 2
with torch.no_grad():
tokens = tokenizer2(text2).to(DEVICE)
enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True)
text_embedding2_penu = enc_out["hidden_states"][-2]
# print("hidden_states2", text_embedding2_penu.shape)
text_embedding2_pool = enc_out["text_embeds"] # do not support Textual Inversion
# 連結して終了 concat and finish
text_embedding = torch.cat([text_embedding1, text_embedding2_penu], dim=2)
return text_embedding, text_embedding2_pool
# cond
c_ctx, c_ctx_pool = call_text_encoder(prompt, prompt2)
# print(c_ctx.shape, c_ctx_p.shape, c_vector.shape)
c_vector = torch.cat([c_ctx_pool, c_vector], dim=1)
# uncond
uc_ctx, uc_ctx_pool = call_text_encoder(negative_prompt, negative_prompt)
uc_vector = torch.cat([uc_ctx_pool, uc_vector], dim=1)
text_embeddings = torch.cat([uc_ctx, c_ctx])
vector_embeddings = torch.cat([uc_vector, c_vector])
# メモリ使用量を減らすにはここでText Encoderを削除するかCPUへ移動する
if seed is not None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# # random generator for initial noise
# generator = torch.Generator(device="cuda").manual_seed(seed)
generator = None
else:
generator = None
# get the initial random noise unless the user supplied it
# SDXLはCPUでlatentsを作成しているので一応合わせておく、Diffusersはtarget deviceでlatentsを作成している
# SDXL creates latents in CPU, Diffusers creates latents in target device
latents_shape = (1, 4, target_height // 8, target_width // 8)
latents = torch.randn(
latents_shape,
generator=generator,
device="cpu",
dtype=torch.float32,
).to(DEVICE, dtype=DTYPE)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * scheduler.init_noise_sigma
# set timesteps
scheduler.set_timesteps(steps, DEVICE)
# このへんはDiffusersからのコピペ
# Copy from Diffusers
timesteps = scheduler.timesteps.to(DEVICE) # .to(DTYPE)
num_latent_input = 2
with torch.no_grad():
for i, t in enumerate(tqdm(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
noise_pred = unet(latent_model_input, t, text_embeddings, vector_embeddings)
noise_pred_uncond, noise_pred_text = noise_pred.chunk(num_latent_input) # uncond by negative prompt
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
# latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = scheduler.step(noise_pred, t, latents).prev_sample
# latents = 1 / 0.18215 * latents
latents = 1 / sdxl_model_util.VAE_SCALE_FACTOR * latents
latents = latents.to(vae_dtype)
image = vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
# image = self.numpy_to_pil(image)
image = (image * 255).round().astype("uint8")
image = [Image.fromarray(im) for im in image]
# 保存して終了 save and finish
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
for i, img in enumerate(image):
img.save(os.path.join(args.output_dir, f"image_{timestamp}_{i:03d}.png"))
if not args.interactive:
generate_image(args.prompt, args.prompt2, args.negative_prompt, seed)
else:
# loop for interactive
while True:
prompt = input("prompt: ")
if prompt == "":
break
prompt2 = input("prompt2: ")
if prompt2 == "":
prompt2 = prompt
negative_prompt = input("negative prompt: ")
seed = input("seed: ")
if seed == "":
seed = None
else:
seed = int(seed)
generate_image(prompt, prompt2, negative_prompt, seed)
print("Done!")

753
sdxl_train.py Normal file
View File

@@ -0,0 +1,753 @@
# training with captions
import argparse
import gc
import math
import os
from multiprocessing import Value
from typing import List
import toml
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from library import sdxl_model_util
import library.train_util as train_util
import library.config_util as config_util
import library.sdxl_train_util as sdxl_train_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import (
apply_snr_weight,
prepare_scheduler_for_custom_training,
scale_v_prediction_loss_like_noise_prediction,
add_v_prediction_like_loss,
)
from library.sdxl_original_unet import SdxlUNet2DConditionModel
UNET_NUM_BLOCKS_FOR_BLOCK_LR = 23
def get_block_params_to_optimize(unet: SdxlUNet2DConditionModel, block_lrs: List[float]) -> List[dict]:
block_params = [[] for _ in range(len(block_lrs))]
for i, (name, param) in enumerate(unet.named_parameters()):
if name.startswith("time_embed.") or name.startswith("label_emb."):
block_index = 0 # 0
elif name.startswith("input_blocks."): # 1-9
block_index = 1 + int(name.split(".")[1])
elif name.startswith("middle_block."): # 10-12
block_index = 10 + int(name.split(".")[1])
elif name.startswith("output_blocks."): # 13-21
block_index = 13 + int(name.split(".")[1])
elif name.startswith("out."): # 22
block_index = 22
else:
raise ValueError(f"unexpected parameter name: {name}")
block_params[block_index].append(param)
params_to_optimize = []
for i, params in enumerate(block_params):
if block_lrs[i] == 0: # 0のときは学習しない do not optimize when lr is 0
continue
params_to_optimize.append({"params": params, "lr": block_lrs[i]})
return params_to_optimize
def append_block_lr_to_logs(block_lrs, logs, lr_scheduler, optimizer_type):
lrs = lr_scheduler.get_last_lr()
lr_index = 0
block_index = 0
while lr_index < len(lrs):
if block_index < UNET_NUM_BLOCKS_FOR_BLOCK_LR:
name = f"block{block_index}"
if block_lrs[block_index] == 0:
block_index += 1
continue
elif block_index == UNET_NUM_BLOCKS_FOR_BLOCK_LR:
name = "text_encoder1"
elif block_index == UNET_NUM_BLOCKS_FOR_BLOCK_LR + 1:
name = "text_encoder2"
else:
raise ValueError(f"unexpected block_index: {block_index}")
block_index += 1
logs["lr/" + name] = float(lrs[lr_index])
if optimizer_type.lower().startswith("DAdapt".lower()) or optimizer_type.lower() == "Prodigy".lower():
logs["lr/d*lr/" + name] = (
lr_scheduler.optimizers[-1].param_groups[lr_index]["d"] * lr_scheduler.optimizers[-1].param_groups[lr_index]["lr"]
)
lr_index += 1
def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
sdxl_train_util.verify_sdxl_training_args(args)
assert not args.weighted_captions, "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
assert (
not args.train_text_encoder or not args.cache_text_encoder_outputs
), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません"
if args.block_lr:
block_lrs = [float(lr) for lr in args.block_lr.split(",")]
assert (
len(block_lrs) == UNET_NUM_BLOCKS_FOR_BLOCK_LR
), f"block_lr must have {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / block_lrは{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値を指定してください"
else:
block_lrs = None
cache_latents = args.cache_latents
use_dreambooth_method = args.in_json is None
if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
if use_dreambooth_method:
print("Using DreamBooth method.")
user_config = {
"datasets": [
{
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
args.train_data_dir, args.reg_data_dir
)
}
]
}
else:
print("Training with captions.")
user_config = {
"datasets": [
{
"subsets": [
{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}
]
}
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2])
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, [tokenizer1, tokenizer2])
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
train_dataset_group.verify_bucket_reso_steps(32)
if args.debug_dataset:
train_util.debug_dataset(train_dataset_group, True)
return
if len(train_dataset_group) == 0:
print(
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
)
return
if cache_latents:
assert (
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
if args.cache_text_encoder_outputs:
assert (
train_dataset_group.is_text_encoder_output_cacheable()
), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
# acceleratorを準備する
print("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
# モデルを読み込む
(
load_stable_diffusion_format,
text_encoder1,
text_encoder2,
vae,
unet,
logit_scale,
ckpt_info,
) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
# logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype)
# verify load/save model formats
if load_stable_diffusion_format:
src_stable_diffusion_ckpt = args.pretrained_model_name_or_path
src_diffusers_model_path = None
else:
src_stable_diffusion_ckpt = None
src_diffusers_model_path = args.pretrained_model_name_or_path
if args.save_model_as is None:
save_stable_diffusion_format = load_stable_diffusion_format
use_safetensors = args.use_safetensors
else:
save_stable_diffusion_format = args.save_model_as.lower() == "ckpt" or args.save_model_as.lower() == "safetensors"
use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
# assert save_stable_diffusion_format, "save_model_as must be ckpt or safetensors / save_model_asはckptかsafetensorsである必要があります"
# Diffusers版のxformers使用フラグを設定する関数
def set_diffusers_xformers_flag(model, valid):
def fn_recursive_set_mem_eff(module: torch.nn.Module):
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
module.set_use_memory_efficient_attention_xformers(valid)
for child in module.children():
fn_recursive_set_mem_eff(child)
fn_recursive_set_mem_eff(model)
# モデルに xformers とか memory efficient attention を組み込む
if args.diffusers_xformers:
# もうU-Netを独自にしたので動かないけどVAEのxformersは動くはず
accelerator.print("Use xformers by Diffusers")
# set_diffusers_xformers_flag(unet, True)
set_diffusers_xformers_flag(vae, True)
else:
# Windows版のxformersはfloatで学習できなかったりするのでxformersを使わない設定も可能にしておく必要がある
accelerator.print("Disable Diffusers' xformers")
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
vae.set_use_memory_efficient_attention_xformers(args.xformers)
# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
accelerator.wait_for_everyone()
# 学習を準備する:モデルを適切な状態にする
training_models = []
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
training_models.append(unet)
if args.train_text_encoder:
# TODO each option for two text encoders?
accelerator.print("enable text encoder training")
if args.gradient_checkpointing:
text_encoder1.gradient_checkpointing_enable()
text_encoder2.gradient_checkpointing_enable()
training_models.append(text_encoder1)
training_models.append(text_encoder2)
# set require_grad=True later
else:
text_encoder1.requires_grad_(False)
text_encoder2.requires_grad_(False)
text_encoder1.eval()
text_encoder2.eval()
# TextEncoderの出力をキャッシュする
if args.cache_text_encoder_outputs:
# Text Encodes are eval and no grad
with torch.no_grad():
train_dataset_group.cache_text_encoder_outputs(
(tokenizer1, tokenizer2),
(text_encoder1, text_encoder2),
accelerator.device,
None,
args.cache_text_encoder_outputs_to_disk,
accelerator.is_main_process,
)
accelerator.wait_for_everyone()
if not cache_latents:
vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=vae_dtype)
for m in training_models:
m.requires_grad_(True)
if block_lrs is None:
params = []
for m in training_models:
params.extend(m.parameters())
params_to_optimize = params
# calculate number of trainable parameters
n_params = 0
for p in params:
n_params += p.numel()
else:
params_to_optimize = get_block_params_to_optimize(training_models[0], block_lrs) # U-Net
for m in training_models[1:]: # Text Encoders if exists
params_to_optimize.append({"params": m.parameters(), "lr": args.learning_rate})
# calculate number of trainable parameters
n_params = 0
for params in params_to_optimize:
for p in params["params"]:
n_params += p.numel()
accelerator.print(f"number of models: {len(training_models)}")
accelerator.print(f"number of trainable parameters: {n_params}")
# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# 実験的機能勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
if args.full_fp16:
assert (
args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
accelerator.print("enable full fp16 training.")
unet.to(weight_dtype)
text_encoder1.to(weight_dtype)
text_encoder2.to(weight_dtype)
elif args.full_bf16:
assert (
args.mixed_precision == "bf16"
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
accelerator.print("enable full bf16 training.")
unet.to(weight_dtype)
text_encoder1.to(weight_dtype)
text_encoder2.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
if args.train_text_encoder:
unet, text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler
)
# transform DDP after prepare
text_encoder1, text_encoder2, unet = train_util.transform_models_if_DDP([text_encoder1, text_encoder2, unet])
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
(unet,) = train_util.transform_models_if_DDP([unet])
text_encoder1.to(weight_dtype)
text_encoder2.to(weight_dtype)
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
if args.cache_text_encoder_outputs:
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
text_encoder1.to("cpu", dtype=torch.float32)
text_encoder2.to("cpu", dtype=torch.float32)
if torch.cuda.is_available():
torch.cuda.empty_cache()
else:
# make sure Text Encoders are on GPU
text_encoder1.to(accelerator.device)
text_encoder2.to(accelerator.device)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
# 学習する
# total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
accelerator.print("running training / 学習開始")
accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
# accelerator.print(
# f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
# )
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
)
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
if args.zero_terminal_snr:
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
if accelerator.is_main_process:
init_kwargs = {}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
for m in training_models:
m.train()
loss_total = 0
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
with torch.no_grad():
# latentに変換
latents = vae.encode(batch["images"].to(vae_dtype)).latent_dist.sample().to(weight_dtype)
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
input_ids1 = batch["input_ids"]
input_ids2 = batch["input_ids2"]
with torch.set_grad_enabled(args.train_text_encoder):
# Get the text embedding for conditioning
# TODO support weighted captions
# if args.weighted_captions:
# encoder_hidden_states = get_weighted_text_embeddings(
# tokenizer,
# text_encoder,
# batch["captions"],
# accelerator.device,
# args.max_token_length // 75 if args.max_token_length else 1,
# clip_skip=args.clip_skip,
# )
# else:
input_ids1 = input_ids1.to(accelerator.device)
input_ids2 = input_ids2.to(accelerator.device)
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
args.max_token_length,
input_ids1,
input_ids2,
tokenizer1,
tokenizer2,
text_encoder1,
text_encoder2,
None if not args.full_fp16 else weight_dtype,
)
else:
encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype)
pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype)
# # verify that the text encoder outputs are correct
# ehs1, ehs2, p2 = train_util.get_hidden_states_sdxl(
# args.max_token_length,
# batch["input_ids"].to(text_encoder1.device),
# batch["input_ids2"].to(text_encoder1.device),
# tokenizer1,
# tokenizer2,
# text_encoder1,
# text_encoder2,
# None if not args.full_fp16 else weight_dtype,
# )
# b_size = encoder_hidden_states1.shape[0]
# assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# print("text encoder outputs verified")
# get size embeddings
orig_size = batch["original_sizes_hw"]
crop_size = batch["crop_top_lefts"]
target_size = batch["target_sizes_hw"]
embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
# concat embeddings
vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
# Predict the noise residual
with accelerator.autocast():
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
target = noise
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.v_pred_like_loss:
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss:
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
loss = loss.mean() # mean over batch dimension
else:
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = []
for m in training_models:
params_to_clip.extend(m.parameters())
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
sdxl_train_util.sample_images(
accelerator,
args,
None,
global_step,
accelerator.device,
vae,
[tokenizer1, tokenizer2],
[text_encoder1, text_encoder2],
unet,
)
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
sdxl_train_util.save_sd_model_on_epoch_end_or_stepwise(
args,
False,
accelerator,
src_path,
save_stable_diffusion_format,
use_safetensors,
save_dtype,
epoch,
num_train_epochs,
global_step,
accelerator.unwrap_model(text_encoder1),
accelerator.unwrap_model(text_encoder2),
accelerator.unwrap_model(unet),
vae,
logit_scale,
ckpt_info,
)
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if args.logging_dir is not None:
logs = {"loss": current_loss}
if block_lrs is None:
logs["lr"] = float(lr_scheduler.get_last_lr()[0])
if (
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
): # tracking d*lr value
logs["lr/d*lr"] = (
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
)
else:
append_block_lr_to_logs(block_lrs, logs, lr_scheduler, args.optimizer_type)
accelerator.log(logs, step=global_step)
# TODO moving averageにする
loss_total += current_loss
avr_loss = loss_total / (step + 1)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(train_dataloader)}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
if args.save_every_n_epochs is not None:
if accelerator.is_main_process:
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
sdxl_train_util.save_sd_model_on_epoch_end_or_stepwise(
args,
True,
accelerator,
src_path,
save_stable_diffusion_format,
use_safetensors,
save_dtype,
epoch,
num_train_epochs,
global_step,
accelerator.unwrap_model(text_encoder1),
accelerator.unwrap_model(text_encoder2),
accelerator.unwrap_model(unet),
vae,
logit_scale,
ckpt_info,
)
sdxl_train_util.sample_images(
accelerator,
args,
epoch + 1,
global_step,
accelerator.device,
vae,
[tokenizer1, tokenizer2],
[text_encoder1, text_encoder2],
unet,
)
is_main_process = accelerator.is_main_process
# if is_main_process:
unet = accelerator.unwrap_model(unet)
text_encoder1 = accelerator.unwrap_model(text_encoder1)
text_encoder2 = accelerator.unwrap_model(text_encoder2)
accelerator.end_training()
if args.save_state: # and is_main_process:
train_util.save_state_on_train_end(args, accelerator)
del accelerator # この後メモリを使うのでこれは消す
if is_main_process:
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
sdxl_train_util.save_sd_model_on_train_end(
args,
src_path,
save_stable_diffusion_format,
use_safetensors,
save_dtype,
epoch,
global_step,
text_encoder1,
text_encoder2,
unet,
vae,
logit_scale,
ckpt_info,
)
print("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, False)
train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
sdxl_train_util.add_sdxl_training_arguments(parser)
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
parser.add_argument(
"--no_half_vae",
action="store_true",
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
)
parser.add_argument(
"--block_lr",
type=str,
default=None,
help=f"learning rates for each block of U-Net, comma-separated, {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / "
+ f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
train(args)

View File

@@ -0,0 +1,609 @@
# cond_imageをU-Netのforwardで渡すバージョンのControlNet-LLLite検証用学習コード
# training code for ControlNet-LLLite with passing cond_image to U-Net's forward
import argparse
import gc
import json
import math
import os
import random
import time
from multiprocessing import Value
from types import SimpleNamespace
import toml
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from torch.nn.parallel import DistributedDataParallel as DDP
from accelerate.utils import set_seed
import accelerate
from diffusers import DDPMScheduler, ControlNetModel
from safetensors.torch import load_file
from library import sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util
import library.model_util as model_util
import library.train_util as train_util
import library.config_util as config_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
import library.huggingface_util as huggingface_util
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import (
add_v_prediction_like_loss,
apply_snr_weight,
prepare_scheduler_for_custom_training,
pyramid_noise_like,
apply_noise_offset,
scale_v_prediction_loss_like_noise_prediction,
)
import networks.control_net_lllite_for_train as control_net_lllite_for_train
# TODO 他のスクリプトと共通化する
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
logs = {
"loss/current": current_loss,
"loss/average": avr_loss,
"lr": lr_scheduler.get_last_lr()[0],
}
if args.optimizer_type.lower().startswith("DAdapt".lower()):
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
return logs
def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
sdxl_train_util.verify_sdxl_training_args(args)
cache_latents = args.cache_latents
use_user_config = args.dataset_config is not None
if args.seed is None:
args.seed = random.randint(0, 2**32)
set_seed(args.seed)
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
# データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
if use_user_config:
print(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "conditioning_data_dir"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
user_config = {
"datasets": [
{
"subsets": config_util.generate_controlnet_subsets_config_by_subdirs(
args.train_data_dir,
args.conditioning_data_dir,
args.caption_extension,
)
}
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2])
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
train_dataset_group.verify_bucket_reso_steps(32)
if args.debug_dataset:
train_util.debug_dataset(train_dataset_group)
return
if len(train_dataset_group) == 0:
print(
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してくださいtrain_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります"
)
return
if cache_latents:
assert (
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
else:
print("WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません")
if args.cache_text_encoder_outputs:
assert (
train_dataset_group.is_text_encoder_output_cacheable()
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
# acceleratorを準備する
print("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
is_main_process = accelerator.is_main_process
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
# モデルを読み込む
(
load_stable_diffusion_format,
text_encoder1,
text_encoder2,
vae,
unet,
logit_scale,
ckpt_info,
) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype)
# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(
vae,
args.vae_batch_size,
args.cache_latents_to_disk,
accelerator.is_main_process,
)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
accelerator.wait_for_everyone()
# TextEncoderの出力をキャッシュする
if args.cache_text_encoder_outputs:
# Text Encodes are eval and no grad
with torch.no_grad():
train_dataset_group.cache_text_encoder_outputs(
(tokenizer1, tokenizer2),
(text_encoder1, text_encoder2),
accelerator.device,
None,
args.cache_text_encoder_outputs_to_disk,
accelerator.is_main_process,
)
accelerator.wait_for_everyone()
# prepare ControlNet-LLLite
control_net_lllite_for_train.replace_unet_linear_and_conv2d()
if args.network_weights is not None:
accelerator.print(f"initialize U-Net with ControlNet-LLLite")
with accelerate.init_empty_weights():
unet_lllite = control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite()
unet_lllite.to(accelerator.device, dtype=weight_dtype)
unet_sd = unet.state_dict()
info = unet_lllite.load_lllite_weights(args.network_weights, unet_sd)
accelerator.print(f"load ControlNet-LLLite weights from {args.network_weights}: {info}")
else:
# cosumes large memory, so send to GPU before creating the LLLite model
accelerator.print("sending U-Net to GPU")
unet.to(accelerator.device, dtype=weight_dtype)
unet_sd = unet.state_dict()
# init LLLite weights
accelerator.print(f"initialize U-Net with ControlNet-LLLite")
if args.lowram:
with accelerate.init_on_device(accelerator.device):
unet_lllite = control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite()
else:
unet_lllite = control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite()
unet_lllite.to(weight_dtype)
info = unet_lllite.load_lllite_weights(None, unet_sd)
accelerator.print(f"init U-Net with ControlNet-LLLite weights: {info}")
del unet_sd, unet
unet: control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite = unet_lllite
del unet_lllite
unet.apply_lllite(args.cond_emb_dim, args.network_dim, args.network_dropout)
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")
trainable_params = list(unet.prepare_params())
print(f"trainable params count: {len(trainable_params)}")
print(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}")
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# 実験的機能勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
# if args.full_fp16:
# assert (
# args.mixed_precision == "fp16"
# ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
# accelerator.print("enable full fp16 training.")
# unet.to(weight_dtype)
# elif args.full_bf16:
# assert (
# args.mixed_precision == "bf16"
# ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
# accelerator.print("enable full bf16 training.")
# unet.to(weight_dtype)
unet.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
# transform DDP after prepare (train_network here only)
unet = train_util.transform_models_if_DDP([unet])[0]
if args.gradient_checkpointing:
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
else:
unet.eval()
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
if args.cache_text_encoder_outputs:
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
text_encoder1.to("cpu", dtype=torch.float32)
text_encoder2.to("cpu", dtype=torch.float32)
if torch.cuda.is_available():
torch.cuda.empty_cache()
else:
# make sure Text Encoders are on GPU
text_encoder1.to(accelerator.device)
text_encoder2.to(accelerator.device)
if not cache_latents:
vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=vae_dtype)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
# 学習する
# TODO: find a way to handle total batch size when there are multiple datasets
accelerator.print("running training / 学習開始")
accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
# print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
)
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
if args.zero_terminal_snr:
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
if accelerator.is_main_process:
init_kwargs = {}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
)
loss_list = []
loss_total = 0.0
del train_dataset_group
# function for saving/removing
def save_model(
ckpt_name,
unwrapped_nw: control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite,
steps,
epoch_no,
force_sync_upload=False,
):
os.makedirs(args.output_dir, exist_ok=True)
ckpt_file = os.path.join(args.output_dir, ckpt_name)
accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
sai_metadata = train_util.get_sai_model_spec(None, args, True, True, False)
sai_metadata["modelspec.architecture"] = sai_model_spec.ARCH_SD_XL_V1_BASE + "/control-net-lllite"
unwrapped_nw.save_lllite_weights(ckpt_file, save_dtype, sai_metadata)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
def remove_model(old_ckpt_name):
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
if os.path.exists(old_ckpt_file):
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
# training loop
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(unet):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample()
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
input_ids1 = batch["input_ids"]
input_ids2 = batch["input_ids2"]
with torch.no_grad():
# Get the text embedding for conditioning
input_ids1 = input_ids1.to(accelerator.device)
input_ids2 = input_ids2.to(accelerator.device)
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
args.max_token_length,
input_ids1,
input_ids2,
tokenizer1,
tokenizer2,
text_encoder1,
text_encoder2,
None if not args.full_fp16 else weight_dtype,
)
else:
encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype)
pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype)
# get size embeddings
orig_size = batch["original_sizes_hw"]
crop_size = batch["crop_top_lefts"]
target_size = batch["target_sizes_hw"]
embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
# concat embeddings
vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype)
with accelerator.autocast():
# conditioning imageをControlNetに渡す / pass conditioning image to ControlNet
# 内部でcond_embに変換される / it will be converted to cond_emb inside
# それらの値を使いつつ、U-Netでイズを予測する / predict noise with U-Net using those values
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding, controlnet_image)
if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss:
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = unet.get_trainable_params()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
# sdxl_train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
save_model(ckpt_name, accelerator.unwrap_model(unet), global_step, epoch)
if args.save_state:
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
remove_step_no = train_util.get_remove_step_no(args, global_step)
if remove_step_no is not None:
remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no)
remove_model(remove_ckpt_name)
current_loss = loss.detach().item()
if epoch == 0:
loss_list.append(current_loss)
else:
loss_total -= loss_list[step]
loss_list[step] = current_loss
loss_total += current_loss
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if args.logging_dir is not None:
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(loss_list)}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
# 指定エポックごとにモデルを保存
if args.save_every_n_epochs is not None:
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
if is_main_process and saving:
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
save_model(ckpt_name, accelerator.unwrap_model(unet), global_step, epoch + 1)
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
if remove_epoch_no is not None:
remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no)
remove_model(remove_ckpt_name)
if args.save_state:
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
# self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
# end of epoch
if is_main_process:
unet = accelerator.unwrap_model(unet)
accelerator.end_training()
if is_main_process and args.save_state:
train_util.save_state_on_train_end(args, accelerator)
if is_main_process:
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(ckpt_name, unet, global_step, num_train_epochs, force_sync_upload=True)
print("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
sdxl_train_util.add_sdxl_training_arguments(parser)
parser.add_argument(
"--save_model_as",
type=str,
default="safetensors",
choices=[None, "ckpt", "pt", "safetensors"],
help="format to save the model (default is .safetensors) / モデル保存時の形式デフォルトはsafetensors",
)
parser.add_argument("--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数")
parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み")
parser.add_argument("--network_dim", type=int, default=None, help="network dimensions (rank) / モジュールの次元数")
parser.add_argument(
"--network_dropout",
type=float,
default=None,
help="Drops neurons out of training every step (0 or None is default behavior (no dropout), 1 would drop all neurons) / 訓練時に毎ステップでニューロンをdropする0またはNoneはdropoutなし、1は全ニューロンをdropout",
)
parser.add_argument(
"--conditioning_data_dir",
type=str,
default=None,
help="conditioning data directory / 条件付けデータのディレクトリ",
)
parser.add_argument(
"--no_half_vae",
action="store_true",
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
)
return parser
if __name__ == "__main__":
# sdxl_original_unet.USE_REENTRANT = False
parser = setup_parser()
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
train(args)

View File

@@ -0,0 +1,579 @@
import argparse
import gc
import json
import math
import os
import random
import time
from multiprocessing import Value
from types import SimpleNamespace
import toml
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from torch.nn.parallel import DistributedDataParallel as DDP
from accelerate.utils import set_seed
from diffusers import DDPMScheduler, ControlNetModel
from safetensors.torch import load_file
from library import sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util
import library.model_util as model_util
import library.train_util as train_util
import library.config_util as config_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
import library.huggingface_util as huggingface_util
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import (
add_v_prediction_like_loss,
apply_snr_weight,
prepare_scheduler_for_custom_training,
pyramid_noise_like,
apply_noise_offset,
scale_v_prediction_loss_like_noise_prediction,
)
import networks.control_net_lllite as control_net_lllite
# TODO 他のスクリプトと共通化する
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
logs = {
"loss/current": current_loss,
"loss/average": avr_loss,
"lr": lr_scheduler.get_last_lr()[0],
}
if args.optimizer_type.lower().startswith("DAdapt".lower()):
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
return logs
def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
sdxl_train_util.verify_sdxl_training_args(args)
cache_latents = args.cache_latents
use_user_config = args.dataset_config is not None
if args.seed is None:
args.seed = random.randint(0, 2**32)
set_seed(args.seed)
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
# データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
if use_user_config:
print(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "conditioning_data_dir"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
user_config = {
"datasets": [
{
"subsets": config_util.generate_controlnet_subsets_config_by_subdirs(
args.train_data_dir,
args.conditioning_data_dir,
args.caption_extension,
)
}
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2])
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
train_dataset_group.verify_bucket_reso_steps(32)
if args.debug_dataset:
train_util.debug_dataset(train_dataset_group)
return
if len(train_dataset_group) == 0:
print(
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してくださいtrain_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります"
)
return
if cache_latents:
assert (
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
else:
print("WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません")
if args.cache_text_encoder_outputs:
assert (
train_dataset_group.is_text_encoder_output_cacheable()
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
# acceleratorを準備する
print("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
is_main_process = accelerator.is_main_process
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
# モデルを読み込む
(
load_stable_diffusion_format,
text_encoder1,
text_encoder2,
vae,
unet,
logit_scale,
ckpt_info,
) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype)
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(
vae,
args.vae_batch_size,
args.cache_latents_to_disk,
accelerator.is_main_process,
)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
accelerator.wait_for_everyone()
# TextEncoderの出力をキャッシュする
if args.cache_text_encoder_outputs:
# Text Encodes are eval and no grad
with torch.no_grad():
train_dataset_group.cache_text_encoder_outputs(
(tokenizer1, tokenizer2),
(text_encoder1, text_encoder2),
accelerator.device,
None,
args.cache_text_encoder_outputs_to_disk,
accelerator.is_main_process,
)
accelerator.wait_for_everyone()
# prepare ControlNet
network = control_net_lllite.ControlNetLLLite(unet, args.cond_emb_dim, args.network_dim, args.network_dropout)
network.apply_to()
if args.network_weights is not None:
info = network.load_weights(args.network_weights)
accelerator.print(f"load ControlNet weights from {args.network_weights}: {info}")
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
network.enable_gradient_checkpointing() # may have no effect
# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")
trainable_params = list(network.prepare_optimizer_params())
print(f"trainable params count: {len(trainable_params)}")
print(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}")
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# 実験的機能勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
if args.full_fp16:
assert (
args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
accelerator.print("enable full fp16 training.")
unet.to(weight_dtype)
network.to(weight_dtype)
elif args.full_bf16:
assert (
args.mixed_precision == "bf16"
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
accelerator.print("enable full bf16 training.")
unet.to(weight_dtype)
network.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, network, optimizer, train_dataloader, lr_scheduler
)
network: control_net_lllite.ControlNetLLLite
# transform DDP after prepare (train_network here only)
unet, network = train_util.transform_models_if_DDP([unet, network])
if args.gradient_checkpointing:
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
else:
unet.eval()
network.prepare_grad_etc()
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
if args.cache_text_encoder_outputs:
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
text_encoder1.to("cpu", dtype=torch.float32)
text_encoder2.to("cpu", dtype=torch.float32)
if torch.cuda.is_available():
torch.cuda.empty_cache()
else:
# make sure Text Encoders are on GPU
text_encoder1.to(accelerator.device)
text_encoder2.to(accelerator.device)
if not cache_latents:
vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=vae_dtype)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
# 学習する
# TODO: find a way to handle total batch size when there are multiple datasets
accelerator.print("running training / 学習開始")
accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
# print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
)
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
if args.zero_terminal_snr:
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
if accelerator.is_main_process:
init_kwargs = {}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
)
loss_list = []
loss_total = 0.0
del train_dataset_group
# function for saving/removing
def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False):
os.makedirs(args.output_dir, exist_ok=True)
ckpt_file = os.path.join(args.output_dir, ckpt_name)
accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
sai_metadata = train_util.get_sai_model_spec(None, args, True, True, False)
sai_metadata["modelspec.architecture"] = sai_model_spec.ARCH_SD_XL_V1_BASE + "/control-net-lllite"
unwrapped_nw.save_weights(ckpt_file, save_dtype, sai_metadata)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
def remove_model(old_ckpt_name):
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
if os.path.exists(old_ckpt_file):
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
# training loop
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
network.on_epoch_start() # train()
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(network):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample()
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
input_ids1 = batch["input_ids"]
input_ids2 = batch["input_ids2"]
with torch.no_grad():
# Get the text embedding for conditioning
input_ids1 = input_ids1.to(accelerator.device)
input_ids2 = input_ids2.to(accelerator.device)
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
args.max_token_length,
input_ids1,
input_ids2,
tokenizer1,
tokenizer2,
text_encoder1,
text_encoder2,
None if not args.full_fp16 else weight_dtype,
)
else:
encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype)
pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype)
# get size embeddings
orig_size = batch["original_sizes_hw"]
crop_size = batch["crop_top_lefts"]
target_size = batch["target_sizes_hw"]
embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
# concat embeddings
vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype)
with accelerator.autocast():
# conditioning imageをControlNetに渡す / pass conditioning image to ControlNet
# 内部でcond_embに変換される / it will be converted to cond_emb inside
network.set_cond_image(controlnet_image)
# それらの値を使いつつ、U-Netでイズを予測する / predict noise with U-Net using those values
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss:
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = network.get_trainable_params()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
# sdxl_train_util.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch)
if args.save_state:
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
remove_step_no = train_util.get_remove_step_no(args, global_step)
if remove_step_no is not None:
remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no)
remove_model(remove_ckpt_name)
current_loss = loss.detach().item()
if epoch == 0:
loss_list.append(current_loss)
else:
loss_total -= loss_list[step]
loss_list[step] = current_loss
loss_total += current_loss
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if args.logging_dir is not None:
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(loss_list)}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
# 指定エポックごとにモデルを保存
if args.save_every_n_epochs is not None:
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
if is_main_process and saving:
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch + 1)
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
if remove_epoch_no is not None:
remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no)
remove_model(remove_ckpt_name)
if args.save_state:
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
# self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
# end of epoch
if is_main_process:
network = accelerator.unwrap_model(network)
accelerator.end_training()
if is_main_process and args.save_state:
train_util.save_state_on_train_end(args, accelerator)
if is_main_process:
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True)
print("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
sdxl_train_util.add_sdxl_training_arguments(parser)
parser.add_argument(
"--save_model_as",
type=str,
default="safetensors",
choices=[None, "ckpt", "pt", "safetensors"],
help="format to save the model (default is .safetensors) / モデル保存時の形式デフォルトはsafetensors",
)
parser.add_argument("--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数")
parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み")
parser.add_argument("--network_dim", type=int, default=None, help="network dimensions (rank) / モジュールの次元数")
parser.add_argument(
"--network_dropout",
type=float,
default=None,
help="Drops neurons out of training every step (0 or None is default behavior (no dropout), 1 would drop all neurons) / 訓練時に毎ステップでニューロンをdropする0またはNoneはdropoutなし、1は全ニューロンをdropout",
)
parser.add_argument(
"--conditioning_data_dir",
type=str,
default=None,
help="conditioning data directory / 条件付けデータのディレクトリ",
)
parser.add_argument(
"--no_half_vae",
action="store_true",
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
)
return parser
if __name__ == "__main__":
# sdxl_original_unet.USE_REENTRANT = False
parser = setup_parser()
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
train(args)

183
sdxl_train_network.py Normal file
View File

@@ -0,0 +1,183 @@
import argparse
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from library import sdxl_model_util, sdxl_train_util, train_util
import train_network
class SdxlNetworkTrainer(train_network.NetworkTrainer):
def __init__(self):
super().__init__()
self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR
self.is_sdxl = True
def assert_extra_args(self, args, train_dataset_group):
super().assert_extra_args(args, train_dataset_group)
sdxl_train_util.verify_sdxl_training_args(args)
if args.cache_text_encoder_outputs:
assert (
train_dataset_group.is_text_encoder_output_cacheable()
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
assert (
args.network_train_unet_only or not args.cache_text_encoder_outputs
), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません"
train_dataset_group.verify_bucket_reso_steps(32)
def load_target_model(self, args, weight_dtype, accelerator):
(
load_stable_diffusion_format,
text_encoder1,
text_encoder2,
vae,
unet,
logit_scale,
ckpt_info,
) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype)
self.load_stable_diffusion_format = load_stable_diffusion_format
self.logit_scale = logit_scale
self.ckpt_info = ckpt_info
return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet
def load_tokenizer(self, args):
tokenizer = sdxl_train_util.load_tokenizers(args)
return tokenizer
def is_text_encoder_outputs_cached(self, args):
return args.cache_text_encoder_outputs
def cache_text_encoder_outputs_if_needed(
self, args, accelerator, unet, vae, tokenizers, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
):
if args.cache_text_encoder_outputs:
if not args.lowram:
# メモリ消費を減らす
print("move vae and unet to cpu to save memory")
org_vae_device = vae.device
org_unet_device = unet.device
vae.to("cpu")
unet.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
dataset.cache_text_encoder_outputs(
tokenizers,
text_encoders,
accelerator.device,
weight_dtype,
args.cache_text_encoder_outputs_to_disk,
accelerator.is_main_process,
)
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
text_encoders[1].to("cpu", dtype=torch.float32)
if torch.cuda.is_available():
torch.cuda.empty_cache()
if not args.lowram:
print("move vae and unet back to original device")
vae.to(org_vae_device)
unet.to(org_unet_device)
else:
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
text_encoders[0].to(accelerator.device)
text_encoders[1].to(accelerator.device)
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
input_ids1 = batch["input_ids"]
input_ids2 = batch["input_ids2"]
with torch.enable_grad():
# Get the text embedding for conditioning
# TODO support weighted captions
# if args.weighted_captions:
# encoder_hidden_states = get_weighted_text_embeddings(
# tokenizer,
# text_encoder,
# batch["captions"],
# accelerator.device,
# args.max_token_length // 75 if args.max_token_length else 1,
# clip_skip=args.clip_skip,
# )
# else:
input_ids1 = input_ids1.to(accelerator.device)
input_ids2 = input_ids2.to(accelerator.device)
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
args.max_token_length,
input_ids1,
input_ids2,
tokenizers[0],
tokenizers[1],
text_encoders[0],
text_encoders[1],
None if not args.full_fp16 else weight_dtype,
)
else:
encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype)
pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype)
# # verify that the text encoder outputs are correct
# ehs1, ehs2, p2 = train_util.get_hidden_states_sdxl(
# args.max_token_length,
# batch["input_ids"].to(text_encoders[0].device),
# batch["input_ids2"].to(text_encoders[0].device),
# tokenizers[0],
# tokenizers[1],
# text_encoders[0],
# text_encoders[1],
# None if not args.full_fp16 else weight_dtype,
# )
# b_size = encoder_hidden_states1.shape[0]
# assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# print("text encoder outputs verified")
return encoder_hidden_states1, encoder_hidden_states2, pool2
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
# get size embeddings
orig_size = batch["original_sizes_hw"]
crop_size = batch["crop_top_lefts"]
target_size = batch["target_sizes_hw"]
embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
# concat embeddings
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
return noise_pred
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
sdxl_train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)
def setup_parser() -> argparse.ArgumentParser:
parser = train_network.setup_parser()
sdxl_train_util.add_sdxl_training_arguments(parser)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
trainer = SdxlNetworkTrainer()
trainer.train(args)

View File

@@ -0,0 +1,140 @@
import argparse
import os
import regex
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
import open_clip
from library import sdxl_model_util, sdxl_train_util, train_util
import train_textual_inversion
class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTrainer):
def __init__(self):
super().__init__()
self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR
self.is_sdxl = True
def assert_extra_args(self, args, train_dataset_group):
super().assert_extra_args(args, train_dataset_group)
sdxl_train_util.verify_sdxl_training_args(args, supportTextEncoderCaching=False)
train_dataset_group.verify_bucket_reso_steps(32)
def load_target_model(self, args, weight_dtype, accelerator):
(
load_stable_diffusion_format,
text_encoder1,
text_encoder2,
vae,
unet,
logit_scale,
ckpt_info,
) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype)
self.load_stable_diffusion_format = load_stable_diffusion_format
self.logit_scale = logit_scale
self.ckpt_info = ckpt_info
return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet
def load_tokenizer(self, args):
tokenizer = sdxl_train_util.load_tokenizers(args)
return tokenizer
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
input_ids1 = batch["input_ids"]
input_ids2 = batch["input_ids2"]
with torch.enable_grad():
input_ids1 = input_ids1.to(accelerator.device)
input_ids2 = input_ids2.to(accelerator.device)
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
args.max_token_length,
input_ids1,
input_ids2,
tokenizers[0],
tokenizers[1],
text_encoders[0],
text_encoders[1],
None if not args.full_fp16 else weight_dtype,
)
return encoder_hidden_states1, encoder_hidden_states2, pool2
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
# get size embeddings
orig_size = batch["original_sizes_hw"]
crop_size = batch["crop_top_lefts"]
target_size = batch["target_sizes_hw"]
embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
# concat embeddings
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
return noise_pred
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement):
sdxl_train_util.sample_images(
accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement
)
def save_weights(self, file, updated_embs, save_dtype, metadata):
state_dict = {"clip_l": updated_embs[0], "clip_g": updated_embs[1]}
if save_dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(save_dtype)
state_dict[key] = v
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file
save_file(state_dict, file, metadata)
else:
torch.save(state_dict, file)
def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
data = load_file(file)
else:
data = torch.load(file, map_location="cpu")
emb_l = data.get("clip_l", None) # ViT-L text encoder 1
emb_g = data.get("clip_g", None) # BiG-G text encoder 2
assert (
emb_l is not None or emb_g is not None
), f"weight file does not contains weights for text encoder 1 or 2 / 重みファイルにテキストエンコーダー1または2の重みが含まれていません: {file}"
return [emb_l, emb_g]
def setup_parser() -> argparse.ArgumentParser:
parser = train_textual_inversion.setup_parser()
# don't add sdxl_train_util.add_sdxl_training_arguments(parser): because it only adds text encoder caching
# sdxl_train_util.add_sdxl_training_arguments(parser)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
trainer = SdxlTextualInversionTrainer()
trainer.train(args)

194
tools/cache_latents.py Normal file
View File

@@ -0,0 +1,194 @@
# latentsのdiskへの事前キャッシュを行う / cache latents to disk
import argparse
import math
from multiprocessing import Value
import os
from accelerate.utils import set_seed
import torch
from tqdm import tqdm
from library import config_util
from library import train_util
from library import sdxl_train_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
def cache_to_disk(args: argparse.Namespace) -> None:
train_util.prepare_dataset_args(args, True)
# check cache latents arg
assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります"
use_dreambooth_method = args.in_json is None
if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する
# tokenizerを準備するdatasetを動かすために必要
if args.sdxl:
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
tokenizers = [tokenizer1, tokenizer2]
else:
tokenizer = train_util.load_tokenizer(args)
tokenizers = [tokenizer]
# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
if use_dreambooth_method:
print("Using DreamBooth method.")
user_config = {
"datasets": [
{
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
args.train_data_dir, args.reg_data_dir
)
}
]
}
else:
print("Training with captions.")
user_config = {
"datasets": [
{
"subsets": [
{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}
]
}
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers)
# datasetのcache_latentsを呼ばなければ、生の画像が返る
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
# acceleratorを準備する
print("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, _ = train_util.prepare_dtype(args)
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
# モデルを読み込む
print("load model")
if args.sdxl:
(_, _, _, vae, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
else:
_, vae, _, _ = train_util.load_target_model(args, weight_dtype, accelerator)
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
vae.set_use_memory_efficient_attention_xformers(args.xformers)
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
# dataloaderを準備する
train_dataset_group.set_caching_mode("latents")
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# acceleratorを使ってモデルを準備するマルチGPUで使えるようになるはず
train_dataloader = accelerator.prepare(train_dataloader)
# データ取得のためのループ
for batch in tqdm(train_dataloader):
b_size = len(batch["images"])
vae_batch_size = b_size if args.vae_batch_size is None else args.vae_batch_size
flip_aug = batch["flip_aug"]
random_crop = batch["random_crop"]
bucket_reso = batch["bucket_reso"]
# バッチを分割して処理する
for i in range(0, b_size, vae_batch_size):
images = batch["images"][i : i + vae_batch_size]
absolute_paths = batch["absolute_paths"][i : i + vae_batch_size]
resized_sizes = batch["resized_sizes"][i : i + vae_batch_size]
image_infos = []
for i, (image, absolute_path, resized_size) in enumerate(zip(images, absolute_paths, resized_sizes)):
image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path)
image_info.image = image
image_info.bucket_reso = bucket_reso
image_info.resized_size = resized_size
image_info.latents_npz = os.path.splitext(absolute_path)[0] + ".npz"
if args.skip_existing:
if train_util.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug):
print(f"Skipping {image_info.latents_npz} because it already exists.")
continue
image_infos.append(image_info)
if len(image_infos) > 0:
train_util.cache_batch_latents(vae, True, image_infos, flip_aug, random_crop)
accelerator.wait_for_everyone()
accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser)
train_util.add_training_arguments(parser, True)
train_util.add_dataset_arguments(parser, True, True, True)
config_util.add_config_arguments(parser)
parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する")
parser.add_argument(
"--no_half_vae",
action="store_true",
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
)
parser.add_argument(
"--skip_existing",
action="store_true",
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップするflip_aug有効時は通常、反転の両方が存在する画像をスキップ",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
cache_to_disk(args)

View File

@@ -0,0 +1,191 @@
# text encoder出力のdiskへの事前キャッシュを行う / cache text encoder outputs to disk in advance
import argparse
import math
from multiprocessing import Value
import os
from accelerate.utils import set_seed
import torch
from tqdm import tqdm
from library import config_util
from library import train_util
from library import sdxl_train_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
def cache_to_disk(args: argparse.Namespace) -> None:
train_util.prepare_dataset_args(args, True)
# check cache arg
assert (
args.cache_text_encoder_outputs_to_disk
), "cache_text_encoder_outputs_to_disk must be True / cache_text_encoder_outputs_to_diskはTrueである必要があります"
# できるだけ準備はしておくが今のところSDXLのみしか動かない
assert (
args.sdxl
), "cache_text_encoder_outputs_to_disk is only available for SDXL / cache_text_encoder_outputs_to_diskはSDXLのみ利用可能です"
use_dreambooth_method = args.in_json is None
if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する
# tokenizerを準備するdatasetを動かすために必要
if args.sdxl:
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
tokenizers = [tokenizer1, tokenizer2]
else:
tokenizer = train_util.load_tokenizer(args)
tokenizers = [tokenizer]
# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
if use_dreambooth_method:
print("Using DreamBooth method.")
user_config = {
"datasets": [
{
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
args.train_data_dir, args.reg_data_dir
)
}
]
}
else:
print("Training with captions.")
user_config = {
"datasets": [
{
"subsets": [
{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}
]
}
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
# acceleratorを準備する
print("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, _ = train_util.prepare_dtype(args)
# モデルを読み込む
print("load model")
if args.sdxl:
(_, text_encoder1, text_encoder2, _, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
text_encoders = [text_encoder1, text_encoder2]
else:
text_encoder1, _, _, _ = train_util.load_target_model(args, weight_dtype, accelerator)
text_encoders = [text_encoder1]
for text_encoder in text_encoders:
text_encoder.to(accelerator.device, dtype=weight_dtype)
text_encoder.requires_grad_(False)
text_encoder.eval()
# dataloaderを準備する
train_dataset_group.set_caching_mode("text")
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# acceleratorを使ってモデルを準備するマルチGPUで使えるようになるはず
train_dataloader = accelerator.prepare(train_dataloader)
# データ取得のためのループ
for batch in tqdm(train_dataloader):
absolute_paths = batch["absolute_paths"]
input_ids1_list = batch["input_ids1_list"]
input_ids2_list = batch["input_ids2_list"]
image_infos = []
for absolute_path, input_ids1, input_ids2 in zip(absolute_paths, input_ids1_list, input_ids2_list):
image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path)
image_info.text_encoder_outputs_npz = os.path.splitext(absolute_path)[0] + train_util.TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX
image_info
if args.skip_existing:
if os.path.exists(image_info.text_encoder_outputs_npz):
print(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.")
continue
image_info.input_ids1 = input_ids1
image_info.input_ids2 = input_ids2
image_infos.append(image_info)
if len(image_infos) > 0:
b_input_ids1 = torch.stack([image_info.input_ids1 for image_info in image_infos])
b_input_ids2 = torch.stack([image_info.input_ids2 for image_info in image_infos])
train_util.cache_batch_text_encoder_outputs(
image_infos, tokenizers, text_encoders, args.max_token_length, True, b_input_ids1, b_input_ids2, weight_dtype
)
accelerator.wait_for_everyone()
accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser)
train_util.add_training_arguments(parser, True)
train_util.add_dataset_arguments(parser, True, True, True)
config_util.add_config_arguments(parser)
sdxl_train_util.add_sdxl_training_arguments(parser)
parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する")
parser.add_argument(
"--skip_existing",
action="store_true",
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップするflip_aug有効時は通常、反転の両方が存在する画像をスキップ",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
cache_to_disk(args)

168
tools/merge_models.py Normal file
View File

@@ -0,0 +1,168 @@
import argparse
import os
import torch
from safetensors import safe_open
from safetensors.torch import load_file, save_file
from tqdm import tqdm
def is_unet_key(key):
# VAE or TextEncoder, the last one is for SDXL
return not ("first_stage_model" in key or "cond_stage_model" in key or "conditioner." in key)
TEXT_ENCODER_KEY_REPLACEMENTS = [
("cond_stage_model.transformer.embeddings.", "cond_stage_model.transformer.text_model.embeddings."),
("cond_stage_model.transformer.encoder.", "cond_stage_model.transformer.text_model.encoder."),
("cond_stage_model.transformer.final_layer_norm.", "cond_stage_model.transformer.text_model.final_layer_norm."),
]
# support for models with different text encoder keys
def replace_text_encoder_key(key):
for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
if key.startswith(rep_from):
return True, rep_to + key[len(rep_from) :]
return False, key
def merge(args):
if args.precision == "fp16":
dtype = torch.float16
elif args.precision == "bf16":
dtype = torch.bfloat16
else:
dtype = torch.float
if args.saving_precision == "fp16":
save_dtype = torch.float16
elif args.saving_precision == "bf16":
save_dtype = torch.bfloat16
else:
save_dtype = torch.float
# check if all models are safetensors
for model in args.models:
if not model.endswith("safetensors"):
print(f"Model {model} is not a safetensors model")
exit()
if not os.path.isfile(model):
print(f"Model {model} does not exist")
exit()
assert args.ratios is None or len(args.models) == len(args.ratios), "ratios must be the same length as models"
# load and merge
ratio = 1.0 / len(args.models) # default
supplementary_key_ratios = {} # [key] = ratio, for keys not in all models, add later
merged_sd = None
first_model_keys = set() # check missing keys in other models
for i, model in enumerate(args.models):
if args.ratios is not None:
ratio = args.ratios[i]
if merged_sd is None:
# load first model
print(f"Loading model {model}, ratio = {ratio}...")
merged_sd = {}
with safe_open(model, framework="pt", device=args.device) as f:
for key in tqdm(f.keys()):
value = f.get_tensor(key)
_, key = replace_text_encoder_key(key)
first_model_keys.add(key)
if not is_unet_key(key) and args.unet_only:
supplementary_key_ratios[key] = 1.0 # use first model's value for VAE or TextEncoder
continue
value = ratio * value.to(dtype) # first model's value * ratio
merged_sd[key] = value
print(f"Model has {len(merged_sd)} keys " + ("(UNet only)" if args.unet_only else ""))
continue
# load other models
print(f"Loading model {model}, ratio = {ratio}...")
with safe_open(model, framework="pt", device=args.device) as f:
model_keys = f.keys()
for key in tqdm(model_keys):
_, new_key = replace_text_encoder_key(key)
if new_key not in merged_sd:
if args.show_skipped and new_key not in first_model_keys:
print(f"Skip: {new_key}")
continue
value = f.get_tensor(key)
merged_sd[new_key] = merged_sd[new_key] + ratio * value.to(dtype)
# enumerate keys not in this model
model_keys = set(model_keys)
for key in merged_sd.keys():
if key in model_keys:
continue
print(f"Key {key} not in model {model}, use first model's value")
if key in supplementary_key_ratios:
supplementary_key_ratios[key] += ratio
else:
supplementary_key_ratios[key] = ratio
# add supplementary keys' value (including VAE and TextEncoder)
if len(supplementary_key_ratios) > 0:
print("add first model's value")
with safe_open(args.models[0], framework="pt", device=args.device) as f:
for key in tqdm(f.keys()):
_, new_key = replace_text_encoder_key(key)
if new_key not in supplementary_key_ratios:
continue
if is_unet_key(new_key): # not VAE or TextEncoder
print(f"Key {new_key} not in all models, ratio = {supplementary_key_ratios[new_key]}")
value = f.get_tensor(key) # original key
if new_key not in merged_sd:
merged_sd[new_key] = supplementary_key_ratios[new_key] * value.to(dtype)
else:
merged_sd[new_key] = merged_sd[new_key] + supplementary_key_ratios[new_key] * value.to(dtype)
# save
output_file = args.output
if not output_file.endswith(".safetensors"):
output_file = output_file + ".safetensors"
print(f"Saving to {output_file}...")
# convert to save_dtype
for k in merged_sd.keys():
merged_sd[k] = merged_sd[k].to(save_dtype)
save_file(merged_sd, output_file)
print("Done!")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Merge models")
parser.add_argument("--models", nargs="+", type=str, help="Models to merge")
parser.add_argument("--output", type=str, help="Output model")
parser.add_argument("--ratios", nargs="+", type=float, help="Ratios of models, default is equal, total = 1.0")
parser.add_argument("--unet_only", action="store_true", help="Only merge unet")
parser.add_argument("--device", type=str, default="cpu", help="Device to use, default is cpu")
parser.add_argument(
"--precision", type=str, default="float", choices=["float", "fp16", "bf16"], help="Calculation precision, default is float"
)
parser.add_argument(
"--saving_precision",
type=str,
default="float",
choices=["float", "fp16", "bf16"],
help="Saving precision, default is float",
)
parser.add_argument("--show_skipped", action="store_true", help="Show skipped keys (keys not in first model)")
args = parser.parse_args()
merge(args)

View File

@@ -4,175 +4,187 @@ import cv2
import torch
from safetensors.torch import load_file
from diffusers import UNet2DConditionModel
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
from library.original_unet import UNet2DConditionModel, SampleOutput
import library.model_util as model_util
class ControlNetInfo(NamedTuple):
unet: Any
net: Any
prep: Any
weight: float
ratio: float
unet: Any
net: Any
prep: Any
weight: float
ratio: float
class ControlNet(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def __init__(self) -> None:
super().__init__()
# make control model
self.control_model = torch.nn.Module()
# make control model
self.control_model = torch.nn.Module()
dims = [320, 320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280, 1280]
zero_convs = torch.nn.ModuleList()
for i, dim in enumerate(dims):
sub_list = torch.nn.ModuleList([torch.nn.Conv2d(dim, dim, 1)])
zero_convs.append(sub_list)
self.control_model.add_module("zero_convs", zero_convs)
dims = [320, 320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280, 1280]
zero_convs = torch.nn.ModuleList()
for i, dim in enumerate(dims):
sub_list = torch.nn.ModuleList([torch.nn.Conv2d(dim, dim, 1)])
zero_convs.append(sub_list)
self.control_model.add_module("zero_convs", zero_convs)
middle_block_out = torch.nn.Conv2d(1280, 1280, 1)
self.control_model.add_module("middle_block_out", torch.nn.ModuleList([middle_block_out]))
middle_block_out = torch.nn.Conv2d(1280, 1280, 1)
self.control_model.add_module("middle_block_out", torch.nn.ModuleList([middle_block_out]))
dims = [16, 16, 32, 32, 96, 96, 256, 320]
strides = [1, 1, 2, 1, 2, 1, 2, 1]
prev_dim = 3
input_hint_block = torch.nn.Sequential()
for i, (dim, stride) in enumerate(zip(dims, strides)):
input_hint_block.append(torch.nn.Conv2d(prev_dim, dim, 3, stride, 1))
if i < len(dims) - 1:
input_hint_block.append(torch.nn.SiLU())
prev_dim = dim
self.control_model.add_module("input_hint_block", input_hint_block)
dims = [16, 16, 32, 32, 96, 96, 256, 320]
strides = [1, 1, 2, 1, 2, 1, 2, 1]
prev_dim = 3
input_hint_block = torch.nn.Sequential()
for i, (dim, stride) in enumerate(zip(dims, strides)):
input_hint_block.append(torch.nn.Conv2d(prev_dim, dim, 3, stride, 1))
if i < len(dims) - 1:
input_hint_block.append(torch.nn.SiLU())
prev_dim = dim
self.control_model.add_module("input_hint_block", input_hint_block)
def load_control_net(v2, unet, model):
device = unet.device
device = unet.device
# control sdからキー変換しつつU-Netに対応する部分のみ取り出し、DiffusersのU-Netに読み込む
# state dictを読み込む
print(f"ControlNet: loading control SD model : {model}")
# control sdからキー変換しつつU-Netに対応する部分のみ取り出し、DiffusersのU-Netに読み込む
# state dictを読み込む
print(f"ControlNet: loading control SD model : {model}")
if model_util.is_safetensors(model):
ctrl_sd_sd = load_file(model)
else:
ctrl_sd_sd = torch.load(model, map_location='cpu')
ctrl_sd_sd = ctrl_sd_sd.pop("state_dict", ctrl_sd_sd)
if model_util.is_safetensors(model):
ctrl_sd_sd = load_file(model)
else:
ctrl_sd_sd = torch.load(model, map_location="cpu")
ctrl_sd_sd = ctrl_sd_sd.pop("state_dict", ctrl_sd_sd)
# 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む
is_difference = "difference" in ctrl_sd_sd
print("ControlNet: loading difference")
# 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む
is_difference = "difference" in ctrl_sd_sd
print("ControlNet: loading difference:", is_difference)
# ControlNetには存在しないキーがあるので、まず現在のU-NetでSD版の全keyを作っておく
# またTransfer Controlの元weightとなる
ctrl_unet_sd_sd = model_util.convert_unet_state_dict_to_sd(v2, unet.state_dict())
# ControlNetには存在しないキーがあるので、まず現在のU-NetでSD版の全keyを作っておく
# またTransfer Controlの元weightとなる
ctrl_unet_sd_sd = model_util.convert_unet_state_dict_to_sd(v2, unet.state_dict())
# 元のU-Netに影響しないようにコピーする。またprefixが付いていないので付ける
for key in list(ctrl_unet_sd_sd.keys()):
ctrl_unet_sd_sd["model.diffusion_model." + key] = ctrl_unet_sd_sd.pop(key).clone()
# 元のU-Netに影響しないようにコピーする。またprefixが付いていないので付ける
for key in list(ctrl_unet_sd_sd.keys()):
ctrl_unet_sd_sd["model.diffusion_model." + key] = ctrl_unet_sd_sd.pop(key).clone()
zero_conv_sd = {}
for key in list(ctrl_sd_sd.keys()):
if key.startswith("control_"):
unet_key = "model.diffusion_" + key[len("control_"):]
if unet_key not in ctrl_unet_sd_sd: # zero conv
zero_conv_sd[key] = ctrl_sd_sd[key]
continue
if is_difference: # Transfer Control
ctrl_unet_sd_sd[unet_key] += ctrl_sd_sd[key].to(device, dtype=unet.dtype)
else:
ctrl_unet_sd_sd[unet_key] = ctrl_sd_sd[key].to(device, dtype=unet.dtype)
zero_conv_sd = {}
for key in list(ctrl_sd_sd.keys()):
if key.startswith("control_"):
unet_key = "model.diffusion_" + key[len("control_") :]
if unet_key not in ctrl_unet_sd_sd: # zero conv
zero_conv_sd[key] = ctrl_sd_sd[key]
continue
if is_difference: # Transfer Control
ctrl_unet_sd_sd[unet_key] += ctrl_sd_sd[key].to(device, dtype=unet.dtype)
else:
ctrl_unet_sd_sd[unet_key] = ctrl_sd_sd[key].to(device, dtype=unet.dtype)
unet_config = model_util.create_unet_diffusers_config(v2)
ctrl_unet_du_sd = model_util.convert_ldm_unet_checkpoint(v2, ctrl_unet_sd_sd, unet_config) # DiffUsers版ControlNetのstate dict
unet_config = model_util.create_unet_diffusers_config(v2)
ctrl_unet_du_sd = model_util.convert_ldm_unet_checkpoint(v2, ctrl_unet_sd_sd, unet_config) # DiffUsers版ControlNetのstate dict
# ControlNetのU-Netを作成する
ctrl_unet = UNet2DConditionModel(**unet_config)
info = ctrl_unet.load_state_dict(ctrl_unet_du_sd)
print("ControlNet: loading Control U-Net:", info)
# ControlNetのU-Netを作成する
ctrl_unet = UNet2DConditionModel(**unet_config)
info = ctrl_unet.load_state_dict(ctrl_unet_du_sd)
print("ControlNet: loading Control U-Net:", info)
# U-Net以外のControlNetを作成する
# TODO support middle only
ctrl_net = ControlNet()
info = ctrl_net.load_state_dict(zero_conv_sd)
print("ControlNet: loading ControlNet:", info)
# U-Net以外のControlNetを作成する
# TODO support middle only
ctrl_net = ControlNet()
info = ctrl_net.load_state_dict(zero_conv_sd)
print("ControlNet: loading ControlNet:", info)
ctrl_unet.to(unet.device, dtype=unet.dtype)
ctrl_net.to(unet.device, dtype=unet.dtype)
return ctrl_unet, ctrl_net
ctrl_unet.to(unet.device, dtype=unet.dtype)
ctrl_net.to(unet.device, dtype=unet.dtype)
return ctrl_unet, ctrl_net
def load_preprocess(prep_type: str):
if prep_type is None or prep_type.lower() == "none":
if prep_type is None or prep_type.lower() == "none":
return None
if prep_type.startswith("canny"):
args = prep_type.split("_")
th1 = int(args[1]) if len(args) >= 2 else 63
th2 = int(args[2]) if len(args) >= 3 else 191
def canny(img):
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
return cv2.Canny(img, th1, th2)
return canny
print("Unsupported prep type:", prep_type)
return None
if prep_type.startswith("canny"):
args = prep_type.split("_")
th1 = int(args[1]) if len(args) >= 2 else 63
th2 = int(args[2]) if len(args) >= 3 else 191
def canny(img):
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
return cv2.Canny(img, th1, th2)
return canny
print("Unsupported prep type:", prep_type)
return None
def preprocess_ctrl_net_hint_image(image):
image = np.array(image).astype(np.float32) / 255.0
image = image[:, :, ::-1].copy() # rgb to bgr
image = image[None].transpose(0, 3, 1, 2) # nchw
image = torch.from_numpy(image)
return image # 0 to 1
image = np.array(image).astype(np.float32) / 255.0
# ControlNetのサンプルはcv2を使っているが、読み込みはGradioなので実はRGBになっている
# image = image[:, :, ::-1].copy() # rgb to bgr
image = image[None].transpose(0, 3, 1, 2) # nchw
image = torch.from_numpy(image)
return image # 0 to 1
def get_guided_hints(control_nets: List[ControlNetInfo], num_latent_input, b_size, hints):
guided_hints = []
for i, cnet_info in enumerate(control_nets):
# hintは 1枚目の画像のcnet1, 1枚目の画像のcnet2, 1枚目の画像のcnet3, 2枚目の画像のcnet1, 2枚目の画像のcnet2 ... と並んでいること
b_hints = []
if len(hints) == 1: # すべて同じ画像をhintとして使う
hint = hints[0]
if cnet_info.prep is not None:
hint = cnet_info.prep(hint)
hint = preprocess_ctrl_net_hint_image(hint)
b_hints = [hint for _ in range(b_size)]
else:
for bi in range(b_size):
hint = hints[(bi * len(control_nets) + i) % len(hints)]
if cnet_info.prep is not None:
hint = cnet_info.prep(hint)
hint = preprocess_ctrl_net_hint_image(hint)
b_hints.append(hint)
b_hints = torch.cat(b_hints, dim=0)
b_hints = b_hints.to(cnet_info.unet.device, dtype=cnet_info.unet.dtype)
guided_hints = []
for i, cnet_info in enumerate(control_nets):
# hintは 1枚目の画像のcnet1, 1枚目の画像のcnet2, 1枚目の画像のcnet3, 2枚目の画像のcnet1, 2枚目の画像のcnet2 ... と並んでいること
b_hints = []
if len(hints) == 1: # すべて同じ画像をhintとして使う
hint = hints[0]
if cnet_info.prep is not None:
hint = cnet_info.prep(hint)
hint = preprocess_ctrl_net_hint_image(hint)
b_hints = [hint for _ in range(b_size)]
else:
for bi in range(b_size):
hint = hints[(bi * len(control_nets) + i) % len(hints)]
if cnet_info.prep is not None:
hint = cnet_info.prep(hint)
hint = preprocess_ctrl_net_hint_image(hint)
b_hints.append(hint)
b_hints = torch.cat(b_hints, dim=0)
b_hints = b_hints.to(cnet_info.unet.device, dtype=cnet_info.unet.dtype)
guided_hint = cnet_info.net.control_model.input_hint_block(b_hints)
guided_hints.append(guided_hint)
return guided_hints
guided_hint = cnet_info.net.control_model.input_hint_block(b_hints)
guided_hints.append(guided_hint)
return guided_hints
def call_unet_and_control_net(step, num_latent_input, original_unet, control_nets: List[ControlNetInfo], guided_hints, current_ratio, sample, timestep, encoder_hidden_states):
# ControlNet
# 複数のControlNetの場合は、出力をマージするのではなく交互に適用する
cnet_cnt = len(control_nets)
cnet_idx = step % cnet_cnt
cnet_info = control_nets[cnet_idx]
def call_unet_and_control_net(
step,
num_latent_input,
original_unet,
control_nets: List[ControlNetInfo],
guided_hints,
current_ratio,
sample,
timestep,
encoder_hidden_states,
encoder_hidden_states_for_control_net,
):
# ControlNet
# 複数のControlNetの場合は、出力をマージするのではなく交互に適用する
cnet_cnt = len(control_nets)
cnet_idx = step % cnet_cnt
cnet_info = control_nets[cnet_idx]
# print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
if cnet_info.ratio < current_ratio:
return original_unet(sample, timestep, encoder_hidden_states)
# print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
if cnet_info.ratio < current_ratio:
return original_unet(sample, timestep, encoder_hidden_states)
guided_hint = guided_hints[cnet_idx]
guided_hint = guided_hint.repeat((num_latent_input, 1, 1, 1))
outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states)
outs = [o * cnet_info.weight for o in outs]
guided_hint = guided_hints[cnet_idx]
guided_hint = guided_hint.repeat((num_latent_input, 1, 1, 1))
outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states_for_control_net)
outs = [o * cnet_info.weight for o in outs]
# U-Net
return unet_forward(False, cnet_info.net, original_unet, None, outs, sample, timestep, encoder_hidden_states)
# U-Net
return unet_forward(False, cnet_info.net, original_unet, None, outs, sample, timestep, encoder_hidden_states)
"""
@@ -203,118 +215,123 @@ def call_unet_and_control_net(step, num_latent_input, original_unet, control_net
"""
def unet_forward(is_control_net, control_net: ControlNet, unet: UNet2DConditionModel, guided_hint, ctrl_outs, sample, timestep, encoder_hidden_states):
# copy from UNet2DConditionModel
default_overall_up_factor = 2**unet.num_upsamplers
def unet_forward(
is_control_net,
control_net: ControlNet,
unet: UNet2DConditionModel,
guided_hint,
ctrl_outs,
sample,
timestep,
encoder_hidden_states,
):
# copy from UNet2DConditionModel
default_overall_up_factor = 2**unet.num_upsamplers
forward_upsample_size = False
upsample_size = None
forward_upsample_size = False
upsample_size = None
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
print("Forward upsample size to force interpolation output size.")
forward_upsample_size = True
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
print("Forward upsample size to force interpolation output size.")
forward_upsample_size = True
# 0. center input if necessary
if unet.config.center_input_sample:
sample = 2 * sample - 1.0
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
t_emb = unet.time_proj(timesteps)
t_emb = unet.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=unet.dtype)
emb = unet.time_embedding(t_emb)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=unet.dtype)
emb = unet.time_embedding(t_emb)
outs = [] # output of ControlNet
zc_idx = 0
outs = [] # output of ControlNet
zc_idx = 0
# 2. pre-process
sample = unet.conv_in(sample)
if is_control_net:
sample += guided_hint
outs.append(control_net.control_model.zero_convs[zc_idx][0](sample)) # , emb, encoder_hidden_states))
zc_idx += 1
# 3. down
down_block_res_samples = (sample,)
for downsample_block in unet.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
# 2. pre-process
sample = unet.conv_in(sample)
if is_control_net:
for rs in res_samples:
outs.append(control_net.control_model.zero_convs[zc_idx][0](rs)) # , emb, encoder_hidden_states))
sample += guided_hint
outs.append(control_net.control_model.zero_convs[zc_idx][0](sample)) # , emb, encoder_hidden_states))
zc_idx += 1
down_block_res_samples += res_samples
# 3. down
down_block_res_samples = (sample,)
for downsample_block in unet.down_blocks:
if downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
if is_control_net:
for rs in res_samples:
outs.append(control_net.control_model.zero_convs[zc_idx][0](rs)) # , emb, encoder_hidden_states))
zc_idx += 1
# 4. mid
sample = unet.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
if is_control_net:
outs.append(control_net.control_model.middle_block_out[0](sample))
return outs
down_block_res_samples += res_samples
if not is_control_net:
sample += ctrl_outs.pop()
# 4. mid
sample = unet.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
if is_control_net:
outs.append(control_net.control_model.middle_block_out[0](sample))
return outs
# 5. up
for i, upsample_block in enumerate(unet.up_blocks):
is_final_block = i == len(unet.up_blocks) - 1
if not is_control_net:
sample += ctrl_outs.pop()
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
# 5. up
for i, upsample_block in enumerate(unet.up_blocks):
is_final_block = i == len(unet.up_blocks) - 1
if not is_control_net and len(ctrl_outs) > 0:
res_samples = list(res_samples)
apply_ctrl_outs = ctrl_outs[-len(res_samples):]
ctrl_outs = ctrl_outs[:-len(res_samples)]
for j in range(len(res_samples)):
res_samples[j] = res_samples[j] + apply_ctrl_outs[j]
res_samples = tuple(res_samples)
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if not is_control_net and len(ctrl_outs) > 0:
res_samples = list(res_samples)
apply_ctrl_outs = ctrl_outs[-len(res_samples) :]
ctrl_outs = ctrl_outs[: -len(res_samples)]
for j in range(len(res_samples)):
res_samples[j] = res_samples[j] + apply_ctrl_outs[j]
res_samples = tuple(res_samples)
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
upsample_size=upsample_size,
)
else:
sample = upsample_block(
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
)
# 6. post-process
sample = unet.conv_norm_out(sample)
sample = unet.conv_act(sample)
sample = unet.conv_out(sample)
# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
return UNet2DConditionOutput(sample=sample)
if upsample_block.has_cross_attention:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
upsample_size=upsample_size,
)
else:
sample = upsample_block(
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
)
# 6. post-process
sample = unet.conv_norm_out(sample)
sample = unet.conv_act(sample)
sample = unet.conv_out(sample)
return SampleOutput(sample=sample)

19
tools/show_metadata.py Normal file
View File

@@ -0,0 +1,19 @@
import json
import argparse
from safetensors import safe_open
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, required=True)
args = parser.parse_args()
with safe_open(args.model, framework="pt") as f:
metadata = f.metadata()
if metadata is None:
print("No metadata found")
else:
# metadata is json dict, but not pretty printed
# sort by key and pretty print
print(json.dumps(metadata, indent=4, sort_keys=True))

611
train_controlnet.py Normal file
View File

@@ -0,0 +1,611 @@
import argparse
import gc
import json
import math
import os
import random
import time
from multiprocessing import Value
from types import SimpleNamespace
import toml
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from torch.nn.parallel import DistributedDataParallel as DDP
from accelerate.utils import set_seed
from diffusers import DDPMScheduler, ControlNetModel
from safetensors.torch import load_file
import library.model_util as model_util
import library.train_util as train_util
import library.config_util as config_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
import library.huggingface_util as huggingface_util
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import (
apply_snr_weight,
pyramid_noise_like,
apply_noise_offset,
)
# TODO 他のスクリプトと共通化する
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
logs = {
"loss/current": current_loss,
"loss/average": avr_loss,
"lr": lr_scheduler.get_last_lr()[0],
}
if args.optimizer_type.lower().startswith("DAdapt".lower()):
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
return logs
def train(args):
# session_id = random.randint(0, 2**32)
# training_started_at = time.time()
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
cache_latents = args.cache_latents
use_user_config = args.dataset_config is not None
if args.seed is None:
args.seed = random.randint(0, 2**32)
set_seed(args.seed)
tokenizer = train_util.load_tokenizer(args)
# データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
if use_user_config:
print(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "conditioning_data_dir"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
user_config = {
"datasets": [
{
"subsets": config_util.generate_controlnet_subsets_config_by_subdirs(
args.train_data_dir,
args.conditioning_data_dir,
args.caption_extension,
)
}
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
if args.debug_dataset:
train_util.debug_dataset(train_dataset_group)
return
if len(train_dataset_group) == 0:
print(
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してくださいtrain_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります"
)
return
if cache_latents:
assert (
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
# acceleratorを準備する
print("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
is_main_process = accelerator.is_main_process
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
# モデルを読み込む
text_encoder, vae, unet, _ = train_util.load_target_model(
args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=True
)
# DiffusersのControlNetが使用するデータを準備する
if args.v2:
unet.config = {
"act_fn": "silu",
"attention_head_dim": [5, 10, 20, 20],
"block_out_channels": [320, 640, 1280, 1280],
"center_input_sample": False,
"cross_attention_dim": 1024,
"down_block_types": ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"],
"downsample_padding": 1,
"dual_cross_attention": False,
"flip_sin_to_cos": True,
"freq_shift": 0,
"in_channels": 4,
"layers_per_block": 2,
"mid_block_scale_factor": 1,
"norm_eps": 1e-05,
"norm_num_groups": 32,
"num_class_embeds": None,
"only_cross_attention": False,
"out_channels": 4,
"sample_size": 96,
"up_block_types": ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"],
"use_linear_projection": True,
"upcast_attention": True,
"only_cross_attention": False,
"downsample_padding": 1,
"use_linear_projection": True,
"class_embed_type": None,
"num_class_embeds": None,
"resnet_time_scale_shift": "default",
"projection_class_embeddings_input_dim": None,
}
else:
unet.config = {
"act_fn": "silu",
"attention_head_dim": 8,
"block_out_channels": [320, 640, 1280, 1280],
"center_input_sample": False,
"cross_attention_dim": 768,
"down_block_types": ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"],
"downsample_padding": 1,
"flip_sin_to_cos": True,
"freq_shift": 0,
"in_channels": 4,
"layers_per_block": 2,
"mid_block_scale_factor": 1,
"norm_eps": 1e-05,
"norm_num_groups": 32,
"out_channels": 4,
"sample_size": 64,
"up_block_types": ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"],
"only_cross_attention": False,
"downsample_padding": 1,
"use_linear_projection": False,
"class_embed_type": None,
"num_class_embeds": None,
"upcast_attention": False,
"resnet_time_scale_shift": "default",
"projection_class_embeddings_input_dim": None,
}
unet.config = SimpleNamespace(**unet.config)
controlnet = ControlNetModel.from_unet(unet)
if args.controlnet_model_name_or_path:
filename = args.controlnet_model_name_or_path
if os.path.isfile(filename):
if os.path.splitext(filename)[1] == ".safetensors":
state_dict = load_file(filename)
else:
state_dict = torch.load(filename)
state_dict = model_util.convert_controlnet_state_dict_to_diffusers(state_dict)
controlnet.load_state_dict(state_dict)
elif os.path.isdir(filename):
controlnet = ControlNetModel.from_pretrained(filename)
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(
vae,
args.vae_batch_size,
args.cache_latents_to_disk,
accelerator.is_main_process,
)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
accelerator.wait_for_everyone()
if args.gradient_checkpointing:
controlnet.enable_gradient_checkpointing()
# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")
trainable_params = controlnet.parameters()
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする
if args.full_fp16:
assert (
args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
accelerator.print("enable full fp16 training.")
controlnet.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
controlnet, optimizer, train_dataloader, lr_scheduler
)
unet.requires_grad_(False)
text_encoder.requires_grad_(False)
unet.to(accelerator.device)
text_encoder.to(accelerator.device)
# transform DDP after prepare
controlnet = controlnet.module if isinstance(controlnet, DDP) else controlnet
controlnet.train()
if not cache_latents:
vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=weight_dtype)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
# 学習する
# TODO: find a way to handle total batch size when there are multiple datasets
accelerator.print("running training / 学習開始")
accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
# print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(
range(args.max_train_steps),
smoothing=0,
disable=not accelerator.is_local_main_process,
desc="steps",
)
global_step = 0
noise_scheduler = DDPMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
clip_sample=False,
)
if accelerator.is_main_process:
init_kwargs = {}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
loss_list = []
loss_total = 0.0
del train_dataset_group
# function for saving/removing
def save_model(ckpt_name, model, force_sync_upload=False):
os.makedirs(args.output_dir, exist_ok=True)
ckpt_file = os.path.join(args.output_dir, ckpt_name)
accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
state_dict = model_util.convert_controlnet_state_dict_to_sd(model.state_dict())
if save_dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(save_dtype)
state_dict[key] = v
if os.path.splitext(ckpt_file)[1] == ".safetensors":
from safetensors.torch import save_file
save_file(state_dict, ckpt_file)
else:
torch.save(state_dict, ckpt_file)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
def remove_model(old_ckpt_name):
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
if os.path.exists(old_ckpt_file):
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
# training loop
for epoch in range(num_train_epochs):
if is_main_process:
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(controlnet):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
b_size = latents.shape[0]
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
noise = apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale)
elif args.multires_noise_iterations:
noise = pyramid_noise_like(
noise,
latents.device,
args.multires_noise_iterations,
args.multires_noise_discount,
)
# Sample a random timestep for each image
timesteps = torch.randint(
0,
noise_scheduler.config.num_train_timesteps,
(b_size,),
device=latents.device,
)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype)
with accelerator.autocast():
down_block_res_samples, mid_block_res_sample = controlnet(
noisy_latents,
timesteps,
encoder_hidden_states=encoder_hidden_states,
controlnet_cond=controlnet_image,
return_dict=False,
)
# Predict the noise residual
noise_pred = unet(
noisy_latents,
timesteps,
encoder_hidden_states,
down_block_additional_residuals=[sample.to(dtype=weight_dtype) for sample in down_block_res_samples],
mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
).sample
if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = controlnet.parameters()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
train_util.sample_images(
accelerator,
args,
None,
global_step,
accelerator.device,
vae,
tokenizer,
text_encoder,
unet,
controlnet=controlnet,
)
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
save_model(
ckpt_name,
accelerator.unwrap_model(controlnet),
)
if args.save_state:
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
remove_step_no = train_util.get_remove_step_no(args, global_step)
if remove_step_no is not None:
remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no)
remove_model(remove_ckpt_name)
current_loss = loss.detach().item()
if epoch == 0:
loss_list.append(current_loss)
else:
loss_total -= loss_list[step]
loss_list[step] = current_loss
loss_total += current_loss
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if args.logging_dir is not None:
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(loss_list)}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
# 指定エポックごとにモデルを保存
if args.save_every_n_epochs is not None:
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
if is_main_process and saving:
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
save_model(ckpt_name, accelerator.unwrap_model(controlnet))
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
if remove_epoch_no is not None:
remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no)
remove_model(remove_ckpt_name)
if args.save_state:
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
train_util.sample_images(
accelerator,
args,
epoch + 1,
global_step,
accelerator.device,
vae,
tokenizer,
text_encoder,
unet,
controlnet=controlnet,
)
# end of epoch
if is_main_process:
controlnet = accelerator.unwrap_model(controlnet)
accelerator.end_training()
if is_main_process and args.save_state:
train_util.save_state_on_train_end(args, accelerator)
# del accelerator # この後メモリを使うのでこれは消す→printで使うので消さずにおく
if is_main_process:
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(ckpt_name, controlnet, force_sync_upload=True)
print("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
parser.add_argument(
"--save_model_as",
type=str,
default="safetensors",
choices=[None, "ckpt", "pt", "safetensors"],
help="format to save the model (default is .safetensors) / モデル保存時の形式デフォルトはsafetensors",
)
parser.add_argument(
"--controlnet_model_name_or_path",
type=str,
default=None,
help="controlnet model name or path / controlnetのモデル名またはパス",
)
parser.add_argument(
"--conditioning_data_dir",
type=str,
default=None,
help="conditioning data directory / 条件付けデータのディレクトリ",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
train(args)

View File

@@ -2,18 +2,23 @@
# XXX dropped option: fine_tune
import gc
import time
import argparse
import itertools
import math
import os
import toml
from multiprocessing import Value
import toml
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from accelerate.utils import set_seed
import diffusers
from diffusers import DDPMScheduler
import library.train_util as train_util
@@ -23,7 +28,16 @@ from library.config_util import (
BlueprintGenerator,
)
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like
from library.custom_train_functions import (
apply_snr_weight,
get_weighted_text_embeddings,
prepare_scheduler_for_custom_training,
pyramid_noise_like,
apply_noise_offset,
scale_v_prediction_loss_like_noise_prediction,
)
# perlin_noise,
def train(args):
@@ -37,31 +51,35 @@ def train(args):
tokenizer = train_util.load_tokenizer(args)
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "reg_data_dir"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, False, True))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "reg_data_dir"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
)
else:
user_config = {
"datasets": [
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
]
}
else:
user_config = {
"datasets": [
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
if args.no_token_padding:
train_dataset_group.disable_token_padding()
@@ -86,7 +104,7 @@ def train(args):
f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデルU-NetおよびText Encoderの学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です"
)
accelerator, unwrap_model = train_util.prepare_accelerator(args)
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
@@ -110,7 +128,7 @@ def train(args):
use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower())
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
# 学習を準備する
if cache_latents:
@@ -131,7 +149,7 @@ def train(args):
unet.requires_grad_(True) # 念のため追加
text_encoder.requires_grad_(train_text_encoder)
if not train_text_encoder:
print("Text Encoder is not trained.")
accelerator.print("Text Encoder is not trained.")
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
@@ -143,12 +161,13 @@ def train(args):
vae.to(accelerator.device, dtype=weight_dtype)
# 学習に必要なクラスを準備する
print("prepare optimizer, data loader etc.")
accelerator.print("prepare optimizer, data loader etc.")
if train_text_encoder:
trainable_params = itertools.chain(unet.parameters(), text_encoder.parameters())
# wightout list, adamw8bit is crashed
trainable_params = list(itertools.chain(unet.parameters(), text_encoder.parameters()))
else:
trainable_params = unet.parameters()
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
@@ -158,7 +177,7 @@ def train(args):
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collater,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
@@ -168,7 +187,7 @@ def train(args):
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
@@ -184,7 +203,7 @@ def train(args):
assert (
args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
print("enable full fp16 training.")
accelerator.print("enable full fp16 training.")
unet.to(weight_dtype)
text_encoder.to(weight_dtype)
@@ -217,15 +236,17 @@ def train(args):
# 学習する
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
print("running training / 学習開始")
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
print(f" gradient ccumulation steps / 勾配合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
accelerator.print("running training / 学習開始")
accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
accelerator.print(
f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
)
accelerator.print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
@@ -233,14 +254,20 @@ def train(args):
noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
)
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
if args.zero_terminal_snr:
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
if accelerator.is_main_process:
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name)
init_kwargs = {}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
loss_list = []
loss_total = 0.0
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
# 指定したステップ数までText Encoderを学習するepoch最初の状態
@@ -253,7 +280,7 @@ def train(args):
current_step.value = global_step
# 指定したステップ数でText Encoderの学習を止める
if global_step == args.stop_text_encoder_training:
print(f"stop text encoder training at step {global_step}")
accelerator.print(f"stop text encoder training at step {global_step}")
if not args.gradient_checkpointing:
text_encoder.train(False)
text_encoder.requires_grad_(False)
@@ -268,14 +295,6 @@ def train(args):
latents = latents * 0.18215
b_size = latents.shape[0]
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
elif args.multires_noise_iterations:
noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount)
# Get the text embedding for conditioning
with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):
if args.weighted_captions:
@@ -293,13 +312,9 @@ def train(args):
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
)
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
# Predict the noise residual
with accelerator.autocast():
@@ -319,6 +334,8 @@ def train(args):
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
@@ -359,15 +376,17 @@ def train(args):
epoch,
num_train_epochs,
global_step,
unwrap_model(text_encoder),
unwrap_model(unet),
accelerator.unwrap_model(text_encoder),
accelerator.unwrap_model(unet),
vae,
)
current_loss = loss.detach().item()
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
if (
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
): # tracking d*lr value
logs["lr/d*lr"] = (
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
)
@@ -407,8 +426,8 @@ def train(args):
epoch,
num_train_epochs,
global_step,
unwrap_model(text_encoder),
unwrap_model(unet),
accelerator.unwrap_model(text_encoder),
accelerator.unwrap_model(unet),
vae,
)
@@ -416,8 +435,8 @@ def train(args):
is_main_process = accelerator.is_main_process
if is_main_process:
unet = unwrap_model(unet)
text_encoder = unwrap_model(text_encoder)
unet = accelerator.unwrap_model(unet)
text_encoder = accelerator.unwrap_model(text_encoder)
accelerator.end_training()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -8,9 +8,17 @@ from multiprocessing import Value
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from accelerate.utils import set_seed
import diffusers
from diffusers import DDPMScheduler
import library
import library.train_util as train_util
import library.huggingface_util as huggingface_util
@@ -20,7 +28,14 @@ from library.config_util import (
BlueprintGenerator,
)
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight, pyramid_noise_like
from library.custom_train_functions import (
apply_snr_weight,
prepare_scheduler_for_custom_training,
pyramid_noise_like,
apply_noise_offset,
scale_v_prediction_loss_like_noise_prediction,
)
import library.original_unet as original_unet
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
imagenet_templates_small = [
@@ -88,6 +103,9 @@ def train(args):
print(
"sample_every_n_steps and sample_every_n_epochs are not supported in this script currently / sample_every_n_stepsとsample_every_n_epochsは現在このスクリプトではサポートされていません"
)
assert (
args.dataset_class is None
), "dataset_class is not supported in this script currently / dataset_classは現在このスクリプトではサポートされていません"
cache_latents = args.cache_latents
@@ -98,7 +116,7 @@ def train(args):
# acceleratorを準備する
print("prepare accelerator")
accelerator, unwrap_model = train_util.prepare_accelerator(args)
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
@@ -178,7 +196,7 @@ def train(args):
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
# データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False))
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, False))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
@@ -218,12 +236,12 @@ def train(args):
train_dataset_group.enable_XTI(XTI_layers, token_strings=token_strings)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
if use_template:
print("use template for training captions. is object: {args.use_object_template}")
print(f"use template for training captions. is object: {args.use_object_template}")
templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small
replace_to = " ".join(token_strings)
captions = []
@@ -256,10 +274,10 @@ def train(args):
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI
diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D.forward = downblock_forward_XTI
diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D.forward = upblock_forward_XTI
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
original_unet.UNet2DConditionModel.forward = unet_forward_XTI
original_unet.CrossAttnDownBlock2D.forward = downblock_forward_XTI
original_unet.CrossAttnUpBlock2D.forward = upblock_forward_XTI
# 学習を準備する
if cache_latents:
@@ -291,7 +309,7 @@ def train(args):
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collater,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
@@ -319,7 +337,7 @@ def train(args):
index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0]
# print(len(index_no_updates), torch.sum(index_no_updates))
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
# Freeze all parameters except for the token embeddings in text encoder
text_encoder.requires_grad_(True)
@@ -372,16 +390,22 @@ def train(args):
noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
)
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
if args.zero_terminal_snr:
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
if accelerator.is_main_process:
accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name)
init_kwargs = {}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
# function for saving/removing
def save_model(ckpt_name, embs, steps, epoch_no, force_sync_upload=False):
os.makedirs(args.output_dir, exist_ok=True)
ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"saving checkpoint: {ckpt_file}")
print(f"\nsaving checkpoint: {ckpt_file}")
save_weights(ckpt_file, embs, save_dtype)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
@@ -394,7 +418,7 @@ def train(args):
# training loop
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
text_encoder.train()
@@ -423,21 +447,9 @@ def train(args):
]
)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
elif args.multires_noise_iterations:
noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount)
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
# Predict the noise residual
with accelerator.autocast():
@@ -452,11 +464,13 @@ def train(args):
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
@@ -471,7 +485,7 @@ def train(args):
# Let's make sure we don't update any embedding weights besides the newly added token
with torch.no_grad():
unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[
index_no_updates
]
@@ -488,7 +502,13 @@ def train(args):
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids_XTI].data.detach().clone()
updated_embs = (
accelerator.unwrap_model(text_encoder)
.get_input_embeddings()
.weight[token_ids_XTI]
.data.detach()
.clone()
)
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
save_model(ckpt_name, updated_embs, global_step, epoch)
@@ -504,7 +524,9 @@ def train(args):
current_loss = loss.detach().item()
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
if (
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
): # tracking d*lr value
logs["lr/d*lr"] = (
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
)
@@ -524,7 +546,7 @@ def train(args):
accelerator.wait_for_everyone()
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids_XTI].data.detach().clone()
updated_embs = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[token_ids_XTI].data.detach().clone()
if args.save_every_n_epochs is not None:
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
@@ -549,7 +571,7 @@ def train(args):
is_main_process = accelerator.is_main_process
if is_main_process:
text_encoder = unwrap_model(text_encoder)
text_encoder = accelerator.unwrap_model(text_encoder)
accelerator.end_training()