Compare commits

..

2 Commits

Author SHA1 Message Date
Kohya S
2bdcd9b2db update as official impl 2023-09-23 20:31:02 +09:00
Kohya S
40525d4f4b implement FreeU 2023-09-22 07:56:09 +09:00
42 changed files with 748 additions and 2068 deletions

View File

@@ -1,7 +0,0 @@
---
version: 2
updates:
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "monthly"

View File

@@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v3
- name: typos-action
uses: crate-ci/typos@v1.16.15
uses: crate-ci/typos@v1.13.10

View File

@@ -1,7 +1,3 @@
SDXLがサポートされました。sdxlブランチはmainブランチにマージされました。リポジトリを更新したときにはUpgradeの手順を実行してください。また accelerate のバージョンが上がっていますので、accelerate config を再度実行してください。
SDXL学習については[こちら](./README.md#sdxl-training)をご覧ください(英語です)。
## リポジトリについて
Stable Diffusionの学習、画像生成、その他のスクリプトを入れたリポジトリです。
@@ -13,12 +9,13 @@ GUIやPowerShellスクリプトなど、より使いやすくする機能が[bma
* DreamBooth、U-NetおよびText Encoderの学習をサポート
* fine-tuning、同上
* LoRAの学習をサポート
* 画像生成
* モデル変換Stable Diffision ckpt/safetensorsとDiffusersの相互変換
## 使用法について
当リポジトリ内およびnote.comに記事がありますのでそちらをご覧ください将来的にはすべてこちらへ移すかもしれません
* [学習について、共通編](./docs/train_README-ja.md) : データ整備やオプションなど
* [データセット設定](./docs/config_README-ja.md)
* [DreamBoothの学習について](./docs/train_db_README-ja.md)
@@ -44,13 +41,11 @@ PowerShellを使う場合、venvを使えるようにするためには以下の
## Windows環境でのインストール
スクリプトはPyTorch 2.0.1でテストしています。PyTorch 1.12.1でも動作すると思われます
以下の例ではPyTorchは2.0.1CUDA 11.8版をインストールします。CUDA 11.6版やPyTorch 1.12.1を使う場合は適宜書き換えください。
以下の例ではPyTorchは1.12.1CUDA 11.6版をインストールします。CUDA 11.3版やPyTorch 1.13を使う場合は適宜書き換えください
なお、python -m venvの行で「python」とだけ表示された場合、py -m venvのようにpythonをpyに変更してください。
PowerShellを使う場合、通常の管理者ではないPowerShellを開き以下を順に実行します。
通常の管理者ではないPowerShellを開き以下を順に実行します。
```powershell
git clone https://github.com/kohya-ss/sd-scripts.git
@@ -59,14 +54,43 @@ cd sd-scripts
python -m venv venv
.\venv\Scripts\activate
pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
pip install --upgrade -r requirements.txt
pip install xformers==0.0.20
pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
accelerate config
```
コマンドプロンプトでも同一です。
<!--
pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117
pip install --use-pep517 --upgrade -r requirements.txt
pip install -U -I --no-deps xformers==0.0.16
-->
コマンドプロンプトでは以下になります。
```bat
git clone https://github.com/kohya-ss/sd-scripts.git
cd sd-scripts
python -m venv venv
.\venv\Scripts\activate
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
pip install --upgrade -r requirements.txt
pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
copy /y .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
copy /y .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
copy /y .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
accelerate config
```
(注:``python -m venv venv`` のほうが ``python -m venv --system-site-packages venv`` より安全そうなため書き換えました。globalなpythonにパッケージがインストールしてあると、後者だといろいろと問題が起きます。
@@ -87,41 +111,30 @@ accelerate configの質問には以下のように答えてください。bf1
※場合によって ``ValueError: fp16 mixed precision requires a GPU`` というエラーが出ることがあるようです。この場合、6番目の質問
``What GPU(s) (by id) should be used for training on this machine as a comma-separated list? [all]:``に「0」と答えてください。id `0`のGPUが使われます。
### オプション:`bitsandbytes`8bit optimizerを使う
### PyTorchとxformersのバージョンについて
`bitsandbytes`はオプションになりました。Linuxでは通常通りpipでインストールできます0.41.1または以降のバージョンを推奨)
他のバージョンでは学習がうまくいかない場合があるようです。特に他の理由がなければ指定のバージョンをお使いください
Windowsでは0.35.0または0.41.1を推奨します。
### オプションLion8bitを使う
- `bitsandbytes` 0.35.0: 安定しているとみられるバージョンです。AdamW8bitは使用できますが、他のいくつかの8bit optimizer、学習時の`full_bf16`オプションは使用できません
- `bitsandbytes` 0.41.1: Lion8bit、PagedAdamW8bit、PagedLion8bitをサポートします。`full_bf16`が使用できます。
注:`bitsandbytes` 0.35.0から0.41.0までのバージョンには問題があるようです。 https://github.com/TimDettmers/bitsandbytes/issues/659
以下の手順に従い、`bitsandbytes`をインストールしてください。
### 0.35.0を使う場合
PowerShellの例です。コマンドプロンプトではcpの代わりにcopyを使ってください。
Lion8bitを使う場合には`bitsandbytes`0.38.0以降にアップグレードする必要があります。`bitsandbytes`をアンインストールし、Windows環境では例えば[こちら](https://github.com/jllllll/bitsandbytes-windows-webui)などからWindows版のwhlファイルをインストールしてください。たとえば以下のような手順になります
```powershell
cd sd-scripts
.\venv\Scripts\activate
pip install bitsandbytes==0.35.0
cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\
cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py
cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
pip install https://github.com/jllllll/bitsandbytes-windows-webui/raw/main/bitsandbytes-0.38.1-py3-none-any.whl
```
### 0.41.1を使う場合
アップグレード時には`pip install .`でこのリポジトリを更新し、必要に応じて他のパッケージもアップグレードしてください。
jllllll氏の配布されている[こちら](https://github.com/jllllll/bitsandbytes-windows-webui) または他の場所から、Windows用のwhlファイルをインストールしてください。
### オプションPagedAdamW8bitとPagedLion8bitを使う
PagedAdamW8bitとPagedLion8bitを使う場合には`bitsandbytes`を0.39.0以降にアップグレードする必要があります。`bitsandbytes`をアンインストールし、Windows環境では例えば[こちら](https://github.com/jllllll/bitsandbytes-windows-webui)などからWindows版のwhlファイルをインストールしてください。たとえば以下のような手順になります。
```powershell
python -m pip install bitsandbytes==0.41.1 --prefer-binary --extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
```
アップグレード時には`pip install .`でこのリポジトリを更新し、必要に応じて他のパッケージもアップグレードしてください。
## アップグレード
新しいリリースがあった場合、以下のコマンドで更新できます。

491
README.md
View File

@@ -1,11 +1,9 @@
__SDXL is now supported. The sdxl branch has been merged into the main branch. If you update the repository, please follow the upgrade instructions. Also, the version of accelerate has been updated, so please run accelerate config again.__ The documentation for SDXL training is [here](./README.md#sdxl-training).
This repository contains training, generation and utility scripts for Stable Diffusion.
[__Change History__](#change-history) is moved to the bottom of the page.
[__Change History__](#change-history) is moved to the bottom of the page.
更新履歴は[ページ末尾](#change-history)に移しました。
[日本語版READMEはこちら](./README-ja.md)
[日本語版README](./README-ja.md)
For easier use (GUI and PowerShell scripts etc...), please visit [the repository maintained by bmaltais](https://github.com/bmaltais/kohya_ss). Thanks to @bmaltais!
@@ -18,13 +16,135 @@ This repository contains the scripts for:
* Image generation
* Model conversion (supports 1.x and 2.x, Stable Diffision ckpt/safetensors and Diffusers)
__Stable Diffusion web UI now seems to support LoRA trained by ``sd-scripts``.__ Thank you for great work!!!
## About SDXL training
The feature of SDXL training is now available in sdxl branch as an experimental feature.
Sep 3, 2023: The feature will be merged into the main branch soon. Following are the changes from the previous version.
- ControlNet-LLLite is added. See [documentation](./docs/train_lllite_README.md) for details.
- JPEG XL is supported. [#786](https://github.com/kohya-ss/sd-scripts/pull/786)
- Peak memory usage is reduced. [#791](https://github.com/kohya-ss/sd-scripts/pull/791)
- Input perturbation noise is added. See [#798](https://github.com/kohya-ss/sd-scripts/pull/798) for details.
- Dataset subset now has `caption_prefix` and `caption_suffix` options. The strings are added to the beginning and the end of the captions before shuffling. You can specify the options in `.toml`.
- Other minor changes.
- Thanks for contributions from Isotr0py, vvern999, lansing and others!
Aug 13, 2023:
- LoRA-FA is added experimentally. Specify `--network_module networks.lora_fa` option instead of `--network_module networks.lora`. The trained model can be used as a normal LoRA model.
Aug 12, 2023:
- The default value of noise offset when omitted has been changed to 0 from 0.0357.
- The different learning rates for each U-Net block are now supported. Specify with `--block_lr` option. Specify 23 values separated by commas like `--block_lr 1e-3,1e-3 ... 1e-3`.
- 23 values correspond to `0: time/label embed, 1-9: input blocks 0-8, 10-12: mid blocks 0-2, 13-21: output blocks 0-8, 22: out`.
Aug 6, 2023:
- [SAI Model Spec](https://github.com/Stability-AI/ModelSpec) metadata is now supported partially. `hash_sha256` is not supported yet.
- The main items are set automatically.
- You can set title, author, description, license and tags with `--metadata_xxx` options in each training script.
- Merging scripts also support minimum SAI Model Spec metadata. See the help message for the usage.
- Metadata editor will be available soon.
- SDXL LoRA has `sdxl_base_v1-0` now for `ss_base_model_version` metadata item, instead of `v0-9`.
Aug 4, 2023:
- `bitsandbytes` is now optional. Please install it if you want to use it. The insructions are in the later section.
- `albumentations` is not required anymore.
- An issue for pooled output for Textual Inversion training is fixed.
- `--v_pred_like_loss ratio` option is added. This option adds the loss like v-prediction loss in SDXL training. `0.1` means that the loss is added 10% of the v-prediction loss. The default value is None (disabled).
- In v-prediction, the loss is higher in the early timesteps (near the noise). This option can be used to increase the loss in the early timesteps.
- Arbitrary options can be used for Diffusers' schedulers. For example `--lr_scheduler_args "lr_end=1e-8"`.
- `sdxl_gen_imgs.py` supports batch size > 1.
- Fix ControlNet to work with attention couple and reginal LoRA in `gen_img_diffusers.py`.
Summary of the feature:
- `tools/cache_latents.py` is added. This script can be used to cache the latents to disk in advance.
- The options are almost the same as `sdxl_train.py'. See the help message for the usage.
- Please launch the script as follows:
`accelerate launch --num_cpu_threads_per_process 1 tools/cache_latents.py ...`
- This script should work with multi-GPU, but it is not tested in my environment.
- `tools/cache_text_encoder_outputs.py` is added. This script can be used to cache the text encoder outputs to disk in advance.
- The options are almost the same as `cache_latents.py' and `sdxl_train.py'. See the help message for the usage.
- `sdxl_train.py` is a script for SDXL fine-tuning. The usage is almost the same as `fine_tune.py`, but it also supports DreamBooth dataset.
- `--full_bf16` option is added. Thanks to KohakuBlueleaf!
- This option enables the full bfloat16 training (includes gradients). This option is useful to reduce the GPU memory usage.
- However, bitsandbytes==0.35 doesn't seem to support this. Please use a newer version of bitsandbytes or another optimizer.
- I cannot find bitsandbytes>0.35.0 that works correctly on Windows.
- In addition, the full bfloat16 training might be unstable. Please use it at your own risk.
- `prepare_buckets_latents.py` now supports SDXL fine-tuning.
- `sdxl_train_network.py` is a script for LoRA training for SDXL. The usage is almost the same as `train_network.py`.
- Both scripts has following additional options:
- `--cache_text_encoder_outputs` and `--cache_text_encoder_outputs_to_disk`: Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions.
- `--no_half_vae`: Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs.
- The image generation during training is now available. `--no_half_vae` option also works to avoid black images.
- `--weighted_captions` option is not supported yet for both scripts.
- `--min_timestep` and `--max_timestep` options are added to each training script. These options can be used to train U-Net with different timesteps. The default values are 0 and 1000.
- `sdxl_train_textual_inversion.py` is a script for Textual Inversion training for SDXL. The usage is almost the same as `train_textual_inversion.py`.
- `--cache_text_encoder_outputs` is not supported.
- `token_string` must be alphabet only currently, due to the limitation of the open-clip tokenizer.
- There are two options for captions:
1. Training with captions. All captions must include the token string. The token string is replaced with multiple tokens.
2. Use `--use_object_template` or `--use_style_template` option. The captions are generated from the template. The existing captions are ignored.
- See below for the format of the embeddings.
- `sdxl_gen_img.py` is added. This script can be used to generate images with SDXL, including LoRA. See the help message for the usage.
- Textual Inversion is supported, but the name for the embeds in the caption becomes alphabet only. For example, `neg_hand_v1.safetensors` can be activated with `neghandv`.
`requirements.txt` is updated to support SDXL training.
### Tips for SDXL training
- The default resolution of SDXL is 1024x1024.
- The fine-tuning can be done with 24GB GPU memory with the batch size of 1. For 24GB GPU, the following options are recommended __for the fine-tuning with 24GB GPU memory__:
- Train U-Net only.
- Use gradient checkpointing.
- Use `--cache_text_encoder_outputs` option and caching latents.
- Use Adafactor optimizer. RMSprop 8bit or Adagrad 8bit may work. AdamW 8bit doesn't seem to work.
- The LoRA training can be done with 8GB GPU memory (10GB recommended). For reducing the GPU memory usage, the following options are recommended:
- Train U-Net only.
- Use gradient checkpointing.
- Use `--cache_text_encoder_outputs` option and caching latents.
- Use one of 8bit optimizers or Adafactor optimizer.
- Use lower dim (-8 for 8GB GPU).
- `--network_train_unet_only` option is highly recommended for SDXL LoRA. Because SDXL has two text encoders, the result of the training will be unexpected.
- PyTorch 2 seems to use slightly less GPU memory than PyTorch 1.
- `--bucket_reso_steps` can be set to 32 instead of the default value 64. Smaller values than 32 will not work for SDXL training.
Example of the optimizer settings for Adafactor with the fixed learning rate:
```toml
optimizer_type = "adafactor"
optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ]
lr_scheduler = "constant_with_warmup"
lr_warmup_steps = 100
learning_rate = 4e-7 # SDXL original learning rate
```
### Format of Textual Inversion embeddings
```python
from safetensors.torch import save_file
state_dict = {"clip_g": embs_for_text_encoder_1280, "clip_l": embs_for_text_encoder_768}
save_file(state_dict, file)
```
## About requirements.txt
These files do not contain requirements for PyTorch. Because the versions of them depend on your environment. Please install PyTorch at first (see installation guide below.)
The scripts are tested with Pytorch 2.0.1. 1.12.1 is not tested but should work.
The scripts are tested with PyTorch 1.12.1 and 2.0.1, Diffusers 0.18.2.
## Links to usage documentation
## Links to how-to-use documents
Most of the documents are written in Japanese.
@@ -64,9 +184,9 @@ cd sd-scripts
python -m venv venv
.\venv\Scripts\activate
pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
pip install --upgrade -r requirements.txt
pip install xformers==0.0.20
pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
accelerate config
```
@@ -95,6 +215,31 @@ note: Some user reports ``ValueError: fp16 mixed precision requires a GPU`` is o
(Single GPU with id `0` will be used.)
### Experimental: Use PyTorch 2.0
In this case, you need to install PyTorch 2.0 and xformers 0.0.20. Instead of the above, please type the following:
```powershell
git clone https://github.com/kohya-ss/sd-scripts.git
cd sd-scripts
python -m venv venv
.\venv\Scripts\activate
pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118
pip install --upgrade -r requirements.txt
pip install xformers==0.0.20
accelerate config
```
Answers to accelerate config should be the same as above.
### about PyTorch and xformers
Other versions of PyTorch and xformers seem to have problems with training.
If there is no other reason, please install the specified version.
### Optional: Use `bitsandbytes` (8bit optimizer)
For 8bit optimizer, you need to install `bitsandbytes`. For Linux, please install `bitsandbytes` as usual (0.41.1 or later is recommended.)
@@ -161,154 +306,214 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
[BLIP](https://github.com/salesforce/BLIP): BSD-3-Clause
## SDXL training
The documentation in this section will be moved to a separate document later.
### Training scripts for SDXL
- `sdxl_train.py` is a script for SDXL fine-tuning. The usage is almost the same as `fine_tune.py`, but it also supports DreamBooth dataset.
- `--full_bf16` option is added. Thanks to KohakuBlueleaf!
- This option enables the full bfloat16 training (includes gradients). This option is useful to reduce the GPU memory usage.
- The full bfloat16 training might be unstable. Please use it at your own risk.
- The different learning rates for each U-Net block are now supported in sdxl_train.py. Specify with `--block_lr` option. Specify 23 values separated by commas like `--block_lr 1e-3,1e-3 ... 1e-3`.
- 23 values correspond to `0: time/label embed, 1-9: input blocks 0-8, 10-12: mid blocks 0-2, 13-21: output blocks 0-8, 22: out`.
- `prepare_buckets_latents.py` now supports SDXL fine-tuning.
- `sdxl_train_network.py` is a script for LoRA training for SDXL. The usage is almost the same as `train_network.py`.
- Both scripts has following additional options:
- `--cache_text_encoder_outputs` and `--cache_text_encoder_outputs_to_disk`: Cache the outputs of the text encoders. This option is useful to reduce the GPU memory usage. This option cannot be used with options for shuffling or dropping the captions.
- `--no_half_vae`: Disable the half-precision (mixed-precision) VAE. VAE for SDXL seems to produce NaNs in some cases. This option is useful to avoid the NaNs.
- `--weighted_captions` option is not supported yet for both scripts.
- `sdxl_train_textual_inversion.py` is a script for Textual Inversion training for SDXL. The usage is almost the same as `train_textual_inversion.py`.
- `--cache_text_encoder_outputs` is not supported.
- There are two options for captions:
1. Training with captions. All captions must include the token string. The token string is replaced with multiple tokens.
2. Use `--use_object_template` or `--use_style_template` option. The captions are generated from the template. The existing captions are ignored.
- See below for the format of the embeddings.
- `--min_timestep` and `--max_timestep` options are added to each training script. These options can be used to train U-Net with different timesteps. The default values are 0 and 1000.
### Utility scripts for SDXL
- `tools/cache_latents.py` is added. This script can be used to cache the latents to disk in advance.
- The options are almost the same as `sdxl_train.py'. See the help message for the usage.
- Please launch the script as follows:
`accelerate launch --num_cpu_threads_per_process 1 tools/cache_latents.py ...`
- This script should work with multi-GPU, but it is not tested in my environment.
- `tools/cache_text_encoder_outputs.py` is added. This script can be used to cache the text encoder outputs to disk in advance.
- The options are almost the same as `cache_latents.py` and `sdxl_train.py`. See the help message for the usage.
- `sdxl_gen_img.py` is added. This script can be used to generate images with SDXL, including LoRA, Textual Inversion and ControlNet-LLLite. See the help message for the usage.
### Tips for SDXL training
- The default resolution of SDXL is 1024x1024.
- The fine-tuning can be done with 24GB GPU memory with the batch size of 1. For 24GB GPU, the following options are recommended __for the fine-tuning with 24GB GPU memory__:
- Train U-Net only.
- Use gradient checkpointing.
- Use `--cache_text_encoder_outputs` option and caching latents.
- Use Adafactor optimizer. RMSprop 8bit or Adagrad 8bit may work. AdamW 8bit doesn't seem to work.
- The LoRA training can be done with 8GB GPU memory (10GB recommended). For reducing the GPU memory usage, the following options are recommended:
- Train U-Net only.
- Use gradient checkpointing.
- Use `--cache_text_encoder_outputs` option and caching latents.
- Use one of 8bit optimizers or Adafactor optimizer.
- Use lower dim (4 to 8 for 8GB GPU).
- `--network_train_unet_only` option is highly recommended for SDXL LoRA. Because SDXL has two text encoders, the result of the training will be unexpected.
- PyTorch 2 seems to use slightly less GPU memory than PyTorch 1.
- `--bucket_reso_steps` can be set to 32 instead of the default value 64. Smaller values than 32 will not work for SDXL training.
Example of the optimizer settings for Adafactor with the fixed learning rate:
```toml
optimizer_type = "adafactor"
optimizer_args = [ "scale_parameter=False", "relative_step=False", "warmup_init=False" ]
lr_scheduler = "constant_with_warmup"
lr_warmup_steps = 100
learning_rate = 4e-7 # SDXL original learning rate
```
### Format of Textual Inversion embeddings for SDXL
```python
from safetensors.torch import save_file
state_dict = {"clip_g": embs_for_text_encoder_1280, "clip_l": embs_for_text_encoder_768}
save_file(state_dict, file)
```
### ControlNet-LLLite
ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [documentation](./docs/train_lllite_README.md) for details.
## Change History
### Oct 9. 2023 / 2023/10/9
### 15 Jun. 2023, 2023/06/15
- `tag_images_by_wd_14_tagger.py` now supports Onnx. If you use Onnx, TensorFlow is not required anymore. [#864](https://github.com/kohya-ss/sd-scripts/pull/864) Thanks to Isotr0py!
- `--onnx` option is added. If you use Onnx, specify `--onnx` option.
- Please install Onnx and other required packages.
1. Uninstall TensorFlow.
1. `pip install tensorboard==2.14.1` This is required for the specified version of protobuf.
1. `pip install protobuf==3.20.3` This is required for Onnx.
1. `pip install onnx==1.14.1`
1. `pip install onnxruntime-gpu==1.16.0` or `pip install onnxruntime==1.16.0`
- `--append_tags` option is added to `tag_images_by_wd_14_tagger.py`. This option appends the tags to the existing tags, instead of replacing them. [#858](https://github.com/kohya-ss/sd-scripts/pull/858) Thanks to a-l-e-x-d-s-9!
- [OFT](https://oft.wyliu.com/) is now supported.
- You can use `networks.oft` for the network module in `sdxl_train_network.py`. The usage is the same as `networks.lora`. Some options are not supported.
- `sdxl_gen_img.py` also supports OFT as `--network_module`.
- OFT only supports SDXL currently. Because current OFT tweaks Q/K/V and O in the transformer, and SD1/2 have extremely fewer transformers than SDXL.
- The implementation is heavily based on laksjdjf's [OFT implementation](https://github.com/laksjdjf/sd-trainer/blob/dev/networks/lora_modules.py). Thanks to laksjdjf!
- Other bug fixes and improvements.
- Prodigy optimizer is supported in each training script. It is a member of D-Adaptation and is effective for DyLoRA training. [PR #585](https://github.com/kohya-ss/sd-scripts/pull/585) Please see the PR for details. Thanks to sdbds!
- Install the package with `pip install prodigyopt`. Then specify the option like `--optimizer_type="prodigy"`.
- Arbitrary Dataset is supported in each training script (except XTI). You can use it by defining a Dataset class that returns images and captions.
- Prepare a Python script and define a class that inherits `train_util.MinimalDataset`. Then specify the option like `--dataset_class package.module.DatasetClass` in each training script.
- Please refer to `MinimalDataset` for implementation. I will prepare a sample later.
- The following features have been added to the generation script.
- Added an option `--highres_fix_disable_control_net` to disable ControlNet in the 2nd stage of Highres. Fix. Please try it if the image is disturbed by some ControlNet such as Canny.
- Added Variants similar to sd-dynamic-propmpts in the prompt.
- If you specify `{spring|summer|autumn|winter}`, one of them will be randomly selected.
- If you specify `{2$$chocolate|vanilla|strawberry}`, two of them will be randomly selected.
- If you specify `{1-2$$ and $$chocolate|vanilla|strawberry}`, one or two of them will be randomly selected and connected by ` and `.
- You can specify the number of candidates in the range `0-2`. You cannot omit one side like `-2` or `1-`.
- It can also be specified for the prompt option.
- If you specify `e` or `E`, all candidates will be selected and the prompt will be repeated multiple times (`--images_per_prompt` is ignored). It may be useful for creating X/Y plots.
- You can also specify `--am {e$$0.2|0.4|0.6|0.8|1.0},{e$$0.4|0.7|1.0} --d 1234`. In this case, 15 prompts will be generated with 5*3.
- There is no weighting function.
- `tag_images_by_wd_14_tagger.py` が Onnx をサポートました。Onnx を使用する場合は TensorFlow は不要です。[#864](https://github.com/kohya-ss/sd-scripts/pull/864) Isotr0py氏に感謝します。
- Onnxを使用する場合は、`--onnx` オプションを指定してください
- Onnx とその他の必要なパッケージをインストールしてください
1. TensorFlow をアンインストールしてください。
1. `pip install tensorboard==2.14.1` protobufの指定バージョンにこれが必要
1. `pip install protobuf==3.20.3` Onnxのために必要
1. `pip install onnx==1.14.1`
1. `pip install onnxruntime-gpu==1.16.0` または `pip install onnxruntime==1.16.0`
- `tag_images_by_wd_14_tagger.py``--append_tags` オプションが追加されました。このオプションを指定すると、既存のタグに上書きするのではなく、新しいタグのみが既存のタグに追加されます。 [#858](https://github.com/kohya-ss/sd-scripts/pull/858) a-l-e-x-d-s-9氏に感謝します。
- [OFT](https://oft.wyliu.com/) をサポートしました
- `sdxl_train_network.py``--network_module``networks.oft` を指定してください。使用方法は `networks.lora` と同様ですが一部のオプションは未サポートです。
- `sdxl_gen_img.py` でも同様に OFT を指定できま
- OFT は現在 SDXL のみサポートしています。OFT は現在 transformer の Q/K/V と O を変更しますが、SD1/2 は transformer の数が SDXL よりも極端に少ないためです。
- 実装は laksjdjf 氏の [OFT実装](https://github.com/laksjdjf/sd-trainer/blob/dev/networks/lora_modules.py) を多くの部分で参考にしています。laksjdjf 氏に感謝します
- その他のバグ修正と改善
- 各学習スクリプトでProdigyオプティマイザがサポートされました。D-Adaptationの仲間でDyLoRAの学習に有効とのことです。 [PR #585](https://github.com/kohya-ss/sd-scripts/pull/585) 詳細はPRをご覧ください。sdbds氏に感謝します。
- `pip install prodigyopt` としてパッケージをインストールしてください。また `--optimizer_type="prodigy"` のようにオプションを指定します
- 各学習スクリプトで任意のDatasetをサポートしましたXTIを除く。画像とキャプションを返すDatasetクラスを定義することで、学習スクリプトから利用できます
- Pythonスクリプトを用意し、`train_util.MinimalDataset`を継承するクラスを定義してください。そして各学習スクリプトのオプションで `--dataset_class package.module.DatasetClass` のように指定してください。
- 実装方法は `MinimalDataset` を参考にしてください。のちほどサンプルを用意します
- 生成スクリプトに以下の機能追加を行いました
- Highres. Fixの2nd stageでControlNetを無効化するオプション `--highres_fix_disable_control_net` を追加しました。Canny等一部のControlNetで画像が乱れる場合にお試しください。
- プロンプトでsd-dynamic-propmptsに似たVariantをサポートしました。
- `{spring|summer|autumn|winter}` のように指定すると、いずれかがランダムに選択されます。
- `{2$$chocolate|vanilla|strawberry}` のように指定すると、いずれか2個がランダムに選択されます
- `{1-2$$ and $$chocolate|vanilla|strawberry}` のように指定すると、1個か2個がランダムに選択され ` and ` で接続されます。
- 個数のレンジ指定では`0-2`のように0個も指定可能です。`-2`や`1-`のような片側の省略はできません
- プロンプトオプションに対しても指定可能です。
- `{e$$chocolate|vanilla|strawberry}` のように`e`または`E`を指定すると、すべての候補が選択されプロンプトが複数回繰り返されます(`--images_per_prompt`は無視されます。X/Y plotの作成に便利かもしれません
- `--am {e$$0.2|0.4|0.6|0.8|1.0},{e$$0.4|0.7|1.0} --d 1234`のような指定も可能です。この場合、5*3で15回のプロンプトが生成されます
- Weightingの機能はありません。
### Oct 1. 2023 / 2023/10/1
### 8 Jun. 2023, 2023/06/08
- SDXL training is now available in the main branch. The sdxl branch is merged into the main branch.
- Fixed a bug where clip skip did not work when training with weighted captions (`--weighted_captions` specified) and when generating sample images during training.
- 重みづけキャプションでの学習時(`--weighted_captions`指定時および学習中のサンプル画像生成時にclip skipが機能しない不具合を修正しました。
- [SAI Model Spec](https://github.com/Stability-AI/ModelSpec) metadata is now supported partially. `hash_sha256` is not supported yet.
- The main items are set automatically.
- You can set title, author, description, license and tags with `--metadata_xxx` options in each training script.
- Merging scripts also support minimum SAI Model Spec metadata. See the help message for the usage.
- Metadata editor will be available soon.
### 6 Jun. 2023, 2023/06/06
- `bitsandbytes` is now optional. Please install it if you want to use it. The insructions are in the later section.
- Fix `train_network.py` to probably work with older versions of LyCORIS.
- `gen_img_diffusers.py` now supports `BREAK` syntax.
- `train_network.py`がLyCORISの以前のバージョンでも恐らく動作するよう修正しました。
- `gen_img_diffusers.py` で `BREAK` 構文をサポートしました。
- `albumentations` is not required anymore.
### 3 Jun. 2023, 2023/06/03
- `--v_pred_like_loss ratio` option is added. This option adds the loss like v-prediction loss in SDXL training. `0.1` means that the loss is added 10% of the v-prediction loss. The default value is None (disabled).
- In v-prediction, the loss is higher in the early timesteps (near the noise). This option can be used to increase the loss in the early timesteps.
- Max Norm Regularization is now available in `train_network.py`. [PR #545](https://github.com/kohya-ss/sd-scripts/pull/545) Thanks to AI-Casanova!
- Max Norm Regularization is a technique to stabilize network training by limiting the norm of network weights. It may be effective in suppressing overfitting of LoRA and improving stability when used with other LoRAs. See PR for details.
- Specify as `--scale_weight_norms=1.0`. It seems good to try from `1.0`.
- The networks other than LoRA in this repository (such as LyCORIS) do not support this option.
- Arbitrary options can be used for Diffusers' schedulers. For example `--lr_scheduler_args "lr_end=1e-8"`.
- Three types of dropout have been added to `train_network.py` and LoRA network.
- Dropout is a technique to suppress overfitting and improve network performance by randomly setting some of the network outputs to 0.
- `--network_dropout` is a normal dropout at the neuron level. In the case of LoRA, it is applied to the output of down. Proposed in [PR #545](https://github.com/kohya-ss/sd-scripts/pull/545) Thanks to AI-Casanova!
- `--network_dropout=0.1` specifies the dropout probability to `0.1`.
- Note that the specification method is different from LyCORIS.
- For LoRA network, `--network_args` can specify `rank_dropout` to dropout each rank with specified probability. Also `module_dropout` can be specified to dropout each module with specified probability.
- Specify as `--network_args "rank_dropout=0.2" "module_dropout=0.1"`.
- `--network_dropout`, `rank_dropout`, and `module_dropout` can be specified at the same time.
- Values of 0.1 to 0.3 may be good to try. Values greater than 0.5 should not be specified.
- `rank_dropout` and `module_dropout` are original techniques of this repository. Their effectiveness has not been verified yet.
- The networks other than LoRA in this repository (such as LyCORIS) do not support these options.
- LoRA-FA is added experimentally. Specify `--network_module networks.lora_fa` option instead of `--network_module networks.lora`. The trained model can be used as a normal LoRA model.
- JPEG XL is supported. [#786](https://github.com/kohya-ss/sd-scripts/pull/786)
- Input perturbation noise is added. See [#798](https://github.com/kohya-ss/sd-scripts/pull/798) for details.
- Dataset subset now has `caption_prefix` and `caption_suffix` options. The strings are added to the beginning and the end of the captions before shuffling. You can specify the options in `.toml`.
- Intel ARC support with IPEX is added. [#825](https://github.com/kohya-ss/sd-scripts/pull/825)
- Other bug fixes and improvements.
- Added an option `--scale_v_pred_loss_like_noise_pred` to scale v-prediction loss like noise prediction in each training script.
- By scaling the loss according to the time step, the weights of global noise prediction and local noise prediction become the same, and the improvement of details may be expected.
- See [this article](https://xrg.hatenablog.com/entry/2023/06/02/202418) by xrg for details (written in Japanese). Thanks to xrg for the great suggestion!
- Max Norm Regularizationが`train_network.py`で使えるようになりました。[PR #545](https://github.com/kohya-ss/sd-scripts/pull/545) AI-Casanova氏に感謝します。
- Max Norm Regularizationは、ネットワークの重みのルムを制限することで、ネットワークの学習を安定させる手法です。LoRAの過学習の抑制、他のLoRAと併用した時の安定性の向上が期待できるかもしれません。詳細はPRを参照してください。
- `--scale_weight_norms=1.0`のように `--scale_weight_norms` で指定してください。`1.0`から試すと良いようです。
- LyCORIS等、当リポジトリ以外のネットワークは現時点では未対応です。
- `train_network.py` およびLoRAに計三種類のdropoutを追加しました。
- dropoutはネットワークの一部の出力をランダムに0にすることで、過学習の抑制、ネットワークの性能向上等を図る手法です。
- `--network_dropout` はニューロン単位の通常のdropoutです。LoRAの場合、downの出力に対して適用されます。[PR #545](https://github.com/kohya-ss/sd-scripts/pull/545) で提案されました。AI-Casanova氏に感謝します。
- `--network_dropout=0.1` などとすることで、dropoutの確率を指定できます。
- LyCORISとは指定方法が異なりますのでご注意ください。
- LoRAの場合、`--network_args`に`rank_dropout`を指定することで各rankを指定確率でdropoutします。また同じくLoRAの場合、`--network_args`に`module_dropout`を指定することで各モジュールを指定確率でdropoutします。
- `--network_args "rank_dropout=0.2" "module_dropout=0.1"` のように指定します。
- `--network_dropout`、`rank_dropout` 、 `module_dropout` は同時に指定できます。
- それぞれの値は0.1~0.3程度から試してみると良いかもしれません。0.5を超える値は指定しない方が良いでしょう。
- `rank_dropout`および`module_dropout`は当リポジトリ独自の手法です。有効性の検証はまだ行っていません。
- これらのdropoutはLyCORIS等、当リポジトリ以外のネットワークは現時点では未対応です。
- 各学習スクリプトにv-prediction lossをnoise predictionと同様の値にスケールするオプション`--scale_v_pred_loss_like_noise_pred`を追加しました。
- タイムステップに応じてlossをスケールすることで、 大域的なノイズの予測と局所的なノイズの予測の重みが同じになり、ディテールの改善が期待できるかもしれません。
- 詳細はxrg氏のこちらの記事をご参照ください[noise_predictionモデルとv_predictionモデルの損失 - 勾配降下党青年局](https://xrg.hatenablog.com/entry/2023/06/02/202418) xrg氏の素晴らしい記事に感謝します。
### 31 May 2023, 2023/05/31
- Show warning when image caption file does not exist during training. [PR #533](https://github.com/kohya-ss/sd-scripts/pull/533) Thanks to TingTingin!
- Warning is also displayed when using class+identifier dataset. Please ignore if it is intended.
- `train_network.py` now supports merging network weights before training. [PR #542](https://github.com/kohya-ss/sd-scripts/pull/542) Thanks to u-haru!
- `--base_weights` option specifies LoRA or other model files (multiple files are allowed) to merge.
- `--base_weights_multiplier` option specifies multiplier of the weights to merge (multiple values are allowed). If omitted or less than `base_weights`, 1.0 is used.
- This is useful for incremental learning. See PR for details.
- Show warning and continue training when uploading to HuggingFace fails.
- 学習時に画像のキャプションファイルが存在しない場合、警告が表示されるようになりました。 [PR #533](https://github.com/kohya-ss/sd-scripts/pull/533) TingTingin氏に感謝します。
- class+identifier方式のデータセットを利用している場合も警告が表示されます。意図している通りの場合は無視してください。
- `train_network.py` に学習前にモデルにnetworkの重みをマージする機能が追加されました。 [PR #542](https://github.com/kohya-ss/sd-scripts/pull/542) u-haru氏に感謝します。
- `--base_weights` オプションでLoRA等のモデルファイル複数可を指定すると、それらの重みをマージします。
- `--base_weights_multiplier` オプションでマージする重みの倍率(複数可)を指定できます。省略時または`base_weights`よりも数が少ない場合は1.0になります。
- 差分追加学習などにご利用ください。詳細はPRをご覧ください。
- HuggingFaceへのアップロードに失敗した場合、警告を表示しそのまま学習を続行するよう変更しました。
### 25 May 2023, 2023/05/25
- [D-Adaptation v3.0](https://github.com/facebookresearch/dadaptation) is now supported. [PR #530](https://github.com/kohya-ss/sd-scripts/pull/530) Thanks to sdbds!
- `--optimizer_type` now accepts `DAdaptAdamPreprint`, `DAdaptAdanIP`, and `DAdaptLion`.
- `DAdaptAdam` is now new. The old `DAdaptAdam` is available with `DAdaptAdamPreprint`.
- Simply specifying `DAdaptation` will use `DAdaptAdamPreprint` (same behavior as before).
- You need to install D-Adaptation v3.0. After activating venv, please do `pip install -U dadaptation`.
- See PR and D-Adaptation documentation for details.
- [D-Adaptation v3.0](https://github.com/facebookresearch/dadaptation)がサポートされました。 [PR #530](https://github.com/kohya-ss/sd-scripts/pull/530) sdbds氏に感謝します。
- `--optimizer_type`に`DAdaptAdamPreprint`、`DAdaptAdanIP`、`DAdaptLion` が追加されました。
- `DAdaptAdam`が新しくなりました。今までの`DAdaptAdam`は`DAdaptAdamPreprint`で使用できます。
- 単に `DAdaptation` を指定すると`DAdaptAdamPreprint`が使用されます(今までと同じ動き)。
- D-Adaptation v3.0のインストールが必要です。venvを有効にした後 `pip install -U dadaptation` としてください。
- 詳細はPRおよびD-Adaptationのドキュメントを参照してください。
### 22 May 2023, 2023/05/22
- Fixed several bugs.
- The state is saved even when the `--save_state` option is not specified in `fine_tune.py` and `train_db.py`. [PR #521](https://github.com/kohya-ss/sd-scripts/pull/521) Thanks to akshaal!
- Cannot load LoRA without `alpha`. [PR #527](https://github.com/kohya-ss/sd-scripts/pull/527) Thanks to Manjiz!
- Minor changes to console output during sample generation. [PR #515](https://github.com/kohya-ss/sd-scripts/pull/515) Thanks to yanhuifair!
- The generation script now uses xformers for VAE as well.
- いくつかのバグ修正を行いました。
- `fine_tune.py`と`train_db.py`で`--save_state`オプション未指定時にもstateが保存される。 [PR #521](https://github.com/kohya-ss/sd-scripts/pull/521) akshaal氏に感謝します。
- `alpha`を持たないLoRAを読み込めない。[PR #527](https://github.com/kohya-ss/sd-scripts/pull/527) Manjiz氏に感謝します。
- サンプル生成時のコンソール出力の軽微な変更。[PR #515](https://github.com/kohya-ss/sd-scripts/pull/515) yanhuifair氏に感謝します。
- 生成スクリプトでVAEについてもxformersを使うようにしました。
### 16 May 2023, 2023/05/16
- Fixed an issue where an error would occur if the encoding of the prompt file was different from the default. [PR #510](https://github.com/kohya-ss/sd-scripts/pull/510) Thanks to sdbds!
- Please save the prompt file in UTF-8.
- プロンプトファイルのエンコーディングがデフォルトと異なる場合にエラーが発生する問題を修正しました。 [PR #510](https://github.com/kohya-ss/sd-scripts/pull/510) sdbds氏に感謝します。
- プロンプトファイルはUTF-8で保存してください。
### 15 May 2023, 2023/05/15
- Added [English translation of documents](https://github.com/darkstorm2150/sd-scripts#links-to-usage-documentation) by darkstorm2150. Thank you very much!
- The prompt for sample generation during training can now be specified in `.toml` or `.json`. [PR #504](https://github.com/kohya-ss/sd-scripts/pull/504) Thanks to Linaqruf!
- For details on prompt description, please see the PR.
- darkstorm2150氏に[ドキュメント類を英訳](https://github.com/darkstorm2150/sd-scripts#links-to-usage-documentation)していただきました。ありがとうございます!
- 学習中のサンプル生成のプロンプトを`.toml`または`.json`で指定可能になりました。 [PR #504](https://github.com/kohya-ss/sd-scripts/pull/504) Linaqruf氏に感謝します。
- プロンプト記述の詳細は当該PRをご覧ください。
### 11 May 2023, 2023/05/11
- Added an option `--dim_from_weights` to `train_network.py` to automatically determine the dim(rank) from the weight file. [PR #491](https://github.com/kohya-ss/sd-scripts/pull/491) Thanks to AI-Casanova!
- It is useful in combination with `resize_lora.py`. Please see the PR for details.
- Fixed a bug where the noise resolution was incorrect with Multires noise. [PR #489](https://github.com/kohya-ss/sd-scripts/pull/489) Thanks to sdbds!
- Please see the PR for details.
- The image generation scripts can now use img2img and highres fix at the same time.
- Fixed a bug where the hint image of ControlNet was incorrectly BGR instead of RGB in the image generation scripts.
- Added a feature to the image generation scripts to use the memory-efficient VAE.
- If you specify a number with the `--vae_slices` option, the memory-efficient VAE will be used. The maximum output size will be larger, but it will be slower. Please specify a value of about `16` or `32`.
- The implementation of the VAE is in `library/slicing_vae.py`.
- `train_network.py`にdim(rank)を重みファイルから自動決定するオプション`--dim_from_weights`が追加されました。 [PR #491](https://github.com/kohya-ss/sd-scripts/pull/491) AI-Casanova氏に感謝します。
- `resize_lora.py`と組み合わせると有用です。詳細はPRもご参照ください。
- Multires noiseでイズ解像度が正しくない不具合が修正されました。 [PR #489](https://github.com/kohya-ss/sd-scripts/pull/489) sdbds氏に感謝します。
- 詳細は当該PRをご参照ください。
- 生成スクリプトでimg2imgとhighres fixを同時に使用できるようにしました。
- 生成スクリプトでControlNetのhint画像が誤ってBGRだったのをRGBに修正しました。
- 生成スクリプトで省メモリ化VAEを使えるよう機能追加しました。
- `--vae_slices`オプションに数値を指定すると、省メモリ化VAEを用います。出力可能な最大サイズが大きくなりますが、遅くなります。`16`または`32`程度の値を指定してください。
- VAEの実装は`library/slicing_vae.py`にあります。
### 7 May 2023, 2023/05/07
- The documentation has been moved to the `docs` folder. If you have links, please change them.
- Removed `gradio` from `requirements.txt`.
- DAdaptAdaGrad, DAdaptAdan, and DAdaptSGD are now supported by DAdaptation. [PR#455](https://github.com/kohya-ss/sd-scripts/pull/455) Thanks to sdbds!
- DAdaptation needs to be installed. Also, depending on the optimizer, DAdaptation may need to be updated. Please update with `pip install --upgrade dadaptation`.
- Added support for pre-calculation of LoRA weights in image generation scripts. Specify `--network_pre_calc`.
- The prompt option `--am` is available. Also, it is disabled when Regional LoRA is used.
- Added Adaptive noise scale to each training script. Specify a number with `--adaptive_noise_scale` to enable it.
- __Experimental option. It may be removed or changed in the future.__
- This is an original implementation that automatically adjusts the value of the noise offset according to the absolute value of the mean of each channel of the latents. It is expected that appropriate noise offsets will be set for bright and dark images, respectively.
- Specify it together with `--noise_offset`.
- The actual value of the noise offset is calculated as `noise_offset + abs(mean(latents, dim=(2,3))) * adaptive_noise_scale`. Since the latent is close to a normal distribution, it may be a good idea to specify a value of about 1/10 to the same as the noise offset.
- Negative values can also be specified, in which case the noise offset will be clipped to 0 or more.
- Other minor fixes.
- ドキュメントを`docs`フォルダに移動しました。リンク等を張られている場合は変更をお願いいたします。
- `requirements.txt`から`gradio`を削除しました。
- DAdaptationで新しくDAdaptAdaGrad、DAdaptAdan、DAdaptSGDがサポートされました。[PR#455](https://github.com/kohya-ss/sd-scripts/pull/455) sdbds氏に感謝します。
- dadaptationのインストールが必要です。またオプティマイザによってはdadaptationの更新が必要です。`pip install --upgrade dadaptation`で更新してください。
- 画像生成スクリプトでLoRAの重みの事前計算をサポートしました。`--network_pre_calc`を指定してください。
- プロンプトオプションの`--am`が利用できます。またRegional LoRA使用時には無効になります。
- 各学習スクリプトにAdaptive noise scaleを追加しました。`--adaptive_noise_scale`で数値を指定すると有効になります。
- __実験的オプションです。将来的に削除、仕様変更される可能性があります。__
- Noise offsetの値を、latentsの各チャネルの平均値の絶対値に応じて自動調整するオプションです。独自の実装で、明るい画像、暗い画像に対してそれぞれ適切なnoise offsetが設定されることが期待されます。
- `--noise_offset` と同時に指定してください。
- 実際のNoise offsetの値は `noise_offset + abs(mean(latents, dim=(2,3))) * adaptive_noise_scale` で計算されます。 latentは正規分布に近いためnoise_offsetの1/10同程度の値を指定するとよいかもしれません。
- 負の値も指定でき、その場合はnoise offsetは0以上にclipされます。
- その他の細かい修正を行いました。
Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates.
最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。

View File

@@ -1,11 +1,4 @@
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from typing import Union, List, Optional, Dict, Any, Tuple
from diffusers.models.unet_2d_condition import UNet2DConditionOutput

View File

@@ -9,25 +9,7 @@ parms="parms"
nin="nin"
extention="extention" # Intentionally left
nd="nd"
shs="shs"
sts="sts"
scs="scs"
cpc="cpc"
coc="coc"
cic="cic"
msm="msm"
usu="usu"
ici="ici"
lvl="lvl"
dii="dii"
muk="muk"
ori="ori"
hru="hru"
rik="rik"
koo="koo"
yos="yos"
wn="wn"
[files]
extend-exclude = ["_typos.toml", "venv"]
extend-exclude = ["_typos.toml"]

View File

@@ -181,8 +181,6 @@ python networks\extract_lora_from_dylora.py --model "foldername/dylora-model.saf
詳細は[PR #355](https://github.com/kohya-ss/sd-scripts/pull/355) をご覧ください。
SDXLは現在サポートしていません。
フルモデルの25個のブロックの重みを指定できます。最初のブロックに該当するLoRAは存在しませんが、階層別LoRA適用等との互換性のために25個としています。またconv2d3x3に拡張しない場合も一部のブロックにはLoRAが存在しませんが、記述を統一するため常に25個の値を指定してください。
`--network_args` で以下の引数を指定してください。
@@ -248,8 +246,6 @@ network_args = [ "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,
merge_lora.pyでStable DiffusionのモデルにLoRAの学習結果をマージしたり、複数のLoRAモデルをマージしたりできます。
SDXL向けにはsdxl_merge_lora.pyを用意しています。オプション等は同一ですので、以下のmerge_lora.pyを読み替えてください。
### Stable DiffusionのモデルにLoRAのモデルをマージする
マージ後のモデルは通常のStable Diffusionのckptと同様に扱えます。たとえば以下のようなコマンドラインになります。
@@ -280,29 +276,29 @@ python networks\merge_lora.py --sd_model ..\model\model.ckpt
### 複数のLoRAのモデルをマージする
--concatオプションを指定すると、複数のLoRAを単純に結合して新しいLoRAモデルを作成できます。ファイルサイズおよびdim/rankは指定したLoRAの合計サイズになりますマージ時にdim (rank)を変更する場合は `svd_merge_lora.py` を使用してください
__複数のLoRAをマージする場合は原則として `svd_merge_lora.py` を使用してください__ 単純なup同士やdown同士のマージでは、計算結果が正しくなくなるためです
`merge_lora.py` によるマージは差分抽出法でLoRAを生成する場合等、ごく限られた場合でのみ有効です。
たとえば以下のようなコマンドラインになります。
```
python networks\merge_lora.py --save_precision bf16
python networks\merge_lora.py
--save_to ..\lora_train1\model-char1-style1-merged.safetensors
--models ..\lora_train1\last.safetensors ..\lora_train2\last.safetensors
--ratios 1.0 -1.0 --concat --shuffle
--models ..\lora_train1\last.safetensors ..\lora_train2\last.safetensors --ratios 0.6 0.4
```
--concatオプション指定します。
また--shuffleオプションを追加し、重みをシャッフルします。シャッフルしないとマージ後のLoRAから元のLoRAを取り出せるため、コピー機学習などの場合には学習元データが明らかになります。ご注意ください。
--sd_modelオプション指定不要です。
--save_toオプションにマージ後のLoRAモデルの保存先を指定します.ckptまたは.safetensors、拡張子で自動判定
--modelsに学習したLoRAのモデルファイルを指定します。三つ以上も指定可能です。
--ratiosにそれぞれのモデルの比率どのくらい重みを元モデルに反映するかを0~1.0の数値で指定します。二つのモデルを一対一でマージす場合は、「0.5 0.5」になります。「1.0 1.0」では合計の重みが大きくなりすぎて、恐らく結果はあまり望ましくないものになると思われます。
--ratiosにそれぞれのモデルの比率どのくらい重みを元モデルに反映するかを0~1.0の数値で指定します。二つのモデルを一対一でマージす場合は、「0.5 0.5」になります。「1.0 1.0」では合計の重みが大きくなりすぎて、恐らく結果はあまり望ましくないものになると思われます。
v1で学習したLoRAとv2で学習したLoRA、rank次元数の異なるLoRAはマージできません。U-NetだけのLoRAとU-Net+Text EncoderのLoRAはマージできるはずですが、結果は未知数です。
### その他のオプション
* precision
@@ -310,7 +306,6 @@ v1で学習したLoRAとv2で学習したLoRA、rank次元数の異なるL
* save_precision
* モデル保存時の精度をfloat、fp16、bf16から指定できます。省略時はprecisionと同じ精度になります。
他にもいくつかのオプションがありますので、--helpで確認してください。
## 複数のrankが異なるLoRAのモデルをマージする

View File

@@ -10,13 +10,6 @@ import toml
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
@@ -80,8 +73,8 @@ def train(args):
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
if args.debug_dataset:
train_util.debug_dataset(train_dataset_group)
@@ -208,7 +201,7 @@ def train(args):
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
collate_fn=collater,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)

View File

@@ -1,14 +1,16 @@
import argparse
import csv
import glob
import os
from pathlib import Path
import cv2
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
import cv2
from tqdm import tqdm
import numpy as np
from tensorflow.keras.models import load_model
from huggingface_hub import hf_hub_download
import torch
from pathlib import Path
import library.train_util as train_util
@@ -18,7 +20,6 @@ IMAGE_SIZE = 448
# wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2
DEFAULT_WD14_TAGGER_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"]
FILES_ONNX = ["model.onnx"]
SUB_DIR = "variables"
SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"]
CSV_FILE = FILES[-1]
@@ -80,10 +81,7 @@ def main(args):
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
if not os.path.exists(args.model_dir) or args.force_download:
print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
files = FILES
if args.onnx:
files += FILES_ONNX
for file in files:
for file in FILES:
hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
for file in SUB_DIR_FILES:
hf_hub_download(
@@ -98,46 +96,7 @@ def main(args):
print("using existing wd14 tagger model")
# 画像を読み込む
if args.onnx:
import onnx
import onnxruntime as ort
onnx_path = f"{args.model_dir}/model.onnx"
print("Running wd14 tagger with onnx")
print(f"loading onnx model: {onnx_path}")
if not os.path.exists(onnx_path):
raise Exception(
f"onnx model not found: {onnx_path}, please redownload the model with --force_download"
+ " / onnxモデルが見つかりませんでした。--force_downloadで再ダウンロードしてください"
)
model = onnx.load(onnx_path)
input_name = model.graph.input[0].name
try:
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_value
except:
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_param
if args.batch_size != batch_size and type(batch_size) != str:
# some rebatch model may use 'N' as dynamic axes
print(
f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}"
)
args.batch_size = batch_size
del model
ort_sess = ort.InferenceSession(
onnx_path,
providers=["CUDAExecutionProvider"]
if "CUDAExecutionProvider" in ort.get_available_providers()
else ["CPUExecutionProvider"],
)
else:
from tensorflow.keras.models import load_model
model = load_model(f"{args.model_dir}")
model = load_model(args.model_dir)
# label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv")
# 依存ライブラリを増やしたくないので自力で読むよ
@@ -165,14 +124,8 @@ def main(args):
def run_batch(path_imgs):
imgs = np.array([im for _, im in path_imgs])
if args.onnx:
if len(imgs) < args.batch_size:
imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0)
probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy
probs = probs[: len(path_imgs)]
else:
probs = model(imgs, training=False)
probs = probs.numpy()
probs = model(imgs, training=False)
probs = probs.numpy()
for (image_path, _), prob in zip(path_imgs, probs):
# 最初の4つはratingなので無視する
@@ -212,27 +165,9 @@ def main(args):
if len(character_tag_text) > 0:
character_tag_text = character_tag_text[2:]
caption_file = os.path.splitext(image_path)[0] + args.caption_extension
tag_text = ", ".join(combined_tags)
if args.append_tags:
# Check if file exists
if os.path.exists(caption_file):
with open(caption_file, "rt", encoding="utf-8") as f:
# Read file and remove new lines
existing_content = f.read().strip("\n") # Remove newlines
# Split the content into tags and store them in a list
existing_tags = [tag.strip() for tag in existing_content.split(",") if tag.strip()]
# Check and remove repeating tags in tag_text
new_tags = [tag for tag in combined_tags if tag not in existing_tags]
# Create new tag_text
tag_text = ", ".join(existing_tags + new_tags)
with open(caption_file, "wt", encoding="utf-8") as f:
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
f.write(tag_text + "\n")
if args.debug:
print(f"\n{image_path}:\n Character tags: {character_tag_text}\n General tags: {general_tag_text}")
@@ -348,15 +283,12 @@ def setup_parser() -> argparse.ArgumentParser:
help="comma-separated list of undesired tags to remove from the output / 出力から除外したいタグのカンマ区切りのリスト",
)
parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する")
parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する")
parser.add_argument("--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する")
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
# スペルミスしていたオプションを復元する

View File

@@ -65,13 +65,6 @@ import re
import diffusers
import numpy as np
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
import torchvision
from diffusers import (
AutoencoderKL,
@@ -3152,7 +3145,7 @@ def main(args):
print("predefined seeds are exhausted")
seed = None
elif args.iter_same_seed:
seed = iter_seed
seeds = iter_seed
else:
seed = None # 前のを消す
@@ -3364,7 +3357,7 @@ def setup_parser() -> argparse.ArgumentParser:
)
parser.add_argument("--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率")
parser.add_argument(
"--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数"
"--network_args", type=str, default=None, nargs="*", help="additional argmuments for network (key=value) / ネットワークへの追加の引数"
)
parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する")
parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする")
@@ -3390,7 +3383,7 @@ def setup_parser() -> argparse.ArgumentParser:
"--max_embeddings_multiples",
type=int,
default=None,
help="max embedding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる",
help="max embeding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる",
)
parser.add_argument(
"--clip_guidance_scale",
@@ -3449,7 +3442,7 @@ def setup_parser() -> argparse.ArgumentParser:
"--highres_fix_upscaler_args",
type=str,
default=None,
help="additional arguments for upscaler (key=value) / upscalerへの追加の引数",
help="additional argmuments for upscaler (key=value) / upscalerへの追加の引数",
)
parser.add_argument(
"--highres_fix_disable_control_net",

View File

@@ -547,13 +547,17 @@ def generate_controlnet_subsets_config_by_subdirs(train_data_dir: Optional[str]
return []
subsets_config = []
subset_config = {"image_dir": train_data_dir, "conditioning_data_dir": conditioning_data_dir, "caption_extension": caption_extension, "num_repeats": 1}
subsets_config.append(subset_config)
for subdir in base_dir.iterdir():
if not subdir.is_dir():
continue
subset_config = {"image_dir": str(subdir), "conditioning_data_dir": conditioning_data_dir, "caption_extension": caption_extension, "num_repeats": 1}
subsets_config.append(subset_config)
return subsets_config
subsets_config = []
subsets_config += generate(train_data_dir)
subsets_config += generate(train_data_dir, False)
return subsets_config

View File

@@ -1,175 +0,0 @@
import os
import sys
import contextlib
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
from .hijacks import ipex_hijacks
from .attention import attention_init
# pylint: disable=protected-access, missing-function-docstring, line-too-long
def ipex_init(): # pylint: disable=too-many-statements
try:
#Replace cuda with xpu:
torch.cuda.current_device = torch.xpu.current_device
torch.cuda.current_stream = torch.xpu.current_stream
torch.cuda.device = torch.xpu.device
torch.cuda.device_count = torch.xpu.device_count
torch.cuda.device_of = torch.xpu.device_of
torch.cuda.get_device_name = torch.xpu.get_device_name
torch.cuda.get_device_properties = torch.xpu.get_device_properties
torch.cuda.init = torch.xpu.init
torch.cuda.is_available = torch.xpu.is_available
torch.cuda.is_initialized = torch.xpu.is_initialized
torch.cuda.is_current_stream_capturing = lambda: False
torch.cuda.set_device = torch.xpu.set_device
torch.cuda.stream = torch.xpu.stream
torch.cuda.synchronize = torch.xpu.synchronize
torch.cuda.Event = torch.xpu.Event
torch.cuda.Stream = torch.xpu.Stream
torch.cuda.FloatTensor = torch.xpu.FloatTensor
torch.Tensor.cuda = torch.Tensor.xpu
torch.Tensor.is_cuda = torch.Tensor.is_xpu
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
torch.cuda._initialized = torch.xpu.lazy_init._initialized
torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
torch.cuda._tls = torch.xpu.lazy_init._tls
torch.cuda.threading = torch.xpu.lazy_init.threading
torch.cuda.traceback = torch.xpu.lazy_init.traceback
torch.cuda.Optional = torch.xpu.Optional
torch.cuda.__cached__ = torch.xpu.__cached__
torch.cuda.__loader__ = torch.xpu.__loader__
torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
torch.cuda.Tuple = torch.xpu.Tuple
torch.cuda.streams = torch.xpu.streams
torch.cuda._lazy_new = torch.xpu._lazy_new
torch.cuda.FloatStorage = torch.xpu.FloatStorage
torch.cuda.Any = torch.xpu.Any
torch.cuda.__doc__ = torch.xpu.__doc__
torch.cuda.default_generators = torch.xpu.default_generators
torch.cuda.HalfTensor = torch.xpu.HalfTensor
torch.cuda._get_device_index = torch.xpu._get_device_index
torch.cuda.__path__ = torch.xpu.__path__
torch.cuda.Device = torch.xpu.Device
torch.cuda.IntTensor = torch.xpu.IntTensor
torch.cuda.ByteStorage = torch.xpu.ByteStorage
torch.cuda.set_stream = torch.xpu.set_stream
torch.cuda.BoolStorage = torch.xpu.BoolStorage
torch.cuda.os = torch.xpu.os
torch.cuda.torch = torch.xpu.torch
torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
torch.cuda.Union = torch.xpu.Union
torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
torch.cuda.ShortTensor = torch.xpu.ShortTensor
torch.cuda.LongTensor = torch.xpu.LongTensor
torch.cuda.IntStorage = torch.xpu.IntStorage
torch.cuda.LongStorage = torch.xpu.LongStorage
torch.cuda.__annotations__ = torch.xpu.__annotations__
torch.cuda.__package__ = torch.xpu.__package__
torch.cuda.__builtins__ = torch.xpu.__builtins__
torch.cuda.CharTensor = torch.xpu.CharTensor
torch.cuda.List = torch.xpu.List
torch.cuda._lazy_init = torch.xpu._lazy_init
torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
torch.cuda.ByteTensor = torch.xpu.ByteTensor
torch.cuda.StreamContext = torch.xpu.StreamContext
torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
torch.cuda.ShortStorage = torch.xpu.ShortStorage
torch.cuda._lazy_call = torch.xpu._lazy_call
torch.cuda.HalfStorage = torch.xpu.HalfStorage
torch.cuda.random = torch.xpu.random
torch.cuda._device = torch.xpu._device
torch.cuda.classproperty = torch.xpu.classproperty
torch.cuda.__name__ = torch.xpu.__name__
torch.cuda._device_t = torch.xpu._device_t
torch.cuda.warnings = torch.xpu.warnings
torch.cuda.__spec__ = torch.xpu.__spec__
torch.cuda.BoolTensor = torch.xpu.BoolTensor
torch.cuda.CharStorage = torch.xpu.CharStorage
torch.cuda.__file__ = torch.xpu.__file__
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
#torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
#Memory:
torch.cuda.memory = torch.xpu.memory
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
torch.xpu.empty_cache = lambda: None
torch.cuda.empty_cache = torch.xpu.empty_cache
torch.cuda.memory_stats = torch.xpu.memory_stats
torch.cuda.memory_summary = torch.xpu.memory_summary
torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
torch.cuda.memory_allocated = torch.xpu.memory_allocated
torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated
torch.cuda.memory_reserved = torch.xpu.memory_reserved
torch.cuda.memory_cached = torch.xpu.memory_reserved
torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved
torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved
torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats
torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats
torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats
torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict
torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats
#RNG:
torch.cuda.get_rng_state = torch.xpu.get_rng_state
torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all
torch.cuda.set_rng_state = torch.xpu.set_rng_state
torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all
torch.cuda.manual_seed = torch.xpu.manual_seed
torch.cuda.manual_seed_all = torch.xpu.manual_seed_all
torch.cuda.seed = torch.xpu.seed
torch.cuda.seed_all = torch.xpu.seed_all
torch.cuda.initial_seed = torch.xpu.initial_seed
#AMP:
torch.cuda.amp = torch.xpu.amp
if not hasattr(torch.cuda.amp, "common"):
torch.cuda.amp.common = contextlib.nullcontext()
torch.cuda.amp.common.amp_definitely_not_available = lambda: False
try:
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
except Exception: # pylint: disable=broad-exception-caught
try:
from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error
gradscaler_init()
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
except Exception: # pylint: disable=broad-exception-caught
torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
#C
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
ipex._C._DeviceProperties.major = 2023
ipex._C._DeviceProperties.minor = 2
#Fix functions with ipex:
torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
torch._utils._get_available_device_type = lambda: "xpu"
torch.has_cuda = True
torch.cuda.has_half = True
torch.cuda.is_bf16_supported = lambda *args, **kwargs: True
torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
torch.version.cuda = "11.7"
torch.cuda.get_device_capability = lambda *args, **kwargs: [11,7]
torch.cuda.get_device_properties.major = 11
torch.cuda.get_device_properties.minor = 7
torch.cuda.ipc_collect = lambda *args, **kwargs: None
torch.cuda.utilization = lambda *args, **kwargs: 0
if hasattr(torch.xpu, 'getDeviceIdListForCard'):
torch.cuda.getDeviceIdListForCard = torch.xpu.getDeviceIdListForCard
torch.cuda.get_device_id_list_per_card = torch.xpu.getDeviceIdListForCard
else:
torch.cuda.getDeviceIdListForCard = torch.xpu.get_device_id_list_per_card
torch.cuda.get_device_id_list_per_card = torch.xpu.get_device_id_list_per_card
ipex_hijacks()
attention_init()
try:
from .diffusers import ipex_diffusers
ipex_diffusers()
except Exception: # pylint: disable=broad-exception-caught
pass
except Exception as e:
return False, e
return True, None

View File

@@ -1,157 +0,0 @@
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
# pylint: disable=protected-access, missing-function-docstring, line-too-long
original_torch_bmm = torch.bmm
def torch_bmm(input, mat2, *, out=None):
if input.dtype != mat2.dtype:
mat2 = mat2.to(input.dtype)
#ARC GPUs can't allocate more than 4GB to a single block, Slice it:
batch_size_attention, input_tokens, mat2_shape = input.shape[0], input.shape[1], mat2.shape[2]
block_multiply = input.element_size()
slice_block_size = input_tokens * mat2_shape / 1024 / 1024 * block_multiply
block_size = batch_size_attention * slice_block_size
split_slice_size = batch_size_attention
if block_size > 4:
do_split = True
#Find something divisible with the input_tokens
while (split_slice_size * slice_block_size) > 4:
split_slice_size = split_slice_size // 2
if split_slice_size <= 1:
split_slice_size = 1
break
else:
do_split = False
split_2_slice_size = input_tokens
if split_slice_size * slice_block_size > 4:
slice_block_size2 = split_slice_size * mat2_shape / 1024 / 1024 * block_multiply
do_split_2 = True
#Find something divisible with the input_tokens
while (split_2_slice_size * slice_block_size2) > 4:
split_2_slice_size = split_2_slice_size // 2
if split_2_slice_size <= 1:
split_2_slice_size = 1
break
else:
do_split_2 = False
if do_split:
hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype)
for i in range(batch_size_attention // split_slice_size):
start_idx = i * split_slice_size
end_idx = (i + 1) * split_slice_size
if do_split_2:
for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name
start_idx_2 = i2 * split_2_slice_size
end_idx_2 = (i2 + 1) * split_2_slice_size
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm(
input[start_idx:end_idx, start_idx_2:end_idx_2],
mat2[start_idx:end_idx, start_idx_2:end_idx_2],
out=out
)
else:
hidden_states[start_idx:end_idx] = original_torch_bmm(
input[start_idx:end_idx],
mat2[start_idx:end_idx],
out=out
)
else:
return original_torch_bmm(input, mat2, out=out)
return hidden_states
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
#ARC GPUs can't allocate more than 4GB to a single block, Slice it:
if len(query.shape) == 3:
batch_size_attention, query_tokens, shape_four = query.shape
shape_one = 1
no_shape_one = True
else:
shape_one, batch_size_attention, query_tokens, shape_four = query.shape
no_shape_one = False
block_multiply = query.element_size()
slice_block_size = shape_one * query_tokens * shape_four / 1024 / 1024 * block_multiply
block_size = batch_size_attention * slice_block_size
split_slice_size = batch_size_attention
if block_size > 4:
do_split = True
#Find something divisible with the shape_one
while (split_slice_size * slice_block_size) > 4:
split_slice_size = split_slice_size // 2
if split_slice_size <= 1:
split_slice_size = 1
break
else:
do_split = False
split_2_slice_size = query_tokens
if split_slice_size * slice_block_size > 4:
slice_block_size2 = shape_one * split_slice_size * shape_four / 1024 / 1024 * block_multiply
do_split_2 = True
#Find something divisible with the batch_size_attention
while (split_2_slice_size * slice_block_size2) > 4:
split_2_slice_size = split_2_slice_size // 2
if split_2_slice_size <= 1:
split_2_slice_size = 1
break
else:
do_split_2 = False
if do_split:
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
for i in range(batch_size_attention // split_slice_size):
start_idx = i * split_slice_size
end_idx = (i + 1) * split_slice_size
if do_split_2:
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
start_idx_2 = i2 * split_2_slice_size
end_idx_2 = (i2 + 1) * split_2_slice_size
if no_shape_one:
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
query[start_idx:end_idx, start_idx_2:end_idx_2],
key[start_idx:end_idx, start_idx_2:end_idx_2],
value[start_idx:end_idx, start_idx_2:end_idx_2],
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal
)
else:
hidden_states[:, start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
query[:, start_idx:end_idx, start_idx_2:end_idx_2],
key[:, start_idx:end_idx, start_idx_2:end_idx_2],
value[:, start_idx:end_idx, start_idx_2:end_idx_2],
attn_mask=attn_mask[:, start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal
)
else:
if no_shape_one:
hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention(
query[start_idx:end_idx],
key[start_idx:end_idx],
value[start_idx:end_idx],
attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal
)
else:
hidden_states[:, start_idx:end_idx] = original_scaled_dot_product_attention(
query[:, start_idx:end_idx],
key[:, start_idx:end_idx],
value[:, start_idx:end_idx],
attn_mask=attn_mask[:, start_idx:end_idx] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal
)
else:
return original_scaled_dot_product_attention(
query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal
)
return hidden_states
def attention_init():
#ARC GPUs can't allocate more than 4GB to a single block:
torch.bmm = torch_bmm
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention

View File

@@ -1,120 +0,0 @@
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
import diffusers #0.21.1 # pylint: disable=import-error
from diffusers.models.attention_processor import Attention
# pylint: disable=protected-access, missing-function-docstring, line-too-long
class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
r"""
Processor for implementing sliced attention.
Args:
slice_size (`int`, *optional*):
The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
`attention_head_dim` must be a multiple of the `slice_size`.
"""
def __init__(self, slice_size):
self.slice_size = slice_size
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): # pylint: disable=too-many-statements, too-many-locals, too-many-branches
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
dim = query.shape[-1]
query = attn.head_to_batch_dim(query)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
batch_size_attention, query_tokens, shape_three = query.shape
hidden_states = torch.zeros(
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
)
#ARC GPUs can't allocate more than 4GB to a single block, Slice it:
block_multiply = query.element_size()
slice_block_size = self.slice_size * shape_three / 1024 / 1024 * block_multiply
block_size = query_tokens * slice_block_size
split_2_slice_size = query_tokens
if block_size > 4:
do_split_2 = True
#Find something divisible with the query_tokens
while (split_2_slice_size * slice_block_size) > 4:
split_2_slice_size = split_2_slice_size // 2
if split_2_slice_size <= 1:
split_2_slice_size = 1
break
else:
do_split_2 = False
for i in range(batch_size_attention // self.slice_size):
start_idx = i * self.slice_size
end_idx = (i + 1) * self.slice_size
if do_split_2:
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
start_idx_2 = i2 * split_2_slice_size
end_idx_2 = (i2 + 1) * split_2_slice_size
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
else:
query_slice = query[start_idx:end_idx]
key_slice = key[start_idx:end_idx]
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
def ipex_diffusers():
#ARC GPUs can't allocate more than 4GB to a single block:
diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor

View File

@@ -1,179 +0,0 @@
from collections import defaultdict
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
import intel_extension_for_pytorch._C as core # pylint: disable=import-error, unused-import
# pylint: disable=protected-access, missing-function-docstring, line-too-long
OptState = ipex.cpu.autocast._grad_scaler.OptState
_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator
_refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state
def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): # pylint: disable=unused-argument
per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
per_device_found_inf = _MultiDeviceReplicator(found_inf)
# To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
# There could be hundreds of grads, so we'd like to iterate through them just once.
# However, we don't know their devices or dtypes in advance.
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
# Google says mypy struggles with defaultdicts type annotations.
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
# sync grad to master weight
if hasattr(optimizer, "sync_grad"):
optimizer.sync_grad()
with torch.no_grad():
for group in optimizer.param_groups:
for param in group["params"]:
if param.grad is None:
continue
if (not allow_fp16) and param.grad.dtype == torch.float16:
raise ValueError("Attempting to unscale FP16 gradients.")
if param.grad.is_sparse:
# is_coalesced() == False means the sparse grad has values with duplicate indices.
# coalesce() deduplicates indices and adds all values that have the same index.
# For scaled fp16 values, there's a good chance coalescing will cause overflow,
# so we should check the coalesced _values().
if param.grad.dtype is torch.float16:
param.grad = param.grad.coalesce()
to_unscale = param.grad._values()
else:
to_unscale = param.grad
# -: is there a way to split by device and dtype without appending in the inner loop?
to_unscale = to_unscale.to("cpu")
per_device_and_dtype_grads[to_unscale.device][
to_unscale.dtype
].append(to_unscale)
for _, per_dtype_grads in per_device_and_dtype_grads.items():
for grads in per_dtype_grads.values():
core._amp_foreach_non_finite_check_and_unscale_(
grads,
per_device_found_inf.get("cpu"),
per_device_inv_scale.get("cpu"),
)
return per_device_found_inf._per_device_tensors
def unscale_(self, optimizer):
"""
Divides ("unscales") the optimizer's gradient tensors by the scale factor.
:meth:`unscale_` is optional, serving cases where you need to
:ref:`modify or inspect gradients<working-with-unscaled-gradients>`
between the backward pass(es) and :meth:`step`.
If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`.
Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
...
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
scaler.step(optimizer)
scaler.update()
Args:
optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled.
.. warning::
:meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
and only after all gradients for that optimizer's assigned parameters have been accumulated.
Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
.. warning::
:meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
"""
if not self._enabled:
return
self._check_scale_growth_tracker("unscale_")
optimizer_state = self._per_optimizer_states[id(optimizer)]
if optimizer_state["stage"] is OptState.UNSCALED: # pylint: disable=no-else-raise
raise RuntimeError(
"unscale_() has already been called on this optimizer since the last update()."
)
elif optimizer_state["stage"] is OptState.STEPPED:
raise RuntimeError("unscale_() is being called after step().")
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
assert self._scale is not None
inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device)
found_inf = torch.full(
(1,), 0.0, dtype=torch.float32, device=self._scale.device
)
optimizer_state["found_inf_per_device"] = self._unscale_grads_(
optimizer, inv_scale, found_inf, False
)
optimizer_state["stage"] = OptState.UNSCALED
def update(self, new_scale=None):
"""
Updates the scale factor.
If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
the scale is multiplied by ``growth_factor`` to increase it.
Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
used directly, it's used to fill GradScaler's internal scale tensor. So if
``new_scale`` was a tensor, later in-place changes to that tensor will not further
affect the scale GradScaler uses internally.)
Args:
new_scale (float or :class:`torch.FloatTensor`, optional, default=None): New scale factor.
.. warning::
:meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
been invoked for all optimizers used this iteration.
"""
if not self._enabled:
return
_scale, _growth_tracker = self._check_scale_growth_tracker("update")
if new_scale is not None:
# Accept a new user-defined scale.
if isinstance(new_scale, float):
self._scale.fill_(new_scale) # type: ignore[union-attr]
else:
reason = "new_scale should be a float or a 1-element torch.FloatTensor with requires_grad=False."
assert isinstance(new_scale, torch.FloatTensor), reason # type: ignore[attr-defined]
assert new_scale.numel() == 1, reason
assert new_scale.requires_grad is False, reason
self._scale.copy_(new_scale) # type: ignore[union-attr]
else:
# Consume shared inf/nan data collected from optimizers to update the scale.
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
found_infs = [
found_inf.to(device="cpu", non_blocking=True)
for state in self._per_optimizer_states.values()
for found_inf in state["found_inf_per_device"].values()
]
assert len(found_infs) > 0, "No inf checks were recorded prior to update."
found_inf_combined = found_infs[0]
if len(found_infs) > 1:
for i in range(1, len(found_infs)):
found_inf_combined += found_infs[i]
to_device = _scale.device
_scale = _scale.to("cpu")
_growth_tracker = _growth_tracker.to("cpu")
core._amp_update_scale_(
_scale,
_growth_tracker,
found_inf_combined,
self._growth_factor,
self._backoff_factor,
self._growth_interval,
)
_scale = _scale.to(to_device)
_growth_tracker = _growth_tracker.to(to_device)
# To prepare for next iteration, clear the data collected from optimizers this iteration.
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
def gradscaler_init():
torch.xpu.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
torch.xpu.amp.GradScaler._unscale_grads_ = _unscale_grads_
torch.xpu.amp.GradScaler.unscale_ = unscale_
torch.xpu.amp.GradScaler.update = update
return torch.xpu.amp.GradScaler

View File

@@ -1,196 +0,0 @@
import contextlib
import importlib
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
class CondFunc: # pylint: disable=missing-class-docstring
def __new__(cls, orig_func, sub_func, cond_func):
self = super(CondFunc, cls).__new__(cls)
if isinstance(orig_func, str):
func_path = orig_func.split('.')
for i in range(len(func_path)-1, -1, -1):
try:
resolved_obj = importlib.import_module('.'.join(func_path[:i]))
break
except ImportError:
pass
for attr_name in func_path[i:-1]:
resolved_obj = getattr(resolved_obj, attr_name)
orig_func = getattr(resolved_obj, func_path[-1])
setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
self.__init__(orig_func, sub_func, cond_func)
return lambda *args, **kwargs: self(*args, **kwargs)
def __init__(self, orig_func, sub_func, cond_func):
self.__orig_func = orig_func
self.__sub_func = sub_func
self.__cond_func = cond_func
def __call__(self, *args, **kwargs):
if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
return self.__sub_func(self.__orig_func, *args, **kwargs)
else:
return self.__orig_func(*args, **kwargs)
_utils = torch.utils.data._utils
def _shutdown_workers(self):
if torch.utils.data._utils is None or torch.utils.data._utils.python_exit_status is True or torch.utils.data._utils.python_exit_status is None:
return
if hasattr(self, "_shutdown") and not self._shutdown:
self._shutdown = True
try:
if hasattr(self, '_pin_memory_thread'):
self._pin_memory_thread_done_event.set()
self._worker_result_queue.put((None, None))
self._pin_memory_thread.join()
self._worker_result_queue.cancel_join_thread()
self._worker_result_queue.close()
self._workers_done_event.set()
for worker_id in range(len(self._workers)):
if self._persistent_workers or self._workers_status[worker_id]:
self._mark_worker_as_unavailable(worker_id, shutdown=True)
for w in self._workers: # pylint: disable=invalid-name
w.join(timeout=torch.utils.data._utils.MP_STATUS_CHECK_INTERVAL)
for q in self._index_queues: # pylint: disable=invalid-name
q.cancel_join_thread()
q.close()
finally:
if self._worker_pids_set:
torch.utils.data._utils.signal_handling._remove_worker_pids(id(self))
self._worker_pids_set = False
for w in self._workers: # pylint: disable=invalid-name
if w.is_alive():
w.terminate()
class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument
if isinstance(device_ids, list) and len(device_ids) > 1:
print("IPEX backend doesn't support DataParallel on multiple XPU devices")
return module.to("xpu")
def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
return contextlib.nullcontext()
def check_device(device):
return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int))
def return_xpu(device):
return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu"
def ipex_no_cuda(orig_func, *args, **kwargs):
torch.cuda.is_available = lambda: False
orig_func(*args, **kwargs)
torch.cuda.is_available = torch.xpu.is_available
original_autocast = torch.autocast
def ipex_autocast(*args, **kwargs):
if len(args) > 0 and args[0] == "cuda":
return original_autocast("xpu", *args[1:], **kwargs)
else:
return original_autocast(*args, **kwargs)
original_torch_cat = torch.cat
def torch_cat(tensor, *args, **kwargs):
if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs)
else:
return original_torch_cat(tensor, *args, **kwargs)
original_interpolate = torch.nn.functional.interpolate
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
if antialias or align_corners is not None:
return_device = tensor.device
return_dtype = tensor.dtype
return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode,
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias).to(return_device, dtype=return_dtype)
else:
return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode,
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias)
original_linalg_solve = torch.linalg.solve
def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name
if A.device != torch.device("cpu") or B.device != torch.device("cpu"):
return_device = A.device
return original_linalg_solve(A.to("cpu"), B.to("cpu"), *args, **kwargs).to(return_device)
else:
return original_linalg_solve(A, B, *args, **kwargs)
def ipex_hijacks():
CondFunc('torch.Tensor.to',
lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
CondFunc('torch.Tensor.cuda',
lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
CondFunc('torch.empty',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.load',
lambda orig_func, *args, map_location=None, **kwargs: orig_func(*args, return_xpu(map_location), **kwargs),
lambda orig_func, *args, map_location=None, **kwargs: map_location is None or check_device(map_location))
CondFunc('torch.randn',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.ones',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.zeros',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.tensor',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.linspace',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.Generator',
lambda orig_func, device=None: torch.xpu.Generator(device),
lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu")
CondFunc('torch.batch_norm',
lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input,
weight if weight is not None else torch.ones(input.size()[1], device=input.device),
bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs),
lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu"))
CondFunc('torch.instance_norm',
lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input,
weight if weight is not None else torch.ones(input.size()[1], device=input.device),
bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs),
lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu"))
#Functions with dtype errors:
CondFunc('torch.nn.modules.GroupNorm.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
CondFunc('torch.nn.modules.linear.Linear.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
CondFunc('torch.nn.modules.conv.Conv2d.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
CondFunc('torch.nn.functional.layer_norm',
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs),
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
weight is not None and input.dtype != weight.data.dtype)
#Diffusers Float64 (ARC GPUs doesn't support double or Float64):
if not torch.xpu.has_fp64_dtype():
CondFunc('torch.from_numpy',
lambda orig_func, ndarray: orig_func(ndarray.astype('float32')),
lambda orig_func, ndarray: ndarray.dtype == float)
#Broken functions when torch.cuda.is_available is True:
CondFunc('torch.utils.data.dataloader._BaseDataLoaderIter.__init__',
lambda orig_func, *args, **kwargs: ipex_no_cuda(orig_func, *args, **kwargs),
lambda orig_func, *args, **kwargs: True)
#Functions that make compile mad with CondFunc:
torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = _shutdown_workers
torch.nn.DataParallel = DummyDataParallel
torch.autocast = ipex_autocast
torch.cat = torch_cat
torch.linalg.solve = linalg_solve
torch.nn.functional.interpolate = interpolate
torch.backends.cuda.sdp_kernel = return_null_context

View File

@@ -4,13 +4,6 @@
import math
import os
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
import diffusers
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel

View File

@@ -131,7 +131,7 @@ DOWN_BLOCK_TYPES = ["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDo
UP_BLOCK_TYPES = ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"]
# region memory efficient attention
# region memory effcient attention
# FlashAttentionを使うCrossAttention
# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py

View File

@@ -41,7 +41,7 @@ TIME_EMBED_DIM = 320 * 4
USE_REENTRANT = True
# region memory efficient attention
# region memory effcient attention
# FlashAttentionを使うCrossAttention
# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
@@ -996,6 +996,109 @@ class SdxlUNet2DConditionModel(nn.Module):
[GroupNorm32(32, self.model_channels), nn.SiLU(), nn.Conv2d(self.model_channels, self.out_channels, 3, padding=1)]
)
# FreeU
self.freeU = False
self.freeUB1 = 1.0
self.freeUB2 = 1.0
self.freeUS1 = 1.0
self.freeUS2 = 1.0
self.freeURThres = 1
# implementation of FreeU
# FreeU: Free Lunch in Diffusion U-Net https://arxiv.org/abs/2309.11497
def set_free_u_enabled(self, enabled: bool, b1=1.0, b2=1.0, s1=1.0, s2=1.0, rthresh=1):
print(f"FreeU: {enabled}, b1={b1}, b2={b2}, s1={s1}, s2={s2}, rthresh={rthresh}")
self.freeU = enabled
self.freeUB1 = b1
self.freeUB2 = b2
self.freeUS1 = s1
self.freeUS2 = s2
self.freeURThres = rthresh
def spectral_modulation(self, skip_feature, sl=1.0, rthresh=1):
"""
スキップ特徴を周波数領域で修正する関数
:param skip_feature: スキップ特徴のテンソル [b, c, H, W]
:param sl: スケーリング係数
:param rthresh: 周波数の閾値
:return: 修正されたスキップ特徴
"""
import torch.fft
r"""
# 論文に従った実装
org_dtype = skip_feature.dtype
if org_dtype == torch.bfloat16:
skip_feature = skip_feature.to(torch.float32)
# FFTを計算
F = torch.fft.fftn(skip_feature, dim=(2, 3))
# 周波数領域での座標を計算
freq_x = torch.fft.fftfreq(skip_feature.size(2), d=1 / skip_feature.size(2)).to(skip_feature.device)
freq_y = torch.fft.fftfreq(skip_feature.size(3), d=1 / skip_feature.size(3)).to(skip_feature.device)
# 2Dグリッドを作成
freq_x = freq_x[:, None] # [H, 1]
freq_y = freq_y[None, :] # [1, W]
# ラジアス(距離)を計算
r = torch.sqrt(freq_x**2 + freq_y**2)
# 32,32: tensor(0., device='cuda:0') tensor(22.6274, device='cuda:0') tensor(12.2521, device='cuda:0')
# 64,64: tensor(0., device='cuda:0') tensor(45.2548, device='cuda:0') tensor(24.4908, device='cuda:0')
# 128,128: tensor(0., device='cuda:0') tensor(90.5097, device='cuda:0') tensor(48.9748, device='cuda:0')
# マスクを作成
mask = torch.ones_like(r)
mask[r < rthresh] = sl
# b,c,H,Wの形状にブロードキャスト
# TODO shapeごとに同じなのでキャッシュすると良さそう
mask = mask[None, None, :, :]
# 周波数領域での要素ごとの乗算
F_prime = F * mask
# 逆FFTを計算
modified_skip_feature = torch.fft.ifftn(F_prime, dim=(2, 3))
modified_skip_feature = modified_skip_feature.real # 実部のみを取得
"""
# 公式リポジトリの実装
org_dtype = skip_feature.dtype
x = skip_feature
threshold = rthresh
scale = sl
# FFT
x_freq = torch.fft.fftn(x.float(), dim=(-2, -1))
x_freq = torch.fft.fftshift(x_freq, dim=(-2, -1))
B, C, H, W = x_freq.shape
mask = torch.ones((B, C, H, W), device=x.device)
crow, ccol = H // 2, W // 2
mask[..., crow - threshold : crow + threshold, ccol - threshold : ccol + threshold] = scale
x_freq = x_freq * mask
# IFFT
x_freq = torch.fft.ifftshift(x_freq, dim=(-2, -1))
x_filtered = torch.fft.ifftn(x_freq, dim=(-2, -1)).real
modified_skip_feature = x_filtered
# if org_dtype == torch.bfloat16:
modified_skip_feature = modified_skip_feature.to(org_dtype)
return modified_skip_feature
# region diffusers compatibility
def prepare_config(self):
self.config = SimpleNamespace()
@@ -1079,11 +1182,30 @@ class SdxlUNet2DConditionModel(nn.Module):
h = x
for module in self.input_blocks:
h = call_module(module, h, emb, context)
hs.append(h)
if self.freeU:
ch = h.shape[1]
s = self.freeUS1 if ch == 1280 else (self.freeUS2 if ch == 640 else 1.0)
if s == 1.0:
h_mod = h
else:
h_mod = self.spectral_modulation(h, s, self.freeURThres)
hs.append(h_mod)
else:
hs.append(h)
h = call_module(self.middle_block, h, emb, context)
for module in self.output_blocks:
if self.freeU:
ch = h.shape[1]
if ch == 1280:
h[:, : ch // 2] = h[:, : ch // 2] * self.freeUB1
elif ch == 640:
h[:, : ch // 2] = h[:, : ch // 2] * self.freeUB2
# else:
# print(f"disable freeU: {ch}")
h = torch.cat([h, hs.pop()], dim=1)
h = call_module(module, h, emb, context)

View File

@@ -96,7 +96,6 @@ try:
except:
pass
# JPEG-XL on Linux
try:
from jxlpy import JXLImagePlugin
@@ -104,14 +103,6 @@ try:
except:
pass
# JPEG-XL on Windows
try:
import pillow_jxl
IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
except:
pass
IMAGE_TRANSFORMS = transforms.Compose(
[
transforms.ToTensor(),
@@ -4667,7 +4658,7 @@ class ImageLoadingDataset(torch.utils.data.Dataset):
# collate_fn用 epoch,stepはmultiprocessing.Value
class collator_class:
class collater_class:
def __init__(self, epoch, step, dataset):
self.current_epoch = epoch
self.current_step = step

View File

@@ -14,7 +14,7 @@ import lora
CLAMP_QUANTILE = 0.99
MIN_DIFF = 1e-1
MIN_DIFF = 1e-4
def save_to_file(file_name, model, state_dict, dtype):
@@ -200,7 +200,7 @@ def svd(args):
if not args.no_metadata:
title = os.path.splitext(os.path.basename(args.save_to))[0]
sai_metadata = sai_model_spec.build_metadata(
None, args.v2, args.v_parameterization, args.sdxl, True, False, time.time(), title=title
None, args.v2, args.v_parameterization, False, True, False, time.time(), title=title
)
metadata.update(sai_metadata)

View File

@@ -117,7 +117,7 @@ class LoRAModule(torch.nn.Module):
super().__init__()
self.lora_name = lora_name
if org_module.__class__.__name__ == "Conv2d" or org_module.__class__.__name__ == "LoRACompatibleConv":
if org_module.__class__.__name__ == "Conv2d":
in_dim = org_module.in_channels
out_dim = org_module.out_channels
else:
@@ -126,7 +126,7 @@ class LoRAModule(torch.nn.Module):
self.lora_dim = lora_dim
if org_module.__class__.__name__ == "Conv2d" or org_module.__class__.__name__ == "LoRACompatibleConv":
if org_module.__class__.__name__ == "Conv2d":
kernel_size = org_module.kernel_size
stride = org_module.stride
padding = org_module.padding
@@ -166,8 +166,7 @@ class LoRAModule(torch.nn.Module):
self.org_module[0].forward = self.org_forward
# forward with lora
# scale is used LoRACompatibleConv, but we ignore it because we have multiplier
def forward(self, x, scale=1.0):
def forward(self, x):
if not self.enabled:
return self.org_forward(x)
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
@@ -319,12 +318,8 @@ class LoRANetwork(torch.nn.Module):
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
is_linear = (
child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "LoRACompatibleLinear"
)
is_conv2d = (
child_module.__class__.__name__ == "Conv2d" or child_module.__class__.__name__ == "LoRACompatibleConv"
)
is_linear = child_module.__class__.__name__ == "Linear"
is_conv2d = child_module.__class__.__name__ == "Conv2d"
if is_linear or is_conv2d:
lora_name = prefix + "." + name + "." + child_name
@@ -364,7 +359,7 @@ class LoRANetwork(torch.nn.Module):
skipped_te += skipped
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
if len(skipped_te) > 0:
print(f"skipped {len(skipped_te)} modules because of missing weight for text encoder.")
print(f"skipped {len(skipped_te)} modules because of missing weight.")
# extend U-Net target modules to include Conv2d 3x3
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
@@ -373,7 +368,7 @@ class LoRANetwork(torch.nn.Module):
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
if len(skipped_un) > 0:
print(f"skipped {len(skipped_un)} modules because of missing weight for U-Net.")
print(f"skipped {len(skipped_un)} modules because of missing weight.")
# assertion
names = set()

View File

@@ -110,7 +110,7 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
module.weight = torch.nn.Parameter(weight)
def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
def merge_lora_models(models, ratios, merge_dtype):
base_alphas = {} # alpha for merged model
base_dims = {}
@@ -158,12 +158,6 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
for key in lora_sd.keys():
if "alpha" in key:
continue
if "lora_up" in key and concat:
concat_dim = 1
elif "lora_down" in key and concat:
concat_dim = 0
else:
concat_dim = None
lora_module_name = key[: key.rfind(".lora_")]
@@ -171,16 +165,12 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
alpha = alphas[lora_module_name]
scale = math.sqrt(alpha / base_alpha) * ratio
scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。
if key in merged_sd:
assert (
merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None
merged_sd[key].size() == lora_sd[key].size()
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
if concat_dim is not None:
merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim)
else:
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
else:
merged_sd[key] = lora_sd[key] * scale
@@ -188,13 +178,6 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
for lora_module_name, alpha in base_alphas.items():
key = lora_module_name + ".alpha"
merged_sd[key] = torch.tensor(alpha)
if shuffle:
key_down = lora_module_name + ".lora_down.weight"
key_up = lora_module_name + ".lora_up.weight"
dim = merged_sd[key_down].shape[0]
perm = torch.randperm(dim)
merged_sd[key_down] = merged_sd[key_down][perm]
merged_sd[key_up] = merged_sd[key_up][:,perm]
print("merged model")
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
@@ -273,7 +256,7 @@ def merge(args):
args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, sai_metadata, save_dtype, vae
)
else:
state_dict, metadata, v2 = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
state_dict, metadata, v2 = merge_lora_models(args.models, args.ratios, merge_dtype)
print(f"calculating hashes and creating metadata...")
@@ -334,19 +317,7 @@ def setup_parser() -> argparse.ArgumentParser:
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
+ "sai modelspecのメタデータを保存しないLoRAの最低限のss_metadataは保存される",
)
parser.add_argument(
"--concat",
action="store_true",
help="concat lora instead of merge (The dim(rank) of the output LoRA is the sum of the input dims) / "
+ "マージの代わりに結合するLoRAのdim(rank)は入力dimの合計になる",
)
parser.add_argument(
"--shuffle",
action="store_true",
help="shuffle lora weight./ "
+ "LoRAの重みをシャッフルする",
)
return parser

View File

@@ -1,430 +0,0 @@
# OFT network module
import math
import os
from typing import Dict, List, Optional, Tuple, Type, Union
from diffusers import AutoencoderKL
from transformers import CLIPTextModel
import numpy as np
import torch
import re
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
class OFTModule(torch.nn.Module):
"""
replaces forward method of the original Linear, instead of replacing the original Linear module.
"""
def __init__(
self,
oft_name,
org_module: torch.nn.Module,
multiplier=1.0,
dim=4,
alpha=1,
):
"""
dim -> num blocks
alpha -> constraint
"""
super().__init__()
self.oft_name = oft_name
self.num_blocks = dim
if "Linear" in org_module.__class__.__name__:
out_dim = org_module.out_features
elif "Conv" in org_module.__class__.__name__:
out_dim = org_module.out_channels
if type(alpha) == torch.Tensor:
alpha = alpha.detach().numpy()
self.constraint = alpha * out_dim
self.register_buffer("alpha", torch.tensor(alpha))
self.block_size = out_dim // self.num_blocks
self.oft_blocks = torch.nn.Parameter(torch.zeros(self.num_blocks, self.block_size, self.block_size))
self.out_dim = out_dim
self.shape = org_module.weight.shape
self.multiplier = multiplier
self.org_module = [org_module] # moduleにならないようにlistに入れる
def apply_to(self):
self.org_forward = self.org_module[0].forward
self.org_module[0].forward = self.forward
def get_weight(self, multiplier=None):
if multiplier is None:
multiplier = self.multiplier
block_Q = self.oft_blocks - self.oft_blocks.transpose(1, 2)
norm_Q = torch.norm(block_Q.flatten())
new_norm_Q = torch.clamp(norm_Q, max=self.constraint)
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1)
block_R = torch.matmul(I + block_Q, (I - block_Q).inverse())
block_R_weighted = self.multiplier * block_R + (1 - self.multiplier) * I
R = torch.block_diag(*block_R_weighted)
return R
def forward(self, x, scale=None):
x = self.org_forward(x)
if self.multiplier == 0.0:
return x
R = self.get_weight().to(x.device, dtype=x.dtype)
if x.dim() == 4:
x = x.permute(0, 2, 3, 1)
x = torch.matmul(x, R)
x = x.permute(0, 3, 1, 2)
else:
x = torch.matmul(x, R)
return x
class OFTInfModule(OFTModule):
def __init__(
self,
oft_name,
org_module: torch.nn.Module,
multiplier=1.0,
dim=4,
alpha=1,
**kwargs,
):
# no dropout for inference
super().__init__(oft_name, org_module, multiplier, dim, alpha)
self.enabled = True
self.network: OFTNetwork = None
def set_network(self, network):
self.network = network
def forward(self, x, scale=None):
if not self.enabled:
return self.org_forward(x)
return super().forward(x, scale)
def merge_to(self, multiplier=None, sign=1):
R = self.get_weight(multiplier) * sign
# get org weight
org_sd = self.org_module[0].state_dict()
org_weight = org_sd["weight"]
R = R.to(org_weight.device, dtype=org_weight.dtype)
if org_weight.dim() == 4:
weight = torch.einsum("oihw, op -> pihw", org_weight, R)
else:
weight = torch.einsum("oi, op -> pi", org_weight, R)
# set weight to org_module
org_sd["weight"] = weight
self.org_module[0].load_state_dict(org_sd)
def create_network(
multiplier: float,
network_dim: Optional[int],
network_alpha: Optional[float],
vae: AutoencoderKL,
text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
unet,
neuron_dropout: Optional[float] = None,
**kwargs,
):
if network_dim is None:
network_dim = 4 # default
if network_alpha is None:
network_alpha = 1.0
enable_all_linear = kwargs.get("enable_all_linear", None)
enable_conv = kwargs.get("enable_conv", None)
if enable_all_linear is not None:
enable_all_linear = bool(enable_all_linear)
if enable_conv is not None:
enable_conv = bool(enable_conv)
network = OFTNetwork(
text_encoder,
unet,
multiplier=multiplier,
dim=network_dim,
alpha=network_alpha,
enable_all_linear=enable_all_linear,
enable_conv=enable_conv,
varbose=True,
)
return network
# Create network from weights for inference, weights are not loaded here (because can be merged)
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
if weights_sd is None:
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file, safe_open
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
# check dim, alpha and if weights have for conv2d
dim = None
alpha = None
has_conv2d = None
all_linear = None
for name, param in weights_sd.items():
if name.endswith(".alpha"):
if alpha is None:
alpha = param.item()
else:
if dim is None:
dim = param.size()[0]
if has_conv2d is None and param.dim() == 4:
has_conv2d = True
if all_linear is None:
if param.dim() == 3 and "attn" not in name:
all_linear = True
if dim is not None and alpha is not None and has_conv2d is not None:
break
if has_conv2d is None:
has_conv2d = False
if all_linear is None:
all_linear = False
module_class = OFTInfModule if for_inference else OFTModule
network = OFTNetwork(
text_encoder,
unet,
multiplier=multiplier,
dim=dim,
alpha=alpha,
enable_all_linear=all_linear,
enable_conv=has_conv2d,
module_class=module_class,
)
return network, weights_sd
class OFTNetwork(torch.nn.Module):
UNET_TARGET_REPLACE_MODULE_ATTN_ONLY = ["CrossAttention"]
UNET_TARGET_REPLACE_MODULE_ALL_LINEAR = ["Transformer2DModel"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
OFT_PREFIX_UNET = "oft_unet" # これ変えないほうがいいかな
def __init__(
self,
text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
unet,
multiplier: float = 1.0,
dim: int = 4,
alpha: float = 1,
enable_all_linear: Optional[bool] = False,
enable_conv: Optional[bool] = False,
module_class: Type[object] = OFTModule,
varbose: Optional[bool] = False,
) -> None:
super().__init__()
self.multiplier = multiplier
self.dim = dim
self.alpha = alpha
print(
f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}"
)
# create module instances
def create_modules(
root_module: torch.nn.Module,
target_replace_modules: List[torch.nn.Module],
) -> List[OFTModule]:
prefix = self.OFT_PREFIX_UNET
ofts = []
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
is_linear = "Linear" in child_module.__class__.__name__
is_conv2d = "Conv2d" in child_module.__class__.__name__
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
if is_linear or is_conv2d_1x1 or (is_conv2d and enable_conv):
oft_name = prefix + "." + name + "." + child_name
oft_name = oft_name.replace(".", "_")
# print(oft_name)
oft = module_class(
oft_name,
child_module,
self.multiplier,
dim,
alpha,
)
ofts.append(oft)
return ofts
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
if enable_all_linear:
target_modules = OFTNetwork.UNET_TARGET_REPLACE_MODULE_ALL_LINEAR
else:
target_modules = OFTNetwork.UNET_TARGET_REPLACE_MODULE_ATTN_ONLY
if enable_conv:
target_modules += OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
self.unet_ofts: List[OFTModule] = create_modules(unet, target_modules)
print(f"create OFT for U-Net: {len(self.unet_ofts)} modules.")
# assertion
names = set()
for oft in self.unet_ofts:
assert oft.oft_name not in names, f"duplicated oft name: {oft.oft_name}"
names.add(oft.oft_name)
def set_multiplier(self, multiplier):
self.multiplier = multiplier
for oft in self.unet_ofts:
oft.multiplier = self.multiplier
def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
info = self.load_state_dict(weights_sd, False)
return info
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
assert apply_unet, "apply_unet must be True"
for oft in self.unet_ofts:
oft.apply_to()
self.add_module(oft.oft_name, oft)
# マージできるかどうかを返す
def is_mergeable(self):
return True
# TODO refactor to common function with apply_to
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
print("enable OFT for U-Net")
for oft in self.unet_ofts:
sd_for_lora = {}
for key in weights_sd.keys():
if key.startswith(oft.oft_name):
sd_for_lora[key[len(oft.oft_name) + 1 :]] = weights_sd[key]
oft.load_state_dict(sd_for_lora, False)
oft.merge_to()
print(f"weights are merged")
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
self.requires_grad_(True)
all_params = []
def enumerate_params(ofts):
params = []
for oft in ofts:
params.extend(oft.parameters())
# print num of params
num_params = 0
for p in params:
num_params += p.numel()
print(f"OFT params: {num_params}")
return params
param_data = {"params": enumerate_params(self.unet_ofts)}
if unet_lr is not None:
param_data["lr"] = unet_lr
all_params.append(param_data)
return all_params
def enable_gradient_checkpointing(self):
# not supported
pass
def prepare_grad_etc(self, text_encoder, unet):
self.requires_grad_(True)
def on_epoch_start(self, text_encoder, unet):
self.train()
def get_trainable_params(self):
return self.parameters()
def save_weights(self, file, dtype, metadata):
if metadata is not None and len(metadata) == 0:
metadata = None
state_dict = self.state_dict()
if dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(dtype)
state_dict[key] = v
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file
from library import train_util
# Precalculate model hashes to save time on indexing
if metadata is None:
metadata = {}
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
save_file(state_dict, file, metadata)
else:
torch.save(state_dict, file)
def backup_weights(self):
# 重みのバックアップを行う
ofts: List[OFTInfModule] = self.unet_ofts
for oft in ofts:
org_module = oft.org_module[0]
if not hasattr(org_module, "_lora_org_weight"):
sd = org_module.state_dict()
org_module._lora_org_weight = sd["weight"].detach().clone()
org_module._lora_restored = True
def restore_weights(self):
# 重みのリストアを行う
ofts: List[OFTInfModule] = self.unet_ofts
for oft in ofts:
org_module = oft.org_module[0]
if not org_module._lora_restored:
sd = org_module.state_dict()
sd["weight"] = org_module._lora_org_weight
org_module.load_state_dict(sd)
org_module._lora_restored = True
def pre_calculation(self):
# 事前計算を行う
ofts: List[OFTInfModule] = self.unet_ofts
for oft in ofts:
org_module = oft.org_module[0]
oft.merge_to()
# sd = org_module.state_dict()
# org_weight = sd["weight"]
# lora_weight = oft.get_weight().to(org_weight.device, dtype=org_weight.dtype)
# sd["weight"] = org_weight + lora_weight
# assert sd["weight"].shape == org_weight.shape
# org_module.load_state_dict(sd)
org_module._lora_restored = False
oft.enabled = False

View File

@@ -219,8 +219,8 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
for key, value in tqdm(lora_sd.items()):
weight_name = None
if 'lora_down' in key:
block_down_name = key.rsplit('.lora_down', 1)[0]
weight_name = key.rsplit(".", 1)[-1]
block_down_name = key.split(".")[0]
weight_name = key.split(".")[-1]
lora_down_weight = value
else:
continue
@@ -283,10 +283,7 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
def resize(args):
if args.save_to is None or not (args.save_to.endswith('.ckpt') or args.save_to.endswith('.pt') or args.save_to.endswith('.pth') or args.save_to.endswith('.safetensors')):
raise Exception("The --save_to argument must be specified and must be a .ckpt , .pt, .pth or .safetensors file.")
def str_to_dtype(p):
if p == 'float':
return torch.float

View File

@@ -113,7 +113,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
module.weight = torch.nn.Parameter(weight)
def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
def merge_lora_models(models, ratios, merge_dtype):
base_alphas = {} # alpha for merged model
base_dims = {}
@@ -161,13 +161,6 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
for key in tqdm(lora_sd.keys()):
if "alpha" in key:
continue
if "lora_up" in key and concat:
concat_dim = 1
elif "lora_down" in key and concat:
concat_dim = 0
else:
concat_dim = None
lora_module_name = key[: key.rfind(".lora_")]
@@ -175,16 +168,12 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
alpha = alphas[lora_module_name]
scale = math.sqrt(alpha / base_alpha) * ratio
scale = abs(scale) if "lora_up" in key else scale # マイナスの重みに対応する。
if key in merged_sd:
assert (
merged_sd[key].size() == lora_sd[key].size() or concat_dim is not None
merged_sd[key].size() == lora_sd[key].size()
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
if concat_dim is not None:
merged_sd[key] = torch.cat([merged_sd[key], lora_sd[key] * scale], dim=concat_dim)
else:
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
else:
merged_sd[key] = lora_sd[key] * scale
@@ -192,13 +181,6 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
for lora_module_name, alpha in base_alphas.items():
key = lora_module_name + ".alpha"
merged_sd[key] = torch.tensor(alpha)
if shuffle:
key_down = lora_module_name + ".lora_down.weight"
key_up = lora_module_name + ".lora_up.weight"
dim = merged_sd[key_down].shape[0]
perm = torch.randperm(dim)
merged_sd[key_down] = merged_sd[key_down][perm]
merged_sd[key_up] = merged_sd[key_up][:,perm]
print("merged model")
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
@@ -270,7 +252,7 @@ def merge(args):
args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, logit_scale, sai_metadata, save_dtype
)
else:
state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype)
print(f"calculating hashes and creating metadata...")
@@ -325,18 +307,6 @@ def setup_parser() -> argparse.ArgumentParser:
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
+ "sai modelspecのメタデータを保存しないLoRAの最低限のss_metadataは保存される",
)
parser.add_argument(
"--concat",
action="store_true",
help="concat lora instead of merge (The dim(rank) of the output LoRA is the sum of the input dims) / "
+ "マージの代わりに結合するLoRAのdim(rank)は入力dimの合計になる",
)
parser.add_argument(
"--shuffle",
action="store_true",
help="shuffle lora weight./ "
+ "LoRAの重みをシャッフルする",
)
return parser

View File

@@ -1,6 +1,6 @@
accelerate==0.23.0
accelerate==0.19.0
transformers==4.30.2
diffusers[torch]==0.21.2
diffusers[torch]==0.18.2
ftfy==6.1.1
# albumentations==1.3.0
opencv-python==4.7.0.68
@@ -15,18 +15,14 @@ easygui==0.98.3
toml==0.10.2
voluptuous==0.13.1
huggingface-hub==0.15.1
# for loading Diffusers' SDXL
invisible-watermark==0.2.0
# for BLIP captioning
# requests==2.28.2
# timm==0.6.12
# fairscale==0.4.13
# for WD14 captioning (tensorflow)
# for WD14 captioning
# tensorflow==2.10.1
# for WD14 captioning (onnx)
# onnx==1.14.1
# onnxruntime-gpu==1.16.0
# onnxruntime==1.16.0
# this is for onnx:
# protobuf==3.20.3
# open clip for SDXL
open-clip-torch==2.20.0
# for kohya_ss library

View File

@@ -17,13 +17,6 @@ import re
import diffusers
import numpy as np
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
import torchvision
from diffusers import (
AutoencoderKL,
@@ -1528,6 +1521,10 @@ def main(args):
text_encoder2.to(dtype).to(device)
unet.to(dtype).to(device)
# freeU
# unet.set_free_u_enabled(False, 1.0, 1.0, 0)
unet.set_free_u_enabled(True, 1.1, 1.2, 0.9, 0.2)
# networkを組み込む
if args.network_module:
networks = []
@@ -2612,7 +2609,7 @@ def setup_parser() -> argparse.ArgumentParser:
)
parser.add_argument("--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率")
parser.add_argument(
"--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数"
"--network_args", type=str, default=None, nargs="*", help="additional argmuments for network (key=value) / ネットワークへの追加の引数"
)
parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する")
parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする")
@@ -2631,7 +2628,7 @@ def setup_parser() -> argparse.ArgumentParser:
"--max_embeddings_multiples",
type=int,
default=None,
help="max embedding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる",
help="max embeding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる",
)
parser.add_argument(
"--guide_image_path", type=str, default=None, nargs="*", help="image to CLIP guidance / CLIP guided SDでガイドに使う画像"
@@ -2666,7 +2663,7 @@ def setup_parser() -> argparse.ArgumentParser:
"--highres_fix_upscaler_args",
type=str,
default=None,
help="additional arguments for upscaler (key=value) / upscalerへの追加の引数",
help="additional argmuments for upscaler (key=value) / upscalerへの追加の引数",
)
parser.add_argument(
"--highres_fix_disable_control_net",

View File

@@ -9,13 +9,6 @@ import random
from einops import repeat
import numpy as np
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from tqdm import tqdm
from transformers import CLIPTokenizer
from diffusers import EulerDiscreteScheduler
@@ -101,7 +94,7 @@ if __name__ == "__main__":
type=str,
nargs="*",
default=[],
help="LoRA weights, only supports networks.lora, each argument is a `path;multiplier` (semi-colon separated)",
help="LoRA weights, only supports networks.lora, each arguement is a `path;multiplier` (semi-colon separated)",
)
parser.add_argument("--interactive", action="store_true")
args = parser.parse_args()

View File

@@ -10,13 +10,6 @@ import toml
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from library import sdxl_model_util
@@ -172,8 +165,8 @@ def train(args):
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
train_dataset_group.verify_bucket_reso_steps(32)
@@ -348,7 +341,7 @@ def train(args):
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
collate_fn=collater,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)

View File

@@ -1,6 +1,3 @@
# cond_imageをU-Netのforwardで渡すバージョンのControlNet-LLLite検証用学習コード
# training code for ControlNet-LLLite with passing cond_image to U-Net's forward
import argparse
import gc
import json
@@ -14,16 +11,8 @@ import toml
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from torch.nn.parallel import DistributedDataParallel as DDP
from accelerate.utils import set_seed
import accelerate
from diffusers import DDPMScheduler, ControlNetModel
from safetensors.torch import load_file
from library import sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util
@@ -45,7 +34,7 @@ from library.custom_train_functions import (
apply_noise_offset,
scale_v_prediction_loss_like_noise_prediction,
)
import networks.control_net_lllite_for_train as control_net_lllite_for_train
import networks.control_net_lllite as control_net_lllite
# TODO 他のスクリプトと共通化する
@@ -106,8 +95,8 @@ def train(args):
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
train_dataset_group.verify_bucket_reso_steps(32)
@@ -152,6 +141,9 @@ def train(args):
ckpt_info,
) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype)
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=vae_dtype)
@@ -185,53 +177,22 @@ def train(args):
)
accelerator.wait_for_everyone()
# prepare ControlNet-LLLite
control_net_lllite_for_train.replace_unet_linear_and_conv2d()
# prepare ControlNet
network = control_net_lllite.ControlNetLLLite(unet, args.cond_emb_dim, args.network_dim, args.network_dropout)
network.apply_to()
if args.network_weights is not None:
accelerator.print(f"initialize U-Net with ControlNet-LLLite")
with accelerate.init_empty_weights():
unet_lllite = control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite()
unet_lllite.to(accelerator.device, dtype=weight_dtype)
unet_sd = unet.state_dict()
info = unet_lllite.load_lllite_weights(args.network_weights, unet_sd)
accelerator.print(f"load ControlNet-LLLite weights from {args.network_weights}: {info}")
else:
# cosumes large memory, so send to GPU before creating the LLLite model
accelerator.print("sending U-Net to GPU")
unet.to(accelerator.device, dtype=weight_dtype)
unet_sd = unet.state_dict()
# init LLLite weights
accelerator.print(f"initialize U-Net with ControlNet-LLLite")
if args.lowram:
with accelerate.init_on_device(accelerator.device):
unet_lllite = control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite()
else:
unet_lllite = control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite()
unet_lllite.to(weight_dtype)
info = unet_lllite.load_lllite_weights(None, unet_sd)
accelerator.print(f"init U-Net with ControlNet-LLLite weights: {info}")
del unet_sd, unet
unet: control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite = unet_lllite
del unet_lllite
unet.apply_lllite(args.cond_emb_dim, args.network_dim, args.network_dropout)
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
info = network.load_weights(args.network_weights)
accelerator.print(f"load ControlNet weights from {args.network_weights}: {info}")
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
network.enable_gradient_checkpointing() # may have no effect
# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")
trainable_params = list(unet.prepare_params())
trainable_params = list(network.prepare_optimizer_params())
print(f"trainable params count: {len(trainable_params)}")
print(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}")
@@ -245,7 +206,7 @@ def train(args):
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
collate_fn=collater,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
@@ -264,32 +225,37 @@ def train(args):
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# 実験的機能勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
# if args.full_fp16:
# assert (
# args.mixed_precision == "fp16"
# ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
# accelerator.print("enable full fp16 training.")
# unet.to(weight_dtype)
# elif args.full_bf16:
# assert (
# args.mixed_precision == "bf16"
# ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
# accelerator.print("enable full bf16 training.")
# unet.to(weight_dtype)
unet.to(weight_dtype)
if args.full_fp16:
assert (
args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
accelerator.print("enable full fp16 training.")
unet.to(weight_dtype)
network.to(weight_dtype)
elif args.full_bf16:
assert (
args.mixed_precision == "bf16"
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
accelerator.print("enable full bf16 training.")
unet.to(weight_dtype)
network.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, network, optimizer, train_dataloader, lr_scheduler
)
network: control_net_lllite.ControlNetLLLite
# transform DDP after prepare (train_network here only)
unet = train_util.transform_models_if_DDP([unet])[0]
unet, network = train_util.transform_models_if_DDP([unet, network])
if args.gradient_checkpointing:
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
else:
unet.eval()
network.prepare_grad_etc()
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
if args.cache_text_encoder_outputs:
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
@@ -355,13 +321,7 @@ def train(args):
del train_dataset_group
# function for saving/removing
def save_model(
ckpt_name,
unwrapped_nw: control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite,
steps,
epoch_no,
force_sync_upload=False,
):
def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False):
os.makedirs(args.output_dir, exist_ok=True)
ckpt_file = os.path.join(args.output_dir, ckpt_name)
@@ -369,7 +329,7 @@ def train(args):
sai_metadata = train_util.get_sai_model_spec(None, args, True, True, False)
sai_metadata["modelspec.architecture"] = sai_model_spec.ARCH_SD_XL_V1_BASE + "/control-net-lllite"
unwrapped_nw.save_lllite_weights(ckpt_file, save_dtype, sai_metadata)
unwrapped_nw.save_weights(ckpt_file, save_dtype, sai_metadata)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
@@ -384,9 +344,11 @@ def train(args):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
network.on_epoch_start() # train()
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(unet):
with accelerator.accumulate(network):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
@@ -443,9 +405,10 @@ def train(args):
with accelerator.autocast():
# conditioning imageをControlNetに渡す / pass conditioning image to ControlNet
# 内部でcond_embに変換される / it will be converted to cond_emb inside
network.set_cond_image(controlnet_image)
# それらの値を使いつつ、U-Netでイズを予測する / predict noise with U-Net using those values
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding, controlnet_image)
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
if args.v_parameterization:
# v-parameterization training
@@ -470,7 +433,7 @@ def train(args):
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = unet.get_trainable_params()
params_to_clip = network.get_trainable_params()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
@@ -489,7 +452,7 @@ def train(args):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
save_model(ckpt_name, accelerator.unwrap_model(unet), global_step, epoch)
save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch)
if args.save_state:
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
@@ -528,7 +491,7 @@ def train(args):
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
if is_main_process and saving:
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
save_model(ckpt_name, accelerator.unwrap_model(unet), global_step, epoch + 1)
save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch + 1)
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
if remove_epoch_no is not None:
@@ -543,7 +506,7 @@ def train(args):
# end of epoch
if is_main_process:
unet = accelerator.unwrap_model(unet)
network = accelerator.unwrap_model(network)
accelerator.end_training()
@@ -552,7 +515,7 @@ def train(args):
if is_main_process:
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(ckpt_name, unet, global_step, num_train_epochs, force_sync_upload=True)
save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True)
print("model saved.")

View File

@@ -1,3 +1,6 @@
# cond_imageをU-Netのforwardで渡すバージョンのControlNet-LLLite検証用学習コード
# training code for ControlNet-LLLite with passing cond_image to U-Net's forward
import argparse
import gc
import json
@@ -11,15 +14,9 @@ import toml
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from torch.nn.parallel import DistributedDataParallel as DDP
from accelerate.utils import set_seed
import accelerate
from diffusers import DDPMScheduler, ControlNetModel
from safetensors.torch import load_file
from library import sai_model_spec, sdxl_model_util, sdxl_original_unet, sdxl_train_util
@@ -41,7 +38,7 @@ from library.custom_train_functions import (
apply_noise_offset,
scale_v_prediction_loss_like_noise_prediction,
)
import networks.control_net_lllite as control_net_lllite
import networks.control_net_lllite_for_train as control_net_lllite_for_train
# TODO 他のスクリプトと共通化する
@@ -102,8 +99,8 @@ def train(args):
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
train_dataset_group.verify_bucket_reso_steps(32)
@@ -148,9 +145,6 @@ def train(args):
ckpt_info,
) = sdxl_train_util.load_target_model(args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype)
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=vae_dtype)
@@ -184,22 +178,53 @@ def train(args):
)
accelerator.wait_for_everyone()
# prepare ControlNet
network = control_net_lllite.ControlNetLLLite(unet, args.cond_emb_dim, args.network_dim, args.network_dropout)
network.apply_to()
# prepare ControlNet-LLLite
control_net_lllite_for_train.replace_unet_linear_and_conv2d()
if args.network_weights is not None:
info = network.load_weights(args.network_weights)
accelerator.print(f"load ControlNet weights from {args.network_weights}: {info}")
accelerator.print(f"initialize U-Net with ControlNet-LLLite")
with accelerate.init_empty_weights():
unet_lllite = control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite()
unet_lllite.to(accelerator.device, dtype=weight_dtype)
unet_sd = unet.state_dict()
info = unet_lllite.load_lllite_weights(args.network_weights, unet_sd)
accelerator.print(f"load ControlNet-LLLite weights from {args.network_weights}: {info}")
else:
# cosumes large memory, so send to GPU before creating the LLLite model
accelerator.print("sending U-Net to GPU")
unet.to(accelerator.device, dtype=weight_dtype)
unet_sd = unet.state_dict()
# init LLLite weights
accelerator.print(f"initialize U-Net with ControlNet-LLLite")
if args.lowram:
with accelerate.init_on_device(accelerator.device):
unet_lllite = control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite()
else:
unet_lllite = control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite()
unet_lllite.to(weight_dtype)
info = unet_lllite.load_lllite_weights(None, unet_sd)
accelerator.print(f"init U-Net with ControlNet-LLLite weights: {info}")
del unet_sd, unet
unet: control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite = unet_lllite
del unet_lllite
unet.apply_lllite(args.cond_emb_dim, args.network_dim, args.network_dropout)
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
network.enable_gradient_checkpointing() # may have no effect
# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")
trainable_params = list(network.prepare_optimizer_params())
trainable_params = list(unet.prepare_params())
print(f"trainable params count: {len(trainable_params)}")
print(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}")
@@ -213,7 +238,7 @@ def train(args):
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
collate_fn=collater,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
@@ -232,37 +257,32 @@ def train(args):
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# 実験的機能勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
if args.full_fp16:
assert (
args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
accelerator.print("enable full fp16 training.")
unet.to(weight_dtype)
network.to(weight_dtype)
elif args.full_bf16:
assert (
args.mixed_precision == "bf16"
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
accelerator.print("enable full bf16 training.")
unet.to(weight_dtype)
network.to(weight_dtype)
# if args.full_fp16:
# assert (
# args.mixed_precision == "fp16"
# ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
# accelerator.print("enable full fp16 training.")
# unet.to(weight_dtype)
# elif args.full_bf16:
# assert (
# args.mixed_precision == "bf16"
# ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
# accelerator.print("enable full bf16 training.")
# unet.to(weight_dtype)
unet.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, network, optimizer, train_dataloader, lr_scheduler
)
network: control_net_lllite.ControlNetLLLite
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
# transform DDP after prepare (train_network here only)
unet, network = train_util.transform_models_if_DDP([unet, network])
unet = train_util.transform_models_if_DDP([unet])[0]
if args.gradient_checkpointing:
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
else:
unet.eval()
network.prepare_grad_etc()
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
if args.cache_text_encoder_outputs:
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
@@ -328,7 +348,13 @@ def train(args):
del train_dataset_group
# function for saving/removing
def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, force_sync_upload=False):
def save_model(
ckpt_name,
unwrapped_nw: control_net_lllite_for_train.SdxlUNet2DConditionModelControlNetLLLite,
steps,
epoch_no,
force_sync_upload=False,
):
os.makedirs(args.output_dir, exist_ok=True)
ckpt_file = os.path.join(args.output_dir, ckpt_name)
@@ -336,7 +362,7 @@ def train(args):
sai_metadata = train_util.get_sai_model_spec(None, args, True, True, False)
sai_metadata["modelspec.architecture"] = sai_model_spec.ARCH_SD_XL_V1_BASE + "/control-net-lllite"
unwrapped_nw.save_weights(ckpt_file, save_dtype, sai_metadata)
unwrapped_nw.save_lllite_weights(ckpt_file, save_dtype, sai_metadata)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
@@ -351,11 +377,9 @@ def train(args):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
network.on_epoch_start() # train()
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(network):
with accelerator.accumulate(unet):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
@@ -412,10 +436,9 @@ def train(args):
with accelerator.autocast():
# conditioning imageをControlNetに渡す / pass conditioning image to ControlNet
# 内部でcond_embに変換される / it will be converted to cond_emb inside
network.set_cond_image(controlnet_image)
# それらの値を使いつつ、U-Netでイズを予測する / predict noise with U-Net using those values
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding, controlnet_image)
if args.v_parameterization:
# v-parameterization training
@@ -440,7 +463,7 @@ def train(args):
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = network.get_trainable_params()
params_to_clip = unet.get_trainable_params()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
@@ -459,7 +482,7 @@ def train(args):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch)
save_model(ckpt_name, accelerator.unwrap_model(unet), global_step, epoch)
if args.save_state:
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
@@ -498,7 +521,7 @@ def train(args):
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
if is_main_process and saving:
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch + 1)
save_model(ckpt_name, accelerator.unwrap_model(unet), global_step, epoch + 1)
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
if remove_epoch_no is not None:
@@ -513,7 +536,7 @@ def train(args):
# end of epoch
if is_main_process:
network = accelerator.unwrap_model(network)
unet = accelerator.unwrap_model(unet)
accelerator.end_training()
@@ -522,7 +545,7 @@ def train(args):
if is_main_process:
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True)
save_model(ckpt_name, unet, global_step, num_train_epochs, force_sync_upload=True)
print("model saved.")

View File

@@ -1,12 +1,5 @@
import argparse
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from library import sdxl_model_util, sdxl_train_util, train_util
import train_network

View File

@@ -3,13 +3,6 @@ import os
import regex
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
import open_clip
from library import sdxl_model_util, sdxl_train_util, train_util

View File

@@ -86,8 +86,8 @@ def cache_to_disk(args: argparse.Namespace) -> None:
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
# acceleratorを準備する
print("prepare accelerator")
@@ -120,7 +120,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
collate_fn=collater,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)

View File

@@ -91,8 +91,8 @@ def cache_to_disk(args: argparse.Namespace) -> None:
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
# acceleratorを準備する
print("prepare accelerator")
@@ -125,7 +125,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
collate_fn=collater,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)

View File

@@ -1,84 +0,0 @@
import argparse
import os
import torch
from safetensors import safe_open
from safetensors.torch import load_file, save_file
from tqdm import tqdm
def split(args):
# load embedding
if args.embedding.endswith(".safetensors"):
embedding = load_file(args.embedding)
with safe_open(args.embedding, framework="pt") as f:
metadata = f.metadata()
else:
embedding = torch.load(args.embedding)
metadata = None
# check format
if "emb_params" in embedding:
# SD1/2
keys = ["emb_params"]
elif "clip_l" in embedding:
# SDXL
keys = ["clip_l", "clip_g"]
else:
print("Unknown embedding format")
exit()
num_vectors = embedding[keys[0]].shape[0]
# prepare output directory
os.makedirs(args.output_dir, exist_ok=True)
# prepare splits
if args.vectors_per_split is not None:
num_splits = (num_vectors + args.vectors_per_split - 1) // args.vectors_per_split
vectors_for_split = [args.vectors_per_split] * num_splits
if sum(vectors_for_split) > num_vectors:
vectors_for_split[-1] -= sum(vectors_for_split) - num_vectors
assert sum(vectors_for_split) == num_vectors
elif args.vectors is not None:
vectors_for_split = args.vectors
num_splits = len(vectors_for_split)
else:
print("Must specify either --vectors_per_split or --vectors / --vectors_per_split または --vectors のどちらかを指定する必要があります")
exit()
assert (
sum(vectors_for_split) == num_vectors
), "Sum of vectors must be equal to the number of vectors in the embedding / 分割したベクトルの合計はembeddingのベクトル数と等しくなければなりません"
# split
basename = os.path.splitext(os.path.basename(args.embedding))[0]
done_vectors = 0
for i, num_vectors in enumerate(vectors_for_split):
print(f"Splitting {num_vectors} vectors...")
split_embedding = {}
for key in keys:
split_embedding[key] = embedding[key][done_vectors : done_vectors + num_vectors]
output_file = os.path.join(args.output_dir, f"{basename}_{i}.safetensors")
save_file(split_embedding, output_file, metadata)
print(f"Saved to {output_file}")
done_vectors += num_vectors
print("Done")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Merge models")
parser.add_argument("--embedding", type=str, help="Embedding to split")
parser.add_argument("--output_dir", type=str, help="Output directory")
parser.add_argument(
"--vectors_per_split",
type=int,
default=None,
help="Number of vectors per split. If num_vectors is 8 and vectors_per_split is 3, then 3, 3, 2 vectors will be split",
)
parser.add_argument("--vectors", type=int, default=None, nargs="*", help="number of vectors for each split. e.g. 3 3 2")
args = parser.parse_args()
split(args)

View File

@@ -11,13 +11,6 @@ import toml
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from torch.nn.parallel import DistributedDataParallel as DDP
from accelerate.utils import set_seed
from diffusers import DDPMScheduler, ControlNetModel
@@ -98,8 +91,8 @@ def train(args):
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
if args.debug_dataset:
train_util.debug_dataset(train_dataset_group)
@@ -245,7 +238,7 @@ def train(args):
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
collate_fn=collater,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)

View File

@@ -11,13 +11,6 @@ import toml
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
@@ -78,8 +71,8 @@ def train(args):
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
if args.no_token_padding:
train_dataset_group.disable_token_padding()
@@ -177,7 +170,7 @@ def train(args):
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
collate_fn=collater,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)

View File

@@ -12,16 +12,6 @@ import toml
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from library import model_util
@@ -192,8 +182,8 @@ class NetworkTrainer:
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
if args.debug_dataset:
train_util.debug_dataset(train_dataset_group)
@@ -283,10 +273,7 @@ class NetworkTrainer:
if args.dim_from_weights:
network, _ = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet, **net_kwargs)
else:
if "dropout" not in net_kwargs:
# workaround for LyCORIS (;^ω^)
net_kwargs["dropout"] = args.network_dropout
# LyCORIS will work with this...
network = network_module.create_network(
1.0,
args.network_dim,
@@ -345,7 +332,7 @@ class NetworkTrainer:
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
collate_fn=collater,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
@@ -432,12 +419,7 @@ class NetworkTrainer:
t_enc.train()
# set top parameter requires_grad = True for gradient checkpointing works
if train_text_encoder:
t_enc.text_model.embeddings.requires_grad_(True)
# set top parameter requires_grad = True for gradient checkpointing works
if not train_text_encoder: # train U-Net only
unet.parameters().__next__().requires_grad_(True)
t_enc.text_model.embeddings.requires_grad_(True)
else:
unet.eval()
for t_enc in text_encoders:
@@ -527,7 +509,6 @@ class NetworkTrainer:
"ss_prior_loss_weight": args.prior_loss_weight,
"ss_min_snr_gamma": args.min_snr_gamma,
"ss_scale_weight_norms": args.scale_weight_norms,
"ss_ip_noise_gamma": args.ip_noise_gamma,
}
if use_user_config:
@@ -957,7 +938,7 @@ def setup_parser() -> argparse.ArgumentParser:
help="Drops neurons out of training every step (0 or None is default behavior (no dropout), 1 would drop all neurons) / 訓練時に毎ステップでニューロンをdropする0またはNoneはdropoutなし、1は全ニューロンをdropout",
)
parser.add_argument(
"--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数"
"--network_args", type=str, default=None, nargs="*", help="additional argmuments for network (key=value) / ネットワークへの追加の引数"
)
parser.add_argument("--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する")
parser.add_argument(

View File

@@ -7,16 +7,6 @@ import toml
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from transformers import CLIPTokenizer
@@ -170,13 +160,6 @@ class TextualInversionTrainer:
args.output_name = args.token_string
use_template = args.use_object_template or args.use_style_template
assert (
args.token_string is not None or args.token_strings is not None
), "token_string or token_strings must be specified / token_stringまたはtoken_stringsを指定してください"
assert (
not use_template or args.token_strings is None
), "token_strings cannot be used with template / token_stringsはテンプレートと一緒に使えません"
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
@@ -225,17 +208,9 @@ class TextualInversionTrainer:
# add new word to tokenizer, count is num_vectors_per_token
# if token_string is hoge, "hoge", "hoge1", "hoge2", ... are added
if args.token_strings is not None:
token_strings = args.token_strings
assert (
len(token_strings) == args.num_vectors_per_token
), f"num_vectors_per_token is mismatch for token_strings / token_stringsの数がnum_vectors_per_tokenと合いません: {len(token_strings)}"
for token_string in token_strings:
self.assert_token_string(token_string, tokenizers)
else:
self.assert_token_string(args.token_string, tokenizers)
token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
self.assert_token_string(args.token_string, tokenizers)
token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
token_ids_list = []
token_embeds_list = []
for i, (tokenizer, text_encoder, init_token_ids) in enumerate(zip(tokenizers, text_encoders, init_token_ids_list)):
@@ -330,8 +305,8 @@ class TextualInversionTrainer:
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
if use_template:
@@ -350,7 +325,7 @@ class TextualInversionTrainer:
prompt_replacement = None
else:
# サンプル生成用
if args.num_vectors_per_token > 1 and args.token_strings is None:
if args.num_vectors_per_token > 1:
replace_to = " ".join(token_strings)
train_dataset_group.add_replacement(args.token_string, replace_to)
prompt_replacement = (args.token_string, replace_to)
@@ -407,7 +382,7 @@ class TextualInversionTrainer:
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
collate_fn=collater,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
@@ -770,13 +745,6 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること",
)
parser.add_argument(
"--token_strings",
type=str,
default=None,
nargs="*",
help="token strings used in training for multiple embedding / 複数のembeddingsの個別学習時に使用されるトークン文字列",
)
parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可")
parser.add_argument(
"--use_object_template",

View File

@@ -8,13 +8,6 @@ from multiprocessing import Value
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from accelerate.utils import set_seed
import diffusers
from diffusers import DDPMScheduler
@@ -236,8 +229,8 @@ def train(args):
train_dataset_group.enable_XTI(XTI_layers, token_strings=token_strings)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
if use_template:
@@ -309,7 +302,7 @@ def train(args):
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
collate_fn=collater,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)