mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
Compare commits
152 Commits
new_cache
...
no-grad-bl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0e8b5d4af9 | ||
|
|
5753b8ff6b | ||
|
|
2bfda1271b | ||
|
|
5b38d07f03 | ||
|
|
e2ed265104 | ||
|
|
e85813200a | ||
|
|
a27ace74d9 | ||
|
|
865c8d55e2 | ||
|
|
7c075a9c8d | ||
|
|
b4a89c3cdf | ||
|
|
f62c68df3c | ||
|
|
a4fae93dce | ||
|
|
1684ababcd | ||
|
|
64430eb9b2 | ||
|
|
d8717a3d1c | ||
|
|
a21b6a917e | ||
|
|
4625b34f4e | ||
|
|
80320d21fe | ||
|
|
29523c9b68 | ||
|
|
fd3a445769 | ||
|
|
13296ae93b | ||
|
|
0e8ac43760 | ||
|
|
bc9252cc1b | ||
|
|
3b25de1f17 | ||
|
|
f0b07c52ab | ||
|
|
309c44bdf2 | ||
|
|
8387e0b95c | ||
|
|
5c50cdbb44 | ||
|
|
46ad3be059 | ||
|
|
abf2c44bc5 | ||
|
|
adb775c616 | ||
|
|
b11c053b8f | ||
|
|
c46f08a87a | ||
|
|
0d9da0ea71 | ||
|
|
f501209c37 | ||
|
|
c8af252a44 | ||
|
|
7f984f4775 | ||
|
|
d33d5eccd1 | ||
|
|
7c61c0dfe0 | ||
|
|
26db64be17 | ||
|
|
629073cd9d | ||
|
|
06df0377f9 | ||
|
|
5a18a03ffc | ||
|
|
572cc3efb8 | ||
|
|
52c8dec953 | ||
|
|
4589262f8f | ||
|
|
c56dc90b26 | ||
|
|
ee0f754b08 | ||
|
|
606e6875d2 | ||
|
|
fd36fd1aa9 | ||
|
|
92845e8806 | ||
|
|
f1423a7229 | ||
|
|
b822b7e60b | ||
|
|
ede3470260 | ||
|
|
b3c56b22bd | ||
|
|
583ab27b3c | ||
|
|
aa5978dffd | ||
|
|
aaa26bb882 | ||
|
|
d0b5c0e5cf | ||
|
|
59d98e45a9 | ||
|
|
3149b2771f | ||
|
|
96a133c998 | ||
|
|
1f432e2c0e | ||
|
|
9e9a13aa8a | ||
|
|
93a4efabb5 | ||
|
|
381303d64f | ||
|
|
0181b7a042 | ||
|
|
182544dcce | ||
|
|
8ebe858f89 | ||
|
|
a0f11730f7 | ||
|
|
6364379f17 | ||
|
|
5253a38783 | ||
|
|
8f4ee8fc34 | ||
|
|
367f348430 | ||
|
|
89f0d27a59 | ||
|
|
d40f5b1e4e | ||
|
|
8aa126582e | ||
|
|
e8b3254858 | ||
|
|
16cef81aea | ||
|
|
d151833526 | ||
|
|
936d333ff4 | ||
|
|
f974c6b257 | ||
|
|
5d5a7d2acf | ||
|
|
1eddac26b0 | ||
|
|
8e6817b0c2 | ||
|
|
d93ad90a71 | ||
|
|
7197266703 | ||
|
|
5b210ad717 | ||
|
|
b81bcd0b01 | ||
|
|
6f4d365775 | ||
|
|
a4f3a9fc1a | ||
|
|
b425466e7b | ||
|
|
c8be141ae0 | ||
|
|
0b25a05e3c | ||
|
|
3647d065b5 | ||
|
|
620a06f517 | ||
|
|
564ec5fb7f | ||
|
|
7e90cdd47a | ||
|
|
c898e4e536 | ||
|
|
e5b5c7e1db | ||
|
|
ea53290f62 | ||
|
|
75933d70a1 | ||
|
|
aa2bde7ece | ||
|
|
3f49053c90 | ||
|
|
acdca2abb7 | ||
|
|
ba5251168a | ||
|
|
272f4c3775 | ||
|
|
734333d0c9 | ||
|
|
2f69f4dbdb | ||
|
|
9a415ba965 | ||
|
|
3d79239be4 | ||
|
|
ec350c83eb | ||
|
|
49651892ce | ||
|
|
1fcac98280 | ||
|
|
b286304e5f | ||
|
|
ae409e83c9 | ||
|
|
5228db1548 | ||
|
|
f4a0047865 | ||
|
|
f68702f71c | ||
|
|
6e90c0f86c | ||
|
|
67fde015f7 | ||
|
|
386b7332c6 | ||
|
|
905f081798 | ||
|
|
59ae9ea20c | ||
|
|
efb2a128cd | ||
|
|
13df47516d | ||
|
|
7f2747176b | ||
|
|
ca1c129ffd | ||
|
|
545425c13e | ||
|
|
7729c4c8f9 | ||
|
|
d0128d18be | ||
|
|
58e9e146a3 | ||
|
|
4a36996134 | ||
|
|
dc7d5fb459 | ||
|
|
63337d9fe4 | ||
|
|
76b761943b | ||
|
|
cd80752175 | ||
|
|
177203818a | ||
|
|
344845b429 | ||
|
|
0911683717 | ||
|
|
a24db1d532 | ||
|
|
c5b803ce94 | ||
|
|
4a71687d20 | ||
|
|
de830b8941 | ||
|
|
45ec02b2a8 | ||
|
|
42c0a9e1fc | ||
|
|
0750859133 | ||
|
|
29f31d005f | ||
|
|
b6a3093216 | ||
|
|
86a2f3fd26 | ||
|
|
532f5c58a6 | ||
|
|
a9c5aa1f93 |
3
.github/FUNDING.yml
vendored
Normal file
3
.github/FUNDING.yml
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
# These are supported funding model platforms
|
||||
|
||||
github: kohya-ss
|
||||
5
.github/workflows/tests.yml
vendored
5
.github/workflows/tests.yml
vendored
@@ -12,6 +12,9 @@ on:
|
||||
- dev
|
||||
- sd3
|
||||
|
||||
# CKV2_GHA_1: "Ensure top-level permissions are not set to write-all"
|
||||
permissions: read-all
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ${{ matrix.os }}
|
||||
@@ -40,7 +43,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
# Pre-install torch to pin version (requirements.txt has dependencies like transformers which requires pytorch)
|
||||
pip install dadaptation==3.2 torch==${{ matrix.pytorch-version }} torchvision==0.19.0 pytest==8.3.4
|
||||
pip install dadaptation==3.2 torch==${{ matrix.pytorch-version }} torchvision pytest==8.3.4
|
||||
pip install -r requirements.txt
|
||||
|
||||
- name: Test with pytest
|
||||
|
||||
3
.github/workflows/typos.yml
vendored
3
.github/workflows/typos.yml
vendored
@@ -12,6 +12,9 @@ on:
|
||||
- synchronize
|
||||
- reopened
|
||||
|
||||
# CKV2_GHA_1: "Ensure top-level permissions are not set to write-all"
|
||||
permissions: read-all
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
52
README.md
52
README.md
@@ -9,11 +9,44 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv
|
||||
The command to install PyTorch is as follows:
|
||||
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`
|
||||
|
||||
If you are using DeepSpeed, please install DeepSpeed with `pip install deepspeed==0.16.7`.
|
||||
|
||||
- [FLUX.1 training](#flux1-training)
|
||||
- [SD3 training](#sd3-training)
|
||||
|
||||
### Recent Updates
|
||||
|
||||
May 1, 2025:
|
||||
- The error when training FLUX.1 with mixed precision in flux_train.py with DeepSpeed enabled has been resolved. Thanks to sharlynxy for PR [#2060](https://github.com/kohya-ss/sd-scripts/pull/2060). Please refer to the PR for details.
|
||||
- If you enable DeepSpeed, please install DeepSpeed with `pip install deepspeed==0.16.7`.
|
||||
|
||||
Apr 27, 2025:
|
||||
- FLUX.1 training now supports CFG scale in the sample generation during training. Please use `--g` option, to specify the CFG scale (note that `--l` is used as the embedded guidance scale.) PR [#2064](https://github.com/kohya-ss/sd-scripts/pull/2064).
|
||||
- See [here](#sample-image-generation-during-training) for details.
|
||||
- If you have any issues with this, please let us know.
|
||||
|
||||
Apr 6, 2025:
|
||||
- IP noise gamma has been enabled in FLUX.1. Thanks to rockerBOO for PR [#1992](https://github.com/kohya-ss/sd-scripts/pull/1992). See the PR for details.
|
||||
- `--ip_noise_gamma` and `--ip_noise_gamma_random_strength` are available.
|
||||
|
||||
Mar 30, 2025:
|
||||
- LoRA-GGPO is added for FLUX.1 LoRA training. Thank you to rockerBOO for PR [#1974](https://github.com/kohya-ss/sd-scripts/pull/1974).
|
||||
- Specify `--network_args ggpo_sigma=0.03 ggpo_beta=0.01` in the command line or `network_args = ["ggpo_sigma=0.03", "ggpo_beta=0.01"]` in .toml file. See PR for details.
|
||||
- The interpolation method for resizing the original image to the training size can now be specified. Thank you to rockerBOO for PR [#1936](https://github.com/kohya-ss/sd-scripts/pull/1936).
|
||||
|
||||
Mar 20, 2025:
|
||||
- `pytorch-optimizer` is added to requirements.txt. Thank you to gesen2egee for PR [#1985](https://github.com/kohya-ss/sd-scripts/pull/1985).
|
||||
- For example, you can use CAME optimizer with `--optimizer_type "pytorch_optimizer.CAME" --optimizer_args "weight_decay=0.01"`.
|
||||
|
||||
Mar 6, 2025:
|
||||
|
||||
- Added a utility script to merge the weights of SD3's DiT, VAE (optional), CLIP-L, CLIP-G, and T5XXL into a single .safetensors file. Run `tools/merge_sd3_safetensors.py`. See `--help` for usage. PR [#1960](https://github.com/kohya-ss/sd-scripts/pull/1960)
|
||||
|
||||
Feb 26, 2025:
|
||||
|
||||
- Improve the validation loss calculation in `train_network.py`, `sdxl_train_network.py`, `flux_train_network.py`, and `sd3_train_network.py`. PR [#1903](https://github.com/kohya-ss/sd-scripts/pull/1903)
|
||||
- The validation loss uses the fixed timestep sampling and the fixed random seed. This is to ensure that the validation loss is not fluctuated by the random values.
|
||||
|
||||
Jan 25, 2025:
|
||||
|
||||
- `train_network.py`, `sdxl_train_network.py`, `flux_train_network.py`, and `sd3_train_network.py` now support validation loss. PR [#1864](https://github.com/kohya-ss/sd-scripts/pull/1864) Thank you to rockerBOO!
|
||||
@@ -739,6 +772,8 @@ Not available yet.
|
||||
[__Change History__](#change-history) is moved to the bottom of the page.
|
||||
更新履歴は[ページ末尾](#change-history)に移しました。
|
||||
|
||||
Latest update: 2025-03-21 (Version 0.9.1)
|
||||
|
||||
[日本語版READMEはこちら](./README-ja.md)
|
||||
|
||||
The development version is in the `dev` branch. Please check the dev branch for the latest changes.
|
||||
@@ -846,6 +881,14 @@ Note: Some user reports ``ValueError: fp16 mixed precision requires a GPU`` is o
|
||||
|
||||
(Single GPU with id `0` will be used.)
|
||||
|
||||
## DeepSpeed installation (experimental, Linux or WSL2 only)
|
||||
|
||||
To install DeepSpeed, run the following command in your activated virtual environment:
|
||||
|
||||
```bash
|
||||
pip install deepspeed==0.16.7
|
||||
```
|
||||
|
||||
## Upgrade
|
||||
|
||||
When a new release comes out you can upgrade your repo with the following command:
|
||||
@@ -882,6 +925,11 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
|
||||
|
||||
## Change History
|
||||
|
||||
### Mar 21, 2025 / 2025-03-21 Version 0.9.1
|
||||
|
||||
- Fixed a bug where some of LoRA modules for CLIP Text Encoder were not trained. Thank you Nekotekina for PR [#1964](https://github.com/kohya-ss/sd-scripts/pull/1964)
|
||||
- The LoRA modules for CLIP Text Encoder are now 264 modules, which is the same as before. Only 88 modules were trained in the previous version.
|
||||
|
||||
### Jan 17, 2025 / 2025-01-17 Version 0.9.0
|
||||
|
||||
- __important__ The dependent libraries are updated. Please see [Upgrade](#upgrade) and update the libraries.
|
||||
@@ -1315,11 +1363,13 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
|
||||
|
||||
Lines beginning with `#` are comments. You can specify options for the generated image with options like `--n` after the prompt. The following can be used.
|
||||
|
||||
* `--n` Negative prompt up to the next option.
|
||||
* `--n` Negative prompt up to the next option. Ignored when CFG scale is `1.0`.
|
||||
* `--w` Specifies the width of the generated image.
|
||||
* `--h` Specifies the height of the generated image.
|
||||
* `--d` Specifies the seed of the generated image.
|
||||
* `--l` Specifies the CFG scale of the generated image.
|
||||
* In guidance distillation models like FLUX.1, this value is used as the embedded guidance scale for backward compatibility.
|
||||
* `--g` Specifies the CFG scale for the models with embedded guidance scale. The default is `1.0`, `1.0` means no CFG. In general, should not be changed unless you train the un-distilled FLUX.1 models.
|
||||
* `--s` Specifies the number of steps in the generation.
|
||||
|
||||
The prompt weighting such as `( )` and `[ ]` are working.
|
||||
|
||||
@@ -152,6 +152,7 @@ These options are related to subset configuration.
|
||||
| `keep_tokens_separator` | `“|||”` | o | o | o |
|
||||
| `secondary_separator` | `“;;;”` | o | o | o |
|
||||
| `enable_wildcard` | `true` | o | o | o |
|
||||
| `resize_interpolation` | (not specified) | o | o | o |
|
||||
|
||||
* `num_repeats`
|
||||
* Specifies the number of repeats for images in a subset. This is equivalent to `--dataset_repeats` in fine-tuning but can be specified for any training method.
|
||||
@@ -165,6 +166,8 @@ These options are related to subset configuration.
|
||||
* Specifies an additional separator. The part separated by this separator is treated as one tag and is shuffled and dropped. It is then replaced by `caption_separator`. For example, if you specify `aaa;;;bbb;;;ccc`, it will be replaced by `aaa,bbb,ccc` or dropped together.
|
||||
* `enable_wildcard`
|
||||
* Enables wildcard notation. This will be explained later.
|
||||
* `resize_interpolation`
|
||||
* Specifies the interpolation method used when resizing images. Normally, there is no need to specify this. The following options can be specified: `lanczos`, `nearest`, `bilinear`, `linear`, `bicubic`, `cubic`, `area`, `box`. By default (when not specified), `area` is used for downscaling, and `lanczos` is used for upscaling. If this option is specified, the same interpolation method will be used for both upscaling and downscaling. When `lanczos` or `box` is specified, PIL is used; for other options, OpenCV is used.
|
||||
|
||||
### DreamBooth-specific options
|
||||
|
||||
|
||||
@@ -144,6 +144,7 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学
|
||||
| `keep_tokens_separator` | `“|||”` | o | o | o |
|
||||
| `secondary_separator` | `“;;;”` | o | o | o |
|
||||
| `enable_wildcard` | `true` | o | o | o |
|
||||
| `resize_interpolation` |(通常は設定しません) | o | o | o |
|
||||
|
||||
* `num_repeats`
|
||||
* サブセットの画像の繰り返し回数を指定します。fine tuning における `--dataset_repeats` に相当しますが、`num_repeats` はどの学習方法でも指定可能です。
|
||||
@@ -162,6 +163,9 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学
|
||||
* `enable_wildcard`
|
||||
* ワイルドカード記法および複数行キャプションを有効にします。ワイルドカード記法、複数行キャプションについては後述します。
|
||||
|
||||
* `resize_interpolation`
|
||||
* 画像のリサイズ時に使用する補間方法を指定します。通常は指定しなくて構いません。`lanczos`, `nearest`, `bilinear`, `linear`, `bicubic`, `cubic`, `area`, `box` が指定可能です。デフォルト(未指定時)は、縮小時は `area`、拡大時は `lanczos` になります。このオプションを指定すると、拡大時・縮小時とも同じ補間方法が使用されます。`lanczos`、`box`を指定するとPILが、それ以外を指定するとOpenCVが使用されます。
|
||||
|
||||
### DreamBooth 方式専用のオプション
|
||||
|
||||
DreamBooth 方式のオプションは、サブセット向けオプションのみ存在します。
|
||||
|
||||
@@ -11,7 +11,7 @@ from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
import library.train_util as train_util
|
||||
from library.utils import setup_logging, pil_resize
|
||||
from library.utils import setup_logging, resize_image
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
@@ -42,10 +42,7 @@ def preprocess_image(image):
|
||||
pad_t = pad_y // 2
|
||||
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)
|
||||
|
||||
if size > IMAGE_SIZE:
|
||||
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), cv2.INTER_AREA)
|
||||
else:
|
||||
image = pil_resize(image, (IMAGE_SIZE, IMAGE_SIZE))
|
||||
image = resize_image(image, image.shape[0], image.shape[1], IMAGE_SIZE, IMAGE_SIZE)
|
||||
|
||||
image = image.astype(np.float32)
|
||||
return image
|
||||
@@ -100,15 +97,19 @@ def main(args):
|
||||
else:
|
||||
for file in SUB_DIR_FILES:
|
||||
hf_hub_download(
|
||||
args.repo_id,
|
||||
file,
|
||||
repo_id=args.repo_id,
|
||||
filename=file,
|
||||
subfolder=SUB_DIR,
|
||||
cache_dir=os.path.join(model_location, SUB_DIR),
|
||||
local_dir=os.path.join(model_location, SUB_DIR),
|
||||
force_download=True,
|
||||
force_filename=file,
|
||||
)
|
||||
for file in files:
|
||||
hf_hub_download(args.repo_id, file, cache_dir=model_location, force_download=True, force_filename=file)
|
||||
hf_hub_download(
|
||||
repo_id=args.repo_id,
|
||||
filename=file,
|
||||
local_dir=model_location,
|
||||
force_download=True,
|
||||
)
|
||||
else:
|
||||
logger.info("using existing wd14 tagger model")
|
||||
|
||||
|
||||
@@ -36,7 +36,12 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
self.is_schnell: Optional[bool] = None
|
||||
self.is_swapping_blocks: bool = False
|
||||
|
||||
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
|
||||
def assert_extra_args(
|
||||
self,
|
||||
args,
|
||||
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
|
||||
val_dataset_group: Optional[train_util.DatasetGroup],
|
||||
):
|
||||
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
|
||||
# sdxl_train_util.verify_sdxl_training_args(args)
|
||||
|
||||
@@ -323,7 +328,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
||||
return noise_scheduler
|
||||
|
||||
def encode_images_to_latents(self, args, accelerator, vae, images):
|
||||
def encode_images_to_latents(self, args, vae, images):
|
||||
return vae.encode(images)
|
||||
|
||||
def shift_scale_latents(self, args, latents):
|
||||
@@ -341,7 +346,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
network,
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
is_train=True
|
||||
is_train=True,
|
||||
):
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
@@ -376,8 +381,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
t5_attn_mask = None
|
||||
|
||||
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
|
||||
# if not args.split_mode:
|
||||
# normal forward
|
||||
# grad is enabled even if unet is not in train mode, because Text Encoder is in train mode
|
||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
model_pred = unet(
|
||||
@@ -390,44 +394,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=t5_attn_mask,
|
||||
)
|
||||
"""
|
||||
else:
|
||||
# split forward to reduce memory usage
|
||||
assert network.train_blocks == "single", "train_blocks must be single for split mode"
|
||||
with accelerator.autocast():
|
||||
# move flux lower to cpu, and then move flux upper to gpu
|
||||
unet.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
self.flux_upper.to(accelerator.device)
|
||||
|
||||
# upper model does not require grad
|
||||
with torch.no_grad():
|
||||
intermediate_img, intermediate_txt, vec, pe = self.flux_upper(
|
||||
img=packed_noisy_model_input,
|
||||
img_ids=img_ids,
|
||||
txt=t5_out,
|
||||
txt_ids=txt_ids,
|
||||
y=l_pooled,
|
||||
timesteps=timesteps / 1000,
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=t5_attn_mask,
|
||||
)
|
||||
|
||||
# move flux upper back to cpu, and then move flux lower to gpu
|
||||
self.flux_upper.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
unet.to(accelerator.device)
|
||||
|
||||
# lower model requires grad
|
||||
intermediate_img.requires_grad_(True)
|
||||
intermediate_txt.requires_grad_(True)
|
||||
vec.requires_grad_(True)
|
||||
pe.requires_grad_(True)
|
||||
|
||||
with torch.set_grad_enabled(is_train and train_unet):
|
||||
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
|
||||
"""
|
||||
|
||||
return model_pred
|
||||
|
||||
model_pred = call_dit(
|
||||
@@ -546,6 +512,11 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
text_encoder.to(te_weight_dtype) # fp8
|
||||
prepare_fp8(text_encoder, weight_dtype)
|
||||
|
||||
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
||||
if self.is_swapping_blocks:
|
||||
# prepare for next forward: because backward pass is not called, we need to prepare it here
|
||||
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
|
||||
|
||||
def prepare_unet_with_accelerator(
|
||||
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
|
||||
) -> torch.nn.Module:
|
||||
|
||||
@@ -75,6 +75,7 @@ class BaseSubsetParams:
|
||||
custom_attributes: Optional[Dict[str, Any]] = None
|
||||
validation_seed: int = 0
|
||||
validation_split: float = 0.0
|
||||
resize_interpolation: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -106,7 +107,7 @@ class BaseDatasetParams:
|
||||
debug_dataset: bool = False
|
||||
validation_seed: Optional[int] = None
|
||||
validation_split: float = 0.0
|
||||
|
||||
resize_interpolation: Optional[str] = None
|
||||
|
||||
@dataclass
|
||||
class DreamBoothDatasetParams(BaseDatasetParams):
|
||||
@@ -196,6 +197,7 @@ class ConfigSanitizer:
|
||||
"caption_prefix": str,
|
||||
"caption_suffix": str,
|
||||
"custom_attributes": dict,
|
||||
"resize_interpolation": str,
|
||||
}
|
||||
# DO means DropOut
|
||||
DO_SUBSET_ASCENDABLE_SCHEMA = {
|
||||
@@ -241,6 +243,7 @@ class ConfigSanitizer:
|
||||
"validation_split": float,
|
||||
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
|
||||
"network_multiplier": float,
|
||||
"resize_interpolation": str,
|
||||
}
|
||||
|
||||
# options handled by argparse but not handled by user config
|
||||
@@ -525,6 +528,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
[{dataset_type} {i}]
|
||||
batch_size: {dataset.batch_size}
|
||||
resolution: {(dataset.width, dataset.height)}
|
||||
resize_interpolation: {dataset.resize_interpolation}
|
||||
enable_bucket: {dataset.enable_bucket}
|
||||
""")
|
||||
|
||||
@@ -558,6 +562,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
token_warmup_min: {subset.token_warmup_min},
|
||||
token_warmup_step: {subset.token_warmup_step},
|
||||
alpha_mask: {subset.alpha_mask}
|
||||
resize_interpolation: {subset.resize_interpolation}
|
||||
custom_attributes: {subset.custom_attributes}
|
||||
"""), " ")
|
||||
|
||||
|
||||
@@ -42,19 +42,20 @@ def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, laye
|
||||
|
||||
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
||||
|
||||
stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream):
|
||||
# cuda to cpu
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
cuda_data_view.record_stream(stream)
|
||||
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
|
||||
with torch.no_grad():
|
||||
stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream):
|
||||
# cuda to cpu
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
cuda_data_view.record_stream(stream)
|
||||
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
|
||||
|
||||
stream.synchronize()
|
||||
stream.synchronize()
|
||||
|
||||
# cpu to cuda
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
||||
module_to_cuda.weight.data = cuda_data_view
|
||||
# cpu to cuda
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
||||
module_to_cuda.weight.data = cuda_data_view
|
||||
|
||||
stream.synchronize()
|
||||
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
||||
@@ -75,14 +76,14 @@ def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, l
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
|
||||
|
||||
synchronize_device()
|
||||
synchronize_device(device)
|
||||
|
||||
# cpu to device
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
||||
module_to_cuda.weight.data = cuda_data_view
|
||||
|
||||
synchronize_device()
|
||||
synchronize_device(device)
|
||||
|
||||
|
||||
def weighs_to_device(layer: nn.Module, device: torch.device):
|
||||
|
||||
@@ -5,6 +5,8 @@ from accelerate import DeepSpeedPlugin, Accelerator
|
||||
|
||||
from .utils import setup_logging
|
||||
|
||||
from .device_utils import get_preferred_device
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
@@ -94,6 +96,7 @@ def prepare_deepspeed_plugin(args: argparse.Namespace):
|
||||
deepspeed_plugin.deepspeed_config["train_batch_size"] = (
|
||||
args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"])
|
||||
)
|
||||
|
||||
deepspeed_plugin.set_mixed_precision(args.mixed_precision)
|
||||
if args.mixed_precision.lower() == "fp16":
|
||||
deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow.
|
||||
@@ -122,18 +125,56 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models):
|
||||
class DeepSpeedWrapper(torch.nn.Module):
|
||||
def __init__(self, **kw_models) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.models = torch.nn.ModuleDict()
|
||||
|
||||
wrap_model_forward_with_torch_autocast = args.mixed_precision is not "no"
|
||||
|
||||
for key, model in kw_models.items():
|
||||
if isinstance(model, list):
|
||||
model = torch.nn.ModuleList(model)
|
||||
|
||||
if wrap_model_forward_with_torch_autocast:
|
||||
model = self.__wrap_model_with_torch_autocast(model)
|
||||
|
||||
assert isinstance(
|
||||
model, torch.nn.Module
|
||||
), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
|
||||
|
||||
self.models.update(torch.nn.ModuleDict({key: model}))
|
||||
|
||||
def __wrap_model_with_torch_autocast(self, model):
|
||||
if isinstance(model, torch.nn.ModuleList):
|
||||
model = torch.nn.ModuleList([self.__wrap_model_forward_with_torch_autocast(m) for m in model])
|
||||
else:
|
||||
model = self.__wrap_model_forward_with_torch_autocast(model)
|
||||
return model
|
||||
|
||||
def __wrap_model_forward_with_torch_autocast(self, model):
|
||||
|
||||
assert hasattr(model, "forward"), f"model must have a forward method."
|
||||
|
||||
forward_fn = model.forward
|
||||
|
||||
def forward(*args, **kwargs):
|
||||
try:
|
||||
device_type = model.device.type
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
"[DeepSpeed] model.device is not available. Using get_preferred_device() "
|
||||
"to determine the device_type for torch.autocast()."
|
||||
)
|
||||
device_type = get_preferred_device().type
|
||||
|
||||
with torch.autocast(device_type = device_type):
|
||||
return forward_fn(*args, **kwargs)
|
||||
|
||||
model.forward = forward
|
||||
return model
|
||||
|
||||
def get_models(self):
|
||||
return self.models
|
||||
|
||||
|
||||
ds_model = DeepSpeedWrapper(**models)
|
||||
return ds_model
|
||||
|
||||
@@ -2,6 +2,13 @@ import functools
|
||||
import gc
|
||||
|
||||
import torch
|
||||
try:
|
||||
# intel gpu support for pytorch older than 2.5
|
||||
# ipex is not needed after pytorch 2.5
|
||||
import intel_extension_for_pytorch as ipex # noqa
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
try:
|
||||
HAS_CUDA = torch.cuda.is_available()
|
||||
@@ -14,8 +21,6 @@ except Exception:
|
||||
HAS_MPS = False
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex # noqa
|
||||
|
||||
HAS_XPU = torch.xpu.is_available()
|
||||
except Exception:
|
||||
HAS_XPU = False
|
||||
@@ -69,7 +74,7 @@ def init_ipex():
|
||||
|
||||
This function should run right after importing torch and before doing anything else.
|
||||
|
||||
If IPEX is not available, this function does nothing.
|
||||
If xpu is not available, this function does nothing.
|
||||
"""
|
||||
try:
|
||||
if HAS_XPU:
|
||||
|
||||
@@ -40,7 +40,7 @@ def sample_images(
|
||||
text_encoders,
|
||||
sample_prompts_te_outputs,
|
||||
prompt_replacement=None,
|
||||
controlnet=None
|
||||
controlnet=None,
|
||||
):
|
||||
if steps == 0:
|
||||
if not args.sample_at_first:
|
||||
@@ -101,7 +101,7 @@ def sample_images(
|
||||
steps,
|
||||
sample_prompts_te_outputs,
|
||||
prompt_replacement,
|
||||
controlnet
|
||||
controlnet,
|
||||
)
|
||||
else:
|
||||
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
|
||||
@@ -125,7 +125,7 @@ def sample_images(
|
||||
steps,
|
||||
sample_prompts_te_outputs,
|
||||
prompt_replacement,
|
||||
controlnet
|
||||
controlnet,
|
||||
)
|
||||
|
||||
torch.set_rng_state(rng_state)
|
||||
@@ -147,14 +147,16 @@ def sample_image_inference(
|
||||
steps,
|
||||
sample_prompts_te_outputs,
|
||||
prompt_replacement,
|
||||
controlnet
|
||||
controlnet,
|
||||
):
|
||||
assert isinstance(prompt_dict, dict)
|
||||
# negative_prompt = prompt_dict.get("negative_prompt")
|
||||
negative_prompt = prompt_dict.get("negative_prompt")
|
||||
sample_steps = prompt_dict.get("sample_steps", 20)
|
||||
width = prompt_dict.get("width", 512)
|
||||
height = prompt_dict.get("height", 512)
|
||||
scale = prompt_dict.get("scale", 3.5)
|
||||
# TODO refactor variable names
|
||||
cfg_scale = prompt_dict.get("guidance_scale", 1.0)
|
||||
emb_guidance_scale = prompt_dict.get("scale", 3.5)
|
||||
seed = prompt_dict.get("seed")
|
||||
controlnet_image = prompt_dict.get("controlnet_image")
|
||||
prompt: str = prompt_dict.get("prompt", "")
|
||||
@@ -162,8 +164,8 @@ def sample_image_inference(
|
||||
|
||||
if prompt_replacement is not None:
|
||||
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
# if negative_prompt is not None:
|
||||
# negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
if negative_prompt is not None:
|
||||
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
|
||||
if seed is not None:
|
||||
torch.manual_seed(seed)
|
||||
@@ -173,16 +175,21 @@ def sample_image_inference(
|
||||
torch.seed()
|
||||
torch.cuda.seed()
|
||||
|
||||
# if negative_prompt is None:
|
||||
# negative_prompt = ""
|
||||
if negative_prompt is None:
|
||||
negative_prompt = ""
|
||||
height = max(64, height - height % 16) # round to divisible by 16
|
||||
width = max(64, width - width % 16) # round to divisible by 16
|
||||
logger.info(f"prompt: {prompt}")
|
||||
# logger.info(f"negative_prompt: {negative_prompt}")
|
||||
if cfg_scale != 1.0:
|
||||
logger.info(f"negative_prompt: {negative_prompt}")
|
||||
elif negative_prompt != "":
|
||||
logger.info(f"negative prompt is ignored because scale is 1.0")
|
||||
logger.info(f"height: {height}")
|
||||
logger.info(f"width: {width}")
|
||||
logger.info(f"sample_steps: {sample_steps}")
|
||||
logger.info(f"scale: {scale}")
|
||||
logger.info(f"embedded guidance scale: {emb_guidance_scale}")
|
||||
if cfg_scale != 1.0:
|
||||
logger.info(f"CFG scale: {cfg_scale}")
|
||||
# logger.info(f"sample_sampler: {sampler_name}")
|
||||
if seed is not None:
|
||||
logger.info(f"seed: {seed}")
|
||||
@@ -191,26 +198,37 @@ def sample_image_inference(
|
||||
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
|
||||
encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||
|
||||
text_encoder_conds = []
|
||||
if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs:
|
||||
text_encoder_conds = sample_prompts_te_outputs[prompt]
|
||||
print(f"Using cached text encoder outputs for prompt: {prompt}")
|
||||
if text_encoders is not None:
|
||||
print(f"Encoding prompt: {prompt}")
|
||||
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
||||
# strategy has apply_t5_attn_mask option
|
||||
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
|
||||
def encode_prompt(prpt):
|
||||
text_encoder_conds = []
|
||||
if sample_prompts_te_outputs and prpt in sample_prompts_te_outputs:
|
||||
text_encoder_conds = sample_prompts_te_outputs[prpt]
|
||||
print(f"Using cached text encoder outputs for prompt: {prpt}")
|
||||
if text_encoders is not None:
|
||||
print(f"Encoding prompt: {prpt}")
|
||||
tokens_and_masks = tokenize_strategy.tokenize(prpt)
|
||||
# strategy has apply_t5_attn_mask option
|
||||
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
|
||||
|
||||
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
|
||||
if len(text_encoder_conds) == 0:
|
||||
text_encoder_conds = encoded_text_encoder_conds
|
||||
else:
|
||||
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
|
||||
for i in range(len(encoded_text_encoder_conds)):
|
||||
if encoded_text_encoder_conds[i] is not None:
|
||||
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
||||
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
|
||||
if len(text_encoder_conds) == 0:
|
||||
text_encoder_conds = encoded_text_encoder_conds
|
||||
else:
|
||||
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
|
||||
for i in range(len(encoded_text_encoder_conds)):
|
||||
if encoded_text_encoder_conds[i] is not None:
|
||||
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
||||
return text_encoder_conds
|
||||
|
||||
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
||||
l_pooled, t5_out, txt_ids, t5_attn_mask = encode_prompt(prompt)
|
||||
# encode negative prompts
|
||||
if cfg_scale != 1.0:
|
||||
neg_l_pooled, neg_t5_out, _, neg_t5_attn_mask = encode_prompt(negative_prompt)
|
||||
neg_t5_attn_mask = (
|
||||
neg_t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask and neg_t5_attn_mask is not None else None
|
||||
)
|
||||
neg_cond = (cfg_scale, neg_l_pooled, neg_t5_out, neg_t5_attn_mask)
|
||||
else:
|
||||
neg_cond = None
|
||||
|
||||
# sample image
|
||||
weight_dtype = ae.dtype # TOFO give dtype as argument
|
||||
@@ -235,7 +253,20 @@ def sample_image_inference(
|
||||
controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device)
|
||||
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image)
|
||||
x = denoise(
|
||||
flux,
|
||||
noise,
|
||||
img_ids,
|
||||
t5_out,
|
||||
txt_ids,
|
||||
l_pooled,
|
||||
timesteps=timesteps,
|
||||
guidance=emb_guidance_scale,
|
||||
t5_attn_mask=t5_attn_mask,
|
||||
controlnet=controlnet,
|
||||
controlnet_img=controlnet_image,
|
||||
neg_cond=neg_cond,
|
||||
)
|
||||
|
||||
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
|
||||
|
||||
@@ -305,22 +336,24 @@ def denoise(
|
||||
model: flux_models.Flux,
|
||||
img: torch.Tensor,
|
||||
img_ids: torch.Tensor,
|
||||
txt: torch.Tensor,
|
||||
txt: torch.Tensor, # t5_out
|
||||
txt_ids: torch.Tensor,
|
||||
vec: torch.Tensor,
|
||||
vec: torch.Tensor, # l_pooled
|
||||
timesteps: list[float],
|
||||
guidance: float = 4.0,
|
||||
t5_attn_mask: Optional[torch.Tensor] = None,
|
||||
controlnet: Optional[flux_models.ControlNetFlux] = None,
|
||||
controlnet_img: Optional[torch.Tensor] = None,
|
||||
neg_cond: Optional[Tuple[float, torch.Tensor, torch.Tensor, torch.Tensor]] = None,
|
||||
):
|
||||
# this is ignored for schnell
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
|
||||
do_cfg = neg_cond is not None
|
||||
|
||||
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
model.prepare_block_swap_before_forward()
|
||||
|
||||
if controlnet is not None:
|
||||
block_samples, block_single_samples = controlnet(
|
||||
img=img,
|
||||
@@ -336,20 +369,48 @@ def denoise(
|
||||
else:
|
||||
block_samples = None
|
||||
block_single_samples = None
|
||||
pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
block_controlnet_hidden_states=block_samples,
|
||||
block_controlnet_single_hidden_states=block_single_samples,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=t5_attn_mask,
|
||||
)
|
||||
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
if not do_cfg:
|
||||
pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
block_controlnet_hidden_states=block_samples,
|
||||
block_controlnet_single_hidden_states=block_single_samples,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=t5_attn_mask,
|
||||
)
|
||||
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
else:
|
||||
cfg_scale, neg_l_pooled, neg_t5_out, neg_t5_attn_mask = neg_cond
|
||||
nc_c_t5_attn_mask = None if t5_attn_mask is None else torch.cat([neg_t5_attn_mask, t5_attn_mask], dim=0)
|
||||
|
||||
# TODO is it ok to use the same block samples for both cond and uncond?
|
||||
block_samples = None if block_samples is None else torch.cat([block_samples, block_samples], dim=0)
|
||||
block_single_samples = (
|
||||
None if block_single_samples is None else torch.cat([block_single_samples, block_single_samples], dim=0)
|
||||
)
|
||||
|
||||
nc_c_pred = model(
|
||||
img=torch.cat([img, img], dim=0),
|
||||
img_ids=torch.cat([img_ids, img_ids], dim=0),
|
||||
txt=torch.cat([neg_t5_out, txt], dim=0),
|
||||
txt_ids=torch.cat([txt_ids, txt_ids], dim=0),
|
||||
y=torch.cat([neg_l_pooled, vec], dim=0),
|
||||
block_controlnet_hidden_states=block_samples,
|
||||
block_controlnet_single_hidden_states=block_single_samples,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=nc_c_t5_attn_mask,
|
||||
)
|
||||
neg_pred, pred = torch.chunk(nc_c_pred, 2, dim=0)
|
||||
pred = neg_pred + (pred - neg_pred) * cfg_scale
|
||||
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
|
||||
model.prepare_block_swap_before_forward()
|
||||
return img
|
||||
@@ -366,8 +427,6 @@ def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32)
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < n_dim:
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
return sigma
|
||||
|
||||
|
||||
@@ -410,42 +469,34 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
||||
|
||||
|
||||
def get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler, latents, noise, device, dtype
|
||||
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
bsz, _, h, w = latents.shape
|
||||
sigmas = None
|
||||
|
||||
assert bsz > 0, "Batch size not large enough"
|
||||
num_timesteps = noise_scheduler.config.num_train_timesteps
|
||||
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
|
||||
# Simple random t-based noise sampling
|
||||
# Simple random sigma-based noise sampling
|
||||
if args.timestep_sampling == "sigmoid":
|
||||
# https://github.com/XLabs-AI/x-flux/tree/main
|
||||
t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
|
||||
sigmas = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
|
||||
else:
|
||||
t = torch.rand((bsz,), device=device)
|
||||
sigmas = torch.rand((bsz,), device=device)
|
||||
|
||||
timesteps = t * 1000.0
|
||||
t = t.view(-1, 1, 1, 1)
|
||||
noisy_model_input = (1 - t) * latents + t * noise
|
||||
timesteps = sigmas * num_timesteps
|
||||
elif args.timestep_sampling == "shift":
|
||||
shift = args.discrete_flow_shift
|
||||
logits_norm = torch.randn(bsz, device=device)
|
||||
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
|
||||
timesteps = logits_norm.sigmoid()
|
||||
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
|
||||
|
||||
t = timesteps.view(-1, 1, 1, 1)
|
||||
timesteps = timesteps * 1000.0
|
||||
noisy_model_input = (1 - t) * latents + t * noise
|
||||
sigmas = torch.randn(bsz, device=device)
|
||||
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
|
||||
sigmas = sigmas.sigmoid()
|
||||
sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas)
|
||||
timesteps = sigmas * num_timesteps
|
||||
elif args.timestep_sampling == "flux_shift":
|
||||
logits_norm = torch.randn(bsz, device=device)
|
||||
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
|
||||
timesteps = logits_norm.sigmoid()
|
||||
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
|
||||
timesteps = time_shift(mu, 1.0, timesteps)
|
||||
|
||||
t = timesteps.view(-1, 1, 1, 1)
|
||||
timesteps = timesteps * 1000.0
|
||||
noisy_model_input = (1 - t) * latents + t * noise
|
||||
sigmas = torch.randn(bsz, device=device)
|
||||
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
|
||||
sigmas = sigmas.sigmoid()
|
||||
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size
|
||||
sigmas = time_shift(mu, 1.0, sigmas)
|
||||
timesteps = sigmas * num_timesteps
|
||||
else:
|
||||
# Sample a random timestep for each image
|
||||
# for weighting schemes where we sample timesteps non-uniformly
|
||||
@@ -456,12 +507,24 @@ def get_noisy_model_input_and_timesteps(
|
||||
logit_std=args.logit_std,
|
||||
mode_scale=args.mode_scale,
|
||||
)
|
||||
indices = (u * noise_scheduler.config.num_train_timesteps).long()
|
||||
indices = (u * num_timesteps).long()
|
||||
timesteps = noise_scheduler.timesteps[indices].to(device=device)
|
||||
|
||||
# Add noise according to flow matching.
|
||||
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
|
||||
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
|
||||
|
||||
# Broadcast sigmas to latent shape
|
||||
sigmas = sigmas.view(-1, 1, 1, 1)
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
if args.ip_noise_gamma:
|
||||
xi = torch.randn_like(latents, device=latents.device, dtype=dtype)
|
||||
if args.ip_noise_gamma_random_strength:
|
||||
ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma
|
||||
else:
|
||||
ip_noise_gamma = args.ip_noise_gamma
|
||||
noisy_model_input = (1.0 - sigmas) * latents + sigmas * (noise + ip_noise_gamma * xi)
|
||||
else:
|
||||
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
|
||||
|
||||
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas
|
||||
|
||||
@@ -567,7 +630,7 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
|
||||
"--controlnet_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="path to controlnet (*.sft or *.safetensors) / controlnetのパス(*.sftまたは*.safetensors)"
|
||||
help="path to controlnet (*.sft or *.safetensors) / controlnetのパス(*.sftまたは*.safetensors)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--t5xxl_max_token_length",
|
||||
|
||||
@@ -2,7 +2,11 @@ import os
|
||||
import sys
|
||||
import contextlib
|
||||
import torch
|
||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||
legacy = True
|
||||
except Exception:
|
||||
legacy = False
|
||||
from .hijacks import ipex_hijacks
|
||||
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||
@@ -12,6 +16,13 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_xpu_hijacked") and torch.cuda.is_xpu_hijacked:
|
||||
return True, "Skipping IPEX hijack"
|
||||
else:
|
||||
try: # force xpu device on torch compile and triton
|
||||
torch._inductor.utils.GPU_TYPES = ["xpu"]
|
||||
torch._inductor.utils.get_gpu_type = lambda *args, **kwargs: "xpu"
|
||||
from triton import backends as triton_backends # pylint: disable=import-error
|
||||
triton_backends.backends["nvidia"].driver.is_active = lambda *args, **kwargs: False
|
||||
except Exception:
|
||||
pass
|
||||
# Replace cuda with xpu:
|
||||
torch.cuda.current_device = torch.xpu.current_device
|
||||
torch.cuda.current_stream = torch.xpu.current_stream
|
||||
@@ -26,84 +37,99 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
torch.cuda.is_current_stream_capturing = lambda: False
|
||||
torch.cuda.set_device = torch.xpu.set_device
|
||||
torch.cuda.stream = torch.xpu.stream
|
||||
torch.cuda.synchronize = torch.xpu.synchronize
|
||||
torch.cuda.Event = torch.xpu.Event
|
||||
torch.cuda.Stream = torch.xpu.Stream
|
||||
torch.cuda.FloatTensor = torch.xpu.FloatTensor
|
||||
torch.Tensor.cuda = torch.Tensor.xpu
|
||||
torch.Tensor.is_cuda = torch.Tensor.is_xpu
|
||||
torch.nn.Module.cuda = torch.nn.Module.xpu
|
||||
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
|
||||
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
|
||||
torch.cuda._initialized = torch.xpu.lazy_init._initialized
|
||||
torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
|
||||
torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
|
||||
torch.cuda._tls = torch.xpu.lazy_init._tls
|
||||
torch.cuda.threading = torch.xpu.lazy_init.threading
|
||||
torch.cuda.traceback = torch.xpu.lazy_init.traceback
|
||||
torch.cuda.Optional = torch.xpu.Optional
|
||||
torch.cuda.__cached__ = torch.xpu.__cached__
|
||||
torch.cuda.__loader__ = torch.xpu.__loader__
|
||||
torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
|
||||
torch.cuda.Tuple = torch.xpu.Tuple
|
||||
torch.cuda.streams = torch.xpu.streams
|
||||
torch.cuda._lazy_new = torch.xpu._lazy_new
|
||||
torch.cuda.FloatStorage = torch.xpu.FloatStorage
|
||||
torch.cuda.Any = torch.xpu.Any
|
||||
torch.cuda.__doc__ = torch.xpu.__doc__
|
||||
torch.cuda.default_generators = torch.xpu.default_generators
|
||||
torch.cuda.HalfTensor = torch.xpu.HalfTensor
|
||||
torch.cuda._get_device_index = torch.xpu._get_device_index
|
||||
torch.cuda.__path__ = torch.xpu.__path__
|
||||
torch.cuda.Device = torch.xpu.Device
|
||||
torch.cuda.IntTensor = torch.xpu.IntTensor
|
||||
torch.cuda.ByteStorage = torch.xpu.ByteStorage
|
||||
torch.cuda.set_stream = torch.xpu.set_stream
|
||||
torch.cuda.BoolStorage = torch.xpu.BoolStorage
|
||||
torch.cuda.os = torch.xpu.os
|
||||
torch.cuda.torch = torch.xpu.torch
|
||||
torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
|
||||
torch.cuda.Union = torch.xpu.Union
|
||||
torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
|
||||
torch.cuda.ShortTensor = torch.xpu.ShortTensor
|
||||
torch.cuda.LongTensor = torch.xpu.LongTensor
|
||||
torch.cuda.IntStorage = torch.xpu.IntStorage
|
||||
torch.cuda.LongStorage = torch.xpu.LongStorage
|
||||
torch.cuda.__annotations__ = torch.xpu.__annotations__
|
||||
torch.cuda.__package__ = torch.xpu.__package__
|
||||
torch.cuda.__builtins__ = torch.xpu.__builtins__
|
||||
torch.cuda.CharTensor = torch.xpu.CharTensor
|
||||
torch.cuda.List = torch.xpu.List
|
||||
torch.cuda._lazy_init = torch.xpu._lazy_init
|
||||
torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
|
||||
torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
|
||||
torch.cuda.ByteTensor = torch.xpu.ByteTensor
|
||||
torch.cuda.StreamContext = torch.xpu.StreamContext
|
||||
torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
|
||||
torch.cuda.ShortStorage = torch.xpu.ShortStorage
|
||||
torch.cuda._lazy_call = torch.xpu._lazy_call
|
||||
torch.cuda.HalfStorage = torch.xpu.HalfStorage
|
||||
torch.cuda.random = torch.xpu.random
|
||||
torch.cuda._device = torch.xpu._device
|
||||
torch.cuda.classproperty = torch.xpu.classproperty
|
||||
torch.cuda.__name__ = torch.xpu.__name__
|
||||
torch.cuda._device_t = torch.xpu._device_t
|
||||
torch.cuda.warnings = torch.xpu.warnings
|
||||
torch.cuda.__spec__ = torch.xpu.__spec__
|
||||
torch.cuda.BoolTensor = torch.xpu.BoolTensor
|
||||
torch.cuda.CharStorage = torch.xpu.CharStorage
|
||||
torch.cuda.__file__ = torch.xpu.__file__
|
||||
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
|
||||
# torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
|
||||
|
||||
if legacy:
|
||||
torch.cuda.os = torch.xpu.os
|
||||
torch.cuda.Device = torch.xpu.Device
|
||||
torch.cuda.warnings = torch.xpu.warnings
|
||||
torch.cuda.classproperty = torch.xpu.classproperty
|
||||
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
|
||||
if float(ipex.__version__[:3]) < 2.3:
|
||||
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
|
||||
torch.cuda._initialized = torch.xpu.lazy_init._initialized
|
||||
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
|
||||
torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
|
||||
torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
|
||||
torch.cuda._tls = torch.xpu.lazy_init._tls
|
||||
torch.cuda.threading = torch.xpu.lazy_init.threading
|
||||
torch.cuda.traceback = torch.xpu.lazy_init.traceback
|
||||
torch.cuda._lazy_new = torch.xpu._lazy_new
|
||||
|
||||
torch.cuda.FloatTensor = torch.xpu.FloatTensor
|
||||
torch.cuda.FloatStorage = torch.xpu.FloatStorage
|
||||
torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
|
||||
torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
|
||||
torch.cuda.HalfTensor = torch.xpu.HalfTensor
|
||||
torch.cuda.HalfStorage = torch.xpu.HalfStorage
|
||||
torch.cuda.ByteTensor = torch.xpu.ByteTensor
|
||||
torch.cuda.ByteStorage = torch.xpu.ByteStorage
|
||||
torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
|
||||
torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
|
||||
torch.cuda.ShortTensor = torch.xpu.ShortTensor
|
||||
torch.cuda.ShortStorage = torch.xpu.ShortStorage
|
||||
torch.cuda.LongTensor = torch.xpu.LongTensor
|
||||
torch.cuda.LongStorage = torch.xpu.LongStorage
|
||||
torch.cuda.IntTensor = torch.xpu.IntTensor
|
||||
torch.cuda.IntStorage = torch.xpu.IntStorage
|
||||
torch.cuda.CharTensor = torch.xpu.CharTensor
|
||||
torch.cuda.CharStorage = torch.xpu.CharStorage
|
||||
torch.cuda.BoolTensor = torch.xpu.BoolTensor
|
||||
torch.cuda.BoolStorage = torch.xpu.BoolStorage
|
||||
torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
|
||||
torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
|
||||
|
||||
if not legacy or float(ipex.__version__[:3]) >= 2.3:
|
||||
torch.cuda._initialization_lock = torch.xpu._initialization_lock
|
||||
torch.cuda._initialized = torch.xpu._initialized
|
||||
torch.cuda._is_in_bad_fork = torch.xpu._is_in_bad_fork
|
||||
torch.cuda._lazy_seed_tracker = torch.xpu._lazy_seed_tracker
|
||||
torch.cuda._queued_calls = torch.xpu._queued_calls
|
||||
torch.cuda._tls = torch.xpu._tls
|
||||
torch.cuda.threading = torch.xpu.threading
|
||||
torch.cuda.traceback = torch.xpu.traceback
|
||||
|
||||
# Memory:
|
||||
torch.cuda.memory = torch.xpu.memory
|
||||
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
|
||||
torch.xpu.empty_cache = lambda: None
|
||||
torch.cuda.empty_cache = torch.xpu.empty_cache
|
||||
|
||||
if legacy:
|
||||
torch.cuda.memory_summary = torch.xpu.memory_summary
|
||||
torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
|
||||
torch.cuda.memory = torch.xpu.memory
|
||||
torch.cuda.memory_stats = torch.xpu.memory_stats
|
||||
torch.cuda.memory_summary = torch.xpu.memory_summary
|
||||
torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
|
||||
torch.cuda.memory_allocated = torch.xpu.memory_allocated
|
||||
torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated
|
||||
torch.cuda.memory_reserved = torch.xpu.memory_reserved
|
||||
@@ -128,32 +154,44 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
torch.cuda.initial_seed = torch.xpu.initial_seed
|
||||
|
||||
# AMP:
|
||||
torch.cuda.amp = torch.xpu.amp
|
||||
torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled
|
||||
torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype
|
||||
if legacy:
|
||||
torch.xpu.amp.custom_fwd = torch.cuda.amp.custom_fwd
|
||||
torch.xpu.amp.custom_bwd = torch.cuda.amp.custom_bwd
|
||||
torch.cuda.amp = torch.xpu.amp
|
||||
if float(ipex.__version__[:3]) < 2.3:
|
||||
torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled
|
||||
torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype
|
||||
|
||||
if not hasattr(torch.cuda.amp, "common"):
|
||||
torch.cuda.amp.common = contextlib.nullcontext()
|
||||
torch.cuda.amp.common.amp_definitely_not_available = lambda: False
|
||||
if not hasattr(torch.cuda.amp, "common"):
|
||||
torch.cuda.amp.common = contextlib.nullcontext()
|
||||
torch.cuda.amp.common.amp_definitely_not_available = lambda: False
|
||||
|
||||
try:
|
||||
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
try:
|
||||
from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error
|
||||
gradscaler_init()
|
||||
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
|
||||
try:
|
||||
from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error
|
||||
gradscaler_init()
|
||||
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
|
||||
|
||||
# C
|
||||
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
|
||||
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
|
||||
ipex._C._DeviceProperties.major = 2024
|
||||
ipex._C._DeviceProperties.minor = 0
|
||||
if legacy and float(ipex.__version__[:3]) < 2.3:
|
||||
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentRawStream
|
||||
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
|
||||
ipex._C._DeviceProperties.major = 12
|
||||
ipex._C._DeviceProperties.minor = 1
|
||||
else:
|
||||
torch._C._cuda_getCurrentRawStream = torch._C._xpu_getCurrentRawStream
|
||||
torch._C._XpuDeviceProperties.multi_processor_count = torch._C._XpuDeviceProperties.gpu_subslice_count
|
||||
torch._C._XpuDeviceProperties.major = 12
|
||||
torch._C._XpuDeviceProperties.minor = 1
|
||||
|
||||
# Fix functions with ipex:
|
||||
torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
|
||||
# torch.xpu.mem_get_info always returns the total memory as free memory
|
||||
torch.xpu.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
|
||||
torch.cuda.mem_get_info = torch.xpu.mem_get_info
|
||||
torch._utils._get_available_device_type = lambda: "xpu"
|
||||
torch.has_cuda = True
|
||||
torch.cuda.has_half = True
|
||||
@@ -161,19 +199,19 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
|
||||
torch.backends.cuda.is_built = lambda *args, **kwargs: True
|
||||
torch.version.cuda = "12.1"
|
||||
torch.cuda.get_device_capability = lambda *args, **kwargs: [12,1]
|
||||
torch.cuda.get_arch_list = lambda: ["ats-m150", "pvc"]
|
||||
torch.cuda.get_device_capability = lambda *args, **kwargs: (12,1)
|
||||
torch.cuda.get_device_properties.major = 12
|
||||
torch.cuda.get_device_properties.minor = 1
|
||||
torch.cuda.ipc_collect = lambda *args, **kwargs: None
|
||||
torch.cuda.utilization = lambda *args, **kwargs: 0
|
||||
|
||||
ipex_hijacks()
|
||||
if not torch.xpu.has_fp64_dtype() or os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is not None:
|
||||
try:
|
||||
from .diffusers import ipex_diffusers
|
||||
ipex_diffusers()
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
pass
|
||||
device_supports_fp64, can_allocate_plus_4gb = ipex_hijacks(legacy=legacy)
|
||||
try:
|
||||
from .diffusers import ipex_diffusers
|
||||
ipex_diffusers(device_supports_fp64=device_supports_fp64, can_allocate_plus_4gb=can_allocate_plus_4gb)
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
pass
|
||||
torch.cuda.is_xpu_hijacked = True
|
||||
except Exception as e:
|
||||
return False, e
|
||||
|
||||
@@ -1,177 +1,119 @@
|
||||
import os
|
||||
import torch
|
||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||
from functools import cache
|
||||
from functools import cache, wraps
|
||||
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||
|
||||
# ARC GPUs can't allocate more than 4GB to a single block so we slice the attention layers
|
||||
|
||||
sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4))
|
||||
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
|
||||
sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 1))
|
||||
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 0.5))
|
||||
|
||||
# Find something divisible with the input_tokens
|
||||
@cache
|
||||
def find_slice_size(slice_size, slice_block_size):
|
||||
while (slice_size * slice_block_size) > attention_slice_rate:
|
||||
slice_size = slice_size // 2
|
||||
if slice_size <= 1:
|
||||
slice_size = 1
|
||||
break
|
||||
return slice_size
|
||||
def find_split_size(original_size, slice_block_size, slice_rate=2):
|
||||
split_size = original_size
|
||||
while True:
|
||||
if (split_size * slice_block_size) <= slice_rate and original_size % split_size == 0:
|
||||
return split_size
|
||||
split_size = split_size - 1
|
||||
if split_size <= 1:
|
||||
return 1
|
||||
return split_size
|
||||
|
||||
|
||||
# Find slice sizes for SDPA
|
||||
@cache
|
||||
def find_sdpa_slice_sizes(query_shape, query_element_size):
|
||||
if len(query_shape) == 3:
|
||||
batch_size_attention, query_tokens, shape_three = query_shape
|
||||
shape_four = 1
|
||||
else:
|
||||
batch_size_attention, query_tokens, shape_three, shape_four = query_shape
|
||||
def find_sdpa_slice_sizes(query_shape, key_shape, query_element_size, slice_rate=2, trigger_rate=3):
|
||||
batch_size, attn_heads, query_len, _ = query_shape
|
||||
_, _, key_len, _ = key_shape
|
||||
|
||||
slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size
|
||||
block_size = batch_size_attention * slice_block_size
|
||||
slice_batch_size = attn_heads * (query_len * key_len) * query_element_size / 1024 / 1024 / 1024
|
||||
|
||||
split_slice_size = batch_size_attention
|
||||
split_2_slice_size = query_tokens
|
||||
split_3_slice_size = shape_three
|
||||
split_batch_size = batch_size
|
||||
split_head_size = attn_heads
|
||||
split_query_size = query_len
|
||||
|
||||
do_split = False
|
||||
do_split_2 = False
|
||||
do_split_3 = False
|
||||
do_batch_split = False
|
||||
do_head_split = False
|
||||
do_query_split = False
|
||||
|
||||
if block_size > sdpa_slice_trigger_rate:
|
||||
do_split = True
|
||||
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
|
||||
if split_slice_size * slice_block_size > attention_slice_rate:
|
||||
slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size
|
||||
do_split_2 = True
|
||||
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
|
||||
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
|
||||
slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size
|
||||
do_split_3 = True
|
||||
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
|
||||
if batch_size * slice_batch_size >= trigger_rate:
|
||||
do_batch_split = True
|
||||
split_batch_size = find_split_size(batch_size, slice_batch_size, slice_rate=slice_rate)
|
||||
|
||||
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
|
||||
if split_batch_size * slice_batch_size > slice_rate:
|
||||
slice_head_size = split_batch_size * (query_len * key_len) * query_element_size / 1024 / 1024 / 1024
|
||||
do_head_split = True
|
||||
split_head_size = find_split_size(attn_heads, slice_head_size, slice_rate=slice_rate)
|
||||
|
||||
# Find slice sizes for BMM
|
||||
@cache
|
||||
def find_bmm_slice_sizes(input_shape, input_element_size, mat2_shape):
|
||||
batch_size_attention, input_tokens, mat2_atten_shape = input_shape[0], input_shape[1], mat2_shape[2]
|
||||
slice_block_size = input_tokens * mat2_atten_shape / 1024 / 1024 * input_element_size
|
||||
block_size = batch_size_attention * slice_block_size
|
||||
if split_head_size * slice_head_size > slice_rate:
|
||||
slice_query_size = split_batch_size * split_head_size * (key_len) * query_element_size / 1024 / 1024 / 1024
|
||||
do_query_split = True
|
||||
split_query_size = find_split_size(query_len, slice_query_size, slice_rate=slice_rate)
|
||||
|
||||
split_slice_size = batch_size_attention
|
||||
split_2_slice_size = input_tokens
|
||||
split_3_slice_size = mat2_atten_shape
|
||||
return do_batch_split, do_head_split, do_query_split, split_batch_size, split_head_size, split_query_size
|
||||
|
||||
do_split = False
|
||||
do_split_2 = False
|
||||
do_split_3 = False
|
||||
|
||||
if block_size > attention_slice_rate:
|
||||
do_split = True
|
||||
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
|
||||
if split_slice_size * slice_block_size > attention_slice_rate:
|
||||
slice_2_block_size = split_slice_size * mat2_atten_shape / 1024 / 1024 * input_element_size
|
||||
do_split_2 = True
|
||||
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
|
||||
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
|
||||
slice_3_block_size = split_slice_size * split_2_slice_size / 1024 / 1024 * input_element_size
|
||||
do_split_3 = True
|
||||
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
|
||||
|
||||
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
|
||||
|
||||
|
||||
original_torch_bmm = torch.bmm
|
||||
def torch_bmm_32_bit(input, mat2, *, out=None):
|
||||
if input.device.type != "xpu":
|
||||
return original_torch_bmm(input, mat2, out=out)
|
||||
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_bmm_slice_sizes(input.shape, input.element_size(), mat2.shape)
|
||||
|
||||
# Slice BMM
|
||||
if do_split:
|
||||
batch_size_attention, input_tokens, mat2_atten_shape = input.shape[0], input.shape[1], mat2.shape[2]
|
||||
hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype)
|
||||
for i in range(batch_size_attention // split_slice_size):
|
||||
start_idx = i * split_slice_size
|
||||
end_idx = (i + 1) * split_slice_size
|
||||
if do_split_2:
|
||||
for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
||||
start_idx_2 = i2 * split_2_slice_size
|
||||
end_idx_2 = (i2 + 1) * split_2_slice_size
|
||||
if do_split_3:
|
||||
for i3 in range(mat2_atten_shape // split_3_slice_size): # pylint: disable=invalid-name
|
||||
start_idx_3 = i3 * split_3_slice_size
|
||||
end_idx_3 = (i3 + 1) * split_3_slice_size
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_torch_bmm(
|
||||
input[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
||||
mat2[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
||||
out=out
|
||||
)
|
||||
else:
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm(
|
||||
input[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||
mat2[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||
out=out
|
||||
)
|
||||
else:
|
||||
hidden_states[start_idx:end_idx] = original_torch_bmm(
|
||||
input[start_idx:end_idx],
|
||||
mat2[start_idx:end_idx],
|
||||
out=out
|
||||
)
|
||||
torch.xpu.synchronize(input.device)
|
||||
else:
|
||||
return original_torch_bmm(input, mat2, out=out)
|
||||
return hidden_states
|
||||
|
||||
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
||||
def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs):
|
||||
@wraps(torch.nn.functional.scaled_dot_product_attention)
|
||||
def dynamic_scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs):
|
||||
if query.device.type != "xpu":
|
||||
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
|
||||
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_sdpa_slice_sizes(query.shape, query.element_size())
|
||||
is_unsqueezed = False
|
||||
if len(query.shape) == 3:
|
||||
query = query.unsqueeze(0)
|
||||
is_unsqueezed = True
|
||||
if len(key.shape) == 3:
|
||||
key = key.unsqueeze(0)
|
||||
if len(value.shape) == 3:
|
||||
value = value.unsqueeze(0)
|
||||
do_batch_split, do_head_split, do_query_split, split_batch_size, split_head_size, split_query_size = find_sdpa_slice_sizes(query.shape, key.shape, query.element_size(), slice_rate=attention_slice_rate, trigger_rate=sdpa_slice_trigger_rate)
|
||||
|
||||
# Slice SDPA
|
||||
if do_split:
|
||||
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
|
||||
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
|
||||
for i in range(batch_size_attention // split_slice_size):
|
||||
start_idx = i * split_slice_size
|
||||
end_idx = (i + 1) * split_slice_size
|
||||
if do_split_2:
|
||||
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
||||
start_idx_2 = i2 * split_2_slice_size
|
||||
end_idx_2 = (i2 + 1) * split_2_slice_size
|
||||
if do_split_3:
|
||||
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
|
||||
start_idx_3 = i3 * split_3_slice_size
|
||||
end_idx_3 = (i3 + 1) * split_3_slice_size
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_scaled_dot_product_attention(
|
||||
query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
||||
key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
||||
value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
||||
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask,
|
||||
if do_batch_split:
|
||||
batch_size, attn_heads, query_len, _ = query.shape
|
||||
_, _, _, head_dim = value.shape
|
||||
hidden_states = torch.zeros((batch_size, attn_heads, query_len, head_dim), device=query.device, dtype=query.dtype)
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.expand((query.shape[0], query.shape[1], query.shape[2], key.shape[-2]))
|
||||
for ib in range(batch_size // split_batch_size):
|
||||
start_idx = ib * split_batch_size
|
||||
end_idx = (ib + 1) * split_batch_size
|
||||
if do_head_split:
|
||||
for ih in range(attn_heads // split_head_size): # pylint: disable=invalid-name
|
||||
start_idx_h = ih * split_head_size
|
||||
end_idx_h = (ih + 1) * split_head_size
|
||||
if do_query_split:
|
||||
for iq in range(query_len // split_query_size): # pylint: disable=invalid-name
|
||||
start_idx_q = iq * split_query_size
|
||||
end_idx_q = (iq + 1) * split_query_size
|
||||
hidden_states[start_idx:end_idx, start_idx_h:end_idx_h, start_idx_q:end_idx_q, :] = original_scaled_dot_product_attention(
|
||||
query[start_idx:end_idx, start_idx_h:end_idx_h, start_idx_q:end_idx_q, :],
|
||||
key[start_idx:end_idx, start_idx_h:end_idx_h, :, :],
|
||||
value[start_idx:end_idx, start_idx_h:end_idx_h, :, :],
|
||||
attn_mask=attn_mask[start_idx:end_idx, start_idx_h:end_idx_h, start_idx_q:end_idx_q, :] if attn_mask is not None else attn_mask,
|
||||
dropout_p=dropout_p, is_causal=is_causal, **kwargs
|
||||
)
|
||||
else:
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
|
||||
query[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||
key[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||
value[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
|
||||
hidden_states[start_idx:end_idx, start_idx_h:end_idx_h, :, :] = original_scaled_dot_product_attention(
|
||||
query[start_idx:end_idx, start_idx_h:end_idx_h, :, :],
|
||||
key[start_idx:end_idx, start_idx_h:end_idx_h, :, :],
|
||||
value[start_idx:end_idx, start_idx_h:end_idx_h, :, :],
|
||||
attn_mask=attn_mask[start_idx:end_idx, start_idx_h:end_idx_h, :, :] if attn_mask is not None else attn_mask,
|
||||
dropout_p=dropout_p, is_causal=is_causal, **kwargs
|
||||
)
|
||||
else:
|
||||
hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention(
|
||||
query[start_idx:end_idx],
|
||||
key[start_idx:end_idx],
|
||||
value[start_idx:end_idx],
|
||||
attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask,
|
||||
hidden_states[start_idx:end_idx, :, :, :] = original_scaled_dot_product_attention(
|
||||
query[start_idx:end_idx, :, :, :],
|
||||
key[start_idx:end_idx, :, :, :],
|
||||
value[start_idx:end_idx, :, :, :],
|
||||
attn_mask=attn_mask[start_idx:end_idx, :, :, :] if attn_mask is not None else attn_mask,
|
||||
dropout_p=dropout_p, is_causal=is_causal, **kwargs
|
||||
)
|
||||
torch.xpu.synchronize(query.device)
|
||||
else:
|
||||
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
|
||||
hidden_states = original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
|
||||
if is_unsqueezed:
|
||||
hidden_states.squeeze(0)
|
||||
return hidden_states
|
||||
|
||||
@@ -1,312 +1,47 @@
|
||||
import os
|
||||
from functools import wraps
|
||||
import torch
|
||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||
import diffusers #0.24.0 # pylint: disable=import-error
|
||||
from diffusers.models.attention_processor import Attention
|
||||
from diffusers.utils import USE_PEFT_BACKEND
|
||||
from functools import cache
|
||||
import diffusers # pylint: disable=import-error
|
||||
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||
|
||||
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
|
||||
|
||||
@cache
|
||||
def find_slice_size(slice_size, slice_block_size):
|
||||
while (slice_size * slice_block_size) > attention_slice_rate:
|
||||
slice_size = slice_size // 2
|
||||
if slice_size <= 1:
|
||||
slice_size = 1
|
||||
break
|
||||
return slice_size
|
||||
|
||||
@cache
|
||||
def find_attention_slice_sizes(query_shape, query_element_size, query_device_type, slice_size=None):
|
||||
if len(query_shape) == 3:
|
||||
batch_size_attention, query_tokens, shape_three = query_shape
|
||||
shape_four = 1
|
||||
else:
|
||||
batch_size_attention, query_tokens, shape_three, shape_four = query_shape
|
||||
if slice_size is not None:
|
||||
batch_size_attention = slice_size
|
||||
|
||||
slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size
|
||||
block_size = batch_size_attention * slice_block_size
|
||||
|
||||
split_slice_size = batch_size_attention
|
||||
split_2_slice_size = query_tokens
|
||||
split_3_slice_size = shape_three
|
||||
|
||||
do_split = False
|
||||
do_split_2 = False
|
||||
do_split_3 = False
|
||||
|
||||
if query_device_type != "xpu":
|
||||
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
|
||||
|
||||
if block_size > attention_slice_rate:
|
||||
do_split = True
|
||||
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
|
||||
if split_slice_size * slice_block_size > attention_slice_rate:
|
||||
slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size
|
||||
do_split_2 = True
|
||||
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
|
||||
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
|
||||
slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size
|
||||
do_split_3 = True
|
||||
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
|
||||
|
||||
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
|
||||
|
||||
class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
|
||||
r"""
|
||||
Processor for implementing sliced attention.
|
||||
|
||||
Args:
|
||||
slice_size (`int`, *optional*):
|
||||
The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
|
||||
`attention_head_dim` must be a multiple of the `slice_size`.
|
||||
"""
|
||||
|
||||
def __init__(self, slice_size):
|
||||
self.slice_size = slice_size
|
||||
|
||||
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states=None, attention_mask=None) -> torch.FloatTensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
dim = query.shape[-1]
|
||||
query = attn.head_to_batch_dim(query)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
|
||||
batch_size_attention, query_tokens, shape_three = query.shape
|
||||
hidden_states = torch.zeros(
|
||||
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
|
||||
)
|
||||
|
||||
####################################################################
|
||||
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
||||
_, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type, slice_size=self.slice_size)
|
||||
|
||||
for i in range(batch_size_attention // split_slice_size):
|
||||
start_idx = i * split_slice_size
|
||||
end_idx = (i + 1) * split_slice_size
|
||||
if do_split_2:
|
||||
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
||||
start_idx_2 = i2 * split_2_slice_size
|
||||
end_idx_2 = (i2 + 1) * split_2_slice_size
|
||||
if do_split_3:
|
||||
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
|
||||
start_idx_3 = i3 * split_3_slice_size
|
||||
end_idx_3 = (i3 + 1) * split_3_slice_size
|
||||
|
||||
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
|
||||
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
|
||||
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None
|
||||
|
||||
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||
del query_slice
|
||||
del key_slice
|
||||
del attn_mask_slice
|
||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3])
|
||||
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice
|
||||
del attn_slice
|
||||
else:
|
||||
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
|
||||
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
|
||||
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
|
||||
|
||||
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||
del query_slice
|
||||
del key_slice
|
||||
del attn_mask_slice
|
||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
|
||||
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
|
||||
del attn_slice
|
||||
torch.xpu.synchronize(query.device)
|
||||
else:
|
||||
query_slice = query[start_idx:end_idx]
|
||||
key_slice = key[start_idx:end_idx]
|
||||
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
||||
|
||||
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||
del query_slice
|
||||
del key_slice
|
||||
del attn_mask_slice
|
||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
||||
|
||||
hidden_states[start_idx:end_idx] = attn_slice
|
||||
del attn_slice
|
||||
####################################################################
|
||||
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AttnProcessor:
|
||||
r"""
|
||||
Default processor for performing attention-related computations.
|
||||
"""
|
||||
|
||||
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states=None, attention_mask=None,
|
||||
temb=None, scale: float = 1.0) -> torch.Tensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
args = () if USE_PEFT_BACKEND else (scale,)
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states, *args)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states, *args)
|
||||
value = attn.to_v(encoder_hidden_states, *args)
|
||||
|
||||
query = attn.head_to_batch_dim(query)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
|
||||
####################################################################
|
||||
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
||||
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
|
||||
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
|
||||
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type)
|
||||
|
||||
if do_split:
|
||||
for i in range(batch_size_attention // split_slice_size):
|
||||
start_idx = i * split_slice_size
|
||||
end_idx = (i + 1) * split_slice_size
|
||||
if do_split_2:
|
||||
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
||||
start_idx_2 = i2 * split_2_slice_size
|
||||
end_idx_2 = (i2 + 1) * split_2_slice_size
|
||||
if do_split_3:
|
||||
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
|
||||
start_idx_3 = i3 * split_3_slice_size
|
||||
end_idx_3 = (i3 + 1) * split_3_slice_size
|
||||
|
||||
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
|
||||
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
|
||||
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None
|
||||
|
||||
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||
del query_slice
|
||||
del key_slice
|
||||
del attn_mask_slice
|
||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3])
|
||||
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice
|
||||
del attn_slice
|
||||
else:
|
||||
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
|
||||
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
|
||||
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
|
||||
|
||||
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||
del query_slice
|
||||
del key_slice
|
||||
del attn_mask_slice
|
||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
|
||||
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
|
||||
del attn_slice
|
||||
else:
|
||||
query_slice = query[start_idx:end_idx]
|
||||
key_slice = key[start_idx:end_idx]
|
||||
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
||||
|
||||
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
||||
del query_slice
|
||||
del key_slice
|
||||
del attn_mask_slice
|
||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
||||
|
||||
hidden_states[start_idx:end_idx] = attn_slice
|
||||
del attn_slice
|
||||
torch.xpu.synchronize(query.device)
|
||||
else:
|
||||
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
||||
hidden_states = torch.bmm(attention_probs, value)
|
||||
####################################################################
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states, *args)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
def ipex_diffusers():
|
||||
#ARC GPUs can't allocate more than 4GB to a single block:
|
||||
diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor
|
||||
diffusers.models.attention_processor.AttnProcessor = AttnProcessor
|
||||
# Diffusers FreeU
|
||||
original_fourier_filter = diffusers.utils.torch_utils.fourier_filter
|
||||
@wraps(diffusers.utils.torch_utils.fourier_filter)
|
||||
def fourier_filter(x_in, threshold, scale):
|
||||
return_dtype = x_in.dtype
|
||||
return original_fourier_filter(x_in.to(dtype=torch.float32), threshold, scale).to(dtype=return_dtype)
|
||||
|
||||
|
||||
# fp64 error
|
||||
class FluxPosEmbed(torch.nn.Module):
|
||||
def __init__(self, theta: int, axes_dim):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
||||
n_axes = ids.shape[-1]
|
||||
cos_out = []
|
||||
sin_out = []
|
||||
pos = ids.float()
|
||||
for i in range(n_axes):
|
||||
cos, sin = diffusers.models.embeddings.get_1d_rotary_pos_embed(
|
||||
self.axes_dim[i],
|
||||
pos[:, i],
|
||||
theta=self.theta,
|
||||
repeat_interleave_real=True,
|
||||
use_real=True,
|
||||
freqs_dtype=torch.float32,
|
||||
)
|
||||
cos_out.append(cos)
|
||||
sin_out.append(sin)
|
||||
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
|
||||
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
def ipex_diffusers(device_supports_fp64=False, can_allocate_plus_4gb=False):
|
||||
diffusers.utils.torch_utils.fourier_filter = fourier_filter
|
||||
if not device_supports_fp64:
|
||||
diffusers.models.embeddings.FluxPosEmbed = FluxPosEmbed
|
||||
|
||||
@@ -5,7 +5,7 @@ import intel_extension_for_pytorch._C as core # pylint: disable=import-error, un
|
||||
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||
|
||||
device_supports_fp64 = torch.xpu.has_fp64_dtype()
|
||||
device_supports_fp64 = torch.xpu.has_fp64_dtype() if hasattr(torch.xpu, "has_fp64_dtype") else torch.xpu.get_device_properties("xpu").has_fp64
|
||||
OptState = ipex.cpu.autocast._grad_scaler.OptState
|
||||
_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator
|
||||
_refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state
|
||||
|
||||
@@ -2,10 +2,19 @@ import os
|
||||
from functools import wraps
|
||||
from contextlib import nullcontext
|
||||
import torch
|
||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||
import numpy as np
|
||||
|
||||
device_supports_fp64 = torch.xpu.has_fp64_dtype()
|
||||
device_supports_fp64 = torch.xpu.has_fp64_dtype() if hasattr(torch.xpu, "has_fp64_dtype") else torch.xpu.get_device_properties("xpu").has_fp64
|
||||
if os.environ.get('IPEX_FORCE_ATTENTION_SLICE', '0') == '0' and (torch.xpu.get_device_properties("xpu").total_memory / 1024 / 1024 / 1024) > 4.1:
|
||||
try:
|
||||
x = torch.ones((33000,33000), dtype=torch.float32, device="xpu")
|
||||
del x
|
||||
torch.xpu.empty_cache()
|
||||
can_allocate_plus_4gb = True
|
||||
except Exception:
|
||||
can_allocate_plus_4gb = False
|
||||
else:
|
||||
can_allocate_plus_4gb = bool(os.environ.get('IPEX_FORCE_ATTENTION_SLICE', '0') == '-1')
|
||||
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
|
||||
|
||||
@@ -26,7 +35,7 @@ def check_device(device):
|
||||
return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int))
|
||||
|
||||
def return_xpu(device):
|
||||
return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu"
|
||||
return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device(f"xpu:{device.index}" if device.index is not None else "xpu") if isinstance(device, torch.device) else "xpu"
|
||||
|
||||
|
||||
# Autocast
|
||||
@@ -42,7 +51,7 @@ def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=Non
|
||||
original_interpolate = torch.nn.functional.interpolate
|
||||
@wraps(torch.nn.functional.interpolate)
|
||||
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
|
||||
if antialias or align_corners is not None or mode == 'bicubic':
|
||||
if mode in {'bicubic', 'bilinear'}:
|
||||
return_device = tensor.device
|
||||
return_dtype = tensor.dtype
|
||||
return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode,
|
||||
@@ -73,35 +82,46 @@ def as_tensor(data, dtype=None, device=None):
|
||||
return original_as_tensor(data, dtype=dtype, device=device)
|
||||
|
||||
|
||||
if device_supports_fp64 and os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is None:
|
||||
original_torch_bmm = torch.bmm
|
||||
if can_allocate_plus_4gb:
|
||||
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
||||
else:
|
||||
# 32 bit attention workarounds for Alchemist:
|
||||
try:
|
||||
from .attention import torch_bmm_32_bit as original_torch_bmm
|
||||
from .attention import scaled_dot_product_attention_32_bit as original_scaled_dot_product_attention
|
||||
from .attention import dynamic_scaled_dot_product_attention as original_scaled_dot_product_attention
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
original_torch_bmm = torch.bmm
|
||||
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
||||
|
||||
|
||||
# Data Type Errors:
|
||||
@wraps(torch.bmm)
|
||||
def torch_bmm(input, mat2, *, out=None):
|
||||
if input.dtype != mat2.dtype:
|
||||
mat2 = mat2.to(input.dtype)
|
||||
return original_torch_bmm(input, mat2, out=out)
|
||||
|
||||
@wraps(torch.nn.functional.scaled_dot_product_attention)
|
||||
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
|
||||
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs):
|
||||
if query.dtype != key.dtype:
|
||||
key = key.to(dtype=query.dtype)
|
||||
if query.dtype != value.dtype:
|
||||
value = value.to(dtype=query.dtype)
|
||||
if attn_mask is not None and query.dtype != attn_mask.dtype:
|
||||
attn_mask = attn_mask.to(dtype=query.dtype)
|
||||
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
|
||||
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
|
||||
|
||||
# Data Type Errors:
|
||||
original_torch_bmm = torch.bmm
|
||||
@wraps(torch.bmm)
|
||||
def torch_bmm(input, mat2, *, out=None):
|
||||
if input.dtype != mat2.dtype:
|
||||
mat2 = mat2.to(input.dtype)
|
||||
return original_torch_bmm(input, mat2, out=out)
|
||||
|
||||
# Diffusers FreeU
|
||||
original_fft_fftn = torch.fft.fftn
|
||||
@wraps(torch.fft.fftn)
|
||||
def fft_fftn(input, s=None, dim=None, norm=None, *, out=None):
|
||||
return_dtype = input.dtype
|
||||
return original_fft_fftn(input.to(dtype=torch.float32), s=s, dim=dim, norm=norm, out=out).to(dtype=return_dtype)
|
||||
|
||||
# Diffusers FreeU
|
||||
original_fft_ifftn = torch.fft.ifftn
|
||||
@wraps(torch.fft.ifftn)
|
||||
def fft_ifftn(input, s=None, dim=None, norm=None, *, out=None):
|
||||
return_dtype = input.dtype
|
||||
return original_fft_ifftn(input.to(dtype=torch.float32), s=s, dim=dim, norm=norm, out=out).to(dtype=return_dtype)
|
||||
|
||||
# A1111 FP16
|
||||
original_functional_group_norm = torch.nn.functional.group_norm
|
||||
@@ -133,6 +153,15 @@ def functional_linear(input, weight, bias=None):
|
||||
bias.data = bias.data.to(dtype=weight.data.dtype)
|
||||
return original_functional_linear(input, weight, bias=bias)
|
||||
|
||||
original_functional_conv1d = torch.nn.functional.conv1d
|
||||
@wraps(torch.nn.functional.conv1d)
|
||||
def functional_conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
||||
if input.dtype != weight.data.dtype:
|
||||
input = input.to(dtype=weight.data.dtype)
|
||||
if bias is not None and bias.data.dtype != weight.data.dtype:
|
||||
bias.data = bias.data.to(dtype=weight.data.dtype)
|
||||
return original_functional_conv1d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||
|
||||
original_functional_conv2d = torch.nn.functional.conv2d
|
||||
@wraps(torch.nn.functional.conv2d)
|
||||
def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
||||
@@ -142,14 +171,15 @@ def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1,
|
||||
bias.data = bias.data.to(dtype=weight.data.dtype)
|
||||
return original_functional_conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||
|
||||
# A1111 Embedding BF16
|
||||
original_torch_cat = torch.cat
|
||||
@wraps(torch.cat)
|
||||
def torch_cat(tensor, *args, **kwargs):
|
||||
if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
|
||||
return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs)
|
||||
else:
|
||||
return original_torch_cat(tensor, *args, **kwargs)
|
||||
# LTX Video
|
||||
original_functional_conv3d = torch.nn.functional.conv3d
|
||||
@wraps(torch.nn.functional.conv3d)
|
||||
def functional_conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
||||
if input.dtype != weight.data.dtype:
|
||||
input = input.to(dtype=weight.data.dtype)
|
||||
if bias is not None and bias.data.dtype != weight.data.dtype:
|
||||
bias.data = bias.data.to(dtype=weight.data.dtype)
|
||||
return original_functional_conv3d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||
|
||||
# SwinIR BF16:
|
||||
original_functional_pad = torch.nn.functional.pad
|
||||
@@ -164,6 +194,7 @@ def functional_pad(input, pad, mode='constant', value=None):
|
||||
original_torch_tensor = torch.tensor
|
||||
@wraps(torch.tensor)
|
||||
def torch_tensor(data, *args, dtype=None, device=None, **kwargs):
|
||||
global device_supports_fp64
|
||||
if check_device(device):
|
||||
device = return_xpu(device)
|
||||
if not device_supports_fp64:
|
||||
@@ -227,7 +258,7 @@ def torch_empty(*args, device=None, **kwargs):
|
||||
original_torch_randn = torch.randn
|
||||
@wraps(torch.randn)
|
||||
def torch_randn(*args, device=None, dtype=None, **kwargs):
|
||||
if dtype == bytes:
|
||||
if dtype is bytes:
|
||||
dtype = None
|
||||
if check_device(device):
|
||||
return original_torch_randn(*args, device=return_xpu(device), **kwargs)
|
||||
@@ -250,6 +281,14 @@ def torch_zeros(*args, device=None, **kwargs):
|
||||
else:
|
||||
return original_torch_zeros(*args, device=device, **kwargs)
|
||||
|
||||
original_torch_full = torch.full
|
||||
@wraps(torch.full)
|
||||
def torch_full(*args, device=None, **kwargs):
|
||||
if check_device(device):
|
||||
return original_torch_full(*args, device=return_xpu(device), **kwargs)
|
||||
else:
|
||||
return original_torch_full(*args, device=device, **kwargs)
|
||||
|
||||
original_torch_linspace = torch.linspace
|
||||
@wraps(torch.linspace)
|
||||
def torch_linspace(*args, device=None, **kwargs):
|
||||
@@ -258,14 +297,6 @@ def torch_linspace(*args, device=None, **kwargs):
|
||||
else:
|
||||
return original_torch_linspace(*args, device=device, **kwargs)
|
||||
|
||||
original_torch_Generator = torch.Generator
|
||||
@wraps(torch.Generator)
|
||||
def torch_Generator(device=None):
|
||||
if check_device(device):
|
||||
return original_torch_Generator(return_xpu(device))
|
||||
else:
|
||||
return original_torch_Generator(device)
|
||||
|
||||
original_torch_load = torch.load
|
||||
@wraps(torch.load)
|
||||
def torch_load(f, map_location=None, *args, **kwargs):
|
||||
@@ -276,9 +307,27 @@ def torch_load(f, map_location=None, *args, **kwargs):
|
||||
else:
|
||||
return original_torch_load(f, *args, map_location=map_location, **kwargs)
|
||||
|
||||
original_torch_Generator = torch.Generator
|
||||
@wraps(torch.Generator)
|
||||
def torch_Generator(device=None):
|
||||
if check_device(device):
|
||||
return original_torch_Generator(return_xpu(device))
|
||||
else:
|
||||
return original_torch_Generator(device)
|
||||
|
||||
@wraps(torch.cuda.synchronize)
|
||||
def torch_cuda_synchronize(device=None):
|
||||
if check_device(device):
|
||||
return torch.xpu.synchronize(return_xpu(device))
|
||||
else:
|
||||
return torch.xpu.synchronize(device)
|
||||
|
||||
|
||||
# Hijack Functions:
|
||||
def ipex_hijacks():
|
||||
def ipex_hijacks(legacy=True):
|
||||
global device_supports_fp64, can_allocate_plus_4gb
|
||||
if legacy and float(torch.__version__[:3]) < 2.5:
|
||||
torch.nn.functional.interpolate = interpolate
|
||||
torch.tensor = torch_tensor
|
||||
torch.Tensor.to = Tensor_to
|
||||
torch.Tensor.cuda = Tensor_cuda
|
||||
@@ -289,9 +338,11 @@ def ipex_hijacks():
|
||||
torch.randn = torch_randn
|
||||
torch.ones = torch_ones
|
||||
torch.zeros = torch_zeros
|
||||
torch.full = torch_full
|
||||
torch.linspace = torch_linspace
|
||||
torch.Generator = torch_Generator
|
||||
torch.load = torch_load
|
||||
torch.Generator = torch_Generator
|
||||
torch.cuda.synchronize = torch_cuda_synchronize
|
||||
|
||||
torch.backends.cuda.sdp_kernel = return_null_context
|
||||
torch.nn.DataParallel = DummyDataParallel
|
||||
@@ -302,12 +353,15 @@ def ipex_hijacks():
|
||||
torch.nn.functional.group_norm = functional_group_norm
|
||||
torch.nn.functional.layer_norm = functional_layer_norm
|
||||
torch.nn.functional.linear = functional_linear
|
||||
torch.nn.functional.conv1d = functional_conv1d
|
||||
torch.nn.functional.conv2d = functional_conv2d
|
||||
torch.nn.functional.interpolate = interpolate
|
||||
torch.nn.functional.conv3d = functional_conv3d
|
||||
torch.nn.functional.pad = functional_pad
|
||||
|
||||
torch.bmm = torch_bmm
|
||||
torch.cat = torch_cat
|
||||
torch.fft.fftn = fft_fftn
|
||||
torch.fft.ifftn = fft_ifftn
|
||||
if not device_supports_fp64:
|
||||
torch.from_numpy = from_numpy
|
||||
torch.as_tensor = as_tensor
|
||||
return device_supports_fp64, can_allocate_plus_4gb
|
||||
|
||||
186
library/jpeg_xl_util.py
Normal file
186
library/jpeg_xl_util.py
Normal file
@@ -0,0 +1,186 @@
|
||||
# Modified from https://github.com/Fraetor/jxl_decode Original license: MIT
|
||||
# Added partial read support for up to 200x speedup
|
||||
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
|
||||
class JXLBitstream:
|
||||
"""
|
||||
A stream of bits with methods for easy handling.
|
||||
"""
|
||||
|
||||
def __init__(self, file, offset: int = 0, offsets: List[List[int]] = None):
|
||||
self.shift = 0
|
||||
self.bitstream = bytearray()
|
||||
self.file = file
|
||||
self.offset = offset
|
||||
self.offsets = offsets
|
||||
if self.offsets:
|
||||
self.offset = self.offsets[0][1]
|
||||
self.previous_data_len = 0
|
||||
self.index = 0
|
||||
self.file.seek(self.offset)
|
||||
|
||||
def get_bits(self, length: int = 1) -> int:
|
||||
if self.offsets and self.shift + length > self.previous_data_len + self.offsets[self.index][2]:
|
||||
self.partial_to_read_length = length
|
||||
if self.shift < self.previous_data_len + self.offsets[self.index][2]:
|
||||
self.partial_read(0, length)
|
||||
self.bitstream.extend(self.file.read(self.partial_to_read_length))
|
||||
else:
|
||||
self.bitstream.extend(self.file.read(length))
|
||||
bitmask = 2**length - 1
|
||||
bits = (int.from_bytes(self.bitstream, "little") >> self.shift) & bitmask
|
||||
self.shift += length
|
||||
return bits
|
||||
|
||||
def partial_read(self, current_length: int, length: int) -> None:
|
||||
self.previous_data_len += self.offsets[self.index][2]
|
||||
to_read_length = self.previous_data_len - (self.shift + current_length)
|
||||
self.bitstream.extend(self.file.read(to_read_length))
|
||||
current_length += to_read_length
|
||||
self.partial_to_read_length -= to_read_length
|
||||
self.index += 1
|
||||
self.file.seek(self.offsets[self.index][1])
|
||||
if self.shift + length > self.previous_data_len + self.offsets[self.index][2]:
|
||||
self.partial_read(current_length, length)
|
||||
|
||||
|
||||
def decode_codestream(file, offset: int = 0, offsets: List[List[int]] = None) -> Tuple[int,int]:
|
||||
"""
|
||||
Decodes the actual codestream.
|
||||
JXL codestream specification: http://www-internal/2022/18181-1
|
||||
"""
|
||||
|
||||
# Convert codestream to int within an object to get some handy methods.
|
||||
codestream = JXLBitstream(file, offset=offset, offsets=offsets)
|
||||
|
||||
# Skip signature
|
||||
codestream.get_bits(16)
|
||||
|
||||
# SizeHeader
|
||||
div8 = codestream.get_bits(1)
|
||||
if div8:
|
||||
height = 8 * (1 + codestream.get_bits(5))
|
||||
else:
|
||||
distribution = codestream.get_bits(2)
|
||||
match distribution:
|
||||
case 0:
|
||||
height = 1 + codestream.get_bits(9)
|
||||
case 1:
|
||||
height = 1 + codestream.get_bits(13)
|
||||
case 2:
|
||||
height = 1 + codestream.get_bits(18)
|
||||
case 3:
|
||||
height = 1 + codestream.get_bits(30)
|
||||
ratio = codestream.get_bits(3)
|
||||
if div8 and not ratio:
|
||||
width = 8 * (1 + codestream.get_bits(5))
|
||||
elif not ratio:
|
||||
distribution = codestream.get_bits(2)
|
||||
match distribution:
|
||||
case 0:
|
||||
width = 1 + codestream.get_bits(9)
|
||||
case 1:
|
||||
width = 1 + codestream.get_bits(13)
|
||||
case 2:
|
||||
width = 1 + codestream.get_bits(18)
|
||||
case 3:
|
||||
width = 1 + codestream.get_bits(30)
|
||||
else:
|
||||
match ratio:
|
||||
case 1:
|
||||
width = height
|
||||
case 2:
|
||||
width = (height * 12) // 10
|
||||
case 3:
|
||||
width = (height * 4) // 3
|
||||
case 4:
|
||||
width = (height * 3) // 2
|
||||
case 5:
|
||||
width = (height * 16) // 9
|
||||
case 6:
|
||||
width = (height * 5) // 4
|
||||
case 7:
|
||||
width = (height * 2) // 1
|
||||
return width, height
|
||||
|
||||
|
||||
def decode_container(file) -> Tuple[int,int]:
|
||||
"""
|
||||
Parses the ISOBMFF container, extracts the codestream, and decodes it.
|
||||
JXL container specification: http://www-internal/2022/18181-2
|
||||
"""
|
||||
|
||||
def parse_box(file, file_start: int) -> dict:
|
||||
file.seek(file_start)
|
||||
LBox = int.from_bytes(file.read(4), "big")
|
||||
XLBox = None
|
||||
if 1 < LBox <= 8:
|
||||
raise ValueError(f"Invalid LBox at byte {file_start}.")
|
||||
if LBox == 1:
|
||||
file.seek(file_start + 8)
|
||||
XLBox = int.from_bytes(file.read(8), "big")
|
||||
if XLBox <= 16:
|
||||
raise ValueError(f"Invalid XLBox at byte {file_start}.")
|
||||
if XLBox:
|
||||
header_length = 16
|
||||
box_length = XLBox
|
||||
else:
|
||||
header_length = 8
|
||||
if LBox == 0:
|
||||
box_length = os.fstat(file.fileno()).st_size - file_start
|
||||
else:
|
||||
box_length = LBox
|
||||
file.seek(file_start + 4)
|
||||
box_type = file.read(4)
|
||||
file.seek(file_start)
|
||||
return {
|
||||
"length": box_length,
|
||||
"type": box_type,
|
||||
"offset": header_length,
|
||||
}
|
||||
|
||||
file.seek(0)
|
||||
# Reject files missing required boxes. These two boxes are required to be at
|
||||
# the start and contain no values, so we can manually check there presence.
|
||||
# Signature box. (Redundant as has already been checked.)
|
||||
if file.read(12) != bytes.fromhex("0000000C 4A584C20 0D0A870A"):
|
||||
raise ValueError("Invalid signature box.")
|
||||
# File Type box.
|
||||
if file.read(20) != bytes.fromhex(
|
||||
"00000014 66747970 6A786C20 00000000 6A786C20"
|
||||
):
|
||||
raise ValueError("Invalid file type box.")
|
||||
|
||||
offset = 0
|
||||
offsets = []
|
||||
data_offset_not_found = True
|
||||
container_pointer = 32
|
||||
file_size = os.fstat(file.fileno()).st_size
|
||||
while data_offset_not_found:
|
||||
box = parse_box(file, container_pointer)
|
||||
match box["type"]:
|
||||
case b"jxlc":
|
||||
offset = container_pointer + box["offset"]
|
||||
data_offset_not_found = False
|
||||
case b"jxlp":
|
||||
file.seek(container_pointer + box["offset"])
|
||||
index = int.from_bytes(file.read(4), "big")
|
||||
offsets.append([index, container_pointer + box["offset"] + 4, box["length"] - box["offset"] - 4])
|
||||
container_pointer += box["length"]
|
||||
if container_pointer >= file_size:
|
||||
data_offset_not_found = False
|
||||
|
||||
if offsets:
|
||||
offsets.sort(key=lambda i: i[0])
|
||||
file.seek(0)
|
||||
|
||||
return decode_codestream(file, offset=offset, offsets=offsets)
|
||||
|
||||
|
||||
def get_jxl_size(path: str) -> Tuple[int,int]:
|
||||
with open(path, "rb") as file:
|
||||
if file.read(2) == bytes.fromhex("FF0A"):
|
||||
return decode_codestream(file)
|
||||
return decode_container(file)
|
||||
@@ -643,16 +643,15 @@ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
|
||||
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
|
||||
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
|
||||
|
||||
# rename or add position_ids
|
||||
# remove position_ids for newer transformer, which causes error :(
|
||||
ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids"
|
||||
if ANOTHER_POSITION_IDS_KEY in new_sd:
|
||||
# waifu diffusion v1.4
|
||||
position_ids = new_sd[ANOTHER_POSITION_IDS_KEY]
|
||||
del new_sd[ANOTHER_POSITION_IDS_KEY]
|
||||
else:
|
||||
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
|
||||
|
||||
new_sd["text_model.embeddings.position_ids"] = position_ids
|
||||
if "text_model.embeddings.position_ids" in new_sd:
|
||||
del new_sd["text_model.embeddings.position_ids"]
|
||||
|
||||
return new_sd
|
||||
|
||||
|
||||
|
||||
@@ -344,8 +344,6 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser, support_text_en
|
||||
|
||||
def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
|
||||
assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません"
|
||||
if args.v_parameterization:
|
||||
logger.warning("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります")
|
||||
|
||||
if args.clip_skip is not None:
|
||||
logger.warning("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")
|
||||
|
||||
@@ -13,17 +13,7 @@ import re
|
||||
import shutil
|
||||
import time
|
||||
import typing
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union
|
||||
)
|
||||
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
|
||||
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState
|
||||
import glob
|
||||
import math
|
||||
@@ -84,7 +74,7 @@ import library.model_util as model_util
|
||||
import library.huggingface_util as huggingface_util
|
||||
import library.sai_model_spec as sai_model_spec
|
||||
import library.deepspeed_utils as deepspeed_utils
|
||||
from library.utils import setup_logging, pil_resize
|
||||
from library.utils import setup_logging, resize_image, validate_interpolation_fn
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
@@ -123,14 +113,16 @@ except:
|
||||
# JPEG-XL on Linux
|
||||
try:
|
||||
from jxlpy import JXLImagePlugin
|
||||
from library.jpeg_xl_util import get_jxl_size
|
||||
|
||||
IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
|
||||
except:
|
||||
pass
|
||||
|
||||
# JPEG-XL on Windows
|
||||
# JPEG-XL on Linux and Windows
|
||||
try:
|
||||
import pillow_jxl
|
||||
from library.jpeg_xl_util import get_jxl_size
|
||||
|
||||
IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
|
||||
except:
|
||||
@@ -146,12 +138,13 @@ IMAGE_TRANSFORMS = transforms.Compose(
|
||||
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
|
||||
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz"
|
||||
|
||||
|
||||
def split_train_val(
|
||||
paths: List[str],
|
||||
paths: List[str],
|
||||
sizes: List[Optional[Tuple[int, int]]],
|
||||
is_training_dataset: bool,
|
||||
validation_split: float,
|
||||
validation_seed: int | None
|
||||
is_training_dataset: bool,
|
||||
validation_split: float,
|
||||
validation_seed: int | None,
|
||||
) -> Tuple[List[str], List[Optional[Tuple[int, int]]]]:
|
||||
"""
|
||||
Split the dataset into train and validation
|
||||
@@ -214,6 +207,7 @@ class ImageInfo:
|
||||
self.text_encoder_pool2: Optional[torch.Tensor] = None
|
||||
|
||||
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
|
||||
self.resize_interpolation: Optional[str] = None
|
||||
|
||||
|
||||
class BucketManager:
|
||||
@@ -438,6 +432,7 @@ class BaseSubset:
|
||||
custom_attributes: Optional[Dict[str, Any]] = None,
|
||||
validation_seed: Optional[int] = None,
|
||||
validation_split: Optional[float] = 0.0,
|
||||
resize_interpolation: Optional[str] = None,
|
||||
) -> None:
|
||||
self.image_dir = image_dir
|
||||
self.alpha_mask = alpha_mask if alpha_mask is not None else False
|
||||
@@ -468,6 +463,8 @@ class BaseSubset:
|
||||
self.validation_seed = validation_seed
|
||||
self.validation_split = validation_split
|
||||
|
||||
self.resize_interpolation = resize_interpolation
|
||||
|
||||
|
||||
class DreamBoothSubset(BaseSubset):
|
||||
def __init__(
|
||||
@@ -499,6 +496,7 @@ class DreamBoothSubset(BaseSubset):
|
||||
custom_attributes: Optional[Dict[str, Any]] = None,
|
||||
validation_seed: Optional[int] = None,
|
||||
validation_split: Optional[float] = 0.0,
|
||||
resize_interpolation: Optional[str] = None,
|
||||
) -> None:
|
||||
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
||||
|
||||
@@ -526,6 +524,7 @@ class DreamBoothSubset(BaseSubset):
|
||||
custom_attributes=custom_attributes,
|
||||
validation_seed=validation_seed,
|
||||
validation_split=validation_split,
|
||||
resize_interpolation=resize_interpolation,
|
||||
)
|
||||
|
||||
self.is_reg = is_reg
|
||||
@@ -568,6 +567,7 @@ class FineTuningSubset(BaseSubset):
|
||||
custom_attributes: Optional[Dict[str, Any]] = None,
|
||||
validation_seed: Optional[int] = None,
|
||||
validation_split: Optional[float] = 0.0,
|
||||
resize_interpolation: Optional[str] = None,
|
||||
) -> None:
|
||||
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
|
||||
|
||||
@@ -595,6 +595,7 @@ class FineTuningSubset(BaseSubset):
|
||||
custom_attributes=custom_attributes,
|
||||
validation_seed=validation_seed,
|
||||
validation_split=validation_split,
|
||||
resize_interpolation=resize_interpolation,
|
||||
)
|
||||
|
||||
self.metadata_file = metadata_file
|
||||
@@ -633,6 +634,7 @@ class ControlNetSubset(BaseSubset):
|
||||
custom_attributes: Optional[Dict[str, Any]] = None,
|
||||
validation_seed: Optional[int] = None,
|
||||
validation_split: Optional[float] = 0.0,
|
||||
resize_interpolation: Optional[str] = None,
|
||||
) -> None:
|
||||
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
|
||||
|
||||
@@ -660,6 +662,7 @@ class ControlNetSubset(BaseSubset):
|
||||
custom_attributes=custom_attributes,
|
||||
validation_seed=validation_seed,
|
||||
validation_split=validation_split,
|
||||
resize_interpolation=resize_interpolation,
|
||||
)
|
||||
|
||||
self.conditioning_data_dir = conditioning_data_dir
|
||||
@@ -680,6 +683,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
resolution: Optional[Tuple[int, int]],
|
||||
network_multiplier: float,
|
||||
debug_dataset: bool,
|
||||
resize_interpolation: Optional[str] = None
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -714,6 +718,10 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
self.image_transforms = IMAGE_TRANSFORMS
|
||||
|
||||
if resize_interpolation is not None:
|
||||
assert validate_interpolation_fn(resize_interpolation), f"Resize interpolation \"{resize_interpolation}\" is not a valid interpolation"
|
||||
self.resize_interpolation = resize_interpolation
|
||||
|
||||
self.image_data: Dict[str, ImageInfo] = {}
|
||||
self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {}
|
||||
|
||||
@@ -1052,8 +1060,11 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)}
|
||||
logger.info(f"bucket {i}: resolution {reso}, count: {len(bucket)}")
|
||||
|
||||
img_ar_errors = np.array(img_ar_errors)
|
||||
mean_img_ar_error = np.mean(np.abs(img_ar_errors))
|
||||
if len(img_ar_errors) == 0:
|
||||
mean_img_ar_error = 0 # avoid NaN
|
||||
else:
|
||||
img_ar_errors = np.array(img_ar_errors)
|
||||
mean_img_ar_error = np.mean(np.abs(img_ar_errors))
|
||||
self.bucket_info["mean_img_ar_error"] = mean_img_ar_error
|
||||
logger.info(f"mean ar error (without repeats): {mean_img_ar_error}")
|
||||
|
||||
@@ -1457,6 +1468,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
)
|
||||
|
||||
def get_image_size(self, image_path):
|
||||
if image_path.endswith(".jxl") or image_path.endswith(".JXL"):
|
||||
return get_jxl_size(image_path)
|
||||
# return imagesize.get(image_path)
|
||||
image_size = imagesize.get(image_path)
|
||||
if image_size[0] <= 0:
|
||||
@@ -1503,7 +1516,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
nh = int(height * scale + 0.5)
|
||||
nw = int(width * scale + 0.5)
|
||||
assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}"
|
||||
image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA)
|
||||
image = resize_image(image, width, height, nw, nh, subset.resize_interpolation)
|
||||
face_cx = int(face_cx * scale + 0.5)
|
||||
face_cy = int(face_cy * scale + 0.5)
|
||||
height, width = nh, nw
|
||||
@@ -1600,7 +1613,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
if self.enable_bucket:
|
||||
img, original_size, crop_ltrb = trim_and_resize_if_required(
|
||||
subset.random_crop, img, image_info.bucket_reso, image_info.resized_size
|
||||
subset.random_crop, img, image_info.bucket_reso, image_info.resized_size, resize_interpolation=image_info.resize_interpolation
|
||||
)
|
||||
else:
|
||||
if face_cx > 0: # 顔位置情報あり
|
||||
@@ -1842,7 +1855,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
class DreamBoothDataset(BaseDataset):
|
||||
IMAGE_INFO_CACHE_FILE = "metadata_cache.json"
|
||||
|
||||
# The is_training_dataset defines the type of dataset, training or validation
|
||||
# The is_training_dataset defines the type of dataset, training or validation
|
||||
# if is_training_dataset is True -> training dataset
|
||||
# if is_training_dataset is False -> validation dataset
|
||||
def __init__(
|
||||
@@ -1861,8 +1874,9 @@ class DreamBoothDataset(BaseDataset):
|
||||
debug_dataset: bool,
|
||||
validation_split: float,
|
||||
validation_seed: Optional[int],
|
||||
resize_interpolation: Optional[str],
|
||||
) -> None:
|
||||
super().__init__(resolution, network_multiplier, debug_dataset)
|
||||
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
|
||||
|
||||
assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
|
||||
|
||||
@@ -1981,29 +1995,25 @@ class DreamBoothDataset(BaseDataset):
|
||||
logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}")
|
||||
|
||||
# We want to create a training and validation split. This should be improved in the future
|
||||
# to allow a clearer distinction between training and validation. This can be seen as a
|
||||
# to allow a clearer distinction between training and validation. This can be seen as a
|
||||
# short-term solution to limit what is necessary to implement validation datasets
|
||||
#
|
||||
#
|
||||
# We split the dataset for the subset based on if we are doing a validation split
|
||||
# The self.is_training_dataset defines the type of dataset, training or validation
|
||||
# The self.is_training_dataset defines the type of dataset, training or validation
|
||||
# if self.is_training_dataset is True -> training dataset
|
||||
# if self.is_training_dataset is False -> validation dataset
|
||||
if self.validation_split > 0.0:
|
||||
# For regularization images we do not want to split this dataset.
|
||||
# For regularization images we do not want to split this dataset.
|
||||
if subset.is_reg is True:
|
||||
# Skip any validation dataset for regularization images
|
||||
if self.is_training_dataset is False:
|
||||
img_paths = []
|
||||
sizes = []
|
||||
# Otherwise the img_paths remain as original img_paths and no split
|
||||
# Otherwise the img_paths remain as original img_paths and no split
|
||||
# required for training images dataset of regularization images
|
||||
else:
|
||||
img_paths, sizes = split_train_val(
|
||||
img_paths,
|
||||
sizes,
|
||||
self.is_training_dataset,
|
||||
self.validation_split,
|
||||
self.validation_seed
|
||||
img_paths, sizes, self.is_training_dataset, self.validation_split, self.validation_seed
|
||||
)
|
||||
|
||||
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
|
||||
@@ -2091,6 +2101,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
|
||||
for img_path, caption, size in zip(img_paths, captions, sizes):
|
||||
info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path)
|
||||
info.resize_interpolation = subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
|
||||
if size is not None:
|
||||
info.image_size = size
|
||||
if subset.is_reg:
|
||||
@@ -2146,8 +2157,9 @@ class FineTuningDataset(BaseDataset):
|
||||
debug_dataset: bool,
|
||||
validation_seed: int,
|
||||
validation_split: float,
|
||||
resize_interpolation: Optional[str],
|
||||
) -> None:
|
||||
super().__init__(resolution, network_multiplier, debug_dataset)
|
||||
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
|
||||
|
||||
self.batch_size = batch_size
|
||||
|
||||
@@ -2374,8 +2386,9 @@ class ControlNetDataset(BaseDataset):
|
||||
debug_dataset: bool,
|
||||
validation_split: float,
|
||||
validation_seed: Optional[int],
|
||||
resize_interpolation: Optional[str] = None,
|
||||
) -> None:
|
||||
super().__init__(resolution, network_multiplier, debug_dataset)
|
||||
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
|
||||
|
||||
db_subsets = []
|
||||
for subset in subsets:
|
||||
@@ -2407,6 +2420,7 @@ class ControlNetDataset(BaseDataset):
|
||||
subset.caption_suffix,
|
||||
subset.token_warmup_min,
|
||||
subset.token_warmup_step,
|
||||
resize_interpolation=subset.resize_interpolation,
|
||||
)
|
||||
db_subsets.append(db_subset)
|
||||
|
||||
@@ -2425,15 +2439,17 @@ class ControlNetDataset(BaseDataset):
|
||||
debug_dataset,
|
||||
validation_split,
|
||||
validation_seed,
|
||||
resize_interpolation,
|
||||
)
|
||||
|
||||
# config_util等から参照される値をいれておく(若干微妙なのでなんとかしたい)
|
||||
self.image_data = self.dreambooth_dataset_delegate.image_data
|
||||
self.batch_size = batch_size
|
||||
self.num_train_images = self.dreambooth_dataset_delegate.num_train_images
|
||||
self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images
|
||||
self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images
|
||||
self.validation_split = validation_split
|
||||
self.validation_seed = validation_seed
|
||||
self.resize_interpolation = resize_interpolation
|
||||
|
||||
# assert all conditioning data exists
|
||||
missing_imgs = []
|
||||
@@ -2521,9 +2537,8 @@ class ControlNetDataset(BaseDataset):
|
||||
assert (
|
||||
cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1]
|
||||
), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}"
|
||||
cond_img = cv2.resize(
|
||||
cond_img, image_info.resized_size, interpolation=cv2.INTER_AREA
|
||||
) # INTER_AREAでやりたいのでcv2でリサイズ
|
||||
|
||||
cond_img = resize_image(cond_img, original_size_hw[1], original_size_hw[0], target_size_hw[1], target_size_hw[0], self.resize_interpolation)
|
||||
|
||||
# TODO support random crop
|
||||
# 現在サポートしているcropはrandomではなく中央のみ
|
||||
@@ -2537,7 +2552,7 @@ class ControlNetDataset(BaseDataset):
|
||||
# ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
|
||||
# resize to target
|
||||
if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]:
|
||||
cond_img = pil_resize(cond_img, (int(target_size_hw[1]), int(target_size_hw[0])))
|
||||
cond_img = resize_image(cond_img, cond_img.shape[0], cond_img.shape[1], target_size_hw[1], target_size_hw[0], self.resize_interpolation)
|
||||
|
||||
if flipped:
|
||||
cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride
|
||||
@@ -2934,17 +2949,13 @@ def load_image(image_path, alpha=False):
|
||||
|
||||
# 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top, crop right, crop bottom)
|
||||
def trim_and_resize_if_required(
|
||||
random_crop: bool, image: np.ndarray, reso, resized_size: Tuple[int, int]
|
||||
random_crop: bool, image: np.ndarray, reso, resized_size: Tuple[int, int], resize_interpolation: Optional[str] = None
|
||||
) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int, int, int]]:
|
||||
image_height, image_width = image.shape[0:2]
|
||||
original_size = (image_width, image_height) # size before resize
|
||||
|
||||
if image_width != resized_size[0] or image_height != resized_size[1]:
|
||||
# リサイズする
|
||||
if image_width > resized_size[0] and image_height > resized_size[1]:
|
||||
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
|
||||
else:
|
||||
image = pil_resize(image, resized_size)
|
||||
image = resize_image(image, image_width, image_height, resized_size[0], resized_size[1], resize_interpolation)
|
||||
|
||||
image_height, image_width = image.shape[0:2]
|
||||
|
||||
@@ -2989,7 +3000,7 @@ def load_images_and_masks_for_caching(
|
||||
for info in image_infos:
|
||||
image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8)
|
||||
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
|
||||
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
|
||||
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation)
|
||||
|
||||
original_sizes.append(original_size)
|
||||
crop_ltrbs.append(crop_ltrb)
|
||||
@@ -3030,7 +3041,7 @@ def cache_batch_latents(
|
||||
for info in image_infos:
|
||||
image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8)
|
||||
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
|
||||
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
|
||||
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation)
|
||||
|
||||
info.latents_original_size = original_size
|
||||
info.latents_crop_ltrb = crop_ltrb
|
||||
@@ -4508,7 +4519,13 @@ def add_dataset_arguments(
|
||||
action="store_true",
|
||||
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--resize_interpolation",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area"],
|
||||
help="Resize interpolation when required. Default: area Options: lanczos, nearest, bilinear, bicubic, area / 必要に応じてサイズ補間を変更します。デフォルト: area オプション: lanczos, nearest, bilinear, bicubic, area",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--token_warmup_min",
|
||||
type=int,
|
||||
@@ -5481,6 +5498,11 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
|
||||
|
||||
|
||||
def patch_accelerator_for_fp16_training(accelerator):
|
||||
|
||||
from accelerate import DistributedType
|
||||
if accelerator.distributed_type == DistributedType.DEEPSPEED:
|
||||
return
|
||||
|
||||
org_unscale_grads = accelerator.scaler._unscale_grads_
|
||||
|
||||
def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
|
||||
@@ -5944,12 +5966,17 @@ def save_sd_model_on_train_end_common(
|
||||
|
||||
|
||||
def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device) -> torch.Tensor:
|
||||
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu")
|
||||
if min_timestep < max_timestep:
|
||||
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu")
|
||||
else:
|
||||
timesteps = torch.full((b_size,), max_timestep, device="cpu")
|
||||
timesteps = timesteps.long().to(device)
|
||||
return timesteps
|
||||
|
||||
|
||||
def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]:
|
||||
def get_noise_noisy_latents_and_timesteps(
|
||||
args, noise_scheduler, latents: torch.FloatTensor
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]:
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents, device=latents.device)
|
||||
if args.noise_offset:
|
||||
@@ -6159,6 +6186,11 @@ def line_to_prompt_dict(line: str) -> dict:
|
||||
prompt_dict["scale"] = float(m.group(1))
|
||||
continue
|
||||
|
||||
m = re.match(r"g ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # guidance scale
|
||||
prompt_dict["guidance_scale"] = float(m.group(1))
|
||||
continue
|
||||
|
||||
m = re.match(r"n (.+)", parg, re.IGNORECASE)
|
||||
if m: # negative prompt
|
||||
prompt_dict["negative_prompt"] = m.group(1)
|
||||
@@ -6441,7 +6473,7 @@ def sample_image_inference(
|
||||
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
|
||||
|
||||
|
||||
def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tracker_name: str):
|
||||
def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tracker_name: str):
|
||||
"""
|
||||
Initialize experiment trackers with tracker specific behaviors
|
||||
"""
|
||||
@@ -6458,13 +6490,17 @@ def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tr
|
||||
)
|
||||
|
||||
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
|
||||
import wandb
|
||||
import wandb
|
||||
|
||||
wandb_tracker = accelerator.get_tracker("wandb", unwrap=True)
|
||||
|
||||
# Define specific metrics to handle validation and epochs "steps"
|
||||
wandb_tracker.define_metric("epoch", hidden=True)
|
||||
wandb_tracker.define_metric("val_step", hidden=True)
|
||||
|
||||
wandb_tracker.define_metric("global_step", hidden=True)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
@@ -6537,3 +6573,4 @@ class LossRecorder:
|
||||
if losses == 0:
|
||||
return 0
|
||||
return self.loss_total / losses
|
||||
|
||||
|
||||
131
library/utils.py
131
library/utils.py
@@ -16,7 +16,6 @@ from PIL import Image
|
||||
import numpy as np
|
||||
from safetensors.torch import load_file
|
||||
|
||||
|
||||
def fire_in_thread(f, *args, **kwargs):
|
||||
threading.Thread(target=f, args=args, kwargs=kwargs).start()
|
||||
|
||||
@@ -89,6 +88,8 @@ def setup_logging(args=None, log_level=None, reset=False):
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(msg_init)
|
||||
|
||||
setup_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# endregion
|
||||
|
||||
@@ -261,11 +262,10 @@ def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata:
|
||||
|
||||
|
||||
class MemoryEfficientSafeOpen:
|
||||
# does not support metadata loading
|
||||
def __init__(self, filename):
|
||||
self.filename = filename
|
||||
self.header, self.header_size = self._read_header()
|
||||
self.file = open(filename, "rb")
|
||||
self.header, self.header_size = self._read_header()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
@@ -276,6 +276,9 @@ class MemoryEfficientSafeOpen:
|
||||
def keys(self):
|
||||
return [k for k in self.header.keys() if k != "__metadata__"]
|
||||
|
||||
def metadata(self) -> Dict[str, str]:
|
||||
return self.header.get("__metadata__", {})
|
||||
|
||||
def get_tensor(self, key):
|
||||
if key not in self.header:
|
||||
raise KeyError(f"Tensor '{key}' not found in the file")
|
||||
@@ -293,10 +296,9 @@ class MemoryEfficientSafeOpen:
|
||||
return self._deserialize_tensor(tensor_bytes, metadata)
|
||||
|
||||
def _read_header(self):
|
||||
with open(self.filename, "rb") as f:
|
||||
header_size = struct.unpack("<Q", f.read(8))[0]
|
||||
header_json = f.read(header_size).decode("utf-8")
|
||||
return json.loads(header_json), header_size
|
||||
header_size = struct.unpack("<Q", self.file.read(8))[0]
|
||||
header_json = self.file.read(header_size).decode("utf-8")
|
||||
return json.loads(header_json), header_size
|
||||
|
||||
def _deserialize_tensor(self, tensor_bytes, metadata):
|
||||
dtype = self._get_torch_dtype(metadata["dtype"])
|
||||
@@ -377,7 +379,7 @@ def load_safetensors(
|
||||
# region Image utils
|
||||
|
||||
|
||||
def pil_resize(image, size, interpolation=Image.LANCZOS):
|
||||
def pil_resize(image, size, interpolation):
|
||||
has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False
|
||||
|
||||
if has_alpha:
|
||||
@@ -385,7 +387,7 @@ def pil_resize(image, size, interpolation=Image.LANCZOS):
|
||||
else:
|
||||
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
||||
|
||||
resized_pil = pil_image.resize(size, interpolation)
|
||||
resized_pil = pil_image.resize(size, resample=interpolation)
|
||||
|
||||
# Convert back to cv2 format
|
||||
if has_alpha:
|
||||
@@ -396,6 +398,117 @@ def pil_resize(image, size, interpolation=Image.LANCZOS):
|
||||
return resized_cv2
|
||||
|
||||
|
||||
def resize_image(image: np.ndarray, width: int, height: int, resized_width: int, resized_height: int, resize_interpolation: Optional[str] = None):
|
||||
"""
|
||||
Resize image with resize interpolation. Default interpolation to AREA if image is smaller, else LANCZOS.
|
||||
|
||||
Args:
|
||||
image: numpy.ndarray
|
||||
width: int Original image width
|
||||
height: int Original image height
|
||||
resized_width: int Resized image width
|
||||
resized_height: int Resized image height
|
||||
resize_interpolation: Optional[str] Resize interpolation method "lanczos", "area", "bilinear", "bicubic", "nearest", "box"
|
||||
|
||||
Returns:
|
||||
image
|
||||
"""
|
||||
|
||||
# Ensure all size parameters are actual integers
|
||||
width = int(width)
|
||||
height = int(height)
|
||||
resized_width = int(resized_width)
|
||||
resized_height = int(resized_height)
|
||||
|
||||
if resize_interpolation is None:
|
||||
if width >= resized_width and height >= resized_height:
|
||||
resize_interpolation = "area"
|
||||
else:
|
||||
resize_interpolation = "lanczos"
|
||||
|
||||
# we use PIL for lanczos (for backward compatibility) and box, cv2 for others
|
||||
use_pil = resize_interpolation in ["lanczos", "lanczos4", "box"]
|
||||
|
||||
resized_size = (resized_width, resized_height)
|
||||
if use_pil:
|
||||
interpolation = get_pil_interpolation(resize_interpolation)
|
||||
image = pil_resize(image, resized_size, interpolation=interpolation)
|
||||
logger.debug(f"resize image using {resize_interpolation} (PIL)")
|
||||
else:
|
||||
interpolation = get_cv2_interpolation(resize_interpolation)
|
||||
image = cv2.resize(image, resized_size, interpolation=interpolation)
|
||||
logger.debug(f"resize image using {resize_interpolation} (cv2)")
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def get_cv2_interpolation(interpolation: Optional[str]) -> Optional[int]:
|
||||
"""
|
||||
Convert interpolation value to cv2 interpolation integer
|
||||
|
||||
https://docs.opencv.org/3.4/da/d54/group__imgproc__transform.html#ga5bb5a1fea74ea38e1a5445ca803ff121
|
||||
"""
|
||||
if interpolation is None:
|
||||
return None
|
||||
|
||||
if interpolation == "lanczos" or interpolation == "lanczos4":
|
||||
# Lanczos interpolation over 8x8 neighborhood
|
||||
return cv2.INTER_LANCZOS4
|
||||
elif interpolation == "nearest":
|
||||
# Bit exact nearest neighbor interpolation. This will produce same results as the nearest neighbor method in PIL, scikit-image or Matlab.
|
||||
return cv2.INTER_NEAREST_EXACT
|
||||
elif interpolation == "bilinear" or interpolation == "linear":
|
||||
# bilinear interpolation
|
||||
return cv2.INTER_LINEAR
|
||||
elif interpolation == "bicubic" or interpolation == "cubic":
|
||||
# bicubic interpolation
|
||||
return cv2.INTER_CUBIC
|
||||
elif interpolation == "area":
|
||||
# resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.
|
||||
return cv2.INTER_AREA
|
||||
elif interpolation == "box":
|
||||
# resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.
|
||||
return cv2.INTER_AREA
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resampling]:
|
||||
"""
|
||||
Convert interpolation value to PIL interpolation
|
||||
|
||||
https://pillow.readthedocs.io/en/stable/handbook/concepts.html#concept-filters
|
||||
"""
|
||||
if interpolation is None:
|
||||
return None
|
||||
|
||||
if interpolation == "lanczos":
|
||||
return Image.Resampling.LANCZOS
|
||||
elif interpolation == "nearest":
|
||||
# Pick one nearest pixel from the input image. Ignore all other input pixels.
|
||||
return Image.Resampling.NEAREST
|
||||
elif interpolation == "bilinear" or interpolation == "linear":
|
||||
# For resize calculate the output pixel value using linear interpolation on all pixels that may contribute to the output value. For other transformations linear interpolation over a 2x2 environment in the input image is used.
|
||||
return Image.Resampling.BILINEAR
|
||||
elif interpolation == "bicubic" or interpolation == "cubic":
|
||||
# For resize calculate the output pixel value using cubic interpolation on all pixels that may contribute to the output value. For other transformations cubic interpolation over a 4x4 environment in the input image is used.
|
||||
return Image.Resampling.BICUBIC
|
||||
elif interpolation == "area":
|
||||
# Image.Resampling.BOX may be more appropriate if upscaling
|
||||
# Area interpolation is related to cv2.INTER_AREA
|
||||
# Produces a sharper image than Resampling.BILINEAR, doesn’t have dislocations on local level like with Resampling.BOX.
|
||||
return Image.Resampling.HAMMING
|
||||
elif interpolation == "box":
|
||||
# Each pixel of source image contributes to one pixel of the destination image with identical weights. For upscaling is equivalent of Resampling.NEAREST.
|
||||
return Image.Resampling.BOX
|
||||
else:
|
||||
return None
|
||||
|
||||
def validate_interpolation_fn(interpolation_str: str) -> bool:
|
||||
"""
|
||||
Check if a interpolation function is supported
|
||||
"""
|
||||
return interpolation_str in ["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area", "box"]
|
||||
|
||||
# endregion
|
||||
|
||||
# TODO make inf_utils.py
|
||||
|
||||
@@ -268,7 +268,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
|
||||
class DyLoRANetwork(torch.nn.Module):
|
||||
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
||||
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
|
||||
LORA_PREFIX_UNET = "lora_unet"
|
||||
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
||||
|
||||
|
||||
@@ -866,7 +866,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
||||
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
|
||||
LORA_PREFIX_UNET = "lora_unet"
|
||||
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
||||
|
||||
|
||||
@@ -278,7 +278,7 @@ def merge_lora_weights(pipe, weights_sd: Dict, multiplier: float = 1.0):
|
||||
class LoRANetwork(torch.nn.Module):
|
||||
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
||||
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
|
||||
LORA_PREFIX_UNET = "lora_unet"
|
||||
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
||||
|
||||
|
||||
@@ -755,7 +755,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
||||
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
|
||||
LORA_PREFIX_UNET = "lora_unet"
|
||||
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
||||
|
||||
|
||||
@@ -9,11 +9,13 @@
|
||||
|
||||
import math
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
from diffusers import AutoencoderKL
|
||||
from transformers import CLIPTextModel
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import re
|
||||
from library.utils import setup_logging
|
||||
from library.sdxl_original_unet import SdxlUNet2DConditionModel
|
||||
@@ -44,6 +46,8 @@ class LoRAModule(torch.nn.Module):
|
||||
rank_dropout=None,
|
||||
module_dropout=None,
|
||||
split_dims: Optional[List[int]] = None,
|
||||
ggpo_beta: Optional[float] = None,
|
||||
ggpo_sigma: Optional[float] = None,
|
||||
):
|
||||
"""
|
||||
if alpha == 0 or None, alpha is rank (no scaling).
|
||||
@@ -103,9 +107,20 @@ class LoRAModule(torch.nn.Module):
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
|
||||
self.ggpo_sigma = ggpo_sigma
|
||||
self.ggpo_beta = ggpo_beta
|
||||
|
||||
if self.ggpo_beta is not None and self.ggpo_sigma is not None:
|
||||
self.combined_weight_norms = None
|
||||
self.grad_norms = None
|
||||
self.perturbation_norm_factor = 1.0 / math.sqrt(org_module.weight.shape[0])
|
||||
self.initialize_norm_cache(org_module.weight)
|
||||
self.org_module_shape: tuple[int] = org_module.weight.shape
|
||||
|
||||
def apply_to(self):
|
||||
self.org_forward = self.org_module.forward
|
||||
self.org_module.forward = self.forward
|
||||
|
||||
del self.org_module
|
||||
|
||||
def forward(self, x):
|
||||
@@ -140,7 +155,17 @@ class LoRAModule(torch.nn.Module):
|
||||
|
||||
lx = self.lora_up(lx)
|
||||
|
||||
return org_forwarded + lx * self.multiplier * scale
|
||||
# LoRA Gradient-Guided Perturbation Optimization
|
||||
if self.training and self.ggpo_sigma is not None and self.ggpo_beta is not None and self.combined_weight_norms is not None and self.grad_norms is not None:
|
||||
with torch.no_grad():
|
||||
perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms ** 2)) + (self.ggpo_beta * (self.grad_norms ** 2))
|
||||
perturbation_scale_factor = (perturbation_scale * self.perturbation_norm_factor).to(self.device)
|
||||
perturbation = torch.randn(self.org_module_shape, dtype=self.dtype, device=self.device)
|
||||
perturbation.mul_(perturbation_scale_factor)
|
||||
perturbation_output = x @ perturbation.T # Result: (batch × n)
|
||||
return org_forwarded + (self.multiplier * scale * lx) + perturbation_output
|
||||
else:
|
||||
return org_forwarded + lx * self.multiplier * scale
|
||||
else:
|
||||
lxs = [lora_down(x) for lora_down in self.lora_down]
|
||||
|
||||
@@ -167,6 +192,116 @@ class LoRAModule(torch.nn.Module):
|
||||
|
||||
return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale
|
||||
|
||||
@torch.no_grad()
|
||||
def initialize_norm_cache(self, org_module_weight: Tensor):
|
||||
# Choose a reasonable sample size
|
||||
n_rows = org_module_weight.shape[0]
|
||||
sample_size = min(1000, n_rows) # Cap at 1000 samples or use all if smaller
|
||||
|
||||
# Sample random indices across all rows
|
||||
indices = torch.randperm(n_rows)[:sample_size]
|
||||
|
||||
# Convert to a supported data type first, then index
|
||||
# Use float32 for indexing operations
|
||||
weights_float32 = org_module_weight.to(dtype=torch.float32)
|
||||
sampled_weights = weights_float32[indices].to(device=self.device)
|
||||
|
||||
# Calculate sampled norms
|
||||
sampled_norms = torch.norm(sampled_weights, dim=1, keepdim=True)
|
||||
|
||||
# Store the mean norm as our estimate
|
||||
self.org_weight_norm_estimate = sampled_norms.mean()
|
||||
|
||||
# Optional: store standard deviation for confidence intervals
|
||||
self.org_weight_norm_std = sampled_norms.std()
|
||||
|
||||
# Free memory
|
||||
del sampled_weights, weights_float32
|
||||
|
||||
@torch.no_grad()
|
||||
def validate_norm_approximation(self, org_module_weight: Tensor, verbose=True):
|
||||
# Calculate the true norm (this will be slow but it's just for validation)
|
||||
true_norms = []
|
||||
chunk_size = 1024 # Process in chunks to avoid OOM
|
||||
|
||||
for i in range(0, org_module_weight.shape[0], chunk_size):
|
||||
end_idx = min(i + chunk_size, org_module_weight.shape[0])
|
||||
chunk = org_module_weight[i:end_idx].to(device=self.device, dtype=self.dtype)
|
||||
chunk_norms = torch.norm(chunk, dim=1, keepdim=True)
|
||||
true_norms.append(chunk_norms.cpu())
|
||||
del chunk
|
||||
|
||||
true_norms = torch.cat(true_norms, dim=0)
|
||||
true_mean_norm = true_norms.mean().item()
|
||||
|
||||
# Compare with our estimate
|
||||
estimated_norm = self.org_weight_norm_estimate.item()
|
||||
|
||||
# Calculate error metrics
|
||||
absolute_error = abs(true_mean_norm - estimated_norm)
|
||||
relative_error = absolute_error / true_mean_norm * 100 # as percentage
|
||||
|
||||
if verbose:
|
||||
logger.info(f"True mean norm: {true_mean_norm:.6f}")
|
||||
logger.info(f"Estimated norm: {estimated_norm:.6f}")
|
||||
logger.info(f"Absolute error: {absolute_error:.6f}")
|
||||
logger.info(f"Relative error: {relative_error:.2f}%")
|
||||
|
||||
return {
|
||||
'true_mean_norm': true_mean_norm,
|
||||
'estimated_norm': estimated_norm,
|
||||
'absolute_error': absolute_error,
|
||||
'relative_error': relative_error
|
||||
}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def update_norms(self):
|
||||
# Not running GGPO so not currently running update norms
|
||||
if self.ggpo_beta is None or self.ggpo_sigma is None:
|
||||
return
|
||||
|
||||
# only update norms when we are training
|
||||
if self.training is False:
|
||||
return
|
||||
|
||||
module_weights = self.lora_up.weight @ self.lora_down.weight
|
||||
module_weights.mul(self.scale)
|
||||
|
||||
self.weight_norms = torch.norm(module_weights, dim=1, keepdim=True)
|
||||
self.combined_weight_norms = torch.sqrt((self.org_weight_norm_estimate**2) +
|
||||
torch.sum(module_weights**2, dim=1, keepdim=True))
|
||||
|
||||
@torch.no_grad()
|
||||
def update_grad_norms(self):
|
||||
if self.training is False:
|
||||
print(f"skipping update_grad_norms for {self.lora_name}")
|
||||
return
|
||||
|
||||
lora_down_grad = None
|
||||
lora_up_grad = None
|
||||
|
||||
for name, param in self.named_parameters():
|
||||
if name == "lora_down.weight":
|
||||
lora_down_grad = param.grad
|
||||
elif name == "lora_up.weight":
|
||||
lora_up_grad = param.grad
|
||||
|
||||
# Calculate gradient norms if we have both gradients
|
||||
if lora_down_grad is not None and lora_up_grad is not None:
|
||||
with torch.autocast(self.device.type):
|
||||
approx_grad = self.scale * ((self.lora_up.weight @ lora_down_grad) + (lora_up_grad @ self.lora_down.weight))
|
||||
self.grad_norms = torch.norm(approx_grad, dim=1, keepdim=True)
|
||||
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
|
||||
class LoRAInfModule(LoRAModule):
|
||||
def __init__(
|
||||
@@ -420,6 +555,16 @@ def create_network(
|
||||
if split_qkv is not None:
|
||||
split_qkv = True if split_qkv == "True" else False
|
||||
|
||||
ggpo_beta = kwargs.get("ggpo_beta", None)
|
||||
ggpo_sigma = kwargs.get("ggpo_sigma", None)
|
||||
|
||||
if ggpo_beta is not None:
|
||||
ggpo_beta = float(ggpo_beta)
|
||||
|
||||
if ggpo_sigma is not None:
|
||||
ggpo_sigma = float(ggpo_sigma)
|
||||
|
||||
|
||||
# train T5XXL
|
||||
train_t5xxl = kwargs.get("train_t5xxl", False)
|
||||
if train_t5xxl is not None:
|
||||
@@ -449,6 +594,8 @@ def create_network(
|
||||
in_dims=in_dims,
|
||||
train_double_block_indices=train_double_block_indices,
|
||||
train_single_block_indices=train_single_block_indices,
|
||||
ggpo_beta=ggpo_beta,
|
||||
ggpo_sigma=ggpo_sigma,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
@@ -561,6 +708,8 @@ class LoRANetwork(torch.nn.Module):
|
||||
in_dims: Optional[List[int]] = None,
|
||||
train_double_block_indices: Optional[List[bool]] = None,
|
||||
train_single_block_indices: Optional[List[bool]] = None,
|
||||
ggpo_beta: Optional[float] = None,
|
||||
ggpo_sigma: Optional[float] = None,
|
||||
verbose: Optional[bool] = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -599,10 +748,16 @@ class LoRANetwork(torch.nn.Module):
|
||||
# logger.info(
|
||||
# f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
|
||||
# )
|
||||
|
||||
if ggpo_beta is not None and ggpo_sigma is not None:
|
||||
logger.info(f"LoRA-GGPO training sigma: {ggpo_sigma} beta: {ggpo_beta}")
|
||||
|
||||
if self.split_qkv:
|
||||
logger.info(f"split qkv for LoRA")
|
||||
if self.train_blocks is not None:
|
||||
logger.info(f"train {self.train_blocks} blocks only")
|
||||
|
||||
|
||||
if train_t5xxl:
|
||||
logger.info(f"train T5XXL as well")
|
||||
|
||||
@@ -722,6 +877,8 @@ class LoRANetwork(torch.nn.Module):
|
||||
rank_dropout=rank_dropout,
|
||||
module_dropout=module_dropout,
|
||||
split_dims=split_dims,
|
||||
ggpo_beta=ggpo_beta,
|
||||
ggpo_sigma=ggpo_sigma,
|
||||
)
|
||||
loras.append(lora)
|
||||
|
||||
@@ -790,6 +947,36 @@ class LoRANetwork(torch.nn.Module):
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.enabled = is_enabled
|
||||
|
||||
def update_norms(self):
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.update_norms()
|
||||
|
||||
def update_grad_norms(self):
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.update_grad_norms()
|
||||
|
||||
def grad_norms(self) -> Tensor | None:
|
||||
grad_norms = []
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
if hasattr(lora, "grad_norms") and lora.grad_norms is not None:
|
||||
grad_norms.append(lora.grad_norms.mean(dim=0))
|
||||
return torch.stack(grad_norms) if len(grad_norms) > 0 else None
|
||||
|
||||
def weight_norms(self) -> Tensor | None:
|
||||
weight_norms = []
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
if hasattr(lora, "weight_norms") and lora.weight_norms is not None:
|
||||
weight_norms.append(lora.weight_norms.mean(dim=0))
|
||||
return torch.stack(weight_norms) if len(weight_norms) > 0 else None
|
||||
|
||||
def combined_weight_norms(self) -> Tensor | None:
|
||||
combined_weight_norms = []
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
if hasattr(lora, "combined_weight_norms") and lora.combined_weight_norms is not None:
|
||||
combined_weight_norms.append(lora.combined_weight_norms.mean(dim=0))
|
||||
return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else None
|
||||
|
||||
|
||||
def load_weights(self, file):
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
@@ -6,3 +6,4 @@ filterwarnings =
|
||||
ignore::DeprecationWarning
|
||||
ignore::UserWarning
|
||||
ignore::FutureWarning
|
||||
pythonpath = .
|
||||
|
||||
@@ -7,9 +7,11 @@ opencv-python==4.8.1.78
|
||||
einops==0.7.0
|
||||
pytorch-lightning==1.9.0
|
||||
bitsandbytes==0.44.0
|
||||
prodigyopt==1.0
|
||||
lion-pytorch==0.0.6
|
||||
schedulefree==1.4
|
||||
pytorch-optimizer==3.5.0
|
||||
prodigy-plus-schedule-free==1.9.0
|
||||
prodigyopt==1.1.2
|
||||
tensorboard
|
||||
safetensors==0.4.4
|
||||
# gradio==3.16.2
|
||||
|
||||
@@ -26,7 +26,12 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
super().__init__()
|
||||
self.sample_prompts_te_outputs = None
|
||||
|
||||
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
|
||||
def assert_extra_args(
|
||||
self,
|
||||
args,
|
||||
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
|
||||
val_dataset_group: Optional[train_util.DatasetGroup],
|
||||
):
|
||||
# super().assert_extra_args(args, train_dataset_group)
|
||||
# sdxl_train_util.verify_sdxl_training_args(args)
|
||||
|
||||
@@ -299,7 +304,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.training_shift)
|
||||
return noise_scheduler
|
||||
|
||||
def encode_images_to_latents(self, args, accelerator, vae, images):
|
||||
def encode_images_to_latents(self, args, vae, images):
|
||||
return vae.encode(images)
|
||||
|
||||
def shift_scale_latents(self, args, latents):
|
||||
@@ -317,7 +322,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
network,
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
is_train=True
|
||||
is_train=True,
|
||||
):
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
@@ -445,14 +450,19 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
text_encoder.to(te_weight_dtype) # fp8
|
||||
prepare_fp8(text_encoder, weight_dtype)
|
||||
|
||||
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
||||
# drop cached text encoder outputs
|
||||
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True):
|
||||
# drop cached text encoder outputs: in validation, we drop cached outputs deterministically by fixed seed
|
||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||
if text_encoder_outputs_list is not None:
|
||||
text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||
text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list)
|
||||
batch["text_encoder_outputs_list"] = text_encoder_outputs_list
|
||||
|
||||
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
||||
if self.is_swapping_blocks:
|
||||
# prepare for next forward: because backward pass is not called, we need to prepare it here
|
||||
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
|
||||
|
||||
def prepare_unet_with_accelerator(
|
||||
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
|
||||
) -> torch.nn.Module:
|
||||
|
||||
@@ -24,7 +24,6 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
self.is_sdxl = True
|
||||
|
||||
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
|
||||
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
|
||||
sdxl_train_util.verify_sdxl_training_args(args)
|
||||
|
||||
if args.cache_text_encoder_outputs:
|
||||
|
||||
220
tests/library/test_flux_train_utils.py
Normal file
220
tests/library/test_flux_train_utils.py
Normal file
@@ -0,0 +1,220 @@
|
||||
import pytest
|
||||
import torch
|
||||
from unittest.mock import MagicMock, patch
|
||||
from library.flux_train_utils import (
|
||||
get_noisy_model_input_and_timesteps,
|
||||
)
|
||||
|
||||
# Mock classes and functions
|
||||
class MockNoiseScheduler:
|
||||
def __init__(self, num_train_timesteps=1000):
|
||||
self.config = MagicMock()
|
||||
self.config.num_train_timesteps = num_train_timesteps
|
||||
self.timesteps = torch.arange(num_train_timesteps, dtype=torch.long)
|
||||
|
||||
|
||||
# Create fixtures for commonly used objects
|
||||
@pytest.fixture
|
||||
def args():
|
||||
args = MagicMock()
|
||||
args.timestep_sampling = "uniform"
|
||||
args.weighting_scheme = "uniform"
|
||||
args.logit_mean = 0.0
|
||||
args.logit_std = 1.0
|
||||
args.mode_scale = 1.0
|
||||
args.sigmoid_scale = 1.0
|
||||
args.discrete_flow_shift = 3.1582
|
||||
args.ip_noise_gamma = None
|
||||
args.ip_noise_gamma_random_strength = False
|
||||
return args
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def noise_scheduler():
|
||||
return MockNoiseScheduler(num_train_timesteps=1000)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def latents():
|
||||
return torch.randn(2, 4, 8, 8)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def noise():
|
||||
return torch.randn(2, 4, 8, 8)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def device():
|
||||
# return "cuda" if torch.cuda.is_available() else "cpu"
|
||||
return "cpu"
|
||||
|
||||
|
||||
# Mock the required functions
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_functions():
|
||||
with (
|
||||
patch("torch.sigmoid", side_effect=torch.sigmoid),
|
||||
patch("torch.rand", side_effect=torch.rand),
|
||||
patch("torch.randn", side_effect=torch.randn),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
# Test different timestep sampling methods
|
||||
def test_uniform_sampling(args, noise_scheduler, latents, noise, device):
|
||||
args.timestep_sampling = "uniform"
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (latents.shape[0],)
|
||||
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
|
||||
assert noisy_input.dtype == dtype
|
||||
assert timesteps.dtype == dtype
|
||||
|
||||
|
||||
def test_sigmoid_sampling(args, noise_scheduler, latents, noise, device):
|
||||
args.timestep_sampling = "sigmoid"
|
||||
args.sigmoid_scale = 1.0
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (latents.shape[0],)
|
||||
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
|
||||
|
||||
|
||||
def test_shift_sampling(args, noise_scheduler, latents, noise, device):
|
||||
args.timestep_sampling = "shift"
|
||||
args.sigmoid_scale = 1.0
|
||||
args.discrete_flow_shift = 3.1582
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (latents.shape[0],)
|
||||
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
|
||||
|
||||
|
||||
def test_flux_shift_sampling(args, noise_scheduler, latents, noise, device):
|
||||
args.timestep_sampling = "flux_shift"
|
||||
args.sigmoid_scale = 1.0
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (latents.shape[0],)
|
||||
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
|
||||
|
||||
|
||||
def test_weighting_scheme(args, noise_scheduler, latents, noise, device):
|
||||
# Mock the necessary functions for this specific test
|
||||
with patch("library.flux_train_utils.compute_density_for_timestep_sampling",
|
||||
return_value=torch.tensor([0.3, 0.7], device=device)), \
|
||||
patch("library.flux_train_utils.get_sigmas",
|
||||
return_value=torch.tensor([[0.3], [0.7]], device=device).view(-1, 1, 1, 1)):
|
||||
|
||||
args.timestep_sampling = "other" # Will trigger the weighting scheme path
|
||||
args.weighting_scheme = "uniform"
|
||||
args.logit_mean = 0.0
|
||||
args.logit_std = 1.0
|
||||
args.mode_scale = 1.0
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler, latents, noise, device, dtype
|
||||
)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (latents.shape[0],)
|
||||
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
|
||||
|
||||
|
||||
# Test IP noise options
|
||||
def test_with_ip_noise(args, noise_scheduler, latents, noise, device):
|
||||
args.ip_noise_gamma = 0.5
|
||||
args.ip_noise_gamma_random_strength = False
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (latents.shape[0],)
|
||||
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
|
||||
|
||||
|
||||
def test_with_random_ip_noise(args, noise_scheduler, latents, noise, device):
|
||||
args.ip_noise_gamma = 0.1
|
||||
args.ip_noise_gamma_random_strength = True
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (latents.shape[0],)
|
||||
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
|
||||
|
||||
|
||||
# Test different data types
|
||||
def test_float16_dtype(args, noise_scheduler, latents, noise, device):
|
||||
dtype = torch.float16
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.dtype == dtype
|
||||
assert timesteps.dtype == dtype
|
||||
|
||||
|
||||
# Test different batch sizes
|
||||
def test_different_batch_size(args, noise_scheduler, device):
|
||||
latents = torch.randn(5, 4, 8, 8) # batch size of 5
|
||||
noise = torch.randn(5, 4, 8, 8)
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (5,)
|
||||
assert sigmas.shape == (5, 1, 1, 1)
|
||||
|
||||
|
||||
# Test different image sizes
|
||||
def test_different_image_size(args, noise_scheduler, device):
|
||||
latents = torch.randn(2, 4, 16, 16) # larger image size
|
||||
noise = torch.randn(2, 4, 16, 16)
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (2,)
|
||||
assert sigmas.shape == (2, 1, 1, 1)
|
||||
|
||||
|
||||
# Test edge cases
|
||||
def test_zero_batch_size(args, noise_scheduler, device):
|
||||
with pytest.raises(AssertionError): # expecting an error with zero batch size
|
||||
latents = torch.randn(0, 4, 8, 8)
|
||||
noise = torch.randn(0, 4, 8, 8)
|
||||
dtype = torch.float32
|
||||
|
||||
get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
|
||||
def test_different_timestep_count(args, device):
|
||||
noise_scheduler = MockNoiseScheduler(num_train_timesteps=500) # different timestep count
|
||||
latents = torch.randn(2, 4, 8, 8)
|
||||
noise = torch.randn(2, 4, 8, 8)
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (2,)
|
||||
# Check that timesteps are within the proper range
|
||||
assert torch.all(timesteps < 500)
|
||||
@@ -15,7 +15,7 @@ import os
|
||||
from anime_face_detector import create_detector
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from library.utils import setup_logging, pil_resize
|
||||
from library.utils import setup_logging, resize_image
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -170,12 +170,9 @@ def process(args):
|
||||
scale = max(cur_crop_width / w, cur_crop_height / h)
|
||||
|
||||
if scale != 1.0:
|
||||
w = int(w * scale + .5)
|
||||
h = int(h * scale + .5)
|
||||
if scale < 1.0:
|
||||
face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA)
|
||||
else:
|
||||
face_img = pil_resize(face_img, (w, h))
|
||||
rw = int(w * scale + .5)
|
||||
rh = int(h * scale + .5)
|
||||
face_img = resize_image(face_img, w, h, rw, rh)
|
||||
cx = int(cx * scale + .5)
|
||||
cy = int(cy * scale + .5)
|
||||
fw = int(fw * scale + .5)
|
||||
|
||||
166
tools/merge_sd3_safetensors.py
Normal file
166
tools/merge_sd3_safetensors.py
Normal file
@@ -0,0 +1,166 @@
|
||||
import argparse
|
||||
import os
|
||||
import gc
|
||||
from typing import Dict, Optional, Union
|
||||
import torch
|
||||
from safetensors.torch import safe_open
|
||||
|
||||
from library.utils import setup_logging
|
||||
from library.utils import load_safetensors, mem_eff_save_file, str_to_dtype
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def merge_safetensors(
|
||||
dit_path: str,
|
||||
vae_path: Optional[str] = None,
|
||||
clip_l_path: Optional[str] = None,
|
||||
clip_g_path: Optional[str] = None,
|
||||
t5xxl_path: Optional[str] = None,
|
||||
output_path: str = "merged_model.safetensors",
|
||||
device: str = "cpu",
|
||||
save_precision: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Merge multiple safetensors files into a single file
|
||||
|
||||
Args:
|
||||
dit_path: Path to the DiT/MMDiT model
|
||||
vae_path: Path to the VAE model
|
||||
clip_l_path: Path to the CLIP-L model
|
||||
clip_g_path: Path to the CLIP-G model
|
||||
t5xxl_path: Path to the T5-XXL model
|
||||
output_path: Path to save the merged model
|
||||
device: Device to load tensors to
|
||||
save_precision: Target dtype for model weights (e.g. 'fp16', 'bf16')
|
||||
"""
|
||||
logger.info("Starting to merge safetensors files...")
|
||||
|
||||
# Convert save_precision string to torch dtype if specified
|
||||
if save_precision:
|
||||
target_dtype = str_to_dtype(save_precision)
|
||||
else:
|
||||
target_dtype = None
|
||||
|
||||
# 1. Get DiT metadata if available
|
||||
metadata = None
|
||||
try:
|
||||
with safe_open(dit_path, framework="pt") as f:
|
||||
metadata = f.metadata() # may be None
|
||||
if metadata:
|
||||
logger.info(f"Found metadata in DiT model: {metadata}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read metadata from DiT model: {e}")
|
||||
|
||||
# 2. Create empty merged state dict
|
||||
merged_state_dict = {}
|
||||
|
||||
# 3. Load and merge each model with memory management
|
||||
|
||||
# DiT/MMDiT - prefix: model.diffusion_model.
|
||||
# This state dict may have VAE keys.
|
||||
logger.info(f"Loading DiT model from {dit_path}")
|
||||
dit_state_dict = load_safetensors(dit_path, device=device, disable_mmap=True, dtype=target_dtype)
|
||||
logger.info(f"Adding DiT model with {len(dit_state_dict)} keys")
|
||||
for key, value in dit_state_dict.items():
|
||||
if key.startswith("model.diffusion_model.") or key.startswith("first_stage_model."):
|
||||
merged_state_dict[key] = value
|
||||
else:
|
||||
merged_state_dict[f"model.diffusion_model.{key}"] = value
|
||||
# Free memory
|
||||
del dit_state_dict
|
||||
gc.collect()
|
||||
|
||||
# VAE - prefix: first_stage_model.
|
||||
# May be omitted if VAE is already included in DiT model.
|
||||
if vae_path:
|
||||
logger.info(f"Loading VAE model from {vae_path}")
|
||||
vae_state_dict = load_safetensors(vae_path, device=device, disable_mmap=True, dtype=target_dtype)
|
||||
logger.info(f"Adding VAE model with {len(vae_state_dict)} keys")
|
||||
for key, value in vae_state_dict.items():
|
||||
if key.startswith("first_stage_model."):
|
||||
merged_state_dict[key] = value
|
||||
else:
|
||||
merged_state_dict[f"first_stage_model.{key}"] = value
|
||||
# Free memory
|
||||
del vae_state_dict
|
||||
gc.collect()
|
||||
|
||||
# CLIP-L - prefix: text_encoders.clip_l.
|
||||
if clip_l_path:
|
||||
logger.info(f"Loading CLIP-L model from {clip_l_path}")
|
||||
clip_l_state_dict = load_safetensors(clip_l_path, device=device, disable_mmap=True, dtype=target_dtype)
|
||||
logger.info(f"Adding CLIP-L model with {len(clip_l_state_dict)} keys")
|
||||
for key, value in clip_l_state_dict.items():
|
||||
if key.startswith("text_encoders.clip_l.transformer."):
|
||||
merged_state_dict[key] = value
|
||||
else:
|
||||
merged_state_dict[f"text_encoders.clip_l.transformer.{key}"] = value
|
||||
# Free memory
|
||||
del clip_l_state_dict
|
||||
gc.collect()
|
||||
|
||||
# CLIP-G - prefix: text_encoders.clip_g.
|
||||
if clip_g_path:
|
||||
logger.info(f"Loading CLIP-G model from {clip_g_path}")
|
||||
clip_g_state_dict = load_safetensors(clip_g_path, device=device, disable_mmap=True, dtype=target_dtype)
|
||||
logger.info(f"Adding CLIP-G model with {len(clip_g_state_dict)} keys")
|
||||
for key, value in clip_g_state_dict.items():
|
||||
if key.startswith("text_encoders.clip_g.transformer."):
|
||||
merged_state_dict[key] = value
|
||||
else:
|
||||
merged_state_dict[f"text_encoders.clip_g.transformer.{key}"] = value
|
||||
# Free memory
|
||||
del clip_g_state_dict
|
||||
gc.collect()
|
||||
|
||||
# T5-XXL - prefix: text_encoders.t5xxl.
|
||||
if t5xxl_path:
|
||||
logger.info(f"Loading T5-XXL model from {t5xxl_path}")
|
||||
t5xxl_state_dict = load_safetensors(t5xxl_path, device=device, disable_mmap=True, dtype=target_dtype)
|
||||
logger.info(f"Adding T5-XXL model with {len(t5xxl_state_dict)} keys")
|
||||
for key, value in t5xxl_state_dict.items():
|
||||
if key.startswith("text_encoders.t5xxl.transformer."):
|
||||
merged_state_dict[key] = value
|
||||
else:
|
||||
merged_state_dict[f"text_encoders.t5xxl.transformer.{key}"] = value
|
||||
# Free memory
|
||||
del t5xxl_state_dict
|
||||
gc.collect()
|
||||
|
||||
# 4. Save merged state dict
|
||||
logger.info(f"Saving merged model to {output_path} with {len(merged_state_dict)} keys total")
|
||||
mem_eff_save_file(merged_state_dict, output_path, metadata)
|
||||
logger.info("Successfully merged safetensors files")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Merge Stable Diffusion 3.5 model components into a single safetensors file")
|
||||
parser.add_argument("--dit", required=True, help="Path to the DiT/MMDiT model")
|
||||
parser.add_argument("--vae", help="Path to the VAE model. May be omitted if VAE is included in DiT model")
|
||||
parser.add_argument("--clip_l", help="Path to the CLIP-L model")
|
||||
parser.add_argument("--clip_g", help="Path to the CLIP-G model")
|
||||
parser.add_argument("--t5xxl", help="Path to the T5-XXL model")
|
||||
parser.add_argument("--output", default="merged_model.safetensors", help="Path to save the merged model")
|
||||
parser.add_argument("--device", default="cpu", help="Device to load tensors to")
|
||||
parser.add_argument("--save_precision", type=str, help="Precision to save the model in (e.g., 'fp16', 'bf16', 'float16', etc.)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
merge_safetensors(
|
||||
dit_path=args.dit,
|
||||
vae_path=args.vae,
|
||||
clip_l_path=args.clip_l,
|
||||
clip_g_path=args.clip_g,
|
||||
t5xxl_path=args.t5xxl,
|
||||
output_path=args.output,
|
||||
device=args.device,
|
||||
save_precision=args.save_precision,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -6,7 +6,7 @@ import shutil
|
||||
import math
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from library.utils import setup_logging, pil_resize
|
||||
from library.utils import setup_logging, resize_image
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -22,14 +22,6 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
|
||||
if not os.path.exists(dst_img_folder):
|
||||
os.makedirs(dst_img_folder)
|
||||
|
||||
# Select interpolation method
|
||||
if interpolation == 'lanczos4':
|
||||
pil_interpolation = Image.LANCZOS
|
||||
elif interpolation == 'cubic':
|
||||
pil_interpolation = Image.BICUBIC
|
||||
else:
|
||||
cv2_interpolation = cv2.INTER_AREA
|
||||
|
||||
# Iterate through all files in src_img_folder
|
||||
img_exts = (".png", ".jpg", ".jpeg", ".webp", ".bmp") # copy from train_util.py
|
||||
for filename in os.listdir(src_img_folder):
|
||||
@@ -63,11 +55,7 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
|
||||
new_height = int(img.shape[0] * math.sqrt(scale_factor))
|
||||
new_width = int(img.shape[1] * math.sqrt(scale_factor))
|
||||
|
||||
# Resize image
|
||||
if cv2_interpolation:
|
||||
img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation)
|
||||
else:
|
||||
img = pil_resize(img, (new_width, new_height), interpolation=pil_interpolation)
|
||||
img = resize_image(img, img.shape[0], img.shape[1], new_height, new_width, interpolation)
|
||||
else:
|
||||
new_height, new_width = img.shape[0:2]
|
||||
|
||||
@@ -113,8 +101,8 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help='Maximum resolution(s) in the format "512x512,384x384, etc, etc" / 最大画像サイズをカンマ区切りで指定 ("512x512,384x384, etc, etc" など)', default="512x512,384x384,256x256,128x128")
|
||||
parser.add_argument('--divisible_by', type=int,
|
||||
help='Ensure new dimensions are divisible by this value / リサイズ後の画像のサイズをこの値で割り切れるようにします', default=1)
|
||||
parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4'],
|
||||
default='area', help='Interpolation method for resizing / リサイズ時の補完方法')
|
||||
parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4', 'nearest', 'linear', 'box'],
|
||||
default=None, help='Interpolation method for resizing. Default to area if smaller, lanczos if larger / サイズ変更の補間方法。小さい場合はデフォルトでエリア、大きい場合はランチョスになります。')
|
||||
parser.add_argument('--save_as_png', action='store_true', help='Save as png format / png形式で保存')
|
||||
parser.add_argument('--copy_associated_files', action='store_true',
|
||||
help='Copy files with same base name to images (captions etc) / 画像と同じファイル名(拡張子を除く)のファイルもコピーする')
|
||||
|
||||
498
train_network.py
498
train_network.py
@@ -9,6 +9,7 @@ import random
|
||||
import time
|
||||
import json
|
||||
from multiprocessing import Value
|
||||
import numpy as np
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
@@ -68,13 +69,20 @@ class NetworkTrainer:
|
||||
keys_scaled=None,
|
||||
mean_norm=None,
|
||||
maximum_norm=None,
|
||||
mean_grad_norm=None,
|
||||
mean_combined_norm=None,
|
||||
):
|
||||
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
||||
|
||||
if keys_scaled is not None:
|
||||
logs["max_norm/keys_scaled"] = keys_scaled
|
||||
logs["max_norm/average_key_norm"] = mean_norm
|
||||
logs["max_norm/max_key_norm"] = maximum_norm
|
||||
if mean_norm is not None:
|
||||
logs["norm/avg_key_norm"] = mean_norm
|
||||
if mean_grad_norm is not None:
|
||||
logs["norm/avg_grad_norm"] = mean_grad_norm
|
||||
if mean_combined_norm is not None:
|
||||
logs["norm/avg_combined_norm"] = mean_combined_norm
|
||||
|
||||
lrs = lr_scheduler.get_last_lr()
|
||||
for i, lr in enumerate(lrs):
|
||||
@@ -100,9 +108,7 @@ class NetworkTrainer:
|
||||
if (
|
||||
args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None
|
||||
): # tracking d*lr value of unet.
|
||||
logs["lr/d*lr"] = (
|
||||
optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
|
||||
)
|
||||
logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
|
||||
else:
|
||||
idx = 0
|
||||
if not args.network_train_unet_only:
|
||||
@@ -115,16 +121,56 @@ class NetworkTrainer:
|
||||
logs[f"lr/d*lr/group{i}"] = (
|
||||
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
|
||||
)
|
||||
if (
|
||||
args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None
|
||||
):
|
||||
logs[f"lr/d*lr/group{i}"] = (
|
||||
optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"]
|
||||
)
|
||||
if args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None:
|
||||
logs[f"lr/d*lr/group{i}"] = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"]
|
||||
|
||||
return logs
|
||||
|
||||
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
|
||||
def step_logging(self, accelerator: Accelerator, logs: dict, global_step: int, epoch: int):
|
||||
self.accelerator_logging(accelerator, logs, global_step, global_step, epoch)
|
||||
|
||||
def epoch_logging(self, accelerator: Accelerator, logs: dict, global_step: int, epoch: int):
|
||||
self.accelerator_logging(accelerator, logs, epoch, global_step, epoch)
|
||||
|
||||
def val_logging(self, accelerator: Accelerator, logs: dict, global_step: int, epoch: int, val_step: int):
|
||||
self.accelerator_logging(accelerator, logs, global_step + val_step, global_step, epoch, val_step)
|
||||
|
||||
def accelerator_logging(
|
||||
self, accelerator: Accelerator, logs: dict, step_value: int, global_step: int, epoch: int, val_step: Optional[int] = None
|
||||
):
|
||||
"""
|
||||
step_value is for tensorboard, other values are for wandb
|
||||
"""
|
||||
tensorboard_tracker = None
|
||||
wandb_tracker = None
|
||||
other_trackers = []
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "tensorboard":
|
||||
tensorboard_tracker = accelerator.get_tracker("tensorboard")
|
||||
elif tracker.name == "wandb":
|
||||
wandb_tracker = accelerator.get_tracker("wandb")
|
||||
else:
|
||||
other_trackers.append(accelerator.get_tracker(tracker.name))
|
||||
|
||||
if tensorboard_tracker is not None:
|
||||
tensorboard_tracker.log(logs, step=step_value)
|
||||
|
||||
if wandb_tracker is not None:
|
||||
logs["global_step"] = global_step
|
||||
logs["epoch"] = epoch
|
||||
if val_step is not None:
|
||||
logs["val_step"] = val_step
|
||||
wandb_tracker.log(logs)
|
||||
|
||||
for tracker in other_trackers:
|
||||
tracker.log(logs, step=step_value)
|
||||
|
||||
def assert_extra_args(
|
||||
self,
|
||||
args,
|
||||
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
|
||||
val_dataset_group: Optional[train_util.DatasetGroup],
|
||||
):
|
||||
train_dataset_group.verify_bucket_reso_steps(64)
|
||||
if val_dataset_group is not None:
|
||||
val_dataset_group.verify_bucket_reso_steps(64)
|
||||
@@ -219,7 +265,7 @@ class NetworkTrainer:
|
||||
network,
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
is_train=True
|
||||
is_train=True,
|
||||
):
|
||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||
# with noise offset and/or multires noise if specified
|
||||
@@ -309,28 +355,31 @@ class NetworkTrainer:
|
||||
) -> torch.nn.Module:
|
||||
return accelerator.prepare(unet)
|
||||
|
||||
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
||||
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train: bool = True):
|
||||
pass
|
||||
|
||||
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
||||
pass
|
||||
|
||||
# endregion
|
||||
|
||||
def process_batch(
|
||||
self,
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
network,
|
||||
vae,
|
||||
noise_scheduler,
|
||||
vae_dtype,
|
||||
weight_dtype,
|
||||
accelerator,
|
||||
args,
|
||||
text_encoding_strategy: strategy_base.TextEncodingStrategy,
|
||||
tokenize_strategy: strategy_base.TokenizeStrategy,
|
||||
is_train=True,
|
||||
train_text_encoder=True,
|
||||
train_unet=True
|
||||
self,
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
network,
|
||||
vae,
|
||||
noise_scheduler,
|
||||
vae_dtype,
|
||||
weight_dtype,
|
||||
accelerator,
|
||||
args,
|
||||
text_encoding_strategy: strategy_base.TextEncodingStrategy,
|
||||
tokenize_strategy: strategy_base.TokenizeStrategy,
|
||||
is_train=True,
|
||||
train_text_encoder=True,
|
||||
train_unet=True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Process a batch for the network
|
||||
@@ -340,7 +389,18 @@ class NetworkTrainer:
|
||||
latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device))
|
||||
else:
|
||||
# latentに変換
|
||||
latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype))
|
||||
if args.vae_batch_size is None or len(batch["images"]) <= args.vae_batch_size:
|
||||
latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype))
|
||||
else:
|
||||
chunks = [
|
||||
batch["images"][i : i + args.vae_batch_size] for i in range(0, len(batch["images"]), args.vae_batch_size)
|
||||
]
|
||||
list_latents = []
|
||||
for chunk in chunks:
|
||||
with torch.no_grad():
|
||||
chunk = self.encode_images_to_latents(args, vae, chunk.to(accelerator.device, dtype=vae_dtype))
|
||||
list_latents.append(chunk)
|
||||
latents = torch.cat(list_latents, dim=0)
|
||||
|
||||
# NaNが含まれていれば警告を表示し0に置き換える
|
||||
if torch.any(torch.isnan(latents)):
|
||||
@@ -397,7 +457,7 @@ class NetworkTrainer:
|
||||
network,
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
is_train=is_train
|
||||
is_train=is_train,
|
||||
)
|
||||
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||
@@ -484,7 +544,7 @@ class NetworkTrainer:
|
||||
else:
|
||||
# use arbitrary dataset class
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||
val_dataset_group = None # placeholder until validation dataset supported for arbitrary
|
||||
val_dataset_group = None # placeholder until validation dataset supported for arbitrary
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
@@ -609,6 +669,10 @@ class NetworkTrainer:
|
||||
return
|
||||
network_has_multiplier = hasattr(network, "set_multiplier")
|
||||
|
||||
# TODO remove `hasattr`s by setting up methods if not defined in the network like (hacky but works):
|
||||
# if not hasattr(network, "prepare_network"):
|
||||
# network.prepare_network = lambda args: None
|
||||
|
||||
if hasattr(network, "prepare_network"):
|
||||
network.prepare_network(args)
|
||||
if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"):
|
||||
@@ -701,7 +765,7 @@ class NetworkTrainer:
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
|
||||
val_dataloader = torch.utils.data.DataLoader(
|
||||
val_dataset_group if val_dataset_group is not None else [],
|
||||
shuffle=False,
|
||||
@@ -900,7 +964,9 @@ class NetworkTrainer:
|
||||
|
||||
accelerator.print("running training / 学習開始")
|
||||
accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
||||
accelerator.print(f" num validation images * repeats / 学習画像の数×繰り返し回数: {val_dataset_group.num_train_images if val_dataset_group is not None else 0}")
|
||||
accelerator.print(
|
||||
f" num validation images * repeats / 学習画像の数×繰り返し回数: {val_dataset_group.num_train_images if val_dataset_group is not None else 0}"
|
||||
)
|
||||
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
||||
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
||||
@@ -968,11 +1034,12 @@ class NetworkTrainer:
|
||||
"ss_huber_c": args.huber_c,
|
||||
"ss_fp8_base": bool(args.fp8_base),
|
||||
"ss_fp8_base_unet": bool(args.fp8_base_unet),
|
||||
"ss_validation_seed": args.validation_seed,
|
||||
"ss_validation_split": args.validation_split,
|
||||
"ss_max_validation_steps": args.max_validation_steps,
|
||||
"ss_validate_every_n_epochs": args.validate_every_n_epochs,
|
||||
"ss_validate_every_n_steps": args.validate_every_n_steps,
|
||||
"ss_validation_seed": args.validation_seed,
|
||||
"ss_validation_split": args.validation_split,
|
||||
"ss_max_validation_steps": args.max_validation_steps,
|
||||
"ss_validate_every_n_epochs": args.validate_every_n_epochs,
|
||||
"ss_validate_every_n_steps": args.validate_every_n_steps,
|
||||
"ss_resize_interpolation": args.resize_interpolation,
|
||||
}
|
||||
|
||||
self.update_metadata(metadata, args) # architecture specific metadata
|
||||
@@ -998,6 +1065,7 @@ class NetworkTrainer:
|
||||
"max_bucket_reso": dataset.max_bucket_reso,
|
||||
"tag_frequency": dataset.tag_frequency,
|
||||
"bucket_info": dataset.bucket_info,
|
||||
"resize_interpolation": dataset.resize_interpolation,
|
||||
}
|
||||
|
||||
subsets_metadata = []
|
||||
@@ -1015,6 +1083,7 @@ class NetworkTrainer:
|
||||
"enable_wildcard": bool(subset.enable_wildcard),
|
||||
"caption_prefix": subset.caption_prefix,
|
||||
"caption_suffix": subset.caption_suffix,
|
||||
"resize_interpolation": subset.resize_interpolation,
|
||||
}
|
||||
|
||||
image_dir_or_metadata_file = None
|
||||
@@ -1163,10 +1232,6 @@ class NetworkTrainer:
|
||||
args.max_train_steps > initial_step
|
||||
), f"max_train_steps should be greater than initial step / max_train_stepsは初期ステップより大きい必要があります: {args.max_train_steps} vs {initial_step}"
|
||||
|
||||
progress_bar = tqdm(
|
||||
range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps"
|
||||
)
|
||||
|
||||
epoch_to_start = 0
|
||||
if initial_step > 0:
|
||||
if args.skip_until_initial_step:
|
||||
@@ -1247,12 +1312,6 @@ class NetworkTrainer:
|
||||
# log empty object to commit the sample images to wandb
|
||||
accelerator.log({}, step=0)
|
||||
|
||||
validation_steps = (
|
||||
min(args.max_validation_steps, len(val_dataloader))
|
||||
if args.max_validation_steps is not None
|
||||
else len(val_dataloader)
|
||||
)
|
||||
|
||||
# training loop
|
||||
if initial_step > 0: # only if skip_until_initial_step is specified
|
||||
for skip_epoch in range(epoch_to_start): # skip epochs
|
||||
@@ -1271,13 +1330,57 @@ class NetworkTrainer:
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
progress_bar = tqdm(
|
||||
range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps"
|
||||
)
|
||||
|
||||
validation_steps = (
|
||||
min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader)
|
||||
)
|
||||
NUM_VALIDATION_TIMESTEPS = 4 # 200, 400, 600, 800 TODO make this configurable
|
||||
min_timestep = 0 if args.min_timestep is None else args.min_timestep
|
||||
max_timestep = noise_scheduler.num_train_timesteps if args.max_timestep is None else args.max_timestep
|
||||
validation_timesteps = np.linspace(min_timestep, max_timestep, (NUM_VALIDATION_TIMESTEPS + 2), dtype=int)[1:-1]
|
||||
validation_total_steps = validation_steps * len(validation_timesteps)
|
||||
original_args_min_timestep = args.min_timestep
|
||||
original_args_max_timestep = args.max_timestep
|
||||
|
||||
def switch_rng_state(seed: int) -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]:
|
||||
cpu_rng_state = torch.get_rng_state()
|
||||
if accelerator.device.type == "cuda":
|
||||
gpu_rng_state = torch.cuda.get_rng_state()
|
||||
elif accelerator.device.type == "xpu":
|
||||
gpu_rng_state = torch.xpu.get_rng_state()
|
||||
elif accelerator.device.type == "mps":
|
||||
gpu_rng_state = torch.cuda.get_rng_state()
|
||||
else:
|
||||
gpu_rng_state = None
|
||||
python_rng_state = random.getstate()
|
||||
|
||||
torch.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
return (cpu_rng_state, gpu_rng_state, python_rng_state)
|
||||
|
||||
def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]):
|
||||
cpu_rng_state, gpu_rng_state, python_rng_state = rng_states
|
||||
torch.set_rng_state(cpu_rng_state)
|
||||
if gpu_rng_state is not None:
|
||||
if accelerator.device.type == "cuda":
|
||||
torch.cuda.set_rng_state(gpu_rng_state)
|
||||
elif accelerator.device.type == "xpu":
|
||||
torch.xpu.set_rng_state(gpu_rng_state)
|
||||
elif accelerator.device.type == "mps":
|
||||
torch.cuda.set_rng_state(gpu_rng_state)
|
||||
random.setstate(python_rng_state)
|
||||
|
||||
for epoch in range(epoch_to_start, num_train_epochs):
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n")
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
metadata["ss_epoch"] = str(epoch + 1)
|
||||
|
||||
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet)
|
||||
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) # network.train() is called here
|
||||
|
||||
# TRAINING
|
||||
skipped_dataloader = None
|
||||
@@ -1294,25 +1397,25 @@ class NetworkTrainer:
|
||||
with accelerator.accumulate(training_model):
|
||||
on_step_start_for_network(text_encoder, unet)
|
||||
|
||||
# temporary, for batch processing
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||
# preprocess batch for each model
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True)
|
||||
|
||||
loss = self.process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
network,
|
||||
vae,
|
||||
noise_scheduler,
|
||||
vae_dtype,
|
||||
weight_dtype,
|
||||
accelerator,
|
||||
args,
|
||||
text_encoding_strategy,
|
||||
tokenize_strategy,
|
||||
is_train=True,
|
||||
train_text_encoder=train_text_encoder,
|
||||
train_unet=train_unet
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
network,
|
||||
vae,
|
||||
noise_scheduler,
|
||||
vae_dtype,
|
||||
weight_dtype,
|
||||
accelerator,
|
||||
args,
|
||||
text_encoding_strategy,
|
||||
tokenize_strategy,
|
||||
is_train=True,
|
||||
train_text_encoder=train_text_encoder,
|
||||
train_unet=train_unet,
|
||||
)
|
||||
|
||||
accelerator.backward(loss)
|
||||
@@ -1322,6 +1425,11 @@ class NetworkTrainer:
|
||||
params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
if hasattr(network, "update_grad_norms"):
|
||||
network.update_grad_norms()
|
||||
if hasattr(network, "update_norms"):
|
||||
network.update_norms()
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
@@ -1330,9 +1438,25 @@ class NetworkTrainer:
|
||||
keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(
|
||||
args.scale_weight_norms, accelerator.device
|
||||
)
|
||||
mean_grad_norm = None
|
||||
mean_combined_norm = None
|
||||
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
|
||||
else:
|
||||
keys_scaled, mean_norm, maximum_norm = None, None, None
|
||||
if hasattr(network, "weight_norms"):
|
||||
weight_norms = network.weight_norms()
|
||||
mean_norm = weight_norms.mean().item() if weight_norms is not None else None
|
||||
grad_norms = network.grad_norms()
|
||||
mean_grad_norm = grad_norms.mean().item() if grad_norms is not None else None
|
||||
combined_weight_norms = network.combined_weight_norms()
|
||||
mean_combined_norm = combined_weight_norms.mean().item() if combined_weight_norms is not None else None
|
||||
maximum_norm = weight_norms.max().item() if weight_norms is not None else None
|
||||
keys_scaled = None
|
||||
max_mean_logs = {}
|
||||
else:
|
||||
keys_scaled, mean_norm, maximum_norm = None, None, None
|
||||
mean_grad_norm = None
|
||||
mean_combined_norm = None
|
||||
max_mean_logs = {}
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
@@ -1364,153 +1488,179 @@ class NetworkTrainer:
|
||||
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||
avr_loss: float = loss_recorder.moving_average
|
||||
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if args.scale_weight_norms:
|
||||
progress_bar.set_postfix(**{**max_mean_logs, **logs})
|
||||
|
||||
progress_bar.set_postfix(**{**max_mean_logs, **logs})
|
||||
|
||||
if is_tracking:
|
||||
logs = self.generate_step_logs(
|
||||
args,
|
||||
current_loss,
|
||||
avr_loss,
|
||||
lr_scheduler,
|
||||
lr_descriptions,
|
||||
optimizer,
|
||||
keys_scaled,
|
||||
mean_norm,
|
||||
maximum_norm
|
||||
args,
|
||||
current_loss,
|
||||
avr_loss,
|
||||
lr_scheduler,
|
||||
lr_descriptions,
|
||||
optimizer,
|
||||
keys_scaled,
|
||||
mean_norm,
|
||||
maximum_norm,
|
||||
mean_grad_norm,
|
||||
mean_combined_norm,
|
||||
)
|
||||
accelerator.log(logs, step=global_step)
|
||||
self.step_logging(accelerator, logs, global_step, epoch + 1)
|
||||
|
||||
# VALIDATION PER STEP
|
||||
should_validate_step = (
|
||||
args.validate_every_n_steps is not None
|
||||
and global_step != 0 # Skip first step
|
||||
and global_step % args.validate_every_n_steps == 0
|
||||
)
|
||||
# VALIDATION PER STEP: global_step is already incremented
|
||||
# for example, if validate_every_n_steps=100, validate at step 100, 200, 300, ...
|
||||
should_validate_step = args.validate_every_n_steps is not None and global_step % args.validate_every_n_steps == 0
|
||||
if accelerator.sync_gradients and validation_steps > 0 and should_validate_step:
|
||||
optimizer_eval_fn()
|
||||
accelerator.unwrap_model(network).eval()
|
||||
rng_states = switch_rng_state(args.validation_seed if args.validation_seed is not None else args.seed)
|
||||
|
||||
val_progress_bar = tqdm(
|
||||
range(validation_steps), smoothing=0,
|
||||
disable=not accelerator.is_local_main_process,
|
||||
desc="validation steps"
|
||||
range(validation_total_steps),
|
||||
smoothing=0,
|
||||
disable=not accelerator.is_local_main_process,
|
||||
desc="validation steps",
|
||||
)
|
||||
val_timesteps_step = 0
|
||||
for val_step, batch in enumerate(val_dataloader):
|
||||
if val_step >= validation_steps:
|
||||
break
|
||||
|
||||
# temporary, for batch processing
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||
for timestep in validation_timesteps:
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False)
|
||||
|
||||
loss = self.process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
network,
|
||||
vae,
|
||||
noise_scheduler,
|
||||
vae_dtype,
|
||||
weight_dtype,
|
||||
accelerator,
|
||||
args,
|
||||
text_encoding_strategy,
|
||||
tokenize_strategy,
|
||||
is_train=False,
|
||||
train_text_encoder=False,
|
||||
train_unet=False
|
||||
)
|
||||
args.min_timestep = args.max_timestep = timestep # dirty hack to change timestep
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
val_step_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
|
||||
val_progress_bar.update(1)
|
||||
val_progress_bar.set_postfix({ "val_avg_loss": val_step_loss_recorder.moving_average })
|
||||
loss = self.process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
network,
|
||||
vae,
|
||||
noise_scheduler,
|
||||
vae_dtype,
|
||||
weight_dtype,
|
||||
accelerator,
|
||||
args,
|
||||
text_encoding_strategy,
|
||||
tokenize_strategy,
|
||||
is_train=False,
|
||||
train_text_encoder=train_text_encoder, # this is needed for validation because Text Encoders must be called if train_text_encoder is True
|
||||
train_unet=train_unet,
|
||||
)
|
||||
|
||||
if is_tracking:
|
||||
logs = {
|
||||
"loss/validation/step_current": current_loss,
|
||||
"val_step": (epoch * validation_steps) + val_step,
|
||||
}
|
||||
accelerator.log(logs, step=global_step)
|
||||
current_loss = loss.detach().item()
|
||||
val_step_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss)
|
||||
val_progress_bar.update(1)
|
||||
val_progress_bar.set_postfix(
|
||||
{"val_avg_loss": val_step_loss_recorder.moving_average, "timestep": timestep}
|
||||
)
|
||||
|
||||
# if is_tracking:
|
||||
# logs = {f"loss/validation/step_current_{timestep}": current_loss}
|
||||
# self.val_logging(accelerator, logs, global_step, epoch + 1, val_step)
|
||||
|
||||
self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||
val_timesteps_step += 1
|
||||
|
||||
if is_tracking:
|
||||
loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average
|
||||
logs = {
|
||||
"loss/validation/step_average": val_step_loss_recorder.moving_average,
|
||||
"loss/validation/step_divergence": loss_validation_divergence,
|
||||
"loss/validation/step_average": val_step_loss_recorder.moving_average,
|
||||
"loss/validation/step_divergence": loss_validation_divergence,
|
||||
}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
self.step_logging(accelerator, logs, global_step, epoch=epoch + 1)
|
||||
|
||||
restore_rng_state(rng_states)
|
||||
args.min_timestep = original_args_min_timestep
|
||||
args.max_timestep = original_args_max_timestep
|
||||
optimizer_train_fn()
|
||||
accelerator.unwrap_model(network).train()
|
||||
progress_bar.unpause()
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
# EPOCH VALIDATION
|
||||
should_validate_epoch = (
|
||||
(epoch + 1) % args.validate_every_n_epochs == 0
|
||||
if args.validate_every_n_epochs is not None
|
||||
else True
|
||||
(epoch + 1) % args.validate_every_n_epochs == 0 if args.validate_every_n_epochs is not None else True
|
||||
)
|
||||
|
||||
if should_validate_epoch and len(val_dataloader) > 0:
|
||||
optimizer_eval_fn()
|
||||
accelerator.unwrap_model(network).eval()
|
||||
rng_states = switch_rng_state(args.validation_seed if args.validation_seed is not None else args.seed)
|
||||
|
||||
val_progress_bar = tqdm(
|
||||
range(validation_steps), smoothing=0,
|
||||
disable=not accelerator.is_local_main_process,
|
||||
desc="epoch validation steps"
|
||||
range(validation_total_steps),
|
||||
smoothing=0,
|
||||
disable=not accelerator.is_local_main_process,
|
||||
desc="epoch validation steps",
|
||||
)
|
||||
|
||||
val_timesteps_step = 0
|
||||
for val_step, batch in enumerate(val_dataloader):
|
||||
if val_step >= validation_steps:
|
||||
break
|
||||
|
||||
# temporary, for batch processing
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||
for timestep in validation_timesteps:
|
||||
args.min_timestep = args.max_timestep = timestep
|
||||
|
||||
loss = self.process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
network,
|
||||
vae,
|
||||
noise_scheduler,
|
||||
vae_dtype,
|
||||
weight_dtype,
|
||||
accelerator,
|
||||
args,
|
||||
text_encoding_strategy,
|
||||
tokenize_strategy,
|
||||
is_train=False,
|
||||
train_text_encoder=False,
|
||||
train_unet=False
|
||||
)
|
||||
# temporary, for batch processing
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False)
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
val_epoch_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
|
||||
val_progress_bar.update(1)
|
||||
val_progress_bar.set_postfix({ "val_epoch_avg_loss": val_epoch_loss_recorder.moving_average })
|
||||
loss = self.process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
network,
|
||||
vae,
|
||||
noise_scheduler,
|
||||
vae_dtype,
|
||||
weight_dtype,
|
||||
accelerator,
|
||||
args,
|
||||
text_encoding_strategy,
|
||||
tokenize_strategy,
|
||||
is_train=False,
|
||||
train_text_encoder=train_text_encoder,
|
||||
train_unet=train_unet,
|
||||
)
|
||||
|
||||
if is_tracking:
|
||||
logs = {
|
||||
"loss/validation/epoch_current": current_loss,
|
||||
"epoch": epoch + 1,
|
||||
"val_step": (epoch * validation_steps) + val_step
|
||||
}
|
||||
accelerator.log(logs, step=global_step)
|
||||
current_loss = loss.detach().item()
|
||||
val_epoch_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss)
|
||||
val_progress_bar.update(1)
|
||||
val_progress_bar.set_postfix(
|
||||
{"val_epoch_avg_loss": val_epoch_loss_recorder.moving_average, "timestep": timestep}
|
||||
)
|
||||
|
||||
# if is_tracking:
|
||||
# logs = {f"loss/validation/epoch_current_{timestep}": current_loss}
|
||||
# self.val_logging(accelerator, logs, global_step, epoch + 1, val_step)
|
||||
|
||||
self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||
val_timesteps_step += 1
|
||||
|
||||
if is_tracking:
|
||||
avr_loss: float = val_epoch_loss_recorder.moving_average
|
||||
loss_validation_divergence = val_epoch_loss_recorder.moving_average - loss_recorder.moving_average
|
||||
loss_validation_divergence = val_epoch_loss_recorder.moving_average - loss_recorder.moving_average
|
||||
logs = {
|
||||
"loss/validation/epoch_average": avr_loss,
|
||||
"loss/validation/epoch_divergence": loss_validation_divergence,
|
||||
"epoch": epoch + 1
|
||||
"loss/validation/epoch_average": avr_loss,
|
||||
"loss/validation/epoch_divergence": loss_validation_divergence,
|
||||
}
|
||||
accelerator.log(logs, step=global_step)
|
||||
self.epoch_logging(accelerator, logs, global_step, epoch + 1)
|
||||
|
||||
restore_rng_state(rng_states)
|
||||
args.min_timestep = original_args_min_timestep
|
||||
args.max_timestep = original_args_max_timestep
|
||||
optimizer_train_fn()
|
||||
accelerator.unwrap_model(network).train()
|
||||
progress_bar.unpause()
|
||||
|
||||
# END OF EPOCH
|
||||
if is_tracking:
|
||||
logs = {"loss/epoch_average": loss_recorder.moving_average, "epoch": epoch + 1}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
logs = {"loss/epoch_average": loss_recorder.moving_average}
|
||||
self.epoch_logging(accelerator, logs, global_step, epoch + 1)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# 指定エポックごとにモデルを保存
|
||||
@@ -1696,31 +1846,31 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
"--validation_seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Validation seed for shuffling validation dataset, training `--seed` used otherwise / 検証データセットをシャッフルするための検証シード、それ以外の場合はトレーニング `--seed` を使用する"
|
||||
help="Validation seed for shuffling validation dataset, training `--seed` used otherwise / 検証データセットをシャッフルするための検証シード、それ以外の場合はトレーニング `--seed` を使用する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_split",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合"
|
||||
help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validate_every_n_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Run validation on validation dataset every N steps. By default, validation will only occur every epoch if a validation dataset is available / 検証データセットの検証をNステップごとに実行します。デフォルトでは、検証データセットが利用可能な場合にのみ、検証はエポックごとに実行されます"
|
||||
help="Run validation on validation dataset every N steps. By default, validation will only occur every epoch if a validation dataset is available / 検証データセットの検証をNステップごとに実行します。デフォルトでは、検証データセットが利用可能な場合にのみ、検証はエポックごとに実行されます",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validate_every_n_epochs",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available / 検証データセットをNエポックごとに実行します。デフォルトでは、検証データセットが利用可能な場合、検証はエポックごとに実行されます"
|
||||
help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available / 検証データセットをNエポックごとに実行します。デフォルトでは、検証データセットが利用可能な場合、検証はエポックごとに実行されます",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_validation_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します"
|
||||
help="Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
Reference in New Issue
Block a user