mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
Compare commits
38 Commits
v0.8.0
...
nw_applica
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ae0872ba3b | ||
|
|
7f948db158 | ||
|
|
9d7729c00d | ||
|
|
988dee02b9 | ||
|
|
d4b9568269 | ||
|
|
ccc3a481e7 | ||
|
|
cd19df49cd | ||
|
|
736365bdd5 | ||
|
|
6ceedb9448 | ||
|
|
930a3912a7 | ||
|
|
cf790d87c4 | ||
|
|
4e67fb8444 | ||
|
|
50f631c768 | ||
|
|
85bc371ebc | ||
|
|
322ee52c77 | ||
|
|
c576f80639 | ||
|
|
d5ab97b69b | ||
|
|
7cb44e4502 | ||
|
|
7a20df5ad5 | ||
|
|
bea4362e21 | ||
|
|
6805cafa9b | ||
|
|
711b40ccda | ||
|
|
696dd7f668 | ||
|
|
e0a3c69223 | ||
|
|
c59249a664 | ||
|
|
fef172966f | ||
|
|
5a1ebc4c7c | ||
|
|
2a0f45aea9 | ||
|
|
1f77bb6e73 | ||
|
|
a7ef6422b6 | ||
|
|
9cfa68c92f | ||
|
|
6f3f701d3d | ||
|
|
d2a99a19d4 | ||
|
|
0395a35543 | ||
|
|
987d4a969d | ||
|
|
976d092c68 | ||
|
|
e6b15c7e4a | ||
|
|
ef50436464 |
120
README.md
120
README.md
@@ -1,3 +1,39 @@
|
||||
## LoRAの層別適用率の探索について
|
||||
|
||||
層別適用率を探索する `train_network_appl_weights.py` を追加してあります。現在は SDXL のみ対応しています。
|
||||
|
||||
LoRA 等の学習済みネットワークに対して、層別適用率を変化させながら通常の学習プロセスを実行することで、適用率を探索します。つまり、どのような層別適用率を適用すると、学習データに近い画像が生成されるかを探索することができます。
|
||||
|
||||
層別適用率の合計をペナルティとすることが可能です。つまり、画像を再現しつつ、影響の少ない層の適用率が低くなるような適用率が探索できるはずです。
|
||||
|
||||
複数のネットワークを対象に探索できます。また探索には最低 1 枚の学習データが必要になります。
|
||||
|
||||
(何枚程度から正しく動くかは確認していません。50枚程度の画像でテスト済みです。また学習データは LoRA 学習時のデータでなくてもよいはずですが、未確認です。)
|
||||
|
||||
コマンドラインオプションは `sdxl_train_network.py` とほぼ同じですが、以下のオプションが追加、拡張されています。
|
||||
|
||||
- `--application_loss_weight` : 層別適用率を loss に加える際の重みです。デフォルトは 0.0001 です。大きくすると、なるべく適用率を低くするように学習します。0 を指定するとペナルティが適用されないため、再現度が最も高くなる適用率を自由に探索します。
|
||||
- `--network_module` : 探索対象の複数のモジュールを指定することができます。たとえば `--network_module networks.lora networks.lora` のように指定します。
|
||||
- `--network_weights` : 探索対象の複数のネットワークの重みを指定することができます。たとえば `--network_weights model1.safetensors model2.safetensors` のように指定します。
|
||||
|
||||
層別適用率のパラメータ数は 20個で、`BASE, IN00-08, MID, OUT00-08` となります。`BASE` は Text Encoder に適用されます。(Text Encoder を対象とした LoRA の動作は未確認です。)
|
||||
|
||||
パラメータは一応ファイルに保存されますが、画面に表示される値をコピーして保存することをお勧めします。
|
||||
|
||||
### 備考
|
||||
|
||||
オプティマイザ AdamW、学習率 1e-1 で動作確認しています。学習率はかなり高めに設定してよいようです。この設定では LoRA 学習時の 1/20 ~ 1/10 ほどの epoch 数でそれなりの結果が得られます。
|
||||
|
||||
`application_loss_weight` を 0.0001 より大きくすると合計の適用率がかなり低くなる(=LoRA があまり適用されない)ようです。条件にもよると思いますので、適宜調整してください。
|
||||
|
||||
適用率に負の値を使うと、影響の少ない層の適用率を極端に低くして合計を小さくする、という動きをしてしまうので、負の値は10倍の重み付けをしてあります(-0.01 は 0.1 とほぼ同じペナルティ)。重み付けを変更するときはソースを修正してください。
|
||||
|
||||
「必要ない層への適用率を下げて影響範囲を小さくする」という使い方だけでなく、「あるキャラクターがあるポーズをしている画像を教師データに、キャラクターを維持しつつポーズを取るための LoRA の適用率を探索する」、「ある画風のあるキャラクターの画像を教師データに、画風 LoRA とキャラクター LoRA の適用率を探索する」などの使い方が考えられます。
|
||||
|
||||
もしかすると、「あるキャラクターの、あえて別の画風の画像を教師データに、キャラクターの属性を再現するのに必要な層を探す」、「理想とする画像を教師データに、使えそうな LoRA を多数適用し、その中から最も再現度が高い適用率を探す(ただし LoRA の数が多いほど学習が遅くなります)」といった使い方もできるかもしれません。
|
||||
|
||||
---
|
||||
|
||||
__SDXL is now supported. The sdxl branch has been merged into the main branch. If you update the repository, please follow the upgrade instructions. Also, the version of accelerate has been updated, so please run accelerate config again.__ The documentation for SDXL training is [here](./README.md#sdxl-training).
|
||||
|
||||
This repository contains training, generation and utility scripts for Stable Diffusion.
|
||||
@@ -249,6 +285,86 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum
|
||||
|
||||
## Change History
|
||||
|
||||
### Jan 27, 2024 / 2024/1/27: v0.8.3
|
||||
|
||||
- Fixed a bug that the training crashes when `--fp8_base` is specified with `--save_state`. PR [#1079](https://github.com/kohya-ss/sd-scripts/pull/1079) Thanks to feffy380!
|
||||
- `safetensors` is updated. Please see [Upgrade](#upgrade) and update the library.
|
||||
- Fixed a bug that the training crashes when `network_multiplier` is specified with multi-GPU training. PR [#1084](https://github.com/kohya-ss/sd-scripts/pull/1084) Thanks to fireicewolf!
|
||||
- Fixed a bug that the training crashes when training ControlNet-LLLite.
|
||||
|
||||
- `--fp8_base` 指定時に `--save_state` での保存がエラーになる不具合が修正されました。 PR [#1079](https://github.com/kohya-ss/sd-scripts/pull/1079) feffy380 氏に感謝します。
|
||||
- `safetensors` がバージョンアップされていますので、[Upgrade](#upgrade) を参照し更新をお願いします。
|
||||
- 複数 GPU での学習時に `network_multiplier` を指定するとクラッシュする不具合が修正されました。 PR [#1084](https://github.com/kohya-ss/sd-scripts/pull/1084) fireicewolf 氏に感謝します。
|
||||
- ControlNet-LLLite の学習がエラーになる不具合を修正しました。
|
||||
|
||||
### Jan 23, 2024 / 2024/1/23: v0.8.2
|
||||
|
||||
- [Experimental] The `--fp8_base` option is added to the training scripts for LoRA etc. The base model (U-Net, and Text Encoder when training modules for Text Encoder) can be trained with fp8. PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) Thanks to KohakuBlueleaf!
|
||||
- Please specify `--fp8_base` in `train_network.py` or `sdxl_train_network.py`.
|
||||
- PyTorch 2.1 or later is required.
|
||||
- If you use xformers with PyTorch 2.1, please see [xformers repository](https://github.com/facebookresearch/xformers) and install the appropriate version according to your CUDA version.
|
||||
- The sample image generation during training consumes a lot of memory. It is recommended to turn it off.
|
||||
|
||||
- [Experimental] The network multiplier can be specified for each dataset in the training scripts for LoRA etc.
|
||||
- This is an experimental option and may be removed or changed in the future.
|
||||
- For example, if you train with state A as `1.0` and state B as `-1.0`, you may be able to generate by switching between state A and B depending on the LoRA application rate.
|
||||
- Also, if you prepare five states and train them as `0.2`, `0.4`, `0.6`, `0.8`, and `1.0`, you may be able to generate by switching the states smoothly depending on the application rate.
|
||||
- Please specify `network_multiplier` in `[[datasets]]` in `.toml` file.
|
||||
- Some options are added to `networks/extract_lora_from_models.py` to reduce the memory usage.
|
||||
- `--load_precision` option can be used to specify the precision when loading the model. If the model is saved in fp16, you can reduce the memory usage by specifying `--load_precision fp16` without losing precision.
|
||||
- `--load_original_model_to` option can be used to specify the device to load the original model. `--load_tuned_model_to` option can be used to specify the device to load the derived model. The default is `cpu` for both options, but you can specify `cuda` etc. You can reduce the memory usage by loading one of them to GPU. This option is available only for SDXL.
|
||||
|
||||
- The gradient synchronization in LoRA training with multi-GPU is improved. PR [#1064](https://github.com/kohya-ss/sd-scripts/pull/1064) Thanks to KohakuBlueleaf!
|
||||
- The code for Intel IPEX support is improved. PR [#1060](https://github.com/kohya-ss/sd-scripts/pull/1060) Thanks to akx!
|
||||
- Fixed a bug in multi-GPU Textual Inversion training.
|
||||
|
||||
- (実験的) LoRA等の学習スクリプトで、ベースモデル(U-Net、および Text Encoder のモジュール学習時は Text Encoder も)の重みを fp8 にして学習するオプションが追加されました。 PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) KohakuBlueleaf 氏に感謝します。
|
||||
- `train_network.py` または `sdxl_train_network.py` で `--fp8_base` を指定してください。
|
||||
- PyTorch 2.1 以降が必要です。
|
||||
- PyTorch 2.1 で xformers を使用する場合は、[xformers のリポジトリ](https://github.com/facebookresearch/xformers) を参照し、CUDA バージョンに応じて適切なバージョンをインストールしてください。
|
||||
- 学習中のサンプル画像生成はメモリを大量に消費するため、オフにすることをお勧めします。
|
||||
- (実験的) LoRA 等の学習で、データセットごとに異なるネットワーク適用率を指定できるようになりました。
|
||||
- 実験的オプションのため、将来的に削除または仕様変更される可能性があります。
|
||||
- たとえば状態 A を `1.0`、状態 B を `-1.0` として学習すると、LoRA の適用率に応じて状態 A と B を切り替えつつ生成できるかもしれません。
|
||||
- また、五段階の状態を用意し、それぞれ `0.2`、`0.4`、`0.6`、`0.8`、`1.0` として学習すると、適用率でなめらかに状態を切り替えて生成できるかもしれません。
|
||||
- `.toml` ファイルで `[[datasets]]` に `network_multiplier` を指定してください。
|
||||
- `networks/extract_lora_from_models.py` に使用メモリ量を削減するいくつかのオプションを追加しました。
|
||||
- `--load_precision` で読み込み時の精度を指定できます。モデルが fp16 で保存されている場合は `--load_precision fp16` を指定して精度を変えずにメモリ量を削減できます。
|
||||
- `--load_original_model_to` で元モデルを読み込むデバイスを、`--load_tuned_model_to` で派生モデルを読み込むデバイスを指定できます。デフォルトは両方とも `cpu` ですがそれぞれ `cuda` 等を指定できます。片方を GPU に読み込むことでメモリ量を削減できます。SDXL の場合のみ有効です。
|
||||
- マルチ GPU での LoRA 等の学習時に勾配の同期が改善されました。 PR [#1064](https://github.com/kohya-ss/sd-scripts/pull/1064) KohakuBlueleaf 氏に感謝します。
|
||||
- Intel IPEX サポートのコードが改善されました。PR [#1060](https://github.com/kohya-ss/sd-scripts/pull/1060) akx 氏に感謝します。
|
||||
- マルチ GPU での Textual Inversion 学習の不具合を修正しました。
|
||||
|
||||
- `.toml` example for network multiplier / ネットワーク適用率の `.toml` の記述例
|
||||
|
||||
```toml
|
||||
[general]
|
||||
[[datasets]]
|
||||
resolution = 512
|
||||
batch_size = 8
|
||||
network_multiplier = 1.0
|
||||
|
||||
... subset settings ...
|
||||
|
||||
[[datasets]]
|
||||
resolution = 512
|
||||
batch_size = 8
|
||||
network_multiplier = -1.0
|
||||
|
||||
... subset settings ...
|
||||
```
|
||||
|
||||
|
||||
### Jan 17, 2024 / 2024/1/17: v0.8.1
|
||||
|
||||
- Fixed a bug that the VRAM usage without Text Encoder training is larger than before in training scripts for LoRA etc (`train_network.py`, `sdxl_train_network.py`).
|
||||
- Text Encoders were not moved to CPU.
|
||||
- Fixed typos. Thanks to akx! [PR #1053](https://github.com/kohya-ss/sd-scripts/pull/1053)
|
||||
|
||||
- LoRA 等の学習スクリプト(`train_network.py`、`sdxl_train_network.py`)で、Text Encoder を学習しない場合の VRAM 使用量が以前に比べて大きくなっていた不具合を修正しました。
|
||||
- Text Encoder が GPU に保持されたままになっていました。
|
||||
- 誤字が修正されました。 [PR #1053](https://github.com/kohya-ss/sd-scripts/pull/1053) akx 氏に感謝します。
|
||||
|
||||
### Jan 15, 2024 / 2024/1/15: v0.8.0
|
||||
|
||||
- Diffusers, Accelerate, Transformers and other related libraries have been updated. Please update the libraries with [Upgrade](#upgrade).
|
||||
@@ -257,7 +373,7 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum
|
||||
- This feature works only on Linux or WSL.
|
||||
- Please specify `--torch_compile` option in each training script.
|
||||
- You can select the backend with `--dynamo_backend` option. The default is `"inductor"`. `inductor` or `eager` seems to work.
|
||||
- Please use `--spda` option instead of `--xformers` option.
|
||||
- Please use `--sdpa` option instead of `--xformers` option.
|
||||
- PyTorch 2.1 or later is recommended.
|
||||
- Please see [PR](https://github.com/kohya-ss/sd-scripts/pull/1024) for details.
|
||||
- The session name for wandb can be specified with `--wandb_run_name` option. PR [#1032](https://github.com/kohya-ss/sd-scripts/pull/1032) Thanks to hopl1t!
|
||||
@@ -270,7 +386,7 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum
|
||||
- Linux または WSL でのみ動作します。
|
||||
- 各学習スクリプトで `--torch_compile` オプションを指定してください。
|
||||
- `--dynamo_backend` オプションで使用される backend を選択できます。デフォルトは `"inductor"` です。 `inductor` または `eager` が動作するようです。
|
||||
- `--xformers` オプションとは互換性がありません。 代わりに `--spda` オプションを使用してください。
|
||||
- `--xformers` オプションとは互換性がありません。 代わりに `--sdpa` オプションを使用してください。
|
||||
- PyTorch 2.1以降を推奨します。
|
||||
- 詳細は [PR](https://github.com/kohya-ss/sd-scripts/pull/1024) をご覧ください。
|
||||
- wandb 保存時のセッション名が各学習スクリプトの `--wandb_run_name` オプションで指定できるようになりました。 PR [#1032](https://github.com/kohya-ss/sd-scripts/pull/1032) hopl1t 氏に感謝します。
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
import torch
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
from typing import Union, List, Optional, Dict, Any, Tuple
|
||||
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
|
||||
|
||||
|
||||
@@ -11,15 +11,10 @@ import toml
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
init_ipex()
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler
|
||||
|
||||
|
||||
@@ -66,15 +66,10 @@ import diffusers
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
init_ipex()
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
import torchvision
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -125,9 +125,13 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
|
||||
# AMP:
|
||||
torch.cuda.amp = torch.xpu.amp
|
||||
torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled
|
||||
torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype
|
||||
|
||||
if not hasattr(torch.cuda.amp, "common"):
|
||||
torch.cuda.amp.common = contextlib.nullcontext()
|
||||
torch.cuda.amp.common.amp_definitely_not_available = lambda: False
|
||||
|
||||
try:
|
||||
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
@@ -151,15 +155,16 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
torch.cuda.has_half = True
|
||||
torch.cuda.is_bf16_supported = lambda *args, **kwargs: True
|
||||
torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
|
||||
torch.version.cuda = "11.7"
|
||||
torch.cuda.get_device_capability = lambda *args, **kwargs: [11,7]
|
||||
torch.cuda.get_device_properties.major = 11
|
||||
torch.cuda.get_device_properties.minor = 7
|
||||
torch.backends.cuda.is_built = lambda *args, **kwargs: True
|
||||
torch.version.cuda = "12.1"
|
||||
torch.cuda.get_device_capability = lambda *args, **kwargs: [12,1]
|
||||
torch.cuda.get_device_properties.major = 12
|
||||
torch.cuda.get_device_properties.minor = 1
|
||||
torch.cuda.ipc_collect = lambda *args, **kwargs: None
|
||||
torch.cuda.utilization = lambda *args, **kwargs: 0
|
||||
|
||||
ipex_hijacks()
|
||||
if not torch.xpu.has_fp64_dtype():
|
||||
if not torch.xpu.has_fp64_dtype() or os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is not None:
|
||||
try:
|
||||
from .diffusers import ipex_diffusers
|
||||
ipex_diffusers()
|
||||
|
||||
@@ -124,6 +124,7 @@ def torch_bmm_32_bit(input, mat2, *, out=None):
|
||||
)
|
||||
else:
|
||||
return original_torch_bmm(input, mat2, out=out)
|
||||
torch.xpu.synchronize(input.device)
|
||||
return hidden_states
|
||||
|
||||
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
||||
@@ -172,4 +173,5 @@ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropo
|
||||
)
|
||||
else:
|
||||
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
|
||||
torch.xpu.synchronize(query.device)
|
||||
return hidden_states
|
||||
|
||||
@@ -149,6 +149,7 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
|
||||
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
|
||||
del attn_slice
|
||||
torch.xpu.synchronize(query.device)
|
||||
else:
|
||||
query_slice = query[start_idx:end_idx]
|
||||
key_slice = key[start_idx:end_idx]
|
||||
@@ -283,6 +284,7 @@ class AttnProcessor:
|
||||
|
||||
hidden_states[start_idx:end_idx] = attn_slice
|
||||
del attn_slice
|
||||
torch.xpu.synchronize(query.device)
|
||||
else:
|
||||
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
||||
hidden_states = torch.bmm(attention_probs, value)
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
import contextlib
|
||||
import os
|
||||
from functools import wraps
|
||||
from contextlib import nullcontext
|
||||
import torch
|
||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||
import numpy as np
|
||||
|
||||
device_supports_fp64 = torch.xpu.has_fp64_dtype()
|
||||
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
|
||||
|
||||
@@ -11,7 +16,7 @@ class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstr
|
||||
return module.to("xpu")
|
||||
|
||||
def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
|
||||
return contextlib.nullcontext()
|
||||
return nullcontext()
|
||||
|
||||
@property
|
||||
def is_cuda(self):
|
||||
@@ -25,15 +30,17 @@ def return_xpu(device):
|
||||
|
||||
|
||||
# Autocast
|
||||
original_autocast = torch.autocast
|
||||
def ipex_autocast(*args, **kwargs):
|
||||
if len(args) > 0 and args[0] == "cuda":
|
||||
return original_autocast("xpu", *args[1:], **kwargs)
|
||||
original_autocast_init = torch.amp.autocast_mode.autocast.__init__
|
||||
@wraps(torch.amp.autocast_mode.autocast.__init__)
|
||||
def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=None):
|
||||
if device_type == "cuda":
|
||||
return original_autocast_init(self, device_type="xpu", dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
|
||||
else:
|
||||
return original_autocast(*args, **kwargs)
|
||||
return original_autocast_init(self, device_type=device_type, dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
|
||||
|
||||
# Latent Antialias CPU Offload:
|
||||
original_interpolate = torch.nn.functional.interpolate
|
||||
@wraps(torch.nn.functional.interpolate)
|
||||
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
|
||||
if antialias or align_corners is not None:
|
||||
return_device = tensor.device
|
||||
@@ -44,15 +51,29 @@ def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corn
|
||||
return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode,
|
||||
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias)
|
||||
|
||||
|
||||
# Diffusers Float64 (Alchemist GPUs doesn't support 64 bit):
|
||||
original_from_numpy = torch.from_numpy
|
||||
@wraps(torch.from_numpy)
|
||||
def from_numpy(ndarray):
|
||||
if ndarray.dtype == float:
|
||||
return original_from_numpy(ndarray.astype('float32'))
|
||||
else:
|
||||
return original_from_numpy(ndarray)
|
||||
|
||||
if torch.xpu.has_fp64_dtype():
|
||||
original_as_tensor = torch.as_tensor
|
||||
@wraps(torch.as_tensor)
|
||||
def as_tensor(data, dtype=None, device=None):
|
||||
if check_device(device):
|
||||
device = return_xpu(device)
|
||||
if isinstance(data, np.ndarray) and data.dtype == float and not (
|
||||
(isinstance(device, torch.device) and device.type == "cpu") or (isinstance(device, str) and "cpu" in device)):
|
||||
return original_as_tensor(data, dtype=torch.float32, device=device)
|
||||
else:
|
||||
return original_as_tensor(data, dtype=dtype, device=device)
|
||||
|
||||
|
||||
if device_supports_fp64 and os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is None:
|
||||
original_torch_bmm = torch.bmm
|
||||
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
||||
else:
|
||||
@@ -66,20 +87,25 @@ else:
|
||||
|
||||
|
||||
# Data Type Errors:
|
||||
@wraps(torch.bmm)
|
||||
def torch_bmm(input, mat2, *, out=None):
|
||||
if input.dtype != mat2.dtype:
|
||||
mat2 = mat2.to(input.dtype)
|
||||
return original_torch_bmm(input, mat2, out=out)
|
||||
|
||||
@wraps(torch.nn.functional.scaled_dot_product_attention)
|
||||
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
|
||||
if query.dtype != key.dtype:
|
||||
key = key.to(dtype=query.dtype)
|
||||
if query.dtype != value.dtype:
|
||||
value = value.to(dtype=query.dtype)
|
||||
if attn_mask is not None and query.dtype != attn_mask.dtype:
|
||||
attn_mask = attn_mask.to(dtype=query.dtype)
|
||||
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
|
||||
|
||||
# A1111 FP16
|
||||
original_functional_group_norm = torch.nn.functional.group_norm
|
||||
@wraps(torch.nn.functional.group_norm)
|
||||
def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05):
|
||||
if weight is not None and input.dtype != weight.data.dtype:
|
||||
input = input.to(dtype=weight.data.dtype)
|
||||
@@ -89,6 +115,7 @@ def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05):
|
||||
|
||||
# A1111 BF16
|
||||
original_functional_layer_norm = torch.nn.functional.layer_norm
|
||||
@wraps(torch.nn.functional.layer_norm)
|
||||
def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
|
||||
if weight is not None and input.dtype != weight.data.dtype:
|
||||
input = input.to(dtype=weight.data.dtype)
|
||||
@@ -98,6 +125,7 @@ def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1
|
||||
|
||||
# Training
|
||||
original_functional_linear = torch.nn.functional.linear
|
||||
@wraps(torch.nn.functional.linear)
|
||||
def functional_linear(input, weight, bias=None):
|
||||
if input.dtype != weight.data.dtype:
|
||||
input = input.to(dtype=weight.data.dtype)
|
||||
@@ -106,6 +134,7 @@ def functional_linear(input, weight, bias=None):
|
||||
return original_functional_linear(input, weight, bias=bias)
|
||||
|
||||
original_functional_conv2d = torch.nn.functional.conv2d
|
||||
@wraps(torch.nn.functional.conv2d)
|
||||
def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
||||
if input.dtype != weight.data.dtype:
|
||||
input = input.to(dtype=weight.data.dtype)
|
||||
@@ -115,6 +144,7 @@ def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1,
|
||||
|
||||
# A1111 Embedding BF16
|
||||
original_torch_cat = torch.cat
|
||||
@wraps(torch.cat)
|
||||
def torch_cat(tensor, *args, **kwargs):
|
||||
if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
|
||||
return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs)
|
||||
@@ -123,6 +153,7 @@ def torch_cat(tensor, *args, **kwargs):
|
||||
|
||||
# SwinIR BF16:
|
||||
original_functional_pad = torch.nn.functional.pad
|
||||
@wraps(torch.nn.functional.pad)
|
||||
def functional_pad(input, pad, mode='constant', value=None):
|
||||
if mode == 'reflect' and input.dtype == torch.bfloat16:
|
||||
return original_functional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16)
|
||||
@@ -131,13 +162,20 @@ def functional_pad(input, pad, mode='constant', value=None):
|
||||
|
||||
|
||||
original_torch_tensor = torch.tensor
|
||||
def torch_tensor(*args, device=None, **kwargs):
|
||||
@wraps(torch.tensor)
|
||||
def torch_tensor(data, *args, dtype=None, device=None, **kwargs):
|
||||
if check_device(device):
|
||||
return original_torch_tensor(*args, device=return_xpu(device), **kwargs)
|
||||
else:
|
||||
return original_torch_tensor(*args, device=device, **kwargs)
|
||||
device = return_xpu(device)
|
||||
if not device_supports_fp64:
|
||||
if (isinstance(device, torch.device) and device.type == "xpu") or (isinstance(device, str) and "xpu" in device):
|
||||
if dtype == torch.float64:
|
||||
dtype = torch.float32
|
||||
elif dtype is None and (hasattr(data, "dtype") and (data.dtype == torch.float64 or data.dtype == float)):
|
||||
dtype = torch.float32
|
||||
return original_torch_tensor(data, *args, dtype=dtype, device=device, **kwargs)
|
||||
|
||||
original_Tensor_to = torch.Tensor.to
|
||||
@wraps(torch.Tensor.to)
|
||||
def Tensor_to(self, device=None, *args, **kwargs):
|
||||
if check_device(device):
|
||||
return original_Tensor_to(self, return_xpu(device), *args, **kwargs)
|
||||
@@ -145,6 +183,7 @@ def Tensor_to(self, device=None, *args, **kwargs):
|
||||
return original_Tensor_to(self, device, *args, **kwargs)
|
||||
|
||||
original_Tensor_cuda = torch.Tensor.cuda
|
||||
@wraps(torch.Tensor.cuda)
|
||||
def Tensor_cuda(self, device=None, *args, **kwargs):
|
||||
if check_device(device):
|
||||
return original_Tensor_cuda(self, return_xpu(device), *args, **kwargs)
|
||||
@@ -152,6 +191,7 @@ def Tensor_cuda(self, device=None, *args, **kwargs):
|
||||
return original_Tensor_cuda(self, device, *args, **kwargs)
|
||||
|
||||
original_UntypedStorage_init = torch.UntypedStorage.__init__
|
||||
@wraps(torch.UntypedStorage.__init__)
|
||||
def UntypedStorage_init(*args, device=None, **kwargs):
|
||||
if check_device(device):
|
||||
return original_UntypedStorage_init(*args, device=return_xpu(device), **kwargs)
|
||||
@@ -159,6 +199,7 @@ def UntypedStorage_init(*args, device=None, **kwargs):
|
||||
return original_UntypedStorage_init(*args, device=device, **kwargs)
|
||||
|
||||
original_UntypedStorage_cuda = torch.UntypedStorage.cuda
|
||||
@wraps(torch.UntypedStorage.cuda)
|
||||
def UntypedStorage_cuda(self, device=None, *args, **kwargs):
|
||||
if check_device(device):
|
||||
return original_UntypedStorage_cuda(self, return_xpu(device), *args, **kwargs)
|
||||
@@ -166,6 +207,7 @@ def UntypedStorage_cuda(self, device=None, *args, **kwargs):
|
||||
return original_UntypedStorage_cuda(self, device, *args, **kwargs)
|
||||
|
||||
original_torch_empty = torch.empty
|
||||
@wraps(torch.empty)
|
||||
def torch_empty(*args, device=None, **kwargs):
|
||||
if check_device(device):
|
||||
return original_torch_empty(*args, device=return_xpu(device), **kwargs)
|
||||
@@ -173,6 +215,7 @@ def torch_empty(*args, device=None, **kwargs):
|
||||
return original_torch_empty(*args, device=device, **kwargs)
|
||||
|
||||
original_torch_randn = torch.randn
|
||||
@wraps(torch.randn)
|
||||
def torch_randn(*args, device=None, **kwargs):
|
||||
if check_device(device):
|
||||
return original_torch_randn(*args, device=return_xpu(device), **kwargs)
|
||||
@@ -180,6 +223,7 @@ def torch_randn(*args, device=None, **kwargs):
|
||||
return original_torch_randn(*args, device=device, **kwargs)
|
||||
|
||||
original_torch_ones = torch.ones
|
||||
@wraps(torch.ones)
|
||||
def torch_ones(*args, device=None, **kwargs):
|
||||
if check_device(device):
|
||||
return original_torch_ones(*args, device=return_xpu(device), **kwargs)
|
||||
@@ -187,6 +231,7 @@ def torch_ones(*args, device=None, **kwargs):
|
||||
return original_torch_ones(*args, device=device, **kwargs)
|
||||
|
||||
original_torch_zeros = torch.zeros
|
||||
@wraps(torch.zeros)
|
||||
def torch_zeros(*args, device=None, **kwargs):
|
||||
if check_device(device):
|
||||
return original_torch_zeros(*args, device=return_xpu(device), **kwargs)
|
||||
@@ -194,6 +239,7 @@ def torch_zeros(*args, device=None, **kwargs):
|
||||
return original_torch_zeros(*args, device=device, **kwargs)
|
||||
|
||||
original_torch_linspace = torch.linspace
|
||||
@wraps(torch.linspace)
|
||||
def torch_linspace(*args, device=None, **kwargs):
|
||||
if check_device(device):
|
||||
return original_torch_linspace(*args, device=return_xpu(device), **kwargs)
|
||||
@@ -201,6 +247,7 @@ def torch_linspace(*args, device=None, **kwargs):
|
||||
return original_torch_linspace(*args, device=device, **kwargs)
|
||||
|
||||
original_torch_Generator = torch.Generator
|
||||
@wraps(torch.Generator)
|
||||
def torch_Generator(device=None):
|
||||
if check_device(device):
|
||||
return original_torch_Generator(return_xpu(device))
|
||||
@@ -208,12 +255,14 @@ def torch_Generator(device=None):
|
||||
return original_torch_Generator(device)
|
||||
|
||||
original_torch_load = torch.load
|
||||
@wraps(torch.load)
|
||||
def torch_load(f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs):
|
||||
if check_device(map_location):
|
||||
return original_torch_load(f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs)
|
||||
else:
|
||||
return original_torch_load(f, map_location=map_location, pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs)
|
||||
|
||||
|
||||
# Hijack Functions:
|
||||
def ipex_hijacks():
|
||||
torch.tensor = torch_tensor
|
||||
@@ -232,7 +281,7 @@ def ipex_hijacks():
|
||||
torch.backends.cuda.sdp_kernel = return_null_context
|
||||
torch.nn.DataParallel = DummyDataParallel
|
||||
torch.UntypedStorage.is_cuda = is_cuda
|
||||
torch.autocast = ipex_autocast
|
||||
torch.amp.autocast_mode.autocast.__init__ = autocast_init
|
||||
|
||||
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
|
||||
torch.nn.functional.group_norm = functional_group_norm
|
||||
@@ -244,5 +293,6 @@ def ipex_hijacks():
|
||||
|
||||
torch.bmm = torch_bmm
|
||||
torch.cat = torch_cat
|
||||
if not torch.xpu.has_fp64_dtype():
|
||||
if not device_supports_fp64:
|
||||
torch.from_numpy = from_numpy
|
||||
torch.as_tensor = as_tensor
|
||||
|
||||
24
library/ipex_interop.py
Normal file
24
library/ipex_interop.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import torch
|
||||
|
||||
|
||||
def init_ipex():
|
||||
"""
|
||||
Try to import `intel_extension_for_pytorch`, and apply
|
||||
the hijacks using `library.ipex.ipex_init`.
|
||||
|
||||
If IPEX is not installed, this function does nothing.
|
||||
"""
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex # noqa
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
try:
|
||||
from library.ipex import ipex_init
|
||||
|
||||
if torch.xpu.is_available():
|
||||
is_initialized, error_message = ipex_init()
|
||||
if not is_initialized:
|
||||
print("failed to initialize ipex:", error_message)
|
||||
except Exception as e:
|
||||
print("failed to initialize ipex:", e)
|
||||
@@ -5,15 +5,9 @@ import math
|
||||
import os
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
init_ipex()
|
||||
import diffusers
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
|
||||
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel
|
||||
|
||||
@@ -1262,9 +1262,9 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
for attn in self.attentions:
|
||||
attn.set_use_memory_efficient_attention(xformers, mem_eff)
|
||||
|
||||
def set_use_sdpa(self, spda):
|
||||
def set_use_sdpa(self, sdpa):
|
||||
for attn in self.attentions:
|
||||
attn.set_use_sdpa(spda)
|
||||
attn.set_use_sdpa(sdpa)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -923,7 +923,11 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
|
||||
if up1 is not None:
|
||||
uncond_pool = up1
|
||||
|
||||
dtype = self.unet.dtype
|
||||
unet_dtype = self.unet.dtype
|
||||
dtype = unet_dtype
|
||||
if hasattr(dtype, "itemsize") and dtype.itemsize == 1: # fp8
|
||||
dtype = torch.float16
|
||||
self.unet.to(dtype)
|
||||
|
||||
# 4. Preprocess image and mask
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
@@ -1028,6 +1032,7 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
|
||||
if is_cancelled_callback is not None and is_cancelled_callback():
|
||||
return None
|
||||
|
||||
self.unet.to(unet_dtype)
|
||||
return latents
|
||||
|
||||
def latents_to_image(self, latents):
|
||||
|
||||
@@ -558,6 +558,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]],
|
||||
max_token_length: int,
|
||||
resolution: Optional[Tuple[int, int]],
|
||||
network_multiplier: float,
|
||||
debug_dataset: bool,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -567,6 +568,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
self.max_token_length = max_token_length
|
||||
# width/height is used when enable_bucket==False
|
||||
self.width, self.height = (None, None) if resolution is None else resolution
|
||||
self.network_multiplier = network_multiplier
|
||||
self.debug_dataset = debug_dataset
|
||||
|
||||
self.subsets: List[Union[DreamBoothSubset, FineTuningSubset]] = []
|
||||
@@ -1106,7 +1108,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
for image_key in bucket[image_index : image_index + bucket_batch_size]:
|
||||
image_info = self.image_data[image_key]
|
||||
subset = self.image_to_subset[image_key]
|
||||
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
|
||||
loss_weights.append(
|
||||
self.prior_loss_weight if image_info.is_reg else 1.0
|
||||
) # in case of fine tuning, is_reg is always False
|
||||
|
||||
flipped = subset.flip_aug and random.random() < 0.5 # not flipped or flipped with 50% chance
|
||||
|
||||
@@ -1272,6 +1276,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
example["target_sizes_hw"] = torch.stack([torch.LongTensor(x) for x in target_sizes_hw])
|
||||
example["flippeds"] = flippeds
|
||||
|
||||
example["network_multipliers"] = torch.FloatTensor([self.network_multiplier] * len(captions))
|
||||
|
||||
if self.debug_dataset:
|
||||
example["image_keys"] = bucket[image_index : image_index + self.batch_size]
|
||||
return example
|
||||
@@ -1346,15 +1352,16 @@ class DreamBoothDataset(BaseDataset):
|
||||
tokenizer,
|
||||
max_token_length,
|
||||
resolution,
|
||||
network_multiplier: float,
|
||||
enable_bucket: bool,
|
||||
min_bucket_reso: int,
|
||||
max_bucket_reso: int,
|
||||
bucket_reso_steps: int,
|
||||
bucket_no_upscale: bool,
|
||||
prior_loss_weight: float,
|
||||
debug_dataset,
|
||||
debug_dataset: bool,
|
||||
) -> None:
|
||||
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
|
||||
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
|
||||
|
||||
assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
|
||||
|
||||
@@ -1520,14 +1527,15 @@ class FineTuningDataset(BaseDataset):
|
||||
tokenizer,
|
||||
max_token_length,
|
||||
resolution,
|
||||
network_multiplier: float,
|
||||
enable_bucket: bool,
|
||||
min_bucket_reso: int,
|
||||
max_bucket_reso: int,
|
||||
bucket_reso_steps: int,
|
||||
bucket_no_upscale: bool,
|
||||
debug_dataset,
|
||||
debug_dataset: bool,
|
||||
) -> None:
|
||||
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
|
||||
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
|
||||
|
||||
self.batch_size = batch_size
|
||||
|
||||
@@ -1724,14 +1732,15 @@ class ControlNetDataset(BaseDataset):
|
||||
tokenizer,
|
||||
max_token_length,
|
||||
resolution,
|
||||
network_multiplier: float,
|
||||
enable_bucket: bool,
|
||||
min_bucket_reso: int,
|
||||
max_bucket_reso: int,
|
||||
bucket_reso_steps: int,
|
||||
bucket_no_upscale: bool,
|
||||
debug_dataset,
|
||||
debug_dataset: float,
|
||||
) -> None:
|
||||
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
|
||||
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
|
||||
|
||||
db_subsets = []
|
||||
for subset in subsets:
|
||||
@@ -1765,6 +1774,7 @@ class ControlNetDataset(BaseDataset):
|
||||
tokenizer,
|
||||
max_token_length,
|
||||
resolution,
|
||||
network_multiplier,
|
||||
enable_bucket,
|
||||
min_bucket_reso,
|
||||
max_bucket_reso,
|
||||
@@ -2039,6 +2049,8 @@ def debug_dataset(train_dataset, show_input_ids=False):
|
||||
print(
|
||||
f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}", original size: {orgsz}, crop top left: {crptl}, target size: {trgsz}, flipped: {flpdz}'
|
||||
)
|
||||
if "network_multipliers" in example:
|
||||
print(f"network multiplier: {example['network_multipliers'][j]}")
|
||||
|
||||
if show_input_ids:
|
||||
print(f"input ids: {iid}")
|
||||
@@ -2105,8 +2117,8 @@ def glob_images_pathlib(dir_path, recursive):
|
||||
|
||||
|
||||
class MinimalDataset(BaseDataset):
|
||||
def __init__(self, tokenizer, max_token_length, resolution, debug_dataset=False):
|
||||
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
|
||||
def __init__(self, tokenizer, max_token_length, resolution, network_multiplier, debug_dataset=False):
|
||||
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
|
||||
|
||||
self.num_train_images = 0 # update in subclass
|
||||
self.num_reg_images = 0 # update in subclass
|
||||
@@ -2850,14 +2862,14 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
)
|
||||
parser.add_argument("--torch_compile", action="store_true", help="use torch.compile (requires PyTorch 2.0) / torch.compile を使う")
|
||||
parser.add_argument(
|
||||
"--dynamo_backend",
|
||||
type=str,
|
||||
default="inductor",
|
||||
"--dynamo_backend",
|
||||
type=str,
|
||||
default="inductor",
|
||||
# available backends:
|
||||
# https://github.com/huggingface/accelerate/blob/d1abd59114ada8ba673e1214218cb2878c13b82d/src/accelerate/utils/dataclasses.py#L376-L388C5
|
||||
# https://pytorch.org/docs/stable/torch.compiler.html
|
||||
choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt"],
|
||||
help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)"
|
||||
choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt"],
|
||||
help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)",
|
||||
)
|
||||
parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
|
||||
parser.add_argument(
|
||||
@@ -2904,6 +2916,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
parser.add_argument(
|
||||
"--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する"
|
||||
) # TODO move to SDXL training, because it is not supported by SD1/2
|
||||
parser.add_argument("--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う")
|
||||
parser.add_argument(
|
||||
"--ddp_timeout",
|
||||
type=int,
|
||||
@@ -3886,7 +3899,7 @@ def prepare_accelerator(args: argparse.Namespace):
|
||||
os.environ["WANDB_DIR"] = logging_dir
|
||||
if args.wandb_api_key is not None:
|
||||
wandb.login(key=args.wandb_api_key)
|
||||
|
||||
|
||||
# torch.compile のオプション。 NO の場合は torch.compile は使わない
|
||||
dynamo_backend = "NO"
|
||||
if args.torch_compile:
|
||||
|
||||
@@ -43,6 +43,9 @@ def svd(
|
||||
clamp_quantile=0.99,
|
||||
min_diff=0.01,
|
||||
no_metadata=False,
|
||||
load_precision=None,
|
||||
load_original_model_to=None,
|
||||
load_tuned_model_to=None,
|
||||
):
|
||||
def str_to_dtype(p):
|
||||
if p == "float":
|
||||
@@ -57,28 +60,51 @@ def svd(
|
||||
if v_parameterization is None:
|
||||
v_parameterization = v2
|
||||
|
||||
load_dtype = str_to_dtype(load_precision) if load_precision else None
|
||||
save_dtype = str_to_dtype(save_precision)
|
||||
work_device = "cpu"
|
||||
|
||||
# load models
|
||||
if not sdxl:
|
||||
print(f"loading original SD model : {model_org}")
|
||||
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org)
|
||||
text_encoders_o = [text_encoder_o]
|
||||
if load_dtype is not None:
|
||||
text_encoder_o = text_encoder_o.to(load_dtype)
|
||||
unet_o = unet_o.to(load_dtype)
|
||||
|
||||
print(f"loading tuned SD model : {model_tuned}")
|
||||
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned)
|
||||
text_encoders_t = [text_encoder_t]
|
||||
if load_dtype is not None:
|
||||
text_encoder_t = text_encoder_t.to(load_dtype)
|
||||
unet_t = unet_t.to(load_dtype)
|
||||
|
||||
model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization)
|
||||
else:
|
||||
device_org = load_original_model_to if load_original_model_to else "cpu"
|
||||
device_tuned = load_tuned_model_to if load_tuned_model_to else "cpu"
|
||||
|
||||
print(f"loading original SDXL model : {model_org}")
|
||||
text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, "cpu"
|
||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, device_org
|
||||
)
|
||||
text_encoders_o = [text_encoder_o1, text_encoder_o2]
|
||||
if load_dtype is not None:
|
||||
text_encoder_o1 = text_encoder_o1.to(load_dtype)
|
||||
text_encoder_o2 = text_encoder_o2.to(load_dtype)
|
||||
unet_o = unet_o.to(load_dtype)
|
||||
|
||||
print(f"loading original SDXL model : {model_tuned}")
|
||||
text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, "cpu"
|
||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, device_tuned
|
||||
)
|
||||
text_encoders_t = [text_encoder_t1, text_encoder_t2]
|
||||
if load_dtype is not None:
|
||||
text_encoder_t1 = text_encoder_t1.to(load_dtype)
|
||||
text_encoder_t2 = text_encoder_t2.to(load_dtype)
|
||||
unet_t = unet_t.to(load_dtype)
|
||||
|
||||
model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0
|
||||
|
||||
# create LoRA network to extract weights: Use dim (rank) as alpha
|
||||
@@ -100,38 +126,54 @@ def svd(
|
||||
lora_name = lora_o.lora_name
|
||||
module_o = lora_o.org_module
|
||||
module_t = lora_t.org_module
|
||||
diff = module_t.weight - module_o.weight
|
||||
diff = module_t.weight.to(work_device) - module_o.weight.to(work_device)
|
||||
|
||||
# clear weight to save memory
|
||||
module_o.weight = None
|
||||
module_t.weight = None
|
||||
|
||||
# Text Encoder might be same
|
||||
if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff:
|
||||
text_encoder_different = True
|
||||
print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}")
|
||||
|
||||
diff = diff.float()
|
||||
diffs[lora_name] = diff
|
||||
|
||||
# clear target Text Encoder to save memory
|
||||
for text_encoder in text_encoders_t:
|
||||
del text_encoder
|
||||
|
||||
if not text_encoder_different:
|
||||
print("Text encoder is same. Extract U-Net only.")
|
||||
lora_network_o.text_encoder_loras = []
|
||||
diffs = {}
|
||||
diffs = {} # clear diffs
|
||||
|
||||
for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)):
|
||||
lora_name = lora_o.lora_name
|
||||
module_o = lora_o.org_module
|
||||
module_t = lora_t.org_module
|
||||
diff = module_t.weight - module_o.weight
|
||||
diff = diff.float()
|
||||
diff = module_t.weight.to(work_device) - module_o.weight.to(work_device)
|
||||
|
||||
if args.device:
|
||||
diff = diff.to(args.device)
|
||||
# clear weight to save memory
|
||||
module_o.weight = None
|
||||
module_t.weight = None
|
||||
|
||||
diffs[lora_name] = diff
|
||||
|
||||
# clear LoRA network, target U-Net to save memory
|
||||
del lora_network_o
|
||||
del lora_network_t
|
||||
del unet_t
|
||||
|
||||
# make LoRA with svd
|
||||
print("calculating by svd")
|
||||
lora_weights = {}
|
||||
with torch.no_grad():
|
||||
for lora_name, mat in tqdm(list(diffs.items())):
|
||||
if args.device:
|
||||
mat = mat.to(args.device)
|
||||
mat = mat.to(torch.float) # calc by float
|
||||
|
||||
# if conv_dim is None, diffs do not include LoRAs for conv2d-3x3
|
||||
conv2d = len(mat.size()) == 4
|
||||
kernel_size = None if not conv2d else mat.size()[2:4]
|
||||
@@ -171,8 +213,8 @@ def svd(
|
||||
U = U.reshape(out_dim, rank, 1, 1)
|
||||
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
|
||||
|
||||
U = U.to("cpu").contiguous()
|
||||
Vh = Vh.to("cpu").contiguous()
|
||||
U = U.to(work_device, dtype=save_dtype).contiguous()
|
||||
Vh = Vh.to(work_device, dtype=save_dtype).contiguous()
|
||||
|
||||
lora_weights[lora_name] = (U, Vh)
|
||||
|
||||
@@ -230,6 +272,13 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument(
|
||||
"--sdxl", action="store_true", help="load Stable Diffusion SDXL base model / Stable Diffusion SDXL baseのモデルを読み込む"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load_precision",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=[None, "float", "fp16", "bf16"],
|
||||
help="precision in loading, model default if omitted / 読み込み時に精度を変更して読み込む、省略時はモデルファイルによる"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_precision",
|
||||
type=str,
|
||||
@@ -285,6 +334,18 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
|
||||
+ "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load_original_model_to",
|
||||
type=str,
|
||||
default=None,
|
||||
help="location to load original model, cpu or cuda, cuda:0, etc, default is cpu, only for SDXL / 元モデル読み込み先、cpuまたはcuda、cuda:0など、省略時はcpu、SDXLのみ有効",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load_tuned_model_to",
|
||||
type=str,
|
||||
default=None,
|
||||
help="location to load tuned model, cpu or cuda, cuda:0, etc, default is cpu, only for SDXL / 派生モデル読み込み先、cpuまたはcuda、cuda:0など、省略時はcpu、SDXLのみ有効",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@@ -511,7 +511,9 @@ def get_block_dims_and_alphas(
|
||||
len(block_dims) == num_total_blocks
|
||||
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
|
||||
else:
|
||||
print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
|
||||
print(
|
||||
f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります"
|
||||
)
|
||||
block_dims = [network_dim] * num_total_blocks
|
||||
|
||||
if block_alphas is not None:
|
||||
@@ -1223,3 +1225,40 @@ class LoRANetwork(torch.nn.Module):
|
||||
norms.append(scalednorm.item())
|
||||
|
||||
return keys_scaled, sum(norms) / len(norms), max(norms)
|
||||
|
||||
# region application weight
|
||||
|
||||
def get_number_of_blocks(self):
|
||||
# only for SDXL
|
||||
return 20
|
||||
|
||||
def has_text_encoder_block(self):
|
||||
return self.text_encoder_loras is not None and len(self.text_encoder_loras) > 0
|
||||
|
||||
def set_block_wise_weights(self, weights):
|
||||
if self.text_encoder_loras:
|
||||
for lora in self.text_encoder_loras:
|
||||
lora.multiplier = weights[0]
|
||||
|
||||
for lora in self.unet_loras:
|
||||
# determine block index
|
||||
key = lora.lora_name[10:] # remove "lora_unet_"
|
||||
if key.startswith("input_blocks"):
|
||||
block_index = int(key.split("_")[2]) + 1 # 1-9
|
||||
elif key.startswith("middle_block"):
|
||||
block_index = 10 # int(key.split("_")[2]) + 10
|
||||
elif key.startswith("output_blocks"):
|
||||
block_index = int(key.split("_")[2]) + 11 # 11-19
|
||||
else:
|
||||
print(f"unknown block: {key}")
|
||||
block_index = 0
|
||||
|
||||
lora.multiplier = weights[block_index]
|
||||
|
||||
# print(f"{lora.lora_name} block index: {block_index}, weight: {lora.multiplier}")
|
||||
# print(f"set block-wise weights to {weights}")
|
||||
|
||||
# TODO LoRA の weight をあらかじめ計算しておいて multiplier を掛けるだけにすると速くなるはず
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import math
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
@@ -6,8 +5,6 @@ import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
from library import sai_model_spec, train_util
|
||||
import library.model_util as model_util
|
||||
import lora
|
||||
|
||||
|
||||
CLAMP_QUANTILE = 0.99
|
||||
|
||||
@@ -8,7 +8,7 @@ einops==0.6.1
|
||||
pytorch-lightning==1.9.0
|
||||
# bitsandbytes==0.39.1
|
||||
tensorboard==2.10.1
|
||||
safetensors==0.3.1
|
||||
safetensors==0.4.2
|
||||
# gradio==3.16.2
|
||||
altair==4.2.2
|
||||
easygui==0.98.3
|
||||
|
||||
@@ -18,15 +18,10 @@ import diffusers
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
init_ipex()
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
import torchvision
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
|
||||
@@ -9,13 +9,11 @@ import random
|
||||
from einops import repeat
|
||||
import numpy as np
|
||||
import torch
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
from tqdm import tqdm
|
||||
from transformers import CLIPTokenizer
|
||||
from diffusers import EulerDiscreteScheduler
|
||||
|
||||
@@ -11,15 +11,10 @@ import toml
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
init_ipex()
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler
|
||||
from library import sdxl_model_util
|
||||
|
||||
@@ -14,13 +14,11 @@ import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from accelerate.utils import set_seed
|
||||
import accelerate
|
||||
|
||||
@@ -11,13 +11,11 @@ import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler, ControlNetModel
|
||||
|
||||
@@ -1,15 +1,10 @@
|
||||
import argparse
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
init_ipex()
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
from library import sdxl_model_util, sdxl_train_util, train_util
|
||||
import train_network
|
||||
|
||||
@@ -95,8 +90,8 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
unet.to(org_unet_device)
|
||||
else:
|
||||
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
|
||||
text_encoders[0].to(accelerator.device)
|
||||
text_encoders[1].to(accelerator.device)
|
||||
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoders[1].to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
|
||||
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
|
||||
|
||||
@@ -3,13 +3,9 @@ import os
|
||||
|
||||
import regex
|
||||
import torch
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
import open_clip
|
||||
from library import sdxl_model_util, sdxl_train_util, train_util
|
||||
|
||||
|
||||
@@ -12,15 +12,10 @@ import toml
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
init_ipex()
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler, ControlNetModel
|
||||
|
||||
@@ -12,15 +12,10 @@ import toml
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
init_ipex()
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler
|
||||
|
||||
|
||||
@@ -14,15 +14,10 @@ from tqdm import tqdm
|
||||
import torch
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
init_ipex()
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler
|
||||
from library import model_util
|
||||
@@ -117,7 +112,7 @@ class NetworkTrainer:
|
||||
self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype
|
||||
):
|
||||
for t_enc in text_encoders:
|
||||
t_enc.to(accelerator.device)
|
||||
t_enc.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
|
||||
input_ids = batch["input_ids"].to(accelerator.device)
|
||||
@@ -278,6 +273,7 @@ class NetworkTrainer:
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される
|
||||
# cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu
|
||||
self.cache_text_encoder_outputs_if_needed(
|
||||
args, accelerator, unet, vae, tokenizers, text_encoders, train_dataset_group, weight_dtype
|
||||
)
|
||||
@@ -309,6 +305,7 @@ class NetworkTrainer:
|
||||
)
|
||||
if network is None:
|
||||
return
|
||||
network_has_multiplier = hasattr(network, "set_multiplier")
|
||||
|
||||
if hasattr(network, "prepare_network"):
|
||||
network.prepare_network(args)
|
||||
@@ -389,17 +386,33 @@ class NetworkTrainer:
|
||||
accelerator.print("enable full bf16 training.")
|
||||
network.to(weight_dtype)
|
||||
|
||||
unet_weight_dtype = te_weight_dtype = weight_dtype
|
||||
# Experimental Feature: Put base model into fp8 to save vram
|
||||
if args.fp8_base:
|
||||
assert torch.__version__ >= "2.1.0", "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。"
|
||||
assert (
|
||||
args.mixed_precision != "no"
|
||||
), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。"
|
||||
accelerator.print("enable fp8 training.")
|
||||
unet_weight_dtype = torch.float8_e4m3fn
|
||||
te_weight_dtype = torch.float8_e4m3fn
|
||||
|
||||
unet.requires_grad_(False)
|
||||
unet.to(dtype=weight_dtype)
|
||||
unet.to(dtype=unet_weight_dtype)
|
||||
for t_enc in text_encoders:
|
||||
t_enc.requires_grad_(False)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
# TODO めちゃくちゃ冗長なのでコードを整理する
|
||||
# in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16
|
||||
if t_enc.device.type != "cpu":
|
||||
t_enc.to(dtype=te_weight_dtype)
|
||||
# nn.Embedding not support FP8
|
||||
t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
|
||||
if train_unet:
|
||||
unet = accelerator.prepare(unet)
|
||||
else:
|
||||
unet.to(accelerator.device, dtype=weight_dtype) # move to device because unet is not prepared by accelerator
|
||||
unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator
|
||||
if train_text_encoder:
|
||||
if len(text_encoders) > 1:
|
||||
text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders]
|
||||
@@ -407,8 +420,8 @@ class NetworkTrainer:
|
||||
text_encoder = accelerator.prepare(text_encoder)
|
||||
text_encoders = [text_encoder]
|
||||
else:
|
||||
for t_enc in text_encoders:
|
||||
t_enc.to(accelerator.device, dtype=weight_dtype)
|
||||
pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set
|
||||
|
||||
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
@@ -421,9 +434,6 @@ class NetworkTrainer:
|
||||
if train_text_encoder:
|
||||
t_enc.text_model.embeddings.requires_grad_(True)
|
||||
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
if not train_text_encoder: # train U-Net only
|
||||
unet.parameters().__next__().requires_grad_(True)
|
||||
else:
|
||||
unet.eval()
|
||||
for t_enc in text_encoders:
|
||||
@@ -685,7 +695,7 @@ class NetworkTrainer:
|
||||
if accelerator.is_main_process:
|
||||
init_kwargs = {}
|
||||
if args.wandb_run_name:
|
||||
init_kwargs['wandb'] = {'name': args.wandb_run_name}
|
||||
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers(
|
||||
@@ -754,7 +764,17 @@ class NetworkTrainer:
|
||||
accelerator.print("NaN found in latents, replacing with zeros")
|
||||
latents = torch.nan_to_num(latents, 0, out=latents)
|
||||
latents = latents * self.vae_scale_factor
|
||||
b_size = latents.shape[0]
|
||||
|
||||
# get multiplier for each sample
|
||||
if network_has_multiplier:
|
||||
multipliers = batch["network_multipliers"]
|
||||
# if all multipliers are same, use single multiplier
|
||||
if torch.all(multipliers == multipliers[0]):
|
||||
multipliers = multipliers[0].item()
|
||||
else:
|
||||
raise NotImplementedError("multipliers for each sample is not supported yet")
|
||||
# print(f"set multiplier: {multipliers}")
|
||||
accelerator.unwrap_model(network).set_multiplier(multipliers)
|
||||
|
||||
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
|
||||
# Get the text embedding for conditioning
|
||||
@@ -778,10 +798,24 @@ class NetworkTrainer:
|
||||
args, noise_scheduler, latents
|
||||
)
|
||||
|
||||
# ensure the hidden state will require grad
|
||||
if args.gradient_checkpointing:
|
||||
for x in noisy_latents:
|
||||
x.requires_grad_(True)
|
||||
for t in text_encoder_conds:
|
||||
t.requires_grad_(True)
|
||||
|
||||
# Predict the noise residual
|
||||
with accelerator.autocast():
|
||||
noise_pred = self.call_unet(
|
||||
args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype
|
||||
args,
|
||||
accelerator,
|
||||
unet,
|
||||
noisy_latents.requires_grad_(train_unet),
|
||||
timesteps,
|
||||
text_encoder_conds,
|
||||
batch,
|
||||
weight_dtype,
|
||||
)
|
||||
|
||||
if args.v_parameterization:
|
||||
@@ -808,10 +842,11 @@ class NetworkTrainer:
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
accelerator.backward(loss)
|
||||
self.all_reduce_network(accelerator, network) # sync DDP grad manually
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
if accelerator.sync_gradients:
|
||||
self.all_reduce_network(accelerator, network) # sync DDP grad manually
|
||||
if args.max_grad_norm != 0.0:
|
||||
params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
|
||||
1039
train_network_appl_weights.py
Normal file
1039
train_network_appl_weights.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -8,15 +8,10 @@ import toml
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
init_ipex()
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler
|
||||
from transformers import CLIPTokenizer
|
||||
@@ -505,7 +500,7 @@ class TextualInversionTrainer:
|
||||
if accelerator.is_main_process:
|
||||
init_kwargs = {}
|
||||
if args.wandb_run_name:
|
||||
init_kwargs['wandb'] = {'name': args.wandb_run_name}
|
||||
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers(
|
||||
@@ -730,14 +725,13 @@ class TextualInversionTrainer:
|
||||
is_main_process = accelerator.is_main_process
|
||||
if is_main_process:
|
||||
text_encoder = accelerator.unwrap_model(text_encoder)
|
||||
updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone()
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
if args.save_state and is_main_process:
|
||||
train_util.save_state_on_train_end(args, accelerator)
|
||||
|
||||
updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone()
|
||||
|
||||
if is_main_process:
|
||||
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
|
||||
save_model(ckpt_name, updated_embs_list, global_step, num_train_epochs, force_sync_upload=True)
|
||||
|
||||
@@ -8,13 +8,11 @@ from multiprocessing import Value
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
from accelerate.utils import set_seed
|
||||
import diffusers
|
||||
from diffusers import DDPMScheduler
|
||||
|
||||
Reference in New Issue
Block a user