mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
Compare commits
23 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cae42728ab | ||
|
|
50f65d683d | ||
|
|
0fc1cc8076 | ||
|
|
943eae1211 | ||
|
|
4c928c8d12 | ||
|
|
687044519b | ||
|
|
758323532b | ||
|
|
8bd844cdc1 | ||
|
|
4d4ebf600e | ||
|
|
e6a8c9d269 | ||
|
|
3eb8fb1875 | ||
|
|
fda66db0d8 | ||
|
|
3815b82bef | ||
|
|
37fbefb3cd | ||
|
|
c6e28faa57 | ||
|
|
a888223869 | ||
|
|
d30ea7966d | ||
|
|
df9cb2f11c | ||
|
|
8544e219b0 | ||
|
|
f2f2ce0d7d | ||
|
|
c9fda104b4 | ||
|
|
aa40cb9345 | ||
|
|
b8734405c6 |
@@ -99,7 +99,7 @@ accelerate configの質問には以下のように答えてください。(bf1
|
||||
```
|
||||
|
||||
※場合によって ``ValueError: fp16 mixed precision requires a GPU`` というエラーが出ることがあるようです。この場合、6番目の質問(
|
||||
``What GPU(s) (by id) should be used for training on this machine as a comma-seperated list? [all]:``)に「0」と答えてください。(id `0`のGPUが使われます。)
|
||||
``What GPU(s) (by id) should be used for training on this machine as a comma-separated list? [all]:``)に「0」と答えてください。(id `0`のGPUが使われます。)
|
||||
|
||||
## アップグレード
|
||||
|
||||
|
||||
144
README.md
144
README.md
@@ -2,23 +2,63 @@ This repository contains training, generation and utility scripts for Stable Dif
|
||||
|
||||
## Updates
|
||||
|
||||
- January 14, 2023, 2023/1/14
|
||||
- Fix loading some VAE or .safetensors as VAE is failed for ``--vae`` option. Thanks to Fannovel16!
|
||||
- Add negative prompt scaling for ``gen_img_diffusers.py`` You can set another conditioning scale to the negative prompt with ``--negative_scale`` option, and ``--nl`` option for the prompt. Thanks to laksjdjf!
|
||||
- ``--vae`` オプションに一部のVAEや .safetensors 形式のモデルを指定するとエラーになる不具合を修正しました。Fannovel16氏に感謝します。
|
||||
- ``gen_img_diffusers.py`` に、ネガティブプロンプトに異なる guidance scale を設定できる ``--negative_scale`` オプションを追加しました。プロンプトからは ``--nl`` で指定できます。laksjdjf氏に感謝します。
|
||||
- January 12, 2023, 2023/1/12
|
||||
- Metadata is saved on the model (.safetensors only) (model name, VAE name, training steps, learning rate etc.) The metadata will be able to inspect by sd-webui-additional-networks extension in near future. If you do not want to save it, specify ``no_metadata`` option.
|
||||
- メタデータが保存されるようになりました( .safetensors 形式の場合のみ)(モデル名、VAE 名、ステップ数、学習率など)。近日中に拡張から確認できるようになる予定です。メタデータを保存したくない場合は ``no_metadata`` オプションをしてしてください。
|
||||
|
||||
**January 9, 2023: Important information about the update can be found at [the end of the page](#updates-jan-9-2023).**
|
||||
- 19 Jan. 2023, 2023/1/19
|
||||
- Fix a part of LoRA modules are not trained when ``gradient_checkpointing`` is enabled.
|
||||
- Add ``--save_last_n_epochs_state`` option. You can specify how many state folders to keep, apart from how many models to keep. Thanks to shirayu!
|
||||
- Fix Text Encoder training stops at ``max_train_steps`` even if ``max_train_epochs`` is set in `train_db.py``.
|
||||
- Added script to check LoRA weights. You can check weights by ``python networks\check_lora_weights.py <model file>``. If some modules are not trained, the value is ``0.0`` like following.
|
||||
- ``lora_te_text_model_encoder_layers_11_*`` is not trained with ``clip_skip=2``, so ``0.0`` is okay for these modules.
|
||||
- 一部のLoRAモジュールが ``gradient_checkpointing`` を有効にすると学習されない不具合を修正しました。ご不便をおかけしました。
|
||||
- ``--save_last_n_epochs_state`` オプションを追加しました。モデルの保存数とは別に、stateフォルダの保存数を指定できます。shirayu氏に感謝します。
|
||||
- ``train_db.py`` で、``max_train_epochs`` を指定していても、``max_train_steps`` のステップでText Encoderの学習が停止してしまう不具合を修正しました。
|
||||
- LoRAの重みをチェックするスクリプトを追加してあります。``python networks\check_lora_weights.py <model file>`` のように実行してください。学習していない重みがあると、値が 下のように ``0.0`` になります。
|
||||
- ``lora_te_text_model_encoder_layers_11_`` で始まる部分は ``clip_skip=2`` の場合は学習されないため、``0.0`` で正常です。
|
||||
|
||||
**20231/1/9: 更新情報が[ページ末尾](#更新情報-202319)にありますのでご覧ください。**
|
||||
- example result of ``check_lora_weights.py``, Text Encoder and a part of U-Net are not trained:
|
||||
```
|
||||
number of LoRA-up modules: 264
|
||||
lora_te_text_model_encoder_layers_0_mlp_fc1.lora_up.weight,0.0
|
||||
lora_te_text_model_encoder_layers_0_mlp_fc2.lora_up.weight,0.0
|
||||
lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_up.weight,0.0
|
||||
:
|
||||
lora_unet_down_blocks_2_attentions_1_transformer_blocks_0_ff_net_0_proj.lora_up.weight,0.0
|
||||
lora_unet_down_blocks_2_attentions_1_transformer_blocks_0_ff_net_2.lora_up.weight,0.0
|
||||
lora_unet_mid_block_attentions_0_proj_in.lora_up.weight,0.003503334941342473
|
||||
lora_unet_mid_block_attentions_0_proj_out.lora_up.weight,0.004308608360588551
|
||||
:
|
||||
```
|
||||
|
||||
[日本語版README](./README-ja.md)
|
||||
- all modules are trained:
|
||||
```
|
||||
number of LoRA-up modules: 264
|
||||
lora_te_text_model_encoder_layers_0_mlp_fc1.lora_up.weight,0.0028684409335255623
|
||||
lora_te_text_model_encoder_layers_0_mlp_fc2.lora_up.weight,0.0029794853180646896
|
||||
lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_up.weight,0.002507600700482726
|
||||
lora_te_text_model_encoder_layers_0_self_attn_out_proj.lora_up.weight,0.002639499492943287
|
||||
:
|
||||
```
|
||||
|
||||
- 17 Jan. 2023, 2023/1/17
|
||||
- __Important Notice__
|
||||
It seems that only a part of LoRA modules are trained when ``gradient_checkpointing`` is enabled. The cause is under investigation, but for the time being, please train without ``gradient_checkpointing``. __The issue is fixed now.__
|
||||
- __重要なお知らせ__
|
||||
``gradient_checkpointing`` を有効にすると LoRA モジュールの一部しか学習されないようです。原因は調査中ですが当面は ``gradient_checkpointing`` を指定せずに学習してください。__問題は修正されました。__
|
||||
|
||||
- 15 Jan. 2023, 2023/1/15
|
||||
- Added ``--max_train_epochs`` and ``--max_data_loader_n_workers`` option for each training script.
|
||||
- If you specify the number of training epochs with ``--max_train_epochs``, the number of steps is calculated from the number of epochs automatically.
|
||||
- You can set the number of workers for DataLoader with ``--max_data_loader_n_workers``, default is 8. The lower number may reduce the main memory usage and the time between epochs, but may cause slower dataloading (training).
|
||||
- ``--max_train_epochs`` と ``--max_data_loader_n_workers`` のオプションが学習スクリプトに追加されました。
|
||||
- ``--max_train_epochs`` で学習したいエポック数を指定すると、必要なステップ数が自動的に計算され設定されます。
|
||||
- ``--max_data_loader_n_workers`` で DataLoader の worker 数が指定できます(デフォルトは8)。値を小さくするとメインメモリの使用量が減り、エポック間の待ち時間も短くなるようです。ただしデータ読み込み(学習時間)は長くなる可能性があります。
|
||||
|
||||
Please read [release version 0.3.0](https://github.com/kohya-ss/sd-scripts/releases/tag/v0.3.0) for recent updates.
|
||||
最近の更新情報は [release version 0.3.0](https://github.com/kohya-ss/sd-scripts/releases/tag/v0.3.0) をご覧ください。
|
||||
|
||||
##
|
||||
|
||||
[日本語版README](./README-ja.md)
|
||||
|
||||
For easier use (GUI and PowerShell scripts etc...), please visit [the repository maintained by bmaltais](https://github.com/bmaltais/kohya_ss). Thanks to @bmaltais!
|
||||
|
||||
This repository contains the scripts for:
|
||||
@@ -94,8 +134,8 @@ Answers to accelerate config:
|
||||
- fp16
|
||||
```
|
||||
|
||||
note: Some user reports ``ValueError: fp16 mixed precision requires a GPU`` is occured in training. In this case, answer `0` for the 6th question:
|
||||
``What GPU(s) (by id) should be used for training on this machine as a comma-seperated list? [all]:``
|
||||
note: Some user reports ``ValueError: fp16 mixed precision requires a GPU`` is occurred in training. In this case, answer `0` for the 6th question:
|
||||
``What GPU(s) (by id) should be used for training on this machine as a comma-separated list? [all]:``
|
||||
|
||||
(Single GPU with id `0` will be used.)
|
||||
|
||||
@@ -125,79 +165,3 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
|
||||
[bitsandbytes](https://github.com/TimDettmers/bitsandbytes): MIT
|
||||
|
||||
[BLIP](https://github.com/salesforce/BLIP): BSD-3-Clause
|
||||
|
||||
|
||||
# Updates: Jan 9. 2023
|
||||
|
||||
All training scripts are updated.
|
||||
|
||||
## Breaking Changes
|
||||
|
||||
- The ``fine_tuning`` option in ``train_db.py`` is removed. Please use DreamBooth with captions or ``fine_tune.py``.
|
||||
- The Hypernet feature in ``fine_tune.py`` is removed, will be implemented in ``train_network.py`` in future.
|
||||
|
||||
## Features, Improvements and Bug Fixes
|
||||
|
||||
### for all script: train_db.py, fine_tune.py and train_network.py
|
||||
|
||||
- Added ``output_name`` option. The name of output file can be specified.
|
||||
- With ``--output_name style1``, the output file is like ``style1_000001.ckpt`` (or ``.safetensors``) for each epoch and ``style1.ckpt`` for last.
|
||||
- If ommitted (default), same to previous. ``epoch-000001.ckpt`` and ``last.ckpt``.
|
||||
- Added ``save_last_n_epochs`` option. Keep only latest n files for the checkpoints and the states. Older files are removed. (Thanks to shirayu!)
|
||||
- If the options are ``--save_every_n_epochs=2 --save_last_n_epochs=3``, in the end of epoch 8, ``epoch-000008.ckpt`` is created and ``epoch-000002.ckpt`` is removed.
|
||||
|
||||
### train_db.py
|
||||
|
||||
- Added ``max_token_length`` option. Captions can have more than 75 tokens.
|
||||
|
||||
### fine_tune.py
|
||||
|
||||
- The script now works without .npz files. If .npz is not found, the scripts get the latents with VAE.
|
||||
- You can omit ``prepare_buckets_latents.py`` in preprocessing. However, it is recommended if you train more than 1 or 2 epochs.
|
||||
- ``--resolution`` option is required to specify the training resolution.
|
||||
- Added ``cache_latents`` and ``color_aug`` options.
|
||||
|
||||
### train_network.py
|
||||
|
||||
- Now ``--gradient_checkpointing`` is effective for U-Net and Text Encoder.
|
||||
- The memory usage is reduced. The larger batch size is avilable, but the training speed will be slow.
|
||||
- The training might be possible with 6GB VRAM for dimension=4 with batch size=1.
|
||||
|
||||
Documents are not updated now, I will update one by one.
|
||||
|
||||
# 更新情報 (2023/1/9)
|
||||
|
||||
学習スクリプトを更新しました。
|
||||
|
||||
## 削除された機能
|
||||
- ``train_db.py`` の ``fine_tuning`` は削除されました。キャプション付きの DreamBooth または ``fine_tune.py`` を使ってください。
|
||||
- ``fine_tune.py`` の Hypernet学習の機能は削除されました。将来的に``train_network.py``に追加される予定です。
|
||||
|
||||
## その他の機能追加、バグ修正など
|
||||
|
||||
### 学習スクリプトに共通: train_db.py, fine_tune.py and train_network.py
|
||||
|
||||
- ``output_name``オプションを追加しました。保存されるモデルファイルの名前を指定できます。
|
||||
- ``--output_name style1``と指定すると、エポックごとに保存されるファイル名は``style1_000001.ckpt`` (または ``.safetensors``) に、最後に保存されるファイル名は``style1.ckpt``になります。
|
||||
- 省略時は今までと同じです(``epoch-000001.ckpt``および``last.ckpt``)。
|
||||
- ``save_last_n_epochs``オプションを追加しました。最新の n ファイル、stateだけ保存し、古いものは削除します。(shirayu氏に感謝します。)
|
||||
- たとえば``--save_every_n_epochs=2 --save_last_n_epochs=3``と指定した時、8エポック目の終了時には、``epoch-000008.ckpt``が保存され``epoch-000002.ckpt``が削除されます。
|
||||
|
||||
### train_db.py
|
||||
|
||||
- ``max_token_length``オプションを追加しました。75文字を超えるキャプションが使えるようになります。
|
||||
|
||||
### fine_tune.py
|
||||
|
||||
- .npzファイルがなくても動作するようになりました。.npzファイルがない場合、VAEからlatentsを取得して動作します。
|
||||
- ``prepare_buckets_latents.py``を前処理で実行しなくても良くなります。ただし事前取得をしておいたほうが、2エポック以上学習する場合にはトータルで高速です。
|
||||
- この場合、解像度を指定するために``--resolution``オプションが必要です。
|
||||
- ``cache_latents``と``color_aug``オプションを追加しました。
|
||||
|
||||
### train_network.py
|
||||
|
||||
- ``--gradient_checkpointing``がU-NetとText Encoderにも有効になりました。
|
||||
- メモリ消費が減ります。バッチサイズを大きくできますが、トータルでの学習時間は長くなるかもしれません。
|
||||
- dimension=4のLoRAはバッチサイズ1で6GB VRAMで学習できるかもしれません。
|
||||
|
||||
ドキュメントは未更新ですが少しずつ更新の予定です。
|
||||
|
||||
@@ -161,10 +161,15 @@ def train(args):
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(8, os.cpu_count() - 1) # cpu_count-1 ただし最大8
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers)
|
||||
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
|
||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
# lr schedulerを用意する
|
||||
lr_scheduler = diffusers.optimization.get_scheduler(
|
||||
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
|
||||
|
||||
@@ -2518,9 +2518,9 @@ if __name__ == '__main__':
|
||||
parser.add_argument("--bf16", action='store_true', help='use bfloat16 / bfloat16を指定し省メモリ化する')
|
||||
parser.add_argument("--xformers", action='store_true', help='use xformers / xformersを使用し高速化する')
|
||||
parser.add_argument("--diffusers_xformers", action='store_true',
|
||||
help='use xformers by diffusers (Hypernetworks doen\'t work) / Diffusersでxformersを使用する(Hypernetwork利用不可)')
|
||||
help='use xformers by diffusers (Hypernetworks doesn\'t work) / Diffusersでxformersを使用する(Hypernetwork利用不可)')
|
||||
parser.add_argument("--opt_channels_last", action='store_true',
|
||||
help='set channels last option to model / モデルにchannles lastを指定し最適化する')
|
||||
help='set channels last option to model / モデルにchannels lastを指定し最適化する')
|
||||
parser.add_argument("--network_module", type=str, default=None, nargs='*',
|
||||
help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名')
|
||||
parser.add_argument("--network_weights", type=str, default=None, nargs='*',
|
||||
|
||||
@@ -1029,6 +1029,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
parser.add_argument("--save_every_n_epochs", type=int, default=None,
|
||||
help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する")
|
||||
parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する")
|
||||
parser.add_argument("--save_last_n_epochs_state", type=int, default=None, help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)")
|
||||
parser.add_argument("--save_state", action="store_true",
|
||||
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
|
||||
parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
|
||||
@@ -1047,6 +1048,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
|
||||
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
|
||||
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
|
||||
parser.add_argument("--max_train_epochs", type=int, default=None, help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
|
||||
parser.add_argument("--max_data_loader_n_workers", type=int, default=8, help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)")
|
||||
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
|
||||
parser.add_argument("--gradient_checkpointing", action="store_true",
|
||||
help="enable gradient checkpointing / grandient checkpointingを有効にする")
|
||||
@@ -1296,7 +1299,6 @@ def get_epoch_ckpt_name(args: argparse.Namespace, use_safetensors, epoch):
|
||||
|
||||
def save_on_epoch_end(args: argparse.Namespace, save_func, remove_old_func, epoch_no: int, num_train_epochs: int):
|
||||
saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs
|
||||
remove_epoch_no = None
|
||||
if saving:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
save_func()
|
||||
@@ -1304,7 +1306,7 @@ def save_on_epoch_end(args: argparse.Namespace, save_func, remove_old_func, epoc
|
||||
if args.save_last_n_epochs is not None:
|
||||
remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs
|
||||
remove_old_func(remove_epoch_no)
|
||||
return saving, remove_epoch_no
|
||||
return saving
|
||||
|
||||
|
||||
def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path: str, save_stable_diffusion_format: bool, use_safetensors: bool, save_dtype: torch.dtype, epoch: int, num_train_epochs: int, global_step: int, text_encoder, unet, vae):
|
||||
@@ -1344,15 +1346,18 @@ def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path:
|
||||
save_func = save_du
|
||||
remove_old_func = remove_du
|
||||
|
||||
saving, remove_epoch_no = save_on_epoch_end(args, save_func, remove_old_func, epoch_no, num_train_epochs)
|
||||
saving = save_on_epoch_end(args, save_func, remove_old_func, epoch_no, num_train_epochs)
|
||||
if saving and args.save_state:
|
||||
save_state_on_epoch_end(args, accelerator, model_name, epoch_no, remove_epoch_no)
|
||||
save_state_on_epoch_end(args, accelerator, model_name, epoch_no)
|
||||
|
||||
|
||||
def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, epoch_no, remove_epoch_no):
|
||||
def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, epoch_no):
|
||||
print("saving state.")
|
||||
accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no)))
|
||||
if remove_epoch_no is not None:
|
||||
|
||||
last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs
|
||||
if last_n_epochs is not None:
|
||||
remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epochs
|
||||
state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no))
|
||||
if os.path.exists(state_dir_old):
|
||||
print(f"removing old state: {state_dir_old}")
|
||||
|
||||
31
networks/check_lora_weights.py
Normal file
31
networks/check_lora_weights.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import argparse
|
||||
import os
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
|
||||
|
||||
def main(file):
|
||||
print(f"loading: {file}")
|
||||
if os.path.splitext(file)[1] == '.safetensors':
|
||||
sd = load_file(file)
|
||||
else:
|
||||
sd = torch.load(file, map_location='cpu')
|
||||
|
||||
values = []
|
||||
|
||||
keys = list(sd.keys())
|
||||
for key in keys:
|
||||
if 'lora_up' in key:
|
||||
values.append((key, sd[key]))
|
||||
print(f"number of LoRA-up modules: {len(values)}")
|
||||
|
||||
for key, value in values:
|
||||
print(f"{key},{torch.mean(torch.abs(value))}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("file", type=str, help="model file to check / 重みを確認するモデルファイル")
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args.file)
|
||||
15
train_db.py
15
train_db.py
@@ -92,10 +92,7 @@ def train(args):
|
||||
gc.collect()
|
||||
|
||||
# 学習を準備する:モデルを適切な状態にする
|
||||
if args.stop_text_encoder_training is None:
|
||||
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
|
||||
|
||||
train_text_encoder = args.stop_text_encoder_training >= 0
|
||||
train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0
|
||||
unet.requires_grad_(True) # 念のため追加
|
||||
text_encoder.requires_grad_(train_text_encoder)
|
||||
if not train_text_encoder:
|
||||
@@ -134,10 +131,18 @@ def train(args):
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(8, os.cpu_count() - 1) # cpu_count-1 ただし最大8
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers)
|
||||
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
|
||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
if args.stop_text_encoder_training is None:
|
||||
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
|
||||
|
||||
# lr schedulerを用意する
|
||||
lr_scheduler = diffusers.optimization.get_scheduler(
|
||||
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps)
|
||||
|
||||
@@ -126,10 +126,15 @@ def train(args):
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(8, os.cpu_count() - 1) # cpu_count-1 ただし最大8
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers)
|
||||
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
|
||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
# lr schedulerを用意する
|
||||
lr_scheduler = diffusers.optimization.get_scheduler(
|
||||
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
|
||||
@@ -161,6 +166,9 @@ def train(args):
|
||||
if args.gradient_checkpointing: # according to TI example in Diffusers, train is required
|
||||
unet.train()
|
||||
text_encoder.train()
|
||||
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
text_encoder.text_model.embeddings.requires_grad_(True)
|
||||
else:
|
||||
unet.eval()
|
||||
text_encoder.eval()
|
||||
@@ -359,9 +367,9 @@ def train(args):
|
||||
print(f"removing old checkpoint: {old_ckpt_file}")
|
||||
os.remove(old_ckpt_file)
|
||||
|
||||
saving, remove_epoch_no = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
|
||||
saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
|
||||
if saving and args.save_state:
|
||||
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1, remove_epoch_no)
|
||||
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
||||
|
||||
# end of epoch
|
||||
|
||||
|
||||
Reference in New Issue
Block a user