Merge branch 'gesen2egee/val' into validation-loss-upstream

Modified various implementations to restore original behavior
This commit is contained in:
rockerBOO
2025-01-02 23:04:38 -05:00
84 changed files with 11419 additions and 3458 deletions

View File

@@ -18,4 +18,4 @@ jobs:
- uses: actions/checkout@v4
- name: typos-action
uses: crate-ci/typos@v1.16.15
uses: crate-ci/typos@v1.19.0

View File

@@ -1,7 +1,3 @@
SDXLがサポートされました。sdxlブランチはmainブランチにマージされました。リポジトリを更新したときにはUpgradeの手順を実行してください。また accelerate のバージョンが上がっていますので、accelerate config を再度実行してください。
SDXL学習については[こちら](./README.md#sdxl-training)をご覧ください(英語です)。
## リポジトリについて
Stable Diffusionの学習、画像生成、その他のスクリプトを入れたリポジトリです。
@@ -21,6 +17,7 @@ GUIやPowerShellスクリプトなど、より使いやすくする機能が[bma
* [学習について、共通編](./docs/train_README-ja.md) : データ整備やオプションなど
* [データセット設定](./docs/config_README-ja.md)
* [SDXL学習](./docs/train_SDXL-en.md) (英語版)
* [DreamBoothの学習について](./docs/train_db_README-ja.md)
* [fine-tuningのガイド](./docs/fine_tune_README_ja.md):
* [LoRAの学習について](./docs/train_network_README-ja.md)
@@ -44,9 +41,7 @@ PowerShellを使う場合、venvを使えるようにするためには以下の
## Windows環境でのインストール
スクリプトはPyTorch 2.0.1でテストしています。PyTorch 1.12.1でも動作すると思われます。
以下の例ではPyTorchは2.0.1CUDA 11.8版をインストールします。CUDA 11.6版やPyTorch 1.12.1を使う場合は適宜書き換えください。
スクリプトはPyTorch 2.1.2でテストしています。PyTorch 2.0.1、1.12.1でも動作すると思われます。
なお、python -m venvの行で「python」とだけ表示された場合、py -m venvのようにpythonをpyに変更してください。
@@ -59,21 +54,21 @@ cd sd-scripts
python -m venv venv
.\venv\Scripts\activate
pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118
pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu118
pip install --upgrade -r requirements.txt
pip install xformers==0.0.20
pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu118
accelerate config
```
コマンドプロンプトでも同一です。
(注:``python -m venv venv`` のほうが ``python -m venv --system-site-packages venv`` より安全そうなため書き換えました。globalなpythonにパッケージがインストールしてあると、後者だといろいろと問題が起きます。
注:`bitsandbytes==0.43.0``prodigyopt==1.0``lion-pytorch==0.0.6``requirements.txt` に含まれるようになりました。他のバージョンを使う場合は適宜インストールしてください。
この例では PyTorch および xfomers は2.1.2CUDA 11.8版をインストールします。CUDA 12.1版やPyTorch 1.12.1を使う場合は適宜書き換えください。たとえば CUDA 12.1版の場合は `pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu121` および `pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu121` としてください。
accelerate configの質問には以下のように答えてください。bf16で学習する場合、最後の質問にはbf16と答えてください。
※0.15.0から日本語環境では選択のためにカーソルキーを押すと落ちます……。数字キーの0、1、2……で選択できますので、そちらを使ってください。
```txt
- This machine
- No distributed training
@@ -87,41 +82,6 @@ accelerate configの質問には以下のように答えてください。bf1
※場合によって ``ValueError: fp16 mixed precision requires a GPU`` というエラーが出ることがあるようです。この場合、6番目の質問
``What GPU(s) (by id) should be used for training on this machine as a comma-separated list? [all]:``に「0」と答えてください。id `0`のGPUが使われます。
### オプション:`bitsandbytes`8bit optimizerを使う
`bitsandbytes`はオプションになりました。Linuxでは通常通りpipでインストールできます0.41.1または以降のバージョンを推奨)。
Windowsでは0.35.0または0.41.1を推奨します。
- `bitsandbytes` 0.35.0: 安定しているとみられるバージョンです。AdamW8bitは使用できますが、他のいくつかの8bit optimizer、学習時の`full_bf16`オプションは使用できません。
- `bitsandbytes` 0.41.1: Lion8bit、PagedAdamW8bit、PagedLion8bitをサポートします。`full_bf16`が使用できます。
注:`bitsandbytes` 0.35.0から0.41.0までのバージョンには問題があるようです。 https://github.com/TimDettmers/bitsandbytes/issues/659
以下の手順に従い、`bitsandbytes`をインストールしてください。
### 0.35.0を使う場合
PowerShellの例です。コマンドプロンプトではcpの代わりにcopyを使ってください。
```powershell
cd sd-scripts
.\venv\Scripts\activate
pip install bitsandbytes==0.35.0
cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
```
### 0.41.1を使う場合
jllllll氏の配布されている[こちら](https://github.com/jllllll/bitsandbytes-windows-webui) または他の場所から、Windows用のwhlファイルをインストールしてください。
```powershell
python -m pip install bitsandbytes==0.41.1 --prefer-binary --extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui
```
## アップグレード
新しいリリースがあった場合、以下のコマンドで更新できます。
@@ -151,4 +111,47 @@ Conv2d 3x3への拡大は [cloneofsimo氏](https://github.com/cloneofsimo/lora)
[BLIP](https://github.com/salesforce/BLIP): BSD-3-Clause
## その他の情報
### LoRAの名称について
`train_network.py` がサポートするLoRAについて、混乱を避けるため名前を付けました。ドキュメントは更新済みです。以下は当リポジトリ内の独自の名称です。
1. __LoRA-LierLa__ : (LoRA for __Li__ n __e__ a __r__ __La__ yers、リエラと読みます)
Linear 層およびカーネルサイズ 1x1 の Conv2d 層に適用されるLoRA
2. __LoRA-C3Lier__ : (LoRA for __C__ olutional layers with __3__ x3 Kernel and __Li__ n __e__ a __r__ layers、セリアと読みます)
1.に加え、カーネルサイズ 3x3 の Conv2d 層に適用されるLoRA
デフォルトではLoRA-LierLaが使われます。LoRA-C3Lierを使う場合は `--network_args` に `conv_dim` を指定してください。
<!--
LoRA-LierLa は[Web UI向け拡張](https://github.com/kohya-ss/sd-webui-additional-networks)、またはAUTOMATIC1111氏のWeb UIのLoRA機能で使用することができます。
LoRA-C3Lierを使いWeb UIで生成するには拡張を使用してください。
-->
### 学習中のサンプル画像生成
プロンプトファイルは例えば以下のようになります。
```
# prompt 1
masterpiece, best quality, (1girl), in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
# prompt 2
masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n (low quality, worst quality), bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
```
`#` で始まる行はコメントになります。`--n` のように「ハイフン二個+英小文字」の形でオプションを指定できます。以下が使用可能できます。
* `--n` Negative prompt up to the next option.
* `--w` Specifies the width of the generated image.
* `--h` Specifies the height of the generated image.
* `--d` Specifies the seed of the generated image.
* `--l` Specifies the CFG scale of the generated image.
* `--s` Specifies the number of steps in the generation.
`( )` や `[ ]` などの重みづけも動作します。

393
README.md
View File

@@ -1,5 +1,3 @@
__SDXL is now supported. The sdxl branch has been merged into the main branch. If you update the repository, please follow the upgrade instructions. Also, the version of accelerate has been updated, so please run accelerate config again.__ The documentation for SDXL training is [here](./README.md#sdxl-training).
This repository contains training, generation and utility scripts for Stable Diffusion.
[__Change History__](#change-history) is moved to the bottom of the page.
@@ -20,9 +18,9 @@ This repository contains the scripts for:
## About requirements.txt
These files do not contain requirements for PyTorch. Because the versions of them depend on your environment. Please install PyTorch at first (see installation guide below.)
The file does not contain requirements for PyTorch. Because the version of PyTorch depends on the environment, it is not included in the file. Please install PyTorch first according to the environment. See installation instructions below.
The scripts are tested with Pytorch 2.0.1. 1.12.1 is not tested but should work.
The scripts are tested with Pytorch 2.1.2. 2.0.1 and 1.12.1 is not tested but should work.
## Links to usage documentation
@@ -32,11 +30,13 @@ Most of the documents are written in Japanese.
* [Training guide - common](./docs/train_README-ja.md) : data preparation, options etc...
* [Chinese version](./docs/train_README-zh.md)
* [SDXL training](./docs/train_SDXL-en.md) (English version)
* [Dataset config](./docs/config_README-ja.md)
* [English version](./docs/config_README-en.md)
* [DreamBooth training guide](./docs/train_db_README-ja.md)
* [Step by Step fine-tuning guide](./docs/fine_tune_README_ja.md):
* [training LoRA](./docs/train_network_README-ja.md)
* [training Textual Inversion](./docs/train_ti_README-ja.md)
* [Training LoRA](./docs/train_network_README-ja.md)
* [Training Textual Inversion](./docs/train_ti_README-ja.md)
* [Image generation](./docs/gen_img_README-ja.md)
* note.com [Model conversion](https://note.com/kohya_ss/n/n374f316fe4ad)
@@ -64,14 +64,18 @@ cd sd-scripts
python -m venv venv
.\venv\Scripts\activate
pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118
pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu118
pip install --upgrade -r requirements.txt
pip install xformers==0.0.20
pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu118
accelerate config
```
__Note:__ Now bitsandbytes is optional. Please install any version of bitsandbytes as needed. Installation instructions are in the following section.
If `python -m venv` shows only `python`, change `python` to `py`.
__Note:__ Now `bitsandbytes==0.43.0`, `prodigyopt==1.0` and `lion-pytorch==0.0.6` are included in the requirements.txt. If you'd like to use the another version, please install it manually.
This installation is for CUDA 11.8. If you use a different version of CUDA, please install the appropriate version of PyTorch and xformers. For example, if you use CUDA 12, please install `pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu121` and `pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu121`.
<!--
cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
@@ -90,48 +94,13 @@ Answers to accelerate config:
- fp16
```
note: Some user reports ``ValueError: fp16 mixed precision requires a GPU`` is occurred in training. In this case, answer `0` for the 6th question:
If you'd like to use bf16, please answer `bf16` to the last question.
Note: Some user reports ``ValueError: fp16 mixed precision requires a GPU`` is occurred in training. In this case, answer `0` for the 6th question:
``What GPU(s) (by id) should be used for training on this machine as a comma-separated list? [all]:``
(Single GPU with id `0` will be used.)
### Optional: Use `bitsandbytes` (8bit optimizer)
For 8bit optimizer, you need to install `bitsandbytes`. For Linux, please install `bitsandbytes` as usual (0.41.1 or later is recommended.)
For Windows, there are several versions of `bitsandbytes`:
- `bitsandbytes` 0.35.0: Stable version. AdamW8bit is available. `full_bf16` is not available.
- `bitsandbytes` 0.41.1: Lion8bit, PagedAdamW8bit and PagedLion8bit are available. `full_bf16` is available.
Note: `bitsandbytes`above 0.35.0 till 0.41.0 seems to have an issue: https://github.com/TimDettmers/bitsandbytes/issues/659
Follow the instructions below to install `bitsandbytes` for Windows.
### bitsandbytes 0.35.0 for Windows
Open a regular Powershell terminal and type the following inside:
```powershell
cd sd-scripts
.\venv\Scripts\activate
pip install bitsandbytes==0.35.0
cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
```
This will install `bitsandbytes` 0.35.0 and copy the necessary files to the `bitsandbytes` directory.
### bitsandbytes 0.41.1 for Windows
Install the Windows version whl file from [here](https://github.com/jllllll/bitsandbytes-windows-webui) or other sources, like:
```powershell
python -m pip install bitsandbytes==0.41.1 --prefer-binary --extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui
```
## Upgrade
When a new release comes out you can upgrade your repo with the following command:
@@ -145,6 +114,10 @@ pip install --use-pep517 --upgrade -r requirements.txt
Once the commands have completed successfully you should be ready to use the new version.
### Upgrade PyTorch
If you want to upgrade PyTorch, you can upgrade it with `pip install` command in [Windows Installation](#windows-installation) section. `xformers` is also required to be upgraded when PyTorch is upgraded.
## Credits
The implementation for LoRA is based on [cloneofsimo's repo](https://github.com/cloneofsimo/lora). Thank you for great work!
@@ -162,131 +135,183 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
[BLIP](https://github.com/salesforce/BLIP): BSD-3-Clause
## SDXL training
The documentation in this section will be moved to a separate document later.
### Training scripts for SDXL
- `sdxl_train.py` is a script for SDXL fine-tuning. The usage is almost the same as `fine_tune.py`, but it also supports DreamBooth dataset.
- `--full_bf16` option is added. Thanks to KohakuBlueleaf!
- This option enables the full bfloat16 training (includes gradients). This option is useful to reduce the GPU memory usage.
- The full bfloat16 training might be unstable. Please use it at your own risk.
- The different learning rates for each U-Net block are now supported in sdxl_train.py. Specify with `--block_lr` option. Specify 23 values separated by commas like `--block_lr 1e-3,1e-3 ... 1e-3`.
- 23 values correspond to `0: time/label embed, 1-9: input blocks 0-8, 10-12: mid blocks 0-2, 13-21: output blocks 0-8, 22: out`.
- `prepare_buckets_latents.py` now supports SDXL fine-tuning.
- `sdxl_train_network.py` is a script for LoRA training for SDXL. The usage is almost the same as `train_network.py`.
- Both scripts has following additional options:
- `--cache_text_encoder_outputs` and `--cache_text_encoder_outputs_to_disk`: Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions.
- `--no_half_vae`: Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs.
- `--weighted_captions` option is not supported yet for both scripts.
- `sdxl_train_textual_inversion.py` is a script for Textual Inversion training for SDXL. The usage is almost the same as `train_textual_inversion.py`.
- `--cache_text_encoder_outputs` is not supported.
- There are two options for captions:
1. Training with captions. All captions must include the token string. The token string is replaced with multiple tokens.
2. Use `--use_object_template` or `--use_style_template` option. The captions are generated from the template. The existing captions are ignored.
- See below for the format of the embeddings.
- `--min_timestep` and `--max_timestep` options are added to each training script. These options can be used to train U-Net with different timesteps. The default values are 0 and 1000.
### Utility scripts for SDXL
- `tools/cache_latents.py` is added. This script can be used to cache the latents to disk in advance.
- The options are almost the same as `sdxl_train.py'. See the help message for the usage.
- Please launch the script as follows:
`accelerate launch --num_cpu_threads_per_process 1 tools/cache_latents.py ...`
- This script should work with multi-GPU, but it is not tested in my environment.
- `tools/cache_text_encoder_outputs.py` is added. This script can be used to cache the text encoder outputs to disk in advance.
- The options are almost the same as `cache_latents.py` and `sdxl_train.py`. See the help message for the usage.
- `sdxl_gen_img.py` is added. This script can be used to generate images with SDXL, including LoRA, Textual Inversion and ControlNet-LLLite. See the help message for the usage.
### Tips for SDXL training
- The default resolution of SDXL is 1024x1024.
- The fine-tuning can be done with 24GB GPU memory with the batch size of 1. For 24GB GPU, the following options are recommended __for the fine-tuning with 24GB GPU memory__:
- Train U-Net only.
- Use gradient checkpointing.
- Use `--cache_text_encoder_outputs` option and caching latents.
- Use Adafactor optimizer. RMSprop 8bit or Adagrad 8bit may work. AdamW 8bit doesn't seem to work.
- The LoRA training can be done with 8GB GPU memory (10GB recommended). For reducing the GPU memory usage, the following options are recommended:
- Train U-Net only.
- Use gradient checkpointing.
- Use `--cache_text_encoder_outputs` option and caching latents.
- Use one of 8bit optimizers or Adafactor optimizer.
- Use lower dim (4 to 8 for 8GB GPU).
- `--network_train_unet_only` option is highly recommended for SDXL LoRA. Because SDXL has two text encoders, the result of the training will be unexpected.
- PyTorch 2 seems to use slightly less GPU memory than PyTorch 1.
- `--bucket_reso_steps` can be set to 32 instead of the default value 64. Smaller values than 32 will not work for SDXL training.
Example of the optimizer settings for Adafactor with the fixed learning rate:
```toml
optimizer_type = "adafactor"
optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ]
lr_scheduler = "constant_with_warmup"
lr_warmup_steps = 100
learning_rate = 4e-7 # SDXL original learning rate
```
### Format of Textual Inversion embeddings for SDXL
```python
from safetensors.torch import save_file
state_dict = {"clip_g": embs_for_text_encoder_1280, "clip_l": embs_for_text_encoder_768}
save_file(state_dict, file)
```
### ControlNet-LLLite
ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [documentation](./docs/train_lllite_README.md) for details.
## Change History
### Nov 5, 2023 / 2023/11/5
### Apr 7, 2024 / 2024-04-07: v0.8.7
- `sdxl_train.py` now supports different learning rates for each Text Encoder.
- Example:
- `--learning_rate 1e-6`: train U-Net only
- `--train_text_encoder --learning_rate 1e-6`: train U-Net and two Text Encoders with the same learning rate (same as the previous version)
- `--train_text_encoder --learning_rate 1e-6 --learning_rate_te1 1e-6 --learning_rate_te2 1e-6`: train U-Net and two Text Encoders with the different learning rates
- `--train_text_encoder --learning_rate 0 --learning_rate_te1 1e-6 --learning_rate_te2 1e-6`: train two Text Encoders only
- `--train_text_encoder --learning_rate 1e-6 --learning_rate_te1 1e-6 --learning_rate_te2 0`: train U-Net and one Text Encoder only
- `--train_text_encoder --learning_rate 0 --learning_rate_te1 0 --learning_rate_te2 1e-6`: train one Text Encoder only
- The default value of `huber_schedule` in Scheduled Huber Loss is changed from `exponential` to `snr`, which is expected to give better results.
- `train_db.py` and `fine_tune.py` now support different learning rates for Text Encoder. Specify with `--learning_rate_te` option.
- To train Text Encoder with `fine_tune.py`, specify `--train_text_encoder` option too. `train_db.py` trains Text Encoder by default.
- Scheduled Huber Loss の `huber_schedule` のデフォルト値を `exponential` から、より良い結果が期待できる `snr` に変更しました。
- Fixed the bug that Text Encoder is not trained when block lr is specified in `sdxl_train.py`.
### Apr 7, 2024 / 2024-04-07: v0.8.6
- Debiased Estimation loss is added to each training script. Thanks to sdbds!
- Specify `--debiased_estimation_loss` option to enable it. See PR [#889](https://github.com/kohya-ss/sd-scripts/pull/889) for details.
- Training of Text Encoder is improved in `train_network.py` and `sdxl_train_network.py`. Thanks to KohakuBlueleaf! PR [#895](https://github.com/kohya-ss/sd-scripts/pull/895)
- The moving average of the loss is now displayed in the progress bar in each training script. Thanks to shirayu! PR [#899](https://github.com/kohya-ss/sd-scripts/pull/899)
- PagedAdamW32bit optimizer is supported. Specify `--optimizer_type=PagedAdamW32bit`. Thanks to xzuyn! PR [#900](https://github.com/kohya-ss/sd-scripts/pull/900)
- Other bug fixes and improvements.
#### Highlights
- `sdxl_train.py` で、二つのText Encoderそれぞれに独立した学習率が指定できるようになりました。サンプルは上の英語版を参照してください。
- `train_db.py` および `fine_tune.py` で Text Encoder に別の学習率を指定できるようになりました。`--learning_rate_te` オプションで指定してください。
- `fine_tune.py` で Text Encoder を学習するには `--train_text_encoder` オプションをあわせて指定してください。`train_db.py` はデフォルトで学習します。
- `sdxl_train.py` で block lr を指定すると Text Encoder が学習されない不具合を修正しました。
- Debiased Estimation loss が各学習スクリプトに追加されました。sdbsd 氏に感謝します。
- `--debiased_estimation_loss` を指定すると有効になります。詳細は PR [#889](https://github.com/kohya-ss/sd-scripts/pull/889) を参照してください。
- `train_network.py``sdxl_train_network.py` でText Encoderの学習が改善されました。KohakuBlueleaf 氏に感謝します。 PR [#895](https://github.com/kohya-ss/sd-scripts/pull/895)
- 各学習スクリプトで移動平均のlossがプログレスバーに表示されるようになりました。shirayu 氏に感謝します。 PR [#899](https://github.com/kohya-ss/sd-scripts/pull/899)
- PagedAdamW32bit オプティマイザがサポートされました。`--optimizer_type=PagedAdamW32bit` と指定してください。xzuyn 氏に感謝します。 PR [#900](https://github.com/kohya-ss/sd-scripts/pull/900)
- その他のバグ修正と改善。
- The dependent libraries are updated. Please see [Upgrade](#upgrade) and update the libraries.
- Especially `imagesize` is newly added, so if you cannot update the libraries immediately, please install with `pip install imagesize==1.4.1` separately.
- `bitsandbytes==0.43.0`, `prodigyopt==1.0`, `lion-pytorch==0.0.6` are included in the requirements.txt.
- `bitsandbytes` no longer requires complex procedures as it now officially supports Windows.
- Also, the PyTorch version is updated to 2.1.2 (PyTorch does not need to be updated immediately). In the upgrade procedure, PyTorch is not updated, so please manually install or update torch, torchvision, xformers if necessary (see [Upgrade PyTorch](#upgrade-pytorch)).
- When logging to wandb is enabled, the entire command line is exposed. Therefore, it is recommended to write wandb API key and HuggingFace token in the configuration file (`.toml`). Thanks to bghira for raising the issue.
- A warning is displayed at the start of training if such information is included in the command line.
- Also, if there is an absolute path, the path may be exposed, so it is recommended to specify a relative path or write it in the configuration file. In such cases, an INFO log is displayed.
- See [#1123](https://github.com/kohya-ss/sd-scripts/pull/1123) and PR [#1240](https://github.com/kohya-ss/sd-scripts/pull/1240) for details.
- Colab seems to stop with log output. Try specifying `--console_log_simple` option in the training script to disable rich logging.
- Other improvements include the addition of masked loss, scheduled Huber Loss, DeepSpeed support, dataset settings improvements, and image tagging improvements. See below for details.
#### Training scripts
- `train_network.py` and `sdxl_train_network.py` are modified to record some dataset settings in the metadata of the trained model (`caption_prefix`, `caption_suffix`, `keep_tokens_separator`, `secondary_separator`, `enable_wildcard`).
- Fixed a bug that U-Net and Text Encoders are included in the state in `train_network.py` and `sdxl_train_network.py`. The saving and loading of the state are faster, the file size is smaller, and the memory usage when loading is reduced.
- DeepSpeed is supported. PR [#1101](https://github.com/kohya-ss/sd-scripts/pull/1101) and [#1139](https://github.com/kohya-ss/sd-scripts/pull/1139) Thanks to BootsofLagrangian! See PR [#1101](https://github.com/kohya-ss/sd-scripts/pull/1101) for details.
- The masked loss is supported in each training script. PR [#1207](https://github.com/kohya-ss/sd-scripts/pull/1207) See [Masked loss](#about-masked-loss) for details.
- Scheduled Huber Loss has been introduced to each training scripts. PR [#1228](https://github.com/kohya-ss/sd-scripts/pull/1228/) Thanks to kabachuha for the PR and cheald, drhead, and others for the discussion! See the PR and [Scheduled Huber Loss](#about-scheduled-huber-loss) for details.
- The options `--noise_offset_random_strength` and `--ip_noise_gamma_random_strength` are added to each training script. These options can be used to vary the noise offset and ip noise gamma in the range of 0 to the specified value. PR [#1177](https://github.com/kohya-ss/sd-scripts/pull/1177) Thanks to KohakuBlueleaf!
- The options `--save_state_on_train_end` are added to each training script. PR [#1168](https://github.com/kohya-ss/sd-scripts/pull/1168) Thanks to gesen2egee!
- The options `--sample_every_n_epochs` and `--sample_every_n_steps` in each training script now display a warning and ignore them when a number less than or equal to `0` is specified. Thanks to S-Del for raising the issue.
#### Dataset settings
- The [English version of the dataset settings documentation](./docs/config_README-en.md) is added. PR [#1175](https://github.com/kohya-ss/sd-scripts/pull/1175) Thanks to darkstorm2150!
- The `.toml` file for the dataset config is now read in UTF-8 encoding. PR [#1167](https://github.com/kohya-ss/sd-scripts/pull/1167) Thanks to Horizon1704!
- Fixed a bug that the last subset settings are applied to all images when multiple subsets of regularization images are specified in the dataset settings. The settings for each subset are correctly applied to each image. PR [#1205](https://github.com/kohya-ss/sd-scripts/pull/1205) Thanks to feffy380!
- Some features are added to the dataset subset settings.
- `secondary_separator` is added to specify the tag separator that is not the target of shuffling or dropping.
- Specify `secondary_separator=";;;"`. When you specify `secondary_separator`, the part is not shuffled or dropped.
- `enable_wildcard` is added. When set to `true`, the wildcard notation `{aaa|bbb|ccc}` can be used. The multi-line caption is also enabled.
- `keep_tokens_separator` is updated to be used twice in the caption. When you specify `keep_tokens_separator="|||"`, the part divided by the second `|||` is not shuffled or dropped and remains at the end.
- The existing features `caption_prefix` and `caption_suffix` can be used together. `caption_prefix` and `caption_suffix` are processed first, and then `enable_wildcard`, `keep_tokens_separator`, shuffling and dropping, and `secondary_separator` are processed in order.
- See [Dataset config](./docs/config_README-en.md) for details.
- The dataset with DreamBooth method supports caching image information (size, caption). PR [#1178](https://github.com/kohya-ss/sd-scripts/pull/1178) and [#1206](https://github.com/kohya-ss/sd-scripts/pull/1206) Thanks to KohakuBlueleaf! See [DreamBooth method specific options](./docs/config_README-en.md#dreambooth-specific-options) for details.
#### Image tagging
- The support for v3 repositories is added to `tag_image_by_wd14_tagger.py` (`--onnx` option only). PR [#1192](https://github.com/kohya-ss/sd-scripts/pull/1192) Thanks to sdbds!
- Onnx may need to be updated. Onnx is not installed by default, so please install or update it with `pip install onnx==1.15.0 onnxruntime-gpu==1.17.1` etc. Please also check the comments in `requirements.txt`.
- The model is now saved in the subdirectory as `--repo_id` in `tag_image_by_wd14_tagger.py` . This caches multiple repo_id models. Please delete unnecessary files under `--model_dir`.
- Some options are added to `tag_image_by_wd14_tagger.py`.
- Some are added in PR [#1216](https://github.com/kohya-ss/sd-scripts/pull/1216) Thanks to Disty0!
- Output rating tags `--use_rating_tags` and `--use_rating_tags_as_last_tag`
- Output character tags first `--character_tags_first`
- Expand character tags and series `--character_tag_expand`
- Specify tags to output first `--always_first_tags`
- Replace tags `--tag_replacement`
- See [Tagging documentation](./docs/wd14_tagger_README-en.md) for details.
- Fixed an error when specifying `--beam_search` and a value of 2 or more for `--num_beams` in `make_captions.py`.
#### About Masked loss
The masked loss is supported in each training script. To enable the masked loss, specify the `--masked_loss` option.
The feature is not fully tested, so there may be bugs. If you find any issues, please open an Issue.
ControlNet dataset is used to specify the mask. The mask images should be the RGB images. The pixel value 255 in R channel is treated as the mask (the loss is calculated only for the pixels with the mask), and 0 is treated as the non-mask. The pixel values 0-255 are converted to 0-1 (i.e., the pixel value 128 is treated as the half weight of the loss). See details for the dataset specification in the [LLLite documentation](./docs/train_lllite_README.md#preparing-the-dataset).
#### About Scheduled Huber Loss
Scheduled Huber Loss has been introduced to each training scripts. This is a method to improve robustness against outliers or anomalies (data corruption) in the training data.
With the traditional MSE (L2) loss function, the impact of outliers could be significant, potentially leading to a degradation in the quality of generated images. On the other hand, while the Huber loss function can suppress the influence of outliers, it tends to compromise the reproduction of fine details in images.
To address this, the proposed method employs a clever application of the Huber loss function. By scheduling the use of Huber loss in the early stages of training (when noise is high) and MSE in the later stages, it strikes a balance between outlier robustness and fine detail reproduction.
Experimental results have confirmed that this method achieves higher accuracy on data containing outliers compared to pure Huber loss or MSE. The increase in computational cost is minimal.
The newly added arguments loss_type, huber_schedule, and huber_c allow for the selection of the loss function type (Huber, smooth L1, MSE), scheduling method (exponential, constant, SNR), and Huber's parameter. This enables optimization based on the characteristics of the dataset.
See PR [#1228](https://github.com/kohya-ss/sd-scripts/pull/1228/) for details.
- `loss_type`: Specify the loss function type. Choose `huber` for Huber loss, `smooth_l1` for smooth L1 loss, and `l2` for MSE loss. The default is `l2`, which is the same as before.
- `huber_schedule`: Specify the scheduling method. Choose `exponential`, `constant`, or `snr`. The default is `snr`.
- `huber_c`: Specify the Huber's parameter. The default is `0.1`.
Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates.
#### 主要な変更点
- 依存ライブラリが更新されました。[アップグレード](./README-ja.md#アップグレード) を参照しライブラリを更新してください。
- 特に `imagesize` が新しく追加されていますので、すぐにライブラリの更新ができない場合は `pip install imagesize==1.4.1` で個別にインストールしてください。
- `bitsandbytes==0.43.0`、`prodigyopt==1.0`、`lion-pytorch==0.0.6` が requirements.txt に含まれるようになりました。
- `bitsandbytes` が公式に Windows をサポートしたため複雑な手順が不要になりました。
- また PyTorch のバージョンを 2.1.2 に更新しました。PyTorch はすぐに更新する必要はありません。更新時は、アップグレードの手順では PyTorch が更新されませんので、torch、torchvision、xformers を手動でインストールしてください。
- wandb へのログ出力が有効の場合、コマンドライン全体が公開されます。そのため、コマンドラインに wandb の API キーや HuggingFace のトークンなどが含まれる場合、設定ファイル(`.toml`)への記載をお勧めします。問題提起していただいた bghira 氏に感謝します。
- このような場合には学習開始時に警告が表示されます。
- また絶対パスの指定がある場合、そのパスが公開される可能性がありますので、相対パスを指定するか設定ファイルに記載することをお勧めします。このような場合は INFO ログが表示されます。
- 詳細は [#1123](https://github.com/kohya-ss/sd-scripts/pull/1123) および PR [#1240](https://github.com/kohya-ss/sd-scripts/pull/1240) をご覧ください。
- Colab での動作時、ログ出力で停止してしまうようです。学習スクリプトに `--console_log_simple` オプションを指定し、rich のロギングを無効してお試しください。
- その他、マスクロス追加、Scheduled Huber Loss 追加、DeepSpeed 対応、データセット設定の改善、画像タグ付けの改善などがあります。詳細は以下をご覧ください。
#### 学習スクリプト
- `train_network.py` および `sdxl_train_network.py` で、学習したモデルのメタデータに一部のデータセット設定が記録されるよう修正しました(`caption_prefix`、`caption_suffix`、`keep_tokens_separator`、`secondary_separator`、`enable_wildcard`)。
- `train_network.py` および `sdxl_train_network.py` で、state に U-Net および Text Encoder が含まれる不具合を修正しました。state の保存、読み込みが高速化され、ファイルサイズも小さくなり、また読み込み時のメモリ使用量も削減されます。
- DeepSpeed がサポートされました。PR [#1101](https://github.com/kohya-ss/sd-scripts/pull/1101) 、[#1139](https://github.com/kohya-ss/sd-scripts/pull/1139) BootsofLagrangian 氏に感謝します。詳細は PR [#1101](https://github.com/kohya-ss/sd-scripts/pull/1101) をご覧ください。
- 各学習スクリプトでマスクロスをサポートしました。PR [#1207](https://github.com/kohya-ss/sd-scripts/pull/1207) 詳細は [マスクロスについて](#マスクロスについて) をご覧ください。
- 各学習スクリプトに Scheduled Huber Loss を追加しました。PR [#1228](https://github.com/kohya-ss/sd-scripts/pull/1228/) ご提案いただいた kabachuha 氏、および議論を深めてくださった cheald 氏、drhead 氏を始めとする諸氏に感謝します。詳細は当該 PR および [Scheduled Huber Loss について](#scheduled-huber-loss-について) をご覧ください。
- 各学習スクリプトに、noise offset、ip noise gammaを、それぞれ 0~指定した値の範囲で変動させるオプション `--noise_offset_random_strength` および `--ip_noise_gamma_random_strength` が追加されました。 PR [#1177](https://github.com/kohya-ss/sd-scripts/pull/1177) KohakuBlueleaf 氏に感謝します。
- 各学習スクリプトに、学習終了時に state を保存する `--save_state_on_train_end` オプションが追加されました。 PR [#1168](https://github.com/kohya-ss/sd-scripts/pull/1168) gesen2egee 氏に感謝します。
- 各学習スクリプトで `--sample_every_n_epochs` および `--sample_every_n_steps` オプションに `0` 以下の数値を指定した時、警告を表示するとともにそれらを無視するよう変更しました。問題提起していただいた S-Del 氏に感謝します。
#### データセット設定
- データセット設定の `.toml` ファイルが UTF-8 encoding で読み込まれるようになりました。PR [#1167](https://github.com/kohya-ss/sd-scripts/pull/1167) Horizon1704 氏に感謝します。
- データセット設定で、正則化画像のサブセットを複数指定した時、最後のサブセットの各種設定がすべてのサブセットの画像に適用される不具合が修正されました。それぞれのサブセットの設定が、それぞれの画像に正しく適用されます。PR [#1205](https://github.com/kohya-ss/sd-scripts/pull/1205) feffy380 氏に感謝します。
- データセットのサブセット設定にいくつかの機能を追加しました。
- シャッフルの対象とならないタグ分割識別子の指定 `secondary_separator` を追加しました。`secondary_separator=";;;"` のように指定します。`secondary_separator` で区切ることで、その部分はシャッフル、drop 時にまとめて扱われます。
- `enable_wildcard` を追加しました。`true` にするとワイルドカード記法 `{aaa|bbb|ccc}` が使えます。また複数行キャプションも有効になります。
- `keep_tokens_separator` をキャプション内に 2 つ使えるようにしました。たとえば `keep_tokens_separator="|||"` と指定したとき、`1girl, hatsune miku, vocaloid ||| stage, mic ||| best quality, rating: general` とキャプションを指定すると、二番目の `|||` で分割された部分はシャッフル、drop されず末尾に残ります。
- 既存の機能 `caption_prefix` と `caption_suffix` とあわせて使えます。`caption_prefix` と `caption_suffix` は一番最初に処理され、その後、ワイルドカード、`keep_tokens_separator`、シャッフルおよび drop、`secondary_separator` の順に処理されます。
- 詳細は [データセット設定](./docs/config_README-ja.md) をご覧ください。
- DreamBooth 方式の DataSet で画像情報サイズ、キャプションをキャッシュする機能が追加されました。PR [#1178](https://github.com/kohya-ss/sd-scripts/pull/1178)、[#1206](https://github.com/kohya-ss/sd-scripts/pull/1206) KohakuBlueleaf 氏に感謝します。詳細は [データセット設定](./docs/config_README-ja.md#dreambooth-方式専用のオプション) をご覧ください。
- データセット設定の[英語版ドキュメント](./docs/config_README-en.md) が追加されました。PR [#1175](https://github.com/kohya-ss/sd-scripts/pull/1175) darkstorm2150 氏に感謝します。
#### 画像のタグ付け
- `tag_image_by_wd14_tagger.py` で v3 のリポジトリがサポートされました(`--onnx` 指定時のみ有効)。 PR [#1192](https://github.com/kohya-ss/sd-scripts/pull/1192) sdbds 氏に感謝します。
- Onnx のバージョンアップが必要になるかもしれません。デフォルトでは Onnx はインストールされていませんので、`pip install onnx==1.15.0 onnxruntime-gpu==1.17.1` 等でインストール、アップデートしてください。`requirements.txt` のコメントもあわせてご確認ください。
- `tag_image_by_wd14_tagger.py` で、モデルを`--repo_id` のサブディレクトリに保存するようにしました。これにより複数のモデルファイルがキャッシュされます。`--model_dir` 直下の不要なファイルは削除願います。
- `tag_image_by_wd14_tagger.py` にいくつかのオプションを追加しました。
- 一部は PR [#1216](https://github.com/kohya-ss/sd-scripts/pull/1216) で追加されました。Disty0 氏に感謝します。
- レーティングタグを出力する `--use_rating_tags` および `--use_rating_tags_as_last_tag`
- キャラクタタグを最初に出力する `--character_tags_first`
- キャラクタタグとシリーズを展開する `--character_tag_expand`
- 常に最初に出力するタグを指定する `--always_first_tags`
- タグを置換する `--tag_replacement`
- 詳細は [タグ付けに関するドキュメント](./docs/wd14_tagger_README-ja.md) をご覧ください。
- `make_captions.py` で `--beam_search` を指定し `--num_beams` に2以上の値を指定した時のエラーを修正しました。
#### マスクロスについて
各学習スクリプトでマスクロスをサポートしました。マスクロスを有効にするには `--masked_loss` オプションを指定してください。
機能は完全にテストされていないため、不具合があるかもしれません。その場合は Issue を立てていただけると助かります。
マスクの指定には ControlNet データセットを使用します。マスク画像は RGB 画像である必要があります。R チャンネルのピクセル値 255 がロス計算対象、0 がロス計算対象外になります。0-255 の値は、0-1 の範囲に変換されます(つまりピクセル値 128 の部分はロスの重みが半分になります)。データセットの詳細は [LLLite ドキュメント](./docs/train_lllite_README-ja.md#データセットの準備) をご覧ください。
#### Scheduled Huber Loss について
各学習スクリプトに、学習データ中の異常値や外れ値data corruptionへの耐性を高めるための手法、Scheduled Huber Lossが導入されました。
従来のMSEL2損失関数では、異常値の影響を大きく受けてしまい、生成画像の品質低下を招く恐れがありました。一方、Huber損失関数は異常値の影響を抑えられますが、画像の細部再現性が損なわれがちでした。
この手法ではHuber損失関数の適用を工夫し、学習の初期段階イズが大きい場合ではHuber損失を、後期段階ではMSEを用いるようスケジューリングすることで、異常値耐性と細部再現性のバランスを取ります。
実験の結果では、この手法が純粋なHuber損失やMSEと比べ、異常値を含むデータでより高い精度を達成することが確認されています。また計算コストの増加はわずかです。
具体的には、新たに追加された引数loss_type、huber_schedule、huber_cで、損失関数の種類Huber, smooth L1, MSEとスケジューリング方法exponential, constant, SNRを選択できます。これによりデータセットに応じた最適化が可能になります。
詳細は PR [#1228](https://github.com/kohya-ss/sd-scripts/pull/1228/) をご覧ください。
- `loss_type` : 損失関数の種類を指定します。`huber` で Huber損失、`smooth_l1` で smooth L1 損失、`l2` で MSE 損失を選択します。デフォルトは `l2` で、従来と同様です。
- `huber_schedule` : スケジューリング方法を指定します。`exponential` で指数関数的、`constant` で一定、`snr` で信号対雑音比に基づくスケジューリングを選択します。デフォルトは `snr` です。
- `huber_c` : Huber損失のパラメータを指定します。デフォルトは `0.1` です。
PR 内でいくつかの比較が共有されています。この機能を試す場合、最初は `--loss_type smooth_l1 --huber_schedule snr --huber_c 0.1` などで試してみるとよいかもしれません。
最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。
## Additional Information
### Naming of LoRA
The LoRA supported by `train_network.py` has been named to avoid confusion. The documentation has been updated. The following are the names of LoRA types in this repository.
@@ -299,27 +324,14 @@ The LoRA supported by `train_network.py` has been named to avoid confusion. The
In addition to 1., LoRA for Conv2d layers with 3x3 kernel
LoRA-LierLa is the default LoRA type for `train_network.py` (without `conv_dim` network arg). LoRA-LierLa can be used with [our extension](https://github.com/kohya-ss/sd-webui-additional-networks) for AUTOMATIC1111's Web UI, or with the built-in LoRA feature of the Web UI.
LoRA-LierLa is the default LoRA type for `train_network.py` (without `conv_dim` network arg).
<!--
LoRA-LierLa can be used with [our extension](https://github.com/kohya-ss/sd-webui-additional-networks) for AUTOMATIC1111's Web UI, or with the built-in LoRA feature of the Web UI.
To use LoRA-C3Lier with Web UI, please use our extension.
To use LoRA-C3Lier with Web UI, please use our extension.
-->
### LoRAの名称について
`train_network.py` がサポートするLoRAについて、混乱を避けるため名前を付けました。ドキュメントは更新済みです。以下は当リポジトリ内の独自の名称です。
1. __LoRA-LierLa__ : (LoRA for __Li__ n __e__ a __r__ __La__ yers、リエラと読みます)
Linear 層およびカーネルサイズ 1x1 の Conv2d 層に適用されるLoRA
2. __LoRA-C3Lier__ : (LoRA for __C__ olutional layers with __3__ x3 Kernel and __Li__ n __e__ a __r__ layers、セリアと読みます)
1.に加え、カーネルサイズ 3x3 の Conv2d 層に適用されるLoRA
LoRA-LierLa は[Web UI向け拡張](https://github.com/kohya-ss/sd-webui-additional-networks)、またはAUTOMATIC1111氏のWeb UIのLoRA機能で使用することができます。
LoRA-C3Lierを使いWeb UIで生成するには拡張を使用してください。
## Sample image generation during training
### Sample image generation during training
A prompt file might look like this, for example
```
@@ -340,26 +352,3 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
* `--s` Specifies the number of steps in the generation.
The prompt weighting such as `( )` and `[ ]` are working.
## サンプル画像生成
プロンプトファイルは例えば以下のようになります。
```
# prompt 1
masterpiece, best quality, (1girl), in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
# prompt 2
masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n (low quality, worst quality), bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
```
`#` で始まる行はコメントになります。`--n` のように「ハイフン二個+英小文字」の形でオプションを指定できます。以下が使用可能できます。
* `--n` Negative prompt up to the next option.
* `--w` Specifies the width of the generated image.
* `--h` Specifies the height of the generated image.
* `--d` Specifies the seed of the generated image.
* `--l` Specifies the CFG scale of the generated image.
* `--s` Specifies the number of steps in the generation.
`( )``[ ]` などの重みづけも動作します。

View File

@@ -1,11 +1,7 @@
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from library.device_utils import init_ipex
init_ipex()
from typing import Union, List, Optional, Dict, Any, Tuple
from diffusers.models.unet_2d_condition import UNet2DConditionOutput

384
docs/config_README-en.md Normal file
View File

@@ -0,0 +1,384 @@
Original Source by kohya-ss
First version:
A.I Translation by Model: NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO, editing by Darkstorm2150
Some parts are manually added.
# Config Readme
This README is about the configuration files that can be passed with the `--dataset_config` option.
## Overview
By passing a configuration file, users can make detailed settings.
* Multiple datasets can be configured
* For example, by setting `resolution` for each dataset, they can be mixed and trained.
* In training methods that support both the DreamBooth approach and the fine-tuning approach, datasets of the DreamBooth method and the fine-tuning method can be mixed.
* Settings can be changed for each subset
* A subset is a partition of the dataset by image directory or metadata. Several subsets make up a dataset.
* Options such as `keep_tokens` and `flip_aug` can be set for each subset. On the other hand, options such as `resolution` and `batch_size` can be set for each dataset, and their values are common among subsets belonging to the same dataset. More details will be provided later.
The configuration file format can be JSON or TOML. Considering the ease of writing, it is recommended to use [TOML](https://toml.io/ja/v1.0.0-rc.2). The following explanation assumes the use of TOML.
Here is an example of a configuration file written in TOML.
```toml
[general]
shuffle_caption = true
caption_extension = '.txt'
keep_tokens = 1
# This is a DreamBooth-style dataset
[[datasets]]
resolution = 512
batch_size = 4
keep_tokens = 2
[[datasets.subsets]]
image_dir = 'C:\hoge'
class_tokens = 'hoge girl'
# This subset uses keep_tokens = 2 (the value of the parent datasets)
[[datasets.subsets]]
image_dir = 'C:\fuga'
class_tokens = 'fuga boy'
keep_tokens = 3
[[datasets.subsets]]
is_reg = true
image_dir = 'C:\reg'
class_tokens = 'human'
keep_tokens = 1
# This is a fine-tuning dataset
[[datasets]]
resolution = [768, 768]
batch_size = 2
[[datasets.subsets]]
image_dir = 'C:\piyo'
metadata_file = 'C:\piyo\piyo_md.json'
# This subset uses keep_tokens = 1 (the value of [general])
```
In this example, three directories are trained as a DreamBooth-style dataset at 512x512 (batch size 4), and one directory is trained as a fine-tuning dataset at 768x768 (batch size 2).
## Settings for datasets and subsets
Settings for datasets and subsets are divided into several registration locations.
* `[general]`
* This is where options that apply to all datasets or all subsets are specified.
* If there are options with the same name in the dataset-specific or subset-specific settings, the dataset-specific or subset-specific settings take precedence.
* `[[datasets]]`
* `datasets` is where settings for datasets are registered. This is where options that apply individually to each dataset are specified.
* If there are subset-specific settings, the subset-specific settings take precedence.
* `[[datasets.subsets]]`
* `datasets.subsets` is where settings for subsets are registered. This is where options that apply individually to each subset are specified.
Here is an image showing the correspondence between image directories and registration locations in the previous example.
```
C:\
├─ hoge -> [[datasets.subsets]] No.1 ┐ ┐
├─ fuga -> [[datasets.subsets]] No.2 |-> [[datasets]] No.1 |-> [general]
├─ reg -> [[datasets.subsets]] No.3 ┘ |
└─ piyo -> [[datasets.subsets]] No.4 --> [[datasets]] No.2 ┘
```
The image directory corresponds to each `[[datasets.subsets]]`. Then, multiple `[[datasets.subsets]]` are combined to form one `[[datasets]]`. All `[[datasets]]` and `[[datasets.subsets]]` belong to `[general]`.
The available options for each registration location may differ, but if the same option is specified, the value in the lower registration location will take precedence. You can check how the `keep_tokens` option is handled in the previous example for better understanding.
Additionally, the available options may vary depending on the method that the learning approach supports.
* Options specific to the DreamBooth method
* Options specific to the fine-tuning method
* Options available when using the caption dropout technique
When using both the DreamBooth method and the fine-tuning method, they can be used together with a learning approach that supports both.
When using them together, a point to note is that the method is determined based on the dataset, so it is not possible to mix DreamBooth method subsets and fine-tuning method subsets within the same dataset.
In other words, if you want to use both methods together, you need to set up subsets of different methods belonging to different datasets.
In terms of program behavior, if the `metadata_file` option exists, it is determined to be a subset of fine-tuning. Therefore, for subsets belonging to the same dataset, as long as they are either "all have the `metadata_file` option" or "all have no `metadata_file` option," there is no problem.
Below, the available options will be explained. For options with the same name as the command-line argument, the explanation will be omitted in principle. Please refer to other READMEs.
### Common options for all learning methods
These are options that can be specified regardless of the learning method.
#### Data set specific options
These are options related to the configuration of the data set. They cannot be described in `datasets.subsets`.
| Option Name | Example Setting | `[general]` | `[[datasets]]` |
| ---- | ---- | ---- | ---- |
| `batch_size` | `1` | o | o |
| `bucket_no_upscale` | `true` | o | o |
| `bucket_reso_steps` | `64` | o | o |
| `enable_bucket` | `true` | o | o |
| `max_bucket_reso` | `1024` | o | o |
| `min_bucket_reso` | `128` | o | o |
| `resolution` | `256`, `[512, 512]` | o | o |
* `batch_size`
* This corresponds to the command-line argument `--train_batch_size`.
These settings are fixed per dataset. That means that subsets belonging to the same dataset will share these settings. For example, if you want to prepare datasets with different resolutions, you can define them as separate datasets as shown in the example above, and set different resolutions for each.
#### Options for Subsets
These options are related to subset configuration.
| Option Name | Example | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
| ---- | ---- | ---- | ---- | ---- |
| `color_aug` | `false` | o | o | o |
| `face_crop_aug_range` | `[1.0, 3.0]` | o | o | o |
| `flip_aug` | `true` | o | o | o |
| `keep_tokens` | `2` | o | o | o |
| `num_repeats` | `10` | o | o | o |
| `random_crop` | `false` | o | o | o |
| `shuffle_caption` | `true` | o | o | o |
| `caption_prefix` | `"masterpiece, best quality, "` | o | o | o |
| `caption_suffix` | `", from side"` | o | o | o |
| `caption_separator` | (not specified) | o | o | o |
| `keep_tokens_separator` | `“|||”` | o | o | o |
| `secondary_separator` | `“;;;”` | o | o | o |
| `enable_wildcard` | `true` | o | o | o |
* `num_repeats`
* Specifies the number of repeats for images in a subset. This is equivalent to `--dataset_repeats` in fine-tuning but can be specified for any training method.
* `caption_prefix`, `caption_suffix`
* Specifies the prefix and suffix strings to be appended to the captions. Shuffling is performed with these strings included. Be cautious when using `keep_tokens`.
* `caption_separator`
* Specifies the string to separate the tags. The default is `,`. This option is usually not necessary to set.
* `keep_tokens_separator`
* Specifies the string to separate the parts to be fixed in the caption. For example, if you specify `aaa, bbb ||| ccc, ddd, eee, fff ||| ggg, hhh`, the parts `aaa, bbb` and `ggg, hhh` will remain, and the rest will be shuffled and dropped. The comma in between is not necessary. As a result, the prompt will be `aaa, bbb, eee, ccc, fff, ggg, hhh` or `aaa, bbb, fff, ccc, eee, ggg, hhh`, etc.
* `secondary_separator`
* Specifies an additional separator. The part separated by this separator is treated as one tag and is shuffled and dropped. It is then replaced by `caption_separator`. For example, if you specify `aaa;;;bbb;;;ccc`, it will be replaced by `aaa,bbb,ccc` or dropped together.
* `enable_wildcard`
* Enables wildcard notation. This will be explained later.
### DreamBooth-specific options
DreamBooth-specific options only exist as subsets-specific options.
#### Subset-specific options
Options related to the configuration of DreamBooth subsets.
| Option Name | Example Setting | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
| ---- | ---- | ---- | ---- | ---- |
| `image_dir` | `'C:\hoge'` | - | - | o (required) |
| `caption_extension` | `".txt"` | o | o | o |
| `class_tokens` | `"sks girl"` | - | - | o |
| `cache_info` | `false` | o | o | o |
| `is_reg` | `false` | - | - | o |
Firstly, note that for `image_dir`, the path to the image files must be specified as being directly in the directory. Unlike the previous DreamBooth method, where images had to be placed in subdirectories, this is not compatible with that specification. Also, even if you name the folder something like "5_cat", the number of repeats of the image and the class name will not be reflected. If you want to set these individually, you will need to explicitly specify them using `num_repeats` and `class_tokens`.
* `image_dir`
* Specifies the path to the image directory. This is a required option.
* Images must be placed directly under the directory.
* `class_tokens`
* Sets the class tokens.
* Only used during training when a corresponding caption file does not exist. The determination of whether or not to use it is made on a per-image basis. If `class_tokens` is not specified and a caption file is not found, an error will occur.
* `cache_info`
* Specifies whether to cache the image size and caption. If not specified, it is set to `false`. The cache is saved in `metadata_cache.json` in `image_dir`.
* Caching speeds up the loading of the dataset after the first time. It is effective when dealing with thousands of images or more.
* `is_reg`
* Specifies whether the subset images are for normalization. If not specified, it is set to `false`, meaning that the images are not for normalization.
### Fine-tuning method specific options
The options for the fine-tuning method only exist for subset-specific options.
#### Subset-specific options
These options are related to the configuration of the fine-tuning method's subsets.
| Option name | Example setting | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
| ---- | ---- | ---- | ---- | ---- |
| `image_dir` | `'C:\hoge'` | - | - | o |
| `metadata_file` | `'C:\piyo\piyo_md.json'` | - | - | o (required) |
* `image_dir`
* Specify the path to the image directory. Unlike the DreamBooth method, specifying it is not mandatory, but it is recommended to do so.
* The case where it is not necessary to specify is when the `--full_path` is added to the command line when generating the metadata file.
* The images must be placed directly under the directory.
* `metadata_file`
* Specify the path to the metadata file used for the subset. This is a required option.
* It is equivalent to the command-line argument `--in_json`.
* Due to the specification that a metadata file must be specified for each subset, it is recommended to avoid creating a metadata file with images from different directories as a single metadata file. It is strongly recommended to prepare a separate metadata file for each image directory and register them as separate subsets.
### Options available when caption dropout method can be used
The options available when the caption dropout method can be used exist only for subsets. Regardless of whether it's the DreamBooth method or fine-tuning method, if it supports caption dropout, it can be specified.
#### Subset-specific options
Options related to the setting of subsets that caption dropout can be used for.
| Option Name | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
| ---- | ---- | ---- | ---- |
| `caption_dropout_every_n_epochs` | o | o | o |
| `caption_dropout_rate` | o | o | o |
| `caption_tag_dropout_rate` | o | o | o |
## Behavior when there are duplicate subsets
In the case of the DreamBooth dataset, if there are multiple `image_dir` directories with the same content, they are considered to be duplicate subsets. For the fine-tuning dataset, if there are multiple `metadata_file` files with the same content, they are considered to be duplicate subsets. If duplicate subsets exist in the dataset, subsequent subsets will be ignored.
However, if they belong to different datasets, they are not considered duplicates. For example, if you have subsets with the same `image_dir` in different datasets, they will not be considered duplicates. This is useful when you want to train with the same image but with different resolutions.
```toml
# If data sets exist separately, they are not considered duplicates and are both used for training.
[[datasets]]
resolution = 512
[[datasets.subsets]]
image_dir = 'C:\hoge'
[[datasets]]
resolution = 768
[[datasets.subsets]]
image_dir = 'C:\hoge'
```
## Command Line Argument and Configuration File
There are options in the configuration file that have overlapping roles with command line argument options.
The following command line argument options are ignored if a configuration file is passed:
* `--train_data_dir`
* `--reg_data_dir`
* `--in_json`
The following command line argument options are given priority over the configuration file options if both are specified simultaneously. In most cases, they have the same names as the corresponding options in the configuration file.
| Command Line Argument Option | Prioritized Configuration File Option |
| ------------------------------- | ------------------------------------- |
| `--bucket_no_upscale` | |
| `--bucket_reso_steps` | |
| `--caption_dropout_every_n_epochs` | |
| `--caption_dropout_rate` | |
| `--caption_extension` | |
| `--caption_tag_dropout_rate` | |
| `--color_aug` | |
| `--dataset_repeats` | `num_repeats` |
| `--enable_bucket` | |
| `--face_crop_aug_range` | |
| `--flip_aug` | |
| `--keep_tokens` | |
| `--min_bucket_reso` | |
| `--random_crop` | |
| `--resolution` | |
| `--shuffle_caption` | |
| `--train_batch_size` | `batch_size` |
## Error Guide
Currently, we are using an external library to check if the configuration file is written correctly, but the development has not been completed, and there is a problem that the error message is not clear. In the future, we plan to improve this problem.
As a temporary measure, we will list common errors and their solutions. If you encounter an error even though it should be correct or if the error content is not understandable, please contact us as it may be a bug.
* `voluptuous.error.MultipleInvalid: required key not provided @ ...`: This error occurs when a required option is not provided. It is highly likely that you forgot to specify the option or misspelled the option name.
* The error location is indicated by `...` in the error message. For example, if you encounter an error like `voluptuous.error.MultipleInvalid: required key not provided @ data['datasets'][0]['subsets'][0]['image_dir']`, it means that the `image_dir` option does not exist in the 0th `subsets` of the 0th `datasets` setting.
* `voluptuous.error.MultipleInvalid: expected int for dictionary value @ ...`: This error occurs when the specified value format is incorrect. It is highly likely that the value format is incorrect. The `int` part changes depending on the target option. The example configurations in this README may be helpful.
* `voluptuous.error.MultipleInvalid: extra keys not allowed @ ...`: This error occurs when there is an option name that is not supported. It is highly likely that you misspelled the option name or mistakenly included it.
## Miscellaneous
### Multi-line captions
By setting `enable_wildcard = true`, multiple-line captions are also enabled. If the caption file consists of multiple lines, one line is randomly selected as the caption.
```txt
1girl, hatsune miku, vocaloid, upper body, looking at viewer, microphone, stage
a girl with a microphone standing on a stage
detailed digital art of a girl with a microphone on a stage
```
It can be combined with wildcard notation.
In metadata files, you can also specify multiple-line captions. In the `.json` metadata file, use `\n` to represent a line break. If the caption file consists of multiple lines, `merge_captions_to_metadata.py` will create a metadata file in this format.
The tags in the metadata (`tags`) are added to each line of the caption.
```json
{
"/path/to/image.png": {
"caption": "a cartoon of a frog with the word frog on it\ntest multiline caption1\ntest multiline caption2",
"tags": "open mouth, simple background, standing, no humans, animal, black background, frog, animal costume, animal focus"
},
...
}
```
In this case, the actual caption will be `a cartoon of a frog with the word frog on it, open mouth, simple background ...`, `test multiline caption1, open mouth, simple background ...`, `test multiline caption2, open mouth, simple background ...`, etc.
### Example of configuration file : `secondary_separator`, wildcard notation, `keep_tokens_separator`, etc.
```toml
[general]
flip_aug = true
color_aug = false
resolution = [1024, 1024]
[[datasets]]
batch_size = 6
enable_bucket = true
bucket_no_upscale = true
caption_extension = ".txt"
keep_tokens_separator= "|||"
shuffle_caption = true
caption_tag_dropout_rate = 0.1
secondary_separator = ";;;" # subset 側に書くこともできます / can be written in the subset side
enable_wildcard = true # 同上 / same as above
[[datasets.subsets]]
image_dir = "/path/to/image_dir"
num_repeats = 1
# ||| の前後はカンマは不要です(自動的に追加されます) / No comma is required before and after ||| (it is added automatically)
caption_prefix = "1girl, hatsune miku, vocaloid |||"
# ||| の後はシャッフル、drop されず残ります / After |||, it is not shuffled or dropped and remains
# 単純に文字列として連結されるので、カンマなどは自分で入れる必要があります / It is simply concatenated as a string, so you need to put commas yourself
caption_suffix = ", anime screencap ||| masterpiece, rating: general"
```
### Example of caption, secondary_separator notation: `secondary_separator = ";;;"`
```txt
1girl, hatsune miku, vocaloid, upper body, looking at viewer, sky;;;cloud;;;day, outdoors
```
The part `sky;;;cloud;;;day` is replaced with `sky,cloud,day` without shuffling or dropping. When shuffling and dropping are enabled, it is processed as a whole (as one tag). For example, it becomes `vocaloid, 1girl, upper body, sky,cloud,day, outdoors, hatsune miku` (shuffled) or `vocaloid, 1girl, outdoors, looking at viewer, upper body, hatsune miku` (dropped).
### Example of caption, enable_wildcard notation: `enable_wildcard = true`
```txt
1girl, hatsune miku, vocaloid, upper body, looking at viewer, {simple|white} background
```
`simple` or `white` is randomly selected, and it becomes `simple background` or `white background`.
```txt
1girl, hatsune miku, vocaloid, {{retro style}}
```
If you want to include `{` or `}` in the tag string, double them like `{{` or `}}` (in this example, the actual caption used for training is `{retro style}`).
### Example of caption, `keep_tokens_separator` notation: `keep_tokens_separator = "|||"`
```txt
1girl, hatsune miku, vocaloid ||| stage, microphone, white shirt, smile ||| best quality, rating: general
```
It becomes `1girl, hatsune miku, vocaloid, microphone, stage, white shirt, best quality, rating: general` or `1girl, hatsune miku, vocaloid, white shirt, smile, stage, microphone, best quality, rating: general` etc.

View File

@@ -1,5 +1,3 @@
For non-Japanese speakers: this README is provided only in Japanese in the current state. Sorry for inconvenience. We will provide English version in the near future.
`--dataset_config` で渡すことができる設定ファイルに関する説明です。
## 概要
@@ -140,12 +138,28 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学
| `shuffle_caption` | `true` | o | o | o |
| `caption_prefix` | `“masterpiece, best quality, ”` | o | o | o |
| `caption_suffix` | `“, from side”` | o | o | o |
| `caption_separator` | (通常は設定しません) | o | o | o |
| `keep_tokens_separator` | `“|||”` | o | o | o |
| `secondary_separator` | `“;;;”` | o | o | o |
| `enable_wildcard` | `true` | o | o | o |
* `num_repeats`
* サブセットの画像の繰り返し回数を指定します。fine tuning における `--dataset_repeats` に相当しますが、`num_repeats` はどの学習方法でも指定可能です。
* `caption_prefix`, `caption_suffix`
* キャプションの前、後に付与する文字列を指定します。シャッフルはこれらの文字列を含めた状態で行われます。`keep_tokens` を指定する場合には注意してください。
* `caption_separator`
* タグを区切る文字列を指定します。デフォルトは `,` です。このオプションは通常は設定する必要はありません。
* `keep_tokens_separator`
* キャプションで固定したい部分を区切る文字列を指定します。たとえば `aaa, bbb ||| ccc, ddd, eee, fff ||| ggg, hhh` のように指定すると、`aaa, bbb``ggg, hhh` の部分はシャッフル、drop されず残ります。間のカンマは不要です。結果としてプロンプトは `aaa, bbb, eee, ccc, fff, ggg, hhh``aaa, bbb, fff, ccc, eee, ggg, hhh` などになります。
* `secondary_separator`
* 追加の区切り文字を指定します。この区切り文字で区切られた部分は一つのタグとして扱われ、シャッフル、drop されます。その後、`caption_separator` に置き換えられます。たとえば `aaa;;;bbb;;;ccc` のように指定すると、`aaa,bbb,ccc` に置き換えられるか、まとめて drop されます。
* `enable_wildcard`
* ワイルドカード記法および複数行キャプションを有効にします。ワイルドカード記法、複数行キャプションについては後述します。
### DreamBooth 方式専用のオプション
DreamBooth 方式のオプションは、サブセット向けオプションのみ存在します。
@@ -159,6 +173,7 @@ DreamBooth 方式のサブセットの設定に関わるオプションです。
| `image_dir` | `C:\hoge` | - | - | o必須 |
| `caption_extension` | `".txt"` | o | o | o |
| `class_tokens` | `“sks girl”` | - | - | o |
| `cache_info` | `false` | o | o | o |
| `is_reg` | `false` | - | - | o |
まず注意点として、 `image_dir` には画像ファイルが直下に置かれているパスを指定する必要があります。従来の DreamBooth の手法ではサブディレクトリに画像を置く必要がありましたが、そちらとは仕様に互換性がありません。また、`5_cat` のようなフォルダ名にしても、画像の繰り返し回数とクラス名は反映されません。これらを個別に設定したい場合、`num_repeats``class_tokens` で明示的に指定する必要があることに注意してください。
@@ -169,6 +184,9 @@ DreamBooth 方式のサブセットの設定に関わるオプションです。
* `class_tokens`
* クラストークンを設定します。
* 画像に対応する caption ファイルが存在しない場合にのみ学習時に利用されます。利用するかどうかの判定は画像ごとに行います。`class_tokens` を指定しなかった場合に caption ファイルも見つからなかった場合にはエラーになります。
* `cache_info`
* 画像サイズ、キャプションをキャッシュするかどうかを指定します。指定しなかった場合は `false` になります。キャッシュは `image_dir``metadata_cache.json` というファイル名で保存されます。
* キャッシュを行うと、二回目以降のデータセット読み込みが高速化されます。数千枚以上の画像を扱う場合には有効です。
* `is_reg`
* サブセットの画像が正規化用かどうかを指定します。指定しなかった場合は `false` として、つまり正規化画像ではないとして扱います。
@@ -280,4 +298,89 @@ resolution = 768
* `voluptuous.error.MultipleInvalid: expected int for dictionary value @ ...`: 指定する値の形式が不正というエラーです。値の形式が間違っている可能性が高いです。`int` の部分は対象となるオプションによって変わります。この README に載っているオプションの「設定例」が役立つかもしれません。
* `voluptuous.error.MultipleInvalid: extra keys not allowed @ ...`: 対応していないオプション名が存在している場合に発生するエラーです。オプション名を間違って記述しているか、誤って紛れ込んでいる可能性が高いです。
## その他
### 複数行キャプション
`enable_wildcard = true` を設定することで、複数行キャプションも同時に有効になります。キャプションファイルが複数の行からなる場合、ランダムに一つの行が選ばれてキャプションとして利用されます。
```txt
1girl, hatsune miku, vocaloid, upper body, looking at viewer, microphone, stage
a girl with a microphone standing on a stage
detailed digital art of a girl with a microphone on a stage
```
ワイルドカード記法と組み合わせることも可能です。
メタデータファイルでも同様に複数行キャプションを指定することができます。メタデータの .json 内には、`\n` を使って改行を表現してください。キャプションファイルが複数行からなる場合、`merge_captions_to_metadata.py` を使うと、この形式でメタデータファイルが作成されます。
メタデータのタグ (`tags`) は、キャプションの各行に追加されます。
```json
{
"/path/to/image.png": {
"caption": "a cartoon of a frog with the word frog on it\ntest multiline caption1\ntest multiline caption2",
"tags": "open mouth, simple background, standing, no humans, animal, black background, frog, animal costume, animal focus"
},
...
}
```
この場合、実際のキャプションは `a cartoon of a frog with the word frog on it, open mouth, simple background ...` または `test multiline caption1, open mouth, simple background ...``test multiline caption2, open mouth, simple background ...` 等になります。
### 設定ファイルの記述例:追加の区切り文字、ワイルドカード記法、`keep_tokens_separator` 等
```toml
[general]
flip_aug = true
color_aug = false
resolution = [1024, 1024]
[[datasets]]
batch_size = 6
enable_bucket = true
bucket_no_upscale = true
caption_extension = ".txt"
keep_tokens_separator= "|||"
shuffle_caption = true
caption_tag_dropout_rate = 0.1
secondary_separator = ";;;" # subset 側に書くこともできます / can be written in the subset side
enable_wildcard = true # 同上 / same as above
[[datasets.subsets]]
image_dir = "/path/to/image_dir"
num_repeats = 1
# ||| の前後はカンマは不要です(自動的に追加されます) / No comma is required before and after ||| (it is added automatically)
caption_prefix = "1girl, hatsune miku, vocaloid |||"
# ||| の後はシャッフル、drop されず残ります / After |||, it is not shuffled or dropped and remains
# 単純に文字列として連結されるので、カンマなどは自分で入れる必要があります / It is simply concatenated as a string, so you need to put commas yourself
caption_suffix = ", anime screencap ||| masterpiece, rating: general"
```
### キャプション記述例、secondary_separator 記法:`secondary_separator = ";;;"` の場合
```txt
1girl, hatsune miku, vocaloid, upper body, looking at viewer, sky;;;cloud;;;day, outdoors
```
`sky;;;cloud;;;day` の部分はシャッフル、drop されず `sky,cloud,day` に置換されます。シャッフル、drop が有効な場合、まとめて(一つのタグとして)処理されます。つまり `vocaloid, 1girl, upper body, sky,cloud,day, outdoors, hatsune miku` (シャッフル)や `vocaloid, 1girl, outdoors, looking at viewer, upper body, hatsune miku` drop されたケース)などになります。
### キャプション記述例、ワイルドカード記法: `enable_wildcard = true` の場合
```txt
1girl, hatsune miku, vocaloid, upper body, looking at viewer, {simple|white} background
```
ランダムに `simple` または `white` が選ばれ、`simple background` または `white background` になります。
```txt
1girl, hatsune miku, vocaloid, {{retro style}}
```
タグ文字列に `{``}` そのものを含めたい場合は `{{``}}` のように二つ重ねてください(この例では実際に学習に用いられるキャプションは `{retro style}` になります)。
### キャプション記述例、`keep_tokens_separator` 記法: `keep_tokens_separator = "|||"` の場合
```txt
1girl, hatsune miku, vocaloid ||| stage, microphone, white shirt, smile ||| best quality, rating: general
```
`1girl, hatsune miku, vocaloid, microphone, stage, white shirt, best quality, rating: general``1girl, hatsune miku, vocaloid, white shirt, smile, stage, microphone, best quality, rating: general` などになります。

View File

@@ -452,3 +452,36 @@ python gen_img_diffusers.py --ckpt wd-v1-3-full-pruned-half.ckpt
- `--network_show_meta` : 追加ネットワークのメタデータを表示します。
---
# About Gradual Latent
Gradual Latent is a Hires fix that gradually increases the size of the latent. `gen_img.py`, `sdxl_gen_img.py`, and `gen_img_diffusers.py` have the following options.
- `--gradual_latent_timesteps`: Specifies the timestep to start increasing the size of the latent. The default is None, which means Gradual Latent is not used. Please try around 750 at first.
- `--gradual_latent_ratio`: Specifies the initial size of the latent. The default is 0.5, which means it starts with half the default latent size.
- `--gradual_latent_ratio_step`: Specifies the ratio to increase the size of the latent. The default is 0.125, which means the latent size is gradually increased to 0.625, 0.75, 0.875, 1.0.
- `--gradual_latent_ratio_every_n_steps`: Specifies the interval to increase the size of the latent. The default is 3, which means the latent size is increased every 3 steps.
Each option can also be specified with prompt options, `--glt`, `--glr`, `--gls`, `--gle`.
__Please specify `euler_a` for the sampler.__ Because the source code of the sampler is modified. It will not work with other samplers.
It is more effective with SD 1.5. It is quite subtle with SDXL.
# Gradual Latent について
latentのサイズを徐々に大きくしていくHires fixです。`gen_img.py` 、``sdxl_gen_img.py``gen_img_diffusers.py` に以下のオプションが追加されています。
- `--gradual_latent_timesteps` : latentのサイズを大きくし始めるタイムステップを指定します。デフォルトは None で、Gradual Latentを使用しません。750 くらいから始めてみてください。
- `--gradual_latent_ratio` : latentの初期サイズを指定します。デフォルトは 0.5 で、デフォルトの latent サイズの半分のサイズから始めます。
- `--gradual_latent_ratio_step`: latentのサイズを大きくする割合を指定します。デフォルトは 0.125 で、latentのサイズを 0.625, 0.75, 0.875, 1.0 と徐々に大きくします。
- `--gradual_latent_ratio_every_n_steps`: latentのサイズを大きくする間隔を指定します。デフォルトは 3 で、3ステップごとに latent のサイズを大きくします。
それぞれのオプションは、プロンプトオプション、`--glt``--glr``--gls``--gle` でも指定できます。
サンプラーに手を加えているため、__サンプラーに `euler_a` を指定してください。__ 他のサンプラーでは動作しません。
SD 1.5 のほうが効果があります。SDXL ではかなり微妙です。

View File

@@ -374,6 +374,10 @@ classがひとつで対象が複数の場合、正則化画像フォルダはひ
サンプル出力するステップ数またはエポック数を指定します。この数ごとにサンプル出力します。両方指定するとエポック数が優先されます。
- `--sample_at_first`
学習開始前にサンプル出力します。学習前との比較ができます。
- `--sample_prompts`
サンプル出力用プロンプトのファイルを指定します。

84
docs/train_SDXL-en.md Normal file
View File

@@ -0,0 +1,84 @@
## SDXL training
The documentation will be moved to the training documentation in the future. The following is a brief explanation of the training scripts for SDXL.
### Training scripts for SDXL
- `sdxl_train.py` is a script for SDXL fine-tuning. The usage is almost the same as `fine_tune.py`, but it also supports DreamBooth dataset.
- `--full_bf16` option is added. Thanks to KohakuBlueleaf!
- This option enables the full bfloat16 training (includes gradients). This option is useful to reduce the GPU memory usage.
- The full bfloat16 training might be unstable. Please use it at your own risk.
- The different learning rates for each U-Net block are now supported in sdxl_train.py. Specify with `--block_lr` option. Specify 23 values separated by commas like `--block_lr 1e-3,1e-3 ... 1e-3`.
- 23 values correspond to `0: time/label embed, 1-9: input blocks 0-8, 10-12: mid blocks 0-2, 13-21: output blocks 0-8, 22: out`.
- `prepare_buckets_latents.py` now supports SDXL fine-tuning.
- `sdxl_train_network.py` is a script for LoRA training for SDXL. The usage is almost the same as `train_network.py`.
- Both scripts has following additional options:
- `--cache_text_encoder_outputs` and `--cache_text_encoder_outputs_to_disk`: Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions.
- `--no_half_vae`: Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs.
- `--weighted_captions` option is not supported yet for both scripts.
- `sdxl_train_textual_inversion.py` is a script for Textual Inversion training for SDXL. The usage is almost the same as `train_textual_inversion.py`.
- `--cache_text_encoder_outputs` is not supported.
- There are two options for captions:
1. Training with captions. All captions must include the token string. The token string is replaced with multiple tokens.
2. Use `--use_object_template` or `--use_style_template` option. The captions are generated from the template. The existing captions are ignored.
- See below for the format of the embeddings.
- `--min_timestep` and `--max_timestep` options are added to each training script. These options can be used to train U-Net with different timesteps. The default values are 0 and 1000.
### Utility scripts for SDXL
- `tools/cache_latents.py` is added. This script can be used to cache the latents to disk in advance.
- The options are almost the same as `sdxl_train.py'. See the help message for the usage.
- Please launch the script as follows:
`accelerate launch --num_cpu_threads_per_process 1 tools/cache_latents.py ...`
- This script should work with multi-GPU, but it is not tested in my environment.
- `tools/cache_text_encoder_outputs.py` is added. This script can be used to cache the text encoder outputs to disk in advance.
- The options are almost the same as `cache_latents.py` and `sdxl_train.py`. See the help message for the usage.
- `sdxl_gen_img.py` is added. This script can be used to generate images with SDXL, including LoRA, Textual Inversion and ControlNet-LLLite. See the help message for the usage.
### Tips for SDXL training
- The default resolution of SDXL is 1024x1024.
- The fine-tuning can be done with 24GB GPU memory with the batch size of 1. For 24GB GPU, the following options are recommended __for the fine-tuning with 24GB GPU memory__:
- Train U-Net only.
- Use gradient checkpointing.
- Use `--cache_text_encoder_outputs` option and caching latents.
- Use Adafactor optimizer. RMSprop 8bit or Adagrad 8bit may work. AdamW 8bit doesn't seem to work.
- The LoRA training can be done with 8GB GPU memory (10GB recommended). For reducing the GPU memory usage, the following options are recommended:
- Train U-Net only.
- Use gradient checkpointing.
- Use `--cache_text_encoder_outputs` option and caching latents.
- Use one of 8bit optimizers or Adafactor optimizer.
- Use lower dim (4 to 8 for 8GB GPU).
- `--network_train_unet_only` option is highly recommended for SDXL LoRA. Because SDXL has two text encoders, the result of the training will be unexpected.
- PyTorch 2 seems to use slightly less GPU memory than PyTorch 1.
- `--bucket_reso_steps` can be set to 32 instead of the default value 64. Smaller values than 32 will not work for SDXL training.
Example of the optimizer settings for Adafactor with the fixed learning rate:
```toml
optimizer_type = "adafactor"
optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ]
lr_scheduler = "constant_with_warmup"
lr_warmup_steps = 100
learning_rate = 4e-7 # SDXL original learning rate
```
### Format of Textual Inversion embeddings for SDXL
```python
from safetensors.torch import save_file
state_dict = {"clip_g": embs_for_text_encoder_1280, "clip_l": embs_for_text_encoder_768}
save_file(state_dict, file)
```
### ControlNet-LLLite
ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [documentation](./docs/train_lllite_README.md) for details.

View File

@@ -21,9 +21,13 @@ ComfyUIのカスタムードを用意しています。: https://github.com/k
## モデルの学習
### データセットの準備
通常のdatasetに加え`conditioning_data_dir` で指定したディレクトリにconditioning imageを格納してください。conditioning imageは学習用画像と同じbasenameを持つ必要があります。また、conditioning imageは学習用画像と同じサイズに自動的にリサイズされます。conditioning imageにはキャプションファイルは不要です。
DreamBooth 方式の dataset`conditioning_data_dir` で指定したディレクトリにconditioning imageを格納してください。
たとえば DreamBooth 方式でキャプションファイルを用いる場合の設定ファイルは以下のようになります。
finetuning 方式の dataset はサポートしていません。)
conditioning imageは学習用画像と同じbasenameを持つ必要があります。また、conditioning imageは学習用画像と同じサイズに自動的にリサイズされます。conditioning imageにはキャプションファイルは不要です。
たとえば、キャプションにフォルダ名ではなくキャプションファイルを用いる場合の設定ファイルは以下のようになります。
```toml
[[datasets.subsets]]

View File

@@ -26,7 +26,9 @@ Due to the limitations of the inference environment, only CrossAttention (attn1
### Preparing the dataset
In addition to the normal dataset, please store the conditioning image in the directory specified by `conditioning_data_dir`. The conditioning image must have the same basename as the training image. The conditioning image will be automatically resized to the same size as the training image. The conditioning image does not require a caption file.
In addition to the normal DreamBooth method dataset, please store the conditioning image in the directory specified by `conditioning_data_dir`. The conditioning image must have the same basename as the training image. The conditioning image will be automatically resized to the same size as the training image. The conditioning image does not require a caption file.
(We do not support the finetuning method dataset.)
```toml
[[datasets.subsets]]

View File

@@ -0,0 +1,88 @@
# Image Tagging using WD14Tagger
This document is based on the information from this github page (https://github.com/toriato/stable-diffusion-webui-wd14-tagger#mrsmilingwolfs-model-aka-waifu-diffusion-14-tagger).
Using onnx for inference is recommended. Please install onnx with the following command:
```powershell
pip install onnx==1.15.0 onnxruntime-gpu==1.17.1
```
The model weights will be automatically downloaded from Hugging Face.
# Usage
Run the script to perform tagging.
```powershell
python finetune/tag_images_by_wd14_tagger.py --onnx --repo_id <model repo id> --batch_size <batch size> <training data folder>
```
For example, if using the repository `SmilingWolf/wd-swinv2-tagger-v3` with a batch size of 4, and the training data is located in the parent folder `train_data`, it would be:
```powershell
python tag_images_by_wd14_tagger.py --onnx --repo_id SmilingWolf/wd-swinv2-tagger-v3 --batch_size 4 ..\train_data
```
On the first run, the model files will be automatically downloaded to the `wd14_tagger_model` folder (the folder can be changed with an option).
Tag files will be created in the same directory as the training data images, with the same filename and a `.txt` extension.
![Generated tag files](https://user-images.githubusercontent.com/52813779/208910534-ea514373-1185-4b7d-9ae3-61eb50bc294e.png)
![Tags and image](https://user-images.githubusercontent.com/52813779/208910599-29070c15-7639-474f-b3e4-06bd5a3df29e.png)
## Example
To output in the Animagine XL 3.1 format, it would be as follows (enter on a single line in practice):
```
python tag_images_by_wd14_tagger.py --onnx --repo_id SmilingWolf/wd-swinv2-tagger-v3
--batch_size 4 --remove_underscore --undesired_tags "PUT,YOUR,UNDESIRED,TAGS" --recursive
--use_rating_tags_as_last_tag --character_tags_first --character_tag_expand
--always_first_tags "1girl,1boy" ..\train_data
```
## Available Repository IDs
[SmilingWolf's V2 and V3 models](https://huggingface.co/SmilingWolf) are available for use. Specify them in the format like `SmilingWolf/wd-vit-tagger-v3`. The default when omitted is `SmilingWolf/wd-v1-4-convnext-tagger-v2`.
# Options
## General Options
- `--onnx`: Use ONNX for inference. If not specified, TensorFlow will be used. If using TensorFlow, please install TensorFlow separately.
- `--batch_size`: Number of images to process at once. Default is 1. Adjust according to VRAM capacity.
- `--caption_extension`: File extension for caption files. Default is `.txt`.
- `--max_data_loader_n_workers`: Maximum number of workers for DataLoader. Specifying a value of 1 or more will use DataLoader to speed up image loading. If unspecified, DataLoader will not be used.
- `--thresh`: Confidence threshold for outputting tags. Default is 0.35. Lowering the value will assign more tags but accuracy will decrease.
- `--general_threshold`: Confidence threshold for general tags. If omitted, same as `--thresh`.
- `--character_threshold`: Confidence threshold for character tags. If omitted, same as `--thresh`.
- `--recursive`: If specified, subfolders within the specified folder will also be processed recursively.
- `--append_tags`: Append tags to existing tag files.
- `--frequency_tags`: Output tag frequencies.
- `--debug`: Debug mode. Outputs debug information if specified.
## Model Download
- `--model_dir`: Folder to save model files. Default is `wd14_tagger_model`.
- `--force_download`: Re-download model files if specified.
## Tag Editing
- `--remove_underscore`: Remove underscores from output tags.
- `--undesired_tags`: Specify tags not to output. Multiple tags can be specified, separated by commas. For example, `black eyes,black hair`.
- `--use_rating_tags`: Output rating tags at the beginning of the tags.
- `--use_rating_tags_as_last_tag`: Add rating tags at the end of the tags.
- `--character_tags_first`: Output character tags first.
- `--character_tag_expand`: Expand character tag series names. For example, split the tag `chara_name_(series)` into `chara_name, series`.
- `--always_first_tags`: Specify tags to always output first when a certain tag appears in an image. Multiple tags can be specified, separated by commas. For example, `1girl,1boy`.
- `--caption_separator`: Separate tags with this string in the output file. Default is `, `.
- `--tag_replacement`: Perform tag replacement. Specify in the format `tag1,tag2;tag3,tag4`. If using `,` and `;`, escape them with `\`. \
For example, specify `aira tsubase,aira tsubase (uniform)` (when you want to train a specific costume), `aira tsubase,aira tsubase\, heir of shadows` (when the series name is not included in the tag).
When using `tag_replacement`, it is applied after `character_tag_expand`.
When specifying `remove_underscore`, specify `undesired_tags`, `always_first_tags`, and `tag_replacement` without including underscores.
When specifying `caption_separator`, separate `undesired_tags` and `always_first_tags` with `caption_separator`. Always separate `tag_replacement` with `,`.

View File

@@ -0,0 +1,88 @@
# WD14Taggerによるタグ付け
こちらのgithubページhttps://github.com/toriato/stable-diffusion-webui-wd14-tagger#mrsmilingwolfs-model-aka-waifu-diffusion-14-tagger )の情報を参考にさせていただきました。
onnx を用いた推論を推奨します。以下のコマンドで onnx をインストールしてください。
```powershell
pip install onnx==1.15.0 onnxruntime-gpu==1.17.1
```
モデルの重みはHugging Faceから自動的にダウンロードしてきます。
# 使い方
スクリプトを実行してタグ付けを行います。
```
python fintune/tag_images_by_wd14_tagger.py --onnx --repo_id <モデルのrepo id> --batch_size <バッチサイズ> <教師データフォルダ>
```
レポジトリに `SmilingWolf/wd-swinv2-tagger-v3` を使用し、バッチサイズを4にして、教師データを親フォルダの `train_data`に置いた場合、以下のようになります。
```
python tag_images_by_wd14_tagger.py --onnx --repo_id SmilingWolf/wd-swinv2-tagger-v3 --batch_size 4 ..\train_data
```
初回起動時にはモデルファイルが `wd14_tagger_model` フォルダに自動的にダウンロードされます(フォルダはオプションで変えられます)。
タグファイルが教師データ画像と同じディレクトリに、同じファイル名、拡張子.txtで作成されます。
![生成されたタグファイル](https://user-images.githubusercontent.com/52813779/208910534-ea514373-1185-4b7d-9ae3-61eb50bc294e.png)
![タグと画像](https://user-images.githubusercontent.com/52813779/208910599-29070c15-7639-474f-b3e4-06bd5a3df29e.png)
## 記述例
Animagine XL 3.1 方式で出力する場合、以下のようになります(実際には 1 行で入力してください)。
```
python tag_images_by_wd14_tagger.py --onnx --repo_id SmilingWolf/wd-swinv2-tagger-v3
--batch_size 4 --remove_underscore --undesired_tags "PUT,YOUR,UNDESIRED,TAGS" --recursive
--use_rating_tags_as_last_tag --character_tags_first --character_tag_expand
--always_first_tags "1girl,1boy" ..\train_data
```
## 使用可能なリポジトリID
[SmilingWolf 氏の V2、V3 のモデル](https://huggingface.co/SmilingWolf)が使用可能です。`SmilingWolf/wd-vit-tagger-v3` のように指定してください。省略時のデフォルトは `SmilingWolf/wd-v1-4-convnext-tagger-v2` です。
# オプション
## 一般オプション
- `--onnx` : ONNX を使用して推論します。指定しない場合は TensorFlow を使用します。TensorFlow 使用時は別途 TensorFlow をインストールしてください。
- `--batch_size` : 一度に処理する画像の数。デフォルトは1です。VRAMの容量に応じて増減してください。
- `--caption_extension` : キャプションファイルの拡張子。デフォルトは `.txt` です。
- `--max_data_loader_n_workers` : DataLoader の最大ワーカー数です。このオプションに 1 以上の数値を指定すると、DataLoader を用いて画像読み込みを高速化します。未指定時は DataLoader を用いません。
- `--thresh` : 出力するタグの信頼度の閾値。デフォルトは0.35です。値を下げるとより多くのタグが付与されますが、精度は下がります。
- `--general_threshold` : 一般タグの信頼度の閾値。省略時は `--thresh` と同じです。
- `--character_threshold` : キャラクタータグの信頼度の閾値。省略時は `--thresh` と同じです。
- `--recursive` : 指定すると、指定したフォルダ内のサブフォルダも再帰的に処理します。
- `--append_tags` : 既存のタグファイルにタグを追加します。
- `--frequency_tags` : タグの頻度を出力します。
- `--debug` : デバッグモード。指定するとデバッグ情報を出力します。
## モデルのダウンロード
- `--model_dir` : モデルファイルの保存先フォルダ。デフォルトは `wd14_tagger_model` です。
- `--force_download` : 指定するとモデルファイルを再ダウンロードします。
## タグ編集関連
- `--remove_underscore` : 出力するタグからアンダースコアを削除します。
- `--undesired_tags` : 出力しないタグを指定します。カンマ区切りで複数指定できます。たとえば `black eyes,black hair` のように指定します。
- `--use_rating_tags` : タグの最初にレーティングタグを出力します。
- `--use_rating_tags_as_last_tag` : タグの最後にレーティングタグを追加します。
- `--character_tags_first` : キャラクタータグを最初に出力します。
- `--character_tag_expand` : キャラクタータグのシリーズ名を展開します。たとえば `chara_name_(series)` のタグを `chara_name, series` に分割します。
- `--always_first_tags` : あるタグが画像に出力されたとき、そのタグを最初に出力するタグを指定します。カンマ区切りで複数指定できます。たとえば `1girl,1boy` のように指定します。
- `--caption_separator` : 出力するファイルでタグをこの文字列で区切ります。デフォルトは `, ` です。
- `--tag_replacement` : タグの置換を行います。`tag1,tag2;tag3,tag4` のように指定します。`,` および `;` を使う場合は `\` でエスケープしてください。\
たとえば `aira tsubase,aira tsubase (uniform)` (特定の衣装を学習させたいとき)、`aira tsubase,aira tsubase\, heir of shadows` (シリーズ名がタグに含まれないとき)のように指定します。
`tag_replacement``character_tag_expand` の後に適用されます。
`remove_underscore` 指定時は、`undesired_tags``always_first_tags``tag_replacement` はアンダースコアを含めずに指定してください。
`caption_separator` 指定時は、`undesired_tags``always_first_tags``caption_separator` で区切ってください。`tag_replacement` は必ず `,` で区切ってください。

View File

@@ -2,27 +2,29 @@
# XXX dropped option: hypernetwork training
import argparse
import gc
import math
import os
from multiprocessing import Value
import toml
from tqdm import tqdm
import torch
from library import deepspeed_utils
from library.device_utils import init_ipex, clean_memory_on_device
try:
import intel_extension_for_pytorch as ipex
init_ipex()
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
import library.train_util as train_util
import library.config_util as config_util
from library.config_util import (
@@ -42,6 +44,8 @@ from library.custom_train_functions import (
def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
deepspeed_utils.prepare_deepspeed_args(args)
setup_logging(args, reset=True)
cache_latents = args.cache_latents
@@ -54,11 +58,11 @@ def train(args):
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, False, True))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
@@ -91,7 +95,7 @@ def train(args):
train_util.debug_dataset(train_dataset_group)
return
if len(train_dataset_group) == 0:
print(
logger.error(
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
)
return
@@ -102,11 +106,12 @@ def train(args):
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
# acceleratorを準備する
print("prepare accelerator")
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
# モデルを読み込む
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)
@@ -157,15 +162,13 @@ def train(args):
# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
@@ -192,7 +195,7 @@ def train(args):
if not cache_latents:
vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=vae_dtype)
for m in training_models:
m.requires_grad_(True)
@@ -212,8 +215,8 @@ def train(args):
_, _, optimizer = train_util.get_optimizer(args, trainable_params=trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
@@ -228,7 +231,9 @@ def train(args):
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
@@ -245,16 +250,23 @@ def train(args):
unet.to(weight_dtype)
text_encoder.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
if args.deepspeed:
if args.train_text_encoder:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder)
else:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
training_models = [ds_model]
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
# transform DDP after prepare
text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet)
# acceleratorがなんかよろしくやってくれるらしい
if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
@@ -294,10 +306,15 @@ def train(args):
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
# For --sample_at_first
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
@@ -308,13 +325,13 @@ def train(args):
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
with accelerator.accumulate(*training_models):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device) # .to(dtype=weight_dtype)
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(weight_dtype)
latents = latents * 0.18215
b_size = latents.shape[0]
@@ -337,7 +354,7 @@ def train(args):
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
# Predict the noise residual
with accelerator.autocast():
@@ -351,11 +368,11 @@ def train(args):
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss:
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = loss.mean([1, 2, 3])
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.debiased_estimation_loss:
@@ -363,7 +380,7 @@ def train(args):
loss = loss.mean() # mean over batch dimension
else:
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c)
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
@@ -454,7 +471,7 @@ def train(args):
accelerator.end_training()
if args.save_state and is_main_process:
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)
del accelerator # この後メモリを使うのでこれは消す
@@ -464,21 +481,25 @@ def train(args):
train_util.save_sd_model_on_train_end(
args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae
)
print("model saved.")
logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
parser.add_argument(
"--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する"
)
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
parser.add_argument(
"--learning_rate_te",
@@ -486,6 +507,11 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="learning rate for text encoder, default is same as unet / Text Encoderの学習率、デフォルトはunetと同じ",
)
parser.add_argument(
"--no_half_vae",
action="store_true",
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
)
return parser
@@ -494,6 +520,7 @@ if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
train(args)

View File

@@ -21,6 +21,10 @@ import torch.nn.functional as F
import os
from urllib.parse import urlparse
from timm.models.hub import download_cached_file
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class BLIP_Base(nn.Module):
def __init__(self,
@@ -130,8 +134,9 @@ class BLIP_Decoder(nn.Module):
def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0):
image_embeds = self.visual_encoder(image)
if not sample:
image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
# recent version of transformers seems to do repeat_interleave automatically
# if not sample:
# image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}
@@ -235,6 +240,6 @@ def load_checkpoint(model,url_or_filename):
del state_dict[key]
msg = model.load_state_dict(state_dict,strict=False)
print('load checkpoint from %s'%url_or_filename)
logger.info('load checkpoint from %s'%url_or_filename)
return model,msg

View File

@@ -8,6 +8,10 @@ import json
import re
from tqdm import tqdm
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
PATTERN_HAIR_LENGTH = re.compile(r', (long|short|medium) hair, ')
PATTERN_HAIR_CUT = re.compile(r', (bob|hime) cut, ')
@@ -36,13 +40,13 @@ def clean_tags(image_key, tags):
tokens = tags.split(", rating")
if len(tokens) == 1:
# WD14 taggerのときはこちらになるのでメッセージは出さない
# print("no rating:")
# print(f"{image_key} {tags}")
# logger.info("no rating:")
# logger.info(f"{image_key} {tags}")
pass
else:
if len(tokens) > 2:
print("multiple ratings:")
print(f"{image_key} {tags}")
logger.info("multiple ratings:")
logger.info(f"{image_key} {tags}")
tags = tokens[0]
tags = ", " + tags.replace(", ", ", , ") + ", " # カンマ付きで検索をするための身も蓋もない対策
@@ -124,43 +128,43 @@ def clean_caption(caption):
def main(args):
if os.path.exists(args.in_json):
print(f"loading existing metadata: {args.in_json}")
logger.info(f"loading existing metadata: {args.in_json}")
with open(args.in_json, "rt", encoding='utf-8') as f:
metadata = json.load(f)
else:
print("no metadata / メタデータファイルがありません")
logger.error("no metadata / メタデータファイルがありません")
return
print("cleaning captions and tags.")
logger.info("cleaning captions and tags.")
image_keys = list(metadata.keys())
for image_key in tqdm(image_keys):
tags = metadata[image_key].get('tags')
if tags is None:
print(f"image does not have tags / メタデータにタグがありません: {image_key}")
logger.error(f"image does not have tags / メタデータにタグがありません: {image_key}")
else:
org = tags
tags = clean_tags(image_key, tags)
metadata[image_key]['tags'] = tags
if args.debug and org != tags:
print("FROM: " + org)
print("TO: " + tags)
logger.info("FROM: " + org)
logger.info("TO: " + tags)
caption = metadata[image_key].get('caption')
if caption is None:
print(f"image does not have caption / メタデータにキャプションがありません: {image_key}")
logger.error(f"image does not have caption / メタデータにキャプションがありません: {image_key}")
else:
org = caption
caption = clean_caption(caption)
metadata[image_key]['caption'] = caption
if args.debug and org != caption:
print("FROM: " + org)
print("TO: " + caption)
logger.info("FROM: " + org)
logger.info("TO: " + caption)
# metadataを書き出して終わり
print(f"writing metadata: {args.out_json}")
logger.info(f"writing metadata: {args.out_json}")
with open(args.out_json, "wt", encoding='utf-8') as f:
json.dump(metadata, f, indent=2)
print("done!")
logger.info("done!")
def setup_parser() -> argparse.ArgumentParser:
@@ -178,10 +182,10 @@ if __name__ == '__main__':
args, unknown = parser.parse_known_args()
if len(unknown) == 1:
print("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.")
print("All captions and tags in the metadata are processed.")
print("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。")
print("メタデータ内のすべてのキャプションとタグが処理されます。")
logger.warning("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.")
logger.warning("All captions and tags in the metadata are processed.")
logger.warning("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。")
logger.warning("メタデータ内のすべてのキャプションとタグが処理されます。")
args.in_json = args.out_json
args.out_json = unknown[0]
elif len(unknown) > 0:

View File

@@ -9,14 +9,22 @@ from pathlib import Path
from PIL import Image
from tqdm import tqdm
import numpy as np
import torch
from library.device_utils import init_ipex, get_preferred_device
init_ipex()
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
sys.path.append(os.path.dirname(__file__))
from blip.blip import blip_decoder
from blip.blip import blip_decoder, is_url
import library.train_util as train_util
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE = get_preferred_device()
IMAGE_SIZE = 384
@@ -47,7 +55,7 @@ class ImageLoadingTransformDataset(torch.utils.data.Dataset):
# convert to tensor temporarily so dataloader will accept it
tensor = IMAGE_TRANSFORM(image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
return None
return (tensor, img_path)
@@ -74,19 +82,21 @@ def main(args):
args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path
cwd = os.getcwd()
print("Current Working Directory is: ", cwd)
logger.info(f"Current Working Directory is: {cwd}")
os.chdir("finetune")
if not is_url(args.caption_weights) and not os.path.isfile(args.caption_weights):
args.caption_weights = os.path.join("..", args.caption_weights)
print(f"load images from {args.train_data_dir}")
logger.info(f"load images from {args.train_data_dir}")
train_data_dir_path = Path(args.train_data_dir)
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.")
logger.info(f"found {len(image_paths)} images.")
print(f"loading BLIP caption: {args.caption_weights}")
logger.info(f"loading BLIP caption: {args.caption_weights}")
model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit="large", med_config="./blip/med_config.json")
model.eval()
model = model.to(DEVICE)
print("BLIP loaded")
logger.info("BLIP loaded")
# captioningする
def run_batch(path_imgs):
@@ -106,7 +116,7 @@ def main(args):
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
f.write(caption + "\n")
if args.debug:
print(image_path, caption)
logger.info(f'{image_path} {caption}')
# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
@@ -136,7 +146,7 @@ def main(args):
raw_image = raw_image.convert("RGB")
img_tensor = IMAGE_TRANSFORM(raw_image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue
b_imgs.append((image_path, img_tensor))
@@ -146,7 +156,7 @@ def main(args):
if len(b_imgs) > 0:
run_batch(b_imgs)
print("done!")
logger.info("done!")
def setup_parser() -> argparse.ArgumentParser:

View File

@@ -5,12 +5,19 @@ import re
from pathlib import Path
from PIL import Image
from tqdm import tqdm
import torch
from library.device_utils import init_ipex, get_preferred_device
init_ipex()
from transformers import AutoProcessor, AutoModelForCausalLM
from transformers.generation.utils import GenerationMixin
import library.train_util as train_util
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -35,8 +42,8 @@ def remove_words(captions, debug):
for pat in PATTERN_REPLACE:
cap = pat.sub("", cap)
if debug and cap != caption:
print(caption)
print(cap)
logger.info(caption)
logger.info(cap)
removed_caps.append(cap)
return removed_caps
@@ -70,16 +77,16 @@ def main(args):
GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
"""
print(f"load images from {args.train_data_dir}")
logger.info(f"load images from {args.train_data_dir}")
train_data_dir_path = Path(args.train_data_dir)
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.")
logger.info(f"found {len(image_paths)} images.")
# できればcacheに依存せず明示的にダウンロードしたい
print(f"loading GIT: {args.model_id}")
logger.info(f"loading GIT: {args.model_id}")
git_processor = AutoProcessor.from_pretrained(args.model_id)
git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE)
print("GIT loaded")
logger.info("GIT loaded")
# captioningする
def run_batch(path_imgs):
@@ -97,7 +104,7 @@ def main(args):
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
f.write(caption + "\n")
if args.debug:
print(image_path, caption)
logger.info(f"{image_path} {caption}")
# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
@@ -126,7 +133,7 @@ def main(args):
if image.mode != "RGB":
image = image.convert("RGB")
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue
b_imgs.append((image_path, image))
@@ -137,7 +144,7 @@ def main(args):
if len(b_imgs) > 0:
run_batch(b_imgs)
print("done!")
logger.info("done!")
def setup_parser() -> argparse.ArgumentParser:

View File

@@ -5,72 +5,96 @@ from typing import List
from tqdm import tqdm
import library.train_util as train_util
import os
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def main(args):
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
assert not args.recursive or (
args.recursive and args.full_path
), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
train_data_dir_path = Path(args.train_data_dir)
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.")
train_data_dir_path = Path(args.train_data_dir)
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
logger.info(f"found {len(image_paths)} images.")
if args.in_json is None and Path(args.out_json).is_file():
args.in_json = args.out_json
if args.in_json is None and Path(args.out_json).is_file():
args.in_json = args.out_json
if args.in_json is not None:
print(f"loading existing metadata: {args.in_json}")
metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
print("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます")
else:
print("new metadata will be created / 新しいメタデータファイルが作成されます")
metadata = {}
if args.in_json is not None:
logger.info(f"loading existing metadata: {args.in_json}")
metadata = json.loads(Path(args.in_json).read_text(encoding="utf-8"))
logger.warning("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます")
else:
logger.info("new metadata will be created / 新しいメタデータファイルが作成されます")
metadata = {}
print("merge caption texts to metadata json.")
for image_path in tqdm(image_paths):
caption_path = image_path.with_suffix(args.caption_extension)
caption = caption_path.read_text(encoding='utf-8').strip()
logger.info("merge caption texts to metadata json.")
for image_path in tqdm(image_paths):
caption_path = image_path.with_suffix(args.caption_extension)
caption = caption_path.read_text(encoding="utf-8").strip()
if not os.path.exists(caption_path):
caption_path = os.path.join(image_path, args.caption_extension)
if not os.path.exists(caption_path):
caption_path = os.path.join(image_path, args.caption_extension)
image_key = str(image_path) if args.full_path else image_path.stem
if image_key not in metadata:
metadata[image_key] = {}
image_key = str(image_path) if args.full_path else image_path.stem
if image_key not in metadata:
metadata[image_key] = {}
metadata[image_key]['caption'] = caption
if args.debug:
print(image_key, caption)
metadata[image_key]["caption"] = caption
if args.debug:
logger.info(f"{image_key} {caption}")
# metadataを書き出して終わり
print(f"writing metadata: {args.out_json}")
Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
print("done!")
# metadataを書き出して終わり
logger.info(f"writing metadata: {args.out_json}")
Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding="utf-8")
logger.info("done!")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("--in_json", type=str,
help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル省略時、out_jsonが存在すればそれを読み込む")
parser.add_argument("--caption_extention", type=str, default=None,
help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります")
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子")
parser.add_argument("--full_path", action="store_true",
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
parser.add_argument("--recursive", action="store_true",
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")
parser.add_argument("--debug", action="store_true", help="debug mode")
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument(
"--in_json",
type=str,
help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル省略時、out_jsonが存在すればそれを読み込む",
)
parser.add_argument(
"--caption_extention",
type=str,
default=None,
help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)",
)
parser.add_argument(
"--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子"
)
parser.add_argument(
"--full_path",
action="store_true",
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)",
)
parser.add_argument(
"--recursive",
action="store_true",
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す",
)
parser.add_argument("--debug", action="store_true", help="debug mode")
return parser
return parser
if __name__ == '__main__':
parser = setup_parser()
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
args = parser.parse_args()
# スペルミスしていたオプションを復元する
if args.caption_extention is not None:
args.caption_extension = args.caption_extention
# スペルミスしていたオプションを復元する
if args.caption_extention is not None:
args.caption_extension = args.caption_extention
main(args)
main(args)

View File

@@ -5,67 +5,89 @@ from typing import List
from tqdm import tqdm
import library.train_util as train_util
import os
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def main(args):
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
assert not args.recursive or (
args.recursive and args.full_path
), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
train_data_dir_path = Path(args.train_data_dir)
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.")
train_data_dir_path = Path(args.train_data_dir)
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
logger.info(f"found {len(image_paths)} images.")
if args.in_json is None and Path(args.out_json).is_file():
args.in_json = args.out_json
if args.in_json is None and Path(args.out_json).is_file():
args.in_json = args.out_json
if args.in_json is not None:
print(f"loading existing metadata: {args.in_json}")
metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
print("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます")
else:
print("new metadata will be created / 新しいメタデータファイルが作成されます")
metadata = {}
if args.in_json is not None:
logger.info(f"loading existing metadata: {args.in_json}")
metadata = json.loads(Path(args.in_json).read_text(encoding="utf-8"))
logger.warning("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます")
else:
logger.info("new metadata will be created / 新しいメタデータファイルが作成されます")
metadata = {}
print("merge tags to metadata json.")
for image_path in tqdm(image_paths):
tags_path = image_path.with_suffix(args.caption_extension)
tags = tags_path.read_text(encoding='utf-8').strip()
logger.info("merge tags to metadata json.")
for image_path in tqdm(image_paths):
tags_path = image_path.with_suffix(args.caption_extension)
tags = tags_path.read_text(encoding="utf-8").strip()
if not os.path.exists(tags_path):
tags_path = os.path.join(image_path, args.caption_extension)
if not os.path.exists(tags_path):
tags_path = os.path.join(image_path, args.caption_extension)
image_key = str(image_path) if args.full_path else image_path.stem
if image_key not in metadata:
metadata[image_key] = {}
image_key = str(image_path) if args.full_path else image_path.stem
if image_key not in metadata:
metadata[image_key] = {}
metadata[image_key]['tags'] = tags
if args.debug:
print(image_key, tags)
metadata[image_key]["tags"] = tags
if args.debug:
logger.info(f"{image_key} {tags}")
# metadataを書き出して終わり
print(f"writing metadata: {args.out_json}")
Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
# metadataを書き出して終わり
logger.info(f"writing metadata: {args.out_json}")
Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding="utf-8")
print("done!")
logger.info("done!")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("--in_json", type=str,
help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル省略時、out_jsonが存在すればそれを読み込む")
parser.add_argument("--full_path", action="store_true",
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応")
parser.add_argument("--recursive", action="store_true",
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")
parser.add_argument("--caption_extension", type=str, default=".txt",
help="extension of caption (tag) file / 読み込むキャプション(タグ)ファイルの拡張子")
parser.add_argument("--debug", action="store_true", help="debug mode, print tags")
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument(
"--in_json",
type=str,
help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル省略時、out_jsonが存在すればそれを読み込む",
)
parser.add_argument(
"--full_path",
action="store_true",
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)",
)
parser.add_argument(
"--recursive",
action="store_true",
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す",
)
parser.add_argument(
"--caption_extension",
type=str,
default=".txt",
help="extension of caption (tag) file / 読み込むキャプション(タグ)ファイルの拡張子",
)
parser.add_argument("--debug", action="store_true", help="debug mode, print tags")
return parser
return parser
if __name__ == '__main__':
parser = setup_parser()
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
main(args)
args = parser.parse_args()
main(args)

View File

@@ -8,13 +8,21 @@ from tqdm import tqdm
import numpy as np
from PIL import Image
import cv2
import torch
from library.device_utils import init_ipex, get_preferred_device
init_ipex()
from torchvision import transforms
import library.model_util as model_util
import library.train_util as train_util
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE = get_preferred_device()
IMAGE_TRANSFORMS = transforms.Compose(
[
@@ -51,22 +59,22 @@ def get_npz_filename(data_dir, image_key, is_full_path, recursive):
def main(args):
# assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります"
if args.bucket_reso_steps % 8 > 0:
print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります")
logger.warning(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります")
if args.bucket_reso_steps % 32 > 0:
print(
logger.warning(
f"WARNING: bucket_reso_steps is not divisible by 32. It is not working with SDXL / bucket_reso_stepsが32で割り切れません。SDXLでは動作しません"
)
train_data_dir_path = Path(args.train_data_dir)
image_paths: List[str] = [str(p) for p in train_util.glob_images_pathlib(train_data_dir_path, args.recursive)]
print(f"found {len(image_paths)} images.")
logger.info(f"found {len(image_paths)} images.")
if os.path.exists(args.in_json):
print(f"loading existing metadata: {args.in_json}")
logger.info(f"loading existing metadata: {args.in_json}")
with open(args.in_json, "rt", encoding="utf-8") as f:
metadata = json.load(f)
else:
print(f"no metadata / メタデータファイルがありません: {args.in_json}")
logger.error(f"no metadata / メタデータファイルがありません: {args.in_json}")
return
weight_dtype = torch.float32
@@ -89,7 +97,7 @@ def main(args):
if not args.bucket_no_upscale:
bucket_manager.make_buckets()
else:
print(
logger.warning(
"min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます"
)
@@ -130,7 +138,7 @@ def main(args):
if image.mode != "RGB":
image = image.convert("RGB")
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
@@ -183,15 +191,15 @@ def main(args):
for i, reso in enumerate(bucket_manager.resos):
count = bucket_counts.get(reso, 0)
if count > 0:
print(f"bucket {i} {reso}: {count}")
logger.info(f"bucket {i} {reso}: {count}")
img_ar_errors = np.array(img_ar_errors)
print(f"mean ar error: {np.mean(img_ar_errors)}")
logger.info(f"mean ar error: {np.mean(img_ar_errors)}")
# metadataを書き出して終わり
print(f"writing metadata: {args.out_json}")
logger.info(f"writing metadata: {args.out_json}")
with open(args.out_json, "wt", encoding="utf-8") as f:
json.dump(metadata, f, indent=2)
print("done!")
logger.info("done!")
def setup_parser() -> argparse.ArgumentParser:

View File

@@ -11,6 +11,12 @@ from PIL import Image
from tqdm import tqdm
import library.train_util as train_util
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# from wd14 tagger
IMAGE_SIZE = 448
@@ -56,12 +62,12 @@ class ImageLoadingPrepDataset(torch.utils.data.Dataset):
try:
image = Image.open(img_path).convert("RGB")
image = preprocess_image(image)
tensor = torch.tensor(image)
# tensor = torch.tensor(image) # これ Tensor に変換する必要ないな……(;・∀・)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
return None
return (tensor, img_path)
return (image, img_path)
def collate_fn_remove_corrupted(batch):
@@ -75,36 +81,44 @@ def collate_fn_remove_corrupted(batch):
def main(args):
# model location is model_dir + repo_id
# repo id may be like "user/repo" or "user/repo/branch", so we need to remove slash
model_location = os.path.join(args.model_dir, args.repo_id.replace("/", "_"))
# hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする
# depreacatedの警告が出るけどなくなったらその時
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
if not os.path.exists(args.model_dir) or args.force_download:
print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
if not os.path.exists(model_location) or args.force_download:
os.makedirs(args.model_dir, exist_ok=True)
logger.info(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
files = FILES
if args.onnx:
files = ["selected_tags.csv"]
files += FILES_ONNX
else:
for file in SUB_DIR_FILES:
hf_hub_download(
args.repo_id,
file,
subfolder=SUB_DIR,
cache_dir=os.path.join(model_location, SUB_DIR),
force_download=True,
force_filename=file,
)
for file in files:
hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
for file in SUB_DIR_FILES:
hf_hub_download(
args.repo_id,
file,
subfolder=SUB_DIR,
cache_dir=os.path.join(args.model_dir, SUB_DIR),
force_download=True,
force_filename=file,
)
hf_hub_download(args.repo_id, file, cache_dir=model_location, force_download=True, force_filename=file)
else:
print("using existing wd14 tagger model")
logger.info("using existing wd14 tagger model")
# 画像を読み込む
# モデルを読み込む
if args.onnx:
import torch
import onnx
import onnxruntime as ort
onnx_path = f"{args.model_dir}/model.onnx"
print("Running wd14 tagger with onnx")
print(f"loading onnx model: {onnx_path}")
onnx_path = f"{model_location}/model.onnx"
logger.info("Running wd14 tagger with onnx")
logger.info(f"loading onnx model: {onnx_path}")
if not os.path.exists(onnx_path):
raise Exception(
@@ -116,58 +130,112 @@ def main(args):
input_name = model.graph.input[0].name
try:
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_value
except:
except Exception:
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_param
if args.batch_size != batch_size and type(batch_size) != str:
if args.batch_size != batch_size and not isinstance(batch_size, str) and batch_size > 0:
# some rebatch model may use 'N' as dynamic axes
print(
logger.warning(
f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}"
)
args.batch_size = batch_size
del model
ort_sess = ort.InferenceSession(
onnx_path,
providers=["CUDAExecutionProvider"]
if "CUDAExecutionProvider" in ort.get_available_providers()
else ["CPUExecutionProvider"],
)
if "OpenVINOExecutionProvider" in ort.get_available_providers():
# requires provider options for gpu support
# fp16 causes nonsense outputs
ort_sess = ort.InferenceSession(
onnx_path,
providers=(["OpenVINOExecutionProvider"]),
provider_options=[{'device_type' : "GPU_FP32"}],
)
else:
ort_sess = ort.InferenceSession(
onnx_path,
providers=(
["CUDAExecutionProvider"] if "CUDAExecutionProvider" in ort.get_available_providers() else
["ROCMExecutionProvider"] if "ROCMExecutionProvider" in ort.get_available_providers() else
["CPUExecutionProvider"]
),
)
else:
from tensorflow.keras.models import load_model
model = load_model(f"{args.model_dir}")
model = load_model(f"{model_location}")
# label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv")
# 依存ライブラリを増やしたくないので自力で読むよ
with open(os.path.join(args.model_dir, CSV_FILE), "r", encoding="utf-8") as f:
with open(os.path.join(model_location, CSV_FILE), "r", encoding="utf-8") as f:
reader = csv.reader(f)
l = [row for row in reader]
header = l[0] # tag_id,name,category,count
rows = l[1:]
line = [row for row in reader]
header = line[0] # tag_id,name,category,count
rows = line[1:]
assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}"
general_tags = [row[1] for row in rows[1:] if row[2] == "0"]
character_tags = [row[1] for row in rows[1:] if row[2] == "4"]
rating_tags = [row[1] for row in rows[0:] if row[2] == "9"]
general_tags = [row[1] for row in rows[0:] if row[2] == "0"]
character_tags = [row[1] for row in rows[0:] if row[2] == "4"]
# preprocess tags in advance
if args.character_tag_expand:
for i, tag in enumerate(character_tags):
if tag.endswith(")"):
# chara_name_(series) -> chara_name, series
# chara_name_(costume)_(series) -> chara_name_(costume), series
tags = tag.split("(")
character_tag = "(".join(tags[:-1])
if character_tag.endswith("_"):
character_tag = character_tag[:-1]
series_tag = tags[-1].replace(")", "")
character_tags[i] = character_tag + args.caption_separator + series_tag
if args.remove_underscore:
rating_tags = [tag.replace("_", " ") if len(tag) > 3 else tag for tag in rating_tags]
general_tags = [tag.replace("_", " ") if len(tag) > 3 else tag for tag in general_tags]
character_tags = [tag.replace("_", " ") if len(tag) > 3 else tag for tag in character_tags]
if args.tag_replacement is not None:
# escape , and ; in tag_replacement: wd14 tag names may contain , and ;
escaped_tag_replacements = args.tag_replacement.replace("\\,", "@@@@").replace("\\;", "####")
tag_replacements = escaped_tag_replacements.split(";")
for tag_replacement in tag_replacements:
tags = tag_replacement.split(",") # source, target
assert len(tags) == 2, f"tag replacement must be in the format of `source,target` / タグの置換は `置換元,置換先` の形式で指定してください: {args.tag_replacement}"
source, target = [tag.replace("@@@@", ",").replace("####", ";") for tag in tags]
logger.info(f"replacing tag: {source} -> {target}")
if source in general_tags:
general_tags[general_tags.index(source)] = target
elif source in character_tags:
character_tags[character_tags.index(source)] = target
elif source in rating_tags:
rating_tags[rating_tags.index(source)] = target
# 画像を読み込む
train_data_dir_path = Path(args.train_data_dir)
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.")
logger.info(f"found {len(image_paths)} images.")
tag_freq = {}
undesired_tags = set(args.undesired_tags.split(","))
caption_separator = args.caption_separator
stripped_caption_separator = caption_separator.strip()
undesired_tags = args.undesired_tags.split(stripped_caption_separator)
undesired_tags = set([tag.strip() for tag in undesired_tags if tag.strip() != ""])
always_first_tags = None
if args.always_first_tags is not None:
always_first_tags = [tag for tag in args.always_first_tags.split(stripped_caption_separator) if tag.strip() != ""]
def run_batch(path_imgs):
imgs = np.array([im for _, im in path_imgs])
if args.onnx:
if len(imgs) < args.batch_size:
imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0)
# if len(imgs) < args.batch_size:
# imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0)
probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy
probs = probs[: len(path_imgs)]
else:
@@ -175,46 +243,64 @@ def main(args):
probs = probs.numpy()
for (image_path, _), prob in zip(path_imgs, probs):
# 最初の4つはratingなので無視する
# # First 4 labels are actually ratings: pick one with argmax
# ratings_names = label_names[:4]
# rating_index = ratings_names["probs"].argmax()
# found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]]
# それ以降はタグなのでconfidenceがthresholdより高いものを追加する
# Everything else is tags: pick any where prediction confidence > threshold
combined_tags = []
general_tag_text = ""
rating_tag_text = ""
character_tag_text = ""
general_tag_text = ""
# 最初の4つ以降はタグなのでconfidenceがthreshold以上のものを追加する
# First 4 labels are ratings, the rest are tags: pick any where prediction confidence >= threshold
for i, p in enumerate(prob[4:]):
if i < len(general_tags) and p >= args.general_threshold:
tag_name = general_tags[i]
if args.remove_underscore and len(tag_name) > 3: # ignore emoji tags like >_< and ^_^
tag_name = tag_name.replace("_", " ")
if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
general_tag_text += ", " + tag_name
general_tag_text += caption_separator + tag_name
combined_tags.append(tag_name)
elif i >= len(general_tags) and p >= args.character_threshold:
tag_name = character_tags[i - len(general_tags)]
if args.remove_underscore and len(tag_name) > 3:
tag_name = tag_name.replace("_", " ")
if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
character_tag_text += ", " + tag_name
combined_tags.append(tag_name)
character_tag_text += caption_separator + tag_name
if args.character_tags_first: # insert to the beginning
combined_tags.insert(0, tag_name)
else:
combined_tags.append(tag_name)
# 最初の4つはratingなのでargmaxで選ぶ
# First 4 labels are actually ratings: pick one with argmax
if args.use_rating_tags or args.use_rating_tags_as_last_tag:
ratings_probs = prob[:4]
rating_index = ratings_probs.argmax()
found_rating = rating_tags[rating_index]
if found_rating not in undesired_tags:
tag_freq[found_rating] = tag_freq.get(found_rating, 0) + 1
rating_tag_text = found_rating
if args.use_rating_tags:
combined_tags.insert(0, found_rating) # insert to the beginning
else:
combined_tags.append(found_rating)
# 一番最初に置くタグを指定する
# Always put some tags at the beginning
if always_first_tags is not None:
for tag in always_first_tags:
if tag in combined_tags:
combined_tags.remove(tag)
combined_tags.insert(0, tag)
# 先頭のカンマを取る
if len(general_tag_text) > 0:
general_tag_text = general_tag_text[2:]
general_tag_text = general_tag_text[len(caption_separator) :]
if len(character_tag_text) > 0:
character_tag_text = character_tag_text[2:]
character_tag_text = character_tag_text[len(caption_separator) :]
caption_file = os.path.splitext(image_path)[0] + args.caption_extension
tag_text = ", ".join(combined_tags)
tag_text = caption_separator.join(combined_tags)
if args.append_tags:
# Check if file exists
@@ -224,18 +310,22 @@ def main(args):
existing_content = f.read().strip("\n") # Remove newlines
# Split the content into tags and store them in a list
existing_tags = [tag.strip() for tag in existing_content.split(",") if tag.strip()]
existing_tags = [tag.strip() for tag in existing_content.split(stripped_caption_separator) if tag.strip()]
# Check and remove repeating tags in tag_text
new_tags = [tag for tag in combined_tags if tag not in existing_tags]
# Create new tag_text
tag_text = ", ".join(existing_tags + new_tags)
tag_text = caption_separator.join(existing_tags + new_tags)
with open(caption_file, "wt", encoding="utf-8") as f:
f.write(tag_text + "\n")
if args.debug:
print(f"\n{image_path}:\n Character tags: {character_tag_text}\n General tags: {general_tag_text}")
logger.info("")
logger.info(f"{image_path}:")
logger.info(f"\tRating tags: {rating_tag_text}")
logger.info(f"\tCharacter tags: {character_tag_text}")
logger.info(f"\tGeneral tags: {general_tag_text}")
# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
@@ -258,16 +348,14 @@ def main(args):
continue
image, image_path = data
if image is not None:
image = image.detach().numpy()
else:
if image is None:
try:
image = Image.open(image_path)
if image.mode != "RGB":
image = image.convert("RGB")
image = preprocess_image(image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue
b_imgs.append((image_path, image))
@@ -282,16 +370,18 @@ def main(args):
if args.frequency_tags:
sorted_tags = sorted(tag_freq.items(), key=lambda x: x[1], reverse=True)
print("\nTag frequencies:")
print("Tag frequencies:")
for tag, freq in sorted_tags:
print(f"{tag}: {freq}")
print("done!")
logger.info("done!")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument(
"train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ"
)
parser.add_argument(
"--repo_id",
type=str,
@@ -305,9 +395,13 @@ def setup_parser() -> argparse.ArgumentParser:
help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ",
)
parser.add_argument(
"--force_download", action="store_true", help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします"
"--force_download",
action="store_true",
help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします",
)
parser.add_argument(
"--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ"
)
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument(
"--max_data_loader_n_workers",
type=int,
@@ -320,8 +414,12 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)",
)
parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子")
parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値")
parser.add_argument(
"--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子"
)
parser.add_argument(
"--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値"
)
parser.add_argument(
"--general_threshold",
type=float,
@@ -334,22 +432,67 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="threshold of confidence to add a tag for character category, same as --thres if omitted / characterカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ",
)
parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する")
parser.add_argument(
"--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する"
)
parser.add_argument(
"--remove_underscore",
action="store_true",
help="replace underscores with spaces in the output tags / 出力されるタグのアンダースコアをスペースに置き換える",
)
parser.add_argument("--debug", action="store_true", help="debug mode")
parser.add_argument(
"--debug", action="store_true", help="debug mode"
)
parser.add_argument(
"--undesired_tags",
type=str,
default="",
help="comma-separated list of undesired tags to remove from the output / 出力から除外したいタグのカンマ区切りのリスト",
)
parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する")
parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する")
parser.add_argument("--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する")
parser.add_argument(
"--frequency_tags", action="store_true", help="Show frequency of tags for images / タグの出現頻度を表示する"
)
parser.add_argument(
"--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する"
)
parser.add_argument(
"--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する"
)
parser.add_argument(
"--use_rating_tags", action="store_true", help="Adds rating tags as the first tag / レーティングタグを最初のタグとして追加する",
)
parser.add_argument(
"--use_rating_tags_as_last_tag", action="store_true", help="Adds rating tags as the last tag / レーティングタグを最後のタグとして追加する",
)
parser.add_argument(
"--character_tags_first", action="store_true", help="Always inserts character tags before the general tags / characterタグを常にgeneralタグの前に出力する",
)
parser.add_argument(
"--always_first_tags",
type=str,
default=None,
help="comma-separated list of tags to always put at the beginning, e.g. `1girl,1boy`"
+ " / 必ず先頭に置くタグのカンマ区切りリスト、例 : `1girl,1boy`",
)
parser.add_argument(
"--caption_separator",
type=str,
default=", ",
help="Separator for captions, include space if needed / キャプションの区切り文字、必要ならスペースを含めてください",
)
parser.add_argument(
"--tag_replacement",
type=str,
default=None,
help="tag replacement in the format of `source1,target1;source2,target2; ...`. Escape `,` and `;` with `\`. e.g. `tag1,tag2;tag3,tag4`"
+ " / タグの置換を `置換元1,置換先1;置換元2,置換先2; ...`で指定する。`\` で `,` と `;` をエスケープできる。例: `tag1,tag2;tag3,tag4`",
)
parser.add_argument(
"--character_tag_expand",
action="store_true",
help="expand tag tail parenthesis to another tag for character tags. `chara_name_(series)` becomes `chara_name, series`"
+ " / キャラクタタグの末尾の括弧を別のタグに展開する。`chara_name_(series)` は `chara_name, series` になる",
)
return parser

3334
gen_img.py Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,8 +1,16 @@
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
import torch
import argparse
import random
import re
from torch.types import Number
from typing import List, Optional, Union
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def prepare_scheduler_for_custom_training(noise_scheduler, device):
@@ -21,7 +29,7 @@ def prepare_scheduler_for_custom_training(noise_scheduler, device):
def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
# fix beta: zero terminal SNR
print(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891")
logger.info(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891")
def enforce_zero_terminal_snr(betas):
# Convert betas to alphas_bar_sqrt
@@ -49,50 +57,55 @@ def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
# print("original:", noise_scheduler.betas)
# print("fixed:", betas)
# logger.info(f"original: {noise_scheduler.betas}")
# logger.info(f"fixed: {betas}")
noise_scheduler.betas = betas
noise_scheduler.alphas = alphas
noise_scheduler.alphas_cumprod = alphas_cumprod
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
def apply_snr_weight(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, gamma: Number, v_prediction=False):
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device) # from paper
min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
if v_prediction:
snr_weight = torch.div(min_snr_gamma, snr + 1).float().to(loss.device)
else:
snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device)
loss = loss * snr_weight
return loss
def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler):
def scale_v_prediction_loss_like_noise_prediction(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler):
scale = get_snr_scale(timesteps, noise_scheduler)
loss = loss * scale
return loss
def get_snr_scale(timesteps, noise_scheduler):
def get_snr_scale(timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler):
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
scale = snr_t / (snr_t + 1)
# # show debug info
# print(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}")
# logger.info(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}")
return scale
def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss):
def add_v_prediction_like_loss(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_pred_like_loss: torch.Tensor):
scale = get_snr_scale(timesteps, noise_scheduler)
# print(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
# logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
loss = loss + loss / scale * v_pred_like_loss
return loss
def apply_debiased_estimation(loss, timesteps, noise_scheduler):
def apply_debiased_estimation(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler):
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
weight = 1/torch.sqrt(snr_t)
weight = 1 / torch.sqrt(snr_t)
loss = weight * loss
return loss
# TODO train_utilと分散しているのでどちらかに寄せる
@@ -265,7 +278,7 @@ def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int):
tokens.append(text_token)
weights.append(text_weight)
if truncated:
print("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
return tokens, weights
@@ -468,6 +481,17 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
return noise
def apply_masked_loss(loss, batch):
# mask image is -1 to 1. we need to convert it to 0 to 1
mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
# resize to the same size as the loss
mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")
mask_image = mask_image / 2 + 0.5
loss = loss * mask_image
return loss
"""
##########################################
# Perlin Noise

139
library/deepspeed_utils.py Normal file
View File

@@ -0,0 +1,139 @@
import os
import argparse
import torch
from accelerate import DeepSpeedPlugin, Accelerator
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def add_deepspeed_arguments(parser: argparse.ArgumentParser):
# DeepSpeed Arguments. https://huggingface.co/docs/accelerate/usage_guides/deepspeed
parser.add_argument("--deepspeed", action="store_true", help="enable deepspeed training")
parser.add_argument("--zero_stage", type=int, default=2, choices=[0, 1, 2, 3], help="Possible options are 0,1,2,3.")
parser.add_argument(
"--offload_optimizer_device",
type=str,
default=None,
choices=[None, "cpu", "nvme"],
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3.",
)
parser.add_argument(
"--offload_optimizer_nvme_path",
type=str,
default=None,
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
)
parser.add_argument(
"--offload_param_device",
type=str,
default=None,
choices=[None, "cpu", "nvme"],
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3.",
)
parser.add_argument(
"--offload_param_nvme_path",
type=str,
default=None,
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
)
parser.add_argument(
"--zero3_init_flag",
action="store_true",
help="Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models."
"Only applicable with ZeRO Stage-3.",
)
parser.add_argument(
"--zero3_save_16bit_model",
action="store_true",
help="Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3.",
)
parser.add_argument(
"--fp16_master_weights_and_gradients",
action="store_true",
help="fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32.",
)
def prepare_deepspeed_args(args: argparse.Namespace):
if not args.deepspeed:
return
# To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
args.max_data_loader_n_workers = 1
def prepare_deepspeed_plugin(args: argparse.Namespace):
if not args.deepspeed:
return None
try:
import deepspeed
except ImportError as e:
logger.error(
"deepspeed is not installed. please install deepspeed in your environment with following command. DS_BUILD_OPS=0 pip install deepspeed"
)
exit(1)
deepspeed_plugin = DeepSpeedPlugin(
zero_stage=args.zero_stage,
gradient_accumulation_steps=args.gradient_accumulation_steps,
gradient_clipping=args.max_grad_norm,
offload_optimizer_device=args.offload_optimizer_device,
offload_optimizer_nvme_path=args.offload_optimizer_nvme_path,
offload_param_device=args.offload_param_device,
offload_param_nvme_path=args.offload_param_nvme_path,
zero3_init_flag=args.zero3_init_flag,
zero3_save_16bit_model=args.zero3_save_16bit_model,
)
deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_batch_size
deepspeed_plugin.deepspeed_config["train_batch_size"] = (
args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"])
)
deepspeed_plugin.set_mixed_precision(args.mixed_precision)
if args.mixed_precision.lower() == "fp16":
deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow.
if args.full_fp16 or args.fp16_master_weights_and_gradients:
if args.offload_optimizer_device == "cpu" and args.zero_stage == 2:
deepspeed_plugin.deepspeed_config["fp16"]["fp16_master_weights_and_grads"] = True
logger.info("[DeepSpeed] full fp16 enable.")
else:
logger.info(
"[DeepSpeed]full fp16, fp16_master_weights_and_grads currently only supported using ZeRO-Offload with DeepSpeedCPUAdam on ZeRO-2 stage."
)
if args.offload_optimizer_device is not None:
logger.info("[DeepSpeed] start to manually build cpu_adam.")
deepspeed.ops.op_builder.CPUAdamBuilder().load()
logger.info("[DeepSpeed] building cpu_adam done.")
return deepspeed_plugin
# Accelerate library does not support multiple models for deepspeed. So, we need to wrap multiple models into a single model.
def prepare_deepspeed_model(args: argparse.Namespace, **models):
# remove None from models
models = {k: v for k, v in models.items() if v is not None}
class DeepSpeedWrapper(torch.nn.Module):
def __init__(self, **kw_models) -> None:
super().__init__()
self.models = torch.nn.ModuleDict()
for key, model in kw_models.items():
if isinstance(model, list):
model = torch.nn.ModuleList(model)
assert isinstance(
model, torch.nn.Module
), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
self.models.update(torch.nn.ModuleDict({key: model}))
def get_models(self):
return self.models
ds_model = DeepSpeedWrapper(**models)
return ds_model

84
library/device_utils.py Normal file
View File

@@ -0,0 +1,84 @@
import functools
import gc
import torch
try:
HAS_CUDA = torch.cuda.is_available()
except Exception:
HAS_CUDA = False
try:
HAS_MPS = torch.backends.mps.is_available()
except Exception:
HAS_MPS = False
try:
import intel_extension_for_pytorch as ipex # noqa
HAS_XPU = torch.xpu.is_available()
except Exception:
HAS_XPU = False
def clean_memory():
gc.collect()
if HAS_CUDA:
torch.cuda.empty_cache()
if HAS_XPU:
torch.xpu.empty_cache()
if HAS_MPS:
torch.mps.empty_cache()
def clean_memory_on_device(device: torch.device):
r"""
Clean memory on the specified device, will be called from training scripts.
"""
gc.collect()
# device may "cuda" or "cuda:0", so we need to check the type of device
if device.type == "cuda":
torch.cuda.empty_cache()
if device.type == "xpu":
torch.xpu.empty_cache()
if device.type == "mps":
torch.mps.empty_cache()
@functools.lru_cache(maxsize=None)
def get_preferred_device() -> torch.device:
r"""
Do not call this function from training scripts. Use accelerator.device instead.
"""
if HAS_CUDA:
device = torch.device("cuda")
elif HAS_XPU:
device = torch.device("xpu")
elif HAS_MPS:
device = torch.device("mps")
else:
device = torch.device("cpu")
print(f"get_preferred_device() -> {device}")
return device
def init_ipex():
"""
Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`.
This function should run right after importing torch and before doing anything else.
If IPEX is not available, this function does nothing.
"""
try:
if HAS_XPU:
from library.ipex import ipex_init
is_initialized, error_message = ipex_init()
if not is_initialized:
print("failed to initialize ipex:", error_message)
else:
return
except Exception as e:
print("failed to initialize ipex:", e)

View File

@@ -4,7 +4,10 @@ from pathlib import Path
import argparse
import os
from library.utils import fire_in_thread
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None):
api = HfApi(
@@ -33,9 +36,9 @@ def upload(
try:
api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
except Exception as e: # とりあえずRepositoryNotFoundErrorは確認したが他にあると困るので
print("===========================================")
print(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}")
print("===========================================")
logger.error("===========================================")
logger.error(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}")
logger.error("===========================================")
is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir())
@@ -56,9 +59,9 @@ def upload(
path_in_repo=path_in_repo,
)
except Exception as e: # RuntimeErrorを確認済みだが他にあると困るので
print("===========================================")
print(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}")
print("===========================================")
logger.error("===========================================")
logger.error(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}")
logger.error("===========================================")
if args.async_upload and not force_sync_upload:
fire_in_thread(uploader)

View File

@@ -4,172 +4,177 @@ import contextlib
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
from .hijacks import ipex_hijacks
from .attention import attention_init
# pylint: disable=protected-access, missing-function-docstring, line-too-long
def ipex_init(): # pylint: disable=too-many-statements
try:
#Replace cuda with xpu:
torch.cuda.current_device = torch.xpu.current_device
torch.cuda.current_stream = torch.xpu.current_stream
torch.cuda.device = torch.xpu.device
torch.cuda.device_count = torch.xpu.device_count
torch.cuda.device_of = torch.xpu.device_of
torch.cuda.get_device_name = torch.xpu.get_device_name
torch.cuda.get_device_properties = torch.xpu.get_device_properties
torch.cuda.init = torch.xpu.init
torch.cuda.is_available = torch.xpu.is_available
torch.cuda.is_initialized = torch.xpu.is_initialized
torch.cuda.is_current_stream_capturing = lambda: False
torch.cuda.set_device = torch.xpu.set_device
torch.cuda.stream = torch.xpu.stream
torch.cuda.synchronize = torch.xpu.synchronize
torch.cuda.Event = torch.xpu.Event
torch.cuda.Stream = torch.xpu.Stream
torch.cuda.FloatTensor = torch.xpu.FloatTensor
torch.Tensor.cuda = torch.Tensor.xpu
torch.Tensor.is_cuda = torch.Tensor.is_xpu
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
torch.cuda._initialized = torch.xpu.lazy_init._initialized
torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
torch.cuda._tls = torch.xpu.lazy_init._tls
torch.cuda.threading = torch.xpu.lazy_init.threading
torch.cuda.traceback = torch.xpu.lazy_init.traceback
torch.cuda.Optional = torch.xpu.Optional
torch.cuda.__cached__ = torch.xpu.__cached__
torch.cuda.__loader__ = torch.xpu.__loader__
torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
torch.cuda.Tuple = torch.xpu.Tuple
torch.cuda.streams = torch.xpu.streams
torch.cuda._lazy_new = torch.xpu._lazy_new
torch.cuda.FloatStorage = torch.xpu.FloatStorage
torch.cuda.Any = torch.xpu.Any
torch.cuda.__doc__ = torch.xpu.__doc__
torch.cuda.default_generators = torch.xpu.default_generators
torch.cuda.HalfTensor = torch.xpu.HalfTensor
torch.cuda._get_device_index = torch.xpu._get_device_index
torch.cuda.__path__ = torch.xpu.__path__
torch.cuda.Device = torch.xpu.Device
torch.cuda.IntTensor = torch.xpu.IntTensor
torch.cuda.ByteStorage = torch.xpu.ByteStorage
torch.cuda.set_stream = torch.xpu.set_stream
torch.cuda.BoolStorage = torch.xpu.BoolStorage
torch.cuda.os = torch.xpu.os
torch.cuda.torch = torch.xpu.torch
torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
torch.cuda.Union = torch.xpu.Union
torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
torch.cuda.ShortTensor = torch.xpu.ShortTensor
torch.cuda.LongTensor = torch.xpu.LongTensor
torch.cuda.IntStorage = torch.xpu.IntStorage
torch.cuda.LongStorage = torch.xpu.LongStorage
torch.cuda.__annotations__ = torch.xpu.__annotations__
torch.cuda.__package__ = torch.xpu.__package__
torch.cuda.__builtins__ = torch.xpu.__builtins__
torch.cuda.CharTensor = torch.xpu.CharTensor
torch.cuda.List = torch.xpu.List
torch.cuda._lazy_init = torch.xpu._lazy_init
torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
torch.cuda.ByteTensor = torch.xpu.ByteTensor
torch.cuda.StreamContext = torch.xpu.StreamContext
torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
torch.cuda.ShortStorage = torch.xpu.ShortStorage
torch.cuda._lazy_call = torch.xpu._lazy_call
torch.cuda.HalfStorage = torch.xpu.HalfStorage
torch.cuda.random = torch.xpu.random
torch.cuda._device = torch.xpu._device
torch.cuda.classproperty = torch.xpu.classproperty
torch.cuda.__name__ = torch.xpu.__name__
torch.cuda._device_t = torch.xpu._device_t
torch.cuda.warnings = torch.xpu.warnings
torch.cuda.__spec__ = torch.xpu.__spec__
torch.cuda.BoolTensor = torch.xpu.BoolTensor
torch.cuda.CharStorage = torch.xpu.CharStorage
torch.cuda.__file__ = torch.xpu.__file__
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
#torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_xpu_hijacked") and torch.cuda.is_xpu_hijacked:
return True, "Skipping IPEX hijack"
else:
# Replace cuda with xpu:
torch.cuda.current_device = torch.xpu.current_device
torch.cuda.current_stream = torch.xpu.current_stream
torch.cuda.device = torch.xpu.device
torch.cuda.device_count = torch.xpu.device_count
torch.cuda.device_of = torch.xpu.device_of
torch.cuda.get_device_name = torch.xpu.get_device_name
torch.cuda.get_device_properties = torch.xpu.get_device_properties
torch.cuda.init = torch.xpu.init
torch.cuda.is_available = torch.xpu.is_available
torch.cuda.is_initialized = torch.xpu.is_initialized
torch.cuda.is_current_stream_capturing = lambda: False
torch.cuda.set_device = torch.xpu.set_device
torch.cuda.stream = torch.xpu.stream
torch.cuda.synchronize = torch.xpu.synchronize
torch.cuda.Event = torch.xpu.Event
torch.cuda.Stream = torch.xpu.Stream
torch.cuda.FloatTensor = torch.xpu.FloatTensor
torch.Tensor.cuda = torch.Tensor.xpu
torch.Tensor.is_cuda = torch.Tensor.is_xpu
torch.nn.Module.cuda = torch.nn.Module.xpu
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
torch.cuda._initialized = torch.xpu.lazy_init._initialized
torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
torch.cuda._tls = torch.xpu.lazy_init._tls
torch.cuda.threading = torch.xpu.lazy_init.threading
torch.cuda.traceback = torch.xpu.lazy_init.traceback
torch.cuda.Optional = torch.xpu.Optional
torch.cuda.__cached__ = torch.xpu.__cached__
torch.cuda.__loader__ = torch.xpu.__loader__
torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
torch.cuda.Tuple = torch.xpu.Tuple
torch.cuda.streams = torch.xpu.streams
torch.cuda._lazy_new = torch.xpu._lazy_new
torch.cuda.FloatStorage = torch.xpu.FloatStorage
torch.cuda.Any = torch.xpu.Any
torch.cuda.__doc__ = torch.xpu.__doc__
torch.cuda.default_generators = torch.xpu.default_generators
torch.cuda.HalfTensor = torch.xpu.HalfTensor
torch.cuda._get_device_index = torch.xpu._get_device_index
torch.cuda.__path__ = torch.xpu.__path__
torch.cuda.Device = torch.xpu.Device
torch.cuda.IntTensor = torch.xpu.IntTensor
torch.cuda.ByteStorage = torch.xpu.ByteStorage
torch.cuda.set_stream = torch.xpu.set_stream
torch.cuda.BoolStorage = torch.xpu.BoolStorage
torch.cuda.os = torch.xpu.os
torch.cuda.torch = torch.xpu.torch
torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
torch.cuda.Union = torch.xpu.Union
torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
torch.cuda.ShortTensor = torch.xpu.ShortTensor
torch.cuda.LongTensor = torch.xpu.LongTensor
torch.cuda.IntStorage = torch.xpu.IntStorage
torch.cuda.LongStorage = torch.xpu.LongStorage
torch.cuda.__annotations__ = torch.xpu.__annotations__
torch.cuda.__package__ = torch.xpu.__package__
torch.cuda.__builtins__ = torch.xpu.__builtins__
torch.cuda.CharTensor = torch.xpu.CharTensor
torch.cuda.List = torch.xpu.List
torch.cuda._lazy_init = torch.xpu._lazy_init
torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
torch.cuda.ByteTensor = torch.xpu.ByteTensor
torch.cuda.StreamContext = torch.xpu.StreamContext
torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
torch.cuda.ShortStorage = torch.xpu.ShortStorage
torch.cuda._lazy_call = torch.xpu._lazy_call
torch.cuda.HalfStorage = torch.xpu.HalfStorage
torch.cuda.random = torch.xpu.random
torch.cuda._device = torch.xpu._device
torch.cuda.classproperty = torch.xpu.classproperty
torch.cuda.__name__ = torch.xpu.__name__
torch.cuda._device_t = torch.xpu._device_t
torch.cuda.warnings = torch.xpu.warnings
torch.cuda.__spec__ = torch.xpu.__spec__
torch.cuda.BoolTensor = torch.xpu.BoolTensor
torch.cuda.CharStorage = torch.xpu.CharStorage
torch.cuda.__file__ = torch.xpu.__file__
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
# torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
#Memory:
torch.cuda.memory = torch.xpu.memory
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
torch.xpu.empty_cache = lambda: None
torch.cuda.empty_cache = torch.xpu.empty_cache
torch.cuda.memory_stats = torch.xpu.memory_stats
torch.cuda.memory_summary = torch.xpu.memory_summary
torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
torch.cuda.memory_allocated = torch.xpu.memory_allocated
torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated
torch.cuda.memory_reserved = torch.xpu.memory_reserved
torch.cuda.memory_cached = torch.xpu.memory_reserved
torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved
torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved
torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats
torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats
torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats
torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict
torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats
# Memory:
torch.cuda.memory = torch.xpu.memory
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
torch.xpu.empty_cache = lambda: None
torch.cuda.empty_cache = torch.xpu.empty_cache
torch.cuda.memory_stats = torch.xpu.memory_stats
torch.cuda.memory_summary = torch.xpu.memory_summary
torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
torch.cuda.memory_allocated = torch.xpu.memory_allocated
torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated
torch.cuda.memory_reserved = torch.xpu.memory_reserved
torch.cuda.memory_cached = torch.xpu.memory_reserved
torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved
torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved
torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats
torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats
torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats
torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict
torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats
#RNG:
torch.cuda.get_rng_state = torch.xpu.get_rng_state
torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all
torch.cuda.set_rng_state = torch.xpu.set_rng_state
torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all
torch.cuda.manual_seed = torch.xpu.manual_seed
torch.cuda.manual_seed_all = torch.xpu.manual_seed_all
torch.cuda.seed = torch.xpu.seed
torch.cuda.seed_all = torch.xpu.seed_all
torch.cuda.initial_seed = torch.xpu.initial_seed
# RNG:
torch.cuda.get_rng_state = torch.xpu.get_rng_state
torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all
torch.cuda.set_rng_state = torch.xpu.set_rng_state
torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all
torch.cuda.manual_seed = torch.xpu.manual_seed
torch.cuda.manual_seed_all = torch.xpu.manual_seed_all
torch.cuda.seed = torch.xpu.seed
torch.cuda.seed_all = torch.xpu.seed_all
torch.cuda.initial_seed = torch.xpu.initial_seed
# AMP:
torch.cuda.amp = torch.xpu.amp
torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled
torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype
if not hasattr(torch.cuda.amp, "common"):
torch.cuda.amp.common = contextlib.nullcontext()
torch.cuda.amp.common.amp_definitely_not_available = lambda: False
#AMP:
torch.cuda.amp = torch.xpu.amp
if not hasattr(torch.cuda.amp, "common"):
torch.cuda.amp.common = contextlib.nullcontext()
torch.cuda.amp.common.amp_definitely_not_available = lambda: False
try:
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
except Exception: # pylint: disable=broad-exception-caught
try:
from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error
gradscaler_init()
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
except Exception: # pylint: disable=broad-exception-caught
torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
try:
from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error
gradscaler_init()
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
except Exception: # pylint: disable=broad-exception-caught
torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
#C
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
ipex._C._DeviceProperties.major = 2023
ipex._C._DeviceProperties.minor = 2
# C
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
ipex._C._DeviceProperties.major = 2024
ipex._C._DeviceProperties.minor = 0
#Fix functions with ipex:
torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
torch._utils._get_available_device_type = lambda: "xpu"
torch.has_cuda = True
torch.cuda.has_half = True
torch.cuda.is_bf16_supported = lambda *args, **kwargs: True
torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
torch.version.cuda = "11.7"
torch.cuda.get_device_capability = lambda *args, **kwargs: [11,7]
torch.cuda.get_device_properties.major = 11
torch.cuda.get_device_properties.minor = 7
torch.cuda.ipc_collect = lambda *args, **kwargs: None
torch.cuda.utilization = lambda *args, **kwargs: 0
if hasattr(torch.xpu, 'getDeviceIdListForCard'):
torch.cuda.getDeviceIdListForCard = torch.xpu.getDeviceIdListForCard
torch.cuda.get_device_id_list_per_card = torch.xpu.getDeviceIdListForCard
else:
torch.cuda.getDeviceIdListForCard = torch.xpu.get_device_id_list_per_card
torch.cuda.get_device_id_list_per_card = torch.xpu.get_device_id_list_per_card
# Fix functions with ipex:
torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
torch._utils._get_available_device_type = lambda: "xpu"
torch.has_cuda = True
torch.cuda.has_half = True
torch.cuda.is_bf16_supported = lambda *args, **kwargs: True
torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
torch.backends.cuda.is_built = lambda *args, **kwargs: True
torch.version.cuda = "12.1"
torch.cuda.get_device_capability = lambda *args, **kwargs: [12,1]
torch.cuda.get_device_properties.major = 12
torch.cuda.get_device_properties.minor = 1
torch.cuda.ipc_collect = lambda *args, **kwargs: None
torch.cuda.utilization = lambda *args, **kwargs: 0
ipex_hijacks()
attention_init()
try:
from .diffusers import ipex_diffusers
ipex_diffusers()
except Exception: # pylint: disable=broad-exception-caught
pass
ipex_hijacks()
if not torch.xpu.has_fp64_dtype() or os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is not None:
try:
from .diffusers import ipex_diffusers
ipex_diffusers()
except Exception: # pylint: disable=broad-exception-caught
pass
torch.cuda.is_xpu_hijacked = True
except Exception as e:
return False, e
return True, None

View File

@@ -1,45 +1,98 @@
import os
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
from functools import cache
# pylint: disable=protected-access, missing-function-docstring, line-too-long
original_torch_bmm = torch.bmm
def torch_bmm(input, mat2, *, out=None):
if input.dtype != mat2.dtype:
mat2 = mat2.to(input.dtype)
# ARC GPUs can't allocate more than 4GB to a single block so we slice the attetion layers
#ARC GPUs can't allocate more than 4GB to a single block, Slice it:
batch_size_attention, input_tokens, mat2_shape = input.shape[0], input.shape[1], mat2.shape[2]
block_multiply = input.element_size()
slice_block_size = input_tokens * mat2_shape / 1024 / 1024 * block_multiply
sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4))
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
# Find something divisible with the input_tokens
@cache
def find_slice_size(slice_size, slice_block_size):
while (slice_size * slice_block_size) > attention_slice_rate:
slice_size = slice_size // 2
if slice_size <= 1:
slice_size = 1
break
return slice_size
# Find slice sizes for SDPA
@cache
def find_sdpa_slice_sizes(query_shape, query_element_size):
if len(query_shape) == 3:
batch_size_attention, query_tokens, shape_three = query_shape
shape_four = 1
else:
batch_size_attention, query_tokens, shape_three, shape_four = query_shape
slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size
block_size = batch_size_attention * slice_block_size
split_slice_size = batch_size_attention
if block_size > 4:
split_2_slice_size = query_tokens
split_3_slice_size = shape_three
do_split = False
do_split_2 = False
do_split_3 = False
if block_size > sdpa_slice_trigger_rate:
do_split = True
#Find something divisible with the input_tokens
while (split_slice_size * slice_block_size) > 4:
split_slice_size = split_slice_size // 2
if split_slice_size <= 1:
split_slice_size = 1
break
else:
do_split = False
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
if split_slice_size * slice_block_size > attention_slice_rate:
slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size
do_split_2 = True
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size
do_split_3 = True
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
# Find slice sizes for BMM
@cache
def find_bmm_slice_sizes(input_shape, input_element_size, mat2_shape):
batch_size_attention, input_tokens, mat2_atten_shape = input_shape[0], input_shape[1], mat2_shape[2]
slice_block_size = input_tokens * mat2_atten_shape / 1024 / 1024 * input_element_size
block_size = batch_size_attention * slice_block_size
split_slice_size = batch_size_attention
split_2_slice_size = input_tokens
if split_slice_size * slice_block_size > 4:
slice_block_size2 = split_slice_size * mat2_shape / 1024 / 1024 * block_multiply
do_split_2 = True
#Find something divisible with the input_tokens
while (split_2_slice_size * slice_block_size2) > 4:
split_2_slice_size = split_2_slice_size // 2
if split_2_slice_size <= 1:
split_2_slice_size = 1
break
else:
do_split_2 = False
split_3_slice_size = mat2_atten_shape
do_split = False
do_split_2 = False
do_split_3 = False
if block_size > attention_slice_rate:
do_split = True
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
if split_slice_size * slice_block_size > attention_slice_rate:
slice_2_block_size = split_slice_size * mat2_atten_shape / 1024 / 1024 * input_element_size
do_split_2 = True
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
slice_3_block_size = split_slice_size * split_2_slice_size / 1024 / 1024 * input_element_size
do_split_3 = True
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
original_torch_bmm = torch.bmm
def torch_bmm_32_bit(input, mat2, *, out=None):
if input.device.type != "xpu":
return original_torch_bmm(input, mat2, out=out)
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_bmm_slice_sizes(input.shape, input.element_size(), mat2.shape)
# Slice BMM
if do_split:
batch_size_attention, input_tokens, mat2_atten_shape = input.shape[0], input.shape[1], mat2.shape[2]
hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype)
for i in range(batch_size_attention // split_slice_size):
start_idx = i * split_slice_size
@@ -48,62 +101,41 @@ def torch_bmm(input, mat2, *, out=None):
for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name
start_idx_2 = i2 * split_2_slice_size
end_idx_2 = (i2 + 1) * split_2_slice_size
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm(
input[start_idx:end_idx, start_idx_2:end_idx_2],
mat2[start_idx:end_idx, start_idx_2:end_idx_2],
out=out
)
if do_split_3:
for i3 in range(mat2_atten_shape // split_3_slice_size): # pylint: disable=invalid-name
start_idx_3 = i3 * split_3_slice_size
end_idx_3 = (i3 + 1) * split_3_slice_size
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_torch_bmm(
input[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
mat2[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
out=out
)
else:
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm(
input[start_idx:end_idx, start_idx_2:end_idx_2],
mat2[start_idx:end_idx, start_idx_2:end_idx_2],
out=out
)
else:
hidden_states[start_idx:end_idx] = original_torch_bmm(
input[start_idx:end_idx],
mat2[start_idx:end_idx],
out=out
)
torch.xpu.synchronize(input.device)
else:
return original_torch_bmm(input, mat2, out=out)
return hidden_states
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
#ARC GPUs can't allocate more than 4GB to a single block, Slice it:
if len(query.shape) == 3:
batch_size_attention, query_tokens, shape_four = query.shape
shape_one = 1
no_shape_one = True
else:
shape_one, batch_size_attention, query_tokens, shape_four = query.shape
no_shape_one = False
block_multiply = query.element_size()
slice_block_size = shape_one * query_tokens * shape_four / 1024 / 1024 * block_multiply
block_size = batch_size_attention * slice_block_size
split_slice_size = batch_size_attention
if block_size > 4:
do_split = True
#Find something divisible with the shape_one
while (split_slice_size * slice_block_size) > 4:
split_slice_size = split_slice_size // 2
if split_slice_size <= 1:
split_slice_size = 1
break
else:
do_split = False
split_2_slice_size = query_tokens
if split_slice_size * slice_block_size > 4:
slice_block_size2 = shape_one * split_slice_size * shape_four / 1024 / 1024 * block_multiply
do_split_2 = True
#Find something divisible with the batch_size_attention
while (split_2_slice_size * slice_block_size2) > 4:
split_2_slice_size = split_2_slice_size // 2
if split_2_slice_size <= 1:
split_2_slice_size = 1
break
else:
do_split_2 = False
def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs):
if query.device.type != "xpu":
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_sdpa_slice_sizes(query.shape, query.element_size())
# Slice SDPA
if do_split:
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
for i in range(batch_size_attention // split_slice_size):
start_idx = i * split_slice_size
@@ -112,46 +144,34 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
start_idx_2 = i2 * split_2_slice_size
end_idx_2 = (i2 + 1) * split_2_slice_size
if no_shape_one:
if do_split_3:
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
start_idx_3 = i3 * split_3_slice_size
end_idx_3 = (i3 + 1) * split_3_slice_size
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_scaled_dot_product_attention(
query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal, **kwargs
)
else:
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
query[start_idx:end_idx, start_idx_2:end_idx_2],
key[start_idx:end_idx, start_idx_2:end_idx_2],
value[start_idx:end_idx, start_idx_2:end_idx_2],
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal
)
else:
hidden_states[:, start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
query[:, start_idx:end_idx, start_idx_2:end_idx_2],
key[:, start_idx:end_idx, start_idx_2:end_idx_2],
value[:, start_idx:end_idx, start_idx_2:end_idx_2],
attn_mask=attn_mask[:, start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal
dropout_p=dropout_p, is_causal=is_causal, **kwargs
)
else:
if no_shape_one:
hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention(
query[start_idx:end_idx],
key[start_idx:end_idx],
value[start_idx:end_idx],
attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal
)
else:
hidden_states[:, start_idx:end_idx] = original_scaled_dot_product_attention(
query[:, start_idx:end_idx],
key[:, start_idx:end_idx],
value[:, start_idx:end_idx],
attn_mask=attn_mask[:, start_idx:end_idx] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal
)
hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention(
query[start_idx:end_idx],
key[start_idx:end_idx],
value[start_idx:end_idx],
attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal, **kwargs
)
torch.xpu.synchronize(query.device)
else:
return original_scaled_dot_product_attention(
query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal
)
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
return hidden_states
def attention_init():
#ARC GPUs can't allocate more than 4GB to a single block:
torch.bmm = torch_bmm
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention

View File

@@ -1,10 +1,62 @@
import os
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
import diffusers #0.21.1 # pylint: disable=import-error
import diffusers #0.24.0 # pylint: disable=import-error
from diffusers.models.attention_processor import Attention
from diffusers.utils import USE_PEFT_BACKEND
from functools import cache
# pylint: disable=protected-access, missing-function-docstring, line-too-long
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
@cache
def find_slice_size(slice_size, slice_block_size):
while (slice_size * slice_block_size) > attention_slice_rate:
slice_size = slice_size // 2
if slice_size <= 1:
slice_size = 1
break
return slice_size
@cache
def find_attention_slice_sizes(query_shape, query_element_size, query_device_type, slice_size=None):
if len(query_shape) == 3:
batch_size_attention, query_tokens, shape_three = query_shape
shape_four = 1
else:
batch_size_attention, query_tokens, shape_three, shape_four = query_shape
if slice_size is not None:
batch_size_attention = slice_size
slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size
block_size = batch_size_attention * slice_block_size
split_slice_size = batch_size_attention
split_2_slice_size = query_tokens
split_3_slice_size = shape_three
do_split = False
do_split_2 = False
do_split_3 = False
if query_device_type != "xpu":
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
if block_size > attention_slice_rate:
do_split = True
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
if split_slice_size * slice_block_size > attention_slice_rate:
slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size
do_split_2 = True
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size
do_split_3 = True
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
r"""
Processor for implementing sliced attention.
@@ -18,7 +70,9 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
def __init__(self, slice_size):
self.slice_size = slice_size
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): # pylint: disable=too-many-statements, too-many-locals, too-many-branches
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
encoder_hidden_states=None, attention_mask=None) -> torch.FloatTensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches
residual = hidden_states
input_ndim = hidden_states.ndim
@@ -54,49 +108,62 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
)
#ARC GPUs can't allocate more than 4GB to a single block, Slice it:
block_multiply = query.element_size()
slice_block_size = self.slice_size * shape_three / 1024 / 1024 * block_multiply
block_size = query_tokens * slice_block_size
split_2_slice_size = query_tokens
if block_size > 4:
do_split_2 = True
#Find something divisible with the query_tokens
while (split_2_slice_size * slice_block_size) > 4:
split_2_slice_size = split_2_slice_size // 2
if split_2_slice_size <= 1:
split_2_slice_size = 1
break
else:
do_split_2 = False
for i in range(batch_size_attention // self.slice_size):
start_idx = i * self.slice_size
end_idx = (i + 1) * self.slice_size
####################################################################
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
_, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type, slice_size=self.slice_size)
for i in range(batch_size_attention // split_slice_size):
start_idx = i * split_slice_size
end_idx = (i + 1) * split_slice_size
if do_split_2:
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
start_idx_2 = i2 * split_2_slice_size
end_idx_2 = (i2 + 1) * split_2_slice_size
if do_split_3:
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
start_idx_3 = i3 * split_3_slice_size
end_idx_3 = (i3 + 1) * split_3_slice_size
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
del query_slice
del key_slice
del attn_mask_slice
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3])
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice
del attn_slice
else:
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
del query_slice
del key_slice
del attn_mask_slice
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
del attn_slice
torch.xpu.synchronize(query.device)
else:
query_slice = query[start_idx:end_idx]
key_slice = key[start_idx:end_idx]
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
del query_slice
del key_slice
del attn_mask_slice
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
del attn_slice
####################################################################
hidden_states = attn.batch_to_head_dim(hidden_states)
@@ -115,6 +182,131 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
return hidden_states
class AttnProcessor:
r"""
Default processor for performing attention-related computations.
"""
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
encoder_hidden_states=None, attention_mask=None,
temb=None, scale: float = 1.0) -> torch.Tensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches
residual = hidden_states
args = () if USE_PEFT_BACKEND else (scale,)
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, *args)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
####################################################################
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type)
if do_split:
for i in range(batch_size_attention // split_slice_size):
start_idx = i * split_slice_size
end_idx = (i + 1) * split_slice_size
if do_split_2:
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
start_idx_2 = i2 * split_2_slice_size
end_idx_2 = (i2 + 1) * split_2_slice_size
if do_split_3:
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
start_idx_3 = i3 * split_3_slice_size
end_idx_3 = (i3 + 1) * split_3_slice_size
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
del query_slice
del key_slice
del attn_mask_slice
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3])
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice
del attn_slice
else:
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
del query_slice
del key_slice
del attn_mask_slice
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
del attn_slice
else:
query_slice = query[start_idx:end_idx]
key_slice = key[start_idx:end_idx]
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
del query_slice
del key_slice
del attn_mask_slice
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
del attn_slice
torch.xpu.synchronize(query.device)
else:
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
####################################################################
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
def ipex_diffusers():
#ARC GPUs can't allocate more than 4GB to a single block:
diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor
diffusers.models.attention_processor.AttnProcessor = AttnProcessor

View File

@@ -5,6 +5,7 @@ import intel_extension_for_pytorch._C as core # pylint: disable=import-error, un
# pylint: disable=protected-access, missing-function-docstring, line-too-long
device_supports_fp64 = torch.xpu.has_fp64_dtype()
OptState = ipex.cpu.autocast._grad_scaler.OptState
_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator
_refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state
@@ -96,7 +97,10 @@ def unscale_(self, optimizer):
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
assert self._scale is not None
inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device)
if device_supports_fp64:
inv_scale = self._scale.double().reciprocal().float()
else:
inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device)
found_inf = torch.full(
(1,), 0.0, dtype=torch.float32, device=self._scale.device
)

View File

@@ -1,67 +1,14 @@
import contextlib
import importlib
import os
from functools import wraps
from contextlib import nullcontext
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
import numpy as np
device_supports_fp64 = torch.xpu.has_fp64_dtype()
# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
class CondFunc: # pylint: disable=missing-class-docstring
def __new__(cls, orig_func, sub_func, cond_func):
self = super(CondFunc, cls).__new__(cls)
if isinstance(orig_func, str):
func_path = orig_func.split('.')
for i in range(len(func_path)-1, -1, -1):
try:
resolved_obj = importlib.import_module('.'.join(func_path[:i]))
break
except ImportError:
pass
for attr_name in func_path[i:-1]:
resolved_obj = getattr(resolved_obj, attr_name)
orig_func = getattr(resolved_obj, func_path[-1])
setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
self.__init__(orig_func, sub_func, cond_func)
return lambda *args, **kwargs: self(*args, **kwargs)
def __init__(self, orig_func, sub_func, cond_func):
self.__orig_func = orig_func
self.__sub_func = sub_func
self.__cond_func = cond_func
def __call__(self, *args, **kwargs):
if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
return self.__sub_func(self.__orig_func, *args, **kwargs)
else:
return self.__orig_func(*args, **kwargs)
_utils = torch.utils.data._utils
def _shutdown_workers(self):
if torch.utils.data._utils is None or torch.utils.data._utils.python_exit_status is True or torch.utils.data._utils.python_exit_status is None:
return
if hasattr(self, "_shutdown") and not self._shutdown:
self._shutdown = True
try:
if hasattr(self, '_pin_memory_thread'):
self._pin_memory_thread_done_event.set()
self._worker_result_queue.put((None, None))
self._pin_memory_thread.join()
self._worker_result_queue.cancel_join_thread()
self._worker_result_queue.close()
self._workers_done_event.set()
for worker_id in range(len(self._workers)):
if self._persistent_workers or self._workers_status[worker_id]:
self._mark_worker_as_unavailable(worker_id, shutdown=True)
for w in self._workers: # pylint: disable=invalid-name
w.join(timeout=torch.utils.data._utils.MP_STATUS_CHECK_INTERVAL)
for q in self._index_queues: # pylint: disable=invalid-name
q.cancel_join_thread()
q.close()
finally:
if self._worker_pids_set:
torch.utils.data._utils.signal_handling._remove_worker_pids(id(self))
self._worker_pids_set = False
for w in self._workers: # pylint: disable=invalid-name
if w.is_alive():
w.terminate()
class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument
if isinstance(device_ids, list) and len(device_ids) > 1:
@@ -69,7 +16,11 @@ class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstr
return module.to("xpu")
def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
return contextlib.nullcontext()
return nullcontext()
@property
def is_cuda(self):
return self.device.type == 'xpu' or self.device.type == 'cuda'
def check_device(device):
return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int))
@@ -77,28 +28,21 @@ def check_device(device):
def return_xpu(device):
return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu"
def ipex_no_cuda(orig_func, *args, **kwargs):
torch.cuda.is_available = lambda: False
orig_func(*args, **kwargs)
torch.cuda.is_available = torch.xpu.is_available
original_autocast = torch.autocast
def ipex_autocast(*args, **kwargs):
if len(args) > 0 and args[0] == "cuda":
return original_autocast("xpu", *args[1:], **kwargs)
# Autocast
original_autocast_init = torch.amp.autocast_mode.autocast.__init__
@wraps(torch.amp.autocast_mode.autocast.__init__)
def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=None):
if device_type == "cuda":
return original_autocast_init(self, device_type="xpu", dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
else:
return original_autocast(*args, **kwargs)
original_torch_cat = torch.cat
def torch_cat(tensor, *args, **kwargs):
if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs)
else:
return original_torch_cat(tensor, *args, **kwargs)
return original_autocast_init(self, device_type=device_type, dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
# Latent Antialias CPU Offload:
original_interpolate = torch.nn.functional.interpolate
@wraps(torch.nn.functional.interpolate)
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
if antialias or align_corners is not None:
if antialias or align_corners is not None or mode == 'bicubic':
return_device = tensor.device
return_dtype = tensor.dtype
return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode,
@@ -107,90 +51,263 @@ def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corn
return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode,
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias)
original_linalg_solve = torch.linalg.solve
def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name
if A.device != torch.device("cpu") or B.device != torch.device("cpu"):
return_device = A.device
return original_linalg_solve(A.to("cpu"), B.to("cpu"), *args, **kwargs).to(return_device)
# Diffusers Float64 (Alchemist GPUs doesn't support 64 bit):
original_from_numpy = torch.from_numpy
@wraps(torch.from_numpy)
def from_numpy(ndarray):
if ndarray.dtype == float:
return original_from_numpy(ndarray.astype('float32'))
else:
return original_linalg_solve(A, B, *args, **kwargs)
return original_from_numpy(ndarray)
original_as_tensor = torch.as_tensor
@wraps(torch.as_tensor)
def as_tensor(data, dtype=None, device=None):
if check_device(device):
device = return_xpu(device)
if isinstance(data, np.ndarray) and data.dtype == float and not (
(isinstance(device, torch.device) and device.type == "cpu") or (isinstance(device, str) and "cpu" in device)):
return original_as_tensor(data, dtype=torch.float32, device=device)
else:
return original_as_tensor(data, dtype=dtype, device=device)
if device_supports_fp64 and os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is None:
original_torch_bmm = torch.bmm
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
else:
# 32 bit attention workarounds for Alchemist:
try:
from .attention import torch_bmm_32_bit as original_torch_bmm
from .attention import scaled_dot_product_attention_32_bit as original_scaled_dot_product_attention
except Exception: # pylint: disable=broad-exception-caught
original_torch_bmm = torch.bmm
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
# Data Type Errors:
@wraps(torch.bmm)
def torch_bmm(input, mat2, *, out=None):
if input.dtype != mat2.dtype:
mat2 = mat2.to(input.dtype)
return original_torch_bmm(input, mat2, out=out)
@wraps(torch.nn.functional.scaled_dot_product_attention)
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
if query.dtype != key.dtype:
key = key.to(dtype=query.dtype)
if query.dtype != value.dtype:
value = value.to(dtype=query.dtype)
if attn_mask is not None and query.dtype != attn_mask.dtype:
attn_mask = attn_mask.to(dtype=query.dtype)
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
# A1111 FP16
original_functional_group_norm = torch.nn.functional.group_norm
@wraps(torch.nn.functional.group_norm)
def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05):
if weight is not None and input.dtype != weight.data.dtype:
input = input.to(dtype=weight.data.dtype)
if bias is not None and weight is not None and bias.data.dtype != weight.data.dtype:
bias.data = bias.data.to(dtype=weight.data.dtype)
return original_functional_group_norm(input, num_groups, weight=weight, bias=bias, eps=eps)
# A1111 BF16
original_functional_layer_norm = torch.nn.functional.layer_norm
@wraps(torch.nn.functional.layer_norm)
def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
if weight is not None and input.dtype != weight.data.dtype:
input = input.to(dtype=weight.data.dtype)
if bias is not None and weight is not None and bias.data.dtype != weight.data.dtype:
bias.data = bias.data.to(dtype=weight.data.dtype)
return original_functional_layer_norm(input, normalized_shape, weight=weight, bias=bias, eps=eps)
# Training
original_functional_linear = torch.nn.functional.linear
@wraps(torch.nn.functional.linear)
def functional_linear(input, weight, bias=None):
if input.dtype != weight.data.dtype:
input = input.to(dtype=weight.data.dtype)
if bias is not None and bias.data.dtype != weight.data.dtype:
bias.data = bias.data.to(dtype=weight.data.dtype)
return original_functional_linear(input, weight, bias=bias)
original_functional_conv2d = torch.nn.functional.conv2d
@wraps(torch.nn.functional.conv2d)
def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
if input.dtype != weight.data.dtype:
input = input.to(dtype=weight.data.dtype)
if bias is not None and bias.data.dtype != weight.data.dtype:
bias.data = bias.data.to(dtype=weight.data.dtype)
return original_functional_conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
# A1111 Embedding BF16
original_torch_cat = torch.cat
@wraps(torch.cat)
def torch_cat(tensor, *args, **kwargs):
if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs)
else:
return original_torch_cat(tensor, *args, **kwargs)
# SwinIR BF16:
original_functional_pad = torch.nn.functional.pad
@wraps(torch.nn.functional.pad)
def functional_pad(input, pad, mode='constant', value=None):
if mode == 'reflect' and input.dtype == torch.bfloat16:
return original_functional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16)
else:
return original_functional_pad(input, pad, mode=mode, value=value)
original_torch_tensor = torch.tensor
@wraps(torch.tensor)
def torch_tensor(data, *args, dtype=None, device=None, **kwargs):
if check_device(device):
device = return_xpu(device)
if not device_supports_fp64:
if (isinstance(device, torch.device) and device.type == "xpu") or (isinstance(device, str) and "xpu" in device):
if dtype == torch.float64:
dtype = torch.float32
elif dtype is None and (hasattr(data, "dtype") and (data.dtype == torch.float64 or data.dtype == float)):
dtype = torch.float32
return original_torch_tensor(data, *args, dtype=dtype, device=device, **kwargs)
original_Tensor_to = torch.Tensor.to
@wraps(torch.Tensor.to)
def Tensor_to(self, device=None, *args, **kwargs):
if check_device(device):
return original_Tensor_to(self, return_xpu(device), *args, **kwargs)
else:
return original_Tensor_to(self, device, *args, **kwargs)
original_Tensor_cuda = torch.Tensor.cuda
@wraps(torch.Tensor.cuda)
def Tensor_cuda(self, device=None, *args, **kwargs):
if check_device(device):
return original_Tensor_cuda(self, return_xpu(device), *args, **kwargs)
else:
return original_Tensor_cuda(self, device, *args, **kwargs)
original_Tensor_pin_memory = torch.Tensor.pin_memory
@wraps(torch.Tensor.pin_memory)
def Tensor_pin_memory(self, device=None, *args, **kwargs):
if device is None:
device = "xpu"
if check_device(device):
return original_Tensor_pin_memory(self, return_xpu(device), *args, **kwargs)
else:
return original_Tensor_pin_memory(self, device, *args, **kwargs)
original_UntypedStorage_init = torch.UntypedStorage.__init__
@wraps(torch.UntypedStorage.__init__)
def UntypedStorage_init(*args, device=None, **kwargs):
if check_device(device):
return original_UntypedStorage_init(*args, device=return_xpu(device), **kwargs)
else:
return original_UntypedStorage_init(*args, device=device, **kwargs)
original_UntypedStorage_cuda = torch.UntypedStorage.cuda
@wraps(torch.UntypedStorage.cuda)
def UntypedStorage_cuda(self, device=None, *args, **kwargs):
if check_device(device):
return original_UntypedStorage_cuda(self, return_xpu(device), *args, **kwargs)
else:
return original_UntypedStorage_cuda(self, device, *args, **kwargs)
original_torch_empty = torch.empty
@wraps(torch.empty)
def torch_empty(*args, device=None, **kwargs):
if check_device(device):
return original_torch_empty(*args, device=return_xpu(device), **kwargs)
else:
return original_torch_empty(*args, device=device, **kwargs)
original_torch_randn = torch.randn
@wraps(torch.randn)
def torch_randn(*args, device=None, dtype=None, **kwargs):
if dtype == bytes:
dtype = None
if check_device(device):
return original_torch_randn(*args, device=return_xpu(device), **kwargs)
else:
return original_torch_randn(*args, device=device, **kwargs)
original_torch_ones = torch.ones
@wraps(torch.ones)
def torch_ones(*args, device=None, **kwargs):
if check_device(device):
return original_torch_ones(*args, device=return_xpu(device), **kwargs)
else:
return original_torch_ones(*args, device=device, **kwargs)
original_torch_zeros = torch.zeros
@wraps(torch.zeros)
def torch_zeros(*args, device=None, **kwargs):
if check_device(device):
return original_torch_zeros(*args, device=return_xpu(device), **kwargs)
else:
return original_torch_zeros(*args, device=device, **kwargs)
original_torch_linspace = torch.linspace
@wraps(torch.linspace)
def torch_linspace(*args, device=None, **kwargs):
if check_device(device):
return original_torch_linspace(*args, device=return_xpu(device), **kwargs)
else:
return original_torch_linspace(*args, device=device, **kwargs)
original_torch_Generator = torch.Generator
@wraps(torch.Generator)
def torch_Generator(device=None):
if check_device(device):
return original_torch_Generator(return_xpu(device))
else:
return original_torch_Generator(device)
original_torch_load = torch.load
@wraps(torch.load)
def torch_load(f, map_location=None, *args, **kwargs):
if map_location is None:
map_location = "xpu"
if check_device(map_location):
return original_torch_load(f, *args, map_location=return_xpu(map_location), **kwargs)
else:
return original_torch_load(f, *args, map_location=map_location, **kwargs)
# Hijack Functions:
def ipex_hijacks():
CondFunc('torch.Tensor.to',
lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
CondFunc('torch.Tensor.cuda',
lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
CondFunc('torch.empty',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.load',
lambda orig_func, *args, map_location=None, **kwargs: orig_func(*args, return_xpu(map_location), **kwargs),
lambda orig_func, *args, map_location=None, **kwargs: map_location is None or check_device(map_location))
CondFunc('torch.randn',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.ones',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.zeros',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.tensor',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.linspace',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
torch.tensor = torch_tensor
torch.Tensor.to = Tensor_to
torch.Tensor.cuda = Tensor_cuda
torch.Tensor.pin_memory = Tensor_pin_memory
torch.UntypedStorage.__init__ = UntypedStorage_init
torch.UntypedStorage.cuda = UntypedStorage_cuda
torch.empty = torch_empty
torch.randn = torch_randn
torch.ones = torch_ones
torch.zeros = torch_zeros
torch.linspace = torch_linspace
torch.Generator = torch_Generator
torch.load = torch_load
CondFunc('torch.Generator',
lambda orig_func, device=None: torch.xpu.Generator(device),
lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu")
CondFunc('torch.batch_norm',
lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input,
weight if weight is not None else torch.ones(input.size()[1], device=input.device),
bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs),
lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu"))
CondFunc('torch.instance_norm',
lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input,
weight if weight is not None else torch.ones(input.size()[1], device=input.device),
bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs),
lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu"))
#Functions with dtype errors:
CondFunc('torch.nn.modules.GroupNorm.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
CondFunc('torch.nn.modules.linear.Linear.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
CondFunc('torch.nn.modules.conv.Conv2d.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
CondFunc('torch.nn.functional.layer_norm',
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs),
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
weight is not None and input.dtype != weight.data.dtype)
#Diffusers Float64 (ARC GPUs doesn't support double or Float64):
if not torch.xpu.has_fp64_dtype():
CondFunc('torch.from_numpy',
lambda orig_func, ndarray: orig_func(ndarray.astype('float32')),
lambda orig_func, ndarray: ndarray.dtype == float)
#Broken functions when torch.cuda.is_available is True:
CondFunc('torch.utils.data.dataloader._BaseDataLoaderIter.__init__',
lambda orig_func, *args, **kwargs: ipex_no_cuda(orig_func, *args, **kwargs),
lambda orig_func, *args, **kwargs: True)
#Functions that make compile mad with CondFunc:
torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = _shutdown_workers
torch.nn.DataParallel = DummyDataParallel
torch.autocast = ipex_autocast
torch.cat = torch_cat
torch.linalg.solve = linalg_solve
torch.nn.functional.interpolate = interpolate
torch.backends.cuda.sdp_kernel = return_null_context
torch.nn.DataParallel = DummyDataParallel
torch.UntypedStorage.is_cuda = is_cuda
torch.amp.autocast_mode.autocast.__init__ = autocast_init
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
torch.nn.functional.group_norm = functional_group_norm
torch.nn.functional.layer_norm = functional_layer_norm
torch.nn.functional.linear = functional_linear
torch.nn.functional.conv2d = functional_conv2d
torch.nn.functional.interpolate = interpolate
torch.nn.functional.pad = functional_pad
torch.bmm = torch_bmm
torch.cat = torch_cat
if not device_supports_fp64:
torch.from_numpy = from_numpy
torch.as_tensor = as_tensor

View File

@@ -9,7 +9,7 @@ import numpy as np
import PIL.Image
import torch
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
import diffusers
from diffusers import SchedulerMixin, StableDiffusionPipeline
@@ -17,7 +17,6 @@ from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
from diffusers.utils import logging
try:
from diffusers.utils import PIL_INTERPOLATION
except ImportError:
@@ -520,6 +519,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
image_encoder: CLIPVisionModelWithProjection = None,
clip_skip: int = 1,
):
super().__init__(
@@ -531,32 +531,11 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
requires_safety_checker=requires_safety_checker,
image_encoder=image_encoder,
)
self.clip_skip = clip_skip
self.custom_clip_skip = clip_skip
self.__init__additional__()
# else:
# def __init__(
# self,
# vae: AutoencoderKL,
# text_encoder: CLIPTextModel,
# tokenizer: CLIPTokenizer,
# unet: UNet2DConditionModel,
# scheduler: SchedulerMixin,
# safety_checker: StableDiffusionSafetyChecker,
# feature_extractor: CLIPFeatureExtractor,
# ):
# super().__init__(
# vae=vae,
# text_encoder=text_encoder,
# tokenizer=tokenizer,
# unet=unet,
# scheduler=scheduler,
# safety_checker=safety_checker,
# feature_extractor=feature_extractor,
# )
# self.__init__additional__()
def __init__additional__(self):
if not hasattr(self, "vae_scale_factor"):
setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
@@ -624,7 +603,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
prompt=prompt,
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
max_embeddings_multiples=max_embeddings_multiples,
clip_skip=self.clip_skip,
clip_skip=self.custom_clip_skip,
)
bs_embed, seq_len, _ = text_embeddings.shape
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
@@ -646,7 +625,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if height % 8 != 0 or width % 8 != 0:
print(height, width)
logger.info(f'{height} {width}')
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (

View File

@@ -3,19 +3,20 @@
import math
import os
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from library.device_utils import init_ipex
init_ipex()
import diffusers
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel
from safetensors.torch import load_file, save_file
from library.original_unet import UNet2DConditionModel
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# DiffUsers版StableDiffusionのモデルパラメータ
NUM_TRAIN_TIMESTEPS = 1000
@@ -571,9 +572,9 @@ def convert_ldm_clip_checkpoint_v1(checkpoint):
if key.startswith("cond_stage_model.transformer"):
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
# support checkpoint without position_ids (invalid checkpoint)
if "text_model.embeddings.position_ids" not in text_model_dict:
text_model_dict["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) # 77 is the max length of the text
# remove position_ids for newer transformer, which causes error :(
if "text_model.embeddings.position_ids" in text_model_dict:
text_model_dict.pop("text_model.embeddings.position_ids")
return text_model_dict
@@ -947,7 +948,7 @@ def convert_vae_state_dict(vae_state_dict):
for k, v in new_state_dict.items():
for weight_name in weights_to_convert:
if f"mid.attn_1.{weight_name}.weight" in k:
# print(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1")
# logger.info(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1")
new_state_dict[k] = reshape_weight_for_sd(v)
return new_state_dict
@@ -1005,7 +1006,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt
unet = UNet2DConditionModel(**unet_config).to(device)
info = unet.load_state_dict(converted_unet_checkpoint)
print("loading u-net:", info)
logger.info(f"loading u-net: {info}")
# Convert the VAE model.
vae_config = create_vae_diffusers_config()
@@ -1013,7 +1014,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt
vae = AutoencoderKL(**vae_config).to(device)
info = vae.load_state_dict(converted_vae_checkpoint)
print("loading vae:", info)
logger.info(f"loading vae: {info}")
# convert text_model
if v2:
@@ -1047,7 +1048,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt
# logging.set_verbosity_error() # don't show annoying warning
# text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
# logging.set_verbosity_warning()
# print(f"config: {text_model.config}")
# logger.info(f"config: {text_model.config}")
cfg = CLIPTextConfig(
vocab_size=49408,
hidden_size=768,
@@ -1070,7 +1071,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt
)
text_model = CLIPTextModel._from_config(cfg)
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
print("loading text encoder:", info)
logger.info(f"loading text encoder: {info}")
return text_model, vae, unet
@@ -1145,7 +1146,7 @@ def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=Fals
# 最後の層などを捏造するか
if make_dummy_weights:
print("make dummy weights for resblock.23, text_projection and logit scale.")
logger.info("make dummy weights for resblock.23, text_projection and logit scale.")
keys = list(new_sd.keys())
for key in keys:
if key.startswith("transformer.resblocks.22."):
@@ -1242,8 +1243,13 @@ def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_mod
if vae is None:
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
# original U-Net cannot be saved, so we need to convert it to the Diffusers version
# TODO this consumes a lot of memory
diffusers_unet = diffusers.UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet")
diffusers_unet.load_state_dict(unet.state_dict())
pipeline = StableDiffusionPipeline(
unet=unet,
unet=diffusers_unet,
text_encoder=text_encoder,
vae=vae,
scheduler=scheduler,
@@ -1259,14 +1265,14 @@ VAE_PREFIX = "first_stage_model."
def load_vae(vae_id, dtype):
print(f"load VAE: {vae_id}")
logger.info(f"load VAE: {vae_id}")
if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
# Diffusers local/remote
try:
vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
except EnvironmentError as e:
print(f"exception occurs in loading vae: {e}")
print("retry with subfolder='vae'")
logger.error(f"exception occurs in loading vae: {e}")
logger.error("retry with subfolder='vae'")
vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
return vae
@@ -1307,19 +1313,19 @@ def load_vae(vae_id, dtype):
def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
max_width, max_height = max_reso
max_area = (max_width // divisible) * (max_height // divisible)
max_area = max_width * max_height
resos = set()
size = int(math.sqrt(max_area)) * divisible
resos.add((size, size))
width = int(math.sqrt(max_area) // divisible) * divisible
resos.add((width, width))
size = min_size
while size <= max_size:
width = size
height = min(max_size, (max_area // (width // divisible)) * divisible)
resos.add((width, height))
resos.add((height, width))
width = min_size
while width <= max_size:
height = min(max_size, int((max_area // width) // divisible) * divisible)
if height >= min_size:
resos.add((width, height))
resos.add((height, width))
# # make additional resos
# if width >= height and width - divisible >= min_size:
@@ -1329,7 +1335,7 @@ def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64)
# resos.add((width, height - divisible))
# resos.add((height - divisible, width))
size += divisible
width += divisible
resos = list(resos)
resos.sort()
@@ -1338,13 +1344,13 @@ def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64)
if __name__ == "__main__":
resos = make_bucket_resolutions((512, 768))
print(len(resos))
print(resos)
logger.info(f"{len(resos)}")
logger.info(f"{resos}")
aspect_ratios = [w / h for w, h in resos]
print(aspect_ratios)
logger.info(f"{aspect_ratios}")
ars = set()
for ar in aspect_ratios:
if ar in ars:
print("error! duplicate ar:", ar)
logger.error(f"error! duplicate ar: {ar}")
ars.add(ar)

View File

@@ -113,6 +113,10 @@ import torch
from torch import nn
from torch.nn import functional as F
from einops import rearrange
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
BLOCK_OUT_CHANNELS: Tuple[int] = (320, 640, 1280, 1280)
TIMESTEP_INPUT_DIM = BLOCK_OUT_CHANNELS[0]
@@ -361,6 +365,23 @@ def get_timestep_embedding(
return emb
# Deep Shrink: We do not common this function, because minimize dependencies.
def resize_like(x, target, mode="bicubic", align_corners=False):
org_dtype = x.dtype
if org_dtype == torch.bfloat16:
x = x.to(torch.float32)
if x.shape[-2:] != target.shape[-2:]:
if mode == "nearest":
x = F.interpolate(x, size=target.shape[-2:], mode=mode)
else:
x = F.interpolate(x, size=target.shape[-2:], mode=mode, align_corners=align_corners)
if org_dtype == torch.bfloat16:
x = x.to(org_dtype)
return x
class SampleOutput:
def __init__(self, sample):
self.sample = sample
@@ -569,6 +590,9 @@ class CrossAttention(nn.Module):
self.use_memory_efficient_attention_mem_eff = False
self.use_sdpa = False
# Attention processor
self.processor = None
def set_use_memory_efficient_attention(self, xformers, mem_eff):
self.use_memory_efficient_attention_xformers = xformers
self.use_memory_efficient_attention_mem_eff = mem_eff
@@ -590,7 +614,28 @@ class CrossAttention(nn.Module):
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
def forward(self, hidden_states, context=None, mask=None):
def set_processor(self):
return self.processor
def get_processor(self):
return self.processor
def forward(self, hidden_states, context=None, mask=None, **kwargs):
if self.processor is not None:
(
hidden_states,
encoder_hidden_states,
attention_mask,
) = translate_attention_names_from_diffusers(
hidden_states=hidden_states, context=context, mask=mask, **kwargs
)
return self.processor(
attn=self,
hidden_states=hidden_states,
encoder_hidden_states=context,
attention_mask=mask,
**kwargs
)
if self.use_memory_efficient_attention_xformers:
return self.forward_memory_efficient_xformers(hidden_states, context, mask)
if self.use_memory_efficient_attention_mem_eff:
@@ -703,6 +748,21 @@ class CrossAttention(nn.Module):
out = self.to_out[0](out)
return out
def translate_attention_names_from_diffusers(
hidden_states: torch.FloatTensor,
context: Optional[torch.FloatTensor] = None,
mask: Optional[torch.FloatTensor] = None,
# HF naming
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None
):
# translate from hugging face diffusers
context = context if context is not None else encoder_hidden_states
# translate from hugging face diffusers
mask = mask if mask is not None else attention_mask
return hidden_states, context, mask
# feedforward
class GEGLU(nn.Module):
@@ -1130,6 +1190,7 @@ class UpBlock2D(nn.Module):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
@@ -1205,9 +1266,9 @@ class CrossAttnUpBlock2D(nn.Module):
for attn in self.attentions:
attn.set_use_memory_efficient_attention(xformers, mem_eff)
def set_use_sdpa(self, spda):
def set_use_sdpa(self, sdpa):
for attn in self.attentions:
attn.set_use_sdpa(spda)
attn.set_use_sdpa(sdpa)
def forward(
self,
@@ -1221,6 +1282,7 @@ class CrossAttnUpBlock2D(nn.Module):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
@@ -1322,7 +1384,7 @@ class UNet2DConditionModel(nn.Module):
):
super().__init__()
assert sample_size is not None, "sample_size must be specified"
print(
logger.info(
f"UNet2DConditionModel: {sample_size}, {attention_head_dim}, {cross_attention_dim}, {use_linear_projection}, {upcast_attention}"
)
@@ -1331,7 +1393,7 @@ class UNet2DConditionModel(nn.Module):
self.out_channels = OUT_CHANNELS
self.sample_size = sample_size
self.prepare_config()
self.prepare_config(sample_size=sample_size)
# state_dictの書式が変わるのでmoduleの持ち方は変えられない
@@ -1418,8 +1480,8 @@ class UNet2DConditionModel(nn.Module):
self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1)
# region diffusers compatibility
def prepare_config(self):
self.config = SimpleNamespace()
def prepare_config(self, *args, **kwargs):
self.config = SimpleNamespace(**kwargs)
@property
def dtype(self) -> torch.dtype:
@@ -1456,7 +1518,7 @@ class UNet2DConditionModel(nn.Module):
def set_gradient_checkpointing(self, value=False):
modules = self.down_blocks + [self.mid_block] + self.up_blocks
for module in modules:
print(module.__class__.__name__, module.gradient_checkpointing, "->", value)
logger.info(f"{module.__class__.__name__} {module.gradient_checkpointing} -> {value}")
module.gradient_checkpointing = value
# endregion
@@ -1519,7 +1581,6 @@ class UNet2DConditionModel(nn.Module):
# 2. pre-process
sample = self.conv_in(sample)
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
# downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
@@ -1604,3 +1665,255 @@ class UNet2DConditionModel(nn.Module):
timesteps = timesteps.expand(sample.shape[0])
return timesteps
class InferUNet2DConditionModel:
def __init__(self, original_unet: UNet2DConditionModel):
self.delegate = original_unet
# override original model's forward method: because forward is not called by `__call__`
# overriding `__call__` is not enough, because nn.Module.forward has a special handling
self.delegate.forward = self.forward
# override original model's up blocks' forward method
for up_block in self.delegate.up_blocks:
if up_block.__class__.__name__ == "UpBlock2D":
def resnet_wrapper(func, block):
def forward(*args, **kwargs):
return func(block, *args, **kwargs)
return forward
up_block.forward = resnet_wrapper(self.up_block_forward, up_block)
elif up_block.__class__.__name__ == "CrossAttnUpBlock2D":
def cross_attn_up_wrapper(func, block):
def forward(*args, **kwargs):
return func(block, *args, **kwargs)
return forward
up_block.forward = cross_attn_up_wrapper(self.cross_attn_up_block_forward, up_block)
# Deep Shrink
self.ds_depth_1 = None
self.ds_depth_2 = None
self.ds_timesteps_1 = None
self.ds_timesteps_2 = None
self.ds_ratio = None
# call original model's methods
def __getattr__(self, name):
return getattr(self.delegate, name)
def __call__(self, *args, **kwargs):
return self.delegate(*args, **kwargs)
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
if ds_depth_1 is None:
logger.info("Deep Shrink is disabled.")
self.ds_depth_1 = None
self.ds_timesteps_1 = None
self.ds_depth_2 = None
self.ds_timesteps_2 = None
self.ds_ratio = None
else:
logger.info(
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
)
self.ds_depth_1 = ds_depth_1
self.ds_timesteps_1 = ds_timesteps_1
self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
self.ds_ratio = ds_ratio
def up_block_forward(self, _self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
for resnet in _self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# Deep Shrink
if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
hidden_states = resize_like(hidden_states, res_hidden_states)
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
if _self.upsamplers is not None:
for upsampler in _self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
def cross_attn_up_block_forward(
self,
_self,
hidden_states,
res_hidden_states_tuple,
temb=None,
encoder_hidden_states=None,
upsample_size=None,
):
for resnet, attn in zip(_self.resnets, _self.attentions):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
# Deep Shrink
if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
hidden_states = resize_like(hidden_states, res_hidden_states)
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
if _self.upsamplers is not None:
for upsampler in _self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None,
return_dict: bool = True,
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
mid_block_additional_residual: Optional[torch.Tensor] = None,
) -> Union[Dict, Tuple]:
r"""
current implementation is a copy of `UNet2DConditionModel.forward()` with Deep Shrink.
"""
r"""
Args:
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a dict instead of a plain tuple.
Returns:
`SampleOutput` or `tuple`:
`SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
"""
_self = self.delegate
# By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
# However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary.
# デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある
# ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する
# 多分画質が悪くなるので、64で割り切れるようにしておくのが良い
default_overall_up_factor = 2**_self.num_upsamplers
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
# 64で割り切れないときはupsamplerにサイズを伝える
forward_upsample_size = False
upsample_size = None
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
# logger.info("Forward upsample size to force interpolation output size.")
forward_upsample_size = True
# 1. time
timesteps = timestep
timesteps = _self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理
t_emb = _self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
# timestepsは重みを含まないので常にfloat32のテンソルを返す
# しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある
# time_projでキャストしておけばいいんじゃね
t_emb = t_emb.to(dtype=_self.dtype)
emb = _self.time_embedding(t_emb)
# 2. pre-process
sample = _self.conv_in(sample)
down_block_res_samples = (sample,)
for depth, downsample_block in enumerate(_self.down_blocks):
# Deep Shrink
if self.ds_depth_1 is not None:
if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
self.ds_depth_2 is not None
and depth == self.ds_depth_2
and timesteps[0] < self.ds_timesteps_1
and timesteps[0] >= self.ds_timesteps_2
):
org_dtype = sample.dtype
if org_dtype == torch.bfloat16:
sample = sample.to(torch.float32)
sample = F.interpolate(sample, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype)
# downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
# まあこちらのほうがわかりやすいかもしれない
if downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
# skip connectionにControlNetの出力を追加する
if down_block_additional_residuals is not None:
down_block_res_samples = list(down_block_res_samples)
for i in range(len(down_block_res_samples)):
down_block_res_samples[i] += down_block_additional_residuals[i]
down_block_res_samples = tuple(down_block_res_samples)
# 4. mid
sample = _self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
# ControlNetの出力を追加する
if mid_block_additional_residual is not None:
sample += mid_block_additional_residual
# 5. up
for i, upsample_block in enumerate(_self.up_blocks):
is_final_block = i == len(_self.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection
# if we have not reached the final block and need to forward the upsample size, we do it here
# 前述のように最後のブロック以外ではupsample_sizeを伝える
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if upsample_block.has_cross_attention:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
upsample_size=upsample_size,
)
else:
sample = upsample_block(
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
)
# 6. post-process
sample = _self.conv_norm_out(sample)
sample = _self.conv_act(sample)
sample = _self.conv_out(sample)
if not return_dict:
return (sample,)
return SampleOutput(sample=sample)

View File

@@ -5,6 +5,10 @@ from io import BytesIO
import os
from typing import List, Optional, Tuple, Union
import safetensors
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
r"""
# Metadata Example
@@ -231,7 +235,7 @@ def build_metadata(
# # assert all values are filled
# assert all([v is not None for v in metadata.values()]), metadata
if not all([v is not None for v in metadata.values()]):
print(f"Internal error: some metadata values are None: {metadata}")
logger.error(f"Internal error: some metadata values are None: {metadata}")
return metadata

View File

@@ -923,7 +923,11 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
if up1 is not None:
uncond_pool = up1
dtype = self.unet.dtype
unet_dtype = self.unet.dtype
dtype = unet_dtype
if hasattr(dtype, "itemsize") and dtype.itemsize == 1: # fp8
dtype = torch.float16
self.unet.to(dtype)
# 4. Preprocess image and mask
if isinstance(image, PIL.Image.Image):
@@ -1028,6 +1032,7 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
if is_cancelled_callback is not None and is_cancelled_callback():
return None
self.unet.to(unet_dtype)
return latents
def latents_to_image(self, latents):

View File

@@ -7,7 +7,10 @@ from typing import List
from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
from library import model_util
from library import sdxl_original_unet
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
VAE_SCALE_FACTOR = 0.13025
MODEL_VERSION_SDXL_BASE_V1_0 = "sdxl_base_v1-0"
@@ -100,7 +103,7 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
key = key.replace(".ln_final", ".final_layer_norm")
# ckpt from comfy has this key: text_model.encoder.text_model.embeddings.position_ids
elif ".embeddings.position_ids" in key:
key = None # remove this key: make position_ids by ourselves
key = None # remove this key: position_ids is not used in newer transformers
return key
keys = list(checkpoint.keys())
@@ -126,13 +129,15 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
# original SD にはないので、position_idsを追加
position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64)
new_sd["text_model.embeddings.position_ids"] = position_ids
# logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す
logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None)
# temporary workaround for text_projection.weight.weight for Playground-v2
if "text_projection.weight.weight" in new_sd:
logger.info("convert_sdxl_text_encoder_2_checkpoint: convert text_projection.weight.weight to text_projection.weight")
new_sd["text_projection.weight"] = new_sd["text_projection.weight.weight"]
del new_sd["text_projection.weight.weight"]
return new_sd, logit_scale
@@ -184,20 +189,20 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
checkpoint = None
# U-Net
print("building U-Net")
logger.info("building U-Net")
with init_empty_weights():
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
print("loading U-Net from checkpoint")
logger.info("loading U-Net from checkpoint")
unet_sd = {}
for k in list(state_dict.keys()):
if k.startswith("model.diffusion_model."):
unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k)
info = _load_state_dict_on_device(unet, unet_sd, device=map_location, dtype=dtype)
print("U-Net: ", info)
logger.info(f"U-Net: {info}")
# Text Encoders
print("building text encoders")
logger.info("building text encoders")
# Text Encoder 1 is same to Stability AI's SDXL
text_model1_cfg = CLIPTextConfig(
@@ -250,7 +255,7 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
with init_empty_weights():
text_model2 = CLIPTextModelWithProjection(text_model2_cfg)
print("loading text encoders from checkpoint")
logger.info("loading text encoders from checkpoint")
te1_sd = {}
te2_sd = {}
for k in list(state_dict.keys()):
@@ -258,28 +263,28 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
te1_sd[k.replace("conditioner.embedders.0.transformer.", "")] = state_dict.pop(k)
elif k.startswith("conditioner.embedders.1.model."):
te2_sd[k] = state_dict.pop(k)
# 一部のposition_idsがないモデルへの対応 / add position_ids for some models
if "text_model.embeddings.position_ids" not in te1_sd:
te1_sd["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0)
# 最新の transformers では position_ids を含むとエラーになるので削除 / remove position_ids for latest transformers
if "text_model.embeddings.position_ids" in te1_sd:
te1_sd.pop("text_model.embeddings.position_ids")
info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32
print("text encoder 1:", info1)
logger.info(f"text encoder 1: {info1}")
converted_sd, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77)
info2 = _load_state_dict_on_device(text_model2, converted_sd, device=map_location) # remain fp32
print("text encoder 2:", info2)
logger.info(f"text encoder 2: {info2}")
# prepare vae
print("building VAE")
logger.info("building VAE")
vae_config = model_util.create_vae_diffusers_config()
with init_empty_weights():
vae = AutoencoderKL(**vae_config)
print("loading VAE from checkpoint")
logger.info("loading VAE from checkpoint")
converted_vae_checkpoint = model_util.convert_ldm_vae_checkpoint(state_dict, vae_config)
info = _load_state_dict_on_device(vae, converted_vae_checkpoint, device=map_location, dtype=dtype)
print("VAE:", info)
logger.info(f"VAE: {info}")
ckpt_info = (epoch, global_step) if epoch is not None else None
return text_model1, text_model2, vae, unet, logit_scale, ckpt_info

View File

@@ -24,13 +24,18 @@
import math
from types import SimpleNamespace
from typing import Optional
from typing import Any, Optional
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import functional as F
from einops import rearrange
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
IN_CHANNELS: int = 4
OUT_CHANNELS: int = 4
@@ -266,6 +271,23 @@ def get_timestep_embedding(
return emb
# Deep Shrink: We do not common this function, because minimize dependencies.
def resize_like(x, target, mode="bicubic", align_corners=False):
org_dtype = x.dtype
if org_dtype == torch.bfloat16:
x = x.to(torch.float32)
if x.shape[-2:] != target.shape[-2:]:
if mode == "nearest":
x = F.interpolate(x, size=target.shape[-2:], mode=mode)
else:
x = F.interpolate(x, size=target.shape[-2:], mode=mode, align_corners=align_corners)
if org_dtype == torch.bfloat16:
x = x.to(org_dtype)
return x
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
if self.weight.dtype != torch.float32:
@@ -315,7 +337,7 @@ class ResnetBlock2D(nn.Module):
def forward(self, x, emb):
if self.training and self.gradient_checkpointing:
# print("ResnetBlock2D: gradient_checkpointing")
# logger.info("ResnetBlock2D: gradient_checkpointing")
def create_custom_forward(func):
def custom_forward(*inputs):
@@ -349,7 +371,7 @@ class Downsample2D(nn.Module):
def forward(self, hidden_states):
if self.training and self.gradient_checkpointing:
# print("Downsample2D: gradient_checkpointing")
# logger.info("Downsample2D: gradient_checkpointing")
def create_custom_forward(func):
def custom_forward(*inputs):
@@ -636,7 +658,7 @@ class BasicTransformerBlock(nn.Module):
def forward(self, hidden_states, context=None, timestep=None):
if self.training and self.gradient_checkpointing:
# print("BasicTransformerBlock: checkpointing")
# logger.info("BasicTransformerBlock: checkpointing")
def create_custom_forward(func):
def custom_forward(*inputs):
@@ -779,7 +801,7 @@ class Upsample2D(nn.Module):
def forward(self, hidden_states, output_size=None):
if self.training and self.gradient_checkpointing:
# print("Upsample2D: gradient_checkpointing")
# logger.info("Upsample2D: gradient_checkpointing")
def create_custom_forward(func):
def custom_forward(*inputs):
@@ -1029,7 +1051,7 @@ class SdxlUNet2DConditionModel(nn.Module):
for block in blocks:
for module in block:
if hasattr(module, "set_use_memory_efficient_attention"):
# print(module.__class__.__name__)
# logger.info(module.__class__.__name__)
module.set_use_memory_efficient_attention(xformers, mem_eff)
def set_use_sdpa(self, sdpa: bool) -> None:
@@ -1044,7 +1066,7 @@ class SdxlUNet2DConditionModel(nn.Module):
for block in blocks:
for module in block.modules():
if hasattr(module, "gradient_checkpointing"):
# print(module.__class__.__name__, module.gradient_checkpointing, "->", value)
# logger.info(f{module.__class__.__name__} {module.gradient_checkpointing} -> {value}")
module.gradient_checkpointing = value
# endregion
@@ -1054,7 +1076,7 @@ class SdxlUNet2DConditionModel(nn.Module):
timesteps = timesteps.expand(x.shape[0])
hs = []
t_emb = get_timestep_embedding(timesteps, self.model_channels) # , repeat_only=False)
t_emb = get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0) # , repeat_only=False)
t_emb = t_emb.to(x.dtype)
emb = self.time_embed(t_emb)
@@ -1066,7 +1088,7 @@ class SdxlUNet2DConditionModel(nn.Module):
def call_module(module, h, emb, context):
x = h
for layer in module:
# print(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None)
# logger.info(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None)
if isinstance(layer, ResnetBlock2D):
x = layer(x, emb)
elif isinstance(layer, Transformer2DModel):
@@ -1077,6 +1099,7 @@ class SdxlUNet2DConditionModel(nn.Module):
# h = x.type(self.dtype)
h = x
for module in self.input_blocks:
h = call_module(module, h, emb, context)
hs.append(h)
@@ -1093,10 +1116,125 @@ class SdxlUNet2DConditionModel(nn.Module):
return h
class InferSdxlUNet2DConditionModel:
def __init__(self, original_unet: SdxlUNet2DConditionModel, **kwargs):
self.delegate = original_unet
# override original model's forward method: because forward is not called by `__call__`
# overriding `__call__` is not enough, because nn.Module.forward has a special handling
self.delegate.forward = self.forward
# Deep Shrink
self.ds_depth_1 = None
self.ds_depth_2 = None
self.ds_timesteps_1 = None
self.ds_timesteps_2 = None
self.ds_ratio = None
# call original model's methods
def __getattr__(self, name):
return getattr(self.delegate, name)
def __call__(self, *args, **kwargs):
return self.delegate(*args, **kwargs)
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
if ds_depth_1 is None:
logger.info("Deep Shrink is disabled.")
self.ds_depth_1 = None
self.ds_timesteps_1 = None
self.ds_depth_2 = None
self.ds_timesteps_2 = None
self.ds_ratio = None
else:
logger.info(
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
)
self.ds_depth_1 = ds_depth_1
self.ds_timesteps_1 = ds_timesteps_1
self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
self.ds_ratio = ds_ratio
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
r"""
current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink.
"""
_self = self.delegate
# broadcast timesteps to batch dimension
timesteps = timesteps.expand(x.shape[0])
hs = []
t_emb = get_timestep_embedding(timesteps, _self.model_channels, downscale_freq_shift=0) # , repeat_only=False)
t_emb = t_emb.to(x.dtype)
emb = _self.time_embed(t_emb)
assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
# assert x.dtype == _self.dtype
emb = emb + _self.label_emb(y)
def call_module(module, h, emb, context):
x = h
for layer in module:
# print(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None)
if isinstance(layer, ResnetBlock2D):
x = layer(x, emb)
elif isinstance(layer, Transformer2DModel):
x = layer(x, context)
else:
x = layer(x)
return x
# h = x.type(self.dtype)
h = x
for depth, module in enumerate(_self.input_blocks):
# Deep Shrink
if self.ds_depth_1 is not None:
if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
self.ds_depth_2 is not None
and depth == self.ds_depth_2
and timesteps[0] < self.ds_timesteps_1
and timesteps[0] >= self.ds_timesteps_2
):
# print("downsample", h.shape, self.ds_ratio)
org_dtype = h.dtype
if org_dtype == torch.bfloat16:
h = h.to(torch.float32)
h = F.interpolate(h, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype)
h = call_module(module, h, emb, context)
hs.append(h)
h = call_module(_self.middle_block, h, emb, context)
for module in _self.output_blocks:
# Deep Shrink
if self.ds_depth_1 is not None:
if hs[-1].shape[-2:] != h.shape[-2:]:
# print("upsample", h.shape, hs[-1].shape)
h = resize_like(h, hs[-1])
h = torch.cat([h, hs.pop()], dim=1)
h = call_module(module, h, emb, context)
# Deep Shrink: in case of depth 0
if self.ds_depth_1 == 0 and h.shape[-2:] != x.shape[-2:]:
# print("upsample", h.shape, x.shape)
h = resize_like(h, x)
h = h.type(x.dtype)
h = call_module(_self.out, h, emb, context)
return h
if __name__ == "__main__":
import time
print("create unet")
logger.info("create unet")
unet = SdxlUNet2DConditionModel()
unet.to("cuda")
@@ -1105,7 +1243,7 @@ if __name__ == "__main__":
unet.train()
# 使用メモリ量確認用の疑似学習ループ
print("preparing optimizer")
logger.info("preparing optimizer")
# optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working
@@ -1120,12 +1258,12 @@ if __name__ == "__main__":
scaler = torch.cuda.amp.GradScaler(enabled=True)
print("start training")
logger.info("start training")
steps = 10
batch_size = 1
for step in range(steps):
print(f"step {step}")
logger.info(f"step {step}")
if step == 1:
time_start = time.perf_counter()
@@ -1145,4 +1283,4 @@ if __name__ == "__main__":
optimizer.zero_grad(set_to_none=True)
time_end = time.perf_counter()
print(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps")
logger.info(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps")

View File

@@ -1,14 +1,21 @@
import argparse
import gc
import math
import os
from typing import Optional
import torch
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate import init_empty_weights
from tqdm import tqdm
from transformers import CLIPTokenizer
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
@@ -17,11 +24,10 @@ TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
def load_target_model(args, accelerator, model_version: str, weight_dtype):
# load models for each process
model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16
for pi in range(accelerator.state.num_processes):
if pi == accelerator.state.local_process_index:
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
(
load_stable_diffusion_format,
@@ -47,12 +53,9 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
unet.to(accelerator.device)
vae.to(accelerator.device)
gc.collect()
torch.cuda.empty_cache()
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
text_encoder1, text_encoder2, unet = train_util.transform_models_if_DDP([text_encoder1, text_encoder2, unet])
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
@@ -64,7 +67,7 @@ def _load_target_model(
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
if load_stable_diffusion_format:
print(f"load StableDiffusion checkpoint: {name_or_path}")
logger.info(f"load StableDiffusion checkpoint: {name_or_path}")
(
text_encoder1,
text_encoder2,
@@ -78,7 +81,7 @@ def _load_target_model(
from diffusers import StableDiffusionXLPipeline
variant = "fp16" if weight_dtype == torch.float16 else None
print(f"load Diffusers pretrained models: {name_or_path}, variant={variant}")
logger.info(f"load Diffusers pretrained models: {name_or_path}, variant={variant}")
try:
try:
pipe = StableDiffusionXLPipeline.from_pretrained(
@@ -86,12 +89,12 @@ def _load_target_model(
)
except EnvironmentError as ex:
if variant is not None:
print("try to load fp32 model")
logger.info("try to load fp32 model")
pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, variant=None, tokenizer=None)
else:
raise ex
except EnvironmentError as ex:
print(
logger.error(
f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}"
)
raise ex
@@ -114,7 +117,7 @@ def _load_target_model(
with init_empty_weights():
unet = sdxl_original_unet.SdxlUNet2DConditionModel() # overwrite unet
sdxl_model_util._load_state_dict_on_device(unet, state_dict, device=device, dtype=model_dtype)
print("U-Net converted to original U-Net")
logger.info("U-Net converted to original U-Net")
logit_scale = None
ckpt_info = None
@@ -122,13 +125,13 @@ def _load_target_model(
# VAEを読み込む
if vae_path is not None:
vae = model_util.load_vae(vae_path, weight_dtype)
print("additional VAE loaded")
logger.info("additional VAE loaded")
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
def load_tokenizers(args: argparse.Namespace):
print("prepare tokenizers")
logger.info("prepare tokenizers")
original_paths = [TOKENIZER1_PATH, TOKENIZER2_PATH]
tokeniers = []
@@ -137,14 +140,14 @@ def load_tokenizers(args: argparse.Namespace):
if args.tokenizer_cache_dir:
local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_"))
if os.path.exists(local_tokenizer_path):
print(f"load tokenizer from cache: {local_tokenizer_path}")
logger.info(f"load tokenizer from cache: {local_tokenizer_path}")
tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path)
if tokenizer is None:
tokenizer = CLIPTokenizer.from_pretrained(original_path)
if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
print(f"save Tokenizer to cache: {local_tokenizer_path}")
logger.info(f"save Tokenizer to cache: {local_tokenizer_path}")
tokenizer.save_pretrained(local_tokenizer_path)
if i == 1:
@@ -153,7 +156,7 @@ def load_tokenizers(args: argparse.Namespace):
tokeniers.append(tokenizer)
if hasattr(args, "max_token_length") and args.max_token_length is not None:
print(f"update token length: {args.max_token_length}")
logger.info(f"update token length: {args.max_token_length}")
return tokeniers
@@ -334,23 +337,23 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません"
if args.v_parameterization:
print("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります")
logger.warning("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります")
if args.clip_skip is not None:
print("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")
logger.warning("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")
# if args.multires_noise_iterations:
# print(
# logger.info(
# f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります"
# )
# else:
# if args.noise_offset is None:
# args.noise_offset = DEFAULT_NOISE_OFFSET
# elif args.noise_offset != DEFAULT_NOISE_OFFSET:
# print(
# logger.info(
# f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています"
# )
# print(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
# logger.info(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
assert (
not hasattr(args, "weighted_captions") or not args.weighted_captions
@@ -359,7 +362,7 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin
if supportTextEncoderCaching:
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
args.cache_text_encoder_outputs = True
print(
logger.warning(
"cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / "
+ "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました"
)

View File

@@ -26,7 +26,10 @@ from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
from diffusers.models.vae import DecoderOutput, DiagonalGaussianDistribution
from diffusers.models.autoencoder_kl import AutoencoderKLOutput
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def slice_h(x, num_slices):
# slice with pad 1 both sides: to eliminate side effect of padding of conv2d
@@ -62,7 +65,7 @@ def cat_h(sliced):
return x
def resblock_forward(_self, num_slices, input_tensor, temb):
def resblock_forward(_self, num_slices, input_tensor, temb, **kwargs):
assert _self.upsample is None and _self.downsample is None
assert _self.norm1.num_groups == _self.norm2.num_groups
assert temb is None
@@ -89,7 +92,7 @@ def resblock_forward(_self, num_slices, input_tensor, temb):
# sliced_tensor = torch.chunk(x, num_div, dim=1)
# sliced_weight = torch.chunk(norm.weight, num_div, dim=0)
# sliced_bias = torch.chunk(norm.bias, num_div, dim=0)
# print(sliced_tensor[0].shape, num_div, sliced_weight[0].shape, sliced_bias[0].shape)
# logger.info(sliced_tensor[0].shape, num_div, sliced_weight[0].shape, sliced_bias[0].shape)
# normed_tensor = []
# for i in range(num_div):
# n = torch.group_norm(sliced_tensor[i], norm.num_groups, sliced_weight[i], sliced_bias[i], norm.eps)
@@ -243,7 +246,7 @@ class SlicingEncoder(nn.Module):
self.num_slices = num_slices
div = num_slices / (2 ** (len(self.down_blocks) - 1)) # 深い層はそこまで分割しなくていいので適宜減らす
# print(f"initial divisor: {div}")
# logger.info(f"initial divisor: {div}")
if div >= 2:
div = int(div)
for resnet in self.mid_block.resnets:
@@ -253,11 +256,11 @@ class SlicingEncoder(nn.Module):
for i, down_block in enumerate(self.down_blocks[::-1]):
if div >= 2:
div = int(div)
# print(f"down block: {i} divisor: {div}")
# logger.info(f"down block: {i} divisor: {div}")
for resnet in down_block.resnets:
resnet.forward = wrapper(resblock_forward, resnet, div)
if down_block.downsamplers is not None:
# print("has downsample")
# logger.info("has downsample")
for downsample in down_block.downsamplers:
downsample.forward = wrapper(self.downsample_forward, downsample, div * 2)
div *= 2
@@ -307,7 +310,7 @@ class SlicingEncoder(nn.Module):
def downsample_forward(self, _self, num_slices, hidden_states):
assert hidden_states.shape[1] == _self.channels
assert _self.use_conv and _self.padding == 0
print("downsample forward", num_slices, hidden_states.shape)
logger.info(f"downsample forward {num_slices} {hidden_states.shape}")
org_device = hidden_states.device
cpu_device = torch.device("cpu")
@@ -350,7 +353,7 @@ class SlicingEncoder(nn.Module):
hidden_states = torch.cat([hidden_states, x], dim=2)
hidden_states = hidden_states.to(org_device)
# print("downsample forward done", hidden_states.shape)
# logger.info(f"downsample forward done {hidden_states.shape}")
return hidden_states
@@ -426,7 +429,7 @@ class SlicingDecoder(nn.Module):
self.num_slices = num_slices
div = num_slices / (2 ** (len(self.up_blocks) - 1))
print(f"initial divisor: {div}")
logger.info(f"initial divisor: {div}")
if div >= 2:
div = int(div)
for resnet in self.mid_block.resnets:
@@ -436,11 +439,11 @@ class SlicingDecoder(nn.Module):
for i, up_block in enumerate(self.up_blocks):
if div >= 2:
div = int(div)
# print(f"up block: {i} divisor: {div}")
# logger.info(f"up block: {i} divisor: {div}")
for resnet in up_block.resnets:
resnet.forward = wrapper(resblock_forward, resnet, div)
if up_block.upsamplers is not None:
# print("has upsample")
# logger.info("has upsample")
for upsample in up_block.upsamplers:
upsample.forward = wrapper(self.upsample_forward, upsample, div * 2)
div *= 2
@@ -528,7 +531,7 @@ class SlicingDecoder(nn.Module):
del x
hidden_states = torch.cat(sliced, dim=2)
# print("us hidden_states", hidden_states.shape)
# logger.info(f"us hidden_states {hidden_states.shape}")
del sliced
hidden_states = hidden_states.to(org_device)

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,266 @@
import logging
import sys
import threading
import torch
from torchvision import transforms
from typing import *
from diffusers import EulerAncestralDiscreteScheduler
import diffusers.schedulers.scheduling_euler_ancestral_discrete
from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput
def fire_in_thread(f, *args, **kwargs):
threading.Thread(target=f, args=args, kwargs=kwargs).start()
threading.Thread(target=f, args=args, kwargs=kwargs).start()
def add_logging_arguments(parser):
parser.add_argument(
"--console_log_level",
type=str,
default=None,
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
help="Set the logging level, default is INFO / ログレベルを設定する。デフォルトはINFO",
)
parser.add_argument(
"--console_log_file",
type=str,
default=None,
help="Log to a file instead of stderr / 標準エラー出力ではなくファイルにログを出力する",
)
parser.add_argument("--console_log_simple", action="store_true", help="Simple log output / シンプルなログ出力")
def setup_logging(args=None, log_level=None, reset=False):
if logging.root.handlers:
if reset:
# remove all handlers
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
else:
return
# log_level can be set by the caller or by the args, the caller has priority. If not set, use INFO
if log_level is None and args is not None:
log_level = args.console_log_level
if log_level is None:
log_level = "INFO"
log_level = getattr(logging, log_level)
msg_init = None
if args is not None and args.console_log_file:
handler = logging.FileHandler(args.console_log_file, mode="w")
else:
handler = None
if not args or not args.console_log_simple:
try:
from rich.logging import RichHandler
from rich.console import Console
from rich.logging import RichHandler
handler = RichHandler(console=Console(stderr=True))
except ImportError:
# print("rich is not installed, using basic logging")
msg_init = "rich is not installed, using basic logging"
if handler is None:
handler = logging.StreamHandler(sys.stdout) # same as print
handler.propagate = False
formatter = logging.Formatter(
fmt="%(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
handler.setFormatter(formatter)
logging.root.setLevel(log_level)
logging.root.addHandler(handler)
if msg_init is not None:
logger = logging.getLogger(__name__)
logger.info(msg_init)
# TODO make inf_utils.py
# region Gradual Latent hires fix
class GradualLatent:
def __init__(
self,
ratio,
start_timesteps,
every_n_steps,
ratio_step,
s_noise=1.0,
gaussian_blur_ksize=None,
gaussian_blur_sigma=0.5,
gaussian_blur_strength=0.5,
unsharp_target_x=True,
):
self.ratio = ratio
self.start_timesteps = start_timesteps
self.every_n_steps = every_n_steps
self.ratio_step = ratio_step
self.s_noise = s_noise
self.gaussian_blur_ksize = gaussian_blur_ksize
self.gaussian_blur_sigma = gaussian_blur_sigma
self.gaussian_blur_strength = gaussian_blur_strength
self.unsharp_target_x = unsharp_target_x
def __str__(self) -> str:
return (
f"GradualLatent(ratio={self.ratio}, start_timesteps={self.start_timesteps}, "
+ f"every_n_steps={self.every_n_steps}, ratio_step={self.ratio_step}, s_noise={self.s_noise}, "
+ f"gaussian_blur_ksize={self.gaussian_blur_ksize}, gaussian_blur_sigma={self.gaussian_blur_sigma}, gaussian_blur_strength={self.gaussian_blur_strength}, "
+ f"unsharp_target_x={self.unsharp_target_x})"
)
def apply_unshark_mask(self, x: torch.Tensor):
if self.gaussian_blur_ksize is None:
return x
blurred = transforms.functional.gaussian_blur(x, self.gaussian_blur_ksize, self.gaussian_blur_sigma)
# mask = torch.sigmoid((x - blurred) * self.gaussian_blur_strength)
mask = (x - blurred) * self.gaussian_blur_strength
sharpened = x + mask
return sharpened
def interpolate(self, x: torch.Tensor, resized_size, unsharp=True):
org_dtype = x.dtype
if org_dtype == torch.bfloat16:
x = x.float()
x = torch.nn.functional.interpolate(x, size=resized_size, mode="bicubic", align_corners=False).to(dtype=org_dtype)
# apply unsharp mask / アンシャープマスクを適用する
if unsharp and self.gaussian_blur_ksize:
x = self.apply_unshark_mask(x)
return x
class EulerAncestralDiscreteSchedulerGL(EulerAncestralDiscreteScheduler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.resized_size = None
self.gradual_latent = None
def set_gradual_latent_params(self, size, gradual_latent: GradualLatent):
self.resized_size = size
self.gradual_latent = gradual_latent
def step(
self,
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`):
Whether or not to return a
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
Returns:
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
If return_dict is `True`,
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
otherwise a tuple is returned where the first element is the sample tensor.
"""
if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor):
raise ValueError(
(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep."
),
)
if not self.is_scale_input_called:
# logger.warning(
print(
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
"See `StableDiffusionPipeline` for a usage example."
)
if self.step_index is None:
self._init_step_index(timestep)
sigma = self.sigmas[self.step_index]
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
if self.config.prediction_type == "epsilon":
pred_original_sample = sample - sigma * model_output
elif self.config.prediction_type == "v_prediction":
# * c_out + input * c_skip
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
elif self.config.prediction_type == "sample":
raise NotImplementedError("prediction_type not implemented yet: sample")
else:
raise ValueError(f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`")
sigma_from = self.sigmas[self.step_index]
sigma_to = self.sigmas[self.step_index + 1]
sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
# 2. Convert to an ODE derivative
derivative = (sample - pred_original_sample) / sigma
dt = sigma_down - sigma
device = model_output.device
if self.resized_size is None:
prev_sample = sample + derivative * dt
noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor(
model_output.shape, dtype=model_output.dtype, device=device, generator=generator
)
s_noise = 1.0
else:
print("resized_size", self.resized_size, "model_output.shape", model_output.shape, "sample.shape", sample.shape)
s_noise = self.gradual_latent.s_noise
if self.gradual_latent.unsharp_target_x:
prev_sample = sample + derivative * dt
prev_sample = self.gradual_latent.interpolate(prev_sample, self.resized_size)
else:
sample = self.gradual_latent.interpolate(sample, self.resized_size)
derivative = self.gradual_latent.interpolate(derivative, self.resized_size, unsharp=False)
prev_sample = sample + derivative * dt
noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor(
(model_output.shape[0], model_output.shape[1], self.resized_size[0], self.resized_size[1]),
dtype=model_output.dtype,
device=device,
generator=generator,
)
prev_sample = prev_sample + noise * sigma_up * s_noise
# upon completion increase step index by one
self._step_index += 1
if not return_dict:
return (prev_sample,)
return EulerAncestralDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
# endregion

View File

@@ -2,10 +2,13 @@ import argparse
import os
import torch
from safetensors.torch import load_file
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def main(file):
print(f"loading: {file}")
logger.info(f"loading: {file}")
if os.path.splitext(file)[1] == ".safetensors":
sd = load_file(file)
else:

View File

@@ -2,7 +2,10 @@ import os
from typing import Optional, List, Type
import torch
from library import sdxl_original_unet
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# input_blocksに適用するかどうか / if True, input_blocks are not applied
SKIP_INPUT_BLOCKS = False
@@ -125,7 +128,7 @@ class LLLiteModule(torch.nn.Module):
return
# timestepごとに呼ばれないので、あらかじめ計算しておく / it is not called for each timestep, so calculate it in advance
# print(f"C {self.lllite_name}, cond_image.shape={cond_image.shape}")
# logger.info(f"C {self.lllite_name}, cond_image.shape={cond_image.shape}")
cx = self.conditioning1(cond_image)
if not self.is_conv2d:
# reshape / b,c,h,w -> b,h*w,c
@@ -155,7 +158,7 @@ class LLLiteModule(torch.nn.Module):
cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1)
if self.use_zeros_for_batch_uncond:
cx[0::2] = 0.0 # uncond is zero
# print(f"C {self.lllite_name}, x.shape={x.shape}, cx.shape={cx.shape}")
# logger.info(f"C {self.lllite_name}, x.shape={x.shape}, cx.shape={cx.shape}")
# downで入力の次元数を削減し、conditioning image embeddingと結合する
# 加算ではなくchannel方向に結合することで、うまいこと混ぜてくれることを期待している
@@ -286,7 +289,7 @@ class ControlNetLLLite(torch.nn.Module):
# create module instances
self.unet_modules: List[LLLiteModule] = create_modules(unet, target_modules, LLLiteModule)
print(f"create ControlNet LLLite for U-Net: {len(self.unet_modules)} modules.")
logger.info(f"create ControlNet LLLite for U-Net: {len(self.unet_modules)} modules.")
def forward(self, x):
return x # dummy
@@ -319,7 +322,7 @@ class ControlNetLLLite(torch.nn.Module):
return info
def apply_to(self):
print("applying LLLite for U-Net...")
logger.info("applying LLLite for U-Net...")
for module in self.unet_modules:
module.apply_to()
self.add_module(module.lllite_name, module)
@@ -374,19 +377,19 @@ if __name__ == "__main__":
# sdxl_original_unet.USE_REENTRANT = False
# test shape etc
print("create unet")
logger.info("create unet")
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
unet.to("cuda").to(torch.float16)
print("create ControlNet-LLLite")
logger.info("create ControlNet-LLLite")
control_net = ControlNetLLLite(unet, 32, 64)
control_net.apply_to()
control_net.to("cuda")
print(control_net)
logger.info(control_net)
# print number of parameters
print("number of parameters", sum(p.numel() for p in control_net.parameters() if p.requires_grad))
# logger.info number of parameters
logger.info(f"number of parameters {sum(p.numel() for p in control_net.parameters() if p.requires_grad)}")
input()
@@ -398,12 +401,12 @@ if __name__ == "__main__":
# # visualize
# import torchviz
# print("run visualize")
# logger.info("run visualize")
# controlnet.set_control(conditioning_image)
# output = unet(x, t, ctx, y)
# print("make_dot")
# logger.info("make_dot")
# image = torchviz.make_dot(output, params=dict(controlnet.named_parameters()))
# print("render")
# logger.info("render")
# image.format = "svg" # "png"
# image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time
# input()
@@ -414,12 +417,12 @@ if __name__ == "__main__":
scaler = torch.cuda.amp.GradScaler(enabled=True)
print("start training")
logger.info("start training")
steps = 10
sample_param = [p for p in control_net.named_parameters() if "up" in p[0]][0]
for step in range(steps):
print(f"step {step}")
logger.info(f"step {step}")
batch_size = 1
conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0
@@ -439,7 +442,7 @@ if __name__ == "__main__":
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
print(sample_param)
logger.info(f"{sample_param}")
# from safetensors.torch import save_file

View File

@@ -6,7 +6,10 @@ import re
from typing import Optional, List, Type
import torch
from library import sdxl_original_unet
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# input_blocksに適用するかどうか / if True, input_blocks are not applied
SKIP_INPUT_BLOCKS = False
@@ -270,7 +273,7 @@ class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DCond
# create module instances
self.lllite_modules = apply_to_modules(self, target_modules)
print(f"enable ControlNet LLLite for U-Net: {len(self.lllite_modules)} modules.")
logger.info(f"enable ControlNet LLLite for U-Net: {len(self.lllite_modules)} modules.")
# def prepare_optimizer_params(self):
def prepare_params(self):
@@ -281,8 +284,8 @@ class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DCond
train_params.append(p)
else:
non_train_params.append(p)
print(f"count of trainable parameters: {len(train_params)}")
print(f"count of non-trainable parameters: {len(non_train_params)}")
logger.info(f"count of trainable parameters: {len(train_params)}")
logger.info(f"count of non-trainable parameters: {len(non_train_params)}")
for p in non_train_params:
p.requires_grad_(False)
@@ -388,7 +391,7 @@ class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DCond
matches = pattern.findall(module_name)
if matches is not None:
for m in matches:
print(module_name, m)
logger.info(f"{module_name} {m}")
module_name = module_name.replace(m, m.replace("_", "@"))
module_name = module_name.replace("_", ".")
module_name = module_name.replace("@", "_")
@@ -407,7 +410,7 @@ class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DCond
def replace_unet_linear_and_conv2d():
print("replace torch.nn.Linear and torch.nn.Conv2d to LLLiteLinear and LLLiteConv2d in U-Net")
logger.info("replace torch.nn.Linear and torch.nn.Conv2d to LLLiteLinear and LLLiteConv2d in U-Net")
sdxl_original_unet.torch.nn.Linear = LLLiteLinear
sdxl_original_unet.torch.nn.Conv2d = LLLiteConv2d
@@ -419,10 +422,10 @@ if __name__ == "__main__":
replace_unet_linear_and_conv2d()
# test shape etc
print("create unet")
logger.info("create unet")
unet = SdxlUNet2DConditionModelControlNetLLLite()
print("enable ControlNet-LLLite")
logger.info("enable ControlNet-LLLite")
unet.apply_lllite(32, 64, None, False, 1.0)
unet.to("cuda") # .to(torch.float16)
@@ -439,14 +442,14 @@ if __name__ == "__main__":
# unet_sd[converted_key] = model_sd[key]
# info = unet.load_lllite_weights("r:/lllite_from_unet.safetensors", unet_sd)
# print(info)
# logger.info(info)
# print(unet)
# logger.info(unet)
# print number of parameters
# logger.info number of parameters
params = unet.prepare_params()
print("number of parameters", sum(p.numel() for p in params))
# print("type any key to continue")
logger.info(f"number of parameters {sum(p.numel() for p in params)}")
# logger.info("type any key to continue")
# input()
unet.set_use_memory_efficient_attention(True, False)
@@ -455,12 +458,12 @@ if __name__ == "__main__":
# # visualize
# import torchviz
# print("run visualize")
# logger.info("run visualize")
# controlnet.set_control(conditioning_image)
# output = unet(x, t, ctx, y)
# print("make_dot")
# logger.info("make_dot")
# image = torchviz.make_dot(output, params=dict(controlnet.named_parameters()))
# print("render")
# logger.info("render")
# image.format = "svg" # "png"
# image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time
# input()
@@ -471,13 +474,13 @@ if __name__ == "__main__":
scaler = torch.cuda.amp.GradScaler(enabled=True)
print("start training")
logger.info("start training")
steps = 10
batch_size = 1
sample_param = [p for p in unet.named_parameters() if ".lllite_up." in p[0]][0]
for step in range(steps):
print(f"step {step}")
logger.info(f"step {step}")
conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0
x = torch.randn(batch_size, 4, 128, 128).cuda()
@@ -494,9 +497,9 @@ if __name__ == "__main__":
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
print(sample_param)
logger.info(sample_param)
# from safetensors.torch import save_file
# print("save weights")
# logger.info("save weights")
# unet.save_lllite_weights("r:/lllite_from_unet.safetensors", torch.float16, None)

View File

@@ -12,10 +12,15 @@
import math
import os
import random
from typing import List, Tuple, Union
from typing import Dict, List, Optional, Tuple, Type, Union
from diffusers import AutoencoderKL
from transformers import CLIPTextModel
import torch
from torch import nn
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class DyLoRAModule(torch.nn.Module):
"""
@@ -165,7 +170,15 @@ class DyLoRAModule(torch.nn.Module):
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
def create_network(
multiplier: float,
network_dim: Optional[int],
network_alpha: Optional[float],
vae: AutoencoderKL,
text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
unet,
**kwargs,
):
if network_dim is None:
network_dim = 4 # default
if network_alpha is None:
@@ -182,6 +195,7 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
conv_alpha = 1.0
else:
conv_alpha = float(conv_alpha)
if unit is not None:
unit = int(unit)
else:
@@ -223,7 +237,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
elif "lora_down" in key:
dim = value.size()[0]
modules_dim[lora_name] = dim
# print(lora_name, value.size(), dim)
# logger.info(f"{lora_name} {value.size()} {dim}")
# support old LoRA without alpha
for key in modules_dim.keys():
@@ -267,11 +281,11 @@ class DyLoRANetwork(torch.nn.Module):
self.apply_to_conv = apply_to_conv
if modules_dim is not None:
print(f"create LoRA network from weights")
logger.info("create LoRA network from weights")
else:
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, unit: {unit}")
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, unit: {unit}")
if self.apply_to_conv:
print(f"apply LoRA to Conv2d with kernel size (3,3).")
logger.info("apply LoRA to Conv2d with kernel size (3,3).")
# create module instances
def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[DyLoRAModule]:
@@ -306,9 +320,23 @@ class DyLoRANetwork(torch.nn.Module):
lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, unit)
loras.append(lora)
return loras
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
self.text_encoder_loras = []
for i, text_encoder in enumerate(text_encoders):
if len(text_encoders) > 1:
index = i + 1
logger.info(f"create LoRA for Text Encoder {index}")
else:
index = None
logger.info("create LoRA for Text Encoder")
text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
self.text_encoder_loras.extend(text_encoder_loras)
self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
# self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
target_modules = DyLoRANetwork.UNET_TARGET_REPLACE_MODULE
@@ -316,7 +344,7 @@ class DyLoRANetwork(torch.nn.Module):
target_modules += DyLoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
self.unet_loras = create_modules(True, unet, target_modules)
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
def set_multiplier(self, multiplier):
self.multiplier = multiplier
@@ -336,12 +364,12 @@ class DyLoRANetwork(torch.nn.Module):
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
if apply_text_encoder:
print("enable LoRA for text encoder")
logger.info("enable LoRA for text encoder")
else:
self.text_encoder_loras = []
if apply_unet:
print("enable LoRA for U-Net")
logger.info("enable LoRA for U-Net")
else:
self.unet_loras = []
@@ -359,12 +387,12 @@ class DyLoRANetwork(torch.nn.Module):
apply_unet = True
if apply_text_encoder:
print("enable LoRA for text encoder")
logger.info("enable LoRA for text encoder")
else:
self.text_encoder_loras = []
if apply_unet:
print("enable LoRA for U-Net")
logger.info("enable LoRA for U-Net")
else:
self.unet_loras = []
@@ -375,7 +403,7 @@ class DyLoRANetwork(torch.nn.Module):
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
lora.merge_to(sd_for_lora, dtype, device)
print(f"weights are merged")
logger.info(f"weights are merged")
"""
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):

View File

@@ -10,7 +10,10 @@ from safetensors.torch import load_file, save_file, safe_open
from tqdm import tqdm
from library import train_util, model_util
import numpy as np
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def load_state_dict(file_name):
if model_util.is_safetensors(file_name):
@@ -40,13 +43,13 @@ def split_lora_model(lora_sd, unit):
rank = value.size()[0]
if rank > max_rank:
max_rank = rank
print(f"Max rank: {max_rank}")
logger.info(f"Max rank: {max_rank}")
rank = unit
split_models = []
new_alpha = None
while rank < max_rank:
print(f"Splitting rank {rank}")
logger.info(f"Splitting rank {rank}")
new_sd = {}
for key, value in lora_sd.items():
if "lora_down" in key:
@@ -57,7 +60,7 @@ def split_lora_model(lora_sd, unit):
# なぜかscaleするとおかしくなる……
# this_rank = lora_sd[key.replace("alpha", "lora_down.weight")].size()[0]
# scale = math.sqrt(this_rank / rank) # rank is > unit
# print(key, value.size(), this_rank, rank, value, scale)
# logger.info(key, value.size(), this_rank, rank, value, scale)
# new_alpha = value * scale # always same
# new_sd[key] = new_alpha
new_sd[key] = value
@@ -69,10 +72,10 @@ def split_lora_model(lora_sd, unit):
def split(args):
print("loading Model...")
logger.info("loading Model...")
lora_sd, metadata = load_state_dict(args.model)
print("Splitting Model...")
logger.info("Splitting Model...")
original_rank, split_models = split_lora_model(lora_sd, args.unit)
comment = metadata.get("ss_training_comment", "")
@@ -94,7 +97,7 @@ def split(args):
filename, ext = os.path.splitext(args.save_to)
model_file_name = filename + f"-{new_rank:04d}{ext}"
print(f"saving model to: {model_file_name}")
logger.info(f"saving model to: {model_file_name}")
save_to_file(model_file_name, state_dict, new_metadata)

View File

@@ -11,10 +11,13 @@ from safetensors.torch import load_file, save_file
from tqdm import tqdm
from library import sai_model_spec, model_util, sdxl_model_util
import lora
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
CLAMP_QUANTILE = 0.99
MIN_DIFF = 1e-1
# CLAMP_QUANTILE = 0.99
# MIN_DIFF = 1e-1
def save_to_file(file_name, model, state_dict, dtype):
@@ -29,7 +32,24 @@ def save_to_file(file_name, model, state_dict, dtype):
torch.save(model, file_name)
def svd(args):
def svd(
model_org=None,
model_tuned=None,
save_to=None,
dim=4,
v2=None,
sdxl=None,
conv_dim=None,
v_parameterization=None,
device=None,
save_precision=None,
clamp_quantile=0.99,
min_diff=0.01,
no_metadata=False,
load_precision=None,
load_original_model_to=None,
load_tuned_model_to=None,
):
def str_to_dtype(p):
if p == "float":
return torch.float
@@ -39,44 +59,65 @@ def svd(args):
return torch.bfloat16
return None
assert args.v2 != args.sdxl or (
not args.v2 and not args.sdxl
), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません"
if args.v_parameterization is None:
args.v_parameterization = args.v2
assert v2 != sdxl or (not v2 and not sdxl), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません"
if v_parameterization is None:
v_parameterization = v2
save_dtype = str_to_dtype(args.save_precision)
load_dtype = str_to_dtype(load_precision) if load_precision else None
save_dtype = str_to_dtype(save_precision)
work_device = "cpu"
# load models
if not args.sdxl:
print(f"loading original SD model : {args.model_org}")
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org)
if not sdxl:
logger.info(f"loading original SD model : {model_org}")
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org)
text_encoders_o = [text_encoder_o]
print(f"loading tuned SD model : {args.model_tuned}")
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned)
if load_dtype is not None:
text_encoder_o = text_encoder_o.to(load_dtype)
unet_o = unet_o.to(load_dtype)
logger.info(f"loading tuned SD model : {model_tuned}")
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned)
text_encoders_t = [text_encoder_t]
model_version = model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization)
if load_dtype is not None:
text_encoder_t = text_encoder_t.to(load_dtype)
unet_t = unet_t.to(load_dtype)
model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization)
else:
print(f"loading original SDXL model : {args.model_org}")
device_org = load_original_model_to if load_original_model_to else "cpu"
device_tuned = load_tuned_model_to if load_tuned_model_to else "cpu"
logger.info(f"loading original SDXL model : {model_org}")
text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.model_org, "cpu"
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, device_org
)
text_encoders_o = [text_encoder_o1, text_encoder_o2]
print(f"loading original SDXL model : {args.model_tuned}")
if load_dtype is not None:
text_encoder_o1 = text_encoder_o1.to(load_dtype)
text_encoder_o2 = text_encoder_o2.to(load_dtype)
unet_o = unet_o.to(load_dtype)
logger.info(f"loading original SDXL model : {model_tuned}")
text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.model_tuned, "cpu"
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, device_tuned
)
text_encoders_t = [text_encoder_t1, text_encoder_t2]
if load_dtype is not None:
text_encoder_t1 = text_encoder_t1.to(load_dtype)
text_encoder_t2 = text_encoder_t2.to(load_dtype)
unet_t = unet_t.to(load_dtype)
model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0
# create LoRA network to extract weights: Use dim (rank) as alpha
if args.conv_dim is None:
if conv_dim is None:
kwargs = {}
else:
kwargs = {"conv_dim": args.conv_dim, "conv_alpha": args.conv_dim}
kwargs = {"conv_dim": conv_dim, "conv_alpha": conv_dim}
lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoders_o, unet_o, **kwargs)
lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoders_t, unet_t, **kwargs)
lora_network_o = lora.create_network(1.0, dim, dim, None, text_encoders_o, unet_o, **kwargs)
lora_network_t = lora.create_network(1.0, dim, dim, None, text_encoders_t, unet_t, **kwargs)
assert len(lora_network_o.text_encoder_loras) == len(
lora_network_t.text_encoder_loras
), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違いますSD1.xベースとSD2.xベース "
@@ -88,50 +129,66 @@ def svd(args):
lora_name = lora_o.lora_name
module_o = lora_o.org_module
module_t = lora_t.org_module
diff = module_t.weight - module_o.weight
diff = module_t.weight.to(work_device) - module_o.weight.to(work_device)
# clear weight to save memory
module_o.weight = None
module_t.weight = None
# Text Encoder might be same
if not text_encoder_different and torch.max(torch.abs(diff)) > MIN_DIFF:
if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff:
text_encoder_different = True
print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {MIN_DIFF}")
logger.info(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}")
diff = diff.float()
diffs[lora_name] = diff
# clear target Text Encoder to save memory
for text_encoder in text_encoders_t:
del text_encoder
if not text_encoder_different:
print("Text encoder is same. Extract U-Net only.")
logger.warning("Text encoder is same. Extract U-Net only.")
lora_network_o.text_encoder_loras = []
diffs = {}
diffs = {} # clear diffs
for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)):
lora_name = lora_o.lora_name
module_o = lora_o.org_module
module_t = lora_t.org_module
diff = module_t.weight - module_o.weight
diff = diff.float()
diff = module_t.weight.to(work_device) - module_o.weight.to(work_device)
if args.device:
diff = diff.to(args.device)
# clear weight to save memory
module_o.weight = None
module_t.weight = None
diffs[lora_name] = diff
# clear LoRA network, target U-Net to save memory
del lora_network_o
del lora_network_t
del unet_t
# make LoRA with svd
print("calculating by svd")
logger.info("calculating by svd")
lora_weights = {}
with torch.no_grad():
for lora_name, mat in tqdm(list(diffs.items())):
# if args.conv_dim is None, diffs do not include LoRAs for conv2d-3x3
if args.device:
mat = mat.to(args.device)
mat = mat.to(torch.float) # calc by float
# if conv_dim is None, diffs do not include LoRAs for conv2d-3x3
conv2d = len(mat.size()) == 4
kernel_size = None if not conv2d else mat.size()[2:4]
conv2d_3x3 = conv2d and kernel_size != (1, 1)
rank = args.dim if not conv2d_3x3 or args.conv_dim is None else args.conv_dim
rank = dim if not conv2d_3x3 or conv_dim is None else conv_dim
out_dim, in_dim = mat.size()[0:2]
if args.device:
mat = mat.to(args.device)
if device:
mat = mat.to(device)
# print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim)
# logger.info(lora_name, mat.size(), mat.device, rank, in_dim, out_dim)
rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
if conv2d:
@@ -149,7 +206,7 @@ def svd(args):
Vh = Vh[:rank, :]
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
hi_val = torch.quantile(dist, clamp_quantile)
low_val = -hi_val
U = U.clamp(low_val, hi_val)
@@ -159,8 +216,8 @@ def svd(args):
U = U.reshape(out_dim, rank, 1, 1)
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
U = U.to("cpu").contiguous()
Vh = Vh.to("cpu").contiguous()
U = U.to(work_device, dtype=save_dtype).contiguous()
Vh = Vh.to(work_device, dtype=save_dtype).contiguous()
lora_weights[lora_name] = (U, Vh)
@@ -176,36 +233,34 @@ def svd(args):
lora_network_save.apply_to(text_encoders_o, unet_o) # create internal module references for state_dict
info = lora_network_save.load_state_dict(lora_sd)
print(f"Loading extracted LoRA weights: {info}")
logger.info(f"Loading extracted LoRA weights: {info}")
dir_name = os.path.dirname(args.save_to)
dir_name = os.path.dirname(save_to)
if dir_name and not os.path.exists(dir_name):
os.makedirs(dir_name, exist_ok=True)
# minimum metadata
net_kwargs = {}
if args.conv_dim is not None:
net_kwargs["conv_dim"] = args.conv_dim
net_kwargs["conv_alpha"] = args.conv_dim
if conv_dim is not None:
net_kwargs["conv_dim"] = str(conv_dim)
net_kwargs["conv_alpha"] = str(float(conv_dim))
metadata = {
"ss_v2": str(args.v2),
"ss_v2": str(v2),
"ss_base_model_version": model_version,
"ss_network_module": "networks.lora",
"ss_network_dim": str(args.dim),
"ss_network_alpha": str(args.dim),
"ss_network_dim": str(dim),
"ss_network_alpha": str(float(dim)),
"ss_network_args": json.dumps(net_kwargs),
}
if not args.no_metadata:
title = os.path.splitext(os.path.basename(args.save_to))[0]
sai_metadata = sai_model_spec.build_metadata(
None, args.v2, args.v_parameterization, args.sdxl, True, False, time.time(), title=title
)
if not no_metadata:
title = os.path.splitext(os.path.basename(save_to))[0]
sai_metadata = sai_model_spec.build_metadata(None, v2, v_parameterization, sdxl, True, False, time.time(), title=title)
metadata.update(sai_metadata)
lora_network_save.save_weights(args.save_to, save_dtype, metadata)
print(f"LoRA weights are saved to: {args.save_to}")
lora_network_save.save_weights(save_to, save_dtype, metadata)
logger.info(f"LoRA weights are saved to: {save_to}")
def setup_parser() -> argparse.ArgumentParser:
@@ -213,13 +268,20 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む")
parser.add_argument(
"--v_parameterization",
type=bool,
action="store_true",
default=None,
help="make LoRA metadata for v-parameterization (default is same to v2) / 作成するLoRAのメタデータにv-parameterization用と設定する省略時はv2と同じ",
)
parser.add_argument(
"--sdxl", action="store_true", help="load Stable Diffusion SDXL base model / Stable Diffusion SDXL baseのモデルを読み込む"
)
parser.add_argument(
"--load_precision",
type=str,
default=None,
choices=[None, "float", "fp16", "bf16"],
help="precision in loading, model default if omitted / 読み込み時に精度を変更して読み込む、省略時はモデルファイルによる"
)
parser.add_argument(
"--save_precision",
type=str,
@@ -231,16 +293,22 @@ def setup_parser() -> argparse.ArgumentParser:
"--model_org",
type=str,
default=None,
required=True,
help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors",
)
parser.add_argument(
"--model_tuned",
type=str,
default=None,
required=True,
help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル生成されるLoRAは元→派生の差分になります、ckptまたはsafetensors",
)
parser.add_argument(
"--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
"--save_to",
type=str,
default=None,
required=True,
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors",
)
parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数rankデフォルト4")
parser.add_argument(
@@ -250,12 +318,37 @@ def setup_parser() -> argparse.ArgumentParser:
help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数rankデフォルトNone、適用なし",
)
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
parser.add_argument(
"--clamp_quantile",
type=float,
default=0.99,
help="Quantile clamping value, float, (0-1). Default = 0.99 / 値をクランプするための分位点、float、(0-1)。デフォルトは0.99",
)
parser.add_argument(
"--min_diff",
type=float,
default=0.01,
help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01 /"
+ "LoRAを抽出するために元モデルと派生モデルの差分の最小値、float、(0-1)。デフォルトは0.01",
)
parser.add_argument(
"--no_metadata",
action="store_true",
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
+ "sai modelspecのメタデータを保存しないLoRAの最低限のss_metadataは保存される",
)
parser.add_argument(
"--load_original_model_to",
type=str,
default=None,
help="location to load original model, cpu or cuda, cuda:0, etc, default is cpu, only for SDXL / 元モデル読み込み先、cpuまたはcuda、cuda:0など、省略時はcpu、SDXLのみ有効",
)
parser.add_argument(
"--load_tuned_model_to",
type=str,
default=None,
help="location to load tuned model, cpu or cuda, cuda:0, etc, default is cpu, only for SDXL / 派生モデル読み込み先、cpuまたはcuda、cuda:0など、省略時はcpu、SDXLのみ有効",
)
return parser
@@ -264,4 +357,4 @@ if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
svd(args)
svd(**vars(args))

View File

@@ -11,7 +11,12 @@ from transformers import CLIPTextModel
import numpy as np
import torch
import re
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
@@ -46,7 +51,7 @@ class LoRAModule(torch.nn.Module):
# if limit_rank:
# self.lora_dim = min(lora_dim, in_dim, out_dim)
# if self.lora_dim != lora_dim:
# print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
# logger.info(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
# else:
self.lora_dim = lora_dim
@@ -177,7 +182,7 @@ class LoRAInfModule(LoRAModule):
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# print(conved.size(), weight.size(), module.stride, module.padding)
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + self.multiplier * conved * self.scale
# set weight to org_module
@@ -216,7 +221,7 @@ class LoRAInfModule(LoRAModule):
self.region_mask = None
def default_forward(self, x):
# print("default_forward", self.lora_name, x.size())
# logger.info(f"default_forward {self.lora_name} {x.size()}")
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
def forward(self, x):
@@ -242,13 +247,13 @@ class LoRAInfModule(LoRAModule):
area = x.size()[1]
mask = self.network.mask_dic.get(area, None)
if mask is None:
# raise ValueError(f"mask is None for resolution {area}")
if mask is None or len(x.size()) == 2:
# emb_layers in SDXL doesn't have mask
# print(f"mask is None for resolution {area}, {x.size()}")
# if "emb" not in self.lora_name:
# print(f"mask is None for resolution {self.lora_name}, {area}, {x.size()}")
mask_size = (1, x.size()[1]) if len(x.size()) == 2 else (1, *x.size()[1:-1], 1)
return torch.ones(mask_size, dtype=x.dtype, device=x.device) / self.network.num_sub_prompts
if len(x.size()) != 4:
if len(x.size()) == 3:
mask = torch.reshape(mask, (1, -1, 1))
return mask
@@ -263,6 +268,8 @@ class LoRAInfModule(LoRAModule):
lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
mask = self.get_mask_for_x(lx)
# print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size())
# if mask.ndim > lx.ndim: # in some resolution, lx is 2d and mask is 3d (the reason is not checked)
# mask = mask.squeeze(-1)
lx = lx * mask
x = self.org_forward(x)
@@ -291,7 +298,7 @@ class LoRAInfModule(LoRAModule):
if has_real_uncond:
query[-self.network.batch_size :] = x[-self.network.batch_size :]
# print("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts)
# logger.info(f"postp_to_q {self.lora_name} {x.size()} {query.size()} {self.network.num_sub_prompts}")
return query
def sub_prompt_forward(self, x):
@@ -306,7 +313,7 @@ class LoRAInfModule(LoRAModule):
lx = x[emb_idx :: self.network.num_sub_prompts]
lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale
# print("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx)
# logger.info(f"sub_prompt_forward {self.lora_name} {x.size()} {lx.size()} {emb_idx}")
x = self.org_forward(x)
x[emb_idx :: self.network.num_sub_prompts] += lx
@@ -314,7 +321,7 @@ class LoRAInfModule(LoRAModule):
return x
def to_out_forward(self, x):
# print("to_out_forward", self.lora_name, x.size(), self.network.is_last_network)
# logger.info(f"to_out_forward {self.lora_name} {x.size()} {self.network.is_last_network}")
if self.network.is_last_network:
masks = [None] * self.network.num_sub_prompts
@@ -332,7 +339,7 @@ class LoRAInfModule(LoRAModule):
)
self.network.shared[self.lora_name] = (lx, masks)
# print("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts)
# logger.info(f"to_out_forward {lx.size()} {lx1.size()} {self.network.sub_prompt_index} {self.network.num_sub_prompts}")
lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1
masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1)
@@ -351,7 +358,7 @@ class LoRAInfModule(LoRAModule):
if has_real_uncond:
out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond
# print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts)
# logger.info(f"to_out_forward {self.lora_name} {self.network.sub_prompt_index} {self.network.num_sub_prompts}")
# if num_sub_prompts > num of LoRAs, fill with zero
for i in range(len(masks)):
if masks[i] is None:
@@ -374,7 +381,7 @@ class LoRAInfModule(LoRAModule):
x1 = x1 + lx1
out[self.network.batch_size + i] = x1
# print("to_out_forward", x.size(), out.size(), has_real_uncond)
# logger.info(f"to_out_forward {x.size()} {out.size()} {has_real_uncond}")
return out
@@ -511,7 +518,9 @@ def get_block_dims_and_alphas(
len(block_dims) == num_total_blocks
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
else:
print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
logger.warning(
f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります"
)
block_dims = [network_dim] * num_total_blocks
if block_alphas is not None:
@@ -520,7 +529,7 @@ def get_block_dims_and_alphas(
len(block_alphas) == num_total_blocks
), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください"
else:
print(
logger.warning(
f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります"
)
block_alphas = [network_alpha] * num_total_blocks
@@ -540,13 +549,13 @@ def get_block_dims_and_alphas(
else:
if conv_alpha is None:
conv_alpha = 1.0
print(
logger.warning(
f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります"
)
conv_block_alphas = [conv_alpha] * num_total_blocks
else:
if conv_dim is not None:
print(
logger.warning(
f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります"
)
conv_block_dims = [conv_dim] * num_total_blocks
@@ -586,7 +595,7 @@ def get_block_lr_weight(
elif name == "zeros":
return [0.0 + base_lr] * max_len
else:
print(
logger.error(
"Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros"
% (name)
)
@@ -598,14 +607,14 @@ def get_block_lr_weight(
up_lr_weight = get_list(up_lr_weight)
if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len):
print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
logger.warning("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
logger.warning("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
up_lr_weight = up_lr_weight[:max_len]
down_lr_weight = down_lr_weight[:max_len]
if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len):
print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
logger.warning("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
logger.warning("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
if down_lr_weight != None and len(down_lr_weight) < max_len:
down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight))
@@ -613,24 +622,24 @@ def get_block_lr_weight(
up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight))
if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
print("apply block learning rate / 階層別学習率を適用します。")
logger.info("apply block learning rate / 階層別学習率を適用します。")
if down_lr_weight != None:
down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight]
print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", down_lr_weight)
logger.info(f"down_lr_weight (shallower -> deeper, 浅い層->深い層): {down_lr_weight}")
else:
print("down_lr_weight: all 1.0, すべて1.0")
logger.info("down_lr_weight: all 1.0, すべて1.0")
if mid_lr_weight != None:
mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0
print("mid_lr_weight:", mid_lr_weight)
logger.info(f"mid_lr_weight: {mid_lr_weight}")
else:
print("mid_lr_weight: 1.0")
logger.info("mid_lr_weight: 1.0")
if up_lr_weight != None:
up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight]
print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", up_lr_weight)
logger.info(f"up_lr_weight (deeper -> shallower, 深い層->浅い層): {up_lr_weight}")
else:
print("up_lr_weight: all 1.0, すべて1.0")
logger.info("up_lr_weight: all 1.0, すべて1.0")
return down_lr_weight, mid_lr_weight, up_lr_weight
@@ -711,7 +720,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
elif "lora_down" in key:
dim = value.size()[0]
modules_dim[lora_name] = dim
# print(lora_name, value.size(), dim)
# logger.info(lora_name, value.size(), dim)
# support old LoRA without alpha
for key in modules_dim.keys():
@@ -786,20 +795,26 @@ class LoRANetwork(torch.nn.Module):
self.module_dropout = module_dropout
if modules_dim is not None:
print(f"create LoRA network from weights")
logger.info(f"create LoRA network from weights")
elif block_dims is not None:
print(f"create LoRA network from block_dims")
print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
print(f"block_dims: {block_dims}")
print(f"block_alphas: {block_alphas}")
logger.info(f"create LoRA network from block_dims")
logger.info(
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
)
logger.info(f"block_dims: {block_dims}")
logger.info(f"block_alphas: {block_alphas}")
if conv_block_dims is not None:
print(f"conv_block_dims: {conv_block_dims}")
print(f"conv_block_alphas: {conv_block_alphas}")
logger.info(f"conv_block_dims: {conv_block_dims}")
logger.info(f"conv_block_alphas: {conv_block_alphas}")
else:
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
logger.info(
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
)
if self.conv_lora_dim is not None:
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
logger.info(
f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
)
# create module instances
def create_modules(
@@ -884,15 +899,15 @@ class LoRANetwork(torch.nn.Module):
for i, text_encoder in enumerate(text_encoders):
if len(text_encoders) > 1:
index = i + 1
print(f"create LoRA for Text Encoder {index}:")
logger.info(f"create LoRA for Text Encoder {index}:")
else:
index = None
print(f"create LoRA for Text Encoder:")
logger.info(f"create LoRA for Text Encoder:")
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
self.text_encoder_loras.extend(text_encoder_loras)
skipped_te += skipped
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
@@ -900,15 +915,15 @@ class LoRANetwork(torch.nn.Module):
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
skipped = skipped_te + skipped_un
if varbose and len(skipped) > 0:
print(
logger.warning(
f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
)
for name in skipped:
print(f"\t{name}")
logger.info(f"\t{name}")
self.up_lr_weight: List[float] = None
self.down_lr_weight: List[float] = None
@@ -926,6 +941,10 @@ class LoRANetwork(torch.nn.Module):
for lora in self.text_encoder_loras + self.unet_loras:
lora.multiplier = self.multiplier
def set_enabled(self, is_enabled):
for lora in self.text_encoder_loras + self.unet_loras:
lora.enabled = is_enabled
def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
@@ -939,12 +958,12 @@ class LoRANetwork(torch.nn.Module):
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
if apply_text_encoder:
print("enable LoRA for text encoder")
logger.info("enable LoRA for text encoder")
else:
self.text_encoder_loras = []
if apply_unet:
print("enable LoRA for U-Net")
logger.info("enable LoRA for U-Net")
else:
self.unet_loras = []
@@ -966,12 +985,12 @@ class LoRANetwork(torch.nn.Module):
apply_unet = True
if apply_text_encoder:
print("enable LoRA for text encoder")
logger.info("enable LoRA for text encoder")
else:
self.text_encoder_loras = []
if apply_unet:
print("enable LoRA for U-Net")
logger.info("enable LoRA for U-Net")
else:
self.unet_loras = []
@@ -982,7 +1001,7 @@ class LoRANetwork(torch.nn.Module):
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
lora.merge_to(sd_for_lora, dtype, device)
print(f"weights are merged")
logger.info(f"weights are merged")
# 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない
def set_block_lr_weight(
@@ -1113,7 +1132,7 @@ class LoRANetwork(torch.nn.Module):
for lora in self.text_encoder_loras + self.unet_loras:
lora.set_network(self)
def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared):
def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared, ds_ratio=None):
self.batch_size = batch_size
self.num_sub_prompts = num_sub_prompts
self.current_size = (height, width)
@@ -1128,7 +1147,7 @@ class LoRANetwork(torch.nn.Module):
device = ref_weight.device
def resize_add(mh, mw):
# print(mh, mw, mh * mw)
# logger.info(mh, mw, mh * mw)
m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16
m = m.to(device, dtype=dtype)
mask_dic[mh * mw] = m
@@ -1139,6 +1158,13 @@ class LoRANetwork(torch.nn.Module):
resize_add(h, w)
if h % 2 == 1 or w % 2 == 1: # add extra shape if h/w is not divisible by 2
resize_add(h + h % 2, w + w % 2)
# deep shrink
if ds_ratio is not None:
hd = int(h * ds_ratio)
wd = int(w * ds_ratio)
resize_add(hd, wd)
h = (h + 1) // 2
w = (w + 1) // 2

View File

@@ -9,8 +9,15 @@ from diffusers import UNet2DConditionModel
import numpy as np
from tqdm import tqdm
from transformers import CLIPTextModel
import torch
import torch
from library.device_utils import init_ipex, get_preferred_device
init_ipex()
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def make_unet_conversion_map() -> Dict[str, str]:
unet_conversion_map_layer = []
@@ -248,7 +255,7 @@ def create_network_from_weights(
elif "lora_down" in key:
dim = value.size()[0]
modules_dim[lora_name] = dim
# print(lora_name, value.size(), dim)
# logger.info(f"{lora_name} {value.size()} {dim}")
# support old LoRA without alpha
for key in modules_dim.keys():
@@ -291,12 +298,12 @@ class LoRANetwork(torch.nn.Module):
super().__init__()
self.multiplier = multiplier
print(f"create LoRA network from weights")
logger.info("create LoRA network from weights")
# convert SDXL Stability AI's U-Net modules to Diffusers
converted = self.convert_unet_modules(modules_dim, modules_alpha)
if converted:
print(f"converted {converted} Stability AI's U-Net LoRA modules to Diffusers (SDXL)")
logger.info(f"converted {converted} Stability AI's U-Net LoRA modules to Diffusers (SDXL)")
# create module instances
def create_modules(
@@ -331,7 +338,7 @@ class LoRANetwork(torch.nn.Module):
lora_name = lora_name.replace(".", "_")
if lora_name not in modules_dim:
# print(f"skipped {lora_name} (not found in modules_dim)")
# logger.info(f"skipped {lora_name} (not found in modules_dim)")
skipped.append(lora_name)
continue
@@ -362,18 +369,18 @@ class LoRANetwork(torch.nn.Module):
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
self.text_encoder_loras.extend(text_encoder_loras)
skipped_te += skipped
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
if len(skipped_te) > 0:
print(f"skipped {len(skipped_te)} modules because of missing weight for text encoder.")
logger.warning(f"skipped {len(skipped_te)} modules because of missing weight for text encoder.")
# extend U-Net target modules to include Conv2d 3x3
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
self.unet_loras: List[LoRAModule]
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
if len(skipped_un) > 0:
print(f"skipped {len(skipped_un)} modules because of missing weight for U-Net.")
logger.warning(f"skipped {len(skipped_un)} modules because of missing weight for U-Net.")
# assertion
names = set()
@@ -420,11 +427,11 @@ class LoRANetwork(torch.nn.Module):
def apply_to(self, multiplier=1.0, apply_text_encoder=True, apply_unet=True):
if apply_text_encoder:
print("enable LoRA for text encoder")
logger.info("enable LoRA for text encoder")
for lora in self.text_encoder_loras:
lora.apply_to(multiplier)
if apply_unet:
print("enable LoRA for U-Net")
logger.info("enable LoRA for U-Net")
for lora in self.unet_loras:
lora.apply_to(multiplier)
@@ -433,16 +440,16 @@ class LoRANetwork(torch.nn.Module):
lora.unapply_to()
def merge_to(self, multiplier=1.0):
print("merge LoRA weights to original weights")
logger.info("merge LoRA weights to original weights")
for lora in tqdm(self.text_encoder_loras + self.unet_loras):
lora.merge_to(multiplier)
print(f"weights are merged")
logger.info(f"weights are merged")
def restore_from(self, multiplier=1.0):
print("restore LoRA weights from original weights")
logger.info("restore LoRA weights from original weights")
for lora in tqdm(self.text_encoder_loras + self.unet_loras):
lora.restore_from(multiplier)
print(f"weights are restored")
logger.info(f"weights are restored")
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
# convert SDXL Stability AI's state dict to Diffusers' based state dict
@@ -463,7 +470,7 @@ class LoRANetwork(torch.nn.Module):
my_state_dict = self.state_dict()
for key in state_dict.keys():
if state_dict[key].size() != my_state_dict[key].size():
# print(f"convert {key} from {state_dict[key].size()} to {my_state_dict[key].size()}")
# logger.info(f"convert {key} from {state_dict[key].size()} to {my_state_dict[key].size()}")
state_dict[key] = state_dict[key].view(my_state_dict[key].size())
return super().load_state_dict(state_dict, strict)
@@ -476,7 +483,7 @@ if __name__ == "__main__":
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = get_preferred_device()
parser = argparse.ArgumentParser()
parser.add_argument("--model_id", type=str, default=None, help="model id for huggingface")
@@ -490,7 +497,7 @@ if __name__ == "__main__":
image_prefix = args.model_id.replace("/", "_") + "_"
# load Diffusers model
print(f"load model from {args.model_id}")
logger.info(f"load model from {args.model_id}")
pipe: Union[StableDiffusionPipeline, StableDiffusionXLPipeline]
if args.sdxl:
# use_safetensors=True does not work with 0.18.2
@@ -503,7 +510,7 @@ if __name__ == "__main__":
text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if args.sdxl else [pipe.text_encoder]
# load LoRA weights
print(f"load LoRA weights from {args.lora_weights}")
logger.info(f"load LoRA weights from {args.lora_weights}")
if os.path.splitext(args.lora_weights)[1] == ".safetensors":
from safetensors.torch import load_file
@@ -512,10 +519,10 @@ if __name__ == "__main__":
lora_sd = torch.load(args.lora_weights)
# create by LoRA weights and load weights
print(f"create LoRA network")
logger.info(f"create LoRA network")
lora_network: LoRANetwork = create_network_from_weights(text_encoders, pipe.unet, lora_sd, multiplier=1.0)
print(f"load LoRA network weights")
logger.info(f"load LoRA network weights")
lora_network.load_state_dict(lora_sd)
lora_network.to(device, dtype=pipe.unet.dtype) # required to apply_to. merge_to works without this
@@ -544,34 +551,34 @@ if __name__ == "__main__":
random.seed(seed)
# create image with original weights
print(f"create image with original weights")
logger.info(f"create image with original weights")
seed_everything(args.seed)
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
image.save(image_prefix + "original.png")
# apply LoRA network to the model: slower than merge_to, but can be reverted easily
print(f"apply LoRA network to the model")
logger.info(f"apply LoRA network to the model")
lora_network.apply_to(multiplier=1.0)
print(f"create image with applied LoRA")
logger.info(f"create image with applied LoRA")
seed_everything(args.seed)
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
image.save(image_prefix + "applied_lora.png")
# unapply LoRA network to the model
print(f"unapply LoRA network to the model")
logger.info(f"unapply LoRA network to the model")
lora_network.unapply_to()
print(f"create image with unapplied LoRA")
logger.info(f"create image with unapplied LoRA")
seed_everything(args.seed)
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
image.save(image_prefix + "unapplied_lora.png")
# merge LoRA network to the model: faster than apply_to, but requires back-up of original weights (or unmerge_to)
print(f"merge LoRA network to the model")
logger.info(f"merge LoRA network to the model")
lora_network.merge_to(multiplier=1.0)
print(f"create image with LoRA")
logger.info(f"create image with LoRA")
seed_everything(args.seed)
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
image.save(image_prefix + "merged_lora.png")
@@ -579,31 +586,31 @@ if __name__ == "__main__":
# restore (unmerge) LoRA weights: numerically unstable
# マージされた重みを元に戻す。計算誤差のため、元の重みと完全に一致しないことがあるかもしれない
# 保存したstate_dictから元の重みを復元するのが確実
print(f"restore (unmerge) LoRA weights")
logger.info(f"restore (unmerge) LoRA weights")
lora_network.restore_from(multiplier=1.0)
print(f"create image without LoRA")
logger.info(f"create image without LoRA")
seed_everything(args.seed)
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
image.save(image_prefix + "unmerged_lora.png")
# restore original weights
print(f"restore original weights")
logger.info(f"restore original weights")
pipe.unet.load_state_dict(org_unet_sd)
pipe.text_encoder.load_state_dict(org_text_encoder_sd)
if args.sdxl:
pipe.text_encoder_2.load_state_dict(org_text_encoder_2_sd)
print(f"create image with restored original weights")
logger.info(f"create image with restored original weights")
seed_everything(args.seed)
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
image.save(image_prefix + "restore_original.png")
# use convenience function to merge LoRA weights
print(f"merge LoRA weights with convenience function")
logger.info(f"merge LoRA weights with convenience function")
merge_lora_weights(pipe, lora_sd, multiplier=1.0)
print(f"create image with merged LoRA weights")
logger.info(f"create image with merged LoRA weights")
seed_everything(args.seed)
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
image.save(image_prefix + "convenience_merged_lora.png")

View File

@@ -14,7 +14,10 @@ from transformers import CLIPTextModel
import numpy as np
import torch
import re
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
@@ -49,7 +52,7 @@ class LoRAModule(torch.nn.Module):
# if limit_rank:
# self.lora_dim = min(lora_dim, in_dim, out_dim)
# if self.lora_dim != lora_dim:
# print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
# logger.info(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
# else:
self.lora_dim = lora_dim
@@ -197,7 +200,7 @@ class LoRAInfModule(LoRAModule):
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# print(conved.size(), weight.size(), module.stride, module.padding)
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + self.multiplier * conved * self.scale
# set weight to org_module
@@ -236,7 +239,7 @@ class LoRAInfModule(LoRAModule):
self.region_mask = None
def default_forward(self, x):
# print("default_forward", self.lora_name, x.size())
# logger.info("default_forward", self.lora_name, x.size())
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
def forward(self, x):
@@ -278,7 +281,7 @@ class LoRAInfModule(LoRAModule):
# apply mask for LoRA result
lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
mask = self.get_mask_for_x(lx)
# print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size())
# logger.info("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size())
lx = lx * mask
x = self.org_forward(x)
@@ -307,7 +310,7 @@ class LoRAInfModule(LoRAModule):
if has_real_uncond:
query[-self.network.batch_size :] = x[-self.network.batch_size :]
# print("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts)
# logger.info("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts)
return query
def sub_prompt_forward(self, x):
@@ -322,7 +325,7 @@ class LoRAInfModule(LoRAModule):
lx = x[emb_idx :: self.network.num_sub_prompts]
lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale
# print("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx)
# logger.info("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx)
x = self.org_forward(x)
x[emb_idx :: self.network.num_sub_prompts] += lx
@@ -330,7 +333,7 @@ class LoRAInfModule(LoRAModule):
return x
def to_out_forward(self, x):
# print("to_out_forward", self.lora_name, x.size(), self.network.is_last_network)
# logger.info("to_out_forward", self.lora_name, x.size(), self.network.is_last_network)
if self.network.is_last_network:
masks = [None] * self.network.num_sub_prompts
@@ -348,7 +351,7 @@ class LoRAInfModule(LoRAModule):
)
self.network.shared[self.lora_name] = (lx, masks)
# print("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts)
# logger.info("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts)
lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1
masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1)
@@ -367,7 +370,7 @@ class LoRAInfModule(LoRAModule):
if has_real_uncond:
out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond
# print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts)
# logger.info("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts)
# for i in range(len(masks)):
# if masks[i] is None:
# masks[i] = torch.zeros_like(masks[-1])
@@ -389,7 +392,7 @@ class LoRAInfModule(LoRAModule):
x1 = x1 + lx1
out[self.network.batch_size + i] = x1
# print("to_out_forward", x.size(), out.size(), has_real_uncond)
# logger.info("to_out_forward", x.size(), out.size(), has_real_uncond)
return out
@@ -526,7 +529,7 @@ def get_block_dims_and_alphas(
len(block_dims) == num_total_blocks
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
else:
print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
logger.warning(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
block_dims = [network_dim] * num_total_blocks
if block_alphas is not None:
@@ -535,7 +538,7 @@ def get_block_dims_and_alphas(
len(block_alphas) == num_total_blocks
), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください"
else:
print(
logger.warning(
f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります"
)
block_alphas = [network_alpha] * num_total_blocks
@@ -555,13 +558,13 @@ def get_block_dims_and_alphas(
else:
if conv_alpha is None:
conv_alpha = 1.0
print(
logger.warning(
f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります"
)
conv_block_alphas = [conv_alpha] * num_total_blocks
else:
if conv_dim is not None:
print(
logger.warning(
f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります"
)
conv_block_dims = [conv_dim] * num_total_blocks
@@ -601,7 +604,7 @@ def get_block_lr_weight(
elif name == "zeros":
return [0.0 + base_lr] * max_len
else:
print(
logger.error(
"Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros"
% (name)
)
@@ -613,14 +616,14 @@ def get_block_lr_weight(
up_lr_weight = get_list(up_lr_weight)
if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len):
print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
logger.warning("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
logger.warning("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
up_lr_weight = up_lr_weight[:max_len]
down_lr_weight = down_lr_weight[:max_len]
if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len):
print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
logger.warning("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
logger.warning("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
if down_lr_weight != None and len(down_lr_weight) < max_len:
down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight))
@@ -628,24 +631,24 @@ def get_block_lr_weight(
up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight))
if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
print("apply block learning rate / 階層別学習率を適用します。")
logger.info("apply block learning rate / 階層別学習率を適用します。")
if down_lr_weight != None:
down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight]
print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", down_lr_weight)
logger.info(f"down_lr_weight (shallower -> deeper, 浅い層->深い層): {down_lr_weight}")
else:
print("down_lr_weight: all 1.0, すべて1.0")
logger.info("down_lr_weight: all 1.0, すべて1.0")
if mid_lr_weight != None:
mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0
print("mid_lr_weight:", mid_lr_weight)
logger.info(f"mid_lr_weight: {mid_lr_weight}")
else:
print("mid_lr_weight: 1.0")
logger.info("mid_lr_weight: 1.0")
if up_lr_weight != None:
up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight]
print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", up_lr_weight)
logger.info(f"up_lr_weight (deeper -> shallower, 深い層->浅い層): {up_lr_weight}")
else:
print("up_lr_weight: all 1.0, すべて1.0")
logger.info("up_lr_weight: all 1.0, すべて1.0")
return down_lr_weight, mid_lr_weight, up_lr_weight
@@ -726,7 +729,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
elif "lora_down" in key:
dim = value.size()[0]
modules_dim[lora_name] = dim
# print(lora_name, value.size(), dim)
# logger.info(lora_name, value.size(), dim)
# support old LoRA without alpha
for key in modules_dim.keys():
@@ -801,20 +804,20 @@ class LoRANetwork(torch.nn.Module):
self.module_dropout = module_dropout
if modules_dim is not None:
print(f"create LoRA network from weights")
logger.info(f"create LoRA network from weights")
elif block_dims is not None:
print(f"create LoRA network from block_dims")
print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
print(f"block_dims: {block_dims}")
print(f"block_alphas: {block_alphas}")
logger.info(f"create LoRA network from block_dims")
logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
logger.info(f"block_dims: {block_dims}")
logger.info(f"block_alphas: {block_alphas}")
if conv_block_dims is not None:
print(f"conv_block_dims: {conv_block_dims}")
print(f"conv_block_alphas: {conv_block_alphas}")
logger.info(f"conv_block_dims: {conv_block_dims}")
logger.info(f"conv_block_alphas: {conv_block_alphas}")
else:
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
if self.conv_lora_dim is not None:
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
# create module instances
def create_modules(
@@ -899,15 +902,15 @@ class LoRANetwork(torch.nn.Module):
for i, text_encoder in enumerate(text_encoders):
if len(text_encoders) > 1:
index = i + 1
print(f"create LoRA for Text Encoder {index}:")
logger.info(f"create LoRA for Text Encoder {index}:")
else:
index = None
print(f"create LoRA for Text Encoder:")
logger.info(f"create LoRA for Text Encoder:")
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
self.text_encoder_loras.extend(text_encoder_loras)
skipped_te += skipped
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
@@ -915,15 +918,15 @@ class LoRANetwork(torch.nn.Module):
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
skipped = skipped_te + skipped_un
if varbose and len(skipped) > 0:
print(
logger.warning(
f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
)
for name in skipped:
print(f"\t{name}")
logger.info(f"\t{name}")
self.up_lr_weight: List[float] = None
self.down_lr_weight: List[float] = None
@@ -954,12 +957,12 @@ class LoRANetwork(torch.nn.Module):
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
if apply_text_encoder:
print("enable LoRA for text encoder")
logger.info("enable LoRA for text encoder")
else:
self.text_encoder_loras = []
if apply_unet:
print("enable LoRA for U-Net")
logger.info("enable LoRA for U-Net")
else:
self.unet_loras = []
@@ -981,12 +984,12 @@ class LoRANetwork(torch.nn.Module):
apply_unet = True
if apply_text_encoder:
print("enable LoRA for text encoder")
logger.info("enable LoRA for text encoder")
else:
self.text_encoder_loras = []
if apply_unet:
print("enable LoRA for U-Net")
logger.info("enable LoRA for U-Net")
else:
self.unet_loras = []
@@ -997,7 +1000,7 @@ class LoRANetwork(torch.nn.Module):
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
lora.merge_to(sd_for_lora, dtype, device)
print(f"weights are merged")
logger.info(f"weights are merged")
# 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない
def set_block_lr_weight(
@@ -1144,7 +1147,7 @@ class LoRANetwork(torch.nn.Module):
device = ref_weight.device
def resize_add(mh, mw):
# print(mh, mw, mh * mw)
# logger.info(mh, mw, mh * mw)
m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16
m = m.to(device, dtype=dtype)
mask_dic[mh * mw] = m

View File

@@ -5,27 +5,34 @@ from library import model_util
import library.train_util as train_util
import argparse
from transformers import CLIPTokenizer
import torch
from library.device_utils import init_ipex, get_preferred_device
init_ipex()
import library.model_util as model_util
import lora
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE = get_preferred_device()
def interrogate(args):
weights_dtype = torch.float16
# いろいろ準備する
print(f"loading SD model: {args.sd_model}")
logger.info(f"loading SD model: {args.sd_model}")
args.pretrained_model_name_or_path = args.sd_model
args.vae = None
text_encoder, vae, unet, _ = train_util._load_target_model(args,weights_dtype, DEVICE)
print(f"loading LoRA: {args.model}")
logger.info(f"loading LoRA: {args.model}")
network, weights_sd = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)
# text encoder向けの重みがあるかチェックする本当はlora側でやるのがいい
@@ -35,11 +42,11 @@ def interrogate(args):
has_te_weight = True
break
if not has_te_weight:
print("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません")
logger.error("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません")
return
del vae
print("loading tokenizer")
logger.info("loading tokenizer")
if args.v2:
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
else:
@@ -53,7 +60,7 @@ def interrogate(args):
# トークンをひとつひとつ当たっていく
token_id_start = 0
token_id_end = max(tokenizer.all_special_ids)
print(f"interrogate tokens are: {token_id_start} to {token_id_end}")
logger.info(f"interrogate tokens are: {token_id_start} to {token_id_end}")
def get_all_embeddings(text_encoder):
embs = []
@@ -79,24 +86,24 @@ def interrogate(args):
embs.extend(encoder_hidden_states)
return torch.stack(embs)
print("get original text encoder embeddings.")
logger.info("get original text encoder embeddings.")
orig_embs = get_all_embeddings(text_encoder)
network.apply_to(text_encoder, unet, True, len(network.unet_loras) > 0)
info = network.load_state_dict(weights_sd, strict=False)
print(f"Loading LoRA weights: {info}")
logger.info(f"Loading LoRA weights: {info}")
network.to(DEVICE, dtype=weights_dtype)
network.eval()
del unet
print("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません")
print("get text encoder embeddings with lora.")
logger.info("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません")
logger.info("get text encoder embeddings with lora.")
lora_embs = get_all_embeddings(text_encoder)
# 比べる:とりあえず単純に差分の絶対値で
print("comparing...")
logger.info("comparing...")
diffs = {}
for i, (orig_emb, lora_emb) in enumerate(zip(orig_embs, tqdm(lora_embs))):
diff = torch.mean(torch.abs(orig_emb - lora_emb))

View File

@@ -7,7 +7,10 @@ from safetensors.torch import load_file, save_file
from library import sai_model_spec, train_util
import library.model_util as model_util
import lora
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == ".safetensors":
@@ -61,10 +64,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
name_to_module[lora_name] = child_module
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
logger.info(f"loading: {model}")
lora_sd, _ = load_state_dict(model, merge_dtype)
print(f"merging...")
logger.info(f"merging...")
for key in lora_sd.keys():
if "lora_down" in key:
up_key = key.replace("lora_down", "lora_up")
@@ -73,10 +76,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
# find original module for this lora
module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
if module_name not in name_to_module:
print(f"no module found for LoRA weight: {key}")
logger.info(f"no module found for LoRA weight: {key}")
continue
module = name_to_module[module_name]
# print(f"apply {key} to {module}")
# logger.info(f"apply {key} to {module}")
down_weight = lora_sd[key]
up_weight = lora_sd[up_key]
@@ -104,7 +107,7 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# print(conved.size(), weight.size(), module.stride, module.padding)
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + ratio * conved * scale
module.weight = torch.nn.Parameter(weight)
@@ -118,7 +121,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
v2 = None
base_model = None
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
logger.info(f"loading: {model}")
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
if lora_metadata is not None:
@@ -151,10 +154,10 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
if lora_module_name not in base_alphas:
base_alphas[lora_module_name] = alpha
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
logger.info(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
# merge
print(f"merging...")
logger.info(f"merging...")
for key in lora_sd.keys():
if "alpha" in key:
continue
@@ -196,8 +199,8 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
merged_sd[key_down] = merged_sd[key_down][perm]
merged_sd[key_up] = merged_sd[key_up][:,perm]
print("merged model")
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
logger.info("merged model")
logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
# check all dims are same
dims_list = list(set(base_dims.values()))
@@ -239,7 +242,7 @@ def merge(args):
save_dtype = merge_dtype
if args.sd_model is not None:
print(f"loading SD model: {args.sd_model}")
logger.info(f"loading SD model: {args.sd_model}")
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
@@ -264,18 +267,18 @@ def merge(args):
)
if args.v2:
# TODO read sai modelspec
print(
logger.warning(
"Cannot determine if model is for v-prediction, so save metadata as v-prediction / modelがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します"
)
print(f"saving SD model to: {args.save_to}")
logger.info(f"saving SD model to: {args.save_to}")
model_util.save_stable_diffusion_checkpoint(
args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, sai_metadata, save_dtype, vae
)
else:
state_dict, metadata, v2 = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
print(f"calculating hashes and creating metadata...")
logger.info(f"calculating hashes and creating metadata...")
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
@@ -289,12 +292,12 @@ def merge(args):
)
if v2:
# TODO read sai modelspec
print(
logger.warning(
"Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します"
)
metadata.update(sai_metadata)
print(f"saving model to: {args.save_to}")
logger.info(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)

View File

@@ -6,7 +6,10 @@ import torch
from safetensors.torch import load_file, save_file
import library.model_util as model_util
import lora
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == '.safetensors':
@@ -54,10 +57,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
name_to_module[lora_name] = child_module
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
logger.info(f"loading: {model}")
lora_sd = load_state_dict(model, merge_dtype)
print(f"merging...")
logger.info(f"merging...")
for key in lora_sd.keys():
if "lora_down" in key:
up_key = key.replace("lora_down", "lora_up")
@@ -66,10 +69,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
# find original module for this lora
module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight"
if module_name not in name_to_module:
print(f"no module found for LoRA weight: {key}")
logger.info(f"no module found for LoRA weight: {key}")
continue
module = name_to_module[module_name]
# print(f"apply {key} to {module}")
# logger.info(f"apply {key} to {module}")
down_weight = lora_sd[key]
up_weight = lora_sd[up_key]
@@ -96,10 +99,10 @@ def merge_lora_models(models, ratios, merge_dtype):
alpha = None
dim = None
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
logger.info(f"loading: {model}")
lora_sd = load_state_dict(model, merge_dtype)
print(f"merging...")
logger.info(f"merging...")
for key in lora_sd.keys():
if 'alpha' in key:
if key in merged_sd:
@@ -117,7 +120,7 @@ def merge_lora_models(models, ratios, merge_dtype):
dim = lora_sd[key].size()[0]
merged_sd[key] = lora_sd[key] * ratio
print(f"dim (rank): {dim}, alpha: {alpha}")
logger.info(f"dim (rank): {dim}, alpha: {alpha}")
if alpha is None:
alpha = dim
@@ -142,19 +145,21 @@ def merge(args):
save_dtype = merge_dtype
if args.sd_model is not None:
print(f"loading SD model: {args.sd_model}")
logger.info(f"loading SD model: {args.sd_model}")
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype)
print(f"\nsaving SD model to: {args.save_to}")
logger.info("")
logger.info(f"saving SD model to: {args.save_to}")
model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet,
args.sd_model, 0, 0, save_dtype, vae)
else:
state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype)
print(f"\nsaving model to: {args.save_to}")
logger.info(f"")
logger.info(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype)

View File

@@ -8,7 +8,10 @@ from transformers import CLIPTextModel
import numpy as np
import torch
import re
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
@@ -237,7 +240,7 @@ class OFTNetwork(torch.nn.Module):
self.dim = dim
self.alpha = alpha
print(
logger.info(
f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}"
)
@@ -258,7 +261,7 @@ class OFTNetwork(torch.nn.Module):
if is_linear or is_conv2d_1x1 or (is_conv2d and enable_conv):
oft_name = prefix + "." + name + "." + child_name
oft_name = oft_name.replace(".", "_")
# print(oft_name)
# logger.info(oft_name)
oft = module_class(
oft_name,
@@ -279,7 +282,7 @@ class OFTNetwork(torch.nn.Module):
target_modules += OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
self.unet_ofts: List[OFTModule] = create_modules(unet, target_modules)
print(f"create OFT for U-Net: {len(self.unet_ofts)} modules.")
logger.info(f"create OFT for U-Net: {len(self.unet_ofts)} modules.")
# assertion
names = set()
@@ -316,7 +319,7 @@ class OFTNetwork(torch.nn.Module):
# TODO refactor to common function with apply_to
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
print("enable OFT for U-Net")
logger.info("enable OFT for U-Net")
for oft in self.unet_ofts:
sd_for_lora = {}
@@ -326,7 +329,7 @@ class OFTNetwork(torch.nn.Module):
oft.load_state_dict(sd_for_lora, False)
oft.merge_to()
print(f"weights are merged")
logger.info(f"weights are merged")
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
@@ -338,11 +341,11 @@ class OFTNetwork(torch.nn.Module):
for oft in ofts:
params.extend(oft.parameters())
# print num of params
# logger.info num of params
num_params = 0
for p in params:
num_params += p.numel()
print(f"OFT params: {num_params}")
logger.info(f"OFT params: {num_params}")
return params
param_data = {"params": enumerate_params(self.unet_ofts)}

View File

@@ -2,80 +2,91 @@
# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
# Thanks to cloneofsimo
import os
import argparse
import torch
from safetensors.torch import load_file, save_file, safe_open
from tqdm import tqdm
from library import train_util, model_util
import numpy as np
from library import train_util
from library import model_util
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
MIN_SV = 1e-6
# Model save and load functions
def load_state_dict(file_name, dtype):
if model_util.is_safetensors(file_name):
sd = load_file(file_name)
with safe_open(file_name, framework="pt") as f:
metadata = f.metadata()
else:
sd = torch.load(file_name, map_location='cpu')
metadata = None
if model_util.is_safetensors(file_name):
sd = load_file(file_name)
with safe_open(file_name, framework="pt") as f:
metadata = f.metadata()
else:
sd = torch.load(file_name, map_location="cpu")
metadata = None
for key in list(sd.keys()):
if type(sd[key]) == torch.Tensor:
sd[key] = sd[key].to(dtype)
for key in list(sd.keys()):
if type(sd[key]) == torch.Tensor:
sd[key] = sd[key].to(dtype)
return sd, metadata
return sd, metadata
def save_to_file(file_name, model, state_dict, dtype, metadata):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
def save_to_file(file_name, state_dict, dtype, metadata):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
if model_util.is_safetensors(file_name):
save_file(model, file_name, metadata)
else:
torch.save(model, file_name)
if model_util.is_safetensors(file_name):
save_file(state_dict, file_name, metadata)
else:
torch.save(state_dict, file_name)
# Indexing functions
def index_sv_cumulative(S, target):
original_sum = float(torch.sum(S))
cumulative_sums = torch.cumsum(S, dim=0)/original_sum
index = int(torch.searchsorted(cumulative_sums, target)) + 1
index = max(1, min(index, len(S)-1))
return index
def index_sv_cumulative(S, target):
original_sum = float(torch.sum(S))
cumulative_sums = torch.cumsum(S, dim=0) / original_sum
index = int(torch.searchsorted(cumulative_sums, target)) + 1
index = max(1, min(index, len(S) - 1))
return index
def index_sv_fro(S, target):
S_squared = S.pow(2)
s_fro_sq = float(torch.sum(S_squared))
sum_S_squared = torch.cumsum(S_squared, dim=0)/s_fro_sq
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
index = max(1, min(index, len(S)-1))
S_squared = S.pow(2)
S_fro_sq = float(torch.sum(S_squared))
sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
index = max(1, min(index, len(S) - 1))
return index
return index
def index_sv_ratio(S, target):
max_sv = S[0]
min_sv = max_sv/target
index = int(torch.sum(S > min_sv).item())
index = max(1, min(index, len(S)-1))
max_sv = S[0]
min_sv = max_sv / target
index = int(torch.sum(S > min_sv).item())
index = max(1, min(index, len(S) - 1))
return index
return index
# Modified from Kohaku-blueleaf's extract/merge functions
def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
out_size, in_size, kernel_size, _ = weight.size()
U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device))
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
lora_rank = param_dict["new_rank"]
@@ -92,17 +103,17 @@ def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale
def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
out_size, in_size = weight.size()
U, S, Vh = torch.linalg.svd(weight.to(device))
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
lora_rank = param_dict["new_rank"]
U = U[:, :lora_rank]
S = S[:lora_rank]
U = U @ torch.diag(S)
Vh = Vh[:lora_rank, :]
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu()
param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu()
del U, S, Vh, weight
@@ -113,7 +124,7 @@ def merge_conv(lora_down, lora_up, device):
in_rank, in_size, kernel_size, k_ = lora_down.shape
out_size, out_rank, _, _ = lora_up.shape
assert in_rank == out_rank and kernel_size == k_, f"rank {in_rank} {out_rank} or kernel {kernel_size} {k_} mismatch"
lora_down = lora_down.to(device)
lora_up = lora_up.to(device)
@@ -127,236 +138,274 @@ def merge_linear(lora_down, lora_up, device):
in_rank, in_size = lora_down.shape
out_size, out_rank = lora_up.shape
assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch"
lora_down = lora_down.to(device)
lora_up = lora_up.to(device)
weight = lora_up @ lora_down
del lora_up, lora_down
return weight
# Calculate new rank
def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
param_dict = {}
if dynamic_method=="sv_ratio":
if dynamic_method == "sv_ratio":
# Calculate new dim and alpha based off ratio
new_rank = index_sv_ratio(S, dynamic_param) + 1
new_alpha = float(scale*new_rank)
new_alpha = float(scale * new_rank)
elif dynamic_method=="sv_cumulative":
elif dynamic_method == "sv_cumulative":
# Calculate new dim and alpha based off cumulative sum
new_rank = index_sv_cumulative(S, dynamic_param) + 1
new_alpha = float(scale*new_rank)
new_alpha = float(scale * new_rank)
elif dynamic_method=="sv_fro":
elif dynamic_method == "sv_fro":
# Calculate new dim and alpha based off sqrt sum of squares
new_rank = index_sv_fro(S, dynamic_param) + 1
new_alpha = float(scale*new_rank)
new_alpha = float(scale * new_rank)
else:
new_rank = rank
new_alpha = float(scale*new_rank)
new_alpha = float(scale * new_rank)
if S[0] <= MIN_SV: # Zero matrix, set dim to 1
if S[0] <= MIN_SV: # Zero matrix, set dim to 1
new_rank = 1
new_alpha = float(scale*new_rank)
elif new_rank > rank: # cap max rank at rank
new_alpha = float(scale * new_rank)
elif new_rank > rank: # cap max rank at rank
new_rank = rank
new_alpha = float(scale*new_rank)
new_alpha = float(scale * new_rank)
# Calculate resize info
s_sum = torch.sum(torch.abs(S))
s_rank = torch.sum(torch.abs(S[:new_rank]))
S_squared = S.pow(2)
s_fro = torch.sqrt(torch.sum(S_squared))
s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank]))
fro_percent = float(s_red_fro/s_fro)
fro_percent = float(s_red_fro / s_fro)
param_dict["new_rank"] = new_rank
param_dict["new_alpha"] = new_alpha
param_dict["sum_retained"] = (s_rank)/s_sum
param_dict["sum_retained"] = (s_rank) / s_sum
param_dict["fro_retained"] = fro_percent
param_dict["max_ratio"] = S[0]/S[new_rank - 1]
param_dict["max_ratio"] = S[0] / S[new_rank - 1]
return param_dict
def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
network_alpha = None
network_dim = None
verbose_str = "\n"
fro_list = []
def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
network_alpha = None
network_dim = None
verbose_str = "\n"
fro_list = []
# Extract loaded lora dim and alpha
for key, value in lora_sd.items():
if network_alpha is None and 'alpha' in key:
network_alpha = value
if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
network_dim = value.size()[0]
if network_alpha is not None and network_dim is not None:
break
if network_alpha is None:
network_alpha = network_dim
# Extract loaded lora dim and alpha
for key, value in lora_sd.items():
if network_alpha is None and "alpha" in key:
network_alpha = value
if network_dim is None and "lora_down" in key and len(value.size()) == 2:
network_dim = value.size()[0]
if network_alpha is not None and network_dim is not None:
break
if network_alpha is None:
network_alpha = network_dim
scale = network_alpha/network_dim
scale = network_alpha / network_dim
if dynamic_method:
print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}")
if dynamic_method:
logger.info(
f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}"
)
lora_down_weight = None
lora_up_weight = None
lora_down_weight = None
lora_up_weight = None
o_lora_sd = lora_sd.copy()
block_down_name = None
block_up_name = None
o_lora_sd = lora_sd.copy()
block_down_name = None
block_up_name = None
with torch.no_grad():
for key, value in tqdm(lora_sd.items()):
weight_name = None
if 'lora_down' in key:
block_down_name = key.rsplit('.lora_down', 1)[0]
weight_name = key.rsplit(".", 1)[-1]
lora_down_weight = value
else:
continue
with torch.no_grad():
for key, value in tqdm(lora_sd.items()):
weight_name = None
if "lora_down" in key:
block_down_name = key.rsplit(".lora_down", 1)[0]
weight_name = key.rsplit(".", 1)[-1]
lora_down_weight = value
else:
continue
# find corresponding lora_up and alpha
block_up_name = block_down_name
lora_up_weight = lora_sd.get(block_up_name + '.lora_up.' + weight_name, None)
lora_alpha = lora_sd.get(block_down_name + '.alpha', None)
# find corresponding lora_up and alpha
block_up_name = block_down_name
lora_up_weight = lora_sd.get(block_up_name + ".lora_up." + weight_name, None)
lora_alpha = lora_sd.get(block_down_name + ".alpha", None)
weights_loaded = (lora_down_weight is not None and lora_up_weight is not None)
weights_loaded = lora_down_weight is not None and lora_up_weight is not None
if weights_loaded:
if weights_loaded:
conv2d = (len(lora_down_weight.size()) == 4)
if lora_alpha is None:
scale = 1.0
else:
scale = lora_alpha/lora_down_weight.size()[0]
conv2d = len(lora_down_weight.size()) == 4
if lora_alpha is None:
scale = 1.0
else:
scale = lora_alpha / lora_down_weight.size()[0]
if conv2d:
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
param_dict = extract_conv(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
else:
full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device)
param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
if conv2d:
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
param_dict = extract_conv(full_weight_matrix, new_conv_rank, dynamic_method, dynamic_param, device, scale)
else:
full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device)
param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
if verbose:
max_ratio = param_dict['max_ratio']
sum_retained = param_dict['sum_retained']
fro_retained = param_dict['fro_retained']
if not np.isnan(fro_retained):
fro_list.append(float(fro_retained))
if verbose:
max_ratio = param_dict["max_ratio"]
sum_retained = param_dict["sum_retained"]
fro_retained = param_dict["fro_retained"]
if not np.isnan(fro_retained):
fro_list.append(float(fro_retained))
verbose_str+=f"{block_down_name:75} | "
verbose_str+=f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}"
verbose_str += f"{block_down_name:75} | "
verbose_str += (
f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}"
)
if verbose and dynamic_method:
verbose_str+=f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n"
else:
verbose_str+=f"\n"
if verbose and dynamic_method:
verbose_str += f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n"
else:
verbose_str += "\n"
new_alpha = param_dict['new_alpha']
o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous()
o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous()
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict['new_alpha']).to(save_dtype)
new_alpha = param_dict["new_alpha"]
o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous()
o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous()
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict["new_alpha"]).to(save_dtype)
block_down_name = None
block_up_name = None
lora_down_weight = None
lora_up_weight = None
weights_loaded = False
del param_dict
block_down_name = None
block_up_name = None
lora_down_weight = None
lora_up_weight = None
weights_loaded = False
del param_dict
if verbose:
print(verbose_str)
print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
print("resizing complete")
return o_lora_sd, network_dim, new_alpha
if verbose:
print(verbose_str)
print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
logger.info("resizing complete")
return o_lora_sd, network_dim, new_alpha
def resize(args):
if args.save_to is None or not (args.save_to.endswith('.ckpt') or args.save_to.endswith('.pt') or args.save_to.endswith('.pth') or args.save_to.endswith('.safetensors')):
raise Exception("The --save_to argument must be specified and must be a .ckpt , .pt, .pth or .safetensors file.")
if args.save_to is None or not (
args.save_to.endswith(".ckpt")
or args.save_to.endswith(".pt")
or args.save_to.endswith(".pth")
or args.save_to.endswith(".safetensors")
):
raise Exception("The --save_to argument must be specified and must be a .ckpt , .pt, .pth or .safetensors file.")
def str_to_dtype(p):
if p == 'float':
return torch.float
if p == 'fp16':
return torch.float16
if p == 'bf16':
return torch.bfloat16
return None
args.new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank
if args.dynamic_method and not args.dynamic_param:
raise Exception("If using dynamic_method, then dynamic_param is required")
def str_to_dtype(p):
if p == "float":
return torch.float
if p == "fp16":
return torch.float16
if p == "bf16":
return torch.bfloat16
return None
merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32
save_dtype = str_to_dtype(args.save_precision)
if save_dtype is None:
save_dtype = merge_dtype
if args.dynamic_method and not args.dynamic_param:
raise Exception("If using dynamic_method, then dynamic_param is required")
print("loading Model...")
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
merge_dtype = str_to_dtype("float") # matmul method above only seems to work in float32
save_dtype = str_to_dtype(args.save_precision)
if save_dtype is None:
save_dtype = merge_dtype
print("Resizing Lora...")
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose)
logger.info("loading Model...")
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
# update metadata
if metadata is None:
metadata = {}
logger.info("Resizing Lora...")
state_dict, old_dim, new_alpha = resize_lora_model(
lora_sd, args.new_rank, args.new_conv_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose
)
comment = metadata.get("ss_training_comment", "")
# update metadata
if metadata is None:
metadata = {}
if not args.dynamic_method:
metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}"
metadata["ss_network_dim"] = str(args.new_rank)
metadata["ss_network_alpha"] = str(new_alpha)
else:
metadata["ss_training_comment"] = f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}"
metadata["ss_network_dim"] = 'Dynamic'
metadata["ss_network_alpha"] = 'Dynamic'
comment = metadata.get("ss_training_comment", "")
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
if not args.dynamic_method:
conv_desc = "" if args.new_rank == args.new_conv_rank else f" (conv: {args.new_conv_rank})"
metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}{conv_desc}; {comment}"
metadata["ss_network_dim"] = str(args.new_rank)
metadata["ss_network_alpha"] = str(new_alpha)
else:
metadata["ss_training_comment"] = (
f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}"
)
metadata["ss_network_dim"] = "Dynamic"
metadata["ss_network_alpha"] = "Dynamic"
print(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
logger.info(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, save_dtype, metadata)
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser()
parser.add_argument("--save_precision", type=str, default=None,
choices=[None, "float", "fp16", "bf16"], help="precision in saving, float if omitted / 保存時の精度、未指定時はfloat")
parser.add_argument("--new_rank", type=int, default=4,
help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
parser.add_argument("--save_to", type=str, default=None,
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
parser.add_argument("--model", type=str, default=None,
help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors")
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
parser.add_argument("--verbose", action="store_true",
help="Display verbose resizing information / rank変更時の詳細情報を出力する")
parser.add_argument("--dynamic_method", type=str, default=None, choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"],
help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank")
parser.add_argument("--dynamic_param", type=float, default=None,
help="Specify target for dynamic reduction")
return parser
parser.add_argument(
"--save_precision",
type=str,
default=None,
choices=[None, "float", "fp16", "bf16"],
help="precision in saving, float if omitted / 保存時の精度、未指定時はfloat",
)
parser.add_argument("--new_rank", type=int, default=4, help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
parser.add_argument(
"--new_conv_rank",
type=int,
default=None,
help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ",
)
parser.add_argument(
"--save_to",
type=str,
default=None,
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors",
)
parser.add_argument(
"--model",
type=str,
default=None,
help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors",
)
parser.add_argument(
"--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う"
)
parser.add_argument(
"--verbose", action="store_true", help="Display verbose resizing information / rank変更時の詳細情報を出力する"
)
parser.add_argument(
"--dynamic_method",
type=str,
default=None,
choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"],
help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank",
)
parser.add_argument("--dynamic_param", type=float, default=None, help="Specify target for dynamic reduction")
return parser
if __name__ == '__main__':
parser = setup_parser()
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
resize(args)
args = parser.parse_args()
resize(args)

View File

@@ -8,7 +8,10 @@ from tqdm import tqdm
from library import sai_model_spec, sdxl_model_util, train_util
import library.model_util as model_util
import lora
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == ".safetensors":
@@ -66,10 +69,10 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
name_to_module[lora_name] = child_module
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
logger.info(f"loading: {model}")
lora_sd, _ = load_state_dict(model, merge_dtype)
print(f"merging...")
logger.info(f"merging...")
for key in tqdm(lora_sd.keys()):
if "lora_down" in key:
up_key = key.replace("lora_down", "lora_up")
@@ -78,10 +81,10 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
# find original module for this lora
module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
if module_name not in name_to_module:
print(f"no module found for LoRA weight: {key}")
logger.info(f"no module found for LoRA weight: {key}")
continue
module = name_to_module[module_name]
# print(f"apply {key} to {module}")
# logger.info(f"apply {key} to {module}")
down_weight = lora_sd[key]
up_weight = lora_sd[up_key]
@@ -92,7 +95,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
# W <- W + U * D
weight = module.weight
# print(module_name, down_weight.size(), up_weight.size())
# logger.info(module_name, down_weight.size(), up_weight.size())
if len(weight.size()) == 2:
# linear
weight = weight + ratio * (up_weight @ down_weight) * scale
@@ -107,7 +110,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# print(conved.size(), weight.size(), module.stride, module.padding)
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + ratio * conved * scale
module.weight = torch.nn.Parameter(weight)
@@ -121,7 +124,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
v2 = None
base_model = None
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
logger.info(f"loading: {model}")
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
if lora_metadata is not None:
@@ -154,10 +157,10 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
if lora_module_name not in base_alphas:
base_alphas[lora_module_name] = alpha
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
logger.info(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
# merge
print(f"merging...")
logger.info(f"merging...")
for key in tqdm(lora_sd.keys()):
if "alpha" in key:
continue
@@ -200,8 +203,8 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
merged_sd[key_down] = merged_sd[key_down][perm]
merged_sd[key_up] = merged_sd[key_up][:,perm]
print("merged model")
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
logger.info("merged model")
logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
# check all dims are same
dims_list = list(set(base_dims.values()))
@@ -243,7 +246,7 @@ def merge(args):
save_dtype = merge_dtype
if args.sd_model is not None:
print(f"loading SD model: {args.sd_model}")
logger.info(f"loading SD model: {args.sd_model}")
(
text_model1,
@@ -265,14 +268,14 @@ def merge(args):
None, False, False, True, False, False, time.time(), title=title, merged_from=merged_from
)
print(f"saving SD model to: {args.save_to}")
logger.info(f"saving SD model to: {args.save_to}")
sdxl_model_util.save_stable_diffusion_checkpoint(
args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, logit_scale, sai_metadata, save_dtype
)
else:
state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
print(f"calculating hashes and creating metadata...")
logger.info(f"calculating hashes and creating metadata...")
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
@@ -286,7 +289,7 @@ def merge(args):
)
metadata.update(sai_metadata)
print(f"saving model to: {args.save_to}")
logger.info(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)

View File

@@ -1,4 +1,3 @@
import math
import argparse
import os
import time
@@ -8,7 +7,10 @@ from tqdm import tqdm
from library import sai_model_spec, train_util
import library.model_util as model_util
import lora
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
CLAMP_QUANTILE = 0.99
@@ -41,12 +43,12 @@ def save_to_file(file_name, state_dict, dtype, metadata):
def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype):
print(f"new rank: {new_rank}, new conv rank: {new_conv_rank}")
logger.info(f"new rank: {new_rank}, new conv rank: {new_conv_rank}")
merged_sd = {}
v2 = None
base_model = None
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
logger.info(f"loading: {model}")
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
if lora_metadata is not None:
@@ -56,7 +58,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None)
# merge
print(f"merging...")
logger.info(f"merging...")
for key in tqdm(list(lora_sd.keys())):
if "lora_down" not in key:
continue
@@ -73,7 +75,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
out_dim = up_weight.size()[0]
conv2d = len(down_weight.size()) == 4
kernel_size = None if not conv2d else down_weight.size()[2:4]
# print(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size)
# logger.info(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size)
# make original weight if not exist
if lora_module_name not in merged_sd:
@@ -110,7 +112,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
merged_sd[lora_module_name] = weight
# extract from merged weights
print("extract new lora...")
logger.info("extract new lora...")
merged_lora_sd = {}
with torch.no_grad():
for lora_module_name, mat in tqdm(list(merged_sd.items())):
@@ -188,7 +190,7 @@ def merge(args):
args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype
)
print(f"calculating hashes and creating metadata...")
logger.info(f"calculating hashes and creating metadata...")
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
@@ -203,12 +205,12 @@ def merge(args):
)
if v2:
# TODO read sai modelspec
print(
logger.warning(
"Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します"
)
metadata.update(sai_metadata)
print(f"saving model to: {args.save_to}")
logger.info(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, save_dtype, metadata)

View File

@@ -1,20 +1,24 @@
accelerate==0.23.0
transformers==4.30.2
diffusers[torch]==0.21.2
accelerate==0.25.0
transformers==4.36.2
diffusers[torch]==0.25.0
ftfy==6.1.1
# albumentations==1.3.0
opencv-python==4.7.0.68
einops==0.6.0
einops==0.7.0
pytorch-lightning==1.9.0
# bitsandbytes==0.39.1
tensorboard==2.10.1
safetensors==0.3.1
bitsandbytes==0.43.0
prodigyopt==1.0
lion-pytorch==0.0.6
tensorboard
safetensors==0.4.2
# gradio==3.16.2
altair==4.2.2
easygui==0.98.3
toml==0.10.2
voluptuous==0.13.1
huggingface-hub==0.15.1
huggingface-hub==0.20.1
# for Image utils
imagesize==1.4.1
# for BLIP captioning
# requests==2.28.2
# timm==0.6.12
@@ -22,12 +26,17 @@ huggingface-hub==0.15.1
# for WD14 captioning (tensorflow)
# tensorflow==2.10.1
# for WD14 captioning (onnx)
# onnx==1.14.1
# onnxruntime-gpu==1.16.0
# onnxruntime==1.16.0
# onnx==1.15.0
# onnxruntime-gpu==1.17.1
# onnxruntime==1.17.1
# for cuda 12.1(default 11.8)
# onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/
# this is for onnx:
# protobuf==3.20.3
# open clip for SDXL
open-clip-torch==2.20.0
# open-clip-torch==2.20.0
# For logging
rich==13.7.0
# for kohya_ss library
-e .

File diff suppressed because it is too large Load Diff

View File

@@ -8,23 +8,28 @@ import os
import random
from einops import repeat
import numpy as np
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from library.device_utils import init_ipex, get_preferred_device
init_ipex()
from tqdm import tqdm
from transformers import CLIPTokenizer
from diffusers import EulerDiscreteScheduler
from PIL import Image
import open_clip
# import open_clip
from safetensors.torch import load_file
from library import model_util, sdxl_model_util
import networks.lora as lora
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# scheduler: このあたりの設定はSD1/2と同じでいいらしい
# scheduler: The settings around here seem to be the same as SD1/2
@@ -87,7 +92,7 @@ if __name__ == "__main__":
guidance_scale = 7
seed = None # 1
DEVICE = "cuda"
DEVICE = get_preferred_device()
DTYPE = torch.float16 # bfloat16 may work
parser = argparse.ArgumentParser()
@@ -142,7 +147,7 @@ if __name__ == "__main__":
vae_dtype = DTYPE
if DTYPE == torch.float16:
print("use float32 for vae")
logger.info("use float32 for vae")
vae_dtype = torch.float32
vae.to(DEVICE, dtype=vae_dtype)
vae.eval()
@@ -153,12 +158,13 @@ if __name__ == "__main__":
text_model2.eval()
unet.set_use_memory_efficient_attention(True, False)
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
vae.set_use_memory_efficient_attention_xformers(True)
# Tokenizers
tokenizer1 = CLIPTokenizer.from_pretrained(text_encoder_1_name)
tokenizer2 = lambda x: open_clip.tokenize(x, context_length=77)
# tokenizer2 = lambda x: open_clip.tokenize(x, context_length=77)
tokenizer2 = CLIPTokenizer.from_pretrained(text_encoder_2_name)
# LoRA
for weights_file in args.lora_weights:
@@ -189,9 +195,11 @@ if __name__ == "__main__":
emb1 = get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256)
emb2 = get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256)
emb3 = get_timestep_embedding(torch.FloatTensor([target_height, target_width]).unsqueeze(0), 256)
# print("emb1", emb1.shape)
# logger.info("emb1", emb1.shape)
c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(DEVICE, dtype=DTYPE)
uc_vector = c_vector.clone().to(DEVICE, dtype=DTYPE) # ちょっとここ正しいかどうかわからない I'm not sure if this is right
uc_vector = c_vector.clone().to(
DEVICE, dtype=DTYPE
) # ちょっとここ正しいかどうかわからない I'm not sure if this is right
# crossattn
@@ -214,13 +222,22 @@ if __name__ == "__main__":
# text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding) # layer normは通さないらしい
# text encoder 2
with torch.no_grad():
tokens = tokenizer2(text2).to(DEVICE)
# tokens = tokenizer2(text2).to(DEVICE)
tokens = tokenizer2(
text,
truncation=True,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt",
)
tokens = batch_encoding["input_ids"].to(DEVICE)
with torch.no_grad():
enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True)
text_embedding2_penu = enc_out["hidden_states"][-2]
# print("hidden_states2", text_embedding2_penu.shape)
text_embedding2_pool = enc_out["text_embeds"] # do not support Textual Inversion
# logger.info("hidden_states2", text_embedding2_penu.shape)
text_embedding2_pool = enc_out["text_embeds"] # do not support Textual Inversion
# 連結して終了 concat and finish
text_embedding = torch.cat([text_embedding1, text_embedding2_penu], dim=2)
@@ -228,7 +245,7 @@ if __name__ == "__main__":
# cond
c_ctx, c_ctx_pool = call_text_encoder(prompt, prompt2)
# print(c_ctx.shape, c_ctx_p.shape, c_vector.shape)
# logger.info(c_ctx.shape, c_ctx_p.shape, c_vector.shape)
c_vector = torch.cat([c_ctx_pool, c_vector], dim=1)
# uncond
@@ -325,4 +342,4 @@ if __name__ == "__main__":
seed = int(seed)
generate_image(prompt, prompt2, negative_prompt, seed)
print("Done!")
logger.info("Done!")

View File

@@ -1,7 +1,6 @@
# training with captions
import argparse
import gc
import math
import os
from multiprocessing import Value
@@ -9,22 +8,26 @@ from typing import List
import toml
from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory_on_device
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
init_ipex()
ipex_init()
except Exception:
pass
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from library import sdxl_model_util
from library import deepspeed_utils, sdxl_model_util
import library.train_util as train_util
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
import library.config_util as config_util
import library.sdxl_train_util as sdxl_train_util
from library.config_util import (
@@ -38,6 +41,7 @@ from library.custom_train_functions import (
scale_v_prediction_loss_like_noise_prediction,
add_v_prediction_like_loss,
apply_debiased_estimation,
apply_masked_loss,
)
from library.sdxl_original_unet import SdxlUNet2DConditionModel
@@ -96,8 +100,12 @@ def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
sdxl_train_util.verify_sdxl_training_args(args)
deepspeed_utils.prepare_deepspeed_args(args)
setup_logging(args, reset=True)
assert not args.weighted_captions, "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
assert (
not args.weighted_captions
), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
assert (
not args.train_text_encoder or not args.cache_text_encoder_outputs
), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません"
@@ -120,20 +128,20 @@ def train(args):
# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
if use_dreambooth_method:
print("Using DreamBooth method.")
logger.info("Using DreamBooth method.")
user_config = {
"datasets": [
{
@@ -144,7 +152,7 @@ def train(args):
]
}
else:
print("Training with captions.")
logger.info("Training with captions.")
user_config = {
"datasets": [
{
@@ -174,7 +182,7 @@ def train(args):
train_util.debug_dataset(train_dataset_group, True)
return
if len(train_dataset_group) == 0:
print(
logger.error(
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
)
return
@@ -190,7 +198,7 @@ def train(args):
), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
# acceleratorを準備する
print("prepare accelerator")
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
@@ -257,9 +265,7 @@ def train(args):
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
@@ -352,8 +358,8 @@ def train(args):
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
@@ -368,7 +374,9 @@ def train(args):
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
@@ -394,26 +402,40 @@ def train(args):
text_encoder1.to(weight_dtype)
text_encoder2.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
if train_unet:
unet = accelerator.prepare(unet)
(unet,) = train_util.transform_models_if_DDP([unet])
# freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer
if train_text_encoder1:
text_encoder1 = accelerator.prepare(text_encoder1)
(text_encoder1,) = train_util.transform_models_if_DDP([text_encoder1])
if train_text_encoder2:
text_encoder2 = accelerator.prepare(text_encoder2)
(text_encoder2,) = train_util.transform_models_if_DDP([text_encoder2])
text_encoder1.text_model.encoder.layers[-1].requires_grad_(False)
text_encoder1.text_model.final_layer_norm.requires_grad_(False)
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
if args.deepspeed:
ds_model = deepspeed_utils.prepare_deepspeed_model(
args,
unet=unet if train_unet else None,
text_encoder1=text_encoder1 if train_text_encoder1 else None,
text_encoder2=text_encoder2 if train_text_encoder2 else None,
)
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
training_models = [ds_model]
else:
# acceleratorがなんかよろしくやってくれるらしい
if train_unet:
unet = accelerator.prepare(unet)
if train_text_encoder1:
text_encoder1 = accelerator.prepare(text_encoder1)
if train_text_encoder2:
text_encoder2 = accelerator.prepare(text_encoder2)
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
if args.cache_text_encoder_outputs:
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
text_encoder1.to("cpu", dtype=torch.float32)
text_encoder2.to("cpu", dtype=torch.float32)
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory_on_device(accelerator.device)
else:
# make sure Text Encoders are on GPU
text_encoder1.to(accelerator.device)
@@ -421,6 +443,8 @@ def train(args):
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
# During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do.
# -> But we think it's ok to patch accelerator even if deepspeed is enabled.
train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする
@@ -438,7 +462,9 @@ def train(args):
accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
accelerator.print(
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
)
# accelerator.print(
# f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
# )
@@ -457,10 +483,17 @@ def train(args):
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
# For --sample_at_first
sdxl_train_util.sample_images(
accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [text_encoder1, text_encoder2], unet
)
loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
@@ -482,7 +515,7 @@ def train(args):
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
latents = torch.nan_to_num(latents, 0, out=latents)
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
@@ -503,6 +536,7 @@ def train(args):
# else:
input_ids1 = input_ids1.to(accelerator.device)
input_ids2 = input_ids2.to(accelerator.device)
# unwrap_model is fine for models not wrapped by accelerator
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
args.max_token_length,
input_ids1,
@@ -512,6 +546,7 @@ def train(args):
text_encoder1,
text_encoder2,
None if not args.full_fp16 else weight_dtype,
accelerator=accelerator,
)
else:
encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
@@ -533,7 +568,7 @@ def train(args):
# assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# print("text encoder outputs verified")
# logger.info("text encoder outputs verified")
# get size embeddings
orig_size = batch["original_sizes_hw"]
@@ -547,7 +582,7 @@ def train(args):
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
@@ -562,9 +597,12 @@ def train(args):
or args.scale_v_pred_loss_like_noise_pred
or args.v_pred_like_loss
or args.debiased_estimation_loss
or args.masked_loss
):
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
if args.masked_loss:
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
if args.min_snr_gamma:
@@ -578,7 +616,7 @@ def train(args):
loss = loss.mean() # mean over batch dimension
else:
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c)
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
@@ -698,7 +736,7 @@ def train(args):
accelerator.end_training()
if args.save_state: # and is_main_process:
if args.save_state or args.save_state_on_train_end:
train_util.save_state_on_train_end(args, accelerator)
del accelerator # この後メモリを使うのでこれは消す
@@ -720,15 +758,18 @@ def train(args):
logit_scale,
ckpt_info,
)
print("model saved.")
logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, False)
train_util.add_masked_loss_arguments(parser)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
@@ -748,7 +789,9 @@ def setup_parser() -> argparse.ArgumentParser:
help="learning rate for text encoder 2 (BiG-G) / text encoder 2 (BiG-G)の学習率",
)
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
parser.add_argument(
"--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する"
)
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
parser.add_argument(
"--no_half_vae",
@@ -762,7 +805,6 @@ def setup_parser() -> argparse.ArgumentParser:
help=f"learning rates for each block of U-Net, comma-separated, {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / "
+ f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値",
)
return parser
@@ -770,6 +812,7 @@ if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
train(args)

View File

@@ -2,7 +2,6 @@
# training code for ControlNet-LLLite with passing cond_image to U-Net's forward
import argparse
import gc
import json
import math
import os
@@ -13,20 +12,17 @@ from types import SimpleNamespace
import toml
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from torch.nn.parallel import DistributedDataParallel as DDP
from accelerate.utils import set_seed
import accelerate
from diffusers import DDPMScheduler, ControlNetModel
from safetensors.torch import load_file
from library import sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util
from library import deepspeed_utils, sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util
import library.model_util as model_util
import library.train_util as train_util
@@ -47,6 +43,12 @@ from library.custom_train_functions import (
apply_debiased_estimation,
)
import networks.control_net_lllite_for_train as control_net_lllite_for_train
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
# TODO 他のスクリプトと共通化する
@@ -67,6 +69,7 @@ def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
sdxl_train_util.verify_sdxl_training_args(args)
setup_logging(args, reset=True)
cache_latents = args.cache_latents
use_user_config = args.dataset_config is not None
@@ -80,11 +83,11 @@ def train(args):
# データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
if use_user_config:
print(f"Load dataset config from {args.dataset_config}")
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "conditioning_data_dir"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
@@ -116,7 +119,7 @@ def train(args):
train_util.debug_dataset(train_dataset_group)
return
if len(train_dataset_group) == 0:
print(
logger.error(
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してくださいtrain_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります"
)
return
@@ -126,7 +129,9 @@ def train(args):
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
else:
print("WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません")
logger.warning(
"WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません"
)
if args.cache_text_encoder_outputs:
assert (
@@ -134,7 +139,7 @@ def train(args):
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
# acceleratorを準備する
print("prepare accelerator")
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
is_main_process = accelerator.is_main_process
@@ -166,9 +171,7 @@ def train(args):
accelerator.is_main_process,
)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
@@ -233,14 +236,14 @@ def train(args):
accelerator.print("prepare optimizer, data loader etc.")
trainable_params = list(unet.prepare_params())
print(f"trainable params count: {len(trainable_params)}")
print(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}")
logger.info(f"trainable params count: {len(trainable_params)}")
logger.info(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}")
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
@@ -256,7 +259,9 @@ def train(args):
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
@@ -283,9 +288,6 @@ def train(args):
# acceleratorがなんかよろしくやってくれるらしい
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
# transform DDP after prepare (train_network here only)
unet = train_util.transform_models_if_DDP([unet])[0]
if args.gradient_checkpointing:
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
else:
@@ -296,8 +298,7 @@ def train(args):
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
text_encoder1.to("cpu", dtype=torch.float32)
text_encoder2.to("cpu", dtype=torch.float32)
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory_on_device(accelerator.device)
else:
# make sure Text Encoders are on GPU
text_encoder1.to(accelerator.device)
@@ -328,8 +329,10 @@ def train(args):
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
# print(f" total train batch size (with parallel & distributed & accumulation) / バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
accelerator.print(
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
)
# logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
@@ -345,6 +348,8 @@ def train(args):
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
@@ -389,15 +394,15 @@ def train(args):
with accelerator.accumulate(unet):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample()
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
latents = torch.nan_to_num(latents, 0, out=latents)
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
@@ -434,7 +439,7 @@ def train(args):
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
@@ -453,14 +458,14 @@ def train(args):
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss:
@@ -544,22 +549,24 @@ def train(args):
accelerator.end_training()
if is_main_process and args.save_state:
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)
if is_main_process:
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(ckpt_name, unet, global_step, num_train_epochs, force_sync_upload=True)
print("model saved.")
logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
@@ -572,8 +579,12 @@ def setup_parser() -> argparse.ArgumentParser:
choices=[None, "ckpt", "pt", "safetensors"],
help="format to save the model (default is .safetensors) / モデル保存時の形式デフォルトはsafetensors",
)
parser.add_argument("--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数")
parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み")
parser.add_argument(
"--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数"
)
parser.add_argument(
"--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み"
)
parser.add_argument("--network_dim", type=int, default=None, help="network dimensions (rank) / モジュールの次元数")
parser.add_argument(
"--network_dropout",
@@ -601,6 +612,7 @@ if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
train(args)

View File

@@ -1,5 +1,4 @@
import argparse
import gc
import json
import math
import os
@@ -10,19 +9,16 @@ from types import SimpleNamespace
import toml
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from torch.nn.parallel import DistributedDataParallel as DDP
from accelerate.utils import set_seed
from diffusers import DDPMScheduler, ControlNetModel
from safetensors.torch import load_file
from library import sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util
from library import deepspeed_utils, sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util
import library.model_util as model_util
import library.train_util as train_util
@@ -43,6 +39,12 @@ from library.custom_train_functions import (
apply_debiased_estimation,
)
import networks.control_net_lllite as control_net_lllite
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
# TODO 他のスクリプトと共通化する
@@ -63,6 +65,7 @@ def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
sdxl_train_util.verify_sdxl_training_args(args)
setup_logging(args, reset=True)
cache_latents = args.cache_latents
use_user_config = args.dataset_config is not None
@@ -76,11 +79,11 @@ def train(args):
# データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
if use_user_config:
print(f"Load dataset config from {args.dataset_config}")
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "conditioning_data_dir"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
@@ -112,7 +115,7 @@ def train(args):
train_util.debug_dataset(train_dataset_group)
return
if len(train_dataset_group) == 0:
print(
logger.error(
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してくださいtrain_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります"
)
return
@@ -122,7 +125,9 @@ def train(args):
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
else:
print("WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません")
logger.warning(
"WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません"
)
if args.cache_text_encoder_outputs:
assert (
@@ -130,7 +135,7 @@ def train(args):
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
# acceleratorを準備する
print("prepare accelerator")
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
is_main_process = accelerator.is_main_process
@@ -165,9 +170,7 @@ def train(args):
accelerator.is_main_process,
)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
@@ -201,14 +204,14 @@ def train(args):
accelerator.print("prepare optimizer, data loader etc.")
trainable_params = list(network.prepare_optimizer_params())
print(f"trainable params count: {len(trainable_params)}")
print(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}")
logger.info(f"trainable params count: {len(trainable_params)}")
logger.info(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}")
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
@@ -224,7 +227,9 @@ def train(args):
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
@@ -254,9 +259,6 @@ def train(args):
)
network: control_net_lllite.ControlNetLLLite
# transform DDP after prepare (train_network here only)
unet, network = train_util.transform_models_if_DDP([unet, network])
if args.gradient_checkpointing:
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
else:
@@ -269,8 +271,7 @@ def train(args):
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
text_encoder1.to("cpu", dtype=torch.float32)
text_encoder2.to("cpu", dtype=torch.float32)
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory_on_device(accelerator.device)
else:
# make sure Text Encoders are on GPU
text_encoder1.to(accelerator.device)
@@ -301,8 +302,10 @@ def train(args):
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
# print(f" total train batch size (with parallel & distributed & accumulation) / バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
accelerator.print(
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
)
# logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
@@ -358,15 +361,15 @@ def train(args):
with accelerator.accumulate(network):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample()
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
latents = torch.nan_to_num(latents, 0, out=latents)
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
@@ -403,7 +406,7 @@ def train(args):
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
@@ -423,14 +426,14 @@ def train(args):
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss:
@@ -521,15 +524,17 @@ def train(args):
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True)
print("model saved.")
logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
@@ -542,8 +547,12 @@ def setup_parser() -> argparse.ArgumentParser:
choices=[None, "ckpt", "pt", "safetensors"],
help="format to save the model (default is .safetensors) / モデル保存時の形式デフォルトはsafetensors",
)
parser.add_argument("--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数")
parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み")
parser.add_argument(
"--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数"
)
parser.add_argument(
"--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み"
)
parser.add_argument("--network_dim", type=int, default=None, help="network dimensions (rank) / モジュールの次元数")
parser.add_argument(
"--network_dropout",
@@ -571,6 +580,7 @@ if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
train(args)

View File

@@ -1,15 +1,15 @@
import argparse
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from library import sdxl_model_util, sdxl_train_util, train_util
import train_network
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class SdxlNetworkTrainer(train_network.NetworkTrainer):
def __init__(self):
@@ -62,13 +62,12 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
if args.cache_text_encoder_outputs:
if not args.lowram:
# メモリ消費を減らす
print("move vae and unet to cpu to save memory")
logger.info("move vae and unet to cpu to save memory")
org_vae_device = vae.device
org_unet_device = unet.device
vae.to("cpu")
unet.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory_on_device(accelerator.device)
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
with accelerator.autocast():
@@ -83,17 +82,16 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
text_encoders[1].to("cpu", dtype=torch.float32)
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory_on_device(accelerator.device)
if not args.lowram:
print("move vae and unet back to original device")
logger.info("move vae and unet back to original device")
vae.to(org_vae_device)
unet.to(org_unet_device)
else:
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
text_encoders[0].to(accelerator.device)
text_encoders[1].to(accelerator.device)
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
text_encoders[1].to(accelerator.device, dtype=weight_dtype)
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
@@ -123,6 +121,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
text_encoders[0],
text_encoders[1],
None if not args.full_fp16 else weight_dtype,
accelerator=accelerator,
)
else:
encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
@@ -144,7 +143,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
# assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
# print("text encoder outputs verified")
# logger.info("text encoder outputs verified")
return encoder_hidden_states1, encoder_hidden_states2, pool2
@@ -179,6 +178,7 @@ if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
trainer = SdxlNetworkTrainer()

View File

@@ -2,15 +2,11 @@ import argparse
import os
import regex
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
import open_clip
from library.device_utils import init_ipex
init_ipex()
from library import sdxl_model_util, sdxl_train_util, train_util
import train_textual_inversion
@@ -64,6 +60,7 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
text_encoders[0],
text_encoders[1],
None if not args.full_fp16 else weight_dtype,
accelerator=accelerator,
)
return encoder_hidden_states1, encoder_hidden_states2, pool2
@@ -134,6 +131,7 @@ if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
trainer = SdxlTextualInversionTrainer()

View File

@@ -16,7 +16,10 @@ from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def cache_to_disk(args: argparse.Namespace) -> None:
train_util.prepare_dataset_args(args, True)
@@ -41,18 +44,18 @@ def cache_to_disk(args: argparse.Namespace) -> None:
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
if use_dreambooth_method:
print("Using DreamBooth method.")
logger.info("Using DreamBooth method.")
user_config = {
"datasets": [
{
@@ -63,7 +66,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
]
}
else:
print("Training with captions.")
logger.info("Training with captions.")
user_config = {
"datasets": [
{
@@ -90,7 +93,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
# acceleratorを準備する
print("prepare accelerator")
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
@@ -98,7 +101,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
# モデルを読み込む
print("load model")
logger.info("load model")
if args.sdxl:
(_, _, _, vae, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
else:
@@ -113,8 +116,8 @@ def cache_to_disk(args: argparse.Namespace) -> None:
# dataloaderを準備する
train_dataset_group.set_caching_mode("latents")
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
@@ -152,7 +155,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
if args.skip_existing:
if train_util.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug):
print(f"Skipping {image_info.latents_npz} because it already exists.")
logger.warning(f"Skipping {image_info.latents_npz} because it already exists.")
continue
image_infos.append(image_info)

View File

@@ -16,7 +16,10 @@ from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def cache_to_disk(args: argparse.Namespace) -> None:
train_util.prepare_dataset_args(args, True)
@@ -48,18 +51,18 @@ def cache_to_disk(args: argparse.Namespace) -> None:
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
if use_dreambooth_method:
print("Using DreamBooth method.")
logger.info("Using DreamBooth method.")
user_config = {
"datasets": [
{
@@ -70,7 +73,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
]
}
else:
print("Training with captions.")
logger.info("Training with captions.")
user_config = {
"datasets": [
{
@@ -95,14 +98,14 @@ def cache_to_disk(args: argparse.Namespace) -> None:
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
# acceleratorを準備する
print("prepare accelerator")
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, _ = train_util.prepare_dtype(args)
# モデルを読み込む
print("load model")
logger.info("load model")
if args.sdxl:
(_, text_encoder1, text_encoder2, _, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
text_encoders = [text_encoder1, text_encoder2]
@@ -118,8 +121,8 @@ def cache_to_disk(args: argparse.Namespace) -> None:
# dataloaderを準備する
train_dataset_group.set_caching_mode("text")
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
@@ -147,7 +150,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
if args.skip_existing:
if os.path.exists(image_info.text_encoder_outputs_npz):
print(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.")
logger.warning(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.")
continue
image_info.input_ids1 = input_ids1

View File

@@ -1,6 +1,10 @@
import argparse
import cv2
import logging
from library.utils import setup_logging
setup_logging()
logger = logging.getLogger(__name__)
def canny(args):
img = cv2.imread(args.input)
@@ -10,7 +14,7 @@ def canny(args):
# canny_img = 255 - canny_img
cv2.imwrite(args.output, canny_img)
print("done!")
logger.info("done!")
def setup_parser() -> argparse.ArgumentParser:

View File

@@ -6,7 +6,10 @@ import torch
from diffusers import StableDiffusionPipeline
import library.model_util as model_util
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def convert(args):
# 引数を確認する
@@ -23,21 +26,23 @@ def convert(args):
is_load_ckpt = os.path.isfile(args.model_to_load)
is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0
assert not is_load_ckpt or args.v1 != args.v2, f"v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です"
assert not is_load_ckpt or args.v1 != args.v2, "v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です"
# assert (
# is_save_ckpt or args.reference_model is not None
# ), f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です"
# モデルを読み込む
msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else ""))
print(f"loading {msg}: {args.model_to_load}")
logger.info(f"loading {msg}: {args.model_to_load}")
if is_load_ckpt:
v2_model = args.v2
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load, unet_use_linear_projection_in_v2=args.unet_use_linear_projection)
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(
v2_model, args.model_to_load, unet_use_linear_projection_in_v2=args.unet_use_linear_projection
)
else:
pipe = StableDiffusionPipeline.from_pretrained(
args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None
args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None, variant=args.variant
)
text_encoder = pipe.text_encoder
vae = pipe.vae
@@ -46,26 +51,37 @@ def convert(args):
if args.v1 == args.v2:
# 自動判定する
v2_model = unet.config.cross_attention_dim == 1024
print("checking model version: model is " + ("v2" if v2_model else "v1"))
logger.info("checking model version: model is " + ("v2" if v2_model else "v1"))
else:
v2_model = not args.v1
# 変換して保存する
msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers"
print(f"converting and saving as {msg}: {args.model_to_save}")
logger.info(f"converting and saving as {msg}: {args.model_to_save}")
if is_save_ckpt:
original_model = args.model_to_load if is_load_ckpt else None
key_count = model_util.save_stable_diffusion_checkpoint(
v2_model, args.model_to_save, text_encoder, unet, original_model, args.epoch, args.global_step, save_dtype, vae
v2_model,
args.model_to_save,
text_encoder,
unet,
original_model,
args.epoch,
args.global_step,
None if args.metadata is None else eval(args.metadata),
save_dtype=save_dtype,
vae=vae,
)
print(f"model saved. total converted state_dict keys: {key_count}")
logger.info(f"model saved. total converted state_dict keys: {key_count}")
else:
print(f"copy scheduler/tokenizer config from: {args.reference_model if args.reference_model is not None else 'default model'}")
logger.info(
f"copy scheduler/tokenizer config from: {args.reference_model if args.reference_model is not None else 'default model'}"
)
model_util.save_diffusers_checkpoint(
v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors
)
print(f"model saved.")
logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser:
@@ -77,7 +93,9 @@ def setup_parser() -> argparse.ArgumentParser:
"--v2", action="store_true", help="load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む"
)
parser.add_argument(
"--unet_use_linear_projection", action="store_true", help="When saving v2 model as Diffusers, set U-Net config to `use_linear_projection=true` (to match stabilityai's model) / Diffusers形式でv2モデルを保存するときにU-Netの設定を`use_linear_projection=true`にするstabilityaiのモデルと合わせる"
"--unet_use_linear_projection",
action="store_true",
help="When saving v2 model as Diffusers, set U-Net config to `use_linear_projection=true` (to match stabilityai's model) / Diffusers形式でv2モデルを保存するときにU-Netの設定を`use_linear_projection=true`にするstabilityaiのモデルと合わせる",
)
parser.add_argument(
"--fp16",
@@ -99,6 +117,18 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--global_step", type=int, default=0, help="global_step to write to checkpoint / checkpointに記録するglobal_stepの値"
)
parser.add_argument(
"--metadata",
type=str,
default=None,
help='モデルに保存されるメタデータ、Pythonの辞書形式で指定 / metadata: metadata written in to the model in Python Dictionary. Example metadata: \'{"name": "model_name", "resolution": "512x512"}\'',
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="読む込むDiffusersのvariantを指定する、例: fp16 / variant: Diffusers variant to load. Example: fp16",
)
parser.add_argument(
"--reference_model",
type=str,

View File

@@ -15,6 +15,10 @@ import os
from anime_face_detector import create_detector
from tqdm import tqdm
import numpy as np
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
KP_REYE = 11
KP_LEYE = 19
@@ -24,7 +28,7 @@ SCORE_THRES = 0.90
def detect_faces(detector, image, min_size):
preds = detector(image) # bgr
# print(len(preds))
# logger.info(len(preds))
faces = []
for pred in preds:
@@ -78,7 +82,7 @@ def process(args):
assert args.crop_ratio is None or args.resize_face_size is None, f"crop_ratio指定時はresize_face_sizeは指定できません"
# アニメ顔検出モデルを読み込む
print("loading face detector.")
logger.info("loading face detector.")
detector = create_detector('yolov3')
# cropの引数を解析する
@@ -97,7 +101,7 @@ def process(args):
crop_h_ratio, crop_v_ratio = [float(t) for t in tokens]
# 画像を処理する
print("processing.")
logger.info("processing.")
output_extension = ".png"
os.makedirs(args.dst_dir, exist_ok=True)
@@ -111,7 +115,7 @@ def process(args):
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
if image.shape[2] == 4:
print(f"image has alpha. ignore / 画像の透明度が設定されているため無視します: {path}")
logger.warning(f"image has alpha. ignore / 画像の透明度が設定されているため無視します: {path}")
image = image[:, :, :3].copy() # copyをしないと内部的に透明度情報が付いたままになるらしい
h, w = image.shape[:2]
@@ -144,11 +148,11 @@ def process(args):
# 顔サイズを基準にリサイズする
scale = args.resize_face_size / face_size
if scale < cur_crop_width / w:
print(
logger.warning(
f"image width too small in face size based resizing / 顔を基準にリサイズすると画像の幅がcrop sizeより小さい顔が相対的に大きすぎるので顔サイズが変わります: {path}")
scale = cur_crop_width / w
if scale < cur_crop_height / h:
print(
logger.warning(
f"image height too small in face size based resizing / 顔を基準にリサイズすると画像の高さがcrop sizeより小さい顔が相対的に大きすぎるので顔サイズが変わります: {path}")
scale = cur_crop_height / h
elif crop_h_ratio is not None:
@@ -157,10 +161,10 @@ def process(args):
else:
# 切り出しサイズ指定あり
if w < cur_crop_width:
print(f"image width too small/ 画像の幅がcrop sizeより小さいので画質が劣化します: {path}")
logger.warning(f"image width too small/ 画像の幅がcrop sizeより小さいので画質が劣化します: {path}")
scale = cur_crop_width / w
if h < cur_crop_height:
print(f"image height too small/ 画像の高さがcrop sizeより小さいので画質が劣化します: {path}")
logger.warning(f"image height too small/ 画像の高さがcrop sizeより小さいので画質が劣化します: {path}")
scale = cur_crop_height / h
if args.resize_fit:
scale = max(cur_crop_width / w, cur_crop_height / h)
@@ -198,7 +202,7 @@ def process(args):
face_img = face_img[y:y + cur_crop_height]
# # debug
# print(path, cx, cy, angle)
# logger.info(path, cx, cy, angle)
# crp = cv2.resize(image, (image.shape[1]//8, image.shape[0]//8))
# cv2.imshow("image", crp)
# if cv2.waitKey() == 27:

View File

@@ -11,10 +11,16 @@ from typing import Dict, List
import numpy as np
import torch
from library.device_utils import init_ipex, get_preferred_device
init_ipex()
from torch import nn
from tqdm import tqdm
from PIL import Image
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=1):
@@ -216,7 +222,7 @@ class Upscaler(nn.Module):
upsampled_images = upsampled_images / 127.5 - 1.0
# convert upsample images to latents with batch size
# print("Encoding upsampled (LANCZOS4) images...")
# logger.info("Encoding upsampled (LANCZOS4) images...")
upsampled_latents = []
for i in tqdm(range(0, upsampled_images.shape[0], vae_batch_size)):
batch = upsampled_images[i : i + vae_batch_size].to(vae.device)
@@ -227,7 +233,7 @@ class Upscaler(nn.Module):
upsampled_latents = torch.cat(upsampled_latents, dim=0)
# upscale (refine) latents with this model with batch size
print("Upscaling latents...")
logger.info("Upscaling latents...")
upscaled_latents = []
for i in range(0, upsampled_latents.shape[0], batch_size):
with torch.no_grad():
@@ -242,7 +248,7 @@ def create_upscaler(**kwargs):
weights = kwargs["weights"]
model = Upscaler()
print(f"Loading weights from {weights}...")
logger.info(f"Loading weights from {weights}...")
if os.path.splitext(weights)[1] == ".safetensors":
from safetensors.torch import load_file
@@ -255,20 +261,20 @@ def create_upscaler(**kwargs):
# another interface: upscale images with a model for given images from command line
def upscale_images(args: argparse.Namespace):
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE = get_preferred_device()
us_dtype = torch.float16 # TODO: support fp32/bf16
os.makedirs(args.output_dir, exist_ok=True)
# load VAE with Diffusers
assert args.vae_path is not None, "VAE path is required"
print(f"Loading VAE from {args.vae_path}...")
logger.info(f"Loading VAE from {args.vae_path}...")
vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae")
vae.to(DEVICE, dtype=us_dtype)
# prepare model
print("Preparing model...")
logger.info("Preparing model...")
upscaler: Upscaler = create_upscaler(weights=args.weights)
# print("Loading weights from", args.weights)
# logger.info("Loading weights from", args.weights)
# upscaler.load_state_dict(torch.load(args.weights))
upscaler.eval()
upscaler.to(DEVICE, dtype=us_dtype)
@@ -303,14 +309,14 @@ def upscale_images(args: argparse.Namespace):
image_debug.save(dest_file_name)
# upscale
print("Upscaling...")
logger.info("Upscaling...")
upscaled_latents = upscaler.upscale(
vae, images, None, us_dtype, width * 2, height * 2, batch_size=args.batch_size, vae_batch_size=args.vae_batch_size
)
upscaled_latents /= 0.18215
# decode with batch
print("Decoding...")
logger.info("Decoding...")
upscaled_images = []
for i in tqdm(range(0, upscaled_latents.shape[0], args.vae_batch_size)):
with torch.no_grad():

View File

@@ -5,7 +5,10 @@ import torch
from safetensors import safe_open
from safetensors.torch import load_file, save_file
from tqdm import tqdm
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def is_unet_key(key):
# VAE or TextEncoder, the last one is for SDXL
@@ -45,10 +48,10 @@ def merge(args):
# check if all models are safetensors
for model in args.models:
if not model.endswith("safetensors"):
print(f"Model {model} is not a safetensors model")
logger.info(f"Model {model} is not a safetensors model")
exit()
if not os.path.isfile(model):
print(f"Model {model} does not exist")
logger.info(f"Model {model} does not exist")
exit()
assert args.ratios is None or len(args.models) == len(args.ratios), "ratios must be the same length as models"
@@ -65,7 +68,7 @@ def merge(args):
if merged_sd is None:
# load first model
print(f"Loading model {model}, ratio = {ratio}...")
logger.info(f"Loading model {model}, ratio = {ratio}...")
merged_sd = {}
with safe_open(model, framework="pt", device=args.device) as f:
for key in tqdm(f.keys()):
@@ -81,11 +84,11 @@ def merge(args):
value = ratio * value.to(dtype) # first model's value * ratio
merged_sd[key] = value
print(f"Model has {len(merged_sd)} keys " + ("(UNet only)" if args.unet_only else ""))
logger.info(f"Model has {len(merged_sd)} keys " + ("(UNet only)" if args.unet_only else ""))
continue
# load other models
print(f"Loading model {model}, ratio = {ratio}...")
logger.info(f"Loading model {model}, ratio = {ratio}...")
with safe_open(model, framework="pt", device=args.device) as f:
model_keys = f.keys()
@@ -93,7 +96,7 @@ def merge(args):
_, new_key = replace_text_encoder_key(key)
if new_key not in merged_sd:
if args.show_skipped and new_key not in first_model_keys:
print(f"Skip: {new_key}")
logger.info(f"Skip: {new_key}")
continue
value = f.get_tensor(key)
@@ -104,7 +107,7 @@ def merge(args):
for key in merged_sd.keys():
if key in model_keys:
continue
print(f"Key {key} not in model {model}, use first model's value")
logger.warning(f"Key {key} not in model {model}, use first model's value")
if key in supplementary_key_ratios:
supplementary_key_ratios[key] += ratio
else:
@@ -112,7 +115,7 @@ def merge(args):
# add supplementary keys' value (including VAE and TextEncoder)
if len(supplementary_key_ratios) > 0:
print("add first model's value")
logger.info("add first model's value")
with safe_open(args.models[0], framework="pt", device=args.device) as f:
for key in tqdm(f.keys()):
_, new_key = replace_text_encoder_key(key)
@@ -120,7 +123,7 @@ def merge(args):
continue
if is_unet_key(new_key): # not VAE or TextEncoder
print(f"Key {new_key} not in all models, ratio = {supplementary_key_ratios[new_key]}")
logger.warning(f"Key {new_key} not in all models, ratio = {supplementary_key_ratios[new_key]}")
value = f.get_tensor(key) # original key
@@ -134,7 +137,7 @@ def merge(args):
if not output_file.endswith(".safetensors"):
output_file = output_file + ".safetensors"
print(f"Saving to {output_file}...")
logger.info(f"Saving to {output_file}...")
# convert to save_dtype
for k in merged_sd.keys():
@@ -142,7 +145,7 @@ def merge(args):
save_file(merged_sd, output_file)
print("Done!")
logger.info("Done!")
if __name__ == "__main__":

View File

@@ -7,7 +7,10 @@ from safetensors.torch import load_file
from library.original_unet import UNet2DConditionModel, SampleOutput
import library.model_util as model_util
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class ControlNetInfo(NamedTuple):
unet: Any
@@ -51,7 +54,7 @@ def load_control_net(v2, unet, model):
# control sdからキー変換しつつU-Netに対応する部分のみ取り出し、DiffusersのU-Netに読み込む
# state dictを読み込む
print(f"ControlNet: loading control SD model : {model}")
logger.info(f"ControlNet: loading control SD model : {model}")
if model_util.is_safetensors(model):
ctrl_sd_sd = load_file(model)
@@ -61,7 +64,7 @@ def load_control_net(v2, unet, model):
# 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む
is_difference = "difference" in ctrl_sd_sd
print("ControlNet: loading difference:", is_difference)
logger.info(f"ControlNet: loading difference: {is_difference}")
# ControlNetには存在しないキーがあるので、まず現在のU-NetでSD版の全keyを作っておく
# またTransfer Controlの元weightとなる
@@ -89,13 +92,13 @@ def load_control_net(v2, unet, model):
# ControlNetのU-Netを作成する
ctrl_unet = UNet2DConditionModel(**unet_config)
info = ctrl_unet.load_state_dict(ctrl_unet_du_sd)
print("ControlNet: loading Control U-Net:", info)
logger.info(f"ControlNet: loading Control U-Net: {info}")
# U-Net以外のControlNetを作成する
# TODO support middle only
ctrl_net = ControlNet()
info = ctrl_net.load_state_dict(zero_conv_sd)
print("ControlNet: loading ControlNet:", info)
logger.info("ControlNet: loading ControlNet: {info}")
ctrl_unet.to(unet.device, dtype=unet.dtype)
ctrl_net.to(unet.device, dtype=unet.dtype)
@@ -117,7 +120,7 @@ def load_preprocess(prep_type: str):
return canny
print("Unsupported prep type:", prep_type)
logger.info(f"Unsupported prep type: {prep_type}")
return None
@@ -174,13 +177,26 @@ def call_unet_and_control_net(
cnet_idx = step % cnet_cnt
cnet_info = control_nets[cnet_idx]
# print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
# logger.info(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
if cnet_info.ratio < current_ratio:
return original_unet(sample, timestep, encoder_hidden_states)
guided_hint = guided_hints[cnet_idx]
# gradual latent support: match the size of guided_hint to the size of sample
if guided_hint.shape[-2:] != sample.shape[-2:]:
# print(f"guided_hint.shape={guided_hint.shape}, sample.shape={sample.shape}")
org_dtype = guided_hint.dtype
if org_dtype == torch.bfloat16:
guided_hint = guided_hint.to(torch.float32)
guided_hint = torch.nn.functional.interpolate(guided_hint, size=sample.shape[-2:], mode="bicubic")
if org_dtype == torch.bfloat16:
guided_hint = guided_hint.to(org_dtype)
guided_hint = guided_hint.repeat((num_latent_input, 1, 1, 1))
outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states_for_control_net)
outs = unet_forward(
True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states_for_control_net
)
outs = [o * cnet_info.weight for o in outs]
# U-Net
@@ -192,7 +208,7 @@ def call_unet_and_control_net(
# ControlNet
cnet_outs_list = []
for i, cnet_info in enumerate(control_nets):
# print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
# logger.info(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
if cnet_info.ratio < current_ratio:
continue
guided_hint = guided_hints[i]
@@ -232,7 +248,7 @@ def unet_forward(
upsample_size = None
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
print("Forward upsample size to force interpolation output size.")
logger.info("Forward upsample size to force interpolation output size.")
forward_upsample_size = True
# 1. time

View File

@@ -6,7 +6,10 @@ import shutil
import math
from PIL import Image
import numpy as np
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2, interpolation=None, save_as_png=False, copy_associated_files=False):
# Split the max_resolution string by "," and strip any whitespaces
@@ -83,7 +86,7 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
image.save(os.path.join(dst_img_folder, new_filename), quality=100)
proc = "Resized" if current_pixels > max_pixels else "Saved"
print(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}")
logger.info(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}")
# If other files with same basename, copy them with resolution suffix
if copy_associated_files:
@@ -94,7 +97,7 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
continue
for max_resolution in max_resolutions:
new_asoc_file = base + '+' + max_resolution + ext
print(f"Copy {asoc_file} as {new_asoc_file}")
logger.info(f"Copy {asoc_file} as {new_asoc_file}")
shutil.copy(os.path.join(src_img_folder, asoc_file), os.path.join(dst_img_folder, new_asoc_file))

View File

@@ -1,6 +1,10 @@
import json
import argparse
from safetensors import safe_open
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, required=True)
@@ -10,10 +14,10 @@ with safe_open(args.model, framework="pt") as f:
metadata = f.metadata()
if metadata is None:
print("No metadata found")
logger.error("No metadata found")
else:
# metadata is json dict, but not pretty printed
# sort by key and pretty print
print(json.dumps(metadata, indent=4, sort_keys=True))

View File

@@ -1,5 +1,4 @@
import argparse
import gc
import json
import math
import os
@@ -10,14 +9,12 @@ from types import SimpleNamespace
import toml
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from library import deepspeed_utils
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from torch.nn.parallel import DistributedDataParallel as DDP
from accelerate.utils import set_seed
from diffusers import DDPMScheduler, ControlNetModel
@@ -37,6 +34,12 @@ from library.custom_train_functions import (
pyramid_noise_like,
apply_noise_offset,
)
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
# TODO 他のスクリプトと共通化する
@@ -58,6 +61,7 @@ def train(args):
# training_started_at = time.time()
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
setup_logging(args, reset=True)
cache_latents = args.cache_latents
use_user_config = args.dataset_config is not None
@@ -71,11 +75,11 @@ def train(args):
# データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
if use_user_config:
print(f"Load dataset config from {args.dataset_config}")
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "conditioning_data_dir"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
@@ -105,7 +109,7 @@ def train(args):
train_util.debug_dataset(train_dataset_group)
return
if len(train_dataset_group) == 0:
print(
logger.error(
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してくださいtrain_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります"
)
return
@@ -116,7 +120,7 @@ def train(args):
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
# acceleratorを準備する
print("prepare accelerator")
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
is_main_process = accelerator.is_main_process
@@ -221,10 +225,8 @@ def train(args):
accelerator.is_main_process,
)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
if args.gradient_checkpointing:
@@ -238,8 +240,8 @@ def train(args):
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
@@ -255,7 +257,9 @@ def train(args):
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
@@ -311,8 +315,10 @@ def train(args):
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
# print(f" total train batch size (with parallel & distributed & accumulation) / バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
accelerator.print(
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
)
# logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
@@ -333,9 +339,13 @@ def train(args):
)
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
accelerator.init_trackers(
"controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
)
loss_recorder = train_util.LossRecorder()
del train_dataset_group
@@ -371,6 +381,11 @@ def train(args):
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
# For --sample_at_first
train_util.sample_images(
accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, controlnet=controlnet
)
# training loop
for epoch in range(num_train_epochs):
if is_main_process:
@@ -382,7 +397,7 @@ def train(args):
with accelerator.accumulate(controlnet):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
@@ -405,13 +420,9 @@ def train(args):
)
# Sample a random timestep for each image
timesteps = torch.randint(
0,
noise_scheduler.config.num_train_timesteps,
(b_size,),
device=latents.device,
)
timesteps = timesteps.long()
timesteps = train_util.get_timesteps(args, 0, noise_scheduler.config.num_train_timesteps, b_size)
huber_c = train_util.get_huber_c(args, noise_scheduler, timesteps.item(), latents.device)
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
@@ -442,14 +453,14 @@ def train(args):
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
@@ -551,7 +562,7 @@ def train(args):
accelerator.end_training()
if is_main_process and args.save_state:
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)
# del accelerator # この後メモリを使うのでこれは消す→printで使うので消さずにおく
@@ -560,15 +571,17 @@ def train(args):
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(ckpt_name, controlnet, force_sync_upload=True)
print("model saved.")
logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
@@ -600,6 +613,7 @@ if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
train(args)

View File

@@ -1,26 +1,21 @@
# DreamBooth training
# XXX dropped option: fine_tune
import gc
import argparse
import itertools
import math
import os
from multiprocessing import Value
import toml
from tqdm import tqdm
import torch
from library import deepspeed_utils
from library.device_utils import init_ipex, clean_memory_on_device
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
init_ipex()
ipex_init()
except Exception:
pass
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
@@ -39,14 +34,85 @@ from library.custom_train_functions import (
apply_noise_offset,
scale_v_prediction_loss_like_noise_prediction,
apply_debiased_estimation,
apply_masked_loss,
)
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
import itertools
logger = logging.getLogger(__name__)
# perlin_noise,
def process_val_batch(*training_models, batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args):
total_loss = 0.0
timesteps_list = [10, 350, 500, 650, 990]
with accelerator.accumulate(*training_models):
with torch.no_grad():
# latentに変換
if cache_latents:
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
b_size = latents.shape[0]
with torch.set_grad_enabled(False), accelerator.autocast():
if args.weighted_captions:
encoder_hidden_states = get_weighted_text_embeddings(
tokenizer,
text_encoder,
batch["captions"],
accelerator.device,
args.max_token_length // 75 if args.max_token_length else 1,
clip_skip=args.clip_skip,
)
else:
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
)
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
for fixed_timesteps in timesteps_list:
with torch.set_grad_enabled(False), accelerator.autocast():
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise = torch.randn_like(latents, device=latents.device)
b_size = latents.shape[0]
timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device=latents.device)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Predict the noise residual
with accelerator.autocast():
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
if args.masked_loss:
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
total_loss += loss
average_loss = total_loss / len(timesteps_list)
return average_loss
def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, False)
deepspeed_utils.prepare_deepspeed_args(args)
setup_logging(args, reset=True)
cache_latents = args.cache_latents
@@ -57,13 +123,13 @@ def train(args):
# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, False, True))
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, args.masked_loss, True))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "reg_data_dir"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
@@ -76,9 +142,10 @@ def train(args):
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
val_dataset_group = None
current_epoch = Value("i", 0)
current_step = Value("i", 0)
@@ -98,13 +165,13 @@ def train(args):
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
# acceleratorを準備する
print("prepare accelerator")
logger.info("prepare accelerator")
if args.gradient_accumulation_steps > 1:
print(
logger.warning(
f"gradient_accumulation_steps is {args.gradient_accumulation_steps}. accelerate does not support gradient_accumulation_steps when training multiple models (U-Net and Text Encoder), so something might be wrong"
)
print(
logger.warning(
f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデルU-NetおよびText Encoderの学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です"
)
@@ -112,6 +179,7 @@ def train(args):
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
# モデルを読み込む
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)
@@ -136,15 +204,16 @@ def train(args):
# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
if val_dataset_group is not None:
print("Cache validation latents...")
val_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
@@ -181,8 +250,8 @@ def train(args):
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
@@ -191,13 +260,24 @@ def train(args):
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
val_dataloader = torch.utils.data.DataLoader(
val_dataset_group if val_dataset_group is not None else [],
shuffle=False,
batch_size=1,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
cyclic_val_dataloader = itertools.cycle(val_dataloader)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
@@ -218,15 +298,25 @@ def train(args):
text_encoder.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
if train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
if args.deepspeed:
if args.train_text_encoder:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder)
else:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
training_models = [ds_model]
# transform DDP after prepare
text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet)
else:
if train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
training_models = [unet, text_encoder]
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
training_models = [unet]
if not train_text_encoder:
text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error
@@ -270,11 +360,18 @@ def train(args):
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
# For --sample_at_first
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
loss_recorder = train_util.LossRecorder()
val_loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
@@ -293,12 +390,14 @@ def train(args):
if not args.gradient_checkpointing:
text_encoder.train(False)
text_encoder.requires_grad_(False)
if len(training_models) == 2:
training_models = training_models[0] # remove text_encoder from training_models
with accelerator.accumulate(unet):
with accelerator.accumulate(*training_models):
with torch.no_grad():
# latentに変換
if cache_latents:
latents = batch["latents"].to(accelerator.device)
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
@@ -323,7 +422,7 @@ def train(args):
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
# Predict the noise residual
with accelerator.autocast():
@@ -335,14 +434,16 @@ def train(args):
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
if args.masked_loss:
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.debiased_estimation_loss:
@@ -402,12 +503,31 @@ def train(args):
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if len(val_dataloader) > 0:
if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps:
accelerator.print("Validating バリデーション処理...")
total_loss = 0.0
with torch.no_grad():
validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader)
for val_step in tqdm(range(validation_steps), desc='Validation Steps'):
batch = next(cyclic_val_dataloader)
loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args)
total_loss += loss.detach().item()
current_loss = total_loss / validation_steps
val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss)
if args.logging_dir is not None:
logs = {"loss/current_val_loss": current_loss}
accelerator.log(logs, step=global_step)
avr_loss: float = val_loss_recorder.moving_average
logs = {"loss/average_val_loss": avr_loss}
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_recorder.moving_average}
logs = {"loss/epoch_average": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
@@ -441,7 +561,7 @@ def train(args):
accelerator.end_training()
if args.save_state and is_main_process:
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)
del accelerator # この後メモリを使うのでこれは消す
@@ -451,15 +571,18 @@ def train(args):
train_util.save_sd_model_on_train_end(
args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae
)
print("model saved.")
logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, False, True)
train_util.add_training_arguments(parser, True)
train_util.add_masked_loss_arguments(parser)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
@@ -482,7 +605,35 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない",
)
parser.add_argument(
"--no_half_vae",
action="store_true",
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
)
parser.add_argument(
"--validation_seed",
type=int,
default=None,
help="Validation seed"
)
parser.add_argument(
"--validation_split",
type=float,
default=0.0,
help="Split for validation images out of the training dataset"
)
parser.add_argument(
"--validation_every_n_step",
type=int,
default=None,
help="Number of train steps for counting validation loss. By default, validation per train epoch is performed"
)
parser.add_argument(
"--max_validation_steps",
type=int,
default=None,
help="Number of max validation steps for counting validation loss. By default, validation will run entire validation dataset"
)
return parser
@@ -490,6 +641,7 @@ if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
train(args)

View File

@@ -1,8 +1,9 @@
import importlib
import argparse
import gc
import math
import os
import typing
from typing import List, Optional, Union
import sys
import random
import time
@@ -11,25 +12,20 @@ from multiprocessing import Value
import toml
from tqdm import tqdm
import torch
from torch.types import Number
from library.device_utils import init_ipex, clean_memory_on_device
try:
import intel_extension_for_pytorch as ipex
init_ipex()
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from library import model_util
from diffusers import DDPMScheduler, AutoencoderKL
from diffusers.models.modeling_outputs import AutoencoderKLOutput
from library import deepspeed_utils, model_util
import library.train_util as train_util
from library.train_util import (
DreamBoothDataset,
)
from library.train_util import DreamBoothDataset
import library.config_util as config_util
from library.config_util import (
ConfigSanitizer,
@@ -44,7 +40,15 @@ from library.custom_train_functions import (
scale_v_prediction_loss_like_noise_prediction,
add_v_prediction_like_loss,
apply_debiased_estimation,
apply_masked_loss,
)
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
import itertools
logger = logging.getLogger(__name__)
class NetworkTrainer:
@@ -116,7 +120,7 @@ class NetworkTrainer:
self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype
):
for t_enc in text_encoders:
t_enc.to(accelerator.device)
t_enc.to(accelerator.device, dtype=weight_dtype)
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
input_ids = batch["input_ids"].to(accelerator.device)
@@ -127,25 +131,28 @@ class NetworkTrainer:
noise_pred = unet(noisy_latents, timesteps, text_conds).sample
return noise_pred
def all_reduce_network(self, accelerator, network):
for param in network.parameters():
if param.grad is not None:
param.grad = accelerator.reduce(param.grad, reduction="mean")
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)
def process_batch(self, batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=True, timesteps_list=None):
total_loss = 0.0
def process_batch(self, batch, tokenizers, text_encoders, unet, vae: AutoencoderKL, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, is_train=True, train_text_encoder=True, train_unet=True, timesteps_list: Optional[List[Number]]=None):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
latents: torch.Tensor = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device))
else:
# latentに変換
latents = vae.encode(batch["images"].to(accelerator.device, dtype=vae_dtype)).latent_dist.sample()
latents: torch.Tensor = typing.cast(torch.FloatTensor, typing.cast(AutoencoderKLOutput, vae.encode(batch["images"].to(accelerator.device, dtype=vae_dtype))).latent_dist.sample())
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
latents = latents * self.vae_scale_factor
b_size = latents.shape[0]
latents = typing.cast(torch.FloatTensor, torch.where(torch.isnan(latents), torch.zeros_like(latents), latents))
latents = typing.cast(torch.FloatTensor, latents * self.vae_scale_factor)
with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast():
# Get the text embedding for conditioning
@@ -163,52 +170,66 @@ class NetworkTrainer:
args, accelerator, batch, tokenizers, text_encoders, weight_dtype
)
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)
batch_size = latents.shape[0]
# Sample noise,
noise = train_util.make_noise(args, latents)
def pick_timesteps_list() -> torch.IntTensor:
if timesteps_list is None or timesteps_list == []:
return typing.cast(torch.IntTensor, train_util.make_random_timesteps(args, noise_scheduler, batch_size, latents.device).unsqueeze(1))
else:
return typing.cast(torch.IntTensor, torch.tensor(timesteps_list).unsqueeze(1).repeat(1, batch_size).to(latents.device))
choosen_timesteps_list = pick_timesteps_list()
total_loss = torch.zeros((batch_size, 1)).to(latents.device)
# Use input timesteps_list or use described timesteps above
timesteps_list = timesteps_list or [timesteps]
for timesteps in timesteps_list:
for fixed_timestep in choosen_timesteps_list:
fixed_timestep = typing.cast(torch.IntTensor, fixed_timestep)
# Predict the noise residual
with torch.set_grad_enabled(is_train), accelerator.autocast():
# and add noise to the latents
# with noise offset and/or multires noise if specified
noisy_latents = train_util.get_noisy_latents(args, noise, noise_scheduler, latents, fixed_timestep)
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast():
noise_pred = self.call_unet(
args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype
args, accelerator, unet, noisy_latents.requires_grad_(train_unet), fixed_timestep, text_encoder_conds, batch, weight_dtype
)
if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, timesteps)
target = noise_scheduler.get_velocity(latents, noise, fixed_timestep)
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
loss = loss.mean([1, 2, 3]) # 平均なのでbatch_sizeで割る必要なし
loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight
loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = apply_snr_weight(loss, fixed_timestep, noise_scheduler, args.min_snr_gamma)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
loss = scale_v_prediction_loss_like_noise_prediction(loss, fixed_timestep, noise_scheduler)
if args.v_pred_like_loss:
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
loss = add_v_prediction_like_loss(loss, fixed_timestep, noise_scheduler, args.v_pred_like_loss)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
loss = apply_debiased_estimation(loss, fixed_timestep, noise_scheduler)
total_loss += loss.mean() # 平均なのでbatch_sizeで割る必要なし
total_loss += loss
average_loss = total_loss / len(timesteps_list)
return average_loss
return total_loss.mean()
def train(self, args):
session_id = random.randint(0, 2**32)
training_started_at = time.time()
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
deepspeed_utils.prepare_deepspeed_args(args)
setup_logging(args, reset=True)
cache_latents = args.cache_latents
use_dreambooth_method = args.in_json is None
@@ -224,20 +245,20 @@ class NetworkTrainer:
# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
if use_user_config:
print(f"Loading dataset config from {args.dataset_config}")
logger.info(f"Loading dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
logger.warning(
"ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
if use_dreambooth_method:
print("Using DreamBooth method.")
logger.info("Using DreamBooth method.")
user_config = {
"datasets": [
{
@@ -248,7 +269,7 @@ class NetworkTrainer:
]
}
else:
print("Training with captions.")
logger.info("Training with captions.")
user_config = {
"datasets": [
{
@@ -278,7 +299,7 @@ class NetworkTrainer:
train_util.debug_dataset(train_dataset_group)
return
if len(train_dataset_group) == 0:
print(
logger.error(
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してくださいtrain_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります"
)
return
@@ -295,7 +316,7 @@ class NetworkTrainer:
self.assert_extra_args(args, train_dataset_group)
# acceleratorを準備する
print("preparing accelerator")
logger.info("preparing accelerator")
accelerator = train_util.prepare_accelerator(args)
is_main_process = accelerator.is_main_process
@@ -347,13 +368,12 @@ class NetworkTrainer:
print("Cache validation latents...")
val_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
# 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される
# cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu
self.cache_text_encoder_outputs_if_needed(
args, accelerator, unet, vae, tokenizers, text_encoders, train_dataset_group, weight_dtype
)
@@ -385,11 +405,12 @@ class NetworkTrainer:
)
if network is None:
return
network_has_multiplier = hasattr(network, "set_multiplier")
if hasattr(network, "prepare_network"):
network.prepare_network(args)
if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"):
print(
logger.warning(
"warning: scale_weight_norms is specified but the network does not support it / scale_weight_normsが指定されていますが、ネットワークが対応していません"
)
args.scale_weight_norms = False
@@ -424,8 +445,8 @@ class NetworkTrainer:
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
@@ -445,15 +466,6 @@ class NetworkTrainer:
persistent_workers=args.persistent_data_loader_workers,
)
val_dataloader = torch.utils.data.DataLoader(
val_dataset_group if val_dataset_group is not None else [],
shuffle=False,
batch_size=1,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(
@@ -483,53 +495,59 @@ class NetworkTrainer:
accelerator.print("enable full bf16 training.")
network.to(weight_dtype)
unet_weight_dtype = te_weight_dtype = weight_dtype
# Experimental Feature: Put base model into fp8 to save vram
if args.fp8_base:
assert torch.__version__ >= "2.1.0", "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。"
assert (
args.mixed_precision != "no"
), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。"
accelerator.print("enable fp8 training.")
unet_weight_dtype = torch.float8_e4m3fn
te_weight_dtype = torch.float8_e4m3fn
unet.requires_grad_(False)
unet.to(dtype=weight_dtype)
unet.to(dtype=unet_weight_dtype)
for t_enc in text_encoders:
t_enc.requires_grad_(False)
# acceleratorがなんかよろしくやってくれるらしい
# TODO めちゃくちゃ冗長なのでコードを整理する
if train_unet and train_text_encoder:
if len(text_encoders) > 1:
unet, t_enc1, t_enc2, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoders[0], text_encoders[1], network, optimizer, train_dataloader, lr_scheduler
)
text_encoder = text_encoders = [t_enc1, t_enc2]
del t_enc1, t_enc2
else:
unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler
)
text_encoders = [text_encoder]
elif train_unet:
unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, network, optimizer, train_dataloader, lr_scheduler
)
for t_enc in text_encoders:
t_enc.to(accelerator.device, dtype=weight_dtype)
elif train_text_encoder:
if len(text_encoders) > 1:
t_enc1, t_enc2, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoders[0], text_encoders[1], network, optimizer, train_dataloader, lr_scheduler
)
text_encoder = text_encoders = [t_enc1, t_enc2]
del t_enc1, t_enc2
else:
text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder, network, optimizer, train_dataloader, lr_scheduler
)
text_encoders = [text_encoder]
# in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16
if t_enc.device.type != "cpu":
t_enc.to(dtype=te_weight_dtype)
# nn.Embedding not support FP8
t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
unet.to(accelerator.device, dtype=weight_dtype) # move to device because unet is not prepared by accelerator
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
if args.deepspeed:
ds_model = deepspeed_utils.prepare_deepspeed_model(
args,
unet=unet if train_unet else None,
text_encoder1=text_encoders[0] if train_text_encoder else None,
text_encoder2=text_encoders[1] if train_text_encoder and len(text_encoders) > 1 else None,
network=network,
)
ds_model, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, val_dataloader, lr_scheduler
)
training_model = ds_model
else:
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
network, optimizer, train_dataloader, lr_scheduler
)
if train_unet:
unet = accelerator.prepare(unet)
else:
unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator
if train_text_encoder:
if len(text_encoders) > 1:
text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders]
else:
text_encoder = accelerator.prepare(text_encoder)
text_encoders = [text_encoder]
else:
pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set
# transform DDP after prepare (train_network here only)
text_encoders = train_util.transform_models_if_DDP(text_encoders)
unet, network = train_util.transform_models_if_DDP([unet, network])
network, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
network, optimizer, train_dataloader, val_dataloader, lr_scheduler
)
training_model = network
if args.gradient_checkpointing:
# according to TI example in Diffusers, train is required
@@ -541,9 +559,6 @@ class NetworkTrainer:
if train_text_encoder:
t_enc.text_model.embeddings.requires_grad_(True)
# set top parameter requires_grad = True for gradient checkpointing works
if not train_text_encoder: # train U-Net only
unet.parameters().__next__().requires_grad_(True)
else:
unet.eval()
for t_enc in text_encoders:
@@ -551,7 +566,7 @@ class NetworkTrainer:
del t_enc
network.prepare_grad_etc(text_encoder, unet)
accelerator.unwrap_model(network).prepare_grad_etc(text_encoder, unet)
if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する
vae.requires_grad_(False)
@@ -562,6 +577,31 @@ class NetworkTrainer:
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
# before resuming make hook for saving/loading to save/load the network weights only
def save_model_hook(models, weights, output_dir):
# pop weights of other models than network to save only network weights
if accelerator.is_main_process:
remove_indices = []
for i, model in enumerate(models):
if not isinstance(model, type(accelerator.unwrap_model(network))):
remove_indices.append(i)
for i in reversed(remove_indices):
weights.pop(i)
# print(f"save model hook: {len(weights)} weights will be saved")
def load_model_hook(models, input_dir):
# remove models except network
remove_indices = []
for i, model in enumerate(models):
if not isinstance(model, type(accelerator.unwrap_model(network))):
remove_indices.append(i)
for i in reversed(remove_indices):
models.pop(i)
# print(f"load model hook: {len(models)} models will be loaded")
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
# resumeする
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
@@ -635,6 +675,11 @@ class NetworkTrainer:
"ss_scale_weight_norms": args.scale_weight_norms,
"ss_ip_noise_gamma": args.ip_noise_gamma,
"ss_debiased_estimation": bool(args.debiased_estimation_loss),
"ss_noise_offset_random_strength": args.noise_offset_random_strength,
"ss_ip_noise_gamma_random_strength": args.ip_noise_gamma_random_strength,
"ss_loss_type": args.loss_type,
"ss_huber_schedule": args.huber_schedule,
"ss_huber_c": args.huber_c,
}
if use_user_config:
@@ -670,6 +715,11 @@ class NetworkTrainer:
"random_crop": bool(subset.random_crop),
"shuffle_caption": bool(subset.shuffle_caption),
"keep_tokens": subset.keep_tokens,
"keep_tokens_separator": subset.keep_tokens_separator,
"secondary_separator": subset.secondary_separator,
"enable_wildcard": bool(subset.enable_wildcard),
"caption_prefix": subset.caption_prefix,
"caption_suffix": subset.caption_suffix,
}
image_dir_or_metadata_file = None
@@ -804,6 +854,8 @@ class NetworkTrainer:
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
@@ -815,8 +867,8 @@ class NetworkTrainer:
del train_dataset_group
# callback for step start
if hasattr(network, "on_step_start"):
on_step_start = network.on_step_start
if hasattr(accelerator.unwrap_model(network), "on_step_start"):
on_step_start = accelerator.unwrap_model(network).on_step_start
else:
on_step_start = lambda *args, **kwargs: None
@@ -844,6 +896,9 @@ class NetworkTrainer:
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
# For --sample_at_first
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
# training loop
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
@@ -851,27 +906,28 @@ class NetworkTrainer:
metadata["ss_epoch"] = str(epoch + 1)
network.on_epoch_start(text_encoder, unet)
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet)
# TRAINING
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(network):
with accelerator.accumulate(training_model):
on_step_start(text_encoder, unet)
is_train = True
loss = self.process_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=train_text_encoder)
loss = self.process_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, is_train=True, train_text_encoder=train_text_encoder, train_unet=train_unet)
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = network.get_trainable_params()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
if accelerator.sync_gradients:
self.all_reduce_network(accelerator, network) # sync DDP grad manually
if args.max_grad_norm != 0.0:
params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
if args.scale_weight_norms:
keys_scaled, mean_norm, maximum_norm = network.apply_max_norm_regularization(
keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(
args.scale_weight_norms, accelerator.device
)
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
@@ -912,25 +968,30 @@ class NetworkTrainer:
if args.logging_dir is not None:
logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm)
accelerator.log(logs, step=global_step)
if global_step % 25 == 0:
if len(val_dataloader) > 0:
print("Validating バリデーション処理...")
if len(val_dataloader) > 0:
if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps:
accelerator.print("Validating バリデーション処理...")
total_loss = 0.0
validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader)
# batch = next(val_dataloader)
for val_step, batch in tqdm(enumerate(val_dataloader), desc='Validation Steps'):
if val_step >= validation_steps:
break
with torch.no_grad():
val_dataloader_iter = iter(val_dataloader)
batch = next(val_dataloader_iter)
is_train = False
loss = self.process_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, timesteps_list=[10, 350, 500, 650, 990])
current_loss = loss.detach().item()
val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss)
if args.logging_dir is not None:
avr_loss: float = val_loss_recorder.moving_average
logs = {"loss/validation_current": current_loss}
accelerator.log(logs, step=global_step)
loss = self.process_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, is_train=False, timesteps_list=[10, 350, 500, 650, 990])
total_loss += loss.detach().item()
current_loss = total_loss / validation_steps
val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss)
if args.logging_dir is not None:
logs = {"loss/current_val_loss": current_loss}
accelerator.log(logs, step=global_step)
avr_loss: float = val_loss_recorder.moving_average
logs = {"loss/average_val_loss": avr_loss}
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
@@ -939,24 +1000,22 @@ class NetworkTrainer:
if len(val_dataloader) > 0:
print("Validating バリデーション処理...")
with torch.no_grad():
for val_step, batch in enumerate(val_dataloader):
is_train = False
loss = self.process_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, timesteps_list=[10, 350, 500, 650, 990])
for val_step, batch in enumerate(val_dataloader):
loss = self.process_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, is_train=False, timesteps_list=[10, 350, 500, 650, 990])
current_loss = loss.detach().item()
val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
current_loss = loss.detach().item()
val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
if args.logging_dir is not None:
avr_loss: float = val_loss_recorder.moving_average
logs = {"loss/validation_current": current_loss}
accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step)
if args.logging_dir is not None:
avr_loss: float = val_loss_recorder.moving_average
logs = {"loss/validation_current": current_loss}
accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step)
if len(val_dataloader) > 0:
if args.logging_dir is not None:
avr_loss: float = val_loss_recorder.moving_average
logs = {"loss/validation_average": avr_loss}
accelerator.log(logs, step=epoch + 1)
if len(val_dataloader) > 0:
if args.logging_dir is not None:
avr_loss: float = val_loss_recorder.moving_average
logs = {"loss/validation_average": avr_loss}
accelerator.log(logs, step=epoch + 1)
if args.logging_dir is not None:
@@ -998,27 +1057,32 @@ class NetworkTrainer:
accelerator.end_training()
if is_main_process and args.save_state:
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)
if is_main_process:
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True)
print("model saved.")
logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, True)
train_util.add_masked_loss_arguments(parser)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
parser.add_argument(
"--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない"
)
parser.add_argument(
"--save_model_as",
type=str,
@@ -1030,10 +1094,17 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み")
parser.add_argument("--network_module", type=str, default=None, help="network module to train / 学習対象のネットワークのモジュール")
parser.add_argument(
"--network_dim", type=int, default=None, help="network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)"
"--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み"
)
parser.add_argument(
"--network_module", type=str, default=None, help="network module to train / 学習対象のネットワークのモジュール"
)
parser.add_argument(
"--network_dim",
type=int,
default=None,
help="network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)",
)
parser.add_argument(
"--network_alpha",
@@ -1048,14 +1119,25 @@ def setup_parser() -> argparse.ArgumentParser:
help="Drops neurons out of training every step (0 or None is default behavior (no dropout), 1 would drop all neurons) / 訓練時に毎ステップでニューロンをdropする0またはNoneはdropoutなし、1は全ニューロンをdropout",
)
parser.add_argument(
"--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数"
)
parser.add_argument("--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する")
parser.add_argument(
"--network_train_text_encoder_only", action="store_true", help="only training Text Encoder part / Text Encoder関連部分のみ学習する"
"--network_args",
type=str,
default=None,
nargs="*",
help="additional arguments for network (key=value) / ネットワークへの追加の引数",
)
parser.add_argument(
"--training_comment", type=str, default=None, help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列"
"--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する"
)
parser.add_argument(
"--network_train_text_encoder_only",
action="store_true",
help="only training Text Encoder part / Text Encoder関連部分のみ学習する",
)
parser.add_argument(
"--training_comment",
type=str,
default=None,
help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列",
)
parser.add_argument(
"--dim_from_weights",
@@ -1098,8 +1180,19 @@ def setup_parser() -> argparse.ArgumentParser:
type=float,
default=0.0,
help="Split for validation images out of the training dataset"
)
)
parser.add_argument(
"--validation_every_n_step",
type=int,
default=None,
help="Number of train steps for counting validation loss. By default, validation per train epoch is performed"
)
parser.add_argument(
"--max_validation_steps",
type=int,
default=None,
help="Number of max validation steps for counting validation loss. By default, validation will run entire validation dataset"
)
return parser
@@ -1107,6 +1200,7 @@ if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
trainer = NetworkTrainer()

View File

@@ -1,23 +1,21 @@
import argparse
import gc
import math
import os
from multiprocessing import Value
import toml
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from transformers import CLIPTokenizer
from library import model_util
from library import deepspeed_utils, model_util
import library.train_util as train_util
import library.huggingface_util as huggingface_util
@@ -33,7 +31,14 @@ from library.custom_train_functions import (
scale_v_prediction_loss_like_noise_prediction,
add_v_prediction_like_loss,
apply_debiased_estimation,
apply_masked_loss,
)
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
imagenet_templates_small = [
"a photo of a {}",
@@ -170,6 +175,7 @@ class TextualInversionTrainer:
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
setup_logging(args, reset=True)
cache_latents = args.cache_latents
@@ -180,7 +186,7 @@ class TextualInversionTrainer:
tokenizers = tokenizer_or_list if isinstance(tokenizer_or_list, list) else [tokenizer_or_list]
# acceleratorを準備する
print("prepare accelerator")
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
@@ -265,7 +271,7 @@ class TextualInversionTrainer:
# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, False))
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, False))
if args.dataset_config is not None:
accelerator.print(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
@@ -290,7 +296,7 @@ class TextualInversionTrainer:
]
}
else:
print("Train with captions.")
logger.info("Train with captions.")
user_config = {
"datasets": [
{
@@ -365,9 +371,7 @@ class TextualInversionTrainer:
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
@@ -384,8 +388,8 @@ class TextualInversionTrainer:
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
@@ -415,15 +419,11 @@ class TextualInversionTrainer:
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler
)
# transform DDP after prepare
text_encoder_or_list, unet = train_util.transform_if_model_is_DDP(text_encoder_or_list, unet)
elif len(text_encoders) == 2:
text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoders[0], text_encoders[1], optimizer, train_dataloader, lr_scheduler
)
# transform DDP after prepare
text_encoder1, text_encoder2, unet = train_util.transform_if_model_is_DDP(text_encoder1, text_encoder2, unet)
text_encoder_or_list = text_encoders = [text_encoder1, text_encoder2]
@@ -442,9 +442,10 @@ class TextualInversionTrainer:
# Freeze all parameters except for the token embeddings in text encoder
text_encoder.requires_grad_(True)
text_encoder.text_model.encoder.requires_grad_(False)
text_encoder.text_model.final_layer_norm.requires_grad_(False)
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
unwrapped_text_encoder = accelerator.unwrap_model(text_encoder)
unwrapped_text_encoder.text_model.encoder.requires_grad_(False)
unwrapped_text_encoder.text_model.final_layer_norm.requires_grad_(False)
unwrapped_text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
# text_encoder.text_model.embeddings.token_embedding.requires_grad_(True)
unet.requires_grad_(False)
@@ -504,6 +505,8 @@ class TextualInversionTrainer:
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
@@ -529,6 +532,20 @@ class TextualInversionTrainer:
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
# For --sample_at_first
self.sample_images(
accelerator,
args,
0,
global_step,
accelerator.device,
vae,
tokenizer_or_list,
text_encoder_or_list,
unet,
prompt_replacement,
)
# training loop
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
@@ -544,10 +561,10 @@ class TextualInversionTrainer:
with accelerator.accumulate(text_encoders[0]):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample()
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)
latents = latents * self.vae_scale_factor
# Get the text embedding for conditioning
@@ -555,7 +572,7 @@ class TextualInversionTrainer:
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)
@@ -571,14 +588,16 @@ class TextualInversionTrainer:
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
if args.masked_loss:
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss:
@@ -590,7 +609,7 @@ class TextualInversionTrainer:
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = text_encoder.get_input_embeddings().parameters()
params_to_clip = accelerator.unwrap_model(text_encoder).get_input_embeddings().parameters()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
@@ -602,9 +621,11 @@ class TextualInversionTrainer:
for text_encoder, orig_embeds_params, index_no_updates in zip(
text_encoders, orig_embeds_params_list, index_no_updates_list
):
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
# if full_fp16/bf16, input_embeddings_weight is fp16/bf16, orig_embeds_params is fp32
input_embeddings_weight = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight
input_embeddings_weight[index_no_updates] = orig_embeds_params.to(input_embeddings_weight.dtype)[
index_no_updates
] = orig_embeds_params[index_no_updates]
]
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
@@ -712,27 +733,29 @@ class TextualInversionTrainer:
is_main_process = accelerator.is_main_process
if is_main_process:
text_encoder = accelerator.unwrap_model(text_encoder)
updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone()
accelerator.end_training()
if args.save_state and is_main_process:
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)
updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone()
if is_main_process:
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(ckpt_name, updated_embs_list, global_step, num_train_epochs, force_sync_upload=True)
print("model saved.")
logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, False)
train_util.add_training_arguments(parser, True)
train_util.add_masked_loss_arguments(parser)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser, False)
@@ -745,7 +768,9 @@ def setup_parser() -> argparse.ArgumentParser:
help="format to save the model (default is .pt) / モデル保存時の形式デフォルトはpt",
)
parser.add_argument("--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み")
parser.add_argument(
"--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み"
)
parser.add_argument(
"--num_vectors_per_token", type=int, default=1, help="number of vectors per token / トークンに割り当てるembeddingsの要素数"
)
@@ -755,7 +780,9 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること",
)
parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可")
parser.add_argument(
"--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可"
)
parser.add_argument(
"--use_object_template",
action="store_true",
@@ -779,6 +806,7 @@ if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
trainer = TextualInversionTrainer()

View File

@@ -1,20 +1,18 @@
import importlib
import argparse
import gc
import math
import os
import toml
from multiprocessing import Value
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from library import deepspeed_utils
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate.utils import set_seed
import diffusers
from diffusers import DDPMScheduler
@@ -35,9 +33,16 @@ from library.custom_train_functions import (
apply_noise_offset,
scale_v_prediction_loss_like_noise_prediction,
apply_debiased_estimation,
apply_masked_loss,
)
import library.original_unet as original_unet
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
imagenet_templates_small = [
"a photo of a {}",
@@ -96,12 +101,13 @@ def train(args):
if args.output_name is None:
args.output_name = args.token_string
use_template = args.use_object_template or args.use_style_template
setup_logging(args, reset=True)
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
if args.sample_every_n_steps is not None or args.sample_every_n_epochs is not None:
print(
logger.warning(
"sample_every_n_steps and sample_every_n_epochs are not supported in this script currently / sample_every_n_stepsとsample_every_n_epochsは現在このスクリプトではサポートされていません"
)
assert (
@@ -116,7 +122,7 @@ def train(args):
tokenizer = train_util.load_tokenizer(args)
# acceleratorを準備する
print("prepare accelerator")
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
@@ -129,7 +135,7 @@ def train(args):
if args.init_word is not None:
init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False)
if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token:
print(
logger.warning(
f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}"
)
else:
@@ -143,7 +149,7 @@ def train(args):
), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}"
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
print(f"tokens are added: {token_ids}")
logger.info(f"tokens are added: {token_ids}")
assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered"
assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}"
@@ -171,7 +177,7 @@ def train(args):
tokenizer.add_tokens(token_strings_XTI)
token_ids_XTI = tokenizer.convert_tokens_to_ids(token_strings_XTI)
print(f"tokens are added (XTI): {token_ids_XTI}")
logger.info(f"tokens are added (XTI): {token_ids_XTI}")
# Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder.resize_token_embeddings(len(tokenizer))
@@ -180,7 +186,7 @@ def train(args):
if init_token_ids is not None:
for i, token_id in enumerate(token_ids_XTI):
token_embeds[token_id] = token_embeds[init_token_ids[(i // 16) % len(init_token_ids)]]
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
# logger.info(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
# load weights
if args.weights is not None:
@@ -188,22 +194,22 @@ def train(args):
assert len(token_ids) == len(
embeddings
), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}"
# print(token_ids, embeddings.size())
# logger.info(token_ids, embeddings.size())
for token_id, embedding in zip(token_ids_XTI, embeddings):
token_embeds[token_id] = embedding
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
print(f"weighs loaded")
# logger.info(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
logger.info(f"weighs loaded")
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
logger.info(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
# データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, False))
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, False))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
logger.info(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
@@ -211,14 +217,14 @@ def train(args):
else:
use_dreambooth_method = args.in_json is None
if use_dreambooth_method:
print("Use DreamBooth method.")
logger.info("Use DreamBooth method.")
user_config = {
"datasets": [
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
]
}
else:
print("Train with captions.")
logger.info("Train with captions.")
user_config = {
"datasets": [
{
@@ -242,7 +248,7 @@ def train(args):
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
if use_template:
print(f"use template for training captions. is object: {args.use_object_template}")
logger.info(f"use template for training captions. is object: {args.use_object_template}")
templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small
replace_to = " ".join(token_strings)
captions = []
@@ -266,7 +272,7 @@ def train(args):
train_util.debug_dataset(train_dataset_group, show_input_ids=True)
return
if len(train_dataset_group) == 0:
print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
logger.error("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
return
if cache_latents:
@@ -288,9 +294,7 @@ def train(args):
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
@@ -299,13 +303,13 @@ def train(args):
text_encoder.gradient_checkpointing_enable()
# 学習に必要なクラスを準備する
print("prepare optimizer, data loader etc.")
logger.info("prepare optimizer, data loader etc.")
trainable_params = text_encoder.get_input_embeddings().parameters()
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
# DataLoaderのプロセス数0 は persistent_workers が使えないので注意
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
@@ -320,7 +324,9 @@ def train(args):
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
logger.info(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
@@ -333,11 +339,8 @@ def train(args):
text_encoder, optimizer, train_dataloader, lr_scheduler
)
# transform DDP after prepare
text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet)
index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0]
# print(len(index_no_updates), torch.sum(index_no_updates))
# logger.info(len(index_no_updates), torch.sum(index_no_updates))
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
# Freeze all parameters except for the token embeddings in text encoder
@@ -375,15 +378,17 @@ def train(args):
# 学習する
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
print("running training / 学習開始")
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
print(f" gradient ccumulation steps / 勾配合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
logger.info("running training / 学習開始")
logger.info(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
logger.info(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
logger.info(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
logger.info(f" num epochs / epoch数: {num_train_epochs}")
logger.info(f" batch size per device / バッチサイズ: {args.train_batch_size}")
logger.info(
f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
)
logger.info(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
logger.info(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
@@ -397,16 +402,21 @@ def train(args):
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
accelerator.init_trackers(
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
)
# function for saving/removing
def save_model(ckpt_name, embs, steps, epoch_no, force_sync_upload=False):
os.makedirs(args.output_dir, exist_ok=True)
ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"\nsaving checkpoint: {ckpt_file}")
logger.info("")
logger.info(f"saving checkpoint: {ckpt_file}")
save_weights(ckpt_file, embs, save_dtype)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
@@ -414,12 +424,13 @@ def train(args):
def remove_model(old_ckpt_name):
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
if os.path.exists(old_ckpt_file):
print(f"removing old checkpoint: {old_ckpt_file}")
logger.info(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
# training loop
for epoch in range(num_train_epochs):
print(f"\nepoch {epoch+1}/{num_train_epochs}")
logger.info("")
logger.info(f"epoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
text_encoder.train()
@@ -431,7 +442,7 @@ def train(args):
with accelerator.accumulate(text_encoder):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
@@ -450,7 +461,7 @@ def train(args):
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
# Predict the noise residual
with accelerator.autocast():
@@ -462,14 +473,16 @@ def train(args):
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
if args.masked_loss:
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.debiased_estimation_loss:
@@ -578,7 +591,7 @@ def train(args):
accelerator.end_training()
if args.save_state and is_main_process:
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)
updated_embs = text_encoder.get_input_embeddings().weight[token_ids_XTI].data.detach().clone()
@@ -589,7 +602,7 @@ def train(args):
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(ckpt_name, updated_embs, global_step, num_train_epochs, force_sync_upload=True)
print("model saved.")
logger.info("model saved.")
def save_weights(file, updated_embs, save_dtype):
@@ -650,9 +663,12 @@ def load_weights(file):
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, False)
train_util.add_training_arguments(parser, True)
train_util.add_masked_loss_arguments(parser)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser, False)
@@ -665,7 +681,9 @@ def setup_parser() -> argparse.ArgumentParser:
help="format to save the model (default is .pt) / モデル保存時の形式デフォルトはpt",
)
parser.add_argument("--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み")
parser.add_argument(
"--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み"
)
parser.add_argument(
"--num_vectors_per_token", type=int, default=1, help="number of vectors per token / トークンに割り当てるembeddingsの要素数"
)
@@ -675,7 +693,9 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること",
)
parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可")
parser.add_argument(
"--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可"
)
parser.add_argument(
"--use_object_template",
action="store_true",
@@ -694,6 +714,7 @@ if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)
train(args)