Compare commits

..

109 Commits

Author SHA1 Message Date
kohya-ss
dbd835ee4b train: Optimize VAE encoding by handling batch sizes for images 2025-04-08 21:57:16 +09:00
Kohya S
5a18a03ffc Merge branch 'dev' into sd3 2025-04-07 21:55:17 +09:00
Kohya S
572cc3efb8 Merge branch 'main' into dev 2025-04-07 21:48:45 +09:00
Kohya S.
52c8dec953 Merge pull request #2015 from DKnight54/uncache_vae_batch
Using --vae_batch_size to set batch size for dynamic latent generation
2025-04-07 21:48:02 +09:00
Kohya S
4589262f8f README.md: Update recent updates section to include IP noise gamma feature for FLUX.1 2025-04-06 21:34:27 +09:00
Kohya S.
c56dc90b26 Merge pull request #1992 from rockerBOO/flux-ip-noise-gamma
Add IP noise gamma for Flux
2025-04-06 21:29:26 +09:00
Kohya S.
ee0f754b08 Merge pull request #2028 from rockerBOO/patch-5
Fix resize PR link
2025-04-05 20:15:13 +09:00
Kohya S.
606e6875d2 Merge pull request #2022 from LexSong/fix-resize-issue
Fix size parameter types and improve resize_image interpolation
2025-04-05 19:28:25 +09:00
Dave Lage
fd36fd1aa9 Fix resize PR link 2025-04-03 16:09:45 -04:00
Kohya S.
92845e8806 Merge pull request #2026 from kohya-ss/fix-finetune-dataset-resize-interpolation
fix: add resize_interpolation parameter to FineTuningDataset constructor
2025-04-03 21:52:14 +09:00
Kohya S
f1423a7229 fix: add resize_interpolation parameter to FineTuningDataset constructor 2025-04-03 21:48:51 +09:00
Lex Song
b822b7e60b Fix the interpolation logic error in resize_image()
The original code had a mistake. It used 'lanczos' when the image got smaller (width > resized_width and height > resized_height) and 'area' when it stayed the same or got bigger. This was the wrong way. 'area' is better for big shrinking.
2025-04-02 22:04:37 +08:00
Lex Song
ede3470260 Ensure all size parameters are integers to prevent type errors 2025-04-02 03:50:33 +08:00
Kohya S
b3c56b22bd Merge branch 'dev' into sd3 2025-03-31 22:05:40 +09:00
Kohya S
583ab27b3c doc: update license information in jpeg_xl_util.py 2025-03-31 22:02:25 +09:00
Kohya S.
aa5978dffd Merge pull request #1955 from Disty0/dev
Fast image size reading support for JPEG XL
2025-03-31 22:00:31 +09:00
Kohya S
aaa26bb882 docs: update README to include LoRA-GGPO details for FLUX.1 training 2025-03-30 21:18:05 +09:00
Kohya S
d0b5c0e5cf chore: formatting, add TODO comment 2025-03-30 21:15:37 +09:00
Kohya S.
59d98e45a9 Merge pull request #1974 from rockerBOO/lora-ggpo
Add LoRA-GGPO for Flux
2025-03-30 21:07:31 +09:00
Kohya S.
3149b2771f Merge pull request #2018 from kohya-ss/resize-interpolation-small-fix
Resize interpolation small fix
2025-03-30 20:52:25 +09:00
Kohya S
96a133c998 README.md: update recent updates section to include new interpolation method for resizing images 2025-03-30 20:45:06 +09:00
Kohya S
1f432e2c0e use PIL for lanczos and box 2025-03-30 20:40:29 +09:00
Kohya S.
9e9a13aa8a Merge pull request #1936 from rockerBOO/resize-interpolation
Add resize interpolation parameter
2025-03-30 20:37:34 +09:00
Kohya S.
93a4efabb5 Merge branch 'sd3' into resize-interpolation 2025-03-30 19:30:56 +09:00
DKnight54
381303d64f Update train_network.py 2025-03-29 02:26:18 +08:00
rockerBOO
0181b7a042 Remove progress bar avg norms 2025-03-27 03:28:33 -04:00
rockerBOO
182544dcce Remove pertubation seed 2025-03-26 14:23:04 -04:00
Kohya S
8ebe858f89 Merge branch 'dev' into sd3 2025-03-24 22:02:16 +09:00
Kohya S.
a0f11730f7 Merge pull request #1966 from sdbds/faster_fix_sdxl
Fatser fix bug for SDXL super SD1.5 assert cant use 32
2025-03-24 21:53:42 +09:00
Kohya S
6364379f17 Merge branch 'dev' into sd3 2025-03-21 22:07:50 +09:00
Kohya S
5253a38783 Merge branch 'main' into dev 2025-03-21 22:07:03 +09:00
Kohya S
8f4ee8fc34 doc: update README for latest 2025-03-21 22:05:48 +09:00
Kohya S.
367f348430 Merge pull request #1964 from Nekotekina/main
Fix missing text encoder attn modules
2025-03-21 21:59:03 +09:00
rockerBOO
89f0d27a59 Set sigmoid_scale to default 1.0 2025-03-20 15:10:33 -04:00
rockerBOO
d40f5b1e4e Revert "Scale sigmoid to default 1.0"
This reverts commit 8aa126582e.
2025-03-20 15:09:50 -04:00
rockerBOO
8aa126582e Scale sigmoid to default 1.0 2025-03-20 15:09:11 -04:00
rockerBOO
e8b3254858 Add flux_train_utils tests for get get_noisy_model_input_and_timesteps 2025-03-20 15:01:15 -04:00
rockerBOO
16cef81aea Refactor sigmas and timesteps 2025-03-20 14:32:56 -04:00
Kohya S
d151833526 docs: update README with recent changes and specify version for pytorch-optimizer 2025-03-20 22:05:29 +09:00
Kohya S.
936d333ff4 Merge pull request #1985 from gesen2egee/pytorch-optimizer
Support pytorch_optimizer
2025-03-20 22:01:03 +09:00
rockerBOO
f974c6b257 change order to match upstream 2025-03-19 14:27:43 -04:00
rockerBOO
5d5a7d2acf Fix IP noise calculation 2025-03-19 13:50:04 -04:00
rockerBOO
1eddac26b0 Separate random to a variable, and make sure on device 2025-03-19 00:49:42 -04:00
rockerBOO
8e6817b0c2 Remove double noise 2025-03-19 00:45:13 -04:00
rockerBOO
d93ad90a71 Add perturbation on noisy_model_input if needed 2025-03-19 00:37:27 -04:00
rockerBOO
7197266703 Perturbed noise should be separate of input noise 2025-03-19 00:25:51 -04:00
gesen2egee
5b210ad717 update prodigyopt and prodigy-plus-schedule-free 2025-03-19 10:49:06 +08:00
rockerBOO
b81bcd0b01 Move IP noise gamma to noise creation to remove complexity and align noise for target loss 2025-03-18 21:36:55 -04:00
rockerBOO
6f4d365775 zeros_like because we are adding 2025-03-18 18:53:34 -04:00
rockerBOO
a4f3a9fc1a Use ones_like 2025-03-18 18:44:21 -04:00
rockerBOO
b425466e7b Fix IP noise gamma to use random values 2025-03-18 18:42:35 -04:00
rockerBOO
c8be141ae0 Apply IP gamma to noise fix 2025-03-18 15:42:18 -04:00
rockerBOO
0b25a05e3c Add IP noise gamma for Flux 2025-03-18 15:40:40 -04:00
rockerBOO
3647d065b5 Cache weight norms estimate on initialization. Move to update norms every step 2025-03-18 14:25:09 -04:00
Disty0
620a06f517 Check for uppercase file extension too 2025-03-17 17:44:29 +03:00
Disty0
564ec5fb7f use extend instead of += 2025-03-17 17:41:03 +03:00
Disty0
7e90cdd47a use bytearray and add typing hints 2025-03-17 17:26:08 +03:00
gesen2egee
e5b5c7e1db Update requirements.txt 2025-03-15 13:29:32 +08:00
rockerBOO
ea53290f62 Add LoRA-GGPO for Flux 2025-03-06 00:00:38 -05:00
Kohya S.
75933d70a1 Merge pull request #1960 from kohya-ss/sd3_safetensors_merge
Sd3 safetensors merge
2025-03-05 23:28:38 +09:00
Kohya S
aa2bde7ece docs: add utility script for merging SD3 weights into a single .safetensors file 2025-03-05 23:24:52 +09:00
sdbds
3f49053c90 fatser fix bug for SDXL super SD1.5 assert cant use 32 2025-03-02 19:32:06 +08:00
Ivan Chikish
acdca2abb7 Fix [occasionally] missing text encoder attn modules
Should fix #1952
I added alternative name for CLIPAttention.
I have no idea why this name changed.
Now it should accept both names.
2025-03-01 20:35:45 +03:00
Kohya S
ba5251168a fix: save tensors as is dtype, add save_precision option 2025-03-01 10:31:39 +09:00
Kohya S
272f4c3775 Merge branch 'sd3' into sd3_safetensors_merge 2025-02-28 23:52:36 +09:00
Kohya S
734333d0c9 feat: enhance merging logic for safetensors models to handle key prefixes correctly 2025-02-28 23:52:29 +09:00
Disty0
2f69f4dbdb fix typo 2025-02-27 00:30:19 +03:00
Disty0
9a415ba965 JPEG XL support 2025-02-27 00:21:57 +03:00
Kohya S
3d79239be4 docs: update README to include recent improvements in validation loss calculation 2025-02-26 21:21:04 +09:00
Kohya S
ec350c83eb Merge branch 'dev' into sd3 2025-02-26 21:17:29 +09:00
Kohya S.
49651892ce Merge pull request #1903 from kohya-ss/val-loss-improvement
Val loss improvement
2025-02-26 21:15:14 +09:00
Kohya S
1fcac98280 Merge branch 'sd3' into val-loss-improvement 2025-02-26 21:09:10 +09:00
Kohya S.
b286304e5f Merge pull request #1953 from Disty0/dev
Update IPEX libs
2025-02-26 21:03:09 +09:00
Kohya S
ae409e83c9 fix: FLUX/SD3 network training not working without caching latents closes #1954 2025-02-26 20:56:32 +09:00
Kohya S
5228db1548 feat: add script to merge multiple safetensors files into a single file for SD3 2025-02-26 20:50:58 +09:00
Kohya S
f4a0047865 feat: support metadata loading in MemoryEfficientSafeOpen 2025-02-26 20:50:44 +09:00
Disty0
f68702f71c Update IPEX libs 2025-02-25 21:27:41 +03:00
Kohya S.
6e90c0f86c Merge pull request #1909 from rockerBOO/progress_bar
Move progress bar to account for sampling image first
2025-02-24 18:57:44 +09:00
Kohya S
67fde015f7 Merge branch 'dev' into sd3 2025-02-24 18:56:15 +09:00
Kohya S.
386b7332c6 Merge pull request #1918 from tsukimiya/fix_vperd_warning
Remove v-pred warning.
2025-02-24 18:55:25 +09:00
Kohya S
905f081798 Merge branch 'dev' into sd3 2025-02-24 18:54:28 +09:00
Kohya S.
59ae9ea20c Merge pull request #1945 from yidiq7/dev
Remove position_ids for V2
2025-02-24 18:53:46 +09:00
Kohya S
efb2a128cd fix wandb val logging 2025-02-21 22:07:35 +09:00
Yidi
13df47516d Remove position_ids for V2
The postions_ids cause errors for the newer version of transformer.
This has already been fixed in convert_ldm_clip_checkpoint_v1() but
not in v2.
The new code applies the same fix to convert_ldm_clip_checkpoint_v2().
2025-02-20 04:49:51 -05:00
rockerBOO
7f2747176b Use resize_image where resizing is required 2025-02-19 14:20:40 -05:00
rockerBOO
ca1c129ffd Fix metadata 2025-02-19 14:20:40 -05:00
rockerBOO
545425c13e Typo 2025-02-19 14:20:40 -05:00
rockerBOO
7729c4c8f9 Add metadata 2025-02-19 14:20:40 -05:00
rockerBOO
d0128d18be Add resize interpolation CLI option 2025-02-19 14:20:40 -05:00
rockerBOO
58e9e146a3 Add resize interpolation configuration 2025-02-19 14:20:40 -05:00
Kohya S
4a36996134 modify log step calculation 2025-02-18 22:05:08 +09:00
Kohya S
dc7d5fb459 Merge branch 'sd3' into val-loss-improvement 2025-02-18 21:34:30 +09:00
Kohya S
63337d9fe4 Merge branch 'sd3' into val-loss-improvement 2025-02-15 21:41:07 +09:00
Kohya S
76b761943b fix: simplify validation step condition in NetworkTrainer 2025-02-11 21:53:57 +09:00
Kohya S
cd80752175 fix: remove unused parameter 'accelerator' from encode_images_to_latents method 2025-02-11 21:42:58 +09:00
Kohya S
177203818a fix: unpause training progress bar after vaidation 2025-02-11 21:42:46 +09:00
Kohya S
344845b429 fix: validation with block swap 2025-02-09 21:25:40 +09:00
Kohya S
0911683717 set python random state 2025-02-09 20:53:49 +09:00
Kohya S
a24db1d532 fix: validation timestep generation fails on SD/SDXL training 2025-02-04 22:02:42 +09:00
Kohya S
c5b803ce94 rng state management: Implement functions to get and set RNG states for consistent validation 2025-02-04 21:59:09 +09:00
tsukimiya
4a71687d20 不要な警告の削除
(おそらく be14c06267 の修正漏れ )
2025-02-04 00:42:27 +09:00
rockerBOO
de830b8941 Move progress bar to account for sampling image first 2025-01-29 00:02:45 -05:00
Kohya S
45ec02b2a8 use same noise for every validation 2025-01-27 22:10:38 +09:00
Kohya S
42c0a9e1fc Merge branch 'sd3' into val-loss-improvement 2025-01-27 22:06:18 +09:00
Kohya S
0750859133 validation: Implement timestep-based validation processing 2025-01-27 21:56:59 +09:00
Kohya S
29f31d005f add network.train()/eval() for validation 2025-01-27 21:35:43 +09:00
Kohya S
b6a3093216 call optimizer eval/train fn before/after validation 2025-01-27 21:22:11 +09:00
Kohya S
86a2f3fd26 Fix gradient handling when Text Encoders are trained 2025-01-27 21:10:52 +09:00
Kohya S
532f5c58a6 formatting 2025-01-27 20:50:42 +09:00
49 changed files with 3019 additions and 2491 deletions

View File

