Compare commits

..

150 Commits

Author SHA1 Message Date
Kohya S
235a1ea2c6 Merge branch 'dev' into stable-cascade 2024-02-25 20:03:39 +09:00
Kohya S
cb648a2bf8 update readme 2024-02-25 20:03:00 +09:00
Kohya S
3a2a48c15d make LoRA compatible with ComfyUI #1119 2024-02-25 20:01:37 +09:00
Kohya S
e0acb10f31 Merge pull request #1137 from shirayu/replace_print_with_logger
Replaced print with logger
2024-02-25 10:34:19 +09:00
Kohya S
40f2c688db fix stage c weight is loaded in bf16/fp16 #1119 2024-02-25 09:39:53 +09:00
Kohya S
e4f8736c60 Merge branch 'dev' into stable-cascade 2024-02-25 08:58:27 +09:00
Yuta Hayashibe
5d5f39b6e6 Replaced print with logger 2024-02-25 01:24:11 +09:00
Kohya S
e69d34103b Merge pull request #1136 from kohya-ss/dev
v0.8.4
2024-02-24 21:15:46 +09:00
Kohya S
a21218bdd5 update readme 2024-02-24 21:09:59 +09:00
Kohya S
81e8af6519 fix ipex init 2024-02-24 20:51:26 +09:00
Kohya S
8b7c14246a some log output to print 2024-02-24 20:50:00 +09:00
Kohya S
52b3799989 fix format, add new conv rank to metadata comment 2024-02-24 20:49:41 +09:00
Kohya S
738c397e1a Merge pull request #1102 from mgz-dev/resize_lora-add-rank-for-conv
Resize lora add new rank for conv
2024-02-24 20:10:20 +09:00
Kohya S
0e703608f9 Merge branch 'dev' into resize_lora-add-rank-for-conv 2024-02-24 20:09:38 +09:00
Kohya S
fb9110bac1 format by black 2024-02-24 20:00:57 +09:00
Kohya S
24092e6f21 update einops to 0.7.0 #1122 2024-02-24 19:51:51 +09:00
Kohya S
f4132018c5 fix to work with cpu_count() == 1 closes #1134 2024-02-24 19:25:31 +09:00
Kohya S
488d1870ab Merge pull request #1126 from tamlog06/DyLoRA-xl
Fix dylora create_modules error when training sdxl
2024-02-24 19:19:33 +09:00
Kohya S
86279c8855 Merge branch 'dev' into DyLoRA-xl 2024-02-24 19:18:36 +09:00
Kohya S
13f49d1e4a update readme 2024-02-22 23:50:10 +09:00
Kohya S
df7648245e update readme 2024-02-22 23:41:46 +09:00
Kohya S
3368fb1af7 Modify nn.MHA to attn with q/k/v 2024-02-22 23:39:28 +09:00
Kohya S
417f14d245 Merge pull request #1130 from sdbds/fixbugs
[stable-cascade]add save parser and fix lora scripts model name and hash
2024-02-22 12:30:59 +09:00
青龍聖者@bdsqlsz
86503cb945 add save parser and fix lora scripts model name and hash 2024-02-21 19:38:12 +08:00
Kohya S
d91b1d3793 update readme 2024-02-20 22:39:57 +09:00
Kohya S
70917077a6 update readme 2024-02-20 22:38:36 +09:00
Kohya S
69dbc50912 fix effnet encoder preprocess issue 2024-02-20 22:34:06 +09:00
Kohya S
985761ca43 fix to work without network module 2024-02-20 20:33:03 +09:00
Kohya S
71e03559e2 support LoRA training for Stable Cascade Stage C 2024-02-20 08:27:11 +09:00
tamlog06
a6f1ed2e14 fix dylora create_modules error 2024-02-18 13:20:47 +00:00
Kohya S
806a6237fb minor fixes 2024-02-18 21:57:16 +09:00
Kohya S
9b0e532942 add command line sample 2024-02-18 21:40:36 +09:00
Kohya S
c26f01241f input prompt from console 2024-02-18 21:29:46 +09:00
Kohya S
ac71168939 add train_text_encoder arg 2024-02-18 21:29:10 +09:00
Kohya S
4e37d950d2 fix typos 2024-02-18 18:02:20 +09:00
Kohya S
4b5784eb44 update stable cascade stage C training #1119 2024-02-18 17:54:21 +09:00
Kohya S
856df07f49 Merge branch 'dev' into stable-cascade 2024-02-18 09:15:12 +09:00
Kohya S
d1fb480887 format by black 2024-02-18 09:13:24 +09:00
Kohya S
80ef59c115 support text encoder training in stable cascade 2024-02-18 09:12:37 +09:00
Kohya S
319bbf8057 add stage c tmp training code 2024-02-17 23:59:20 +09:00
Kohya S
fa440208b7 add inference script 2024-02-17 17:57:30 +09:00
Kohya S
75e4a951d0 update readme 2024-02-17 12:04:12 +09:00
Kohya S
42f3318e17 Merge pull request #1116 from kohya-ss/dev_device_support
Dev device support
2024-02-17 11:58:02 +09:00
Kohya S
baa0e97ced Merge branch 'dev' into dev_device_support 2024-02-17 11:54:07 +09:00
Kohya S
71ebcc5e25 update readme and gradual latent doc 2024-02-12 14:52:19 +09:00
Kohya S
93bed60762 fix to work --console_log_xxx options 2024-02-12 14:49:29 +09:00
Kohya S
41d32c0be4 Merge pull request #1117 from kohya-ss/gradual_latent_hires_fix
Gradual latent hires fix
2024-02-12 14:21:27 +09:00
Kohya S
cbe9c5dc06 supprt deep shink with regional lora, add prompter module 2024-02-12 14:17:27 +09:00
Kohya S
d3745db764 add args for logging 2024-02-12 13:15:21 +09:00
Kohya S
358ca205a3 Merge branch 'dev' into dev_device_support 2024-02-12 13:01:54 +09:00
Kohya S
c748719115 fix indent 2024-02-12 12:59:45 +09:00
Kohya S
98f42d3a0b Merge branch 'dev' into gradual_latent_hires_fix 2024-02-12 12:59:25 +09:00
Kohya S
35c6053de3 Merge pull request #1104 from kohya-ss/dev_improve_log
replace print with logger
2024-02-12 11:33:32 +09:00
Kohya S
20ae603221 Merge branch 'dev' into gradual_latent_hires_fix 2024-02-12 11:26:36 +09:00
Kohya S
672851e805 Merge branch 'dev' into dev_improve_log 2024-02-12 11:24:33 +09:00
Kohya S
e579648ce9 fix help for highvram arg 2024-02-12 11:12:41 +09:00
Kohya S
e24d9606a2 add clean_memory_on_device and use it from training 2024-02-12 11:10:52 +09:00
Kohya S
75ecb047e2 Merge branch 'dev' into dev_device_support 2024-02-11 19:51:28 +09:00
Kohya S
f897d55781 Merge pull request #1113 from kohya-ss/dev_multi_gpu_sample_gen
Dev multi gpu sample gen
2024-02-11 19:49:08 +09:00
Kohya S
7202596393 log to print tag frequencies 2024-02-10 09:59:12 +09:00
Kohya S
5d9e2873f6 make rich to output to stderr instead of stdout 2024-02-08 21:38:02 +09:00
Kohya S
055f02e1e1 add logging args for training scripts 2024-02-08 21:16:42 +09:00
Kohya S
9b8ea12d34 update log initialization without rich 2024-02-08 21:06:39 +09:00
Kohya S
74fe0453b2 add comment for get_preferred_device 2024-02-08 20:58:54 +09:00
Kohya S
efd3b58973 Add logging arguments and update logging setup 2024-02-04 20:44:10 +09:00
Kohya S
6279b33736 fallback to basic logging if rich is not installed 2024-02-04 18:28:54 +09:00
Yuta Hayashibe
5f6bf29e52 Replace print with logger if they are logs (#905)
* Add get_my_logger()

* Use logger instead of print

* Fix log level

* Removed line-breaks for readability

* Use setup_logging()

* Add rich to requirements.txt

* Make simple

* Use logger instead of print

---------

Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
2024-02-04 18:14:34 +09:00
Kohya S
e793d7780d reduce peak VRAM in sample gen 2024-02-04 17:31:01 +09:00
mgz
1492bcbfa2 add --new_conv_rank option
update script to also take a separate conv rank value
2024-02-03 23:18:55 -06:00
mgz
bf2de5620c fix formatting in resize_lora.py 2024-02-03 20:09:37 -06:00
Kohya S
6269682c56 unificaition of gen scripts for SD and SDXL, work in progress 2024-02-03 23:33:48 +09:00
Kohya S
2f9a344297 fix typo 2024-02-03 23:26:57 +09:00
Kohya S
11aced3500 simplify multi-GPU sample generation 2024-02-03 22:25:29 +09:00
DKnight54
1567ce1e17 Enable distributed sample image generation on multi-GPU enviroment (#1061)
* Update train_util.py

Modifying to attempt enable multi GPU inference

* Update train_util.py

additional VRAM checking, refactor check_vram_usage to return string for use with accelerator.print

* Update train_network.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

remove sample image debug outputs

* Update train_util.py

* Update train_util.py

* Update train_network.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_network.py

* Update train_util.py

* Update train_network.py

* Update train_network.py

* Update train_network.py

* Cleanup of debugging outputs

* adopt more elegant coding

Co-authored-by: Aarni Koskela <akx@iki.fi>

* Update train_util.py

Fix leftover debugging code
attempt to refactor inference into separate function

* refactor in function generate_per_device_prompt_list() generation of distributed prompt list

* Clean up missing variables

* fix syntax error

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* true random sample image generation

update code to reinitialize random seed to true random if seed was set

* true random sample image generation

* simplify per process prompt

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_util.py

* Update train_network.py

* Update train_network.py

* Update train_network.py

---------

Co-authored-by: Aarni Koskela <akx@iki.fi>
2024-02-03 21:46:31 +09:00
Kohya S
5cca1fdc40 add highvram option and do not clear cache in caching latents 2024-02-01 21:55:55 +09:00
Kohya S
9f0f0d573d Merge pull request #1092 from Disty0/dev_device_support
Fix IPEX support and add XPU device to device_utils
2024-02-01 20:41:21 +09:00
Disty0
a6a2b5a867 Fix IPEX support and add XPU device to device_utils 2024-01-31 17:32:37 +03:00
Kohya S
2ca4d0c831 Merge pull request #1054 from akx/mps
Device support improvements (MPS)
2024-01-31 21:30:12 +09:00
Kohya S
7f948db158 Merge pull request #1087 from mgz-dev/fix-imports-on-svd_merge_lora
fix broken import in svd_merge_lora script
2024-01-31 21:08:40 +09:00
Kohya S
9d7729c00d Merge pull request #1086 from Disty0/dev
Update IPEX Libs
2024-01-31 21:06:34 +09:00
Disty0
988dee02b9 IPEX torch.tensor FP64 workaround 2024-01-30 01:52:32 +03:00
mgz
d4b9568269 fix broken import in svd_merge_lora script
remove missing import, and remove unused imports
2024-01-28 11:59:07 -06:00
Disty0
ccc3a481e7 Update IPEX Libs 2024-01-28 14:14:31 +03:00
Kohya S
8f6f734a6f Merge branch 'dev' into gradual_latent_hires_fix 2024-01-28 08:21:15 +09:00
Kohya S
cd19df49cd Merge pull request #1085 from kohya-ss/dev
Dev
2024-01-27 18:32:06 +09:00
Kohya S
736365bdd5 update README.md 2024-01-27 18:31:01 +09:00
Kohya S
6ceedb9448 Merge branch 'main' into dev 2024-01-27 18:23:52 +09:00
Kohya S
930a3912a7 Merge pull request #1084 from fireicewolf/devel
Fix network multiplier cause crashed while use multi-GPUs
2024-01-27 18:22:00 +09:00
Kohya S
cf790d87c4 Merge pull request #1079 from feffy380/fix/fp8savestate
Update safetensors to fix a crash with `--fp8_base --save_state`
2024-01-26 22:34:35 +09:00
DukeG
4e67fb8444 test 2024-01-26 20:22:49 +08:00
DukeG
50f631c768 test 2024-01-26 20:02:48 +08:00
DukeG
85bc371ebc test 2024-01-26 18:58:47 +08:00
feffy380
322ee52c77 Update requirements.txt
Update safetensors to fix a crash when using `--fp8_base --save_state`
2024-01-25 19:15:53 +01:00
Kohya S
c576f80639 Fix ControlNetLLLite training issue #1069 2024-01-25 18:43:07 +09:00
Aarni Koskela
478156b4f7 Refactor device determination to function; add MPS fallback 2024-01-23 14:29:03 +02:00
Aarni Koskela
afc38707d5 Refactor memory cleaning into a single function 2024-01-23 14:28:50 +02:00
Aarni Koskela
2e4bee6f24 Log accelerator device 2024-01-23 14:20:40 +02:00
Kohya S
d5ab97b69b Merge pull request #1067 from kohya-ss/dev
Dev
2024-01-23 21:04:16 +09:00
Kohya S
7cb44e4502 update readme 2024-01-23 21:02:40 +09:00
Kohya S
7a20df5ad5 Merge pull request #1064 from KohakuBlueleaf/fix-grad-sync
Avoid grad sync on each step even when doing accumulation
2024-01-23 20:33:55 +09:00
Kohya S
bea4362e21 Merge pull request #1060 from akx/refactor-xpu-init
Deduplicate ipex initialization code
2024-01-23 20:25:37 +09:00
Kohya S
6805cafa9b fix TI training crashes in multigpu #1019 2024-01-23 20:17:19 +09:00
Kohaku-Blueleaf
711b40ccda Avoid always sync 2024-01-23 11:49:03 +08:00
Kohya S
696dd7f668 Fix dtype issue in PyTorch 2.0 for generating samples in training sdxl network 2024-01-22 12:43:37 +09:00
Kohya S
e0a3c69223 update readme 2024-01-20 18:47:10 +09:00
Kohya S
c59249a664 Add options to reduce memory usage in extract_lora_from_models.py closes #1059 2024-01-20 18:45:54 +09:00
Kohya S
fef172966f Add network_multiplier for dataset and train LoRA 2024-01-20 16:24:43 +09:00
Kohya S
5a1ebc4c7c format by black 2024-01-20 13:10:45 +09:00
Kohya S
2a0f45aea9 update readme 2024-01-20 11:08:20 +09:00
Kohya S
1f77bb6e73 fix to work sample generation in fp8 ref #1057 2024-01-20 10:57:42 +09:00
Kohya S
a7ef6422b6 fix to work with torch 2.0 2024-01-20 10:00:30 +09:00
Kohaku-Blueleaf
9cfa68c92f [Experimental Feature] FP8 weight dtype for base model when running train_network (or sdxl_train_network) (#1057)
* Add fp8 support

* remove some debug prints

* Better implementation for te

* Fix some misunderstanding

* as same as unet, add explicit convert

* better impl for convert TE to fp8

* fp8 for not only unet

* Better cache TE and TE lr

* match arg name

* Fix with list

* Add timeout settings

* Fix arg style

* Add custom seperator

* Fix typo

* Fix typo again

* Fix dtype error

* Fix gradient problem

* Fix req grad

* fix merge

* Fix merge

* Resolve merge

* arrangement and document

* Resolve merge error

* Add assert for mixed precision
2024-01-20 09:46:53 +09:00
Aarni Koskela
6f3f701d3d Deduplicate ipex initialization code 2024-01-19 18:07:36 +02:00
Kohya S
d2a99a19d4 Merge pull request #1056 from kohya-ss/dev
fix vram usage in LoRA training
2024-01-17 21:41:36 +09:00
Kohya S
0395a35543 Merge branch 'main' into dev 2024-01-17 21:39:13 +09:00
Kohya S
987d4a969d update readme 2024-01-17 21:38:49 +09:00
Kohya S
976d092c68 fix text encodes are on gpu even when not trained 2024-01-17 21:31:50 +09:00
Kohya S
e6b15c7e4a Merge pull request #1053 from akx/sdpa
Fix typo `--spda` (it's `--sdpa`)
2024-01-16 21:50:45 +09:00
Aarni Koskela
ef50436464 Fix typo --spda (it's --sdpa) 2024-01-16 14:32:48 +02:00
Kohya S
26d35794e3 Merge pull request #1052 from kohya-ss/dev
merge dev
2024-01-15 21:39:02 +09:00
Kohya S
dcf0eeb5b6 update readme 2024-01-15 21:35:26 +09:00
Kohya S
32b759a328 Add wandb_run_name parameter to init_kwargs #1032 2024-01-14 22:02:03 +09:00
Kohya S
09ef3ffa8b Merge branch 'main' into dev 2024-01-14 21:49:25 +09:00
Kohya S
aab265e431 Fix an issue with saving as diffusers sd1/2 model close #1033 2024-01-04 21:43:50 +09:00
Kohya S
da9b34fa26 Merge branch 'dev' into gradual_latent_hires_fix 2024-01-04 19:53:46 +09:00
Kohya S
716bad188b Update dependencies ref #1024 2024-01-04 19:53:25 +09:00
Kohya S
07bf2a21ac Merge pull request #1024 from p1atdev/main
Add support for `torch.compile`
2024-01-04 10:49:52 +09:00
Kohya S
8ac2d2a92f Merge pull request #1030 from Disty0/dev
Update IPEX Libs
2024-01-04 10:46:07 +09:00
Kohya S
76aee71257 Merge branch 'main' into dev 2024-01-04 10:42:16 +09:00
Kohya S
663b481029 fix TI training with full_fp16/bf16 ref #1019 2024-01-03 23:22:00 +09:00
Kohya S
1ab6493268 Merge branch 'main' into dev 2024-01-03 21:36:31 +09:00
Disty0
b9d2181192 Cleanup 2024-01-02 11:51:29 +03:00
Disty0
49148eb36e Disable Diffusers slicing if device is not XPU 2024-01-02 11:50:08 +03:00
Disty0
479bac447e Fix typo 2024-01-01 12:51:23 +03:00
Disty0
15d5e78ac2 Update IPEX Libs 2024-01-01 12:44:26 +03:00
Plat
62e7516537 feat: support torch.compile 2023-12-27 02:17:24 +09:00
Plat
20296b4f0e chore: bump eniops version due to support torch.compile 2023-12-27 02:17:24 +09:00
Kohya S
5cae6db804 Fix to work with DDP TextualInversionTrainer ref #1019 2023-12-24 22:05:56 +09:00
Kohya S
d61ecb26fd enable comment in prompt file, record raw prompt to metadata 2023-12-12 08:20:36 +09:00
Kohya S
07ef03d340 fix controlnet to work with gradual latent 2023-12-12 08:03:27 +09:00
Kohya S
9278031e60 Merge branch 'dev' into gradual_latent_hires_fix 2023-12-12 07:49:36 +09:00
Kohya S
e8c3a02830 Merge branch 'dev' into gradual_latent_hires_fix 2023-12-08 08:23:53 +09:00
Kohya S
7a4e50705c add target_x flag (not sure this impl is correct) 2023-12-03 17:59:41 +09:00
Kohya S
2952bca520 fix strength error 2023-12-01 21:56:08 +09:00
Kohya S
29b6fa6212 add unsharp mask 2023-11-28 22:33:22 +09:00
Kohya S
2c50ea0403 apply unsharp mask 2023-11-27 23:50:21 +09:00
Kohya S
298c6c2343 fix gradual latent cannot be disabled 2023-11-26 21:48:36 +09:00
Kohya S
2897a89dfd Merge branch 'dev' into gradual_latent_hires_fix 2023-11-26 18:12:24 +09:00
Kohya S
610566fbb9 Update README.md 2023-11-23 22:22:36 +09:00
Kohya S
684954695d add gradual latent 2023-11-23 22:17:49 +09:00
79 changed files with 13106 additions and 2710 deletions

344
README.md
View File

@@ -1,3 +1,174 @@
# Training Stable Cascade Stage C
This is an experimental feature. There may be bugs.
__Feb 25, 2024 Update:__ Fixed a bug that the LoRA weights trained can be loaded in ComfyUI. If you still have a problem, please let me know.
__Feb 25, 2024 Update:__ Fixed a bug that Stage C training with mixed precision behaves the same as `--full_bf16` (fp16) regardless of `--full_bf16` (fp16) specified.
This is because the Stage C weights were loaded in bf16/fp16. With this fix, the memory usage without `--full_bf16` (fp16) specified will increase, so you may need to specify `--full_bf16` (fp16) as needed.
__Feb 22, 2024 Update:__ Fixed a bug that LoRA is not applied to some modules (to_q/k/v and to_out) in Attention. Also, the model structure of Stage C has been changed, and you can choose xformers and SDPA (SDPA was used before). Please specify `--sdpa` or `--xformers` option.
__Feb 20, 2024 Update:__ There was a problem with the preprocessing of the EfficientNetEncoder, and the latents became invalid (the saturation of the training results decreases). If you have cached `_sc_latents.npz` files with `--cache_latents_to_disk`, please delete them before training.
## Usage
Training is run with `stable_cascade_train_stage_c.py`.
The main options are the same as `sdxl_train.py`. The following options have been added.
- `--effnet_checkpoint_path`: Specifies the path to the EfficientNetEncoder weights.
- `--stage_c_checkpoint_path`: Specifies the path to the Stage C weights.
- `--text_model_checkpoint_path`: Specifies the path to the Text Encoder weights. If omitted, the model from Hugging Face will be used.
- `--save_text_model`: Saves the model downloaded from Hugging Face to `--text_model_checkpoint_path`.
- `--previewer_checkpoint_path`: Specifies the path to the Previewer weights. Used to generate sample images during training.
- `--adaptive_loss_weight`: Uses [Adaptive Loss Weight](https://github.com/Stability-AI/StableCascade/blob/master/gdf/loss_weights.py) . If omitted, P2LossWeight is used. The official settings use Adaptive Loss Weight.
The learning rate is set to 1e-4 in the official settings.
The first time, specify `--text_model_checkpoint_path` and `--save_text_model` to save the Text Encoder weights. From the next time, specify `--text_model_checkpoint_path` to load the saved weights.
Sample image generation during training is done with Perviewer. Perviewer is a simple decoder that converts EfficientNetEncoder latents to images.
Some of the options for SDXL are simply ignored or cause an error (especially noise-related options such as `--noise_offset`). `--vae_batch_size` and `--no_half_vae` are applied directly to the EfficientNetEncoder (when `bf16` is specified for mixed precision, `--no_half_vae` is not necessary).
Options for latents and Text Encoder output caches can be used as is, but since the EfficientNetEncoder is much lighter than the VAE, you may not need to use the cache unless memory is particularly tight.
`--gradient_checkpointing`, `--full_bf16`, and `--full_fp16` (untested) to reduce memory consumption can be used as is.
A scale of about 4 is suitable for sample image generation.
Since the official settings use `bf16` for training, training with `fp16` may be unstable.
The code for training the Text Encoder is also written, but it is untested.
### Command line sample
```batch
accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 stable_cascade_train_stage_c.py --mixed_precision bf16 --save_precision bf16 --max_data_loader_n_workers 2 --persistent_data_loader_workers --gradient_checkpointing --learning_rate 1e-4 --optimizer_type adafactor --optimizer_args "scale_parameter=False" "relative_step=False" "warmup_init=False" --max_train_epochs 10 --save_every_n_epochs 1 --save_precision bf16 --output_dir ../output --output_name sc_test - --stage_c_checkpoint_path ../models/stage_c_bf16.safetensors --effnet_checkpoint_path ../models/effnet_encoder.safetensors --previewer_checkpoint_path ../models/previewer.safetensors --dataset_config ../dataset/config_bs1.toml --sample_every_n_epochs 1 --sample_prompts ../dataset/prompts.txt --adaptive_loss_weight
```
### About the dataset for fine tuning
If the latents cache files for SD/SDXL exist (extension `*.npz`), it will be read and an error will occur during training. Please move them to another location in advance.
After that, run `finetune/prepare_buckets_latents.py` with the `--stable_cascade` option to create latents cache files for Stable Cascade (suffix `_sc_latents.npz` is added).
## LoRA training
`stable_cascade_train_c_network.py` is used for LoRA training. The main options are the same as `train_network.py`, and the same options as `stable_cascade_train_stage_c.py` have been added.
__This is an experimental feature, so the format of the saved weights may change in the future and become incompatible.__
There is no compatibility with the official LoRA, and the implementation of Text Encoder embedding training (Pivotal Tuning) in the official implementation is not implemented here.
Text Encoder LoRA training is implemented, but untested.
## Image generation
Basic image generation functionality is available in `stable_cascade_gen_img.py`. See `--help` for usage.
When using LoRA, specify `--network_module networks.lora --network_mul 1 --network_weights lora_weights.safetensors`.
The following prompt options are available.
* `--n` Negative prompt up to the next option.
* `--w` Specifies the width of the generated image.
* `--h` Specifies the height of the generated image.
* `--d` Specifies the seed of the generated image.
* `--l` Specifies the CFG scale of the generated image.
* `--s` Specifies the number of steps in the generation.
* `--t` Specifies the t_start of the generation.
* `--f` Specifies the shift of the generation.
# Stable Cascade Stage C の学習
実験的機能です。不具合があるかもしれません。
__2024/2/25 追記:__ 学習される LoRA の重みが ComfyUI で読み込めるよう修正しました。依然として不具合がある場合にはご連絡ください。
__2024/2/25 追記:__ Mixed precision 時のStage C の学習が、 `--full_bf16` (fp16) の指定に関わらず `--full_bf16` (fp16) 指定時と同じ動作となる(と思われる)不具合を修正しました。
Stage C の重みを bf16/fp16 で読み込んでいたためです。この修正により `--full_bf16` (fp16) 未指定時のメモリ使用量が増えますので、必要に応じて `--full_bf16` (fp16) を指定してください。
__2024/2/22 追記:__ LoRA が一部のモジュールAttention の to_q/k/v および to_outに適用されない不具合を修正しました。また Stage C のモデル構造を変更し xformers と SDPA を選べるようになりました(今までは SDPA が使用されていました)。`--sdpa` または `--xformers` オプションを指定してください。
__2024/2/20 追記:__ EfficientNetEncoder の前処理に不具合があり、latents が不正になっていました(学習結果の彩度が低下する現象が起きます)。`--cache_latents_to_disk` でキャッシュした `_sc_latents.npz` がある場合、いったん削除してから学習してください。
## 使い方
学習は `stable_cascade_train_stage_c.py` で行います。
主なオプションは `sdxl_train.py` と同様です。以下のオプションが追加されています。
- `--effnet_checkpoint_path` : EfficientNetEncoder の重みのパスを指定します。
- `--stage_c_checkpoint_path` : Stage C の重みのパスを指定します。
- `--text_model_checkpoint_path` : Text Encoder の重みのパスを指定します。省略時は Hugging Face のモデルを使用します。
- `--save_text_model` : `--text_model_checkpoint_path` にHugging Face からダウンロードしたモデルを保存します。
- `--previewer_checkpoint_path` : Previewer の重みのパスを指定します。学習中のサンプル画像生成に使用します。
- `--adaptive_loss_weight` : [Adaptive Loss Weight](https://github.com/Stability-AI/StableCascade/blob/master/gdf/loss_weights.py) を用います。省略時は P2LossWeight が使用されます。公式では Adaptive Loss Weight が使用されているようです。
学習率は、公式の設定では 1e-4 のようです。
初回は `--text_model_checkpoint_path``--save_text_model` を指定して、Text Encoder の重みを保存すると良いでしょう。次からは `--text_model_checkpoint_path` を指定して、保存した重みを読み込むことができます。
学習中のサンプル画像生成は Perviewer で行われます。Previewer は EfficientNetEncoder の latents を画像に変換する簡易的な decoder です。
SDXL の向けの一部のオプションは単に無視されるか、エラーになります(特に `--noise_offset` などのノイズ関係)。`--vae_batch_size` および `--no_half_vae` はそのまま EfficientNetEncoder に適用されますmixed precision に `bf16` 指定時は `--no_half_vae` は不要のようです)。
latents および Text Encoder 出力キャッシュのためのオプションはそのまま使用できますが、EfficientNetEncoder は VAE よりもかなり軽量のため、メモリが特に厳しい場合以外はキャッシュを使用する必要はないかもしれません。
メモリ消費を抑えるための `--gradient_checkpointing``--full_bf16``--full_fp16`(未テスト)はそのまま使用できます。
サンプル画像生成時の Scale には 4 程度が適しているようです。
公式の設定では学習に `bf16` を用いているため、`fp16` での学習は不安定かもしれません。
Text Encoder 学習のコードも書いてありますが、未テストです。
### コマンドラインのサンプル
[Command-line-sample](#command-line-sample)を参照してください。
### fine tuning方式のデータセットについて
SD/SDXL 向けの latents キャッシュファイル(拡張子 `*.npz`)が存在するとそれを読み込んでしまい学習時にエラーになります。あらかじめ他の場所に退避しておいてください。
その後、`finetune/prepare_buckets_latents.py` をオプション `--stable_cascade` を指定して実行すると、Stable Cascade 向けの latents キャッシュファイル(接尾辞 `_sc_latents.npz` が付きます)が作成されます。
## LoRA 等の学習
LoRA の学習は `stable_cascade_train_c_network.py` で行います。主なオプションは `train_network.py` と同様で、`stable_cascade_train_stage_c.py` と同様のオプションが追加されています。
__実験的機能のため、保存される重みのフォーマットは将来的に変更され、互換性がなくなる可能性があります。__
公式の LoRA と重みの互換性はありません。また公式で実装されている Text Encoder の embedding 学習Pivotal Tuningも実装されていません。
Text Encoder の LoRA 学習は実装してありますが、未テストです。
## 画像生成
最低限の画像生成機能が `stable_cascade_gen_img.py` にあります。使用法は `--help` を参照してください。
LoRA 使用時は `--network_module networks.lora --network_mul 1 --network_weights lora_weights.safetensors` のように指定します。
プロンプトオプションとして以下が使用できます。
* `--n` Negative prompt up to the next option.
* `--w` Specifies the width of the generated image.
* `--h` Specifies the height of the generated image.
* `--d` Specifies the seed of the generated image.
* `--l` Specifies the CFG scale of the generated image.
* `--s` Specifies the number of steps in the generation.
* `--t` Specifies the t_start of the generation.
* `--f` Specifies the shift of the generation.
---
__SDXL is now supported. The sdxl branch has been merged into the main branch. If you update the repository, please follow the upgrade instructions. Also, the version of accelerate has been updated, so please run accelerate config again.__ The documentation for SDXL training is [here](./README.md#sdxl-training).
This repository contains training, generation and utility scripts for Stable Diffusion.
@@ -249,98 +420,117 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum
## Change History
### Dec 24, 2023 / 2023/12/24
### Feb 24, 2024 / 2024/2/24: v0.8.4
- Fixed to work `tools/convert_diffusers20_original_sd.py`. Thanks to Disty0! PR [#1016](https://github.com/kohya-ss/sd-scripts/pull/1016)
- The log output has been improved. PR [#905](https://github.com/kohya-ss/sd-scripts/pull/905) Thanks to shirayu!
- The log is formatted by default. The `rich` library is required. Please see [Upgrade](#upgrade) and update the library.
- If `rich` is not installed, the log output will be the same as before.
- The following options are available in each training script:
- `--console_log_simple` option can be used to switch to the previous log output.
- `--console_log_level` option can be used to specify the log level. The default is `INFO`.
- `--console_log_file` option can be used to output the log to a file. The default is `None` (output to the console).
- The sample image generation during multi-GPU training is now done with multiple GPUs. PR [#1061](https://github.com/kohya-ss/sd-scripts/pull/1061) Thanks to DKnight54!
- The support for mps devices is improved. PR [#1054](https://github.com/kohya-ss/sd-scripts/pull/1054) Thanks to akx! If mps device exists instead of CUDA, the mps device is used automatically.
- The `--new_conv_rank` option to specify the new rank of Conv2d is added to `networks/resize_lora.py`. PR [#1102](https://github.com/kohya-ss/sd-scripts/pull/1102) Thanks to mgz-dev!
- An option `--highvram` to disable the optimization for environments with little VRAM is added to the training scripts. If you specify it when there is enough VRAM, the operation will be faster.
- Currently, only the cache part of latents is optimized.
- The IPEX support is improved. PR [#1086](https://github.com/kohya-ss/sd-scripts/pull/1086) Thanks to Disty0!
- Fixed a bug that `svd_merge_lora.py` crashes in some cases. PR [#1087](https://github.com/kohya-ss/sd-scripts/pull/1087) Thanks to mgz-dev!
- DyLoRA is fixed to work with SDXL. PR [#1126](https://github.com/kohya-ss/sd-scripts/pull/1126) Thanks to tamlog06!
- The common image generation script `gen_img.py` for SD 1/2 and SDXL is added. The basic functions are the same as the scripts for SD 1/2 and SDXL, but some new features are added.
- External scripts to generate prompts can be supported. It can be called with `--from_module` option. (The documentation will be added later)
- The normalization method after prompt weighting can be specified with `--emb_normalize_mode` option. `original` is the original method, `abs` is the normalization with the average of the absolute values, `none` is no normalization.
- Gradual Latent Hires fix is added to each generation script. See [here](./docs/gen_img_README-ja.md#about-gradual-latent) for details.
- `tools/convert_diffusers20_original_sd.py` が動かなくなっていたのが修正されました。Disty0 氏に感謝します。 PR [#1016](https://github.com/kohya-ss/sd-scripts/pull/1016)
- ログ出力が改善されました。 PR [#905](https://github.com/kohya-ss/sd-scripts/pull/905) shirayu 氏に感謝します。
- デフォルトでログが成形されます。`rich` ライブラリが必要なため、[Upgrade](#upgrade) を参照し更新をお願いします。
- `rich` がインストールされていない場合は、従来のログ出力になります。
- 各学習スクリプトでは以下のオプションが有効です。
- `--console_log_simple` オプションで従来のログ出力に切り替えられます。
- `--console_log_level` でログレベルを指定できます。デフォルトは `INFO` です。
- `--console_log_file` でログファイルを出力できます。デフォルトは `None`(コンソールに出力) です。
- 複数 GPU 学習時に学習中のサンプル画像生成を複数 GPU で行うようになりました。 PR [#1061](https://github.com/kohya-ss/sd-scripts/pull/1061) DKnight54 氏に感謝します。
- mps デバイスのサポートが改善されました。 PR [#1054](https://github.com/kohya-ss/sd-scripts/pull/1054) akx 氏に感謝します。CUDA ではなく mps が存在する場合には自動的に mps デバイスを使用します。
- `networks/resize_lora.py` に Conv2d の新しいランクを指定するオプション `--new_conv_rank` が追加されました。 PR [#1102](https://github.com/kohya-ss/sd-scripts/pull/1102) mgz-dev 氏に感謝します。
- 学習スクリプトに VRAMが少ない環境向け最適化を無効にするオプション `--highvram` を追加しました。VRAM に余裕がある場合に指定すると動作が高速化されます。
- 現在は latents のキャッシュ部分のみ高速化されます。
- IPEX サポートが改善されました。 PR [#1086](https://github.com/kohya-ss/sd-scripts/pull/1086) Disty0 氏に感謝します。
- `svd_merge_lora.py` が場合によってエラーになる不具合が修正されました。 PR [#1087](https://github.com/kohya-ss/sd-scripts/pull/1087) mgz-dev 氏に感謝します。
- DyLoRA が SDXL で動くよう修正されました。PR [#1126](https://github.com/kohya-ss/sd-scripts/pull/1126) tamlog06 氏に感謝します。
- SD 1/2 および SDXL 共通の生成スクリプト `gen_img.py` を追加しました。基本的な機能は SD 1/2、SDXL 向けスクリプトと同じですが、いくつかの新機能が追加されています。
- プロンプトを動的に生成する外部スクリプトをサポートしました。 `--from_module` で呼び出せます。(ドキュメントはのちほど追加します)
- プロンプト重みづけ後の正規化方法を `--emb_normalize_mode` で指定できます。`original` は元の方法、`abs` は絶対値の平均値で正規化、`none` は正規化を行いません。
- Gradual Latent Hires fix を各生成スクリプトに追加しました。詳細は [こちら](./docs/gen_img_README-ja.md#about-gradual-latent)。
### Dec 21, 2023 / 2023/12/21
### Jan 27, 2024 / 2024/1/27: v0.8.3
- The issues in multi-GPU training are fixed. Thanks to Isotr0py! PR [#989](https://github.com/kohya-ss/sd-scripts/pull/989) and [#1000](https://github.com/kohya-ss/sd-scripts/pull/1000)
- `--ddp_gradient_as_bucket_view` and `--ddp_bucket_view`options are added to `sdxl_train.py`. Please specify these options for multi-GPU training.
- IPEX support is updated. Thanks to Disty0!
- Fixed the bug that the size of the bucket becomes less than `min_bucket_reso`. Thanks to Cauldrath! PR [#1008](https://github.com/kohya-ss/sd-scripts/pull/1008)
- `--sample_at_first` option is added to each training script. This option is useful to generate images at the first step, before training. Thanks to shirayu! PR [#907](https://github.com/kohya-ss/sd-scripts/pull/907)
- `--ss` option is added to the sampling prompt in training. You can specify the scheduler for the sampling like `--ss euler_a`. Thanks to shirayu! PR [#906](https://github.com/kohya-ss/sd-scripts/pull/906)
- `keep_tokens_separator` is added to the dataset config. This option is useful to keep (prevent from shuffling) the tokens in the captions. See [#975](https://github.com/kohya-ss/sd-scripts/pull/975) for details. Thanks to Linaqruf!
- You can specify the separator with an option like `--keep_tokens_separator "|||"` or with `keep_tokens_separator: "|||"` in `.toml`. The tokens before `|||` are not shuffled.
- Attention processor hook is added. See [#961](https://github.com/kohya-ss/sd-scripts/pull/961) for details. Thanks to rockerBOO!
- The optimizer `PagedAdamW` is added. Thanks to xzuyn! PR [#955](https://github.com/kohya-ss/sd-scripts/pull/955)
- NaN replacement in SDXL VAE is sped up. Thanks to liubo0902! PR [#1009](https://github.com/kohya-ss/sd-scripts/pull/1009)
- Fixed the path error in `finetune/make_captions.py`. Thanks to CjangCjengh! PR [#986](https://github.com/kohya-ss/sd-scripts/pull/986)
- Fixed a bug that the training crashes when `--fp8_base` is specified with `--save_state`. PR [#1079](https://github.com/kohya-ss/sd-scripts/pull/1079) Thanks to feffy380!
- `safetensors` is updated. Please see [Upgrade](#upgrade) and update the library.
- Fixed a bug that the training crashes when `network_multiplier` is specified with multi-GPU training. PR [#1084](https://github.com/kohya-ss/sd-scripts/pull/1084) Thanks to fireicewolf!
- Fixed a bug that the training crashes when training ControlNet-LLLite.
- マルチGPUでの学習の不具合修正ました。Isotr0py 氏に感謝します。 PR [#989](https://github.com/kohya-ss/sd-scripts/pull/989) および [#1000](https://github.com/kohya-ss/sd-scripts/pull/1000)
- `sdxl_train.py``--ddp_gradient_as_bucket_view``--ddp_bucket_view` オプションが追加されました。マルチGPUでの学習時にはこれらのオプションを指定してください
- IPEX サポートが更新されました。Disty0 氏に感謝します。
- Aspect Ratio Bucketing で bucket のサイズが `min_bucket_reso` 未満になる不具合を修正しました。Cauldrath 氏に感謝します。 PR [#1008](https://github.com/kohya-ss/sd-scripts/pull/1008)
- 各学習スクリプトに `--sample_at_first` オプションが追加されました。学習前に画像を生成することで、学習結果が比較しやすくなります。shirayu 氏に感謝します。 PR [#907](https://github.com/kohya-ss/sd-scripts/pull/907)
- 学習時のプロンプトに `--ss` オプションが追加されました。`--ss euler_a` のようにスケジューラを指定できます。shirayu 氏に感謝します。 PR [#906](https://github.com/kohya-ss/sd-scripts/pull/906)
- データセット設定に `keep_tokens_separator` が追加されました。キャプション内のトークンをどの位置までシャッフルしないかを指定できます。詳細は [#975](https://github.com/kohya-ss/sd-scripts/pull/975) を参照してください。Linaqruf 氏に感謝します。
- オプションで `--keep_tokens_separator "|||"` のように指定するか、`.toml``keep_tokens_separator: "|||"` のように指定します。`|||` の前のトークンはシャッフルされません。
- Attention processor hook が追加されました。詳細は [#961](https://github.com/kohya-ss/sd-scripts/pull/961) を参照してください。rockerBOO 氏に感謝します。
- オプティマイザ `PagedAdamW` が追加されました。xzuyn 氏に感謝します。 PR [#955](https://github.com/kohya-ss/sd-scripts/pull/955)
- 学習時、SDXL VAE で NaN が発生した時の置き換えが高速化されました。liubo0902 氏に感謝します。 PR [#1009](https://github.com/kohya-ss/sd-scripts/pull/1009)
- `finetune/make_captions.py` で相対パス指定時のエラーが修正されました。CjangCjengh 氏に感謝します。 PR [#986](https://github.com/kohya-ss/sd-scripts/pull/986)
- `--fp8_base` 指定時に `--save_state` での保存がエラーになる不具合修正されました。 PR [#1079](https://github.com/kohya-ss/sd-scripts/pull/1079) feffy380 氏に感謝します。
- `safetensors` がバージョンアップされていますので、[Upgrade](#upgrade) を参照し更新をお願いします
- 複数 GPU での学習時に `network_multiplier` を指定するとクラッシュする不具合が修正されました。 PR [#1084](https://github.com/kohya-ss/sd-scripts/pull/1084) fireicewolf 氏に感謝します。
- ControlNet-LLLite の学習がエラーになる不具合を修正しました。
### Dec 3, 2023 / 2023/12/3
### Jan 23, 2024 / 2024/1/23: v0.8.2
- `finetune\tag_images_by_wd14_tagger.py` now supports the separator other than `,` with `--caption_separator` option. Thanks to KohakuBlueleaf! PR [#913](https://github.com/kohya-ss/sd-scripts/pull/913)
- Min SNR Gamma with V-predicition (SD 2.1) is fixed. Thanks to feffy380! PR[#934](https://github.com/kohya-ss/sd-scripts/pull/934)
- See [#673](https://github.com/kohya-ss/sd-scripts/issues/673) for details.
- `--min_diff` and `--clamp_quantile` options are added to `networks/extract_lora_from_models.py`. Thanks to wkpark! PR [#936](https://github.com/kohya-ss/sd-scripts/pull/936)
- The default values are same as the previous version.
- Deep Shrink hires fix is supported in `sdxl_gen_img.py` and `gen_img_diffusers.py`.
- `--ds_timesteps_1` and `--ds_timesteps_2` options denote the timesteps of the Deep Shrink for the first and second stages.
- `--ds_depth_1` and `--ds_depth_2` options denote the depth (block index) of the Deep Shrink for the first and second stages.
- `--ds_ratio` option denotes the ratio of the Deep Shrink. `0.5` means the half of the original latent size for the Deep Shrink.
- `--dst1`, `--dst2`, `--dsd1`, `--dsd2` and `--dsr` prompt options are also available.
- [Experimental] The `--fp8_base` option is added to the training scripts for LoRA etc. The base model (U-Net, and Text Encoder when training modules for Text Encoder) can be trained with fp8. PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) Thanks to KohakuBlueleaf!
- Please specify `--fp8_base` in `train_network.py` or `sdxl_train_network.py`.
- PyTorch 2.1 or later is required.
- If you use xformers with PyTorch 2.1, please see [xformers repository](https://github.com/facebookresearch/xformers) and install the appropriate version according to your CUDA version.
- The sample image generation during training consumes a lot of memory. It is recommended to turn it off.
- `finetune\tag_images_by_wd14_tagger.py``--caption_separator` オプションでカンマ以外の区切り文字を指定できるようになりました。KohakuBlueleaf 氏に感謝します。 PR [#913](https://github.com/kohya-ss/sd-scripts/pull/913)
- V-predicition (SD 2.1) での Min SNR Gamma が修正されました。feffy380 氏に感謝します。 PR[#934](https://github.com/kohya-ss/sd-scripts/pull/934)
- 詳細は [#673](https://github.com/kohya-ss/sd-scripts/issues/673) を参照してください。
- `networks/extract_lora_from_models.py``--min_diff``--clamp_quantile` オプションが追加されました。wkpark 氏に感謝します。 PR [#936](https://github.com/kohya-ss/sd-scripts/pull/936)
- デフォルト値は前のバージョンと同じです。
- `sdxl_gen_img.py``gen_img_diffusers.py` で Deep Shrink hires fix をサポートしました。
- `--ds_timesteps_1``--ds_timesteps_2` オプションは Deep Shrink の第一段階と第二段階の timesteps を指定します。
- `--ds_depth_1``--ds_depth_2` オプションは Deep Shrink の第一段階と第二段階の深さ(ブロックの indexを指定します。
- `--ds_ratio` オプションは Deep Shrink の比率を指定します。`0.5` を指定すると Deep Shrink 適用時の latent は元のサイズの半分になります。
- `--dst1``--dst2``--dsd1``--dsd2``--dsr` プロンプトオプションも使用できます。
- [Experimental] The network multiplier can be specified for each dataset in the training scripts for LoRA etc.
- This is an experimental option and may be removed or changed in the future.
- For example, if you train with state A as `1.0` and state B as `-1.0`, you may be able to generate by switching between state A and B depending on the LoRA application rate.
- Also, if you prepare five states and train them as `0.2`, `0.4`, `0.6`, `0.8`, and `1.0`, you may be able to generate by switching the states smoothly depending on the application rate.
- Please specify `network_multiplier` in `[[datasets]]` in `.toml` file.
- Some options are added to `networks/extract_lora_from_models.py` to reduce the memory usage.
- `--load_precision` option can be used to specify the precision when loading the model. If the model is saved in fp16, you can reduce the memory usage by specifying `--load_precision fp16` without losing precision.
- `--load_original_model_to` option can be used to specify the device to load the original model. `--load_tuned_model_to` option can be used to specify the device to load the derived model. The default is `cpu` for both options, but you can specify `cuda` etc. You can reduce the memory usage by loading one of them to GPU. This option is available only for SDXL.
### Nov 5, 2023 / 2023/11/5
- The gradient synchronization in LoRA training with multi-GPU is improved. PR [#1064](https://github.com/kohya-ss/sd-scripts/pull/1064) Thanks to KohakuBlueleaf!
- The code for Intel IPEX support is improved. PR [#1060](https://github.com/kohya-ss/sd-scripts/pull/1060) Thanks to akx!
- Fixed a bug in multi-GPU Textual Inversion training.
- `sdxl_train.py` now supports different learning rates for each Text Encoder.
- Example:
- `--learning_rate 1e-6`: train U-Net only
- `--train_text_encoder --learning_rate 1e-6`: train U-Net and two Text Encoders with the same learning rate (same as the previous version)
- `--train_text_encoder --learning_rate 1e-6 --learning_rate_te1 1e-6 --learning_rate_te2 1e-6`: train U-Net and two Text Encoders with the different learning rates
- `--train_text_encoder --learning_rate 0 --learning_rate_te1 1e-6 --learning_rate_te2 1e-6`: train two Text Encoders only
- `--train_text_encoder --learning_rate 1e-6 --learning_rate_te1 1e-6 --learning_rate_te2 0`: train U-Net and one Text Encoder only
- `--train_text_encoder --learning_rate 0 --learning_rate_te1 0 --learning_rate_te2 1e-6`: train one Text Encoder only
- 実験的 LoRA等の学習スクリプトで、ベースモデルU-Net、および Text Encoder のモジュール学習時は Text Encoder も)の重みを fp8 にして学習するオプションが追加されました。 PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) KohakuBlueleaf 氏に感謝します。
- `train_network.py` または `sdxl_train_network.py``--fp8_base` を指定してください。
- PyTorch 2.1 以降が必要です。
- PyTorch 2.1 で xformers を使用する場合は、[xformers のリポジトリ](https://github.com/facebookresearch/xformers) を参照し、CUDA バージョンに応じて適切なバージョンをインストールしてください。
- 学習中のサンプル画像生成はメモリを大量に消費するため、オフにすることをお勧めします。
- (実験的) LoRA 等の学習で、データセットごとに異なるネットワーク適用率を指定できるようになりました。
- 実験的オプションのため、将来的に削除または仕様変更される可能性があります。
- たとえば状態 A を `1.0`、状態 B を `-1.0` として学習すると、LoRA の適用率に応じて状態 A と B を切り替えつつ生成できるかもしれません。
- また、五段階の状態を用意し、それぞれ `0.2``0.4``0.6``0.8``1.0` として学習すると、適用率でなめらかに状態を切り替えて生成できるかもしれません。
- `.toml` ファイルで `[[datasets]]``network_multiplier` を指定してください。
- `networks/extract_lora_from_models.py` に使用メモリ量を削減するいくつかのオプションを追加しました。
- `--load_precision` で読み込み時の精度を指定できます。モデルが fp16 で保存されている場合は `--load_precision fp16` を指定して精度を変えずにメモリ量を削減できます。
- `--load_original_model_to` で元モデルを読み込むデバイスを、`--load_tuned_model_to` で派生モデルを読み込むデバイスを指定できます。デフォルトは両方とも `cpu` ですがそれぞれ `cuda` 等を指定できます。片方を GPU に読み込むことでメモリ量を削減できます。SDXL の場合のみ有効です。
- マルチ GPU での LoRA 等の学習時に勾配の同期が改善されました。 PR [#1064](https://github.com/kohya-ss/sd-scripts/pull/1064) KohakuBlueleaf 氏に感謝します。
- Intel IPEX サポートのコードが改善されました。PR [#1060](https://github.com/kohya-ss/sd-scripts/pull/1060) akx 氏に感謝します。
- マルチ GPU での Textual Inversion 学習の不具合を修正しました。
- `train_db.py` and `fine_tune.py` now support different learning rates for Text Encoder. Specify with `--learning_rate_te` option.
- To train Text Encoder with `fine_tune.py`, specify `--train_text_encoder` option too. `train_db.py` trains Text Encoder by default.
- `.toml` example for network multiplier / ネットワーク適用率の `.toml` の記述例
- Fixed the bug that Text Encoder is not trained when block lr is specified in `sdxl_train.py`.
```toml
[general]
[[datasets]]
resolution = 512
batch_size = 8
network_multiplier = 1.0
- Debiased Estimation loss is added to each training script. Thanks to sdbds!
- Specify `--debiased_estimation_loss` option to enable it. See PR [#889](https://github.com/kohya-ss/sd-scripts/pull/889) for details.
- Training of Text Encoder is improved in `train_network.py` and `sdxl_train_network.py`. Thanks to KohakuBlueleaf! PR [#895](https://github.com/kohya-ss/sd-scripts/pull/895)
- The moving average of the loss is now displayed in the progress bar in each training script. Thanks to shirayu! PR [#899](https://github.com/kohya-ss/sd-scripts/pull/899)
- PagedAdamW32bit optimizer is supported. Specify `--optimizer_type=PagedAdamW32bit`. Thanks to xzuyn! PR [#900](https://github.com/kohya-ss/sd-scripts/pull/900)
- Other bug fixes and improvements.
... subset settings ...
- `sdxl_train.py` で、二つのText Encoderそれぞれに独立した学習率が指定できるようになりました。サンプルは上の英語版を参照してください。
- `train_db.py` および `fine_tune.py` で Text Encoder に別の学習率を指定できるようになりました。`--learning_rate_te` オプションで指定してください。
- `fine_tune.py` で Text Encoder を学習するには `--train_text_encoder` オプションをあわせて指定してください。`train_db.py` はデフォルトで学習します。
- `sdxl_train.py` で block lr を指定すると Text Encoder が学習されない不具合を修正しました。
- Debiased Estimation loss が各学習スクリプトに追加されました。sdbsd 氏に感謝します。
- `--debiased_estimation_loss` を指定すると有効になります。詳細は PR [#889](https://github.com/kohya-ss/sd-scripts/pull/889) を参照してください。
- `train_network.py``sdxl_train_network.py` でText Encoderの学習が改善されました。KohakuBlueleaf 氏に感謝します。 PR [#895](https://github.com/kohya-ss/sd-scripts/pull/895)
- 各学習スクリプトで移動平均のlossがプログレスバーに表示されるようになりました。shirayu 氏に感謝します。 PR [#899](https://github.com/kohya-ss/sd-scripts/pull/899)
- PagedAdamW32bit オプティマイザがサポートされました。`--optimizer_type=PagedAdamW32bit` と指定してください。xzuyn 氏に感謝します。 PR [#900](https://github.com/kohya-ss/sd-scripts/pull/900)
- その他のバグ修正と改善。
[[datasets]]
resolution = 512
batch_size = 8
network_multiplier = -1.0
... subset settings ...
```
Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates.

View File

@@ -1,11 +1,7 @@
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from library.device_utils import init_ipex
init_ipex()
from typing import Union, List, Optional, Dict, Any, Tuple
from diffusers.models.unet_2d_condition import UNet2DConditionOutput

View File

@@ -452,3 +452,36 @@ python gen_img_diffusers.py --ckpt wd-v1-3-full-pruned-half.ckpt
- `--network_show_meta` : 追加ネットワークのメタデータを表示します。
---
# About Gradual Latent
Gradual Latent is a Hires fix that gradually increases the size of the latent. `gen_img.py`, `sdxl_gen_img.py`, and `gen_img_diffusers.py` have the following options.
- `--gradual_latent_timesteps`: Specifies the timestep to start increasing the size of the latent. The default is None, which means Gradual Latent is not used. Please try around 750 at first.
- `--gradual_latent_ratio`: Specifies the initial size of the latent. The default is 0.5, which means it starts with half the default latent size.
- `--gradual_latent_ratio_step`: Specifies the ratio to increase the size of the latent. The default is 0.125, which means the latent size is gradually increased to 0.625, 0.75, 0.875, 1.0.
- `--gradual_latent_ratio_every_n_steps`: Specifies the interval to increase the size of the latent. The default is 3, which means the latent size is increased every 3 steps.
Each option can also be specified with prompt options, `--glt`, `--glr`, `--gls`, `--gle`.
__Please specify `euler_a` for the sampler.__ Because the source code of the sampler is modified. It will not work with other samplers.
It is more effective with SD 1.5. It is quite subtle with SDXL.
# Gradual Latent について
latentのサイズを徐々に大きくしていくHires fixです。`gen_img.py` 、``sdxl_gen_img.py``gen_img_diffusers.py` に以下のオプションが追加されています。
- `--gradual_latent_timesteps` : latentのサイズを大きくし始めるタイムステップを指定します。デフォルトは None で、Gradual Latentを使用しません。750 くらいから始めてみてください。
- `--gradual_latent_ratio` : latentの初期サイズを指定します。デフォルトは 0.5 で、デフォルトの latent サイズの半分のサイズから始めます。
- `--gradual_latent_ratio_step`: latentのサイズを大きくする割合を指定します。デフォルトは 0.125 で、latentのサイズを 0.625, 0.75, 0.875, 1.0 と徐々に大きくします。
- `--gradual_latent_ratio_every_n_steps`: latentのサイズを大きくする間隔を指定します。デフォルトは 3 で、3ステップごとに latent のサイズを大きくします。
それぞれのオプションは、プロンプトオプション、`--glt``--glr``--gls``--gle` でも指定できます。
サンプラーに手を加えているため、__サンプラーに `euler_a` を指定してください。__ 他のサンプラーでは動作しません。
SD 1.5 のほうが効果があります。SDXL ではかなり微妙です。

View File

@@ -2,27 +2,27 @@
# XXX dropped option: hypernetwork training
import argparse
import gc
import math
import os
from multiprocessing import Value
import toml
from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
import library.train_util as train_util
import library.config_util as config_util
from library.config_util import (
@@ -42,6 +42,7 @@ from library.custom_train_functions import (
def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
setup_logging(args, reset=True)
cache_latents = args.cache_latents
@@ -54,11 +55,11 @@ def train(args):
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, False, True))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
@@ -91,7 +92,7 @@ def train(args):
train_util.debug_dataset(train_dataset_group)
return
if len(train_dataset_group) == 0:
print(
logger.error(
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
)
return
@@ -102,7 +103,7 @@ def train(args):
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
# acceleratorを準備する
print("prepare accelerator")
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
@@ -163,9 +164,7 @@ def train(args):
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
@@ -212,8 +211,8 @@ def train(args):
_, _, optimizer = train_util.get_optimizer(args, trainable_params=trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
@@ -228,7 +227,9 @@ def train(args):
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
@@ -291,6 +292,8 @@ def train(args):
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
@@ -464,12 +467,13 @@ def train(args):
train_util.save_sd_model_on_train_end(
args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae
)
print("model saved.")
logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
@@ -478,7 +482,9 @@ def setup_parser() -> argparse.ArgumentParser:
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
parser.add_argument(
"--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する"
)
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
parser.add_argument(
"--learning_rate_te",

View File

@@ -21,6 +21,10 @@ import torch.nn.functional as F
import os
from urllib.parse import urlparse
from timm.models.hub import download_cached_file
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class BLIP_Base(nn.Module):
def __init__(self,
@@ -235,6 +239,6 @@ def load_checkpoint(model,url_or_filename):
del state_dict[key]
msg = model.load_state_dict(state_dict,strict=False)
print('load checkpoint from %s'%url_or_filename)
logger.info('load checkpoint from %s'%url_or_filename)
return model,msg

View File

@@ -8,6 +8,10 @@ import json
import re
from tqdm import tqdm
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
PATTERN_HAIR_LENGTH = re.compile(r', (long|short|medium) hair, ')
PATTERN_HAIR_CUT = re.compile(r', (bob|hime) cut, ')
@@ -36,13 +40,13 @@ def clean_tags(image_key, tags):
tokens = tags.split(", rating")
if len(tokens) == 1:
# WD14 taggerのときはこちらになるのでメッセージは出さない
# print("no rating:")
# print(f"{image_key} {tags}")
# logger.info("no rating:")
# logger.info(f"{image_key} {tags}")
pass
else:
if len(tokens) > 2:
print("multiple ratings:")
print(f"{image_key} {tags}")
logger.info("multiple ratings:")
logger.info(f"{image_key} {tags}")
tags = tokens[0]
tags = ", " + tags.replace(", ", ", , ") + ", " # カンマ付きで検索をするための身も蓋もない対策
@@ -124,43 +128,43 @@ def clean_caption(caption):
def main(args):
if os.path.exists(args.in_json):
print(f"loading existing metadata: {args.in_json}")
logger.info(f"loading existing metadata: {args.in_json}")
with open(args.in_json, "rt", encoding='utf-8') as f:
metadata = json.load(f)
else:
print("no metadata / メタデータファイルがありません")
logger.error("no metadata / メタデータファイルがありません")
return
print("cleaning captions and tags.")
logger.info("cleaning captions and tags.")
image_keys = list(metadata.keys())
for image_key in tqdm(image_keys):
tags = metadata[image_key].get('tags')
if tags is None:
print(f"image does not have tags / メタデータにタグがありません: {image_key}")
logger.error(f"image does not have tags / メタデータにタグがありません: {image_key}")
else:
org = tags
tags = clean_tags(image_key, tags)
metadata[image_key]['tags'] = tags
if args.debug and org != tags:
print("FROM: " + org)
print("TO: " + tags)
logger.info("FROM: " + org)
logger.info("TO: " + tags)
caption = metadata[image_key].get('caption')
if caption is None:
print(f"image does not have caption / メタデータにキャプションがありません: {image_key}")
logger.error(f"image does not have caption / メタデータにキャプションがありません: {image_key}")
else:
org = caption
caption = clean_caption(caption)
metadata[image_key]['caption'] = caption
if args.debug and org != caption:
print("FROM: " + org)
print("TO: " + caption)
logger.info("FROM: " + org)
logger.info("TO: " + caption)
# metadataを書き出して終わり
print(f"writing metadata: {args.out_json}")
logger.info(f"writing metadata: {args.out_json}")
with open(args.out_json, "wt", encoding='utf-8') as f:
json.dump(metadata, f, indent=2)
print("done!")
logger.info("done!")
def setup_parser() -> argparse.ArgumentParser:
@@ -178,10 +182,10 @@ if __name__ == '__main__':
args, unknown = parser.parse_known_args()
if len(unknown) == 1:
print("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.")
print("All captions and tags in the metadata are processed.")
print("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。")
print("メタデータ内のすべてのキャプションとタグが処理されます。")
logger.warning("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.")
logger.warning("All captions and tags in the metadata are processed.")
logger.warning("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。")
logger.warning("メタデータ内のすべてのキャプションとタグが処理されます。")
args.in_json = args.out_json
args.out_json = unknown[0]
elif len(unknown) > 0:

View File

@@ -9,14 +9,22 @@ from pathlib import Path
from PIL import Image
from tqdm import tqdm
import numpy as np
import torch
from library.device_utils import init_ipex, get_preferred_device
init_ipex()
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
sys.path.append(os.path.dirname(__file__))
from blip.blip import blip_decoder, is_url
import library.train_util as train_util
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE = get_preferred_device()
IMAGE_SIZE = 384
@@ -47,7 +55,7 @@ class ImageLoadingTransformDataset(torch.utils.data.Dataset):
# convert to tensor temporarily so dataloader will accept it
tensor = IMAGE_TRANSFORM(image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
return None
return (tensor, img_path)
@@ -74,21 +82,21 @@ def main(args):
args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path
cwd = os.getcwd()
print("Current Working Directory is: ", cwd)
logger.info(f"Current Working Directory is: {cwd}")
os.chdir("finetune")
if not is_url(args.caption_weights) and not os.path.isfile(args.caption_weights):
args.caption_weights = os.path.join("..", args.caption_weights)
print(f"load images from {args.train_data_dir}")
logger.info(f"load images from {args.train_data_dir}")
train_data_dir_path = Path(args.train_data_dir)
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.")
logger.info(f"found {len(image_paths)} images.")
print(f"loading BLIP caption: {args.caption_weights}")
logger.info(f"loading BLIP caption: {args.caption_weights}")
model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit="large", med_config="./blip/med_config.json")
model.eval()
model = model.to(DEVICE)
print("BLIP loaded")
logger.info("BLIP loaded")
# captioningする
def run_batch(path_imgs):
@@ -108,7 +116,7 @@ def main(args):
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
f.write(caption + "\n")
if args.debug:
print(image_path, caption)
logger.info(f'{image_path} {caption}')
# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
@@ -138,7 +146,7 @@ def main(args):
raw_image = raw_image.convert("RGB")
img_tensor = IMAGE_TRANSFORM(raw_image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue
b_imgs.append((image_path, img_tensor))
@@ -148,7 +156,7 @@ def main(args):
if len(b_imgs) > 0:
run_batch(b_imgs)
print("done!")
logger.info("done!")
def setup_parser() -> argparse.ArgumentParser:

View File

@@ -5,12 +5,19 @@ import re
from pathlib import Path
from PIL import Image
from tqdm import tqdm
import torch
from library.device_utils import init_ipex, get_preferred_device
init_ipex()
from transformers import AutoProcessor, AutoModelForCausalLM
from transformers.generation.utils import GenerationMixin
import library.train_util as train_util
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -35,8 +42,8 @@ def remove_words(captions, debug):
for pat in PATTERN_REPLACE:
cap = pat.sub("", cap)
if debug and cap != caption:
print(caption)
print(cap)
logger.info(caption)
logger.info(cap)
removed_caps.append(cap)
return removed_caps
@@ -70,16 +77,16 @@ def main(args):
GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
"""
print(f"load images from {args.train_data_dir}")
logger.info(f"load images from {args.train_data_dir}")
train_data_dir_path = Path(args.train_data_dir)
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.")
logger.info(f"found {len(image_paths)} images.")
# できればcacheに依存せず明示的にダウンロードしたい
print(f"loading GIT: {args.model_id}")
logger.info(f"loading GIT: {args.model_id}")
git_processor = AutoProcessor.from_pretrained(args.model_id)
git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE)
print("GIT loaded")
logger.info("GIT loaded")
# captioningする
def run_batch(path_imgs):
@@ -97,7 +104,7 @@ def main(args):
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
f.write(caption + "\n")
if args.debug:
print(image_path, caption)
logger.info(f"{image_path} {caption}")
# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
@@ -126,7 +133,7 @@ def main(args):
if image.mode != "RGB":
image = image.convert("RGB")
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue
b_imgs.append((image_path, image))
@@ -137,7 +144,7 @@ def main(args):
if len(b_imgs) > 0:
run_batch(b_imgs)
print("done!")
logger.info("done!")
def setup_parser() -> argparse.ArgumentParser:

View File

@@ -5,26 +5,30 @@ from typing import List
from tqdm import tqdm
import library.train_util as train_util
import os
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def main(args):
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
train_data_dir_path = Path(args.train_data_dir)
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.")
logger.info(f"found {len(image_paths)} images.")
if args.in_json is None and Path(args.out_json).is_file():
args.in_json = args.out_json
if args.in_json is not None:
print(f"loading existing metadata: {args.in_json}")
logger.info(f"loading existing metadata: {args.in_json}")
metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
print("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます")
logger.warning("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます")
else:
print("new metadata will be created / 新しいメタデータファイルが作成されます")
logger.info("new metadata will be created / 新しいメタデータファイルが作成されます")
metadata = {}
print("merge caption texts to metadata json.")
logger.info("merge caption texts to metadata json.")
for image_path in tqdm(image_paths):
caption_path = image_path.with_suffix(args.caption_extension)
caption = caption_path.read_text(encoding='utf-8').strip()
@@ -38,12 +42,12 @@ def main(args):
metadata[image_key]['caption'] = caption
if args.debug:
print(image_key, caption)
logger.info(f"{image_key} {caption}")
# metadataを書き出して終わり
print(f"writing metadata: {args.out_json}")
logger.info(f"writing metadata: {args.out_json}")
Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
print("done!")
logger.info("done!")
def setup_parser() -> argparse.ArgumentParser:

View File

@@ -5,26 +5,30 @@ from typing import List
from tqdm import tqdm
import library.train_util as train_util
import os
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def main(args):
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
train_data_dir_path = Path(args.train_data_dir)
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.")
logger.info(f"found {len(image_paths)} images.")
if args.in_json is None and Path(args.out_json).is_file():
args.in_json = args.out_json
if args.in_json is not None:
print(f"loading existing metadata: {args.in_json}")
logger.info(f"loading existing metadata: {args.in_json}")
metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
print("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます")
logger.warning("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます")
else:
print("new metadata will be created / 新しいメタデータファイルが作成されます")
logger.info("new metadata will be created / 新しいメタデータファイルが作成されます")
metadata = {}
print("merge tags to metadata json.")
logger.info("merge tags to metadata json.")
for image_path in tqdm(image_paths):
tags_path = image_path.with_suffix(args.caption_extension)
tags = tags_path.read_text(encoding='utf-8').strip()
@@ -38,13 +42,13 @@ def main(args):
metadata[image_key]['tags'] = tags
if args.debug:
print(image_key, tags)
logger.info(f"{image_key} {tags}")
# metadataを書き出して終わり
print(f"writing metadata: {args.out_json}")
logger.info(f"writing metadata: {args.out_json}")
Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
print("done!")
logger.info("done!")
def setup_parser() -> argparse.ArgumentParser:

View File

@@ -8,13 +8,25 @@ from tqdm import tqdm
import numpy as np
from PIL import Image
import cv2
import torch
from library.device_utils import init_ipex, get_preferred_device
init_ipex()
from torchvision import transforms
import library.model_util as model_util
import library.stable_cascade_utils as sc_utils
import library.train_util as train_util
from library.utils import setup_logging
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
setup_logging()
import logging
logger = logging.getLogger(__name__)
DEVICE = get_preferred_device()
IMAGE_TRANSFORMS = transforms.Compose(
[
@@ -34,7 +46,7 @@ def collate_fn_remove_corrupted(batch):
return batch
def get_npz_filename(data_dir, image_key, is_full_path, recursive):
def get_npz_filename(data_dir, image_key, is_full_path, recursive, stable_cascade):
if is_full_path:
base_name = os.path.splitext(os.path.basename(image_key))[0]
relative_path = os.path.relpath(os.path.dirname(image_key), data_dir)
@@ -42,31 +54,32 @@ def get_npz_filename(data_dir, image_key, is_full_path, recursive):
base_name = image_key
relative_path = ""
ext = ".npz" if not stable_cascade else train_util.STABLE_CASCADE_LATENTS_CACHE_SUFFIX
if recursive and relative_path:
return os.path.join(data_dir, relative_path, base_name) + ".npz"
return os.path.join(data_dir, relative_path, base_name) + ext
else:
return os.path.join(data_dir, base_name) + ".npz"
return os.path.join(data_dir, base_name) + ext
def main(args):
# assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります"
if args.bucket_reso_steps % 8 > 0:
print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります")
logger.warning(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります")
if args.bucket_reso_steps % 32 > 0:
print(
logger.warning(
f"WARNING: bucket_reso_steps is not divisible by 32. It is not working with SDXL / bucket_reso_stepsが32で割り切れません。SDXLでは動作しません"
)
train_data_dir_path = Path(args.train_data_dir)
image_paths: List[str] = [str(p) for p in train_util.glob_images_pathlib(train_data_dir_path, args.recursive)]
print(f"found {len(image_paths)} images.")
logger.info(f"found {len(image_paths)} images.")
if os.path.exists(args.in_json):
print(f"loading existing metadata: {args.in_json}")
logger.info(f"loading existing metadata: {args.in_json}")
with open(args.in_json, "rt", encoding="utf-8") as f:
metadata = json.load(f)
else:
print(f"no metadata / メタデータファイルがありません: {args.in_json}")
logger.error(f"no metadata / メタデータファイルがありません: {args.in_json}")
return
weight_dtype = torch.float32
@@ -75,13 +88,20 @@ def main(args):
elif args.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
vae = model_util.load_vae(args.model_name_or_path, weight_dtype)
if not args.stable_cascade:
vae = model_util.load_vae(args.model_name_or_path, weight_dtype)
divisor = 8
else:
vae = sc_utils.load_effnet(args.model_name_or_path, DEVICE)
divisor = 32
vae.eval()
vae.to(DEVICE, dtype=weight_dtype)
# bucketのサイズを計算する
max_reso = tuple([int(t) for t in args.max_resolution.split(",")])
assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
assert (
len(max_reso) == 2
), f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
bucket_manager = train_util.BucketManager(
args.bucket_no_upscale, max_reso, args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps
@@ -89,7 +109,7 @@ def main(args):
if not args.bucket_no_upscale:
bucket_manager.make_buckets()
else:
print(
logger.warning(
"min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます"
)
@@ -130,7 +150,7 @@ def main(args):
if image.mode != "RGB":
image = image.convert("RGB")
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
@@ -146,6 +166,10 @@ def main(args):
# メタデータに記録する解像度はlatent単位とするので、8単位で切り捨て
metadata[image_key]["train_resolution"] = (reso[0] - reso[0] % 8, reso[1] - reso[1] % 8)
# 追加情報を記録
metadata[image_key]["original_size"] = (image.width, image.height)
metadata[image_key]["train_resized_size"] = resized_size
if not args.bucket_no_upscale:
# upscaleを行わないときには、resize後のサイズは、bucketのサイズと、縦横どちらかが同じであることを確認する
assert (
@@ -160,9 +184,9 @@ def main(args):
), f"internal error resized size is small: {resized_size}, {reso}"
# 既に存在するファイルがあればshape等を確認して同じならskipする
npz_file_name = get_npz_filename(args.train_data_dir, image_key, args.full_path, args.recursive)
npz_file_name = get_npz_filename(args.train_data_dir, image_key, args.full_path, args.recursive, args.stable_cascade)
if args.skip_existing:
if train_util.is_disk_cached_latents_is_expected(reso, npz_file_name, args.flip_aug):
if train_util.is_disk_cached_latents_is_expected(reso, npz_file_name, args.flip_aug, divisor):
continue
# バッチへ追加
@@ -183,15 +207,15 @@ def main(args):
for i, reso in enumerate(bucket_manager.resos):
count = bucket_counts.get(reso, 0)
if count > 0:
print(f"bucket {i} {reso}: {count}")
logger.info(f"bucket {i} {reso}: {count}")
img_ar_errors = np.array(img_ar_errors)
print(f"mean ar error: {np.mean(img_ar_errors)}")
logger.info(f"mean ar error: {np.mean(img_ar_errors)}")
# metadataを書き出して終わり
print(f"writing metadata: {args.out_json}")
logger.info(f"writing metadata: {args.out_json}")
with open(args.out_json, "wt", encoding="utf-8") as f:
json.dump(metadata, f, indent=2)
print("done!")
logger.info("done!")
def setup_parser() -> argparse.ArgumentParser:
@@ -200,7 +224,14 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
parser.add_argument("--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)")
parser.add_argument(
"--stable_cascade",
action="store_true",
help="prepare EffNet latents for stable cascade / stable cascade用のEffNetのlatentsを準備する",
)
parser.add_argument(
"--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)"
)
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument(
"--max_data_loader_n_workers",
@@ -223,10 +254,16 @@ def setup_parser() -> argparse.ArgumentParser:
help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します",
)
parser.add_argument(
"--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します"
"--bucket_no_upscale",
action="store_true",
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します",
)
parser.add_argument(
"--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度"
"--mixed_precision",
type=str,
default="no",
choices=["no", "fp16", "bf16"],
help="use mixed precision / 混合精度を使う場合、その精度",
)
parser.add_argument(
"--full_path",
@@ -234,7 +271,9 @@ def setup_parser() -> argparse.ArgumentParser:
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)",
)
parser.add_argument(
"--flip_aug", action="store_true", help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する"
"--flip_aug",
action="store_true",
help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する",
)
parser.add_argument(
"--skip_existing",

View File

@@ -11,6 +11,10 @@ from PIL import Image
from tqdm import tqdm
import library.train_util as train_util
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# from wd14 tagger
IMAGE_SIZE = 448
@@ -58,7 +62,7 @@ class ImageLoadingPrepDataset(torch.utils.data.Dataset):
image = preprocess_image(image)
tensor = torch.tensor(image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
return None
return (tensor, img_path)
@@ -79,7 +83,7 @@ def main(args):
# depreacatedの警告が出るけどなくなったらその時
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
if not os.path.exists(args.model_dir) or args.force_download:
print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
logger.info(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
files = FILES
if args.onnx:
files += FILES_ONNX
@@ -95,7 +99,7 @@ def main(args):
force_filename=file,
)
else:
print("using existing wd14 tagger model")
logger.info("using existing wd14 tagger model")
# 画像を読み込む
if args.onnx:
@@ -103,8 +107,8 @@ def main(args):
import onnxruntime as ort
onnx_path = f"{args.model_dir}/model.onnx"
print("Running wd14 tagger with onnx")
print(f"loading onnx model: {onnx_path}")
logger.info("Running wd14 tagger with onnx")
logger.info(f"loading onnx model: {onnx_path}")
if not os.path.exists(onnx_path):
raise Exception(
@@ -121,7 +125,7 @@ def main(args):
if args.batch_size != batch_size and type(batch_size) != str:
# some rebatch model may use 'N' as dynamic axes
print(
logger.warning(
f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}"
)
args.batch_size = batch_size
@@ -156,7 +160,7 @@ def main(args):
train_data_dir_path = Path(args.train_data_dir)
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.")
logger.info(f"found {len(image_paths)} images.")
tag_freq = {}
@@ -237,7 +241,10 @@ def main(args):
with open(caption_file, "wt", encoding="utf-8") as f:
f.write(tag_text + "\n")
if args.debug:
print(f"\n{image_path}:\n Character tags: {character_tag_text}\n General tags: {general_tag_text}")
logger.info("")
logger.info(f"{image_path}:")
logger.info(f"\tCharacter tags: {character_tag_text}")
logger.info(f"\tGeneral tags: {general_tag_text}")
# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
@@ -269,7 +276,7 @@ def main(args):
image = image.convert("RGB")
image = preprocess_image(image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue
b_imgs.append((image_path, image))
@@ -284,11 +291,11 @@ def main(args):
if args.frequency_tags:
sorted_tags = sorted(tag_freq.items(), key=lambda x: x[1], reverse=True)
print("\nTag frequencies:")
print("Tag frequencies:")
for tag, freq in sorted_tags:
print(f"{tag}: {freq}")
print("done!")
logger.info("done!")
def setup_parser() -> argparse.ArgumentParser:

3334
gen_img.py Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -3,7 +3,10 @@ import argparse
import random
import re
from typing import List, Optional, Union
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def prepare_scheduler_for_custom_training(noise_scheduler, device):
if hasattr(noise_scheduler, "all_snr"):
@@ -21,7 +24,7 @@ def prepare_scheduler_for_custom_training(noise_scheduler, device):
def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
# fix beta: zero terminal SNR
print(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891")
logger.info(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891")
def enforce_zero_terminal_snr(betas):
# Convert betas to alphas_bar_sqrt
@@ -49,8 +52,8 @@ def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
# print("original:", noise_scheduler.betas)
# print("fixed:", betas)
# logger.info(f"original: {noise_scheduler.betas}")
# logger.info(f"fixed: {betas}")
noise_scheduler.betas = betas
noise_scheduler.alphas = alphas
@@ -79,13 +82,13 @@ def get_snr_scale(timesteps, noise_scheduler):
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
scale = snr_t / (snr_t + 1)
# # show debug info
# print(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}")
# logger.info(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}")
return scale
def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss):
scale = get_snr_scale(timesteps, noise_scheduler)
# print(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
# logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
loss = loss + loss / scale * v_pred_like_loss
return loss
@@ -268,7 +271,7 @@ def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int):
tokens.append(text_token)
weights.append(text_weight)
if truncated:
print("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
return tokens, weights

89
library/device_utils.py Normal file
View File

@@ -0,0 +1,89 @@
import functools
import gc
import torch
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
try:
HAS_CUDA = torch.cuda.is_available()
except Exception:
HAS_CUDA = False
try:
HAS_MPS = torch.backends.mps.is_available()
except Exception:
HAS_MPS = False
try:
import intel_extension_for_pytorch as ipex # noqa
HAS_XPU = torch.xpu.is_available()
except Exception:
HAS_XPU = False
def clean_memory():
gc.collect()
if HAS_CUDA:
torch.cuda.empty_cache()
if HAS_XPU:
torch.xpu.empty_cache()
if HAS_MPS:
torch.mps.empty_cache()
def clean_memory_on_device(device: torch.device):
r"""
Clean memory on the specified device, will be called from training scripts.
"""
gc.collect()
# device may "cuda" or "cuda:0", so we need to check the type of device
if device.type == "cuda":
torch.cuda.empty_cache()
if device.type == "xpu":
torch.xpu.empty_cache()
if device.type == "mps":
torch.mps.empty_cache()
@functools.lru_cache(maxsize=None)
def get_preferred_device() -> torch.device:
r"""
Do not call this function from training scripts. Use accelerator.device instead.
"""
if HAS_CUDA:
device = torch.device("cuda")
elif HAS_XPU:
device = torch.device("xpu")
elif HAS_MPS:
device = torch.device("mps")
else:
device = torch.device("cpu")
logger.info(f"get_preferred_device() -> {device}")
return device
def init_ipex():
"""
Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`.
This function should run right after importing torch and before doing anything else.
If IPEX is not available, this function does nothing.
"""
try:
if HAS_XPU:
from library.ipex import ipex_init
is_initialized, error_message = ipex_init()
if not is_initialized:
logger.error("failed to initialize ipex: {error_message}")
else:
return
except Exception as e:
logger.error("failed to initialize ipex: {e}")

View File

@@ -4,7 +4,10 @@ from pathlib import Path
import argparse
import os
from library.utils import fire_in_thread
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None):
api = HfApi(
@@ -33,9 +36,9 @@ def upload(
try:
api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
except Exception as e: # とりあえずRepositoryNotFoundErrorは確認したが他にあると困るので
print("===========================================")
print(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}")
print("===========================================")
logger.error("===========================================")
logger.error(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}")
logger.error("===========================================")
is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir())
@@ -56,9 +59,9 @@ def upload(
path_in_repo=path_in_repo,
)
except Exception as e: # RuntimeErrorを確認済みだが他にあると困るので
print("===========================================")
print(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}")
print("===========================================")
logger.error("===========================================")
logger.error(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}")
logger.error("===========================================")
if args.async_upload and not force_sync_upload:
fire_in_thread(uploader)

View File

@@ -9,161 +9,171 @@ from .hijacks import ipex_hijacks
def ipex_init(): # pylint: disable=too-many-statements
try:
# Replace cuda with xpu:
torch.cuda.current_device = torch.xpu.current_device
torch.cuda.current_stream = torch.xpu.current_stream
torch.cuda.device = torch.xpu.device
torch.cuda.device_count = torch.xpu.device_count
torch.cuda.device_of = torch.xpu.device_of
torch.cuda.get_device_name = torch.xpu.get_device_name
torch.cuda.get_device_properties = torch.xpu.get_device_properties
torch.cuda.init = torch.xpu.init
torch.cuda.is_available = torch.xpu.is_available
torch.cuda.is_initialized = torch.xpu.is_initialized
torch.cuda.is_current_stream_capturing = lambda: False
torch.cuda.set_device = torch.xpu.set_device
torch.cuda.stream = torch.xpu.stream
torch.cuda.synchronize = torch.xpu.synchronize
torch.cuda.Event = torch.xpu.Event
torch.cuda.Stream = torch.xpu.Stream
torch.cuda.FloatTensor = torch.xpu.FloatTensor
torch.Tensor.cuda = torch.Tensor.xpu
torch.Tensor.is_cuda = torch.Tensor.is_xpu
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
torch.cuda._initialized = torch.xpu.lazy_init._initialized
torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
torch.cuda._tls = torch.xpu.lazy_init._tls
torch.cuda.threading = torch.xpu.lazy_init.threading
torch.cuda.traceback = torch.xpu.lazy_init.traceback
torch.cuda.Optional = torch.xpu.Optional
torch.cuda.__cached__ = torch.xpu.__cached__
torch.cuda.__loader__ = torch.xpu.__loader__
torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
torch.cuda.Tuple = torch.xpu.Tuple
torch.cuda.streams = torch.xpu.streams
torch.cuda._lazy_new = torch.xpu._lazy_new
torch.cuda.FloatStorage = torch.xpu.FloatStorage
torch.cuda.Any = torch.xpu.Any
torch.cuda.__doc__ = torch.xpu.__doc__
torch.cuda.default_generators = torch.xpu.default_generators
torch.cuda.HalfTensor = torch.xpu.HalfTensor
torch.cuda._get_device_index = torch.xpu._get_device_index
torch.cuda.__path__ = torch.xpu.__path__
torch.cuda.Device = torch.xpu.Device
torch.cuda.IntTensor = torch.xpu.IntTensor
torch.cuda.ByteStorage = torch.xpu.ByteStorage
torch.cuda.set_stream = torch.xpu.set_stream
torch.cuda.BoolStorage = torch.xpu.BoolStorage
torch.cuda.os = torch.xpu.os
torch.cuda.torch = torch.xpu.torch
torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
torch.cuda.Union = torch.xpu.Union
torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
torch.cuda.ShortTensor = torch.xpu.ShortTensor
torch.cuda.LongTensor = torch.xpu.LongTensor
torch.cuda.IntStorage = torch.xpu.IntStorage
torch.cuda.LongStorage = torch.xpu.LongStorage
torch.cuda.__annotations__ = torch.xpu.__annotations__
torch.cuda.__package__ = torch.xpu.__package__
torch.cuda.__builtins__ = torch.xpu.__builtins__
torch.cuda.CharTensor = torch.xpu.CharTensor
torch.cuda.List = torch.xpu.List
torch.cuda._lazy_init = torch.xpu._lazy_init
torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
torch.cuda.ByteTensor = torch.xpu.ByteTensor
torch.cuda.StreamContext = torch.xpu.StreamContext
torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
torch.cuda.ShortStorage = torch.xpu.ShortStorage
torch.cuda._lazy_call = torch.xpu._lazy_call
torch.cuda.HalfStorage = torch.xpu.HalfStorage
torch.cuda.random = torch.xpu.random
torch.cuda._device = torch.xpu._device
torch.cuda.classproperty = torch.xpu.classproperty
torch.cuda.__name__ = torch.xpu.__name__
torch.cuda._device_t = torch.xpu._device_t
torch.cuda.warnings = torch.xpu.warnings
torch.cuda.__spec__ = torch.xpu.__spec__
torch.cuda.BoolTensor = torch.xpu.BoolTensor
torch.cuda.CharStorage = torch.xpu.CharStorage
torch.cuda.__file__ = torch.xpu.__file__
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
# torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_xpu_hijacked") and torch.cuda.is_xpu_hijacked:
return True, "Skipping IPEX hijack"
else:
# Replace cuda with xpu:
torch.cuda.current_device = torch.xpu.current_device
torch.cuda.current_stream = torch.xpu.current_stream
torch.cuda.device = torch.xpu.device
torch.cuda.device_count = torch.xpu.device_count
torch.cuda.device_of = torch.xpu.device_of
torch.cuda.get_device_name = torch.xpu.get_device_name
torch.cuda.get_device_properties = torch.xpu.get_device_properties
torch.cuda.init = torch.xpu.init
torch.cuda.is_available = torch.xpu.is_available
torch.cuda.is_initialized = torch.xpu.is_initialized
torch.cuda.is_current_stream_capturing = lambda: False
torch.cuda.set_device = torch.xpu.set_device
torch.cuda.stream = torch.xpu.stream
torch.cuda.synchronize = torch.xpu.synchronize
torch.cuda.Event = torch.xpu.Event
torch.cuda.Stream = torch.xpu.Stream
torch.cuda.FloatTensor = torch.xpu.FloatTensor
torch.Tensor.cuda = torch.Tensor.xpu
torch.Tensor.is_cuda = torch.Tensor.is_xpu
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
torch.cuda._initialized = torch.xpu.lazy_init._initialized
torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
torch.cuda._tls = torch.xpu.lazy_init._tls
torch.cuda.threading = torch.xpu.lazy_init.threading
torch.cuda.traceback = torch.xpu.lazy_init.traceback
torch.cuda.Optional = torch.xpu.Optional
torch.cuda.__cached__ = torch.xpu.__cached__
torch.cuda.__loader__ = torch.xpu.__loader__
torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
torch.cuda.Tuple = torch.xpu.Tuple
torch.cuda.streams = torch.xpu.streams
torch.cuda._lazy_new = torch.xpu._lazy_new
torch.cuda.FloatStorage = torch.xpu.FloatStorage
torch.cuda.Any = torch.xpu.Any
torch.cuda.__doc__ = torch.xpu.__doc__
torch.cuda.default_generators = torch.xpu.default_generators
torch.cuda.HalfTensor = torch.xpu.HalfTensor
torch.cuda._get_device_index = torch.xpu._get_device_index
torch.cuda.__path__ = torch.xpu.__path__
torch.cuda.Device = torch.xpu.Device
torch.cuda.IntTensor = torch.xpu.IntTensor
torch.cuda.ByteStorage = torch.xpu.ByteStorage
torch.cuda.set_stream = torch.xpu.set_stream
torch.cuda.BoolStorage = torch.xpu.BoolStorage
torch.cuda.os = torch.xpu.os
torch.cuda.torch = torch.xpu.torch
torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
torch.cuda.Union = torch.xpu.Union
torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
torch.cuda.ShortTensor = torch.xpu.ShortTensor
torch.cuda.LongTensor = torch.xpu.LongTensor
torch.cuda.IntStorage = torch.xpu.IntStorage
torch.cuda.LongStorage = torch.xpu.LongStorage
torch.cuda.__annotations__ = torch.xpu.__annotations__
torch.cuda.__package__ = torch.xpu.__package__
torch.cuda.__builtins__ = torch.xpu.__builtins__
torch.cuda.CharTensor = torch.xpu.CharTensor
torch.cuda.List = torch.xpu.List
torch.cuda._lazy_init = torch.xpu._lazy_init
torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
torch.cuda.ByteTensor = torch.xpu.ByteTensor
torch.cuda.StreamContext = torch.xpu.StreamContext
torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
torch.cuda.ShortStorage = torch.xpu.ShortStorage
torch.cuda._lazy_call = torch.xpu._lazy_call
torch.cuda.HalfStorage = torch.xpu.HalfStorage
torch.cuda.random = torch.xpu.random
torch.cuda._device = torch.xpu._device
torch.cuda.classproperty = torch.xpu.classproperty
torch.cuda.__name__ = torch.xpu.__name__
torch.cuda._device_t = torch.xpu._device_t
torch.cuda.warnings = torch.xpu.warnings
torch.cuda.__spec__ = torch.xpu.__spec__
torch.cuda.BoolTensor = torch.xpu.BoolTensor
torch.cuda.CharStorage = torch.xpu.CharStorage
torch.cuda.__file__ = torch.xpu.__file__
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
# torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
# Memory:
torch.cuda.memory = torch.xpu.memory
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
torch.xpu.empty_cache = lambda: None
torch.cuda.empty_cache = torch.xpu.empty_cache
torch.cuda.memory_stats = torch.xpu.memory_stats
torch.cuda.memory_summary = torch.xpu.memory_summary
torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
torch.cuda.memory_allocated = torch.xpu.memory_allocated
torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated
torch.cuda.memory_reserved = torch.xpu.memory_reserved
torch.cuda.memory_cached = torch.xpu.memory_reserved
torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved
torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved
torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats
torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats
torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats
torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict
torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats
# Memory:
torch.cuda.memory = torch.xpu.memory
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
torch.xpu.empty_cache = lambda: None
torch.cuda.empty_cache = torch.xpu.empty_cache
torch.cuda.memory_stats = torch.xpu.memory_stats
torch.cuda.memory_summary = torch.xpu.memory_summary
torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
torch.cuda.memory_allocated = torch.xpu.memory_allocated
torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated
torch.cuda.memory_reserved = torch.xpu.memory_reserved
torch.cuda.memory_cached = torch.xpu.memory_reserved
torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved
torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved
torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats
torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats
torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats
torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict
torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats
# RNG:
torch.cuda.get_rng_state = torch.xpu.get_rng_state
torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all
torch.cuda.set_rng_state = torch.xpu.set_rng_state
torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all
torch.cuda.manual_seed = torch.xpu.manual_seed
torch.cuda.manual_seed_all = torch.xpu.manual_seed_all
torch.cuda.seed = torch.xpu.seed
torch.cuda.seed_all = torch.xpu.seed_all
torch.cuda.initial_seed = torch.xpu.initial_seed
# RNG:
torch.cuda.get_rng_state = torch.xpu.get_rng_state
torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all
torch.cuda.set_rng_state = torch.xpu.set_rng_state
torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all
torch.cuda.manual_seed = torch.xpu.manual_seed
torch.cuda.manual_seed_all = torch.xpu.manual_seed_all
torch.cuda.seed = torch.xpu.seed
torch.cuda.seed_all = torch.xpu.seed_all
torch.cuda.initial_seed = torch.xpu.initial_seed
# AMP:
torch.cuda.amp = torch.xpu.amp
torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled
torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype
if not hasattr(torch.cuda.amp, "common"):
torch.cuda.amp.common = contextlib.nullcontext()
torch.cuda.amp.common.amp_definitely_not_available = lambda: False
# AMP:
torch.cuda.amp = torch.xpu.amp
if not hasattr(torch.cuda.amp, "common"):
torch.cuda.amp.common = contextlib.nullcontext()
torch.cuda.amp.common.amp_definitely_not_available = lambda: False
try:
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
except Exception: # pylint: disable=broad-exception-caught
try:
from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error
gradscaler_init()
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
except Exception: # pylint: disable=broad-exception-caught
torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
try:
from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error
gradscaler_init()
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
except Exception: # pylint: disable=broad-exception-caught
torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
# C
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
ipex._C._DeviceProperties.major = 2023
ipex._C._DeviceProperties.minor = 2
# C
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_eu_count
ipex._C._DeviceProperties.major = 2023
ipex._C._DeviceProperties.minor = 2
# Fix functions with ipex:
torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
torch._utils._get_available_device_type = lambda: "xpu"
torch.has_cuda = True
torch.cuda.has_half = True
torch.cuda.is_bf16_supported = lambda *args, **kwargs: True
torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
torch.version.cuda = "11.7"
torch.cuda.get_device_capability = lambda *args, **kwargs: [11,7]
torch.cuda.get_device_properties.major = 11
torch.cuda.get_device_properties.minor = 7
torch.cuda.ipc_collect = lambda *args, **kwargs: None
torch.cuda.utilization = lambda *args, **kwargs: 0
# Fix functions with ipex:
torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
torch._utils._get_available_device_type = lambda: "xpu"
torch.has_cuda = True
torch.cuda.has_half = True
torch.cuda.is_bf16_supported = lambda *args, **kwargs: True
torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
torch.backends.cuda.is_built = lambda *args, **kwargs: True
torch.version.cuda = "12.1"
torch.cuda.get_device_capability = lambda *args, **kwargs: [12,1]
torch.cuda.get_device_properties.major = 12
torch.cuda.get_device_properties.minor = 1
torch.cuda.ipc_collect = lambda *args, **kwargs: None
torch.cuda.utilization = lambda *args, **kwargs: 0
ipex_hijacks()
if not torch.xpu.has_fp64_dtype():
try:
from .diffusers import ipex_diffusers
ipex_diffusers()
except Exception: # pylint: disable=broad-exception-caught
pass
ipex_hijacks()
if not torch.xpu.has_fp64_dtype() or os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is not None:
try:
from .diffusers import ipex_diffusers
ipex_diffusers()
except Exception: # pylint: disable=broad-exception-caught
pass
torch.cuda.is_xpu_hijacked = True
except Exception as e:
return False, e
return True, None

View File

@@ -1,41 +1,98 @@
import os
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
from functools import cache
# pylint: disable=protected-access, missing-function-docstring, line-too-long
original_torch_bmm = torch.bmm
def torch_bmm_32_bit(input, mat2, *, out=None):
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
batch_size_attention, input_tokens, mat2_shape = input.shape[0], input.shape[1], mat2.shape[2]
block_multiply = input.element_size()
slice_block_size = input_tokens * mat2_shape / 1024 / 1024 * block_multiply
# ARC GPUs can't allocate more than 4GB to a single block so we slice the attetion layers
sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4))
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
# Find something divisible with the input_tokens
@cache
def find_slice_size(slice_size, slice_block_size):
while (slice_size * slice_block_size) > attention_slice_rate:
slice_size = slice_size // 2
if slice_size <= 1:
slice_size = 1
break
return slice_size
# Find slice sizes for SDPA
@cache
def find_sdpa_slice_sizes(query_shape, query_element_size):
if len(query_shape) == 3:
batch_size_attention, query_tokens, shape_three = query_shape
shape_four = 1
else:
batch_size_attention, query_tokens, shape_three, shape_four = query_shape
slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size
block_size = batch_size_attention * slice_block_size
split_slice_size = batch_size_attention
if block_size > 4:
do_split = True
# Find something divisible with the input_tokens
while (split_slice_size * slice_block_size) > 4:
split_slice_size = split_slice_size // 2
if split_slice_size <= 1:
split_slice_size = 1
break
split_2_slice_size = input_tokens
if split_slice_size * slice_block_size > 4:
slice_block_size_2 = split_slice_size * mat2_shape / 1024 / 1024 * block_multiply
do_split_2 = True
# Find something divisible with the input_tokens
while (split_2_slice_size * slice_block_size_2) > 4:
split_2_slice_size = split_2_slice_size // 2
if split_2_slice_size <= 1:
split_2_slice_size = 1
break
else:
do_split_2 = False
else:
do_split = False
split_2_slice_size = query_tokens
split_3_slice_size = shape_three
do_split = False
do_split_2 = False
do_split_3 = False
if block_size > sdpa_slice_trigger_rate:
do_split = True
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
if split_slice_size * slice_block_size > attention_slice_rate:
slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size
do_split_2 = True
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size
do_split_3 = True
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
# Find slice sizes for BMM
@cache
def find_bmm_slice_sizes(input_shape, input_element_size, mat2_shape):
batch_size_attention, input_tokens, mat2_atten_shape = input_shape[0], input_shape[1], mat2_shape[2]
slice_block_size = input_tokens * mat2_atten_shape / 1024 / 1024 * input_element_size
block_size = batch_size_attention * slice_block_size
split_slice_size = batch_size_attention
split_2_slice_size = input_tokens
split_3_slice_size = mat2_atten_shape
do_split = False
do_split_2 = False
do_split_3 = False
if block_size > attention_slice_rate:
do_split = True
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
if split_slice_size * slice_block_size > attention_slice_rate:
slice_2_block_size = split_slice_size * mat2_atten_shape / 1024 / 1024 * input_element_size
do_split_2 = True
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
slice_3_block_size = split_slice_size * split_2_slice_size / 1024 / 1024 * input_element_size
do_split_3 = True
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
original_torch_bmm = torch.bmm
def torch_bmm_32_bit(input, mat2, *, out=None):
if input.device.type != "xpu":
return original_torch_bmm(input, mat2, out=out)
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_bmm_slice_sizes(input.shape, input.element_size(), mat2.shape)
# Slice BMM
if do_split:
batch_size_attention, input_tokens, mat2_atten_shape = input.shape[0], input.shape[1], mat2.shape[2]
hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype)
for i in range(batch_size_attention // split_slice_size):
start_idx = i * split_slice_size
@@ -44,11 +101,21 @@ def torch_bmm_32_bit(input, mat2, *, out=None):
for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name
start_idx_2 = i2 * split_2_slice_size
end_idx_2 = (i2 + 1) * split_2_slice_size
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm(
input[start_idx:end_idx, start_idx_2:end_idx_2],
mat2[start_idx:end_idx, start_idx_2:end_idx_2],
out=out
)
if do_split_3:
for i3 in range(mat2_atten_shape // split_3_slice_size): # pylint: disable=invalid-name
start_idx_3 = i3 * split_3_slice_size
end_idx_3 = (i3 + 1) * split_3_slice_size
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_torch_bmm(
input[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
mat2[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
out=out
)
else:
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm(
input[start_idx:end_idx, start_idx_2:end_idx_2],
mat2[start_idx:end_idx, start_idx_2:end_idx_2],
out=out
)
else:
hidden_states[start_idx:end_idx] = original_torch_bmm(
input[start_idx:end_idx],
@@ -57,58 +124,18 @@ def torch_bmm_32_bit(input, mat2, *, out=None):
)
else:
return original_torch_bmm(input, mat2, out=out)
torch.xpu.synchronize(input.device)
return hidden_states
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
if len(query.shape) == 3:
batch_size_attention, query_tokens, shape_three = query.shape
shape_four = 1
else:
batch_size_attention, query_tokens, shape_three, shape_four = query.shape
block_multiply = query.element_size()
slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * block_multiply
block_size = batch_size_attention * slice_block_size
split_slice_size = batch_size_attention
if block_size > 4:
do_split = True
# Find something divisible with the batch_size_attention
while (split_slice_size * slice_block_size) > 4:
split_slice_size = split_slice_size // 2
if split_slice_size <= 1:
split_slice_size = 1
break
split_2_slice_size = query_tokens
if split_slice_size * slice_block_size > 4:
slice_block_size_2 = split_slice_size * shape_three * shape_four / 1024 / 1024 * block_multiply
do_split_2 = True
# Find something divisible with the query_tokens
while (split_2_slice_size * slice_block_size_2) > 4:
split_2_slice_size = split_2_slice_size // 2
if split_2_slice_size <= 1:
split_2_slice_size = 1
break
split_3_slice_size = shape_three
if split_2_slice_size * slice_block_size_2 > 4:
slice_block_size_3 = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * block_multiply
do_split_3 = True
# Find something divisible with the shape_three
while (split_3_slice_size * slice_block_size_3) > 4:
split_3_slice_size = split_3_slice_size // 2
if split_3_slice_size <= 1:
split_3_slice_size = 1
break
else:
do_split_3 = False
else:
do_split_2 = False
else:
do_split = False
if query.device.type != "xpu":
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_sdpa_slice_sizes(query.shape, query.element_size())
# Slice SDPA
if do_split:
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
for i in range(batch_size_attention // split_slice_size):
start_idx = i * split_slice_size
@@ -145,7 +172,6 @@ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropo
dropout_p=dropout_p, is_causal=is_causal
)
else:
return original_scaled_dot_product_attention(
query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal
)
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
torch.xpu.synchronize(query.device)
return hidden_states

View File

@@ -1,10 +1,62 @@
import os
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
import diffusers #0.24.0 # pylint: disable=import-error
from diffusers.models.attention_processor import Attention
from diffusers.utils import USE_PEFT_BACKEND
from functools import cache
# pylint: disable=protected-access, missing-function-docstring, line-too-long
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
@cache
def find_slice_size(slice_size, slice_block_size):
while (slice_size * slice_block_size) > attention_slice_rate:
slice_size = slice_size // 2
if slice_size <= 1:
slice_size = 1
break
return slice_size
@cache
def find_attention_slice_sizes(query_shape, query_element_size, query_device_type, slice_size=None):
if len(query_shape) == 3:
batch_size_attention, query_tokens, shape_three = query_shape
shape_four = 1
else:
batch_size_attention, query_tokens, shape_three, shape_four = query_shape
if slice_size is not None:
batch_size_attention = slice_size
slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size
block_size = batch_size_attention * slice_block_size
split_slice_size = batch_size_attention
split_2_slice_size = query_tokens
split_3_slice_size = shape_three
do_split = False
do_split_2 = False
do_split_3 = False
if query_device_type != "xpu":
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
if block_size > attention_slice_rate:
do_split = True
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
if split_slice_size * slice_block_size > attention_slice_rate:
slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size
do_split_2 = True
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size
do_split_3 = True
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
r"""
Processor for implementing sliced attention.
@@ -18,7 +70,9 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
def __init__(self, slice_size):
self.slice_size = slice_size
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): # pylint: disable=too-many-statements, too-many-locals, too-many-branches
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
encoder_hidden_states=None, attention_mask=None) -> torch.FloatTensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches
residual = hidden_states
input_ndim = hidden_states.ndim
@@ -54,49 +108,62 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
)
#ARC GPUs can't allocate more than 4GB to a single block, Slice it:
block_multiply = query.element_size()
slice_block_size = self.slice_size * shape_three / 1024 / 1024 * block_multiply
block_size = query_tokens * slice_block_size
split_2_slice_size = query_tokens
if block_size > 4:
do_split_2 = True
#Find something divisible with the query_tokens
while (split_2_slice_size * slice_block_size) > 4:
split_2_slice_size = split_2_slice_size // 2
if split_2_slice_size <= 1:
split_2_slice_size = 1
break
else:
do_split_2 = False
for i in range(batch_size_attention // self.slice_size):
start_idx = i * self.slice_size
end_idx = (i + 1) * self.slice_size
####################################################################
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
_, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type, slice_size=self.slice_size)
for i in range(batch_size_attention // split_slice_size):
start_idx = i * split_slice_size
end_idx = (i + 1) * split_slice_size
if do_split_2:
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
start_idx_2 = i2 * split_2_slice_size
end_idx_2 = (i2 + 1) * split_2_slice_size
if do_split_3:
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
start_idx_3 = i3 * split_3_slice_size
end_idx_3 = (i3 + 1) * split_3_slice_size
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
del query_slice
del key_slice
del attn_mask_slice
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3])
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice
del attn_slice
else:
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
del query_slice
del key_slice
del attn_mask_slice
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
del attn_slice
torch.xpu.synchronize(query.device)
else:
query_slice = query[start_idx:end_idx]
key_slice = key[start_idx:end_idx]
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
del query_slice
del key_slice
del attn_mask_slice
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
del attn_slice
####################################################################
hidden_states = attn.batch_to_head_dim(hidden_states)
@@ -115,6 +182,131 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
return hidden_states
class AttnProcessor:
r"""
Default processor for performing attention-related computations.
"""
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
encoder_hidden_states=None, attention_mask=None,
temb=None, scale: float = 1.0) -> torch.Tensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches
residual = hidden_states
args = () if USE_PEFT_BACKEND else (scale,)
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, *args)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
####################################################################
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type)
if do_split:
for i in range(batch_size_attention // split_slice_size):
start_idx = i * split_slice_size
end_idx = (i + 1) * split_slice_size
if do_split_2:
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
start_idx_2 = i2 * split_2_slice_size
end_idx_2 = (i2 + 1) * split_2_slice_size
if do_split_3:
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
start_idx_3 = i3 * split_3_slice_size
end_idx_3 = (i3 + 1) * split_3_slice_size
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
del query_slice
del key_slice
del attn_mask_slice
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3])
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice
del attn_slice
else:
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
del query_slice
del key_slice
del attn_mask_slice
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
del attn_slice
else:
query_slice = query[start_idx:end_idx]
key_slice = key[start_idx:end_idx]
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
del query_slice
del key_slice
del attn_mask_slice
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
del attn_slice
torch.xpu.synchronize(query.device)
else:
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
####################################################################
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
def ipex_diffusers():
#ARC GPUs can't allocate more than 4GB to a single block:
diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor
diffusers.models.attention_processor.AttnProcessor = AttnProcessor

View File

@@ -1,75 +1,26 @@
import contextlib
import importlib
import os
from functools import wraps
from contextlib import nullcontext
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
import numpy as np
device_supports_fp64 = torch.xpu.has_fp64_dtype()
# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
class CondFunc: # pylint: disable=missing-class-docstring
def __new__(cls, orig_func, sub_func, cond_func):
self = super(CondFunc, cls).__new__(cls)
if isinstance(orig_func, str):
func_path = orig_func.split('.')
for i in range(len(func_path)-1, -1, -1):
try:
resolved_obj = importlib.import_module('.'.join(func_path[:i]))
break
except ImportError:
pass
for attr_name in func_path[i:-1]:
resolved_obj = getattr(resolved_obj, attr_name)
orig_func = getattr(resolved_obj, func_path[-1])
setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
self.__init__(orig_func, sub_func, cond_func)
return lambda *args, **kwargs: self(*args, **kwargs)
def __init__(self, orig_func, sub_func, cond_func):
self.__orig_func = orig_func
self.__sub_func = sub_func
self.__cond_func = cond_func
def __call__(self, *args, **kwargs):
if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
return self.__sub_func(self.__orig_func, *args, **kwargs)
else:
return self.__orig_func(*args, **kwargs)
_utils = torch.utils.data._utils
def _shutdown_workers(self):
if torch.utils.data._utils is None or torch.utils.data._utils.python_exit_status is True or torch.utils.data._utils.python_exit_status is None:
return
if hasattr(self, "_shutdown") and not self._shutdown:
self._shutdown = True
try:
if hasattr(self, '_pin_memory_thread'):
self._pin_memory_thread_done_event.set()
self._worker_result_queue.put((None, None))
self._pin_memory_thread.join()
self._worker_result_queue.cancel_join_thread()
self._worker_result_queue.close()
self._workers_done_event.set()
for worker_id in range(len(self._workers)):
if self._persistent_workers or self._workers_status[worker_id]:
self._mark_worker_as_unavailable(worker_id, shutdown=True)
for w in self._workers: # pylint: disable=invalid-name
w.join(timeout=torch.utils.data._utils.MP_STATUS_CHECK_INTERVAL)
for q in self._index_queues: # pylint: disable=invalid-name
q.cancel_join_thread()
q.close()
finally:
if self._worker_pids_set:
torch.utils.data._utils.signal_handling._remove_worker_pids(id(self))
self._worker_pids_set = False
for w in self._workers: # pylint: disable=invalid-name
if w.is_alive():
w.terminate()
class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument
if isinstance(device_ids, list) and len(device_ids) > 1:
print("IPEX backend doesn't support DataParallel on multiple XPU devices")
logger.error("IPEX backend doesn't support DataParallel on multiple XPU devices")
return module.to("xpu")
def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
return contextlib.nullcontext()
return nullcontext()
@property
def is_cuda(self):
return self.device.type == 'xpu' or self.device.type == 'cuda'
def check_device(device):
return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int))
@@ -77,28 +28,19 @@ def check_device(device):
def return_xpu(device):
return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu"
def ipex_no_cuda(orig_func, *args, **kwargs):
torch.cuda.is_available = lambda: False
orig_func(*args, **kwargs)
torch.cuda.is_available = torch.xpu.is_available
original_autocast = torch.autocast
def ipex_autocast(*args, **kwargs):
if len(args) > 0 and args[0] == "cuda":
return original_autocast("xpu", *args[1:], **kwargs)
# Autocast
original_autocast_init = torch.amp.autocast_mode.autocast.__init__
@wraps(torch.amp.autocast_mode.autocast.__init__)
def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=None):
if device_type == "cuda":
return original_autocast_init(self, device_type="xpu", dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
else:
return original_autocast(*args, **kwargs)
return original_autocast_init(self, device_type=device_type, dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
# Embedding BF16
original_torch_cat = torch.cat
def torch_cat(tensor, *args, **kwargs):
if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs)
else:
return original_torch_cat(tensor, *args, **kwargs)
# Latent antialias:
# Latent Antialias CPU Offload:
original_interpolate = torch.nn.functional.interpolate
@wraps(torch.nn.functional.interpolate)
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
if antialias or align_corners is not None:
return_device = tensor.device
@@ -109,19 +51,33 @@ def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corn
return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode,
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias)
original_linalg_solve = torch.linalg.solve
def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name
if A.device != torch.device("cpu") or B.device != torch.device("cpu"):
return_device = A.device
return original_linalg_solve(A.to("cpu"), B.to("cpu"), *args, **kwargs).to(return_device)
else:
return original_linalg_solve(A, B, *args, **kwargs)
if torch.xpu.has_fp64_dtype():
# Diffusers Float64 (Alchemist GPUs doesn't support 64 bit):
original_from_numpy = torch.from_numpy
@wraps(torch.from_numpy)
def from_numpy(ndarray):
if ndarray.dtype == float:
return original_from_numpy(ndarray.astype('float32'))
else:
return original_from_numpy(ndarray)
original_as_tensor = torch.as_tensor
@wraps(torch.as_tensor)
def as_tensor(data, dtype=None, device=None):
if check_device(device):
device = return_xpu(device)
if isinstance(data, np.ndarray) and data.dtype == float and not (
(isinstance(device, torch.device) and device.type == "cpu") or (isinstance(device, str) and "cpu" in device)):
return original_as_tensor(data, dtype=torch.float32, device=device)
else:
return original_as_tensor(data, dtype=dtype, device=device)
if device_supports_fp64 and os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is None:
original_torch_bmm = torch.bmm
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
else:
# 64 bit attention workarounds for Alchemist:
# 32 bit attention workarounds for Alchemist:
try:
from .attention import torch_bmm_32_bit as original_torch_bmm
from .attention import scaled_dot_product_attention_32_bit as original_scaled_dot_product_attention
@@ -129,124 +85,214 @@ else:
original_torch_bmm = torch.bmm
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
# dtype errors:
# Data Type Errors:
@wraps(torch.bmm)
def torch_bmm(input, mat2, *, out=None):
if input.dtype != mat2.dtype:
mat2 = mat2.to(input.dtype)
return original_torch_bmm(input, mat2, out=out)
@wraps(torch.nn.functional.scaled_dot_product_attention)
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
if query.dtype != key.dtype:
key = key.to(dtype=query.dtype)
if query.dtype != value.dtype:
value = value.to(dtype=query.dtype)
if attn_mask is not None and query.dtype != attn_mask.dtype:
attn_mask = attn_mask.to(dtype=query.dtype)
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
@property
def is_cuda(self):
return self.device.type == 'xpu'
# A1111 FP16
original_functional_group_norm = torch.nn.functional.group_norm
@wraps(torch.nn.functional.group_norm)
def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05):
if weight is not None and input.dtype != weight.data.dtype:
input = input.to(dtype=weight.data.dtype)
if bias is not None and weight is not None and bias.data.dtype != weight.data.dtype:
bias.data = bias.data.to(dtype=weight.data.dtype)
return original_functional_group_norm(input, num_groups, weight=weight, bias=bias, eps=eps)
def ipex_hijacks():
CondFunc('torch.tensor',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.Tensor.to',
lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
CondFunc('torch.Tensor.cuda',
lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
CondFunc('torch.UntypedStorage.__init__',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.UntypedStorage.cuda',
lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
CondFunc('torch.empty',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.randn',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.ones',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.zeros',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.linspace',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.load',
lambda orig_func, f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs:
orig_func(orig_func, f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs),
lambda orig_func, f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs: check_device(map_location))
if hasattr(torch.xpu, "Generator"):
CondFunc('torch.Generator',
lambda orig_func, device=None: torch.xpu.Generator(return_xpu(device)),
lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu")
# A1111 BF16
original_functional_layer_norm = torch.nn.functional.layer_norm
@wraps(torch.nn.functional.layer_norm)
def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
if weight is not None and input.dtype != weight.data.dtype:
input = input.to(dtype=weight.data.dtype)
if bias is not None and weight is not None and bias.data.dtype != weight.data.dtype:
bias.data = bias.data.to(dtype=weight.data.dtype)
return original_functional_layer_norm(input, normalized_shape, weight=weight, bias=bias, eps=eps)
# Training
original_functional_linear = torch.nn.functional.linear
@wraps(torch.nn.functional.linear)
def functional_linear(input, weight, bias=None):
if input.dtype != weight.data.dtype:
input = input.to(dtype=weight.data.dtype)
if bias is not None and bias.data.dtype != weight.data.dtype:
bias.data = bias.data.to(dtype=weight.data.dtype)
return original_functional_linear(input, weight, bias=bias)
original_functional_conv2d = torch.nn.functional.conv2d
@wraps(torch.nn.functional.conv2d)
def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
if input.dtype != weight.data.dtype:
input = input.to(dtype=weight.data.dtype)
if bias is not None and bias.data.dtype != weight.data.dtype:
bias.data = bias.data.to(dtype=weight.data.dtype)
return original_functional_conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
# A1111 Embedding BF16
original_torch_cat = torch.cat
@wraps(torch.cat)
def torch_cat(tensor, *args, **kwargs):
if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs)
else:
CondFunc('torch.Generator',
lambda orig_func, device=None: orig_func(return_xpu(device)),
lambda orig_func, device=None: check_device(device))
return original_torch_cat(tensor, *args, **kwargs)
# TiledVAE and ControlNet:
CondFunc('torch.batch_norm',
lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input,
weight if weight is not None else torch.ones(input.size()[1], device=input.device),
bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs),
lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu"))
CondFunc('torch.instance_norm',
lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input,
weight if weight is not None else torch.ones(input.size()[1], device=input.device),
bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs),
lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu"))
# SwinIR BF16:
original_functional_pad = torch.nn.functional.pad
@wraps(torch.nn.functional.pad)
def functional_pad(input, pad, mode='constant', value=None):
if mode == 'reflect' and input.dtype == torch.bfloat16:
return original_functional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16)
else:
return original_functional_pad(input, pad, mode=mode, value=value)
# Functions with dtype errors:
CondFunc('torch.nn.modules.GroupNorm.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
# Training:
CondFunc('torch.nn.modules.linear.Linear.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
CondFunc('torch.nn.modules.conv.Conv2d.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
# BF16:
CondFunc('torch.nn.functional.layer_norm',
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs),
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
weight is not None and input.dtype != weight.data.dtype)
# SwinIR BF16:
CondFunc('torch.nn.functional.pad',
lambda orig_func, input, pad, mode='constant', value=None: orig_func(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16),
lambda orig_func, input, pad, mode='constant', value=None: mode == 'reflect' and input.dtype == torch.bfloat16)
# Diffusers Float64 (Alchemist GPUs doesn't support 64 bit):
if not torch.xpu.has_fp64_dtype():
CondFunc('torch.from_numpy',
lambda orig_func, ndarray: orig_func(ndarray.astype('float32')),
lambda orig_func, ndarray: ndarray.dtype == float)
original_torch_tensor = torch.tensor
@wraps(torch.tensor)
def torch_tensor(data, *args, dtype=None, device=None, **kwargs):
if check_device(device):
device = return_xpu(device)
if not device_supports_fp64:
if (isinstance(device, torch.device) and device.type == "xpu") or (isinstance(device, str) and "xpu" in device):
if dtype == torch.float64:
dtype = torch.float32
elif dtype is None and (hasattr(data, "dtype") and (data.dtype == torch.float64 or data.dtype == float)):
dtype = torch.float32
return original_torch_tensor(data, *args, dtype=dtype, device=device, **kwargs)
# Broken functions when torch.cuda.is_available is True:
# Pin Memory:
CondFunc('torch.utils.data.dataloader._BaseDataLoaderIter.__init__',
lambda orig_func, *args, **kwargs: ipex_no_cuda(orig_func, *args, **kwargs),
lambda orig_func, *args, **kwargs: True)
original_Tensor_to = torch.Tensor.to
@wraps(torch.Tensor.to)
def Tensor_to(self, device=None, *args, **kwargs):
if check_device(device):
return original_Tensor_to(self, return_xpu(device), *args, **kwargs)
else:
return original_Tensor_to(self, device, *args, **kwargs)
# Functions that make compile mad with CondFunc:
torch.nn.DataParallel = DummyDataParallel
torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = _shutdown_workers
original_Tensor_cuda = torch.Tensor.cuda
@wraps(torch.Tensor.cuda)
def Tensor_cuda(self, device=None, *args, **kwargs):
if check_device(device):
return original_Tensor_cuda(self, return_xpu(device), *args, **kwargs)
else:
return original_Tensor_cuda(self, device, *args, **kwargs)
original_UntypedStorage_init = torch.UntypedStorage.__init__
@wraps(torch.UntypedStorage.__init__)
def UntypedStorage_init(*args, device=None, **kwargs):
if check_device(device):
return original_UntypedStorage_init(*args, device=return_xpu(device), **kwargs)
else:
return original_UntypedStorage_init(*args, device=device, **kwargs)
original_UntypedStorage_cuda = torch.UntypedStorage.cuda
@wraps(torch.UntypedStorage.cuda)
def UntypedStorage_cuda(self, device=None, *args, **kwargs):
if check_device(device):
return original_UntypedStorage_cuda(self, return_xpu(device), *args, **kwargs)
else:
return original_UntypedStorage_cuda(self, device, *args, **kwargs)
original_torch_empty = torch.empty
@wraps(torch.empty)
def torch_empty(*args, device=None, **kwargs):
if check_device(device):
return original_torch_empty(*args, device=return_xpu(device), **kwargs)
else:
return original_torch_empty(*args, device=device, **kwargs)
original_torch_randn = torch.randn
@wraps(torch.randn)
def torch_randn(*args, device=None, **kwargs):
if check_device(device):
return original_torch_randn(*args, device=return_xpu(device), **kwargs)
else:
return original_torch_randn(*args, device=device, **kwargs)
original_torch_ones = torch.ones
@wraps(torch.ones)
def torch_ones(*args, device=None, **kwargs):
if check_device(device):
return original_torch_ones(*args, device=return_xpu(device), **kwargs)
else:
return original_torch_ones(*args, device=device, **kwargs)
original_torch_zeros = torch.zeros
@wraps(torch.zeros)
def torch_zeros(*args, device=None, **kwargs):
if check_device(device):
return original_torch_zeros(*args, device=return_xpu(device), **kwargs)
else:
return original_torch_zeros(*args, device=device, **kwargs)
original_torch_linspace = torch.linspace
@wraps(torch.linspace)
def torch_linspace(*args, device=None, **kwargs):
if check_device(device):
return original_torch_linspace(*args, device=return_xpu(device), **kwargs)
else:
return original_torch_linspace(*args, device=device, **kwargs)
original_torch_Generator = torch.Generator
@wraps(torch.Generator)
def torch_Generator(device=None):
if check_device(device):
return original_torch_Generator(return_xpu(device))
else:
return original_torch_Generator(device)
original_torch_load = torch.load
@wraps(torch.load)
def torch_load(f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs):
if check_device(map_location):
return original_torch_load(f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs)
else:
return original_torch_load(f, map_location=map_location, pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs)
# Hijack Functions:
def ipex_hijacks():
torch.tensor = torch_tensor
torch.Tensor.to = Tensor_to
torch.Tensor.cuda = Tensor_cuda
torch.UntypedStorage.__init__ = UntypedStorage_init
torch.UntypedStorage.cuda = UntypedStorage_cuda
torch.empty = torch_empty
torch.randn = torch_randn
torch.ones = torch_ones
torch.zeros = torch_zeros
torch.linspace = torch_linspace
torch.Generator = torch_Generator
torch.load = torch_load
torch.autocast = ipex_autocast
torch.backends.cuda.sdp_kernel = return_null_context
torch.nn.DataParallel = DummyDataParallel
torch.UntypedStorage.is_cuda = is_cuda
torch.amp.autocast_mode.autocast.__init__ = autocast_init
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
torch.nn.functional.group_norm = functional_group_norm
torch.nn.functional.layer_norm = functional_layer_norm
torch.nn.functional.linear = functional_linear
torch.nn.functional.conv2d = functional_conv2d
torch.nn.functional.interpolate = interpolate
torch.linalg.solve = linalg_solve
torch.nn.functional.pad = functional_pad
torch.bmm = torch_bmm
torch.cat = torch_cat
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
if not device_supports_fp64:
torch.from_numpy = from_numpy
torch.as_tensor = as_tensor

View File

@@ -9,7 +9,7 @@ import numpy as np
import PIL.Image
import torch
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
import diffusers
from diffusers import SchedulerMixin, StableDiffusionPipeline
@@ -17,7 +17,6 @@ from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
from diffusers.utils import logging
try:
from diffusers.utils import PIL_INTERPOLATION
except ImportError:
@@ -520,6 +519,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
image_encoder: CLIPVisionModelWithProjection = None,
clip_skip: int = 1,
):
super().__init__(
@@ -531,32 +531,11 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
requires_safety_checker=requires_safety_checker,
image_encoder=image_encoder,
)
self.clip_skip = clip_skip
self.custom_clip_skip = clip_skip
self.__init__additional__()
# else:
# def __init__(
# self,
# vae: AutoencoderKL,
# text_encoder: CLIPTextModel,
# tokenizer: CLIPTokenizer,
# unet: UNet2DConditionModel,
# scheduler: SchedulerMixin,
# safety_checker: StableDiffusionSafetyChecker,
# feature_extractor: CLIPFeatureExtractor,
# ):
# super().__init__(
# vae=vae,
# text_encoder=text_encoder,
# tokenizer=tokenizer,
# unet=unet,
# scheduler=scheduler,
# safety_checker=safety_checker,
# feature_extractor=feature_extractor,
# )
# self.__init__additional__()
def __init__additional__(self):
if not hasattr(self, "vae_scale_factor"):
setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
@@ -624,7 +603,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
prompt=prompt,
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
max_embeddings_multiples=max_embeddings_multiples,
clip_skip=self.clip_skip,
clip_skip=self.custom_clip_skip,
)
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
@@ -646,7 +625,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if height % 8 != 0 or width % 8 != 0:
print(height, width)
logger.info(f'{height} {width}')
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (

View File

@@ -3,19 +3,20 @@
import math
import os
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from library.device_utils import init_ipex
init_ipex()
import diffusers
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel
from safetensors.torch import load_file, save_file
from library.original_unet import UNet2DConditionModel
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# DiffUsers版StableDiffusionのモデルパラメータ
NUM_TRAIN_TIMESTEPS = 1000
@@ -571,9 +572,9 @@ def convert_ldm_clip_checkpoint_v1(checkpoint):
if key.startswith("cond_stage_model.transformer"):
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
# support checkpoint without position_ids (invalid checkpoint)
if "text_model.embeddings.position_ids" not in text_model_dict:
text_model_dict["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) # 77 is the max length of the text
# remove position_ids for newer transformer, which causes error :(
if "text_model.embeddings.position_ids" in text_model_dict:
text_model_dict.pop("text_model.embeddings.position_ids")
return text_model_dict
@@ -947,7 +948,7 @@ def convert_vae_state_dict(vae_state_dict):
for k, v in new_state_dict.items():
for weight_name in weights_to_convert:
if f"mid.attn_1.{weight_name}.weight" in k:
# print(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1")
# logger.info(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1")
new_state_dict[k] = reshape_weight_for_sd(v)
return new_state_dict
@@ -1005,7 +1006,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt
unet = UNet2DConditionModel(**unet_config).to(device)
info = unet.load_state_dict(converted_unet_checkpoint)
print("loading u-net:", info)
logger.info(f"loading u-net: {info}")
# Convert the VAE model.
vae_config = create_vae_diffusers_config()
@@ -1013,7 +1014,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt
vae = AutoencoderKL(**vae_config).to(device)
info = vae.load_state_dict(converted_vae_checkpoint)
print("loading vae:", info)
logger.info(f"loading vae: {info}")
# convert text_model
if v2:
@@ -1047,7 +1048,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt
# logging.set_verbosity_error() # don't show annoying warning
# text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
# logging.set_verbosity_warning()
# print(f"config: {text_model.config}")
# logger.info(f"config: {text_model.config}")
cfg = CLIPTextConfig(
vocab_size=49408,
hidden_size=768,
@@ -1070,7 +1071,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt
)
text_model = CLIPTextModel._from_config(cfg)
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
print("loading text encoder:", info)
logger.info(f"loading text encoder: {info}")
return text_model, vae, unet
@@ -1145,7 +1146,7 @@ def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=Fals
# 最後の層などを捏造するか
if make_dummy_weights:
print("make dummy weights for resblock.23, text_projection and logit scale.")
logger.info("make dummy weights for resblock.23, text_projection and logit scale.")
keys = list(new_sd.keys())
for key in keys:
if key.startswith("transformer.resblocks.22."):
@@ -1242,8 +1243,13 @@ def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_mod
if vae is None:
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
# original U-Net cannot be saved, so we need to convert it to the Diffusers version
# TODO this consumes a lot of memory
diffusers_unet = diffusers.UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet")
diffusers_unet.load_state_dict(unet.state_dict())
pipeline = StableDiffusionPipeline(
unet=unet,
unet=diffusers_unet,
text_encoder=text_encoder,
vae=vae,
scheduler=scheduler,
@@ -1259,14 +1265,14 @@ VAE_PREFIX = "first_stage_model."
def load_vae(vae_id, dtype):
print(f"load VAE: {vae_id}")
logger.info(f"load VAE: {vae_id}")
if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
# Diffusers local/remote
try:
vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
except EnvironmentError as e:
print(f"exception occurs in loading vae: {e}")
print("retry with subfolder='vae'")
logger.error(f"exception occurs in loading vae: {e}")
logger.error("retry with subfolder='vae'")
vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
return vae
@@ -1338,13 +1344,13 @@ def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64)
if __name__ == "__main__":
resos = make_bucket_resolutions((512, 768))
print(len(resos))
print(resos)
logger.info(f"{len(resos)}")
logger.info(f"{resos}")
aspect_ratios = [w / h for w, h in resos]
print(aspect_ratios)
logger.info(f"{aspect_ratios}")
ars = set()
for ar in aspect_ratios:
if ar in ars:
print("error! duplicate ar:", ar)
logger.error(f"error! duplicate ar: {ar}")
ars.add(ar)

View File

@@ -113,6 +113,10 @@ import torch
from torch import nn
from torch.nn import functional as F
from einops import rearrange
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
BLOCK_OUT_CHANNELS: Tuple[int] = (320, 640, 1280, 1280)
TIMESTEP_INPUT_DIM = BLOCK_OUT_CHANNELS[0]
@@ -1262,9 +1266,9 @@ class CrossAttnUpBlock2D(nn.Module):
for attn in self.attentions:
attn.set_use_memory_efficient_attention(xformers, mem_eff)
def set_use_sdpa(self, spda):
def set_use_sdpa(self, sdpa):
for attn in self.attentions:
attn.set_use_sdpa(spda)
attn.set_use_sdpa(sdpa)
def forward(
self,
@@ -1380,7 +1384,7 @@ class UNet2DConditionModel(nn.Module):
):
super().__init__()
assert sample_size is not None, "sample_size must be specified"
print(
logger.info(
f"UNet2DConditionModel: {sample_size}, {attention_head_dim}, {cross_attention_dim}, {use_linear_projection}, {upcast_attention}"
)
@@ -1514,7 +1518,7 @@ class UNet2DConditionModel(nn.Module):
def set_gradient_checkpointing(self, value=False):
modules = self.down_blocks + [self.mid_block] + self.up_blocks
for module in modules:
print(module.__class__.__name__, module.gradient_checkpointing, "->", value)
logger.info(f"{module.__class__.__name__} {module.gradient_checkpointing} -> {value}")
module.gradient_checkpointing = value
# endregion
@@ -1709,14 +1713,14 @@ class InferUNet2DConditionModel:
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
if ds_depth_1 is None:
print("Deep Shrink is disabled.")
logger.info("Deep Shrink is disabled.")
self.ds_depth_1 = None
self.ds_timesteps_1 = None
self.ds_depth_2 = None
self.ds_timesteps_2 = None
self.ds_ratio = None
else:
print(
logger.info(
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
)
self.ds_depth_1 = ds_depth_1

View File

@@ -5,6 +5,12 @@ from io import BytesIO
import os
from typing import List, Optional, Tuple, Union
import safetensors
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
r"""
# Metadata Example
@@ -51,11 +57,13 @@ ARCH_SD_V1 = "stable-diffusion-v1"
ARCH_SD_V2_512 = "stable-diffusion-v2-512"
ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v"
ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base"
ARCH_STABLE_CASCADE = "stable-cascade"
ADAPTER_LORA = "lora"
ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models"
IMPL_STABILITY_AI_STABLE_CASCADE = "https://github.com/Stability-AI/StableCascade"
IMPL_DIFFUSERS = "diffusers"
PRED_TYPE_EPSILON = "epsilon"
@@ -109,6 +117,7 @@ def build_metadata(
merged_from: Optional[str] = None,
timesteps: Optional[Tuple[int, int]] = None,
clip_skip: Optional[int] = None,
stable_cascade: Optional[bool] = None,
):
# if state_dict is None, hash is not calculated
@@ -120,7 +129,9 @@ def build_metadata(
# hash = precalculate_safetensors_hashes(state_dict)
# metadata["modelspec.hash_sha256"] = hash
if sdxl:
if stable_cascade:
arch = ARCH_STABLE_CASCADE
elif sdxl:
arch = ARCH_SD_XL_V1_BASE
elif v2:
if v_parameterization:
@@ -138,9 +149,11 @@ def build_metadata(
metadata["modelspec.architecture"] = arch
if not lora and not textual_inversion and is_stable_diffusion_ckpt is None:
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
if stable_cascade:
impl = IMPL_STABILITY_AI_STABLE_CASCADE
elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
# Stable Diffusion ckpt, TI, SDXL LoRA
impl = IMPL_STABILITY_AI
else:
@@ -231,8 +244,8 @@ def build_metadata(
# # assert all values are filled
# assert all([v is not None for v in metadata.values()]), metadata
if not all([v is not None for v in metadata.values()]):
print(f"Internal error: some metadata values are None: {metadata}")
logger.error(f"Internal error: some metadata values are None: {metadata}")
return metadata
@@ -246,7 +259,7 @@ def get_title(metadata: dict) -> Optional[str]:
def load_metadata_from_safetensors(model: str) -> dict:
if not model.endswith(".safetensors"):
return {}
with safetensors.safe_open(model, framework="pt") as f:
metadata = f.metadata()
if metadata is None:

View File

@@ -923,7 +923,11 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
if up1 is not None:
uncond_pool = up1
dtype = self.unet.dtype
unet_dtype = self.unet.dtype
dtype = unet_dtype
if hasattr(dtype, "itemsize") and dtype.itemsize == 1: # fp8
dtype = torch.float16
self.unet.to(dtype)
# 4. Preprocess image and mask
if isinstance(image, PIL.Image.Image):
@@ -1028,6 +1032,7 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
if is_cancelled_callback is not None and is_cancelled_callback():
return None
self.unet.to(unet_dtype)
return latents
def latents_to_image(self, latents):

View File

@@ -7,7 +7,10 @@ from typing import List
from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
from library import model_util
from library import sdxl_original_unet
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
VAE_SCALE_FACTOR = 0.13025
MODEL_VERSION_SDXL_BASE_V1_0 = "sdxl_base_v1-0"
@@ -100,7 +103,7 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
key = key.replace(".ln_final", ".final_layer_norm")
# ckpt from comfy has this key: text_model.encoder.text_model.embeddings.position_ids
elif ".embeddings.position_ids" in key:
key = None # remove this key: make position_ids by ourselves
key = None # remove this key: position_ids is not used in newer transformers
return key
keys = list(checkpoint.keys())
@@ -126,16 +129,12 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
# original SD にはないので、position_idsを追加
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
new_sd["text_model.embeddings.position_ids"] = position_ids
# logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す
logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None)
# temporary workaround for text_projection.weight.weight for Playground-v2
if "text_projection.weight.weight" in new_sd:
print(f"convert_sdxl_text_encoder_2_checkpoint: convert text_projection.weight.weight to text_projection.weight")
logger.info("convert_sdxl_text_encoder_2_checkpoint: convert text_projection.weight.weight to text_projection.weight")
new_sd["text_projection.weight"] = new_sd["text_projection.weight.weight"]
del new_sd["text_projection.weight.weight"]
@@ -190,20 +189,20 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
checkpoint = None
# U-Net
print("building U-Net")
logger.info("building U-Net")
with init_empty_weights():
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
print("loading U-Net from checkpoint")
logger.info("loading U-Net from checkpoint")
unet_sd = {}
for k in list(state_dict.keys()):
if k.startswith("model.diffusion_model."):
unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k)
info = _load_state_dict_on_device(unet, unet_sd, device=map_location, dtype=dtype)
print("U-Net: ", info)
logger.info(f"U-Net: {info}")
# Text Encoders
print("building text encoders")
logger.info("building text encoders")
# Text Encoder 1 is same to Stability AI's SDXL
text_model1_cfg = CLIPTextConfig(
@@ -256,7 +255,7 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
with init_empty_weights():
text_model2 = CLIPTextModelWithProjection(text_model2_cfg)
print("loading text encoders from checkpoint")
logger.info("loading text encoders from checkpoint")
te1_sd = {}
te2_sd = {}
for k in list(state_dict.keys()):
@@ -265,27 +264,27 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
elif k.startswith("conditioner.embedders.1.model."):
te2_sd[k] = state_dict.pop(k)
# 一部のposition_idsがないモデルへの対応 / add position_ids for some models
if "text_model.embeddings.position_ids" not in te1_sd:
te1_sd["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0)
# 最新の transformers では position_ids を含むとエラーになるので削除 / remove position_ids for latest transformers
if "text_model.embeddings.position_ids" in te1_sd:
te1_sd.pop("text_model.embeddings.position_ids")
info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32
print("text encoder 1:", info1)
logger.info(f"text encoder 1: {info1}")
converted_sd, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77)
info2 = _load_state_dict_on_device(text_model2, converted_sd, device=map_location) # remain fp32
print("text encoder 2:", info2)
logger.info(f"text encoder 2: {info2}")
# prepare vae
print("building VAE")
logger.info("building VAE")
vae_config = model_util.create_vae_diffusers_config()
with init_empty_weights():
vae = AutoencoderKL(**vae_config)
print("loading VAE from checkpoint")
logger.info("loading VAE from checkpoint")
converted_vae_checkpoint = model_util.convert_ldm_vae_checkpoint(state_dict, vae_config)
info = _load_state_dict_on_device(vae, converted_vae_checkpoint, device=map_location, dtype=dtype)
print("VAE:", info)
logger.info(f"VAE: {info}")
ckpt_info = (epoch, global_step) if epoch is not None else None
return text_model1, text_model2, vae, unet, logit_scale, ckpt_info

View File

@@ -30,7 +30,10 @@ import torch.utils.checkpoint
from torch import nn
from torch.nn import functional as F
from einops import rearrange
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
IN_CHANNELS: int = 4
OUT_CHANNELS: int = 4
@@ -332,7 +335,7 @@ class ResnetBlock2D(nn.Module):
def forward(self, x, emb):
if self.training and self.gradient_checkpointing:
# print("ResnetBlock2D: gradient_checkpointing")
# logger.info("ResnetBlock2D: gradient_checkpointing")
def create_custom_forward(func):
def custom_forward(*inputs):
@@ -366,7 +369,7 @@ class Downsample2D(nn.Module):
def forward(self, hidden_states):
if self.training and self.gradient_checkpointing:
# print("Downsample2D: gradient_checkpointing")
# logger.info("Downsample2D: gradient_checkpointing")
def create_custom_forward(func):
def custom_forward(*inputs):
@@ -653,7 +656,7 @@ class BasicTransformerBlock(nn.Module):
def forward(self, hidden_states, context=None, timestep=None):
if self.training and self.gradient_checkpointing:
# print("BasicTransformerBlock: checkpointing")
# logger.info("BasicTransformerBlock: checkpointing")
def create_custom_forward(func):
def custom_forward(*inputs):
@@ -796,7 +799,7 @@ class Upsample2D(nn.Module):
def forward(self, hidden_states, output_size=None):
if self.training and self.gradient_checkpointing:
# print("Upsample2D: gradient_checkpointing")
# logger.info("Upsample2D: gradient_checkpointing")
def create_custom_forward(func):
def custom_forward(*inputs):
@@ -1046,7 +1049,7 @@ class SdxlUNet2DConditionModel(nn.Module):
for block in blocks:
for module in block:
if hasattr(module, "set_use_memory_efficient_attention"):
# print(module.__class__.__name__)
# logger.info(module.__class__.__name__)
module.set_use_memory_efficient_attention(xformers, mem_eff)
def set_use_sdpa(self, sdpa: bool) -> None:
@@ -1061,7 +1064,7 @@ class SdxlUNet2DConditionModel(nn.Module):
for block in blocks:
for module in block.modules():
if hasattr(module, "gradient_checkpointing"):
# print(module.__class__.__name__, module.gradient_checkpointing, "->", value)
# logger.info(f{module.__class__.__name__} {module.gradient_checkpointing} -> {value}")
module.gradient_checkpointing = value
# endregion
@@ -1083,7 +1086,7 @@ class SdxlUNet2DConditionModel(nn.Module):
def call_module(module, h, emb, context):
x = h
for layer in module:
# print(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None)
# logger.info(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None)
if isinstance(layer, ResnetBlock2D):
x = layer(x, emb)
elif isinstance(layer, Transformer2DModel):
@@ -1135,14 +1138,14 @@ class InferSdxlUNet2DConditionModel:
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
if ds_depth_1 is None:
print("Deep Shrink is disabled.")
logger.info("Deep Shrink is disabled.")
self.ds_depth_1 = None
self.ds_timesteps_1 = None
self.ds_depth_2 = None
self.ds_timesteps_2 = None
self.ds_ratio = None
else:
print(
logger.info(
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
)
self.ds_depth_1 = ds_depth_1
@@ -1229,7 +1232,7 @@ class InferSdxlUNet2DConditionModel:
if __name__ == "__main__":
import time
print("create unet")
logger.info("create unet")
unet = SdxlUNet2DConditionModel()
unet.to("cuda")
@@ -1238,7 +1241,7 @@ if __name__ == "__main__":
unet.train()
# 使用メモリ量確認用の疑似学習ループ
print("preparing optimizer")
logger.info("preparing optimizer")
# optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working
@@ -1253,12 +1256,12 @@ if __name__ == "__main__":
scaler = torch.cuda.amp.GradScaler(enabled=True)
print("start training")
logger.info("start training")
steps = 10
batch_size = 1
for step in range(steps):
print(f"step {step}")
logger.info(f"step {step}")
if step == 1:
time_start = time.perf_counter()
@@ -1278,4 +1281,4 @@ if __name__ == "__main__":
optimizer.zero_grad(set_to_none=True)
time_end = time.perf_counter()
print(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps")
logger.info(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps")

View File

@@ -1,14 +1,21 @@
import argparse
import gc
import math
import os
from typing import Optional
import torch
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate import init_empty_weights
from tqdm import tqdm
from transformers import CLIPTokenizer
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
@@ -21,7 +28,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16
for pi in range(accelerator.state.num_processes):
if pi == accelerator.state.local_process_index:
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
(
load_stable_diffusion_format,
@@ -47,8 +54,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
unet.to(accelerator.device)
vae.to(accelerator.device)
gc.collect()
torch.cuda.empty_cache()
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
@@ -62,7 +68,7 @@ def _load_target_model(
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
if load_stable_diffusion_format:
print(f"load StableDiffusion checkpoint: {name_or_path}")
logger.info(f"load StableDiffusion checkpoint: {name_or_path}")
(
text_encoder1,
text_encoder2,
@@ -76,7 +82,7 @@ def _load_target_model(
from diffusers import StableDiffusionXLPipeline
variant = "fp16" if weight_dtype == torch.float16 else None
print(f"load Diffusers pretrained models: {name_or_path}, variant={variant}")
logger.info(f"load Diffusers pretrained models: {name_or_path}, variant={variant}")
try:
try:
pipe = StableDiffusionXLPipeline.from_pretrained(
@@ -84,12 +90,12 @@ def _load_target_model(
)
except EnvironmentError as ex:
if variant is not None:
print("try to load fp32 model")
logger.info("try to load fp32 model")
pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, variant=None, tokenizer=None)
else:
raise ex
except EnvironmentError as ex:
print(
logger.error(
f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}"
)
raise ex
@@ -112,7 +118,7 @@ def _load_target_model(
with init_empty_weights():
unet = sdxl_original_unet.SdxlUNet2DConditionModel() # overwrite unet
sdxl_model_util._load_state_dict_on_device(unet, state_dict, device=device, dtype=model_dtype)
print("U-Net converted to original U-Net")
logger.info("U-Net converted to original U-Net")
logit_scale = None
ckpt_info = None
@@ -120,13 +126,13 @@ def _load_target_model(
# VAEを読み込む
if vae_path is not None:
vae = model_util.load_vae(vae_path, weight_dtype)
print("additional VAE loaded")
logger.info("additional VAE loaded")
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
def load_tokenizers(args: argparse.Namespace):
print("prepare tokenizers")
logger.info("prepare tokenizers")
original_paths = [TOKENIZER1_PATH, TOKENIZER2_PATH]
tokeniers = []
@@ -135,14 +141,14 @@ def load_tokenizers(args: argparse.Namespace):
if args.tokenizer_cache_dir:
local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_"))
if os.path.exists(local_tokenizer_path):
print(f"load tokenizer from cache: {local_tokenizer_path}")
logger.info(f"load tokenizer from cache: {local_tokenizer_path}")
tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path)
if tokenizer is None:
tokenizer = CLIPTokenizer.from_pretrained(original_path)
if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
print(f"save Tokenizer to cache: {local_tokenizer_path}")
logger.info(f"save Tokenizer to cache: {local_tokenizer_path}")
tokenizer.save_pretrained(local_tokenizer_path)
if i == 1:
@@ -151,7 +157,7 @@ def load_tokenizers(args: argparse.Namespace):
tokeniers.append(tokenizer)
if hasattr(args, "max_token_length") and args.max_token_length is not None:
print(f"update token length: {args.max_token_length}")
logger.info(f"update token length: {args.max_token_length}")
return tokeniers
@@ -332,23 +338,23 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません"
if args.v_parameterization:
print("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります")
logger.warning("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります")
if args.clip_skip is not None:
print("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")
logger.warning("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")
# if args.multires_noise_iterations:
# print(
# logger.info(
# f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります"
# )
# else:
# if args.noise_offset is None:
# args.noise_offset = DEFAULT_NOISE_OFFSET
# elif args.noise_offset != DEFAULT_NOISE_OFFSET:
# print(
# logger.info(
# f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています"
# )
# print(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
# logger.info(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
assert (
not hasattr(args, "weighted_captions") or not args.weighted_captions
@@ -357,7 +363,7 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin
if supportTextEncoderCaching:
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
args.cache_text_encoder_outputs = True
print(
logger.warning(
"cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / "
+ "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました"
)

View File

@@ -26,7 +26,10 @@ from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
from diffusers.models.vae import DecoderOutput, DiagonalGaussianDistribution
from diffusers.models.autoencoder_kl import AutoencoderKLOutput
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def slice_h(x, num_slices):
# slice with pad 1 both sides: to eliminate side effect of padding of conv2d
@@ -89,7 +92,7 @@ def resblock_forward(_self, num_slices, input_tensor, temb, **kwargs):
# sliced_tensor = torch.chunk(x, num_div, dim=1)
# sliced_weight = torch.chunk(norm.weight, num_div, dim=0)
# sliced_bias = torch.chunk(norm.bias, num_div, dim=0)
# print(sliced_tensor[0].shape, num_div, sliced_weight[0].shape, sliced_bias[0].shape)
# logger.info(sliced_tensor[0].shape, num_div, sliced_weight[0].shape, sliced_bias[0].shape)
# normed_tensor = []
# for i in range(num_div):
# n = torch.group_norm(sliced_tensor[i], norm.num_groups, sliced_weight[i], sliced_bias[i], norm.eps)
@@ -243,7 +246,7 @@ class SlicingEncoder(nn.Module):
self.num_slices = num_slices
div = num_slices / (2 ** (len(self.down_blocks) - 1)) # 深い層はそこまで分割しなくていいので適宜減らす
# print(f"initial divisor: {div}")
# logger.info(f"initial divisor: {div}")
if div >= 2:
div = int(div)
for resnet in self.mid_block.resnets:
@@ -253,11 +256,11 @@ class SlicingEncoder(nn.Module):
for i, down_block in enumerate(self.down_blocks[::-1]):
if div >= 2:
div = int(div)
# print(f"down block: {i} divisor: {div}")
# logger.info(f"down block: {i} divisor: {div}")
for resnet in down_block.resnets:
resnet.forward = wrapper(resblock_forward, resnet, div)
if down_block.downsamplers is not None:
# print("has downsample")
# logger.info("has downsample")
for downsample in down_block.downsamplers:
downsample.forward = wrapper(self.downsample_forward, downsample, div * 2)
div *= 2
@@ -307,7 +310,7 @@ class SlicingEncoder(nn.Module):
def downsample_forward(self, _self, num_slices, hidden_states):
assert hidden_states.shape[1] == _self.channels
assert _self.use_conv and _self.padding == 0
print("downsample forward", num_slices, hidden_states.shape)
logger.info(f"downsample forward {num_slices} {hidden_states.shape}")
org_device = hidden_states.device
cpu_device = torch.device("cpu")
@@ -350,7 +353,7 @@ class SlicingEncoder(nn.Module):
hidden_states = torch.cat([hidden_states, x], dim=2)
hidden_states = hidden_states.to(org_device)
# print("downsample forward done", hidden_states.shape)
# logger.info(f"downsample forward done {hidden_states.shape}")
return hidden_states
@@ -426,7 +429,7 @@ class SlicingDecoder(nn.Module):
self.num_slices = num_slices
div = num_slices / (2 ** (len(self.up_blocks) - 1))
print(f"initial divisor: {div}")
logger.info(f"initial divisor: {div}")
if div >= 2:
div = int(div)
for resnet in self.mid_block.resnets:
@@ -436,11 +439,11 @@ class SlicingDecoder(nn.Module):
for i, up_block in enumerate(self.up_blocks):
if div >= 2:
div = int(div)
# print(f"up block: {i} divisor: {div}")
# logger.info(f"up block: {i} divisor: {div}")
for resnet in up_block.resnets:
resnet.forward = wrapper(resblock_forward, resnet, div)
if up_block.upsamplers is not None:
# print("has upsample")
# logger.info("has upsample")
for upsample in up_block.upsamplers:
upsample.forward = wrapper(self.upsample_forward, upsample, div * 2)
div *= 2
@@ -528,7 +531,7 @@ class SlicingDecoder(nn.Module):
del x
hidden_states = torch.cat(sliced, dim=2)
# print("us hidden_states", hidden_states.shape)
# logger.info(f"us hidden_states {hidden_states.shape}")
del sliced
hidden_states = hidden_states.to(org_device)

1654
library/stable_cascade.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,668 @@
import argparse
import json
import math
import os
import time
from typing import List
import numpy as np
import toml
import torch
import torchvision
from safetensors.torch import load_file, save_file
from tqdm import tqdm
from transformers import CLIPTokenizer, CLIPTextModelWithProjection, CLIPTextConfig
from accelerate import init_empty_weights, Accelerator, PartialState
from PIL import Image
from library import stable_cascade as sc
from library.sdxl_model_util import _load_state_dict_on_device
from library.device_utils import clean_memory_on_device
from library.train_util import (
save_sd_model_on_epoch_end_or_stepwise_common,
save_sd_model_on_train_end_common,
line_to_prompt_dict,
get_hidden_states_stable_cascade,
)
from library import sai_model_spec
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
CLIP_TEXT_MODEL_NAME: str = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_sc_te_outputs.npz"
def calculate_latent_sizes(height=1024, width=1024, batch_size=4, compression_factor_b=42.67, compression_factor_a=4.0):
resolution_multiple = 42.67
latent_height = math.ceil(height / compression_factor_b)
latent_width = math.ceil(width / compression_factor_b)
stage_c_latent_shape = (batch_size, 16, latent_height, latent_width)
latent_height = math.ceil(height / compression_factor_a)
latent_width = math.ceil(width / compression_factor_a)
stage_b_latent_shape = (batch_size, 4, latent_height, latent_width)
return stage_c_latent_shape, stage_b_latent_shape
# region load and save
def load_effnet(effnet_checkpoint_path, loading_device="cpu") -> sc.EfficientNetEncoder:
logger.info(f"Loading EfficientNet encoder from {effnet_checkpoint_path}")
effnet = sc.EfficientNetEncoder()
effnet_checkpoint = load_file(effnet_checkpoint_path)
info = effnet.load_state_dict(effnet_checkpoint if "state_dict" not in effnet_checkpoint else effnet_checkpoint["state_dict"])
logger.info(info)
del effnet_checkpoint
return effnet
def load_tokenizer(args: argparse.Namespace):
# TODO commonize with sdxl_train_util.load_tokenizers
logger.info("prepare tokenizers")
original_paths = [CLIP_TEXT_MODEL_NAME]
tokenizers = []
for i, original_path in enumerate(original_paths):
tokenizer: CLIPTokenizer = None
if args.tokenizer_cache_dir:
local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_"))
if os.path.exists(local_tokenizer_path):
logger.info(f"load tokenizer from cache: {local_tokenizer_path}")
tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path)
if tokenizer is None:
tokenizer = CLIPTokenizer.from_pretrained(original_path)
if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
logger.info(f"save Tokenizer to cache: {local_tokenizer_path}")
tokenizer.save_pretrained(local_tokenizer_path)
tokenizers.append(tokenizer)
if hasattr(args, "max_token_length") and args.max_token_length is not None:
logger.info(f"update token length: {args.max_token_length}")
return tokenizers[0]
def load_stage_c_model(stage_c_checkpoint_path, dtype=None, device="cpu") -> sc.StageC:
# Generator
logger.info(f"Instantiating Stage C generator")
with init_empty_weights():
generator_c = sc.StageC()
logger.info(f"Loading Stage C generator from {stage_c_checkpoint_path}")
stage_c_checkpoint = load_file(stage_c_checkpoint_path)
stage_c_checkpoint = convert_state_dict_mha_to_normal_attn(stage_c_checkpoint)
logger.info(f"Loading state dict")
info = _load_state_dict_on_device(generator_c, stage_c_checkpoint, device, dtype=dtype)
logger.info(info)
return generator_c
def load_stage_b_model(stage_b_checkpoint_path, dtype=None, device="cpu") -> sc.StageB:
logger.info(f"Instantiating Stage B generator")
with init_empty_weights():
generator_b = sc.StageB()
logger.info(f"Loading Stage B generator from {stage_b_checkpoint_path}")
stage_b_checkpoint = load_file(stage_b_checkpoint_path)
stage_b_checkpoint = convert_state_dict_mha_to_normal_attn(stage_b_checkpoint)
logger.info(f"Loading state dict")
info = _load_state_dict_on_device(generator_b, stage_b_checkpoint, device, dtype=dtype)
logger.info(info)
return generator_b
def load_clip_text_model(text_model_checkpoint_path, dtype=None, device="cpu", save_text_model=False):
# CLIP encoders
logger.info(f"Loading CLIP text model")
if save_text_model or text_model_checkpoint_path is None:
logger.info(f"Loading CLIP text model from {CLIP_TEXT_MODEL_NAME}")
text_model = CLIPTextModelWithProjection.from_pretrained(CLIP_TEXT_MODEL_NAME)
if save_text_model:
sd = text_model.state_dict()
logger.info(f"Saving CLIP text model to {text_model_checkpoint_path}")
save_file(sd, text_model_checkpoint_path)
else:
logger.info(f"Loading CLIP text model from {text_model_checkpoint_path}")
# copy from sdxl_model_util.py
text_model2_cfg = CLIPTextConfig(
vocab_size=49408,
hidden_size=1280,
intermediate_size=5120,
num_hidden_layers=32,
num_attention_heads=20,
max_position_embeddings=77,
hidden_act="gelu",
layer_norm_eps=1e-05,
dropout=0.0,
attention_dropout=0.0,
initializer_range=0.02,
initializer_factor=1.0,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
model_type="clip_text_model",
projection_dim=1280,
# torch_dtype="float32",
# transformers_version="4.25.0.dev0",
)
with init_empty_weights():
text_model = CLIPTextModelWithProjection(text_model2_cfg)
text_model_checkpoint = load_file(text_model_checkpoint_path)
info = _load_state_dict_on_device(text_model, text_model_checkpoint, device, dtype=dtype)
logger.info(info)
return text_model
def load_stage_a_model(stage_a_checkpoint_path, dtype=None, device="cpu") -> sc.StageA:
logger.info(f"Loading Stage A vqGAN from {stage_a_checkpoint_path}")
stage_a = sc.StageA().to(device)
stage_a_checkpoint = load_file(stage_a_checkpoint_path)
info = stage_a.load_state_dict(
stage_a_checkpoint if "state_dict" not in stage_a_checkpoint else stage_a_checkpoint["state_dict"]
)
logger.info(info)
return stage_a
def load_previewer_model(previewer_checkpoint_path, dtype=None, device="cpu") -> sc.Previewer:
logger.info(f"Loading Previewer from {previewer_checkpoint_path}")
previewer = sc.Previewer().to(device)
previewer_checkpoint = load_file(previewer_checkpoint_path)
info = previewer.load_state_dict(
previewer_checkpoint if "state_dict" not in previewer_checkpoint else previewer_checkpoint["state_dict"]
)
logger.info(info)
return previewer
def convert_state_dict_mha_to_normal_attn(state_dict):
# convert nn.MultiheadAttention to to_q/k/v and out_proj
print("convert_state_dict_mha_to_normal_attn")
for key in list(state_dict.keys()):
if "attention.attn." in key:
if "in_proj_bias" in key:
value = state_dict.pop(key)
qkv = torch.chunk(value, 3, dim=0)
state_dict[key.replace("in_proj_bias", "to_q.bias")] = qkv[0]
state_dict[key.replace("in_proj_bias", "to_k.bias")] = qkv[1]
state_dict[key.replace("in_proj_bias", "to_v.bias")] = qkv[2]
elif "in_proj_weight" in key:
value = state_dict.pop(key)
qkv = torch.chunk(value, 3, dim=0)
state_dict[key.replace("in_proj_weight", "to_q.weight")] = qkv[0]
state_dict[key.replace("in_proj_weight", "to_k.weight")] = qkv[1]
state_dict[key.replace("in_proj_weight", "to_v.weight")] = qkv[2]
elif "out_proj.bias" in key:
value = state_dict.pop(key)
state_dict[key.replace("out_proj.bias", "out_proj.bias")] = value
elif "out_proj.weight" in key:
value = state_dict.pop(key)
state_dict[key.replace("out_proj.weight", "out_proj.weight")] = value
return state_dict
def convert_state_dict_normal_attn_to_mha(state_dict):
# convert to_q/k/v and out_proj to nn.MultiheadAttention
for key in list(state_dict.keys()):
if "attention.attn." in key:
if "to_q.bias" in key:
q = state_dict.pop(key)
k = state_dict.pop(key.replace("to_q.bias", "to_k.bias"))
v = state_dict.pop(key.replace("to_q.bias", "to_v.bias"))
state_dict[key.replace("to_q.bias", "in_proj_bias")] = torch.cat([q, k, v])
elif "to_q.weight" in key:
q = state_dict.pop(key)
k = state_dict.pop(key.replace("to_q.weight", "to_k.weight"))
v = state_dict.pop(key.replace("to_q.weight", "to_v.weight"))
state_dict[key.replace("to_q.weight", "in_proj_weight")] = torch.cat([q, k, v])
elif "out_proj.bias" in key:
v = state_dict.pop(key)
state_dict[key.replace("out_proj.bias", "out_proj.bias")] = v
elif "out_proj.weight" in key:
v = state_dict.pop(key)
state_dict[key.replace("out_proj.weight", "out_proj.weight")] = v
return state_dict
def get_sai_model_spec(args, lora=False):
timestamp = time.time()
reso = args.resolution
title = args.metadata_title if args.metadata_title is not None else args.output_name
if args.min_timestep is not None or args.max_timestep is not None:
min_time_step = args.min_timestep if args.min_timestep is not None else 0
max_time_step = args.max_timestep if args.max_timestep is not None else 1000
timesteps = (min_time_step, max_time_step)
else:
timesteps = None
metadata = sai_model_spec.build_metadata(
None,
False,
False,
False,
lora,
False,
timestamp,
title=title,
reso=reso,
is_stable_diffusion_ckpt=False,
author=args.metadata_author,
description=args.metadata_description,
license=args.metadata_license,
tags=args.metadata_tags,
timesteps=timesteps,
clip_skip=args.clip_skip, # None or int
stable_cascade=True,
)
return metadata
def stage_c_saver_common(ckpt_file, stage_c, text_model, save_dtype, sai_metadata):
state_dict = stage_c.state_dict()
if save_dtype is not None:
state_dict = {k: v.to(save_dtype) for k, v in state_dict.items()}
state_dict = convert_state_dict_normal_attn_to_mha(state_dict)
save_file(state_dict, ckpt_file, metadata=sai_metadata)
# save text model
if text_model is not None:
text_model_sd = text_model.state_dict()
if save_dtype is not None:
text_model_sd = {k: v.to(save_dtype) for k, v in text_model_sd.items()}
text_model_ckpt_file = os.path.splitext(ckpt_file)[0] + "_text_model.safetensors"
save_file(text_model_sd, text_model_ckpt_file)
def save_stage_c_model_on_epoch_end_or_stepwise(
args: argparse.Namespace,
on_epoch_end: bool,
accelerator,
save_dtype: torch.dtype,
epoch: int,
num_train_epochs: int,
global_step: int,
stage_c,
text_model,
):
def stage_c_saver(ckpt_file, epoch_no, global_step):
sai_metadata = get_sai_model_spec(args)
stage_c_saver_common(ckpt_file, stage_c, text_model, save_dtype, sai_metadata)
save_sd_model_on_epoch_end_or_stepwise_common(
args, on_epoch_end, accelerator, True, True, epoch, num_train_epochs, global_step, stage_c_saver, None
)
def save_stage_c_model_on_end(
args: argparse.Namespace,
save_dtype: torch.dtype,
epoch: int,
global_step: int,
stage_c,
text_model,
):
def stage_c_saver(ckpt_file, epoch_no, global_step):
sai_metadata = get_sai_model_spec(args)
stage_c_saver_common(ckpt_file, stage_c, text_model, save_dtype, sai_metadata)
save_sd_model_on_train_end_common(args, True, True, epoch, global_step, stage_c_saver, None)
# endregion
# region sample generation
def sample_images(
accelerator: Accelerator,
args: argparse.Namespace,
epoch,
steps,
previewer,
tokenizer,
text_encoder,
stage_c,
gdf,
prompt_replacement=None,
):
if steps == 0:
if not args.sample_at_first:
return
else:
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
return
if args.sample_every_n_epochs is not None:
# sample_every_n_steps は無視する
if epoch is None or epoch % args.sample_every_n_epochs != 0:
return
else:
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
return
logger.info("")
logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
if not os.path.isfile(args.sample_prompts):
logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
return
distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
# unwrap unet and text_encoder(s)
stage_c = accelerator.unwrap_model(stage_c)
text_encoder = accelerator.unwrap_model(text_encoder)
# read prompts
if args.sample_prompts.endswith(".txt"):
with open(args.sample_prompts, "r", encoding="utf-8") as f:
lines = f.readlines()
prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
elif args.sample_prompts.endswith(".toml"):
with open(args.sample_prompts, "r", encoding="utf-8") as f:
data = toml.load(f)
prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]]
elif args.sample_prompts.endswith(".json"):
with open(args.sample_prompts, "r", encoding="utf-8") as f:
prompts = json.load(f)
save_dir = args.output_dir + "/sample"
os.makedirs(save_dir, exist_ok=True)
# preprocess prompts
for i in range(len(prompts)):
prompt_dict = prompts[i]
if isinstance(prompt_dict, str):
prompt_dict = line_to_prompt_dict(prompt_dict)
prompts[i] = prompt_dict
assert isinstance(prompt_dict, dict)
# Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict.
prompt_dict["enum"] = i
prompt_dict.pop("subset", None)
# save random state to restore later
rng_state = torch.get_rng_state()
cuda_rng_state = None
try:
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
except Exception:
pass
if distributed_state.num_processes <= 1:
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
with torch.no_grad():
for prompt_dict in prompts:
sample_image_inference(
accelerator,
args,
tokenizer,
text_encoder,
stage_c,
previewer,
gdf,
save_dir,
prompt_dict,
epoch,
steps,
prompt_replacement,
)
else:
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
per_process_prompts = [] # list of lists
for i in range(distributed_state.num_processes):
per_process_prompts.append(prompts[i :: distributed_state.num_processes])
with torch.no_grad():
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
for prompt_dict in prompt_dict_lists[0]:
sample_image_inference(
accelerator,
args,
tokenizer,
text_encoder,
stage_c,
previewer,
gdf,
save_dir,
prompt_dict,
epoch,
steps,
prompt_replacement,
)
# I'm not sure which of these is the correct way to clear the memory, but accelerator's device is used in the pipeline, so I'm using it here.
# with torch.cuda.device(torch.cuda.current_device()):
# torch.cuda.empty_cache()
clean_memory_on_device(accelerator.device)
torch.set_rng_state(rng_state)
if cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state)
def sample_image_inference(
accelerator: Accelerator,
args: argparse.Namespace,
tokenizer,
text_model,
stage_c,
previewer,
gdf,
save_dir,
prompt_dict,
epoch,
steps,
prompt_replacement,
):
assert isinstance(prompt_dict, dict)
negative_prompt = prompt_dict.get("negative_prompt")
sample_steps = prompt_dict.get("sample_steps", 20)
width = prompt_dict.get("width", 1024)
height = prompt_dict.get("height", 1024)
scale = prompt_dict.get("scale", 4)
seed = prompt_dict.get("seed")
# controlnet_image = prompt_dict.get("controlnet_image")
prompt: str = prompt_dict.get("prompt", "")
# sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
if negative_prompt is not None:
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
else:
# True random sample image generation
torch.seed()
torch.cuda.seed()
height = max(64, height - height % 8) # round to divisible by 8
width = max(64, width - width % 8) # round to divisible by 8
logger.info(f"prompt: {prompt}")
logger.info(f"negative_prompt: {negative_prompt}")
logger.info(f"height: {height}")
logger.info(f"width: {width}")
logger.info(f"sample_steps: {sample_steps}")
logger.info(f"scale: {scale}")
# logger.info(f"sample_sampler: {sampler_name}")
if seed is not None:
logger.info(f"seed: {seed}")
negative_prompt = "" if negative_prompt is None else negative_prompt
cfg = scale
timesteps = sample_steps
shift = 2
t_start = 1.0
stage_c_latent_shape, _ = calculate_latent_sizes(height, width, batch_size=1)
# PREPARE CONDITIONS
input_ids = tokenizer(
[prompt], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
)["input_ids"].to(text_model.device)
cond_text, cond_pooled = get_hidden_states_stable_cascade(tokenizer.model_max_length, input_ids, tokenizer, text_model)
input_ids = tokenizer(
[negative_prompt], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
)["input_ids"].to(text_model.device)
uncond_text, uncond_pooled = get_hidden_states_stable_cascade(tokenizer.model_max_length, input_ids, tokenizer, text_model)
device = accelerator.device
dtype = stage_c.dtype
cond_text = cond_text.to(device, dtype=dtype)
cond_pooled = cond_pooled.unsqueeze(1).to(device, dtype=dtype)
uncond_text = uncond_text.to(device, dtype=dtype)
uncond_pooled = uncond_pooled.unsqueeze(1).to(device, dtype=dtype)
zero_img_emb = torch.zeros(1, 768, device=device)
# 辞書にしたくないけど GDF から先の変更が面倒だからとりあえず辞書にしておく
conditions = {"clip_text_pooled": cond_pooled, "clip": cond_pooled, "clip_text": cond_text, "clip_img": zero_img_emb}
unconditions = {"clip_text_pooled": uncond_pooled, "clip": uncond_pooled, "clip_text": uncond_text, "clip_img": zero_img_emb}
with torch.no_grad(): # , torch.cuda.amp.autocast(dtype=dtype):
sampling_c = gdf.sample(
stage_c,
conditions,
stage_c_latent_shape,
unconditions,
device=device,
cfg=cfg,
shift=shift,
timesteps=timesteps,
t_start=t_start,
)
for sampled_c, _, _ in tqdm(sampling_c, total=timesteps):
sampled_c = sampled_c
sampled_c = sampled_c.to(previewer.device, dtype=previewer.dtype)
image = previewer(sampled_c)[0]
image = torch.clamp(image, 0, 1)
image = image.cpu().numpy().transpose(1, 2, 0)
image = image * 255
image = image.astype(np.uint8)
image = Image.fromarray(image)
# adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
# but adding 'enum' to the filename should be enough
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
seed_suffix = "" if seed is None else f"_{seed}"
i: int = prompt_dict["enum"]
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
image.save(os.path.join(save_dir, img_filename))
# wandb有効時のみログを送信
try:
wandb_tracker = accelerator.get_tracker("wandb")
try:
import wandb
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
raise ImportError("No wandb / wandb がインストールされていないようです")
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
except: # wandb 無効時
pass
# endregion
def add_effnet_arguments(parser):
parser.add_argument(
"--effnet_checkpoint_path",
type=str,
required=True,
help="path to EfficientNet checkpoint / EfficientNetのチェックポイントのパス",
)
return parser
def add_text_model_arguments(parser):
parser.add_argument(
"--text_model_checkpoint_path",
type=str,
help="path to CLIP text model checkpoint / CLIPテキストモデルのチェックポイントのパス",
)
parser.add_argument("--save_text_model", action="store_true", help="if specified, save text model to corresponding path")
return parser
def add_stage_a_arguments(parser):
parser.add_argument(
"--stage_a_checkpoint_path",
type=str,
required=True,
help="path to Stage A checkpoint / Stage Aのチェックポイントのパス",
)
return parser
def add_stage_b_arguments(parser):
parser.add_argument(
"--stage_b_checkpoint_path",
type=str,
required=True,
help="path to Stage B checkpoint / Stage Bのチェックポイントのパス",
)
return parser
def add_stage_c_arguments(parser):
parser.add_argument(
"--stage_c_checkpoint_path",
type=str,
required=True,
help="path to Stage C checkpoint / Stage Cのチェックポイントのパス",
)
return parser
def add_previewer_arguments(parser):
parser.add_argument(
"--previewer_checkpoint_path",
type=str,
required=False,
help="path to previewer checkpoint / previewerのチェックポイントのパス",
)
return parser
def add_training_arguments(parser):
parser.add_argument(
"--adaptive_loss_weight",
action="store_true",
help="if specified, use adaptive loss weight. if not, use P2 loss weight"
+ " / Adaptive Loss Weightを使用する。指定しない場合はP2 Loss Weightを使用する",
)

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,266 @@
import logging
import sys
import threading
import torch
from torchvision import transforms
from typing import *
from diffusers import EulerAncestralDiscreteScheduler
import diffusers.schedulers.scheduling_euler_ancestral_discrete
from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput
def fire_in_thread(f, *args, **kwargs):
threading.Thread(target=f, args=args, kwargs=kwargs).start()
threading.Thread(target=f, args=args, kwargs=kwargs).start()
def add_logging_arguments(parser):
parser.add_argument(
"--console_log_level",
type=str,
default=None,
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
help="Set the logging level, default is INFO / ログレベルを設定する。デフォルトはINFO",
)
parser.add_argument(
"--console_log_file",
type=str,
default=None,
help="Log to a file instead of stderr / 標準エラー出力ではなくファイルにログを出力する",
)
parser.add_argument("--console_log_simple", action="store_true", help="Simple log output / シンプルなログ出力")
def setup_logging(args=None, log_level=None, reset=False):
if logging.root.handlers:
if reset:
# remove all handlers
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
else:
return
# log_level can be set by the caller or by the args, the caller has priority. If not set, use INFO
if log_level is None and args is not None:
log_level = args.console_log_level
if log_level is None:
log_level = "INFO"
log_level = getattr(logging, log_level)
msg_init = None
if args is not None and args.console_log_file:
handler = logging.FileHandler(args.console_log_file, mode="w")
else:
handler = None
if not args or not args.console_log_simple:
try:
from rich.logging import RichHandler
from rich.console import Console
from rich.logging import RichHandler
handler = RichHandler(console=Console(stderr=True))
except ImportError:
# print("rich is not installed, using basic logging")
msg_init = "rich is not installed, using basic logging"
if handler is None:
handler = logging.StreamHandler(sys.stdout) # same as print
handler.propagate = False
formatter = logging.Formatter(
fmt="%(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
handler.setFormatter(formatter)
logging.root.setLevel(log_level)
logging.root.addHandler(handler)
if msg_init is not None:
logger = logging.getLogger(__name__)
logger.info(msg_init)
# TODO make inf_utils.py
# region Gradual Latent hires fix
class GradualLatent:
def __init__(
self,
ratio,
start_timesteps,
every_n_steps,
ratio_step,
s_noise=1.0,
gaussian_blur_ksize=None,
gaussian_blur_sigma=0.5,
gaussian_blur_strength=0.5,
unsharp_target_x=True,
):
self.ratio = ratio
self.start_timesteps = start_timesteps
self.every_n_steps = every_n_steps
self.ratio_step = ratio_step
self.s_noise = s_noise
self.gaussian_blur_ksize = gaussian_blur_ksize
self.gaussian_blur_sigma = gaussian_blur_sigma
self.gaussian_blur_strength = gaussian_blur_strength
self.unsharp_target_x = unsharp_target_x
def __str__(self) -> str:
return (
f"GradualLatent(ratio={self.ratio}, start_timesteps={self.start_timesteps}, "
+ f"every_n_steps={self.every_n_steps}, ratio_step={self.ratio_step}, s_noise={self.s_noise}, "
+ f"gaussian_blur_ksize={self.gaussian_blur_ksize}, gaussian_blur_sigma={self.gaussian_blur_sigma}, gaussian_blur_strength={self.gaussian_blur_strength}, "
+ f"unsharp_target_x={self.unsharp_target_x})"
)
def apply_unshark_mask(self, x: torch.Tensor):
if self.gaussian_blur_ksize is None:
return x
blurred = transforms.functional.gaussian_blur(x, self.gaussian_blur_ksize, self.gaussian_blur_sigma)
# mask = torch.sigmoid((x - blurred) * self.gaussian_blur_strength)
mask = (x - blurred) * self.gaussian_blur_strength
sharpened = x + mask
return sharpened
def interpolate(self, x: torch.Tensor, resized_size, unsharp=True):
org_dtype = x.dtype
if org_dtype == torch.bfloat16:
x = x.float()
x = torch.nn.functional.interpolate(x, size=resized_size, mode="bicubic", align_corners=False).to(dtype=org_dtype)
# apply unsharp mask / アンシャープマスクを適用する
if unsharp and self.gaussian_blur_ksize:
x = self.apply_unshark_mask(x)
return x
class EulerAncestralDiscreteSchedulerGL(EulerAncestralDiscreteScheduler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.resized_size = None
self.gradual_latent = None
def set_gradual_latent_params(self, size, gradual_latent: GradualLatent):
self.resized_size = size
self.gradual_latent = gradual_latent
def step(
self,
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`):
Whether or not to return a
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
Returns:
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
If return_dict is `True`,
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
otherwise a tuple is returned where the first element is the sample tensor.
"""
if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor):
raise ValueError(
(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep."
),
)
if not self.is_scale_input_called:
# logger.warning(
print(
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
"See `StableDiffusionPipeline` for a usage example."
)
if self.step_index is None:
self._init_step_index(timestep)
sigma = self.sigmas[self.step_index]
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
if self.config.prediction_type == "epsilon":
pred_original_sample = sample - sigma * model_output
elif self.config.prediction_type == "v_prediction":
# * c_out + input * c_skip
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
elif self.config.prediction_type == "sample":
raise NotImplementedError("prediction_type not implemented yet: sample")
else:
raise ValueError(f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`")
sigma_from = self.sigmas[self.step_index]
sigma_to = self.sigmas[self.step_index + 1]
sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
# 2. Convert to an ODE derivative
derivative = (sample - pred_original_sample) / sigma
dt = sigma_down - sigma
device = model_output.device
if self.resized_size is None:
prev_sample = sample + derivative * dt
noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor(
model_output.shape, dtype=model_output.dtype, device=device, generator=generator
)
s_noise = 1.0
else:
print("resized_size", self.resized_size, "model_output.shape", model_output.shape, "sample.shape", sample.shape)
s_noise = self.gradual_latent.s_noise
if self.gradual_latent.unsharp_target_x:
prev_sample = sample + derivative * dt
prev_sample = self.gradual_latent.interpolate(prev_sample, self.resized_size)
else:
sample = self.gradual_latent.interpolate(sample, self.resized_size)
derivative = self.gradual_latent.interpolate(derivative, self.resized_size, unsharp=False)
prev_sample = sample + derivative * dt
noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor(
(model_output.shape[0], model_output.shape[1], self.resized_size[0], self.resized_size[1]),
dtype=model_output.dtype,
device=device,
generator=generator,
)
prev_sample = prev_sample + noise * sigma_up * s_noise
# upon completion increase step index by one
self._step_index += 1
if not return_dict:
return (prev_sample,)
return EulerAncestralDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
# endregion

View File

@@ -2,10 +2,13 @@ import argparse
import os
import torch
from safetensors.torch import load_file
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def main(file):
print(f"loading: {file}")
logger.info(f"loading: {file}")
if os.path.splitext(file)[1] == ".safetensors":
sd = load_file(file)
else:

View File

@@ -2,7 +2,10 @@ import os
from typing import Optional, List, Type
import torch
from library import sdxl_original_unet
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# input_blocksに適用するかどうか / if True, input_blocks are not applied
SKIP_INPUT_BLOCKS = False
@@ -125,7 +128,7 @@ class LLLiteModule(torch.nn.Module):
return
# timestepごとに呼ばれないので、あらかじめ計算しておく / it is not called for each timestep, so calculate it in advance
# print(f"C {self.lllite_name}, cond_image.shape={cond_image.shape}")
# logger.info(f"C {self.lllite_name}, cond_image.shape={cond_image.shape}")
cx = self.conditioning1(cond_image)
if not self.is_conv2d:
# reshape / b,c,h,w -> b,h*w,c
@@ -155,7 +158,7 @@ class LLLiteModule(torch.nn.Module):
cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1)
if self.use_zeros_for_batch_uncond:
cx[0::2] = 0.0 # uncond is zero
# print(f"C {self.lllite_name}, x.shape={x.shape}, cx.shape={cx.shape}")
# logger.info(f"C {self.lllite_name}, x.shape={x.shape}, cx.shape={cx.shape}")
# downで入力の次元数を削減し、conditioning image embeddingと結合する
# 加算ではなくchannel方向に結合することで、うまいこと混ぜてくれることを期待している
@@ -286,7 +289,7 @@ class ControlNetLLLite(torch.nn.Module):
# create module instances
self.unet_modules: List[LLLiteModule] = create_modules(unet, target_modules, LLLiteModule)
print(f"create ControlNet LLLite for U-Net: {len(self.unet_modules)} modules.")
logger.info(f"create ControlNet LLLite for U-Net: {len(self.unet_modules)} modules.")
def forward(self, x):
return x # dummy
@@ -319,7 +322,7 @@ class ControlNetLLLite(torch.nn.Module):
return info
def apply_to(self):
print("applying LLLite for U-Net...")
logger.info("applying LLLite for U-Net...")
for module in self.unet_modules:
module.apply_to()
self.add_module(module.lllite_name, module)
@@ -374,19 +377,19 @@ if __name__ == "__main__":
# sdxl_original_unet.USE_REENTRANT = False
# test shape etc
print("create unet")
logger.info("create unet")
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
unet.to("cuda").to(torch.float16)
print("create ControlNet-LLLite")
logger.info("create ControlNet-LLLite")
control_net = ControlNetLLLite(unet, 32, 64)
control_net.apply_to()
control_net.to("cuda")
print(control_net)
logger.info(control_net)
# print number of parameters
print("number of parameters", sum(p.numel() for p in control_net.parameters() if p.requires_grad))
# logger.info number of parameters
logger.info(f"number of parameters {sum(p.numel() for p in control_net.parameters() if p.requires_grad)}")
input()
@@ -398,12 +401,12 @@ if __name__ == "__main__":
# # visualize
# import torchviz
# print("run visualize")
# logger.info("run visualize")
# controlnet.set_control(conditioning_image)
# output = unet(x, t, ctx, y)
# print("make_dot")
# logger.info("make_dot")
# image = torchviz.make_dot(output, params=dict(controlnet.named_parameters()))
# print("render")
# logger.info("render")
# image.format = "svg" # "png"
# image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time
# input()
@@ -414,12 +417,12 @@ if __name__ == "__main__":
scaler = torch.cuda.amp.GradScaler(enabled=True)
print("start training")
logger.info("start training")
steps = 10
sample_param = [p for p in control_net.named_parameters() if "up" in p[0]][0]
for step in range(steps):
print(f"step {step}")
logger.info(f"step {step}")
batch_size = 1
conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0
@@ -439,7 +442,7 @@ if __name__ == "__main__":
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
print(sample_param)
logger.info(f"{sample_param}")
# from safetensors.torch import save_file

View File

@@ -6,7 +6,10 @@ import re
from typing import Optional, List, Type
import torch
from library import sdxl_original_unet
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# input_blocksに適用するかどうか / if True, input_blocks are not applied
SKIP_INPUT_BLOCKS = False
@@ -270,7 +273,7 @@ class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DCond
# create module instances
self.lllite_modules = apply_to_modules(self, target_modules)
print(f"enable ControlNet LLLite for U-Net: {len(self.lllite_modules)} modules.")
logger.info(f"enable ControlNet LLLite for U-Net: {len(self.lllite_modules)} modules.")
# def prepare_optimizer_params(self):
def prepare_params(self):
@@ -281,8 +284,8 @@ class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DCond
train_params.append(p)
else:
non_train_params.append(p)
print(f"count of trainable parameters: {len(train_params)}")
print(f"count of non-trainable parameters: {len(non_train_params)}")
logger.info(f"count of trainable parameters: {len(train_params)}")
logger.info(f"count of non-trainable parameters: {len(non_train_params)}")
for p in non_train_params:
p.requires_grad_(False)
@@ -388,7 +391,7 @@ class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DCond
matches = pattern.findall(module_name)
if matches is not None:
for m in matches:
print(module_name, m)
logger.info(f"{module_name} {m}")
module_name = module_name.replace(m, m.replace("_", "@"))
module_name = module_name.replace("_", ".")
module_name = module_name.replace("@", "_")
@@ -407,7 +410,7 @@ class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DCond
def replace_unet_linear_and_conv2d():
print("replace torch.nn.Linear and torch.nn.Conv2d to LLLiteLinear and LLLiteConv2d in U-Net")
logger.info("replace torch.nn.Linear and torch.nn.Conv2d to LLLiteLinear and LLLiteConv2d in U-Net")
sdxl_original_unet.torch.nn.Linear = LLLiteLinear
sdxl_original_unet.torch.nn.Conv2d = LLLiteConv2d
@@ -419,10 +422,10 @@ if __name__ == "__main__":
replace_unet_linear_and_conv2d()
# test shape etc
print("create unet")
logger.info("create unet")
unet = SdxlUNet2DConditionModelControlNetLLLite()
print("enable ControlNet-LLLite")
logger.info("enable ControlNet-LLLite")
unet.apply_lllite(32, 64, None, False, 1.0)
unet.to("cuda") # .to(torch.float16)
@@ -439,14 +442,14 @@ if __name__ == "__main__":
# unet_sd[converted_key] = model_sd[key]
# info = unet.load_lllite_weights("r:/lllite_from_unet.safetensors", unet_sd)
# print(info)
# logger.info(info)
# print(unet)
# logger.info(unet)
# print number of parameters
# logger.info number of parameters
params = unet.prepare_params()
print("number of parameters", sum(p.numel() for p in params))
# print("type any key to continue")
logger.info(f"number of parameters {sum(p.numel() for p in params)}")
# logger.info("type any key to continue")
# input()
unet.set_use_memory_efficient_attention(True, False)
@@ -455,12 +458,12 @@ if __name__ == "__main__":
# # visualize
# import torchviz
# print("run visualize")
# logger.info("run visualize")
# controlnet.set_control(conditioning_image)
# output = unet(x, t, ctx, y)
# print("make_dot")
# logger.info("make_dot")
# image = torchviz.make_dot(output, params=dict(controlnet.named_parameters()))
# print("render")
# logger.info("render")
# image.format = "svg" # "png"
# image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time
# input()
@@ -471,13 +474,13 @@ if __name__ == "__main__":
scaler = torch.cuda.amp.GradScaler(enabled=True)
print("start training")
logger.info("start training")
steps = 10
batch_size = 1
sample_param = [p for p in unet.named_parameters() if ".lllite_up." in p[0]][0]
for step in range(steps):
print(f"step {step}")
logger.info(f"step {step}")
conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0
x = torch.randn(batch_size, 4, 128, 128).cuda()
@@ -494,9 +497,9 @@ if __name__ == "__main__":
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
print(sample_param)
logger.info(sample_param)
# from safetensors.torch import save_file
# print("save weights")
# logger.info("save weights")
# unet.save_lllite_weights("r:/lllite_from_unet.safetensors", torch.float16, None)

View File

@@ -12,10 +12,15 @@
import math
import os
import random
from typing import List, Tuple, Union
from typing import Dict, List, Optional, Tuple, Type, Union
from diffusers import AutoencoderKL
from transformers import CLIPTextModel
import torch
from torch import nn
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class DyLoRAModule(torch.nn.Module):
"""
@@ -165,7 +170,15 @@ class DyLoRAModule(torch.nn.Module):
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
def create_network(
multiplier: float,
network_dim: Optional[int],
network_alpha: Optional[float],
vae: AutoencoderKL,
text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
unet,
**kwargs,
):
if network_dim is None:
network_dim = 4 # default
if network_alpha is None:
@@ -182,6 +195,7 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
conv_alpha = 1.0
else:
conv_alpha = float(conv_alpha)
if unit is not None:
unit = int(unit)
else:
@@ -223,7 +237,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
elif "lora_down" in key:
dim = value.size()[0]
modules_dim[lora_name] = dim
# print(lora_name, value.size(), dim)
# logger.info(f"{lora_name} {value.size()} {dim}")
# support old LoRA without alpha
for key in modules_dim.keys():
@@ -267,11 +281,11 @@ class DyLoRANetwork(torch.nn.Module):
self.apply_to_conv = apply_to_conv
if modules_dim is not None:
print(f"create LoRA network from weights")
logger.info("create LoRA network from weights")
else:
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, unit: {unit}")
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, unit: {unit}")
if self.apply_to_conv:
print(f"apply LoRA to Conv2d with kernel size (3,3).")
logger.info("apply LoRA to Conv2d with kernel size (3,3).")
# create module instances
def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[DyLoRAModule]:
@@ -306,9 +320,23 @@ class DyLoRANetwork(torch.nn.Module):
lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, unit)
loras.append(lora)
return loras
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
self.text_encoder_loras = []
for i, text_encoder in enumerate(text_encoders):
if len(text_encoders) > 1:
index = i + 1
logger.info(f"create LoRA for Text Encoder {index}")
else:
index = None
logger.info("create LoRA for Text Encoder")
text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
self.text_encoder_loras.extend(text_encoder_loras)
self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
# self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
target_modules = DyLoRANetwork.UNET_TARGET_REPLACE_MODULE
@@ -316,7 +344,7 @@ class DyLoRANetwork(torch.nn.Module):
target_modules += DyLoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
self.unet_loras = create_modules(True, unet, target_modules)
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
def set_multiplier(self, multiplier):
self.multiplier = multiplier
@@ -336,12 +364,12 @@ class DyLoRANetwork(torch.nn.Module):
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
if apply_text_encoder:
print("enable LoRA for text encoder")
logger.info("enable LoRA for text encoder")
else:
self.text_encoder_loras = []
if apply_unet:
print("enable LoRA for U-Net")
logger.info("enable LoRA for U-Net")
else:
self.unet_loras = []
@@ -359,12 +387,12 @@ class DyLoRANetwork(torch.nn.Module):
apply_unet = True
if apply_text_encoder:
print("enable LoRA for text encoder")
logger.info("enable LoRA for text encoder")
else:
self.text_encoder_loras = []
if apply_unet:
print("enable LoRA for U-Net")
logger.info("enable LoRA for U-Net")
else:
self.unet_loras = []
@@ -375,7 +403,7 @@ class DyLoRANetwork(torch.nn.Module):
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
lora.merge_to(sd_for_lora, dtype, device)
print(f"weights are merged")
logger.info(f"weights are merged")
"""
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):

View File

@@ -10,7 +10,10 @@ from safetensors.torch import load_file, save_file, safe_open
from tqdm import tqdm
from library import train_util, model_util
import numpy as np
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def load_state_dict(file_name):
if model_util.is_safetensors(file_name):
@@ -40,13 +43,13 @@ def split_lora_model(lora_sd, unit):
rank = value.size()[0]
if rank > max_rank:
max_rank = rank
print(f"Max rank: {max_rank}")
logger.info(f"Max rank: {max_rank}")
rank = unit
split_models = []
new_alpha = None
while rank < max_rank:
print(f"Splitting rank {rank}")
logger.info(f"Splitting rank {rank}")
new_sd = {}
for key, value in lora_sd.items():
if "lora_down" in key:
@@ -57,7 +60,7 @@ def split_lora_model(lora_sd, unit):
# なぜかscaleするとおかしくなる……
# this_rank = lora_sd[key.replace("alpha", "lora_down.weight")].size()[0]
# scale = math.sqrt(this_rank / rank) # rank is > unit
# print(key, value.size(), this_rank, rank, value, scale)
# logger.info(key, value.size(), this_rank, rank, value, scale)
# new_alpha = value * scale # always same
# new_sd[key] = new_alpha
new_sd[key] = value
@@ -69,10 +72,10 @@ def split_lora_model(lora_sd, unit):
def split(args):
print("loading Model...")
logger.info("loading Model...")
lora_sd, metadata = load_state_dict(args.model)
print("Splitting Model...")
logger.info("Splitting Model...")
original_rank, split_models = split_lora_model(lora_sd, args.unit)
comment = metadata.get("ss_training_comment", "")
@@ -94,7 +97,7 @@ def split(args):
filename, ext = os.path.splitext(args.save_to)
model_file_name = filename + f"-{new_rank:04d}{ext}"
print(f"saving model to: {model_file_name}")
logger.info(f"saving model to: {model_file_name}")
save_to_file(model_file_name, state_dict, new_metadata)

View File

@@ -11,7 +11,10 @@ from safetensors.torch import load_file, save_file
from tqdm import tqdm
from library import sai_model_spec, model_util, sdxl_model_util
import lora
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# CLAMP_QUANTILE = 0.99
# MIN_DIFF = 1e-1
@@ -43,6 +46,9 @@ def svd(
clamp_quantile=0.99,
min_diff=0.01,
no_metadata=False,
load_precision=None,
load_original_model_to=None,
load_tuned_model_to=None,
):
def str_to_dtype(p):
if p == "float":
@@ -57,28 +63,51 @@ def svd(
if v_parameterization is None:
v_parameterization = v2
load_dtype = str_to_dtype(load_precision) if load_precision else None
save_dtype = str_to_dtype(save_precision)
work_device = "cpu"
# load models
if not sdxl:
print(f"loading original SD model : {model_org}")
logger.info(f"loading original SD model : {model_org}")
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org)
text_encoders_o = [text_encoder_o]
print(f"loading tuned SD model : {model_tuned}")
if load_dtype is not None:
text_encoder_o = text_encoder_o.to(load_dtype)
unet_o = unet_o.to(load_dtype)
logger.info(f"loading tuned SD model : {model_tuned}")
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned)
text_encoders_t = [text_encoder_t]
if load_dtype is not None:
text_encoder_t = text_encoder_t.to(load_dtype)
unet_t = unet_t.to(load_dtype)
model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization)
else:
print(f"loading original SDXL model : {model_org}")
device_org = load_original_model_to if load_original_model_to else "cpu"
device_tuned = load_tuned_model_to if load_tuned_model_to else "cpu"
logger.info(f"loading original SDXL model : {model_org}")
text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, "cpu"
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, device_org
)
text_encoders_o = [text_encoder_o1, text_encoder_o2]
print(f"loading original SDXL model : {model_tuned}")
if load_dtype is not None:
text_encoder_o1 = text_encoder_o1.to(load_dtype)
text_encoder_o2 = text_encoder_o2.to(load_dtype)
unet_o = unet_o.to(load_dtype)
logger.info(f"loading original SDXL model : {model_tuned}")
text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, "cpu"
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, device_tuned
)
text_encoders_t = [text_encoder_t1, text_encoder_t2]
if load_dtype is not None:
text_encoder_t1 = text_encoder_t1.to(load_dtype)
text_encoder_t2 = text_encoder_t2.to(load_dtype)
unet_t = unet_t.to(load_dtype)
model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0
# create LoRA network to extract weights: Use dim (rank) as alpha
@@ -100,38 +129,54 @@ def svd(
lora_name = lora_o.lora_name
module_o = lora_o.org_module
module_t = lora_t.org_module
diff = module_t.weight - module_o.weight
diff = module_t.weight.to(work_device) - module_o.weight.to(work_device)
# clear weight to save memory
module_o.weight = None
module_t.weight = None
# Text Encoder might be same
if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff:
text_encoder_different = True
print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}")
logger.info(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}")
diff = diff.float()
diffs[lora_name] = diff
# clear target Text Encoder to save memory
for text_encoder in text_encoders_t:
del text_encoder
if not text_encoder_different:
print("Text encoder is same. Extract U-Net only.")
logger.warning("Text encoder is same. Extract U-Net only.")
lora_network_o.text_encoder_loras = []
diffs = {}
diffs = {} # clear diffs
for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)):
lora_name = lora_o.lora_name
module_o = lora_o.org_module
module_t = lora_t.org_module
diff = module_t.weight - module_o.weight
diff = diff.float()
diff = module_t.weight.to(work_device) - module_o.weight.to(work_device)
if args.device:
diff = diff.to(args.device)
# clear weight to save memory
module_o.weight = None
module_t.weight = None
diffs[lora_name] = diff
# clear LoRA network, target U-Net to save memory
del lora_network_o
del lora_network_t
del unet_t
# make LoRA with svd
print("calculating by svd")
logger.info("calculating by svd")
lora_weights = {}
with torch.no_grad():
for lora_name, mat in tqdm(list(diffs.items())):
if args.device:
mat = mat.to(args.device)
mat = mat.to(torch.float) # calc by float
# if conv_dim is None, diffs do not include LoRAs for conv2d-3x3
conv2d = len(mat.size()) == 4
kernel_size = None if not conv2d else mat.size()[2:4]
@@ -143,7 +188,7 @@ def svd(
if device:
mat = mat.to(device)
# print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim)
# logger.info(lora_name, mat.size(), mat.device, rank, in_dim, out_dim)
rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
if conv2d:
@@ -171,8 +216,8 @@ def svd(
U = U.reshape(out_dim, rank, 1, 1)
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
U = U.to("cpu").contiguous()
Vh = Vh.to("cpu").contiguous()
U = U.to(work_device, dtype=save_dtype).contiguous()
Vh = Vh.to(work_device, dtype=save_dtype).contiguous()
lora_weights[lora_name] = (U, Vh)
@@ -188,7 +233,7 @@ def svd(
lora_network_save.apply_to(text_encoders_o, unet_o) # create internal module references for state_dict
info = lora_network_save.load_state_dict(lora_sd)
print(f"Loading extracted LoRA weights: {info}")
logger.info(f"Loading extracted LoRA weights: {info}")
dir_name = os.path.dirname(save_to)
if dir_name and not os.path.exists(dir_name):
@@ -215,7 +260,7 @@ def svd(
metadata.update(sai_metadata)
lora_network_save.save_weights(save_to, save_dtype, metadata)
print(f"LoRA weights are saved to: {save_to}")
logger.info(f"LoRA weights are saved to: {save_to}")
def setup_parser() -> argparse.ArgumentParser:
@@ -230,6 +275,13 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--sdxl", action="store_true", help="load Stable Diffusion SDXL base model / Stable Diffusion SDXL baseのモデルを読み込む"
)
parser.add_argument(
"--load_precision",
type=str,
default=None,
choices=[None, "float", "fp16", "bf16"],
help="precision in loading, model default if omitted / 読み込み時に精度を変更して読み込む、省略時はモデルファイルによる"
)
parser.add_argument(
"--save_precision",
type=str,
@@ -285,6 +337,18 @@ def setup_parser() -> argparse.ArgumentParser:
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
+ "sai modelspecのメタデータを保存しないLoRAの最低限のss_metadataは保存される",
)
parser.add_argument(
"--load_original_model_to",
type=str,
default=None,
help="location to load original model, cpu or cuda, cuda:0, etc, default is cpu, only for SDXL / 元モデル読み込み先、cpuまたはcuda、cuda:0など、省略時はcpu、SDXLのみ有効",
)
parser.add_argument(
"--load_tuned_model_to",
type=str,
default=None,
help="location to load tuned model, cpu or cuda, cuda:0, etc, default is cpu, only for SDXL / 派生モデル読み込み先、cpuまたはcuda、cuda:0など、省略時はcpu、SDXLのみ有効",
)
return parser

View File

@@ -11,7 +11,12 @@ from transformers import CLIPTextModel
import numpy as np
import torch
import re
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
@@ -46,7 +51,7 @@ class LoRAModule(torch.nn.Module):
# if limit_rank:
# self.lora_dim = min(lora_dim, in_dim, out_dim)
# if self.lora_dim != lora_dim:
# print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
# logger.info(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
# else:
self.lora_dim = lora_dim
@@ -177,7 +182,7 @@ class LoRAInfModule(LoRAModule):
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# print(conved.size(), weight.size(), module.stride, module.padding)
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + self.multiplier * conved * self.scale
# set weight to org_module
@@ -216,7 +221,7 @@ class LoRAInfModule(LoRAModule):
self.region_mask = None
def default_forward(self, x):
# print("default_forward", self.lora_name, x.size())
# logger.info(f"default_forward {self.lora_name} {x.size()}")
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
def forward(self, x):
@@ -245,7 +250,8 @@ class LoRAInfModule(LoRAModule):
if mask is None:
# raise ValueError(f"mask is None for resolution {area}")
# emb_layers in SDXL doesn't have mask
# print(f"mask is None for resolution {area}, {x.size()}")
# if "emb" not in self.lora_name:
# print(f"mask is None for resolution {self.lora_name}, {area}, {x.size()}")
mask_size = (1, x.size()[1]) if len(x.size()) == 2 else (1, *x.size()[1:-1], 1)
return torch.ones(mask_size, dtype=x.dtype, device=x.device) / self.network.num_sub_prompts
if len(x.size()) != 4:
@@ -263,6 +269,8 @@ class LoRAInfModule(LoRAModule):
lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
mask = self.get_mask_for_x(lx)
# print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size())
# if mask.ndim > lx.ndim: # in some resolution, lx is 2d and mask is 3d (the reason is not checked)
# mask = mask.squeeze(-1)
lx = lx * mask
x = self.org_forward(x)
@@ -291,7 +299,7 @@ class LoRAInfModule(LoRAModule):
if has_real_uncond:
query[-self.network.batch_size :] = x[-self.network.batch_size :]
# print("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts)
# logger.info(f"postp_to_q {self.lora_name} {x.size()} {query.size()} {self.network.num_sub_prompts}")
return query
def sub_prompt_forward(self, x):
@@ -306,7 +314,7 @@ class LoRAInfModule(LoRAModule):
lx = x[emb_idx :: self.network.num_sub_prompts]
lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale
# print("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx)
# logger.info(f"sub_prompt_forward {self.lora_name} {x.size()} {lx.size()} {emb_idx}")
x = self.org_forward(x)
x[emb_idx :: self.network.num_sub_prompts] += lx
@@ -314,7 +322,7 @@ class LoRAInfModule(LoRAModule):
return x
def to_out_forward(self, x):
# print("to_out_forward", self.lora_name, x.size(), self.network.is_last_network)
# logger.info(f"to_out_forward {self.lora_name} {x.size()} {self.network.is_last_network}")
if self.network.is_last_network:
masks = [None] * self.network.num_sub_prompts
@@ -332,7 +340,7 @@ class LoRAInfModule(LoRAModule):
)
self.network.shared[self.lora_name] = (lx, masks)
# print("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts)
# logger.info(f"to_out_forward {lx.size()} {lx1.size()} {self.network.sub_prompt_index} {self.network.num_sub_prompts}")
lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1
masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1)
@@ -351,7 +359,7 @@ class LoRAInfModule(LoRAModule):
if has_real_uncond:
out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond
# print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts)
# logger.info(f"to_out_forward {self.lora_name} {self.network.sub_prompt_index} {self.network.num_sub_prompts}")
# if num_sub_prompts > num of LoRAs, fill with zero
for i in range(len(masks)):
if masks[i] is None:
@@ -374,7 +382,7 @@ class LoRAInfModule(LoRAModule):
x1 = x1 + lx1
out[self.network.batch_size + i] = x1
# print("to_out_forward", x.size(), out.size(), has_real_uncond)
# logger.info(f"to_out_forward {x.size()} {out.size()} {has_real_uncond}")
return out
@@ -511,7 +519,9 @@ def get_block_dims_and_alphas(
len(block_dims) == num_total_blocks
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
else:
print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
logger.warning(
f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります"
)
block_dims = [network_dim] * num_total_blocks
if block_alphas is not None:
@@ -520,7 +530,7 @@ def get_block_dims_and_alphas(
len(block_alphas) == num_total_blocks
), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください"
else:
print(
logger.warning(
f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります"
)
block_alphas = [network_alpha] * num_total_blocks
@@ -540,13 +550,13 @@ def get_block_dims_and_alphas(
else:
if conv_alpha is None:
conv_alpha = 1.0
print(
logger.warning(
f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります"
)
conv_block_alphas = [conv_alpha] * num_total_blocks
else:
if conv_dim is not None:
print(
logger.warning(
f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります"
)
conv_block_dims = [conv_dim] * num_total_blocks
@@ -586,7 +596,7 @@ def get_block_lr_weight(
elif name == "zeros":
return [0.0 + base_lr] * max_len
else:
print(
logger.error(
"Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros"
% (name)
)
@@ -598,14 +608,14 @@ def get_block_lr_weight(
up_lr_weight = get_list(up_lr_weight)
if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len):
print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
logger.warning("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
logger.warning("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
up_lr_weight = up_lr_weight[:max_len]
down_lr_weight = down_lr_weight[:max_len]
if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len):
print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
logger.warning("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
logger.warning("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
if down_lr_weight != None and len(down_lr_weight) < max_len:
down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight))
@@ -613,24 +623,24 @@ def get_block_lr_weight(
up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight))
if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
print("apply block learning rate / 階層別学習率を適用します。")
logger.info("apply block learning rate / 階層別学習率を適用します。")
if down_lr_weight != None:
down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight]
print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", down_lr_weight)
logger.info(f"down_lr_weight (shallower -> deeper, 浅い層->深い層): {down_lr_weight}")
else:
print("down_lr_weight: all 1.0, すべて1.0")
logger.info("down_lr_weight: all 1.0, すべて1.0")
if mid_lr_weight != None:
mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0
print("mid_lr_weight:", mid_lr_weight)
logger.info(f"mid_lr_weight: {mid_lr_weight}")
else:
print("mid_lr_weight: 1.0")
logger.info("mid_lr_weight: 1.0")
if up_lr_weight != None:
up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight]
print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", up_lr_weight)
logger.info(f"up_lr_weight (deeper -> shallower, 深い層->浅い層): {up_lr_weight}")
else:
print("up_lr_weight: all 1.0, すべて1.0")
logger.info("up_lr_weight: all 1.0, すべて1.0")
return down_lr_weight, mid_lr_weight, up_lr_weight
@@ -711,7 +721,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
elif "lora_down" in key:
dim = value.size()[0]
modules_dim[lora_name] = dim
# print(lora_name, value.size(), dim)
# logger.info(lora_name, value.size(), dim)
# support old LoRA without alpha
for key in modules_dim.keys():
@@ -786,20 +796,26 @@ class LoRANetwork(torch.nn.Module):
self.module_dropout = module_dropout
if modules_dim is not None:
print(f"create LoRA network from weights")
logger.info(f"create LoRA network from weights")
elif block_dims is not None:
print(f"create LoRA network from block_dims")
print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
print(f"block_dims: {block_dims}")
print(f"block_alphas: {block_alphas}")
logger.info(f"create LoRA network from block_dims")
logger.info(
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
)
logger.info(f"block_dims: {block_dims}")
logger.info(f"block_alphas: {block_alphas}")
if conv_block_dims is not None:
print(f"conv_block_dims: {conv_block_dims}")
print(f"conv_block_alphas: {conv_block_alphas}")
logger.info(f"conv_block_dims: {conv_block_dims}")
logger.info(f"conv_block_alphas: {conv_block_alphas}")
else:
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
logger.info(
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
)
if self.conv_lora_dim is not None:
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
logger.info(
f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
)
# create module instances
def create_modules(
@@ -825,9 +841,14 @@ class LoRANetwork(torch.nn.Module):
is_linear = child_module.__class__.__name__ == "Linear"
is_conv2d = child_module.__class__.__name__ == "Conv2d"
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
is_group_conv2d = is_conv2d and child_module.groups > 1
if is_linear or is_conv2d:
lora_name = prefix + "." + name + "." + child_name
# if is_group_conv2d:
# logger.info(f"skip group conv2d: {name}.{child_name}")
# continue
if is_linear or (is_conv2d and not is_group_conv2d):
lora_name = prefix + "." + name + ("." + child_name if child_name else "")
lora_name = lora_name.replace(".", "_")
dim = None
@@ -884,31 +905,36 @@ class LoRANetwork(torch.nn.Module):
for i, text_encoder in enumerate(text_encoders):
if len(text_encoders) > 1:
index = i + 1
print(f"create LoRA for Text Encoder {index}:")
logger.info(f"create LoRA for Text Encoder {index}:")
else:
index = None
print(f"create LoRA for Text Encoder:")
logger.info(f"create LoRA for Text Encoder:")
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
self.text_encoder_loras.extend(text_encoder_loras)
skipped_te += skipped
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
# XXX temporary solution for Stable Cascade Stage C: replace all modules
if "StageC" in unet.__class__.__name__:
logger.info("replace all modules for Stable Cascade Stage C")
target_modules = ["Linear", "Conv2d"]
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
skipped = skipped_te + skipped_un
if varbose and len(skipped) > 0:
print(
logger.warning(
f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
)
for name in skipped:
print(f"\t{name}")
logger.info(f"\t{name}")
self.up_lr_weight: List[float] = None
self.down_lr_weight: List[float] = None
@@ -926,6 +952,10 @@ class LoRANetwork(torch.nn.Module):
for lora in self.text_encoder_loras + self.unet_loras:
lora.multiplier = self.multiplier
def set_enabled(self, is_enabled):
for lora in self.text_encoder_loras + self.unet_loras:
lora.enabled = is_enabled
def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
@@ -939,12 +969,12 @@ class LoRANetwork(torch.nn.Module):
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
if apply_text_encoder:
print("enable LoRA for text encoder")
logger.info("enable LoRA for text encoder")
else:
self.text_encoder_loras = []
if apply_unet:
print("enable LoRA for U-Net")
logger.info("enable LoRA for U-Net")
else:
self.unet_loras = []
@@ -966,12 +996,12 @@ class LoRANetwork(torch.nn.Module):
apply_unet = True
if apply_text_encoder:
print("enable LoRA for text encoder")
logger.info("enable LoRA for text encoder")
else:
self.text_encoder_loras = []
if apply_unet:
print("enable LoRA for U-Net")
logger.info("enable LoRA for U-Net")
else:
self.unet_loras = []
@@ -982,7 +1012,7 @@ class LoRANetwork(torch.nn.Module):
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
lora.merge_to(sd_for_lora, dtype, device)
print(f"weights are merged")
logger.info(f"weights are merged")
# 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない
def set_block_lr_weight(
@@ -1113,7 +1143,7 @@ class LoRANetwork(torch.nn.Module):
for lora in self.text_encoder_loras + self.unet_loras:
lora.set_network(self)
def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared):
def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared, ds_ratio=None):
self.batch_size = batch_size
self.num_sub_prompts = num_sub_prompts
self.current_size = (height, width)
@@ -1128,7 +1158,7 @@ class LoRANetwork(torch.nn.Module):
device = ref_weight.device
def resize_add(mh, mw):
# print(mh, mw, mh * mw)
# logger.info(mh, mw, mh * mw)
m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16
m = m.to(device, dtype=dtype)
mask_dic[mh * mw] = m
@@ -1139,6 +1169,13 @@ class LoRANetwork(torch.nn.Module):
resize_add(h, w)
if h % 2 == 1 or w % 2 == 1: # add extra shape if h/w is not divisible by 2
resize_add(h + h % 2, w + w % 2)
# deep shrink
if ds_ratio is not None:
hd = int(h * ds_ratio)
wd = int(w * ds_ratio)
resize_add(hd, wd)
h = (h + 1) // 2
w = (w + 1) // 2

View File

@@ -9,8 +9,15 @@ from diffusers import UNet2DConditionModel
import numpy as np
from tqdm import tqdm
from transformers import CLIPTextModel
import torch
import torch
from library.device_utils import init_ipex, get_preferred_device
init_ipex()
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def make_unet_conversion_map() -> Dict[str, str]:
unet_conversion_map_layer = []
@@ -248,7 +255,7 @@ def create_network_from_weights(
elif "lora_down" in key:
dim = value.size()[0]
modules_dim[lora_name] = dim
# print(lora_name, value.size(), dim)
# logger.info(f"{lora_name} {value.size()} {dim}")
# support old LoRA without alpha
for key in modules_dim.keys():
@@ -291,12 +298,12 @@ class LoRANetwork(torch.nn.Module):
super().__init__()
self.multiplier = multiplier
print(f"create LoRA network from weights")
logger.info("create LoRA network from weights")
# convert SDXL Stability AI's U-Net modules to Diffusers
converted = self.convert_unet_modules(modules_dim, modules_alpha)
if converted:
print(f"converted {converted} Stability AI's U-Net LoRA modules to Diffusers (SDXL)")
logger.info(f"converted {converted} Stability AI's U-Net LoRA modules to Diffusers (SDXL)")
# create module instances
def create_modules(
@@ -331,7 +338,7 @@ class LoRANetwork(torch.nn.Module):
lora_name = lora_name.replace(".", "_")
if lora_name not in modules_dim:
# print(f"skipped {lora_name} (not found in modules_dim)")
# logger.info(f"skipped {lora_name} (not found in modules_dim)")
skipped.append(lora_name)
continue
@@ -362,18 +369,18 @@ class LoRANetwork(torch.nn.Module):
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
self.text_encoder_loras.extend(text_encoder_loras)
skipped_te += skipped
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
if len(skipped_te) > 0:
print(f"skipped {len(skipped_te)} modules because of missing weight for text encoder.")
logger.warning(f"skipped {len(skipped_te)} modules because of missing weight for text encoder.")
# extend U-Net target modules to include Conv2d 3x3
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
self.unet_loras: List[LoRAModule]
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
if len(skipped_un) > 0:
print(f"skipped {len(skipped_un)} modules because of missing weight for U-Net.")
logger.warning(f"skipped {len(skipped_un)} modules because of missing weight for U-Net.")
# assertion
names = set()
@@ -420,11 +427,11 @@ class LoRANetwork(torch.nn.Module):
def apply_to(self, multiplier=1.0, apply_text_encoder=True, apply_unet=True):
if apply_text_encoder:
print("enable LoRA for text encoder")
logger.info("enable LoRA for text encoder")
for lora in self.text_encoder_loras:
lora.apply_to(multiplier)
if apply_unet:
print("enable LoRA for U-Net")
logger.info("enable LoRA for U-Net")
for lora in self.unet_loras:
lora.apply_to(multiplier)
@@ -433,16 +440,16 @@ class LoRANetwork(torch.nn.Module):
lora.unapply_to()
def merge_to(self, multiplier=1.0):
print("merge LoRA weights to original weights")
logger.info("merge LoRA weights to original weights")
for lora in tqdm(self.text_encoder_loras + self.unet_loras):
lora.merge_to(multiplier)
print(f"weights are merged")
logger.info(f"weights are merged")
def restore_from(self, multiplier=1.0):
print("restore LoRA weights from original weights")
logger.info("restore LoRA weights from original weights")
for lora in tqdm(self.text_encoder_loras + self.unet_loras):
lora.restore_from(multiplier)
print(f"weights are restored")
logger.info(f"weights are restored")
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
# convert SDXL Stability AI's state dict to Diffusers' based state dict
@@ -463,7 +470,7 @@ class LoRANetwork(torch.nn.Module):
my_state_dict = self.state_dict()
for key in state_dict.keys():
if state_dict[key].size() != my_state_dict[key].size():
# print(f"convert {key} from {state_dict[key].size()} to {my_state_dict[key].size()}")
# logger.info(f"convert {key} from {state_dict[key].size()} to {my_state_dict[key].size()}")
state_dict[key] = state_dict[key].view(my_state_dict[key].size())
return super().load_state_dict(state_dict, strict)
@@ -476,7 +483,7 @@ if __name__ == "__main__":
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = get_preferred_device()
parser = argparse.ArgumentParser()
parser.add_argument("--model_id", type=str, default=None, help="model id for huggingface")
@@ -490,7 +497,7 @@ if __name__ == "__main__":
image_prefix = args.model_id.replace("/", "_") + "_"
# load Diffusers model
print(f"load model from {args.model_id}")
logger.info(f"load model from {args.model_id}")
pipe: Union[StableDiffusionPipeline, StableDiffusionXLPipeline]
if args.sdxl:
# use_safetensors=True does not work with 0.18.2
@@ -503,7 +510,7 @@ if __name__ == "__main__":
text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if args.sdxl else [pipe.text_encoder]
# load LoRA weights
print(f"load LoRA weights from {args.lora_weights}")
logger.info(f"load LoRA weights from {args.lora_weights}")
if os.path.splitext(args.lora_weights)[1] == ".safetensors":
from safetensors.torch import load_file
@@ -512,10 +519,10 @@ if __name__ == "__main__":
lora_sd = torch.load(args.lora_weights)
# create by LoRA weights and load weights
print(f"create LoRA network")
logger.info(f"create LoRA network")
lora_network: LoRANetwork = create_network_from_weights(text_encoders, pipe.unet, lora_sd, multiplier=1.0)
print(f"load LoRA network weights")
logger.info(f"load LoRA network weights")
lora_network.load_state_dict(lora_sd)
lora_network.to(device, dtype=pipe.unet.dtype) # required to apply_to. merge_to works without this
@@ -544,34 +551,34 @@ if __name__ == "__main__":
random.seed(seed)
# create image with original weights
print(f"create image with original weights")
logger.info(f"create image with original weights")
seed_everything(args.seed)
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
image.save(image_prefix + "original.png")
# apply LoRA network to the model: slower than merge_to, but can be reverted easily
print(f"apply LoRA network to the model")
logger.info(f"apply LoRA network to the model")
lora_network.apply_to(multiplier=1.0)
print(f"create image with applied LoRA")
logger.info(f"create image with applied LoRA")
seed_everything(args.seed)
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
image.save(image_prefix + "applied_lora.png")
# unapply LoRA network to the model
print(f"unapply LoRA network to the model")
logger.info(f"unapply LoRA network to the model")
lora_network.unapply_to()
print(f"create image with unapplied LoRA")
logger.info(f"create image with unapplied LoRA")
seed_everything(args.seed)
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
image.save(image_prefix + "unapplied_lora.png")
# merge LoRA network to the model: faster than apply_to, but requires back-up of original weights (or unmerge_to)
print(f"merge LoRA network to the model")
logger.info(f"merge LoRA network to the model")
lora_network.merge_to(multiplier=1.0)
print(f"create image with LoRA")
logger.info(f"create image with LoRA")
seed_everything(args.seed)
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
image.save(image_prefix + "merged_lora.png")
@@ -579,31 +586,31 @@ if __name__ == "__main__":
# restore (unmerge) LoRA weights: numerically unstable
# マージされた重みを元に戻す。計算誤差のため、元の重みと完全に一致しないことがあるかもしれない
# 保存したstate_dictから元の重みを復元するのが確実
print(f"restore (unmerge) LoRA weights")
logger.info(f"restore (unmerge) LoRA weights")
lora_network.restore_from(multiplier=1.0)
print(f"create image without LoRA")
logger.info(f"create image without LoRA")
seed_everything(args.seed)
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
image.save(image_prefix + "unmerged_lora.png")
# restore original weights
print(f"restore original weights")
logger.info(f"restore original weights")
pipe.unet.load_state_dict(org_unet_sd)
pipe.text_encoder.load_state_dict(org_text_encoder_sd)
if args.sdxl:
pipe.text_encoder_2.load_state_dict(org_text_encoder_2_sd)
print(f"create image with restored original weights")
logger.info(f"create image with restored original weights")
seed_everything(args.seed)
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
image.save(image_prefix + "restore_original.png")
# use convenience function to merge LoRA weights
print(f"merge LoRA weights with convenience function")
logger.info(f"merge LoRA weights with convenience function")
merge_lora_weights(pipe, lora_sd, multiplier=1.0)
print(f"create image with merged LoRA weights")
logger.info(f"create image with merged LoRA weights")
seed_everything(args.seed)
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
image.save(image_prefix + "convenience_merged_lora.png")

View File

@@ -14,7 +14,10 @@ from transformers import CLIPTextModel
import numpy as np
import torch
import re
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
@@ -49,7 +52,7 @@ class LoRAModule(torch.nn.Module):
# if limit_rank:
# self.lora_dim = min(lora_dim, in_dim, out_dim)
# if self.lora_dim != lora_dim:
# print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
# logger.info(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
# else:
self.lora_dim = lora_dim
@@ -197,7 +200,7 @@ class LoRAInfModule(LoRAModule):
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# print(conved.size(), weight.size(), module.stride, module.padding)
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + self.multiplier * conved * self.scale
# set weight to org_module
@@ -236,7 +239,7 @@ class LoRAInfModule(LoRAModule):
self.region_mask = None
def default_forward(self, x):
# print("default_forward", self.lora_name, x.size())
# logger.info("default_forward", self.lora_name, x.size())
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
def forward(self, x):
@@ -278,7 +281,7 @@ class LoRAInfModule(LoRAModule):
# apply mask for LoRA result
lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
mask = self.get_mask_for_x(lx)
# print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size())
# logger.info("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size())
lx = lx * mask
x = self.org_forward(x)
@@ -307,7 +310,7 @@ class LoRAInfModule(LoRAModule):
if has_real_uncond:
query[-self.network.batch_size :] = x[-self.network.batch_size :]
# print("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts)
# logger.info("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts)
return query
def sub_prompt_forward(self, x):
@@ -322,7 +325,7 @@ class LoRAInfModule(LoRAModule):
lx = x[emb_idx :: self.network.num_sub_prompts]
lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale
# print("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx)
# logger.info("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx)
x = self.org_forward(x)
x[emb_idx :: self.network.num_sub_prompts] += lx
@@ -330,7 +333,7 @@ class LoRAInfModule(LoRAModule):
return x
def to_out_forward(self, x):
# print("to_out_forward", self.lora_name, x.size(), self.network.is_last_network)
# logger.info("to_out_forward", self.lora_name, x.size(), self.network.is_last_network)
if self.network.is_last_network:
masks = [None] * self.network.num_sub_prompts
@@ -348,7 +351,7 @@ class LoRAInfModule(LoRAModule):
)
self.network.shared[self.lora_name] = (lx, masks)
# print("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts)
# logger.info("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts)
lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1
masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1)
@@ -367,7 +370,7 @@ class LoRAInfModule(LoRAModule):
if has_real_uncond:
out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond
# print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts)
# logger.info("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts)
# for i in range(len(masks)):
# if masks[i] is None:
# masks[i] = torch.zeros_like(masks[-1])
@@ -389,7 +392,7 @@ class LoRAInfModule(LoRAModule):
x1 = x1 + lx1
out[self.network.batch_size + i] = x1
# print("to_out_forward", x.size(), out.size(), has_real_uncond)
# logger.info("to_out_forward", x.size(), out.size(), has_real_uncond)
return out
@@ -526,7 +529,7 @@ def get_block_dims_and_alphas(
len(block_dims) == num_total_blocks
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
else:
print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
logger.warning(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
block_dims = [network_dim] * num_total_blocks
if block_alphas is not None:
@@ -535,7 +538,7 @@ def get_block_dims_and_alphas(
len(block_alphas) == num_total_blocks
), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください"
else:
print(
logger.warning(
f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります"
)
block_alphas = [network_alpha] * num_total_blocks
@@ -555,13 +558,13 @@ def get_block_dims_and_alphas(
else:
if conv_alpha is None:
conv_alpha = 1.0
print(
logger.warning(
f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります"
)
conv_block_alphas = [conv_alpha] * num_total_blocks
else:
if conv_dim is not None:
print(
logger.warning(
f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります"
)
conv_block_dims = [conv_dim] * num_total_blocks
@@ -601,7 +604,7 @@ def get_block_lr_weight(
elif name == "zeros":
return [0.0 + base_lr] * max_len
else:
print(
logger.error(
"Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros"
% (name)
)
@@ -613,14 +616,14 @@ def get_block_lr_weight(
up_lr_weight = get_list(up_lr_weight)
if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len):
print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
logger.warning("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
logger.warning("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
up_lr_weight = up_lr_weight[:max_len]
down_lr_weight = down_lr_weight[:max_len]
if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len):
print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
logger.warning("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
logger.warning("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
if down_lr_weight != None and len(down_lr_weight) < max_len:
down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight))
@@ -628,24 +631,24 @@ def get_block_lr_weight(
up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight))
if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
print("apply block learning rate / 階層別学習率を適用します。")
logger.info("apply block learning rate / 階層別学習率を適用します。")
if down_lr_weight != None:
down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight]
print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", down_lr_weight)
logger.info(f"down_lr_weight (shallower -> deeper, 浅い層->深い層): {down_lr_weight}")
else:
print("down_lr_weight: all 1.0, すべて1.0")
logger.info("down_lr_weight: all 1.0, すべて1.0")
if mid_lr_weight != None:
mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0
print("mid_lr_weight:", mid_lr_weight)
logger.info(f"mid_lr_weight: {mid_lr_weight}")
else:
print("mid_lr_weight: 1.0")
logger.info("mid_lr_weight: 1.0")
if up_lr_weight != None:
up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight]
print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", up_lr_weight)
logger.info(f"up_lr_weight (deeper -> shallower, 深い層->浅い層): {up_lr_weight}")
else:
print("up_lr_weight: all 1.0, すべて1.0")
logger.info("up_lr_weight: all 1.0, すべて1.0")
return down_lr_weight, mid_lr_weight, up_lr_weight
@@ -726,7 +729,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
elif "lora_down" in key:
dim = value.size()[0]
modules_dim[lora_name] = dim
# print(lora_name, value.size(), dim)
# logger.info(lora_name, value.size(), dim)
# support old LoRA without alpha
for key in modules_dim.keys():
@@ -801,20 +804,20 @@ class LoRANetwork(torch.nn.Module):
self.module_dropout = module_dropout
if modules_dim is not None:
print(f"create LoRA network from weights")
logger.info(f"create LoRA network from weights")
elif block_dims is not None:
print(f"create LoRA network from block_dims")
print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
print(f"block_dims: {block_dims}")
print(f"block_alphas: {block_alphas}")
logger.info(f"create LoRA network from block_dims")
logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
logger.info(f"block_dims: {block_dims}")
logger.info(f"block_alphas: {block_alphas}")
if conv_block_dims is not None:
print(f"conv_block_dims: {conv_block_dims}")
print(f"conv_block_alphas: {conv_block_alphas}")
logger.info(f"conv_block_dims: {conv_block_dims}")
logger.info(f"conv_block_alphas: {conv_block_alphas}")
else:
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
if self.conv_lora_dim is not None:
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
# create module instances
def create_modules(
@@ -899,15 +902,15 @@ class LoRANetwork(torch.nn.Module):
for i, text_encoder in enumerate(text_encoders):
if len(text_encoders) > 1:
index = i + 1
print(f"create LoRA for Text Encoder {index}:")
logger.info(f"create LoRA for Text Encoder {index}:")
else:
index = None
print(f"create LoRA for Text Encoder:")
logger.info(f"create LoRA for Text Encoder:")
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
self.text_encoder_loras.extend(text_encoder_loras)
skipped_te += skipped
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
@@ -915,15 +918,15 @@ class LoRANetwork(torch.nn.Module):
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
skipped = skipped_te + skipped_un
if varbose and len(skipped) > 0:
print(
logger.warning(
f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
)
for name in skipped:
print(f"\t{name}")
logger.info(f"\t{name}")
self.up_lr_weight: List[float] = None
self.down_lr_weight: List[float] = None
@@ -954,12 +957,12 @@ class LoRANetwork(torch.nn.Module):
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
if apply_text_encoder:
print("enable LoRA for text encoder")
logger.info("enable LoRA for text encoder")
else:
self.text_encoder_loras = []
if apply_unet:
print("enable LoRA for U-Net")
logger.info("enable LoRA for U-Net")
else:
self.unet_loras = []
@@ -981,12 +984,12 @@ class LoRANetwork(torch.nn.Module):
apply_unet = True
if apply_text_encoder:
print("enable LoRA for text encoder")
logger.info("enable LoRA for text encoder")
else:
self.text_encoder_loras = []
if apply_unet:
print("enable LoRA for U-Net")
logger.info("enable LoRA for U-Net")
else:
self.unet_loras = []
@@ -997,7 +1000,7 @@ class LoRANetwork(torch.nn.Module):
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
lora.merge_to(sd_for_lora, dtype, device)
print(f"weights are merged")
logger.info(f"weights are merged")
# 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない
def set_block_lr_weight(
@@ -1144,7 +1147,7 @@ class LoRANetwork(torch.nn.Module):
device = ref_weight.device
def resize_add(mh, mw):
# print(mh, mw, mh * mw)
# logger.info(mh, mw, mh * mw)
m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16
m = m.to(device, dtype=dtype)
mask_dic[mh * mw] = m

View File

@@ -5,27 +5,34 @@ from library import model_util
import library.train_util as train_util
import argparse
from transformers import CLIPTokenizer
import torch
from library.device_utils import init_ipex, get_preferred_device
init_ipex()
import library.model_util as model_util
import lora
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE = get_preferred_device()
def interrogate(args):
weights_dtype = torch.float16
# いろいろ準備する
print(f"loading SD model: {args.sd_model}")
logger.info(f"loading SD model: {args.sd_model}")
args.pretrained_model_name_or_path = args.sd_model
args.vae = None
text_encoder, vae, unet, _ = train_util._load_target_model(args,weights_dtype, DEVICE)
print(f"loading LoRA: {args.model}")
logger.info(f"loading LoRA: {args.model}")
network, weights_sd = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)
# text encoder向けの重みがあるかチェックする本当はlora側でやるのがいい
@@ -35,11 +42,11 @@ def interrogate(args):
has_te_weight = True
break
if not has_te_weight:
print("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません")
logger.error("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません")
return
del vae
print("loading tokenizer")
logger.info("loading tokenizer")
if args.v2:
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
else:
@@ -53,7 +60,7 @@ def interrogate(args):
# トークンをひとつひとつ当たっていく
token_id_start = 0
token_id_end = max(tokenizer.all_special_ids)
print(f"interrogate tokens are: {token_id_start} to {token_id_end}")
logger.info(f"interrogate tokens are: {token_id_start} to {token_id_end}")
def get_all_embeddings(text_encoder):
embs = []
@@ -79,24 +86,24 @@ def interrogate(args):
embs.extend(encoder_hidden_states)
return torch.stack(embs)
print("get original text encoder embeddings.")
logger.info("get original text encoder embeddings.")
orig_embs = get_all_embeddings(text_encoder)
network.apply_to(text_encoder, unet, True, len(network.unet_loras) > 0)
info = network.load_state_dict(weights_sd, strict=False)
print(f"Loading LoRA weights: {info}")
logger.info(f"Loading LoRA weights: {info}")
network.to(DEVICE, dtype=weights_dtype)
network.eval()
del unet
print("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません")
print("get text encoder embeddings with lora.")
logger.info("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません")
logger.info("get text encoder embeddings with lora.")
lora_embs = get_all_embeddings(text_encoder)
# 比べる:とりあえず単純に差分の絶対値で
print("comparing...")
logger.info("comparing...")
diffs = {}
for i, (orig_emb, lora_emb) in enumerate(zip(orig_embs, tqdm(lora_embs))):
diff = torch.mean(torch.abs(orig_emb - lora_emb))

View File

@@ -7,7 +7,10 @@ from safetensors.torch import load_file, save_file
from library import sai_model_spec, train_util
import library.model_util as model_util
import lora
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == ".safetensors":
@@ -61,10 +64,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
name_to_module[lora_name] = child_module
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
logger.info(f"loading: {model}")
lora_sd, _ = load_state_dict(model, merge_dtype)
print(f"merging...")
logger.info(f"merging...")
for key in lora_sd.keys():
if "lora_down" in key:
up_key = key.replace("lora_down", "lora_up")
@@ -73,10 +76,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
# find original module for this lora
module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
if module_name not in name_to_module:
print(f"no module found for LoRA weight: {key}")
logger.info(f"no module found for LoRA weight: {key}")
continue
module = name_to_module[module_name]
# print(f"apply {key} to {module}")
# logger.info(f"apply {key} to {module}")
down_weight = lora_sd[key]
up_weight = lora_sd[up_key]
@@ -104,7 +107,7 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# print(conved.size(), weight.size(), module.stride, module.padding)
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + ratio * conved * scale
module.weight = torch.nn.Parameter(weight)
@@ -118,7 +121,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
v2 = None
base_model = None
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
logger.info(f"loading: {model}")
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
if lora_metadata is not None:
@@ -151,10 +154,10 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
if lora_module_name not in base_alphas:
base_alphas[lora_module_name] = alpha
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
logger.info(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
# merge
print(f"merging...")
logger.info(f"merging...")
for key in lora_sd.keys():
if "alpha" in key:
continue
@@ -196,8 +199,8 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
merged_sd[key_down] = merged_sd[key_down][perm]
merged_sd[key_up] = merged_sd[key_up][:,perm]
print("merged model")
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
logger.info("merged model")
logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
# check all dims are same
dims_list = list(set(base_dims.values()))
@@ -239,7 +242,7 @@ def merge(args):
save_dtype = merge_dtype
if args.sd_model is not None:
print(f"loading SD model: {args.sd_model}")
logger.info(f"loading SD model: {args.sd_model}")
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
@@ -264,18 +267,18 @@ def merge(args):
)
if args.v2:
# TODO read sai modelspec
print(
logger.warning(
"Cannot determine if model is for v-prediction, so save metadata as v-prediction / modelがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します"
)
print(f"saving SD model to: {args.save_to}")
logger.info(f"saving SD model to: {args.save_to}")
model_util.save_stable_diffusion_checkpoint(
args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, sai_metadata, save_dtype, vae
)
else:
state_dict, metadata, v2 = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
print(f"calculating hashes and creating metadata...")
logger.info(f"calculating hashes and creating metadata...")
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
@@ -289,12 +292,12 @@ def merge(args):
)
if v2:
# TODO read sai modelspec
print(
logger.warning(
"Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します"
)
metadata.update(sai_metadata)
print(f"saving model to: {args.save_to}")
logger.info(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)

View File

@@ -6,7 +6,10 @@ import torch
from safetensors.torch import load_file, save_file
import library.model_util as model_util
import lora
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == '.safetensors':
@@ -54,10 +57,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
name_to_module[lora_name] = child_module
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
logger.info(f"loading: {model}")
lora_sd = load_state_dict(model, merge_dtype)
print(f"merging...")
logger.info(f"merging...")
for key in lora_sd.keys():
if "lora_down" in key:
up_key = key.replace("lora_down", "lora_up")
@@ -66,10 +69,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
# find original module for this lora
module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight"
if module_name not in name_to_module:
print(f"no module found for LoRA weight: {key}")
logger.info(f"no module found for LoRA weight: {key}")
continue
module = name_to_module[module_name]
# print(f"apply {key} to {module}")
# logger.info(f"apply {key} to {module}")
down_weight = lora_sd[key]
up_weight = lora_sd[up_key]
@@ -96,10 +99,10 @@ def merge_lora_models(models, ratios, merge_dtype):
alpha = None
dim = None
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
logger.info(f"loading: {model}")
lora_sd = load_state_dict(model, merge_dtype)
print(f"merging...")
logger.info(f"merging...")
for key in lora_sd.keys():
if 'alpha' in key:
if key in merged_sd:
@@ -117,7 +120,7 @@ def merge_lora_models(models, ratios, merge_dtype):
dim = lora_sd[key].size()[0]
merged_sd[key] = lora_sd[key] * ratio
print(f"dim (rank): {dim}, alpha: {alpha}")
logger.info(f"dim (rank): {dim}, alpha: {alpha}")
if alpha is None:
alpha = dim
@@ -142,19 +145,21 @@ def merge(args):
save_dtype = merge_dtype
if args.sd_model is not None:
print(f"loading SD model: {args.sd_model}")
logger.info(f"loading SD model: {args.sd_model}")
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype)
print(f"\nsaving SD model to: {args.save_to}")
logger.info("")
logger.info(f"saving SD model to: {args.save_to}")
model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet,
args.sd_model, 0, 0, save_dtype, vae)
else:
state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype)
print(f"\nsaving model to: {args.save_to}")
logger.info(f"")
logger.info(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype)

View File

@@ -8,7 +8,10 @@ from transformers import CLIPTextModel
import numpy as np
import torch
import re
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
@@ -237,7 +240,7 @@ class OFTNetwork(torch.nn.Module):
self.dim = dim
self.alpha = alpha
print(
logger.info(
f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}"
)
@@ -258,7 +261,7 @@ class OFTNetwork(torch.nn.Module):
if is_linear or is_conv2d_1x1 or (is_conv2d and enable_conv):
oft_name = prefix + "." + name + "." + child_name
oft_name = oft_name.replace(".", "_")
# print(oft_name)
# logger.info(oft_name)
oft = module_class(
oft_name,
@@ -279,7 +282,7 @@ class OFTNetwork(torch.nn.Module):
target_modules += OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
self.unet_ofts: List[OFTModule] = create_modules(unet, target_modules)
print(f"create OFT for U-Net: {len(self.unet_ofts)} modules.")
logger.info(f"create OFT for U-Net: {len(self.unet_ofts)} modules.")
# assertion
names = set()
@@ -316,7 +319,7 @@ class OFTNetwork(torch.nn.Module):
# TODO refactor to common function with apply_to
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
print("enable OFT for U-Net")
logger.info("enable OFT for U-Net")
for oft in self.unet_ofts:
sd_for_lora = {}
@@ -326,7 +329,7 @@ class OFTNetwork(torch.nn.Module):
oft.load_state_dict(sd_for_lora, False)
oft.merge_to()
print(f"weights are merged")
logger.info(f"weights are merged")
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
@@ -338,11 +341,11 @@ class OFTNetwork(torch.nn.Module):
for oft in ofts:
params.extend(oft.parameters())
# print num of params
# logger.info num of params
num_params = 0
for p in params:
num_params += p.numel()
print(f"OFT params: {num_params}")
logger.info(f"OFT params: {num_params}")
return params
param_data = {"params": enumerate_params(self.unet_ofts)}

View File

@@ -2,80 +2,91 @@
# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
# Thanks to cloneofsimo
import os
import argparse
import torch
from safetensors.torch import load_file, save_file, safe_open
from tqdm import tqdm
from library import train_util, model_util
import numpy as np
from library import train_util
from library import model_util
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
MIN_SV = 1e-6
# Model save and load functions
def load_state_dict(file_name, dtype):
if model_util.is_safetensors(file_name):
sd = load_file(file_name)
with safe_open(file_name, framework="pt") as f:
metadata = f.metadata()
else:
sd = torch.load(file_name, map_location='cpu')
metadata = None
if model_util.is_safetensors(file_name):
sd = load_file(file_name)
with safe_open(file_name, framework="pt") as f:
metadata = f.metadata()
else:
sd = torch.load(file_name, map_location="cpu")
metadata = None
for key in list(sd.keys()):
if type(sd[key]) == torch.Tensor:
sd[key] = sd[key].to(dtype)
for key in list(sd.keys()):
if type(sd[key]) == torch.Tensor:
sd[key] = sd[key].to(dtype)
return sd, metadata
return sd, metadata
def save_to_file(file_name, model, state_dict, dtype, metadata):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
def save_to_file(file_name, state_dict, dtype, metadata):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
if model_util.is_safetensors(file_name):
save_file(model, file_name, metadata)
else:
torch.save(model, file_name)
if model_util.is_safetensors(file_name):
save_file(state_dict, file_name, metadata)
else:
torch.save(state_dict, file_name)
# Indexing functions
def index_sv_cumulative(S, target):
original_sum = float(torch.sum(S))
cumulative_sums = torch.cumsum(S, dim=0)/original_sum
index = int(torch.searchsorted(cumulative_sums, target)) + 1
index = max(1, min(index, len(S)-1))
return index
def index_sv_cumulative(S, target):
original_sum = float(torch.sum(S))
cumulative_sums = torch.cumsum(S, dim=0) / original_sum
index = int(torch.searchsorted(cumulative_sums, target)) + 1
index = max(1, min(index, len(S) - 1))
return index
def index_sv_fro(S, target):
S_squared = S.pow(2)
s_fro_sq = float(torch.sum(S_squared))
sum_S_squared = torch.cumsum(S_squared, dim=0)/s_fro_sq
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
index = max(1, min(index, len(S)-1))
S_squared = S.pow(2)
S_fro_sq = float(torch.sum(S_squared))
sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
index = max(1, min(index, len(S) - 1))
return index
return index
def index_sv_ratio(S, target):
max_sv = S[0]
min_sv = max_sv/target
index = int(torch.sum(S > min_sv).item())
index = max(1, min(index, len(S)-1))
max_sv = S[0]
min_sv = max_sv / target
index = int(torch.sum(S > min_sv).item())
index = max(1, min(index, len(S) - 1))
return index
return index
# Modified from Kohaku-blueleaf's extract/merge functions
def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
out_size, in_size, kernel_size, _ = weight.size()
U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device))
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
lora_rank = param_dict["new_rank"]
@@ -92,17 +103,17 @@ def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale
def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
out_size, in_size = weight.size()
U, S, Vh = torch.linalg.svd(weight.to(device))
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
lora_rank = param_dict["new_rank"]
U = U[:, :lora_rank]
S = S[:lora_rank]
U = U @ torch.diag(S)
Vh = Vh[:lora_rank, :]
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu()
param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu()
del U, S, Vh, weight
@@ -113,7 +124,7 @@ def merge_conv(lora_down, lora_up, device):
in_rank, in_size, kernel_size, k_ = lora_down.shape
out_size, out_rank, _, _ = lora_up.shape
assert in_rank == out_rank and kernel_size == k_, f"rank {in_rank} {out_rank} or kernel {kernel_size} {k_} mismatch"
lora_down = lora_down.to(device)
lora_up = lora_up.to(device)
@@ -127,236 +138,274 @@ def merge_linear(lora_down, lora_up, device):
in_rank, in_size = lora_down.shape
out_size, out_rank = lora_up.shape
assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch"
lora_down = lora_down.to(device)
lora_up = lora_up.to(device)
weight = lora_up @ lora_down
del lora_up, lora_down
return weight
# Calculate new rank
def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
param_dict = {}
if dynamic_method=="sv_ratio":
if dynamic_method == "sv_ratio":
# Calculate new dim and alpha based off ratio
new_rank = index_sv_ratio(S, dynamic_param) + 1
new_alpha = float(scale*new_rank)
new_alpha = float(scale * new_rank)
elif dynamic_method=="sv_cumulative":
elif dynamic_method == "sv_cumulative":
# Calculate new dim and alpha based off cumulative sum
new_rank = index_sv_cumulative(S, dynamic_param) + 1
new_alpha = float(scale*new_rank)
new_alpha = float(scale * new_rank)
elif dynamic_method=="sv_fro":
elif dynamic_method == "sv_fro":
# Calculate new dim and alpha based off sqrt sum of squares
new_rank = index_sv_fro(S, dynamic_param) + 1
new_alpha = float(scale*new_rank)
new_alpha = float(scale * new_rank)
else:
new_rank = rank
new_alpha = float(scale*new_rank)
new_alpha = float(scale * new_rank)
if S[0] <= MIN_SV: # Zero matrix, set dim to 1
if S[0] <= MIN_SV: # Zero matrix, set dim to 1
new_rank = 1
new_alpha = float(scale*new_rank)
elif new_rank > rank: # cap max rank at rank
new_alpha = float(scale * new_rank)
elif new_rank > rank: # cap max rank at rank
new_rank = rank
new_alpha = float(scale*new_rank)
new_alpha = float(scale * new_rank)
# Calculate resize info
s_sum = torch.sum(torch.abs(S))
s_rank = torch.sum(torch.abs(S[:new_rank]))
S_squared = S.pow(2)
s_fro = torch.sqrt(torch.sum(S_squared))
s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank]))
fro_percent = float(s_red_fro/s_fro)
fro_percent = float(s_red_fro / s_fro)
param_dict["new_rank"] = new_rank
param_dict["new_alpha"] = new_alpha
param_dict["sum_retained"] = (s_rank)/s_sum
param_dict["sum_retained"] = (s_rank) / s_sum
param_dict["fro_retained"] = fro_percent
param_dict["max_ratio"] = S[0]/S[new_rank - 1]
param_dict["max_ratio"] = S[0] / S[new_rank - 1]
return param_dict
def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
network_alpha = None
network_dim = None
verbose_str = "\n"
fro_list = []
def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
network_alpha = None
network_dim = None
verbose_str = "\n"
fro_list = []
# Extract loaded lora dim and alpha
for key, value in lora_sd.items():
if network_alpha is None and 'alpha' in key:
network_alpha = value
if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
network_dim = value.size()[0]
if network_alpha is not None and network_dim is not None:
break
if network_alpha is None:
network_alpha = network_dim
# Extract loaded lora dim and alpha
for key, value in lora_sd.items():
if network_alpha is None and "alpha" in key:
network_alpha = value
if network_dim is None and "lora_down" in key and len(value.size()) == 2:
network_dim = value.size()[0]
if network_alpha is not None and network_dim is not None:
break
if network_alpha is None:
network_alpha = network_dim
scale = network_alpha/network_dim
scale = network_alpha / network_dim
if dynamic_method:
print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}")
if dynamic_method:
logger.info(
f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}"
)
lora_down_weight = None
lora_up_weight = None
lora_down_weight = None
lora_up_weight = None
o_lora_sd = lora_sd.copy()
block_down_name = None
block_up_name = None
o_lora_sd = lora_sd.copy()
block_down_name = None
block_up_name = None
with torch.no_grad():
for key, value in tqdm(lora_sd.items()):
weight_name = None
if 'lora_down' in key:
block_down_name = key.rsplit('.lora_down', 1)[0]
weight_name = key.rsplit(".", 1)[-1]
lora_down_weight = value
else:
continue
with torch.no_grad():
for key, value in tqdm(lora_sd.items()):
weight_name = None
if "lora_down" in key:
block_down_name = key.rsplit(".lora_down", 1)[0]
weight_name = key.rsplit(".", 1)[-1]
lora_down_weight = value
else:
continue
# find corresponding lora_up and alpha
block_up_name = block_down_name
lora_up_weight = lora_sd.get(block_up_name + '.lora_up.' + weight_name, None)
lora_alpha = lora_sd.get(block_down_name + '.alpha', None)
# find corresponding lora_up and alpha
block_up_name = block_down_name
lora_up_weight = lora_sd.get(block_up_name + ".lora_up." + weight_name, None)
lora_alpha = lora_sd.get(block_down_name + ".alpha", None)
weights_loaded = (lora_down_weight is not None and lora_up_weight is not None)
weights_loaded = lora_down_weight is not None and lora_up_weight is not None
if weights_loaded:
if weights_loaded:
conv2d = (len(lora_down_weight.size()) == 4)
if lora_alpha is None:
scale = 1.0
else:
scale = lora_alpha/lora_down_weight.size()[0]
conv2d = len(lora_down_weight.size()) == 4
if lora_alpha is None:
scale = 1.0
else:
scale = lora_alpha / lora_down_weight.size()[0]
if conv2d:
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
param_dict = extract_conv(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
else:
full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device)
param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
if conv2d:
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
param_dict = extract_conv(full_weight_matrix, new_conv_rank, dynamic_method, dynamic_param, device, scale)
else:
full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device)
param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
if verbose:
max_ratio = param_dict['max_ratio']
sum_retained = param_dict['sum_retained']
fro_retained = param_dict['fro_retained']
if not np.isnan(fro_retained):
fro_list.append(float(fro_retained))
if verbose:
max_ratio = param_dict["max_ratio"]
sum_retained = param_dict["sum_retained"]
fro_retained = param_dict["fro_retained"]
if not np.isnan(fro_retained):
fro_list.append(float(fro_retained))
verbose_str+=f"{block_down_name:75} | "
verbose_str+=f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}"
verbose_str += f"{block_down_name:75} | "
verbose_str += (
f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}"
)
if verbose and dynamic_method:
verbose_str+=f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n"
else:
verbose_str+=f"\n"
if verbose and dynamic_method:
verbose_str += f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n"
else:
verbose_str += "\n"
new_alpha = param_dict['new_alpha']
o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous()
o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous()
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict['new_alpha']).to(save_dtype)
new_alpha = param_dict["new_alpha"]
o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous()
o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous()
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict["new_alpha"]).to(save_dtype)
block_down_name = None
block_up_name = None
lora_down_weight = None
lora_up_weight = None
weights_loaded = False
del param_dict
block_down_name = None
block_up_name = None
lora_down_weight = None
lora_up_weight = None
weights_loaded = False
del param_dict
if verbose:
print(verbose_str)
print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
print("resizing complete")
return o_lora_sd, network_dim, new_alpha
if verbose:
print(verbose_str)
print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
logger.info("resizing complete")
return o_lora_sd, network_dim, new_alpha
def resize(args):
if args.save_to is None or not (args.save_to.endswith('.ckpt') or args.save_to.endswith('.pt') or args.save_to.endswith('.pth') or args.save_to.endswith('.safetensors')):
raise Exception("The --save_to argument must be specified and must be a .ckpt , .pt, .pth or .safetensors file.")
if args.save_to is None or not (
args.save_to.endswith(".ckpt")
or args.save_to.endswith(".pt")
or args.save_to.endswith(".pth")
or args.save_to.endswith(".safetensors")
):
raise Exception("The --save_to argument must be specified and must be a .ckpt , .pt, .pth or .safetensors file.")
def str_to_dtype(p):
if p == 'float':
return torch.float
if p == 'fp16':
return torch.float16
if p == 'bf16':
return torch.bfloat16
return None
args.new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank
if args.dynamic_method and not args.dynamic_param:
raise Exception("If using dynamic_method, then dynamic_param is required")
def str_to_dtype(p):
if p == "float":
return torch.float
if p == "fp16":
return torch.float16
if p == "bf16":
return torch.bfloat16
return None
merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32
save_dtype = str_to_dtype(args.save_precision)
if save_dtype is None:
save_dtype = merge_dtype
if args.dynamic_method and not args.dynamic_param:
raise Exception("If using dynamic_method, then dynamic_param is required")
print("loading Model...")
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
merge_dtype = str_to_dtype("float") # matmul method above only seems to work in float32
save_dtype = str_to_dtype(args.save_precision)
if save_dtype is None:
save_dtype = merge_dtype
print("Resizing Lora...")
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose)
logger.info("loading Model...")
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
# update metadata
if metadata is None:
metadata = {}
logger.info("Resizing Lora...")
state_dict, old_dim, new_alpha = resize_lora_model(
lora_sd, args.new_rank, args.new_conv_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose
)
comment = metadata.get("ss_training_comment", "")
# update metadata
if metadata is None:
metadata = {}
if not args.dynamic_method:
metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}"
metadata["ss_network_dim"] = str(args.new_rank)
metadata["ss_network_alpha"] = str(new_alpha)
else:
metadata["ss_training_comment"] = f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}"
metadata["ss_network_dim"] = 'Dynamic'
metadata["ss_network_alpha"] = 'Dynamic'
comment = metadata.get("ss_training_comment", "")
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
if not args.dynamic_method:
conv_desc = "" if args.new_rank == args.new_conv_rank else f" (conv: {args.new_conv_rank})"
metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}{conv_desc}; {comment}"
metadata["ss_network_dim"] = str(args.new_rank)
metadata["ss_network_alpha"] = str(new_alpha)
else:
metadata["ss_training_comment"] = (
f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}"
)
metadata["ss_network_dim"] = "Dynamic"
metadata["ss_network_alpha"] = "Dynamic"
print(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
logger.info(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, save_dtype, metadata)
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser()
parser.add_argument("--save_precision", type=str, default=None,
choices=[None, "float", "fp16", "bf16"], help="precision in saving, float if omitted / 保存時の精度、未指定時はfloat")
parser.add_argument("--new_rank", type=int, default=4,
help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
parser.add_argument("--save_to", type=str, default=None,
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
parser.add_argument("--model", type=str, default=None,
help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors")
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
parser.add_argument("--verbose", action="store_true",
help="Display verbose resizing information / rank変更時の詳細情報を出力する")
parser.add_argument("--dynamic_method", type=str, default=None, choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"],
help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank")
parser.add_argument("--dynamic_param", type=float, default=None,
help="Specify target for dynamic reduction")
return parser
parser.add_argument(
"--save_precision",
type=str,
default=None,
choices=[None, "float", "fp16", "bf16"],
help="precision in saving, float if omitted / 保存時の精度、未指定時はfloat",
)
parser.add_argument("--new_rank", type=int, default=4, help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
parser.add_argument(
"--new_conv_rank",
type=int,
default=None,
help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ",
)
parser.add_argument(
"--save_to",
type=str,
default=None,
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors",
)
parser.add_argument(
"--model",
type=str,
default=None,
help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors",
)
parser.add_argument(
"--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う"
)
parser.add_argument(
"--verbose", action="store_true", help="Display verbose resizing information / rank変更時の詳細情報を出力する"
)
parser.add_argument(
"--dynamic_method",
type=str,
default=None,
choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"],
help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank",
)
parser.add_argument("--dynamic_param", type=float, default=None, help="Specify target for dynamic reduction")
return parser
if __name__ == '__main__':
parser = setup_parser()
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
resize(args)
args = parser.parse_args()
resize(args)

View File

@@ -8,7 +8,10 @@ from tqdm import tqdm
from library import sai_model_spec, sdxl_model_util, train_util
import library.model_util as model_util
import lora
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == ".safetensors":
@@ -66,10 +69,10 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
name_to_module[lora_name] = child_module
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
logger.info(f"loading: {model}")
lora_sd, _ = load_state_dict(model, merge_dtype)
print(f"merging...")
logger.info(f"merging...")
for key in tqdm(lora_sd.keys()):
if "lora_down" in key:
up_key = key.replace("lora_down", "lora_up")
@@ -78,10 +81,10 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
# find original module for this lora
module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
if module_name not in name_to_module:
print(f"no module found for LoRA weight: {key}")
logger.info(f"no module found for LoRA weight: {key}")
continue
module = name_to_module[module_name]
# print(f"apply {key} to {module}")
# logger.info(f"apply {key} to {module}")
down_weight = lora_sd[key]
up_weight = lora_sd[up_key]
@@ -92,7 +95,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
# W <- W + U * D
weight = module.weight
# print(module_name, down_weight.size(), up_weight.size())
# logger.info(module_name, down_weight.size(), up_weight.size())
if len(weight.size()) == 2:
# linear
weight = weight + ratio * (up_weight @ down_weight) * scale
@@ -107,7 +110,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# print(conved.size(), weight.size(), module.stride, module.padding)
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + ratio * conved * scale
module.weight = torch.nn.Parameter(weight)
@@ -121,7 +124,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
v2 = None
base_model = None
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
logger.info(f"loading: {model}")
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
if lora_metadata is not None:
@@ -154,10 +157,10 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
if lora_module_name not in base_alphas:
base_alphas[lora_module_name] = alpha
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
logger.info(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
# merge
print(f"merging...")
logger.info(f"merging...")
for key in tqdm(lora_sd.keys()):
if "alpha" in key:
continue
@@ -200,8 +203,8 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
merged_sd[key_down] = merged_sd[key_down][perm]
merged_sd[key_up] = merged_sd[key_up][:,perm]
print("merged model")
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
logger.info("merged model")
logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
# check all dims are same
dims_list = list(set(base_dims.values()))
@@ -243,7 +246,7 @@ def merge(args):
save_dtype = merge_dtype
if args.sd_model is not None:
print(f"loading SD model: {args.sd_model}")
logger.info(f"loading SD model: {args.sd_model}")
(
text_model1,
@@ -265,14 +268,14 @@ def merge(args):
None, False, False, True, False, False, time.time(), title=title, merged_from=merged_from
)
print(f"saving SD model to: {args.save_to}")
logger.info(f"saving SD model to: {args.save_to}")
sdxl_model_util.save_stable_diffusion_checkpoint(
args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, logit_scale, sai_metadata, save_dtype
)
else:
state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
print(f"calculating hashes and creating metadata...")
logger.info(f"calculating hashes and creating metadata...")
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
@@ -286,7 +289,7 @@ def merge(args):
)
metadata.update(sai_metadata)
print(f"saving model to: {args.save_to}")
logger.info(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)

View File

@@ -1,4 +1,3 @@
import math
import argparse
import os
import time
@@ -8,7 +7,10 @@ from tqdm import tqdm
from library import sai_model_spec, train_util
import library.model_util as model_util
import lora
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
CLAMP_QUANTILE = 0.99
@@ -41,12 +43,12 @@ def save_to_file(file_name, state_dict, dtype, metadata):
def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype):
print(f"new rank: {new_rank}, new conv rank: {new_conv_rank}")
logger.info(f"new rank: {new_rank}, new conv rank: {new_conv_rank}")
merged_sd = {}
v2 = None
base_model = None
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
logger.info(f"loading: {model}")
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
if lora_metadata is not None:
@@ -56,7 +58,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None)
# merge
print(f"merging...")
logger.info(f"merging...")
for key in tqdm(list(lora_sd.keys())):
if "lora_down" not in key:
continue
@@ -73,7 +75,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
out_dim = up_weight.size()[0]
conv2d = len(down_weight.size()) == 4
kernel_size = None if not conv2d else down_weight.size()[2:4]
# print(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size)
# logger.info(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size)
# make original weight if not exist
if lora_module_name not in merged_sd:
@@ -110,7 +112,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
merged_sd[lora_module_name] = weight
# extract from merged weights
print("extract new lora...")
logger.info("extract new lora...")
merged_lora_sd = {}
with torch.no_grad():
for lora_module_name, mat in tqdm(list(merged_sd.items())):
@@ -188,7 +190,7 @@ def merge(args):
args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype
)
print(f"calculating hashes and creating metadata...")
logger.info(f"calculating hashes and creating metadata...")
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
@@ -203,12 +205,12 @@ def merge(args):
)
if v2:
# TODO read sai modelspec
print(
logger.warning(
"Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します"
)
metadata.update(sai_metadata)
print(f"saving model to: {args.save_to}")
logger.info(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, save_dtype, metadata)

View File

@@ -1,20 +1,20 @@
accelerate==0.23.0
transformers==4.30.2
diffusers[torch]==0.21.2
accelerate==0.25.0
transformers==4.36.2
diffusers[torch]==0.25.0
ftfy==6.1.1
# albumentations==1.3.0
opencv-python==4.7.0.68
einops==0.6.0
einops==0.7.0
pytorch-lightning==1.9.0
# bitsandbytes==0.39.1
tensorboard==2.10.1
safetensors==0.3.1
safetensors==0.4.2
# gradio==3.16.2
altair==4.2.2
easygui==0.98.3
toml==0.10.2
voluptuous==0.13.1
huggingface-hub==0.15.1
huggingface-hub==0.20.1
# for BLIP captioning
# requests==2.28.2
# timm==0.6.12
@@ -29,5 +29,7 @@ huggingface-hub==0.15.1
# protobuf==3.20.3
# open clip for SDXL
open-clip-torch==2.20.0
# For logging
rich==13.7.0
# for kohya_ss library
-e .

File diff suppressed because it is too large Load Diff

View File

@@ -8,14 +8,11 @@ import os
import random
from einops import repeat
import numpy as np
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from library.device_utils import init_ipex, get_preferred_device
init_ipex()
from tqdm import tqdm
from transformers import CLIPTokenizer
from diffusers import EulerDiscreteScheduler
@@ -25,6 +22,10 @@ from safetensors.torch import load_file
from library import model_util, sdxl_model_util
import networks.lora as lora
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# scheduler: このあたりの設定はSD1/2と同じでいいらしい
# scheduler: The settings around here seem to be the same as SD1/2
@@ -87,7 +88,7 @@ if __name__ == "__main__":
guidance_scale = 7
seed = None # 1
DEVICE = "cuda"
DEVICE = get_preferred_device()
DTYPE = torch.float16 # bfloat16 may work
parser = argparse.ArgumentParser()
@@ -142,7 +143,7 @@ if __name__ == "__main__":
vae_dtype = DTYPE
if DTYPE == torch.float16:
print("use float32 for vae")
logger.info("use float32 for vae")
vae_dtype = torch.float32
vae.to(DEVICE, dtype=vae_dtype)
vae.eval()
@@ -189,7 +190,7 @@ if __name__ == "__main__":
emb1 = get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256)
emb2 = get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256)
emb3 = get_timestep_embedding(torch.FloatTensor([target_height, target_width]).unsqueeze(0), 256)
# print("emb1", emb1.shape)
# logger.info("emb1", emb1.shape)
c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(DEVICE, dtype=DTYPE)
uc_vector = c_vector.clone().to(DEVICE, dtype=DTYPE) # ちょっとここ正しいかどうかわからない I'm not sure if this is right
@@ -219,7 +220,7 @@ if __name__ == "__main__":
enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True)
text_embedding2_penu = enc_out["hidden_states"][-2]
# print("hidden_states2", text_embedding2_penu.shape)
# logger.info("hidden_states2", text_embedding2_penu.shape)
text_embedding2_pool = enc_out["text_embeds"] # do not support Textual Inversion
# 連結して終了 concat and finish
@@ -228,7 +229,7 @@ if __name__ == "__main__":
# cond
c_ctx, c_ctx_pool = call_text_encoder(prompt, prompt2)
# print(c_ctx.shape, c_ctx_p.shape, c_vector.shape)
# logger.info(c_ctx.shape, c_ctx_p.shape, c_vector.shape)
c_vector = torch.cat([c_ctx_pool, c_vector], dim=1)
# uncond
@@ -325,4 +326,4 @@ if __name__ == "__main__":
seed = int(seed)
generate_image(prompt, prompt2, negative_prompt, seed)
print("Done!")
logger.info("Done!")

View File

@@ -1,7 +1,6 @@
# training with captions
import argparse
import gc
import math
import os
from multiprocessing import Value
@@ -9,22 +8,24 @@ from typing import List
import toml
from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from library import sdxl_model_util
import library.train_util as train_util
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
import library.config_util as config_util
import library.sdxl_train_util as sdxl_train_util
from library.config_util import (
@@ -96,8 +97,11 @@ def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
sdxl_train_util.verify_sdxl_training_args(args)
setup_logging(args, reset=True)
assert not args.weighted_captions, "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
assert (
not args.weighted_captions
), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
assert (
not args.train_text_encoder or not args.cache_text_encoder_outputs
), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません"
@@ -122,18 +126,18 @@ def train(args):
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
if use_dreambooth_method:
print("Using DreamBooth method.")
logger.info("Using DreamBooth method.")
user_config = {
"datasets": [
{
@@ -144,7 +148,7 @@ def train(args):
]
}
else:
print("Training with captions.")
logger.info("Training with captions.")
user_config = {
"datasets": [
{
@@ -174,7 +178,7 @@ def train(args):
train_util.debug_dataset(train_dataset_group, True)
return
if len(train_dataset_group) == 0:
print(
logger.error(
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
)
return
@@ -190,7 +194,7 @@ def train(args):
), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
# acceleratorを準備する
print("prepare accelerator")
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
@@ -257,9 +261,7 @@ def train(args):
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
@@ -352,8 +354,8 @@ def train(args):
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
@@ -368,7 +370,9 @@ def train(args):
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
@@ -412,8 +416,7 @@ def train(args):
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
text_encoder1.to("cpu", dtype=torch.float32)
text_encoder2.to("cpu", dtype=torch.float32)
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory_on_device(accelerator.device)
else:
# make sure Text Encoders are on GPU
text_encoder1.to(accelerator.device)
@@ -438,7 +441,9 @@ def train(args):
accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
accelerator.print(
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
)
# accelerator.print(
# f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
# )
@@ -457,6 +462,8 @@ def train(args):
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
@@ -540,7 +547,7 @@ def train(args):
# assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# print("text encoder outputs verified")
# logger.info("text encoder outputs verified")
# get size embeddings
orig_size = batch["original_sizes_hw"]
@@ -727,12 +734,13 @@ def train(args):
logit_scale,
ckpt_info,
)
print("model saved.")
logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, False)
@@ -755,7 +763,9 @@ def setup_parser() -> argparse.ArgumentParser:
help="learning rate for text encoder 2 (BiG-G) / text encoder 2 (BiG-G)の学習率",
)
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
parser.add_argument(
"--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する"
)
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
parser.add_argument(
"--no_half_vae",

View File

@@ -2,7 +2,6 @@
# training code for ControlNet-LLLite with passing cond_image to U-Net's forward
import argparse
import gc
import json
import math
import os
@@ -13,14 +12,11 @@ from types import SimpleNamespace
import toml
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from torch.nn.parallel import DistributedDataParallel as DDP
from accelerate.utils import set_seed
import accelerate
@@ -47,6 +43,12 @@ from library.custom_train_functions import (
apply_debiased_estimation,
)
import networks.control_net_lllite_for_train as control_net_lllite_for_train
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
# TODO 他のスクリプトと共通化する
@@ -67,6 +69,7 @@ def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
sdxl_train_util.verify_sdxl_training_args(args)
setup_logging(args, reset=True)
cache_latents = args.cache_latents
use_user_config = args.dataset_config is not None
@@ -80,11 +83,11 @@ def train(args):
# データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
if use_user_config:
print(f"Load dataset config from {args.dataset_config}")
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "conditioning_data_dir"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
@@ -116,7 +119,7 @@ def train(args):
train_util.debug_dataset(train_dataset_group)
return
if len(train_dataset_group) == 0:
print(
logger.error(
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してくださいtrain_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります"
)
return
@@ -126,7 +129,9 @@ def train(args):
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
else:
print("WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません")
logger.warning(
"WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません"
)
if args.cache_text_encoder_outputs:
assert (
@@ -134,7 +139,7 @@ def train(args):
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
# acceleratorを準備する
print("prepare accelerator")
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
is_main_process = accelerator.is_main_process
@@ -166,9 +171,7 @@ def train(args):
accelerator.is_main_process,
)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
@@ -233,14 +236,14 @@ def train(args):
accelerator.print("prepare optimizer, data loader etc.")
trainable_params = list(unet.prepare_params())
print(f"trainable params count: {len(trainable_params)}")
print(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}")
logger.info(f"trainable params count: {len(trainable_params)}")
logger.info(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}")
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
@@ -256,7 +259,9 @@ def train(args):
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
@@ -293,8 +298,7 @@ def train(args):
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
text_encoder1.to("cpu", dtype=torch.float32)
text_encoder2.to("cpu", dtype=torch.float32)
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory_on_device(accelerator.device)
else:
# make sure Text Encoders are on GPU
text_encoder1.to(accelerator.device)
@@ -325,8 +329,10 @@ def train(args):
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
# print(f" total train batch size (with parallel & distributed & accumulation) / バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
accelerator.print(
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
)
# logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
@@ -342,6 +348,8 @@ def train(args):
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
@@ -548,12 +556,13 @@ def train(args):
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(ckpt_name, unet, global_step, num_train_epochs, force_sync_upload=True)
print("model saved.")
logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
@@ -569,8 +578,12 @@ def setup_parser() -> argparse.ArgumentParser:
choices=[None, "ckpt", "pt", "safetensors"],
help="format to save the model (default is .safetensors) / モデル保存時の形式デフォルトはsafetensors",
)
parser.add_argument("--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数")
parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み")
parser.add_argument(
"--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数"
)
parser.add_argument(
"--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み"
)
parser.add_argument("--network_dim", type=int, default=None, help="network dimensions (rank) / モジュールの次元数")
parser.add_argument(
"--network_dropout",

View File

@@ -1,5 +1,4 @@
import argparse
import gc
import json
import math
import os
@@ -10,14 +9,11 @@ from types import SimpleNamespace
import toml
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from torch.nn.parallel import DistributedDataParallel as DDP
from accelerate.utils import set_seed
from diffusers import DDPMScheduler, ControlNetModel
@@ -43,6 +39,12 @@ from library.custom_train_functions import (
apply_debiased_estimation,
)
import networks.control_net_lllite as control_net_lllite
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
# TODO 他のスクリプトと共通化する
@@ -63,6 +65,7 @@ def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
sdxl_train_util.verify_sdxl_training_args(args)
setup_logging(args, reset=True)
cache_latents = args.cache_latents
use_user_config = args.dataset_config is not None
@@ -76,11 +79,11 @@ def train(args):
# データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
if use_user_config:
print(f"Load dataset config from {args.dataset_config}")
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "conditioning_data_dir"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
@@ -112,7 +115,7 @@ def train(args):
train_util.debug_dataset(train_dataset_group)
return
if len(train_dataset_group) == 0:
print(
logger.error(
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してくださいtrain_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります"
)
return
@@ -122,7 +125,9 @@ def train(args):
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
else:
print("WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません")
logger.warning(
"WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません"
)
if args.cache_text_encoder_outputs:
assert (
@@ -130,7 +135,7 @@ def train(args):
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
# acceleratorを準備する
print("prepare accelerator")
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
is_main_process = accelerator.is_main_process
@@ -165,9 +170,7 @@ def train(args):
accelerator.is_main_process,
)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
@@ -201,14 +204,14 @@ def train(args):
accelerator.print("prepare optimizer, data loader etc.")
trainable_params = list(network.prepare_optimizer_params())
print(f"trainable params count: {len(trainable_params)}")
print(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}")
logger.info(f"trainable params count: {len(trainable_params)}")
logger.info(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}")
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
@@ -224,7 +227,9 @@ def train(args):
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
@@ -266,8 +271,7 @@ def train(args):
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
text_encoder1.to("cpu", dtype=torch.float32)
text_encoder2.to("cpu", dtype=torch.float32)
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory_on_device(accelerator.device)
else:
# make sure Text Encoders are on GPU
text_encoder1.to(accelerator.device)
@@ -298,8 +302,10 @@ def train(args):
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
# print(f" total train batch size (with parallel & distributed & accumulation) / バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
accelerator.print(
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
)
# logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
@@ -518,12 +524,13 @@ def train(args):
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True)
print("model saved.")
logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
@@ -539,8 +546,12 @@ def setup_parser() -> argparse.ArgumentParser:
choices=[None, "ckpt", "pt", "safetensors"],
help="format to save the model (default is .safetensors) / モデル保存時の形式デフォルトはsafetensors",
)
parser.add_argument("--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数")
parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み")
parser.add_argument(
"--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数"
)
parser.add_argument(
"--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み"
)
parser.add_argument("--network_dim", type=int, default=None, help="network dimensions (rank) / モジュールの次元数")
parser.add_argument(
"--network_dropout",

View File

@@ -1,18 +1,15 @@
import argparse
import torch
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from library import sdxl_model_util, sdxl_train_util, train_util
import train_network
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class SdxlNetworkTrainer(train_network.NetworkTrainer):
def __init__(self):
@@ -65,13 +62,12 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
if args.cache_text_encoder_outputs:
if not args.lowram:
# メモリ消費を減らす
print("move vae and unet to cpu to save memory")
logger.info("move vae and unet to cpu to save memory")
org_vae_device = vae.device
org_unet_device = unet.device
vae.to("cpu")
unet.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory_on_device(accelerator.device)
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
with accelerator.autocast():
@@ -86,17 +82,16 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
text_encoders[1].to("cpu", dtype=torch.float32)
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory_on_device(accelerator.device)
if not args.lowram:
print("move vae and unet back to original device")
logger.info("move vae and unet back to original device")
vae.to(org_vae_device)
unet.to(org_unet_device)
else:
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
text_encoders[0].to(accelerator.device)
text_encoders[1].to(accelerator.device)
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
text_encoders[1].to(accelerator.device, dtype=weight_dtype)
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
@@ -148,7 +143,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
# assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# print("text encoder outputs verified")
# logger.info("text encoder outputs verified")
return encoder_hidden_states1, encoder_hidden_states2, pool2

View File

@@ -2,14 +2,11 @@ import argparse
import os
import regex
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from library.device_utils import init_ipex
init_ipex()
import open_clip
from library import sdxl_model_util, sdxl_train_util, train_util

367
stable_cascade_gen_img.py Normal file
View File

@@ -0,0 +1,367 @@
import argparse
import importlib
import math
import os
import random
import time
import numpy as np
from safetensors.torch import load_file, save_file
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTextConfig
from PIL import Image
from accelerate import init_empty_weights
import library.stable_cascade as sc
import library.stable_cascade_utils as sc_utils
import library.device_utils as device_utils
from library import train_util
from library.sdxl_model_util import _load_state_dict_on_device
def main(args):
device = device_utils.get_preferred_device()
loading_device = device if not args.lowvram else "cpu"
text_model_device = "cpu"
dtype = torch.float32
if args.bf16:
dtype = torch.bfloat16
elif args.fp16:
dtype = torch.float16
text_model_dtype = torch.float32
# EfficientNet encoder
effnet = sc_utils.load_effnet(args.effnet_checkpoint_path, loading_device)
effnet.eval().requires_grad_(False).to(loading_device)
generator_c = sc_utils.load_stage_c_model(args.stage_c_checkpoint_path, dtype=dtype, device=loading_device)
generator_c.eval().requires_grad_(False).to(loading_device)
# if args.xformers or args.sdpa:
print(f"Stage C: use_xformers_or_sdpa: {args.xformers} {args.sdpa}")
generator_c.set_use_xformers_or_sdpa(args.xformers, args.sdpa)
generator_b = sc_utils.load_stage_b_model(args.stage_b_checkpoint_path, dtype=dtype, device=loading_device)
generator_b.eval().requires_grad_(False).to(loading_device)
# if args.xformers or args.sdpa:
print(f"Stage B: use_xformers_or_sdpa: {args.xformers} {args.sdpa}")
generator_b.set_use_xformers_or_sdpa(args.xformers, args.sdpa)
# CLIP encoders
tokenizer = sc_utils.load_tokenizer(args)
text_model = sc_utils.load_clip_text_model(
args.text_model_checkpoint_path, text_model_dtype, text_model_device, args.save_text_model
)
text_model = text_model.requires_grad_(False).to(text_model_dtype).to(text_model_device)
# image_model = (
# CLIPVisionModelWithProjection.from_pretrained(clip_image_model_name).requires_grad_(False).to(dtype).to(device)
# )
# vqGAN
stage_a = sc_utils.load_stage_a_model(args.stage_a_checkpoint_path, dtype=dtype, device=loading_device)
stage_a.eval().requires_grad_(False)
# previewer
if args.previewer_checkpoint_path is not None:
previewer = sc_utils.load_previewer_model(args.previewer_checkpoint_path, dtype=dtype, device=loading_device)
previewer.eval().requires_grad_(False)
else:
previewer = None
# LoRA
if args.network_module:
for i, network_module in enumerate(args.network_module):
print("import network module:", network_module)
imported_module = importlib.import_module(network_module)
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
net_kwargs = {}
if args.network_args and i < len(args.network_args):
network_args = args.network_args[i]
# TODO escape special chars
network_args = network_args.split(";")
for net_arg in network_args:
key, value = net_arg.split("=")
net_kwargs[key] = value
if args.network_weights is None or len(args.network_weights) <= i:
raise ValueError("No weight. Weight is required.")
network_weight = args.network_weights[i]
print("load network weights from:", network_weight)
network, weights_sd = imported_module.create_network_from_weights(
network_mul, network_weight, effnet, text_model, generator_c, for_inference=True, **net_kwargs
)
if network is None:
return
mergeable = network.is_mergeable()
assert mergeable, "not-mergeable network is not supported yet."
network.merge_to(text_model, generator_c, weights_sd, dtype, device)
# 謎のクラス gdf
gdf_c = sc.GDF(
schedule=sc.CosineSchedule(clamp_range=[0.0001, 0.9999]),
input_scaler=sc.VPScaler(),
target=sc.EpsilonTarget(),
noise_cond=sc.CosineTNoiseCond(),
loss_weight=None,
)
gdf_b = sc.GDF(
schedule=sc.CosineSchedule(clamp_range=[0.0001, 0.9999]),
input_scaler=sc.VPScaler(),
target=sc.EpsilonTarget(),
noise_cond=sc.CosineTNoiseCond(),
loss_weight=None,
)
# Stage C Parameters
# extras.sampling_configs["cfg"] = 4
# extras.sampling_configs["shift"] = 2
# extras.sampling_configs["timesteps"] = 20
# extras.sampling_configs["t_start"] = 1.0
# # Stage B Parameters
# extras_b.sampling_configs["cfg"] = 1.1
# extras_b.sampling_configs["shift"] = 1
# extras_b.sampling_configs["timesteps"] = 10
# extras_b.sampling_configs["t_start"] = 1.0
b_cfg = 1.1
b_shift = 1
b_timesteps = 10
b_t_start = 1.0
# caption = "Cinematic photo of an anthropomorphic penguin sitting in a cafe reading a book and having a coffee"
# height, width = 1024, 1024
while True:
print("type caption:")
# if Ctrl+Z is pressed, it will raise EOFError
try:
caption = input()
except EOFError:
break
caption = caption.strip()
if caption == "":
continue
# parse options: '--w' and '--h' for size, '--l' for cfg, '--s' for timesteps, '--f' for shift. if not specified, use default values
# e.g. "caption --w 4 --h 4 --l 20 --s 20 --f 1.0"
tokens = caption.split()
width = height = 1024
cfg = 4
timesteps = 20
shift = 2
t_start = 1.0
negative_prompt = ""
seed = None
caption_tokens = []
i = 0
while i < len(tokens):
token = tokens[i]
if i == len(tokens) - 1:
caption_tokens.append(token)
i += 1
continue
if token == "--w":
width = int(tokens[i + 1])
elif token == "--h":
height = int(tokens[i + 1])
elif token == "--l":
cfg = float(tokens[i + 1])
elif token == "--s":
timesteps = int(tokens[i + 1])
elif token == "--f":
shift = float(tokens[i + 1])
elif token == "--t":
t_start = float(tokens[i + 1])
elif token == "--n":
negative_prompt = tokens[i + 1]
elif token == "--d":
seed = int(tokens[i + 1])
else:
caption_tokens.append(token)
i += 1
continue
i += 2
caption = " ".join(caption_tokens)
stage_c_latent_shape, stage_b_latent_shape = sc_utils.calculate_latent_sizes(height, width, batch_size=1)
# PREPARE CONDITIONS
# cond_text, cond_pooled = sc.get_clip_conditions([caption], None, tokenizer, text_model)
input_ids = tokenizer(
[caption], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
)["input_ids"].to(text_model.device)
cond_text, cond_pooled = train_util.get_hidden_states_stable_cascade(
tokenizer.model_max_length, input_ids, tokenizer, text_model
)
cond_text = cond_text.to(device, dtype=dtype)
cond_pooled = cond_pooled.unsqueeze(1).to(device, dtype=dtype)
# uncond_text, uncond_pooled = sc.get_clip_conditions([""], None, tokenizer, text_model)
input_ids = tokenizer(
[negative_prompt], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
)["input_ids"].to(text_model.device)
uncond_text, uncond_pooled = train_util.get_hidden_states_stable_cascade(
tokenizer.model_max_length, input_ids, tokenizer, text_model
)
uncond_text = uncond_text.to(device, dtype=dtype)
uncond_pooled = uncond_pooled.unsqueeze(1).to(device, dtype=dtype)
zero_img_emb = torch.zeros(1, 768, device=device)
# 辞書にしたくないけど GDF から先の変更が面倒だからとりあえず辞書にしておく
conditions = {"clip_text_pooled": cond_pooled, "clip": cond_pooled, "clip_text": cond_text, "clip_img": zero_img_emb}
unconditions = {
"clip_text_pooled": uncond_pooled,
"clip": uncond_pooled,
"clip_text": uncond_text,
"clip_img": zero_img_emb,
}
conditions_b = {}
conditions_b.update(conditions)
unconditions_b = {}
unconditions_b.update(unconditions)
# seed everything
if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
if args.lowvram:
generator_c = generator_c.to(device)
with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype):
sampling_c = gdf_c.sample(
generator_c,
conditions,
stage_c_latent_shape,
unconditions,
device=device,
cfg=cfg,
shift=shift,
timesteps=timesteps,
t_start=t_start,
)
for sampled_c, _, _ in tqdm(sampling_c, total=timesteps):
sampled_c = sampled_c
conditions_b["effnet"] = sampled_c
unconditions_b["effnet"] = torch.zeros_like(sampled_c)
if previewer is not None:
with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype):
preview = previewer(sampled_c)
preview = preview.clamp(0, 1)
preview = preview.permute(0, 2, 3, 1).squeeze(0)
preview = preview.detach().float().cpu().numpy()
preview = Image.fromarray((preview * 255).astype(np.uint8))
timestamp_str = time.strftime("%Y%m%d_%H%M%S")
os.makedirs(args.outdir, exist_ok=True)
preview.save(os.path.join(args.outdir, f"preview_{timestamp_str}.png"))
if args.lowvram:
generator_c = generator_c.to(loading_device)
device_utils.clean_memory_on_device(device)
generator_b = generator_b.to(device)
with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype):
sampling_b = gdf_b.sample(
generator_b,
conditions_b,
stage_b_latent_shape,
unconditions_b,
device=device,
cfg=b_cfg,
shift=b_shift,
timesteps=b_timesteps,
t_start=b_t_start,
)
for sampled_b, _, _ in tqdm(sampling_b, total=b_t_start):
sampled_b = sampled_b
if args.lowvram:
generator_b = generator_b.to(loading_device)
device_utils.clean_memory_on_device(device)
stage_a = stage_a.to(device)
with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype):
sampled = stage_a.decode(sampled_b).float()
# print(sampled.shape, sampled.min(), sampled.max())
if args.lowvram:
stage_a = stage_a.to(loading_device)
device_utils.clean_memory_on_device(device)
# float 0-1 to PIL Image
sampled = sampled.clamp(0, 1)
sampled = sampled.mul(255).to(dtype=torch.uint8)
sampled = sampled.permute(0, 2, 3, 1)
sampled = sampled.cpu().numpy()
sampled = Image.fromarray(sampled[0])
timestamp_str = time.strftime("%Y%m%d_%H%M%S")
os.makedirs(args.outdir, exist_ok=True)
sampled.save(os.path.join(args.outdir, f"sampled_{timestamp_str}.png"))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
sc_utils.add_effnet_arguments(parser)
train_util.add_tokenizer_arguments(parser)
sc_utils.add_stage_a_arguments(parser)
sc_utils.add_stage_b_arguments(parser)
sc_utils.add_stage_c_arguments(parser)
sc_utils.add_previewer_arguments(parser)
sc_utils.add_text_model_arguments(parser)
parser.add_argument("--bf16", action="store_true")
parser.add_argument("--fp16", action="store_true")
parser.add_argument("--xformers", action="store_true")
parser.add_argument("--sdpa", action="store_true")
parser.add_argument("--outdir", type=str, default="../outputs", help="dir to write results to / 生成画像の出力先")
parser.add_argument("--lowvram", action="store_true", help="if specified, use low VRAM mode")
parser.add_argument(
"--network_module",
type=str,
default=None,
nargs="*",
help="additional network module to use / 追加ネットワークを使う時そのモジュール名",
)
parser.add_argument(
"--network_weights", type=str, default=None, nargs="*", help="additional network weights to load / 追加ネットワークの重み"
)
parser.add_argument(
"--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率"
)
parser.add_argument(
"--network_args",
type=str,
default=None,
nargs="*",
help="additional arguments for network (key=value) / ネットワークへの追加の引数",
)
args = parser.parse_args()
main(args)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,564 @@
# training with captions
import argparse
import math
import os
from multiprocessing import Value
from typing import List
import toml
from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
import library.train_util as train_util
from library.sdxl_train_util import add_sdxl_training_arguments
import library.stable_cascade_utils as sc_utils
import library.stable_cascade as sc
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
import library.config_util as config_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
setup_logging(args, reset=True)
# assert (
# not args.weighted_captions
# ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
# TODO add assertions for other unsupported options
cache_latents = args.cache_latents
use_dreambooth_method = args.in_json is None
if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する
tokenizer = sc_utils.load_tokenizer(args)
# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
if args.dataset_config is not None:
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
if use_dreambooth_method:
logger.info("Using DreamBooth method.")
user_config = {
"datasets": [
{
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
args.train_data_dir, args.reg_data_dir
)
}
]
}
else:
logger.info("Training with captions.")
user_config = {
"datasets": [
{
"subsets": [
{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}
]
}
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer])
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, [tokenizer])
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
train_dataset_group.verify_bucket_reso_steps(32)
if args.debug_dataset:
train_util.debug_dataset(train_dataset_group, True)
return
if len(train_dataset_group) == 0:
logger.error(
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
)
return
if cache_latents:
assert (
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
if args.cache_text_encoder_outputs:
assert (
train_dataset_group.is_text_encoder_output_cacheable()
), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
# acceleratorを準備する
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
effnet_dtype = torch.float32 if args.no_half_vae else weight_dtype
# モデルを読み込む
loading_device = accelerator.device if args.lowram else "cpu"
effnet = sc_utils.load_effnet(args.effnet_checkpoint_path, loading_device)
stage_c = sc_utils.load_stage_c_model(args.stage_c_checkpoint_path, device=loading_device) # dtype is as it is
text_encoder1 = sc_utils.load_clip_text_model(args.text_model_checkpoint_path, dtype=weight_dtype, device=loading_device)
if args.sample_at_first or args.sample_every_n_steps is not None or args.sample_every_n_epochs is not None:
# Previewer is small enough to be loaded on CPU
previewer = sc_utils.load_previewer_model(args.previewer_checkpoint_path, dtype=torch.float32, device="cpu")
previewer.eval()
else:
previewer = None
# モデルに xformers とか memory efficient attention を組み込む
stage_c.set_use_xformers_or_sdpa(args.xformers, args.sdpa)
# 学習を準備する
if cache_latents:
effnet.to(accelerator.device, dtype=effnet_dtype)
effnet.requires_grad_(False)
effnet.eval()
with torch.no_grad():
train_dataset_group.cache_latents(
effnet,
args.vae_batch_size,
args.cache_latents_to_disk,
accelerator.is_main_process,
train_util.STABLE_CASCADE_LATENTS_CACHE_SUFFIX,
32,
)
effnet.to("cpu")
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
# 学習を準備する:モデルを適切な状態にする
if args.gradient_checkpointing:
accelerator.print("enable gradient checkpointing")
stage_c.set_gradient_checkpointing(True)
train_stage_c = args.learning_rate > 0
train_text_encoder1 = False
if args.train_text_encoder:
accelerator.print("enable text encoder training")
if args.gradient_checkpointing:
text_encoder1.gradient_checkpointing_enable()
lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train
train_text_encoder1 = lr_te1 > 0
assert (
train_text_encoder1
), "text_encoder1 learning rate is 0. Please set a positive value / text_encoder1の学習率が0です。正の値を設定してください。"
if not train_text_encoder1:
text_encoder1.to(weight_dtype)
text_encoder1.requires_grad_(train_text_encoder1)
text_encoder1.train(train_text_encoder1)
else:
text_encoder1.to(weight_dtype)
text_encoder1.requires_grad_(False)
text_encoder1.eval()
# TextEncoderの出力をキャッシュする
if args.cache_text_encoder_outputs:
# Text Encodes are eval and no grad
with torch.no_grad(), accelerator.autocast():
train_dataset_group.cache_text_encoder_outputs(
(tokenizer,),
(text_encoder1,),
accelerator.device,
None,
args.cache_text_encoder_outputs_to_disk,
accelerator.is_main_process,
sc_utils.TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX,
)
accelerator.wait_for_everyone()
if not cache_latents:
effnet.requires_grad_(False)
effnet.eval()
effnet.to(accelerator.device, dtype=effnet_dtype)
stage_c.requires_grad_(True)
if not train_stage_c:
stage_c.to(accelerator.device, dtype=weight_dtype) # because of stage_c will not be prepared
training_models = []
params_to_optimize = []
if train_stage_c:
training_models.append(stage_c)
params_to_optimize.append({"params": list(stage_c.parameters()), "lr": args.learning_rate})
if train_text_encoder1:
training_models.append(text_encoder1)
params_to_optimize.append({"params": list(text_encoder1.parameters()), "lr": args.learning_rate_te1 or args.learning_rate})
# calculate number of trainable parameters
n_params = 0
for params in params_to_optimize:
for p in params["params"]:
n_params += p.numel()
accelerator.print(f"train stage-C: {train_stage_c}, text_encoder1: {train_text_encoder1}")
accelerator.print(f"number of models: {len(training_models)}")
accelerator.print(f"number of trainable parameters: {n_params}")
# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# 実験的機能勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
if args.full_fp16:
assert (
args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
accelerator.print("enable full fp16 training.")
stage_c.to(weight_dtype)
text_encoder1.to(weight_dtype)
elif args.full_bf16:
assert (
args.mixed_precision == "bf16"
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
accelerator.print("enable full bf16 training.")
stage_c.to(weight_dtype)
text_encoder1.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
if train_stage_c:
stage_c = accelerator.prepare(stage_c)
if train_text_encoder1:
text_encoder1 = accelerator.prepare(text_encoder1)
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
if args.cache_text_encoder_outputs:
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
text_encoder1.to("cpu", dtype=torch.float32)
clean_memory_on_device(accelerator.device)
else:
# make sure Text Encoders are on GPU
text_encoder1.to(accelerator.device)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
# 学習する
# total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
accelerator.print("running training / 学習開始")
accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
accelerator.print(
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
)
# accelerator.print(
# f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
# )
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
# 謎のクラス GDF
gdf = sc.GDF(
schedule=sc.CosineSchedule(clamp_range=[0.0001, 0.9999]),
input_scaler=sc.VPScaler(),
target=sc.EpsilonTarget(),
noise_cond=sc.CosineTNoiseCond(),
loss_weight=sc.AdaptiveLossWeight() if args.adaptive_loss_weight else sc.P2LossWeight(),
)
# 以下2つの変数は、どうもデフォルトのままっぽい
# gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges'])
# gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses'])
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
# For --sample_at_first
sc_utils.sample_images(accelerator, args, 0, global_step, previewer, tokenizer, text_encoder1, stage_c, gdf)
loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
for m in training_models:
m.train()
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(*training_models):
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
with torch.no_grad():
# latentに変換
# XXX Effnet preprocessing is included in encode method
latents = effnet.encode(batch["images"].to(effnet_dtype)).latent_dist.sample().to(weight_dtype)
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.nan_to_num(latents, 0, out=latents)
# # debug: decode latent with previewer and save it
# import time
# import numpy as np
# from PIL import Image
# ts = time.time()
# images = previewer(latents.to(previewer.device, dtype=previewer.dtype))
# for i, img in enumerate(images):
# img = img.detach().cpu().numpy().transpose(1, 2, 0)
# img = np.clip(img, 0, 1)
# img = (img * 255).astype(np.uint8)
# img = Image.fromarray(img)
# img.save(f"logs/previewer_{i}_{ts}.png")
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
input_ids1 = batch["input_ids"]
with torch.set_grad_enabled(args.train_text_encoder):
# Get the text embedding for conditioning
# TODO support weighted captions
input_ids1 = input_ids1.to(accelerator.device)
# unwrap_model is fine for models not wrapped by accelerator
encoder_hidden_states, pool = train_util.get_hidden_states_stable_cascade(
args.max_token_length,
input_ids1,
tokenizer,
text_encoder1,
None if not args.full_fp16 else weight_dtype,
accelerator,
)
else:
encoder_hidden_states = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
pool = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype)
pool = pool.unsqueeze(1) # add extra dimension b,1280 -> b,1,1280
# FORWARD PASS
with torch.no_grad():
noised, noise, target, logSNR, noise_cond, loss_weight = gdf.diffuse(latents, shift=1, loss_shift=1)
zero_img_emb = torch.zeros(noised.shape[0], 768, device=accelerator.device)
with accelerator.autocast():
pred = stage_c(
noised, noise_cond, clip_text=encoder_hidden_states, clip_text_pooled=pool, clip_img=zero_img_emb
)
loss = torch.nn.functional.mse_loss(pred, target, reduction="none").mean(dim=[1, 2, 3])
loss_adjusted = (loss * loss_weight).mean()
if args.adaptive_loss_weight:
gdf.loss_weight.update_buckets(logSNR, loss) # use loss instead of loss_adjusted
accelerator.backward(loss_adjusted)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = []
for m in training_models:
params_to_clip.extend(m.parameters())
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
sc_utils.sample_images(accelerator, args, None, global_step, previewer, tokenizer, text_encoder1, stage_c, gdf)
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
sc_utils.save_stage_c_model_on_epoch_end_or_stepwise(
args,
False,
accelerator,
save_dtype,
epoch,
num_train_epochs,
global_step,
accelerator.unwrap_model(stage_c),
accelerator.unwrap_model(text_encoder1) if train_text_encoder1 else None,
)
current_loss = loss_adjusted.detach().item() # 平均なのでbatch sizeは関係ないはず
if args.logging_dir is not None:
logs = {"loss": current_loss}
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
accelerator.log(logs, step=global_step)
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
if args.save_every_n_epochs is not None:
if accelerator.is_main_process:
sc_utils.save_stage_c_model_on_epoch_end_or_stepwise(
args,
True,
accelerator,
save_dtype,
epoch,
num_train_epochs,
global_step,
accelerator.unwrap_model(stage_c),
accelerator.unwrap_model(text_encoder1) if train_text_encoder1 else None,
)
sc_utils.sample_images(accelerator, args, epoch + 1, global_step, previewer, tokenizer, text_encoder1, stage_c, gdf)
is_main_process = accelerator.is_main_process
# if is_main_process:
stage_c = accelerator.unwrap_model(stage_c)
text_encoder1 = accelerator.unwrap_model(text_encoder1)
accelerator.end_training()
if args.save_state: # and is_main_process:
train_util.save_state_on_train_end(args, accelerator)
del accelerator # この後メモリを使うのでこれは消す
if is_main_process:
sc_utils.save_stage_c_model_on_end(
args, save_dtype, epoch, global_step, stage_c, text_encoder1 if train_text_encoder1 else None
)
logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
sc_utils.add_effnet_arguments(parser)
sc_utils.add_stage_c_arguments(parser)
sc_utils.add_text_model_arguments(parser)
sc_utils.add_previewer_arguments(parser)
sc_utils.add_training_arguments(parser)
train_util.add_tokenizer_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, False)
train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
add_sdxl_training_arguments(parser) # cache text encoder outputs
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
parser.add_argument(
"--learning_rate_te1",
type=float,
default=None,
help="learning rate for text encoder / text encoderの学習率",
)
parser.add_argument(
"--no_half_vae",
action="store_true",
help="do not use fp16/bf16 Effnet in mixed precision (use float Effnet) / mixed precisionでも fp16/bf16 Effnetを使わずfloat Effnetを使う",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
train(args)

View File

@@ -16,7 +16,10 @@ from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def cache_to_disk(args: argparse.Namespace) -> None:
train_util.prepare_dataset_args(args, True)
@@ -41,18 +44,18 @@ def cache_to_disk(args: argparse.Namespace) -> None:
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
if use_dreambooth_method:
print("Using DreamBooth method.")
logger.info("Using DreamBooth method.")
user_config = {
"datasets": [
{
@@ -63,7 +66,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
]
}
else:
print("Training with captions.")
logger.info("Training with captions.")
user_config = {
"datasets": [
{
@@ -90,7 +93,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
# acceleratorを準備する
print("prepare accelerator")
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
@@ -98,7 +101,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
# モデルを読み込む
print("load model")
logger.info("load model")
if args.sdxl:
(_, _, _, vae, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
else:
@@ -113,8 +116,8 @@ def cache_to_disk(args: argparse.Namespace) -> None:
# dataloaderを準備する
train_dataset_group.set_caching_mode("latents")
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
@@ -152,7 +155,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
if args.skip_existing:
if train_util.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug):
print(f"Skipping {image_info.latents_npz} because it already exists.")
logger.warning(f"Skipping {image_info.latents_npz} because it already exists.")
continue
image_infos.append(image_info)

View File

@@ -16,7 +16,10 @@ from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def cache_to_disk(args: argparse.Namespace) -> None:
train_util.prepare_dataset_args(args, True)
@@ -48,18 +51,18 @@ def cache_to_disk(args: argparse.Namespace) -> None:
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
if use_dreambooth_method:
print("Using DreamBooth method.")
logger.info("Using DreamBooth method.")
user_config = {
"datasets": [
{
@@ -70,7 +73,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
]
}
else:
print("Training with captions.")
logger.info("Training with captions.")
user_config = {
"datasets": [
{
@@ -95,14 +98,14 @@ def cache_to_disk(args: argparse.Namespace) -> None:
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
# acceleratorを準備する
print("prepare accelerator")
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, _ = train_util.prepare_dtype(args)
# モデルを読み込む
print("load model")
logger.info("load model")
if args.sdxl:
(_, text_encoder1, text_encoder2, _, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
text_encoders = [text_encoder1, text_encoder2]
@@ -118,8 +121,8 @@ def cache_to_disk(args: argparse.Namespace) -> None:
# dataloaderを準備する
train_dataset_group.set_caching_mode("text")
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
@@ -147,7 +150,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
if args.skip_existing:
if os.path.exists(image_info.text_encoder_outputs_npz):
print(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.")
logger.warning(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.")
continue
image_info.input_ids1 = input_ids1

View File

@@ -1,6 +1,10 @@
import argparse
import cv2
import logging
from library.utils import setup_logging
setup_logging()
logger = logging.getLogger(__name__)
def canny(args):
img = cv2.imread(args.input)
@@ -10,7 +14,7 @@ def canny(args):
# canny_img = 255 - canny_img
cv2.imwrite(args.output, canny_img)
print("done!")
logger.info("done!")
def setup_parser() -> argparse.ArgumentParser:

View File

@@ -6,7 +6,10 @@ import torch
from diffusers import StableDiffusionPipeline
import library.model_util as model_util
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def convert(args):
# 引数を確認する
@@ -30,7 +33,7 @@ def convert(args):
# モデルを読み込む
msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else ""))
print(f"loading {msg}: {args.model_to_load}")
logger.info(f"loading {msg}: {args.model_to_load}")
if is_load_ckpt:
v2_model = args.v2
@@ -48,13 +51,13 @@ def convert(args):
if args.v1 == args.v2:
# 自動判定する
v2_model = unet.config.cross_attention_dim == 1024
print("checking model version: model is " + ("v2" if v2_model else "v1"))
logger.info("checking model version: model is " + ("v2" if v2_model else "v1"))
else:
v2_model = not args.v1
# 変換して保存する
msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers"
print(f"converting and saving as {msg}: {args.model_to_save}")
logger.info(f"converting and saving as {msg}: {args.model_to_save}")
if is_save_ckpt:
original_model = args.model_to_load if is_load_ckpt else None
@@ -70,15 +73,15 @@ def convert(args):
save_dtype=save_dtype,
vae=vae,
)
print(f"model saved. total converted state_dict keys: {key_count}")
logger.info(f"model saved. total converted state_dict keys: {key_count}")
else:
print(
logger.info(
f"copy scheduler/tokenizer config from: {args.reference_model if args.reference_model is not None else 'default model'}"
)
model_util.save_diffusers_checkpoint(
v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors
)
print("model saved.")
logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser:

View File

@@ -15,6 +15,10 @@ import os
from anime_face_detector import create_detector
from tqdm import tqdm
import numpy as np
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
KP_REYE = 11
KP_LEYE = 19
@@ -24,7 +28,7 @@ SCORE_THRES = 0.90
def detect_faces(detector, image, min_size):
preds = detector(image) # bgr
# print(len(preds))
# logger.info(len(preds))
faces = []
for pred in preds:
@@ -78,7 +82,7 @@ def process(args):
assert args.crop_ratio is None or args.resize_face_size is None, f"crop_ratio指定時はresize_face_sizeは指定できません"
# アニメ顔検出モデルを読み込む
print("loading face detector.")
logger.info("loading face detector.")
detector = create_detector('yolov3')
# cropの引数を解析する
@@ -97,7 +101,7 @@ def process(args):
crop_h_ratio, crop_v_ratio = [float(t) for t in tokens]
# 画像を処理する
print("processing.")
logger.info("processing.")
output_extension = ".png"
os.makedirs(args.dst_dir, exist_ok=True)
@@ -111,7 +115,7 @@ def process(args):
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
if image.shape[2] == 4:
print(f"image has alpha. ignore / 画像の透明度が設定されているため無視します: {path}")
logger.warning(f"image has alpha. ignore / 画像の透明度が設定されているため無視します: {path}")
image = image[:, :, :3].copy() # copyをしないと内部的に透明度情報が付いたままになるらしい
h, w = image.shape[:2]
@@ -144,11 +148,11 @@ def process(args):
# 顔サイズを基準にリサイズする
scale = args.resize_face_size / face_size
if scale < cur_crop_width / w:
print(
logger.warning(
f"image width too small in face size based resizing / 顔を基準にリサイズすると画像の幅がcrop sizeより小さい顔が相対的に大きすぎるので顔サイズが変わります: {path}")
scale = cur_crop_width / w
if scale < cur_crop_height / h:
print(
logger.warning(
f"image height too small in face size based resizing / 顔を基準にリサイズすると画像の高さがcrop sizeより小さい顔が相対的に大きすぎるので顔サイズが変わります: {path}")
scale = cur_crop_height / h
elif crop_h_ratio is not None:
@@ -157,10 +161,10 @@ def process(args):
else:
# 切り出しサイズ指定あり
if w < cur_crop_width:
print(f"image width too small/ 画像の幅がcrop sizeより小さいので画質が劣化します: {path}")
logger.warning(f"image width too small/ 画像の幅がcrop sizeより小さいので画質が劣化します: {path}")
scale = cur_crop_width / w
if h < cur_crop_height:
print(f"image height too small/ 画像の高さがcrop sizeより小さいので画質が劣化します: {path}")
logger.warning(f"image height too small/ 画像の高さがcrop sizeより小さいので画質が劣化します: {path}")
scale = cur_crop_height / h
if args.resize_fit:
scale = max(cur_crop_width / w, cur_crop_height / h)
@@ -198,7 +202,7 @@ def process(args):
face_img = face_img[y:y + cur_crop_height]
# # debug
# print(path, cx, cy, angle)
# logger.info(path, cx, cy, angle)
# crp = cv2.resize(image, (image.shape[1]//8, image.shape[0]//8))
# cv2.imshow("image", crp)
# if cv2.waitKey() == 27:

View File

@@ -11,10 +11,16 @@ from typing import Dict, List
import numpy as np
import torch
from library.device_utils import init_ipex, get_preferred_device
init_ipex()
from torch import nn
from tqdm import tqdm
from PIL import Image
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=1):
@@ -216,7 +222,7 @@ class Upscaler(nn.Module):
upsampled_images = upsampled_images / 127.5 - 1.0
# convert upsample images to latents with batch size
# print("Encoding upsampled (LANCZOS4) images...")
# logger.info("Encoding upsampled (LANCZOS4) images...")
upsampled_latents = []
for i in tqdm(range(0, upsampled_images.shape[0], vae_batch_size)):
batch = upsampled_images[i : i + vae_batch_size].to(vae.device)
@@ -227,7 +233,7 @@ class Upscaler(nn.Module):
upsampled_latents = torch.cat(upsampled_latents, dim=0)
# upscale (refine) latents with this model with batch size
print("Upscaling latents...")
logger.info("Upscaling latents...")
upscaled_latents = []
for i in range(0, upsampled_latents.shape[0], batch_size):
with torch.no_grad():
@@ -242,7 +248,7 @@ def create_upscaler(**kwargs):
weights = kwargs["weights"]
model = Upscaler()
print(f"Loading weights from {weights}...")
logger.info(f"Loading weights from {weights}...")
if os.path.splitext(weights)[1] == ".safetensors":
from safetensors.torch import load_file
@@ -255,20 +261,20 @@ def create_upscaler(**kwargs):
# another interface: upscale images with a model for given images from command line
def upscale_images(args: argparse.Namespace):
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE = get_preferred_device()
us_dtype = torch.float16 # TODO: support fp32/bf16
os.makedirs(args.output_dir, exist_ok=True)
# load VAE with Diffusers
assert args.vae_path is not None, "VAE path is required"
print(f"Loading VAE from {args.vae_path}...")
logger.info(f"Loading VAE from {args.vae_path}...")
vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae")
vae.to(DEVICE, dtype=us_dtype)
# prepare model
print("Preparing model...")
logger.info("Preparing model...")
upscaler: Upscaler = create_upscaler(weights=args.weights)
# print("Loading weights from", args.weights)
# logger.info("Loading weights from", args.weights)
# upscaler.load_state_dict(torch.load(args.weights))
upscaler.eval()
upscaler.to(DEVICE, dtype=us_dtype)
@@ -303,14 +309,14 @@ def upscale_images(args: argparse.Namespace):
image_debug.save(dest_file_name)
# upscale
print("Upscaling...")
logger.info("Upscaling...")
upscaled_latents = upscaler.upscale(
vae, images, None, us_dtype, width * 2, height * 2, batch_size=args.batch_size, vae_batch_size=args.vae_batch_size
)
upscaled_latents /= 0.18215
# decode with batch
print("Decoding...")
logger.info("Decoding...")
upscaled_images = []
for i in tqdm(range(0, upscaled_latents.shape[0], args.vae_batch_size)):
with torch.no_grad():

View File

@@ -5,7 +5,10 @@ import torch
from safetensors import safe_open
from safetensors.torch import load_file, save_file
from tqdm import tqdm
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def is_unet_key(key):
# VAE or TextEncoder, the last one is for SDXL
@@ -45,10 +48,10 @@ def merge(args):
# check if all models are safetensors
for model in args.models:
if not model.endswith("safetensors"):
print(f"Model {model} is not a safetensors model")
logger.info(f"Model {model} is not a safetensors model")
exit()
if not os.path.isfile(model):
print(f"Model {model} does not exist")
logger.info(f"Model {model} does not exist")
exit()
assert args.ratios is None or len(args.models) == len(args.ratios), "ratios must be the same length as models"
@@ -65,7 +68,7 @@ def merge(args):
if merged_sd is None:
# load first model
print(f"Loading model {model}, ratio = {ratio}...")
logger.info(f"Loading model {model}, ratio = {ratio}...")
merged_sd = {}
with safe_open(model, framework="pt", device=args.device) as f:
for key in tqdm(f.keys()):
@@ -81,11 +84,11 @@ def merge(args):
value = ratio * value.to(dtype) # first model's value * ratio
merged_sd[key] = value
print(f"Model has {len(merged_sd)} keys " + ("(UNet only)" if args.unet_only else ""))
logger.info(f"Model has {len(merged_sd)} keys " + ("(UNet only)" if args.unet_only else ""))
continue
# load other models
print(f"Loading model {model}, ratio = {ratio}...")
logger.info(f"Loading model {model}, ratio = {ratio}...")
with safe_open(model, framework="pt", device=args.device) as f:
model_keys = f.keys()
@@ -93,7 +96,7 @@ def merge(args):
_, new_key = replace_text_encoder_key(key)
if new_key not in merged_sd:
if args.show_skipped and new_key not in first_model_keys:
print(f"Skip: {new_key}")
logger.info(f"Skip: {new_key}")
continue
value = f.get_tensor(key)
@@ -104,7 +107,7 @@ def merge(args):
for key in merged_sd.keys():
if key in model_keys:
continue
print(f"Key {key} not in model {model}, use first model's value")
logger.warning(f"Key {key} not in model {model}, use first model's value")
if key in supplementary_key_ratios:
supplementary_key_ratios[key] += ratio
else:
@@ -112,7 +115,7 @@ def merge(args):
# add supplementary keys' value (including VAE and TextEncoder)
if len(supplementary_key_ratios) > 0:
print("add first model's value")
logger.info("add first model's value")
with safe_open(args.models[0], framework="pt", device=args.device) as f:
for key in tqdm(f.keys()):
_, new_key = replace_text_encoder_key(key)
@@ -120,7 +123,7 @@ def merge(args):
continue
if is_unet_key(new_key): # not VAE or TextEncoder
print(f"Key {new_key} not in all models, ratio = {supplementary_key_ratios[new_key]}")
logger.warning(f"Key {new_key} not in all models, ratio = {supplementary_key_ratios[new_key]}")
value = f.get_tensor(key) # original key
@@ -134,7 +137,7 @@ def merge(args):
if not output_file.endswith(".safetensors"):
output_file = output_file + ".safetensors"
print(f"Saving to {output_file}...")
logger.info(f"Saving to {output_file}...")
# convert to save_dtype
for k in merged_sd.keys():
@@ -142,7 +145,7 @@ def merge(args):
save_file(merged_sd, output_file)
print("Done!")
logger.info("Done!")
if __name__ == "__main__":

View File

@@ -7,7 +7,10 @@ from safetensors.torch import load_file
from library.original_unet import UNet2DConditionModel, SampleOutput
import library.model_util as model_util
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class ControlNetInfo(NamedTuple):
unet: Any
@@ -51,7 +54,7 @@ def load_control_net(v2, unet, model):
# control sdからキー変換しつつU-Netに対応する部分のみ取り出し、DiffusersのU-Netに読み込む
# state dictを読み込む
print(f"ControlNet: loading control SD model : {model}")
logger.info(f"ControlNet: loading control SD model : {model}")
if model_util.is_safetensors(model):
ctrl_sd_sd = load_file(model)
@@ -61,7 +64,7 @@ def load_control_net(v2, unet, model):
# 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む
is_difference = "difference" in ctrl_sd_sd
print("ControlNet: loading difference:", is_difference)
logger.info(f"ControlNet: loading difference: {is_difference}")
# ControlNetには存在しないキーがあるので、まず現在のU-NetでSD版の全keyを作っておく
# またTransfer Controlの元weightとなる
@@ -89,13 +92,13 @@ def load_control_net(v2, unet, model):
# ControlNetのU-Netを作成する
ctrl_unet = UNet2DConditionModel(**unet_config)
info = ctrl_unet.load_state_dict(ctrl_unet_du_sd)
print("ControlNet: loading Control U-Net:", info)
logger.info(f"ControlNet: loading Control U-Net: {info}")
# U-Net以外のControlNetを作成する
# TODO support middle only
ctrl_net = ControlNet()
info = ctrl_net.load_state_dict(zero_conv_sd)
print("ControlNet: loading ControlNet:", info)
logger.info("ControlNet: loading ControlNet: {info}")
ctrl_unet.to(unet.device, dtype=unet.dtype)
ctrl_net.to(unet.device, dtype=unet.dtype)
@@ -117,7 +120,7 @@ def load_preprocess(prep_type: str):
return canny
print("Unsupported prep type:", prep_type)
logger.info(f"Unsupported prep type: {prep_type}")
return None
@@ -174,13 +177,26 @@ def call_unet_and_control_net(
cnet_idx = step % cnet_cnt
cnet_info = control_nets[cnet_idx]
# print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
# logger.info(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
if cnet_info.ratio < current_ratio:
return original_unet(sample, timestep, encoder_hidden_states)
guided_hint = guided_hints[cnet_idx]
# gradual latent support: match the size of guided_hint to the size of sample
if guided_hint.shape[-2:] != sample.shape[-2:]:
# print(f"guided_hint.shape={guided_hint.shape}, sample.shape={sample.shape}")
org_dtype = guided_hint.dtype
if org_dtype == torch.bfloat16:
guided_hint = guided_hint.to(torch.float32)
guided_hint = torch.nn.functional.interpolate(guided_hint, size=sample.shape[-2:], mode="bicubic")
if org_dtype == torch.bfloat16:
guided_hint = guided_hint.to(org_dtype)
guided_hint = guided_hint.repeat((num_latent_input, 1, 1, 1))
outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states_for_control_net)
outs = unet_forward(
True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states_for_control_net
)
outs = [o * cnet_info.weight for o in outs]
# U-Net
@@ -192,7 +208,7 @@ def call_unet_and_control_net(
# ControlNet
cnet_outs_list = []
for i, cnet_info in enumerate(control_nets):
# print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
# logger.info(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
if cnet_info.ratio < current_ratio:
continue
guided_hint = guided_hints[i]
@@ -232,7 +248,7 @@ def unet_forward(
upsample_size = None
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
print("Forward upsample size to force interpolation output size.")
logger.info("Forward upsample size to force interpolation output size.")
forward_upsample_size = True
# 1. time

View File

@@ -6,7 +6,10 @@ import shutil
import math
from PIL import Image
import numpy as np
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2, interpolation=None, save_as_png=False, copy_associated_files=False):
# Split the max_resolution string by "," and strip any whitespaces
@@ -83,7 +86,7 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
image.save(os.path.join(dst_img_folder, new_filename), quality=100)
proc = "Resized" if current_pixels > max_pixels else "Saved"
print(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}")
logger.info(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}")
# If other files with same basename, copy them with resolution suffix
if copy_associated_files:
@@ -94,7 +97,7 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
continue
for max_resolution in max_resolutions:
new_asoc_file = base + '+' + max_resolution + ext
print(f"Copy {asoc_file} as {new_asoc_file}")
logger.info(f"Copy {asoc_file} as {new_asoc_file}")
shutil.copy(os.path.join(src_img_folder, asoc_file), os.path.join(dst_img_folder, new_asoc_file))

View File

@@ -1,6 +1,10 @@
import json
import argparse
from safetensors import safe_open
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, required=True)
@@ -10,10 +14,10 @@ with safe_open(args.model, framework="pt") as f:
metadata = f.metadata()
if metadata is None:
print("No metadata found")
logger.error("No metadata found")
else:
# metadata is json dict, but not pretty printed
# sort by key and pretty print
print(json.dumps(metadata, indent=4, sort_keys=True))

View File

@@ -0,0 +1,191 @@
# Stable Cascadeのlatentsをdiskにキャッシュする
# cache latents of Stable Cascade to disk
import argparse
import math
from multiprocessing import Value
import os
from accelerate.utils import set_seed
import torch
from tqdm import tqdm
from library import stable_cascade_utils as sc_utils
from library import config_util
from library import train_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def cache_to_disk(args: argparse.Namespace) -> None:
train_util.prepare_dataset_args(args, True)
# check cache latents arg
assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります"
use_dreambooth_method = args.in_json is None
if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する
# tokenizerを準備するdatasetを動かすために必要
tokenizer = sc_utils.load_tokenizer(args)
tokenizers = [tokenizer]
# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
if args.dataset_config is not None:
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
if use_dreambooth_method:
logger.info("Using DreamBooth method.")
user_config = {
"datasets": [
{
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
args.train_data_dir, args.reg_data_dir
)
}
]
}
else:
logger.info("Training with captions.")
user_config = {
"datasets": [
{
"subsets": [
{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}
]
}
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers)
# datasetのcache_latentsを呼ばなければ、生の画像が返る
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
# acceleratorを準備する
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, _ = train_util.prepare_dtype(args)
effnet_dtype = torch.float32 if args.no_half_vae else weight_dtype
# モデルを読み込む
logger.info("load model")
effnet = sc_utils.load_effnet(args.effnet_checkpoint_path, accelerator.device)
effnet.to(accelerator.device, dtype=effnet_dtype)
effnet.requires_grad_(False)
effnet.eval()
# dataloaderを準備する
train_dataset_group.set_caching_mode("latents")
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# acceleratorを使ってモデルを準備するマルチGPUで使えるようになるはず
train_dataloader = accelerator.prepare(train_dataloader)
# データ取得のためのループ
for batch in tqdm(train_dataloader):
b_size = len(batch["images"])
vae_batch_size = b_size if args.vae_batch_size is None else args.vae_batch_size
flip_aug = batch["flip_aug"]
random_crop = batch["random_crop"]
bucket_reso = batch["bucket_reso"]
# バッチを分割して処理する
for i in range(0, b_size, vae_batch_size):
images = batch["images"][i : i + vae_batch_size]
absolute_paths = batch["absolute_paths"][i : i + vae_batch_size]
resized_sizes = batch["resized_sizes"][i : i + vae_batch_size]
image_infos = []
for i, (image, absolute_path, resized_size) in enumerate(zip(images, absolute_paths, resized_sizes)):
image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path)
image_info.image = image
image_info.bucket_reso = bucket_reso
image_info.resized_size = resized_size
image_info.latents_npz = os.path.splitext(absolute_path)[0] + train_util.STABLE_CASCADE_LATENTS_CACHE_SUFFIX
if args.skip_existing:
if train_util.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug, 32):
logger.warning(f"Skipping {image_info.latents_npz} because it already exists.")
continue
image_infos.append(image_info)
if len(image_infos) > 0:
train_util.cache_batch_latents(effnet, True, image_infos, flip_aug, random_crop)
accelerator.wait_for_everyone()
accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
train_util.add_tokenizer_arguments(parser)
sc_utils.add_effnet_arguments(parser)
train_util.add_training_arguments(parser, True)
train_util.add_dataset_arguments(parser, True, True, True)
config_util.add_config_arguments(parser)
parser.add_argument(
"--no_half_vae",
action="store_true",
help="do not use fp16/bf16 Effnet in mixed precision (use float Effnet) / mixed precisionでも fp16/bf16 Effnetを使わずfloat Effnetを使う",
)
parser.add_argument(
"--skip_existing",
action="store_true",
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップするflip_aug有効時は通常、反転の両方が存在する画像をスキップ",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
cache_to_disk(args)

View File

@@ -0,0 +1,183 @@
# text encoder出力のdiskへの事前キャッシュを行う / cache text encoder outputs to disk in advance
import argparse
import math
from multiprocessing import Value
import os
from accelerate.utils import set_seed
import torch
from tqdm import tqdm
from library import config_util
from library import train_util
from library import sdxl_train_util
from library import stable_cascade_utils as sc_utils
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def cache_to_disk(args: argparse.Namespace) -> None:
train_util.prepare_dataset_args(args, True)
# check cache arg
assert (
args.cache_text_encoder_outputs_to_disk
), "cache_text_encoder_outputs_to_disk must be True / cache_text_encoder_outputs_to_diskはTrueである必要があります"
use_dreambooth_method = args.in_json is None
if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する
# tokenizerを準備するdatasetを動かすために必要
tokenizer = sc_utils.load_tokenizer(args)
tokenizers = [tokenizer]
# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
if args.dataset_config is not None:
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
if use_dreambooth_method:
logger.info("Using DreamBooth method.")
user_config = {
"datasets": [
{
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
args.train_data_dir, args.reg_data_dir
)
}
]
}
else:
logger.info("Training with captions.")
user_config = {
"datasets": [
{
"subsets": [
{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}
]
}
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
# acceleratorを準備する
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, _ = train_util.prepare_dtype(args)
# モデルを読み込む
logger.info("load model")
text_encoder = sc_utils.load_clip_text_model(
args.text_model_checkpoint_path, weight_dtype, accelerator.device, args.save_text_model
)
text_encoders = [text_encoder]
for text_encoder in text_encoders:
text_encoder.to(accelerator.device, dtype=weight_dtype)
text_encoder.requires_grad_(False)
text_encoder.eval()
# dataloaderを準備する
train_dataset_group.set_caching_mode("text")
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# acceleratorを使ってモデルを準備するマルチGPUで使えるようになるはず
train_dataloader = accelerator.prepare(train_dataloader)
# データ取得のためのループ
for batch in tqdm(train_dataloader):
absolute_paths = batch["absolute_paths"]
input_ids1_list = batch["input_ids1_list"]
image_infos = []
for absolute_path, input_ids1 in zip(absolute_paths, input_ids1_list):
image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path)
image_info.text_encoder_outputs_npz = os.path.splitext(absolute_path)[0] + sc_utils.TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX
image_info
if args.skip_existing:
if os.path.exists(image_info.text_encoder_outputs_npz):
logger.warning(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.")
continue
image_info.input_ids1 = input_ids1
image_infos.append(image_info)
if len(image_infos) > 0:
b_input_ids1 = torch.stack([image_info.input_ids1 for image_info in image_infos])
train_util.cache_batch_text_encoder_outputs(
image_infos, tokenizers, text_encoders, args.max_token_length, True, b_input_ids1, None, weight_dtype
)
accelerator.wait_for_everyone()
accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
train_util.add_tokenizer_arguments(parser)
sc_utils.add_text_model_arguments(parser)
train_util.add_training_arguments(parser, True)
train_util.add_dataset_arguments(parser, True, True, True)
config_util.add_config_arguments(parser)
sdxl_train_util.add_sdxl_training_arguments(parser)
parser.add_argument(
"--skip_existing",
action="store_true",
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップするflip_aug有効時は通常、反転の両方が存在する画像をスキップ",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
cache_to_disk(args)

View File

@@ -1,5 +1,4 @@
import argparse
import gc
import json
import math
import os
@@ -10,17 +9,11 @@ from types import SimpleNamespace
import toml
from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from torch.nn.parallel import DistributedDataParallel as DDP
from accelerate.utils import set_seed
from diffusers import DDPMScheduler, ControlNetModel
@@ -40,6 +33,12 @@ from library.custom_train_functions import (
pyramid_noise_like,
apply_noise_offset,
)
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
# TODO 他のスクリプトと共通化する
@@ -61,6 +60,7 @@ def train(args):
# training_started_at = time.time()
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
setup_logging(args, reset=True)
cache_latents = args.cache_latents
use_user_config = args.dataset_config is not None
@@ -74,11 +74,11 @@ def train(args):
# データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
if use_user_config:
print(f"Load dataset config from {args.dataset_config}")
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "conditioning_data_dir"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
@@ -108,7 +108,7 @@ def train(args):
train_util.debug_dataset(train_dataset_group)
return
if len(train_dataset_group) == 0:
print(
logger.error(
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してくださいtrain_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります"
)
return
@@ -119,7 +119,7 @@ def train(args):
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
# acceleratorを準備する
print("prepare accelerator")
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
is_main_process = accelerator.is_main_process
@@ -224,10 +224,8 @@ def train(args):
accelerator.is_main_process,
)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
if args.gradient_checkpointing:
@@ -241,8 +239,8 @@ def train(args):
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
@@ -258,7 +256,9 @@ def train(args):
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
@@ -314,8 +314,10 @@ def train(args):
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
# print(f" total train batch size (with parallel & distributed & accumulation) / バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
accelerator.print(
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
)
# logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
@@ -336,6 +338,8 @@ def train(args):
)
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
@@ -570,12 +574,13 @@ def train(args):
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(ckpt_name, controlnet, force_sync_upload=True)
print("model saved.")
logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)

View File

@@ -1,7 +1,6 @@
# DreamBooth training
# XXX dropped option: fine_tune
import gc
import argparse
import itertools
import math
@@ -10,17 +9,11 @@ from multiprocessing import Value
import toml
from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
@@ -40,6 +33,12 @@ from library.custom_train_functions import (
scale_v_prediction_loss_like_noise_prediction,
apply_debiased_estimation,
)
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
# perlin_noise,
@@ -47,6 +46,7 @@ from library.custom_train_functions import (
def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, False)
setup_logging(args, reset=True)
cache_latents = args.cache_latents
@@ -59,11 +59,11 @@ def train(args):
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, False, True))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "reg_data_dir"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
@@ -98,13 +98,13 @@ def train(args):
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
# acceleratorを準備する
print("prepare accelerator")
logger.info("prepare accelerator")
if args.gradient_accumulation_steps > 1:
print(
logger.warning(
f"gradient_accumulation_steps is {args.gradient_accumulation_steps}. accelerate does not support gradient_accumulation_steps when training multiple models (U-Net and Text Encoder), so something might be wrong"
)
print(
logger.warning(
f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデルU-NetおよびText Encoderの学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です"
)
@@ -143,9 +143,7 @@ def train(args):
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
@@ -182,8 +180,8 @@ def train(args):
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
@@ -198,7 +196,9 @@ def train(args):
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
@@ -268,6 +268,8 @@ def train(args):
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
@@ -452,12 +454,13 @@ def train(args):
train_util.save_sd_model_on_train_end(
args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae
)
print("model saved.")
logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, False, True)
train_util.add_training_arguments(parser, True)

View File

@@ -1,6 +1,5 @@
import importlib
import argparse
import gc
import math
import os
import sys
@@ -11,18 +10,13 @@ from multiprocessing import Value
import toml
from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from torch.nn.parallel import DistributedDataParallel as DDP
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from library import model_util
@@ -46,6 +40,12 @@ from library.custom_train_functions import (
add_v_prediction_like_loss,
apply_debiased_estimation,
)
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
class NetworkTrainer:
@@ -117,7 +117,7 @@ class NetworkTrainer:
self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype
):
for t_enc in text_encoders:
t_enc.to(accelerator.device)
t_enc.to(accelerator.device, dtype=weight_dtype)
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
input_ids = batch["input_ids"].to(accelerator.device)
@@ -141,6 +141,7 @@ class NetworkTrainer:
training_started_at = time.time()
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
setup_logging(args, reset=True)
cache_latents = args.cache_latents
use_dreambooth_method = args.in_json is None
@@ -158,18 +159,18 @@ class NetworkTrainer:
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
if use_user_config:
print(f"Loading dataset config from {args.dataset_config}")
logger.info(f"Loading dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
logger.warning(
"ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
if use_dreambooth_method:
print("Using DreamBooth method.")
logger.info("Using DreamBooth method.")
user_config = {
"datasets": [
{
@@ -180,7 +181,7 @@ class NetworkTrainer:
]
}
else:
print("Training with captions.")
logger.info("Training with captions.")
user_config = {
"datasets": [
{
@@ -209,7 +210,7 @@ class NetworkTrainer:
train_util.debug_dataset(train_dataset_group)
return
if len(train_dataset_group) == 0:
print(
logger.error(
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してくださいtrain_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります"
)
return
@@ -222,7 +223,7 @@ class NetworkTrainer:
self.assert_extra_args(args, train_dataset_group)
# acceleratorを準備する
print("preparing accelerator")
logger.info("preparing accelerator")
accelerator = train_util.prepare_accelerator(args)
is_main_process = accelerator.is_main_process
@@ -271,13 +272,12 @@ class NetworkTrainer:
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
# 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される
# cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu
self.cache_text_encoder_outputs_if_needed(
args, accelerator, unet, vae, tokenizers, text_encoders, train_dataset_group, weight_dtype
)
@@ -309,11 +309,12 @@ class NetworkTrainer:
)
if network is None:
return
network_has_multiplier = hasattr(network, "set_multiplier")
if hasattr(network, "prepare_network"):
network.prepare_network(args)
if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"):
print(
logger.warning(
"warning: scale_weight_norms is specified but the network does not support it / scale_weight_normsが指定されていますが、ネットワークが対応していません"
)
args.scale_weight_norms = False
@@ -348,8 +349,8 @@ class NetworkTrainer:
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
@@ -389,17 +390,33 @@ class NetworkTrainer:
accelerator.print("enable full bf16 training.")
network.to(weight_dtype)
unet_weight_dtype = te_weight_dtype = weight_dtype
# Experimental Feature: Put base model into fp8 to save vram
if args.fp8_base:
assert torch.__version__ >= "2.1.0", "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。"
assert (
args.mixed_precision != "no"
), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。"
accelerator.print("enable fp8 training.")
unet_weight_dtype = torch.float8_e4m3fn
te_weight_dtype = torch.float8_e4m3fn
unet.requires_grad_(False)
unet.to(dtype=weight_dtype)
unet.to(dtype=unet_weight_dtype)
for t_enc in text_encoders:
t_enc.requires_grad_(False)
# acceleratorがなんかよろしくやってくれるらしい
# TODO めちゃくちゃ冗長なのでコードを整理する
# in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16
if t_enc.device.type != "cpu":
t_enc.to(dtype=te_weight_dtype)
# nn.Embedding not support FP8
t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
if train_unet:
unet = accelerator.prepare(unet)
else:
unet.to(accelerator.device, dtype=weight_dtype) # move to device because unet is not prepared by accelerator
unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator
if train_text_encoder:
if len(text_encoders) > 1:
text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders]
@@ -407,8 +424,8 @@ class NetworkTrainer:
text_encoder = accelerator.prepare(text_encoder)
text_encoders = [text_encoder]
else:
for t_enc in text_encoders:
t_enc.to(accelerator.device, dtype=weight_dtype)
pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler)
if args.gradient_checkpointing:
@@ -421,9 +438,6 @@ class NetworkTrainer:
if train_text_encoder:
t_enc.text_model.embeddings.requires_grad_(True)
# set top parameter requires_grad = True for gradient checkpointing works
if not train_text_encoder: # train U-Net only
unet.parameters().__next__().requires_grad_(True)
else:
unet.eval()
for t_enc in text_encoders:
@@ -685,7 +699,7 @@ class NetworkTrainer:
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs['wandb'] = {'name': args.wandb_run_name}
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
@@ -754,7 +768,17 @@ class NetworkTrainer:
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.nan_to_num(latents, 0, out=latents)
latents = latents * self.vae_scale_factor
b_size = latents.shape[0]
# get multiplier for each sample
if network_has_multiplier:
multipliers = batch["network_multipliers"]
# if all multipliers are same, use single multiplier
if torch.all(multipliers == multipliers[0]):
multipliers = multipliers[0].item()
else:
raise NotImplementedError("multipliers for each sample is not supported yet")
# print(f"set multiplier: {multipliers}")
accelerator.unwrap_model(network).set_multiplier(multipliers)
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
# Get the text embedding for conditioning
@@ -778,10 +802,24 @@ class NetworkTrainer:
args, noise_scheduler, latents
)
# ensure the hidden state will require grad
if args.gradient_checkpointing:
for x in noisy_latents:
x.requires_grad_(True)
for t in text_encoder_conds:
t.requires_grad_(True)
# Predict the noise residual
with accelerator.autocast():
noise_pred = self.call_unet(
args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype
args,
accelerator,
unet,
noisy_latents.requires_grad_(train_unet),
timesteps,
text_encoder_conds,
batch,
weight_dtype,
)
if args.v_parameterization:
@@ -808,10 +846,11 @@ class NetworkTrainer:
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss)
self.all_reduce_network(accelerator, network) # sync DDP grad manually
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
if accelerator.sync_gradients:
self.all_reduce_network(accelerator, network) # sync DDP grad manually
if args.max_grad_norm != 0.0:
params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
@@ -903,12 +942,13 @@ class NetworkTrainer:
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True)
print("model saved.")
logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, True)
@@ -916,7 +956,9 @@ def setup_parser() -> argparse.ArgumentParser:
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
parser.add_argument(
"--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない"
)
parser.add_argument(
"--save_model_as",
type=str,
@@ -928,10 +970,17 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み")
parser.add_argument("--network_module", type=str, default=None, help="network module to train / 学習対象のネットワークのモジュール")
parser.add_argument(
"--network_dim", type=int, default=None, help="network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)"
"--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み"
)
parser.add_argument(
"--network_module", type=str, default=None, help="network module to train / 学習対象のネットワークのモジュール"
)
parser.add_argument(
"--network_dim",
type=int,
default=None,
help="network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)",
)
parser.add_argument(
"--network_alpha",
@@ -946,14 +995,25 @@ def setup_parser() -> argparse.ArgumentParser:
help="Drops neurons out of training every step (0 or None is default behavior (no dropout), 1 would drop all neurons) / 訓練時に毎ステップでニューロンをdropする0またはNoneはdropoutなし、1は全ニューロンをdropout",
)
parser.add_argument(
"--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数"
)
parser.add_argument("--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する")
parser.add_argument(
"--network_train_text_encoder_only", action="store_true", help="only training Text Encoder part / Text Encoder関連部分のみ学習する"
"--network_args",
type=str,
default=None,
nargs="*",
help="additional arguments for network (key=value) / ネットワークへの追加の引数",
)
parser.add_argument(
"--training_comment", type=str, default=None, help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列"
"--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する"
)
parser.add_argument(
"--network_train_text_encoder_only",
action="store_true",
help="only training Text Encoder part / Text Encoder関連部分のみ学習する",
)
parser.add_argument(
"--training_comment",
type=str,
default=None,
help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列",
)
parser.add_argument(
"--dim_from_weights",

View File

@@ -1,22 +1,15 @@
import argparse
import gc
import math
import os
from multiprocessing import Value
import toml
from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from transformers import CLIPTokenizer
@@ -37,6 +30,12 @@ from library.custom_train_functions import (
add_v_prediction_like_loss,
apply_debiased_estimation,
)
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
imagenet_templates_small = [
"a photo of a {}",
@@ -173,6 +172,7 @@ class TextualInversionTrainer:
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
setup_logging(args, reset=True)
cache_latents = args.cache_latents
@@ -183,7 +183,7 @@ class TextualInversionTrainer:
tokenizers = tokenizer_or_list if isinstance(tokenizer_or_list, list) else [tokenizer_or_list]
# acceleratorを準備する
print("prepare accelerator")
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
@@ -293,7 +293,7 @@ class TextualInversionTrainer:
]
}
else:
print("Train with captions.")
logger.info("Train with captions.")
user_config = {
"datasets": [
{
@@ -368,9 +368,7 @@ class TextualInversionTrainer:
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
@@ -387,8 +385,8 @@ class TextualInversionTrainer:
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
@@ -441,9 +439,10 @@ class TextualInversionTrainer:
# Freeze all parameters except for the token embeddings in text encoder
text_encoder.requires_grad_(True)
text_encoder.text_model.encoder.requires_grad_(False)
text_encoder.text_model.final_layer_norm.requires_grad_(False)
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
unwrapped_text_encoder = accelerator.unwrap_model(text_encoder)
unwrapped_text_encoder.text_model.encoder.requires_grad_(False)
unwrapped_text_encoder.text_model.final_layer_norm.requires_grad_(False)
unwrapped_text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
# text_encoder.text_model.embeddings.token_embedding.requires_grad_(True)
unet.requires_grad_(False)
@@ -503,6 +502,8 @@ class TextualInversionTrainer:
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
@@ -603,7 +604,7 @@ class TextualInversionTrainer:
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = text_encoder.get_input_embeddings().parameters()
params_to_clip = accelerator.unwrap_model(text_encoder).get_input_embeddings().parameters()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
@@ -615,9 +616,11 @@ class TextualInversionTrainer:
for text_encoder, orig_embeds_params, index_no_updates in zip(
text_encoders, orig_embeds_params_list, index_no_updates_list
):
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
# if full_fp16/bf16, input_embeddings_weight is fp16/bf16, orig_embeds_params is fp32
input_embeddings_weight = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight
input_embeddings_weight[index_no_updates] = orig_embeds_params.to(input_embeddings_weight.dtype)[
index_no_updates
] = orig_embeds_params[index_no_updates]
]
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
@@ -725,24 +728,24 @@ class TextualInversionTrainer:
is_main_process = accelerator.is_main_process
if is_main_process:
text_encoder = accelerator.unwrap_model(text_encoder)
updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone()
accelerator.end_training()
if args.save_state and is_main_process:
train_util.save_state_on_train_end(args, accelerator)
updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone()
if is_main_process:
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(ckpt_name, updated_embs_list, global_step, num_train_epochs, force_sync_upload=True)
print("model saved.")
logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, False)
train_util.add_training_arguments(parser, True)
@@ -758,7 +761,9 @@ def setup_parser() -> argparse.ArgumentParser:
help="format to save the model (default is .pt) / モデル保存時の形式デフォルトはpt",
)
parser.add_argument("--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み")
parser.add_argument(
"--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み"
)
parser.add_argument(
"--num_vectors_per_token", type=int, default=1, help="number of vectors per token / トークンに割り当てるembeddingsの要素数"
)
@@ -768,7 +773,9 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること",
)
parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可")
parser.add_argument(
"--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可"
)
parser.add_argument(
"--use_object_template",
action="store_true",

View File

@@ -1,20 +1,16 @@
import importlib
import argparse
import gc
import math
import os
import toml
from multiprocessing import Value
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate.utils import set_seed
import diffusers
from diffusers import DDPMScheduler
@@ -38,6 +34,12 @@ from library.custom_train_functions import (
)
import library.original_unet as original_unet
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
imagenet_templates_small = [
"a photo of a {}",
@@ -96,12 +98,13 @@ def train(args):
if args.output_name is None:
args.output_name = args.token_string
use_template = args.use_object_template or args.use_style_template
setup_logging(args, reset=True)
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
if args.sample_every_n_steps is not None or args.sample_every_n_epochs is not None:
print(
logger.warning(
"sample_every_n_steps and sample_every_n_epochs are not supported in this script currently / sample_every_n_stepsとsample_every_n_epochsは現在このスクリプトではサポートされていません"
)
assert (
@@ -116,7 +119,7 @@ def train(args):
tokenizer = train_util.load_tokenizer(args)
# acceleratorを準備する
print("prepare accelerator")
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
@@ -129,7 +132,7 @@ def train(args):
if args.init_word is not None:
init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False)
if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token:
print(
logger.warning(
f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}"
)
else:
@@ -143,7 +146,7 @@ def train(args):
), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}"
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
print(f"tokens are added: {token_ids}")
logger.info(f"tokens are added: {token_ids}")
assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered"
assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}"
@@ -171,7 +174,7 @@ def train(args):
tokenizer.add_tokens(token_strings_XTI)
token_ids_XTI = tokenizer.convert_tokens_to_ids(token_strings_XTI)
print(f"tokens are added (XTI): {token_ids_XTI}")
logger.info(f"tokens are added (XTI): {token_ids_XTI}")
# Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder.resize_token_embeddings(len(tokenizer))
@@ -180,7 +183,7 @@ def train(args):
if init_token_ids is not None:
for i, token_id in enumerate(token_ids_XTI):
token_embeds[token_id] = token_embeds[init_token_ids[(i // 16) % len(init_token_ids)]]
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
# logger.info(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
# load weights
if args.weights is not None:
@@ -188,22 +191,22 @@ def train(args):
assert len(token_ids) == len(
embeddings
), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}"
# print(token_ids, embeddings.size())
# logger.info(token_ids, embeddings.size())
for token_id, embedding in zip(token_ids_XTI, embeddings):
token_embeds[token_id] = embedding
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
print(f"weighs loaded")
# logger.info(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
logger.info(f"weighs loaded")
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
logger.info(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
# データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, False))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
logger.info(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
@@ -211,14 +214,14 @@ def train(args):
else:
use_dreambooth_method = args.in_json is None
if use_dreambooth_method:
print("Use DreamBooth method.")
logger.info("Use DreamBooth method.")
user_config = {
"datasets": [
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
]
}
else:
print("Train with captions.")
logger.info("Train with captions.")
user_config = {
"datasets": [
{
@@ -242,7 +245,7 @@ def train(args):
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
if use_template:
print(f"use template for training captions. is object: {args.use_object_template}")
logger.info(f"use template for training captions. is object: {args.use_object_template}")
templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small
replace_to = " ".join(token_strings)
captions = []
@@ -266,7 +269,7 @@ def train(args):
train_util.debug_dataset(train_dataset_group, show_input_ids=True)
return
if len(train_dataset_group) == 0:
print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
logger.error("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
return
if cache_latents:
@@ -288,9 +291,7 @@ def train(args):
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
@@ -299,13 +300,13 @@ def train(args):
text_encoder.gradient_checkpointing_enable()
# 学習に必要なクラスを準備する
print("prepare optimizer, data loader etc.")
logger.info("prepare optimizer, data loader etc.")
trainable_params = text_encoder.get_input_embeddings().parameters()
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
@@ -320,7 +321,9 @@ def train(args):
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
logger.info(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
@@ -334,7 +337,7 @@ def train(args):
)
index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0]
# print(len(index_no_updates), torch.sum(index_no_updates))
# logger.info(len(index_no_updates), torch.sum(index_no_updates))
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
# Freeze all parameters except for the token embeddings in text encoder
@@ -372,15 +375,17 @@ def train(args):
# 学習する
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
print("running training / 学習開始")
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
print(f" gradient ccumulation steps / 勾配合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
logger.info("running training / 学習開始")
logger.info(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
logger.info(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
logger.info(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
logger.info(f" num epochs / epoch数: {num_train_epochs}")
logger.info(f" batch size per device / バッチサイズ: {args.train_batch_size}")
logger.info(
f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
)
logger.info(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
logger.info(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
@@ -394,16 +399,21 @@ def train(args):
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
accelerator.init_trackers(
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
)
# function for saving/removing
def save_model(ckpt_name, embs, steps, epoch_no, force_sync_upload=False):
os.makedirs(args.output_dir, exist_ok=True)
ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"\nsaving checkpoint: {ckpt_file}")
logger.info("")
logger.info(f"saving checkpoint: {ckpt_file}")
save_weights(ckpt_file, embs, save_dtype)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
@@ -411,12 +421,13 @@ def train(args):
def remove_model(old_ckpt_name):
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
if os.path.exists(old_ckpt_file):
print(f"removing old checkpoint: {old_ckpt_file}")
logger.info(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
# training loop
for epoch in range(num_train_epochs):
print(f"\nepoch {epoch+1}/{num_train_epochs}")
logger.info("")
logger.info(f"epoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
text_encoder.train()
@@ -586,7 +597,7 @@ def train(args):
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(ckpt_name, updated_embs, global_step, num_train_epochs, force_sync_upload=True)
print("model saved.")
logger.info("model saved.")
def save_weights(file, updated_embs, save_dtype):
@@ -647,6 +658,7 @@ def load_weights(file):
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, False)
train_util.add_training_arguments(parser, True)
@@ -662,7 +674,9 @@ def setup_parser() -> argparse.ArgumentParser:
help="format to save the model (default is .pt) / モデル保存時の形式デフォルトはpt",
)
parser.add_argument("--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み")
parser.add_argument(
"--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み"
)
parser.add_argument(
"--num_vectors_per_token", type=int, default=1, help="number of vectors per token / トークンに割り当てるembeddingsの要素数"
)
@@ -672,7 +686,9 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること",
)
parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可")
parser.add_argument(
"--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可"
)
parser.add_argument(
"--use_object_template",
action="store_true",