mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
Compare commits
37 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f0ae7eea95 | ||
|
|
b22b0a5c75 | ||
|
|
c7a13c89c7 | ||
|
|
39a70f10bd | ||
|
|
a3c0e4cf44 | ||
|
|
9b13444b9c | ||
|
|
0eb01dea55 | ||
|
|
a3aa3b1712 | ||
|
|
95b5aed41b | ||
|
|
d9184ab21c | ||
|
|
e7dd77836d | ||
|
|
4c5c486d28 | ||
|
|
f403ac6132 | ||
|
|
b39cf6e2c0 | ||
|
|
71b728d5fc | ||
|
|
f0ef81f865 | ||
|
|
f68a48b354 | ||
|
|
7a0d2a2d45 | ||
|
|
e13e503cbc | ||
|
|
125039f491 | ||
|
|
f2b300a221 | ||
|
|
9ab964d0b8 | ||
|
|
663aad2b0d | ||
|
|
12d30afb39 | ||
|
|
107fa754e5 | ||
|
|
a17d1180cb | ||
|
|
014fd3d037 | ||
|
|
b29c5a750c | ||
|
|
b612d0b091 | ||
|
|
d94c0d70fe | ||
|
|
045a3dbe48 | ||
|
|
e45e272e9d | ||
|
|
8590d5dbca | ||
|
|
39aa390d2b | ||
|
|
64bffe5238 | ||
|
|
cebee02698 | ||
|
|
bc9fc4ccee |
143
README.md
143
README.md
@@ -124,112 +124,45 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
|
||||
|
||||
## Change History
|
||||
|
||||
- 19 Feb. 2023, 2023/2/19:
|
||||
- Add ``--use_lion_optimizer`` to each training script to use [Lion optimizer](https://github.com/lucidrains/lion-pytorch).
|
||||
- Please install Lion optimizer with ``pip install lion-pytorch`` (it is not in ``requirements.txt`` currently.)
|
||||
- Add ``--lowram`` option to ``train_network.py``. Load models to VRAM instead of VRAM (for machines which have bigger VRAM than RAM such as Colab and Kaggle). Thanks to Isotr0py!
|
||||
- Default behavior (without lowram) has reverted to the same as before 14 Feb.
|
||||
- Fixed git commit hash to be set correctly regardless of the working directory. Thanks to vladmandic!
|
||||
- 23 Feb. 2023, 2023/2/23:
|
||||
- Fix instability training issue in ``train_network.py``.
|
||||
- ``fp16`` training is probably not affected by this issue.
|
||||
- Training with ``float`` for SD2.x models will work now. Also training with ``bf16`` might be improved.
|
||||
- This issue seems to have occurred in [PR#190](https://github.com/kohya-ss/sd-scripts/pull/190).
|
||||
- Add some metadata to LoRA model. Thanks to space-nuko!
|
||||
- Raise an error if optimizer options conflict (e.g. ``--optimizer_type`` and ``--use_8bit_adam``.)
|
||||
- Support ControlNet in ``gen_img_diffusers.py`` (no documentation yet.)
|
||||
- ``train_network.py`` で学習が不安定になる不具合を修正しました。
|
||||
- ``fp16`` 精度での学習には恐らくこの問題は影響しません。
|
||||
- ``float`` 精度での SD2.x モデルの学習が正しく動作するようになりました。また ``bf16`` 精度の学習も改善する可能性があります。
|
||||
- この問題は [PR#190](https://github.com/kohya-ss/sd-scripts/pull/190) から起きていたようです。
|
||||
- いくつかのメタデータを LoRA モデルに追加しました。 space-nuko 氏に感謝します。
|
||||
- オプティマイザ関係のオプションが矛盾していた場合、エラーとするように修正しました(例: ``--optimizer_type`` と ``--use_8bit_adam``)。
|
||||
- ``gen_img_diffusers.py`` で ControlNet をサポートしました(ドキュメントはのちほど追加します)。
|
||||
|
||||
- ``--use_lion_optimizer`` オプションを各学習スクリプトに追加しました。 [Lion optimizer](https://github.com/lucidrains/lion-pytorch) を使用できます。
|
||||
- あらかじめ ``pip install lion-pytorch`` でインストールしてください(現在は ``requirements.txt`` に含まれていません)。
|
||||
- ``--lowram`` オプションを ``train_network.py`` に追加しました。モデルをRAMではなくVRAMに読み込みます(ColabやKaggleなど、VRAMがRAMに比べて多い環境で有効です)。 Isotr0py 氏に感謝します。
|
||||
- lowram オプションなしのデフォルト動作は2/14より前と同じに戻しました。
|
||||
- git commit hash を現在のフォルダ位置に関わらず正しく取得するように修正しました。vladmandic 氏に感謝します。
|
||||
|
||||
- 16 Feb. 2023, 2023/2/16:
|
||||
- Noise offset is recorded to the metadata. Thanks to space-nuko!
|
||||
- Show the moving average loss to prevent loss jumping in ``train_network.py`` and ``train_db.py``. Thanks to shirayu!
|
||||
- Noise offsetがメタデータに記録されるようになりました。space-nuko氏に感謝します。
|
||||
- ``train_network.py``と``train_db.py``で学習中に表示されるlossの値が移動平均になりました。epochの先頭で表示されるlossが大きく変動する事象を解決します。shirayu氏に感謝します。
|
||||
- 14 Feb. 2023, 2023/2/14:
|
||||
- Add support with multi-gpu trainining for ``train_network.py``. Thanks to Isotr0py!
|
||||
- Add ``--verbose`` option for ``resize_lora.py``. For details, see [this PR](https://github.com/kohya-ss/sd-scripts/pull/179). Thanks to mgz-dev!
|
||||
- Git commit hash is added to the metadata for LoRA. Thanks to space-nuko!
|
||||
- Add ``--noise_offset`` option for each training scripts.
|
||||
- Implementation of https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||
- This option may improve ability to generate darker/lighter images. May work with LoRA.
|
||||
- ``train_network.py``でマルチGPU学習をサポートしました。Isotr0py氏に感謝します。
|
||||
- ``--verbose``オプションを ``resize_lora.py`` に追加しました。表示される情報の詳細は [こちらのPR](https://github.com/kohya-ss/sd-scripts/pull/179) をご参照ください。mgz-dev氏に感謝します。
|
||||
- LoRAのメタデータにgitのcommit hashを追加しました。space-nuko氏に感謝します。
|
||||
- ``--noise_offset`` オプションを各学習スクリプトに追加しました。
|
||||
- こちらの記事の実装になります: https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||
- 全体的に暗い、明るい画像の生成結果が良くなる可能性があるようです。LoRA学習でも有効なようです。
|
||||
|
||||
- 11 Feb. 2023, 2023/2/11:
|
||||
- ``lora_interrogator.py`` is added in ``networks`` folder. See ``python networks\lora_interrogator.py -h`` for usage.
|
||||
- For LoRAs where the activation word is unknown, this script compares the output of Text Encoder after applying LoRA to that of unapplied to find out which token is affected by LoRA. Hopefully you can figure out the activation word. LoRA trained with captions does not seem to be able to interrogate.
|
||||
- Batch size can be large (like 64 or 128).
|
||||
- ``train_textual_inversion.py`` now supports multiple init words.
|
||||
- Following feature is reverted to be the same as before. Sorry for confusion:
|
||||
> Now the number of data in each batch is limited to the number of actual images (not duplicated). Because a certain bucket may contain smaller number of actual images, so the batch may contain same (duplicated) images.
|
||||
|
||||
- ``lora_interrogator.py`` を ``network``フォルダに追加しました。使用法は ``python networks\lora_interrogator.py -h`` でご確認ください。
|
||||
- このスクリプトは、起動promptがわからないLoRAについて、LoRA適用前後のText Encoderの出力を比較することで、どのtokenの出力が変化しているかを調べます。運が良ければ起動用の単語が分かります。キャプション付きで学習されたLoRAは影響が広範囲に及ぶため、調査は難しいようです。
|
||||
- バッチサイズはわりと大きくできます(64や128など)。
|
||||
- ``train_textual_inversion.py`` で複数のinit_word指定が可能になりました。
|
||||
- 次の機能を削除し元に戻しました。混乱を招き申し訳ありません。
|
||||
> これらのオプションによりbucketが細分化され、ひとつのバッチ内に同一画像が重複して存在することが増えたため、バッチサイズを``そのbucketの画像種類数``までに制限する機能を追加しました。
|
||||
|
||||
- 10 Feb. 2023, 2023/2/10:
|
||||
- Updated ``requirements.txt`` to prevent upgrading with pip taking a long time or failure to upgrade.
|
||||
- ``resize_lora.py`` keeps the metadata of the model. ``dimension is resized from ...`` is added to the top of ``ss_training_comment``.
|
||||
- ``merge_lora.py`` supports models with different ``alpha``s. If there is a problem, old version is ``merge_lora_old.py``.
|
||||
- ``svd_merge_lora.py`` is added. This script merges LoRA models with any rank (dim) and alpha, and approximate a new LoRA with svd for a specified rank (dim).
|
||||
- Note: merging scripts erase the metadata currently.
|
||||
- ``resize_images_to_resolution.py`` supports multibyte characters in filenames.
|
||||
- pipでの更新が長時間掛かったり、更新に失敗したりするのを防ぐため、``requirements.txt``を更新しました。
|
||||
- ``resize_lora.py``がメタデータを保持するようになりました。 ``dimension is resized from ...`` という文字列が ``ss_training_comment`` の先頭に追加されます。
|
||||
- ``merge_lora.py``がalphaが異なるモデルをサポートしました。 何か問題がありましたら旧バージョン ``merge_lora_old.py`` をお使いください。
|
||||
- ``svd_merge_lora.py`` を追加しました。 複数の任意のdim (rank)、alphaのLoRAモデルをマージし、svdで任意dim(rank)のLoRAで近似します。
|
||||
- 注:マージ系のスクリプトは現時点ではメタデータを消去しますのでご注意ください。
|
||||
- ``resize_images_to_resolution.py``が日本語ファイル名をサポートしました。
|
||||
|
||||
- 9 Feb. 2023, 2023/2/9:
|
||||
- Caption dropout is supported in ``train_db.py``, ``fine_tune.py`` and ``train_network.py``. Thanks to forestsource!
|
||||
- ``--caption_dropout_rate`` option specifies the dropout rate for captions (0~1.0, 0.1 means 10% chance for dropout). If dropout occurs, the image is trained with the empty caption. Default is 0 (no dropout).
|
||||
- ``--caption_dropout_every_n_epochs`` option specifies how many epochs to drop captions. If ``3`` is specified, in epoch 3, 6, 9 ..., images are trained with all captions empty. Default is None (no dropout).
|
||||
- ``--caption_tag_dropout_rate`` option specified the dropout rate for tags (comma separated tokens) (0~1.0, 0.1 means 10% chance for dropout). If dropout occurs, the tag is removed from the caption. If ``--keep_tokens`` option is set, these tokens (tags) are not dropped. Default is 0 (no droupout).
|
||||
- The bulk image downsampling script is added. Documentation is [here](https://github.com/kohya-ss/sd-scripts/blob/main/train_network_README-ja.md#%E7%94%BB%E5%83%8F%E3%83%AA%E3%82%B5%E3%82%A4%E3%82%BA%E3%82%B9%E3%82%AF%E3%83%AA%E3%83%97%E3%83%88) (in Jpanaese). Thanks to bmaltais!
|
||||
- Typo check is added. Thanks to shirayu!
|
||||
- キャプションのドロップアウトを``train_db.py``、``fine_tune.py``、``train_network.py``の各スクリプトに追加しました。forestsource氏に感謝します。
|
||||
- ``--caption_dropout_rate``オプションでキャプションのドロップアウト率を指定します(0~1.0、 0.1を指定すると10%の確率でドロップアウト)。ドロップアウトされた場合、画像は空のキャプションで学習されます。デフォルトは 0 (ドロップアウトなし)です。
|
||||
- ``--caption_dropout_every_n_epochs`` オプションで何エポックごとにキャプションを完全にドロップアウトするか指定します。たとえば``3``を指定すると、エポック3、6、9……で、すべての画像がキャプションなしで学習されます。デフォルトは None (ドロップアウトなし)です。
|
||||
- ``--caption_tag_dropout_rate`` オプションで各タグ(カンマ区切りの各部分)のドロップアウト率を指定します(0~1.0、 0.1を指定すると10%の確率でドロップアウト)。ドロップアウトが起きるとそのタグはそのときだけキャプションから取り除かれて学習されます。``--keep_tokens`` オプションを指定していると、シャッフルされない部分のタグはドロップアウトされません。デフォルトは 0 (ドロップアウトなし)です。
|
||||
- 画像の一括縮小スクリプトを追加しました。ドキュメントは [こちら](https://github.com/kohya-ss/sd-scripts/blob/main/train_network_README-ja.md#%E7%94%BB%E5%83%8F%E3%83%AA%E3%82%B5%E3%82%A4%E3%82%BA%E3%82%B9%E3%82%AF%E3%83%AA%E3%83%97%E3%83%88) です。bmaltais氏に感謝します。
|
||||
- 誤字チェッカが追加されました。shirayu氏に感謝します。
|
||||
|
||||
- 6 Feb. 2023, 2023/2/6:
|
||||
- ``--bucket_reso_steps`` and ``--bucket_no_upscale`` options are added to training scripts (fine tuning, DreamBooth, LoRA and Textual Inversion) and ``prepare_buckets_latents.py``.
|
||||
- ``--bucket_reso_steps`` takes the steps for buckets in aspect ratio bucketing. Default is 64, same as before.
|
||||
- Any value greater than or equal to 1 can be specified; 64 is highly recommended and a value divisible by 8 is recommended.
|
||||
- If less than 64 is specified, padding will occur within U-Net. The result is unknown.
|
||||
- If you specify a value that is not divisible by 8, it will be truncated to divisible by 8 inside VAE, because the size of the latent is 1/8 of the image size.
|
||||
- If ``--bucket_no_upscale`` option is specified, images smaller than the bucket size will be processed without upscaling.
|
||||
- Internally, a bucket smaller than the image size is created (for example, if the image is 300x300 and ``bucket_reso_steps=64``, the bucket is 256x256). The image will be trimmed.
|
||||
- Implementation of [#130](https://github.com/kohya-ss/sd-scripts/issues/130).
|
||||
- Images with an area larger than the maximum size specified by ``--resolution`` are downsampled to the max bucket size.
|
||||
- Now the number of data in each batch is limited to the number of actual images (not duplicated). Because a certain bucket may contain smaller number of actual images, so the batch may contain same (duplicated) images.
|
||||
- ``--random_crop`` now also works with buckets enabled.
|
||||
- Instead of always cropping the center of the image, the image is shifted left, right, up, and down to be used as the training data. This is expected to train to the edges of the image.
|
||||
- Implementation of discussion [#34](https://github.com/kohya-ss/sd-scripts/discussions/34).
|
||||
|
||||
- ``--bucket_reso_steps``および``--bucket_no_upscale``オプションを、学習スクリプトおよび``prepare_buckets_latents.py``に追加しました。
|
||||
- ``--bucket_reso_steps``オプションでは、bucketの解像度の単位を指定できます。デフォルトは64で、今までと同じ動作です。
|
||||
- 1以上の任意の値を指定できます。基本的には64を推奨します。64以外の値では、8で割り切れる値を推奨します。
|
||||
- 64未満を指定するとU-Netの内部でpaddingが発生します。どのような結果になるかは未知数です。
|
||||
- 8で割り切れない値を指定すると余りはVAE内部で切り捨てられます。
|
||||
- ``--bucket_no_upscale``オプションを指定すると、bucketサイズよりも小さい画像は拡大せずそのまま処理します。
|
||||
- 内部的には画像サイズ以下のサイズのbucketを作成します(たとえば画像が300x300で``bucket_reso_steps=64``の場合、256x256のbucket)。余りは都度trimmingされます。
|
||||
- [#130](https://github.com/kohya-ss/sd-scripts/issues/130) を実装したものです。
|
||||
- ``--resolution``で指定した最大サイズよりも面積が大きい画像は、最大サイズと同じ面積になるようアスペクト比を維持したまま縮小され、そのサイズを元にbucketが作られます。
|
||||
- これらのオプションによりbucketが細分化され、ひとつのバッチ内に同一画像が重複して存在することが増えたため、バッチサイズを``そのbucketの画像種類数``までに制限する機能を追加しました。
|
||||
- たとえば繰り返し回数10で、あるbucketに1枚しか画像がなく、バッチサイズが10以上のとき、今まではepoch内で、同一画像を10枚含むバッチが1回だけ使用されていました。
|
||||
- 機能追加後はepoch内にサイズ1のバッチが10回、使用されます。
|
||||
- ``--random_crop``がbucketを有効にした場合にも機能するようになりました。
|
||||
- 常に画像の中央を切り取るのではなく、左右、上下にずらして教師データにします。これにより画像端まで学習されることが期待されます。
|
||||
- discussionの[#34](https://github.com/kohya-ss/sd-scripts/discussions/34)を実装したものです。
|
||||
|
||||
- 22 Feb. 2023, 2023/2/22:
|
||||
- Refactor optmizer options. Thanks to mgz-dev!
|
||||
- Add ``--optimizer_type`` option for each training script. Please see help. Japanese documentation is [here](https://github.com/kohya-ss/sd-scripts/blob/main/train_network_README-ja.md#%E3%82%AA%E3%83%97%E3%83%86%E3%82%A3%E3%83%9E%E3%82%A4%E3%82%B6%E3%81%AE%E6%8C%87%E5%AE%9A%E3%81%AB%E3%81%A4%E3%81%84%E3%81%A6).
|
||||
- ``--use_8bit_adam`` and ``--use_lion_optimizer`` options also work, but override above option.
|
||||
- Add SGDNesterov and its 8bit.
|
||||
- Add [D-Adaptation](https://github.com/facebookresearch/dadaptation) optimizer. Thanks to BootsofLagrangian and all!
|
||||
- Please install D-Adaptation optimizer with ``pip install dadaptation`` (it is not in requirements.txt currently.)
|
||||
- Please see https://github.com/kohya-ss/sd-scripts/issues/181 for details.
|
||||
- Add AdaFactor optimizer. Thanks to Toshiaki!
|
||||
- Extra lr scheduler settings (num_cycles etc.) are working in training scripts other than ``train_network.py``.
|
||||
- Add ``--max_grad_norm`` option for each training script for gradient clipping. ``0.0`` disables clipping.
|
||||
- Symbolic link can be loaded in each training script. Thanks to TkskKurumi!
|
||||
- オプティマイザ関連のオプションを見直しました。mgz-dev氏に感謝します。
|
||||
- ``--optimizer_type`` を各学習スクリプトに追加しました。ドキュメントは[こちら](https://github.com/kohya-ss/sd-scripts/blob/main/train_network_README-ja.md#%E3%82%AA%E3%83%97%E3%83%86%E3%82%A3%E3%83%9E%E3%82%A4%E3%82%B6%E3%81%AE%E6%8C%87%E5%AE%9A%E3%81%AB%E3%81%A4%E3%81%84%E3%81%A6)。
|
||||
- ``--use_8bit_adam`` と ``--use_lion_optimizer`` のオプションは依然として動作しますがoptimizer_typeを上書きしますのでご注意ください。
|
||||
- SGDNesterov オプティマイザおよびその8bit版を追加しました。
|
||||
- [D-Adaptation](https://github.com/facebookresearch/dadaptation) オプティマイザを追加しました。BootsofLagrangian 氏および諸氏に感謝します。
|
||||
- ``pip install dadaptation`` コマンドで別途インストールが必要です(現時点ではrequirements.txtに含まれておりません)。
|
||||
- こちらのissueもあわせてご覧ください。 https://github.com/kohya-ss/sd-scripts/issues/181
|
||||
- AdaFactor オプティマイザを追加しました。Toshiaki氏に感謝します。
|
||||
- 追加のスケジューラ設定(num_cycles等)が ``train_network.py`` 以外の学習スクリプトでも使えるようになりました。
|
||||
- 勾配クリップ時の最大normを指定する ``--max_grad_norm`` オプションを追加しました。``0.0``を指定するとクリップしなくなります。
|
||||
- 各学習スクリプトでシンボリックリンクが読み込めるようになりました。TkskKurumi氏に感謝します。
|
||||
|
||||
Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates.
|
||||
最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。
|
||||
|
||||
39
fine_tune.py
39
fine_tune.py
@@ -149,27 +149,7 @@ def train(args):
|
||||
|
||||
# 学習に必要なクラスを準備する
|
||||
print("prepare optimizer, data loader etc.")
|
||||
|
||||
# 8-bit Adamを使う
|
||||
if args.use_8bit_adam:
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
||||
print("use 8-bit Adam optimizer")
|
||||
optimizer_class = bnb.optim.AdamW8bit
|
||||
elif args.use_lion_optimizer:
|
||||
try:
|
||||
import lion_pytorch
|
||||
except ImportError:
|
||||
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
|
||||
print("use Lion optimizer")
|
||||
optimizer_class = lion_pytorch.Lion
|
||||
else:
|
||||
optimizer_class = torch.optim.AdamW
|
||||
|
||||
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
||||
optimizer = optimizer_class(params_to_optimize, lr=args.learning_rate)
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
@@ -183,8 +163,9 @@ def train(args):
|
||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
# lr schedulerを用意する
|
||||
lr_scheduler = diffusers.optimization.get_scheduler(
|
||||
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
|
||||
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
||||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
||||
if args.full_fp16:
|
||||
@@ -286,11 +267,11 @@ def train(args):
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
params_to_clip = []
|
||||
for m in training_models:
|
||||
params_to_clip.extend(m.parameters())
|
||||
accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
@@ -303,9 +284,12 @@ def train(args):
|
||||
|
||||
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
||||
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
||||
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
|
||||
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
# TODO moving averageにする
|
||||
loss_total += current_loss
|
||||
avr_loss = loss_total / (step+1)
|
||||
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
@@ -315,7 +299,7 @@ def train(args):
|
||||
break
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = {"epoch_loss": loss_total / len(train_dataloader)}
|
||||
logs = {"loss/epoch": loss_total / len(train_dataloader)}
|
||||
accelerator.log(logs, step=epoch+1)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
@@ -351,6 +335,7 @@ if __name__ == '__main__':
|
||||
train_util.add_dataset_arguments(parser, False, True, True)
|
||||
train_util.add_training_arguments(parser, False)
|
||||
train_util.add_sd_saving_arguments(parser)
|
||||
train_util.add_optimizer_arguments(parser)
|
||||
|
||||
parser.add_argument("--diffusers_xformers", action='store_true',
|
||||
help='use xformers by diffusers / Diffusersでxformersを使用する')
|
||||
|
||||
@@ -47,7 +47,7 @@ VGG(
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import List, Optional, Union
|
||||
from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable
|
||||
import glob
|
||||
import importlib
|
||||
import inspect
|
||||
@@ -60,7 +60,6 @@ import math
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
|
||||
import diffusers
|
||||
import numpy as np
|
||||
@@ -81,6 +80,8 @@ from PIL import Image
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
|
||||
import library.model_util as model_util
|
||||
import tools.original_control_net as original_control_net
|
||||
from tools.original_control_net import ControlNetInfo
|
||||
|
||||
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
||||
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
||||
@@ -487,6 +488,9 @@ class PipelineLike():
|
||||
self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers)
|
||||
self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD)
|
||||
|
||||
# ControlNet
|
||||
self.control_nets: List[ControlNetInfo] = []
|
||||
|
||||
# Textual Inversion
|
||||
def add_token_replacement(self, target_token_id, rep_token_ids):
|
||||
self.token_replacements[target_token_id] = rep_token_ids
|
||||
@@ -500,7 +504,11 @@ class PipelineLike():
|
||||
new_tokens.append(token)
|
||||
return new_tokens
|
||||
|
||||
def set_control_nets(self, ctrl_nets):
|
||||
self.control_nets = ctrl_nets
|
||||
|
||||
# region xformersとか使う部分:独自に書き換えるので関係なし
|
||||
|
||||
def enable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
Enable memory efficient attention as implemented in xformers.
|
||||
@@ -752,7 +760,7 @@ class PipelineLike():
|
||||
text_embeddings_clip = self.clip_model.get_text_features(clip_text_input)
|
||||
text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) # prompt複数件でもOK
|
||||
|
||||
if self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0 and clip_guide_images is not None:
|
||||
if self.clip_image_guidance_scale > 0 or self.vgg16_guidance_scale > 0 and clip_guide_images is not None or self.control_nets:
|
||||
if isinstance(clip_guide_images, PIL.Image.Image):
|
||||
clip_guide_images = [clip_guide_images]
|
||||
|
||||
@@ -765,7 +773,7 @@ class PipelineLike():
|
||||
image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
|
||||
if len(image_embeddings_clip) == 1:
|
||||
image_embeddings_clip = image_embeddings_clip.repeat((batch_size, 1, 1, 1))
|
||||
else:
|
||||
elif self.vgg16_guidance_scale > 0:
|
||||
size = (width // VGG16_INPUT_RESIZE_DIV, height // VGG16_INPUT_RESIZE_DIV) # とりあえず1/4に(小さいか?)
|
||||
clip_guide_images = [preprocess_vgg16_guide_image(im, size) for im in clip_guide_images]
|
||||
clip_guide_images = torch.cat(clip_guide_images, dim=0)
|
||||
@@ -774,6 +782,10 @@ class PipelineLike():
|
||||
image_embeddings_vgg16 = self.vgg16_feat_model(clip_guide_images)['feat']
|
||||
if len(image_embeddings_vgg16) == 1:
|
||||
image_embeddings_vgg16 = image_embeddings_vgg16.repeat((batch_size, 1, 1, 1))
|
||||
else:
|
||||
# ControlNetのhintにguide imageを流用する
|
||||
# 前処理はControlNet側で行う
|
||||
pass
|
||||
|
||||
# set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, self.device)
|
||||
@@ -864,12 +876,21 @@ class PipelineLike():
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1
|
||||
|
||||
if self.control_nets:
|
||||
guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images)
|
||||
|
||||
for i, t in enumerate(tqdm(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
if self.control_nets:
|
||||
noise_pred = original_control_net.call_unet_and_control_net(
|
||||
i, num_latent_input, self.unet, self.control_nets, guided_hints, i / len(timesteps), latent_model_input, t, text_embeddings).sample
|
||||
else:
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
@@ -1817,6 +1838,34 @@ def preprocess_mask(mask):
|
||||
# return text_encoder
|
||||
|
||||
|
||||
class BatchDataBase(NamedTuple):
|
||||
# バッチ分割が必要ないデータ
|
||||
step: int
|
||||
prompt: str
|
||||
negative_prompt: str
|
||||
seed: int
|
||||
init_image: Any
|
||||
mask_image: Any
|
||||
clip_prompt: str
|
||||
guide_image: Any
|
||||
|
||||
|
||||
class BatchDataExt(NamedTuple):
|
||||
# バッチ分割が必要なデータ
|
||||
width: int
|
||||
height: int
|
||||
steps: int
|
||||
scale: float
|
||||
negative_scale: float
|
||||
strength: float
|
||||
network_muls: Tuple[float]
|
||||
|
||||
|
||||
class BatchData(NamedTuple):
|
||||
base: BatchDataBase
|
||||
ext: BatchDataExt
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.fp16:
|
||||
dtype = torch.float16
|
||||
@@ -1995,11 +2044,13 @@ def main(args):
|
||||
# networkを組み込む
|
||||
if args.network_module:
|
||||
networks = []
|
||||
network_default_muls = []
|
||||
for i, network_module in enumerate(args.network_module):
|
||||
print("import network module:", network_module)
|
||||
imported_module = importlib.import_module(network_module)
|
||||
|
||||
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
|
||||
network_default_muls.append(network_mul)
|
||||
|
||||
net_kwargs = {}
|
||||
if args.network_args and i < len(args.network_args):
|
||||
@@ -2014,7 +2065,7 @@ def main(args):
|
||||
network_weight = args.network_weights[i]
|
||||
print("load network weights from:", network_weight)
|
||||
|
||||
if model_util.is_safetensors(network_weight):
|
||||
if model_util.is_safetensors(network_weight) and args.network_show_meta:
|
||||
from safetensors.torch import safe_open
|
||||
with safe_open(network_weight, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
@@ -2037,6 +2088,18 @@ def main(args):
|
||||
else:
|
||||
networks = []
|
||||
|
||||
# ControlNetの処理
|
||||
control_nets: List[ControlNetInfo] = []
|
||||
if args.control_net_models:
|
||||
for i, model in enumerate(args.control_net_models):
|
||||
prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i]
|
||||
weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i]
|
||||
ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i]
|
||||
|
||||
ctrl_unet, ctrl_net = original_control_net.load_control_net(args.v2, unet, model)
|
||||
prep = original_control_net.load_preprocess(prep_type)
|
||||
control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio))
|
||||
|
||||
if args.opt_channels_last:
|
||||
print(f"set optimizing: channels last")
|
||||
text_encoder.to(memory_format=torch.channels_last)
|
||||
@@ -2050,9 +2113,14 @@ def main(args):
|
||||
if vgg16_model is not None:
|
||||
vgg16_model.to(memory_format=torch.channels_last)
|
||||
|
||||
for cn in control_nets:
|
||||
cn.unet.to(memory_format=torch.channels_last)
|
||||
cn.net.to(memory_format=torch.channels_last)
|
||||
|
||||
pipe = PipelineLike(device, vae, text_encoder, tokenizer, unet, scheduler, args.clip_skip,
|
||||
clip_model, args.clip_guidance_scale, args.clip_image_guidance_scale,
|
||||
vgg16_model, args.vgg16_guidance_scale, args.vgg16_guidance_layer)
|
||||
pipe.set_control_nets(control_nets)
|
||||
print("pipeline is ready.")
|
||||
|
||||
if args.diffusers_xformers:
|
||||
@@ -2186,9 +2254,12 @@ def main(args):
|
||||
|
||||
prev_image = None # for VGG16 guided
|
||||
if args.guide_image_path is not None:
|
||||
print(f"load image for CLIP/VGG16 guidance: {args.guide_image_path}")
|
||||
guide_images = load_images(args.guide_image_path)
|
||||
print(f"loaded {len(guide_images)} guide images for CLIP/VGG16 guidance")
|
||||
print(f"load image for CLIP/VGG16/ControlNet guidance: {args.guide_image_path}")
|
||||
guide_images = []
|
||||
for p in args.guide_image_path:
|
||||
guide_images.extend(load_images(p))
|
||||
|
||||
print(f"loaded {len(guide_images)} guide images for guidance")
|
||||
if len(guide_images) == 0:
|
||||
print(f"No guide image, use previous generated image. / ガイド画像がありません。直前に生成した画像を使います: {args.image_path}")
|
||||
guide_images = None
|
||||
@@ -2219,33 +2290,37 @@ def main(args):
|
||||
iter_seed = random.randint(0, 0x7fffffff)
|
||||
|
||||
# バッチ処理の関数
|
||||
def process_batch(batch, highres_fix, highres_1st=False):
|
||||
def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
|
||||
batch_size = len(batch)
|
||||
|
||||
# highres_fixの処理
|
||||
if highres_fix and not highres_1st:
|
||||
# 1st stageのバッチを作成して呼び出す
|
||||
# 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す
|
||||
print("process 1st stage1")
|
||||
batch_1st = []
|
||||
for params1, (width, height, steps, scale, negative_scale, strength) in batch:
|
||||
width_1st = int(width * args.highres_fix_scale + .5)
|
||||
height_1st = int(height * args.highres_fix_scale + .5)
|
||||
for base, ext in batch:
|
||||
width_1st = int(ext.width * args.highres_fix_scale + .5)
|
||||
height_1st = int(ext.height * args.highres_fix_scale + .5)
|
||||
width_1st = width_1st - width_1st % 32
|
||||
height_1st = height_1st - height_1st % 32
|
||||
batch_1st.append((params1, (width_1st, height_1st, args.highres_fix_steps, scale, negative_scale, strength)))
|
||||
|
||||
ext_1st = BatchDataExt(width_1st, height_1st, args.highres_fix_steps, ext.scale,
|
||||
ext.negative_scale, ext.strength, ext.network_muls)
|
||||
batch_1st.append(BatchData(base, ext_1st))
|
||||
images_1st = process_batch(batch_1st, True, True)
|
||||
|
||||
# 2nd stageのバッチを作成して以下処理する
|
||||
print("process 2nd stage1")
|
||||
batch_2nd = []
|
||||
for i, (b1, image) in enumerate(zip(batch, images_1st)):
|
||||
image = image.resize((width, height), resample=PIL.Image.LANCZOS)
|
||||
(step, prompt, negative_prompt, seed, _, _, clip_prompt, guide_image), params2 = b1
|
||||
batch_2nd.append(((step, prompt, negative_prompt, seed+1, image, None, clip_prompt, guide_image), params2))
|
||||
for i, (bd, image) in enumerate(zip(batch, images_1st)):
|
||||
image = image.resize((bd.ext.width, bd.ext.height), resample=PIL.Image.LANCZOS) # img2imgとして設定
|
||||
bd_2nd = BatchData(BatchDataBase(*bd.base[0:3], bd.base.seed+1, image, None, *bd.base[6:]), bd.ext)
|
||||
batch_2nd.append(bd_2nd)
|
||||
batch = batch_2nd
|
||||
|
||||
(step_first, _, _, _, init_image, mask_image, _, guide_image), (width,
|
||||
height, steps, scale, negative_scale, strength) = batch[0]
|
||||
# このバッチの情報を取り出す
|
||||
(step_first, _, _, _, init_image, mask_image, _, guide_image), \
|
||||
(width, height, steps, scale, negative_scale, strength, network_muls) = batch[0]
|
||||
noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
|
||||
|
||||
prompts = []
|
||||
@@ -2295,9 +2370,13 @@ def main(args):
|
||||
all_masks_are_same = mask_images[-2] is mask_image
|
||||
|
||||
if guide_image is not None:
|
||||
guide_images.append(guide_image)
|
||||
if i > 0 and all_guide_images_are_same:
|
||||
all_guide_images_are_same = guide_images[-2] is guide_image
|
||||
if type(guide_image) is list:
|
||||
guide_images.extend(guide_image)
|
||||
all_guide_images_are_same = False
|
||||
else:
|
||||
guide_images.append(guide_image)
|
||||
if i > 0 and all_guide_images_are_same:
|
||||
all_guide_images_are_same = guide_images[-2] is guide_image
|
||||
|
||||
# make start code
|
||||
torch.manual_seed(seed)
|
||||
@@ -2320,7 +2399,19 @@ def main(args):
|
||||
if guide_images is not None and all_guide_images_are_same:
|
||||
guide_images = guide_images[0]
|
||||
|
||||
# ControlNet使用時はguide imageをリサイズする
|
||||
if control_nets:
|
||||
# TODO resampleのメソッド
|
||||
guide_images = guide_images if type(guide_images) == list else [guide_images]
|
||||
guide_images = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in guide_images]
|
||||
if len(guide_images) == 1:
|
||||
guide_images = guide_images[0]
|
||||
|
||||
# generate
|
||||
if networks:
|
||||
for n, m in zip(networks, network_muls if network_muls else network_default_muls):
|
||||
n.set_multiplier(m)
|
||||
|
||||
images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code,
|
||||
output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises, clip_prompts=clip_prompts, clip_guide_images=guide_images)[0]
|
||||
if highres_1st and not args.highres_fix_save_1st:
|
||||
@@ -2398,6 +2489,7 @@ def main(args):
|
||||
strength = 0.8 if args.strength is None else args.strength
|
||||
negative_prompt = ""
|
||||
clip_prompt = None
|
||||
network_muls = None
|
||||
|
||||
prompt_args = prompt.strip().split(' --')
|
||||
prompt = prompt_args[0]
|
||||
@@ -2461,6 +2553,15 @@ def main(args):
|
||||
clip_prompt = m.group(1)
|
||||
print(f"clip prompt: {clip_prompt}")
|
||||
continue
|
||||
|
||||
m = re.match(r'am ([\d\.\-,]+)', parg, re.IGNORECASE)
|
||||
if m: # network multiplies
|
||||
network_muls = [float(v) for v in m.group(1).split(",")]
|
||||
while len(network_muls) < len(networks):
|
||||
network_muls.append(network_muls[-1])
|
||||
print(f"network mul: {network_muls}")
|
||||
continue
|
||||
|
||||
except ValueError as ex:
|
||||
print(f"Exception in parsing / 解析エラー: {parg}")
|
||||
print(ex)
|
||||
@@ -2498,7 +2599,12 @@ def main(args):
|
||||
mask_image = mask_images[global_step % len(mask_images)]
|
||||
|
||||
if guide_images is not None:
|
||||
guide_image = guide_images[global_step % len(guide_images)]
|
||||
if control_nets: # 複数件の場合あり
|
||||
c = len(control_nets)
|
||||
p = global_step % (len(guide_images) // c)
|
||||
guide_image = guide_images[p * c:p * c + c]
|
||||
else:
|
||||
guide_image = guide_images[global_step % len(guide_images)]
|
||||
elif args.clip_image_guidance_scale > 0 or args.vgg16_guidance_scale > 0:
|
||||
if prev_image is None:
|
||||
print("Generate 1st image without guide image.")
|
||||
@@ -2506,9 +2612,8 @@ def main(args):
|
||||
print("Use previous image as guide image.")
|
||||
guide_image = prev_image
|
||||
|
||||
# TODO named tupleか何かにする
|
||||
b1 = ((global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
|
||||
(width, height, steps, scale, negative_scale, strength))
|
||||
b1 = BatchData(BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
|
||||
BatchDataExt(width, height, steps, scale, negative_scale, strength, tuple(network_muls) if network_muls else None))
|
||||
if len(batch_data) > 0 and batch_data[-1][1] != b1[1]: # バッチ分割必要?
|
||||
process_batch(batch_data, highres_fix)
|
||||
batch_data.clear()
|
||||
@@ -2578,12 +2683,15 @@ if __name__ == '__main__':
|
||||
parser.add_argument("--opt_channels_last", action='store_true',
|
||||
help='set channels last option to model / モデルにchannels lastを指定し最適化する')
|
||||
parser.add_argument("--network_module", type=str, default=None, nargs='*',
|
||||
help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名')
|
||||
help='additional network module to use / 追加ネットワークを使う時そのモジュール名')
|
||||
parser.add_argument("--network_weights", type=str, default=None, nargs='*',
|
||||
help='Hypernetwork weights to load / Hypernetworkの重み')
|
||||
parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率')
|
||||
help='additional network weights to load / 追加ネットワークの重み')
|
||||
parser.add_argument("--network_mul", type=float, default=None, nargs='*',
|
||||
help='additional network multiplier / 追加ネットワークの効果の倍率')
|
||||
parser.add_argument("--network_args", type=str, default=None, nargs='*',
|
||||
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
|
||||
parser.add_argument("--network_show_meta", action='store_true',
|
||||
help='show metadata of network model / ネットワークモデルのメタデータを表示する')
|
||||
parser.add_argument("--textual_inversion_embeddings", type=str, default=None, nargs='*',
|
||||
help='Embeddings files of Textual Inversion / Textual Inversionのembeddings')
|
||||
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
|
||||
@@ -2597,7 +2705,8 @@ if __name__ == '__main__':
|
||||
help='enable VGG16 guided SD by image, scale for guidance / 画像によるVGG16 guided SDを有効にしてこのscaleを適用する')
|
||||
parser.add_argument("--vgg16_guidance_layer", type=int, default=20,
|
||||
help='layer of VGG16 to calculate contents guide (1~30, 20 for conv4_2) / VGG16のcontents guideに使うレイヤー番号 (1~30、20はconv4_2)')
|
||||
parser.add_argument("--guide_image_path", type=str, default=None, help="image to CLIP guidance / CLIP guided SDでガイドに使う画像")
|
||||
parser.add_argument("--guide_image_path", type=str, default=None, nargs="*",
|
||||
help="image to CLIP guidance / CLIP guided SDでガイドに使う画像")
|
||||
parser.add_argument("--highres_fix_scale", type=float, default=None,
|
||||
help="enable highres fix, reso scale for 1st stage / highres fixを有効にして最初の解像度をこのscaleにする")
|
||||
parser.add_argument("--highres_fix_steps", type=int, default=28,
|
||||
@@ -2607,5 +2716,13 @@ if __name__ == '__main__':
|
||||
parser.add_argument("--negative_scale", type=float, default=None,
|
||||
help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する")
|
||||
|
||||
parser.add_argument("--control_net_models", type=str, default=None, nargs='*',
|
||||
help='ControlNet models to use / 使用するControlNetのモデル名')
|
||||
parser.add_argument("--control_net_preps", type=str, default=None, nargs='*',
|
||||
help='ControlNet preprocess to use / 使用するControlNetのプリプロセス名')
|
||||
parser.add_argument("--control_net_weights", type=float, default=None, nargs='*', help='ControlNet weights / ControlNetの重み')
|
||||
parser.add_argument("--control_net_ratios", type=float, default=None, nargs='*',
|
||||
help='ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率')
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
# common functions for training
|
||||
|
||||
import argparse
|
||||
import importlib
|
||||
import json
|
||||
import shutil
|
||||
import time
|
||||
from typing import Dict, List, NamedTuple, Tuple
|
||||
from typing import Optional, Union
|
||||
from accelerate import Accelerator
|
||||
from torch.autograd.function import Function
|
||||
import glob
|
||||
@@ -17,9 +19,12 @@ from io import BytesIO
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
from torchvision import transforms
|
||||
from transformers import CLIPTokenizer
|
||||
import transformers
|
||||
import diffusers
|
||||
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
||||
from diffusers import DDPMScheduler, StableDiffusionPipeline
|
||||
import albumentations as albu
|
||||
import numpy as np
|
||||
@@ -1366,6 +1371,33 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
|
||||
help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
|
||||
|
||||
|
||||
def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--optimizer_type", type=str, default="",
|
||||
help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor")
|
||||
|
||||
# backward compatibility
|
||||
parser.add_argument("--use_8bit_adam", action="store_true",
|
||||
help="use 8bit AdamW optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
|
||||
parser.add_argument("--use_lion_optimizer", action="store_true",
|
||||
help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)")
|
||||
|
||||
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
||||
help="Max gradient norm, 0 for no clipping / 勾配正規化の最大norm、0でclippingを行わない")
|
||||
|
||||
parser.add_argument("--optimizer_args", type=str, default=None, nargs='*',
|
||||
help="additional arguments for optimizer (like \"weight_decay=0.01 betas=0.9,0.999 ...\") / オプティマイザの追加引数(例: \"weight_decay=0.01 betas=0.9,0.999 ...\")")
|
||||
|
||||
parser.add_argument("--lr_scheduler", type=str, default="constant",
|
||||
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup, adafactor")
|
||||
parser.add_argument("--lr_warmup_steps", type=int, default=0,
|
||||
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)")
|
||||
parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
|
||||
help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
|
||||
parser.add_argument("--lr_scheduler_power", type=float, default=1,
|
||||
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
|
||||
|
||||
|
||||
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
|
||||
parser.add_argument("--output_dir", type=str, default=None,
|
||||
help="directory to output trained model / 学習後のモデル出力先ディレクトリ")
|
||||
@@ -1387,10 +1419,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
|
||||
parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225],
|
||||
help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)")
|
||||
parser.add_argument("--use_8bit_adam", action="store_true",
|
||||
help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
|
||||
parser.add_argument("--use_lion_optimizer", action="store_true",
|
||||
help="use Lion optimizer (requires lion-pytorch) / Lionオプティマイザを使う( lion-pytorch のインストールが必要)")
|
||||
parser.add_argument("--mem_eff_attn", action="store_true",
|
||||
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
|
||||
parser.add_argument("--xformers", action="store_true",
|
||||
@@ -1398,7 +1426,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
parser.add_argument("--vae", type=str, default=None,
|
||||
help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
|
||||
|
||||
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
|
||||
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
|
||||
parser.add_argument("--max_train_epochs", type=int, default=None,
|
||||
help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
|
||||
@@ -1419,10 +1446,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
parser.add_argument("--logging_dir", type=str, default=None,
|
||||
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
|
||||
parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
|
||||
parser.add_argument("--lr_scheduler", type=str, default="constant",
|
||||
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
|
||||
parser.add_argument("--lr_warmup_steps", type=int, default=0,
|
||||
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)")
|
||||
parser.add_argument("--noise_offset", type=float, default=None,
|
||||
help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)")
|
||||
parser.add_argument("--lowram", action="store_true",
|
||||
@@ -1504,6 +1527,243 @@ def add_sd_saving_arguments(parser: argparse.ArgumentParser):
|
||||
# region utils
|
||||
|
||||
|
||||
def get_optimizer(args, trainable_params):
|
||||
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, Adafactor"
|
||||
|
||||
optimizer_type = args.optimizer_type
|
||||
if args.use_8bit_adam:
|
||||
assert not args.use_lion_optimizer, "both option use_8bit_adam and use_lion_optimizer are specified / use_8bit_adamとuse_lion_optimizerの両方のオプションが指定されています"
|
||||
assert optimizer_type is None or optimizer_type == "", "both option use_8bit_adam and optimizer_type are specified / use_8bit_adamとoptimizer_typeの両方のオプションが指定されています"
|
||||
optimizer_type = "AdamW8bit"
|
||||
|
||||
elif args.use_lion_optimizer:
|
||||
assert optimizer_type is None or optimizer_type == "", "both option use_lion_optimizer and optimizer_type are specified / use_lion_optimizerとoptimizer_typeの両方のオプションが指定されています"
|
||||
optimizer_type = "Lion"
|
||||
|
||||
if optimizer_type is None or optimizer_type == "":
|
||||
optimizer_type = "AdamW"
|
||||
optimizer_type = optimizer_type.lower()
|
||||
|
||||
# 引数を分解する:boolとfloat、tupleのみ対応
|
||||
optimizer_kwargs = {}
|
||||
if args.optimizer_args is not None and len(args.optimizer_args) > 0:
|
||||
for arg in args.optimizer_args:
|
||||
key, value = arg.split('=')
|
||||
|
||||
value = value.split(",")
|
||||
for i in range(len(value)):
|
||||
if value[i].lower() == "true" or value[i].lower() == "false":
|
||||
value[i] = (value[i].lower() == "true")
|
||||
else:
|
||||
value[i] = float(value[i])
|
||||
if len(value) == 1:
|
||||
value = value[0]
|
||||
else:
|
||||
value = tuple(value)
|
||||
|
||||
optimizer_kwargs[key] = value
|
||||
# print("optkwargs:", optimizer_kwargs)
|
||||
|
||||
lr = args.learning_rate
|
||||
|
||||
if optimizer_type == "AdamW8bit".lower():
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
||||
print(f"use 8-bit AdamW optimizer | {optimizer_kwargs}")
|
||||
optimizer_class = bnb.optim.AdamW8bit
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type == "SGDNesterov8bit".lower():
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
||||
print(f"use 8-bit SGD with Nesterov optimizer | {optimizer_kwargs}")
|
||||
if "momentum" not in optimizer_kwargs:
|
||||
print(f"8-bit SGD with Nesterov must be with momentum, set momentum to 0.9 / 8-bit SGD with Nesterovはmomentum指定が必須のため0.9に設定します")
|
||||
optimizer_kwargs["momentum"] = 0.9
|
||||
|
||||
optimizer_class = bnb.optim.SGD8bit
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type == "Lion".lower():
|
||||
try:
|
||||
import lion_pytorch
|
||||
except ImportError:
|
||||
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
|
||||
print(f"use Lion optimizer | {optimizer_kwargs}")
|
||||
optimizer_class = lion_pytorch.Lion
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type == "SGDNesterov".lower():
|
||||
print(f"use SGD with Nesterov optimizer | {optimizer_kwargs}")
|
||||
if "momentum" not in optimizer_kwargs:
|
||||
print(f"SGD with Nesterov must be with momentum, set momentum to 0.9 / SGD with Nesterovはmomentum指定が必須のため0.9に設定します")
|
||||
optimizer_kwargs["momentum"] = 0.9
|
||||
|
||||
optimizer_class = torch.optim.SGD
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, nesterov=True, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type == "DAdaptation".lower():
|
||||
try:
|
||||
import dadaptation
|
||||
except ImportError:
|
||||
raise ImportError("No dadaptation / dadaptation がインストールされていないようです")
|
||||
print(f"use D-Adaptation Adam optimizer | {optimizer_kwargs}")
|
||||
|
||||
min_lr = lr
|
||||
if type(trainable_params) == list and type(trainable_params[0]) == dict:
|
||||
for group in trainable_params:
|
||||
min_lr = min(min_lr, group.get("lr", lr))
|
||||
|
||||
if min_lr <= 0.1:
|
||||
print(
|
||||
f'learning rate is too low. If using dadaptation, set learning rate around 1.0 / 学習率が低すぎるようです。1.0前後の値を指定してください: {min_lr}')
|
||||
print('recommend option: lr=1.0 / 推奨は1.0です')
|
||||
|
||||
optimizer_class = dadaptation.DAdaptAdam
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type == "Adafactor".lower():
|
||||
# 引数を確認して適宜補正する
|
||||
if "relative_step" not in optimizer_kwargs:
|
||||
optimizer_kwargs["relative_step"] = True # default
|
||||
if not optimizer_kwargs["relative_step"] and optimizer_kwargs.get("warmup_init", False):
|
||||
print(f"set relative_step to True because warmup_init is True / warmup_initがTrueのためrelative_stepをTrueにします")
|
||||
optimizer_kwargs["relative_step"] = True
|
||||
print(f"use Adafactor optimizer | {optimizer_kwargs}")
|
||||
|
||||
if optimizer_kwargs["relative_step"]:
|
||||
print(f"relative_step is true / relative_stepがtrueです")
|
||||
if lr != 0.0:
|
||||
print(f"learning rate is used as initial_lr / 指定したlearning rateはinitial_lrとして使用されます")
|
||||
args.learning_rate = None
|
||||
|
||||
# trainable_paramsがgroupだった時の処理:lrを削除する
|
||||
if type(trainable_params) == list and type(trainable_params[0]) == dict:
|
||||
has_group_lr = False
|
||||
for group in trainable_params:
|
||||
p = group.pop("lr", None)
|
||||
has_group_lr = has_group_lr or (p is not None)
|
||||
|
||||
if has_group_lr:
|
||||
# 一応argsを無効にしておく TODO 依存関係が逆転してるのであまり望ましくない
|
||||
print(f"unet_lr and text_encoder_lr are ignored / unet_lrとtext_encoder_lrは無視されます")
|
||||
args.unet_lr = None
|
||||
args.text_encoder_lr = None
|
||||
|
||||
if args.lr_scheduler != "adafactor":
|
||||
print(f"use adafactor_scheduler / スケジューラにadafactor_schedulerを使用します")
|
||||
args.lr_scheduler = f"adafactor:{lr}" # ちょっと微妙だけど
|
||||
|
||||
lr = None
|
||||
else:
|
||||
if args.max_grad_norm != 0.0:
|
||||
print(f"because max_grad_norm is set, clip_grad_norm is enabled. consider set to 0 / max_grad_normが設定されているためclip_grad_normが有効になります。0に設定して無効にしたほうがいいかもしれません")
|
||||
if args.lr_scheduler != "constant_with_warmup":
|
||||
print(f"constant_with_warmup will be good / スケジューラはconstant_with_warmupが良いかもしれません")
|
||||
if optimizer_kwargs.get("clip_threshold", 1.0) != 1.0:
|
||||
print(f"clip_threshold=1.0 will be good / clip_thresholdは1.0が良いかもしれません")
|
||||
|
||||
optimizer_class = transformers.optimization.Adafactor
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type == "AdamW".lower():
|
||||
print(f"use AdamW optimizer | {optimizer_kwargs}")
|
||||
optimizer_class = torch.optim.AdamW
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
else:
|
||||
# 任意のoptimizerを使う
|
||||
optimizer_type = args.optimizer_type # lowerでないやつ(微妙)
|
||||
print(f"use {optimizer_type} | {optimizer_kwargs}")
|
||||
if "." not in optimizer_type:
|
||||
optimizer_module = torch.optim
|
||||
else:
|
||||
values = optimizer_type.split(".")
|
||||
optimizer_module = importlib.import_module(".".join(values[:-1]))
|
||||
optimizer_type = values[-1]
|
||||
|
||||
optimizer_class = getattr(optimizer_module, optimizer_type)
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
|
||||
optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()])
|
||||
|
||||
return optimizer_name, optimizer_args, optimizer
|
||||
|
||||
|
||||
# Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
|
||||
# code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
|
||||
# Which is a newer release of diffusers than currently packaged with sd-scripts
|
||||
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
|
||||
|
||||
|
||||
def get_scheduler_fix(
|
||||
name: Union[str, SchedulerType],
|
||||
optimizer: Optimizer,
|
||||
num_warmup_steps: Optional[int] = None,
|
||||
num_training_steps: Optional[int] = None,
|
||||
num_cycles: int = 1,
|
||||
power: float = 1.0,
|
||||
):
|
||||
"""
|
||||
Unified API to get any scheduler from its name.
|
||||
Args:
|
||||
name (`str` or `SchedulerType`):
|
||||
The name of the scheduler to use.
|
||||
optimizer (`torch.optim.Optimizer`):
|
||||
The optimizer that will be used during training.
|
||||
num_warmup_steps (`int`, *optional*):
|
||||
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
||||
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
||||
num_training_steps (`int``, *optional*):
|
||||
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
||||
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
||||
num_cycles (`int`, *optional*):
|
||||
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
|
||||
power (`float`, *optional*, defaults to 1.0):
|
||||
Power factor. See `POLYNOMIAL` scheduler
|
||||
last_epoch (`int`, *optional*, defaults to -1):
|
||||
The index of the last epoch when resuming training.
|
||||
"""
|
||||
if name.startswith("adafactor"):
|
||||
assert type(optimizer) == transformers.optimization.Adafactor, f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください"
|
||||
initial_lr = float(name.split(':')[1])
|
||||
# print("adafactor scheduler init lr", initial_lr)
|
||||
return transformers.optimization.AdafactorSchedule(optimizer, initial_lr)
|
||||
|
||||
name = SchedulerType(name)
|
||||
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
||||
if name == SchedulerType.CONSTANT:
|
||||
return schedule_func(optimizer)
|
||||
|
||||
# All other schedulers require `num_warmup_steps`
|
||||
if num_warmup_steps is None:
|
||||
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
||||
|
||||
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
||||
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
|
||||
|
||||
# All other schedulers require `num_training_steps`
|
||||
if num_training_steps is None:
|
||||
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
||||
|
||||
if name == SchedulerType.COSINE_WITH_RESTARTS:
|
||||
return schedule_func(
|
||||
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
|
||||
)
|
||||
|
||||
if name == SchedulerType.POLYNOMIAL:
|
||||
return schedule_func(
|
||||
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
|
||||
)
|
||||
|
||||
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
|
||||
|
||||
|
||||
def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
||||
# backward compatibility
|
||||
if args.caption_extention is not None:
|
||||
@@ -1592,13 +1852,19 @@ def prepare_dtype(args: argparse.Namespace):
|
||||
|
||||
|
||||
def load_target_model(args: argparse.Namespace, weight_dtype):
|
||||
load_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path) # determine SD or Diffusers
|
||||
name_or_path = args.pretrained_model_name_or_path
|
||||
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
|
||||
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
|
||||
if load_stable_diffusion_format:
|
||||
print("load StableDiffusion checkpoint")
|
||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path)
|
||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path)
|
||||
else:
|
||||
print("load Diffusers pretrained models")
|
||||
pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None)
|
||||
try:
|
||||
pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None)
|
||||
except EnvironmentError as ex:
|
||||
print(
|
||||
f"model is not found as a file or in Hugging Face, perhaps file name is wrong? / 指定したモデル名のファイル、またはHugging Faceのモデルが見つかりません。ファイル名が誤っているかもしれません: {name_or_path}")
|
||||
text_encoder = pipe.text_encoder
|
||||
vae = pipe.vae
|
||||
unet = pipe.unet
|
||||
|
||||
@@ -126,6 +126,11 @@ class LoRANetwork(torch.nn.Module):
|
||||
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
||||
names.add(lora.lora_name)
|
||||
|
||||
def set_multiplier(self, multiplier):
|
||||
self.multiplier = multiplier
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.multiplier = self.multiplier
|
||||
|
||||
def load_weights(self, file):
|
||||
if os.path.splitext(file)[1] == '.safetensors':
|
||||
from safetensors.torch import load_file, safe_open
|
||||
|
||||
24
tools/canny.py
Normal file
24
tools/canny.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import argparse
|
||||
import cv2
|
||||
|
||||
|
||||
def canny(args):
|
||||
img = cv2.imread(args.input)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
canny_img = cv2.Canny(img, args.thres1, args.thres2)
|
||||
# canny_img = 255 - canny_img
|
||||
|
||||
cv2.imwrite(args.output, canny_img)
|
||||
print("done!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--input", type=str, default=None, help="input path")
|
||||
parser.add_argument("--output", type=str, default=None, help="output path")
|
||||
parser.add_argument("--thres1", type=int, default=32, help="thres1")
|
||||
parser.add_argument("--thres2", type=int, default=224, help="thres2")
|
||||
|
||||
args = parser.parse_args()
|
||||
canny(args)
|
||||
320
tools/original_control_net.py
Normal file
320
tools/original_control_net.py
Normal file
@@ -0,0 +1,320 @@
|
||||
from typing import List, NamedTuple, Any
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from diffusers import UNet2DConditionModel
|
||||
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
|
||||
|
||||
import library.model_util as model_util
|
||||
|
||||
|
||||
class ControlNetInfo(NamedTuple):
|
||||
unet: Any
|
||||
net: Any
|
||||
prep: Any
|
||||
weight: float
|
||||
ratio: float
|
||||
|
||||
|
||||
class ControlNet(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
# make control model
|
||||
self.control_model = torch.nn.Module()
|
||||
|
||||
dims = [320, 320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280, 1280]
|
||||
zero_convs = torch.nn.ModuleList()
|
||||
for i, dim in enumerate(dims):
|
||||
sub_list = torch.nn.ModuleList([torch.nn.Conv2d(dim, dim, 1)])
|
||||
zero_convs.append(sub_list)
|
||||
self.control_model.add_module("zero_convs", zero_convs)
|
||||
|
||||
middle_block_out = torch.nn.Conv2d(1280, 1280, 1)
|
||||
self.control_model.add_module("middle_block_out", torch.nn.ModuleList([middle_block_out]))
|
||||
|
||||
dims = [16, 16, 32, 32, 96, 96, 256, 320]
|
||||
strides = [1, 1, 2, 1, 2, 1, 2, 1]
|
||||
prev_dim = 3
|
||||
input_hint_block = torch.nn.Sequential()
|
||||
for i, (dim, stride) in enumerate(zip(dims, strides)):
|
||||
input_hint_block.append(torch.nn.Conv2d(prev_dim, dim, 3, stride, 1))
|
||||
if i < len(dims) - 1:
|
||||
input_hint_block.append(torch.nn.SiLU())
|
||||
prev_dim = dim
|
||||
self.control_model.add_module("input_hint_block", input_hint_block)
|
||||
|
||||
|
||||
def load_control_net(v2, unet, model):
|
||||
device = unet.device
|
||||
|
||||
# control sdからキー変換しつつU-Netに対応する部分のみ取り出し、DiffusersのU-Netに読み込む
|
||||
# state dictを読み込む
|
||||
print(f"ControlNet: loading control SD model : {model}")
|
||||
|
||||
if model_util.is_safetensors(model):
|
||||
ctrl_sd_sd = load_file(model)
|
||||
else:
|
||||
ctrl_sd_sd = torch.load(model, map_location='cpu')
|
||||
ctrl_sd_sd = ctrl_sd_sd.pop("state_dict", ctrl_sd_sd)
|
||||
|
||||
# 重みをU-Netに読み込めるようにする。ControlNetはSD版のstate dictなので、それを読み込む
|
||||
is_difference = "difference" in ctrl_sd_sd
|
||||
print("ControlNet: loading difference")
|
||||
|
||||
# ControlNetには存在しないキーがあるので、まず現在のU-NetでSD版の全keyを作っておく
|
||||
# またTransfer Controlの元weightとなる
|
||||
ctrl_unet_sd_sd = model_util.convert_unet_state_dict_to_sd(v2, unet.state_dict())
|
||||
|
||||
# 元のU-Netに影響しないようにコピーする。またprefixが付いていないので付ける
|
||||
for key in list(ctrl_unet_sd_sd.keys()):
|
||||
ctrl_unet_sd_sd["model.diffusion_model." + key] = ctrl_unet_sd_sd.pop(key).clone()
|
||||
|
||||
zero_conv_sd = {}
|
||||
for key in list(ctrl_sd_sd.keys()):
|
||||
if key.startswith("control_"):
|
||||
unet_key = "model.diffusion_" + key[len("control_"):]
|
||||
if unet_key not in ctrl_unet_sd_sd: # zero conv
|
||||
zero_conv_sd[key] = ctrl_sd_sd[key]
|
||||
continue
|
||||
if is_difference: # Transfer Control
|
||||
ctrl_unet_sd_sd[unet_key] += ctrl_sd_sd[key].to(device, dtype=unet.dtype)
|
||||
else:
|
||||
ctrl_unet_sd_sd[unet_key] = ctrl_sd_sd[key].to(device, dtype=unet.dtype)
|
||||
|
||||
unet_config = model_util.create_unet_diffusers_config(v2)
|
||||
ctrl_unet_du_sd = model_util.convert_ldm_unet_checkpoint(v2, ctrl_unet_sd_sd, unet_config) # DiffUsers版ControlNetのstate dict
|
||||
|
||||
# ControlNetのU-Netを作成する
|
||||
ctrl_unet = UNet2DConditionModel(**unet_config)
|
||||
info = ctrl_unet.load_state_dict(ctrl_unet_du_sd)
|
||||
print("ControlNet: loading Control U-Net:", info)
|
||||
|
||||
# U-Net以外のControlNetを作成する
|
||||
# TODO support middle only
|
||||
ctrl_net = ControlNet()
|
||||
info = ctrl_net.load_state_dict(zero_conv_sd)
|
||||
print("ControlNet: loading ControlNet:", info)
|
||||
|
||||
ctrl_unet.to(unet.device, dtype=unet.dtype)
|
||||
ctrl_net.to(unet.device, dtype=unet.dtype)
|
||||
return ctrl_unet, ctrl_net
|
||||
|
||||
|
||||
def load_preprocess(prep_type: str):
|
||||
if prep_type is None or prep_type.lower() == "none":
|
||||
return None
|
||||
|
||||
if prep_type.startswith("canny"):
|
||||
args = prep_type.split("_")
|
||||
th1 = int(args[1]) if len(args) >= 2 else 63
|
||||
th2 = int(args[2]) if len(args) >= 3 else 191
|
||||
|
||||
def canny(img):
|
||||
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
|
||||
return cv2.Canny(img, th1, th2)
|
||||
return canny
|
||||
|
||||
print("Unsupported prep type:", prep_type)
|
||||
return None
|
||||
|
||||
|
||||
def preprocess_ctrl_net_hint_image(image):
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[:, :, ::-1].copy() # rgb to bgr
|
||||
image = image[None].transpose(0, 3, 1, 2) # nchw
|
||||
image = torch.from_numpy(image)
|
||||
return image # 0 to 1
|
||||
|
||||
|
||||
def get_guided_hints(control_nets: List[ControlNetInfo], num_latent_input, b_size, hints):
|
||||
guided_hints = []
|
||||
for i, cnet_info in enumerate(control_nets):
|
||||
# hintは 1枚目の画像のcnet1, 1枚目の画像のcnet2, 1枚目の画像のcnet3, 2枚目の画像のcnet1, 2枚目の画像のcnet2 ... と並んでいること
|
||||
b_hints = []
|
||||
if len(hints) == 1: # すべて同じ画像をhintとして使う
|
||||
hint = hints[0]
|
||||
if cnet_info.prep is not None:
|
||||
hint = cnet_info.prep(hint)
|
||||
hint = preprocess_ctrl_net_hint_image(hint)
|
||||
b_hints = [hint for _ in range(b_size)]
|
||||
else:
|
||||
for bi in range(b_size):
|
||||
hint = hints[(bi * len(control_nets) + i) % len(hints)]
|
||||
if cnet_info.prep is not None:
|
||||
hint = cnet_info.prep(hint)
|
||||
hint = preprocess_ctrl_net_hint_image(hint)
|
||||
b_hints.append(hint)
|
||||
b_hints = torch.cat(b_hints, dim=0)
|
||||
b_hints = b_hints.to(cnet_info.unet.device, dtype=cnet_info.unet.dtype)
|
||||
|
||||
guided_hint = cnet_info.net.control_model.input_hint_block(b_hints)
|
||||
guided_hints.append(guided_hint)
|
||||
return guided_hints
|
||||
|
||||
|
||||
def call_unet_and_control_net(step, num_latent_input, original_unet, control_nets: List[ControlNetInfo], guided_hints, current_ratio, sample, timestep, encoder_hidden_states):
|
||||
# ControlNet
|
||||
# 複数のControlNetの場合は、出力をマージするのではなく交互に適用する
|
||||
cnet_cnt = len(control_nets)
|
||||
cnet_idx = step % cnet_cnt
|
||||
cnet_info = control_nets[cnet_idx]
|
||||
|
||||
# print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
|
||||
if cnet_info.ratio < current_ratio:
|
||||
return original_unet(sample, timestep, encoder_hidden_states)
|
||||
|
||||
guided_hint = guided_hints[cnet_idx]
|
||||
guided_hint = guided_hint.repeat((num_latent_input, 1, 1, 1))
|
||||
outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states)
|
||||
outs = [o * cnet_info.weight for o in outs]
|
||||
|
||||
# U-Net
|
||||
return unet_forward(False, cnet_info.net, original_unet, None, outs, sample, timestep, encoder_hidden_states)
|
||||
|
||||
|
||||
"""
|
||||
# これはmergeのバージョン
|
||||
# ControlNet
|
||||
cnet_outs_list = []
|
||||
for i, cnet_info in enumerate(control_nets):
|
||||
# print(current_ratio, cnet_info.prep, cnet_info.weight, cnet_info.ratio)
|
||||
if cnet_info.ratio < current_ratio:
|
||||
continue
|
||||
guided_hint = guided_hints[i]
|
||||
outs = unet_forward(True, cnet_info.net, cnet_info.unet, guided_hint, None, sample, timestep, encoder_hidden_states)
|
||||
for i in range(len(outs)):
|
||||
outs[i] *= cnet_info.weight
|
||||
|
||||
cnet_outs_list.append(outs)
|
||||
|
||||
count = len(cnet_outs_list)
|
||||
if count == 0:
|
||||
return original_unet(sample, timestep, encoder_hidden_states)
|
||||
|
||||
# sum of controlnets
|
||||
for i in range(1, count):
|
||||
cnet_outs_list[0] += cnet_outs_list[i]
|
||||
|
||||
# U-Net
|
||||
return unet_forward(False, cnet_info.net, original_unet, None, cnet_outs_list[0], sample, timestep, encoder_hidden_states)
|
||||
"""
|
||||
|
||||
|
||||
def unet_forward(is_control_net, control_net: ControlNet, unet: UNet2DConditionModel, guided_hint, ctrl_outs, sample, timestep, encoder_hidden_states):
|
||||
# copy from UNet2DConditionModel
|
||||
default_overall_up_factor = 2**unet.num_upsamplers
|
||||
|
||||
forward_upsample_size = False
|
||||
upsample_size = None
|
||||
|
||||
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
||||
print("Forward upsample size to force interpolation output size.")
|
||||
forward_upsample_size = True
|
||||
|
||||
# 0. center input if necessary
|
||||
if unet.config.center_input_sample:
|
||||
sample = 2 * sample - 1.0
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||
# This would be a good case for the `match` statement (Python 3.10+)
|
||||
is_mps = sample.device.type == "mps"
|
||||
if isinstance(timestep, float):
|
||||
dtype = torch.float32 if is_mps else torch.float64
|
||||
else:
|
||||
dtype = torch.int32 if is_mps else torch.int64
|
||||
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
||||
elif len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
|
||||
t_emb = unet.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=unet.dtype)
|
||||
emb = unet.time_embedding(t_emb)
|
||||
|
||||
outs = [] # output of ControlNet
|
||||
zc_idx = 0
|
||||
|
||||
# 2. pre-process
|
||||
sample = unet.conv_in(sample)
|
||||
if is_control_net:
|
||||
sample += guided_hint
|
||||
outs.append(control_net.control_model.zero_convs[zc_idx][0](sample)) # , emb, encoder_hidden_states))
|
||||
zc_idx += 1
|
||||
|
||||
# 3. down
|
||||
down_block_res_samples = (sample,)
|
||||
for downsample_block in unet.down_blocks:
|
||||
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||
if is_control_net:
|
||||
for rs in res_samples:
|
||||
outs.append(control_net.control_model.zero_convs[zc_idx][0](rs)) # , emb, encoder_hidden_states))
|
||||
zc_idx += 1
|
||||
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 4. mid
|
||||
sample = unet.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
|
||||
if is_control_net:
|
||||
outs.append(control_net.control_model.middle_block_out[0](sample))
|
||||
return outs
|
||||
|
||||
if not is_control_net:
|
||||
sample += ctrl_outs.pop()
|
||||
|
||||
# 5. up
|
||||
for i, upsample_block in enumerate(unet.up_blocks):
|
||||
is_final_block = i == len(unet.up_blocks) - 1
|
||||
|
||||
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
|
||||
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
||||
|
||||
if not is_control_net and len(ctrl_outs) > 0:
|
||||
res_samples = list(res_samples)
|
||||
apply_ctrl_outs = ctrl_outs[-len(res_samples):]
|
||||
ctrl_outs = ctrl_outs[:-len(res_samples)]
|
||||
for j in range(len(res_samples)):
|
||||
res_samples[j] = res_samples[j] + apply_ctrl_outs[j]
|
||||
res_samples = tuple(res_samples)
|
||||
|
||||
# if we have not reached the final block and need to forward the
|
||||
# upsample size, we do it here
|
||||
if not is_final_block and forward_upsample_size:
|
||||
upsample_size = down_block_res_samples[-1].shape[2:]
|
||||
|
||||
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
||||
sample = upsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
res_hidden_states_tuple=res_samples,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
upsample_size=upsample_size,
|
||||
)
|
||||
else:
|
||||
sample = upsample_block(
|
||||
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
|
||||
)
|
||||
# 6. post-process
|
||||
sample = unet.conv_norm_out(sample)
|
||||
sample = unet.conv_act(sample)
|
||||
sample = unet.conv_out(sample)
|
||||
|
||||
return UNet2DConditionOutput(sample=sample)
|
||||
38
train_db.py
38
train_db.py
@@ -115,32 +115,12 @@ def train(args):
|
||||
|
||||
# 学習に必要なクラスを準備する
|
||||
print("prepare optimizer, data loader etc.")
|
||||
|
||||
# 8-bit Adamを使う
|
||||
if args.use_8bit_adam:
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
||||
print("use 8-bit Adam optimizer")
|
||||
optimizer_class = bnb.optim.AdamW8bit
|
||||
elif args.use_lion_optimizer:
|
||||
try:
|
||||
import lion_pytorch
|
||||
except ImportError:
|
||||
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
|
||||
print("use Lion optimizer")
|
||||
optimizer_class = lion_pytorch.Lion
|
||||
else:
|
||||
optimizer_class = torch.optim.AdamW
|
||||
|
||||
if train_text_encoder:
|
||||
trainable_params = (itertools.chain(unet.parameters(), text_encoder.parameters()))
|
||||
else:
|
||||
trainable_params = unet.parameters()
|
||||
|
||||
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
||||
optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
@@ -156,9 +136,10 @@ def train(args):
|
||||
if args.stop_text_encoder_training is None:
|
||||
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
|
||||
|
||||
# lr schedulerを用意する
|
||||
lr_scheduler = diffusers.optimization.get_scheduler(
|
||||
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps)
|
||||
# lr schedulerを用意する TODO gradient_accumulation_stepsの扱いが何かおかしいかもしれない。後で確認する
|
||||
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
||||
num_training_steps=args.max_train_steps,
|
||||
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
||||
if args.full_fp16:
|
||||
@@ -281,12 +262,12 @@ def train(args):
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
if train_text_encoder:
|
||||
params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters()))
|
||||
else:
|
||||
params_to_clip = unet.parameters()
|
||||
accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
@@ -299,7 +280,9 @@ def train(args):
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
||||
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
||||
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
|
||||
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if epoch == 0:
|
||||
@@ -352,6 +335,7 @@ if __name__ == '__main__':
|
||||
train_util.add_dataset_arguments(parser, True, False, True)
|
||||
train_util.add_training_arguments(parser, True)
|
||||
train_util.add_sd_saving_arguments(parser)
|
||||
train_util.add_optimizer_arguments(parser)
|
||||
|
||||
parser.add_argument("--no_token_padding", action="store_true",
|
||||
help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)")
|
||||
|
||||
135
train_network.py
135
train_network.py
@@ -1,8 +1,5 @@
|
||||
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
||||
from torch.optim import Optimizer
|
||||
from torch.cuda.amp import autocast
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from typing import Optional, Union
|
||||
import importlib
|
||||
import argparse
|
||||
import gc
|
||||
@@ -26,83 +23,24 @@ def collate_fn(examples):
|
||||
return examples[0]
|
||||
|
||||
|
||||
# TODO 他のスクリプトと共通化する
|
||||
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
|
||||
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
||||
|
||||
if args.network_train_unet_only:
|
||||
logs["lr/unet"] = lr_scheduler.get_last_lr()[0]
|
||||
logs["lr/unet"] = float(lr_scheduler.get_last_lr()[0])
|
||||
elif args.network_train_text_encoder_only:
|
||||
logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
|
||||
logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
|
||||
else:
|
||||
logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
|
||||
logs["lr/unet"] = lr_scheduler.get_last_lr()[-1] # may be same to textencoder
|
||||
logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
|
||||
logs["lr/unet"] = float(lr_scheduler.get_last_lr()[-1]) # may be same to textencoder
|
||||
|
||||
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet.
|
||||
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]['d']*lr_scheduler.optimizers[-1].param_groups[0]['lr']
|
||||
|
||||
return logs
|
||||
|
||||
|
||||
# Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
|
||||
# code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
|
||||
# Which is a newer release of diffusers than currently packaged with sd-scripts
|
||||
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
|
||||
|
||||
|
||||
def get_scheduler_fix(
|
||||
name: Union[str, SchedulerType],
|
||||
optimizer: Optimizer,
|
||||
num_warmup_steps: Optional[int] = None,
|
||||
num_training_steps: Optional[int] = None,
|
||||
num_cycles: int = 1,
|
||||
power: float = 1.0,
|
||||
):
|
||||
"""
|
||||
Unified API to get any scheduler from its name.
|
||||
Args:
|
||||
name (`str` or `SchedulerType`):
|
||||
The name of the scheduler to use.
|
||||
optimizer (`torch.optim.Optimizer`):
|
||||
The optimizer that will be used during training.
|
||||
num_warmup_steps (`int`, *optional*):
|
||||
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
||||
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
||||
num_training_steps (`int``, *optional*):
|
||||
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
||||
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
||||
num_cycles (`int`, *optional*):
|
||||
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
|
||||
power (`float`, *optional*, defaults to 1.0):
|
||||
Power factor. See `POLYNOMIAL` scheduler
|
||||
last_epoch (`int`, *optional*, defaults to -1):
|
||||
The index of the last epoch when resuming training.
|
||||
"""
|
||||
name = SchedulerType(name)
|
||||
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
||||
if name == SchedulerType.CONSTANT:
|
||||
return schedule_func(optimizer)
|
||||
|
||||
# All other schedulers require `num_warmup_steps`
|
||||
if num_warmup_steps is None:
|
||||
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
||||
|
||||
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
||||
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
|
||||
|
||||
# All other schedulers require `num_training_steps`
|
||||
if num_training_steps is None:
|
||||
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
||||
|
||||
if name == SchedulerType.COSINE_WITH_RESTARTS:
|
||||
return schedule_func(
|
||||
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
|
||||
)
|
||||
|
||||
if name == SchedulerType.POLYNOMIAL:
|
||||
return schedule_func(
|
||||
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
|
||||
)
|
||||
|
||||
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
|
||||
|
||||
|
||||
def train(args):
|
||||
session_id = random.randint(0, 2**32)
|
||||
training_started_at = time.time()
|
||||
@@ -161,7 +99,7 @@ def train(args):
|
||||
if args.lowram:
|
||||
text_encoder.to("cuda")
|
||||
unet.to("cuda")
|
||||
|
||||
|
||||
# モデルに xformers とか memory efficient attention を組み込む
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
||||
|
||||
@@ -208,30 +146,8 @@ def train(args):
|
||||
# 学習に必要なクラスを準備する
|
||||
print("prepare optimizer, data loader etc.")
|
||||
|
||||
# 8-bit Adamを使う
|
||||
if args.use_8bit_adam:
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
||||
print("use 8-bit Adam optimizer")
|
||||
optimizer_class = bnb.optim.AdamW8bit
|
||||
elif args.use_lion_optimizer:
|
||||
try:
|
||||
import lion_pytorch
|
||||
except ImportError:
|
||||
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
|
||||
print("use Lion optimizer")
|
||||
optimizer_class = lion_pytorch.Lion
|
||||
else:
|
||||
optimizer_class = torch.optim.AdamW
|
||||
|
||||
optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__
|
||||
|
||||
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
|
||||
|
||||
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
||||
optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
|
||||
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
@@ -245,11 +161,9 @@ def train(args):
|
||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
# lr schedulerを用意する
|
||||
# lr_scheduler = diffusers.optimization.get_scheduler(
|
||||
lr_scheduler = get_scheduler_fix(
|
||||
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
||||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
||||
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
||||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
||||
if args.full_fp16:
|
||||
@@ -361,9 +275,11 @@ def train(args):
|
||||
"ss_shuffle_caption": bool(args.shuffle_caption),
|
||||
"ss_cache_latents": bool(args.cache_latents),
|
||||
"ss_enable_bucket": bool(train_dataset.enable_bucket),
|
||||
"ss_bucket_no_upscale": bool(train_dataset.bucket_no_upscale),
|
||||
"ss_min_bucket_reso": train_dataset.min_bucket_reso,
|
||||
"ss_max_bucket_reso": train_dataset.max_bucket_reso,
|
||||
"ss_seed": args.seed,
|
||||
"ss_lowram": args.lowram,
|
||||
"ss_keep_tokens": args.keep_tokens,
|
||||
"ss_noise_offset": args.noise_offset,
|
||||
"ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info),
|
||||
@@ -372,7 +288,13 @@ def train(args):
|
||||
"ss_bucket_info": json.dumps(train_dataset.bucket_info),
|
||||
"ss_training_comment": args.training_comment, # will not be updated after training
|
||||
"ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
|
||||
"ss_optimizer": optimizer_name
|
||||
"ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""),
|
||||
"ss_max_grad_norm": args.max_grad_norm,
|
||||
"ss_caption_dropout_rate": args.caption_dropout_rate,
|
||||
"ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs,
|
||||
"ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate,
|
||||
"ss_face_crop_aug_range": args.face_crop_aug_range,
|
||||
"ss_prior_loss_weight": args.prior_loss_weight,
|
||||
}
|
||||
|
||||
# uncomment if another network is added
|
||||
@@ -447,7 +369,7 @@ def train(args):
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Predict the noise residual
|
||||
with autocast():
|
||||
with accelerator.autocast():
|
||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
if args.v_parameterization:
|
||||
@@ -465,9 +387,9 @@ def train(args):
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
params_to_clip = network.get_trainable_params()
|
||||
accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
@@ -508,6 +430,7 @@ def train(args):
|
||||
def save_func():
|
||||
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
metadata["ss_training_finished_at"] = str(time.time())
|
||||
print(f"saving checkpoint: {ckpt_file}")
|
||||
unwrap_model(network).save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
|
||||
|
||||
@@ -525,6 +448,7 @@ def train(args):
|
||||
# end of epoch
|
||||
|
||||
metadata["ss_epoch"] = str(num_train_epochs)
|
||||
metadata["ss_training_finished_at"] = str(time.time())
|
||||
|
||||
is_main_process = accelerator.is_main_process
|
||||
if is_main_process:
|
||||
@@ -555,6 +479,7 @@ if __name__ == '__main__':
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, True, True, True)
|
||||
train_util.add_training_arguments(parser, True)
|
||||
train_util.add_optimizer_arguments(parser)
|
||||
|
||||
parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
|
||||
parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
|
||||
@@ -562,10 +487,6 @@ if __name__ == '__main__':
|
||||
|
||||
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
||||
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
|
||||
parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
|
||||
help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
|
||||
parser.add_argument("--lr_scheduler_power", type=float, default=1,
|
||||
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
|
||||
|
||||
parser.add_argument("--network_weights", type=str, default=None,
|
||||
help="pretrained weights for network / 学習するネットワークの初期重み")
|
||||
|
||||
@@ -50,11 +50,13 @@ accelerate launch --num_cpu_threads_per_process 1 train_network.py
|
||||
--train_data_dir=..\data\db\char1 --output_dir=..\lora_train1
|
||||
--reg_data_dir=..\data\db\reg1 --prior_loss_weight=1.0
|
||||
--resolution=448,640 --train_batch_size=1 --learning_rate=1e-4
|
||||
--max_train_steps=400 --use_8bit_adam --xformers --mixed_precision=fp16
|
||||
--max_train_steps=400 --optimizer_type=AdamW8bit --xformers --mixed_precision=fp16
|
||||
--save_every_n_epochs=1 --save_model_as=safetensors --clip_skip=2 --seed=42 --color_aug
|
||||
--network_module=networks.lora
|
||||
```
|
||||
|
||||
(2023/2/22:オプティマイザの指定方法が変わりました。[こちら](#オプティマイザの指定について)をご覧ください。)
|
||||
|
||||
--output_dirオプションで指定したフォルダに、LoRAのモデルが保存されます。
|
||||
|
||||
その他、以下のオプションが指定できます。
|
||||
@@ -76,6 +78,42 @@ accelerate launch --num_cpu_threads_per_process 1 train_network.py
|
||||
|
||||
--network_train_unet_onlyと--network_train_text_encoder_onlyの両方とも未指定時(デフォルト)はText EncoderとU-Netの両方のLoRAモジュールを有効にします。
|
||||
|
||||
## オプティマイザの指定について
|
||||
|
||||
--optimizer_type オプションでオプティマイザの種類を指定します。以下が指定できます。
|
||||
|
||||
- AdamW : [torch.optim.AdamW](https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html)
|
||||
- 過去のバージョンのオプション未指定時と同じ
|
||||
- AdamW8bit : 引数は同上
|
||||
- 過去のバージョンの--use_8bit_adam指定時と同じ
|
||||
- Lion : https://github.com/lucidrains/lion-pytorch
|
||||
- 過去のバージョンの--use_lion_optimizer指定時と同じ
|
||||
- SGDNesterov : [torch.optim.SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html), nesterov=True
|
||||
- SGDNesterov8bit : 引数は同上
|
||||
- DAdaptation : https://github.com/facebookresearch/dadaptation
|
||||
- AdaFactor : [Transformers AdaFactor](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules)
|
||||
- 任意のオプティマイザ
|
||||
|
||||
オプティマイザのオプション引数は--optimizer_argsオプションで指定してください。key=valueの形式で、複数の値が指定できます。また、valueはカンマ区切りで複数の値が指定できます。たとえばAdamWオプティマイザに引数を指定する場合は、``--optimizer_args weight_decay=0.01 betas=.9,.999``のようになります。
|
||||
|
||||
オプション引数を指定する場合は、それぞれのオプティマイザの仕様をご確認ください。
|
||||
|
||||
一部のオプティマイザでは必須の引数があり、省略すると自動的に追加されます(SGDNesterovのmomentumなど)。コンソールの出力を確認してください。
|
||||
|
||||
D-Adaptationオプティマイザは学習率を自動調整します。学習率のオプションに指定した値は学習率そのものではなくD-Adaptationが決定した学習率の適用率になりますので、通常は1.0を指定してください。Text EncoderにU-Netの半分の学習率を指定したい場合は、``--text_encoder_lr=0.5 --unet_lr=1.0``と指定します。
|
||||
|
||||
AdaFactorオプティマイザはrelative_step=Trueを指定すると学習率を自動調整できます(省略時はデフォルトで追加されます)。自動調整する場合は学習率のスケジューラにはadafactor_schedulerが強制的に使用されます。またscale_parameterとwarmup_initを指定するとよいようです。
|
||||
|
||||
自動調整する場合のオプション指定はたとえば ``--optimizer_args "relative_step=True" "scale_parameter=True" "warmup_init=True"`` のようになります。
|
||||
|
||||
学習率を自動調整しない場合はオプション引数 ``relative_step=False`` を追加してください。その場合、学習率のスケジューラにはconstant_with_warmupが、また勾配のclip normをしないことが推奨されているようです。そのため引数は ``--optimizer_type=adafactor --optimizer_args "relative_step=False" --lr_scheduler="constant_with_warmup" --max_grad_norm=0.0`` のようになります。
|
||||
|
||||
### 任意のオプティマイザを使う
|
||||
|
||||
``torch.optim`` のオプティマイザを使う場合にはクラス名のみを(``--optimizer_type=RMSprop``など)、他のモジュールのオプティマイザを使う時は「モジュール名.クラス名」を指定してください(``--optimizer_type=bitsandbytes.optim.lamb.LAMB``など)。
|
||||
|
||||
(内部でimportlibしているだけで動作は未確認です。必要ならパッケージをインストールしてください。)
|
||||
|
||||
## マージスクリプトについて
|
||||
|
||||
merge_lora.pyでStable DiffusionのモデルにLoRAの学習結果をマージしたり、複数のLoRAモデルをマージしたりできます。
|
||||
|
||||
@@ -198,29 +198,8 @@ def train(args):
|
||||
|
||||
# 学習に必要なクラスを準備する
|
||||
print("prepare optimizer, data loader etc.")
|
||||
|
||||
# 8-bit Adamを使う
|
||||
if args.use_8bit_adam:
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
||||
print("use 8-bit Adam optimizer")
|
||||
optimizer_class = bnb.optim.AdamW8bit
|
||||
elif args.use_lion_optimizer:
|
||||
try:
|
||||
import lion_pytorch
|
||||
except ImportError:
|
||||
raise ImportError("No lion_pytorch / lion_pytorch がインストールされていないようです")
|
||||
print("use Lion optimizer")
|
||||
optimizer_class = lion_pytorch.Lion
|
||||
else:
|
||||
optimizer_class = torch.optim.AdamW
|
||||
|
||||
trainable_params = text_encoder.get_input_embeddings().parameters()
|
||||
|
||||
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
||||
optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
|
||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
@@ -234,8 +213,9 @@ def train(args):
|
||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
# lr schedulerを用意する
|
||||
lr_scheduler = diffusers.optimization.get_scheduler(
|
||||
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
|
||||
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
||||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
@@ -357,9 +337,9 @@ def train(args):
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
params_to_clip = text_encoder.get_input_embeddings().parameters()
|
||||
accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
@@ -376,7 +356,9 @@ def train(args):
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
||||
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
||||
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
|
||||
logs["lr/d*lr"] = lr_scheduler.optimizers[0].param_groups[0]['d']*lr_scheduler.optimizers[0].param_groups[0]['lr']
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
loss_total += current_loss
|
||||
@@ -491,6 +473,7 @@ if __name__ == '__main__':
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, True, True, False)
|
||||
train_util.add_training_arguments(parser, True)
|
||||
train_util.add_optimizer_arguments(parser)
|
||||
|
||||
parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
|
||||
help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)")
|
||||
|
||||
Reference in New Issue
Block a user