@@ -14,6 +14,28 @@ The command to install PyTorch is as follows:
### Recent Updates
Apr 6, 2025:
- IP noise gamma has been enabled in FLUX.1. Thanks to rockerBOO for PR [#1992](https://github.com/kohya-ss/sd-scripts/pull/1992). See the PR for details.
- `--ip_noise_gamma` and `--ip_noise_gamma_random_strength` are available.
Mar 30, 2025:
- LoRA-GGPO is added for FLUX.1 LoRA training. Thank you to rockerBOO for PR [#1974](https://github.com/kohya-ss/sd-scripts/pull/1974).
- Specify `--network_args ggpo_sigma=0.03 ggpo_beta=0.01` in the command line or `network_args = ["ggpo_sigma=0.03", "ggpo_beta=0.01"]` in .toml file. See PR for details.
- The interpolation method for resizing the original image to the training size can now be specified. Thank you to rockerBOO for PR [#1936](https://github.com/kohya-ss/sd-scripts/pull/1936).
Mar 20, 2025:
- `pytorch-optimizer` is added to requirements.txt. Thank you to gesen2egee for PR [#1985](https://github.com/kohya-ss/sd-scripts/pull/1985).
- For example, you can use CAME optimizer with `--optimizer_type "pytorch_optimizer.CAME" --optimizer_args "weight_decay=0.01"`.
Mar 6, 2025:
- Added a utility script to merge the weights of SD3's DiT, VAE (optional), CLIP-L, CLIP-G, and T5XXL into a single .safetensors file. Run `tools/merge_sd3_safetensors.py`. See `--help` for usage. PR [#1960](https://github.com/kohya-ss/sd-scripts/pull/1960)
Feb 26, 2025:
- Improve the validation loss calculation in `train_network.py`, `sdxl_train_network.py`, `flux_train_network.py`, and `sd3_train_network.py`. PR [#1903](https://github.com/kohya-ss/sd-scripts/pull/1903)
- The validation loss uses the fixed timestep sampling and the fixed random seed. This is to ensure that the validation loss is not fluctuated by the random values.
Jan 25, 2025:
- `train_network.py`, `sdxl_train_network.py`, `flux_train_network.py`, and `sd3_train_network.py` now support validation loss. PR [#1864](https://github.com/kohya-ss/sd-scripts/pull/1864) Thank you to rockerBOO!
@@ -739,6 +761,8 @@ Not available yet.
[__Change History__](#change-history) is moved to the bottom of the page.
更新履歴は[ページ末尾](#change-history)に移しました。
Latest update: 2025-03-21 (Version 0.9.1)
[日本語版READMEはこちら](./README-ja.md)
The development version is in the `dev` branch. Please check the dev branch for the latest changes.
@@ -882,6 +906,11 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
## Change History
### Mar 21, 2025 / 2025-03-21 Version 0.9.1
- Fixed a bug where some of LoRA modules for CLIP Text Encoder were not trained. Thank you Nekotekina for PR [#1964](https://github.com/kohya-ss/sd-scripts/pull/1964)
- The LoRA modules for CLIP Text Encoder are now 264 modules, which is the same as before. Only 88 modules were trained in the previous version.
### Jan 17, 2025 / 2025-01-17 Version 0.9.0
- __important__ The dependent libraries are updated. Please see [Upgrade](#upgrade) and update the libraries.

View File

@@ -152,6 +152,7 @@ These options are related to subset configuration.
| `keep_tokens_separator` | `“|||”` | o | o | o |
| `secondary_separator` | `“;;;”` | o | o | o |
| `enable_wildcard` | `true` | o | o | o |
| `resize_interpolation` | (not specified) | o | o | o |
* `num_repeats`
* Specifies the number of repeats for images in a subset. This is equivalent to `--dataset_repeats` in fine-tuning but can be specified for any training method.
@@ -165,6 +166,8 @@ These options are related to subset configuration.
* Specifies an additional separator. The part separated by this separator is treated as one tag and is shuffled and dropped. It is then replaced by `caption_separator`. For example, if you specify `aaa;;;bbb;;;ccc`, it will be replaced by `aaa,bbb,ccc` or dropped together.
* `enable_wildcard`
* Enables wildcard notation. This will be explained later.
* `resize_interpolation`
* Specifies the interpolation method used when resizing images. Normally, there is no need to specify this. The following options can be specified: `lanczos`, `nearest`, `bilinear`, `linear`, `bicubic`, `cubic`, `area`, `box`. By default (when not specified), `area` is used for downscaling, and `lanczos` is used for upscaling. If this option is specified, the same interpolation method will be used for both upscaling and downscaling. When `lanczos` or `box` is specified, PIL is used; for other options, OpenCV is used.
### DreamBooth-specific options

View File

@@ -144,6 +144,7 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学
| `keep_tokens_separator` | `“|||”` | o | o | o |
| `secondary_separator` | `“;;;”` | o | o | o |
| `enable_wildcard` | `true` | o | o | o |
| `resize_interpolation` |(通常は設定しません) | o | o | o |
* `num_repeats`
* サブセットの画像の繰り返し回数を指定します。fine tuning における `--dataset_repeats` に相当しますが、`num_repeats` はどの学習方法でも指定可能です。
@@ -162,6 +163,9 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学
* `enable_wildcard`
* ワイルドカード記法および複数行キャプションを有効にします。ワイルドカード記法、複数行キャプションについては後述します。
* `resize_interpolation`
* 画像のリサイズ時に使用する補間方法を指定します。通常は指定しなくて構いません。`lanczos`, `nearest`, `bilinear`, `linear`, `bicubic`, `cubic`, `area`, `box` が指定可能です。デフォルト(未指定時)は、縮小時は `area`、拡大時は `lanczos` になります。このオプションを指定すると、拡大時・縮小時とも同じ補間方法が使用されます。`lanczos``box`を指定するとPILが、それ以外を指定するとOpenCVが使用されます。
### DreamBooth 方式専用のオプション
DreamBooth 方式のオプションは、サブセット向けオプションのみ存在します。

View File

@@ -178,7 +178,7 @@ def train(args):
vae.requires_grad_(False)
vae.eval()
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
train_dataset_group.new_cache_latents(vae, accelerator)
vae.to("cpu")
clean_memory_on_device(accelerator.device)

View File

@@ -1,232 +0,0 @@
# add caption to images by Florence-2
import argparse
import json
import os
import glob
from pathlib import Path
from typing import Any, Optional
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm
from transformers import AutoProcessor, AutoModelForCausalLM
from library import device_utils, train_util, dataset_metadata_utils
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
import tagger_utils
TASK_PROMPT = "<MORE_DETAILED_CAPTION>"
def main(args):
assert args.load_archive == (
args.metadata is not None
), "load_archive must be used with metadata / load_archiveはmetadataと一緒に使う必要があります"
device = args.device if args.device is not None else device_utils.get_preferred_device()
if type(device) is str:
device = torch.device(device)
torch_dtype = torch.float16 if device.type == "cuda" else torch.float32
logger.info(f"device: {device}, dtype: {torch_dtype}")
logger.info("Loading Florence-2-large model / Florence-2-largeモデルをロード中")
support_flash_attn = False
try:
import flash_attn
support_flash_attn = True
except ImportError:
pass
if support_flash_attn:
model = AutoModelForCausalLM.from_pretrained(
"microsoft/Florence-2-large", torch_dtype=torch_dtype, trust_remote_code=True
).to(device)
else:
logger.info(
"flash_attn is not available. Trying to load without it / flash_attnが利用できません。flash_attnを使わずにロードを試みます"
)
# https://github.com/huggingface/transformers/issues/31793#issuecomment-2295797330
# Removing the unnecessary flash_attn import which causes issues on CPU or MPS backends
from transformers.dynamic_module_utils import get_imports
from unittest.mock import patch
def fixed_get_imports(filename) -> list[str]:
if not str(filename).endswith("modeling_florence2.py"):
return get_imports(filename)
imports = get_imports(filename)
imports.remove("flash_attn")
return imports
# workaround for unnecessary flash_attn requirement
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
model = AutoModelForCausalLM.from_pretrained(
"microsoft/Florence-2-large", torch_dtype=torch_dtype, trust_remote_code=True
).to(device)
model.eval()
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
# 画像を読み込む
if not args.load_archive:
train_data_dir_path = Path(args.train_data_dir)
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
logger.info(f"found {len(image_paths)} images.")
else:
archive_files = glob.glob(os.path.join(args.train_data_dir, "*.zip")) + glob.glob(
os.path.join(args.train_data_dir, "*.tar")
)
image_paths = [Path(archive_file) for archive_file in archive_files]
# load metadata if needed
if args.metadata is not None:
metadata = dataset_metadata_utils.load_metadata(args.metadata, create_new=True)
images_metadata = metadata["images"]
else:
images_metadata = metadata = None
# define preprocess_image function
def preprocess_image(image: Image.Image):
inputs = processor(text=TASK_PROMPT, images=image, return_tensors="pt").to(device, torch_dtype)
return inputs
# prepare DataLoader or something similar :)
# Loader returns: list of (image_path, processed_image_or_something, image_size)
if args.load_archive:
loader = tagger_utils.ArchiveImageLoader([str(p) for p in image_paths], args.batch_size, preprocess_image, args.debug)
else:
# we cannot use DataLoader with ImageLoadingPrepDataset because processor is not pickleable
loader = tagger_utils.ImageLoader(image_paths, args.batch_size, preprocess_image, args.debug)
def run_batch(
list_of_path_inputs_size: list[tuple[str, dict[str, torch.Tensor], tuple[int, int]]],
images_metadata: Optional[dict[str, Any]],
caption_index: Optional[int] = None,
):
input_ids = torch.cat([inputs["input_ids"] for _, inputs, _ in list_of_path_inputs_size])
pixel_values = torch.cat([inputs["pixel_values"] for _, inputs, _ in list_of_path_inputs_size])
if args.debug:
logger.info(f"input_ids: {input_ids.shape}, pixel_values: {pixel_values.shape}")
with torch.no_grad():
generated_ids = model.generate(
input_ids=input_ids,
pixel_values=pixel_values,
max_new_tokens=args.max_new_tokens,
num_beams=args.num_beams,
)
if args.debug:
logger.info(f"generate done: {generated_ids.shape}")
generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=False)
if args.debug:
logger.info(f"decode done: {len(generated_texts)}")
for generated_text, (image_path, _, image_size) in zip(generated_texts, list_of_path_inputs_size):
parsed_answer = processor.post_process_generation(generated_text, task=TASK_PROMPT, image_size=image_size)
caption_text = parsed_answer["<MORE_DETAILED_CAPTION>"]
caption_text = caption_text.strip().replace("<pad>", "")
original_caption_text = caption_text
if args.remove_mood:
p = caption_text.find("The overall ")
if p != -1:
caption_text = caption_text[:p].strip()
caption_file = os.path.splitext(image_path)[0] + args.caption_extension
if images_metadata is None:
with open(caption_file, "wt", encoding="utf-8") as f:
f.write(caption_text + "\n")
else:
image_md = images_metadata.get(image_path, None)
if image_md is None:
image_md = {"image_size": list(image_size)}
images_metadata[image_path] = image_md
if "caption" not in image_md:
image_md["caption"] = []
if caption_index is None:
image_md["caption"].append(caption_text)
else:
while len(image_md["caption"]) <= caption_index:
image_md["caption"].append("")
image_md["caption"][caption_index] = caption_text
if args.debug:
logger.info("")
logger.info(f"{image_path}:")
logger.info(f"\tCaption: {caption_text}")
if args.remove_mood and original_caption_text != caption_text:
logger.info(f"\tCaption (prior to removing mood): {original_caption_text}")
for data_entry in tqdm(loader, smoothing=0.0):
b_imgs = data_entry
b_imgs = [(str(image_path), image, size) for image_path, image, size in b_imgs] # Convert image_path to string
run_batch(b_imgs, images_metadata, args.caption_index)
if args.metadata is not None:
logger.info(f"saving metadata file: {args.metadata}")
with open(args.metadata, "wt", encoding="utf-8") as f:
json.dump(metadata, f, ensure_ascii=False, indent=2)
logger.info("done!")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument(
"--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子"
)
parser.add_argument("--recursive", action="store_true", help="search images recursively / 画像を再帰的に検索する")
parser.add_argument(
"--remove_mood", action="store_true", help="remove mood from the caption / キャプションからムードを削除する"
)
parser.add_argument(
"--max_new_tokens",
type=int,
default=1024,
help="maximum number of tokens to generate. default is 1024 / 生成するトークンの最大数。デフォルトは1024",
)
parser.add_argument(
"--num_beams",
type=int,
default=3,
help="number of beams for beam search. default is 3 / ビームサーチのビーム数。デフォルトは3",
)
parser.add_argument(
"--device",
type=str,
default=None,
help="device for model. default is None, which means using an appropriate device / モデルのデバイス。デフォルトはNoneで、適切なデバイスを使用する",
)
parser.add_argument(
"--caption_index",
type=int,
default=None,
help="index of the caption in the metadata file. default is None, which means adding caption to the existing captions. 0>= to replace the caption"
" / メタデータファイル内のキャプションのインデックス。デフォルトはNoneで、新しく追加する。0以上でキャプションを置き換える",
)
parser.add_argument("--debug", action="store_true", help="debug mode")
tagger_utils.add_archive_arguments(parser)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
main(args)

View File

@@ -180,7 +180,7 @@ def main(args):
# バッチへ追加
image_info = train_util.ImageInfo(image_key, 1, "", False, image_path)
image_info.latents_cache_path = npz_file_name
image_info.latents_npz = npz_file_name
image_info.bucket_reso = reso
image_info.resized_size = resized_size
image_info.image = image

View File

@@ -1,10 +1,7 @@
import argparse
import csv
import glob
import json
import os
from pathlib import Path
from typing import Any, Optional
import cv2
import numpy as np
@@ -13,18 +10,14 @@ from huggingface_hub import hf_hub_download
from PIL import Image
from tqdm import tqdm
from library import dataset_metadata_utils
from library.utils import setup_logging
import library.train_util as train_util
from library.utils import setup_logging, resize_image
setup_logging()
import logging
logger = logging.getLogger(__name__)
import library.train_util as train_util
from library.utils import pil_resize
import tagger_utils
# from wd14 tagger
IMAGE_SIZE = 448
@@ -49,10 +42,7 @@ def preprocess_image(image):
pad_t = pad_y // 2
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)
if size > IMAGE_SIZE:
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), cv2.INTER_AREA)
else:
image = pil_resize(image, (IMAGE_SIZE, IMAGE_SIZE))
image = resize_image(image, image.shape[0], image.shape[1], IMAGE_SIZE, IMAGE_SIZE)
image = image.astype(np.float32)
return image
@@ -70,14 +60,13 @@ class ImageLoadingPrepDataset(torch.utils.data.Dataset):
try:
image = Image.open(img_path).convert("RGB")
size = image.size
image = preprocess_image(image)
# tensor = torch.tensor(image) # これ Tensor に変換する必要ないな……(;・∀・)
except Exception as e:
logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
return None
return (image, img_path, size)
return (image, img_path)
def collate_fn_remove_corrupted(batch):
@@ -91,10 +80,6 @@ def collate_fn_remove_corrupted(batch):
def main(args):
assert args.load_archive == (
args.metadata is not None
), "load_archive must be used with metadata / load_archiveはmetadataと一緒に使う必要があります"
# model location is model_dir + repo_id
# repo id may be like "user/repo" or "user/repo/branch", so we need to remove slash
model_location = os.path.join(args.model_dir, args.repo_id.replace("/", "_"))
@@ -161,19 +146,15 @@ def main(args):
ort_sess = ort.InferenceSession(
onnx_path,
providers=(["OpenVINOExecutionProvider"]),
provider_options=[{"device_type": "GPU_FP32"}],
provider_options=[{'device_type' : "GPU_FP32"}],
)
else:
ort_sess = ort.InferenceSession(
onnx_path,
providers=(
["CUDAExecutionProvider"]
if "CUDAExecutionProvider" in ort.get_available_providers()
else (
["ROCMExecutionProvider"]
if "ROCMExecutionProvider" in ort.get_available_providers()
else ["CPUExecutionProvider"]
)
["CUDAExecutionProvider"] if "CUDAExecutionProvider" in ort.get_available_providers() else
["ROCMExecutionProvider"] if "ROCMExecutionProvider" in ort.get_available_providers() else
["CPUExecutionProvider"]
),
)
else:
@@ -219,9 +200,7 @@ def main(args):
tag_replacements = escaped_tag_replacements.split(";")
for tag_replacement in tag_replacements:
tags = tag_replacement.split(",") # source, target
assert (
len(tags) == 2
), f"tag replacement must be in the format of `source,target` / タグの置換は `置換元,置換先` の形式で指定してください: {args.tag_replacement}"
assert len(tags) == 2, f"tag replacement must be in the format of `source,target` / タグの置換は `置換元,置換先` の形式で指定してください: {args.tag_replacement}"
source, target = [tag.replace("@@@@", ",").replace("####", ";") for tag in tags]
logger.info(f"replacing tag: {source} -> {target}")
@@ -234,15 +213,9 @@ def main(args):
rating_tags[rating_tags.index(source)] = target
# 画像を読み込む
if not args.load_archive:
train_data_dir_path = Path(args.train_data_dir)
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
logger.info(f"found {len(image_paths)} images.")
else:
archive_files = glob.glob(os.path.join(args.train_data_dir, "*.zip")) + glob.glob(
os.path.join(args.train_data_dir, "*.tar")
)
image_paths = [Path(archive_file) for archive_file in archive_files]
train_data_dir_path = Path(args.train_data_dir)
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
logger.info(f"found {len(image_paths)} images.")
tag_freq = {}
@@ -255,23 +228,19 @@ def main(args):
if args.always_first_tags is not None:
always_first_tags = [tag for tag in args.always_first_tags.split(stripped_caption_separator) if tag.strip() != ""]
def run_batch(
list_of_path_img_size: list[tuple[str, np.ndarray, tuple[int, int]]],
images_metadata: Optional[dict[str, Any]],
tags_index: Optional[int] = None,
):
imgs = np.array([im for _, im, _ in list_of_path_img_size])
def run_batch(path_imgs):
imgs = np.array([im for _, im in path_imgs])
if args.onnx:
# if len(imgs) < args.batch_size:
# imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0)
probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy
probs = probs[: len(list_of_path_img_size)]
probs = probs[: len(path_imgs)]
else:
probs = model(imgs, training=False)
probs = probs.numpy()
for (image_path, _, image_size), prob in zip(list_of_path_img_size, probs):
for (image_path, _), prob in zip(path_imgs, probs):
combined_tags = []
rating_tag_text = ""
character_tag_text = ""
@@ -293,7 +262,7 @@ def main(args):
if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
character_tag_text += caption_separator + tag_name
if args.character_tags_first: # insert to the beginning
if args.character_tags_first: # insert to the beginning
combined_tags.insert(0, tag_name)
else:
combined_tags.append(tag_name)
@@ -309,7 +278,7 @@ def main(args):
tag_freq[found_rating] = tag_freq.get(found_rating, 0) + 1
rating_tag_text = found_rating
if args.use_rating_tags:
combined_tags.insert(0, found_rating) # insert to the beginning
combined_tags.insert(0, found_rating) # insert to the beginning
else:
combined_tags.append(found_rating)
@@ -332,24 +301,12 @@ def main(args):
tag_text = caption_separator.join(combined_tags)
if args.append_tags:
existing_content = None
if images_metadata is None:
# Check if file exists
if os.path.exists(caption_file):
with open(caption_file, "rt", encoding="utf-8") as f:
# Read file and remove new lines
existing_content = f.read().strip("\n") # Remove newlines
else:
image_md = images_metadata.get(image_path, None)
if image_md is not None:
tags = image_md.get("tags", None)
if tags is not None:
if tags_index is None and len(tags) > 0:
existing_content = tags[-1]
elif tags_index is not None and tags_index < len(tags):
existing_content = tags[tags_index]
# Check if file exists
if os.path.exists(caption_file):
with open(caption_file, "rt", encoding="utf-8") as f:
# Read file and remove new lines
existing_content = f.read().strip("\n") # Remove newlines
if existing_content is not None:
# Split the content into tags and store them in a list
existing_tags = [tag.strip() for tag in existing_content.split(stripped_caption_separator) if tag.strip()]
@@ -359,46 +316,19 @@ def main(args):
# Create new tag_text
tag_text = caption_separator.join(existing_tags + new_tags)
if images_metadata is None:
with open(caption_file, "wt", encoding="utf-8") as f:
f.write(tag_text + "\n")
else:
image_md = images_metadata.get(image_path, None)
if image_md is None:
image_md = {"image_size": list(image_size)}
images_metadata[image_path] = image_md
if "tags" not in image_md:
image_md["tags"] = []
if tags_index is None:
image_md["tags"].append(tag_text)
else:
while len(image_md["tags"]) <= tags_index:
image_md["tags"].append("")
image_md["tags"][tags_index] = tag_text
with open(caption_file, "wt", encoding="utf-8") as f:
f.write(tag_text + "\n")
if args.debug:
logger.info("")
logger.info(f"{image_path}:")
logger.info(f"\tRating tags: {rating_tag_text}")
logger.info(f"\tCharacter tags: {character_tag_text}")
logger.info(f"\tGeneral tags: {general_tag_text}")
if args.debug:
logger.info("")
logger.info(f"{image_path}:")
logger.info(f"\tRating tags: {rating_tag_text}")
logger.info(f"\tCharacter tags: {character_tag_text}")
logger.info(f"\tGeneral tags: {general_tag_text}")
# load metadata if needed
if args.metadata is not None:
metadata = dataset_metadata_utils.load_metadata(args.metadata, create_new=True)
images_metadata = metadata["images"]
else:
images_metadata = metadata = None
# prepare DataLoader or something similar :)
use_loader = False
if args.load_archive:
loader = tagger_utils.ArchiveImageLoader([str(p) for p in image_paths], args.batch_size, preprocess_image, args.debug)
use_loader = True
elif args.max_data_loader_n_workers is not None:
# 読み込みの高速化のためにDataLoaderを使うオプション
# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
dataset = ImageLoadingPrepDataset(image_paths)
loader = torch.utils.data.DataLoader(
data = torch.utils.data.DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=False,
@@ -406,37 +336,35 @@ def main(args):
collate_fn=collate_fn_remove_corrupted,
drop_last=False,
)
use_loader = True
else:
# make batch of image paths
loader = []
for i in range(0, len(image_paths), args.batch_size):
loader.append(image_paths[i : i + args.batch_size])
data = [[(None, ip)] for ip in image_paths]
for data_entry in tqdm(loader, smoothing=0.0):
if use_loader:
b_imgs = data_entry
else:
b_imgs = []
for image_path in data_entry:
b_imgs = []
for data_entry in tqdm(data, smoothing=0.0):
for data in data_entry:
if data is None:
continue
image, image_path = data
if image is None:
try:
image = Image.open(image_path)
if image.mode != "RGB":
image = image.convert("RGB")
size = image.size
image = preprocess_image(image)
except Exception as e:
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue
b_imgs.append((image_path, image, size))
b_imgs.append((image_path, image))
b_imgs = [(str(image_path), image, size) for image_path, image, size in b_imgs] # Convert image_path to string
run_batch(b_imgs, images_metadata, args.tags_index)
if len(b_imgs) >= args.batch_size:
b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string
run_batch(b_imgs)
b_imgs.clear()
if args.metadata is not None:
logger.info(f"saving metadata file: {args.metadata}")
with open(args.metadata, "wt", encoding="utf-8") as f:
json.dump(metadata, f, ensure_ascii=False, indent=2)
if len(b_imgs) > 0:
b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string
run_batch(b_imgs)
if args.frequency_tags:
sorted_tags = sorted(tag_freq.items(), key=lambda x: x[1], reverse=True)
@@ -449,7 +377,9 @@ def main(args):
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument(
"train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ"
)
parser.add_argument(
"--repo_id",
type=str,
@@ -467,7 +397,9 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします",
)
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument(
"--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ"
)
parser.add_argument(
"--max_data_loader_n_workers",
type=int,
@@ -506,7 +438,9 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="replace underscores with spaces in the output tags / 出力されるタグのアンダースコアをスペースに置き換える",
)
parser.add_argument("--debug", action="store_true", help="debug mode")
parser.add_argument(
"--debug", action="store_true", help="debug mode"
)
parser.add_argument(
"--undesired_tags",
type=str,
@@ -516,24 +450,20 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--frequency_tags", action="store_true", help="Show frequency of tags for images / タグの出現頻度を表示する"
)
parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する")
parser.add_argument(
"--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する"
)
parser.add_argument(
"--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する"
)
parser.add_argument(
"--use_rating_tags",
action="store_true",
help="Adds rating tags as the first tag / レーティングタグを最初のタグとして追加する",
"--use_rating_tags", action="store_true", help="Adds rating tags as the first tag / レーティングタグを最初のタグとして追加する",
)
parser.add_argument(
"--use_rating_tags_as_last_tag",
action="store_true",
help="Adds rating tags as the last tag / レーティングタグを最後のタグとして追加する",
"--use_rating_tags_as_last_tag", action="store_true", help="Adds rating tags as the last tag / レーティングタグを最後のタグとして追加する",
)
parser.add_argument(
"--character_tags_first",
action="store_true",
help="Always inserts character tags before the general tags / characterタグを常にgeneralタグの前に出力する",
"--character_tags_first", action="store_true", help="Always inserts character tags before the general tags / characterタグを常にgeneralタグの前に出力する",
)
parser.add_argument(
"--always_first_tags",
@@ -562,15 +492,6 @@ def setup_parser() -> argparse.ArgumentParser:
+ " / キャラクタタグの末尾の括弧を別のタグに展開する。`chara_name_(series)` は `chara_name, series` になる",
)
parser.add_argument(
"--tags_index",
type=int,
default=None,
help="index of the tags in the metadata file. default is None, which means adding tags to the existing tags. 0>= to replace the tags"
" / メタデータファイル内のタグのインデックス。デフォルトはNoneで、既存のタグにタグを追加する。0以上でタグを置き換える",
)
tagger_utils.add_archive_arguments(parser)
return parser

View File

@@ -1,150 +0,0 @@
import argparse
import json
import math
import os
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, Union
import zipfile
import tarfile
from PIL import Image
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
from library import dataset_metadata_utils, train_util
class ArchiveImageLoader:
def __init__(self, archive_paths: list[str], batch_size: int, preprocess: Callable, debug: bool = False):
self.archive_paths = archive_paths
self.batch_size = batch_size
self.preprocess = preprocess
self.debug = debug
self.current_archive = None
self.archive_index = 0
self.image_index = 0
self.files = None
self.executor = ThreadPoolExecutor()
self.image_exts = set(train_util.IMAGE_EXTENSIONS)
def __iter__(self):
return self
def __next__(self):
images = []
while len(images) < self.batch_size:
if self.current_archive is None:
if self.archive_index >= len(self.archive_paths):
if len(images) == 0:
raise StopIteration
else:
break # return the remaining images
if self.debug:
logger.info(f"loading archive: {self.archive_paths[self.archive_index]}")
current_archive_path = self.archive_paths[self.archive_index]
if current_archive_path.endswith(".zip"):
self.current_archive = zipfile.ZipFile(current_archive_path)
self.files = self.current_archive.namelist()
elif current_archive_path.endswith(".tar"):
self.current_archive = tarfile.open(current_archive_path, "r")
self.files = self.current_archive.getnames()
else:
raise ValueError(f"unsupported archive file: {self.current_archive_path}")
self.image_index = 0
# filter by image extensions
self.files = [file for file in self.files if os.path.splitext(file)[1].lower() in self.image_exts]
if self.debug:
logger.info(f"found {len(self.files)} images in the archive")
new_images = []
while len(images) + len(new_images) < self.batch_size:
if self.image_index >= len(self.files):
break
file = self.files[self.image_index]
archive_and_image_path = (
f"{self.archive_paths[self.archive_index]}{dataset_metadata_utils.ARCHIVE_PATH_SEPARATOR}{file}"
)
self.image_index += 1
def load_image(file, archive: Union[zipfile.ZipFile, tarfile.TarFile]):
with archive.open(file) as f:
image = Image.open(f).convert("RGB")
size = image.size
image = self.preprocess(image)
return image, size
new_images.append((archive_and_image_path, self.executor.submit(load_image, file, self.current_archive)))
# wait for all new_images to load to close the archive
new_images = [(image_path, future.result()) for image_path, future in new_images]
if self.image_index >= len(self.files):
self.current_archive.close()
self.current_archive = None
self.archive_index += 1
images.extend(new_images)
return [(image_path, image, size) for image_path, (image, size) in images]
class ImageLoader:
def __init__(self, image_paths: list[str], batch_size: int, preprocess: Callable, debug: bool = False):
self.image_paths = image_paths
self.batch_size = batch_size
self.preprocess = preprocess
self.debug = debug
self.image_index = 0
self.executor = ThreadPoolExecutor()
def __len__(self):
return math.ceil(len(self.image_paths) / self.batch_size)
def __iter__(self):
return self
def __next__(self):
if self.image_index >= len(self.image_paths):
raise StopIteration
images = []
while len(images) < self.batch_size and self.image_index < len(self.image_paths):
def load_image(file):
image = Image.open(file).convert("RGB")
size = image.size
image = self.preprocess(image)
return image, size
image_path = self.image_paths[self.image_index]
images.append((image_path, self.executor.submit(load_image, image_path)))
self.image_index += 1
images = [(image_path, future.result()) for image_path, future in images]
return [(image_path, image, size) for image_path, (image, size) in images]
def add_archive_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--metadata",
type=str,
default=None,
help="metadata file for the dataset. write tags to this file instead of the caption file / データセットのメタデータファイル。キャプションファイルの代わりにこのファイルにタグを書き込む",
)
parser.add_argument(
"--load_archive",
action="store_true",
help="load archive file such as .zip instead of image files. currently .zip and .tar are supported. must be used with --metadata"
" / 画像ファイルではなく.zipなどのアーカイブファイルを読み込む。現在.zipと.tarをサポート。--metadataと一緒に使う必要があります",
)

View File

@@ -152,20 +152,15 @@ def train(args):
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
if args.debug_dataset:
t5xxl_max_token_length = (
args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if is_schnell else 512)
)
if args.cache_text_encoder_outputs:
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
strategy_flux.FluxTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
args.text_encoder_batch_size,
args.skip_cache_check,
t5xxl_max_token_length,
args.apply_t5_attn_mask,
False,
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False
)
)
t5xxl_max_token_length = (
args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if is_schnell else 512)
)
strategy_base.TokenizeStrategy.set_strategy(strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length))
train_dataset_group.set_current_strategies()
@@ -204,7 +199,7 @@ def train(args):
ae.requires_grad_(False)
ae.eval()
train_dataset_group.new_cache_latents(ae, accelerator, args.force_cache_precision)
train_dataset_group.new_cache_latents(ae, accelerator)
ae.to("cpu") # if no sampling, vae can be deleted
clean_memory_on_device(accelerator.device)
@@ -242,12 +237,7 @@ def train(args):
t5xxl.to(accelerator.device)
text_encoder_caching_strategy = strategy_flux.FluxTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
args.text_encoder_batch_size,
args.skip_cache_check,
t5xxl_max_token_length,
args.apply_t5_attn_mask,
False,
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False, args.apply_t5_attn_mask
)
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy)

View File

@@ -11,6 +11,16 @@ from library.device_utils import clean_memory_on_device, init_ipex
init_ipex()
import train_network
from library import (
flux_models,
flux_train_utils,
flux_utils,
sd3_train_utils,
strategy_base,
strategy_flux,
train_util,
)
from library.utils import setup_logging
setup_logging()
@@ -18,9 +28,6 @@ import logging
logger = logging.getLogger(__name__)
from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, strategy_base, strategy_flux, train_util
import train_network
class FluxNetworkTrainer(train_network.NetworkTrainer):
def __init__(self):
@@ -29,7 +36,12 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
self.is_schnell: Optional[bool] = None
self.is_swapping_blocks: bool = False
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
def assert_extra_args(
self,
args,
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
val_dataset_group: Optional[train_util.DatasetGroup],
):
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
# sdxl_train_util.verify_sdxl_training_args(args)
@@ -178,17 +190,13 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
def get_text_encoder_outputs_caching_strategy(self, args):
if args.cache_text_encoder_outputs:
fluxTokenizeStrategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
t5xxl_max_token_length = fluxTokenizeStrategy.t5xxl_max_length
# if the text encoders is trained, we need tokenization, so is_partial is True
return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
args.text_encoder_batch_size,
args.skip_cache_check,
t5xxl_max_token_length,
args.apply_t5_attn_mask,
is_partial=self.train_clip_l or self.train_t5xxl,
apply_t5_attn_mask=args.apply_t5_attn_mask,
)
else:
return None
@@ -320,7 +328,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
return noise_scheduler
def encode_images_to_latents(self, args, accelerator, vae, images):
def encode_images_to_latents(self, args, vae, images):
return vae.encode(images)
def shift_scale_latents(self, args, latents):
@@ -338,7 +346,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
network,
weight_dtype,
train_unet,
is_train=True
is_train=True,
):
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
@@ -373,8 +381,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
t5_attn_mask = None
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
# if not args.split_mode:
# normal forward
# grad is enabled even if unet is not in train mode, because Text Encoder is in train mode
with torch.set_grad_enabled(is_train), accelerator.autocast():
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
model_pred = unet(
@@ -387,44 +394,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)
"""
else:
# split forward to reduce memory usage
assert network.train_blocks == "single", "train_blocks must be single for split mode"
with accelerator.autocast():
# move flux lower to cpu, and then move flux upper to gpu
unet.to("cpu")
clean_memory_on_device(accelerator.device)
self.flux_upper.to(accelerator.device)
# upper model does not require grad
with torch.no_grad():
intermediate_img, intermediate_txt, vec, pe = self.flux_upper(
img=packed_noisy_model_input,
img_ids=img_ids,
txt=t5_out,
txt_ids=txt_ids,
y=l_pooled,
timesteps=timesteps / 1000,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)
# move flux upper back to cpu, and then move flux lower to gpu
self.flux_upper.to("cpu")
clean_memory_on_device(accelerator.device)
unet.to(accelerator.device)
# lower model requires grad
intermediate_img.requires_grad_(True)
intermediate_txt.requires_grad_(True)
vec.requires_grad_(True)
pe.requires_grad_(True)
with torch.set_grad_enabled(is_train and train_unet):
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
"""
return model_pred
model_pred = call_dit(
@@ -543,6 +512,11 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
text_encoder.to(te_weight_dtype) # fp8
prepare_fp8(text_encoder, weight_dtype)
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
if self.is_swapping_blocks:
# prepare for next forward: because backward pass is not called, we need to prepare it here
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
def prepare_unet_with_accelerator(
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
) -> torch.nn.Module:

View File

@@ -75,6 +75,7 @@ class BaseSubsetParams:
custom_attributes: Optional[Dict[str, Any]] = None
validation_seed: int = 0
validation_split: float = 0.0
resize_interpolation: Optional[str] = None
@dataclass
@@ -106,7 +107,7 @@ class BaseDatasetParams:
debug_dataset: bool = False
validation_seed: Optional[int] = None
validation_split: float = 0.0
resize_interpolation: Optional[str] = None
@dataclass
class DreamBoothDatasetParams(BaseDatasetParams):
@@ -196,6 +197,7 @@ class ConfigSanitizer:
"caption_prefix": str,
"caption_suffix": str,
"custom_attributes": dict,
"resize_interpolation": str,
}
# DO means DropOut
DO_SUBSET_ASCENDABLE_SCHEMA = {
@@ -241,6 +243,7 @@ class ConfigSanitizer:
"validation_split": float,
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
"network_multiplier": float,
"resize_interpolation": str,
}
# options handled by argparse but not handled by user config
@@ -525,6 +528,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
[{dataset_type} {i}]
batch_size: {dataset.batch_size}
resolution: {(dataset.width, dataset.height)}
resize_interpolation: {dataset.resize_interpolation}
enable_bucket: {dataset.enable_bucket}
""")
@@ -558,6 +562,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
token_warmup_min: {subset.token_warmup_min},
token_warmup_step: {subset.token_warmup_step},
alpha_mask: {subset.alpha_mask}
resize_interpolation: {subset.resize_interpolation}
custom_attributes: {subset.custom_attributes}
"""), " ")

View File

@@ -1,58 +0,0 @@
import os
import json
from typing import Any, Optional
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
METADATA_VERSION = [1, 0, 0]
VERSION_STRING = ".".join(str(v) for v in METADATA_VERSION)
ARCHIVE_PATH_SEPARATOR = "////"
def load_metadata(metadata_file: str, create_new: bool = False) -> Optional[dict[str, Any]]:
if os.path.exists(metadata_file):
logger.info(f"loading metadata file: {metadata_file}")
with open(metadata_file, "rt", encoding="utf-8") as f:
metadata = json.load(f)
# version check
major, minor, patch = metadata.get("format_version", "0.0.0").split(".")
major, minor, patch = int(major), int(minor), int(patch)
if major > METADATA_VERSION[0] or (major == METADATA_VERSION[0] and minor > METADATA_VERSION[1]):
logger.warning(
f"metadata format version {major}.{minor}.{patch} is higher than supported version {VERSION_STRING}. Some features may not work."
)
if "images" not in metadata:
metadata["images"] = {}
else:
if not create_new:
return None
logger.info(f"metadata file not found: {metadata_file}, creating new metadata")
metadata = {"format_version": VERSION_STRING, "images": {}}
return metadata
def is_archive_path(archive_and_image_path: str) -> bool:
return archive_and_image_path.count(ARCHIVE_PATH_SEPARATOR) == 1
def get_inner_path(archive_and_image_path: str) -> str:
return archive_and_image_path.split(ARCHIVE_PATH_SEPARATOR, 1)[1]
def get_archive_digest(archive_and_image_path: str) -> str:
"""
calculate a 8-digits hex digest for the archive path to avoid collisions for different archives with the same name.
"""
archive_path = archive_and_image_path.split(ARCHIVE_PATH_SEPARATOR, 1)[0]
return f"{hash(archive_path) & 0xFFFFFFFF:08x}"

View File

