Compare commits

..

38 Commits

Author SHA1 Message Date
ykume
ae0872ba3b search block-wise application weights 2024-02-04 14:56:31 +09:00
Kohya S
7f948db158 Merge pull request #1087 from mgz-dev/fix-imports-on-svd_merge_lora
fix broken import in svd_merge_lora script
2024-01-31 21:08:40 +09:00
Kohya S
9d7729c00d Merge pull request #1086 from Disty0/dev
Update IPEX Libs
2024-01-31 21:06:34 +09:00
Disty0
988dee02b9 IPEX torch.tensor FP64 workaround 2024-01-30 01:52:32 +03:00
mgz
d4b9568269 fix broken import in svd_merge_lora script
remove missing import, and remove unused imports
2024-01-28 11:59:07 -06:00
Disty0
ccc3a481e7 Update IPEX Libs 2024-01-28 14:14:31 +03:00
Kohya S
cd19df49cd Merge pull request #1085 from kohya-ss/dev
Dev
2024-01-27 18:32:06 +09:00
Kohya S
736365bdd5 update README.md 2024-01-27 18:31:01 +09:00
Kohya S
6ceedb9448 Merge branch 'main' into dev 2024-01-27 18:23:52 +09:00
Kohya S
930a3912a7 Merge pull request #1084 from fireicewolf/devel
Fix network multiplier cause crashed while use multi-GPUs
2024-01-27 18:22:00 +09:00
Kohya S
cf790d87c4 Merge pull request #1079 from feffy380/fix/fp8savestate
Update safetensors to fix a crash with `--fp8_base --save_state`
2024-01-26 22:34:35 +09:00
DukeG
4e67fb8444 test 2024-01-26 20:22:49 +08:00
DukeG
50f631c768 test 2024-01-26 20:02:48 +08:00
DukeG
85bc371ebc test 2024-01-26 18:58:47 +08:00
feffy380
322ee52c77 Update requirements.txt
Update safetensors to fix a crash when using `--fp8_base --save_state`
2024-01-25 19:15:53 +01:00
Kohya S
c576f80639 Fix ControlNetLLLite training issue #1069 2024-01-25 18:43:07 +09:00
Kohya S
d5ab97b69b Merge pull request #1067 from kohya-ss/dev
Dev
2024-01-23 21:04:16 +09:00
Kohya S
7cb44e4502 update readme 2024-01-23 21:02:40 +09:00
Kohya S
7a20df5ad5 Merge pull request #1064 from KohakuBlueleaf/fix-grad-sync
Avoid grad sync on each step even when doing accumulation
2024-01-23 20:33:55 +09:00
Kohya S
bea4362e21 Merge pull request #1060 from akx/refactor-xpu-init
Deduplicate ipex initialization code
2024-01-23 20:25:37 +09:00
Kohya S
6805cafa9b fix TI training crashes in multigpu #1019 2024-01-23 20:17:19 +09:00
Kohaku-Blueleaf
711b40ccda Avoid always sync 2024-01-23 11:49:03 +08:00
Kohya S
696dd7f668 Fix dtype issue in PyTorch 2.0 for generating samples in training sdxl network 2024-01-22 12:43:37 +09:00
Kohya S
e0a3c69223 update readme 2024-01-20 18:47:10 +09:00
Kohya S
c59249a664 Add options to reduce memory usage in extract_lora_from_models.py closes #1059 2024-01-20 18:45:54 +09:00
Kohya S
fef172966f Add network_multiplier for dataset and train LoRA 2024-01-20 16:24:43 +09:00
Kohya S
5a1ebc4c7c format by black 2024-01-20 13:10:45 +09:00
Kohya S
2a0f45aea9 update readme 2024-01-20 11:08:20 +09:00
Kohya S
1f77bb6e73 fix to work sample generation in fp8 ref #1057 2024-01-20 10:57:42 +09:00
Kohya S
a7ef6422b6 fix to work with torch 2.0 2024-01-20 10:00:30 +09:00
Kohaku-Blueleaf
9cfa68c92f [Experimental Feature] FP8 weight dtype for base model when running train_network (or sdxl_train_network) (#1057)
* Add fp8 support

* remove some debug prints

* Better implementation for te

* Fix some misunderstanding

* as same as unet, add explicit convert

* better impl for convert TE to fp8

* fp8 for not only unet

* Better cache TE and TE lr

* match arg name

* Fix with list

* Add timeout settings

* Fix arg style

* Add custom seperator

* Fix typo

* Fix typo again

* Fix dtype error

* Fix gradient problem

* Fix req grad

* fix merge

* Fix merge

* Resolve merge

* arrangement and document

* Resolve merge error

* Add assert for mixed precision
2024-01-20 09:46:53 +09:00
Aarni Koskela
6f3f701d3d Deduplicate ipex initialization code 2024-01-19 18:07:36 +02:00
Kohya S
d2a99a19d4 Merge pull request #1056 from kohya-ss/dev
fix vram usage in LoRA training
2024-01-17 21:41:36 +09:00
Kohya S
0395a35543 Merge branch 'main' into dev 2024-01-17 21:39:13 +09:00
Kohya S
987d4a969d update readme 2024-01-17 21:38:49 +09:00
Kohya S
976d092c68 fix text encodes are on gpu even when not trained 2024-01-17 21:31:50 +09:00
Kohya S
e6b15c7e4a Merge pull request #1053 from akx/sdpa
Fix typo `--spda` (it's `--sdpa`)
2024-01-16 21:50:45 +09:00
Aarni Koskela
ef50436464 Fix typo --spda (it's --sdpa) 2024-01-16 14:32:48 +02:00
31 changed files with 2024 additions and 633 deletions

120
README.md
View File

@@ -1,3 +1,39 @@
## LoRAの層別適用率の探索について
層別適用率を探索する `train_network_appl_weights.py` を追加してあります。現在は SDXL のみ対応しています。
LoRA 等の学習済みネットワークに対して、層別適用率を変化させながら通常の学習プロセスを実行することで、適用率を探索します。つまり、どのような層別適用率を適用すると、学習データに近い画像が生成されるかを探索することができます。
層別適用率の合計をペナルティとすることが可能です。つまり、画像を再現しつつ、影響の少ない層の適用率が低くなるような適用率が探索できるはずです。
複数のネットワークを対象に探索できます。また探索には最低 1 枚の学習データが必要になります。
何枚程度から正しく動くかは確認していません。50枚程度の画像でテスト済みです。また学習データは LoRA 学習時のデータでなくてもよいはずですが、未確認です。)
コマンドラインオプションは `sdxl_train_network.py` とほぼ同じですが、以下のオプションが追加、拡張されています。
- `--application_loss_weight` : 層別適用率を loss に加える際の重みです。デフォルトは 0.0001 です。大きくすると、なるべく適用率を低くするように学習します。0 を指定するとペナルティが適用されないため、再現度が最も高くなる適用率を自由に探索します。
- `--network_module` : 探索対象の複数のモジュールを指定することができます。たとえば `--network_module networks.lora networks.lora` のように指定します。
- `--network_weights` : 探索対象の複数のネットワークの重みを指定することができます。たとえば `--network_weights model1.safetensors model2.safetensors` のように指定します。
層別適用率のパラメータ数は 20個で、`BASE, IN00-08, MID, OUT00-08` となります。`BASE` は Text Encoder に適用されます。Text Encoder を対象とした LoRA の動作は未確認です。)
パラメータは一応ファイルに保存されますが、画面に表示される値をコピーして保存することをお勧めします。
### 備考
オプティマイザ AdamW、学習率 1e-1 で動作確認しています。学習率はかなり高めに設定してよいようです。この設定では LoRA 学習時の 1/20 ~ 1/10 ほどの epoch 数でそれなりの結果が得られます。
`application_loss_weight` を 0.0001 より大きくすると合計の適用率がかなり低くなるLoRA があまり適用されない)ようです。条件にもよると思いますので、適宜調整してください。
適用率に負の値を使うと、影響の少ない層の適用率を極端に低くして合計を小さくする、という動きをしてしまうので、負の値は10倍の重み付けをしてあります-0.01 は 0.1 とほぼ同じペナルティ)。重み付けを変更するときはソースを修正してください。
「必要ない層への適用率を下げて影響範囲を小さくする」という使い方だけでなく、「あるキャラクターがあるポーズをしている画像を教師データに、キャラクターを維持しつつポーズを取るための LoRA の適用率を探索する」、「ある画風のあるキャラクターの画像を教師データに、画風 LoRA とキャラクター LoRA の適用率を探索する」などの使い方が考えられます。
もしかすると、「あるキャラクターの、あえて別の画風の画像を教師データに、キャラクターの属性を再現するのに必要な層を探す」、「理想とする画像を教師データに、使えそうな LoRA を多数適用し、その中から最も再現度が高い適用率を探す(ただし LoRA の数が多いほど学習が遅くなります)」といった使い方もできるかもしれません。
---
__SDXL is now supported. The sdxl branch has been merged into the main branch. If you update the repository, please follow the upgrade instructions. Also, the version of accelerate has been updated, so please run accelerate config again.__ The documentation for SDXL training is [here](./README.md#sdxl-training).
This repository contains training, generation and utility scripts for Stable Diffusion.
@@ -249,6 +285,86 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum
## Change History
### Jan 27, 2024 / 2024/1/27: v0.8.3
- Fixed a bug that the training crashes when `--fp8_base` is specified with `--save_state`. PR [#1079](https://github.com/kohya-ss/sd-scripts/pull/1079) Thanks to feffy380!
- `safetensors` is updated. Please see [Upgrade](#upgrade) and update the library.
- Fixed a bug that the training crashes when `network_multiplier` is specified with multi-GPU training. PR [#1084](https://github.com/kohya-ss/sd-scripts/pull/1084) Thanks to fireicewolf!
- Fixed a bug that the training crashes when training ControlNet-LLLite.
- `--fp8_base` 指定時に `--save_state` での保存がエラーになる不具合が修正されました。 PR [#1079](https://github.com/kohya-ss/sd-scripts/pull/1079) feffy380 氏に感謝します。
- `safetensors` がバージョンアップされていますので、[Upgrade](#upgrade) を参照し更新をお願いします。
- 複数 GPU での学習時に `network_multiplier` を指定するとクラッシュする不具合が修正されました。 PR [#1084](https://github.com/kohya-ss/sd-scripts/pull/1084) fireicewolf 氏に感謝します。
- ControlNet-LLLite の学習がエラーになる不具合を修正しました。
### Jan 23, 2024 / 2024/1/23: v0.8.2
- [Experimental] The `--fp8_base` option is added to the training scripts for LoRA etc. The base model (U-Net, and Text Encoder when training modules for Text Encoder) can be trained with fp8. PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) Thanks to KohakuBlueleaf!
- Please specify `--fp8_base` in `train_network.py` or `sdxl_train_network.py`.
- PyTorch 2.1 or later is required.
- If you use xformers with PyTorch 2.1, please see [xformers repository](https://github.com/facebookresearch/xformers) and install the appropriate version according to your CUDA version.
- The sample image generation during training consumes a lot of memory. It is recommended to turn it off.
- [Experimental] The network multiplier can be specified for each dataset in the training scripts for LoRA etc.
- This is an experimental option and may be removed or changed in the future.
- For example, if you train with state A as `1.0` and state B as `-1.0`, you may be able to generate by switching between state A and B depending on the LoRA application rate.
- Also, if you prepare five states and train them as `0.2`, `0.4`, `0.6`, `0.8`, and `1.0`, you may be able to generate by switching the states smoothly depending on the application rate.
- Please specify `network_multiplier` in `[[datasets]]` in `.toml` file.
- Some options are added to `networks/extract_lora_from_models.py` to reduce the memory usage.
- `--load_precision` option can be used to specify the precision when loading the model. If the model is saved in fp16, you can reduce the memory usage by specifying `--load_precision fp16` without losing precision.
- `--load_original_model_to` option can be used to specify the device to load the original model. `--load_tuned_model_to` option can be used to specify the device to load the derived model. The default is `cpu` for both options, but you can specify `cuda` etc. You can reduce the memory usage by loading one of them to GPU. This option is available only for SDXL.
- The gradient synchronization in LoRA training with multi-GPU is improved. PR [#1064](https://github.com/kohya-ss/sd-scripts/pull/1064) Thanks to KohakuBlueleaf!
- The code for Intel IPEX support is improved. PR [#1060](https://github.com/kohya-ss/sd-scripts/pull/1060) Thanks to akx!
- Fixed a bug in multi-GPU Textual Inversion training.
- 実験的 LoRA等の学習スクリプトで、ベースモデルU-Net、および Text Encoder のモジュール学習時は Text Encoder も)の重みを fp8 にして学習するオプションが追加されました。 PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) KohakuBlueleaf 氏に感謝します。
- `train_network.py` または `sdxl_train_network.py``--fp8_base` を指定してください。
- PyTorch 2.1 以降が必要です。
- PyTorch 2.1 で xformers を使用する場合は、[xformers のリポジトリ](https://github.com/facebookresearch/xformers) を参照し、CUDA バージョンに応じて適切なバージョンをインストールしてください。
- 学習中のサンプル画像生成はメモリを大量に消費するため、オフにすることをお勧めします。
- (実験的) LoRA 等の学習で、データセットごとに異なるネットワーク適用率を指定できるようになりました。
- 実験的オプションのため、将来的に削除または仕様変更される可能性があります。
- たとえば状態 A を `1.0`、状態 B を `-1.0` として学習すると、LoRA の適用率に応じて状態 A と B を切り替えつつ生成できるかもしれません。
- また、五段階の状態を用意し、それぞれ `0.2``0.4``0.6``0.8``1.0` として学習すると、適用率でなめらかに状態を切り替えて生成できるかもしれません。
- `.toml` ファイルで `[[datasets]]``network_multiplier` を指定してください。
- `networks/extract_lora_from_models.py` に使用メモリ量を削減するいくつかのオプションを追加しました。
- `--load_precision` で読み込み時の精度を指定できます。モデルが fp16 で保存されている場合は `--load_precision fp16` を指定して精度を変えずにメモリ量を削減できます。
- `--load_original_model_to` で元モデルを読み込むデバイスを、`--load_tuned_model_to` で派生モデルを読み込むデバイスを指定できます。デフォルトは両方とも `cpu` ですがそれぞれ `cuda` 等を指定できます。片方を GPU に読み込むことでメモリ量を削減できます。SDXL の場合のみ有効です。
- マルチ GPU での LoRA 等の学習時に勾配の同期が改善されました。 PR [#1064](https://github.com/kohya-ss/sd-scripts/pull/1064) KohakuBlueleaf 氏に感謝します。
- Intel IPEX サポートのコードが改善されました。PR [#1060](https://github.com/kohya-ss/sd-scripts/pull/1060) akx 氏に感謝します。
- マルチ GPU での Textual Inversion 学習の不具合を修正しました。
- `.toml` example for network multiplier / ネットワーク適用率の `.toml` の記述例
```toml
[general]
[[datasets]]
resolution = 512
batch_size = 8
network_multiplier = 1.0
... subset settings ...
[[datasets]]
resolution = 512
batch_size = 8
network_multiplier = -1.0
... subset settings ...
```
### Jan 17, 2024 / 2024/1/17: v0.8.1
- Fixed a bug that the VRAM usage without Text Encoder training is larger than before in training scripts for LoRA etc (`train_network.py`, `sdxl_train_network.py`).
- Text Encoders were not moved to CPU.
- Fixed typos. Thanks to akx! [PR #1053](https://github.com/kohya-ss/sd-scripts/pull/1053)
- LoRA 等の学習スクリプト(`train_network.py``sdxl_train_network.py`で、Text Encoder を学習しない場合の VRAM 使用量が以前に比べて大きくなっていた不具合を修正しました。
- Text Encoder が GPU に保持されたままになっていました。
- 誤字が修正されました。 [PR #1053](https://github.com/kohya-ss/sd-scripts/pull/1053) akx 氏に感謝します。
### Jan 15, 2024 / 2024/1/15: v0.8.0
- Diffusers, Accelerate, Transformers and other related libraries have been updated. Please update the libraries with [Upgrade](#upgrade).
@@ -257,7 +373,7 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum
- This feature works only on Linux or WSL.
- Please specify `--torch_compile` option in each training script.
- You can select the backend with `--dynamo_backend` option. The default is `"inductor"`. `inductor` or `eager` seems to work.
- Please use `--spda` option instead of `--xformers` option.
- Please use `--sdpa` option instead of `--xformers` option.
- PyTorch 2.1 or later is recommended.
- Please see [PR](https://github.com/kohya-ss/sd-scripts/pull/1024) for details.
- The session name for wandb can be specified with `--wandb_run_name` option. PR [#1032](https://github.com/kohya-ss/sd-scripts/pull/1032) Thanks to hopl1t!
@@ -270,7 +386,7 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum
- Linux または WSL でのみ動作します。
- 各学習スクリプトで `--torch_compile` オプションを指定してください。
- `--dynamo_backend` オプションで使用される backend を選択できます。デフォルトは `"inductor"` です。 `inductor` または `eager` が動作するようです。
- `--xformers` オプションとは互換性がありません。 代わりに `--spda` オプションを使用してください。
- `--xformers` オプションとは互換性がありません。 代わりに `--sdpa` オプションを使用してください。
- PyTorch 2.1以降を推奨します。
- 詳細は [PR](https://github.com/kohya-ss/sd-scripts/pull/1024) をご覧ください。
- wandb 保存時のセッション名が各学習スクリプトの `--wandb_run_name` オプションで指定できるようになりました。 PR [#1032](https://github.com/kohya-ss/sd-scripts/pull/1032) hopl1t 氏に感謝します。

View File

@@ -1,11 +1,7 @@
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from library.ipex_interop import init_ipex
init_ipex()
from typing import Union, List, Optional, Dict, Any, Tuple
from diffusers.models.unet_2d_condition import UNet2DConditionOutput

View File

@@ -11,15 +11,10 @@ import toml
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
from library.ipex_interop import init_ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
init_ipex()
ipex_init()
except Exception:
pass
from accelerate.utils import set_seed
from diffusers import DDPMScheduler

View File

@@ -66,15 +66,10 @@ import diffusers
import numpy as np
import torch
try:
import intel_extension_for_pytorch as ipex
from library.ipex_interop import init_ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
init_ipex()
ipex_init()
except Exception:
pass
import torchvision
from diffusers import (
AutoencoderKL,

File diff suppressed because it is too large Load Diff

View File

@@ -125,9 +125,13 @@ def ipex_init(): # pylint: disable=too-many-statements
# AMP:
torch.cuda.amp = torch.xpu.amp
torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled
torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype
if not hasattr(torch.cuda.amp, "common"):
torch.cuda.amp.common = contextlib.nullcontext()
torch.cuda.amp.common.amp_definitely_not_available = lambda: False
try:
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
except Exception: # pylint: disable=broad-exception-caught
@@ -151,15 +155,16 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.has_half = True
torch.cuda.is_bf16_supported = lambda *args, **kwargs: True
torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
torch.version.cuda = "11.7"
torch.cuda.get_device_capability = lambda *args, **kwargs: [11,7]
torch.cuda.get_device_properties.major = 11
torch.cuda.get_device_properties.minor = 7
torch.backends.cuda.is_built = lambda *args, **kwargs: True
torch.version.cuda = "12.1"
torch.cuda.get_device_capability = lambda *args, **kwargs: [12,1]
torch.cuda.get_device_properties.major = 12
torch.cuda.get_device_properties.minor = 1
torch.cuda.ipc_collect = lambda *args, **kwargs: None
torch.cuda.utilization = lambda *args, **kwargs: 0
ipex_hijacks()
if not torch.xpu.has_fp64_dtype():
if not torch.xpu.has_fp64_dtype() or os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is not None:
try:
from .diffusers import ipex_diffusers
ipex_diffusers()

View File

@@ -124,6 +124,7 @@ def torch_bmm_32_bit(input, mat2, *, out=None):
)
else:
return original_torch_bmm(input, mat2, out=out)
torch.xpu.synchronize(input.device)
return hidden_states
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
@@ -172,4 +173,5 @@ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropo
)
else:
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
torch.xpu.synchronize(query.device)
return hidden_states

View File

@@ -149,6 +149,7 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
del attn_slice
torch.xpu.synchronize(query.device)
else:
query_slice = query[start_idx:end_idx]
key_slice = key[start_idx:end_idx]
@@ -283,6 +284,7 @@ class AttnProcessor:
hidden_states[start_idx:end_idx] = attn_slice
del attn_slice
torch.xpu.synchronize(query.device)
else:
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)

View File

@@ -1,6 +1,11 @@
import contextlib
import os
from functools import wraps
from contextlib import nullcontext
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
import numpy as np
device_supports_fp64 = torch.xpu.has_fp64_dtype()
# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
@@ -11,7 +16,7 @@ class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstr
return module.to("xpu")
def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
return contextlib.nullcontext()
return nullcontext()
@property
def is_cuda(self):
@@ -25,15 +30,17 @@ def return_xpu(device):
# Autocast
original_autocast = torch.autocast
def ipex_autocast(*args, **kwargs):
if len(args) > 0 and args[0] == "cuda":
return original_autocast("xpu", *args[1:], **kwargs)
original_autocast_init = torch.amp.autocast_mode.autocast.__init__
@wraps(torch.amp.autocast_mode.autocast.__init__)
def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=None):
if device_type == "cuda":
return original_autocast_init(self, device_type="xpu", dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
else:
return original_autocast(*args, **kwargs)
return original_autocast_init(self, device_type=device_type, dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
# Latent Antialias CPU Offload:
original_interpolate = torch.nn.functional.interpolate
@wraps(torch.nn.functional.interpolate)
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
if antialias or align_corners is not None:
return_device = tensor.device
@@ -44,15 +51,29 @@ def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corn
return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode,
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias)
# Diffusers Float64 (Alchemist GPUs doesn't support 64 bit):
original_from_numpy = torch.from_numpy
@wraps(torch.from_numpy)
def from_numpy(ndarray):
if ndarray.dtype == float:
return original_from_numpy(ndarray.astype('float32'))
else:
return original_from_numpy(ndarray)
if torch.xpu.has_fp64_dtype():
original_as_tensor = torch.as_tensor
@wraps(torch.as_tensor)
def as_tensor(data, dtype=None, device=None):
if check_device(device):
device = return_xpu(device)
if isinstance(data, np.ndarray) and data.dtype == float and not (
(isinstance(device, torch.device) and device.type == "cpu") or (isinstance(device, str) and "cpu" in device)):
return original_as_tensor(data, dtype=torch.float32, device=device)
else:
return original_as_tensor(data, dtype=dtype, device=device)
if device_supports_fp64 and os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is None:
original_torch_bmm = torch.bmm
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
else:
@@ -66,20 +87,25 @@ else:
# Data Type Errors:
@wraps(torch.bmm)
def torch_bmm(input, mat2, *, out=None):
if input.dtype != mat2.dtype:
mat2 = mat2.to(input.dtype)
return original_torch_bmm(input, mat2, out=out)
@wraps(torch.nn.functional.scaled_dot_product_attention)
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
if query.dtype != key.dtype:
key = key.to(dtype=query.dtype)
if query.dtype != value.dtype:
value = value.to(dtype=query.dtype)
if attn_mask is not None and query.dtype != attn_mask.dtype:
attn_mask = attn_mask.to(dtype=query.dtype)
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
# A1111 FP16
original_functional_group_norm = torch.nn.functional.group_norm
@wraps(torch.nn.functional.group_norm)
def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05):
if weight is not None and input.dtype != weight.data.dtype:
input = input.to(dtype=weight.data.dtype)
@@ -89,6 +115,7 @@ def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05):
# A1111 BF16
original_functional_layer_norm = torch.nn.functional.layer_norm
@wraps(torch.nn.functional.layer_norm)
def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
if weight is not None and input.dtype != weight.data.dtype:
input = input.to(dtype=weight.data.dtype)
@@ -98,6 +125,7 @@ def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1
# Training
original_functional_linear = torch.nn.functional.linear
@wraps(torch.nn.functional.linear)
def functional_linear(input, weight, bias=None):
if input.dtype != weight.data.dtype:
input = input.to(dtype=weight.data.dtype)
@@ -106,6 +134,7 @@ def functional_linear(input, weight, bias=None):
return original_functional_linear(input, weight, bias=bias)
original_functional_conv2d = torch.nn.functional.conv2d
@wraps(torch.nn.functional.conv2d)
def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
if input.dtype != weight.data.dtype:
input = input.to(dtype=weight.data.dtype)
@@ -115,6 +144,7 @@ def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1,
# A1111 Embedding BF16
original_torch_cat = torch.cat
@wraps(torch.cat)
def torch_cat(tensor, *args, **kwargs):
if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs)
@@ -123,6 +153,7 @@ def torch_cat(tensor, *args, **kwargs):
# SwinIR BF16:
original_functional_pad = torch.nn.functional.pad
@wraps(torch.nn.functional.pad)
def functional_pad(input, pad, mode='constant', value=None):
if mode == 'reflect' and input.dtype == torch.bfloat16:
return original_functional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16)
@@ -131,13 +162,20 @@ def functional_pad(input, pad, mode='constant', value=None):
original_torch_tensor = torch.tensor
def torch_tensor(*args, device=None, **kwargs):
@wraps(torch.tensor)
def torch_tensor(data, *args, dtype=None, device=None, **kwargs):
if check_device(device):
return original_torch_tensor(*args, device=return_xpu(device), **kwargs)
else:
return original_torch_tensor(*args, device=device, **kwargs)
device = return_xpu(device)
if not device_supports_fp64:
if (isinstance(device, torch.device) and device.type == "xpu") or (isinstance(device, str) and "xpu" in device):
if dtype == torch.float64:
dtype = torch.float32
elif dtype is None and (hasattr(data, "dtype") and (data.dtype == torch.float64 or data.dtype == float)):
dtype = torch.float32
return original_torch_tensor(data, *args, dtype=dtype, device=device, **kwargs)
original_Tensor_to = torch.Tensor.to
@wraps(torch.Tensor.to)
def Tensor_to(self, device=None, *args, **kwargs):
if check_device(device):
return original_Tensor_to(self, return_xpu(device), *args, **kwargs)
@@ -145,6 +183,7 @@ def Tensor_to(self, device=None, *args, **kwargs):
return original_Tensor_to(self, device, *args, **kwargs)
original_Tensor_cuda = torch.Tensor.cuda
@wraps(torch.Tensor.cuda)
def Tensor_cuda(self, device=None, *args, **kwargs):
if check_device(device):
return original_Tensor_cuda(self, return_xpu(device), *args, **kwargs)
@@ -152,6 +191,7 @@ def Tensor_cuda(self, device=None, *args, **kwargs):
return original_Tensor_cuda(self, device, *args, **kwargs)
original_UntypedStorage_init = torch.UntypedStorage.__init__
@wraps(torch.UntypedStorage.__init__)
def UntypedStorage_init(*args, device=None, **kwargs):
if check_device(device):
return original_UntypedStorage_init(*args, device=return_xpu(device), **kwargs)
@@ -159,6 +199,7 @@ def UntypedStorage_init(*args, device=None, **kwargs):
return original_UntypedStorage_init(*args, device=device, **kwargs)
original_UntypedStorage_cuda = torch.UntypedStorage.cuda
@wraps(torch.UntypedStorage.cuda)
def UntypedStorage_cuda(self, device=None, *args, **kwargs):
if check_device(device):
return original_UntypedStorage_cuda(self, return_xpu(device), *args, **kwargs)
@@ -166,6 +207,7 @@ def UntypedStorage_cuda(self, device=None, *args, **kwargs):
return original_UntypedStorage_cuda(self, device, *args, **kwargs)
original_torch_empty = torch.empty
@wraps(torch.empty)
def torch_empty(*args, device=None, **kwargs):
if check_device(device):
return original_torch_empty(*args, device=return_xpu(device), **kwargs)
@@ -173,6 +215,7 @@ def torch_empty(*args, device=None, **kwargs):
return original_torch_empty(*args, device=device, **kwargs)
original_torch_randn = torch.randn
@wraps(torch.randn)
def torch_randn(*args, device=None, **kwargs):
if check_device(device):
return original_torch_randn(*args, device=return_xpu(device), **kwargs)
@@ -180,6 +223,7 @@ def torch_randn(*args, device=None, **kwargs):
return original_torch_randn(*args, device=device, **kwargs)
original_torch_ones = torch.ones
@wraps(torch.ones)
def torch_ones(*args, device=None, **kwargs):
if check_device(device):
return original_torch_ones(*args, device=return_xpu(device), **kwargs)
@@ -187,6 +231,7 @@ def torch_ones(*args, device=None, **kwargs):
return original_torch_ones(*args, device=device, **kwargs)
original_torch_zeros = torch.zeros
@wraps(torch.zeros)
def torch_zeros(*args, device=None, **kwargs):
if check_device(device):
return original_torch_zeros(*args, device=return_xpu(device), **kwargs)
@@ -194,6 +239,7 @@ def torch_zeros(*args, device=None, **kwargs):
return original_torch_zeros(*args, device=device, **kwargs)
original_torch_linspace = torch.linspace
@wraps(torch.linspace)
def torch_linspace(*args, device=None, **kwargs):
if check_device(device):
return original_torch_linspace(*args, device=return_xpu(device), **kwargs)
@@ -201,6 +247,7 @@ def torch_linspace(*args, device=None, **kwargs):
return original_torch_linspace(*args, device=device, **kwargs)
original_torch_Generator = torch.Generator
@wraps(torch.Generator)
def torch_Generator(device=None):
if check_device(device):
return original_torch_Generator(return_xpu(device))
@@ -208,12 +255,14 @@ def torch_Generator(device=None):
return original_torch_Generator(device)
original_torch_load = torch.load
@wraps(torch.load)
def torch_load(f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs):
if check_device(map_location):
return original_torch_load(f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs)
else:
return original_torch_load(f, map_location=map_location, pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs)
# Hijack Functions:
def ipex_hijacks():
torch.tensor = torch_tensor
@@ -232,7 +281,7 @@ def ipex_hijacks():
torch.backends.cuda.sdp_kernel = return_null_context
torch.nn.DataParallel = DummyDataParallel
torch.UntypedStorage.is_cuda = is_cuda
torch.autocast = ipex_autocast
torch.amp.autocast_mode.autocast.__init__ = autocast_init
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
torch.nn.functional.group_norm = functional_group_norm
@@ -244,5 +293,6 @@ def ipex_hijacks():
torch.bmm = torch_bmm
torch.cat = torch_cat
if not torch.xpu.has_fp64_dtype():
if not device_supports_fp64:
torch.from_numpy = from_numpy
torch.as_tensor = as_tensor

24
library/ipex_interop.py Normal file
View File

@@ -0,0 +1,24 @@
import torch
def init_ipex():
"""
Try to import `intel_extension_for_pytorch`, and apply
the hijacks using `library.ipex.ipex_init`.
If IPEX is not installed, this function does nothing.
"""
try:
import intel_extension_for_pytorch as ipex # noqa
except ImportError:
return
try:
from library.ipex import ipex_init
if torch.xpu.is_available():
is_initialized, error_message = ipex_init()
if not is_initialized:
print("failed to initialize ipex:", error_message)
except Exception as e:
print("failed to initialize ipex:", e)

View File

@@ -5,15 +5,9 @@ import math
import os
import torch
try:
import intel_extension_for_pytorch as ipex
from library.ipex_interop import init_ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
init_ipex()
import diffusers
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel

View File

@@ -1262,9 +1262,9 @@ class CrossAttnUpBlock2D(nn.Module):
for attn in self.attentions:
attn.set_use_memory_efficient_attention(xformers, mem_eff)
def set_use_sdpa(self, spda):
def set_use_sdpa(self, sdpa):
for attn in self.attentions:
attn.set_use_sdpa(spda)
attn.set_use_sdpa(sdpa)
def forward(
self,

View File

@@ -923,7 +923,11 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
if up1 is not None:
uncond_pool = up1
dtype = self.unet.dtype
unet_dtype = self.unet.dtype
dtype = unet_dtype
if hasattr(dtype, "itemsize") and dtype.itemsize == 1: # fp8
dtype = torch.float16
self.unet.to(dtype)
# 4. Preprocess image and mask
if isinstance(image, PIL.Image.Image):
@@ -1028,6 +1032,7 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
if is_cancelled_callback is not None and is_cancelled_callback():
return None
self.unet.to(unet_dtype)
return latents
def latents_to_image(self, latents):

View File

@@ -558,6 +558,7 @@ class BaseDataset(torch.utils.data.Dataset):
tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]],
max_token_length: int,
resolution: Optional[Tuple[int, int]],
network_multiplier: float,
debug_dataset: bool,
) -> None:
super().__init__()
@@ -567,6 +568,7 @@ class BaseDataset(torch.utils.data.Dataset):
self.max_token_length = max_token_length
# width/height is used when enable_bucket==False
self.width, self.height = (None, None) if resolution is None else resolution
self.network_multiplier = network_multiplier
self.debug_dataset = debug_dataset
self.subsets: List[Union[DreamBoothSubset, FineTuningSubset]] = []
@@ -1106,7 +1108,9 @@ class BaseDataset(torch.utils.data.Dataset):
for image_key in bucket[image_index : image_index + bucket_batch_size]:
image_info = self.image_data[image_key]
subset = self.image_to_subset[image_key]
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
loss_weights.append(
self.prior_loss_weight if image_info.is_reg else 1.0
) # in case of fine tuning, is_reg is always False
flipped = subset.flip_aug and random.random() < 0.5 # not flipped or flipped with 50% chance
@@ -1272,6 +1276,8 @@ class BaseDataset(torch.utils.data.Dataset):
example["target_sizes_hw"] = torch.stack([torch.LongTensor(x) for x in target_sizes_hw])
example["flippeds"] = flippeds
example["network_multipliers"] = torch.FloatTensor([self.network_multiplier] * len(captions))
if self.debug_dataset:
example["image_keys"] = bucket[image_index : image_index + self.batch_size]
return example
@@ -1346,15 +1352,16 @@ class DreamBoothDataset(BaseDataset):
tokenizer,
max_token_length,
resolution,
network_multiplier: float,
enable_bucket: bool,
min_bucket_reso: int,
max_bucket_reso: int,
bucket_reso_steps: int,
bucket_no_upscale: bool,
prior_loss_weight: float,
debug_dataset,
debug_dataset: bool,
) -> None:
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
assert resolution is not None, f"resolution is required / resolution解像度指定は必須です"
@@ -1520,14 +1527,15 @@ class FineTuningDataset(BaseDataset):
tokenizer,
max_token_length,
resolution,
network_multiplier: float,
enable_bucket: bool,
min_bucket_reso: int,
max_bucket_reso: int,
bucket_reso_steps: int,
bucket_no_upscale: bool,
debug_dataset,
debug_dataset: bool,
) -> None:
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
self.batch_size = batch_size
@@ -1724,14 +1732,15 @@ class ControlNetDataset(BaseDataset):
tokenizer,
max_token_length,
resolution,
network_multiplier: float,
enable_bucket: bool,
min_bucket_reso: int,
max_bucket_reso: int,
bucket_reso_steps: int,
bucket_no_upscale: bool,
debug_dataset,
debug_dataset: float,
) -> None:
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
db_subsets = []
for subset in subsets:
@@ -1765,6 +1774,7 @@ class ControlNetDataset(BaseDataset):
tokenizer,
max_token_length,
resolution,
network_multiplier,
enable_bucket,
min_bucket_reso,
max_bucket_reso,
@@ -2039,6 +2049,8 @@ def debug_dataset(train_dataset, show_input_ids=False):
print(
f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}", original size: {orgsz}, crop top left: {crptl}, target size: {trgsz}, flipped: {flpdz}'
)
if "network_multipliers" in example:
print(f"network multiplier: {example['network_multipliers'][j]}")
if show_input_ids:
print(f"input ids: {iid}")
@@ -2105,8 +2117,8 @@ def glob_images_pathlib(dir_path, recursive):
class MinimalDataset(BaseDataset):
def __init__(self, tokenizer, max_token_length, resolution, debug_dataset=False):
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
def __init__(self, tokenizer, max_token_length, resolution, network_multiplier, debug_dataset=False):
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
self.num_train_images = 0 # update in subclass
self.num_reg_images = 0 # update in subclass
@@ -2850,14 +2862,14 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
)
parser.add_argument("--torch_compile", action="store_true", help="use torch.compile (requires PyTorch 2.0) / torch.compile を使う")
parser.add_argument(
"--dynamo_backend",
type=str,
default="inductor",
"--dynamo_backend",
type=str,
default="inductor",
# available backends:
# https://github.com/huggingface/accelerate/blob/d1abd59114ada8ba673e1214218cb2878c13b82d/src/accelerate/utils/dataclasses.py#L376-L388C5
# https://pytorch.org/docs/stable/torch.compiler.html
choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt"],
help="dynamo backend type (default is inductor) / dynamoのbackendの種類デフォルトは inductor"
choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt"],
help="dynamo backend type (default is inductor) / dynamoのbackendの種類デフォルトは inductor",
)
parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
parser.add_argument(
@@ -2904,6 +2916,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument(
"--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する"
) # TODO move to SDXL training, because it is not supported by SD1/2
parser.add_argument("--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う")
parser.add_argument(
"--ddp_timeout",
type=int,
@@ -3886,7 +3899,7 @@ def prepare_accelerator(args: argparse.Namespace):
os.environ["WANDB_DIR"] = logging_dir
if args.wandb_api_key is not None:
wandb.login(key=args.wandb_api_key)
# torch.compile のオプション。 NO の場合は torch.compile は使わない
dynamo_backend = "NO"
if args.torch_compile:

View File

@@ -43,6 +43,9 @@ def svd(
clamp_quantile=0.99,
min_diff=0.01,
no_metadata=False,
load_precision=None,
load_original_model_to=None,
load_tuned_model_to=None,
):
def str_to_dtype(p):
if p == "float":
@@ -57,28 +60,51 @@ def svd(
if v_parameterization is None:
v_parameterization = v2
load_dtype = str_to_dtype(load_precision) if load_precision else None
save_dtype = str_to_dtype(save_precision)
work_device = "cpu"
# load models
if not sdxl:
print(f"loading original SD model : {model_org}")
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org)
text_encoders_o = [text_encoder_o]
if load_dtype is not None:
text_encoder_o = text_encoder_o.to(load_dtype)
unet_o = unet_o.to(load_dtype)
print(f"loading tuned SD model : {model_tuned}")
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned)
text_encoders_t = [text_encoder_t]
if load_dtype is not None:
text_encoder_t = text_encoder_t.to(load_dtype)
unet_t = unet_t.to(load_dtype)
model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization)
else:
device_org = load_original_model_to if load_original_model_to else "cpu"
device_tuned = load_tuned_model_to if load_tuned_model_to else "cpu"
print(f"loading original SDXL model : {model_org}")
text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, "cpu"
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, device_org
)
text_encoders_o = [text_encoder_o1, text_encoder_o2]
if load_dtype is not None:
text_encoder_o1 = text_encoder_o1.to(load_dtype)
text_encoder_o2 = text_encoder_o2.to(load_dtype)
unet_o = unet_o.to(load_dtype)
print(f"loading original SDXL model : {model_tuned}")
text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, "cpu"
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, device_tuned
)
text_encoders_t = [text_encoder_t1, text_encoder_t2]
if load_dtype is not None:
text_encoder_t1 = text_encoder_t1.to(load_dtype)
text_encoder_t2 = text_encoder_t2.to(load_dtype)
unet_t = unet_t.to(load_dtype)
model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0
# create LoRA network to extract weights: Use dim (rank) as alpha
@@ -100,38 +126,54 @@ def svd(
lora_name = lora_o.lora_name
module_o = lora_o.org_module
module_t = lora_t.org_module
diff = module_t.weight - module_o.weight
diff = module_t.weight.to(work_device) - module_o.weight.to(work_device)
# clear weight to save memory
module_o.weight = None
module_t.weight = None
# Text Encoder might be same
if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff:
text_encoder_different = True
print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}")
diff = diff.float()
diffs[lora_name] = diff
# clear target Text Encoder to save memory
for text_encoder in text_encoders_t:
del text_encoder
if not text_encoder_different:
print("Text encoder is same. Extract U-Net only.")
lora_network_o.text_encoder_loras = []
diffs = {}
diffs = {} # clear diffs
for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)):
lora_name = lora_o.lora_name
module_o = lora_o.org_module
module_t = lora_t.org_module
diff = module_t.weight - module_o.weight
diff = diff.float()
diff = module_t.weight.to(work_device) - module_o.weight.to(work_device)
if args.device:
diff = diff.to(args.device)
# clear weight to save memory
module_o.weight = None
module_t.weight = None
diffs[lora_name] = diff
# clear LoRA network, target U-Net to save memory
del lora_network_o
del lora_network_t
del unet_t
# make LoRA with svd
print("calculating by svd")
lora_weights = {}
with torch.no_grad():
for lora_name, mat in tqdm(list(diffs.items())):
if args.device:
mat = mat.to(args.device)
mat = mat.to(torch.float) # calc by float
# if conv_dim is None, diffs do not include LoRAs for conv2d-3x3
conv2d = len(mat.size()) == 4
kernel_size = None if not conv2d else mat.size()[2:4]
@@ -171,8 +213,8 @@ def svd(
U = U.reshape(out_dim, rank, 1, 1)
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
U = U.to("cpu").contiguous()
Vh = Vh.to("cpu").contiguous()
U = U.to(work_device, dtype=save_dtype).contiguous()
Vh = Vh.to(work_device, dtype=save_dtype).contiguous()
lora_weights[lora_name] = (U, Vh)
@@ -230,6 +272,13 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--sdxl", action="store_true", help="load Stable Diffusion SDXL base model / Stable Diffusion SDXL baseのモデルを読み込む"
)
parser.add_argument(
"--load_precision",
type=str,
default=None,
choices=[None, "float", "fp16", "bf16"],
help="precision in loading, model default if omitted / 読み込み時に精度を変更して読み込む、省略時はモデルファイルによる"
)
parser.add_argument(
"--save_precision",
type=str,
@@ -285,6 +334,18 @@ def setup_parser() -> argparse.ArgumentParser:
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
+ "sai modelspecのメタデータを保存しないLoRAの最低限のss_metadataは保存される",
)
parser.add_argument(
"--load_original_model_to",
type=str,
default=None,
help="location to load original model, cpu or cuda, cuda:0, etc, default is cpu, only for SDXL / 元モデル読み込み先、cpuまたはcuda、cuda:0など、省略時はcpu、SDXLのみ有効",
)
parser.add_argument(
"--load_tuned_model_to",
type=str,
default=None,
help="location to load tuned model, cpu or cuda, cuda:0, etc, default is cpu, only for SDXL / 派生モデル読み込み先、cpuまたはcuda、cuda:0など、省略時はcpu、SDXLのみ有効",
)
return parser

View File

@@ -511,7 +511,9 @@ def get_block_dims_and_alphas(
len(block_dims) == num_total_blocks
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
else:
print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
print(
f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります"
)
block_dims = [network_dim] * num_total_blocks
if block_alphas is not None:
@@ -1223,3 +1225,40 @@ class LoRANetwork(torch.nn.Module):
norms.append(scalednorm.item())
return keys_scaled, sum(norms) / len(norms), max(norms)
# region application weight
def get_number_of_blocks(self):
# only for SDXL
return 20
def has_text_encoder_block(self):
return self.text_encoder_loras is not None and len(self.text_encoder_loras) > 0
def set_block_wise_weights(self, weights):
if self.text_encoder_loras:
for lora in self.text_encoder_loras:
lora.multiplier = weights[0]
for lora in self.unet_loras:
# determine block index
key = lora.lora_name[10:] # remove "lora_unet_"
if key.startswith("input_blocks"):
block_index = int(key.split("_")[2]) + 1 # 1-9
elif key.startswith("middle_block"):
block_index = 10 # int(key.split("_")[2]) + 10
elif key.startswith("output_blocks"):
block_index = int(key.split("_")[2]) + 11 # 11-19
else:
print(f"unknown block: {key}")
block_index = 0
lora.multiplier = weights[block_index]
# print(f"{lora.lora_name} block index: {block_index}, weight: {lora.multiplier}")
# print(f"set block-wise weights to {weights}")
# TODO LoRA の weight をあらかじめ計算しておいて multiplier を掛けるだけにすると速くなるはず
# endregion

View File

@@ -1,4 +1,3 @@
import math
import argparse
import os
import time
@@ -6,8 +5,6 @@ import torch
from safetensors.torch import load_file, save_file
from tqdm import tqdm
from library import sai_model_spec, train_util
import library.model_util as model_util
import lora
CLAMP_QUANTILE = 0.99

View File

@@ -8,7 +8,7 @@ einops==0.6.1
pytorch-lightning==1.9.0
# bitsandbytes==0.39.1
tensorboard==2.10.1
safetensors==0.3.1
safetensors==0.4.2
# gradio==3.16.2
altair==4.2.2
easygui==0.98.3

View File

@@ -18,15 +18,10 @@ import diffusers
import numpy as np
import torch
try:
import intel_extension_for_pytorch as ipex
from library.ipex_interop import init_ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
init_ipex()
ipex_init()
except Exception:
pass
import torchvision
from diffusers import (
AutoencoderKL,

View File

@@ -9,13 +9,11 @@ import random
from einops import repeat
import numpy as np
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from library.ipex_interop import init_ipex
init_ipex()
from tqdm import tqdm
from transformers import CLIPTokenizer
from diffusers import EulerDiscreteScheduler

View File

@@ -11,15 +11,10 @@ import toml
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
from library.ipex_interop import init_ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
init_ipex()
ipex_init()
except Exception:
pass
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from library import sdxl_model_util

View File

@@ -14,13 +14,11 @@ import toml
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from library.ipex_interop import init_ipex
init_ipex()
from torch.nn.parallel import DistributedDataParallel as DDP
from accelerate.utils import set_seed
import accelerate

View File

@@ -11,13 +11,11 @@ import toml
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from library.ipex_interop import init_ipex
init_ipex()
from torch.nn.parallel import DistributedDataParallel as DDP
from accelerate.utils import set_seed
from diffusers import DDPMScheduler, ControlNetModel

View File

@@ -1,15 +1,10 @@
import argparse
import torch
try:
import intel_extension_for_pytorch as ipex
from library.ipex_interop import init_ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
init_ipex()
ipex_init()
except Exception:
pass
from library import sdxl_model_util, sdxl_train_util, train_util
import train_network
@@ -95,8 +90,8 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
unet.to(org_unet_device)
else:
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
text_encoders[0].to(accelerator.device)
text_encoders[1].to(accelerator.device)
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
text_encoders[1].to(accelerator.device, dtype=weight_dtype)
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:

View File

@@ -3,13 +3,9 @@ import os
import regex
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from library.ipex_interop import init_ipex
init_ipex()
import open_clip
from library import sdxl_model_util, sdxl_train_util, train_util

View File

@@ -12,15 +12,10 @@ import toml
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
from library.ipex_interop import init_ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
init_ipex()
ipex_init()
except Exception:
pass
from torch.nn.parallel import DistributedDataParallel as DDP
from accelerate.utils import set_seed
from diffusers import DDPMScheduler, ControlNetModel

View File

@@ -12,15 +12,10 @@ import toml
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
from library.ipex_interop import init_ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
init_ipex()
ipex_init()
except Exception:
pass
from accelerate.utils import set_seed
from diffusers import DDPMScheduler

View File

@@ -14,15 +14,10 @@ from tqdm import tqdm
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
try:
import intel_extension_for_pytorch as ipex
from library.ipex_interop import init_ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
init_ipex()
ipex_init()
except Exception:
pass
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from library import model_util
@@ -117,7 +112,7 @@ class NetworkTrainer:
self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype
):
for t_enc in text_encoders:
t_enc.to(accelerator.device)
t_enc.to(accelerator.device, dtype=weight_dtype)
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
input_ids = batch["input_ids"].to(accelerator.device)
@@ -278,6 +273,7 @@ class NetworkTrainer:
accelerator.wait_for_everyone()
# 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される
# cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu
self.cache_text_encoder_outputs_if_needed(
args, accelerator, unet, vae, tokenizers, text_encoders, train_dataset_group, weight_dtype
)
@@ -309,6 +305,7 @@ class NetworkTrainer:
)
if network is None:
return
network_has_multiplier = hasattr(network, "set_multiplier")
if hasattr(network, "prepare_network"):
network.prepare_network(args)
@@ -389,17 +386,33 @@ class NetworkTrainer:
accelerator.print("enable full bf16 training.")
network.to(weight_dtype)
unet_weight_dtype = te_weight_dtype = weight_dtype
# Experimental Feature: Put base model into fp8 to save vram
if args.fp8_base:
assert torch.__version__ >= "2.1.0", "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。"
assert (
args.mixed_precision != "no"
), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。"
accelerator.print("enable fp8 training.")
unet_weight_dtype = torch.float8_e4m3fn
te_weight_dtype = torch.float8_e4m3fn
unet.requires_grad_(False)
unet.to(dtype=weight_dtype)
unet.to(dtype=unet_weight_dtype)
for t_enc in text_encoders:
t_enc.requires_grad_(False)
# acceleratorがなんかよろしくやってくれるらしい
# TODO めちゃくちゃ冗長なのでコードを整理する
# in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16
if t_enc.device.type != "cpu":
t_enc.to(dtype=te_weight_dtype)
# nn.Embedding not support FP8
t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
if train_unet:
unet = accelerator.prepare(unet)
else:
unet.to(accelerator.device, dtype=weight_dtype) # move to device because unet is not prepared by accelerator
unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator
if train_text_encoder:
if len(text_encoders) > 1:
text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders]
@@ -407,8 +420,8 @@ class NetworkTrainer:
text_encoder = accelerator.prepare(text_encoder)
text_encoders = [text_encoder]
else:
for t_enc in text_encoders:
t_enc.to(accelerator.device, dtype=weight_dtype)
pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler)
if args.gradient_checkpointing:
@@ -421,9 +434,6 @@ class NetworkTrainer:
if train_text_encoder:
t_enc.text_model.embeddings.requires_grad_(True)
# set top parameter requires_grad = True for gradient checkpointing works
if not train_text_encoder: # train U-Net only
unet.parameters().__next__().requires_grad_(True)
else:
unet.eval()
for t_enc in text_encoders:
@@ -685,7 +695,7 @@ class NetworkTrainer:
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs['wandb'] = {'name': args.wandb_run_name}
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
@@ -754,7 +764,17 @@ class NetworkTrainer:
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.nan_to_num(latents, 0, out=latents)
latents = latents * self.vae_scale_factor
b_size = latents.shape[0]
# get multiplier for each sample
if network_has_multiplier:
multipliers = batch["network_multipliers"]
# if all multipliers are same, use single multiplier
if torch.all(multipliers == multipliers[0]):
multipliers = multipliers[0].item()
else:
raise NotImplementedError("multipliers for each sample is not supported yet")
# print(f"set multiplier: {multipliers}")
accelerator.unwrap_model(network).set_multiplier(multipliers)
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
# Get the text embedding for conditioning
@@ -778,10 +798,24 @@ class NetworkTrainer:
args, noise_scheduler, latents
)
# ensure the hidden state will require grad
if args.gradient_checkpointing:
for x in noisy_latents:
x.requires_grad_(True)
for t in text_encoder_conds:
t.requires_grad_(True)
# Predict the noise residual
with accelerator.autocast():
noise_pred = self.call_unet(
args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype
args,
accelerator,
unet,
noisy_latents.requires_grad_(train_unet),
timesteps,
text_encoder_conds,
batch,
weight_dtype,
)
if args.v_parameterization:
@@ -808,10 +842,11 @@ class NetworkTrainer:
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss)
self.all_reduce_network(accelerator, network) # sync DDP grad manually
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
if accelerator.sync_gradients:
self.all_reduce_network(accelerator, network) # sync DDP grad manually
if args.max_grad_norm != 0.0:
params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()

File diff suppressed because it is too large Load Diff

View File

@@ -8,15 +8,10 @@ import toml
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
from library.ipex_interop import init_ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
init_ipex()
ipex_init()
except Exception:
pass
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from transformers import CLIPTokenizer
@@ -505,7 +500,7 @@ class TextualInversionTrainer:
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs['wandb'] = {'name': args.wandb_run_name}
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
@@ -730,14 +725,13 @@ class TextualInversionTrainer:
is_main_process = accelerator.is_main_process
if is_main_process:
text_encoder = accelerator.unwrap_model(text_encoder)
updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone()
accelerator.end_training()
if args.save_state and is_main_process:
train_util.save_state_on_train_end(args, accelerator)
updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone()
if is_main_process:
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(ckpt_name, updated_embs_list, global_step, num_train_epochs, force_sync_upload=True)

View File

@@ -8,13 +8,11 @@ from multiprocessing import Value
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from library.ipex_interop import init_ipex
init_ipex()
from accelerate.utils import set_seed
import diffusers
from diffusers import DDPMScheduler