mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 13:47:06 +00:00
Merge branch 'dev' into gradual_latent_hires_fix
This commit is contained in:
181
README.md
181
README.md
@@ -281,98 +281,113 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum
|
||||
|
||||
## Change History
|
||||
|
||||
### Dec 24, 2023 / 2023/12/24
|
||||
### Jan 27, 2024 / 2024/1/27: v0.8.3
|
||||
|
||||
- Fixed to work `tools/convert_diffusers20_original_sd.py`. Thanks to Disty0! PR [#1016](https://github.com/kohya-ss/sd-scripts/pull/1016)
|
||||
- Fixed a bug that the training crashes when `--fp8_base` is specified with `--save_state`. PR [#1079](https://github.com/kohya-ss/sd-scripts/pull/1079) Thanks to feffy380!
|
||||
- `safetensors` is updated. Please see [Upgrade](#upgrade) and update the library.
|
||||
- Fixed a bug that the training crashes when `network_multiplier` is specified with multi-GPU training. PR [#1084](https://github.com/kohya-ss/sd-scripts/pull/1084) Thanks to fireicewolf!
|
||||
- Fixed a bug that the training crashes when training ControlNet-LLLite.
|
||||
|
||||
- `tools/convert_diffusers20_original_sd.py` が動かなくなっていたのが修正されました。Disty0 氏に感謝します。 PR [#1016](https://github.com/kohya-ss/sd-scripts/pull/1016)
|
||||
- `--fp8_base` 指定時に `--save_state` での保存がエラーになる不具合が修正されました。 PR [#1079](https://github.com/kohya-ss/sd-scripts/pull/1079) feffy380 氏に感謝します。
|
||||
- `safetensors` がバージョンアップされていますので、[Upgrade](#upgrade) を参照し更新をお願いします。
|
||||
- 複数 GPU での学習時に `network_multiplier` を指定するとクラッシュする不具合が修正されました。 PR [#1084](https://github.com/kohya-ss/sd-scripts/pull/1084) fireicewolf 氏に感謝します。
|
||||
- ControlNet-LLLite の学習がエラーになる不具合を修正しました。
|
||||
|
||||
### Jan 23, 2024 / 2024/1/23: v0.8.2
|
||||
|
||||
- [Experimental] The `--fp8_base` option is added to the training scripts for LoRA etc. The base model (U-Net, and Text Encoder when training modules for Text Encoder) can be trained with fp8. PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) Thanks to KohakuBlueleaf!
|
||||
- Please specify `--fp8_base` in `train_network.py` or `sdxl_train_network.py`.
|
||||
- PyTorch 2.1 or later is required.
|
||||
- If you use xformers with PyTorch 2.1, please see [xformers repository](https://github.com/facebookresearch/xformers) and install the appropriate version according to your CUDA version.
|
||||
- The sample image generation during training consumes a lot of memory. It is recommended to turn it off.
|
||||
|
||||
- [Experimental] The network multiplier can be specified for each dataset in the training scripts for LoRA etc.
|
||||
- This is an experimental option and may be removed or changed in the future.
|
||||
- For example, if you train with state A as `1.0` and state B as `-1.0`, you may be able to generate by switching between state A and B depending on the LoRA application rate.
|
||||
- Also, if you prepare five states and train them as `0.2`, `0.4`, `0.6`, `0.8`, and `1.0`, you may be able to generate by switching the states smoothly depending on the application rate.
|
||||
- Please specify `network_multiplier` in `[[datasets]]` in `.toml` file.
|
||||
- Some options are added to `networks/extract_lora_from_models.py` to reduce the memory usage.
|
||||
- `--load_precision` option can be used to specify the precision when loading the model. If the model is saved in fp16, you can reduce the memory usage by specifying `--load_precision fp16` without losing precision.
|
||||
- `--load_original_model_to` option can be used to specify the device to load the original model. `--load_tuned_model_to` option can be used to specify the device to load the derived model. The default is `cpu` for both options, but you can specify `cuda` etc. You can reduce the memory usage by loading one of them to GPU. This option is available only for SDXL.
|
||||
|
||||
- The gradient synchronization in LoRA training with multi-GPU is improved. PR [#1064](https://github.com/kohya-ss/sd-scripts/pull/1064) Thanks to KohakuBlueleaf!
|
||||
- The code for Intel IPEX support is improved. PR [#1060](https://github.com/kohya-ss/sd-scripts/pull/1060) Thanks to akx!
|
||||
- Fixed a bug in multi-GPU Textual Inversion training.
|
||||
|
||||
- (実験的) LoRA等の学習スクリプトで、ベースモデル(U-Net、および Text Encoder のモジュール学習時は Text Encoder も)の重みを fp8 にして学習するオプションが追加されました。 PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) KohakuBlueleaf 氏に感謝します。
|
||||
- `train_network.py` または `sdxl_train_network.py` で `--fp8_base` を指定してください。
|
||||
- PyTorch 2.1 以降が必要です。
|
||||
- PyTorch 2.1 で xformers を使用する場合は、[xformers のリポジトリ](https://github.com/facebookresearch/xformers) を参照し、CUDA バージョンに応じて適切なバージョンをインストールしてください。
|
||||
- 学習中のサンプル画像生成はメモリを大量に消費するため、オフにすることをお勧めします。
|
||||
- (実験的) LoRA 等の学習で、データセットごとに異なるネットワーク適用率を指定できるようになりました。
|
||||
- 実験的オプションのため、将来的に削除または仕様変更される可能性があります。
|
||||
- たとえば状態 A を `1.0`、状態 B を `-1.0` として学習すると、LoRA の適用率に応じて状態 A と B を切り替えつつ生成できるかもしれません。
|
||||
- また、五段階の状態を用意し、それぞれ `0.2`、`0.4`、`0.6`、`0.8`、`1.0` として学習すると、適用率でなめらかに状態を切り替えて生成できるかもしれません。
|
||||
- `.toml` ファイルで `[[datasets]]` に `network_multiplier` を指定してください。
|
||||
- `networks/extract_lora_from_models.py` に使用メモリ量を削減するいくつかのオプションを追加しました。
|
||||
- `--load_precision` で読み込み時の精度を指定できます。モデルが fp16 で保存されている場合は `--load_precision fp16` を指定して精度を変えずにメモリ量を削減できます。
|
||||
- `--load_original_model_to` で元モデルを読み込むデバイスを、`--load_tuned_model_to` で派生モデルを読み込むデバイスを指定できます。デフォルトは両方とも `cpu` ですがそれぞれ `cuda` 等を指定できます。片方を GPU に読み込むことでメモリ量を削減できます。SDXL の場合のみ有効です。
|
||||
- マルチ GPU での LoRA 等の学習時に勾配の同期が改善されました。 PR [#1064](https://github.com/kohya-ss/sd-scripts/pull/1064) KohakuBlueleaf 氏に感謝します。
|
||||
- Intel IPEX サポートのコードが改善されました。PR [#1060](https://github.com/kohya-ss/sd-scripts/pull/1060) akx 氏に感謝します。
|
||||
- マルチ GPU での Textual Inversion 学習の不具合を修正しました。
|
||||
|
||||
- `.toml` example for network multiplier / ネットワーク適用率の `.toml` の記述例
|
||||
|
||||
```toml
|
||||
[general]
|
||||
[[datasets]]
|
||||
resolution = 512
|
||||
batch_size = 8
|
||||
network_multiplier = 1.0
|
||||
|
||||
... subset settings ...
|
||||
|
||||
[[datasets]]
|
||||
resolution = 512
|
||||
batch_size = 8
|
||||
network_multiplier = -1.0
|
||||
|
||||
... subset settings ...
|
||||
```
|
||||
|
||||
|
||||
### Dec 21, 2023 / 2023/12/21
|
||||
### Jan 17, 2024 / 2024/1/17: v0.8.1
|
||||
|
||||
- The issues in multi-GPU training are fixed. Thanks to Isotr0py! PR [#989](https://github.com/kohya-ss/sd-scripts/pull/989) and [#1000](https://github.com/kohya-ss/sd-scripts/pull/1000)
|
||||
- `--ddp_gradient_as_bucket_view` and `--ddp_bucket_view`options are added to `sdxl_train.py`. Please specify these options for multi-GPU training.
|
||||
- IPEX support is updated. Thanks to Disty0!
|
||||
- Fixed the bug that the size of the bucket becomes less than `min_bucket_reso`. Thanks to Cauldrath! PR [#1008](https://github.com/kohya-ss/sd-scripts/pull/1008)
|
||||
- `--sample_at_first` option is added to each training script. This option is useful to generate images at the first step, before training. Thanks to shirayu! PR [#907](https://github.com/kohya-ss/sd-scripts/pull/907)
|
||||
- `--ss` option is added to the sampling prompt in training. You can specify the scheduler for the sampling like `--ss euler_a`. Thanks to shirayu! PR [#906](https://github.com/kohya-ss/sd-scripts/pull/906)
|
||||
- `keep_tokens_separator` is added to the dataset config. This option is useful to keep (prevent from shuffling) the tokens in the captions. See [#975](https://github.com/kohya-ss/sd-scripts/pull/975) for details. Thanks to Linaqruf!
|
||||
- You can specify the separator with an option like `--keep_tokens_separator "|||"` or with `keep_tokens_separator: "|||"` in `.toml`. The tokens before `|||` are not shuffled.
|
||||
- Attention processor hook is added. See [#961](https://github.com/kohya-ss/sd-scripts/pull/961) for details. Thanks to rockerBOO!
|
||||
- The optimizer `PagedAdamW` is added. Thanks to xzuyn! PR [#955](https://github.com/kohya-ss/sd-scripts/pull/955)
|
||||
- NaN replacement in SDXL VAE is sped up. Thanks to liubo0902! PR [#1009](https://github.com/kohya-ss/sd-scripts/pull/1009)
|
||||
- Fixed the path error in `finetune/make_captions.py`. Thanks to CjangCjengh! PR [#986](https://github.com/kohya-ss/sd-scripts/pull/986)
|
||||
- Fixed a bug that the VRAM usage without Text Encoder training is larger than before in training scripts for LoRA etc (`train_network.py`, `sdxl_train_network.py`).
|
||||
- Text Encoders were not moved to CPU.
|
||||
- Fixed typos. Thanks to akx! [PR #1053](https://github.com/kohya-ss/sd-scripts/pull/1053)
|
||||
|
||||
- マルチGPUでの学習の不具合を修正しました。Isotr0py 氏に感謝します。 PR [#989](https://github.com/kohya-ss/sd-scripts/pull/989) および [#1000](https://github.com/kohya-ss/sd-scripts/pull/1000)
|
||||
- `sdxl_train.py` に `--ddp_gradient_as_bucket_view` と `--ddp_bucket_view` オプションが追加されました。マルチGPUでの学習時にはこれらのオプションを指定してください。
|
||||
- IPEX サポートが更新されました。Disty0 氏に感謝します。
|
||||
- Aspect Ratio Bucketing で bucket のサイズが `min_bucket_reso` 未満になる不具合を修正しました。Cauldrath 氏に感謝します。 PR [#1008](https://github.com/kohya-ss/sd-scripts/pull/1008)
|
||||
- 各学習スクリプトに `--sample_at_first` オプションが追加されました。学習前に画像を生成することで、学習結果が比較しやすくなります。shirayu 氏に感謝します。 PR [#907](https://github.com/kohya-ss/sd-scripts/pull/907)
|
||||
- 学習時のプロンプトに `--ss` オプションが追加されました。`--ss euler_a` のようにスケジューラを指定できます。shirayu 氏に感謝します。 PR [#906](https://github.com/kohya-ss/sd-scripts/pull/906)
|
||||
- データセット設定に `keep_tokens_separator` が追加されました。キャプション内のトークンをどの位置までシャッフルしないかを指定できます。詳細は [#975](https://github.com/kohya-ss/sd-scripts/pull/975) を参照してください。Linaqruf 氏に感謝します。
|
||||
- オプションで `--keep_tokens_separator "|||"` のように指定するか、`.toml` で `keep_tokens_separator: "|||"` のように指定します。`|||` の前のトークンはシャッフルされません。
|
||||
- Attention processor hook が追加されました。詳細は [#961](https://github.com/kohya-ss/sd-scripts/pull/961) を参照してください。rockerBOO 氏に感謝します。
|
||||
- オプティマイザ `PagedAdamW` が追加されました。xzuyn 氏に感謝します。 PR [#955](https://github.com/kohya-ss/sd-scripts/pull/955)
|
||||
- 学習時、SDXL VAE で NaN が発生した時の置き換えが高速化されました。liubo0902 氏に感謝します。 PR [#1009](https://github.com/kohya-ss/sd-scripts/pull/1009)
|
||||
- `finetune/make_captions.py` で相対パス指定時のエラーが修正されました。CjangCjengh 氏に感謝します。 PR [#986](https://github.com/kohya-ss/sd-scripts/pull/986)
|
||||
- LoRA 等の学習スクリプト(`train_network.py`、`sdxl_train_network.py`)で、Text Encoder を学習しない場合の VRAM 使用量が以前に比べて大きくなっていた不具合を修正しました。
|
||||
- Text Encoder が GPU に保持されたままになっていました。
|
||||
- 誤字が修正されました。 [PR #1053](https://github.com/kohya-ss/sd-scripts/pull/1053) akx 氏に感謝します。
|
||||
|
||||
### Dec 3, 2023 / 2023/12/3
|
||||
### Jan 15, 2024 / 2024/1/15: v0.8.0
|
||||
|
||||
- `finetune\tag_images_by_wd14_tagger.py` now supports the separator other than `,` with `--caption_separator` option. Thanks to KohakuBlueleaf! PR [#913](https://github.com/kohya-ss/sd-scripts/pull/913)
|
||||
- Min SNR Gamma with V-predicition (SD 2.1) is fixed. Thanks to feffy380! PR[#934](https://github.com/kohya-ss/sd-scripts/pull/934)
|
||||
- See [#673](https://github.com/kohya-ss/sd-scripts/issues/673) for details.
|
||||
- `--min_diff` and `--clamp_quantile` options are added to `networks/extract_lora_from_models.py`. Thanks to wkpark! PR [#936](https://github.com/kohya-ss/sd-scripts/pull/936)
|
||||
- The default values are same as the previous version.
|
||||
- Deep Shrink hires fix is supported in `sdxl_gen_img.py` and `gen_img_diffusers.py`.
|
||||
- `--ds_timesteps_1` and `--ds_timesteps_2` options denote the timesteps of the Deep Shrink for the first and second stages.
|
||||
- `--ds_depth_1` and `--ds_depth_2` options denote the depth (block index) of the Deep Shrink for the first and second stages.
|
||||
- `--ds_ratio` option denotes the ratio of the Deep Shrink. `0.5` means the half of the original latent size for the Deep Shrink.
|
||||
- `--dst1`, `--dst2`, `--dsd1`, `--dsd2` and `--dsr` prompt options are also available.
|
||||
- Diffusers, Accelerate, Transformers and other related libraries have been updated. Please update the libraries with [Upgrade](#upgrade).
|
||||
- Some model files (Text Encoder without position_id) based on the latest Transformers can be loaded.
|
||||
- `torch.compile` is supported (experimental). PR [#1024](https://github.com/kohya-ss/sd-scripts/pull/1024) Thanks to p1atdev!
|
||||
- This feature works only on Linux or WSL.
|
||||
- Please specify `--torch_compile` option in each training script.
|
||||
- You can select the backend with `--dynamo_backend` option. The default is `"inductor"`. `inductor` or `eager` seems to work.
|
||||
- Please use `--sdpa` option instead of `--xformers` option.
|
||||
- PyTorch 2.1 or later is recommended.
|
||||
- Please see [PR](https://github.com/kohya-ss/sd-scripts/pull/1024) for details.
|
||||
- The session name for wandb can be specified with `--wandb_run_name` option. PR [#1032](https://github.com/kohya-ss/sd-scripts/pull/1032) Thanks to hopl1t!
|
||||
- IPEX library is updated. PR [#1030](https://github.com/kohya-ss/sd-scripts/pull/1030) Thanks to Disty0!
|
||||
- Fixed a bug that Diffusers format model cannot be saved.
|
||||
|
||||
- `finetune\tag_images_by_wd14_tagger.py` で `--caption_separator` オプションでカンマ以外の区切り文字を指定できるようになりました。KohakuBlueleaf 氏に感謝します。 PR [#913](https://github.com/kohya-ss/sd-scripts/pull/913)
|
||||
- V-predicition (SD 2.1) での Min SNR Gamma が修正されました。feffy380 氏に感謝します。 PR[#934](https://github.com/kohya-ss/sd-scripts/pull/934)
|
||||
- 詳細は [#673](https://github.com/kohya-ss/sd-scripts/issues/673) を参照してください。
|
||||
- `networks/extract_lora_from_models.py` に `--min_diff` と `--clamp_quantile` オプションが追加されました。wkpark 氏に感謝します。 PR [#936](https://github.com/kohya-ss/sd-scripts/pull/936)
|
||||
- デフォルト値は前のバージョンと同じです。
|
||||
- `sdxl_gen_img.py` と `gen_img_diffusers.py` で Deep Shrink hires fix をサポートしました。
|
||||
- `--ds_timesteps_1` と `--ds_timesteps_2` オプションは Deep Shrink の第一段階と第二段階の timesteps を指定します。
|
||||
- `--ds_depth_1` と `--ds_depth_2` オプションは Deep Shrink の第一段階と第二段階の深さ(ブロックの index)を指定します。
|
||||
- `--ds_ratio` オプションは Deep Shrink の比率を指定します。`0.5` を指定すると Deep Shrink 適用時の latent は元のサイズの半分になります。
|
||||
- `--dst1`、`--dst2`、`--dsd1`、`--dsd2`、`--dsr` プロンプトオプションも使用できます。
|
||||
|
||||
### Nov 5, 2023 / 2023/11/5
|
||||
|
||||
- `sdxl_train.py` now supports different learning rates for each Text Encoder.
|
||||
- Example:
|
||||
- `--learning_rate 1e-6`: train U-Net only
|
||||
- `--train_text_encoder --learning_rate 1e-6`: train U-Net and two Text Encoders with the same learning rate (same as the previous version)
|
||||
- `--train_text_encoder --learning_rate 1e-6 --learning_rate_te1 1e-6 --learning_rate_te2 1e-6`: train U-Net and two Text Encoders with the different learning rates
|
||||
- `--train_text_encoder --learning_rate 0 --learning_rate_te1 1e-6 --learning_rate_te2 1e-6`: train two Text Encoders only
|
||||
- `--train_text_encoder --learning_rate 1e-6 --learning_rate_te1 1e-6 --learning_rate_te2 0`: train U-Net and one Text Encoder only
|
||||
- `--train_text_encoder --learning_rate 0 --learning_rate_te1 0 --learning_rate_te2 1e-6`: train one Text Encoder only
|
||||
|
||||
- `train_db.py` and `fine_tune.py` now support different learning rates for Text Encoder. Specify with `--learning_rate_te` option.
|
||||
- To train Text Encoder with `fine_tune.py`, specify `--train_text_encoder` option too. `train_db.py` trains Text Encoder by default.
|
||||
|
||||
- Fixed the bug that Text Encoder is not trained when block lr is specified in `sdxl_train.py`.
|
||||
|
||||
- Debiased Estimation loss is added to each training script. Thanks to sdbds!
|
||||
- Specify `--debiased_estimation_loss` option to enable it. See PR [#889](https://github.com/kohya-ss/sd-scripts/pull/889) for details.
|
||||
- Training of Text Encoder is improved in `train_network.py` and `sdxl_train_network.py`. Thanks to KohakuBlueleaf! PR [#895](https://github.com/kohya-ss/sd-scripts/pull/895)
|
||||
- The moving average of the loss is now displayed in the progress bar in each training script. Thanks to shirayu! PR [#899](https://github.com/kohya-ss/sd-scripts/pull/899)
|
||||
- PagedAdamW32bit optimizer is supported. Specify `--optimizer_type=PagedAdamW32bit`. Thanks to xzuyn! PR [#900](https://github.com/kohya-ss/sd-scripts/pull/900)
|
||||
- Other bug fixes and improvements.
|
||||
|
||||
- `sdxl_train.py` で、二つのText Encoderそれぞれに独立した学習率が指定できるようになりました。サンプルは上の英語版を参照してください。
|
||||
- `train_db.py` および `fine_tune.py` で Text Encoder に別の学習率を指定できるようになりました。`--learning_rate_te` オプションで指定してください。
|
||||
- `fine_tune.py` で Text Encoder を学習するには `--train_text_encoder` オプションをあわせて指定してください。`train_db.py` はデフォルトで学習します。
|
||||
- `sdxl_train.py` で block lr を指定すると Text Encoder が学習されない不具合を修正しました。
|
||||
- Debiased Estimation loss が各学習スクリプトに追加されました。sdbsd 氏に感謝します。
|
||||
- `--debiased_estimation_loss` を指定すると有効になります。詳細は PR [#889](https://github.com/kohya-ss/sd-scripts/pull/889) を参照してください。
|
||||
- `train_network.py` と `sdxl_train_network.py` でText Encoderの学習が改善されました。KohakuBlueleaf 氏に感謝します。 PR [#895](https://github.com/kohya-ss/sd-scripts/pull/895)
|
||||
- 各学習スクリプトで移動平均のlossがプログレスバーに表示されるようになりました。shirayu 氏に感謝します。 PR [#899](https://github.com/kohya-ss/sd-scripts/pull/899)
|
||||
- PagedAdamW32bit オプティマイザがサポートされました。`--optimizer_type=PagedAdamW32bit` と指定してください。xzuyn 氏に感謝します。 PR [#900](https://github.com/kohya-ss/sd-scripts/pull/900)
|
||||
- その他のバグ修正と改善。
|
||||
- Diffusers、Accelerate、Transformers 等の関連ライブラリを更新しました。[Upgrade](#upgrade) を参照し更新をお願いします。
|
||||
- 最新の Transformers を前提とした一部のモデルファイル(Text Encoder が position_id を持たないもの)が読み込めるようになりました。
|
||||
- `torch.compile` がサポートされしました(実験的)。 PR [#1024](https://github.com/kohya-ss/sd-scripts/pull/1024) p1atdev 氏に感謝します。
|
||||
- Linux または WSL でのみ動作します。
|
||||
- 各学習スクリプトで `--torch_compile` オプションを指定してください。
|
||||
- `--dynamo_backend` オプションで使用される backend を選択できます。デフォルトは `"inductor"` です。 `inductor` または `eager` が動作するようです。
|
||||
- `--xformers` オプションとは互換性がありません。 代わりに `--sdpa` オプションを使用してください。
|
||||
- PyTorch 2.1以降を推奨します。
|
||||
- 詳細は [PR](https://github.com/kohya-ss/sd-scripts/pull/1024) をご覧ください。
|
||||
- wandb 保存時のセッション名が各学習スクリプトの `--wandb_run_name` オプションで指定できるようになりました。 PR [#1032](https://github.com/kohya-ss/sd-scripts/pull/1032) hopl1t 氏に感謝します。
|
||||
- IPEX ライブラリが更新されました。[PR #1030](https://github.com/kohya-ss/sd-scripts/pull/1030) Disty0 氏に感謝します。
|
||||
- Diffusers 形式でのモデル保存ができなくなっていた不具合を修正しました。
|
||||
|
||||
|
||||
Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates.
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
import torch
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
from typing import Union, List, Optional, Dict, Any, Tuple
|
||||
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
|
||||
|
||||
|
||||
11
fine_tune.py
11
fine_tune.py
@@ -11,15 +11,10 @@ import toml
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
init_ipex()
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler
|
||||
|
||||
@@ -291,6 +286,8 @@ def train(args):
|
||||
|
||||
if accelerator.is_main_process:
|
||||
init_kwargs = {}
|
||||
if args.wandb_run_name:
|
||||
init_kwargs['wandb'] = {'name': args.wandb_run_name}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
|
||||
|
||||
@@ -66,15 +66,10 @@ import diffusers
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
init_ipex()
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
import torchvision
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
24
library/ipex_interop.py
Normal file
24
library/ipex_interop.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import torch
|
||||
|
||||
|
||||
def init_ipex():
|
||||
"""
|
||||
Try to import `intel_extension_for_pytorch`, and apply
|
||||
the hijacks using `library.ipex.ipex_init`.
|
||||
|
||||
If IPEX is not installed, this function does nothing.
|
||||
"""
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex # noqa
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
try:
|
||||
from library.ipex import ipex_init
|
||||
|
||||
if torch.xpu.is_available():
|
||||
is_initialized, error_message = ipex_init()
|
||||
if not is_initialized:
|
||||
print("failed to initialize ipex:", error_message)
|
||||
except Exception as e:
|
||||
print("failed to initialize ipex:", e)
|
||||
@@ -5,15 +5,9 @@ import math
|
||||
import os
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
init_ipex()
|
||||
import diffusers
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
|
||||
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel
|
||||
@@ -1245,8 +1239,13 @@ def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_mod
|
||||
if vae is None:
|
||||
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
|
||||
|
||||
# original U-Net cannot be saved, so we need to convert it to the Diffusers version
|
||||
# TODO this consumes a lot of memory
|
||||
diffusers_unet = diffusers.UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet")
|
||||
diffusers_unet.load_state_dict(unet.state_dict())
|
||||
|
||||
pipeline = StableDiffusionPipeline(
|
||||
unet=unet,
|
||||
unet=diffusers_unet,
|
||||
text_encoder=text_encoder,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
|
||||
@@ -1262,9 +1262,9 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
for attn in self.attentions:
|
||||
attn.set_use_memory_efficient_attention(xformers, mem_eff)
|
||||
|
||||
def set_use_sdpa(self, spda):
|
||||
def set_use_sdpa(self, sdpa):
|
||||
for attn in self.attentions:
|
||||
attn.set_use_sdpa(spda)
|
||||
attn.set_use_sdpa(sdpa)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -923,7 +923,11 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
|
||||
if up1 is not None:
|
||||
uncond_pool = up1
|
||||
|
||||
dtype = self.unet.dtype
|
||||
unet_dtype = self.unet.dtype
|
||||
dtype = unet_dtype
|
||||
if hasattr(dtype, "itemsize") and dtype.itemsize == 1: # fp8
|
||||
dtype = torch.float16
|
||||
self.unet.to(dtype)
|
||||
|
||||
# 4. Preprocess image and mask
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
@@ -1028,6 +1032,7 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
|
||||
if is_cancelled_callback is not None and is_cancelled_callback():
|
||||
return None
|
||||
|
||||
self.unet.to(unet_dtype)
|
||||
return latents
|
||||
|
||||
def latents_to_image(self, latents):
|
||||
|
||||
@@ -558,6 +558,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]],
|
||||
max_token_length: int,
|
||||
resolution: Optional[Tuple[int, int]],
|
||||
network_multiplier: float,
|
||||
debug_dataset: bool,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -567,6 +568,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
self.max_token_length = max_token_length
|
||||
# width/height is used when enable_bucket==False
|
||||
self.width, self.height = (None, None) if resolution is None else resolution
|
||||
self.network_multiplier = network_multiplier
|
||||
self.debug_dataset = debug_dataset
|
||||
|
||||
self.subsets: List[Union[DreamBoothSubset, FineTuningSubset]] = []
|
||||
@@ -1106,7 +1108,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
for image_key in bucket[image_index : image_index + bucket_batch_size]:
|
||||
image_info = self.image_data[image_key]
|
||||
subset = self.image_to_subset[image_key]
|
||||
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
|
||||
loss_weights.append(
|
||||
self.prior_loss_weight if image_info.is_reg else 1.0
|
||||
) # in case of fine tuning, is_reg is always False
|
||||
|
||||
flipped = subset.flip_aug and random.random() < 0.5 # not flipped or flipped with 50% chance
|
||||
|
||||
@@ -1272,6 +1276,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
example["target_sizes_hw"] = torch.stack([torch.LongTensor(x) for x in target_sizes_hw])
|
||||
example["flippeds"] = flippeds
|
||||
|
||||
example["network_multipliers"] = torch.FloatTensor([self.network_multiplier] * len(captions))
|
||||
|
||||
if self.debug_dataset:
|
||||
example["image_keys"] = bucket[image_index : image_index + self.batch_size]
|
||||
return example
|
||||
@@ -1346,15 +1352,16 @@ class DreamBoothDataset(BaseDataset):
|
||||
tokenizer,
|
||||
max_token_length,
|
||||
resolution,
|
||||
network_multiplier: float,
|
||||
enable_bucket: bool,
|
||||
min_bucket_reso: int,
|
||||
max_bucket_reso: int,
|
||||
bucket_reso_steps: int,
|
||||
bucket_no_upscale: bool,
|
||||
prior_loss_weight: float,
|
||||
debug_dataset,
|
||||
debug_dataset: bool,
|
||||
) -> None:
|
||||
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
|
||||
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
|
||||
|
||||
assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
|
||||
|
||||
@@ -1520,14 +1527,15 @@ class FineTuningDataset(BaseDataset):
|
||||
tokenizer,
|
||||
max_token_length,
|
||||
resolution,
|
||||
network_multiplier: float,
|
||||
enable_bucket: bool,
|
||||
min_bucket_reso: int,
|
||||
max_bucket_reso: int,
|
||||
bucket_reso_steps: int,
|
||||
bucket_no_upscale: bool,
|
||||
debug_dataset,
|
||||
debug_dataset: bool,
|
||||
) -> None:
|
||||
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
|
||||
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
|
||||
|
||||
self.batch_size = batch_size
|
||||
|
||||
@@ -1724,14 +1732,15 @@ class ControlNetDataset(BaseDataset):
|
||||
tokenizer,
|
||||
max_token_length,
|
||||
resolution,
|
||||
network_multiplier: float,
|
||||
enable_bucket: bool,
|
||||
min_bucket_reso: int,
|
||||
max_bucket_reso: int,
|
||||
bucket_reso_steps: int,
|
||||
bucket_no_upscale: bool,
|
||||
debug_dataset,
|
||||
debug_dataset: float,
|
||||
) -> None:
|
||||
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
|
||||
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
|
||||
|
||||
db_subsets = []
|
||||
for subset in subsets:
|
||||
@@ -1765,6 +1774,7 @@ class ControlNetDataset(BaseDataset):
|
||||
tokenizer,
|
||||
max_token_length,
|
||||
resolution,
|
||||
network_multiplier,
|
||||
enable_bucket,
|
||||
min_bucket_reso,
|
||||
max_bucket_reso,
|
||||
@@ -2039,6 +2049,8 @@ def debug_dataset(train_dataset, show_input_ids=False):
|
||||
print(
|
||||
f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}", original size: {orgsz}, crop top left: {crptl}, target size: {trgsz}, flipped: {flpdz}'
|
||||
)
|
||||
if "network_multipliers" in example:
|
||||
print(f"network multiplier: {example['network_multipliers'][j]}")
|
||||
|
||||
if show_input_ids:
|
||||
print(f"input ids: {iid}")
|
||||
@@ -2105,8 +2117,8 @@ def glob_images_pathlib(dir_path, recursive):
|
||||
|
||||
|
||||
class MinimalDataset(BaseDataset):
|
||||
def __init__(self, tokenizer, max_token_length, resolution, debug_dataset=False):
|
||||
super().__init__(tokenizer, max_token_length, resolution, debug_dataset)
|
||||
def __init__(self, tokenizer, max_token_length, resolution, network_multiplier, debug_dataset=False):
|
||||
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
|
||||
|
||||
self.num_train_images = 0 # update in subclass
|
||||
self.num_reg_images = 0 # update in subclass
|
||||
@@ -2850,14 +2862,14 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
)
|
||||
parser.add_argument("--torch_compile", action="store_true", help="use torch.compile (requires PyTorch 2.0) / torch.compile を使う")
|
||||
parser.add_argument(
|
||||
"--dynamo_backend",
|
||||
type=str,
|
||||
default="inductor",
|
||||
"--dynamo_backend",
|
||||
type=str,
|
||||
default="inductor",
|
||||
# available backends:
|
||||
# https://github.com/huggingface/accelerate/blob/d1abd59114ada8ba673e1214218cb2878c13b82d/src/accelerate/utils/dataclasses.py#L376-L388C5
|
||||
# https://pytorch.org/docs/stable/torch.compiler.html
|
||||
choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt"],
|
||||
help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)"
|
||||
choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt"],
|
||||
help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)",
|
||||
)
|
||||
parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
|
||||
parser.add_argument(
|
||||
@@ -2904,6 +2916,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
parser.add_argument(
|
||||
"--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する"
|
||||
) # TODO move to SDXL training, because it is not supported by SD1/2
|
||||
parser.add_argument("--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う")
|
||||
parser.add_argument(
|
||||
"--ddp_timeout",
|
||||
type=int,
|
||||
@@ -2946,6 +2959,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
default=None,
|
||||
help="name of tracker to use for logging, default is script-specific default name / ログ出力に使用するtrackerの名前、省略時はスクリプトごとのデフォルト名",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--wandb_run_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The name of the specific wandb session / wandb ログに表示される特定の実行の名前",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log_tracker_config",
|
||||
type=str,
|
||||
@@ -3880,7 +3899,7 @@ def prepare_accelerator(args: argparse.Namespace):
|
||||
os.environ["WANDB_DIR"] = logging_dir
|
||||
if args.wandb_api_key is not None:
|
||||
wandb.login(key=args.wandb_api_key)
|
||||
|
||||
|
||||
# torch.compile のオプション。 NO の場合は torch.compile は使わない
|
||||
dynamo_backend = "NO"
|
||||
if args.torch_compile:
|
||||
|
||||
@@ -43,6 +43,9 @@ def svd(
|
||||
clamp_quantile=0.99,
|
||||
min_diff=0.01,
|
||||
no_metadata=False,
|
||||
load_precision=None,
|
||||
load_original_model_to=None,
|
||||
load_tuned_model_to=None,
|
||||
):
|
||||
def str_to_dtype(p):
|
||||
if p == "float":
|
||||
@@ -57,28 +60,51 @@ def svd(
|
||||
if v_parameterization is None:
|
||||
v_parameterization = v2
|
||||
|
||||
load_dtype = str_to_dtype(load_precision) if load_precision else None
|
||||
save_dtype = str_to_dtype(save_precision)
|
||||
work_device = "cpu"
|
||||
|
||||
# load models
|
||||
if not sdxl:
|
||||
print(f"loading original SD model : {model_org}")
|
||||
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org)
|
||||
text_encoders_o = [text_encoder_o]
|
||||
if load_dtype is not None:
|
||||
text_encoder_o = text_encoder_o.to(load_dtype)
|
||||
unet_o = unet_o.to(load_dtype)
|
||||
|
||||
print(f"loading tuned SD model : {model_tuned}")
|
||||
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned)
|
||||
text_encoders_t = [text_encoder_t]
|
||||
if load_dtype is not None:
|
||||
text_encoder_t = text_encoder_t.to(load_dtype)
|
||||
unet_t = unet_t.to(load_dtype)
|
||||
|
||||
model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization)
|
||||
else:
|
||||
device_org = load_original_model_to if load_original_model_to else "cpu"
|
||||
device_tuned = load_tuned_model_to if load_tuned_model_to else "cpu"
|
||||
|
||||
print(f"loading original SDXL model : {model_org}")
|
||||
text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, "cpu"
|
||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, device_org
|
||||
)
|
||||
text_encoders_o = [text_encoder_o1, text_encoder_o2]
|
||||
if load_dtype is not None:
|
||||
text_encoder_o1 = text_encoder_o1.to(load_dtype)
|
||||
text_encoder_o2 = text_encoder_o2.to(load_dtype)
|
||||
unet_o = unet_o.to(load_dtype)
|
||||
|
||||
print(f"loading original SDXL model : {model_tuned}")
|
||||
text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, "cpu"
|
||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, device_tuned
|
||||
)
|
||||
text_encoders_t = [text_encoder_t1, text_encoder_t2]
|
||||
if load_dtype is not None:
|
||||
text_encoder_t1 = text_encoder_t1.to(load_dtype)
|
||||
text_encoder_t2 = text_encoder_t2.to(load_dtype)
|
||||
unet_t = unet_t.to(load_dtype)
|
||||
|
||||
model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0
|
||||
|
||||
# create LoRA network to extract weights: Use dim (rank) as alpha
|
||||
@@ -100,38 +126,54 @@ def svd(
|
||||
lora_name = lora_o.lora_name
|
||||
module_o = lora_o.org_module
|
||||
module_t = lora_t.org_module
|
||||
diff = module_t.weight - module_o.weight
|
||||
diff = module_t.weight.to(work_device) - module_o.weight.to(work_device)
|
||||
|
||||
# clear weight to save memory
|
||||
module_o.weight = None
|
||||
module_t.weight = None
|
||||
|
||||
# Text Encoder might be same
|
||||
if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff:
|
||||
text_encoder_different = True
|
||||
print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}")
|
||||
|
||||
diff = diff.float()
|
||||
diffs[lora_name] = diff
|
||||
|
||||
# clear target Text Encoder to save memory
|
||||
for text_encoder in text_encoders_t:
|
||||
del text_encoder
|
||||
|
||||
if not text_encoder_different:
|
||||
print("Text encoder is same. Extract U-Net only.")
|
||||
lora_network_o.text_encoder_loras = []
|
||||
diffs = {}
|
||||
diffs = {} # clear diffs
|
||||
|
||||
for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)):
|
||||
lora_name = lora_o.lora_name
|
||||
module_o = lora_o.org_module
|
||||
module_t = lora_t.org_module
|
||||
diff = module_t.weight - module_o.weight
|
||||
diff = diff.float()
|
||||
diff = module_t.weight.to(work_device) - module_o.weight.to(work_device)
|
||||
|
||||
if args.device:
|
||||
diff = diff.to(args.device)
|
||||
# clear weight to save memory
|
||||
module_o.weight = None
|
||||
module_t.weight = None
|
||||
|
||||
diffs[lora_name] = diff
|
||||
|
||||
# clear LoRA network, target U-Net to save memory
|
||||
del lora_network_o
|
||||
del lora_network_t
|
||||
del unet_t
|
||||
|
||||
# make LoRA with svd
|
||||
print("calculating by svd")
|
||||
lora_weights = {}
|
||||
with torch.no_grad():
|
||||
for lora_name, mat in tqdm(list(diffs.items())):
|
||||
if args.device:
|
||||
mat = mat.to(args.device)
|
||||
mat = mat.to(torch.float) # calc by float
|
||||
|
||||
# if conv_dim is None, diffs do not include LoRAs for conv2d-3x3
|
||||
conv2d = len(mat.size()) == 4
|
||||
kernel_size = None if not conv2d else mat.size()[2:4]
|
||||
@@ -171,8 +213,8 @@ def svd(
|
||||
U = U.reshape(out_dim, rank, 1, 1)
|
||||
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
|
||||
|
||||
U = U.to("cpu").contiguous()
|
||||
Vh = Vh.to("cpu").contiguous()
|
||||
U = U.to(work_device, dtype=save_dtype).contiguous()
|
||||
Vh = Vh.to(work_device, dtype=save_dtype).contiguous()
|
||||
|
||||
lora_weights[lora_name] = (U, Vh)
|
||||
|
||||
@@ -230,6 +272,13 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument(
|
||||
"--sdxl", action="store_true", help="load Stable Diffusion SDXL base model / Stable Diffusion SDXL baseのモデルを読み込む"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load_precision",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=[None, "float", "fp16", "bf16"],
|
||||
help="precision in loading, model default if omitted / 読み込み時に精度を変更して読み込む、省略時はモデルファイルによる"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_precision",
|
||||
type=str,
|
||||
@@ -285,6 +334,18 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
|
||||
+ "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load_original_model_to",
|
||||
type=str,
|
||||
default=None,
|
||||
help="location to load original model, cpu or cuda, cuda:0, etc, default is cpu, only for SDXL / 元モデル読み込み先、cpuまたはcuda、cuda:0など、省略時はcpu、SDXLのみ有効",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load_tuned_model_to",
|
||||
type=str,
|
||||
default=None,
|
||||
help="location to load tuned model, cpu or cuda, cuda:0, etc, default is cpu, only for SDXL / 派生モデル読み込み先、cpuまたはcuda、cuda:0など、省略時はcpu、SDXLのみ有効",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ einops==0.6.1
|
||||
pytorch-lightning==1.9.0
|
||||
# bitsandbytes==0.39.1
|
||||
tensorboard==2.10.1
|
||||
safetensors==0.3.1
|
||||
safetensors==0.4.2
|
||||
# gradio==3.16.2
|
||||
altair==4.2.2
|
||||
easygui==0.98.3
|
||||
|
||||
@@ -18,15 +18,10 @@ import diffusers
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
init_ipex()
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
import torchvision
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
|
||||
@@ -9,13 +9,11 @@ import random
|
||||
from einops import repeat
|
||||
import numpy as np
|
||||
import torch
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
from tqdm import tqdm
|
||||
from transformers import CLIPTokenizer
|
||||
from diffusers import EulerDiscreteScheduler
|
||||
|
||||
@@ -11,15 +11,10 @@ import toml
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
init_ipex()
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler
|
||||
from library import sdxl_model_util
|
||||
@@ -457,6 +452,8 @@ def train(args):
|
||||
|
||||
if accelerator.is_main_process:
|
||||
init_kwargs = {}
|
||||
if args.wandb_run_name:
|
||||
init_kwargs['wandb'] = {'name': args.wandb_run_name}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
|
||||
|
||||
@@ -14,13 +14,11 @@ import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from accelerate.utils import set_seed
|
||||
import accelerate
|
||||
@@ -342,6 +340,8 @@ def train(args):
|
||||
|
||||
if accelerator.is_main_process:
|
||||
init_kwargs = {}
|
||||
if args.wandb_run_name:
|
||||
init_kwargs['wandb'] = {'name': args.wandb_run_name}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers(
|
||||
|
||||
@@ -11,13 +11,11 @@ import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler, ControlNetModel
|
||||
|
||||
@@ -1,15 +1,10 @@
|
||||
import argparse
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
init_ipex()
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
from library import sdxl_model_util, sdxl_train_util, train_util
|
||||
import train_network
|
||||
|
||||
@@ -95,8 +90,8 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
unet.to(org_unet_device)
|
||||
else:
|
||||
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
|
||||
text_encoders[0].to(accelerator.device)
|
||||
text_encoders[1].to(accelerator.device)
|
||||
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoders[1].to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
|
||||
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
|
||||
|
||||
@@ -3,13 +3,9 @@ import os
|
||||
|
||||
import regex
|
||||
import torch
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
import open_clip
|
||||
from library import sdxl_model_util, sdxl_train_util, train_util
|
||||
|
||||
|
||||
@@ -12,15 +12,10 @@ import toml
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
init_ipex()
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler, ControlNetModel
|
||||
@@ -336,6 +331,8 @@ def train(args):
|
||||
)
|
||||
if accelerator.is_main_process:
|
||||
init_kwargs = {}
|
||||
if args.wandb_run_name:
|
||||
init_kwargs['wandb'] = {'name': args.wandb_run_name}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers(
|
||||
|
||||
11
train_db.py
11
train_db.py
@@ -12,15 +12,10 @@ import toml
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
init_ipex()
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler
|
||||
|
||||
@@ -268,6 +263,8 @@ def train(args):
|
||||
|
||||
if accelerator.is_main_process:
|
||||
init_kwargs = {}
|
||||
if args.wandb_run_name:
|
||||
init_kwargs['wandb'] = {'name': args.wandb_run_name}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
|
||||
|
||||
@@ -14,15 +14,10 @@ from tqdm import tqdm
|
||||
import torch
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
init_ipex()
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler
|
||||
from library import model_util
|
||||
@@ -117,7 +112,7 @@ class NetworkTrainer:
|
||||
self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype
|
||||
):
|
||||
for t_enc in text_encoders:
|
||||
t_enc.to(accelerator.device)
|
||||
t_enc.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, weight_dtype):
|
||||
input_ids = batch["input_ids"].to(accelerator.device)
|
||||
@@ -278,6 +273,7 @@ class NetworkTrainer:
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# 必要ならテキストエンコーダーの出力をキャッシュする: Text Encoderはcpuまたはgpuへ移される
|
||||
# cache text encoder outputs if needed: Text Encoder is moved to cpu or gpu
|
||||
self.cache_text_encoder_outputs_if_needed(
|
||||
args, accelerator, unet, vae, tokenizers, text_encoders, train_dataset_group, weight_dtype
|
||||
)
|
||||
@@ -309,6 +305,7 @@ class NetworkTrainer:
|
||||
)
|
||||
if network is None:
|
||||
return
|
||||
network_has_multiplier = hasattr(network, "set_multiplier")
|
||||
|
||||
if hasattr(network, "prepare_network"):
|
||||
network.prepare_network(args)
|
||||
@@ -389,17 +386,33 @@ class NetworkTrainer:
|
||||
accelerator.print("enable full bf16 training.")
|
||||
network.to(weight_dtype)
|
||||
|
||||
unet_weight_dtype = te_weight_dtype = weight_dtype
|
||||
# Experimental Feature: Put base model into fp8 to save vram
|
||||
if args.fp8_base:
|
||||
assert torch.__version__ >= "2.1.0", "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。"
|
||||
assert (
|
||||
args.mixed_precision != "no"
|
||||
), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。"
|
||||
accelerator.print("enable fp8 training.")
|
||||
unet_weight_dtype = torch.float8_e4m3fn
|
||||
te_weight_dtype = torch.float8_e4m3fn
|
||||
|
||||
unet.requires_grad_(False)
|
||||
unet.to(dtype=weight_dtype)
|
||||
unet.to(dtype=unet_weight_dtype)
|
||||
for t_enc in text_encoders:
|
||||
t_enc.requires_grad_(False)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
# TODO めちゃくちゃ冗長なのでコードを整理する
|
||||
# in case of cpu, dtype is already set to fp32 because cpu does not support fp8/fp16/bf16
|
||||
if t_enc.device.type != "cpu":
|
||||
t_enc.to(dtype=te_weight_dtype)
|
||||
# nn.Embedding not support FP8
|
||||
t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype))
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
|
||||
if train_unet:
|
||||
unet = accelerator.prepare(unet)
|
||||
else:
|
||||
unet.to(accelerator.device, dtype=weight_dtype) # move to device because unet is not prepared by accelerator
|
||||
unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator
|
||||
if train_text_encoder:
|
||||
if len(text_encoders) > 1:
|
||||
text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders]
|
||||
@@ -407,8 +420,8 @@ class NetworkTrainer:
|
||||
text_encoder = accelerator.prepare(text_encoder)
|
||||
text_encoders = [text_encoder]
|
||||
else:
|
||||
for t_enc in text_encoders:
|
||||
t_enc.to(accelerator.device, dtype=weight_dtype)
|
||||
pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set
|
||||
|
||||
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
@@ -421,9 +434,6 @@ class NetworkTrainer:
|
||||
if train_text_encoder:
|
||||
t_enc.text_model.embeddings.requires_grad_(True)
|
||||
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
if not train_text_encoder: # train U-Net only
|
||||
unet.parameters().__next__().requires_grad_(True)
|
||||
else:
|
||||
unet.eval()
|
||||
for t_enc in text_encoders:
|
||||
@@ -684,6 +694,8 @@ class NetworkTrainer:
|
||||
|
||||
if accelerator.is_main_process:
|
||||
init_kwargs = {}
|
||||
if args.wandb_run_name:
|
||||
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers(
|
||||
@@ -752,7 +764,17 @@ class NetworkTrainer:
|
||||
accelerator.print("NaN found in latents, replacing with zeros")
|
||||
latents = torch.nan_to_num(latents, 0, out=latents)
|
||||
latents = latents * self.vae_scale_factor
|
||||
b_size = latents.shape[0]
|
||||
|
||||
# get multiplier for each sample
|
||||
if network_has_multiplier:
|
||||
multipliers = batch["network_multipliers"]
|
||||
# if all multipliers are same, use single multiplier
|
||||
if torch.all(multipliers == multipliers[0]):
|
||||
multipliers = multipliers[0].item()
|
||||
else:
|
||||
raise NotImplementedError("multipliers for each sample is not supported yet")
|
||||
# print(f"set multiplier: {multipliers}")
|
||||
accelerator.unwrap_model(network).set_multiplier(multipliers)
|
||||
|
||||
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
|
||||
# Get the text embedding for conditioning
|
||||
@@ -776,10 +798,24 @@ class NetworkTrainer:
|
||||
args, noise_scheduler, latents
|
||||
)
|
||||
|
||||
# ensure the hidden state will require grad
|
||||
if args.gradient_checkpointing:
|
||||
for x in noisy_latents:
|
||||
x.requires_grad_(True)
|
||||
for t in text_encoder_conds:
|
||||
t.requires_grad_(True)
|
||||
|
||||
# Predict the noise residual
|
||||
with accelerator.autocast():
|
||||
noise_pred = self.call_unet(
|
||||
args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype
|
||||
args,
|
||||
accelerator,
|
||||
unet,
|
||||
noisy_latents.requires_grad_(train_unet),
|
||||
timesteps,
|
||||
text_encoder_conds,
|
||||
batch,
|
||||
weight_dtype,
|
||||
)
|
||||
|
||||
if args.v_parameterization:
|
||||
@@ -806,10 +842,11 @@ class NetworkTrainer:
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
accelerator.backward(loss)
|
||||
self.all_reduce_network(accelerator, network) # sync DDP grad manually
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
if accelerator.sync_gradients:
|
||||
self.all_reduce_network(accelerator, network) # sync DDP grad manually
|
||||
if args.max_grad_norm != 0.0:
|
||||
params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
|
||||
@@ -8,15 +8,10 @@ import toml
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
init_ipex()
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler
|
||||
from transformers import CLIPTokenizer
|
||||
@@ -504,6 +499,8 @@ class TextualInversionTrainer:
|
||||
|
||||
if accelerator.is_main_process:
|
||||
init_kwargs = {}
|
||||
if args.wandb_run_name:
|
||||
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers(
|
||||
@@ -728,14 +725,13 @@ class TextualInversionTrainer:
|
||||
is_main_process = accelerator.is_main_process
|
||||
if is_main_process:
|
||||
text_encoder = accelerator.unwrap_model(text_encoder)
|
||||
updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone()
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
if args.save_state and is_main_process:
|
||||
train_util.save_state_on_train_end(args, accelerator)
|
||||
|
||||
updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone()
|
||||
|
||||
if is_main_process:
|
||||
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
|
||||
save_model(ckpt_name, updated_embs_list, global_step, num_train_epochs, force_sync_upload=True)
|
||||
|
||||
@@ -8,13 +8,11 @@ from multiprocessing import Value
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
from accelerate.utils import set_seed
|
||||
import diffusers
|
||||
from diffusers import DDPMScheduler
|
||||
@@ -394,6 +392,8 @@ def train(args):
|
||||
|
||||
if accelerator.is_main_process:
|
||||
init_kwargs = {}
|
||||
if args.wandb_run_name:
|
||||
init_kwargs['wandb'] = {'name': args.wandb_run_name}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
|
||||
|
||||
Reference in New Issue
Block a user