Compare commits

...

505 Commits

Author SHA1 Message Date
Kohya S
1476040787 fix: update help text for huber loss parameters in train_util.py 2024-12-01 21:26:39 +09:00
Kohya S
cc11989755 fix: refactor huber-loss calculation in multiple training scripts 2024-12-01 21:20:28 +09:00
Kohya S
0fe6320f09 fix flux_train.py is not working 2024-12-01 14:13:37 +09:00
Kohya S
14f642f88b fix: huber_schedule exponential not working on sd3_train.py 2024-12-01 13:30:35 +09:00
Kohya S.
a5a27fe4c3 Merge pull request #1808 from recris/huber-loss-flux
Implement pseudo Huber loss for Flux and SD3
2024-12-01 13:15:33 +09:00
recris
7b61e9eb58 Fix issues found in review (pt 2) 2024-11-30 11:36:40 +00:00
Kohya S
9c885e549d fix: improve pos_embed handling for oversized images and update resolution_area_to_latent_size, when sample image size > train image size 2024-11-30 18:25:50 +09:00
recris
740ec1d526 Fix issues found in review 2024-11-28 20:38:32 +00:00
recris
420a180d93 Implement pseudo Huber loss for Flux and SD3 2024-11-27 18:37:09 +00:00
kohya-ss
2a61fc0784 docs: fix typo from block_to_swap to blocks_to_swap in README 2024-11-20 21:20:35 +09:00
Kohya S
2a188f07e6 Fix to work DOP with bock swap 2024-11-17 16:12:10 +09:00
Kohya S.
0047bb1fc3 Merge pull request #1779 from kohya-ss/faster-block-swap
Improve block swap speed and apply to LoRA
2024-11-14 19:47:10 +09:00
Kohya S
fd2d879ac8 docs: update README 2024-11-14 19:43:08 +09:00
Kohya S
5c5b544b91 refactor: remove unused prepare_split_model method from FluxNetworkTrainer 2024-11-14 19:35:43 +09:00
Kohya S
2bb0f547d7 update grad hook creation to fix TE lr in sd3 fine tuning 2024-11-14 19:33:12 +09:00
Kohya S
2cb7a6db02 feat: add block swap for FLUX.1/SD3 LoRA training 2024-11-12 21:39:13 +09:00
Kohya S
17cf249d76 Merge branch 'sd3' into faster-block-swap 2024-11-12 08:49:15 +09:00
Kohya S
cde90b8903 feat: implement block swapping for FLUX.1 LoRA (WIP) 2024-11-12 08:49:05 +09:00
Kohya S
3fe94b058a update comment 2024-11-12 08:09:07 +09:00
Kohya S.
92482c7a07 Merge pull request #1774 from sdbds/avif_get_imagesize
Support avif get image size
2024-11-12 08:02:16 +09:00
Kohya S
7feaae5f06 Merge branch 'sd3' into faster-block-swap 2024-11-11 21:16:01 +09:00
Kohya S
02bd76e6c7 Refactor block swapping to utilize custom offloading utilities 2024-11-11 21:15:36 +09:00
sdbds
26bd4540a6 init 2024-11-11 09:25:28 +08:00
Kohya S
8fac3c3b08 update README 2024-11-09 19:56:02 +09:00
Kohya S.
2a2042a762 Merge pull request #1770 from feffy380/fix-size-from-cache
fix: sort order when getting image size from cache file
2024-11-09 19:51:03 +09:00
feffy380
b3248a8eef fix: sort order when getting image size from cache file 2024-11-07 14:31:05 +01:00
Kohya S
186aa5b97d fix illeagal block is swapped #1764 2024-11-07 22:16:05 +09:00
Kohya S
b8d3feca77 Merge branch 'sd3' into faster-block-swap 2024-11-07 21:43:48 +09:00
Kohya S
123474d784 Merge branch 'sd3' of https://github.com/kohya-ss/sd-scripts into sd3 2024-11-07 21:43:37 +09:00
Kohya S
e877b306c8 Merge branch 'dev' into sd3 2024-11-07 21:43:04 +09:00
Kohya S
6adb69be63 Merge branch 'main' into dev 2024-11-07 21:42:44 +09:00
Kohya S.
387b40ea37 Merge pull request #1769 from Dango233/patch-1
Update README.md
2024-11-07 21:41:12 +09:00
Kohya S
e5ac095749 add about dev and sd3 branch to README 2024-11-07 21:39:47 +09:00
Dango233
5eb6d209d5 Update README.md 2024-11-07 20:33:31 +08:00
Kohya S.
f264f4091f Update README.md 2024-11-07 21:30:31 +09:00
Kohya S
5e86323f12 Update README and clean-up the code for SD3 timesteps 2024-11-07 21:27:12 +09:00
Kohya S.
588ea9e123 Merge pull request #1768 from Dango233/dango/timesteps_fix
Dango/timesteps fix
2024-11-07 20:56:04 +09:00
Dango233
bafd10d558 Fix typo 2024-11-07 18:21:04 +08:00
Dango233
e54462a4a9 Fix SD3 trained lora loading and merging 2024-11-07 09:54:12 +00:00
Dango233
40ed54bfc0 Simplify Timestep weighting
* Remove diffusers dependency in ts & sigma calc
* support Shift setting
* Add uniform distribution
* Default to Uniform distribution and shift 1
2024-11-07 09:53:54 +00:00
Kohya S
43849030cf Fix to work without latent cache #1758 2024-11-06 21:33:28 +09:00
Kohya S
aab943cea3 remove unused weight swapping functions from utils.py 2024-11-05 23:27:41 +09:00
Kohya S
81c0c965a2 faster block swap 2024-11-05 21:22:42 +09:00
Kohya S
5e32ee26a1 fix crashing in DDP training closes #1751 2024-11-02 15:32:16 +09:00
Kohya S
e0db59695f update multi-res training in SD3.5M 2024-11-02 11:13:04 +09:00
Kohya S.
264328d117 Merge pull request #1719 from kohya-ss/sd3_5_support
SD3.5 Large support
2024-11-01 21:55:48 +09:00
Kohya S
82daa98fe8 remove duplicate resolution for scaled pos embed 2024-11-01 21:43:47 +09:00
Kohya S
9aa6f52ac3 Fix memory leak in latent caching. bmp failed to cache 2024-11-01 21:43:21 +09:00
Kohya S
830df4abcc Fix crashing if image is too tall or wide. 2024-10-31 21:39:07 +09:00
Kohya S
9e23368e3d Update SD3 training 2024-10-31 19:58:41 +09:00
Kohya S
1434d8506f Support SD3.5M multi resolutional training 2024-10-31 19:58:22 +09:00
Kohya S
70a179e446 Fix to use SDPA instead of xformers 2024-10-30 14:34:19 +09:00
Kohya S
8c3c825b5f Merge branch 'sd3_5_support' of https://github.com/kohya-ss/sd-scripts into sd3_5_support 2024-10-30 12:51:55 +09:00
Kohya S
bdddc20d68 support SD3.5M 2024-10-30 12:51:49 +09:00
kohya-ss
b502f58488 Fix emb_dim to work. 2024-10-29 23:29:50 +09:00
kohya-ss
c9a1417157 Merge branch 'sd3' into sd3_5_support 2024-10-29 22:30:01 +09:00
kohya-ss
ce5b532582 Fix additional LoRA to work 2024-10-29 22:29:24 +09:00
Kohya S
1e2f7b0e44 Support for checkpoint files with a mysterious prefix "model.diffusion_model." 2024-10-29 22:11:04 +09:00
kohya-ss
80bb3f4ecf Merge branch 'sd3_5_support' of https://github.com/kohya-ss/sd-scripts into sd3_5_support 2024-10-29 21:52:08 +09:00
kohya-ss
d4e19fbd5e Support Lora 2024-10-29 21:52:04 +09:00
kohya-ss
0af4edd8a6 Fix split_qkv 2024-10-29 21:51:56 +09:00
Kohya S
75554867ce Fix error on saving T5XXL 2024-10-29 08:34:31 +09:00
Kohya S
af8e216035 Fix sample image gen to work with block swap 2024-10-28 22:08:57 +09:00
kohya-ss
1065dd1b56 Fix to work dropout_rate for TEs 2024-10-27 19:36:36 +09:00
kohya-ss
d4f7849592 prevent unintended cast for disk cached TE outputs 2024-10-27 19:35:56 +09:00
kohya-ss
a1255d637f Fix SD3 LoRA training to work (WIP) 2024-10-27 17:03:36 +09:00
Kohya S
db2b4d41b9 Add dropout rate arguments for CLIP-L, CLIP-G, and T5, fix Text Encoders LoRA not trained 2024-10-27 16:42:58 +09:00
Kohya S
b649bbf2b6 Merge branch 'sd3' into sd3_5_support 2024-10-27 10:20:35 +09:00
Kohya S
731664b8c3 Merge branch 'dev' into sd3 2024-10-27 10:20:14 +09:00
Kohya S
e070bd9973 Merge branch 'main' into dev 2024-10-27 10:19:55 +09:00
Kohya S
ca44e3e447 reduce VRAM usage, instead of increasing main RAM usage 2024-10-27 10:19:05 +09:00
Kohya S
150579db32 Merge branch 'sd3' into sd3_5_support 2024-10-26 22:03:41 +09:00
Kohya S
8549669f89 Merge branch 'dev' into sd3 2024-10-26 22:02:57 +09:00
Kohya S
900d551a6a Merge branch 'main' into dev 2024-10-26 22:02:36 +09:00
Kohya S
56b4ea963e Fix LoRA metadata hash calculation bug in svd_merge_lora.py, sdxl_merge_lora.py, and resize_lora.py closes #1722 2024-10-26 22:01:10 +09:00
kohya-ss
014064fd81 fix sample image generation without seed failed close #1726 2024-10-26 18:59:45 +09:00
kohya-ss
56bf761164 fix errors in SD3 LoRA training with Text Encoders close #1724 2024-10-26 17:29:24 +09:00
kohya-ss
0031d916f0 add latent scaling/shifting 2024-10-25 23:20:38 +09:00
kohya-ss
d2c549d7b2 support SD3 LoRA 2024-10-25 21:58:31 +09:00
Kohya S
f52fb66e8f Merge branch 'sd3' into sd3_5_support 2024-10-25 19:03:58 +09:00
Kohya S
5fba6f514a Merge branch 'dev' into sd3 2024-10-25 19:03:27 +09:00
Kohya S
b1e6504007 update README 2024-10-25 18:56:25 +09:00
Kohya S.
b8ae745d0c Merge pull request #1717 from catboxanon/fix/remove-vpred-warnings
Remove v-pred warnings
2024-10-25 18:49:40 +09:00
Kohya S.
c632af860e Merge pull request #1715 from catboxanon/vpred-ztsnr-fixes
Update debiased estimation loss function to accommodate V-pred
2024-10-25 18:48:14 +09:00
Kohya S
f8c5146d71 support block swap with fused_optimizer_pass 2024-10-24 22:02:05 +09:00
Kohya S
0286114bd2 support SD3.5L, fix final saving 2024-10-24 21:28:42 +09:00
Kohya S
e3c43bda49 reduce memory usage in sample image generation 2024-10-24 20:35:47 +09:00
Kohya S
623017f716 refactor SD3 CLIP to transformers etc. 2024-10-24 19:49:28 +09:00
catboxanon
be14c06267 Remove v-pred warnings
Different model architectures, such as SDXL, can take advantage of
v-pred. It doesn't make sense to include these warnings anymore.
2024-10-22 12:13:51 -04:00
catboxanon
0e7c592933 Remove scale_v_pred_loss_like_noise_pred deprecation
https://github.com/kohya-ss/sd-scripts/pull/1715#issuecomment-2427876376
2024-10-22 11:19:34 -04:00
catboxanon
e1b63c2249 Only add warning for deprecated scaling vpred loss function 2024-10-21 08:12:53 -04:00
catboxanon
8fc30f8205 Fix training for V-pred and ztSNR
1) Updates debiased estimation loss function for V-pred.
2) Prevents now-deprecated scaling of loss if ztSNR is enabled.
2024-10-21 07:34:33 -04:00
Kohya S
138dac4aea update README 2024-10-20 09:22:38 +09:00
Kohya S
7fe8e162cb fix to work ControlNetSubset with custom_attributes 2024-10-20 08:45:27 +09:00
Kohya S.
aa932429d1 Merge pull request #1710 from kohya-ss/diff_output_prsv
Differential Output Preservation loss for LoRA
2024-10-19 19:56:13 +09:00
Kohya S
09b4d1e9b6 Merge branch 'sd3' into diff_output_prsv 2024-10-19 19:30:23 +09:00
kohya-ss
2c45d979e6 update README, remove unnecessary autocast 2024-10-19 19:21:12 +09:00
kohya-ss
ef70aa7b42 add FLUX.1 support 2024-10-18 23:39:48 +09:00
Kohya S
d8d7142665 fix to work caching latents #1696 2024-10-18 23:16:30 +09:00
Kohya S
3cc5b8db99 Diff Output Preserv loss for SDXL 2024-10-18 20:57:13 +09:00
Kohya S
2500f5a798 fix latents caching not working closes #1696 2024-10-15 07:16:34 +09:00
Kohya S.
1275e148df Merge pull request #1690 from kohya-ss/multi-gpu-caching
Caching latents and Text Encoder outputs with multiple GPUs
2024-10-13 19:25:59 +09:00
kohya-ss
2d5f7fa709 update README 2024-10-13 19:23:21 +09:00
kohya-ss
886ffb4d65 Merge branch 'sd3' into multi-gpu-caching 2024-10-13 19:14:06 +09:00
Kohya S.
d02a6ef7c4 Merge pull request #1660 from kohya-ss/fast_image_sizes
Fast image sizes
2024-10-13 19:11:37 +09:00
kohya-ss
bfc3a65acd fix to work cache latents/text encoder outputs 2024-10-13 19:08:16 +09:00
kohya-ss
2244cf5b83 load images in parallel when caching latents 2024-10-13 18:22:19 +09:00
Kohya S
c65cf3812d Merge branch 'sd3' into fast_image_sizes 2024-10-13 17:31:11 +09:00
kohya-ss
74228c9953 update cache_latents/text_encoder_outputs 2024-10-13 16:27:22 +09:00
kohya-ss
5bb9f7fb1a Merge branch 'sd3' into multi-gpu-caching 2024-10-13 11:52:42 +09:00
Kohya S
e277b5789e Update FLUX.1 support for compact models 2024-10-12 21:49:07 +09:00
kohya-ss
ecaea909b1 update README 2024-10-12 20:26:57 +09:00
kohya-ss
c80c304779 Refactor caching in train scripts 2024-10-12 20:18:41 +09:00
kohya-ss
ff4083b910 Merge branch 'sd3' into multi-gpu-caching 2024-10-12 16:39:36 +09:00
kohya-ss
0d3058b65a update README 2024-10-12 14:46:35 +09:00
Kohya S.
d005652d03 Merge pull request #1686 from Akegarasu/sd3
fix: fix some distributed training error in windows
2024-10-12 14:33:02 +09:00
Kohya S.
43bfeea600 Merge pull request #1655 from kohya-ss/sdxl-ctrl-net
ControlNet training for SDXL
2024-10-11 22:27:53 +09:00
Kohya S
035c4a8552 update docs and help text 2024-10-11 22:23:15 +09:00
Kohya S
f2bc820133 support weighted captions for SD/SDXL 2024-10-11 08:48:55 +09:00
Akegarasu
9f4dac5731 torch 2.4 2024-10-10 14:08:55 +08:00
Akegarasu
3de42b6edb fix: distributed training in windows 2024-10-10 14:03:59 +08:00
Kohya S
886f75345c support weighted captions for sdxl LoRA and fine tuning 2024-10-10 08:27:15 +09:00
Kohya S
126159f7c4 Merge branch 'sd3' into sdxl-ctrl-net 2024-10-07 20:39:53 +09:00
Kohya S
83e3048cb0 load Diffusers format, check schnell/dev 2024-10-06 21:32:21 +09:00
Kohya S
ba08a89894 call optimizer eval/train for sample_at_first, also set train after resuming closes #1667 2024-10-04 20:35:16 +09:00
Kohya S
c2440f9e53 fix cond image normlization, add independent LR for control 2024-10-03 21:32:21 +09:00
Kohya S
33e942e36e Merge branch 'sd3' into fast_image_sizes 2024-10-01 08:38:09 +09:00
Kohya S
793999d116 sample generation in SDXL ControlNet training 2024-09-30 23:39:32 +09:00
Kohya S
d78f6a775c Merge branch 'sd3' into sdxl-ctrl-net 2024-09-29 23:23:07 +09:00
Kohya S
8bea039a8d Merge branch 'dev' into sd3 2024-09-29 23:19:12 +09:00
Kohya S
012e7e63a5 fix to work linear/cosine scheduler closes #1651 ref #1393 2024-09-29 23:18:16 +09:00
Kohya S
0243c65877 fix typo 2024-09-29 23:09:56 +09:00
Kohya S
8919b31145 use original ControlNet instead of Diffusers 2024-09-29 23:07:34 +09:00
Kohya S
56a63f01ae Merge branch 'sd3' into multi-gpu-caching 2024-09-29 10:12:18 +09:00
青龍聖者@bdsqlsz
e0c3630203 Support Sdxl Controlnet (#1648)
* Create sdxl_train_controlnet.py

* add fuse_background_pass

* Update sdxl_train_controlnet.py

* add fuse and fix error

* update

* Update sdxl_train_controlnet.py

* Update sdxl_train_controlnet.py

* Update sdxl_train_controlnet.py

* update

* Update sdxl_train_controlnet.py
2024-09-29 10:11:15 +09:00
Kohya S
d050638571 Merge branch 'dev' into sd3 2024-09-29 10:00:01 +09:00
Kohya S
1567549220 update help text #1632 2024-09-29 09:51:36 +09:00
Kohya S
fe2aa32484 adjust min/max bucket reso divisible by reso steps #1632 2024-09-29 09:49:25 +09:00
Kohya S
1a0f5b0c38 re-fix sample generation is not working in FLUX1 split mode #1647 2024-09-29 00:35:29 +09:00
Kohya S
822fe57859 add workaround for 'Some tensors share memory' error #1614 2024-09-28 20:57:27 +09:00
Kohya S
a9aa52658a fix sample generation is not working in FLUX1 fine tuning #1647 2024-09-28 17:12:56 +09:00
kohya-ss
24b1fdb664 remove debug print 2024-09-26 22:22:06 +09:00
kohya-ss
9249d00311 experimental support for multi-gpus latents caching 2024-09-26 22:19:56 +09:00
Kohya S
3ebb65f945 Merge branch 'dev' into sd3 2024-09-26 21:41:25 +09:00
Kohya S
ce49ced699 update readme 2024-09-26 21:37:40 +09:00
Kohya S
a94bc84dec fix to work bitsandbytes optimizers with full path #1640 2024-09-26 21:37:31 +09:00
Kohya S.
4296e286b8 Merge pull request #1640 from sdbds/ademamix8bit
New optimizer:AdEMAMix8bit and PagedAdEMAMix8bit
2024-09-26 21:20:19 +09:00
Kohya S
392e8dedd8 fix flip_aug, alpha_mask, random_crop issue in caching in caching strategy 2024-09-26 21:14:11 +09:00
Kohya S
2cd6aa281c Merge branch 'dev' into sd3 2024-09-26 20:52:08 +09:00
Kohya S
bf91bea2e4 fix flip_aug, alpha_mask, random_crop issue in caching 2024-09-26 20:51:40 +09:00
Kohya S
da94fd934e fix typos 2024-09-26 08:27:48 +09:00
Kohya S
56a7bc171d new block swap for FLUX.1 fine tuning 2024-09-26 08:26:31 +09:00
sdbds
1beddd84e5 delete code for cleaning 2024-09-25 22:58:26 +08:00
Kohya S
65fb69f808 Merge branch 'dev' into sd3 2024-09-25 20:56:16 +09:00
Kohya S
e74f58148c update README 2024-09-25 20:55:50 +09:00
Kohya S.
c1d16a76d6 Merge pull request #1628 from recris/huber-timesteps
Make timesteps work in the standard way when Huber loss is used
2024-09-25 20:52:55 +09:00
sdbds
ab7b231870 init 2024-09-25 19:38:52 +08:00
Kohya S.
fba769222b Merge branch 'dev' into sd3 2024-09-23 21:20:02 +09:00
Kohya S
29177d2f03 retain alpha in pil_resize backport #1619 2024-09-23 21:14:03 +09:00
recris
e1f23af1bc make timestep sampling behave in the standard way when huber loss is used 2024-09-21 13:21:56 +01:00
Kohya S.
95ff9dba0c Merge pull request #1619 from emcmanus/patch-1
Retain alpha in `pil_resize` for `--alpha_mask`
2024-09-20 22:24:49 +09:00
Kohya S
583d4a436c add compatibility for int LR (D-Adaptation etc.) #1620 2024-09-20 22:22:24 +09:00
Kohya S.
24f8975fb7 Merge pull request #1620 from Akegarasu/sd3
fix: backward compatibility for text_encoder_lr
2024-09-20 22:16:39 +09:00
Akegarasu
0535cd29b9 fix: backward compatibility for text_encoder_lr 2024-09-20 10:05:22 +08:00
Ed McManus
de4bb657b0 Update utils.py
Cleanup
2024-09-19 14:38:32 -07:00
Ed McManus
3957372ded Retain alpha in pil_resize
Currently the alpha channel is dropped by `pil_resize()` when `--alpha_mask` is supplied and the image width does not exceed the bucket.

This codepath is entered on the last line, here:
```
def trim_and_resize_if_required(
    random_crop: bool, image: np.ndarray, reso, resized_size: Tuple[int, int]
) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int, int, int]]:
    image_height, image_width = image.shape[0:2]
    original_size = (image_width, image_height)  # size before resize

    if image_width != resized_size[0] or image_height != resized_size[1]:
        # リサイズする
        if image_width > resized_size[0] and image_height > resized_size[1]:
            image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA)  # INTER_AREAでやりたいのでcv2でリサイズ
        else:
            image = pil_resize(image, resized_size)
```
2024-09-19 14:30:03 -07:00
Kohya S
b844c70d14 Merge branch 'dev' into sd3 2024-09-19 21:51:33 +09:00
Kohya S.
0b7927e50b Merge pull request #1615 from Maru-mee/patch-1
Bug fix: alpha_mask load
2024-09-19 21:49:49 +09:00
Kohya S
706a48d50e Merge branch 'dev' into sd3 2024-09-19 21:15:39 +09:00
Kohya S
d7e14721e2 Merge branch 'main' into dev 2024-09-19 21:15:19 +09:00
Kohya S
9c757c2fba fix SDXL block index to match LBW 2024-09-19 21:14:57 +09:00
Maru-mee
e7040669bc Bug fix: alpha_mask load 2024-09-19 15:47:06 +09:00
Kohya S
1286e00bb0 fix to call train/eval in schedulefree #1605 2024-09-18 21:31:54 +09:00
Kohya S
e74502117b update README 2024-09-18 08:04:32 +09:00
Kohya S.
bbd160b4ca sd3 schedule free opt (#1605)
* New ScheduleFree support for Flux (#1600)

* init

* use no schedule

* fix typo

* update for eval()

* fix typo

* update

* Update train_util.py

* Update requirements.txt

* update sfwrapper WIP

* no need to check schedulefree optimizer

* remove debug print

* comment out schedulefree wrapper

* update readme

---------

Co-authored-by: 青龍聖者@bdsqlsz <865105819@qq.com>
2024-09-18 07:55:04 +09:00
Kohya S
a2ad7e5644 blocks_to_swap=0 means no swap 2024-09-17 21:42:14 +09:00
Kohya S
0cbe95bcc7 fix text_encoder_lr to work with int closes #1608 2024-09-17 21:21:28 +09:00
Kohya S
d8d15f1a7e add support for specifying blocks in FLUX.1 LoRA training 2024-09-16 23:14:09 +09:00
Kohya S
96c677b459 fix to work lienar/cosine lr scheduler closes #1602 ref #1393 2024-09-16 10:42:09 +09:00
Kohya S
be078bdaca fix typo 2024-09-15 13:59:17 +09:00
Kohya S
9f44ef1330 add diffusers to FLUX.1 conversion script 2024-09-15 13:52:23 +09:00
Kohya S
6445bb2bc9 update README 2024-09-14 22:37:26 +09:00
Kohya S
c9ff4de905 Add support for specifying rank for each layer in FLUX.1 2024-09-14 22:17:52 +09:00
Kohya S
2d8ee3c280 OFT for FLUX.1 2024-09-14 15:48:16 +09:00
Kohya S
0485f236a0 Merge branch 'dev' into sd3 2024-09-13 22:39:24 +09:00
Kohya S
93d9fbf607 improve OFT implementation closes #944 2024-09-13 22:37:11 +09:00
Kohya S
c15a3a1a65 Merge branch 'dev' into sd3 2024-09-13 21:30:49 +09:00
Kohya S
43ad73860d Merge branch 'main' into dev 2024-09-13 21:29:51 +09:00
Kohya S
b755ebd0a4 add LBW support for SDXL merge LoRA 2024-09-13 21:29:31 +09:00
Kohya S
f4a0bea6dc format by black 2024-09-13 21:26:06 +09:00
terracottahaniwa
734d2e5b2b Support Lora Block Weight (LBW) to svd_merge_lora.py (#1575)
* support lora block weight

* solve license incompatibility

* Fix issue: lbw index calculation
2024-09-13 20:45:35 +09:00
Kohya S
f3ce80ef8f Merge branch 'dev' into sd3 2024-09-13 19:49:16 +09:00
Kohya S
9d2860760d Merge branch 'main' into dev 2024-09-13 19:48:53 +09:00
Kohya S
3387dc7306 formatting, update README 2024-09-13 19:45:42 +09:00
Kohya S
57ae44eb61 refactor to make safer 2024-09-13 19:45:00 +09:00
Maru-mee
1d7118a622 Support : OFT merge to base model (#1580)
* Support : OFT merge to base model

* Fix typo

* Fix typo_2

* Delete unused parameter 'eye'
2024-09-13 19:01:36 +09:00
Kohya S
cefe52629e fix to work old notation for TE LR in .toml 2024-09-12 12:36:07 +09:00
Kohya S
237317fffd update README 2024-09-11 22:23:43 +09:00
Plat
a823fd9fb8 Improve wandb logging (#1576)
* fix: wrong training steps were recorded to wandb, and no log was sent when logging_dir was not specified

* fix: checking of whether wandb is enabled

* feat: log images to wandb with their positive prompt as captions

* feat: logging sample images' caption for sd3 and flux

* fix: import wandb before use
2024-09-11 22:21:16 +09:00
Kohya S
c7c666b182 fix typo 2024-09-11 22:12:31 +09:00
Kohya S.
d83f2e92da Merge pull request #1592 from cocktailpeanut/sd3
Critical typo fix
2024-09-11 22:10:37 +09:00
cocktailpeanut
8311e88225 typo fix 2024-09-11 09:02:29 -04:00
Kohya S
eaafa5c9da Merge branch 'dev' into sd3 2024-09-11 21:46:21 +09:00
Kohya S
6dbfd47a59 Fix to work PIECEWISE_CONSTANT, update requirement.txt and README #1393 2024-09-11 21:44:36 +09:00
青龍聖者@bdsqlsz
fd68703f37 Add New lr scheduler (#1393)
* add new lr scheduler

* fix bugs and use num_cycles / 2

* Update requirements.txt

* add num_cycles for min lr

* keep PIECEWISE_CONSTANT

* allow use float with warmup or decay ratio.

* Update train_util.py
2024-09-11 21:25:45 +09:00
Kohya S
65b8a064f6 update README 2024-09-10 21:20:38 +09:00
Kohya S
d10ff62a78 support individual LR for CLIP-L/T5XXL 2024-09-10 20:32:09 +09:00
Kohya S
d29af146b8 add negative prompt for flux inference script 2024-09-09 23:01:15 +09:00
Kohya S
ce144476cf Merge branch 'dev' into sd3 2024-09-07 10:59:22 +09:00
Kohya S
62ec3e6424 Merge branch 'main' into dev 2024-09-07 10:52:49 +09:00
Kohya S.
de25945a93 Merge pull request #1550 from kohya-ss/dependabot/github_actions/crate-ci/typos-1.24.3
Bump crate-ci/typos from 1.19.0 to 1.24.3
2024-09-07 10:50:46 +09:00
Kohya S
0005867ba5 update README, format code 2024-09-07 10:45:18 +09:00
Kohya S.
16bb5699ac Merge pull request #1426 from sdbds/resize
Replacing CV2 resize to Pil resize
2024-09-07 10:22:52 +09:00
Kohya S.
319e4d9831 Merge pull request #1433 from millie-v/sample-image-without-cuda
Generate sample images without having CUDA (such as on Macs)
2024-09-07 10:19:55 +09:00
Kohya S
2889108d85 feat: Add --cpu_offload_checkpointing option to LoRA training 2024-09-05 20:58:33 +09:00
Kohya S
d9129522a6 set dtype before calling ae closes #1562 2024-09-05 12:20:07 +09:00
Kohya S
90ed2dfb52 feat: Add support for merging CLIP-L and T5XXL LoRA models 2024-09-05 08:39:29 +09:00
Kohya S
56cb2fc885 support T5XXL LoRA, reduce peak memory usage #1560 2024-09-04 23:15:27 +09:00
Kohya S
b7cff0a754 update README 2024-09-04 21:35:47 +09:00
Kohya S
b65ae9b439 T5XXL LoRA training, fp8 T5XXL support 2024-09-04 21:33:17 +09:00
Kohya S
6abacf04da update README 2024-09-02 13:05:26 +09:00
Kohya S
4f6d915d15 update help and README 2024-09-01 19:12:29 +09:00
Kohya S.
1e30aa83b4 Merge pull request #1541 from sdbds/flux_shift
Add Flux_Shift for solving the problem of multi-resolution training blurry
2024-09-01 19:00:16 +09:00
Kohya S
92e7600cc2 Move freeze_blocks to sd3_train because it's only for sd3 2024-09-01 18:57:07 +09:00
青龍聖者@bdsqlsz
ef510b3cb9 Sd3 freeze x_block (#1417)
* Update sd3_train.py

* add freeze block lr

* Update train_util.py

* update
2024-09-01 18:41:01 +09:00
Kohya S.
928e0fc096 Merge pull request #1529 from Akegarasu/sd3
fix: text_encoder_conds referenced before assignment
2024-09-01 18:29:27 +09:00
dependabot[bot]
1bcf8d600b Bump crate-ci/typos from 1.19.0 to 1.24.3
Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.19.0 to 1.24.3.
- [Release notes](https://github.com/crate-ci/typos/releases)
- [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md)
- [Commits](https://github.com/crate-ci/typos/compare/v1.19.0...v1.24.3)

---
updated-dependencies:
- dependency-name: crate-ci/typos
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-09-01 01:33:04 +00:00
Kohya S.
f8f5b16958 Merge pull request #1540 from kohya-ss/dependabot/pip/opencv-python-4.8.1.78
Bump opencv-python from 4.7.0.68 to 4.8.1.78
2024-08-31 21:37:07 +09:00
Kohya S.
826ab5ce2e Merge pull request #1532 from nandometzger/main
Update train_util.py, bug fix
2024-08-31 21:36:33 +09:00
sdbds
25c9040f4f Update flux_train_utils.py 2024-08-31 19:53:59 +08:00
dependabot[bot]
3a6154b7b0 Bump opencv-python from 4.7.0.68 to 4.8.1.78
Bumps [opencv-python](https://github.com/opencv/opencv-python) from 4.7.0.68 to 4.8.1.78.
- [Release notes](https://github.com/opencv/opencv-python/releases)
- [Commits](https://github.com/opencv/opencv-python/commits)

---
updated-dependencies:
- dependency-name: opencv-python
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-08-31 06:21:16 +00:00
Nando Metzger
2a3aefb4e4 Update train_util.py, bug fix 2024-08-30 08:15:05 +02:00
Akegarasu
35882f8d5b fix 2024-08-29 23:03:43 +08:00
Akegarasu
34f2315047 fix: text_encoder_conds referenced before assignment 2024-08-29 22:33:37 +08:00
Kohya S
8fdfd8c857 Update safetensors to version 0.4.4 in requirements.txt #1524 2024-08-29 22:26:29 +09:00
Kohya S
8ecf0fc4bf Refactor code to ensure args.guidance_scale is always a float #1525 2024-08-29 22:10:57 +09:00
Kohya S.
930d709e3d Merge pull request #1525 from Akegarasu/sd3
make guidance_scale keep float in args
2024-08-29 22:08:57 +09:00
Kohya S.
daa6ad5165 Update README.md 2024-08-29 21:25:30 +09:00
Kohya S
a0cfb0894c Cleaned up README 2024-08-29 21:20:33 +09:00
Akegarasu
6c0e8a5a17 make guidance_scale keep float in args 2024-08-29 14:50:29 +08:00
Kohya S
a61cf73a5c update readme 2024-08-27 21:44:10 +09:00
Kohya S
3be712e3e0 feat: Update direct loading fp8 ckpt for LoRA training 2024-08-27 21:40:02 +09:00
Kohya S
0087a46e14 FLUX.1 LoRA supports CLIP-L 2024-08-27 19:59:40 +09:00
Kohya S
72287d39c7 feat: Add shift option to --timestep_sampling in FLUX.1 fine-tuning and LoRA training 2024-08-25 16:01:24 +09:00
Kohya S
ea9242653c Merge branch 'dev' into sd3 2024-08-24 21:24:44 +09:00
Kohya S
d5c076cf90 update readme 2024-08-24 21:21:39 +09:00
Kohya S.
4ca29edbff Merge pull request #1505 from liesened/patch-2
Add v-pred support for SDXL train
2024-08-24 21:16:53 +09:00
Kohya S
5639c2adc0 fix typo 2024-08-24 16:37:49 +09:00
Kohya S
cf689e7aa6 feat: Add option to split projection layers and apply LoRA 2024-08-24 16:35:43 +09:00
Kohya S
2e89cd2cc6 Fix issue with attention mask not being applied in single blocks 2024-08-24 12:39:54 +09:00
liesen
1e8108fec9 Handle args.v_parameterization properly for MinSNR and changed prediction target 2024-08-24 01:38:17 +03:00
Kohya S
81411a398e speed up getting image sizes 2024-08-22 22:02:29 +09:00
Kohya S
99744af53a Merge branch 'dev' into sd3 2024-08-22 21:34:24 +09:00
Kohya S
afb971f9c3 fix SD1.5 LoRA extraction #1490 2024-08-22 21:33:15 +09:00
Kohya S
bf9f798985 chore: fix typos, remove debug print 2024-08-22 19:59:38 +09:00
Kohya S
b0a980844a added a script to extract LoRA 2024-08-22 19:57:29 +09:00
Kohya S
2d8fa3387a Fix to remove zero pad for t5xxl output 2024-08-22 19:56:27 +09:00
Kohya S
a4d27a232b Fix --debug_dataset to work. 2024-08-22 19:55:31 +09:00
kohya-ss
98c91a7625 Fix bug in FLUX multi GPU training 2024-08-22 12:37:41 +09:00
Kohya S
e1cd19c0c0 add stochastic rounding, fix single block 2024-08-21 21:04:10 +09:00
Kohya S
2b07a92c8d Fix error in applying mask in Attention and add LoRA converter script 2024-08-21 12:30:23 +09:00
Kohya S
e17c42cb0d Add BFL/Diffusers LoRA converter #1467 #1458 #1483 2024-08-21 12:28:45 +09:00
Kohya S
7e459c00b2 Update T5 attention mask handling in FLUX 2024-08-21 08:02:33 +09:00
Kohya S
6ab48b09d8 feat: Support multi-resolution training with caching latents to disk 2024-08-20 21:39:43 +09:00
Kohya S.
388b3b4b74 Merge pull request #1482 from kohya-ss/flux-merge-lora
Flux merge lora
2024-08-20 19:34:57 +09:00
Kohya S
dbed5126bd chore: formatting 2024-08-20 19:33:47 +09:00
Kohya S
9381332020 revert merge function add add option to use new func 2024-08-20 19:32:26 +09:00
Kohya S
6f6faf9b5a fix to work with ai-toolkit LoRA 2024-08-20 19:16:25 +09:00
Kohya S.
92b1f6d968 Merge pull request #1469 from exveria1015/sd3
Flux の LoRA マージ機能を修正
2024-08-20 19:06:06 +09:00
Kohya S
c62c95e862 update about multi-resolution training in FLUX.1 2024-08-20 08:21:01 +09:00
Kohya S
9e72be0a13 Fix debug_dataset to work 2024-08-20 08:19:00 +09:00
Kohya S
486fe8f70a feat: reduce memory usage and add memory efficient option for model saving 2024-08-19 22:30:24 +09:00
Kohya S
6e72a799c8 reduce peak VRAM usage by excluding some blocks to cuda 2024-08-19 21:55:28 +09:00
Kohya S
d034032a5d update README fix option name 2024-08-19 13:08:49 +09:00
Kohya S
a450488928 update readme 2024-08-18 16:56:50 +09:00
Kohya S
ef535ec6bb add memory efficient training for FLUX.1 2024-08-18 16:54:18 +09:00
exveria1015
7e688913ae fix: Flux の LoRA マージ機能を修正 2024-08-18 12:38:05 +09:00
kohya-ss
25f77f6ef0 fix flux fine tuning to work 2024-08-17 15:54:32 +09:00
Kohya S
400955d3ea add fine tuning FLUX.1 (WIP) 2024-08-17 15:36:18 +09:00
Kohya S
7367584e67 fix sd3 training to work without cachine TE outputs #1465 2024-08-17 14:38:34 +09:00
Kohya S
e45d3f8634 add merge LoRA script 2024-08-16 22:19:21 +09:00
Kohya S
3921a4efda add t5xxl max token length, support schnell 2024-08-16 17:06:05 +09:00
Kohya S.
739a8969bc Merge pull request #1461 from fireicewolf/sd3-devel
Fix AttributeError: 'FluxNetworkTrainer' object has no attribute 'sample_prompts_te_outputs'
2024-08-16 14:15:24 +09:00
DukeG
08ef886bfe Fix AttributeError: 'FluxNetworkTrainer' object has no attribute 'sample_prompts_te_outputs'
Move "self.sample_prompts_te_outputs = None" from Line 150 to Line 26.
2024-08-16 11:00:08 +08:00
Kohya S
35b6cb0cd1 update for torchvision 2024-08-15 22:07:35 +09:00
Kohya S
8aaa1967bd fix encoding latents closes #1456 2024-08-15 22:07:23 +09:00
Kohya S.
e2d822cad7 Merge pull request #1452 from fireicewolf/sd3-devel
Fix AttributeError: 'T5EncoderModel' object has no attribute 'text_model', while loading T5 model in GPU.
2024-08-15 21:12:19 +09:00
Kohya S
7db4222119 add sample image generation during training 2024-08-14 22:15:26 +09:00
DukeG
9760d097b0 Fix AttributeError: 'T5EncoderModel' object has no attribute 'text_model'
While loading T5 model in GPU.
2024-08-14 19:58:54 +08:00
Kohya S
56d7651f08 add experimental split mode for FLUX 2024-08-13 22:28:39 +09:00
kohya-ss
9711c96f96 update README 2024-08-13 21:03:17 +09:00
kohya-ss
f5ce754bc2 Merge branch 'dev' into sd3 2024-08-13 21:00:44 +09:00
kohya-ss
4cf42cc5d4 Merge branch 'sd3' of https://github.com/kohya-ss/sd-scripts into sd3 2024-08-13 21:00:21 +09:00
kohya-ss
0415d200f5 update dependencies closes #1450 2024-08-13 21:00:16 +09:00
Kohya S
a7d5dabde3 Update readme 2024-08-12 17:09:19 +09:00
kohya-ss
4af36f9632 update to work interactive mode 2024-08-12 13:24:10 +09:00
Kohya S
9e09a69df1 update README 2024-08-12 08:19:45 +09:00
Kohya S
74f91c2ff7 correct option name closes #1446 2024-08-11 21:54:10 +09:00
Kohya S
d25ae361d0 fix apply_t5_attn_mask to work 2024-08-11 19:07:07 +09:00
Kohya S
82314ac2e7 update readme for ai toolkit settings 2024-08-11 11:14:08 +09:00
Kohya S
8a0f12dde8 update FLUX LoRA training 2024-08-10 23:42:05 +09:00
Kohya S
358f13f2c9 fix alpha is ignored 2024-08-10 14:03:59 +09:00
Kohya S
808d2d1f48 fix typos 2024-08-09 23:02:51 +09:00
Kohya S
36b2e6fc28 add FLUX.1 LoRA training 2024-08-09 22:56:48 +09:00
Kohya S
da4d0fe016 support attn mask for l+g/t5 2024-08-05 20:51:34 +09:00
Kohya S
231df197dd Fix npz path for verification 2024-08-05 20:26:30 +09:00
Kohya S
002d75179a sample images for training 2024-07-29 23:18:34 +09:00
Kohya S
1a977e847a fix typos 2024-07-27 13:51:50 +09:00
Kohya S
41dee60383 Refactor caching mechanism for latents and text encoder outputs, etc. 2024-07-27 13:50:05 +09:00
sdbds
9ca7a5b6cc instead cv2 LANCZOS4 resize to pil resize 2024-07-20 21:59:11 +08:00
sdbds
1f16b80e88 Revert "judge image size for using diff interpolation"
This reverts commit 87526942a6.
2024-07-20 21:35:24 +08:00
Millie
2e67978ee2 Generate sample images without having CUDA (such as on Macs) 2024-07-18 11:52:58 -07:00
sdbds
87526942a6 judge image size for using diff interpolation 2024-07-12 22:56:38 +08:00
Kohya S
082f13658b reduce peak GPU memory usage before training 2024-07-12 21:28:01 +09:00
Kohya S
b8896aad40 update README 2024-07-11 08:01:23 +09:00
Kohya S
6f0e235f2c Fix shift value in SD3 inference. 2024-07-11 08:00:45 +09:00
Kohya S
3d402927ef WIP: update new latents caching 2024-07-09 23:15:38 +09:00
Kohya S
9dc7997803 fix typo 2024-07-09 20:37:00 +09:00
Kohya S
3ea4fce5e0 load models one by one 2024-07-08 22:04:43 +09:00
Kohya S
c9de7c4e9a WIP: new latents caching 2024-07-08 19:48:28 +09:00
Kohya S
50e3d62474 fix to work T5XXL with fp16 2024-07-08 19:46:23 +09:00
Kohya S
ea18d5ba6d Fix to work full_bf16 and full_fp16. 2024-06-29 17:45:50 +09:00
Kohya S
19086465e8 Fix fp16 mixed precision, model is in bf16 without full_bf16 2024-06-29 17:21:25 +09:00
Kohya S
66cf435479 re-fix assertion ref #1389 2024-06-27 13:14:09 +09:00
Kohya S
381598c8bb fix resolution in metadata for sd3 2024-06-26 21:15:02 +09:00
Kohya S
828a581e29 fix assertion for experimental impl ref #1389 2024-06-26 20:43:31 +09:00
Kohya S
8f2ba27869 support text_encoder_batch_size for caching 2024-06-26 20:36:22 +09:00
Kohya S
0b3e4f7ab6 show file name if error in load_image ref #1385 2024-06-25 20:03:09 +09:00
Kohya S
4802e4aaec workaround for long caption ref #1382 2024-06-24 23:13:14 +09:00
Kohya S
0fe4eafac9 fix to use zero for initial latent 2024-06-24 23:12:48 +09:00
Kohya S
d53ea22b2a sd3 training 2024-06-23 23:38:20 +09:00
Kohya S
a518e3c819 Merge branch 'dev' into sd3 2024-06-23 14:12:07 +09:00
Kohya S
9dd1ee458c Merge branch 'main' into dev 2024-06-23 14:11:51 +09:00
Kohya S
25f961bc77 fix to work cache_latents/text_encoder_outputs 2024-06-23 13:24:30 +09:00
Kohya S
e5268286bf add sd3 models and inference script 2024-06-15 22:20:24 +09:00
Kohya S
56bb81c9e6 add grad_hook after restore state closes #1344 2024-06-12 21:39:35 +09:00
Kohya S
22413a5247 Merge pull request #1359 from kohya-ss/train_resume_step
Train resume step
2024-06-11 19:52:03 +09:00
Kohya S
18d7597b0b update README 2024-06-11 19:51:30 +09:00
Kohya S
4a441889d4 Merge branch 'dev' into train_resume_step 2024-06-11 19:27:37 +09:00
Kohya S
3259928ce4 Merge branch 'dev' of https://github.com/kohya-ss/sd-scripts into dev 2024-06-09 19:26:42 +09:00
Kohya S
1a104dc75e make forward/backward pathes same ref #1363 2024-06-09 19:26:36 +09:00
Kohya S
58fb64819a set static graph flag when DDP ref #1363 2024-06-09 19:26:09 +09:00
Kohya S
5bfe5e411b Merge pull request #1361 from shirayu/update/github_actions/crate-ci/typos-1.21.0
Bump crate-ci/typos from 1.19.0 to 1.21.0, fix typos, and updated _typos.toml (Close #1307)
2024-06-06 21:23:24 +09:00
Yuta Hayashibe
4ecbac131a Bump crate-ci/typos from 1.19.0 to 1.21.0, fix typos, and updated _typos.toml (Close #1307) 2024-06-05 16:31:55 +09:00
Kohya S
4dbcef429b update for corner cases 2024-06-04 21:26:55 +09:00
Kohya S
321e24d83b Merge pull request #1353 from KohakuBlueleaf/train_resume_step
Resume correct step for "resume from state" feature.
2024-06-04 19:30:11 +09:00
Kohya S
e5bab69e3a fix alpha mask without disk cache closes #1351, ref #1339 2024-06-02 21:11:40 +09:00
Kohaku-Blueleaf
3eb27ced52 Skip the final 1 step 2024-05-31 12:24:15 +08:00
Kohaku-Blueleaf
b2363f1021 Final implementation 2024-05-31 12:20:20 +08:00
Kohya S
0d96e10b3e Merge pull request #1339 from kohya-ss/alpha-masked-loss
Alpha masked loss
2024-05-27 21:41:16 +09:00
Kohya S
fc85496f7e update docs for masked loss 2024-05-27 21:25:06 +09:00
Kohya S
2870be9b52 Merge branch 'dev' into alpha-masked-loss 2024-05-27 21:08:43 +09:00
Kohya S
71ad3c0f45 Update masked_loss_README-ja.md
add sample images
2024-05-27 21:07:57 +09:00
Kohya S
ffce3b5098 Merge pull request #1349 from rockerBOO/patch-4
Update issue link
2024-05-27 21:00:46 +09:00
Kohya S
a4c3155148 add doc for mask loss 2024-05-27 20:59:40 +09:00
Kohya S
58cadf476b Merge branch 'dev' into alpha-masked-loss 2024-05-27 20:02:32 +09:00
Dave Lage
d50c1b3c5c Update issue link 2024-05-27 01:11:01 -04:00
Kohya S
e8cfd4ba1d fix to work cond mask and alpha mask 2024-05-26 22:01:37 +09:00
Kohya S
fb12b6d8e5 Merge pull request #1347 from rockerBOO/lora-plus-log-info
Add LoRA+ LR Ratio info message to logger
2024-05-26 19:45:03 +09:00
rockerBOO
00513b9b70 Add LoRA+ LR Ratio info message to logger 2024-05-23 22:27:12 -04:00
Kohya S
da6fea3d97 simplify and update alpha mask to work with various cases 2024-05-19 21:26:18 +09:00
Kohya S
f2dd43e198 revert kwargs to explicit declaration 2024-05-19 19:23:59 +09:00
u-haru
db6752901f 画像のアルファチャンネルをlossのマスクとして使用するオプションを追加 (#1223)
* Add alpha_mask parameter and apply masked loss

* Fix type hint in trim_and_resize_if_required function

* Refactor code to use keyword arguments in train_util.py

* Fix alpha mask flipping logic

* Fix alpha mask initialization

* Fix alpha_mask transformation

* Cache alpha_mask

* Update alpha_masks to be on CPU

* Set flipped_alpha_masks to Null if option disabled

* Check if alpha_mask is None

* Set alpha_mask to None if option disabled

* Add description of alpha_mask option to docs
2024-05-19 19:07:25 +09:00
Kohya S
febc5c59fa update README 2024-05-19 19:03:43 +09:00
Kohya S
4c798129b0 update README 2024-05-19 19:00:32 +09:00
Kohya S
38e4c602b1 Merge pull request #1277 from Cauldrath/negative_learning
Allow negative learning rate
2024-05-19 18:55:50 +09:00
Kohya S
e4d9e3c843 remove dependency for omegaconf #ref 1284 2024-05-19 17:46:07 +09:00
Kohya S
de0e0b9468 Merge pull request #1284 from sdbds/fix_traincontrolnet
Fix train controlnet
2024-05-19 17:39:15 +09:00
Kohya S
c68baae480 add --log_config option to enable/disable output training config 2024-05-19 17:21:04 +09:00
Kohya S
47187f7079 Merge pull request #1285 from ccharest93/main
Hyperparameter tracking
2024-05-19 16:31:33 +09:00
Kohya S
e3ddd1fbbe update README and format code 2024-05-19 16:26:10 +09:00
Kohya S
0640f017ab Merge pull request #1322 from aria1th/patch-1
Accelerate: fix get_trainable_params in controlnet-llite training
2024-05-19 16:23:01 +09:00
Kohya S
2f19175dfe update README 2024-05-19 15:38:37 +09:00
Kohya S
146edce693 support Diffusers' based SDXL LoRA key for inference 2024-05-18 11:05:04 +09:00
Kohya S
153764a687 add prompt option '--f' for filename 2024-05-15 20:21:49 +09:00
Kohya S
589c2aa025 update README 2024-05-13 21:20:37 +09:00
Kohya S
16677da0d9 fix create_network_from_weights doesn't work 2024-05-12 22:15:07 +09:00
Kohya S
a384bf2187 Merge pull request #1313 from rockerBOO/patch-3
Add caption_separator to output for subset
2024-05-12 21:36:56 +09:00
Kohya S
1c296f7229 Merge pull request #1312 from rockerBOO/patch-2
Fix caption_separator missing in subset schema
2024-05-12 21:33:12 +09:00
Kohya S
e96a5217c3 Merge pull request #1291 from frodo821/patch-1
removed unnecessary `torch` import on line 115
2024-05-12 21:14:50 +09:00
Kohya S
39b82f26e5 update readme 2024-05-12 20:58:45 +09:00
Kohya S
3701507874 raise original error if error is occured in checking latents 2024-05-12 20:56:56 +09:00
Kohya S
78020936d2 Merge pull request #1278 from Cauldrath/catch_latent_error_file
Display name of error latent file
2024-05-12 20:46:25 +09:00
Kohya S
9ddb4d7a01 update readme and help message etc. 2024-05-12 17:55:08 +09:00
Kohya S
8d1b1acd33 Merge pull request #1266 from Zovjsra/feature/disable-mmap
Add "--disable_mmap_load_safetensors" parameter
2024-05-12 17:43:44 +09:00
Kohya S
02298e3c4a Merge pull request #1331 from kohya-ss/lora-plus
Lora plus
2024-05-12 17:04:58 +09:00
Kohya S
44190416c6 update docs etc. 2024-05-12 17:01:20 +09:00
Kohya S
3c8193f642 revert lora+ for lora_fa 2024-05-12 17:00:51 +09:00
Kohya S
c6a437054a Merge branch 'dev' into lora-plus 2024-05-12 16:18:57 +09:00
Kohya S
1ffc0b330a fix typo 2024-05-12 16:18:43 +09:00
Kohya S
e01e148705 Merge branch 'dev' into lora-plus 2024-05-12 16:17:52 +09:00
Kohya S
e9f3a622f4 Merge branch 'dev' into lora-plus 2024-05-12 16:17:27 +09:00
Kohya S
7983d3db5f Merge pull request #1319 from kohya-ss/fused-backward-pass
Fused backward pass
2024-05-12 15:09:39 +09:00
Kohya S
bee8cee7e8 update README for fused optimizer 2024-05-12 15:08:52 +09:00
Kohya S
f3d2cf22ff update README for fused optimizer 2024-05-12 15:03:02 +09:00
Kohya S
6dbc23cf63 Merge branch 'dev' into fused-backward-pass 2024-05-12 14:21:56 +09:00
Kohya S
c1ba0b4356 update readme 2024-05-12 14:21:10 +09:00
Kohya S
607e041f3d chore: Refactor optimizer group 2024-05-12 14:16:41 +09:00
AngelBottomless
793aeb94da fix get_trainable_params in controlnet-llite training 2024-05-07 18:21:31 +09:00
Kohya S
b56d5f7801 add experimental option to fuse params to optimizer groups 2024-05-06 21:35:39 +09:00
Kohya S
017b82ebe3 update help message for fused_backward_pass 2024-05-06 15:05:42 +09:00
Kohya S
2a359e0a41 Merge pull request #1259 from 2kpr/fused_backward_pass
Adafactor fused backward pass and optimizer step, lowers SDXL (@ 1024 resolution) VRAM usage to BF16(10GB)/FP32(16.4GB)
2024-05-06 15:01:56 +09:00
Kohya S
3fd8cdc55d fix dylora loraplus 2024-05-06 14:03:19 +09:00
Kohya S
7fe81502d0 update loraplus on dylora/lofa_fa 2024-05-06 11:09:32 +09:00
Kohya S
52e64c69cf add debug log 2024-05-04 18:43:52 +09:00
Kohya S
58c2d856ae support block dim/lr for sdxl 2024-05-03 22:18:20 +09:00
Dave Lage
8db0cadcee Add caption_separator to output for subset 2024-05-02 18:08:28 -04:00
Dave Lage
dbb7bb288e Fix caption_separator missing in subset schema 2024-05-02 17:39:35 -04:00
Kohya S
969f82ab47 move loraplus args from args to network_args, simplify log lr desc 2024-04-29 20:04:25 +09:00
Kohya S
834445a1d6 Merge pull request #1233 from rockerBOO/lora-plus
Add LoRA+ support
2024-04-29 18:05:12 +09:00
frodo821
fdbb03c360 removed unnecessary torch import on line 115
as per #1290
2024-04-23 14:29:05 +09:00
Cauldrath
040e26ff1d Regenerate failed file
If a latent file fails to load, print out the path and the error, then return false to regenerate it
2024-04-21 13:46:31 -04:00
Kohya S
0540c33aca pop weights if available #1247 2024-04-21 17:45:29 +09:00
Kohya S
52652cba1a disable main process check for deepspeed #1247 2024-04-21 17:41:32 +09:00
青龍聖者@bdsqlsz
5cb145d13b Update train_util.py 2024-04-20 21:56:24 +08:00
Maatra
b886d0a359 Cleaned typing to be in line with accelerate hyperparameters type resctrictions 2024-04-20 14:36:47 +01:00
青龍聖者@bdsqlsz
4477116a64 fix train controlnet 2024-04-20 21:26:09 +08:00
Maatra
2c9db5d9f2 passing filtered hyperparameters to accelerate 2024-04-20 14:11:43 +01:00
Cauldrath
fc374375de Allow negative learning rate
This can be used to train away from a group of images you don't want
As this moves the model away from a point instead of towards it, the change in the model is unbounded
So, don't set it too low. -4e-7 seemed to work well.
2024-04-18 23:29:01 -04:00
Cauldrath
feefcf256e Display name of error latent file
When trying to load stored latents, if an error occurs, this change will tell you what file failed to load
Currently it will just tell you that something failed without telling you which file
2024-04-18 23:15:36 -04:00
Zovjsra
64916a35b2 add disable_mmap to args 2024-04-16 16:40:08 +08:00
2kpr
4f203ce40d Fused backward pass 2024-04-14 09:56:58 -05:00
rockerBOO
68467bdf4d Fix unset or invalid LR from making a param_group 2024-04-11 17:33:19 -04:00
rockerBOO
75833e84a1 Fix default LR, Add overall LoRA+ ratio, Add log
`--loraplus_ratio` added for both TE and UNet
Add log for lora+
2024-04-08 19:23:02 -04:00
Kohya S
71e2c91330 Merge pull request #1230 from kohya-ss/dependabot/github_actions/crate-ci/typos-1.19.0
Bump crate-ci/typos from 1.17.2 to 1.19.0
2024-04-07 21:14:18 +09:00
Kohya S
bfb352bc43 change huber_schedule from exponential to snr 2024-04-07 21:07:52 +09:00
Kohya S
c973b29da4 update readme 2024-04-07 20:51:52 +09:00
Kohya S
683f3d6ab3 Merge pull request #1212 from kohya-ss/dev
Version 0.8.6
2024-04-07 20:42:41 +09:00
Kohya S
dfa30790a9 update readme 2024-04-07 20:34:26 +09:00
Kohya S
d30ebb205c update readme, add metadata for network module 2024-04-07 14:58:17 +09:00
kabachuha
90b18795fc Add option to use Scheduled Huber Loss in all training pipelines to improve resilience to data corruption (#1228)
* add huber loss and huber_c compute to train_util

* add reduction modes

* add huber_c retrieval from timestep getter

* move get timesteps and huber to own function

* add conditional loss to all training scripts

* add cond loss to train network

* add (scheduled) huber_loss to args

* fixup twice timesteps getting

* PHL-schedule should depend on noise scheduler's num timesteps

* *2 multiplier to huber loss cause of 1/2 a^2 conv.

The Taylor expansion of sqrt near zero gives 1/2 a^2, which differs from a^2 of the standard MSE loss. This change scales them better against one another

* add option for smooth l1 (huber / delta)

* unify huber scheduling

* add snr huber scheduler

---------

Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
2024-04-07 13:54:21 +09:00
Kohya S
089727b5ee update readme 2024-04-07 12:42:49 +09:00
Kohya S
921036dd91 Merge pull request #1240 from kohya-ss/verify-command-line-args
verify command line args if wandb is enabled
2024-04-07 12:27:03 +09:00
ykume
cd587ce62c verify command line args if wandb is enabled 2024-04-05 08:23:03 +09:00
rockerBOO
1933ab4b48 Fix default_lr being applied 2024-04-03 12:46:34 -04:00
Kohya S
b748b48dbb fix attention couple+deep shink cause error in some reso 2024-04-03 12:43:08 +09:00
rockerBOO
c7691607ea Add LoRA-FA for LoRA+ 2024-04-01 15:43:04 -04:00
rockerBOO
f99fe281cb Add LoRA+ support 2024-04-01 15:38:26 -04:00
dependabot[bot]
80e9f72234 Bump crate-ci/typos from 1.17.2 to 1.19.0
Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.17.2 to 1.19.0.
- [Release notes](https://github.com/crate-ci/typos/releases)
- [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md)
- [Commits](https://github.com/crate-ci/typos/compare/v1.17.2...v1.19.0)

---
updated-dependencies:
- dependency-name: crate-ci/typos
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-04-01 01:50:22 +00:00
Kohya S
2258a1b753 add save/load hook to remove U-Net/TEs from state 2024-03-31 15:50:35 +09:00
Kohya S
059ee047f3 fix typo 2024-03-30 23:02:24 +09:00
Kohya S
2c2ca9d726 update tagger doc 2024-03-30 22:55:56 +09:00
Kohya S
f5323e3c4b update tagger doc 2024-03-30 22:10:37 +09:00
Kohya S
cae5aa0a56 update wd14 tagger and doc 2024-03-30 21:48:22 +09:00
Kohya S
6ba84288d9 Merge pull request #1216 from Disty0/dev
Rating support for WD Tagger
2024-03-30 18:50:49 +09:00
Kohya S
434dc408f9 update readme 2024-03-30 17:12:36 +09:00
Kohya S
ae3f625739 Merge branch 'dev' of https://github.com/kohya-ss/sd-scripts into dev 2024-03-30 14:57:43 +09:00
Kohya S
f1f30ab418 fix to work with num_beams>1 closes #1149 2024-03-30 14:57:39 +09:00
Disty0
bc586ce190 Add --use_rating_tags and --character_tags_first for WD Tagger 2024-03-29 13:56:42 +03:00
Disty0
4012fd24f6 IPEX fix pin_memory 2024-03-28 21:08:16 +03:00
Disty0
954731d564 fix typo 2024-03-27 22:00:59 +03:00
Disty0
dd9763be31 Rating support for WD Tagger 2024-03-27 21:53:40 +03:00
Kohya S
b86af6798d Merge pull request #1213 from Disty0/dev
Add OpenVINO and ROCm ONNX Runtime for WD14
2024-03-27 23:15:33 +09:00
Disty0
6f7e93d5cc Add OpenVINO and ROCm ONNX Runtime for WD14 2024-03-27 03:21:13 +03:00
Kohya S
6c08e97e1f update readme 2024-03-26 20:48:08 +09:00
Kohya S
78e0a7630c Merge pull request #1206 from kohya-ss/dataset-cache
Add metadata caching for DreamBooth dataset
2024-03-26 19:49:23 +09:00
Kohya S
c86e356013 Merge branch 'dev' into dataset-cache 2024-03-26 19:43:40 +09:00
Kohya S
5a2afb3588 Merge pull request #1207 from kohya-ss/masked-loss
Add masked loss
2024-03-26 19:41:31 +09:00
Kohya S
ab1e389347 Merge branch 'dev' into masked-loss 2024-03-26 19:39:30 +09:00
Kohya S
ea05e3fd5b Merge pull request #1139 from kohya-ss/deep-speed
Deep speed
2024-03-26 19:33:57 +09:00
Kohya S
a2b8531627 make each script consistent, fix to work w/o DeepSpeed 2024-03-25 22:28:46 +09:00
Kohya S
c24422fb9d Merge branch 'dev' into deep-speed 2024-03-25 22:11:05 +09:00
Kohya S
9c4492b58a fix pytorch version 2.1.1 to 2.1.2 2024-03-24 23:17:25 +09:00
Kohya S
9bbb28c361 update PyTorch version and reorganize dependencies 2024-03-24 22:06:37 +09:00
Kohya S
1648ade6da format by black 2024-03-24 20:55:48 +09:00
Kohya S
993b2ab4c1 Merge branch 'dev' into deep-speed 2024-03-24 18:45:59 +09:00
Kohya S
8d5858826f Merge branch 'dev' into masked-loss 2024-03-24 18:19:53 +09:00
Kohya S
025347214d refactor metadata caching for DreamBooth dataset 2024-03-24 18:09:32 +09:00
Kohaku-Blueleaf
ae97c8bfd1 [Experimental] Add cache mechanism for dataset groups to avoid long waiting time for initilization (#1178)
* support meta cached dataset

* add cache meta scripts

* random ip_noise_gamma strength

* random noise_offset strength

* use correct settings for parser

* cache path/caption/size only

* revert mess up commit

* revert mess up commit

* Update requirements.txt

* Add arguments for meta cache.

* remove pickle implementation

* Return sizes when enable cache

---------

Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
2024-03-24 15:40:18 +09:00
Kohya S
381c44955e update readme and typing hint 2024-03-24 11:27:18 +09:00
Kohya S
ad97410ba5 Merge pull request #1205 from feffy380/patch-1
register reg images with correct subset
2024-03-24 11:14:07 +09:00
Kohya S
691f04322a update readme 2024-03-24 11:10:26 +09:00
Kohya S
79d1c12ab0 disable sample_every_n_xxx if value less than 1 ref #1202 2024-03-24 11:06:37 +09:00
feffy380
0c7baea88c register reg images with correct subset 2024-03-23 17:28:02 +01:00
Kohya S
f4a4c11cd3 support multiline captions ref #1155 2024-03-23 18:51:37 +09:00
Kohya S
594c7f7050 format by black 2024-03-23 16:11:31 +09:00
Kohya S
d17c0f5084 update dataset config doc 2024-03-21 08:31:29 +09:00
Kohya S
a35e7bd595 Merge pull request #1200 from BootsofLagrangian/deep-speed
Fix sdxl_train.py in deepspeed branch
2024-03-20 21:32:35 +09:00
BootsofLagrangian
d9456020d7 Fix most of ZeRO stage uses optimizer partitioning
- we have to prepare optimizer and ds_model at the same time.
 - pull/1139#issuecomment-1986790007

Signed-off-by: BootsofLagrangian <hard2251@yonsei.ac.kr>
2024-03-20 20:52:59 +09:00
Kohya S
fbb98f144e Merge branch 'dev' into deep-speed 2024-03-20 18:15:26 +09:00
Kohya S
9b6b39f204 Merge branch 'dev' into masked-loss 2024-03-20 18:14:36 +09:00
Kohya S
86e40fabbc Merge branch 'dev' into deep-speed 2024-03-17 19:30:42 +09:00
Kohya S
3419c3de0d common masked loss func, apply to all training script 2024-03-17 19:30:20 +09:00
Kohya S
7081a0cf0f extension of src image could be different than target image 2024-03-17 18:09:15 +09:00
Kohya S
0ef4fe70f0 Merge branch 'dev' into masked-loss 2024-03-17 11:18:18 +09:00
Kohya S
97524f1bda Merge branch 'dev' into deep-speed 2024-03-12 20:41:41 +09:00
Kohya S
74c266a597 Merge branch 'dev' into masked-loss 2024-03-12 20:40:57 +09:00
Kohya S
a9b64ffba8 support masked loss in sdxl_train ref #589 2024-02-27 21:43:55 +09:00
Kohya S
e3ccf8fbf7 make deepspeed_utils 2024-02-27 21:30:46 +09:00
Kohya S
0e4a5738df Merge pull request #1101 from BootsofLagrangian/deepspeed
support deepspeed
2024-02-27 18:59:00 +09:00
Kohya S
eefb3cc1e7 Merge branch 'deep-speed' into deepspeed 2024-02-27 18:57:42 +09:00
Kohya S
4a5546d40e fix typo 2024-02-26 23:39:56 +09:00
Kohya S
175193623b update readme 2024-02-26 23:29:41 +09:00
Kohya S
f2c727fc8c add minimal impl for masked loss 2024-02-26 23:19:58 +09:00
BootsofLagrangian
4d5186d1cf refactored codes, some function moved into train_utils.py 2024-02-22 16:20:53 +09:00
BootsofLagrangian
03f0816f86 the reason not working grad accum steps found. it was becasue of my accelerate settings 2024-02-09 17:47:49 +09:00
BootsofLagrangian
a98fecaeb1 forgot setting mixed_precision for deepspeed. sorry 2024-02-07 17:19:46 +09:00
BootsofLagrangian
2445a5b74e remove test requirements 2024-02-07 16:48:18 +09:00
BootsofLagrangian
62556619bd fix full_fp16 compatible and train_step 2024-02-07 16:42:05 +09:00
BootsofLagrangian
7d2a9268b9 apply offloading method runable for all trainer 2024-02-05 22:42:06 +09:00
BootsofLagrangian
3970bf4080 maybe fix branch to run offloading 2024-02-05 22:40:43 +09:00
BootsofLagrangian
4295f91dcd fix all trainer about vae 2024-02-05 20:19:56 +09:00
BootsofLagrangian
2824312d5e fix vae type error during training sdxl 2024-02-05 20:13:28 +09:00
BootsofLagrangian
64873c1b43 fix offload_optimizer_device typo 2024-02-05 17:11:50 +09:00
BootsofLagrangian
dfe08f395f support deepspeed 2024-02-04 03:12:42 +09:00
90 changed files with 23055 additions and 2141 deletions

View File

@@ -18,4 +18,4 @@ jobs:
- uses: actions/checkout@v4
- name: typos-action
uses: crate-ci/typos@v1.17.2
uses: crate-ci/typos@v1.24.3

View File

@@ -1,12 +1,12 @@
SDXLがサポートされました。sdxlブランチはmainブランチにマージされました。リポジトリを更新したときにはUpgradeの手順を実行してください。また accelerate のバージョンが上がっていますので、accelerate config を再度実行してください。
SDXL学習については[こちら](./README.md#sdxl-training)をご覧ください(英語です)。
## リポジトリについて
Stable Diffusionの学習、画像生成、その他のスクリプトを入れたリポジトリです。
[README in English](./README.md) ←更新情報はこちらにあります
開発中のバージョンはdevブランチにあります。最新の変更点はdevブランチをご確認ください。
FLUX.1およびSD3/SD3.5対応はsd3ブランチで行っています。それらの学習を行う場合はsd3ブランチをご利用ください。
GUIやPowerShellスクリプトなど、より使いやすくする機能が[bmaltais氏のリポジトリ](https://github.com/bmaltais/kohya_ss)で提供されています英語ですのであわせてご覧ください。bmaltais氏に感謝します。
以下のスクリプトがあります。
@@ -21,6 +21,7 @@ GUIやPowerShellスクリプトなど、より使いやすくする機能が[bma
* [学習について、共通編](./docs/train_README-ja.md) : データ整備やオプションなど
* [データセット設定](./docs/config_README-ja.md)
* [SDXL学習](./docs/train_SDXL-en.md) (英語版)
* [DreamBoothの学習について](./docs/train_db_README-ja.md)
* [fine-tuningのガイド](./docs/fine_tune_README_ja.md):
* [LoRAの学習について](./docs/train_network_README-ja.md)
@@ -44,9 +45,7 @@ PowerShellを使う場合、venvを使えるようにするためには以下の
## Windows環境でのインストール
スクリプトはPyTorch 2.0.1でテストしています。PyTorch 1.12.1でも動作すると思われます。
以下の例ではPyTorchは2.0.1CUDA 11.8版をインストールします。CUDA 11.6版やPyTorch 1.12.1を使う場合は適宜書き換えください。
スクリプトはPyTorch 2.1.2でテストしています。PyTorch 2.0.1、1.12.1でも動作すると思われます。
なお、python -m venvの行で「python」とだけ表示された場合、py -m venvのようにpythonをpyに変更してください。
@@ -59,21 +58,21 @@ cd sd-scripts
python -m venv venv
.\venv\Scripts\activate
pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118
pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu118
pip install --upgrade -r requirements.txt
pip install xformers==0.0.20
pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu118
accelerate config
```
コマンドプロンプトでも同一です。
(注:``python -m venv venv`` のほうが ``python -m venv --system-site-packages venv`` より安全そうなため書き換えました。globalなpythonにパッケージがインストールしてあると、後者だといろいろと問題が起きます。
注:`bitsandbytes==0.43.0``prodigyopt==1.0``lion-pytorch==0.0.6``requirements.txt` に含まれるようになりました。他のバージョンを使う場合は適宜インストールしてください。
この例では PyTorch および xfomers は2.1.2CUDA 11.8版をインストールします。CUDA 12.1版やPyTorch 1.12.1を使う場合は適宜書き換えください。たとえば CUDA 12.1版の場合は `pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu121` および `pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu121` としてください。
accelerate configの質問には以下のように答えてください。bf16で学習する場合、最後の質問にはbf16と答えてください。
※0.15.0から日本語環境では選択のためにカーソルキーを押すと落ちます……。数字キーの0、1、2……で選択できますので、そちらを使ってください。
```txt
- This machine
- No distributed training
@@ -87,41 +86,6 @@ accelerate configの質問には以下のように答えてください。bf1
※場合によって ``ValueError: fp16 mixed precision requires a GPU`` というエラーが出ることがあるようです。この場合、6番目の質問
``What GPU(s) (by id) should be used for training on this machine as a comma-separated list? [all]:``に「0」と答えてください。id `0`のGPUが使われます。
### オプション:`bitsandbytes`8bit optimizerを使う
`bitsandbytes`はオプションになりました。Linuxでは通常通りpipでインストールできます0.41.1または以降のバージョンを推奨)。
Windowsでは0.35.0または0.41.1を推奨します。
- `bitsandbytes` 0.35.0: 安定しているとみられるバージョンです。AdamW8bitは使用できますが、他のいくつかの8bit optimizer、学習時の`full_bf16`オプションは使用できません。
- `bitsandbytes` 0.41.1: Lion8bit、PagedAdamW8bit、PagedLion8bitをサポートします。`full_bf16`が使用できます。
注:`bitsandbytes` 0.35.0から0.41.0までのバージョンには問題があるようです。 https://github.com/TimDettmers/bitsandbytes/issues/659
以下の手順に従い、`bitsandbytes`をインストールしてください。
### 0.35.0を使う場合
PowerShellの例です。コマンドプロンプトではcpの代わりにcopyを使ってください。
```powershell
cd sd-scripts
.\venv\Scripts\activate
pip install bitsandbytes==0.35.0
cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
```
### 0.41.1を使う場合
jllllll氏の配布されている[こちら](https://github.com/jllllll/bitsandbytes-windows-webui) または他の場所から、Windows用のwhlファイルをインストールしてください。
```powershell
python -m pip install bitsandbytes==0.41.1 --prefer-binary --extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui
```
## アップグレード
新しいリリースがあった場合、以下のコマンドで更新できます。
@@ -151,4 +115,47 @@ Conv2d 3x3への拡大は [cloneofsimo氏](https://github.com/cloneofsimo/lora)
[BLIP](https://github.com/salesforce/BLIP): BSD-3-Clause
## その他の情報
### LoRAの名称について
`train_network.py` がサポートするLoRAについて、混乱を避けるため名前を付けました。ドキュメントは更新済みです。以下は当リポジトリ内の独自の名称です。
1. __LoRA-LierLa__ : (LoRA for __Li__ n __e__ a __r__ __La__ yers、リエラと読みます)
Linear 層およびカーネルサイズ 1x1 の Conv2d 層に適用されるLoRA
2. __LoRA-C3Lier__ : (LoRA for __C__ olutional layers with __3__ x3 Kernel and __Li__ n __e__ a __r__ layers、セリアと読みます)
1.に加え、カーネルサイズ 3x3 の Conv2d 層に適用されるLoRA
デフォルトではLoRA-LierLaが使われます。LoRA-C3Lierを使う場合は `--network_args` に `conv_dim` を指定してください。
<!--
LoRA-LierLa は[Web UI向け拡張](https://github.com/kohya-ss/sd-webui-additional-networks)、またはAUTOMATIC1111氏のWeb UIのLoRA機能で使用することができます。
LoRA-C3Lierを使いWeb UIで生成するには拡張を使用してください。
-->
### 学習中のサンプル画像生成
プロンプトファイルは例えば以下のようになります。
```
# prompt 1
masterpiece, best quality, (1girl), in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
# prompt 2
masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n (low quality, worst quality), bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
```
`#` で始まる行はコメントになります。`--n` のように「ハイフン二個+英小文字」の形でオプションを指定できます。以下が使用可能できます。
* `--n` Negative prompt up to the next option.
* `--w` Specifies the width of the generated image.
* `--h` Specifies the height of the generated image.
* `--d` Specifies the seed of the generated image.
* `--l` Specifies the CFG scale of the generated image.
* `--s` Specifies the number of steps in the generation.
`( )` や `[ ]` などの重みづけも動作します。

1381
README.md

File diff suppressed because it is too large Load Diff

View File

@@ -2,6 +2,7 @@
# Instruction: https://github.com/marketplace/actions/typos-action#getting-started
[default.extend-identifiers]
ddPn08="ddPn08"
[default.extend-words]
NIN="NIN"
@@ -27,6 +28,7 @@ rik="rik"
koo="koo"
yos="yos"
wn="wn"
hime="hime"
[files]

View File

@@ -1,7 +1,10 @@
Original Source by kohya-ss
First version:
A.I Translation by Model: NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO, editing by Darkstorm2150
Some parts are manually added.
# Config Readme
This README is about the configuration files that can be passed with the `--dataset_config` option.
@@ -125,6 +128,8 @@ These are options related to the configuration of the data set. They cannot be d
* `batch_size`
* This corresponds to the command-line argument `--train_batch_size`.
* `max_bucket_reso`, `min_bucket_reso`
* Specify the maximum and minimum resolutions of the bucket. It must be divisible by `bucket_reso_steps`.
These settings are fixed per dataset. That means that subsets belonging to the same dataset will share these settings. For example, if you want to prepare datasets with different resolutions, you can define them as separate datasets as shown in the example above, and set different resolutions for each.
@@ -143,11 +148,23 @@ These options are related to subset configuration.
| `shuffle_caption` | `true` | o | o | o |
| `caption_prefix` | `"masterpiece, best quality, "` | o | o | o |
| `caption_suffix` | `", from side"` | o | o | o |
| `caption_separator` | (not specified) | o | o | o |
| `keep_tokens_separator` | `“|||”` | o | o | o |
| `secondary_separator` | `“;;;”` | o | o | o |
| `enable_wildcard` | `true` | o | o | o |
* `num_repeats`
* Specifies the number of repeats for images in a subset. This is equivalent to `--dataset_repeats` in fine-tuning but can be specified for any training method.
* `caption_prefix`, `caption_suffix`
* Specifies the prefix and suffix strings to be appended to the captions. Shuffling is performed with these strings included. Be cautious when using `keep_tokens`.
* `caption_separator`
* Specifies the string to separate the tags. The default is `,`. This option is usually not necessary to set.
* `keep_tokens_separator`
* Specifies the string to separate the parts to be fixed in the caption. For example, if you specify `aaa, bbb ||| ccc, ddd, eee, fff ||| ggg, hhh`, the parts `aaa, bbb` and `ggg, hhh` will remain, and the rest will be shuffled and dropped. The comma in between is not necessary. As a result, the prompt will be `aaa, bbb, eee, ccc, fff, ggg, hhh` or `aaa, bbb, fff, ccc, eee, ggg, hhh`, etc.
* `secondary_separator`
* Specifies an additional separator. The part separated by this separator is treated as one tag and is shuffled and dropped. It is then replaced by `caption_separator`. For example, if you specify `aaa;;;bbb;;;ccc`, it will be replaced by `aaa,bbb,ccc` or dropped together.
* `enable_wildcard`
* Enables wildcard notation. This will be explained later.
### DreamBooth-specific options
@@ -162,6 +179,7 @@ Options related to the configuration of DreamBooth subsets.
| `image_dir` | `'C:\hoge'` | - | - | o (required) |
| `caption_extension` | `".txt"` | o | o | o |
| `class_tokens` | `"sks girl"` | - | - | o |
| `cache_info` | `false` | o | o | o |
| `is_reg` | `false` | - | - | o |
Firstly, note that for `image_dir`, the path to the image files must be specified as being directly in the directory. Unlike the previous DreamBooth method, where images had to be placed in subdirectories, this is not compatible with that specification. Also, even if you name the folder something like "5_cat", the number of repeats of the image and the class name will not be reflected. If you want to set these individually, you will need to explicitly specify them using `num_repeats` and `class_tokens`.
@@ -172,6 +190,9 @@ Firstly, note that for `image_dir`, the path to the image files must be specifie
* `class_tokens`
* Sets the class tokens.
* Only used during training when a corresponding caption file does not exist. The determination of whether or not to use it is made on a per-image basis. If `class_tokens` is not specified and a caption file is not found, an error will occur.
* `cache_info`
* Specifies whether to cache the image size and caption. If not specified, it is set to `false`. The cache is saved in `metadata_cache.json` in `image_dir`.
* Caching speeds up the loading of the dataset after the first time. It is effective when dealing with thousands of images or more.
* `is_reg`
* Specifies whether the subset images are for normalization. If not specified, it is set to `false`, meaning that the images are not for normalization.
@@ -276,4 +297,90 @@ As a temporary measure, we will list common errors and their solutions. If you e
* `voluptuous.error.MultipleInvalid: expected int for dictionary value @ ...`: This error occurs when the specified value format is incorrect. It is highly likely that the value format is incorrect. The `int` part changes depending on the target option. The example configurations in this README may be helpful.
* `voluptuous.error.MultipleInvalid: extra keys not allowed @ ...`: This error occurs when there is an option name that is not supported. It is highly likely that you misspelled the option name or mistakenly included it.
## Miscellaneous
### Multi-line captions
By setting `enable_wildcard = true`, multiple-line captions are also enabled. If the caption file consists of multiple lines, one line is randomly selected as the caption.
```txt
1girl, hatsune miku, vocaloid, upper body, looking at viewer, microphone, stage
a girl with a microphone standing on a stage
detailed digital art of a girl with a microphone on a stage
```
It can be combined with wildcard notation.
In metadata files, you can also specify multiple-line captions. In the `.json` metadata file, use `\n` to represent a line break. If the caption file consists of multiple lines, `merge_captions_to_metadata.py` will create a metadata file in this format.
The tags in the metadata (`tags`) are added to each line of the caption.
```json
{
"/path/to/image.png": {
"caption": "a cartoon of a frog with the word frog on it\ntest multiline caption1\ntest multiline caption2",
"tags": "open mouth, simple background, standing, no humans, animal, black background, frog, animal costume, animal focus"
},
...
}
```
In this case, the actual caption will be `a cartoon of a frog with the word frog on it, open mouth, simple background ...`, `test multiline caption1, open mouth, simple background ...`, `test multiline caption2, open mouth, simple background ...`, etc.
### Example of configuration file : `secondary_separator`, wildcard notation, `keep_tokens_separator`, etc.
```toml
[general]
flip_aug = true
color_aug = false
resolution = [1024, 1024]
[[datasets]]
batch_size = 6
enable_bucket = true
bucket_no_upscale = true
caption_extension = ".txt"
keep_tokens_separator= "|||"
shuffle_caption = true
caption_tag_dropout_rate = 0.1
secondary_separator = ";;;" # subset 側に書くこともできます / can be written in the subset side
enable_wildcard = true # 同上 / same as above
[[datasets.subsets]]
image_dir = "/path/to/image_dir"
num_repeats = 1
# ||| の前後はカンマは不要です(自動的に追加されます) / No comma is required before and after ||| (it is added automatically)
caption_prefix = "1girl, hatsune miku, vocaloid |||"
# ||| の後はシャッフル、drop されず残ります / After |||, it is not shuffled or dropped and remains
# 単純に文字列として連結されるので、カンマなどは自分で入れる必要があります / It is simply concatenated as a string, so you need to put commas yourself
caption_suffix = ", anime screencap ||| masterpiece, rating: general"
```
### Example of caption, secondary_separator notation: `secondary_separator = ";;;"`
```txt
1girl, hatsune miku, vocaloid, upper body, looking at viewer, sky;;;cloud;;;day, outdoors
```
The part `sky;;;cloud;;;day` is replaced with `sky,cloud,day` without shuffling or dropping. When shuffling and dropping are enabled, it is processed as a whole (as one tag). For example, it becomes `vocaloid, 1girl, upper body, sky,cloud,day, outdoors, hatsune miku` (shuffled) or `vocaloid, 1girl, outdoors, looking at viewer, upper body, hatsune miku` (dropped).
### Example of caption, enable_wildcard notation: `enable_wildcard = true`
```txt
1girl, hatsune miku, vocaloid, upper body, looking at viewer, {simple|white} background
```
`simple` or `white` is randomly selected, and it becomes `simple background` or `white background`.
```txt
1girl, hatsune miku, vocaloid, {{retro style}}
```
If you want to include `{` or `}` in the tag string, double them like `{{` or `}}` (in this example, the actual caption used for training is `{retro style}`).
### Example of caption, `keep_tokens_separator` notation: `keep_tokens_separator = "|||"`
```txt
1girl, hatsune miku, vocaloid ||| stage, microphone, white shirt, smile ||| best quality, rating: general
```
It becomes `1girl, hatsune miku, vocaloid, microphone, stage, white shirt, best quality, rating: general` or `1girl, hatsune miku, vocaloid, white shirt, smile, stage, microphone, best quality, rating: general` etc.

View File

@@ -1,5 +1,3 @@
For non-Japanese speakers: this README is provided only in Japanese in the current state. Sorry for inconvenience. We will provide English version in the near future.
`--dataset_config` で渡すことができる設定ファイルに関する説明です。
## 概要
@@ -120,6 +118,8 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学
* `batch_size`
* コマンドライン引数の `--train_batch_size` と同等です。
* `max_bucket_reso`, `min_bucket_reso`
* bucketの最大、最小解像度を指定します。`bucket_reso_steps` で割り切れる必要があります。
これらの設定はデータセットごとに固定です。
つまり、データセットに所属するサブセットはこれらの設定を共有することになります。
@@ -140,12 +140,28 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学
| `shuffle_caption` | `true` | o | o | o |
| `caption_prefix` | `“masterpiece, best quality, ”` | o | o | o |
| `caption_suffix` | `“, from side”` | o | o | o |
| `caption_separator` | (通常は設定しません) | o | o | o |
| `keep_tokens_separator` | `“|||”` | o | o | o |
| `secondary_separator` | `“;;;”` | o | o | o |
| `enable_wildcard` | `true` | o | o | o |
* `num_repeats`
* サブセットの画像の繰り返し回数を指定します。fine tuning における `--dataset_repeats` に相当しますが、`num_repeats` はどの学習方法でも指定可能です。
* `caption_prefix`, `caption_suffix`
* キャプションの前、後に付与する文字列を指定します。シャッフルはこれらの文字列を含めた状態で行われます。`keep_tokens` を指定する場合には注意してください。
* `caption_separator`
* タグを区切る文字列を指定します。デフォルトは `,` です。このオプションは通常は設定する必要はありません。
* `keep_tokens_separator`
* キャプションで固定したい部分を区切る文字列を指定します。たとえば `aaa, bbb ||| ccc, ddd, eee, fff ||| ggg, hhh` のように指定すると、`aaa, bbb``ggg, hhh` の部分はシャッフル、drop されず残ります。間のカンマは不要です。結果としてプロンプトは `aaa, bbb, eee, ccc, fff, ggg, hhh``aaa, bbb, fff, ccc, eee, ggg, hhh` などになります。
* `secondary_separator`
* 追加の区切り文字を指定します。この区切り文字で区切られた部分は一つのタグとして扱われ、シャッフル、drop されます。その後、`caption_separator` に置き換えられます。たとえば `aaa;;;bbb;;;ccc` のように指定すると、`aaa,bbb,ccc` に置き換えられるか、まとめて drop されます。
* `enable_wildcard`
* ワイルドカード記法および複数行キャプションを有効にします。ワイルドカード記法、複数行キャプションについては後述します。
### DreamBooth 方式専用のオプション
DreamBooth 方式のオプションは、サブセット向けオプションのみ存在します。
@@ -159,6 +175,7 @@ DreamBooth 方式のサブセットの設定に関わるオプションです。
| `image_dir` | `C:\hoge` | - | - | o必須 |
| `caption_extension` | `".txt"` | o | o | o |
| `class_tokens` | `“sks girl”` | - | - | o |
| `cache_info` | `false` | o | o | o |
| `is_reg` | `false` | - | - | o |
まず注意点として、 `image_dir` には画像ファイルが直下に置かれているパスを指定する必要があります。従来の DreamBooth の手法ではサブディレクトリに画像を置く必要がありましたが、そちらとは仕様に互換性がありません。また、`5_cat` のようなフォルダ名にしても、画像の繰り返し回数とクラス名は反映されません。これらを個別に設定したい場合、`num_repeats``class_tokens` で明示的に指定する必要があることに注意してください。
@@ -169,6 +186,9 @@ DreamBooth 方式のサブセットの設定に関わるオプションです。
* `class_tokens`
* クラストークンを設定します。
* 画像に対応する caption ファイルが存在しない場合にのみ学習時に利用されます。利用するかどうかの判定は画像ごとに行います。`class_tokens` を指定しなかった場合に caption ファイルも見つからなかった場合にはエラーになります。
* `cache_info`
* 画像サイズ、キャプションをキャッシュするかどうかを指定します。指定しなかった場合は `false` になります。キャッシュは `image_dir``metadata_cache.json` というファイル名で保存されます。
* キャッシュを行うと、二回目以降のデータセット読み込みが高速化されます。数千枚以上の画像を扱う場合には有効です。
* `is_reg`
* サブセットの画像が正規化用かどうかを指定します。指定しなかった場合は `false` として、つまり正規化画像ではないとして扱います。
@@ -280,4 +300,89 @@ resolution = 768
* `voluptuous.error.MultipleInvalid: expected int for dictionary value @ ...`: 指定する値の形式が不正というエラーです。値の形式が間違っている可能性が高いです。`int` の部分は対象となるオプションによって変わります。この README に載っているオプションの「設定例」が役立つかもしれません。
* `voluptuous.error.MultipleInvalid: extra keys not allowed @ ...`: 対応していないオプション名が存在している場合に発生するエラーです。オプション名を間違って記述しているか、誤って紛れ込んでいる可能性が高いです。
## その他
### 複数行キャプション
`enable_wildcard = true` を設定することで、複数行キャプションも同時に有効になります。キャプションファイルが複数の行からなる場合、ランダムに一つの行が選ばれてキャプションとして利用されます。
```txt
1girl, hatsune miku, vocaloid, upper body, looking at viewer, microphone, stage
a girl with a microphone standing on a stage
detailed digital art of a girl with a microphone on a stage
```
ワイルドカード記法と組み合わせることも可能です。
メタデータファイルでも同様に複数行キャプションを指定することができます。メタデータの .json 内には、`\n` を使って改行を表現してください。キャプションファイルが複数行からなる場合、`merge_captions_to_metadata.py` を使うと、この形式でメタデータファイルが作成されます。
メタデータのタグ (`tags`) は、キャプションの各行に追加されます。
```json
{
"/path/to/image.png": {
"caption": "a cartoon of a frog with the word frog on it\ntest multiline caption1\ntest multiline caption2",
"tags": "open mouth, simple background, standing, no humans, animal, black background, frog, animal costume, animal focus"
},
...
}
```
この場合、実際のキャプションは `a cartoon of a frog with the word frog on it, open mouth, simple background ...` または `test multiline caption1, open mouth, simple background ...``test multiline caption2, open mouth, simple background ...` 等になります。
### 設定ファイルの記述例:追加の区切り文字、ワイルドカード記法、`keep_tokens_separator` 等
```toml
[general]
flip_aug = true
color_aug = false
resolution = [1024, 1024]
[[datasets]]
batch_size = 6
enable_bucket = true
bucket_no_upscale = true
caption_extension = ".txt"
keep_tokens_separator= "|||"
shuffle_caption = true
caption_tag_dropout_rate = 0.1
secondary_separator = ";;;" # subset 側に書くこともできます / can be written in the subset side
enable_wildcard = true # 同上 / same as above
[[datasets.subsets]]
image_dir = "/path/to/image_dir"
num_repeats = 1
# ||| の前後はカンマは不要です(自動的に追加されます) / No comma is required before and after ||| (it is added automatically)
caption_prefix = "1girl, hatsune miku, vocaloid |||"
# ||| の後はシャッフル、drop されず残ります / After |||, it is not shuffled or dropped and remains
# 単純に文字列として連結されるので、カンマなどは自分で入れる必要があります / It is simply concatenated as a string, so you need to put commas yourself
caption_suffix = ", anime screencap ||| masterpiece, rating: general"
```
### キャプション記述例、secondary_separator 記法:`secondary_separator = ";;;"` の場合
```txt
1girl, hatsune miku, vocaloid, upper body, looking at viewer, sky;;;cloud;;;day, outdoors
```
`sky;;;cloud;;;day` の部分はシャッフル、drop されず `sky,cloud,day` に置換されます。シャッフル、drop が有効な場合、まとめて(一つのタグとして)処理されます。つまり `vocaloid, 1girl, upper body, sky,cloud,day, outdoors, hatsune miku` (シャッフル)や `vocaloid, 1girl, outdoors, looking at viewer, upper body, hatsune miku` drop されたケース)などになります。
### キャプション記述例、ワイルドカード記法: `enable_wildcard = true` の場合
```txt
1girl, hatsune miku, vocaloid, upper body, looking at viewer, {simple|white} background
```
ランダムに `simple` または `white` が選ばれ、`simple background` または `white background` になります。
```txt
1girl, hatsune miku, vocaloid, {{retro style}}
```
タグ文字列に `{``}` そのものを含めたい場合は `{{``}}` のように二つ重ねてください(この例では実際に学習に用いられるキャプションは `{retro style}` になります)。
### キャプション記述例、`keep_tokens_separator` 記法: `keep_tokens_separator = "|||"` の場合
```txt
1girl, hatsune miku, vocaloid ||| stage, microphone, white shirt, smile ||| best quality, rating: general
```
`1girl, hatsune miku, vocaloid, microphone, stage, white shirt, best quality, rating: general``1girl, hatsune miku, vocaloid, white shirt, smile, stage, microphone, best quality, rating: general` などになります。

View File

@@ -0,0 +1,57 @@
## マスクロスについて
マスクロスは、入力画像のマスクで指定された部分だけ損失計算することで、画像の一部分だけを学習することができる機能です。
たとえばキャラクタを学習したい場合、キャラクタ部分だけをマスクして学習することで、背景を無視して学習することができます。
マスクロスのマスクには、二種類の指定方法があります。
- マスク画像を用いる方法
- 透明度(アルファチャネル)を使用する方法
なお、サンプルは [ずんずんPJイラスト/3Dデータ](https://zunko.jp/con_illust.html) の「AI画像モデル用学習データ」を使用しています。
### マスク画像を用いる方法
学習画像それぞれに対応するマスク画像を用意する方法です。学習画像と同じファイル名のマスク画像を用意し、それを学習画像と別のディレクトリに保存します。
- 学習画像
![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/607c5116-5f62-47de-8b66-9c4a597f0441)
- マスク画像
![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/53e9b0f8-a4bf-49ed-882d-4026f84e8450)
```.toml
[[datasets.subsets]]
image_dir = "/path/to/a_zundamon"
caption_extension = ".txt"
conditioning_data_dir = "/path/to/a_zundamon_mask"
num_repeats = 8
```
マスク画像は、学習画像と同じサイズで、学習する部分を白、無視する部分を黒で描画します。グレースケールにも対応しています127 ならロス重みが 0.5 になります)。なお、正確にはマスク画像の R チャネルが用いられます。
DreamBooth 方式の dataset で、`conditioning_data_dir` で指定したディレクトリにマスク画像を保存してください。ControlNet のデータセットと同じですので、詳細は [ControlNet-LLLite](train_lllite_README-ja.md#データセットの準備) を参照してください。
### 透明度(アルファチャネル)を使用する方法
学習画像の透明度(アルファチャネル)がマスクとして使用されます。透明度が 0 の部分は無視され、255 の部分は学習されます。半透明の場合は、その透明度に応じてロス重みが変化します127 ならおおむね 0.5)。
![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/0baa129b-446a-4aac-b98c-7208efb0e75e)
※それぞれの画像は透過PNG
学習時のスクリプトのオプションに `--alpha_mask` を指定するか、dataset の設定ファイルの subset で、`alpha_mask` を指定してください。たとえば、以下のようになります。
```toml
[[datasets.subsets]]
image_dir = "/path/to/image/dir"
caption_extension = ".txt"
num_repeats = 8
alpha_mask = true
```
## 学習時の注意事項
- 現時点では DreamBooth 方式の dataset のみ対応しています。
- マスクは latents のサイズ、つまり 1/8 に縮小されてから適用されます。そのため、細かい部分(たとえばアホ毛やイヤリングなど)はうまく学習できない可能性があります。マスクをわずかに拡張するなどの工夫が必要かもしれません。
- マスクロスを用いる場合、学習対象外の部分をキャプションに含める必要はないかもしれません。(要検証)
- `alpha_mask` の場合、マスクの有無を切り替えると latents キャッシュが自動的に再生成されます。

View File

@@ -0,0 +1,56 @@
## Masked Loss
Masked loss is a feature that allows you to train only part of an image by calculating the loss only for the part specified by the mask of the input image. For example, if you want to train a character, you can train only the character part by masking it, ignoring the background.
There are two ways to specify the mask for masked loss.
- Using a mask image
- Using transparency (alpha channel) of the image
The sample uses the "AI image model training data" from [ZunZunPJ Illustration/3D Data](https://zunko.jp/con_illust.html).
### Using a mask image
This is a method of preparing a mask image corresponding to each training image. Prepare a mask image with the same file name as the training image and save it in a different directory from the training image.
- Training image
![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/607c5116-5f62-47de-8b66-9c4a597f0441)
- Mask image
![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/53e9b0f8-a4bf-49ed-882d-4026f84e8450)
```.toml
[[datasets.subsets]]
image_dir = "/path/to/a_zundamon"
caption_extension = ".txt"
conditioning_data_dir = "/path/to/a_zundamon_mask"
num_repeats = 8
```
The mask image is the same size as the training image, with the part to be trained drawn in white and the part to be ignored in black. It also supports grayscale (127 gives a loss weight of 0.5). The R channel of the mask image is used currently.
Use the dataset in the DreamBooth method, and save the mask image in the directory specified by `conditioning_data_dir`. It is the same as the ControlNet dataset, so please refer to [ControlNet-LLLite](train_lllite_README.md#Preparing-the-dataset) for details.
### Using transparency (alpha channel) of the image
The transparency (alpha channel) of the training image is used as a mask. The part with transparency 0 is ignored, the part with transparency 255 is trained. For semi-transparent parts, the loss weight changes according to the transparency (127 gives a weight of about 0.5).
![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/0baa129b-446a-4aac-b98c-7208efb0e75e)
※Each image is a transparent PNG
Specify `--alpha_mask` in the training script options or specify `alpha_mask` in the subset of the dataset configuration file. For example, it will look like this.
```toml
[[datasets.subsets]]
image_dir = "/path/to/image/dir"
caption_extension = ".txt"
num_repeats = 8
alpha_mask = true
```
## Notes on training
- At the moment, only the dataset in the DreamBooth method is supported.
- The mask is applied after the size is reduced to 1/8, which is the size of the latents. Therefore, fine details (such as ahoge or earrings) may not be learned well. Some dilations of the mask may be necessary.
- If using masked loss, it may not be necessary to include parts that are not to be trained in the caption. (To be verified)
- In the case of `alpha_mask`, the latents cache is automatically regenerated when the enable/disable state of the mask is switched.

View File

@@ -648,7 +648,7 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
詳細については各自お調べください。
任意のスケジューラを使う場合、任意のオプティマイザと同様に、`--scheduler_args`でオプション引数を指定してください。
任意のスケジューラを使う場合、任意のオプティマイザと同様に、`--lr_scheduler_args`でオプション引数を指定してください。
### オプティマイザの指定について

View File

@@ -582,7 +582,7 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
有关详细信息,请自行研究。
要使用任何调度程序,请像使用任何优化器一样使用“--scheduler_args”指定可选参数。
要使用任何调度程序,请像使用任何优化器一样使用“--lr_scheduler_args”指定可选参数。
### 关于指定优化器
使用 --optimizer_args 选项指定优化器选项参数。可以以key=value的格式指定多个值。此外您可以指定多个值以逗号分隔。例如要指定 AdamW 优化器的参数,``--optimizer_args weight_decay=0.01 betas=.9,.999``。

84
docs/train_SDXL-en.md Normal file
View File

@@ -0,0 +1,84 @@
## SDXL training
The documentation will be moved to the training documentation in the future. The following is a brief explanation of the training scripts for SDXL.
### Training scripts for SDXL
- `sdxl_train.py` is a script for SDXL fine-tuning. The usage is almost the same as `fine_tune.py`, but it also supports DreamBooth dataset.
- `--full_bf16` option is added. Thanks to KohakuBlueleaf!
- This option enables the full bfloat16 training (includes gradients). This option is useful to reduce the GPU memory usage.
- The full bfloat16 training might be unstable. Please use it at your own risk.
- The different learning rates for each U-Net block are now supported in sdxl_train.py. Specify with `--block_lr` option. Specify 23 values separated by commas like `--block_lr 1e-3,1e-3 ... 1e-3`.
- 23 values correspond to `0: time/label embed, 1-9: input blocks 0-8, 10-12: mid blocks 0-2, 13-21: output blocks 0-8, 22: out`.
- `prepare_buckets_latents.py` now supports SDXL fine-tuning.
- `sdxl_train_network.py` is a script for LoRA training for SDXL. The usage is almost the same as `train_network.py`.
- Both scripts has following additional options:
- `--cache_text_encoder_outputs` and `--cache_text_encoder_outputs_to_disk`: Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions.
- `--no_half_vae`: Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs.
- `--weighted_captions` option is not supported yet for both scripts.
- `sdxl_train_textual_inversion.py` is a script for Textual Inversion training for SDXL. The usage is almost the same as `train_textual_inversion.py`.
- `--cache_text_encoder_outputs` is not supported.
- There are two options for captions:
1. Training with captions. All captions must include the token string. The token string is replaced with multiple tokens.
2. Use `--use_object_template` or `--use_style_template` option. The captions are generated from the template. The existing captions are ignored.
- See below for the format of the embeddings.
- `--min_timestep` and `--max_timestep` options are added to each training script. These options can be used to train U-Net with different timesteps. The default values are 0 and 1000.
### Utility scripts for SDXL
- `tools/cache_latents.py` is added. This script can be used to cache the latents to disk in advance.
- The options are almost the same as `sdxl_train.py'. See the help message for the usage.
- Please launch the script as follows:
`accelerate launch --num_cpu_threads_per_process 1 tools/cache_latents.py ...`
- This script should work with multi-GPU, but it is not tested in my environment.
- `tools/cache_text_encoder_outputs.py` is added. This script can be used to cache the text encoder outputs to disk in advance.
- The options are almost the same as `cache_latents.py` and `sdxl_train.py`. See the help message for the usage.
- `sdxl_gen_img.py` is added. This script can be used to generate images with SDXL, including LoRA, Textual Inversion and ControlNet-LLLite. See the help message for the usage.
### Tips for SDXL training
- The default resolution of SDXL is 1024x1024.
- The fine-tuning can be done with 24GB GPU memory with the batch size of 1. For 24GB GPU, the following options are recommended __for the fine-tuning with 24GB GPU memory__:
- Train U-Net only.
- Use gradient checkpointing.
- Use `--cache_text_encoder_outputs` option and caching latents.
- Use Adafactor optimizer. RMSprop 8bit or Adagrad 8bit may work. AdamW 8bit doesn't seem to work.
- The LoRA training can be done with 8GB GPU memory (10GB recommended). For reducing the GPU memory usage, the following options are recommended:
- Train U-Net only.
- Use gradient checkpointing.
- Use `--cache_text_encoder_outputs` option and caching latents.
- Use one of 8bit optimizers or Adafactor optimizer.
- Use lower dim (4 to 8 for 8GB GPU).
- `--network_train_unet_only` option is highly recommended for SDXL LoRA. Because SDXL has two text encoders, the result of the training will be unexpected.
- PyTorch 2 seems to use slightly less GPU memory than PyTorch 1.
- `--bucket_reso_steps` can be set to 32 instead of the default value 64. Smaller values than 32 will not work for SDXL training.
Example of the optimizer settings for Adafactor with the fixed learning rate:
```toml
optimizer_type = "adafactor"
optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ]
lr_scheduler = "constant_with_warmup"
lr_warmup_steps = 100
learning_rate = 4e-7 # SDXL original learning rate
```
### Format of Textual Inversion embeddings for SDXL
```python
from safetensors.torch import save_file
state_dict = {"clip_g": embs_for_text_encoder_1280, "clip_l": embs_for_text_encoder_768}
save_file(state_dict, file)
```
### ControlNet-LLLite
ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [documentation](./docs/train_lllite_README.md) for details.

View File

@@ -21,9 +21,13 @@ ComfyUIのカスタムードを用意しています。: https://github.com/k
## モデルの学習
### データセットの準備
通常のdatasetに加え`conditioning_data_dir` で指定したディレクトリにconditioning imageを格納してください。conditioning imageは学習用画像と同じbasenameを持つ必要があります。また、conditioning imageは学習用画像と同じサイズに自動的にリサイズされます。conditioning imageにはキャプションファイルは不要です。
DreamBooth 方式の dataset`conditioning_data_dir` で指定したディレクトリにconditioning imageを格納してください。
たとえば DreamBooth 方式でキャプションファイルを用いる場合の設定ファイルは以下のようになります。
finetuning 方式の dataset はサポートしていません。)
conditioning imageは学習用画像と同じbasenameを持つ必要があります。また、conditioning imageは学習用画像と同じサイズに自動的にリサイズされます。conditioning imageにはキャプションファイルは不要です。
たとえば、キャプションにフォルダ名ではなくキャプションファイルを用いる場合の設定ファイルは以下のようになります。
```toml
[[datasets.subsets]]

View File

@@ -26,7 +26,9 @@ Due to the limitations of the inference environment, only CrossAttention (attn1
### Preparing the dataset
In addition to the normal dataset, please store the conditioning image in the directory specified by `conditioning_data_dir`. The conditioning image must have the same basename as the training image. The conditioning image will be automatically resized to the same size as the training image. The conditioning image does not require a caption file.
In addition to the normal DreamBooth method dataset, please store the conditioning image in the directory specified by `conditioning_data_dir`. The conditioning image must have the same basename as the training image. The conditioning image will be automatically resized to the same size as the training image. The conditioning image does not require a caption file.
(We do not support the finetuning method dataset.)
```toml
[[datasets.subsets]]
@@ -183,7 +185,7 @@ for img_file in img_files:
### Creating a dataset configuration file
You can use the command line arguments of `sdxl_train_control_net_lllite.py` to specify the conditioning image directory. However, if you want to use a `.toml` file, specify the conditioning image directory in `conditioning_data_dir`.
You can use the command line argument `--conditioning_data_dir` of `sdxl_train_control_net_lllite.py` to specify the conditioning image directory. However, if you want to use a `.toml` file, specify the conditioning image directory in `conditioning_data_dir`.
```toml
[general]

View File

@@ -102,6 +102,8 @@ accelerate launch --num_cpu_threads_per_process 1 train_network.py
* Text Encoderに関連するLoRAモジュールに、通常の学習率--learning_rateオプションで指定とは異なる学習率を使う時に指定します。Text Encoderのほうを若干低めの学習率5e-5などにしたほうが良い、という話もあるようです。
* `--network_args`
* 複数の引数を指定できます。後述します。
* `--alpha_mask`
* 画像のアルファ値をマスクとして使用します。透過画像を学習する際に使用します。[PR #1223](https://github.com/kohya-ss/sd-scripts/pull/1223)
`--network_train_unet_only``--network_train_text_encoder_only` の両方とも未指定時デフォルトはText EncoderとU-Netの両方のLoRAモジュールを有効にします。
@@ -181,16 +183,16 @@ python networks\extract_lora_from_dylora.py --model "foldername/dylora-model.saf
詳細は[PR #355](https://github.com/kohya-ss/sd-scripts/pull/355) をご覧ください。
SDXLは現在サポートしていません。
フルモデルの25個のブロックの重みを指定できます。最初のブロックに該当するLoRAは存在しませんが、階層別LoRA適用等との互換性のために25個としています。またconv2d3x3に拡張しない場合も一部のブロックにはLoRAが存在しませんが、記述を統一するため常に25個の値を指定してください。
SDXL では down/up 9 個、middle 3 個の値を指定してください。
`--network_args` で以下の引数を指定してください。
- `down_lr_weight` : U-Netのdown blocksの学習率の重みを指定します。以下が指定可能です。
- ブロックごとの重み : `"down_lr_weight=0,0,0,0,0,0,1,1,1,1,1,1"` のように12個の数値を指定します。
- ブロックごとの重み : `"down_lr_weight=0,0,0,0,0,0,1,1,1,1,1,1"` のように12個SDXL では 9 個)の数値を指定します。
- プリセットからの指定 : `"down_lr_weight=sine"` のように指定しますサインカーブで重みを指定します。sine, cosine, linear, reverse_linear, zeros が指定可能です。また `"down_lr_weight=cosine+.25"` のように `+数値` を追加すると、指定した数値を加算します0.25~1.25になります)。
- `mid_lr_weight` : U-Netのmid blockの学習率の重みを指定します。`"down_lr_weight=0.5"` のように数値を一つだけ指定します。
- `mid_lr_weight` : U-Netのmid blockの学習率の重みを指定します。`"down_lr_weight=0.5"` のように数値を一つだけ指定しますSDXL の場合は 3 個)
- `up_lr_weight` : U-Netのup blocksの学習率の重みを指定します。down_lr_weightと同様です。
- 指定を省略した部分は1.0として扱われます。また重みを0にするとそのブロックのLoRAモジュールは作成されません。
- `block_lr_zero_threshold` : 重みがこの値以下の場合、LoRAモジュールを作成しません。デフォルトは0です。
@@ -215,6 +217,9 @@ network_args = [ "block_lr_zero_threshold=0.1", "down_lr_weight=sine+.5", "mid_l
フルモデルの25個のブロックのdim (rank)を指定できます。階層別学習率と同様に一部のブロックにはLoRAが存在しない場合がありますが、常に25個の値を指定してください。
SDXL では 23 個の値を指定してください。一部のブロックにはLoRA が存在しませんが、`sdxl_train.py` の[階層別学習率](./train_SDXL-en.md) との互換性のためです。
対応は、`0: time/label embed, 1-9: input blocks 0-8, 10-12: mid blocks 0-2, 13-21: output blocks 0-8, 22: out` です。
`--network_args` で以下の引数を指定してください。
- `block_dims` : 各ブロックのdim (rank)を指定します。`"block_dims=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"` のように25個の数値を指定します。

View File

@@ -101,6 +101,8 @@ LoRA的模型将会被保存在通过`--output_dir`选项指定的文件夹中
* 当在Text Encoder相关的LoRA模块中使用与常规学习率`--learning_rate`选项指定不同的学习率时应指定此选项。可能最好将Text Encoder的学习率稍微降低例如5e-5
* `--network_args`
* 可以指定多个参数。将在下面详细说明。
* `--alpha_mask`
* 使用图像的 Alpha 值作为遮罩。这在学习透明图像时使用。[PR #1223](https://github.com/kohya-ss/sd-scripts/pull/1223)
当未指定`--network_train_unet_only``--network_train_text_encoder_only`默认情况将启用Text Encoder和U-Net的两个LoRA模块。

View File

@@ -0,0 +1,88 @@
# Image Tagging using WD14Tagger
This document is based on the information from this github page (https://github.com/toriato/stable-diffusion-webui-wd14-tagger#mrsmilingwolfs-model-aka-waifu-diffusion-14-tagger).
Using onnx for inference is recommended. Please install onnx with the following command:
```powershell
pip install onnx==1.15.0 onnxruntime-gpu==1.17.1
```
The model weights will be automatically downloaded from Hugging Face.
# Usage
Run the script to perform tagging.
```powershell
python finetune/tag_images_by_wd14_tagger.py --onnx --repo_id <model repo id> --batch_size <batch size> <training data folder>
```
For example, if using the repository `SmilingWolf/wd-swinv2-tagger-v3` with a batch size of 4, and the training data is located in the parent folder `train_data`, it would be:
```powershell
python tag_images_by_wd14_tagger.py --onnx --repo_id SmilingWolf/wd-swinv2-tagger-v3 --batch_size 4 ..\train_data
```
On the first run, the model files will be automatically downloaded to the `wd14_tagger_model` folder (the folder can be changed with an option).
Tag files will be created in the same directory as the training data images, with the same filename and a `.txt` extension.
![Generated tag files](https://user-images.githubusercontent.com/52813779/208910534-ea514373-1185-4b7d-9ae3-61eb50bc294e.png)
![Tags and image](https://user-images.githubusercontent.com/52813779/208910599-29070c15-7639-474f-b3e4-06bd5a3df29e.png)
## Example
To output in the Animagine XL 3.1 format, it would be as follows (enter on a single line in practice):
```
python tag_images_by_wd14_tagger.py --onnx --repo_id SmilingWolf/wd-swinv2-tagger-v3
--batch_size 4 --remove_underscore --undesired_tags "PUT,YOUR,UNDESIRED,TAGS" --recursive
--use_rating_tags_as_last_tag --character_tags_first --character_tag_expand
--always_first_tags "1girl,1boy" ..\train_data
```
## Available Repository IDs
[SmilingWolf's V2 and V3 models](https://huggingface.co/SmilingWolf) are available for use. Specify them in the format like `SmilingWolf/wd-vit-tagger-v3`. The default when omitted is `SmilingWolf/wd-v1-4-convnext-tagger-v2`.
# Options
## General Options
- `--onnx`: Use ONNX for inference. If not specified, TensorFlow will be used. If using TensorFlow, please install TensorFlow separately.
- `--batch_size`: Number of images to process at once. Default is 1. Adjust according to VRAM capacity.
- `--caption_extension`: File extension for caption files. Default is `.txt`.
- `--max_data_loader_n_workers`: Maximum number of workers for DataLoader. Specifying a value of 1 or more will use DataLoader to speed up image loading. If unspecified, DataLoader will not be used.
- `--thresh`: Confidence threshold for outputting tags. Default is 0.35. Lowering the value will assign more tags but accuracy will decrease.
- `--general_threshold`: Confidence threshold for general tags. If omitted, same as `--thresh`.
- `--character_threshold`: Confidence threshold for character tags. If omitted, same as `--thresh`.
- `--recursive`: If specified, subfolders within the specified folder will also be processed recursively.
- `--append_tags`: Append tags to existing tag files.
- `--frequency_tags`: Output tag frequencies.
- `--debug`: Debug mode. Outputs debug information if specified.
## Model Download
- `--model_dir`: Folder to save model files. Default is `wd14_tagger_model`.
- `--force_download`: Re-download model files if specified.
## Tag Editing
- `--remove_underscore`: Remove underscores from output tags.
- `--undesired_tags`: Specify tags not to output. Multiple tags can be specified, separated by commas. For example, `black eyes,black hair`.
- `--use_rating_tags`: Output rating tags at the beginning of the tags.
- `--use_rating_tags_as_last_tag`: Add rating tags at the end of the tags.
- `--character_tags_first`: Output character tags first.
- `--character_tag_expand`: Expand character tag series names. For example, split the tag `chara_name_(series)` into `chara_name, series`.
- `--always_first_tags`: Specify tags to always output first when a certain tag appears in an image. Multiple tags can be specified, separated by commas. For example, `1girl,1boy`.
- `--caption_separator`: Separate tags with this string in the output file. Default is `, `.
- `--tag_replacement`: Perform tag replacement. Specify in the format `tag1,tag2;tag3,tag4`. If using `,` and `;`, escape them with `\`. \
For example, specify `aira tsubase,aira tsubase (uniform)` (when you want to train a specific costume), `aira tsubase,aira tsubase\, heir of shadows` (when the series name is not included in the tag).
When using `tag_replacement`, it is applied after `character_tag_expand`.
When specifying `remove_underscore`, specify `undesired_tags`, `always_first_tags`, and `tag_replacement` without including underscores.
When specifying `caption_separator`, separate `undesired_tags` and `always_first_tags` with `caption_separator`. Always separate `tag_replacement` with `,`.

View File

@@ -0,0 +1,88 @@
# WD14Taggerによるタグ付け
こちらのgithubページhttps://github.com/toriato/stable-diffusion-webui-wd14-tagger#mrsmilingwolfs-model-aka-waifu-diffusion-14-tagger )の情報を参考にさせていただきました。
onnx を用いた推論を推奨します。以下のコマンドで onnx をインストールしてください。
```powershell
pip install onnx==1.15.0 onnxruntime-gpu==1.17.1
```
モデルの重みはHugging Faceから自動的にダウンロードしてきます。
# 使い方
スクリプトを実行してタグ付けを行います。
```
python fintune/tag_images_by_wd14_tagger.py --onnx --repo_id <モデルのrepo id> --batch_size <バッチサイズ> <教師データフォルダ>
```
レポジトリに `SmilingWolf/wd-swinv2-tagger-v3` を使用し、バッチサイズを4にして、教師データを親フォルダの `train_data`に置いた場合、以下のようになります。
```
python tag_images_by_wd14_tagger.py --onnx --repo_id SmilingWolf/wd-swinv2-tagger-v3 --batch_size 4 ..\train_data
```
初回起動時にはモデルファイルが `wd14_tagger_model` フォルダに自動的にダウンロードされます(フォルダはオプションで変えられます)。
タグファイルが教師データ画像と同じディレクトリに、同じファイル名、拡張子.txtで作成されます。
![生成されたタグファイル](https://user-images.githubusercontent.com/52813779/208910534-ea514373-1185-4b7d-9ae3-61eb50bc294e.png)
![タグと画像](https://user-images.githubusercontent.com/52813779/208910599-29070c15-7639-474f-b3e4-06bd5a3df29e.png)
## 記述例
Animagine XL 3.1 方式で出力する場合、以下のようになります(実際には 1 行で入力してください)。
```
python tag_images_by_wd14_tagger.py --onnx --repo_id SmilingWolf/wd-swinv2-tagger-v3
--batch_size 4 --remove_underscore --undesired_tags "PUT,YOUR,UNDESIRED,TAGS" --recursive
--use_rating_tags_as_last_tag --character_tags_first --character_tag_expand
--always_first_tags "1girl,1boy" ..\train_data
```
## 使用可能なリポジトリID
[SmilingWolf 氏の V2、V3 のモデル](https://huggingface.co/SmilingWolf)が使用可能です。`SmilingWolf/wd-vit-tagger-v3` のように指定してください。省略時のデフォルトは `SmilingWolf/wd-v1-4-convnext-tagger-v2` です。
# オプション
## 一般オプション
- `--onnx` : ONNX を使用して推論します。指定しない場合は TensorFlow を使用します。TensorFlow 使用時は別途 TensorFlow をインストールしてください。
- `--batch_size` : 一度に処理する画像の数。デフォルトは1です。VRAMの容量に応じて増減してください。
- `--caption_extension` : キャプションファイルの拡張子。デフォルトは `.txt` です。
- `--max_data_loader_n_workers` : DataLoader の最大ワーカー数です。このオプションに 1 以上の数値を指定すると、DataLoader を用いて画像読み込みを高速化します。未指定時は DataLoader を用いません。
- `--thresh` : 出力するタグの信頼度の閾値。デフォルトは0.35です。値を下げるとより多くのタグが付与されますが、精度は下がります。
- `--general_threshold` : 一般タグの信頼度の閾値。省略時は `--thresh` と同じです。
- `--character_threshold` : キャラクタータグの信頼度の閾値。省略時は `--thresh` と同じです。
- `--recursive` : 指定すると、指定したフォルダ内のサブフォルダも再帰的に処理します。
- `--append_tags` : 既存のタグファイルにタグを追加します。
- `--frequency_tags` : タグの頻度を出力します。
- `--debug` : デバッグモード。指定するとデバッグ情報を出力します。
## モデルのダウンロード
- `--model_dir` : モデルファイルの保存先フォルダ。デフォルトは `wd14_tagger_model` です。
- `--force_download` : 指定するとモデルファイルを再ダウンロードします。
## タグ編集関連
- `--remove_underscore` : 出力するタグからアンダースコアを削除します。
- `--undesired_tags` : 出力しないタグを指定します。カンマ区切りで複数指定できます。たとえば `black eyes,black hair` のように指定します。
- `--use_rating_tags` : タグの最初にレーティングタグを出力します。
- `--use_rating_tags_as_last_tag` : タグの最後にレーティングタグを追加します。
- `--character_tags_first` : キャラクタータグを最初に出力します。
- `--character_tag_expand` : キャラクタータグのシリーズ名を展開します。たとえば `chara_name_(series)` のタグを `chara_name, series` に分割します。
- `--always_first_tags` : あるタグが画像に出力されたとき、そのタグを最初に出力するタグを指定します。カンマ区切りで複数指定できます。たとえば `1girl,1boy` のように指定します。
- `--caption_separator` : 出力するファイルでタグをこの文字列で区切ります。デフォルトは `, ` です。
- `--tag_replacement` : タグの置換を行います。`tag1,tag2;tag3,tag4` のように指定します。`,` および `;` を使う場合は `\` でエスケープしてください。\
たとえば `aira tsubase,aira tsubase (uniform)` (特定の衣装を学習させたいとき)、`aira tsubase,aira tsubase\, heir of shadows` (シリーズ名がタグに含まれないとき)のように指定します。
`tag_replacement``character_tag_expand` の後に適用されます。
`remove_underscore` 指定時は、`undesired_tags``always_first_tags``tag_replacement` はアンダースコアを含めずに指定してください。
`caption_separator` 指定時は、`undesired_tags``always_first_tags``caption_separator` で区切ってください。`tag_replacement` は必ず `,` で区切ってください。

View File

@@ -10,7 +10,9 @@ import toml
from tqdm import tqdm
import torch
from library import deepspeed_utils, strategy_base
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate.utils import set_seed
@@ -37,11 +39,13 @@ from library.custom_train_functions import (
scale_v_prediction_loss_like_noise_prediction,
apply_debiased_estimation,
)
import library.strategy_sd as strategy_sd
def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
deepspeed_utils.prepare_deepspeed_args(args)
setup_logging(args, reset=True)
cache_latents = args.cache_latents
@@ -49,7 +53,15 @@ def train(args):
if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する
tokenizer = train_util.load_tokenizer(args)
tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
if cache_latents:
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
# データセットを準備する
if args.dataset_class is None:
@@ -78,16 +90,18 @@ def train(args):
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
train_dataset_group = train_util.load_arbitrary_dataset(args)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
train_dataset_group.verify_bucket_reso_steps(64)
if args.debug_dataset:
train_util.debug_dataset(train_dataset_group)
return
@@ -108,6 +122,7 @@ def train(args):
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
# モデルを読み込む
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)
@@ -158,11 +173,12 @@ def train(args):
# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
train_dataset_group.new_cache_latents(vae, accelerator)
vae.to("cpu")
clean_memory_on_device(accelerator.device)
@@ -188,10 +204,13 @@ def train(args):
else:
text_encoder.eval()
text_encoding_strategy = strategy_sd.SdTextEncodingStrategy(args.clip_skip)
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
if not cache_latents:
vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=vae_dtype)
for m in training_models:
m.requires_grad_(True)
@@ -210,7 +229,11 @@ def train(args):
accelerator.print("prepare optimizer, data loader etc.")
_, _, optimizer = train_util.get_optimizer(args, trainable_params=trainable_params)
# dataloaderを準備する
# prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
train_dataset_group.set_current_strategies()
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
@@ -246,13 +269,23 @@ def train(args):
unet.to(weight_dtype)
text_encoder.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
if args.deepspeed:
if args.train_text_encoder:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder)
else:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
training_models = [ds_model]
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
# acceleratorがなんかよろしくやってくれるらしい
if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
@@ -296,10 +329,19 @@ def train(args):
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
accelerator.init_trackers(
"finetuning" if args.log_tracker_name is None else args.log_tracker_name,
config=train_util.get_sanitized_config_or_none(args),
init_kwargs=init_kwargs,
)
# For --sample_at_first
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
train_util.sample_images(
accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
)
if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)
loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
@@ -311,32 +353,30 @@ def train(args):
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
with accelerator.accumulate(*training_models):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device) # .to(dtype=weight_dtype)
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(weight_dtype)
latents = latents * 0.18215
b_size = latents.shape[0]
with torch.set_grad_enabled(args.train_text_encoder):
# Get the text embedding for conditioning
if args.weighted_captions:
encoder_hidden_states = get_weighted_text_embeddings(
tokenizer,
text_encoder,
batch["captions"],
accelerator.device,
args.max_token_length // 75 if args.max_token_length else 1,
clip_skip=args.clip_skip,
)
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
encoder_hidden_states = text_encoding_strategy.encode_tokens_with_weights(
tokenize_strategy, [text_encoder], input_ids_list, weights_list
)[0]
else:
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
)
input_ids = batch["input_ids_list"][0].to(accelerator.device)
encoder_hidden_states = text_encoding_strategy.encode_tokens(
tokenize_strategy, [text_encoder], [input_ids]
)[0]
if args.full_fp16:
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
@@ -352,9 +392,10 @@ def train(args):
else:
target = noise
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss:
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
loss = loss.mean([1, 2, 3])
if args.min_snr_gamma:
@@ -362,11 +403,11 @@ def train(args):
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
loss = loss.mean() # mean over batch dimension
else:
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "mean", huber_c)
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
@@ -385,7 +426,7 @@ def train(args):
global_step += 1
train_util.sample_images(
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
accelerator, args, None, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
)
# 指定ステップごとにモデルを保存
@@ -410,7 +451,7 @@ def train(args):
)
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss": current_loss}
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
accelerator.log(logs, step=global_step)
@@ -423,7 +464,7 @@ def train(args):
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
@@ -448,7 +489,9 @@ def train(args):
vae,
)
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
train_util.sample_images(
accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
)
is_main_process = accelerator.is_main_process
if is_main_process:
@@ -457,7 +500,7 @@ def train(args):
accelerator.end_training()
if is_main_process and (args.save_state or args.save_state_on_train_end):
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)
del accelerator # この後メモリを使うのでこれは消す
@@ -477,6 +520,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
@@ -492,6 +536,11 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="learning rate for text encoder, default is same as unet / Text Encoderの学習率、デフォルトはunetと同じ",
)
parser.add_argument(
"--no_half_vae",
action="store_true",
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
)
return parser
@@ -500,6 +549,7 @@ if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
train(args)

View File

@@ -134,8 +134,9 @@ class BLIP_Decoder(nn.Module):
def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0):
image_embeds = self.visual_encoder(image)
if not sample:
image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
# recent version of transformers seems to do repeat_interleave automatically
# if not sample:
# image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}

View File

@@ -6,75 +6,95 @@ from tqdm import tqdm
import library.train_util as train_util
import os
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def main(args):
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
assert not args.recursive or (
args.recursive and args.full_path
), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
train_data_dir_path = Path(args.train_data_dir)
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
logger.info(f"found {len(image_paths)} images.")
train_data_dir_path = Path(args.train_data_dir)
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
logger.info(f"found {len(image_paths)} images.")
if args.in_json is None and Path(args.out_json).is_file():
args.in_json = args.out_json
if args.in_json is None and Path(args.out_json).is_file():
args.in_json = args.out_json
if args.in_json is not None:
logger.info(f"loading existing metadata: {args.in_json}")
metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
logger.warning("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます")
else:
logger.info("new metadata will be created / 新しいメタデータファイルが作成されます")
metadata = {}
if args.in_json is not None:
logger.info(f"loading existing metadata: {args.in_json}")
metadata = json.loads(Path(args.in_json).read_text(encoding="utf-8"))
logger.warning("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます")
else:
logger.info("new metadata will be created / 新しいメタデータファイルが作成されます")
metadata = {}
logger.info("merge caption texts to metadata json.")
for image_path in tqdm(image_paths):
caption_path = image_path.with_suffix(args.caption_extension)
caption = caption_path.read_text(encoding='utf-8').strip()
logger.info("merge caption texts to metadata json.")
for image_path in tqdm(image_paths):
caption_path = image_path.with_suffix(args.caption_extension)
caption = caption_path.read_text(encoding="utf-8").strip()
if not os.path.exists(caption_path):
caption_path = os.path.join(image_path, args.caption_extension)
if not os.path.exists(caption_path):
caption_path = os.path.join(image_path, args.caption_extension)
image_key = str(image_path) if args.full_path else image_path.stem
if image_key not in metadata:
metadata[image_key] = {}
image_key = str(image_path) if args.full_path else image_path.stem
if image_key not in metadata:
metadata[image_key] = {}
metadata[image_key]['caption'] = caption
if args.debug:
logger.info(f"{image_key} {caption}")
metadata[image_key]["caption"] = caption
if args.debug:
logger.info(f"{image_key} {caption}")
# metadataを書き出して終わり
logger.info(f"writing metadata: {args.out_json}")
Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
logger.info("done!")
# metadataを書き出して終わり
logger.info(f"writing metadata: {args.out_json}")
Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding="utf-8")
logger.info("done!")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("--in_json", type=str,
help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル省略時、out_jsonが存在すればそれを読み込む")
parser.add_argument("--caption_extention", type=str, default=None,
help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります")
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子")
parser.add_argument("--full_path", action="store_true",
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
parser.add_argument("--recursive", action="store_true",
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")
parser.add_argument("--debug", action="store_true", help="debug mode")
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument(
"--in_json",
type=str,
help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル省略時、out_jsonが存在すればそれを読み込む",
)
parser.add_argument(
"--caption_extention",
type=str,
default=None,
help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)",
)
parser.add_argument(
"--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子"
)
parser.add_argument(
"--full_path",
action="store_true",
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)",
)
parser.add_argument(
"--recursive",
action="store_true",
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す",
)
parser.add_argument("--debug", action="store_true", help="debug mode")
return parser
return parser
if __name__ == '__main__':
parser = setup_parser()
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
args = parser.parse_args()
# スペルミスしていたオプションを復元する
if args.caption_extention is not None:
args.caption_extension = args.caption_extention
# スペルミスしていたオプションを復元する
if args.caption_extention is not None:
args.caption_extension = args.caption_extention
main(args)
main(args)

View File

@@ -6,70 +6,88 @@ from tqdm import tqdm
import library.train_util as train_util
import os
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def main(args):
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
assert not args.recursive or (
args.recursive and args.full_path
), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
train_data_dir_path = Path(args.train_data_dir)
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
logger.info(f"found {len(image_paths)} images.")
train_data_dir_path = Path(args.train_data_dir)
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
logger.info(f"found {len(image_paths)} images.")
if args.in_json is None and Path(args.out_json).is_file():
args.in_json = args.out_json
if args.in_json is None and Path(args.out_json).is_file():
args.in_json = args.out_json
if args.in_json is not None:
logger.info(f"loading existing metadata: {args.in_json}")
metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
logger.warning("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます")
else:
logger.info("new metadata will be created / 新しいメタデータファイルが作成されます")
metadata = {}
if args.in_json is not None:
logger.info(f"loading existing metadata: {args.in_json}")
metadata = json.loads(Path(args.in_json).read_text(encoding="utf-8"))
logger.warning("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます")
else:
logger.info("new metadata will be created / 新しいメタデータファイルが作成されます")
metadata = {}
logger.info("merge tags to metadata json.")
for image_path in tqdm(image_paths):
tags_path = image_path.with_suffix(args.caption_extension)
tags = tags_path.read_text(encoding='utf-8').strip()
logger.info("merge tags to metadata json.")
for image_path in tqdm(image_paths):
tags_path = image_path.with_suffix(args.caption_extension)
tags = tags_path.read_text(encoding="utf-8").strip()
if not os.path.exists(tags_path):
tags_path = os.path.join(image_path, args.caption_extension)
if not os.path.exists(tags_path):
tags_path = os.path.join(image_path, args.caption_extension)
image_key = str(image_path) if args.full_path else image_path.stem
if image_key not in metadata:
metadata[image_key] = {}
image_key = str(image_path) if args.full_path else image_path.stem
if image_key not in metadata:
metadata[image_key] = {}
metadata[image_key]['tags'] = tags
if args.debug:
logger.info(f"{image_key} {tags}")
metadata[image_key]["tags"] = tags
if args.debug:
logger.info(f"{image_key} {tags}")
# metadataを書き出して終わり
logger.info(f"writing metadata: {args.out_json}")
Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
# metadataを書き出して終わり
logger.info(f"writing metadata: {args.out_json}")
Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding="utf-8")
logger.info("done!")
logger.info("done!")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("--in_json", type=str,
help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル省略時、out_jsonが存在すればそれを読み込む")
parser.add_argument("--full_path", action="store_true",
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応")
parser.add_argument("--recursive", action="store_true",
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")
parser.add_argument("--caption_extension", type=str, default=".txt",
help="extension of caption (tag) file / 読み込むキャプション(タグ)ファイルの拡張子")
parser.add_argument("--debug", action="store_true", help="debug mode, print tags")
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument(
"--in_json",
type=str,
help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル省略時、out_jsonが存在すればそれを読み込む",
)
parser.add_argument(
"--full_path",
action="store_true",
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)",
)
parser.add_argument(
"--recursive",
action="store_true",
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す",
)
parser.add_argument(
"--caption_extension",
type=str,
default=".txt",
help="extension of caption (tag) file / 読み込むキャプション(タグ)ファイルの拡張子",
)
parser.add_argument("--debug", action="store_true", help="debug mode, print tags")
return parser
return parser
if __name__ == '__main__':
parser = setup_parser()
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
main(args)
args = parser.parse_args()
main(args)

View File

@@ -11,6 +11,7 @@ import cv2
import torch
from library.device_utils import init_ipex, get_preferred_device
init_ipex()
from torchvision import transforms
@@ -18,8 +19,10 @@ from torchvision import transforms
import library.model_util as model_util
import library.train_util as train_util
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
DEVICE = get_preferred_device()
@@ -89,7 +92,9 @@ def main(args):
# bucketのサイズを計算する
max_reso = tuple([int(t) for t in args.max_resolution.split(",")])
assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
assert (
len(max_reso) == 2
), f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
bucket_manager = train_util.BucketManager(
args.bucket_no_upscale, max_reso, args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps
@@ -107,7 +112,7 @@ def main(args):
def process_batch(is_last):
for bucket in bucket_manager.buckets:
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
train_util.cache_batch_latents(vae, True, bucket, args.flip_aug, False)
train_util.cache_batch_latents(vae, True, bucket, args.flip_aug, args.alpha_mask, False)
bucket.clear()
# 読み込みの高速化のためにDataLoaderを使うオプション
@@ -208,7 +213,9 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
parser.add_argument("--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)")
parser.add_argument(
"--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)"
)
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument(
"--max_data_loader_n_workers",
@@ -231,10 +238,16 @@ def setup_parser() -> argparse.ArgumentParser:
help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します",
)
parser.add_argument(
"--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します"
"--bucket_no_upscale",
action="store_true",
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します",
)
parser.add_argument(
"--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度"
"--mixed_precision",
type=str,
default="no",
choices=["no", "fp16", "bf16"],
help="use mixed precision / 混合精度を使う場合、その精度",
)
parser.add_argument(
"--full_path",
@@ -242,7 +255,15 @@ def setup_parser() -> argparse.ArgumentParser:
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)",
)
parser.add_argument(
"--flip_aug", action="store_true", help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する"
"--flip_aug",
action="store_true",
help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する",
)
parser.add_argument(
"--alpha_mask",
type=str,
default="",
help="save alpha mask for images for loss calculation / 損失計算用に画像のアルファマスクを保存する",
)
parser.add_argument(
"--skip_existing",

View File

@@ -11,7 +11,7 @@ from PIL import Image
from tqdm import tqdm
import library.train_util as train_util
from library.utils import setup_logging
from library.utils import setup_logging, pil_resize
setup_logging()
import logging
@@ -42,8 +42,10 @@ def preprocess_image(image):
pad_t = pad_y // 2
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
if size > IMAGE_SIZE:
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), cv2.INTER_AREA)
else:
image = pil_resize(image, (IMAGE_SIZE, IMAGE_SIZE))
image = image.astype(np.float32)
return image
@@ -62,12 +64,12 @@ class ImageLoadingPrepDataset(torch.utils.data.Dataset):
try:
image = Image.open(img_path).convert("RGB")
image = preprocess_image(image)
tensor = torch.tensor(image)
# tensor = torch.tensor(image) # これ Tensor に変換する必要ないな……(;・∀・)
except Exception as e:
logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
return None
return (tensor, img_path)
return (image, img_path)
def collate_fn_remove_corrupted(batch):
@@ -110,9 +112,8 @@ def main(args):
else:
logger.info("using existing wd14 tagger model")
# 画像を読み込む
# モデルを読み込む
if args.onnx:
import torch
import onnx
import onnxruntime as ort
@@ -130,10 +131,10 @@ def main(args):
input_name = model.graph.input[0].name
try:
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_value
except:
except Exception:
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_param
if args.batch_size != batch_size and type(batch_size) != str and batch_size > 0:
if args.batch_size != batch_size and not isinstance(batch_size, str) and batch_size > 0:
# some rebatch model may use 'N' as dynamic axes
logger.warning(
f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}"
@@ -142,12 +143,23 @@ def main(args):
del model
ort_sess = ort.InferenceSession(
onnx_path,
providers=(
["CUDAExecutionProvider"] if "CUDAExecutionProvider" in ort.get_available_providers() else ["CPUExecutionProvider"]
),
)
if "OpenVINOExecutionProvider" in ort.get_available_providers():
# requires provider options for gpu support
# fp16 causes nonsense outputs
ort_sess = ort.InferenceSession(
onnx_path,
providers=(["OpenVINOExecutionProvider"]),
provider_options=[{'device_type' : "GPU_FP32"}],
)
else:
ort_sess = ort.InferenceSession(
onnx_path,
providers=(
["CUDAExecutionProvider"] if "CUDAExecutionProvider" in ort.get_available_providers() else
["ROCMExecutionProvider"] if "ROCMExecutionProvider" in ort.get_available_providers() else
["CPUExecutionProvider"]
),
)
else:
from tensorflow.keras.models import load_model
@@ -158,16 +170,52 @@ def main(args):
with open(os.path.join(model_location, CSV_FILE), "r", encoding="utf-8") as f:
reader = csv.reader(f)
l = [row for row in reader]
header = l[0] # tag_id,name,category,count
rows = l[1:]
line = [row for row in reader]
header = line[0] # tag_id,name,category,count
rows = line[1:]
assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}"
general_tags = [row[1] for row in rows[1:] if row[2] == "0"]
character_tags = [row[1] for row in rows[1:] if row[2] == "4"]
rating_tags = [row[1] for row in rows[0:] if row[2] == "9"]
general_tags = [row[1] for row in rows[0:] if row[2] == "0"]
character_tags = [row[1] for row in rows[0:] if row[2] == "4"]
# preprocess tags in advance
if args.character_tag_expand:
for i, tag in enumerate(character_tags):
if tag.endswith(")"):
# chara_name_(series) -> chara_name, series
# chara_name_(costume)_(series) -> chara_name_(costume), series
tags = tag.split("(")
character_tag = "(".join(tags[:-1])
if character_tag.endswith("_"):
character_tag = character_tag[:-1]
series_tag = tags[-1].replace(")", "")
character_tags[i] = character_tag + args.caption_separator + series_tag
if args.remove_underscore:
rating_tags = [tag.replace("_", " ") if len(tag) > 3 else tag for tag in rating_tags]
general_tags = [tag.replace("_", " ") if len(tag) > 3 else tag for tag in general_tags]
character_tags = [tag.replace("_", " ") if len(tag) > 3 else tag for tag in character_tags]
if args.tag_replacement is not None:
# escape , and ; in tag_replacement: wd14 tag names may contain , and ;
escaped_tag_replacements = args.tag_replacement.replace("\\,", "@@@@").replace("\\;", "####")
tag_replacements = escaped_tag_replacements.split(";")
for tag_replacement in tag_replacements:
tags = tag_replacement.split(",") # source, target
assert len(tags) == 2, f"tag replacement must be in the format of `source,target` / タグの置換は `置換元,置換先` の形式で指定してください: {args.tag_replacement}"
source, target = [tag.replace("@@@@", ",").replace("####", ";") for tag in tags]
logger.info(f"replacing tag: {source} -> {target}")
if source in general_tags:
general_tags[general_tags.index(source)] = target
elif source in character_tags:
character_tags[character_tags.index(source)] = target
elif source in rating_tags:
rating_tags[rating_tags.index(source)] = target
# 画像を読み込む
train_data_dir_path = Path(args.train_data_dir)
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
logger.info(f"found {len(image_paths)} images.")
@@ -176,7 +224,12 @@ def main(args):
caption_separator = args.caption_separator
stripped_caption_separator = caption_separator.strip()
undesired_tags = set(args.undesired_tags.split(stripped_caption_separator))
undesired_tags = args.undesired_tags.split(stripped_caption_separator)
undesired_tags = set([tag.strip() for tag in undesired_tags if tag.strip() != ""])
always_first_tags = None
if args.always_first_tags is not None:
always_first_tags = [tag for tag in args.always_first_tags.split(stripped_caption_separator) if tag.strip() != ""]
def run_batch(path_imgs):
imgs = np.array([im for _, im in path_imgs])
@@ -191,22 +244,16 @@ def main(args):
probs = probs.numpy()
for (image_path, _), prob in zip(path_imgs, probs):
# 最初の4つはratingなので無視する
# # First 4 labels are actually ratings: pick one with argmax
# ratings_names = label_names[:4]
# rating_index = ratings_names["probs"].argmax()
# found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]]
# それ以降はタグなのでconfidenceがthresholdより高いものを追加する
# Everything else is tags: pick any where prediction confidence > threshold
combined_tags = []
general_tag_text = ""
rating_tag_text = ""
character_tag_text = ""
general_tag_text = ""
# 最初の4つ以降はタグなのでconfidenceがthreshold以上のものを追加する
# First 4 labels are ratings, the rest are tags: pick any where prediction confidence >= threshold
for i, p in enumerate(prob[4:]):
if i < len(general_tags) and p >= args.general_threshold:
tag_name = general_tags[i]
if args.remove_underscore and len(tag_name) > 3: # ignore emoji tags like >_< and ^_^
tag_name = tag_name.replace("_", " ")
if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
@@ -214,13 +261,37 @@ def main(args):
combined_tags.append(tag_name)
elif i >= len(general_tags) and p >= args.character_threshold:
tag_name = character_tags[i - len(general_tags)]
if args.remove_underscore and len(tag_name) > 3:
tag_name = tag_name.replace("_", " ")
if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
character_tag_text += caption_separator + tag_name
combined_tags.append(tag_name)
if args.character_tags_first: # insert to the beginning
combined_tags.insert(0, tag_name)
else:
combined_tags.append(tag_name)
# 最初の4つはratingなのでargmaxで選ぶ
# First 4 labels are actually ratings: pick one with argmax
if args.use_rating_tags or args.use_rating_tags_as_last_tag:
ratings_probs = prob[:4]
rating_index = ratings_probs.argmax()
found_rating = rating_tags[rating_index]
if found_rating not in undesired_tags:
tag_freq[found_rating] = tag_freq.get(found_rating, 0) + 1
rating_tag_text = found_rating
if args.use_rating_tags:
combined_tags.insert(0, found_rating) # insert to the beginning
else:
combined_tags.append(found_rating)
# 一番最初に置くタグを指定する
# Always put some tags at the beginning
if always_first_tags is not None:
for tag in always_first_tags:
if tag in combined_tags:
combined_tags.remove(tag)
combined_tags.insert(0, tag)
# 先頭のカンマを取る
if len(general_tag_text) > 0:
@@ -253,6 +324,7 @@ def main(args):
if args.debug:
logger.info("")
logger.info(f"{image_path}:")
logger.info(f"\tRating tags: {rating_tag_text}")
logger.info(f"\tCharacter tags: {character_tag_text}")
logger.info(f"\tGeneral tags: {general_tag_text}")
@@ -277,9 +349,7 @@ def main(args):
continue
image, image_path = data
if image is not None:
image = image.detach().numpy()
else:
if image is None:
try:
image = Image.open(image_path)
if image.mode != "RGB":
@@ -310,7 +380,9 @@ def main(args):
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument(
"train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ"
)
parser.add_argument(
"--repo_id",
type=str,
@@ -328,7 +400,9 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします",
)
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument(
"--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ"
)
parser.add_argument(
"--max_data_loader_n_workers",
type=int,
@@ -367,7 +441,9 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="replace underscores with spaces in the output tags / 出力されるタグのアンダースコアをスペースに置き換える",
)
parser.add_argument("--debug", action="store_true", help="debug mode")
parser.add_argument(
"--debug", action="store_true", help="debug mode"
)
parser.add_argument(
"--undesired_tags",
type=str,
@@ -375,18 +451,49 @@ def setup_parser() -> argparse.ArgumentParser:
help="comma-separated list of undesired tags to remove from the output / 出力から除外したいタグのカンマ区切りのリスト",
)
parser.add_argument(
"--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する"
"--frequency_tags", action="store_true", help="Show frequency of tags for images / タグの出現頻度を表示する"
)
parser.add_argument(
"--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する"
)
parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する")
parser.add_argument(
"--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する"
)
parser.add_argument(
"--use_rating_tags", action="store_true", help="Adds rating tags as the first tag / レーティングタグを最初のタグとして追加する",
)
parser.add_argument(
"--use_rating_tags_as_last_tag", action="store_true", help="Adds rating tags as the last tag / レーティングタグを最後のタグとして追加する",
)
parser.add_argument(
"--character_tags_first", action="store_true", help="Always inserts character tags before the general tags / characterタグを常にgeneralタグの前に出力する",
)
parser.add_argument(
"--always_first_tags",
type=str,
default=None,
help="comma-separated list of tags to always put at the beginning, e.g. `1girl,1boy`"
+ " / 必ず先頭に置くタグのカンマ区切りリスト、例 : `1girl,1boy`",
)
parser.add_argument(
"--caption_separator",
type=str,
default=", ",
help="Separator for captions, include space if needed / キャプションの区切り文字、必要ならスペースを含めてください",
)
parser.add_argument(
"--tag_replacement",
type=str,
default=None,
help="tag replacement in the format of `source1,target1;source2,target2; ...`. Escape `,` and `;` with `\`. e.g. `tag1,tag2;tag3,tag4`"
+ " / タグの置換を `置換元1,置換先1;置換元2,置換先2; ...`で指定する。`\` で `,` と `;` をエスケープできる。例: `tag1,tag2;tag3,tag4`",
)
parser.add_argument(
"--character_tag_expand",
action="store_true",
help="expand tag tail parenthesis to another tag for character tags. `chara_name_(series)` becomes `chara_name, series`"
+ " / キャラクタタグの末尾の括弧を別のタグに展開する。`chara_name_(series)` は `chara_name, series` になる",
)
return parser

576
flux_minimal_inference.py Normal file
View File

@@ -0,0 +1,576 @@
# Minimum Inference Code for FLUX
import argparse
import datetime
import math
import os
import random
from typing import Callable, List, Optional
import einops
import numpy as np
import torch
from tqdm import tqdm
from PIL import Image
import accelerate
from transformers import CLIPTextModel
from safetensors.torch import load_file
from library import device_utils
from library.device_utils import init_ipex, get_preferred_device
from networks import oft_flux
init_ipex()
from library.utils import setup_logging, str_to_dtype
setup_logging()
import logging
logger = logging.getLogger(__name__)
import networks.lora_flux as lora_flux
from library import flux_models, flux_utils, sd3_utils, strategy_flux
def time_shift(mu: float, sigma: float, t: torch.Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b
def get_schedule(
num_steps: int,
image_seq_len: int,
base_shift: float = 0.5,
max_shift: float = 1.15,
shift: bool = True,
) -> list[float]:
# extra step for zero
timesteps = torch.linspace(1, 0, num_steps + 1)
# shifting the schedule to favor high timesteps for higher signal images
if shift:
# eastimate mu based on linear estimation between two points
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
timesteps = time_shift(mu, 1.0, timesteps)
return timesteps.tolist()
def denoise(
model: flux_models.Flux,
img: torch.Tensor,
img_ids: torch.Tensor,
txt: torch.Tensor,
txt_ids: torch.Tensor,
vec: torch.Tensor,
timesteps: list[float],
guidance: float = 4.0,
t5_attn_mask: Optional[torch.Tensor] = None,
neg_txt: Optional[torch.Tensor] = None,
neg_vec: Optional[torch.Tensor] = None,
neg_t5_attn_mask: Optional[torch.Tensor] = None,
cfg_scale: Optional[float] = None,
):
# this is ignored for schnell
logger.info(f"guidance: {guidance}, cfg_scale: {cfg_scale}")
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
# prepare classifier free guidance
if neg_txt is not None and neg_vec is not None:
b_img_ids = torch.cat([img_ids, img_ids], dim=0)
b_txt_ids = torch.cat([txt_ids, txt_ids], dim=0)
b_txt = torch.cat([neg_txt, txt], dim=0)
b_vec = torch.cat([neg_vec, vec], dim=0)
if t5_attn_mask is not None and neg_t5_attn_mask is not None:
b_t5_attn_mask = torch.cat([neg_t5_attn_mask, t5_attn_mask], dim=0)
else:
b_t5_attn_mask = None
else:
b_img_ids = img_ids
b_txt_ids = txt_ids
b_txt = txt
b_vec = vec
b_t5_attn_mask = t5_attn_mask
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
t_vec = torch.full((b_img_ids.shape[0],), t_curr, dtype=img.dtype, device=img.device)
# classifier free guidance
if neg_txt is not None and neg_vec is not None:
b_img = torch.cat([img, img], dim=0)
else:
b_img = img
pred = model(
img=b_img,
img_ids=b_img_ids,
txt=b_txt,
txt_ids=b_txt_ids,
y=b_vec,
timesteps=t_vec,
guidance=guidance_vec,
txt_attention_mask=b_t5_attn_mask,
)
# classifier free guidance
if neg_txt is not None and neg_vec is not None:
pred_uncond, pred = torch.chunk(pred, 2, dim=0)
pred = pred_uncond + cfg_scale * (pred - pred_uncond)
img = img + (t_prev - t_curr) * pred
return img
def do_sample(
accelerator: Optional[accelerate.Accelerator],
model: flux_models.Flux,
img: torch.Tensor,
img_ids: torch.Tensor,
l_pooled: torch.Tensor,
t5_out: torch.Tensor,
txt_ids: torch.Tensor,
num_steps: int,
guidance: float,
t5_attn_mask: Optional[torch.Tensor],
is_schnell: bool,
device: torch.device,
flux_dtype: torch.dtype,
neg_l_pooled: Optional[torch.Tensor] = None,
neg_t5_out: Optional[torch.Tensor] = None,
neg_t5_attn_mask: Optional[torch.Tensor] = None,
cfg_scale: Optional[float] = None,
):
logger.info(f"num_steps: {num_steps}")
timesteps = get_schedule(num_steps, img.shape[1], shift=not is_schnell)
# denoise initial noise
if accelerator:
with accelerator.autocast(), torch.no_grad():
x = denoise(
model,
img,
img_ids,
t5_out,
txt_ids,
l_pooled,
timesteps,
guidance,
t5_attn_mask,
neg_t5_out,
neg_l_pooled,
neg_t5_attn_mask,
cfg_scale,
)
else:
with torch.autocast(device_type=device.type, dtype=flux_dtype), torch.no_grad():
x = denoise(
model,
img,
img_ids,
t5_out,
txt_ids,
l_pooled,
timesteps,
guidance,
t5_attn_mask,
neg_t5_out,
neg_l_pooled,
neg_t5_attn_mask,
cfg_scale,
)
return x
def generate_image(
model,
clip_l: CLIPTextModel,
t5xxl,
ae,
prompt: str,
seed: Optional[int],
image_width: int,
image_height: int,
steps: Optional[int],
guidance: float,
negative_prompt: Optional[str],
cfg_scale: float,
):
seed = seed if seed is not None else random.randint(0, 2**32 - 1)
logger.info(f"Seed: {seed}")
# make first noise with packed shape
# original: b,16,2*h//16,2*w//16, packed: b,h//16*w//16,16*2*2
packed_latent_height, packed_latent_width = math.ceil(image_height / 16), math.ceil(image_width / 16)
noise_dtype = torch.float32 if is_fp8(dtype) else dtype
noise = torch.randn(
1,
packed_latent_height * packed_latent_width,
16 * 2 * 2,
device=device,
dtype=noise_dtype,
generator=torch.Generator(device=device).manual_seed(seed),
)
# prepare img and img ids
# this is needed only for img2img
# img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
# if img.shape[0] == 1 and bs > 1:
# img = repeat(img, "1 ... -> bs ...", bs=bs)
# txt2img only needs img_ids
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width)
# prepare fp8 models
if is_fp8(clip_l_dtype) and (not hasattr(clip_l, "fp8_prepared") or not clip_l.fp8_prepared):
logger.info(f"prepare CLIP-L for fp8: set to {clip_l_dtype}, set embeddings to {torch.bfloat16}")
clip_l.to(clip_l_dtype) # fp8
clip_l.text_model.embeddings.to(dtype=torch.bfloat16)
clip_l.fp8_prepared = True
if is_fp8(t5xxl_dtype) and (not hasattr(t5xxl, "fp8_prepared") or not t5xxl.fp8_prepared):
logger.info(f"prepare T5xxl for fp8: set to {t5xxl_dtype}")
def prepare_fp8(text_encoder, target_dtype):
def forward_hook(module):
def forward(hidden_states):
hidden_gelu = module.act(module.wi_0(hidden_states))
hidden_linear = module.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = module.dropout(hidden_states)
hidden_states = module.wo(hidden_states)
return hidden_states
return forward
for module in text_encoder.modules():
if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
# print("set", module.__class__.__name__, "to", target_dtype)
module.to(target_dtype)
if module.__class__.__name__ in ["T5DenseGatedActDense"]:
# print("set", module.__class__.__name__, "hooks")
module.forward = forward_hook(module)
t5xxl.to(t5xxl_dtype)
prepare_fp8(t5xxl.encoder, torch.bfloat16)
t5xxl.fp8_prepared = True
# prepare embeddings
logger.info("Encoding prompts...")
clip_l = clip_l.to(device)
t5xxl = t5xxl.to(device)
def encode(prpt: str):
tokens_and_masks = tokenize_strategy.tokenize(prpt)
with torch.no_grad():
if is_fp8(clip_l_dtype):
with accelerator.autocast():
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
else:
with torch.autocast(device_type=device.type, dtype=clip_l_dtype):
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
if is_fp8(t5xxl_dtype):
with accelerator.autocast():
_, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(
tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
)
else:
with torch.autocast(device_type=device.type, dtype=t5xxl_dtype):
_, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(
tokenize_strategy, [None, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
)
return l_pooled, t5_out, txt_ids, t5_attn_mask
l_pooled, t5_out, txt_ids, t5_attn_mask = encode(prompt)
if negative_prompt:
neg_l_pooled, neg_t5_out, _, neg_t5_attn_mask = encode(negative_prompt)
else:
neg_l_pooled, neg_t5_out, neg_t5_attn_mask = None, None, None
# NaN check
if torch.isnan(l_pooled).any():
raise ValueError("NaN in l_pooled")
if torch.isnan(t5_out).any():
raise ValueError("NaN in t5_out")
if args.offload:
clip_l = clip_l.cpu()
t5xxl = t5xxl.cpu()
# del clip_l, t5xxl
device_utils.clean_memory()
# generate image
logger.info("Generating image...")
model = model.to(device)
if steps is None:
steps = 4 if is_schnell else 50
img_ids = img_ids.to(device)
t5_attn_mask = t5_attn_mask.to(device) if args.apply_t5_attn_mask else None
x = do_sample(
accelerator,
model,
noise,
img_ids,
l_pooled,
t5_out,
txt_ids,
steps,
guidance,
t5_attn_mask,
is_schnell,
device,
flux_dtype,
neg_l_pooled,
neg_t5_out,
neg_t5_attn_mask,
cfg_scale,
)
if args.offload:
model = model.cpu()
# del model
device_utils.clean_memory()
# unpack
x = x.float()
x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2)
# decode
logger.info("Decoding image...")
ae = ae.to(device)
with torch.no_grad():
if is_fp8(ae_dtype):
with accelerator.autocast():
x = ae.decode(x)
else:
with torch.autocast(device_type=device.type, dtype=ae_dtype):
x = ae.decode(x)
if args.offload:
ae = ae.cpu()
x = x.clamp(-1, 1)
x = x.permute(0, 2, 3, 1)
img = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
# save image
output_dir = args.output_dir
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png")
img.save(output_path)
logger.info(f"Saved image to {output_path}")
if __name__ == "__main__":
target_height = 768 # 1024
target_width = 1360 # 1024
# steps = 50 # 28 # 50
# guidance_scale = 5
# seed = 1 # None # 1
device = get_preferred_device()
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt_path", type=str, required=True)
parser.add_argument("--clip_l", type=str, required=False)
parser.add_argument("--t5xxl", type=str, required=False)
parser.add_argument("--ae", type=str, required=False)
parser.add_argument("--apply_t5_attn_mask", action="store_true")
parser.add_argument("--prompt", type=str, default="A photo of a cat")
parser.add_argument("--output_dir", type=str, default=".")
parser.add_argument("--dtype", type=str, default="bfloat16", help="base dtype")
parser.add_argument("--clip_l_dtype", type=str, default=None, help="dtype for clip_l")
parser.add_argument("--ae_dtype", type=str, default=None, help="dtype for ae")
parser.add_argument("--t5xxl_dtype", type=str, default=None, help="dtype for t5xxl")
parser.add_argument("--flux_dtype", type=str, default=None, help="dtype for flux")
parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--steps", type=int, default=None, help="Number of steps. Default is 4 for schnell, 50 for dev")
parser.add_argument("--guidance", type=float, default=3.5)
parser.add_argument("--negative_prompt", type=str, default=None)
parser.add_argument("--cfg_scale", type=float, default=1.0)
parser.add_argument("--offload", action="store_true", help="Offload to CPU")
parser.add_argument(
"--lora_weights",
type=str,
nargs="*",
default=[],
help="LoRA weights, only supports networks.lora_flux and lora_oft, each argument is a `path;multiplier` (semi-colon separated)",
)
parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model")
parser.add_argument("--width", type=int, default=target_width)
parser.add_argument("--height", type=int, default=target_height)
parser.add_argument("--interactive", action="store_true")
args = parser.parse_args()
seed = args.seed
steps = args.steps
guidance_scale = args.guidance
def is_fp8(dt):
return dt in [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz]
dtype = str_to_dtype(args.dtype)
clip_l_dtype = str_to_dtype(args.clip_l_dtype, dtype)
t5xxl_dtype = str_to_dtype(args.t5xxl_dtype, dtype)
ae_dtype = str_to_dtype(args.ae_dtype, dtype)
flux_dtype = str_to_dtype(args.flux_dtype, dtype)
logger.info(f"Dtypes for clip_l, t5xxl, ae, flux: {clip_l_dtype}, {t5xxl_dtype}, {ae_dtype}, {flux_dtype}")
loading_device = "cpu" if args.offload else device
use_fp8 = [is_fp8(d) for d in [dtype, clip_l_dtype, t5xxl_dtype, ae_dtype, flux_dtype]]
if any(use_fp8):
accelerator = accelerate.Accelerator(mixed_precision="bf16")
else:
accelerator = None
# load clip_l
logger.info(f"Loading clip_l from {args.clip_l}...")
clip_l = flux_utils.load_clip_l(args.clip_l, clip_l_dtype, loading_device)
clip_l.eval()
logger.info(f"Loading t5xxl from {args.t5xxl}...")
t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device)
t5xxl.eval()
# if is_fp8(clip_l_dtype):
# clip_l = accelerator.prepare(clip_l)
# if is_fp8(t5xxl_dtype):
# t5xxl = accelerator.prepare(t5xxl)
# DiT
is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device)
model.eval()
logger.info(f"Casting model to {flux_dtype}")
model.to(flux_dtype) # make sure model is dtype
# if is_fp8(flux_dtype):
# model = accelerator.prepare(model)
# if args.offload:
# model = model.to("cpu")
t5xxl_max_length = 256 if is_schnell else 512
tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length)
encoding_strategy = strategy_flux.FluxTextEncodingStrategy()
# AE
ae = flux_utils.load_ae(args.ae, ae_dtype, loading_device)
ae.eval()
# if is_fp8(ae_dtype):
# ae = accelerator.prepare(ae)
# LoRA
lora_models: List[lora_flux.LoRANetwork] = []
for weights_file in args.lora_weights:
if ";" in weights_file:
weights_file, multiplier = weights_file.split(";")
multiplier = float(multiplier)
else:
multiplier = 1.0
weights_sd = load_file(weights_file)
is_lora = is_oft = False
for key in weights_sd.keys():
if key.startswith("lora"):
is_lora = True
if key.startswith("oft"):
is_oft = True
if is_lora or is_oft:
break
module = lora_flux if is_lora else oft_flux
lora_model, _ = module.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True)
if args.merge_lora_weights:
lora_model.merge_to([clip_l, t5xxl], model, weights_sd)
else:
lora_model.apply_to([clip_l, t5xxl], model)
info = lora_model.load_state_dict(weights_sd, strict=True)
logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
lora_model.eval()
lora_model.to(device)
lora_models.append(lora_model)
if not args.interactive:
generate_image(
model,
clip_l,
t5xxl,
ae,
args.prompt,
args.seed,
args.width,
args.height,
args.steps,
args.guidance,
args.negative_prompt,
args.cfg_scale,
)
else:
# loop for interactive
width = target_width
height = target_height
steps = None
guidance = args.guidance
cfg_scale = args.cfg_scale
while True:
print(
"Enter prompt (empty to exit). Options: --w <width> --h <height> --s <steps> --d <seed> --g <guidance> --m <multipliers for LoRA>"
" --n <negative prompt>, `-` for empty negative prompt --c <cfg_scale>"
)
prompt = input()
if prompt == "":
break
# parse options
options = prompt.split("--")
prompt = options[0].strip()
seed = None
negative_prompt = None
for opt in options[1:]:
try:
opt = opt.strip()
if opt.startswith("w"):
width = int(opt[1:].strip())
elif opt.startswith("h"):
height = int(opt[1:].strip())
elif opt.startswith("s"):
steps = int(opt[1:].strip())
elif opt.startswith("d"):
seed = int(opt[1:].strip())
elif opt.startswith("g"):
guidance = float(opt[1:].strip())
elif opt.startswith("m"):
mutipliers = opt[1:].strip().split(",")
if len(mutipliers) != len(lora_models):
logger.error(f"Invalid number of multipliers, expected {len(lora_models)}")
continue
for i, lora_model in enumerate(lora_models):
lora_model.set_multiplier(float(mutipliers[i]))
elif opt.startswith("n"):
negative_prompt = opt[1:].strip()
if negative_prompt == "-":
negative_prompt = ""
elif opt.startswith("c"):
cfg_scale = float(opt[1:].strip())
except ValueError as e:
logger.error(f"Invalid option: {opt}, {e}")
generate_image(model, clip_l, t5xxl, ae, prompt, seed, width, height, steps, guidance, negative_prompt, cfg_scale)
logger.info("Done!")

849
flux_train.py Normal file
View File

@@ -0,0 +1,849 @@
# training with captions
# Swap blocks between CPU and GPU:
# This implementation is inspired by and based on the work of 2kpr.
# Many thanks to 2kpr for the original concept and implementation of memory-efficient offloading.
# The original idea has been adapted and extended to fit the current project's needs.
# Key features:
# - CPU offloading during forward and backward passes
# - Use of fused optimizer and grad_hook for efficient gradient processing
# - Per-block fused optimizer instances
import argparse
from concurrent.futures import ThreadPoolExecutor
import copy
import math
import os
from multiprocessing import Value
import time
from typing import List, Optional, Tuple, Union
import toml
from tqdm import tqdm
import torch
import torch.nn as nn
from library import utils
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate.utils import set_seed
from library import deepspeed_utils, flux_train_utils, flux_utils, strategy_base, strategy_flux
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
import library.train_util as train_util
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
import library.config_util as config_util
# import library.sdxl_train_util as sdxl_train_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments
def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
# sdxl_train_util.verify_sdxl_training_args(args)
deepspeed_utils.prepare_deepspeed_args(args)
setup_logging(args, reset=True)
# temporary: backward compatibility for deprecated options. remove in the future
if not args.skip_cache_check:
args.skip_cache_check = args.skip_latents_validity_check
# assert (
# not args.weighted_captions
# ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
logger.warning(
"cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります"
)
args.cache_text_encoder_outputs = True
if args.cpu_offload_checkpointing and not args.gradient_checkpointing:
logger.warning(
"cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled / cpu_offload_checkpointingが有効になっているため、gradient_checkpointingも有効になります"
)
args.gradient_checkpointing = True
assert (
args.blocks_to_swap is None or args.blocks_to_swap == 0
) or not args.cpu_offload_checkpointing, (
"blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません"
)
cache_latents = args.cache_latents
use_dreambooth_method = args.in_json is None
if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
if args.cache_latents:
latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(
args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
if args.dataset_config is not None:
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
if use_dreambooth_method:
logger.info("Using DreamBooth method.")
user_config = {
"datasets": [
{
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
args.train_data_dir, args.reg_data_dir
)
}
]
}
else:
logger.info("Training with captions.")
user_config = {
"datasets": [
{
"subsets": [
{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}
]
}
]
}
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
if args.debug_dataset:
if args.cache_text_encoder_outputs:
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
strategy_flux.FluxTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False
)
)
t5xxl_max_token_length = (
args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if is_schnell else 512)
)
strategy_base.TokenizeStrategy.set_strategy(strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length))
train_dataset_group.set_current_strategies()
train_util.debug_dataset(train_dataset_group, True)
return
if len(train_dataset_group) == 0:
logger.error(
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
)
return
if cache_latents:
assert (
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
if args.cache_text_encoder_outputs:
assert (
train_dataset_group.is_text_encoder_output_cacheable()
), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
# acceleratorを準備する
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
# モデルを読み込む
# load VAE for caching latents
ae = None
if cache_latents:
ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
ae.to(accelerator.device, dtype=weight_dtype)
ae.requires_grad_(False)
ae.eval()
train_dataset_group.new_cache_latents(ae, accelerator)
ae.to("cpu") # if no sampling, vae can be deleted
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
# prepare tokenize strategy
if args.t5xxl_max_token_length is None:
if is_schnell:
t5xxl_max_token_length = 256
else:
t5xxl_max_token_length = 512
else:
t5xxl_max_token_length = args.t5xxl_max_token_length
flux_tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length)
strategy_base.TokenizeStrategy.set_strategy(flux_tokenize_strategy)
# load clip_l, t5xxl for caching text encoder outputs
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
clip_l.eval()
t5xxl.eval()
clip_l.requires_grad_(False)
t5xxl.requires_grad_(False)
text_encoding_strategy = strategy_flux.FluxTextEncodingStrategy(args.apply_t5_attn_mask)
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
# cache text encoder outputs
sample_prompts_te_outputs = None
if args.cache_text_encoder_outputs:
# Text Encodes are eval and no grad here
clip_l.to(accelerator.device)
t5xxl.to(accelerator.device)
text_encoder_caching_strategy = strategy_flux.FluxTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False, args.apply_t5_attn_mask
)
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy)
with accelerator.autocast():
train_dataset_group.new_cache_text_encoder_outputs([clip_l, t5xxl], accelerator)
# cache sample prompt's embeddings to free text encoder's memory
if args.sample_prompts is not None:
logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}")
text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
prompts = train_util.load_prompts(args.sample_prompts)
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
with accelerator.autocast(), torch.no_grad():
for prompt_dict in prompts:
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
if p not in sample_prompts_te_outputs:
logger.info(f"cache Text Encoder outputs for prompt: {p}")
tokens_and_masks = flux_tokenize_strategy.tokenize(p)
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
flux_tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
)
accelerator.wait_for_everyone()
# now we can delete Text Encoders to free memory
clip_l = None
t5xxl = None
clean_memory_on_device(accelerator.device)
# load FLUX
_, flux = flux_utils.load_flow_model(
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors
)
if args.gradient_checkpointing:
flux.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing)
flux.requires_grad_(True)
# block swap
# backward compatibility
if args.blocks_to_swap is None:
blocks_to_swap = args.double_blocks_to_swap or 0
if args.single_blocks_to_swap is not None:
blocks_to_swap += args.single_blocks_to_swap // 2
if blocks_to_swap > 0:
logger.warning(
"double_blocks_to_swap and single_blocks_to_swap are deprecated. Use blocks_to_swap instead."
" / double_blocks_to_swapとsingle_blocks_to_swapは非推奨です。blocks_to_swapを使ってください。"
)
logger.info(
f"double_blocks_to_swap={args.double_blocks_to_swap} and single_blocks_to_swap={args.single_blocks_to_swap} are converted to blocks_to_swap={blocks_to_swap}."
)
args.blocks_to_swap = blocks_to_swap
del blocks_to_swap
is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
if is_swapping_blocks:
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
# This idea is based on 2kpr's great work. Thank you!
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
flux.enable_block_swap(args.blocks_to_swap, accelerator.device)
if not cache_latents:
# load VAE here if not cached
ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu")
ae.requires_grad_(False)
ae.eval()
ae.to(accelerator.device, dtype=weight_dtype)
training_models = []
params_to_optimize = []
training_models.append(flux)
name_and_params = list(flux.named_parameters())
# single param group for now
params_to_optimize.append({"params": [p for _, p in name_and_params], "lr": args.learning_rate})
param_names = [[n for n, _ in name_and_params]]
# calculate number of trainable parameters
n_params = 0
for group in params_to_optimize:
for p in group["params"]:
n_params += p.numel()
accelerator.print(f"number of trainable parameters: {n_params}")
# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")
if args.blockwise_fused_optimizers:
# fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
# Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters.
# This balances memory usage and management complexity.
# split params into groups. currently different learning rates are not supported
grouped_params = []
param_group = {}
for group in params_to_optimize:
named_parameters = list(flux.named_parameters())
assert len(named_parameters) == len(group["params"]), "number of parameters does not match"
for p, np in zip(group["params"], named_parameters):
# determine target layer and block index for each parameter
block_type = "other" # double, single or other
if np[0].startswith("double_blocks"):
block_index = int(np[0].split(".")[1])
block_type = "double"
elif np[0].startswith("single_blocks"):
block_index = int(np[0].split(".")[1])
block_type = "single"
else:
block_index = -1
param_group_key = (block_type, block_index)
if param_group_key not in param_group:
param_group[param_group_key] = []
param_group[param_group_key].append(p)
block_types_and_indices = []
for param_group_key, param_group in param_group.items():
block_types_and_indices.append(param_group_key)
grouped_params.append({"params": param_group, "lr": args.learning_rate})
num_params = 0
for p in param_group:
num_params += p.numel()
accelerator.print(f"block {param_group_key}: {num_params} parameters")
# prepare optimizers for each group
optimizers = []
for group in grouped_params:
_, _, optimizer = train_util.get_optimizer(args, trainable_params=[group])
optimizers.append(optimizer)
optimizer = optimizers[0] # avoid error in the following code
logger.info(f"using {len(optimizers)} optimizers for blockwise fused optimizers")
if train_util.is_schedulefree_optimizer(optimizers[0], args):
raise ValueError("Schedule-free optimizer is not supported with blockwise fused optimizers")
optimizer_train_fn = lambda: None # dummy function
optimizer_eval_fn = lambda: None # dummy function
else:
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(optimizer, args)
# prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
train_dataset_group.set_current_strategies()
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する
if args.blockwise_fused_optimizers:
# prepare lr schedulers for each optimizer
lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers]
lr_scheduler = lr_schedulers[0] # avoid error in the following code
else:
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# 実験的機能勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
if args.full_fp16:
assert (
args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
accelerator.print("enable full fp16 training.")
flux.to(weight_dtype)
if clip_l is not None:
clip_l.to(weight_dtype)
t5xxl.to(weight_dtype) # TODO check works with fp16 or not
elif args.full_bf16:
assert (
args.mixed_precision == "bf16"
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
accelerator.print("enable full bf16 training.")
flux.to(weight_dtype)
if clip_l is not None:
clip_l.to(weight_dtype)
t5xxl.to(weight_dtype)
# if we don't cache text encoder outputs, move them to device
if not args.cache_text_encoder_outputs:
clip_l.to(accelerator.device)
t5xxl.to(accelerator.device)
clean_memory_on_device(accelerator.device)
if args.deepspeed:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=flux)
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
training_models = [ds_model]
else:
# accelerator does some magic
# if we doesn't swap blocks, we can move the model to device
flux = accelerator.prepare(flux, device_placement=[not is_swapping_blocks])
if is_swapping_blocks:
accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
# During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do.
# -> But we think it's ok to patch accelerator even if deepspeed is enabled.
train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
if args.fused_backward_pass:
# use fused optimizer for backward pass: other optimizers will be supported in the future
import library.adafactor_fused
library.adafactor_fused.patch_adafactor_fused(optimizer)
for param_group, param_name_group in zip(optimizer.param_groups, param_names):
for parameter, param_name in zip(param_group["params"], param_name_group):
if parameter.requires_grad:
def create_grad_hook(p_name, p_group):
def grad_hook(tensor: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, p_group)
tensor.grad = None
return grad_hook
parameter.register_post_accumulate_grad_hook(create_grad_hook(param_name, param_group))
elif args.blockwise_fused_optimizers:
# prepare for additional optimizers and lr schedulers
for i in range(1, len(optimizers)):
optimizers[i] = accelerator.prepare(optimizers[i])
lr_schedulers[i] = accelerator.prepare(lr_schedulers[i])
# counters are used to determine when to step the optimizer
global optimizer_hooked_count
global num_parameters_per_group
global parameter_optimizer_map
optimizer_hooked_count = {}
num_parameters_per_group = [0] * len(optimizers)
parameter_optimizer_map = {}
for opt_idx, optimizer in enumerate(optimizers):
for param_group in optimizer.param_groups:
for parameter in param_group["params"]:
if parameter.requires_grad:
def grad_hook(parameter: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(parameter, args.max_grad_norm)
i = parameter_optimizer_map[parameter]
optimizer_hooked_count[i] += 1
if optimizer_hooked_count[i] == num_parameters_per_group[i]:
optimizers[i].step()
optimizers[i].zero_grad(set_to_none=True)
parameter.register_post_accumulate_grad_hook(grad_hook)
parameter_optimizer_map[parameter] = opt_idx
num_parameters_per_group[opt_idx] += 1
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
# 学習する
# total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
accelerator.print("running training / 学習開始")
accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
accelerator.print(
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
)
# accelerator.print(
# f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
# )
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"finetuning" if args.log_tracker_name is None else args.log_tracker_name,
config=train_util.get_sanitized_config_or_none(args),
init_kwargs=init_kwargs,
)
if is_swapping_blocks:
accelerator.unwrap_model(flux).prepare_block_swap_before_forward()
# For --sample_at_first
optimizer_eval_fn()
flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs)
optimizer_train_fn()
if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)
loss_recorder = train_util.LossRecorder()
epoch = 0 # avoid error when max_train_steps is 0
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
for m in training_models:
m.train()
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
if args.blockwise_fused_optimizers:
optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step
with accelerator.accumulate(*training_models):
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device, dtype=weight_dtype)
else:
with torch.no_grad():
# encode images to latents. images are [-1, 1]
latents = ae.encode(batch["images"].to(ae.dtype)).to(accelerator.device, dtype=weight_dtype)
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.nan_to_num(latents, 0, out=latents)
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
text_encoder_conds = text_encoder_outputs_list
else:
# not cached or training, so get from text encoders
tokens_and_masks = batch["input_ids_list"]
with torch.no_grad():
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
text_encoder_conds = text_encoding_strategy.encode_tokens(
flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask
)
if args.full_fp16:
text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds]
# TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# get noisy model input and timesteps
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
args, noise_scheduler_copy, latents, noise, accelerator.device, weight_dtype
)
# pack latents and get img_ids
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)
# get guidance: ensure args.guidance_scale is float
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)
# call model
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
if not args.apply_t5_attn_mask:
t5_attn_mask = None
with accelerator.autocast():
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
model_pred = flux(
img=packed_noisy_model_input,
img_ids=img_ids,
txt=t5_out,
txt_ids=txt_ids,
y=l_pooled,
timesteps=timesteps / 1000,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)
# unpack latents
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
# apply model prediction type
model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
# flow matching loss: this is different from SD3
target = noise - latents
# calculate loss
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c)
if weighting is not None:
loss = loss * weighting
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
loss = loss.mean()
# backward
accelerator.backward(loss)
if not (args.fused_backward_pass or args.blockwise_fused_optimizers):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = []
for m in training_models:
params_to_clip.extend(m.parameters())
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
else:
# optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
lr_scheduler.step()
if args.blockwise_fused_optimizers:
for i in range(1, len(optimizers)):
lr_schedulers[i].step()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
optimizer_eval_fn()
flux_train_utils.sample_images(
accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs
)
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
flux_train_utils.save_flux_model_on_epoch_end_or_stepwise(
args,
False,
accelerator,
save_dtype,
epoch,
num_train_epochs,
global_step,
accelerator.unwrap_model(flux),
)
optimizer_train_fn()
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if len(accelerator.trackers) > 0:
logs = {"loss": current_loss}
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
accelerator.log(logs, step=global_step)
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
optimizer_eval_fn()
if args.save_every_n_epochs is not None:
if accelerator.is_main_process:
flux_train_utils.save_flux_model_on_epoch_end_or_stepwise(
args,
True,
accelerator,
save_dtype,
epoch,
num_train_epochs,
global_step,
accelerator.unwrap_model(flux),
)
flux_train_utils.sample_images(
accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs
)
optimizer_train_fn()
is_main_process = accelerator.is_main_process
# if is_main_process:
flux = accelerator.unwrap_model(flux)
accelerator.end_training()
optimizer_eval_fn()
if args.save_state or args.save_state_on_train_end:
train_util.save_state_on_train_end(args, accelerator)
del accelerator # この後メモリを使うのでこれは消す
if is_main_process:
flux_train_utils.save_flux_model_on_train_end(args, save_dtype, epoch, global_step, flux)
logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser) # TODO split this
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, False)
train_util.add_masked_loss_arguments(parser)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
add_custom_train_arguments(parser) # TODO remove this from here
train_util.add_dit_training_arguments(parser)
flux_train_utils.add_flux_train_arguments(parser)
parser.add_argument(
"--mem_eff_save",
action="store_true",
help="[EXPERIMENTAL] use memory efficient custom model saving method / メモリ効率の良い独自のモデル保存方法を使う",
)
parser.add_argument(
"--fused_optimizer_groups",
type=int,
default=None,
help="**this option is not working** will be removed in the future / このオプションは動作しません。将来削除されます",
)
parser.add_argument(
"--blockwise_fused_optimizers",
action="store_true",
help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする",
)
parser.add_argument(
"--skip_latents_validity_check",
action="store_true",
help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください",
)
parser.add_argument(
"--double_blocks_to_swap",
type=int,
default=None,
help="[Deprecated] use 'blocks_to_swap' instead / 代わりに 'blocks_to_swap' を使用してください",
)
parser.add_argument(
"--single_blocks_to_swap",
type=int,
default=None,
help="[Deprecated] use 'blocks_to_swap' instead / 代わりに 'blocks_to_swap' を使用してください",
)
parser.add_argument(
"--cpu_offload_checkpointing",
action="store_true",
help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
train(args)

574
flux_train_network.py Normal file
View File

@@ -0,0 +1,574 @@
import argparse
import copy
import math
import random
from typing import Any, Optional
import torch
from accelerate import Accelerator
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, strategy_base, strategy_flux, train_util
import train_network
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class FluxNetworkTrainer(train_network.NetworkTrainer):
def __init__(self):
super().__init__()
self.sample_prompts_te_outputs = None
self.is_schnell: Optional[bool] = None
self.is_swapping_blocks: bool = False
def assert_extra_args(self, args, train_dataset_group):
super().assert_extra_args(args, train_dataset_group)
# sdxl_train_util.verify_sdxl_training_args(args)
if args.fp8_base_unet:
args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
logger.warning(
"cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります"
)
args.cache_text_encoder_outputs = True
if args.cache_text_encoder_outputs:
assert (
train_dataset_group.is_text_encoder_output_cacheable()
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
# prepare CLIP-L/T5XXL training flags
self.train_clip_l = not args.network_train_unet_only
self.train_t5xxl = False # default is False even if args.network_train_unet_only is False
if args.max_token_length is not None:
logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません")
assert (
args.blocks_to_swap is None or args.blocks_to_swap == 0
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません"
# deprecated split_mode option
if args.split_mode:
if args.blocks_to_swap is not None:
logger.warning(
"split_mode is deprecated. Because `--blocks_to_swap` is set, `--split_mode` is ignored."
" / split_modeは非推奨です。`--blocks_to_swap`が設定されているため、`--split_mode`は無視されます。"
)
else:
logger.warning(
"split_mode is deprecated. Please use `--blocks_to_swap` instead. `--blocks_to_swap 18` is automatically set."
" / split_modeは非推奨です。代わりに`--blocks_to_swap`を使用してください。`--blocks_to_swap 18`が自動的に設定されました。"
)
args.blocks_to_swap = 18 # 18 is safe for most cases
train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
def load_target_model(self, args, weight_dtype, accelerator):
# currently offload to cpu for some models
# if the file is fp8 and we are using fp8_base, we can load it as is (fp8)
loading_dtype = None if args.fp8_base else weight_dtype
# if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future
self.is_schnell, model = flux_utils.load_flow_model(
args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors
)
if args.fp8_base:
# check dtype of model
if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz:
raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}")
elif model.dtype == torch.float8_e4m3fn:
logger.info("Loaded fp8 FLUX model")
else:
logger.info(
"Cast FLUX model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint."
" / FLUXモデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。"
)
model.to(torch.float8_e4m3fn)
# if args.split_mode:
# model = self.prepare_split_model(model, weight_dtype, accelerator)
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
if self.is_swapping_blocks:
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
model.enable_block_swap(args.blocks_to_swap, accelerator.device)
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
clip_l.eval()
# if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8)
if args.fp8_base and not args.fp8_base_unet:
loading_dtype = None # as is
else:
loading_dtype = weight_dtype
# loading t5xxl to cpu takes a long time, so we should load to gpu in future
t5xxl = flux_utils.load_t5xxl(args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
t5xxl.eval()
if args.fp8_base and not args.fp8_base_unet:
# check dtype of model
if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz:
raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}")
elif t5xxl.dtype == torch.float8_e4m3fn:
logger.info("Loaded fp8 T5XXL model")
ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model
def get_tokenize_strategy(self, args):
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
if args.t5xxl_max_token_length is None:
if is_schnell:
t5xxl_max_token_length = 256
else:
t5xxl_max_token_length = 512
else:
t5xxl_max_token_length = args.t5xxl_max_token_length
logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}")
return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir)
def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy):
return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl]
def get_latents_caching_strategy(self, args):
latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False)
return latents_caching_strategy
def get_text_encoding_strategy(self, args):
return strategy_flux.FluxTextEncodingStrategy(apply_t5_attn_mask=args.apply_t5_attn_mask)
def post_process_network(self, args, accelerator, network, text_encoders, unet):
# check t5xxl is trained or not
self.train_t5xxl = network.train_t5xxl
if self.train_t5xxl and args.cache_text_encoder_outputs:
raise ValueError(
"T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません"
)
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
if args.cache_text_encoder_outputs:
if self.train_clip_l and not self.train_t5xxl:
return text_encoders[0:1] # only CLIP-L is needed for encoding because T5XXL is cached
else:
return None # no text encoders are needed for encoding because both are cached
else:
return text_encoders # both CLIP-L and T5XXL are needed for encoding
def get_text_encoders_train_flags(self, args, text_encoders):
return [self.train_clip_l, self.train_t5xxl]
def get_text_encoder_outputs_caching_strategy(self, args):
if args.cache_text_encoder_outputs:
# if the text encoders is trained, we need tokenization, so is_partial is True
return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
args.text_encoder_batch_size,
args.skip_cache_check,
is_partial=self.train_clip_l or self.train_t5xxl,
apply_t5_attn_mask=args.apply_t5_attn_mask,
)
else:
return None
def cache_text_encoder_outputs_if_needed(
self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
):
if args.cache_text_encoder_outputs:
if not args.lowram:
# メモリ消費を減らす
logger.info("move vae and unet to cpu to save memory")
org_vae_device = vae.device
org_unet_device = unet.device
vae.to("cpu")
unet.to("cpu")
clean_memory_on_device(accelerator.device)
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
logger.info("move text encoders to gpu")
text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
text_encoders[1].to(accelerator.device)
if text_encoders[1].dtype == torch.float8_e4m3fn:
# if we load fp8 weights, the model is already fp8, so we use it as is
self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype)
else:
# otherwise, we need to convert it to target dtype
text_encoders[1].to(weight_dtype)
with accelerator.autocast():
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
# cache sample prompts
if args.sample_prompts is not None:
logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}")
tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
prompts = train_util.load_prompts(args.sample_prompts)
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
with accelerator.autocast(), torch.no_grad():
for prompt_dict in prompts:
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
if p not in sample_prompts_te_outputs:
logger.info(f"cache Text Encoder outputs for prompt: {p}")
tokens_and_masks = tokenize_strategy.tokenize(p)
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
tokenize_strategy, text_encoders, tokens_and_masks, args.apply_t5_attn_mask
)
self.sample_prompts_te_outputs = sample_prompts_te_outputs
accelerator.wait_for_everyone()
# move back to cpu
if not self.is_train_text_encoder(args):
logger.info("move CLIP-L back to cpu")
text_encoders[0].to("cpu")
logger.info("move t5XXL back to cpu")
text_encoders[1].to("cpu")
clean_memory_on_device(accelerator.device)
if not args.lowram:
logger.info("move vae and unet back to original device")
vae.to(org_vae_device)
unet.to(org_unet_device)
else:
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
text_encoders[1].to(accelerator.device)
# def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
# noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
# # get size embeddings
# orig_size = batch["original_sizes_hw"]
# crop_size = batch["crop_top_lefts"]
# target_size = batch["target_sizes_hw"]
# embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
# # concat embeddings
# encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
# vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
# text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
# noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
# return noise_pred
def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux):
text_encoders = text_encoder # for compatibility
text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders)
flux_train_utils.sample_images(
accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs
)
# return
"""
class FluxUpperLowerWrapper(torch.nn.Module):
def __init__(self, flux_upper: flux_models.FluxUpper, flux_lower: flux_models.FluxLower, device: torch.device):
super().__init__()
self.flux_upper = flux_upper
self.flux_lower = flux_lower
self.target_device = device
def prepare_block_swap_before_forward(self):
pass
def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None):
self.flux_lower.to("cpu")
clean_memory_on_device(self.target_device)
self.flux_upper.to(self.target_device)
img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance, txt_attention_mask)
self.flux_upper.to("cpu")
clean_memory_on_device(self.target_device)
self.flux_lower.to(self.target_device)
return self.flux_lower(img, txt, vec, pe, txt_attention_mask)
wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device)
clean_memory_on_device(accelerator.device)
flux_train_utils.sample_images(
accelerator, args, epoch, global_step, wrapper, ae, text_encoders, self.sample_prompts_te_outputs
)
clean_memory_on_device(accelerator.device)
"""
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
return noise_scheduler
def encode_images_to_latents(self, args, accelerator, vae, images):
return vae.encode(images)
def shift_scale_latents(self, args, latents):
return latents
def get_noise_pred_and_target(
self,
args,
accelerator,
noise_scheduler,
latents,
batch,
text_encoder_conds,
unet: flux_models.Flux,
network,
weight_dtype,
train_unet,
):
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# get noisy model input and timesteps
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
)
# pack latents and get img_ids
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)
# get guidance
# ensure guidance_scale in args is float
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)
# ensure the hidden state will require grad
if args.gradient_checkpointing:
noisy_model_input.requires_grad_(True)
for t in text_encoder_conds:
if t is not None and t.dtype.is_floating_point:
t.requires_grad_(True)
img_ids.requires_grad_(True)
guidance_vec.requires_grad_(True)
# Predict the noise residual
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
if not args.apply_t5_attn_mask:
t5_attn_mask = None
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
# if not args.split_mode:
# normal forward
with accelerator.autocast():
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
model_pred = unet(
img=img,
img_ids=img_ids,
txt=t5_out,
txt_ids=txt_ids,
y=l_pooled,
timesteps=timesteps / 1000,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)
"""
else:
# split forward to reduce memory usage
assert network.train_blocks == "single", "train_blocks must be single for split mode"
with accelerator.autocast():
# move flux lower to cpu, and then move flux upper to gpu
unet.to("cpu")
clean_memory_on_device(accelerator.device)
self.flux_upper.to(accelerator.device)
# upper model does not require grad
with torch.no_grad():
intermediate_img, intermediate_txt, vec, pe = self.flux_upper(
img=packed_noisy_model_input,
img_ids=img_ids,
txt=t5_out,
txt_ids=txt_ids,
y=l_pooled,
timesteps=timesteps / 1000,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)
# move flux upper back to cpu, and then move flux lower to gpu
self.flux_upper.to("cpu")
clean_memory_on_device(accelerator.device)
unet.to(accelerator.device)
# lower model requires grad
intermediate_img.requires_grad_(True)
intermediate_txt.requires_grad_(True)
vec.requires_grad_(True)
pe.requires_grad_(True)
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
"""
return model_pred
model_pred = call_dit(
img=packed_noisy_model_input,
img_ids=img_ids,
t5_out=t5_out,
txt_ids=txt_ids,
l_pooled=l_pooled,
timesteps=timesteps,
guidance_vec=guidance_vec,
t5_attn_mask=t5_attn_mask,
)
# unpack latents
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
# apply model prediction type
model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
# flow matching loss: this is different from SD3
target = noise - latents
# differential output preservation
if "custom_attributes" in batch:
diff_output_pr_indices = []
for i, custom_attributes in enumerate(batch["custom_attributes"]):
if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
diff_output_pr_indices.append(i)
if len(diff_output_pr_indices) > 0:
network.set_multiplier(0.0)
unet.prepare_block_swap_before_forward()
with torch.no_grad():
model_pred_prior = call_dit(
img=packed_noisy_model_input[diff_output_pr_indices],
img_ids=img_ids[diff_output_pr_indices],
t5_out=t5_out[diff_output_pr_indices],
txt_ids=txt_ids[diff_output_pr_indices],
l_pooled=l_pooled[diff_output_pr_indices],
timesteps=timesteps[diff_output_pr_indices],
guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None,
t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None,
)
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
model_pred_prior = flux_utils.unpack_latents(model_pred_prior, packed_latent_height, packed_latent_width)
model_pred_prior, _ = flux_train_utils.apply_model_prediction_type(
args,
model_pred_prior,
noisy_model_input[diff_output_pr_indices],
sigmas[diff_output_pr_indices] if sigmas is not None else None,
)
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
return model_pred, target, timesteps, weighting
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
return loss
def get_sai_model_spec(self, args):
return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev")
def update_metadata(self, metadata, args):
metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask
metadata["ss_weighting_scheme"] = args.weighting_scheme
metadata["ss_logit_mean"] = args.logit_mean
metadata["ss_logit_std"] = args.logit_std
metadata["ss_mode_scale"] = args.mode_scale
metadata["ss_guidance_scale"] = args.guidance_scale
metadata["ss_timestep_sampling"] = args.timestep_sampling
metadata["ss_sigmoid_scale"] = args.sigmoid_scale
metadata["ss_model_prediction_type"] = args.model_prediction_type
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
def is_text_encoder_not_needed_for_training(self, args):
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
if index == 0: # CLIP-L
return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder)
else: # T5XXL
text_encoder.encoder.embed_tokens.requires_grad_(True)
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
if index == 0: # CLIP-L
logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}")
text_encoder.to(te_weight_dtype) # fp8
text_encoder.text_model.embeddings.to(dtype=weight_dtype)
else: # T5XXL
def prepare_fp8(text_encoder, target_dtype):
def forward_hook(module):
def forward(hidden_states):
hidden_gelu = module.act(module.wi_0(hidden_states))
hidden_linear = module.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = module.dropout(hidden_states)
hidden_states = module.wo(hidden_states)
return hidden_states
return forward
for module in text_encoder.modules():
if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
# print("set", module.__class__.__name__, "to", target_dtype)
module.to(target_dtype)
if module.__class__.__name__ in ["T5DenseGatedActDense"]:
# print("set", module.__class__.__name__, "hooks")
module.forward = forward_hook(module)
if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype:
logger.info(f"T5XXL already prepared for fp8")
else:
logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks")
text_encoder.to(te_weight_dtype) # fp8
prepare_fp8(text_encoder, weight_dtype)
def prepare_unet_with_accelerator(
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
) -> torch.nn.Module:
if not self.is_swapping_blocks:
return super().prepare_unet_with_accelerator(args, accelerator, unet)
# if we doesn't swap blocks, we can move the model to device
flux: flux_models.Flux = unet
flux = accelerator.prepare(flux, device_placement=[not self.is_swapping_blocks])
accelerator.unwrap_model(flux).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
accelerator.unwrap_model(flux).prepare_block_swap_before_forward()
return flux
def setup_parser() -> argparse.ArgumentParser:
parser = train_network.setup_parser()
train_util.add_dit_training_arguments(parser)
flux_train_utils.add_flux_train_arguments(parser)
parser.add_argument(
"--split_mode",
action="store_true",
# help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required"
# + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要",
help="[Deprecated] This option is deprecated. Please use `--blocks_to_swap` instead."
" / このオプションは非推奨です。代わりに`--blocks_to_swap`を使用してください。",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
trainer = FluxNetworkTrainer()
trainer.train(args)

View File

@@ -43,8 +43,8 @@ from diffusers import (
)
from einops import rearrange
from tqdm import tqdm
from torchvision import transforms
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPImageProcessor
from accelerate import init_empty_weights
import PIL
from PIL import Image
from PIL.PngImagePlugin import PngInfo
@@ -58,6 +58,7 @@ import tools.original_control_net as original_control_net
from tools.original_control_net import ControlNetInfo
from library.original_unet import UNet2DConditionModel, InferUNet2DConditionModel
from library.sdxl_original_unet import InferSdxlUNet2DConditionModel
from library.sdxl_original_control_net import SdxlControlNet
from library.original_unet import FlashAttentionFunction
from networks.control_net_lllite import ControlNetLLLite
from library.utils import GradualLatent, EulerAncestralDiscreteSchedulerGL
@@ -86,7 +87,8 @@ CLIP_VISION_MODEL = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
"""
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
# def replace_unet_modules(unet: diffusers.models.unets.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa):
def replace_unet_modules(unet, mem_eff_attn, xformers, sdpa):
if mem_eff_attn:
logger.info("Enable memory efficient attention for U-Net")
@@ -351,8 +353,8 @@ class PipelineLike:
self.token_replacements_list.append({})
# ControlNet
self.control_nets: List[ControlNetInfo] = [] # only for SD 1.5
self.control_net_lllites: List[ControlNetLLLite] = []
self.control_nets: List[Union[ControlNetInfo, Tuple[SdxlControlNet, float]]] = []
self.control_net_lllites: List[Tuple[ControlNetLLLite, float]] = []
self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない
self.gradual_latent: GradualLatent = None
@@ -541,7 +543,7 @@ class PipelineLike:
else:
text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings])
if self.control_net_lllites:
if self.control_net_lllites or (self.control_nets and self.is_sdxl):
# ControlNetのhintにguide imageを流用する。ControlNetの場合はControlNet側で行う
if isinstance(clip_guide_images, PIL.Image.Image):
clip_guide_images = [clip_guide_images]
@@ -730,7 +732,12 @@ class PipelineLike:
num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1
if self.control_nets:
guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images)
if not self.is_sdxl:
guided_hints = original_control_net.get_guided_hints(
self.control_nets, num_latent_input, batch_size, clip_guide_images
)
else:
clip_guide_images = clip_guide_images * 0.5 + 0.5 # [-1, 1] => [0, 1]
each_control_net_enabled = [self.control_net_enabled] * len(self.control_nets)
if self.control_net_lllites:
@@ -792,7 +799,7 @@ class PipelineLike:
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# disable ControlNet-LLLite if ratio is set. ControlNet is disabled in ControlNetInfo
# disable ControlNet-LLLite or SDXL ControlNet if ratio is set. ControlNet is disabled in ControlNetInfo
if self.control_net_lllites:
for j, ((control_net, ratio), enabled) in enumerate(zip(self.control_net_lllites, each_control_net_enabled)):
if not enabled or ratio >= 1.0:
@@ -801,9 +808,16 @@ class PipelineLike:
logger.info(f"ControlNetLLLite {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})")
control_net.set_cond_image(None)
each_control_net_enabled[j] = False
if self.control_nets and self.is_sdxl:
for j, ((control_net, ratio), enabled) in enumerate(zip(self.control_nets, each_control_net_enabled)):
if not enabled or ratio >= 1.0:
continue
if ratio < i / len(timesteps):
logger.info(f"ControlNet {j} is disabled (ratio={ratio} at {i} / {len(timesteps)})")
each_control_net_enabled[j] = False
# predict the noise residual
if self.control_nets and self.control_net_enabled:
if self.control_nets and self.control_net_enabled and not self.is_sdxl:
if regional_network:
num_sub_and_neg_prompts = len(text_embeddings) // batch_size
text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt
@@ -822,6 +836,31 @@ class PipelineLike:
text_embeddings,
text_emb_last,
).sample
elif self.control_nets:
input_resi_add_list = []
mid_add_list = []
for (control_net, _), enbld in zip(self.control_nets, each_control_net_enabled):
if not enbld:
continue
input_resi_add, mid_add = control_net(
latent_model_input, t, text_embeddings, vector_embeddings, clip_guide_images
)
input_resi_add_list.append(input_resi_add)
mid_add_list.append(mid_add)
if len(input_resi_add_list) == 0:
noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings)
else:
if len(input_resi_add_list) > 1:
# get mean of input_resi_add_list and mid_add_list
input_resi_add_mean = []
for i in range(len(input_resi_add_list[0])):
input_resi_add_mean.append(
torch.mean(torch.stack([input_resi_add_list[j][i] for j in range(len(input_resi_add_list))], dim=0))
)
input_resi_add = input_resi_add_mean
mid_add = torch.mean(torch.stack(mid_add_list), dim=0)
noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings, input_resi_add, mid_add)
elif self.is_sdxl:
noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings)
else:
@@ -1435,6 +1474,7 @@ class BatchDataBase(NamedTuple):
clip_prompt: str
guide_image: Any
raw_prompt: str
file_name: Optional[str]
class BatchDataExt(NamedTuple):
@@ -1493,8 +1533,6 @@ def main(args):
highres_fix = args.highres_fix_scale is not None
# assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません"
if args.v_parameterization and not args.v2:
logger.warning("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません")
if args.v2 and args.clip_skip is not None:
logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
@@ -1825,16 +1863,37 @@ def main(args):
upscaler.to(dtype).to(device)
# ControlNetの処理
control_nets: List[ControlNetInfo] = []
control_nets: List[Union[ControlNetInfo, Tuple[SdxlControlNet, float]]] = []
if args.control_net_models:
for i, model in enumerate(args.control_net_models):
prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i]
weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i]
ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i]
if not is_sdxl:
for i, model in enumerate(args.control_net_models):
prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i]
weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i]
ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i]
ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model)
prep = original_control_net.load_preprocess(prep_type)
control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio))
ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model)
prep = original_control_net.load_preprocess(prep_type)
control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio))
else:
for i, model_file in enumerate(args.control_net_models):
multiplier = (
1.0
if not args.control_net_multipliers or len(args.control_net_multipliers) <= i
else args.control_net_multipliers[i]
)
ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i]
logger.info(f"loading SDXL ControlNet: {model_file}")
from safetensors.torch import load_file
state_dict = load_file(model_file)
logger.info(f"Initializing SDXL ControlNet with multiplier: {multiplier}")
with init_empty_weights():
control_net = SdxlControlNet(multiplier=multiplier)
control_net.load_state_dict(state_dict)
control_net.to(dtype).to(device)
control_nets.append((control_net, ratio))
control_net_lllites: List[Tuple[ControlNetLLLite, float]] = []
if args.control_net_lllite_models:
@@ -2316,7 +2375,7 @@ def main(args):
# このバッチの情報を取り出す
(
return_latents,
(step_first, _, _, _, init_image, mask_image, _, guide_image, _),
(step_first, _, _, _, init_image, mask_image, _, guide_image, _, _),
(
width,
height,
@@ -2339,6 +2398,7 @@ def main(args):
prompts = []
negative_prompts = []
raw_prompts = []
filenames = []
start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype)
noises = [
torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype)
@@ -2371,7 +2431,7 @@ def main(args):
all_guide_images_are_same = True
for i, (
_,
(_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt),
(_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt, filename),
_,
) in enumerate(batch):
prompts.append(prompt)
@@ -2379,6 +2439,7 @@ def main(args):
seeds.append(seed)
clip_prompts.append(clip_prompt)
raw_prompts.append(raw_prompt)
filenames.append(filename)
if init_image is not None:
init_images.append(init_image)
@@ -2478,8 +2539,8 @@ def main(args):
# save image
highres_prefix = ("0" if highres_1st else "1") if highres_fix else ""
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate(
zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts)
for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt, filename) in enumerate(
zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts, filenames)
):
if highres_fix:
seed -= 1 # record original seed
@@ -2505,17 +2566,23 @@ def main(args):
metadata.add_text("crop-top", str(crop_top))
metadata.add_text("crop-left", str(crop_left))
if args.use_original_file_name and init_images is not None:
if type(init_images) is list:
fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png"
else:
fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png"
elif args.sequential_file_name:
fln = f"im_{highres_prefix}{step_first + i + 1:06d}.png"
if filename is not None:
fln = filename
else:
fln = f"im_{ts_str}_{highres_prefix}{i:03d}_{seed}.png"
if args.use_original_file_name and init_images is not None:
if type(init_images) is list:
fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png"
else:
fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png"
elif args.sequential_file_name:
fln = f"im_{highres_prefix}{step_first + i + 1:06d}.png"
else:
fln = f"im_{ts_str}_{highres_prefix}{i:03d}_{seed}.png"
image.save(os.path.join(args.outdir, fln), pnginfo=metadata)
if fln.endswith(".webp"):
image.save(os.path.join(args.outdir, fln), pnginfo=metadata, quality=100) # lossy
else:
image.save(os.path.join(args.outdir, fln), pnginfo=metadata)
if not args.no_preview and not highres_1st and args.interactive:
try:
@@ -2562,6 +2629,7 @@ def main(args):
# repeat prompt
for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)):
raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0]
filename = None
if pi == 0 or len(raw_prompts) > 1:
# parse prompt: if prompt is not changed, skip parsing
@@ -2783,6 +2851,12 @@ def main(args):
logger.info(f"gradual latent unsharp params: {gl_unsharp_params}")
continue
m = re.match(r"f (.+)", parg, re.IGNORECASE)
if m: # filename
filename = m.group(1)
logger.info(f"filename: {filename}")
continue
except ValueError as ex:
logger.error(f"Exception in parsing / 解析エラー: {parg}")
logger.error(f"{ex}")
@@ -2873,7 +2947,16 @@ def main(args):
b1 = BatchData(
False,
BatchDataBase(
global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt
global_step,
prompt,
negative_prompt,
seed,
init_image,
mask_image,
clip_prompt,
guide_image,
raw_prompt,
filename,
),
BatchDataExt(
width,
@@ -2916,7 +2999,7 @@ def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
parser.add_argument(
"--sdxl", action="store_true", help="load Stable Diffusion XL model / Stable Diffusion XLのモデルを読み込む"
)

View File

@@ -2216,8 +2216,6 @@ def main(args):
highres_fix = args.highres_fix_scale is not None
# assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません"
if args.v_parameterization and not args.v2:
logger.warning("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません")
if args.v2 and args.clip_skip is not None:
logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")

138
library/adafactor_fused.py Normal file
View File

@@ -0,0 +1,138 @@
import math
import torch
from transformers import Adafactor
# stochastic rounding for bfloat16
# The implementation was provided by 2kpr. Thank you very much!
def copy_stochastic_(target: torch.Tensor, source: torch.Tensor):
"""
copies source into target using stochastic rounding
Args:
target: the target tensor with dtype=bfloat16
source: the target tensor with dtype=float32
"""
# create a random 16 bit integer
result = torch.randint_like(source, dtype=torch.int32, low=0, high=(1 << 16))
# add the random number to the lower 16 bit of the mantissa
result.add_(source.view(dtype=torch.int32))
# mask off the lower 16 bit of the mantissa
result.bitwise_and_(-65536) # -65536 = FFFF0000 as a signed int32
# copy the higher 16 bit into the target tensor
target.copy_(result.view(dtype=torch.float32))
del result
@torch.no_grad()
def adafactor_step_param(self, p, group):
if p.grad is None:
return
grad = p.grad
if grad.dtype in {torch.float16, torch.bfloat16}:
grad = grad.float()
if grad.is_sparse:
raise RuntimeError("Adafactor does not support sparse gradients.")
state = self.state[p]
grad_shape = grad.shape
factored, use_first_moment = Adafactor._get_options(group, grad_shape)
# State Initialization
if len(state) == 0:
state["step"] = 0
if use_first_moment:
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(grad)
if factored:
state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
else:
state["exp_avg_sq"] = torch.zeros_like(grad)
state["RMS"] = 0
else:
if use_first_moment:
state["exp_avg"] = state["exp_avg"].to(grad)
if factored:
state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
else:
state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
p_data_fp32 = p
if p.dtype in {torch.float16, torch.bfloat16}:
p_data_fp32 = p_data_fp32.float()
state["step"] += 1
state["RMS"] = Adafactor._rms(p_data_fp32)
lr = Adafactor._get_lr(group, state)
beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
update = (grad**2) + group["eps"][0]
if factored:
exp_avg_sq_row = state["exp_avg_sq_row"]
exp_avg_sq_col = state["exp_avg_sq_col"]
exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
# Approximation of exponential moving average of square of gradient
update = Adafactor._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
update.mul_(grad)
else:
exp_avg_sq = state["exp_avg_sq"]
exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
update = exp_avg_sq.rsqrt().mul_(grad)
update.div_((Adafactor._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
update.mul_(lr)
if use_first_moment:
exp_avg = state["exp_avg"]
exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
update = exp_avg
if group["weight_decay"] != 0:
p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
p_data_fp32.add_(-update)
# if p.dtype in {torch.float16, torch.bfloat16}:
# p.copy_(p_data_fp32)
if p.dtype == torch.bfloat16:
copy_stochastic_(p, p_data_fp32)
elif p.dtype == torch.float16:
p.copy_(p_data_fp32)
@torch.no_grad()
def adafactor_step(self, closure=None):
"""
Performs a single optimization step
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
adafactor_step_param(self, p, group)
return loss
def patch_adafactor_fused(optimizer: Adafactor):
optimizer.step_param = adafactor_step_param.__get__(optimizer)
optimizer.step = adafactor_step.__get__(optimizer)

View File

@@ -10,13 +10,7 @@ import json
from pathlib import Path
# from toolz import curry
from typing import (
List,
Optional,
Sequence,
Tuple,
Union,
)
from typing import Dict, List, Optional, Sequence, Tuple, Union
import toml
import voluptuous
@@ -41,12 +35,17 @@ from .train_util import (
DatasetGroup,
)
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def add_config_arguments(parser: argparse.ArgumentParser):
parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル")
parser.add_argument(
"--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル"
)
# TODO: inherit Params class in Subset, Dataset
@@ -73,6 +72,7 @@ class BaseSubsetParams:
caption_tag_dropout_rate: float = 0.0
token_warmup_min: int = 1
token_warmup_step: float = 0
custom_attributes: Optional[Dict[str, Any]] = None
@dataclass
@@ -80,23 +80,25 @@ class DreamBoothSubsetParams(BaseSubsetParams):
is_reg: bool = False
class_tokens: Optional[str] = None
caption_extension: str = ".caption"
cache_info: bool = False
alpha_mask: bool = False
@dataclass
class FineTuningSubsetParams(BaseSubsetParams):
metadata_file: Optional[str] = None
alpha_mask: bool = False
@dataclass
class ControlNetSubsetParams(BaseSubsetParams):
conditioning_data_dir: str = None
caption_extension: str = ".caption"
cache_info: bool = False
@dataclass
class BaseDatasetParams:
tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None
max_token_length: int = None
resolution: Optional[Tuple[int, int]] = None
network_multiplier: float = 1.0
debug_dataset: bool = False
@@ -184,11 +186,13 @@ class ConfigSanitizer:
"keep_tokens": int,
"keep_tokens_separator": str,
"secondary_separator": str,
"caption_separator": str,
"enable_wildcard": bool,
"token_warmup_min": int,
"token_warmup_step": Any(float, int),
"caption_prefix": str,
"caption_suffix": str,
"custom_attributes": dict,
}
# DO means DropOut
DO_SUBSET_ASCENDABLE_SCHEMA = {
@@ -200,18 +204,22 @@ class ConfigSanitizer:
DB_SUBSET_ASCENDABLE_SCHEMA = {
"caption_extension": str,
"class_tokens": str,
"cache_info": bool,
}
DB_SUBSET_DISTINCT_SCHEMA = {
Required("image_dir"): str,
"is_reg": bool,
"alpha_mask": bool,
}
# FT means FineTuning
FT_SUBSET_DISTINCT_SCHEMA = {
Required("metadata_file"): str,
"image_dir": str,
"alpha_mask": bool,
}
CN_SUBSET_ASCENDABLE_SCHEMA = {
"caption_extension": str,
"cache_info": bool,
}
CN_SUBSET_DISTINCT_SCHEMA = {
Required("image_dir"): str,
@@ -248,9 +256,10 @@ class ConfigSanitizer:
}
def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None:
assert (
support_dreambooth or support_finetuning or support_controlnet
), "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。"
assert support_dreambooth or support_finetuning or support_controlnet, (
"Neither DreamBooth mode nor fine tuning mode nor controlnet mode specified. Please specify one mode or more."
+ " / DreamBooth モードか fine tuning モードか controlnet モードのども指定されていません。1つ以上指定してください。"
)
self.db_subset_schema = self.__merge_dict(
self.SUBSET_ASCENDABLE_SCHEMA,
@@ -317,7 +326,10 @@ class ConfigSanitizer:
self.dataset_schema = validate_flex_dataset
elif support_dreambooth:
self.dataset_schema = self.db_dataset_schema
if support_controlnet:
self.dataset_schema = self.cn_dataset_schema
else:
self.dataset_schema = self.db_dataset_schema
elif support_finetuning:
self.dataset_schema = self.ft_dataset_schema
elif support_controlnet:
@@ -362,7 +374,9 @@ class ConfigSanitizer:
return self.argparse_config_validator(argparse_namespace)
except MultipleInvalid:
# XXX: this should be a bug
logger.error("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。")
logger.error(
"Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。"
)
raise
# NOTE: value would be overwritten by latter dict if there is already the same key
@@ -508,10 +522,11 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
shuffle_caption: {subset.shuffle_caption}
keep_tokens: {subset.keep_tokens}
keep_tokens_separator: {subset.keep_tokens_separator}
caption_separator: {subset.caption_separator}
secondary_separator: {subset.secondary_separator}
enable_wildcard: {subset.enable_wildcard}
caption_dropout_rate: {subset.caption_dropout_rate}
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
caption_dropout_every_n_epochs: {subset.caption_dropout_every_n_epochs}
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
caption_prefix: {subset.caption_prefix}
caption_suffix: {subset.caption_suffix}
@@ -519,8 +534,10 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
flip_aug: {subset.flip_aug}
face_crop_aug_range: {subset.face_crop_aug_range}
random_crop: {subset.random_crop}
token_warmup_min: {subset.token_warmup_min},
token_warmup_step: {subset.token_warmup_step},
token_warmup_min: {subset.token_warmup_min}
token_warmup_step: {subset.token_warmup_step}
alpha_mask: {subset.alpha_mask}
custom_attributes: {subset.custom_attributes}
"""
),
" ",
@@ -547,11 +564,11 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
" ",
)
logger.info(f'{info}')
logger.info(f"{info}")
# make buckets first because it determines the length of dataset
# and set the same seed for all datasets
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
for i, dataset in enumerate(datasets):
logger.info(f"[Dataset {i}]")
dataset.make_buckets()
@@ -638,13 +655,17 @@ def load_user_config(file: str) -> dict:
with open(file, "r") as f:
config = json.load(f)
except Exception:
logger.error(f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}")
logger.error(
f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
)
raise
elif file.name.lower().endswith(".toml"):
try:
config = toml.load(file)
except Exception:
logger.error(f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}")
logger.error(
f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
)
raise
else:
raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")
@@ -671,13 +692,13 @@ if __name__ == "__main__":
train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning)
logger.info("[argparse_namespace]")
logger.info(f'{vars(argparse_namespace)}')
logger.info(f"{vars(argparse_namespace)}")
user_config = load_user_config(config_args.dataset_config)
logger.info("")
logger.info("[user_config]")
logger.info(f'{user_config}')
logger.info(f"{user_config}")
sanitizer = ConfigSanitizer(
config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout
@@ -686,10 +707,10 @@ if __name__ == "__main__":
logger.info("")
logger.info("[sanitized_user_config]")
logger.info(f'{sanitized_user_config}')
logger.info(f"{sanitized_user_config}")
blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
logger.info("")
logger.info("[blueprint]")
logger.info(f'{blueprint}')
logger.info(f"{blueprint}")

View File

@@ -0,0 +1,227 @@
from concurrent.futures import ThreadPoolExecutor
import time
from typing import Optional
import torch
import torch.nn as nn
from library.device_utils import clean_memory_on_device
def synchronize_device(device: torch.device):
if device.type == "cuda":
torch.cuda.synchronize()
elif device.type == "xpu":
torch.xpu.synchronize()
elif device.type == "mps":
torch.mps.synchronize()
def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
weight_swap_jobs = []
# This is not working for all cases (e.g. SD3), so we need to find the corresponding modules
# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
# print(module_to_cpu.__class__, module_to_cuda.__class__)
# if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
# weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
modules_to_cpu = {k: v for k, v in layer_to_cpu.named_modules()}
for module_to_cuda_name, module_to_cuda in layer_to_cuda.named_modules():
if hasattr(module_to_cuda, "weight") and module_to_cuda.weight is not None:
module_to_cpu = modules_to_cpu.get(module_to_cuda_name, None)
if module_to_cpu is not None and module_to_cpu.weight.shape == module_to_cuda.weight.shape:
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
else:
if module_to_cuda.weight.data.device.type != device.type:
# print(
# f"Module {module_to_cuda_name} not found in CPU model or shape mismatch, so not swapping and moving to device"
# )
module_to_cuda.weight.data = module_to_cuda.weight.data.to(device)
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
# cuda to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.record_stream(stream)
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
stream.synchronize()
# cpu to cuda
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view
stream.synchronize()
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
"""
not tested
"""
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
weight_swap_jobs = []
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
# device to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
synchronize_device()
# cpu to device
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view
synchronize_device()
def weighs_to_device(layer: nn.Module, device: torch.device):
for module in layer.modules():
if hasattr(module, "weight") and module.weight is not None:
module.weight.data = module.weight.data.to(device, non_blocking=True)
class Offloader:
"""
common offloading class
"""
def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
self.num_blocks = num_blocks
self.blocks_to_swap = blocks_to_swap
self.device = device
self.debug = debug
self.thread_pool = ThreadPoolExecutor(max_workers=1)
self.futures = {}
self.cuda_available = device.type == "cuda"
def swap_weight_devices(self, block_to_cpu: nn.Module, block_to_cuda: nn.Module):
if self.cuda_available:
swap_weight_devices_cuda(self.device, block_to_cpu, block_to_cuda)
else:
swap_weight_devices_no_cuda(self.device, block_to_cpu, block_to_cuda)
def _submit_move_blocks(self, blocks, block_idx_to_cpu, block_idx_to_cuda):
def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
if self.debug:
start_time = time.perf_counter()
print(f"Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to {'CUDA' if self.cuda_available else 'device'}")
self.swap_weight_devices(block_to_cpu, block_to_cuda)
if self.debug:
print(f"Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter()-start_time:.2f}s")
return bidx_to_cpu, bidx_to_cuda # , event
block_to_cpu = blocks[block_idx_to_cpu]
block_to_cuda = blocks[block_idx_to_cuda]
self.futures[block_idx_to_cuda] = self.thread_pool.submit(
move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda
)
def _wait_blocks_move(self, block_idx):
if block_idx not in self.futures:
return
if self.debug:
print(f"Wait for block {block_idx}")
start_time = time.perf_counter()
future = self.futures.pop(block_idx)
_, bidx_to_cuda = future.result()
assert block_idx == bidx_to_cuda, f"Block index mismatch: {block_idx} != {bidx_to_cuda}"
if self.debug:
print(f"Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s")
class ModelOffloader(Offloader):
"""
supports forward offloading
"""
def __init__(self, blocks: list[nn.Module], num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
super().__init__(num_blocks, blocks_to_swap, device, debug)
# register backward hooks
self.remove_handles = []
for i, block in enumerate(blocks):
hook = self.create_backward_hook(blocks, i)
if hook is not None:
handle = block.register_full_backward_hook(hook)
self.remove_handles.append(handle)
def __del__(self):
for handle in self.remove_handles:
handle.remove()
def create_backward_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]:
# -1 for 0-based index
num_blocks_propagated = self.num_blocks - block_index - 1
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap
waiting = block_index > 0 and block_index <= self.blocks_to_swap
if not swapping and not waiting:
return None
# create hook
block_idx_to_cpu = self.num_blocks - num_blocks_propagated
block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated
block_idx_to_wait = block_index - 1
def backward_hook(module, grad_input, grad_output):
if self.debug:
print(f"Backward hook for block {block_index}")
if swapping:
self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)
if waiting:
self._wait_blocks_move(block_idx_to_wait)
return None
return backward_hook
def prepare_block_devices_before_forward(self, blocks: list[nn.Module]):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
if self.debug:
print("Prepare block devices before forward")
for b in blocks[0 : self.num_blocks - self.blocks_to_swap]:
b.to(self.device)
weighs_to_device(b, self.device) # make sure weights are on device
for b in blocks[self.num_blocks - self.blocks_to_swap :]:
b.to(self.device) # move block to device first
weighs_to_device(b, "cpu") # make sure weights are on cpu
synchronize_device(self.device)
clean_memory_on_device(self.device)
def wait_for_block(self, block_idx: int):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
self._wait_blocks_move(block_idx)
def submit_move_blocks(self, blocks: list[nn.Module], block_idx: int):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
if block_idx >= self.blocks_to_swap:
return
block_idx_to_cpu = block_idx
block_idx_to_cuda = self.num_blocks - self.blocks_to_swap + block_idx
self._submit_move_blocks(blocks, block_idx_to_cpu, block_idx_to_cuda)

View File

@@ -3,11 +3,14 @@ import argparse
import random
import re
from typing import List, Optional, Union
from .utils import setup_logging
from .utils import setup_logging
setup_logging()
import logging
import logging
logger = logging.getLogger(__name__)
def prepare_scheduler_for_custom_training(noise_scheduler, device):
if hasattr(noise_scheduler, "all_snr"):
return
@@ -64,7 +67,7 @@ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
if v_prediction:
snr_weight = torch.div(min_snr_gamma, snr+1).float().to(loss.device)
snr_weight = torch.div(min_snr_gamma, snr + 1).float().to(loss.device)
else:
snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device)
loss = loss * snr_weight
@@ -92,13 +95,18 @@ def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_los
loss = loss + loss / scale * v_pred_like_loss
return loss
def apply_debiased_estimation(loss, timesteps, noise_scheduler):
def apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=False):
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
weight = 1/torch.sqrt(snr_t)
if v_prediction:
weight = 1 / (snr_t + 1)
else:
weight = 1 / torch.sqrt(snr_t)
loss = weight * loss
return loss
# TODO train_utilと分散しているのでどちらかに寄せる
@@ -474,6 +482,25 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
return noise
def apply_masked_loss(loss, batch):
if "conditioning_images" in batch:
# conditioning image is -1 to 1. we need to convert it to 0 to 1
mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
mask_image = mask_image / 2 + 0.5
# print(f"conditioning_image: {mask_image.shape}")
elif "alpha_masks" in batch and batch["alpha_masks"] is not None:
# alpha mask is 0 to 1
mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension
# print(f"mask_image: {mask_image.shape}, {mask_image.mean()}")
else:
return loss
# resize to the same size as the loss
mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")
loss = loss * mask_image
return loss
"""
##########################################
# Perlin Noise

139
library/deepspeed_utils.py Normal file
View File

@@ -0,0 +1,139 @@
import os
import argparse
import torch
from accelerate import DeepSpeedPlugin, Accelerator
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def add_deepspeed_arguments(parser: argparse.ArgumentParser):
# DeepSpeed Arguments. https://huggingface.co/docs/accelerate/usage_guides/deepspeed
parser.add_argument("--deepspeed", action="store_true", help="enable deepspeed training")
parser.add_argument("--zero_stage", type=int, default=2, choices=[0, 1, 2, 3], help="Possible options are 0,1,2,3.")
parser.add_argument(
"--offload_optimizer_device",
type=str,
default=None,
choices=[None, "cpu", "nvme"],
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3.",
)
parser.add_argument(
"--offload_optimizer_nvme_path",
type=str,
default=None,
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
)
parser.add_argument(
"--offload_param_device",
type=str,
default=None,
choices=[None, "cpu", "nvme"],
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3.",
)
parser.add_argument(
"--offload_param_nvme_path",
type=str,
default=None,
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
)
parser.add_argument(
"--zero3_init_flag",
action="store_true",
help="Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models."
"Only applicable with ZeRO Stage-3.",
)
parser.add_argument(
"--zero3_save_16bit_model",
action="store_true",
help="Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3.",
)
parser.add_argument(
"--fp16_master_weights_and_gradients",
action="store_true",
help="fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32.",
)
def prepare_deepspeed_args(args: argparse.Namespace):
if not args.deepspeed:
return
# To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
args.max_data_loader_n_workers = 1
def prepare_deepspeed_plugin(args: argparse.Namespace):
if not args.deepspeed:
return None
try:
import deepspeed
except ImportError as e:
logger.error(
"deepspeed is not installed. please install deepspeed in your environment with following command. DS_BUILD_OPS=0 pip install deepspeed"
)
exit(1)
deepspeed_plugin = DeepSpeedPlugin(
zero_stage=args.zero_stage,
gradient_accumulation_steps=args.gradient_accumulation_steps,
gradient_clipping=args.max_grad_norm,
offload_optimizer_device=args.offload_optimizer_device,
offload_optimizer_nvme_path=args.offload_optimizer_nvme_path,
offload_param_device=args.offload_param_device,
offload_param_nvme_path=args.offload_param_nvme_path,
zero3_init_flag=args.zero3_init_flag,
zero3_save_16bit_model=args.zero3_save_16bit_model,
)
deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_batch_size
deepspeed_plugin.deepspeed_config["train_batch_size"] = (
args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"])
)
deepspeed_plugin.set_mixed_precision(args.mixed_precision)
if args.mixed_precision.lower() == "fp16":
deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow.
if args.full_fp16 or args.fp16_master_weights_and_gradients:
if args.offload_optimizer_device == "cpu" and args.zero_stage == 2:
deepspeed_plugin.deepspeed_config["fp16"]["fp16_master_weights_and_grads"] = True
logger.info("[DeepSpeed] full fp16 enable.")
else:
logger.info(
"[DeepSpeed]full fp16, fp16_master_weights_and_grads currently only supported using ZeRO-Offload with DeepSpeedCPUAdam on ZeRO-2 stage."
)
if args.offload_optimizer_device is not None:
logger.info("[DeepSpeed] start to manually build cpu_adam.")
deepspeed.ops.op_builder.CPUAdamBuilder().load()
logger.info("[DeepSpeed] building cpu_adam done.")
return deepspeed_plugin
# Accelerate library does not support multiple models for deepspeed. So, we need to wrap multiple models into a single model.
def prepare_deepspeed_model(args: argparse.Namespace, **models):
# remove None from models
models = {k: v for k, v in models.items() if v is not None}
class DeepSpeedWrapper(torch.nn.Module):
def __init__(self, **kw_models) -> None:
super().__init__()
self.models = torch.nn.ModuleDict()
for key, model in kw_models.items():
if isinstance(model, list):
model = torch.nn.ModuleList(model)
assert isinstance(
model, torch.nn.Module
), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
self.models.update(torch.nn.ModuleDict({key: model}))
def get_models(self):
return self.models
ds_model = DeepSpeedWrapper(**models)
return ds_model

1237
library/flux_models.py Normal file

File diff suppressed because it is too large Load Diff

582
library/flux_train_utils.py Normal file
View File

@@ -0,0 +1,582 @@
import argparse
import math
import os
import numpy as np
import toml
import json
import time
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
from accelerate import Accelerator, PartialState
from transformers import CLIPTextModel
from tqdm import tqdm
from PIL import Image
from safetensors.torch import save_file
from library import flux_models, flux_utils, strategy_base, train_util
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from .utils import setup_logging, mem_eff_save_file
setup_logging()
import logging
logger = logging.getLogger(__name__)
# region sample images
def sample_images(
accelerator: Accelerator,
args: argparse.Namespace,
epoch,
steps,
flux,
ae,
text_encoders,
sample_prompts_te_outputs,
prompt_replacement=None,
):
if steps == 0:
if not args.sample_at_first:
return
else:
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
return
if args.sample_every_n_epochs is not None:
# sample_every_n_steps は無視する
if epoch is None or epoch % args.sample_every_n_epochs != 0:
return
else:
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
return
logger.info("")
logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None:
logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
return
distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
# unwrap unet and text_encoder(s)
flux = accelerator.unwrap_model(flux)
if text_encoders is not None:
text_encoders = [accelerator.unwrap_model(te) for te in text_encoders]
# print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
prompts = train_util.load_prompts(args.sample_prompts)
save_dir = args.output_dir + "/sample"
os.makedirs(save_dir, exist_ok=True)
# save random state to restore later
rng_state = torch.get_rng_state()
cuda_rng_state = None
try:
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
except Exception:
pass
if distributed_state.num_processes <= 1:
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
with torch.no_grad(), accelerator.autocast():
for prompt_dict in prompts:
sample_image_inference(
accelerator,
args,
flux,
text_encoders,
ae,
save_dir,
prompt_dict,
epoch,
steps,
sample_prompts_te_outputs,
prompt_replacement,
)
else:
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
per_process_prompts = [] # list of lists
for i in range(distributed_state.num_processes):
per_process_prompts.append(prompts[i :: distributed_state.num_processes])
with torch.no_grad():
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
for prompt_dict in prompt_dict_lists[0]:
sample_image_inference(
accelerator,
args,
flux,
text_encoders,
ae,
save_dir,
prompt_dict,
epoch,
steps,
sample_prompts_te_outputs,
prompt_replacement,
)
torch.set_rng_state(rng_state)
if cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state)
clean_memory_on_device(accelerator.device)
def sample_image_inference(
accelerator: Accelerator,
args: argparse.Namespace,
flux: flux_models.Flux,
text_encoders: Optional[List[CLIPTextModel]],
ae: flux_models.AutoEncoder,
save_dir,
prompt_dict,
epoch,
steps,
sample_prompts_te_outputs,
prompt_replacement,
):
assert isinstance(prompt_dict, dict)
# negative_prompt = prompt_dict.get("negative_prompt")
sample_steps = prompt_dict.get("sample_steps", 20)
width = prompt_dict.get("width", 512)
height = prompt_dict.get("height", 512)
scale = prompt_dict.get("scale", 3.5)
seed = prompt_dict.get("seed")
# controlnet_image = prompt_dict.get("controlnet_image")
prompt: str = prompt_dict.get("prompt", "")
# sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
# if negative_prompt is not None:
# negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
else:
# True random sample image generation
torch.seed()
torch.cuda.seed()
# if negative_prompt is None:
# negative_prompt = ""
height = max(64, height - height % 16) # round to divisible by 16
width = max(64, width - width % 16) # round to divisible by 16
logger.info(f"prompt: {prompt}")
# logger.info(f"negative_prompt: {negative_prompt}")
logger.info(f"height: {height}")
logger.info(f"width: {width}")
logger.info(f"sample_steps: {sample_steps}")
logger.info(f"scale: {scale}")
# logger.info(f"sample_sampler: {sampler_name}")
if seed is not None:
logger.info(f"seed: {seed}")
# encode prompts
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
text_encoder_conds = []
if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs:
text_encoder_conds = sample_prompts_te_outputs[prompt]
print(f"Using cached text encoder outputs for prompt: {prompt}")
if text_encoders is not None:
print(f"Encoding prompt: {prompt}")
tokens_and_masks = tokenize_strategy.tokenize(prompt)
# strategy has apply_t5_attn_mask option
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
if len(text_encoder_conds) == 0:
text_encoder_conds = encoded_text_encoder_conds
else:
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
for i in range(len(encoded_text_encoder_conds)):
if encoded_text_encoder_conds[i] is not None:
text_encoder_conds[i] = encoded_text_encoder_conds[i]
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
# sample image
weight_dtype = ae.dtype # TOFO give dtype as argument
packed_latent_height = height // 16
packed_latent_width = width // 16
noise = torch.randn(
1,
packed_latent_height * packed_latent_width,
16 * 2 * 2,
device=accelerator.device,
dtype=weight_dtype,
generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None,
)
timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) # FLUX.1 dev -> shift=True
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype)
t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None
with accelerator.autocast(), torch.no_grad():
x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask)
x = x.float()
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
# latent to image
clean_memory_on_device(accelerator.device)
org_vae_device = ae.device # will be on cpu
ae.to(accelerator.device) # distributed_state.device is same as accelerator.device
with accelerator.autocast(), torch.no_grad():
x = ae.decode(x)
ae.to(org_vae_device)
clean_memory_on_device(accelerator.device)
x = x.clamp(-1, 1)
x = x.permute(0, 2, 3, 1)
image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
# adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
# but adding 'enum' to the filename should be enough
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
seed_suffix = "" if seed is None else f"_{seed}"
i: int = prompt_dict["enum"]
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
image.save(os.path.join(save_dir, img_filename))
# send images to wandb if enabled
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
wandb_tracker = accelerator.get_tracker("wandb")
import wandb
# not to commit images to avoid inconsistency between training and logging steps
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
def time_shift(mu: float, sigma: float, t: torch.Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b
def get_schedule(
num_steps: int,
image_seq_len: int,
base_shift: float = 0.5,
max_shift: float = 1.15,
shift: bool = True,
) -> list[float]:
# extra step for zero
timesteps = torch.linspace(1, 0, num_steps + 1)
# shifting the schedule to favor high timesteps for higher signal images
if shift:
# eastimate mu based on linear estimation between two points
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
timesteps = time_shift(mu, 1.0, timesteps)
return timesteps.tolist()
def denoise(
model: flux_models.Flux,
img: torch.Tensor,
img_ids: torch.Tensor,
txt: torch.Tensor,
txt_ids: torch.Tensor,
vec: torch.Tensor,
timesteps: list[float],
guidance: float = 4.0,
t5_attn_mask: Optional[torch.Tensor] = None,
):
# this is ignored for schnell
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
model.prepare_block_swap_before_forward()
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)
img = img + (t_prev - t_curr) * pred
model.prepare_block_swap_before_forward()
return img
# endregion
# region train
def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32):
sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
schedule_timesteps = noise_scheduler.timesteps.to(device)
timesteps = timesteps.to(device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
def compute_density_for_timestep_sampling(
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
):
"""Compute the density for sampling the timesteps when doing SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
u = torch.nn.functional.sigmoid(u)
elif weighting_scheme == "mode":
u = torch.rand(size=(batch_size,), device="cpu")
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
else:
u = torch.rand(size=(batch_size,), device="cpu")
return u
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
"""Computes loss weighting scheme for SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "sigma_sqrt":
weighting = (sigmas**-2.0).float()
elif weighting_scheme == "cosmap":
bot = 1 - 2 * sigmas + 2 * sigmas**2
weighting = 2 / (math.pi * bot)
else:
weighting = torch.ones_like(sigmas)
return weighting
def get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, device, dtype
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
bsz, _, h, w = latents.shape
sigmas = None
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
# Simple random t-based noise sampling
if args.timestep_sampling == "sigmoid":
# https://github.com/XLabs-AI/x-flux/tree/main
t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
else:
t = torch.rand((bsz,), device=device)
timesteps = t * 1000.0
t = t.view(-1, 1, 1, 1)
noisy_model_input = (1 - t) * latents + t * noise
elif args.timestep_sampling == "shift":
shift = args.discrete_flow_shift
logits_norm = torch.randn(bsz, device=device)
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
timesteps = logits_norm.sigmoid()
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
t = timesteps.view(-1, 1, 1, 1)
timesteps = timesteps * 1000.0
noisy_model_input = (1 - t) * latents + t * noise
elif args.timestep_sampling == "flux_shift":
logits_norm = torch.randn(bsz, device=device)
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
timesteps = logits_norm.sigmoid()
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
timesteps = time_shift(mu, 1.0, timesteps)
t = timesteps.view(-1, 1, 1, 1)
timesteps = timesteps * 1000.0
noisy_model_input = (1 - t) * latents + t * noise
else:
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
u = compute_density_for_timestep_sampling(
weighting_scheme=args.weighting_scheme,
batch_size=bsz,
logit_mean=args.logit_mean,
logit_std=args.logit_std,
mode_scale=args.mode_scale,
)
indices = (u * noise_scheduler.config.num_train_timesteps).long()
timesteps = noise_scheduler.timesteps[indices].to(device=device)
# Add noise according to flow matching.
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
return noisy_model_input, timesteps, sigmas
def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas):
weighting = None
if args.model_prediction_type == "raw":
pass
elif args.model_prediction_type == "additive":
# add the model_pred to the noisy_model_input
model_pred = model_pred + noisy_model_input
elif args.model_prediction_type == "sigma_scaled":
# apply sigma scaling
model_pred = model_pred * (-sigmas) + noisy_model_input
# these weighting schemes use a uniform timestep sampling
# and instead post-weight the loss
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
return model_pred, weighting
def save_models(
ckpt_path: str,
flux: flux_models.Flux,
sai_metadata: Optional[dict],
save_dtype: Optional[torch.dtype] = None,
use_mem_eff_save: bool = False,
):
state_dict = {}
def update_sd(prefix, sd):
for k, v in sd.items():
key = prefix + k
if save_dtype is not None and v.dtype != save_dtype:
v = v.detach().clone().to("cpu").to(save_dtype)
state_dict[key] = v
update_sd("", flux.state_dict())
if not use_mem_eff_save:
save_file(state_dict, ckpt_path, metadata=sai_metadata)
else:
mem_eff_save_file(state_dict, ckpt_path, metadata=sai_metadata)
def save_flux_model_on_train_end(
args: argparse.Namespace, save_dtype: torch.dtype, epoch: int, global_step: int, flux: flux_models.Flux
):
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev")
save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save)
train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None)
# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している
# on_epoch_end: Trueならepoch終了時、Falseならstep経過時
def save_flux_model_on_epoch_end_or_stepwise(
args: argparse.Namespace,
on_epoch_end: bool,
accelerator,
save_dtype: torch.dtype,
epoch: int,
num_train_epochs: int,
global_step: int,
flux: flux_models.Flux,
):
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True, flux="dev")
save_models(ckpt_file, flux, sai_metadata, save_dtype, args.mem_eff_save)
train_util.save_sd_model_on_epoch_end_or_stepwise_common(
args,
on_epoch_end,
accelerator,
True,
True,
epoch,
num_train_epochs,
global_step,
sd_saver,
None,
)
# endregion
def add_flux_train_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--clip_l",
type=str,
help="path to clip_l (*.sft or *.safetensors), should be float16 / clip_lのパス*.sftまたは*.safetensors、float16が前提",
)
parser.add_argument(
"--t5xxl",
type=str,
help="path to t5xxl (*.sft or *.safetensors), should be float16 / t5xxlのパス*.sftまたは*.safetensors、float16が前提",
)
parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス*.sftまたは*.safetensors")
parser.add_argument(
"--t5xxl_max_token_length",
type=int,
default=None,
help="maximum token length for T5-XXL. if omitted, 256 for schnell and 512 for dev"
" / T5-XXLの最大トークン長。省略された場合、schnellの場合は256、devの場合は512",
)
parser.add_argument(
"--apply_t5_attn_mask",
action="store_true",
help="apply attention mask to T5-XXL encode and FLUX double blocks / T5-XXLエンコードとFLUXダブルブロックにアテンションマスクを適用する",
)
parser.add_argument(
"--guidance_scale",
type=float,
default=3.5,
help="the FLUX.1 dev variant is a guidance distilled model",
)
parser.add_argument(
"--timestep_sampling",
choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"],
default="sigma",
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and FLUX.1 shifting."
" / タイムステップをサンプリングする方法sigma、random uniform、random normalのsigmoid、sigmoidのシフト、FLUX.1のシフト。",
)
parser.add_argument(
"--sigmoid_scale",
type=float,
default=1.0,
help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率timestep-samplingが"sigmoid"の場合のみ有効)。',
)
parser.add_argument(
"--model_prediction_type",
choices=["raw", "additive", "sigma_scaled"],
default="sigma_scaled",
help="How to interpret and process the model prediction: "
"raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)."
" / モデル予測の解釈と処理方法:"
"rawそのまま使用、additiveイズ入力に加算、sigma_scaledシグマスケーリングを適用",
)
parser.add_argument(
"--discrete_flow_shift",
type=float,
default=3.0,
help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。",
)

472
library/flux_utils.py Normal file
View File

@@ -0,0 +1,472 @@
from dataclasses import replace
import json
import os
from typing import List, Optional, Tuple, Union
import einops
import torch
from safetensors.torch import load_file
from safetensors import safe_open
from accelerate import init_empty_weights
from transformers import CLIPTextModel, CLIPConfig, T5EncoderModel, T5Config
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
from library import flux_models
from library.utils import load_safetensors
MODEL_VERSION_FLUX_V1 = "flux1"
MODEL_NAME_DEV = "dev"
MODEL_NAME_SCHNELL = "schnell"
def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]:
"""
チェックポイントの状態を分析し、DiffusersかBFLか、devかschnellか、ブロック数を計算して返す。
Args:
ckpt_path (str): チェックポイントファイルまたはディレクトリのパス。
Returns:
Tuple[bool, bool, Tuple[int, int], List[str]]:
- bool: Diffusersかどうかを示すフラグ。
- bool: Schnellかどうかを示すフラグ。
- Tuple[int, int]: ダブルブロックとシングルブロックの数。
- List[str]: チェックポイントに含まれるキーのリスト。
"""
# check the state dict: Diffusers or BFL, dev or schnell, number of blocks
logger.info(f"Checking the state dict: Diffusers or BFL, dev or schnell")
if os.path.isdir(ckpt_path): # if ckpt_path is a directory, it is Diffusers
ckpt_path = os.path.join(ckpt_path, "transformer", "diffusion_pytorch_model-00001-of-00003.safetensors")
if "00001-of-00003" in ckpt_path:
ckpt_paths = [ckpt_path.replace("00001-of-00003", f"0000{i}-of-00003") for i in range(1, 4)]
else:
ckpt_paths = [ckpt_path]
keys = []
for ckpt_path in ckpt_paths:
with safe_open(ckpt_path, framework="pt") as f:
keys.extend(f.keys())
# if the key has annoying prefix, remove it
if keys[0].startswith("model.diffusion_model."):
keys = [key.replace("model.diffusion_model.", "") for key in keys]
is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys
is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys)
# check number of double and single blocks
if not is_diffusers:
max_double_block_index = max(
[int(key.split(".")[1]) for key in keys if key.startswith("double_blocks.") and key.endswith(".img_attn.proj.bias")]
)
max_single_block_index = max(
[int(key.split(".")[1]) for key in keys if key.startswith("single_blocks.") and key.endswith(".modulation.lin.bias")]
)
else:
max_double_block_index = max(
[
int(key.split(".")[1])
for key in keys
if key.startswith("transformer_blocks.") and key.endswith(".attn.add_k_proj.bias")
]
)
max_single_block_index = max(
[
int(key.split(".")[1])
for key in keys
if key.startswith("single_transformer_blocks.") and key.endswith(".attn.to_k.bias")
]
)
num_double_blocks = max_double_block_index + 1
num_single_blocks = max_single_block_index + 1
return is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths
def load_flow_model(
ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False
) -> Tuple[bool, flux_models.Flux]:
is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path)
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
# build model
logger.info(f"Building Flux model {name} from {'Diffusers' if is_diffusers else 'BFL'} checkpoint")
with torch.device("meta"):
params = flux_models.configs[name].params
# set the number of blocks
if params.depth != num_double_blocks:
logger.info(f"Setting the number of double blocks from {params.depth} to {num_double_blocks}")
params = replace(params, depth=num_double_blocks)
if params.depth_single_blocks != num_single_blocks:
logger.info(f"Setting the number of single blocks from {params.depth_single_blocks} to {num_single_blocks}")
params = replace(params, depth_single_blocks=num_single_blocks)
model = flux_models.Flux(params)
if dtype is not None:
model = model.to(dtype)
# load_sft doesn't support torch.device
logger.info(f"Loading state dict from {ckpt_path}")
sd = {}
for ckpt_path in ckpt_paths:
sd.update(load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype))
# convert Diffusers to BFL
if is_diffusers:
logger.info("Converting Diffusers to BFL")
sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks)
logger.info("Converted Diffusers to BFL")
# if the key has annoying prefix, remove it
for key in list(sd.keys()):
new_key = key.replace("model.diffusion_model.", "")
if new_key == key:
break # the model doesn't have annoying prefix
sd[new_key] = sd.pop(key)
info = model.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded Flux: {info}")
return is_schnell, model
def load_ae(
ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False
) -> flux_models.AutoEncoder:
logger.info("Building AutoEncoder")
with torch.device("meta"):
# dev and schnell have the same AE params
ae = flux_models.AutoEncoder(flux_models.configs[MODEL_NAME_DEV].ae_params).to(dtype)
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
info = ae.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded AE: {info}")
return ae
def load_clip_l(
ckpt_path: Optional[str],
dtype: torch.dtype,
device: Union[str, torch.device],
disable_mmap: bool = False,
state_dict: Optional[dict] = None,
) -> CLIPTextModel:
logger.info("Building CLIP-L")
CLIPL_CONFIG = {
"_name_or_path": "clip-vit-large-patch14/",
"architectures": ["CLIPModel"],
"initializer_factor": 1.0,
"logit_scale_init_value": 2.6592,
"model_type": "clip",
"projection_dim": 768,
# "text_config": {
"_name_or_path": "",
"add_cross_attention": False,
"architectures": None,
"attention_dropout": 0.0,
"bad_words_ids": None,
"bos_token_id": 0,
"chunk_size_feed_forward": 0,
"cross_attention_hidden_size": None,
"decoder_start_token_id": None,
"diversity_penalty": 0.0,
"do_sample": False,
"dropout": 0.0,
"early_stopping": False,
"encoder_no_repeat_ngram_size": 0,
"eos_token_id": 2,
"finetuning_task": None,
"forced_bos_token_id": None,
"forced_eos_token_id": None,
"hidden_act": "quick_gelu",
"hidden_size": 768,
"id2label": {"0": "LABEL_0", "1": "LABEL_1"},
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 3072,
"is_decoder": False,
"is_encoder_decoder": False,
"label2id": {"LABEL_0": 0, "LABEL_1": 1},
"layer_norm_eps": 1e-05,
"length_penalty": 1.0,
"max_length": 20,
"max_position_embeddings": 77,
"min_length": 0,
"model_type": "clip_text_model",
"no_repeat_ngram_size": 0,
"num_attention_heads": 12,
"num_beam_groups": 1,
"num_beams": 1,
"num_hidden_layers": 12,
"num_return_sequences": 1,
"output_attentions": False,
"output_hidden_states": False,
"output_scores": False,
"pad_token_id": 1,
"prefix": None,
"problem_type": None,
"projection_dim": 768,
"pruned_heads": {},
"remove_invalid_values": False,
"repetition_penalty": 1.0,
"return_dict": True,
"return_dict_in_generate": False,
"sep_token_id": None,
"task_specific_params": None,
"temperature": 1.0,
"tie_encoder_decoder": False,
"tie_word_embeddings": True,
"tokenizer_class": None,
"top_k": 50,
"top_p": 1.0,
"torch_dtype": None,
"torchscript": False,
"transformers_version": "4.16.0.dev0",
"use_bfloat16": False,
"vocab_size": 49408,
"hidden_act": "gelu",
"hidden_size": 1280,
"intermediate_size": 5120,
"num_attention_heads": 20,
"num_hidden_layers": 32,
# },
# "text_config_dict": {
"hidden_size": 768,
"intermediate_size": 3072,
"num_attention_heads": 12,
"num_hidden_layers": 12,
"projection_dim": 768,
# },
# "torch_dtype": "float32",
# "transformers_version": None,
}
config = CLIPConfig(**CLIPL_CONFIG)
with init_empty_weights():
clip = CLIPTextModel._from_config(config)
if state_dict is not None:
sd = state_dict
else:
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
info = clip.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded CLIP-L: {info}")
return clip
def load_t5xxl(
ckpt_path: str,
dtype: Optional[torch.dtype],
device: Union[str, torch.device],
disable_mmap: bool = False,
state_dict: Optional[dict] = None,
) -> T5EncoderModel:
T5_CONFIG_JSON = """
{
"architectures": [
"T5EncoderModel"
],
"classifier_dropout": 0.0,
"d_ff": 10240,
"d_kv": 64,
"d_model": 4096,
"decoder_start_token_id": 0,
"dense_act_fn": "gelu_new",
"dropout_rate": 0.1,
"eos_token_id": 1,
"feed_forward_proj": "gated-gelu",
"initializer_factor": 1.0,
"is_encoder_decoder": true,
"is_gated_act": true,
"layer_norm_epsilon": 1e-06,
"model_type": "t5",
"num_decoder_layers": 24,
"num_heads": 64,
"num_layers": 24,
"output_past": true,
"pad_token_id": 0,
"relative_attention_max_distance": 128,
"relative_attention_num_buckets": 32,
"tie_word_embeddings": false,
"torch_dtype": "float16",
"transformers_version": "4.41.2",
"use_cache": true,
"vocab_size": 32128
}
"""
config = json.loads(T5_CONFIG_JSON)
config = T5Config(**config)
with init_empty_weights():
t5xxl = T5EncoderModel._from_config(config)
if state_dict is not None:
sd = state_dict
else:
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
info = t5xxl.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded T5xxl: {info}")
return t5xxl
def get_t5xxl_actual_dtype(t5xxl: T5EncoderModel) -> torch.dtype:
# nn.Embedding is the first layer, but it could be casted to bfloat16 or float32
return t5xxl.encoder.block[0].layer[0].SelfAttention.q.weight.dtype
def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_width: int):
img_ids = torch.zeros(packed_latent_height, packed_latent_width, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(packed_latent_height)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(packed_latent_width)[None, :]
img_ids = einops.repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
return img_ids
def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor:
"""
x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2
"""
x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2)
return x
def pack_latents(x: torch.Tensor) -> torch.Tensor:
"""
x: [b c (h ph) (w pw)] -> [b (h w) (c ph pw)], ph=2, pw=2
"""
x = einops.rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
return x
# region Diffusers
NUM_DOUBLE_BLOCKS = 19
NUM_SINGLE_BLOCKS = 38
BFL_TO_DIFFUSERS_MAP = {
"time_in.in_layer.weight": ["time_text_embed.timestep_embedder.linear_1.weight"],
"time_in.in_layer.bias": ["time_text_embed.timestep_embedder.linear_1.bias"],
"time_in.out_layer.weight": ["time_text_embed.timestep_embedder.linear_2.weight"],
"time_in.out_layer.bias": ["time_text_embed.timestep_embedder.linear_2.bias"],
"vector_in.in_layer.weight": ["time_text_embed.text_embedder.linear_1.weight"],
"vector_in.in_layer.bias": ["time_text_embed.text_embedder.linear_1.bias"],
"vector_in.out_layer.weight": ["time_text_embed.text_embedder.linear_2.weight"],
"vector_in.out_layer.bias": ["time_text_embed.text_embedder.linear_2.bias"],
"guidance_in.in_layer.weight": ["time_text_embed.guidance_embedder.linear_1.weight"],
"guidance_in.in_layer.bias": ["time_text_embed.guidance_embedder.linear_1.bias"],
"guidance_in.out_layer.weight": ["time_text_embed.guidance_embedder.linear_2.weight"],
"guidance_in.out_layer.bias": ["time_text_embed.guidance_embedder.linear_2.bias"],
"txt_in.weight": ["context_embedder.weight"],
"txt_in.bias": ["context_embedder.bias"],
"img_in.weight": ["x_embedder.weight"],
"img_in.bias": ["x_embedder.bias"],
"double_blocks.().img_mod.lin.weight": ["norm1.linear.weight"],
"double_blocks.().img_mod.lin.bias": ["norm1.linear.bias"],
"double_blocks.().txt_mod.lin.weight": ["norm1_context.linear.weight"],
"double_blocks.().txt_mod.lin.bias": ["norm1_context.linear.bias"],
"double_blocks.().img_attn.qkv.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight"],
"double_blocks.().img_attn.qkv.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias"],
"double_blocks.().txt_attn.qkv.weight": ["attn.add_q_proj.weight", "attn.add_k_proj.weight", "attn.add_v_proj.weight"],
"double_blocks.().txt_attn.qkv.bias": ["attn.add_q_proj.bias", "attn.add_k_proj.bias", "attn.add_v_proj.bias"],
"double_blocks.().img_attn.norm.query_norm.scale": ["attn.norm_q.weight"],
"double_blocks.().img_attn.norm.key_norm.scale": ["attn.norm_k.weight"],
"double_blocks.().txt_attn.norm.query_norm.scale": ["attn.norm_added_q.weight"],
"double_blocks.().txt_attn.norm.key_norm.scale": ["attn.norm_added_k.weight"],
"double_blocks.().img_mlp.0.weight": ["ff.net.0.proj.weight"],
"double_blocks.().img_mlp.0.bias": ["ff.net.0.proj.bias"],
"double_blocks.().img_mlp.2.weight": ["ff.net.2.weight"],
"double_blocks.().img_mlp.2.bias": ["ff.net.2.bias"],
"double_blocks.().txt_mlp.0.weight": ["ff_context.net.0.proj.weight"],
"double_blocks.().txt_mlp.0.bias": ["ff_context.net.0.proj.bias"],
"double_blocks.().txt_mlp.2.weight": ["ff_context.net.2.weight"],
"double_blocks.().txt_mlp.2.bias": ["ff_context.net.2.bias"],
"double_blocks.().img_attn.proj.weight": ["attn.to_out.0.weight"],
"double_blocks.().img_attn.proj.bias": ["attn.to_out.0.bias"],
"double_blocks.().txt_attn.proj.weight": ["attn.to_add_out.weight"],
"double_blocks.().txt_attn.proj.bias": ["attn.to_add_out.bias"],
"single_blocks.().modulation.lin.weight": ["norm.linear.weight"],
"single_blocks.().modulation.lin.bias": ["norm.linear.bias"],
"single_blocks.().linear1.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight", "proj_mlp.weight"],
"single_blocks.().linear1.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias", "proj_mlp.bias"],
"single_blocks.().linear2.weight": ["proj_out.weight"],
"single_blocks.().norm.query_norm.scale": ["attn.norm_q.weight"],
"single_blocks.().norm.key_norm.scale": ["attn.norm_k.weight"],
"single_blocks.().linear2.weight": ["proj_out.weight"],
"single_blocks.().linear2.bias": ["proj_out.bias"],
"final_layer.linear.weight": ["proj_out.weight"],
"final_layer.linear.bias": ["proj_out.bias"],
"final_layer.adaLN_modulation.1.weight": ["norm_out.linear.weight"],
"final_layer.adaLN_modulation.1.bias": ["norm_out.linear.bias"],
}
def make_diffusers_to_bfl_map(num_double_blocks: int, num_single_blocks: int) -> dict[str, tuple[int, str]]:
# make reverse map from diffusers map
diffusers_to_bfl_map = {} # key: diffusers_key, value: (index, bfl_key)
for b in range(num_double_blocks):
for key, weights in BFL_TO_DIFFUSERS_MAP.items():
if key.startswith("double_blocks."):
block_prefix = f"transformer_blocks.{b}."
for i, weight in enumerate(weights):
diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}"))
for b in range(num_single_blocks):
for key, weights in BFL_TO_DIFFUSERS_MAP.items():
if key.startswith("single_blocks."):
block_prefix = f"single_transformer_blocks.{b}."
for i, weight in enumerate(weights):
diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}"))
for key, weights in BFL_TO_DIFFUSERS_MAP.items():
if not (key.startswith("double_blocks.") or key.startswith("single_blocks.")):
for i, weight in enumerate(weights):
diffusers_to_bfl_map[weight] = (i, key)
return diffusers_to_bfl_map
def convert_diffusers_sd_to_bfl(
diffusers_sd: dict[str, torch.Tensor], num_double_blocks: int = NUM_DOUBLE_BLOCKS, num_single_blocks: int = NUM_SINGLE_BLOCKS
) -> dict[str, torch.Tensor]:
diffusers_to_bfl_map = make_diffusers_to_bfl_map(num_double_blocks, num_single_blocks)
# iterate over three safetensors files to reduce memory usage
flux_sd = {}
for diffusers_key, tensor in diffusers_sd.items():
if diffusers_key in diffusers_to_bfl_map:
index, bfl_key = diffusers_to_bfl_map[diffusers_key]
if bfl_key not in flux_sd:
flux_sd[bfl_key] = []
flux_sd[bfl_key].append((index, tensor))
else:
logger.error(f"Error: Key not found in diffusers_to_bfl_map: {diffusers_key}")
raise KeyError(f"Key not found in diffusers_to_bfl_map: {diffusers_key}")
# concat tensors if multiple tensors are mapped to a single key, sort by index
for key, values in flux_sd.items():
if len(values) == 1:
flux_sd[key] = values[0][1]
else:
flux_sd[key] = torch.cat([value[1] for value in sorted(values, key=lambda x: x[0])])
# special case for final_layer.adaLN_modulation.1.weight and final_layer.adaLN_modulation.1.bias
def swap_scale_shift(weight):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
if "final_layer.adaLN_modulation.1.weight" in flux_sd:
flux_sd["final_layer.adaLN_modulation.1.weight"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.weight"])
if "final_layer.adaLN_modulation.1.bias" in flux_sd:
flux_sd["final_layer.adaLN_modulation.1.bias"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.bias"])
return flux_sd
# endregion

View File

@@ -32,6 +32,7 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.FloatTensor = torch.xpu.FloatTensor
torch.Tensor.cuda = torch.Tensor.xpu
torch.Tensor.is_cuda = torch.Tensor.is_xpu
torch.nn.Module.cuda = torch.nn.Module.xpu
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
torch.cuda._initialized = torch.xpu.lazy_init._initialized
@@ -147,9 +148,9 @@ def ipex_init(): # pylint: disable=too-many-statements
# C
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_eu_count
ipex._C._DeviceProperties.major = 2023
ipex._C._DeviceProperties.minor = 2
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
ipex._C._DeviceProperties.major = 2024
ipex._C._DeviceProperties.minor = 0
# Fix functions with ipex:
torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]

View File

@@ -5,7 +5,7 @@ from functools import cache
# pylint: disable=protected-access, missing-function-docstring, line-too-long
# ARC GPUs can't allocate more than 4GB to a single block so we slice the attetion layers
# ARC GPUs can't allocate more than 4GB to a single block so we slice the attention layers
sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4))
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
@@ -122,15 +122,15 @@ def torch_bmm_32_bit(input, mat2, *, out=None):
mat2[start_idx:end_idx],
out=out
)
torch.xpu.synchronize(input.device)
else:
return original_torch_bmm(input, mat2, out=out)
torch.xpu.synchronize(input.device)
return hidden_states
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs):
if query.device.type != "xpu":
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_sdpa_slice_sizes(query.shape, query.element_size())
# Slice SDPA
@@ -153,7 +153,7 @@ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropo
key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal
dropout_p=dropout_p, is_causal=is_causal, **kwargs
)
else:
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
@@ -161,7 +161,7 @@ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropo
key[start_idx:end_idx, start_idx_2:end_idx_2],
value[start_idx:end_idx, start_idx_2:end_idx_2],
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal
dropout_p=dropout_p, is_causal=is_causal, **kwargs
)
else:
hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention(
@@ -169,9 +169,9 @@ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropo
key[start_idx:end_idx],
value[start_idx:end_idx],
attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal
dropout_p=dropout_p, is_causal=is_causal, **kwargs
)
torch.xpu.synchronize(query.device)
else:
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
torch.xpu.synchronize(query.device)
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
return hidden_states

View File

@@ -12,7 +12,7 @@ device_supports_fp64 = torch.xpu.has_fp64_dtype()
class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument
if isinstance(device_ids, list) and len(device_ids) > 1:
logger.error("IPEX backend doesn't support DataParallel on multiple XPU devices")
print("IPEX backend doesn't support DataParallel on multiple XPU devices")
return module.to("xpu")
def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
@@ -42,7 +42,7 @@ def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=Non
original_interpolate = torch.nn.functional.interpolate
@wraps(torch.nn.functional.interpolate)
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
if antialias or align_corners is not None:
if antialias or align_corners is not None or mode == 'bicubic':
return_device = tensor.device
return_dtype = tensor.dtype
return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode,
@@ -190,6 +190,16 @@ def Tensor_cuda(self, device=None, *args, **kwargs):
else:
return original_Tensor_cuda(self, device, *args, **kwargs)
original_Tensor_pin_memory = torch.Tensor.pin_memory
@wraps(torch.Tensor.pin_memory)
def Tensor_pin_memory(self, device=None, *args, **kwargs):
if device is None:
device = "xpu"
if check_device(device):
return original_Tensor_pin_memory(self, return_xpu(device), *args, **kwargs)
else:
return original_Tensor_pin_memory(self, device, *args, **kwargs)
original_UntypedStorage_init = torch.UntypedStorage.__init__
@wraps(torch.UntypedStorage.__init__)
def UntypedStorage_init(*args, device=None, **kwargs):
@@ -216,7 +226,9 @@ def torch_empty(*args, device=None, **kwargs):
original_torch_randn = torch.randn
@wraps(torch.randn)
def torch_randn(*args, device=None, **kwargs):
def torch_randn(*args, device=None, dtype=None, **kwargs):
if dtype == bytes:
dtype = None
if check_device(device):
return original_torch_randn(*args, device=return_xpu(device), **kwargs)
else:
@@ -256,11 +268,13 @@ def torch_Generator(device=None):
original_torch_load = torch.load
@wraps(torch.load)
def torch_load(f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs):
def torch_load(f, map_location=None, *args, **kwargs):
if map_location is None:
map_location = "xpu"
if check_device(map_location):
return original_torch_load(f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs)
return original_torch_load(f, *args, map_location=return_xpu(map_location), **kwargs)
else:
return original_torch_load(f, map_location=map_location, pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs)
return original_torch_load(f, *args, map_location=map_location, **kwargs)
# Hijack Functions:
@@ -268,6 +282,7 @@ def ipex_hijacks():
torch.tensor = torch_tensor
torch.Tensor.to = Tensor_to
torch.Tensor.cuda = Tensor_cuda
torch.Tensor.pin_memory = Tensor_pin_memory
torch.UntypedStorage.__init__ = UntypedStorage_init
torch.UntypedStorage.cuda = UntypedStorage_cuda
torch.empty = torch_empty

View File

@@ -6,8 +6,10 @@ import os
from typing import List, Optional, Tuple, Union
import safetensors
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
r"""
@@ -55,12 +57,18 @@ ARCH_SD_V1 = "stable-diffusion-v1"
ARCH_SD_V2_512 = "stable-diffusion-v2-512"
ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v"
ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base"
ARCH_SD3_M = "stable-diffusion-3" # may be followed by "-m" or "-5-large" etc.
# ARCH_SD3_UNKNOWN = "stable-diffusion-3"
ARCH_FLUX_1_DEV = "flux-1-dev"
ARCH_FLUX_1_UNKNOWN = "flux-1"
ADAPTER_LORA = "lora"
ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models"
IMPL_COMFY_UI = "https://github.com/comfyanonymous/ComfyUI"
IMPL_DIFFUSERS = "diffusers"
IMPL_FLUX = "https://github.com/black-forest-labs/flux"
PRED_TYPE_EPSILON = "epsilon"
PRED_TYPE_V = "v"
@@ -113,7 +121,12 @@ def build_metadata(
merged_from: Optional[str] = None,
timesteps: Optional[Tuple[int, int]] = None,
clip_skip: Optional[int] = None,
sd3: Optional[str] = None,
flux: Optional[str] = None,
):
"""
sd3: only supports "m", flux: only supports "dev"
"""
# if state_dict is None, hash is not calculated
metadata = {}
@@ -126,6 +139,13 @@ def build_metadata(
if sdxl:
arch = ARCH_SD_XL_V1_BASE
elif sd3 is not None:
arch = ARCH_SD3_M + "-" + sd3
elif flux is not None:
if flux == "dev":
arch = ARCH_FLUX_1_DEV
else:
arch = ARCH_FLUX_1_UNKNOWN
elif v2:
if v_parameterization:
arch = ARCH_SD_V2_768_V
@@ -142,9 +162,12 @@ def build_metadata(
metadata["modelspec.architecture"] = arch
if not lora and not textual_inversion and is_stable_diffusion_ckpt is None:
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
if flux is not None:
# Flux
impl = IMPL_FLUX
elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
# Stable Diffusion ckpt, TI, SDXL LoRA
impl = IMPL_STABILITY_AI
else:
@@ -202,7 +225,7 @@ def build_metadata(
reso = (reso[0], reso[0])
else:
# resolution is defined in dataset, so use default
if sdxl:
if sdxl or sd3 is not None or flux is not None:
reso = 1024
elif v2 and v_parameterization:
reso = 768
@@ -213,7 +236,9 @@ def build_metadata(
metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}"
if v_parameterization:
if flux is not None:
del metadata["modelspec.prediction_type"]
elif v_parameterization:
metadata["modelspec.prediction_type"] = PRED_TYPE_V
else:
metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON
@@ -236,7 +261,7 @@ def build_metadata(
# assert all([v is not None for v in metadata.values()]), metadata
if not all([v is not None for v in metadata.values()]):
logger.error(f"Internal error: some metadata values are None: {metadata}")
return metadata
@@ -250,7 +275,7 @@ def get_title(metadata: dict) -> Optional[str]:
def load_metadata_from_safetensors(model: str) -> dict:
if not model.endswith(".safetensors"):
return {}
with safetensors.safe_open(model, framework="pt") as f:
metadata = f.metadata()
if metadata is None:

1426
library/sd3_models.py Normal file

File diff suppressed because it is too large Load Diff

945
library/sd3_train_utils.py Normal file
View File

@@ -0,0 +1,945 @@
import argparse
import math
import os
import toml
import json
import time
from typing import Dict, List, Optional, Tuple, Union
import torch
from safetensors.torch import save_file
from accelerate import Accelerator, PartialState
from tqdm import tqdm
from PIL import Image
from transformers import CLIPTextModelWithProjection, T5EncoderModel
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
# from transformers import CLIPTokenizer
# from library import model_util
# , sdxl_model_util, train_util, sdxl_original_unet
# from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
from library import sd3_models, sd3_utils, strategy_base, train_util
def save_models(
ckpt_path: str,
mmdit: Optional[sd3_models.MMDiT],
vae: Optional[sd3_models.SDVAE],
clip_l: Optional[CLIPTextModelWithProjection],
clip_g: Optional[CLIPTextModelWithProjection],
t5xxl: Optional[T5EncoderModel],
sai_metadata: Optional[dict],
save_dtype: Optional[torch.dtype] = None,
):
r"""
Save models to checkpoint file. Only supports unified checkpoint format.
"""
state_dict = {}
def update_sd(prefix, sd):
for k, v in sd.items():
key = prefix + k
if save_dtype is not None:
v = v.detach().clone().to("cpu").to(save_dtype)
state_dict[key] = v
update_sd("model.diffusion_model.", mmdit.state_dict())
update_sd("first_stage_model.", vae.state_dict())
# do not support unified checkpoint format for now
# if clip_l is not None:
# update_sd("text_encoders.clip_l.", clip_l.state_dict())
# if clip_g is not None:
# update_sd("text_encoders.clip_g.", clip_g.state_dict())
# if t5xxl is not None:
# update_sd("text_encoders.t5xxl.", t5xxl.state_dict())
save_file(state_dict, ckpt_path, metadata=sai_metadata)
if clip_l is not None:
clip_l_path = ckpt_path.replace(".safetensors", "_clip_l.safetensors")
save_file(clip_l.state_dict(), clip_l_path)
if clip_g is not None:
clip_g_path = ckpt_path.replace(".safetensors", "_clip_g.safetensors")
save_file(clip_g.state_dict(), clip_g_path)
if t5xxl is not None:
t5xxl_path = ckpt_path.replace(".safetensors", "_t5xxl.safetensors")
t5xxl_state_dict = t5xxl.state_dict()
# replace "shared.weight" with copy of it to avoid annoying shared tensor error on safetensors.save_file
shared_weight = t5xxl_state_dict["shared.weight"]
shared_weight_copy = shared_weight.detach().clone()
t5xxl_state_dict["shared.weight"] = shared_weight_copy
save_file(t5xxl_state_dict, t5xxl_path)
def save_sd3_model_on_train_end(
args: argparse.Namespace,
save_dtype: torch.dtype,
epoch: int,
global_step: int,
clip_l: Optional[CLIPTextModelWithProjection],
clip_g: Optional[CLIPTextModelWithProjection],
t5xxl: Optional[T5EncoderModel],
mmdit: sd3_models.MMDiT,
vae: sd3_models.SDVAE,
):
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec(
None, args, False, False, False, is_stable_diffusion_ckpt=True, sd3=mmdit.model_type
)
save_models(ckpt_file, mmdit, vae, clip_l, clip_g, t5xxl, sai_metadata, save_dtype)
train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None)
# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合している
# on_epoch_end: Trueならepoch終了時、Falseならstep経過時
def save_sd3_model_on_epoch_end_or_stepwise(
args: argparse.Namespace,
on_epoch_end: bool,
accelerator,
save_dtype: torch.dtype,
epoch: int,
num_train_epochs: int,
global_step: int,
clip_l: Optional[CLIPTextModelWithProjection],
clip_g: Optional[CLIPTextModelWithProjection],
t5xxl: Optional[T5EncoderModel],
mmdit: sd3_models.MMDiT,
vae: sd3_models.SDVAE,
):
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec(
None, args, False, False, False, is_stable_diffusion_ckpt=True, sd3=mmdit.model_type
)
save_models(ckpt_file, mmdit, vae, clip_l, clip_g, t5xxl, sai_metadata, save_dtype)
train_util.save_sd_model_on_epoch_end_or_stepwise_common(
args,
on_epoch_end,
accelerator,
True,
True,
epoch,
num_train_epochs,
global_step,
sd_saver,
None,
)
def add_sd3_training_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--clip_l",
type=str,
required=False,
help="CLIP-L model path. if not specified, use ckpt's state_dict / CLIP-Lモデルのパス。指定しない場合はckptのstate_dictを使用",
)
parser.add_argument(
"--clip_g",
type=str,
required=False,
help="CLIP-G model path. if not specified, use ckpt's state_dict / CLIP-Gモデルのパス。指定しない場合はckptのstate_dictを使用",
)
parser.add_argument(
"--t5xxl",
type=str,
required=False,
help="T5-XXL model path. if not specified, use ckpt's state_dict / T5-XXLモデルのパス。指定しない場合はckptのstate_dictを使用",
)
parser.add_argument(
"--save_clip",
action="store_true",
help="[DOES NOT WORK] unified checkpoint is not supported / 統合チェックポイントはまだサポートされていません",
)
parser.add_argument(
"--save_t5xxl",
action="store_true",
help="[DOES NOT WORK] unified checkpoint is not supported / 統合チェックポイントはまだサポートされていません",
)
parser.add_argument(
"--t5xxl_device",
type=str,
default=None,
help="[DOES NOT WORK] not supported yet. T5-XXL device. if not specified, use accelerator's device / T5-XXLデバイス。指定しない場合はacceleratorのデバイスを使用",
)
parser.add_argument(
"--t5xxl_dtype",
type=str,
default=None,
help="[DOES NOT WORK] not supported yet. T5-XXL dtype. if not specified, use default dtype (from mixed precision) / T5-XXL dtype。指定しない場合はデフォルトのdtypemixed precisionからを使用",
)
parser.add_argument(
"--t5xxl_max_token_length",
type=int,
default=256,
help="maximum token length for T5-XXL. 256 is the default value / T5-XXLの最大トークン長。デフォルトは256",
)
parser.add_argument(
"--apply_lg_attn_mask",
action="store_true",
help="apply attention mask (zero embs) to CLIP-L and G / CLIP-LとGにアテンションマスクゼロ埋めを適用する",
)
parser.add_argument(
"--apply_t5_attn_mask",
action="store_true",
help="apply attention mask (zero embs) to T5-XXL / T5-XXLにアテンションマスクゼロ埋めを適用する",
)
parser.add_argument(
"--clip_l_dropout_rate",
type=float,
default=0.0,
help="Dropout rate for CLIP-L encoder, default is 0.0 / CLIP-Lエンコーダのドロップアウト率、デフォルトは0.0",
)
parser.add_argument(
"--clip_g_dropout_rate",
type=float,
default=0.0,
help="Dropout rate for CLIP-G encoder, default is 0.0 / CLIP-Gエンコーダのドロップアウト率、デフォルトは0.0",
)
parser.add_argument(
"--t5_dropout_rate",
type=float,
default=0.0,
help="Dropout rate for T5 encoder, default is 0.0 / T5エンコーダのドロップアウト率、デフォルトは0.0",
)
parser.add_argument(
"--pos_emb_random_crop_rate",
type=float,
default=0.0,
help="Random crop rate for positional embeddings, default is 0.0. Only for SD3.5M"
" / 位置埋め込みのランダムクロップ率、デフォルトは0.0。SD3.5M以外では予期しない動作になります",
)
parser.add_argument(
"--enable_scaled_pos_embed",
action="store_true",
help="Scale position embeddings for each resolution during multi-resolution training. Only for SD3.5M"
" / 複数解像度学習時に解像度ごとに位置埋め込みをスケーリングする。SD3.5M以外では予期しない動作になります",
)
# Dependencies of Diffusers noise sampler has been removed for clarity in training
parser.add_argument(
"--training_shift",
type=float,
default=1.0,
help="Discrete flow shift for training timestep distribution adjustment, applied in addition to the weighting scheme, default is 1.0. /タイムステップ分布のための離散フローシフト、重み付けスキームの上に適用される、デフォルトは1.0。",
)
def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません"
if args.v_parameterization:
logger.warning("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります")
if args.clip_skip is not None:
logger.warning("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")
# if args.multires_noise_iterations:
# logger.info(
# f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります"
# )
# else:
# if args.noise_offset is None:
# args.noise_offset = DEFAULT_NOISE_OFFSET
# elif args.noise_offset != DEFAULT_NOISE_OFFSET:
# logger.info(
# f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています"
# )
# logger.info(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
assert (
not hasattr(args, "weighted_captions") or not args.weighted_captions
), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
if supportTextEncoderCaching:
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
args.cache_text_encoder_outputs = True
logger.warning(
"cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / "
+ "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました"
)
# temporary copied from sd3_minimal_inferece.py
def get_all_sigmas(sampling: sd3_utils.ModelSamplingDiscreteFlow, steps):
start = sampling.timestep(sampling.sigma_max)
end = sampling.timestep(sampling.sigma_min)
timesteps = torch.linspace(start, end, steps)
sigs = []
for x in range(len(timesteps)):
ts = timesteps[x]
sigs.append(sampling.sigma(ts))
sigs += [0.0]
return torch.FloatTensor(sigs)
def max_denoise(model_sampling, sigmas):
max_sigma = float(model_sampling.sigma_max)
sigma = float(sigmas[0])
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
def do_sample(
height: int,
width: int,
seed: int,
cond: Tuple[torch.Tensor, torch.Tensor],
neg_cond: Tuple[torch.Tensor, torch.Tensor],
mmdit: sd3_models.MMDiT,
steps: int,
guidance_scale: float,
dtype: torch.dtype,
device: str,
):
latent = torch.zeros(1, 16, height // 8, width // 8, device=device)
latent = latent.to(dtype).to(device)
# noise = get_noise(seed, latent).to(device)
if seed is not None:
generator = torch.manual_seed(seed)
else:
generator = None
noise = (
torch.randn(latent.size(), dtype=torch.float32, layout=latent.layout, generator=generator, device="cpu")
.to(latent.dtype)
.to(device)
)
model_sampling = sd3_utils.ModelSamplingDiscreteFlow(shift=3.0) # 3.0 is for SD3
sigmas = get_all_sigmas(model_sampling, steps).to(device)
noise_scaled = model_sampling.noise_scaling(sigmas[0], noise, latent, max_denoise(model_sampling, sigmas))
c_crossattn = torch.cat([cond[0], neg_cond[0]]).to(device).to(dtype)
y = torch.cat([cond[1], neg_cond[1]]).to(device).to(dtype)
x = noise_scaled.to(device).to(dtype)
# print(x.shape)
# with torch.no_grad():
for i in tqdm(range(len(sigmas) - 1)):
sigma_hat = sigmas[i]
timestep = model_sampling.timestep(sigma_hat).float()
timestep = torch.FloatTensor([timestep, timestep]).to(device)
x_c_nc = torch.cat([x, x], dim=0)
# print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape)
mmdit.prepare_block_swap_before_forward()
model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y)
model_output = model_output.float()
batched = model_sampling.calculate_denoised(sigma_hat, model_output, x)
pos_out, neg_out = batched.chunk(2)
denoised = neg_out + (pos_out - neg_out) * guidance_scale
# print(denoised.shape)
# d = to_d(x, sigma_hat, denoised)
dims_to_append = x.ndim - sigma_hat.ndim
sigma_hat_dims = sigma_hat[(...,) + (None,) * dims_to_append]
# print(dims_to_append, x.shape, sigma_hat.shape, denoised.shape, sigma_hat_dims.shape)
"""Converts a denoiser output to a Karras ODE derivative."""
d = (x - denoised) / sigma_hat_dims
dt = sigmas[i + 1] - sigma_hat
# Euler method
x = x + d * dt
x = x.to(dtype)
mmdit.prepare_block_swap_before_forward()
return x
def sample_images(
accelerator: Accelerator,
args: argparse.Namespace,
epoch,
steps,
mmdit,
vae,
text_encoders,
sample_prompts_te_outputs,
prompt_replacement=None,
):
if steps == 0:
if not args.sample_at_first:
return
else:
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
return
if args.sample_every_n_epochs is not None:
# sample_every_n_steps は無視する
if epoch is None or epoch % args.sample_every_n_epochs != 0:
return
else:
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
return
logger.info("")
logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
if not os.path.isfile(args.sample_prompts) and sample_prompts_te_outputs is None:
logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
return
distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
# unwrap unet and text_encoder(s)
mmdit = accelerator.unwrap_model(mmdit)
text_encoders = None if text_encoders is None else [accelerator.unwrap_model(te) for te in text_encoders]
# print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
prompts = train_util.load_prompts(args.sample_prompts)
save_dir = args.output_dir + "/sample"
os.makedirs(save_dir, exist_ok=True)
# save random state to restore later
rng_state = torch.get_rng_state()
cuda_rng_state = None
try:
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
except Exception:
pass
if distributed_state.num_processes <= 1:
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
with torch.no_grad(), accelerator.autocast():
for prompt_dict in prompts:
sample_image_inference(
accelerator,
args,
mmdit,
text_encoders,
vae,
save_dir,
prompt_dict,
epoch,
steps,
sample_prompts_te_outputs,
prompt_replacement,
)
else:
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
per_process_prompts = [] # list of lists
for i in range(distributed_state.num_processes):
per_process_prompts.append(prompts[i :: distributed_state.num_processes])
with torch.no_grad():
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
for prompt_dict in prompt_dict_lists[0]:
sample_image_inference(
accelerator,
args,
mmdit,
text_encoders,
vae,
save_dir,
prompt_dict,
epoch,
steps,
sample_prompts_te_outputs,
prompt_replacement,
)
torch.set_rng_state(rng_state)
if cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state)
clean_memory_on_device(accelerator.device)
def sample_image_inference(
accelerator: Accelerator,
args: argparse.Namespace,
mmdit: sd3_models.MMDiT,
text_encoders: List[Union[CLIPTextModelWithProjection, T5EncoderModel]],
vae: sd3_models.SDVAE,
save_dir,
prompt_dict,
epoch,
steps,
sample_prompts_te_outputs,
prompt_replacement,
):
assert isinstance(prompt_dict, dict)
negative_prompt = prompt_dict.get("negative_prompt")
sample_steps = prompt_dict.get("sample_steps", 30)
width = prompt_dict.get("width", 512)
height = prompt_dict.get("height", 512)
scale = prompt_dict.get("scale", 7.5)
seed = prompt_dict.get("seed")
# controlnet_image = prompt_dict.get("controlnet_image")
prompt: str = prompt_dict.get("prompt", "")
# sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
if negative_prompt is not None:
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
else:
# True random sample image generation
torch.seed()
torch.cuda.seed()
if negative_prompt is None:
negative_prompt = ""
height = max(64, height - height % 8) # round to divisible by 8
width = max(64, width - width % 8) # round to divisible by 8
logger.info(f"prompt: {prompt}")
logger.info(f"negative_prompt: {negative_prompt}")
logger.info(f"height: {height}")
logger.info(f"width: {width}")
logger.info(f"sample_steps: {sample_steps}")
logger.info(f"scale: {scale}")
# logger.info(f"sample_sampler: {sampler_name}")
if seed is not None:
logger.info(f"seed: {seed}")
# encode prompts
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
def encode_prompt(prpt):
text_encoder_conds = []
if sample_prompts_te_outputs and prpt in sample_prompts_te_outputs:
text_encoder_conds = sample_prompts_te_outputs[prpt]
print(f"Using cached text encoder outputs for prompt: {prpt}")
if text_encoders is not None:
print(f"Encoding prompt: {prpt}")
tokens_and_masks = tokenize_strategy.tokenize(prpt)
# strategy has apply_t5_attn_mask option
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
if len(text_encoder_conds) == 0:
text_encoder_conds = encoded_text_encoder_conds
else:
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
for i in range(len(encoded_text_encoder_conds)):
if encoded_text_encoder_conds[i] is not None:
text_encoder_conds[i] = encoded_text_encoder_conds[i]
return text_encoder_conds
lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encode_prompt(prompt)
cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
# encode negative prompts
lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encode_prompt(negative_prompt)
neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
# sample image
clean_memory_on_device(accelerator.device)
with accelerator.autocast(), torch.no_grad():
# mmdit may be fp8, so we need weight_dtype here. vae is always in that dtype.
latents = do_sample(height, width, seed, cond, neg_cond, mmdit, sample_steps, scale, vae.dtype, accelerator.device)
# latent to image
clean_memory_on_device(accelerator.device)
org_vae_device = vae.device # will be on cpu
vae.to(accelerator.device)
latents = vae.process_out(latents.to(vae.device, dtype=vae.dtype))
image = vae.decode(latents)
vae.to(org_vae_device)
clean_memory_on_device(accelerator.device)
image = image.float()
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0]
decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2)
decoded_np = decoded_np.astype(np.uint8)
image = Image.fromarray(decoded_np)
# adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
# but adding 'enum' to the filename should be enough
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
seed_suffix = "" if seed is None else f"_{seed}"
i: int = prompt_dict["enum"]
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
image.save(os.path.join(save_dir, img_filename))
# send images to wandb if enabled
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
wandb_tracker = accelerator.get_tracker("wandb")
import wandb
# not to commit images to avoid inconsistency between training and logging steps
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
# region Diffusers
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from diffusers.utils.torch_utils import randn_tensor
from diffusers.utils import BaseOutput
@dataclass
class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
"""
prev_sample: torch.FloatTensor
class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
Euler scheduler.
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.
Args:
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
timestep_spacing (`str`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
shift (`float`, defaults to 1.0):
The shift value for the timestep schedule.
"""
_compatibles = []
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
shift: float = 1.0,
):
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
sigmas = timesteps / num_train_timesteps
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
self.timesteps = sigmas * num_train_timesteps
self._step_index = None
self._begin_index = None
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
self.sigma_min = self.sigmas[-1].item()
self.sigma_max = self.sigmas[0].item()
@property
def step_index(self):
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
"""
return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def scale_noise(
self,
sample: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
noise: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
"""
Forward process in flow-matching
Args:
sample (`torch.FloatTensor`):
The input sample.
timestep (`int`, *optional*):
The current timestep in the diffusion chain.
Returns:
`torch.FloatTensor`:
A scaled input sample.
"""
if self.step_index is None:
self._init_step_index(timestep)
sigma = self.sigmas[self.step_index]
sample = sigma * noise + (1.0 - sigma) * sample
return sample
def _sigma_to_t(self, sigma):
return sigma * self.config.num_train_timesteps
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
self.num_inference_steps = num_inference_steps
timesteps = np.linspace(self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps)
sigmas = timesteps / self.config.num_train_timesteps
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
timesteps = sigmas * self.config.num_train_timesteps
self.timesteps = timesteps.to(device=device)
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
self._step_index = None
self._begin_index = None
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
indices = (schedule_timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
pos = 1 if len(indices) > 1 else 0
return indices[pos].item()
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def step(
self,
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
s_churn: float = 0.0,
s_tmin: float = 0.0,
s_tmax: float = float("inf"),
s_noise: float = 1.0,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
s_churn (`float`):
s_tmin (`float`):
s_tmax (`float`):
s_noise (`float`, defaults to 1.0):
Scaling factor for noise added to the sample.
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
tuple.
Returns:
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
returned, otherwise a tuple is returned where the first element is the sample tensor.
"""
if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor):
raise ValueError(
(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep."
),
)
if self.step_index is None:
self._init_step_index(timestep)
# Upcast to avoid precision issues when computing prev_sample
sample = sample.to(torch.float32)
sigma = self.sigmas[self.step_index]
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator)
eps = noise * s_noise
sigma_hat = sigma * (gamma + 1)
if gamma > 0:
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
# NOTE: "original_sample" should not be an expected prediction_type but is left in for
# backwards compatibility
# if self.config.prediction_type == "vector_field":
denoised = sample - model_output * sigma
# 2. Convert to an ODE derivative
derivative = (sample - denoised) / sigma_hat
dt = self.sigmas[self.step_index + 1] - sigma_hat
prev_sample = sample + derivative * dt
# Cast sample back to model compatible dtype
prev_sample = prev_sample.to(model_output.dtype)
# upon completion increase step index by one
self._step_index += 1
if not return_dict:
return (prev_sample,)
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
def __len__(self):
return self.config.num_train_timesteps
def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32):
sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
schedule_timesteps = noise_scheduler.timesteps.to(device)
timesteps = timesteps.to(device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
def compute_density_for_timestep_sampling(
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
):
"""Compute the density for sampling the timesteps when doing SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
u = torch.nn.functional.sigmoid(u)
elif weighting_scheme == "mode":
u = torch.rand(size=(batch_size,), device="cpu")
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
else:
u = torch.rand(size=(batch_size,), device="cpu")
return u
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
"""Computes loss weighting scheme for SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "sigma_sqrt":
weighting = (sigmas**-2.0).float()
elif weighting_scheme == "cosmap":
bot = 1 - 2 * sigmas + 2 * sigmas**2
weighting = 2 / (math.pi * bot)
else:
weighting = torch.ones_like(sigmas)
return weighting
# endregion
def get_noisy_model_input_and_timesteps(args, latents, noise, device, dtype) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
bsz = latents.shape[0]
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
u = compute_density_for_timestep_sampling(
weighting_scheme=args.weighting_scheme,
batch_size=bsz,
logit_mean=args.logit_mean,
logit_std=args.logit_std,
mode_scale=args.mode_scale,
)
t_min = args.min_timestep if args.min_timestep is not None else 0
t_max = args.max_timestep if args.max_timestep is not None else 1000
shift = args.training_shift
# weighting shift, value >1 will shift distribution to noisy side (focus more on overall structure), value <1 will shift towards less-noisy side (focus more on details)
u = (u * shift) / (1 + (shift - 1) * u)
indices = (u * (t_max - t_min) + t_min).long()
timesteps = indices.to(device=device, dtype=dtype)
# sigmas according to flowmatching
sigmas = timesteps / 1000
sigmas = sigmas.view(-1, 1, 1, 1)
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
return noisy_model_input, timesteps, sigmas

302
library/sd3_utils.py Normal file
View File

@@ -0,0 +1,302 @@
from dataclasses import dataclass
import math
import re
from typing import Dict, List, Optional, Union
import torch
import safetensors
from safetensors.torch import load_file
from accelerate import init_empty_weights
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPConfig, CLIPTextConfig
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
from library import sd3_models
# TODO move some of functions to model_util.py
from library import sdxl_model_util
# region models
# TODO remove dependency on flux_utils
from library.utils import load_safetensors
from library.flux_utils import load_t5xxl as flux_utils_load_t5xxl
def analyze_state_dict_state(state_dict: Dict, prefix: str = ""):
logger.info(f"Analyzing state dict state...")
# analyze configs
patch_size = state_dict[f"{prefix}x_embedder.proj.weight"].shape[2]
depth = state_dict[f"{prefix}x_embedder.proj.weight"].shape[0] // 64
num_patches = state_dict[f"{prefix}pos_embed"].shape[1]
pos_embed_max_size = round(math.sqrt(num_patches))
adm_in_channels = state_dict[f"{prefix}y_embedder.mlp.0.weight"].shape[1]
context_shape = state_dict[f"{prefix}context_embedder.weight"].shape
qk_norm = "rms" if f"{prefix}joint_blocks.0.context_block.attn.ln_k.weight" in state_dict.keys() else None
# x_block_self_attn_layers.append(int(key.split(".x_block.attn2.ln_k.weight")[0].split(".")[-1]))
x_block_self_attn_layers = []
re_attn = re.compile(r"\.(\d+)\.x_block\.attn2\.ln_k\.weight")
for key in list(state_dict.keys()):
m = re_attn.search(key)
if m:
x_block_self_attn_layers.append(int(m.group(1)))
context_embedder_in_features = context_shape[1]
context_embedder_out_features = context_shape[0]
# only supports 3-5-large, medium or 3-medium
if qk_norm is not None:
if len(x_block_self_attn_layers) == 0:
model_type = "3-5-large"
else:
model_type = "3-5-medium"
else:
model_type = "3-medium"
params = sd3_models.SD3Params(
patch_size=patch_size,
depth=depth,
num_patches=num_patches,
pos_embed_max_size=pos_embed_max_size,
adm_in_channels=adm_in_channels,
qk_norm=qk_norm,
x_block_self_attn_layers=x_block_self_attn_layers,
context_embedder_in_features=context_embedder_in_features,
context_embedder_out_features=context_embedder_out_features,
model_type=model_type,
)
logger.info(f"Analyzed state dict state: {params}")
return params
def load_mmdit(
state_dict: Dict, dtype: Optional[Union[str, torch.dtype]], device: Union[str, torch.device], attn_mode: str = "torch"
) -> sd3_models.MMDiT:
mmdit_sd = {}
mmdit_prefix = "model.diffusion_model."
for k in list(state_dict.keys()):
if k.startswith(mmdit_prefix):
mmdit_sd[k[len(mmdit_prefix) :]] = state_dict.pop(k)
# load MMDiT
logger.info("Building MMDit")
params = analyze_state_dict_state(mmdit_sd)
with init_empty_weights():
mmdit = sd3_models.create_sd3_mmdit(params, attn_mode)
logger.info("Loading state dict...")
info = mmdit.load_state_dict(mmdit_sd, strict=False, assign=True)
logger.info(f"Loaded MMDiT: {info}")
return mmdit
def load_clip_l(
clip_l_path: Optional[str],
dtype: Optional[Union[str, torch.dtype]],
device: Union[str, torch.device],
disable_mmap: bool = False,
state_dict: Optional[Dict] = None,
):
clip_l_sd = None
if clip_l_path is None:
if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
# found clip_l: remove prefix "text_encoders.clip_l."
logger.info("clip_l is included in the checkpoint")
clip_l_sd = {}
prefix = "text_encoders.clip_l."
for k in list(state_dict.keys()):
if k.startswith(prefix):
clip_l_sd[k[len(prefix) :]] = state_dict.pop(k)
elif clip_l_path is None:
logger.info("clip_l is not included in the checkpoint and clip_l_path is not provided")
return None
# load clip_l
logger.info("Building CLIP-L")
config = CLIPTextConfig(
vocab_size=49408,
hidden_size=768,
intermediate_size=3072,
num_hidden_layers=12,
num_attention_heads=12,
max_position_embeddings=77,
hidden_act="quick_gelu",
layer_norm_eps=1e-05,
dropout=0.0,
attention_dropout=0.0,
initializer_range=0.02,
initializer_factor=1.0,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
model_type="clip_text_model",
projection_dim=768,
# torch_dtype="float32",
# transformers_version="4.25.0.dev0",
)
with init_empty_weights():
clip = CLIPTextModelWithProjection(config)
if clip_l_sd is None:
logger.info(f"Loading state dict from {clip_l_path}")
clip_l_sd = load_safetensors(clip_l_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
if "text_projection.weight" not in clip_l_sd:
logger.info("Adding text_projection.weight to clip_l_sd")
clip_l_sd["text_projection.weight"] = torch.eye(768, dtype=dtype, device=device)
info = clip.load_state_dict(clip_l_sd, strict=False, assign=True)
logger.info(f"Loaded CLIP-L: {info}")
return clip
def load_clip_g(
clip_g_path: Optional[str],
dtype: Optional[Union[str, torch.dtype]],
device: Union[str, torch.device],
disable_mmap: bool = False,
state_dict: Optional[Dict] = None,
):
clip_g_sd = None
if state_dict is not None:
if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
# found clip_g: remove prefix "text_encoders.clip_g."
logger.info("clip_g is included in the checkpoint")
clip_g_sd = {}
prefix = "text_encoders.clip_g."
for k in list(state_dict.keys()):
if k.startswith(prefix):
clip_g_sd[k[len(prefix) :]] = state_dict.pop(k)
elif clip_g_path is None:
logger.info("clip_g is not included in the checkpoint and clip_g_path is not provided")
return None
# load clip_g
logger.info("Building CLIP-G")
config = CLIPTextConfig(
vocab_size=49408,
hidden_size=1280,
intermediate_size=5120,
num_hidden_layers=32,
num_attention_heads=20,
max_position_embeddings=77,
hidden_act="gelu",
layer_norm_eps=1e-05,
dropout=0.0,
attention_dropout=0.0,
initializer_range=0.02,
initializer_factor=1.0,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
model_type="clip_text_model",
projection_dim=1280,
# torch_dtype="float32",
# transformers_version="4.25.0.dev0",
)
with init_empty_weights():
clip = CLIPTextModelWithProjection(config)
if clip_g_sd is None:
logger.info(f"Loading state dict from {clip_g_path}")
clip_g_sd = load_safetensors(clip_g_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
info = clip.load_state_dict(clip_g_sd, strict=False, assign=True)
logger.info(f"Loaded CLIP-G: {info}")
return clip
def load_t5xxl(
t5xxl_path: Optional[str],
dtype: Optional[Union[str, torch.dtype]],
device: Union[str, torch.device],
disable_mmap: bool = False,
state_dict: Optional[Dict] = None,
):
t5xxl_sd = None
if state_dict is not None:
if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict:
# found t5xxl: remove prefix "text_encoders.t5xxl."
logger.info("t5xxl is included in the checkpoint")
t5xxl_sd = {}
prefix = "text_encoders.t5xxl."
for k in list(state_dict.keys()):
if k.startswith(prefix):
t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k)
elif t5xxl_path is None:
logger.info("t5xxl is not included in the checkpoint and t5xxl_path is not provided")
return None
return flux_utils_load_t5xxl(t5xxl_path, dtype, device, disable_mmap, state_dict=t5xxl_sd)
def load_vae(
vae_path: Optional[str],
vae_dtype: Optional[Union[str, torch.dtype]],
device: Optional[Union[str, torch.device]],
disable_mmap: bool = False,
state_dict: Optional[Dict] = None,
):
vae_sd = {}
if vae_path:
logger.info(f"Loading VAE from {vae_path}...")
vae_sd = load_safetensors(vae_path, device, disable_mmap)
else:
# remove prefix "first_stage_model."
vae_sd = {}
vae_prefix = "first_stage_model."
for k in list(state_dict.keys()):
if k.startswith(vae_prefix):
vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k)
logger.info("Building VAE")
vae = sd3_models.SDVAE(vae_dtype, device)
logger.info("Loading state dict...")
info = vae.load_state_dict(vae_sd)
logger.info(f"Loaded VAE: {info}")
vae.to(device=device, dtype=vae_dtype) # make sure it's in the right device and dtype
return vae
# endregion
class ModelSamplingDiscreteFlow:
"""Helper for sampler scheduling (ie timestep/sigma calculations) for Discrete Flow models"""
def __init__(self, shift=1.0):
self.shift = shift
timesteps = 1000
self.sigmas = self.sigma(torch.arange(1, timesteps + 1, 1))
@property
def sigma_min(self):
return self.sigmas[0]
@property
def sigma_max(self):
return self.sigmas[-1]
def timestep(self, sigma):
return sigma * 1000
def sigma(self, timestep: torch.Tensor):
timestep = timestep / 1000.0
if self.shift == 1.0:
return timestep
return self.shift * timestep / (1 + (self.shift - 1) * timestep)
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
# assert max_denoise is False, "max_denoise not implemented"
# max_denoise is always True, I'm not sure why it's there
return sigma * noise + (1.0 - sigma) * latent_image

View File

@@ -13,12 +13,20 @@ from tqdm import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from diffusers import SchedulerMixin, StableDiffusionPipeline
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
from diffusers.models import AutoencoderKL
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from diffusers.utils import logging
from PIL import Image
from library import sdxl_model_util, sdxl_train_util, train_util
from library import (
sdxl_model_util,
sdxl_train_util,
strategy_base,
strategy_sdxl,
train_util,
sdxl_original_unet,
sdxl_original_control_net,
)
try:
@@ -537,7 +545,7 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
vae: AutoencoderKL,
text_encoder: List[CLIPTextModel],
tokenizer: List[CLIPTokenizer],
unet: UNet2DConditionModel,
unet: Union[sdxl_original_unet.SdxlUNet2DConditionModel, sdxl_original_control_net.SdxlControlledUNet],
scheduler: SchedulerMixin,
# clip_skip: int,
safety_checker: StableDiffusionSafetyChecker,
@@ -594,74 +602,6 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
return torch.device(module._hf_hook.execution_device)
return self.device
def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
max_embeddings_multiples,
is_sdxl_text_encoder2,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list(int)`):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
The max multiple length of prompt embeddings compared to the max output length of text encoder.
"""
batch_size = len(prompt) if isinstance(prompt, list) else 1
if negative_prompt is None:
negative_prompt = [""] * batch_size
elif isinstance(negative_prompt, str):
negative_prompt = [negative_prompt] * batch_size
if batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
text_embeddings, text_pool, uncond_embeddings, uncond_pool = get_weighted_text_embeddings(
pipe=self,
prompt=prompt,
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
max_embeddings_multiples=max_embeddings_multiples,
clip_skip=self.clip_skip,
is_sdxl_text_encoder2=is_sdxl_text_encoder2,
)
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) # ??
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
if text_pool is not None:
text_pool = text_pool.repeat(1, num_images_per_prompt)
text_pool = text_pool.view(bs_embed * num_images_per_prompt, -1)
if do_classifier_free_guidance:
bs_embed, seq_len, _ = uncond_embeddings.shape
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
if uncond_pool is not None:
uncond_pool = uncond_pool.repeat(1, num_images_per_prompt)
uncond_pool = uncond_pool.view(bs_embed * num_images_per_prompt, -1)
return text_embeddings, text_pool, uncond_embeddings, uncond_pool
return text_embeddings, text_pool, None, None
def check_inputs(self, prompt, height, width, strength, callback_steps):
if not isinstance(prompt, str) and not isinstance(prompt, list):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
@@ -792,7 +732,7 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
max_embeddings_multiples: Optional[int] = 3,
output_type: Optional[str] = "pil",
return_dict: bool = True,
controlnet=None,
controlnet: sdxl_original_control_net.SdxlControlNet = None,
controlnet_image=None,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
is_cancelled_callback: Optional[Callable[[], bool]] = None,
@@ -896,32 +836,24 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
# 実装を簡単にするためにtokenzer/text encoderを切り替えて二回呼び出す
# To simplify the implementation, switch the tokenzer/text encoder and call it twice
text_embeddings_list = []
text_pool = None
uncond_embeddings_list = []
uncond_pool = None
for i in range(len(self.tokenizers)):
self.tokenizer = self.tokenizers[i]
self.text_encoder = self.text_encoders[i]
tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
encoding_strategy: strategy_sdxl.SdxlTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
text_embeddings, tp1, uncond_embeddings, up1 = self._encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
max_embeddings_multiples,
is_sdxl_text_encoder2=i == 1,
text_input_ids, text_weights = tokenize_strategy.tokenize_with_weights(prompt)
hidden_states_1, hidden_states_2, text_pool = encoding_strategy.encode_tokens_with_weights(
tokenize_strategy, self.text_encoders, text_input_ids, text_weights
)
text_embeddings = torch.cat([hidden_states_1, hidden_states_2], dim=-1)
if do_classifier_free_guidance:
input_ids, weights = tokenize_strategy.tokenize_with_weights(negative_prompt or "")
hidden_states_1, hidden_states_2, uncond_pool = encoding_strategy.encode_tokens_with_weights(
tokenize_strategy, self.text_encoders, input_ids, weights
)
text_embeddings_list.append(text_embeddings)
uncond_embeddings_list.append(uncond_embeddings)
if tp1 is not None:
text_pool = tp1
if up1 is not None:
uncond_pool = up1
uncond_embeddings = torch.cat([hidden_states_1, hidden_states_2], dim=-1)
else:
uncond_embeddings = None
uncond_pool = None
unet_dtype = self.unet.dtype
dtype = unet_dtype
@@ -970,23 +902,23 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# create size embs and concat embeddings for SDXL
orig_size = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1).to(dtype)
orig_size = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1).to(device, dtype)
crop_size = torch.zeros_like(orig_size)
target_size = orig_size
embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, device).to(dtype)
embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, device).to(device, dtype)
# make conditionings
text_pool = text_pool.to(device, dtype)
if do_classifier_free_guidance:
text_embeddings = torch.cat(text_embeddings_list, dim=2)
uncond_embeddings = torch.cat(uncond_embeddings_list, dim=2)
text_embedding = torch.cat([uncond_embeddings, text_embeddings]).to(dtype)
text_embedding = torch.cat([uncond_embeddings, text_embeddings]).to(device, dtype)
cond_vector = torch.cat([text_pool, embs], dim=1)
uncond_vector = torch.cat([uncond_pool, embs], dim=1)
vector_embedding = torch.cat([uncond_vector, cond_vector]).to(dtype)
uncond_pool = uncond_pool.to(device, dtype)
cond_vector = torch.cat([text_pool, embs], dim=1).to(dtype)
uncond_vector = torch.cat([uncond_pool, embs], dim=1).to(dtype)
vector_embedding = torch.cat([uncond_vector, cond_vector])
else:
text_embedding = torch.cat(text_embeddings_list, dim=2).to(dtype)
vector_embedding = torch.cat([text_pool, embs], dim=1).to(dtype)
text_embedding = text_embeddings.to(device, dtype)
vector_embedding = torch.cat([text_pool, embs], dim=1)
# 8. Denoising loop
for i, t in enumerate(self.progress_bar(timesteps)):
@@ -994,22 +926,14 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
unet_additional_args = {}
if controlnet is not None:
down_block_res_samples, mid_block_res_sample = controlnet(
latent_model_input,
t,
encoder_hidden_states=text_embeddings,
controlnet_cond=controlnet_image,
conditioning_scale=1.0,
guess_mode=False,
return_dict=False,
)
unet_additional_args["down_block_additional_residuals"] = down_block_res_samples
unet_additional_args["mid_block_additional_residual"] = mid_block_res_sample
# FIXME SD1 ControlNet is not working
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding)
if controlnet is not None:
input_resi_add, mid_add = controlnet(latent_model_input, t, text_embedding, vector_embedding, controlnet_image)
noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding, input_resi_add, mid_add)
else:
noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding)
noise_pred = noise_pred.to(dtype) # U-Net changes dtype in LoRA training
# perform guidance

View File

@@ -1,4 +1,5 @@
import torch
import safetensors
from accelerate import init_empty_weights
from accelerate.utils.modeling import set_module_tensor_to_device
from safetensors.torch import load_file, save_file
@@ -7,9 +8,11 @@ from typing import List
from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
from library import model_util
from library import sdxl_original_unet
from .utils import setup_logging
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
VAE_SCALE_FACTOR = 0.13025
@@ -163,17 +166,20 @@ def _load_state_dict_on_device(model, state_dict, device, dtype=None):
raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs)))
def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None):
def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None, disable_mmap=False):
# model_version is reserved for future use
# dtype is used for full_fp16/bf16 integration. Text Encoder will remain fp32, because it runs on CPU when caching
# Load the state dict
if model_util.is_safetensors(ckpt_path):
checkpoint = None
try:
state_dict = load_file(ckpt_path, device=map_location)
except:
state_dict = load_file(ckpt_path) # prevent device invalid Error
if disable_mmap:
state_dict = safetensors.torch.load(open(ckpt_path, "rb").read())
else:
try:
state_dict = load_file(ckpt_path, device=map_location)
except:
state_dict = load_file(ckpt_path) # prevent device invalid Error
epoch = None
global_step = None
else:

View File

@@ -0,0 +1,272 @@
# some parts are modified from Diffusers library (Apache License 2.0)
import math
from types import SimpleNamespace
from typing import Any, Optional
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import functional as F
from einops import rearrange
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
from library import sdxl_original_unet
from library.sdxl_model_util import convert_sdxl_unet_state_dict_to_diffusers, convert_diffusers_unet_state_dict_to_sdxl
class ControlNetConditioningEmbedding(nn.Module):
def __init__(self):
super().__init__()
dims = [16, 32, 96, 256]
self.conv_in = nn.Conv2d(3, dims[0], kernel_size=3, padding=1)
self.blocks = nn.ModuleList([])
for i in range(len(dims) - 1):
channel_in = dims[i]
channel_out = dims[i + 1]
self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
self.conv_out = nn.Conv2d(dims[-1], 320, kernel_size=3, padding=1)
nn.init.zeros_(self.conv_out.weight) # zero module weight
nn.init.zeros_(self.conv_out.bias) # zero module bias
def forward(self, x):
x = self.conv_in(x)
x = F.silu(x)
for block in self.blocks:
x = block(x)
x = F.silu(x)
x = self.conv_out(x)
return x
class SdxlControlNet(sdxl_original_unet.SdxlUNet2DConditionModel):
def __init__(self, multiplier: Optional[float] = None, **kwargs):
super().__init__(**kwargs)
self.multiplier = multiplier
# remove unet layers
self.output_blocks = nn.ModuleList([])
del self.out
self.controlnet_cond_embedding = ControlNetConditioningEmbedding()
dims = [320, 320, 320, 320, 640, 640, 640, 1280, 1280]
self.controlnet_down_blocks = nn.ModuleList([])
for dim in dims:
self.controlnet_down_blocks.append(nn.Conv2d(dim, dim, kernel_size=1))
nn.init.zeros_(self.controlnet_down_blocks[-1].weight) # zero module weight
nn.init.zeros_(self.controlnet_down_blocks[-1].bias) # zero module bias
self.controlnet_mid_block = nn.Conv2d(1280, 1280, kernel_size=1)
nn.init.zeros_(self.controlnet_mid_block.weight) # zero module weight
nn.init.zeros_(self.controlnet_mid_block.bias) # zero module bias
def init_from_unet(self, unet: sdxl_original_unet.SdxlUNet2DConditionModel):
unet_sd = unet.state_dict()
unet_sd = {k: v for k, v in unet_sd.items() if not k.startswith("out")}
sd = super().state_dict()
sd.update(unet_sd)
info = super().load_state_dict(sd, strict=True, assign=True)
return info
def load_state_dict(self, state_dict: dict, strict: bool = True, assign: bool = True) -> Any:
# convert state_dict to SAI format
unet_sd = {}
for k in list(state_dict.keys()):
if not k.startswith("controlnet_"):
unet_sd[k] = state_dict.pop(k)
unet_sd = convert_diffusers_unet_state_dict_to_sdxl(unet_sd)
state_dict.update(unet_sd)
super().load_state_dict(state_dict, strict=strict, assign=assign)
def state_dict(self, destination=None, prefix="", keep_vars=False):
# convert state_dict to Diffusers format
state_dict = super().state_dict(destination, prefix, keep_vars)
control_net_sd = {}
for k in list(state_dict.keys()):
if k.startswith("controlnet_"):
control_net_sd[k] = state_dict.pop(k)
state_dict = convert_sdxl_unet_state_dict_to_diffusers(state_dict)
state_dict.update(control_net_sd)
return state_dict
def forward(
self,
x: torch.Tensor,
timesteps: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
cond_image: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
# broadcast timesteps to batch dimension
timesteps = timesteps.expand(x.shape[0])
t_emb = sdxl_original_unet.get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0)
t_emb = t_emb.to(x.dtype)
emb = self.time_embed(t_emb)
assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
emb = emb + self.label_emb(y)
def call_module(module, h, emb, context):
x = h
for layer in module:
if isinstance(layer, sdxl_original_unet.ResnetBlock2D):
x = layer(x, emb)
elif isinstance(layer, sdxl_original_unet.Transformer2DModel):
x = layer(x, context)
else:
x = layer(x)
return x
h = x
multiplier = self.multiplier if self.multiplier is not None else 1.0
hs = []
for i, module in enumerate(self.input_blocks):
h = call_module(module, h, emb, context)
if i == 0:
h = self.controlnet_cond_embedding(cond_image) + h
hs.append(self.controlnet_down_blocks[i](h) * multiplier)
h = call_module(self.middle_block, h, emb, context)
h = self.controlnet_mid_block(h) * multiplier
return hs, h
class SdxlControlledUNet(sdxl_original_unet.SdxlUNet2DConditionModel):
"""
This class is for training purpose only.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def forward(self, x, timesteps=None, context=None, y=None, input_resi_add=None, mid_add=None, **kwargs):
# broadcast timesteps to batch dimension
timesteps = timesteps.expand(x.shape[0])
hs = []
t_emb = sdxl_original_unet.get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0)
t_emb = t_emb.to(x.dtype)
emb = self.time_embed(t_emb)
assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
emb = emb + self.label_emb(y)
def call_module(module, h, emb, context):
x = h
for layer in module:
if isinstance(layer, sdxl_original_unet.ResnetBlock2D):
x = layer(x, emb)
elif isinstance(layer, sdxl_original_unet.Transformer2DModel):
x = layer(x, context)
else:
x = layer(x)
return x
h = x
for module in self.input_blocks:
h = call_module(module, h, emb, context)
hs.append(h)
h = call_module(self.middle_block, h, emb, context)
h = h + mid_add
for module in self.output_blocks:
resi = hs.pop() + input_resi_add.pop()
h = torch.cat([h, resi], dim=1)
h = call_module(module, h, emb, context)
h = h.type(x.dtype)
h = call_module(self.out, h, emb, context)
return h
if __name__ == "__main__":
import time
logger.info("create unet")
unet = SdxlControlledUNet()
unet.to("cuda", torch.bfloat16)
unet.set_use_sdpa(True)
unet.set_gradient_checkpointing(True)
unet.train()
logger.info("create control_net")
control_net = SdxlControlNet()
control_net.to("cuda")
control_net.set_use_sdpa(True)
control_net.set_gradient_checkpointing(True)
control_net.train()
logger.info("Initialize control_net from unet")
control_net.init_from_unet(unet)
unet.requires_grad_(False)
control_net.requires_grad_(True)
# 使用メモリ量確認用の疑似学習ループ
logger.info("preparing optimizer")
# optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working
import bitsandbytes
optimizer = bitsandbytes.adam.Adam8bit(control_net.parameters(), lr=1e-3) # not working
# optimizer = bitsandbytes.optim.RMSprop8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2
# optimizer=bitsandbytes.optim.Adagrad8bit(unet.parameters(), lr=1e-3) # working at 23.5 GB with torch2
# import transformers
# optimizer = transformers.optimization.Adafactor(unet.parameters(), relative_step=True) # working at 22.2GB with torch2
scaler = torch.cuda.amp.GradScaler(enabled=True)
logger.info("start training")
steps = 10
batch_size = 1
for step in range(steps):
logger.info(f"step {step}")
if step == 1:
time_start = time.perf_counter()
x = torch.randn(batch_size, 4, 128, 128).cuda() # 1024x1024
t = torch.randint(low=0, high=1000, size=(batch_size,), device="cuda")
txt = torch.randn(batch_size, 77, 2048).cuda()
vector = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda()
cond_img = torch.rand(batch_size, 3, 1024, 1024).cuda()
with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
input_resi_add, mid_add = control_net(x, t, txt, vector, cond_img)
output = unet(x, t, txt, vector, input_resi_add, mid_add)
target = torch.randn_like(output)
loss = torch.nn.functional.mse_loss(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
time_end = time.perf_counter()
logger.info(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps")
logger.info("finish training")
sd = control_net.state_dict()
from safetensors.torch import save_file
save_file(sd, r"E:\Work\SD\Tmp\sdxl\ctrl\control_net.safetensors")

View File

@@ -30,7 +30,7 @@ import torch.utils.checkpoint
from torch import nn
from torch.nn import functional as F
from einops import rearrange
from .utils import setup_logging
from library.utils import setup_logging
setup_logging()
import logging
@@ -1156,9 +1156,9 @@ class InferSdxlUNet2DConditionModel:
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
self.ds_ratio = ds_ratio
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
def forward(self, x, timesteps=None, context=None, y=None, input_resi_add=None, mid_add=None, **kwargs):
r"""
current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink.
current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink and ControlNet.
"""
_self = self.delegate
@@ -1209,6 +1209,8 @@ class InferSdxlUNet2DConditionModel:
hs.append(h)
h = call_module(_self.middle_block, h, emb, context)
if mid_add is not None:
h = h + mid_add
for module in _self.output_blocks:
# Deep Shrink
@@ -1217,7 +1219,11 @@ class InferSdxlUNet2DConditionModel:
# print("upsample", h.shape, hs[-1].shape)
h = resize_like(h, hs[-1])
h = torch.cat([h, hs.pop()], dim=1)
resi = hs.pop()
if input_resi_add is not None:
resi = resi + input_resi_add.pop()
h = torch.cat([h, resi], dim=1)
h = call_module(module, h, emb, context)
# Deep Shrink: in case of depth 0

View File

@@ -5,16 +5,18 @@ from typing import Optional
import torch
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate import init_empty_weights
from tqdm import tqdm
from transformers import CLIPTokenizer
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
@@ -24,7 +26,6 @@ TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
def load_target_model(args, accelerator, model_version: str, weight_dtype):
# load models for each process
model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16
for pi in range(accelerator.state.num_processes):
if pi == accelerator.state.local_process_index:
@@ -45,6 +46,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
weight_dtype,
accelerator.device if args.lowram else "cpu",
model_dtype,
args.disable_mmap_load_safetensors,
)
# work on low-ram device
@@ -61,7 +63,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
def _load_target_model(
name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None
name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None, disable_mmap=False
):
# model_dtype only work with full fp16/bf16
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
@@ -76,7 +78,7 @@ def _load_target_model(
unet,
logit_scale,
ckpt_info,
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, model_dtype)
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, model_dtype, disable_mmap)
else:
# Diffusers model is loaded to CPU
from diffusers import StableDiffusionXLPipeline
@@ -324,7 +326,7 @@ def save_sd_model_on_epoch_end_or_stepwise(
)
def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
def add_sdxl_training_arguments(parser: argparse.ArgumentParser, support_text_encoder_caching: bool = True):
parser.add_argument(
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
)
@@ -333,6 +335,11 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
action="store_true",
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
)
parser.add_argument(
"--disable_mmap_load_safetensors",
action="store_true",
help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる",
)
def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
@@ -356,9 +363,9 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin
# )
# logger.info(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
assert (
not hasattr(args, "weighted_captions") or not args.weighted_captions
), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
# assert (
# not hasattr(args, "weighted_captions") or not args.weighted_captions
# ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
if supportTextEncoderCaching:
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
@@ -370,4 +377,6 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin
def sample_images(*args, **kwargs):
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs)

570
library/strategy_base.py Normal file
View File

@@ -0,0 +1,570 @@
# base class for platform strategies. this file defines the interface for strategies
import os
import re
from typing import Any, List, Optional, Tuple, Union
import numpy as np
import torch
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
# TODO remove circular import by moving ImageInfo to a separate file
# from library.train_util import ImageInfo
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class TokenizeStrategy:
_strategy = None # strategy instance: actual strategy class
_re_attention = re.compile(
r"""\\\(|
\\\)|
\\\[|
\\]|
\\\\|
\\|
\(|
\[|
:([+-]?[.\d]+)\)|
\)|
]|
[^\\()\[\]:]+|
:
""",
re.X,
)
@classmethod
def set_strategy(cls, strategy):
if cls._strategy is not None:
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
cls._strategy = strategy
@classmethod
def get_strategy(cls) -> Optional["TokenizeStrategy"]:
return cls._strategy
def _load_tokenizer(
self, model_class: Any, model_id: str, subfolder: Optional[str] = None, tokenizer_cache_dir: Optional[str] = None
) -> Any:
tokenizer = None
if tokenizer_cache_dir:
local_tokenizer_path = os.path.join(tokenizer_cache_dir, model_id.replace("/", "_"))
if os.path.exists(local_tokenizer_path):
logger.info(f"load tokenizer from cache: {local_tokenizer_path}")
tokenizer = model_class.from_pretrained(local_tokenizer_path) # same for v1 and v2
if tokenizer is None:
tokenizer = model_class.from_pretrained(model_id, subfolder=subfolder)
if tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
logger.info(f"save Tokenizer to cache: {local_tokenizer_path}")
tokenizer.save_pretrained(local_tokenizer_path)
return tokenizer
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
raise NotImplementedError
def tokenize_with_weights(self, text: Union[str, List[str]]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
"""
returns: [tokens1, tokens2, ...], [weights1, weights2, ...]
"""
raise NotImplementedError
def _get_weighted_input_ids(
self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
max_length includes starting and ending tokens.
"""
def parse_prompt_attention(text):
"""
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
Accepted tokens are:
(abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12
[abc] - decreases attention to abc by a multiplier of 1.1
\( - literal character '('
\[ - literal character '['
\) - literal character ')'
\] - literal character ']'
\\ - literal character '\'
anything else - just text
>>> parse_prompt_attention('normal text')
[['normal text', 1.0]]
>>> parse_prompt_attention('an (important) word')
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
>>> parse_prompt_attention('(unbalanced')
[['unbalanced', 1.1]]
>>> parse_prompt_attention('\(literal\]')
[['(literal]', 1.0]]
>>> parse_prompt_attention('(unnecessary)(parens)')
[['unnecessaryparens', 1.1]]
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
[['a ', 1.0],
['house', 1.5730000000000004],
[' ', 1.1],
['on', 1.0],
[' a ', 1.1],
['hill', 0.55],
[', sun, ', 1.1],
['sky', 1.4641000000000006],
['.', 1.1]]
"""
res = []
round_brackets = []
square_brackets = []
round_bracket_multiplier = 1.1
square_bracket_multiplier = 1 / 1.1
def multiply_range(start_position, multiplier):
for p in range(start_position, len(res)):
res[p][1] *= multiplier
for m in TokenizeStrategy._re_attention.finditer(text):
text = m.group(0)
weight = m.group(1)
if text.startswith("\\"):
res.append([text[1:], 1.0])
elif text == "(":
round_brackets.append(len(res))
elif text == "[":
square_brackets.append(len(res))
elif weight is not None and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), float(weight))
elif text == ")" and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), round_bracket_multiplier)
elif text == "]" and len(square_brackets) > 0:
multiply_range(square_brackets.pop(), square_bracket_multiplier)
else:
res.append([text, 1.0])
for pos in round_brackets:
multiply_range(pos, round_bracket_multiplier)
for pos in square_brackets:
multiply_range(pos, square_bracket_multiplier)
if len(res) == 0:
res = [["", 1.0]]
# merge runs of identical weights
i = 0
while i + 1 < len(res):
if res[i][1] == res[i + 1][1]:
res[i][0] += res[i + 1][0]
res.pop(i + 1)
else:
i += 1
return res
def get_prompts_with_weights(text: str, max_length: int):
r"""
Tokenize a list of prompts and return its tokens with weights of each token. max_length does not include starting and ending token.
No padding, starting or ending token is included.
"""
truncated = False
texts_and_weights = parse_prompt_attention(text)
tokens = []
weights = []
for word, weight in texts_and_weights:
# tokenize and discard the starting and the ending token
token = tokenizer(word).input_ids[1:-1]
tokens += token
# copy the weight by length of token
weights += [weight] * len(token)
# stop if the text is too long (longer than truncation limit)
if len(tokens) > max_length:
truncated = True
break
# truncate
if len(tokens) > max_length:
truncated = True
tokens = tokens[:max_length]
weights = weights[:max_length]
if truncated:
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
return tokens, weights
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad):
r"""
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
"""
tokens = [bos] + tokens + [eos] + [pad] * (max_length - 2 - len(tokens))
weights = [1.0] + weights + [1.0] * (max_length - 1 - len(weights))
return tokens, weights
if max_length is None:
max_length = tokenizer.model_max_length
tokens, weights = get_prompts_with_weights(text, max_length - 2)
tokens, weights = pad_tokens_and_weights(
tokens, weights, max_length, tokenizer.bos_token_id, tokenizer.eos_token_id, tokenizer.pad_token_id
)
return torch.tensor(tokens).unsqueeze(0), torch.tensor(weights).unsqueeze(0)
def _get_input_ids(
self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None, weighted: bool = False
) -> torch.Tensor:
"""
for SD1.5/2.0/SDXL
TODO support batch input
"""
if max_length is None:
max_length = tokenizer.model_max_length - 2
if weighted:
input_ids, weights = self._get_weighted_input_ids(tokenizer, text, max_length)
else:
input_ids = tokenizer(text, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt").input_ids
if max_length > tokenizer.model_max_length:
input_ids = input_ids.squeeze(0)
iids_list = []
if tokenizer.pad_token_id == tokenizer.eos_token_id:
# v1
# 77以上の時は "<BOS> .... <EOS> <EOS> <EOS>" でトータル227とかになっているので、"<BOS>...<EOS>"の三連に変換する
# 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に
for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2): # (1, 152, 75)
ids_chunk = (
input_ids[0].unsqueeze(0),
input_ids[i : i + tokenizer.model_max_length - 2],
input_ids[-1].unsqueeze(0),
)
ids_chunk = torch.cat(ids_chunk)
iids_list.append(ids_chunk)
else:
# v2 or SDXL
# 77以上の時は "<BOS> .... <EOS> <PAD> <PAD>..." でトータル227とかになっているので、"<BOS>...<EOS> <PAD> <PAD> ..."の三連に変換する
for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2):
ids_chunk = (
input_ids[0].unsqueeze(0), # BOS
input_ids[i : i + tokenizer.model_max_length - 2],
input_ids[-1].unsqueeze(0),
) # PAD or EOS
ids_chunk = torch.cat(ids_chunk)
# 末尾が <EOS> <PAD> または <PAD> <PAD> の場合は、何もしなくてよい
# 末尾が x <PAD/EOS> の場合は末尾を <EOS> に変えるx <EOS> なら結果的に変化なし)
if ids_chunk[-2] != tokenizer.eos_token_id and ids_chunk[-2] != tokenizer.pad_token_id:
ids_chunk[-1] = tokenizer.eos_token_id
# 先頭が <BOS> <PAD> ... の場合は <BOS> <EOS> <PAD> ... に変える
if ids_chunk[1] == tokenizer.pad_token_id:
ids_chunk[1] = tokenizer.eos_token_id
iids_list.append(ids_chunk)
input_ids = torch.stack(iids_list) # 3,77
if weighted:
weights = weights.squeeze(0)
new_weights = torch.ones(input_ids.shape)
for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2):
b = i // (tokenizer.model_max_length - 2)
new_weights[b, 1 : 1 + tokenizer.model_max_length - 2] = weights[i : i + tokenizer.model_max_length - 2]
weights = new_weights
if weighted:
return input_ids, weights
return input_ids
class TextEncodingStrategy:
_strategy = None # strategy instance: actual strategy class
@classmethod
def set_strategy(cls, strategy):
if cls._strategy is not None:
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
cls._strategy = strategy
@classmethod
def get_strategy(cls) -> Optional["TextEncodingStrategy"]:
return cls._strategy
def encode_tokens(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
) -> List[torch.Tensor]:
"""
Encode tokens into embeddings and outputs.
:param tokens: list of token tensors for each TextModel
:return: list of output embeddings for each architecture
"""
raise NotImplementedError
def encode_tokens_with_weights(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor], weights: List[torch.Tensor]
) -> List[torch.Tensor]:
"""
Encode tokens into embeddings and outputs.
:param tokens: list of token tensors for each TextModel
:param weights: list of weight tensors for each TextModel
:return: list of output embeddings for each architecture
"""
raise NotImplementedError
class TextEncoderOutputsCachingStrategy:
_strategy = None # strategy instance: actual strategy class
def __init__(
self,
cache_to_disk: bool,
batch_size: Optional[int],
skip_disk_cache_validity_check: bool,
is_partial: bool = False,
is_weighted: bool = False,
) -> None:
self._cache_to_disk = cache_to_disk
self._batch_size = batch_size
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
self._is_partial = is_partial
self._is_weighted = is_weighted
@classmethod
def set_strategy(cls, strategy):
if cls._strategy is not None:
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
cls._strategy = strategy
@classmethod
def get_strategy(cls) -> Optional["TextEncoderOutputsCachingStrategy"]:
return cls._strategy
@property
def cache_to_disk(self):
return self._cache_to_disk
@property
def batch_size(self):
return self._batch_size
@property
def is_partial(self):
return self._is_partial
@property
def is_weighted(self):
return self._is_weighted
def get_outputs_npz_path(self, image_abs_path: str) -> str:
raise NotImplementedError
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
raise NotImplementedError
def is_disk_cached_outputs_expected(self, npz_path: str) -> bool:
raise NotImplementedError
def cache_batch_outputs(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, batch: List
):
raise NotImplementedError
class LatentsCachingStrategy:
# TODO commonize utillity functions to this class, such as npz handling etc.
_strategy = None # strategy instance: actual strategy class
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
self._cache_to_disk = cache_to_disk
self._batch_size = batch_size
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
@classmethod
def set_strategy(cls, strategy):
if cls._strategy is not None:
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
cls._strategy = strategy
@classmethod
def get_strategy(cls) -> Optional["LatentsCachingStrategy"]:
return cls._strategy
@property
def cache_to_disk(self):
return self._cache_to_disk
@property
def batch_size(self):
return self._batch_size
@property
def cache_suffix(self):
raise NotImplementedError
def get_image_size_from_disk_cache_path(self, absolute_path: str, npz_path: str) -> Tuple[Optional[int], Optional[int]]:
w, h = os.path.splitext(npz_path)[0].split("_")[-2].split("x")
return int(w), int(h)
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
raise NotImplementedError
def is_disk_cached_latents_expected(
self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool
) -> bool:
raise NotImplementedError
def cache_batch_latents(self, model: Any, batch: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
raise NotImplementedError
def _default_is_disk_cached_latents_expected(
self,
latents_stride: int,
bucket_reso: Tuple[int, int],
npz_path: str,
flip_aug: bool,
alpha_mask: bool,
multi_resolution: bool = False,
):
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
return False
if self.skip_disk_cache_validity_check:
return True
expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
# e.g. "_32x64", HxW
key_reso_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}" if multi_resolution else ""
try:
npz = np.load(npz_path)
if "latents" + key_reso_suffix not in npz:
return False
if flip_aug and "latents_flipped" + key_reso_suffix not in npz:
return False
if alpha_mask and "alpha_mask" + key_reso_suffix not in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
return True
# TODO remove circular dependency for ImageInfo
def _default_cache_batch_latents(
self,
encode_by_vae,
vae_device,
vae_dtype,
image_infos: List,
flip_aug: bool,
alpha_mask: bool,
random_crop: bool,
multi_resolution: bool = False,
):
"""
Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common.
"""
from library import train_util # import here to avoid circular import
img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching(
image_infos, alpha_mask, random_crop
)
img_tensor = img_tensor.to(device=vae_device, dtype=vae_dtype)
with torch.no_grad():
latents_tensors = encode_by_vae(img_tensor).to("cpu")
if flip_aug:
img_tensor = torch.flip(img_tensor, dims=[3])
with torch.no_grad():
flipped_latents = encode_by_vae(img_tensor).to("cpu")
else:
flipped_latents = [None] * len(latents_tensors)
# for info, latents, flipped_latent, alpha_mask in zip(image_infos, latents_tensors, flipped_latents, alpha_masks):
for i in range(len(image_infos)):
info = image_infos[i]
latents = latents_tensors[i]
flipped_latent = flipped_latents[i]
alpha_mask = alpha_masks[i]
original_size = original_sizes[i]
crop_ltrb = crop_ltrbs[i]
latents_size = latents.shape[1:3] # H, W
key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" if multi_resolution else "" # e.g. "_32x64", HxW
if self.cache_to_disk:
self.save_latents_to_disk(
info.latents_npz, latents, original_size, crop_ltrb, flipped_latent, alpha_mask, key_reso_suffix
)
else:
info.latents_original_size = original_size
info.latents_crop_ltrb = crop_ltrb
info.latents = latents
if flip_aug:
info.latents_flipped = flipped_latent
info.alpha_mask = alpha_mask
def load_latents_from_disk(
self, npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
"""
for SD/SDXL
"""
return self._default_load_latents_from_disk(None, npz_path, bucket_reso)
def _default_load_latents_from_disk(
self, latents_stride: Optional[int], npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
if latents_stride is None:
key_reso_suffix = ""
else:
latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" # e.g. "_32x64", HxW
npz = np.load(npz_path)
if "latents" + key_reso_suffix not in npz:
raise ValueError(f"latents{key_reso_suffix} not found in {npz_path}")
latents = npz["latents" + key_reso_suffix]
original_size = npz["original_size" + key_reso_suffix].tolist()
crop_ltrb = npz["crop_ltrb" + key_reso_suffix].tolist()
flipped_latents = npz["latents_flipped" + key_reso_suffix] if "latents_flipped" + key_reso_suffix in npz else None
alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
def save_latents_to_disk(
self,
npz_path,
latents_tensor,
original_size,
crop_ltrb,
flipped_latents_tensor=None,
alpha_mask=None,
key_reso_suffix="",
):
kwargs = {}
if os.path.exists(npz_path):
# load existing npz and update it
npz = np.load(npz_path)
for key in npz.files:
kwargs[key] = npz[key]
kwargs["latents" + key_reso_suffix] = latents_tensor.float().cpu().numpy()
kwargs["original_size" + key_reso_suffix] = np.array(original_size)
kwargs["crop_ltrb" + key_reso_suffix] = np.array(crop_ltrb)
if flipped_latents_tensor is not None:
kwargs["latents_flipped" + key_reso_suffix] = flipped_latents_tensor.float().cpu().numpy()
if alpha_mask is not None:
kwargs["alpha_mask" + key_reso_suffix] = alpha_mask.float().cpu().numpy()
np.savez(npz_path, **kwargs)

271
library/strategy_flux.py Normal file
View File

@@ -0,0 +1,271 @@
import os
import glob
from typing import Any, List, Optional, Tuple, Union
import torch
import numpy as np
from transformers import CLIPTokenizer, T5TokenizerFast
from library import flux_utils, train_util
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14"
T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl"
class FluxTokenizeStrategy(TokenizeStrategy):
def __init__(self, t5xxl_max_length: int = 512, tokenizer_cache_dir: Optional[str] = None) -> None:
self.t5xxl_max_length = t5xxl_max_length
self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
self.t5xxl = self._load_tokenizer(T5TokenizerFast, T5_XXL_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
text = [text] if isinstance(text, str) else text
l_tokens = self.clip_l(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt")
t5_attn_mask = t5_tokens["attention_mask"]
l_tokens = l_tokens["input_ids"]
t5_tokens = t5_tokens["input_ids"]
return [l_tokens, t5_tokens, t5_attn_mask]
class FluxTextEncodingStrategy(TextEncodingStrategy):
def __init__(self, apply_t5_attn_mask: Optional[bool] = None) -> None:
"""
Args:
apply_t5_attn_mask: Default value for apply_t5_attn_mask.
"""
self.apply_t5_attn_mask = apply_t5_attn_mask
def encode_tokens(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
tokens: List[torch.Tensor],
apply_t5_attn_mask: Optional[bool] = None,
) -> List[torch.Tensor]:
# supports single model inference
if apply_t5_attn_mask is None:
apply_t5_attn_mask = self.apply_t5_attn_mask
clip_l, t5xxl = models if len(models) == 2 else (models[0], None)
l_tokens, t5_tokens = tokens[:2]
t5_attn_mask = tokens[2] if len(tokens) > 2 else None
# clip_l is None when using T5 only
if clip_l is not None and l_tokens is not None:
l_pooled = clip_l(l_tokens.to(clip_l.device))["pooler_output"]
else:
l_pooled = None
# t5xxl is None when using CLIP only
if t5xxl is not None and t5_tokens is not None:
# t5_out is [b, max length, 4096]
attention_mask = None if not apply_t5_attn_mask else t5_attn_mask.to(t5xxl.device)
t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), attention_mask, return_dict=False, output_hidden_states=True)
# if zero_pad_t5_output:
# t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1)
txt_ids = torch.zeros(t5_out.shape[0], t5_out.shape[1], 3, device=t5_out.device)
else:
t5_out = None
txt_ids = None
t5_attn_mask = None # caption may be dropped/shuffled, so t5_attn_mask should not be used to make sure the mask is same as the cached one
return [l_pooled, t5_out, txt_ids, t5_attn_mask] # returns t5_attn_mask for attention mask in transformer
class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_flux_te.npz"
def __init__(
self,
cache_to_disk: bool,
batch_size: int,
skip_disk_cache_validity_check: bool,
is_partial: bool = False,
apply_t5_attn_mask: bool = False,
) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
self.apply_t5_attn_mask = apply_t5_attn_mask
self.warn_fp8_weights = False
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return os.path.splitext(image_abs_path)[0] + FluxTextEncoderOutputsCachingStrategy.FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
def is_disk_cached_outputs_expected(self, npz_path: str):
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
return False
if self.skip_disk_cache_validity_check:
return True
try:
npz = np.load(npz_path)
if "l_pooled" not in npz:
return False
if "t5_out" not in npz:
return False
if "txt_ids" not in npz:
return False
if "t5_attn_mask" not in npz:
return False
if "apply_t5_attn_mask" not in npz:
return False
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
data = np.load(npz_path)
l_pooled = data["l_pooled"]
t5_out = data["t5_out"]
txt_ids = data["txt_ids"]
t5_attn_mask = data["t5_attn_mask"]
# apply_t5_attn_mask should be same as self.apply_t5_attn_mask
return [l_pooled, t5_out, txt_ids, t5_attn_mask]
def cache_batch_outputs(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
):
if not self.warn_fp8_weights:
if flux_utils.get_t5xxl_actual_dtype(models[1]) == torch.float8_e4m3fn:
logger.warning(
"T5 model is using fp8 weights for caching. This may affect the quality of the cached outputs."
" / T5モデルはfp8の重みを使用しています。これはキャッシュの品質に影響を与える可能性があります。"
)
self.warn_fp8_weights = True
flux_text_encoding_strategy: FluxTextEncodingStrategy = text_encoding_strategy
captions = [info.caption for info in infos]
tokens_and_masks = tokenize_strategy.tokenize(captions)
with torch.no_grad():
# attn_mask is applied in text_encoding_strategy.encode_tokens if apply_t5_attn_mask is True
l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens(tokenize_strategy, models, tokens_and_masks)
if l_pooled.dtype == torch.bfloat16:
l_pooled = l_pooled.float()
if t5_out.dtype == torch.bfloat16:
t5_out = t5_out.float()
if txt_ids.dtype == torch.bfloat16:
txt_ids = txt_ids.float()
l_pooled = l_pooled.cpu().numpy()
t5_out = t5_out.cpu().numpy()
txt_ids = txt_ids.cpu().numpy()
t5_attn_mask = tokens_and_masks[2].cpu().numpy()
for i, info in enumerate(infos):
l_pooled_i = l_pooled[i]
t5_out_i = t5_out[i]
txt_ids_i = txt_ids[i]
t5_attn_mask_i = t5_attn_mask[i]
apply_t5_attn_mask_i = self.apply_t5_attn_mask
if self.cache_to_disk:
np.savez(
info.text_encoder_outputs_npz,
l_pooled=l_pooled_i,
t5_out=t5_out_i,
txt_ids=txt_ids_i,
t5_attn_mask=t5_attn_mask_i,
apply_t5_attn_mask=apply_t5_attn_mask_i,
)
else:
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i)
class FluxLatentsCachingStrategy(LatentsCachingStrategy):
FLUX_LATENTS_NPZ_SUFFIX = "_flux.npz"
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
@property
def cache_suffix(self) -> str:
return FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return (
os.path.splitext(absolute_path)[0]
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX
)
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
def load_latents_from_disk(
self, npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # support multi-resolution
# TODO remove circular dependency for ImageInfo
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu")
vae_device = vae.device
vae_dtype = vae.dtype
self._default_cache_batch_latents(
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True
)
if not train_util.HIGH_VRAM:
train_util.clean_memory_on_device(vae.device)
if __name__ == "__main__":
# test code for FluxTokenizeStrategy
# tokenizer = sd3_models.SD3Tokenizer()
strategy = FluxTokenizeStrategy(256)
text = "hello world"
l_tokens, g_tokens, t5_tokens = strategy.tokenize(text)
# print(l_tokens.shape)
print(l_tokens)
print(g_tokens)
print(t5_tokens)
texts = ["hello world", "the quick brown fox jumps over the lazy dog"]
l_tokens_2 = strategy.clip_l(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
g_tokens_2 = strategy.clip_g(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
t5_tokens_2 = strategy.t5xxl(
texts, max_length=strategy.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt"
)
print(l_tokens_2)
print(g_tokens_2)
print(t5_tokens_2)
# compare
print(torch.allclose(l_tokens, l_tokens_2["input_ids"][0]))
print(torch.allclose(g_tokens, g_tokens_2["input_ids"][0]))
print(torch.allclose(t5_tokens, t5_tokens_2["input_ids"][0]))
text = ",".join(["hello world! this is long text"] * 50)
l_tokens, g_tokens, t5_tokens = strategy.tokenize(text)
print(l_tokens)
print(g_tokens)
print(t5_tokens)
print(f"model max length l: {strategy.clip_l.model_max_length}")
print(f"model max length g: {strategy.clip_g.model_max_length}")
print(f"model max length t5: {strategy.t5xxl.model_max_length}")

171
library/strategy_sd.py Normal file
View File

@@ -0,0 +1,171 @@
import glob
import os
from typing import Any, List, Optional, Tuple, Union
import torch
from transformers import CLIPTokenizer
from library import train_util
from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncodingStrategy
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
TOKENIZER_ID = "openai/clip-vit-large-patch14"
V2_STABLE_DIFFUSION_ID = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ
class SdTokenizeStrategy(TokenizeStrategy):
def __init__(self, v2: bool, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None) -> None:
"""
max_length does not include <BOS> and <EOS> (None, 75, 150, 225)
"""
logger.info(f"Using {'v2' if v2 else 'v1'} tokenizer")
if v2:
self.tokenizer = self._load_tokenizer(
CLIPTokenizer, V2_STABLE_DIFFUSION_ID, subfolder="tokenizer", tokenizer_cache_dir=tokenizer_cache_dir
)
else:
self.tokenizer = self._load_tokenizer(CLIPTokenizer, TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
if max_length is None:
self.max_length = self.tokenizer.model_max_length
else:
self.max_length = max_length + 2
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
text = [text] if isinstance(text, str) else text
return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)]
def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]:
text = [text] if isinstance(text, str) else text
tokens_list = []
weights_list = []
for t in text:
tokens, weights = self._get_input_ids(self.tokenizer, t, self.max_length, weighted=True)
tokens_list.append(tokens)
weights_list.append(weights)
return [torch.stack(tokens_list, dim=0)], [torch.stack(weights_list, dim=0)]
class SdTextEncodingStrategy(TextEncodingStrategy):
def __init__(self, clip_skip: Optional[int] = None) -> None:
self.clip_skip = clip_skip
def encode_tokens(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
) -> List[torch.Tensor]:
text_encoder = models[0]
tokens = tokens[0]
sd_tokenize_strategy = tokenize_strategy # type: SdTokenizeStrategy
# tokens: b,n,77
b_size = tokens.size()[0]
max_token_length = tokens.size()[1] * tokens.size()[2]
model_max_length = sd_tokenize_strategy.tokenizer.model_max_length
tokens = tokens.reshape((-1, model_max_length)) # batch_size*3, 77
tokens = tokens.to(text_encoder.device)
if self.clip_skip is None:
encoder_hidden_states = text_encoder(tokens)[0]
else:
enc_out = text_encoder(tokens, output_hidden_states=True, return_dict=True)
encoder_hidden_states = enc_out["hidden_states"][-self.clip_skip]
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
# bs*3, 77, 768 or 1024
encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1]))
if max_token_length != model_max_length:
v1 = sd_tokenize_strategy.tokenizer.pad_token_id == sd_tokenize_strategy.tokenizer.eos_token_id
if not v1:
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
for i in range(1, max_token_length, model_max_length):
chunk = encoder_hidden_states[:, i : i + model_max_length - 2] # <BOS> の後から 最後の前まで
if i > 0:
for j in range(len(chunk)):
if tokens[j, 1] == sd_tokenize_strategy.tokenizer.eos_token:
# 空、つまり <BOS> <EOS> <PAD> ...のパターン
chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
encoder_hidden_states = torch.cat(states_list, dim=1)
else:
# v1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
for i in range(1, max_token_length, model_max_length):
states_list.append(encoder_hidden_states[:, i : i + model_max_length - 2]) # <BOS> の後から <EOS> の前まで
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS>
encoder_hidden_states = torch.cat(states_list, dim=1)
return [encoder_hidden_states]
def encode_tokens_with_weights(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
tokens_list: List[torch.Tensor],
weights_list: List[torch.Tensor],
) -> List[torch.Tensor]:
encoder_hidden_states = self.encode_tokens(tokenize_strategy, models, tokens_list)[0]
weights = weights_list[0].to(encoder_hidden_states.device)
# apply weights
if weights.shape[1] == 1: # no max_token_length
# weights: ((b, 1, 77), (b, 1, 77)), hidden_states: (b, 77, 768), (b, 77, 768)
encoder_hidden_states = encoder_hidden_states * weights.squeeze(1).unsqueeze(2)
else:
# weights: ((b, n, 77), (b, n, 77)), hidden_states: (b, n*75+2, 768), (b, n*75+2, 768)
for i in range(weights.shape[1]):
encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] = encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] * weights[
:, i, 1:-1
].unsqueeze(-1)
return [encoder_hidden_states]
class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
# sd and sdxl share the same strategy. we can make them separate, but the difference is only the suffix.
# and we keep the old npz for the backward compatibility.
SD_OLD_LATENTS_NPZ_SUFFIX = ".npz"
SD_LATENTS_NPZ_SUFFIX = "_sd.npz"
SDXL_LATENTS_NPZ_SUFFIX = "_sdxl.npz"
def __init__(self, sd: bool, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
self.sd = sd
self.suffix = (
SdSdxlLatentsCachingStrategy.SD_LATENTS_NPZ_SUFFIX if sd else SdSdxlLatentsCachingStrategy.SDXL_LATENTS_NPZ_SUFFIX
)
@property
def cache_suffix(self) -> str:
return self.suffix
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
# support old .npz
old_npz_file = os.path.splitext(absolute_path)[0] + SdSdxlLatentsCachingStrategy.SD_OLD_LATENTS_NPZ_SUFFIX
if os.path.exists(old_npz_file):
return old_npz_file
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask)
# TODO remove circular dependency for ImageInfo
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).latent_dist.sample()
vae_device = vae.device
vae_dtype = vae.dtype
self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop)
if not train_util.HIGH_VRAM:
train_util.clean_memory_on_device(vae.device)

420
library/strategy_sd3.py Normal file
View File

@@ -0,0 +1,420 @@
import os
import glob
import random
from typing import Any, List, Optional, Tuple, Union
import torch
import numpy as np
from transformers import CLIPTokenizer, T5TokenizerFast, CLIPTextModel, CLIPTextModelWithProjection, T5EncoderModel
from library import sd3_utils, train_util
from library import sd3_models
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14"
CLIP_G_TOKENIZER_ID = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl"
class Sd3TokenizeStrategy(TokenizeStrategy):
def __init__(self, t5xxl_max_length: int = 256, tokenizer_cache_dir: Optional[str] = None) -> None:
self.t5xxl_max_length = t5xxl_max_length
self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
self.clip_g = self._load_tokenizer(CLIPTokenizer, CLIP_G_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
self.t5xxl = self._load_tokenizer(T5TokenizerFast, T5_XXL_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
self.clip_g.pad_token_id = 0 # use 0 as pad token for clip_g
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
text = [text] if isinstance(text, str) else text
l_tokens = self.clip_l(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
g_tokens = self.clip_g(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt")
l_attn_mask = l_tokens["attention_mask"]
g_attn_mask = g_tokens["attention_mask"]
t5_attn_mask = t5_tokens["attention_mask"]
l_tokens = l_tokens["input_ids"]
g_tokens = g_tokens["input_ids"]
t5_tokens = t5_tokens["input_ids"]
return [l_tokens, g_tokens, t5_tokens, l_attn_mask, g_attn_mask, t5_attn_mask]
class Sd3TextEncodingStrategy(TextEncodingStrategy):
def __init__(
self,
apply_lg_attn_mask: Optional[bool] = None,
apply_t5_attn_mask: Optional[bool] = None,
l_dropout_rate: float = 0.0,
g_dropout_rate: float = 0.0,
t5_dropout_rate: float = 0.0,
) -> None:
"""
Args:
apply_t5_attn_mask: Default value for apply_t5_attn_mask.
"""
self.apply_lg_attn_mask = apply_lg_attn_mask
self.apply_t5_attn_mask = apply_t5_attn_mask
self.l_dropout_rate = l_dropout_rate
self.g_dropout_rate = g_dropout_rate
self.t5_dropout_rate = t5_dropout_rate
def encode_tokens(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
tokens: List[torch.Tensor],
apply_lg_attn_mask: Optional[bool] = False,
apply_t5_attn_mask: Optional[bool] = False,
enable_dropout: bool = True,
) -> List[torch.Tensor]:
"""
returned embeddings are not masked
"""
clip_l, clip_g, t5xxl = models
clip_l: Optional[CLIPTextModel]
clip_g: Optional[CLIPTextModelWithProjection]
t5xxl: Optional[T5EncoderModel]
if apply_lg_attn_mask is None:
apply_lg_attn_mask = self.apply_lg_attn_mask
if apply_t5_attn_mask is None:
apply_t5_attn_mask = self.apply_t5_attn_mask
l_tokens, g_tokens, t5_tokens, l_attn_mask, g_attn_mask, t5_attn_mask = tokens
# dropout: if enable_dropout is False, dropout is not applied. dropout means zeroing out embeddings
if l_tokens is None or clip_l is None:
assert g_tokens is None, "g_tokens must be None if l_tokens is None"
lg_out = None
lg_pooled = None
l_attn_mask = None
g_attn_mask = None
else:
assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None"
# drop some members of the batch: we do not call clip_l and clip_g for dropped members
batch_size, l_seq_len = l_tokens.shape
g_seq_len = g_tokens.shape[1]
non_drop_l_indices = []
non_drop_g_indices = []
for i in range(l_tokens.shape[0]):
drop_l = enable_dropout and (self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate)
drop_g = enable_dropout and (self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate)
if not drop_l:
non_drop_l_indices.append(i)
if not drop_g:
non_drop_g_indices.append(i)
# filter out dropped members
if len(non_drop_l_indices) > 0 and len(non_drop_l_indices) < batch_size:
l_tokens = l_tokens[non_drop_l_indices]
l_attn_mask = l_attn_mask[non_drop_l_indices]
if len(non_drop_g_indices) > 0 and len(non_drop_g_indices) < batch_size:
g_tokens = g_tokens[non_drop_g_indices]
g_attn_mask = g_attn_mask[non_drop_g_indices]
# call clip_l for non-dropped members
if len(non_drop_l_indices) > 0:
nd_l_attn_mask = l_attn_mask.to(clip_l.device)
prompt_embeds = clip_l(
l_tokens.to(clip_l.device), nd_l_attn_mask if apply_lg_attn_mask else None, output_hidden_states=True
)
nd_l_pooled = prompt_embeds[0]
nd_l_out = prompt_embeds.hidden_states[-2]
if len(non_drop_g_indices) > 0:
nd_g_attn_mask = g_attn_mask.to(clip_g.device)
prompt_embeds = clip_g(
g_tokens.to(clip_g.device), nd_g_attn_mask if apply_lg_attn_mask else None, output_hidden_states=True
)
nd_g_pooled = prompt_embeds[0]
nd_g_out = prompt_embeds.hidden_states[-2]
# fill in the dropped members
if len(non_drop_l_indices) == batch_size:
l_pooled = nd_l_pooled
l_out = nd_l_out
else:
# model output is always float32 because of the models are wrapped with Accelerator
l_pooled = torch.zeros((batch_size, 768), device=clip_l.device, dtype=torch.float32)
l_out = torch.zeros((batch_size, l_seq_len, 768), device=clip_l.device, dtype=torch.float32)
l_attn_mask = torch.zeros((batch_size, l_seq_len), device=clip_l.device, dtype=l_attn_mask.dtype)
if len(non_drop_l_indices) > 0:
l_pooled[non_drop_l_indices] = nd_l_pooled
l_out[non_drop_l_indices] = nd_l_out
l_attn_mask[non_drop_l_indices] = nd_l_attn_mask
if len(non_drop_g_indices) == batch_size:
g_pooled = nd_g_pooled
g_out = nd_g_out
else:
g_pooled = torch.zeros((batch_size, 1280), device=clip_g.device, dtype=torch.float32)
g_out = torch.zeros((batch_size, g_seq_len, 1280), device=clip_g.device, dtype=torch.float32)
g_attn_mask = torch.zeros((batch_size, g_seq_len), device=clip_g.device, dtype=g_attn_mask.dtype)
if len(non_drop_g_indices) > 0:
g_pooled[non_drop_g_indices] = nd_g_pooled
g_out[non_drop_g_indices] = nd_g_out
g_attn_mask[non_drop_g_indices] = nd_g_attn_mask
lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1)
lg_out = torch.cat([l_out, g_out], dim=-1)
if t5xxl is None or t5_tokens is None:
t5_out = None
t5_attn_mask = None
else:
# drop some members of the batch: we do not call t5xxl for dropped members
batch_size, t5_seq_len = t5_tokens.shape
non_drop_t5_indices = []
for i in range(t5_tokens.shape[0]):
drop_t5 = enable_dropout and (self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate)
if not drop_t5:
non_drop_t5_indices.append(i)
# filter out dropped members
if len(non_drop_t5_indices) > 0 and len(non_drop_t5_indices) < batch_size:
t5_tokens = t5_tokens[non_drop_t5_indices]
t5_attn_mask = t5_attn_mask[non_drop_t5_indices]
# call t5xxl for non-dropped members
if len(non_drop_t5_indices) > 0:
nd_t5_attn_mask = t5_attn_mask.to(t5xxl.device)
nd_t5_out, _ = t5xxl(
t5_tokens.to(t5xxl.device),
nd_t5_attn_mask if apply_t5_attn_mask else None,
return_dict=False,
output_hidden_states=True,
)
# fill in the dropped members
if len(non_drop_t5_indices) == batch_size:
t5_out = nd_t5_out
else:
t5_out = torch.zeros((batch_size, t5_seq_len, 4096), device=t5xxl.device, dtype=torch.float32)
t5_attn_mask = torch.zeros((batch_size, t5_seq_len), device=t5xxl.device, dtype=t5_attn_mask.dtype)
if len(non_drop_t5_indices) > 0:
t5_out[non_drop_t5_indices] = nd_t5_out
t5_attn_mask[non_drop_t5_indices] = nd_t5_attn_mask
# masks are used for attention masking in transformer
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
def drop_cached_text_encoder_outputs(
self,
lg_out: torch.Tensor,
t5_out: torch.Tensor,
lg_pooled: torch.Tensor,
l_attn_mask: torch.Tensor,
g_attn_mask: torch.Tensor,
t5_attn_mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# dropout: if enable_dropout is True, dropout is not applied. dropout means zeroing out embeddings
if lg_out is not None:
for i in range(lg_out.shape[0]):
drop_l = self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate
if drop_l:
lg_out[i, :, :768] = torch.zeros_like(lg_out[i, :, :768])
lg_pooled[i, :768] = torch.zeros_like(lg_pooled[i, :768])
if l_attn_mask is not None:
l_attn_mask[i] = torch.zeros_like(l_attn_mask[i])
drop_g = self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate
if drop_g:
lg_out[i, :, 768:] = torch.zeros_like(lg_out[i, :, 768:])
lg_pooled[i, 768:] = torch.zeros_like(lg_pooled[i, 768:])
if g_attn_mask is not None:
g_attn_mask[i] = torch.zeros_like(g_attn_mask[i])
if t5_out is not None:
for i in range(t5_out.shape[0]):
drop_t5 = self.t5_dropout_rate > 0.0 and random.random() < self.t5_dropout_rate
if drop_t5:
t5_out[i] = torch.zeros_like(t5_out[i])
if t5_attn_mask is not None:
t5_attn_mask[i] = torch.zeros_like(t5_attn_mask[i])
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
def concat_encodings(
self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
if t5_out is None:
t5_out = torch.zeros((lg_out.shape[0], 77, 4096), device=lg_out.device, dtype=lg_out.dtype)
return torch.cat([lg_out, t5_out], dim=-2), lg_pooled
class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_sd3_te.npz"
def __init__(
self,
cache_to_disk: bool,
batch_size: int,
skip_disk_cache_validity_check: bool,
is_partial: bool = False,
apply_lg_attn_mask: bool = False,
apply_t5_attn_mask: bool = False,
) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
self.apply_lg_attn_mask = apply_lg_attn_mask
self.apply_t5_attn_mask = apply_t5_attn_mask
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return os.path.splitext(image_abs_path)[0] + Sd3TextEncoderOutputsCachingStrategy.SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
def is_disk_cached_outputs_expected(self, npz_path: str):
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
return False
if self.skip_disk_cache_validity_check:
return True
try:
npz = np.load(npz_path)
if "lg_out" not in npz:
return False
if "lg_pooled" not in npz:
return False
if "clip_l_attn_mask" not in npz or "clip_g_attn_mask" not in npz: # necessary even if not used
return False
if "apply_lg_attn_mask" not in npz:
return False
if "t5_out" not in npz:
return False
if "t5_attn_mask" not in npz:
return False
npz_apply_lg_attn_mask = npz["apply_lg_attn_mask"]
if npz_apply_lg_attn_mask != self.apply_lg_attn_mask:
return False
if "apply_t5_attn_mask" not in npz:
return False
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
data = np.load(npz_path)
lg_out = data["lg_out"]
lg_pooled = data["lg_pooled"]
t5_out = data["t5_out"]
l_attn_mask = data["clip_l_attn_mask"]
g_attn_mask = data["clip_g_attn_mask"]
t5_attn_mask = data["t5_attn_mask"]
# apply_t5_attn_mask and apply_lg_attn_mask are same as self.apply_t5_attn_mask and self.apply_lg_attn_mask
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
def cache_batch_outputs(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
):
sd3_text_encoding_strategy: Sd3TextEncodingStrategy = text_encoding_strategy
captions = [info.caption for info in infos]
tokens_and_masks = tokenize_strategy.tokenize(captions)
with torch.no_grad():
# always disable dropout during caching
lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = sd3_text_encoding_strategy.encode_tokens(
tokenize_strategy,
models,
tokens_and_masks,
apply_lg_attn_mask=self.apply_lg_attn_mask,
apply_t5_attn_mask=self.apply_t5_attn_mask,
enable_dropout=False,
)
if lg_out.dtype == torch.bfloat16:
lg_out = lg_out.float()
if lg_pooled.dtype == torch.bfloat16:
lg_pooled = lg_pooled.float()
if t5_out.dtype == torch.bfloat16:
t5_out = t5_out.float()
lg_out = lg_out.cpu().numpy()
lg_pooled = lg_pooled.cpu().numpy()
t5_out = t5_out.cpu().numpy()
l_attn_mask = tokens_and_masks[3].cpu().numpy()
g_attn_mask = tokens_and_masks[4].cpu().numpy()
t5_attn_mask = tokens_and_masks[5].cpu().numpy()
for i, info in enumerate(infos):
lg_out_i = lg_out[i]
t5_out_i = t5_out[i]
lg_pooled_i = lg_pooled[i]
l_attn_mask_i = l_attn_mask[i]
g_attn_mask_i = g_attn_mask[i]
t5_attn_mask_i = t5_attn_mask[i]
apply_lg_attn_mask = self.apply_lg_attn_mask
apply_t5_attn_mask = self.apply_t5_attn_mask
if self.cache_to_disk:
np.savez(
info.text_encoder_outputs_npz,
lg_out=lg_out_i,
lg_pooled=lg_pooled_i,
t5_out=t5_out_i,
clip_l_attn_mask=l_attn_mask_i,
clip_g_attn_mask=g_attn_mask_i,
t5_attn_mask=t5_attn_mask_i,
apply_lg_attn_mask=apply_lg_attn_mask,
apply_t5_attn_mask=apply_t5_attn_mask,
)
else:
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i, l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i)
class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz"
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
@property
def cache_suffix(self) -> str:
return Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return (
os.path.splitext(absolute_path)[0]
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX
)
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
def load_latents_from_disk(
self, npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # support multi-resolution
# TODO remove circular dependency for ImageInfo
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu")
vae_device = vae.device
vae_dtype = vae.dtype
self._default_cache_batch_latents(
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True
)
if not train_util.HIGH_VRAM:
train_util.clean_memory_on_device(vae.device)

306
library/strategy_sdxl.py Normal file
View File

@@ -0,0 +1,306 @@
import os
from typing import Any, List, Optional, Tuple, Union
import numpy as np
import torch
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
from library.strategy_base import TokenizeStrategy, TextEncodingStrategy, TextEncoderOutputsCachingStrategy
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
class SdxlTokenizeStrategy(TokenizeStrategy):
def __init__(self, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None) -> None:
self.tokenizer1 = self._load_tokenizer(CLIPTokenizer, TOKENIZER1_PATH, tokenizer_cache_dir=tokenizer_cache_dir)
self.tokenizer2 = self._load_tokenizer(CLIPTokenizer, TOKENIZER2_PATH, tokenizer_cache_dir=tokenizer_cache_dir)
self.tokenizer2.pad_token_id = 0 # use 0 as pad token for tokenizer2
if max_length is None:
self.max_length = self.tokenizer1.model_max_length
else:
self.max_length = max_length + 2
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
text = [text] if isinstance(text, str) else text
return (
torch.stack([self._get_input_ids(self.tokenizer1, t, self.max_length) for t in text], dim=0),
torch.stack([self._get_input_ids(self.tokenizer2, t, self.max_length) for t in text], dim=0),
)
def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]:
text = [text] if isinstance(text, str) else text
tokens1_list, tokens2_list = [], []
weights1_list, weights2_list = [], []
for t in text:
tokens1, weights1 = self._get_input_ids(self.tokenizer1, t, self.max_length, weighted=True)
tokens2, weights2 = self._get_input_ids(self.tokenizer2, t, self.max_length, weighted=True)
tokens1_list.append(tokens1)
tokens2_list.append(tokens2)
weights1_list.append(weights1)
weights2_list.append(weights2)
return [torch.stack(tokens1_list, dim=0), torch.stack(tokens2_list, dim=0)], [
torch.stack(weights1_list, dim=0),
torch.stack(weights2_list, dim=0),
]
class SdxlTextEncodingStrategy(TextEncodingStrategy):
def __init__(self) -> None:
pass
def _pool_workaround(
self, text_encoder: CLIPTextModelWithProjection, last_hidden_state: torch.Tensor, input_ids: torch.Tensor, eos_token_id: int
):
r"""
workaround for CLIP's pooling bug: it returns the hidden states for the max token id as the pooled output
instead of the hidden states for the EOS token
If we use Textual Inversion, we need to use the hidden states for the EOS token as the pooled output
Original code from CLIP's pooling function:
\# text_embeds.shape = [batch_size, sequence_length, transformer.width]
\# take features from the eot embedding (eot_token is the highest number in each sequence)
\# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
]
"""
# input_ids: b*n,77
# find index for EOS token
# Following code is not working if one of the input_ids has multiple EOS tokens (very odd case)
# eos_token_index = torch.where(input_ids == eos_token_id)[1]
# eos_token_index = eos_token_index.to(device=last_hidden_state.device)
# Create a mask where the EOS tokens are
eos_token_mask = (input_ids == eos_token_id).int()
# Use argmax to find the last index of the EOS token for each element in the batch
eos_token_index = torch.argmax(eos_token_mask, dim=1) # this will be 0 if there is no EOS token, it's fine
eos_token_index = eos_token_index.to(device=last_hidden_state.device)
# get hidden states for EOS token
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index
]
# apply projection: projection may be of different dtype than last_hidden_state
pooled_output = text_encoder.text_projection(pooled_output.to(text_encoder.text_projection.weight.dtype))
pooled_output = pooled_output.to(last_hidden_state.dtype)
return pooled_output
def _get_hidden_states_sdxl(
self,
input_ids1: torch.Tensor,
input_ids2: torch.Tensor,
tokenizer1: CLIPTokenizer,
tokenizer2: CLIPTokenizer,
text_encoder1: Union[CLIPTextModel, torch.nn.Module],
text_encoder2: Union[CLIPTextModelWithProjection, torch.nn.Module],
unwrapped_text_encoder2: Optional[CLIPTextModelWithProjection] = None,
):
# input_ids: b,n,77 -> b*n, 77
b_size = input_ids1.size()[0]
if input_ids1.size()[1] == 1:
max_token_length = None
else:
max_token_length = input_ids1.size()[1] * input_ids1.size()[2]
input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77
input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77
input_ids1 = input_ids1.to(text_encoder1.device)
input_ids2 = input_ids2.to(text_encoder2.device)
# text_encoder1
enc_out = text_encoder1(input_ids1, output_hidden_states=True, return_dict=True)
hidden_states1 = enc_out["hidden_states"][11]
# text_encoder2
enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True)
hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer
# pool2 = enc_out["text_embeds"]
unwrapped_text_encoder2 = unwrapped_text_encoder2 or text_encoder2
pool2 = self._pool_workaround(unwrapped_text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id)
# b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280
n_size = 1 if max_token_length is None else max_token_length // 75
hidden_states1 = hidden_states1.reshape((b_size, -1, hidden_states1.shape[-1]))
hidden_states2 = hidden_states2.reshape((b_size, -1, hidden_states2.shape[-1]))
if max_token_length is not None:
# bs*3, 77, 768 or 1024
# encoder1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
states_list = [hidden_states1[:, 0].unsqueeze(1)] # <BOS>
for i in range(1, max_token_length, tokenizer1.model_max_length):
states_list.append(hidden_states1[:, i : i + tokenizer1.model_max_length - 2]) # <BOS> の後から <EOS> の前まで
states_list.append(hidden_states1[:, -1].unsqueeze(1)) # <EOS>
hidden_states1 = torch.cat(states_list, dim=1)
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
states_list = [hidden_states2[:, 0].unsqueeze(1)] # <BOS>
for i in range(1, max_token_length, tokenizer2.model_max_length):
chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2] # <BOS> の後から 最後の前まで
# this causes an error:
# RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
# if i > 1:
# for j in range(len(chunk)): # batch_size
# if input_ids2[n_index + j * n_size, 1] == tokenizer2.eos_token_id: # 空、つまり <BOS> <EOS> <PAD> ...のパターン
# chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
states_list.append(hidden_states2[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
hidden_states2 = torch.cat(states_list, dim=1)
# pool はnの最初のものを使う
pool2 = pool2[::n_size]
return hidden_states1, hidden_states2, pool2
def encode_tokens(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
) -> List[torch.Tensor]:
"""
Args:
tokenize_strategy: TokenizeStrategy
models: List of models, [text_encoder1, text_encoder2, unwrapped text_encoder2 (optional)].
If text_encoder2 is wrapped by accelerate, unwrapped_text_encoder2 is required
tokens: List of tokens, for text_encoder1 and text_encoder2
"""
if len(models) == 2:
text_encoder1, text_encoder2 = models
unwrapped_text_encoder2 = None
else:
text_encoder1, text_encoder2, unwrapped_text_encoder2 = models
tokens1, tokens2 = tokens
sdxl_tokenize_strategy = tokenize_strategy # type: SdxlTokenizeStrategy
tokenizer1, tokenizer2 = sdxl_tokenize_strategy.tokenizer1, sdxl_tokenize_strategy.tokenizer2
hidden_states1, hidden_states2, pool2 = self._get_hidden_states_sdxl(
tokens1, tokens2, tokenizer1, tokenizer2, text_encoder1, text_encoder2, unwrapped_text_encoder2
)
return [hidden_states1, hidden_states2, pool2]
def encode_tokens_with_weights(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
tokens_list: List[torch.Tensor],
weights_list: List[torch.Tensor],
) -> List[torch.Tensor]:
hidden_states1, hidden_states2, pool2 = self.encode_tokens(tokenize_strategy, models, tokens_list)
weights_list = [weights.to(hidden_states1.device) for weights in weights_list]
# apply weights
if weights_list[0].shape[1] == 1: # no max_token_length
# weights: ((b, 1, 77), (b, 1, 77)), hidden_states: (b, 77, 768), (b, 77, 768)
hidden_states1 = hidden_states1 * weights_list[0].squeeze(1).unsqueeze(2)
hidden_states2 = hidden_states2 * weights_list[1].squeeze(1).unsqueeze(2)
else:
# weights: ((b, n, 77), (b, n, 77)), hidden_states: (b, n*75+2, 768), (b, n*75+2, 768)
for weight, hidden_states in zip(weights_list, [hidden_states1, hidden_states2]):
for i in range(weight.shape[1]):
hidden_states[:, i * 75 + 1 : i * 75 + 76] = hidden_states[:, i * 75 + 1 : i * 75 + 76] * weight[
:, i, 1:-1
].unsqueeze(-1)
return [hidden_states1, hidden_states2, pool2]
class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_te_outputs.npz"
def __init__(
self,
cache_to_disk: bool,
batch_size: int,
skip_disk_cache_validity_check: bool,
is_partial: bool = False,
is_weighted: bool = False,
) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial, is_weighted)
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return os.path.splitext(image_abs_path)[0] + SdxlTextEncoderOutputsCachingStrategy.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
def is_disk_cached_outputs_expected(self, npz_path: str):
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
return False
if self.skip_disk_cache_validity_check:
return True
try:
npz = np.load(npz_path)
if "hidden_state1" not in npz or "hidden_state2" not in npz or "pool2" not in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
data = np.load(npz_path)
hidden_state1 = data["hidden_state1"]
hidden_state2 = data["hidden_state2"]
pool2 = data["pool2"]
return [hidden_state1, hidden_state2, pool2]
def cache_batch_outputs(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
):
sdxl_text_encoding_strategy = text_encoding_strategy # type: SdxlTextEncodingStrategy
captions = [info.caption for info in infos]
if self.is_weighted:
tokens_list, weights_list = tokenize_strategy.tokenize_with_weights(captions)
with torch.no_grad():
hidden_state1, hidden_state2, pool2 = sdxl_text_encoding_strategy.encode_tokens_with_weights(
tokenize_strategy, models, tokens_list, weights_list
)
else:
tokens1, tokens2 = tokenize_strategy.tokenize(captions)
with torch.no_grad():
hidden_state1, hidden_state2, pool2 = sdxl_text_encoding_strategy.encode_tokens(
tokenize_strategy, models, [tokens1, tokens2]
)
if hidden_state1.dtype == torch.bfloat16:
hidden_state1 = hidden_state1.float()
if hidden_state2.dtype == torch.bfloat16:
hidden_state2 = hidden_state2.float()
if pool2.dtype == torch.bfloat16:
pool2 = pool2.float()
hidden_state1 = hidden_state1.cpu().numpy()
hidden_state2 = hidden_state2.cpu().numpy()
pool2 = pool2.cpu().numpy()
for i, info in enumerate(infos):
hidden_state1_i = hidden_state1[i]
hidden_state2_i = hidden_state2[i]
pool2_i = pool2[i]
if self.cache_to_disk:
np.savez(
info.text_encoder_outputs_npz,
hidden_state1=hidden_state1_i,
hidden_state2=hidden_state2_i,
pool2=pool2_i,
)
else:
info.text_encoder_outputs = [hidden_state1_i, hidden_state2_i, pool2_i]

File diff suppressed because it is too large Load Diff

View File

@@ -1,18 +1,29 @@
import logging
import sys
import threading
import torch
from torchvision import transforms
from typing import *
import json
import struct
import torch
import torch.nn as nn
from torchvision import transforms
from diffusers import EulerAncestralDiscreteScheduler
import diffusers.schedulers.scheduling_euler_ancestral_discrete
from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput
import cv2
from PIL import Image
import numpy as np
from safetensors.torch import load_file
def fire_in_thread(f, *args, **kwargs):
threading.Thread(target=f, args=args, kwargs=kwargs).start()
# region Logging
def add_logging_arguments(parser):
parser.add_argument(
"--console_log_level",
@@ -79,10 +90,315 @@ def setup_logging(args=None, log_level=None, reset=False):
logger.info(msg_init)
# endregion
# region PyTorch utils
def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
weight_swap_jobs = []
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
# cuda to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.record_stream(stream)
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
stream.synchronize()
# cpu to cuda
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view
stream.synchronize()
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
def weighs_to_device(layer: nn.Module, device: torch.device):
for module in layer.modules():
if hasattr(module, "weight") and module.weight is not None:
module.weight.data = module.weight.data.to(device, non_blocking=True)
def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype:
"""
Convert a string to a torch.dtype
Args:
s: string representation of the dtype
default_dtype: default dtype to return if s is None
Returns:
torch.dtype: the corresponding torch.dtype
Raises:
ValueError: if the dtype is not supported
Examples:
>>> str_to_dtype("float32")
torch.float32
>>> str_to_dtype("fp32")
torch.float32
>>> str_to_dtype("float16")
torch.float16
>>> str_to_dtype("fp16")
torch.float16
>>> str_to_dtype("bfloat16")
torch.bfloat16
>>> str_to_dtype("bf16")
torch.bfloat16
>>> str_to_dtype("fp8")
torch.float8_e4m3fn
>>> str_to_dtype("fp8_e4m3fn")
torch.float8_e4m3fn
>>> str_to_dtype("fp8_e4m3fnuz")
torch.float8_e4m3fnuz
>>> str_to_dtype("fp8_e5m2")
torch.float8_e5m2
>>> str_to_dtype("fp8_e5m2fnuz")
torch.float8_e5m2fnuz
"""
if s is None:
return default_dtype
if s in ["bf16", "bfloat16"]:
return torch.bfloat16
elif s in ["fp16", "float16"]:
return torch.float16
elif s in ["fp32", "float32", "float"]:
return torch.float32
elif s in ["fp8_e4m3fn", "e4m3fn", "float8_e4m3fn"]:
return torch.float8_e4m3fn
elif s in ["fp8_e4m3fnuz", "e4m3fnuz", "float8_e4m3fnuz"]:
return torch.float8_e4m3fnuz
elif s in ["fp8_e5m2", "e5m2", "float8_e5m2"]:
return torch.float8_e5m2
elif s in ["fp8_e5m2fnuz", "e5m2fnuz", "float8_e5m2fnuz"]:
return torch.float8_e5m2fnuz
elif s in ["fp8", "float8"]:
return torch.float8_e4m3fn # default fp8
else:
raise ValueError(f"Unsupported dtype: {s}")
def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None):
"""
memory efficient save file
"""
_TYPES = {
torch.float64: "F64",
torch.float32: "F32",
torch.float16: "F16",
torch.bfloat16: "BF16",
torch.int64: "I64",
torch.int32: "I32",
torch.int16: "I16",
torch.int8: "I8",
torch.uint8: "U8",
torch.bool: "BOOL",
getattr(torch, "float8_e5m2", None): "F8_E5M2",
getattr(torch, "float8_e4m3fn", None): "F8_E4M3",
}
_ALIGN = 256
def validate_metadata(metadata: Dict[str, Any]) -> Dict[str, str]:
validated = {}
for key, value in metadata.items():
if not isinstance(key, str):
raise ValueError(f"Metadata key must be a string, got {type(key)}")
if not isinstance(value, str):
print(f"Warning: Metadata value for key '{key}' is not a string. Converting to string.")
validated[key] = str(value)
else:
validated[key] = value
return validated
print(f"Using memory efficient save file: {filename}")
header = {}
offset = 0
if metadata:
header["__metadata__"] = validate_metadata(metadata)
for k, v in tensors.items():
if v.numel() == 0: # empty tensor
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset]}
else:
size = v.numel() * v.element_size()
header[k] = {"dtype": _TYPES[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + size]}
offset += size
hjson = json.dumps(header).encode("utf-8")
hjson += b" " * (-(len(hjson) + 8) % _ALIGN)
with open(filename, "wb") as f:
f.write(struct.pack("<Q", len(hjson)))
f.write(hjson)
for k, v in tensors.items():
if v.numel() == 0:
continue
if v.is_cuda:
# Direct GPU to disk save
with torch.cuda.device(v.device):
if v.dim() == 0: # if scalar, need to add a dimension to work with view
v = v.unsqueeze(0)
tensor_bytes = v.contiguous().view(torch.uint8)
tensor_bytes.cpu().numpy().tofile(f)
else:
# CPU tensor save
if v.dim() == 0: # if scalar, need to add a dimension to work with view
v = v.unsqueeze(0)
v.contiguous().view(torch.uint8).numpy().tofile(f)
class MemoryEfficientSafeOpen:
# does not support metadata loading
def __init__(self, filename):
self.filename = filename
self.header, self.header_size = self._read_header()
self.file = open(filename, "rb")
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.file.close()
def keys(self):
return [k for k in self.header.keys() if k != "__metadata__"]
def get_tensor(self, key):
if key not in self.header:
raise KeyError(f"Tensor '{key}' not found in the file")
metadata = self.header[key]
offset_start, offset_end = metadata["data_offsets"]
if offset_start == offset_end:
tensor_bytes = None
else:
# adjust offset by header size
self.file.seek(self.header_size + 8 + offset_start)
tensor_bytes = self.file.read(offset_end - offset_start)
return self._deserialize_tensor(tensor_bytes, metadata)
def _read_header(self):
with open(self.filename, "rb") as f:
header_size = struct.unpack("<Q", f.read(8))[0]
header_json = f.read(header_size).decode("utf-8")
return json.loads(header_json), header_size
def _deserialize_tensor(self, tensor_bytes, metadata):
dtype = self._get_torch_dtype(metadata["dtype"])
shape = metadata["shape"]
if tensor_bytes is None:
byte_tensor = torch.empty(0, dtype=torch.uint8)
else:
tensor_bytes = bytearray(tensor_bytes) # make it writable
byte_tensor = torch.frombuffer(tensor_bytes, dtype=torch.uint8)
# process float8 types
if metadata["dtype"] in ["F8_E5M2", "F8_E4M3"]:
return self._convert_float8(byte_tensor, metadata["dtype"], shape)
# convert to the target dtype and reshape
return byte_tensor.view(dtype).reshape(shape)
@staticmethod
def _get_torch_dtype(dtype_str):
dtype_map = {
"F64": torch.float64,
"F32": torch.float32,
"F16": torch.float16,
"BF16": torch.bfloat16,
"I64": torch.int64,
"I32": torch.int32,
"I16": torch.int16,
"I8": torch.int8,
"U8": torch.uint8,
"BOOL": torch.bool,
}
# add float8 types if available
if hasattr(torch, "float8_e5m2"):
dtype_map["F8_E5M2"] = torch.float8_e5m2
if hasattr(torch, "float8_e4m3fn"):
dtype_map["F8_E4M3"] = torch.float8_e4m3fn
return dtype_map.get(dtype_str)
@staticmethod
def _convert_float8(byte_tensor, dtype_str, shape):
if dtype_str == "F8_E5M2" and hasattr(torch, "float8_e5m2"):
return byte_tensor.view(torch.float8_e5m2).reshape(shape)
elif dtype_str == "F8_E4M3" and hasattr(torch, "float8_e4m3fn"):
return byte_tensor.view(torch.float8_e4m3fn).reshape(shape)
else:
# # convert to float16 if float8 is not supported
# print(f"Warning: {dtype_str} is not supported in this PyTorch version. Converting to float16.")
# return byte_tensor.view(torch.uint8).to(torch.float16).reshape(shape)
raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)")
def load_safetensors(
path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32
) -> dict[str, torch.Tensor]:
if disable_mmap:
# return safetensors.torch.load(open(path, "rb").read())
# use experimental loader
# logger.info(f"Loading without mmap (experimental)")
state_dict = {}
with MemoryEfficientSafeOpen(path) as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key).to(device, dtype=dtype)
return state_dict
else:
try:
state_dict = load_file(path, device=device)
except:
state_dict = load_file(path) # prevent device invalid Error
if dtype is not None:
for key in state_dict.keys():
state_dict[key] = state_dict[key].to(dtype=dtype)
return state_dict
# endregion
# region Image utils
def pil_resize(image, size, interpolation=Image.LANCZOS):
has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False
if has_alpha:
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA))
else:
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
resized_pil = pil_image.resize(size, interpolation)
# Convert back to cv2 format
if has_alpha:
resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGBA2BGRA)
else:
resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR)
return resized_cv2
# endregion
# TODO make inf_utils.py
# region Gradual Latent hires fix

View File

@@ -18,7 +18,7 @@ def main(file):
keys = list(sd.keys())
for key in keys:
if "lora_up" in key or "lora_down" in key:
if "lora_up" in key or "lora_down" in key or "lora_A" in key or "lora_B" in key or "oft_" in key:
values.append((key, sd[key]))
print(f"number of LoRA modules: {len(values)}")

View File

@@ -7,8 +7,10 @@ from typing import Optional, List, Type
import torch
from library import sdxl_original_unet
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# input_blocksに適用するかどうか / if True, input_blocks are not applied
@@ -103,19 +105,15 @@ class LLLiteLinear(ORIGINAL_LINEAR):
add_lllite_modules(self, in_dim, depth, cond_emb_dim, mlp_dim)
self.cond_image = None
self.cond_emb = None
def set_cond_image(self, cond_image):
self.cond_image = cond_image
self.cond_emb = None
def forward(self, x):
if not self.enabled:
return super().forward(x)
if self.cond_emb is None:
self.cond_emb = self.lllite_conditioning1(self.cond_image)
cx = self.cond_emb
cx = self.lllite_conditioning1(self.cond_image) # make forward and backward compatible
# reshape / b,c,h,w -> b,h*w,c
n, c, h, w = cx.shape
@@ -159,9 +157,7 @@ class LLLiteConv2d(ORIGINAL_CONV2D):
if not self.enabled:
return super().forward(x)
if self.cond_emb is None:
self.cond_emb = self.lllite_conditioning1(self.cond_image)
cx = self.cond_emb
cx = self.lllite_conditioning1(self.cond_image)
cx = torch.cat([cx, self.down(x)], dim=1)
cx = self.mid(cx)

View File

@@ -0,0 +1,434 @@
# convert key mapping and data format from some LoRA format to another
"""
Original LoRA format: Based on Black Forest Labs, QKV and MLP are unified into one module
alpha is scalar for each LoRA module
0 to 18
lora_unet_double_blocks_0_img_attn_proj.alpha torch.Size([])
lora_unet_double_blocks_0_img_attn_proj.lora_down.weight torch.Size([4, 3072])
lora_unet_double_blocks_0_img_attn_proj.lora_up.weight torch.Size([3072, 4])
lora_unet_double_blocks_0_img_attn_qkv.alpha torch.Size([])
lora_unet_double_blocks_0_img_attn_qkv.lora_down.weight torch.Size([4, 3072])
lora_unet_double_blocks_0_img_attn_qkv.lora_up.weight torch.Size([9216, 4])
lora_unet_double_blocks_0_img_mlp_0.alpha torch.Size([])
lora_unet_double_blocks_0_img_mlp_0.lora_down.weight torch.Size([4, 3072])
lora_unet_double_blocks_0_img_mlp_0.lora_up.weight torch.Size([12288, 4])
lora_unet_double_blocks_0_img_mlp_2.alpha torch.Size([])
lora_unet_double_blocks_0_img_mlp_2.lora_down.weight torch.Size([4, 12288])
lora_unet_double_blocks_0_img_mlp_2.lora_up.weight torch.Size([3072, 4])
lora_unet_double_blocks_0_img_mod_lin.alpha torch.Size([])
lora_unet_double_blocks_0_img_mod_lin.lora_down.weight torch.Size([4, 3072])
lora_unet_double_blocks_0_img_mod_lin.lora_up.weight torch.Size([18432, 4])
lora_unet_double_blocks_0_txt_attn_proj.alpha torch.Size([])
lora_unet_double_blocks_0_txt_attn_proj.lora_down.weight torch.Size([4, 3072])
lora_unet_double_blocks_0_txt_attn_proj.lora_up.weight torch.Size([3072, 4])
lora_unet_double_blocks_0_txt_attn_qkv.alpha torch.Size([])
lora_unet_double_blocks_0_txt_attn_qkv.lora_down.weight torch.Size([4, 3072])
lora_unet_double_blocks_0_txt_attn_qkv.lora_up.weight torch.Size([9216, 4])
lora_unet_double_blocks_0_txt_mlp_0.alpha torch.Size([])
lora_unet_double_blocks_0_txt_mlp_0.lora_down.weight torch.Size([4, 3072])
lora_unet_double_blocks_0_txt_mlp_0.lora_up.weight torch.Size([12288, 4])
lora_unet_double_blocks_0_txt_mlp_2.alpha torch.Size([])
lora_unet_double_blocks_0_txt_mlp_2.lora_down.weight torch.Size([4, 12288])
lora_unet_double_blocks_0_txt_mlp_2.lora_up.weight torch.Size([3072, 4])
lora_unet_double_blocks_0_txt_mod_lin.alpha torch.Size([])
lora_unet_double_blocks_0_txt_mod_lin.lora_down.weight torch.Size([4, 3072])
lora_unet_double_blocks_0_txt_mod_lin.lora_up.weight torch.Size([18432, 4])
0 to 37
lora_unet_single_blocks_0_linear1.alpha torch.Size([])
lora_unet_single_blocks_0_linear1.lora_down.weight torch.Size([4, 3072])
lora_unet_single_blocks_0_linear1.lora_up.weight torch.Size([21504, 4])
lora_unet_single_blocks_0_linear2.alpha torch.Size([])
lora_unet_single_blocks_0_linear2.lora_down.weight torch.Size([4, 15360])
lora_unet_single_blocks_0_linear2.lora_up.weight torch.Size([3072, 4])
lora_unet_single_blocks_0_modulation_lin.alpha torch.Size([])
lora_unet_single_blocks_0_modulation_lin.lora_down.weight torch.Size([4, 3072])
lora_unet_single_blocks_0_modulation_lin.lora_up.weight torch.Size([9216, 4])
"""
"""
ai-toolkit: Based on Diffusers, QKV and MLP are separated into 3 modules.
A is down, B is up. No alpha for each LoRA module.
0 to 18
transformer.transformer_blocks.0.attn.add_k_proj.lora_A.weight torch.Size([16, 3072])
transformer.transformer_blocks.0.attn.add_k_proj.lora_B.weight torch.Size([3072, 16])
transformer.transformer_blocks.0.attn.add_q_proj.lora_A.weight torch.Size([16, 3072])
transformer.transformer_blocks.0.attn.add_q_proj.lora_B.weight torch.Size([3072, 16])
transformer.transformer_blocks.0.attn.add_v_proj.lora_A.weight torch.Size([16, 3072])
transformer.transformer_blocks.0.attn.add_v_proj.lora_B.weight torch.Size([3072, 16])
transformer.transformer_blocks.0.attn.to_add_out.lora_A.weight torch.Size([16, 3072])
transformer.transformer_blocks.0.attn.to_add_out.lora_B.weight torch.Size([3072, 16])
transformer.transformer_blocks.0.attn.to_k.lora_A.weight torch.Size([16, 3072])
transformer.transformer_blocks.0.attn.to_k.lora_B.weight torch.Size([3072, 16])
transformer.transformer_blocks.0.attn.to_out.0.lora_A.weight torch.Size([16, 3072])
transformer.transformer_blocks.0.attn.to_out.0.lora_B.weight torch.Size([3072, 16])
transformer.transformer_blocks.0.attn.to_q.lora_A.weight torch.Size([16, 3072])
transformer.transformer_blocks.0.attn.to_q.lora_B.weight torch.Size([3072, 16])
transformer.transformer_blocks.0.attn.to_v.lora_A.weight torch.Size([16, 3072])
transformer.transformer_blocks.0.attn.to_v.lora_B.weight torch.Size([3072, 16])
transformer.transformer_blocks.0.ff.net.0.proj.lora_A.weight torch.Size([16, 3072])
transformer.transformer_blocks.0.ff.net.0.proj.lora_B.weight torch.Size([12288, 16])
transformer.transformer_blocks.0.ff.net.2.lora_A.weight torch.Size([16, 12288])
transformer.transformer_blocks.0.ff.net.2.lora_B.weight torch.Size([3072, 16])
transformer.transformer_blocks.0.ff_context.net.0.proj.lora_A.weight torch.Size([16, 3072])
transformer.transformer_blocks.0.ff_context.net.0.proj.lora_B.weight torch.Size([12288, 16])
transformer.transformer_blocks.0.ff_context.net.2.lora_A.weight torch.Size([16, 12288])
transformer.transformer_blocks.0.ff_context.net.2.lora_B.weight torch.Size([3072, 16])
transformer.transformer_blocks.0.norm1.linear.lora_A.weight torch.Size([16, 3072])
transformer.transformer_blocks.0.norm1.linear.lora_B.weight torch.Size([18432, 16])
transformer.transformer_blocks.0.norm1_context.linear.lora_A.weight torch.Size([16, 3072])
transformer.transformer_blocks.0.norm1_context.linear.lora_B.weight torch.Size([18432, 16])
0 to 37
transformer.single_transformer_blocks.0.attn.to_k.lora_A.weight torch.Size([16, 3072])
transformer.single_transformer_blocks.0.attn.to_k.lora_B.weight torch.Size([3072, 16])
transformer.single_transformer_blocks.0.attn.to_q.lora_A.weight torch.Size([16, 3072])
transformer.single_transformer_blocks.0.attn.to_q.lora_B.weight torch.Size([3072, 16])
transformer.single_transformer_blocks.0.attn.to_v.lora_A.weight torch.Size([16, 3072])
transformer.single_transformer_blocks.0.attn.to_v.lora_B.weight torch.Size([3072, 16])
transformer.single_transformer_blocks.0.norm.linear.lora_A.weight torch.Size([16, 3072])
transformer.single_transformer_blocks.0.norm.linear.lora_B.weight torch.Size([9216, 16])
transformer.single_transformer_blocks.0.proj_mlp.lora_A.weight torch.Size([16, 3072])
transformer.single_transformer_blocks.0.proj_mlp.lora_B.weight torch.Size([12288, 16])
transformer.single_transformer_blocks.0.proj_out.lora_A.weight torch.Size([16, 15360])
transformer.single_transformer_blocks.0.proj_out.lora_B.weight torch.Size([3072, 16])
"""
"""
xlabs: Unknown format.
0 to 18
double_blocks.0.processor.proj_lora1.down.weight torch.Size([16, 3072])
double_blocks.0.processor.proj_lora1.up.weight torch.Size([3072, 16])
double_blocks.0.processor.proj_lora2.down.weight torch.Size([16, 3072])
double_blocks.0.processor.proj_lora2.up.weight torch.Size([3072, 16])
double_blocks.0.processor.qkv_lora1.down.weight torch.Size([16, 3072])
double_blocks.0.processor.qkv_lora1.up.weight torch.Size([9216, 16])
double_blocks.0.processor.qkv_lora2.down.weight torch.Size([16, 3072])
double_blocks.0.processor.qkv_lora2.up.weight torch.Size([9216, 16])
"""
import argparse
from safetensors.torch import save_file
from safetensors import safe_open
import torch
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def convert_to_sd_scripts(sds_sd, ait_sd, sds_key, ait_key):
ait_down_key = ait_key + ".lora_A.weight"
if ait_down_key not in ait_sd:
return
ait_up_key = ait_key + ".lora_B.weight"
down_weight = ait_sd.pop(ait_down_key)
sds_sd[sds_key + ".lora_down.weight"] = down_weight
sds_sd[sds_key + ".lora_up.weight"] = ait_sd.pop(ait_up_key)
rank = down_weight.shape[0]
sds_sd[sds_key + ".alpha"] = torch.scalar_tensor(rank, dtype=down_weight.dtype, device=down_weight.device)
def convert_to_sd_scripts_cat(sds_sd, ait_sd, sds_key, ait_keys):
ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
if ait_down_keys[0] not in ait_sd:
return
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
down_weights = [ait_sd.pop(k) for k in ait_down_keys]
up_weights = [ait_sd.pop(k) for k in ait_up_keys]
# lora_down is concatenated along dim=0, so rank is multiplied by the number of splits
rank = down_weights[0].shape[0]
num_splits = len(ait_keys)
sds_sd[sds_key + ".lora_down.weight"] = torch.cat(down_weights, dim=0)
merged_up_weights = torch.zeros(
(sum(w.shape[0] for w in up_weights), rank * num_splits),
dtype=up_weights[0].dtype,
device=up_weights[0].device,
)
i = 0
for j, up_weight in enumerate(up_weights):
merged_up_weights[i : i + up_weight.shape[0], j * rank : (j + 1) * rank] = up_weight
i += up_weight.shape[0]
sds_sd[sds_key + ".lora_up.weight"] = merged_up_weights
# set alpha to new_rank
new_rank = rank * num_splits
sds_sd[sds_key + ".alpha"] = torch.scalar_tensor(new_rank, dtype=down_weights[0].dtype, device=down_weights[0].device)
def convert_ai_toolkit_to_sd_scripts(ait_sd):
sds_sd = {}
for i in range(19):
convert_to_sd_scripts(
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_attn_proj", f"transformer.transformer_blocks.{i}.attn.to_out.0"
)
convert_to_sd_scripts_cat(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_img_attn_qkv",
[
f"transformer.transformer_blocks.{i}.attn.to_q",
f"transformer.transformer_blocks.{i}.attn.to_k",
f"transformer.transformer_blocks.{i}.attn.to_v",
],
)
convert_to_sd_scripts(
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mlp_0", f"transformer.transformer_blocks.{i}.ff.net.0.proj"
)
convert_to_sd_scripts(
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mlp_2", f"transformer.transformer_blocks.{i}.ff.net.2"
)
convert_to_sd_scripts(
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mod_lin", f"transformer.transformer_blocks.{i}.norm1.linear"
)
convert_to_sd_scripts(
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_attn_proj", f"transformer.transformer_blocks.{i}.attn.to_add_out"
)
convert_to_sd_scripts_cat(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_attn_qkv",
[
f"transformer.transformer_blocks.{i}.attn.add_q_proj",
f"transformer.transformer_blocks.{i}.attn.add_k_proj",
f"transformer.transformer_blocks.{i}.attn.add_v_proj",
],
)
convert_to_sd_scripts(
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mlp_0", f"transformer.transformer_blocks.{i}.ff_context.net.0.proj"
)
convert_to_sd_scripts(
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mlp_2", f"transformer.transformer_blocks.{i}.ff_context.net.2"
)
convert_to_sd_scripts(
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mod_lin", f"transformer.transformer_blocks.{i}.norm1_context.linear"
)
for i in range(38):
convert_to_sd_scripts_cat(
sds_sd,
ait_sd,
f"lora_unet_single_blocks_{i}_linear1",
[
f"transformer.single_transformer_blocks.{i}.attn.to_q",
f"transformer.single_transformer_blocks.{i}.attn.to_k",
f"transformer.single_transformer_blocks.{i}.attn.to_v",
f"transformer.single_transformer_blocks.{i}.proj_mlp",
],
)
convert_to_sd_scripts(
sds_sd, ait_sd, f"lora_unet_single_blocks_{i}_linear2", f"transformer.single_transformer_blocks.{i}.proj_out"
)
convert_to_sd_scripts(
sds_sd, ait_sd, f"lora_unet_single_blocks_{i}_modulation_lin", f"transformer.single_transformer_blocks.{i}.norm.linear"
)
if len(ait_sd) > 0:
logger.warning(f"Unsuppored keys for sd-scripts: {ait_sd.keys()}")
return sds_sd
def convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
if sds_key + ".lora_down.weight" not in sds_sd:
return
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
# scale weight by alpha and dim
rank = down_weight.shape[0]
alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
# print(f"rank: {rank}, alpha: {alpha}, scale: {scale}")
# calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
# print(f"scale: {scale}, scale_down: {scale_down}, scale_up: {scale_up}")
ait_sd[ait_key + ".lora_A.weight"] = down_weight * scale_down
ait_sd[ait_key + ".lora_B.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up
def convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
if sds_key + ".lora_down.weight" not in sds_sd:
return
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
up_weight = sds_sd.pop(sds_key + ".lora_up.weight")
sd_lora_rank = down_weight.shape[0]
# scale weight by alpha and dim
alpha = sds_sd.pop(sds_key + ".alpha")
scale = alpha / sd_lora_rank
# calculate scale_down and scale_up
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
down_weight = down_weight * scale_down
up_weight = up_weight * scale_up
# calculate dims if not provided
num_splits = len(ait_keys)
if dims is None:
dims = [up_weight.shape[0] // num_splits] * num_splits
else:
assert sum(dims) == up_weight.shape[0]
# check upweight is sparse or not
is_sparse = False
if sd_lora_rank % num_splits == 0:
ait_rank = sd_lora_rank // num_splits
is_sparse = True
i = 0
for j in range(len(dims)):
for k in range(len(dims)):
if j == k:
continue
is_sparse = is_sparse and torch.all(up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0)
i += dims[j]
if is_sparse:
logger.info(f"weight is sparse: {sds_key}")
# make ai-toolkit weight
ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
if not is_sparse:
# down_weight is copied to each split
ait_sd.update({k: down_weight for k in ait_down_keys})
# up_weight is split to each split
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))})
else:
# down_weight is chunked to each split
ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))})
# up_weight is sparse: only non-zero values are copied to each split
i = 0
for j in range(len(dims)):
ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous()
i += dims[j]
def convert_sd_scripts_to_ai_toolkit(sds_sd):
ait_sd = {}
for i in range(19):
convert_to_ai_toolkit(
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_attn_proj", f"transformer.transformer_blocks.{i}.attn.to_out.0"
)
convert_to_ai_toolkit_cat(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_img_attn_qkv",
[
f"transformer.transformer_blocks.{i}.attn.to_q",
f"transformer.transformer_blocks.{i}.attn.to_k",
f"transformer.transformer_blocks.{i}.attn.to_v",
],
)
convert_to_ai_toolkit(
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mlp_0", f"transformer.transformer_blocks.{i}.ff.net.0.proj"
)
convert_to_ai_toolkit(
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mlp_2", f"transformer.transformer_blocks.{i}.ff.net.2"
)
convert_to_ai_toolkit(
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_img_mod_lin", f"transformer.transformer_blocks.{i}.norm1.linear"
)
convert_to_ai_toolkit(
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_attn_proj", f"transformer.transformer_blocks.{i}.attn.to_add_out"
)
convert_to_ai_toolkit_cat(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_attn_qkv",
[
f"transformer.transformer_blocks.{i}.attn.add_q_proj",
f"transformer.transformer_blocks.{i}.attn.add_k_proj",
f"transformer.transformer_blocks.{i}.attn.add_v_proj",
],
)
convert_to_ai_toolkit(
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mlp_0", f"transformer.transformer_blocks.{i}.ff_context.net.0.proj"
)
convert_to_ai_toolkit(
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mlp_2", f"transformer.transformer_blocks.{i}.ff_context.net.2"
)
convert_to_ai_toolkit(
sds_sd, ait_sd, f"lora_unet_double_blocks_{i}_txt_mod_lin", f"transformer.transformer_blocks.{i}.norm1_context.linear"
)
for i in range(38):
convert_to_ai_toolkit_cat(
sds_sd,
ait_sd,
f"lora_unet_single_blocks_{i}_linear1",
[
f"transformer.single_transformer_blocks.{i}.attn.to_q",
f"transformer.single_transformer_blocks.{i}.attn.to_k",
f"transformer.single_transformer_blocks.{i}.attn.to_v",
f"transformer.single_transformer_blocks.{i}.proj_mlp",
],
dims=[3072, 3072, 3072, 12288],
)
convert_to_ai_toolkit(
sds_sd, ait_sd, f"lora_unet_single_blocks_{i}_linear2", f"transformer.single_transformer_blocks.{i}.proj_out"
)
convert_to_ai_toolkit(
sds_sd, ait_sd, f"lora_unet_single_blocks_{i}_modulation_lin", f"transformer.single_transformer_blocks.{i}.norm.linear"
)
if len(sds_sd) > 0:
logger.warning(f"Unsuppored keys for ai-toolkit: {sds_sd.keys()}")
return ait_sd
def main(args):
# load source safetensors
logger.info(f"Loading source file {args.src_path}")
state_dict = {}
with safe_open(args.src_path, framework="pt") as f:
metadata = f.metadata()
for k in f.keys():
state_dict[k] = f.get_tensor(k)
logger.info(f"Converting {args.src} to {args.dst} format")
if args.src == "ai-toolkit" and args.dst == "sd-scripts":
state_dict = convert_ai_toolkit_to_sd_scripts(state_dict)
elif args.src == "sd-scripts" and args.dst == "ai-toolkit":
state_dict = convert_sd_scripts_to_ai_toolkit(state_dict)
# eliminate 'shared tensors'
for k in list(state_dict.keys()):
state_dict[k] = state_dict[k].detach().clone()
else:
raise NotImplementedError(f"Conversion from {args.src} to {args.dst} is not supported")
# save destination safetensors
logger.info(f"Saving destination file {args.dst_path}")
save_file(state_dict, args.dst_path, metadata=metadata)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert LoRA format")
parser.add_argument("--src", type=str, default="ai-toolkit", help="source format, ai-toolkit or sd-scripts")
parser.add_argument("--dst", type=str, default="sd-scripts", help="destination format, ai-toolkit or sd-scripts")
parser.add_argument("--src_path", type=str, default=None, help="source path")
parser.add_argument("--dst_path", type=str, default=None, help="destination path")
args = parser.parse_args()
main(args)

View File

@@ -18,10 +18,13 @@ from transformers import CLIPTextModel
import torch
from torch import nn
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class DyLoRAModule(torch.nn.Module):
"""
replaces forward method of the original Linear, instead of replacing the original Linear module.
@@ -195,7 +198,7 @@ def create_network(
conv_alpha = 1.0
else:
conv_alpha = float(conv_alpha)
if unit is not None:
unit = int(unit)
else:
@@ -211,6 +214,16 @@ def create_network(
unit=unit,
varbose=True,
)
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
return network
@@ -280,6 +293,10 @@ class DyLoRANetwork(torch.nn.Module):
self.alpha = alpha
self.apply_to_conv = apply_to_conv
self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None
self.loraplus_text_encoder_lr_ratio = None
if modules_dim is not None:
logger.info("create LoRA network from weights")
else:
@@ -320,9 +337,9 @@ class DyLoRANetwork(torch.nn.Module):
lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, unit)
loras.append(lora)
return loras
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
self.text_encoder_loras = []
for i, text_encoder in enumerate(text_encoders):
if len(text_encoders) > 1:
@@ -331,7 +348,7 @@ class DyLoRANetwork(torch.nn.Module):
else:
index = None
logger.info("create LoRA for Text Encoder")
text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
self.text_encoder_loras.extend(text_encoder_loras)
@@ -346,6 +363,14 @@ class DyLoRANetwork(torch.nn.Module):
self.unet_loras = create_modules(True, unet, target_modules)
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
self.loraplus_lr_ratio = loraplus_lr_ratio
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}")
logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}")
def set_multiplier(self, multiplier):
self.multiplier = multiplier
for lora in self.text_encoder_loras + self.unet_loras:
@@ -406,27 +431,53 @@ class DyLoRANetwork(torch.nn.Module):
logger.info(f"weights are merged")
"""
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
self.requires_grad_(True)
all_params = []
def enumerate_params(loras):
params = []
def assemble_params(loras, lr, ratio):
param_groups = {"lora": {}, "plus": {}}
for lora in loras:
params.extend(lora.parameters())
for name, param in lora.named_parameters():
if ratio is not None and "lora_B" in name:
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
else:
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
params = []
for key in param_groups.keys():
param_data = {"params": param_groups[key].values()}
if len(param_data["params"]) == 0:
continue
if lr is not None:
if key == "plus":
param_data["lr"] = lr * ratio
else:
param_data["lr"] = lr
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
continue
params.append(param_data)
return params
if self.text_encoder_loras:
param_data = {"params": enumerate_params(self.text_encoder_loras)}
if text_encoder_lr is not None:
param_data["lr"] = text_encoder_lr
all_params.append(param_data)
params = assemble_params(
self.text_encoder_loras,
text_encoder_lr if text_encoder_lr is not None else default_lr,
self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio,
)
all_params.extend(params)
if self.unet_loras:
param_data = {"params": enumerate_params(self.unet_loras)}
if unet_lr is not None:
param_data["lr"] = unet_lr
all_params.append(param_data)
params = assemble_params(
self.unet_loras, default_lr if unet_lr is None else unet_lr, self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio
)
all_params.extend(params)
return all_params

View File

@@ -0,0 +1,219 @@
# extract approximating LoRA by svd from two FLUX models
# The code is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
# Thanks to cloneofsimo!
import argparse
import json
import os
import time
import torch
from safetensors.torch import load_file, save_file
from safetensors import safe_open
from tqdm import tqdm
from library import flux_utils, sai_model_spec, model_util, sdxl_model_util
import lora
from library.utils import MemoryEfficientSafeOpen
from library.utils import setup_logging
from networks import lora_flux
setup_logging()
import logging
logger = logging.getLogger(__name__)
# CLAMP_QUANTILE = 0.99
# MIN_DIFF = 1e-1
def save_to_file(file_name, state_dict, metadata, dtype):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
save_file(state_dict, file_name, metadata=metadata)
def svd(
model_org=None,
model_tuned=None,
save_to=None,
dim=4,
device=None,
save_precision=None,
clamp_quantile=0.99,
min_diff=0.01,
no_metadata=False,
mem_eff_safe_open=False,
):
def str_to_dtype(p):
if p == "float":
return torch.float
if p == "fp16":
return torch.float16
if p == "bf16":
return torch.bfloat16
return None
calc_dtype = torch.float
save_dtype = str_to_dtype(save_precision)
store_device = "cpu"
# open models
lora_weights = {}
if not mem_eff_safe_open:
# use original safetensors.safe_open
open_fn = lambda fn: safe_open(fn, framework="pt")
else:
logger.info("Using memory efficient safe_open")
open_fn = lambda fn: MemoryEfficientSafeOpen(fn)
with open_fn(model_org) as f_org:
# filter keys
keys = []
for key in f_org.keys():
if not ("single_block" in key or "double_block" in key):
continue
if ".bias" in key:
continue
if "norm" in key:
continue
keys.append(key)
with open_fn(model_tuned) as f_tuned:
for key in tqdm(keys):
# get tensors and calculate difference
value_o = f_org.get_tensor(key)
value_t = f_tuned.get_tensor(key)
mat = value_t.to(calc_dtype) - value_o.to(calc_dtype)
del value_o, value_t
# extract LoRA weights
if device:
mat = mat.to(device)
out_dim, in_dim = mat.size()[0:2]
rank = min(dim, in_dim, out_dim) # LoRA rank cannot exceed the original dim
mat = mat.squeeze()
U, S, Vh = torch.linalg.svd(mat)
U = U[:, :rank]
S = S[:rank]
U = U @ torch.diag(S)
Vh = Vh[:rank, :]
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, clamp_quantile)
low_val = -hi_val
U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)
U = U.to(store_device, dtype=save_dtype).contiguous()
Vh = Vh.to(store_device, dtype=save_dtype).contiguous()
# print(f"key: {key}, U: {U.size()}, Vh: {Vh.size()}")
lora_weights[key] = (U, Vh)
del mat, U, S, Vh
# make state dict for LoRA
lora_sd = {}
for key, (up_weight, down_weight) in lora_weights.items():
lora_name = key.replace(".weight", "").replace(".", "_")
lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + lora_name
lora_sd[lora_name + ".lora_up.weight"] = up_weight
lora_sd[lora_name + ".lora_down.weight"] = down_weight
lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0]) # same as rank
# minimum metadata
net_kwargs = {}
metadata = {
"ss_v2": str(False),
"ss_base_model_version": flux_utils.MODEL_VERSION_FLUX_V1,
"ss_network_module": "networks.lora_flux",
"ss_network_dim": str(dim),
"ss_network_alpha": str(float(dim)),
"ss_network_args": json.dumps(net_kwargs),
}
if not no_metadata:
title = os.path.splitext(os.path.basename(save_to))[0]
sai_metadata = sai_model_spec.build_metadata(lora_sd, False, False, False, True, False, time.time(), title, flux="dev")
metadata.update(sai_metadata)
save_to_file(save_to, lora_sd, metadata, save_dtype)
logger.info(f"LoRA weights saved to {save_to}")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument(
"--save_precision",
type=str,
default=None,
choices=[None, "float", "fp16", "bf16"],
help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はfloat",
)
parser.add_argument(
"--model_org",
type=str,
default=None,
required=True,
help="Original model: safetensors file / 元モデル、safetensors",
)
parser.add_argument(
"--model_tuned",
type=str,
default=None,
required=True,
help="Tuned model, LoRA is difference of `original to tuned`: safetensors file / 派生モデル生成されるLoRAは元→派生の差分になります、ckptまたはsafetensors",
)
parser.add_argument(
"--mem_eff_safe_open",
action="store_true",
help="use memory efficient safe_open. This is an experimental feature, use only when memory is not enough."
" / メモリ効率の良いsafe_openを使用する。実装は実験的なものなので、メモリが足りない場合のみ使用してください。",
)
parser.add_argument(
"--save_to",
type=str,
default=None,
required=True,
help="destination file name: safetensors file / 保存先のファイル名、safetensors",
)
parser.add_argument(
"--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数rankデフォルト4"
)
parser.add_argument(
"--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う"
)
parser.add_argument(
"--clamp_quantile",
type=float,
default=0.99,
help="Quantile clamping value, float, (0-1). Default = 0.99 / 値をクランプするための分位点、float、(0-1)。デフォルトは0.99",
)
# parser.add_argument(
# "--min_diff",
# type=float,
# default=0.01,
# help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01 /"
# + "LoRAを抽出するために元モデルと派生モデルの差分の最小値、float、(0-1)。デフォルトは0.01",
# )
parser.add_argument(
"--no_metadata",
action="store_true",
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
+ "sai modelspecのメタデータを保存しないLoRAの最低限のss_metadataは保存される",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
svd(**vars(args))

765
networks/flux_merge_lora.py Normal file
View File

@@ -0,0 +1,765 @@
import argparse
import math
import os
import time
from typing import Any, Dict, Union
import torch
from safetensors import safe_open
from safetensors.torch import load_file, save_file
from tqdm import tqdm
from library.utils import setup_logging, str_to_dtype, MemoryEfficientSafeOpen, mem_eff_save_file
setup_logging()
import logging
logger = logging.getLogger(__name__)
import lora_flux as lora_flux
from library import sai_model_spec, train_util
def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == ".safetensors":
sd = load_file(file_name)
metadata = train_util.load_metadata_from_safetensors(file_name)
else:
sd = torch.load(file_name, map_location="cpu")
metadata = {}
for key in list(sd.keys()):
if type(sd[key]) == torch.Tensor:
sd[key] = sd[key].to(dtype)
return sd, metadata
def save_to_file(file_name, state_dict: Dict[str, Union[Any, torch.Tensor]], dtype, metadata, mem_eff_save=False):
if dtype is not None:
logger.info(f"converting to {dtype}...")
for key in tqdm(list(state_dict.keys())):
if type(state_dict[key]) == torch.Tensor and state_dict[key].dtype.is_floating_point:
state_dict[key] = state_dict[key].to(dtype)
logger.info(f"saving to: {file_name}")
if mem_eff_save:
mem_eff_save_file(state_dict, file_name, metadata=metadata)
else:
save_file(state_dict, file_name, metadata=metadata)
def merge_to_flux_model(
loading_device,
working_device,
flux_path: str,
clip_l_path: str,
t5xxl_path: str,
models,
ratios,
merge_dtype,
save_dtype,
mem_eff_load_save=False,
):
# create module map without loading state_dict
lora_name_to_module_key = {}
if flux_path is not None:
logger.info(f"loading keys from FLUX.1 model: {flux_path}")
with safe_open(flux_path, framework="pt", device=loading_device) as flux_file:
keys = list(flux_file.keys())
for key in keys:
if key.endswith(".weight"):
module_name = ".".join(key.split(".")[:-1])
lora_name = lora_flux.LoRANetwork.LORA_PREFIX_FLUX + "_" + module_name.replace(".", "_")
lora_name_to_module_key[lora_name] = key
lora_name_to_clip_l_key = {}
if clip_l_path is not None:
logger.info(f"loading keys from clip_l model: {clip_l_path}")
with safe_open(clip_l_path, framework="pt", device=loading_device) as clip_l_file:
keys = list(clip_l_file.keys())
for key in keys:
if key.endswith(".weight"):
module_name = ".".join(key.split(".")[:-1])
lora_name = lora_flux.LoRANetwork.LORA_PREFIX_TEXT_ENCODER_CLIP + "_" + module_name.replace(".", "_")
lora_name_to_clip_l_key[lora_name] = key
lora_name_to_t5xxl_key = {}
if t5xxl_path is not None:
logger.info(f"loading keys from t5xxl model: {t5xxl_path}")
with safe_open(t5xxl_path, framework="pt", device=loading_device) as t5xxl_file:
keys = list(t5xxl_file.keys())
for key in keys:
if key.endswith(".weight"):
module_name = ".".join(key.split(".")[:-1])
lora_name = lora_flux.LoRANetwork.LORA_PREFIX_TEXT_ENCODER_T5 + "_" + module_name.replace(".", "_")
lora_name_to_t5xxl_key[lora_name] = key
flux_state_dict = {}
clip_l_state_dict = {}
t5xxl_state_dict = {}
if mem_eff_load_save:
if flux_path is not None:
with MemoryEfficientSafeOpen(flux_path) as flux_file:
for key in tqdm(flux_file.keys()):
flux_state_dict[key] = flux_file.get_tensor(key).to(loading_device) # dtype is not changed
if clip_l_path is not None:
with MemoryEfficientSafeOpen(clip_l_path) as clip_l_file:
for key in tqdm(clip_l_file.keys()):
clip_l_state_dict[key] = clip_l_file.get_tensor(key).to(loading_device)
if t5xxl_path is not None:
with MemoryEfficientSafeOpen(t5xxl_path) as t5xxl_file:
for key in tqdm(t5xxl_file.keys()):
t5xxl_state_dict[key] = t5xxl_file.get_tensor(key).to(loading_device)
else:
if flux_path is not None:
flux_state_dict = load_file(flux_path, device=loading_device)
if clip_l_path is not None:
clip_l_state_dict = load_file(clip_l_path, device=loading_device)
if t5xxl_path is not None:
t5xxl_state_dict = load_file(t5xxl_path, device=loading_device)
for model, ratio in zip(models, ratios):
logger.info(f"loading: {model}")
lora_sd, _ = load_state_dict(model, merge_dtype) # loading on CPU
logger.info(f"merging...")
for key in tqdm(list(lora_sd.keys())):
if "lora_down" in key:
lora_name = key[: key.rfind(".lora_down")]
up_key = key.replace("lora_down", "lora_up")
alpha_key = key[: key.index("lora_down")] + "alpha"
if lora_name in lora_name_to_module_key:
module_weight_key = lora_name_to_module_key[lora_name]
state_dict = flux_state_dict
elif lora_name in lora_name_to_clip_l_key:
module_weight_key = lora_name_to_clip_l_key[lora_name]
state_dict = clip_l_state_dict
elif lora_name in lora_name_to_t5xxl_key:
module_weight_key = lora_name_to_t5xxl_key[lora_name]
state_dict = t5xxl_state_dict
else:
logger.warning(
f"no module found for LoRA weight: {key}. Skipping..."
f"LoRAの重みに対応するモジュールが見つかりませんでした。スキップします。"
)
continue
down_weight = lora_sd.pop(key)
up_weight = lora_sd.pop(up_key)
dim = down_weight.size()[0]
alpha = lora_sd.pop(alpha_key, dim)
scale = alpha / dim
# W <- W + U * D
weight = state_dict[module_weight_key]
weight = weight.to(working_device, merge_dtype)
up_weight = up_weight.to(working_device, merge_dtype)
down_weight = down_weight.to(working_device, merge_dtype)
# logger.info(module_name, down_weight.size(), up_weight.size())
if len(weight.size()) == 2:
# linear
weight = weight + ratio * (up_weight @ down_weight) * scale
elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1
weight = (
weight
+ ratio
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* scale
)
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + ratio * conved * scale
state_dict[module_weight_key] = weight.to(loading_device, save_dtype)
del up_weight
del down_weight
del weight
if len(lora_sd) > 0:
logger.warning(f"Unused keys in LoRA model: {list(lora_sd.keys())}")
return flux_state_dict, clip_l_state_dict, t5xxl_state_dict
def merge_to_flux_model_diffusers(
loading_device, working_device, flux_model, models, ratios, merge_dtype, save_dtype, mem_eff_load_save=False
):
logger.info(f"loading keys from FLUX.1 model: {flux_model}")
if mem_eff_load_save:
flux_state_dict = {}
with MemoryEfficientSafeOpen(flux_model) as flux_file:
for key in tqdm(flux_file.keys()):
flux_state_dict[key] = flux_file.get_tensor(key).to(loading_device) # dtype is not changed
else:
flux_state_dict = load_file(flux_model, device=loading_device)
def create_key_map(n_double_layers, n_single_layers):
key_map = {}
for index in range(n_double_layers):
prefix_from = f"transformer_blocks.{index}"
prefix_to = f"double_blocks.{index}"
for end in ("weight", "bias"):
k = f"{prefix_from}.attn."
qkv_img = f"{prefix_to}.img_attn.qkv.{end}"
qkv_txt = f"{prefix_to}.txt_attn.qkv.{end}"
key_map[f"{k}to_q.{end}"] = qkv_img
key_map[f"{k}to_k.{end}"] = qkv_img
key_map[f"{k}to_v.{end}"] = qkv_img
key_map[f"{k}add_q_proj.{end}"] = qkv_txt
key_map[f"{k}add_k_proj.{end}"] = qkv_txt
key_map[f"{k}add_v_proj.{end}"] = qkv_txt
block_map = {
"attn.to_out.0.weight": "img_attn.proj.weight",
"attn.to_out.0.bias": "img_attn.proj.bias",
"norm1.linear.weight": "img_mod.lin.weight",
"norm1.linear.bias": "img_mod.lin.bias",
"norm1_context.linear.weight": "txt_mod.lin.weight",
"norm1_context.linear.bias": "txt_mod.lin.bias",
"attn.to_add_out.weight": "txt_attn.proj.weight",
"attn.to_add_out.bias": "txt_attn.proj.bias",
"ff.net.0.proj.weight": "img_mlp.0.weight",
"ff.net.0.proj.bias": "img_mlp.0.bias",
"ff.net.2.weight": "img_mlp.2.weight",
"ff.net.2.bias": "img_mlp.2.bias",
"ff_context.net.0.proj.weight": "txt_mlp.0.weight",
"ff_context.net.0.proj.bias": "txt_mlp.0.bias",
"ff_context.net.2.weight": "txt_mlp.2.weight",
"ff_context.net.2.bias": "txt_mlp.2.bias",
"attn.norm_q.weight": "img_attn.norm.query_norm.scale",
"attn.norm_k.weight": "img_attn.norm.key_norm.scale",
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
}
for k, v in block_map.items():
key_map[f"{prefix_from}.{k}"] = f"{prefix_to}.{v}"
for index in range(n_single_layers):
prefix_from = f"single_transformer_blocks.{index}"
prefix_to = f"single_blocks.{index}"
for end in ("weight", "bias"):
k = f"{prefix_from}.attn."
qkv = f"{prefix_to}.linear1.{end}"
key_map[f"{k}to_q.{end}"] = qkv
key_map[f"{k}to_k.{end}"] = qkv
key_map[f"{k}to_v.{end}"] = qkv
key_map[f"{prefix_from}.proj_mlp.{end}"] = qkv
block_map = {
"norm.linear.weight": "modulation.lin.weight",
"norm.linear.bias": "modulation.lin.bias",
"proj_out.weight": "linear2.weight",
"proj_out.bias": "linear2.bias",
"attn.norm_q.weight": "norm.query_norm.scale",
"attn.norm_k.weight": "norm.key_norm.scale",
}
for k, v in block_map.items():
key_map[f"{prefix_from}.{k}"] = f"{prefix_to}.{v}"
# add as-is keys
values = list([(v if isinstance(v, str) else v[0]) for v in set(key_map.values())])
values.sort()
key_map.update({v: v for v in values})
return key_map
key_map = create_key_map(18, 38) # 18 double layers, 38 single layers
def find_matching_key(flux_dict, lora_key):
lora_key = lora_key.replace("diffusion_model.", "")
lora_key = lora_key.replace("transformer.", "")
lora_key = lora_key.replace("lora_A", "lora_down").replace("lora_B", "lora_up")
lora_key = lora_key.replace("single_transformer_blocks", "single_blocks")
lora_key = lora_key.replace("transformer_blocks", "double_blocks")
double_block_map = {
"attn.to_out.0": "img_attn.proj",
"norm1.linear": "img_mod.lin",
"norm1_context.linear": "txt_mod.lin",
"attn.to_add_out": "txt_attn.proj",
"ff.net.0.proj": "img_mlp.0",
"ff.net.2": "img_mlp.2",
"ff_context.net.0.proj": "txt_mlp.0",
"ff_context.net.2": "txt_mlp.2",
"attn.norm_q": "img_attn.norm.query_norm",
"attn.norm_k": "img_attn.norm.key_norm",
"attn.norm_added_q": "txt_attn.norm.query_norm",
"attn.norm_added_k": "txt_attn.norm.key_norm",
"attn.to_q": "img_attn.qkv",
"attn.to_k": "img_attn.qkv",
"attn.to_v": "img_attn.qkv",
"attn.add_q_proj": "txt_attn.qkv",
"attn.add_k_proj": "txt_attn.qkv",
"attn.add_v_proj": "txt_attn.qkv",
}
single_block_map = {
"norm.linear": "modulation.lin",
"proj_out": "linear2",
"attn.norm_q": "norm.query_norm",
"attn.norm_k": "norm.key_norm",
"attn.to_q": "linear1",
"attn.to_k": "linear1",
"attn.to_v": "linear1",
"proj_mlp": "linear1",
}
# same key exists in both single_block_map and double_block_map, so we must care about single/double
# print("lora_key before double_block_map", lora_key)
for old, new in double_block_map.items():
if "double" in lora_key:
lora_key = lora_key.replace(old, new)
# print("lora_key before single_block_map", lora_key)
for old, new in single_block_map.items():
if "single" in lora_key:
lora_key = lora_key.replace(old, new)
# print("lora_key after mapping", lora_key)
if lora_key in key_map:
flux_key = key_map[lora_key]
logger.info(f"Found matching key: {flux_key}")
return flux_key
# If not found in key_map, try partial matching
potential_key = lora_key + ".weight"
logger.info(f"Searching for key: {potential_key}")
matches = [k for k in flux_dict.keys() if potential_key in k]
if matches:
logger.info(f"Found matching key: {matches[0]}")
return matches[0]
return None
merged_keys = set()
for model, ratio in zip(models, ratios):
logger.info(f"loading: {model}")
lora_sd, _ = load_state_dict(model, merge_dtype)
logger.info("merging...")
for key in lora_sd.keys():
if "lora_down" in key or "lora_A" in key:
lora_name = key[: key.rfind(".lora_down" if "lora_down" in key else ".lora_A")]
up_key = key.replace("lora_down", "lora_up").replace("lora_A", "lora_B")
alpha_key = key[: key.index("lora_down" if "lora_down" in key else "lora_A")] + "alpha"
logger.info(f"Processing LoRA key: {lora_name}")
flux_key = find_matching_key(flux_state_dict, lora_name)
if flux_key is None:
logger.warning(f"no module found for LoRA weight: {key}")
continue
logger.info(f"Merging LoRA key {lora_name} into Flux key {flux_key}")
down_weight = lora_sd[key]
up_weight = lora_sd[up_key]
dim = down_weight.size()[0]
alpha = lora_sd.get(alpha_key, dim)
scale = alpha / dim
weight = flux_state_dict[flux_key]
weight = weight.to(working_device, merge_dtype)
up_weight = up_weight.to(working_device, merge_dtype)
down_weight = down_weight.to(working_device, merge_dtype)
# print(up_weight.size(), down_weight.size(), weight.size())
if lora_name.startswith("transformer."):
if "qkv" in flux_key or "linear1" in flux_key: # combined qkv or qkv+mlp
update = ratio * (up_weight @ down_weight) * scale
# print(update.shape)
if "img_attn" in flux_key or "txt_attn" in flux_key:
q, k, v = torch.chunk(weight, 3, dim=0)
if "to_q" in lora_name or "add_q_proj" in lora_name:
q += update.reshape(q.shape)
elif "to_k" in lora_name or "add_k_proj" in lora_name:
k += update.reshape(k.shape)
elif "to_v" in lora_name or "add_v_proj" in lora_name:
v += update.reshape(v.shape)
weight = torch.cat([q, k, v], dim=0)
elif "linear1" in flux_key:
q, k, v = torch.chunk(weight[: int(update.shape[-1] * 3)], 3, dim=0)
mlp = weight[int(update.shape[-1] * 3) :]
# print(q.shape, k.shape, v.shape, mlp.shape)
if "to_q" in lora_name:
q += update.reshape(q.shape)
elif "to_k" in lora_name:
k += update.reshape(k.shape)
elif "to_v" in lora_name:
v += update.reshape(v.shape)
elif "proj_mlp" in lora_name:
mlp += update.reshape(mlp.shape)
weight = torch.cat([q, k, v, mlp], dim=0)
else:
if len(weight.size()) == 2:
weight = weight + ratio * (up_weight @ down_weight) * scale
elif down_weight.size()[2:4] == (1, 1):
weight = (
weight
+ ratio
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* scale
)
else:
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
weight = weight + ratio * conved * scale
else:
if len(weight.size()) == 2:
weight = weight + ratio * (up_weight @ down_weight) * scale
elif down_weight.size()[2:4] == (1, 1):
weight = (
weight
+ ratio
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* scale
)
else:
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
weight = weight + ratio * conved * scale
flux_state_dict[flux_key] = weight.to(loading_device, save_dtype)
merged_keys.add(flux_key)
del up_weight
del down_weight
del weight
logger.info(f"Merged keys: {sorted(list(merged_keys))}")
return flux_state_dict
def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
base_alphas = {} # alpha for merged model
base_dims = {}
merged_sd = {}
base_model = None
for model, ratio in zip(models, ratios):
logger.info(f"loading: {model}")
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
if lora_metadata is not None:
if base_model is None:
base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None)
# get alpha and dim
alphas = {} # alpha for current model
dims = {} # dims for current model
for key in lora_sd.keys():
if "alpha" in key:
lora_module_name = key[: key.rfind(".alpha")]
alpha = float(lora_sd[key].detach().numpy())
alphas[lora_module_name] = alpha
if lora_module_name not in base_alphas:
base_alphas[lora_module_name] = alpha
elif "lora_down" in key:
lora_module_name = key[: key.rfind(".lora_down")]
dim = lora_sd[key].size()[0]
dims[lora_module_name] = dim
if lora_module_name not in base_dims:
base_dims[lora_module_name] = dim
for lora_module_name in dims.keys():
if lora_module_name not in alphas:
alpha = dims[lora_module_name]
alphas[lora_module_name] = alpha
if lora_module_name not in base_alphas:
base_alphas[lora_module_name] = alpha
logger.info(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
# merge
logger.info("merging...")
for key in tqdm(lora_sd.keys()):
if "alpha" in key:
continue
if "lora_up" in key and concat:
concat_dim = 1
elif "lora_down" in key and concat:
concat_dim = 0
else:
concat_dim = None
lora_module_name = key[: key.rfind(".lora_")]
base_alpha = base_alphas[lora_module_name]
alpha = alphas[lora_module_name]
scale = math.sqrt(alpha / base_alpha) * ratio
scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。
if key in merged_sd:
assert (
merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None
), "weights shape mismatch, different dims? / 重みのサイズが合いません。dimが異なる可能性があります。"
if concat_dim is not None:
merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim)
else:
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
else:
merged_sd[key] = lora_sd[key] * scale
# set alpha to sd
for lora_module_name, alpha in base_alphas.items():
key = lora_module_name + ".alpha"
merged_sd[key] = torch.tensor(alpha)
if shuffle:
key_down = lora_module_name + ".lora_down.weight"
key_up = lora_module_name + ".lora_up.weight"
dim = merged_sd[key_down].shape[0]
perm = torch.randperm(dim)
merged_sd[key_down] = merged_sd[key_down][perm]
merged_sd[key_up] = merged_sd[key_up][:, perm]
logger.info("merged model")
logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
# check all dims are same
dims_list = list(set(base_dims.values()))
alphas_list = list(set(base_alphas.values()))
all_same_dims = True
all_same_alphas = True
for dims in dims_list:
if dims != dims_list[0]:
all_same_dims = False
break
for alphas in alphas_list:
if alphas != alphas_list[0]:
all_same_alphas = False
break
# build minimum metadata
dims = f"{dims_list[0]}" if all_same_dims else "Dynamic"
alphas = f"{alphas_list[0]}" if all_same_alphas else "Dynamic"
metadata = train_util.build_minimum_network_metadata(str(False), base_model, "networks.lora", dims, alphas, None)
return merged_sd, metadata
def merge(args):
if args.models is None:
args.models = []
if args.ratios is None:
args.ratios = []
assert len(args.models) == len(
args.ratios
), "number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
merge_dtype = str_to_dtype(args.precision)
save_dtype = str_to_dtype(args.save_precision)
if save_dtype is None:
save_dtype = merge_dtype
assert (
args.save_to or args.clip_l_save_to or args.t5xxl_save_to
), "save_to or clip_l_save_to or t5xxl_save_to must be specified / save_toまたはclip_l_save_toまたはt5xxl_save_toを指定してください"
dest_dir = os.path.dirname(args.save_to or args.clip_l_save_to or args.t5xxl_save_to)
if not os.path.exists(dest_dir):
logger.info(f"creating directory: {dest_dir}")
os.makedirs(dest_dir)
if args.flux_model is not None or args.clip_l is not None or args.t5xxl is not None:
if not args.diffusers:
assert (args.clip_l is None and args.clip_l_save_to is None) or (
args.clip_l is not None and args.clip_l_save_to is not None
), "clip_l_save_to must be specified if clip_l is specified / clip_lが指定されている場合はclip_l_save_toも指定してください"
assert (args.t5xxl is None and args.t5xxl_save_to is None) or (
args.t5xxl is not None and args.t5xxl_save_to is not None
), "t5xxl_save_to must be specified if t5xxl is specified / t5xxlが指定されている場合はt5xxl_save_toも指定してください"
flux_state_dict, clip_l_state_dict, t5xxl_state_dict = merge_to_flux_model(
args.loading_device,
args.working_device,
args.flux_model,
args.clip_l,
args.t5xxl,
args.models,
args.ratios,
merge_dtype,
save_dtype,
args.mem_eff_load_save,
)
else:
assert (
args.clip_l is None and args.t5xxl is None
), "clip_l and t5xxl are not supported with --diffusers / clip_l、t5xxlはDiffusersではサポートされていません"
flux_state_dict = merge_to_flux_model_diffusers(
args.loading_device,
args.working_device,
args.flux_model,
args.models,
args.ratios,
merge_dtype,
save_dtype,
args.mem_eff_load_save,
)
clip_l_state_dict = None
t5xxl_state_dict = None
if args.no_metadata or (flux_state_dict is None or len(flux_state_dict) == 0):
sai_metadata = None
else:
merged_from = sai_model_spec.build_merged_from([args.flux_model] + args.models)
title = os.path.splitext(os.path.basename(args.save_to))[0]
sai_metadata = sai_model_spec.build_metadata(
None, False, False, False, False, False, time.time(), title=title, merged_from=merged_from, flux="dev"
)
if flux_state_dict is not None and len(flux_state_dict) > 0:
logger.info(f"saving FLUX model to: {args.save_to}")
save_to_file(args.save_to, flux_state_dict, save_dtype, sai_metadata, args.mem_eff_load_save)
if clip_l_state_dict is not None and len(clip_l_state_dict) > 0:
logger.info(f"saving clip_l model to: {args.clip_l_save_to}")
save_to_file(args.clip_l_save_to, clip_l_state_dict, save_dtype, None, args.mem_eff_load_save)
if t5xxl_state_dict is not None and len(t5xxl_state_dict) > 0:
logger.info(f"saving t5xxl model to: {args.t5xxl_save_to}")
save_to_file(args.t5xxl_save_to, t5xxl_state_dict, save_dtype, None, args.mem_eff_load_save)
else:
flux_state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
logger.info("calculating hashes and creating metadata...")
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(flux_state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
if not args.no_metadata:
merged_from = sai_model_spec.build_merged_from(args.models)
title = os.path.splitext(os.path.basename(args.save_to))[0]
sai_metadata = sai_model_spec.build_metadata(
flux_state_dict, False, False, False, True, False, time.time(), title=title, merged_from=merged_from, flux="dev"
)
metadata.update(sai_metadata)
logger.info(f"saving model to: {args.save_to}")
save_to_file(args.save_to, flux_state_dict, save_dtype, metadata)
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument(
"--save_precision",
type=str,
default=None,
help="precision in saving, same to merging if omitted. supported types: "
"float32, fp16, bf16, fp8 (same as fp8_e4m3fn), fp8_e4m3fn, fp8_e4m3fnuz, fp8_e5m2, fp8_e5m2fnuz"
" / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ",
)
parser.add_argument(
"--precision",
type=str,
default="float",
help="precision in merging (float is recommended) / マージの計算時の精度floatを推奨",
)
parser.add_argument(
"--flux_model",
type=str,
default=None,
help="FLUX.1 model to load, merge LoRA models if omitted / 読み込むモデル、指定しない場合はLoRAモデルをマージする",
)
parser.add_argument(
"--clip_l",
type=str,
default=None,
help="path to clip_l (*.sft or *.safetensors), should be float16 / clip_lのパス*.sftまたは*.safetensors",
)
parser.add_argument(
"--t5xxl",
type=str,
default=None,
help="path to t5xxl (*.sft or *.safetensors), should be float16 / t5xxlのパス*.sftまたは*.safetensors",
)
parser.add_argument(
"--mem_eff_load_save",
action="store_true",
help="use custom memory efficient load and save functions for FLUX.1 model"
" / カスタムのメモリ効率の良い読み込みと保存関数をFLUX.1モデルに使用する",
)
parser.add_argument(
"--loading_device",
type=str,
default="cpu",
help="device to load FLUX.1 model. LoRA models are loaded on CPU / FLUX.1モデルを読み込むデバイス。LoRAモデルはCPUで読み込まれます",
)
parser.add_argument(
"--working_device",
type=str,
default="cpu",
help="device to work (merge). Merging LoRA models are done on CPU."
+ " / 作業マージするデバイス。LoRAモデルのマージはCPUで行われます。",
)
parser.add_argument(
"--save_to",
type=str,
default=None,
help="destination file name: safetensors file / 保存先のファイル名、safetensorsファイル",
)
parser.add_argument(
"--clip_l_save_to",
type=str,
default=None,
help="destination file name for clip_l: safetensors file / clip_lの保存先のファイル名、safetensorsファイル",
)
parser.add_argument(
"--t5xxl_save_to",
type=str,
default=None,
help="destination file name for t5xxl: safetensors file / t5xxlの保存先のファイル名、safetensorsファイル",
)
parser.add_argument(
"--models",
type=str,
nargs="*",
help="LoRA models to merge: safetensors file / マージするLoRAモデル、safetensorsファイル",
)
parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
parser.add_argument(
"--no_metadata",
action="store_true",
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
+ "sai modelspecのメタデータを保存しないLoRAの最低限のss_metadataは保存される",
)
parser.add_argument(
"--concat",
action="store_true",
help="concat lora instead of merge (The dim(rank) of the output LoRA is the sum of the input dims) / "
+ "マージの代わりに結合するLoRAのdim(rank)は入力dimの合計になる",
)
parser.add_argument(
"--shuffle",
action="store_true",
help="shuffle lora weight./ " + "LoRAの重みをシャッフルする",
)
parser.add_argument(
"--diffusers",
action="store_true",
help="merge Diffusers (?) LoRA models / Diffusers (?) LoRAモデルをマージする",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
merge(args)

View File

@@ -12,6 +12,7 @@ import numpy as np
import torch
import re
from library.utils import setup_logging
from library.sdxl_original_unet import SdxlUNet2DConditionModel
setup_logging()
import logging
@@ -247,14 +248,13 @@ class LoRAInfModule(LoRAModule):
area = x.size()[1]
mask = self.network.mask_dic.get(area, None)
if mask is None:
# raise ValueError(f"mask is None for resolution {area}")
if mask is None or len(x.size()) == 2:
# emb_layers in SDXL doesn't have mask
# if "emb" not in self.lora_name:
# print(f"mask is None for resolution {self.lora_name}, {area}, {x.size()}")
mask_size = (1, x.size()[1]) if len(x.size()) == 2 else (1, *x.size()[1:-1], 1)
return torch.ones(mask_size, dtype=x.dtype, device=x.device) / self.network.num_sub_prompts
if len(x.size()) != 4:
if len(x.size()) == 3:
mask = torch.reshape(mask, (1, -1, 1))
return mask
@@ -386,14 +386,14 @@ class LoRAInfModule(LoRAModule):
return out
def parse_block_lr_kwargs(nw_kwargs):
def parse_block_lr_kwargs(is_sdxl: bool, nw_kwargs: Dict) -> Optional[List[float]]:
down_lr_weight = nw_kwargs.get("down_lr_weight", None)
mid_lr_weight = nw_kwargs.get("mid_lr_weight", None)
up_lr_weight = nw_kwargs.get("up_lr_weight", None)
# 以上のいずれにも設定がない場合は無効としてNoneを返す
if down_lr_weight is None and mid_lr_weight is None and up_lr_weight is None:
return None, None, None
return None
# extract learning rate weight for each block
if down_lr_weight is not None:
@@ -402,18 +402,16 @@ def parse_block_lr_kwargs(nw_kwargs):
down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")]
if mid_lr_weight is not None:
mid_lr_weight = float(mid_lr_weight)
mid_lr_weight = [(float(s) if s else 0.0) for s in mid_lr_weight.split(",")]
if up_lr_weight is not None:
if "," in up_lr_weight:
up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")]
down_lr_weight, mid_lr_weight, up_lr_weight = get_block_lr_weight(
down_lr_weight, mid_lr_weight, up_lr_weight, float(nw_kwargs.get("block_lr_zero_threshold", 0.0))
return get_block_lr_weight(
is_sdxl, down_lr_weight, mid_lr_weight, up_lr_weight, float(nw_kwargs.get("block_lr_zero_threshold", 0.0))
)
return down_lr_weight, mid_lr_weight, up_lr_weight
def create_network(
multiplier: float,
@@ -425,6 +423,9 @@ def create_network(
neuron_dropout: Optional[float] = None,
**kwargs,
):
# if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True
is_sdxl = unet is not None and issubclass(unet.__class__, SdxlUNet2DConditionModel)
if network_dim is None:
network_dim = 4 # default
if network_alpha is None:
@@ -442,21 +443,21 @@ def create_network(
# block dim/alpha/lr
block_dims = kwargs.get("block_dims", None)
down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs)
block_lr_weight = parse_block_lr_kwargs(is_sdxl, kwargs)
# 以上のいずれかに指定があればblockごとのdim(rank)を有効にする
if block_dims is not None or down_lr_weight is not None or mid_lr_weight is not None or up_lr_weight is not None:
if block_dims is not None or block_lr_weight is not None:
block_alphas = kwargs.get("block_alphas", None)
conv_block_dims = kwargs.get("conv_block_dims", None)
conv_block_alphas = kwargs.get("conv_block_alphas", None)
block_dims, block_alphas, conv_block_dims, conv_block_alphas = get_block_dims_and_alphas(
block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
is_sdxl, block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
)
# remove block dim/alpha without learning rate
block_dims, block_alphas, conv_block_dims, conv_block_alphas = remove_block_dims_and_alphas(
block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
is_sdxl, block_dims, block_alphas, conv_block_dims, conv_block_alphas, block_lr_weight
)
else:
@@ -489,10 +490,20 @@ def create_network(
conv_block_dims=conv_block_dims,
conv_block_alphas=conv_block_alphas,
varbose=True,
is_sdxl=is_sdxl,
)
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
if block_lr_weight is not None:
network.set_block_lr_weight(block_lr_weight)
return network
@@ -502,9 +513,13 @@ def create_network(
# block_dims, block_alphas は両方ともNoneまたは両方とも値が入っている
# conv_dim, conv_alpha は両方ともNoneまたは両方とも値が入っている
def get_block_dims_and_alphas(
block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
is_sdxl, block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
):
num_total_blocks = LoRANetwork.NUM_OF_BLOCKS * 2 + 1
if not is_sdxl:
num_total_blocks = LoRANetwork.NUM_OF_BLOCKS * 2 + LoRANetwork.NUM_OF_MID_BLOCKS
else:
# 1+9+3+9+1=23, no LoRA for emb_layers (0)
num_total_blocks = 1 + LoRANetwork.SDXL_NUM_OF_BLOCKS * 2 + LoRANetwork.SDXL_NUM_OF_MID_BLOCKS + 1
def parse_ints(s):
return [int(i) for i in s.split(",")]
@@ -515,9 +530,10 @@ def get_block_dims_and_alphas(
# block_dimsとblock_alphasをパースする。必ず値が入る
if block_dims is not None:
block_dims = parse_ints(block_dims)
assert (
len(block_dims) == num_total_blocks
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
assert len(block_dims) == num_total_blocks, (
f"block_dims must have {num_total_blocks} elements but {len(block_dims)} elements are given"
+ f" / block_dims{num_total_blocks}個指定してください(指定された個数: {len(block_dims)}"
)
else:
logger.warning(
f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります"
@@ -568,15 +584,25 @@ def get_block_dims_and_alphas(
return block_dims, block_alphas, conv_block_dims, conv_block_alphas
# 層別学習率用に層ごとの学習率に対する倍率を定義する、外部から呼び出される可能性を考慮しておく
# 層別学習率用に層ごとの学習率に対する倍率を定義する、外部から呼び出せるようにclass外に出しておく
# 戻り値は block ごとの倍率のリスト
def get_block_lr_weight(
down_lr_weight, mid_lr_weight, up_lr_weight, zero_threshold
) -> Tuple[List[float], List[float], List[float]]:
is_sdxl,
down_lr_weight: Union[str, List[float]],
mid_lr_weight: List[float],
up_lr_weight: Union[str, List[float]],
zero_threshold: float,
) -> Optional[List[float]]:
# パラメータ未指定時は何もせず、今までと同じ動作とする
if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None:
return None, None, None
return None
max_len = LoRANetwork.NUM_OF_BLOCKS # フルモデル相当でのup,downの層の数
if not is_sdxl:
max_len_for_down_or_up = LoRANetwork.NUM_OF_BLOCKS
max_len_for_mid = LoRANetwork.NUM_OF_MID_BLOCKS
else:
max_len_for_down_or_up = LoRANetwork.SDXL_NUM_OF_BLOCKS
max_len_for_mid = LoRANetwork.SDXL_NUM_OF_MID_BLOCKS
def get_list(name_with_suffix) -> List[float]:
import math
@@ -586,15 +612,18 @@ def get_block_lr_weight(
base_lr = float(tokens[1]) if len(tokens) > 1 else 0.0
if name == "cosine":
return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in reversed(range(max_len))]
return [
math.sin(math.pi * (i / (max_len_for_down_or_up - 1)) / 2) + base_lr
for i in reversed(range(max_len_for_down_or_up))
]
elif name == "sine":
return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in range(max_len)]
return [math.sin(math.pi * (i / (max_len_for_down_or_up - 1)) / 2) + base_lr for i in range(max_len_for_down_or_up)]
elif name == "linear":
return [i / (max_len - 1) + base_lr for i in range(max_len)]
return [i / (max_len_for_down_or_up - 1) + base_lr for i in range(max_len_for_down_or_up)]
elif name == "reverse_linear":
return [i / (max_len - 1) + base_lr for i in reversed(range(max_len))]
return [i / (max_len_for_down_or_up - 1) + base_lr for i in reversed(range(max_len_for_down_or_up))]
elif name == "zeros":
return [0.0 + base_lr] * max_len
return [0.0 + base_lr] * max_len_for_down_or_up
else:
logger.error(
"Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros"
@@ -607,20 +636,36 @@ def get_block_lr_weight(
if type(up_lr_weight) == str:
up_lr_weight = get_list(up_lr_weight)
if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len):
logger.warning("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
logger.warning("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
up_lr_weight = up_lr_weight[:max_len]
down_lr_weight = down_lr_weight[:max_len]
if (up_lr_weight != None and len(up_lr_weight) > max_len_for_down_or_up) or (
down_lr_weight != None and len(down_lr_weight) > max_len_for_down_or_up
):
logger.warning("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len_for_down_or_up)
logger.warning("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len_for_down_or_up)
up_lr_weight = up_lr_weight[:max_len_for_down_or_up]
down_lr_weight = down_lr_weight[:max_len_for_down_or_up]
if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len):
logger.warning("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
logger.warning("down_weightもしくはup_weightがすぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
if mid_lr_weight != None and len(mid_lr_weight) > max_len_for_mid:
logger.warning("mid_weight is too long. Parameters after %d-th are ignored." % max_len_for_mid)
logger.warning("mid_weightがすぎます。%d個目以降のパラメータは無視されます。" % max_len_for_mid)
mid_lr_weight = mid_lr_weight[:max_len_for_mid]
if down_lr_weight != None and len(down_lr_weight) < max_len:
down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight))
if up_lr_weight != None and len(up_lr_weight) < max_len:
up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight))
if (up_lr_weight != None and len(up_lr_weight) < max_len_for_down_or_up) or (
down_lr_weight != None and len(down_lr_weight) < max_len_for_down_or_up
):
logger.warning("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len_for_down_or_up)
logger.warning(
"down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len_for_down_or_up
)
if down_lr_weight != None and len(down_lr_weight) < max_len_for_down_or_up:
down_lr_weight = down_lr_weight + [1.0] * (max_len_for_down_or_up - len(down_lr_weight))
if up_lr_weight != None and len(up_lr_weight) < max_len_for_down_or_up:
up_lr_weight = up_lr_weight + [1.0] * (max_len_for_down_or_up - len(up_lr_weight))
if mid_lr_weight != None and len(mid_lr_weight) < max_len_for_mid:
logger.warning("mid_weight is too short. Parameters after %d-th are filled with 1." % max_len_for_mid)
logger.warning("mid_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len_for_mid)
mid_lr_weight = mid_lr_weight + [1.0] * (max_len_for_mid - len(mid_lr_weight))
if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
logger.info("apply block learning rate / 階層別学習率を適用します。")
@@ -628,78 +673,139 @@ def get_block_lr_weight(
down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight]
logger.info(f"down_lr_weight (shallower -> deeper, 浅い層->深い層): {down_lr_weight}")
else:
down_lr_weight = [1.0] * max_len_for_down_or_up
logger.info("down_lr_weight: all 1.0, すべて1.0")
if mid_lr_weight != None:
mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0
mid_lr_weight = [w if w > zero_threshold else 0 for w in mid_lr_weight]
logger.info(f"mid_lr_weight: {mid_lr_weight}")
else:
logger.info("mid_lr_weight: 1.0")
mid_lr_weight = [1.0] * max_len_for_mid
logger.info("mid_lr_weight: all 1.0, すべて1.0")
if up_lr_weight != None:
up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight]
logger.info(f"up_lr_weight (deeper -> shallower, 深い層->浅い層): {up_lr_weight}")
else:
up_lr_weight = [1.0] * max_len_for_down_or_up
logger.info("up_lr_weight: all 1.0, すべて1.0")
return down_lr_weight, mid_lr_weight, up_lr_weight
lr_weight = down_lr_weight + mid_lr_weight + up_lr_weight
if is_sdxl:
lr_weight = [1.0] + lr_weight + [1.0] # add 1.0 for emb_layers and out
assert (not is_sdxl and len(lr_weight) == LoRANetwork.NUM_OF_BLOCKS * 2 + LoRANetwork.NUM_OF_MID_BLOCKS) or (
is_sdxl and len(lr_weight) == 1 + LoRANetwork.SDXL_NUM_OF_BLOCKS * 2 + LoRANetwork.SDXL_NUM_OF_MID_BLOCKS + 1
), f"lr_weight length is invalid: {len(lr_weight)}"
return lr_weight
# lr_weightが0のblockをblock_dimsから除外する、外部から呼び出す可能性を考慮しておく
def remove_block_dims_and_alphas(
block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
is_sdxl, block_dims, block_alphas, conv_block_dims, conv_block_alphas, block_lr_weight: Optional[List[float]]
):
# set 0 to block dim without learning rate to remove the block
if down_lr_weight != None:
for i, lr in enumerate(down_lr_weight):
if block_lr_weight is not None:
for i, lr in enumerate(block_lr_weight):
if lr == 0:
block_dims[i] = 0
if conv_block_dims is not None:
conv_block_dims[i] = 0
if mid_lr_weight != None:
if mid_lr_weight == 0:
block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0
if conv_block_dims is not None:
conv_block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0
if up_lr_weight != None:
for i, lr in enumerate(up_lr_weight):
if lr == 0:
block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0
if conv_block_dims is not None:
conv_block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0
return block_dims, block_alphas, conv_block_dims, conv_block_alphas
# 外部から呼び出す可能性を考慮しておく
def get_block_index(lora_name: str) -> int:
def get_block_index(lora_name: str, is_sdxl: bool = False) -> int:
block_idx = -1 # invalid lora name
if not is_sdxl:
m = RE_UPDOWN.search(lora_name)
if m:
g = m.groups()
i = int(g[1])
j = int(g[3])
if g[2] == "resnets":
idx = 3 * i + j
elif g[2] == "attentions":
idx = 3 * i + j
elif g[2] == "upsamplers" or g[2] == "downsamplers":
idx = 3 * i + 2
m = RE_UPDOWN.search(lora_name)
if m:
g = m.groups()
i = int(g[1])
j = int(g[3])
if g[2] == "resnets":
idx = 3 * i + j
elif g[2] == "attentions":
idx = 3 * i + j
elif g[2] == "upsamplers" or g[2] == "downsamplers":
idx = 3 * i + 2
if g[0] == "down":
block_idx = 1 + idx # 0に該当するLoRAは存在しない
elif g[0] == "up":
block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx
elif "mid_block_" in lora_name:
block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12
if g[0] == "down":
block_idx = 1 + idx # 0に該当するLoRAは存在しない
elif g[0] == "up":
block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx
elif "mid_block_" in lora_name:
block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12
else:
# copy from sdxl_train
if lora_name.startswith("lora_unet_"):
name = lora_name[len("lora_unet_") :]
if name.startswith("time_embed_") or name.startswith("label_emb_"): # No LoRA
block_idx = 0 # 0
elif name.startswith("input_blocks_"): # 1-9
block_idx = 1 + int(name.split("_")[2])
elif name.startswith("middle_block_"): # 10-12
block_idx = 10 + int(name.split("_")[2])
elif name.startswith("output_blocks_"): # 13-21
block_idx = 13 + int(name.split("_")[2])
elif name.startswith("out_"): # 22, out, no LoRA
block_idx = 22
return block_idx
def convert_diffusers_to_sai_if_needed(weights_sd):
# only supports U-Net LoRA modules
found_up_down_blocks = False
for k in list(weights_sd.keys()):
if "down_blocks" in k:
found_up_down_blocks = True
break
if "up_blocks" in k:
found_up_down_blocks = True
break
if not found_up_down_blocks:
return
from library.sdxl_model_util import make_unet_conversion_map
unet_conversion_map = make_unet_conversion_map()
unet_conversion_map = {hf.replace(".", "_")[:-1]: sd.replace(".", "_")[:-1] for sd, hf in unet_conversion_map}
# # add extra conversion
# unet_conversion_map["up_blocks_1_upsamplers_0"] = "lora_unet_output_blocks_2_2_conv"
logger.info(f"Converting LoRA keys from Diffusers to SAI")
lora_unet_prefix = "lora_unet_"
for k in list(weights_sd.keys()):
if not k.startswith(lora_unet_prefix):
continue
unet_module_name = k[len(lora_unet_prefix) :].split(".")[0]
# search for conversion: this is slow because the algorithm is O(n^2), but the number of keys is small
for hf_module_name, sd_module_name in unet_conversion_map.items():
if hf_module_name in unet_module_name:
new_key = (
lora_unet_prefix
+ unet_module_name.replace(hf_module_name, sd_module_name)
+ k[len(lora_unet_prefix) + len(unet_module_name) :]
)
weights_sd[new_key] = weights_sd.pop(k)
found = True
break
if not found:
logger.warning(f"Key {k} is not found in unet_conversion_map")
# Create network from weights for inference, weights are not loaded here (because can be merged)
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
# if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True
is_sdxl = unet is not None and issubclass(unet.__class__, SdxlUNet2DConditionModel)
if weights_sd is None:
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file, safe_open
@@ -708,6 +814,10 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
else:
weights_sd = torch.load(file, map_location="cpu")
# if keys are Diffusers based, convert to SAI based
if is_sdxl:
convert_diffusers_to_sai_if_needed(weights_sd)
# get dim/alpha mapping
modules_dim = {}
modules_alpha = {}
@@ -731,19 +841,28 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
module_class = LoRAInfModule if for_inference else LoRAModule
network = LoRANetwork(
text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class
text_encoder,
unet,
multiplier=multiplier,
modules_dim=modules_dim,
modules_alpha=modules_alpha,
module_class=module_class,
is_sdxl=is_sdxl,
)
# block lr
down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs)
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
block_lr_weight = parse_block_lr_kwargs(is_sdxl, kwargs)
if block_lr_weight is not None:
network.set_block_lr_weight(block_lr_weight)
return network, weights_sd
class LoRANetwork(torch.nn.Module):
NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
NUM_OF_MID_BLOCKS = 1
SDXL_NUM_OF_BLOCKS = 9 # SDXLのモデルでのinput/outputの層の数 total=1(base) 9(input) + 3(mid) + 9(output) + 1(out) = 23
SDXL_NUM_OF_MID_BLOCKS = 3
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
@@ -775,6 +894,7 @@ class LoRANetwork(torch.nn.Module):
modules_alpha: Optional[Dict[str, int]] = None,
module_class: Type[object] = LoRAModule,
varbose: Optional[bool] = False,
is_sdxl: Optional[bool] = False,
) -> None:
"""
LoRA network: すごく引数が多いが、パターンは以下の通り
@@ -795,6 +915,10 @@ class LoRANetwork(torch.nn.Module):
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None
self.loraplus_text_encoder_lr_ratio = None
if modules_dim is not None:
logger.info(f"create LoRA network from weights")
elif block_dims is not None:
@@ -856,7 +980,7 @@ class LoRANetwork(torch.nn.Module):
alpha = modules_alpha[lora_name]
elif is_unet and block_dims is not None:
# U-Netでblock_dims指定あり
block_idx = get_block_index(lora_name)
block_idx = get_block_index(lora_name, is_sdxl)
if is_linear or is_conv2d_1x1:
dim = block_dims[block_idx]
alpha = block_alphas[block_idx]
@@ -926,9 +1050,7 @@ class LoRANetwork(torch.nn.Module):
for name in skipped:
logger.info(f"\t{name}")
self.up_lr_weight: List[float] = None
self.down_lr_weight: List[float] = None
self.mid_lr_weight: float = None
self.block_lr_weight = None
self.block_lr = False
# assertion
@@ -959,12 +1081,12 @@ class LoRANetwork(torch.nn.Module):
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
if apply_text_encoder:
logger.info("enable LoRA for text encoder")
logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules")
else:
self.text_encoder_loras = []
if apply_unet:
logger.info("enable LoRA for U-Net")
logger.info(f"enable LoRA for U-Net: {len(self.unet_loras)} modules")
else:
self.unet_loras = []
@@ -1005,81 +1127,117 @@ class LoRANetwork(torch.nn.Module):
logger.info(f"weights are merged")
# 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない
def set_block_lr_weight(
self,
up_lr_weight: List[float] = None,
mid_lr_weight: float = None,
down_lr_weight: List[float] = None,
):
def set_block_lr_weight(self, block_lr_weight: Optional[List[float]]):
self.block_lr = True
self.down_lr_weight = down_lr_weight
self.mid_lr_weight = mid_lr_weight
self.up_lr_weight = up_lr_weight
self.block_lr_weight = block_lr_weight
def get_lr_weight(self, lora: LoRAModule) -> float:
lr_weight = 1.0
block_idx = get_block_index(lora.lora_name)
if block_idx < 0:
return lr_weight
def get_lr_weight(self, block_idx: int) -> float:
if not self.block_lr or self.block_lr_weight is None:
return 1.0
return self.block_lr_weight[block_idx]
if block_idx < LoRANetwork.NUM_OF_BLOCKS:
if self.down_lr_weight != None:
lr_weight = self.down_lr_weight[block_idx]
elif block_idx == LoRANetwork.NUM_OF_BLOCKS:
if self.mid_lr_weight != None:
lr_weight = self.mid_lr_weight
elif block_idx > LoRANetwork.NUM_OF_BLOCKS:
if self.up_lr_weight != None:
lr_weight = self.up_lr_weight[block_idx - LoRANetwork.NUM_OF_BLOCKS - 1]
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
self.loraplus_lr_ratio = loraplus_lr_ratio
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
return lr_weight
logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}")
logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}")
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
self.requires_grad_(True)
all_params = []
# TODO warn if optimizer is not compatible with LoRA+ (but it will cause error so we don't need to check it here?)
# if (
# self.loraplus_lr_ratio is not None
# or self.loraplus_text_encoder_lr_ratio is not None
# or self.loraplus_unet_lr_ratio is not None
# ):
# assert (
# optimizer_type.lower() != "prodigy" and "dadapt" not in optimizer_type.lower()
# ), "LoRA+ and Prodigy/DAdaptation is not supported / LoRA+とProdigy/DAdaptationの組み合わせはサポートされていません"
def enumerate_params(loras):
params = []
self.requires_grad_(True)
all_params = []
lr_descriptions = []
def assemble_params(loras, lr, ratio):
param_groups = {"lora": {}, "plus": {}}
for lora in loras:
params.extend(lora.parameters())
return params
for name, param in lora.named_parameters():
if ratio is not None and "lora_up" in name:
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
else:
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
params = []
descriptions = []
for key in param_groups.keys():
param_data = {"params": param_groups[key].values()}
if len(param_data["params"]) == 0:
continue
if lr is not None:
if key == "plus":
param_data["lr"] = lr * ratio
else:
param_data["lr"] = lr
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
logger.info("NO LR skipping!")
continue
params.append(param_data)
descriptions.append("plus" if key == "plus" else "")
return params, descriptions
if self.text_encoder_loras:
param_data = {"params": enumerate_params(self.text_encoder_loras)}
if text_encoder_lr is not None:
param_data["lr"] = text_encoder_lr
all_params.append(param_data)
params, descriptions = assemble_params(
self.text_encoder_loras,
text_encoder_lr if text_encoder_lr is not None else default_lr,
self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio,
)
all_params.extend(params)
lr_descriptions.extend(["textencoder" + (" " + d if d else "") for d in descriptions])
if self.unet_loras:
if self.block_lr:
is_sdxl = False
for lora in self.unet_loras:
if "input_blocks" in lora.lora_name or "output_blocks" in lora.lora_name:
is_sdxl = True
break
# 学習率のグラフをblockごとにしたいので、blockごとにloraを分類
block_idx_to_lora = {}
for lora in self.unet_loras:
idx = get_block_index(lora.lora_name)
idx = get_block_index(lora.lora_name, is_sdxl)
if idx not in block_idx_to_lora:
block_idx_to_lora[idx] = []
block_idx_to_lora[idx].append(lora)
# blockごとにパラメータを設定する
for idx, block_loras in block_idx_to_lora.items():
param_data = {"params": enumerate_params(block_loras)}
if unet_lr is not None:
param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0])
elif default_lr is not None:
param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0])
if ("lr" in param_data) and (param_data["lr"] == 0):
continue
all_params.append(param_data)
params, descriptions = assemble_params(
block_loras,
(unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(idx),
self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio,
)
all_params.extend(params)
lr_descriptions.extend([f"unet_block{idx}" + (" " + d if d else "") for d in descriptions])
else:
param_data = {"params": enumerate_params(self.unet_loras)}
if unet_lr is not None:
param_data["lr"] = unet_lr
all_params.append(param_data)
params, descriptions = assemble_params(
self.unet_loras,
unet_lr if unet_lr is not None else default_lr,
self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio,
)
all_params.extend(params)
lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions])
return all_params
return all_params, lr_descriptions
def enable_gradient_checkpointing(self):
# not supported

1157
networks/lora_flux.py Normal file

File diff suppressed because it is too large Load Diff

839
networks/lora_sd3.py Normal file
View File

@@ -0,0 +1,839 @@
# temporary minimum implementation of LoRA
# SD3 doesn't have Conv2d, so we ignore it
# TODO commonize with the original/SD3/FLUX implementation
# LoRA network module
# reference:
# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
import math
import os
from typing import Dict, List, Optional, Tuple, Type, Union
from transformers import CLIPTextModelWithProjection, T5EncoderModel
import numpy as np
import torch
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
from networks.lora_flux import LoRAModule, LoRAInfModule
from library import sd3_models
def create_network(
multiplier: float,
network_dim: Optional[int],
network_alpha: Optional[float],
vae: sd3_models.SDVAE,
text_encoders: List[Union[CLIPTextModelWithProjection, T5EncoderModel]],
mmdit,
neuron_dropout: Optional[float] = None,
**kwargs,
):
if network_dim is None:
network_dim = 4 # default
if network_alpha is None:
network_alpha = 1.0
# extract dim/alpha for conv2d, and block dim
conv_dim = kwargs.get("conv_dim", None)
conv_alpha = kwargs.get("conv_alpha", None)
if conv_dim is not None:
conv_dim = int(conv_dim)
if conv_alpha is None:
conv_alpha = 1.0
else:
conv_alpha = float(conv_alpha)
# attn dim, mlp dim: only for DoubleStreamBlock. SingleStreamBlock is not supported because of combined qkv
context_attn_dim = kwargs.get("context_attn_dim", None)
context_mlp_dim = kwargs.get("context_mlp_dim", None)
context_mod_dim = kwargs.get("context_mod_dim", None)
x_attn_dim = kwargs.get("x_attn_dim", None)
x_mlp_dim = kwargs.get("x_mlp_dim", None)
x_mod_dim = kwargs.get("x_mod_dim", None)
if context_attn_dim is not None:
context_attn_dim = int(context_attn_dim)
if context_mlp_dim is not None:
context_mlp_dim = int(context_mlp_dim)
if context_mod_dim is not None:
context_mod_dim = int(context_mod_dim)
if x_attn_dim is not None:
x_attn_dim = int(x_attn_dim)
if x_mlp_dim is not None:
x_mlp_dim = int(x_mlp_dim)
if x_mod_dim is not None:
x_mod_dim = int(x_mod_dim)
type_dims = [context_attn_dim, context_mlp_dim, context_mod_dim, x_attn_dim, x_mlp_dim, x_mod_dim]
if all([d is None for d in type_dims]):
type_dims = None
# emb_dims [context_embedder, t_embedder, x_embedder, y_embedder, final_mod, final_linear]
emb_dims = kwargs.get("emb_dims", None)
if emb_dims is not None:
emb_dims = emb_dims.strip()
if emb_dims.startswith("[") and emb_dims.endswith("]"):
emb_dims = emb_dims[1:-1]
emb_dims = [int(d) for d in emb_dims.split(",")] # is it better to use ast.literal_eval?
assert len(emb_dims) == 6, f"invalid emb_dims: {emb_dims}, must be 6 dimensions (context, t, x, y, final_mod, final_linear)"
# double/single train blocks
def parse_block_selection(selection: str, total_blocks: int) -> List[bool]:
"""
Parse a block selection string and return a list of booleans.
Args:
selection (str): A string specifying which blocks to select.
total_blocks (int): The total number of blocks available.
Returns:
List[bool]: A list of booleans indicating which blocks are selected.
"""
if selection == "all":
return [True] * total_blocks
if selection == "none" or selection == "":
return [False] * total_blocks
selected = [False] * total_blocks
ranges = selection.split(",")
for r in ranges:
if "-" in r:
start, end = map(str.strip, r.split("-"))
start = int(start)
end = int(end)
assert 0 <= start < total_blocks, f"invalid start index: {start}"
assert 0 <= end < total_blocks, f"invalid end index: {end}"
assert start <= end, f"invalid range: {start}-{end}"
for i in range(start, end + 1):
selected[i] = True
else:
index = int(r)
assert 0 <= index < total_blocks, f"invalid index: {index}"
selected[index] = True
return selected
train_block_indices = kwargs.get("train_block_indices", None)
if train_block_indices is not None:
train_block_indices = parse_block_selection(train_block_indices, 999) # 999 is a dummy number
# rank/module dropout
rank_dropout = kwargs.get("rank_dropout", None)
if rank_dropout is not None:
rank_dropout = float(rank_dropout)
module_dropout = kwargs.get("module_dropout", None)
if module_dropout is not None:
module_dropout = float(module_dropout)
# split qkv
split_qkv = kwargs.get("split_qkv", False)
if split_qkv is not None:
split_qkv = True if split_qkv == "True" else False
# train T5XXL
train_t5xxl = kwargs.get("train_t5xxl", False)
if train_t5xxl is not None:
train_t5xxl = True if train_t5xxl == "True" else False
# verbose
verbose = kwargs.get("verbose", False)
if verbose is not None:
verbose = True if verbose == "True" else False
# すごく引数が多いな ( ^ω^)・・・
network = LoRANetwork(
text_encoders,
mmdit,
multiplier=multiplier,
lora_dim=network_dim,
alpha=network_alpha,
dropout=neuron_dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
conv_lora_dim=conv_dim,
conv_alpha=conv_alpha,
split_qkv=split_qkv,
train_t5xxl=train_t5xxl,
type_dims=type_dims,
emb_dims=emb_dims,
train_block_indices=train_block_indices,
verbose=verbose,
)
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
return network
# Create network from weights for inference, weights are not loaded here (because can be merged)
def create_network_from_weights(multiplier, file, ae, text_encoders, mmdit, weights_sd=None, for_inference=False, **kwargs):
# if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True
if weights_sd is None:
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file, safe_open
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
# get dim/alpha mapping, and train t5xxl
modules_dim = {}
modules_alpha = {}
train_t5xxl = None
for key, value in weights_sd.items():
if "." not in key:
continue
lora_name = key.split(".")[0]
if "alpha" in key:
modules_alpha[lora_name] = value
elif "lora_down" in key:
dim = value.size()[0]
modules_dim[lora_name] = dim
# logger.info(lora_name, value.size(), dim)
if train_t5xxl is None or train_t5xxl is False:
train_t5xxl = "lora_te3" in lora_name
if train_t5xxl is None:
train_t5xxl = False
split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined
module_class = LoRAInfModule if for_inference else LoRAModule
network = LoRANetwork(
text_encoders,
mmdit,
multiplier=multiplier,
modules_dim=modules_dim,
modules_alpha=modules_alpha,
module_class=module_class,
split_qkv=split_qkv,
train_t5xxl=train_t5xxl,
)
return network, weights_sd
class LoRANetwork(torch.nn.Module):
SD3_TARGET_REPLACE_MODULE = ["SingleDiTBlock"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP", "T5Attention", "T5DenseGatedActDense"]
LORA_PREFIX_SD3 = "lora_unet" # make ComfyUI compatible
LORA_PREFIX_TEXT_ENCODER_CLIP_L = "lora_te1"
LORA_PREFIX_TEXT_ENCODER_CLIP_G = "lora_te2"
LORA_PREFIX_TEXT_ENCODER_T5 = "lora_te3" # make ComfyUI compatible
def __init__(
self,
text_encoders: List[Union[CLIPTextModelWithProjection, T5EncoderModel]],
unet: sd3_models.MMDiT,
multiplier: float = 1.0,
lora_dim: int = 4,
alpha: float = 1,
dropout: Optional[float] = None,
rank_dropout: Optional[float] = None,
module_dropout: Optional[float] = None,
conv_lora_dim: Optional[int] = None,
conv_alpha: Optional[float] = None,
module_class: Type[object] = LoRAModule,
modules_dim: Optional[Dict[str, int]] = None,
modules_alpha: Optional[Dict[str, int]] = None,
split_qkv: bool = False,
train_t5xxl: bool = False,
type_dims: Optional[List[int]] = None,
emb_dims: Optional[List[int]] = None,
train_block_indices: Optional[List[bool]] = None,
verbose: Optional[bool] = False,
) -> None:
super().__init__()
self.multiplier = multiplier
self.lora_dim = lora_dim
self.alpha = alpha
self.conv_lora_dim = conv_lora_dim
self.conv_alpha = conv_alpha
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.split_qkv = split_qkv
self.train_t5xxl = train_t5xxl
self.type_dims = type_dims
self.emb_dims = emb_dims
self.train_block_indices = train_block_indices
self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None
self.loraplus_text_encoder_lr_ratio = None
if modules_dim is not None:
logger.info(f"create LoRA network from weights")
self.emb_dims = [0] * 6 # create emb_dims
# verbose = True
else:
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
logger.info(
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
)
# if self.conv_lora_dim is not None:
# logger.info(
# f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
# )
qkv_dim = 0
if self.split_qkv:
logger.info(f"split qkv for LoRA")
qkv_dim = unet.joint_blocks[0].context_block.attn.qkv.weight.size(0)
if train_t5xxl:
logger.info(f"train T5XXL as well")
# create module instances
def create_modules(
is_mmdit: bool,
text_encoder_idx: Optional[int],
root_module: torch.nn.Module,
target_replace_modules: List[str],
filter: Optional[str] = None,
default_dim: Optional[int] = None,
include_conv2d_if_filter: bool = False,
) -> List[LoRAModule]:
prefix = (
self.LORA_PREFIX_SD3
if is_mmdit
else [self.LORA_PREFIX_TEXT_ENCODER_CLIP_L, self.LORA_PREFIX_TEXT_ENCODER_CLIP_G, self.LORA_PREFIX_TEXT_ENCODER_T5][
text_encoder_idx
]
)
loras = []
skipped = []
for name, module in root_module.named_modules():
if target_replace_modules is None or module.__class__.__name__ in target_replace_modules:
if target_replace_modules is None: # dirty hack for all modules
module = root_module # search all modules
for child_name, child_module in module.named_modules():
is_linear = child_module.__class__.__name__ == "Linear"
is_conv2d = child_module.__class__.__name__ == "Conv2d"
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
if is_linear or is_conv2d:
lora_name = prefix + "." + (name + "." if name else "") + child_name
lora_name = lora_name.replace(".", "_")
force_incl_conv2d = False
if filter is not None:
if not filter in lora_name:
continue
force_incl_conv2d = include_conv2d_if_filter
dim = None
alpha = None
if modules_dim is not None:
# モジュール指定あり
if lora_name in modules_dim:
dim = modules_dim[lora_name]
alpha = modules_alpha[lora_name]
else:
# 通常、すべて対象とする
if is_linear or is_conv2d_1x1:
dim = default_dim if default_dim is not None else self.lora_dim
alpha = self.alpha
if is_mmdit and type_dims is not None:
# type_dims = [context_attn_dim, context_mlp_dim, context_mod_dim, x_attn_dim, x_mlp_dim, x_mod_dim]
identifier = [
("context_block", "attn"),
("context_block", "mlp"),
("context_block", "adaLN_modulation"),
("x_block", "attn"),
("x_block", "mlp"),
("x_block", "adaLN_modulation"),
]
for i, d in enumerate(type_dims):
if d is not None and all([id in lora_name for id in identifier[i]]):
dim = d # may be 0 for skip
break
if is_mmdit and dim and self.train_block_indices is not None and "joint_blocks" in lora_name:
# "lora_unet_joint_blocks_0_x_block_attn_proj..."
block_index = int(lora_name.split("_")[4]) # bit dirty
if self.train_block_indices is not None and not self.train_block_indices[block_index]:
dim = 0
elif self.conv_lora_dim is not None:
dim = self.conv_lora_dim
alpha = self.conv_alpha
elif force_incl_conv2d:
# x_embedder
dim = default_dim if default_dim is not None else self.lora_dim
alpha = self.alpha
if dim is None or dim == 0:
# skipした情報を出力
if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None):
skipped.append(lora_name)
continue
# qkv split
split_dims = None
if is_mmdit and split_qkv:
if "joint_blocks" in lora_name and "qkv" in lora_name:
split_dims = [qkv_dim // 3] * 3
lora = module_class(
lora_name,
child_module,
self.multiplier,
dim,
alpha,
dropout=dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
split_dims=split_dims,
)
loras.append(lora)
if target_replace_modules is None:
break # all modules are searched
return loras, skipped
# create LoRA for text encoder
# 毎回すべてのモジュールを作るのは無駄なので要検討
self.text_encoder_loras: List[Union[LoRAModule, LoRAInfModule]] = []
skipped_te = []
for i, text_encoder in enumerate(text_encoders):
index = i
if not train_t5xxl and index >= 2: # 0: CLIP-L, 1: CLIP-G, 2: T5XXL, so we skip T5XXL if train_t5xxl is False
break
logger.info(f"create LoRA for Text Encoder {index+1}:")
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
logger.info(f"create LoRA for Text Encoder {index+1}: {len(text_encoder_loras)} modules.")
self.text_encoder_loras.extend(text_encoder_loras)
skipped_te += skipped
# create LoRA for U-Net
self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
self.unet_loras, skipped_un = create_modules(True, None, unet, LoRANetwork.SD3_TARGET_REPLACE_MODULE)
# emb_dims [context_embedder, t_embedder, x_embedder, y_embedder, final_mod, final_linear]
if self.emb_dims:
for filter, in_dim in zip(
[
"context_embedder",
"_t_embedder", # don't use "t_embedder" because it's used in "context_embedder"
"x_embedder",
"y_embedder",
"final_layer_adaLN_modulation",
"final_layer_linear",
],
self.emb_dims,
):
# x_embedder is conv2d, so we need to include it
loras, _ = create_modules(
True, None, unet, None, filter=filter, default_dim=in_dim, include_conv2d_if_filter=filter == "x_embedder"
)
# if len(loras) > 0:
# logger.info(f"create LoRA for {filter}: {len(loras)} modules.")
self.unet_loras.extend(loras)
logger.info(f"create LoRA for SD3 MMDiT: {len(self.unet_loras)} modules.")
if verbose:
for lora in self.unet_loras:
logger.info(f"\t{lora.lora_name:50} {lora.lora_dim}, {lora.alpha}")
skipped = skipped_te + skipped_un
if verbose and len(skipped) > 0:
logger.warning(
f"because dim (rank) is 0, {len(skipped)} LoRA modules are skipped / dim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
)
for name in skipped:
logger.info(f"\t{name}")
# assertion
names = set()
for lora in self.text_encoder_loras + self.unet_loras:
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
names.add(lora.lora_name)
def set_multiplier(self, multiplier):
self.multiplier = multiplier
for lora in self.text_encoder_loras + self.unet_loras:
lora.multiplier = self.multiplier
def set_enabled(self, is_enabled):
for lora in self.text_encoder_loras + self.unet_loras:
lora.enabled = is_enabled
def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
info = self.load_state_dict(weights_sd, False)
return info
def load_state_dict(self, state_dict, strict=True):
# override to convert original weight to split qkv
if not self.split_qkv:
return super().load_state_dict(state_dict, strict)
# split qkv
for key in list(state_dict.keys()):
if not ("joint_blocks" in key and "qkv" in key):
continue
weight = state_dict[key]
lora_name = key.split(".")[0]
if "lora_down" in key and "weight" in key:
# dense weight (rank*3, in_dim)
split_weight = torch.chunk(weight, 3, dim=0)
for i, split_w in enumerate(split_weight):
state_dict[f"{lora_name}.lora_down.{i}.weight"] = split_w
del state_dict[key]
# print(f"split {key}: {weight.shape} to {[w.shape for w in split_weight]}")
elif "lora_up" in key and "weight" in key:
# sparse weight (out_dim=sum(split_dims), rank*3)
rank = weight.size(1) // 3
i = 0
split_dim = weight.shape[0] // 3
for j in range(3):
state_dict[f"{lora_name}.lora_up.{j}.weight"] = weight[i : i + split_dim, j * rank : (j + 1) * rank]
i += split_dim
del state_dict[key]
# alpha is unchanged
return super().load_state_dict(state_dict, strict)
def state_dict(self, destination=None, prefix="", keep_vars=False):
if not self.split_qkv:
return super().state_dict(destination, prefix, keep_vars)
# merge qkv
state_dict = super().state_dict(destination, prefix, keep_vars)
new_state_dict = {}
for key in list(state_dict.keys()):
if not ("joint_blocks" in key and "qkv" in key):
new_state_dict[key] = state_dict[key]
continue
if key not in state_dict:
continue # already merged
lora_name = key.split(".")[0]
# (rank, in_dim) * 3
down_weights = [state_dict.pop(f"{lora_name}.lora_down.{i}.weight") for i in range(3)]
# (split dim, rank) * 3
up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(3)]
alpha = state_dict.pop(f"{lora_name}.alpha")
# merge down weight
down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim)
# merge up weight (sum of split_dim, rank*3)
split_dim, rank = up_weights[0].size()
qkv_dim = split_dim * 3
up_weight = torch.zeros((qkv_dim, down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype)
i = 0
for j in range(3):
up_weight[i : i + split_dim, j * rank : (j + 1) * rank] = up_weights[j]
i += split_dim
new_state_dict[f"{lora_name}.lora_down.weight"] = down_weight
new_state_dict[f"{lora_name}.lora_up.weight"] = up_weight
new_state_dict[f"{lora_name}.alpha"] = alpha
# print(
# f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}"
# )
print(f"new key: {lora_name}.lora_down.weight, {lora_name}.lora_up.weight, {lora_name}.alpha")
return new_state_dict
def apply_to(self, text_encoders, mmdit, apply_text_encoder=True, apply_unet=True):
if apply_text_encoder:
logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules")
else:
self.text_encoder_loras = []
if apply_unet:
logger.info(f"enable LoRA for U-Net: {len(self.unet_loras)} modules")
else:
self.unet_loras = []
for lora in self.text_encoder_loras + self.unet_loras:
lora.apply_to()
self.add_module(lora.lora_name, lora)
# マージできるかどうかを返す
def is_mergeable(self):
return True
# TODO refactor to common function with apply_to
def merge_to(self, text_encoders, mmdit, weights_sd, dtype=None, device=None):
apply_text_encoder = apply_unet = False
for key in weights_sd.keys():
if (
key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_CLIP_L)
or key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_CLIP_G)
or key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER_T5)
):
apply_text_encoder = True
elif key.startswith(LoRANetwork.LORA_PREFIX_SD3):
apply_unet = True
if apply_text_encoder:
logger.info("enable LoRA for text encoder")
else:
self.text_encoder_loras = []
if apply_unet:
logger.info("enable LoRA for U-Net")
else:
self.unet_loras = []
for lora in self.text_encoder_loras + self.unet_loras:
sd_for_lora = {}
for key in weights_sd.keys():
if key.startswith(lora.lora_name):
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
lora.merge_to(sd_for_lora, dtype, device)
logger.info(f"weights are merged")
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
self.loraplus_lr_ratio = loraplus_lr_ratio
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}")
logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}")
def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr):
# make sure text_encoder_lr as list of three elements
# if float, use the same value for all three
if text_encoder_lr is None or (isinstance(text_encoder_lr, list) and len(text_encoder_lr) == 0):
text_encoder_lr = [default_lr, default_lr, default_lr]
elif isinstance(text_encoder_lr, float) or isinstance(text_encoder_lr, int):
text_encoder_lr = [float(text_encoder_lr), float(text_encoder_lr), float(text_encoder_lr)]
elif len(text_encoder_lr) == 1:
text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[0], text_encoder_lr[0]]
elif len(text_encoder_lr) == 2:
text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[1], text_encoder_lr[1]]
self.requires_grad_(True)
all_params = []
lr_descriptions = []
def assemble_params(loras, lr, loraplus_ratio):
param_groups = {"lora": {}, "plus": {}}
for lora in loras:
for name, param in lora.named_parameters():
if loraplus_ratio is not None and "lora_up" in name:
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
else:
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
params = []
descriptions = []
for key in param_groups.keys():
param_data = {"params": param_groups[key].values()}
if len(param_data["params"]) == 0:
continue
if lr is not None:
if key == "plus":
param_data["lr"] = lr * loraplus_ratio
else:
param_data["lr"] = lr
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
logger.info("NO LR skipping!")
continue
params.append(param_data)
descriptions.append("plus" if key == "plus" else "")
return params, descriptions
if self.text_encoder_loras:
loraplus_lr_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio
# split text encoder loras for te1 and te3
te1_loras = [
lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_CLIP_L)
]
te2_loras = [
lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_CLIP_G)
]
te3_loras = [lora for lora in self.text_encoder_loras if lora.lora_name.startswith(self.LORA_PREFIX_TEXT_ENCODER_T5)]
if len(te1_loras) > 0:
logger.info(f"Text Encoder 1 (CLIP-L): {len(te1_loras)} modules, LR {text_encoder_lr[0]}")
params, descriptions = assemble_params(te1_loras, text_encoder_lr[0], loraplus_lr_ratio)
all_params.extend(params)
lr_descriptions.extend(["textencoder 1 " + (" " + d if d else "") for d in descriptions])
if len(te2_loras) > 0:
logger.info(f"Text Encoder 2 (CLIP-G): {len(te2_loras)} modules, LR {text_encoder_lr[1]}")
params, descriptions = assemble_params(te2_loras, text_encoder_lr[1], loraplus_lr_ratio)
all_params.extend(params)
lr_descriptions.extend(["textencoder 1 " + (" " + d if d else "") for d in descriptions])
if len(te3_loras) > 0:
logger.info(f"Text Encoder 3 (T5XXL): {len(te3_loras)} modules, LR {text_encoder_lr[2]}")
params, descriptions = assemble_params(te3_loras, text_encoder_lr[2], loraplus_lr_ratio)
all_params.extend(params)
lr_descriptions.extend(["textencoder 3 " + (" " + d if d else "") for d in descriptions])
if self.unet_loras:
params, descriptions = assemble_params(
self.unet_loras,
unet_lr if unet_lr is not None else default_lr,
self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio,
)
all_params.extend(params)
lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions])
return all_params, lr_descriptions
def enable_gradient_checkpointing(self):
# not supported
pass
def prepare_grad_etc(self, text_encoder, unet):
self.requires_grad_(True)
def on_epoch_start(self, text_encoder, unet):
self.train()
def get_trainable_params(self):
return self.parameters()
def save_weights(self, file, dtype, metadata):
if metadata is not None and len(metadata) == 0:
metadata = None
state_dict = self.state_dict()
if dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(dtype)
state_dict[key] = v
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file
from library import train_util
# Precalculate model hashes to save time on indexing
if metadata is None:
metadata = {}
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
save_file(state_dict, file, metadata)
else:
torch.save(state_dict, file)
def backup_weights(self):
# 重みのバックアップを行う
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
for lora in loras:
org_module = lora.org_module_ref[0]
if not hasattr(org_module, "_lora_org_weight"):
sd = org_module.state_dict()
org_module._lora_org_weight = sd["weight"].detach().clone()
org_module._lora_restored = True
def restore_weights(self):
# 重みのリストアを行う
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
for lora in loras:
org_module = lora.org_module_ref[0]
if not org_module._lora_restored:
sd = org_module.state_dict()
sd["weight"] = org_module._lora_org_weight
org_module.load_state_dict(sd)
org_module._lora_restored = True
def pre_calculation(self):
# 事前計算を行う
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
for lora in loras:
org_module = lora.org_module_ref[0]
sd = org_module.state_dict()
org_weight = sd["weight"]
lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype)
sd["weight"] = org_weight + lora_weight
assert sd["weight"].shape == org_weight.shape
org_module.load_state_dict(sd)
org_module._lora_restored = False
lora.enabled = False
def apply_max_norm_regularization(self, max_norm_value, device):
downkeys = []
upkeys = []
alphakeys = []
norms = []
keys_scaled = 0
state_dict = self.state_dict()
for key in state_dict.keys():
if "lora_down" in key and "weight" in key:
downkeys.append(key)
upkeys.append(key.replace("lora_down", "lora_up"))
alphakeys.append(key.replace("lora_down.weight", "alpha"))
for i in range(len(downkeys)):
down = state_dict[downkeys[i]].to(device)
up = state_dict[upkeys[i]].to(device)
alpha = state_dict[alphakeys[i]].to(device)
dim = down.shape[0]
scale = alpha / dim
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
else:
updown = up @ down
updown *= scale
norm = updown.norm().clamp(min=max_norm_value / 2)
desired = torch.clamp(norm, max=max_norm_value)
ratio = desired.cpu() / norm.cpu()
sqrt_ratio = ratio**0.5
if ratio != 1:
keys_scaled += 1
state_dict[upkeys[i]] *= sqrt_ratio
state_dict[downkeys[i]] *= sqrt_ratio
scalednorm = updown.norm() * ratio
norms.append(scalednorm.item())
return keys_scaled, sum(norms) / len(norms), max(norms)

View File

@@ -4,13 +4,17 @@ import math
import os
from typing import Dict, List, Optional, Tuple, Type, Union
from diffusers import AutoencoderKL
import einops
from transformers import CLIPTextModel
import numpy as np
import torch
import torch.nn.functional as F
import re
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
@@ -45,11 +49,16 @@ class OFTModule(torch.nn.Module):
if type(alpha) == torch.Tensor:
alpha = alpha.detach().numpy()
self.constraint = alpha * out_dim
# constraint in original paper is alpha * out_dim * out_dim, but we use alpha * out_dim for backward compatibility
# original alpha is 1e-5, so we use 1e-2 or 1e-4 for alpha
self.constraint = alpha * out_dim
self.register_buffer("alpha", torch.tensor(alpha))
self.block_size = out_dim // self.num_blocks
self.oft_blocks = torch.nn.Parameter(torch.zeros(self.num_blocks, self.block_size, self.block_size))
self.I = torch.eye(self.block_size).unsqueeze(0).repeat(self.num_blocks, 1, 1) # cpu
self.out_dim = out_dim
self.shape = org_module.weight.shape
@@ -69,27 +78,36 @@ class OFTModule(torch.nn.Module):
norm_Q = torch.norm(block_Q.flatten())
new_norm_Q = torch.clamp(norm_Q, max=self.constraint)
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1)
block_R = torch.matmul(I + block_Q, (I - block_Q).inverse())
block_R_weighted = self.multiplier * block_R + (1 - self.multiplier) * I
R = torch.block_diag(*block_R_weighted)
return R
if self.I.device != block_Q.device:
self.I = self.I.to(block_Q.device)
I = self.I
block_R = torch.matmul(I + block_Q, (I - block_Q).float().inverse())
block_R_weighted = self.multiplier * (block_R - I) + I
return block_R_weighted
def forward(self, x, scale=None):
x = self.org_forward(x)
if self.multiplier == 0.0:
return x
return self.org_forward(x)
org_module = self.org_module[0]
org_dtype = x.dtype
R = self.get_weight().to(x.device, dtype=x.dtype)
if x.dim() == 4:
x = x.permute(0, 2, 3, 1)
x = torch.matmul(x, R)
x = x.permute(0, 3, 1, 2)
else:
x = torch.matmul(x, R)
return x
R = self.get_weight().to(torch.float32)
W = org_module.weight.to(torch.float32)
if len(W.shape) == 4: # Conv2d
W_reshaped = einops.rearrange(W, "(k n) ... -> k n ...", k=self.num_blocks, n=self.block_size)
RW = torch.einsum("k n m, k n ... -> k m ...", R, W_reshaped)
RW = einops.rearrange(RW, "k m ... -> (k m) ...")
result = F.conv2d(
x, RW.to(org_dtype), org_module.bias, org_module.stride, org_module.padding, org_module.dilation, org_module.groups
)
else: # Linear
W_reshaped = einops.rearrange(W, "(k n) m -> k n m", k=self.num_blocks, n=self.block_size)
RW = torch.einsum("k n m, k n p -> k m p", R, W_reshaped)
RW = einops.rearrange(RW, "k m p -> (k m) p")
result = F.linear(x, RW.to(org_dtype), org_module.bias)
return result
class OFTInfModule(OFTModule):
@@ -115,18 +133,19 @@ class OFTInfModule(OFTModule):
return self.org_forward(x)
return super().forward(x, scale)
def merge_to(self, multiplier=None, sign=1):
R = self.get_weight(multiplier) * sign
def merge_to(self, multiplier=None):
# get org weight
org_sd = self.org_module[0].state_dict()
org_weight = org_sd["weight"]
R = R.to(org_weight.device, dtype=org_weight.dtype)
org_weight = org_sd["weight"].to(torch.float32)
if org_weight.dim() == 4:
weight = torch.einsum("oihw, op -> pihw", org_weight, R)
else:
weight = torch.einsum("oi, op -> pi", org_weight, R)
R = self.get_weight(multiplier).to(torch.float32)
weight = org_weight.reshape(self.num_blocks, self.block_size, -1)
weight = torch.einsum("k n m, k n ... -> k m ...", R, weight)
weight = weight.reshape(org_weight.shape)
# convert back to original dtype
weight = weight.to(org_sd["weight"].dtype)
# set weight to org_module
org_sd["weight"] = weight
@@ -145,8 +164,16 @@ def create_network(
):
if network_dim is None:
network_dim = 4 # default
if network_alpha is None:
network_alpha = 1.0
if network_alpha is None: # should be set
logger.info(
"network_alpha is not set, use default value 1e-3 / network_alphaが設定されていないのでデフォルト値 1e-3 を使用します"
)
network_alpha = 1e-3
elif network_alpha >= 1:
logger.warning(
"network_alpha is too large (>=1, maybe default value is too large), please consider to set smaller value like 1e-3"
" / network_alphaが大きすぎるようです(>=1, デフォルト値が大きすぎる可能性があります)。1e-3のような小さな値を推奨"
)
enable_all_linear = kwargs.get("enable_all_linear", None)
enable_conv = kwargs.get("enable_conv", None)
@@ -190,12 +217,11 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
else:
if dim is None:
dim = param.size()[0]
if has_conv2d is None and param.dim() == 4:
if has_conv2d is None and "in_layers_2" in name:
has_conv2d = True
if all_linear is None:
if param.dim() == 3 and "attn" not in name:
all_linear = True
if dim is not None and alpha is not None and has_conv2d is not None:
if all_linear is None and "_ff_" in name:
all_linear = True
if dim is not None and alpha is not None and has_conv2d is not None and all_linear is not None:
break
if has_conv2d is None:
has_conv2d = False
@@ -241,7 +267,7 @@ class OFTNetwork(torch.nn.Module):
self.alpha = alpha
logger.info(
f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}"
f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}, enable_all_linear: {enable_all_linear}"
)
# create module instances

482
networks/oft_flux.py Normal file
View File

@@ -0,0 +1,482 @@
# OFT network module
import math
import os
from typing import Dict, List, Optional, Tuple, Type, Union
from diffusers import AutoencoderKL
import einops
from transformers import CLIPTextModel
import numpy as np
import torch
import torch.nn.functional as F
import re
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class OFTModule(torch.nn.Module):
"""
replaces forward method of the original Linear, instead of replacing the original Linear module.
"""
def __init__(
self,
oft_name,
org_module: torch.nn.Module,
multiplier=1.0,
dim=4,
alpha=1,
split_dims: Optional[List[int]] = None,
):
"""
dim -> num blocks
alpha -> constraint
split_dims is used to mimic the split qkv of FLUX as same as Diffusers
"""
super().__init__()
self.oft_name = oft_name
self.num_blocks = dim
if type(alpha) == torch.Tensor:
alpha = alpha.detach().numpy()
self.register_buffer("alpha", torch.tensor(alpha))
# No conv2d in FLUX
# if "Linear" in org_module.__class__.__name__:
self.out_dim = org_module.out_features
# elif "Conv" in org_module.__class__.__name__:
# out_dim = org_module.out_channels
if split_dims is None:
split_dims = [self.out_dim]
else:
assert sum(split_dims) == self.out_dim, "sum of split_dims must be equal to out_dim"
self.split_dims = split_dims
# assert all dim is divisible by num_blocks
for split_dim in self.split_dims:
assert split_dim % self.num_blocks == 0, "split_dim must be divisible by num_blocks"
self.constraint = [alpha * split_dim for split_dim in self.split_dims]
self.block_size = [split_dim // self.num_blocks for split_dim in self.split_dims]
self.oft_blocks = torch.nn.ParameterList(
[torch.nn.Parameter(torch.zeros(self.num_blocks, block_size, block_size)) for block_size in self.block_size]
)
self.I = [torch.eye(block_size).unsqueeze(0).repeat(self.num_blocks, 1, 1) for block_size in self.block_size]
self.shape = org_module.weight.shape
self.multiplier = multiplier
self.org_module = [org_module] # moduleにならないようにlistに入れる
def apply_to(self):
self.org_forward = self.org_module[0].forward
self.org_module[0].forward = self.forward
def get_weight(self, multiplier=None):
if multiplier is None:
multiplier = self.multiplier
if self.I[0].device != self.oft_blocks[0].device:
self.I = [I.to(self.oft_blocks[0].device) for I in self.I]
block_R_weighted_list = []
for i in range(len(self.oft_blocks)):
block_Q = self.oft_blocks[i] - self.oft_blocks[i].transpose(1, 2)
norm_Q = torch.norm(block_Q.flatten())
new_norm_Q = torch.clamp(norm_Q, max=self.constraint[i])
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
I = self.I[i]
block_R = torch.matmul(I + block_Q, (I - block_Q).float().inverse())
block_R_weighted = self.multiplier * (block_R - I) + I
block_R_weighted_list.append(block_R_weighted)
return block_R_weighted_list
def forward(self, x, scale=None):
if self.multiplier == 0.0:
return self.org_forward(x)
org_module = self.org_module[0]
org_dtype = x.dtype
R = self.get_weight()
W = org_module.weight.to(torch.float32)
B = org_module.bias.to(torch.float32)
# split W to match R
results = []
d2 = 0
for i in range(len(R)):
d1 = d2
d2 += self.split_dims[i]
W1 = W[d1:d2]
W_reshaped = einops.rearrange(W1, "(k n) m -> k n m", k=self.num_blocks, n=self.block_size[i])
RW_1 = torch.einsum("k n m, k n p -> k m p", R[i], W_reshaped)
RW_1 = einops.rearrange(RW_1, "k m p -> (k m) p")
B1 = B[d1:d2]
result = F.linear(x, RW_1.to(org_dtype), B1.to(org_dtype))
results.append(result)
result = torch.cat(results, dim=-1)
return result
class OFTInfModule(OFTModule):
def __init__(
self,
oft_name,
org_module: torch.nn.Module,
multiplier=1.0,
dim=4,
alpha=1,
split_dims: Optional[List[int]] = None,
**kwargs,
):
# no dropout for inference
super().__init__(oft_name, org_module, multiplier, dim, alpha, split_dims)
self.enabled = True
self.network: OFTNetwork = None
def set_network(self, network):
self.network = network
def forward(self, x, scale=None):
if not self.enabled:
return self.org_forward(x)
return super().forward(x, scale)
def merge_to(self, multiplier=None):
# get org weight
org_sd = self.org_module[0].state_dict()
W = org_sd["weight"].to(torch.float32)
R = self.get_weight(multiplier).to(torch.float32)
d2 = 0
W_list = []
for i in range(len(self.oft_blocks)):
d1 = d2
d2 += self.split_dims[i]
W1 = W[d1:d2]
W_reshaped = einops.rearrange(W1, "(k n) m -> k n m", k=self.num_blocks, n=self.block_size[i])
W1 = torch.einsum("k n m, k n p -> k m p", R[i], W_reshaped)
W1 = einops.rearrange(W1, "k m p -> (k m) p")
W_list.append(W1)
W = torch.cat(W_list, dim=-1)
# convert back to original dtype
W = W.to(org_sd["weight"].dtype)
# set weight to org_module
org_sd["weight"] = W
self.org_module[0].load_state_dict(org_sd)
def create_network(
multiplier: float,
network_dim: Optional[int],
network_alpha: Optional[float],
vae: AutoencoderKL,
text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
unet,
neuron_dropout: Optional[float] = None,
**kwargs,
):
if network_dim is None:
network_dim = 4 # default
if network_alpha is None: # should be set
logger.info(
"network_alpha is not set, use default value 1e-3 / network_alphaが設定されていないのでデフォルト値 1e-3 を使用します"
)
network_alpha = 1e-3
elif network_alpha >= 1:
logger.warning(
"network_alpha is too large (>=1, maybe default value is too large), please consider to set smaller value like 1e-3"
" / network_alphaが大きすぎるようです(>=1, デフォルト値が大きすぎる可能性があります)。1e-3のような小さな値を推奨"
)
# attn only or all linear (FFN) layers
enable_all_linear = kwargs.get("enable_all_linear", None)
# enable_conv = kwargs.get("enable_conv", None)
if enable_all_linear is not None:
enable_all_linear = bool(enable_all_linear)
# if enable_conv is not None:
# enable_conv = bool(enable_conv)
network = OFTNetwork(
text_encoder,
unet,
multiplier=multiplier,
dim=network_dim,
alpha=network_alpha,
enable_all_linear=enable_all_linear,
varbose=True,
)
return network
# Create network from weights for inference, weights are not loaded here (because can be merged)
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
if weights_sd is None:
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file, safe_open
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
# check dim, alpha and if weights have for conv2d
dim = None
alpha = None
all_linear = None
for name, param in weights_sd.items():
if name.endswith(".alpha"):
if alpha is None:
alpha = param.item()
elif "qkv" in name:
continue # ignore qkv
else:
if dim is None:
dim = param.size()[0]
if all_linear is None and "_mlp" in name:
all_linear = True
if dim is not None and alpha is not None and all_linear is not None:
break
if all_linear is None:
all_linear = False
module_class = OFTInfModule if for_inference else OFTModule
network = OFTNetwork(
text_encoder,
unet,
multiplier=multiplier,
dim=dim,
alpha=alpha,
enable_all_linear=all_linear,
module_class=module_class,
)
return network, weights_sd
class OFTNetwork(torch.nn.Module):
FLUX_TARGET_REPLACE_MODULE_ALL_LINEAR = ["DoubleStreamBlock", "SingleStreamBlock"]
FLUX_TARGET_REPLACE_MODULE_ATTN_ONLY = ["SelfAttention"]
OFT_PREFIX_UNET = "oft_unet"
def __init__(
self,
text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
unet,
multiplier: float = 1.0,
dim: int = 4,
alpha: float = 1,
enable_all_linear: Optional[bool] = False,
module_class: Union[Type[OFTModule], Type[OFTInfModule]] = OFTModule,
varbose: Optional[bool] = False,
) -> None:
super().__init__()
self.train_t5xxl = False # make compatible with LoRA
self.multiplier = multiplier
self.dim = dim
self.alpha = alpha
logger.info(
f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_all_linear: {enable_all_linear}"
)
# create module instances
def create_modules(
root_module: torch.nn.Module,
target_replace_modules: List[torch.nn.Module],
) -> List[OFTModule]:
prefix = self.OFT_PREFIX_UNET
ofts = []
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
is_linear = "Linear" in child_module.__class__.__name__
if is_linear:
oft_name = prefix + "." + name + "." + child_name
oft_name = oft_name.replace(".", "_")
# logger.info(oft_name)
if "double" in oft_name and "qkv" in oft_name:
split_dims = [3072] * 3
elif "single" in oft_name and "linear1" in oft_name:
split_dims = [3072] * 3 + [12288]
else:
split_dims = None
oft = module_class(oft_name, child_module, self.multiplier, dim, alpha, split_dims)
ofts.append(oft)
return ofts
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
if enable_all_linear:
target_modules = OFTNetwork.FLUX_TARGET_REPLACE_MODULE_ALL_LINEAR
else:
target_modules = OFTNetwork.FLUX_TARGET_REPLACE_MODULE_ATTN_ONLY
self.unet_ofts: List[OFTModule] = create_modules(unet, target_modules)
logger.info(f"create OFT for Flux: {len(self.unet_ofts)} modules.")
# assertion
names = set()
for oft in self.unet_ofts:
assert oft.oft_name not in names, f"duplicated oft name: {oft.oft_name}"
names.add(oft.oft_name)
def set_multiplier(self, multiplier):
self.multiplier = multiplier
for oft in self.unet_ofts:
oft.multiplier = self.multiplier
def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
info = self.load_state_dict(weights_sd, False)
return info
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
assert apply_unet, "apply_unet must be True"
for oft in self.unet_ofts:
oft.apply_to()
self.add_module(oft.oft_name, oft)
# マージできるかどうかを返す
def is_mergeable(self):
return True
# TODO refactor to common function with apply_to
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
logger.info("enable OFT for U-Net")
for oft in self.unet_ofts:
sd_for_lora = {}
for key in weights_sd.keys():
if key.startswith(oft.oft_name):
sd_for_lora[key[len(oft.oft_name) + 1 :]] = weights_sd[key]
oft.load_state_dict(sd_for_lora, False)
oft.merge_to()
logger.info(f"weights are merged")
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
self.requires_grad_(True)
all_params = []
def enumerate_params(ofts):
params = []
for oft in ofts:
params.extend(oft.parameters())
# logger.info num of params
num_params = 0
for p in params:
num_params += p.numel()
logger.info(f"OFT params: {num_params}")
return params
param_data = {"params": enumerate_params(self.unet_ofts)}
if unet_lr is not None:
param_data["lr"] = unet_lr
all_params.append(param_data)
return all_params
def enable_gradient_checkpointing(self):
# not supported
pass
def prepare_grad_etc(self, text_encoder, unet):
self.requires_grad_(True)
def on_epoch_start(self, text_encoder, unet):
self.train()
def get_trainable_params(self):
return self.parameters()
def save_weights(self, file, dtype, metadata):
if metadata is not None and len(metadata) == 0:
metadata = None
state_dict = self.state_dict()
if dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(dtype)
state_dict[key] = v
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file
from library import train_util
# Precalculate model hashes to save time on indexing
if metadata is None:
metadata = {}
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
save_file(state_dict, file, metadata)
else:
torch.save(state_dict, file)
def backup_weights(self):
# 重みのバックアップを行う
ofts: List[OFTInfModule] = self.unet_ofts
for oft in ofts:
org_module = oft.org_module[0]
if not hasattr(org_module, "_lora_org_weight"):
sd = org_module.state_dict()
org_module._lora_org_weight = sd["weight"].detach().clone()
org_module._lora_restored = True
def restore_weights(self):
# 重みのリストアを行う
ofts: List[OFTInfModule] = self.unet_ofts
for oft in ofts:
org_module = oft.org_module[0]
if not org_module._lora_restored:
sd = org_module.state_dict()
sd["weight"] = org_module._lora_org_weight
org_module.load_state_dict(sd)
org_module._lora_restored = True
def pre_calculation(self):
# 事前計算を行う
ofts: List[OFTInfModule] = self.unet_ofts
for oft in ofts:
org_module = oft.org_module[0]
oft.merge_to()
# sd = org_module.state_dict()
# org_weight = sd["weight"]
# lora_weight = oft.get_weight().to(org_weight.device, dtype=org_weight.dtype)
# sd["weight"] = org_weight + lora_weight
# assert sd["weight"].shape == org_weight.shape
# org_module.load_state_dict(sd)
org_module._lora_restored = False
oft.enabled = False

View File

@@ -39,12 +39,7 @@ def load_state_dict(file_name, dtype):
return sd, metadata
def save_to_file(file_name, state_dict, dtype, metadata):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
def save_to_file(file_name, state_dict, metadata):
if model_util.is_safetensors(file_name):
save_file(state_dict, file_name, metadata)
else:
@@ -349,12 +344,18 @@ def resize(args):
metadata["ss_network_dim"] = "Dynamic"
metadata["ss_network_alpha"] = "Dynamic"
# cast to save_dtype before calculating hashes
for key in list(state_dict.keys()):
value = state_dict[key]
if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype:
state_dict[key] = value.to(save_dtype)
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
logger.info(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, save_dtype, metadata)
save_to_file(args.save_to, state_dict, metadata)
def setup_parser() -> argparse.ArgumentParser:

View File

@@ -1,18 +1,25 @@
import itertools
import math
import argparse
import os
import time
import concurrent.futures
import torch
from safetensors.torch import load_file, save_file
from tqdm import tqdm
from library import sai_model_spec, sdxl_model_util, train_util
import library.model_util as model_util
import lora
import oft
from svd_merge_lora import format_lbws, get_lbw_block_index, LAYER26
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == ".safetensors":
sd = load_file(file_name)
@@ -28,36 +35,58 @@ def load_state_dict(file_name, dtype):
return sd, metadata
def save_to_file(file_name, model, state_dict, dtype, metadata):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
def save_to_file(file_name, model, metadata):
if os.path.splitext(file_name)[1] == ".safetensors":
save_file(model, file_name, metadata=metadata)
else:
torch.save(model, file_name)
def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype):
text_encoder1.to(merge_dtype)
def detect_method_from_training_model(models, dtype):
for model in models:
# TODO It is better to use key names to detect the method
lora_sd, _ = load_state_dict(model, dtype)
for key in tqdm(lora_sd.keys()):
if "lora_up" in key or "lora_down" in key:
return "LoRA"
elif "oft_blocks" in key:
return "OFT"
def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, lbws, merge_dtype):
text_encoder1.to(merge_dtype)
text_encoder2.to(merge_dtype)
unet.to(merge_dtype)
# detect the method: OFT or LoRA_module
method = detect_method_from_training_model(models, merge_dtype)
logger.info(f"method:{method}")
if lbws:
lbws, _, LBW_TARGET_IDX = format_lbws(lbws)
else:
LBW_TARGET_IDX = []
# create module map
name_to_module = {}
for i, root_module in enumerate([text_encoder1, text_encoder2, unet]):
if i <= 1:
if i == 0:
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1
if method == "LoRA":
if i <= 1:
if i == 0:
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1
else:
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2
target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
else:
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2
target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
else:
prefix = lora.LoRANetwork.LORA_PREFIX_UNET
prefix = lora.LoRANetwork.LORA_PREFIX_UNET
target_replace_modules = (
lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
)
elif method == "OFT":
prefix = oft.OFTNetwork.OFT_PREFIX_UNET
# ALL_LINEAR includes ATTN_ONLY, so we don't need to specify ATTN_ONLY
target_replace_modules = (
lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_ALL_LINEAR + oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
)
for name, module in root_module.named_modules():
@@ -68,65 +97,172 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
lora_name = lora_name.replace(".", "_")
name_to_module[lora_name] = child_module
for model, ratio in zip(models, ratios):
for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws):
logger.info(f"loading: {model}")
lora_sd, _ = load_state_dict(model, merge_dtype)
logger.info(f"merging...")
for key in tqdm(lora_sd.keys()):
if "lora_down" in key:
up_key = key.replace("lora_down", "lora_up")
alpha_key = key[: key.index("lora_down")] + "alpha"
# find original module for this lora
module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
if lbw:
lbw_weights = [1] * 26
for index, value in zip(LBW_TARGET_IDX, lbw):
lbw_weights[index] = value
logger.info(f"lbw: {dict(zip(LAYER26.keys(), lbw_weights))}")
if method == "LoRA":
for key in tqdm(lora_sd.keys()):
if "lora_down" in key:
up_key = key.replace("lora_down", "lora_up")
alpha_key = key[: key.index("lora_down")] + "alpha"
# find original module for this lora
module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
if module_name not in name_to_module:
logger.info(f"no module found for LoRA weight: {key}")
continue
module = name_to_module[module_name]
# logger.info(f"apply {key} to {module}")
down_weight = lora_sd[key]
up_weight = lora_sd[up_key]
dim = down_weight.size()[0]
alpha = lora_sd.get(alpha_key, dim)
scale = alpha / dim
if lbw:
index = get_lbw_block_index(key, True)
is_lbw_target = index in LBW_TARGET_IDX
if is_lbw_target:
scale *= lbw_weights[index] # keyがlbwの対象であれば、lbwの重みを掛ける
# W <- W + U * D
weight = module.weight
# logger.info(module_name, down_weight.size(), up_weight.size())
if len(weight.size()) == 2:
# linear
weight = weight + ratio * (up_weight @ down_weight) * scale
elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1
weight = (
weight
+ ratio
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* scale
)
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + ratio * conved * scale
module.weight = torch.nn.Parameter(weight)
elif method == "OFT":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for key in tqdm(lora_sd.keys()):
if "oft_blocks" in key:
oft_blocks = lora_sd[key]
dim = oft_blocks.shape[0]
break
for key in tqdm(lora_sd.keys()):
if "alpha" in key:
oft_blocks = lora_sd[key]
alpha = oft_blocks.item()
break
def merge_to(key):
if "alpha" in key:
return
# find original module for this OFT
module_name = ".".join(key.split(".")[:-1])
if module_name not in name_to_module:
logger.info(f"no module found for LoRA weight: {key}")
continue
logger.info(f"no module found for OFT weight: {key}")
return
module = name_to_module[module_name]
# logger.info(f"apply {key} to {module}")
down_weight = lora_sd[key]
up_weight = lora_sd[up_key]
oft_blocks = lora_sd[key]
dim = down_weight.size()[0]
alpha = lora_sd.get(alpha_key, dim)
scale = alpha / dim
if isinstance(module, torch.nn.Linear):
out_dim = module.out_features
elif isinstance(module, torch.nn.Conv2d):
out_dim = module.out_channels
# W <- W + U * D
weight = module.weight
# logger.info(module_name, down_weight.size(), up_weight.size())
if len(weight.size()) == 2:
# linear
weight = weight + ratio * (up_weight @ down_weight) * scale
elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1
weight = (
weight
+ ratio
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* scale
)
num_blocks = dim
block_size = out_dim // dim
constraint = (0 if alpha is None else alpha) * out_dim
multiplier = 1
if lbw:
index = get_lbw_block_index(key, False)
is_lbw_target = index in LBW_TARGET_IDX
if is_lbw_target:
multiplier *= lbw_weights[index]
block_Q = oft_blocks - oft_blocks.transpose(1, 2)
norm_Q = torch.norm(block_Q.flatten())
new_norm_Q = torch.clamp(norm_Q, max=constraint)
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
I = torch.eye(block_size, device=oft_blocks.device).unsqueeze(0).repeat(num_blocks, 1, 1)
block_R = torch.matmul(I + block_Q, (I - block_Q).inverse())
block_R_weighted = multiplier * block_R + (1 - multiplier) * I
R = torch.block_diag(*block_R_weighted)
# get org weight
org_sd = module.state_dict()
org_weight = org_sd["weight"].to(device)
R = R.to(org_weight.device, dtype=org_weight.dtype)
if org_weight.dim() == 4:
weight = torch.einsum("oihw, op -> pihw", org_weight, R)
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + ratio * conved * scale
weight = torch.einsum("oi, op -> pi", org_weight, R)
weight = weight.contiguous() # Make Tensor contiguous; required due to ThreadPoolExecutor
module.weight = torch.nn.Parameter(weight)
# TODO multi-threading may cause OOM on CPU if cpu_count is too high and RAM is not enough
max_workers = 1 if device.type != "cpu" else None # avoid OOM on GPU
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
list(tqdm(executor.map(merge_to, lora_sd.keys()), total=len(lora_sd.keys())))
def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
def merge_lora_models(models, ratios, lbws, merge_dtype, concat=False, shuffle=False):
base_alphas = {} # alpha for merged model
base_dims = {}
# detect the method: OFT or LoRA_module
method = detect_method_from_training_model(models, merge_dtype)
if method == "OFT":
raise ValueError(
"OFT model is not supported for merging OFT models. / OFTモデルはOFTモデル同士のマージには対応していません"
)
if lbws:
lbws, _, LBW_TARGET_IDX = format_lbws(lbws)
else:
LBW_TARGET_IDX = []
merged_sd = {}
v2 = None
base_model = None
for model, ratio in zip(models, ratios):
for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws):
logger.info(f"loading: {model}")
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
if lbw:
lbw_weights = [1] * 26
for index, value in zip(LBW_TARGET_IDX, lbw):
lbw_weights[index] = value
logger.info(f"lbw: {dict(zip(LAYER26.keys(), lbw_weights))}")
if lora_metadata is not None:
if v2 is None:
v2 = lora_metadata.get(train_util.SS_METADATA_KEY_V2, None) # returns string, SDXLはv2がないのでFalseのはず
@@ -164,7 +300,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
for key in tqdm(lora_sd.keys()):
if "alpha" in key:
continue
if "lora_up" in key and concat:
concat_dim = 1
elif "lora_down" in key and concat:
@@ -178,8 +314,14 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
alpha = alphas[lora_module_name]
scale = math.sqrt(alpha / base_alpha) * ratio
scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。
scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。
if lbw:
index = get_lbw_block_index(key, True)
is_lbw_target = index in LBW_TARGET_IDX
if is_lbw_target:
scale *= lbw_weights[index] # keyがlbwの対象であれば、lbwの重みを掛ける
if key in merged_sd:
assert (
merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None
@@ -201,7 +343,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
dim = merged_sd[key_down].shape[0]
perm = torch.randperm(dim)
merged_sd[key_down] = merged_sd[key_down][perm]
merged_sd[key_up] = merged_sd[key_up][:,perm]
merged_sd[key_up] = merged_sd[key_up][:, perm]
logger.info("merged model")
logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
@@ -229,7 +371,15 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
def merge(args):
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
assert len(args.models) == len(
args.ratios
), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
if args.lbws:
assert len(args.models) == len(
args.lbws
), f"number of models must be equal to number of ratios / モデルの数と層別適用率の数は合わせてください"
else:
args.lbws = [] # zip_longestで扱えるようにlbws未使用時には空のリストにしておく
def str_to_dtype(p):
if p == "float":
@@ -257,7 +407,7 @@ def merge(args):
ckpt_info,
) = sdxl_model_util.load_models_from_sdxl_checkpoint(sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.sd_model, "cpu")
merge_to_sd_model(text_model1, text_model2, unet, args.models, args.ratios, merge_dtype)
merge_to_sd_model(text_model1, text_model2, unet, args.models, args.ratios, args.lbws, merge_dtype)
if args.no_metadata:
sai_metadata = None
@@ -273,7 +423,13 @@ def merge(args):
args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, logit_scale, sai_metadata, save_dtype
)
else:
state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
state_dict, metadata = merge_lora_models(args.models, args.ratios, args.lbws, merge_dtype, args.concat, args.shuffle)
# cast to save_dtype before calculating hashes
for key in list(state_dict.keys()):
value = state_dict[key]
if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype:
state_dict[key] = value.to(save_dtype)
logger.info(f"calculating hashes and creating metadata...")
@@ -290,7 +446,7 @@ def merge(args):
metadata.update(sai_metadata)
logger.info(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
save_to_file(args.save_to, state_dict, metadata)
def setup_parser() -> argparse.ArgumentParser:
@@ -316,12 +472,19 @@ def setup_parser() -> argparse.ArgumentParser:
help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする",
)
parser.add_argument(
"--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
"--save_to",
type=str,
default=None,
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors",
)
parser.add_argument(
"--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors"
"--models",
type=str,
nargs="*",
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors",
)
parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
parser.add_argument("--lbws", type=str, nargs="*", help="lbw for each model / それぞれのLoRAモデルの層別適用率")
parser.add_argument(
"--no_metadata",
action="store_true",
@@ -337,8 +500,7 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--shuffle",
action="store_true",
help="shuffle lora weight./ "
+ "LoRAの重みをシャッフルする",
help="shuffle lora weight./ " + "LoRAの重みをシャッフルする",
)
return parser

View File

@@ -1,5 +1,8 @@
import argparse
import itertools
import json
import os
import re
import time
import torch
from safetensors.torch import load_file, save_file
@@ -8,12 +11,195 @@ from library import sai_model_spec, train_util
import library.model_util as model_util
import lora
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
CLAMP_QUANTILE = 0.99
ACCEPTABLE = [12, 17, 20, 26]
SDXL_LAYER_NUM = [12, 20]
LAYER12 = {
"BASE": True,
"IN00": False,
"IN01": False,
"IN02": False,
"IN03": False,
"IN04": True,
"IN05": True,
"IN06": False,
"IN07": True,
"IN08": True,
"IN09": False,
"IN10": False,
"IN11": False,
"MID": True,
"OUT00": True,
"OUT01": True,
"OUT02": True,
"OUT03": True,
"OUT04": True,
"OUT05": True,
"OUT06": False,
"OUT07": False,
"OUT08": False,
"OUT09": False,
"OUT10": False,
"OUT11": False,
}
LAYER17 = {
"BASE": True,
"IN00": False,
"IN01": True,
"IN02": True,
"IN03": False,
"IN04": True,
"IN05": True,
"IN06": False,
"IN07": True,
"IN08": True,
"IN09": False,
"IN10": False,
"IN11": False,
"MID": True,
"OUT00": False,
"OUT01": False,
"OUT02": False,
"OUT03": True,
"OUT04": True,
"OUT05": True,
"OUT06": True,
"OUT07": True,
"OUT08": True,
"OUT09": True,
"OUT10": True,
"OUT11": True,
}
LAYER20 = {
"BASE": True,
"IN00": True,
"IN01": True,
"IN02": True,
"IN03": True,
"IN04": True,
"IN05": True,
"IN06": True,
"IN07": True,
"IN08": True,
"IN09": False,
"IN10": False,
"IN11": False,
"MID": True,
"OUT00": True,
"OUT01": True,
"OUT02": True,
"OUT03": True,
"OUT04": True,
"OUT05": True,
"OUT06": True,
"OUT07": True,
"OUT08": True,
"OUT09": False,
"OUT10": False,
"OUT11": False,
}
LAYER26 = {
"BASE": True,
"IN00": True,
"IN01": True,
"IN02": True,
"IN03": True,
"IN04": True,
"IN05": True,
"IN06": True,
"IN07": True,
"IN08": True,
"IN09": True,
"IN10": True,
"IN11": True,
"MID": True,
"OUT00": True,
"OUT01": True,
"OUT02": True,
"OUT03": True,
"OUT04": True,
"OUT05": True,
"OUT06": True,
"OUT07": True,
"OUT08": True,
"OUT09": True,
"OUT10": True,
"OUT11": True,
}
assert len([v for v in LAYER12.values() if v]) == 12
assert len([v for v in LAYER17.values() if v]) == 17
assert len([v for v in LAYER20.values() if v]) == 20
assert len([v for v in LAYER26.values() if v]) == 26
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
def get_lbw_block_index(lora_name: str, is_sdxl: bool = False) -> int:
# lbw block index is 0-based, but 0 for text encoder, so we return 0 for text encoder
if "text_model_encoder_" in lora_name: # LoRA for text encoder
return 0
# lbw block index is 1-based for U-Net, and no "input_blocks.0" in CompVis SD, so "input_blocks.1" have index 2
block_idx = -1 # invalid lora name
if not is_sdxl:
NUM_OF_BLOCKS = 12 # up/down blocks
m = RE_UPDOWN.search(lora_name)
if m:
g = m.groups()
up_down = g[0]
i = int(g[1])
j = int(g[3])
if up_down == "down":
if g[2] == "resnets" or g[2] == "attentions":
idx = 3 * i + j + 1
elif g[2] == "downsamplers":
idx = 3 * (i + 1)
else:
return block_idx # invalid lora name
elif up_down == "up":
if g[2] == "resnets" or g[2] == "attentions":
idx = 3 * i + j
elif g[2] == "upsamplers":
idx = 3 * i + 2
else:
return block_idx # invalid lora name
if g[0] == "down":
block_idx = 1 + idx # 1-based index, down block index
elif g[0] == "up":
block_idx = 1 + NUM_OF_BLOCKS + 1 + idx # 1-based index, num blocks, mid block, up block index
elif "mid_block_" in lora_name:
block_idx = 1 + NUM_OF_BLOCKS # 1-based index, num blocks, mid block
else:
# SDXL: some numbers are skipped
if lora_name.startswith("lora_unet_"):
name = lora_name[len("lora_unet_") :]
if name.startswith("time_embed_") or name.startswith("label_emb_"): # 1, No LoRA in sd-scripts
block_idx = 1
elif name.startswith("input_blocks_"): # 1-8 to 2-9
block_idx = 1 + int(name.split("_")[2])
elif name.startswith("middle_block_"): # 13
block_idx = 13
elif name.startswith("output_blocks_"): # 0-8 to 14-22
block_idx = 14 + int(name.split("_")[2])
elif name.startswith("out_"): # 23, No LoRA in sd-scripts
block_idx = 23
return block_idx
def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == ".safetensors":
@@ -30,24 +216,53 @@ def load_state_dict(file_name, dtype):
return sd, metadata
def save_to_file(file_name, state_dict, dtype, metadata):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
def save_to_file(file_name, state_dict, metadata):
if os.path.splitext(file_name)[1] == ".safetensors":
save_file(state_dict, file_name, metadata=metadata)
else:
torch.save(state_dict, file_name)
def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype):
def format_lbws(lbws):
try:
# lbwは"[1,1,1,1,1,1,1,1,1,1,1,1]"のような文字列で与えられることを期待している
lbws = [json.loads(lbw) for lbw in lbws]
except Exception:
raise ValueError(f"format of lbws are must be json / 層別適用率はJSON形式で書いてください")
assert all(isinstance(lbw, list) for lbw in lbws), f"lbws are must be list / 層別適用率はリストにしてください"
assert len(set(len(lbw) for lbw in lbws)) == 1, "all lbws should have the same length / 層別適用率は同じ長さにしてください"
assert all(
len(lbw) in ACCEPTABLE for lbw in lbws
), f"length of lbw are must be in {ACCEPTABLE} / 層別適用率の長さは{ACCEPTABLE}のいずれかにしてください"
assert all(
all(isinstance(weight, (int, float)) for weight in lbw) for lbw in lbws
), f"values of lbs are must be numbers / 層別適用率の値はすべて数値にしてください"
layer_num = len(lbws[0])
is_sdxl = True if layer_num in SDXL_LAYER_NUM else False
FLAGS = {
"12": LAYER12.values(),
"17": LAYER17.values(),
"20": LAYER20.values(),
"26": LAYER26.values(),
}[str(layer_num)]
LBW_TARGET_IDX = [i for i, flag in enumerate(FLAGS) if flag]
return lbws, is_sdxl, LBW_TARGET_IDX
def merge_lora_models(models, ratios, lbws, new_rank, new_conv_rank, device, merge_dtype):
logger.info(f"new rank: {new_rank}, new conv rank: {new_conv_rank}")
merged_sd = {}
v2 = None
v2 = None # This is meaning LoRA Metadata v2, Not meaning SD2
base_model = None
for model, ratio in zip(models, ratios):
if lbws:
lbws, is_sdxl, LBW_TARGET_IDX = format_lbws(lbws)
else:
is_sdxl = False
LBW_TARGET_IDX = []
for model, ratio, lbw in itertools.zip_longest(models, ratios, lbws):
logger.info(f"loading: {model}")
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
@@ -57,6 +272,12 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
if base_model is None:
base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None)
if lbw:
lbw_weights = [1] * 26
for index, value in zip(LBW_TARGET_IDX, lbw):
lbw_weights[index] = value
logger.info(f"lbw: {dict(zip(LAYER26.keys(), lbw_weights))}")
# merge
logger.info(f"merging...")
for key in tqdm(list(lora_sd.keys())):
@@ -80,10 +301,10 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
# make original weight if not exist
if lora_module_name not in merged_sd:
weight = torch.zeros((out_dim, in_dim, *kernel_size) if conv2d else (out_dim, in_dim), dtype=merge_dtype)
if device:
weight = weight.to(device)
else:
weight = merged_sd[lora_module_name]
if device:
weight = weight.to(device)
# merge to weight
if device:
@@ -93,6 +314,12 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
# W <- W + U * D
scale = alpha / network_dim
if lbw:
index = get_lbw_block_index(key, is_sdxl)
is_lbw_target = index in LBW_TARGET_IDX
if is_lbw_target:
scale *= lbw_weights[index] # keyがlbwの対象であれば、lbwの重みを掛ける
if device: # and isinstance(scale, torch.Tensor):
scale = scale.to(device)
@@ -109,13 +336,16 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
weight = weight + ratio * conved * scale
merged_sd[lora_module_name] = weight
merged_sd[lora_module_name] = weight.to("cpu")
# extract from merged weights
logger.info("extract new lora...")
merged_lora_sd = {}
with torch.no_grad():
for lora_module_name, mat in tqdm(list(merged_sd.items())):
if device:
mat = mat.to(device)
conv2d = len(mat.size()) == 4
kernel_size = None if not conv2d else mat.size()[2:4]
conv2d_3x3 = conv2d and kernel_size != (1, 1)
@@ -154,7 +384,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
merged_lora_sd[lora_module_name + ".lora_up.weight"] = up_weight.to("cpu").contiguous()
merged_lora_sd[lora_module_name + ".lora_down.weight"] = down_weight.to("cpu").contiguous()
merged_lora_sd[lora_module_name + ".alpha"] = torch.tensor(module_new_rank)
merged_lora_sd[lora_module_name + ".alpha"] = torch.tensor(module_new_rank, device="cpu")
# build minimum metadata
dims = f"{new_rank}"
@@ -169,7 +399,15 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
def merge(args):
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
assert len(args.models) == len(
args.ratios
), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
if args.lbws:
assert len(args.models) == len(
args.lbws
), f"number of models must be equal to number of ratios / モデルの数と層別適用率の数は合わせてください"
else:
args.lbws = [] # zip_longestで扱えるようにlbws未使用時には空のリストにしておく
def str_to_dtype(p):
if p == "float":
@@ -187,9 +425,15 @@ def merge(args):
new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank
state_dict, metadata, v2, base_model = merge_lora_models(
args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype
args.models, args.ratios, args.lbws, args.new_rank, new_conv_rank, args.device, merge_dtype
)
# cast to save_dtype before calculating hashes
for key in list(state_dict.keys()):
value = state_dict[key]
if type(value) == torch.Tensor and value.dtype.is_floating_point and value.dtype != save_dtype:
state_dict[key] = value.to(save_dtype)
logger.info(f"calculating hashes and creating metadata...")
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
@@ -211,7 +455,7 @@ def merge(args):
metadata.update(sai_metadata)
logger.info(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, save_dtype, metadata)
save_to_file(args.save_to, state_dict, metadata)
def setup_parser() -> argparse.ArgumentParser:
@@ -231,12 +475,19 @@ def setup_parser() -> argparse.ArgumentParser:
help="precision in merging (float is recommended) / マージの計算時の精度floatを推奨",
)
parser.add_argument(
"--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
"--save_to",
type=str,
default=None,
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors",
)
parser.add_argument(
"--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors"
"--models",
type=str,
nargs="*",
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors",
)
parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
parser.add_argument("--lbws", type=str, nargs="*", help="lbw for each model / それぞれのLoRAモデルの層別適用率")
parser.add_argument("--new_rank", type=int, default=4, help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
parser.add_argument(
"--new_conv_rank",
@@ -244,7 +495,9 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ",
)
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
parser.add_argument(
"--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う"
)
parser.add_argument(
"--no_metadata",
action="store_true",

View File

@@ -1,20 +1,25 @@
accelerate==0.25.0
transformers==4.36.2
accelerate==0.33.0
transformers==4.44.0
diffusers[torch]==0.25.0
ftfy==6.1.1
# albumentations==1.3.0
opencv-python==4.7.0.68
opencv-python==4.8.1.78
einops==0.7.0
pytorch-lightning==1.9.0
# bitsandbytes==0.39.1
tensorboard==2.10.1
safetensors==0.4.2
bitsandbytes==0.44.0
prodigyopt==1.0
lion-pytorch==0.0.6
schedulefree==1.2.7
tensorboard
safetensors==0.4.4
# gradio==3.16.2
altair==4.2.2
easygui==0.98.3
toml==0.10.2
voluptuous==0.13.1
huggingface-hub==0.20.1
huggingface-hub==0.24.5
# for Image utils
imagesize==1.4.1
# for BLIP captioning
# requests==2.28.2
# timm==0.6.12
@@ -31,8 +36,10 @@ huggingface-hub==0.20.1
# this is for onnx:
# protobuf==3.20.3
# open clip for SDXL
open-clip-torch==2.20.0
# open-clip-torch==2.20.0
# For logging
rich==13.7.0
# for T5XXL tokenizer (SD3/FLUX)
sentencepiece==0.2.0
# for kohya_ss library
-e .

407
sd3_minimal_inference.py Normal file
View File

@@ -0,0 +1,407 @@
# Minimum Inference Code for SD3
import argparse
import datetime
import math
import os
import random
from typing import Optional, Tuple
import numpy as np
import torch
from safetensors.torch import safe_open, load_file
import torch.amp
from tqdm import tqdm
from PIL import Image
from transformers import CLIPTextModelWithProjection, T5EncoderModel
from library.device_utils import init_ipex, get_preferred_device
from networks import lora_sd3
init_ipex()
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
from library import sd3_models, sd3_utils, strategy_sd3
from library.utils import load_safetensors
def get_noise(seed, latent, device="cpu"):
# generator = torch.manual_seed(seed)
generator = torch.Generator(device)
generator.manual_seed(seed)
return torch.randn(latent.size(), dtype=latent.dtype, layout=latent.layout, generator=generator, device=device)
def get_sigmas(sampling: sd3_utils.ModelSamplingDiscreteFlow, steps):
start = sampling.timestep(sampling.sigma_max)
end = sampling.timestep(sampling.sigma_min)
timesteps = torch.linspace(start, end, steps)
sigs = []
for x in range(len(timesteps)):
ts = timesteps[x]
sigs.append(sampling.sigma(ts))
sigs += [0.0]
return torch.FloatTensor(sigs)
def max_denoise(model_sampling, sigmas):
max_sigma = float(model_sampling.sigma_max)
sigma = float(sigmas[0])
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
def do_sample(
height: int,
width: int,
initial_latent: Optional[torch.Tensor],
seed: int,
cond: Tuple[torch.Tensor, torch.Tensor],
neg_cond: Tuple[torch.Tensor, torch.Tensor],
mmdit: sd3_models.MMDiT,
steps: int,
cfg_scale: float,
dtype: torch.dtype,
device: str,
):
if initial_latent is None:
# latent = torch.ones(1, 16, height // 8, width // 8, device=device) * 0.0609 # this seems to be a bug in the original code. thanks to furusu for pointing it out
latent = torch.zeros(1, 16, height // 8, width // 8, device=device)
else:
latent = initial_latent
latent = latent.to(dtype).to(device)
noise = get_noise(seed, latent, device)
model_sampling = sd3_utils.ModelSamplingDiscreteFlow(shift=3.0) # 3.0 is for SD3
sigmas = get_sigmas(model_sampling, steps).to(device)
# sigmas = sigmas[int(steps * (1 - denoise)) :] # do not support i2i
# conditioning = fix_cond(conditioning)
# neg_cond = fix_cond(neg_cond)
# extra_args = {"cond": cond, "uncond": neg_cond, "cond_scale": guidance_scale}
noise_scaled = model_sampling.noise_scaling(sigmas[0], noise, latent, max_denoise(model_sampling, sigmas))
c_crossattn = torch.cat([cond[0], neg_cond[0]]).to(device).to(dtype)
y = torch.cat([cond[1], neg_cond[1]]).to(device).to(dtype)
x = noise_scaled.to(device).to(dtype)
# print(x.shape)
with torch.no_grad():
for i in tqdm(range(len(sigmas) - 1)):
sigma_hat = sigmas[i]
timestep = model_sampling.timestep(sigma_hat).float()
timestep = torch.FloatTensor([timestep, timestep]).to(device)
x_c_nc = torch.cat([x, x], dim=0)
# print(x_c_nc.shape, timestep.shape, c_crossattn.shape, y.shape)
with torch.autocast(device_type=device.type, dtype=dtype):
model_output = mmdit(x_c_nc, timestep, context=c_crossattn, y=y)
model_output = model_output.float()
batched = model_sampling.calculate_denoised(sigma_hat, model_output, x)
pos_out, neg_out = batched.chunk(2)
denoised = neg_out + (pos_out - neg_out) * cfg_scale
# print(denoised.shape)
# d = to_d(x, sigma_hat, denoised)
dims_to_append = x.ndim - sigma_hat.ndim
sigma_hat_dims = sigma_hat[(...,) + (None,) * dims_to_append]
# print(dims_to_append, x.shape, sigma_hat.shape, denoised.shape, sigma_hat_dims.shape)
"""Converts a denoiser output to a Karras ODE derivative."""
d = (x - denoised) / sigma_hat_dims
dt = sigmas[i + 1] - sigma_hat
# Euler method
x = x + d * dt
x = x.to(dtype)
latent = x
latent = vae.process_out(latent)
return latent
def generate_image(
mmdit: sd3_models.MMDiT,
vae: sd3_models.SDVAE,
clip_l: CLIPTextModelWithProjection,
clip_g: CLIPTextModelWithProjection,
t5xxl: T5EncoderModel,
steps: int,
prompt: str,
seed: int,
target_width: int,
target_height: int,
device: str,
negative_prompt: str,
cfg_scale: float,
):
# prepare embeddings
logger.info("Encoding prompts...")
# TODO support one-by-one offloading
clip_l.to(device)
clip_g.to(device)
t5xxl.to(device)
with torch.autocast(device_type=device.type, dtype=mmdit.dtype), torch.no_grad():
tokens_and_masks = tokenize_strategy.tokenize(prompt)
lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encoding_strategy.encode_tokens(
tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask
)
cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
tokens_and_masks = tokenize_strategy.tokenize(negative_prompt)
lg_out, t5_out, pooled, neg_l_attn_mask, neg_g_attn_mask, neg_t5_attn_mask = encoding_strategy.encode_tokens(
tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask
)
neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
# attn masks are not used currently
if args.offload:
clip_l.to("cpu")
clip_g.to("cpu")
t5xxl.to("cpu")
# generate image
logger.info("Generating image...")
mmdit.to(device)
latent_sampled = do_sample(target_height, target_width, None, seed, cond, neg_cond, mmdit, steps, cfg_scale, sd3_dtype, device)
if args.offload:
mmdit.to("cpu")
# latent to image
vae.to(device)
with torch.no_grad():
image = vae.decode(latent_sampled)
if args.offload:
vae.to("cpu")
image = image.float()
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0]
decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2)
decoded_np = decoded_np.astype(np.uint8)
out_image = Image.fromarray(decoded_np)
# save image
output_dir = args.output_dir
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png")
out_image.save(output_path)
logger.info(f"Saved image to {output_path}")
if __name__ == "__main__":
target_height = 1024
target_width = 1024
# steps = 50 # 28 # 50
# cfg_scale = 5
# seed = 1 # None # 1
device = get_preferred_device()
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt_path", type=str, required=True)
parser.add_argument("--clip_g", type=str, required=False)
parser.add_argument("--clip_l", type=str, required=False)
parser.add_argument("--t5xxl", type=str, required=False)
parser.add_argument("--t5xxl_token_length", type=int, default=256, help="t5xxl token length, default: 256")
parser.add_argument("--apply_lg_attn_mask", action="store_true")
parser.add_argument("--apply_t5_attn_mask", action="store_true")
parser.add_argument("--prompt", type=str, default="A photo of a cat")
# parser.add_argument("--prompt2", type=str, default=None) # do not support different prompts for text encoders
parser.add_argument("--negative_prompt", type=str, default="")
parser.add_argument("--cfg_scale", type=float, default=5.0)
parser.add_argument("--offload", action="store_true", help="Offload to CPU")
parser.add_argument("--output_dir", type=str, default=".")
# parser.add_argument("--do_not_use_t5xxl", action="store_true")
# parser.add_argument("--attn_mode", type=str, default="torch", help="torch (SDPA) or xformers. default: torch")
parser.add_argument("--fp16", action="store_true")
parser.add_argument("--bf16", action="store_true")
parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--steps", type=int, default=50)
parser.add_argument(
"--lora_weights",
type=str,
nargs="*",
default=[],
help="LoRA weights, only supports networks.lora_sd3, each argument is a `path;multiplier` (semi-colon separated)",
)
parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model")
parser.add_argument("--width", type=int, default=target_width)
parser.add_argument("--height", type=int, default=target_height)
parser.add_argument("--interactive", action="store_true")
args = parser.parse_args()
seed = args.seed
steps = args.steps
sd3_dtype = torch.float32
if args.fp16:
sd3_dtype = torch.float16
elif args.bf16:
sd3_dtype = torch.bfloat16
loading_device = "cpu" if args.offload else device
# load state dict
logger.info(f"Loading SD3 models from {args.ckpt_path}...")
# state_dict = load_file(args.ckpt_path)
state_dict = load_safetensors(args.ckpt_path, loading_device, disable_mmap=True, dtype=sd3_dtype)
# load text encoders
clip_l = sd3_utils.load_clip_l(args.clip_l, sd3_dtype, loading_device, state_dict=state_dict)
clip_g = sd3_utils.load_clip_g(args.clip_g, sd3_dtype, loading_device, state_dict=state_dict)
t5xxl = sd3_utils.load_t5xxl(args.t5xxl, sd3_dtype, loading_device, state_dict=state_dict)
# MMDiT and VAE
vae = sd3_utils.load_vae(None, sd3_dtype, loading_device, state_dict=state_dict)
mmdit = sd3_utils.load_mmdit(state_dict, sd3_dtype, loading_device)
clip_l.to(sd3_dtype)
clip_g.to(sd3_dtype)
t5xxl.to(sd3_dtype)
vae.to(sd3_dtype)
mmdit.to(sd3_dtype)
if not args.offload:
# make sure to move to the device: some tensors are created in the constructor on the CPU
clip_l.to(device)
clip_g.to(device)
t5xxl.to(device)
vae.to(device)
mmdit.to(device)
clip_l.eval()
clip_g.eval()
t5xxl.eval()
mmdit.eval()
vae.eval()
# load tokenizers
logger.info("Loading tokenizers...")
tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_token_length)
encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy()
# LoRA
lora_models: list[lora_sd3.LoRANetwork] = []
for weights_file in args.lora_weights:
if ";" in weights_file:
weights_file, multiplier = weights_file.split(";")
multiplier = float(multiplier)
else:
multiplier = 1.0
weights_sd = load_file(weights_file)
module = lora_sd3
lora_model, _ = module.create_network_from_weights(multiplier, None, vae, [clip_l, clip_g, t5xxl], mmdit, weights_sd, True)
if args.merge_lora_weights:
lora_model.merge_to([clip_l, clip_g, t5xxl], mmdit, weights_sd)
else:
lora_model.apply_to([clip_l, clip_g, t5xxl], mmdit)
info = lora_model.load_state_dict(weights_sd, strict=True)
logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
lora_model.eval()
lora_model.to(device)
lora_models.append(lora_model)
if not args.interactive:
generate_image(
mmdit,
vae,
clip_l,
clip_g,
t5xxl,
args.steps,
args.prompt,
args.seed,
args.width,
args.height,
device,
args.negative_prompt,
args.cfg_scale,
)
else:
# loop for interactive
width = args.width
height = args.height
steps = None
cfg_scale = args.cfg_scale
while True:
print(
"Enter prompt (empty to exit). Options: --w <width> --h <height> --s <steps> --d <seed>"
" --n <negative prompt>, `--n -` for empty negative prompt"
"Options are kept for the next prompt. Current options:"
f" width={width}, height={height}, steps={steps}, seed={seed}, cfg_scale={cfg_scale}"
)
prompt = input()
if prompt == "":
break
# parse options
options = prompt.split("--")
prompt = options[0].strip()
seed = None
negative_prompt = None
for opt in options[1:]:
try:
opt = opt.strip()
if opt.startswith("w"):
width = int(opt[1:].strip())
elif opt.startswith("h"):
height = int(opt[1:].strip())
elif opt.startswith("s"):
steps = int(opt[1:].strip())
elif opt.startswith("d"):
seed = int(opt[1:].strip())
elif opt.startswith("m"):
mutipliers = opt[1:].strip().split(",")
if len(mutipliers) != len(lora_models):
logger.error(f"Invalid number of multipliers, expected {len(lora_models)}")
continue
for i, lora_model in enumerate(lora_models):
lora_model.set_multiplier(float(mutipliers[i]))
elif opt.startswith("n"):
negative_prompt = opt[1:].strip()
if negative_prompt == "-":
negative_prompt = ""
elif opt.startswith("c"):
cfg_scale = float(opt[1:].strip())
except ValueError as e:
logger.error(f"Invalid option: {opt}, {e}")
generate_image(
mmdit,
vae,
clip_l,
clip_g,
t5xxl,
steps if steps is not None else args.steps,
prompt,
seed if seed is not None else args.seed,
width,
height,
device,
negative_prompt if negative_prompt is not None else args.negative_prompt,
cfg_scale,
)
logger.info("Done!")

1073
sd3_train.py Normal file

File diff suppressed because it is too large Load Diff

480
sd3_train_network.py Normal file
View File

@@ -0,0 +1,480 @@
import argparse
import copy
import math
import random
from typing import Any, Optional
import torch
from accelerate import Accelerator
from library import sd3_models, strategy_sd3, utils
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, sd3_utils, strategy_base, strategy_sd3, train_util
import train_network
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class Sd3NetworkTrainer(train_network.NetworkTrainer):
def __init__(self):
super().__init__()
self.sample_prompts_te_outputs = None
def assert_extra_args(self, args, train_dataset_group: train_util.DatasetGroup):
# super().assert_extra_args(args, train_dataset_group)
# sdxl_train_util.verify_sdxl_training_args(args)
if args.fp8_base_unet:
args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for SD3
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
logger.warning(
"cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります"
)
args.cache_text_encoder_outputs = True
if args.cache_text_encoder_outputs:
assert (
train_dataset_group.is_text_encoder_output_cacheable()
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
# prepare CLIP-L/CLIP-G/T5XXL training flags
self.train_clip = not args.network_train_unet_only
self.train_t5xxl = False # default is False even if args.network_train_unet_only is False
if args.max_token_length is not None:
logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません")
assert (
args.blocks_to_swap is None or args.blocks_to_swap == 0
) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません"
train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
# enumerate resolutions from dataset for positional embeddings
self.resolutions = train_dataset_group.get_resolutions()
def load_target_model(self, args, weight_dtype, accelerator):
# currently offload to cpu for some models
# if the file is fp8 and we are using fp8_base, we can load it as is (fp8)
loading_dtype = None if args.fp8_base else weight_dtype
# if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future
state_dict = utils.load_safetensors(
args.pretrained_model_name_or_path, "cpu", disable_mmap=args.disable_mmap_load_safetensors, dtype=loading_dtype
)
mmdit = sd3_utils.load_mmdit(state_dict, loading_dtype, "cpu")
self.model_type = mmdit.model_type
mmdit.set_pos_emb_random_crop_rate(args.pos_emb_random_crop_rate)
# set resolutions for positional embeddings
if args.enable_scaled_pos_embed:
latent_sizes = [round(math.sqrt(res[0] * res[1])) // 8 for res in self.resolutions] # 8 is stride for latent
latent_sizes = list(set(latent_sizes)) # remove duplicates
logger.info(f"Prepare scaled positional embeddings for resolutions: {self.resolutions}, sizes: {latent_sizes}")
mmdit.enable_scaled_pos_embed(True, latent_sizes)
if args.fp8_base:
# check dtype of model
if mmdit.dtype == torch.float8_e4m3fnuz or mmdit.dtype == torch.float8_e5m2 or mmdit.dtype == torch.float8_e5m2fnuz:
raise ValueError(f"Unsupported fp8 model dtype: {mmdit.dtype}")
elif mmdit.dtype == torch.float8_e4m3fn:
logger.info("Loaded fp8 SD3 model")
else:
logger.info(
"Cast SD3 model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint."
" / SD3モデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。"
)
mmdit.to(torch.float8_e4m3fn)
self.is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
if self.is_swapping_blocks:
# Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
mmdit.enable_block_swap(args.blocks_to_swap, accelerator.device)
clip_l = sd3_utils.load_clip_l(
args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict
)
clip_l.eval()
clip_g = sd3_utils.load_clip_g(
args.clip_g, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict
)
clip_g.eval()
# if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8)
if args.fp8_base and not args.fp8_base_unet:
loading_dtype = None # as is
else:
loading_dtype = weight_dtype
# loading t5xxl to cpu takes a long time, so we should load to gpu in future
t5xxl = sd3_utils.load_t5xxl(
args.t5xxl, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict
)
t5xxl.eval()
if args.fp8_base and not args.fp8_base_unet:
# check dtype of model
if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz:
raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}")
elif t5xxl.dtype == torch.float8_e4m3fn:
logger.info("Loaded fp8 T5XXL model")
vae = sd3_utils.load_vae(
args.vae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors, state_dict=state_dict
)
return mmdit.model_type, [clip_l, clip_g, t5xxl], vae, mmdit
def get_tokenize_strategy(self, args):
logger.info(f"t5xxl_max_token_length: {args.t5xxl_max_token_length}")
return strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_max_token_length, args.tokenizer_cache_dir)
def get_tokenizers(self, tokenize_strategy: strategy_sd3.Sd3TokenizeStrategy):
return [tokenize_strategy.clip_l, tokenize_strategy.clip_g, tokenize_strategy.t5xxl]
def get_latents_caching_strategy(self, args):
latents_caching_strategy = strategy_sd3.Sd3LatentsCachingStrategy(
args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
return latents_caching_strategy
def get_text_encoding_strategy(self, args):
return strategy_sd3.Sd3TextEncodingStrategy(
args.apply_lg_attn_mask,
args.apply_t5_attn_mask,
args.clip_l_dropout_rate,
args.clip_g_dropout_rate,
args.t5_dropout_rate,
)
def post_process_network(self, args, accelerator, network, text_encoders, unet):
# check t5xxl is trained or not
self.train_t5xxl = network.train_t5xxl
if self.train_t5xxl and args.cache_text_encoder_outputs:
raise ValueError(
"T5XXL is trained, so cache_text_encoder_outputs cannot be used / T5XXL学習時はcache_text_encoder_outputsは使用できません"
)
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
if args.cache_text_encoder_outputs:
if self.train_clip and not self.train_t5xxl:
return text_encoders[0:2] + [None] # only CLIP-L/CLIP-G is needed for encoding because T5XXL is cached
else:
return None # no text encoders are needed for encoding because both are cached
else:
return text_encoders # CLIP-L, CLIP-G and T5XXL are needed for encoding
def get_text_encoders_train_flags(self, args, text_encoders):
return [self.train_clip, self.train_clip, self.train_t5xxl]
def get_text_encoder_outputs_caching_strategy(self, args):
if args.cache_text_encoder_outputs:
# if the text encoders is trained, we need tokenization, so is_partial is True
return strategy_sd3.Sd3TextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
args.text_encoder_batch_size,
args.skip_cache_check,
is_partial=self.train_clip or self.train_t5xxl,
apply_lg_attn_mask=args.apply_lg_attn_mask,
apply_t5_attn_mask=args.apply_t5_attn_mask,
)
else:
return None
def cache_text_encoder_outputs_if_needed(
self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
):
if args.cache_text_encoder_outputs:
if not args.lowram:
# メモリ消費を減らす
logger.info("move vae and unet to cpu to save memory")
org_vae_device = vae.device
org_unet_device = unet.device
vae.to("cpu")
unet.to("cpu")
clean_memory_on_device(accelerator.device)
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
logger.info("move text encoders to gpu")
text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
text_encoders[1].to(accelerator.device, dtype=weight_dtype) # always not fp8
text_encoders[2].to(accelerator.device) # may be fp8
if text_encoders[2].dtype == torch.float8_e4m3fn:
# if we load fp8 weights, the model is already fp8, so we use it as is
self.prepare_text_encoder_fp8(2, text_encoders[2], text_encoders[2].dtype, weight_dtype)
else:
# otherwise, we need to convert it to target dtype
text_encoders[2].to(weight_dtype)
with accelerator.autocast():
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
# cache sample prompts
if args.sample_prompts is not None:
logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}")
tokenize_strategy: strategy_sd3.Sd3TokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
text_encoding_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
prompts = train_util.load_prompts(args.sample_prompts)
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
with accelerator.autocast(), torch.no_grad():
for prompt_dict in prompts:
for p in [prompt_dict.get("prompt", ""), prompt_dict.get("negative_prompt", "")]:
if p not in sample_prompts_te_outputs:
logger.info(f"cache Text Encoder outputs for prompt: {p}")
tokens_and_masks = tokenize_strategy.tokenize(p)
sample_prompts_te_outputs[p] = text_encoding_strategy.encode_tokens(
tokenize_strategy,
text_encoders,
tokens_and_masks,
args.apply_lg_attn_mask,
args.apply_t5_attn_mask,
)
self.sample_prompts_te_outputs = sample_prompts_te_outputs
accelerator.wait_for_everyone()
# move back to cpu
if not self.is_train_text_encoder(args):
logger.info("move CLIP-L back to cpu")
text_encoders[0].to("cpu")
logger.info("move CLIP-G back to cpu")
text_encoders[1].to("cpu")
logger.info("move t5XXL back to cpu")
text_encoders[2].to("cpu")
clean_memory_on_device(accelerator.device)
if not args.lowram:
logger.info("move vae and unet back to original device")
vae.to(org_vae_device)
unet.to(org_unet_device)
else:
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
text_encoders[1].to(accelerator.device, dtype=weight_dtype)
text_encoders[2].to(accelerator.device)
# def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
# noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
# # get size embeddings
# orig_size = batch["original_sizes_hw"]
# crop_size = batch["crop_top_lefts"]
# target_size = batch["target_sizes_hw"]
# embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
# # concat embeddings
# encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
# vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
# text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
# noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
# return noise_pred
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, mmdit):
text_encoders = text_encoder # for compatibility
text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders)
sd3_train_utils.sample_images(
accelerator, args, epoch, global_step, mmdit, vae, text_encoders, self.sample_prompts_te_outputs
)
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
# this scheduler is not used in training, but used to get num_train_timesteps etc.
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.training_shift)
return noise_scheduler
def encode_images_to_latents(self, args, accelerator, vae, images):
return vae.encode(images)
def shift_scale_latents(self, args, latents):
return sd3_models.SDVAE.process_in(latents)
def get_noise_pred_and_target(
self,
args,
accelerator,
noise_scheduler,
latents,
batch,
text_encoder_conds,
unet: flux_models.Flux,
network,
weight_dtype,
train_unet,
):
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
# get noisy model input and timesteps
noisy_model_input, timesteps, sigmas = sd3_train_utils.get_noisy_model_input_and_timesteps(
args, latents, noise, accelerator.device, weight_dtype
)
# ensure the hidden state will require grad
if args.gradient_checkpointing:
noisy_model_input.requires_grad_(True)
for t in text_encoder_conds:
if t is not None and t.dtype.is_floating_point:
t.requires_grad_(True)
# Predict the noise residual
lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = text_encoder_conds
text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
context, lg_pooled = text_encoding_strategy.concat_encodings(lg_out, t5_out, lg_pooled)
if not args.apply_lg_attn_mask:
l_attn_mask = None
g_attn_mask = None
if not args.apply_t5_attn_mask:
t5_attn_mask = None
# call model
with accelerator.autocast():
# TODO support attention mask
model_pred = unet(noisy_model_input, timesteps, context=context, y=lg_pooled)
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
# Preconditioning of the model outputs.
model_pred = model_pred * (-sigmas) + noisy_model_input
# these weighting schemes use a uniform timestep sampling
# and instead post-weight the loss
weighting = sd3_train_utils.compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
# flow matching loss
target = latents
# differential output preservation
if "custom_attributes" in batch:
diff_output_pr_indices = []
for i, custom_attributes in enumerate(batch["custom_attributes"]):
if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
diff_output_pr_indices.append(i)
if len(diff_output_pr_indices) > 0:
network.set_multiplier(0.0)
with torch.no_grad(), accelerator.autocast():
model_pred_prior = unet(
noisy_model_input[diff_output_pr_indices],
timesteps[diff_output_pr_indices],
context=context[diff_output_pr_indices],
y=lg_pooled[diff_output_pr_indices],
)
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
model_pred_prior = model_pred_prior * (-sigmas[diff_output_pr_indices]) + noisy_model_input[diff_output_pr_indices]
# weighting for differential output preservation is not needed because it is already applied
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
return model_pred, target, timesteps, weighting
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
return loss
def get_sai_model_spec(self, args):
return train_util.get_sai_model_spec(None, args, False, True, False, sd3=self.model_type)
def update_metadata(self, metadata, args):
metadata["ss_apply_lg_attn_mask"] = args.apply_lg_attn_mask
metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask
metadata["ss_weighting_scheme"] = args.weighting_scheme
metadata["ss_logit_mean"] = args.logit_mean
metadata["ss_logit_std"] = args.logit_std
metadata["ss_mode_scale"] = args.mode_scale
def is_text_encoder_not_needed_for_training(self, args):
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
if index == 0 or index == 1: # CLIP-L/CLIP-G
return super().prepare_text_encoder_grad_ckpt_workaround(index, text_encoder)
else: # T5XXL
text_encoder.encoder.embed_tokens.requires_grad_(True)
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
if index == 0 or index == 1: # CLIP-L/CLIP-G
clip_type = "CLIP-L" if index == 0 else "CLIP-G"
logger.info(f"prepare CLIP-{clip_type} for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}")
text_encoder.to(te_weight_dtype) # fp8
text_encoder.text_model.embeddings.to(dtype=weight_dtype)
else: # T5XXL
def prepare_fp8(text_encoder, target_dtype):
def forward_hook(module):
def forward(hidden_states):
hidden_gelu = module.act(module.wi_0(hidden_states))
hidden_linear = module.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = module.dropout(hidden_states)
hidden_states = module.wo(hidden_states)
return hidden_states
return forward
for module in text_encoder.modules():
if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
# print("set", module.__class__.__name__, "to", target_dtype)
module.to(target_dtype)
if module.__class__.__name__ in ["T5DenseGatedActDense"]:
# print("set", module.__class__.__name__, "hooks")
module.forward = forward_hook(module)
if flux_utils.get_t5xxl_actual_dtype(text_encoder) == torch.float8_e4m3fn and text_encoder.dtype == weight_dtype:
logger.info(f"T5XXL already prepared for fp8")
else:
logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks")
text_encoder.to(te_weight_dtype) # fp8
prepare_fp8(text_encoder, weight_dtype)
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
# drop cached text encoder outputs
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list)
batch["text_encoder_outputs_list"] = text_encoder_outputs_list
def prepare_unet_with_accelerator(
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
) -> torch.nn.Module:
if not self.is_swapping_blocks:
return super().prepare_unet_with_accelerator(args, accelerator, unet)
# if we doesn't swap blocks, we can move the model to device
mmdit: sd3_models.MMDiT = unet
mmdit = accelerator.prepare(mmdit, device_placement=[not self.is_swapping_blocks])
accelerator.unwrap_model(mmdit).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
accelerator.unwrap_model(mmdit).prepare_block_swap_before_forward()
return mmdit
def setup_parser() -> argparse.ArgumentParser:
parser = train_network.setup_parser()
train_util.add_dit_training_arguments(parser)
sd3_train_utils.add_sd3_training_arguments(parser)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
trainer = Sd3NetworkTrainer()
trainer.train(args)

View File

@@ -11,20 +11,24 @@ import numpy as np
import torch
from library.device_utils import init_ipex, get_preferred_device
init_ipex()
from tqdm import tqdm
from transformers import CLIPTokenizer
from diffusers import EulerDiscreteScheduler
from PIL import Image
import open_clip
# import open_clip
from safetensors.torch import load_file
from library import model_util, sdxl_model_util
import networks.lora as lora
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# scheduler: このあたりの設定はSD1/2と同じでいいらしい
@@ -154,12 +158,13 @@ if __name__ == "__main__":
text_model2.eval()
unet.set_use_memory_efficient_attention(True, False)
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
vae.set_use_memory_efficient_attention_xformers(True)
# Tokenizers
tokenizer1 = CLIPTokenizer.from_pretrained(text_encoder_1_name)
tokenizer2 = lambda x: open_clip.tokenize(x, context_length=77)
# tokenizer2 = lambda x: open_clip.tokenize(x, context_length=77)
tokenizer2 = CLIPTokenizer.from_pretrained(text_encoder_2_name)
# LoRA
for weights_file in args.lora_weights:
@@ -192,7 +197,9 @@ if __name__ == "__main__":
emb3 = get_timestep_embedding(torch.FloatTensor([target_height, target_width]).unsqueeze(0), 256)
# logger.info("emb1", emb1.shape)
c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(DEVICE, dtype=DTYPE)
uc_vector = c_vector.clone().to(DEVICE, dtype=DTYPE) # ちょっとここ正しいかどうかわからない I'm not sure if this is right
uc_vector = c_vector.clone().to(
DEVICE, dtype=DTYPE
) # ちょっとここ正しいかどうかわからない I'm not sure if this is right
# crossattn
@@ -215,13 +222,22 @@ if __name__ == "__main__":
# text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding) # layer normは通さないらしい
# text encoder 2
with torch.no_grad():
tokens = tokenizer2(text2).to(DEVICE)
# tokens = tokenizer2(text2).to(DEVICE)
tokens = tokenizer2(
text,
truncation=True,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(DEVICE)
with torch.no_grad():
enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True)
text_embedding2_penu = enc_out["hidden_states"][-2]
# logger.info("hidden_states2", text_embedding2_penu.shape)
text_embedding2_pool = enc_out["text_embeds"] # do not support Textual Inversion
text_embedding2_pool = enc_out["text_embeds"] # do not support Textual Inversion
# 連結して終了 concat and finish
text_embedding = torch.cat([text_embedding1, text_embedding2_penu], dim=2)

View File

@@ -11,11 +11,13 @@ from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from library import sdxl_model_util
from library import deepspeed_utils, sdxl_model_util, strategy_base, strategy_sd, strategy_sdxl
import library.train_util as train_util
@@ -39,6 +41,7 @@ from library.custom_train_functions import (
scale_v_prediction_loss_like_noise_prediction,
add_v_prediction_like_loss,
apply_debiased_estimation,
apply_masked_loss,
)
from library.sdxl_original_unet import SdxlUNet2DConditionModel
@@ -97,11 +100,12 @@ def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
sdxl_train_util.verify_sdxl_training_args(args)
deepspeed_utils.prepare_deepspeed_args(args)
setup_logging(args, reset=True)
assert (
not args.weighted_captions
), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
not args.weighted_captions or not args.cache_text_encoder_outputs
), "weighted_captions is not supported when caching text encoder outputs / cache_text_encoder_outputsを使うときはweighted_captionsはサポートされていません"
assert (
not args.train_text_encoder or not args.cache_text_encoder_outputs
), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません"
@@ -120,11 +124,20 @@ def train(args):
if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
tokenizers = [tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2] # will be removed in the future
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
if args.cache_latents:
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
if args.dataset_config is not None:
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
@@ -162,10 +175,10 @@ def train(args):
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2])
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, [tokenizer1, tokenizer2])
train_dataset_group = train_util.load_arbitrary_dataset(args)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
@@ -258,8 +271,9 @@ def train(args):
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
train_dataset_group.new_cache_latents(vae, accelerator)
vae.to("cpu")
clean_memory_on_device(accelerator.device)
@@ -268,10 +282,13 @@ def train(args):
# 学習を準備する:モデルを適切な状態にする
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
train_unet = args.learning_rate > 0
train_unet = args.learning_rate != 0
train_text_encoder1 = False
train_text_encoder2 = False
text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy()
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
if args.train_text_encoder:
# TODO each option for two text encoders?
accelerator.print("enable text encoder training")
@@ -280,8 +297,8 @@ def train(args):
text_encoder2.gradient_checkpointing_enable()
lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train
lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train
train_text_encoder1 = lr_te1 > 0
train_text_encoder2 = lr_te2 > 0
train_text_encoder1 = lr_te1 != 0
train_text_encoder2 = lr_te2 != 0
# caching one text encoder output is not supported
if not train_text_encoder1:
@@ -303,16 +320,17 @@ def train(args):
# TextEncoderの出力をキャッシュする
if args.cache_text_encoder_outputs:
# Text Encodes are eval and no grad
with torch.no_grad(), accelerator.autocast():
train_dataset_group.cache_text_encoder_outputs(
(tokenizer1, tokenizer2),
(text_encoder1, text_encoder2),
accelerator.device,
None,
args.cache_text_encoder_outputs_to_disk,
accelerator.is_main_process,
)
accelerator.wait_for_everyone()
text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, None, False, is_weighted=args.weighted_captions
)
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy)
text_encoder1.to(accelerator.device)
text_encoder2.to(accelerator.device)
with accelerator.autocast():
train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator)
accelerator.wait_for_everyone()
if not cache_latents:
vae.requires_grad_(False)
@@ -341,8 +359,8 @@ def train(args):
# calculate number of trainable parameters
n_params = 0
for params in params_to_optimize:
for p in params["params"]:
for group in params_to_optimize:
for p in group["params"]:
n_params += p.numel()
accelerator.print(f"train unet: {train_unet}, text_encoder1: {train_text_encoder1}, text_encoder2: {train_text_encoder2}")
@@ -351,9 +369,59 @@ def train(args):
# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
# dataloaderを準備する
if args.fused_optimizer_groups:
# fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
# Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each group of parameters.
# This balances memory usage and management complexity.
# calculate total number of parameters
n_total_params = sum(len(params["params"]) for params in params_to_optimize)
params_per_group = math.ceil(n_total_params / args.fused_optimizer_groups)
# split params into groups, keeping the learning rate the same for all params in a group
# this will increase the number of groups if the learning rate is different for different params (e.g. U-Net and text encoders)
grouped_params = []
param_group = []
param_group_lr = -1
for group in params_to_optimize:
lr = group["lr"]
for p in group["params"]:
# if the learning rate is different for different params, start a new group
if lr != param_group_lr:
if param_group:
grouped_params.append({"params": param_group, "lr": param_group_lr})
param_group = []
param_group_lr = lr
param_group.append(p)
# if the group has enough parameters, start a new group
if len(param_group) == params_per_group:
grouped_params.append({"params": param_group, "lr": param_group_lr})
param_group = []
param_group_lr = -1
if param_group:
grouped_params.append({"params": param_group, "lr": param_group_lr})
# prepare optimizers for each group
optimizers = []
for group in grouped_params:
_, _, optimizer = train_util.get_optimizer(args, trainable_params=[group])
optimizers.append(optimizer)
optimizer = optimizers[0] # avoid error in the following code
logger.info(f"using {len(optimizers)} optimizers for fused optimizer groups")
else:
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
# prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
train_dataset_group.set_current_strategies()
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
@@ -378,7 +446,12 @@ def train(args):
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
if args.fused_optimizer_groups:
# prepare lr schedulers for each optimizer
lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers]
lr_scheduler = lr_schedulers[0] # avoid error in the following code
else:
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# 実験的機能勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
if args.full_fp16:
@@ -398,18 +471,33 @@ def train(args):
text_encoder1.to(weight_dtype)
text_encoder2.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
if train_unet:
unet = accelerator.prepare(unet)
# freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer
if train_text_encoder1:
# freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer
text_encoder1.text_model.encoder.layers[-1].requires_grad_(False)
text_encoder1.text_model.final_layer_norm.requires_grad_(False)
text_encoder1 = accelerator.prepare(text_encoder1)
if train_text_encoder2:
text_encoder2 = accelerator.prepare(text_encoder2)
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
if args.deepspeed:
ds_model = deepspeed_utils.prepare_deepspeed_model(
args,
unet=unet if train_unet else None,
text_encoder1=text_encoder1 if train_text_encoder1 else None,
text_encoder2=text_encoder2 if train_text_encoder2 else None,
)
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
training_models = [ds_model]
else:
# acceleratorがなんかよろしくやってくれるらしい
if train_unet:
unet = accelerator.prepare(unet)
if train_text_encoder1:
text_encoder1 = accelerator.prepare(text_encoder1)
if train_text_encoder2:
text_encoder2 = accelerator.prepare(text_encoder2)
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
if args.cache_text_encoder_outputs:
@@ -424,11 +512,64 @@ def train(args):
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
# During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do.
# -> But we think it's ok to patch accelerator even if deepspeed is enabled.
train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
if args.fused_backward_pass:
# use fused optimizer for backward pass: other optimizers will be supported in the future
import library.adafactor_fused
library.adafactor_fused.patch_adafactor_fused(optimizer)
for param_group in optimizer.param_groups:
for parameter in param_group["params"]:
if parameter.requires_grad:
def __grad_hook(tensor: torch.Tensor, param_group=param_group):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, param_group)
tensor.grad = None
parameter.register_post_accumulate_grad_hook(__grad_hook)
elif args.fused_optimizer_groups:
# prepare for additional optimizers and lr schedulers
for i in range(1, len(optimizers)):
optimizers[i] = accelerator.prepare(optimizers[i])
lr_schedulers[i] = accelerator.prepare(lr_schedulers[i])
# counters are used to determine when to step the optimizer
global optimizer_hooked_count
global num_parameters_per_group
global parameter_optimizer_map
optimizer_hooked_count = {}
num_parameters_per_group = [0] * len(optimizers)
parameter_optimizer_map = {}
for opt_idx, optimizer in enumerate(optimizers):
for param_group in optimizer.param_groups:
for parameter in param_group["params"]:
if parameter.requires_grad:
def optimizer_hook(parameter: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(parameter, args.max_grad_norm)
i = parameter_optimizer_map[parameter]
optimizer_hooked_count[i] += 1
if optimizer_hooked_count[i] == num_parameters_per_group[i]:
optimizers[i].step()
optimizers[i].zero_grad(set_to_none=True)
parameter.register_post_accumulate_grad_hook(optimizer_hook)
parameter_optimizer_map[parameter] = opt_idx
num_parameters_per_group[opt_idx] += 1
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
@@ -466,12 +607,19 @@ def train(args):
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
accelerator.init_trackers(
"finetuning" if args.log_tracker_name is None else args.log_tracker_name,
config=train_util.get_sanitized_config_or_none(args),
init_kwargs=init_kwargs,
)
# For --sample_at_first
sdxl_train_util.sample_images(
accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [text_encoder1, text_encoder2], unet
accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, [text_encoder1, text_encoder2], unet
)
if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)
loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
@@ -483,6 +631,10 @@ def train(args):
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
if args.fused_optimizer_groups:
optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step
with accelerator.accumulate(*training_models):
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
@@ -497,57 +649,39 @@ def train(args):
latents = torch.nan_to_num(latents, 0, out=latents)
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
input_ids1 = batch["input_ids"]
input_ids2 = batch["input_ids2"]
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
# Text Encoder outputs are cached
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoder_outputs_list
encoder_hidden_states1 = encoder_hidden_states1.to(accelerator.device, dtype=weight_dtype)
encoder_hidden_states2 = encoder_hidden_states2.to(accelerator.device, dtype=weight_dtype)
pool2 = pool2.to(accelerator.device, dtype=weight_dtype)
else:
input_ids1, input_ids2 = batch["input_ids_list"]
with torch.set_grad_enabled(args.train_text_encoder):
# Get the text embedding for conditioning
# TODO support weighted captions
# if args.weighted_captions:
# encoder_hidden_states = get_weighted_text_embeddings(
# tokenizer,
# text_encoder,
# batch["captions"],
# accelerator.device,
# args.max_token_length // 75 if args.max_token_length else 1,
# clip_skip=args.clip_skip,
# )
# else:
input_ids1 = input_ids1.to(accelerator.device)
input_ids2 = input_ids2.to(accelerator.device)
# unwrap_model is fine for models not wrapped by accelerator
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
args.max_token_length,
input_ids1,
input_ids2,
tokenizer1,
tokenizer2,
text_encoder1,
text_encoder2,
None if not args.full_fp16 else weight_dtype,
accelerator=accelerator,
)
else:
encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype)
pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype)
# # verify that the text encoder outputs are correct
# ehs1, ehs2, p2 = train_util.get_hidden_states_sdxl(
# args.max_token_length,
# batch["input_ids"].to(text_encoder1.device),
# batch["input_ids2"].to(text_encoder1.device),
# tokenizer1,
# tokenizer2,
# text_encoder1,
# text_encoder2,
# None if not args.full_fp16 else weight_dtype,
# )
# b_size = encoder_hidden_states1.shape[0]
# assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# logger.info("text encoder outputs verified")
if args.weighted_captions:
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
encoder_hidden_states1, encoder_hidden_states2, pool2 = (
text_encoding_strategy.encode_tokens_with_weights(
tokenize_strategy,
[text_encoder1, text_encoder2, accelerator.unwrap_model(text_encoder2)],
input_ids_list,
weights_list,
)
)
else:
input_ids1 = input_ids1.to(accelerator.device)
input_ids2 = input_ids2.to(accelerator.device)
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens(
tokenize_strategy,
[text_encoder1, text_encoder2, accelerator.unwrap_model(text_encoder2)],
[input_ids1, input_ids2],
)
if args.full_fp16:
encoder_hidden_states1 = encoder_hidden_states1.to(weight_dtype)
encoder_hidden_states2 = encoder_hidden_states2.to(weight_dtype)
pool2 = pool2.to(weight_dtype)
# get size embeddings
orig_size = batch["original_sizes_hw"]
@@ -569,41 +703,57 @@ def train(args):
with accelerator.autocast():
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
target = noise
if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
target = noise
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
if (
args.min_snr_gamma
or args.scale_v_pred_loss_like_noise_pred
or args.v_pred_like_loss
or args.debiased_estimation_loss
or args.masked_loss
):
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss:
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
loss = loss.mean() # mean over batch dimension
else:
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "mean", huber_c)
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = []
for m in training_models:
params_to_clip.extend(m.parameters())
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
if not (args.fused_backward_pass or args.fused_optimizer_groups):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = []
for m in training_models:
params_to_clip.extend(m.parameters())
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
else:
# optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
lr_scheduler.step()
if args.fused_optimizer_groups:
for i in range(1, len(optimizers)):
lr_schedulers[i].step()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
@@ -617,7 +767,7 @@ def train(args):
global_step,
accelerator.device,
vae,
[tokenizer1, tokenizer2],
tokenizers,
[text_encoder1, text_encoder2],
unet,
)
@@ -647,7 +797,7 @@ def train(args):
)
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss": current_loss}
if block_lrs is None:
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=train_unet)
@@ -664,7 +814,7 @@ def train(args):
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
@@ -699,7 +849,7 @@ def train(args):
global_step,
accelerator.device,
vae,
[tokenizer1, tokenizer2],
tokenizers,
[text_encoder1, text_encoder2],
unet,
)
@@ -712,7 +862,7 @@ def train(args):
accelerator.end_training()
if args.save_state or args.save_state_on_train_end:
if args.save_state or args.save_state_on_train_end:
train_util.save_state_on_train_end(args, accelerator)
del accelerator # この後メモリを使うのでこれは消す
@@ -744,6 +894,8 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, False)
train_util.add_masked_loss_arguments(parser)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
@@ -779,7 +931,12 @@ def setup_parser() -> argparse.ArgumentParser:
help=f"learning rates for each block of U-Net, comma-separated, {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / "
+ f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値",
)
parser.add_argument(
"--fused_optimizer_groups",
type=int,
default=None,
help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数",
)
return parser
@@ -787,6 +944,7 @@ if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
train(args)

719
sdxl_train_control_net.py Normal file
View File

@@ -0,0 +1,719 @@
import argparse
import math
import os
import random
from multiprocessing import Value
import toml
from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate.utils import set_seed
from accelerate import init_empty_weights
from diffusers import DDPMScheduler
from diffusers.utils.torch_utils import is_compiled_module
from safetensors.torch import load_file
from library import (
deepspeed_utils,
sai_model_spec,
sdxl_model_util,
sdxl_train_util,
strategy_base,
strategy_sd,
strategy_sdxl,
)
import library.train_util as train_util
import library.config_util as config_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
import library.huggingface_util as huggingface_util
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import (
add_v_prediction_like_loss,
apply_snr_weight,
prepare_scheduler_for_custom_training,
scale_v_prediction_loss_like_noise_prediction,
apply_debiased_estimation,
)
from library.sdxl_original_control_net import SdxlControlNet, SdxlControlledUNet
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
# TODO 他のスクリプトと共通化する
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
logs = {
"loss/current": current_loss,
"loss/average": avr_loss,
"lr": lr_scheduler.get_last_lr()[0],
}
if args.optimizer_type.lower().startswith("DAdapt".lower()):
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
return logs
def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
sdxl_train_util.verify_sdxl_training_args(args)
setup_logging(args, reset=True)
cache_latents = args.cache_latents
use_user_config = args.dataset_config is not None
if args.seed is None:
args.seed = random.randint(0, 2**32)
set_seed(args.seed)
tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
tokenizer1, tokenizer2 = tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2 # this is used for sampling images
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
# データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
if use_user_config:
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "conditioning_data_dir"]
if any(getattr(args, attr) is not None for attr in ignored):
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
user_config = {
"datasets": [
{
"subsets": config_util.generate_controlnet_subsets_config_by_subdirs(
args.train_data_dir,
args.conditioning_data_dir,
args.caption_extension,
)
}
]
}
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
train_dataset_group.verify_bucket_reso_steps(32)
if args.debug_dataset:
train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly
train_util.debug_dataset(train_dataset_group)
return
if len(train_dataset_group) == 0:
logger.error(
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してくださいtrain_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります"
)
return
if cache_latents:
assert (
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
else:
logger.warning(
"WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません"
)
if args.cache_text_encoder_outputs:
assert (
train_dataset_group.is_text_encoder_output_cacheable()
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
# acceleratorを準備する
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
is_main_process = accelerator.is_main_process
def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
# モデルを読み込む
(
load_stable_diffusion_format,
text_encoder1,
text_encoder2,
vae,
unet,
logit_scale,
ckpt_info,
) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype)
unet.to(accelerator.device) # reduce main memory usage
# convert U-Net to Controlled U-Net
logger.info("convert U-Net to Controlled U-Net")
unet_sd = unet.state_dict()
with init_empty_weights():
unet = SdxlControlledUNet()
unet.load_state_dict(unet_sd, strict=True, assign=True)
del unet_sd
# make control net
logger.info("make ControlNet")
if args.controlnet_model_path:
with init_empty_weights():
control_net = SdxlControlNet()
logger.info(f"load ControlNet from {args.controlnet_model_path}")
filename = args.controlnet_model_path
if os.path.splitext(filename)[1] == ".safetensors":
state_dict = load_file(filename)
else:
state_dict = torch.load(filename)
info = control_net.load_state_dict(state_dict, strict=True, assign=True)
logger.info(f"ControlNet loaded from {filename}: {info}")
else:
control_net = SdxlControlNet()
logger.info("initialize ControlNet from U-Net")
info = control_net.init_from_unet(unet)
logger.info(f"ControlNet initialized from U-Net: {info}")
# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
train_dataset_group.new_cache_latents(vae, accelerator)
vae.to("cpu")
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy()
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
# TextEncoderの出力をキャッシュする
if args.cache_text_encoder_outputs:
# Text Encodes are eval and no grad
text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, None, False
)
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy)
text_encoder1.to(accelerator.device)
text_encoder2.to(accelerator.device)
with accelerator.autocast():
train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator)
accelerator.wait_for_everyone()
# モデルに xformers とか memory efficient attention を組み込む
# train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
if args.xformers:
unet.set_use_memory_efficient_attention(True, False)
control_net.set_use_memory_efficient_attention(True, False)
elif args.sdpa:
unet.set_use_sdpa(True)
control_net.set_use_sdpa(True)
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
control_net.enable_gradient_checkpointing()
# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")
trainable_params = []
ctrlnet_params = []
unet_params = []
for name, param in control_net.named_parameters():
if name.startswith("controlnet_"):
ctrlnet_params.append(param)
else:
unet_params.append(param)
trainable_params.append({"params": ctrlnet_params, "lr": args.control_net_lr})
trainable_params.append({"params": unet_params, "lr": args.learning_rate})
all_params = ctrlnet_params + unet_params
logger.info(f"trainable params count: {len(all_params)}")
logger.info(f"number of trainable parameters: {sum(p.numel() for p in all_params)}")
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
train_dataset_group.set_current_strategies()
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# 実験的機能勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
if args.full_fp16:
assert (
args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
accelerator.print("enable full fp16 training.")
control_net.to(weight_dtype)
elif args.full_bf16:
assert (
args.mixed_precision == "bf16"
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
accelerator.print("enable full bf16 training.")
control_net.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
control_net, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
control_net, optimizer, train_dataloader, lr_scheduler
)
if args.fused_backward_pass:
# use fused optimizer for backward pass: other optimizers will be supported in the future
import library.adafactor_fused
library.adafactor_fused.patch_adafactor_fused(optimizer)
for param_group in optimizer.param_groups:
for parameter in param_group["params"]:
if parameter.requires_grad:
def __grad_hook(tensor: torch.Tensor, param_group=param_group):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, param_group)
tensor.grad = None
parameter.register_post_accumulate_grad_hook(__grad_hook)
unet.requires_grad_(False)
text_encoder1.requires_grad_(False)
text_encoder2.requires_grad_(False)
unet.to(accelerator.device, dtype=weight_dtype)
unet.eval()
control_net.train()
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
if args.cache_text_encoder_outputs:
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
text_encoder1.to("cpu", dtype=torch.float32)
text_encoder2.to("cpu", dtype=torch.float32)
clean_memory_on_device(accelerator.device)
else:
# make sure Text Encoders are on GPU
text_encoder1.to(accelerator.device)
text_encoder2.to(accelerator.device)
if not cache_latents:
vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=vae_dtype)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
# 学習する
# TODO: find a way to handle total batch size when there are multiple datasets
accelerator.print("running training / 学習開始")
accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
accelerator.print(
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
)
# logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
)
prepare_scheduler_for_custom_training(noise_scheduler, accelerator.device)
if args.zero_terminal_snr:
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
("sdxl_control_net_train" if args.log_tracker_name is None else args.log_tracker_name),
config=train_util.get_sanitized_config_or_none(args),
init_kwargs=init_kwargs,
)
loss_recorder = train_util.LossRecorder()
del train_dataset_group
# function for saving/removing
def save_model(ckpt_name, model, force_sync_upload=False):
os.makedirs(args.output_dir, exist_ok=True)
ckpt_file = os.path.join(args.output_dir, ckpt_name)
accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
sai_metadata = train_util.get_sai_model_spec(None, args, True, True, False)
sai_metadata["modelspec.architecture"] = sai_model_spec.ARCH_SD_XL_V1_BASE + "/controlnet"
state_dict = model.state_dict()
if save_dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(save_dtype)
state_dict[key] = v
if os.path.splitext(ckpt_file)[1] == ".safetensors":
from safetensors.torch import save_file
save_file(state_dict, ckpt_file, sai_metadata)
else:
torch.save(state_dict, ckpt_file)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
def remove_model(old_ckpt_name):
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
if os.path.exists(old_ckpt_file):
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
# For --sample_at_first
sdxl_train_util.sample_images(
accelerator,
args,
0,
global_step,
accelerator.device,
vae,
[tokenizer1, tokenizer2],
[text_encoder1, text_encoder2, unwrap_model(text_encoder2)],
unet,
controlnet=control_net,
)
# training loop
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
control_net.train()
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(control_net):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.nan_to_num(latents, 0, out=latents)
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
# Text Encoder outputs are cached
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoder_outputs_list
encoder_hidden_states1 = encoder_hidden_states1.to(accelerator.device, dtype=weight_dtype)
encoder_hidden_states2 = encoder_hidden_states2.to(accelerator.device, dtype=weight_dtype)
pool2 = pool2.to(accelerator.device, dtype=weight_dtype)
else:
input_ids1, input_ids2 = batch["input_ids_list"]
with torch.no_grad():
input_ids1 = input_ids1.to(accelerator.device)
input_ids2 = input_ids2.to(accelerator.device)
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens(
tokenize_strategy, [text_encoder1, text_encoder2, unwrap_model(text_encoder2)], [input_ids1, input_ids2]
)
if args.full_fp16:
encoder_hidden_states1 = encoder_hidden_states1.to(weight_dtype)
encoder_hidden_states2 = encoder_hidden_states2.to(weight_dtype)
pool2 = pool2.to(weight_dtype)
# get size embeddings
orig_size = batch["original_sizes_hw"]
crop_size = batch["crop_top_lefts"]
target_size = batch["target_sizes_hw"]
embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
# concat embeddings
vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
controlnet_image = batch["conditioning_images"].to(dtype=weight_dtype)
# '-1 to +1' to '0 to 1'
controlnet_image = (controlnet_image + 1) / 2
with accelerator.autocast():
input_resi_add, mid_add = control_net(
noisy_latents, timesteps, text_embedding, vector_embedding, controlnet_image
)
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding, input_resi_add, mid_add)
if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
target = noise
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss:
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss)
if not args.fused_backward_pass:
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = control_net.parameters()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
else:
# optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
lr_scheduler.step()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
sdxl_train_util.sample_images(
accelerator,
args,
None,
global_step,
accelerator.device,
vae,
[tokenizer1, tokenizer2],
[text_encoder1, text_encoder2, unwrap_model(text_encoder2)],
unet,
controlnet=control_net,
)
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
save_model(ckpt_name, unwrap_model(control_net))
if args.save_state:
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
remove_step_no = train_util.get_remove_step_no(args, global_step)
if remove_step_no is not None:
remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no)
remove_model(remove_ckpt_name)
current_loss = loss.detach().item()
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if len(accelerator.trackers) > 0:
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
# 指定エポックごとにモデルを保存
if args.save_every_n_epochs is not None:
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
if is_main_process and saving:
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
save_model(ckpt_name, unwrap_model(control_net))
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
if remove_epoch_no is not None:
remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no)
remove_model(remove_ckpt_name)
if args.save_state:
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)
sdxl_train_util.sample_images(
accelerator,
args,
epoch + 1,
global_step,
accelerator.device,
vae,
[tokenizer1, tokenizer2],
[text_encoder1, text_encoder2, unwrap_model(text_encoder2)],
unet,
controlnet=control_net,
)
# end of epoch
if is_main_process:
control_net = unwrap_model(control_net)
accelerator.end_training()
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)
if is_main_process:
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(ckpt_name, control_net, force_sync_upload=True)
logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
# train_util.add_masked_loss_arguments(parser)
deepspeed_utils.add_deepspeed_arguments(parser)
# train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
sdxl_train_util.add_sdxl_training_arguments(parser)
parser.add_argument(
"--controlnet_model_path",
type=str,
default=None,
help="controlnet model name or path / controlnetのモデル名またはパス",
)
parser.add_argument(
"--conditioning_data_dir",
type=str,
default=None,
help="conditioning data directory / 条件付けデータのディレクトリ",
)
parser.add_argument(
"--save_model_as",
type=str,
default="safetensors",
choices=[None, "ckpt", "pt", "safetensors"],
help="format to save the model (default is .safetensors) / モデル保存時の形式デフォルトはsafetensors",
)
parser.add_argument(
"--no_half_vae",
action="store_true",
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
)
parser.add_argument(
"--control_net_lr",
type=float,
default=1e-4,
help="learning rate for controlnet modules / controlnetモジュールの学習率",
)
return parser
if __name__ == "__main__":
# sdxl_original_unet.USE_REENTRANT = False
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
train(args)

View File

@@ -15,6 +15,7 @@ from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from torch.nn.parallel import DistributedDataParallel as DDP
@@ -22,7 +23,16 @@ from accelerate.utils import set_seed
import accelerate
from diffusers import DDPMScheduler, ControlNetModel
from safetensors.torch import load_file
from library import sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util
from library import (
deepspeed_utils,
sai_model_spec,
sdxl_model_util,
sdxl_original_unet,
sdxl_train_util,
strategy_base,
strategy_sd,
strategy_sdxl,
)
import library.model_util as model_util
import library.train_util as train_util
@@ -78,7 +88,14 @@ def train(args):
args.seed = random.randint(0, 2**32)
set_seed(args.seed)
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
# データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
@@ -105,7 +122,7 @@ def train(args):
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2])
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
current_epoch = Value("i", 0)
@@ -163,30 +180,30 @@ def train(args):
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(
vae,
args.vae_batch_size,
args.cache_latents_to_disk,
accelerator.is_main_process,
)
train_dataset_group.new_cache_latents(vae, accelerator)
vae.to("cpu")
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy()
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
# TextEncoderの出力をキャッシュする
if args.cache_text_encoder_outputs:
# Text Encodes are eval and no grad
with torch.no_grad():
train_dataset_group.cache_text_encoder_outputs(
(tokenizer1, tokenizer2),
(text_encoder1, text_encoder2),
accelerator.device,
None,
args.cache_text_encoder_outputs_to_disk,
accelerator.is_main_process,
)
text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, None, False
)
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy)
text_encoder1.to(accelerator.device)
text_encoder2.to(accelerator.device)
with accelerator.autocast():
train_dataset_group.new_cache_text_encoder_outputs([text_encoder1, text_encoder2], accelerator)
accelerator.wait_for_everyone()
# prepare ControlNet-LLLite
@@ -241,7 +258,11 @@ def train(args):
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
train_dataset_group.set_current_strategies()
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
@@ -288,6 +309,9 @@ def train(args):
# acceleratorがなんかよろしくやってくれるらしい
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
if isinstance(unet, DDP):
unet._set_static_graph() # avoid error for multiple use of the parameter
if args.gradient_checkpointing:
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
else:
@@ -353,7 +377,9 @@ def train(args):
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name,
config=train_util.get_sanitized_config_or_none(args),
init_kwargs=init_kwargs,
)
loss_recorder = train_util.LossRecorder()
@@ -394,10 +420,10 @@ def train(args):
with accelerator.accumulate(unet):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample()
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
@@ -405,27 +431,25 @@ def train(args):
latents = torch.nan_to_num(latents, 0, out=latents)
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
input_ids1 = batch["input_ids"]
input_ids2 = batch["input_ids2"]
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
# Text Encoder outputs are cached
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoder_outputs_list
encoder_hidden_states1 = encoder_hidden_states1.to(accelerator.device, dtype=weight_dtype)
encoder_hidden_states2 = encoder_hidden_states2.to(accelerator.device, dtype=weight_dtype)
pool2 = pool2.to(accelerator.device, dtype=weight_dtype)
else:
input_ids1, input_ids2 = batch["input_ids_list"]
with torch.no_grad():
# Get the text embedding for conditioning
input_ids1 = input_ids1.to(accelerator.device)
input_ids2 = input_ids2.to(accelerator.device)
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
args.max_token_length,
input_ids1,
input_ids2,
tokenizer1,
tokenizer2,
text_encoder1,
text_encoder2,
None if not args.full_fp16 else weight_dtype,
encoder_hidden_states1, encoder_hidden_states2, pool2 = text_encoding_strategy.encode_tokens(
tokenize_strategy, [text_encoder1, text_encoder2], [input_ids1, input_ids2]
)
else:
encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
encoder_hidden_states2 = batch["text_encoder_outputs2_list"].to(accelerator.device).to(weight_dtype)
pool2 = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype)
if args.full_fp16:
encoder_hidden_states1 = encoder_hidden_states1.to(weight_dtype)
encoder_hidden_states2 = encoder_hidden_states2.to(weight_dtype)
pool2 = pool2.to(weight_dtype)
# get size embeddings
orig_size = batch["original_sizes_hw"]
@@ -458,7 +482,8 @@ def train(args):
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
@@ -471,13 +496,13 @@ def train(args):
if args.v_pred_like_loss:
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = unet.get_trainable_params()
params_to_clip = accelerator.unwrap_model(unet).get_trainable_params()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
@@ -512,14 +537,14 @@ def train(args):
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
@@ -566,6 +591,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
@@ -611,6 +637,7 @@ if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
train(args)

View File

@@ -12,13 +12,14 @@ from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from torch.nn.parallel import DistributedDataParallel as DDP
from accelerate.utils import set_seed
from diffusers import DDPMScheduler, ControlNetModel
from safetensors.torch import load_file
from library import sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util
from library import deepspeed_utils, sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util
import library.model_util as model_util
import library.train_util as train_util
@@ -324,7 +325,9 @@ def train(args):
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name,
config=train_util.get_sanitized_config_or_none(args),
init_kwargs=init_kwargs,
)
loss_recorder = train_util.LossRecorder()
@@ -361,10 +364,10 @@ def train(args):
with accelerator.accumulate(network):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample()
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
@@ -426,7 +429,8 @@ def train(args):
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
@@ -439,7 +443,7 @@ def train(args):
if args.v_pred_like_loss:
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
@@ -480,14 +484,14 @@ def train(args):
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
@@ -534,6 +538,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
@@ -579,6 +584,7 @@ if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
train(args)

View File

@@ -1,16 +1,22 @@
import argparse
from typing import List, Optional
import torch
from accelerate import Accelerator
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from library import sdxl_model_util, sdxl_train_util, train_util
from library import sdxl_model_util, sdxl_train_util, strategy_base, strategy_sd, strategy_sdxl, train_util
import train_network
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class SdxlNetworkTrainer(train_network.NetworkTrainer):
def __init__(self):
super().__init__()
@@ -47,17 +53,41 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
self.logit_scale = logit_scale
self.ckpt_info = ckpt_info
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
vae.set_use_memory_efficient_attention_xformers(args.xformers)
return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet
def load_tokenizer(self, args):
tokenizer = sdxl_train_util.load_tokenizers(args)
return tokenizer
def get_tokenize_strategy(self, args):
return strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
def is_text_encoder_outputs_cached(self, args):
return args.cache_text_encoder_outputs
def get_tokenizers(self, tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy):
return [tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2]
def get_latents_caching_strategy(self, args):
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
return latents_caching_strategy
def get_text_encoding_strategy(self, args):
return strategy_sdxl.SdxlTextEncodingStrategy()
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
return text_encoders + [accelerator.unwrap_model(text_encoders[-1])]
def get_text_encoder_outputs_caching_strategy(self, args):
if args.cache_text_encoder_outputs:
return strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, None, args.skip_cache_check, is_weighted=args.weighted_captions
)
else:
return None
def cache_text_encoder_outputs_if_needed(
self, args, accelerator, unet, vae, tokenizers, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
):
if args.cache_text_encoder_outputs:
if not args.lowram:
@@ -70,15 +100,11 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
clean_memory_on_device(accelerator.device)
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
text_encoders[1].to(accelerator.device, dtype=weight_dtype)
with accelerator.autocast():
dataset.cache_text_encoder_outputs(
tokenizers,
text_encoders,
accelerator.device,
weight_dtype,
args.cache_text_encoder_outputs_to_disk,
accelerator.is_main_process,
)
dataset.new_cache_text_encoder_outputs(text_encoders + [accelerator.unwrap_model(text_encoders[-1])], accelerator)
accelerator.wait_for_everyone()
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
text_encoders[1].to("cpu", dtype=torch.float32)
@@ -147,7 +173,18 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
return encoder_hidden_states1, encoder_hidden_states2, pool2
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
def call_unet(
self,
args,
accelerator,
unet,
noisy_latents,
timesteps,
text_conds,
batch,
weight_dtype,
indices: Optional[List[int]] = None,
):
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
# get size embeddings
@@ -161,6 +198,12 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
if indices is not None and len(indices) > 0:
noisy_latents = noisy_latents[indices]
timesteps = timesteps[indices]
text_embedding = text_embedding[indices]
vector_embedding = vector_embedding[indices]
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
return noise_pred
@@ -178,6 +221,7 @@ if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
trainer = SdxlNetworkTrainer()

View File

@@ -5,11 +5,10 @@ import regex
import torch
from library.device_utils import init_ipex
init_ipex()
import open_clip
from library import sdxl_model_util, sdxl_train_util, train_util
from library import sdxl_model_util, sdxl_train_util, strategy_sd, strategy_sdxl, train_util
import train_textual_inversion
@@ -42,28 +41,20 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
return sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, [text_encoder1, text_encoder2], vae, unet
def load_tokenizer(self, args):
tokenizer = sdxl_train_util.load_tokenizers(args)
return tokenizer
def get_tokenize_strategy(self, args):
return strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
input_ids1 = batch["input_ids"]
input_ids2 = batch["input_ids2"]
with torch.enable_grad():
input_ids1 = input_ids1.to(accelerator.device)
input_ids2 = input_ids2.to(accelerator.device)
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
args.max_token_length,
input_ids1,
input_ids2,
tokenizers[0],
tokenizers[1],
text_encoders[0],
text_encoders[1],
None if not args.full_fp16 else weight_dtype,
accelerator=accelerator,
)
return encoder_hidden_states1, encoder_hidden_states2, pool2
def get_tokenizers(self, tokenize_strategy: strategy_sdxl.SdxlTokenizeStrategy):
return [tokenize_strategy.tokenizer1, tokenize_strategy.tokenizer2]
def get_latents_caching_strategy(self, args):
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
return latents_caching_strategy
def get_text_encoding_strategy(self, args):
return strategy_sdxl.SdxlTextEncodingStrategy()
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
@@ -82,9 +73,11 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
return noise_pred
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement):
def sample_images(
self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoders, unet, prompt_replacement
):
sdxl_train_util.sample_images(
accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement
accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoders, unet, prompt_replacement
)
def save_weights(self, file, updated_embs, save_dtype, metadata):
@@ -123,8 +116,7 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
def setup_parser() -> argparse.ArgumentParser:
parser = train_textual_inversion.setup_parser()
# don't add sdxl_train_util.add_sdxl_training_arguments(parser): because it only adds text encoder caching
# sdxl_train_util.add_sdxl_training_arguments(parser)
sdxl_train_util.add_sdxl_training_arguments(parser, support_text_encoder_caching=False)
return parser
@@ -132,6 +124,7 @@ if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
trainer = SdxlTextualInversionTrainer()

View File

@@ -9,47 +9,82 @@ from accelerate.utils import set_seed
import torch
from tqdm import tqdm
from library import config_util
from library import config_util, flux_train_utils, flux_utils, strategy_base, strategy_flux, strategy_sd, strategy_sdxl
from library import train_util
from library import sdxl_train_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
from library.utils import setup_logging
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
def cache_to_disk(args: argparse.Namespace) -> None:
train_util.prepare_dataset_args(args, True)
# check cache latents arg
assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります"
def set_tokenize_strategy(is_sd: bool, is_sdxl: bool, is_flux: bool, args: argparse.Namespace) -> None:
if is_flux:
_, is_schnell, _ = flux_utils.check_flux_state_dict_diffusers_schnell(args.pretrained_model_name_or_path)
else:
is_schnell = False
if is_sd:
tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
elif is_sdxl:
tokenize_strategy = strategy_sdxl.SdxlTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
else:
if args.t5xxl_max_token_length is None:
if is_schnell:
t5xxl_max_token_length = 256
else:
t5xxl_max_token_length = 512
else:
t5xxl_max_token_length = args.t5xxl_max_token_length
logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}")
tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir)
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
def cache_to_disk(args: argparse.Namespace) -> None:
setup_logging(args, reset=True)
train_util.prepare_dataset_args(args, True)
train_util.enable_high_vram(args)
# assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります"
args.cache_latents = True
args.cache_latents_to_disk = True
use_dreambooth_method = args.in_json is None
if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する
# tokenizerを準備するdatasetを動かすために必要
if args.sdxl:
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
tokenizers = [tokenizer1, tokenizer2]
is_sd = not args.sdxl and not args.flux
is_sdxl = args.sdxl
is_flux = args.flux
set_tokenize_strategy(is_sd, is_sdxl, is_flux, args)
if is_sd or is_sdxl:
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(is_sd, True, args.vae_batch_size, args.skip_cache_check)
else:
tokenizer = train_util.load_tokenizer(args)
tokenizers = [tokenizer]
latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(True, args.vae_batch_size, args.skip_cache_check)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
# データセットを準備する
use_user_config = args.dataset_config is not None
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
if args.dataset_config is not None:
logger.info(f"Load dataset config from {args.dataset_config}")
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
if use_user_config:
logger.info(f"Loading dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
"ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
@@ -80,20 +115,15 @@ def cache_to_disk(args: argparse.Namespace) -> None:
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers)
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers)
# datasetのcache_latentsを呼ばなければ、生の画像が返る
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
# use arbitrary dataset class
train_dataset_group = train_util.load_arbitrary_dataset(args)
# acceleratorを準備する
logger.info("prepare accelerator")
args.deepspeed = False
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
@@ -102,79 +132,43 @@ def cache_to_disk(args: argparse.Namespace) -> None:
# モデルを読み込む
logger.info("load model")
if args.sdxl:
if is_sd:
_, vae, _, _ = train_util.load_target_model(args, weight_dtype, accelerator)
elif is_sdxl:
(_, _, _, vae, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
else:
_, vae, _, _ = train_util.load_target_model(args, weight_dtype, accelerator)
vae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
if is_sd or is_sdxl:
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
vae.set_use_memory_efficient_attention_xformers(args.xformers)
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
vae.set_use_memory_efficient_attention_xformers(args.xformers)
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
# dataloaderを準備する
train_dataset_group.set_caching_mode("latents")
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# acceleratorを使ってモデルを準備するマルチGPUで使えるようになるはず
train_dataloader = accelerator.prepare(train_dataloader)
# データ取得のためのループ
for batch in tqdm(train_dataloader):
b_size = len(batch["images"])
vae_batch_size = b_size if args.vae_batch_size is None else args.vae_batch_size
flip_aug = batch["flip_aug"]
random_crop = batch["random_crop"]
bucket_reso = batch["bucket_reso"]
# バッチを分割して処理する
for i in range(0, b_size, vae_batch_size):
images = batch["images"][i : i + vae_batch_size]
absolute_paths = batch["absolute_paths"][i : i + vae_batch_size]
resized_sizes = batch["resized_sizes"][i : i + vae_batch_size]
image_infos = []
for i, (image, absolute_path, resized_size) in enumerate(zip(images, absolute_paths, resized_sizes)):
image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path)
image_info.image = image
image_info.bucket_reso = bucket_reso
image_info.resized_size = resized_size
image_info.latents_npz = os.path.splitext(absolute_path)[0] + ".npz"
if args.skip_existing:
if train_util.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug):
logger.warning(f"Skipping {image_info.latents_npz} because it already exists.")
continue
image_infos.append(image_info)
if len(image_infos) > 0:
train_util.cache_batch_latents(vae, True, image_infos, flip_aug, random_crop)
# cache latents with dataset
# TODO use DataLoader to speed up
train_dataset_group.new_cache_latents(vae, accelerator)
accelerator.wait_for_everyone()
accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.")
accelerator.print(f"Finished caching latents to disk.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
train_util.add_training_arguments(parser, True)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_masked_loss_arguments(parser)
config_util.add_config_arguments(parser)
train_util.add_dit_training_arguments(parser)
flux_train_utils.add_flux_train_arguments(parser)
parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する")
parser.add_argument("--flux", action="store_true", help="Use FLUX model / FLUXモデルを使用する")
parser.add_argument(
"--no_half_vae",
action="store_true",
@@ -183,7 +177,8 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--skip_existing",
action="store_true",
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップするflip_aug有効時は通常、反転の両方が存在する画像をスキップ",
help="[Deprecated] This option does not work. Existing .npz files are always checked. Use `--skip_cache_check` to skip the check."
" / [非推奨] このオプションは機能しません。既存の .npz は常に検証されます。`--skip_cache_check` で検証をスキップできます。",
)
return parser

View File

@@ -9,54 +9,69 @@ from accelerate.utils import set_seed
import torch
from tqdm import tqdm
from library import config_util
from library import (
config_util,
flux_train_utils,
flux_utils,
sdxl_model_util,
strategy_base,
strategy_flux,
strategy_sd,
strategy_sdxl,
)
from library import train_util
from library import sdxl_train_util
from library import utils
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
from library.utils import setup_logging
from library.utils import setup_logging, add_logging_arguments
from cache_latents import set_tokenize_strategy
setup_logging()
import logging
logger = logging.getLogger(__name__)
def cache_to_disk(args: argparse.Namespace) -> None:
setup_logging(args, reset=True)
train_util.prepare_dataset_args(args, True)
train_util.enable_high_vram(args)
# check cache arg
assert (
args.cache_text_encoder_outputs_to_disk
), "cache_text_encoder_outputs_to_disk must be True / cache_text_encoder_outputs_to_diskはTrueである必要があります"
# できるだけ準備はしておくが今のところSDXLのみしか動かない
assert (
args.sdxl
), "cache_text_encoder_outputs_to_disk is only available for SDXL / cache_text_encoder_outputs_to_diskはSDXLのみ利用可能です"
args.cache_text_encoder_outputs = True
args.cache_text_encoder_outputs_to_disk = True
use_dreambooth_method = args.in_json is None
if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する
# tokenizerを準備するdatasetを動かすために必要
if args.sdxl:
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
tokenizers = [tokenizer1, tokenizer2]
else:
tokenizer = train_util.load_tokenizer(args)
tokenizers = [tokenizer]
is_sd = not args.sdxl and not args.flux
is_sdxl = args.sdxl
is_flux = args.flux
assert (
is_sdxl or is_flux
), "Cache text encoder outputs to disk is only supported for SDXL and FLUX models / テキストエンコーダ出力のディスクキャッシュはSDXLまたはFLUXでのみ有効です"
assert (
is_sdxl or args.weighted_captions is None
), "Weighted captions are only supported for SDXL models / 重み付きキャプションはSDXLモデルでのみ有効です"
set_tokenize_strategy(is_sd, is_sdxl, is_flux, args)
# データセットを準備する
use_user_config = args.dataset_config is not None
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
if args.dataset_config is not None:
logger.info(f"Load dataset config from {args.dataset_config}")
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
if use_user_config:
logger.info(f"Loading dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
"ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
@@ -87,100 +102,117 @@ def cache_to_disk(args: argparse.Namespace) -> None:
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers)
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
# use arbitrary dataset class
train_dataset_group = train_util.load_arbitrary_dataset(args)
# acceleratorを準備する
logger.info("prepare accelerator")
args.deepspeed = False
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, _ = train_util.prepare_dtype(args)
t5xxl_dtype = utils.str_to_dtype(args.t5xxl_dtype, weight_dtype)
# モデルを読み込む
logger.info("load model")
if args.sdxl:
(_, text_encoder1, text_encoder2, _, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
if is_sdxl:
_, text_encoder1, text_encoder2, _, _, _, _ = sdxl_train_util.load_target_model(
args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype
)
text_encoder1.to(accelerator.device, weight_dtype)
text_encoder2.to(accelerator.device, weight_dtype)
text_encoders = [text_encoder1, text_encoder2]
else:
text_encoder1, _, _, _ = train_util.load_target_model(args, weight_dtype, accelerator)
text_encoders = [text_encoder1]
clip_l = flux_utils.load_clip_l(
args.clip_l, weight_dtype, accelerator.device, disable_mmap=args.disable_mmap_load_safetensors
)
t5xxl = flux_utils.load_t5xxl(args.t5xxl, None, accelerator.device, disable_mmap=args.disable_mmap_load_safetensors)
if t5xxl.dtype == torch.float8_e4m3fnuz or t5xxl.dtype == torch.float8_e5m2 or t5xxl.dtype == torch.float8_e5m2fnuz:
raise ValueError(f"Unsupported fp8 model dtype: {t5xxl.dtype}")
elif t5xxl.dtype == torch.float8_e4m3fn:
logger.info("Loaded fp8 T5XXL model")
if t5xxl_dtype != t5xxl_dtype:
if t5xxl.dtype == torch.float8_e4m3fn and t5xxl_dtype.itemsize() >= 2:
logger.warning(
"The loaded model is fp8, but the specified T5XXL dtype is larger than fp8. This may cause a performance drop."
" / ロードされたモデルはfp8ですが、指定されたT5XXLのdtypeがfp8より高精度です。精度低下が発生する可能性があります。"
)
logger.info(f"Casting T5XXL model to {t5xxl_dtype}")
t5xxl.to(t5xxl_dtype)
text_encoders = [clip_l, t5xxl]
for text_encoder in text_encoders:
text_encoder.to(accelerator.device, dtype=weight_dtype)
text_encoder.requires_grad_(False)
text_encoder.eval()
# dataloaderを準備する
train_dataset_group.set_caching_mode("text")
# build text encoder outputs caching strategy
if is_sdxl:
text_encoder_outputs_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, None, args.skip_cache_check, is_weighted=args.weighted_captions
)
else:
text_encoder_outputs_caching_strategy = strategy_flux.FluxTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
args.text_encoder_batch_size,
args.skip_cache_check,
is_partial=False,
apply_t5_attn_mask=args.apply_t5_attn_mask,
)
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_outputs_caching_strategy)
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
# build text encoding strategy
if is_sdxl:
text_encoding_strategy = strategy_sdxl.SdxlTextEncodingStrategy()
else:
text_encoding_strategy = strategy_flux.FluxTextEncodingStrategy(args.apply_t5_attn_mask)
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# acceleratorを使ってモデルを準備するマルチGPUで使えるようになるはず
train_dataloader = accelerator.prepare(train_dataloader)
# データ取得のためのループ
for batch in tqdm(train_dataloader):
absolute_paths = batch["absolute_paths"]
input_ids1_list = batch["input_ids1_list"]
input_ids2_list = batch["input_ids2_list"]
image_infos = []
for absolute_path, input_ids1, input_ids2 in zip(absolute_paths, input_ids1_list, input_ids2_list):
image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path)
image_info.text_encoder_outputs_npz = os.path.splitext(absolute_path)[0] + train_util.TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX
image_info
if args.skip_existing:
if os.path.exists(image_info.text_encoder_outputs_npz):
logger.warning(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.")
continue
image_info.input_ids1 = input_ids1
image_info.input_ids2 = input_ids2
image_infos.append(image_info)
if len(image_infos) > 0:
b_input_ids1 = torch.stack([image_info.input_ids1 for image_info in image_infos])
b_input_ids2 = torch.stack([image_info.input_ids2 for image_info in image_infos])
train_util.cache_batch_text_encoder_outputs(
image_infos, tokenizers, text_encoders, args.max_token_length, True, b_input_ids1, b_input_ids2, weight_dtype
)
# cache text encoder outputs
train_dataset_group.new_cache_text_encoder_outputs(text_encoders, accelerator)
accelerator.wait_for_everyone()
accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.")
accelerator.print(f"Finished caching text encoder outputs to disk.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
train_util.add_training_arguments(parser, True)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_masked_loss_arguments(parser)
config_util.add_config_arguments(parser)
sdxl_train_util.add_sdxl_training_arguments(parser)
train_util.add_dit_training_arguments(parser)
flux_train_utils.add_flux_train_arguments(parser)
parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する")
parser.add_argument("--flux", action="store_true", help="Use FLUX model / FLUXモデルを使用する")
parser.add_argument(
"--t5xxl_dtype",
type=str,
default=None,
help="T5XXL model dtype, default: None (use mixed precision dtype) / T5XXLモデルのdtype, デフォルト: None (mixed precisionのdtypeを使用)",
)
parser.add_argument(
"--skip_existing",
action="store_true",
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップするflip_aug有効時は通常、反転の両方が存在する画像をスキップ",
help="[Deprecated] This option does not work. Existing .npz files are always checked. Use `--skip_cache_check` to skip the check."
" / [非推奨] このオプションは機能しません。既存の .npz は常に検証されます。`--skip_cache_check` で検証をスキップできます。",
)
parser.add_argument(
"--weighted_captions",
action="store_true",
default=False,
help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意",
)
return parser

View File

@@ -0,0 +1,149 @@
# This script converts the diffusers of a Flux model to a safetensors file of a Flux.1 model.
# It is based on the implementation by 2kpr. Thanks to 2kpr!
# Major changes:
# - Iterates over three safetensors files to reduce memory usage, not loading all tensors at once.
# - Makes reverse map from diffusers map to avoid loading all tensors.
# - Removes dependency on .json file for weights mapping.
# - Adds support for custom memory efficient load and save functions.
# - Supports saving with different precision.
# - Supports .safetensors file as input.
# Copyright 2024 2kpr. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import argparse
import os
from pathlib import Path
import safetensors
from safetensors.torch import safe_open
import torch
from tqdm import tqdm
from library import flux_utils
from library.utils import setup_logging, str_to_dtype, MemoryEfficientSafeOpen, mem_eff_save_file
setup_logging()
import logging
logger = logging.getLogger(__name__)
def convert(args):
# if diffusers_path is folder, get safetensors file
diffusers_path = Path(args.diffusers_path)
if diffusers_path.is_dir():
diffusers_path = Path.joinpath(diffusers_path, "transformer", "diffusion_pytorch_model-00001-of-00003.safetensors")
flux_path = Path(args.save_to)
if not os.path.exists(flux_path.parent):
os.makedirs(flux_path.parent)
if not diffusers_path.exists():
logger.error(f"Error: Missing transformer safetensors file: {diffusers_path}")
return
mem_eff_flag = args.mem_eff_load_save
save_dtype = str_to_dtype(args.save_precision) if args.save_precision is not None else None
# make reverse map from diffusers map
diffusers_to_bfl_map = flux_utils.make_diffusers_to_bfl_map()
# iterate over three safetensors files to reduce memory usage
flux_sd = {}
for i in range(3):
# replace 00001 with 0000i
current_diffusers_path = Path(str(diffusers_path).replace("00001", f"0000{i+1}"))
logger.info(f"Loading diffusers file: {current_diffusers_path}")
open_func = MemoryEfficientSafeOpen if mem_eff_flag else (lambda x: safe_open(x, framework="pt"))
with open_func(current_diffusers_path) as f:
for diffusers_key in tqdm(f.keys()):
if diffusers_key in diffusers_to_bfl_map:
tensor = f.get_tensor(diffusers_key).to("cpu")
if save_dtype is not None:
tensor = tensor.to(save_dtype)
index, bfl_key = diffusers_to_bfl_map[diffusers_key]
if bfl_key not in flux_sd:
flux_sd[bfl_key] = []
flux_sd[bfl_key].append((index, tensor))
else:
logger.error(f"Error: Key not found in diffusers_to_bfl_map: {diffusers_key}")
return
# concat tensors if multiple tensors are mapped to a single key, sort by index
for key, values in flux_sd.items():
if len(values) == 1:
flux_sd[key] = values[0][1]
else:
flux_sd[key] = torch.cat([value[1] for value in sorted(values, key=lambda x: x[0])])
# special case for final_layer.adaLN_modulation.1.weight and final_layer.adaLN_modulation.1.bias
def swap_scale_shift(weight):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
if "final_layer.adaLN_modulation.1.weight" in flux_sd:
flux_sd["final_layer.adaLN_modulation.1.weight"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.weight"])
if "final_layer.adaLN_modulation.1.bias" in flux_sd:
flux_sd["final_layer.adaLN_modulation.1.bias"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.bias"])
# save flux_sd to safetensors file
logger.info(f"Saving Flux safetensors file: {flux_path}")
if mem_eff_flag:
mem_eff_save_file(flux_sd, flux_path)
else:
safetensors.torch.save_file(flux_sd, flux_path)
logger.info("Conversion completed.")
def setup_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--diffusers_path",
default=None,
type=str,
required=True,
help="Path to the original Flux diffusers folder or *-00001-of-00003.safetensors file."
" / 元のFlux diffusersフォルダーまたは*-00001-of-00003.safetensorsファイルへのパス",
)
parser.add_argument(
"--save_to",
default=None,
type=str,
required=True,
help="Output path for the Flux safetensors file. / Flux safetensorsファイルの出力先",
)
parser.add_argument(
"--mem_eff_load_save",
action="store_true",
help="use custom memory efficient load and save functions for FLUX.1 model"
" / カスタムのメモリ効率の良い読み込みと保存関数をFLUX.1モデルに使用する",
)
parser.add_argument(
"--save_precision",
type=str,
default=None,
help="precision in saving, default is same as loading precision"
"float32, fp16, bf16, fp8 (same as fp8_e4m3fn), fp8_e4m3fn, fp8_e4m3fnuz, fp8_e5m2, fp8_e5m2fnuz"
" / 保存時に精度を変更して保存する、デフォルトは読み込み時と同じ精度",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
convert(args)

View File

@@ -15,7 +15,7 @@ import os
from anime_face_detector import create_detector
from tqdm import tqdm
import numpy as np
from library.utils import setup_logging
from library.utils import setup_logging, pil_resize
setup_logging()
import logging
logger = logging.getLogger(__name__)
@@ -172,7 +172,10 @@ def process(args):
if scale != 1.0:
w = int(w * scale + .5)
h = int(h * scale + .5)
face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LANCZOS4)
if scale < 1.0:
face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA)
else:
face_img = pil_resize(face_img, (w, h))
cx = int(cx * scale + .5)
cy = int(cy * scale + .5)
fw = int(fw * scale + .5)

View File

@@ -6,7 +6,7 @@ import shutil
import math
from PIL import Image
import numpy as np
from library.utils import setup_logging
from library.utils import setup_logging, pil_resize
setup_logging()
import logging
logger = logging.getLogger(__name__)
@@ -24,9 +24,9 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
# Select interpolation method
if interpolation == 'lanczos4':
cv2_interpolation = cv2.INTER_LANCZOS4
pil_interpolation = Image.LANCZOS
elif interpolation == 'cubic':
cv2_interpolation = cv2.INTER_CUBIC
pil_interpolation = Image.BICUBIC
else:
cv2_interpolation = cv2.INTER_AREA
@@ -64,7 +64,10 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
new_width = int(img.shape[1] * math.sqrt(scale_factor))
# Resize image
img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation)
if cv2_interpolation:
img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation)
else:
img = pil_resize(img, (new_width, new_height), interpolation=pil_interpolation)
else:
new_height, new_width = img.shape[0:2]

View File

@@ -5,13 +5,16 @@ import os
import random
import time
from multiprocessing import Value
from types import SimpleNamespace
# from omegaconf import OmegaConf
import toml
from tqdm import tqdm
import torch
from library import deepspeed_utils
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from torch.nn.parallel import DistributedDataParallel as DDP
@@ -104,6 +107,8 @@ def train(args):
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
train_dataset_group.verify_bucket_reso_steps(64)
if args.debug_dataset:
train_util.debug_dataset(train_dataset_group)
return
@@ -147,8 +152,10 @@ def train(args):
"in_channels": 4,
"layers_per_block": 2,
"mid_block_scale_factor": 1,
"mid_block_type": "UNetMidBlock2DCrossAttn",
"norm_eps": 1e-05,
"norm_num_groups": 32,
"num_attention_heads": [5, 10, 20, 20],
"num_class_embeds": None,
"only_cross_attention": False,
"out_channels": 4,
@@ -178,8 +185,10 @@ def train(args):
"in_channels": 4,
"layers_per_block": 2,
"mid_block_scale_factor": 1,
"mid_block_type": "UNetMidBlock2DCrossAttn",
"norm_eps": 1e-05,
"norm_num_groups": 32,
"num_attention_heads": 8,
"out_channels": 4,
"sample_size": 64,
"up_block_types": ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"],
@@ -192,7 +201,23 @@ def train(args):
"resnet_time_scale_shift": "default",
"projection_class_embeddings_input_dim": None,
}
unet.config = SimpleNamespace(**unet.config)
# unet.config = OmegaConf.create(unet.config)
# make unet.config iterable and accessible by attribute
class CustomConfig:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
def __getattr__(self, name):
if name in self.__dict__:
return self.__dict__[name]
else:
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
def __contains__(self, name):
return name in self.__dict__
unet.config = CustomConfig(**unet.config)
controlnet = ControlNetModel.from_unet(unet)
@@ -225,16 +250,17 @@ def train(args):
)
vae.to("cpu")
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
controlnet.enable_gradient_checkpointing()
# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")
trainable_params = controlnet.parameters()
trainable_params = list(controlnet.parameters())
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
@@ -279,6 +305,22 @@ def train(args):
controlnet, optimizer, train_dataloader, lr_scheduler
)
if args.fused_backward_pass:
import library.adafactor_fused
library.adafactor_fused.patch_adafactor_fused(optimizer)
for param_group in optimizer.param_groups:
for parameter in param_group["params"]:
if parameter.requires_grad:
def __grad_hook(tensor: torch.Tensor, param_group=param_group):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, param_group)
tensor.grad = None
parameter.register_post_accumulate_grad_hook(__grad_hook)
unet.requires_grad_(False)
text_encoder.requires_grad_(False)
unet.to(accelerator.device)
@@ -343,7 +385,9 @@ def train(args):
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
"controlnet_train" if args.log_tracker_name is None else args.log_tracker_name,
config=train_util.get_sanitized_config_or_none(args),
init_kwargs=init_kwargs,
)
loss_recorder = train_util.LossRecorder()
@@ -384,6 +428,9 @@ def train(args):
train_util.sample_images(
accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, controlnet=controlnet
)
if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)
# training loop
for epoch in range(num_train_epochs):
@@ -396,7 +443,7 @@ def train(args):
with accelerator.accumulate(controlnet):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
@@ -419,13 +466,8 @@ def train(args):
)
# Sample a random timestep for each image
timesteps = torch.randint(
0,
noise_scheduler.config.num_train_timesteps,
(b_size,),
device=latents.device,
)
timesteps = timesteps.long()
timesteps = train_util.get_timesteps(0, noise_scheduler.config.num_train_timesteps, b_size, latents.device)
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
@@ -456,7 +498,8 @@ def train(args):
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
@@ -468,13 +511,17 @@ def train(args):
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = controlnet.parameters()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
if not args.fused_backward_pass:
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = controlnet.parameters()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
else:
# optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
lr_scheduler.step()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
@@ -518,14 +565,14 @@ def train(args):
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
@@ -584,6 +631,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
@@ -615,6 +663,7 @@ if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
train(args)

View File

@@ -11,7 +11,10 @@ import toml
from tqdm import tqdm
import torch
from library import deepspeed_utils, strategy_base
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate.utils import set_seed
@@ -32,8 +35,10 @@ from library.custom_train_functions import (
apply_noise_offset,
scale_v_prediction_loss_like_noise_prediction,
apply_debiased_estimation,
apply_masked_loss,
)
from library.utils import setup_logging, add_logging_arguments
import library.strategy_sd as strategy_sd
setup_logging()
import logging
@@ -46,6 +51,7 @@ logger = logging.getLogger(__name__)
def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, False)
deepspeed_utils.prepare_deepspeed_args(args)
setup_logging(args, reset=True)
cache_latents = args.cache_latents
@@ -53,11 +59,18 @@ def train(args):
if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する
tokenizer = train_util.load_tokenizer(args)
tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
False, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, False, True))
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, args.masked_loss, True))
if args.dataset_config is not None:
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
@@ -75,10 +88,10 @@ def train(args):
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
train_dataset_group = train_util.load_arbitrary_dataset(args)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
@@ -88,6 +101,8 @@ def train(args):
if args.no_token_padding:
train_dataset_group.disable_token_padding()
train_dataset_group.verify_bucket_reso_steps(64)
if args.debug_dataset:
train_util.debug_dataset(train_dataset_group)
return
@@ -140,13 +155,17 @@ def train(args):
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
train_dataset_group.new_cache_latents(vae, accelerator)
vae.to("cpu")
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
text_encoding_strategy = strategy_sd.SdTextEncodingStrategy(args.clip_skip)
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
# 学習を準備する:モデルを適切な状態にする
train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0
unet.requires_grad_(True) # 念のため追加
@@ -179,8 +198,11 @@ def train(args):
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
# prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
train_dataset_group.set_current_strategies()
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
@@ -219,12 +241,25 @@ def train(args):
text_encoder.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
if train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
if args.deepspeed:
if args.train_text_encoder:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder)
else:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
training_models = [ds_model]
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
if train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
training_models = [unet, text_encoder]
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
training_models = [unet]
if not train_text_encoder:
text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error
@@ -272,10 +307,19 @@ def train(args):
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
accelerator.init_trackers(
"dreambooth" if args.log_tracker_name is None else args.log_tracker_name,
config=train_util.get_sanitized_config_or_none(args),
init_kwargs=init_kwargs,
)
# For --sample_at_first
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
train_util.sample_images(
accelerator, args, 0, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
)
if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)
loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
@@ -296,12 +340,14 @@ def train(args):
if not args.gradient_checkpointing:
text_encoder.train(False)
text_encoder.requires_grad_(False)
if len(training_models) == 2:
training_models = training_models[0] # remove text_encoder from training_models
with accelerator.accumulate(unet):
with accelerator.accumulate(*training_models):
with torch.no_grad():
# latentに変換
if cache_latents:
latents = batch["latents"].to(accelerator.device)
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
@@ -310,19 +356,17 @@ def train(args):
# Get the text embedding for conditioning
with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):
if args.weighted_captions:
encoder_hidden_states = get_weighted_text_embeddings(
tokenizer,
text_encoder,
batch["captions"],
accelerator.device,
args.max_token_length // 75 if args.max_token_length else 1,
clip_skip=args.clip_skip,
)
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
encoder_hidden_states = text_encoding_strategy.encode_tokens_with_weights(
tokenize_strategy, [text_encoder], input_ids_list, weights_list
)[0]
else:
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
)
input_ids = batch["input_ids_list"][0].to(accelerator.device)
encoder_hidden_states = text_encoding_strategy.encode_tokens(
tokenize_strategy, [text_encoder], [input_ids]
)[0]
if args.full_fp16:
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
@@ -338,7 +382,10 @@ def train(args):
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
@@ -349,7 +396,7 @@ def train(args):
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
@@ -371,7 +418,7 @@ def train(args):
global_step += 1
train_util.sample_images(
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
accelerator, args, None, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
)
# 指定ステップごとにモデルを保存
@@ -396,7 +443,7 @@ def train(args):
)
current_loss = loss.detach().item()
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss": current_loss}
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
accelerator.log(logs, step=global_step)
@@ -409,7 +456,7 @@ def train(args):
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
@@ -435,7 +482,9 @@ def train(args):
vae,
)
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
train_util.sample_images(
accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenize_strategy.tokenizer, text_encoder, unet
)
is_main_process = accelerator.is_main_process
if is_main_process:
@@ -464,6 +513,8 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, False, True)
train_util.add_training_arguments(parser, True)
train_util.add_masked_loss_arguments(parser)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
@@ -499,6 +550,7 @@ if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
train(args)

File diff suppressed because it is too large Load Diff

View File

@@ -2,18 +2,21 @@ import argparse
import math
import os
from multiprocessing import Value
from typing import Any, List
import toml
from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from transformers import CLIPTokenizer
from library import model_util
from library import deepspeed_utils, model_util, strategy_base, strategy_sd
import library.train_util as train_util
import library.huggingface_util as huggingface_util
@@ -29,6 +32,7 @@ from library.custom_train_functions import (
scale_v_prediction_loss_like_noise_prediction,
add_v_prediction_like_loss,
apply_debiased_estimation,
apply_masked_loss,
)
from library.utils import setup_logging, add_logging_arguments
@@ -96,32 +100,42 @@ class TextualInversionTrainer:
self.is_sdxl = False
def assert_extra_args(self, args, train_dataset_group):
pass
train_dataset_group.verify_bucket_reso_steps(64)
def load_target_model(self, args, weight_dtype, accelerator):
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)
return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), text_encoder, vae, unet
return model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization), [text_encoder], vae, unet
def load_tokenizer(self, args):
tokenizer = train_util.load_tokenizer(args)
return tokenizer
def get_tokenize_strategy(self, args):
return strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
def get_tokenizers(self, tokenize_strategy: strategy_sd.SdTokenizeStrategy) -> List[Any]:
return [tokenize_strategy.tokenizer]
def get_latents_caching_strategy(self, args):
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
True, args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
return latents_caching_strategy
def assert_token_string(self, token_string, tokenizers: CLIPTokenizer):
pass
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
with torch.enable_grad():
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizers[0], text_encoders[0], None)
return encoder_hidden_states
def get_text_encoding_strategy(self, args):
return strategy_sd.SdTextEncodingStrategy(args.clip_skip)
def get_models_for_text_encoding(self, args, accelerator, text_encoders) -> List[Any]:
return text_encoders
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
noise_pred = unet(noisy_latents, timesteps, text_conds).sample
noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample
return noise_pred
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement):
def sample_images(
self, accelerator, args, epoch, global_step, device, vae, tokenizers, text_encoders, unet, prompt_replacement
):
train_util.sample_images(
accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet, prompt_replacement
accelerator, args, epoch, global_step, device, vae, tokenizers[0], text_encoders[0], unet, prompt_replacement
)
def save_weights(self, file, updated_embs, save_dtype, metadata):
@@ -179,8 +193,13 @@ class TextualInversionTrainer:
if args.seed is not None:
set_seed(args.seed)
tokenizer_or_list = self.load_tokenizer(args) # list of tokenizer or tokenizer
tokenizers = tokenizer_or_list if isinstance(tokenizer_or_list, list) else [tokenizer_or_list]
tokenize_strategy = self.get_tokenize_strategy(args)
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
tokenizers = self.get_tokenizers(tokenize_strategy) # will be removed after sample_image is refactored
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
latents_caching_strategy = self.get_latents_caching_strategy(args)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
# acceleratorを準備する
logger.info("prepare accelerator")
@@ -191,14 +210,7 @@ class TextualInversionTrainer:
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
# モデルを読み込む
model_version, text_encoder_or_list, vae, unet = self.load_target_model(args, weight_dtype, accelerator)
text_encoders = [text_encoder_or_list] if not isinstance(text_encoder_or_list, list) else text_encoder_or_list
if len(text_encoders) > 1 and args.gradient_accumulation_steps > 1:
accelerator.print(
"accelerate doesn't seem to support gradient_accumulation_steps for multiple models (text encoders) / "
+ "accelerateでは複数のモデルテキストエンコーダーのgradient_accumulation_stepsはサポートされていないようです"
)
model_version, text_encoders, vae, unet = self.load_target_model(args, weight_dtype, accelerator)
# Convert the init_word to token_id
init_token_ids_list = []
@@ -268,7 +280,7 @@ class TextualInversionTrainer:
# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, False))
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, False))
if args.dataset_config is not None:
accelerator.print(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
@@ -307,10 +319,10 @@ class TextualInversionTrainer:
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer_or_list)
blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer_or_list)
train_dataset_group = train_util.load_arbitrary_dataset(args)
self.assert_extra_args(args, train_dataset_group)
@@ -365,11 +377,10 @@ class TextualInversionTrainer:
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
clean_memory_on_device(accelerator.device)
train_dataset_group.new_cache_latents(vae, accelerator)
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
if args.gradient_checkpointing:
@@ -384,7 +395,11 @@ class TextualInversionTrainer:
trainable_params += text_encoder.get_input_embeddings().parameters()
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
train_dataset_group.set_current_strategies()
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
@@ -412,20 +427,8 @@ class TextualInversionTrainer:
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# acceleratorがなんかよろしくやってくれるらしい
if len(text_encoders) == 1:
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler
)
elif len(text_encoders) == 2:
text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoders[0], text_encoders[1], optimizer, train_dataloader, lr_scheduler
)
text_encoder_or_list = text_encoders = [text_encoder1, text_encoder2]
else:
raise NotImplementedError()
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
text_encoders = [accelerator.prepare(text_encoder) for text_encoder in text_encoders]
index_no_updates_list = []
orig_embeds_params_list = []
@@ -453,6 +456,9 @@ class TextualInversionTrainer:
else:
unet.eval()
text_encoding_strategy = self.get_text_encoding_strategy(args)
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)
if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する
vae.requires_grad_(False)
vae.eval()
@@ -507,7 +513,9 @@ class TextualInversionTrainer:
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name,
config=train_util.get_sanitized_config_or_none(args),
init_kwargs=init_kwargs,
)
# function for saving/removing
@@ -537,11 +545,14 @@ class TextualInversionTrainer:
global_step,
accelerator.device,
vae,
tokenizer_or_list,
text_encoder_or_list,
tokenizers,
text_encoders,
unet,
prompt_replacement,
)
if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)
# training loop
for epoch in range(num_train_epochs):
@@ -558,14 +569,19 @@ class TextualInversionTrainer:
with accelerator.accumulate(text_encoders[0]):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample()
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)
latents = latents * self.vae_scale_factor
# Get the text embedding for conditioning
text_encoder_conds = self.get_text_cond(args, accelerator, batch, tokenizers, text_encoders, weight_dtype)
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
text_encoder_conds = text_encoding_strategy.encode_tokens(
tokenize_strategy, self.get_models_for_text_encoding(args, accelerator, text_encoders), input_ids
)
if args.full_fp16:
text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds]
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
@@ -585,7 +601,10 @@ class TextualInversionTrainer:
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
@@ -598,7 +617,7 @@ class TextualInversionTrainer:
if args.v_pred_like_loss:
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
@@ -634,8 +653,8 @@ class TextualInversionTrainer:
global_step,
accelerator.device,
vae,
tokenizer_or_list,
text_encoder_or_list,
tokenizers,
text_encoders,
unet,
prompt_replacement,
)
@@ -667,7 +686,7 @@ class TextualInversionTrainer:
remove_model(remove_ckpt_name)
current_loss = loss.detach().item()
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if (
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
@@ -685,7 +704,7 @@ class TextualInversionTrainer:
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_total / len(train_dataloader)}
accelerator.log(logs, step=epoch + 1)
@@ -717,11 +736,12 @@ class TextualInversionTrainer:
global_step,
accelerator.device,
vae,
tokenizer_or_list,
text_encoder_or_list,
tokenizers,
text_encoders,
unet,
prompt_replacement,
)
accelerator.log({})
# end of epoch
@@ -749,6 +769,8 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, False)
train_util.add_training_arguments(parser, True)
train_util.add_masked_loss_arguments(parser)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser, False)
@@ -799,6 +821,7 @@ if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
trainer = TextualInversionTrainer()

View File

@@ -8,7 +8,9 @@ from multiprocessing import Value
from tqdm import tqdm
import torch
from library import deepspeed_utils
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate.utils import set_seed
@@ -31,6 +33,7 @@ from library.custom_train_functions import (
apply_noise_offset,
scale_v_prediction_loss_like_noise_prediction,
apply_debiased_estimation,
apply_masked_loss,
)
import library.original_unet as original_unet
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
@@ -200,7 +203,7 @@ def train(args):
logger.info(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
# データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, False))
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, False))
if args.dataset_config is not None:
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
@@ -404,7 +407,9 @@ def train(args):
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name,
config=train_util.get_sanitized_config_or_none(args),
init_kwargs=init_kwargs,
)
# function for saving/removing
@@ -439,7 +444,7 @@ def train(args):
with accelerator.accumulate(text_encoder):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
@@ -470,7 +475,10 @@ def train(args):
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
@@ -481,7 +489,7 @@ def train(args):
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
@@ -533,7 +541,7 @@ def train(args):
remove_model(remove_ckpt_name)
current_loss = loss.detach().item()
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if (
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
@@ -551,7 +559,7 @@ def train(args):
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_total / len(train_dataloader)}
accelerator.log(logs, step=epoch + 1)
@@ -662,6 +670,8 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, False)
train_util.add_training_arguments(parser, True)
train_util.add_masked_loss_arguments(parser)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser, False)
@@ -707,6 +717,7 @@ if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
train(args)