Compare commits

..

32 Commits

Author SHA1 Message Date
Kohya S
5095f29e7c fix hf upload #1196 2024-03-20 18:46:13 +09:00
Kohya S
855add067b update option help and readme 2024-03-20 18:14:05 +09:00
Kohya S
bf6cd4b9da Merge pull request #1168 from gesen2egee/save_state_on_train_end
Save state on train end
2024-03-20 18:02:13 +09:00
Kohya S
3b0db0f17f update readme 2024-03-20 17:45:35 +09:00
Kohya S
119cc99fb0 Merge pull request #1167 from Horizon1704/patch-1
Add "encoding='utf-8'" for --config_file
2024-03-20 17:39:08 +09:00
Kohya S
5f6196e4c7 update readme 2024-03-20 16:35:23 +09:00
Victor Espinoza-Guerra
46331a9e8e English Translation of config_README-ja.md (#1175)
* Add files via upload

Creating template to work on.

* Update config_README-en.md

Total Conversion from Japanese to English.

* Update config_README-en.md

* Update config_README-en.md

* Update config_README-en.md
2024-03-20 16:31:01 +09:00
Kohya S
cf09c6aa9f Merge pull request #1177 from KohakuBlueleaf/random-strength-noise
Random strength for Noise Offset and input perturbation noise
2024-03-20 16:17:16 +09:00
Kohya S
80dbbf5e48 tagger now stores model under repo_id subdir 2024-03-20 16:14:57 +09:00
Kohya S
7da41be281 Merge pull request #1192 from sdbds/main
Add WDV3 support
2024-03-20 15:49:55 +09:00
Kohya S
e281e867e6 Merge branch 'main' into dev 2024-03-20 15:49:08 +09:00
青龍聖者@bdsqlsz
6c51c971d1 fix typo 2024-03-20 09:35:21 +08:00
青龍聖者@bdsqlsz
a71c35ccd9 Update requirements.txt 2024-03-18 22:31:59 +08:00
青龍聖者@bdsqlsz
5410a8c79b Update requirements.txt 2024-03-18 22:31:00 +08:00
青龍聖者@bdsqlsz
a7dff592d3 Update tag_images_by_wd14_tagger.py
add WDV3
2024-03-18 22:29:05 +08:00
Kohya S
f9317052ed update readme for timestep embs bug 2024-03-18 08:53:23 +09:00
Kohya S
443f02942c fix doc 2024-03-15 21:35:14 +09:00
Kohya S
0a8ec5224e Merge branch 'main' into dev 2024-03-15 21:33:07 +09:00
Kohya S
6b1520a46b Merge pull request #1187 from kohya-ss/fix-timeemb
fix sdxl timestep embedding
2024-03-15 21:17:13 +09:00
Kohya S
f811b115ba fix sdxl timestep embedding 2024-03-15 21:05:00 +09:00
kblueleaf
53954a1e2e use correct settings for parser 2024-03-13 18:21:49 +08:00
kblueleaf
86399407b2 random noise_offset strength 2024-03-13 18:21:49 +08:00
kblueleaf
948029fe61 random ip_noise_gamma strength 2024-03-13 18:21:49 +08:00
gesen2egee
d282c45002 Update train_network.py 2024-03-11 23:56:09 +08:00
gesen2egee
095b8035e6 save state on train end 2024-03-10 23:33:38 +08:00
Horizon1704
124ec45876 Add "encoding='utf-8'" 2024-03-10 22:53:05 +08:00
Kohya S
14c9372a38 add doc about Colab/rich issue 2024-03-03 21:47:37 +09:00
Kohya S
074d32af20 Merge branch 'main' into dev 2024-02-27 18:53:43 +09:00
Kohya S
2d7389185c Merge pull request #1094 from kohya-ss/dependabot/github_actions/crate-ci/typos-1.17.2
Bump crate-ci/typos from 1.16.26 to 1.17.2
2024-02-27 18:23:41 +09:00
Kohya S
577e9913ca add some new dataset settings 2024-02-26 20:01:25 +09:00
Kohya S
fccbee2727 revert logging #1137 2024-02-25 10:43:14 +09:00
dependabot[bot]
716a92cbed Bump crate-ci/typos from 1.16.26 to 1.17.2
Bumps [crate-ci/typos](https://github.com/crate-ci/typos) from 1.16.26 to 1.17.2.
- [Release notes](https://github.com/crate-ci/typos/releases)
- [Changelog](https://github.com/crate-ci/typos/blob/master/CHANGELOG.md)
- [Commits](https://github.com/crate-ci/typos/compare/v1.16.26...v1.17.2)

---
updated-dependencies:
- dependency-name: crate-ci/typos
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-02-01 01:57:52 +00:00
28 changed files with 688 additions and 5249 deletions

View File

@@ -18,4 +18,4 @@ jobs:
- uses: actions/checkout@v4
- name: typos-action
uses: crate-ci/typos@v1.16.26
uses: crate-ci/typos@v1.17.2

364
README.md
View File

@@ -1,174 +1,3 @@
# Training Stable Cascade Stage C
This is an experimental feature. There may be bugs.
__Feb 25, 2024 Update:__ Fixed a bug that the LoRA weights trained can be loaded in ComfyUI. If you still have a problem, please let me know.
__Feb 25, 2024 Update:__ Fixed a bug that Stage C training with mixed precision behaves the same as `--full_bf16` (fp16) regardless of `--full_bf16` (fp16) specified.
This is because the Stage C weights were loaded in bf16/fp16. With this fix, the memory usage without `--full_bf16` (fp16) specified will increase, so you may need to specify `--full_bf16` (fp16) as needed.
__Feb 22, 2024 Update:__ Fixed a bug that LoRA is not applied to some modules (to_q/k/v and to_out) in Attention. Also, the model structure of Stage C has been changed, and you can choose xformers and SDPA (SDPA was used before). Please specify `--sdpa` or `--xformers` option.
__Feb 20, 2024 Update:__ There was a problem with the preprocessing of the EfficientNetEncoder, and the latents became invalid (the saturation of the training results decreases). If you have cached `_sc_latents.npz` files with `--cache_latents_to_disk`, please delete them before training.
## Usage
Training is run with `stable_cascade_train_stage_c.py`.
The main options are the same as `sdxl_train.py`. The following options have been added.
- `--effnet_checkpoint_path`: Specifies the path to the EfficientNetEncoder weights.
- `--stage_c_checkpoint_path`: Specifies the path to the Stage C weights.
- `--text_model_checkpoint_path`: Specifies the path to the Text Encoder weights. If omitted, the model from Hugging Face will be used.
- `--save_text_model`: Saves the model downloaded from Hugging Face to `--text_model_checkpoint_path`.
- `--previewer_checkpoint_path`: Specifies the path to the Previewer weights. Used to generate sample images during training.
- `--adaptive_loss_weight`: Uses [Adaptive Loss Weight](https://github.com/Stability-AI/StableCascade/blob/master/gdf/loss_weights.py) . If omitted, P2LossWeight is used. The official settings use Adaptive Loss Weight.
The learning rate is set to 1e-4 in the official settings.
The first time, specify `--text_model_checkpoint_path` and `--save_text_model` to save the Text Encoder weights. From the next time, specify `--text_model_checkpoint_path` to load the saved weights.
Sample image generation during training is done with Perviewer. Perviewer is a simple decoder that converts EfficientNetEncoder latents to images.
Some of the options for SDXL are simply ignored or cause an error (especially noise-related options such as `--noise_offset`). `--vae_batch_size` and `--no_half_vae` are applied directly to the EfficientNetEncoder (when `bf16` is specified for mixed precision, `--no_half_vae` is not necessary).
Options for latents and Text Encoder output caches can be used as is, but since the EfficientNetEncoder is much lighter than the VAE, you may not need to use the cache unless memory is particularly tight.
`--gradient_checkpointing`, `--full_bf16`, and `--full_fp16` (untested) to reduce memory consumption can be used as is.
A scale of about 4 is suitable for sample image generation.
Since the official settings use `bf16` for training, training with `fp16` may be unstable.
The code for training the Text Encoder is also written, but it is untested.
### Command line sample
```batch
accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 stable_cascade_train_stage_c.py --mixed_precision bf16 --save_precision bf16 --max_data_loader_n_workers 2 --persistent_data_loader_workers --gradient_checkpointing --learning_rate 1e-4 --optimizer_type adafactor --optimizer_args "scale_parameter=False" "relative_step=False" "warmup_init=False" --max_train_epochs 10 --save_every_n_epochs 1 --save_precision bf16 --output_dir ../output --output_name sc_test - --stage_c_checkpoint_path ../models/stage_c_bf16.safetensors --effnet_checkpoint_path ../models/effnet_encoder.safetensors --previewer_checkpoint_path ../models/previewer.safetensors --dataset_config ../dataset/config_bs1.toml --sample_every_n_epochs 1 --sample_prompts ../dataset/prompts.txt --adaptive_loss_weight
```
### About the dataset for fine tuning
If the latents cache files for SD/SDXL exist (extension `*.npz`), it will be read and an error will occur during training. Please move them to another location in advance.
After that, run `finetune/prepare_buckets_latents.py` with the `--stable_cascade` option to create latents cache files for Stable Cascade (suffix `_sc_latents.npz` is added).
## LoRA training
`stable_cascade_train_c_network.py` is used for LoRA training. The main options are the same as `train_network.py`, and the same options as `stable_cascade_train_stage_c.py` have been added.
__This is an experimental feature, so the format of the saved weights may change in the future and become incompatible.__
There is no compatibility with the official LoRA, and the implementation of Text Encoder embedding training (Pivotal Tuning) in the official implementation is not implemented here.
Text Encoder LoRA training is implemented, but untested.
## Image generation
Basic image generation functionality is available in `stable_cascade_gen_img.py`. See `--help` for usage.
When using LoRA, specify `--network_module networks.lora --network_mul 1 --network_weights lora_weights.safetensors`.
The following prompt options are available.
* `--n` Negative prompt up to the next option.
* `--w` Specifies the width of the generated image.
* `--h` Specifies the height of the generated image.
* `--d` Specifies the seed of the generated image.
* `--l` Specifies the CFG scale of the generated image.
* `--s` Specifies the number of steps in the generation.
* `--t` Specifies the t_start of the generation.
* `--f` Specifies the shift of the generation.
# Stable Cascade Stage C の学習
実験的機能です。不具合があるかもしれません。
__2024/2/25 追記:__ 学習される LoRA の重みが ComfyUI で読み込めるよう修正しました。依然として不具合がある場合にはご連絡ください。
__2024/2/25 追記:__ Mixed precision 時のStage C の学習が、 `--full_bf16` (fp16) の指定に関わらず `--full_bf16` (fp16) 指定時と同じ動作となる(と思われる)不具合を修正しました。
Stage C の重みを bf16/fp16 で読み込んでいたためです。この修正により `--full_bf16` (fp16) 未指定時のメモリ使用量が増えますので、必要に応じて `--full_bf16` (fp16) を指定してください。
__2024/2/22 追記:__ LoRA が一部のモジュールAttention の to_q/k/v および to_outに適用されない不具合を修正しました。また Stage C のモデル構造を変更し xformers と SDPA を選べるようになりました(今までは SDPA が使用されていました)。`--sdpa` または `--xformers` オプションを指定してください。
__2024/2/20 追記:__ EfficientNetEncoder の前処理に不具合があり、latents が不正になっていました(学習結果の彩度が低下する現象が起きます)。`--cache_latents_to_disk` でキャッシュした `_sc_latents.npz` がある場合、いったん削除してから学習してください。
## 使い方
学習は `stable_cascade_train_stage_c.py` で行います。
主なオプションは `sdxl_train.py` と同様です。以下のオプションが追加されています。
- `--effnet_checkpoint_path` : EfficientNetEncoder の重みのパスを指定します。
- `--stage_c_checkpoint_path` : Stage C の重みのパスを指定します。
- `--text_model_checkpoint_path` : Text Encoder の重みのパスを指定します。省略時は Hugging Face のモデルを使用します。
- `--save_text_model` : `--text_model_checkpoint_path` にHugging Face からダウンロードしたモデルを保存します。
- `--previewer_checkpoint_path` : Previewer の重みのパスを指定します。学習中のサンプル画像生成に使用します。
- `--adaptive_loss_weight` : [Adaptive Loss Weight](https://github.com/Stability-AI/StableCascade/blob/master/gdf/loss_weights.py) を用います。省略時は P2LossWeight が使用されます。公式では Adaptive Loss Weight が使用されているようです。
学習率は、公式の設定では 1e-4 のようです。
初回は `--text_model_checkpoint_path``--save_text_model` を指定して、Text Encoder の重みを保存すると良いでしょう。次からは `--text_model_checkpoint_path` を指定して、保存した重みを読み込むことができます。
学習中のサンプル画像生成は Perviewer で行われます。Previewer は EfficientNetEncoder の latents を画像に変換する簡易的な decoder です。
SDXL の向けの一部のオプションは単に無視されるか、エラーになります(特に `--noise_offset` などのノイズ関係)。`--vae_batch_size` および `--no_half_vae` はそのまま EfficientNetEncoder に適用されますmixed precision に `bf16` 指定時は `--no_half_vae` は不要のようです)。
latents および Text Encoder 出力キャッシュのためのオプションはそのまま使用できますが、EfficientNetEncoder は VAE よりもかなり軽量のため、メモリが特に厳しい場合以外はキャッシュを使用する必要はないかもしれません。
メモリ消費を抑えるための `--gradient_checkpointing``--full_bf16``--full_fp16`(未テスト)はそのまま使用できます。
サンプル画像生成時の Scale には 4 程度が適しているようです。
公式の設定では学習に `bf16` を用いているため、`fp16` での学習は不安定かもしれません。
Text Encoder 学習のコードも書いてありますが、未テストです。
### コマンドラインのサンプル
[Command-line-sample](#command-line-sample)を参照してください。
### fine tuning方式のデータセットについて
SD/SDXL 向けの latents キャッシュファイル(拡張子 `*.npz`)が存在するとそれを読み込んでしまい学習時にエラーになります。あらかじめ他の場所に退避しておいてください。
その後、`finetune/prepare_buckets_latents.py` をオプション `--stable_cascade` を指定して実行すると、Stable Cascade 向けの latents キャッシュファイル(接尾辞 `_sc_latents.npz` が付きます)が作成されます。
## LoRA 等の学習
LoRA の学習は `stable_cascade_train_c_network.py` で行います。主なオプションは `train_network.py` と同様で、`stable_cascade_train_stage_c.py` と同様のオプションが追加されています。
__実験的機能のため、保存される重みのフォーマットは将来的に変更され、互換性がなくなる可能性があります。__
公式の LoRA と重みの互換性はありません。また公式で実装されている Text Encoder の embedding 学習Pivotal Tuningも実装されていません。
Text Encoder の LoRA 学習は実装してありますが、未テストです。
## 画像生成
最低限の画像生成機能が `stable_cascade_gen_img.py` にあります。使用法は `--help` を参照してください。
LoRA 使用時は `--network_module networks.lora --network_mul 1 --network_weights lora_weights.safetensors` のように指定します。
プロンプトオプションとして以下が使用できます。
* `--n` Negative prompt up to the next option.
* `--w` Specifies the width of the generated image.
* `--h` Specifies the height of the generated image.
* `--d` Specifies the seed of the generated image.
* `--l` Specifies the CFG scale of the generated image.
* `--s` Specifies the number of steps in the generation.
* `--t` Specifies the t_start of the generation.
* `--f` Specifies the shift of the generation.
---
__SDXL is now supported. The sdxl branch has been merged into the main branch. If you update the repository, please follow the upgrade instructions. Also, the version of accelerate has been updated, so please run accelerate config again.__ The documentation for SDXL training is [here](./README.md#sdxl-training).
This repository contains training, generation and utility scripts for Stable Diffusion.
@@ -204,6 +33,7 @@ Most of the documents are written in Japanese.
* [Training guide - common](./docs/train_README-ja.md) : data preparation, options etc...
* [Chinese version](./docs/train_README-zh.md)
* [Dataset config](./docs/config_README-ja.md)
* [English version](./docs/config_README-en.md)
* [DreamBooth training guide](./docs/train_db_README-ja.md)
* [Step by Step fine-tuning guide](./docs/fine_tune_README_ja.md):
* [training LoRA](./docs/train_network_README-ja.md)
@@ -420,6 +250,140 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum
## Change History
### Working in progress
- Colab seems to stop with log output. Try specifying `--console_log_simple` option in the training script to disable rich logging.
- The `.toml` file for the dataset config is now read in UTF-8 encoding. PR [#1167](https://github.com/kohya-ss/sd-scripts/pull/1167) Thanks to Horizon1704!
- `train_network.py` and `sdxl_train_network.py` are modified to record some dataset settings in the metadata of the trained model (`caption_prefix`, `caption_suffix`, `keep_tokens_separator`, `secondary_separator`, `enable_wildcard`).
- Some features are added to the dataset subset settings.
- `secondary_separator` is added to specify the tag separator that is not the target of shuffling or dropping.
- Specify `secondary_separator=";;;"`. When you specify `secondary_separator`, the part is not shuffled or dropped. See the example below.
- `enable_wildcard` is added. When set to `true`, the wildcard notation `{aaa|bbb|ccc}` can be used. See the example below.
- `keep_tokens_separator` is updated to be used twice in the caption. When you specify `keep_tokens_separator="|||"`, the part divided by the second `|||` is not shuffled or dropped and remains at the end.
- The existing features `caption_prefix` and `caption_suffix` can be used together. `caption_prefix` and `caption_suffix` are processed first, and then `enable_wildcard`, `keep_tokens_separator`, shuffling and dropping, and `secondary_separator` are processed in order.
- The examples are [shown below](#example-of-dataset-settings--データセット設定の記述例).
- The support for v3 repositories is added to `tag_image_by_wd14_tagger.py` (`--onnx` option only). PR [#1192](https://github.com/kohya-ss/sd-scripts/pull/1192) Thanks to sdbds!
- Onnx may need to be updated. Onnx is not installed by default, so please install or update it with `pip install onnx==1.15.0 onnxruntime-gpu==1.17.1` etc. Please also check the comments in `requirements.txt`.
- The model is now saved in the subdirectory as `--repo_id` in `tag_image_by_wd14_tagger.py` . This caches multiple repo_id models. Please delete unnecessary files under `--model_dir`.
- The options `--noise_offset_random_strength` and `--ip_noise_gamma_random_strength` are added to each training script. These options can be used to vary the noise offset and ip noise gamma in the range of 0 to the specified value. PR [#1177](https://github.com/kohya-ss/sd-scripts/pull/1177) Thanks to KohakuBlueleaf!
- The [English version of the dataset settings documentation](./docs/config_README-en.md) is added. PR [#1175](https://github.com/kohya-ss/sd-scripts/pull/1175) Thanks to darkstorm2150!
- The options `--save_state_on_train_end` are added to each training script. PR [#1168](https://github.com/kohya-ss/sd-scripts/pull/1168) Thanks to gesen2egee!
- Colab での動作時、ログ出力で停止してしまうようです。学習スクリプトに `--console_log_simple` オプションを指定し、rich のロギングを無効してお試しください。
- データセット設定の `.toml` ファイルが UTF-8 encoding で読み込まれるようになりました。PR [#1167](https://github.com/kohya-ss/sd-scripts/pull/1167) Horizon1704 氏に感謝します。
- `train_network.py` および `sdxl_train_network.py` で、学習したモデルのメタデータに一部のデータセット設定が記録されるよう修正しました(`caption_prefix``caption_suffix``keep_tokens_separator``secondary_separator``enable_wildcard`)。
- データセットのサブセット設定にいくつかの機能を追加しました。
- シャッフルの対象とならないタグ分割識別子の指定 `secondary_separator` を追加しました。`secondary_separator=";;;"` のように指定します。`secondary_separator` で区切ることで、その部分はシャッフル、drop 時にまとめて扱われます。詳しくは記述例をご覧ください。
- `enable_wildcard` を追加しました。`true` にするとワイルドカード記法 `{aaa|bbb|ccc}` が使えます。詳しくは記述例をご覧ください。
- `keep_tokens_separator` をキャプション内に 2 つ使えるようにしました。たとえば `keep_tokens_separator="|||"` と指定したとき、`1girl, hatsune miku, vocaloid ||| stage, mic ||| best quality, rating: general` とキャプションを指定すると、二番目の `|||` で分割された部分はシャッフル、drop されず末尾に残ります。
- 既存の機能 `caption_prefix``caption_suffix` とあわせて使えます。`caption_prefix``caption_suffix` は一番最初に処理され、その後、ワイルドカード、`keep_tokens_separator`、シャッフルおよび drop、`secondary_separator` の順に処理されます。
- `tag_image_by_wd14_tagger.py` で v3 のリポジトリがサポートされました(`--onnx` 指定時のみ有効)。 PR [#1192](https://github.com/kohya-ss/sd-scripts/pull/1192) sdbds 氏に感謝します。
- Onnx のバージョンアップが必要になるかもしれません。デフォルトでは Onnx はインストールされていませんので、`pip install onnx==1.15.0 onnxruntime-gpu==1.17.1` 等でインストール、アップデートしてください。`requirements.txt` のコメントもあわせてご確認ください。
- `tag_image_by_wd14_tagger.py` で、モデルを`--repo_id` のサブディレクトリに保存するようにしました。これにより複数のモデルファイルがキャッシュされます。`--model_dir` 直下の不要なファイルは削除願います。
- 各学習スクリプトに、noise offset、ip noise gammaを、それぞれ 0~指定した値の範囲で変動させるオプション `--noise_offset_random_strength` および `--ip_noise_gamma_random_strength` が追加されました。 PR [#1177](https://github.com/kohya-ss/sd-scripts/pull/1177) KohakuBlueleaf 氏に感謝します。
- データセット設定の[英語版ドキュメント](./docs/config_README-en.md) が追加されました。PR [#1175](https://github.com/kohya-ss/sd-scripts/pull/1175) darkstorm2150 氏に感謝します。
- 各学習スクリプトに、学習終了時に state を保存する `--save_state_on_train_end` オプションが追加されました。 PR [#1168](https://github.com/kohya-ss/sd-scripts/pull/1168) gesen2egee 氏に感謝します。
#### Example of dataset settings / データセット設定の記述例:
```toml
[general]
flip_aug = true
color_aug = false
resolution = [1024, 1024]
[[datasets]]
batch_size = 6
enable_bucket = true
bucket_no_upscale = true
caption_extension = ".txt"
keep_tokens_separator= "|||"
shuffle_caption = true
caption_tag_dropout_rate = 0.1
secondary_separator = ";;;" # subset 側に書くこともできます / can be written in the subset side
enable_wildcard = true # 同上 / same as above
[[datasets.subsets]]
image_dir = "/path/to/image_dir"
num_repeats = 1
# ||| の前後はカンマは不要です(自動的に追加されます) / No comma is required before and after ||| (it is added automatically)
caption_prefix = "1girl, hatsune miku, vocaloid |||"
# ||| の後はシャッフル、drop されず残ります / After |||, it is not shuffled or dropped and remains
# 単純に文字列として連結されるので、カンマなどは自分で入れる必要があります / It is simply concatenated as a string, so you need to put commas yourself
caption_suffix = ", anime screencap ||| masterpiece, rating: general"
```
#### Example of caption, secondary_separator notation: `secondary_separator = ";;;"`
```txt
1girl, hatsune miku, vocaloid, upper body, looking at viewer, sky;;;cloud;;;day, outdoors
```
The part `sky;;;cloud;;;day` is replaced with `sky,cloud,day` without shuffling or dropping. When shuffling and dropping are enabled, it is processed as a whole (as one tag). For example, it becomes `vocaloid, 1girl, upper body, sky,cloud,day, outdoors, hatsune miku` (shuffled) or `vocaloid, 1girl, outdoors, looking at viewer, upper body, hatsune miku` (dropped).
#### Example of caption, enable_wildcard notation: `enable_wildcard = true`
```txt
1girl, hatsune miku, vocaloid, upper body, looking at viewer, {simple|white} background
```
`simple` or `white` is randomly selected, and it becomes `simple background` or `white background`.
```txt
1girl, hatsune miku, vocaloid, {{retro style}}
```
If you want to include `{` or `}` in the tag string, double them like `{{` or `}}` (in this example, the actual caption used for training is `{retro style}`).
#### Example of caption, `keep_tokens_separator` notation: `keep_tokens_separator = "|||"`
```txt
1girl, hatsune miku, vocaloid ||| stage, microphone, white shirt, smile ||| best quality, rating: general
```
It becomes `1girl, hatsune miku, vocaloid, microphone, stage, white shirt, best quality, rating: general` or `1girl, hatsune miku, vocaloid, white shirt, smile, stage, microphone, best quality, rating: general` etc.
#### キャプション記述例、secondary_separator 記法:`secondary_separator = ";;;"` の場合
```txt
1girl, hatsune miku, vocaloid, upper body, looking at viewer, sky;;;cloud;;;day, outdoors
```
`sky;;;cloud;;;day` の部分はシャッフル、drop されず `sky,cloud,day` に置換されます。シャッフル、drop が有効な場合、まとめて(一つのタグとして)処理されます。つまり `vocaloid, 1girl, upper body, sky,cloud,day, outdoors, hatsune miku` (シャッフル)や `vocaloid, 1girl, outdoors, looking at viewer, upper body, hatsune miku` drop されたケース)などになります。
#### キャプション記述例、ワイルドカード記法: `enable_wildcard = true` の場合
```txt
1girl, hatsune miku, vocaloid, upper body, looking at viewer, {simple|white} background
```
ランダムに `simple` または `white` が選ばれ、`simple background` または `white background` になります。
```txt
1girl, hatsune miku, vocaloid, {{retro style}}
```
タグ文字列に `{``}` そのものを含めたい場合は `{{``}}` のように二つ重ねてください(この例では実際に学習に用いられるキャプションは `{retro style}` になります)。
#### キャプション記述例、`keep_tokens_separator` 記法: `keep_tokens_separator = "|||"` の場合
```txt
1girl, hatsune miku, vocaloid ||| stage, microphone, white shirt, smile ||| best quality, rating: general
```
`1girl, hatsune miku, vocaloid, microphone, stage, white shirt, best quality, rating: general``1girl, hatsune miku, vocaloid, white shirt, smile, stage, microphone, best quality, rating: general` などになります。
### Mar 15, 2024 / 2024/3/15: v0.8.5
- Fixed a bug that the value of timestep embedding during SDXL training was incorrect.
- Please update for SDXL training.
- The inference with the generation script is also fixed.
- This fix appears to resolve an issue where unintended artifacts occurred in trained models under certain conditions.
We would like to express our deep gratitude to Mark Saint (cacoe) from leonardo.ai, for reporting the issue and cooperating with the verification, and to gcem156 for the advice provided in identifying the part of the code that needed to be fixed.
- SDXL 学習時の timestep embedding の値が誤っていたのを修正しました。
- SDXL の学習時にはアップデートをお願いいたします。
- 生成スクリプトでの推論時についてもあわせて修正しました。
- この修正により、特定の条件下で学習されたモデルに意図しないアーティファクトが発生する問題が解消されるようです。問題を報告いただき、また検証にご協力いただいた leonardo.ai の Mark Saint (cacoe) 氏、および修正点の特定に関するアドバイスをいただいた gcem156 氏に深く感謝いたします。
### Feb 24, 2024 / 2024/2/24: v0.8.4
- The log output has been improved. PR [#905](https://github.com/kohya-ss/sd-scripts/pull/905) Thanks to shirayu!
@@ -475,64 +439,6 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum
- 複数 GPU での学習時に `network_multiplier` を指定するとクラッシュする不具合が修正されました。 PR [#1084](https://github.com/kohya-ss/sd-scripts/pull/1084) fireicewolf 氏に感謝します。
- ControlNet-LLLite の学習がエラーになる不具合を修正しました。
### Jan 23, 2024 / 2024/1/23: v0.8.2
- [Experimental] The `--fp8_base` option is added to the training scripts for LoRA etc. The base model (U-Net, and Text Encoder when training modules for Text Encoder) can be trained with fp8. PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) Thanks to KohakuBlueleaf!
- Please specify `--fp8_base` in `train_network.py` or `sdxl_train_network.py`.
- PyTorch 2.1 or later is required.
- If you use xformers with PyTorch 2.1, please see [xformers repository](https://github.com/facebookresearch/xformers) and install the appropriate version according to your CUDA version.
- The sample image generation during training consumes a lot of memory. It is recommended to turn it off.
- [Experimental] The network multiplier can be specified for each dataset in the training scripts for LoRA etc.
- This is an experimental option and may be removed or changed in the future.
- For example, if you train with state A as `1.0` and state B as `-1.0`, you may be able to generate by switching between state A and B depending on the LoRA application rate.
- Also, if you prepare five states and train them as `0.2`, `0.4`, `0.6`, `0.8`, and `1.0`, you may be able to generate by switching the states smoothly depending on the application rate.
- Please specify `network_multiplier` in `[[datasets]]` in `.toml` file.
- Some options are added to `networks/extract_lora_from_models.py` to reduce the memory usage.
- `--load_precision` option can be used to specify the precision when loading the model. If the model is saved in fp16, you can reduce the memory usage by specifying `--load_precision fp16` without losing precision.
- `--load_original_model_to` option can be used to specify the device to load the original model. `--load_tuned_model_to` option can be used to specify the device to load the derived model. The default is `cpu` for both options, but you can specify `cuda` etc. You can reduce the memory usage by loading one of them to GPU. This option is available only for SDXL.
- The gradient synchronization in LoRA training with multi-GPU is improved. PR [#1064](https://github.com/kohya-ss/sd-scripts/pull/1064) Thanks to KohakuBlueleaf!
- The code for Intel IPEX support is improved. PR [#1060](https://github.com/kohya-ss/sd-scripts/pull/1060) Thanks to akx!
- Fixed a bug in multi-GPU Textual Inversion training.
- 実験的 LoRA等の学習スクリプトで、ベースモデルU-Net、および Text Encoder のモジュール学習時は Text Encoder も)の重みを fp8 にして学習するオプションが追加されました。 PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) KohakuBlueleaf 氏に感謝します。
- `train_network.py` または `sdxl_train_network.py``--fp8_base` を指定してください。
- PyTorch 2.1 以降が必要です。
- PyTorch 2.1 で xformers を使用する場合は、[xformers のリポジトリ](https://github.com/facebookresearch/xformers) を参照し、CUDA バージョンに応じて適切なバージョンをインストールしてください。
- 学習中のサンプル画像生成はメモリを大量に消費するため、オフにすることをお勧めします。
- (実験的) LoRA 等の学習で、データセットごとに異なるネットワーク適用率を指定できるようになりました。
- 実験的オプションのため、将来的に削除または仕様変更される可能性があります。
- たとえば状態 A を `1.0`、状態 B を `-1.0` として学習すると、LoRA の適用率に応じて状態 A と B を切り替えつつ生成できるかもしれません。
- また、五段階の状態を用意し、それぞれ `0.2``0.4``0.6``0.8``1.0` として学習すると、適用率でなめらかに状態を切り替えて生成できるかもしれません。
- `.toml` ファイルで `[[datasets]]``network_multiplier` を指定してください。
- `networks/extract_lora_from_models.py` に使用メモリ量を削減するいくつかのオプションを追加しました。
- `--load_precision` で読み込み時の精度を指定できます。モデルが fp16 で保存されている場合は `--load_precision fp16` を指定して精度を変えずにメモリ量を削減できます。
- `--load_original_model_to` で元モデルを読み込むデバイスを、`--load_tuned_model_to` で派生モデルを読み込むデバイスを指定できます。デフォルトは両方とも `cpu` ですがそれぞれ `cuda` 等を指定できます。片方を GPU に読み込むことでメモリ量を削減できます。SDXL の場合のみ有効です。
- マルチ GPU での LoRA 等の学習時に勾配の同期が改善されました。 PR [#1064](https://github.com/kohya-ss/sd-scripts/pull/1064) KohakuBlueleaf 氏に感謝します。
- Intel IPEX サポートのコードが改善されました。PR [#1060](https://github.com/kohya-ss/sd-scripts/pull/1060) akx 氏に感謝します。
- マルチ GPU での Textual Inversion 学習の不具合を修正しました。
- `.toml` example for network multiplier / ネットワーク適用率の `.toml` の記述例
```toml
[general]
[[datasets]]
resolution = 512
batch_size = 8
network_multiplier = 1.0
... subset settings ...
[[datasets]]
resolution = 512
batch_size = 8
network_multiplier = -1.0
... subset settings ...
```
Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates.
最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。

279
docs/config_README-en.md Normal file
View File

@@ -0,0 +1,279 @@
Original Source by kohya-ss
A.I Translation by Model: NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO, editing by Darkstorm2150
# Config Readme
This README is about the configuration files that can be passed with the `--dataset_config` option.
## Overview
By passing a configuration file, users can make detailed settings.
* Multiple datasets can be configured
* For example, by setting `resolution` for each dataset, they can be mixed and trained.
* In training methods that support both the DreamBooth approach and the fine-tuning approach, datasets of the DreamBooth method and the fine-tuning method can be mixed.
* Settings can be changed for each subset
* A subset is a partition of the dataset by image directory or metadata. Several subsets make up a dataset.
* Options such as `keep_tokens` and `flip_aug` can be set for each subset. On the other hand, options such as `resolution` and `batch_size` can be set for each dataset, and their values are common among subsets belonging to the same dataset. More details will be provided later.
The configuration file format can be JSON or TOML. Considering the ease of writing, it is recommended to use [TOML](https://toml.io/ja/v1.0.0-rc.2). The following explanation assumes the use of TOML.
Here is an example of a configuration file written in TOML.
```toml
[general]
shuffle_caption = true
caption_extension = '.txt'
keep_tokens = 1
# This is a DreamBooth-style dataset
[[datasets]]
resolution = 512
batch_size = 4
keep_tokens = 2
[[datasets.subsets]]
image_dir = 'C:\hoge'
class_tokens = 'hoge girl'
# This subset uses keep_tokens = 2 (the value of the parent datasets)
[[datasets.subsets]]
image_dir = 'C:\fuga'
class_tokens = 'fuga boy'
keep_tokens = 3
[[datasets.subsets]]
is_reg = true
image_dir = 'C:\reg'
class_tokens = 'human'
keep_tokens = 1
# This is a fine-tuning dataset
[[datasets]]
resolution = [768, 768]
batch_size = 2
[[datasets.subsets]]
image_dir = 'C:\piyo'
metadata_file = 'C:\piyo\piyo_md.json'
# This subset uses keep_tokens = 1 (the value of [general])
```
In this example, three directories are trained as a DreamBooth-style dataset at 512x512 (batch size 4), and one directory is trained as a fine-tuning dataset at 768x768 (batch size 2).
## Settings for datasets and subsets
Settings for datasets and subsets are divided into several registration locations.
* `[general]`
* This is where options that apply to all datasets or all subsets are specified.
* If there are options with the same name in the dataset-specific or subset-specific settings, the dataset-specific or subset-specific settings take precedence.
* `[[datasets]]`
* `datasets` is where settings for datasets are registered. This is where options that apply individually to each dataset are specified.
* If there are subset-specific settings, the subset-specific settings take precedence.
* `[[datasets.subsets]]`
* `datasets.subsets` is where settings for subsets are registered. This is where options that apply individually to each subset are specified.
Here is an image showing the correspondence between image directories and registration locations in the previous example.
```
C:\
├─ hoge -> [[datasets.subsets]] No.1 ┐ ┐
├─ fuga -> [[datasets.subsets]] No.2 |-> [[datasets]] No.1 |-> [general]
├─ reg -> [[datasets.subsets]] No.3 ┘ |
└─ piyo -> [[datasets.subsets]] No.4 --> [[datasets]] No.2 ┘
```
The image directory corresponds to each `[[datasets.subsets]]`. Then, multiple `[[datasets.subsets]]` are combined to form one `[[datasets]]`. All `[[datasets]]` and `[[datasets.subsets]]` belong to `[general]`.
The available options for each registration location may differ, but if the same option is specified, the value in the lower registration location will take precedence. You can check how the `keep_tokens` option is handled in the previous example for better understanding.
Additionally, the available options may vary depending on the method that the learning approach supports.
* Options specific to the DreamBooth method
* Options specific to the fine-tuning method
* Options available when using the caption dropout technique
When using both the DreamBooth method and the fine-tuning method, they can be used together with a learning approach that supports both.
When using them together, a point to note is that the method is determined based on the dataset, so it is not possible to mix DreamBooth method subsets and fine-tuning method subsets within the same dataset.
In other words, if you want to use both methods together, you need to set up subsets of different methods belonging to different datasets.
In terms of program behavior, if the `metadata_file` option exists, it is determined to be a subset of fine-tuning. Therefore, for subsets belonging to the same dataset, as long as they are either "all have the `metadata_file` option" or "all have no `metadata_file` option," there is no problem.
Below, the available options will be explained. For options with the same name as the command-line argument, the explanation will be omitted in principle. Please refer to other READMEs.
### Common options for all learning methods
These are options that can be specified regardless of the learning method.
#### Data set specific options
These are options related to the configuration of the data set. They cannot be described in `datasets.subsets`.
| Option Name | Example Setting | `[general]` | `[[datasets]]` |
| ---- | ---- | ---- | ---- |
| `batch_size` | `1` | o | o |
| `bucket_no_upscale` | `true` | o | o |
| `bucket_reso_steps` | `64` | o | o |
| `enable_bucket` | `true` | o | o |
| `max_bucket_reso` | `1024` | o | o |
| `min_bucket_reso` | `128` | o | o |
| `resolution` | `256`, `[512, 512]` | o | o |
* `batch_size`
* This corresponds to the command-line argument `--train_batch_size`.
These settings are fixed per dataset. That means that subsets belonging to the same dataset will share these settings. For example, if you want to prepare datasets with different resolutions, you can define them as separate datasets as shown in the example above, and set different resolutions for each.
#### Options for Subsets
These options are related to subset configuration.
| Option Name | Example | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
| ---- | ---- | ---- | ---- | ---- |
| `color_aug` | `false` | o | o | o |
| `face_crop_aug_range` | `[1.0, 3.0]` | o | o | o |
| `flip_aug` | `true` | o | o | o |
| `keep_tokens` | `2` | o | o | o |
| `num_repeats` | `10` | o | o | o |
| `random_crop` | `false` | o | o | o |
| `shuffle_caption` | `true` | o | o | o |
| `caption_prefix` | `"masterpiece, best quality, "` | o | o | o |
| `caption_suffix` | `", from side"` | o | o | o |
* `num_repeats`
* Specifies the number of repeats for images in a subset. This is equivalent to `--dataset_repeats` in fine-tuning but can be specified for any training method.
* `caption_prefix`, `caption_suffix`
* Specifies the prefix and suffix strings to be appended to the captions. Shuffling is performed with these strings included. Be cautious when using `keep_tokens`.
### DreamBooth-specific options
DreamBooth-specific options only exist as subsets-specific options.
#### Subset-specific options
Options related to the configuration of DreamBooth subsets.
| Option Name | Example Setting | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
| ---- | ---- | ---- | ---- | ---- |
| `image_dir` | `'C:\hoge'` | - | - | o (required) |
| `caption_extension` | `".txt"` | o | o | o |
| `class_tokens` | `"sks girl"` | - | - | o |
| `is_reg` | `false` | - | - | o |
Firstly, note that for `image_dir`, the path to the image files must be specified as being directly in the directory. Unlike the previous DreamBooth method, where images had to be placed in subdirectories, this is not compatible with that specification. Also, even if you name the folder something like "5_cat", the number of repeats of the image and the class name will not be reflected. If you want to set these individually, you will need to explicitly specify them using `num_repeats` and `class_tokens`.
* `image_dir`
* Specifies the path to the image directory. This is a required option.
* Images must be placed directly under the directory.
* `class_tokens`
* Sets the class tokens.
* Only used during training when a corresponding caption file does not exist. The determination of whether or not to use it is made on a per-image basis. If `class_tokens` is not specified and a caption file is not found, an error will occur.
* `is_reg`
* Specifies whether the subset images are for normalization. If not specified, it is set to `false`, meaning that the images are not for normalization.
### Fine-tuning method specific options
The options for the fine-tuning method only exist for subset-specific options.
#### Subset-specific options
These options are related to the configuration of the fine-tuning method's subsets.
| Option name | Example setting | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
| ---- | ---- | ---- | ---- | ---- |
| `image_dir` | `'C:\hoge'` | - | - | o |
| `metadata_file` | `'C:\piyo\piyo_md.json'` | - | - | o (required) |
* `image_dir`
* Specify the path to the image directory. Unlike the DreamBooth method, specifying it is not mandatory, but it is recommended to do so.
* The case where it is not necessary to specify is when the `--full_path` is added to the command line when generating the metadata file.
* The images must be placed directly under the directory.
* `metadata_file`
* Specify the path to the metadata file used for the subset. This is a required option.
* It is equivalent to the command-line argument `--in_json`.
* Due to the specification that a metadata file must be specified for each subset, it is recommended to avoid creating a metadata file with images from different directories as a single metadata file. It is strongly recommended to prepare a separate metadata file for each image directory and register them as separate subsets.
### Options available when caption dropout method can be used
The options available when the caption dropout method can be used exist only for subsets. Regardless of whether it's the DreamBooth method or fine-tuning method, if it supports caption dropout, it can be specified.
#### Subset-specific options
Options related to the setting of subsets that caption dropout can be used for.
| Option Name | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
| ---- | ---- | ---- | ---- |
| `caption_dropout_every_n_epochs` | o | o | o |
| `caption_dropout_rate` | o | o | o |
| `caption_tag_dropout_rate` | o | o | o |
## Behavior when there are duplicate subsets
In the case of the DreamBooth dataset, if there are multiple `image_dir` directories with the same content, they are considered to be duplicate subsets. For the fine-tuning dataset, if there are multiple `metadata_file` files with the same content, they are considered to be duplicate subsets. If duplicate subsets exist in the dataset, subsequent subsets will be ignored.
However, if they belong to different datasets, they are not considered duplicates. For example, if you have subsets with the same `image_dir` in different datasets, they will not be considered duplicates. This is useful when you want to train with the same image but with different resolutions.
```toml
# If data sets exist separately, they are not considered duplicates and are both used for training.
[[datasets]]
resolution = 512
[[datasets.subsets]]
image_dir = 'C:\hoge'
[[datasets]]
resolution = 768
[[datasets.subsets]]
image_dir = 'C:\hoge'
```
## Command Line Argument and Configuration File
There are options in the configuration file that have overlapping roles with command line argument options.
The following command line argument options are ignored if a configuration file is passed:
* `--train_data_dir`
* `--reg_data_dir`
* `--in_json`
The following command line argument options are given priority over the configuration file options if both are specified simultaneously. In most cases, they have the same names as the corresponding options in the configuration file.
| Command Line Argument Option | Prioritized Configuration File Option |
| ------------------------------- | ------------------------------------- |
| `--bucket_no_upscale` | |
| `--bucket_reso_steps` | |
| `--caption_dropout_every_n_epochs` | |
| `--caption_dropout_rate` | |
| `--caption_extension` | |
| `--caption_tag_dropout_rate` | |
| `--color_aug` | |
| `--dataset_repeats` | `num_repeats` |
| `--enable_bucket` | |
| `--face_crop_aug_range` | |
| `--flip_aug` | |
| `--keep_tokens` | |
| `--min_bucket_reso` | |
| `--random_crop` | |
| `--resolution` | |
| `--shuffle_caption` | |
| `--train_batch_size` | `batch_size` |
## Error Guide
Currently, we are using an external library to check if the configuration file is written correctly, but the development has not been completed, and there is a problem that the error message is not clear. In the future, we plan to improve this problem.
As a temporary measure, we will list common errors and their solutions. If you encounter an error even though it should be correct or if the error content is not understandable, please contact us as it may be a bug.
* `voluptuous.error.MultipleInvalid: required key not provided @ ...`: This error occurs when a required option is not provided. It is highly likely that you forgot to specify the option or misspelled the option name.
* The error location is indicated by `...` in the error message. For example, if you encounter an error like `voluptuous.error.MultipleInvalid: required key not provided @ data['datasets'][0]['subsets'][0]['image_dir']`, it means that the `image_dir` option does not exist in the 0th `subsets` of the 0th `datasets` setting.
* `voluptuous.error.MultipleInvalid: expected int for dictionary value @ ...`: This error occurs when the specified value format is incorrect. It is highly likely that the value format is incorrect. The `int` part changes depending on the target option. The example configurations in this README may be helpful.
* `voluptuous.error.MultipleInvalid: extra keys not allowed @ ...`: This error occurs when there is an option name that is not supported. It is highly likely that you misspelled the option name or mistakenly included it.

View File

@@ -457,7 +457,7 @@ def train(args):
accelerator.end_training()
if args.save_state and is_main_process:
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)
del accelerator # この後メモリを使うのでこれは消す

View File

@@ -11,19 +11,15 @@ import cv2
import torch
from library.device_utils import init_ipex, get_preferred_device
init_ipex()
from torchvision import transforms
import library.model_util as model_util
import library.stable_cascade_utils as sc_utils
import library.train_util as train_util
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
DEVICE = get_preferred_device()
@@ -46,7 +42,7 @@ def collate_fn_remove_corrupted(batch):
return batch
def get_npz_filename(data_dir, image_key, is_full_path, recursive, stable_cascade):
def get_npz_filename(data_dir, image_key, is_full_path, recursive):
if is_full_path:
base_name = os.path.splitext(os.path.basename(image_key))[0]
relative_path = os.path.relpath(os.path.dirname(image_key), data_dir)
@@ -54,11 +50,10 @@ def get_npz_filename(data_dir, image_key, is_full_path, recursive, stable_cascad
base_name = image_key
relative_path = ""
ext = ".npz" if not stable_cascade else train_util.STABLE_CASCADE_LATENTS_CACHE_SUFFIX
if recursive and relative_path:
return os.path.join(data_dir, relative_path, base_name) + ext
return os.path.join(data_dir, relative_path, base_name) + ".npz"
else:
return os.path.join(data_dir, base_name) + ext
return os.path.join(data_dir, base_name) + ".npz"
def main(args):
@@ -88,20 +83,13 @@ def main(args):
elif args.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
if not args.stable_cascade:
vae = model_util.load_vae(args.model_name_or_path, weight_dtype)
divisor = 8
else:
vae = sc_utils.load_effnet(args.model_name_or_path, DEVICE)
divisor = 32
vae = model_util.load_vae(args.model_name_or_path, weight_dtype)
vae.eval()
vae.to(DEVICE, dtype=weight_dtype)
# bucketのサイズを計算する
max_reso = tuple([int(t) for t in args.max_resolution.split(",")])
assert (
len(max_reso) == 2
), f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
bucket_manager = train_util.BucketManager(
args.bucket_no_upscale, max_reso, args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps
@@ -166,10 +154,6 @@ def main(args):
# メタデータに記録する解像度はlatent単位とするので、8単位で切り捨て
metadata[image_key]["train_resolution"] = (reso[0] - reso[0] % 8, reso[1] - reso[1] % 8)
# 追加情報を記録
metadata[image_key]["original_size"] = (image.width, image.height)
metadata[image_key]["train_resized_size"] = resized_size
if not args.bucket_no_upscale:
# upscaleを行わないときには、resize後のサイズは、bucketのサイズと、縦横どちらかが同じであることを確認する
assert (
@@ -184,9 +168,9 @@ def main(args):
), f"internal error resized size is small: {resized_size}, {reso}"
# 既に存在するファイルがあればshape等を確認して同じならskipする
npz_file_name = get_npz_filename(args.train_data_dir, image_key, args.full_path, args.recursive, args.stable_cascade)
npz_file_name = get_npz_filename(args.train_data_dir, image_key, args.full_path, args.recursive)
if args.skip_existing:
if train_util.is_disk_cached_latents_is_expected(reso, npz_file_name, args.flip_aug, divisor):
if train_util.is_disk_cached_latents_is_expected(reso, npz_file_name, args.flip_aug):
continue
# バッチへ追加
@@ -224,14 +208,7 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
parser.add_argument(
"--stable_cascade",
action="store_true",
help="prepare EffNet latents for stable cascade / stable cascade用のEffNetのlatentsを準備する",
)
parser.add_argument(
"--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)"
)
parser.add_argument("--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)")
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument(
"--max_data_loader_n_workers",
@@ -254,16 +231,10 @@ def setup_parser() -> argparse.ArgumentParser:
help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します",
)
parser.add_argument(
"--bucket_no_upscale",
action="store_true",
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します",
"--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します"
)
parser.add_argument(
"--mixed_precision",
type=str,
default="no",
choices=["no", "fp16", "bf16"],
help="use mixed precision / 混合精度を使う場合、その精度",
"--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度"
)
parser.add_argument(
"--full_path",
@@ -271,9 +242,7 @@ def setup_parser() -> argparse.ArgumentParser:
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)",
)
parser.add_argument(
"--flip_aug",
action="store_true",
help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する",
"--flip_aug", action="store_true", help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する"
)
parser.add_argument(
"--skip_existing",

View File

@@ -12,8 +12,10 @@ from tqdm import tqdm
import library.train_util as train_util
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
# from wd14 tagger
@@ -79,34 +81,42 @@ def collate_fn_remove_corrupted(batch):
def main(args):
# model location is model_dir + repo_id
# repo id may be like "user/repo" or "user/repo/branch", so we need to remove slash
model_location = os.path.join(args.model_dir, args.repo_id.replace("/", "_"))
# hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする
# depreacatedの警告が出るけどなくなったらその時
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
if not os.path.exists(args.model_dir) or args.force_download:
if not os.path.exists(model_location) or args.force_download:
os.makedirs(args.model_dir, exist_ok=True)
logger.info(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
files = FILES
if args.onnx:
files = ["selected_tags.csv"]
files += FILES_ONNX
else:
for file in SUB_DIR_FILES:
hf_hub_download(
args.repo_id,
file,
subfolder=SUB_DIR,
cache_dir=os.path.join(model_location, SUB_DIR),
force_download=True,
force_filename=file,
)
for file in files:
hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
for file in SUB_DIR_FILES:
hf_hub_download(
args.repo_id,
file,
subfolder=SUB_DIR,
cache_dir=os.path.join(args.model_dir, SUB_DIR),
force_download=True,
force_filename=file,
)
hf_hub_download(args.repo_id, file, cache_dir=model_location, force_download=True, force_filename=file)
else:
logger.info("using existing wd14 tagger model")
# 画像を読み込む
if args.onnx:
import torch
import onnx
import onnxruntime as ort
onnx_path = f"{args.model_dir}/model.onnx"
onnx_path = f"{model_location}/model.onnx"
logger.info("Running wd14 tagger with onnx")
logger.info(f"loading onnx model: {onnx_path}")
@@ -123,7 +133,7 @@ def main(args):
except:
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_param
if args.batch_size != batch_size and type(batch_size) != str:
if args.batch_size != batch_size and type(batch_size) != str and batch_size > 0:
# some rebatch model may use 'N' as dynamic axes
logger.warning(
f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}"
@@ -134,19 +144,19 @@ def main(args):
ort_sess = ort.InferenceSession(
onnx_path,
providers=["CUDAExecutionProvider"]
if "CUDAExecutionProvider" in ort.get_available_providers()
else ["CPUExecutionProvider"],
providers=(
["CUDAExecutionProvider"] if "CUDAExecutionProvider" in ort.get_available_providers() else ["CPUExecutionProvider"]
),
)
else:
from tensorflow.keras.models import load_model
model = load_model(f"{args.model_dir}")
model = load_model(f"{model_location}")
# label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv")
# 依存ライブラリを増やしたくないので自力で読むよ
with open(os.path.join(args.model_dir, CSV_FILE), "r", encoding="utf-8") as f:
with open(os.path.join(model_location, CSV_FILE), "r", encoding="utf-8") as f:
reader = csv.reader(f)
l = [row for row in reader]
header = l[0] # tag_id,name,category,count
@@ -172,8 +182,8 @@ def main(args):
imgs = np.array([im for _, im in path_imgs])
if args.onnx:
if len(imgs) < args.batch_size:
imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0)
# if len(imgs) < args.batch_size:
# imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0)
probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy
probs = probs[: len(path_imgs)]
else:
@@ -314,7 +324,9 @@ def setup_parser() -> argparse.ArgumentParser:
help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ",
)
parser.add_argument(
"--force_download", action="store_true", help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします"
"--force_download",
action="store_true",
help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします",
)
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument(
@@ -329,8 +341,12 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)",
)
parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子")
parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値")
parser.add_argument(
"--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子"
)
parser.add_argument(
"--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値"
)
parser.add_argument(
"--general_threshold",
type=float,
@@ -343,7 +359,9 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="threshold of confidence to add a tag for character category, same as --thres if omitted / characterカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ",
)
parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する")
parser.add_argument(
"--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する"
)
parser.add_argument(
"--remove_underscore",
action="store_true",
@@ -356,9 +374,13 @@ def setup_parser() -> argparse.ArgumentParser:
default="",
help="comma-separated list of undesired tags to remove from the output / 出力から除外したいタグのカンマ区切りのリスト",
)
parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する")
parser.add_argument(
"--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する"
)
parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する")
parser.add_argument("--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する")
parser.add_argument(
"--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する"
)
parser.add_argument(
"--caption_separator",
type=str,

View File

@@ -60,6 +60,8 @@ class BaseSubsetParams:
caption_separator: str = (",",)
keep_tokens: int = 0
keep_tokens_separator: str = (None,)
secondary_separator: Optional[str] = None
enable_wildcard: bool = False
color_aug: bool = False
flip_aug: bool = False
face_crop_aug_range: Optional[Tuple[float, float]] = None
@@ -181,6 +183,8 @@ class ConfigSanitizer:
"shuffle_caption": bool,
"keep_tokens": int,
"keep_tokens_separator": str,
"secondary_separator": str,
"enable_wildcard": bool,
"token_warmup_min": int,
"token_warmup_step": Any(float, int),
"caption_prefix": str,
@@ -504,6 +508,8 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
shuffle_caption: {subset.shuffle_caption}
keep_tokens: {subset.keep_tokens}
keep_tokens_separator: {subset.keep_tokens_separator}
secondary_separator: {subset.secondary_separator}
enable_wildcard: {subset.enable_wildcard}
caption_dropout_rate: {subset.caption_dropout_rate}
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}

View File

@@ -3,11 +3,6 @@ import gc
import torch
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
try:
HAS_CUDA = torch.cuda.is_available()
except Exception:
@@ -64,7 +59,7 @@ def get_preferred_device() -> torch.device:
device = torch.device("mps")
else:
device = torch.device("cpu")
logger.info(f"get_preferred_device() -> {device}")
print(f"get_preferred_device() -> {device}")
return device
@@ -82,8 +77,8 @@ def init_ipex():
is_initialized, error_message = ipex_init()
if not is_initialized:
logger.error("failed to initialize ipex: {error_message}")
print("failed to initialize ipex:", error_message)
else:
return
except Exception as e:
logger.error("failed to initialize ipex: {e}")
print("failed to initialize ipex:", e)

View File

@@ -5,10 +5,13 @@ import argparse
import os
from library.utils import fire_in_thread
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None):
api = HfApi(
token=token,
@@ -44,19 +47,14 @@ def upload(
def uploader():
try:
# 自前でスレッド化しているので run_as_future は明示的に False にするHub APIのバグかもしれない
if is_folder:
api.upload_folder(
repo_id=repo_id,
repo_type=repo_type,
folder_path=src,
path_in_repo=path_in_repo,
repo_id=repo_id, repo_type=repo_type, folder_path=src, path_in_repo=path_in_repo, run_as_future=False
)
else:
api.upload_file(
repo_id=repo_id,
repo_type=repo_type,
path_or_fileobj=src,
path_in_repo=path_in_repo,
repo_id=repo_id, repo_type=repo_type, path_or_fileobj=src, path_in_repo=path_in_repo, run_as_future=False
)
except Exception as e: # RuntimeErrorを確認済みだが他にあると困るので
logger.error("===========================================")

View File

@@ -6,10 +6,8 @@ import os
from typing import List, Optional, Tuple, Union
import safetensors
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
r"""
@@ -57,13 +55,11 @@ ARCH_SD_V1 = "stable-diffusion-v1"
ARCH_SD_V2_512 = "stable-diffusion-v2-512"
ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v"
ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base"
ARCH_STABLE_CASCADE = "stable-cascade"
ADAPTER_LORA = "lora"
ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models"
IMPL_STABILITY_AI_STABLE_CASCADE = "https://github.com/Stability-AI/StableCascade"
IMPL_DIFFUSERS = "diffusers"
PRED_TYPE_EPSILON = "epsilon"
@@ -117,7 +113,6 @@ def build_metadata(
merged_from: Optional[str] = None,
timesteps: Optional[Tuple[int, int]] = None,
clip_skip: Optional[int] = None,
stable_cascade: Optional[bool] = None,
):
# if state_dict is None, hash is not calculated
@@ -129,9 +124,7 @@ def build_metadata(
# hash = precalculate_safetensors_hashes(state_dict)
# metadata["modelspec.hash_sha256"] = hash
if stable_cascade:
arch = ARCH_STABLE_CASCADE
elif sdxl:
if sdxl:
arch = ARCH_SD_XL_V1_BASE
elif v2:
if v_parameterization:
@@ -149,11 +142,9 @@ def build_metadata(
metadata["modelspec.architecture"] = arch
if not lora and not textual_inversion and is_stable_diffusion_ckpt is None:
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
if stable_cascade:
impl = IMPL_STABILITY_AI_STABLE_CASCADE
elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
# Stable Diffusion ckpt, TI, SDXL LoRA
impl = IMPL_STABILITY_AI
else:
@@ -245,7 +236,7 @@ def build_metadata(
# assert all([v is not None for v in metadata.values()]), metadata
if not all([v is not None for v in metadata.values()]):
logger.error(f"Internal error: some metadata values are None: {metadata}")
return metadata
@@ -259,7 +250,7 @@ def get_title(metadata: dict) -> Optional[str]:
def load_metadata_from_safetensors(model: str) -> dict:
if not model.endswith(".safetensors"):
return {}
with safetensors.safe_open(model, framework="pt") as f:
metadata = f.metadata()
if metadata is None:

View File

@@ -31,8 +31,10 @@ from torch import nn
from torch.nn import functional as F
from einops import rearrange
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
IN_CHANNELS: int = 4
@@ -1074,7 +1076,7 @@ class SdxlUNet2DConditionModel(nn.Module):
timesteps = timesteps.expand(x.shape[0])
hs = []
t_emb = get_timestep_embedding(timesteps, self.model_channels) # , repeat_only=False)
t_emb = get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0) # , repeat_only=False)
t_emb = t_emb.to(x.dtype)
emb = self.time_embed(t_emb)
@@ -1132,7 +1134,7 @@ class InferSdxlUNet2DConditionModel:
# call original model's methods
def __getattr__(self, name):
return getattr(self.delegate, name)
def __call__(self, *args, **kwargs):
return self.delegate(*args, **kwargs)
@@ -1164,7 +1166,7 @@ class InferSdxlUNet2DConditionModel:
timesteps = timesteps.expand(x.shape[0])
hs = []
t_emb = get_timestep_embedding(timesteps, _self.model_channels) # , repeat_only=False)
t_emb = get_timestep_embedding(timesteps, _self.model_channels, downscale_freq_shift=0) # , repeat_only=False)
t_emb = t_emb.to(x.dtype)
emb = _self.time_embed(t_emb)

File diff suppressed because it is too large Load Diff

View File

@@ -1,668 +0,0 @@
import argparse
import json
import math
import os
import time
from typing import List
import numpy as np
import toml
import torch
import torchvision
from safetensors.torch import load_file, save_file
from tqdm import tqdm
from transformers import CLIPTokenizer, CLIPTextModelWithProjection, CLIPTextConfig
from accelerate import init_empty_weights, Accelerator, PartialState
from PIL import Image
from library import stable_cascade as sc
from library.sdxl_model_util import _load_state_dict_on_device
from library.device_utils import clean_memory_on_device
from library.train_util import (
save_sd_model_on_epoch_end_or_stepwise_common,
save_sd_model_on_train_end_common,
line_to_prompt_dict,
get_hidden_states_stable_cascade,
)
from library import sai_model_spec
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
CLIP_TEXT_MODEL_NAME: str = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_sc_te_outputs.npz"
def calculate_latent_sizes(height=1024, width=1024, batch_size=4, compression_factor_b=42.67, compression_factor_a=4.0):
resolution_multiple = 42.67
latent_height = math.ceil(height / compression_factor_b)
latent_width = math.ceil(width / compression_factor_b)
stage_c_latent_shape = (batch_size, 16, latent_height, latent_width)
latent_height = math.ceil(height / compression_factor_a)
latent_width = math.ceil(width / compression_factor_a)
stage_b_latent_shape = (batch_size, 4, latent_height, latent_width)
return stage_c_latent_shape, stage_b_latent_shape
# region load and save
def load_effnet(effnet_checkpoint_path, loading_device="cpu") -> sc.EfficientNetEncoder:
logger.info(f"Loading EfficientNet encoder from {effnet_checkpoint_path}")
effnet = sc.EfficientNetEncoder()
effnet_checkpoint = load_file(effnet_checkpoint_path)
info = effnet.load_state_dict(effnet_checkpoint if "state_dict" not in effnet_checkpoint else effnet_checkpoint["state_dict"])
logger.info(info)
del effnet_checkpoint
return effnet
def load_tokenizer(args: argparse.Namespace):
# TODO commonize with sdxl_train_util.load_tokenizers
logger.info("prepare tokenizers")
original_paths = [CLIP_TEXT_MODEL_NAME]
tokenizers = []
for i, original_path in enumerate(original_paths):
tokenizer: CLIPTokenizer = None
if args.tokenizer_cache_dir:
local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_"))
if os.path.exists(local_tokenizer_path):
logger.info(f"load tokenizer from cache: {local_tokenizer_path}")
tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path)
if tokenizer is None:
tokenizer = CLIPTokenizer.from_pretrained(original_path)
if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
logger.info(f"save Tokenizer to cache: {local_tokenizer_path}")
tokenizer.save_pretrained(local_tokenizer_path)
tokenizers.append(tokenizer)
if hasattr(args, "max_token_length") and args.max_token_length is not None:
logger.info(f"update token length: {args.max_token_length}")
return tokenizers[0]
def load_stage_c_model(stage_c_checkpoint_path, dtype=None, device="cpu") -> sc.StageC:
# Generator
logger.info(f"Instantiating Stage C generator")
with init_empty_weights():
generator_c = sc.StageC()
logger.info(f"Loading Stage C generator from {stage_c_checkpoint_path}")
stage_c_checkpoint = load_file(stage_c_checkpoint_path)
stage_c_checkpoint = convert_state_dict_mha_to_normal_attn(stage_c_checkpoint)
logger.info(f"Loading state dict")
info = _load_state_dict_on_device(generator_c, stage_c_checkpoint, device, dtype=dtype)
logger.info(info)
return generator_c
def load_stage_b_model(stage_b_checkpoint_path, dtype=None, device="cpu") -> sc.StageB:
logger.info(f"Instantiating Stage B generator")
with init_empty_weights():
generator_b = sc.StageB()
logger.info(f"Loading Stage B generator from {stage_b_checkpoint_path}")
stage_b_checkpoint = load_file(stage_b_checkpoint_path)
stage_b_checkpoint = convert_state_dict_mha_to_normal_attn(stage_b_checkpoint)
logger.info(f"Loading state dict")
info = _load_state_dict_on_device(generator_b, stage_b_checkpoint, device, dtype=dtype)
logger.info(info)
return generator_b
def load_clip_text_model(text_model_checkpoint_path, dtype=None, device="cpu", save_text_model=False):
# CLIP encoders
logger.info(f"Loading CLIP text model")
if save_text_model or text_model_checkpoint_path is None:
logger.info(f"Loading CLIP text model from {CLIP_TEXT_MODEL_NAME}")
text_model = CLIPTextModelWithProjection.from_pretrained(CLIP_TEXT_MODEL_NAME)
if save_text_model:
sd = text_model.state_dict()
logger.info(f"Saving CLIP text model to {text_model_checkpoint_path}")
save_file(sd, text_model_checkpoint_path)
else:
logger.info(f"Loading CLIP text model from {text_model_checkpoint_path}")
# copy from sdxl_model_util.py
text_model2_cfg = CLIPTextConfig(
vocab_size=49408,
hidden_size=1280,
intermediate_size=5120,
num_hidden_layers=32,
num_attention_heads=20,
max_position_embeddings=77,
hidden_act="gelu",
layer_norm_eps=1e-05,
dropout=0.0,
attention_dropout=0.0,
initializer_range=0.02,
initializer_factor=1.0,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
model_type="clip_text_model",
projection_dim=1280,
# torch_dtype="float32",
# transformers_version="4.25.0.dev0",
)
with init_empty_weights():
text_model = CLIPTextModelWithProjection(text_model2_cfg)
text_model_checkpoint = load_file(text_model_checkpoint_path)
info = _load_state_dict_on_device(text_model, text_model_checkpoint, device, dtype=dtype)
logger.info(info)
return text_model
def load_stage_a_model(stage_a_checkpoint_path, dtype=None, device="cpu") -> sc.StageA:
logger.info(f"Loading Stage A vqGAN from {stage_a_checkpoint_path}")
stage_a = sc.StageA().to(device)
stage_a_checkpoint = load_file(stage_a_checkpoint_path)
info = stage_a.load_state_dict(
stage_a_checkpoint if "state_dict" not in stage_a_checkpoint else stage_a_checkpoint["state_dict"]
)
logger.info(info)
return stage_a
def load_previewer_model(previewer_checkpoint_path, dtype=None, device="cpu") -> sc.Previewer:
logger.info(f"Loading Previewer from {previewer_checkpoint_path}")
previewer = sc.Previewer().to(device)
previewer_checkpoint = load_file(previewer_checkpoint_path)
info = previewer.load_state_dict(
previewer_checkpoint if "state_dict" not in previewer_checkpoint else previewer_checkpoint["state_dict"]
)
logger.info(info)
return previewer
def convert_state_dict_mha_to_normal_attn(state_dict):
# convert nn.MultiheadAttention to to_q/k/v and out_proj
print("convert_state_dict_mha_to_normal_attn")
for key in list(state_dict.keys()):
if "attention.attn." in key:
if "in_proj_bias" in key:
value = state_dict.pop(key)
qkv = torch.chunk(value, 3, dim=0)
state_dict[key.replace("in_proj_bias", "to_q.bias")] = qkv[0]
state_dict[key.replace("in_proj_bias", "to_k.bias")] = qkv[1]
state_dict[key.replace("in_proj_bias", "to_v.bias")] = qkv[2]
elif "in_proj_weight" in key:
value = state_dict.pop(key)
qkv = torch.chunk(value, 3, dim=0)
state_dict[key.replace("in_proj_weight", "to_q.weight")] = qkv[0]
state_dict[key.replace("in_proj_weight", "to_k.weight")] = qkv[1]
state_dict[key.replace("in_proj_weight", "to_v.weight")] = qkv[2]
elif "out_proj.bias" in key:
value = state_dict.pop(key)
state_dict[key.replace("out_proj.bias", "out_proj.bias")] = value
elif "out_proj.weight" in key:
value = state_dict.pop(key)
state_dict[key.replace("out_proj.weight", "out_proj.weight")] = value
return state_dict
def convert_state_dict_normal_attn_to_mha(state_dict):
# convert to_q/k/v and out_proj to nn.MultiheadAttention
for key in list(state_dict.keys()):
if "attention.attn." in key:
if "to_q.bias" in key:
q = state_dict.pop(key)
k = state_dict.pop(key.replace("to_q.bias", "to_k.bias"))
v = state_dict.pop(key.replace("to_q.bias", "to_v.bias"))
state_dict[key.replace("to_q.bias", "in_proj_bias")] = torch.cat([q, k, v])
elif "to_q.weight" in key:
q = state_dict.pop(key)
k = state_dict.pop(key.replace("to_q.weight", "to_k.weight"))
v = state_dict.pop(key.replace("to_q.weight", "to_v.weight"))
state_dict[key.replace("to_q.weight", "in_proj_weight")] = torch.cat([q, k, v])
elif "out_proj.bias" in key:
v = state_dict.pop(key)
state_dict[key.replace("out_proj.bias", "out_proj.bias")] = v
elif "out_proj.weight" in key:
v = state_dict.pop(key)
state_dict[key.replace("out_proj.weight", "out_proj.weight")] = v
return state_dict
def get_sai_model_spec(args, lora=False):
timestamp = time.time()
reso = args.resolution
title = args.metadata_title if args.metadata_title is not None else args.output_name
if args.min_timestep is not None or args.max_timestep is not None:
min_time_step = args.min_timestep if args.min_timestep is not None else 0
max_time_step = args.max_timestep if args.max_timestep is not None else 1000
timesteps = (min_time_step, max_time_step)
else:
timesteps = None
metadata = sai_model_spec.build_metadata(
None,
False,
False,
False,
lora,
False,
timestamp,
title=title,
reso=reso,
is_stable_diffusion_ckpt=False,
author=args.metadata_author,
description=args.metadata_description,
license=args.metadata_license,
tags=args.metadata_tags,
timesteps=timesteps,
clip_skip=args.clip_skip, # None or int
stable_cascade=True,
)
return metadata
def stage_c_saver_common(ckpt_file, stage_c, text_model, save_dtype, sai_metadata):
state_dict = stage_c.state_dict()
if save_dtype is not None:
state_dict = {k: v.to(save_dtype) for k, v in state_dict.items()}
state_dict = convert_state_dict_normal_attn_to_mha(state_dict)
save_file(state_dict, ckpt_file, metadata=sai_metadata)
# save text model
if text_model is not None:
text_model_sd = text_model.state_dict()
if save_dtype is not None:
text_model_sd = {k: v.to(save_dtype) for k, v in text_model_sd.items()}
text_model_ckpt_file = os.path.splitext(ckpt_file)[0] + "_text_model.safetensors"
save_file(text_model_sd, text_model_ckpt_file)
def save_stage_c_model_on_epoch_end_or_stepwise(
args: argparse.Namespace,
on_epoch_end: bool,
accelerator,
save_dtype: torch.dtype,
epoch: int,
num_train_epochs: int,
global_step: int,
stage_c,
text_model,
):
def stage_c_saver(ckpt_file, epoch_no, global_step):
sai_metadata = get_sai_model_spec(args)
stage_c_saver_common(ckpt_file, stage_c, text_model, save_dtype, sai_metadata)
save_sd_model_on_epoch_end_or_stepwise_common(
args, on_epoch_end, accelerator, True, True, epoch, num_train_epochs, global_step, stage_c_saver, None
)
def save_stage_c_model_on_end(
args: argparse.Namespace,
save_dtype: torch.dtype,
epoch: int,
global_step: int,
stage_c,
text_model,
):
def stage_c_saver(ckpt_file, epoch_no, global_step):
sai_metadata = get_sai_model_spec(args)
stage_c_saver_common(ckpt_file, stage_c, text_model, save_dtype, sai_metadata)
save_sd_model_on_train_end_common(args, True, True, epoch, global_step, stage_c_saver, None)
# endregion
# region sample generation
def sample_images(
accelerator: Accelerator,
args: argparse.Namespace,
epoch,
steps,
previewer,
tokenizer,
text_encoder,
stage_c,
gdf,
prompt_replacement=None,
):
if steps == 0:
if not args.sample_at_first:
return
else:
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
return
if args.sample_every_n_epochs is not None:
# sample_every_n_steps は無視する
if epoch is None or epoch % args.sample_every_n_epochs != 0:
return
else:
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
return
logger.info("")
logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
if not os.path.isfile(args.sample_prompts):
logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
return
distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
# unwrap unet and text_encoder(s)
stage_c = accelerator.unwrap_model(stage_c)
text_encoder = accelerator.unwrap_model(text_encoder)
# read prompts
if args.sample_prompts.endswith(".txt"):
with open(args.sample_prompts, "r", encoding="utf-8") as f:
lines = f.readlines()
prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
elif args.sample_prompts.endswith(".toml"):
with open(args.sample_prompts, "r", encoding="utf-8") as f:
data = toml.load(f)
prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]]
elif args.sample_prompts.endswith(".json"):
with open(args.sample_prompts, "r", encoding="utf-8") as f:
prompts = json.load(f)
save_dir = args.output_dir + "/sample"
os.makedirs(save_dir, exist_ok=True)
# preprocess prompts
for i in range(len(prompts)):
prompt_dict = prompts[i]
if isinstance(prompt_dict, str):
prompt_dict = line_to_prompt_dict(prompt_dict)
prompts[i] = prompt_dict
assert isinstance(prompt_dict, dict)
# Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict.
prompt_dict["enum"] = i
prompt_dict.pop("subset", None)
# save random state to restore later
rng_state = torch.get_rng_state()
cuda_rng_state = None
try:
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
except Exception:
pass
if distributed_state.num_processes <= 1:
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
with torch.no_grad():
for prompt_dict in prompts:
sample_image_inference(
accelerator,
args,
tokenizer,
text_encoder,
stage_c,
previewer,
gdf,
save_dir,
prompt_dict,
epoch,
steps,
prompt_replacement,
)
else:
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
per_process_prompts = [] # list of lists
for i in range(distributed_state.num_processes):
per_process_prompts.append(prompts[i :: distributed_state.num_processes])
with torch.no_grad():
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
for prompt_dict in prompt_dict_lists[0]:
sample_image_inference(
accelerator,
args,
tokenizer,
text_encoder,
stage_c,
previewer,
gdf,
save_dir,
prompt_dict,
epoch,
steps,
prompt_replacement,
)
# I'm not sure which of these is the correct way to clear the memory, but accelerator's device is used in the pipeline, so I'm using it here.
# with torch.cuda.device(torch.cuda.current_device()):
# torch.cuda.empty_cache()
clean_memory_on_device(accelerator.device)
torch.set_rng_state(rng_state)
if cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state)
def sample_image_inference(
accelerator: Accelerator,
args: argparse.Namespace,
tokenizer,
text_model,
stage_c,
previewer,
gdf,
save_dir,
prompt_dict,
epoch,
steps,
prompt_replacement,
):
assert isinstance(prompt_dict, dict)
negative_prompt = prompt_dict.get("negative_prompt")
sample_steps = prompt_dict.get("sample_steps", 20)
width = prompt_dict.get("width", 1024)
height = prompt_dict.get("height", 1024)
scale = prompt_dict.get("scale", 4)
seed = prompt_dict.get("seed")
# controlnet_image = prompt_dict.get("controlnet_image")
prompt: str = prompt_dict.get("prompt", "")
# sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
if negative_prompt is not None:
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
else:
# True random sample image generation
torch.seed()
torch.cuda.seed()
height = max(64, height - height % 8) # round to divisible by 8
width = max(64, width - width % 8) # round to divisible by 8
logger.info(f"prompt: {prompt}")
logger.info(f"negative_prompt: {negative_prompt}")
logger.info(f"height: {height}")
logger.info(f"width: {width}")
logger.info(f"sample_steps: {sample_steps}")
logger.info(f"scale: {scale}")
# logger.info(f"sample_sampler: {sampler_name}")
if seed is not None:
logger.info(f"seed: {seed}")
negative_prompt = "" if negative_prompt is None else negative_prompt
cfg = scale
timesteps = sample_steps
shift = 2
t_start = 1.0
stage_c_latent_shape, _ = calculate_latent_sizes(height, width, batch_size=1)
# PREPARE CONDITIONS
input_ids = tokenizer(
[prompt], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
)["input_ids"].to(text_model.device)
cond_text, cond_pooled = get_hidden_states_stable_cascade(tokenizer.model_max_length, input_ids, tokenizer, text_model)
input_ids = tokenizer(
[negative_prompt], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
)["input_ids"].to(text_model.device)
uncond_text, uncond_pooled = get_hidden_states_stable_cascade(tokenizer.model_max_length, input_ids, tokenizer, text_model)
device = accelerator.device
dtype = stage_c.dtype
cond_text = cond_text.to(device, dtype=dtype)
cond_pooled = cond_pooled.unsqueeze(1).to(device, dtype=dtype)
uncond_text = uncond_text.to(device, dtype=dtype)
uncond_pooled = uncond_pooled.unsqueeze(1).to(device, dtype=dtype)
zero_img_emb = torch.zeros(1, 768, device=device)
# 辞書にしたくないけど GDF から先の変更が面倒だからとりあえず辞書にしておく
conditions = {"clip_text_pooled": cond_pooled, "clip": cond_pooled, "clip_text": cond_text, "clip_img": zero_img_emb}
unconditions = {"clip_text_pooled": uncond_pooled, "clip": uncond_pooled, "clip_text": uncond_text, "clip_img": zero_img_emb}
with torch.no_grad(): # , torch.cuda.amp.autocast(dtype=dtype):
sampling_c = gdf.sample(
stage_c,
conditions,
stage_c_latent_shape,
unconditions,
device=device,
cfg=cfg,
shift=shift,
timesteps=timesteps,
t_start=t_start,
)
for sampled_c, _, _ in tqdm(sampling_c, total=timesteps):
sampled_c = sampled_c
sampled_c = sampled_c.to(previewer.device, dtype=previewer.dtype)
image = previewer(sampled_c)[0]
image = torch.clamp(image, 0, 1)
image = image.cpu().numpy().transpose(1, 2, 0)
image = image * 255
image = image.astype(np.uint8)
image = Image.fromarray(image)
# adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
# but adding 'enum' to the filename should be enough
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
seed_suffix = "" if seed is None else f"_{seed}"
i: int = prompt_dict["enum"]
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
image.save(os.path.join(save_dir, img_filename))
# wandb有効時のみログを送信
try:
wandb_tracker = accelerator.get_tracker("wandb")
try:
import wandb
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
raise ImportError("No wandb / wandb がインストールされていないようです")
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
except: # wandb 無効時
pass
# endregion
def add_effnet_arguments(parser):
parser.add_argument(
"--effnet_checkpoint_path",
type=str,
required=True,
help="path to EfficientNet checkpoint / EfficientNetのチェックポイントのパス",
)
return parser
def add_text_model_arguments(parser):
parser.add_argument(
"--text_model_checkpoint_path",
type=str,
help="path to CLIP text model checkpoint / CLIPテキストモデルのチェックポイントのパス",
)
parser.add_argument("--save_text_model", action="store_true", help="if specified, save text model to corresponding path")
return parser
def add_stage_a_arguments(parser):
parser.add_argument(
"--stage_a_checkpoint_path",
type=str,
required=True,
help="path to Stage A checkpoint / Stage Aのチェックポイントのパス",
)
return parser
def add_stage_b_arguments(parser):
parser.add_argument(
"--stage_b_checkpoint_path",
type=str,
required=True,
help="path to Stage B checkpoint / Stage Bのチェックポイントのパス",
)
return parser
def add_stage_c_arguments(parser):
parser.add_argument(
"--stage_c_checkpoint_path",
type=str,
required=True,
help="path to Stage C checkpoint / Stage Cのチェックポイントのパス",
)
return parser
def add_previewer_arguments(parser):
parser.add_argument(
"--previewer_checkpoint_path",
type=str,
required=False,
help="path to previewer checkpoint / previewerのチェックポイントのパス",
)
return parser
def add_training_arguments(parser):
parser.add_argument(
"--adaptive_loss_weight",
action="store_true",
help="if specified, use adaptive loss weight. if not, use P2 loss weight"
+ " / Adaptive Loss Weightを使用する。指定しない場合はP2 Loss Weightを使用する",
)

View File

@@ -133,7 +133,6 @@ IMAGE_TRANSFORMS = transforms.Compose(
)
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
STABLE_CASCADE_LATENTS_CACHE_SUFFIX = "_sc_latents.npz"
class ImageInfo:
@@ -365,6 +364,8 @@ class BaseSubset:
caption_separator: str,
keep_tokens: int,
keep_tokens_separator: str,
secondary_separator: Optional[str],
enable_wildcard: bool,
color_aug: bool,
flip_aug: bool,
face_crop_aug_range: Optional[Tuple[float, float]],
@@ -383,6 +384,8 @@ class BaseSubset:
self.caption_separator = caption_separator
self.keep_tokens = keep_tokens
self.keep_tokens_separator = keep_tokens_separator
self.secondary_separator = secondary_separator
self.enable_wildcard = enable_wildcard
self.color_aug = color_aug
self.flip_aug = flip_aug
self.face_crop_aug_range = face_crop_aug_range
@@ -411,6 +414,8 @@ class DreamBoothSubset(BaseSubset):
caption_separator: str,
keep_tokens,
keep_tokens_separator,
secondary_separator,
enable_wildcard,
color_aug,
flip_aug,
face_crop_aug_range,
@@ -432,6 +437,8 @@ class DreamBoothSubset(BaseSubset):
caption_separator,
keep_tokens,
keep_tokens_separator,
secondary_separator,
enable_wildcard,
color_aug,
flip_aug,
face_crop_aug_range,
@@ -467,6 +474,8 @@ class FineTuningSubset(BaseSubset):
caption_separator,
keep_tokens,
keep_tokens_separator,
secondary_separator,
enable_wildcard,
color_aug,
flip_aug,
face_crop_aug_range,
@@ -488,6 +497,8 @@ class FineTuningSubset(BaseSubset):
caption_separator,
keep_tokens,
keep_tokens_separator,
secondary_separator,
enable_wildcard,
color_aug,
flip_aug,
face_crop_aug_range,
@@ -520,6 +531,8 @@ class ControlNetSubset(BaseSubset):
caption_separator,
keep_tokens,
keep_tokens_separator,
secondary_separator,
enable_wildcard,
color_aug,
flip_aug,
face_crop_aug_range,
@@ -541,6 +554,8 @@ class ControlNetSubset(BaseSubset):
caption_separator,
keep_tokens,
keep_tokens_separator,
secondary_separator,
enable_wildcard,
color_aug,
flip_aug,
face_crop_aug_range,
@@ -676,15 +691,41 @@ class BaseDataset(torch.utils.data.Dataset):
if is_drop_out:
caption = ""
else:
# process wildcards
if subset.enable_wildcard:
# wildcard is like '{aaa|bbb|ccc...}'
# escape the curly braces like {{ or }}
replacer1 = ""
replacer2 = ""
while replacer1 in caption or replacer2 in caption:
replacer1 += ""
replacer2 += ""
caption = caption.replace("{{", replacer1).replace("}}", replacer2)
# replace the wildcard
def replace_wildcard(match):
return random.choice(match.group(1).split("|"))
caption = re.sub(r"\{([^}]+)\}", replace_wildcard, caption)
# unescape the curly braces
caption = caption.replace(replacer1, "{").replace(replacer2, "}")
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
fixed_tokens = []
flex_tokens = []
fixed_suffix_tokens = []
if (
hasattr(subset, "keep_tokens_separator")
and subset.keep_tokens_separator
and subset.keep_tokens_separator in caption
):
fixed_part, flex_part = caption.split(subset.keep_tokens_separator, 1)
if subset.keep_tokens_separator in flex_part:
flex_part, fixed_suffix_part = flex_part.split(subset.keep_tokens_separator, 1)
fixed_suffix_tokens = [t.strip() for t in fixed_suffix_part.split(subset.caption_separator) if t.strip()]
fixed_tokens = [t.strip() for t in fixed_part.split(subset.caption_separator) if t.strip()]
flex_tokens = [t.strip() for t in flex_part.split(subset.caption_separator) if t.strip()]
else:
@@ -719,7 +760,11 @@ class BaseDataset(torch.utils.data.Dataset):
flex_tokens = dropout_tags(flex_tokens)
caption = ", ".join(fixed_tokens + flex_tokens)
caption = ", ".join(fixed_tokens + flex_tokens + fixed_suffix_tokens)
# process secondary separator
if subset.secondary_separator:
caption = caption.replace(subset.secondary_separator, subset.caption_separator)
# textual inversion対応
for str_from, str_to in self.replacements.items():
@@ -857,7 +902,7 @@ class BaseDataset(torch.utils.data.Dataset):
logger.info(f"mean ar error (without repeats): {mean_img_ar_error}")
# データ参照用indexを作る。このindexはdatasetのshuffleに用いられる
self.buckets_indices: List[BucketBatchIndex] = []
self.buckets_indices: List(BucketBatchIndex) = []
for bucket_index, bucket in enumerate(self.bucket_manager.buckets):
batch_count = int(math.ceil(len(bucket) / self.batch_size))
for batch_index in range(batch_count):
@@ -911,7 +956,7 @@ class BaseDataset(torch.utils.data.Dataset):
]
)
def cache_latents(self, vae, vae_batch_size, cache_to_disk, is_main_process, cache_file_suffix, divisor):
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
# マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
logger.info("caching latents.")
@@ -932,11 +977,11 @@ class BaseDataset(torch.utils.data.Dataset):
# check disk cache exists and size of latents
if cache_to_disk:
info.latents_npz = os.path.splitext(info.absolute_path)[0] + cache_file_suffix
info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz"
if not is_main_process: # store to info only
continue
cache_available = is_disk_cached_latents_is_expected(info.bucket_reso, info.latents_npz, subset.flip_aug, divisor)
cache_available = is_disk_cached_latents_is_expected(info.bucket_reso, info.latents_npz, subset.flip_aug)
if cache_available: # do not add to batch
continue
@@ -968,13 +1013,9 @@ class BaseDataset(torch.utils.data.Dataset):
# SDXLでのみ有効だが、datasetのメソッドとする必要があるので、sdxl_train_util.pyではなくこちらに実装する
# SD1/2に対応するにはv2のフラグを持つ必要があるので後回し
def cache_text_encoder_outputs(
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process, cache_file_suffix
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True
):
"""
最後の Text Encoder の pool がキャッシュされる。
The last Text Encoder's pool is cached.
"""
# assert len(tokenizers) == 2, "only support SDXL"
assert len(tokenizers) == 2, "only support SDXL"
# latentsのキャッシュと同様に、ディスクへのキャッシュに対応する
# またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
@@ -986,7 +1027,7 @@ class BaseDataset(torch.utils.data.Dataset):
for info in tqdm(image_infos):
# subset = self.image_to_subset[info.image_key]
if cache_to_disk:
te_out_npz = os.path.splitext(info.absolute_path)[0] + cache_file_suffix
te_out_npz = os.path.splitext(info.absolute_path)[0] + TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX
info.text_encoder_outputs_npz = te_out_npz
if not is_main_process: # store to info only
@@ -1011,7 +1052,7 @@ class BaseDataset(torch.utils.data.Dataset):
batches = []
for info in image_infos_to_cache:
input_ids1 = self.get_input_ids(info.caption, tokenizers[0])
input_ids2 = self.get_input_ids(info.caption, tokenizers[1]) if len(tokenizers) > 1 else None
input_ids2 = self.get_input_ids(info.caption, tokenizers[1])
batch.append((info, input_ids1, input_ids2))
if len(batch) >= self.batch_size:
@@ -1026,7 +1067,7 @@ class BaseDataset(torch.utils.data.Dataset):
for batch in tqdm(batches):
infos, input_ids1, input_ids2 = zip(*batch)
input_ids1 = torch.stack(input_ids1, dim=0)
input_ids2 = torch.stack(input_ids2, dim=0) if input_ids2[0] is not None else None
input_ids2 = torch.stack(input_ids2, dim=0)
cache_batch_text_encoder_outputs(
infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, input_ids2, weight_dtype
)
@@ -1275,9 +1316,7 @@ class BaseDataset(torch.utils.data.Dataset):
# example["input_ids"] = torch.stack([self.get_input_ids(cap, self.tokenizers[0]) for cap in captions])
# example["input_ids2"] = torch.stack([self.get_input_ids(cap, self.tokenizers[1]) for cap in captions])
example["text_encoder_outputs1_list"] = torch.stack(text_encoder_outputs1_list)
example["text_encoder_outputs2_list"] = (
torch.stack(text_encoder_outputs2_list) if text_encoder_outputs2_list[0] is not None else None
)
example["text_encoder_outputs2_list"] = torch.stack(text_encoder_outputs2_list)
example["text_encoder_pool2_list"] = torch.stack(text_encoder_pool2_list)
if images[0] is not None:
@@ -1334,7 +1373,7 @@ class BaseDataset(torch.utils.data.Dataset):
if self.caching_mode == "text":
input_ids1 = self.get_input_ids(caption, self.tokenizers[0])
input_ids2 = self.get_input_ids(caption, self.tokenizers[1]) if len(self.tokenizers) > 1 else None
input_ids2 = self.get_input_ids(caption, self.tokenizers[1])
else:
input_ids1 = None
input_ids2 = None
@@ -1606,15 +1645,12 @@ class FineTuningDataset(BaseDataset):
# なければnpzを探す
if abs_path is None:
abs_path = os.path.splitext(image_key)[0] + ".npz"
if not os.path.exists(abs_path):
abs_path = os.path.splitext(image_key)[0] + STABLE_CASCADE_LATENTS_CACHE_SUFFIX
if not os.path.exists(abs_path):
abs_path = os.path.join(subset.image_dir, image_key + ".npz")
if not os.path.exists(abs_path):
abs_path = os.path.join(subset.image_dir, image_key + STABLE_CASCADE_LATENTS_CACHE_SUFFIX)
if not os.path.exists(abs_path):
abs_path = None
if os.path.exists(os.path.splitext(image_key)[0] + ".npz"):
abs_path = os.path.splitext(image_key)[0] + ".npz"
else:
npz_path = os.path.join(subset.image_dir, image_key + ".npz")
if os.path.exists(npz_path):
abs_path = npz_path
assert abs_path is not None, f"no image / 画像がありません: {image_key}"
@@ -1634,7 +1670,7 @@ class FineTuningDataset(BaseDataset):
if not subset.color_aug and not subset.random_crop:
# if npz exists, use them
image_info.latents_npz = self.image_key_to_npz_file(subset, image_key)
image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key)
self.register_image(image_info, subset)
@@ -1648,7 +1684,7 @@ class FineTuningDataset(BaseDataset):
# check existence of all npz files
use_npz_latents = all([not (subset.color_aug or subset.random_crop) for subset in self.subsets])
if use_npz_latents:
# flip_aug_in_subset = False
flip_aug_in_subset = False
npz_any = False
npz_all = True
@@ -1658,12 +1694,9 @@ class FineTuningDataset(BaseDataset):
has_npz = image_info.latents_npz is not None
npz_any = npz_any or has_npz
# flip は同一の .npz 内に格納するようにした:
# そのためここでチェック漏れがあり実行時にエラーになる可能性があるので要検討
# if subset.flip_aug:
# has_npz = has_npz and image_info.latents_npz_flipped is not None
# flip_aug_in_subset = True
if subset.flip_aug:
has_npz = has_npz and image_info.latents_npz_flipped is not None
flip_aug_in_subset = True
npz_all = npz_all and has_npz
if npz_any and not npz_all:
@@ -1677,8 +1710,8 @@ class FineTuningDataset(BaseDataset):
logger.warning(
f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します"
)
# if flip_aug_in_subset:
# logger.warning("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
if flip_aug_in_subset:
logger.warning("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
# else:
# logger.info("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
@@ -1727,29 +1760,34 @@ class FineTuningDataset(BaseDataset):
# npz情報をきれいにしておく
if not use_npz_latents:
for image_info in self.image_data.values():
image_info.latents_npz = None # image_info.latents_npz_flipped =
image_info.latents_npz = image_info.latents_npz_flipped = None
def image_key_to_npz_file(self, subset: FineTuningSubset, image_key):
base_name = os.path.splitext(image_key)[0]
npz_file_norm = base_name + ".npz"
if not os.path.exists(npz_file_norm):
npz_file_norm = base_name + STABLE_CASCADE_LATENTS_CACHE_SUFFIX
if os.path.exists(npz_file_norm):
return npz_file_norm
# image_key is full path
npz_file_flip = base_name + "_flip.npz"
if not os.path.exists(npz_file_flip):
npz_file_flip = None
return npz_file_norm, npz_file_flip
# if not full path, check image_dir. if image_dir is None, return None
if subset.image_dir is None:
return None
return None, None
# image_key is relative path
npz_file_norm = os.path.join(subset.image_dir, image_key + ".npz")
if not os.path.exists(npz_file_norm):
npz_file_norm = os.path.join(subset.image_dir, base_name + STABLE_CASCADE_LATENTS_CACHE_SUFFIX)
if os.path.exists(npz_file_norm):
return npz_file_norm
npz_file_flip = os.path.join(subset.image_dir, image_key + "_flip.npz")
return None
if not os.path.exists(npz_file_norm):
npz_file_norm = None
npz_file_flip = None
elif not os.path.exists(npz_file_flip):
npz_file_flip = None
return npz_file_norm, npz_file_flip
class ControlNetDataset(BaseDataset):
@@ -1782,6 +1820,8 @@ class ControlNetDataset(BaseDataset):
subset.caption_separator,
subset.keep_tokens,
subset.keep_tokens_separator,
subset.secondary_separator,
subset.enable_wildcard,
subset.color_aug,
subset.flip_aug,
subset.face_crop_aug_range,
@@ -1951,26 +1991,17 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
for dataset in self.datasets:
dataset.enable_XTI(*args, **kwargs)
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, cache_file_suffix=".npz", divisor=8):
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
for i, dataset in enumerate(self.datasets):
logger.info(f"[Dataset {i}]")
dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, cache_file_suffix, divisor)
dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process)
def cache_text_encoder_outputs(
self,
tokenizers,
text_encoders,
device,
weight_dtype,
cache_to_disk=False,
is_main_process=True,
cache_file_suffix=TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX,
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True
):
for i, dataset in enumerate(self.datasets):
logger.info(f"[Dataset {i}]")
dataset.cache_text_encoder_outputs(
tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process, cache_file_suffix
)
dataset.cache_text_encoder_outputs(tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process)
def set_caching_mode(self, caching_mode):
for dataset in self.datasets:
@@ -2003,8 +2034,8 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
dataset.disable_token_padding()
def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, divisor: int = 8) -> bool:
expected_latents_size = (reso[1] // divisor, reso[0] // divisor) # bucket_resoはWxHなので注意
def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool):
expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意
if not os.path.exists(npz_path):
return False
@@ -2096,7 +2127,7 @@ def debug_dataset(train_dataset, show_input_ids=False):
if show_input_ids:
logger.info(f"input ids: {iid}")
if "input_ids2" in example and example["input_ids2"] is not None:
if "input_ids2" in example:
logger.info(f"input ids2: {example['input_ids2'][j]}")
if example["images"] is not None:
im = example["images"][j]
@@ -2273,7 +2304,7 @@ def trim_and_resize_if_required(
def cache_batch_latents(
vae: Union[AutoencoderKL, torch.nn.Module], cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, random_crop: bool
vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, random_crop: bool
) -> None:
r"""
requires image_infos to have: absolute_path, bucket_reso, resized_size, latents_npz
@@ -2328,36 +2359,23 @@ def cache_batch_text_encoder_outputs(
image_infos, tokenizers, text_encoders, max_token_length, cache_to_disk, input_ids1, input_ids2, dtype
):
input_ids1 = input_ids1.to(text_encoders[0].device)
input_ids2 = input_ids2.to(text_encoders[1].device) if input_ids2 is not None else None
input_ids2 = input_ids2.to(text_encoders[1].device)
with torch.no_grad():
# TODO SDXL と Stable Cascade で統一する
if len(tokenizers) == 1:
# Stable Cascade
b_hidden_state1, b_pool2 = get_hidden_states_stable_cascade(
max_token_length, input_ids1, tokenizers[0], text_encoders[0], dtype
)
b_hidden_state1 = b_hidden_state1.detach().to("cpu") # b,n*75+2,768
b_pool2 = b_pool2.detach().to("cpu") # b,1280
b_hidden_state2 = [None] * input_ids1.shape[0]
else:
# SDXL
b_hidden_state1, b_hidden_state2, b_pool2 = get_hidden_states_sdxl(
max_token_length,
input_ids1,
input_ids2,
tokenizers[0],
tokenizers[1],
text_encoders[0],
text_encoders[1],
dtype,
)
b_hidden_state1, b_hidden_state2, b_pool2 = get_hidden_states_sdxl(
max_token_length,
input_ids1,
input_ids2,
tokenizers[0],
tokenizers[1],
text_encoders[0],
text_encoders[1],
dtype,
)
# ここでcpuに移動しておかないと、上書きされてしまう
b_hidden_state1 = b_hidden_state1.detach().to("cpu") # b,n*75+2,768
b_hidden_state2 = b_hidden_state2.detach().to("cpu") if b_hidden_state2[0] is not None else b_hidden_state2 # b,n*75+2,1280
b_hidden_state2 = b_hidden_state2.detach().to("cpu") # b,n*75+2,1280
b_pool2 = b_pool2.detach().to("cpu") # b,1280
for info, hidden_state1, hidden_state2, pool2 in zip(image_infos, b_hidden_state1, b_hidden_state2, b_pool2):
@@ -2370,25 +2388,18 @@ def cache_batch_text_encoder_outputs(
def save_text_encoder_outputs_to_disk(npz_path, hidden_state1, hidden_state2, pool2):
save_kwargs = {
"hidden_state1": hidden_state1.cpu().float().numpy(),
"pool2": pool2.cpu().float().numpy(),
}
if hidden_state2 is not None:
save_kwargs["hidden_state2"] = hidden_state2.cpu().float().numpy()
np.savez(npz_path, **save_kwargs)
# np.savez(
# npz_path,
# hidden_state1=hidden_state1.cpu().float().numpy(),
# hidden_state2=hidden_state2.cpu().float().numpy() if hidden_state2 is not None else None,
# pool2=pool2.cpu().float().numpy(),
# )
np.savez(
npz_path,
hidden_state1=hidden_state1.cpu().float().numpy(),
hidden_state2=hidden_state2.cpu().float().numpy(),
pool2=pool2.cpu().float().numpy(),
)
def load_text_encoder_outputs_from_disk(npz_path):
with np.load(npz_path) as f:
hidden_state1 = torch.from_numpy(f["hidden_state1"])
hidden_state2 = torch.from_numpy(f["hidden_state2"]) if "hidden_state2" in f and f["hidden_state2"] is not None else None
hidden_state2 = torch.from_numpy(f["hidden_state2"]) if "hidden_state2" in f else None
pool2 = torch.from_numpy(f["pool2"]) if "pool2" in f else None
return hidden_state1, hidden_state2, pool2
@@ -2735,15 +2746,6 @@ def get_sai_model_spec(
return metadata
def add_tokenizer_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--tokenizer_cache_dir",
type=str,
default=None,
help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリネット接続なしでの学習のため",
)
def add_sd_models_arguments(parser: argparse.ArgumentParser):
# for pretrained models
parser.add_argument(
@@ -2758,7 +2760,12 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
default=None,
help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル",
)
add_tokenizer_arguments(parser)
parser.add_argument(
"--tokenizer_cache_dir",
type=str,
default=None,
help="directory for caching Tokenizer (for offline training) / Tokenizerをキャッシュするディレクトリネット接続なしでの学習のため",
)
def add_optimizer_arguments(parser: argparse.ArgumentParser):
@@ -2929,7 +2936,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument(
"--save_state",
action="store_true",
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する",
help="save training state additionally (including optimizer states etc.) when saving model / optimizerなど学習状態も含めたstateをモデル保存時に追加で保存する",
)
parser.add_argument(
"--save_state_on_train_end",
action="store_true",
help="save training state (including optimizer states etc.) on train end / optimizerなど学習状態も含めたstateを学習完了時に保存する",
)
parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
@@ -3080,6 +3092,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=None,
help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する有効にする場合は0.1程度を推奨)",
)
parser.add_argument(
"--noise_offset_random_strength",
action="store_true",
help="use random strength between 0~noise_offset for noise offset. / noise offsetにおいて、0からnoise_offsetの間でランダムな強度を使用します。",
)
parser.add_argument(
"--multires_noise_iterations",
type=int,
@@ -3093,6 +3110,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
help="enable input perturbation noise. used for regularization. recommended value: around 0.1 (from arxiv.org/abs/2301.11706) "
+ "/ input perturbation noiseを有効にする。正則化に使用される。推奨値: 0.1程度 (arxiv.org/abs/2301.11706 より)",
)
parser.add_argument(
"--ip_noise_gamma_random_strength",
action="store_true",
help="Use random strength between 0~ip_noise_gamma for input perturbation noise."
+ "/ input perturbation noiseにおいて、0からip_noise_gammaの間でランダムな強度を使用します。",
)
# parser.add_argument(
# "--perlin_noise",
# type=int,
@@ -3246,16 +3269,6 @@ def verify_training_args(args: argparse.Namespace):
global HIGH_VRAM
HIGH_VRAM = True
if args.cache_latents_to_disk and not args.cache_latents:
args.cache_latents = True
logger.warning(
"cache_latents_to_disk is enabled, so cache_latents is also enabled / cache_latents_to_diskが有効なため、cache_latentsを有効にします"
)
if not hasattr(args, "v_parameterization"):
# Stable Cascade: skip following checks
return
if args.v_parameterization and not args.v2:
logger.warning(
"v_parameterization should be with v2 not v1 or sdxl / v1やsdxlでv_parameterizationを使用することは想定されていません"
@@ -3263,6 +3276,12 @@ def verify_training_args(args: argparse.Namespace):
if args.v2 and args.clip_skip is not None:
logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
if args.cache_latents_to_disk and not args.cache_latents:
args.cache_latents = True
logger.warning(
"cache_latents_to_disk is enabled, so cache_latents is also enabled / cache_latents_to_diskが有効なため、cache_latentsを有効にします"
)
# noise_offset, perlin_noise, multires_noise_iterations cannot be enabled at the same time
# # Listを使って数えてもいいけど並べてしまえ
# if args.noise_offset is not None and args.multires_noise_iterations is not None:
@@ -3329,6 +3348,18 @@ def add_dataset_arguments(
help="A custom separator to divide the caption into fixed and flexible parts. Tokens before this separator will not be shuffled. If not specified, '--keep_tokens' will be used to determine the fixed number of tokens."
+ " / captionを固定部分と可変部分に分けるためのカスタム区切り文字。この区切り文字より前のトークンはシャッフルされない。指定しない場合、'--keep_tokens'が固定部分のトークン数として使用される。",
)
parser.add_argument(
"--secondary_separator",
type=str,
default=None,
help="a secondary separator for caption. This separator is replaced to caption_separator after dropping/shuffling caption"
+ " / captionのセカンダリ区切り文字。この区切り文字はcaptionのドロップやシャッフル後にcaption_separatorに置き換えられる",
)
parser.add_argument(
"--enable_wildcard",
action="store_true",
help="enable wildcard for caption (e.g. '{image|picture|rendition}') / captionのワイルドカードを有効にする'{image|picture|rendition}'",
)
parser.add_argument(
"--caption_prefix",
type=str,
@@ -3519,7 +3550,7 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar
exit(1)
logger.info(f"Loading settings from {config_path}...")
with open(config_path, "r") as f:
with open(config_path, "r", encoding="utf-8") as f:
config_dict = toml.load(f)
# combine all sections into one
@@ -4342,54 +4373,6 @@ def get_hidden_states_sdxl(
return hidden_states1, hidden_states2, pool2
def get_hidden_states_stable_cascade(
max_token_length: int,
input_ids2: torch.Tensor,
tokenizer2: CLIPTokenizer,
text_encoder2: CLIPTextModel,
weight_dtype: Optional[str] = None,
accelerator: Optional[Accelerator] = None,
):
# ここに Stable Cascade 用のコードがあるのはとても気持ち悪いが、変に整理するよりわかりやすいので、とりあえずこのまま
# It's very awkward to have Stable Cascade code here, but it's easier to understand than to organize it in a strange way, so for now it's as it is.
# input_ids: b,n,77 -> b*n, 77
b_size = input_ids2.size()[0]
input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77
# text_encoder2
enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True)
hidden_states2 = enc_out["hidden_states"][-1] # ** last layer **
# pool2 = enc_out["text_embeds"]
unwrapped_text_encoder2 = text_encoder2 if accelerator is None else accelerator.unwrap_model(text_encoder2)
pool2 = pool_workaround(unwrapped_text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id)
# b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280
n_size = 1 if max_token_length is None else max_token_length // 75
hidden_states2 = hidden_states2.reshape((b_size, -1, hidden_states2.shape[-1]))
if max_token_length is not None:
# bs*3, 77, 768 or 1024
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
states_list = [hidden_states2[:, 0].unsqueeze(1)] # <BOS>
for i in range(1, max_token_length, tokenizer2.model_max_length):
chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2] # <BOS> の後から 最後の前まで
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
states_list.append(hidden_states2[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
hidden_states2 = torch.cat(states_list, dim=1)
# pool はnの最初のものを使う
pool2 = pool2[::n_size]
if weight_dtype is not None:
# this is required for additional network training
hidden_states2 = hidden_states2.to(weight_dtype)
return hidden_states2, pool2
def default_if_none(value, default):
return default if value is None else value
@@ -4689,7 +4672,11 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
noise = custom_train_functions.apply_noise_offset(latents, noise, args.noise_offset, args.adaptive_noise_scale)
if args.noise_offset_random_strength:
noise_offset = torch.rand(1, device=latents.device) * args.noise_offset
else:
noise_offset = args.noise_offset
noise = custom_train_functions.apply_noise_offset(latents, noise, noise_offset, args.adaptive_noise_scale)
if args.multires_noise_iterations:
noise = custom_train_functions.pyramid_noise_like(
noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount
@@ -4706,7 +4693,11 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
if args.ip_noise_gamma:
noisy_latents = noise_scheduler.add_noise(latents, noise + args.ip_noise_gamma * torch.randn_like(latents), timesteps)
if args.ip_noise_gamma_random_strength:
strength = torch.rand(1, device=latents.device) * args.ip_noise_gamma
else:
strength = args.ip_noise_gamma
noisy_latents = noise_scheduler.add_noise(latents, noise + strength * torch.randn_like(latents), timesteps)
else:
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

View File

@@ -841,14 +841,9 @@ class LoRANetwork(torch.nn.Module):
is_linear = child_module.__class__.__name__ == "Linear"
is_conv2d = child_module.__class__.__name__ == "Conv2d"
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
is_group_conv2d = is_conv2d and child_module.groups > 1
# if is_group_conv2d:
# logger.info(f"skip group conv2d: {name}.{child_name}")
# continue
if is_linear or (is_conv2d and not is_group_conv2d):
lora_name = prefix + "." + name + ("." + child_name if child_name else "")
if is_linear or is_conv2d:
lora_name = prefix + "." + name + "." + child_name
lora_name = lora_name.replace(".", "_")
dim = None
@@ -920,11 +915,6 @@ class LoRANetwork(torch.nn.Module):
if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
# XXX temporary solution for Stable Cascade Stage C: replace all modules
if "StageC" in unet.__class__.__name__:
logger.info("replace all modules for Stable Cascade Stage C")
target_modules = ["Linear", "Conv2d"]
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")

View File

@@ -22,9 +22,12 @@ huggingface-hub==0.20.1
# for WD14 captioning (tensorflow)
# tensorflow==2.10.1
# for WD14 captioning (onnx)
# onnx==1.14.1
# onnxruntime-gpu==1.16.0
# onnxruntime==1.16.0
# onnx==1.15.0
# onnxruntime-gpu==1.17.1
# onnxruntime==1.17.1
# for cuda 12.1(default 11.8)
# onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/
# this is for onnx:
# protobuf==3.20.3
# open clip for SDXL

View File

@@ -712,7 +712,7 @@ def train(args):
accelerator.end_training()
if args.save_state: # and is_main_process:
if args.save_state or args.save_state_on_train_end:
train_util.save_state_on_train_end(args, accelerator)
del accelerator # この後メモリを使うのでこれは消す

View File

@@ -549,7 +549,7 @@ def train(args):
accelerator.end_training()
if is_main_process and args.save_state:
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)
if is_main_process:

View File

@@ -1,367 +0,0 @@
import argparse
import importlib
import math
import os
import random
import time
import numpy as np
from safetensors.torch import load_file, save_file
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTextConfig
from PIL import Image
from accelerate import init_empty_weights
import library.stable_cascade as sc
import library.stable_cascade_utils as sc_utils
import library.device_utils as device_utils
from library import train_util
from library.sdxl_model_util import _load_state_dict_on_device
def main(args):
device = device_utils.get_preferred_device()
loading_device = device if not args.lowvram else "cpu"
text_model_device = "cpu"
dtype = torch.float32
if args.bf16:
dtype = torch.bfloat16
elif args.fp16:
dtype = torch.float16
text_model_dtype = torch.float32
# EfficientNet encoder
effnet = sc_utils.load_effnet(args.effnet_checkpoint_path, loading_device)
effnet.eval().requires_grad_(False).to(loading_device)
generator_c = sc_utils.load_stage_c_model(args.stage_c_checkpoint_path, dtype=dtype, device=loading_device)
generator_c.eval().requires_grad_(False).to(loading_device)
# if args.xformers or args.sdpa:
print(f"Stage C: use_xformers_or_sdpa: {args.xformers} {args.sdpa}")
generator_c.set_use_xformers_or_sdpa(args.xformers, args.sdpa)
generator_b = sc_utils.load_stage_b_model(args.stage_b_checkpoint_path, dtype=dtype, device=loading_device)
generator_b.eval().requires_grad_(False).to(loading_device)
# if args.xformers or args.sdpa:
print(f"Stage B: use_xformers_or_sdpa: {args.xformers} {args.sdpa}")
generator_b.set_use_xformers_or_sdpa(args.xformers, args.sdpa)
# CLIP encoders
tokenizer = sc_utils.load_tokenizer(args)
text_model = sc_utils.load_clip_text_model(
args.text_model_checkpoint_path, text_model_dtype, text_model_device, args.save_text_model
)
text_model = text_model.requires_grad_(False).to(text_model_dtype).to(text_model_device)
# image_model = (
# CLIPVisionModelWithProjection.from_pretrained(clip_image_model_name).requires_grad_(False).to(dtype).to(device)
# )
# vqGAN
stage_a = sc_utils.load_stage_a_model(args.stage_a_checkpoint_path, dtype=dtype, device=loading_device)
stage_a.eval().requires_grad_(False)
# previewer
if args.previewer_checkpoint_path is not None:
previewer = sc_utils.load_previewer_model(args.previewer_checkpoint_path, dtype=dtype, device=loading_device)
previewer.eval().requires_grad_(False)
else:
previewer = None
# LoRA
if args.network_module:
for i, network_module in enumerate(args.network_module):
print("import network module:", network_module)
imported_module = importlib.import_module(network_module)
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
net_kwargs = {}
if args.network_args and i < len(args.network_args):
network_args = args.network_args[i]
# TODO escape special chars
network_args = network_args.split(";")
for net_arg in network_args:
key, value = net_arg.split("=")
net_kwargs[key] = value
if args.network_weights is None or len(args.network_weights) <= i:
raise ValueError("No weight. Weight is required.")
network_weight = args.network_weights[i]
print("load network weights from:", network_weight)
network, weights_sd = imported_module.create_network_from_weights(
network_mul, network_weight, effnet, text_model, generator_c, for_inference=True, **net_kwargs
)
if network is None:
return
mergeable = network.is_mergeable()
assert mergeable, "not-mergeable network is not supported yet."
network.merge_to(text_model, generator_c, weights_sd, dtype, device)
# 謎のクラス gdf
gdf_c = sc.GDF(
schedule=sc.CosineSchedule(clamp_range=[0.0001, 0.9999]),
input_scaler=sc.VPScaler(),
target=sc.EpsilonTarget(),
noise_cond=sc.CosineTNoiseCond(),
loss_weight=None,
)
gdf_b = sc.GDF(
schedule=sc.CosineSchedule(clamp_range=[0.0001, 0.9999]),
input_scaler=sc.VPScaler(),
target=sc.EpsilonTarget(),
noise_cond=sc.CosineTNoiseCond(),
loss_weight=None,
)
# Stage C Parameters
# extras.sampling_configs["cfg"] = 4
# extras.sampling_configs["shift"] = 2
# extras.sampling_configs["timesteps"] = 20
# extras.sampling_configs["t_start"] = 1.0
# # Stage B Parameters
# extras_b.sampling_configs["cfg"] = 1.1
# extras_b.sampling_configs["shift"] = 1
# extras_b.sampling_configs["timesteps"] = 10
# extras_b.sampling_configs["t_start"] = 1.0
b_cfg = 1.1
b_shift = 1
b_timesteps = 10
b_t_start = 1.0
# caption = "Cinematic photo of an anthropomorphic penguin sitting in a cafe reading a book and having a coffee"
# height, width = 1024, 1024
while True:
print("type caption:")
# if Ctrl+Z is pressed, it will raise EOFError
try:
caption = input()
except EOFError:
break
caption = caption.strip()
if caption == "":
continue
# parse options: '--w' and '--h' for size, '--l' for cfg, '--s' for timesteps, '--f' for shift. if not specified, use default values
# e.g. "caption --w 4 --h 4 --l 20 --s 20 --f 1.0"
tokens = caption.split()
width = height = 1024
cfg = 4
timesteps = 20
shift = 2
t_start = 1.0
negative_prompt = ""
seed = None
caption_tokens = []
i = 0
while i < len(tokens):
token = tokens[i]
if i == len(tokens) - 1:
caption_tokens.append(token)
i += 1
continue
if token == "--w":
width = int(tokens[i + 1])
elif token == "--h":
height = int(tokens[i + 1])
elif token == "--l":
cfg = float(tokens[i + 1])
elif token == "--s":
timesteps = int(tokens[i + 1])
elif token == "--f":
shift = float(tokens[i + 1])
elif token == "--t":
t_start = float(tokens[i + 1])
elif token == "--n":
negative_prompt = tokens[i + 1]
elif token == "--d":
seed = int(tokens[i + 1])
else:
caption_tokens.append(token)
i += 1
continue
i += 2
caption = " ".join(caption_tokens)
stage_c_latent_shape, stage_b_latent_shape = sc_utils.calculate_latent_sizes(height, width, batch_size=1)
# PREPARE CONDITIONS
# cond_text, cond_pooled = sc.get_clip_conditions([caption], None, tokenizer, text_model)
input_ids = tokenizer(
[caption], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
)["input_ids"].to(text_model.device)
cond_text, cond_pooled = train_util.get_hidden_states_stable_cascade(
tokenizer.model_max_length, input_ids, tokenizer, text_model
)
cond_text = cond_text.to(device, dtype=dtype)
cond_pooled = cond_pooled.unsqueeze(1).to(device, dtype=dtype)
# uncond_text, uncond_pooled = sc.get_clip_conditions([""], None, tokenizer, text_model)
input_ids = tokenizer(
[negative_prompt], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
)["input_ids"].to(text_model.device)
uncond_text, uncond_pooled = train_util.get_hidden_states_stable_cascade(
tokenizer.model_max_length, input_ids, tokenizer, text_model
)
uncond_text = uncond_text.to(device, dtype=dtype)
uncond_pooled = uncond_pooled.unsqueeze(1).to(device, dtype=dtype)
zero_img_emb = torch.zeros(1, 768, device=device)
# 辞書にしたくないけど GDF から先の変更が面倒だからとりあえず辞書にしておく
conditions = {"clip_text_pooled": cond_pooled, "clip": cond_pooled, "clip_text": cond_text, "clip_img": zero_img_emb}
unconditions = {
"clip_text_pooled": uncond_pooled,
"clip": uncond_pooled,
"clip_text": uncond_text,
"clip_img": zero_img_emb,
}
conditions_b = {}
conditions_b.update(conditions)
unconditions_b = {}
unconditions_b.update(unconditions)
# seed everything
if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
if args.lowvram:
generator_c = generator_c.to(device)
with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype):
sampling_c = gdf_c.sample(
generator_c,
conditions,
stage_c_latent_shape,
unconditions,
device=device,
cfg=cfg,
shift=shift,
timesteps=timesteps,
t_start=t_start,
)
for sampled_c, _, _ in tqdm(sampling_c, total=timesteps):
sampled_c = sampled_c
conditions_b["effnet"] = sampled_c
unconditions_b["effnet"] = torch.zeros_like(sampled_c)
if previewer is not None:
with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype):
preview = previewer(sampled_c)
preview = preview.clamp(0, 1)
preview = preview.permute(0, 2, 3, 1).squeeze(0)
preview = preview.detach().float().cpu().numpy()
preview = Image.fromarray((preview * 255).astype(np.uint8))
timestamp_str = time.strftime("%Y%m%d_%H%M%S")
os.makedirs(args.outdir, exist_ok=True)
preview.save(os.path.join(args.outdir, f"preview_{timestamp_str}.png"))
if args.lowvram:
generator_c = generator_c.to(loading_device)
device_utils.clean_memory_on_device(device)
generator_b = generator_b.to(device)
with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype):
sampling_b = gdf_b.sample(
generator_b,
conditions_b,
stage_b_latent_shape,
unconditions_b,
device=device,
cfg=b_cfg,
shift=b_shift,
timesteps=b_timesteps,
t_start=b_t_start,
)
for sampled_b, _, _ in tqdm(sampling_b, total=b_t_start):
sampled_b = sampled_b
if args.lowvram:
generator_b = generator_b.to(loading_device)
device_utils.clean_memory_on_device(device)
stage_a = stage_a.to(device)
with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype):
sampled = stage_a.decode(sampled_b).float()
# print(sampled.shape, sampled.min(), sampled.max())
if args.lowvram:
stage_a = stage_a.to(loading_device)
device_utils.clean_memory_on_device(device)
# float 0-1 to PIL Image
sampled = sampled.clamp(0, 1)
sampled = sampled.mul(255).to(dtype=torch.uint8)
sampled = sampled.permute(0, 2, 3, 1)
sampled = sampled.cpu().numpy()
sampled = Image.fromarray(sampled[0])
timestamp_str = time.strftime("%Y%m%d_%H%M%S")
os.makedirs(args.outdir, exist_ok=True)
sampled.save(os.path.join(args.outdir, f"sampled_{timestamp_str}.png"))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
sc_utils.add_effnet_arguments(parser)
train_util.add_tokenizer_arguments(parser)
sc_utils.add_stage_a_arguments(parser)
sc_utils.add_stage_b_arguments(parser)
sc_utils.add_stage_c_arguments(parser)
sc_utils.add_previewer_arguments(parser)
sc_utils.add_text_model_arguments(parser)
parser.add_argument("--bf16", action="store_true")
parser.add_argument("--fp16", action="store_true")
parser.add_argument("--xformers", action="store_true")
parser.add_argument("--sdpa", action="store_true")
parser.add_argument("--outdir", type=str, default="../outputs", help="dir to write results to / 生成画像の出力先")
parser.add_argument("--lowvram", action="store_true", help="if specified, use low VRAM mode")
parser.add_argument(
"--network_module",
type=str,
default=None,
nargs="*",
help="additional network module to use / 追加ネットワークを使う時そのモジュール名",
)
parser.add_argument(
"--network_weights", type=str, default=None, nargs="*", help="additional network weights to load / 追加ネットワークの重み"
)
parser.add_argument(
"--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率"
)
parser.add_argument(
"--network_args",
type=str,
default=None,
nargs="*",
help="additional arguments for network (key=value) / ネットワークへの追加の引数",
)
args = parser.parse_args()
main(args)

File diff suppressed because it is too large Load Diff

View File

@@ -1,564 +0,0 @@
# training with captions
import argparse
import math
import os
from multiprocessing import Value
from typing import List
import toml
from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
import library.train_util as train_util
from library.sdxl_train_util import add_sdxl_training_arguments
import library.stable_cascade_utils as sc_utils
import library.stable_cascade as sc
from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
logger = logging.getLogger(__name__)
import library.config_util as config_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
setup_logging(args, reset=True)
# assert (
# not args.weighted_captions
# ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
# TODO add assertions for other unsupported options
cache_latents = args.cache_latents
use_dreambooth_method = args.in_json is None
if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する
tokenizer = sc_utils.load_tokenizer(args)
# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
if args.dataset_config is not None:
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
if use_dreambooth_method:
logger.info("Using DreamBooth method.")
user_config = {
"datasets": [
{
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
args.train_data_dir, args.reg_data_dir
)
}
]
}
else:
logger.info("Training with captions.")
user_config = {
"datasets": [
{
"subsets": [
{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}
]
}
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer])
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, [tokenizer])
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
train_dataset_group.verify_bucket_reso_steps(32)
if args.debug_dataset:
train_util.debug_dataset(train_dataset_group, True)
return
if len(train_dataset_group) == 0:
logger.error(
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
)
return
if cache_latents:
assert (
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
if args.cache_text_encoder_outputs:
assert (
train_dataset_group.is_text_encoder_output_cacheable()
), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
# acceleratorを準備する
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
effnet_dtype = torch.float32 if args.no_half_vae else weight_dtype
# モデルを読み込む
loading_device = accelerator.device if args.lowram else "cpu"
effnet = sc_utils.load_effnet(args.effnet_checkpoint_path, loading_device)
stage_c = sc_utils.load_stage_c_model(args.stage_c_checkpoint_path, device=loading_device) # dtype is as it is
text_encoder1 = sc_utils.load_clip_text_model(args.text_model_checkpoint_path, dtype=weight_dtype, device=loading_device)
if args.sample_at_first or args.sample_every_n_steps is not None or args.sample_every_n_epochs is not None:
# Previewer is small enough to be loaded on CPU
previewer = sc_utils.load_previewer_model(args.previewer_checkpoint_path, dtype=torch.float32, device="cpu")
previewer.eval()
else:
previewer = None
# モデルに xformers とか memory efficient attention を組み込む
stage_c.set_use_xformers_or_sdpa(args.xformers, args.sdpa)
# 学習を準備する
if cache_latents:
effnet.to(accelerator.device, dtype=effnet_dtype)
effnet.requires_grad_(False)
effnet.eval()
with torch.no_grad():
train_dataset_group.cache_latents(
effnet,
args.vae_batch_size,
args.cache_latents_to_disk,
accelerator.is_main_process,
train_util.STABLE_CASCADE_LATENTS_CACHE_SUFFIX,
32,
)
effnet.to("cpu")
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
# 学習を準備する:モデルを適切な状態にする
if args.gradient_checkpointing:
accelerator.print("enable gradient checkpointing")
stage_c.set_gradient_checkpointing(True)
train_stage_c = args.learning_rate > 0
train_text_encoder1 = False
if args.train_text_encoder:
accelerator.print("enable text encoder training")
if args.gradient_checkpointing:
text_encoder1.gradient_checkpointing_enable()
lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train
train_text_encoder1 = lr_te1 > 0
assert (
train_text_encoder1
), "text_encoder1 learning rate is 0. Please set a positive value / text_encoder1の学習率が0です。正の値を設定してください。"
if not train_text_encoder1:
text_encoder1.to(weight_dtype)
text_encoder1.requires_grad_(train_text_encoder1)
text_encoder1.train(train_text_encoder1)
else:
text_encoder1.to(weight_dtype)
text_encoder1.requires_grad_(False)
text_encoder1.eval()
# TextEncoderの出力をキャッシュする
if args.cache_text_encoder_outputs:
# Text Encodes are eval and no grad
with torch.no_grad(), accelerator.autocast():
train_dataset_group.cache_text_encoder_outputs(
(tokenizer,),
(text_encoder1,),
accelerator.device,
None,
args.cache_text_encoder_outputs_to_disk,
accelerator.is_main_process,
sc_utils.TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX,
)
accelerator.wait_for_everyone()
if not cache_latents:
effnet.requires_grad_(False)
effnet.eval()
effnet.to(accelerator.device, dtype=effnet_dtype)
stage_c.requires_grad_(True)
if not train_stage_c:
stage_c.to(accelerator.device, dtype=weight_dtype) # because of stage_c will not be prepared
training_models = []
params_to_optimize = []
if train_stage_c:
training_models.append(stage_c)
params_to_optimize.append({"params": list(stage_c.parameters()), "lr": args.learning_rate})
if train_text_encoder1:
training_models.append(text_encoder1)
params_to_optimize.append({"params": list(text_encoder1.parameters()), "lr": args.learning_rate_te1 or args.learning_rate})
# calculate number of trainable parameters
n_params = 0
for params in params_to_optimize:
for p in params["params"]:
n_params += p.numel()
accelerator.print(f"train stage-C: {train_stage_c}, text_encoder1: {train_text_encoder1}")
accelerator.print(f"number of models: {len(training_models)}")
accelerator.print(f"number of trainable parameters: {n_params}")
# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# 実験的機能勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
if args.full_fp16:
assert (
args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
accelerator.print("enable full fp16 training.")
stage_c.to(weight_dtype)
text_encoder1.to(weight_dtype)
elif args.full_bf16:
assert (
args.mixed_precision == "bf16"
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
accelerator.print("enable full bf16 training.")
stage_c.to(weight_dtype)
text_encoder1.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
if train_stage_c:
stage_c = accelerator.prepare(stage_c)
if train_text_encoder1:
text_encoder1 = accelerator.prepare(text_encoder1)
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
if args.cache_text_encoder_outputs:
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
text_encoder1.to("cpu", dtype=torch.float32)
clean_memory_on_device(accelerator.device)
else:
# make sure Text Encoders are on GPU
text_encoder1.to(accelerator.device)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
# 学習する
# total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
accelerator.print("running training / 学習開始")
accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
accelerator.print(
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
)
# accelerator.print(
# f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
# )
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
# 謎のクラス GDF
gdf = sc.GDF(
schedule=sc.CosineSchedule(clamp_range=[0.0001, 0.9999]),
input_scaler=sc.VPScaler(),
target=sc.EpsilonTarget(),
noise_cond=sc.CosineTNoiseCond(),
loss_weight=sc.AdaptiveLossWeight() if args.adaptive_loss_weight else sc.P2LossWeight(),
)
# 以下2つの変数は、どうもデフォルトのままっぽい
# gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges'])
# gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses'])
if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
# For --sample_at_first
sc_utils.sample_images(accelerator, args, 0, global_step, previewer, tokenizer, text_encoder1, stage_c, gdf)
loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
for m in training_models:
m.train()
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(*training_models):
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
with torch.no_grad():
# latentに変換
# XXX Effnet preprocessing is included in encode method
latents = effnet.encode(batch["images"].to(effnet_dtype)).latent_dist.sample().to(weight_dtype)
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.nan_to_num(latents, 0, out=latents)
# # debug: decode latent with previewer and save it
# import time
# import numpy as np
# from PIL import Image
# ts = time.time()
# images = previewer(latents.to(previewer.device, dtype=previewer.dtype))
# for i, img in enumerate(images):
# img = img.detach().cpu().numpy().transpose(1, 2, 0)
# img = np.clip(img, 0, 1)
# img = (img * 255).astype(np.uint8)
# img = Image.fromarray(img)
# img.save(f"logs/previewer_{i}_{ts}.png")
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
input_ids1 = batch["input_ids"]
with torch.set_grad_enabled(args.train_text_encoder):
# Get the text embedding for conditioning
# TODO support weighted captions
input_ids1 = input_ids1.to(accelerator.device)
# unwrap_model is fine for models not wrapped by accelerator
encoder_hidden_states, pool = train_util.get_hidden_states_stable_cascade(
args.max_token_length,
input_ids1,
tokenizer,
text_encoder1,
None if not args.full_fp16 else weight_dtype,
accelerator,
)
else:
encoder_hidden_states = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
pool = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype)
pool = pool.unsqueeze(1) # add extra dimension b,1280 -> b,1,1280
# FORWARD PASS
with torch.no_grad():
noised, noise, target, logSNR, noise_cond, loss_weight = gdf.diffuse(latents, shift=1, loss_shift=1)
zero_img_emb = torch.zeros(noised.shape[0], 768, device=accelerator.device)
with accelerator.autocast():
pred = stage_c(
noised, noise_cond, clip_text=encoder_hidden_states, clip_text_pooled=pool, clip_img=zero_img_emb
)
loss = torch.nn.functional.mse_loss(pred, target, reduction="none").mean(dim=[1, 2, 3])
loss_adjusted = (loss * loss_weight).mean()
if args.adaptive_loss_weight:
gdf.loss_weight.update_buckets(logSNR, loss) # use loss instead of loss_adjusted
accelerator.backward(loss_adjusted)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = []
for m in training_models:
params_to_clip.extend(m.parameters())
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
sc_utils.sample_images(accelerator, args, None, global_step, previewer, tokenizer, text_encoder1, stage_c, gdf)
# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
sc_utils.save_stage_c_model_on_epoch_end_or_stepwise(
args,
False,
accelerator,
save_dtype,
epoch,
num_train_epochs,
global_step,
accelerator.unwrap_model(stage_c),
accelerator.unwrap_model(text_encoder1) if train_text_encoder1 else None,
)
current_loss = loss_adjusted.detach().item() # 平均なのでbatch sizeは関係ないはず
if args.logging_dir is not None:
logs = {"loss": current_loss}
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
accelerator.log(logs, step=global_step)
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
if args.save_every_n_epochs is not None:
if accelerator.is_main_process:
sc_utils.save_stage_c_model_on_epoch_end_or_stepwise(
args,
True,
accelerator,
save_dtype,
epoch,
num_train_epochs,
global_step,
accelerator.unwrap_model(stage_c),
accelerator.unwrap_model(text_encoder1) if train_text_encoder1 else None,
)
sc_utils.sample_images(accelerator, args, epoch + 1, global_step, previewer, tokenizer, text_encoder1, stage_c, gdf)
is_main_process = accelerator.is_main_process
# if is_main_process:
stage_c = accelerator.unwrap_model(stage_c)
text_encoder1 = accelerator.unwrap_model(text_encoder1)
accelerator.end_training()
if args.save_state: # and is_main_process:
train_util.save_state_on_train_end(args, accelerator)
del accelerator # この後メモリを使うのでこれは消す
if is_main_process:
sc_utils.save_stage_c_model_on_end(
args, save_dtype, epoch, global_step, stage_c, text_encoder1 if train_text_encoder1 else None
)
logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
add_logging_arguments(parser)
sc_utils.add_effnet_arguments(parser)
sc_utils.add_stage_c_arguments(parser)
sc_utils.add_text_model_arguments(parser)
sc_utils.add_previewer_arguments(parser)
sc_utils.add_training_arguments(parser)
train_util.add_tokenizer_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, False)
train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
add_sdxl_training_arguments(parser) # cache text encoder outputs
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
parser.add_argument(
"--learning_rate_te1",
type=float,
default=None,
help="learning rate for text encoder / text encoderの学習率",
)
parser.add_argument(
"--no_half_vae",
action="store_true",
help="do not use fp16/bf16 Effnet in mixed precision (use float Effnet) / mixed precisionでも fp16/bf16 Effnetを使わずfloat Effnetを使う",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
train(args)

View File

@@ -1,191 +0,0 @@
# Stable Cascadeのlatentsをdiskにキャッシュする
# cache latents of Stable Cascade to disk
import argparse
import math
from multiprocessing import Value
import os
from accelerate.utils import set_seed
import torch
from tqdm import tqdm
from library import stable_cascade_utils as sc_utils
from library import config_util
from library import train_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def cache_to_disk(args: argparse.Namespace) -> None:
train_util.prepare_dataset_args(args, True)
# check cache latents arg
assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります"
use_dreambooth_method = args.in_json is None
if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する
# tokenizerを準備するdatasetを動かすために必要
tokenizer = sc_utils.load_tokenizer(args)
tokenizers = [tokenizer]
# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
if args.dataset_config is not None:
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
if use_dreambooth_method:
logger.info("Using DreamBooth method.")
user_config = {
"datasets": [
{
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
args.train_data_dir, args.reg_data_dir
)
}
]
}
else:
logger.info("Training with captions.")
user_config = {
"datasets": [
{
"subsets": [
{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}
]
}
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers)
# datasetのcache_latentsを呼ばなければ、生の画像が返る
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
# acceleratorを準備する
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, _ = train_util.prepare_dtype(args)
effnet_dtype = torch.float32 if args.no_half_vae else weight_dtype
# モデルを読み込む
logger.info("load model")
effnet = sc_utils.load_effnet(args.effnet_checkpoint_path, accelerator.device)
effnet.to(accelerator.device, dtype=effnet_dtype)
effnet.requires_grad_(False)
effnet.eval()
# dataloaderを準備する
train_dataset_group.set_caching_mode("latents")
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# acceleratorを使ってモデルを準備するマルチGPUで使えるようになるはず
train_dataloader = accelerator.prepare(train_dataloader)
# データ取得のためのループ
for batch in tqdm(train_dataloader):
b_size = len(batch["images"])
vae_batch_size = b_size if args.vae_batch_size is None else args.vae_batch_size
flip_aug = batch["flip_aug"]
random_crop = batch["random_crop"]
bucket_reso = batch["bucket_reso"]
# バッチを分割して処理する
for i in range(0, b_size, vae_batch_size):
images = batch["images"][i : i + vae_batch_size]
absolute_paths = batch["absolute_paths"][i : i + vae_batch_size]
resized_sizes = batch["resized_sizes"][i : i + vae_batch_size]
image_infos = []
for i, (image, absolute_path, resized_size) in enumerate(zip(images, absolute_paths, resized_sizes)):
image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path)
image_info.image = image
image_info.bucket_reso = bucket_reso
image_info.resized_size = resized_size
image_info.latents_npz = os.path.splitext(absolute_path)[0] + train_util.STABLE_CASCADE_LATENTS_CACHE_SUFFIX
if args.skip_existing:
if train_util.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug, 32):
logger.warning(f"Skipping {image_info.latents_npz} because it already exists.")
continue
image_infos.append(image_info)
if len(image_infos) > 0:
train_util.cache_batch_latents(effnet, True, image_infos, flip_aug, random_crop)
accelerator.wait_for_everyone()
accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
train_util.add_tokenizer_arguments(parser)
sc_utils.add_effnet_arguments(parser)
train_util.add_training_arguments(parser, True)
train_util.add_dataset_arguments(parser, True, True, True)
config_util.add_config_arguments(parser)
parser.add_argument(
"--no_half_vae",
action="store_true",
help="do not use fp16/bf16 Effnet in mixed precision (use float Effnet) / mixed precisionでも fp16/bf16 Effnetを使わずfloat Effnetを使う",
)
parser.add_argument(
"--skip_existing",
action="store_true",
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップするflip_aug有効時は通常、反転の両方が存在する画像をスキップ",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
cache_to_disk(args)

View File

@@ -1,183 +0,0 @@
# text encoder出力のdiskへの事前キャッシュを行う / cache text encoder outputs to disk in advance
import argparse
import math
from multiprocessing import Value
import os
from accelerate.utils import set_seed
import torch
from tqdm import tqdm
from library import config_util
from library import train_util
from library import sdxl_train_util
from library import stable_cascade_utils as sc_utils
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
def cache_to_disk(args: argparse.Namespace) -> None:
train_util.prepare_dataset_args(args, True)
# check cache arg
assert (
args.cache_text_encoder_outputs_to_disk
), "cache_text_encoder_outputs_to_disk must be True / cache_text_encoder_outputs_to_diskはTrueである必要があります"
use_dreambooth_method = args.in_json is None
if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する
# tokenizerを準備するdatasetを動かすために必要
tokenizer = sc_utils.load_tokenizer(args)
tokenizers = [tokenizer]
# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
if args.dataset_config is not None:
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
if use_dreambooth_method:
logger.info("Using DreamBooth method.")
user_config = {
"datasets": [
{
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
args.train_data_dir, args.reg_data_dir
)
}
]
}
else:
logger.info("Training with captions.")
user_config = {
"datasets": [
{
"subsets": [
{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}
]
}
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
# acceleratorを準備する
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, _ = train_util.prepare_dtype(args)
# モデルを読み込む
logger.info("load model")
text_encoder = sc_utils.load_clip_text_model(
args.text_model_checkpoint_path, weight_dtype, accelerator.device, args.save_text_model
)
text_encoders = [text_encoder]
for text_encoder in text_encoders:
text_encoder.to(accelerator.device, dtype=weight_dtype)
text_encoder.requires_grad_(False)
text_encoder.eval()
# dataloaderを準備する
train_dataset_group.set_caching_mode("text")
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# acceleratorを使ってモデルを準備するマルチGPUで使えるようになるはず
train_dataloader = accelerator.prepare(train_dataloader)
# データ取得のためのループ
for batch in tqdm(train_dataloader):
absolute_paths = batch["absolute_paths"]
input_ids1_list = batch["input_ids1_list"]
image_infos = []
for absolute_path, input_ids1 in zip(absolute_paths, input_ids1_list):
image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path)
image_info.text_encoder_outputs_npz = os.path.splitext(absolute_path)[0] + sc_utils.TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX
image_info
if args.skip_existing:
if os.path.exists(image_info.text_encoder_outputs_npz):
logger.warning(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.")
continue
image_info.input_ids1 = input_ids1
image_infos.append(image_info)
if len(image_infos) > 0:
b_input_ids1 = torch.stack([image_info.input_ids1 for image_info in image_infos])
train_util.cache_batch_text_encoder_outputs(
image_infos, tokenizers, text_encoders, args.max_token_length, True, b_input_ids1, None, weight_dtype
)
accelerator.wait_for_everyone()
accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
train_util.add_tokenizer_arguments(parser)
sc_utils.add_text_model_arguments(parser)
train_util.add_training_arguments(parser, True)
train_util.add_dataset_arguments(parser, True, True, True)
config_util.add_config_arguments(parser)
sdxl_train_util.add_sdxl_training_arguments(parser)
parser.add_argument(
"--skip_existing",
action="store_true",
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップするflip_aug有効時は通常、反転の両方が存在する画像をスキップ",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
cache_to_disk(args)

View File

@@ -565,7 +565,7 @@ def train(args):
accelerator.end_training()
if is_main_process and args.save_state:
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)
# del accelerator # この後メモリを使うのでこれは消す→printで使うので消さずにおく

View File

@@ -444,7 +444,7 @@ def train(args):
accelerator.end_training()
if args.save_state and is_main_process:
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)
del accelerator # この後メモリを使うのでこれは消す

View File

@@ -564,6 +564,11 @@ class NetworkTrainer:
"random_crop": bool(subset.random_crop),
"shuffle_caption": bool(subset.shuffle_caption),
"keep_tokens": subset.keep_tokens,
"keep_tokens_separator": subset.keep_tokens_separator,
"secondary_separator": subset.secondary_separator,
"enable_wildcard": bool(subset.enable_wildcard),
"caption_prefix": subset.caption_prefix,
"caption_suffix": subset.caption_suffix,
}
image_dir_or_metadata_file = None
@@ -935,7 +940,7 @@ class NetworkTrainer:
accelerator.end_training()
if is_main_process and args.save_state:
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)
if is_main_process:

View File

@@ -732,7 +732,7 @@ class TextualInversionTrainer:
accelerator.end_training()
if args.save_state and is_main_process:
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)
if is_main_process:

View File

@@ -586,7 +586,7 @@ def train(args):
accelerator.end_training()
if args.save_state and is_main_process:
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)
updated_embs = text_encoder.get_input_embeddings().weight[token_ids_XTI].data.detach().clone()