mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-07 05:58:56 +00:00
Compare commits
1 Commits
stable-cas
...
nw_applica
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ae0872ba3b |
255
README.md
255
README.md
@@ -1,173 +1,38 @@
|
||||
# Training Stable Cascade Stage C
|
||||
## LoRAの層別適用率の探索について
|
||||
|
||||
This is an experimental feature. There may be bugs.
|
||||
層別適用率を探索する `train_network_appl_weights.py` を追加してあります。現在は SDXL のみ対応しています。
|
||||
|
||||
__Feb 25, 2024 Update:__ Fixed a bug that the LoRA weights trained can be loaded in ComfyUI. If you still have a problem, please let me know.
|
||||
LoRA 等の学習済みネットワークに対して、層別適用率を変化させながら通常の学習プロセスを実行することで、適用率を探索します。つまり、どのような層別適用率を適用すると、学習データに近い画像が生成されるかを探索することができます。
|
||||
|
||||
__Feb 25, 2024 Update:__ Fixed a bug that Stage C training with mixed precision behaves the same as `--full_bf16` (fp16) regardless of `--full_bf16` (fp16) specified.
|
||||
層別適用率の合計をペナルティとすることが可能です。つまり、画像を再現しつつ、影響の少ない層の適用率が低くなるような適用率が探索できるはずです。
|
||||
|
||||
This is because the Stage C weights were loaded in bf16/fp16. With this fix, the memory usage without `--full_bf16` (fp16) specified will increase, so you may need to specify `--full_bf16` (fp16) as needed.
|
||||
複数のネットワークを対象に探索できます。また探索には最低 1 枚の学習データが必要になります。
|
||||
|
||||
__Feb 22, 2024 Update:__ Fixed a bug that LoRA is not applied to some modules (to_q/k/v and to_out) in Attention. Also, the model structure of Stage C has been changed, and you can choose xformers and SDPA (SDPA was used before). Please specify `--sdpa` or `--xformers` option.
|
||||
(何枚程度から正しく動くかは確認していません。50枚程度の画像でテスト済みです。また学習データは LoRA 学習時のデータでなくてもよいはずですが、未確認です。)
|
||||
|
||||
__Feb 20, 2024 Update:__ There was a problem with the preprocessing of the EfficientNetEncoder, and the latents became invalid (the saturation of the training results decreases). If you have cached `_sc_latents.npz` files with `--cache_latents_to_disk`, please delete them before training.
|
||||
コマンドラインオプションは `sdxl_train_network.py` とほぼ同じですが、以下のオプションが追加、拡張されています。
|
||||
|
||||
## Usage
|
||||
- `--application_loss_weight` : 層別適用率を loss に加える際の重みです。デフォルトは 0.0001 です。大きくすると、なるべく適用率を低くするように学習します。0 を指定するとペナルティが適用されないため、再現度が最も高くなる適用率を自由に探索します。
|
||||
- `--network_module` : 探索対象の複数のモジュールを指定することができます。たとえば `--network_module networks.lora networks.lora` のように指定します。
|
||||
- `--network_weights` : 探索対象の複数のネットワークの重みを指定することができます。たとえば `--network_weights model1.safetensors model2.safetensors` のように指定します。
|
||||
|
||||
Training is run with `stable_cascade_train_stage_c.py`.
|
||||
層別適用率のパラメータ数は 20個で、`BASE, IN00-08, MID, OUT00-08` となります。`BASE` は Text Encoder に適用されます。(Text Encoder を対象とした LoRA の動作は未確認です。)
|
||||
|
||||
The main options are the same as `sdxl_train.py`. The following options have been added.
|
||||
パラメータは一応ファイルに保存されますが、画面に表示される値をコピーして保存することをお勧めします。
|
||||
|
||||
- `--effnet_checkpoint_path`: Specifies the path to the EfficientNetEncoder weights.
|
||||
- `--stage_c_checkpoint_path`: Specifies the path to the Stage C weights.
|
||||
- `--text_model_checkpoint_path`: Specifies the path to the Text Encoder weights. If omitted, the model from Hugging Face will be used.
|
||||
- `--save_text_model`: Saves the model downloaded from Hugging Face to `--text_model_checkpoint_path`.
|
||||
- `--previewer_checkpoint_path`: Specifies the path to the Previewer weights. Used to generate sample images during training.
|
||||
- `--adaptive_loss_weight`: Uses [Adaptive Loss Weight](https://github.com/Stability-AI/StableCascade/blob/master/gdf/loss_weights.py) . If omitted, P2LossWeight is used. The official settings use Adaptive Loss Weight.
|
||||
### 備考
|
||||
|
||||
The learning rate is set to 1e-4 in the official settings.
|
||||
オプティマイザ AdamW、学習率 1e-1 で動作確認しています。学習率はかなり高めに設定してよいようです。この設定では LoRA 学習時の 1/20 ~ 1/10 ほどの epoch 数でそれなりの結果が得られます。
|
||||
|
||||
The first time, specify `--text_model_checkpoint_path` and `--save_text_model` to save the Text Encoder weights. From the next time, specify `--text_model_checkpoint_path` to load the saved weights.
|
||||
`application_loss_weight` を 0.0001 より大きくすると合計の適用率がかなり低くなる(=LoRA があまり適用されない)ようです。条件にもよると思いますので、適宜調整してください。
|
||||
|
||||
Sample image generation during training is done with Perviewer. Perviewer is a simple decoder that converts EfficientNetEncoder latents to images.
|
||||
適用率に負の値を使うと、影響の少ない層の適用率を極端に低くして合計を小さくする、という動きをしてしまうので、負の値は10倍の重み付けをしてあります(-0.01 は 0.1 とほぼ同じペナルティ)。重み付けを変更するときはソースを修正してください。
|
||||
|
||||
Some of the options for SDXL are simply ignored or cause an error (especially noise-related options such as `--noise_offset`). `--vae_batch_size` and `--no_half_vae` are applied directly to the EfficientNetEncoder (when `bf16` is specified for mixed precision, `--no_half_vae` is not necessary).
|
||||
「必要ない層への適用率を下げて影響範囲を小さくする」という使い方だけでなく、「あるキャラクターがあるポーズをしている画像を教師データに、キャラクターを維持しつつポーズを取るための LoRA の適用率を探索する」、「ある画風のあるキャラクターの画像を教師データに、画風 LoRA とキャラクター LoRA の適用率を探索する」などの使い方が考えられます。
|
||||
|
||||
Options for latents and Text Encoder output caches can be used as is, but since the EfficientNetEncoder is much lighter than the VAE, you may not need to use the cache unless memory is particularly tight.
|
||||
もしかすると、「あるキャラクターの、あえて別の画風の画像を教師データに、キャラクターの属性を再現するのに必要な層を探す」、「理想とする画像を教師データに、使えそうな LoRA を多数適用し、その中から最も再現度が高い適用率を探す(ただし LoRA の数が多いほど学習が遅くなります)」といった使い方もできるかもしれません。
|
||||
|
||||
`--gradient_checkpointing`, `--full_bf16`, and `--full_fp16` (untested) to reduce memory consumption can be used as is.
|
||||
|
||||
A scale of about 4 is suitable for sample image generation.
|
||||
|
||||
Since the official settings use `bf16` for training, training with `fp16` may be unstable.
|
||||
|
||||
The code for training the Text Encoder is also written, but it is untested.
|
||||
|
||||
### Command line sample
|
||||
|
||||
```batch
|
||||
accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 stable_cascade_train_stage_c.py --mixed_precision bf16 --save_precision bf16 --max_data_loader_n_workers 2 --persistent_data_loader_workers --gradient_checkpointing --learning_rate 1e-4 --optimizer_type adafactor --optimizer_args "scale_parameter=False" "relative_step=False" "warmup_init=False" --max_train_epochs 10 --save_every_n_epochs 1 --save_precision bf16 --output_dir ../output --output_name sc_test - --stage_c_checkpoint_path ../models/stage_c_bf16.safetensors --effnet_checkpoint_path ../models/effnet_encoder.safetensors --previewer_checkpoint_path ../models/previewer.safetensors --dataset_config ../dataset/config_bs1.toml --sample_every_n_epochs 1 --sample_prompts ../dataset/prompts.txt --adaptive_loss_weight
|
||||
```
|
||||
|
||||
### About the dataset for fine tuning
|
||||
|
||||
If the latents cache files for SD/SDXL exist (extension `*.npz`), it will be read and an error will occur during training. Please move them to another location in advance.
|
||||
|
||||
After that, run `finetune/prepare_buckets_latents.py` with the `--stable_cascade` option to create latents cache files for Stable Cascade (suffix `_sc_latents.npz` is added).
|
||||
|
||||
## LoRA training
|
||||
|
||||
`stable_cascade_train_c_network.py` is used for LoRA training. The main options are the same as `train_network.py`, and the same options as `stable_cascade_train_stage_c.py` have been added.
|
||||
|
||||
__This is an experimental feature, so the format of the saved weights may change in the future and become incompatible.__
|
||||
|
||||
There is no compatibility with the official LoRA, and the implementation of Text Encoder embedding training (Pivotal Tuning) in the official implementation is not implemented here.
|
||||
|
||||
Text Encoder LoRA training is implemented, but untested.
|
||||
|
||||
## Image generation
|
||||
|
||||
Basic image generation functionality is available in `stable_cascade_gen_img.py`. See `--help` for usage.
|
||||
|
||||
When using LoRA, specify `--network_module networks.lora --network_mul 1 --network_weights lora_weights.safetensors`.
|
||||
|
||||
The following prompt options are available.
|
||||
|
||||
* `--n` Negative prompt up to the next option.
|
||||
* `--w` Specifies the width of the generated image.
|
||||
* `--h` Specifies the height of the generated image.
|
||||
* `--d` Specifies the seed of the generated image.
|
||||
* `--l` Specifies the CFG scale of the generated image.
|
||||
* `--s` Specifies the number of steps in the generation.
|
||||
* `--t` Specifies the t_start of the generation.
|
||||
* `--f` Specifies the shift of the generation.
|
||||
|
||||
# Stable Cascade Stage C の学習
|
||||
|
||||
実験的機能です。不具合があるかもしれません。
|
||||
|
||||
__2024/2/25 追記:__ 学習される LoRA の重みが ComfyUI で読み込めるよう修正しました。依然として不具合がある場合にはご連絡ください。
|
||||
|
||||
__2024/2/25 追記:__ Mixed precision 時のStage C の学習が、 `--full_bf16` (fp16) の指定に関わらず `--full_bf16` (fp16) 指定時と同じ動作となる(と思われる)不具合を修正しました。
|
||||
|
||||
Stage C の重みを bf16/fp16 で読み込んでいたためです。この修正により `--full_bf16` (fp16) 未指定時のメモリ使用量が増えますので、必要に応じて `--full_bf16` (fp16) を指定してください。
|
||||
|
||||
__2024/2/22 追記:__ LoRA が一部のモジュール(Attention の to_q/k/v および to_out)に適用されない不具合を修正しました。また Stage C のモデル構造を変更し xformers と SDPA を選べるようになりました(今までは SDPA が使用されていました)。`--sdpa` または `--xformers` オプションを指定してください。
|
||||
|
||||
__2024/2/20 追記:__ EfficientNetEncoder の前処理に不具合があり、latents が不正になっていました(学習結果の彩度が低下する現象が起きます)。`--cache_latents_to_disk` でキャッシュした `_sc_latents.npz` がある場合、いったん削除してから学習してください。
|
||||
|
||||
## 使い方
|
||||
|
||||
学習は `stable_cascade_train_stage_c.py` で行います。
|
||||
|
||||
主なオプションは `sdxl_train.py` と同様です。以下のオプションが追加されています。
|
||||
|
||||
- `--effnet_checkpoint_path` : EfficientNetEncoder の重みのパスを指定します。
|
||||
- `--stage_c_checkpoint_path` : Stage C の重みのパスを指定します。
|
||||
- `--text_model_checkpoint_path` : Text Encoder の重みのパスを指定します。省略時は Hugging Face のモデルを使用します。
|
||||
- `--save_text_model` : `--text_model_checkpoint_path` にHugging Face からダウンロードしたモデルを保存します。
|
||||
- `--previewer_checkpoint_path` : Previewer の重みのパスを指定します。学習中のサンプル画像生成に使用します。
|
||||
- `--adaptive_loss_weight` : [Adaptive Loss Weight](https://github.com/Stability-AI/StableCascade/blob/master/gdf/loss_weights.py) を用います。省略時は P2LossWeight が使用されます。公式では Adaptive Loss Weight が使用されているようです。
|
||||
|
||||
学習率は、公式の設定では 1e-4 のようです。
|
||||
|
||||
初回は `--text_model_checkpoint_path` と `--save_text_model` を指定して、Text Encoder の重みを保存すると良いでしょう。次からは `--text_model_checkpoint_path` を指定して、保存した重みを読み込むことができます。
|
||||
|
||||
学習中のサンプル画像生成は Perviewer で行われます。Previewer は EfficientNetEncoder の latents を画像に変換する簡易的な decoder です。
|
||||
|
||||
SDXL の向けの一部のオプションは単に無視されるか、エラーになります(特に `--noise_offset` などのノイズ関係)。`--vae_batch_size` および `--no_half_vae` はそのまま EfficientNetEncoder に適用されます(mixed precision に `bf16` 指定時は `--no_half_vae` は不要のようです)。
|
||||
|
||||
latents および Text Encoder 出力キャッシュのためのオプションはそのまま使用できますが、EfficientNetEncoder は VAE よりもかなり軽量のため、メモリが特に厳しい場合以外はキャッシュを使用する必要はないかもしれません。
|
||||
|
||||
メモリ消費を抑えるための `--gradient_checkpointing` 、`--full_bf16`、`--full_fp16`(未テスト)はそのまま使用できます。
|
||||
|
||||
サンプル画像生成時の Scale には 4 程度が適しているようです。
|
||||
|
||||
公式の設定では学習に `bf16` を用いているため、`fp16` での学習は不安定かもしれません。
|
||||
|
||||
Text Encoder 学習のコードも書いてありますが、未テストです。
|
||||
|
||||
### コマンドラインのサンプル
|
||||
|
||||
[Command-line-sample](#command-line-sample)を参照してください。
|
||||
|
||||
|
||||
### fine tuning方式のデータセットについて
|
||||
|
||||
SD/SDXL 向けの latents キャッシュファイル(拡張子 `*.npz`)が存在するとそれを読み込んでしまい学習時にエラーになります。あらかじめ他の場所に退避しておいてください。
|
||||
|
||||
その後、`finetune/prepare_buckets_latents.py` をオプション `--stable_cascade` を指定して実行すると、Stable Cascade 向けの latents キャッシュファイル(接尾辞 `_sc_latents.npz` が付きます)が作成されます。
|
||||
|
||||
|
||||
## LoRA 等の学習
|
||||
|
||||
LoRA の学習は `stable_cascade_train_c_network.py` で行います。主なオプションは `train_network.py` と同様で、`stable_cascade_train_stage_c.py` と同様のオプションが追加されています。
|
||||
|
||||
__実験的機能のため、保存される重みのフォーマットは将来的に変更され、互換性がなくなる可能性があります。__
|
||||
|
||||
公式の LoRA と重みの互換性はありません。また公式で実装されている Text Encoder の embedding 学習(Pivotal Tuning)も実装されていません。
|
||||
|
||||
Text Encoder の LoRA 学習は実装してありますが、未テストです。
|
||||
|
||||
## 画像生成
|
||||
|
||||
最低限の画像生成機能が `stable_cascade_gen_img.py` にあります。使用法は `--help` を参照してください。
|
||||
|
||||
LoRA 使用時は `--network_module networks.lora --network_mul 1 --network_weights lora_weights.safetensors` のように指定します。
|
||||
|
||||
プロンプトオプションとして以下が使用できます。
|
||||
|
||||
* `--n` Negative prompt up to the next option.
|
||||
* `--w` Specifies the width of the generated image.
|
||||
* `--h` Specifies the height of the generated image.
|
||||
* `--d` Specifies the seed of the generated image.
|
||||
* `--l` Specifies the CFG scale of the generated image.
|
||||
* `--s` Specifies the number of steps in the generation.
|
||||
* `--t` Specifies the t_start of the generation.
|
||||
* `--f` Specifies the shift of the generation.
|
||||
|
||||
|
||||
---
|
||||
---
|
||||
|
||||
__SDXL is now supported. The sdxl branch has been merged into the main branch. If you update the repository, please follow the upgrade instructions. Also, the version of accelerate has been updated, so please run accelerate config again.__ The documentation for SDXL training is [here](./README.md#sdxl-training).
|
||||
|
||||
@@ -420,49 +285,6 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum
|
||||
|
||||
## Change History
|
||||
|
||||
### Feb 24, 2024 / 2024/2/24: v0.8.4
|
||||
|
||||
- The log output has been improved. PR [#905](https://github.com/kohya-ss/sd-scripts/pull/905) Thanks to shirayu!
|
||||
- The log is formatted by default. The `rich` library is required. Please see [Upgrade](#upgrade) and update the library.
|
||||
- If `rich` is not installed, the log output will be the same as before.
|
||||
- The following options are available in each training script:
|
||||
- `--console_log_simple` option can be used to switch to the previous log output.
|
||||
- `--console_log_level` option can be used to specify the log level. The default is `INFO`.
|
||||
- `--console_log_file` option can be used to output the log to a file. The default is `None` (output to the console).
|
||||
- The sample image generation during multi-GPU training is now done with multiple GPUs. PR [#1061](https://github.com/kohya-ss/sd-scripts/pull/1061) Thanks to DKnight54!
|
||||
- The support for mps devices is improved. PR [#1054](https://github.com/kohya-ss/sd-scripts/pull/1054) Thanks to akx! If mps device exists instead of CUDA, the mps device is used automatically.
|
||||
- The `--new_conv_rank` option to specify the new rank of Conv2d is added to `networks/resize_lora.py`. PR [#1102](https://github.com/kohya-ss/sd-scripts/pull/1102) Thanks to mgz-dev!
|
||||
- An option `--highvram` to disable the optimization for environments with little VRAM is added to the training scripts. If you specify it when there is enough VRAM, the operation will be faster.
|
||||
- Currently, only the cache part of latents is optimized.
|
||||
- The IPEX support is improved. PR [#1086](https://github.com/kohya-ss/sd-scripts/pull/1086) Thanks to Disty0!
|
||||
- Fixed a bug that `svd_merge_lora.py` crashes in some cases. PR [#1087](https://github.com/kohya-ss/sd-scripts/pull/1087) Thanks to mgz-dev!
|
||||
- DyLoRA is fixed to work with SDXL. PR [#1126](https://github.com/kohya-ss/sd-scripts/pull/1126) Thanks to tamlog06!
|
||||
- The common image generation script `gen_img.py` for SD 1/2 and SDXL is added. The basic functions are the same as the scripts for SD 1/2 and SDXL, but some new features are added.
|
||||
- External scripts to generate prompts can be supported. It can be called with `--from_module` option. (The documentation will be added later)
|
||||
- The normalization method after prompt weighting can be specified with `--emb_normalize_mode` option. `original` is the original method, `abs` is the normalization with the average of the absolute values, `none` is no normalization.
|
||||
- Gradual Latent Hires fix is added to each generation script. See [here](./docs/gen_img_README-ja.md#about-gradual-latent) for details.
|
||||
|
||||
- ログ出力が改善されました。 PR [#905](https://github.com/kohya-ss/sd-scripts/pull/905) shirayu 氏に感謝します。
|
||||
- デフォルトでログが成形されます。`rich` ライブラリが必要なため、[Upgrade](#upgrade) を参照し更新をお願いします。
|
||||
- `rich` がインストールされていない場合は、従来のログ出力になります。
|
||||
- 各学習スクリプトでは以下のオプションが有効です。
|
||||
- `--console_log_simple` オプションで従来のログ出力に切り替えられます。
|
||||
- `--console_log_level` でログレベルを指定できます。デフォルトは `INFO` です。
|
||||
- `--console_log_file` でログファイルを出力できます。デフォルトは `None`(コンソールに出力) です。
|
||||
- 複数 GPU 学習時に学習中のサンプル画像生成を複数 GPU で行うようになりました。 PR [#1061](https://github.com/kohya-ss/sd-scripts/pull/1061) DKnight54 氏に感謝します。
|
||||
- mps デバイスのサポートが改善されました。 PR [#1054](https://github.com/kohya-ss/sd-scripts/pull/1054) akx 氏に感謝します。CUDA ではなく mps が存在する場合には自動的に mps デバイスを使用します。
|
||||
- `networks/resize_lora.py` に Conv2d の新しいランクを指定するオプション `--new_conv_rank` が追加されました。 PR [#1102](https://github.com/kohya-ss/sd-scripts/pull/1102) mgz-dev 氏に感謝します。
|
||||
- 学習スクリプトに VRAMが少ない環境向け最適化を無効にするオプション `--highvram` を追加しました。VRAM に余裕がある場合に指定すると動作が高速化されます。
|
||||
- 現在は latents のキャッシュ部分のみ高速化されます。
|
||||
- IPEX サポートが改善されました。 PR [#1086](https://github.com/kohya-ss/sd-scripts/pull/1086) Disty0 氏に感謝します。
|
||||
- `svd_merge_lora.py` が場合によってエラーになる不具合が修正されました。 PR [#1087](https://github.com/kohya-ss/sd-scripts/pull/1087) mgz-dev 氏に感謝します。
|
||||
- DyLoRA が SDXL で動くよう修正されました。PR [#1126](https://github.com/kohya-ss/sd-scripts/pull/1126) tamlog06 氏に感謝します。
|
||||
- SD 1/2 および SDXL 共通の生成スクリプト `gen_img.py` を追加しました。基本的な機能は SD 1/2、SDXL 向けスクリプトと同じですが、いくつかの新機能が追加されています。
|
||||
- プロンプトを動的に生成する外部スクリプトをサポートしました。 `--from_module` で呼び出せます。(ドキュメントはのちほど追加します)
|
||||
- プロンプト重みづけ後の正規化方法を `--emb_normalize_mode` で指定できます。`original` は元の方法、`abs` は絶対値の平均値で正規化、`none` は正規化を行いません。
|
||||
- Gradual Latent Hires fix を各生成スクリプトに追加しました。詳細は [こちら](./docs/gen_img_README-ja.md#about-gradual-latent)。
|
||||
|
||||
|
||||
### Jan 27, 2024 / 2024/1/27: v0.8.3
|
||||
|
||||
- Fixed a bug that the training crashes when `--fp8_base` is specified with `--save_state`. PR [#1079](https://github.com/kohya-ss/sd-scripts/pull/1079) Thanks to feffy380!
|
||||
@@ -533,6 +355,45 @@ network_multiplier = -1.0
|
||||
```
|
||||
|
||||
|
||||
### Jan 17, 2024 / 2024/1/17: v0.8.1
|
||||
|
||||
- Fixed a bug that the VRAM usage without Text Encoder training is larger than before in training scripts for LoRA etc (`train_network.py`, `sdxl_train_network.py`).
|
||||
- Text Encoders were not moved to CPU.
|
||||
- Fixed typos. Thanks to akx! [PR #1053](https://github.com/kohya-ss/sd-scripts/pull/1053)
|
||||
|
||||
- LoRA 等の学習スクリプト(`train_network.py`、`sdxl_train_network.py`)で、Text Encoder を学習しない場合の VRAM 使用量が以前に比べて大きくなっていた不具合を修正しました。
|
||||
- Text Encoder が GPU に保持されたままになっていました。
|
||||
- 誤字が修正されました。 [PR #1053](https://github.com/kohya-ss/sd-scripts/pull/1053) akx 氏に感謝します。
|
||||
|
||||
### Jan 15, 2024 / 2024/1/15: v0.8.0
|
||||
|
||||
- Diffusers, Accelerate, Transformers and other related libraries have been updated. Please update the libraries with [Upgrade](#upgrade).
|
||||
- Some model files (Text Encoder without position_id) based on the latest Transformers can be loaded.
|
||||
- `torch.compile` is supported (experimental). PR [#1024](https://github.com/kohya-ss/sd-scripts/pull/1024) Thanks to p1atdev!
|
||||
- This feature works only on Linux or WSL.
|
||||
- Please specify `--torch_compile` option in each training script.
|
||||
- You can select the backend with `--dynamo_backend` option. The default is `"inductor"`. `inductor` or `eager` seems to work.
|
||||
- Please use `--sdpa` option instead of `--xformers` option.
|
||||
- PyTorch 2.1 or later is recommended.
|
||||
- Please see [PR](https://github.com/kohya-ss/sd-scripts/pull/1024) for details.
|
||||
- The session name for wandb can be specified with `--wandb_run_name` option. PR [#1032](https://github.com/kohya-ss/sd-scripts/pull/1032) Thanks to hopl1t!
|
||||
- IPEX library is updated. PR [#1030](https://github.com/kohya-ss/sd-scripts/pull/1030) Thanks to Disty0!
|
||||
- Fixed a bug that Diffusers format model cannot be saved.
|
||||
|
||||
- Diffusers、Accelerate、Transformers 等の関連ライブラリを更新しました。[Upgrade](#upgrade) を参照し更新をお願いします。
|
||||
- 最新の Transformers を前提とした一部のモデルファイル(Text Encoder が position_id を持たないもの)が読み込めるようになりました。
|
||||
- `torch.compile` がサポートされしました(実験的)。 PR [#1024](https://github.com/kohya-ss/sd-scripts/pull/1024) p1atdev 氏に感謝します。
|
||||
- Linux または WSL でのみ動作します。
|
||||
- 各学習スクリプトで `--torch_compile` オプションを指定してください。
|
||||
- `--dynamo_backend` オプションで使用される backend を選択できます。デフォルトは `"inductor"` です。 `inductor` または `eager` が動作するようです。
|
||||
- `--xformers` オプションとは互換性がありません。 代わりに `--sdpa` オプションを使用してください。
|
||||
- PyTorch 2.1以降を推奨します。
|
||||
- 詳細は [PR](https://github.com/kohya-ss/sd-scripts/pull/1024) をご覧ください。
|
||||
- wandb 保存時のセッション名が各学習スクリプトの `--wandb_run_name` オプションで指定できるようになりました。 PR [#1032](https://github.com/kohya-ss/sd-scripts/pull/1032) hopl1t 氏に感謝します。
|
||||
- IPEX ライブラリが更新されました。[PR #1030](https://github.com/kohya-ss/sd-scripts/pull/1030) Disty0 氏に感謝します。
|
||||
- Diffusers 形式でのモデル保存ができなくなっていた不具合を修正しました。
|
||||
|
||||
|
||||
Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates.
|
||||
最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
from library.device_utils import init_ipex
|
||||
init_ipex()
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
from typing import Union, List, Optional, Dict, Any, Tuple
|
||||
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
|
||||
|
||||
|
||||
@@ -452,36 +452,3 @@ python gen_img_diffusers.py --ckpt wd-v1-3-full-pruned-half.ckpt
|
||||
|
||||
- `--network_show_meta` : 追加ネットワークのメタデータを表示します。
|
||||
|
||||
|
||||
---
|
||||
|
||||
# About Gradual Latent
|
||||
|
||||
Gradual Latent is a Hires fix that gradually increases the size of the latent. `gen_img.py`, `sdxl_gen_img.py`, and `gen_img_diffusers.py` have the following options.
|
||||
|
||||
- `--gradual_latent_timesteps`: Specifies the timestep to start increasing the size of the latent. The default is None, which means Gradual Latent is not used. Please try around 750 at first.
|
||||
- `--gradual_latent_ratio`: Specifies the initial size of the latent. The default is 0.5, which means it starts with half the default latent size.
|
||||
- `--gradual_latent_ratio_step`: Specifies the ratio to increase the size of the latent. The default is 0.125, which means the latent size is gradually increased to 0.625, 0.75, 0.875, 1.0.
|
||||
- `--gradual_latent_ratio_every_n_steps`: Specifies the interval to increase the size of the latent. The default is 3, which means the latent size is increased every 3 steps.
|
||||
|
||||
Each option can also be specified with prompt options, `--glt`, `--glr`, `--gls`, `--gle`.
|
||||
|
||||
__Please specify `euler_a` for the sampler.__ Because the source code of the sampler is modified. It will not work with other samplers.
|
||||
|
||||
It is more effective with SD 1.5. It is quite subtle with SDXL.
|
||||
|
||||
# Gradual Latent について
|
||||
|
||||
latentのサイズを徐々に大きくしていくHires fixです。`gen_img.py` 、``sdxl_gen_img.py` 、`gen_img_diffusers.py` に以下のオプションが追加されています。
|
||||
|
||||
- `--gradual_latent_timesteps` : latentのサイズを大きくし始めるタイムステップを指定します。デフォルトは None で、Gradual Latentを使用しません。750 くらいから始めてみてください。
|
||||
- `--gradual_latent_ratio` : latentの初期サイズを指定します。デフォルトは 0.5 で、デフォルトの latent サイズの半分のサイズから始めます。
|
||||
- `--gradual_latent_ratio_step`: latentのサイズを大きくする割合を指定します。デフォルトは 0.125 で、latentのサイズを 0.625, 0.75, 0.875, 1.0 と徐々に大きくします。
|
||||
- `--gradual_latent_ratio_every_n_steps`: latentのサイズを大きくする間隔を指定します。デフォルトは 3 で、3ステップごとに latent のサイズを大きくします。
|
||||
|
||||
それぞれのオプションは、プロンプトオプション、`--glt`、`--glr`、`--gls`、`--gle` でも指定できます。
|
||||
|
||||
サンプラーに手を加えているため、__サンプラーに `euler_a` を指定してください。__ 他のサンプラーでは動作しません。
|
||||
|
||||
SD 1.5 のほうが効果があります。SDXL ではかなり微妙です。
|
||||
|
||||
|
||||
43
fine_tune.py
43
fine_tune.py
@@ -2,27 +2,22 @@
|
||||
# XXX dropped option: hypernetwork training
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
from multiprocessing import Value
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler
|
||||
|
||||
from library.utils import setup_logging, add_logging_arguments
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import library.train_util as train_util
|
||||
import library.config_util as config_util
|
||||
from library.config_util import (
|
||||
@@ -42,7 +37,6 @@ from library.custom_train_functions import (
|
||||
def train(args):
|
||||
train_util.verify_training_args(args)
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
setup_logging(args, reset=True)
|
||||
|
||||
cache_latents = args.cache_latents
|
||||
|
||||
@@ -55,11 +49,11 @@ def train(args):
|
||||
if args.dataset_class is None:
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, False, True))
|
||||
if args.dataset_config is not None:
|
||||
logger.info(f"Load dataset config from {args.dataset_config}")
|
||||
print(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
logger.warning(
|
||||
print(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
@@ -92,7 +86,7 @@ def train(args):
|
||||
train_util.debug_dataset(train_dataset_group)
|
||||
return
|
||||
if len(train_dataset_group) == 0:
|
||||
logger.error(
|
||||
print(
|
||||
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
|
||||
)
|
||||
return
|
||||
@@ -103,7 +97,7 @@ def train(args):
|
||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
|
||||
# acceleratorを準備する
|
||||
logger.info("prepare accelerator")
|
||||
print("prepare accelerator")
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
@@ -164,7 +158,9 @@ def train(args):
|
||||
with torch.no_grad():
|
||||
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
@@ -211,8 +207,8 @@ def train(args):
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params=trainable_params)
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset_group,
|
||||
batch_size=1,
|
||||
@@ -227,9 +223,7 @@ def train(args):
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||
)
|
||||
accelerator.print(
|
||||
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
|
||||
)
|
||||
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
# データセット側にも学習ステップを送信
|
||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
@@ -293,7 +287,7 @@ def train(args):
|
||||
if accelerator.is_main_process:
|
||||
init_kwargs = {}
|
||||
if args.wandb_run_name:
|
||||
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
||||
init_kwargs['wandb'] = {'name': args.wandb_run_name}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
|
||||
@@ -467,13 +461,12 @@ def train(args):
|
||||
train_util.save_sd_model_on_train_end(
|
||||
args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae
|
||||
)
|
||||
logger.info("model saved.")
|
||||
print("model saved.")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
add_logging_arguments(parser)
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, False, True, True)
|
||||
train_util.add_training_arguments(parser, False)
|
||||
@@ -482,9 +475,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
config_util.add_config_arguments(parser)
|
||||
custom_train_functions.add_custom_train_arguments(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する"
|
||||
)
|
||||
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
|
||||
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
|
||||
parser.add_argument(
|
||||
"--learning_rate_te",
|
||||
|
||||
@@ -21,10 +21,6 @@ import torch.nn.functional as F
|
||||
import os
|
||||
from urllib.parse import urlparse
|
||||
from timm.models.hub import download_cached_file
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class BLIP_Base(nn.Module):
|
||||
def __init__(self,
|
||||
@@ -239,6 +235,6 @@ def load_checkpoint(model,url_or_filename):
|
||||
del state_dict[key]
|
||||
|
||||
msg = model.load_state_dict(state_dict,strict=False)
|
||||
logger.info('load checkpoint from %s'%url_or_filename)
|
||||
print('load checkpoint from %s'%url_or_filename)
|
||||
return model,msg
|
||||
|
||||
|
||||
@@ -8,10 +8,6 @@ import json
|
||||
import re
|
||||
|
||||
from tqdm import tqdm
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PATTERN_HAIR_LENGTH = re.compile(r', (long|short|medium) hair, ')
|
||||
PATTERN_HAIR_CUT = re.compile(r', (bob|hime) cut, ')
|
||||
@@ -40,13 +36,13 @@ def clean_tags(image_key, tags):
|
||||
tokens = tags.split(", rating")
|
||||
if len(tokens) == 1:
|
||||
# WD14 taggerのときはこちらになるのでメッセージは出さない
|
||||
# logger.info("no rating:")
|
||||
# logger.info(f"{image_key} {tags}")
|
||||
# print("no rating:")
|
||||
# print(f"{image_key} {tags}")
|
||||
pass
|
||||
else:
|
||||
if len(tokens) > 2:
|
||||
logger.info("multiple ratings:")
|
||||
logger.info(f"{image_key} {tags}")
|
||||
print("multiple ratings:")
|
||||
print(f"{image_key} {tags}")
|
||||
tags = tokens[0]
|
||||
|
||||
tags = ", " + tags.replace(", ", ", , ") + ", " # カンマ付きで検索をするための身も蓋もない対策
|
||||
@@ -128,43 +124,43 @@ def clean_caption(caption):
|
||||
|
||||
def main(args):
|
||||
if os.path.exists(args.in_json):
|
||||
logger.info(f"loading existing metadata: {args.in_json}")
|
||||
print(f"loading existing metadata: {args.in_json}")
|
||||
with open(args.in_json, "rt", encoding='utf-8') as f:
|
||||
metadata = json.load(f)
|
||||
else:
|
||||
logger.error("no metadata / メタデータファイルがありません")
|
||||
print("no metadata / メタデータファイルがありません")
|
||||
return
|
||||
|
||||
logger.info("cleaning captions and tags.")
|
||||
print("cleaning captions and tags.")
|
||||
image_keys = list(metadata.keys())
|
||||
for image_key in tqdm(image_keys):
|
||||
tags = metadata[image_key].get('tags')
|
||||
if tags is None:
|
||||
logger.error(f"image does not have tags / メタデータにタグがありません: {image_key}")
|
||||
print(f"image does not have tags / メタデータにタグがありません: {image_key}")
|
||||
else:
|
||||
org = tags
|
||||
tags = clean_tags(image_key, tags)
|
||||
metadata[image_key]['tags'] = tags
|
||||
if args.debug and org != tags:
|
||||
logger.info("FROM: " + org)
|
||||
logger.info("TO: " + tags)
|
||||
print("FROM: " + org)
|
||||
print("TO: " + tags)
|
||||
|
||||
caption = metadata[image_key].get('caption')
|
||||
if caption is None:
|
||||
logger.error(f"image does not have caption / メタデータにキャプションがありません: {image_key}")
|
||||
print(f"image does not have caption / メタデータにキャプションがありません: {image_key}")
|
||||
else:
|
||||
org = caption
|
||||
caption = clean_caption(caption)
|
||||
metadata[image_key]['caption'] = caption
|
||||
if args.debug and org != caption:
|
||||
logger.info("FROM: " + org)
|
||||
logger.info("TO: " + caption)
|
||||
print("FROM: " + org)
|
||||
print("TO: " + caption)
|
||||
|
||||
# metadataを書き出して終わり
|
||||
logger.info(f"writing metadata: {args.out_json}")
|
||||
print(f"writing metadata: {args.out_json}")
|
||||
with open(args.out_json, "wt", encoding='utf-8') as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
logger.info("done!")
|
||||
print("done!")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
@@ -182,10 +178,10 @@ if __name__ == '__main__':
|
||||
|
||||
args, unknown = parser.parse_known_args()
|
||||
if len(unknown) == 1:
|
||||
logger.warning("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.")
|
||||
logger.warning("All captions and tags in the metadata are processed.")
|
||||
logger.warning("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。")
|
||||
logger.warning("メタデータ内のすべてのキャプションとタグが処理されます。")
|
||||
print("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.")
|
||||
print("All captions and tags in the metadata are processed.")
|
||||
print("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。")
|
||||
print("メタデータ内のすべてのキャプションとタグが処理されます。")
|
||||
args.in_json = args.out_json
|
||||
args.out_json = unknown[0]
|
||||
elif len(unknown) > 0:
|
||||
|
||||
@@ -9,22 +9,14 @@ from pathlib import Path
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, get_preferred_device
|
||||
init_ipex()
|
||||
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
from blip.blip import blip_decoder, is_url
|
||||
import library.train_util as train_util
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEVICE = get_preferred_device()
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
IMAGE_SIZE = 384
|
||||
@@ -55,7 +47,7 @@ class ImageLoadingTransformDataset(torch.utils.data.Dataset):
|
||||
# convert to tensor temporarily so dataloader will accept it
|
||||
tensor = IMAGE_TRANSFORM(image)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
|
||||
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
|
||||
return None
|
||||
|
||||
return (tensor, img_path)
|
||||
@@ -82,21 +74,21 @@ def main(args):
|
||||
args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path
|
||||
|
||||
cwd = os.getcwd()
|
||||
logger.info(f"Current Working Directory is: {cwd}")
|
||||
print("Current Working Directory is: ", cwd)
|
||||
os.chdir("finetune")
|
||||
if not is_url(args.caption_weights) and not os.path.isfile(args.caption_weights):
|
||||
args.caption_weights = os.path.join("..", args.caption_weights)
|
||||
|
||||
logger.info(f"load images from {args.train_data_dir}")
|
||||
print(f"load images from {args.train_data_dir}")
|
||||
train_data_dir_path = Path(args.train_data_dir)
|
||||
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
||||
logger.info(f"found {len(image_paths)} images.")
|
||||
print(f"found {len(image_paths)} images.")
|
||||
|
||||
logger.info(f"loading BLIP caption: {args.caption_weights}")
|
||||
print(f"loading BLIP caption: {args.caption_weights}")
|
||||
model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit="large", med_config="./blip/med_config.json")
|
||||
model.eval()
|
||||
model = model.to(DEVICE)
|
||||
logger.info("BLIP loaded")
|
||||
print("BLIP loaded")
|
||||
|
||||
# captioningする
|
||||
def run_batch(path_imgs):
|
||||
@@ -116,7 +108,7 @@ def main(args):
|
||||
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
|
||||
f.write(caption + "\n")
|
||||
if args.debug:
|
||||
logger.info(f'{image_path} {caption}')
|
||||
print(image_path, caption)
|
||||
|
||||
# 読み込みの高速化のためにDataLoaderを使うオプション
|
||||
if args.max_data_loader_n_workers is not None:
|
||||
@@ -146,7 +138,7 @@ def main(args):
|
||||
raw_image = raw_image.convert("RGB")
|
||||
img_tensor = IMAGE_TRANSFORM(raw_image)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||
continue
|
||||
|
||||
b_imgs.append((image_path, img_tensor))
|
||||
@@ -156,7 +148,7 @@ def main(args):
|
||||
if len(b_imgs) > 0:
|
||||
run_batch(b_imgs)
|
||||
|
||||
logger.info("done!")
|
||||
print("done!")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
|
||||
@@ -5,19 +5,12 @@ import re
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, get_preferred_device
|
||||
init_ipex()
|
||||
|
||||
from transformers import AutoProcessor, AutoModelForCausalLM
|
||||
from transformers.generation.utils import GenerationMixin
|
||||
|
||||
import library.train_util as train_util
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
@@ -42,8 +35,8 @@ def remove_words(captions, debug):
|
||||
for pat in PATTERN_REPLACE:
|
||||
cap = pat.sub("", cap)
|
||||
if debug and cap != caption:
|
||||
logger.info(caption)
|
||||
logger.info(cap)
|
||||
print(caption)
|
||||
print(cap)
|
||||
removed_caps.append(cap)
|
||||
return removed_caps
|
||||
|
||||
@@ -77,16 +70,16 @@ def main(args):
|
||||
GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
|
||||
"""
|
||||
|
||||
logger.info(f"load images from {args.train_data_dir}")
|
||||
print(f"load images from {args.train_data_dir}")
|
||||
train_data_dir_path = Path(args.train_data_dir)
|
||||
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
||||
logger.info(f"found {len(image_paths)} images.")
|
||||
print(f"found {len(image_paths)} images.")
|
||||
|
||||
# できればcacheに依存せず明示的にダウンロードしたい
|
||||
logger.info(f"loading GIT: {args.model_id}")
|
||||
print(f"loading GIT: {args.model_id}")
|
||||
git_processor = AutoProcessor.from_pretrained(args.model_id)
|
||||
git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE)
|
||||
logger.info("GIT loaded")
|
||||
print("GIT loaded")
|
||||
|
||||
# captioningする
|
||||
def run_batch(path_imgs):
|
||||
@@ -104,7 +97,7 @@ def main(args):
|
||||
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
|
||||
f.write(caption + "\n")
|
||||
if args.debug:
|
||||
logger.info(f"{image_path} {caption}")
|
||||
print(image_path, caption)
|
||||
|
||||
# 読み込みの高速化のためにDataLoaderを使うオプション
|
||||
if args.max_data_loader_n_workers is not None:
|
||||
@@ -133,7 +126,7 @@ def main(args):
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
except Exception as e:
|
||||
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||
continue
|
||||
|
||||
b_imgs.append((image_path, image))
|
||||
@@ -144,7 +137,7 @@ def main(args):
|
||||
if len(b_imgs) > 0:
|
||||
run_batch(b_imgs)
|
||||
|
||||
logger.info("done!")
|
||||
print("done!")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
|
||||
@@ -5,30 +5,26 @@ from typing import List
|
||||
from tqdm import tqdm
|
||||
import library.train_util as train_util
|
||||
import os
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def main(args):
|
||||
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
|
||||
|
||||
train_data_dir_path = Path(args.train_data_dir)
|
||||
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
||||
logger.info(f"found {len(image_paths)} images.")
|
||||
print(f"found {len(image_paths)} images.")
|
||||
|
||||
if args.in_json is None and Path(args.out_json).is_file():
|
||||
args.in_json = args.out_json
|
||||
|
||||
if args.in_json is not None:
|
||||
logger.info(f"loading existing metadata: {args.in_json}")
|
||||
print(f"loading existing metadata: {args.in_json}")
|
||||
metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
|
||||
logger.warning("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます")
|
||||
print("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます")
|
||||
else:
|
||||
logger.info("new metadata will be created / 新しいメタデータファイルが作成されます")
|
||||
print("new metadata will be created / 新しいメタデータファイルが作成されます")
|
||||
metadata = {}
|
||||
|
||||
logger.info("merge caption texts to metadata json.")
|
||||
print("merge caption texts to metadata json.")
|
||||
for image_path in tqdm(image_paths):
|
||||
caption_path = image_path.with_suffix(args.caption_extension)
|
||||
caption = caption_path.read_text(encoding='utf-8').strip()
|
||||
@@ -42,12 +38,12 @@ def main(args):
|
||||
|
||||
metadata[image_key]['caption'] = caption
|
||||
if args.debug:
|
||||
logger.info(f"{image_key} {caption}")
|
||||
print(image_key, caption)
|
||||
|
||||
# metadataを書き出して終わり
|
||||
logger.info(f"writing metadata: {args.out_json}")
|
||||
print(f"writing metadata: {args.out_json}")
|
||||
Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
|
||||
logger.info("done!")
|
||||
print("done!")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
|
||||
@@ -5,30 +5,26 @@ from typing import List
|
||||
from tqdm import tqdm
|
||||
import library.train_util as train_util
|
||||
import os
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def main(args):
|
||||
assert not args.recursive or (args.recursive and args.full_path), "recursive requires full_path / recursiveはfull_pathと同時に指定してください"
|
||||
|
||||
train_data_dir_path = Path(args.train_data_dir)
|
||||
image_paths: List[Path] = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
||||
logger.info(f"found {len(image_paths)} images.")
|
||||
print(f"found {len(image_paths)} images.")
|
||||
|
||||
if args.in_json is None and Path(args.out_json).is_file():
|
||||
args.in_json = args.out_json
|
||||
|
||||
if args.in_json is not None:
|
||||
logger.info(f"loading existing metadata: {args.in_json}")
|
||||
print(f"loading existing metadata: {args.in_json}")
|
||||
metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
|
||||
logger.warning("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます")
|
||||
print("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます")
|
||||
else:
|
||||
logger.info("new metadata will be created / 新しいメタデータファイルが作成されます")
|
||||
print("new metadata will be created / 新しいメタデータファイルが作成されます")
|
||||
metadata = {}
|
||||
|
||||
logger.info("merge tags to metadata json.")
|
||||
print("merge tags to metadata json.")
|
||||
for image_path in tqdm(image_paths):
|
||||
tags_path = image_path.with_suffix(args.caption_extension)
|
||||
tags = tags_path.read_text(encoding='utf-8').strip()
|
||||
@@ -42,13 +38,13 @@ def main(args):
|
||||
|
||||
metadata[image_key]['tags'] = tags
|
||||
if args.debug:
|
||||
logger.info(f"{image_key} {tags}")
|
||||
print(image_key, tags)
|
||||
|
||||
# metadataを書き出して終わり
|
||||
logger.info(f"writing metadata: {args.out_json}")
|
||||
print(f"writing metadata: {args.out_json}")
|
||||
Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
|
||||
|
||||
logger.info("done!")
|
||||
print("done!")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
|
||||
@@ -8,25 +8,13 @@ from tqdm import tqdm
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import cv2
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, get_preferred_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from torchvision import transforms
|
||||
|
||||
import library.model_util as model_util
|
||||
import library.stable_cascade_utils as sc_utils
|
||||
import library.train_util as train_util
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEVICE = get_preferred_device()
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
IMAGE_TRANSFORMS = transforms.Compose(
|
||||
[
|
||||
@@ -46,7 +34,7 @@ def collate_fn_remove_corrupted(batch):
|
||||
return batch
|
||||
|
||||
|
||||
def get_npz_filename(data_dir, image_key, is_full_path, recursive, stable_cascade):
|
||||
def get_npz_filename(data_dir, image_key, is_full_path, recursive):
|
||||
if is_full_path:
|
||||
base_name = os.path.splitext(os.path.basename(image_key))[0]
|
||||
relative_path = os.path.relpath(os.path.dirname(image_key), data_dir)
|
||||
@@ -54,32 +42,31 @@ def get_npz_filename(data_dir, image_key, is_full_path, recursive, stable_cascad
|
||||
base_name = image_key
|
||||
relative_path = ""
|
||||
|
||||
ext = ".npz" if not stable_cascade else train_util.STABLE_CASCADE_LATENTS_CACHE_SUFFIX
|
||||
if recursive and relative_path:
|
||||
return os.path.join(data_dir, relative_path, base_name) + ext
|
||||
return os.path.join(data_dir, relative_path, base_name) + ".npz"
|
||||
else:
|
||||
return os.path.join(data_dir, base_name) + ext
|
||||
return os.path.join(data_dir, base_name) + ".npz"
|
||||
|
||||
|
||||
def main(args):
|
||||
# assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります"
|
||||
if args.bucket_reso_steps % 8 > 0:
|
||||
logger.warning(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります")
|
||||
print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります")
|
||||
if args.bucket_reso_steps % 32 > 0:
|
||||
logger.warning(
|
||||
print(
|
||||
f"WARNING: bucket_reso_steps is not divisible by 32. It is not working with SDXL / bucket_reso_stepsが32で割り切れません。SDXLでは動作しません"
|
||||
)
|
||||
|
||||
train_data_dir_path = Path(args.train_data_dir)
|
||||
image_paths: List[str] = [str(p) for p in train_util.glob_images_pathlib(train_data_dir_path, args.recursive)]
|
||||
logger.info(f"found {len(image_paths)} images.")
|
||||
print(f"found {len(image_paths)} images.")
|
||||
|
||||
if os.path.exists(args.in_json):
|
||||
logger.info(f"loading existing metadata: {args.in_json}")
|
||||
print(f"loading existing metadata: {args.in_json}")
|
||||
with open(args.in_json, "rt", encoding="utf-8") as f:
|
||||
metadata = json.load(f)
|
||||
else:
|
||||
logger.error(f"no metadata / メタデータファイルがありません: {args.in_json}")
|
||||
print(f"no metadata / メタデータファイルがありません: {args.in_json}")
|
||||
return
|
||||
|
||||
weight_dtype = torch.float32
|
||||
@@ -88,20 +75,13 @@ def main(args):
|
||||
elif args.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
|
||||
if not args.stable_cascade:
|
||||
vae = model_util.load_vae(args.model_name_or_path, weight_dtype)
|
||||
divisor = 8
|
||||
else:
|
||||
vae = sc_utils.load_effnet(args.model_name_or_path, DEVICE)
|
||||
divisor = 32
|
||||
vae = model_util.load_vae(args.model_name_or_path, weight_dtype)
|
||||
vae.eval()
|
||||
vae.to(DEVICE, dtype=weight_dtype)
|
||||
|
||||
# bucketのサイズを計算する
|
||||
max_reso = tuple([int(t) for t in args.max_resolution.split(",")])
|
||||
assert (
|
||||
len(max_reso) == 2
|
||||
), f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
|
||||
assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
|
||||
|
||||
bucket_manager = train_util.BucketManager(
|
||||
args.bucket_no_upscale, max_reso, args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps
|
||||
@@ -109,7 +89,7 @@ def main(args):
|
||||
if not args.bucket_no_upscale:
|
||||
bucket_manager.make_buckets()
|
||||
else:
|
||||
logger.warning(
|
||||
print(
|
||||
"min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます"
|
||||
)
|
||||
|
||||
@@ -150,7 +130,7 @@ def main(args):
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
except Exception as e:
|
||||
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||
continue
|
||||
|
||||
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
|
||||
@@ -166,10 +146,6 @@ def main(args):
|
||||
# メタデータに記録する解像度はlatent単位とするので、8単位で切り捨て
|
||||
metadata[image_key]["train_resolution"] = (reso[0] - reso[0] % 8, reso[1] - reso[1] % 8)
|
||||
|
||||
# 追加情報を記録
|
||||
metadata[image_key]["original_size"] = (image.width, image.height)
|
||||
metadata[image_key]["train_resized_size"] = resized_size
|
||||
|
||||
if not args.bucket_no_upscale:
|
||||
# upscaleを行わないときには、resize後のサイズは、bucketのサイズと、縦横どちらかが同じであることを確認する
|
||||
assert (
|
||||
@@ -184,9 +160,9 @@ def main(args):
|
||||
), f"internal error resized size is small: {resized_size}, {reso}"
|
||||
|
||||
# 既に存在するファイルがあればshape等を確認して同じならskipする
|
||||
npz_file_name = get_npz_filename(args.train_data_dir, image_key, args.full_path, args.recursive, args.stable_cascade)
|
||||
npz_file_name = get_npz_filename(args.train_data_dir, image_key, args.full_path, args.recursive)
|
||||
if args.skip_existing:
|
||||
if train_util.is_disk_cached_latents_is_expected(reso, npz_file_name, args.flip_aug, divisor):
|
||||
if train_util.is_disk_cached_latents_is_expected(reso, npz_file_name, args.flip_aug):
|
||||
continue
|
||||
|
||||
# バッチへ追加
|
||||
@@ -207,15 +183,15 @@ def main(args):
|
||||
for i, reso in enumerate(bucket_manager.resos):
|
||||
count = bucket_counts.get(reso, 0)
|
||||
if count > 0:
|
||||
logger.info(f"bucket {i} {reso}: {count}")
|
||||
print(f"bucket {i} {reso}: {count}")
|
||||
img_ar_errors = np.array(img_ar_errors)
|
||||
logger.info(f"mean ar error: {np.mean(img_ar_errors)}")
|
||||
print(f"mean ar error: {np.mean(img_ar_errors)}")
|
||||
|
||||
# metadataを書き出して終わり
|
||||
logger.info(f"writing metadata: {args.out_json}")
|
||||
print(f"writing metadata: {args.out_json}")
|
||||
with open(args.out_json, "wt", encoding="utf-8") as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
logger.info("done!")
|
||||
print("done!")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
@@ -224,14 +200,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
|
||||
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
||||
parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
|
||||
parser.add_argument(
|
||||
"--stable_cascade",
|
||||
action="store_true",
|
||||
help="prepare EffNet latents for stable cascade / stable cascade用のEffNetのlatentsを準備する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)"
|
||||
)
|
||||
parser.add_argument("--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
||||
parser.add_argument(
|
||||
"--max_data_loader_n_workers",
|
||||
@@ -254,16 +223,10 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bucket_no_upscale",
|
||||
action="store_true",
|
||||
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します",
|
||||
"--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default="no",
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help="use mixed precision / 混合精度を使う場合、その精度",
|
||||
"--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--full_path",
|
||||
@@ -271,9 +234,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--flip_aug",
|
||||
action="store_true",
|
||||
help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する",
|
||||
"--flip_aug", action="store_true", help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_existing",
|
||||
|
||||
@@ -11,10 +11,6 @@ from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
import library.train_util as train_util
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# from wd14 tagger
|
||||
IMAGE_SIZE = 448
|
||||
@@ -62,7 +58,7 @@ class ImageLoadingPrepDataset(torch.utils.data.Dataset):
|
||||
image = preprocess_image(image)
|
||||
tensor = torch.tensor(image)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
|
||||
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
|
||||
return None
|
||||
|
||||
return (tensor, img_path)
|
||||
@@ -83,7 +79,7 @@ def main(args):
|
||||
# depreacatedの警告が出るけどなくなったらその時
|
||||
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
|
||||
if not os.path.exists(args.model_dir) or args.force_download:
|
||||
logger.info(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
|
||||
print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
|
||||
files = FILES
|
||||
if args.onnx:
|
||||
files += FILES_ONNX
|
||||
@@ -99,7 +95,7 @@ def main(args):
|
||||
force_filename=file,
|
||||
)
|
||||
else:
|
||||
logger.info("using existing wd14 tagger model")
|
||||
print("using existing wd14 tagger model")
|
||||
|
||||
# 画像を読み込む
|
||||
if args.onnx:
|
||||
@@ -107,8 +103,8 @@ def main(args):
|
||||
import onnxruntime as ort
|
||||
|
||||
onnx_path = f"{args.model_dir}/model.onnx"
|
||||
logger.info("Running wd14 tagger with onnx")
|
||||
logger.info(f"loading onnx model: {onnx_path}")
|
||||
print("Running wd14 tagger with onnx")
|
||||
print(f"loading onnx model: {onnx_path}")
|
||||
|
||||
if not os.path.exists(onnx_path):
|
||||
raise Exception(
|
||||
@@ -125,7 +121,7 @@ def main(args):
|
||||
|
||||
if args.batch_size != batch_size and type(batch_size) != str:
|
||||
# some rebatch model may use 'N' as dynamic axes
|
||||
logger.warning(
|
||||
print(
|
||||
f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}"
|
||||
)
|
||||
args.batch_size = batch_size
|
||||
@@ -160,7 +156,7 @@ def main(args):
|
||||
|
||||
train_data_dir_path = Path(args.train_data_dir)
|
||||
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
||||
logger.info(f"found {len(image_paths)} images.")
|
||||
print(f"found {len(image_paths)} images.")
|
||||
|
||||
tag_freq = {}
|
||||
|
||||
@@ -241,10 +237,7 @@ def main(args):
|
||||
with open(caption_file, "wt", encoding="utf-8") as f:
|
||||
f.write(tag_text + "\n")
|
||||
if args.debug:
|
||||
logger.info("")
|
||||
logger.info(f"{image_path}:")
|
||||
logger.info(f"\tCharacter tags: {character_tag_text}")
|
||||
logger.info(f"\tGeneral tags: {general_tag_text}")
|
||||
print(f"\n{image_path}:\n Character tags: {character_tag_text}\n General tags: {general_tag_text}")
|
||||
|
||||
# 読み込みの高速化のためにDataLoaderを使うオプション
|
||||
if args.max_data_loader_n_workers is not None:
|
||||
@@ -276,7 +269,7 @@ def main(args):
|
||||
image = image.convert("RGB")
|
||||
image = preprocess_image(image)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
||||
continue
|
||||
b_imgs.append((image_path, image))
|
||||
|
||||
@@ -291,11 +284,11 @@ def main(args):
|
||||
|
||||
if args.frequency_tags:
|
||||
sorted_tags = sorted(tag_freq.items(), key=lambda x: x[1], reverse=True)
|
||||
print("Tag frequencies:")
|
||||
print("\nTag frequencies:")
|
||||
for tag, freq in sorted_tags:
|
||||
print(f"{tag}: {freq}")
|
||||
|
||||
logger.info("done!")
|
||||
print("done!")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
|
||||
3334
gen_img.py
3334
gen_img.py
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -40,10 +40,7 @@ from .train_util import (
|
||||
ControlNetDataset,
|
||||
DatasetGroup,
|
||||
)
|
||||
from .utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def add_config_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル")
|
||||
@@ -348,7 +345,7 @@ class ConfigSanitizer:
|
||||
return self.user_config_validator(user_config)
|
||||
except MultipleInvalid:
|
||||
# TODO: エラー発生時のメッセージをわかりやすくする
|
||||
logger.error("Invalid user config / ユーザ設定の形式が正しくないようです")
|
||||
print("Invalid user config / ユーザ設定の形式が正しくないようです")
|
||||
raise
|
||||
|
||||
# NOTE: In nature, argument parser result is not needed to be sanitize
|
||||
@@ -358,7 +355,7 @@ class ConfigSanitizer:
|
||||
return self.argparse_config_validator(argparse_namespace)
|
||||
except MultipleInvalid:
|
||||
# XXX: this should be a bug
|
||||
logger.error("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。")
|
||||
print("Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。")
|
||||
raise
|
||||
|
||||
# NOTE: value would be overwritten by latter dict if there is already the same key
|
||||
@@ -541,13 +538,13 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
" ",
|
||||
)
|
||||
|
||||
logger.info(f'{info}')
|
||||
print(info)
|
||||
|
||||
# make buckets first because it determines the length of dataset
|
||||
# and set the same seed for all datasets
|
||||
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
|
||||
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
|
||||
for i, dataset in enumerate(datasets):
|
||||
logger.info(f"[Dataset {i}]")
|
||||
print(f"[Dataset {i}]")
|
||||
dataset.make_buckets()
|
||||
dataset.set_seed(seed)
|
||||
|
||||
@@ -560,7 +557,7 @@ def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str]
|
||||
try:
|
||||
n_repeats = int(tokens[0])
|
||||
except ValueError as e:
|
||||
logger.warning(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}")
|
||||
print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}")
|
||||
return 0, ""
|
||||
caption_by_folder = "_".join(tokens[1:])
|
||||
return n_repeats, caption_by_folder
|
||||
@@ -632,13 +629,17 @@ def load_user_config(file: str) -> dict:
|
||||
with open(file, "r") as f:
|
||||
config = json.load(f)
|
||||
except Exception:
|
||||
logger.error(f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}")
|
||||
print(
|
||||
f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
|
||||
)
|
||||
raise
|
||||
elif file.name.lower().endswith(".toml"):
|
||||
try:
|
||||
config = toml.load(file)
|
||||
except Exception:
|
||||
logger.error(f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}")
|
||||
print(
|
||||
f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
|
||||
)
|
||||
raise
|
||||
else:
|
||||
raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")
|
||||
@@ -664,26 +665,23 @@ if __name__ == "__main__":
|
||||
argparse_namespace = parser.parse_args(remain)
|
||||
train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning)
|
||||
|
||||
logger.info("[argparse_namespace]")
|
||||
logger.info(f'{vars(argparse_namespace)}')
|
||||
print("[argparse_namespace]")
|
||||
print(vars(argparse_namespace))
|
||||
|
||||
user_config = load_user_config(config_args.dataset_config)
|
||||
|
||||
logger.info("")
|
||||
logger.info("[user_config]")
|
||||
logger.info(f'{user_config}')
|
||||
print("\n[user_config]")
|
||||
print(user_config)
|
||||
|
||||
sanitizer = ConfigSanitizer(
|
||||
config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout
|
||||
)
|
||||
sanitized_user_config = sanitizer.sanitize_user_config(user_config)
|
||||
|
||||
logger.info("")
|
||||
logger.info("[sanitized_user_config]")
|
||||
logger.info(f'{sanitized_user_config}')
|
||||
print("\n[sanitized_user_config]")
|
||||
print(sanitized_user_config)
|
||||
|
||||
blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
|
||||
|
||||
logger.info("")
|
||||
logger.info("[blueprint]")
|
||||
logger.info(f'{blueprint}')
|
||||
print("\n[blueprint]")
|
||||
print(blueprint)
|
||||
|
||||
@@ -3,10 +3,7 @@ import argparse
|
||||
import random
|
||||
import re
|
||||
from typing import List, Optional, Union
|
||||
from .utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def prepare_scheduler_for_custom_training(noise_scheduler, device):
|
||||
if hasattr(noise_scheduler, "all_snr"):
|
||||
@@ -24,7 +21,7 @@ def prepare_scheduler_for_custom_training(noise_scheduler, device):
|
||||
|
||||
def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
|
||||
# fix beta: zero terminal SNR
|
||||
logger.info(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891")
|
||||
print(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891")
|
||||
|
||||
def enforce_zero_terminal_snr(betas):
|
||||
# Convert betas to alphas_bar_sqrt
|
||||
@@ -52,8 +49,8 @@ def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
|
||||
# logger.info(f"original: {noise_scheduler.betas}")
|
||||
# logger.info(f"fixed: {betas}")
|
||||
# print("original:", noise_scheduler.betas)
|
||||
# print("fixed:", betas)
|
||||
|
||||
noise_scheduler.betas = betas
|
||||
noise_scheduler.alphas = alphas
|
||||
@@ -82,13 +79,13 @@ def get_snr_scale(timesteps, noise_scheduler):
|
||||
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
|
||||
scale = snr_t / (snr_t + 1)
|
||||
# # show debug info
|
||||
# logger.info(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}")
|
||||
# print(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}")
|
||||
return scale
|
||||
|
||||
|
||||
def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss):
|
||||
scale = get_snr_scale(timesteps, noise_scheduler)
|
||||
# logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
|
||||
# print(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
|
||||
loss = loss + loss / scale * v_pred_like_loss
|
||||
return loss
|
||||
|
||||
@@ -271,7 +268,7 @@ def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int):
|
||||
tokens.append(text_token)
|
||||
weights.append(text_weight)
|
||||
if truncated:
|
||||
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
|
||||
print("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
|
||||
return tokens, weights
|
||||
|
||||
|
||||
|
||||
@@ -1,89 +0,0 @@
|
||||
import functools
|
||||
import gc
|
||||
|
||||
import torch
|
||||
|
||||
from .utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
HAS_CUDA = torch.cuda.is_available()
|
||||
except Exception:
|
||||
HAS_CUDA = False
|
||||
|
||||
try:
|
||||
HAS_MPS = torch.backends.mps.is_available()
|
||||
except Exception:
|
||||
HAS_MPS = False
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex # noqa
|
||||
|
||||
HAS_XPU = torch.xpu.is_available()
|
||||
except Exception:
|
||||
HAS_XPU = False
|
||||
|
||||
|
||||
def clean_memory():
|
||||
gc.collect()
|
||||
if HAS_CUDA:
|
||||
torch.cuda.empty_cache()
|
||||
if HAS_XPU:
|
||||
torch.xpu.empty_cache()
|
||||
if HAS_MPS:
|
||||
torch.mps.empty_cache()
|
||||
|
||||
|
||||
def clean_memory_on_device(device: torch.device):
|
||||
r"""
|
||||
Clean memory on the specified device, will be called from training scripts.
|
||||
"""
|
||||
gc.collect()
|
||||
|
||||
# device may "cuda" or "cuda:0", so we need to check the type of device
|
||||
if device.type == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
if device.type == "xpu":
|
||||
torch.xpu.empty_cache()
|
||||
if device.type == "mps":
|
||||
torch.mps.empty_cache()
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_preferred_device() -> torch.device:
|
||||
r"""
|
||||
Do not call this function from training scripts. Use accelerator.device instead.
|
||||
"""
|
||||
if HAS_CUDA:
|
||||
device = torch.device("cuda")
|
||||
elif HAS_XPU:
|
||||
device = torch.device("xpu")
|
||||
elif HAS_MPS:
|
||||
device = torch.device("mps")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
logger.info(f"get_preferred_device() -> {device}")
|
||||
return device
|
||||
|
||||
|
||||
def init_ipex():
|
||||
"""
|
||||
Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`.
|
||||
|
||||
This function should run right after importing torch and before doing anything else.
|
||||
|
||||
If IPEX is not available, this function does nothing.
|
||||
"""
|
||||
try:
|
||||
if HAS_XPU:
|
||||
from library.ipex import ipex_init
|
||||
|
||||
is_initialized, error_message = ipex_init()
|
||||
if not is_initialized:
|
||||
logger.error("failed to initialize ipex: {error_message}")
|
||||
else:
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error("failed to initialize ipex: {e}")
|
||||
@@ -4,10 +4,7 @@ from pathlib import Path
|
||||
import argparse
|
||||
import os
|
||||
from library.utils import fire_in_thread
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None):
|
||||
api = HfApi(
|
||||
@@ -36,9 +33,9 @@ def upload(
|
||||
try:
|
||||
api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
|
||||
except Exception as e: # とりあえずRepositoryNotFoundErrorは確認したが他にあると困るので
|
||||
logger.error("===========================================")
|
||||
logger.error(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}")
|
||||
logger.error("===========================================")
|
||||
print("===========================================")
|
||||
print(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}")
|
||||
print("===========================================")
|
||||
|
||||
is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir())
|
||||
|
||||
@@ -59,9 +56,9 @@ def upload(
|
||||
path_in_repo=path_in_repo,
|
||||
)
|
||||
except Exception as e: # RuntimeErrorを確認済みだが他にあると困るので
|
||||
logger.error("===========================================")
|
||||
logger.error(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}")
|
||||
logger.error("===========================================")
|
||||
print("===========================================")
|
||||
print(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}")
|
||||
print("===========================================")
|
||||
|
||||
if args.async_upload and not force_sync_upload:
|
||||
fire_in_thread(uploader)
|
||||
|
||||
@@ -9,171 +9,167 @@ from .hijacks import ipex_hijacks
|
||||
|
||||
def ipex_init(): # pylint: disable=too-many-statements
|
||||
try:
|
||||
if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_xpu_hijacked") and torch.cuda.is_xpu_hijacked:
|
||||
return True, "Skipping IPEX hijack"
|
||||
else:
|
||||
# Replace cuda with xpu:
|
||||
torch.cuda.current_device = torch.xpu.current_device
|
||||
torch.cuda.current_stream = torch.xpu.current_stream
|
||||
torch.cuda.device = torch.xpu.device
|
||||
torch.cuda.device_count = torch.xpu.device_count
|
||||
torch.cuda.device_of = torch.xpu.device_of
|
||||
torch.cuda.get_device_name = torch.xpu.get_device_name
|
||||
torch.cuda.get_device_properties = torch.xpu.get_device_properties
|
||||
torch.cuda.init = torch.xpu.init
|
||||
torch.cuda.is_available = torch.xpu.is_available
|
||||
torch.cuda.is_initialized = torch.xpu.is_initialized
|
||||
torch.cuda.is_current_stream_capturing = lambda: False
|
||||
torch.cuda.set_device = torch.xpu.set_device
|
||||
torch.cuda.stream = torch.xpu.stream
|
||||
torch.cuda.synchronize = torch.xpu.synchronize
|
||||
torch.cuda.Event = torch.xpu.Event
|
||||
torch.cuda.Stream = torch.xpu.Stream
|
||||
torch.cuda.FloatTensor = torch.xpu.FloatTensor
|
||||
torch.Tensor.cuda = torch.Tensor.xpu
|
||||
torch.Tensor.is_cuda = torch.Tensor.is_xpu
|
||||
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
|
||||
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
|
||||
torch.cuda._initialized = torch.xpu.lazy_init._initialized
|
||||
torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
|
||||
torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
|
||||
torch.cuda._tls = torch.xpu.lazy_init._tls
|
||||
torch.cuda.threading = torch.xpu.lazy_init.threading
|
||||
torch.cuda.traceback = torch.xpu.lazy_init.traceback
|
||||
torch.cuda.Optional = torch.xpu.Optional
|
||||
torch.cuda.__cached__ = torch.xpu.__cached__
|
||||
torch.cuda.__loader__ = torch.xpu.__loader__
|
||||
torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
|
||||
torch.cuda.Tuple = torch.xpu.Tuple
|
||||
torch.cuda.streams = torch.xpu.streams
|
||||
torch.cuda._lazy_new = torch.xpu._lazy_new
|
||||
torch.cuda.FloatStorage = torch.xpu.FloatStorage
|
||||
torch.cuda.Any = torch.xpu.Any
|
||||
torch.cuda.__doc__ = torch.xpu.__doc__
|
||||
torch.cuda.default_generators = torch.xpu.default_generators
|
||||
torch.cuda.HalfTensor = torch.xpu.HalfTensor
|
||||
torch.cuda._get_device_index = torch.xpu._get_device_index
|
||||
torch.cuda.__path__ = torch.xpu.__path__
|
||||
torch.cuda.Device = torch.xpu.Device
|
||||
torch.cuda.IntTensor = torch.xpu.IntTensor
|
||||
torch.cuda.ByteStorage = torch.xpu.ByteStorage
|
||||
torch.cuda.set_stream = torch.xpu.set_stream
|
||||
torch.cuda.BoolStorage = torch.xpu.BoolStorage
|
||||
torch.cuda.os = torch.xpu.os
|
||||
torch.cuda.torch = torch.xpu.torch
|
||||
torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
|
||||
torch.cuda.Union = torch.xpu.Union
|
||||
torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
|
||||
torch.cuda.ShortTensor = torch.xpu.ShortTensor
|
||||
torch.cuda.LongTensor = torch.xpu.LongTensor
|
||||
torch.cuda.IntStorage = torch.xpu.IntStorage
|
||||
torch.cuda.LongStorage = torch.xpu.LongStorage
|
||||
torch.cuda.__annotations__ = torch.xpu.__annotations__
|
||||
torch.cuda.__package__ = torch.xpu.__package__
|
||||
torch.cuda.__builtins__ = torch.xpu.__builtins__
|
||||
torch.cuda.CharTensor = torch.xpu.CharTensor
|
||||
torch.cuda.List = torch.xpu.List
|
||||
torch.cuda._lazy_init = torch.xpu._lazy_init
|
||||
torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
|
||||
torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
|
||||
torch.cuda.ByteTensor = torch.xpu.ByteTensor
|
||||
torch.cuda.StreamContext = torch.xpu.StreamContext
|
||||
torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
|
||||
torch.cuda.ShortStorage = torch.xpu.ShortStorage
|
||||
torch.cuda._lazy_call = torch.xpu._lazy_call
|
||||
torch.cuda.HalfStorage = torch.xpu.HalfStorage
|
||||
torch.cuda.random = torch.xpu.random
|
||||
torch.cuda._device = torch.xpu._device
|
||||
torch.cuda.classproperty = torch.xpu.classproperty
|
||||
torch.cuda.__name__ = torch.xpu.__name__
|
||||
torch.cuda._device_t = torch.xpu._device_t
|
||||
torch.cuda.warnings = torch.xpu.warnings
|
||||
torch.cuda.__spec__ = torch.xpu.__spec__
|
||||
torch.cuda.BoolTensor = torch.xpu.BoolTensor
|
||||
torch.cuda.CharStorage = torch.xpu.CharStorage
|
||||
torch.cuda.__file__ = torch.xpu.__file__
|
||||
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
|
||||
# torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
|
||||
# Replace cuda with xpu:
|
||||
torch.cuda.current_device = torch.xpu.current_device
|
||||
torch.cuda.current_stream = torch.xpu.current_stream
|
||||
torch.cuda.device = torch.xpu.device
|
||||
torch.cuda.device_count = torch.xpu.device_count
|
||||
torch.cuda.device_of = torch.xpu.device_of
|
||||
torch.cuda.get_device_name = torch.xpu.get_device_name
|
||||
torch.cuda.get_device_properties = torch.xpu.get_device_properties
|
||||
torch.cuda.init = torch.xpu.init
|
||||
torch.cuda.is_available = torch.xpu.is_available
|
||||
torch.cuda.is_initialized = torch.xpu.is_initialized
|
||||
torch.cuda.is_current_stream_capturing = lambda: False
|
||||
torch.cuda.set_device = torch.xpu.set_device
|
||||
torch.cuda.stream = torch.xpu.stream
|
||||
torch.cuda.synchronize = torch.xpu.synchronize
|
||||
torch.cuda.Event = torch.xpu.Event
|
||||
torch.cuda.Stream = torch.xpu.Stream
|
||||
torch.cuda.FloatTensor = torch.xpu.FloatTensor
|
||||
torch.Tensor.cuda = torch.Tensor.xpu
|
||||
torch.Tensor.is_cuda = torch.Tensor.is_xpu
|
||||
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
|
||||
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
|
||||
torch.cuda._initialized = torch.xpu.lazy_init._initialized
|
||||
torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
|
||||
torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
|
||||
torch.cuda._tls = torch.xpu.lazy_init._tls
|
||||
torch.cuda.threading = torch.xpu.lazy_init.threading
|
||||
torch.cuda.traceback = torch.xpu.lazy_init.traceback
|
||||
torch.cuda.Optional = torch.xpu.Optional
|
||||
torch.cuda.__cached__ = torch.xpu.__cached__
|
||||
torch.cuda.__loader__ = torch.xpu.__loader__
|
||||
torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
|
||||
torch.cuda.Tuple = torch.xpu.Tuple
|
||||
torch.cuda.streams = torch.xpu.streams
|
||||
torch.cuda._lazy_new = torch.xpu._lazy_new
|
||||
torch.cuda.FloatStorage = torch.xpu.FloatStorage
|
||||
torch.cuda.Any = torch.xpu.Any
|
||||
torch.cuda.__doc__ = torch.xpu.__doc__
|
||||
torch.cuda.default_generators = torch.xpu.default_generators
|
||||
torch.cuda.HalfTensor = torch.xpu.HalfTensor
|
||||
torch.cuda._get_device_index = torch.xpu._get_device_index
|
||||
torch.cuda.__path__ = torch.xpu.__path__
|
||||
torch.cuda.Device = torch.xpu.Device
|
||||
torch.cuda.IntTensor = torch.xpu.IntTensor
|
||||
torch.cuda.ByteStorage = torch.xpu.ByteStorage
|
||||
torch.cuda.set_stream = torch.xpu.set_stream
|
||||
torch.cuda.BoolStorage = torch.xpu.BoolStorage
|
||||
torch.cuda.os = torch.xpu.os
|
||||
torch.cuda.torch = torch.xpu.torch
|
||||
torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
|
||||
torch.cuda.Union = torch.xpu.Union
|
||||
torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
|
||||
torch.cuda.ShortTensor = torch.xpu.ShortTensor
|
||||
torch.cuda.LongTensor = torch.xpu.LongTensor
|
||||
torch.cuda.IntStorage = torch.xpu.IntStorage
|
||||
torch.cuda.LongStorage = torch.xpu.LongStorage
|
||||
torch.cuda.__annotations__ = torch.xpu.__annotations__
|
||||
torch.cuda.__package__ = torch.xpu.__package__
|
||||
torch.cuda.__builtins__ = torch.xpu.__builtins__
|
||||
torch.cuda.CharTensor = torch.xpu.CharTensor
|
||||
torch.cuda.List = torch.xpu.List
|
||||
torch.cuda._lazy_init = torch.xpu._lazy_init
|
||||
torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
|
||||
torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
|
||||
torch.cuda.ByteTensor = torch.xpu.ByteTensor
|
||||
torch.cuda.StreamContext = torch.xpu.StreamContext
|
||||
torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
|
||||
torch.cuda.ShortStorage = torch.xpu.ShortStorage
|
||||
torch.cuda._lazy_call = torch.xpu._lazy_call
|
||||
torch.cuda.HalfStorage = torch.xpu.HalfStorage
|
||||
torch.cuda.random = torch.xpu.random
|
||||
torch.cuda._device = torch.xpu._device
|
||||
torch.cuda.classproperty = torch.xpu.classproperty
|
||||
torch.cuda.__name__ = torch.xpu.__name__
|
||||
torch.cuda._device_t = torch.xpu._device_t
|
||||
torch.cuda.warnings = torch.xpu.warnings
|
||||
torch.cuda.__spec__ = torch.xpu.__spec__
|
||||
torch.cuda.BoolTensor = torch.xpu.BoolTensor
|
||||
torch.cuda.CharStorage = torch.xpu.CharStorage
|
||||
torch.cuda.__file__ = torch.xpu.__file__
|
||||
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
|
||||
# torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
|
||||
|
||||
# Memory:
|
||||
torch.cuda.memory = torch.xpu.memory
|
||||
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
|
||||
torch.xpu.empty_cache = lambda: None
|
||||
torch.cuda.empty_cache = torch.xpu.empty_cache
|
||||
torch.cuda.memory_stats = torch.xpu.memory_stats
|
||||
torch.cuda.memory_summary = torch.xpu.memory_summary
|
||||
torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
|
||||
torch.cuda.memory_allocated = torch.xpu.memory_allocated
|
||||
torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated
|
||||
torch.cuda.memory_reserved = torch.xpu.memory_reserved
|
||||
torch.cuda.memory_cached = torch.xpu.memory_reserved
|
||||
torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved
|
||||
torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved
|
||||
torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats
|
||||
torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats
|
||||
torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats
|
||||
torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict
|
||||
torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats
|
||||
# Memory:
|
||||
torch.cuda.memory = torch.xpu.memory
|
||||
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
|
||||
torch.xpu.empty_cache = lambda: None
|
||||
torch.cuda.empty_cache = torch.xpu.empty_cache
|
||||
torch.cuda.memory_stats = torch.xpu.memory_stats
|
||||
torch.cuda.memory_summary = torch.xpu.memory_summary
|
||||
torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
|
||||
torch.cuda.memory_allocated = torch.xpu.memory_allocated
|
||||
torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated
|
||||
torch.cuda.memory_reserved = torch.xpu.memory_reserved
|
||||
torch.cuda.memory_cached = torch.xpu.memory_reserved
|
||||
torch.cuda.max_memory_reserved = torch.xpu.max_memory_reserved
|
||||
torch.cuda.max_memory_cached = torch.xpu.max_memory_reserved
|
||||
torch.cuda.reset_peak_memory_stats = torch.xpu.reset_peak_memory_stats
|
||||
torch.cuda.reset_max_memory_cached = torch.xpu.reset_peak_memory_stats
|
||||
torch.cuda.reset_max_memory_allocated = torch.xpu.reset_peak_memory_stats
|
||||
torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict
|
||||
torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats
|
||||
|
||||
# RNG:
|
||||
torch.cuda.get_rng_state = torch.xpu.get_rng_state
|
||||
torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all
|
||||
torch.cuda.set_rng_state = torch.xpu.set_rng_state
|
||||
torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all
|
||||
torch.cuda.manual_seed = torch.xpu.manual_seed
|
||||
torch.cuda.manual_seed_all = torch.xpu.manual_seed_all
|
||||
torch.cuda.seed = torch.xpu.seed
|
||||
torch.cuda.seed_all = torch.xpu.seed_all
|
||||
torch.cuda.initial_seed = torch.xpu.initial_seed
|
||||
# RNG:
|
||||
torch.cuda.get_rng_state = torch.xpu.get_rng_state
|
||||
torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all
|
||||
torch.cuda.set_rng_state = torch.xpu.set_rng_state
|
||||
torch.cuda.set_rng_state_all = torch.xpu.set_rng_state_all
|
||||
torch.cuda.manual_seed = torch.xpu.manual_seed
|
||||
torch.cuda.manual_seed_all = torch.xpu.manual_seed_all
|
||||
torch.cuda.seed = torch.xpu.seed
|
||||
torch.cuda.seed_all = torch.xpu.seed_all
|
||||
torch.cuda.initial_seed = torch.xpu.initial_seed
|
||||
|
||||
# AMP:
|
||||
torch.cuda.amp = torch.xpu.amp
|
||||
torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled
|
||||
torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype
|
||||
# AMP:
|
||||
torch.cuda.amp = torch.xpu.amp
|
||||
torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled
|
||||
torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype
|
||||
|
||||
if not hasattr(torch.cuda.amp, "common"):
|
||||
torch.cuda.amp.common = contextlib.nullcontext()
|
||||
torch.cuda.amp.common.amp_definitely_not_available = lambda: False
|
||||
if not hasattr(torch.cuda.amp, "common"):
|
||||
torch.cuda.amp.common = contextlib.nullcontext()
|
||||
torch.cuda.amp.common.amp_definitely_not_available = lambda: False
|
||||
|
||||
try:
|
||||
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
try:
|
||||
from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error
|
||||
gradscaler_init()
|
||||
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
try:
|
||||
from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error
|
||||
gradscaler_init()
|
||||
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
|
||||
torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
|
||||
|
||||
# C
|
||||
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
|
||||
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_eu_count
|
||||
ipex._C._DeviceProperties.major = 2023
|
||||
ipex._C._DeviceProperties.minor = 2
|
||||
# C
|
||||
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
|
||||
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_eu_count
|
||||
ipex._C._DeviceProperties.major = 2023
|
||||
ipex._C._DeviceProperties.minor = 2
|
||||
|
||||
# Fix functions with ipex:
|
||||
torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
|
||||
torch._utils._get_available_device_type = lambda: "xpu"
|
||||
torch.has_cuda = True
|
||||
torch.cuda.has_half = True
|
||||
torch.cuda.is_bf16_supported = lambda *args, **kwargs: True
|
||||
torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
|
||||
torch.backends.cuda.is_built = lambda *args, **kwargs: True
|
||||
torch.version.cuda = "12.1"
|
||||
torch.cuda.get_device_capability = lambda *args, **kwargs: [12,1]
|
||||
torch.cuda.get_device_properties.major = 12
|
||||
torch.cuda.get_device_properties.minor = 1
|
||||
torch.cuda.ipc_collect = lambda *args, **kwargs: None
|
||||
torch.cuda.utilization = lambda *args, **kwargs: 0
|
||||
# Fix functions with ipex:
|
||||
torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
|
||||
torch._utils._get_available_device_type = lambda: "xpu"
|
||||
torch.has_cuda = True
|
||||
torch.cuda.has_half = True
|
||||
torch.cuda.is_bf16_supported = lambda *args, **kwargs: True
|
||||
torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
|
||||
torch.backends.cuda.is_built = lambda *args, **kwargs: True
|
||||
torch.version.cuda = "12.1"
|
||||
torch.cuda.get_device_capability = lambda *args, **kwargs: [12,1]
|
||||
torch.cuda.get_device_properties.major = 12
|
||||
torch.cuda.get_device_properties.minor = 1
|
||||
torch.cuda.ipc_collect = lambda *args, **kwargs: None
|
||||
torch.cuda.utilization = lambda *args, **kwargs: 0
|
||||
|
||||
ipex_hijacks()
|
||||
if not torch.xpu.has_fp64_dtype() or os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is not None:
|
||||
try:
|
||||
from .diffusers import ipex_diffusers
|
||||
ipex_diffusers()
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
pass
|
||||
torch.cuda.is_xpu_hijacked = True
|
||||
ipex_hijacks()
|
||||
if not torch.xpu.has_fp64_dtype() or os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is not None:
|
||||
try:
|
||||
from .diffusers import ipex_diffusers
|
||||
ipex_diffusers()
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
pass
|
||||
except Exception as e:
|
||||
return False, e
|
||||
return True, None
|
||||
|
||||
@@ -12,7 +12,7 @@ device_supports_fp64 = torch.xpu.has_fp64_dtype()
|
||||
class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
|
||||
def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument
|
||||
if isinstance(device_ids, list) and len(device_ids) > 1:
|
||||
logger.error("IPEX backend doesn't support DataParallel on multiple XPU devices")
|
||||
print("IPEX backend doesn't support DataParallel on multiple XPU devices")
|
||||
return module.to("xpu")
|
||||
|
||||
def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
|
||||
|
||||
24
library/ipex_interop.py
Normal file
24
library/ipex_interop.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import torch
|
||||
|
||||
|
||||
def init_ipex():
|
||||
"""
|
||||
Try to import `intel_extension_for_pytorch`, and apply
|
||||
the hijacks using `library.ipex.ipex_init`.
|
||||
|
||||
If IPEX is not installed, this function does nothing.
|
||||
"""
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex # noqa
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
try:
|
||||
from library.ipex import ipex_init
|
||||
|
||||
if torch.xpu.is_available():
|
||||
is_initialized, error_message = ipex_init()
|
||||
if not is_initialized:
|
||||
print("failed to initialize ipex:", error_message)
|
||||
except Exception as e:
|
||||
print("failed to initialize ipex:", e)
|
||||
@@ -17,6 +17,7 @@ from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
||||
from diffusers.utils import logging
|
||||
|
||||
|
||||
try:
|
||||
from diffusers.utils import PIL_INTERPOLATION
|
||||
except ImportError:
|
||||
@@ -625,7 +626,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
||||
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
||||
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
logger.info(f'{height} {width}')
|
||||
print(height, width)
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if (callback_steps is None) or (
|
||||
|
||||
@@ -3,20 +3,16 @@
|
||||
|
||||
import math
|
||||
import os
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex
|
||||
init_ipex()
|
||||
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
import diffusers
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
|
||||
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel
|
||||
from safetensors.torch import load_file, save_file
|
||||
from library.original_unet import UNet2DConditionModel
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# DiffUsers版StableDiffusionのモデルパラメータ
|
||||
NUM_TRAIN_TIMESTEPS = 1000
|
||||
@@ -948,7 +944,7 @@ def convert_vae_state_dict(vae_state_dict):
|
||||
for k, v in new_state_dict.items():
|
||||
for weight_name in weights_to_convert:
|
||||
if f"mid.attn_1.{weight_name}.weight" in k:
|
||||
# logger.info(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1")
|
||||
# print(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1")
|
||||
new_state_dict[k] = reshape_weight_for_sd(v)
|
||||
|
||||
return new_state_dict
|
||||
@@ -1006,7 +1002,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt
|
||||
|
||||
unet = UNet2DConditionModel(**unet_config).to(device)
|
||||
info = unet.load_state_dict(converted_unet_checkpoint)
|
||||
logger.info(f"loading u-net: {info}")
|
||||
print("loading u-net:", info)
|
||||
|
||||
# Convert the VAE model.
|
||||
vae_config = create_vae_diffusers_config()
|
||||
@@ -1014,7 +1010,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt
|
||||
|
||||
vae = AutoencoderKL(**vae_config).to(device)
|
||||
info = vae.load_state_dict(converted_vae_checkpoint)
|
||||
logger.info(f"loading vae: {info}")
|
||||
print("loading vae:", info)
|
||||
|
||||
# convert text_model
|
||||
if v2:
|
||||
@@ -1048,7 +1044,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt
|
||||
# logging.set_verbosity_error() # don't show annoying warning
|
||||
# text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
|
||||
# logging.set_verbosity_warning()
|
||||
# logger.info(f"config: {text_model.config}")
|
||||
# print(f"config: {text_model.config}")
|
||||
cfg = CLIPTextConfig(
|
||||
vocab_size=49408,
|
||||
hidden_size=768,
|
||||
@@ -1071,7 +1067,7 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dt
|
||||
)
|
||||
text_model = CLIPTextModel._from_config(cfg)
|
||||
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
||||
logger.info(f"loading text encoder: {info}")
|
||||
print("loading text encoder:", info)
|
||||
|
||||
return text_model, vae, unet
|
||||
|
||||
@@ -1146,7 +1142,7 @@ def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=Fals
|
||||
|
||||
# 最後の層などを捏造するか
|
||||
if make_dummy_weights:
|
||||
logger.info("make dummy weights for resblock.23, text_projection and logit scale.")
|
||||
print("make dummy weights for resblock.23, text_projection and logit scale.")
|
||||
keys = list(new_sd.keys())
|
||||
for key in keys:
|
||||
if key.startswith("transformer.resblocks.22."):
|
||||
@@ -1265,14 +1261,14 @@ VAE_PREFIX = "first_stage_model."
|
||||
|
||||
|
||||
def load_vae(vae_id, dtype):
|
||||
logger.info(f"load VAE: {vae_id}")
|
||||
print(f"load VAE: {vae_id}")
|
||||
if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
|
||||
# Diffusers local/remote
|
||||
try:
|
||||
vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
|
||||
except EnvironmentError as e:
|
||||
logger.error(f"exception occurs in loading vae: {e}")
|
||||
logger.error("retry with subfolder='vae'")
|
||||
print(f"exception occurs in loading vae: {e}")
|
||||
print("retry with subfolder='vae'")
|
||||
vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
|
||||
return vae
|
||||
|
||||
@@ -1344,13 +1340,13 @@ def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64)
|
||||
|
||||
if __name__ == "__main__":
|
||||
resos = make_bucket_resolutions((512, 768))
|
||||
logger.info(f"{len(resos)}")
|
||||
logger.info(f"{resos}")
|
||||
print(len(resos))
|
||||
print(resos)
|
||||
aspect_ratios = [w / h for w, h in resos]
|
||||
logger.info(f"{aspect_ratios}")
|
||||
print(aspect_ratios)
|
||||
|
||||
ars = set()
|
||||
for ar in aspect_ratios:
|
||||
if ar in ars:
|
||||
logger.error(f"error! duplicate ar: {ar}")
|
||||
print("error! duplicate ar:", ar)
|
||||
ars.add(ar)
|
||||
|
||||
@@ -113,10 +113,6 @@ import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from einops import rearrange
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BLOCK_OUT_CHANNELS: Tuple[int] = (320, 640, 1280, 1280)
|
||||
TIMESTEP_INPUT_DIM = BLOCK_OUT_CHANNELS[0]
|
||||
@@ -1384,7 +1380,7 @@ class UNet2DConditionModel(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
assert sample_size is not None, "sample_size must be specified"
|
||||
logger.info(
|
||||
print(
|
||||
f"UNet2DConditionModel: {sample_size}, {attention_head_dim}, {cross_attention_dim}, {use_linear_projection}, {upcast_attention}"
|
||||
)
|
||||
|
||||
@@ -1518,7 +1514,7 @@ class UNet2DConditionModel(nn.Module):
|
||||
def set_gradient_checkpointing(self, value=False):
|
||||
modules = self.down_blocks + [self.mid_block] + self.up_blocks
|
||||
for module in modules:
|
||||
logger.info(f"{module.__class__.__name__} {module.gradient_checkpointing} -> {value}")
|
||||
print(module.__class__.__name__, module.gradient_checkpointing, "->", value)
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
# endregion
|
||||
@@ -1713,14 +1709,14 @@ class InferUNet2DConditionModel:
|
||||
|
||||
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
|
||||
if ds_depth_1 is None:
|
||||
logger.info("Deep Shrink is disabled.")
|
||||
print("Deep Shrink is disabled.")
|
||||
self.ds_depth_1 = None
|
||||
self.ds_timesteps_1 = None
|
||||
self.ds_depth_2 = None
|
||||
self.ds_timesteps_2 = None
|
||||
self.ds_ratio = None
|
||||
else:
|
||||
logger.info(
|
||||
print(
|
||||
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
|
||||
)
|
||||
self.ds_depth_1 = ds_depth_1
|
||||
|
||||
@@ -5,12 +5,6 @@ from io import BytesIO
|
||||
import os
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import safetensors
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
r"""
|
||||
# Metadata Example
|
||||
@@ -57,13 +51,11 @@ ARCH_SD_V1 = "stable-diffusion-v1"
|
||||
ARCH_SD_V2_512 = "stable-diffusion-v2-512"
|
||||
ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v"
|
||||
ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base"
|
||||
ARCH_STABLE_CASCADE = "stable-cascade"
|
||||
|
||||
ADAPTER_LORA = "lora"
|
||||
ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
|
||||
|
||||
IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models"
|
||||
IMPL_STABILITY_AI_STABLE_CASCADE = "https://github.com/Stability-AI/StableCascade"
|
||||
IMPL_DIFFUSERS = "diffusers"
|
||||
|
||||
PRED_TYPE_EPSILON = "epsilon"
|
||||
@@ -117,7 +109,6 @@ def build_metadata(
|
||||
merged_from: Optional[str] = None,
|
||||
timesteps: Optional[Tuple[int, int]] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
stable_cascade: Optional[bool] = None,
|
||||
):
|
||||
# if state_dict is None, hash is not calculated
|
||||
|
||||
@@ -129,9 +120,7 @@ def build_metadata(
|
||||
# hash = precalculate_safetensors_hashes(state_dict)
|
||||
# metadata["modelspec.hash_sha256"] = hash
|
||||
|
||||
if stable_cascade:
|
||||
arch = ARCH_STABLE_CASCADE
|
||||
elif sdxl:
|
||||
if sdxl:
|
||||
arch = ARCH_SD_XL_V1_BASE
|
||||
elif v2:
|
||||
if v_parameterization:
|
||||
@@ -149,11 +138,9 @@ def build_metadata(
|
||||
metadata["modelspec.architecture"] = arch
|
||||
|
||||
if not lora and not textual_inversion and is_stable_diffusion_ckpt is None:
|
||||
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
|
||||
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion
|
||||
|
||||
if stable_cascade:
|
||||
impl = IMPL_STABILITY_AI_STABLE_CASCADE
|
||||
elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
|
||||
if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
|
||||
# Stable Diffusion ckpt, TI, SDXL LoRA
|
||||
impl = IMPL_STABILITY_AI
|
||||
else:
|
||||
@@ -244,8 +231,8 @@ def build_metadata(
|
||||
# # assert all values are filled
|
||||
# assert all([v is not None for v in metadata.values()]), metadata
|
||||
if not all([v is not None for v in metadata.values()]):
|
||||
logger.error(f"Internal error: some metadata values are None: {metadata}")
|
||||
|
||||
print(f"Internal error: some metadata values are None: {metadata}")
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
@@ -259,7 +246,7 @@ def get_title(metadata: dict) -> Optional[str]:
|
||||
def load_metadata_from_safetensors(model: str) -> dict:
|
||||
if not model.endswith(".safetensors"):
|
||||
return {}
|
||||
|
||||
|
||||
with safetensors.safe_open(model, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata is None:
|
||||
|
||||
@@ -7,10 +7,7 @@ from typing import List
|
||||
from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
|
||||
from library import model_util
|
||||
from library import sdxl_original_unet
|
||||
from .utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
VAE_SCALE_FACTOR = 0.13025
|
||||
MODEL_VERSION_SDXL_BASE_V1_0 = "sdxl_base_v1-0"
|
||||
@@ -134,7 +131,7 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
|
||||
|
||||
# temporary workaround for text_projection.weight.weight for Playground-v2
|
||||
if "text_projection.weight.weight" in new_sd:
|
||||
logger.info("convert_sdxl_text_encoder_2_checkpoint: convert text_projection.weight.weight to text_projection.weight")
|
||||
print(f"convert_sdxl_text_encoder_2_checkpoint: convert text_projection.weight.weight to text_projection.weight")
|
||||
new_sd["text_projection.weight"] = new_sd["text_projection.weight.weight"]
|
||||
del new_sd["text_projection.weight.weight"]
|
||||
|
||||
@@ -189,20 +186,20 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
|
||||
checkpoint = None
|
||||
|
||||
# U-Net
|
||||
logger.info("building U-Net")
|
||||
print("building U-Net")
|
||||
with init_empty_weights():
|
||||
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
|
||||
|
||||
logger.info("loading U-Net from checkpoint")
|
||||
print("loading U-Net from checkpoint")
|
||||
unet_sd = {}
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith("model.diffusion_model."):
|
||||
unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k)
|
||||
info = _load_state_dict_on_device(unet, unet_sd, device=map_location, dtype=dtype)
|
||||
logger.info(f"U-Net: {info}")
|
||||
print("U-Net: ", info)
|
||||
|
||||
# Text Encoders
|
||||
logger.info("building text encoders")
|
||||
print("building text encoders")
|
||||
|
||||
# Text Encoder 1 is same to Stability AI's SDXL
|
||||
text_model1_cfg = CLIPTextConfig(
|
||||
@@ -255,7 +252,7 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
|
||||
with init_empty_weights():
|
||||
text_model2 = CLIPTextModelWithProjection(text_model2_cfg)
|
||||
|
||||
logger.info("loading text encoders from checkpoint")
|
||||
print("loading text encoders from checkpoint")
|
||||
te1_sd = {}
|
||||
te2_sd = {}
|
||||
for k in list(state_dict.keys()):
|
||||
@@ -269,22 +266,22 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
|
||||
te1_sd.pop("text_model.embeddings.position_ids")
|
||||
|
||||
info1 = _load_state_dict_on_device(text_model1, te1_sd, device=map_location) # remain fp32
|
||||
logger.info(f"text encoder 1: {info1}")
|
||||
print("text encoder 1:", info1)
|
||||
|
||||
converted_sd, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77)
|
||||
info2 = _load_state_dict_on_device(text_model2, converted_sd, device=map_location) # remain fp32
|
||||
logger.info(f"text encoder 2: {info2}")
|
||||
print("text encoder 2:", info2)
|
||||
|
||||
# prepare vae
|
||||
logger.info("building VAE")
|
||||
print("building VAE")
|
||||
vae_config = model_util.create_vae_diffusers_config()
|
||||
with init_empty_weights():
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
|
||||
logger.info("loading VAE from checkpoint")
|
||||
print("loading VAE from checkpoint")
|
||||
converted_vae_checkpoint = model_util.convert_ldm_vae_checkpoint(state_dict, vae_config)
|
||||
info = _load_state_dict_on_device(vae, converted_vae_checkpoint, device=map_location, dtype=dtype)
|
||||
logger.info(f"VAE: {info}")
|
||||
print("VAE:", info)
|
||||
|
||||
ckpt_info = (epoch, global_step) if epoch is not None else None
|
||||
return text_model1, text_model2, vae, unet, logit_scale, ckpt_info
|
||||
|
||||
@@ -30,10 +30,7 @@ import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from einops import rearrange
|
||||
from .utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
IN_CHANNELS: int = 4
|
||||
OUT_CHANNELS: int = 4
|
||||
@@ -335,7 +332,7 @@ class ResnetBlock2D(nn.Module):
|
||||
|
||||
def forward(self, x, emb):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
# logger.info("ResnetBlock2D: gradient_checkpointing")
|
||||
# print("ResnetBlock2D: gradient_checkpointing")
|
||||
|
||||
def create_custom_forward(func):
|
||||
def custom_forward(*inputs):
|
||||
@@ -369,7 +366,7 @@ class Downsample2D(nn.Module):
|
||||
|
||||
def forward(self, hidden_states):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
# logger.info("Downsample2D: gradient_checkpointing")
|
||||
# print("Downsample2D: gradient_checkpointing")
|
||||
|
||||
def create_custom_forward(func):
|
||||
def custom_forward(*inputs):
|
||||
@@ -656,7 +653,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
|
||||
def forward(self, hidden_states, context=None, timestep=None):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
# logger.info("BasicTransformerBlock: checkpointing")
|
||||
# print("BasicTransformerBlock: checkpointing")
|
||||
|
||||
def create_custom_forward(func):
|
||||
def custom_forward(*inputs):
|
||||
@@ -799,7 +796,7 @@ class Upsample2D(nn.Module):
|
||||
|
||||
def forward(self, hidden_states, output_size=None):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
# logger.info("Upsample2D: gradient_checkpointing")
|
||||
# print("Upsample2D: gradient_checkpointing")
|
||||
|
||||
def create_custom_forward(func):
|
||||
def custom_forward(*inputs):
|
||||
@@ -1049,7 +1046,7 @@ class SdxlUNet2DConditionModel(nn.Module):
|
||||
for block in blocks:
|
||||
for module in block:
|
||||
if hasattr(module, "set_use_memory_efficient_attention"):
|
||||
# logger.info(module.__class__.__name__)
|
||||
# print(module.__class__.__name__)
|
||||
module.set_use_memory_efficient_attention(xformers, mem_eff)
|
||||
|
||||
def set_use_sdpa(self, sdpa: bool) -> None:
|
||||
@@ -1064,7 +1061,7 @@ class SdxlUNet2DConditionModel(nn.Module):
|
||||
for block in blocks:
|
||||
for module in block.modules():
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
# logger.info(f{module.__class__.__name__} {module.gradient_checkpointing} -> {value}")
|
||||
# print(module.__class__.__name__, module.gradient_checkpointing, "->", value)
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
# endregion
|
||||
@@ -1086,7 +1083,7 @@ class SdxlUNet2DConditionModel(nn.Module):
|
||||
def call_module(module, h, emb, context):
|
||||
x = h
|
||||
for layer in module:
|
||||
# logger.info(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None)
|
||||
# print(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None)
|
||||
if isinstance(layer, ResnetBlock2D):
|
||||
x = layer(x, emb)
|
||||
elif isinstance(layer, Transformer2DModel):
|
||||
@@ -1138,14 +1135,14 @@ class InferSdxlUNet2DConditionModel:
|
||||
|
||||
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
|
||||
if ds_depth_1 is None:
|
||||
logger.info("Deep Shrink is disabled.")
|
||||
print("Deep Shrink is disabled.")
|
||||
self.ds_depth_1 = None
|
||||
self.ds_timesteps_1 = None
|
||||
self.ds_depth_2 = None
|
||||
self.ds_timesteps_2 = None
|
||||
self.ds_ratio = None
|
||||
else:
|
||||
logger.info(
|
||||
print(
|
||||
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
|
||||
)
|
||||
self.ds_depth_1 = ds_depth_1
|
||||
@@ -1232,7 +1229,7 @@ class InferSdxlUNet2DConditionModel:
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
|
||||
logger.info("create unet")
|
||||
print("create unet")
|
||||
unet = SdxlUNet2DConditionModel()
|
||||
|
||||
unet.to("cuda")
|
||||
@@ -1241,7 +1238,7 @@ if __name__ == "__main__":
|
||||
unet.train()
|
||||
|
||||
# 使用メモリ量確認用の疑似学習ループ
|
||||
logger.info("preparing optimizer")
|
||||
print("preparing optimizer")
|
||||
|
||||
# optimizer = torch.optim.SGD(unet.parameters(), lr=1e-3, nesterov=True, momentum=0.9) # not working
|
||||
|
||||
@@ -1256,12 +1253,12 @@ if __name__ == "__main__":
|
||||
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=True)
|
||||
|
||||
logger.info("start training")
|
||||
print("start training")
|
||||
steps = 10
|
||||
batch_size = 1
|
||||
|
||||
for step in range(steps):
|
||||
logger.info(f"step {step}")
|
||||
print(f"step {step}")
|
||||
if step == 1:
|
||||
time_start = time.perf_counter()
|
||||
|
||||
@@ -1281,4 +1278,4 @@ if __name__ == "__main__":
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
time_end = time.perf_counter()
|
||||
logger.info(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps")
|
||||
print(f"elapsed time: {time_end - time_start} [sec] for last {steps - 1} steps")
|
||||
|
||||
@@ -1,21 +1,14 @@
|
||||
import argparse
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
init_ipex()
|
||||
|
||||
from accelerate import init_empty_weights
|
||||
from tqdm import tqdm
|
||||
from transformers import CLIPTokenizer
|
||||
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
|
||||
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
|
||||
from .utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
|
||||
TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
@@ -28,7 +21,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
||||
model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16
|
||||
for pi in range(accelerator.state.num_processes):
|
||||
if pi == accelerator.state.local_process_index:
|
||||
logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
|
||||
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
|
||||
|
||||
(
|
||||
load_stable_diffusion_format,
|
||||
@@ -54,7 +47,8 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
||||
unet.to(accelerator.device)
|
||||
vae.to(accelerator.device)
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
|
||||
@@ -68,7 +62,7 @@ def _load_target_model(
|
||||
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
|
||||
|
||||
if load_stable_diffusion_format:
|
||||
logger.info(f"load StableDiffusion checkpoint: {name_or_path}")
|
||||
print(f"load StableDiffusion checkpoint: {name_or_path}")
|
||||
(
|
||||
text_encoder1,
|
||||
text_encoder2,
|
||||
@@ -82,7 +76,7 @@ def _load_target_model(
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
|
||||
variant = "fp16" if weight_dtype == torch.float16 else None
|
||||
logger.info(f"load Diffusers pretrained models: {name_or_path}, variant={variant}")
|
||||
print(f"load Diffusers pretrained models: {name_or_path}, variant={variant}")
|
||||
try:
|
||||
try:
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
@@ -90,12 +84,12 @@ def _load_target_model(
|
||||
)
|
||||
except EnvironmentError as ex:
|
||||
if variant is not None:
|
||||
logger.info("try to load fp32 model")
|
||||
print("try to load fp32 model")
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, variant=None, tokenizer=None)
|
||||
else:
|
||||
raise ex
|
||||
except EnvironmentError as ex:
|
||||
logger.error(
|
||||
print(
|
||||
f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}"
|
||||
)
|
||||
raise ex
|
||||
@@ -118,7 +112,7 @@ def _load_target_model(
|
||||
with init_empty_weights():
|
||||
unet = sdxl_original_unet.SdxlUNet2DConditionModel() # overwrite unet
|
||||
sdxl_model_util._load_state_dict_on_device(unet, state_dict, device=device, dtype=model_dtype)
|
||||
logger.info("U-Net converted to original U-Net")
|
||||
print("U-Net converted to original U-Net")
|
||||
|
||||
logit_scale = None
|
||||
ckpt_info = None
|
||||
@@ -126,13 +120,13 @@ def _load_target_model(
|
||||
# VAEを読み込む
|
||||
if vae_path is not None:
|
||||
vae = model_util.load_vae(vae_path, weight_dtype)
|
||||
logger.info("additional VAE loaded")
|
||||
print("additional VAE loaded")
|
||||
|
||||
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
|
||||
|
||||
|
||||
def load_tokenizers(args: argparse.Namespace):
|
||||
logger.info("prepare tokenizers")
|
||||
print("prepare tokenizers")
|
||||
|
||||
original_paths = [TOKENIZER1_PATH, TOKENIZER2_PATH]
|
||||
tokeniers = []
|
||||
@@ -141,14 +135,14 @@ def load_tokenizers(args: argparse.Namespace):
|
||||
if args.tokenizer_cache_dir:
|
||||
local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_"))
|
||||
if os.path.exists(local_tokenizer_path):
|
||||
logger.info(f"load tokenizer from cache: {local_tokenizer_path}")
|
||||
print(f"load tokenizer from cache: {local_tokenizer_path}")
|
||||
tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path)
|
||||
|
||||
if tokenizer is None:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(original_path)
|
||||
|
||||
if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
|
||||
logger.info(f"save Tokenizer to cache: {local_tokenizer_path}")
|
||||
print(f"save Tokenizer to cache: {local_tokenizer_path}")
|
||||
tokenizer.save_pretrained(local_tokenizer_path)
|
||||
|
||||
if i == 1:
|
||||
@@ -157,7 +151,7 @@ def load_tokenizers(args: argparse.Namespace):
|
||||
tokeniers.append(tokenizer)
|
||||
|
||||
if hasattr(args, "max_token_length") and args.max_token_length is not None:
|
||||
logger.info(f"update token length: {args.max_token_length}")
|
||||
print(f"update token length: {args.max_token_length}")
|
||||
|
||||
return tokeniers
|
||||
|
||||
@@ -338,23 +332,23 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
|
||||
def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
|
||||
assert not args.v2, "v2 cannot be enabled in SDXL training / SDXL学習ではv2を有効にすることはできません"
|
||||
if args.v_parameterization:
|
||||
logger.warning("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります")
|
||||
print("v_parameterization will be unexpected / SDXL学習ではv_parameterizationは想定外の動作になります")
|
||||
|
||||
if args.clip_skip is not None:
|
||||
logger.warning("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")
|
||||
print("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")
|
||||
|
||||
# if args.multires_noise_iterations:
|
||||
# logger.info(
|
||||
# print(
|
||||
# f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されていますが、multires_noise_iterationsが有効になっているためnoise_offsetは無効になります"
|
||||
# )
|
||||
# else:
|
||||
# if args.noise_offset is None:
|
||||
# args.noise_offset = DEFAULT_NOISE_OFFSET
|
||||
# elif args.noise_offset != DEFAULT_NOISE_OFFSET:
|
||||
# logger.info(
|
||||
# print(
|
||||
# f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています"
|
||||
# )
|
||||
# logger.info(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
|
||||
# print(f"noise_offset is set to {args.noise_offset} / noise_offsetが{args.noise_offset}に設定されました")
|
||||
|
||||
assert (
|
||||
not hasattr(args, "weighted_captions") or not args.weighted_captions
|
||||
@@ -363,7 +357,7 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin
|
||||
if supportTextEncoderCaching:
|
||||
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
||||
args.cache_text_encoder_outputs = True
|
||||
logger.warning(
|
||||
print(
|
||||
"cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / "
|
||||
+ "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました"
|
||||
)
|
||||
|
||||
@@ -26,10 +26,7 @@ from diffusers.models.modeling_utils import ModelMixin
|
||||
from diffusers.models.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
|
||||
from diffusers.models.vae import DecoderOutput, DiagonalGaussianDistribution
|
||||
from diffusers.models.autoencoder_kl import AutoencoderKLOutput
|
||||
from .utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def slice_h(x, num_slices):
|
||||
# slice with pad 1 both sides: to eliminate side effect of padding of conv2d
|
||||
@@ -92,7 +89,7 @@ def resblock_forward(_self, num_slices, input_tensor, temb, **kwargs):
|
||||
# sliced_tensor = torch.chunk(x, num_div, dim=1)
|
||||
# sliced_weight = torch.chunk(norm.weight, num_div, dim=0)
|
||||
# sliced_bias = torch.chunk(norm.bias, num_div, dim=0)
|
||||
# logger.info(sliced_tensor[0].shape, num_div, sliced_weight[0].shape, sliced_bias[0].shape)
|
||||
# print(sliced_tensor[0].shape, num_div, sliced_weight[0].shape, sliced_bias[0].shape)
|
||||
# normed_tensor = []
|
||||
# for i in range(num_div):
|
||||
# n = torch.group_norm(sliced_tensor[i], norm.num_groups, sliced_weight[i], sliced_bias[i], norm.eps)
|
||||
@@ -246,7 +243,7 @@ class SlicingEncoder(nn.Module):
|
||||
|
||||
self.num_slices = num_slices
|
||||
div = num_slices / (2 ** (len(self.down_blocks) - 1)) # 深い層はそこまで分割しなくていいので適宜減らす
|
||||
# logger.info(f"initial divisor: {div}")
|
||||
# print(f"initial divisor: {div}")
|
||||
if div >= 2:
|
||||
div = int(div)
|
||||
for resnet in self.mid_block.resnets:
|
||||
@@ -256,11 +253,11 @@ class SlicingEncoder(nn.Module):
|
||||
for i, down_block in enumerate(self.down_blocks[::-1]):
|
||||
if div >= 2:
|
||||
div = int(div)
|
||||
# logger.info(f"down block: {i} divisor: {div}")
|
||||
# print(f"down block: {i} divisor: {div}")
|
||||
for resnet in down_block.resnets:
|
||||
resnet.forward = wrapper(resblock_forward, resnet, div)
|
||||
if down_block.downsamplers is not None:
|
||||
# logger.info("has downsample")
|
||||
# print("has downsample")
|
||||
for downsample in down_block.downsamplers:
|
||||
downsample.forward = wrapper(self.downsample_forward, downsample, div * 2)
|
||||
div *= 2
|
||||
@@ -310,7 +307,7 @@ class SlicingEncoder(nn.Module):
|
||||
def downsample_forward(self, _self, num_slices, hidden_states):
|
||||
assert hidden_states.shape[1] == _self.channels
|
||||
assert _self.use_conv and _self.padding == 0
|
||||
logger.info(f"downsample forward {num_slices} {hidden_states.shape}")
|
||||
print("downsample forward", num_slices, hidden_states.shape)
|
||||
|
||||
org_device = hidden_states.device
|
||||
cpu_device = torch.device("cpu")
|
||||
@@ -353,7 +350,7 @@ class SlicingEncoder(nn.Module):
|
||||
hidden_states = torch.cat([hidden_states, x], dim=2)
|
||||
|
||||
hidden_states = hidden_states.to(org_device)
|
||||
# logger.info(f"downsample forward done {hidden_states.shape}")
|
||||
# print("downsample forward done", hidden_states.shape)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -429,7 +426,7 @@ class SlicingDecoder(nn.Module):
|
||||
|
||||
self.num_slices = num_slices
|
||||
div = num_slices / (2 ** (len(self.up_blocks) - 1))
|
||||
logger.info(f"initial divisor: {div}")
|
||||
print(f"initial divisor: {div}")
|
||||
if div >= 2:
|
||||
div = int(div)
|
||||
for resnet in self.mid_block.resnets:
|
||||
@@ -439,11 +436,11 @@ class SlicingDecoder(nn.Module):
|
||||
for i, up_block in enumerate(self.up_blocks):
|
||||
if div >= 2:
|
||||
div = int(div)
|
||||
# logger.info(f"up block: {i} divisor: {div}")
|
||||
# print(f"up block: {i} divisor: {div}")
|
||||
for resnet in up_block.resnets:
|
||||
resnet.forward = wrapper(resblock_forward, resnet, div)
|
||||
if up_block.upsamplers is not None:
|
||||
# logger.info("has upsample")
|
||||
# print("has upsample")
|
||||
for upsample in up_block.upsamplers:
|
||||
upsample.forward = wrapper(self.upsample_forward, upsample, div * 2)
|
||||
div *= 2
|
||||
@@ -531,7 +528,7 @@ class SlicingDecoder(nn.Module):
|
||||
del x
|
||||
|
||||
hidden_states = torch.cat(sliced, dim=2)
|
||||
# logger.info(f"us hidden_states {hidden_states.shape}")
|
||||
# print("us hidden_states", hidden_states.shape)
|
||||
del sliced
|
||||
|
||||
hidden_states = hidden_states.to(org_device)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,668 +0,0 @@
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from typing import List
|
||||
import numpy as np
|
||||
import toml
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
from transformers import CLIPTokenizer, CLIPTextModelWithProjection, CLIPTextConfig
|
||||
from accelerate import init_empty_weights, Accelerator, PartialState
|
||||
from PIL import Image
|
||||
|
||||
from library import stable_cascade as sc
|
||||
|
||||
from library.sdxl_model_util import _load_state_dict_on_device
|
||||
from library.device_utils import clean_memory_on_device
|
||||
from library.train_util import (
|
||||
save_sd_model_on_epoch_end_or_stepwise_common,
|
||||
save_sd_model_on_train_end_common,
|
||||
line_to_prompt_dict,
|
||||
get_hidden_states_stable_cascade,
|
||||
)
|
||||
from library import sai_model_spec
|
||||
|
||||
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
CLIP_TEXT_MODEL_NAME: str = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
|
||||
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_sc_te_outputs.npz"
|
||||
|
||||
|
||||
def calculate_latent_sizes(height=1024, width=1024, batch_size=4, compression_factor_b=42.67, compression_factor_a=4.0):
|
||||
resolution_multiple = 42.67
|
||||
latent_height = math.ceil(height / compression_factor_b)
|
||||
latent_width = math.ceil(width / compression_factor_b)
|
||||
stage_c_latent_shape = (batch_size, 16, latent_height, latent_width)
|
||||
|
||||
latent_height = math.ceil(height / compression_factor_a)
|
||||
latent_width = math.ceil(width / compression_factor_a)
|
||||
stage_b_latent_shape = (batch_size, 4, latent_height, latent_width)
|
||||
|
||||
return stage_c_latent_shape, stage_b_latent_shape
|
||||
|
||||
|
||||
# region load and save
|
||||
|
||||
|
||||
def load_effnet(effnet_checkpoint_path, loading_device="cpu") -> sc.EfficientNetEncoder:
|
||||
logger.info(f"Loading EfficientNet encoder from {effnet_checkpoint_path}")
|
||||
effnet = sc.EfficientNetEncoder()
|
||||
effnet_checkpoint = load_file(effnet_checkpoint_path)
|
||||
info = effnet.load_state_dict(effnet_checkpoint if "state_dict" not in effnet_checkpoint else effnet_checkpoint["state_dict"])
|
||||
logger.info(info)
|
||||
del effnet_checkpoint
|
||||
return effnet
|
||||
|
||||
|
||||
def load_tokenizer(args: argparse.Namespace):
|
||||
# TODO commonize with sdxl_train_util.load_tokenizers
|
||||
logger.info("prepare tokenizers")
|
||||
|
||||
original_paths = [CLIP_TEXT_MODEL_NAME]
|
||||
tokenizers = []
|
||||
for i, original_path in enumerate(original_paths):
|
||||
tokenizer: CLIPTokenizer = None
|
||||
if args.tokenizer_cache_dir:
|
||||
local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_"))
|
||||
if os.path.exists(local_tokenizer_path):
|
||||
logger.info(f"load tokenizer from cache: {local_tokenizer_path}")
|
||||
tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path)
|
||||
|
||||
if tokenizer is None:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(original_path)
|
||||
|
||||
if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
|
||||
logger.info(f"save Tokenizer to cache: {local_tokenizer_path}")
|
||||
tokenizer.save_pretrained(local_tokenizer_path)
|
||||
|
||||
tokenizers.append(tokenizer)
|
||||
|
||||
if hasattr(args, "max_token_length") and args.max_token_length is not None:
|
||||
logger.info(f"update token length: {args.max_token_length}")
|
||||
|
||||
return tokenizers[0]
|
||||
|
||||
|
||||
def load_stage_c_model(stage_c_checkpoint_path, dtype=None, device="cpu") -> sc.StageC:
|
||||
# Generator
|
||||
logger.info(f"Instantiating Stage C generator")
|
||||
with init_empty_weights():
|
||||
generator_c = sc.StageC()
|
||||
logger.info(f"Loading Stage C generator from {stage_c_checkpoint_path}")
|
||||
stage_c_checkpoint = load_file(stage_c_checkpoint_path)
|
||||
|
||||
stage_c_checkpoint = convert_state_dict_mha_to_normal_attn(stage_c_checkpoint)
|
||||
|
||||
logger.info(f"Loading state dict")
|
||||
info = _load_state_dict_on_device(generator_c, stage_c_checkpoint, device, dtype=dtype)
|
||||
logger.info(info)
|
||||
return generator_c
|
||||
|
||||
|
||||
def load_stage_b_model(stage_b_checkpoint_path, dtype=None, device="cpu") -> sc.StageB:
|
||||
logger.info(f"Instantiating Stage B generator")
|
||||
with init_empty_weights():
|
||||
generator_b = sc.StageB()
|
||||
logger.info(f"Loading Stage B generator from {stage_b_checkpoint_path}")
|
||||
stage_b_checkpoint = load_file(stage_b_checkpoint_path)
|
||||
|
||||
stage_b_checkpoint = convert_state_dict_mha_to_normal_attn(stage_b_checkpoint)
|
||||
|
||||
logger.info(f"Loading state dict")
|
||||
info = _load_state_dict_on_device(generator_b, stage_b_checkpoint, device, dtype=dtype)
|
||||
logger.info(info)
|
||||
return generator_b
|
||||
|
||||
|
||||
def load_clip_text_model(text_model_checkpoint_path, dtype=None, device="cpu", save_text_model=False):
|
||||
# CLIP encoders
|
||||
logger.info(f"Loading CLIP text model")
|
||||
if save_text_model or text_model_checkpoint_path is None:
|
||||
logger.info(f"Loading CLIP text model from {CLIP_TEXT_MODEL_NAME}")
|
||||
text_model = CLIPTextModelWithProjection.from_pretrained(CLIP_TEXT_MODEL_NAME)
|
||||
|
||||
if save_text_model:
|
||||
sd = text_model.state_dict()
|
||||
logger.info(f"Saving CLIP text model to {text_model_checkpoint_path}")
|
||||
save_file(sd, text_model_checkpoint_path)
|
||||
else:
|
||||
logger.info(f"Loading CLIP text model from {text_model_checkpoint_path}")
|
||||
|
||||
# copy from sdxl_model_util.py
|
||||
text_model2_cfg = CLIPTextConfig(
|
||||
vocab_size=49408,
|
||||
hidden_size=1280,
|
||||
intermediate_size=5120,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=20,
|
||||
max_position_embeddings=77,
|
||||
hidden_act="gelu",
|
||||
layer_norm_eps=1e-05,
|
||||
dropout=0.0,
|
||||
attention_dropout=0.0,
|
||||
initializer_range=0.02,
|
||||
initializer_factor=1.0,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
model_type="clip_text_model",
|
||||
projection_dim=1280,
|
||||
# torch_dtype="float32",
|
||||
# transformers_version="4.25.0.dev0",
|
||||
)
|
||||
with init_empty_weights():
|
||||
text_model = CLIPTextModelWithProjection(text_model2_cfg)
|
||||
|
||||
text_model_checkpoint = load_file(text_model_checkpoint_path)
|
||||
info = _load_state_dict_on_device(text_model, text_model_checkpoint, device, dtype=dtype)
|
||||
logger.info(info)
|
||||
|
||||
return text_model
|
||||
|
||||
|
||||
def load_stage_a_model(stage_a_checkpoint_path, dtype=None, device="cpu") -> sc.StageA:
|
||||
logger.info(f"Loading Stage A vqGAN from {stage_a_checkpoint_path}")
|
||||
stage_a = sc.StageA().to(device)
|
||||
stage_a_checkpoint = load_file(stage_a_checkpoint_path)
|
||||
info = stage_a.load_state_dict(
|
||||
stage_a_checkpoint if "state_dict" not in stage_a_checkpoint else stage_a_checkpoint["state_dict"]
|
||||
)
|
||||
logger.info(info)
|
||||
return stage_a
|
||||
|
||||
|
||||
def load_previewer_model(previewer_checkpoint_path, dtype=None, device="cpu") -> sc.Previewer:
|
||||
logger.info(f"Loading Previewer from {previewer_checkpoint_path}")
|
||||
previewer = sc.Previewer().to(device)
|
||||
previewer_checkpoint = load_file(previewer_checkpoint_path)
|
||||
info = previewer.load_state_dict(
|
||||
previewer_checkpoint if "state_dict" not in previewer_checkpoint else previewer_checkpoint["state_dict"]
|
||||
)
|
||||
logger.info(info)
|
||||
return previewer
|
||||
|
||||
|
||||
def convert_state_dict_mha_to_normal_attn(state_dict):
|
||||
# convert nn.MultiheadAttention to to_q/k/v and out_proj
|
||||
print("convert_state_dict_mha_to_normal_attn")
|
||||
for key in list(state_dict.keys()):
|
||||
if "attention.attn." in key:
|
||||
if "in_proj_bias" in key:
|
||||
value = state_dict.pop(key)
|
||||
qkv = torch.chunk(value, 3, dim=0)
|
||||
state_dict[key.replace("in_proj_bias", "to_q.bias")] = qkv[0]
|
||||
state_dict[key.replace("in_proj_bias", "to_k.bias")] = qkv[1]
|
||||
state_dict[key.replace("in_proj_bias", "to_v.bias")] = qkv[2]
|
||||
elif "in_proj_weight" in key:
|
||||
value = state_dict.pop(key)
|
||||
qkv = torch.chunk(value, 3, dim=0)
|
||||
state_dict[key.replace("in_proj_weight", "to_q.weight")] = qkv[0]
|
||||
state_dict[key.replace("in_proj_weight", "to_k.weight")] = qkv[1]
|
||||
state_dict[key.replace("in_proj_weight", "to_v.weight")] = qkv[2]
|
||||
elif "out_proj.bias" in key:
|
||||
value = state_dict.pop(key)
|
||||
state_dict[key.replace("out_proj.bias", "out_proj.bias")] = value
|
||||
elif "out_proj.weight" in key:
|
||||
value = state_dict.pop(key)
|
||||
state_dict[key.replace("out_proj.weight", "out_proj.weight")] = value
|
||||
return state_dict
|
||||
|
||||
|
||||
def convert_state_dict_normal_attn_to_mha(state_dict):
|
||||
# convert to_q/k/v and out_proj to nn.MultiheadAttention
|
||||
for key in list(state_dict.keys()):
|
||||
if "attention.attn." in key:
|
||||
if "to_q.bias" in key:
|
||||
q = state_dict.pop(key)
|
||||
k = state_dict.pop(key.replace("to_q.bias", "to_k.bias"))
|
||||
v = state_dict.pop(key.replace("to_q.bias", "to_v.bias"))
|
||||
state_dict[key.replace("to_q.bias", "in_proj_bias")] = torch.cat([q, k, v])
|
||||
elif "to_q.weight" in key:
|
||||
q = state_dict.pop(key)
|
||||
k = state_dict.pop(key.replace("to_q.weight", "to_k.weight"))
|
||||
v = state_dict.pop(key.replace("to_q.weight", "to_v.weight"))
|
||||
state_dict[key.replace("to_q.weight", "in_proj_weight")] = torch.cat([q, k, v])
|
||||
elif "out_proj.bias" in key:
|
||||
v = state_dict.pop(key)
|
||||
state_dict[key.replace("out_proj.bias", "out_proj.bias")] = v
|
||||
elif "out_proj.weight" in key:
|
||||
v = state_dict.pop(key)
|
||||
state_dict[key.replace("out_proj.weight", "out_proj.weight")] = v
|
||||
return state_dict
|
||||
|
||||
|
||||
def get_sai_model_spec(args, lora=False):
|
||||
timestamp = time.time()
|
||||
|
||||
reso = args.resolution
|
||||
|
||||
title = args.metadata_title if args.metadata_title is not None else args.output_name
|
||||
|
||||
if args.min_timestep is not None or args.max_timestep is not None:
|
||||
min_time_step = args.min_timestep if args.min_timestep is not None else 0
|
||||
max_time_step = args.max_timestep if args.max_timestep is not None else 1000
|
||||
timesteps = (min_time_step, max_time_step)
|
||||
else:
|
||||
timesteps = None
|
||||
|
||||
metadata = sai_model_spec.build_metadata(
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
lora,
|
||||
False,
|
||||
timestamp,
|
||||
title=title,
|
||||
reso=reso,
|
||||
is_stable_diffusion_ckpt=False,
|
||||
author=args.metadata_author,
|
||||
description=args.metadata_description,
|
||||
license=args.metadata_license,
|
||||
tags=args.metadata_tags,
|
||||
timesteps=timesteps,
|
||||
clip_skip=args.clip_skip, # None or int
|
||||
stable_cascade=True,
|
||||
)
|
||||
return metadata
|
||||
|
||||
|
||||
def stage_c_saver_common(ckpt_file, stage_c, text_model, save_dtype, sai_metadata):
|
||||
state_dict = stage_c.state_dict()
|
||||
if save_dtype is not None:
|
||||
state_dict = {k: v.to(save_dtype) for k, v in state_dict.items()}
|
||||
|
||||
state_dict = convert_state_dict_normal_attn_to_mha(state_dict)
|
||||
|
||||
save_file(state_dict, ckpt_file, metadata=sai_metadata)
|
||||
|
||||
# save text model
|
||||
if text_model is not None:
|
||||
text_model_sd = text_model.state_dict()
|
||||
|
||||
if save_dtype is not None:
|
||||
text_model_sd = {k: v.to(save_dtype) for k, v in text_model_sd.items()}
|
||||
|
||||
text_model_ckpt_file = os.path.splitext(ckpt_file)[0] + "_text_model.safetensors"
|
||||
save_file(text_model_sd, text_model_ckpt_file)
|
||||
|
||||
|
||||
def save_stage_c_model_on_epoch_end_or_stepwise(
|
||||
args: argparse.Namespace,
|
||||
on_epoch_end: bool,
|
||||
accelerator,
|
||||
save_dtype: torch.dtype,
|
||||
epoch: int,
|
||||
num_train_epochs: int,
|
||||
global_step: int,
|
||||
stage_c,
|
||||
text_model,
|
||||
):
|
||||
def stage_c_saver(ckpt_file, epoch_no, global_step):
|
||||
sai_metadata = get_sai_model_spec(args)
|
||||
stage_c_saver_common(ckpt_file, stage_c, text_model, save_dtype, sai_metadata)
|
||||
|
||||
save_sd_model_on_epoch_end_or_stepwise_common(
|
||||
args, on_epoch_end, accelerator, True, True, epoch, num_train_epochs, global_step, stage_c_saver, None
|
||||
)
|
||||
|
||||
|
||||
def save_stage_c_model_on_end(
|
||||
args: argparse.Namespace,
|
||||
save_dtype: torch.dtype,
|
||||
epoch: int,
|
||||
global_step: int,
|
||||
stage_c,
|
||||
text_model,
|
||||
):
|
||||
def stage_c_saver(ckpt_file, epoch_no, global_step):
|
||||
sai_metadata = get_sai_model_spec(args)
|
||||
stage_c_saver_common(ckpt_file, stage_c, text_model, save_dtype, sai_metadata)
|
||||
|
||||
save_sd_model_on_train_end_common(args, True, True, epoch, global_step, stage_c_saver, None)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region sample generation
|
||||
|
||||
|
||||
def sample_images(
|
||||
accelerator: Accelerator,
|
||||
args: argparse.Namespace,
|
||||
epoch,
|
||||
steps,
|
||||
previewer,
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
stage_c,
|
||||
gdf,
|
||||
prompt_replacement=None,
|
||||
):
|
||||
if steps == 0:
|
||||
if not args.sample_at_first:
|
||||
return
|
||||
else:
|
||||
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
|
||||
return
|
||||
if args.sample_every_n_epochs is not None:
|
||||
# sample_every_n_steps は無視する
|
||||
if epoch is None or epoch % args.sample_every_n_epochs != 0:
|
||||
return
|
||||
else:
|
||||
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
|
||||
return
|
||||
|
||||
logger.info("")
|
||||
logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
|
||||
if not os.path.isfile(args.sample_prompts):
|
||||
logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
|
||||
return
|
||||
|
||||
distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
|
||||
|
||||
# unwrap unet and text_encoder(s)
|
||||
stage_c = accelerator.unwrap_model(stage_c)
|
||||
text_encoder = accelerator.unwrap_model(text_encoder)
|
||||
|
||||
# read prompts
|
||||
if args.sample_prompts.endswith(".txt"):
|
||||
with open(args.sample_prompts, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
|
||||
elif args.sample_prompts.endswith(".toml"):
|
||||
with open(args.sample_prompts, "r", encoding="utf-8") as f:
|
||||
data = toml.load(f)
|
||||
prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]]
|
||||
elif args.sample_prompts.endswith(".json"):
|
||||
with open(args.sample_prompts, "r", encoding="utf-8") as f:
|
||||
prompts = json.load(f)
|
||||
|
||||
save_dir = args.output_dir + "/sample"
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# preprocess prompts
|
||||
for i in range(len(prompts)):
|
||||
prompt_dict = prompts[i]
|
||||
if isinstance(prompt_dict, str):
|
||||
prompt_dict = line_to_prompt_dict(prompt_dict)
|
||||
prompts[i] = prompt_dict
|
||||
assert isinstance(prompt_dict, dict)
|
||||
|
||||
# Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict.
|
||||
prompt_dict["enum"] = i
|
||||
prompt_dict.pop("subset", None)
|
||||
|
||||
# save random state to restore later
|
||||
rng_state = torch.get_rng_state()
|
||||
cuda_rng_state = None
|
||||
try:
|
||||
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if distributed_state.num_processes <= 1:
|
||||
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
|
||||
with torch.no_grad():
|
||||
for prompt_dict in prompts:
|
||||
sample_image_inference(
|
||||
accelerator,
|
||||
args,
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
stage_c,
|
||||
previewer,
|
||||
gdf,
|
||||
save_dir,
|
||||
prompt_dict,
|
||||
epoch,
|
||||
steps,
|
||||
prompt_replacement,
|
||||
)
|
||||
else:
|
||||
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
|
||||
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
|
||||
per_process_prompts = [] # list of lists
|
||||
for i in range(distributed_state.num_processes):
|
||||
per_process_prompts.append(prompts[i :: distributed_state.num_processes])
|
||||
|
||||
with torch.no_grad():
|
||||
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
|
||||
for prompt_dict in prompt_dict_lists[0]:
|
||||
sample_image_inference(
|
||||
accelerator,
|
||||
args,
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
stage_c,
|
||||
previewer,
|
||||
gdf,
|
||||
save_dir,
|
||||
prompt_dict,
|
||||
epoch,
|
||||
steps,
|
||||
prompt_replacement,
|
||||
)
|
||||
|
||||
# I'm not sure which of these is the correct way to clear the memory, but accelerator's device is used in the pipeline, so I'm using it here.
|
||||
# with torch.cuda.device(torch.cuda.current_device()):
|
||||
# torch.cuda.empty_cache()
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
torch.set_rng_state(rng_state)
|
||||
if cuda_rng_state is not None:
|
||||
torch.cuda.set_rng_state(cuda_rng_state)
|
||||
|
||||
|
||||
def sample_image_inference(
|
||||
accelerator: Accelerator,
|
||||
args: argparse.Namespace,
|
||||
tokenizer,
|
||||
text_model,
|
||||
stage_c,
|
||||
previewer,
|
||||
gdf,
|
||||
save_dir,
|
||||
prompt_dict,
|
||||
epoch,
|
||||
steps,
|
||||
prompt_replacement,
|
||||
):
|
||||
assert isinstance(prompt_dict, dict)
|
||||
negative_prompt = prompt_dict.get("negative_prompt")
|
||||
sample_steps = prompt_dict.get("sample_steps", 20)
|
||||
width = prompt_dict.get("width", 1024)
|
||||
height = prompt_dict.get("height", 1024)
|
||||
scale = prompt_dict.get("scale", 4)
|
||||
seed = prompt_dict.get("seed")
|
||||
# controlnet_image = prompt_dict.get("controlnet_image")
|
||||
prompt: str = prompt_dict.get("prompt", "")
|
||||
# sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
|
||||
|
||||
if prompt_replacement is not None:
|
||||
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
if negative_prompt is not None:
|
||||
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
|
||||
if seed is not None:
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
else:
|
||||
# True random sample image generation
|
||||
torch.seed()
|
||||
torch.cuda.seed()
|
||||
|
||||
height = max(64, height - height % 8) # round to divisible by 8
|
||||
width = max(64, width - width % 8) # round to divisible by 8
|
||||
logger.info(f"prompt: {prompt}")
|
||||
logger.info(f"negative_prompt: {negative_prompt}")
|
||||
logger.info(f"height: {height}")
|
||||
logger.info(f"width: {width}")
|
||||
logger.info(f"sample_steps: {sample_steps}")
|
||||
logger.info(f"scale: {scale}")
|
||||
# logger.info(f"sample_sampler: {sampler_name}")
|
||||
if seed is not None:
|
||||
logger.info(f"seed: {seed}")
|
||||
|
||||
negative_prompt = "" if negative_prompt is None else negative_prompt
|
||||
cfg = scale
|
||||
timesteps = sample_steps
|
||||
shift = 2
|
||||
t_start = 1.0
|
||||
|
||||
stage_c_latent_shape, _ = calculate_latent_sizes(height, width, batch_size=1)
|
||||
|
||||
# PREPARE CONDITIONS
|
||||
input_ids = tokenizer(
|
||||
[prompt], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
|
||||
)["input_ids"].to(text_model.device)
|
||||
cond_text, cond_pooled = get_hidden_states_stable_cascade(tokenizer.model_max_length, input_ids, tokenizer, text_model)
|
||||
|
||||
input_ids = tokenizer(
|
||||
[negative_prompt], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
|
||||
)["input_ids"].to(text_model.device)
|
||||
uncond_text, uncond_pooled = get_hidden_states_stable_cascade(tokenizer.model_max_length, input_ids, tokenizer, text_model)
|
||||
|
||||
device = accelerator.device
|
||||
dtype = stage_c.dtype
|
||||
cond_text = cond_text.to(device, dtype=dtype)
|
||||
cond_pooled = cond_pooled.unsqueeze(1).to(device, dtype=dtype)
|
||||
|
||||
uncond_text = uncond_text.to(device, dtype=dtype)
|
||||
uncond_pooled = uncond_pooled.unsqueeze(1).to(device, dtype=dtype)
|
||||
|
||||
zero_img_emb = torch.zeros(1, 768, device=device)
|
||||
|
||||
# 辞書にしたくないけど GDF から先の変更が面倒だからとりあえず辞書にしておく
|
||||
conditions = {"clip_text_pooled": cond_pooled, "clip": cond_pooled, "clip_text": cond_text, "clip_img": zero_img_emb}
|
||||
unconditions = {"clip_text_pooled": uncond_pooled, "clip": uncond_pooled, "clip_text": uncond_text, "clip_img": zero_img_emb}
|
||||
|
||||
with torch.no_grad(): # , torch.cuda.amp.autocast(dtype=dtype):
|
||||
sampling_c = gdf.sample(
|
||||
stage_c,
|
||||
conditions,
|
||||
stage_c_latent_shape,
|
||||
unconditions,
|
||||
device=device,
|
||||
cfg=cfg,
|
||||
shift=shift,
|
||||
timesteps=timesteps,
|
||||
t_start=t_start,
|
||||
)
|
||||
for sampled_c, _, _ in tqdm(sampling_c, total=timesteps):
|
||||
sampled_c = sampled_c
|
||||
|
||||
sampled_c = sampled_c.to(previewer.device, dtype=previewer.dtype)
|
||||
image = previewer(sampled_c)[0]
|
||||
image = torch.clamp(image, 0, 1)
|
||||
image = image.cpu().numpy().transpose(1, 2, 0)
|
||||
image = image * 255
|
||||
image = image.astype(np.uint8)
|
||||
image = Image.fromarray(image)
|
||||
|
||||
# adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
|
||||
# but adding 'enum' to the filename should be enough
|
||||
|
||||
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
||||
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
|
||||
seed_suffix = "" if seed is None else f"_{seed}"
|
||||
i: int = prompt_dict["enum"]
|
||||
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
|
||||
image.save(os.path.join(save_dir, img_filename))
|
||||
|
||||
# wandb有効時のみログを送信
|
||||
try:
|
||||
wandb_tracker = accelerator.get_tracker("wandb")
|
||||
try:
|
||||
import wandb
|
||||
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
|
||||
raise ImportError("No wandb / wandb がインストールされていないようです")
|
||||
|
||||
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
|
||||
except: # wandb 無効時
|
||||
pass
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
def add_effnet_arguments(parser):
|
||||
parser.add_argument(
|
||||
"--effnet_checkpoint_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="path to EfficientNet checkpoint / EfficientNetのチェックポイントのパス",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def add_text_model_arguments(parser):
|
||||
parser.add_argument(
|
||||
"--text_model_checkpoint_path",
|
||||
type=str,
|
||||
help="path to CLIP text model checkpoint / CLIPテキストモデルのチェックポイントのパス",
|
||||
)
|
||||
parser.add_argument("--save_text_model", action="store_true", help="if specified, save text model to corresponding path")
|
||||
return parser
|
||||
|
||||
|
||||
def add_stage_a_arguments(parser):
|
||||
parser.add_argument(
|
||||
"--stage_a_checkpoint_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="path to Stage A checkpoint / Stage Aのチェックポイントのパス",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def add_stage_b_arguments(parser):
|
||||
parser.add_argument(
|
||||
"--stage_b_checkpoint_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="path to Stage B checkpoint / Stage Bのチェックポイントのパス",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def add_stage_c_arguments(parser):
|
||||
parser.add_argument(
|
||||
"--stage_c_checkpoint_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="path to Stage C checkpoint / Stage Cのチェックポイントのパス",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def add_previewer_arguments(parser):
|
||||
parser.add_argument(
|
||||
"--previewer_checkpoint_path",
|
||||
type=str,
|
||||
required=False,
|
||||
help="path to previewer checkpoint / previewerのチェックポイントのパス",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def add_training_arguments(parser):
|
||||
parser.add_argument(
|
||||
"--adaptive_loss_weight",
|
||||
action="store_true",
|
||||
help="if specified, use adaptive loss weight. if not, use P2 loss weight"
|
||||
+ " / Adaptive Loss Weightを使用する。指定しない場合はP2 Loss Weightを使用する",
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
262
library/utils.py
262
library/utils.py
@@ -1,266 +1,6 @@
|
||||
import logging
|
||||
import sys
|
||||
import threading
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
from typing import *
|
||||
from diffusers import EulerAncestralDiscreteScheduler
|
||||
import diffusers.schedulers.scheduling_euler_ancestral_discrete
|
||||
from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput
|
||||
|
||||
|
||||
def fire_in_thread(f, *args, **kwargs):
|
||||
threading.Thread(target=f, args=args, kwargs=kwargs).start()
|
||||
|
||||
|
||||
def add_logging_arguments(parser):
|
||||
parser.add_argument(
|
||||
"--console_log_level",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
||||
help="Set the logging level, default is INFO / ログレベルを設定する。デフォルトはINFO",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--console_log_file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Log to a file instead of stderr / 標準エラー出力ではなくファイルにログを出力する",
|
||||
)
|
||||
parser.add_argument("--console_log_simple", action="store_true", help="Simple log output / シンプルなログ出力")
|
||||
|
||||
|
||||
def setup_logging(args=None, log_level=None, reset=False):
|
||||
if logging.root.handlers:
|
||||
if reset:
|
||||
# remove all handlers
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
else:
|
||||
return
|
||||
|
||||
# log_level can be set by the caller or by the args, the caller has priority. If not set, use INFO
|
||||
if log_level is None and args is not None:
|
||||
log_level = args.console_log_level
|
||||
if log_level is None:
|
||||
log_level = "INFO"
|
||||
log_level = getattr(logging, log_level)
|
||||
|
||||
msg_init = None
|
||||
if args is not None and args.console_log_file:
|
||||
handler = logging.FileHandler(args.console_log_file, mode="w")
|
||||
else:
|
||||
handler = None
|
||||
if not args or not args.console_log_simple:
|
||||
try:
|
||||
from rich.logging import RichHandler
|
||||
from rich.console import Console
|
||||
from rich.logging import RichHandler
|
||||
|
||||
handler = RichHandler(console=Console(stderr=True))
|
||||
except ImportError:
|
||||
# print("rich is not installed, using basic logging")
|
||||
msg_init = "rich is not installed, using basic logging"
|
||||
|
||||
if handler is None:
|
||||
handler = logging.StreamHandler(sys.stdout) # same as print
|
||||
handler.propagate = False
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="%(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
handler.setFormatter(formatter)
|
||||
logging.root.setLevel(log_level)
|
||||
logging.root.addHandler(handler)
|
||||
|
||||
if msg_init is not None:
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(msg_init)
|
||||
|
||||
|
||||
|
||||
# TODO make inf_utils.py
|
||||
|
||||
|
||||
# region Gradual Latent hires fix
|
||||
|
||||
|
||||
class GradualLatent:
|
||||
def __init__(
|
||||
self,
|
||||
ratio,
|
||||
start_timesteps,
|
||||
every_n_steps,
|
||||
ratio_step,
|
||||
s_noise=1.0,
|
||||
gaussian_blur_ksize=None,
|
||||
gaussian_blur_sigma=0.5,
|
||||
gaussian_blur_strength=0.5,
|
||||
unsharp_target_x=True,
|
||||
):
|
||||
self.ratio = ratio
|
||||
self.start_timesteps = start_timesteps
|
||||
self.every_n_steps = every_n_steps
|
||||
self.ratio_step = ratio_step
|
||||
self.s_noise = s_noise
|
||||
self.gaussian_blur_ksize = gaussian_blur_ksize
|
||||
self.gaussian_blur_sigma = gaussian_blur_sigma
|
||||
self.gaussian_blur_strength = gaussian_blur_strength
|
||||
self.unsharp_target_x = unsharp_target_x
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f"GradualLatent(ratio={self.ratio}, start_timesteps={self.start_timesteps}, "
|
||||
+ f"every_n_steps={self.every_n_steps}, ratio_step={self.ratio_step}, s_noise={self.s_noise}, "
|
||||
+ f"gaussian_blur_ksize={self.gaussian_blur_ksize}, gaussian_blur_sigma={self.gaussian_blur_sigma}, gaussian_blur_strength={self.gaussian_blur_strength}, "
|
||||
+ f"unsharp_target_x={self.unsharp_target_x})"
|
||||
)
|
||||
|
||||
def apply_unshark_mask(self, x: torch.Tensor):
|
||||
if self.gaussian_blur_ksize is None:
|
||||
return x
|
||||
blurred = transforms.functional.gaussian_blur(x, self.gaussian_blur_ksize, self.gaussian_blur_sigma)
|
||||
# mask = torch.sigmoid((x - blurred) * self.gaussian_blur_strength)
|
||||
mask = (x - blurred) * self.gaussian_blur_strength
|
||||
sharpened = x + mask
|
||||
return sharpened
|
||||
|
||||
def interpolate(self, x: torch.Tensor, resized_size, unsharp=True):
|
||||
org_dtype = x.dtype
|
||||
if org_dtype == torch.bfloat16:
|
||||
x = x.float()
|
||||
|
||||
x = torch.nn.functional.interpolate(x, size=resized_size, mode="bicubic", align_corners=False).to(dtype=org_dtype)
|
||||
|
||||
# apply unsharp mask / アンシャープマスクを適用する
|
||||
if unsharp and self.gaussian_blur_ksize:
|
||||
x = self.apply_unshark_mask(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class EulerAncestralDiscreteSchedulerGL(EulerAncestralDiscreteScheduler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.resized_size = None
|
||||
self.gradual_latent = None
|
||||
|
||||
def set_gradual_latent_params(self, size, gradual_latent: GradualLatent):
|
||||
self.resized_size = size
|
||||
self.gradual_latent = gradual_latent
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: Union[float, torch.FloatTensor],
|
||||
sample: torch.FloatTensor,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`):
|
||||
The direct output from learned diffusion model.
|
||||
timestep (`float`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A random number generator.
|
||||
return_dict (`bool`):
|
||||
Whether or not to return a
|
||||
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
|
||||
|
||||
Returns:
|
||||
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
|
||||
If return_dict is `True`,
|
||||
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
|
||||
otherwise a tuple is returned where the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
|
||||
if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor):
|
||||
raise ValueError(
|
||||
(
|
||||
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
||||
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
||||
" one of the `scheduler.timesteps` as a timestep."
|
||||
),
|
||||
)
|
||||
|
||||
if not self.is_scale_input_called:
|
||||
# logger.warning(
|
||||
print(
|
||||
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
|
||||
"See `StableDiffusionPipeline` for a usage example."
|
||||
)
|
||||
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
|
||||
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
||||
if self.config.prediction_type == "epsilon":
|
||||
pred_original_sample = sample - sigma * model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
# * c_out + input * c_skip
|
||||
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
|
||||
elif self.config.prediction_type == "sample":
|
||||
raise NotImplementedError("prediction_type not implemented yet: sample")
|
||||
else:
|
||||
raise ValueError(f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`")
|
||||
|
||||
sigma_from = self.sigmas[self.step_index]
|
||||
sigma_to = self.sigmas[self.step_index + 1]
|
||||
sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
|
||||
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
|
||||
|
||||
# 2. Convert to an ODE derivative
|
||||
derivative = (sample - pred_original_sample) / sigma
|
||||
|
||||
dt = sigma_down - sigma
|
||||
|
||||
device = model_output.device
|
||||
if self.resized_size is None:
|
||||
prev_sample = sample + derivative * dt
|
||||
|
||||
noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor(
|
||||
model_output.shape, dtype=model_output.dtype, device=device, generator=generator
|
||||
)
|
||||
s_noise = 1.0
|
||||
else:
|
||||
print("resized_size", self.resized_size, "model_output.shape", model_output.shape, "sample.shape", sample.shape)
|
||||
s_noise = self.gradual_latent.s_noise
|
||||
|
||||
if self.gradual_latent.unsharp_target_x:
|
||||
prev_sample = sample + derivative * dt
|
||||
prev_sample = self.gradual_latent.interpolate(prev_sample, self.resized_size)
|
||||
else:
|
||||
sample = self.gradual_latent.interpolate(sample, self.resized_size)
|
||||
derivative = self.gradual_latent.interpolate(derivative, self.resized_size, unsharp=False)
|
||||
prev_sample = sample + derivative * dt
|
||||
|
||||
noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor(
|
||||
(model_output.shape[0], model_output.shape[1], self.resized_size[0], self.resized_size[1]),
|
||||
dtype=model_output.dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
prev_sample = prev_sample + noise * sigma_up * s_noise
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
return EulerAncestralDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
|
||||
|
||||
|
||||
# endregion
|
||||
threading.Thread(target=f, args=args, kwargs=kwargs).start()
|
||||
@@ -2,13 +2,10 @@ import argparse
|
||||
import os
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def main(file):
|
||||
logger.info(f"loading: {file}")
|
||||
print(f"loading: {file}")
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
sd = load_file(file)
|
||||
else:
|
||||
|
||||
@@ -2,10 +2,7 @@ import os
|
||||
from typing import Optional, List, Type
|
||||
import torch
|
||||
from library import sdxl_original_unet
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# input_blocksに適用するかどうか / if True, input_blocks are not applied
|
||||
SKIP_INPUT_BLOCKS = False
|
||||
@@ -128,7 +125,7 @@ class LLLiteModule(torch.nn.Module):
|
||||
return
|
||||
|
||||
# timestepごとに呼ばれないので、あらかじめ計算しておく / it is not called for each timestep, so calculate it in advance
|
||||
# logger.info(f"C {self.lllite_name}, cond_image.shape={cond_image.shape}")
|
||||
# print(f"C {self.lllite_name}, cond_image.shape={cond_image.shape}")
|
||||
cx = self.conditioning1(cond_image)
|
||||
if not self.is_conv2d:
|
||||
# reshape / b,c,h,w -> b,h*w,c
|
||||
@@ -158,7 +155,7 @@ class LLLiteModule(torch.nn.Module):
|
||||
cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1)
|
||||
if self.use_zeros_for_batch_uncond:
|
||||
cx[0::2] = 0.0 # uncond is zero
|
||||
# logger.info(f"C {self.lllite_name}, x.shape={x.shape}, cx.shape={cx.shape}")
|
||||
# print(f"C {self.lllite_name}, x.shape={x.shape}, cx.shape={cx.shape}")
|
||||
|
||||
# downで入力の次元数を削減し、conditioning image embeddingと結合する
|
||||
# 加算ではなくchannel方向に結合することで、うまいこと混ぜてくれることを期待している
|
||||
@@ -289,7 +286,7 @@ class ControlNetLLLite(torch.nn.Module):
|
||||
|
||||
# create module instances
|
||||
self.unet_modules: List[LLLiteModule] = create_modules(unet, target_modules, LLLiteModule)
|
||||
logger.info(f"create ControlNet LLLite for U-Net: {len(self.unet_modules)} modules.")
|
||||
print(f"create ControlNet LLLite for U-Net: {len(self.unet_modules)} modules.")
|
||||
|
||||
def forward(self, x):
|
||||
return x # dummy
|
||||
@@ -322,7 +319,7 @@ class ControlNetLLLite(torch.nn.Module):
|
||||
return info
|
||||
|
||||
def apply_to(self):
|
||||
logger.info("applying LLLite for U-Net...")
|
||||
print("applying LLLite for U-Net...")
|
||||
for module in self.unet_modules:
|
||||
module.apply_to()
|
||||
self.add_module(module.lllite_name, module)
|
||||
@@ -377,19 +374,19 @@ if __name__ == "__main__":
|
||||
# sdxl_original_unet.USE_REENTRANT = False
|
||||
|
||||
# test shape etc
|
||||
logger.info("create unet")
|
||||
print("create unet")
|
||||
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
|
||||
unet.to("cuda").to(torch.float16)
|
||||
|
||||
logger.info("create ControlNet-LLLite")
|
||||
print("create ControlNet-LLLite")
|
||||
control_net = ControlNetLLLite(unet, 32, 64)
|
||||
control_net.apply_to()
|
||||
control_net.to("cuda")
|
||||
|
||||
logger.info(control_net)
|
||||
print(control_net)
|
||||
|
||||
# logger.info number of parameters
|
||||
logger.info(f"number of parameters {sum(p.numel() for p in control_net.parameters() if p.requires_grad)}")
|
||||
# print number of parameters
|
||||
print("number of parameters", sum(p.numel() for p in control_net.parameters() if p.requires_grad))
|
||||
|
||||
input()
|
||||
|
||||
@@ -401,12 +398,12 @@ if __name__ == "__main__":
|
||||
|
||||
# # visualize
|
||||
# import torchviz
|
||||
# logger.info("run visualize")
|
||||
# print("run visualize")
|
||||
# controlnet.set_control(conditioning_image)
|
||||
# output = unet(x, t, ctx, y)
|
||||
# logger.info("make_dot")
|
||||
# print("make_dot")
|
||||
# image = torchviz.make_dot(output, params=dict(controlnet.named_parameters()))
|
||||
# logger.info("render")
|
||||
# print("render")
|
||||
# image.format = "svg" # "png"
|
||||
# image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time
|
||||
# input()
|
||||
@@ -417,12 +414,12 @@ if __name__ == "__main__":
|
||||
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=True)
|
||||
|
||||
logger.info("start training")
|
||||
print("start training")
|
||||
steps = 10
|
||||
|
||||
sample_param = [p for p in control_net.named_parameters() if "up" in p[0]][0]
|
||||
for step in range(steps):
|
||||
logger.info(f"step {step}")
|
||||
print(f"step {step}")
|
||||
|
||||
batch_size = 1
|
||||
conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0
|
||||
@@ -442,7 +439,7 @@ if __name__ == "__main__":
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
logger.info(f"{sample_param}")
|
||||
print(sample_param)
|
||||
|
||||
# from safetensors.torch import save_file
|
||||
|
||||
|
||||
@@ -6,10 +6,7 @@ import re
|
||||
from typing import Optional, List, Type
|
||||
import torch
|
||||
from library import sdxl_original_unet
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# input_blocksに適用するかどうか / if True, input_blocks are not applied
|
||||
SKIP_INPUT_BLOCKS = False
|
||||
@@ -273,7 +270,7 @@ class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DCond
|
||||
|
||||
# create module instances
|
||||
self.lllite_modules = apply_to_modules(self, target_modules)
|
||||
logger.info(f"enable ControlNet LLLite for U-Net: {len(self.lllite_modules)} modules.")
|
||||
print(f"enable ControlNet LLLite for U-Net: {len(self.lllite_modules)} modules.")
|
||||
|
||||
# def prepare_optimizer_params(self):
|
||||
def prepare_params(self):
|
||||
@@ -284,8 +281,8 @@ class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DCond
|
||||
train_params.append(p)
|
||||
else:
|
||||
non_train_params.append(p)
|
||||
logger.info(f"count of trainable parameters: {len(train_params)}")
|
||||
logger.info(f"count of non-trainable parameters: {len(non_train_params)}")
|
||||
print(f"count of trainable parameters: {len(train_params)}")
|
||||
print(f"count of non-trainable parameters: {len(non_train_params)}")
|
||||
|
||||
for p in non_train_params:
|
||||
p.requires_grad_(False)
|
||||
@@ -391,7 +388,7 @@ class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DCond
|
||||
matches = pattern.findall(module_name)
|
||||
if matches is not None:
|
||||
for m in matches:
|
||||
logger.info(f"{module_name} {m}")
|
||||
print(module_name, m)
|
||||
module_name = module_name.replace(m, m.replace("_", "@"))
|
||||
module_name = module_name.replace("_", ".")
|
||||
module_name = module_name.replace("@", "_")
|
||||
@@ -410,7 +407,7 @@ class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DCond
|
||||
|
||||
|
||||
def replace_unet_linear_and_conv2d():
|
||||
logger.info("replace torch.nn.Linear and torch.nn.Conv2d to LLLiteLinear and LLLiteConv2d in U-Net")
|
||||
print("replace torch.nn.Linear and torch.nn.Conv2d to LLLiteLinear and LLLiteConv2d in U-Net")
|
||||
sdxl_original_unet.torch.nn.Linear = LLLiteLinear
|
||||
sdxl_original_unet.torch.nn.Conv2d = LLLiteConv2d
|
||||
|
||||
@@ -422,10 +419,10 @@ if __name__ == "__main__":
|
||||
replace_unet_linear_and_conv2d()
|
||||
|
||||
# test shape etc
|
||||
logger.info("create unet")
|
||||
print("create unet")
|
||||
unet = SdxlUNet2DConditionModelControlNetLLLite()
|
||||
|
||||
logger.info("enable ControlNet-LLLite")
|
||||
print("enable ControlNet-LLLite")
|
||||
unet.apply_lllite(32, 64, None, False, 1.0)
|
||||
unet.to("cuda") # .to(torch.float16)
|
||||
|
||||
@@ -442,14 +439,14 @@ if __name__ == "__main__":
|
||||
# unet_sd[converted_key] = model_sd[key]
|
||||
|
||||
# info = unet.load_lllite_weights("r:/lllite_from_unet.safetensors", unet_sd)
|
||||
# logger.info(info)
|
||||
# print(info)
|
||||
|
||||
# logger.info(unet)
|
||||
# print(unet)
|
||||
|
||||
# logger.info number of parameters
|
||||
# print number of parameters
|
||||
params = unet.prepare_params()
|
||||
logger.info(f"number of parameters {sum(p.numel() for p in params)}")
|
||||
# logger.info("type any key to continue")
|
||||
print("number of parameters", sum(p.numel() for p in params))
|
||||
# print("type any key to continue")
|
||||
# input()
|
||||
|
||||
unet.set_use_memory_efficient_attention(True, False)
|
||||
@@ -458,12 +455,12 @@ if __name__ == "__main__":
|
||||
|
||||
# # visualize
|
||||
# import torchviz
|
||||
# logger.info("run visualize")
|
||||
# print("run visualize")
|
||||
# controlnet.set_control(conditioning_image)
|
||||
# output = unet(x, t, ctx, y)
|
||||
# logger.info("make_dot")
|
||||
# print("make_dot")
|
||||
# image = torchviz.make_dot(output, params=dict(controlnet.named_parameters()))
|
||||
# logger.info("render")
|
||||
# print("render")
|
||||
# image.format = "svg" # "png"
|
||||
# image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time
|
||||
# input()
|
||||
@@ -474,13 +471,13 @@ if __name__ == "__main__":
|
||||
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=True)
|
||||
|
||||
logger.info("start training")
|
||||
print("start training")
|
||||
steps = 10
|
||||
batch_size = 1
|
||||
|
||||
sample_param = [p for p in unet.named_parameters() if ".lllite_up." in p[0]][0]
|
||||
for step in range(steps):
|
||||
logger.info(f"step {step}")
|
||||
print(f"step {step}")
|
||||
|
||||
conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0
|
||||
x = torch.randn(batch_size, 4, 128, 128).cuda()
|
||||
@@ -497,9 +494,9 @@ if __name__ == "__main__":
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
logger.info(sample_param)
|
||||
print(sample_param)
|
||||
|
||||
# from safetensors.torch import save_file
|
||||
|
||||
# logger.info("save weights")
|
||||
# print("save weights")
|
||||
# unet.save_lllite_weights("r:/lllite_from_unet.safetensors", torch.float16, None)
|
||||
|
||||
@@ -12,15 +12,10 @@
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
from diffusers import AutoencoderKL
|
||||
from transformers import CLIPTextModel
|
||||
from typing import List, Tuple, Union
|
||||
import torch
|
||||
from torch import nn
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DyLoRAModule(torch.nn.Module):
|
||||
"""
|
||||
@@ -170,15 +165,7 @@ class DyLoRAModule(torch.nn.Module):
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
|
||||
def create_network(
|
||||
multiplier: float,
|
||||
network_dim: Optional[int],
|
||||
network_alpha: Optional[float],
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
|
||||
unet,
|
||||
**kwargs,
|
||||
):
|
||||
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
||||
if network_dim is None:
|
||||
network_dim = 4 # default
|
||||
if network_alpha is None:
|
||||
@@ -195,7 +182,6 @@ def create_network(
|
||||
conv_alpha = 1.0
|
||||
else:
|
||||
conv_alpha = float(conv_alpha)
|
||||
|
||||
if unit is not None:
|
||||
unit = int(unit)
|
||||
else:
|
||||
@@ -237,7 +223,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
|
||||
elif "lora_down" in key:
|
||||
dim = value.size()[0]
|
||||
modules_dim[lora_name] = dim
|
||||
# logger.info(f"{lora_name} {value.size()} {dim}")
|
||||
# print(lora_name, value.size(), dim)
|
||||
|
||||
# support old LoRA without alpha
|
||||
for key in modules_dim.keys():
|
||||
@@ -281,11 +267,11 @@ class DyLoRANetwork(torch.nn.Module):
|
||||
self.apply_to_conv = apply_to_conv
|
||||
|
||||
if modules_dim is not None:
|
||||
logger.info("create LoRA network from weights")
|
||||
print(f"create LoRA network from weights")
|
||||
else:
|
||||
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, unit: {unit}")
|
||||
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, unit: {unit}")
|
||||
if self.apply_to_conv:
|
||||
logger.info("apply LoRA to Conv2d with kernel size (3,3).")
|
||||
print(f"apply LoRA to Conv2d with kernel size (3,3).")
|
||||
|
||||
# create module instances
|
||||
def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[DyLoRAModule]:
|
||||
@@ -320,23 +306,9 @@ class DyLoRANetwork(torch.nn.Module):
|
||||
lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, unit)
|
||||
loras.append(lora)
|
||||
return loras
|
||||
|
||||
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
|
||||
|
||||
self.text_encoder_loras = []
|
||||
for i, text_encoder in enumerate(text_encoders):
|
||||
if len(text_encoders) > 1:
|
||||
index = i + 1
|
||||
logger.info(f"create LoRA for Text Encoder {index}")
|
||||
else:
|
||||
index = None
|
||||
logger.info("create LoRA for Text Encoder")
|
||||
|
||||
text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||
self.text_encoder_loras.extend(text_encoder_loras)
|
||||
|
||||
# self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||
logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||
self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||
|
||||
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
||||
target_modules = DyLoRANetwork.UNET_TARGET_REPLACE_MODULE
|
||||
@@ -344,7 +316,7 @@ class DyLoRANetwork(torch.nn.Module):
|
||||
target_modules += DyLoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||
|
||||
self.unet_loras = create_modules(True, unet, target_modules)
|
||||
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||
|
||||
def set_multiplier(self, multiplier):
|
||||
self.multiplier = multiplier
|
||||
@@ -364,12 +336,12 @@ class DyLoRANetwork(torch.nn.Module):
|
||||
|
||||
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
|
||||
if apply_text_encoder:
|
||||
logger.info("enable LoRA for text encoder")
|
||||
print("enable LoRA for text encoder")
|
||||
else:
|
||||
self.text_encoder_loras = []
|
||||
|
||||
if apply_unet:
|
||||
logger.info("enable LoRA for U-Net")
|
||||
print("enable LoRA for U-Net")
|
||||
else:
|
||||
self.unet_loras = []
|
||||
|
||||
@@ -387,12 +359,12 @@ class DyLoRANetwork(torch.nn.Module):
|
||||
apply_unet = True
|
||||
|
||||
if apply_text_encoder:
|
||||
logger.info("enable LoRA for text encoder")
|
||||
print("enable LoRA for text encoder")
|
||||
else:
|
||||
self.text_encoder_loras = []
|
||||
|
||||
if apply_unet:
|
||||
logger.info("enable LoRA for U-Net")
|
||||
print("enable LoRA for U-Net")
|
||||
else:
|
||||
self.unet_loras = []
|
||||
|
||||
@@ -403,7 +375,7 @@ class DyLoRANetwork(torch.nn.Module):
|
||||
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
|
||||
lora.merge_to(sd_for_lora, dtype, device)
|
||||
|
||||
logger.info(f"weights are merged")
|
||||
print(f"weights are merged")
|
||||
"""
|
||||
|
||||
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
||||
|
||||
@@ -10,10 +10,7 @@ from safetensors.torch import load_file, save_file, safe_open
|
||||
from tqdm import tqdm
|
||||
from library import train_util, model_util
|
||||
import numpy as np
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_state_dict(file_name):
|
||||
if model_util.is_safetensors(file_name):
|
||||
@@ -43,13 +40,13 @@ def split_lora_model(lora_sd, unit):
|
||||
rank = value.size()[0]
|
||||
if rank > max_rank:
|
||||
max_rank = rank
|
||||
logger.info(f"Max rank: {max_rank}")
|
||||
print(f"Max rank: {max_rank}")
|
||||
|
||||
rank = unit
|
||||
split_models = []
|
||||
new_alpha = None
|
||||
while rank < max_rank:
|
||||
logger.info(f"Splitting rank {rank}")
|
||||
print(f"Splitting rank {rank}")
|
||||
new_sd = {}
|
||||
for key, value in lora_sd.items():
|
||||
if "lora_down" in key:
|
||||
@@ -60,7 +57,7 @@ def split_lora_model(lora_sd, unit):
|
||||
# なぜかscaleするとおかしくなる……
|
||||
# this_rank = lora_sd[key.replace("alpha", "lora_down.weight")].size()[0]
|
||||
# scale = math.sqrt(this_rank / rank) # rank is > unit
|
||||
# logger.info(key, value.size(), this_rank, rank, value, scale)
|
||||
# print(key, value.size(), this_rank, rank, value, scale)
|
||||
# new_alpha = value * scale # always same
|
||||
# new_sd[key] = new_alpha
|
||||
new_sd[key] = value
|
||||
@@ -72,10 +69,10 @@ def split_lora_model(lora_sd, unit):
|
||||
|
||||
|
||||
def split(args):
|
||||
logger.info("loading Model...")
|
||||
print("loading Model...")
|
||||
lora_sd, metadata = load_state_dict(args.model)
|
||||
|
||||
logger.info("Splitting Model...")
|
||||
print("Splitting Model...")
|
||||
original_rank, split_models = split_lora_model(lora_sd, args.unit)
|
||||
|
||||
comment = metadata.get("ss_training_comment", "")
|
||||
@@ -97,7 +94,7 @@ def split(args):
|
||||
filename, ext = os.path.splitext(args.save_to)
|
||||
model_file_name = filename + f"-{new_rank:04d}{ext}"
|
||||
|
||||
logger.info(f"saving model to: {model_file_name}")
|
||||
print(f"saving model to: {model_file_name}")
|
||||
save_to_file(model_file_name, state_dict, new_metadata)
|
||||
|
||||
|
||||
|
||||
@@ -11,10 +11,7 @@ from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
from library import sai_model_spec, model_util, sdxl_model_util
|
||||
import lora
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# CLAMP_QUANTILE = 0.99
|
||||
# MIN_DIFF = 1e-1
|
||||
@@ -69,14 +66,14 @@ def svd(
|
||||
|
||||
# load models
|
||||
if not sdxl:
|
||||
logger.info(f"loading original SD model : {model_org}")
|
||||
print(f"loading original SD model : {model_org}")
|
||||
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org)
|
||||
text_encoders_o = [text_encoder_o]
|
||||
if load_dtype is not None:
|
||||
text_encoder_o = text_encoder_o.to(load_dtype)
|
||||
unet_o = unet_o.to(load_dtype)
|
||||
|
||||
logger.info(f"loading tuned SD model : {model_tuned}")
|
||||
print(f"loading tuned SD model : {model_tuned}")
|
||||
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned)
|
||||
text_encoders_t = [text_encoder_t]
|
||||
if load_dtype is not None:
|
||||
@@ -88,7 +85,7 @@ def svd(
|
||||
device_org = load_original_model_to if load_original_model_to else "cpu"
|
||||
device_tuned = load_tuned_model_to if load_tuned_model_to else "cpu"
|
||||
|
||||
logger.info(f"loading original SDXL model : {model_org}")
|
||||
print(f"loading original SDXL model : {model_org}")
|
||||
text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, device_org
|
||||
)
|
||||
@@ -98,7 +95,7 @@ def svd(
|
||||
text_encoder_o2 = text_encoder_o2.to(load_dtype)
|
||||
unet_o = unet_o.to(load_dtype)
|
||||
|
||||
logger.info(f"loading original SDXL model : {model_tuned}")
|
||||
print(f"loading original SDXL model : {model_tuned}")
|
||||
text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, device_tuned
|
||||
)
|
||||
@@ -138,7 +135,7 @@ def svd(
|
||||
# Text Encoder might be same
|
||||
if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff:
|
||||
text_encoder_different = True
|
||||
logger.info(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}")
|
||||
print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}")
|
||||
|
||||
diffs[lora_name] = diff
|
||||
|
||||
@@ -147,7 +144,7 @@ def svd(
|
||||
del text_encoder
|
||||
|
||||
if not text_encoder_different:
|
||||
logger.warning("Text encoder is same. Extract U-Net only.")
|
||||
print("Text encoder is same. Extract U-Net only.")
|
||||
lora_network_o.text_encoder_loras = []
|
||||
diffs = {} # clear diffs
|
||||
|
||||
@@ -169,7 +166,7 @@ def svd(
|
||||
del unet_t
|
||||
|
||||
# make LoRA with svd
|
||||
logger.info("calculating by svd")
|
||||
print("calculating by svd")
|
||||
lora_weights = {}
|
||||
with torch.no_grad():
|
||||
for lora_name, mat in tqdm(list(diffs.items())):
|
||||
@@ -188,7 +185,7 @@ def svd(
|
||||
if device:
|
||||
mat = mat.to(device)
|
||||
|
||||
# logger.info(lora_name, mat.size(), mat.device, rank, in_dim, out_dim)
|
||||
# print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim)
|
||||
rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
|
||||
|
||||
if conv2d:
|
||||
@@ -233,7 +230,7 @@ def svd(
|
||||
lora_network_save.apply_to(text_encoders_o, unet_o) # create internal module references for state_dict
|
||||
|
||||
info = lora_network_save.load_state_dict(lora_sd)
|
||||
logger.info(f"Loading extracted LoRA weights: {info}")
|
||||
print(f"Loading extracted LoRA weights: {info}")
|
||||
|
||||
dir_name = os.path.dirname(save_to)
|
||||
if dir_name and not os.path.exists(dir_name):
|
||||
@@ -260,7 +257,7 @@ def svd(
|
||||
metadata.update(sai_metadata)
|
||||
|
||||
lora_network_save.save_weights(save_to, save_dtype, metadata)
|
||||
logger.info(f"LoRA weights are saved to: {save_to}")
|
||||
print(f"LoRA weights are saved to: {save_to}")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
|
||||
176
networks/lora.py
176
networks/lora.py
@@ -11,12 +11,7 @@ from transformers import CLIPTextModel
|
||||
import numpy as np
|
||||
import torch
|
||||
import re
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
||||
|
||||
@@ -51,7 +46,7 @@ class LoRAModule(torch.nn.Module):
|
||||
# if limit_rank:
|
||||
# self.lora_dim = min(lora_dim, in_dim, out_dim)
|
||||
# if self.lora_dim != lora_dim:
|
||||
# logger.info(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
|
||||
# print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
|
||||
# else:
|
||||
self.lora_dim = lora_dim
|
||||
|
||||
@@ -182,7 +177,7 @@ class LoRAInfModule(LoRAModule):
|
||||
else:
|
||||
# conv2d 3x3
|
||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
|
||||
# print(conved.size(), weight.size(), module.stride, module.padding)
|
||||
weight = weight + self.multiplier * conved * self.scale
|
||||
|
||||
# set weight to org_module
|
||||
@@ -221,7 +216,7 @@ class LoRAInfModule(LoRAModule):
|
||||
self.region_mask = None
|
||||
|
||||
def default_forward(self, x):
|
||||
# logger.info(f"default_forward {self.lora_name} {x.size()}")
|
||||
# print("default_forward", self.lora_name, x.size())
|
||||
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||
|
||||
def forward(self, x):
|
||||
@@ -250,8 +245,7 @@ class LoRAInfModule(LoRAModule):
|
||||
if mask is None:
|
||||
# raise ValueError(f"mask is None for resolution {area}")
|
||||
# emb_layers in SDXL doesn't have mask
|
||||
# if "emb" not in self.lora_name:
|
||||
# print(f"mask is None for resolution {self.lora_name}, {area}, {x.size()}")
|
||||
# print(f"mask is None for resolution {area}, {x.size()}")
|
||||
mask_size = (1, x.size()[1]) if len(x.size()) == 2 else (1, *x.size()[1:-1], 1)
|
||||
return torch.ones(mask_size, dtype=x.dtype, device=x.device) / self.network.num_sub_prompts
|
||||
if len(x.size()) != 4:
|
||||
@@ -269,8 +263,6 @@ class LoRAInfModule(LoRAModule):
|
||||
lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||
mask = self.get_mask_for_x(lx)
|
||||
# print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size())
|
||||
# if mask.ndim > lx.ndim: # in some resolution, lx is 2d and mask is 3d (the reason is not checked)
|
||||
# mask = mask.squeeze(-1)
|
||||
lx = lx * mask
|
||||
|
||||
x = self.org_forward(x)
|
||||
@@ -299,7 +291,7 @@ class LoRAInfModule(LoRAModule):
|
||||
if has_real_uncond:
|
||||
query[-self.network.batch_size :] = x[-self.network.batch_size :]
|
||||
|
||||
# logger.info(f"postp_to_q {self.lora_name} {x.size()} {query.size()} {self.network.num_sub_prompts}")
|
||||
# print("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts)
|
||||
return query
|
||||
|
||||
def sub_prompt_forward(self, x):
|
||||
@@ -314,7 +306,7 @@ class LoRAInfModule(LoRAModule):
|
||||
lx = x[emb_idx :: self.network.num_sub_prompts]
|
||||
lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale
|
||||
|
||||
# logger.info(f"sub_prompt_forward {self.lora_name} {x.size()} {lx.size()} {emb_idx}")
|
||||
# print("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx)
|
||||
|
||||
x = self.org_forward(x)
|
||||
x[emb_idx :: self.network.num_sub_prompts] += lx
|
||||
@@ -322,7 +314,7 @@ class LoRAInfModule(LoRAModule):
|
||||
return x
|
||||
|
||||
def to_out_forward(self, x):
|
||||
# logger.info(f"to_out_forward {self.lora_name} {x.size()} {self.network.is_last_network}")
|
||||
# print("to_out_forward", self.lora_name, x.size(), self.network.is_last_network)
|
||||
|
||||
if self.network.is_last_network:
|
||||
masks = [None] * self.network.num_sub_prompts
|
||||
@@ -340,7 +332,7 @@ class LoRAInfModule(LoRAModule):
|
||||
)
|
||||
self.network.shared[self.lora_name] = (lx, masks)
|
||||
|
||||
# logger.info(f"to_out_forward {lx.size()} {lx1.size()} {self.network.sub_prompt_index} {self.network.num_sub_prompts}")
|
||||
# print("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts)
|
||||
lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1
|
||||
masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1)
|
||||
|
||||
@@ -359,7 +351,7 @@ class LoRAInfModule(LoRAModule):
|
||||
if has_real_uncond:
|
||||
out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond
|
||||
|
||||
# logger.info(f"to_out_forward {self.lora_name} {self.network.sub_prompt_index} {self.network.num_sub_prompts}")
|
||||
# print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts)
|
||||
# if num_sub_prompts > num of LoRAs, fill with zero
|
||||
for i in range(len(masks)):
|
||||
if masks[i] is None:
|
||||
@@ -382,7 +374,7 @@ class LoRAInfModule(LoRAModule):
|
||||
x1 = x1 + lx1
|
||||
out[self.network.batch_size + i] = x1
|
||||
|
||||
# logger.info(f"to_out_forward {x.size()} {out.size()} {has_real_uncond}")
|
||||
# print("to_out_forward", x.size(), out.size(), has_real_uncond)
|
||||
return out
|
||||
|
||||
|
||||
@@ -519,7 +511,7 @@ def get_block_dims_and_alphas(
|
||||
len(block_dims) == num_total_blocks
|
||||
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
|
||||
else:
|
||||
logger.warning(
|
||||
print(
|
||||
f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります"
|
||||
)
|
||||
block_dims = [network_dim] * num_total_blocks
|
||||
@@ -530,7 +522,7 @@ def get_block_dims_and_alphas(
|
||||
len(block_alphas) == num_total_blocks
|
||||
), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください"
|
||||
else:
|
||||
logger.warning(
|
||||
print(
|
||||
f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります"
|
||||
)
|
||||
block_alphas = [network_alpha] * num_total_blocks
|
||||
@@ -550,13 +542,13 @@ def get_block_dims_and_alphas(
|
||||
else:
|
||||
if conv_alpha is None:
|
||||
conv_alpha = 1.0
|
||||
logger.warning(
|
||||
print(
|
||||
f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります"
|
||||
)
|
||||
conv_block_alphas = [conv_alpha] * num_total_blocks
|
||||
else:
|
||||
if conv_dim is not None:
|
||||
logger.warning(
|
||||
print(
|
||||
f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります"
|
||||
)
|
||||
conv_block_dims = [conv_dim] * num_total_blocks
|
||||
@@ -596,7 +588,7 @@ def get_block_lr_weight(
|
||||
elif name == "zeros":
|
||||
return [0.0 + base_lr] * max_len
|
||||
else:
|
||||
logger.error(
|
||||
print(
|
||||
"Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros"
|
||||
% (name)
|
||||
)
|
||||
@@ -608,14 +600,14 @@ def get_block_lr_weight(
|
||||
up_lr_weight = get_list(up_lr_weight)
|
||||
|
||||
if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len):
|
||||
logger.warning("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
|
||||
logger.warning("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
|
||||
print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
|
||||
print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
|
||||
up_lr_weight = up_lr_weight[:max_len]
|
||||
down_lr_weight = down_lr_weight[:max_len]
|
||||
|
||||
if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len):
|
||||
logger.warning("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
|
||||
logger.warning("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
|
||||
print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
|
||||
print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
|
||||
|
||||
if down_lr_weight != None and len(down_lr_weight) < max_len:
|
||||
down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight))
|
||||
@@ -623,24 +615,24 @@ def get_block_lr_weight(
|
||||
up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight))
|
||||
|
||||
if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
|
||||
logger.info("apply block learning rate / 階層別学習率を適用します。")
|
||||
print("apply block learning rate / 階層別学習率を適用します。")
|
||||
if down_lr_weight != None:
|
||||
down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight]
|
||||
logger.info(f"down_lr_weight (shallower -> deeper, 浅い層->深い層): {down_lr_weight}")
|
||||
print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", down_lr_weight)
|
||||
else:
|
||||
logger.info("down_lr_weight: all 1.0, すべて1.0")
|
||||
print("down_lr_weight: all 1.0, すべて1.0")
|
||||
|
||||
if mid_lr_weight != None:
|
||||
mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0
|
||||
logger.info(f"mid_lr_weight: {mid_lr_weight}")
|
||||
print("mid_lr_weight:", mid_lr_weight)
|
||||
else:
|
||||
logger.info("mid_lr_weight: 1.0")
|
||||
print("mid_lr_weight: 1.0")
|
||||
|
||||
if up_lr_weight != None:
|
||||
up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight]
|
||||
logger.info(f"up_lr_weight (deeper -> shallower, 深い層->浅い層): {up_lr_weight}")
|
||||
print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", up_lr_weight)
|
||||
else:
|
||||
logger.info("up_lr_weight: all 1.0, すべて1.0")
|
||||
print("up_lr_weight: all 1.0, すべて1.0")
|
||||
|
||||
return down_lr_weight, mid_lr_weight, up_lr_weight
|
||||
|
||||
@@ -721,7 +713,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
|
||||
elif "lora_down" in key:
|
||||
dim = value.size()[0]
|
||||
modules_dim[lora_name] = dim
|
||||
# logger.info(lora_name, value.size(), dim)
|
||||
# print(lora_name, value.size(), dim)
|
||||
|
||||
# support old LoRA without alpha
|
||||
for key in modules_dim.keys():
|
||||
@@ -796,26 +788,20 @@ class LoRANetwork(torch.nn.Module):
|
||||
self.module_dropout = module_dropout
|
||||
|
||||
if modules_dim is not None:
|
||||
logger.info(f"create LoRA network from weights")
|
||||
print(f"create LoRA network from weights")
|
||||
elif block_dims is not None:
|
||||
logger.info(f"create LoRA network from block_dims")
|
||||
logger.info(
|
||||
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
|
||||
)
|
||||
logger.info(f"block_dims: {block_dims}")
|
||||
logger.info(f"block_alphas: {block_alphas}")
|
||||
print(f"create LoRA network from block_dims")
|
||||
print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
|
||||
print(f"block_dims: {block_dims}")
|
||||
print(f"block_alphas: {block_alphas}")
|
||||
if conv_block_dims is not None:
|
||||
logger.info(f"conv_block_dims: {conv_block_dims}")
|
||||
logger.info(f"conv_block_alphas: {conv_block_alphas}")
|
||||
print(f"conv_block_dims: {conv_block_dims}")
|
||||
print(f"conv_block_alphas: {conv_block_alphas}")
|
||||
else:
|
||||
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
||||
logger.info(
|
||||
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
|
||||
)
|
||||
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
||||
print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
|
||||
if self.conv_lora_dim is not None:
|
||||
logger.info(
|
||||
f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
|
||||
)
|
||||
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
||||
|
||||
# create module instances
|
||||
def create_modules(
|
||||
@@ -841,14 +827,9 @@ class LoRANetwork(torch.nn.Module):
|
||||
is_linear = child_module.__class__.__name__ == "Linear"
|
||||
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
||||
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
||||
is_group_conv2d = is_conv2d and child_module.groups > 1
|
||||
|
||||
# if is_group_conv2d:
|
||||
# logger.info(f"skip group conv2d: {name}.{child_name}")
|
||||
# continue
|
||||
|
||||
if is_linear or (is_conv2d and not is_group_conv2d):
|
||||
lora_name = prefix + "." + name + ("." + child_name if child_name else "")
|
||||
if is_linear or is_conv2d:
|
||||
lora_name = prefix + "." + name + "." + child_name
|
||||
lora_name = lora_name.replace(".", "_")
|
||||
|
||||
dim = None
|
||||
@@ -905,36 +886,31 @@ class LoRANetwork(torch.nn.Module):
|
||||
for i, text_encoder in enumerate(text_encoders):
|
||||
if len(text_encoders) > 1:
|
||||
index = i + 1
|
||||
logger.info(f"create LoRA for Text Encoder {index}:")
|
||||
print(f"create LoRA for Text Encoder {index}:")
|
||||
else:
|
||||
index = None
|
||||
logger.info(f"create LoRA for Text Encoder:")
|
||||
print(f"create LoRA for Text Encoder:")
|
||||
|
||||
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||
self.text_encoder_loras.extend(text_encoder_loras)
|
||||
skipped_te += skipped
|
||||
logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||
|
||||
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
||||
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
|
||||
if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
|
||||
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||
|
||||
# XXX temporary solution for Stable Cascade Stage C: replace all modules
|
||||
if "StageC" in unet.__class__.__name__:
|
||||
logger.info("replace all modules for Stable Cascade Stage C")
|
||||
target_modules = ["Linear", "Conv2d"]
|
||||
|
||||
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
||||
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||
|
||||
skipped = skipped_te + skipped_un
|
||||
if varbose and len(skipped) > 0:
|
||||
logger.warning(
|
||||
print(
|
||||
f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
|
||||
)
|
||||
for name in skipped:
|
||||
logger.info(f"\t{name}")
|
||||
print(f"\t{name}")
|
||||
|
||||
self.up_lr_weight: List[float] = None
|
||||
self.down_lr_weight: List[float] = None
|
||||
@@ -952,10 +928,6 @@ class LoRANetwork(torch.nn.Module):
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.multiplier = self.multiplier
|
||||
|
||||
def set_enabled(self, is_enabled):
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.enabled = is_enabled
|
||||
|
||||
def load_weights(self, file):
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
@@ -969,12 +941,12 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
|
||||
if apply_text_encoder:
|
||||
logger.info("enable LoRA for text encoder")
|
||||
print("enable LoRA for text encoder")
|
||||
else:
|
||||
self.text_encoder_loras = []
|
||||
|
||||
if apply_unet:
|
||||
logger.info("enable LoRA for U-Net")
|
||||
print("enable LoRA for U-Net")
|
||||
else:
|
||||
self.unet_loras = []
|
||||
|
||||
@@ -996,12 +968,12 @@ class LoRANetwork(torch.nn.Module):
|
||||
apply_unet = True
|
||||
|
||||
if apply_text_encoder:
|
||||
logger.info("enable LoRA for text encoder")
|
||||
print("enable LoRA for text encoder")
|
||||
else:
|
||||
self.text_encoder_loras = []
|
||||
|
||||
if apply_unet:
|
||||
logger.info("enable LoRA for U-Net")
|
||||
print("enable LoRA for U-Net")
|
||||
else:
|
||||
self.unet_loras = []
|
||||
|
||||
@@ -1012,7 +984,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
|
||||
lora.merge_to(sd_for_lora, dtype, device)
|
||||
|
||||
logger.info(f"weights are merged")
|
||||
print(f"weights are merged")
|
||||
|
||||
# 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない
|
||||
def set_block_lr_weight(
|
||||
@@ -1143,7 +1115,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.set_network(self)
|
||||
|
||||
def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared, ds_ratio=None):
|
||||
def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared):
|
||||
self.batch_size = batch_size
|
||||
self.num_sub_prompts = num_sub_prompts
|
||||
self.current_size = (height, width)
|
||||
@@ -1158,7 +1130,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
device = ref_weight.device
|
||||
|
||||
def resize_add(mh, mw):
|
||||
# logger.info(mh, mw, mh * mw)
|
||||
# print(mh, mw, mh * mw)
|
||||
m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16
|
||||
m = m.to(device, dtype=dtype)
|
||||
mask_dic[mh * mw] = m
|
||||
@@ -1169,13 +1141,6 @@ class LoRANetwork(torch.nn.Module):
|
||||
resize_add(h, w)
|
||||
if h % 2 == 1 or w % 2 == 1: # add extra shape if h/w is not divisible by 2
|
||||
resize_add(h + h % 2, w + w % 2)
|
||||
|
||||
# deep shrink
|
||||
if ds_ratio is not None:
|
||||
hd = int(h * ds_ratio)
|
||||
wd = int(w * ds_ratio)
|
||||
resize_add(hd, wd)
|
||||
|
||||
h = (h + 1) // 2
|
||||
w = (w + 1) // 2
|
||||
|
||||
@@ -1260,3 +1225,40 @@ class LoRANetwork(torch.nn.Module):
|
||||
norms.append(scalednorm.item())
|
||||
|
||||
return keys_scaled, sum(norms) / len(norms), max(norms)
|
||||
|
||||
# region application weight
|
||||
|
||||
def get_number_of_blocks(self):
|
||||
# only for SDXL
|
||||
return 20
|
||||
|
||||
def has_text_encoder_block(self):
|
||||
return self.text_encoder_loras is not None and len(self.text_encoder_loras) > 0
|
||||
|
||||
def set_block_wise_weights(self, weights):
|
||||
if self.text_encoder_loras:
|
||||
for lora in self.text_encoder_loras:
|
||||
lora.multiplier = weights[0]
|
||||
|
||||
for lora in self.unet_loras:
|
||||
# determine block index
|
||||
key = lora.lora_name[10:] # remove "lora_unet_"
|
||||
if key.startswith("input_blocks"):
|
||||
block_index = int(key.split("_")[2]) + 1 # 1-9
|
||||
elif key.startswith("middle_block"):
|
||||
block_index = 10 # int(key.split("_")[2]) + 10
|
||||
elif key.startswith("output_blocks"):
|
||||
block_index = int(key.split("_")[2]) + 11 # 11-19
|
||||
else:
|
||||
print(f"unknown block: {key}")
|
||||
block_index = 0
|
||||
|
||||
lora.multiplier = weights[block_index]
|
||||
|
||||
# print(f"{lora.lora_name} block index: {block_index}, weight: {lora.multiplier}")
|
||||
# print(f"set block-wise weights to {weights}")
|
||||
|
||||
# TODO LoRA の weight をあらかじめ計算しておいて multiplier を掛けるだけにすると速くなるはず
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
@@ -9,15 +9,8 @@ from diffusers import UNet2DConditionModel
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from transformers import CLIPTextModel
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, get_preferred_device
|
||||
init_ipex()
|
||||
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def make_unet_conversion_map() -> Dict[str, str]:
|
||||
unet_conversion_map_layer = []
|
||||
@@ -255,7 +248,7 @@ def create_network_from_weights(
|
||||
elif "lora_down" in key:
|
||||
dim = value.size()[0]
|
||||
modules_dim[lora_name] = dim
|
||||
# logger.info(f"{lora_name} {value.size()} {dim}")
|
||||
# print(lora_name, value.size(), dim)
|
||||
|
||||
# support old LoRA without alpha
|
||||
for key in modules_dim.keys():
|
||||
@@ -298,12 +291,12 @@ class LoRANetwork(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.multiplier = multiplier
|
||||
|
||||
logger.info("create LoRA network from weights")
|
||||
print(f"create LoRA network from weights")
|
||||
|
||||
# convert SDXL Stability AI's U-Net modules to Diffusers
|
||||
converted = self.convert_unet_modules(modules_dim, modules_alpha)
|
||||
if converted:
|
||||
logger.info(f"converted {converted} Stability AI's U-Net LoRA modules to Diffusers (SDXL)")
|
||||
print(f"converted {converted} Stability AI's U-Net LoRA modules to Diffusers (SDXL)")
|
||||
|
||||
# create module instances
|
||||
def create_modules(
|
||||
@@ -338,7 +331,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
lora_name = lora_name.replace(".", "_")
|
||||
|
||||
if lora_name not in modules_dim:
|
||||
# logger.info(f"skipped {lora_name} (not found in modules_dim)")
|
||||
# print(f"skipped {lora_name} (not found in modules_dim)")
|
||||
skipped.append(lora_name)
|
||||
continue
|
||||
|
||||
@@ -369,18 +362,18 @@ class LoRANetwork(torch.nn.Module):
|
||||
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||
self.text_encoder_loras.extend(text_encoder_loras)
|
||||
skipped_te += skipped
|
||||
logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||
if len(skipped_te) > 0:
|
||||
logger.warning(f"skipped {len(skipped_te)} modules because of missing weight for text encoder.")
|
||||
print(f"skipped {len(skipped_te)} modules because of missing weight for text encoder.")
|
||||
|
||||
# extend U-Net target modules to include Conv2d 3x3
|
||||
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||
|
||||
self.unet_loras: List[LoRAModule]
|
||||
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
||||
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||
if len(skipped_un) > 0:
|
||||
logger.warning(f"skipped {len(skipped_un)} modules because of missing weight for U-Net.")
|
||||
print(f"skipped {len(skipped_un)} modules because of missing weight for U-Net.")
|
||||
|
||||
# assertion
|
||||
names = set()
|
||||
@@ -427,11 +420,11 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
def apply_to(self, multiplier=1.0, apply_text_encoder=True, apply_unet=True):
|
||||
if apply_text_encoder:
|
||||
logger.info("enable LoRA for text encoder")
|
||||
print("enable LoRA for text encoder")
|
||||
for lora in self.text_encoder_loras:
|
||||
lora.apply_to(multiplier)
|
||||
if apply_unet:
|
||||
logger.info("enable LoRA for U-Net")
|
||||
print("enable LoRA for U-Net")
|
||||
for lora in self.unet_loras:
|
||||
lora.apply_to(multiplier)
|
||||
|
||||
@@ -440,16 +433,16 @@ class LoRANetwork(torch.nn.Module):
|
||||
lora.unapply_to()
|
||||
|
||||
def merge_to(self, multiplier=1.0):
|
||||
logger.info("merge LoRA weights to original weights")
|
||||
print("merge LoRA weights to original weights")
|
||||
for lora in tqdm(self.text_encoder_loras + self.unet_loras):
|
||||
lora.merge_to(multiplier)
|
||||
logger.info(f"weights are merged")
|
||||
print(f"weights are merged")
|
||||
|
||||
def restore_from(self, multiplier=1.0):
|
||||
logger.info("restore LoRA weights from original weights")
|
||||
print("restore LoRA weights from original weights")
|
||||
for lora in tqdm(self.text_encoder_loras + self.unet_loras):
|
||||
lora.restore_from(multiplier)
|
||||
logger.info(f"weights are restored")
|
||||
print(f"weights are restored")
|
||||
|
||||
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
||||
# convert SDXL Stability AI's state dict to Diffusers' based state dict
|
||||
@@ -470,7 +463,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
my_state_dict = self.state_dict()
|
||||
for key in state_dict.keys():
|
||||
if state_dict[key].size() != my_state_dict[key].size():
|
||||
# logger.info(f"convert {key} from {state_dict[key].size()} to {my_state_dict[key].size()}")
|
||||
# print(f"convert {key} from {state_dict[key].size()} to {my_state_dict[key].size()}")
|
||||
state_dict[key] = state_dict[key].view(my_state_dict[key].size())
|
||||
|
||||
return super().load_state_dict(state_dict, strict)
|
||||
@@ -483,7 +476,7 @@ if __name__ == "__main__":
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
|
||||
import torch
|
||||
|
||||
device = get_preferred_device()
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_id", type=str, default=None, help="model id for huggingface")
|
||||
@@ -497,7 +490,7 @@ if __name__ == "__main__":
|
||||
image_prefix = args.model_id.replace("/", "_") + "_"
|
||||
|
||||
# load Diffusers model
|
||||
logger.info(f"load model from {args.model_id}")
|
||||
print(f"load model from {args.model_id}")
|
||||
pipe: Union[StableDiffusionPipeline, StableDiffusionXLPipeline]
|
||||
if args.sdxl:
|
||||
# use_safetensors=True does not work with 0.18.2
|
||||
@@ -510,7 +503,7 @@ if __name__ == "__main__":
|
||||
text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if args.sdxl else [pipe.text_encoder]
|
||||
|
||||
# load LoRA weights
|
||||
logger.info(f"load LoRA weights from {args.lora_weights}")
|
||||
print(f"load LoRA weights from {args.lora_weights}")
|
||||
if os.path.splitext(args.lora_weights)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
@@ -519,10 +512,10 @@ if __name__ == "__main__":
|
||||
lora_sd = torch.load(args.lora_weights)
|
||||
|
||||
# create by LoRA weights and load weights
|
||||
logger.info(f"create LoRA network")
|
||||
print(f"create LoRA network")
|
||||
lora_network: LoRANetwork = create_network_from_weights(text_encoders, pipe.unet, lora_sd, multiplier=1.0)
|
||||
|
||||
logger.info(f"load LoRA network weights")
|
||||
print(f"load LoRA network weights")
|
||||
lora_network.load_state_dict(lora_sd)
|
||||
|
||||
lora_network.to(device, dtype=pipe.unet.dtype) # required to apply_to. merge_to works without this
|
||||
@@ -551,34 +544,34 @@ if __name__ == "__main__":
|
||||
random.seed(seed)
|
||||
|
||||
# create image with original weights
|
||||
logger.info(f"create image with original weights")
|
||||
print(f"create image with original weights")
|
||||
seed_everything(args.seed)
|
||||
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
||||
image.save(image_prefix + "original.png")
|
||||
|
||||
# apply LoRA network to the model: slower than merge_to, but can be reverted easily
|
||||
logger.info(f"apply LoRA network to the model")
|
||||
print(f"apply LoRA network to the model")
|
||||
lora_network.apply_to(multiplier=1.0)
|
||||
|
||||
logger.info(f"create image with applied LoRA")
|
||||
print(f"create image with applied LoRA")
|
||||
seed_everything(args.seed)
|
||||
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
||||
image.save(image_prefix + "applied_lora.png")
|
||||
|
||||
# unapply LoRA network to the model
|
||||
logger.info(f"unapply LoRA network to the model")
|
||||
print(f"unapply LoRA network to the model")
|
||||
lora_network.unapply_to()
|
||||
|
||||
logger.info(f"create image with unapplied LoRA")
|
||||
print(f"create image with unapplied LoRA")
|
||||
seed_everything(args.seed)
|
||||
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
||||
image.save(image_prefix + "unapplied_lora.png")
|
||||
|
||||
# merge LoRA network to the model: faster than apply_to, but requires back-up of original weights (or unmerge_to)
|
||||
logger.info(f"merge LoRA network to the model")
|
||||
print(f"merge LoRA network to the model")
|
||||
lora_network.merge_to(multiplier=1.0)
|
||||
|
||||
logger.info(f"create image with LoRA")
|
||||
print(f"create image with LoRA")
|
||||
seed_everything(args.seed)
|
||||
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
||||
image.save(image_prefix + "merged_lora.png")
|
||||
@@ -586,31 +579,31 @@ if __name__ == "__main__":
|
||||
# restore (unmerge) LoRA weights: numerically unstable
|
||||
# マージされた重みを元に戻す。計算誤差のため、元の重みと完全に一致しないことがあるかもしれない
|
||||
# 保存したstate_dictから元の重みを復元するのが確実
|
||||
logger.info(f"restore (unmerge) LoRA weights")
|
||||
print(f"restore (unmerge) LoRA weights")
|
||||
lora_network.restore_from(multiplier=1.0)
|
||||
|
||||
logger.info(f"create image without LoRA")
|
||||
print(f"create image without LoRA")
|
||||
seed_everything(args.seed)
|
||||
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
||||
image.save(image_prefix + "unmerged_lora.png")
|
||||
|
||||
# restore original weights
|
||||
logger.info(f"restore original weights")
|
||||
print(f"restore original weights")
|
||||
pipe.unet.load_state_dict(org_unet_sd)
|
||||
pipe.text_encoder.load_state_dict(org_text_encoder_sd)
|
||||
if args.sdxl:
|
||||
pipe.text_encoder_2.load_state_dict(org_text_encoder_2_sd)
|
||||
|
||||
logger.info(f"create image with restored original weights")
|
||||
print(f"create image with restored original weights")
|
||||
seed_everything(args.seed)
|
||||
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
||||
image.save(image_prefix + "restore_original.png")
|
||||
|
||||
# use convenience function to merge LoRA weights
|
||||
logger.info(f"merge LoRA weights with convenience function")
|
||||
print(f"merge LoRA weights with convenience function")
|
||||
merge_lora_weights(pipe, lora_sd, multiplier=1.0)
|
||||
|
||||
logger.info(f"create image with merged LoRA weights")
|
||||
print(f"create image with merged LoRA weights")
|
||||
seed_everything(args.seed)
|
||||
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
||||
image.save(image_prefix + "convenience_merged_lora.png")
|
||||
|
||||
@@ -14,10 +14,7 @@ from transformers import CLIPTextModel
|
||||
import numpy as np
|
||||
import torch
|
||||
import re
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
||||
|
||||
@@ -52,7 +49,7 @@ class LoRAModule(torch.nn.Module):
|
||||
# if limit_rank:
|
||||
# self.lora_dim = min(lora_dim, in_dim, out_dim)
|
||||
# if self.lora_dim != lora_dim:
|
||||
# logger.info(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
|
||||
# print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
|
||||
# else:
|
||||
self.lora_dim = lora_dim
|
||||
|
||||
@@ -200,7 +197,7 @@ class LoRAInfModule(LoRAModule):
|
||||
else:
|
||||
# conv2d 3x3
|
||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
|
||||
# print(conved.size(), weight.size(), module.stride, module.padding)
|
||||
weight = weight + self.multiplier * conved * self.scale
|
||||
|
||||
# set weight to org_module
|
||||
@@ -239,7 +236,7 @@ class LoRAInfModule(LoRAModule):
|
||||
self.region_mask = None
|
||||
|
||||
def default_forward(self, x):
|
||||
# logger.info("default_forward", self.lora_name, x.size())
|
||||
# print("default_forward", self.lora_name, x.size())
|
||||
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||
|
||||
def forward(self, x):
|
||||
@@ -281,7 +278,7 @@ class LoRAInfModule(LoRAModule):
|
||||
# apply mask for LoRA result
|
||||
lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||
mask = self.get_mask_for_x(lx)
|
||||
# logger.info("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size())
|
||||
# print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size())
|
||||
lx = lx * mask
|
||||
|
||||
x = self.org_forward(x)
|
||||
@@ -310,7 +307,7 @@ class LoRAInfModule(LoRAModule):
|
||||
if has_real_uncond:
|
||||
query[-self.network.batch_size :] = x[-self.network.batch_size :]
|
||||
|
||||
# logger.info("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts)
|
||||
# print("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts)
|
||||
return query
|
||||
|
||||
def sub_prompt_forward(self, x):
|
||||
@@ -325,7 +322,7 @@ class LoRAInfModule(LoRAModule):
|
||||
lx = x[emb_idx :: self.network.num_sub_prompts]
|
||||
lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale
|
||||
|
||||
# logger.info("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx)
|
||||
# print("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx)
|
||||
|
||||
x = self.org_forward(x)
|
||||
x[emb_idx :: self.network.num_sub_prompts] += lx
|
||||
@@ -333,7 +330,7 @@ class LoRAInfModule(LoRAModule):
|
||||
return x
|
||||
|
||||
def to_out_forward(self, x):
|
||||
# logger.info("to_out_forward", self.lora_name, x.size(), self.network.is_last_network)
|
||||
# print("to_out_forward", self.lora_name, x.size(), self.network.is_last_network)
|
||||
|
||||
if self.network.is_last_network:
|
||||
masks = [None] * self.network.num_sub_prompts
|
||||
@@ -351,7 +348,7 @@ class LoRAInfModule(LoRAModule):
|
||||
)
|
||||
self.network.shared[self.lora_name] = (lx, masks)
|
||||
|
||||
# logger.info("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts)
|
||||
# print("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts)
|
||||
lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1
|
||||
masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1)
|
||||
|
||||
@@ -370,7 +367,7 @@ class LoRAInfModule(LoRAModule):
|
||||
if has_real_uncond:
|
||||
out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond
|
||||
|
||||
# logger.info("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts)
|
||||
# print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts)
|
||||
# for i in range(len(masks)):
|
||||
# if masks[i] is None:
|
||||
# masks[i] = torch.zeros_like(masks[-1])
|
||||
@@ -392,7 +389,7 @@ class LoRAInfModule(LoRAModule):
|
||||
x1 = x1 + lx1
|
||||
out[self.network.batch_size + i] = x1
|
||||
|
||||
# logger.info("to_out_forward", x.size(), out.size(), has_real_uncond)
|
||||
# print("to_out_forward", x.size(), out.size(), has_real_uncond)
|
||||
return out
|
||||
|
||||
|
||||
@@ -529,7 +526,7 @@ def get_block_dims_and_alphas(
|
||||
len(block_dims) == num_total_blocks
|
||||
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
|
||||
else:
|
||||
logger.warning(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
|
||||
print(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
|
||||
block_dims = [network_dim] * num_total_blocks
|
||||
|
||||
if block_alphas is not None:
|
||||
@@ -538,7 +535,7 @@ def get_block_dims_and_alphas(
|
||||
len(block_alphas) == num_total_blocks
|
||||
), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください"
|
||||
else:
|
||||
logger.warning(
|
||||
print(
|
||||
f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります"
|
||||
)
|
||||
block_alphas = [network_alpha] * num_total_blocks
|
||||
@@ -558,13 +555,13 @@ def get_block_dims_and_alphas(
|
||||
else:
|
||||
if conv_alpha is None:
|
||||
conv_alpha = 1.0
|
||||
logger.warning(
|
||||
print(
|
||||
f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります"
|
||||
)
|
||||
conv_block_alphas = [conv_alpha] * num_total_blocks
|
||||
else:
|
||||
if conv_dim is not None:
|
||||
logger.warning(
|
||||
print(
|
||||
f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります"
|
||||
)
|
||||
conv_block_dims = [conv_dim] * num_total_blocks
|
||||
@@ -604,7 +601,7 @@ def get_block_lr_weight(
|
||||
elif name == "zeros":
|
||||
return [0.0 + base_lr] * max_len
|
||||
else:
|
||||
logger.error(
|
||||
print(
|
||||
"Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros"
|
||||
% (name)
|
||||
)
|
||||
@@ -616,14 +613,14 @@ def get_block_lr_weight(
|
||||
up_lr_weight = get_list(up_lr_weight)
|
||||
|
||||
if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len):
|
||||
logger.warning("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
|
||||
logger.warning("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
|
||||
print("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
|
||||
print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
|
||||
up_lr_weight = up_lr_weight[:max_len]
|
||||
down_lr_weight = down_lr_weight[:max_len]
|
||||
|
||||
if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len):
|
||||
logger.warning("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
|
||||
logger.warning("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
|
||||
print("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
|
||||
print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
|
||||
|
||||
if down_lr_weight != None and len(down_lr_weight) < max_len:
|
||||
down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight))
|
||||
@@ -631,24 +628,24 @@ def get_block_lr_weight(
|
||||
up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight))
|
||||
|
||||
if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
|
||||
logger.info("apply block learning rate / 階層別学習率を適用します。")
|
||||
print("apply block learning rate / 階層別学習率を適用します。")
|
||||
if down_lr_weight != None:
|
||||
down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight]
|
||||
logger.info(f"down_lr_weight (shallower -> deeper, 浅い層->深い層): {down_lr_weight}")
|
||||
print("down_lr_weight (shallower -> deeper, 浅い層->深い層):", down_lr_weight)
|
||||
else:
|
||||
logger.info("down_lr_weight: all 1.0, すべて1.0")
|
||||
print("down_lr_weight: all 1.0, すべて1.0")
|
||||
|
||||
if mid_lr_weight != None:
|
||||
mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0
|
||||
logger.info(f"mid_lr_weight: {mid_lr_weight}")
|
||||
print("mid_lr_weight:", mid_lr_weight)
|
||||
else:
|
||||
logger.info("mid_lr_weight: 1.0")
|
||||
print("mid_lr_weight: 1.0")
|
||||
|
||||
if up_lr_weight != None:
|
||||
up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight]
|
||||
logger.info(f"up_lr_weight (deeper -> shallower, 深い層->浅い層): {up_lr_weight}")
|
||||
print("up_lr_weight (deeper -> shallower, 深い層->浅い層):", up_lr_weight)
|
||||
else:
|
||||
logger.info("up_lr_weight: all 1.0, すべて1.0")
|
||||
print("up_lr_weight: all 1.0, すべて1.0")
|
||||
|
||||
return down_lr_weight, mid_lr_weight, up_lr_weight
|
||||
|
||||
@@ -729,7 +726,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
|
||||
elif "lora_down" in key:
|
||||
dim = value.size()[0]
|
||||
modules_dim[lora_name] = dim
|
||||
# logger.info(lora_name, value.size(), dim)
|
||||
# print(lora_name, value.size(), dim)
|
||||
|
||||
# support old LoRA without alpha
|
||||
for key in modules_dim.keys():
|
||||
@@ -804,20 +801,20 @@ class LoRANetwork(torch.nn.Module):
|
||||
self.module_dropout = module_dropout
|
||||
|
||||
if modules_dim is not None:
|
||||
logger.info(f"create LoRA network from weights")
|
||||
print(f"create LoRA network from weights")
|
||||
elif block_dims is not None:
|
||||
logger.info(f"create LoRA network from block_dims")
|
||||
logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
|
||||
logger.info(f"block_dims: {block_dims}")
|
||||
logger.info(f"block_alphas: {block_alphas}")
|
||||
print(f"create LoRA network from block_dims")
|
||||
print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
|
||||
print(f"block_dims: {block_dims}")
|
||||
print(f"block_alphas: {block_alphas}")
|
||||
if conv_block_dims is not None:
|
||||
logger.info(f"conv_block_dims: {conv_block_dims}")
|
||||
logger.info(f"conv_block_alphas: {conv_block_alphas}")
|
||||
print(f"conv_block_dims: {conv_block_dims}")
|
||||
print(f"conv_block_alphas: {conv_block_alphas}")
|
||||
else:
|
||||
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
||||
logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
|
||||
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
||||
print(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
|
||||
if self.conv_lora_dim is not None:
|
||||
logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
||||
print(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
||||
|
||||
# create module instances
|
||||
def create_modules(
|
||||
@@ -902,15 +899,15 @@ class LoRANetwork(torch.nn.Module):
|
||||
for i, text_encoder in enumerate(text_encoders):
|
||||
if len(text_encoders) > 1:
|
||||
index = i + 1
|
||||
logger.info(f"create LoRA for Text Encoder {index}:")
|
||||
print(f"create LoRA for Text Encoder {index}:")
|
||||
else:
|
||||
index = None
|
||||
logger.info(f"create LoRA for Text Encoder:")
|
||||
print(f"create LoRA for Text Encoder:")
|
||||
|
||||
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||
self.text_encoder_loras.extend(text_encoder_loras)
|
||||
skipped_te += skipped
|
||||
logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||
|
||||
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
||||
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
|
||||
@@ -918,15 +915,15 @@ class LoRANetwork(torch.nn.Module):
|
||||
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||
|
||||
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
||||
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||
|
||||
skipped = skipped_te + skipped_un
|
||||
if varbose and len(skipped) > 0:
|
||||
logger.warning(
|
||||
print(
|
||||
f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
|
||||
)
|
||||
for name in skipped:
|
||||
logger.info(f"\t{name}")
|
||||
print(f"\t{name}")
|
||||
|
||||
self.up_lr_weight: List[float] = None
|
||||
self.down_lr_weight: List[float] = None
|
||||
@@ -957,12 +954,12 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
|
||||
if apply_text_encoder:
|
||||
logger.info("enable LoRA for text encoder")
|
||||
print("enable LoRA for text encoder")
|
||||
else:
|
||||
self.text_encoder_loras = []
|
||||
|
||||
if apply_unet:
|
||||
logger.info("enable LoRA for U-Net")
|
||||
print("enable LoRA for U-Net")
|
||||
else:
|
||||
self.unet_loras = []
|
||||
|
||||
@@ -984,12 +981,12 @@ class LoRANetwork(torch.nn.Module):
|
||||
apply_unet = True
|
||||
|
||||
if apply_text_encoder:
|
||||
logger.info("enable LoRA for text encoder")
|
||||
print("enable LoRA for text encoder")
|
||||
else:
|
||||
self.text_encoder_loras = []
|
||||
|
||||
if apply_unet:
|
||||
logger.info("enable LoRA for U-Net")
|
||||
print("enable LoRA for U-Net")
|
||||
else:
|
||||
self.unet_loras = []
|
||||
|
||||
@@ -1000,7 +997,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
|
||||
lora.merge_to(sd_for_lora, dtype, device)
|
||||
|
||||
logger.info(f"weights are merged")
|
||||
print(f"weights are merged")
|
||||
|
||||
# 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない
|
||||
def set_block_lr_weight(
|
||||
@@ -1147,7 +1144,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
device = ref_weight.device
|
||||
|
||||
def resize_add(mh, mw):
|
||||
# logger.info(mh, mw, mh * mw)
|
||||
# print(mh, mw, mh * mw)
|
||||
m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16
|
||||
m = m.to(device, dtype=dtype)
|
||||
mask_dic[mh * mw] = m
|
||||
|
||||
@@ -5,34 +5,27 @@ from library import model_util
|
||||
import library.train_util as train_util
|
||||
import argparse
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, get_preferred_device
|
||||
init_ipex()
|
||||
|
||||
import library.model_util as model_util
|
||||
import lora
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
||||
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う
|
||||
|
||||
DEVICE = get_preferred_device()
|
||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
|
||||
def interrogate(args):
|
||||
weights_dtype = torch.float16
|
||||
|
||||
# いろいろ準備する
|
||||
logger.info(f"loading SD model: {args.sd_model}")
|
||||
print(f"loading SD model: {args.sd_model}")
|
||||
args.pretrained_model_name_or_path = args.sd_model
|
||||
args.vae = None
|
||||
text_encoder, vae, unet, _ = train_util._load_target_model(args,weights_dtype, DEVICE)
|
||||
|
||||
logger.info(f"loading LoRA: {args.model}")
|
||||
print(f"loading LoRA: {args.model}")
|
||||
network, weights_sd = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)
|
||||
|
||||
# text encoder向けの重みがあるかチェックする:本当はlora側でやるのがいい
|
||||
@@ -42,11 +35,11 @@ def interrogate(args):
|
||||
has_te_weight = True
|
||||
break
|
||||
if not has_te_weight:
|
||||
logger.error("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません")
|
||||
print("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません")
|
||||
return
|
||||
del vae
|
||||
|
||||
logger.info("loading tokenizer")
|
||||
print("loading tokenizer")
|
||||
if args.v2:
|
||||
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
|
||||
else:
|
||||
@@ -60,7 +53,7 @@ def interrogate(args):
|
||||
# トークンをひとつひとつ当たっていく
|
||||
token_id_start = 0
|
||||
token_id_end = max(tokenizer.all_special_ids)
|
||||
logger.info(f"interrogate tokens are: {token_id_start} to {token_id_end}")
|
||||
print(f"interrogate tokens are: {token_id_start} to {token_id_end}")
|
||||
|
||||
def get_all_embeddings(text_encoder):
|
||||
embs = []
|
||||
@@ -86,24 +79,24 @@ def interrogate(args):
|
||||
embs.extend(encoder_hidden_states)
|
||||
return torch.stack(embs)
|
||||
|
||||
logger.info("get original text encoder embeddings.")
|
||||
print("get original text encoder embeddings.")
|
||||
orig_embs = get_all_embeddings(text_encoder)
|
||||
|
||||
network.apply_to(text_encoder, unet, True, len(network.unet_loras) > 0)
|
||||
info = network.load_state_dict(weights_sd, strict=False)
|
||||
logger.info(f"Loading LoRA weights: {info}")
|
||||
print(f"Loading LoRA weights: {info}")
|
||||
|
||||
network.to(DEVICE, dtype=weights_dtype)
|
||||
network.eval()
|
||||
|
||||
del unet
|
||||
|
||||
logger.info("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません(以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません)")
|
||||
logger.info("get text encoder embeddings with lora.")
|
||||
print("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません(以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません)")
|
||||
print("get text encoder embeddings with lora.")
|
||||
lora_embs = get_all_embeddings(text_encoder)
|
||||
|
||||
# 比べる:とりあえず単純に差分の絶対値で
|
||||
logger.info("comparing...")
|
||||
print("comparing...")
|
||||
diffs = {}
|
||||
for i, (orig_emb, lora_emb) in enumerate(zip(orig_embs, tqdm(lora_embs))):
|
||||
diff = torch.mean(torch.abs(orig_emb - lora_emb))
|
||||
|
||||
@@ -7,10 +7,7 @@ from safetensors.torch import load_file, save_file
|
||||
from library import sai_model_spec, train_util
|
||||
import library.model_util as model_util
|
||||
import lora
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_state_dict(file_name, dtype):
|
||||
if os.path.splitext(file_name)[1] == ".safetensors":
|
||||
@@ -64,10 +61,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
||||
name_to_module[lora_name] = child_module
|
||||
|
||||
for model, ratio in zip(models, ratios):
|
||||
logger.info(f"loading: {model}")
|
||||
print(f"loading: {model}")
|
||||
lora_sd, _ = load_state_dict(model, merge_dtype)
|
||||
|
||||
logger.info(f"merging...")
|
||||
print(f"merging...")
|
||||
for key in lora_sd.keys():
|
||||
if "lora_down" in key:
|
||||
up_key = key.replace("lora_down", "lora_up")
|
||||
@@ -76,10 +73,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
||||
# find original module for this lora
|
||||
module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
|
||||
if module_name not in name_to_module:
|
||||
logger.info(f"no module found for LoRA weight: {key}")
|
||||
print(f"no module found for LoRA weight: {key}")
|
||||
continue
|
||||
module = name_to_module[module_name]
|
||||
# logger.info(f"apply {key} to {module}")
|
||||
# print(f"apply {key} to {module}")
|
||||
|
||||
down_weight = lora_sd[key]
|
||||
up_weight = lora_sd[up_key]
|
||||
@@ -107,7 +104,7 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
||||
else:
|
||||
# conv2d 3x3
|
||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
|
||||
# print(conved.size(), weight.size(), module.stride, module.padding)
|
||||
weight = weight + ratio * conved * scale
|
||||
|
||||
module.weight = torch.nn.Parameter(weight)
|
||||
@@ -121,7 +118,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
|
||||
v2 = None
|
||||
base_model = None
|
||||
for model, ratio in zip(models, ratios):
|
||||
logger.info(f"loading: {model}")
|
||||
print(f"loading: {model}")
|
||||
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
|
||||
|
||||
if lora_metadata is not None:
|
||||
@@ -154,10 +151,10 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
|
||||
if lora_module_name not in base_alphas:
|
||||
base_alphas[lora_module_name] = alpha
|
||||
|
||||
logger.info(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
|
||||
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
|
||||
|
||||
# merge
|
||||
logger.info(f"merging...")
|
||||
print(f"merging...")
|
||||
for key in lora_sd.keys():
|
||||
if "alpha" in key:
|
||||
continue
|
||||
@@ -199,8 +196,8 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
|
||||
merged_sd[key_down] = merged_sd[key_down][perm]
|
||||
merged_sd[key_up] = merged_sd[key_up][:,perm]
|
||||
|
||||
logger.info("merged model")
|
||||
logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
|
||||
print("merged model")
|
||||
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
|
||||
|
||||
# check all dims are same
|
||||
dims_list = list(set(base_dims.values()))
|
||||
@@ -242,7 +239,7 @@ def merge(args):
|
||||
save_dtype = merge_dtype
|
||||
|
||||
if args.sd_model is not None:
|
||||
logger.info(f"loading SD model: {args.sd_model}")
|
||||
print(f"loading SD model: {args.sd_model}")
|
||||
|
||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
|
||||
|
||||
@@ -267,18 +264,18 @@ def merge(args):
|
||||
)
|
||||
if args.v2:
|
||||
# TODO read sai modelspec
|
||||
logger.warning(
|
||||
print(
|
||||
"Cannot determine if model is for v-prediction, so save metadata as v-prediction / modelがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します"
|
||||
)
|
||||
|
||||
logger.info(f"saving SD model to: {args.save_to}")
|
||||
print(f"saving SD model to: {args.save_to}")
|
||||
model_util.save_stable_diffusion_checkpoint(
|
||||
args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, sai_metadata, save_dtype, vae
|
||||
)
|
||||
else:
|
||||
state_dict, metadata, v2 = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
|
||||
|
||||
logger.info(f"calculating hashes and creating metadata...")
|
||||
print(f"calculating hashes and creating metadata...")
|
||||
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
metadata["sshs_model_hash"] = model_hash
|
||||
@@ -292,12 +289,12 @@ def merge(args):
|
||||
)
|
||||
if v2:
|
||||
# TODO read sai modelspec
|
||||
logger.warning(
|
||||
print(
|
||||
"Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します"
|
||||
)
|
||||
metadata.update(sai_metadata)
|
||||
|
||||
logger.info(f"saving model to: {args.save_to}")
|
||||
print(f"saving model to: {args.save_to}")
|
||||
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
|
||||
|
||||
|
||||
|
||||
@@ -6,10 +6,7 @@ import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
import library.model_util as model_util
|
||||
import lora
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_state_dict(file_name, dtype):
|
||||
if os.path.splitext(file_name)[1] == '.safetensors':
|
||||
@@ -57,10 +54,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
||||
name_to_module[lora_name] = child_module
|
||||
|
||||
for model, ratio in zip(models, ratios):
|
||||
logger.info(f"loading: {model}")
|
||||
print(f"loading: {model}")
|
||||
lora_sd = load_state_dict(model, merge_dtype)
|
||||
|
||||
logger.info(f"merging...")
|
||||
print(f"merging...")
|
||||
for key in lora_sd.keys():
|
||||
if "lora_down" in key:
|
||||
up_key = key.replace("lora_down", "lora_up")
|
||||
@@ -69,10 +66,10 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
||||
# find original module for this lora
|
||||
module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight"
|
||||
if module_name not in name_to_module:
|
||||
logger.info(f"no module found for LoRA weight: {key}")
|
||||
print(f"no module found for LoRA weight: {key}")
|
||||
continue
|
||||
module = name_to_module[module_name]
|
||||
# logger.info(f"apply {key} to {module}")
|
||||
# print(f"apply {key} to {module}")
|
||||
|
||||
down_weight = lora_sd[key]
|
||||
up_weight = lora_sd[up_key]
|
||||
@@ -99,10 +96,10 @@ def merge_lora_models(models, ratios, merge_dtype):
|
||||
alpha = None
|
||||
dim = None
|
||||
for model, ratio in zip(models, ratios):
|
||||
logger.info(f"loading: {model}")
|
||||
print(f"loading: {model}")
|
||||
lora_sd = load_state_dict(model, merge_dtype)
|
||||
|
||||
logger.info(f"merging...")
|
||||
print(f"merging...")
|
||||
for key in lora_sd.keys():
|
||||
if 'alpha' in key:
|
||||
if key in merged_sd:
|
||||
@@ -120,7 +117,7 @@ def merge_lora_models(models, ratios, merge_dtype):
|
||||
dim = lora_sd[key].size()[0]
|
||||
merged_sd[key] = lora_sd[key] * ratio
|
||||
|
||||
logger.info(f"dim (rank): {dim}, alpha: {alpha}")
|
||||
print(f"dim (rank): {dim}, alpha: {alpha}")
|
||||
if alpha is None:
|
||||
alpha = dim
|
||||
|
||||
@@ -145,21 +142,19 @@ def merge(args):
|
||||
save_dtype = merge_dtype
|
||||
|
||||
if args.sd_model is not None:
|
||||
logger.info(f"loading SD model: {args.sd_model}")
|
||||
print(f"loading SD model: {args.sd_model}")
|
||||
|
||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
|
||||
|
||||
merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype)
|
||||
|
||||
logger.info("")
|
||||
logger.info(f"saving SD model to: {args.save_to}")
|
||||
print(f"\nsaving SD model to: {args.save_to}")
|
||||
model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet,
|
||||
args.sd_model, 0, 0, save_dtype, vae)
|
||||
else:
|
||||
state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype)
|
||||
|
||||
logger.info(f"")
|
||||
logger.info(f"saving model to: {args.save_to}")
|
||||
print(f"\nsaving model to: {args.save_to}")
|
||||
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
|
||||
|
||||
|
||||
|
||||
@@ -8,10 +8,7 @@ from transformers import CLIPTextModel
|
||||
import numpy as np
|
||||
import torch
|
||||
import re
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
||||
|
||||
@@ -240,7 +237,7 @@ class OFTNetwork(torch.nn.Module):
|
||||
self.dim = dim
|
||||
self.alpha = alpha
|
||||
|
||||
logger.info(
|
||||
print(
|
||||
f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}"
|
||||
)
|
||||
|
||||
@@ -261,7 +258,7 @@ class OFTNetwork(torch.nn.Module):
|
||||
if is_linear or is_conv2d_1x1 or (is_conv2d and enable_conv):
|
||||
oft_name = prefix + "." + name + "." + child_name
|
||||
oft_name = oft_name.replace(".", "_")
|
||||
# logger.info(oft_name)
|
||||
# print(oft_name)
|
||||
|
||||
oft = module_class(
|
||||
oft_name,
|
||||
@@ -282,7 +279,7 @@ class OFTNetwork(torch.nn.Module):
|
||||
target_modules += OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||
|
||||
self.unet_ofts: List[OFTModule] = create_modules(unet, target_modules)
|
||||
logger.info(f"create OFT for U-Net: {len(self.unet_ofts)} modules.")
|
||||
print(f"create OFT for U-Net: {len(self.unet_ofts)} modules.")
|
||||
|
||||
# assertion
|
||||
names = set()
|
||||
@@ -319,7 +316,7 @@ class OFTNetwork(torch.nn.Module):
|
||||
|
||||
# TODO refactor to common function with apply_to
|
||||
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
|
||||
logger.info("enable OFT for U-Net")
|
||||
print("enable OFT for U-Net")
|
||||
|
||||
for oft in self.unet_ofts:
|
||||
sd_for_lora = {}
|
||||
@@ -329,7 +326,7 @@ class OFTNetwork(torch.nn.Module):
|
||||
oft.load_state_dict(sd_for_lora, False)
|
||||
oft.merge_to()
|
||||
|
||||
logger.info(f"weights are merged")
|
||||
print(f"weights are merged")
|
||||
|
||||
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
||||
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
||||
@@ -341,11 +338,11 @@ class OFTNetwork(torch.nn.Module):
|
||||
for oft in ofts:
|
||||
params.extend(oft.parameters())
|
||||
|
||||
# logger.info num of params
|
||||
# print num of params
|
||||
num_params = 0
|
||||
for p in params:
|
||||
num_params += p.numel()
|
||||
logger.info(f"OFT params: {num_params}")
|
||||
print(f"OFT params: {num_params}")
|
||||
return params
|
||||
|
||||
param_data = {"params": enumerate_params(self.unet_ofts)}
|
||||
|
||||
@@ -2,91 +2,80 @@
|
||||
# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
|
||||
# Thanks to cloneofsimo
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file, safe_open
|
||||
from tqdm import tqdm
|
||||
from library import train_util, model_util
|
||||
import numpy as np
|
||||
|
||||
from library import train_util
|
||||
from library import model_util
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MIN_SV = 1e-6
|
||||
|
||||
# Model save and load functions
|
||||
|
||||
|
||||
def load_state_dict(file_name, dtype):
|
||||
if model_util.is_safetensors(file_name):
|
||||
sd = load_file(file_name)
|
||||
with safe_open(file_name, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
else:
|
||||
sd = torch.load(file_name, map_location="cpu")
|
||||
metadata = None
|
||||
if model_util.is_safetensors(file_name):
|
||||
sd = load_file(file_name)
|
||||
with safe_open(file_name, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
else:
|
||||
sd = torch.load(file_name, map_location='cpu')
|
||||
metadata = None
|
||||
|
||||
for key in list(sd.keys()):
|
||||
if type(sd[key]) == torch.Tensor:
|
||||
sd[key] = sd[key].to(dtype)
|
||||
for key in list(sd.keys()):
|
||||
if type(sd[key]) == torch.Tensor:
|
||||
sd[key] = sd[key].to(dtype)
|
||||
|
||||
return sd, metadata
|
||||
return sd, metadata
|
||||
|
||||
|
||||
def save_to_file(file_name, state_dict, dtype, metadata):
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
if type(state_dict[key]) == torch.Tensor:
|
||||
state_dict[key] = state_dict[key].to(dtype)
|
||||
def save_to_file(file_name, model, state_dict, dtype, metadata):
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
if type(state_dict[key]) == torch.Tensor:
|
||||
state_dict[key] = state_dict[key].to(dtype)
|
||||
|
||||
if model_util.is_safetensors(file_name):
|
||||
save_file(state_dict, file_name, metadata)
|
||||
else:
|
||||
torch.save(state_dict, file_name)
|
||||
if model_util.is_safetensors(file_name):
|
||||
save_file(model, file_name, metadata)
|
||||
else:
|
||||
torch.save(model, file_name)
|
||||
|
||||
|
||||
# Indexing functions
|
||||
|
||||
|
||||
def index_sv_cumulative(S, target):
|
||||
original_sum = float(torch.sum(S))
|
||||
cumulative_sums = torch.cumsum(S, dim=0) / original_sum
|
||||
index = int(torch.searchsorted(cumulative_sums, target)) + 1
|
||||
index = max(1, min(index, len(S) - 1))
|
||||
original_sum = float(torch.sum(S))
|
||||
cumulative_sums = torch.cumsum(S, dim=0)/original_sum
|
||||
index = int(torch.searchsorted(cumulative_sums, target)) + 1
|
||||
index = max(1, min(index, len(S)-1))
|
||||
|
||||
return index
|
||||
return index
|
||||
|
||||
|
||||
def index_sv_fro(S, target):
|
||||
S_squared = S.pow(2)
|
||||
S_fro_sq = float(torch.sum(S_squared))
|
||||
sum_S_squared = torch.cumsum(S_squared, dim=0) / S_fro_sq
|
||||
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
|
||||
index = max(1, min(index, len(S) - 1))
|
||||
S_squared = S.pow(2)
|
||||
s_fro_sq = float(torch.sum(S_squared))
|
||||
sum_S_squared = torch.cumsum(S_squared, dim=0)/s_fro_sq
|
||||
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
|
||||
index = max(1, min(index, len(S)-1))
|
||||
|
||||
return index
|
||||
return index
|
||||
|
||||
|
||||
def index_sv_ratio(S, target):
|
||||
max_sv = S[0]
|
||||
min_sv = max_sv / target
|
||||
index = int(torch.sum(S > min_sv).item())
|
||||
index = max(1, min(index, len(S) - 1))
|
||||
max_sv = S[0]
|
||||
min_sv = max_sv/target
|
||||
index = int(torch.sum(S > min_sv).item())
|
||||
index = max(1, min(index, len(S)-1))
|
||||
|
||||
return index
|
||||
return index
|
||||
|
||||
|
||||
# Modified from Kohaku-blueleaf's extract/merge functions
|
||||
def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
|
||||
out_size, in_size, kernel_size, _ = weight.size()
|
||||
U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device))
|
||||
|
||||
|
||||
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
|
||||
lora_rank = param_dict["new_rank"]
|
||||
|
||||
@@ -103,17 +92,17 @@ def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale
|
||||
|
||||
def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
|
||||
out_size, in_size = weight.size()
|
||||
|
||||
|
||||
U, S, Vh = torch.linalg.svd(weight.to(device))
|
||||
|
||||
|
||||
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
|
||||
lora_rank = param_dict["new_rank"]
|
||||
|
||||
|
||||
U = U[:, :lora_rank]
|
||||
S = S[:lora_rank]
|
||||
U = U @ torch.diag(S)
|
||||
Vh = Vh[:lora_rank, :]
|
||||
|
||||
|
||||
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu()
|
||||
param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu()
|
||||
del U, S, Vh, weight
|
||||
@@ -124,7 +113,7 @@ def merge_conv(lora_down, lora_up, device):
|
||||
in_rank, in_size, kernel_size, k_ = lora_down.shape
|
||||
out_size, out_rank, _, _ = lora_up.shape
|
||||
assert in_rank == out_rank and kernel_size == k_, f"rank {in_rank} {out_rank} or kernel {kernel_size} {k_} mismatch"
|
||||
|
||||
|
||||
lora_down = lora_down.to(device)
|
||||
lora_up = lora_up.to(device)
|
||||
|
||||
@@ -138,274 +127,236 @@ def merge_linear(lora_down, lora_up, device):
|
||||
in_rank, in_size = lora_down.shape
|
||||
out_size, out_rank = lora_up.shape
|
||||
assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch"
|
||||
|
||||
|
||||
lora_down = lora_down.to(device)
|
||||
lora_up = lora_up.to(device)
|
||||
|
||||
|
||||
weight = lora_up @ lora_down
|
||||
del lora_up, lora_down
|
||||
return weight
|
||||
|
||||
|
||||
|
||||
# Calculate new rank
|
||||
|
||||
|
||||
def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
|
||||
param_dict = {}
|
||||
|
||||
if dynamic_method == "sv_ratio":
|
||||
if dynamic_method=="sv_ratio":
|
||||
# Calculate new dim and alpha based off ratio
|
||||
new_rank = index_sv_ratio(S, dynamic_param) + 1
|
||||
new_alpha = float(scale * new_rank)
|
||||
new_alpha = float(scale*new_rank)
|
||||
|
||||
elif dynamic_method == "sv_cumulative":
|
||||
elif dynamic_method=="sv_cumulative":
|
||||
# Calculate new dim and alpha based off cumulative sum
|
||||
new_rank = index_sv_cumulative(S, dynamic_param) + 1
|
||||
new_alpha = float(scale * new_rank)
|
||||
new_alpha = float(scale*new_rank)
|
||||
|
||||
elif dynamic_method == "sv_fro":
|
||||
elif dynamic_method=="sv_fro":
|
||||
# Calculate new dim and alpha based off sqrt sum of squares
|
||||
new_rank = index_sv_fro(S, dynamic_param) + 1
|
||||
new_alpha = float(scale * new_rank)
|
||||
new_alpha = float(scale*new_rank)
|
||||
else:
|
||||
new_rank = rank
|
||||
new_alpha = float(scale * new_rank)
|
||||
new_alpha = float(scale*new_rank)
|
||||
|
||||
if S[0] <= MIN_SV: # Zero matrix, set dim to 1
|
||||
|
||||
if S[0] <= MIN_SV: # Zero matrix, set dim to 1
|
||||
new_rank = 1
|
||||
new_alpha = float(scale * new_rank)
|
||||
elif new_rank > rank: # cap max rank at rank
|
||||
new_alpha = float(scale*new_rank)
|
||||
elif new_rank > rank: # cap max rank at rank
|
||||
new_rank = rank
|
||||
new_alpha = float(scale * new_rank)
|
||||
new_alpha = float(scale*new_rank)
|
||||
|
||||
|
||||
# Calculate resize info
|
||||
s_sum = torch.sum(torch.abs(S))
|
||||
s_rank = torch.sum(torch.abs(S[:new_rank]))
|
||||
|
||||
|
||||
S_squared = S.pow(2)
|
||||
s_fro = torch.sqrt(torch.sum(S_squared))
|
||||
s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank]))
|
||||
fro_percent = float(s_red_fro / s_fro)
|
||||
fro_percent = float(s_red_fro/s_fro)
|
||||
|
||||
param_dict["new_rank"] = new_rank
|
||||
param_dict["new_alpha"] = new_alpha
|
||||
param_dict["sum_retained"] = (s_rank) / s_sum
|
||||
param_dict["sum_retained"] = (s_rank)/s_sum
|
||||
param_dict["fro_retained"] = fro_percent
|
||||
param_dict["max_ratio"] = S[0] / S[new_rank - 1]
|
||||
param_dict["max_ratio"] = S[0]/S[new_rank - 1]
|
||||
|
||||
return param_dict
|
||||
|
||||
|
||||
def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
|
||||
network_alpha = None
|
||||
network_dim = None
|
||||
verbose_str = "\n"
|
||||
fro_list = []
|
||||
def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
|
||||
network_alpha = None
|
||||
network_dim = None
|
||||
verbose_str = "\n"
|
||||
fro_list = []
|
||||
|
||||
# Extract loaded lora dim and alpha
|
||||
for key, value in lora_sd.items():
|
||||
if network_alpha is None and "alpha" in key:
|
||||
network_alpha = value
|
||||
if network_dim is None and "lora_down" in key and len(value.size()) == 2:
|
||||
network_dim = value.size()[0]
|
||||
if network_alpha is not None and network_dim is not None:
|
||||
break
|
||||
if network_alpha is None:
|
||||
network_alpha = network_dim
|
||||
# Extract loaded lora dim and alpha
|
||||
for key, value in lora_sd.items():
|
||||
if network_alpha is None and 'alpha' in key:
|
||||
network_alpha = value
|
||||
if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
|
||||
network_dim = value.size()[0]
|
||||
if network_alpha is not None and network_dim is not None:
|
||||
break
|
||||
if network_alpha is None:
|
||||
network_alpha = network_dim
|
||||
|
||||
scale = network_alpha / network_dim
|
||||
scale = network_alpha/network_dim
|
||||
|
||||
if dynamic_method:
|
||||
logger.info(
|
||||
f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}"
|
||||
)
|
||||
if dynamic_method:
|
||||
print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}")
|
||||
|
||||
lora_down_weight = None
|
||||
lora_up_weight = None
|
||||
lora_down_weight = None
|
||||
lora_up_weight = None
|
||||
|
||||
o_lora_sd = lora_sd.copy()
|
||||
block_down_name = None
|
||||
block_up_name = None
|
||||
o_lora_sd = lora_sd.copy()
|
||||
block_down_name = None
|
||||
block_up_name = None
|
||||
|
||||
with torch.no_grad():
|
||||
for key, value in tqdm(lora_sd.items()):
|
||||
weight_name = None
|
||||
if "lora_down" in key:
|
||||
block_down_name = key.rsplit(".lora_down", 1)[0]
|
||||
weight_name = key.rsplit(".", 1)[-1]
|
||||
lora_down_weight = value
|
||||
else:
|
||||
continue
|
||||
with torch.no_grad():
|
||||
for key, value in tqdm(lora_sd.items()):
|
||||
weight_name = None
|
||||
if 'lora_down' in key:
|
||||
block_down_name = key.rsplit('.lora_down', 1)[0]
|
||||
weight_name = key.rsplit(".", 1)[-1]
|
||||
lora_down_weight = value
|
||||
else:
|
||||
continue
|
||||
|
||||
# find corresponding lora_up and alpha
|
||||
block_up_name = block_down_name
|
||||
lora_up_weight = lora_sd.get(block_up_name + ".lora_up." + weight_name, None)
|
||||
lora_alpha = lora_sd.get(block_down_name + ".alpha", None)
|
||||
# find corresponding lora_up and alpha
|
||||
block_up_name = block_down_name
|
||||
lora_up_weight = lora_sd.get(block_up_name + '.lora_up.' + weight_name, None)
|
||||
lora_alpha = lora_sd.get(block_down_name + '.alpha', None)
|
||||
|
||||
weights_loaded = lora_down_weight is not None and lora_up_weight is not None
|
||||
weights_loaded = (lora_down_weight is not None and lora_up_weight is not None)
|
||||
|
||||
if weights_loaded:
|
||||
if weights_loaded:
|
||||
|
||||
conv2d = len(lora_down_weight.size()) == 4
|
||||
if lora_alpha is None:
|
||||
scale = 1.0
|
||||
else:
|
||||
scale = lora_alpha / lora_down_weight.size()[0]
|
||||
conv2d = (len(lora_down_weight.size()) == 4)
|
||||
if lora_alpha is None:
|
||||
scale = 1.0
|
||||
else:
|
||||
scale = lora_alpha/lora_down_weight.size()[0]
|
||||
|
||||
if conv2d:
|
||||
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
|
||||
param_dict = extract_conv(full_weight_matrix, new_conv_rank, dynamic_method, dynamic_param, device, scale)
|
||||
else:
|
||||
full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device)
|
||||
param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
|
||||
if conv2d:
|
||||
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
|
||||
param_dict = extract_conv(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
|
||||
else:
|
||||
full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device)
|
||||
param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
|
||||
|
||||
if verbose:
|
||||
max_ratio = param_dict["max_ratio"]
|
||||
sum_retained = param_dict["sum_retained"]
|
||||
fro_retained = param_dict["fro_retained"]
|
||||
if not np.isnan(fro_retained):
|
||||
fro_list.append(float(fro_retained))
|
||||
if verbose:
|
||||
max_ratio = param_dict['max_ratio']
|
||||
sum_retained = param_dict['sum_retained']
|
||||
fro_retained = param_dict['fro_retained']
|
||||
if not np.isnan(fro_retained):
|
||||
fro_list.append(float(fro_retained))
|
||||
|
||||
verbose_str += f"{block_down_name:75} | "
|
||||
verbose_str += (
|
||||
f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}"
|
||||
)
|
||||
verbose_str+=f"{block_down_name:75} | "
|
||||
verbose_str+=f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}"
|
||||
|
||||
if verbose and dynamic_method:
|
||||
verbose_str += f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n"
|
||||
else:
|
||||
verbose_str += "\n"
|
||||
if verbose and dynamic_method:
|
||||
verbose_str+=f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n"
|
||||
else:
|
||||
verbose_str+=f"\n"
|
||||
|
||||
new_alpha = param_dict["new_alpha"]
|
||||
o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous()
|
||||
o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous()
|
||||
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict["new_alpha"]).to(save_dtype)
|
||||
new_alpha = param_dict['new_alpha']
|
||||
o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous()
|
||||
o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous()
|
||||
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict['new_alpha']).to(save_dtype)
|
||||
|
||||
block_down_name = None
|
||||
block_up_name = None
|
||||
lora_down_weight = None
|
||||
lora_up_weight = None
|
||||
weights_loaded = False
|
||||
del param_dict
|
||||
block_down_name = None
|
||||
block_up_name = None
|
||||
lora_down_weight = None
|
||||
lora_up_weight = None
|
||||
weights_loaded = False
|
||||
del param_dict
|
||||
|
||||
if verbose:
|
||||
print(verbose_str)
|
||||
print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
|
||||
logger.info("resizing complete")
|
||||
return o_lora_sd, network_dim, new_alpha
|
||||
if verbose:
|
||||
print(verbose_str)
|
||||
|
||||
print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
|
||||
print("resizing complete")
|
||||
return o_lora_sd, network_dim, new_alpha
|
||||
|
||||
|
||||
def resize(args):
|
||||
if args.save_to is None or not (
|
||||
args.save_to.endswith(".ckpt")
|
||||
or args.save_to.endswith(".pt")
|
||||
or args.save_to.endswith(".pth")
|
||||
or args.save_to.endswith(".safetensors")
|
||||
):
|
||||
raise Exception("The --save_to argument must be specified and must be a .ckpt , .pt, .pth or .safetensors file.")
|
||||
if args.save_to is None or not (args.save_to.endswith('.ckpt') or args.save_to.endswith('.pt') or args.save_to.endswith('.pth') or args.save_to.endswith('.safetensors')):
|
||||
raise Exception("The --save_to argument must be specified and must be a .ckpt , .pt, .pth or .safetensors file.")
|
||||
|
||||
args.new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank
|
||||
|
||||
def str_to_dtype(p):
|
||||
if p == 'float':
|
||||
return torch.float
|
||||
if p == 'fp16':
|
||||
return torch.float16
|
||||
if p == 'bf16':
|
||||
return torch.bfloat16
|
||||
return None
|
||||
|
||||
def str_to_dtype(p):
|
||||
if p == "float":
|
||||
return torch.float
|
||||
if p == "fp16":
|
||||
return torch.float16
|
||||
if p == "bf16":
|
||||
return torch.bfloat16
|
||||
return None
|
||||
if args.dynamic_method and not args.dynamic_param:
|
||||
raise Exception("If using dynamic_method, then dynamic_param is required")
|
||||
|
||||
if args.dynamic_method and not args.dynamic_param:
|
||||
raise Exception("If using dynamic_method, then dynamic_param is required")
|
||||
merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32
|
||||
save_dtype = str_to_dtype(args.save_precision)
|
||||
if save_dtype is None:
|
||||
save_dtype = merge_dtype
|
||||
|
||||
merge_dtype = str_to_dtype("float") # matmul method above only seems to work in float32
|
||||
save_dtype = str_to_dtype(args.save_precision)
|
||||
if save_dtype is None:
|
||||
save_dtype = merge_dtype
|
||||
print("loading Model...")
|
||||
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
|
||||
|
||||
logger.info("loading Model...")
|
||||
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
|
||||
print("Resizing Lora...")
|
||||
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose)
|
||||
|
||||
logger.info("Resizing Lora...")
|
||||
state_dict, old_dim, new_alpha = resize_lora_model(
|
||||
lora_sd, args.new_rank, args.new_conv_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose
|
||||
)
|
||||
# update metadata
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
|
||||
# update metadata
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
comment = metadata.get("ss_training_comment", "")
|
||||
|
||||
comment = metadata.get("ss_training_comment", "")
|
||||
if not args.dynamic_method:
|
||||
metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}"
|
||||
metadata["ss_network_dim"] = str(args.new_rank)
|
||||
metadata["ss_network_alpha"] = str(new_alpha)
|
||||
else:
|
||||
metadata["ss_training_comment"] = f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}"
|
||||
metadata["ss_network_dim"] = 'Dynamic'
|
||||
metadata["ss_network_alpha"] = 'Dynamic'
|
||||
|
||||
if not args.dynamic_method:
|
||||
conv_desc = "" if args.new_rank == args.new_conv_rank else f" (conv: {args.new_conv_rank})"
|
||||
metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}{conv_desc}; {comment}"
|
||||
metadata["ss_network_dim"] = str(args.new_rank)
|
||||
metadata["ss_network_alpha"] = str(new_alpha)
|
||||
else:
|
||||
metadata["ss_training_comment"] = (
|
||||
f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}"
|
||||
)
|
||||
metadata["ss_network_dim"] = "Dynamic"
|
||||
metadata["ss_network_alpha"] = "Dynamic"
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
metadata["sshs_model_hash"] = model_hash
|
||||
metadata["sshs_legacy_hash"] = legacy_hash
|
||||
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
metadata["sshs_model_hash"] = model_hash
|
||||
metadata["sshs_legacy_hash"] = legacy_hash
|
||||
|
||||
logger.info(f"saving model to: {args.save_to}")
|
||||
save_to_file(args.save_to, state_dict, save_dtype, metadata)
|
||||
print(f"saving model to: {args.save_to}")
|
||||
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--save_precision",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=[None, "float", "fp16", "bf16"],
|
||||
help="precision in saving, float if omitted / 保存時の精度、未指定時はfloat",
|
||||
)
|
||||
parser.add_argument("--new_rank", type=int, default=4, help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
|
||||
parser.add_argument(
|
||||
"--new_conv_rank",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_to",
|
||||
type=str,
|
||||
default=None,
|
||||
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default=None,
|
||||
help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose", action="store_true", help="Display verbose resizing information / rank変更時の詳細情報を出力する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dynamic_method",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"],
|
||||
help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank",
|
||||
)
|
||||
parser.add_argument("--dynamic_param", type=float, default=None, help="Specify target for dynamic reduction")
|
||||
|
||||
return parser
|
||||
parser.add_argument("--save_precision", type=str, default=None,
|
||||
choices=[None, "float", "fp16", "bf16"], help="precision in saving, float if omitted / 保存時の精度、未指定時はfloat")
|
||||
parser.add_argument("--new_rank", type=int, default=4,
|
||||
help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
|
||||
parser.add_argument("--save_to", type=str, default=None,
|
||||
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
|
||||
parser.add_argument("--model", type=str, default=None,
|
||||
help="LoRA model to resize at to new rank: ckpt or safetensors file / 読み込むLoRAモデル、ckptまたはsafetensors")
|
||||
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
||||
parser.add_argument("--verbose", action="store_true",
|
||||
help="Display verbose resizing information / rank変更時の詳細情報を出力する")
|
||||
parser.add_argument("--dynamic_method", type=str, default=None, choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"],
|
||||
help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank")
|
||||
parser.add_argument("--dynamic_param", type=float, default=None,
|
||||
help="Specify target for dynamic reduction")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
if __name__ == '__main__':
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
resize(args)
|
||||
args = parser.parse_args()
|
||||
resize(args)
|
||||
|
||||
@@ -8,10 +8,7 @@ from tqdm import tqdm
|
||||
from library import sai_model_spec, sdxl_model_util, train_util
|
||||
import library.model_util as model_util
|
||||
import lora
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_state_dict(file_name, dtype):
|
||||
if os.path.splitext(file_name)[1] == ".safetensors":
|
||||
@@ -69,10 +66,10 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
|
||||
name_to_module[lora_name] = child_module
|
||||
|
||||
for model, ratio in zip(models, ratios):
|
||||
logger.info(f"loading: {model}")
|
||||
print(f"loading: {model}")
|
||||
lora_sd, _ = load_state_dict(model, merge_dtype)
|
||||
|
||||
logger.info(f"merging...")
|
||||
print(f"merging...")
|
||||
for key in tqdm(lora_sd.keys()):
|
||||
if "lora_down" in key:
|
||||
up_key = key.replace("lora_down", "lora_up")
|
||||
@@ -81,10 +78,10 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
|
||||
# find original module for this lora
|
||||
module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
|
||||
if module_name not in name_to_module:
|
||||
logger.info(f"no module found for LoRA weight: {key}")
|
||||
print(f"no module found for LoRA weight: {key}")
|
||||
continue
|
||||
module = name_to_module[module_name]
|
||||
# logger.info(f"apply {key} to {module}")
|
||||
# print(f"apply {key} to {module}")
|
||||
|
||||
down_weight = lora_sd[key]
|
||||
up_weight = lora_sd[up_key]
|
||||
@@ -95,7 +92,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
|
||||
|
||||
# W <- W + U * D
|
||||
weight = module.weight
|
||||
# logger.info(module_name, down_weight.size(), up_weight.size())
|
||||
# print(module_name, down_weight.size(), up_weight.size())
|
||||
if len(weight.size()) == 2:
|
||||
# linear
|
||||
weight = weight + ratio * (up_weight @ down_weight) * scale
|
||||
@@ -110,7 +107,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_
|
||||
else:
|
||||
# conv2d 3x3
|
||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
|
||||
# print(conved.size(), weight.size(), module.stride, module.padding)
|
||||
weight = weight + ratio * conved * scale
|
||||
|
||||
module.weight = torch.nn.Parameter(weight)
|
||||
@@ -124,7 +121,7 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
|
||||
v2 = None
|
||||
base_model = None
|
||||
for model, ratio in zip(models, ratios):
|
||||
logger.info(f"loading: {model}")
|
||||
print(f"loading: {model}")
|
||||
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
|
||||
|
||||
if lora_metadata is not None:
|
||||
@@ -157,10 +154,10 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
|
||||
if lora_module_name not in base_alphas:
|
||||
base_alphas[lora_module_name] = alpha
|
||||
|
||||
logger.info(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
|
||||
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
|
||||
|
||||
# merge
|
||||
logger.info(f"merging...")
|
||||
print(f"merging...")
|
||||
for key in tqdm(lora_sd.keys()):
|
||||
if "alpha" in key:
|
||||
continue
|
||||
@@ -203,8 +200,8 @@ def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
|
||||
merged_sd[key_down] = merged_sd[key_down][perm]
|
||||
merged_sd[key_up] = merged_sd[key_up][:,perm]
|
||||
|
||||
logger.info("merged model")
|
||||
logger.info(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
|
||||
print("merged model")
|
||||
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
|
||||
|
||||
# check all dims are same
|
||||
dims_list = list(set(base_dims.values()))
|
||||
@@ -246,7 +243,7 @@ def merge(args):
|
||||
save_dtype = merge_dtype
|
||||
|
||||
if args.sd_model is not None:
|
||||
logger.info(f"loading SD model: {args.sd_model}")
|
||||
print(f"loading SD model: {args.sd_model}")
|
||||
|
||||
(
|
||||
text_model1,
|
||||
@@ -268,14 +265,14 @@ def merge(args):
|
||||
None, False, False, True, False, False, time.time(), title=title, merged_from=merged_from
|
||||
)
|
||||
|
||||
logger.info(f"saving SD model to: {args.save_to}")
|
||||
print(f"saving SD model to: {args.save_to}")
|
||||
sdxl_model_util.save_stable_diffusion_checkpoint(
|
||||
args.save_to, text_model1, text_model2, unet, 0, 0, ckpt_info, vae, logit_scale, sai_metadata, save_dtype
|
||||
)
|
||||
else:
|
||||
state_dict, metadata = merge_lora_models(args.models, args.ratios, merge_dtype, args.concat, args.shuffle)
|
||||
|
||||
logger.info(f"calculating hashes and creating metadata...")
|
||||
print(f"calculating hashes and creating metadata...")
|
||||
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
metadata["sshs_model_hash"] = model_hash
|
||||
@@ -289,7 +286,7 @@ def merge(args):
|
||||
)
|
||||
metadata.update(sai_metadata)
|
||||
|
||||
logger.info(f"saving model to: {args.save_to}")
|
||||
print(f"saving model to: {args.save_to}")
|
||||
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
|
||||
|
||||
|
||||
|
||||
@@ -5,12 +5,7 @@ import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
from library import sai_model_spec, train_util
|
||||
import library.model_util as model_util
|
||||
import lora
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
CLAMP_QUANTILE = 0.99
|
||||
|
||||
@@ -43,12 +38,12 @@ def save_to_file(file_name, state_dict, dtype, metadata):
|
||||
|
||||
|
||||
def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype):
|
||||
logger.info(f"new rank: {new_rank}, new conv rank: {new_conv_rank}")
|
||||
print(f"new rank: {new_rank}, new conv rank: {new_conv_rank}")
|
||||
merged_sd = {}
|
||||
v2 = None
|
||||
base_model = None
|
||||
for model, ratio in zip(models, ratios):
|
||||
logger.info(f"loading: {model}")
|
||||
print(f"loading: {model}")
|
||||
lora_sd, lora_metadata = load_state_dict(model, merge_dtype)
|
||||
|
||||
if lora_metadata is not None:
|
||||
@@ -58,7 +53,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
|
||||
base_model = lora_metadata.get(train_util.SS_METADATA_KEY_BASE_MODEL_VERSION, None)
|
||||
|
||||
# merge
|
||||
logger.info(f"merging...")
|
||||
print(f"merging...")
|
||||
for key in tqdm(list(lora_sd.keys())):
|
||||
if "lora_down" not in key:
|
||||
continue
|
||||
@@ -75,7 +70,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
|
||||
out_dim = up_weight.size()[0]
|
||||
conv2d = len(down_weight.size()) == 4
|
||||
kernel_size = None if not conv2d else down_weight.size()[2:4]
|
||||
# logger.info(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size)
|
||||
# print(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size)
|
||||
|
||||
# make original weight if not exist
|
||||
if lora_module_name not in merged_sd:
|
||||
@@ -112,7 +107,7 @@ def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dty
|
||||
merged_sd[lora_module_name] = weight
|
||||
|
||||
# extract from merged weights
|
||||
logger.info("extract new lora...")
|
||||
print("extract new lora...")
|
||||
merged_lora_sd = {}
|
||||
with torch.no_grad():
|
||||
for lora_module_name, mat in tqdm(list(merged_sd.items())):
|
||||
@@ -190,7 +185,7 @@ def merge(args):
|
||||
args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype
|
||||
)
|
||||
|
||||
logger.info(f"calculating hashes and creating metadata...")
|
||||
print(f"calculating hashes and creating metadata...")
|
||||
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
metadata["sshs_model_hash"] = model_hash
|
||||
@@ -205,12 +200,12 @@ def merge(args):
|
||||
)
|
||||
if v2:
|
||||
# TODO read sai modelspec
|
||||
logger.warning(
|
||||
print(
|
||||
"Cannot determine if LoRA is for v-prediction, so save metadata as v-prediction / LoRAがv-prediction用か否か不明なため、仮にv-prediction用としてmetadataを保存します"
|
||||
)
|
||||
metadata.update(sai_metadata)
|
||||
|
||||
logger.info(f"saving model to: {args.save_to}")
|
||||
print(f"saving model to: {args.save_to}")
|
||||
save_to_file(args.save_to, state_dict, save_dtype, metadata)
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ diffusers[torch]==0.25.0
|
||||
ftfy==6.1.1
|
||||
# albumentations==1.3.0
|
||||
opencv-python==4.7.0.68
|
||||
einops==0.7.0
|
||||
einops==0.6.1
|
||||
pytorch-lightning==1.9.0
|
||||
# bitsandbytes==0.39.1
|
||||
tensorboard==2.10.1
|
||||
@@ -29,7 +29,5 @@ huggingface-hub==0.20.1
|
||||
# protobuf==3.20.3
|
||||
# open clip for SDXL
|
||||
open-clip-torch==2.20.0
|
||||
# For logging
|
||||
rich==13.7.0
|
||||
# for kohya_ss library
|
||||
-e .
|
||||
|
||||
628
sdxl_gen_img.py
628
sdxl_gen_img.py
File diff suppressed because it is too large
Load Diff
@@ -8,9 +8,10 @@ import os
|
||||
import random
|
||||
from einops import repeat
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, get_preferred_device
|
||||
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
from tqdm import tqdm
|
||||
@@ -22,10 +23,6 @@ from safetensors.torch import load_file
|
||||
|
||||
from library import model_util, sdxl_model_util
|
||||
import networks.lora as lora
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# scheduler: このあたりの設定はSD1/2と同じでいいらしい
|
||||
# scheduler: The settings around here seem to be the same as SD1/2
|
||||
@@ -88,7 +85,7 @@ if __name__ == "__main__":
|
||||
guidance_scale = 7
|
||||
seed = None # 1
|
||||
|
||||
DEVICE = get_preferred_device()
|
||||
DEVICE = "cuda"
|
||||
DTYPE = torch.float16 # bfloat16 may work
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
@@ -143,7 +140,7 @@ if __name__ == "__main__":
|
||||
|
||||
vae_dtype = DTYPE
|
||||
if DTYPE == torch.float16:
|
||||
logger.info("use float32 for vae")
|
||||
print("use float32 for vae")
|
||||
vae_dtype = torch.float32
|
||||
vae.to(DEVICE, dtype=vae_dtype)
|
||||
vae.eval()
|
||||
@@ -190,7 +187,7 @@ if __name__ == "__main__":
|
||||
emb1 = get_timestep_embedding(torch.FloatTensor([original_height, original_width]).unsqueeze(0), 256)
|
||||
emb2 = get_timestep_embedding(torch.FloatTensor([crop_top, crop_left]).unsqueeze(0), 256)
|
||||
emb3 = get_timestep_embedding(torch.FloatTensor([target_height, target_width]).unsqueeze(0), 256)
|
||||
# logger.info("emb1", emb1.shape)
|
||||
# print("emb1", emb1.shape)
|
||||
c_vector = torch.cat([emb1, emb2, emb3], dim=1).to(DEVICE, dtype=DTYPE)
|
||||
uc_vector = c_vector.clone().to(DEVICE, dtype=DTYPE) # ちょっとここ正しいかどうかわからない I'm not sure if this is right
|
||||
|
||||
@@ -220,7 +217,7 @@ if __name__ == "__main__":
|
||||
|
||||
enc_out = text_model2(tokens, output_hidden_states=True, return_dict=True)
|
||||
text_embedding2_penu = enc_out["hidden_states"][-2]
|
||||
# logger.info("hidden_states2", text_embedding2_penu.shape)
|
||||
# print("hidden_states2", text_embedding2_penu.shape)
|
||||
text_embedding2_pool = enc_out["text_embeds"] # do not support Textual Inversion
|
||||
|
||||
# 連結して終了 concat and finish
|
||||
@@ -229,7 +226,7 @@ if __name__ == "__main__":
|
||||
|
||||
# cond
|
||||
c_ctx, c_ctx_pool = call_text_encoder(prompt, prompt2)
|
||||
# logger.info(c_ctx.shape, c_ctx_p.shape, c_vector.shape)
|
||||
# print(c_ctx.shape, c_ctx_p.shape, c_vector.shape)
|
||||
c_vector = torch.cat([c_ctx_pool, c_vector], dim=1)
|
||||
|
||||
# uncond
|
||||
@@ -326,4 +323,4 @@ if __name__ == "__main__":
|
||||
seed = int(seed)
|
||||
generate_image(prompt, prompt2, negative_prompt, seed)
|
||||
|
||||
logger.info("Done!")
|
||||
print("Done!")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# training with captions
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
from multiprocessing import Value
|
||||
@@ -8,9 +9,10 @@ from typing import List
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
from accelerate.utils import set_seed
|
||||
@@ -18,14 +20,6 @@ from diffusers import DDPMScheduler
|
||||
from library import sdxl_model_util
|
||||
|
||||
import library.train_util as train_util
|
||||
|
||||
from library.utils import setup_logging, add_logging_arguments
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import library.config_util as config_util
|
||||
import library.sdxl_train_util as sdxl_train_util
|
||||
from library.config_util import (
|
||||
@@ -97,11 +91,8 @@ def train(args):
|
||||
train_util.verify_training_args(args)
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
sdxl_train_util.verify_sdxl_training_args(args)
|
||||
setup_logging(args, reset=True)
|
||||
|
||||
assert (
|
||||
not args.weighted_captions
|
||||
), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
|
||||
assert not args.weighted_captions, "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
|
||||
assert (
|
||||
not args.train_text_encoder or not args.cache_text_encoder_outputs
|
||||
), "cache_text_encoder_outputs is not supported when training text encoder / text encoderを学習するときはcache_text_encoder_outputsはサポートされていません"
|
||||
@@ -126,18 +117,18 @@ def train(args):
|
||||
if args.dataset_class is None:
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
|
||||
if args.dataset_config is not None:
|
||||
logger.info(f"Load dataset config from {args.dataset_config}")
|
||||
print(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
logger.warning(
|
||||
print(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
)
|
||||
else:
|
||||
if use_dreambooth_method:
|
||||
logger.info("Using DreamBooth method.")
|
||||
print("Using DreamBooth method.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
@@ -148,7 +139,7 @@ def train(args):
|
||||
]
|
||||
}
|
||||
else:
|
||||
logger.info("Training with captions.")
|
||||
print("Training with captions.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
@@ -178,7 +169,7 @@ def train(args):
|
||||
train_util.debug_dataset(train_dataset_group, True)
|
||||
return
|
||||
if len(train_dataset_group) == 0:
|
||||
logger.error(
|
||||
print(
|
||||
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
|
||||
)
|
||||
return
|
||||
@@ -194,7 +185,7 @@ def train(args):
|
||||
), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
||||
|
||||
# acceleratorを準備する
|
||||
logger.info("prepare accelerator")
|
||||
print("prepare accelerator")
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
@@ -261,7 +252,9 @@ def train(args):
|
||||
with torch.no_grad():
|
||||
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
@@ -354,8 +347,8 @@ def train(args):
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset_group,
|
||||
batch_size=1,
|
||||
@@ -370,9 +363,7 @@ def train(args):
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||
)
|
||||
accelerator.print(
|
||||
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
|
||||
)
|
||||
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
# データセット側にも学習ステップを送信
|
||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
@@ -416,7 +407,8 @@ def train(args):
|
||||
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
|
||||
text_encoder1.to("cpu", dtype=torch.float32)
|
||||
text_encoder2.to("cpu", dtype=torch.float32)
|
||||
clean_memory_on_device(accelerator.device)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
# make sure Text Encoders are on GPU
|
||||
text_encoder1.to(accelerator.device)
|
||||
@@ -441,9 +433,7 @@ def train(args):
|
||||
accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
|
||||
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
||||
accelerator.print(
|
||||
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
|
||||
)
|
||||
accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
|
||||
# accelerator.print(
|
||||
# f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
|
||||
# )
|
||||
@@ -463,7 +453,7 @@ def train(args):
|
||||
if accelerator.is_main_process:
|
||||
init_kwargs = {}
|
||||
if args.wandb_run_name:
|
||||
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
||||
init_kwargs['wandb'] = {'name': args.wandb_run_name}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
|
||||
@@ -547,7 +537,7 @@ def train(args):
|
||||
# assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
||||
# assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
||||
# assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
||||
# logger.info("text encoder outputs verified")
|
||||
# print("text encoder outputs verified")
|
||||
|
||||
# get size embeddings
|
||||
orig_size = batch["original_sizes_hw"]
|
||||
@@ -734,13 +724,12 @@ def train(args):
|
||||
logit_scale,
|
||||
ckpt_info,
|
||||
)
|
||||
logger.info("model saved.")
|
||||
print("model saved.")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
add_logging_arguments(parser)
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, True, True, True)
|
||||
train_util.add_training_arguments(parser, False)
|
||||
@@ -763,9 +752,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="learning rate for text encoder 2 (BiG-G) / text encoder 2 (BiG-G)の学習率",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する"
|
||||
)
|
||||
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
|
||||
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
|
||||
parser.add_argument(
|
||||
"--no_half_vae",
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# training code for ControlNet-LLLite with passing cond_image to U-Net's forward
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
@@ -12,9 +13,10 @@ from types import SimpleNamespace
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
@@ -43,12 +45,6 @@ from library.custom_train_functions import (
|
||||
apply_debiased_estimation,
|
||||
)
|
||||
import networks.control_net_lllite_for_train as control_net_lllite_for_train
|
||||
from library.utils import setup_logging, add_logging_arguments
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# TODO 他のスクリプトと共通化する
|
||||
@@ -69,7 +65,6 @@ def train(args):
|
||||
train_util.verify_training_args(args)
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
sdxl_train_util.verify_sdxl_training_args(args)
|
||||
setup_logging(args, reset=True)
|
||||
|
||||
cache_latents = args.cache_latents
|
||||
use_user_config = args.dataset_config is not None
|
||||
@@ -83,11 +78,11 @@ def train(args):
|
||||
# データセットを準備する
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
|
||||
if use_user_config:
|
||||
logger.info(f"Load dataset config from {args.dataset_config}")
|
||||
print(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "conditioning_data_dir"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
logger.warning(
|
||||
print(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
@@ -119,7 +114,7 @@ def train(args):
|
||||
train_util.debug_dataset(train_dataset_group)
|
||||
return
|
||||
if len(train_dataset_group) == 0:
|
||||
logger.error(
|
||||
print(
|
||||
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)"
|
||||
)
|
||||
return
|
||||
@@ -129,9 +124,7 @@ def train(args):
|
||||
train_dataset_group.is_latent_cacheable()
|
||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
else:
|
||||
logger.warning(
|
||||
"WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません"
|
||||
)
|
||||
print("WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません")
|
||||
|
||||
if args.cache_text_encoder_outputs:
|
||||
assert (
|
||||
@@ -139,7 +132,7 @@ def train(args):
|
||||
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
||||
|
||||
# acceleratorを準備する
|
||||
logger.info("prepare accelerator")
|
||||
print("prepare accelerator")
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
is_main_process = accelerator.is_main_process
|
||||
|
||||
@@ -171,7 +164,9 @@ def train(args):
|
||||
accelerator.is_main_process,
|
||||
)
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
@@ -236,14 +231,14 @@ def train(args):
|
||||
accelerator.print("prepare optimizer, data loader etc.")
|
||||
|
||||
trainable_params = list(unet.prepare_params())
|
||||
logger.info(f"trainable params count: {len(trainable_params)}")
|
||||
logger.info(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}")
|
||||
print(f"trainable params count: {len(trainable_params)}")
|
||||
print(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}")
|
||||
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset_group,
|
||||
@@ -259,9 +254,7 @@ def train(args):
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||
)
|
||||
accelerator.print(
|
||||
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
|
||||
)
|
||||
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
# データセット側にも学習ステップを送信
|
||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
@@ -298,7 +291,8 @@ def train(args):
|
||||
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
|
||||
text_encoder1.to("cpu", dtype=torch.float32)
|
||||
text_encoder2.to("cpu", dtype=torch.float32)
|
||||
clean_memory_on_device(accelerator.device)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
# make sure Text Encoders are on GPU
|
||||
text_encoder1.to(accelerator.device)
|
||||
@@ -329,10 +323,8 @@ def train(args):
|
||||
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
||||
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
||||
accelerator.print(
|
||||
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
|
||||
)
|
||||
# logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
||||
accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
|
||||
# print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
||||
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||
|
||||
@@ -349,7 +341,7 @@ def train(args):
|
||||
if accelerator.is_main_process:
|
||||
init_kwargs = {}
|
||||
if args.wandb_run_name:
|
||||
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
||||
init_kwargs['wandb'] = {'name': args.wandb_run_name}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers(
|
||||
@@ -556,13 +548,12 @@ def train(args):
|
||||
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
|
||||
save_model(ckpt_name, unet, global_step, num_train_epochs, force_sync_upload=True)
|
||||
|
||||
logger.info("model saved.")
|
||||
print("model saved.")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
add_logging_arguments(parser)
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, False, True, True)
|
||||
train_util.add_training_arguments(parser, False)
|
||||
@@ -578,12 +569,8 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
choices=[None, "ckpt", "pt", "safetensors"],
|
||||
help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み"
|
||||
)
|
||||
parser.add_argument("--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数")
|
||||
parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み")
|
||||
parser.add_argument("--network_dim", type=int, default=None, help="network dimensions (rank) / モジュールの次元数")
|
||||
parser.add_argument(
|
||||
"--network_dropout",
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import argparse
|
||||
import gc
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
@@ -9,9 +10,10 @@ from types import SimpleNamespace
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
@@ -39,12 +41,6 @@ from library.custom_train_functions import (
|
||||
apply_debiased_estimation,
|
||||
)
|
||||
import networks.control_net_lllite as control_net_lllite
|
||||
from library.utils import setup_logging, add_logging_arguments
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# TODO 他のスクリプトと共通化する
|
||||
@@ -65,7 +61,6 @@ def train(args):
|
||||
train_util.verify_training_args(args)
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
sdxl_train_util.verify_sdxl_training_args(args)
|
||||
setup_logging(args, reset=True)
|
||||
|
||||
cache_latents = args.cache_latents
|
||||
use_user_config = args.dataset_config is not None
|
||||
@@ -79,11 +74,11 @@ def train(args):
|
||||
# データセットを準備する
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
|
||||
if use_user_config:
|
||||
logger.info(f"Load dataset config from {args.dataset_config}")
|
||||
print(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "conditioning_data_dir"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
logger.warning(
|
||||
print(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
@@ -115,7 +110,7 @@ def train(args):
|
||||
train_util.debug_dataset(train_dataset_group)
|
||||
return
|
||||
if len(train_dataset_group) == 0:
|
||||
logger.error(
|
||||
print(
|
||||
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)"
|
||||
)
|
||||
return
|
||||
@@ -125,9 +120,7 @@ def train(args):
|
||||
train_dataset_group.is_latent_cacheable()
|
||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
else:
|
||||
logger.warning(
|
||||
"WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません"
|
||||
)
|
||||
print("WARNING: random_crop is not supported yet for ControlNet training / ControlNetの学習ではrandom_cropはまだサポートされていません")
|
||||
|
||||
if args.cache_text_encoder_outputs:
|
||||
assert (
|
||||
@@ -135,7 +128,7 @@ def train(args):
|
||||
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
||||
|
||||
# acceleratorを準備する
|
||||
logger.info("prepare accelerator")
|
||||
print("prepare accelerator")
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
is_main_process = accelerator.is_main_process
|
||||
|
||||
@@ -170,7 +163,9 @@ def train(args):
|
||||
accelerator.is_main_process,
|
||||
)
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
@@ -204,14 +199,14 @@ def train(args):
|
||||
accelerator.print("prepare optimizer, data loader etc.")
|
||||
|
||||
trainable_params = list(network.prepare_optimizer_params())
|
||||
logger.info(f"trainable params count: {len(trainable_params)}")
|
||||
logger.info(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}")
|
||||
print(f"trainable params count: {len(trainable_params)}")
|
||||
print(f"number of trainable parameters: {sum(p.numel() for p in trainable_params if p.requires_grad)}")
|
||||
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset_group,
|
||||
@@ -227,9 +222,7 @@ def train(args):
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||
)
|
||||
accelerator.print(
|
||||
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
|
||||
)
|
||||
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
# データセット側にも学習ステップを送信
|
||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
@@ -271,7 +264,8 @@ def train(args):
|
||||
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
|
||||
text_encoder1.to("cpu", dtype=torch.float32)
|
||||
text_encoder2.to("cpu", dtype=torch.float32)
|
||||
clean_memory_on_device(accelerator.device)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
# make sure Text Encoders are on GPU
|
||||
text_encoder1.to(accelerator.device)
|
||||
@@ -302,10 +296,8 @@ def train(args):
|
||||
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
||||
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
||||
accelerator.print(
|
||||
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
|
||||
)
|
||||
# logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
||||
accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
|
||||
# print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
||||
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||
|
||||
@@ -524,13 +516,12 @@ def train(args):
|
||||
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
|
||||
save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True)
|
||||
|
||||
logger.info("model saved.")
|
||||
print("model saved.")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
add_logging_arguments(parser)
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, False, True, True)
|
||||
train_util.add_training_arguments(parser, False)
|
||||
@@ -546,12 +537,8 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
choices=[None, "ckpt", "pt", "safetensors"],
|
||||
help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み"
|
||||
)
|
||||
parser.add_argument("--cond_emb_dim", type=int, default=None, help="conditioning embedding dimension / 条件付け埋め込みの次元数")
|
||||
parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み")
|
||||
parser.add_argument("--network_dim", type=int, default=None, help="network dimensions (rank) / モジュールの次元数")
|
||||
parser.add_argument(
|
||||
"--network_dropout",
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
from library import sdxl_model_util, sdxl_train_util, train_util
|
||||
import train_network
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
def __init__(self):
|
||||
@@ -62,12 +60,13 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
if args.cache_text_encoder_outputs:
|
||||
if not args.lowram:
|
||||
# メモリ消費を減らす
|
||||
logger.info("move vae and unet to cpu to save memory")
|
||||
print("move vae and unet to cpu to save memory")
|
||||
org_vae_device = vae.device
|
||||
org_unet_device = unet.device
|
||||
vae.to("cpu")
|
||||
unet.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
|
||||
with accelerator.autocast():
|
||||
@@ -82,10 +81,11 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
|
||||
text_encoders[1].to("cpu", dtype=torch.float32)
|
||||
clean_memory_on_device(accelerator.device)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if not args.lowram:
|
||||
logger.info("move vae and unet back to original device")
|
||||
print("move vae and unet back to original device")
|
||||
vae.to(org_vae_device)
|
||||
unet.to(org_unet_device)
|
||||
else:
|
||||
@@ -143,7 +143,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
# assert ((encoder_hidden_states1.to("cpu") - ehs1.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
||||
# assert ((encoder_hidden_states2.to("cpu") - ehs2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
||||
# assert ((pool2.to("cpu") - p2.to(dtype=weight_dtype)).abs().max() > 1e-2).sum() <= b_size * 2
|
||||
# logger.info("text encoder outputs verified")
|
||||
# print("text encoder outputs verified")
|
||||
|
||||
return encoder_hidden_states1, encoder_hidden_states2, pool2
|
||||
|
||||
|
||||
@@ -2,11 +2,10 @@ import argparse
|
||||
import os
|
||||
|
||||
import regex
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex
|
||||
init_ipex()
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
import open_clip
|
||||
from library import sdxl_model_util, sdxl_train_util, train_util
|
||||
|
||||
|
||||
@@ -1,367 +0,0 @@
|
||||
import argparse
|
||||
import importlib
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
from safetensors.torch import load_file, save_file
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTextConfig
|
||||
from PIL import Image
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
import library.stable_cascade as sc
|
||||
import library.stable_cascade_utils as sc_utils
|
||||
import library.device_utils as device_utils
|
||||
from library import train_util
|
||||
from library.sdxl_model_util import _load_state_dict_on_device
|
||||
|
||||
|
||||
def main(args):
|
||||
device = device_utils.get_preferred_device()
|
||||
|
||||
loading_device = device if not args.lowvram else "cpu"
|
||||
text_model_device = "cpu"
|
||||
|
||||
dtype = torch.float32
|
||||
if args.bf16:
|
||||
dtype = torch.bfloat16
|
||||
elif args.fp16:
|
||||
dtype = torch.float16
|
||||
|
||||
text_model_dtype = torch.float32
|
||||
|
||||
# EfficientNet encoder
|
||||
effnet = sc_utils.load_effnet(args.effnet_checkpoint_path, loading_device)
|
||||
effnet.eval().requires_grad_(False).to(loading_device)
|
||||
|
||||
generator_c = sc_utils.load_stage_c_model(args.stage_c_checkpoint_path, dtype=dtype, device=loading_device)
|
||||
generator_c.eval().requires_grad_(False).to(loading_device)
|
||||
# if args.xformers or args.sdpa:
|
||||
print(f"Stage C: use_xformers_or_sdpa: {args.xformers} {args.sdpa}")
|
||||
generator_c.set_use_xformers_or_sdpa(args.xformers, args.sdpa)
|
||||
|
||||
generator_b = sc_utils.load_stage_b_model(args.stage_b_checkpoint_path, dtype=dtype, device=loading_device)
|
||||
generator_b.eval().requires_grad_(False).to(loading_device)
|
||||
# if args.xformers or args.sdpa:
|
||||
print(f"Stage B: use_xformers_or_sdpa: {args.xformers} {args.sdpa}")
|
||||
generator_b.set_use_xformers_or_sdpa(args.xformers, args.sdpa)
|
||||
|
||||
# CLIP encoders
|
||||
tokenizer = sc_utils.load_tokenizer(args)
|
||||
|
||||
text_model = sc_utils.load_clip_text_model(
|
||||
args.text_model_checkpoint_path, text_model_dtype, text_model_device, args.save_text_model
|
||||
)
|
||||
text_model = text_model.requires_grad_(False).to(text_model_dtype).to(text_model_device)
|
||||
|
||||
# image_model = (
|
||||
# CLIPVisionModelWithProjection.from_pretrained(clip_image_model_name).requires_grad_(False).to(dtype).to(device)
|
||||
# )
|
||||
|
||||
# vqGAN
|
||||
stage_a = sc_utils.load_stage_a_model(args.stage_a_checkpoint_path, dtype=dtype, device=loading_device)
|
||||
stage_a.eval().requires_grad_(False)
|
||||
|
||||
# previewer
|
||||
if args.previewer_checkpoint_path is not None:
|
||||
previewer = sc_utils.load_previewer_model(args.previewer_checkpoint_path, dtype=dtype, device=loading_device)
|
||||
previewer.eval().requires_grad_(False)
|
||||
else:
|
||||
previewer = None
|
||||
|
||||
# LoRA
|
||||
if args.network_module:
|
||||
for i, network_module in enumerate(args.network_module):
|
||||
print("import network module:", network_module)
|
||||
imported_module = importlib.import_module(network_module)
|
||||
|
||||
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
|
||||
|
||||
net_kwargs = {}
|
||||
if args.network_args and i < len(args.network_args):
|
||||
network_args = args.network_args[i]
|
||||
# TODO escape special chars
|
||||
network_args = network_args.split(";")
|
||||
for net_arg in network_args:
|
||||
key, value = net_arg.split("=")
|
||||
net_kwargs[key] = value
|
||||
|
||||
if args.network_weights is None or len(args.network_weights) <= i:
|
||||
raise ValueError("No weight. Weight is required.")
|
||||
|
||||
network_weight = args.network_weights[i]
|
||||
print("load network weights from:", network_weight)
|
||||
|
||||
network, weights_sd = imported_module.create_network_from_weights(
|
||||
network_mul, network_weight, effnet, text_model, generator_c, for_inference=True, **net_kwargs
|
||||
)
|
||||
if network is None:
|
||||
return
|
||||
|
||||
mergeable = network.is_mergeable()
|
||||
assert mergeable, "not-mergeable network is not supported yet."
|
||||
|
||||
network.merge_to(text_model, generator_c, weights_sd, dtype, device)
|
||||
|
||||
# 謎のクラス gdf
|
||||
gdf_c = sc.GDF(
|
||||
schedule=sc.CosineSchedule(clamp_range=[0.0001, 0.9999]),
|
||||
input_scaler=sc.VPScaler(),
|
||||
target=sc.EpsilonTarget(),
|
||||
noise_cond=sc.CosineTNoiseCond(),
|
||||
loss_weight=None,
|
||||
)
|
||||
gdf_b = sc.GDF(
|
||||
schedule=sc.CosineSchedule(clamp_range=[0.0001, 0.9999]),
|
||||
input_scaler=sc.VPScaler(),
|
||||
target=sc.EpsilonTarget(),
|
||||
noise_cond=sc.CosineTNoiseCond(),
|
||||
loss_weight=None,
|
||||
)
|
||||
|
||||
# Stage C Parameters
|
||||
|
||||
# extras.sampling_configs["cfg"] = 4
|
||||
# extras.sampling_configs["shift"] = 2
|
||||
# extras.sampling_configs["timesteps"] = 20
|
||||
# extras.sampling_configs["t_start"] = 1.0
|
||||
|
||||
# # Stage B Parameters
|
||||
# extras_b.sampling_configs["cfg"] = 1.1
|
||||
# extras_b.sampling_configs["shift"] = 1
|
||||
# extras_b.sampling_configs["timesteps"] = 10
|
||||
# extras_b.sampling_configs["t_start"] = 1.0
|
||||
b_cfg = 1.1
|
||||
b_shift = 1
|
||||
b_timesteps = 10
|
||||
b_t_start = 1.0
|
||||
|
||||
# caption = "Cinematic photo of an anthropomorphic penguin sitting in a cafe reading a book and having a coffee"
|
||||
# height, width = 1024, 1024
|
||||
|
||||
while True:
|
||||
print("type caption:")
|
||||
# if Ctrl+Z is pressed, it will raise EOFError
|
||||
try:
|
||||
caption = input()
|
||||
except EOFError:
|
||||
break
|
||||
|
||||
caption = caption.strip()
|
||||
if caption == "":
|
||||
continue
|
||||
|
||||
# parse options: '--w' and '--h' for size, '--l' for cfg, '--s' for timesteps, '--f' for shift. if not specified, use default values
|
||||
# e.g. "caption --w 4 --h 4 --l 20 --s 20 --f 1.0"
|
||||
|
||||
tokens = caption.split()
|
||||
width = height = 1024
|
||||
cfg = 4
|
||||
timesteps = 20
|
||||
shift = 2
|
||||
t_start = 1.0
|
||||
negative_prompt = ""
|
||||
seed = None
|
||||
|
||||
caption_tokens = []
|
||||
i = 0
|
||||
while i < len(tokens):
|
||||
token = tokens[i]
|
||||
if i == len(tokens) - 1:
|
||||
caption_tokens.append(token)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if token == "--w":
|
||||
width = int(tokens[i + 1])
|
||||
elif token == "--h":
|
||||
height = int(tokens[i + 1])
|
||||
elif token == "--l":
|
||||
cfg = float(tokens[i + 1])
|
||||
elif token == "--s":
|
||||
timesteps = int(tokens[i + 1])
|
||||
elif token == "--f":
|
||||
shift = float(tokens[i + 1])
|
||||
elif token == "--t":
|
||||
t_start = float(tokens[i + 1])
|
||||
elif token == "--n":
|
||||
negative_prompt = tokens[i + 1]
|
||||
elif token == "--d":
|
||||
seed = int(tokens[i + 1])
|
||||
else:
|
||||
caption_tokens.append(token)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
i += 2
|
||||
|
||||
caption = " ".join(caption_tokens)
|
||||
|
||||
stage_c_latent_shape, stage_b_latent_shape = sc_utils.calculate_latent_sizes(height, width, batch_size=1)
|
||||
|
||||
# PREPARE CONDITIONS
|
||||
# cond_text, cond_pooled = sc.get_clip_conditions([caption], None, tokenizer, text_model)
|
||||
input_ids = tokenizer(
|
||||
[caption], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
|
||||
)["input_ids"].to(text_model.device)
|
||||
cond_text, cond_pooled = train_util.get_hidden_states_stable_cascade(
|
||||
tokenizer.model_max_length, input_ids, tokenizer, text_model
|
||||
)
|
||||
cond_text = cond_text.to(device, dtype=dtype)
|
||||
cond_pooled = cond_pooled.unsqueeze(1).to(device, dtype=dtype)
|
||||
|
||||
# uncond_text, uncond_pooled = sc.get_clip_conditions([""], None, tokenizer, text_model)
|
||||
input_ids = tokenizer(
|
||||
[negative_prompt], truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
|
||||
)["input_ids"].to(text_model.device)
|
||||
uncond_text, uncond_pooled = train_util.get_hidden_states_stable_cascade(
|
||||
tokenizer.model_max_length, input_ids, tokenizer, text_model
|
||||
)
|
||||
uncond_text = uncond_text.to(device, dtype=dtype)
|
||||
uncond_pooled = uncond_pooled.unsqueeze(1).to(device, dtype=dtype)
|
||||
|
||||
zero_img_emb = torch.zeros(1, 768, device=device)
|
||||
|
||||
# 辞書にしたくないけど GDF から先の変更が面倒だからとりあえず辞書にしておく
|
||||
conditions = {"clip_text_pooled": cond_pooled, "clip": cond_pooled, "clip_text": cond_text, "clip_img": zero_img_emb}
|
||||
unconditions = {
|
||||
"clip_text_pooled": uncond_pooled,
|
||||
"clip": uncond_pooled,
|
||||
"clip_text": uncond_text,
|
||||
"clip_img": zero_img_emb,
|
||||
}
|
||||
conditions_b = {}
|
||||
conditions_b.update(conditions)
|
||||
unconditions_b = {}
|
||||
unconditions_b.update(unconditions)
|
||||
|
||||
# seed everything
|
||||
if seed is not None:
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
# torch.backends.cudnn.benchmark = False
|
||||
|
||||
if args.lowvram:
|
||||
generator_c = generator_c.to(device)
|
||||
|
||||
with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype):
|
||||
sampling_c = gdf_c.sample(
|
||||
generator_c,
|
||||
conditions,
|
||||
stage_c_latent_shape,
|
||||
unconditions,
|
||||
device=device,
|
||||
cfg=cfg,
|
||||
shift=shift,
|
||||
timesteps=timesteps,
|
||||
t_start=t_start,
|
||||
)
|
||||
for sampled_c, _, _ in tqdm(sampling_c, total=timesteps):
|
||||
sampled_c = sampled_c
|
||||
|
||||
conditions_b["effnet"] = sampled_c
|
||||
unconditions_b["effnet"] = torch.zeros_like(sampled_c)
|
||||
|
||||
if previewer is not None:
|
||||
with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype):
|
||||
preview = previewer(sampled_c)
|
||||
preview = preview.clamp(0, 1)
|
||||
preview = preview.permute(0, 2, 3, 1).squeeze(0)
|
||||
preview = preview.detach().float().cpu().numpy()
|
||||
preview = Image.fromarray((preview * 255).astype(np.uint8))
|
||||
|
||||
timestamp_str = time.strftime("%Y%m%d_%H%M%S")
|
||||
os.makedirs(args.outdir, exist_ok=True)
|
||||
preview.save(os.path.join(args.outdir, f"preview_{timestamp_str}.png"))
|
||||
|
||||
if args.lowvram:
|
||||
generator_c = generator_c.to(loading_device)
|
||||
device_utils.clean_memory_on_device(device)
|
||||
generator_b = generator_b.to(device)
|
||||
|
||||
with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype):
|
||||
sampling_b = gdf_b.sample(
|
||||
generator_b,
|
||||
conditions_b,
|
||||
stage_b_latent_shape,
|
||||
unconditions_b,
|
||||
device=device,
|
||||
cfg=b_cfg,
|
||||
shift=b_shift,
|
||||
timesteps=b_timesteps,
|
||||
t_start=b_t_start,
|
||||
)
|
||||
for sampled_b, _, _ in tqdm(sampling_b, total=b_t_start):
|
||||
sampled_b = sampled_b
|
||||
|
||||
if args.lowvram:
|
||||
generator_b = generator_b.to(loading_device)
|
||||
device_utils.clean_memory_on_device(device)
|
||||
stage_a = stage_a.to(device)
|
||||
|
||||
with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype):
|
||||
sampled = stage_a.decode(sampled_b).float()
|
||||
# print(sampled.shape, sampled.min(), sampled.max())
|
||||
|
||||
if args.lowvram:
|
||||
stage_a = stage_a.to(loading_device)
|
||||
device_utils.clean_memory_on_device(device)
|
||||
|
||||
# float 0-1 to PIL Image
|
||||
sampled = sampled.clamp(0, 1)
|
||||
sampled = sampled.mul(255).to(dtype=torch.uint8)
|
||||
sampled = sampled.permute(0, 2, 3, 1)
|
||||
sampled = sampled.cpu().numpy()
|
||||
sampled = Image.fromarray(sampled[0])
|
||||
|
||||
timestamp_str = time.strftime("%Y%m%d_%H%M%S")
|
||||
os.makedirs(args.outdir, exist_ok=True)
|
||||
sampled.save(os.path.join(args.outdir, f"sampled_{timestamp_str}.png"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
sc_utils.add_effnet_arguments(parser)
|
||||
train_util.add_tokenizer_arguments(parser)
|
||||
sc_utils.add_stage_a_arguments(parser)
|
||||
sc_utils.add_stage_b_arguments(parser)
|
||||
sc_utils.add_stage_c_arguments(parser)
|
||||
sc_utils.add_previewer_arguments(parser)
|
||||
sc_utils.add_text_model_arguments(parser)
|
||||
parser.add_argument("--bf16", action="store_true")
|
||||
parser.add_argument("--fp16", action="store_true")
|
||||
parser.add_argument("--xformers", action="store_true")
|
||||
parser.add_argument("--sdpa", action="store_true")
|
||||
parser.add_argument("--outdir", type=str, default="../outputs", help="dir to write results to / 生成画像の出力先")
|
||||
parser.add_argument("--lowvram", action="store_true", help="if specified, use low VRAM mode")
|
||||
parser.add_argument(
|
||||
"--network_module",
|
||||
type=str,
|
||||
default=None,
|
||||
nargs="*",
|
||||
help="additional network module to use / 追加ネットワークを使う時そのモジュール名",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--network_weights", type=str, default=None, nargs="*", help="additional network weights to load / 追加ネットワークの重み"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--network_mul", type=float, default=None, nargs="*", help="additional network multiplier / 追加ネットワークの効果の倍率"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--network_args",
|
||||
type=str,
|
||||
default=None,
|
||||
nargs="*",
|
||||
help="additional arguments for network (key=value) / ネットワークへの追加の引数",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,564 +0,0 @@
|
||||
# training with captions
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
from multiprocessing import Value
|
||||
from typing import List
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler
|
||||
|
||||
import library.train_util as train_util
|
||||
from library.sdxl_train_util import add_sdxl_training_arguments
|
||||
import library.stable_cascade_utils as sc_utils
|
||||
import library.stable_cascade as sc
|
||||
|
||||
from library.utils import setup_logging, add_logging_arguments
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import library.config_util as config_util
|
||||
from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
)
|
||||
|
||||
|
||||
def train(args):
|
||||
train_util.verify_training_args(args)
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
setup_logging(args, reset=True)
|
||||
|
||||
# assert (
|
||||
# not args.weighted_captions
|
||||
# ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
|
||||
|
||||
# TODO add assertions for other unsupported options
|
||||
|
||||
cache_latents = args.cache_latents
|
||||
use_dreambooth_method = args.in_json is None
|
||||
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed) # 乱数系列を初期化する
|
||||
|
||||
tokenizer = sc_utils.load_tokenizer(args)
|
||||
|
||||
# データセットを準備する
|
||||
if args.dataset_class is None:
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
|
||||
if args.dataset_config is not None:
|
||||
logger.info(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
logger.warning(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
)
|
||||
else:
|
||||
if use_dreambooth_method:
|
||||
logger.info("Using DreamBooth method.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
|
||||
args.train_data_dir, args.reg_data_dir
|
||||
)
|
||||
}
|
||||
]
|
||||
}
|
||||
else:
|
||||
logger.info("Training with captions.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": [
|
||||
{
|
||||
"image_dir": args.train_data_dir,
|
||||
"metadata_file": args.in_json,
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer])
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args, [tokenizer])
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(32)
|
||||
|
||||
if args.debug_dataset:
|
||||
train_util.debug_dataset(train_dataset_group, True)
|
||||
return
|
||||
if len(train_dataset_group) == 0:
|
||||
logger.error(
|
||||
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
|
||||
)
|
||||
return
|
||||
|
||||
if cache_latents:
|
||||
assert (
|
||||
train_dataset_group.is_latent_cacheable()
|
||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
|
||||
if args.cache_text_encoder_outputs:
|
||||
assert (
|
||||
train_dataset_group.is_text_encoder_output_cacheable()
|
||||
), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
||||
|
||||
# acceleratorを準備する
|
||||
logger.info("prepare accelerator")
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
effnet_dtype = torch.float32 if args.no_half_vae else weight_dtype
|
||||
|
||||
# モデルを読み込む
|
||||
loading_device = accelerator.device if args.lowram else "cpu"
|
||||
effnet = sc_utils.load_effnet(args.effnet_checkpoint_path, loading_device)
|
||||
stage_c = sc_utils.load_stage_c_model(args.stage_c_checkpoint_path, device=loading_device) # dtype is as it is
|
||||
text_encoder1 = sc_utils.load_clip_text_model(args.text_model_checkpoint_path, dtype=weight_dtype, device=loading_device)
|
||||
|
||||
if args.sample_at_first or args.sample_every_n_steps is not None or args.sample_every_n_epochs is not None:
|
||||
# Previewer is small enough to be loaded on CPU
|
||||
previewer = sc_utils.load_previewer_model(args.previewer_checkpoint_path, dtype=torch.float32, device="cpu")
|
||||
previewer.eval()
|
||||
else:
|
||||
previewer = None
|
||||
|
||||
# モデルに xformers とか memory efficient attention を組み込む
|
||||
stage_c.set_use_xformers_or_sdpa(args.xformers, args.sdpa)
|
||||
|
||||
# 学習を準備する
|
||||
if cache_latents:
|
||||
effnet.to(accelerator.device, dtype=effnet_dtype)
|
||||
effnet.requires_grad_(False)
|
||||
effnet.eval()
|
||||
with torch.no_grad():
|
||||
train_dataset_group.cache_latents(
|
||||
effnet,
|
||||
args.vae_batch_size,
|
||||
args.cache_latents_to_disk,
|
||||
accelerator.is_main_process,
|
||||
train_util.STABLE_CASCADE_LATENTS_CACHE_SUFFIX,
|
||||
32,
|
||||
)
|
||||
effnet.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# 学習を準備する:モデルを適切な状態にする
|
||||
if args.gradient_checkpointing:
|
||||
accelerator.print("enable gradient checkpointing")
|
||||
stage_c.set_gradient_checkpointing(True)
|
||||
|
||||
train_stage_c = args.learning_rate > 0
|
||||
train_text_encoder1 = False
|
||||
|
||||
if args.train_text_encoder:
|
||||
accelerator.print("enable text encoder training")
|
||||
if args.gradient_checkpointing:
|
||||
text_encoder1.gradient_checkpointing_enable()
|
||||
lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train
|
||||
train_text_encoder1 = lr_te1 > 0
|
||||
assert (
|
||||
train_text_encoder1
|
||||
), "text_encoder1 learning rate is 0. Please set a positive value / text_encoder1の学習率が0です。正の値を設定してください。"
|
||||
|
||||
if not train_text_encoder1:
|
||||
text_encoder1.to(weight_dtype)
|
||||
text_encoder1.requires_grad_(train_text_encoder1)
|
||||
text_encoder1.train(train_text_encoder1)
|
||||
else:
|
||||
text_encoder1.to(weight_dtype)
|
||||
text_encoder1.requires_grad_(False)
|
||||
text_encoder1.eval()
|
||||
|
||||
# TextEncoderの出力をキャッシュする
|
||||
if args.cache_text_encoder_outputs:
|
||||
# Text Encodes are eval and no grad
|
||||
with torch.no_grad(), accelerator.autocast():
|
||||
train_dataset_group.cache_text_encoder_outputs(
|
||||
(tokenizer,),
|
||||
(text_encoder1,),
|
||||
accelerator.device,
|
||||
None,
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
accelerator.is_main_process,
|
||||
sc_utils.TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX,
|
||||
)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if not cache_latents:
|
||||
effnet.requires_grad_(False)
|
||||
effnet.eval()
|
||||
effnet.to(accelerator.device, dtype=effnet_dtype)
|
||||
|
||||
stage_c.requires_grad_(True)
|
||||
if not train_stage_c:
|
||||
stage_c.to(accelerator.device, dtype=weight_dtype) # because of stage_c will not be prepared
|
||||
|
||||
training_models = []
|
||||
params_to_optimize = []
|
||||
if train_stage_c:
|
||||
training_models.append(stage_c)
|
||||
params_to_optimize.append({"params": list(stage_c.parameters()), "lr": args.learning_rate})
|
||||
|
||||
if train_text_encoder1:
|
||||
training_models.append(text_encoder1)
|
||||
params_to_optimize.append({"params": list(text_encoder1.parameters()), "lr": args.learning_rate_te1 or args.learning_rate})
|
||||
|
||||
# calculate number of trainable parameters
|
||||
n_params = 0
|
||||
for params in params_to_optimize:
|
||||
for p in params["params"]:
|
||||
n_params += p.numel()
|
||||
|
||||
accelerator.print(f"train stage-C: {train_stage_c}, text_encoder1: {train_text_encoder1}")
|
||||
accelerator.print(f"number of models: {len(training_models)}")
|
||||
accelerator.print(f"number of trainable parameters: {n_params}")
|
||||
|
||||
# 学習に必要なクラスを準備する
|
||||
accelerator.print("prepare optimizer, data loader etc.")
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset_group,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||
)
|
||||
accelerator.print(
|
||||
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
|
||||
)
|
||||
|
||||
# データセット側にも学習ステップを送信
|
||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
|
||||
# lr schedulerを用意する
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
|
||||
if args.full_fp16:
|
||||
assert (
|
||||
args.mixed_precision == "fp16"
|
||||
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
||||
accelerator.print("enable full fp16 training.")
|
||||
stage_c.to(weight_dtype)
|
||||
text_encoder1.to(weight_dtype)
|
||||
elif args.full_bf16:
|
||||
assert (
|
||||
args.mixed_precision == "bf16"
|
||||
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
|
||||
accelerator.print("enable full bf16 training.")
|
||||
stage_c.to(weight_dtype)
|
||||
text_encoder1.to(weight_dtype)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
if train_stage_c:
|
||||
stage_c = accelerator.prepare(stage_c)
|
||||
if train_text_encoder1:
|
||||
text_encoder1 = accelerator.prepare(text_encoder1)
|
||||
|
||||
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
|
||||
if args.cache_text_encoder_outputs:
|
||||
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
|
||||
text_encoder1.to("cpu", dtype=torch.float32)
|
||||
clean_memory_on_device(accelerator.device)
|
||||
else:
|
||||
# make sure Text Encoders are on GPU
|
||||
text_encoder1.to(accelerator.device)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||
if args.full_fp16:
|
||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||
|
||||
# resumeする
|
||||
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
||||
|
||||
# epoch数を計算する
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
||||
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
||||
|
||||
# 学習する
|
||||
# total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
accelerator.print("running training / 学習開始")
|
||||
accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
|
||||
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
||||
accelerator.print(
|
||||
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
|
||||
)
|
||||
# accelerator.print(
|
||||
# f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
|
||||
# )
|
||||
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||
|
||||
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
||||
global_step = 0
|
||||
|
||||
# 謎のクラス GDF
|
||||
gdf = sc.GDF(
|
||||
schedule=sc.CosineSchedule(clamp_range=[0.0001, 0.9999]),
|
||||
input_scaler=sc.VPScaler(),
|
||||
target=sc.EpsilonTarget(),
|
||||
noise_cond=sc.CosineTNoiseCond(),
|
||||
loss_weight=sc.AdaptiveLossWeight() if args.adaptive_loss_weight else sc.P2LossWeight(),
|
||||
)
|
||||
|
||||
# 以下2つの変数は、どうもデフォルトのままっぽい
|
||||
# gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges'])
|
||||
# gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses'])
|
||||
|
||||
if accelerator.is_main_process:
|
||||
init_kwargs = {}
|
||||
if args.wandb_run_name:
|
||||
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
|
||||
|
||||
# For --sample_at_first
|
||||
sc_utils.sample_images(accelerator, args, 0, global_step, previewer, tokenizer, text_encoder1, stage_c, gdf)
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
for epoch in range(num_train_epochs):
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
for m in training_models:
|
||||
m.train()
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
current_step.value = global_step
|
||||
with accelerator.accumulate(*training_models):
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
# latentに変換
|
||||
# XXX Effnet preprocessing is included in encode method
|
||||
latents = effnet.encode(batch["images"].to(effnet_dtype)).latent_dist.sample().to(weight_dtype)
|
||||
|
||||
# NaNが含まれていれば警告を表示し0に置き換える
|
||||
if torch.any(torch.isnan(latents)):
|
||||
accelerator.print("NaN found in latents, replacing with zeros")
|
||||
latents = torch.nan_to_num(latents, 0, out=latents)
|
||||
|
||||
# # debug: decode latent with previewer and save it
|
||||
# import time
|
||||
# import numpy as np
|
||||
# from PIL import Image
|
||||
# ts = time.time()
|
||||
# images = previewer(latents.to(previewer.device, dtype=previewer.dtype))
|
||||
# for i, img in enumerate(images):
|
||||
# img = img.detach().cpu().numpy().transpose(1, 2, 0)
|
||||
# img = np.clip(img, 0, 1)
|
||||
# img = (img * 255).astype(np.uint8)
|
||||
# img = Image.fromarray(img)
|
||||
# img.save(f"logs/previewer_{i}_{ts}.png")
|
||||
|
||||
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
|
||||
input_ids1 = batch["input_ids"]
|
||||
with torch.set_grad_enabled(args.train_text_encoder):
|
||||
# Get the text embedding for conditioning
|
||||
# TODO support weighted captions
|
||||
input_ids1 = input_ids1.to(accelerator.device)
|
||||
# unwrap_model is fine for models not wrapped by accelerator
|
||||
encoder_hidden_states, pool = train_util.get_hidden_states_stable_cascade(
|
||||
args.max_token_length,
|
||||
input_ids1,
|
||||
tokenizer,
|
||||
text_encoder1,
|
||||
None if not args.full_fp16 else weight_dtype,
|
||||
accelerator,
|
||||
)
|
||||
else:
|
||||
encoder_hidden_states = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
|
||||
pool = batch["text_encoder_pool2_list"].to(accelerator.device).to(weight_dtype)
|
||||
|
||||
pool = pool.unsqueeze(1) # add extra dimension b,1280 -> b,1,1280
|
||||
|
||||
# FORWARD PASS
|
||||
with torch.no_grad():
|
||||
noised, noise, target, logSNR, noise_cond, loss_weight = gdf.diffuse(latents, shift=1, loss_shift=1)
|
||||
|
||||
zero_img_emb = torch.zeros(noised.shape[0], 768, device=accelerator.device)
|
||||
with accelerator.autocast():
|
||||
pred = stage_c(
|
||||
noised, noise_cond, clip_text=encoder_hidden_states, clip_text_pooled=pool, clip_img=zero_img_emb
|
||||
)
|
||||
loss = torch.nn.functional.mse_loss(pred, target, reduction="none").mean(dim=[1, 2, 3])
|
||||
loss_adjusted = (loss * loss_weight).mean()
|
||||
|
||||
if args.adaptive_loss_weight:
|
||||
gdf.loss_weight.update_buckets(logSNR, loss) # use loss instead of loss_adjusted
|
||||
|
||||
accelerator.backward(loss_adjusted)
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
params_to_clip = []
|
||||
for m in training_models:
|
||||
params_to_clip.extend(m.parameters())
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
sc_utils.sample_images(accelerator, args, None, global_step, previewer, tokenizer, text_encoder1, stage_c, gdf)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
sc_utils.save_stage_c_model_on_epoch_end_or_stepwise(
|
||||
args,
|
||||
False,
|
||||
accelerator,
|
||||
save_dtype,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
accelerator.unwrap_model(stage_c),
|
||||
accelerator.unwrap_model(text_encoder1) if train_text_encoder1 else None,
|
||||
)
|
||||
|
||||
current_loss = loss_adjusted.detach().item() # 平均なのでbatch sizeは関係ないはず
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss": current_loss}
|
||||
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
|
||||
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||
avr_loss: float = loss_recorder.moving_average
|
||||
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss/epoch": loss_recorder.moving_average}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if args.save_every_n_epochs is not None:
|
||||
if accelerator.is_main_process:
|
||||
sc_utils.save_stage_c_model_on_epoch_end_or_stepwise(
|
||||
args,
|
||||
True,
|
||||
accelerator,
|
||||
save_dtype,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
accelerator.unwrap_model(stage_c),
|
||||
accelerator.unwrap_model(text_encoder1) if train_text_encoder1 else None,
|
||||
)
|
||||
|
||||
sc_utils.sample_images(accelerator, args, epoch + 1, global_step, previewer, tokenizer, text_encoder1, stage_c, gdf)
|
||||
|
||||
is_main_process = accelerator.is_main_process
|
||||
# if is_main_process:
|
||||
stage_c = accelerator.unwrap_model(stage_c)
|
||||
text_encoder1 = accelerator.unwrap_model(text_encoder1)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
if args.save_state: # and is_main_process:
|
||||
train_util.save_state_on_train_end(args, accelerator)
|
||||
|
||||
del accelerator # この後メモリを使うのでこれは消す
|
||||
|
||||
if is_main_process:
|
||||
sc_utils.save_stage_c_model_on_end(
|
||||
args, save_dtype, epoch, global_step, stage_c, text_encoder1 if train_text_encoder1 else None
|
||||
)
|
||||
logger.info("model saved.")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
add_logging_arguments(parser)
|
||||
sc_utils.add_effnet_arguments(parser)
|
||||
sc_utils.add_stage_c_arguments(parser)
|
||||
sc_utils.add_text_model_arguments(parser)
|
||||
sc_utils.add_previewer_arguments(parser)
|
||||
sc_utils.add_training_arguments(parser)
|
||||
train_util.add_tokenizer_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, True, True, True)
|
||||
train_util.add_training_arguments(parser, False)
|
||||
train_util.add_sd_saving_arguments(parser)
|
||||
train_util.add_optimizer_arguments(parser)
|
||||
config_util.add_config_arguments(parser)
|
||||
add_sdxl_training_arguments(parser) # cache text encoder outputs
|
||||
|
||||
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
|
||||
parser.add_argument(
|
||||
"--learning_rate_te1",
|
||||
type=float,
|
||||
default=None,
|
||||
help="learning rate for text encoder / text encoderの学習率",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_half_vae",
|
||||
action="store_true",
|
||||
help="do not use fp16/bf16 Effnet in mixed precision (use float Effnet) / mixed precisionでも fp16/bf16 Effnetを使わずfloat Effnetを使う",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
train(args)
|
||||
@@ -16,10 +16,7 @@ from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
)
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
@@ -44,18 +41,18 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
if args.dataset_class is None:
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
|
||||
if args.dataset_config is not None:
|
||||
logger.info(f"Load dataset config from {args.dataset_config}")
|
||||
print(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
logger.warning(
|
||||
print(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
)
|
||||
else:
|
||||
if use_dreambooth_method:
|
||||
logger.info("Using DreamBooth method.")
|
||||
print("Using DreamBooth method.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
@@ -66,7 +63,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
]
|
||||
}
|
||||
else:
|
||||
logger.info("Training with captions.")
|
||||
print("Training with captions.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
@@ -93,7 +90,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
# acceleratorを準備する
|
||||
logger.info("prepare accelerator")
|
||||
print("prepare accelerator")
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
@@ -101,7 +98,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
|
||||
|
||||
# モデルを読み込む
|
||||
logger.info("load model")
|
||||
print("load model")
|
||||
if args.sdxl:
|
||||
(_, _, _, vae, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
|
||||
else:
|
||||
@@ -116,8 +113,8 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
# dataloaderを準備する
|
||||
train_dataset_group.set_caching_mode("latents")
|
||||
|
||||
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset_group,
|
||||
@@ -155,7 +152,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
|
||||
if args.skip_existing:
|
||||
if train_util.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug):
|
||||
logger.warning(f"Skipping {image_info.latents_npz} because it already exists.")
|
||||
print(f"Skipping {image_info.latents_npz} because it already exists.")
|
||||
continue
|
||||
|
||||
image_infos.append(image_info)
|
||||
|
||||
@@ -16,10 +16,7 @@ from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
)
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
@@ -51,18 +48,18 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
if args.dataset_class is None:
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
|
||||
if args.dataset_config is not None:
|
||||
logger.info(f"Load dataset config from {args.dataset_config}")
|
||||
print(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
logger.warning(
|
||||
print(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
)
|
||||
else:
|
||||
if use_dreambooth_method:
|
||||
logger.info("Using DreamBooth method.")
|
||||
print("Using DreamBooth method.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
@@ -73,7 +70,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
]
|
||||
}
|
||||
else:
|
||||
logger.info("Training with captions.")
|
||||
print("Training with captions.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
@@ -98,14 +95,14 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
# acceleratorを準備する
|
||||
logger.info("prepare accelerator")
|
||||
print("prepare accelerator")
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
weight_dtype, _ = train_util.prepare_dtype(args)
|
||||
|
||||
# モデルを読み込む
|
||||
logger.info("load model")
|
||||
print("load model")
|
||||
if args.sdxl:
|
||||
(_, text_encoder1, text_encoder2, _, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
|
||||
text_encoders = [text_encoder1, text_encoder2]
|
||||
@@ -121,8 +118,8 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
# dataloaderを準備する
|
||||
train_dataset_group.set_caching_mode("text")
|
||||
|
||||
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset_group,
|
||||
@@ -150,7 +147,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
|
||||
if args.skip_existing:
|
||||
if os.path.exists(image_info.text_encoder_outputs_npz):
|
||||
logger.warning(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.")
|
||||
print(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.")
|
||||
continue
|
||||
|
||||
image_info.input_ids1 = input_ids1
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
import argparse
|
||||
import cv2
|
||||
|
||||
import logging
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def canny(args):
|
||||
img = cv2.imread(args.input)
|
||||
@@ -14,7 +10,7 @@ def canny(args):
|
||||
# canny_img = 255 - canny_img
|
||||
|
||||
cv2.imwrite(args.output, canny_img)
|
||||
logger.info("done!")
|
||||
print("done!")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
|
||||
@@ -6,10 +6,7 @@ import torch
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
import library.model_util as model_util
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def convert(args):
|
||||
# 引数を確認する
|
||||
@@ -33,7 +30,7 @@ def convert(args):
|
||||
|
||||
# モデルを読み込む
|
||||
msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else ""))
|
||||
logger.info(f"loading {msg}: {args.model_to_load}")
|
||||
print(f"loading {msg}: {args.model_to_load}")
|
||||
|
||||
if is_load_ckpt:
|
||||
v2_model = args.v2
|
||||
@@ -51,13 +48,13 @@ def convert(args):
|
||||
if args.v1 == args.v2:
|
||||
# 自動判定する
|
||||
v2_model = unet.config.cross_attention_dim == 1024
|
||||
logger.info("checking model version: model is " + ("v2" if v2_model else "v1"))
|
||||
print("checking model version: model is " + ("v2" if v2_model else "v1"))
|
||||
else:
|
||||
v2_model = not args.v1
|
||||
|
||||
# 変換して保存する
|
||||
msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers"
|
||||
logger.info(f"converting and saving as {msg}: {args.model_to_save}")
|
||||
print(f"converting and saving as {msg}: {args.model_to_save}")
|
||||
|
||||
if is_save_ckpt:
|
||||
original_model = args.model_to_load if is_load_ckpt else None
|
||||
@@ -73,15 +70,15 @@ def convert(args):
|
||||
save_dtype=save_dtype,
|
||||
vae=vae,
|
||||
)
|
||||
logger.info(f"model saved. total converted state_dict keys: {key_count}")
|
||||
print(f"model saved. total converted state_dict keys: {key_count}")
|
||||
else:
|
||||
logger.info(
|
||||
print(
|
||||
f"copy scheduler/tokenizer config from: {args.reference_model if args.reference_model is not None else 'default model'}"
|
||||
)
|
||||
model_util.save_diffusers_checkpoint(
|
||||
v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors
|
||||
)
|
||||
logger.info("model saved.")
|
||||
print("model saved.")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
|
||||
@@ -15,10 +15,6 @@ import os
|
||||
from anime_face_detector import create_detector
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
KP_REYE = 11
|
||||
KP_LEYE = 19
|
||||
@@ -28,7 +24,7 @@ SCORE_THRES = 0.90
|
||||
|
||||
def detect_faces(detector, image, min_size):
|
||||
preds = detector(image) # bgr
|
||||
# logger.info(len(preds))
|
||||
# print(len(preds))
|
||||
|
||||
faces = []
|
||||
for pred in preds:
|
||||
@@ -82,7 +78,7 @@ def process(args):
|
||||
assert args.crop_ratio is None or args.resize_face_size is None, f"crop_ratio指定時はresize_face_sizeは指定できません"
|
||||
|
||||
# アニメ顔検出モデルを読み込む
|
||||
logger.info("loading face detector.")
|
||||
print("loading face detector.")
|
||||
detector = create_detector('yolov3')
|
||||
|
||||
# cropの引数を解析する
|
||||
@@ -101,7 +97,7 @@ def process(args):
|
||||
crop_h_ratio, crop_v_ratio = [float(t) for t in tokens]
|
||||
|
||||
# 画像を処理する
|
||||
logger.info("processing.")
|
||||
print("processing.")
|
||||
output_extension = ".png"
|
||||
|
||||
os.makedirs(args.dst_dir, exist_ok=True)
|
||||
@@ -115,7 +111,7 @@ def process(args):
|
||||
if len(image.shape) == 2:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
|
||||
if image.shape[2] == 4:
|
||||
logger.warning(f"image has alpha. ignore / 画像の透明度が設定されているため無視します: {path}")
|
||||
print(f"image has alpha. ignore / 画像の透明度が設定されているため無視します: {path}")
|
||||
image = image[:, :, :3].copy() # copyをしないと内部的に透明度情報が付いたままになるらしい
|
||||
|
||||
h, w = image.shape[:2]
|
||||
@@ -148,11 +144,11 @@ def process(args):
|
||||
# 顔サイズを基準にリサイズする
|
||||
scale = args.resize_face_size / face_size
|
||||
if scale < cur_crop_width / w:
|
||||
logger.warning(
|
||||
print(
|
||||
f"image width too small in face size based resizing / 顔を基準にリサイズすると画像の幅がcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}")
|
||||
scale = cur_crop_width / w
|
||||
if scale < cur_crop_height / h:
|
||||
logger.warning(
|
||||
print(
|
||||
f"image height too small in face size based resizing / 顔を基準にリサイズすると画像の高さがcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}")
|
||||
scale = cur_crop_height / h
|
||||
elif crop_h_ratio is not None:
|
||||
@@ -161,10 +157,10 @@ def process(args):
|
||||
else:
|
||||
# 切り出しサイズ指定あり
|
||||
if w < cur_crop_width:
|
||||
logger.warning(f"image width too small/ 画像の幅がcrop sizeより小さいので画質が劣化します: {path}")
|
||||
print(f"image width too small/ 画像の幅がcrop sizeより小さいので画質が劣化します: {path}")
|
||||
scale = cur_crop_width / w
|
||||
if h < cur_crop_height:
|
||||
logger.warning(f"image height too small/ 画像の高さがcrop sizeより小さいので画質が劣化します: {path}")
|
||||
print(f"image height too small/ 画像の高さがcrop sizeより小さいので画質が劣化します: {path}")
|
||||
scale = cur_crop_height / h
|
||||
if args.resize_fit:
|
||||
scale = max(cur_crop_width / w, cur_crop_height / h)
|
||||
@@ -202,7 +198,7 @@ def process(args):
|
||||
face_img = face_img[y:y + cur_crop_height]
|
||||
|
||||
# # debug
|
||||
# logger.info(path, cx, cy, angle)
|
||||
# print(path, cx, cy, angle)
|
||||
# crp = cv2.resize(image, (image.shape[1]//8, image.shape[0]//8))
|
||||
# cv2.imshow("image", crp)
|
||||
# if cv2.waitKey() == 27:
|
||||
|
||||
@@ -11,16 +11,10 @@ from typing import Dict, List
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, get_preferred_device
|
||||
init_ipex()
|
||||
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=1):
|
||||
@@ -222,7 +216,7 @@ class Upscaler(nn.Module):
|
||||
upsampled_images = upsampled_images / 127.5 - 1.0
|
||||
|
||||
# convert upsample images to latents with batch size
|
||||
# logger.info("Encoding upsampled (LANCZOS4) images...")
|
||||
# print("Encoding upsampled (LANCZOS4) images...")
|
||||
upsampled_latents = []
|
||||
for i in tqdm(range(0, upsampled_images.shape[0], vae_batch_size)):
|
||||
batch = upsampled_images[i : i + vae_batch_size].to(vae.device)
|
||||
@@ -233,7 +227,7 @@ class Upscaler(nn.Module):
|
||||
upsampled_latents = torch.cat(upsampled_latents, dim=0)
|
||||
|
||||
# upscale (refine) latents with this model with batch size
|
||||
logger.info("Upscaling latents...")
|
||||
print("Upscaling latents...")
|
||||
upscaled_latents = []
|
||||
for i in range(0, upsampled_latents.shape[0], batch_size):
|
||||
with torch.no_grad():
|
||||
@@ -248,7 +242,7 @@ def create_upscaler(**kwargs):
|
||||
weights = kwargs["weights"]
|
||||
model = Upscaler()
|
||||
|
||||
logger.info(f"Loading weights from {weights}...")
|
||||
print(f"Loading weights from {weights}...")
|
||||
if os.path.splitext(weights)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
@@ -261,20 +255,20 @@ def create_upscaler(**kwargs):
|
||||
|
||||
# another interface: upscale images with a model for given images from command line
|
||||
def upscale_images(args: argparse.Namespace):
|
||||
DEVICE = get_preferred_device()
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
us_dtype = torch.float16 # TODO: support fp32/bf16
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# load VAE with Diffusers
|
||||
assert args.vae_path is not None, "VAE path is required"
|
||||
logger.info(f"Loading VAE from {args.vae_path}...")
|
||||
print(f"Loading VAE from {args.vae_path}...")
|
||||
vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae")
|
||||
vae.to(DEVICE, dtype=us_dtype)
|
||||
|
||||
# prepare model
|
||||
logger.info("Preparing model...")
|
||||
print("Preparing model...")
|
||||
upscaler: Upscaler = create_upscaler(weights=args.weights)
|
||||
# logger.info("Loading weights from", args.weights)
|
||||
# print("Loading weights from", args.weights)
|
||||
# upscaler.load_state_dict(torch.load(args.weights))
|
||||
upscaler.eval()
|
||||
upscaler.to(DEVICE, dtype=us_dtype)
|
||||
@@ -309,14 +303,14 @@ def upscale_images(args: argparse.Namespace):
|
||||
image_debug.save(dest_file_name)
|
||||
|
||||
# upscale
|
||||
logger.info("Upscaling...")
|
||||
print("Upscaling...")
|
||||
upscaled_latents = upscaler.upscale(
|
||||
vae, images, None, us_dtype, width * 2, height * 2, batch_size=args.batch_size, vae_batch_size=args.vae_batch_size
|
||||
)
|
||||
upscaled_latents /= 0.18215
|
||||
|
||||
# decode with batch
|
||||
logger.info("Decoding...")
|
||||
print("Decoding...")
|
||||
upscaled_images = []
|
||||
for i in tqdm(range(0, upscaled_latents.shape[0], args.vae_batch_size)):
|
||||
with torch.no_grad():
|
||||
|
||||
@@ -5,10 +5,7 @@ import torch
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_unet_key(key):
|
||||
# VAE or TextEncoder, the last one is for SDXL
|
||||
@@ -48,10 +45,10 @@ def merge(args):
|
||||
# check if all models are safetensors
|
||||
for model in args.models:
|
||||
if not model.endswith("safetensors"):
|
||||
logger.info(f"Model {model} is not a safetensors model")
|
||||
print(f"Model {model} is not a safetensors model")
|
||||
exit()
|
||||
if not os.path.isfile(model):
|
||||
logger.info(f"Model {model} does not exist")
|
||||
print(f"Model {model} does not exist")
|
||||
exit()
|
||||
|
||||
assert args.ratios is None or len(args.models) == len(args.ratios), "ratios must be the same length as models"
|
||||
@@ -68,7 +65,7 @@ def merge(args):
|
||||
|
||||
if merged_sd is None:
|
||||
# load first model
|
||||
logger.info(f"Loading model {model}, ratio = {ratio}...")
|
||||
print(f"Loading model {model}, ratio = {ratio}...")
|
||||
merged_sd = {}
|
||||
with safe_open(model, framework="pt", device=args.device) as f:
|
||||
for key in tqdm(f.keys()):
|
||||
@@ -84,11 +81,11 @@ def merge(args):
|
||||
value = ratio * value.to(dtype) # first model's value * ratio
|
||||
merged_sd[key] = value
|
||||
|
||||
logger.info(f"Model has {len(merged_sd)} keys " + ("(UNet only)" if args.unet_only else ""))
|
||||
print(f"Model has {len(merged_sd)} keys " + ("(UNet only)" if args.unet_only else ""))
|
||||
continue
|
||||
|
||||
# load other models
|
||||
logger.info(f"Loading model {model}, ratio = {ratio}...")
|
||||
print(f"Loading model {model}, ratio = {ratio}...")
|
||||
|
||||
with safe_open(model, framework="pt", device=args.device) as f:
|
||||
model_keys = f.keys()
|
||||
@@ -96,7 +93,7 @@ def merge(args):
|
||||
_, new_key = replace_text_encoder_key(key)
|
||||
if new_key not in merged_sd:
|
||||
if args.show_skipped and new_key not in first_model_keys:
|
||||
logger.info(f"Skip: {new_key}")
|
||||
print(f"Skip: {new_key}")
|
||||
continue
|
||||
|
||||
value = f.get_tensor(key)
|
||||
@@ -107,7 +104,7 @@ def merge(args):
|
||||
for key in merged_sd.keys():
|
||||
if key in model_keys:
|
||||
continue
|
||||
logger.warning(f"Key {key} not in model {model}, use first model's value")
|
||||
print(f"Key {key} not in model {model}, use first model's value")
|
||||
if key in supplementary_key_ratios:
|
||||
supplementary_key_ratios[key] += ratio
|
||||
else:
|
||||
@@ -115,7 +112,7 @@ def merge(args):
|
||||
|
||||
# add supplementary keys' value (including VAE and TextEncoder)
|
||||
if len(supplementary_key_ratios) > 0:
|
||||
logger.info("add first model's value")
|
||||
print("add first model's value")
|
||||
with safe_open(args.models[0], framework="pt", device=args.device) as f:
|
||||
for key in tqdm(f.keys()):
|
||||
_, new_key = replace_text_encoder_key(key)
|
||||
@@ -123,7 +120,7 @@ def merge(args):
|
||||
continue
|
||||
|
||||
if is_unet_key(new_key): # not VAE or TextEncoder
|
||||
logger.warning(f"Key {new_key} not in all models, ratio = {supplementary_key_ratios[new_key]}")
|
||||
print(f"Key {new_key} not in all models, ratio = {supplementary_key_ratios[new_key]}")
|
||||
|
||||
value = f.get_tensor(key) # original key
|
||||
|
||||
@@ -137,7 +134,7 @@ def merge(args):
|
||||
if not output_file.endswith(".safetensors"):
|
||||
output_file = output_file + ".safetensors"
|
||||
|
||||
logger.info(f"Saving to {output_file}...")
|
||||
print(f"Saving to {output_file}...")
|
||||
|
||||
# convert to save_dtype
|
||||
for k in merged_sd.keys():
|
||||
@@ -145,7 +142,7 @@ def merge(args):
|
||||
|
||||
save_file(merged_sd, output_file)
|
||||
|
||||
logger.info("Done!")
|
||||
print("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -7,10 +7,7 @@ from safetensors.torch import load_file
|
||||
from library.original_unet import UNet2DConditionModel, SampleOutput
|
||||
|
||||
import library.model_util as model_util
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ControlNetInfo(NamedTuple):
|
||||
unet: Any
|
||||
@@ -54,7 +51,7 @@ def load_control_net(v2, unet, model):
|
||||
|
||||
# control sdからキー変換しつつU-Netに対応する部分のみ取り出し、DiffusersのU-Netに読み込む
|
||||
# state dictを読み込む
|
||||
logger.info(f"ControlNet: loading control SD model : {model}")
|
||||
print(f"ControlNet: loading control SD model : {model}")
|
||||
|
||||
if model_util.is_safetensors(model):
|
||||
ctrl_sd_sd = load_file(model)
|
||||
@@ -64,7 +61,7 @@ def load_control_net(v2, unet, model):
|
||||
|
||||
# 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む
|
||||
is_difference = "difference" in ctrl_sd_sd
|
||||
logger.info(f"ControlNet: loading difference: {is_difference}")
|
||||
print("ControlNet: loading difference:", is_difference)
|
||||
|
||||
# ControlNetには存在しないキーがあるので、まず現在のU-NetでSD版の全keyを作っておく
|
||||
# またTransfer Controlの元weightとなる
|
||||
@@ -92,13 +89,13 @@ def load_control_net(v2, unet, model):
|
||||
# ControlNetのU-Netを作成する
|
||||
ctrl_unet = UNet2DConditionModel(**unet_config)
|
||||
info = ctrl_unet.load_state_dict(ctrl_unet_du_sd)
|
||||
logger.info(f"ControlNet: loading Control U-Net: {info}")
|
||||
print("ControlNet: loading Control U-Net:", info)
|
||||
|
||||
# U-Net以外のControlNetを作成する
|
||||
# TODO support middle only
|
||||
ctrl_net = ControlNet()
|
||||
info = ctrl_net.load_state_dict(zero_conv_sd)
|
||||
logger.info("ControlNet: loading ControlNet: {info}")
|
||||
print("ControlNet: loading ControlNet:", info)
|
||||
|
||||
ctrl_unet.to(unet.device, dtype=unet.dtype)
|
||||
ctrl_net.to(unet.device, dtype=unet.dtype)
|
||||
@@ -120,7 +117,7 @@ def load_preprocess(prep_type: str):
|
||||
|
||||
return canny
|
||||
|
||||
logger.info(f"Unsupported prep type: {prep_type}")
|
||||
print("Unsupported prep type:", prep_type)
|
||||
return None
|
||||
|
||||
|
||||
@@ -177,26 +174,13 @@ def call_unet_and_control_net(
|
||||
cnet_idx = step % cnet_cnt
|
||||
cnet_info = control_nets[cnet_idx]
|
||||
|
||||
# logger.info(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
|
||||
# print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
|
||||
if cnet_info.ratio < current_ratio:
|
||||
return original_unet(sample, timestep, encoder_hidden_states)
|
||||
|
||||
guided_hint = guided_hints[cnet_idx]
|
||||
|
||||
# gradual latent support: match the size of guided_hint to the size of sample
|
||||
if guided_hint.shape[-2:] != sample.shape[-2:]:
|
||||
# print(f"guided_hint.shape={guided_hint.shape}, sample.shape={sample.shape}")
|
||||
org_dtype = guided_hint.dtype
|
||||
if org_dtype == torch.bfloat16:
|
||||
guided_hint = guided_hint.to(torch.float32)
|
||||
guided_hint = torch.nn.functional.interpolate(guided_hint, size=sample.shape[-2:], mode="bicubic")
|
||||
if org_dtype == torch.bfloat16:
|
||||
guided_hint = guided_hint.to(org_dtype)
|
||||
|
||||
guided_hint = guided_hint.repeat((num_latent_input, 1, 1, 1))
|
||||
outs = unet_forward(
|
||||
True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states_for_control_net
|
||||
)
|
||||
outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states_for_control_net)
|
||||
outs = [o * cnet_info.weight for o in outs]
|
||||
|
||||
# U-Net
|
||||
@@ -208,7 +192,7 @@ def call_unet_and_control_net(
|
||||
# ControlNet
|
||||
cnet_outs_list = []
|
||||
for i, cnet_info in enumerate(control_nets):
|
||||
# logger.info(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
|
||||
# print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
|
||||
if cnet_info.ratio < current_ratio:
|
||||
continue
|
||||
guided_hint = guided_hints[i]
|
||||
@@ -248,7 +232,7 @@ def unet_forward(
|
||||
upsample_size = None
|
||||
|
||||
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
||||
logger.info("Forward upsample size to force interpolation output size.")
|
||||
print("Forward upsample size to force interpolation output size.")
|
||||
forward_upsample_size = True
|
||||
|
||||
# 1. time
|
||||
|
||||
@@ -6,10 +6,7 @@ import shutil
|
||||
import math
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2, interpolation=None, save_as_png=False, copy_associated_files=False):
|
||||
# Split the max_resolution string by "," and strip any whitespaces
|
||||
@@ -86,7 +83,7 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
|
||||
image.save(os.path.join(dst_img_folder, new_filename), quality=100)
|
||||
|
||||
proc = "Resized" if current_pixels > max_pixels else "Saved"
|
||||
logger.info(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}")
|
||||
print(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}")
|
||||
|
||||
# If other files with same basename, copy them with resolution suffix
|
||||
if copy_associated_files:
|
||||
@@ -97,7 +94,7 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
|
||||
continue
|
||||
for max_resolution in max_resolutions:
|
||||
new_asoc_file = base + '+' + max_resolution + ext
|
||||
logger.info(f"Copy {asoc_file} as {new_asoc_file}")
|
||||
print(f"Copy {asoc_file} as {new_asoc_file}")
|
||||
shutil.copy(os.path.join(src_img_folder, asoc_file), os.path.join(dst_img_folder, new_asoc_file))
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
import json
|
||||
import argparse
|
||||
from safetensors import safe_open
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model", type=str, required=True)
|
||||
@@ -14,10 +10,10 @@ with safe_open(args.model, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
|
||||
if metadata is None:
|
||||
logger.error("No metadata found")
|
||||
print("No metadata found")
|
||||
else:
|
||||
# metadata is json dict, but not pretty printed
|
||||
# sort by key and pretty print
|
||||
print(json.dumps(metadata, indent=4, sort_keys=True))
|
||||
|
||||
|
||||
|
||||
@@ -1,191 +0,0 @@
|
||||
# Stable Cascadeのlatentsをdiskにキャッシュする
|
||||
# cache latents of Stable Cascade to disk
|
||||
|
||||
import argparse
|
||||
import math
|
||||
from multiprocessing import Value
|
||||
import os
|
||||
|
||||
from accelerate.utils import set_seed
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from library import stable_cascade_utils as sc_utils
|
||||
from library import config_util
|
||||
from library import train_util
|
||||
from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
)
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
|
||||
# check cache latents arg
|
||||
assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります"
|
||||
|
||||
use_dreambooth_method = args.in_json is None
|
||||
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed) # 乱数系列を初期化する
|
||||
|
||||
# tokenizerを準備する:datasetを動かすために必要
|
||||
tokenizer = sc_utils.load_tokenizer(args)
|
||||
tokenizers = [tokenizer]
|
||||
|
||||
# データセットを準備する
|
||||
if args.dataset_class is None:
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
|
||||
if args.dataset_config is not None:
|
||||
logger.info(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
logger.warning(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
)
|
||||
else:
|
||||
if use_dreambooth_method:
|
||||
logger.info("Using DreamBooth method.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
|
||||
args.train_data_dir, args.reg_data_dir
|
||||
)
|
||||
}
|
||||
]
|
||||
}
|
||||
else:
|
||||
logger.info("Training with captions.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": [
|
||||
{
|
||||
"image_dir": args.train_data_dir,
|
||||
"metadata_file": args.in_json,
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers)
|
||||
|
||||
# datasetのcache_latentsを呼ばなければ、生の画像が返る
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
# acceleratorを準備する
|
||||
logger.info("prepare accelerator")
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
weight_dtype, _ = train_util.prepare_dtype(args)
|
||||
effnet_dtype = torch.float32 if args.no_half_vae else weight_dtype
|
||||
|
||||
# モデルを読み込む
|
||||
logger.info("load model")
|
||||
effnet = sc_utils.load_effnet(args.effnet_checkpoint_path, accelerator.device)
|
||||
effnet.to(accelerator.device, dtype=effnet_dtype)
|
||||
effnet.requires_grad_(False)
|
||||
effnet.eval()
|
||||
|
||||
# dataloaderを準備する
|
||||
train_dataset_group.set_caching_mode("latents")
|
||||
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset_group,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
# acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず
|
||||
train_dataloader = accelerator.prepare(train_dataloader)
|
||||
|
||||
# データ取得のためのループ
|
||||
for batch in tqdm(train_dataloader):
|
||||
b_size = len(batch["images"])
|
||||
vae_batch_size = b_size if args.vae_batch_size is None else args.vae_batch_size
|
||||
flip_aug = batch["flip_aug"]
|
||||
random_crop = batch["random_crop"]
|
||||
bucket_reso = batch["bucket_reso"]
|
||||
|
||||
# バッチを分割して処理する
|
||||
for i in range(0, b_size, vae_batch_size):
|
||||
images = batch["images"][i : i + vae_batch_size]
|
||||
absolute_paths = batch["absolute_paths"][i : i + vae_batch_size]
|
||||
resized_sizes = batch["resized_sizes"][i : i + vae_batch_size]
|
||||
|
||||
image_infos = []
|
||||
for i, (image, absolute_path, resized_size) in enumerate(zip(images, absolute_paths, resized_sizes)):
|
||||
image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path)
|
||||
image_info.image = image
|
||||
image_info.bucket_reso = bucket_reso
|
||||
image_info.resized_size = resized_size
|
||||
image_info.latents_npz = os.path.splitext(absolute_path)[0] + train_util.STABLE_CASCADE_LATENTS_CACHE_SUFFIX
|
||||
|
||||
if args.skip_existing:
|
||||
if train_util.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug, 32):
|
||||
logger.warning(f"Skipping {image_info.latents_npz} because it already exists.")
|
||||
continue
|
||||
|
||||
image_infos.append(image_info)
|
||||
|
||||
if len(image_infos) > 0:
|
||||
train_util.cache_batch_latents(effnet, True, image_infos, flip_aug, random_crop)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
train_util.add_tokenizer_arguments(parser)
|
||||
sc_utils.add_effnet_arguments(parser)
|
||||
train_util.add_training_arguments(parser, True)
|
||||
train_util.add_dataset_arguments(parser, True, True, True)
|
||||
config_util.add_config_arguments(parser)
|
||||
parser.add_argument(
|
||||
"--no_half_vae",
|
||||
action="store_true",
|
||||
help="do not use fp16/bf16 Effnet in mixed precision (use float Effnet) / mixed precisionでも fp16/bf16 Effnetを使わずfloat Effnetを使う",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_existing",
|
||||
action="store_true",
|
||||
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
cache_to_disk(args)
|
||||
@@ -1,183 +0,0 @@
|
||||
# text encoder出力のdiskへの事前キャッシュを行う / cache text encoder outputs to disk in advance
|
||||
|
||||
import argparse
|
||||
import math
|
||||
from multiprocessing import Value
|
||||
import os
|
||||
|
||||
from accelerate.utils import set_seed
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from library import config_util
|
||||
from library import train_util
|
||||
from library import sdxl_train_util
|
||||
from library import stable_cascade_utils as sc_utils
|
||||
from library.config_util import (
|
||||
ConfigSanitizer,
|
||||
BlueprintGenerator,
|
||||
)
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
|
||||
# check cache arg
|
||||
assert (
|
||||
args.cache_text_encoder_outputs_to_disk
|
||||
), "cache_text_encoder_outputs_to_disk must be True / cache_text_encoder_outputs_to_diskはTrueである必要があります"
|
||||
|
||||
use_dreambooth_method = args.in_json is None
|
||||
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed) # 乱数系列を初期化する
|
||||
|
||||
# tokenizerを準備する:datasetを動かすために必要
|
||||
tokenizer = sc_utils.load_tokenizer(args)
|
||||
tokenizers = [tokenizer]
|
||||
|
||||
# データセットを準備する
|
||||
if args.dataset_class is None:
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
|
||||
if args.dataset_config is not None:
|
||||
logger.info(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
logger.warning(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
)
|
||||
else:
|
||||
if use_dreambooth_method:
|
||||
logger.info("Using DreamBooth method.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
|
||||
args.train_data_dir, args.reg_data_dir
|
||||
)
|
||||
}
|
||||
]
|
||||
}
|
||||
else:
|
||||
logger.info("Training with captions.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
"subsets": [
|
||||
{
|
||||
"image_dir": args.train_data_dir,
|
||||
"metadata_file": args.in_json,
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers)
|
||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||
else:
|
||||
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers)
|
||||
|
||||
current_epoch = Value("i", 0)
|
||||
current_step = Value("i", 0)
|
||||
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
# acceleratorを準備する
|
||||
logger.info("prepare accelerator")
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
weight_dtype, _ = train_util.prepare_dtype(args)
|
||||
|
||||
# モデルを読み込む
|
||||
logger.info("load model")
|
||||
text_encoder = sc_utils.load_clip_text_model(
|
||||
args.text_model_checkpoint_path, weight_dtype, accelerator.device, args.save_text_model
|
||||
)
|
||||
text_encoders = [text_encoder]
|
||||
for text_encoder in text_encoders:
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder.requires_grad_(False)
|
||||
text_encoder.eval()
|
||||
|
||||
# dataloaderを準備する
|
||||
train_dataset_group.set_caching_mode("text")
|
||||
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset_group,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
# acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず
|
||||
train_dataloader = accelerator.prepare(train_dataloader)
|
||||
|
||||
# データ取得のためのループ
|
||||
for batch in tqdm(train_dataloader):
|
||||
absolute_paths = batch["absolute_paths"]
|
||||
input_ids1_list = batch["input_ids1_list"]
|
||||
|
||||
image_infos = []
|
||||
for absolute_path, input_ids1 in zip(absolute_paths, input_ids1_list):
|
||||
image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path)
|
||||
image_info.text_encoder_outputs_npz = os.path.splitext(absolute_path)[0] + sc_utils.TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX
|
||||
image_info
|
||||
|
||||
if args.skip_existing:
|
||||
if os.path.exists(image_info.text_encoder_outputs_npz):
|
||||
logger.warning(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.")
|
||||
continue
|
||||
|
||||
image_info.input_ids1 = input_ids1
|
||||
image_infos.append(image_info)
|
||||
|
||||
if len(image_infos) > 0:
|
||||
b_input_ids1 = torch.stack([image_info.input_ids1 for image_info in image_infos])
|
||||
train_util.cache_batch_text_encoder_outputs(
|
||||
image_infos, tokenizers, text_encoders, args.max_token_length, True, b_input_ids1, None, weight_dtype
|
||||
)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
train_util.add_tokenizer_arguments(parser)
|
||||
sc_utils.add_text_model_arguments(parser)
|
||||
train_util.add_training_arguments(parser, True)
|
||||
train_util.add_dataset_arguments(parser, True, True, True)
|
||||
config_util.add_config_arguments(parser)
|
||||
sdxl_train_util.add_sdxl_training_arguments(parser)
|
||||
parser.add_argument(
|
||||
"--skip_existing",
|
||||
action="store_true",
|
||||
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
cache_to_disk(args)
|
||||
@@ -1,4 +1,5 @@
|
||||
import argparse
|
||||
import gc
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
@@ -9,9 +10,10 @@ from types import SimpleNamespace
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
@@ -33,12 +35,6 @@ from library.custom_train_functions import (
|
||||
pyramid_noise_like,
|
||||
apply_noise_offset,
|
||||
)
|
||||
from library.utils import setup_logging, add_logging_arguments
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# TODO 他のスクリプトと共通化する
|
||||
@@ -60,7 +56,6 @@ def train(args):
|
||||
# training_started_at = time.time()
|
||||
train_util.verify_training_args(args)
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
setup_logging(args, reset=True)
|
||||
|
||||
cache_latents = args.cache_latents
|
||||
use_user_config = args.dataset_config is not None
|
||||
@@ -74,11 +69,11 @@ def train(args):
|
||||
# データセットを準備する
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
|
||||
if use_user_config:
|
||||
logger.info(f"Load dataset config from {args.dataset_config}")
|
||||
print(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "conditioning_data_dir"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
logger.warning(
|
||||
print(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
@@ -108,7 +103,7 @@ def train(args):
|
||||
train_util.debug_dataset(train_dataset_group)
|
||||
return
|
||||
if len(train_dataset_group) == 0:
|
||||
logger.error(
|
||||
print(
|
||||
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)"
|
||||
)
|
||||
return
|
||||
@@ -119,7 +114,7 @@ def train(args):
|
||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
|
||||
# acceleratorを準備する
|
||||
logger.info("prepare accelerator")
|
||||
print("prepare accelerator")
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
is_main_process = accelerator.is_main_process
|
||||
|
||||
@@ -224,8 +219,10 @@ def train(args):
|
||||
accelerator.is_main_process,
|
||||
)
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
@@ -239,8 +236,8 @@ def train(args):
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset_group,
|
||||
@@ -256,9 +253,7 @@ def train(args):
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||
)
|
||||
accelerator.print(
|
||||
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
|
||||
)
|
||||
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
# データセット側にも学習ステップを送信
|
||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
@@ -314,10 +309,8 @@ def train(args):
|
||||
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
||||
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
||||
accelerator.print(
|
||||
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
|
||||
)
|
||||
# logger.info(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
||||
accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
|
||||
# print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
||||
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||
|
||||
@@ -339,7 +332,7 @@ def train(args):
|
||||
if accelerator.is_main_process:
|
||||
init_kwargs = {}
|
||||
if args.wandb_run_name:
|
||||
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
||||
init_kwargs['wandb'] = {'name': args.wandb_run_name}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers(
|
||||
@@ -574,13 +567,12 @@ def train(args):
|
||||
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
|
||||
save_model(ckpt_name, controlnet, force_sync_upload=True)
|
||||
|
||||
logger.info("model saved.")
|
||||
print("model saved.")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
add_logging_arguments(parser)
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, False, True, True)
|
||||
train_util.add_training_arguments(parser, False)
|
||||
|
||||
40
train_db.py
40
train_db.py
@@ -1,6 +1,7 @@
|
||||
# DreamBooth training
|
||||
# XXX dropped option: fine_tune
|
||||
|
||||
import gc
|
||||
import argparse
|
||||
import itertools
|
||||
import math
|
||||
@@ -9,9 +10,10 @@ from multiprocessing import Value
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
from accelerate.utils import set_seed
|
||||
@@ -33,12 +35,6 @@ from library.custom_train_functions import (
|
||||
scale_v_prediction_loss_like_noise_prediction,
|
||||
apply_debiased_estimation,
|
||||
)
|
||||
from library.utils import setup_logging, add_logging_arguments
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# perlin_noise,
|
||||
|
||||
@@ -46,7 +42,6 @@ logger = logging.getLogger(__name__)
|
||||
def train(args):
|
||||
train_util.verify_training_args(args)
|
||||
train_util.prepare_dataset_args(args, False)
|
||||
setup_logging(args, reset=True)
|
||||
|
||||
cache_latents = args.cache_latents
|
||||
|
||||
@@ -59,11 +54,11 @@ def train(args):
|
||||
if args.dataset_class is None:
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, False, True))
|
||||
if args.dataset_config is not None:
|
||||
logger.info(f"Load dataset config from {args.dataset_config}")
|
||||
print(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "reg_data_dir"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
logger.warning(
|
||||
print(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
@@ -98,13 +93,13 @@ def train(args):
|
||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
|
||||
# acceleratorを準備する
|
||||
logger.info("prepare accelerator")
|
||||
print("prepare accelerator")
|
||||
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
logger.warning(
|
||||
print(
|
||||
f"gradient_accumulation_steps is {args.gradient_accumulation_steps}. accelerate does not support gradient_accumulation_steps when training multiple models (U-Net and Text Encoder), so something might be wrong"
|
||||
)
|
||||
logger.warning(
|
||||
print(
|
||||
f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデル(U-NetおよびText Encoder)の学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です"
|
||||
)
|
||||
|
||||
@@ -143,7 +138,9 @@ def train(args):
|
||||
with torch.no_grad():
|
||||
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
@@ -180,8 +177,8 @@ def train(args):
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset_group,
|
||||
batch_size=1,
|
||||
@@ -196,9 +193,7 @@ def train(args):
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||
)
|
||||
accelerator.print(
|
||||
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
|
||||
)
|
||||
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
# データセット側にも学習ステップを送信
|
||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
@@ -269,7 +264,7 @@ def train(args):
|
||||
if accelerator.is_main_process:
|
||||
init_kwargs = {}
|
||||
if args.wandb_run_name:
|
||||
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
||||
init_kwargs['wandb'] = {'name': args.wandb_run_name}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
|
||||
@@ -454,13 +449,12 @@ def train(args):
|
||||
train_util.save_sd_model_on_train_end(
|
||||
args, src_path, save_stable_diffusion_format, use_safetensors, save_dtype, epoch, global_step, text_encoder, unet, vae
|
||||
)
|
||||
logger.info("model saved.")
|
||||
print("model saved.")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
add_logging_arguments(parser)
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, True, False, True)
|
||||
train_util.add_training_arguments(parser, True)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import importlib
|
||||
import argparse
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
@@ -10,13 +11,13 @@ from multiprocessing import Value
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
init_ipex()
|
||||
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DDPMScheduler
|
||||
from library import model_util
|
||||
@@ -40,12 +41,6 @@ from library.custom_train_functions import (
|
||||
add_v_prediction_like_loss,
|
||||
apply_debiased_estimation,
|
||||
)
|
||||
from library.utils import setup_logging, add_logging_arguments
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NetworkTrainer:
|
||||
@@ -141,7 +136,6 @@ class NetworkTrainer:
|
||||
training_started_at = time.time()
|
||||
train_util.verify_training_args(args)
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
setup_logging(args, reset=True)
|
||||
|
||||
cache_latents = args.cache_latents
|
||||
use_dreambooth_method = args.in_json is None
|
||||
@@ -159,18 +153,18 @@ class NetworkTrainer:
|
||||
if args.dataset_class is None:
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
|
||||
if use_user_config:
|
||||
logger.info(f"Loading dataset config from {args.dataset_config}")
|
||||
print(f"Loading dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
logger.warning(
|
||||
print(
|
||||
"ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
)
|
||||
else:
|
||||
if use_dreambooth_method:
|
||||
logger.info("Using DreamBooth method.")
|
||||
print("Using DreamBooth method.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
@@ -181,7 +175,7 @@ class NetworkTrainer:
|
||||
]
|
||||
}
|
||||
else:
|
||||
logger.info("Training with captions.")
|
||||
print("Training with captions.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
@@ -210,7 +204,7 @@ class NetworkTrainer:
|
||||
train_util.debug_dataset(train_dataset_group)
|
||||
return
|
||||
if len(train_dataset_group) == 0:
|
||||
logger.error(
|
||||
print(
|
||||
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)"
|
||||
)
|
||||
return
|
||||
@@ -223,7 +217,7 @@ class NetworkTrainer:
|
||||
self.assert_extra_args(args, train_dataset_group)
|
||||
|
||||
# acceleratorを準備する
|
||||
logger.info("preparing accelerator")
|
||||
print("preparing accelerator")
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
is_main_process = accelerator.is_main_process
|
||||
|
||||
@@ -272,7 +266,9 @@ class NetworkTrainer:
|
||||
with torch.no_grad():
|
||||
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
@@ -314,7 +310,7 @@ class NetworkTrainer:
|
||||
if hasattr(network, "prepare_network"):
|
||||
network.prepare_network(args)
|
||||
if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"):
|
||||
logger.warning(
|
||||
print(
|
||||
"warning: scale_weight_norms is specified but the network does not support it / scale_weight_normsが指定されていますが、ネットワークが対応していません"
|
||||
)
|
||||
args.scale_weight_norms = False
|
||||
@@ -349,8 +345,8 @@ class NetworkTrainer:
|
||||
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset_group,
|
||||
@@ -942,13 +938,12 @@ class NetworkTrainer:
|
||||
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
|
||||
save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True)
|
||||
|
||||
logger.info("model saved.")
|
||||
print("model saved.")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
add_logging_arguments(parser)
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, True, True, True)
|
||||
train_util.add_training_arguments(parser, True)
|
||||
@@ -956,9 +951,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
config_util.add_config_arguments(parser)
|
||||
custom_train_functions.add_custom_train_arguments(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない"
|
||||
)
|
||||
parser.add_argument("--no_metadata", action="store_true", help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
|
||||
parser.add_argument(
|
||||
"--save_model_as",
|
||||
type=str,
|
||||
@@ -970,17 +963,10 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
||||
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
|
||||
|
||||
parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み")
|
||||
parser.add_argument("--network_module", type=str, default=None, help="network module to train / 学習対象のネットワークのモジュール")
|
||||
parser.add_argument(
|
||||
"--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--network_module", type=str, default=None, help="network module to train / 学習対象のネットワークのモジュール"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--network_dim",
|
||||
type=int,
|
||||
default=None,
|
||||
help="network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)",
|
||||
"--network_dim", type=int, default=None, help="network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--network_alpha",
|
||||
@@ -995,25 +981,14 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="Drops neurons out of training every step (0 or None is default behavior (no dropout), 1 would drop all neurons) / 訓練時に毎ステップでニューロンをdropする(0またはNoneはdropoutなし、1は全ニューロンをdropout)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--network_args",
|
||||
type=str,
|
||||
default=None,
|
||||
nargs="*",
|
||||
help="additional arguments for network (key=value) / ネットワークへの追加の引数",
|
||||
"--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数"
|
||||
)
|
||||
parser.add_argument("--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する")
|
||||
parser.add_argument(
|
||||
"--network_train_text_encoder_only", action="store_true", help="only training Text Encoder part / Text Encoder関連部分のみ学習する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--network_train_text_encoder_only",
|
||||
action="store_true",
|
||||
help="only training Text Encoder part / Text Encoder関連部分のみ学習する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--training_comment",
|
||||
type=str,
|
||||
default=None,
|
||||
help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列",
|
||||
"--training_comment", type=str, default=None, help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dim_from_weights",
|
||||
|
||||
1039
train_network_appl_weights.py
Normal file
1039
train_network_appl_weights.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,13 +1,15 @@
|
||||
import argparse
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
from multiprocessing import Value
|
||||
import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
from accelerate.utils import set_seed
|
||||
@@ -30,12 +32,6 @@ from library.custom_train_functions import (
|
||||
add_v_prediction_like_loss,
|
||||
apply_debiased_estimation,
|
||||
)
|
||||
from library.utils import setup_logging, add_logging_arguments
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
imagenet_templates_small = [
|
||||
"a photo of a {}",
|
||||
@@ -172,7 +168,6 @@ class TextualInversionTrainer:
|
||||
|
||||
train_util.verify_training_args(args)
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
setup_logging(args, reset=True)
|
||||
|
||||
cache_latents = args.cache_latents
|
||||
|
||||
@@ -183,7 +178,7 @@ class TextualInversionTrainer:
|
||||
tokenizers = tokenizer_or_list if isinstance(tokenizer_or_list, list) else [tokenizer_or_list]
|
||||
|
||||
# acceleratorを準備する
|
||||
logger.info("prepare accelerator")
|
||||
print("prepare accelerator")
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
@@ -293,7 +288,7 @@ class TextualInversionTrainer:
|
||||
]
|
||||
}
|
||||
else:
|
||||
logger.info("Train with captions.")
|
||||
print("Train with captions.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
@@ -368,7 +363,9 @@ class TextualInversionTrainer:
|
||||
with torch.no_grad():
|
||||
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
@@ -385,8 +382,8 @@ class TextualInversionTrainer:
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset_group,
|
||||
batch_size=1,
|
||||
@@ -739,13 +736,12 @@ class TextualInversionTrainer:
|
||||
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
|
||||
save_model(ckpt_name, updated_embs_list, global_step, num_train_epochs, force_sync_upload=True)
|
||||
|
||||
logger.info("model saved.")
|
||||
print("model saved.")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
add_logging_arguments(parser)
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, True, True, False)
|
||||
train_util.add_training_arguments(parser, True)
|
||||
@@ -761,9 +757,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み"
|
||||
)
|
||||
parser.add_argument("--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み")
|
||||
parser.add_argument(
|
||||
"--num_vectors_per_token", type=int, default=1, help="number of vectors per token / トークンに割り当てるembeddingsの要素数"
|
||||
)
|
||||
@@ -773,9 +767,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
default=None,
|
||||
help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可"
|
||||
)
|
||||
parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可")
|
||||
parser.add_argument(
|
||||
"--use_object_template",
|
||||
action="store_true",
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
import importlib
|
||||
import argparse
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
import toml
|
||||
from multiprocessing import Value
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
from library.ipex_interop import init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
from accelerate.utils import set_seed
|
||||
@@ -34,12 +36,6 @@ from library.custom_train_functions import (
|
||||
)
|
||||
import library.original_unet as original_unet
|
||||
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
|
||||
from library.utils import setup_logging, add_logging_arguments
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
imagenet_templates_small = [
|
||||
"a photo of a {}",
|
||||
@@ -98,13 +94,12 @@ def train(args):
|
||||
if args.output_name is None:
|
||||
args.output_name = args.token_string
|
||||
use_template = args.use_object_template or args.use_style_template
|
||||
setup_logging(args, reset=True)
|
||||
|
||||
train_util.verify_training_args(args)
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
|
||||
if args.sample_every_n_steps is not None or args.sample_every_n_epochs is not None:
|
||||
logger.warning(
|
||||
print(
|
||||
"sample_every_n_steps and sample_every_n_epochs are not supported in this script currently / sample_every_n_stepsとsample_every_n_epochsは現在このスクリプトではサポートされていません"
|
||||
)
|
||||
assert (
|
||||
@@ -119,7 +114,7 @@ def train(args):
|
||||
tokenizer = train_util.load_tokenizer(args)
|
||||
|
||||
# acceleratorを準備する
|
||||
logger.info("prepare accelerator")
|
||||
print("prepare accelerator")
|
||||
accelerator = train_util.prepare_accelerator(args)
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
@@ -132,7 +127,7 @@ def train(args):
|
||||
if args.init_word is not None:
|
||||
init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False)
|
||||
if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token:
|
||||
logger.warning(
|
||||
print(
|
||||
f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}"
|
||||
)
|
||||
else:
|
||||
@@ -146,7 +141,7 @@ def train(args):
|
||||
), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}"
|
||||
|
||||
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
|
||||
logger.info(f"tokens are added: {token_ids}")
|
||||
print(f"tokens are added: {token_ids}")
|
||||
assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered"
|
||||
assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}"
|
||||
|
||||
@@ -174,7 +169,7 @@ def train(args):
|
||||
|
||||
tokenizer.add_tokens(token_strings_XTI)
|
||||
token_ids_XTI = tokenizer.convert_tokens_to_ids(token_strings_XTI)
|
||||
logger.info(f"tokens are added (XTI): {token_ids_XTI}")
|
||||
print(f"tokens are added (XTI): {token_ids_XTI}")
|
||||
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
||||
text_encoder.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
@@ -183,7 +178,7 @@ def train(args):
|
||||
if init_token_ids is not None:
|
||||
for i, token_id in enumerate(token_ids_XTI):
|
||||
token_embeds[token_id] = token_embeds[init_token_ids[(i // 16) % len(init_token_ids)]]
|
||||
# logger.info(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
|
||||
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
|
||||
|
||||
# load weights
|
||||
if args.weights is not None:
|
||||
@@ -191,22 +186,22 @@ def train(args):
|
||||
assert len(token_ids) == len(
|
||||
embeddings
|
||||
), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}"
|
||||
# logger.info(token_ids, embeddings.size())
|
||||
# print(token_ids, embeddings.size())
|
||||
for token_id, embedding in zip(token_ids_XTI, embeddings):
|
||||
token_embeds[token_id] = embedding
|
||||
# logger.info(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
|
||||
logger.info(f"weighs loaded")
|
||||
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
|
||||
print(f"weighs loaded")
|
||||
|
||||
logger.info(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
|
||||
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
|
||||
|
||||
# データセットを準備する
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, False))
|
||||
if args.dataset_config is not None:
|
||||
logger.info(f"Load dataset config from {args.dataset_config}")
|
||||
print(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
logger.info(
|
||||
print(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
", ".join(ignored)
|
||||
)
|
||||
@@ -214,14 +209,14 @@ def train(args):
|
||||
else:
|
||||
use_dreambooth_method = args.in_json is None
|
||||
if use_dreambooth_method:
|
||||
logger.info("Use DreamBooth method.")
|
||||
print("Use DreamBooth method.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
|
||||
]
|
||||
}
|
||||
else:
|
||||
logger.info("Train with captions.")
|
||||
print("Train with captions.")
|
||||
user_config = {
|
||||
"datasets": [
|
||||
{
|
||||
@@ -245,7 +240,7 @@ def train(args):
|
||||
|
||||
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
|
||||
if use_template:
|
||||
logger.info(f"use template for training captions. is object: {args.use_object_template}")
|
||||
print(f"use template for training captions. is object: {args.use_object_template}")
|
||||
templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small
|
||||
replace_to = " ".join(token_strings)
|
||||
captions = []
|
||||
@@ -269,7 +264,7 @@ def train(args):
|
||||
train_util.debug_dataset(train_dataset_group, show_input_ids=True)
|
||||
return
|
||||
if len(train_dataset_group) == 0:
|
||||
logger.error("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
|
||||
print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
|
||||
return
|
||||
|
||||
if cache_latents:
|
||||
@@ -291,7 +286,9 @@ def train(args):
|
||||
with torch.no_grad():
|
||||
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
@@ -300,13 +297,13 @@ def train(args):
|
||||
text_encoder.gradient_checkpointing_enable()
|
||||
|
||||
# 学習に必要なクラスを準備する
|
||||
logger.info("prepare optimizer, data loader etc.")
|
||||
print("prepare optimizer, data loader etc.")
|
||||
trainable_params = text_encoder.get_input_embeddings().parameters()
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset_group,
|
||||
batch_size=1,
|
||||
@@ -321,9 +318,7 @@ def train(args):
|
||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||
)
|
||||
logger.info(
|
||||
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
|
||||
)
|
||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
# データセット側にも学習ステップを送信
|
||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
@@ -337,7 +332,7 @@ def train(args):
|
||||
)
|
||||
|
||||
index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0]
|
||||
# logger.info(len(index_no_updates), torch.sum(index_no_updates))
|
||||
# print(len(index_no_updates), torch.sum(index_no_updates))
|
||||
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
|
||||
|
||||
# Freeze all parameters except for the token embeddings in text encoder
|
||||
@@ -375,17 +370,15 @@ def train(args):
|
||||
|
||||
# 学習する
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
logger.info("running training / 学習開始")
|
||||
logger.info(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
||||
logger.info(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
||||
logger.info(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||
logger.info(f" num epochs / epoch数: {num_train_epochs}")
|
||||
logger.info(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
||||
logger.info(
|
||||
f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
|
||||
)
|
||||
logger.info(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||
logger.info(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||
print("running training / 学習開始")
|
||||
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
||||
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
||||
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||
print(f" num epochs / epoch数: {num_train_epochs}")
|
||||
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
||||
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
||||
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||
|
||||
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
||||
global_step = 0
|
||||
@@ -400,20 +393,17 @@ def train(args):
|
||||
if accelerator.is_main_process:
|
||||
init_kwargs = {}
|
||||
if args.wandb_run_name:
|
||||
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
||||
init_kwargs['wandb'] = {'name': args.wandb_run_name}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers(
|
||||
"textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
|
||||
)
|
||||
accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
|
||||
|
||||
# function for saving/removing
|
||||
def save_model(ckpt_name, embs, steps, epoch_no, force_sync_upload=False):
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
|
||||
logger.info("")
|
||||
logger.info(f"saving checkpoint: {ckpt_file}")
|
||||
print(f"\nsaving checkpoint: {ckpt_file}")
|
||||
save_weights(ckpt_file, embs, save_dtype)
|
||||
if args.huggingface_repo_id is not None:
|
||||
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
|
||||
@@ -421,13 +411,12 @@ def train(args):
|
||||
def remove_model(old_ckpt_name):
|
||||
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
||||
if os.path.exists(old_ckpt_file):
|
||||
logger.info(f"removing old checkpoint: {old_ckpt_file}")
|
||||
print(f"removing old checkpoint: {old_ckpt_file}")
|
||||
os.remove(old_ckpt_file)
|
||||
|
||||
# training loop
|
||||
for epoch in range(num_train_epochs):
|
||||
logger.info("")
|
||||
logger.info(f"epoch {epoch+1}/{num_train_epochs}")
|
||||
print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
text_encoder.train()
|
||||
@@ -597,7 +586,7 @@ def train(args):
|
||||
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
|
||||
save_model(ckpt_name, updated_embs, global_step, num_train_epochs, force_sync_upload=True)
|
||||
|
||||
logger.info("model saved.")
|
||||
print("model saved.")
|
||||
|
||||
|
||||
def save_weights(file, updated_embs, save_dtype):
|
||||
@@ -658,7 +647,6 @@ def load_weights(file):
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
add_logging_arguments(parser)
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, True, True, False)
|
||||
train_util.add_training_arguments(parser, True)
|
||||
@@ -674,9 +662,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み"
|
||||
)
|
||||
parser.add_argument("--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み")
|
||||
parser.add_argument(
|
||||
"--num_vectors_per_token", type=int, default=1, help="number of vectors per token / トークンに割り当てるembeddingsの要素数"
|
||||
)
|
||||
@@ -686,9 +672,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
default=None,
|
||||
help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可"
|
||||
)
|
||||
parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可")
|
||||
parser.add_argument(
|
||||
"--use_object_template",
|
||||
action="store_true",
|
||||
|
||||
Reference in New Issue
Block a user