mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
Compare commits
126 Commits
wuerstchen
...
v0.7.1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4f93bf10f0 | ||
|
|
1db5d790ed | ||
|
|
ab716302e4 | ||
|
|
fd7f27f044 | ||
|
|
1a36f9dc65 | ||
|
|
c2497877ca | ||
|
|
3b5c1a1d4b | ||
|
|
9a2e385f12 | ||
|
|
7080e1a11c | ||
|
|
0a52b83c6a | ||
|
|
11ed8e2a6d | ||
|
|
bb20c09a9a | ||
|
|
04ef8d395f | ||
|
|
0676f1a86f | ||
|
|
6b7823df07 | ||
|
|
2186e417ba | ||
|
|
1519e3067c | ||
|
|
35e5424255 | ||
|
|
8c7d05afd2 | ||
|
|
f8360a4831 | ||
|
|
8556b9d7f5 | ||
|
|
3efd90b2ad | ||
|
|
7adcd9cd1a | ||
|
|
aff05e043f | ||
|
|
ff2c0c192e | ||
|
|
d309a27a51 | ||
|
|
471d274803 | ||
|
|
35f4c9b5c7 | ||
|
|
034a49c69d | ||
|
|
3b6825d7e2 | ||
|
|
bb5ae389f7 | ||
|
|
4a2cef887c | ||
|
|
42750f7846 | ||
|
|
d31aa143f4 | ||
|
|
710e777a92 | ||
|
|
912dca8f65 | ||
|
|
db84530074 | ||
|
|
72bbaac96d | ||
|
|
5713d63dc5 | ||
|
|
d653e594c2 | ||
|
|
dd7bb33ab6 | ||
|
|
a9c6182b3f | ||
|
|
3d70137d31 | ||
|
|
bce9a081db | ||
|
|
46cf41cc93 | ||
|
|
81a440c8e8 | ||
|
|
f24a3b5282 | ||
|
|
383b4a2c3e | ||
|
|
df59822a27 | ||
|
|
0908c5414d | ||
|
|
ee46134fa7 | ||
|
|
39bb319d4c | ||
|
|
1bdd83a85f | ||
|
|
1624c239c2 | ||
|
|
4a913ce61e | ||
|
|
764e333fa2 | ||
|
|
c61e3bf4c9 | ||
|
|
fc8649d80f | ||
|
|
0fb9ecf1f3 | ||
|
|
97958400fb | ||
|
|
6d6d86260b | ||
|
|
c856ea4249 | ||
|
|
d0923d6710 | ||
|
|
f312522cef | ||
|
|
da5a144589 | ||
|
|
2c1e669bd8 | ||
|
|
e20e9f61ac | ||
|
|
6b3148fd3f | ||
|
|
95ae56bd22 | ||
|
|
990192d077 | ||
|
|
f3e69531c3 | ||
|
|
0cb3272bda | ||
|
|
6231aa91e2 | ||
|
|
489b728dbc | ||
|
|
583e2b2d01 | ||
|
|
5dc2a0d3fd | ||
|
|
2c731418ad | ||
|
|
5c150675bf | ||
|
|
fea810b437 | ||
|
|
96d877be90 | ||
|
|
40d917b0fe | ||
|
|
e72020ae01 | ||
|
|
01d929ee2a | ||
|
|
cf876fcdb4 | ||
|
|
291c29caaf | ||
|
|
01e00ac1b0 | ||
|
|
a9ed4ed8a8 | ||
|
|
9d6a5a0c79 | ||
|
|
fb97a7aab1 | ||
|
|
1cefb2a753 | ||
|
|
63992b81c8 | ||
|
|
d8f68674fb | ||
|
|
9d00c8eea2 | ||
|
|
0d21925bdf | ||
|
|
efef5c8ead | ||
|
|
3d2bb1a8f1 | ||
|
|
837a4dddb8 | ||
|
|
b2626bc7a9 | ||
|
|
202f2c3292 | ||
|
|
2a23713f71 | ||
|
|
681034d001 | ||
|
|
17813ff5b4 | ||
|
|
3e81bd6b67 | ||
|
|
23ae358e0f | ||
|
|
f611726364 | ||
|
|
33ee0acd35 | ||
|
|
8b79e3b06c | ||
|
|
cf49e912fc | ||
|
|
66741c035c | ||
|
|
406511c333 | ||
|
|
8a2d68d63e | ||
|
|
07d297fdbe | ||
|
|
0d4e8b50d0 | ||
|
|
1d7c5c2a98 | ||
|
|
0faa350175 | ||
|
|
8a7509db75 | ||
|
|
025368f51c | ||
|
|
5fe52ed322 | ||
|
|
8b247a330b | ||
|
|
d6f458fcb3 | ||
|
|
b8b84021e5 | ||
|
|
70fe7e18be | ||
|
|
9378da3c82 | ||
|
|
a4857fa764 | ||
|
|
592014923f | ||
|
|
6d06b215bf |
4
.github/workflows/typos.yml
vendored
4
.github/workflows/typos.yml
vendored
@@ -15,7 +15,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: typos-action
|
||||
uses: crate-ci/typos@v1.16.15
|
||||
uses: crate-ci/typos@v1.16.26
|
||||
|
||||
101
README.md
101
README.md
@@ -249,32 +249,99 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum
|
||||
|
||||
## Change History
|
||||
|
||||
### Oct 1. 2023 / 2023/10/1
|
||||
### Dec 24, 2023 / 2023/12/24
|
||||
|
||||
- SDXL training is now available in the main branch. The sdxl branch is merged into the main branch.
|
||||
- Fixed to work `tools/convert_diffusers20_original_sd.py`. Thanks to Disty0! PR [#1016](https://github.com/kohya-ss/sd-scripts/pull/1016)
|
||||
|
||||
- [SAI Model Spec](https://github.com/Stability-AI/ModelSpec) metadata is now supported partially. `hash_sha256` is not supported yet.
|
||||
- The main items are set automatically.
|
||||
- You can set title, author, description, license and tags with `--metadata_xxx` options in each training script.
|
||||
- Merging scripts also support minimum SAI Model Spec metadata. See the help message for the usage.
|
||||
- Metadata editor will be available soon.
|
||||
- `tools/convert_diffusers20_original_sd.py` が動かなくなっていたのが修正されました。Disty0 氏に感謝します。 PR [#1016](https://github.com/kohya-ss/sd-scripts/pull/1016)
|
||||
|
||||
- `bitsandbytes` is now optional. Please install it if you want to use it. The insructions are in the later section.
|
||||
|
||||
- `albumentations` is not required anymore.
|
||||
### Dec 21, 2023 / 2023/12/21
|
||||
|
||||
- `--v_pred_like_loss ratio` option is added. This option adds the loss like v-prediction loss in SDXL training. `0.1` means that the loss is added 10% of the v-prediction loss. The default value is None (disabled).
|
||||
- In v-prediction, the loss is higher in the early timesteps (near the noise). This option can be used to increase the loss in the early timesteps.
|
||||
- The issues in multi-GPU training are fixed. Thanks to Isotr0py! PR [#989](https://github.com/kohya-ss/sd-scripts/pull/989) and [#1000](https://github.com/kohya-ss/sd-scripts/pull/1000)
|
||||
- `--ddp_gradient_as_bucket_view` and `--ddp_bucket_view`options are added to `sdxl_train.py`. Please specify these options for multi-GPU training.
|
||||
- IPEX support is updated. Thanks to Disty0!
|
||||
- Fixed the bug that the size of the bucket becomes less than `min_bucket_reso`. Thanks to Cauldrath! PR [#1008](https://github.com/kohya-ss/sd-scripts/pull/1008)
|
||||
- `--sample_at_first` option is added to each training script. This option is useful to generate images at the first step, before training. Thanks to shirayu! PR [#907](https://github.com/kohya-ss/sd-scripts/pull/907)
|
||||
- `--ss` option is added to the sampling prompt in training. You can specify the scheduler for the sampling like `--ss euler_a`. Thanks to shirayu! PR [#906](https://github.com/kohya-ss/sd-scripts/pull/906)
|
||||
- `keep_tokens_separator` is added to the dataset config. This option is useful to keep (prevent from shuffling) the tokens in the captions. See [#975](https://github.com/kohya-ss/sd-scripts/pull/975) for details. Thanks to Linaqruf!
|
||||
- You can specify the separator with an option like `--keep_tokens_separator "|||"` or with `keep_tokens_separator: "|||"` in `.toml`. The tokens before `|||` are not shuffled.
|
||||
- Attention processor hook is added. See [#961](https://github.com/kohya-ss/sd-scripts/pull/961) for details. Thanks to rockerBOO!
|
||||
- The optimizer `PagedAdamW` is added. Thanks to xzuyn! PR [#955](https://github.com/kohya-ss/sd-scripts/pull/955)
|
||||
- NaN replacement in SDXL VAE is sped up. Thanks to liubo0902! PR [#1009](https://github.com/kohya-ss/sd-scripts/pull/1009)
|
||||
- Fixed the path error in `finetune/make_captions.py`. Thanks to CjangCjengh! PR [#986](https://github.com/kohya-ss/sd-scripts/pull/986)
|
||||
|
||||
- Arbitrary options can be used for Diffusers' schedulers. For example `--lr_scheduler_args "lr_end=1e-8"`.
|
||||
- マルチGPUでの学習の不具合を修正しました。Isotr0py 氏に感謝します。 PR [#989](https://github.com/kohya-ss/sd-scripts/pull/989) および [#1000](https://github.com/kohya-ss/sd-scripts/pull/1000)
|
||||
- `sdxl_train.py` に `--ddp_gradient_as_bucket_view` と `--ddp_bucket_view` オプションが追加されました。マルチGPUでの学習時にはこれらのオプションを指定してください。
|
||||
- IPEX サポートが更新されました。Disty0 氏に感謝します。
|
||||
- Aspect Ratio Bucketing で bucket のサイズが `min_bucket_reso` 未満になる不具合を修正しました。Cauldrath 氏に感謝します。 PR [#1008](https://github.com/kohya-ss/sd-scripts/pull/1008)
|
||||
- 各学習スクリプトに `--sample_at_first` オプションが追加されました。学習前に画像を生成することで、学習結果が比較しやすくなります。shirayu 氏に感謝します。 PR [#907](https://github.com/kohya-ss/sd-scripts/pull/907)
|
||||
- 学習時のプロンプトに `--ss` オプションが追加されました。`--ss euler_a` のようにスケジューラを指定できます。shirayu 氏に感謝します。 PR [#906](https://github.com/kohya-ss/sd-scripts/pull/906)
|
||||
- データセット設定に `keep_tokens_separator` が追加されました。キャプション内のトークンをどの位置までシャッフルしないかを指定できます。詳細は [#975](https://github.com/kohya-ss/sd-scripts/pull/975) を参照してください。Linaqruf 氏に感謝します。
|
||||
- オプションで `--keep_tokens_separator "|||"` のように指定するか、`.toml` で `keep_tokens_separator: "|||"` のように指定します。`|||` の前のトークンはシャッフルされません。
|
||||
- Attention processor hook が追加されました。詳細は [#961](https://github.com/kohya-ss/sd-scripts/pull/961) を参照してください。rockerBOO 氏に感謝します。
|
||||
- オプティマイザ `PagedAdamW` が追加されました。xzuyn 氏に感謝します。 PR [#955](https://github.com/kohya-ss/sd-scripts/pull/955)
|
||||
- 学習時、SDXL VAE で NaN が発生した時の置き換えが高速化されました。liubo0902 氏に感謝します。 PR [#1009](https://github.com/kohya-ss/sd-scripts/pull/1009)
|
||||
- `finetune/make_captions.py` で相対パス指定時のエラーが修正されました。CjangCjengh 氏に感謝します。 PR [#986](https://github.com/kohya-ss/sd-scripts/pull/986)
|
||||
|
||||
- LoRA-FA is added experimentally. Specify `--network_module networks.lora_fa` option instead of `--network_module networks.lora`. The trained model can be used as a normal LoRA model.
|
||||
- JPEG XL is supported. [#786](https://github.com/kohya-ss/sd-scripts/pull/786)
|
||||
- Input perturbation noise is added. See [#798](https://github.com/kohya-ss/sd-scripts/pull/798) for details.
|
||||
- Dataset subset now has `caption_prefix` and `caption_suffix` options. The strings are added to the beginning and the end of the captions before shuffling. You can specify the options in `.toml`.
|
||||
- Intel ARC support with IPEX is added. [#825](https://github.com/kohya-ss/sd-scripts/pull/825)
|
||||
### Dec 3, 2023 / 2023/12/3
|
||||
|
||||
- `finetune\tag_images_by_wd14_tagger.py` now supports the separator other than `,` with `--caption_separator` option. Thanks to KohakuBlueleaf! PR [#913](https://github.com/kohya-ss/sd-scripts/pull/913)
|
||||
- Min SNR Gamma with V-predicition (SD 2.1) is fixed. Thanks to feffy380! PR[#934](https://github.com/kohya-ss/sd-scripts/pull/934)
|
||||
- See [#673](https://github.com/kohya-ss/sd-scripts/issues/673) for details.
|
||||
- `--min_diff` and `--clamp_quantile` options are added to `networks/extract_lora_from_models.py`. Thanks to wkpark! PR [#936](https://github.com/kohya-ss/sd-scripts/pull/936)
|
||||
- The default values are same as the previous version.
|
||||
- Deep Shrink hires fix is supported in `sdxl_gen_img.py` and `gen_img_diffusers.py`.
|
||||
- `--ds_timesteps_1` and `--ds_timesteps_2` options denote the timesteps of the Deep Shrink for the first and second stages.
|
||||
- `--ds_depth_1` and `--ds_depth_2` options denote the depth (block index) of the Deep Shrink for the first and second stages.
|
||||
- `--ds_ratio` option denotes the ratio of the Deep Shrink. `0.5` means the half of the original latent size for the Deep Shrink.
|
||||
- `--dst1`, `--dst2`, `--dsd1`, `--dsd2` and `--dsr` prompt options are also available.
|
||||
|
||||
- `finetune\tag_images_by_wd14_tagger.py` で `--caption_separator` オプションでカンマ以外の区切り文字を指定できるようになりました。KohakuBlueleaf 氏に感謝します。 PR [#913](https://github.com/kohya-ss/sd-scripts/pull/913)
|
||||
- V-predicition (SD 2.1) での Min SNR Gamma が修正されました。feffy380 氏に感謝します。 PR[#934](https://github.com/kohya-ss/sd-scripts/pull/934)
|
||||
- 詳細は [#673](https://github.com/kohya-ss/sd-scripts/issues/673) を参照してください。
|
||||
- `networks/extract_lora_from_models.py` に `--min_diff` と `--clamp_quantile` オプションが追加されました。wkpark 氏に感謝します。 PR [#936](https://github.com/kohya-ss/sd-scripts/pull/936)
|
||||
- デフォルト値は前のバージョンと同じです。
|
||||
- `sdxl_gen_img.py` と `gen_img_diffusers.py` で Deep Shrink hires fix をサポートしました。
|
||||
- `--ds_timesteps_1` と `--ds_timesteps_2` オプションは Deep Shrink の第一段階と第二段階の timesteps を指定します。
|
||||
- `--ds_depth_1` と `--ds_depth_2` オプションは Deep Shrink の第一段階と第二段階の深さ(ブロックの index)を指定します。
|
||||
- `--ds_ratio` オプションは Deep Shrink の比率を指定します。`0.5` を指定すると Deep Shrink 適用時の latent は元のサイズの半分になります。
|
||||
- `--dst1`、`--dst2`、`--dsd1`、`--dsd2`、`--dsr` プロンプトオプションも使用できます。
|
||||
|
||||
### Nov 5, 2023 / 2023/11/5
|
||||
|
||||
- `sdxl_train.py` now supports different learning rates for each Text Encoder.
|
||||
- Example:
|
||||
- `--learning_rate 1e-6`: train U-Net only
|
||||
- `--train_text_encoder --learning_rate 1e-6`: train U-Net and two Text Encoders with the same learning rate (same as the previous version)
|
||||
- `--train_text_encoder --learning_rate 1e-6 --learning_rate_te1 1e-6 --learning_rate_te2 1e-6`: train U-Net and two Text Encoders with the different learning rates
|
||||
- `--train_text_encoder --learning_rate 0 --learning_rate_te1 1e-6 --learning_rate_te2 1e-6`: train two Text Encoders only
|
||||
- `--train_text_encoder --learning_rate 1e-6 --learning_rate_te1 1e-6 --learning_rate_te2 0`: train U-Net and one Text Encoder only
|
||||
- `--train_text_encoder --learning_rate 0 --learning_rate_te1 0 --learning_rate_te2 1e-6`: train one Text Encoder only
|
||||
|
||||
- `train_db.py` and `fine_tune.py` now support different learning rates for Text Encoder. Specify with `--learning_rate_te` option.
|
||||
- To train Text Encoder with `fine_tune.py`, specify `--train_text_encoder` option too. `train_db.py` trains Text Encoder by default.
|
||||
|
||||
- Fixed the bug that Text Encoder is not trained when block lr is specified in `sdxl_train.py`.
|
||||
|
||||
- Debiased Estimation loss is added to each training script. Thanks to sdbds!
|
||||
- Specify `--debiased_estimation_loss` option to enable it. See PR [#889](https://github.com/kohya-ss/sd-scripts/pull/889) for details.
|
||||
- Training of Text Encoder is improved in `train_network.py` and `sdxl_train_network.py`. Thanks to KohakuBlueleaf! PR [#895](https://github.com/kohya-ss/sd-scripts/pull/895)
|
||||
- The moving average of the loss is now displayed in the progress bar in each training script. Thanks to shirayu! PR [#899](https://github.com/kohya-ss/sd-scripts/pull/899)
|
||||
- PagedAdamW32bit optimizer is supported. Specify `--optimizer_type=PagedAdamW32bit`. Thanks to xzuyn! PR [#900](https://github.com/kohya-ss/sd-scripts/pull/900)
|
||||
- Other bug fixes and improvements.
|
||||
|
||||
- `sdxl_train.py` で、二つのText Encoderそれぞれに独立した学習率が指定できるようになりました。サンプルは上の英語版を参照してください。
|
||||
- `train_db.py` および `fine_tune.py` で Text Encoder に別の学習率を指定できるようになりました。`--learning_rate_te` オプションで指定してください。
|
||||
- `fine_tune.py` で Text Encoder を学習するには `--train_text_encoder` オプションをあわせて指定してください。`train_db.py` はデフォルトで学習します。
|
||||
- `sdxl_train.py` で block lr を指定すると Text Encoder が学習されない不具合を修正しました。
|
||||
- Debiased Estimation loss が各学習スクリプトに追加されました。sdbsd 氏に感謝します。
|
||||
- `--debiased_estimation_loss` を指定すると有効になります。詳細は PR [#889](https://github.com/kohya-ss/sd-scripts/pull/889) を参照してください。
|
||||
- `train_network.py` と `sdxl_train_network.py` でText Encoderの学習が改善されました。KohakuBlueleaf 氏に感謝します。 PR [#895](https://github.com/kohya-ss/sd-scripts/pull/895)
|
||||
- 各学習スクリプトで移動平均のlossがプログレスバーに表示されるようになりました。shirayu 氏に感謝します。 PR [#899](https://github.com/kohya-ss/sd-scripts/pull/899)
|
||||
- PagedAdamW32bit オプティマイザがサポートされました。`--optimizer_type=PagedAdamW32bit` と指定してください。xzuyn 氏に感謝します。 PR [#900](https://github.com/kohya-ss/sd-scripts/pull/900)
|
||||
- その他のバグ修正と改善。
|
||||
|
||||
|
||||
Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates.
|
||||
最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。
|
||||
|
||||
@@ -374,6 +374,10 @@ classがひとつで対象が複数の場合、正則化画像フォルダはひ
|
||||
|
||||
サンプル出力するステップ数またはエポック数を指定します。この数ごとにサンプル出力します。両方指定するとエポック数が優先されます。
|
||||
|
||||
- `--sample_at_first`
|
||||
|
||||
学習開始前にサンプル出力します。学習前との比較ができます。
|
||||
|
||||
- `--sample_prompts`
|
||||
|
||||
サンプル出力用プロンプトのファイルを指定します。
|
||||
|
||||
58
fine_tune.py
58
fine_tune.py
@@ -10,10 +10,13 @@ import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
@@ -32,6 +35,7 @@ from library.custom_train_functions import (
|
||||
get_weighted_text_embeddings,
|
||||
prepare_scheduler_for_custom_training,
|
||||
scale_v_prediction_loss_like_noise_prediction,
|
||||
apply_debiased_estimation,
|
||||
)
|
||||
|
||||
|
||||
@@ -192,14 +196,20 @@ def train(args):
|
||||
|
||||
for m in training_models:
|
||||
m.requires_grad_(True)
|
||||
params = []
|
||||
for m in training_models:
|
||||
params.extend(m.parameters())
|
||||
params_to_optimize = params
|
||||
|
||||
trainable_params = []
|
||||
if args.learning_rate_te is None or not args.train_text_encoder:
|
||||
for m in training_models:
|
||||
trainable_params.extend(m.parameters())
|
||||
else:
|
||||
trainable_params = [
|
||||
{"params": list(unet.parameters()), "lr": args.learning_rate},
|
||||
{"params": list(text_encoder.parameters()), "lr": args.learning_rate_te},
|
||||
]
|
||||
|
||||
# 学習に必要なクラスを準備する
|
||||
accelerator.print("prepare optimizer, data loader etc.")
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params=trainable_params)
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
@@ -243,9 +253,6 @@ def train(args):
|
||||
else:
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
# transform DDP after prepare
|
||||
text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||
if args.full_fp16:
|
||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||
@@ -288,6 +295,10 @@ def train(args):
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
|
||||
|
||||
# For --sample_at_first
|
||||
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
for epoch in range(num_train_epochs):
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch + 1
|
||||
@@ -295,7 +306,6 @@ def train(args):
|
||||
for m in training_models:
|
||||
m.train()
|
||||
|
||||
loss_total = 0
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
current_step.value = global_step
|
||||
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
|
||||
@@ -339,15 +349,17 @@ def train(args):
|
||||
else:
|
||||
target = noise
|
||||
|
||||
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred:
|
||||
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss:
|
||||
# do not mean over batch dimension for snr weight or scale v-pred loss
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||
if args.debiased_estimation_loss:
|
||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
||||
|
||||
loss = loss.mean() # mean over batch dimension
|
||||
else:
|
||||
@@ -396,26 +408,20 @@ def train(args):
|
||||
|
||||
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
||||
if (
|
||||
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
|
||||
): # tracking d*lr value
|
||||
logs["lr/d*lr"] = (
|
||||
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
|
||||
)
|
||||
logs = {"loss": current_loss}
|
||||
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
# TODO moving averageにする
|
||||
loss_total += current_loss
|
||||
avr_loss = loss_total / (step + 1)
|
||||
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||
avr_loss: float = loss_recorder.moving_average
|
||||
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss/epoch": loss_total / len(train_dataloader)}
|
||||
logs = {"loss/epoch": loss_recorder.moving_average}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
@@ -474,6 +480,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
|
||||
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
|
||||
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
|
||||
parser.add_argument(
|
||||
"--learning_rate_te",
|
||||
type=float,
|
||||
default=None,
|
||||
help="learning rate for text encoder, default is same as unet / Text Encoderの学習率、デフォルトはunetと同じ",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ import torch
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
from blip.blip import blip_decoder
|
||||
from blip.blip import blip_decoder, is_url
|
||||
import library.train_util as train_util
|
||||
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
@@ -76,6 +76,8 @@ def main(args):
|
||||
cwd = os.getcwd()
|
||||
print("Current Working Directory is: ", cwd)
|
||||
os.chdir("finetune")
|
||||
if not is_url(args.caption_weights) and not os.path.isfile(args.caption_weights):
|
||||
args.caption_weights = os.path.join("..", args.caption_weights)
|
||||
|
||||
print(f"load images from {args.train_data_dir}")
|
||||
train_data_dir_path = Path(args.train_data_dir)
|
||||
|
||||
@@ -52,6 +52,9 @@ def collate_fn_remove_corrupted(batch):
|
||||
|
||||
|
||||
def main(args):
|
||||
r"""
|
||||
transformers 4.30.2で、バッチサイズ>1でも動くようになったので、以下コメントアウト
|
||||
|
||||
# GITにバッチサイズが1より大きくても動くようにパッチを当てる: transformers 4.26.0用
|
||||
org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation
|
||||
curr_batch_size = [args.batch_size] # ループの最後で件数がbatch_size未満になるので入れ替えられるように
|
||||
@@ -65,6 +68,7 @@ def main(args):
|
||||
return input_ids
|
||||
|
||||
GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
|
||||
"""
|
||||
|
||||
print(f"load images from {args.train_data_dir}")
|
||||
train_data_dir_path = Path(args.train_data_dir)
|
||||
@@ -81,7 +85,7 @@ def main(args):
|
||||
def run_batch(path_imgs):
|
||||
imgs = [im for _, im in path_imgs]
|
||||
|
||||
curr_batch_size[0] = len(path_imgs)
|
||||
# curr_batch_size[0] = len(path_imgs)
|
||||
inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE) # 画像はpil形式
|
||||
generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length)
|
||||
captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
|
||||
@@ -215,7 +215,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)",
|
||||
)
|
||||
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
|
||||
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度")
|
||||
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度")
|
||||
parser.add_argument(
|
||||
"--bucket_reso_steps",
|
||||
type=int,
|
||||
|
||||
@@ -1,17 +1,15 @@
|
||||
import argparse
|
||||
import csv
|
||||
import glob
|
||||
import os
|
||||
|
||||
from PIL import Image
|
||||
import cv2
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from tensorflow.keras.models import load_model
|
||||
from huggingface_hub import hf_hub_download
|
||||
import torch
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
import library.train_util as train_util
|
||||
|
||||
# from wd14 tagger
|
||||
@@ -20,6 +18,7 @@ IMAGE_SIZE = 448
|
||||
# wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2
|
||||
DEFAULT_WD14_TAGGER_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
|
||||
FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"]
|
||||
FILES_ONNX = ["model.onnx"]
|
||||
SUB_DIR = "variables"
|
||||
SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"]
|
||||
CSV_FILE = FILES[-1]
|
||||
@@ -81,7 +80,10 @@ def main(args):
|
||||
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
|
||||
if not os.path.exists(args.model_dir) or args.force_download:
|
||||
print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
|
||||
for file in FILES:
|
||||
files = FILES
|
||||
if args.onnx:
|
||||
files += FILES_ONNX
|
||||
for file in files:
|
||||
hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
|
||||
for file in SUB_DIR_FILES:
|
||||
hf_hub_download(
|
||||
@@ -96,7 +98,46 @@ def main(args):
|
||||
print("using existing wd14 tagger model")
|
||||
|
||||
# 画像を読み込む
|
||||
model = load_model(args.model_dir)
|
||||
if args.onnx:
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
|
||||
onnx_path = f"{args.model_dir}/model.onnx"
|
||||
print("Running wd14 tagger with onnx")
|
||||
print(f"loading onnx model: {onnx_path}")
|
||||
|
||||
if not os.path.exists(onnx_path):
|
||||
raise Exception(
|
||||
f"onnx model not found: {onnx_path}, please redownload the model with --force_download"
|
||||
+ " / onnxモデルが見つかりませんでした。--force_downloadで再ダウンロードしてください"
|
||||
)
|
||||
|
||||
model = onnx.load(onnx_path)
|
||||
input_name = model.graph.input[0].name
|
||||
try:
|
||||
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_value
|
||||
except:
|
||||
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_param
|
||||
|
||||
if args.batch_size != batch_size and type(batch_size) != str:
|
||||
# some rebatch model may use 'N' as dynamic axes
|
||||
print(
|
||||
f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}"
|
||||
)
|
||||
args.batch_size = batch_size
|
||||
|
||||
del model
|
||||
|
||||
ort_sess = ort.InferenceSession(
|
||||
onnx_path,
|
||||
providers=["CUDAExecutionProvider"]
|
||||
if "CUDAExecutionProvider" in ort.get_available_providers()
|
||||
else ["CPUExecutionProvider"],
|
||||
)
|
||||
else:
|
||||
from tensorflow.keras.models import load_model
|
||||
|
||||
model = load_model(f"{args.model_dir}")
|
||||
|
||||
# label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv")
|
||||
# 依存ライブラリを増やしたくないので自力で読むよ
|
||||
@@ -119,13 +160,21 @@ def main(args):
|
||||
|
||||
tag_freq = {}
|
||||
|
||||
undesired_tags = set(args.undesired_tags.split(","))
|
||||
caption_separator = args.caption_separator
|
||||
stripped_caption_separator = caption_separator.strip()
|
||||
undesired_tags = set(args.undesired_tags.split(stripped_caption_separator))
|
||||
|
||||
def run_batch(path_imgs):
|
||||
imgs = np.array([im for _, im in path_imgs])
|
||||
|
||||
probs = model(imgs, training=False)
|
||||
probs = probs.numpy()
|
||||
if args.onnx:
|
||||
if len(imgs) < args.batch_size:
|
||||
imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0)
|
||||
probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy
|
||||
probs = probs[: len(path_imgs)]
|
||||
else:
|
||||
probs = model(imgs, training=False)
|
||||
probs = probs.numpy()
|
||||
|
||||
for (image_path, _), prob in zip(path_imgs, probs):
|
||||
# 最初の4つはratingなので無視する
|
||||
@@ -147,7 +196,7 @@ def main(args):
|
||||
|
||||
if tag_name not in undesired_tags:
|
||||
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
|
||||
general_tag_text += ", " + tag_name
|
||||
general_tag_text += caption_separator + tag_name
|
||||
combined_tags.append(tag_name)
|
||||
elif i >= len(general_tags) and p >= args.character_threshold:
|
||||
tag_name = character_tags[i - len(general_tags)]
|
||||
@@ -156,18 +205,36 @@ def main(args):
|
||||
|
||||
if tag_name not in undesired_tags:
|
||||
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
|
||||
character_tag_text += ", " + tag_name
|
||||
character_tag_text += caption_separator + tag_name
|
||||
combined_tags.append(tag_name)
|
||||
|
||||
# 先頭のカンマを取る
|
||||
if len(general_tag_text) > 0:
|
||||
general_tag_text = general_tag_text[2:]
|
||||
general_tag_text = general_tag_text[len(caption_separator) :]
|
||||
if len(character_tag_text) > 0:
|
||||
character_tag_text = character_tag_text[2:]
|
||||
character_tag_text = character_tag_text[len(caption_separator) :]
|
||||
|
||||
tag_text = ", ".join(combined_tags)
|
||||
caption_file = os.path.splitext(image_path)[0] + args.caption_extension
|
||||
|
||||
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
|
||||
tag_text = caption_separator.join(combined_tags)
|
||||
|
||||
if args.append_tags:
|
||||
# Check if file exists
|
||||
if os.path.exists(caption_file):
|
||||
with open(caption_file, "rt", encoding="utf-8") as f:
|
||||
# Read file and remove new lines
|
||||
existing_content = f.read().strip("\n") # Remove newlines
|
||||
|
||||
# Split the content into tags and store them in a list
|
||||
existing_tags = [tag.strip() for tag in existing_content.split(stripped_caption_separator) if tag.strip()]
|
||||
|
||||
# Check and remove repeating tags in tag_text
|
||||
new_tags = [tag for tag in combined_tags if tag not in existing_tags]
|
||||
|
||||
# Create new tag_text
|
||||
tag_text = caption_separator.join(existing_tags + new_tags)
|
||||
|
||||
with open(caption_file, "wt", encoding="utf-8") as f:
|
||||
f.write(tag_text + "\n")
|
||||
if args.debug:
|
||||
print(f"\n{image_path}:\n Character tags: {character_tag_text}\n General tags: {general_tag_text}")
|
||||
@@ -283,12 +350,21 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="comma-separated list of undesired tags to remove from the output / 出力から除外したいタグのカンマ区切りのリスト",
|
||||
)
|
||||
parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する")
|
||||
parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する")
|
||||
parser.add_argument("--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する")
|
||||
parser.add_argument(
|
||||
"--caption_separator",
|
||||
type=str,
|
||||
default=", ",
|
||||
help="Separator for captions, include space if needed / キャプションの区切り文字、必要ならスペースを含めてください",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# スペルミスしていたオプションを復元する
|
||||
|
||||
@@ -65,10 +65,13 @@ import re
|
||||
import diffusers
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
@@ -102,7 +105,7 @@ import library.train_util as train_util
|
||||
from networks.lora import LoRANetwork
|
||||
import tools.original_control_net as original_control_net
|
||||
from tools.original_control_net import ControlNetInfo
|
||||
from library.original_unet import UNet2DConditionModel
|
||||
from library.original_unet import UNet2DConditionModel, InferUNet2DConditionModel
|
||||
from library.original_unet import FlashAttentionFunction
|
||||
|
||||
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
|
||||
@@ -375,7 +378,7 @@ class PipelineLike:
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
unet: InferUNet2DConditionModel,
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||
clip_skip: int,
|
||||
clip_model: CLIPModel,
|
||||
@@ -954,7 +957,7 @@ class PipelineLike:
|
||||
text_emb_last = torch.stack(text_emb_last)
|
||||
else:
|
||||
text_emb_last = text_embeddings
|
||||
|
||||
|
||||
for i, t in enumerate(tqdm(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
|
||||
@@ -2193,6 +2196,7 @@ def main(args):
|
||||
)
|
||||
original_unet.load_state_dict(unet.state_dict())
|
||||
unet = original_unet
|
||||
unet: InferUNet2DConditionModel = InferUNet2DConditionModel(unet)
|
||||
|
||||
# VAEを読み込む
|
||||
if args.vae is not None:
|
||||
@@ -2349,13 +2353,20 @@ def main(args):
|
||||
vae = sli_vae
|
||||
del sli_vae
|
||||
vae.to(dtype).to(device)
|
||||
vae.eval()
|
||||
|
||||
text_encoder.to(dtype).to(device)
|
||||
unet.to(dtype).to(device)
|
||||
|
||||
text_encoder.eval()
|
||||
unet.eval()
|
||||
|
||||
if clip_model is not None:
|
||||
clip_model.to(dtype).to(device)
|
||||
clip_model.eval()
|
||||
if vgg16_model is not None:
|
||||
vgg16_model.to(dtype).to(device)
|
||||
vgg16_model.eval()
|
||||
|
||||
# networkを組み込む
|
||||
if args.network_module:
|
||||
@@ -2363,12 +2374,19 @@ def main(args):
|
||||
network_default_muls = []
|
||||
network_pre_calc = args.network_pre_calc
|
||||
|
||||
# merge関連の引数を統合する
|
||||
if args.network_merge:
|
||||
network_merge = len(args.network_module) # all networks are merged
|
||||
elif args.network_merge_n_models:
|
||||
network_merge = args.network_merge_n_models
|
||||
else:
|
||||
network_merge = 0
|
||||
|
||||
for i, network_module in enumerate(args.network_module):
|
||||
print("import network module:", network_module)
|
||||
imported_module = importlib.import_module(network_module)
|
||||
|
||||
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
|
||||
network_default_muls.append(network_mul)
|
||||
|
||||
net_kwargs = {}
|
||||
if args.network_args and i < len(args.network_args):
|
||||
@@ -2379,31 +2397,32 @@ def main(args):
|
||||
key, value = net_arg.split("=")
|
||||
net_kwargs[key] = value
|
||||
|
||||
if args.network_weights and i < len(args.network_weights):
|
||||
network_weight = args.network_weights[i]
|
||||
print("load network weights from:", network_weight)
|
||||
|
||||
if model_util.is_safetensors(network_weight) and args.network_show_meta:
|
||||
from safetensors.torch import safe_open
|
||||
|
||||
with safe_open(network_weight, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata is not None:
|
||||
print(f"metadata for: {network_weight}: {metadata}")
|
||||
|
||||
network, weights_sd = imported_module.create_network_from_weights(
|
||||
network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs
|
||||
)
|
||||
else:
|
||||
if args.network_weights is None or len(args.network_weights) <= i:
|
||||
raise ValueError("No weight. Weight is required.")
|
||||
|
||||
network_weight = args.network_weights[i]
|
||||
print("load network weights from:", network_weight)
|
||||
|
||||
if model_util.is_safetensors(network_weight) and args.network_show_meta:
|
||||
from safetensors.torch import safe_open
|
||||
|
||||
with safe_open(network_weight, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata is not None:
|
||||
print(f"metadata for: {network_weight}: {metadata}")
|
||||
|
||||
network, weights_sd = imported_module.create_network_from_weights(
|
||||
network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs
|
||||
)
|
||||
if network is None:
|
||||
return
|
||||
|
||||
mergeable = network.is_mergeable()
|
||||
if args.network_merge and not mergeable:
|
||||
if network_merge and not mergeable:
|
||||
print("network is not mergiable. ignore merge option.")
|
||||
|
||||
if not args.network_merge or not mergeable:
|
||||
if not mergeable or i >= network_merge:
|
||||
# not merging
|
||||
network.apply_to(text_encoder, unet)
|
||||
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
|
||||
print(f"weights are loaded: {info}")
|
||||
@@ -2417,6 +2436,7 @@ def main(args):
|
||||
network.backup_weights()
|
||||
|
||||
networks.append(network)
|
||||
network_default_muls.append(network_mul)
|
||||
else:
|
||||
network.merge_to(text_encoder, unet, weights_sd, dtype, device)
|
||||
|
||||
@@ -2489,6 +2509,10 @@ def main(args):
|
||||
if args.diffusers_xformers:
|
||||
pipe.enable_xformers_memory_efficient_attention()
|
||||
|
||||
# Deep Shrink
|
||||
if args.ds_depth_1 is not None:
|
||||
unet.set_deep_shrink(args.ds_depth_1, args.ds_timesteps_1, args.ds_depth_2, args.ds_timesteps_2, args.ds_ratio)
|
||||
|
||||
# Extended Textual Inversion および Textual Inversionを処理する
|
||||
if args.XTI_embeddings:
|
||||
diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI
|
||||
@@ -2712,9 +2736,18 @@ def main(args):
|
||||
|
||||
size = None
|
||||
for i, network in enumerate(networks):
|
||||
if i < 3:
|
||||
if (i < 3 and args.network_regional_mask_max_color_codes is None) or i < args.network_regional_mask_max_color_codes:
|
||||
np_mask = np.array(mask_images[0])
|
||||
np_mask = np_mask[:, :, i]
|
||||
|
||||
if args.network_regional_mask_max_color_codes:
|
||||
# カラーコードでマスクを指定する
|
||||
ch0 = (i + 1) & 1
|
||||
ch1 = ((i + 1) >> 1) & 1
|
||||
ch2 = ((i + 1) >> 2) & 1
|
||||
np_mask = np.all(np_mask == np.array([ch0, ch1, ch2]) * 255, axis=2)
|
||||
np_mask = np_mask.astype(np.uint8) * 255
|
||||
else:
|
||||
np_mask = np_mask[:, :, i]
|
||||
size = np_mask.shape
|
||||
else:
|
||||
np_mask = np.full(size, 255, dtype=np.uint8)
|
||||
@@ -3064,6 +3097,13 @@ def main(args):
|
||||
clip_prompt = None
|
||||
network_muls = None
|
||||
|
||||
# Deep Shrink
|
||||
ds_depth_1 = None # means no override
|
||||
ds_timesteps_1 = args.ds_timesteps_1
|
||||
ds_depth_2 = args.ds_depth_2
|
||||
ds_timesteps_2 = args.ds_timesteps_2
|
||||
ds_ratio = args.ds_ratio
|
||||
|
||||
prompt_args = raw_prompt.strip().split(" --")
|
||||
prompt = prompt_args[0]
|
||||
print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}")
|
||||
@@ -3135,10 +3175,51 @@ def main(args):
|
||||
print(f"network mul: {network_muls}")
|
||||
continue
|
||||
|
||||
# Deep Shrink
|
||||
m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # deep shrink depth 1
|
||||
ds_depth_1 = int(m.group(1))
|
||||
print(f"deep shrink depth 1: {ds_depth_1}")
|
||||
continue
|
||||
|
||||
m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # deep shrink timesteps 1
|
||||
ds_timesteps_1 = int(m.group(1))
|
||||
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
|
||||
print(f"deep shrink timesteps 1: {ds_timesteps_1}")
|
||||
continue
|
||||
|
||||
m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # deep shrink depth 2
|
||||
ds_depth_2 = int(m.group(1))
|
||||
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
|
||||
print(f"deep shrink depth 2: {ds_depth_2}")
|
||||
continue
|
||||
|
||||
m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # deep shrink timesteps 2
|
||||
ds_timesteps_2 = int(m.group(1))
|
||||
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
|
||||
print(f"deep shrink timesteps 2: {ds_timesteps_2}")
|
||||
continue
|
||||
|
||||
m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # deep shrink ratio
|
||||
ds_ratio = float(m.group(1))
|
||||
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
|
||||
print(f"deep shrink ratio: {ds_ratio}")
|
||||
continue
|
||||
|
||||
except ValueError as ex:
|
||||
print(f"Exception in parsing / 解析エラー: {parg}")
|
||||
print(ex)
|
||||
|
||||
# override Deep Shrink
|
||||
if ds_depth_1 is not None:
|
||||
if ds_depth_1 < 0:
|
||||
ds_depth_1 = args.ds_depth_1 or 3
|
||||
unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio)
|
||||
|
||||
# prepare seed
|
||||
if seeds is not None: # given in prompt
|
||||
# 数が足りないなら前のをそのまま使う
|
||||
@@ -3367,10 +3448,19 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
"--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数"
|
||||
)
|
||||
parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する")
|
||||
parser.add_argument(
|
||||
"--network_merge_n_models", type=int, default=None, help="merge this number of networks / この数だけネットワークをマージする"
|
||||
)
|
||||
parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする")
|
||||
parser.add_argument(
|
||||
"--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--network_regional_mask_max_color_codes",
|
||||
type=int,
|
||||
default=None,
|
||||
help="max color codes for regional mask (default is None, mask by channel) / regional maskの最大色数(デフォルトはNoneでチャンネルごとのマスク)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--textual_inversion_embeddings",
|
||||
type=str,
|
||||
@@ -3479,6 +3569,30 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
# "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像"
|
||||
# )
|
||||
|
||||
# Deep Shrink
|
||||
parser.add_argument(
|
||||
"--ds_depth_1",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Enable Deep Shrink with this depth 1, valid values are 0 to 3 / Deep Shrinkをこのdepthで有効にする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ds_timesteps_1",
|
||||
type=int,
|
||||
default=650,
|
||||
help="Apply Deep Shrink depth 1 until this timesteps / Deep Shrink depth 1を適用するtimesteps",
|
||||
)
|
||||
parser.add_argument("--ds_depth_2", type=int, default=None, help="Deep Shrink depth 2 / Deep Shrinkのdepth 2")
|
||||
parser.add_argument(
|
||||
"--ds_timesteps_2",
|
||||
type=int,
|
||||
default=650,
|
||||
help="Apply Deep Shrink depth 2 until this timesteps / Deep Shrink depth 2を適用するtimesteps",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ds_ratio", type=float, default=0.5, help="Deep Shrink ratio for downsampling / Deep Shrinkのdownsampling比率"
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
@@ -51,7 +51,9 @@ class BaseSubsetParams:
|
||||
image_dir: Optional[str] = None
|
||||
num_repeats: int = 1
|
||||
shuffle_caption: bool = False
|
||||
caption_separator: str = ',',
|
||||
keep_tokens: int = 0
|
||||
keep_tokens_separator: str = None,
|
||||
color_aug: bool = False
|
||||
flip_aug: bool = False
|
||||
face_crop_aug_range: Optional[Tuple[float, float]] = None
|
||||
@@ -159,6 +161,7 @@ class ConfigSanitizer:
|
||||
"random_crop": bool,
|
||||
"shuffle_caption": bool,
|
||||
"keep_tokens": int,
|
||||
"keep_tokens_separator": str,
|
||||
"token_warmup_min": int,
|
||||
"token_warmup_step": Any(float,int),
|
||||
"caption_prefix": str,
|
||||
@@ -460,6 +463,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
|
||||
num_repeats: {subset.num_repeats}
|
||||
shuffle_caption: {subset.shuffle_caption}
|
||||
keep_tokens: {subset.keep_tokens}
|
||||
keep_tokens_separator: {subset.keep_tokens_separator}
|
||||
caption_dropout_rate: {subset.caption_dropout_rate}
|
||||
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
|
||||
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
|
||||
|
||||
@@ -57,10 +57,13 @@ def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
|
||||
noise_scheduler.alphas_cumprod = alphas_cumprod
|
||||
|
||||
|
||||
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
|
||||
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False):
|
||||
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
|
||||
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
|
||||
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device) # from paper
|
||||
min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
|
||||
if v_prediction:
|
||||
snr_weight = torch.div(min_snr_gamma, snr+1).float().to(loss.device)
|
||||
else:
|
||||
snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device)
|
||||
loss = loss * snr_weight
|
||||
return loss
|
||||
|
||||
@@ -86,6 +89,12 @@ def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_los
|
||||
loss = loss + loss / scale * v_pred_like_loss
|
||||
return loss
|
||||
|
||||
def apply_debiased_estimation(loss, timesteps, noise_scheduler):
|
||||
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
|
||||
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
|
||||
weight = 1/torch.sqrt(snr_t)
|
||||
loss = weight * loss
|
||||
return loss
|
||||
|
||||
# TODO train_utilと分散しているのでどちらかに寄せる
|
||||
|
||||
@@ -108,6 +117,11 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted
|
||||
default=None,
|
||||
help="add v-prediction like loss multiplied by this value / v-prediction lossをこの値をかけたものをlossに加算する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debiased_estimation_loss",
|
||||
action="store_true",
|
||||
help="debiased estimation loss / debiased estimation loss",
|
||||
)
|
||||
if support_weighted_captions:
|
||||
parser.add_argument(
|
||||
"--weighted_captions",
|
||||
|
||||
@@ -4,13 +4,12 @@ import contextlib
|
||||
import torch
|
||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||
from .hijacks import ipex_hijacks
|
||||
from .attention import attention_init
|
||||
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||
|
||||
def ipex_init(): # pylint: disable=too-many-statements
|
||||
try:
|
||||
#Replace cuda with xpu:
|
||||
# Replace cuda with xpu:
|
||||
torch.cuda.current_device = torch.xpu.current_device
|
||||
torch.cuda.current_stream = torch.xpu.current_stream
|
||||
torch.cuda.device = torch.xpu.device
|
||||
@@ -30,6 +29,7 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
torch.cuda.FloatTensor = torch.xpu.FloatTensor
|
||||
torch.Tensor.cuda = torch.Tensor.xpu
|
||||
torch.Tensor.is_cuda = torch.Tensor.is_xpu
|
||||
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
|
||||
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
|
||||
torch.cuda._initialized = torch.xpu.lazy_init._initialized
|
||||
torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
|
||||
@@ -90,9 +90,9 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
torch.cuda.CharStorage = torch.xpu.CharStorage
|
||||
torch.cuda.__file__ = torch.xpu.__file__
|
||||
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
|
||||
#torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
|
||||
# torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
|
||||
|
||||
#Memory:
|
||||
# Memory:
|
||||
torch.cuda.memory = torch.xpu.memory
|
||||
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
|
||||
torch.xpu.empty_cache = lambda: None
|
||||
@@ -112,7 +112,7 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict
|
||||
torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats
|
||||
|
||||
#RNG:
|
||||
# RNG:
|
||||
torch.cuda.get_rng_state = torch.xpu.get_rng_state
|
||||
torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all
|
||||
torch.cuda.set_rng_state = torch.xpu.set_rng_state
|
||||
@@ -123,7 +123,7 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
torch.cuda.seed_all = torch.xpu.seed_all
|
||||
torch.cuda.initial_seed = torch.xpu.initial_seed
|
||||
|
||||
#AMP:
|
||||
# AMP:
|
||||
torch.cuda.amp = torch.xpu.amp
|
||||
if not hasattr(torch.cuda.amp, "common"):
|
||||
torch.cuda.amp.common = contextlib.nullcontext()
|
||||
@@ -138,12 +138,12 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
|
||||
|
||||
#C
|
||||
# C
|
||||
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
|
||||
ipex._C._DeviceProperties.major = 2023
|
||||
ipex._C._DeviceProperties.minor = 2
|
||||
|
||||
#Fix functions with ipex:
|
||||
# Fix functions with ipex:
|
||||
torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
|
||||
torch._utils._get_available_device_type = lambda: "xpu"
|
||||
torch.has_cuda = True
|
||||
@@ -156,20 +156,14 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
torch.cuda.get_device_properties.minor = 7
|
||||
torch.cuda.ipc_collect = lambda *args, **kwargs: None
|
||||
torch.cuda.utilization = lambda *args, **kwargs: 0
|
||||
if hasattr(torch.xpu, 'getDeviceIdListForCard'):
|
||||
torch.cuda.getDeviceIdListForCard = torch.xpu.getDeviceIdListForCard
|
||||
torch.cuda.get_device_id_list_per_card = torch.xpu.getDeviceIdListForCard
|
||||
else:
|
||||
torch.cuda.getDeviceIdListForCard = torch.xpu.get_device_id_list_per_card
|
||||
torch.cuda.get_device_id_list_per_card = torch.xpu.get_device_id_list_per_card
|
||||
|
||||
ipex_hijacks()
|
||||
attention_init()
|
||||
try:
|
||||
from .diffusers import ipex_diffusers
|
||||
ipex_diffusers()
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
pass
|
||||
if not torch.xpu.has_fp64_dtype():
|
||||
try:
|
||||
from .diffusers import ipex_diffusers
|
||||
ipex_diffusers()
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
pass
|
||||
except Exception as e:
|
||||
return False, e
|
||||
return True, None
|
||||
|
||||
@@ -4,11 +4,8 @@ import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unuse
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||
|
||||
original_torch_bmm = torch.bmm
|
||||
def torch_bmm(input, mat2, *, out=None):
|
||||
if input.dtype != mat2.dtype:
|
||||
mat2 = mat2.to(input.dtype)
|
||||
|
||||
#ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
||||
def torch_bmm_32_bit(input, mat2, *, out=None):
|
||||
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
||||
batch_size_attention, input_tokens, mat2_shape = input.shape[0], input.shape[1], mat2.shape[2]
|
||||
block_multiply = input.element_size()
|
||||
slice_block_size = input_tokens * mat2_shape / 1024 / 1024 * block_multiply
|
||||
@@ -17,28 +14,27 @@ def torch_bmm(input, mat2, *, out=None):
|
||||
split_slice_size = batch_size_attention
|
||||
if block_size > 4:
|
||||
do_split = True
|
||||
#Find something divisible with the input_tokens
|
||||
# Find something divisible with the input_tokens
|
||||
while (split_slice_size * slice_block_size) > 4:
|
||||
split_slice_size = split_slice_size // 2
|
||||
if split_slice_size <= 1:
|
||||
split_slice_size = 1
|
||||
break
|
||||
split_2_slice_size = input_tokens
|
||||
if split_slice_size * slice_block_size > 4:
|
||||
slice_block_size_2 = split_slice_size * mat2_shape / 1024 / 1024 * block_multiply
|
||||
do_split_2 = True
|
||||
# Find something divisible with the input_tokens
|
||||
while (split_2_slice_size * slice_block_size_2) > 4:
|
||||
split_2_slice_size = split_2_slice_size // 2
|
||||
if split_2_slice_size <= 1:
|
||||
split_2_slice_size = 1
|
||||
break
|
||||
else:
|
||||
do_split_2 = False
|
||||
else:
|
||||
do_split = False
|
||||
|
||||
split_2_slice_size = input_tokens
|
||||
if split_slice_size * slice_block_size > 4:
|
||||
slice_block_size2 = split_slice_size * mat2_shape / 1024 / 1024 * block_multiply
|
||||
do_split_2 = True
|
||||
#Find something divisible with the input_tokens
|
||||
while (split_2_slice_size * slice_block_size2) > 4:
|
||||
split_2_slice_size = split_2_slice_size // 2
|
||||
if split_2_slice_size <= 1:
|
||||
split_2_slice_size = 1
|
||||
break
|
||||
else:
|
||||
do_split_2 = False
|
||||
|
||||
if do_split:
|
||||
hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype)
|
||||
for i in range(batch_size_attention // split_slice_size):
|
||||
@@ -64,45 +60,54 @@ def torch_bmm(input, mat2, *, out=None):
|
||||
return hidden_states
|
||||
|
||||
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
||||
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
|
||||
#ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
||||
def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
|
||||
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
||||
if len(query.shape) == 3:
|
||||
batch_size_attention, query_tokens, shape_four = query.shape
|
||||
shape_one = 1
|
||||
no_shape_one = True
|
||||
batch_size_attention, query_tokens, shape_three = query.shape
|
||||
shape_four = 1
|
||||
else:
|
||||
shape_one, batch_size_attention, query_tokens, shape_four = query.shape
|
||||
no_shape_one = False
|
||||
batch_size_attention, query_tokens, shape_three, shape_four = query.shape
|
||||
|
||||
block_multiply = query.element_size()
|
||||
slice_block_size = shape_one * query_tokens * shape_four / 1024 / 1024 * block_multiply
|
||||
slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * block_multiply
|
||||
block_size = batch_size_attention * slice_block_size
|
||||
|
||||
split_slice_size = batch_size_attention
|
||||
if block_size > 4:
|
||||
do_split = True
|
||||
#Find something divisible with the shape_one
|
||||
# Find something divisible with the batch_size_attention
|
||||
while (split_slice_size * slice_block_size) > 4:
|
||||
split_slice_size = split_slice_size // 2
|
||||
if split_slice_size <= 1:
|
||||
split_slice_size = 1
|
||||
break
|
||||
split_2_slice_size = query_tokens
|
||||
if split_slice_size * slice_block_size > 4:
|
||||
slice_block_size_2 = split_slice_size * shape_three * shape_four / 1024 / 1024 * block_multiply
|
||||
do_split_2 = True
|
||||
# Find something divisible with the query_tokens
|
||||
while (split_2_slice_size * slice_block_size_2) > 4:
|
||||
split_2_slice_size = split_2_slice_size // 2
|
||||
if split_2_slice_size <= 1:
|
||||
split_2_slice_size = 1
|
||||
break
|
||||
split_3_slice_size = shape_three
|
||||
if split_2_slice_size * slice_block_size_2 > 4:
|
||||
slice_block_size_3 = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * block_multiply
|
||||
do_split_3 = True
|
||||
# Find something divisible with the shape_three
|
||||
while (split_3_slice_size * slice_block_size_3) > 4:
|
||||
split_3_slice_size = split_3_slice_size // 2
|
||||
if split_3_slice_size <= 1:
|
||||
split_3_slice_size = 1
|
||||
break
|
||||
else:
|
||||
do_split_3 = False
|
||||
else:
|
||||
do_split_2 = False
|
||||
else:
|
||||
do_split = False
|
||||
|
||||
split_2_slice_size = query_tokens
|
||||
if split_slice_size * slice_block_size > 4:
|
||||
slice_block_size2 = shape_one * split_slice_size * shape_four / 1024 / 1024 * block_multiply
|
||||
do_split_2 = True
|
||||
#Find something divisible with the batch_size_attention
|
||||
while (split_2_slice_size * slice_block_size2) > 4:
|
||||
split_2_slice_size = split_2_slice_size // 2
|
||||
if split_2_slice_size <= 1:
|
||||
split_2_slice_size = 1
|
||||
break
|
||||
else:
|
||||
do_split_2 = False
|
||||
|
||||
if do_split:
|
||||
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
|
||||
for i in range(batch_size_attention // split_slice_size):
|
||||
@@ -112,7 +117,18 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
|
||||
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
||||
start_idx_2 = i2 * split_2_slice_size
|
||||
end_idx_2 = (i2 + 1) * split_2_slice_size
|
||||
if no_shape_one:
|
||||
if do_split_3:
|
||||
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
|
||||
start_idx_3 = i3 * split_3_slice_size
|
||||
end_idx_3 = (i3 + 1) * split_3_slice_size
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_scaled_dot_product_attention(
|
||||
query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
||||
key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
||||
value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
||||
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask,
|
||||
dropout_p=dropout_p, is_causal=is_causal
|
||||
)
|
||||
else:
|
||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
|
||||
query[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||
key[start_idx:end_idx, start_idx_2:end_idx_2],
|
||||
@@ -120,38 +136,16 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
|
||||
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
|
||||
dropout_p=dropout_p, is_causal=is_causal
|
||||
)
|
||||
else:
|
||||
hidden_states[:, start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
|
||||
query[:, start_idx:end_idx, start_idx_2:end_idx_2],
|
||||
key[:, start_idx:end_idx, start_idx_2:end_idx_2],
|
||||
value[:, start_idx:end_idx, start_idx_2:end_idx_2],
|
||||
attn_mask=attn_mask[:, start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
|
||||
dropout_p=dropout_p, is_causal=is_causal
|
||||
)
|
||||
else:
|
||||
if no_shape_one:
|
||||
hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention(
|
||||
query[start_idx:end_idx],
|
||||
key[start_idx:end_idx],
|
||||
value[start_idx:end_idx],
|
||||
attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask,
|
||||
dropout_p=dropout_p, is_causal=is_causal
|
||||
)
|
||||
else:
|
||||
hidden_states[:, start_idx:end_idx] = original_scaled_dot_product_attention(
|
||||
query[:, start_idx:end_idx],
|
||||
key[:, start_idx:end_idx],
|
||||
value[:, start_idx:end_idx],
|
||||
attn_mask=attn_mask[:, start_idx:end_idx] if attn_mask is not None else attn_mask,
|
||||
dropout_p=dropout_p, is_causal=is_causal
|
||||
)
|
||||
hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention(
|
||||
query[start_idx:end_idx],
|
||||
key[start_idx:end_idx],
|
||||
value[start_idx:end_idx],
|
||||
attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask,
|
||||
dropout_p=dropout_p, is_causal=is_causal
|
||||
)
|
||||
else:
|
||||
return original_scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def attention_init():
|
||||
#ARC GPUs can't allocate more than 4GB to a single block:
|
||||
torch.bmm = torch_bmm
|
||||
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import torch
|
||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||
import diffusers #0.21.1 # pylint: disable=import-error
|
||||
import diffusers #0.24.0 # pylint: disable=import-error
|
||||
from diffusers.models.attention_processor import Attention
|
||||
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||
|
||||
@@ -5,6 +5,7 @@ import intel_extension_for_pytorch._C as core # pylint: disable=import-error, un
|
||||
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||
|
||||
device_supports_fp64 = torch.xpu.has_fp64_dtype()
|
||||
OptState = ipex.cpu.autocast._grad_scaler.OptState
|
||||
_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator
|
||||
_refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state
|
||||
@@ -96,7 +97,10 @@ def unscale_(self, optimizer):
|
||||
|
||||
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
|
||||
assert self._scale is not None
|
||||
inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device)
|
||||
if device_supports_fp64:
|
||||
inv_scale = self._scale.double().reciprocal().float()
|
||||
else:
|
||||
inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device)
|
||||
found_inf = torch.full(
|
||||
(1,), 0.0, dtype=torch.float32, device=self._scale.device
|
||||
)
|
||||
|
||||
@@ -89,6 +89,7 @@ def ipex_autocast(*args, **kwargs):
|
||||
else:
|
||||
return original_autocast(*args, **kwargs)
|
||||
|
||||
# Embedding BF16
|
||||
original_torch_cat = torch.cat
|
||||
def torch_cat(tensor, *args, **kwargs):
|
||||
if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
|
||||
@@ -96,6 +97,7 @@ def torch_cat(tensor, *args, **kwargs):
|
||||
else:
|
||||
return original_torch_cat(tensor, *args, **kwargs)
|
||||
|
||||
# Latent antialias:
|
||||
original_interpolate = torch.nn.functional.interpolate
|
||||
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
|
||||
if antialias or align_corners is not None:
|
||||
@@ -115,19 +117,54 @@ def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name
|
||||
else:
|
||||
return original_linalg_solve(A, B, *args, **kwargs)
|
||||
|
||||
if torch.xpu.has_fp64_dtype():
|
||||
original_torch_bmm = torch.bmm
|
||||
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
||||
else:
|
||||
# 64 bit attention workarounds for Alchemist:
|
||||
try:
|
||||
from .attention import torch_bmm_32_bit as original_torch_bmm
|
||||
from .attention import scaled_dot_product_attention_32_bit as original_scaled_dot_product_attention
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
original_torch_bmm = torch.bmm
|
||||
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
||||
|
||||
# dtype errors:
|
||||
def torch_bmm(input, mat2, *, out=None):
|
||||
if input.dtype != mat2.dtype:
|
||||
mat2 = mat2.to(input.dtype)
|
||||
return original_torch_bmm(input, mat2, out=out)
|
||||
|
||||
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
|
||||
if query.dtype != key.dtype:
|
||||
key = key.to(dtype=query.dtype)
|
||||
if query.dtype != value.dtype:
|
||||
value = value.to(dtype=query.dtype)
|
||||
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
|
||||
|
||||
@property
|
||||
def is_cuda(self):
|
||||
return self.device.type == 'xpu'
|
||||
|
||||
def ipex_hijacks():
|
||||
CondFunc('torch.tensor',
|
||||
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
||||
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
||||
CondFunc('torch.Tensor.to',
|
||||
lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
|
||||
lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
|
||||
CondFunc('torch.Tensor.cuda',
|
||||
lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
|
||||
lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
|
||||
CondFunc('torch.UntypedStorage.__init__',
|
||||
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
||||
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
||||
CondFunc('torch.UntypedStorage.cuda',
|
||||
lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
|
||||
lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
|
||||
CondFunc('torch.empty',
|
||||
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
||||
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
||||
CondFunc('torch.load',
|
||||
lambda orig_func, *args, map_location=None, **kwargs: orig_func(*args, return_xpu(map_location), **kwargs),
|
||||
lambda orig_func, *args, map_location=None, **kwargs: map_location is None or check_device(map_location))
|
||||
CondFunc('torch.randn',
|
||||
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
||||
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
||||
@@ -137,17 +174,23 @@ def ipex_hijacks():
|
||||
CondFunc('torch.zeros',
|
||||
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
||||
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
||||
CondFunc('torch.tensor',
|
||||
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
||||
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
||||
CondFunc('torch.linspace',
|
||||
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
||||
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
||||
CondFunc('torch.load',
|
||||
lambda orig_func, f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs:
|
||||
orig_func(orig_func, f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs),
|
||||
lambda orig_func, f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs: check_device(map_location))
|
||||
if hasattr(torch.xpu, "Generator"):
|
||||
CondFunc('torch.Generator',
|
||||
lambda orig_func, device=None: torch.xpu.Generator(return_xpu(device)),
|
||||
lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu")
|
||||
else:
|
||||
CondFunc('torch.Generator',
|
||||
lambda orig_func, device=None: orig_func(return_xpu(device)),
|
||||
lambda orig_func, device=None: check_device(device))
|
||||
|
||||
CondFunc('torch.Generator',
|
||||
lambda orig_func, device=None: torch.xpu.Generator(device),
|
||||
lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu")
|
||||
|
||||
# TiledVAE and ControlNet:
|
||||
CondFunc('torch.batch_norm',
|
||||
lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input,
|
||||
weight if weight is not None else torch.ones(input.size()[1], device=input.device),
|
||||
@@ -159,38 +202,51 @@ def ipex_hijacks():
|
||||
bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs),
|
||||
lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu"))
|
||||
|
||||
#Functions with dtype errors:
|
||||
# Functions with dtype errors:
|
||||
CondFunc('torch.nn.modules.GroupNorm.forward',
|
||||
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
|
||||
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
|
||||
# Training:
|
||||
CondFunc('torch.nn.modules.linear.Linear.forward',
|
||||
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
|
||||
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
|
||||
CondFunc('torch.nn.modules.conv.Conv2d.forward',
|
||||
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
|
||||
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
|
||||
# BF16:
|
||||
CondFunc('torch.nn.functional.layer_norm',
|
||||
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
|
||||
orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs),
|
||||
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
|
||||
weight is not None and input.dtype != weight.data.dtype)
|
||||
# SwinIR BF16:
|
||||
CondFunc('torch.nn.functional.pad',
|
||||
lambda orig_func, input, pad, mode='constant', value=None: orig_func(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16),
|
||||
lambda orig_func, input, pad, mode='constant', value=None: mode == 'reflect' and input.dtype == torch.bfloat16)
|
||||
|
||||
#Diffusers Float64 (ARC GPUs doesn't support double or Float64):
|
||||
# Diffusers Float64 (Alchemist GPUs doesn't support 64 bit):
|
||||
if not torch.xpu.has_fp64_dtype():
|
||||
CondFunc('torch.from_numpy',
|
||||
lambda orig_func, ndarray: orig_func(ndarray.astype('float32')),
|
||||
lambda orig_func, ndarray: ndarray.dtype == float)
|
||||
|
||||
#Broken functions when torch.cuda.is_available is True:
|
||||
# Broken functions when torch.cuda.is_available is True:
|
||||
# Pin Memory:
|
||||
CondFunc('torch.utils.data.dataloader._BaseDataLoaderIter.__init__',
|
||||
lambda orig_func, *args, **kwargs: ipex_no_cuda(orig_func, *args, **kwargs),
|
||||
lambda orig_func, *args, **kwargs: True)
|
||||
|
||||
#Functions that make compile mad with CondFunc:
|
||||
torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = _shutdown_workers
|
||||
# Functions that make compile mad with CondFunc:
|
||||
torch.nn.DataParallel = DummyDataParallel
|
||||
torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = _shutdown_workers
|
||||
|
||||
torch.autocast = ipex_autocast
|
||||
torch.cat = torch_cat
|
||||
torch.linalg.solve = linalg_solve
|
||||
torch.nn.functional.interpolate = interpolate
|
||||
torch.backends.cuda.sdp_kernel = return_null_context
|
||||
torch.UntypedStorage.is_cuda = is_cuda
|
||||
|
||||
torch.nn.functional.interpolate = interpolate
|
||||
torch.linalg.solve = linalg_solve
|
||||
|
||||
torch.bmm = torch_bmm
|
||||
torch.cat = torch_cat
|
||||
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
|
||||
|
||||
@@ -1307,19 +1307,19 @@ def load_vae(vae_id, dtype):
|
||||
|
||||
def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
|
||||
max_width, max_height = max_reso
|
||||
max_area = (max_width // divisible) * (max_height // divisible)
|
||||
max_area = max_width * max_height
|
||||
|
||||
resos = set()
|
||||
|
||||
size = int(math.sqrt(max_area)) * divisible
|
||||
resos.add((size, size))
|
||||
width = int(math.sqrt(max_area) // divisible) * divisible
|
||||
resos.add((width, width))
|
||||
|
||||
size = min_size
|
||||
while size <= max_size:
|
||||
width = size
|
||||
height = min(max_size, (max_area // (width // divisible)) * divisible)
|
||||
resos.add((width, height))
|
||||
resos.add((height, width))
|
||||
width = min_size
|
||||
while width <= max_size:
|
||||
height = min(max_size, int((max_area // width) // divisible) * divisible)
|
||||
if height >= min_size:
|
||||
resos.add((width, height))
|
||||
resos.add((height, width))
|
||||
|
||||
# # make additional resos
|
||||
# if width >= height and width - divisible >= min_size:
|
||||
@@ -1329,7 +1329,7 @@ def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64)
|
||||
# resos.add((width, height - divisible))
|
||||
# resos.add((height - divisible, width))
|
||||
|
||||
size += divisible
|
||||
width += divisible
|
||||
|
||||
resos = list(resos)
|
||||
resos.sort()
|
||||
|
||||
@@ -361,6 +361,23 @@ def get_timestep_embedding(
|
||||
return emb
|
||||
|
||||
|
||||
# Deep Shrink: We do not common this function, because minimize dependencies.
|
||||
def resize_like(x, target, mode="bicubic", align_corners=False):
|
||||
org_dtype = x.dtype
|
||||
if org_dtype == torch.bfloat16:
|
||||
x = x.to(torch.float32)
|
||||
|
||||
if x.shape[-2:] != target.shape[-2:]:
|
||||
if mode == "nearest":
|
||||
x = F.interpolate(x, size=target.shape[-2:], mode=mode)
|
||||
else:
|
||||
x = F.interpolate(x, size=target.shape[-2:], mode=mode, align_corners=align_corners)
|
||||
|
||||
if org_dtype == torch.bfloat16:
|
||||
x = x.to(org_dtype)
|
||||
return x
|
||||
|
||||
|
||||
class SampleOutput:
|
||||
def __init__(self, sample):
|
||||
self.sample = sample
|
||||
@@ -569,6 +586,9 @@ class CrossAttention(nn.Module):
|
||||
self.use_memory_efficient_attention_mem_eff = False
|
||||
self.use_sdpa = False
|
||||
|
||||
# Attention processor
|
||||
self.processor = None
|
||||
|
||||
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
||||
self.use_memory_efficient_attention_xformers = xformers
|
||||
self.use_memory_efficient_attention_mem_eff = mem_eff
|
||||
@@ -590,7 +610,28 @@ class CrossAttention(nn.Module):
|
||||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
||||
return tensor
|
||||
|
||||
def forward(self, hidden_states, context=None, mask=None):
|
||||
def set_processor(self):
|
||||
return self.processor
|
||||
|
||||
def get_processor(self):
|
||||
return self.processor
|
||||
|
||||
def forward(self, hidden_states, context=None, mask=None, **kwargs):
|
||||
if self.processor is not None:
|
||||
(
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
attention_mask,
|
||||
) = translate_attention_names_from_diffusers(
|
||||
hidden_states=hidden_states, context=context, mask=mask, **kwargs
|
||||
)
|
||||
return self.processor(
|
||||
attn=self,
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=context,
|
||||
attention_mask=mask,
|
||||
**kwargs
|
||||
)
|
||||
if self.use_memory_efficient_attention_xformers:
|
||||
return self.forward_memory_efficient_xformers(hidden_states, context, mask)
|
||||
if self.use_memory_efficient_attention_mem_eff:
|
||||
@@ -703,6 +744,21 @@ class CrossAttention(nn.Module):
|
||||
out = self.to_out[0](out)
|
||||
return out
|
||||
|
||||
def translate_attention_names_from_diffusers(
|
||||
hidden_states: torch.FloatTensor,
|
||||
context: Optional[torch.FloatTensor] = None,
|
||||
mask: Optional[torch.FloatTensor] = None,
|
||||
# HF naming
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None
|
||||
):
|
||||
# translate from hugging face diffusers
|
||||
context = context if context is not None else encoder_hidden_states
|
||||
|
||||
# translate from hugging face diffusers
|
||||
mask = mask if mask is not None else attention_mask
|
||||
|
||||
return hidden_states, context, mask
|
||||
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
@@ -1130,6 +1186,7 @@ class UpBlock2D(nn.Module):
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
@@ -1221,6 +1278,7 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
@@ -1331,7 +1389,7 @@ class UNet2DConditionModel(nn.Module):
|
||||
self.out_channels = OUT_CHANNELS
|
||||
|
||||
self.sample_size = sample_size
|
||||
self.prepare_config()
|
||||
self.prepare_config(sample_size=sample_size)
|
||||
|
||||
# state_dictの書式が変わるのでmoduleの持ち方は変えられない
|
||||
|
||||
@@ -1418,8 +1476,8 @@ class UNet2DConditionModel(nn.Module):
|
||||
self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1)
|
||||
|
||||
# region diffusers compatibility
|
||||
def prepare_config(self):
|
||||
self.config = SimpleNamespace()
|
||||
def prepare_config(self, *args, **kwargs):
|
||||
self.config = SimpleNamespace(**kwargs)
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
@@ -1519,7 +1577,6 @@ class UNet2DConditionModel(nn.Module):
|
||||
# 2. pre-process
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
# 3. down
|
||||
down_block_res_samples = (sample,)
|
||||
for downsample_block in self.down_blocks:
|
||||
# downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
|
||||
@@ -1604,3 +1661,255 @@ class UNet2DConditionModel(nn.Module):
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
|
||||
return timesteps
|
||||
|
||||
|
||||
class InferUNet2DConditionModel:
|
||||
def __init__(self, original_unet: UNet2DConditionModel):
|
||||
self.delegate = original_unet
|
||||
|
||||
# override original model's forward method: because forward is not called by `__call__`
|
||||
# overriding `__call__` is not enough, because nn.Module.forward has a special handling
|
||||
self.delegate.forward = self.forward
|
||||
|
||||
# override original model's up blocks' forward method
|
||||
for up_block in self.delegate.up_blocks:
|
||||
if up_block.__class__.__name__ == "UpBlock2D":
|
||||
|
||||
def resnet_wrapper(func, block):
|
||||
def forward(*args, **kwargs):
|
||||
return func(block, *args, **kwargs)
|
||||
|
||||
return forward
|
||||
|
||||
up_block.forward = resnet_wrapper(self.up_block_forward, up_block)
|
||||
|
||||
elif up_block.__class__.__name__ == "CrossAttnUpBlock2D":
|
||||
|
||||
def cross_attn_up_wrapper(func, block):
|
||||
def forward(*args, **kwargs):
|
||||
return func(block, *args, **kwargs)
|
||||
|
||||
return forward
|
||||
|
||||
up_block.forward = cross_attn_up_wrapper(self.cross_attn_up_block_forward, up_block)
|
||||
|
||||
# Deep Shrink
|
||||
self.ds_depth_1 = None
|
||||
self.ds_depth_2 = None
|
||||
self.ds_timesteps_1 = None
|
||||
self.ds_timesteps_2 = None
|
||||
self.ds_ratio = None
|
||||
|
||||
# call original model's methods
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.delegate, name)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.delegate(*args, **kwargs)
|
||||
|
||||
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
|
||||
if ds_depth_1 is None:
|
||||
print("Deep Shrink is disabled.")
|
||||
self.ds_depth_1 = None
|
||||
self.ds_timesteps_1 = None
|
||||
self.ds_depth_2 = None
|
||||
self.ds_timesteps_2 = None
|
||||
self.ds_ratio = None
|
||||
else:
|
||||
print(
|
||||
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
|
||||
)
|
||||
self.ds_depth_1 = ds_depth_1
|
||||
self.ds_timesteps_1 = ds_timesteps_1
|
||||
self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
|
||||
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
|
||||
self.ds_ratio = ds_ratio
|
||||
|
||||
def up_block_forward(self, _self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
|
||||
for resnet in _self.resnets:
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
|
||||
# Deep Shrink
|
||||
if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
|
||||
hidden_states = resize_like(hidden_states, res_hidden_states)
|
||||
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
if _self.upsamplers is not None:
|
||||
for upsampler in _self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, upsample_size)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def cross_attn_up_block_forward(
|
||||
self,
|
||||
_self,
|
||||
hidden_states,
|
||||
res_hidden_states_tuple,
|
||||
temb=None,
|
||||
encoder_hidden_states=None,
|
||||
upsample_size=None,
|
||||
):
|
||||
for resnet, attn in zip(_self.resnets, _self.attentions):
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
|
||||
# Deep Shrink
|
||||
if res_hidden_states.shape[-2:] != hidden_states.shape[-2:]:
|
||||
hidden_states = resize_like(hidden_states, res_hidden_states)
|
||||
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
if _self.upsamplers is not None:
|
||||
for upsampler in _self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, upsample_size)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
||||
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[Dict, Tuple]:
|
||||
r"""
|
||||
current implementation is a copy of `UNet2DConditionModel.forward()` with Deep Shrink.
|
||||
"""
|
||||
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
||||
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
||||
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a dict instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
`SampleOutput` or `tuple`:
|
||||
`SampleOutput` if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
|
||||
"""
|
||||
|
||||
_self = self.delegate
|
||||
|
||||
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
||||
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
||||
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
||||
# on the fly if necessary.
|
||||
# デフォルトではサンプルは「2^アップサンプルの数」、つまり64の倍数である必要がある
|
||||
# ただそれ以外のサイズにも対応できるように、必要ならアップサンプルのサイズを変更する
|
||||
# 多分画質が悪くなるので、64で割り切れるようにしておくのが良い
|
||||
default_overall_up_factor = 2**_self.num_upsamplers
|
||||
|
||||
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
||||
# 64で割り切れないときはupsamplerにサイズを伝える
|
||||
forward_upsample_size = False
|
||||
upsample_size = None
|
||||
|
||||
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
||||
# logger.info("Forward upsample size to force interpolation output size.")
|
||||
forward_upsample_size = True
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
timesteps = _self.handle_unusual_timesteps(sample, timesteps) # 変な時だけ処理
|
||||
|
||||
t_emb = _self.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
# timestepsは重みを含まないので常にfloat32のテンソルを返す
|
||||
# しかしtime_embeddingはfp16で動いているかもしれないので、ここでキャストする必要がある
|
||||
# time_projでキャストしておけばいいんじゃね?
|
||||
t_emb = t_emb.to(dtype=_self.dtype)
|
||||
emb = _self.time_embedding(t_emb)
|
||||
|
||||
# 2. pre-process
|
||||
sample = _self.conv_in(sample)
|
||||
|
||||
down_block_res_samples = (sample,)
|
||||
for depth, downsample_block in enumerate(_self.down_blocks):
|
||||
# Deep Shrink
|
||||
if self.ds_depth_1 is not None:
|
||||
if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
|
||||
self.ds_depth_2 is not None
|
||||
and depth == self.ds_depth_2
|
||||
and timesteps[0] < self.ds_timesteps_1
|
||||
and timesteps[0] >= self.ds_timesteps_2
|
||||
):
|
||||
org_dtype = sample.dtype
|
||||
if org_dtype == torch.bfloat16:
|
||||
sample = sample.to(torch.float32)
|
||||
sample = F.interpolate(sample, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype)
|
||||
|
||||
# downblockはforwardで必ずencoder_hidden_statesを受け取るようにしても良さそうだけど、
|
||||
# まあこちらのほうがわかりやすいかもしれない
|
||||
if downsample_block.has_cross_attention:
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# skip connectionにControlNetの出力を追加する
|
||||
if down_block_additional_residuals is not None:
|
||||
down_block_res_samples = list(down_block_res_samples)
|
||||
for i in range(len(down_block_res_samples)):
|
||||
down_block_res_samples[i] += down_block_additional_residuals[i]
|
||||
down_block_res_samples = tuple(down_block_res_samples)
|
||||
|
||||
# 4. mid
|
||||
sample = _self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
|
||||
|
||||
# ControlNetの出力を追加する
|
||||
if mid_block_additional_residual is not None:
|
||||
sample += mid_block_additional_residual
|
||||
|
||||
# 5. up
|
||||
for i, upsample_block in enumerate(_self.up_blocks):
|
||||
is_final_block = i == len(_self.up_blocks) - 1
|
||||
|
||||
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
||||
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # skip connection
|
||||
|
||||
# if we have not reached the final block and need to forward the upsample size, we do it here
|
||||
# 前述のように最後のブロック以外ではupsample_sizeを伝える
|
||||
if not is_final_block and forward_upsample_size:
|
||||
upsample_size = down_block_res_samples[-1].shape[2:]
|
||||
|
||||
if upsample_block.has_cross_attention:
|
||||
sample = upsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
res_hidden_states_tuple=res_samples,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
upsample_size=upsample_size,
|
||||
)
|
||||
else:
|
||||
sample = upsample_block(
|
||||
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
|
||||
)
|
||||
|
||||
# 6. post-process
|
||||
sample = _self.conv_norm_out(sample)
|
||||
sample = _self.conv_act(sample)
|
||||
sample = _self.conv_out(sample)
|
||||
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
|
||||
return SampleOutput(sample=sample)
|
||||
|
||||
@@ -133,6 +133,12 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
|
||||
# logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す
|
||||
logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None)
|
||||
|
||||
# temporary workaround for text_projection.weight.weight for Playground-v2
|
||||
if "text_projection.weight.weight" in new_sd:
|
||||
print(f"convert_sdxl_text_encoder_2_checkpoint: convert text_projection.weight.weight to text_projection.weight")
|
||||
new_sd["text_projection.weight"] = new_sd["text_projection.weight.weight"]
|
||||
del new_sd["text_projection.weight.weight"]
|
||||
|
||||
return new_sd, logit_scale
|
||||
|
||||
|
||||
@@ -258,7 +264,7 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
|
||||
te1_sd[k.replace("conditioner.embedders.0.transformer.", "")] = state_dict.pop(k)
|
||||
elif k.startswith("conditioner.embedders.1.model."):
|
||||
te2_sd[k] = state_dict.pop(k)
|
||||
|
||||
|
||||
# 一部のposition_idsがないモデルへの対応 / add position_ids for some models
|
||||
if "text_model.embeddings.position_ids" not in te1_sd:
|
||||
te1_sd["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0)
|
||||
|
||||
@@ -24,7 +24,7 @@
|
||||
|
||||
import math
|
||||
from types import SimpleNamespace
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
@@ -266,6 +266,23 @@ def get_timestep_embedding(
|
||||
return emb
|
||||
|
||||
|
||||
# Deep Shrink: We do not common this function, because minimize dependencies.
|
||||
def resize_like(x, target, mode="bicubic", align_corners=False):
|
||||
org_dtype = x.dtype
|
||||
if org_dtype == torch.bfloat16:
|
||||
x = x.to(torch.float32)
|
||||
|
||||
if x.shape[-2:] != target.shape[-2:]:
|
||||
if mode == "nearest":
|
||||
x = F.interpolate(x, size=target.shape[-2:], mode=mode)
|
||||
else:
|
||||
x = F.interpolate(x, size=target.shape[-2:], mode=mode, align_corners=align_corners)
|
||||
|
||||
if org_dtype == torch.bfloat16:
|
||||
x = x.to(org_dtype)
|
||||
return x
|
||||
|
||||
|
||||
class GroupNorm32(nn.GroupNorm):
|
||||
def forward(self, x):
|
||||
if self.weight.dtype != torch.float32:
|
||||
@@ -1077,6 +1094,7 @@ class SdxlUNet2DConditionModel(nn.Module):
|
||||
|
||||
# h = x.type(self.dtype)
|
||||
h = x
|
||||
|
||||
for module in self.input_blocks:
|
||||
h = call_module(module, h, emb, context)
|
||||
hs.append(h)
|
||||
@@ -1093,6 +1111,121 @@ class SdxlUNet2DConditionModel(nn.Module):
|
||||
return h
|
||||
|
||||
|
||||
class InferSdxlUNet2DConditionModel:
|
||||
def __init__(self, original_unet: SdxlUNet2DConditionModel, **kwargs):
|
||||
self.delegate = original_unet
|
||||
|
||||
# override original model's forward method: because forward is not called by `__call__`
|
||||
# overriding `__call__` is not enough, because nn.Module.forward has a special handling
|
||||
self.delegate.forward = self.forward
|
||||
|
||||
# Deep Shrink
|
||||
self.ds_depth_1 = None
|
||||
self.ds_depth_2 = None
|
||||
self.ds_timesteps_1 = None
|
||||
self.ds_timesteps_2 = None
|
||||
self.ds_ratio = None
|
||||
|
||||
# call original model's methods
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.delegate, name)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.delegate(*args, **kwargs)
|
||||
|
||||
def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5):
|
||||
if ds_depth_1 is None:
|
||||
print("Deep Shrink is disabled.")
|
||||
self.ds_depth_1 = None
|
||||
self.ds_timesteps_1 = None
|
||||
self.ds_depth_2 = None
|
||||
self.ds_timesteps_2 = None
|
||||
self.ds_ratio = None
|
||||
else:
|
||||
print(
|
||||
f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]"
|
||||
)
|
||||
self.ds_depth_1 = ds_depth_1
|
||||
self.ds_timesteps_1 = ds_timesteps_1
|
||||
self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1
|
||||
self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000
|
||||
self.ds_ratio = ds_ratio
|
||||
|
||||
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
|
||||
r"""
|
||||
current implementation is a copy of `SdxlUNet2DConditionModel.forward()` with Deep Shrink.
|
||||
"""
|
||||
_self = self.delegate
|
||||
|
||||
# broadcast timesteps to batch dimension
|
||||
timesteps = timesteps.expand(x.shape[0])
|
||||
|
||||
hs = []
|
||||
t_emb = get_timestep_embedding(timesteps, _self.model_channels) # , repeat_only=False)
|
||||
t_emb = t_emb.to(x.dtype)
|
||||
emb = _self.time_embed(t_emb)
|
||||
|
||||
assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"
|
||||
assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}"
|
||||
# assert x.dtype == _self.dtype
|
||||
emb = emb + _self.label_emb(y)
|
||||
|
||||
def call_module(module, h, emb, context):
|
||||
x = h
|
||||
for layer in module:
|
||||
# print(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None)
|
||||
if isinstance(layer, ResnetBlock2D):
|
||||
x = layer(x, emb)
|
||||
elif isinstance(layer, Transformer2DModel):
|
||||
x = layer(x, context)
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
# h = x.type(self.dtype)
|
||||
h = x
|
||||
|
||||
for depth, module in enumerate(_self.input_blocks):
|
||||
# Deep Shrink
|
||||
if self.ds_depth_1 is not None:
|
||||
if (depth == self.ds_depth_1 and timesteps[0] >= self.ds_timesteps_1) or (
|
||||
self.ds_depth_2 is not None
|
||||
and depth == self.ds_depth_2
|
||||
and timesteps[0] < self.ds_timesteps_1
|
||||
and timesteps[0] >= self.ds_timesteps_2
|
||||
):
|
||||
# print("downsample", h.shape, self.ds_ratio)
|
||||
org_dtype = h.dtype
|
||||
if org_dtype == torch.bfloat16:
|
||||
h = h.to(torch.float32)
|
||||
h = F.interpolate(h, scale_factor=self.ds_ratio, mode="bicubic", align_corners=False).to(org_dtype)
|
||||
|
||||
h = call_module(module, h, emb, context)
|
||||
hs.append(h)
|
||||
|
||||
h = call_module(_self.middle_block, h, emb, context)
|
||||
|
||||
for module in _self.output_blocks:
|
||||
# Deep Shrink
|
||||
if self.ds_depth_1 is not None:
|
||||
if hs[-1].shape[-2:] != h.shape[-2:]:
|
||||
# print("upsample", h.shape, hs[-1].shape)
|
||||
h = resize_like(h, hs[-1])
|
||||
|
||||
h = torch.cat([h, hs.pop()], dim=1)
|
||||
h = call_module(module, h, emb, context)
|
||||
|
||||
# Deep Shrink: in case of depth 0
|
||||
if self.ds_depth_1 == 0 and h.shape[-2:] != x.shape[-2:]:
|
||||
# print("upsample", h.shape, x.shape)
|
||||
h = resize_like(h, x)
|
||||
|
||||
h = h.type(x.dtype)
|
||||
h = call_module(_self.out, h, emb, context)
|
||||
|
||||
return h
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
|
||||
|
||||
@@ -51,8 +51,6 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
||||
torch.cuda.empty_cache()
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
text_encoder1, text_encoder2, unet = train_util.transform_models_if_DDP([text_encoder1, text_encoder2, unet])
|
||||
|
||||
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
|
||||
|
||||
|
||||
|
||||
@@ -62,7 +62,7 @@ def cat_h(sliced):
|
||||
return x
|
||||
|
||||
|
||||
def resblock_forward(_self, num_slices, input_tensor, temb):
|
||||
def resblock_forward(_self, num_slices, input_tensor, temb, **kwargs):
|
||||
assert _self.upsample is None and _self.downsample is None
|
||||
assert _self.norm1.num_groups == _self.norm2.num_groups
|
||||
assert temb is None
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import argparse
|
||||
import ast
|
||||
import asyncio
|
||||
import datetime
|
||||
import importlib
|
||||
import json
|
||||
import pathlib
|
||||
@@ -18,7 +19,7 @@ from typing import (
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
from accelerate import Accelerator
|
||||
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs
|
||||
import gc
|
||||
import glob
|
||||
import math
|
||||
@@ -96,6 +97,7 @@ try:
|
||||
except:
|
||||
pass
|
||||
|
||||
# JPEG-XL on Linux
|
||||
try:
|
||||
from jxlpy import JXLImagePlugin
|
||||
|
||||
@@ -103,6 +105,14 @@ try:
|
||||
except:
|
||||
pass
|
||||
|
||||
# JPEG-XL on Windows
|
||||
try:
|
||||
import pillow_jxl
|
||||
|
||||
IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
|
||||
except:
|
||||
pass
|
||||
|
||||
IMAGE_TRANSFORMS = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
@@ -139,6 +149,13 @@ class ImageInfo:
|
||||
|
||||
class BucketManager:
|
||||
def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None:
|
||||
if max_size is not None:
|
||||
if max_reso is not None:
|
||||
assert max_size >= max_reso[0], "the max_size should be larger than the width of max_reso"
|
||||
assert max_size >= max_reso[1], "the max_size should be larger than the height of max_reso"
|
||||
if min_size is not None:
|
||||
assert max_size >= min_size, "the max_size should be larger than the min_size"
|
||||
|
||||
self.no_upscale = no_upscale
|
||||
if max_reso is None:
|
||||
self.max_reso = None
|
||||
@@ -332,7 +349,9 @@ class BaseSubset:
|
||||
image_dir: Optional[str],
|
||||
num_repeats: int,
|
||||
shuffle_caption: bool,
|
||||
caption_separator: str,
|
||||
keep_tokens: int,
|
||||
keep_tokens_separator: str,
|
||||
color_aug: bool,
|
||||
flip_aug: bool,
|
||||
face_crop_aug_range: Optional[Tuple[float, float]],
|
||||
@@ -348,7 +367,9 @@ class BaseSubset:
|
||||
self.image_dir = image_dir
|
||||
self.num_repeats = num_repeats
|
||||
self.shuffle_caption = shuffle_caption
|
||||
self.caption_separator = caption_separator
|
||||
self.keep_tokens = keep_tokens
|
||||
self.keep_tokens_separator = keep_tokens_separator
|
||||
self.color_aug = color_aug
|
||||
self.flip_aug = flip_aug
|
||||
self.face_crop_aug_range = face_crop_aug_range
|
||||
@@ -374,7 +395,9 @@ class DreamBoothSubset(BaseSubset):
|
||||
caption_extension: str,
|
||||
num_repeats,
|
||||
shuffle_caption,
|
||||
caption_separator: str,
|
||||
keep_tokens,
|
||||
keep_tokens_separator,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
face_crop_aug_range,
|
||||
@@ -393,7 +416,9 @@ class DreamBoothSubset(BaseSubset):
|
||||
image_dir,
|
||||
num_repeats,
|
||||
shuffle_caption,
|
||||
caption_separator,
|
||||
keep_tokens,
|
||||
keep_tokens_separator,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
face_crop_aug_range,
|
||||
@@ -426,7 +451,9 @@ class FineTuningSubset(BaseSubset):
|
||||
metadata_file: str,
|
||||
num_repeats,
|
||||
shuffle_caption,
|
||||
caption_separator,
|
||||
keep_tokens,
|
||||
keep_tokens_separator,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
face_crop_aug_range,
|
||||
@@ -445,7 +472,9 @@ class FineTuningSubset(BaseSubset):
|
||||
image_dir,
|
||||
num_repeats,
|
||||
shuffle_caption,
|
||||
caption_separator,
|
||||
keep_tokens,
|
||||
keep_tokens_separator,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
face_crop_aug_range,
|
||||
@@ -475,7 +504,9 @@ class ControlNetSubset(BaseSubset):
|
||||
caption_extension: str,
|
||||
num_repeats,
|
||||
shuffle_caption,
|
||||
caption_separator,
|
||||
keep_tokens,
|
||||
keep_tokens_separator,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
face_crop_aug_range,
|
||||
@@ -494,7 +525,9 @@ class ControlNetSubset(BaseSubset):
|
||||
image_dir,
|
||||
num_repeats,
|
||||
shuffle_caption,
|
||||
caption_separator,
|
||||
keep_tokens,
|
||||
keep_tokens_separator,
|
||||
color_aug,
|
||||
flip_aug,
|
||||
face_crop_aug_range,
|
||||
@@ -629,15 +662,33 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
caption = ""
|
||||
else:
|
||||
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
|
||||
tokens = [t.strip() for t in caption.strip().split(",")]
|
||||
fixed_tokens = []
|
||||
flex_tokens = []
|
||||
if (
|
||||
hasattr(subset, "keep_tokens_separator")
|
||||
and subset.keep_tokens_separator
|
||||
and subset.keep_tokens_separator in caption
|
||||
):
|
||||
fixed_part, flex_part = caption.split(subset.keep_tokens_separator, 1)
|
||||
fixed_tokens = [t.strip() for t in fixed_part.split(subset.caption_separator) if t.strip()]
|
||||
flex_tokens = [t.strip() for t in flex_part.split(subset.caption_separator) if t.strip()]
|
||||
else:
|
||||
tokens = [t.strip() for t in caption.strip().split(subset.caption_separator)]
|
||||
flex_tokens = tokens[:]
|
||||
if subset.keep_tokens > 0:
|
||||
fixed_tokens = flex_tokens[: subset.keep_tokens]
|
||||
flex_tokens = tokens[subset.keep_tokens :]
|
||||
|
||||
if subset.token_warmup_step < 1: # 初回に上書きする
|
||||
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
|
||||
if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
|
||||
tokens_len = (
|
||||
math.floor((self.current_step) * ((len(tokens) - subset.token_warmup_min) / (subset.token_warmup_step)))
|
||||
math.floor(
|
||||
(self.current_step) * ((len(flex_tokens) - subset.token_warmup_min) / (subset.token_warmup_step))
|
||||
)
|
||||
+ subset.token_warmup_min
|
||||
)
|
||||
tokens = tokens[:tokens_len]
|
||||
flex_tokens = flex_tokens[:tokens_len]
|
||||
|
||||
def dropout_tags(tokens):
|
||||
if subset.caption_tag_dropout_rate <= 0:
|
||||
@@ -648,12 +699,6 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
l.append(token)
|
||||
return l
|
||||
|
||||
fixed_tokens = []
|
||||
flex_tokens = tokens[:]
|
||||
if subset.keep_tokens > 0:
|
||||
fixed_tokens = flex_tokens[: subset.keep_tokens]
|
||||
flex_tokens = tokens[subset.keep_tokens :]
|
||||
|
||||
if subset.shuffle_caption:
|
||||
random.shuffle(flex_tokens)
|
||||
|
||||
@@ -1697,7 +1742,9 @@ class ControlNetDataset(BaseDataset):
|
||||
subset.caption_extension,
|
||||
subset.num_repeats,
|
||||
subset.shuffle_caption,
|
||||
subset.caption_separator,
|
||||
subset.keep_tokens,
|
||||
subset.keep_tokens_separator,
|
||||
subset.color_aug,
|
||||
subset.flip_aug,
|
||||
subset.face_crop_aug_range,
|
||||
@@ -2640,7 +2687,7 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
||||
"--optimizer_type",
|
||||
type=str,
|
||||
default="",
|
||||
help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW8bit, Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor",
|
||||
help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor",
|
||||
)
|
||||
|
||||
# backward compatibility
|
||||
@@ -2846,6 +2893,22 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
parser.add_argument(
|
||||
"--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する"
|
||||
) # TODO move to SDXL training, because it is not supported by SD1/2
|
||||
parser.add_argument(
|
||||
"--ddp_timeout",
|
||||
type=int,
|
||||
default=None,
|
||||
help="DDP timeout (min, None for default of accelerate) / DDPのタイムアウト(分、Noneでaccelerateのデフォルト)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ddp_gradient_as_bucket_view",
|
||||
action="store_true",
|
||||
help="enable gradient_as_bucket_view for DDP / DDPでgradient_as_bucket_viewを有効にする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ddp_static_graph",
|
||||
action="store_true",
|
||||
help="enable static_graph for DDP / DDPでstatic_graphを有効にする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clip_skip",
|
||||
type=int,
|
||||
@@ -2872,6 +2935,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
default=None,
|
||||
help="name of tracker to use for logging, default is script-specific default name / ログ出力に使用するtrackerの名前、省略時はスクリプトごとのデフォルト名",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--wandb_run_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The name of the specific wandb session / wandb ログに表示される特定の実行の名前",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log_tracker_config",
|
||||
type=str,
|
||||
@@ -2948,6 +3017,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
parser.add_argument(
|
||||
"--sample_every_n_steps", type=int, default=None, help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する"
|
||||
)
|
||||
parser.add_argument("--sample_at_first", action="store_true", help="generate sample images before training / 学習前にサンプル出力する")
|
||||
parser.add_argument(
|
||||
"--sample_every_n_epochs",
|
||||
type=int,
|
||||
@@ -3081,9 +3151,8 @@ def add_dataset_arguments(
|
||||
):
|
||||
# dataset common
|
||||
parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument(
|
||||
"--shuffle_caption", action="store_true", help="shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする"
|
||||
)
|
||||
parser.add_argument("--shuffle_caption", action="store_true", help="shuffle separated caption / 区切られたcaptionの各要素をshuffleする")
|
||||
parser.add_argument("--caption_separator", type=str, default=",", help="separator for caption / captionの区切り文字")
|
||||
parser.add_argument(
|
||||
"--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子"
|
||||
)
|
||||
@@ -3099,6 +3168,13 @@ def add_dataset_arguments(
|
||||
default=0,
|
||||
help="keep heading N tokens when shuffling caption tokens (token means comma separated strings) / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す(トークンはカンマ区切りの各部分を意味する)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--keep_tokens_separator",
|
||||
type=str,
|
||||
default="",
|
||||
help="A custom separator to divide the caption into fixed and flexible parts. Tokens before this separator will not be shuffled. If not specified, '--keep_tokens' will be used to determine the fixed number of tokens."
|
||||
+ " / captionを固定部分と可変部分に分けるためのカスタム区切り文字。この区切り文字より前のトークンはシャッフルされない。指定しない場合、'--keep_tokens'が固定部分のトークン数として使用される。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--caption_prefix",
|
||||
type=str,
|
||||
@@ -3350,7 +3426,7 @@ def resume_from_local_or_hf_if_specified(accelerator, args):
|
||||
|
||||
|
||||
def get_optimizer(args, trainable_params):
|
||||
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW8bit, Lion8bit, PagedLion8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor"
|
||||
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor"
|
||||
|
||||
optimizer_type = args.optimizer_type
|
||||
if args.use_8bit_adam:
|
||||
@@ -3454,6 +3530,34 @@ def get_optimizer(args, trainable_params):
|
||||
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type == "PagedAdamW".lower():
|
||||
print(f"use PagedAdamW optimizer | {optimizer_kwargs}")
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです")
|
||||
try:
|
||||
optimizer_class = bnb.optim.PagedAdamW
|
||||
except AttributeError:
|
||||
raise AttributeError(
|
||||
"No PagedAdamW. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedAdamWが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください"
|
||||
)
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type == "PagedAdamW32bit".lower():
|
||||
print(f"use 32-bit PagedAdamW optimizer | {optimizer_kwargs}")
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです")
|
||||
try:
|
||||
optimizer_class = bnb.optim.PagedAdamW32bit
|
||||
except AttributeError:
|
||||
raise AttributeError(
|
||||
"No PagedAdamW32bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedAdamW32bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください"
|
||||
)
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type == "SGDNesterov".lower():
|
||||
print(f"use SGD with Nesterov optimizer | {optimizer_kwargs}")
|
||||
if "momentum" not in optimizer_kwargs:
|
||||
@@ -3772,11 +3876,19 @@ def prepare_accelerator(args: argparse.Namespace):
|
||||
if args.wandb_api_key is not None:
|
||||
wandb.login(key=args.wandb_api_key)
|
||||
|
||||
kwargs_handlers = (
|
||||
InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout)) if args.ddp_timeout else None,
|
||||
DistributedDataParallelKwargs(gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph)
|
||||
if args.ddp_gradient_as_bucket_view or args.ddp_static_graph
|
||||
else None,
|
||||
)
|
||||
kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers))
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
mixed_precision=args.mixed_precision,
|
||||
log_with=log_with,
|
||||
project_dir=logging_dir,
|
||||
kwargs_handlers=kwargs_handlers,
|
||||
)
|
||||
return accelerator
|
||||
|
||||
@@ -3845,17 +3957,6 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une
|
||||
return text_encoder, vae, unet, load_stable_diffusion_format
|
||||
|
||||
|
||||
# TODO remove this function in the future
|
||||
def transform_if_model_is_DDP(text_encoder, unet, network=None):
|
||||
# Transform text_encoder, unet and network from DistributedDataParallel
|
||||
return (model.module if type(model) == DDP else model for model in [text_encoder, unet, network] if model is not None)
|
||||
|
||||
|
||||
def transform_models_if_DDP(models):
|
||||
# Transform text_encoder, unet and network from DistributedDataParallel
|
||||
return [model.module if type(model) == DDP else model for model in models if model is not None]
|
||||
|
||||
|
||||
def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False):
|
||||
# load models for each process
|
||||
for pi in range(accelerator.state.num_processes):
|
||||
@@ -3879,8 +3980,6 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
|
||||
torch.cuda.empty_cache()
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
text_encoder, unet = transform_if_model_is_DDP(text_encoder, unet)
|
||||
|
||||
return text_encoder, vae, unet, load_stable_diffusion_format
|
||||
|
||||
|
||||
@@ -3992,6 +4091,7 @@ def get_hidden_states_sdxl(
|
||||
text_encoder1: CLIPTextModel,
|
||||
text_encoder2: CLIPTextModelWithProjection,
|
||||
weight_dtype: Optional[str] = None,
|
||||
accelerator: Optional[Accelerator] = None,
|
||||
):
|
||||
# input_ids: b,n,77 -> b*n, 77
|
||||
b_size = input_ids1.size()[0]
|
||||
@@ -4007,7 +4107,8 @@ def get_hidden_states_sdxl(
|
||||
hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer
|
||||
|
||||
# pool2 = enc_out["text_embeds"]
|
||||
pool2 = pool_workaround(text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id)
|
||||
unwrapped_text_encoder2 = text_encoder2 if accelerator is None else accelerator.unwrap_model(text_encoder2)
|
||||
pool2 = pool_workaround(unwrapped_text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id)
|
||||
|
||||
# b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280
|
||||
n_size = 1 if max_token_length is None else max_token_length // 75
|
||||
@@ -4366,6 +4467,29 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
|
||||
return noise, noisy_latents, timesteps
|
||||
|
||||
|
||||
def append_lr_to_logs(logs, lr_scheduler, optimizer_type, including_unet=True):
|
||||
names = []
|
||||
if including_unet:
|
||||
names.append("unet")
|
||||
names.append("text_encoder1")
|
||||
names.append("text_encoder2")
|
||||
|
||||
append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names)
|
||||
|
||||
|
||||
def append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names):
|
||||
lrs = lr_scheduler.get_last_lr()
|
||||
|
||||
for lr_index in range(len(lrs)):
|
||||
name = names[lr_index]
|
||||
logs["lr/" + name] = float(lrs[lr_index])
|
||||
|
||||
if optimizer_type.lower().startswith("DAdapt".lower()) or optimizer_type.lower() == "Prodigy".lower():
|
||||
logs["lr/d*lr/" + name] = (
|
||||
lr_scheduler.optimizers[-1].param_groups[lr_index]["d"] * lr_scheduler.optimizers[-1].param_groups[lr_index]["lr"]
|
||||
)
|
||||
|
||||
|
||||
# scheduler:
|
||||
SCHEDULER_LINEAR_START = 0.00085
|
||||
SCHEDULER_LINEAR_END = 0.0120
|
||||
@@ -4373,13 +4497,119 @@ SCHEDULER_TIMESTEPS = 1000
|
||||
SCHEDLER_SCHEDULE = "scaled_linear"
|
||||
|
||||
|
||||
def get_my_scheduler(
|
||||
*,
|
||||
sample_sampler: str,
|
||||
v_parameterization: bool,
|
||||
):
|
||||
sched_init_args = {}
|
||||
if sample_sampler == "ddim":
|
||||
scheduler_cls = DDIMScheduler
|
||||
elif sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある
|
||||
scheduler_cls = DDPMScheduler
|
||||
elif sample_sampler == "pndm":
|
||||
scheduler_cls = PNDMScheduler
|
||||
elif sample_sampler == "lms" or sample_sampler == "k_lms":
|
||||
scheduler_cls = LMSDiscreteScheduler
|
||||
elif sample_sampler == "euler" or sample_sampler == "k_euler":
|
||||
scheduler_cls = EulerDiscreteScheduler
|
||||
elif sample_sampler == "euler_a" or sample_sampler == "k_euler_a":
|
||||
scheduler_cls = EulerAncestralDiscreteScheduler
|
||||
elif sample_sampler == "dpmsolver" or sample_sampler == "dpmsolver++":
|
||||
scheduler_cls = DPMSolverMultistepScheduler
|
||||
sched_init_args["algorithm_type"] = sample_sampler
|
||||
elif sample_sampler == "dpmsingle":
|
||||
scheduler_cls = DPMSolverSinglestepScheduler
|
||||
elif sample_sampler == "heun":
|
||||
scheduler_cls = HeunDiscreteScheduler
|
||||
elif sample_sampler == "dpm_2" or sample_sampler == "k_dpm_2":
|
||||
scheduler_cls = KDPM2DiscreteScheduler
|
||||
elif sample_sampler == "dpm_2_a" or sample_sampler == "k_dpm_2_a":
|
||||
scheduler_cls = KDPM2AncestralDiscreteScheduler
|
||||
else:
|
||||
scheduler_cls = DDIMScheduler
|
||||
|
||||
if v_parameterization:
|
||||
sched_init_args["prediction_type"] = "v_prediction"
|
||||
|
||||
scheduler = scheduler_cls(
|
||||
num_train_timesteps=SCHEDULER_TIMESTEPS,
|
||||
beta_start=SCHEDULER_LINEAR_START,
|
||||
beta_end=SCHEDULER_LINEAR_END,
|
||||
beta_schedule=SCHEDLER_SCHEDULE,
|
||||
**sched_init_args,
|
||||
)
|
||||
|
||||
# clip_sample=Trueにする
|
||||
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
|
||||
# print("set clip_sample to True")
|
||||
scheduler.config.clip_sample = True
|
||||
|
||||
return scheduler
|
||||
|
||||
|
||||
def sample_images(*args, **kwargs):
|
||||
return sample_images_common(StableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
|
||||
|
||||
|
||||
def line_to_prompt_dict(line: str) -> dict:
|
||||
# subset of gen_img_diffusers
|
||||
prompt_args = line.split(" --")
|
||||
prompt_dict = {}
|
||||
prompt_dict["prompt"] = prompt_args[0]
|
||||
|
||||
for parg in prompt_args:
|
||||
try:
|
||||
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
prompt_dict["width"] = int(m.group(1))
|
||||
continue
|
||||
|
||||
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
prompt_dict["height"] = int(m.group(1))
|
||||
continue
|
||||
|
||||
m = re.match(r"d (\d+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
prompt_dict["seed"] = int(m.group(1))
|
||||
continue
|
||||
|
||||
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
|
||||
if m: # steps
|
||||
prompt_dict["sample_steps"] = max(1, min(1000, int(m.group(1))))
|
||||
continue
|
||||
|
||||
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # scale
|
||||
prompt_dict["scale"] = float(m.group(1))
|
||||
continue
|
||||
|
||||
m = re.match(r"n (.+)", parg, re.IGNORECASE)
|
||||
if m: # negative prompt
|
||||
prompt_dict["negative_prompt"] = m.group(1)
|
||||
continue
|
||||
|
||||
m = re.match(r"ss (.+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
prompt_dict["sample_sampler"] = m.group(1)
|
||||
continue
|
||||
|
||||
m = re.match(r"cn (.+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
prompt_dict["controlnet_image"] = m.group(1)
|
||||
continue
|
||||
|
||||
except ValueError as ex:
|
||||
print(f"Exception in parsing / 解析エラー: {parg}")
|
||||
print(ex)
|
||||
|
||||
return prompt_dict
|
||||
|
||||
|
||||
def sample_images_common(
|
||||
pipe_class,
|
||||
accelerator,
|
||||
accelerator: Accelerator,
|
||||
args: argparse.Namespace,
|
||||
epoch,
|
||||
steps,
|
||||
@@ -4394,15 +4624,19 @@ def sample_images_common(
|
||||
"""
|
||||
StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した
|
||||
"""
|
||||
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
|
||||
return
|
||||
if args.sample_every_n_epochs is not None:
|
||||
# sample_every_n_steps は無視する
|
||||
if epoch is None or epoch % args.sample_every_n_epochs != 0:
|
||||
if steps == 0:
|
||||
if not args.sample_at_first:
|
||||
return
|
||||
else:
|
||||
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
|
||||
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
|
||||
return
|
||||
if args.sample_every_n_epochs is not None:
|
||||
# sample_every_n_steps は無視する
|
||||
if epoch is None or epoch % args.sample_every_n_epochs != 0:
|
||||
return
|
||||
else:
|
||||
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
|
||||
return
|
||||
|
||||
print(f"\ngenerating sample images at step / サンプル画像生成 ステップ: {steps}")
|
||||
if not os.path.isfile(args.sample_prompts):
|
||||
@@ -4412,6 +4646,13 @@ def sample_images_common(
|
||||
org_vae_device = vae.device # CPUにいるはず
|
||||
vae.to(device)
|
||||
|
||||
# unwrap unet and text_encoder(s)
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
if isinstance(text_encoder, (list, tuple)):
|
||||
text_encoder = [accelerator.unwrap_model(te) for te in text_encoder]
|
||||
else:
|
||||
text_encoder = accelerator.unwrap_model(text_encoder)
|
||||
|
||||
# read prompts
|
||||
|
||||
# with open(args.sample_prompts, "rt", encoding="utf-8") as f:
|
||||
@@ -4429,56 +4670,19 @@ def sample_images_common(
|
||||
with open(args.sample_prompts, "r", encoding="utf-8") as f:
|
||||
prompts = json.load(f)
|
||||
|
||||
# schedulerを用意する
|
||||
sched_init_args = {}
|
||||
if args.sample_sampler == "ddim":
|
||||
scheduler_cls = DDIMScheduler
|
||||
elif args.sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある
|
||||
scheduler_cls = DDPMScheduler
|
||||
elif args.sample_sampler == "pndm":
|
||||
scheduler_cls = PNDMScheduler
|
||||
elif args.sample_sampler == "lms" or args.sample_sampler == "k_lms":
|
||||
scheduler_cls = LMSDiscreteScheduler
|
||||
elif args.sample_sampler == "euler" or args.sample_sampler == "k_euler":
|
||||
scheduler_cls = EulerDiscreteScheduler
|
||||
elif args.sample_sampler == "euler_a" or args.sample_sampler == "k_euler_a":
|
||||
scheduler_cls = EulerAncestralDiscreteScheduler
|
||||
elif args.sample_sampler == "dpmsolver" or args.sample_sampler == "dpmsolver++":
|
||||
scheduler_cls = DPMSolverMultistepScheduler
|
||||
sched_init_args["algorithm_type"] = args.sample_sampler
|
||||
elif args.sample_sampler == "dpmsingle":
|
||||
scheduler_cls = DPMSolverSinglestepScheduler
|
||||
elif args.sample_sampler == "heun":
|
||||
scheduler_cls = HeunDiscreteScheduler
|
||||
elif args.sample_sampler == "dpm_2" or args.sample_sampler == "k_dpm_2":
|
||||
scheduler_cls = KDPM2DiscreteScheduler
|
||||
elif args.sample_sampler == "dpm_2_a" or args.sample_sampler == "k_dpm_2_a":
|
||||
scheduler_cls = KDPM2AncestralDiscreteScheduler
|
||||
else:
|
||||
scheduler_cls = DDIMScheduler
|
||||
|
||||
if args.v_parameterization:
|
||||
sched_init_args["prediction_type"] = "v_prediction"
|
||||
|
||||
scheduler = scheduler_cls(
|
||||
num_train_timesteps=SCHEDULER_TIMESTEPS,
|
||||
beta_start=SCHEDULER_LINEAR_START,
|
||||
beta_end=SCHEDULER_LINEAR_END,
|
||||
beta_schedule=SCHEDLER_SCHEDULE,
|
||||
**sched_init_args,
|
||||
schedulers: dict = {}
|
||||
default_scheduler = get_my_scheduler(
|
||||
sample_sampler=args.sample_sampler,
|
||||
v_parameterization=args.v_parameterization,
|
||||
)
|
||||
|
||||
# clip_sample=Trueにする
|
||||
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
|
||||
# print("set clip_sample to True")
|
||||
scheduler.config.clip_sample = True
|
||||
schedulers[args.sample_sampler] = default_scheduler
|
||||
|
||||
pipeline = pipe_class(
|
||||
text_encoder=text_encoder,
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=scheduler,
|
||||
scheduler=default_scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
@@ -4494,78 +4698,37 @@ def sample_images_common(
|
||||
|
||||
with torch.no_grad():
|
||||
# with accelerator.autocast():
|
||||
for i, prompt in enumerate(prompts):
|
||||
for i, prompt_dict in enumerate(prompts):
|
||||
if not accelerator.is_main_process:
|
||||
continue
|
||||
|
||||
if isinstance(prompt, dict):
|
||||
negative_prompt = prompt.get("negative_prompt")
|
||||
sample_steps = prompt.get("sample_steps", 30)
|
||||
width = prompt.get("width", 512)
|
||||
height = prompt.get("height", 512)
|
||||
scale = prompt.get("scale", 7.5)
|
||||
seed = prompt.get("seed")
|
||||
controlnet_image = prompt.get("controlnet_image")
|
||||
prompt = prompt.get("prompt")
|
||||
else:
|
||||
# prompt = prompt.strip()
|
||||
# if len(prompt) == 0 or prompt[0] == "#":
|
||||
# continue
|
||||
if isinstance(prompt_dict, str):
|
||||
prompt_dict = line_to_prompt_dict(prompt_dict)
|
||||
|
||||
# subset of gen_img_diffusers
|
||||
prompt_args = prompt.split(" --")
|
||||
prompt = prompt_args[0]
|
||||
negative_prompt = None
|
||||
sample_steps = 30
|
||||
width = height = 512
|
||||
scale = 7.5
|
||||
seed = None
|
||||
controlnet_image = None
|
||||
for parg in prompt_args:
|
||||
try:
|
||||
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
width = int(m.group(1))
|
||||
continue
|
||||
|
||||
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
height = int(m.group(1))
|
||||
continue
|
||||
|
||||
m = re.match(r"d (\d+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
seed = int(m.group(1))
|
||||
continue
|
||||
|
||||
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
|
||||
if m: # steps
|
||||
sample_steps = max(1, min(1000, int(m.group(1))))
|
||||
continue
|
||||
|
||||
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # scale
|
||||
scale = float(m.group(1))
|
||||
continue
|
||||
|
||||
m = re.match(r"n (.+)", parg, re.IGNORECASE)
|
||||
if m: # negative prompt
|
||||
negative_prompt = m.group(1)
|
||||
continue
|
||||
|
||||
m = re.match(r"cn (.+)", parg, re.IGNORECASE)
|
||||
if m: # negative prompt
|
||||
controlnet_image = m.group(1)
|
||||
continue
|
||||
|
||||
except ValueError as ex:
|
||||
print(f"Exception in parsing / 解析エラー: {parg}")
|
||||
print(ex)
|
||||
assert isinstance(prompt_dict, dict)
|
||||
negative_prompt = prompt_dict.get("negative_prompt")
|
||||
sample_steps = prompt_dict.get("sample_steps", 30)
|
||||
width = prompt_dict.get("width", 512)
|
||||
height = prompt_dict.get("height", 512)
|
||||
scale = prompt_dict.get("scale", 7.5)
|
||||
seed = prompt_dict.get("seed")
|
||||
controlnet_image = prompt_dict.get("controlnet_image")
|
||||
prompt: str = prompt_dict.get("prompt", "")
|
||||
sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
|
||||
|
||||
if seed is not None:
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
scheduler = schedulers.get(sampler_name)
|
||||
if scheduler is None:
|
||||
scheduler = get_my_scheduler(
|
||||
sample_sampler=sampler_name,
|
||||
v_parameterization=args.v_parameterization,
|
||||
)
|
||||
schedulers[sampler_name] = scheduler
|
||||
pipeline.scheduler = scheduler
|
||||
|
||||
if prompt_replacement is not None:
|
||||
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
if negative_prompt is not None:
|
||||
@@ -4583,6 +4746,9 @@ def sample_images_common(
|
||||
print(f"width: {width}")
|
||||
print(f"sample_steps: {sample_steps}")
|
||||
print(f"scale: {scale}")
|
||||
print(f"sample_sampler: {sampler_name}")
|
||||
if seed is not None:
|
||||
print(f"seed: {seed}")
|
||||
with accelerator.autocast():
|
||||
latents = pipeline(
|
||||
prompt=prompt,
|
||||
@@ -4676,3 +4842,21 @@ class collator_class:
|
||||
dataset.set_current_epoch(self.current_epoch.value)
|
||||
dataset.set_current_step(self.current_step.value)
|
||||
return examples[0]
|
||||
|
||||
|
||||
class LossRecorder:
|
||||
def __init__(self):
|
||||
self.loss_list: List[float] = []
|
||||
self.loss_total: float = 0.0
|
||||
|
||||
def add(self, *, epoch: int, step: int, loss: float) -> None:
|
||||
if epoch == 0:
|
||||
self.loss_list.append(loss)
|
||||
else:
|
||||
self.loss_total -= self.loss_list[step]
|
||||
self.loss_list[step] = loss
|
||||
self.loss_total += loss
|
||||
|
||||
@property
|
||||
def moving_average(self) -> float:
|
||||
return self.loss_total / len(self.loss_list)
|
||||
|
||||
@@ -13,8 +13,8 @@ from library import sai_model_spec, model_util, sdxl_model_util
|
||||
import lora
|
||||
|
||||
|
||||
CLAMP_QUANTILE = 0.99
|
||||
MIN_DIFF = 1e-1
|
||||
# CLAMP_QUANTILE = 0.99
|
||||
# MIN_DIFF = 1e-1
|
||||
|
||||
|
||||
def save_to_file(file_name, model, state_dict, dtype):
|
||||
@@ -29,7 +29,21 @@ def save_to_file(file_name, model, state_dict, dtype):
|
||||
torch.save(model, file_name)
|
||||
|
||||
|
||||
def svd(args):
|
||||
def svd(
|
||||
model_org=None,
|
||||
model_tuned=None,
|
||||
save_to=None,
|
||||
dim=4,
|
||||
v2=None,
|
||||
sdxl=None,
|
||||
conv_dim=None,
|
||||
v_parameterization=None,
|
||||
device=None,
|
||||
save_precision=None,
|
||||
clamp_quantile=0.99,
|
||||
min_diff=0.01,
|
||||
no_metadata=False,
|
||||
):
|
||||
def str_to_dtype(p):
|
||||
if p == "float":
|
||||
return torch.float
|
||||
@@ -39,44 +53,42 @@ def svd(args):
|
||||
return torch.bfloat16
|
||||
return None
|
||||
|
||||
assert args.v2 != args.sdxl or (
|
||||
not args.v2 and not args.sdxl
|
||||
), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません"
|
||||
if args.v_parameterization is None:
|
||||
args.v_parameterization = args.v2
|
||||
assert v2 != sdxl or (not v2 and not sdxl), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません"
|
||||
if v_parameterization is None:
|
||||
v_parameterization = v2
|
||||
|
||||
save_dtype = str_to_dtype(args.save_precision)
|
||||
save_dtype = str_to_dtype(save_precision)
|
||||
|
||||
# load models
|
||||
if not args.sdxl:
|
||||
print(f"loading original SD model : {args.model_org}")
|
||||
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org)
|
||||
if not sdxl:
|
||||
print(f"loading original SD model : {model_org}")
|
||||
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org)
|
||||
text_encoders_o = [text_encoder_o]
|
||||
print(f"loading tuned SD model : {args.model_tuned}")
|
||||
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned)
|
||||
print(f"loading tuned SD model : {model_tuned}")
|
||||
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned)
|
||||
text_encoders_t = [text_encoder_t]
|
||||
model_version = model_util.get_model_version_str_for_sd1_sd2(args.v2, args.v_parameterization)
|
||||
model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization)
|
||||
else:
|
||||
print(f"loading original SDXL model : {args.model_org}")
|
||||
print(f"loading original SDXL model : {model_org}")
|
||||
text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.model_org, "cpu"
|
||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, "cpu"
|
||||
)
|
||||
text_encoders_o = [text_encoder_o1, text_encoder_o2]
|
||||
print(f"loading original SDXL model : {args.model_tuned}")
|
||||
print(f"loading original SDXL model : {model_tuned}")
|
||||
text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, args.model_tuned, "cpu"
|
||||
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, "cpu"
|
||||
)
|
||||
text_encoders_t = [text_encoder_t1, text_encoder_t2]
|
||||
model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0
|
||||
|
||||
# create LoRA network to extract weights: Use dim (rank) as alpha
|
||||
if args.conv_dim is None:
|
||||
if conv_dim is None:
|
||||
kwargs = {}
|
||||
else:
|
||||
kwargs = {"conv_dim": args.conv_dim, "conv_alpha": args.conv_dim}
|
||||
kwargs = {"conv_dim": conv_dim, "conv_alpha": conv_dim}
|
||||
|
||||
lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoders_o, unet_o, **kwargs)
|
||||
lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoders_t, unet_t, **kwargs)
|
||||
lora_network_o = lora.create_network(1.0, dim, dim, None, text_encoders_o, unet_o, **kwargs)
|
||||
lora_network_t = lora.create_network(1.0, dim, dim, None, text_encoders_t, unet_t, **kwargs)
|
||||
assert len(lora_network_o.text_encoder_loras) == len(
|
||||
lora_network_t.text_encoder_loras
|
||||
), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) "
|
||||
@@ -91,9 +103,9 @@ def svd(args):
|
||||
diff = module_t.weight - module_o.weight
|
||||
|
||||
# Text Encoder might be same
|
||||
if not text_encoder_different and torch.max(torch.abs(diff)) > MIN_DIFF:
|
||||
if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff:
|
||||
text_encoder_different = True
|
||||
print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {MIN_DIFF}")
|
||||
print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}")
|
||||
|
||||
diff = diff.float()
|
||||
diffs[lora_name] = diff
|
||||
@@ -120,16 +132,16 @@ def svd(args):
|
||||
lora_weights = {}
|
||||
with torch.no_grad():
|
||||
for lora_name, mat in tqdm(list(diffs.items())):
|
||||
# if args.conv_dim is None, diffs do not include LoRAs for conv2d-3x3
|
||||
# if conv_dim is None, diffs do not include LoRAs for conv2d-3x3
|
||||
conv2d = len(mat.size()) == 4
|
||||
kernel_size = None if not conv2d else mat.size()[2:4]
|
||||
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
||||
|
||||
rank = args.dim if not conv2d_3x3 or args.conv_dim is None else args.conv_dim
|
||||
rank = dim if not conv2d_3x3 or conv_dim is None else conv_dim
|
||||
out_dim, in_dim = mat.size()[0:2]
|
||||
|
||||
if args.device:
|
||||
mat = mat.to(args.device)
|
||||
if device:
|
||||
mat = mat.to(device)
|
||||
|
||||
# print(lora_name, mat.size(), mat.device, rank, in_dim, out_dim)
|
||||
rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
|
||||
@@ -149,7 +161,7 @@ def svd(args):
|
||||
Vh = Vh[:rank, :]
|
||||
|
||||
dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
||||
hi_val = torch.quantile(dist, clamp_quantile)
|
||||
low_val = -hi_val
|
||||
|
||||
U = U.clamp(low_val, hi_val)
|
||||
@@ -178,34 +190,32 @@ def svd(args):
|
||||
info = lora_network_save.load_state_dict(lora_sd)
|
||||
print(f"Loading extracted LoRA weights: {info}")
|
||||
|
||||
dir_name = os.path.dirname(args.save_to)
|
||||
dir_name = os.path.dirname(save_to)
|
||||
if dir_name and not os.path.exists(dir_name):
|
||||
os.makedirs(dir_name, exist_ok=True)
|
||||
|
||||
# minimum metadata
|
||||
net_kwargs = {}
|
||||
if args.conv_dim is not None:
|
||||
net_kwargs["conv_dim"] = args.conv_dim
|
||||
net_kwargs["conv_alpha"] = args.conv_dim
|
||||
if conv_dim is not None:
|
||||
net_kwargs["conv_dim"] = str(conv_dim)
|
||||
net_kwargs["conv_alpha"] = str(float(conv_dim))
|
||||
|
||||
metadata = {
|
||||
"ss_v2": str(args.v2),
|
||||
"ss_v2": str(v2),
|
||||
"ss_base_model_version": model_version,
|
||||
"ss_network_module": "networks.lora",
|
||||
"ss_network_dim": str(args.dim),
|
||||
"ss_network_alpha": str(args.dim),
|
||||
"ss_network_dim": str(dim),
|
||||
"ss_network_alpha": str(float(dim)),
|
||||
"ss_network_args": json.dumps(net_kwargs),
|
||||
}
|
||||
|
||||
if not args.no_metadata:
|
||||
title = os.path.splitext(os.path.basename(args.save_to))[0]
|
||||
sai_metadata = sai_model_spec.build_metadata(
|
||||
None, args.v2, args.v_parameterization, args.sdxl, True, False, time.time(), title=title
|
||||
)
|
||||
if not no_metadata:
|
||||
title = os.path.splitext(os.path.basename(save_to))[0]
|
||||
sai_metadata = sai_model_spec.build_metadata(None, v2, v_parameterization, sdxl, True, False, time.time(), title=title)
|
||||
metadata.update(sai_metadata)
|
||||
|
||||
lora_network_save.save_weights(args.save_to, save_dtype, metadata)
|
||||
print(f"LoRA weights are saved to: {args.save_to}")
|
||||
lora_network_save.save_weights(save_to, save_dtype, metadata)
|
||||
print(f"LoRA weights are saved to: {save_to}")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
@@ -213,7 +223,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む")
|
||||
parser.add_argument(
|
||||
"--v_parameterization",
|
||||
type=bool,
|
||||
action="store_true",
|
||||
default=None,
|
||||
help="make LoRA metadata for v-parameterization (default is same to v2) / 作成するLoRAのメタデータにv-parameterization用と設定する(省略時はv2と同じ)",
|
||||
)
|
||||
@@ -231,16 +241,22 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
"--model_org",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_tuned",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
|
||||
"--save_to",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors",
|
||||
)
|
||||
parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)")
|
||||
parser.add_argument(
|
||||
@@ -250,6 +266,19 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数(rank)(デフォルトNone、適用なし)",
|
||||
)
|
||||
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
||||
parser.add_argument(
|
||||
"--clamp_quantile",
|
||||
type=float,
|
||||
default=0.99,
|
||||
help="Quantile clamping value, float, (0-1). Default = 0.99 / 値をクランプするための分位点、float、(0-1)。デフォルトは0.99",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min_diff",
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01 /"
|
||||
+ "LoRAを抽出するために元モデルと派生モデルの差分の最小値、float、(0-1)。デフォルトは0.01",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_metadata",
|
||||
action="store_true",
|
||||
@@ -264,4 +293,4 @@ if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
svd(args)
|
||||
svd(**vars(args))
|
||||
|
||||
430
networks/oft.py
Normal file
430
networks/oft.py
Normal file
@@ -0,0 +1,430 @@
|
||||
# OFT network module
|
||||
|
||||
import math
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
from diffusers import AutoencoderKL
|
||||
from transformers import CLIPTextModel
|
||||
import numpy as np
|
||||
import torch
|
||||
import re
|
||||
|
||||
|
||||
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
||||
|
||||
|
||||
class OFTModule(torch.nn.Module):
|
||||
"""
|
||||
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
oft_name,
|
||||
org_module: torch.nn.Module,
|
||||
multiplier=1.0,
|
||||
dim=4,
|
||||
alpha=1,
|
||||
):
|
||||
"""
|
||||
dim -> num blocks
|
||||
alpha -> constraint
|
||||
"""
|
||||
super().__init__()
|
||||
self.oft_name = oft_name
|
||||
|
||||
self.num_blocks = dim
|
||||
|
||||
if "Linear" in org_module.__class__.__name__:
|
||||
out_dim = org_module.out_features
|
||||
elif "Conv" in org_module.__class__.__name__:
|
||||
out_dim = org_module.out_channels
|
||||
|
||||
if type(alpha) == torch.Tensor:
|
||||
alpha = alpha.detach().numpy()
|
||||
self.constraint = alpha * out_dim
|
||||
self.register_buffer("alpha", torch.tensor(alpha))
|
||||
|
||||
self.block_size = out_dim // self.num_blocks
|
||||
self.oft_blocks = torch.nn.Parameter(torch.zeros(self.num_blocks, self.block_size, self.block_size))
|
||||
|
||||
self.out_dim = out_dim
|
||||
self.shape = org_module.weight.shape
|
||||
|
||||
self.multiplier = multiplier
|
||||
self.org_module = [org_module] # moduleにならないようにlistに入れる
|
||||
|
||||
def apply_to(self):
|
||||
self.org_forward = self.org_module[0].forward
|
||||
self.org_module[0].forward = self.forward
|
||||
|
||||
def get_weight(self, multiplier=None):
|
||||
if multiplier is None:
|
||||
multiplier = self.multiplier
|
||||
|
||||
block_Q = self.oft_blocks - self.oft_blocks.transpose(1, 2)
|
||||
norm_Q = torch.norm(block_Q.flatten())
|
||||
new_norm_Q = torch.clamp(norm_Q, max=self.constraint)
|
||||
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
|
||||
I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1)
|
||||
block_R = torch.matmul(I + block_Q, (I - block_Q).inverse())
|
||||
|
||||
block_R_weighted = self.multiplier * block_R + (1 - self.multiplier) * I
|
||||
R = torch.block_diag(*block_R_weighted)
|
||||
|
||||
return R
|
||||
|
||||
def forward(self, x, scale=None):
|
||||
x = self.org_forward(x)
|
||||
if self.multiplier == 0.0:
|
||||
return x
|
||||
|
||||
R = self.get_weight().to(x.device, dtype=x.dtype)
|
||||
if x.dim() == 4:
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
x = torch.matmul(x, R)
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
else:
|
||||
x = torch.matmul(x, R)
|
||||
return x
|
||||
|
||||
|
||||
class OFTInfModule(OFTModule):
|
||||
def __init__(
|
||||
self,
|
||||
oft_name,
|
||||
org_module: torch.nn.Module,
|
||||
multiplier=1.0,
|
||||
dim=4,
|
||||
alpha=1,
|
||||
**kwargs,
|
||||
):
|
||||
# no dropout for inference
|
||||
super().__init__(oft_name, org_module, multiplier, dim, alpha)
|
||||
self.enabled = True
|
||||
self.network: OFTNetwork = None
|
||||
|
||||
def set_network(self, network):
|
||||
self.network = network
|
||||
|
||||
def forward(self, x, scale=None):
|
||||
if not self.enabled:
|
||||
return self.org_forward(x)
|
||||
return super().forward(x, scale)
|
||||
|
||||
def merge_to(self, multiplier=None, sign=1):
|
||||
R = self.get_weight(multiplier) * sign
|
||||
|
||||
# get org weight
|
||||
org_sd = self.org_module[0].state_dict()
|
||||
org_weight = org_sd["weight"]
|
||||
R = R.to(org_weight.device, dtype=org_weight.dtype)
|
||||
|
||||
if org_weight.dim() == 4:
|
||||
weight = torch.einsum("oihw, op -> pihw", org_weight, R)
|
||||
else:
|
||||
weight = torch.einsum("oi, op -> pi", org_weight, R)
|
||||
|
||||
# set weight to org_module
|
||||
org_sd["weight"] = weight
|
||||
self.org_module[0].load_state_dict(org_sd)
|
||||
|
||||
|
||||
def create_network(
|
||||
multiplier: float,
|
||||
network_dim: Optional[int],
|
||||
network_alpha: Optional[float],
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
|
||||
unet,
|
||||
neuron_dropout: Optional[float] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if network_dim is None:
|
||||
network_dim = 4 # default
|
||||
if network_alpha is None:
|
||||
network_alpha = 1.0
|
||||
|
||||
enable_all_linear = kwargs.get("enable_all_linear", None)
|
||||
enable_conv = kwargs.get("enable_conv", None)
|
||||
if enable_all_linear is not None:
|
||||
enable_all_linear = bool(enable_all_linear)
|
||||
if enable_conv is not None:
|
||||
enable_conv = bool(enable_conv)
|
||||
|
||||
network = OFTNetwork(
|
||||
text_encoder,
|
||||
unet,
|
||||
multiplier=multiplier,
|
||||
dim=network_dim,
|
||||
alpha=network_alpha,
|
||||
enable_all_linear=enable_all_linear,
|
||||
enable_conv=enable_conv,
|
||||
varbose=True,
|
||||
)
|
||||
return network
|
||||
|
||||
|
||||
# Create network from weights for inference, weights are not loaded here (because can be merged)
|
||||
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
|
||||
if weights_sd is None:
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file, safe_open
|
||||
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
|
||||
# check dim, alpha and if weights have for conv2d
|
||||
dim = None
|
||||
alpha = None
|
||||
has_conv2d = None
|
||||
all_linear = None
|
||||
for name, param in weights_sd.items():
|
||||
if name.endswith(".alpha"):
|
||||
if alpha is None:
|
||||
alpha = param.item()
|
||||
else:
|
||||
if dim is None:
|
||||
dim = param.size()[0]
|
||||
if has_conv2d is None and param.dim() == 4:
|
||||
has_conv2d = True
|
||||
if all_linear is None:
|
||||
if param.dim() == 3 and "attn" not in name:
|
||||
all_linear = True
|
||||
if dim is not None and alpha is not None and has_conv2d is not None:
|
||||
break
|
||||
if has_conv2d is None:
|
||||
has_conv2d = False
|
||||
if all_linear is None:
|
||||
all_linear = False
|
||||
|
||||
module_class = OFTInfModule if for_inference else OFTModule
|
||||
network = OFTNetwork(
|
||||
text_encoder,
|
||||
unet,
|
||||
multiplier=multiplier,
|
||||
dim=dim,
|
||||
alpha=alpha,
|
||||
enable_all_linear=all_linear,
|
||||
enable_conv=has_conv2d,
|
||||
module_class=module_class,
|
||||
)
|
||||
return network, weights_sd
|
||||
|
||||
|
||||
class OFTNetwork(torch.nn.Module):
|
||||
UNET_TARGET_REPLACE_MODULE_ATTN_ONLY = ["CrossAttention"]
|
||||
UNET_TARGET_REPLACE_MODULE_ALL_LINEAR = ["Transformer2DModel"]
|
||||
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
||||
OFT_PREFIX_UNET = "oft_unet" # これ変えないほうがいいかな
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
|
||||
unet,
|
||||
multiplier: float = 1.0,
|
||||
dim: int = 4,
|
||||
alpha: float = 1,
|
||||
enable_all_linear: Optional[bool] = False,
|
||||
enable_conv: Optional[bool] = False,
|
||||
module_class: Type[object] = OFTModule,
|
||||
varbose: Optional[bool] = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.multiplier = multiplier
|
||||
|
||||
self.dim = dim
|
||||
self.alpha = alpha
|
||||
|
||||
print(
|
||||
f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}"
|
||||
)
|
||||
|
||||
# create module instances
|
||||
def create_modules(
|
||||
root_module: torch.nn.Module,
|
||||
target_replace_modules: List[torch.nn.Module],
|
||||
) -> List[OFTModule]:
|
||||
prefix = self.OFT_PREFIX_UNET
|
||||
ofts = []
|
||||
for name, module in root_module.named_modules():
|
||||
if module.__class__.__name__ in target_replace_modules:
|
||||
for child_name, child_module in module.named_modules():
|
||||
is_linear = "Linear" in child_module.__class__.__name__
|
||||
is_conv2d = "Conv2d" in child_module.__class__.__name__
|
||||
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
||||
|
||||
if is_linear or is_conv2d_1x1 or (is_conv2d and enable_conv):
|
||||
oft_name = prefix + "." + name + "." + child_name
|
||||
oft_name = oft_name.replace(".", "_")
|
||||
# print(oft_name)
|
||||
|
||||
oft = module_class(
|
||||
oft_name,
|
||||
child_module,
|
||||
self.multiplier,
|
||||
dim,
|
||||
alpha,
|
||||
)
|
||||
ofts.append(oft)
|
||||
return ofts
|
||||
|
||||
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
||||
if enable_all_linear:
|
||||
target_modules = OFTNetwork.UNET_TARGET_REPLACE_MODULE_ALL_LINEAR
|
||||
else:
|
||||
target_modules = OFTNetwork.UNET_TARGET_REPLACE_MODULE_ATTN_ONLY
|
||||
if enable_conv:
|
||||
target_modules += OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||
|
||||
self.unet_ofts: List[OFTModule] = create_modules(unet, target_modules)
|
||||
print(f"create OFT for U-Net: {len(self.unet_ofts)} modules.")
|
||||
|
||||
# assertion
|
||||
names = set()
|
||||
for oft in self.unet_ofts:
|
||||
assert oft.oft_name not in names, f"duplicated oft name: {oft.oft_name}"
|
||||
names.add(oft.oft_name)
|
||||
|
||||
def set_multiplier(self, multiplier):
|
||||
self.multiplier = multiplier
|
||||
for oft in self.unet_ofts:
|
||||
oft.multiplier = self.multiplier
|
||||
|
||||
def load_weights(self, file):
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
|
||||
info = self.load_state_dict(weights_sd, False)
|
||||
return info
|
||||
|
||||
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
|
||||
assert apply_unet, "apply_unet must be True"
|
||||
|
||||
for oft in self.unet_ofts:
|
||||
oft.apply_to()
|
||||
self.add_module(oft.oft_name, oft)
|
||||
|
||||
# マージできるかどうかを返す
|
||||
def is_mergeable(self):
|
||||
return True
|
||||
|
||||
# TODO refactor to common function with apply_to
|
||||
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
|
||||
print("enable OFT for U-Net")
|
||||
|
||||
for oft in self.unet_ofts:
|
||||
sd_for_lora = {}
|
||||
for key in weights_sd.keys():
|
||||
if key.startswith(oft.oft_name):
|
||||
sd_for_lora[key[len(oft.oft_name) + 1 :]] = weights_sd[key]
|
||||
oft.load_state_dict(sd_for_lora, False)
|
||||
oft.merge_to()
|
||||
|
||||
print(f"weights are merged")
|
||||
|
||||
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
||||
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
||||
self.requires_grad_(True)
|
||||
all_params = []
|
||||
|
||||
def enumerate_params(ofts):
|
||||
params = []
|
||||
for oft in ofts:
|
||||
params.extend(oft.parameters())
|
||||
|
||||
# print num of params
|
||||
num_params = 0
|
||||
for p in params:
|
||||
num_params += p.numel()
|
||||
print(f"OFT params: {num_params}")
|
||||
return params
|
||||
|
||||
param_data = {"params": enumerate_params(self.unet_ofts)}
|
||||
if unet_lr is not None:
|
||||
param_data["lr"] = unet_lr
|
||||
all_params.append(param_data)
|
||||
|
||||
return all_params
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
# not supported
|
||||
pass
|
||||
|
||||
def prepare_grad_etc(self, text_encoder, unet):
|
||||
self.requires_grad_(True)
|
||||
|
||||
def on_epoch_start(self, text_encoder, unet):
|
||||
self.train()
|
||||
|
||||
def get_trainable_params(self):
|
||||
return self.parameters()
|
||||
|
||||
def save_weights(self, file, dtype, metadata):
|
||||
if metadata is not None and len(metadata) == 0:
|
||||
metadata = None
|
||||
|
||||
state_dict = self.state_dict()
|
||||
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
v = state_dict[key]
|
||||
v = v.detach().clone().to("cpu").to(dtype)
|
||||
state_dict[key] = v
|
||||
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import save_file
|
||||
from library import train_util
|
||||
|
||||
# Precalculate model hashes to save time on indexing
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
metadata["sshs_model_hash"] = model_hash
|
||||
metadata["sshs_legacy_hash"] = legacy_hash
|
||||
|
||||
save_file(state_dict, file, metadata)
|
||||
else:
|
||||
torch.save(state_dict, file)
|
||||
|
||||
def backup_weights(self):
|
||||
# 重みのバックアップを行う
|
||||
ofts: List[OFTInfModule] = self.unet_ofts
|
||||
for oft in ofts:
|
||||
org_module = oft.org_module[0]
|
||||
if not hasattr(org_module, "_lora_org_weight"):
|
||||
sd = org_module.state_dict()
|
||||
org_module._lora_org_weight = sd["weight"].detach().clone()
|
||||
org_module._lora_restored = True
|
||||
|
||||
def restore_weights(self):
|
||||
# 重みのリストアを行う
|
||||
ofts: List[OFTInfModule] = self.unet_ofts
|
||||
for oft in ofts:
|
||||
org_module = oft.org_module[0]
|
||||
if not org_module._lora_restored:
|
||||
sd = org_module.state_dict()
|
||||
sd["weight"] = org_module._lora_org_weight
|
||||
org_module.load_state_dict(sd)
|
||||
org_module._lora_restored = True
|
||||
|
||||
def pre_calculation(self):
|
||||
# 事前計算を行う
|
||||
ofts: List[OFTInfModule] = self.unet_ofts
|
||||
for oft in ofts:
|
||||
org_module = oft.org_module[0]
|
||||
oft.merge_to()
|
||||
# sd = org_module.state_dict()
|
||||
# org_weight = sd["weight"]
|
||||
# lora_weight = oft.get_weight().to(org_weight.device, dtype=org_weight.dtype)
|
||||
# sd["weight"] = org_weight + lora_weight
|
||||
# assert sd["weight"].shape == org_weight.shape
|
||||
# org_module.load_state_dict(sd)
|
||||
|
||||
org_module._lora_restored = False
|
||||
oft.enabled = False
|
||||
@@ -19,8 +19,14 @@ huggingface-hub==0.15.1
|
||||
# requests==2.28.2
|
||||
# timm==0.6.12
|
||||
# fairscale==0.4.13
|
||||
# for WD14 captioning
|
||||
# for WD14 captioning (tensorflow)
|
||||
# tensorflow==2.10.1
|
||||
# for WD14 captioning (onnx)
|
||||
# onnx==1.14.1
|
||||
# onnxruntime-gpu==1.16.0
|
||||
# onnxruntime==1.16.0
|
||||
# this is for onnx:
|
||||
# protobuf==3.20.3
|
||||
# open clip for SDXL
|
||||
open-clip-torch==2.20.0
|
||||
# for kohya_ss library
|
||||
|
||||
172
sdxl_gen_img.py
172
sdxl_gen_img.py
@@ -17,10 +17,13 @@ import re
|
||||
import diffusers
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
@@ -54,7 +57,7 @@ import library.train_util as train_util
|
||||
import library.sdxl_model_util as sdxl_model_util
|
||||
import library.sdxl_train_util as sdxl_train_util
|
||||
from networks.lora import LoRANetwork
|
||||
from library.sdxl_original_unet import SdxlUNet2DConditionModel
|
||||
from library.sdxl_original_unet import InferSdxlUNet2DConditionModel
|
||||
from library.original_unet import FlashAttentionFunction
|
||||
from networks.control_net_lllite import ControlNetLLLite
|
||||
|
||||
@@ -287,7 +290,7 @@ class PipelineLike:
|
||||
vae: AutoencoderKL,
|
||||
text_encoders: List[CLIPTextModel],
|
||||
tokenizers: List[CLIPTokenizer],
|
||||
unet: SdxlUNet2DConditionModel,
|
||||
unet: InferSdxlUNet2DConditionModel,
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||
clip_skip: int,
|
||||
):
|
||||
@@ -325,7 +328,7 @@ class PipelineLike:
|
||||
self.vae = vae
|
||||
self.text_encoders = text_encoders
|
||||
self.tokenizers = tokenizers
|
||||
self.unet: SdxlUNet2DConditionModel = unet
|
||||
self.unet: InferSdxlUNet2DConditionModel = unet
|
||||
self.scheduler = scheduler
|
||||
self.safety_checker = None
|
||||
|
||||
@@ -501,7 +504,8 @@ class PipelineLike:
|
||||
uncond_embeddings = tes_uncond_embs[0]
|
||||
for i in range(1, len(tes_text_embs)):
|
||||
text_embeddings = torch.cat([text_embeddings, tes_text_embs[i]], dim=2) # n,77,2048
|
||||
uncond_embeddings = torch.cat([uncond_embeddings, tes_uncond_embs[i]], dim=2) # n,77,2048
|
||||
if do_classifier_free_guidance:
|
||||
uncond_embeddings = torch.cat([uncond_embeddings, tes_uncond_embs[i]], dim=2) # n,77,2048
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
if negative_scale is None:
|
||||
@@ -564,9 +568,11 @@ class PipelineLike:
|
||||
text_pool = clip_vision_embeddings # replace: same as ComfyUI (?)
|
||||
|
||||
c_vector = torch.cat([text_pool, c_vector], dim=1)
|
||||
uc_vector = torch.cat([uncond_pool, uc_vector], dim=1)
|
||||
|
||||
vector_embeddings = torch.cat([uc_vector, c_vector])
|
||||
if do_classifier_free_guidance:
|
||||
uc_vector = torch.cat([uncond_pool, uc_vector], dim=1)
|
||||
vector_embeddings = torch.cat([uc_vector, c_vector])
|
||||
else:
|
||||
vector_embeddings = c_vector
|
||||
|
||||
# set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, self.device)
|
||||
@@ -1368,6 +1374,7 @@ def main(args):
|
||||
(_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model(
|
||||
args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, dtype
|
||||
)
|
||||
unet: InferSdxlUNet2DConditionModel = InferSdxlUNet2DConditionModel(unet)
|
||||
|
||||
# xformers、Hypernetwork対応
|
||||
if not args.diffusers_xformers:
|
||||
@@ -1523,10 +1530,14 @@ def main(args):
|
||||
print("set vae_dtype to float32")
|
||||
vae_dtype = torch.float32
|
||||
vae.to(vae_dtype).to(device)
|
||||
vae.eval()
|
||||
|
||||
text_encoder1.to(dtype).to(device)
|
||||
text_encoder2.to(dtype).to(device)
|
||||
unet.to(dtype).to(device)
|
||||
text_encoder1.eval()
|
||||
text_encoder2.eval()
|
||||
unet.eval()
|
||||
|
||||
# networkを組み込む
|
||||
if args.network_module:
|
||||
@@ -1534,12 +1545,20 @@ def main(args):
|
||||
network_default_muls = []
|
||||
network_pre_calc = args.network_pre_calc
|
||||
|
||||
# merge関連の引数を統合する
|
||||
if args.network_merge:
|
||||
network_merge = len(args.network_module) # all networks are merged
|
||||
elif args.network_merge_n_models:
|
||||
network_merge = args.network_merge_n_models
|
||||
else:
|
||||
network_merge = 0
|
||||
print(f"network_merge: {network_merge}")
|
||||
|
||||
for i, network_module in enumerate(args.network_module):
|
||||
print("import network module:", network_module)
|
||||
imported_module = importlib.import_module(network_module)
|
||||
|
||||
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
|
||||
network_default_muls.append(network_mul)
|
||||
|
||||
net_kwargs = {}
|
||||
if args.network_args and i < len(args.network_args):
|
||||
@@ -1550,31 +1569,32 @@ def main(args):
|
||||
key, value = net_arg.split("=")
|
||||
net_kwargs[key] = value
|
||||
|
||||
if args.network_weights and i < len(args.network_weights):
|
||||
network_weight = args.network_weights[i]
|
||||
print("load network weights from:", network_weight)
|
||||
|
||||
if model_util.is_safetensors(network_weight) and args.network_show_meta:
|
||||
from safetensors.torch import safe_open
|
||||
|
||||
with safe_open(network_weight, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata is not None:
|
||||
print(f"metadata for: {network_weight}: {metadata}")
|
||||
|
||||
network, weights_sd = imported_module.create_network_from_weights(
|
||||
network_mul, network_weight, vae, [text_encoder1, text_encoder2], unet, for_inference=True, **net_kwargs
|
||||
)
|
||||
else:
|
||||
if args.network_weights is None or len(args.network_weights) <= i:
|
||||
raise ValueError("No weight. Weight is required.")
|
||||
|
||||
network_weight = args.network_weights[i]
|
||||
print("load network weights from:", network_weight)
|
||||
|
||||
if model_util.is_safetensors(network_weight) and args.network_show_meta:
|
||||
from safetensors.torch import safe_open
|
||||
|
||||
with safe_open(network_weight, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata is not None:
|
||||
print(f"metadata for: {network_weight}: {metadata}")
|
||||
|
||||
network, weights_sd = imported_module.create_network_from_weights(
|
||||
network_mul, network_weight, vae, [text_encoder1, text_encoder2], unet, for_inference=True, **net_kwargs
|
||||
)
|
||||
if network is None:
|
||||
return
|
||||
|
||||
mergeable = network.is_mergeable()
|
||||
if args.network_merge and not mergeable:
|
||||
if network_merge and not mergeable:
|
||||
print("network is not mergiable. ignore merge option.")
|
||||
|
||||
if not args.network_merge or not mergeable:
|
||||
if not mergeable or i >= network_merge:
|
||||
# not merging
|
||||
network.apply_to([text_encoder1, text_encoder2], unet)
|
||||
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
|
||||
print(f"weights are loaded: {info}")
|
||||
@@ -1588,6 +1608,7 @@ def main(args):
|
||||
network.backup_weights()
|
||||
|
||||
networks.append(network)
|
||||
network_default_muls.append(network_mul)
|
||||
else:
|
||||
network.merge_to([text_encoder1, text_encoder2], unet, weights_sd, dtype, device)
|
||||
|
||||
@@ -1683,6 +1704,10 @@ def main(args):
|
||||
if args.diffusers_xformers:
|
||||
pipe.enable_xformers_memory_efficient_attention()
|
||||
|
||||
# Deep Shrink
|
||||
if args.ds_depth_1 is not None:
|
||||
unet.set_deep_shrink(args.ds_depth_1, args.ds_timesteps_1, args.ds_depth_2, args.ds_timesteps_2, args.ds_ratio)
|
||||
|
||||
# Textual Inversionを処理する
|
||||
if args.textual_inversion_embeddings:
|
||||
token_ids_embeds1 = []
|
||||
@@ -1864,9 +1889,18 @@ def main(args):
|
||||
|
||||
size = None
|
||||
for i, network in enumerate(networks):
|
||||
if i < 3:
|
||||
if (i < 3 and args.network_regional_mask_max_color_codes is None) or i < args.network_regional_mask_max_color_codes:
|
||||
np_mask = np.array(mask_images[0])
|
||||
np_mask = np_mask[:, :, i]
|
||||
|
||||
if args.network_regional_mask_max_color_codes:
|
||||
# カラーコードでマスクを指定する
|
||||
ch0 = (i + 1) & 1
|
||||
ch1 = ((i + 1) >> 1) & 1
|
||||
ch2 = ((i + 1) >> 2) & 1
|
||||
np_mask = np.all(np_mask == np.array([ch0, ch1, ch2]) * 255, axis=2)
|
||||
np_mask = np_mask.astype(np.uint8) * 255
|
||||
else:
|
||||
np_mask = np_mask[:, :, i]
|
||||
size = np_mask.shape
|
||||
else:
|
||||
np_mask = np.full(size, 255, dtype=np.uint8)
|
||||
@@ -2264,6 +2298,13 @@ def main(args):
|
||||
clip_prompt = None
|
||||
network_muls = None
|
||||
|
||||
# Deep Shrink
|
||||
ds_depth_1 = None # means no override
|
||||
ds_timesteps_1 = args.ds_timesteps_1
|
||||
ds_depth_2 = args.ds_depth_2
|
||||
ds_timesteps_2 = args.ds_timesteps_2
|
||||
ds_ratio = args.ds_ratio
|
||||
|
||||
prompt_args = raw_prompt.strip().split(" --")
|
||||
prompt = prompt_args[0]
|
||||
print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}")
|
||||
@@ -2371,10 +2412,51 @@ def main(args):
|
||||
print(f"network mul: {network_muls}")
|
||||
continue
|
||||
|
||||
# Deep Shrink
|
||||
m = re.match(r"dsd1 ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # deep shrink depth 1
|
||||
ds_depth_1 = int(m.group(1))
|
||||
print(f"deep shrink depth 1: {ds_depth_1}")
|
||||
continue
|
||||
|
||||
m = re.match(r"dst1 ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # deep shrink timesteps 1
|
||||
ds_timesteps_1 = int(m.group(1))
|
||||
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
|
||||
print(f"deep shrink timesteps 1: {ds_timesteps_1}")
|
||||
continue
|
||||
|
||||
m = re.match(r"dsd2 ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # deep shrink depth 2
|
||||
ds_depth_2 = int(m.group(1))
|
||||
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
|
||||
print(f"deep shrink depth 2: {ds_depth_2}")
|
||||
continue
|
||||
|
||||
m = re.match(r"dst2 ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # deep shrink timesteps 2
|
||||
ds_timesteps_2 = int(m.group(1))
|
||||
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
|
||||
print(f"deep shrink timesteps 2: {ds_timesteps_2}")
|
||||
continue
|
||||
|
||||
m = re.match(r"dsr ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # deep shrink ratio
|
||||
ds_ratio = float(m.group(1))
|
||||
ds_depth_1 = ds_depth_1 if ds_depth_1 is not None else -1 # -1 means override
|
||||
print(f"deep shrink ratio: {ds_ratio}")
|
||||
continue
|
||||
|
||||
except ValueError as ex:
|
||||
print(f"Exception in parsing / 解析エラー: {parg}")
|
||||
print(ex)
|
||||
|
||||
# override Deep Shrink
|
||||
if ds_depth_1 is not None:
|
||||
if ds_depth_1 < 0:
|
||||
ds_depth_1 = args.ds_depth_1 or 3
|
||||
unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio)
|
||||
|
||||
# prepare seed
|
||||
if seeds is not None: # given in prompt
|
||||
# 数が足りないなら前のをそのまま使う
|
||||
@@ -2615,10 +2697,19 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
"--network_args", type=str, default=None, nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数"
|
||||
)
|
||||
parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する")
|
||||
parser.add_argument(
|
||||
"--network_merge_n_models", type=int, default=None, help="merge this number of networks / この数だけネットワークをマージする"
|
||||
)
|
||||
parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする")
|
||||
parser.add_argument(
|
||||
"--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--network_regional_mask_max_color_codes",
|
||||
type=int,
|
||||
default=None,
|
||||
help="max color codes for regional mask (default is None, mask by channel) / regional maskの最大色数(デフォルトはNoneでチャンネルごとのマスク)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--textual_inversion_embeddings",
|
||||
type=str,
|
||||
@@ -2703,6 +2794,31 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
default=None,
|
||||
help="enable CLIP Vision Conditioning for img2img with this strength / img2imgでCLIP Vision Conditioningを有効にしてこのstrengthで処理する",
|
||||
)
|
||||
|
||||
# Deep Shrink
|
||||
parser.add_argument(
|
||||
"--ds_depth_1",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Enable Deep Shrink with this depth 1, valid values are 0 to 8 / Deep Shrinkをこのdepthで有効にする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ds_timesteps_1",
|
||||
type=int,
|
||||
default=650,
|
||||
help="Apply Deep Shrink depth 1 until this timesteps / Deep Shrink depth 1を適用するtimesteps",
|
||||
)
|
||||
parser.add_argument("--ds_depth_2", type=int, default=None, help="Deep Shrink depth 2 / Deep Shrinkのdepth 2")
|
||||
parser.add_argument(
|
||||
"--ds_timesteps_2",
|
||||
type=int,
|
||||
default=650,
|
||||
help="Apply Deep Shrink depth 2 until this timesteps / Deep Shrink depth 2を適用するtimesteps",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ds_ratio", type=float, default=0.5, help="Deep Shrink ratio for downsampling / Deep Shrinkのdownsampling比率"
|
||||
)
|
||||
|
||||
# # parser.add_argument(
|
||||
# "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像"
|
||||
# )
|
||||
|
||||
171
sdxl_train.py
171
sdxl_train.py
@@ -10,10 +10,13 @@ import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
@@ -34,6 +37,7 @@ from library.custom_train_functions import (
|
||||
prepare_scheduler_for_custom_training,
|
||||
scale_v_prediction_loss_like_noise_prediction,
|
||||
add_v_prediction_like_loss,
|
||||
apply_debiased_estimation,
|
||||
)
|
||||
from library.sdxl_original_unet import SdxlUNet2DConditionModel
|
||||
|
||||
@@ -70,33 +74,22 @@ def get_block_params_to_optimize(unet: SdxlUNet2DConditionModel, block_lrs: List
|
||||
|
||||
|
||||
def append_block_lr_to_logs(block_lrs, logs, lr_scheduler, optimizer_type):
|
||||
lrs = lr_scheduler.get_last_lr()
|
||||
|
||||
lr_index = 0
|
||||
names = []
|
||||
block_index = 0
|
||||
while lr_index < len(lrs):
|
||||
while block_index < UNET_NUM_BLOCKS_FOR_BLOCK_LR + 2:
|
||||
if block_index < UNET_NUM_BLOCKS_FOR_BLOCK_LR:
|
||||
name = f"block{block_index}"
|
||||
if block_lrs[block_index] == 0:
|
||||
block_index += 1
|
||||
continue
|
||||
names.append(f"block{block_index}")
|
||||
elif block_index == UNET_NUM_BLOCKS_FOR_BLOCK_LR:
|
||||
name = "text_encoder1"
|
||||
names.append("text_encoder1")
|
||||
elif block_index == UNET_NUM_BLOCKS_FOR_BLOCK_LR + 1:
|
||||
name = "text_encoder2"
|
||||
else:
|
||||
raise ValueError(f"unexpected block_index: {block_index}")
|
||||
names.append("text_encoder2")
|
||||
|
||||
block_index += 1
|
||||
|
||||
logs["lr/" + name] = float(lrs[lr_index])
|
||||
|
||||
if optimizer_type.lower().startswith("DAdapt".lower()) or optimizer_type.lower() == "Prodigy".lower():
|
||||
logs["lr/d*lr/" + name] = (
|
||||
lr_scheduler.optimizers[-1].param_groups[lr_index]["d"] * lr_scheduler.optimizers[-1].param_groups[lr_index]["lr"]
|
||||
)
|
||||
|
||||
lr_index += 1
|
||||
train_util.append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names)
|
||||
|
||||
|
||||
def train(args):
|
||||
@@ -271,10 +264,11 @@ def train(args):
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# 学習を準備する:モデルを適切な状態にする
|
||||
training_models = []
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
training_models.append(unet)
|
||||
train_unet = args.learning_rate > 0
|
||||
train_text_encoder1 = False
|
||||
train_text_encoder2 = False
|
||||
|
||||
if args.train_text_encoder:
|
||||
# TODO each option for two text encoders?
|
||||
@@ -282,10 +276,23 @@ def train(args):
|
||||
if args.gradient_checkpointing:
|
||||
text_encoder1.gradient_checkpointing_enable()
|
||||
text_encoder2.gradient_checkpointing_enable()
|
||||
training_models.append(text_encoder1)
|
||||
training_models.append(text_encoder2)
|
||||
# set require_grad=True later
|
||||
lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train
|
||||
lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train
|
||||
train_text_encoder1 = lr_te1 > 0
|
||||
train_text_encoder2 = lr_te2 > 0
|
||||
|
||||
# caching one text encoder output is not supported
|
||||
if not train_text_encoder1:
|
||||
text_encoder1.to(weight_dtype)
|
||||
if not train_text_encoder2:
|
||||
text_encoder2.to(weight_dtype)
|
||||
text_encoder1.requires_grad_(train_text_encoder1)
|
||||
text_encoder2.requires_grad_(train_text_encoder2)
|
||||
text_encoder1.train(train_text_encoder1)
|
||||
text_encoder2.train(train_text_encoder2)
|
||||
else:
|
||||
text_encoder1.to(weight_dtype)
|
||||
text_encoder2.to(weight_dtype)
|
||||
text_encoder1.requires_grad_(False)
|
||||
text_encoder2.requires_grad_(False)
|
||||
text_encoder1.eval()
|
||||
@@ -294,7 +301,7 @@ def train(args):
|
||||
# TextEncoderの出力をキャッシュする
|
||||
if args.cache_text_encoder_outputs:
|
||||
# Text Encodes are eval and no grad
|
||||
with torch.no_grad():
|
||||
with torch.no_grad(), accelerator.autocast():
|
||||
train_dataset_group.cache_text_encoder_outputs(
|
||||
(tokenizer1, tokenizer2),
|
||||
(text_encoder1, text_encoder2),
|
||||
@@ -310,30 +317,33 @@ def train(args):
|
||||
vae.eval()
|
||||
vae.to(accelerator.device, dtype=vae_dtype)
|
||||
|
||||
for m in training_models:
|
||||
m.requires_grad_(True)
|
||||
unet.requires_grad_(train_unet)
|
||||
if not train_unet:
|
||||
unet.to(accelerator.device, dtype=weight_dtype) # because of unet is not prepared
|
||||
|
||||
if block_lrs is None:
|
||||
params = []
|
||||
for m in training_models:
|
||||
params.extend(m.parameters())
|
||||
params_to_optimize = params
|
||||
training_models = []
|
||||
params_to_optimize = []
|
||||
if train_unet:
|
||||
training_models.append(unet)
|
||||
if block_lrs is None:
|
||||
params_to_optimize.append({"params": list(unet.parameters()), "lr": args.learning_rate})
|
||||
else:
|
||||
params_to_optimize.extend(get_block_params_to_optimize(unet, block_lrs))
|
||||
|
||||
# calculate number of trainable parameters
|
||||
n_params = 0
|
||||
for p in params:
|
||||
if train_text_encoder1:
|
||||
training_models.append(text_encoder1)
|
||||
params_to_optimize.append({"params": list(text_encoder1.parameters()), "lr": args.learning_rate_te1 or args.learning_rate})
|
||||
if train_text_encoder2:
|
||||
training_models.append(text_encoder2)
|
||||
params_to_optimize.append({"params": list(text_encoder2.parameters()), "lr": args.learning_rate_te2 or args.learning_rate})
|
||||
|
||||
# calculate number of trainable parameters
|
||||
n_params = 0
|
||||
for params in params_to_optimize:
|
||||
for p in params["params"]:
|
||||
n_params += p.numel()
|
||||
else:
|
||||
params_to_optimize = get_block_params_to_optimize(training_models[0], block_lrs) # U-Net
|
||||
for m in training_models[1:]: # Text Encoders if exists
|
||||
params_to_optimize.append({"params": m.parameters(), "lr": args.learning_rate})
|
||||
|
||||
# calculate number of trainable parameters
|
||||
n_params = 0
|
||||
for params in params_to_optimize:
|
||||
for p in params["params"]:
|
||||
n_params += p.numel()
|
||||
|
||||
accelerator.print(f"train unet: {train_unet}, text_encoder1: {train_text_encoder1}, text_encoder2: {train_text_encoder2}")
|
||||
accelerator.print(f"number of models: {len(training_models)}")
|
||||
accelerator.print(f"number of trainable parameters: {n_params}")
|
||||
|
||||
@@ -385,18 +395,17 @@ def train(args):
|
||||
text_encoder2.to(weight_dtype)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
if args.train_text_encoder:
|
||||
unet, text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
if train_unet:
|
||||
unet = accelerator.prepare(unet)
|
||||
if train_text_encoder1:
|
||||
# freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer
|
||||
text_encoder1.text_model.encoder.layers[-1].requires_grad_(False)
|
||||
text_encoder1.text_model.final_layer_norm.requires_grad_(False)
|
||||
text_encoder1 = accelerator.prepare(text_encoder1)
|
||||
if train_text_encoder2:
|
||||
text_encoder2 = accelerator.prepare(text_encoder2)
|
||||
|
||||
# transform DDP after prepare
|
||||
text_encoder1, text_encoder2, unet = train_util.transform_models_if_DDP([text_encoder1, text_encoder2, unet])
|
||||
else:
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
||||
(unet,) = train_util.transform_models_if_DDP([unet])
|
||||
text_encoder1.to(weight_dtype)
|
||||
text_encoder2.to(weight_dtype)
|
||||
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
|
||||
if args.cache_text_encoder_outputs:
|
||||
@@ -452,6 +461,12 @@ def train(args):
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
|
||||
|
||||
# For --sample_at_first
|
||||
sdxl_train_util.sample_images(
|
||||
accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [text_encoder1, text_encoder2], unet
|
||||
)
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
for epoch in range(num_train_epochs):
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch + 1
|
||||
@@ -459,10 +474,9 @@ def train(args):
|
||||
for m in training_models:
|
||||
m.train()
|
||||
|
||||
loss_total = 0
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
current_step.value = global_step
|
||||
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
|
||||
with accelerator.accumulate(*training_models):
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
||||
else:
|
||||
@@ -473,7 +487,7 @@ def train(args):
|
||||
# NaNが含まれていれば警告を表示し0に置き換える
|
||||
if torch.any(torch.isnan(latents)):
|
||||
accelerator.print("NaN found in latents, replacing with zeros")
|
||||
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
|
||||
latents = torch.nan_to_num(latents, 0, out=latents)
|
||||
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
|
||||
|
||||
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
|
||||
@@ -494,6 +508,7 @@ def train(args):
|
||||
# else:
|
||||
input_ids1 = input_ids1.to(accelerator.device)
|
||||
input_ids2 = input_ids2.to(accelerator.device)
|
||||
# unwrap_model is fine for models not wrapped by accelerator
|
||||
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
|
||||
args.max_token_length,
|
||||
input_ids1,
|
||||
@@ -503,6 +518,7 @@ def train(args):
|
||||
text_encoder1,
|
||||
text_encoder2,
|
||||
None if not args.full_fp16 else weight_dtype,
|
||||
accelerator=accelerator,
|
||||
)
|
||||
else:
|
||||
encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
|
||||
@@ -548,7 +564,12 @@ def train(args):
|
||||
|
||||
target = noise
|
||||
|
||||
if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.v_pred_like_loss:
|
||||
if (
|
||||
args.min_snr_gamma
|
||||
or args.scale_v_pred_loss_like_noise_pred
|
||||
or args.v_pred_like_loss
|
||||
or args.debiased_estimation_loss
|
||||
):
|
||||
# do not mean over batch dimension for snr weight or scale v-pred loss
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
@@ -559,6 +580,8 @@ def train(args):
|
||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||
if args.v_pred_like_loss:
|
||||
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
||||
if args.debiased_estimation_loss:
|
||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
||||
|
||||
loss = loss.mean() # mean over batch dimension
|
||||
else:
|
||||
@@ -620,29 +643,22 @@ def train(args):
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss": current_loss}
|
||||
if block_lrs is None:
|
||||
logs["lr"] = float(lr_scheduler.get_last_lr()[0])
|
||||
if (
|
||||
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
|
||||
): # tracking d*lr value
|
||||
logs["lr/d*lr"] = (
|
||||
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
|
||||
)
|
||||
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=train_unet)
|
||||
else:
|
||||
append_block_lr_to_logs(block_lrs, logs, lr_scheduler, args.optimizer_type)
|
||||
append_block_lr_to_logs(block_lrs, logs, lr_scheduler, args.optimizer_type) # U-Net is included in block_lrs
|
||||
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
# TODO moving averageにする
|
||||
loss_total += current_loss
|
||||
avr_loss = loss_total / (step + 1)
|
||||
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||
avr_loss: float = loss_recorder.moving_average
|
||||
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss/epoch": loss_total / len(train_dataloader)}
|
||||
logs = {"loss/epoch": loss_recorder.moving_average}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
@@ -726,6 +742,19 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
custom_train_functions.add_custom_train_arguments(parser)
|
||||
sdxl_train_util.add_sdxl_training_arguments(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--learning_rate_te1",
|
||||
type=float,
|
||||
default=None,
|
||||
help="learning rate for text encoder 1 (ViT-L) / text encoder 1 (ViT-L)の学習率",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate_te2",
|
||||
type=float,
|
||||
default=None,
|
||||
help="learning rate for text encoder 2 (BiG-G) / text encoder 2 (BiG-G)の学習率",
|
||||
)
|
||||
|
||||
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
|
||||
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
|
||||
parser.add_argument(
|
||||
|
||||
@@ -44,6 +44,7 @@ from library.custom_train_functions import (
|
||||
pyramid_noise_like,
|
||||
apply_noise_offset,
|
||||
scale_v_prediction_loss_like_noise_prediction,
|
||||
apply_debiased_estimation,
|
||||
)
|
||||
import networks.control_net_lllite_for_train as control_net_lllite_for_train
|
||||
|
||||
@@ -282,9 +283,6 @@ def train(args):
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
# transform DDP after prepare (train_network here only)
|
||||
unet = train_util.transform_models_if_DDP([unet])[0]
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
|
||||
else:
|
||||
@@ -350,8 +348,7 @@ def train(args):
|
||||
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
|
||||
)
|
||||
|
||||
loss_list = []
|
||||
loss_total = 0.0
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
del train_dataset_group
|
||||
|
||||
# function for saving/removing
|
||||
@@ -397,7 +394,7 @@ def train(args):
|
||||
# NaNが含まれていれば警告を表示し0に置き換える
|
||||
if torch.any(torch.isnan(latents)):
|
||||
accelerator.print("NaN found in latents, replacing with zeros")
|
||||
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
|
||||
latents = torch.nan_to_num(latents, 0, out=latents)
|
||||
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
|
||||
|
||||
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
|
||||
@@ -460,11 +457,13 @@ def train(args):
|
||||
loss = loss * loss_weights
|
||||
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||
if args.v_pred_like_loss:
|
||||
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
||||
if args.debiased_estimation_loss:
|
||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
||||
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
@@ -500,14 +499,9 @@ def train(args):
|
||||
remove_model(remove_ckpt_name)
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
if epoch == 0:
|
||||
loss_list.append(current_loss)
|
||||
else:
|
||||
loss_total -= loss_list[step]
|
||||
loss_list[step] = current_loss
|
||||
loss_total += current_loss
|
||||
avr_loss = loss_total / len(loss_list)
|
||||
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||
avr_loss: float = loss_recorder.moving_average
|
||||
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if args.logging_dir is not None:
|
||||
@@ -518,7 +512,7 @@ def train(args):
|
||||
break
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss/epoch": loss_total / len(loss_list)}
|
||||
logs = {"loss/epoch": loss_recorder.moving_average}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
@@ -40,6 +40,7 @@ from library.custom_train_functions import (
|
||||
pyramid_noise_like,
|
||||
apply_noise_offset,
|
||||
scale_v_prediction_loss_like_noise_prediction,
|
||||
apply_debiased_estimation,
|
||||
)
|
||||
import networks.control_net_lllite as control_net_lllite
|
||||
|
||||
@@ -253,9 +254,6 @@ def train(args):
|
||||
)
|
||||
network: control_net_lllite.ControlNetLLLite
|
||||
|
||||
# transform DDP after prepare (train_network here only)
|
||||
unet, network = train_util.transform_models_if_DDP([unet, network])
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
|
||||
else:
|
||||
@@ -323,8 +321,7 @@ def train(args):
|
||||
"lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
|
||||
)
|
||||
|
||||
loss_list = []
|
||||
loss_total = 0.0
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
del train_dataset_group
|
||||
|
||||
# function for saving/removing
|
||||
@@ -366,7 +363,7 @@ def train(args):
|
||||
# NaNが含まれていれば警告を表示し0に置き換える
|
||||
if torch.any(torch.isnan(latents)):
|
||||
accelerator.print("NaN found in latents, replacing with zeros")
|
||||
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
|
||||
latents = torch.nan_to_num(latents, 0, out=latents)
|
||||
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
|
||||
|
||||
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
|
||||
@@ -430,11 +427,13 @@ def train(args):
|
||||
loss = loss * loss_weights
|
||||
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||
if args.v_pred_like_loss:
|
||||
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
||||
if args.debiased_estimation_loss:
|
||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
||||
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
@@ -470,14 +469,9 @@ def train(args):
|
||||
remove_model(remove_ckpt_name)
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
if epoch == 0:
|
||||
loss_list.append(current_loss)
|
||||
else:
|
||||
loss_total -= loss_list[step]
|
||||
loss_list[step] = current_loss
|
||||
loss_total += current_loss
|
||||
avr_loss = loss_total / len(loss_list)
|
||||
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||
avr_loss: float = loss_recorder.moving_average
|
||||
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if args.logging_dir is not None:
|
||||
@@ -488,7 +482,7 @@ def train(args):
|
||||
break
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss/epoch": loss_total / len(loss_list)}
|
||||
logs = {"loss/epoch": loss_recorder.moving_average}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import argparse
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
@@ -70,14 +73,16 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
dataset.cache_text_encoder_outputs(
|
||||
tokenizers,
|
||||
text_encoders,
|
||||
accelerator.device,
|
||||
weight_dtype,
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
accelerator.is_main_process,
|
||||
)
|
||||
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
|
||||
with accelerator.autocast():
|
||||
dataset.cache_text_encoder_outputs(
|
||||
tokenizers,
|
||||
text_encoders,
|
||||
accelerator.device,
|
||||
weight_dtype,
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
accelerator.is_main_process,
|
||||
)
|
||||
|
||||
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
|
||||
text_encoders[1].to("cpu", dtype=torch.float32)
|
||||
@@ -121,6 +126,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
text_encoders[0],
|
||||
text_encoders[1],
|
||||
None if not args.full_fp16 else weight_dtype,
|
||||
accelerator=accelerator,
|
||||
)
|
||||
else:
|
||||
encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
|
||||
|
||||
@@ -64,6 +64,7 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
|
||||
text_encoders[0],
|
||||
text_encoders[1],
|
||||
None if not args.full_fp16 else weight_dtype,
|
||||
accelerator=accelerator,
|
||||
)
|
||||
return encoder_hidden_states1, encoder_hidden_states2, pool2
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ def convert(args):
|
||||
is_load_ckpt = os.path.isfile(args.model_to_load)
|
||||
is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0
|
||||
|
||||
assert not is_load_ckpt or args.v1 != args.v2, f"v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です"
|
||||
assert not is_load_ckpt or args.v1 != args.v2, "v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です"
|
||||
# assert (
|
||||
# is_save_ckpt or args.reference_model is not None
|
||||
# ), f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です"
|
||||
@@ -34,10 +34,12 @@ def convert(args):
|
||||
|
||||
if is_load_ckpt:
|
||||
v2_model = args.v2
|
||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load, unet_use_linear_projection_in_v2=args.unet_use_linear_projection)
|
||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(
|
||||
v2_model, args.model_to_load, unet_use_linear_projection_in_v2=args.unet_use_linear_projection
|
||||
)
|
||||
else:
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None
|
||||
args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None, variant=args.variant
|
||||
)
|
||||
text_encoder = pipe.text_encoder
|
||||
vae = pipe.vae
|
||||
@@ -57,15 +59,26 @@ def convert(args):
|
||||
if is_save_ckpt:
|
||||
original_model = args.model_to_load if is_load_ckpt else None
|
||||
key_count = model_util.save_stable_diffusion_checkpoint(
|
||||
v2_model, args.model_to_save, text_encoder, unet, original_model, args.epoch, args.global_step, save_dtype, vae
|
||||
v2_model,
|
||||
args.model_to_save,
|
||||
text_encoder,
|
||||
unet,
|
||||
original_model,
|
||||
args.epoch,
|
||||
args.global_step,
|
||||
None if args.metadata is None else eval(args.metadata),
|
||||
save_dtype=save_dtype,
|
||||
vae=vae,
|
||||
)
|
||||
print(f"model saved. total converted state_dict keys: {key_count}")
|
||||
else:
|
||||
print(f"copy scheduler/tokenizer config from: {args.reference_model if args.reference_model is not None else 'default model'}")
|
||||
print(
|
||||
f"copy scheduler/tokenizer config from: {args.reference_model if args.reference_model is not None else 'default model'}"
|
||||
)
|
||||
model_util.save_diffusers_checkpoint(
|
||||
v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors
|
||||
)
|
||||
print(f"model saved.")
|
||||
print("model saved.")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
@@ -77,7 +90,9 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
"--v2", action="store_true", help="load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--unet_use_linear_projection", action="store_true", help="When saving v2 model as Diffusers, set U-Net config to `use_linear_projection=true` (to match stabilityai's model) / Diffusers形式でv2モデルを保存するときにU-Netの設定を`use_linear_projection=true`にする(stabilityaiのモデルと合わせる)"
|
||||
"--unet_use_linear_projection",
|
||||
action="store_true",
|
||||
help="When saving v2 model as Diffusers, set U-Net config to `use_linear_projection=true` (to match stabilityai's model) / Diffusers形式でv2モデルを保存するときにU-Netの設定を`use_linear_projection=true`にする(stabilityaiのモデルと合わせる)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fp16",
|
||||
@@ -99,6 +114,18 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument(
|
||||
"--global_step", type=int, default=0, help="global_step to write to checkpoint / checkpointに記録するglobal_stepの値"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metadata",
|
||||
type=str,
|
||||
default=None,
|
||||
help='モデルに保存されるメタデータ、Pythonの辞書形式で指定 / metadata: metadata written in to the model in Python Dictionary. Example metadata: \'{"name": "model_name", "resolution": "512x512"}\'',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--variant",
|
||||
type=str,
|
||||
default=None,
|
||||
help="読む込むDiffusersのvariantを指定する、例: fp16 / variant: Diffusers variant to load. Example: fp16",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reference_model",
|
||||
type=str,
|
||||
|
||||
@@ -11,10 +11,13 @@ import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
@@ -335,10 +338,11 @@ def train(args):
|
||||
init_kwargs = {}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers("controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
|
||||
accelerator.init_trackers(
|
||||
"controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
|
||||
)
|
||||
|
||||
loss_list = []
|
||||
loss_total = 0.0
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
del train_dataset_group
|
||||
|
||||
# function for saving/removing
|
||||
@@ -372,6 +376,11 @@ def train(args):
|
||||
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
|
||||
os.remove(old_ckpt_file)
|
||||
|
||||
# For --sample_at_first
|
||||
train_util.sample_images(
|
||||
accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, controlnet=controlnet
|
||||
)
|
||||
|
||||
# training loop
|
||||
for epoch in range(num_train_epochs):
|
||||
if is_main_process:
|
||||
@@ -450,7 +459,7 @@ def train(args):
|
||||
loss = loss * loss_weights
|
||||
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
@@ -500,14 +509,9 @@ def train(args):
|
||||
remove_model(remove_ckpt_name)
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
if epoch == 0:
|
||||
loss_list.append(current_loss)
|
||||
else:
|
||||
loss_total -= loss_list[step]
|
||||
loss_list[step] = current_loss
|
||||
loss_total += current_loss
|
||||
avr_loss = loss_total / len(loss_list)
|
||||
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||
avr_loss: float = loss_recorder.moving_average
|
||||
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if args.logging_dir is not None:
|
||||
@@ -518,7 +522,7 @@ def train(args):
|
||||
break
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss/epoch": loss_total / len(loss_list)}
|
||||
logs = {"loss/epoch": loss_recorder.moving_average}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
65
train_db.py
65
train_db.py
@@ -11,10 +11,13 @@ import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
@@ -35,6 +38,7 @@ from library.custom_train_functions import (
|
||||
pyramid_noise_like,
|
||||
apply_noise_offset,
|
||||
scale_v_prediction_loss_like_noise_prediction,
|
||||
apply_debiased_estimation,
|
||||
)
|
||||
|
||||
# perlin_noise,
|
||||
@@ -108,6 +112,7 @@ def train(args):
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
|
||||
|
||||
# モデルを読み込む
|
||||
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||
@@ -132,7 +137,7 @@ def train(args):
|
||||
|
||||
# 学習を準備する
|
||||
if cache_latents:
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.to(accelerator.device, dtype=vae_dtype)
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
@@ -163,11 +168,17 @@ def train(args):
|
||||
# 学習に必要なクラスを準備する
|
||||
accelerator.print("prepare optimizer, data loader etc.")
|
||||
if train_text_encoder:
|
||||
# wightout list, adamw8bit is crashed
|
||||
trainable_params = list(itertools.chain(unet.parameters(), text_encoder.parameters()))
|
||||
if args.learning_rate_te is None:
|
||||
# wightout list, adamw8bit is crashed
|
||||
trainable_params = list(itertools.chain(unet.parameters(), text_encoder.parameters()))
|
||||
else:
|
||||
trainable_params = [
|
||||
{"params": list(unet.parameters()), "lr": args.learning_rate},
|
||||
{"params": list(text_encoder.parameters()), "lr": args.learning_rate_te},
|
||||
]
|
||||
else:
|
||||
trainable_params = unet.parameters()
|
||||
|
||||
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
|
||||
# dataloaderを準備する
|
||||
@@ -215,9 +226,6 @@ def train(args):
|
||||
else:
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
# transform DDP after prepare
|
||||
text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet)
|
||||
|
||||
if not train_text_encoder:
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error
|
||||
|
||||
@@ -264,8 +272,10 @@ def train(args):
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
|
||||
|
||||
loss_list = []
|
||||
loss_total = 0.0
|
||||
# For --sample_at_first
|
||||
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
for epoch in range(num_train_epochs):
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch + 1
|
||||
@@ -333,9 +343,11 @@ def train(args):
|
||||
loss = loss * loss_weights
|
||||
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||
if args.debiased_estimation_loss:
|
||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
||||
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
@@ -383,30 +395,20 @@ def train(args):
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
||||
if (
|
||||
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
|
||||
): # tracking d*lr value
|
||||
logs["lr/d*lr"] = (
|
||||
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
|
||||
)
|
||||
logs = {"loss": current_loss}
|
||||
train_util.append_lr_to_logs(logs, lr_scheduler, args.optimizer_type, including_unet=True)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if epoch == 0:
|
||||
loss_list.append(current_loss)
|
||||
else:
|
||||
loss_total -= loss_list[step]
|
||||
loss_list[step] = current_loss
|
||||
loss_total += current_loss
|
||||
avr_loss = loss_total / len(loss_list)
|
||||
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||
avr_loss: float = loss_recorder.moving_average
|
||||
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss/epoch": loss_total / len(loss_list)}
|
||||
logs = {"loss/epoch": loss_recorder.moving_average}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
@@ -464,6 +466,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
config_util.add_config_arguments(parser)
|
||||
custom_train_functions.add_custom_train_arguments(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--learning_rate_te",
|
||||
type=float,
|
||||
default=None,
|
||||
help="learning rate for text encoder, default is same as unet / Text Encoderの学習率、デフォルトはunetと同じ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_token_padding",
|
||||
action="store_true",
|
||||
@@ -475,6 +483,11 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
default=None,
|
||||
help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_half_vae",
|
||||
action="store_true",
|
||||
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
110
train_network.py
110
train_network.py
@@ -12,6 +12,7 @@ import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
@@ -43,6 +44,7 @@ from library.custom_train_functions import (
|
||||
prepare_scheduler_for_custom_training,
|
||||
scale_v_prediction_loss_like_noise_prediction,
|
||||
add_v_prediction_like_loss,
|
||||
apply_debiased_estimation,
|
||||
)
|
||||
|
||||
|
||||
@@ -108,6 +110,9 @@ class NetworkTrainer:
|
||||
def is_text_encoder_outputs_cached(self, args):
|
||||
return False
|
||||
|
||||
def is_train_text_encoder(self, args):
|
||||
return not args.network_train_unet_only and not self.is_text_encoder_outputs_cached(args)
|
||||
|
||||
def cache_text_encoder_outputs_if_needed(
|
||||
self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype
|
||||
):
|
||||
@@ -123,6 +128,11 @@ class NetworkTrainer:
|
||||
noise_pred = unet(noisy_latents, timesteps, text_conds).sample
|
||||
return noise_pred
|
||||
|
||||
def all_reduce_network(self, accelerator, network):
|
||||
for param in network.parameters():
|
||||
if param.grad is not None:
|
||||
param.grad = accelerator.reduce(param.grad, reduction="mean")
|
||||
|
||||
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
|
||||
train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
@@ -283,7 +293,10 @@ class NetworkTrainer:
|
||||
if args.dim_from_weights:
|
||||
network, _ = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet, **net_kwargs)
|
||||
else:
|
||||
# LyCORIS will work with this...
|
||||
if "dropout" not in net_kwargs:
|
||||
# workaround for LyCORIS (;^ω^)
|
||||
net_kwargs["dropout"] = args.network_dropout
|
||||
|
||||
network = network_module.create_network(
|
||||
1.0,
|
||||
args.network_dim,
|
||||
@@ -306,7 +319,7 @@ class NetworkTrainer:
|
||||
args.scale_weight_norms = False
|
||||
|
||||
train_unet = not args.network_train_text_encoder_only
|
||||
train_text_encoder = not args.network_train_unet_only and not self.is_text_encoder_outputs_cached(args)
|
||||
train_text_encoder = self.is_train_text_encoder(args)
|
||||
network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
|
||||
|
||||
if args.network_weights is not None:
|
||||
@@ -383,44 +396,20 @@ class NetworkTrainer:
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
# TODO めちゃくちゃ冗長なのでコードを整理する
|
||||
if train_unet and train_text_encoder:
|
||||
if len(text_encoders) > 1:
|
||||
unet, t_enc1, t_enc2, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, text_encoders[0], text_encoders[1], network, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
text_encoder = text_encoders = [t_enc1, t_enc2]
|
||||
del t_enc1, t_enc2
|
||||
else:
|
||||
unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
text_encoders = [text_encoder]
|
||||
elif train_unet:
|
||||
unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, network, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
elif train_text_encoder:
|
||||
if len(text_encoders) > 1:
|
||||
t_enc1, t_enc2, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
text_encoders[0], text_encoders[1], network, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
text_encoder = text_encoders = [t_enc1, t_enc2]
|
||||
del t_enc1, t_enc2
|
||||
else:
|
||||
text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
text_encoder, network, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
text_encoders = [text_encoder]
|
||||
|
||||
unet.to(accelerator.device, dtype=weight_dtype) # move to device because unet is not prepared by accelerator
|
||||
if train_unet:
|
||||
unet = accelerator.prepare(unet)
|
||||
else:
|
||||
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
network, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
# transform DDP after prepare (train_network here only)
|
||||
text_encoders = train_util.transform_models_if_DDP(text_encoders)
|
||||
unet, network = train_util.transform_models_if_DDP([unet, network])
|
||||
unet.to(accelerator.device, dtype=weight_dtype) # move to device because unet is not prepared by accelerator
|
||||
if train_text_encoder:
|
||||
if len(text_encoders) > 1:
|
||||
text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders]
|
||||
else:
|
||||
text_encoder = accelerator.prepare(text_encoder)
|
||||
text_encoders = [text_encoder]
|
||||
else:
|
||||
for t_enc in text_encoders:
|
||||
t_enc.to(accelerator.device, dtype=weight_dtype)
|
||||
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
# according to TI example in Diffusers, train is required
|
||||
@@ -442,7 +431,7 @@ class NetworkTrainer:
|
||||
|
||||
del t_enc
|
||||
|
||||
network.prepare_grad_etc(text_encoder, unet)
|
||||
accelerator.unwrap_model(network).prepare_grad_etc(text_encoder, unet)
|
||||
|
||||
if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する
|
||||
vae.requires_grad_(False)
|
||||
@@ -525,6 +514,7 @@ class NetworkTrainer:
|
||||
"ss_min_snr_gamma": args.min_snr_gamma,
|
||||
"ss_scale_weight_norms": args.scale_weight_norms,
|
||||
"ss_ip_noise_gamma": args.ip_noise_gamma,
|
||||
"ss_debiased_estimation": bool(args.debiased_estimation_loss),
|
||||
}
|
||||
|
||||
if use_user_config:
|
||||
@@ -694,19 +684,20 @@ class NetworkTrainer:
|
||||
|
||||
if accelerator.is_main_process:
|
||||
init_kwargs = {}
|
||||
if args.wandb_run_name:
|
||||
init_kwargs['wandb'] = {'name': args.wandb_run_name}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers(
|
||||
"network_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
|
||||
)
|
||||
|
||||
loss_list = []
|
||||
loss_total = 0.0
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
del train_dataset_group
|
||||
|
||||
# callback for step start
|
||||
if hasattr(network, "on_step_start"):
|
||||
on_step_start = network.on_step_start
|
||||
if hasattr(accelerator.unwrap_model(network), "on_step_start"):
|
||||
on_step_start = accelerator.unwrap_model(network).on_step_start
|
||||
else:
|
||||
on_step_start = lambda *args, **kwargs: None
|
||||
|
||||
@@ -734,6 +725,9 @@ class NetworkTrainer:
|
||||
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
|
||||
os.remove(old_ckpt_file)
|
||||
|
||||
# For --sample_at_first
|
||||
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
# training loop
|
||||
for epoch in range(num_train_epochs):
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
@@ -741,7 +735,7 @@ class NetworkTrainer:
|
||||
|
||||
metadata["ss_epoch"] = str(epoch + 1)
|
||||
|
||||
network.on_epoch_start(text_encoder, unet)
|
||||
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet)
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
current_step.value = global_step
|
||||
@@ -758,11 +752,11 @@ class NetworkTrainer:
|
||||
# NaNが含まれていれば警告を表示し0に置き換える
|
||||
if torch.any(torch.isnan(latents)):
|
||||
accelerator.print("NaN found in latents, replacing with zeros")
|
||||
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
|
||||
latents = torch.nan_to_num(latents, 0, out=latents)
|
||||
latents = latents * self.vae_scale_factor
|
||||
b_size = latents.shape[0]
|
||||
|
||||
with torch.set_grad_enabled(train_text_encoder):
|
||||
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
|
||||
# Get the text embedding for conditioning
|
||||
if args.weighted_captions:
|
||||
text_encoder_conds = get_weighted_text_embeddings(
|
||||
@@ -803,17 +797,20 @@ class NetworkTrainer:
|
||||
loss = loss * loss_weights
|
||||
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||
if args.v_pred_like_loss:
|
||||
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
||||
if args.debiased_estimation_loss:
|
||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
||||
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
accelerator.backward(loss)
|
||||
self.all_reduce_network(accelerator, network) # sync DDP grad manually
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
params_to_clip = network.get_trainable_params()
|
||||
params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
@@ -821,7 +818,7 @@ class NetworkTrainer:
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
if args.scale_weight_norms:
|
||||
keys_scaled, mean_norm, maximum_norm = network.apply_max_norm_regularization(
|
||||
keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(
|
||||
args.scale_weight_norms, accelerator.device
|
||||
)
|
||||
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
|
||||
@@ -851,14 +848,9 @@ class NetworkTrainer:
|
||||
remove_model(remove_ckpt_name)
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
if epoch == 0:
|
||||
loss_list.append(current_loss)
|
||||
else:
|
||||
loss_total -= loss_list[step]
|
||||
loss_list[step] = current_loss
|
||||
loss_total += current_loss
|
||||
avr_loss = loss_total / len(loss_list)
|
||||
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
|
||||
avr_loss: float = loss_recorder.moving_average
|
||||
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if args.scale_weight_norms:
|
||||
@@ -872,7 +864,7 @@ class NetworkTrainer:
|
||||
break
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss/epoch": loss_total / len(loss_list)}
|
||||
logs = {"loss/epoch": loss_recorder.moving_average}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
@@ -7,10 +7,13 @@ import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
@@ -32,6 +35,7 @@ from library.custom_train_functions import (
|
||||
prepare_scheduler_for_custom_training,
|
||||
scale_v_prediction_loss_like_noise_prediction,
|
||||
add_v_prediction_like_loss,
|
||||
apply_debiased_estimation,
|
||||
)
|
||||
|
||||
imagenet_templates_small = [
|
||||
@@ -414,15 +418,11 @@ class TextualInversionTrainer:
|
||||
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
# transform DDP after prepare
|
||||
text_encoder_or_list, unet = train_util.transform_if_model_is_DDP(text_encoder_or_list, unet)
|
||||
|
||||
elif len(text_encoders) == 2:
|
||||
text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
text_encoders[0], text_encoders[1], optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
# transform DDP after prepare
|
||||
text_encoder1, text_encoder2, unet = train_util.transform_if_model_is_DDP(text_encoder1, text_encoder2, unet)
|
||||
|
||||
text_encoder_or_list = text_encoders = [text_encoder1, text_encoder2]
|
||||
|
||||
@@ -528,6 +528,20 @@ class TextualInversionTrainer:
|
||||
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
|
||||
os.remove(old_ckpt_file)
|
||||
|
||||
# For --sample_at_first
|
||||
self.sample_images(
|
||||
accelerator,
|
||||
args,
|
||||
0,
|
||||
global_step,
|
||||
accelerator.device,
|
||||
vae,
|
||||
tokenizer_or_list,
|
||||
text_encoder_or_list,
|
||||
unet,
|
||||
prompt_replacement,
|
||||
)
|
||||
|
||||
# training loop
|
||||
for epoch in range(num_train_epochs):
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
@@ -577,11 +591,13 @@ class TextualInversionTrainer:
|
||||
loss = loss * loss_weights
|
||||
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||
if args.v_pred_like_loss:
|
||||
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
||||
if args.debiased_estimation_loss:
|
||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
||||
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
|
||||
@@ -34,6 +34,7 @@ from library.custom_train_functions import (
|
||||
pyramid_noise_like,
|
||||
apply_noise_offset,
|
||||
scale_v_prediction_loss_like_noise_prediction,
|
||||
apply_debiased_estimation,
|
||||
)
|
||||
import library.original_unet as original_unet
|
||||
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
|
||||
@@ -332,9 +333,6 @@ def train(args):
|
||||
text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
# transform DDP after prepare
|
||||
text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet)
|
||||
|
||||
index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0]
|
||||
# print(len(index_no_updates), torch.sum(index_no_updates))
|
||||
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
|
||||
@@ -468,9 +466,11 @@ def train(args):
|
||||
|
||||
loss = loss * loss_weights
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||
if args.debiased_estimation_loss:
|
||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
||||
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
|
||||
Reference in New Issue
Block a user