@@ -2,6 +2,13 @@ import functools
import gc
import torch
try:
# intel gpu support for pytorch older than 2.5
# ipex is not needed after pytorch 2.5
import intel_extension_for_pytorch as ipex # noqa
except Exception:
pass
try:
HAS_CUDA = torch.cuda.is_available()
@@ -14,8 +21,6 @@ except Exception:
HAS_MPS = False
try:
import intel_extension_for_pytorch as ipex # noqa
HAS_XPU = torch.xpu.is_available()
except Exception:
HAS_XPU = False
@@ -69,7 +74,7 @@ def init_ipex():
This function should run right after importing torch and before doing anything else.
If IPEX is not available, this function does nothing.
If xpu is not available, this function does nothing.
"""
try:
if HAS_XPU:

View File

@@ -366,8 +366,6 @@ def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
@@ -410,42 +408,34 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
def get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, device, dtype
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
bsz, _, h, w = latents.shape
sigmas = None
assert bsz > 0, "Batch size not large enough"
num_timesteps = noise_scheduler.config.num_train_timesteps
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
# Simple random t-based noise sampling
# Simple random sigma-based noise sampling
if args.timestep_sampling == "sigmoid":
# https://github.com/XLabs-AI/x-flux/tree/main
t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
sigmas = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
else:
t = torch.rand((bsz,), device=device)
sigmas = torch.rand((bsz,), device=device)
timesteps = t * 1000.0
t = t.view(-1, 1, 1, 1)
noisy_model_input = (1 - t) * latents + t * noise
timesteps = sigmas * num_timesteps
elif args.timestep_sampling == "shift":
shift = args.discrete_flow_shift
logits_norm = torch.randn(bsz, device=device)
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
timesteps = logits_norm.sigmoid()
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
t = timesteps.view(-1, 1, 1, 1)
timesteps = timesteps * 1000.0
noisy_model_input = (1 - t) * latents + t * noise
sigmas = torch.randn(bsz, device=device)
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
sigmas = sigmas.sigmoid()
sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas)
timesteps = sigmas * num_timesteps
elif args.timestep_sampling == "flux_shift":
logits_norm = torch.randn(bsz, device=device)
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
timesteps = logits_norm.sigmoid()
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
timesteps = time_shift(mu, 1.0, timesteps)
t = timesteps.view(-1, 1, 1, 1)
timesteps = timesteps * 1000.0
noisy_model_input = (1 - t) * latents + t * noise
sigmas = torch.randn(bsz, device=device)
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
sigmas = sigmas.sigmoid()
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size
sigmas = time_shift(mu, 1.0, sigmas)
timesteps = sigmas * num_timesteps
else:
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
@@ -456,12 +446,24 @@ def get_noisy_model_input_and_timesteps(
logit_std=args.logit_std,
mode_scale=args.mode_scale,
)
indices = (u * noise_scheduler.config.num_train_timesteps).long()
indices = (u * num_timesteps).long()
timesteps = noise_scheduler.timesteps[indices].to(device=device)
# Add noise according to flow matching.
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
# Broadcast sigmas to latent shape
sigmas = sigmas.view(-1, 1, 1, 1)
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
if args.ip_noise_gamma:
xi = torch.randn_like(latents, device=latents.device, dtype=dtype)
if args.ip_noise_gamma_random_strength:
ip_noise_gamma = (torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma)
else:
ip_noise_gamma = args.ip_noise_gamma
noisy_model_input = (1.0 - sigmas) * latents + sigmas * (noise + ip_noise_gamma * xi)
else:
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas

View File

@@ -2,7 +2,11 @@ import os
import sys
import contextlib
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
try:
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
legacy = True
except Exception:
legacy = False
from .hijacks import ipex_hijacks
# pylint: disable=protected-access, missing-function-docstring, line-too-long
@@ -12,6 +16,13 @@ def ipex_init(): # pylint: disable=too-many-statements
if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_xpu_hijacked") and torch.cuda.is_xpu_hijacked:
return True, "Skipping IPEX hijack"
else:
try: # force xpu device on torch compile and triton
torch._inductor.utils.GPU_TYPES = ["xpu"]
torch._inductor.utils.get_gpu_type = lambda *args, **kwargs: "xpu"
from triton import backends as triton_backends # pylint: disable=import-error
triton_backends.backends["nvidia"].driver.is_active = lambda *args, **kwargs: False
except Exception:
pass
# Replace cuda with xpu:
torch.cuda.current_device = torch.xpu.current_device
torch.cuda.current_stream = torch.xpu.current_stream
@@ -26,84 +37,99 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.is_current_stream_capturing = lambda: False
torch.cuda.set_device = torch.xpu.set_device
torch.cuda.stream = torch.xpu.stream
torch.cuda.synchronize = torch.xpu.synchronize
torch.cuda.Event = torch.xpu.Event
torch.cuda.Stream = torch.xpu.Stream
torch.cuda.FloatTensor = torch.xpu.FloatTensor
torch.Tensor.cuda = torch.Tensor.xpu
torch.Tensor.is_cuda = torch.Tensor.is_xpu
torch.nn.Module.cuda = torch.nn.Module.xpu
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
torch.cuda._initialized = torch.xpu.lazy_init._initialized
torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
torch.cuda._tls = torch.xpu.lazy_init._tls
torch.cuda.threading = torch.xpu.lazy_init.threading
torch.cuda.traceback = torch.xpu.lazy_init.traceback
torch.cuda.Optional = torch.xpu.Optional
torch.cuda.__cached__ = torch.xpu.__cached__
torch.cuda.__loader__ = torch.xpu.__loader__
torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
torch.cuda.Tuple = torch.xpu.Tuple
torch.cuda.streams = torch.xpu.streams
torch.cuda._lazy_new = torch.xpu._lazy_new
torch.cuda.FloatStorage = torch.xpu.FloatStorage
torch.cuda.Any = torch.xpu.Any
torch.cuda.__doc__ = torch.xpu.__doc__
torch.cuda.default_generators = torch.xpu.default_generators
torch.cuda.HalfTensor = torch.xpu.HalfTensor
torch.cuda._get_device_index = torch.xpu._get_device_index
torch.cuda.__path__ = torch.xpu.__path__
torch.cuda.Device = torch.xpu.Device
torch.cuda.IntTensor = torch.xpu.IntTensor
torch.cuda.ByteStorage = torch.xpu.ByteStorage
torch.cuda.set_stream = torch.xpu.set_stream
torch.cuda.BoolStorage = torch.xpu.BoolStorage
torch.cuda.os = torch.xpu.os
torch.cuda.torch = torch.xpu.torch
torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
torch.cuda.Union = torch.xpu.Union
torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
torch.cuda.ShortTensor = torch.xpu.ShortTensor
torch.cuda.LongTensor = torch.xpu.LongTensor
torch.cuda.IntStorage = torch.xpu.IntStorage
torch.cuda.LongStorage = torch.xpu.LongStorage
torch.cuda.__annotations__ = torch.xpu.__annotations__
torch.cuda.__package__ = torch.xpu.__package__
torch.cuda.__builtins__ = torch.xpu.__builtins__
torch.cuda.CharTensor = torch.xpu.CharTensor
torch.cuda.List = torch.xpu.List
torch.cuda._lazy_init = torch.xpu._lazy_init
torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
torch.cuda.ByteTensor = torch.xpu.ByteTensor
torch.cuda.StreamContext = torch.xpu.StreamContext
torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
torch.cuda.ShortStorage = torch.xpu.ShortStorage
torch.cuda._lazy_call = torch.xpu._lazy_call
torch.cuda.HalfStorage = torch.xpu.HalfStorage
torch.cuda.random = torch.xpu.random
torch.cuda._device = torch.xpu._device
torch.cuda.classproperty = torch.xpu.classproperty
torch.cuda.__name__ = torch.xpu.__name__
torch.cuda._device_t = torch.xpu._device_t
torch.cuda.warnings = torch.xpu.warnings
torch.cuda.__spec__ = torch.xpu.__spec__
torch.cuda.BoolTensor = torch.xpu.BoolTensor
torch.cuda.CharStorage = torch.xpu.CharStorage
torch.cuda.__file__ = torch.xpu.__file__
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
# torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
if legacy:
torch.cuda.os = torch.xpu.os
torch.cuda.Device = torch.xpu.Device
torch.cuda.warnings = torch.xpu.warnings
torch.cuda.classproperty = torch.xpu.classproperty
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
if float(ipex.__version__[:3]) < 2.3:
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
torch.cuda._initialized = torch.xpu.lazy_init._initialized
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
torch.cuda._tls = torch.xpu.lazy_init._tls
torch.cuda.threading = torch.xpu.lazy_init.threading
torch.cuda.traceback = torch.xpu.lazy_init.traceback
torch.cuda._lazy_new = torch.xpu._lazy_new
torch.cuda.FloatTensor = torch.xpu.FloatTensor
torch.cuda.FloatStorage = torch.xpu.FloatStorage
torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
torch.cuda.HalfTensor = torch.xpu.HalfTensor
torch.cuda.HalfStorage = torch.xpu.HalfStorage
torch.cuda.ByteTensor = torch.xpu.ByteTensor
torch.cuda.ByteStorage = torch.xpu.ByteStorage
torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
torch.cuda.ShortTensor = torch.xpu.ShortTensor
torch.cuda.ShortStorage = torch.xpu.ShortStorage
torch.cuda.LongTensor = torch.xpu.LongTensor
torch.cuda.LongStorage = torch.xpu.LongStorage
torch.cuda.IntTensor = torch.xpu.IntTensor
torch.cuda.IntStorage = torch.xpu.IntStorage
torch.cuda.CharTensor = torch.xpu.CharTensor
torch.cuda.CharStorage = torch.xpu.CharStorage
torch.cuda.BoolTensor = torch.xpu.BoolTensor
torch.cuda.BoolStorage = torch.xpu.BoolStorage
torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
if not legacy or float(ipex.__version__[:3]) >= 2.3:
torch.cuda._initialization_lock = torch.xpu._initialization_lock
torch.cuda._initialized = torch.xpu._initialized
torch.cuda._is_in_bad_fork = torch.xpu._is_in_bad_fork
torch.cuda._lazy_seed_tracker = torch.xpu._lazy_seed_tracker
torch.cuda._queued_calls = torch.xpu._queued_calls
torch.cuda._tls = torch.xpu._tls
torch.cuda.threading = torch.xpu.threading
torch.cuda.traceback = torch.xpu.traceback
# Memory:
torch.cuda.memory = torch.xpu.memory
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
torch.xpu.empty_cache = lambda: None
torch.cuda.empty_cache = torch.xpu.empty_cache
if legacy:
torch.cuda.memory_summary = torch.xpu.memory_summary
torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
torch.cuda.memory = torch.xpu.memory
torch.cuda.memory_stats = torch.xpu.memory_stats
torch.cuda.memory_summary = torch.xpu.memory_summary
torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
torch.cuda.memory_allocated = torch.xpu.memory_allocated
torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated
torch.cuda.memory_reserved = torch.xpu.memory_reserved
@@ -128,32 +154,44 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.initial_seed = torch.xpu.initial_seed
# AMP:
torch.cuda.amp = torch.xpu.amp
torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled
torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype
if legacy:
torch.xpu.amp.custom_fwd = torch.cuda.amp.custom_fwd
torch.xpu.amp.custom_bwd = torch.cuda.amp.custom_bwd
torch.cuda.amp = torch.xpu.amp
if float(ipex.__version__[:3]) < 2.3:
torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled
torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype
if not hasattr(torch.cuda.amp, "common"):
torch.cuda.amp.common = contextlib.nullcontext()
torch.cuda.amp.common.amp_definitely_not_available = lambda: False
if not hasattr(torch.cuda.amp, "common"):
torch.cuda.amp.common = contextlib.nullcontext()
torch.cuda.amp.common.amp_definitely_not_available = lambda: False
try:
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
except Exception: # pylint: disable=broad-exception-caught
try:
from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error
gradscaler_init()
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
except Exception: # pylint: disable=broad-exception-caught
torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
try:
from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error
gradscaler_init()
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
except Exception: # pylint: disable=broad-exception-caught
torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
# C
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
ipex._C._DeviceProperties.major = 2024
ipex._C._DeviceProperties.minor = 0
if legacy and float(ipex.__version__[:3]) < 2.3:
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentRawStream
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
ipex._C._DeviceProperties.major = 12
ipex._C._DeviceProperties.minor = 1
else:
torch._C._cuda_getCurrentRawStream = torch._C._xpu_getCurrentRawStream
torch._C._XpuDeviceProperties.multi_processor_count = torch._C._XpuDeviceProperties.gpu_subslice_count
torch._C._XpuDeviceProperties.major = 12
torch._C._XpuDeviceProperties.minor = 1
# Fix functions with ipex:
torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
# torch.xpu.mem_get_info always returns the total memory as free memory
torch.xpu.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
torch.cuda.mem_get_info = torch.xpu.mem_get_info
torch._utils._get_available_device_type = lambda: "xpu"
torch.has_cuda = True
torch.cuda.has_half = True
@@ -161,19 +199,19 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
torch.backends.cuda.is_built = lambda *args, **kwargs: True
torch.version.cuda = "12.1"
torch.cuda.get_device_capability = lambda *args, **kwargs: [12,1]
torch.cuda.get_arch_list = lambda: ["ats-m150", "pvc"]
torch.cuda.get_device_capability = lambda *args, **kwargs: (12,1)
torch.cuda.get_device_properties.major = 12
torch.cuda.get_device_properties.minor = 1
torch.cuda.ipc_collect = lambda *args, **kwargs: None
torch.cuda.utilization = lambda *args, **kwargs: 0
ipex_hijacks()
if not torch.xpu.has_fp64_dtype() or os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is not None:
try:
from .diffusers import ipex_diffusers
ipex_diffusers()
except Exception: # pylint: disable=broad-exception-caught
pass
device_supports_fp64, can_allocate_plus_4gb = ipex_hijacks(legacy=legacy)
try:
from .diffusers import ipex_diffusers
ipex_diffusers(device_supports_fp64=device_supports_fp64, can_allocate_plus_4gb=can_allocate_plus_4gb)
except Exception: # pylint: disable=broad-exception-caught
pass
torch.cuda.is_xpu_hijacked = True
except Exception as e:
return False, e

View File

@@ -1,177 +1,119 @@
import os
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
from functools import cache
from functools import cache, wraps
# pylint: disable=protected-access, missing-function-docstring, line-too-long
# ARC GPUs can't allocate more than 4GB to a single block so we slice the attention layers
sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4))
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 1))
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 0.5))
# Find something divisible with the input_tokens
@cache
def find_slice_size(slice_size, slice_block_size):
while (slice_size * slice_block_size) > attention_slice_rate:
slice_size = slice_size // 2
if slice_size <= 1:
slice_size = 1
break
return slice_size
def find_split_size(original_size, slice_block_size, slice_rate=2):
split_size = original_size
while True:
if (split_size * slice_block_size) <= slice_rate and original_size % split_size == 0:
return split_size
split_size = split_size - 1
if split_size <= 1:
return 1
return split_size
# Find slice sizes for SDPA
@cache
def find_sdpa_slice_sizes(query_shape, query_element_size):
if len(query_shape) == 3:
batch_size_attention, query_tokens, shape_three = query_shape
shape_four = 1
else:
batch_size_attention, query_tokens, shape_three, shape_four = query_shape
def find_sdpa_slice_sizes(query_shape, key_shape, query_element_size, slice_rate=2, trigger_rate=3):
batch_size, attn_heads, query_len, _ = query_shape
_, _, key_len, _ = key_shape
slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size
block_size = batch_size_attention * slice_block_size
slice_batch_size = attn_heads * (query_len * key_len) * query_element_size / 1024 / 1024 / 1024
split_slice_size = batch_size_attention
split_2_slice_size = query_tokens
split_3_slice_size = shape_three
split_batch_size = batch_size
split_head_size = attn_heads
split_query_size = query_len
do_split = False
do_split_2 = False
do_split_3 = False
do_batch_split = False
do_head_split = False
do_query_split = False
if block_size > sdpa_slice_trigger_rate:
do_split = True
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
if split_slice_size * slice_block_size > attention_slice_rate:
slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size
do_split_2 = True
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size
do_split_3 = True
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
if batch_size * slice_batch_size >= trigger_rate:
do_batch_split = True
split_batch_size = find_split_size(batch_size, slice_batch_size, slice_rate=slice_rate)
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
if split_batch_size * slice_batch_size > slice_rate:
slice_head_size = split_batch_size * (query_len * key_len) * query_element_size / 1024 / 1024 / 1024
do_head_split = True
split_head_size = find_split_size(attn_heads, slice_head_size, slice_rate=slice_rate)
# Find slice sizes for BMM
@cache
def find_bmm_slice_sizes(input_shape, input_element_size, mat2_shape):
batch_size_attention, input_tokens, mat2_atten_shape = input_shape[0], input_shape[1], mat2_shape[2]
slice_block_size = input_tokens * mat2_atten_shape / 1024 / 1024 * input_element_size
block_size = batch_size_attention * slice_block_size
if split_head_size * slice_head_size > slice_rate:
slice_query_size = split_batch_size * split_head_size * (key_len) * query_element_size / 1024 / 1024 / 1024
do_query_split = True
split_query_size = find_split_size(query_len, slice_query_size, slice_rate=slice_rate)
split_slice_size = batch_size_attention
split_2_slice_size = input_tokens
split_3_slice_size = mat2_atten_shape
return do_batch_split, do_head_split, do_query_split, split_batch_size, split_head_size, split_query_size
do_split = False
do_split_2 = False
do_split_3 = False
if block_size > attention_slice_rate:
do_split = True
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
if split_slice_size * slice_block_size > attention_slice_rate:
slice_2_block_size = split_slice_size * mat2_atten_shape / 1024 / 1024 * input_element_size
do_split_2 = True
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
slice_3_block_size = split_slice_size * split_2_slice_size / 1024 / 1024 * input_element_size
do_split_3 = True
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
original_torch_bmm = torch.bmm
def torch_bmm_32_bit(input, mat2, *, out=None):
if input.device.type != "xpu":
return original_torch_bmm(input, mat2, out=out)
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_bmm_slice_sizes(input.shape, input.element_size(), mat2.shape)
# Slice BMM
if do_split:
batch_size_attention, input_tokens, mat2_atten_shape = input.shape[0], input.shape[1], mat2.shape[2]
hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype)
for i in range(batch_size_attention // split_slice_size):
start_idx = i * split_slice_size
end_idx = (i + 1) * split_slice_size
if do_split_2:
for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name
start_idx_2 = i2 * split_2_slice_size
end_idx_2 = (i2 + 1) * split_2_slice_size
if do_split_3:
for i3 in range(mat2_atten_shape // split_3_slice_size): # pylint: disable=invalid-name
start_idx_3 = i3 * split_3_slice_size
end_idx_3 = (i3 + 1) * split_3_slice_size
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_torch_bmm(
input[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
mat2[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
out=out
)
else:
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm(
input[start_idx:end_idx, start_idx_2:end_idx_2],
mat2[start_idx:end_idx, start_idx_2:end_idx_2],
out=out
)
else:
hidden_states[start_idx:end_idx] = original_torch_bmm(
input[start_idx:end_idx],
mat2[start_idx:end_idx],
out=out
)
torch.xpu.synchronize(input.device)
else:
return original_torch_bmm(input, mat2, out=out)
return hidden_states
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs):
@wraps(torch.nn.functional.scaled_dot_product_attention)
def dynamic_scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs):
if query.device.type != "xpu":
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_sdpa_slice_sizes(query.shape, query.element_size())
is_unsqueezed = False
if len(query.shape) == 3:
query = query.unsqueeze(0)
is_unsqueezed = True
if len(key.shape) == 3:
key = key.unsqueeze(0)
if len(value.shape) == 3:
value = value.unsqueeze(0)
do_batch_split, do_head_split, do_query_split, split_batch_size, split_head_size, split_query_size = find_sdpa_slice_sizes(query.shape, key.shape, query.element_size(), slice_rate=attention_slice_rate, trigger_rate=sdpa_slice_trigger_rate)
# Slice SDPA
if do_split:
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
for i in range(batch_size_attention // split_slice_size):
start_idx = i * split_slice_size
end_idx = (i + 1) * split_slice_size
if do_split_2:
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
start_idx_2 = i2 * split_2_slice_size
end_idx_2 = (i2 + 1) * split_2_slice_size
if do_split_3:
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
start_idx_3 = i3 * split_3_slice_size
end_idx_3 = (i3 + 1) * split_3_slice_size
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_scaled_dot_product_attention(
query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask,
if do_batch_split:
batch_size, attn_heads, query_len, _ = query.shape
_, _, _, head_dim = value.shape
hidden_states = torch.zeros((batch_size, attn_heads, query_len, head_dim), device=query.device, dtype=query.dtype)
if attn_mask is not None:
attn_mask = attn_mask.expand((query.shape[0], query.shape[1], query.shape[2], key.shape[-2]))
for ib in range(batch_size // split_batch_size):
start_idx = ib * split_batch_size
end_idx = (ib + 1) * split_batch_size
if do_head_split:
for ih in range(attn_heads // split_head_size): # pylint: disable=invalid-name
start_idx_h = ih * split_head_size
end_idx_h = (ih + 1) * split_head_size
if do_query_split:
for iq in range(query_len // split_query_size): # pylint: disable=invalid-name
start_idx_q = iq * split_query_size
end_idx_q = (iq + 1) * split_query_size
hidden_states[start_idx:end_idx, start_idx_h:end_idx_h, start_idx_q:end_idx_q, :] = original_scaled_dot_product_attention(
query[start_idx:end_idx, start_idx_h:end_idx_h, start_idx_q:end_idx_q, :],
key[start_idx:end_idx, start_idx_h:end_idx_h, :, :],
value[start_idx:end_idx, start_idx_h:end_idx_h, :, :],
attn_mask=attn_mask[start_idx:end_idx, start_idx_h:end_idx_h, start_idx_q:end_idx_q, :] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal, **kwargs
)
else:
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
query[start_idx:end_idx, start_idx_2:end_idx_2],
key[start_idx:end_idx, start_idx_2:end_idx_2],
value[start_idx:end_idx, start_idx_2:end_idx_2],
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
hidden_states[start_idx:end_idx, start_idx_h:end_idx_h, :, :] = original_scaled_dot_product_attention(
query[start_idx:end_idx, start_idx_h:end_idx_h, :, :],
key[start_idx:end_idx, start_idx_h:end_idx_h, :, :],
value[start_idx:end_idx, start_idx_h:end_idx_h, :, :],
attn_mask=attn_mask[start_idx:end_idx, start_idx_h:end_idx_h, :, :] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal, **kwargs
)
else:
hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention(
query[start_idx:end_idx],
key[start_idx:end_idx],
value[start_idx:end_idx],
attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask,
hidden_states[start_idx:end_idx, :, :, :] = original_scaled_dot_product_attention(
query[start_idx:end_idx, :, :, :],
key[start_idx:end_idx, :, :, :],
value[start_idx:end_idx, :, :, :],
attn_mask=attn_mask[start_idx:end_idx, :, :, :] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal, **kwargs
)
torch.xpu.synchronize(query.device)
else:
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
hidden_states = original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
if is_unsqueezed:
hidden_states.squeeze(0)
return hidden_states

View File

@@ -1,312 +1,47 @@
import os
from functools import wraps
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
import diffusers #0.24.0 # pylint: disable=import-error
from diffusers.models.attention_processor import Attention
from diffusers.utils import USE_PEFT_BACKEND
from functools import cache
import diffusers # pylint: disable=import-error
# pylint: disable=protected-access, missing-function-docstring, line-too-long
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
@cache
def find_slice_size(slice_size, slice_block_size):
while (slice_size * slice_block_size) > attention_slice_rate:
slice_size = slice_size // 2
if slice_size <= 1:
slice_size = 1
break
return slice_size
@cache
def find_attention_slice_sizes(query_shape, query_element_size, query_device_type, slice_size=None):
if len(query_shape) == 3:
batch_size_attention, query_tokens, shape_three = query_shape
shape_four = 1
else:
batch_size_attention, query_tokens, shape_three, shape_four = query_shape
if slice_size is not None:
batch_size_attention = slice_size
slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size
block_size = batch_size_attention * slice_block_size
split_slice_size = batch_size_attention
split_2_slice_size = query_tokens
split_3_slice_size = shape_three
do_split = False
do_split_2 = False
do_split_3 = False
if query_device_type != "xpu":
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
if block_size > attention_slice_rate:
do_split = True
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
if split_slice_size * slice_block_size > attention_slice_rate:
slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size
do_split_2 = True
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size
do_split_3 = True
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
r"""
Processor for implementing sliced attention.
Args:
slice_size (`int`, *optional*):
The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
`attention_head_dim` must be a multiple of the `slice_size`.
"""
def __init__(self, slice_size):
self.slice_size = slice_size
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
encoder_hidden_states=None, attention_mask=None) -> torch.FloatTensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
dim = query.shape[-1]
query = attn.head_to_batch_dim(query)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
batch_size_attention, query_tokens, shape_three = query.shape
hidden_states = torch.zeros(
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
)
####################################################################
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
_, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type, slice_size=self.slice_size)
for i in range(batch_size_attention // split_slice_size):
start_idx = i * split_slice_size
end_idx = (i + 1) * split_slice_size
if do_split_2:
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
start_idx_2 = i2 * split_2_slice_size
end_idx_2 = (i2 + 1) * split_2_slice_size
if do_split_3:
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
start_idx_3 = i3 * split_3_slice_size
end_idx_3 = (i3 + 1) * split_3_slice_size
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
del query_slice
del key_slice
del attn_mask_slice
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3])
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice
del attn_slice
else:
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
del query_slice
del key_slice
del attn_mask_slice
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
del attn_slice
torch.xpu.synchronize(query.device)
else:
query_slice = query[start_idx:end_idx]
key_slice = key[start_idx:end_idx]
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
del query_slice
del key_slice
del attn_mask_slice
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
del attn_slice
####################################################################
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class AttnProcessor:
r"""
Default processor for performing attention-related computations.
"""
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
encoder_hidden_states=None, attention_mask=None,
temb=None, scale: float = 1.0) -> torch.Tensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches
residual = hidden_states
args = () if USE_PEFT_BACKEND else (scale,)
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, *args)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
####################################################################
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type)
if do_split:
for i in range(batch_size_attention // split_slice_size):
start_idx = i * split_slice_size
end_idx = (i + 1) * split_slice_size
if do_split_2:
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
start_idx_2 = i2 * split_2_slice_size
end_idx_2 = (i2 + 1) * split_2_slice_size
if do_split_3:
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
start_idx_3 = i3 * split_3_slice_size
end_idx_3 = (i3 + 1) * split_3_slice_size
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
del query_slice
del key_slice
del attn_mask_slice
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3])
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice
del attn_slice
else:
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
del query_slice
del key_slice
del attn_mask_slice
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
del attn_slice
else:
query_slice = query[start_idx:end_idx]
key_slice = key[start_idx:end_idx]
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
del query_slice
del key_slice
del attn_mask_slice
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
del attn_slice
torch.xpu.synchronize(query.device)
else:
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
####################################################################
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
def ipex_diffusers():
#ARC GPUs can't allocate more than 4GB to a single block:
diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor
diffusers.models.attention_processor.AttnProcessor = AttnProcessor
# Diffusers FreeU
original_fourier_filter = diffusers.utils.torch_utils.fourier_filter
@wraps(diffusers.utils.torch_utils.fourier_filter)
def fourier_filter(x_in, threshold, scale):
return_dtype = x_in.dtype
return original_fourier_filter(x_in.to(dtype=torch.float32), threshold, scale).to(dtype=return_dtype)
# fp64 error
class FluxPosEmbed(torch.nn.Module):
def __init__(self, theta: int, axes_dim):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
cos_out = []
sin_out = []
pos = ids.float()
for i in range(n_axes):
cos, sin = diffusers.models.embeddings.get_1d_rotary_pos_embed(
self.axes_dim[i],
pos[:, i],
theta=self.theta,
repeat_interleave_real=True,
use_real=True,
freqs_dtype=torch.float32,
)
cos_out.append(cos)
sin_out.append(sin)
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
return freqs_cos, freqs_sin
def ipex_diffusers(device_supports_fp64=False, can_allocate_plus_4gb=False):
diffusers.utils.torch_utils.fourier_filter = fourier_filter
if not device_supports_fp64:
diffusers.models.embeddings.FluxPosEmbed = FluxPosEmbed

View File

@@ -5,7 +5,7 @@ import intel_extension_for_pytorch._C as core # pylint: disable=import-error, un
# pylint: disable=protected-access, missing-function-docstring, line-too-long
device_supports_fp64 = torch.xpu.has_fp64_dtype()
device_supports_fp64 = torch.xpu.has_fp64_dtype() if hasattr(torch.xpu, "has_fp64_dtype") else torch.xpu.get_device_properties("xpu").has_fp64
OptState = ipex.cpu.autocast._grad_scaler.OptState
_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator
_refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state

View File

@@ -2,10 +2,19 @@ import os
from functools import wraps
from contextlib import nullcontext
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
import numpy as np
device_supports_fp64 = torch.xpu.has_fp64_dtype()
device_supports_fp64 = torch.xpu.has_fp64_dtype() if hasattr(torch.xpu, "has_fp64_dtype") else torch.xpu.get_device_properties("xpu").has_fp64
if os.environ.get('IPEX_FORCE_ATTENTION_SLICE', '0') == '0' and (torch.xpu.get_device_properties("xpu").total_memory / 1024 / 1024 / 1024) > 4.1:
try:
x = torch.ones((33000,33000), dtype=torch.float32, device="xpu")
del x
torch.xpu.empty_cache()
can_allocate_plus_4gb = True
except Exception:
can_allocate_plus_4gb = False
else:
can_allocate_plus_4gb = bool(os.environ.get('IPEX_FORCE_ATTENTION_SLICE', '0') == '-1')
# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
@@ -26,7 +35,7 @@ def check_device(device):
return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int))
def return_xpu(device):
return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu"
return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device(f"xpu:{device.index}" if device.index is not None else "xpu") if isinstance(device, torch.device) else "xpu"
# Autocast
@@ -42,7 +51,7 @@ def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=Non
original_interpolate = torch.nn.functional.interpolate
@wraps(torch.nn.functional.interpolate)
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
if antialias or align_corners is not None or mode == 'bicubic':
if mode in {'bicubic', 'bilinear'}:
return_device = tensor.device
return_dtype = tensor.dtype
return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode,
@@ -73,35 +82,46 @@ def as_tensor(data, dtype=None, device=None):
return original_as_tensor(data, dtype=dtype, device=device)
if device_supports_fp64 and os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is None:
original_torch_bmm = torch.bmm
if can_allocate_plus_4gb:
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
else:
# 32 bit attention workarounds for Alchemist:
try:
from .attention import torch_bmm_32_bit as original_torch_bmm
from .attention import scaled_dot_product_attention_32_bit as original_scaled_dot_product_attention
from .attention import dynamic_scaled_dot_product_attention as original_scaled_dot_product_attention
except Exception: # pylint: disable=broad-exception-caught
original_torch_bmm = torch.bmm
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
# Data Type Errors:
@wraps(torch.bmm)
def torch_bmm(input, mat2, *, out=None):
if input.dtype != mat2.dtype:
mat2 = mat2.to(input.dtype)
return original_torch_bmm(input, mat2, out=out)
@wraps(torch.nn.functional.scaled_dot_product_attention)
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs):
if query.dtype != key.dtype:
key = key.to(dtype=query.dtype)
if query.dtype != value.dtype:
value = value.to(dtype=query.dtype)
if attn_mask is not None and query.dtype != attn_mask.dtype:
attn_mask = attn_mask.to(dtype=query.dtype)
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
# Data Type Errors:
original_torch_bmm = torch.bmm
@wraps(torch.bmm)
def torch_bmm(input, mat2, *, out=None):
if input.dtype != mat2.dtype:
mat2 = mat2.to(input.dtype)
return original_torch_bmm(input, mat2, out=out)
# Diffusers FreeU
original_fft_fftn = torch.fft.fftn
@wraps(torch.fft.fftn)
def fft_fftn(input, s=None, dim=None, norm=None, *, out=None):
return_dtype = input.dtype
return original_fft_fftn(input.to(dtype=torch.float32), s=s, dim=dim, norm=norm, out=out).to(dtype=return_dtype)
# Diffusers FreeU
original_fft_ifftn = torch.fft.ifftn
@wraps(torch.fft.ifftn)
def fft_ifftn(input, s=None, dim=None, norm=None, *, out=None):
return_dtype = input.dtype
return original_fft_ifftn(input.to(dtype=torch.float32), s=s, dim=dim, norm=norm, out=out).to(dtype=return_dtype)
# A1111 FP16
original_functional_group_norm = torch.nn.functional.group_norm
@@ -133,6 +153,15 @@ def functional_linear(input, weight, bias=None):
bias.data = bias.data.to(dtype=weight.data.dtype)
return original_functional_linear(input, weight, bias=bias)
original_functional_conv1d = torch.nn.functional.conv1d
@wraps(torch.nn.functional.conv1d)
def functional_conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
if input.dtype != weight.data.dtype:
input = input.to(dtype=weight.data.dtype)
if bias is not None and bias.data.dtype != weight.data.dtype:
bias.data = bias.data.to(dtype=weight.data.dtype)
return original_functional_conv1d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
original_functional_conv2d = torch.nn.functional.conv2d
@wraps(torch.nn.functional.conv2d)
def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
@@ -142,14 +171,15 @@ def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1,
bias.data = bias.data.to(dtype=weight.data.dtype)
return original_functional_conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
# A1111 Embedding BF16
original_torch_cat = torch.cat
@wraps(torch.cat)
def torch_cat(tensor, *args, **kwargs):
if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs)
else:
return original_torch_cat(tensor, *args, **kwargs)
# LTX Video
original_functional_conv3d = torch.nn.functional.conv3d
@wraps(torch.nn.functional.conv3d)
def functional_conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
if input.dtype != weight.data.dtype:
input = input.to(dtype=weight.data.dtype)
if bias is not None and bias.data.dtype != weight.data.dtype:
bias.data = bias.data.to(dtype=weight.data.dtype)
return original_functional_conv3d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
# SwinIR BF16:
original_functional_pad = torch.nn.functional.pad
@@ -164,6 +194,7 @@ def functional_pad(input, pad, mode='constant', value=None):
original_torch_tensor = torch.tensor
@wraps(torch.tensor)
def torch_tensor(data, *args, dtype=None, device=None, **kwargs):
global device_supports_fp64
if check_device(device):
device = return_xpu(device)
if not device_supports_fp64:
@@ -227,7 +258,7 @@ def torch_empty(*args, device=None, **kwargs):
original_torch_randn = torch.randn
@wraps(torch.randn)
def torch_randn(*args, device=None, dtype=None, **kwargs):
if dtype == bytes:
if dtype is bytes:
dtype = None
if check_device(device):
return original_torch_randn(*args, device=return_xpu(device), **kwargs)
@@ -250,6 +281,14 @@ def torch_zeros(*args, device=None, **kwargs):
else:
return original_torch_zeros(*args, device=device, **kwargs)
original_torch_full = torch.full
@wraps(torch.full)
def torch_full(*args, device=None, **kwargs):
if check_device(device):
return original_torch_full(*args, device=return_xpu(device), **kwargs)
else:
return original_torch_full(*args, device=device, **kwargs)
original_torch_linspace = torch.linspace
@wraps(torch.linspace)
def torch_linspace(*args, device=None, **kwargs):
@@ -258,14 +297,6 @@ def torch_linspace(*args, device=None, **kwargs):
else:
return original_torch_linspace(*args, device=device, **kwargs)
original_torch_Generator = torch.Generator
@wraps(torch.Generator)
def torch_Generator(device=None):
if check_device(device):
return original_torch_Generator(return_xpu(device))
else:
return original_torch_Generator(device)
original_torch_load = torch.load
@wraps(torch.load)
def torch_load(f, map_location=None, *args, **kwargs):
@@ -276,9 +307,27 @@ def torch_load(f, map_location=None, *args, **kwargs):
else:
return original_torch_load(f, *args, map_location=map_location, **kwargs)
original_torch_Generator = torch.Generator
@wraps(torch.Generator)
def torch_Generator(device=None):
if check_device(device):
return original_torch_Generator(return_xpu(device))
else:
return original_torch_Generator(device)
@wraps(torch.cuda.synchronize)
def torch_cuda_synchronize(device=None):
if check_device(device):
return torch.xpu.synchronize(return_xpu(device))
else:
return torch.xpu.synchronize(device)
# Hijack Functions:
def ipex_hijacks():
def ipex_hijacks(legacy=True):
global device_supports_fp64, can_allocate_plus_4gb
if legacy and float(torch.__version__[:3]) < 2.5:
torch.nn.functional.interpolate = interpolate
torch.tensor = torch_tensor
torch.Tensor.to = Tensor_to
torch.Tensor.cuda = Tensor_cuda
@@ -289,9 +338,11 @@ def ipex_hijacks():
torch.randn = torch_randn
torch.ones = torch_ones
torch.zeros = torch_zeros
torch.full = torch_full
torch.linspace = torch_linspace
torch.Generator = torch_Generator
torch.load = torch_load
torch.Generator = torch_Generator
torch.cuda.synchronize = torch_cuda_synchronize
torch.backends.cuda.sdp_kernel = return_null_context
torch.nn.DataParallel = DummyDataParallel
@@ -302,12 +353,15 @@ def ipex_hijacks():
torch.nn.functional.group_norm = functional_group_norm
torch.nn.functional.layer_norm = functional_layer_norm
torch.nn.functional.linear = functional_linear
torch.nn.functional.conv1d = functional_conv1d
torch.nn.functional.conv2d = functional_conv2d
torch.nn.functional.interpolate = interpolate
torch.nn.functional.conv3d = functional_conv3d
torch.nn.functional.pad = functional_pad
torch.bmm = torch_bmm
torch.cat = torch_cat
torch.fft.fftn = fft_fftn
torch.fft.ifftn = fft_ifftn
if not device_supports_fp64:
torch.from_numpy = from_numpy
torch.as_tensor = as_tensor
return device_supports_fp64, can_allocate_plus_4gb

186
library/jpeg_xl_util.py Normal file
View File

@@ -0,0 +1,186 @@
# Modified from https://github.com/Fraetor/jxl_decode Original license: MIT
# Added partial read support for up to 200x speedup
import os
from typing import List, Tuple
class JXLBitstream:
"""
A stream of bits with methods for easy handling.
"""
def __init__(self, file, offset: int = 0, offsets: List[List[int]] = None):
self.shift = 0
self.bitstream = bytearray()
self.file = file
self.offset = offset
self.offsets = offsets
if self.offsets:
self.offset = self.offsets[0][1]
self.previous_data_len = 0
self.index = 0
self.file.seek(self.offset)
def get_bits(self, length: int = 1) -> int:
if self.offsets and self.shift + length > self.previous_data_len + self.offsets[self.index][2]:
self.partial_to_read_length = length
if self.shift < self.previous_data_len + self.offsets[self.index][2]:
self.partial_read(0, length)
self.bitstream.extend(self.file.read(self.partial_to_read_length))
else:
self.bitstream.extend(self.file.read(length))
bitmask = 2**length - 1
bits = (int.from_bytes(self.bitstream, "little") >> self.shift) & bitmask
self.shift += length
return bits
def partial_read(self, current_length: int, length: int) -> None:
self.previous_data_len += self.offsets[self.index][2]
to_read_length = self.previous_data_len - (self.shift + current_length)
self.bitstream.extend(self.file.read(to_read_length))
current_length += to_read_length
self.partial_to_read_length -= to_read_length
self.index += 1
self.file.seek(self.offsets[self.index][1])
if self.shift + length > self.previous_data_len + self.offsets[self.index][2]:
self.partial_read(current_length, length)
def decode_codestream(file, offset: int = 0, offsets: List[List[int]] = None) -> Tuple[int,int]:
"""
Decodes the actual codestream.
JXL codestream specification: http://www-internal/2022/18181-1
"""
# Convert codestream to int within an object to get some handy methods.
codestream = JXLBitstream(file, offset=offset, offsets=offsets)
# Skip signature
codestream.get_bits(16)
# SizeHeader
div8 = codestream.get_bits(1)
if div8:
height = 8 * (1 + codestream.get_bits(5))
else:
distribution = codestream.get_bits(2)
match distribution:
case 0:
height = 1 + codestream.get_bits(9)
case 1:
height = 1 + codestream.get_bits(13)
case 2:
height = 1 + codestream.get_bits(18)
case 3:
height = 1 + codestream.get_bits(30)
ratio = codestream.get_bits(3)
if div8 and not ratio:
width = 8 * (1 + codestream.get_bits(5))
elif not ratio:
distribution = codestream.get_bits(2)
match distribution:
case 0:
width = 1 + codestream.get_bits(9)
case 1:
width = 1 + codestream.get_bits(13)
case 2:
width = 1 + codestream.get_bits(18)
case 3:
width = 1 + codestream.get_bits(30)
else:
match ratio:
case 1:
width = height
case 2:
width = (height * 12) // 10
case 3:
width = (height * 4) // 3
case 4:
width = (height * 3) // 2
case 5:
width = (height * 16) // 9
case 6:
width = (height * 5) // 4
case 7:
width = (height * 2) // 1
return width, height
def decode_container(file) -> Tuple[int,int]:
"""
Parses the ISOBMFF container, extracts the codestream, and decodes it.
JXL container specification: http://www-internal/2022/18181-2
"""
def parse_box(file, file_start: int) -> dict:
file.seek(file_start)
LBox = int.from_bytes(file.read(4), "big")
XLBox = None
if 1 < LBox <= 8:
raise ValueError(f"Invalid LBox at byte {file_start}.")
if LBox == 1:
file.seek(file_start + 8)
XLBox = int.from_bytes(file.read(8), "big")
if XLBox <= 16:
raise ValueError(f"Invalid XLBox at byte {file_start}.")
if XLBox:
header_length = 16
box_length = XLBox
else:
header_length = 8
if LBox == 0:
box_length = os.fstat(file.fileno()).st_size - file_start
else:
box_length = LBox
file.seek(file_start + 4)
box_type = file.read(4)
file.seek(file_start)
return {
"length": box_length,
"type": box_type,
"offset": header_length,
}
file.seek(0)
# Reject files missing required boxes. These two boxes are required to be at
# the start and contain no values, so we can manually check there presence.
# Signature box. (Redundant as has already been checked.)
if file.read(12) != bytes.fromhex("0000000C 4A584C20 0D0A870A"):
raise ValueError("Invalid signature box.")
# File Type box.
if file.read(20) != bytes.fromhex(
"00000014 66747970 6A786C20 00000000 6A786C20"
):
raise ValueError("Invalid file type box.")
offset = 0
offsets = []
data_offset_not_found = True
container_pointer = 32
file_size = os.fstat(file.fileno()).st_size
while data_offset_not_found:
box = parse_box(file, container_pointer)
match box["type"]:
case b"jxlc":
offset = container_pointer + box["offset"]
data_offset_not_found = False
case b"jxlp":
file.seek(container_pointer + box["offset"])
index = int.from_bytes(file.read(4), "big")
offsets.append([index, container_pointer + box["offset"] + 4, box["length"] - box["offset"] - 4])
container_pointer += box["length"]
if container_pointer >= file_size:
data_offset_not_found = False
if offsets:
offsets.sort(key=lambda i: i[0])
file.seek(0)
return decode_codestream(file, offset=offset, offsets=offsets)
def get_jxl_size(path: str) -> Tuple[int,int]:
with open(path, "rb") as file:
if file.read(2) == bytes.fromhex("FF0A"):
return decode_codestream(file)
return decode_container(file)

View File

@@ -643,16 +643,15 @@ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
# rename or add position_ids
# remove position_ids for newer transformer, which causes error :(
ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
if ANOTHER_POSITION_IDS_KEY in new_sd:
# waifu diffusion v1.4
position_ids = new_sd[ANOTHER_POSITION_IDS_KEY]
del new_sd[ANOTHER_POSITION_IDS_KEY]
else:
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
new_sd["text_model.embeddings.position_ids"] = position_ids
if "text_model.embeddings.position_ids" in new_sd:
del new_sd["text_model.embeddings.position_ids"]
return new_sd

View File

@@ -344,8 +344,6 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser, support_text_en
def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません"
if args.v_parameterization:
logger.warning("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります")
if args.clip_skip is not None:
logger.warning("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")

View File

@@ -2,14 +2,16 @@
import os
import re
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, List, Optional, Tuple, Union
import numpy as np
from safetensors.torch import safe_open, save_file
import torch
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
# TODO remove circular import by moving ImageInfo to a separate file
# from library.train_util import ImageInfo
from library.utils import setup_logging
setup_logging()
@@ -17,81 +19,6 @@ import logging
logger = logging.getLogger(__name__)
from library import dataset_metadata_utils, utils
def get_compatible_dtypes(dtype: Optional[Union[str, torch.dtype]]) -> List[torch.dtype]:
if dtype is None:
# all dtypes are acceptable
return get_available_dtypes()
dtype = utils.str_to_dtype(dtype) if isinstance(dtype, str) else dtype
compatible_dtypes = [torch.float32]
if dtype.itemsize == 1: # fp8
compatible_dtypes.append(torch.bfloat16)
compatible_dtypes.append(torch.float16)
compatible_dtypes.append(dtype) # add the specified: bf16, fp16, one of fp8
return compatible_dtypes
def get_available_dtypes() -> List[torch.dtype]:
"""
Returns the list of available dtypes for latents caching. Higher precision is preferred.
"""
return [torch.float32, torch.bfloat16, torch.float16, torch.float8_e4m3fn, torch.float8_e5m2]
def remove_lower_precision_values(tensor_dict: Dict[str, torch.Tensor], keys_without_dtype: list[str]) -> None:
"""
Removes lower precision values from tensor_dict.
"""
available_dtypes = get_available_dtypes()
available_dtype_suffixes = [f"_{utils.dtype_to_normalized_str(dtype)}" for dtype in available_dtypes]
for key_without_dtype in keys_without_dtype:
available_itemsize = None
for dtype, dtype_suffix in zip(available_dtypes, available_dtype_suffixes):
key = key_without_dtype + dtype_suffix
if key in tensor_dict:
if available_itemsize is None:
available_itemsize = dtype.itemsize
elif available_itemsize > dtype.itemsize:
# if higher precision latents are already cached, remove lower precision latents
del tensor_dict[key]
def get_compatible_dtype_keys(
dict_keys: set[str], keys_without_dtype: list[str], dtype: Optional[Union[str, torch.dtype]]
) -> list[Optional[str]]:
"""
Returns the list of keys with the specified dtype or higher precision dtype. If the specified dtype is None, any dtype is acceptable.
If the key is not found, it returns None.
If the key in dict_keys doesn't have dtype suffix, it is acceptable, because it it long tensor.
:param dict_keys: set of keys in the dictionary
:param keys_without_dtype: list of keys without dtype suffix to check
:param dtype: dtype to check, or None for any dtype
:return: list of keys with the specified dtype or higher precision dtype. If the key is not found, it returns None for that key.
"""
compatible_dtypes = get_compatible_dtypes(dtype)
dtype_suffixes = [f"_{utils.dtype_to_normalized_str(dt)}" for dt in compatible_dtypes]
available_keys = []
for key_without_dtype in keys_without_dtype:
available_key = None
if key_without_dtype in dict_keys:
available_key = key_without_dtype
else:
for dtype_suffix in dtype_suffixes:
key = key_without_dtype + dtype_suffix
if key in dict_keys:
available_key = key
break
available_keys.append(available_key)
return available_keys
class TokenizeStrategy:
_strategy = None # strategy instance: actual strategy class
@@ -397,26 +324,17 @@ class TextEncoderOutputsCachingStrategy:
def __init__(
self,
architecture: str,
cache_to_disk: bool,
batch_size: Optional[int],
skip_disk_cache_validity_check: bool,
max_token_length: int,
masked: bool = False,
is_partial: bool = False,
is_weighted: bool = False,
) -> None:
"""
max_token_length: maximum token length for the model. Including/excluding starting and ending tokens depends on the model.
"""
self._architecture = architecture
self._cache_to_disk = cache_to_disk
self._batch_size = batch_size
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
self._max_token_length = max_token_length
self._masked = masked
self._is_partial = is_partial
self._is_weighted = is_weighted # enable weighting by `()` or `[]` in the prompt
self._is_weighted = is_weighted
@classmethod
def set_strategy(cls, strategy):
@@ -428,18 +346,6 @@ class TextEncoderOutputsCachingStrategy:
def get_strategy(cls) -> Optional["TextEncoderOutputsCachingStrategy"]:
return cls._strategy
@property
def architecture(self):
return self._architecture
@property
def max_token_length(self):
return self._max_token_length
@property
def masked(self):
return self._masked
@property
def cache_to_disk(self):
return self._cache_to_disk
@@ -448,11 +354,6 @@ class TextEncoderOutputsCachingStrategy:
def batch_size(self):
return self._batch_size
@property
def cache_suffix(self):
suffix_masked = "_m" if self.masked else ""
return f"_{self.architecture.lower()}_{self.max_token_length}{suffix_masked}_te.safetensors"
@property
def is_partial(self):
return self._is_partial
@@ -461,159 +362,31 @@ class TextEncoderOutputsCachingStrategy:
def is_weighted(self):
return self._is_weighted
def get_cache_path(self, absolute_path: str) -> str:
return os.path.splitext(absolute_path)[0] + self.cache_suffix
def load_from_disk(self, cache_path: str, caption_index: int) -> list[Optional[torch.Tensor]]:
def get_outputs_npz_path(self, image_abs_path: str) -> str:
raise NotImplementedError
def load_from_disk_for_keys(self, cache_path: str, caption_index: int, base_keys: list[str]) -> list[Optional[torch.Tensor]]:
"""
get tensors for keys_without_dtype, without dtype suffix. if the key is not found, it returns None.
all dtype tensors are returned, because cache validation is done in advance.
"""
with safe_open(cache_path, framework="pt") as f:
metadata = f.metadata()
version = metadata.get("format_version", "0.0.0")
major, minor, patch = map(int, version.split("."))
if major > 1: # or (major == 1 and minor > 0):
if not self.load_version_warning_printed:
self.load_version_warning_printed = True
logger.warning(
f"Existing latents cache file has a higher version {version} for {cache_path}. This may cause issues."
)
dict_keys = f.keys()
results = []
compatible_keys = self.get_compatible_output_keys(dict_keys, caption_index, base_keys, None)
for key in compatible_keys:
results.append(f.get_tensor(key) if key is not None else None)
return results
def is_disk_cached_outputs_expected(
self, cache_path: str, prompts: list[str], preferred_dtype: Optional[Union[str, torch.dtype]]
) -> bool:
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
raise NotImplementedError
def get_key_suffix(self, prompt_id: int, dtype: Optional[Union[str, torch.dtype]] = None) -> str:
"""
masked: may be False even if self.masked is True. It is False for some outputs.
"""
key_suffix = f"_{prompt_id}"
if dtype is not None and dtype.is_floating_point: # float tensor only
key_suffix += "_" + utils.dtype_to_normalized_str(dtype)
return key_suffix
def get_compatible_output_keys(
self, dict_keys: set[str], caption_index: int, base_keys: list[str], dtype: Optional[Union[str, torch.dtype]]
) -> list[Optional[str], Optional[str]]:
"""
returns the list of keys with the specified dtype or higher precision dtype. If the specified dtype is None, any dtype is acceptable.
"""
key_suffix = self.get_key_suffix(caption_index, None)
keys_without_dtype = [k + key_suffix for k in base_keys]
return get_compatible_dtype_keys(dict_keys, keys_without_dtype, dtype)
def _default_is_disk_cached_outputs_expected(
self,
cache_path: str,
captions: list[str],
base_keys: list[tuple[str, bool]],
preferred_dtype: Optional[Union[str, torch.dtype]],
):
if not self.cache_to_disk:
return False
if not os.path.exists(cache_path):
return False
if self.skip_disk_cache_validity_check:
return True
try:
with utils.MemoryEfficientSafeOpen(cache_path) as f:
keys = f.keys()
metadata = f.metadata()
# check captions in metadata
for i, caption in enumerate(captions):
if metadata.get(f"caption{i+1}") != caption:
return False
compatible_keys = self.get_compatible_output_keys(keys, i, base_keys, preferred_dtype)
if any(key is None for key in compatible_keys):
return False
except Exception as e:
logger.error(f"Error loading file: {cache_path}")
raise e
return True
def is_disk_cached_outputs_expected(self, npz_path: str) -> bool:
raise NotImplementedError
def cache_batch_outputs(
self,
tokenize_strategy: TokenizeStrategy,
models: list[Any],
text_encoding_strategy: TextEncodingStrategy,
batch: list[tuple[utils.ImageInfo, int, str]],
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, batch: List
):
raise NotImplementedError
def save_outputs_to_disk(self, cache_path: str, caption_index: int, caption: str, keys: list[str], outputs: list[torch.Tensor]):
tensor_dict = {}
overwrite = False
if os.path.exists(cache_path):
# load existing safetensors and update it
overwrite = True
with utils.MemoryEfficientSafeOpen(cache_path) as f:
metadata = f.metadata()
keys = f.keys()
for key in keys:
tensor_dict[key] = f.get_tensor(key)
assert metadata["architecture"] == self.architecture
file_version = metadata.get("format_version", "0.0.0")
major, minor, patch = map(int, file_version.split("."))
if major > 1 or (major == 1 and minor > 0):
self.save_version_warning_printed = True
logger.warning(
f"Existing latents cache file has a higher version {file_version} for {cache_path}. This may cause issues."
)
else:
metadata = {}
metadata["architecture"] = self.architecture
metadata["format_version"] = "1.0.0"
metadata[f"caption{caption_index+1}"] = caption
for key, output in zip(keys, outputs):
dtype = output.dtype # long or one of float
key_suffix = self.get_key_suffix(caption_index, dtype)
tensor_dict[key + key_suffix] = output
# remove lower precision latents if higher precision latents are already cached
if overwrite:
suffix_without_dtype = self.get_key_suffix(caption_index, None)
remove_lower_precision_values(tensor_dict, [key + suffix_without_dtype])
save_file(tensor_dict, cache_path, metadata=metadata)
class LatentsCachingStrategy:
# TODO commonize utillity functions to this class, such as npz handling etc.
_strategy = None # strategy instance: actual strategy class
def __init__(
self, architecture: str, latents_stride: int, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool
) -> None:
self._architecture = architecture
self._latents_stride = latents_stride
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
self._cache_to_disk = cache_to_disk
self._batch_size = batch_size
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
self.load_version_warning_printed = False
self.save_version_warning_printed = False
@classmethod
def set_strategy(cls, strategy):
if cls._strategy is not None:
@@ -624,14 +397,6 @@ class LatentsCachingStrategy:
def get_strategy(cls) -> Optional["LatentsCachingStrategy"]:
return cls._strategy
@property
def architecture(self):
return self._architecture
@property
def latents_stride(self):
return self._latents_stride
@property
def cache_to_disk(self):
return self._cache_to_disk
@@ -642,126 +407,54 @@ class LatentsCachingStrategy:
@property
def cache_suffix(self):
return f"_{self.architecture.lower()}.safetensors"
raise NotImplementedError
def get_image_size_from_disk_cache_path(self, absolute_path: str, cache_path: str) -> Tuple[Optional[int], Optional[int]]:
w, h = os.path.splitext(cache_path)[0].rsplit("_", 2)[-2].split("x")
def get_image_size_from_disk_cache_path(self, absolute_path: str, npz_path: str) -> Tuple[Optional[int], Optional[int]]:
w, h = os.path.splitext(npz_path)[0].split("_")[-2].split("x")
return int(w), int(h)
def get_latents_cache_path_from_info(self, info: utils.ImageInfo) -> str:
return self.get_latents_cache_path(info.absolute_path, info.image_size, info.latents_cache_dir)
def get_latents_cache_path(
self, absolute_path_or_archive_img_path: str, image_size: Tuple[int, int], cache_dir: Optional[str] = None
) -> str:
if cache_dir is not None:
if dataset_metadata_utils.is_archive_path(absolute_path_or_archive_img_path):
inner_path = dataset_metadata_utils.get_inner_path(absolute_path_or_archive_img_path)
archive_digest = dataset_metadata_utils.get_archive_digest(absolute_path_or_archive_img_path)
cache_file_base = os.path.join(cache_dir, f"{archive_digest}_{inner_path}")
else:
cache_file_base = os.path.join(cache_dir, os.path.basename(absolute_path_or_archive_img_path))
else:
cache_file_base = absolute_path_or_archive_img_path
return os.path.splitext(cache_file_base)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.cache_suffix
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
raise NotImplementedError
def is_disk_cached_latents_expected(
self,
bucket_reso: Tuple[int, int],
cache_path: str,
flip_aug: bool,
alpha_mask: bool,
preferred_dtype: Optional[Union[str, torch.dtype]],
self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool
) -> bool:
raise NotImplementedError
def cache_batch_latents(self, model: Any, batch: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
raise NotImplementedError
def get_key_suffix(
self,
bucket_reso: Optional[Tuple[int, int]] = None,
latents_size: Optional[Tuple[int, int]] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
) -> str:
"""
if dtype is None, it returns "_32x64" for example.
"""
if latents_size is not None:
expected_latents_size = latents_size # H, W
else:
# bucket_reso is (W, H)
expected_latents_size = (bucket_reso[1] // self.latents_stride, bucket_reso[0] // self.latents_stride) # H, W
if dtype is None:
dtype_suffix = ""
else:
dtype_suffix = "_" + utils.dtype_to_normalized_str(dtype)
# e.g. "_32x64_float16", HxW, dtype
key_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}{dtype_suffix}"
return key_suffix
def get_compatible_latents_keys(
self,
keys: set[str],
dtype: Optional[Union[str, torch.dtype]],
flip_aug: bool,
bucket_reso: Optional[Tuple[int, int]] = None,
latents_size: Optional[Tuple[int, int]] = None,
) -> list[Optional[str], Optional[str]]:
"""
bucket_reso is (W, H), latents_size is (H, W)
"""
key_suffix = self.get_key_suffix(bucket_reso, latents_size, None)
keys_without_dtype = ["latents" + key_suffix]
if flip_aug:
keys_without_dtype.append("latents_flipped" + key_suffix)
compatible_keys = get_compatible_dtype_keys(keys, keys_without_dtype, dtype)
return compatible_keys if flip_aug else compatible_keys[0] + [None]
def _default_is_disk_cached_latents_expected(
self,
latents_stride: int,
bucket_reso: Tuple[int, int],
latents_cache_path: str,
npz_path: str,
flip_aug: bool,
alpha_mask: bool,
preferred_dtype: Optional[Union[str, torch.dtype]],
multi_resolution: bool = False,
):
# multi_resolution is always enabled for any strategy
if not self.cache_to_disk:
return False
if not os.path.exists(latents_cache_path):
if not os.path.exists(npz_path):
return False
if self.skip_disk_cache_validity_check:
return True
key_suffix_without_dtype = self.get_key_suffix(bucket_reso=bucket_reso, dtype=None)
expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
# e.g. "_32x64", HxW
key_reso_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}" if multi_resolution else ""
try:
# safe_open locks the file, so we cannot use it for checking keys
# with safe_open(latents_cache_path, framework="pt") as f:
# keys = f.keys()
with utils.MemoryEfficientSafeOpen(latents_cache_path) as f:
keys = f.keys()
if alpha_mask and "alpha_mask" + key_suffix_without_dtype not in keys:
# print(f"alpha_mask not found: {latents_cache_path}")
npz = np.load(npz_path)
if "latents" + key_reso_suffix not in npz:
return False
# preferred_dtype is None if any dtype is acceptable
latents_key, flipped_latents_key = self.get_compatible_latents_keys(
keys, preferred_dtype, flip_aug, bucket_reso=bucket_reso
)
if latents_key is None or (flip_aug and flipped_latents_key is None):
# print(f"Precise dtype not found: {latents_cache_path}")
if flip_aug and "latents_flipped" + key_reso_suffix not in npz:
return False
if alpha_mask and "alpha_mask" + key_reso_suffix not in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {latents_cache_path}")
logger.error(f"Error loading file: {npz_path}")
raise e
return True
@@ -772,10 +465,11 @@ class LatentsCachingStrategy:
encode_by_vae,
vae_device,
vae_dtype,
image_infos: List[utils.ImageInfo],
image_infos: List,
flip_aug: bool,
alpha_mask: bool,
random_crop: bool,
multi_resolution: bool = False,
):
"""
Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common.
@@ -805,8 +499,13 @@ class LatentsCachingStrategy:
original_size = original_sizes[i]
crop_ltrb = crop_ltrbs[i]
latents_size = latents.shape[1:3] # H, W
key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" if multi_resolution else "" # e.g. "_32x64", HxW
if self.cache_to_disk:
self.save_latents_to_disk(info.latents_cache_path, latents, original_size, crop_ltrb, flipped_latent, alpha_mask)
self.save_latents_to_disk(
info.latents_npz, latents, original_size, crop_ltrb, flipped_latent, alpha_mask, key_reso_suffix
)
else:
info.latents_original_size = original_size
info.latents_crop_ltrb = crop_ltrb
@@ -816,96 +515,56 @@ class LatentsCachingStrategy:
info.alpha_mask = alpha_mask
def load_latents_from_disk(
self, cache_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]:
raise NotImplementedError
self, npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
"""
for SD/SDXL
"""
return self._default_load_latents_from_disk(None, npz_path, bucket_reso)
def _default_load_latents_from_disk(
self, cache_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]:
with safe_open(cache_path, framework="pt") as f:
metadata = f.metadata()
version = metadata.get("format_version", "0.0.0")
major, minor, patch = map(int, version.split("."))
if major > 1: # or (major == 1 and minor > 0):
if not self.load_version_warning_printed:
self.load_version_warning_printed = True
logger.warning(
f"Existing latents cache file has a higher version {version} for {cache_path}. This may cause issues."
)
self, latents_stride: Optional[int], npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
if latents_stride is None:
key_reso_suffix = ""
else:
latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" # e.g. "_32x64", HxW
keys = f.keys()
latents_key, flipped_latents_key = self.get_compatible_latents_keys(keys, None, flip_aug=True, bucket_reso=bucket_reso)
key_suffix_without_dtype = self.get_key_suffix(bucket_reso=bucket_reso, dtype=None)
alpha_mask_key = "alpha_mask" + key_suffix_without_dtype
latents = f.get_tensor(latents_key)
flipped_latents = f.get_tensor(flipped_latents_key) if flipped_latents_key is not None else None
alpha_mask = f.get_tensor(alpha_mask_key) if alpha_mask_key in keys else None
original_size = [int(metadata["width"]), int(metadata["height"])]
crop_ltrb = metadata[f"crop_ltrb" + key_suffix_without_dtype]
crop_ltrb = list(map(int, crop_ltrb.split(",")))
npz = np.load(npz_path)
if "latents" + key_reso_suffix not in npz:
raise ValueError(f"latents{key_reso_suffix} not found in {npz_path}")
latents = npz["latents" + key_reso_suffix]
original_size = npz["original_size" + key_reso_suffix].tolist()
crop_ltrb = npz["crop_ltrb" + key_reso_suffix].tolist()
flipped_latents = npz["latents_flipped" + key_reso_suffix] if "latents_flipped" + key_reso_suffix in npz else None
alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
def save_latents_to_disk(
self,
cache_path: str,
latents_tensor: torch.Tensor,
original_size: Tuple[int, int],
crop_ltrb: List[int],
flipped_latents_tensor: Optional[torch.Tensor] = None,
alpha_mask: Optional[torch.Tensor] = None,
npz_path,
latents_tensor,
original_size,
crop_ltrb,
flipped_latents_tensor=None,
alpha_mask=None,
key_reso_suffix="",
):
dtype = latents_tensor.dtype
latents_size = latents_tensor.shape[1:3] # H, W
tensor_dict = {}
kwargs = {}
overwrite = False
if os.path.exists(cache_path):
# load existing safetensors and update it
overwrite = True
if os.path.exists(npz_path):
# load existing npz and update it
npz = np.load(npz_path)
for key in npz.files:
kwargs[key] = npz[key]
# we cannot use safe_open here because it locks the file
# with safe_open(cache_path, framework="pt") as f:
with utils.MemoryEfficientSafeOpen(cache_path) as f:
metadata = f.metadata()
keys = f.keys()
for key in keys:
tensor_dict[key] = f.get_tensor(key)
assert metadata["architecture"] == self.architecture
file_version = metadata.get("format_version", "0.0.0")
major, minor, patch = map(int, file_version.split("."))
if major > 1 or (major == 1 and minor > 0):
self.save_version_warning_printed = True
logger.warning(
f"Existing latents cache file has a higher version {file_version} for {cache_path}. This may cause issues."
)
else:
metadata = {}
metadata["architecture"] = self.architecture
metadata["width"] = f"{original_size[0]}"
metadata["height"] = f"{original_size[1]}"
metadata["format_version"] = "1.0.0"
metadata[f"crop_ltrb_{latents_size[0]}x{latents_size[1]}"] = ",".join(map(str, crop_ltrb))
key_suffix = self.get_key_suffix(latents_size=latents_size, dtype=dtype)
if latents_tensor is not None:
tensor_dict["latents" + key_suffix] = latents_tensor
kwargs["latents" + key_reso_suffix] = latents_tensor.float().cpu().numpy()
kwargs["original_size" + key_reso_suffix] = np.array(original_size)
kwargs["crop_ltrb" + key_reso_suffix] = np.array(crop_ltrb)
if flipped_latents_tensor is not None:
tensor_dict["latents_flipped" + key_suffix] = flipped_latents_tensor
kwargs["latents_flipped" + key_reso_suffix] = flipped_latents_tensor.float().cpu().numpy()
if alpha_mask is not None:
key_suffix_without_dtype = self.get_key_suffix(latents_size=latents_size, dtype=None)
tensor_dict["alpha_mask" + key_suffix_without_dtype] = alpha_mask
# remove lower precision latents if higher precision latents are already cached
if overwrite:
suffix_without_dtype = self.get_key_suffix(latents_size=latents_size, dtype=None)
remove_lower_precision_values(tensor_dict, ["latents" + suffix_without_dtype, "latents_flipped" + suffix_without_dtype])
save_file(tensor_dict, cache_path, metadata=metadata)
kwargs["alpha_mask" + key_reso_suffix] = alpha_mask.float().cpu().numpy()
np.savez(npz_path, **kwargs)

View File

@@ -5,6 +5,9 @@ import torch
import numpy as np
from transformers import CLIPTokenizer, T5TokenizerFast
from library import flux_utils, train_util
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
from library.utils import setup_logging
setup_logging()
@@ -12,8 +15,6 @@ import logging
logger = logging.getLogger(__name__)
from library import flux_utils, train_util, utils
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14"
T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl"
@@ -85,56 +86,64 @@ class FluxTextEncodingStrategy(TextEncodingStrategy):
class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
KEYS = ["l_pooled", "t5_out", "txt_ids"]
KEYS_MASKED = ["t5_attn_mask", "apply_t5_attn_mask"]
FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_flux_te.npz"
def __init__(
self,
cache_to_disk: bool,
batch_size: int,
skip_disk_cache_validity_check: bool,
max_token_length: int,
masked: bool,
is_partial: bool = False,
apply_t5_attn_mask: bool = False,
) -> None:
super().__init__(
FluxLatentsCachingStrategy.ARCHITECTURE,
cache_to_disk,
batch_size,
skip_disk_cache_validity_check,
max_token_length,
masked,
is_partial,
)
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
self.apply_t5_attn_mask = apply_t5_attn_mask
self.warn_fp8_weights = False
def is_disk_cached_outputs_expected(
self, cache_path: str, prompts: list[str], preferred_dtype: Optional[Union[str, torch.dtype]]
):
keys = FluxTextEncoderOutputsCachingStrategy.KEYS
if self.masked:
keys += FluxTextEncoderOutputsCachingStrategy.KEYS_MASKED
return self._default_is_disk_cached_outputs_expected(cache_path, prompts, keys, preferred_dtype)
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return os.path.splitext(image_abs_path)[0] + FluxTextEncoderOutputsCachingStrategy.FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
def load_from_disk(self, cache_path: str, caption_index: int) -> list[Optional[torch.Tensor]]:
l_pooled, t5_out, txt_ids = self.load_from_disk_for_keys(
cache_path, caption_index, FluxTextEncoderOutputsCachingStrategy.KEYS
)
if self.masked:
t5_attn_mask = self.load_from_disk_for_keys(
cache_path, caption_index, FluxTextEncoderOutputsCachingStrategy.KEYS_MASKED
)[0]
else:
t5_attn_mask = None
def is_disk_cached_outputs_expected(self, npz_path: str):
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
return False
if self.skip_disk_cache_validity_check:
return True
try:
npz = np.load(npz_path)
if "l_pooled" not in npz:
return False
if "t5_out" not in npz:
return False
if "txt_ids" not in npz:
return False
if "t5_attn_mask" not in npz:
return False
if "apply_t5_attn_mask" not in npz:
return False
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
data = np.load(npz_path)
l_pooled = data["l_pooled"]
t5_out = data["t5_out"]
txt_ids = data["txt_ids"]
t5_attn_mask = data["t5_attn_mask"]
# apply_t5_attn_mask should be same as self.apply_t5_attn_mask
return [l_pooled, t5_out, txt_ids, t5_attn_mask]
def cache_batch_outputs(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
text_encoding_strategy: TextEncodingStrategy,
batch: list[tuple[utils.ImageInfo, int, str]],
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
):
if not self.warn_fp8_weights:
if flux_utils.get_t5xxl_actual_dtype(models[1]) == torch.float8_e4m3fn:
@@ -145,67 +154,80 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
self.warn_fp8_weights = True
flux_text_encoding_strategy: FluxTextEncodingStrategy = text_encoding_strategy
captions = [caption for _, _, caption in batch]
captions = [info.caption for info in infos]
tokens_and_masks = tokenize_strategy.tokenize(captions)
with torch.no_grad():
# attn_mask is applied in text_encoding_strategy.encode_tokens if apply_t5_attn_mask is True
l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens(tokenize_strategy, models, tokens_and_masks)
l_pooled = l_pooled.cpu()
t5_out = t5_out.cpu()
txt_ids = txt_ids.cpu()
t5_attn_mask = tokens_and_masks[2].cpu()
if l_pooled.dtype == torch.bfloat16:
l_pooled = l_pooled.float()
if t5_out.dtype == torch.bfloat16:
t5_out = t5_out.float()
if txt_ids.dtype == torch.bfloat16:
txt_ids = txt_ids.float()
keys = FluxTextEncoderOutputsCachingStrategy.KEYS
if self.masked:
keys += FluxTextEncoderOutputsCachingStrategy.KEYS_MASKED
l_pooled = l_pooled.cpu().numpy()
t5_out = t5_out.cpu().numpy()
txt_ids = txt_ids.cpu().numpy()
t5_attn_mask = tokens_and_masks[2].cpu().numpy()
for i, (info, caption_index, caption) in enumerate(batch):
for i, info in enumerate(infos):
l_pooled_i = l_pooled[i]
t5_out_i = t5_out[i]
txt_ids_i = txt_ids[i]
t5_attn_mask_i = t5_attn_mask[i]
apply_t5_attn_mask_i = self.apply_t5_attn_mask
if self.cache_to_disk:
outputs = [l_pooled_i, t5_out_i, txt_ids_i]
if self.masked:
outputs += [t5_attn_mask_i]
self.save_outputs_to_disk(info.text_encoder_outputs_cache_path, caption_index, caption, keys, outputs)
np.savez(
info.text_encoder_outputs_npz,
l_pooled=l_pooled_i,
t5_out=t5_out_i,
txt_ids=txt_ids_i,
t5_attn_mask=t5_attn_mask_i,
apply_t5_attn_mask=apply_t5_attn_mask_i,
)
else:
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
while len(info.text_encoder_outputs) <= caption_index:
info.text_encoder_outputs.append(None)
info.text_encoder_outputs[caption_index] = [l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i]
info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i)
class FluxLatentsCachingStrategy(LatentsCachingStrategy):
ARCHITECTURE = "flux"
FLUX_LATENTS_NPZ_SUFFIX = "_flux.npz"
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(FluxLatentsCachingStrategy.ARCHITECTURE, 8, cache_to_disk, batch_size, skip_disk_cache_validity_check)
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
def is_disk_cached_latents_expected(
self,
bucket_reso: Tuple[int, int],
cache_path: str,
flip_aug: bool,
alpha_mask: bool,
preferred_dtype: Optional[torch.dtype] = None,
):
return self._default_is_disk_cached_latents_expected(bucket_reso, cache_path, flip_aug, alpha_mask, preferred_dtype)
@property
def cache_suffix(self) -> str:
return FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return (
os.path.splitext(absolute_path)[0]
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX
)
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
def load_latents_from_disk(
self, cache_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]:
return self._default_load_latents_from_disk(cache_path, bucket_reso)
self, npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # support multi-resolution
def cache_batch_latents(self, vae, image_infos: List[utils.ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool):
# TODO remove circular dependency for ImageInfo
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu")
vae_device = vae.device
vae_dtype = vae.dtype
self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop)
self._default_cache_batch_latents(
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True
)
if not train_util.HIGH_VRAM:
train_util.clean_memory_on_device(vae.device)

View File

@@ -4,6 +4,8 @@ from typing import Any, List, Optional, Tuple, Union
import torch
from transformers import CLIPTokenizer
from library import train_util
from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncodingStrategy
from library.utils import setup_logging
setup_logging()
@@ -11,8 +13,6 @@ import logging
logger = logging.getLogger(__name__)
from library import train_util, utils
from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncodingStrategy
TOKENIZER_ID = "openai/clip-vit-large-patch14"
V2_STABLE_DIFFUSION_ID = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ
@@ -134,30 +134,33 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
# sd and sdxl share the same strategy. we can make them separate, but the difference is only the suffix.
# and we keep the old npz for the backward compatibility.
ARCHITECTURE_SD = "sd"
ARCHITECTURE_SDXL = "sdxl"
SD_OLD_LATENTS_NPZ_SUFFIX = ".npz"
SD_LATENTS_NPZ_SUFFIX = "_sd.npz"
SDXL_LATENTS_NPZ_SUFFIX = "_sdxl.npz"
def __init__(self, sd: bool, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
arch = SdSdxlLatentsCachingStrategy.ARCHITECTURE_SD if sd else SdSdxlLatentsCachingStrategy.ARCHITECTURE_SDXL
super().__init__(arch, 8, cache_to_disk, batch_size, skip_disk_cache_validity_check)
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
self.sd = sd
self.suffix = (
SdSdxlLatentsCachingStrategy.SD_LATENTS_NPZ_SUFFIX if sd else SdSdxlLatentsCachingStrategy.SDXL_LATENTS_NPZ_SUFFIX
)
@property
def cache_suffix(self) -> str:
return self.suffix
def is_disk_cached_latents_expected(
self,
bucket_reso: Tuple[int, int],
cache_path: str,
flip_aug: bool,
alpha_mask: bool,
preferred_dtype: Optional[torch.dtype] = None,
) -> bool:
return self._default_is_disk_cached_latents_expected(bucket_reso, cache_path, flip_aug, alpha_mask, preferred_dtype)
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
# support old .npz
old_npz_file = os.path.splitext(absolute_path)[0] + SdSdxlLatentsCachingStrategy.SD_OLD_LATENTS_NPZ_SUFFIX
if os.path.exists(old_npz_file):
return old_npz_file
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix
def load_latents_from_disk(
self, cache_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]:
return self._default_load_latents_from_disk(cache_path, bucket_reso)
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask)
def cache_batch_latents(self, vae, image_infos: List[utils.ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool):
# TODO remove circular dependency for ImageInfo
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).latent_dist.sample()
vae_device = vae.device
vae_dtype = vae.dtype

View File

@@ -6,6 +6,10 @@ import torch
import numpy as np
from transformers import CLIPTokenizer, T5TokenizerFast, CLIPTextModel, CLIPTextModelWithProjection, T5EncoderModel
from library import sd3_utils, train_util
from library import sd3_models
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
from library.utils import setup_logging
setup_logging()
@@ -13,9 +17,6 @@ import logging
logger = logging.getLogger(__name__)
from library import train_util, utils
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14"
CLIP_G_TOKENIZER_ID = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
@@ -253,8 +254,7 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy):
class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
KEYS = ["lg_out", "t5_out", "lg_pooled"]
KEYS_MASKED = ["clip_l_attn_mask", "clip_g_attn_mask", "t5_attn_mask"]
SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_sd3_te.npz"
def __init__(
self,
@@ -262,51 +262,70 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
batch_size: int,
skip_disk_cache_validity_check: bool,
is_partial: bool = False,
max_token_length: int = 256,
masked: bool = False,
apply_lg_attn_mask: bool = False,
apply_t5_attn_mask: bool = False,
) -> None:
"""
apply_lg_attn_mask and apply_t5_attn_mask must be same
"""
super().__init__(
Sd3LatentsCachingStrategy.ARCHITECTURE_SD3,
cache_to_disk,
batch_size,
skip_disk_cache_validity_check,
max_token_length,
masked=masked,
is_partial=is_partial,
)
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
self.apply_lg_attn_mask = apply_lg_attn_mask
self.apply_t5_attn_mask = apply_t5_attn_mask
def is_disk_cached_outputs_expected(
self, cache_path: str, prompts: list[str], preferred_dtype: Optional[Union[str, torch.dtype]]
) -> bool:
keys = Sd3TextEncoderOutputsCachingStrategy.KEYS
if self.masked:
keys += Sd3TextEncoderOutputsCachingStrategy.KEYS_MASKED
return self._default_is_disk_cached_outputs_expected(cache_path, prompts, keys, preferred_dtype)
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return os.path.splitext(image_abs_path)[0] + Sd3TextEncoderOutputsCachingStrategy.SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
def load_from_disk(self, cache_path: str, caption_index: int) -> list[Optional[torch.Tensor]]:
lg_out, lg_pooled, t5_out = self.load_from_disk_for_keys(
cache_path, caption_index, Sd3TextEncoderOutputsCachingStrategy.KEYS
)
if self.masked:
l_attn_mask, g_attn_mask, t5_attn_mask = self.load_from_disk_for_keys(
cache_path, caption_index, Sd3TextEncoderOutputsCachingStrategy.KEYS_MASKED
)
else:
l_attn_mask = g_attn_mask = t5_attn_mask = None
def is_disk_cached_outputs_expected(self, npz_path: str):
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
return False
if self.skip_disk_cache_validity_check:
return True
try:
npz = np.load(npz_path)
if "lg_out" not in npz:
return False
if "lg_pooled" not in npz:
return False
if "clip_l_attn_mask" not in npz or "clip_g_attn_mask" not in npz: # necessary even if not used
return False
if "apply_lg_attn_mask" not in npz:
return False
if "t5_out" not in npz:
return False
if "t5_attn_mask" not in npz:
return False
npz_apply_lg_attn_mask = npz["apply_lg_attn_mask"]
if npz_apply_lg_attn_mask != self.apply_lg_attn_mask:
return False
if "apply_t5_attn_mask" not in npz:
return False
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
data = np.load(npz_path)
lg_out = data["lg_out"]
lg_pooled = data["lg_pooled"]
t5_out = data["t5_out"]
l_attn_mask = data["clip_l_attn_mask"]
g_attn_mask = data["clip_g_attn_mask"]
t5_attn_mask = data["t5_attn_mask"]
# apply_t5_attn_mask and apply_lg_attn_mask are same as self.apply_t5_attn_mask and self.apply_lg_attn_mask
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
def cache_batch_outputs(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
text_encoding_strategy: TextEncodingStrategy,
batch: list[tuple[utils.ImageInfo, int, str]],
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
):
sd3_text_encoding_strategy: Sd3TextEncodingStrategy = text_encoding_strategy
captions = [caption for _, _, caption in batch]
captions = [info.caption for info in infos]
tokens_and_masks = tokenize_strategy.tokenize(captions)
with torch.no_grad():
@@ -315,76 +334,87 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
tokenize_strategy,
models,
tokens_and_masks,
apply_lg_attn_mask=self.masked,
apply_t5_attn_mask=self.masked,
apply_lg_attn_mask=self.apply_lg_attn_mask,
apply_t5_attn_mask=self.apply_t5_attn_mask,
enable_dropout=False,
)
lg_out = lg_out.cpu()
lg_pooled = lg_pooled.cpu()
t5_out = t5_out.cpu()
if lg_out.dtype == torch.bfloat16:
lg_out = lg_out.float()
if lg_pooled.dtype == torch.bfloat16:
lg_pooled = lg_pooled.float()
if t5_out.dtype == torch.bfloat16:
t5_out = t5_out.float()
l_attn_mask = tokens_and_masks[3].cpu()
g_attn_mask = tokens_and_masks[4].cpu()
t5_attn_mask = tokens_and_masks[5].cpu()
lg_out = lg_out.cpu().numpy()
lg_pooled = lg_pooled.cpu().numpy()
t5_out = t5_out.cpu().numpy()
keys = Sd3TextEncoderOutputsCachingStrategy.KEYS
if self.masked:
keys += Sd3TextEncoderOutputsCachingStrategy.KEYS_MASKED
for i, (info, caption_index, caption) in enumerate(batch):
l_attn_mask = tokens_and_masks[3].cpu().numpy()
g_attn_mask = tokens_and_masks[4].cpu().numpy()
t5_attn_mask = tokens_and_masks[5].cpu().numpy()
for i, info in enumerate(infos):
lg_out_i = lg_out[i]
t5_out_i = t5_out[i]
lg_pooled_i = lg_pooled[i]
l_attn_mask_i = l_attn_mask[i]
g_attn_mask_i = g_attn_mask[i]
t5_attn_mask_i = t5_attn_mask[i]
apply_lg_attn_mask = self.apply_lg_attn_mask
apply_t5_attn_mask = self.apply_t5_attn_mask
if self.cache_to_disk:
outputs = [lg_out_i, t5_out_i, lg_pooled_i]
if self.masked:
outputs += [l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i]
self.save_outputs_to_disk(info.text_encoder_outputs_cache_path, caption_index, caption, keys, outputs)
np.savez(
info.text_encoder_outputs_npz,
lg_out=lg_out_i,
lg_pooled=lg_pooled_i,
t5_out=t5_out_i,
clip_l_attn_mask=l_attn_mask_i,
clip_g_attn_mask=g_attn_mask_i,
t5_attn_mask=t5_attn_mask_i,
apply_lg_attn_mask=apply_lg_attn_mask,
apply_t5_attn_mask=apply_t5_attn_mask,
)
else:
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
while len(info.text_encoder_outputs) <= caption_index:
info.text_encoder_outputs.append(None)
info.text_encoder_outputs[caption_index] = [
lg_out_i,
t5_out_i,
lg_pooled_i,
l_attn_mask_i,
g_attn_mask_i,
t5_attn_mask_i,
]
info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i, l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i)
class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
ARCHITECTURE_SD3 = "sd3"
SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz"
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(Sd3LatentsCachingStrategy.ARCHITECTURE_SD3, 8, cache_to_disk, batch_size, skip_disk_cache_validity_check)
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
def is_disk_cached_latents_expected(
self,
bucket_reso: Tuple[int, int],
cache_path: str,
flip_aug: bool,
alpha_mask: bool,
preferred_dtype: Optional[torch.dtype] = None,
):
return self._default_is_disk_cached_latents_expected(bucket_reso, cache_path, flip_aug, alpha_mask, preferred_dtype)
@property
def cache_suffix(self) -> str:
return Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return (
os.path.splitext(absolute_path)[0]
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX
)
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
def load_latents_from_disk(
self, cache_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]:
return self._default_load_latents_from_disk(cache_path, bucket_reso)
self, npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # support multi-resolution
def cache_batch_latents(self, vae, image_infos: List[utils.ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool):
# TODO remove circular dependency for ImageInfo
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu")
vae_device = vae.device
vae_dtype = vae.dtype
self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop)
self._default_cache_batch_latents(
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True
)
if not train_util.HIGH_VRAM:
train_util.clean_memory_on_device(vae.device)

View File

@@ -4,6 +4,8 @@ from typing import Any, List, Optional, Tuple, Union
import numpy as np
import torch
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
from library.strategy_base import TokenizeStrategy, TextEncodingStrategy, TextEncoderOutputsCachingStrategy
from library.utils import setup_logging
@@ -12,8 +14,6 @@ import logging
logger = logging.getLogger(__name__)
from library.strategy_base import TokenizeStrategy, TextEncodingStrategy, TextEncoderOutputsCachingStrategy
from library import utils
TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
@@ -21,9 +21,6 @@ TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
class SdxlTokenizeStrategy(TokenizeStrategy):
def __init__(self, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None) -> None:
"""
max_length: maximum length of the input text, **excluding** the special tokens. None or 150 or 225
"""
self.tokenizer1 = self._load_tokenizer(CLIPTokenizer, TOKENIZER1_PATH, tokenizer_cache_dir=tokenizer_cache_dir)
self.tokenizer2 = self._load_tokenizer(CLIPTokenizer, TOKENIZER2_PATH, tokenizer_cache_dir=tokenizer_cache_dir)
self.tokenizer2.pad_token_id = 0 # use 0 as pad token for tokenizer2
@@ -223,51 +220,51 @@ class SdxlTextEncodingStrategy(TextEncodingStrategy):
class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
ARCHITECTURE_SDXL = "sdxl"
KEYS = ["hidden_state1", "hidden_state2", "pool2"]
SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_te_outputs.npz"
def __init__(
self,
cache_to_disk: bool,
batch_size: Optional[int],
batch_size: int,
skip_disk_cache_validity_check: bool,
max_token_length: Optional[int] = None,
is_partial: bool = False,
is_weighted: bool = False,
) -> None:
"""
max_token_length: maximum length of the input text, **excluding** the special tokens. None or 150 or 225
"""
max_token_length = max_token_length or 75
super().__init__(
SdxlTextEncoderOutputsCachingStrategy.ARCHITECTURE_SDXL,
cache_to_disk,
batch_size,
skip_disk_cache_validity_check,
is_partial,
is_weighted,
max_token_length=max_token_length,
)
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial, is_weighted)
def is_disk_cached_outputs_expected(
self, cache_path: str, prompts: list[str], preferred_dtype: Optional[Union[str, torch.dtype]]
) -> bool:
# SDXL does not support attn mask
base_keys = SdxlTextEncoderOutputsCachingStrategy.KEYS
return self._default_is_disk_cached_outputs_expected(cache_path, prompts, base_keys, preferred_dtype)
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return os.path.splitext(image_abs_path)[0] + SdxlTextEncoderOutputsCachingStrategy.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
def load_from_disk(self, cache_path: str, caption_index: int) -> list[Optional[torch.Tensor]]:
return self.load_from_disk_for_keys(cache_path, caption_index, SdxlTextEncoderOutputsCachingStrategy.KEYS)
def is_disk_cached_outputs_expected(self, npz_path: str):
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
return False
if self.skip_disk_cache_validity_check:
return True
try:
npz = np.load(npz_path)
if "hidden_state1" not in npz or "hidden_state2" not in npz or "pool2" not in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
data = np.load(npz_path)
hidden_state1 = data["hidden_state1"]
hidden_state2 = data["hidden_state2"]
pool2 = data["pool2"]
return [hidden_state1, hidden_state2, pool2]
def cache_batch_outputs(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
text_encoding_strategy: TextEncodingStrategy,
batch: list[tuple[utils.ImageInfo, int, str]],
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
):
sdxl_text_encoding_strategy = text_encoding_strategy # type: SdxlTextEncodingStrategy
captions = [caption for _, _, caption in batch]
captions = [info.caption for info in infos]
if self.is_weighted:
tokens_list, weights_list = tokenize_strategy.tokenize_with_weights(captions)
@@ -282,24 +279,28 @@ class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
tokenize_strategy, models, [tokens1, tokens2]
)
hidden_state1 = hidden_state1.cpu()
hidden_state2 = hidden_state2.cpu()
pool2 = pool2.cpu()
if hidden_state1.dtype == torch.bfloat16:
hidden_state1 = hidden_state1.float()
if hidden_state2.dtype == torch.bfloat16:
hidden_state2 = hidden_state2.float()
if pool2.dtype == torch.bfloat16:
pool2 = pool2.float()
for i, (info, caption_index, caption) in enumerate(batch):
hidden_state1 = hidden_state1.cpu().numpy()
hidden_state2 = hidden_state2.cpu().numpy()
pool2 = pool2.cpu().numpy()
for i, info in enumerate(infos):
hidden_state1_i = hidden_state1[i]
hidden_state2_i = hidden_state2[i]
pool2_i = pool2[i]
if self.cache_to_disk:
self.save_outputs_to_disk(
info.text_encoder_outputs_cache_path,
caption_index,
caption,
SdxlTextEncoderOutputsCachingStrategy.KEYS,
[hidden_state1_i, hidden_state2_i, pool2_i],
np.savez(
info.text_encoder_outputs_npz,
hidden_state1=hidden_state1_i,
hidden_state2=hidden_state2_i,
pool2=pool2_i,
)
else:
while len(info.text_encoder_outputs) <= caption_index:
info.text_encoder_outputs.append(None)
info.text_encoder_outputs[caption_index] = [hidden_state1_i, hidden_state2_i, pool2_i]
info.text_encoder_outputs = [hidden_state1_i, hidden_state2_i, pool2_i]

File diff suppressed because it is too large Load Diff

View File

@@ -16,67 +16,10 @@ from PIL import Image
import numpy as np
from safetensors.torch import load_file
def fire_in_thread(f, *args, **kwargs):
threading.Thread(target=f, args=args, kwargs=kwargs).start()
class ImageInfo:
def __init__(self, image_key: str, num_repeats: int, is_reg: bool, absolute_path: str) -> None:
self.image_key: str = image_key
self.num_repeats: int = num_repeats
self.captions: Optional[list[str]] = None
self.caption_weights: Optional[list[float]] = None # weights for each caption in sampling
self.list_of_tags: Optional[list[str]] = None
self.tags_weights: Optional[list[float]] = None
self.is_reg: bool = is_reg
self.absolute_path: str = absolute_path
self.latents_cache_dir: Optional[str] = None
self.image_size: Tuple[int, int] = None
self.resized_size: Tuple[int, int] = None
self.bucket_reso: Tuple[int, int] = None
self.latents: Optional[torch.Tensor] = None
self.latents_flipped: Optional[torch.Tensor] = None
self.latents_cache_path: Optional[str] = None # set in cache_latents
self.latents_original_size: Optional[Tuple[int, int]] = None # original image size, not latents size
# crop left top right bottom in original pixel size, not latents size
self.latents_crop_ltrb: Optional[Tuple[int, int]] = None
self.cond_img_path: Optional[str] = None
self.image: Optional[Image.Image] = None # optional, original PIL Image. None if not the latents is cached
self.text_encoder_outputs_cache_path: Optional[str] = None # set in cache_text_encoder_outputs
# new
self.text_encoder_outputs: Optional[list[list[torch.Tensor]]] = None
# old
self.text_encoder_outputs1: Optional[torch.Tensor] = None
self.text_encoder_outputs2: Optional[torch.Tensor] = None
self.text_encoder_pool2: Optional[torch.Tensor] = None
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
def __str__(self) -> str:
return f"ImageInfo(image_key={self.image_key}, num_repeats={self.num_repeats}, captions={self.captions}, is_reg={self.is_reg}, absolute_path={self.absolute_path})"
def set_dreambooth_info(self, list_of_tags: list[str]) -> None:
self.list_of_tags = list_of_tags
def set_fine_tuning_info(
self,
captions: Optional[list[str]],
caption_weights: Optional[list[float]],
list_of_tags: Optional[list[str]],
tags_weights: Optional[list[float]],
image_size: Tuple[int, int],
latents_cache_dir: Optional[str],
):
self.captions = captions
self.caption_weights = caption_weights
self.list_of_tags = list_of_tags
self.tags_weights = tags_weights
self.image_size = image_size
self.latents_cache_dir = latents_cache_dir
# region Logging
@@ -145,6 +88,8 @@ def setup_logging(args=None, log_level=None, reset=False):
logger = logging.getLogger(__name__)
logger.info(msg_init)
setup_logging()
logger = logging.getLogger(__name__)
# endregion
@@ -245,15 +190,6 @@ def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None)
raise ValueError(f"Unsupported dtype: {s}")
def dtype_to_normalized_str(dtype: Union[str, torch.dtype]) -> str:
dtype = str_to_dtype(dtype) if isinstance(dtype, str) else dtype
# get name of the dtype
dtype_name = str(dtype).split(".")[-1]
return dtype_name
def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None):
"""
memory efficient save file
@@ -326,7 +262,6 @@ def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata:
class MemoryEfficientSafeOpen:
# does not support metadata loading
def __init__(self, filename):
self.filename = filename
self.file = open(filename, "rb")
@@ -444,7 +379,7 @@ def load_safetensors(
# region Image utils
def pil_resize(image, size, interpolation=Image.LANCZOS):
def pil_resize(image, size, interpolation):
has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False
if has_alpha:
@@ -452,7 +387,7 @@ def pil_resize(image, size, interpolation=Image.LANCZOS):
else:
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
resized_pil = pil_image.resize(size, interpolation)
resized_pil = pil_image.resize(size, resample=interpolation)
# Convert back to cv2 format
if has_alpha:
@@ -463,6 +398,117 @@ def pil_resize(image, size, interpolation=Image.LANCZOS):
return resized_cv2
def resize_image(image: np.ndarray, width: int, height: int, resized_width: int, resized_height: int, resize_interpolation: Optional[str] = None):
"""
Resize image with resize interpolation. Default interpolation to AREA if image is smaller, else LANCZOS.
Args:
image: numpy.ndarray
width: int Original image width
height: int Original image height
resized_width: int Resized image width
resized_height: int Resized image height
resize_interpolation: Optional[str] Resize interpolation method "lanczos", "area", "bilinear", "bicubic", "nearest", "box"
Returns:
image
"""
# Ensure all size parameters are actual integers
width = int(width)
height = int(height)
resized_width = int(resized_width)
resized_height = int(resized_height)
if resize_interpolation is None:
if width >= resized_width and height >= resized_height:
resize_interpolation = "area"
else:
resize_interpolation = "lanczos"
# we use PIL for lanczos (for backward compatibility) and box, cv2 for others
use_pil = resize_interpolation in ["lanczos", "lanczos4", "box"]
resized_size = (resized_width, resized_height)
if use_pil:
interpolation = get_pil_interpolation(resize_interpolation)
image = pil_resize(image, resized_size, interpolation=interpolation)
logger.debug(f"resize image using {resize_interpolation} (PIL)")
else:
interpolation = get_cv2_interpolation(resize_interpolation)
image = cv2.resize(image, resized_size, interpolation=interpolation)
logger.debug(f"resize image using {resize_interpolation} (cv2)")
return image
def get_cv2_interpolation(interpolation: Optional[str]) -> Optional[int]:
"""
Convert interpolation value to cv2 interpolation integer
https://docs.opencv.org/3.4/da/d54/group__imgproc__transform.html#ga5bb5a1fea74ea38e1a5445ca803ff121
"""
if interpolation is None:
return None
if interpolation == "lanczos" or interpolation == "lanczos4":
# Lanczos interpolation over 8x8 neighborhood
return cv2.INTER_LANCZOS4
elif interpolation == "nearest":
# Bit exact nearest neighbor interpolation. This will produce same results as the nearest neighbor method in PIL, scikit-image or Matlab.
return cv2.INTER_NEAREST_EXACT
elif interpolation == "bilinear" or interpolation == "linear":
# bilinear interpolation
return cv2.INTER_LINEAR
elif interpolation == "bicubic" or interpolation == "cubic":
# bicubic interpolation
return cv2.INTER_CUBIC
elif interpolation == "area":
# resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.
return cv2.INTER_AREA
elif interpolation == "box":
# resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.
return cv2.INTER_AREA
else:
return None
def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resampling]:
"""
Convert interpolation value to PIL interpolation
https://pillow.readthedocs.io/en/stable/handbook/concepts.html#concept-filters
"""
if interpolation is None:
return None
if interpolation == "lanczos":
return Image.Resampling.LANCZOS
elif interpolation == "nearest":
# Pick one nearest pixel from the input image. Ignore all other input pixels.
return Image.Resampling.NEAREST
elif interpolation == "bilinear" or interpolation == "linear":
# For resize calculate the output pixel value using linear interpolation on all pixels that may contribute to the output value. For other transformations linear interpolation over a 2x2 environment in the input image is used.
return Image.Resampling.BILINEAR
elif interpolation == "bicubic" or interpolation == "cubic":
# For resize calculate the output pixel value using cubic interpolation on all pixels that may contribute to the output value. For other transformations cubic interpolation over a 4x4 environment in the input image is used.
return Image.Resampling.BICUBIC
elif interpolation == "area":
# Image.Resampling.BOX may be more appropriate if upscaling
# Area interpolation is related to cv2.INTER_AREA
# Produces a sharper image than Resampling.BILINEAR, doesnt have dislocations on local level like with Resampling.BOX.
return Image.Resampling.HAMMING
elif interpolation == "box":
# Each pixel of source image contributes to one pixel of the destination image with identical weights. For upscaling is equivalent of Resampling.NEAREST.
return Image.Resampling.BOX
else:
return None
def validate_interpolation_fn(interpolation_str: str) -> bool:
"""
Check if a interpolation function is supported
"""
return interpolation_str in ["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area", "box"]
# endregion
# TODO make inf_utils.py

View File

@@ -268,7 +268,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
class DyLoRANetwork(torch.nn.Module):
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
LORA_PREFIX_UNET = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te"

View File

@@ -866,7 +866,7 @@ class LoRANetwork(torch.nn.Module):
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
LORA_PREFIX_UNET = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te"

View File

@@ -278,7 +278,7 @@ def merge_lora_weights(pipe, weights_sd: Dict, multiplier: float = 1.0):
class LoRANetwork(torch.nn.Module):
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
LORA_PREFIX_UNET = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te"

View File

@@ -755,7 +755,7 @@ class LoRANetwork(torch.nn.Module):
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
LORA_PREFIX_UNET = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te"

View File

@@ -9,11 +9,13 @@
import math
import os
from contextlib import contextmanager
from typing import Dict, List, Optional, Tuple, Type, Union
from diffusers import AutoencoderKL
from transformers import CLIPTextModel
import numpy as np
import torch
from torch import Tensor
import re
from library.utils import setup_logging
from library.sdxl_original_unet import SdxlUNet2DConditionModel
@@ -44,6 +46,8 @@ class LoRAModule(torch.nn.Module):
rank_dropout=None,
module_dropout=None,
split_dims: Optional[List[int]] = None,
ggpo_beta: Optional[float] = None,
ggpo_sigma: Optional[float] = None,
):
"""
if alpha == 0 or None, alpha is rank (no scaling).
@@ -103,9 +107,20 @@ class LoRAModule(torch.nn.Module):
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.ggpo_sigma = ggpo_sigma
self.ggpo_beta = ggpo_beta
if self.ggpo_beta is not None and self.ggpo_sigma is not None:
self.combined_weight_norms = None
self.grad_norms = None
self.perturbation_norm_factor = 1.0 / math.sqrt(org_module.weight.shape[0])
self.initialize_norm_cache(org_module.weight)
self.org_module_shape: tuple[int] = org_module.weight.shape
def apply_to(self):
self.org_forward = self.org_module.forward
self.org_module.forward = self.forward
del self.org_module
def forward(self, x):
@@ -140,7 +155,17 @@ class LoRAModule(torch.nn.Module):
lx = self.lora_up(lx)
return org_forwarded + lx * self.multiplier * scale
# LoRA Gradient-Guided Perturbation Optimization
if self.training and self.ggpo_sigma is not None and self.ggpo_beta is not None and self.combined_weight_norms is not None and self.grad_norms is not None:
with torch.no_grad():
perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms ** 2)) + (self.ggpo_beta * (self.grad_norms ** 2))
perturbation_scale_factor = (perturbation_scale * self.perturbation_norm_factor).to(self.device)
perturbation = torch.randn(self.org_module_shape, dtype=self.dtype, device=self.device)
perturbation.mul_(perturbation_scale_factor)
perturbation_output = x @ perturbation.T # Result: (batch × n)
return org_forwarded + (self.multiplier * scale * lx) + perturbation_output
else:
return org_forwarded + lx * self.multiplier * scale
else:
lxs = [lora_down(x) for lora_down in self.lora_down]
@@ -167,6 +192,116 @@ class LoRAModule(torch.nn.Module):
return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale
@torch.no_grad()
def initialize_norm_cache(self, org_module_weight: Tensor):
# Choose a reasonable sample size
n_rows = org_module_weight.shape[0]
sample_size = min(1000, n_rows) # Cap at 1000 samples or use all if smaller
# Sample random indices across all rows
indices = torch.randperm(n_rows)[:sample_size]
# Convert to a supported data type first, then index
# Use float32 for indexing operations
weights_float32 = org_module_weight.to(dtype=torch.float32)
sampled_weights = weights_float32[indices].to(device=self.device)
# Calculate sampled norms
sampled_norms = torch.norm(sampled_weights, dim=1, keepdim=True)
# Store the mean norm as our estimate
self.org_weight_norm_estimate = sampled_norms.mean()
# Optional: store standard deviation for confidence intervals
self.org_weight_norm_std = sampled_norms.std()
# Free memory
del sampled_weights, weights_float32
@torch.no_grad()
def validate_norm_approximation(self, org_module_weight: Tensor, verbose=True):
# Calculate the true norm (this will be slow but it's just for validation)
true_norms = []
chunk_size = 1024 # Process in chunks to avoid OOM
for i in range(0, org_module_weight.shape[0], chunk_size):
end_idx = min(i + chunk_size, org_module_weight.shape[0])
chunk = org_module_weight[i:end_idx].to(device=self.device, dtype=self.dtype)
chunk_norms = torch.norm(chunk, dim=1, keepdim=True)
true_norms.append(chunk_norms.cpu())
del chunk
true_norms = torch.cat(true_norms, dim=0)
true_mean_norm = true_norms.mean().item()
# Compare with our estimate
estimated_norm = self.org_weight_norm_estimate.item()
# Calculate error metrics
absolute_error = abs(true_mean_norm - estimated_norm)
relative_error = absolute_error / true_mean_norm * 100 # as percentage
if verbose:
logger.info(f"True mean norm: {true_mean_norm:.6f}")
logger.info(f"Estimated norm: {estimated_norm:.6f}")
logger.info(f"Absolute error: {absolute_error:.6f}")
logger.info(f"Relative error: {relative_error:.2f}%")
return {
'true_mean_norm': true_mean_norm,
'estimated_norm': estimated_norm,
'absolute_error': absolute_error,
'relative_error': relative_error
}
@torch.no_grad()
def update_norms(self):
# Not running GGPO so not currently running update norms
if self.ggpo_beta is None or self.ggpo_sigma is None:
return
# only update norms when we are training
if self.training is False:
return
module_weights = self.lora_up.weight @ self.lora_down.weight
module_weights.mul(self.scale)
self.weight_norms = torch.norm(module_weights, dim=1, keepdim=True)
self.combined_weight_norms = torch.sqrt((self.org_weight_norm_estimate**2) +
torch.sum(module_weights**2, dim=1, keepdim=True))
@torch.no_grad()
def update_grad_norms(self):
if self.training is False:
print(f"skipping update_grad_norms for {self.lora_name}")
return
lora_down_grad = None
lora_up_grad = None
for name, param in self.named_parameters():
if name == "lora_down.weight":
lora_down_grad = param.grad
elif name == "lora_up.weight":
lora_up_grad = param.grad
# Calculate gradient norms if we have both gradients
if lora_down_grad is not None and lora_up_grad is not None:
with torch.autocast(self.device.type):
approx_grad = self.scale * ((self.lora_up.weight @ lora_down_grad) + (lora_up_grad @ self.lora_down.weight))
self.grad_norms = torch.norm(approx_grad, dim=1, keepdim=True)
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
class LoRAInfModule(LoRAModule):
def __init__(
@@ -420,6 +555,16 @@ def create_network(
if split_qkv is not None:
split_qkv = True if split_qkv == "True" else False
ggpo_beta = kwargs.get("ggpo_beta", None)
ggpo_sigma = kwargs.get("ggpo_sigma", None)
if ggpo_beta is not None:
ggpo_beta = float(ggpo_beta)
if ggpo_sigma is not None:
ggpo_sigma = float(ggpo_sigma)
# train T5XXL
train_t5xxl = kwargs.get("train_t5xxl", False)
if train_t5xxl is not None:
@@ -449,6 +594,8 @@ def create_network(
in_dims=in_dims,
train_double_block_indices=train_double_block_indices,
train_single_block_indices=train_single_block_indices,
ggpo_beta=ggpo_beta,
ggpo_sigma=ggpo_sigma,
verbose=verbose,
)
@@ -561,6 +708,8 @@ class LoRANetwork(torch.nn.Module):
in_dims: Optional[List[int]] = None,
train_double_block_indices: Optional[List[bool]] = None,
train_single_block_indices: Optional[List[bool]] = None,
ggpo_beta: Optional[float] = None,
ggpo_sigma: Optional[float] = None,
verbose: Optional[bool] = False,
) -> None:
super().__init__()
@@ -599,10 +748,16 @@ class LoRANetwork(torch.nn.Module):
# logger.info(
# f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
# )
if ggpo_beta is not None and ggpo_sigma is not None:
logger.info(f"LoRA-GGPO training sigma: {ggpo_sigma} beta: {ggpo_beta}")
if self.split_qkv:
logger.info(f"split qkv for LoRA")
if self.train_blocks is not None:
logger.info(f"train {self.train_blocks} blocks only")
if train_t5xxl:
logger.info(f"train T5XXL as well")
@@ -722,6 +877,8 @@ class LoRANetwork(torch.nn.Module):
rank_dropout=rank_dropout,
module_dropout=module_dropout,
split_dims=split_dims,
ggpo_beta=ggpo_beta,
ggpo_sigma=ggpo_sigma,
)
loras.append(lora)
@@ -790,6 +947,36 @@ class LoRANetwork(torch.nn.Module):
for lora in self.text_encoder_loras + self.unet_loras:
lora.enabled = is_enabled
def update_norms(self):
for lora in self.text_encoder_loras + self.unet_loras:
lora.update_norms()
def update_grad_norms(self):
for lora in self.text_encoder_loras + self.unet_loras:
lora.update_grad_norms()
def grad_norms(self) -> Tensor:
grad_norms = []
for lora in self.text_encoder_loras + self.unet_loras:
if hasattr(lora, "grad_norms") and lora.grad_norms is not None:
grad_norms.append(lora.grad_norms.mean(dim=0))
return torch.stack(grad_norms) if len(grad_norms) > 0 else torch.tensor([])
def weight_norms(self) -> Tensor:
weight_norms = []
for lora in self.text_encoder_loras + self.unet_loras:
if hasattr(lora, "weight_norms") and lora.weight_norms is not None:
weight_norms.append(lora.weight_norms.mean(dim=0))
return torch.stack(weight_norms) if len(weight_norms) > 0 else torch.tensor([])
def combined_weight_norms(self) -> Tensor:
combined_weight_norms = []
for lora in self.text_encoder_loras + self.unet_loras:
if hasattr(lora, "combined_weight_norms") and lora.combined_weight_norms is not None:
combined_weight_norms.append(lora.combined_weight_norms.mean(dim=0))
return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else torch.tensor([])
def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file

View File

@@ -7,9 +7,11 @@ opencv-python==4.8.1.78
einops==0.7.0
pytorch-lightning==1.9.0
bitsandbytes==0.44.0
prodigyopt==1.0
lion-pytorch==0.0.6
schedulefree==1.4
pytorch-optimizer==3.5.0
prodigy-plus-schedule-free==1.9.0
prodigyopt==1.1.2
tensorboard
safetensors==0.4.4
# gradio==3.16.2

View File

@@ -75,12 +75,6 @@ def train(args):
)
args.cache_text_encoder_outputs = True
if args.cache_text_encoder_outputs:
assert args.apply_lg_attn_mask == args.apply_t5_attn_mask, (
"apply_lg_attn_mask and apply_t5_attn_mask must be the same when caching text encoder outputs"
" / text encoderの出力をキャッシュするときにはapply_lg_attn_maskとapply_t5_attn_maskは同じである必要があります"
)
assert not args.train_text_encoder or (args.use_t5xxl_cache_only or not args.cache_text_encoder_outputs), (
"when training text encoder, text encoder outputs must not be cached (except for T5XXL)"
+ " / text encoderの学習時はtext encoderの出力はキャッシュできませんt5xxlのみキャッシュすることは可能です"
@@ -175,8 +169,8 @@ def train(args):
args.text_encoder_batch_size,
False,
False,
args.t5xxl_max_token_length,
args.apply_lg_attn_mask,
False,
False,
)
)
train_dataset_group.set_current_strategies()
@@ -285,8 +279,8 @@ def train(args):
args.text_encoder_batch_size,
args.skip_cache_check,
train_clip or args.use_t5xxl_cache_only, # if clip is trained or t5xxl is cached, caching is partial
args.t5xxl_max_token_length,
args.apply_lg_attn_mask,
args.apply_t5_attn_mask,
)
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy)
@@ -337,7 +331,7 @@ def train(args):
vae.requires_grad_(False)
vae.eval()
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
train_dataset_group.new_cache_latents(vae, accelerator)
vae.to("cpu") # if no sampling, vae can be deleted
clean_memory_on_device(accelerator.device)

View File

@@ -26,7 +26,12 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
super().__init__()
self.sample_prompts_te_outputs = None
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
def assert_extra_args(
self,
args,
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
val_dataset_group: Optional[train_util.DatasetGroup],
):
# super().assert_extra_args(args, train_dataset_group)
# sdxl_train_util.verify_sdxl_training_args(args)
@@ -43,10 +48,6 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
assert (
train_dataset_group.is_text_encoder_output_cacheable()
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
assert args.apply_lg_attn_mask == args.apply_t5_attn_mask, (
"apply_lg_attn_mask and apply_t5_attn_mask must be the same when caching text encoder outputs"
" / text encoderの出力をキャッシュするときにはapply_lg_attn_maskとapply_t5_attn_maskは同じである必要があります"
)
# prepare CLIP-L/CLIP-G/T5XXL training flags
self.train_clip = not args.network_train_unet_only
@@ -192,8 +193,8 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
args.text_encoder_batch_size,
args.skip_cache_check,
is_partial=self.train_clip or self.train_t5xxl,
max_token_length=args.t5xxl_max_token_length,
apply_lg_attn_mask=args.apply_lg_attn_mask,
apply_t5_attn_mask=args.apply_t5_attn_mask,
)
else:
return None
@@ -303,7 +304,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.training_shift)
return noise_scheduler
def encode_images_to_latents(self, args, accelerator, vae, images):
def encode_images_to_latents(self, args, vae, images):
return vae.encode(images)
def shift_scale_latents(self, args, latents):
@@ -321,7 +322,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
network,
weight_dtype,
train_unet,
is_train=True
is_train=True,
):
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
@@ -449,14 +450,19 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
text_encoder.to(te_weight_dtype) # fp8
prepare_fp8(text_encoder, weight_dtype)
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
# drop cached text encoder outputs
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True):
# drop cached text encoder outputs: in validation, we drop cached outputs deterministically by fixed seed
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list)
batch["text_encoder_outputs_list"] = text_encoder_outputs_list
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
if self.is_swapping_blocks:
# prepare for next forward: because backward pass is not called, we need to prepare it here
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
def prepare_unet_with_accelerator(
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
) -> torch.nn.Module:

View File

@@ -273,7 +273,7 @@ def train(args):
vae.requires_grad_(False)
vae.eval()
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
train_dataset_group.new_cache_latents(vae, accelerator)
vae.to("cpu")
clean_memory_on_device(accelerator.device)
@@ -322,11 +322,7 @@ def train(args):
if args.cache_text_encoder_outputs:
# Text Encodes are eval and no grad
text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
None,
args.skip_cache_check,
args.max_token_length,
is_weighted=args.weighted_captions,
args.cache_text_encoder_outputs_to_disk, None, False, is_weighted=args.weighted_captions
)
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy)
@@ -644,14 +640,23 @@ def train(args):
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
with torch.no_grad():
# latentに変換
latents = vae.encode(batch["images"].to(vae_dtype)).latent_dist.sample().to(weight_dtype)
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.nan_to_num(latents, 0, out=latents)
if args.vae_batch_size is None or len(batch["images"]) <= args.vae_batch_size:
with torch.no_grad():
# latentに変換
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)
else:
chunks = [
batch["images"][i : i + args.vae_batch_size]
for i in range(0, len(batch["images"]), args.vae_batch_size)
]
list_latents = []
for chunk in chunks:
with torch.no_grad():
# latentに変換
list_latents.append(
vae.encode(chunk.to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)
)
latents = torch.cat(list_latents, dim=0)
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)

View File

@@ -209,7 +209,7 @@ def train(args):
vae.requires_grad_(False)
vae.eval()
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
train_dataset_group.new_cache_latents(vae, accelerator)
vae.to("cpu")
clean_memory_on_device(accelerator.device)
@@ -223,11 +223,7 @@ def train(args):
if args.cache_text_encoder_outputs:
# Text Encodes are eval and no grad
text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
None,
args.skip_cache_check,
args.max_token_length,
is_weighted=args.weighted_captions,
args.cache_text_encoder_outputs_to_disk, None, False
)
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy)

View File

@@ -181,7 +181,7 @@ def train(args):
vae.requires_grad_(False)
vae.eval()
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
train_dataset_group.new_cache_latents(vae, accelerator)
vae.to("cpu")
clean_memory_on_device(accelerator.device)
@@ -195,11 +195,7 @@ def train(args):
if args.cache_text_encoder_outputs:
# Text Encodes are eval and no grad
text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
None,
args.skip_cache_check,
args.max_token_length,
is_weighted=args.weighted_captions,
args.cache_text_encoder_outputs_to_disk, None, False
)
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy)

View File

@@ -24,7 +24,6 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
self.is_sdxl = True
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
sdxl_train_util.verify_sdxl_training_args(args)
if args.cache_text_encoder_outputs:
@@ -83,11 +82,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
def get_text_encoder_outputs_caching_strategy(self, args):
if args.cache_text_encoder_outputs:
return strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
None,
args.skip_cache_check,
args.max_token_length,
is_weighted=args.weighted_captions,
args.cache_text_encoder_outputs_to_disk, None, args.skip_cache_check, is_weighted=args.weighted_captions
)
else:
return None

View File

@@ -0,0 +1,220 @@
import pytest
import torch
from unittest.mock import MagicMock, patch
from library.flux_train_utils import (
get_noisy_model_input_and_timesteps,
)
# Mock classes and functions
class MockNoiseScheduler:
def __init__(self, num_train_timesteps=1000):
self.config = MagicMock()
self.config.num_train_timesteps = num_train_timesteps
self.timesteps = torch.arange(num_train_timesteps, dtype=torch.long)
# Create fixtures for commonly used objects
@pytest.fixture
def args():
args = MagicMock()
args.timestep_sampling = "uniform"
args.weighting_scheme = "uniform"
args.logit_mean = 0.0
args.logit_std = 1.0
args.mode_scale = 1.0
args.sigmoid_scale = 1.0
args.discrete_flow_shift = 3.1582
args.ip_noise_gamma = None
args.ip_noise_gamma_random_strength = False
return args
@pytest.fixture
def noise_scheduler():
return MockNoiseScheduler(num_train_timesteps=1000)
@pytest.fixture
def latents():
return torch.randn(2, 4, 8, 8)
@pytest.fixture
def noise():
return torch.randn(2, 4, 8, 8)
@pytest.fixture
def device():
# return "cuda" if torch.cuda.is_available() else "cpu"
return "cpu"
# Mock the required functions
@pytest.fixture(autouse=True)
def mock_functions():
with (
patch("torch.sigmoid", side_effect=torch.sigmoid),
patch("torch.rand", side_effect=torch.rand),
patch("torch.randn", side_effect=torch.randn),
):
yield
# Test different timestep sampling methods
def test_uniform_sampling(args, noise_scheduler, latents, noise, device):
args.timestep_sampling = "uniform"
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
assert noisy_input.dtype == dtype
assert timesteps.dtype == dtype
def test_sigmoid_sampling(args, noise_scheduler, latents, noise, device):
args.timestep_sampling = "sigmoid"
args.sigmoid_scale = 1.0
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
def test_shift_sampling(args, noise_scheduler, latents, noise, device):
args.timestep_sampling = "shift"
args.sigmoid_scale = 1.0
args.discrete_flow_shift = 3.1582
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
def test_flux_shift_sampling(args, noise_scheduler, latents, noise, device):
args.timestep_sampling = "flux_shift"
args.sigmoid_scale = 1.0
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
def test_weighting_scheme(args, noise_scheduler, latents, noise, device):
# Mock the necessary functions for this specific test
with patch("library.flux_train_utils.compute_density_for_timestep_sampling",
return_value=torch.tensor([0.3, 0.7], device=device)), \
patch("library.flux_train_utils.get_sigmas",
return_value=torch.tensor([[0.3], [0.7]], device=device).view(-1, 1, 1, 1)):
args.timestep_sampling = "other" # Will trigger the weighting scheme path
args.weighting_scheme = "uniform"
args.logit_mean = 0.0
args.logit_std = 1.0
args.mode_scale = 1.0
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, device, dtype
)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
# Test IP noise options
def test_with_ip_noise(args, noise_scheduler, latents, noise, device):
args.ip_noise_gamma = 0.5
args.ip_noise_gamma_random_strength = False
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
def test_with_random_ip_noise(args, noise_scheduler, latents, noise, device):
args.ip_noise_gamma = 0.1
args.ip_noise_gamma_random_strength = True
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
# Test different data types
def test_float16_dtype(args, noise_scheduler, latents, noise, device):
dtype = torch.float16
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.dtype == dtype
assert timesteps.dtype == dtype
# Test different batch sizes
def test_different_batch_size(args, noise_scheduler, device):
latents = torch.randn(5, 4, 8, 8) # batch size of 5
noise = torch.randn(5, 4, 8, 8)
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (5,)
assert sigmas.shape == (5, 1, 1, 1)
# Test different image sizes
def test_different_image_size(args, noise_scheduler, device):
latents = torch.randn(2, 4, 16, 16) # larger image size
noise = torch.randn(2, 4, 16, 16)
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (2,)
assert sigmas.shape == (2, 1, 1, 1)
# Test edge cases
def test_zero_batch_size(args, noise_scheduler, device):
with pytest.raises(AssertionError): # expecting an error with zero batch size
latents = torch.randn(0, 4, 8, 8)
noise = torch.randn(0, 4, 8, 8)
dtype = torch.float32
get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
def test_different_timestep_count(args, device):
noise_scheduler = MockNoiseScheduler(num_train_timesteps=500) # different timestep count
latents = torch.randn(2, 4, 8, 8)
noise = torch.randn(2, 4, 8, 8)
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (2,)
# Check that timesteps are within the proper range
assert torch.all(timesteps < 500)

View File

@@ -150,7 +150,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
# cache latents with dataset
# TODO use DataLoader to speed up
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
train_dataset_group.new_cache_latents(vae, accelerator)
accelerator.wait_for_everyone()
accelerator.print(f"Finished caching latents to disk.")

View File

@@ -15,7 +15,7 @@ import os
from anime_face_detector import create_detector
from tqdm import tqdm
import numpy as np
from library.utils import setup_logging, pil_resize
from library.utils import setup_logging, resize_image
setup_logging()
import logging
logger = logging.getLogger(__name__)
@@ -170,12 +170,9 @@ def process(args):
scale = max(cur_crop_width / w, cur_crop_height / h)
if scale != 1.0:
w = int(w * scale + .5)
h = int(h * scale + .5)
if scale < 1.0:
face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA)
else:
face_img = pil_resize(face_img, (w, h))
rw = int(w * scale + .5)
rh = int(h * scale + .5)
face_img = resize_image(face_img, w, h, rw, rh)
cx = int(cx * scale + .5)
cy = int(cy * scale + .5)
fw = int(fw * scale + .5)

View File

@@ -0,0 +1,166 @@
import argparse
import os
import gc
from typing import Dict, Optional, Union
import torch
from safetensors.torch import safe_open
from library.utils import setup_logging
from library.utils import load_safetensors, mem_eff_save_file, str_to_dtype
setup_logging()
import logging
logger = logging.getLogger(__name__)
def merge_safetensors(
dit_path: str,
vae_path: Optional[str] = None,
clip_l_path: Optional[str] = None,
clip_g_path: Optional[str] = None,
t5xxl_path: Optional[str] = None,
output_path: str = "merged_model.safetensors",
device: str = "cpu",
save_precision: Optional[str] = None,
):
"""
Merge multiple safetensors files into a single file
Args:
dit_path: Path to the DiT/MMDiT model
vae_path: Path to the VAE model
clip_l_path: Path to the CLIP-L model
clip_g_path: Path to the CLIP-G model
t5xxl_path: Path to the T5-XXL model
output_path: Path to save the merged model
device: Device to load tensors to
save_precision: Target dtype for model weights (e.g. 'fp16', 'bf16')
"""
logger.info("Starting to merge safetensors files...")
# Convert save_precision string to torch dtype if specified
if save_precision:
target_dtype = str_to_dtype(save_precision)
else:
target_dtype = None
# 1. Get DiT metadata if available
metadata = None
try:
with safe_open(dit_path, framework="pt") as f:
metadata = f.metadata() # may be None
if metadata:
logger.info(f"Found metadata in DiT model: {metadata}")
except Exception as e:
logger.warning(f"Failed to read metadata from DiT model: {e}")
# 2. Create empty merged state dict
merged_state_dict = {}
# 3. Load and merge each model with memory management
# DiT/MMDiT - prefix: model.diffusion_model.
# This state dict may have VAE keys.
logger.info(f"Loading DiT model from {dit_path}")
dit_state_dict = load_safetensors(dit_path, device=device, disable_mmap=True, dtype=target_dtype)
logger.info(f"Adding DiT model with {len(dit_state_dict)} keys")
for key, value in dit_state_dict.items():
if key.startswith("model.diffusion_model.") or key.startswith("first_stage_model."):
merged_state_dict[key] = value
else:
merged_state_dict[f"model.diffusion_model.{key}"] = value
# Free memory
del dit_state_dict
gc.collect()
# VAE - prefix: first_stage_model.
# May be omitted if VAE is already included in DiT model.
if vae_path:
logger.info(f"Loading VAE model from {vae_path}")
vae_state_dict = load_safetensors(vae_path, device=device, disable_mmap=True, dtype=target_dtype)
logger.info(f"Adding VAE model with {len(vae_state_dict)} keys")
for key, value in vae_state_dict.items():
if key.startswith("first_stage_model."):
merged_state_dict[key] = value
else:
merged_state_dict[f"first_stage_model.{key}"] = value
# Free memory
del vae_state_dict
gc.collect()
# CLIP-L - prefix: text_encoders.clip_l.
if clip_l_path:
logger.info(f"Loading CLIP-L model from {clip_l_path}")
clip_l_state_dict = load_safetensors(clip_l_path, device=device, disable_mmap=True, dtype=target_dtype)
logger.info(f"Adding CLIP-L model with {len(clip_l_state_dict)} keys")
for key, value in clip_l_state_dict.items():
if key.startswith("text_encoders.clip_l.transformer."):
merged_state_dict[key] = value
else:
merged_state_dict[f"text_encoders.clip_l.transformer.{key}"] = value
# Free memory
del clip_l_state_dict
gc.collect()
# CLIP-G - prefix: text_encoders.clip_g.
if clip_g_path:
logger.info(f"Loading CLIP-G model from {clip_g_path}")
clip_g_state_dict = load_safetensors(clip_g_path, device=device, disable_mmap=True, dtype=target_dtype)
logger.info(f"Adding CLIP-G model with {len(clip_g_state_dict)} keys")
for key, value in clip_g_state_dict.items():
if key.startswith("text_encoders.clip_g.transformer."):
merged_state_dict[key] = value
else:
merged_state_dict[f"text_encoders.clip_g.transformer.{key}"] = value
# Free memory
del clip_g_state_dict
gc.collect()
# T5-XXL - prefix: text_encoders.t5xxl.
if t5xxl_path:
logger.info(f"Loading T5-XXL model from {t5xxl_path}")
t5xxl_state_dict = load_safetensors(t5xxl_path, device=device, disable_mmap=True, dtype=target_dtype)
logger.info(f"Adding T5-XXL model with {len(t5xxl_state_dict)} keys")
for key, value in t5xxl_state_dict.items():
if key.startswith("text_encoders.t5xxl.transformer."):
merged_state_dict[key] = value
else:
merged_state_dict[f"text_encoders.t5xxl.transformer.{key}"] = value
# Free memory
del t5xxl_state_dict
gc.collect()
# 4. Save merged state dict
logger.info(f"Saving merged model to {output_path} with {len(merged_state_dict)} keys total")
mem_eff_save_file(merged_state_dict, output_path, metadata)
logger.info("Successfully merged safetensors files")
def main():
parser = argparse.ArgumentParser(description="Merge Stable Diffusion 3.5 model components into a single safetensors file")
parser.add_argument("--dit", required=True, help="Path to the DiT/MMDiT model")
parser.add_argument("--vae", help="Path to the VAE model. May be omitted if VAE is included in DiT model")
parser.add_argument("--clip_l", help="Path to the CLIP-L model")
parser.add_argument("--clip_g", help="Path to the CLIP-G model")
parser.add_argument("--t5xxl", help="Path to the T5-XXL model")
parser.add_argument("--output", default="merged_model.safetensors", help="Path to save the merged model")
parser.add_argument("--device", default="cpu", help="Device to load tensors to")
parser.add_argument("--save_precision", type=str, help="Precision to save the model in (e.g., 'fp16', 'bf16', 'float16', etc.)")
args = parser.parse_args()
merge_safetensors(
dit_path=args.dit,
vae_path=args.vae,
clip_l_path=args.clip_l,
clip_g_path=args.clip_g,
t5xxl_path=args.t5xxl,
output_path=args.output,
device=args.device,
save_precision=args.save_precision,
)
if __name__ == "__main__":
main()

View File

@@ -6,7 +6,7 @@ import shutil
import math
from PIL import Image
import numpy as np
from library.utils import setup_logging, pil_resize
from library.utils import setup_logging, resize_image
setup_logging()
import logging
logger = logging.getLogger(__name__)
@@ -22,14 +22,6 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
if not os.path.exists(dst_img_folder):
os.makedirs(dst_img_folder)
# Select interpolation method
if interpolation == 'lanczos4':
pil_interpolation = Image.LANCZOS
elif interpolation == 'cubic':
pil_interpolation = Image.BICUBIC
else:
cv2_interpolation = cv2.INTER_AREA
# Iterate through all files in src_img_folder
img_exts = (".png", ".jpg", ".jpeg", ".webp", ".bmp") # copy from train_util.py
for filename in os.listdir(src_img_folder):
@@ -63,11 +55,7 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
new_height = int(img.shape[0] * math.sqrt(scale_factor))
new_width = int(img.shape[1] * math.sqrt(scale_factor))
# Resize image
if cv2_interpolation:
img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation)
else:
img = pil_resize(img, (new_width, new_height), interpolation=pil_interpolation)
img = resize_image(img, img.shape[0], img.shape[1], new_height, new_width, interpolation)
else:
new_height, new_width = img.shape[0:2]
@@ -113,8 +101,8 @@ def setup_parser() -> argparse.ArgumentParser:
help='Maximum resolution(s) in the format "512x512,384x384, etc, etc" / 最大画像サイズをカンマ区切りで指定 ("512x512,384x384, etc, etc" など)', default="512x512,384x384,256x256,128x128")
parser.add_argument('--divisible_by', type=int,
help='Ensure new dimensions are divisible by this value / リサイズ後の画像のサイズをこの値で割り切れるようにします', default=1)
parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4'],
default='area', help='Interpolation method for resizing / サイズの補方法')
parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4', 'nearest', 'linear', 'box'],
default=None, help='Interpolation method for resizing. Default to area if smaller, lanczos if larger / サイズ変更の補方法。小さい場合はデフォルトでエリア、大きい場合はランチョスになります。')
parser.add_argument('--save_as_png', action='store_true', help='Save as png format / png形式で保存')
parser.add_argument('--copy_associated_files', action='store_true',
help='Copy files with same base name to images (captions etc) / 画像と同じファイル名(拡張子を除く)のファイルもコピーする')

View File

@@ -157,7 +157,7 @@ def train(args):
vae.requires_grad_(False)
vae.eval()
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
train_dataset_group.new_cache_latents(vae, accelerator)
vae.to("cpu")
clean_memory_on_device(accelerator.device)

View File

@@ -9,6 +9,7 @@ import random
import time
import json
from multiprocessing import Value
import numpy as np
import toml
from tqdm import tqdm
@@ -68,13 +69,20 @@ class NetworkTrainer:
keys_scaled=None,
mean_norm=None,
maximum_norm=None,
mean_grad_norm=None,
mean_combined_norm=None,
):
logs = {"loss/current": current_loss, "loss/average": avr_loss}
if keys_scaled is not None:
logs["max_norm/keys_scaled"] = keys_scaled
logs["max_norm/average_key_norm"] = mean_norm
logs["max_norm/max_key_norm"] = maximum_norm
if mean_norm is not None:
logs["norm/avg_key_norm"] = mean_norm
if mean_grad_norm is not None:
logs["norm/avg_grad_norm"] = mean_grad_norm
if mean_combined_norm is not None:
logs["norm/avg_combined_norm"] = mean_combined_norm
lrs = lr_scheduler.get_last_lr()
for i, lr in enumerate(lrs):
@@ -100,9 +108,7 @@ class NetworkTrainer:
if (
args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None
): # tracking d*lr value of unet.
logs["lr/d*lr"] = (
optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
)
logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
else:
idx = 0
if not args.network_train_unet_only:
@@ -115,16 +121,56 @@ class NetworkTrainer:
logs[f"lr/d*lr/group{i}"] = (
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
)
if (
args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None
):
logs[f"lr/d*lr/group{i}"] = (
optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"]
)
if args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None:
logs[f"lr/d*lr/group{i}"] = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"]
return logs
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
def step_logging(self, accelerator: Accelerator, logs: dict, global_step: int, epoch: int):
self.accelerator_logging(accelerator, logs, global_step, global_step, epoch)
def epoch_logging(self, accelerator: Accelerator, logs: dict, global_step: int, epoch: int):
self.accelerator_logging(accelerator, logs, epoch, global_step, epoch)
def val_logging(self, accelerator: Accelerator, logs: dict, global_step: int, epoch: int, val_step: int):
self.accelerator_logging(accelerator, logs, global_step + val_step, global_step, epoch, val_step)
def accelerator_logging(
self, accelerator: Accelerator, logs: dict, step_value: int, global_step: int, epoch: int, val_step: Optional[int] = None
):
"""
step_value is for tensorboard, other values are for wandb
"""
tensorboard_tracker = None
wandb_tracker = None
other_trackers = []
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
tensorboard_tracker = accelerator.get_tracker("tensorboard")
elif tracker.name == "wandb":
wandb_tracker = accelerator.get_tracker("wandb")
else:
other_trackers.append(accelerator.get_tracker(tracker.name))
if tensorboard_tracker is not None:
tensorboard_tracker.log(logs, step=step_value)
if wandb_tracker is not None:
logs["global_step"] = global_step
logs["epoch"] = epoch
if val_step is not None:
logs["val_step"] = val_step
wandb_tracker.log(logs)
for tracker in other_trackers:
tracker.log(logs, step=step_value)
def assert_extra_args(
self,
args,
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
val_dataset_group: Optional[train_util.DatasetGroup],
):
train_dataset_group.verify_bucket_reso_steps(64)
if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(64)
@@ -219,7 +265,7 @@ class NetworkTrainer:
network,
weight_dtype,
train_unet,
is_train=True
is_train=True,
):
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
@@ -309,28 +355,31 @@ class NetworkTrainer:
) -> torch.nn.Module:
return accelerator.prepare(unet)
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train: bool = True):
pass
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
pass
# endregion
def process_batch(
self,
batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy: strategy_base.TextEncodingStrategy,
tokenize_strategy: strategy_base.TokenizeStrategy,
is_train=True,
train_text_encoder=True,
train_unet=True
self,
batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy: strategy_base.TextEncodingStrategy,
tokenize_strategy: strategy_base.TokenizeStrategy,
is_train=True,
train_text_encoder=True,
train_unet=True,
) -> torch.Tensor:
"""
Process a batch for the network
@@ -340,7 +389,18 @@ class NetworkTrainer:
latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device))
else:
# latentに変換
latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype))
if args.vae_batch_size is None or len(batch["images"]) <= args.vae_batch_size:
latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype))
else:
chunks = [
batch["images"][i : i + args.vae_batch_size] for i in range(0, len(batch["images"]), args.vae_batch_size)
]
list_latents = []
for chunk in chunks:
with torch.no_grad():
chunk = self.encode_images_to_latents(args, vae, chunk.to(accelerator.device, dtype=vae_dtype))
list_latents.append(chunk)
latents = torch.cat(list_latents, dim=0)
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
@@ -397,7 +457,7 @@ class NetworkTrainer:
network,
weight_dtype,
train_unet,
is_train=is_train
is_train=is_train,
)
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
@@ -484,7 +544,7 @@ class NetworkTrainer:
else:
# use arbitrary dataset class
train_dataset_group = train_util.load_arbitrary_dataset(args)
val_dataset_group = None # placeholder until validation dataset supported for arbitrary
val_dataset_group = None # placeholder until validation dataset supported for arbitrary
current_epoch = Value("i", 0)
current_step = Value("i", 0)
@@ -559,9 +619,9 @@ class NetworkTrainer:
vae.requires_grad_(False)
vae.eval()
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
train_dataset_group.new_cache_latents(vae, accelerator)
if val_dataset_group is not None:
val_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
val_dataset_group.new_cache_latents(vae, accelerator)
vae.to("cpu")
clean_memory_on_device(accelerator.device)
@@ -609,6 +669,10 @@ class NetworkTrainer:
return
network_has_multiplier = hasattr(network, "set_multiplier")
# TODO remove `hasattr`s by setting up methods if not defined in the network like (hacky but works):
# if not hasattr(network, "prepare_network"):
# network.prepare_network = lambda args: None
if hasattr(network, "prepare_network"):
network.prepare_network(args)
if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"):
@@ -701,7 +765,7 @@ class NetworkTrainer:
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
val_dataloader = torch.utils.data.DataLoader(
val_dataset_group if val_dataset_group is not None else [],
shuffle=False,
@@ -900,7 +964,9 @@ class NetworkTrainer:
accelerator.print("running training / 学習開始")
accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
accelerator.print(f" num validation images * repeats / 学習画像の数×繰り返し回数: {val_dataset_group.num_train_images if val_dataset_group is not None else 0}")
accelerator.print(
f" num validation images * repeats / 学習画像の数×繰り返し回数: {val_dataset_group.num_train_images if val_dataset_group is not None else 0}"
)
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
@@ -968,11 +1034,12 @@ class NetworkTrainer:
"ss_huber_c": args.huber_c,
"ss_fp8_base": bool(args.fp8_base),
"ss_fp8_base_unet": bool(args.fp8_base_unet),
"ss_validation_seed": args.validation_seed,
"ss_validation_split": args.validation_split,
"ss_max_validation_steps": args.max_validation_steps,
"ss_validate_every_n_epochs": args.validate_every_n_epochs,
"ss_validate_every_n_steps": args.validate_every_n_steps,
"ss_validation_seed": args.validation_seed,
"ss_validation_split": args.validation_split,
"ss_max_validation_steps": args.max_validation_steps,
"ss_validate_every_n_epochs": args.validate_every_n_epochs,
"ss_validate_every_n_steps": args.validate_every_n_steps,
"ss_resize_interpolation": args.resize_interpolation,
}
self.update_metadata(metadata, args) # architecture specific metadata
@@ -998,6 +1065,7 @@ class NetworkTrainer:
"max_bucket_reso": dataset.max_bucket_reso,
"tag_frequency": dataset.tag_frequency,
"bucket_info": dataset.bucket_info,
"resize_interpolation": dataset.resize_interpolation,
}
subsets_metadata = []
@@ -1015,6 +1083,7 @@ class NetworkTrainer:
"enable_wildcard": bool(subset.enable_wildcard),
"caption_prefix": subset.caption_prefix,
"caption_suffix": subset.caption_suffix,
"resize_interpolation": subset.resize_interpolation,
}
image_dir_or_metadata_file = None
@@ -1163,10 +1232,6 @@ class NetworkTrainer:
args.max_train_steps > initial_step
), f"max_train_steps should be greater than initial step / max_train_stepsは初期ステップより大きい必要があります: {args.max_train_steps} vs {initial_step}"
progress_bar = tqdm(
range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps"
)
epoch_to_start = 0
if initial_step > 0:
if args.skip_until_initial_step:
@@ -1247,12 +1312,6 @@ class NetworkTrainer:
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)
validation_steps = (
min(args.max_validation_steps, len(val_dataloader))
if args.max_validation_steps is not None
else len(val_dataloader)
)
# training loop
if initial_step > 0: # only if skip_until_initial_step is specified
for skip_epoch in range(epoch_to_start): # skip epochs
@@ -1271,13 +1330,57 @@ class NetworkTrainer:
clean_memory_on_device(accelerator.device)
progress_bar = tqdm(
range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps"
)
validation_steps = (
min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader)
)
NUM_VALIDATION_TIMESTEPS = 4 # 200, 400, 600, 800 TODO make this configurable
min_timestep = 0 if args.min_timestep is None else args.min_timestep
max_timestep = noise_scheduler.num_train_timesteps if args.max_timestep is None else args.max_timestep
validation_timesteps = np.linspace(min_timestep, max_timestep, (NUM_VALIDATION_TIMESTEPS + 2), dtype=int)[1:-1]
validation_total_steps = validation_steps * len(validation_timesteps)
original_args_min_timestep = args.min_timestep
original_args_max_timestep = args.max_timestep
def switch_rng_state(seed: int) -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]:
cpu_rng_state = torch.get_rng_state()
if accelerator.device.type == "cuda":
gpu_rng_state = torch.cuda.get_rng_state()
elif accelerator.device.type == "xpu":
gpu_rng_state = torch.xpu.get_rng_state()
elif accelerator.device.type == "mps":
gpu_rng_state = torch.cuda.get_rng_state()
else:
gpu_rng_state = None
python_rng_state = random.getstate()
torch.manual_seed(seed)
random.seed(seed)
return (cpu_rng_state, gpu_rng_state, python_rng_state)
def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]):
cpu_rng_state, gpu_rng_state, python_rng_state = rng_states
torch.set_rng_state(cpu_rng_state)
if gpu_rng_state is not None:
if accelerator.device.type == "cuda":
torch.cuda.set_rng_state(gpu_rng_state)
elif accelerator.device.type == "xpu":
torch.xpu.set_rng_state(gpu_rng_state)
elif accelerator.device.type == "mps":
torch.cuda.set_rng_state(gpu_rng_state)
random.setstate(python_rng_state)
for epoch in range(epoch_to_start, num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n")
current_epoch.value = epoch + 1
metadata["ss_epoch"] = str(epoch + 1)
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet)
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) # network.train() is called here
# TRAINING
skipped_dataloader = None
@@ -1294,25 +1397,25 @@ class NetworkTrainer:
with accelerator.accumulate(training_model):
on_step_start_for_network(text_encoder, unet)
# temporary, for batch processing
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
# preprocess batch for each model
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True)
loss = self.process_batch(
batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train=True,
train_text_encoder=train_text_encoder,
train_unet=train_unet
batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train=True,
train_text_encoder=train_text_encoder,
train_unet=train_unet,
)
accelerator.backward(loss)
@@ -1322,6 +1425,11 @@ class NetworkTrainer:
params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
if hasattr(network, "update_grad_norms"):
network.update_grad_norms()
if hasattr(network, "update_norms"):
network.update_norms()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
@@ -1330,9 +1438,23 @@ class NetworkTrainer:
keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(
args.scale_weight_norms, accelerator.device
)
mean_grad_norm = None
mean_combined_norm = None
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
else:
keys_scaled, mean_norm, maximum_norm = None, None, None
if hasattr(network, "weight_norms"):
mean_norm = network.weight_norms().mean().item()
mean_grad_norm = network.grad_norms().mean().item()
mean_combined_norm = network.combined_weight_norms().mean().item()
weight_norms = network.weight_norms()
maximum_norm = weight_norms.max().item() if weight_norms.numel() > 0 else None
keys_scaled = None
max_mean_logs = {}
else:
keys_scaled, mean_norm, maximum_norm = None, None, None
mean_grad_norm = None
mean_combined_norm = None
max_mean_logs = {}
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
@@ -1364,153 +1486,179 @@ class NetworkTrainer:
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if args.scale_weight_norms:
progress_bar.set_postfix(**{**max_mean_logs, **logs})
progress_bar.set_postfix(**{**max_mean_logs, **logs})
if is_tracking:
logs = self.generate_step_logs(
args,
current_loss,
avr_loss,
lr_scheduler,
lr_descriptions,
optimizer,
keys_scaled,
mean_norm,
maximum_norm
args,
current_loss,
avr_loss,
lr_scheduler,
lr_descriptions,
optimizer,
keys_scaled,
mean_norm,
maximum_norm,
mean_grad_norm,
mean_combined_norm,
)
accelerator.log(logs, step=global_step)
self.step_logging(accelerator, logs, global_step, epoch + 1)
# VALIDATION PER STEP
should_validate_step = (
args.validate_every_n_steps is not None
and global_step != 0 # Skip first step
and global_step % args.validate_every_n_steps == 0
)
# VALIDATION PER STEP: global_step is already incremented
# for example, if validate_every_n_steps=100, validate at step 100, 200, 300, ...
should_validate_step = args.validate_every_n_steps is not None and global_step % args.validate_every_n_steps == 0
if accelerator.sync_gradients and validation_steps > 0 and should_validate_step:
optimizer_eval_fn()
accelerator.unwrap_model(network).eval()
rng_states = switch_rng_state(args.validation_seed if args.validation_seed is not None else args.seed)
val_progress_bar = tqdm(
range(validation_steps), smoothing=0,
disable=not accelerator.is_local_main_process,
desc="validation steps"
range(validation_total_steps),
smoothing=0,
disable=not accelerator.is_local_main_process,
desc="validation steps",
)
val_timesteps_step = 0
for val_step, batch in enumerate(val_dataloader):
if val_step >= validation_steps:
break
# temporary, for batch processing
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
for timestep in validation_timesteps:
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False)
loss = self.process_batch(
batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train=False,
train_text_encoder=False,
train_unet=False
)
args.min_timestep = args.max_timestep = timestep # dirty hack to change timestep
current_loss = loss.detach().item()
val_step_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
val_progress_bar.update(1)
val_progress_bar.set_postfix({ "val_avg_loss": val_step_loss_recorder.moving_average })
loss = self.process_batch(
batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train=False,
train_text_encoder=train_text_encoder, # this is needed for validation because Text Encoders must be called if train_text_encoder is True
train_unet=train_unet,
)
if is_tracking:
logs = {
"loss/validation/step_current": current_loss,
"val_step": (epoch * validation_steps) + val_step,
}
accelerator.log(logs, step=global_step)
current_loss = loss.detach().item()
val_step_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss)
val_progress_bar.update(1)
val_progress_bar.set_postfix(
{"val_avg_loss": val_step_loss_recorder.moving_average, "timestep": timestep}
)
# if is_tracking:
# logs = {f"loss/validation/step_current_{timestep}": current_loss}
# self.val_logging(accelerator, logs, global_step, epoch + 1, val_step)
self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
val_timesteps_step += 1
if is_tracking:
loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average
logs = {
"loss/validation/step_average": val_step_loss_recorder.moving_average,
"loss/validation/step_divergence": loss_validation_divergence,
"loss/validation/step_average": val_step_loss_recorder.moving_average,
"loss/validation/step_divergence": loss_validation_divergence,
}
accelerator.log(logs, step=global_step)
self.step_logging(accelerator, logs, global_step, epoch=epoch + 1)
restore_rng_state(rng_states)
args.min_timestep = original_args_min_timestep
args.max_timestep = original_args_max_timestep
optimizer_train_fn()
accelerator.unwrap_model(network).train()
progress_bar.unpause()
if global_step >= args.max_train_steps:
break
# EPOCH VALIDATION
should_validate_epoch = (
(epoch + 1) % args.validate_every_n_epochs == 0
if args.validate_every_n_epochs is not None
else True
(epoch + 1) % args.validate_every_n_epochs == 0 if args.validate_every_n_epochs is not None else True
)
if should_validate_epoch and len(val_dataloader) > 0:
optimizer_eval_fn()
accelerator.unwrap_model(network).eval()
rng_states = switch_rng_state(args.validation_seed if args.validation_seed is not None else args.seed)
val_progress_bar = tqdm(
range(validation_steps), smoothing=0,
disable=not accelerator.is_local_main_process,
desc="epoch validation steps"
range(validation_total_steps),
smoothing=0,
disable=not accelerator.is_local_main_process,
desc="epoch validation steps",
)
val_timesteps_step = 0
for val_step, batch in enumerate(val_dataloader):
if val_step >= validation_steps:
break
# temporary, for batch processing
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
for timestep in validation_timesteps:
args.min_timestep = args.max_timestep = timestep
loss = self.process_batch(
batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train=False,
train_text_encoder=False,
train_unet=False
)
# temporary, for batch processing
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False)
current_loss = loss.detach().item()
val_epoch_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
val_progress_bar.update(1)
val_progress_bar.set_postfix({ "val_epoch_avg_loss": val_epoch_loss_recorder.moving_average })
loss = self.process_batch(
batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train=False,
train_text_encoder=train_text_encoder,
train_unet=train_unet,
)
if is_tracking:
logs = {
"loss/validation/epoch_current": current_loss,
"epoch": epoch + 1,
"val_step": (epoch * validation_steps) + val_step
}
accelerator.log(logs, step=global_step)
current_loss = loss.detach().item()
val_epoch_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss)
val_progress_bar.update(1)
val_progress_bar.set_postfix(
{"val_epoch_avg_loss": val_epoch_loss_recorder.moving_average, "timestep": timestep}
)
# if is_tracking:
# logs = {f"loss/validation/epoch_current_{timestep}": current_loss}
# self.val_logging(accelerator, logs, global_step, epoch + 1, val_step)
self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
val_timesteps_step += 1
if is_tracking:
avr_loss: float = val_epoch_loss_recorder.moving_average
loss_validation_divergence = val_epoch_loss_recorder.moving_average - loss_recorder.moving_average
loss_validation_divergence = val_epoch_loss_recorder.moving_average - loss_recorder.moving_average
logs = {
"loss/validation/epoch_average": avr_loss,
"loss/validation/epoch_divergence": loss_validation_divergence,
"epoch": epoch + 1
"loss/validation/epoch_average": avr_loss,
"loss/validation/epoch_divergence": loss_validation_divergence,
}
accelerator.log(logs, step=global_step)
self.epoch_logging(accelerator, logs, global_step, epoch + 1)
restore_rng_state(rng_states)
args.min_timestep = original_args_min_timestep
args.max_timestep = original_args_max_timestep
optimizer_train_fn()
accelerator.unwrap_model(network).train()
progress_bar.unpause()
# END OF EPOCH
if is_tracking:
logs = {"loss/epoch_average": loss_recorder.moving_average, "epoch": epoch + 1}
accelerator.log(logs, step=global_step)
logs = {"loss/epoch_average": loss_recorder.moving_average}
self.epoch_logging(accelerator, logs, global_step, epoch + 1)
accelerator.wait_for_everyone()
# 指定エポックごとにモデルを保存
@@ -1696,31 +1844,31 @@ def setup_parser() -> argparse.ArgumentParser:
"--validation_seed",
type=int,
default=None,
help="Validation seed for shuffling validation dataset, training `--seed` used otherwise / 検証データセットをシャッフルするための検証シード、それ以外の場合はトレーニング `--seed` を使用する"
help="Validation seed for shuffling validation dataset, training `--seed` used otherwise / 検証データセットをシャッフルするための検証シード、それ以外の場合はトレーニング `--seed` を使用する",
)
parser.add_argument(
"--validation_split",
type=float,
default=0.0,
help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合"
help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合",
)
parser.add_argument(
"--validate_every_n_steps",
type=int,
default=None,
help="Run validation on validation dataset every N steps. By default, validation will only occur every epoch if a validation dataset is available / 検証データセットの検証をNステップごとに実行します。デフォルトでは、検証データセットが利用可能な場合にのみ、検証はエポックごとに実行されます"
help="Run validation on validation dataset every N steps. By default, validation will only occur every epoch if a validation dataset is available / 検証データセットの検証をNステップごとに実行します。デフォルトでは、検証データセットが利用可能な場合にのみ、検証はエポックごとに実行されます",
)
parser.add_argument(
"--validate_every_n_epochs",
type=int,
default=None,
help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available / 検証データセットをNエポックごとに実行します。デフォルトでは、検証データセットが利用可能な場合、検証はエポックごとに実行されます"
help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available / 検証データセットをNエポックごとに実行します。デフォルトでは、検証データセットが利用可能な場合、検証はエポックごとに実行されます",
)
parser.add_argument(
"--max_validation_steps",
type=int,
default=None,
help="Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します"
help="Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します",
)
return parser

View File

@@ -382,7 +382,7 @@ class TextualInversionTrainer:
vae.requires_grad_(False)
vae.eval()
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
train_dataset_group.new_cache_latents(vae, accelerator)
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()