Compare commits

..

123 Commits

Author SHA1 Message Date
Kohya S
26d35794e3 Merge pull request #1052 from kohya-ss/dev
merge dev
2024-01-15 21:39:02 +09:00
Kohya S
dcf0eeb5b6 update readme 2024-01-15 21:35:26 +09:00
Kohya S
32b759a328 Add wandb_run_name parameter to init_kwargs #1032 2024-01-14 22:02:03 +09:00
Kohya S
09ef3ffa8b Merge branch 'main' into dev 2024-01-14 21:49:25 +09:00
Kohya S
aab265e431 Fix an issue with saving as diffusers sd1/2 model close #1033 2024-01-04 21:43:50 +09:00
Kohya S
716bad188b Update dependencies ref #1024 2024-01-04 19:53:25 +09:00
Kohya S
4f93bf10f0 Merge pull request #1032 from hopl1t/wandb_session_name_support
Added cli argument for wandb session name
2024-01-04 11:10:31 +09:00
Kohya S
07bf2a21ac Merge pull request #1024 from p1atdev/main
Add support for `torch.compile`
2024-01-04 10:49:52 +09:00
Kohya S
8ac2d2a92f Merge pull request #1030 from Disty0/dev
Update IPEX Libs
2024-01-04 10:46:07 +09:00
Kohya S
76aee71257 Merge branch 'main' into dev 2024-01-04 10:42:16 +09:00
Kohya S
1db5d790ed Merge pull request #1029 from kohya-ss/dependabot/github_actions/crate-ci/typos-1.16.26
Bump crate-ci/typos from 1.16.15 to 1.16.26
2024-01-04 10:41:07 +09:00
Kohya S
663b481029 fix TI training with full_fp16/bf16 ref #1019 2024-01-03 23:22:00 +09:00
Kohya S
1ab6493268 Merge branch 'main' into dev 2024-01-03 21:36:31 +09:00
Nir Weingarten
ab716302e4 Added cli argument for wandb session name 2024-01-03 11:52:38 +02:00
Disty0
b9d2181192 Cleanup 2024-01-02 11:51:29 +03:00
Disty0
49148eb36e Disable Diffusers slicing if device is not XPU 2024-01-02 11:50:08 +03:00
Disty0
479bac447e Fix typo 2024-01-01 12:51:23 +03:00
Disty0
15d5e78ac2 Update IPEX Libs 2024-01-01 12:44:26 +03:00
dependabot[bot]
fd7f27f044 Bump crate-ci/typos from 1.16.15 to 1.16.26
Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.16.15 to 1.16.26.
- [Release notes](https://github.com/crate-ci/typos/releases)
- [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md)
- [Commits](https://github.com/crate-ci/typos/compare/v1.16.15...v1.16.26)

---
updated-dependencies:
- dependency-name: crate-ci/typos
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-01-01 01:28:55 +00:00
Plat
62e7516537 feat: support torch.compile 2023-12-27 02:17:24 +09:00
Plat
20296b4f0e chore: bump eniops version due to support torch.compile 2023-12-27 02:17:24 +09:00
Kohya S
5cae6db804 Fix to work with DDP TextualInversionTrainer ref #1019 2023-12-24 22:05:56 +09:00
Kohya S
1a36f9dc65 Merge pull request #1020 from kohya-ss/dev
Fix convert_diffusers20_original_sd.py and add metadata & variant options
2023-12-24 21:48:25 +09:00
Kohya S
c2497877ca Merge branch 'main' into dev 2023-12-24 21:46:05 +09:00
Kohya S
3b5c1a1d4b Fix issue with tools/convert_diffusers20_original_sd.py 2023-12-24 21:45:51 +09:00
Kohya S
9a2e385f12 Merge pull request #1016 from Disty0/dev
Fix convert_diffusers20_original_sd.py and add --variant option for loading
2023-12-24 21:41:12 +09:00
Disty0
7080e1a11c Fix convert_diffusers20_original_sd.py and add metadata & variant options 2023-12-22 22:40:03 +03:00
Kohya S
0a52b83c6a Merge pull request #1012 from kohya-ss/dev
merge dev to main
2023-12-21 22:18:39 +09:00
Kohya S
11ed8e2a6d update readme 2023-12-21 22:16:55 +09:00
Kohya S
bb20c09a9a update readme 2023-12-21 22:10:47 +09:00
Kohya S
04ef8d395f speed up nan replace in sdxl training ref #1009 2023-12-21 21:44:03 +09:00
Kohya S
0676f1a86f Merge pull request #1009 from liubo0902/main
speed up latents nan replace
2023-12-21 21:37:16 +09:00
Kohya S
6b7823df07 Merge branch 'main' into dev 2023-12-21 21:33:43 +09:00
Kohya S
2186e417ba fix size of bucket < min_size ref #1008 2023-12-20 22:12:21 +09:00
Kohya S
1519e3067c Merge pull request #1008 from Cauldrath/zero_height_error
Fix zero height buckets
2023-12-20 22:09:04 +09:00
Kohya S
35e5424255 Merge pull request #1007 from Disty0/dev
IPEX fix SDPA
2023-12-20 21:53:11 +09:00
liubo0902
8c7d05afd2 speed up latents nan replace 2023-12-20 09:35:17 +08:00
Cauldrath
f8360a4831 Fix zero height buckets
If max_size is too large relative to max_reso, it will calculate a height of zero for some buckets.
This causes a crash later when it divides the width by the height.

This change also simplifies some math and consolidates the redundant "size" variable into "width".
2023-12-19 18:35:09 -05:00
Disty0
8556b9d7f5 IPEX fix SDPA 2023-12-19 22:59:06 +03:00
Kohya S
3efd90b2ad fix sampling in training with mutiple gpus ref #989 2023-12-15 22:35:54 +09:00
Kohya S
7adcd9cd1a Merge pull request #1003 from Disty0/dev
IPEX support for Torch 2.1 and fix dtype erros
2023-12-15 08:18:48 +09:00
Disty0
aff05e043f IPEX support for Torch 2.1 and fix dtype erros 2023-12-13 19:40:38 +03:00
Kohya S
ff2c0c192e update readme 2023-12-13 23:13:22 +09:00
Kohya S
d309a27a51 change option names, add ddp kwargs if needed ref #1000 2023-12-13 21:02:26 +09:00
Kohya S
471d274803 Merge pull request #1000 from Isotr0py/dev
Fix multi-gpu SDXL training
2023-12-13 20:52:11 +09:00
Kohya S
35f4c9b5c7 fix an error when keep_tokens_separator is not set ref #975 2023-12-12 21:43:21 +09:00
Kohya S
034a49c69d Merge pull request #975 from Linaqruf/dev
Add keep_tokens_separator as alternative for keep_tokens
2023-12-12 21:28:32 +09:00
Kohya S
3b6825d7e2 Merge pull request #986 from CjangCjengh/dev
Fixed the path error in finetune/make_captions.py
2023-12-12 21:17:24 +09:00
Isotr0py
bb5ae389f7 fix DDP SDXL training 2023-12-12 19:58:44 +08:00
Kohya S
4a2cef887c fix lllite training not working ref #913 2023-12-10 09:23:37 +09:00
Kohya S
42750f7846 fix error on pool_workaround in sdxl TE training ref #994 2023-12-10 09:18:33 +09:00
CjangCjengh
d31aa143f4 fix path error 2023-12-08 00:27:32 +08:00
CjangCjengh
710e777a92 fix path error 2023-12-08 00:22:13 +08:00
Kohya S
912dca8f65 fix duplicated sample gen for every epoch ref #907 2023-12-07 22:13:38 +09:00
Isotr0py
db84530074 Fix gradients synchronization for multi-GPUs training (#989)
* delete DDP wrapper

* fix train_db vae and train_network

* fix train_db vae and train_network unwrap

* network grad sync

---------

Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
2023-12-07 22:01:42 +09:00
Kohya S
72bbaac96d Merge pull request #985 from Disty0/dev
Update IPEX hijacks
2023-12-07 21:39:24 +09:00
Kohya S
5713d63dc5 add temporary workaround for playground-v2 2023-12-06 23:08:02 +09:00
CjangCjengh
d653e594c2 fix path error 2023-12-06 09:48:42 +08:00
Disty0
dd7bb33ab6 IPEX fix torch.UntypedStorage.is_cuda 2023-12-05 22:18:47 +03:00
Disty0
a9c6182b3f Cleanup IPEX libs 2023-12-05 19:52:31 +03:00
Disty0
3d70137d31 Disable IPEX attention if the GPU supports 64 bit 2023-12-05 19:40:16 +03:00
Disty0
bce9a081db Update IPEX hijacks 2023-12-05 14:17:31 +03:00
Kohya S
46cf41cc93 Merge pull request #961 from rockerBOO/attention-processor
Add attention processor
2023-12-03 21:24:12 +09:00
Kohya S
81a440c8e8 Merge pull request #955 from xzuyn/paged_adamw
Add PagedAdamW
2023-12-03 21:22:38 +09:00
Kohya S
f24a3b5282 show seed in generating samples 2023-12-03 21:15:30 +09:00
Kohya S
383b4a2c3e Merge pull request #907 from shirayu/add_option_sample_at_first
Add option --sample_at_first
2023-12-03 21:00:32 +09:00
Kohya S
df59822a27 Merge pull request #906 from shirayu/accept_scheduler_designation_in_training
Accept sampler designation in sampling of training
2023-12-03 20:46:16 +09:00
Kohya S
0908c5414d Merge pull request #978 from kohya-ss/dev
Dev
2023-12-03 18:27:29 +09:00
Kohya S
ee46134fa7 update readme 2023-12-03 18:24:50 +09:00
Kohya S
39bb319d4c fix to work with cfg scale=1 2023-11-29 12:42:12 +09:00
Furqanil Taqwa
1bdd83a85f remove unnecessary debug print 2023-11-28 17:26:27 +07:00
Furqanil Taqwa
1624c239c2 added keep_tokens_separator to dynamically keep token for being shuffled 2023-11-28 17:23:55 +07:00
Furqanil Taqwa
4a913ce61e initialize keep_tokens_separator to dataset config 2023-11-28 17:22:35 +07:00
Kohya S
764e333fa2 make slicing vae compatible with latest diffusers 2023-11-26 18:12:04 +09:00
Kohya S
c61e3bf4c9 make separate U-Net for inference 2023-11-26 18:11:30 +09:00
Kohya S
fc8649d80f Merge pull request #934 from feffy380/fix-minsnr-vpred-zsnr
Fix min-snr-gamma for v-prediction and ZSNR.
2023-11-25 21:19:39 +09:00
Kohya S
0fb9ecf1f3 format by black, add ja comment 2023-11-25 21:05:55 +09:00
Kohya S
97958400fb Merge pull request #936 from wkpark/model_util-update
use **kwargs and change svd() calling convention to make svd() reusable
2023-11-25 21:01:14 +09:00
Kohya S
6d6d86260b add Deep Shrink 2023-11-23 19:40:48 +09:00
rockerBOO
c856ea4249 Add attention processor 2023-11-19 12:11:36 -05:00
Kohya S
d0923d6710 add caption_separator option 2023-11-19 21:44:52 +09:00
Kohya S
f312522cef Merge pull request #913 from KohakuBlueleaf/custom-seperator
Add custom seperator for shuffle caption
2023-11-19 21:32:01 +09:00
xzuyn
da5a144589 Add PagedAdamW 2023-11-18 07:47:27 -05:00
Won-Kyu Park
2c1e669bd8 add min_diff, clamp_quantile args
based on https://github.com/bmaltais/kohya_ss/pull/1332 a9ec90c40a
2023-11-10 02:35:55 +09:00
Won-Kyu Park
e20e9f61ac use **kwargs and change svd() calling convention to make svd() reusable
* add required attributes to model_org, model_tuned, save_to
 * set "*_alpha" using str(float(foo))
2023-11-10 02:35:10 +09:00
feffy380
6b3148fd3f Fix min-snr-gamma for v-prediction and ZSNR.
This fixes min-snr for vpred+zsnr by dividing directly by SNR+1.
The old implementation did it in two steps: (min-snr/snr) * (snr/(snr+1)), which causes division by zero when combined with --zero_terminal_snr
2023-11-07 23:02:25 +01:00
Kohya S
95ae56bd22 Update README.md 2023-11-05 21:10:26 +09:00
Kohya S
990192d077 Merge pull request #927 from kohya-ss/dev
Dev
2023-11-05 19:31:41 +09:00
Kohya S
f3e69531c3 update readme 2023-11-05 19:30:52 +09:00
Kohya S
0cb3272bda update readme 2023-11-05 19:26:35 +09:00
Kohya S
6231aa91e2 common lr logging, set default None to ddp_timeout 2023-11-05 19:09:17 +09:00
Kohaku-Blueleaf
489b728dbc Fix typo again 2023-10-30 20:19:51 +08:00
Kohaku-Blueleaf
583e2b2d01 Fix typo 2023-10-30 20:02:04 +08:00
Kohaku-Blueleaf
5dc2a0d3fd Add custom seperator 2023-10-30 19:55:30 +08:00
Yuta Hayashibe
2c731418ad Added sample_images() for --sample_at_first 2023-10-29 22:08:42 +09:00
Yuta Hayashibe
5c150675bf Added --sample_at_first description 2023-10-29 21:46:47 +09:00
Yuta Hayashibe
fea810b437 Added --sample_at_first to generate sample images before training 2023-10-29 21:44:57 +09:00
Kohya S
96d877be90 support separate LR for Text Encoder for SD1/2 2023-10-29 21:30:32 +09:00
Yuta Hayashibe
40d917b0fe Removed incorrect comments 2023-10-29 21:02:44 +09:00
Kohya S
e72020ae01 update readme 2023-10-29 20:52:43 +09:00
Kohya S
01d929ee2a support separate learning rates for TE1/2 2023-10-29 20:38:01 +09:00
Yuta Hayashibe
cf876fcdb4 Accept --ss to set sample_sampler dynamically 2023-10-29 20:15:04 +09:00
Yuta Hayashibe
291c29caaf Added a function line_to_prompt_dict() and removed duplicated initializations 2023-10-29 19:57:25 +09:00
Yuta Hayashibe
01e00ac1b0 Make a function get_my_scheduler() 2023-10-29 19:46:02 +09:00
Kohya S
a9ed4ed8a8 Merge pull request #900 from xzuyn/paged_adamw_32bit
Add PagedAdamW32bit
2023-10-29 15:01:55 +09:00
Kohya S
9d6a5a0c79 Merge pull request #899 from shirayu/use_moving_average
Show moving average loss in the progress bar
2023-10-29 14:37:58 +09:00
Kohya S
fb97a7aab1 Merge pull request #898 from shirayu/update_repare_buckets_latents
Fix a typo and add assertions in making buckets
2023-10-29 14:29:53 +09:00
Kohaku-Blueleaf
1cefb2a753 Better implementation for te autocast (#895)
* Better implementation for te

* Fix some misunderstanding

* as same as unet, add explicit convert

* Better cache TE and TE lr

* Fix with list

* Add timeout settings

* Fix arg style
2023-10-28 15:49:59 +09:00
Yuta Hayashibe
63992b81c8 Fix initialize place of loss_recorder 2023-10-27 21:13:29 +09:00
xzuyn
d8f68674fb Update train_util.py 2023-10-27 07:05:53 -04:00
Yuta Hayashibe
9d00c8eea2 Use LossRecorder 2023-10-27 18:31:36 +09:00
Yuta Hayashibe
0d21925bdf Use @property 2023-10-27 18:14:27 +09:00
Yuta Hayashibe
efef5c8ead Show "avr_loss" instead of "loss" because it is moving average 2023-10-27 17:59:58 +09:00
Yuta Hayashibe
3d2bb1a8f1 Add LossRecorder and use moving average in all places 2023-10-27 17:49:49 +09:00
Yuta Hayashibe
837a4dddb8 Added assertions 2023-10-26 13:34:36 +09:00
Yuta Hayashibe
b2626bc7a9 Fix a typo 2023-10-26 00:51:17 +09:00
青龍聖者@bdsqlsz
202f2c3292 Debias Estimation loss (#889)
* update for bnb 0.41.1

* fixed generate_controlnet_subsets_config for training

* Revert "update for bnb 0.41.1"

This reverts commit 70bd3612d8.

* add debiased_estimation_loss

* add train_network

* Revert "add train_network"

This reverts commit 6539363c5c.

* Update train_network.py
2023-10-23 22:59:14 +09:00
Kohya S
2a23713f71 Merge pull request #872 from kohya-ss/dev
fix make_captions_by_git, improve image generation scripts
2023-10-11 07:56:39 +09:00
Kohya S
681034d001 update readme 2023-10-11 07:54:30 +09:00
Kohya S
17813ff5b4 remove workaround for transfomers bs>1 close #869 2023-10-11 07:40:12 +09:00
Kohya S
3e81bd6b67 fix network_merge, add regional mask as color code 2023-10-09 23:07:14 +09:00
Kohya S
23ae358e0f Merge branch 'main' into dev 2023-10-09 21:42:13 +09:00
Kohya S
f611726364 add network_merge_n_models option 2023-10-09 21:41:50 +09:00
39 changed files with 2178 additions and 1041 deletions

View File

@@ -18,4 +18,4 @@ jobs:
- uses: actions/checkout@v4
- name: typos-action
uses: crate-ci/typos@v1.16.15
uses: crate-ci/typos@v1.16.26

View File

@@ -249,65 +249,33 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum
## Change History
### Oct 9. 2023 / 2023/10/9
### Jan 15, 2024 / 2024/1/15: v0.8.0
- `tag_images_by_wd_14_tagger.py` now supports Onnx. If you use Onnx, TensorFlow is not required anymore. [#864](https://github.com/kohya-ss/sd-scripts/pull/864) Thanks to Isotr0py!
- `--onnx` option is added. If you use Onnx, specify `--onnx` option.
- Please install Onnx and other required packages.
1. Uninstall TensorFlow.
1. `pip install tensorboard==2.14.1` This is required for the specified version of protobuf.
1. `pip install protobuf==3.20.3` This is required for Onnx.
1. `pip install onnx==1.14.1`
1. `pip install onnxruntime-gpu==1.16.0` or `pip install onnxruntime==1.16.0`
- `--append_tags` option is added to `tag_images_by_wd_14_tagger.py`. This option appends the tags to the existing tags, instead of replacing them. [#858](https://github.com/kohya-ss/sd-scripts/pull/858) Thanks to a-l-e-x-d-s-9!
- [OFT](https://oft.wyliu.com/) is now supported.
- You can use `networks.oft` for the network module in `sdxl_train_network.py`. The usage is the same as `networks.lora`. Some options are not supported.
- `sdxl_gen_img.py` also supports OFT as `--network_module`.
- OFT only supports SDXL currently. Because current OFT tweaks Q/K/V and O in the transformer, and SD1/2 have extremely fewer transformers than SDXL.
- The implementation is heavily based on laksjdjf's [OFT implementation](https://github.com/laksjdjf/sd-trainer/blob/dev/networks/lora_modules.py). Thanks to laksjdjf!
- Other bug fixes and improvements.
- Diffusers, Accelerate, Transformers and other related libraries have been updated. Please update the libraries with [Upgrade](#upgrade).
- Some model files (Text Encoder without position_id) based on the latest Transformers can be loaded.
- `torch.compile` is supported (experimental). PR [#1024](https://github.com/kohya-ss/sd-scripts/pull/1024) Thanks to p1atdev!
- This feature works only on Linux or WSL.
- Please specify `--torch_compile` option in each training script.
- You can select the backend with `--dynamo_backend` option. The default is `"inductor"`. `inductor` or `eager` seems to work.
- Please use `--spda` option instead of `--xformers` option.
- PyTorch 2.1 or later is recommended.
- Please see [PR](https://github.com/kohya-ss/sd-scripts/pull/1024) for details.
- The session name for wandb can be specified with `--wandb_run_name` option. PR [#1032](https://github.com/kohya-ss/sd-scripts/pull/1032) Thanks to hopl1t!
- IPEX library is updated. PR [#1030](https://github.com/kohya-ss/sd-scripts/pull/1030) Thanks to Disty0!
- Fixed a bug that Diffusers format model cannot be saved.
- `tag_images_by_wd_14_tagger.py` が Onnx をサポートしました。Onnx を使用する場合は TensorFlow は不要です。[#864](https://github.com/kohya-ss/sd-scripts/pull/864) Isotr0py氏に感謝します。
- Onnxを使用する場合は、`--onnx` オプションを指定してください
- Onnx とその他の必要なパッケージをインストールしてください
1. TensorFlow をアンインストールしてください
1. `pip install tensorboard==2.14.1` protobufの指定バージョンにこれが必要
1. `pip install protobuf==3.20.3` Onnxのために必要
1. `pip install onnx==1.14.1`
1. `pip install onnxruntime-gpu==1.16.0` または `pip install onnxruntime==1.16.0`
- `tag_images_by_wd_14_tagger.py``--append_tags` オプションが追加されました。このオプションを指定すると、既存のタグに上書きするのではなく、新しいタグのみが既存のタグに追加されます。 [#858](https://github.com/kohya-ss/sd-scripts/pull/858) a-l-e-x-d-s-9氏に感謝します
- [OFT](https://oft.wyliu.com/) をサポートしました
- `sdxl_train_network.py``--network_module``networks.oft` を指定してください。使用方法は `networks.lora` と同様ですが一部のオプションは未サポートです。
- `sdxl_gen_img.py` でも同様に OFT を指定できます
- OFT は現在 SDXL のみサポートしています。OFT は現在 transformer の Q/K/V と O を変更しますが、SD1/2 は transformer の数が SDXL よりも極端に少ないためです。
- 実装は laksjdjf 氏の [OFT実装](https://github.com/laksjdjf/sd-trainer/blob/dev/networks/lora_modules.py) を多くの部分で参考にしています。laksjdjf 氏に感謝します。
- その他のバグ修正と改善。
### Oct 1. 2023 / 2023/10/1
- SDXL training is now available in the main branch. The sdxl branch is merged into the main branch.
- [SAI Model Spec](https://github.com/Stability-AI/ModelSpec) metadata is now supported partially. `hash_sha256` is not supported yet.
- The main items are set automatically.
- You can set title, author, description, license and tags with `--metadata_xxx` options in each training script.
- Merging scripts also support minimum SAI Model Spec metadata. See the help message for the usage.
- Metadata editor will be available soon.
- `bitsandbytes` is now optional. Please install it if you want to use it. The insructions are in the later section.
- `albumentations` is not required anymore.
- `--v_pred_like_loss ratio` option is added. This option adds the loss like v-prediction loss in SDXL training. `0.1` means that the loss is added 10% of the v-prediction loss. The default value is None (disabled).
- In v-prediction, the loss is higher in the early timesteps (near the noise). This option can be used to increase the loss in the early timesteps.
- Arbitrary options can be used for Diffusers' schedulers. For example `--lr_scheduler_args "lr_end=1e-8"`.
- LoRA-FA is added experimentally. Specify `--network_module networks.lora_fa` option instead of `--network_module networks.lora`. The trained model can be used as a normal LoRA model.
- JPEG XL is supported. [#786](https://github.com/kohya-ss/sd-scripts/pull/786)
- Input perturbation noise is added. See [#798](https://github.com/kohya-ss/sd-scripts/pull/798) for details.
- Dataset subset now has `caption_prefix` and `caption_suffix` options. The strings are added to the beginning and the end of the captions before shuffling. You can specify the options in `.toml`.
- Intel ARC support with IPEX is added. [#825](https://github.com/kohya-ss/sd-scripts/pull/825)
- Other bug fixes and improvements.
- Diffusers、Accelerate、Transformers 等の関連ライブラリを更新しました。[Upgrade](#upgrade) を参照し更新をお願いします。
- 最新の Transformers を前提とした一部のモデルファイルText Encoder が position_id を持たないもの)が読み込めるようになりました
- `torch.compile` がサポートされしました(実験的)。 PR [#1024](https://github.com/kohya-ss/sd-scripts/pull/1024) p1atdev 氏に感謝します
- Linux または WSL でのみ動作します
- 各学習スクリプトで `--torch_compile` オプションを指定してください
- `--dynamo_backend` オプションで使用される backend を選択できます。デフォルトは `"inductor"` です。 `inductor` または `eager` が動作するようです
- `--xformers` オプションとは互換性がありません。 代わりに `--spda` オプションを使用してください。
- PyTorch 2.1以降を推奨します。
- 詳細は [PR](https://github.com/kohya-ss/sd-scripts/pull/1024) をご覧ください
- wandb 保存時のセッション名が各学習スクリプトの `--wandb_run_name` オプションで指定できるようになりました。 PR [#1032](https://github.com/kohya-ss/sd-scripts/pull/1032) hopl1t 氏に感謝します
- IPEX ライブラリが更新されました。[PR #1030](https://github.com/kohya-ss/sd-scripts/pull/1030) Disty0 氏に感謝します。
- Diffusers 形式でのモデル保存ができなくなっていた不具合を修正しました
Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates.

View File

@@ -374,6 +374,10 @@ classがひとつで対象が複数の場合、正則化画像フォルダはひ
サンプル出力するステップ数またはエポック数を指定します。この数ごとにサンプル出力します。両方指定するとエポック数が優先されます。
- `--sample_at_first`
学習開始前にサンプル出力します。学習前との比較ができます。
- `--sample_prompts`
サンプル出力用プロンプトのファイルを指定します。

View File

@@ -10,10 +10,13 @@ import toml
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
@@ -32,6 +35,7 @@ from library.custom_train_functions import (
get_weighted_text_embeddings,
prepare_scheduler_for_custom_training,
scale_v_prediction_loss_like_noise_prediction,
apply_debiased_estimation,
)
@@ -192,14 +196,20 @@ def train(args):
for m in training_models:
m.requires_grad_(True)
params = []
for m in training_models:
params.extend(m.parameters())
params_to_optimize = params
trainable_params = []
if args.learning_rate_te is None or not args.train_text_encoder:
for m in training_models:
trainable_params.extend(m.parameters())
else:
trainable_params = [
{"params": list(unet.parameters()), "lr": args.learning_rate},
{"params": list(text_encoder.parameters()), "lr": args.learning_rate_te},
]
# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
_, _, optimizer = train_util.get_optimizer(args, trainable_params=trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
@@ -243,9 +253,6 @@ def train(args):
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
# transform DDP after prepare
text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
@@ -284,10 +291,16 @@ def train(args):
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs['wandb'] = {'name': args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
# For --sample_at_first
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
@@ -295,7 +308,6 @@ def train(args):
for m in training_models:
m.train()
loss_total = 0
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
@@ -339,15 +351,17 @@ def train(args):
else:
target = noise
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred:
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss:
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
loss = loss.mean() # mean over batch dimension
else:
@@ -396,26 +410,20 @@ def train(args):
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if (
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
): # tracking d*lr value
logs["lr/d*lr"] = (
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
)
logs = {"loss": current_loss}
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
accelerator.log(logs, step=global_step)
# TODO moving averageにする
loss_total += current_loss
avr_loss = loss_total / (step + 1)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(train_dataloader)}
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
@@ -474,6 +482,12 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
parser.add_argument(
"--learning_rate_te",
type=float,
default=None,
help="learning rate for text encoder, default is same as unet / Text Encoderの学習率、デフォルトはunetと同じ",
)
return parser

View File

@@ -13,7 +13,7 @@ import torch
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
sys.path.append(os.path.dirname(__file__))
from blip.blip import blip_decoder
from blip.blip import blip_decoder, is_url
import library.train_util as train_util
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -76,6 +76,8 @@ def main(args):
cwd = os.getcwd()
print("Current Working Directory is: ", cwd)
os.chdir("finetune")
if not is_url(args.caption_weights) and not os.path.isfile(args.caption_weights):
args.caption_weights = os.path.join("..", args.caption_weights)
print(f"load images from {args.train_data_dir}")
train_data_dir_path = Path(args.train_data_dir)

View File

@@ -52,6 +52,9 @@ def collate_fn_remove_corrupted(batch):
def main(args):
r"""
transformers 4.30.2で、バッチサイズ>1でも動くようになったので、以下コメントアウト
# GITにバッチサイズが1より大きくても動くようにパッチを当てる: transformers 4.26.0用
org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation
curr_batch_size = [args.batch_size] # ループの最後で件数がbatch_size未満になるので入れ替えられるように
@@ -65,6 +68,7 @@ def main(args):
return input_ids
GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
"""
print(f"load images from {args.train_data_dir}")
train_data_dir_path = Path(args.train_data_dir)
@@ -81,7 +85,7 @@ def main(args):
def run_batch(path_imgs):
imgs = [im for _, im in path_imgs]
curr_batch_size[0] = len(path_imgs)
# curr_batch_size[0] = len(path_imgs)
inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE) # 画像はpil形式
generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length)
captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True)

View File

@@ -215,7 +215,7 @@ def setup_parser() -> argparse.ArgumentParser:
help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)",
)
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最解像度")
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最解像度")
parser.add_argument(
"--bucket_reso_steps",
type=int,

View File

@@ -160,7 +160,9 @@ def main(args):
tag_freq = {}
undesired_tags = set(args.undesired_tags.split(","))
caption_separator = args.caption_separator
stripped_caption_separator = caption_separator.strip()
undesired_tags = set(args.undesired_tags.split(stripped_caption_separator))
def run_batch(path_imgs):
imgs = np.array([im for _, im in path_imgs])
@@ -194,7 +196,7 @@ def main(args):
if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
general_tag_text += ", " + tag_name
general_tag_text += caption_separator + tag_name
combined_tags.append(tag_name)
elif i >= len(general_tags) and p >= args.character_threshold:
tag_name = character_tags[i - len(general_tags)]
@@ -203,18 +205,18 @@ def main(args):
if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
character_tag_text += ", " + tag_name
character_tag_text += caption_separator + tag_name
combined_tags.append(tag_name)
# 先頭のカンマを取る
if len(general_tag_text) > 0:
general_tag_text = general_tag_text[2:]
general_tag_text = general_tag_text[len(caption_separator) :]
if len(character_tag_text) > 0:
character_tag_text = character_tag_text[2:]
character_tag_text = character_tag_text[len(caption_separator) :]
caption_file = os.path.splitext(image_path)[0] + args.caption_extension
tag_text = ", ".join(combined_tags)
tag_text = caption_separator.join(combined_tags)
if args.append_tags:
# Check if file exists
@@ -224,13 +226,13 @@ def main(args):
existing_content = f.read().strip("\n") # Remove newlines
# Split the content into tags and store them in a list
existing_tags = [tag.strip() for tag in existing_content.split(",") if tag.strip()]
existing_tags = [tag.strip() for tag in existing_content.split(stripped_caption_separator) if tag.strip()]
# Check and remove repeating tags in tag_text
new_tags = [tag for tag in combined_tags if tag not in existing_tags]
# Create new tag_text
tag_text = ", ".join(existing_tags + new_tags)
tag_text = caption_separator.join(existing_tags + new_tags)
with open(caption_file, "wt", encoding="utf-8") as f:
f.write(tag_text + "\n")
@@ -350,6 +352,12 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する")
parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する")
parser.add_argument("--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する")
parser.add_argument(
"--caption_separator",
type=str,
default=", ",
help="Separator for captions, include space if needed / キャプションの区切り文字、必要ならスペースを含めてください",
)
return parser

View File

@@ -65,10 +65,13 @@ import re
import diffusers
import numpy as np
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
@@ -102,7 +105,7 @@ import library.train_util as train_util
from networks.lora import LoRANetwork
import tools.original_control_net as original_control_net
from tools.original_control_net import ControlNetInfo
from library.original_unet import UNet2DConditionModel
from library.original_unet import UNet2DConditionModel, InferUNet2DConditionModel
from library.original_unet import FlashAttentionFunction
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
@@ -375,7 +378,7 @@ class PipelineLike:
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
unet: InferUNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
clip_skip: int,
clip_model: CLIPModel,
@@ -954,7 +957,7 @@ class PipelineLike:
text_emb_last = torch.stack(text_emb_last)
else:
text_emb_last = text_embeddings
for i, t in enumerate(tqdm(timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
@@ -2193,6 +2196,7 @@ def main(args):
)
original_unet.load_state_dict(unet.state_dict())
unet = original_unet
unet: InferUNet2DConditionModel = InferUNet2DConditionModel(unet)
# VAEを読み込む
if args.vae is not None:
@@ -2349,13 +2353,20 @@ def main(args):
vae = sli_vae
del sli_vae
vae.to(dtype).to(device)
vae.eval()
text_encoder.to(dtype).to(device)
unet.to(dtype).to(device)
text_encoder.eval()
unet.eval()
if clip_model is not None:
clip_model.to(dtype).to(device)
clip_model.eval()
if vgg16_model is not None:
vgg16_model.to(dtype).to(device)
vgg16_model.eval()
# networkを組み込む
if args.network_module:
@@ -2363,12 +2374,19 @@ def main(args):
network_default_muls = []
network_pre_calc = args.network_pre_calc
# merge関連の引数を統合する
if args.network_merge:
network_merge = len(args.network_module) # all networks are merged
elif args.network_merge_n_models:
network_merge = args.network_merge_n_models
else:
network_merge = 0
for i, network_module in enumerate(args.network_module):
print("import network module:", network_module)
imported_module = importlib.import_module(network_module)
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
network_default_muls.append(network_mul)
net_kwargs = {}
if args.network_args and i < len(args.network_args):
@@ -2379,31 +2397,32 @@ def main(args):
key, value = net_arg.split("=")
net_kwargs[key] = value
if args.network_weights and i < len(args.network_weights):
network_weight = args.network_weights[i]
print("load network weights from:", network_weight)
if model_util.is_safetensors(network_weight) and args.network_show_meta:
from safetensors.torch import safe_open
with safe_open(network_weight, framework="pt") as f:
metadata = f.metadata()
if metadata is not None:
print(f"metadata for: {network_weight}: {metadata}")
network, weights_sd = imported_module.create_network_from_weights(
network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs
)
else:
if args.network_weights is None or len(args.network_weights) <= i:
raise ValueError("No weight. Weight is required.")
network_weight = args.network_weights[i]
print("load network weights from:", network_weight)
if model_util.is_safetensors(network_weight) and args.network_show_meta:
from safetensors.torch import safe_open
with safe_open(network_weight, framework="pt") as f:
metadata = f.metadata()
if metadata is not None:
print(f"metadata for: {network_weight}: {metadata}")
network, weights_sd = imported_module.create_network_from_weights(
network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs
)
if network is None:
return
mergeable = network.is_mergeable()
if args.network_merge and not mergeable:
if network_merge and not mergeable:
print("network is not mergiable. ignore merge option.")
if not args.network_merge or not mergeable:
if not mergeable or i >= network_merge:
# not merging
network.apply_to(text_encoder, unet)
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
print(f"weights are loaded: {info}")
@@ -2417,6 +2436,7 @@ def main(args):
network.backup_weights()
networks.append(network)
network_default_muls.append(network_mul)
else:
network.merge_to(text_encoder, unet, weights_sd, dtype, device)
@@ -2489,6 +2509,10 @@ def main(args):
if args.diffusers_xformers:
pipe.enable_xformers_memory_efficient_attention()
# Deep Shrink
if args.ds_depth_1 is not None:
unet.set_deep_shrink(args.ds_depth_1, args.ds_timesteps_1, args.ds_depth_2, args.ds_timesteps_2, args.ds_ratio)
# Extended Textual Inversion および Textual Inversionを処理する
if args.XTI_embeddings:
diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI
@@ -2712,9 +2736,18 @@ def main(args):
size = None
for i, network in enumerate(networks):
if i < 3:
if (i < 3 and args.network_regional_mask_max_color_codes is None) or i < args.network_regional_mask_max_color_codes:
np_mask = np.array(mask_images[0])
np_mask = np_mask[:, :, i]
if args.network_regional_mask_max_color_codes:
# カラーコードでマスクを指定する
ch0 = (i + 1) & 1
ch1 = ((i + 1) >> 1) & 1
ch2 = ((i + 1) >> 2) & 1
np_mask = np.all(np_mask == np.array([ch0, ch1, ch2]) * 255, axis=2)
np_mask = np_mask.astype(np.uint8) * 255
else:
np_mask = np_mask[:, :, i]
size = np_mask.shape
else:
np_mask = np.full(size, 255, dtype=np.uint8)
@@ -3064,6 +3097,13 @@ def main(args):
clip_prompt = None
network_muls = None
# Deep Shrink
ds_depth_1 = None # means no override
ds_timesteps_1 = args.ds_timesteps_1
ds_depth_2 = args.ds_depth_2
ds_timesteps_2 = args.ds_timesteps_2
ds_ratio = args.ds_ratio
prompt_args = raw_prompt.strip().split(" --")
prompt = prompt_args[0]
print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}")
@@ -3135,10 +3175,51 @@ def main(args):
print(f"network mul: {network_muls}")
continue
# Deep Shrink
m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink depth 1
ds_depth_1 = int(m.group(1))
print(f"deep shrink depth 1: {ds_depth_1}")
continue
m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink timesteps 1
ds_timesteps_1 = int(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
print(f"deep shrink timesteps 1: {ds_timesteps_1}")
continue
m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink depth 2
ds_depth_2 = int(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
print(f"deep shrink depth 2: {ds_depth_2}")
continue
m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink timesteps 2
ds_timesteps_2 = int(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
print(f"deep shrink timesteps 2: {ds_timesteps_2}")
continue
m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink ratio
ds_ratio = float(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
print(f"deep shrink ratio: {ds_ratio}")
continue
except ValueError as ex:
print(f"Exception in parsing / 解析エラー: {parg}")
print(ex)
# override Deep Shrink
if ds_depth_1 is not None:
if ds_depth_1 < 0:
ds_depth_1 = args.ds_depth_1 or 3
unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio)
# prepare seed
if seeds is not None: # given in prompt
# 数が足りないなら前のをそのまま使う
@@ -3367,10 +3448,19 @@ def setup_parser() -> argparse.ArgumentParser:
"--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数"
)
parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する")
parser.add_argument(
"--network_merge_n_models", type=int, default=None, help="merge this number of networks / この数だけネットワークをマージする"
)
parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする")
parser.add_argument(
"--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する"
)
parser.add_argument(
"--network_regional_mask_max_color_codes",
type=int,
default=None,
help="max color codes for regional mask (default is None, mask by channel) / regional maskの最大色数デフォルトはNoneでチャンネルごとのマスク",
)
parser.add_argument(
"--textual_inversion_embeddings",
type=str,
@@ -3479,6 +3569,30 @@ def setup_parser() -> argparse.ArgumentParser:
# "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像"
# )
# Deep Shrink
parser.add_argument(
"--ds_depth_1",
type=int,
default=None,
help="Enable Deep Shrink with this depth 1, valid values are 0 to 3 / Deep Shrinkをこのdepthで有効にする",
)
parser.add_argument(
"--ds_timesteps_1",
type=int,
default=650,
help="Apply Deep Shrink depth 1 until this timesteps / Deep Shrink depth 1を適用するtimesteps",
)
parser.add_argument("--ds_depth_2", type=int, default=None, help="Deep Shrink depth 2 / Deep Shrinkのdepth 2")
parser.add_argument(
"--ds_timesteps_2",
type=int,
default=650,
help="Apply Deep Shrink depth 2 until this timesteps / Deep Shrink depth 2を適用するtimesteps",
)
parser.add_argument(
"--ds_ratio", type=float, default=0.5, help="Deep Shrink ratio for downsampling / Deep Shrinkのdownsampling比率"
)
return parser

View File

@@ -51,7 +51,9 @@ class BaseSubsetParams:
image_dir: Optional[str] = None
num_repeats: int = 1
shuffle_caption: bool = False
caption_separator: str = ',',
keep_tokens: int = 0
keep_tokens_separator: str = None,
color_aug: bool = False
flip_aug: bool = False
face_crop_aug_range: Optional[Tuple[float, float]] = None
@@ -159,6 +161,7 @@ class ConfigSanitizer:
"random_crop": bool,
"shuffle_caption": bool,
"keep_tokens": int,
"keep_tokens_separator": str,
"token_warmup_min": int,
"token_warmup_step": Any(float,int),
"caption_prefix": str,
@@ -460,6 +463,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
num_repeats: {subset.num_repeats}
shuffle_caption: {subset.shuffle_caption}
keep_tokens: {subset.keep_tokens}
keep_tokens_separator: {subset.keep_tokens_separator}
caption_dropout_rate: {subset.caption_dropout_rate}
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}

View File

@@ -57,10 +57,13 @@ def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
noise_scheduler.alphas_cumprod = alphas_cumprod
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False):
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device) # from paper
min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
if v_prediction:
snr_weight = torch.div(min_snr_gamma, snr+1).float().to(loss.device)
else:
snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device)
loss = loss * snr_weight
return loss
@@ -86,6 +89,12 @@ def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_los
loss = loss + loss / scale * v_pred_like_loss
return loss
def apply_debiased_estimation(loss, timesteps, noise_scheduler):
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
weight = 1/torch.sqrt(snr_t)
loss = weight * loss
return loss
# TODO train_utilと分散しているのでどちらかに寄せる
@@ -108,6 +117,11 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted
default=None,
help="add v-prediction like loss multiplied by this value / v-prediction lossをこの値をかけたものをlossに加算する",
)
parser.add_argument(
"--debiased_estimation_loss",
action="store_true",
help="debiased estimation loss / debiased estimation loss",
)
if support_weighted_captions:
parser.add_argument(
"--weighted_captions",

View File

@@ -4,13 +4,12 @@ import contextlib
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
from .hijacks import ipex_hijacks
from .attention import attention_init
# pylint: disable=protected-access, missing-function-docstring, line-too-long
def ipex_init(): # pylint: disable=too-many-statements
try:
#Replace cuda with xpu:
# Replace cuda with xpu:
torch.cuda.current_device = torch.xpu.current_device
torch.cuda.current_stream = torch.xpu.current_stream
torch.cuda.device = torch.xpu.device
@@ -30,6 +29,7 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.FloatTensor = torch.xpu.FloatTensor
torch.Tensor.cuda = torch.Tensor.xpu
torch.Tensor.is_cuda = torch.Tensor.is_xpu
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
torch.cuda._initialized = torch.xpu.lazy_init._initialized
torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
@@ -90,9 +90,9 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.CharStorage = torch.xpu.CharStorage
torch.cuda.__file__ = torch.xpu.__file__
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
#torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
# torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
#Memory:
# Memory:
torch.cuda.memory = torch.xpu.memory
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
torch.xpu.empty_cache = lambda: None
@@ -112,7 +112,7 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict
torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats
#RNG:
# RNG:
torch.cuda.get_rng_state = torch.xpu.get_rng_state
torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all
torch.cuda.set_rng_state = torch.xpu.set_rng_state
@@ -123,7 +123,7 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.seed_all = torch.xpu.seed_all
torch.cuda.initial_seed = torch.xpu.initial_seed
#AMP:
# AMP:
torch.cuda.amp = torch.xpu.amp
if not hasattr(torch.cuda.amp, "common"):
torch.cuda.amp.common = contextlib.nullcontext()
@@ -138,12 +138,13 @@ def ipex_init(): # pylint: disable=too-many-statements
except Exception: # pylint: disable=broad-exception-caught
torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
#C
# C
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_eu_count
ipex._C._DeviceProperties.major = 2023
ipex._C._DeviceProperties.minor = 2
#Fix functions with ipex:
# Fix functions with ipex:
torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
torch._utils._get_available_device_type = lambda: "xpu"
torch.has_cuda = True
@@ -156,20 +157,14 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.get_device_properties.minor = 7
torch.cuda.ipc_collect = lambda *args, **kwargs: None
torch.cuda.utilization = lambda *args, **kwargs: 0
if hasattr(torch.xpu, 'getDeviceIdListForCard'):
torch.cuda.getDeviceIdListForCard = torch.xpu.getDeviceIdListForCard
torch.cuda.get_device_id_list_per_card = torch.xpu.getDeviceIdListForCard
else:
torch.cuda.getDeviceIdListForCard = torch.xpu.get_device_id_list_per_card
torch.cuda.get_device_id_list_per_card = torch.xpu.get_device_id_list_per_card
ipex_hijacks()
attention_init()
try:
from .diffusers import ipex_diffusers
ipex_diffusers()
except Exception: # pylint: disable=broad-exception-caught
pass
if not torch.xpu.has_fp64_dtype():
try:
from .diffusers import ipex_diffusers
ipex_diffusers()
except Exception: # pylint: disable=broad-exception-caught
pass
except Exception as e:
return False, e
return True, None

View File

@@ -1,45 +1,98 @@
import os
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
from functools import cache
# pylint: disable=protected-access, missing-function-docstring, line-too-long
original_torch_bmm = torch.bmm
def torch_bmm(input, mat2, *, out=None):
if input.dtype != mat2.dtype:
mat2 = mat2.to(input.dtype)
# ARC GPUs can't allocate more than 4GB to a single block so we slice the attetion layers
#ARC GPUs can't allocate more than 4GB to a single block, Slice it:
batch_size_attention, input_tokens, mat2_shape = input.shape[0], input.shape[1], mat2.shape[2]
block_multiply = input.element_size()
slice_block_size = input_tokens * mat2_shape / 1024 / 1024 * block_multiply
sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4))
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
# Find something divisible with the input_tokens
@cache
def find_slice_size(slice_size, slice_block_size):
while (slice_size * slice_block_size) > attention_slice_rate:
slice_size = slice_size // 2
if slice_size <= 1:
slice_size = 1
break
return slice_size
# Find slice sizes for SDPA
@cache
def find_sdpa_slice_sizes(query_shape, query_element_size):
if len(query_shape) == 3:
batch_size_attention, query_tokens, shape_three = query_shape
shape_four = 1
else:
batch_size_attention, query_tokens, shape_three, shape_four = query_shape
slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size
block_size = batch_size_attention * slice_block_size
split_slice_size = batch_size_attention
if block_size > 4:
split_2_slice_size = query_tokens
split_3_slice_size = shape_three
do_split = False
do_split_2 = False
do_split_3 = False
if block_size > sdpa_slice_trigger_rate:
do_split = True
#Find something divisible with the input_tokens
while (split_slice_size * slice_block_size) > 4:
split_slice_size = split_slice_size // 2
if split_slice_size <= 1:
split_slice_size = 1
break
else:
do_split = False
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
if split_slice_size * slice_block_size > attention_slice_rate:
slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size
do_split_2 = True
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size
do_split_3 = True
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
# Find slice sizes for BMM
@cache
def find_bmm_slice_sizes(input_shape, input_element_size, mat2_shape):
batch_size_attention, input_tokens, mat2_atten_shape = input_shape[0], input_shape[1], mat2_shape[2]
slice_block_size = input_tokens * mat2_atten_shape / 1024 / 1024 * input_element_size
block_size = batch_size_attention * slice_block_size
split_slice_size = batch_size_attention
split_2_slice_size = input_tokens
if split_slice_size * slice_block_size > 4:
slice_block_size2 = split_slice_size * mat2_shape / 1024 / 1024 * block_multiply
do_split_2 = True
#Find something divisible with the input_tokens
while (split_2_slice_size * slice_block_size2) > 4:
split_2_slice_size = split_2_slice_size // 2
if split_2_slice_size <= 1:
split_2_slice_size = 1
break
else:
do_split_2 = False
split_3_slice_size = mat2_atten_shape
do_split = False
do_split_2 = False
do_split_3 = False
if block_size > attention_slice_rate:
do_split = True
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
if split_slice_size * slice_block_size > attention_slice_rate:
slice_2_block_size = split_slice_size * mat2_atten_shape / 1024 / 1024 * input_element_size
do_split_2 = True
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
slice_3_block_size = split_slice_size * split_2_slice_size / 1024 / 1024 * input_element_size
do_split_3 = True
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
original_torch_bmm = torch.bmm
def torch_bmm_32_bit(input, mat2, *, out=None):
if input.device.type != "xpu":
return original_torch_bmm(input, mat2, out=out)
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_bmm_slice_sizes(input.shape, input.element_size(), mat2.shape)
# Slice BMM
if do_split:
batch_size_attention, input_tokens, mat2_atten_shape = input.shape[0], input.shape[1], mat2.shape[2]
hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype)
for i in range(batch_size_attention // split_slice_size):
start_idx = i * split_slice_size
@@ -48,11 +101,21 @@ def torch_bmm(input, mat2, *, out=None):
for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name
start_idx_2 = i2 * split_2_slice_size
end_idx_2 = (i2 + 1) * split_2_slice_size
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm(
input[start_idx:end_idx, start_idx_2:end_idx_2],
mat2[start_idx:end_idx, start_idx_2:end_idx_2],
out=out
)
if do_split_3:
for i3 in range(mat2_atten_shape // split_3_slice_size): # pylint: disable=invalid-name
start_idx_3 = i3 * split_3_slice_size
end_idx_3 = (i3 + 1) * split_3_slice_size
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_torch_bmm(
input[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
mat2[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
out=out
)
else:
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm(
input[start_idx:end_idx, start_idx_2:end_idx_2],
mat2[start_idx:end_idx, start_idx_2:end_idx_2],
out=out
)
else:
hidden_states[start_idx:end_idx] = original_torch_bmm(
input[start_idx:end_idx],
@@ -64,46 +127,14 @@ def torch_bmm(input, mat2, *, out=None):
return hidden_states
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
#ARC GPUs can't allocate more than 4GB to a single block, Slice it:
if len(query.shape) == 3:
batch_size_attention, query_tokens, shape_four = query.shape
shape_one = 1
no_shape_one = True
else:
shape_one, batch_size_attention, query_tokens, shape_four = query.shape
no_shape_one = False
block_multiply = query.element_size()
slice_block_size = shape_one * query_tokens * shape_four / 1024 / 1024 * block_multiply
block_size = batch_size_attention * slice_block_size
split_slice_size = batch_size_attention
if block_size > 4:
do_split = True
#Find something divisible with the shape_one
while (split_slice_size * slice_block_size) > 4:
split_slice_size = split_slice_size // 2
if split_slice_size <= 1:
split_slice_size = 1
break
else:
do_split = False
split_2_slice_size = query_tokens
if split_slice_size * slice_block_size > 4:
slice_block_size2 = shape_one * split_slice_size * shape_four / 1024 / 1024 * block_multiply
do_split_2 = True
#Find something divisible with the batch_size_attention
while (split_2_slice_size * slice_block_size2) > 4:
split_2_slice_size = split_2_slice_size // 2
if split_2_slice_size <= 1:
split_2_slice_size = 1
break
else:
do_split_2 = False
def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
if query.device.type != "xpu":
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_sdpa_slice_sizes(query.shape, query.element_size())
# Slice SDPA
if do_split:
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
for i in range(batch_size_attention // split_slice_size):
start_idx = i * split_slice_size
@@ -112,7 +143,18 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
start_idx_2 = i2 * split_2_slice_size
end_idx_2 = (i2 + 1) * split_2_slice_size
if no_shape_one:
if do_split_3:
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
start_idx_3 = i3 * split_3_slice_size
end_idx_3 = (i3 + 1) * split_3_slice_size
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_scaled_dot_product_attention(
query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal
)
else:
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
query[start_idx:end_idx, start_idx_2:end_idx_2],
key[start_idx:end_idx, start_idx_2:end_idx_2],
@@ -120,38 +162,14 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal
)
else:
hidden_states[:, start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
query[:, start_idx:end_idx, start_idx_2:end_idx_2],
key[:, start_idx:end_idx, start_idx_2:end_idx_2],
value[:, start_idx:end_idx, start_idx_2:end_idx_2],
attn_mask=attn_mask[:, start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal
)
else:
if no_shape_one:
hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention(
query[start_idx:end_idx],
key[start_idx:end_idx],
value[start_idx:end_idx],
attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal
)
else:
hidden_states[:, start_idx:end_idx] = original_scaled_dot_product_attention(
query[:, start_idx:end_idx],
key[:, start_idx:end_idx],
value[:, start_idx:end_idx],
attn_mask=attn_mask[:, start_idx:end_idx] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal
)
hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention(
query[start_idx:end_idx],
key[start_idx:end_idx],
value[start_idx:end_idx],
attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal
)
else:
return original_scaled_dot_product_attention(
query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal
)
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
return hidden_states
def attention_init():
#ARC GPUs can't allocate more than 4GB to a single block:
torch.bmm = torch_bmm
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention

View File

@@ -1,10 +1,62 @@
import os
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
import diffusers #0.21.1 # pylint: disable=import-error
import diffusers #0.24.0 # pylint: disable=import-error
from diffusers.models.attention_processor import Attention
from diffusers.utils import USE_PEFT_BACKEND
from functools import cache
# pylint: disable=protected-access, missing-function-docstring, line-too-long
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
@cache
def find_slice_size(slice_size, slice_block_size):
while (slice_size * slice_block_size) > attention_slice_rate:
slice_size = slice_size // 2
if slice_size <= 1:
slice_size = 1
break
return slice_size
@cache
def find_attention_slice_sizes(query_shape, query_element_size, query_device_type, slice_size=None):
if len(query_shape) == 3:
batch_size_attention, query_tokens, shape_three = query_shape
shape_four = 1
else:
batch_size_attention, query_tokens, shape_three, shape_four = query_shape
if slice_size is not None:
batch_size_attention = slice_size
slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size
block_size = batch_size_attention * slice_block_size
split_slice_size = batch_size_attention
split_2_slice_size = query_tokens
split_3_slice_size = shape_three
do_split = False
do_split_2 = False
do_split_3 = False
if query_device_type != "xpu":
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
if block_size > attention_slice_rate:
do_split = True
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
if split_slice_size * slice_block_size > attention_slice_rate:
slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size
do_split_2 = True
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size
do_split_3 = True
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
r"""
Processor for implementing sliced attention.
@@ -18,7 +70,9 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
def __init__(self, slice_size):
self.slice_size = slice_size
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): # pylint: disable=too-many-statements, too-many-locals, too-many-branches
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
encoder_hidden_states=None, attention_mask=None) -> torch.FloatTensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches
residual = hidden_states
input_ndim = hidden_states.ndim
@@ -54,49 +108,61 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
)
#ARC GPUs can't allocate more than 4GB to a single block, Slice it:
block_multiply = query.element_size()
slice_block_size = self.slice_size * shape_three / 1024 / 1024 * block_multiply
block_size = query_tokens * slice_block_size
split_2_slice_size = query_tokens
if block_size > 4:
do_split_2 = True
#Find something divisible with the query_tokens
while (split_2_slice_size * slice_block_size) > 4:
split_2_slice_size = split_2_slice_size // 2
if split_2_slice_size <= 1:
split_2_slice_size = 1
break
else:
do_split_2 = False
for i in range(batch_size_attention // self.slice_size):
start_idx = i * self.slice_size
end_idx = (i + 1) * self.slice_size
####################################################################
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
_, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type, slice_size=self.slice_size)
for i in range(batch_size_attention // split_slice_size):
start_idx = i * split_slice_size
end_idx = (i + 1) * split_slice_size
if do_split_2:
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
start_idx_2 = i2 * split_2_slice_size
end_idx_2 = (i2 + 1) * split_2_slice_size
if do_split_3:
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
start_idx_3 = i3 * split_3_slice_size
end_idx_3 = (i3 + 1) * split_3_slice_size
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
del query_slice
del key_slice
del attn_mask_slice
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3])
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice
del attn_slice
else:
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
del query_slice
del key_slice
del attn_mask_slice
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
del attn_slice
else:
query_slice = query[start_idx:end_idx]
key_slice = key[start_idx:end_idx]
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
del query_slice
del key_slice
del attn_mask_slice
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
del attn_slice
####################################################################
hidden_states = attn.batch_to_head_dim(hidden_states)
@@ -115,6 +181,130 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
return hidden_states
class AttnProcessor:
r"""
Default processor for performing attention-related computations.
"""
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
encoder_hidden_states=None, attention_mask=None,
temb=None, scale: float = 1.0) -> torch.Tensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches
residual = hidden_states
args = () if USE_PEFT_BACKEND else (scale,)
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, *args)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
####################################################################
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type)
if do_split:
for i in range(batch_size_attention // split_slice_size):
start_idx = i * split_slice_size
end_idx = (i + 1) * split_slice_size
if do_split_2:
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
start_idx_2 = i2 * split_2_slice_size
end_idx_2 = (i2 + 1) * split_2_slice_size
if do_split_3:
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
start_idx_3 = i3 * split_3_slice_size
end_idx_3 = (i3 + 1) * split_3_slice_size
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
del query_slice
del key_slice
del attn_mask_slice
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3])
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice
del attn_slice
else:
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
del query_slice
del key_slice
del attn_mask_slice
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
del attn_slice
else:
query_slice = query[start_idx:end_idx]
key_slice = key[start_idx:end_idx]
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
del query_slice
del key_slice
del attn_mask_slice
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
del attn_slice
else:
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
####################################################################
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
def ipex_diffusers():
#ARC GPUs can't allocate more than 4GB to a single block:
diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor
diffusers.models.attention_processor.AttnProcessor = AttnProcessor

View File

@@ -5,6 +5,7 @@ import intel_extension_for_pytorch._C as core # pylint: disable=import-error, un
# pylint: disable=protected-access, missing-function-docstring, line-too-long
device_supports_fp64 = torch.xpu.has_fp64_dtype()
OptState = ipex.cpu.autocast._grad_scaler.OptState
_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator
_refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state
@@ -96,7 +97,10 @@ def unscale_(self, optimizer):
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
assert self._scale is not None
inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device)
if device_supports_fp64:
inv_scale = self._scale.double().reciprocal().float()
else:
inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device)
found_inf = torch.full(
(1,), 0.0, dtype=torch.float32, device=self._scale.device
)

View File

@@ -1,67 +1,9 @@
import contextlib
import importlib
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
class CondFunc: # pylint: disable=missing-class-docstring
def __new__(cls, orig_func, sub_func, cond_func):
self = super(CondFunc, cls).__new__(cls)
if isinstance(orig_func, str):
func_path = orig_func.split('.')
for i in range(len(func_path)-1, -1, -1):
try:
resolved_obj = importlib.import_module('.'.join(func_path[:i]))
break
except ImportError:
pass
for attr_name in func_path[i:-1]:
resolved_obj = getattr(resolved_obj, attr_name)
orig_func = getattr(resolved_obj, func_path[-1])
setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
self.__init__(orig_func, sub_func, cond_func)
return lambda *args, **kwargs: self(*args, **kwargs)
def __init__(self, orig_func, sub_func, cond_func):
self.__orig_func = orig_func
self.__sub_func = sub_func
self.__cond_func = cond_func
def __call__(self, *args, **kwargs):
if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
return self.__sub_func(self.__orig_func, *args, **kwargs)
else:
return self.__orig_func(*args, **kwargs)
_utils = torch.utils.data._utils
def _shutdown_workers(self):
if torch.utils.data._utils is None or torch.utils.data._utils.python_exit_status is True or torch.utils.data._utils.python_exit_status is None:
return
if hasattr(self, "_shutdown") and not self._shutdown:
self._shutdown = True
try:
if hasattr(self, '_pin_memory_thread'):
self._pin_memory_thread_done_event.set()
self._worker_result_queue.put((None, None))
self._pin_memory_thread.join()
self._worker_result_queue.cancel_join_thread()
self._worker_result_queue.close()
self._workers_done_event.set()
for worker_id in range(len(self._workers)):
if self._persistent_workers or self._workers_status[worker_id]:
self._mark_worker_as_unavailable(worker_id, shutdown=True)
for w in self._workers: # pylint: disable=invalid-name
w.join(timeout=torch.utils.data._utils.MP_STATUS_CHECK_INTERVAL)
for q in self._index_queues: # pylint: disable=invalid-name
q.cancel_join_thread()
q.close()
finally:
if self._worker_pids_set:
torch.utils.data._utils.signal_handling._remove_worker_pids(id(self))
self._worker_pids_set = False
for w in self._workers: # pylint: disable=invalid-name
if w.is_alive():
w.terminate()
class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument
if isinstance(device_ids, list) and len(device_ids) > 1:
@@ -71,17 +13,18 @@ class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstr
def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
return contextlib.nullcontext()
@property
def is_cuda(self):
return self.device.type == 'xpu' or self.device.type == 'cuda'
def check_device(device):
return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int))
def return_xpu(device):
return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu"
def ipex_no_cuda(orig_func, *args, **kwargs):
torch.cuda.is_available = lambda: False
orig_func(*args, **kwargs)
torch.cuda.is_available = torch.xpu.is_available
# Autocast
original_autocast = torch.autocast
def ipex_autocast(*args, **kwargs):
if len(args) > 0 and args[0] == "cuda":
@@ -89,13 +32,7 @@ def ipex_autocast(*args, **kwargs):
else:
return original_autocast(*args, **kwargs)
original_torch_cat = torch.cat
def torch_cat(tensor, *args, **kwargs):
if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs)
else:
return original_torch_cat(tensor, *args, **kwargs)
# Latent Antialias CPU Offload:
original_interpolate = torch.nn.functional.interpolate
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
if antialias or align_corners is not None:
@@ -107,90 +44,205 @@ def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corn
return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode,
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias)
original_linalg_solve = torch.linalg.solve
def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name
if A.device != torch.device("cpu") or B.device != torch.device("cpu"):
return_device = A.device
return original_linalg_solve(A.to("cpu"), B.to("cpu"), *args, **kwargs).to(return_device)
# Diffusers Float64 (Alchemist GPUs doesn't support 64 bit):
original_from_numpy = torch.from_numpy
def from_numpy(ndarray):
if ndarray.dtype == float:
return original_from_numpy(ndarray.astype('float32'))
else:
return original_linalg_solve(A, B, *args, **kwargs)
return original_from_numpy(ndarray)
if torch.xpu.has_fp64_dtype():
original_torch_bmm = torch.bmm
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
else:
# 32 bit attention workarounds for Alchemist:
try:
from .attention import torch_bmm_32_bit as original_torch_bmm
from .attention import scaled_dot_product_attention_32_bit as original_scaled_dot_product_attention
except Exception: # pylint: disable=broad-exception-caught
original_torch_bmm = torch.bmm
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
# Data Type Errors:
def torch_bmm(input, mat2, *, out=None):
if input.dtype != mat2.dtype:
mat2 = mat2.to(input.dtype)
return original_torch_bmm(input, mat2, out=out)
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
if query.dtype != key.dtype:
key = key.to(dtype=query.dtype)
if query.dtype != value.dtype:
value = value.to(dtype=query.dtype)
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
# A1111 FP16
original_functional_group_norm = torch.nn.functional.group_norm
def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05):
if weight is not None and input.dtype != weight.data.dtype:
input = input.to(dtype=weight.data.dtype)
if bias is not None and weight is not None and bias.data.dtype != weight.data.dtype:
bias.data = bias.data.to(dtype=weight.data.dtype)
return original_functional_group_norm(input, num_groups, weight=weight, bias=bias, eps=eps)
# A1111 BF16
original_functional_layer_norm = torch.nn.functional.layer_norm
def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
if weight is not None and input.dtype != weight.data.dtype:
input = input.to(dtype=weight.data.dtype)
if bias is not None and weight is not None and bias.data.dtype != weight.data.dtype:
bias.data = bias.data.to(dtype=weight.data.dtype)
return original_functional_layer_norm(input, normalized_shape, weight=weight, bias=bias, eps=eps)
# Training
original_functional_linear = torch.nn.functional.linear
def functional_linear(input, weight, bias=None):
if input.dtype != weight.data.dtype:
input = input.to(dtype=weight.data.dtype)
if bias is not None and bias.data.dtype != weight.data.dtype:
bias.data = bias.data.to(dtype=weight.data.dtype)
return original_functional_linear(input, weight, bias=bias)
original_functional_conv2d = torch.nn.functional.conv2d
def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
if input.dtype != weight.data.dtype:
input = input.to(dtype=weight.data.dtype)
if bias is not None and bias.data.dtype != weight.data.dtype:
bias.data = bias.data.to(dtype=weight.data.dtype)
return original_functional_conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
# A1111 Embedding BF16
original_torch_cat = torch.cat
def torch_cat(tensor, *args, **kwargs):
if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs)
else:
return original_torch_cat(tensor, *args, **kwargs)
# SwinIR BF16:
original_functional_pad = torch.nn.functional.pad
def functional_pad(input, pad, mode='constant', value=None):
if mode == 'reflect' and input.dtype == torch.bfloat16:
return original_functional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16)
else:
return original_functional_pad(input, pad, mode=mode, value=value)
original_torch_tensor = torch.tensor
def torch_tensor(*args, device=None, **kwargs):
if check_device(device):
return original_torch_tensor(*args, device=return_xpu(device), **kwargs)
else:
return original_torch_tensor(*args, device=device, **kwargs)
original_Tensor_to = torch.Tensor.to
def Tensor_to(self, device=None, *args, **kwargs):
if check_device(device):
return original_Tensor_to(self, return_xpu(device), *args, **kwargs)
else:
return original_Tensor_to(self, device, *args, **kwargs)
original_Tensor_cuda = torch.Tensor.cuda
def Tensor_cuda(self, device=None, *args, **kwargs):
if check_device(device):
return original_Tensor_cuda(self, return_xpu(device), *args, **kwargs)
else:
return original_Tensor_cuda(self, device, *args, **kwargs)
original_UntypedStorage_init = torch.UntypedStorage.__init__
def UntypedStorage_init(*args, device=None, **kwargs):
if check_device(device):
return original_UntypedStorage_init(*args, device=return_xpu(device), **kwargs)
else:
return original_UntypedStorage_init(*args, device=device, **kwargs)
original_UntypedStorage_cuda = torch.UntypedStorage.cuda
def UntypedStorage_cuda(self, device=None, *args, **kwargs):
if check_device(device):
return original_UntypedStorage_cuda(self, return_xpu(device), *args, **kwargs)
else:
return original_UntypedStorage_cuda(self, device, *args, **kwargs)
original_torch_empty = torch.empty
def torch_empty(*args, device=None, **kwargs):
if check_device(device):
return original_torch_empty(*args, device=return_xpu(device), **kwargs)
else:
return original_torch_empty(*args, device=device, **kwargs)
original_torch_randn = torch.randn
def torch_randn(*args, device=None, **kwargs):
if check_device(device):
return original_torch_randn(*args, device=return_xpu(device), **kwargs)
else:
return original_torch_randn(*args, device=device, **kwargs)
original_torch_ones = torch.ones
def torch_ones(*args, device=None, **kwargs):
if check_device(device):
return original_torch_ones(*args, device=return_xpu(device), **kwargs)
else:
return original_torch_ones(*args, device=device, **kwargs)
original_torch_zeros = torch.zeros
def torch_zeros(*args, device=None, **kwargs):
if check_device(device):
return original_torch_zeros(*args, device=return_xpu(device), **kwargs)
else:
return original_torch_zeros(*args, device=device, **kwargs)
original_torch_linspace = torch.linspace
def torch_linspace(*args, device=None, **kwargs):
if check_device(device):
return original_torch_linspace(*args, device=return_xpu(device), **kwargs)
else:
return original_torch_linspace(*args, device=device, **kwargs)
original_torch_Generator = torch.Generator
def torch_Generator(device=None):
if check_device(device):
return original_torch_Generator(return_xpu(device))
else:
return original_torch_Generator(device)
original_torch_load = torch.load
def torch_load(f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs):
if check_device(map_location):
return original_torch_load(f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs)
else:
return original_torch_load(f, map_location=map_location, pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs)
# Hijack Functions:
def ipex_hijacks():
CondFunc('torch.Tensor.to',
lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
CondFunc('torch.Tensor.cuda',
lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
CondFunc('torch.empty',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.load',
lambda orig_func, *args, map_location=None, **kwargs: orig_func(*args, return_xpu(map_location), **kwargs),
lambda orig_func, *args, map_location=None, **kwargs: map_location is None or check_device(map_location))
CondFunc('torch.randn',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.ones',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.zeros',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.tensor',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.linspace',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
torch.tensor = torch_tensor
torch.Tensor.to = Tensor_to
torch.Tensor.cuda = Tensor_cuda
torch.UntypedStorage.__init__ = UntypedStorage_init
torch.UntypedStorage.cuda = UntypedStorage_cuda
torch.empty = torch_empty
torch.randn = torch_randn
torch.ones = torch_ones
torch.zeros = torch_zeros
torch.linspace = torch_linspace
torch.Generator = torch_Generator
torch.load = torch_load
CondFunc('torch.Generator',
lambda orig_func, device=None: torch.xpu.Generator(device),
lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu")
CondFunc('torch.batch_norm',
lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input,
weight if weight is not None else torch.ones(input.size()[1], device=input.device),
bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs),
lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu"))
CondFunc('torch.instance_norm',
lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input,
weight if weight is not None else torch.ones(input.size()[1], device=input.device),
bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs),
lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu"))
#Functions with dtype errors:
CondFunc('torch.nn.modules.GroupNorm.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
CondFunc('torch.nn.modules.linear.Linear.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
CondFunc('torch.nn.modules.conv.Conv2d.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
CondFunc('torch.nn.functional.layer_norm',
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs),
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
weight is not None and input.dtype != weight.data.dtype)
#Diffusers Float64 (ARC GPUs doesn't support double or Float64):
if not torch.xpu.has_fp64_dtype():
CondFunc('torch.from_numpy',
lambda orig_func, ndarray: orig_func(ndarray.astype('float32')),
lambda orig_func, ndarray: ndarray.dtype == float)
#Broken functions when torch.cuda.is_available is True:
CondFunc('torch.utils.data.dataloader._BaseDataLoaderIter.__init__',
lambda orig_func, *args, **kwargs: ipex_no_cuda(orig_func, *args, **kwargs),
lambda orig_func, *args, **kwargs: True)
#Functions that make compile mad with CondFunc:
torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = _shutdown_workers
torch.nn.DataParallel = DummyDataParallel
torch.autocast = ipex_autocast
torch.cat = torch_cat
torch.linalg.solve = linalg_solve
torch.nn.functional.interpolate = interpolate
torch.backends.cuda.sdp_kernel = return_null_context
torch.nn.DataParallel = DummyDataParallel
torch.UntypedStorage.is_cuda = is_cuda
torch.autocast = ipex_autocast
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
torch.nn.functional.group_norm = functional_group_norm
torch.nn.functional.layer_norm = functional_layer_norm
torch.nn.functional.linear = functional_linear
torch.nn.functional.conv2d = functional_conv2d
torch.nn.functional.interpolate = interpolate
torch.nn.functional.pad = functional_pad
torch.bmm = torch_bmm
torch.cat = torch_cat
if not torch.xpu.has_fp64_dtype():
torch.from_numpy = from_numpy

View File

@@ -9,7 +9,7 @@ import numpy as np
import PIL.Image
import torch
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
import diffusers
from diffusers import SchedulerMixin, StableDiffusionPipeline
@@ -520,6 +520,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
image_encoder: CLIPVisionModelWithProjection = None,
clip_skip: int = 1,
):
super().__init__(
@@ -531,32 +532,11 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
requires_safety_checker=requires_safety_checker,
image_encoder=image_encoder,
)
self.clip_skip = clip_skip
self.custom_clip_skip = clip_skip
self.__init__additional__()
# else:
# def __init__(
# self,
# vae: AutoencoderKL,
# text_encoder: CLIPTextModel,
# tokenizer: CLIPTokenizer,
# unet: UNet2DConditionModel,
# scheduler: SchedulerMixin,
# safety_checker: StableDiffusionSafetyChecker,
# feature_extractor: CLIPFeatureExtractor,
# ):
# super().__init__(
# vae=vae,
# text_encoder=text_encoder,
# tokenizer=tokenizer,
# unet=unet,
# scheduler=scheduler,
# safety_checker=safety_checker,
# feature_extractor=feature_extractor,
# )
# self.__init__additional__()
def __init__additional__(self):
if not hasattr(self, "vae_scale_factor"):
setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
@@ -624,7 +604,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
prompt=prompt,
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
max_embeddings_multiples=max_embeddings_multiples,
clip_skip=self.clip_skip,
clip_skip=self.custom_clip_skip,
)
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)

View File

@@ -4,10 +4,13 @@
import math
import os
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
@@ -571,9 +574,9 @@ def convert_ldm_clip_checkpoint_v1(checkpoint):
if key.startswith("cond_stage_model.transformer"):
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
# support checkpoint without position_ids (invalid checkpoint)
if "text_model.embeddings.position_ids" not in text_model_dict:
text_model_dict["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) # 77 is the max length of the text
# remove position_ids for newer transformer, which causes error :(
if "text_model.embeddings.position_ids" in text_model_dict:
text_model_dict.pop("text_model.embeddings.position_ids")
return text_model_dict
@@ -1242,8 +1245,13 @@ def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_mod
if vae is None:
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
# original U-Net cannot be saved, so we need to convert it to the Diffusers version
# TODO this consumes a lot of memory
diffusers_unet = diffusers.UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet")
diffusers_unet.load_state_dict(unet.state_dict())
pipeline = StableDiffusionPipeline(
unet=unet,
unet=diffusers_unet,
text_encoder=text_encoder,
vae=vae,
scheduler=scheduler,
@@ -1307,19 +1315,19 @@ def load_vae(vae_id, dtype):
def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
max_width, max_height = max_reso
max_area = (max_width // divisible) * (max_height // divisible)
max_area = max_width * max_height
resos = set()
size = int(math.sqrt(max_area)) * divisible
resos.add((size, size))
width = int(math.sqrt(max_area) // divisible) * divisible
resos.add((width, width))
size = min_size
while size <= max_size:
width = size
height = min(max_size, (max_area // (width // divisible)) * divisible)
resos.add((width, height))
resos.add((height, width))
width = min_size
while width <= max_size:
height = min(max_size, int((max_area // width) // divisible) * divisible)
if height >= min_size:
resos.add((width, height))
resos.add((height, width))
# # make additional resos
# if width >= height and width - divisible >= min_size:
@@ -1329,7 +1337,7 @@ def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64)
# resos.add((width, height - divisible))
# resos.add((height - divisible, width))
size += divisible
width += divisible
resos = list(resos)
resos.sort()

View File

@@ -361,6 +361,23 @@ def get_timestep_embedding(
return emb
# Deep Shrink: We do not common this function, because minimize dependencies.
def resize_like(x, target, mode="bicubic", align_corners=False):
org_dtype = x.dtype
if org_dtype == torch.bfloat16:
x = x.to(torch.float32)
if x.shape[-2:] != target.shape[-2:]:
if mode == "nearest":
x = F.interpolate(x, size=target.shape[-2:], mode=mode)
else:
x = F.interpolate(x, size=target.shape[-2:], mode=mode, align_corners=align_corners)
if org_dtype == torch.bfloat16:
x = x.to(org_dtype)
return x
class SampleOutput:
def __init__(self, sample):
self.sample = sample
@@ -569,6 +586,9 @@ class CrossAttention(nn.Module):
self.use_memory_efficient_attention_mem_eff = False
self.use_sdpa = False
# Attention processor
self.processor = None
def set_use_memory_efficient_attention(self, xformers, mem_eff):
self.use_memory_efficient_attention_xformers = xformers
self.use_memory_efficient_attention_mem_eff = mem_eff
@@ -590,7 +610,28 @@ class CrossAttention(nn.Module):
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
def forward(self, hidden_states, context=None, mask=None):
def set_processor(self):
return self.processor
def get_processor(self):
return self.processor
def forward(self, hidden_states, context=None, mask=None, **kwargs):
if self.processor is not None:
(
hidden_states,
encoder_hidden_states,
attention_mask,
) = translate_attention_names_from_diffusers(
hidden_states=hidden_states, context=context, mask=mask, **kwargs
)
return self.processor(
attn=self,
hidden_states=hidden_states,
encoder_hidden_states=context,
attention_mask=mask,
**kwargs
)
if self.use_memory_efficient_attention_xformers:
return self.forward_memory_efficient_xformers(hidden_states, context, mask)
if self.use_memory_efficient_attention_mem_eff:
@@ -703,6 +744,21 @@ class CrossAttention(nn.Module):
out = self.to_out[0](out)
return out
def translate_attention_names_from_diffusers(
hidden_states: torch.FloatTensor,
context: Optional[torch.FloatTensor] = None,
mask: Optional[torch.FloatTensor] = None,
# HF naming
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None
):
# translate from hugging face diffusers
context = context if context is not None else encoder_hidden_states
# translate from hugging face diffusers
mask = mask if mask is not None else attention_mask
return hidden_states, context, mask
# feedforward
class GEGLU(nn.Module):
@@ -1130,6 +1186,7 @@ class UpBlock2D(nn.Module):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
@@ -1221,6 +1278,7 @@ class CrossAttnUpBlock2D(nn.Module):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
@@ -1331,7 +1389,7 @@ class UNet2DConditionModel(nn.Module):
self.out_channels = OUT_CHANNELS
self.sample_size = sample_size
self.prepare_config()
self.prepare_config(sample_size=sample_size)
# state_dictの書式が変わるのでmoduleの持ち方は変えられない
@@ -1418,8 +1476,8 @@ class UNet2DConditionModel(nn.Module):
self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1)
# region diffusers compatibility
def prepare_config(self):
self.config = SimpleNamespace()
def prepare_config(self, *args, **kwargs):
self.config = SimpleNamespace(**kwargs)
@property
def dtype(self) -> torch.dtype:
@@ -1519,7 +1577,6 @@ class UNet2DConditionModel(nn.Module):
# 2. pre-process
sample = self.conv_in(sample)
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
# downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
@@ -1604,3 +1661,255 @@ class UNet2DConditionModel(nn.Module):
timesteps = timesteps.expand(sample.shape[0])
return timesteps
class InferUNet2DConditionModel:
def __init__(self, original_unet: UNet2DConditionModel):
self.delegate = original_unet
# override original model's forward method: because forward is not called by `__call__`
# overriding `__call__` is not enough, because nn.Module.forward has a special handling
self.delegate.forward = self.forward
# override original model's up blocks' forward method
for up_block in self.delegate.up_blocks:
if up_block.__class__.__name__ == "UpBlock2D":
def resnet_wrapper(func, block):
def forward(*args, **kwargs):
return func(block, *args, **kwargs)
return forward
up_block.forward = resnet_wrapper(self.up_block_forward, up_block)
elif up_block.__class__.__name__ == "CrossAttnUpBlock2D":
def cross_attn_up_wrapper(func, block):
def forward(*args, **kwargs):
return func(block, *args, **kwargs)
return forward
up_block.forward = cross_attn_up_wrapper(self.cross_attn_up_block_forward, up_block)
# Deep Shrink
self.ds_depth_1 = None
self.ds_depth_2 = None
self.ds_timesteps_1 = None
self.ds_timesteps_2 = None
self.ds_ratio = None
# call original model's methods
def __getattr__(self, name):
return getattr(self.delegate, name)
def __call__(self, *args, **kwargs):
return self.delegate(*args, **kwargs)
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
if ds_depth_1 is None:
print("Deep Shrink is disabled.")
self.ds_depth_1 = None
self.ds_timesteps_1 = None
self.ds_depth_2 = None
self.ds_timesteps_2 = None
self.ds_ratio = None
else:
print(
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
)
self.ds_depth_1 = ds_depth_1
self.ds_timesteps_1 = ds_timesteps_1
self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
self.ds_ratio = ds_ratio
def up_block_forward(self, _self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
for resnet in _self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# Deep Shrink
if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
hidden_states = resize_like(hidden_states, res_hidden_states)
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
if _self.upsamplers is not None:
for upsampler in _self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
def cross_attn_up_block_forward(
self,
_self,
hidden_states,
res_hidden_states_tuple,
temb=None,
encoder_hidden_states=None,
upsample_size=None,
):
for resnet, attn in zip(_self.resnets, _self.attentions):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# Deep Shrink
if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
hidden_states = resize_like(hidden_states, res_hidden_states)
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
if _self.upsamplers is not None:
for upsampler in _self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None,
return_dict: bool = True,
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
mid_block_additional_residual: Optional[torch.Tensor] = None,
) -> Union[Dict, Tuple]:
r"""
current implementation is a copy of `UNet2DConditionModel.forward()` with Deep Shrink.
"""
r"""
Args:
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a dict instead of a plain tuple.
Returns:
`SampleOutput` or `tuple`:
`SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
"""
_self = self.delegate
# By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
# However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary.
# デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある
# ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する
# 多分画質が悪くなるので、64で割り切れるようにしておくのが良い
default_overall_up_factor = 2**_self.num_upsamplers
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
# 64で割り切れないときはupsamplerにサイズを伝える
forward_upsample_size = False
upsample_size = None
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
# logger.info("Forward upsample size to force interpolation output size.")
forward_upsample_size = True
# 1. time
timesteps = timestep
timesteps = _self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理
t_emb = _self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
# timestepsは重みを含まないので常にfloat32のテンソルを返す
# しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある
# time_projでキャストしておけばいいんじゃね
t_emb = t_emb.to(dtype=_self.dtype)
emb = _self.time_embedding(t_emb)
# 2. pre-process
sample = _self.conv_in(sample)
down_block_res_samples = (sample,)
for depth, downsample_block in enumerate(_self.down_blocks):
# Deep Shrink
if self.ds_depth_1 is not None:
if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
self.ds_depth_2 is not None
and depth == self.ds_depth_2
and timesteps[0] < self.ds_timesteps_1
and timesteps[0] >= self.ds_timesteps_2
):
org_dtype = sample.dtype
if org_dtype == torch.bfloat16:
sample = sample.to(torch.float32)
sample = F.interpolate(sample, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype)
# downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
# まあこちらのほうがわかりやすいかもしれない
if downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
# skip connectionにControlNetの出力を追加する
if down_block_additional_residuals is not None:
down_block_res_samples = list(down_block_res_samples)
for i in range(len(down_block_res_samples)):
down_block_res_samples[i] += down_block_additional_residuals[i]
down_block_res_samples = tuple(down_block_res_samples)
# 4. mid
sample = _self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
# ControlNetの出力を追加する
if mid_block_additional_residual is not None:
sample += mid_block_additional_residual
# 5. up
for i, upsample_block in enumerate(_self.up_blocks):
is_final_block = i == len(_self.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection
# if we have not reached the final block and need to forward the upsample size, we do it here
# 前述のように最後のブロック以外ではupsample_sizeを伝える
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if upsample_block.has_cross_attention:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
upsample_size=upsample_size,
)
else:
sample = upsample_block(
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
)
# 6. post-process
sample = _self.conv_norm_out(sample)
sample = _self.conv_act(sample)
sample = _self.conv_out(sample)
if not return_dict:
return (sample,)
return SampleOutput(sample=sample)

View File

@@ -100,7 +100,7 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
key = key.replace(".ln_final", ".final_layer_norm")
# ckpt from comfy has this key: text_model.encoder.text_model.embeddings.position_ids
elif ".embeddings.position_ids" in key:
key = None # remove this key: make position_ids by ourselves
key = None # remove this key: position_ids is not used in newer transformers
return key
keys = list(checkpoint.keys())
@@ -126,13 +126,15 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
# original SD にはないので、position_idsを追加
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
new_sd["text_model.embeddings.position_ids"] = position_ids
# logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す
logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None)
# temporary workaround for text_projection.weight.weight for Playground-v2
if "text_projection.weight.weight" in new_sd:
print(f"convert_sdxl_text_encoder_2_checkpoint: convert text_projection.weight.weight to text_projection.weight")
new_sd["text_projection.weight"] = new_sd["text_projection.weight.weight"]
del new_sd["text_projection.weight.weight"]
return new_sd, logit_scale
@@ -258,10 +260,10 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
te1_sd[k.replace("conditioner.embedders.0.transformer.", "")] = state_dict.pop(k)
elif k.startswith("conditioner.embedders.1.model."):
te2_sd[k] = state_dict.pop(k)
# 一部のposition_idsがないモデルへの対応 / add position_ids for some models
if "text_model.embeddings.position_ids" not in te1_sd:
te1_sd["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0)
# 最新の transformers では position_ids を含むとエラーになるので削除 / remove position_ids for latest transformers
if "text_model.embeddings.position_ids" in te1_sd:
te1_sd.pop("text_model.embeddings.position_ids")
info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32
print("text encoder 1:", info1)

View File

@@ -24,7 +24,7 @@
import math
from types import SimpleNamespace
from typing import Optional
from typing import Any, Optional
import torch
import torch.utils.checkpoint
from torch import nn
@@ -266,6 +266,23 @@ def get_timestep_embedding(
return emb
# Deep Shrink: We do not common this function, because minimize dependencies.
def resize_like(x, target, mode="bicubic", align_corners=False):
org_dtype = x.dtype
if org_dtype == torch.bfloat16:
x = x.to(torch.float32)
if x.shape[-2:] != target.shape[-2:]:
if mode == "nearest":
x = F.interpolate(x, size=target.shape[-2:], mode=mode)
else:
x = F.interpolate(x, size=target.shape[-2:], mode=mode, align_corners=align_corners)
if org_dtype == torch.bfloat16:
x = x.to(org_dtype)
return x
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
if self.weight.dtype != torch.float32:
@@ -1077,6 +1094,7 @@ class SdxlUNet2DConditionModel(nn.Module):
# h = x.type(self.dtype)
h = x
for module in self.input_blocks:
h = call_module(module, h, emb, context)
hs.append(h)
@@ -1093,6 +1111,121 @@ class SdxlUNet2DConditionModel(nn.Module):
return h
class InferSdxlUNet2DConditionModel:
def __init__(self, original_unet: SdxlUNet2DConditionModel, **kwargs):
self.delegate = original_unet
# override original model's forward method: because forward is not called by `__call__`
# overriding `__call__` is not enough, because nn.Module.forward has a special handling
self.delegate.forward = self.forward
# Deep Shrink
self.ds_depth_1 = None
self.ds_depth_2 = None
self.ds_timesteps_1 = None
self.ds_timesteps_2 = None
self.ds_ratio = None
# call original model's methods
def __getattr__(self, name):
return getattr(self.delegate, name)
def __call__(self, *args, **kwargs):
return self.delegate(*args, **kwargs)
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
if ds_depth_1 is None:
print("Deep Shrink is disabled.")
self.ds_depth_1 = None
self.ds_timesteps_1 = None
self.ds_depth_2 = None
self.ds_timesteps_2 = None
self.ds_ratio = None
else:
print(
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
)
self.ds_depth_1 = ds_depth_1
self.ds_timesteps_1 = ds_timesteps_1
self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
self.ds_ratio = ds_ratio
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
r"""
current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink.
"""
_self = self.delegate
# broadcast timesteps to batch dimension
timesteps = timesteps.expand(x.shape[0])
hs = []
t_emb = get_timestep_embedding(timesteps, _self.model_channels) # , repeat_only=False)
t_emb = t_emb.to(x.dtype)
emb = _self.time_embed(t_emb)
assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
# assert x.dtype == _self.dtype
emb = emb + _self.label_emb(y)
def call_module(module, h, emb, context):
x = h
for layer in module:
# print(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None)
if isinstance(layer, ResnetBlock2D):
x = layer(x, emb)
elif isinstance(layer, Transformer2DModel):
x = layer(x, context)
else:
x = layer(x)
return x
# h = x.type(self.dtype)
h = x
for depth, module in enumerate(_self.input_blocks):
# Deep Shrink
if self.ds_depth_1 is not None:
if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
self.ds_depth_2 is not None
and depth == self.ds_depth_2
and timesteps[0] < self.ds_timesteps_1
and timesteps[0] >= self.ds_timesteps_2
):
# print("downsample", h.shape, self.ds_ratio)
org_dtype = h.dtype
if org_dtype == torch.bfloat16:
h = h.to(torch.float32)
h = F.interpolate(h, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype)
h = call_module(module, h, emb, context)
hs.append(h)
h = call_module(_self.middle_block, h, emb, context)
for module in _self.output_blocks:
# Deep Shrink
if self.ds_depth_1 is not None:
if hs[-1].shape[-2:] != h.shape[-2:]:
# print("upsample", h.shape, hs[-1].shape)
h = resize_like(h, hs[-1])
h = torch.cat([h, hs.pop()], dim=1)
h = call_module(module, h, emb, context)
# Deep Shrink: in case of depth 0
if self.ds_depth_1 == 0 and h.shape[-2:] != x.shape[-2:]:
# print("upsample", h.shape, x.shape)
h = resize_like(h, x)
h = h.type(x.dtype)
h = call_module(_self.out, h, emb, context)
return h
if __name__ == "__main__":
import time

View File

@@ -51,8 +51,6 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
torch.cuda.empty_cache()
accelerator.wait_for_everyone()
text_encoder1, text_encoder2, unet = train_util.transform_models_if_DDP([text_encoder1, text_encoder2, unet])
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info

View File

@@ -62,7 +62,7 @@ def cat_h(sliced):
return x
def resblock_forward(_self, num_slices, input_tensor, temb):
def resblock_forward(_self, num_slices, input_tensor, temb, **kwargs):
assert _self.upsample is None and _self.downsample is None
assert _self.norm1.num_groups == _self.norm2.num_groups
assert temb is None

View File

@@ -3,6 +3,7 @@
import argparse
import ast
import asyncio
import datetime
import importlib
import json
import pathlib
@@ -18,7 +19,7 @@ from typing import (
Tuple,
Union,
)
from accelerate import Accelerator
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs
import gc
import glob
import math
@@ -148,6 +149,13 @@ class ImageInfo:
class BucketManager:
def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None:
if max_size is not None:
if max_reso is not None:
assert max_size >= max_reso[0], "the max_size should be larger than the width of max_reso"
assert max_size >= max_reso[1], "the max_size should be larger than the height of max_reso"
if min_size is not None:
assert max_size >= min_size, "the max_size should be larger than the min_size"
self.no_upscale = no_upscale
if max_reso is None:
self.max_reso = None
@@ -341,7 +349,9 @@ class BaseSubset:
image_dir: Optional[str],
num_repeats: int,
shuffle_caption: bool,
caption_separator: str,
keep_tokens: int,
keep_tokens_separator: str,
color_aug: bool,
flip_aug: bool,
face_crop_aug_range: Optional[Tuple[float, float]],
@@ -357,7 +367,9 @@ class BaseSubset:
self.image_dir = image_dir
self.num_repeats = num_repeats
self.shuffle_caption = shuffle_caption
self.caption_separator = caption_separator
self.keep_tokens = keep_tokens
self.keep_tokens_separator = keep_tokens_separator
self.color_aug = color_aug
self.flip_aug = flip_aug
self.face_crop_aug_range = face_crop_aug_range
@@ -383,7 +395,9 @@ class DreamBoothSubset(BaseSubset):
caption_extension: str,
num_repeats,
shuffle_caption,
caption_separator: str,
keep_tokens,
keep_tokens_separator,
color_aug,
flip_aug,
face_crop_aug_range,
@@ -402,7 +416,9 @@ class DreamBoothSubset(BaseSubset):
image_dir,
num_repeats,
shuffle_caption,
caption_separator,
keep_tokens,
keep_tokens_separator,
color_aug,
flip_aug,
face_crop_aug_range,
@@ -435,7 +451,9 @@ class FineTuningSubset(BaseSubset):
metadata_file: str,
num_repeats,
shuffle_caption,
caption_separator,
keep_tokens,
keep_tokens_separator,
color_aug,
flip_aug,
face_crop_aug_range,
@@ -454,7 +472,9 @@ class FineTuningSubset(BaseSubset):
image_dir,
num_repeats,
shuffle_caption,
caption_separator,
keep_tokens,
keep_tokens_separator,
color_aug,
flip_aug,
face_crop_aug_range,
@@ -484,7 +504,9 @@ class ControlNetSubset(BaseSubset):
caption_extension: str,
num_repeats,
shuffle_caption,
caption_separator,
keep_tokens,
keep_tokens_separator,
color_aug,
flip_aug,
face_crop_aug_range,
@@ -503,7 +525,9 @@ class ControlNetSubset(BaseSubset):
image_dir,
num_repeats,
shuffle_caption,
caption_separator,
keep_tokens,
keep_tokens_separator,
color_aug,
flip_aug,
face_crop_aug_range,
@@ -638,15 +662,33 @@ class BaseDataset(torch.utils.data.Dataset):
caption = ""
else:
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
tokens = [t.strip() for t in caption.strip().split(",")]
fixed_tokens = []
flex_tokens = []
if (
hasattr(subset, "keep_tokens_separator")
and subset.keep_tokens_separator
and subset.keep_tokens_separator in caption
):
fixed_part, flex_part = caption.split(subset.keep_tokens_separator, 1)
fixed_tokens = [t.strip() for t in fixed_part.split(subset.caption_separator) if t.strip()]
flex_tokens = [t.strip() for t in flex_part.split(subset.caption_separator) if t.strip()]
else:
tokens = [t.strip() for t in caption.strip().split(subset.caption_separator)]
flex_tokens = tokens[:]
if subset.keep_tokens > 0:
fixed_tokens = flex_tokens[: subset.keep_tokens]
flex_tokens = tokens[subset.keep_tokens :]
if subset.token_warmup_step < 1: # 初回に上書きする
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
tokens_len = (
math.floor((self.current_step) * ((len(tokens) - subset.token_warmup_min) / (subset.token_warmup_step)))
math.floor(
(self.current_step) * ((len(flex_tokens) - subset.token_warmup_min) / (subset.token_warmup_step))
)
+ subset.token_warmup_min
)
tokens = tokens[:tokens_len]
flex_tokens = flex_tokens[:tokens_len]
def dropout_tags(tokens):
if subset.caption_tag_dropout_rate <= 0:
@@ -657,12 +699,6 @@ class BaseDataset(torch.utils.data.Dataset):
l.append(token)
return l
fixed_tokens = []
flex_tokens = tokens[:]
if subset.keep_tokens > 0:
fixed_tokens = flex_tokens[: subset.keep_tokens]
flex_tokens = tokens[subset.keep_tokens :]
if subset.shuffle_caption:
random.shuffle(flex_tokens)
@@ -1706,7 +1742,9 @@ class ControlNetDataset(BaseDataset):
subset.caption_extension,
subset.num_repeats,
subset.shuffle_caption,
subset.caption_separator,
subset.keep_tokens,
subset.keep_tokens_separator,
subset.color_aug,
subset.flip_aug,
subset.face_crop_aug_range,
@@ -2649,7 +2687,7 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
"--optimizer_type",
type=str,
default="",
help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW8bit, Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor",
help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor",
)
# backward compatibility
@@ -2810,6 +2848,17 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
action="store_true",
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う",
)
parser.add_argument("--torch_compile", action="store_true", help="use torch.compile (requires PyTorch 2.0) / torch.compile を使う")
parser.add_argument(
"--dynamo_backend",
type=str,
default="inductor",
# available backends:
# https://github.com/huggingface/accelerate/blob/d1abd59114ada8ba673e1214218cb2878c13b82d/src/accelerate/utils/dataclasses.py#L376-L388C5
# https://pytorch.org/docs/stable/torch.compiler.html
choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt"],
help="dynamo backend type (default is inductor) / dynamoのbackendの種類デフォルトは inductor"
)
parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
parser.add_argument(
"--sdpa",
@@ -2855,6 +2904,22 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument(
"--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する"
) # TODO move to SDXL training, because it is not supported by SD1/2
parser.add_argument(
"--ddp_timeout",
type=int,
default=None,
help="DDP timeout (min, None for default of accelerate) / DDPのタイムアウト分、Noneでaccelerateのデフォルト",
)
parser.add_argument(
"--ddp_gradient_as_bucket_view",
action="store_true",
help="enable gradient_as_bucket_view for DDP / DDPでgradient_as_bucket_viewを有効にする",
)
parser.add_argument(
"--ddp_static_graph",
action="store_true",
help="enable static_graph for DDP / DDPでstatic_graphを有効にする",
)
parser.add_argument(
"--clip_skip",
type=int,
@@ -2881,6 +2946,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=None,
help="name of tracker to use for logging, default is script-specific default name / ログ出力に使用するtrackerの名前、省略時はスクリプトごとのデフォルト名",
)
parser.add_argument(
"--wandb_run_name",
type=str,
default=None,
help="The name of the specific wandb session / wandb ログに表示される特定の実行の名前",
)
parser.add_argument(
"--log_tracker_config",
type=str,
@@ -2957,6 +3028,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument(
"--sample_every_n_steps", type=int, default=None, help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する"
)
parser.add_argument("--sample_at_first", action="store_true", help="generate sample images before training / 学習前にサンプル出力する")
parser.add_argument(
"--sample_every_n_epochs",
type=int,
@@ -3090,9 +3162,8 @@ def add_dataset_arguments(
):
# dataset common
parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument(
"--shuffle_caption", action="store_true", help="shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする"
)
parser.add_argument("--shuffle_caption", action="store_true", help="shuffle separated caption / 区切られたcaptionの各要素をshuffleする")
parser.add_argument("--caption_separator", type=str, default=",", help="separator for caption / captionの区切り文字")
parser.add_argument(
"--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子"
)
@@ -3108,6 +3179,13 @@ def add_dataset_arguments(
default=0,
help="keep heading N tokens when shuffling caption tokens (token means comma separated strings) / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残すトークンはカンマ区切りの各部分を意味する",
)
parser.add_argument(
"--keep_tokens_separator",
type=str,
default="",
help="A custom separator to divide the caption into fixed and flexible parts. Tokens before this separator will not be shuffled. If not specified, '--keep_tokens' will be used to determine the fixed number of tokens."
+ " / captionを固定部分と可変部分に分けるためのカスタム区切り文字。この区切り文字より前のトークンはシャッフルされない。指定しない場合、'--keep_tokens'が固定部分のトークン数として使用される。",
)
parser.add_argument(
"--caption_prefix",
type=str,
@@ -3359,7 +3437,7 @@ def resume_from_local_or_hf_if_specified(accelerator, args):
def get_optimizer(args, trainable_params):
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW8bit, Lion8bit, PagedLion8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor"
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor"
optimizer_type = args.optimizer_type
if args.use_8bit_adam:
@@ -3463,6 +3541,34 @@ def get_optimizer(args, trainable_params):
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type == "PagedAdamW".lower():
print(f"use PagedAdamW optimizer | {optimizer_kwargs}")
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです")
try:
optimizer_class = bnb.optim.PagedAdamW
except AttributeError:
raise AttributeError(
"No PagedAdamW. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedAdamWが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください"
)
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type == "PagedAdamW32bit".lower():
print(f"use 32-bit PagedAdamW optimizer | {optimizer_kwargs}")
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです")
try:
optimizer_class = bnb.optim.PagedAdamW32bit
except AttributeError:
raise AttributeError(
"No PagedAdamW32bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedAdamW32bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください"
)
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type == "SGDNesterov".lower():
print(f"use SGD with Nesterov optimizer | {optimizer_kwargs}")
if "momentum" not in optimizer_kwargs:
@@ -3780,12 +3886,26 @@ def prepare_accelerator(args: argparse.Namespace):
os.environ["WANDB_DIR"] = logging_dir
if args.wandb_api_key is not None:
wandb.login(key=args.wandb_api_key)
# torch.compile のオプション。 NO の場合は torch.compile は使わない
dynamo_backend = "NO"
if args.torch_compile:
dynamo_backend = args.dynamo_backend
kwargs_handlers = (
InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout)) if args.ddp_timeout else None,
DistributedDataParallelKwargs(gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph)
if args.ddp_gradient_as_bucket_view or args.ddp_static_graph
else None,
)
kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers))
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=log_with,
project_dir=logging_dir,
kwargs_handlers=kwargs_handlers,
dynamo_backend=dynamo_backend,
)
return accelerator
@@ -3854,17 +3974,6 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une
return text_encoder, vae, unet, load_stable_diffusion_format
# TODO remove this function in the future
def transform_if_model_is_DDP(text_encoder, unet, network=None):
# Transform text_encoder, unet and network from DistributedDataParallel
return (model.module if type(model) == DDP else model for model in [text_encoder, unet, network] if model is not None)
def transform_models_if_DDP(models):
# Transform text_encoder, unet and network from DistributedDataParallel
return [model.module if type(model) == DDP else model for model in models if model is not None]
def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False):
# load models for each process
for pi in range(accelerator.state.num_processes):
@@ -3888,8 +3997,6 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
torch.cuda.empty_cache()
accelerator.wait_for_everyone()
text_encoder, unet = transform_if_model_is_DDP(text_encoder, unet)
return text_encoder, vae, unet, load_stable_diffusion_format
@@ -4001,6 +4108,7 @@ def get_hidden_states_sdxl(
text_encoder1: CLIPTextModel,
text_encoder2: CLIPTextModelWithProjection,
weight_dtype: Optional[str] = None,
accelerator: Optional[Accelerator] = None,
):
# input_ids: b,n,77 -> b*n, 77
b_size = input_ids1.size()[0]
@@ -4016,7 +4124,8 @@ def get_hidden_states_sdxl(
hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer
# pool2 = enc_out["text_embeds"]
pool2 = pool_workaround(text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id)
unwrapped_text_encoder2 = text_encoder2 if accelerator is None else accelerator.unwrap_model(text_encoder2)
pool2 = pool_workaround(unwrapped_text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id)
# b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280
n_size = 1 if max_token_length is None else max_token_length // 75
@@ -4375,6 +4484,29 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
return noise, noisy_latents, timesteps
def append_lr_to_logs(logs, lr_scheduler, optimizer_type, including_unet=True):
names = []
if including_unet:
names.append("unet")
names.append("text_encoder1")
names.append("text_encoder2")
append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names)
def append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names):
lrs = lr_scheduler.get_last_lr()
for lr_index in range(len(lrs)):
name = names[lr_index]
logs["lr/" + name] = float(lrs[lr_index])
if optimizer_type.lower().startswith("DAdapt".lower()) or optimizer_type.lower() == "Prodigy".lower():
logs["lr/d*lr/" + name] = (
lr_scheduler.optimizers[-1].param_groups[lr_index]["d"] * lr_scheduler.optimizers[-1].param_groups[lr_index]["lr"]
)
# scheduler:
SCHEDULER_LINEAR_START = 0.00085
SCHEDULER_LINEAR_END = 0.0120
@@ -4382,13 +4514,119 @@ SCHEDULER_TIMESTEPS = 1000
SCHEDLER_SCHEDULE = "scaled_linear"
def get_my_scheduler(
*,
sample_sampler: str,
v_parameterization: bool,
):
sched_init_args = {}
if sample_sampler == "ddim":
scheduler_cls = DDIMScheduler
elif sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある
scheduler_cls = DDPMScheduler
elif sample_sampler == "pndm":
scheduler_cls = PNDMScheduler
elif sample_sampler == "lms" or sample_sampler == "k_lms":
scheduler_cls = LMSDiscreteScheduler
elif sample_sampler == "euler" or sample_sampler == "k_euler":
scheduler_cls = EulerDiscreteScheduler
elif sample_sampler == "euler_a" or sample_sampler == "k_euler_a":
scheduler_cls = EulerAncestralDiscreteScheduler
elif sample_sampler == "dpmsolver" or sample_sampler == "dpmsolver++":
scheduler_cls = DPMSolverMultistepScheduler
sched_init_args["algorithm_type"] = sample_sampler
elif sample_sampler == "dpmsingle":
scheduler_cls = DPMSolverSinglestepScheduler
elif sample_sampler == "heun":
scheduler_cls = HeunDiscreteScheduler
elif sample_sampler == "dpm_2" or sample_sampler == "k_dpm_2":
scheduler_cls = KDPM2DiscreteScheduler
elif sample_sampler == "dpm_2_a" or sample_sampler == "k_dpm_2_a":
scheduler_cls = KDPM2AncestralDiscreteScheduler
else:
scheduler_cls = DDIMScheduler
if v_parameterization:
sched_init_args["prediction_type"] = "v_prediction"
scheduler = scheduler_cls(
num_train_timesteps=SCHEDULER_TIMESTEPS,
beta_start=SCHEDULER_LINEAR_START,
beta_end=SCHEDULER_LINEAR_END,
beta_schedule=SCHEDLER_SCHEDULE,
**sched_init_args,
)
# clip_sample=Trueにする
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
# print("set clip_sample to True")
scheduler.config.clip_sample = True
return scheduler
def sample_images(*args, **kwargs):
return sample_images_common(StableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
def line_to_prompt_dict(line: str) -> dict:
# subset of gen_img_diffusers
prompt_args = line.split(" --")
prompt_dict = {}
prompt_dict["prompt"] = prompt_args[0]
for parg in prompt_args:
try:
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
if m:
prompt_dict["width"] = int(m.group(1))
continue
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
if m:
prompt_dict["height"] = int(m.group(1))
continue
m = re.match(r"d (\d+)", parg, re.IGNORECASE)
if m:
prompt_dict["seed"] = int(m.group(1))
continue
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
if m: # steps
prompt_dict["sample_steps"] = max(1, min(1000, int(m.group(1))))
continue
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
if m: # scale
prompt_dict["scale"] = float(m.group(1))
continue
m = re.match(r"n (.+)", parg, re.IGNORECASE)
if m: # negative prompt
prompt_dict["negative_prompt"] = m.group(1)
continue
m = re.match(r"ss (.+)", parg, re.IGNORECASE)
if m:
prompt_dict["sample_sampler"] = m.group(1)
continue
m = re.match(r"cn (.+)", parg, re.IGNORECASE)
if m:
prompt_dict["controlnet_image"] = m.group(1)
continue
except ValueError as ex:
print(f"Exception in parsing / 解析エラー: {parg}")
print(ex)
return prompt_dict
def sample_images_common(
pipe_class,
accelerator,
accelerator: Accelerator,
args: argparse.Namespace,
epoch,
steps,
@@ -4403,15 +4641,19 @@ def sample_images_common(
"""
StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した
"""
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
return
if args.sample_every_n_epochs is not None:
# sample_every_n_steps は無視する
if epoch is None or epoch % args.sample_every_n_epochs != 0:
if steps == 0:
if not args.sample_at_first:
return
else:
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
return
if args.sample_every_n_epochs is not None:
# sample_every_n_steps は無視する
if epoch is None or epoch % args.sample_every_n_epochs != 0:
return
else:
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
return
print(f"\ngenerating sample images at step / サンプル画像生成 ステップ: {steps}")
if not os.path.isfile(args.sample_prompts):
@@ -4421,6 +4663,13 @@ def sample_images_common(
org_vae_device = vae.device # CPUにいるはず
vae.to(device)
# unwrap unet and text_encoder(s)
unet = accelerator.unwrap_model(unet)
if isinstance(text_encoder, (list, tuple)):
text_encoder = [accelerator.unwrap_model(te) for te in text_encoder]
else:
text_encoder = accelerator.unwrap_model(text_encoder)
# read prompts
# with open(args.sample_prompts, "rt", encoding="utf-8") as f:
@@ -4438,56 +4687,19 @@ def sample_images_common(
with open(args.sample_prompts, "r", encoding="utf-8") as f:
prompts = json.load(f)
# schedulerを用意する
sched_init_args = {}
if args.sample_sampler == "ddim":
scheduler_cls = DDIMScheduler
elif args.sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある
scheduler_cls = DDPMScheduler
elif args.sample_sampler == "pndm":
scheduler_cls = PNDMScheduler
elif args.sample_sampler == "lms" or args.sample_sampler == "k_lms":
scheduler_cls = LMSDiscreteScheduler
elif args.sample_sampler == "euler" or args.sample_sampler == "k_euler":
scheduler_cls = EulerDiscreteScheduler
elif args.sample_sampler == "euler_a" or args.sample_sampler == "k_euler_a":
scheduler_cls = EulerAncestralDiscreteScheduler
elif args.sample_sampler == "dpmsolver" or args.sample_sampler == "dpmsolver++":
scheduler_cls = DPMSolverMultistepScheduler
sched_init_args["algorithm_type"] = args.sample_sampler
elif args.sample_sampler == "dpmsingle":
scheduler_cls = DPMSolverSinglestepScheduler
elif args.sample_sampler == "heun":
scheduler_cls = HeunDiscreteScheduler
elif args.sample_sampler == "dpm_2" or args.sample_sampler == "k_dpm_2":
scheduler_cls = KDPM2DiscreteScheduler
elif args.sample_sampler == "dpm_2_a" or args.sample_sampler == "k_dpm_2_a":
scheduler_cls = KDPM2AncestralDiscreteScheduler
else:
scheduler_cls = DDIMScheduler
if args.v_parameterization:
sched_init_args["prediction_type"] = "v_prediction"
scheduler = scheduler_cls(
num_train_timesteps=SCHEDULER_TIMESTEPS,
beta_start=SCHEDULER_LINEAR_START,
beta_end=SCHEDULER_LINEAR_END,
beta_schedule=SCHEDLER_SCHEDULE,
**sched_init_args,
schedulers: dict = {}
default_scheduler = get_my_scheduler(
sample_sampler=args.sample_sampler,
v_parameterization=args.v_parameterization,
)
# clip_sample=Trueにする
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
# print("set clip_sample to True")
scheduler.config.clip_sample = True
schedulers[args.sample_sampler] = default_scheduler
pipeline = pipe_class(
text_encoder=text_encoder,
vae=vae,
unet=unet,
tokenizer=tokenizer,
scheduler=scheduler,
scheduler=default_scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
@@ -4503,78 +4715,37 @@ def sample_images_common(
with torch.no_grad():
# with accelerator.autocast():
for i, prompt in enumerate(prompts):
for i, prompt_dict in enumerate(prompts):
if not accelerator.is_main_process:
continue
if isinstance(prompt, dict):
negative_prompt = prompt.get("negative_prompt")
sample_steps = prompt.get("sample_steps", 30)
width = prompt.get("width", 512)
height = prompt.get("height", 512)
scale = prompt.get("scale", 7.5)
seed = prompt.get("seed")
controlnet_image = prompt.get("controlnet_image")
prompt = prompt.get("prompt")
else:
# prompt = prompt.strip()
# if len(prompt) == 0 or prompt[0] == "#":
# continue
if isinstance(prompt_dict, str):
prompt_dict = line_to_prompt_dict(prompt_dict)
# subset of gen_img_diffusers
prompt_args = prompt.split(" --")
prompt = prompt_args[0]
negative_prompt = None
sample_steps = 30
width = height = 512
scale = 7.5
seed = None
controlnet_image = None
for parg in prompt_args:
try:
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
if m:
width = int(m.group(1))
continue
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
if m:
height = int(m.group(1))
continue
m = re.match(r"d (\d+)", parg, re.IGNORECASE)
if m:
seed = int(m.group(1))
continue
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
if m: # steps
sample_steps = max(1, min(1000, int(m.group(1))))
continue
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
if m: # scale
scale = float(m.group(1))
continue
m = re.match(r"n (.+)", parg, re.IGNORECASE)
if m: # negative prompt
negative_prompt = m.group(1)
continue
m = re.match(r"cn (.+)", parg, re.IGNORECASE)
if m: # negative prompt
controlnet_image = m.group(1)
continue
except ValueError as ex:
print(f"Exception in parsing / 解析エラー: {parg}")
print(ex)
assert isinstance(prompt_dict, dict)
negative_prompt = prompt_dict.get("negative_prompt")
sample_steps = prompt_dict.get("sample_steps", 30)
width = prompt_dict.get("width", 512)
height = prompt_dict.get("height", 512)
scale = prompt_dict.get("scale", 7.5)
seed = prompt_dict.get("seed")
controlnet_image = prompt_dict.get("controlnet_image")
prompt: str = prompt_dict.get("prompt", "")
sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
scheduler = schedulers.get(sampler_name)
if scheduler is None:
scheduler = get_my_scheduler(
sample_sampler=sampler_name,
v_parameterization=args.v_parameterization,
)
schedulers[sampler_name] = scheduler
pipeline.scheduler = scheduler
if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
if negative_prompt is not None:
@@ -4592,6 +4763,9 @@ def sample_images_common(
print(f"width: {width}")
print(f"sample_steps: {sample_steps}")
print(f"scale: {scale}")
print(f"sample_sampler: {sampler_name}")
if seed is not None:
print(f"seed: {seed}")
with accelerator.autocast():
latents = pipeline(
prompt=prompt,
@@ -4685,3 +4859,21 @@ class collator_class:
dataset.set_current_epoch(self.current_epoch.value)
dataset.set_current_step(self.current_step.value)
return examples[0]
class LossRecorder:
def __init__(self):
self.loss_list: List[float] = []
self.loss_total: float = 0.0
def add(self, *, epoch: int, step: int, loss: float) -> None:
if epoch == 0:
self.loss_list.append(loss)
else:
self.loss_total -= self.loss_list[step]
self.loss_list[step] = loss
self.loss_total += loss
@property
def moving_average(self) -> float:
return self.loss_total / len(self.loss_list)

View File

@@ -13,8 +13,8 @@ from library import sai_model_spec, model_util, sdxl_model_util
import lora
CLAMP_QUANTILE = 0.99
MIN_DIFF = 1e-1
# CLAMP_QUANTILE = 0.99
# MIN_DIFF = 1e-1
def save_to_file(file_name, model, state_dict, dtype):
@@ -29,7 +29,21 @@ def save_to_file(file_name, model, state_dict, dtype):
torch.save(model, file_name)
def svd(args):
def svd(
model_org=None,
model_tuned=None,
save_to=None,
dim=4,
v2=None,
sdxl=None,
conv_dim=None,
v_parameterization=None,
device=None,
save_precision=None,
clamp_quantile=0.99,
min_diff=0.01,
no_metadata=False,
):
def str_to_dtype(p):
if p == "float":
return torch.float
@@ -39,44 +53,42 @@ def svd(args):
return torch.bfloat16
return None
assert args.v2 != args.sdxl or (
not args.v2 and not args.sdxl
), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません"
if args.v_parameterization is None:
args.v_parameterization = args.v2
assert v2 != sdxl or (not v2 and not sdxl), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません"
if v_parameterization is None:
v_parameterization = v2
save_dtype = str_to_dtype(args.save_precision)
save_dtype = str_to_dtype(save_precision)
# load models
if not args.sdxl:
print(f"loading original SD model : {args.model_org}")
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org)
if not sdxl:
print(f"loading original SD model : {model_org}")
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org)
text_encoders_o = [text_encoder_o]
print(f"loading tuned SD model : {args.model_tuned}")
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned)
print(f"loading tuned SD model : {model_tuned}")
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned)
text_encoders_t = [text_encoder_t]
model_version = model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization)
model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization)
else:
print(f"loading original SDXL model : {args.model_org}")
print(f"loading original SDXL model : {model_org}")
text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.model_org, "cpu"
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, "cpu"
)
text_encoders_o = [text_encoder_o1, text_encoder_o2]
print(f"loading original SDXL model : {args.model_tuned}")
print(f"loading original SDXL model : {model_tuned}")
text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.model_tuned, "cpu"
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, "cpu"
)
text_encoders_t = [text_encoder_t1, text_encoder_t2]
model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0
# create LoRA network to extract weights: Use dim (rank) as alpha
if args.conv_dim is None:
if conv_dim is None:
kwargs = {}
else:
kwargs = {"conv_dim": args.conv_dim, "conv_alpha": args.conv_dim}
kwargs = {"conv_dim": conv_dim, "conv_alpha": conv_dim}
lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoders_o, unet_o, **kwargs)
lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoders_t, unet_t, **kwargs)
lora_network_o = lora.create_network(1.0, dim, dim, None, text_encoders_o, unet_o, **kwargs)
lora_network_t = lora.create_network(1.0, dim, dim, None, text_encoders_t, unet_t, **kwargs)
assert len(lora_network_o.text_encoder_loras) == len(
lora_network_t.text_encoder_loras
), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違いますSD1.xベースとSD2.xベース "
@@ -91,9 +103,9 @@ def svd(args):
diff = module_t.weight - module_o.weight
# Text Encoder might be same
if not text_encoder_different and torch.max(torch.abs(diff)) > MIN_DIFF:
if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff:
text_encoder_different = True
print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {MIN_DIFF}")
print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}")
diff = diff.float()
diffs[lora_name] = diff
@@ -120,16 +132,16 @@ def svd(args):
lora_weights = {}
with torch.no_grad():
for lora_name, mat in tqdm(list(diffs.items())):
# if args.conv_dim is None, diffs do not include LoRAs for conv2d-3x3
# if conv_dim is None, diffs do not include LoRAs for conv2d-3x3
conv2d = len(mat.size()) == 4
kernel_size = None if not conv2d else mat.size()[2:4]
conv2d_3x3 = conv2d and kernel_size != (1, 1)
rank = args.dim if not conv2d_3x3 or args.conv_dim is None else args.conv_dim
rank = dim if not conv2d_3x3 or conv_dim is None else conv_dim
out_dim, in_dim = mat.size()[0:2]
if args.device:
mat = mat.to(args.device)
if device:
mat = mat.to(device)
# print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim)
rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
@@ -149,7 +161,7 @@ def svd(args):
Vh = Vh[:rank, :]
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
hi_val = torch.quantile(dist, clamp_quantile)
low_val = -hi_val
U = U.clamp(low_val, hi_val)
@@ -178,34 +190,32 @@ def svd(args):
info = lora_network_save.load_state_dict(lora_sd)
print(f"Loading extracted LoRA weights: {info}")
dir_name = os.path.dirname(args.save_to)
dir_name = os.path.dirname(save_to)
if dir_name and not os.path.exists(dir_name):
os.makedirs(dir_name, exist_ok=True)
# minimum metadata
net_kwargs = {}
if args.conv_dim is not None:
net_kwargs["conv_dim"] = args.conv_dim
net_kwargs["conv_alpha"] = args.conv_dim
if conv_dim is not None:
net_kwargs["conv_dim"] = str(conv_dim)
net_kwargs["conv_alpha"] = str(float(conv_dim))
metadata = {
"ss_v2": str(args.v2),
"ss_v2": str(v2),
"ss_base_model_version": model_version,
"ss_network_module": "networks.lora",
"ss_network_dim": str(args.dim),
"ss_network_alpha": str(args.dim),
"ss_network_dim": str(dim),
"ss_network_alpha": str(float(dim)),
"ss_network_args": json.dumps(net_kwargs),
}
if not args.no_metadata:
title = os.path.splitext(os.path.basename(args.save_to))[0]
sai_metadata = sai_model_spec.build_metadata(
None, args.v2, args.v_parameterization, args.sdxl, True, False, time.time(), title=title
)
if not no_metadata:
title = os.path.splitext(os.path.basename(save_to))[0]
sai_metadata = sai_model_spec.build_metadata(None, v2, v_parameterization, sdxl, True, False, time.time(), title=title)
metadata.update(sai_metadata)
lora_network_save.save_weights(args.save_to, save_dtype, metadata)
print(f"LoRA weights are saved to: {args.save_to}")
lora_network_save.save_weights(save_to, save_dtype, metadata)
print(f"LoRA weights are saved to: {save_to}")
def setup_parser() -> argparse.ArgumentParser:
@@ -213,7 +223,7 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む")
parser.add_argument(
"--v_parameterization",
type=bool,
action="store_true",
default=None,
help="make LoRA metadata for v-parameterization (default is same to v2) / 作成するLoRAのメタデータにv-parameterization用と設定する省略時はv2と同じ",
)
@@ -231,16 +241,22 @@ def setup_parser() -> argparse.ArgumentParser:
"--model_org",
type=str,
default=None,
required=True,
help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors",
)
parser.add_argument(
"--model_tuned",
type=str,
default=None,
required=True,
help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル生成されるLoRAは元→派生の差分になります、ckptまたはsafetensors",
)
parser.add_argument(
"--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
"--save_to",
type=str,
default=None,
required=True,
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors",
)
parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数rankデフォルト4")
parser.add_argument(
@@ -250,6 +266,19 @@ def setup_parser() -> argparse.ArgumentParser:
help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数rankデフォルトNone、適用なし",
)
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
parser.add_argument(
"--clamp_quantile",
type=float,
default=0.99,
help="Quantile clamping value, float, (0-1). Default = 0.99 / 値をクランプするための分位点、float、(0-1)。デフォルトは0.99",
)
parser.add_argument(
"--min_diff",
type=float,
default=0.01,
help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01 /"
+ "LoRAを抽出するために元モデルと派生モデルの差分の最小値、float、(0-1)。デフォルトは0.01",
)
parser.add_argument(
"--no_metadata",
action="store_true",
@@ -264,4 +293,4 @@ if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
svd(args)
svd(**vars(args))

View File

@@ -1,10 +1,10 @@
accelerate==0.23.0
transformers==4.30.2
diffusers[torch]==0.21.2
accelerate==0.25.0
transformers==4.36.2
diffusers[torch]==0.25.0
ftfy==6.1.1
# albumentations==1.3.0
opencv-python==4.7.0.68
einops==0.6.0
einops==0.6.1
pytorch-lightning==1.9.0
# bitsandbytes==0.39.1
tensorboard==2.10.1
@@ -14,7 +14,7 @@ altair==4.2.2
easygui==0.98.3
toml==0.10.2
voluptuous==0.13.1
huggingface-hub==0.15.1
huggingface-hub==0.20.1
# for BLIP captioning
# requests==2.28.2
# timm==0.6.12

View File

@@ -17,10 +17,13 @@ import re
import diffusers
import numpy as np
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
@@ -54,7 +57,7 @@ import library.train_util as train_util
import library.sdxl_model_util as sdxl_model_util
import library.sdxl_train_util as sdxl_train_util
from networks.lora import LoRANetwork
from library.sdxl_original_unet import SdxlUNet2DConditionModel
from library.sdxl_original_unet import InferSdxlUNet2DConditionModel
from library.original_unet import FlashAttentionFunction
from networks.control_net_lllite import ControlNetLLLite
@@ -287,7 +290,7 @@ class PipelineLike:
vae: AutoencoderKL,
text_encoders: List[CLIPTextModel],
tokenizers: List[CLIPTokenizer],
unet: SdxlUNet2DConditionModel,
unet: InferSdxlUNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
clip_skip: int,
):
@@ -325,7 +328,7 @@ class PipelineLike:
self.vae = vae
self.text_encoders = text_encoders
self.tokenizers = tokenizers
self.unet: SdxlUNet2DConditionModel = unet
self.unet: InferSdxlUNet2DConditionModel = unet
self.scheduler = scheduler
self.safety_checker = None
@@ -501,7 +504,8 @@ class PipelineLike:
uncond_embeddings = tes_uncond_embs[0]
for i in range(1, len(tes_text_embs)):
text_embeddings = torch.cat([text_embeddings, tes_text_embs[i]], dim=2) # n,77,2048
uncond_embeddings = torch.cat([uncond_embeddings, tes_uncond_embs[i]], dim=2) # n,77,2048
if do_classifier_free_guidance:
uncond_embeddings = torch.cat([uncond_embeddings, tes_uncond_embs[i]], dim=2) # n,77,2048
if do_classifier_free_guidance:
if negative_scale is None:
@@ -564,9 +568,11 @@ class PipelineLike:
text_pool = clip_vision_embeddings # replace: same as ComfyUI (?)
c_vector = torch.cat([text_pool, c_vector], dim=1)
uc_vector = torch.cat([uncond_pool, uc_vector], dim=1)
vector_embeddings = torch.cat([uc_vector, c_vector])
if do_classifier_free_guidance:
uc_vector = torch.cat([uncond_pool, uc_vector], dim=1)
vector_embeddings = torch.cat([uc_vector, c_vector])
else:
vector_embeddings = c_vector
# set timesteps
self.scheduler.set_timesteps(num_inference_steps, self.device)
@@ -1368,6 +1374,7 @@ def main(args):
(_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model(
args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype
)
unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet)
# xformers、Hypernetwork対応
if not args.diffusers_xformers:
@@ -1523,10 +1530,14 @@ def main(args):
print("set vae_dtype to float32")
vae_dtype = torch.float32
vae.to(vae_dtype).to(device)
vae.eval()
text_encoder1.to(dtype).to(device)
text_encoder2.to(dtype).to(device)
unet.to(dtype).to(device)
text_encoder1.eval()
text_encoder2.eval()
unet.eval()
# networkを組み込む
if args.network_module:
@@ -1534,12 +1545,20 @@ def main(args):
network_default_muls = []
network_pre_calc = args.network_pre_calc
# merge関連の引数を統合する
if args.network_merge:
network_merge = len(args.network_module) # all networks are merged
elif args.network_merge_n_models:
network_merge = args.network_merge_n_models
else:
network_merge = 0
print(f"network_merge: {network_merge}")
for i, network_module in enumerate(args.network_module):
print("import network module:", network_module)
imported_module = importlib.import_module(network_module)
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
network_default_muls.append(network_mul)
net_kwargs = {}
if args.network_args and i < len(args.network_args):
@@ -1550,31 +1569,32 @@ def main(args):
key, value = net_arg.split("=")
net_kwargs[key] = value
if args.network_weights and i < len(args.network_weights):
network_weight = args.network_weights[i]
print("load network weights from:", network_weight)
if model_util.is_safetensors(network_weight) and args.network_show_meta:
from safetensors.torch import safe_open
with safe_open(network_weight, framework="pt") as f:
metadata = f.metadata()
if metadata is not None:
print(f"metadata for: {network_weight}: {metadata}")
network, weights_sd = imported_module.create_network_from_weights(
network_mul, network_weight, vae, [text_encoder1, text_encoder2], unet, for_inference=True, **net_kwargs
)
else:
if args.network_weights is None or len(args.network_weights) <= i:
raise ValueError("No weight. Weight is required.")
network_weight = args.network_weights[i]
print("load network weights from:", network_weight)
if model_util.is_safetensors(network_weight) and args.network_show_meta:
from safetensors.torch import safe_open
with safe_open(network_weight, framework="pt") as f:
metadata = f.metadata()
if metadata is not None:
print(f"metadata for: {network_weight}: {metadata}")
network, weights_sd = imported_module.create_network_from_weights(
network_mul, network_weight, vae, [text_encoder1, text_encoder2], unet, for_inference=True, **net_kwargs
)
if network is None:
return
mergeable = network.is_mergeable()
if args.network_merge and not mergeable:
if network_merge and not mergeable:
print("network is not mergiable. ignore merge option.")
if not args.network_merge or not mergeable:
if not mergeable or i >= network_merge:
# not merging
network.apply_to([text_encoder1, text_encoder2], unet)
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
print(f"weights are loaded: {info}")
@@ -1588,6 +1608,7 @@ def main(args):
network.backup_weights()
networks.append(network)
network_default_muls.append(network_mul)
else:
network.merge_to([text_encoder1, text_encoder2], unet, weights_sd, dtype, device)
@@ -1683,6 +1704,10 @@ def main(args):
if args.diffusers_xformers:
pipe.enable_xformers_memory_efficient_attention()
# Deep Shrink
if args.ds_depth_1 is not None:
unet.set_deep_shrink(args.ds_depth_1, args.ds_timesteps_1, args.ds_depth_2, args.ds_timesteps_2, args.ds_ratio)
# Textual Inversionを処理する
if args.textual_inversion_embeddings:
token_ids_embeds1 = []
@@ -1864,9 +1889,18 @@ def main(args):
size = None
for i, network in enumerate(networks):
if i < 3:
if (i < 3 and args.network_regional_mask_max_color_codes is None) or i < args.network_regional_mask_max_color_codes:
np_mask = np.array(mask_images[0])
np_mask = np_mask[:, :, i]
if args.network_regional_mask_max_color_codes:
# カラーコードでマスクを指定する
ch0 = (i + 1) & 1
ch1 = ((i + 1) >> 1) & 1
ch2 = ((i + 1) >> 2) & 1
np_mask = np.all(np_mask == np.array([ch0, ch1, ch2]) * 255, axis=2)
np_mask = np_mask.astype(np.uint8) * 255
else:
np_mask = np_mask[:, :, i]
size = np_mask.shape
else:
np_mask = np.full(size, 255, dtype=np.uint8)
@@ -2264,6 +2298,13 @@ def main(args):
clip_prompt = None
network_muls = None
# Deep Shrink
ds_depth_1 = None # means no override
ds_timesteps_1 = args.ds_timesteps_1
ds_depth_2 = args.ds_depth_2
ds_timesteps_2 = args.ds_timesteps_2
ds_ratio = args.ds_ratio
prompt_args = raw_prompt.strip().split(" --")
prompt = prompt_args[0]
print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}")
@@ -2371,10 +2412,51 @@ def main(args):
print(f"network mul: {network_muls}")
continue
# Deep Shrink
m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink depth 1
ds_depth_1 = int(m.group(1))
print(f"deep shrink depth 1: {ds_depth_1}")
continue
m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink timesteps 1
ds_timesteps_1 = int(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
print(f"deep shrink timesteps 1: {ds_timesteps_1}")
continue
m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink depth 2
ds_depth_2 = int(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
print(f"deep shrink depth 2: {ds_depth_2}")
continue
m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink timesteps 2
ds_timesteps_2 = int(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
print(f"deep shrink timesteps 2: {ds_timesteps_2}")
continue
m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE)
if m: # deep shrink ratio
ds_ratio = float(m.group(1))
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
print(f"deep shrink ratio: {ds_ratio}")
continue
except ValueError as ex:
print(f"Exception in parsing / 解析エラー: {parg}")
print(ex)
# override Deep Shrink
if ds_depth_1 is not None:
if ds_depth_1 < 0:
ds_depth_1 = args.ds_depth_1 or 3
unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio)
# prepare seed
if seeds is not None: # given in prompt
# 数が足りないなら前のをそのまま使う
@@ -2615,10 +2697,19 @@ def setup_parser() -> argparse.ArgumentParser:
"--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数"
)
parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する")
parser.add_argument(
"--network_merge_n_models", type=int, default=None, help="merge this number of networks / この数だけネットワークをマージする"
)
parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする")
parser.add_argument(
"--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する"
)
parser.add_argument(
"--network_regional_mask_max_color_codes",
type=int,
default=None,
help="max color codes for regional mask (default is None, mask by channel) / regional maskの最大色数デフォルトはNoneでチャンネルごとのマスク",
)
parser.add_argument(
"--textual_inversion_embeddings",
type=str,
@@ -2703,6 +2794,31 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="enable CLIP Vision Conditioning for img2img with this strength / img2imgでCLIP Vision Conditioningを有効にしてこのstrengthで処理する",
)
# Deep Shrink
parser.add_argument(
"--ds_depth_1",
type=int,
default=None,
help="Enable Deep Shrink with this depth 1, valid values are 0 to 8 / Deep Shrinkをこのdepthで有効にする",
)
parser.add_argument(
"--ds_timesteps_1",
type=int,
default=650,
help="Apply Deep Shrink depth 1 until this timesteps / Deep Shrink depth 1を適用するtimesteps",
)
parser.add_argument("--ds_depth_2", type=int, default=None, help="Deep Shrink depth 2 / Deep Shrinkのdepth 2")
parser.add_argument(
"--ds_timesteps_2",
type=int,
default=650,
help="Apply Deep Shrink depth 2 until this timesteps / Deep Shrink depth 2を適用するtimesteps",
)
parser.add_argument(
"--ds_ratio", type=float, default=0.5, help="Deep Shrink ratio for downsampling / Deep Shrinkのdownsampling比率"
)
# # parser.add_argument(
# "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像"
# )

View File

@@ -10,10 +10,13 @@ import toml
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
@@ -34,6 +37,7 @@ from library.custom_train_functions import (
prepare_scheduler_for_custom_training,
scale_v_prediction_loss_like_noise_prediction,
add_v_prediction_like_loss,
apply_debiased_estimation,
)
from library.sdxl_original_unet import SdxlUNet2DConditionModel
@@ -70,33 +74,22 @@ def get_block_params_to_optimize(unet: SdxlUNet2DConditionModel, block_lrs: List
def append_block_lr_to_logs(block_lrs, logs, lr_scheduler, optimizer_type):
lrs = lr_scheduler.get_last_lr()
lr_index = 0
names = []
block_index = 0
while lr_index < len(lrs):
while block_index < UNET_NUM_BLOCKS_FOR_BLOCK_LR + 2:
if block_index < UNET_NUM_BLOCKS_FOR_BLOCK_LR:
name = f"block{block_index}"
if block_lrs[block_index] == 0:
block_index += 1
continue
names.append(f"block{block_index}")
elif block_index == UNET_NUM_BLOCKS_FOR_BLOCK_LR:
name = "text_encoder1"
names.append("text_encoder1")
elif block_index == UNET_NUM_BLOCKS_FOR_BLOCK_LR + 1:
name = "text_encoder2"
else:
raise ValueError(f"unexpected block_index: {block_index}")
names.append("text_encoder2")
block_index += 1
logs["lr/" + name] = float(lrs[lr_index])
if optimizer_type.lower().startswith("DAdapt".lower()) or optimizer_type.lower() == "Prodigy".lower():
logs["lr/d*lr/" + name] = (
lr_scheduler.optimizers[-1].param_groups[lr_index]["d"] * lr_scheduler.optimizers[-1].param_groups[lr_index]["lr"]
)
lr_index += 1
train_util.append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names)
def train(args):
@@ -271,10 +264,11 @@ def train(args):
accelerator.wait_for_everyone()
# 学習を準備する:モデルを適切な状態にする
training_models = []
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
training_models.append(unet)
train_unet = args.learning_rate > 0
train_text_encoder1 = False
train_text_encoder2 = False
if args.train_text_encoder:
# TODO each option for two text encoders?
@@ -282,10 +276,23 @@ def train(args):
if args.gradient_checkpointing:
text_encoder1.gradient_checkpointing_enable()
text_encoder2.gradient_checkpointing_enable()
training_models.append(text_encoder1)
training_models.append(text_encoder2)
# set require_grad=True later
lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train
lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train
train_text_encoder1 = lr_te1 > 0
train_text_encoder2 = lr_te2 > 0
# caching one text encoder output is not supported
if not train_text_encoder1:
text_encoder1.to(weight_dtype)
if not train_text_encoder2:
text_encoder2.to(weight_dtype)
text_encoder1.requires_grad_(train_text_encoder1)
text_encoder2.requires_grad_(train_text_encoder2)
text_encoder1.train(train_text_encoder1)
text_encoder2.train(train_text_encoder2)
else:
text_encoder1.to(weight_dtype)
text_encoder2.to(weight_dtype)
text_encoder1.requires_grad_(False)
text_encoder2.requires_grad_(False)
text_encoder1.eval()
@@ -294,7 +301,7 @@ def train(args):
# TextEncoderの出力をキャッシュする
if args.cache_text_encoder_outputs:
# Text Encodes are eval and no grad
with torch.no_grad():
with torch.no_grad(), accelerator.autocast():
train_dataset_group.cache_text_encoder_outputs(
(tokenizer1, tokenizer2),
(text_encoder1, text_encoder2),
@@ -310,30 +317,33 @@ def train(args):
vae.eval()
vae.to(accelerator.device, dtype=vae_dtype)
for m in training_models:
m.requires_grad_(True)
unet.requires_grad_(train_unet)
if not train_unet:
unet.to(accelerator.device, dtype=weight_dtype) # because of unet is not prepared
if block_lrs is None:
params = []
for m in training_models:
params.extend(m.parameters())
params_to_optimize = params
training_models = []
params_to_optimize = []
if train_unet:
training_models.append(unet)
if block_lrs is None:
params_to_optimize.append({"params": list(unet.parameters()), "lr": args.learning_rate})
else:
params_to_optimize.extend(get_block_params_to_optimize(unet, block_lrs))
# calculate number of trainable parameters
n_params = 0
for p in params:
if train_text_encoder1:
training_models.append(text_encoder1)
params_to_optimize.append({"params": list(text_encoder1.parameters()), "lr": args.learning_rate_te1 or args.learning_rate})
if train_text_encoder2:
training_models.append(text_encoder2)
params_to_optimize.append({"params": list(text_encoder2.parameters()), "lr": args.learning_rate_te2 or args.learning_rate})
# calculate number of trainable parameters
n_params = 0
for params in params_to_optimize:
for p in params["params"]:
n_params += p.numel()
else:
params_to_optimize = get_block_params_to_optimize(training_models[0], block_lrs) # U-Net
for m in training_models[1:]: # Text Encoders if exists
params_to_optimize.append({"params": m.parameters(), "lr": args.learning_rate})
# calculate number of trainable parameters
n_params = 0
for params in params_to_optimize:
for p in params["params"]:
n_params += p.numel()
accelerator.print(f"train unet: {train_unet}, text_encoder1: {train_text_encoder1}, text_encoder2: {train_text_encoder2}")
accelerator.print(f"number of models: {len(training_models)}")
accelerator.print(f"number of trainable parameters: {n_params}")
@@ -385,18 +395,17 @@ def train(args):
text_encoder2.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
if args.train_text_encoder:
unet, text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler
)
if train_unet:
unet = accelerator.prepare(unet)
if train_text_encoder1:
# freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer
text_encoder1.text_model.encoder.layers[-1].requires_grad_(False)
text_encoder1.text_model.final_layer_norm.requires_grad_(False)
text_encoder1 = accelerator.prepare(text_encoder1)
if train_text_encoder2:
text_encoder2 = accelerator.prepare(text_encoder2)
# transform DDP after prepare
text_encoder1, text_encoder2, unet = train_util.transform_models_if_DDP([text_encoder1, text_encoder2, unet])
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
(unet,) = train_util.transform_models_if_DDP([unet])
text_encoder1.to(weight_dtype)
text_encoder2.to(weight_dtype)
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
if args.cache_text_encoder_outputs:
@@ -448,10 +457,18 @@ def train(args):
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs['wandb'] = {'name': args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
# For --sample_at_first
sdxl_train_util.sample_images(
accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [text_encoder1, text_encoder2], unet
)
loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
@@ -459,10 +476,9 @@ def train(args):
for m in training_models:
m.train()
loss_total = 0
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
with accelerator.accumulate(*training_models):
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
@@ -473,7 +489,7 @@ def train(args):
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
latents = torch.nan_to_num(latents, 0, out=latents)
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
@@ -494,6 +510,7 @@ def train(args):
# else:
input_ids1 = input_ids1.to(accelerator.device)
input_ids2 = input_ids2.to(accelerator.device)
# unwrap_model is fine for models not wrapped by accelerator
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
args.max_token_length,
input_ids1,
@@ -503,6 +520,7 @@ def train(args):
text_encoder1,
text_encoder2,
None if not args.full_fp16 else weight_dtype,
accelerator=accelerator,
)
else:
encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
@@ -548,7 +566,12 @@ def train(args):
target = noise
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.v_pred_like_loss:
if (
args.min_snr_gamma
or args.scale_v_pred_loss_like_noise_pred
or args.v_pred_like_loss
or args.debiased_estimation_loss
):
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
@@ -559,6 +582,8 @@ def train(args):
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss:
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
loss = loss.mean() # mean over batch dimension
else:
@@ -620,29 +645,22 @@ def train(args):
if args.logging_dir is not None:
logs = {"loss": current_loss}
if block_lrs is None:
logs["lr"] = float(lr_scheduler.get_last_lr()[0])
if (
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
): # tracking d*lr value
logs["lr/d*lr"] = (
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
)
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=train_unet)
else:
append_block_lr_to_logs(block_lrs, logs, lr_scheduler, args.optimizer_type)
append_block_lr_to_logs(block_lrs, logs, lr_scheduler, args.optimizer_type) # U-Net is included in block_lrs
accelerator.log(logs, step=global_step)
# TODO moving averageにする
loss_total += current_loss
avr_loss = loss_total / (step + 1)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(train_dataloader)}
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
@@ -726,6 +744,19 @@ def setup_parser() -> argparse.ArgumentParser:
custom_train_functions.add_custom_train_arguments(parser)
sdxl_train_util.add_sdxl_training_arguments(parser)
parser.add_argument(
"--learning_rate_te1",
type=float,
default=None,
help="learning rate for text encoder 1 (ViT-L) / text encoder 1 (ViT-L)の学習率",
)
parser.add_argument(
"--learning_rate_te2",
type=float,
default=None,
help="learning rate for text encoder 2 (BiG-G) / text encoder 2 (BiG-G)の学習率",
)
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
parser.add_argument(

View File

@@ -44,6 +44,7 @@ from library.custom_train_functions import (
pyramid_noise_like,
apply_noise_offset,
scale_v_prediction_loss_like_noise_prediction,
apply_debiased_estimation,
)
import networks.control_net_lllite_for_train as control_net_lllite_for_train
@@ -282,9 +283,6 @@ def train(args):
# acceleratorがなんかよろしくやってくれるらしい
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
# transform DDP after prepare (train_network here only)
unet = train_util.transform_models_if_DDP([unet])[0]
if args.gradient_checkpointing:
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
else:
@@ -344,14 +342,15 @@ def train(args):
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs['wandb'] = {'name': args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
)
loss_list = []
loss_total = 0.0
loss_recorder = train_util.LossRecorder()
del train_dataset_group
# function for saving/removing
@@ -397,7 +396,7 @@ def train(args):
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
latents = torch.nan_to_num(latents, 0, out=latents)
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
@@ -460,11 +459,13 @@ def train(args):
loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss:
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
@@ -500,14 +501,9 @@ def train(args):
remove_model(remove_ckpt_name)
current_loss = loss.detach().item()
if epoch == 0:
loss_list.append(current_loss)
else:
loss_total -= loss_list[step]
loss_list[step] = current_loss
loss_total += current_loss
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if args.logging_dir is not None:
@@ -518,7 +514,7 @@ def train(args):
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(loss_list)}
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()

View File

@@ -40,6 +40,7 @@ from library.custom_train_functions import (
pyramid_noise_like,
apply_noise_offset,
scale_v_prediction_loss_like_noise_prediction,
apply_debiased_estimation,
)
import networks.control_net_lllite as control_net_lllite
@@ -253,9 +254,6 @@ def train(args):
)
network: control_net_lllite.ControlNetLLLite
# transform DDP after prepare (train_network here only)
unet, network = train_util.transform_models_if_DDP([unet, network])
if args.gradient_checkpointing:
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
else:
@@ -323,8 +321,7 @@ def train(args):
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
)
loss_list = []
loss_total = 0.0
loss_recorder = train_util.LossRecorder()
del train_dataset_group
# function for saving/removing
@@ -366,7 +363,7 @@ def train(args):
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
latents = torch.nan_to_num(latents, 0, out=latents)
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
@@ -430,11 +427,13 @@ def train(args):
loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss:
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
@@ -470,14 +469,9 @@ def train(args):
remove_model(remove_ckpt_name)
current_loss = loss.detach().item()
if epoch == 0:
loss_list.append(current_loss)
else:
loss_total -= loss_list[step]
loss_list[step] = current_loss
loss_total += current_loss
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if args.logging_dir is not None:
@@ -488,7 +482,7 @@ def train(args):
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(loss_list)}
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()

View File

@@ -1,9 +1,12 @@
import argparse
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
@@ -70,14 +73,16 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
if torch.cuda.is_available():
torch.cuda.empty_cache()
dataset.cache_text_encoder_outputs(
tokenizers,
text_encoders,
accelerator.device,
weight_dtype,
args.cache_text_encoder_outputs_to_disk,
accelerator.is_main_process,
)
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
with accelerator.autocast():
dataset.cache_text_encoder_outputs(
tokenizers,
text_encoders,
accelerator.device,
weight_dtype,
args.cache_text_encoder_outputs_to_disk,
accelerator.is_main_process,
)
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
text_encoders[1].to("cpu", dtype=torch.float32)
@@ -121,6 +126,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
text_encoders[0],
text_encoders[1],
None if not args.full_fp16 else weight_dtype,
accelerator=accelerator,
)
else:
encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)

View File

@@ -64,6 +64,7 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
text_encoders[0],
text_encoders[1],
None if not args.full_fp16 else weight_dtype,
accelerator=accelerator,
)
return encoder_hidden_states1, encoder_hidden_states2, pool2

View File

@@ -23,7 +23,7 @@ def convert(args):
is_load_ckpt = os.path.isfile(args.model_to_load)
is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0
assert not is_load_ckpt or args.v1 != args.v2, f"v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です"
assert not is_load_ckpt or args.v1 != args.v2, "v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です"
# assert (
# is_save_ckpt or args.reference_model is not None
# ), f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です"
@@ -34,10 +34,12 @@ def convert(args):
if is_load_ckpt:
v2_model = args.v2
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load, unet_use_linear_projection_in_v2=args.unet_use_linear_projection)
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(
v2_model, args.model_to_load, unet_use_linear_projection_in_v2=args.unet_use_linear_projection
)
else:
pipe = StableDiffusionPipeline.from_pretrained(
args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None
args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None, variant=args.variant
)
text_encoder = pipe.text_encoder
vae = pipe.vae
@@ -57,15 +59,26 @@ def convert(args):
if is_save_ckpt:
original_model = args.model_to_load if is_load_ckpt else None
key_count = model_util.save_stable_diffusion_checkpoint(
v2_model, args.model_to_save, text_encoder, unet, original_model, args.epoch, args.global_step, save_dtype, vae
v2_model,
args.model_to_save,
text_encoder,
unet,
original_model,
args.epoch,
args.global_step,
None if args.metadata is None else eval(args.metadata),
save_dtype=save_dtype,
vae=vae,
)
print(f"model saved. total converted state_dict keys: {key_count}")
else:
print(f"copy scheduler/tokenizer config from: {args.reference_model if args.reference_model is not None else 'default model'}")
print(
f"copy scheduler/tokenizer config from: {args.reference_model if args.reference_model is not None else 'default model'}"
)
model_util.save_diffusers_checkpoint(
v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors
)
print(f"model saved.")
print("model saved.")
def setup_parser() -> argparse.ArgumentParser:
@@ -77,7 +90,9 @@ def setup_parser() -> argparse.ArgumentParser:
"--v2", action="store_true", help="load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む"
)
parser.add_argument(
"--unet_use_linear_projection", action="store_true", help="When saving v2 model as Diffusers, set U-Net config to `use_linear_projection=true` (to match stabilityai's model) / Diffusers形式でv2モデルを保存するときにU-Netの設定を`use_linear_projection=true`にするstabilityaiのモデルと合わせる"
"--unet_use_linear_projection",
action="store_true",
help="When saving v2 model as Diffusers, set U-Net config to `use_linear_projection=true` (to match stabilityai's model) / Diffusers形式でv2モデルを保存するときにU-Netの設定を`use_linear_projection=true`にするstabilityaiのモデルと合わせる",
)
parser.add_argument(
"--fp16",
@@ -99,6 +114,18 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--global_step", type=int, default=0, help="global_step to write to checkpoint / checkpointに記録するglobal_stepの値"
)
parser.add_argument(
"--metadata",
type=str,
default=None,
help='モデルに保存されるメタデータ、Pythonの辞書形式で指定 / metadata: metadata written in to the model in Python Dictionary. Example metadata: \'{"name": "model_name", "resolution": "512x512"}\'',
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="読む込むDiffusersのvariantを指定する、例: fp16 / variant: Diffusers variant to load. Example: fp16",
)
parser.add_argument(
"--reference_model",
type=str,

View File

@@ -1,84 +0,0 @@
import argparse
import os
import torch
from safetensors import safe_open
from safetensors.torch import load_file, save_file
from tqdm import tqdm
def split(args):
# load embedding
if args.embedding.endswith(".safetensors"):
embedding = load_file(args.embedding)
with safe_open(args.embedding, framework="pt") as f:
metadata = f.metadata()
else:
embedding = torch.load(args.embedding)
metadata = None
# check format
if "emb_params" in embedding:
# SD1/2
keys = ["emb_params"]
elif "clip_l" in embedding:
# SDXL
keys = ["clip_l", "clip_g"]
else:
print("Unknown embedding format")
exit()
num_vectors = embedding[keys[0]].shape[0]
# prepare output directory
os.makedirs(args.output_dir, exist_ok=True)
# prepare splits
if args.vectors_per_split is not None:
num_splits = (num_vectors + args.vectors_per_split - 1) // args.vectors_per_split
vectors_for_split = [args.vectors_per_split] * num_splits
if sum(vectors_for_split) > num_vectors:
vectors_for_split[-1] -= sum(vectors_for_split) - num_vectors
assert sum(vectors_for_split) == num_vectors
elif args.vectors is not None:
vectors_for_split = args.vectors
num_splits = len(vectors_for_split)
else:
print("Must specify either --vectors_per_split or --vectors / --vectors_per_split または --vectors のどちらかを指定する必要があります")
exit()
assert (
sum(vectors_for_split) == num_vectors
), "Sum of vectors must be equal to the number of vectors in the embedding / 分割したベクトルの合計はembeddingのベクトル数と等しくなければなりません"
# split
basename = os.path.splitext(os.path.basename(args.embedding))[0]
done_vectors = 0
for i, num_vectors in enumerate(vectors_for_split):
print(f"Splitting {num_vectors} vectors...")
split_embedding = {}
for key in keys:
split_embedding[key] = embedding[key][done_vectors : done_vectors + num_vectors]
output_file = os.path.join(args.output_dir, f"{basename}_{i}.safetensors")
save_file(split_embedding, output_file, metadata)
print(f"Saved to {output_file}")
done_vectors += num_vectors
print("Done")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Merge models")
parser.add_argument("--embedding", type=str, help="Embedding to split")
parser.add_argument("--output_dir", type=str, help="Output directory")
parser.add_argument(
"--vectors_per_split",
type=int,
default=None,
help="Number of vectors per split. If num_vectors is 8 and vectors_per_split is 3, then 3, 3, 2 vectors will be split",
)
parser.add_argument("--vectors", type=int, default=None, nargs="*", help="number of vectors for each split. e.g. 3 3 2")
args = parser.parse_args()
split(args)

View File

@@ -11,10 +11,13 @@ import toml
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
@@ -333,12 +336,15 @@ def train(args):
)
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs['wandb'] = {'name': args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
accelerator.init_trackers(
"controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
)
loss_list = []
loss_total = 0.0
loss_recorder = train_util.LossRecorder()
del train_dataset_group
# function for saving/removing
@@ -372,6 +378,11 @@ def train(args):
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
# For --sample_at_first
train_util.sample_images(
accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, controlnet=controlnet
)
# training loop
for epoch in range(num_train_epochs):
if is_main_process:
@@ -450,7 +461,7 @@ def train(args):
loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
@@ -500,14 +511,9 @@ def train(args):
remove_model(remove_ckpt_name)
current_loss = loss.detach().item()
if epoch == 0:
loss_list.append(current_loss)
else:
loss_total -= loss_list[step]
loss_list[step] = current_loss
loss_total += current_loss
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if args.logging_dir is not None:
@@ -518,7 +524,7 @@ def train(args):
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(loss_list)}
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()

View File

@@ -11,10 +11,13 @@ import toml
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
@@ -35,6 +38,7 @@ from library.custom_train_functions import (
pyramid_noise_like,
apply_noise_offset,
scale_v_prediction_loss_like_noise_prediction,
apply_debiased_estimation,
)
# perlin_noise,
@@ -108,6 +112,7 @@ def train(args):
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
# モデルを読み込む
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)
@@ -132,7 +137,7 @@ def train(args):
# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
@@ -163,11 +168,17 @@ def train(args):
# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")
if train_text_encoder:
# wightout list, adamw8bit is crashed
trainable_params = list(itertools.chain(unet.parameters(), text_encoder.parameters()))
if args.learning_rate_te is None:
# wightout list, adamw8bit is crashed
trainable_params = list(itertools.chain(unet.parameters(), text_encoder.parameters()))
else:
trainable_params = [
{"params": list(unet.parameters()), "lr": args.learning_rate},
{"params": list(text_encoder.parameters()), "lr": args.learning_rate_te},
]
else:
trainable_params = unet.parameters()
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
@@ -215,9 +226,6 @@ def train(args):
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
# transform DDP after prepare
text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet)
if not train_text_encoder:
text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error
@@ -260,12 +268,16 @@ def train(args):
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs['wandb'] = {'name': args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
loss_list = []
loss_total = 0.0
# For --sample_at_first
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
@@ -333,9 +345,11 @@ def train(args):
loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
@@ -383,30 +397,20 @@ def train(args):
current_loss = loss.detach().item()
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if (
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
): # tracking d*lr value
logs["lr/d*lr"] = (
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
)
logs = {"loss": current_loss}
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
accelerator.log(logs, step=global_step)
if epoch == 0:
loss_list.append(current_loss)
else:
loss_total -= loss_list[step]
loss_list[step] = current_loss
loss_total += current_loss
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(loss_list)}
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
@@ -464,6 +468,12 @@ def setup_parser() -> argparse.ArgumentParser:
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
parser.add_argument(
"--learning_rate_te",
type=float,
default=None,
help="learning rate for text encoder, default is same as unet / Text Encoderの学習率、デフォルトはunetと同じ",
)
parser.add_argument(
"--no_token_padding",
action="store_true",
@@ -475,6 +485,11 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない",
)
parser.add_argument(
"--no_half_vae",
action="store_true",
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
)
return parser

View File

@@ -12,6 +12,7 @@ import toml
from tqdm import tqdm
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
try:
import intel_extension_for_pytorch as ipex
@@ -43,6 +44,7 @@ from library.custom_train_functions import (
prepare_scheduler_for_custom_training,
scale_v_prediction_loss_like_noise_prediction,
add_v_prediction_like_loss,
apply_debiased_estimation,
)
@@ -108,6 +110,9 @@ class NetworkTrainer:
def is_text_encoder_outputs_cached(self, args):
return False
def is_train_text_encoder(self, args):
return not args.network_train_unet_only and not self.is_text_encoder_outputs_cached(args)
def cache_text_encoder_outputs_if_needed(
self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype
):
@@ -123,6 +128,11 @@ class NetworkTrainer:
noise_pred = unet(noisy_latents, timesteps, text_conds).sample
return noise_pred
def all_reduce_network(self, accelerator, network):
for param in network.parameters():
if param.grad is not None:
param.grad = accelerator.reduce(param.grad, reduction="mean")
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)
@@ -309,7 +319,7 @@ class NetworkTrainer:
args.scale_weight_norms = False
train_unet = not args.network_train_text_encoder_only
train_text_encoder = not args.network_train_unet_only and not self.is_text_encoder_outputs_cached(args)
train_text_encoder = self.is_train_text_encoder(args)
network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
if args.network_weights is not None:
@@ -386,44 +396,20 @@ class NetworkTrainer:
# acceleratorがなんかよろしくやってくれるらしい
# TODO めちゃくちゃ冗長なのでコードを整理する
if train_unet and train_text_encoder:
if len(text_encoders) > 1:
unet, t_enc1, t_enc2, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoders[0], text_encoders[1], network, optimizer, train_dataloader, lr_scheduler
)
text_encoder = text_encoders = [t_enc1, t_enc2]
del t_enc1, t_enc2
else:
unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler
)
text_encoders = [text_encoder]
elif train_unet:
unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, network, optimizer, train_dataloader, lr_scheduler
)
elif train_text_encoder:
if len(text_encoders) > 1:
t_enc1, t_enc2, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoders[0], text_encoders[1], network, optimizer, train_dataloader, lr_scheduler
)
text_encoder = text_encoders = [t_enc1, t_enc2]
del t_enc1, t_enc2
else:
text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder, network, optimizer, train_dataloader, lr_scheduler
)
text_encoders = [text_encoder]
unet.to(accelerator.device, dtype=weight_dtype) # move to device because unet is not prepared by accelerator
if train_unet:
unet = accelerator.prepare(unet)
else:
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
network, optimizer, train_dataloader, lr_scheduler
)
# transform DDP after prepare (train_network here only)
text_encoders = train_util.transform_models_if_DDP(text_encoders)
unet, network = train_util.transform_models_if_DDP([unet, network])
unet.to(accelerator.device, dtype=weight_dtype) # move to device because unet is not prepared by accelerator
if train_text_encoder:
if len(text_encoders) > 1:
text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders]
else:
text_encoder = accelerator.prepare(text_encoder)
text_encoders = [text_encoder]
else:
for t_enc in text_encoders:
t_enc.to(accelerator.device, dtype=weight_dtype)
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler)
if args.gradient_checkpointing:
# according to TI example in Diffusers, train is required
@@ -445,7 +431,7 @@ class NetworkTrainer:
del t_enc
network.prepare_grad_etc(text_encoder, unet)
accelerator.unwrap_model(network).prepare_grad_etc(text_encoder, unet)
if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する
vae.requires_grad_(False)
@@ -528,6 +514,7 @@ class NetworkTrainer:
"ss_min_snr_gamma": args.min_snr_gamma,
"ss_scale_weight_norms": args.scale_weight_norms,
"ss_ip_noise_gamma": args.ip_noise_gamma,
"ss_debiased_estimation": bool(args.debiased_estimation_loss),
}
if use_user_config:
@@ -697,19 +684,20 @@ class NetworkTrainer:
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs['wandb'] = {'name': args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"network_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
)
loss_list = []
loss_total = 0.0
loss_recorder = train_util.LossRecorder()
del train_dataset_group
# callback for step start
if hasattr(network, "on_step_start"):
on_step_start = network.on_step_start
if hasattr(accelerator.unwrap_model(network), "on_step_start"):
on_step_start = accelerator.unwrap_model(network).on_step_start
else:
on_step_start = lambda *args, **kwargs: None
@@ -737,6 +725,9 @@ class NetworkTrainer:
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
# For --sample_at_first
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
# training loop
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
@@ -744,7 +735,7 @@ class NetworkTrainer:
metadata["ss_epoch"] = str(epoch + 1)
network.on_epoch_start(text_encoder, unet)
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet)
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
@@ -761,11 +752,11 @@ class NetworkTrainer:
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
latents = torch.nan_to_num(latents, 0, out=latents)
latents = latents * self.vae_scale_factor
b_size = latents.shape[0]
with torch.set_grad_enabled(train_text_encoder):
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
# Get the text embedding for conditioning
if args.weighted_captions:
text_encoder_conds = get_weighted_text_embeddings(
@@ -806,17 +797,20 @@ class NetworkTrainer:
loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss:
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss)
self.all_reduce_network(accelerator, network) # sync DDP grad manually
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = network.get_trainable_params()
params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
@@ -824,7 +818,7 @@ class NetworkTrainer:
optimizer.zero_grad(set_to_none=True)
if args.scale_weight_norms:
keys_scaled, mean_norm, maximum_norm = network.apply_max_norm_regularization(
keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(
args.scale_weight_norms, accelerator.device
)
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
@@ -854,14 +848,9 @@ class NetworkTrainer:
remove_model(remove_ckpt_name)
current_loss = loss.detach().item()
if epoch == 0:
loss_list.append(current_loss)
else:
loss_total -= loss_list[step]
loss_list[step] = current_loss
loss_total += current_loss
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if args.scale_weight_norms:
@@ -875,7 +864,7 @@ class NetworkTrainer:
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(loss_list)}
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()

View File

@@ -35,6 +35,7 @@ from library.custom_train_functions import (
prepare_scheduler_for_custom_training,
scale_v_prediction_loss_like_noise_prediction,
add_v_prediction_like_loss,
apply_debiased_estimation,
)
imagenet_templates_small = [
@@ -170,13 +171,6 @@ class TextualInversionTrainer:
args.output_name = args.token_string
use_template = args.use_object_template or args.use_style_template
assert (
args.token_string is not None or args.token_strings is not None
), "token_string or token_strings must be specified / token_stringまたはtoken_stringsを指定してください"
assert (
not use_template or args.token_strings is None
), "token_strings cannot be used with template / token_stringsはテンプレートと一緒に使えません"
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
@@ -225,17 +219,9 @@ class TextualInversionTrainer:
# add new word to tokenizer, count is num_vectors_per_token
# if token_string is hoge, "hoge", "hoge1", "hoge2", ... are added
if args.token_strings is not None:
token_strings = args.token_strings
assert (
len(token_strings) == args.num_vectors_per_token
), f"num_vectors_per_token is mismatch for token_strings / token_stringsの数がnum_vectors_per_tokenと合いません: {len(token_strings)}"
for token_string in token_strings:
self.assert_token_string(token_string, tokenizers)
else:
self.assert_token_string(args.token_string, tokenizers)
token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
self.assert_token_string(args.token_string, tokenizers)
token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
token_ids_list = []
token_embeds_list = []
for i, (tokenizer, text_encoder, init_token_ids) in enumerate(zip(tokenizers, text_encoders, init_token_ids_list)):
@@ -350,7 +336,7 @@ class TextualInversionTrainer:
prompt_replacement = None
else:
# サンプル生成用
if args.num_vectors_per_token > 1 and args.token_strings is None:
if args.num_vectors_per_token > 1:
replace_to = " ".join(token_strings)
train_dataset_group.add_replacement(args.token_string, replace_to)
prompt_replacement = (args.token_string, replace_to)
@@ -432,15 +418,11 @@ class TextualInversionTrainer:
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler
)
# transform DDP after prepare
text_encoder_or_list, unet = train_util.transform_if_model_is_DDP(text_encoder_or_list, unet)
elif len(text_encoders) == 2:
text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoders[0], text_encoders[1], optimizer, train_dataloader, lr_scheduler
)
# transform DDP after prepare
text_encoder1, text_encoder2, unet = train_util.transform_if_model_is_DDP(text_encoder1, text_encoder2, unet)
text_encoder_or_list = text_encoders = [text_encoder1, text_encoder2]
@@ -459,9 +441,10 @@ class TextualInversionTrainer:
# Freeze all parameters except for the token embeddings in text encoder
text_encoder.requires_grad_(True)
text_encoder.text_model.encoder.requires_grad_(False)
text_encoder.text_model.final_layer_norm.requires_grad_(False)
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
unwrapped_text_encoder = accelerator.unwrap_model(text_encoder)
unwrapped_text_encoder.text_model.encoder.requires_grad_(False)
unwrapped_text_encoder.text_model.final_layer_norm.requires_grad_(False)
unwrapped_text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
# text_encoder.text_model.embeddings.token_embedding.requires_grad_(True)
unet.requires_grad_(False)
@@ -521,6 +504,8 @@ class TextualInversionTrainer:
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs['wandb'] = {'name': args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
@@ -546,6 +531,20 @@ class TextualInversionTrainer:
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
# For --sample_at_first
self.sample_images(
accelerator,
args,
0,
global_step,
accelerator.device,
vae,
tokenizer_or_list,
text_encoder_or_list,
unet,
prompt_replacement,
)
# training loop
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
@@ -595,17 +594,19 @@ class TextualInversionTrainer:
loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss:
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = text_encoder.get_input_embeddings().parameters()
params_to_clip = accelerator.unwrap_model(text_encoder).get_input_embeddings().parameters()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
@@ -617,9 +618,11 @@ class TextualInversionTrainer:
for text_encoder, orig_embeds_params, index_no_updates in zip(
text_encoders, orig_embeds_params_list, index_no_updates_list
):
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
# if full_fp16/bf16, input_embeddings_weight is fp16/bf16, orig_embeds_params is fp32
input_embeddings_weight = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight
input_embeddings_weight[index_no_updates] = orig_embeds_params.to(input_embeddings_weight.dtype)[
index_no_updates
] = orig_embeds_params[index_no_updates]
]
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
@@ -770,13 +773,6 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること",
)
parser.add_argument(
"--token_strings",
type=str,
default=None,
nargs="*",
help="token strings used in training for multiple embedding / 複数のembeddingsの個別学習時に使用されるトークン文字列",
)
parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可")
parser.add_argument(
"--use_object_template",

View File

@@ -34,6 +34,7 @@ from library.custom_train_functions import (
pyramid_noise_like,
apply_noise_offset,
scale_v_prediction_loss_like_noise_prediction,
apply_debiased_estimation,
)
import library.original_unet as original_unet
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
@@ -332,9 +333,6 @@ def train(args):
text_encoder, optimizer, train_dataloader, lr_scheduler
)
# transform DDP after prepare
text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet)
index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0]
# print(len(index_no_updates), torch.sum(index_no_updates))
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
@@ -396,6 +394,8 @@ def train(args):
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs['wandb'] = {'name': args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
@@ -468,9 +468,11 @@ def train(args):
loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし