mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-06 21:52:27 +00:00
Compare commits
61 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4cabb37977 | ||
|
|
86eba1d2cf | ||
|
|
05940940c0 | ||
|
|
6bbb4d426e | ||
|
|
7817e95a86 | ||
|
|
443ce7a30b | ||
|
|
0fef7b4684 | ||
|
|
67e698af67 | ||
|
|
7c35aee042 | ||
|
|
481823796e | ||
|
|
835b0d54cd | ||
|
|
505768ea86 | ||
|
|
1614d30d1b | ||
|
|
25566182a8 | ||
|
|
6dffc88b44 | ||
|
|
3fb12e41b7 | ||
|
|
591e3c1813 | ||
|
|
b5ba463512 | ||
|
|
e0d7f1d99d | ||
|
|
a68501bede | ||
|
|
c425afb08b | ||
|
|
46029b2707 | ||
|
|
02acae8e1d | ||
|
|
91a50ea637 | ||
|
|
9f644d8dc3 | ||
|
|
36dc97c841 | ||
|
|
e6bad080cb | ||
|
|
7f17237ada | ||
|
|
ebd3ea380c | ||
|
|
bf3a13bb4e | ||
|
|
1a170c4762 | ||
|
|
552cdbd6d8 | ||
|
|
a86514f1ad | ||
|
|
66051883fb | ||
|
|
f7fbdc4b2a | ||
|
|
ebdb624d29 | ||
|
|
93df55d597 | ||
|
|
56bc806d52 | ||
|
|
25f8ac731f | ||
|
|
4ba1667978 | ||
|
|
0ca064287e | ||
|
|
a3171714ce | ||
|
|
4a1668fe37 | ||
|
|
4eb356f165 | ||
|
|
a7218574f2 | ||
|
|
ddfe94b33b | ||
|
|
8746188ed7 | ||
|
|
1bfcf164f1 | ||
|
|
5e817e4343 | ||
|
|
b4636d4185 | ||
|
|
22ee0ac467 | ||
|
|
17089b1287 | ||
|
|
7ee808d5d7 | ||
|
|
9ff26af68b | ||
|
|
7dbcef745a | ||
|
|
da48f74e7b | ||
|
|
e5d9f483f0 | ||
|
|
303c3410e2 | ||
|
|
de1dde1a06 | ||
|
|
186a2665ad | ||
|
|
c1b14fcdd6 |
17
README-ja.md
17
README-ja.md
@@ -1,7 +1,7 @@
|
||||
## リポジトリについて
|
||||
Stable Diffusionの学習、画像生成、その他のスクリプトを入れたリポジトリです。
|
||||
|
||||
[README in English](./README.md)
|
||||
[README in English](./README.md) ←更新情報はこちらにあります
|
||||
|
||||
GUIやPowerShellスクリプトなど、より使いやすくする機能が[bmaltais氏のリポジトリ](https://github.com/bmaltais/kohya_ss)で提供されています(英語です)のであわせてご覧ください。bmaltais氏に感謝します。
|
||||
|
||||
@@ -16,9 +16,11 @@ GUIやPowerShellスクリプトなど、より使いやすくする機能が[bma
|
||||
|
||||
当リポジトリ内およびnote.comに記事がありますのでそちらをご覧ください(将来的にはすべてこちらへ移すかもしれません)。
|
||||
|
||||
* note.com [環境整備とDreamBooth学習スクリプトについて](https://note.com/kohya_ss/n/nba4eceaa4594)
|
||||
* [DreamBoothの学習について](./train_db_README-ja.md)
|
||||
* [fine-tuningのガイド](./fine_tune_README_ja.md):
|
||||
BLIPによるキャプショニングと、DeepDanbooruまたはWD14 taggerによるタグ付けを含みます
|
||||
* [LoRAの学習について](./train_network_README-ja.md)
|
||||
* [Textual Inversionの学習について](./train_ti_README-ja.md)
|
||||
* note.com [画像生成スクリプト](https://note.com/kohya_ss/n/n2693183a798e)
|
||||
* note.com [モデル変換スクリプト](https://note.com/kohya_ss/n/n374f316fe4ad)
|
||||
|
||||
@@ -44,12 +46,11 @@ PowerShellを使う場合、venvを使えるようにするためには以下の
|
||||
|
||||
通常の(管理者ではない)PowerShellを開き以下を順に実行します。
|
||||
|
||||
|
||||
```powershell
|
||||
git clone https://github.com/kohya-ss/sd-scripts.git
|
||||
cd sd-scripts
|
||||
|
||||
python -m venv --system-site-packages venv
|
||||
python -m venv venv
|
||||
.\venv\Scripts\activate
|
||||
|
||||
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
|
||||
@@ -70,7 +71,7 @@ accelerate config
|
||||
git clone https://github.com/kohya-ss/sd-scripts.git
|
||||
cd sd-scripts
|
||||
|
||||
python -m venv --system-site-packages venv
|
||||
python -m venv venv
|
||||
.\venv\Scripts\activate
|
||||
|
||||
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
|
||||
@@ -84,6 +85,8 @@ copy /y .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cud
|
||||
accelerate config
|
||||
```
|
||||
|
||||
(注:``python -m venv venv`` のほうが ``python -m venv --system-site-packages venv`` より安全そうなため書き換えました。globalなpythonにパッケージがインストールしてあると、後者だといろいろと問題が起きます。)
|
||||
|
||||
accelerate configの質問には以下のように答えてください。(bf16で学習する場合、最後の質問にはbf16と答えてください。)
|
||||
|
||||
※0.15.0から日本語環境では選択のためにカーソルキーを押すと落ちます(……)。数字キーの0、1、2……で選択できますので、そちらを使ってください。
|
||||
@@ -101,6 +104,10 @@ accelerate configの質問には以下のように答えてください。(bf1
|
||||
※場合によって ``ValueError: fp16 mixed precision requires a GPU`` というエラーが出ることがあるようです。この場合、6番目の質問(
|
||||
``What GPU(s) (by id) should be used for training on this machine as a comma-separated list? [all]:``)に「0」と答えてください。(id `0`のGPUが使われます。)
|
||||
|
||||
### PyTorchとxformersのバージョンについて
|
||||
|
||||
他のバージョンでは学習がうまくいかない場合があるようです。特に他の理由がなければ指定のバージョンをお使いください。
|
||||
|
||||
## アップグレード
|
||||
|
||||
新しいリリースがあった場合、以下のコマンドで更新できます。
|
||||
|
||||
45
README.md
45
README.md
@@ -2,11 +2,35 @@ This repository contains training, generation and utility scripts for Stable Dif
|
||||
|
||||
## Updates
|
||||
|
||||
- 22 Jan. 2023, 2023/1/22
|
||||
- Fix script to check LoRA weights ``check_lora_weights.py``. Some layer weights were shown as ``0.0`` even if the layer is trained, because of the overflow of ``torch.mean``. Sorry for the confusion.
|
||||
- Noe the script shows the mean of the absolute values of the weights, and the minimum of the absolute values of the weights.
|
||||
- LoRAの重みをチェックするスクリプト ``check_lora_weights.py`` を修正しました。一部のレイヤーで学習されているにもかかわらず重みが ``0.0`` と表示されていました。混乱を招き申し訳ありません。
|
||||
- スクリプトを「重みの絶対の平均」と「重みの絶対値の最小値」を表示するよう修正しました。
|
||||
__Stable Diffusion web UI now seems to support LoRA trained by ``sd-scripts``.__ Thank you for great work!!!
|
||||
|
||||
Note: The LoRA models for SD 2.x is not supported too in Web UI.
|
||||
|
||||
- 29 Jan. 2023, 2023/1/29
|
||||
- Add ``--lr_scheduler_num_cycles`` and ``--lr_scheduler_power`` options for ``train_network.py`` for cosine_with_restarts and polynomial learning rate schedulers. Thanks to mgz-dev!
|
||||
- Fixed U-Net ``sample_size`` parameter to ``64`` when converting from SD to Diffusers format, in ``convert_diffusers20_original_sd.py``
|
||||
- ``--lr_scheduler_num_cycles`` と ``--lr_scheduler_power`` オプションを ``train_network.py`` に追加しました。前者は cosine_with_restarts、後者は polynomial の学習率スケジューラに有効です。mgz-dev氏に感謝します。
|
||||
- ``convert_diffusers20_original_sd.py`` で SD 形式から Diffusers に変換するときの U-Net の ``sample_size`` パラメータを ``64`` に修正しました。
|
||||
- 26 Jan. 2023, 2023/1/26
|
||||
- Add Textual Inversion training. Documentation is [here](./train_ti_README-ja.md) (in Japanese.)
|
||||
- Textual Inversionの学習をサポートしました。ドキュメントは[こちら](./train_ti_README-ja.md)。
|
||||
- 24 Jan. 2023, 2023/1/24
|
||||
- Change the default save format to ``.safetensors`` for ``train_network.py``.
|
||||
- Add ``--save_n_epoch_ratio`` option to specify how often to save. Thanks to forestsource!
|
||||
- For example, if 5 is specified, 5 (or 6) files will be saved in training.
|
||||
- Add feature to pre-calculate hash to reduce loading time in the extension. Thanks to space-nuko!
|
||||
- Add bucketing metadata. Thanks to space-nuko!
|
||||
- Fix an error with bf16 model in ``gen_img_diffusers.py``.
|
||||
- ``train_network.py`` のモデル保存形式のデフォルトを ``.safetensors`` に変更しました。
|
||||
- モデルを保存する頻度を指定する ``--save_n_epoch_ratio`` オプションが追加されました。forestsource氏に感謝します。
|
||||
- たとえば 5 を指定すると、学習終了までに合計で5個(または6個)のファイルが保存されます。
|
||||
- 拡張でモデル読み込み時間を短縮するためのハッシュ事前計算の機能を追加しました。space-nuko氏に感謝します。
|
||||
- メタデータにbucket情報が追加されました。space-nuko氏に感謝します。
|
||||
- ``gen_img_diffusers.py`` でbf16形式のモデルを読み込んだときのエラーを修正しました。
|
||||
|
||||
Stable Diffusion web UI本体で当リポジトリで学習したLoRAモデルによる画像生成がサポートされたようです。
|
||||
|
||||
注:SD2.x用のLoRAモデルはサポートされないようです。
|
||||
|
||||
Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates.
|
||||
最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。
|
||||
@@ -39,6 +63,7 @@ All documents are in Japanese currently, and CUI based.
|
||||
* [Step by Step fine-tuning guide](./fine_tune_README_ja.md):
|
||||
Including BLIP captioning and tagging by DeepDanbooru or WD14 tagger
|
||||
* [training LoRA](./train_network_README-ja.md)
|
||||
* [training Textual Inversion](./train_ti_README-ja.md)
|
||||
* note.com [Image generation](https://note.com/kohya_ss/n/n2693183a798e)
|
||||
* note.com [Model conversion](https://note.com/kohya_ss/n/n374f316fe4ad)
|
||||
|
||||
@@ -63,7 +88,7 @@ Open a regular Powershell terminal and type the following inside:
|
||||
git clone https://github.com/kohya-ss/sd-scripts.git
|
||||
cd sd-scripts
|
||||
|
||||
python -m venv --system-site-packages venv
|
||||
python -m venv venv
|
||||
.\venv\Scripts\activate
|
||||
|
||||
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
|
||||
@@ -75,9 +100,10 @@ cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\ce
|
||||
cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
|
||||
|
||||
accelerate config
|
||||
|
||||
```
|
||||
|
||||
update: ``python -m venv venv`` is seemed to be safer than ``python -m venv --system-site-packages venv`` (some user have packages in global python).
|
||||
|
||||
Answers to accelerate config:
|
||||
|
||||
```txt
|
||||
@@ -95,6 +121,11 @@ note: Some user reports ``ValueError: fp16 mixed precision requires a GPU`` is o
|
||||
|
||||
(Single GPU with id `0` will be used.)
|
||||
|
||||
### about PyTorch and xformers
|
||||
|
||||
Other versions of PyTorch and xformers seem to have problems with training.
|
||||
If there is no other reason, please install the specified version.
|
||||
|
||||
## Upgrade
|
||||
|
||||
When a new release comes out you can upgrade your repo with the following command:
|
||||
|
||||
@@ -200,6 +200,8 @@ def train(args):
|
||||
# epoch数を計算する
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
||||
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
||||
|
||||
# 学習する
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
@@ -324,7 +324,7 @@ __※引数を都度書き換えて、別のメタデータファイルに書き
|
||||
## 学習の実行
|
||||
たとえば以下のように実行します。以下は省メモリ化のための設定です。
|
||||
```
|
||||
accelerate launch --num_cpu_threads_per_process 8 fine_tune.py
|
||||
accelerate launch --num_cpu_threads_per_process 1 fine_tune.py
|
||||
--pretrained_model_name_or_path=model.ckpt
|
||||
--in_json meta_lat.json
|
||||
--train_data_dir=train_data
|
||||
@@ -336,7 +336,7 @@ accelerate launch --num_cpu_threads_per_process 8 fine_tune.py
|
||||
--save_every_n_epochs=4
|
||||
```
|
||||
|
||||
accelerateのnum_cpu_threads_per_processにはCPUのコア数を指定するとよいようです。
|
||||
accelerateのnum_cpu_threads_per_processには通常は1を指定するとよいようです。
|
||||
|
||||
pretrained_model_name_or_pathに学習対象のモデルを指定します(Stable DiffusionのcheckpointかDiffusersのモデル)。Stable Diffusionのcheckpointは.ckptと.safetensorsに対応しています(拡張子で自動判定)。
|
||||
|
||||
|
||||
@@ -470,6 +470,9 @@ class PipelineLike():
|
||||
self.scheduler = scheduler
|
||||
self.safety_checker = None
|
||||
|
||||
# Textual Inversion
|
||||
self.token_replacements = {}
|
||||
|
||||
# CLIP guidance
|
||||
self.clip_guidance_scale = clip_guidance_scale
|
||||
self.clip_image_guidance_scale = clip_image_guidance_scale
|
||||
@@ -484,6 +487,19 @@ class PipelineLike():
|
||||
self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers)
|
||||
self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD)
|
||||
|
||||
# Textual Inversion
|
||||
def add_token_replacement(self, target_token_id, rep_token_ids):
|
||||
self.token_replacements[target_token_id] = rep_token_ids
|
||||
|
||||
def replace_token(self, tokens):
|
||||
new_tokens = []
|
||||
for token in tokens:
|
||||
if token in self.token_replacements:
|
||||
new_tokens.extend(self.token_replacements[token])
|
||||
else:
|
||||
new_tokens.append(token)
|
||||
return new_tokens
|
||||
|
||||
# region xformersとか使う部分:独自に書き換えるので関係なし
|
||||
def enable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
@@ -1507,6 +1523,9 @@ def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length:
|
||||
for word, weight in texts_and_weights:
|
||||
# tokenize and discard the starting and the ending token
|
||||
token = pipe.tokenizer(word).input_ids[1:-1]
|
||||
|
||||
token = pipe.replace_token(token)
|
||||
|
||||
text_token += token
|
||||
# copy the weight by length of token
|
||||
text_weight += [weight] * len(token)
|
||||
@@ -1981,7 +2000,6 @@ def main(args):
|
||||
imported_module = importlib.import_module(network_module)
|
||||
|
||||
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
|
||||
network_dim = None if args.network_dim is None or len(args.network_dim) <= i else args.network_dim[i]
|
||||
|
||||
net_kwargs = {}
|
||||
if args.network_args and i < len(args.network_args):
|
||||
@@ -1992,22 +2010,22 @@ def main(args):
|
||||
key, value = net_arg.split("=")
|
||||
net_kwargs[key] = value
|
||||
|
||||
network = imported_module.create_network(network_mul, network_dim, vae, text_encoder, unet, **net_kwargs)
|
||||
if network is None:
|
||||
return
|
||||
|
||||
if args.network_weights and i < len(args.network_weights):
|
||||
network_weight = args.network_weights[i]
|
||||
print("load network weights from:", network_weight)
|
||||
|
||||
if os.path.splitext(network_weight)[1] == '.safetensors':
|
||||
if model_util.is_safetensors(network_weight):
|
||||
from safetensors.torch import safe_open
|
||||
with safe_open(network_weight, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata is not None:
|
||||
print(f"metadata for: {network_weight}: {metadata}")
|
||||
|
||||
network.load_weights(network_weight)
|
||||
network = imported_module.create_network_from_weights(network_mul, network_weight, vae, text_encoder, unet, **net_kwargs)
|
||||
else:
|
||||
raise ValueError("No weight. Weight is required.")
|
||||
if network is None:
|
||||
return
|
||||
|
||||
network.apply_to(text_encoder, unet)
|
||||
|
||||
@@ -2040,6 +2058,44 @@ def main(args):
|
||||
if args.diffusers_xformers:
|
||||
pipe.enable_xformers_memory_efficient_attention()
|
||||
|
||||
# Textual Inversionを処理する
|
||||
if args.textual_inversion_embeddings:
|
||||
token_ids_embeds = []
|
||||
for embeds_file in args.textual_inversion_embeddings:
|
||||
if model_util.is_safetensors(embeds_file):
|
||||
from safetensors.torch import load_file
|
||||
data = load_file(embeds_file)
|
||||
else:
|
||||
data = torch.load(embeds_file, map_location="cpu")
|
||||
|
||||
embeds = next(iter(data.values()))
|
||||
if type(embeds) != torch.Tensor:
|
||||
raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {embeds_file}")
|
||||
|
||||
num_vectors_per_token = embeds.size()[0]
|
||||
token_string = os.path.splitext(os.path.basename(embeds_file))[0]
|
||||
token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)]
|
||||
|
||||
# add new word to tokenizer, count is num_vectors_per_token
|
||||
num_added_tokens = tokenizer.add_tokens(token_strings)
|
||||
assert num_added_tokens == num_vectors_per_token, f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}"
|
||||
|
||||
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
|
||||
print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids}")
|
||||
assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered"
|
||||
assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}"
|
||||
|
||||
if num_vectors_per_token > 1:
|
||||
pipe.add_token_replacement(token_ids[0], token_ids)
|
||||
|
||||
token_ids_embeds.append((token_ids, embeds))
|
||||
|
||||
text_encoder.resize_token_embeddings(len(tokenizer))
|
||||
token_embeds = text_encoder.get_input_embeddings().weight.data
|
||||
for token_ids, embeds in token_ids_embeds:
|
||||
for token_id, embed in zip(token_ids, embeds):
|
||||
token_embeds[token_id] = embed
|
||||
|
||||
# promptを取得する
|
||||
if args.from_file is not None:
|
||||
print(f"reading prompts from {args.from_file}")
|
||||
@@ -2158,8 +2214,8 @@ def main(args):
|
||||
os.makedirs(args.outdir, exist_ok=True)
|
||||
max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples
|
||||
|
||||
for iter in range(args.n_iter):
|
||||
print(f"iteration {iter+1}/{args.n_iter}")
|
||||
for gen_iter in range(args.n_iter):
|
||||
print(f"iteration {gen_iter+1}/{args.n_iter}")
|
||||
iter_seed = random.randint(0, 0x7fffffff)
|
||||
|
||||
# バッチ処理の関数
|
||||
@@ -2526,10 +2582,10 @@ if __name__ == '__main__':
|
||||
parser.add_argument("--network_weights", type=str, default=None, nargs='*',
|
||||
help='Hypernetwork weights to load / Hypernetworkの重み')
|
||||
parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率')
|
||||
parser.add_argument("--network_dim", type=int, default=None, nargs='*',
|
||||
help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)')
|
||||
parser.add_argument("--network_args", type=str, default=None, nargs='*',
|
||||
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
|
||||
parser.add_argument("--textual_inversion_embeddings", type=str, default=None, nargs='*',
|
||||
help='Embeddings files of Textual Inversion / Textual Inversionのembeddings')
|
||||
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
|
||||
parser.add_argument("--max_embeddings_multiples", type=int, default=None,
|
||||
help='max embeding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる')
|
||||
|
||||
@@ -16,7 +16,7 @@ BETA_END = 0.0120
|
||||
UNET_PARAMS_MODEL_CHANNELS = 320
|
||||
UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
|
||||
UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
|
||||
UNET_PARAMS_IMAGE_SIZE = 32 # unused
|
||||
UNET_PARAMS_IMAGE_SIZE = 64 # fixed from old invalid value `32`
|
||||
UNET_PARAMS_IN_CHANNELS = 4
|
||||
UNET_PARAMS_OUT_CHANNELS = 4
|
||||
UNET_PARAMS_NUM_RES_BLOCKS = 2
|
||||
|
||||
@@ -11,6 +11,8 @@ import glob
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import hashlib
|
||||
from io import BytesIO
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
@@ -24,6 +26,7 @@ from PIL import Image
|
||||
import cv2
|
||||
from einops import rearrange
|
||||
from torch import einsum
|
||||
import safetensors.torch
|
||||
|
||||
import library.model_util as model_util
|
||||
|
||||
@@ -79,6 +82,12 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
self.debug_dataset = debug_dataset
|
||||
self.random_crop = random_crop
|
||||
self.token_padding_disabled = False
|
||||
self.dataset_dirs_info = {}
|
||||
self.reg_dataset_dirs_info = {}
|
||||
self.enable_bucket = False
|
||||
self.min_bucket_reso = None
|
||||
self.max_bucket_reso = None
|
||||
self.bucket_info = None
|
||||
|
||||
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
||||
|
||||
@@ -104,9 +113,14 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
self.image_data: dict[str, ImageInfo] = {}
|
||||
|
||||
self.replacements = {}
|
||||
|
||||
def disable_token_padding(self):
|
||||
self.token_padding_disabled = True
|
||||
|
||||
def add_replacement(self, str_from, str_to):
|
||||
self.replacements[str_from] = str_to
|
||||
|
||||
def process_caption(self, caption):
|
||||
if self.shuffle_caption:
|
||||
tokens = caption.strip().split(",")
|
||||
@@ -119,6 +133,17 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
random.shuffle(tokens)
|
||||
tokens = keep_tokens + tokens
|
||||
caption = ",".join(tokens).strip()
|
||||
|
||||
for str_from, str_to in self.replacements.items():
|
||||
if str_from == "":
|
||||
# replace all
|
||||
if type(str_to) == list:
|
||||
caption = random.choice(str_to)
|
||||
else:
|
||||
caption = str_to
|
||||
else:
|
||||
caption = caption.replace(str_from, str_to)
|
||||
|
||||
return caption
|
||||
|
||||
def get_input_ids(self, caption):
|
||||
@@ -211,11 +236,17 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
self.buckets[bucket_index].append(image_info.image_key)
|
||||
|
||||
if self.enable_bucket:
|
||||
self.bucket_info = {"buckets": {}}
|
||||
print("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)")
|
||||
for i, (reso, img_keys) in enumerate(zip(bucket_resos, self.buckets)):
|
||||
self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(img_keys)}
|
||||
print(f"bucket {i}: resolution {reso}, count: {len(img_keys)}")
|
||||
|
||||
img_ar_errors = np.array(img_ar_errors)
|
||||
print(f"mean ar error (without repeats): {np.mean(np.abs(img_ar_errors))}")
|
||||
mean_img_ar_error = np.mean(np.abs(img_ar_errors))
|
||||
self.bucket_info["mean_img_ar_error"] = mean_img_ar_error
|
||||
print(f"mean ar error (without repeats): {mean_img_ar_error}")
|
||||
|
||||
|
||||
# 参照用indexを作る
|
||||
self.buckets_indices: list(BucketBatchIndex) = []
|
||||
@@ -463,6 +494,8 @@ class DreamBoothDataset(BaseDataset):
|
||||
assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
|
||||
self.bucket_resos, self.bucket_aspect_ratios = model_util.make_bucket_resolutions(
|
||||
(self.width, self.height), min_bucket_reso, max_bucket_reso)
|
||||
self.min_bucket_reso = min_bucket_reso
|
||||
self.max_bucket_reso = max_bucket_reso
|
||||
else:
|
||||
self.bucket_resos = [(self.width, self.height)]
|
||||
self.bucket_aspect_ratios = [self.width / self.height]
|
||||
@@ -523,6 +556,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
for img_path, caption in zip(img_paths, captions):
|
||||
info = ImageInfo(img_path, n_repeats, caption, False, img_path)
|
||||
self.register_image(info)
|
||||
self.dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
|
||||
print(f"{num_train_images} train images with repeating.")
|
||||
self.num_train_images = num_train_images
|
||||
|
||||
@@ -539,6 +573,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
for img_path, caption in zip(img_paths, captions):
|
||||
info = ImageInfo(img_path, n_repeats, caption, True, img_path)
|
||||
reg_infos.append(info)
|
||||
self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
|
||||
|
||||
print(f"{num_reg_images} reg images.")
|
||||
if num_train_images < num_reg_images:
|
||||
@@ -589,7 +624,7 @@ class FineTuningDataset(BaseDataset):
|
||||
else:
|
||||
# わりといい加減だがいい方法が思いつかん
|
||||
abs_path = glob_images(train_data_dir, image_key)
|
||||
assert len(abs_path) >= 1, f"no image / 画像がありません: {abs_path}"
|
||||
assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}"
|
||||
abs_path = abs_path[0]
|
||||
|
||||
caption = img_md.get('caption')
|
||||
@@ -611,6 +646,8 @@ class FineTuningDataset(BaseDataset):
|
||||
self.num_train_images = len(metadata) * dataset_repeats
|
||||
self.num_reg_images = 0
|
||||
|
||||
self.dataset_dirs_info[os.path.basename(self.train_data_dir)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)}
|
||||
|
||||
# check existence of all npz files
|
||||
if not self.color_aug:
|
||||
npz_any = False
|
||||
@@ -653,6 +690,8 @@ class FineTuningDataset(BaseDataset):
|
||||
assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
|
||||
self.bucket_resos, self.bucket_aspect_ratios = model_util.make_bucket_resolutions(
|
||||
(self.width, self.height), min_bucket_reso, max_bucket_reso)
|
||||
self.min_bucket_reso = min_bucket_reso
|
||||
self.max_bucket_reso = max_bucket_reso
|
||||
else:
|
||||
self.bucket_resos = [(self.width, self.height)]
|
||||
self.bucket_aspect_ratios = [self.width / self.height]
|
||||
@@ -665,6 +704,9 @@ class FineTuningDataset(BaseDataset):
|
||||
self.bucket_resos.sort()
|
||||
self.bucket_aspect_ratios = [w / h for w, h in self.bucket_resos]
|
||||
|
||||
self.min_bucket_reso = min([min(reso) for reso in resos])
|
||||
self.max_bucket_reso = max([max(reso) for reso in resos])
|
||||
|
||||
def image_key_to_npz_file(self, image_key):
|
||||
base_name = os.path.splitext(image_key)[0]
|
||||
npz_file_norm = base_name + '.npz'
|
||||
@@ -689,15 +731,17 @@ class FineTuningDataset(BaseDataset):
|
||||
return npz_file_norm, npz_file_flip
|
||||
|
||||
|
||||
def debug_dataset(train_dataset):
|
||||
def debug_dataset(train_dataset, show_input_ids=False):
|
||||
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
|
||||
print("Escape for exit. / Escキーで中断、終了します")
|
||||
k = 0
|
||||
for example in train_dataset:
|
||||
if example['latents'] is not None:
|
||||
print("sample has latents from npz file")
|
||||
for j, (ik, cap, lw) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'])):
|
||||
for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])):
|
||||
print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, caption: "{cap}", loss weight: {lw}')
|
||||
if show_input_ids:
|
||||
print(f"input ids: {iid}")
|
||||
if example['images'] is not None:
|
||||
im = example['images'][j]
|
||||
im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
|
||||
@@ -749,9 +793,9 @@ def default(val, d):
|
||||
|
||||
|
||||
def model_hash(filename):
|
||||
"""Old model hash used by stable-diffusion-webui"""
|
||||
try:
|
||||
with open(filename, "rb") as file:
|
||||
import hashlib
|
||||
m = hashlib.sha256()
|
||||
|
||||
file.seek(0x100000)
|
||||
@@ -761,6 +805,61 @@ def model_hash(filename):
|
||||
return 'NOFILE'
|
||||
|
||||
|
||||
def calculate_sha256(filename):
|
||||
"""New model hash used by stable-diffusion-webui"""
|
||||
hash_sha256 = hashlib.sha256()
|
||||
blksize = 1024 * 1024
|
||||
|
||||
with open(filename, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(blksize), b""):
|
||||
hash_sha256.update(chunk)
|
||||
|
||||
return hash_sha256.hexdigest()
|
||||
|
||||
|
||||
def precalculate_safetensors_hashes(tensors, metadata):
|
||||
"""Precalculate the model hashes needed by sd-webui-additional-networks to
|
||||
save time on indexing the model later."""
|
||||
|
||||
# Because writing user metadata to the file can change the result of
|
||||
# sd_models.model_hash(), only retain the training metadata for purposes of
|
||||
# calculating the hash, as they are meant to be immutable
|
||||
metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
|
||||
|
||||
bytes = safetensors.torch.save(tensors, metadata)
|
||||
b = BytesIO(bytes)
|
||||
|
||||
model_hash = addnet_hash_safetensors(b)
|
||||
legacy_hash = addnet_hash_legacy(b)
|
||||
return model_hash, legacy_hash
|
||||
|
||||
|
||||
def addnet_hash_legacy(b):
|
||||
"""Old model hash used by sd-webui-additional-networks for .safetensors format files"""
|
||||
m = hashlib.sha256()
|
||||
|
||||
b.seek(0x100000)
|
||||
m.update(b.read(0x10000))
|
||||
return m.hexdigest()[0:8]
|
||||
|
||||
|
||||
def addnet_hash_safetensors(b):
|
||||
"""New model hash used by sd-webui-additional-networks for .safetensors format files"""
|
||||
hash_sha256 = hashlib.sha256()
|
||||
blksize = 1024 * 1024
|
||||
|
||||
b.seek(0)
|
||||
header = b.read(8)
|
||||
n = int.from_bytes(header, "little")
|
||||
|
||||
offset = n + 8
|
||||
b.seek(offset)
|
||||
for chunk in iter(lambda: b.read(blksize), b""):
|
||||
hash_sha256.update(chunk)
|
||||
|
||||
return hash_sha256.hexdigest()
|
||||
|
||||
|
||||
# flash attention forwards and backwards
|
||||
|
||||
# https://arxiv.org/abs/2205.14135
|
||||
@@ -1028,8 +1127,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する")
|
||||
parser.add_argument("--save_every_n_epochs", type=int, default=None,
|
||||
help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する")
|
||||
parser.add_argument("--save_n_epoch_ratio", type=int, default=None,
|
||||
help="save checkpoint N epoch ratio (for example 5 means save at least 5 files total) / 学習中のモデルを指定のエポック割合で保存する(たとえば5を指定すると最低5個のファイルが保存される)")
|
||||
parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する")
|
||||
parser.add_argument("--save_last_n_epochs_state", type=int, default=None, help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)")
|
||||
parser.add_argument("--save_last_n_epochs_state", type=int, default=None,
|
||||
help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)")
|
||||
parser.add_argument("--save_state", action="store_true",
|
||||
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
|
||||
parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
|
||||
@@ -1048,8 +1150,10 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
|
||||
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
|
||||
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
|
||||
parser.add_argument("--max_train_epochs", type=int, default=None, help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
|
||||
parser.add_argument("--max_data_loader_n_workers", type=int, default=8, help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)")
|
||||
parser.add_argument("--max_train_epochs", type=int, default=None,
|
||||
help="training epochs (overrides max_train_steps) / 学習エポック数(max_train_stepsを上書きします)")
|
||||
parser.add_argument("--max_data_loader_n_workers", type=int, default=8,
|
||||
help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数(小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)")
|
||||
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
|
||||
parser.add_argument("--gradient_checkpointing", action="store_true",
|
||||
help="enable gradient checkpointing / grandient checkpointingを有効にする")
|
||||
|
||||
@@ -44,9 +44,9 @@ def svd(args):
|
||||
print(f"loading SD model : {args.model_tuned}")
|
||||
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned)
|
||||
|
||||
# create LoRA network to extract weights
|
||||
lora_network_o = lora.create_network(1.0, args.dim, None, text_encoder_o, unet_o)
|
||||
lora_network_t = lora.create_network(1.0, args.dim, None, text_encoder_t, unet_t)
|
||||
# create LoRA network to extract weights: Use dim (rank) as alpha
|
||||
lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_o, unet_o)
|
||||
lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_t, unet_t)
|
||||
assert len(lora_network_o.text_encoder_loras) == len(
|
||||
lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) "
|
||||
|
||||
@@ -77,10 +77,10 @@ def svd(args):
|
||||
module_t = lora_t.org_module
|
||||
diff = module_t.weight - module_o.weight
|
||||
diff = diff.float()
|
||||
|
||||
|
||||
if args.device:
|
||||
diff = diff.to(args.device)
|
||||
|
||||
|
||||
diffs[lora_name] = diff
|
||||
|
||||
# make LoRA with svd
|
||||
@@ -116,6 +116,9 @@ def svd(args):
|
||||
print(f"LoRA has {len(lora_sd)} weights.")
|
||||
|
||||
for key in list(lora_sd.keys()):
|
||||
if "alpha" in key:
|
||||
continue
|
||||
|
||||
lora_name = key.split('.')[0]
|
||||
i = 0 if "lora_up" in key else 1
|
||||
|
||||
@@ -124,7 +127,7 @@ def svd(args):
|
||||
if len(lora_sd[key].size()) == 4:
|
||||
weights = weights.unsqueeze(2).unsqueeze(3)
|
||||
|
||||
assert weights.size() == lora_sd[key].size()
|
||||
assert weights.size() == lora_sd[key].size(), f"size unmatch: {key}"
|
||||
lora_sd[key] = weights
|
||||
|
||||
# load state dict to LoRA and save it
|
||||
@@ -135,7 +138,10 @@ def svd(args):
|
||||
if dir_name and not os.path.exists(dir_name):
|
||||
os.makedirs(dir_name, exist_ok=True)
|
||||
|
||||
lora_network_o.save_weights(args.save_to, save_dtype, {})
|
||||
# minimum metadata
|
||||
metadata = {"ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)}
|
||||
|
||||
lora_network_o.save_weights(args.save_to, save_dtype, metadata)
|
||||
print(f"LoRA weights are saved to: {args.save_to}")
|
||||
|
||||
|
||||
@@ -151,8 +157,8 @@ if __name__ == '__main__':
|
||||
help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors")
|
||||
parser.add_argument("--save_to", type=str, default=None,
|
||||
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
|
||||
parser.add_argument("--dim", type=int, default=4, help="dimension of LoRA (default 4) / LoRAの次元数(デフォルト4)")
|
||||
parser.add_argument("--device", type=str, default=None, help="device to use, 'cuda' for GPU / 計算を行うデバイス、'cuda'でGPUを使う")
|
||||
parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)")
|
||||
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
||||
|
||||
args = parser.parse_args()
|
||||
svd(args)
|
||||
|
||||
@@ -7,15 +7,19 @@ import math
|
||||
import os
|
||||
import torch
|
||||
|
||||
from library import train_util
|
||||
|
||||
|
||||
class LoRAModule(torch.nn.Module):
|
||||
"""
|
||||
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
||||
"""
|
||||
|
||||
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4):
|
||||
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
|
||||
""" if alpha == 0 or None, alpha is rank (no scaling). """
|
||||
super().__init__()
|
||||
self.lora_name = lora_name
|
||||
self.lora_dim = lora_dim
|
||||
|
||||
if org_module.__class__.__name__ == 'Conv2d':
|
||||
in_dim = org_module.in_channels
|
||||
@@ -28,6 +32,12 @@ class LoRAModule(torch.nn.Module):
|
||||
self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False)
|
||||
self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False)
|
||||
|
||||
if type(alpha) == torch.Tensor:
|
||||
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
||||
alpha = lora_dim if alpha is None or alpha == 0 else alpha
|
||||
self.scale = alpha / self.lora_dim
|
||||
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
|
||||
|
||||
# same as microsoft's
|
||||
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
||||
torch.nn.init.zeros_(self.lora_up.weight)
|
||||
@@ -41,13 +51,37 @@ class LoRAModule(torch.nn.Module):
|
||||
del self.org_module
|
||||
|
||||
def forward(self, x):
|
||||
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier
|
||||
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||
|
||||
|
||||
def create_network(multiplier, network_dim, vae, text_encoder, unet, **kwargs):
|
||||
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
||||
if network_dim is None:
|
||||
network_dim = 4 # default
|
||||
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim)
|
||||
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
|
||||
return network
|
||||
|
||||
|
||||
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwargs):
|
||||
if os.path.splitext(file)[1] == '.safetensors':
|
||||
from safetensors.torch import load_file, safe_open
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location='cpu')
|
||||
|
||||
# get dim (rank)
|
||||
network_alpha = None
|
||||
network_dim = None
|
||||
for key, value in weights_sd.items():
|
||||
if network_alpha is None and 'alpha' in key:
|
||||
network_alpha = value
|
||||
if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
|
||||
network_dim = value.size()[0]
|
||||
|
||||
if network_alpha is None:
|
||||
network_alpha = network_dim
|
||||
|
||||
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
|
||||
network.weights_sd = weights_sd
|
||||
return network
|
||||
|
||||
|
||||
@@ -57,10 +91,11 @@ class LoRANetwork(torch.nn.Module):
|
||||
LORA_PREFIX_UNET = 'lora_unet'
|
||||
LORA_PREFIX_TEXT_ENCODER = 'lora_te'
|
||||
|
||||
def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4) -> None:
|
||||
def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1) -> None:
|
||||
super().__init__()
|
||||
self.multiplier = multiplier
|
||||
self.lora_dim = lora_dim
|
||||
self.alpha = alpha
|
||||
|
||||
# create module instances
|
||||
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> list[LoRAModule]:
|
||||
@@ -71,7 +106,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
|
||||
lora_name = prefix + '.' + name + '.' + child_name
|
||||
lora_name = lora_name.replace('.', '_')
|
||||
lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim)
|
||||
lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim, self.alpha)
|
||||
loras.append(lora)
|
||||
return loras
|
||||
|
||||
@@ -149,21 +184,21 @@ class LoRANetwork(torch.nn.Module):
|
||||
return params
|
||||
|
||||
self.requires_grad_(True)
|
||||
params = []
|
||||
all_params = []
|
||||
|
||||
if self.text_encoder_loras:
|
||||
param_data = {'params': enumerate_params(self.text_encoder_loras)}
|
||||
if text_encoder_lr is not None:
|
||||
param_data['lr'] = text_encoder_lr
|
||||
params.append(param_data)
|
||||
all_params.append(param_data)
|
||||
|
||||
if self.unet_loras:
|
||||
param_data = {'params': enumerate_params(self.unet_loras)}
|
||||
if unet_lr is not None:
|
||||
param_data['lr'] = unet_lr
|
||||
params.append(param_data)
|
||||
all_params.append(param_data)
|
||||
|
||||
return params
|
||||
return all_params
|
||||
|
||||
def prepare_grad_etc(self, text_encoder, unet):
|
||||
self.requires_grad_(True)
|
||||
@@ -188,6 +223,14 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
if os.path.splitext(file)[1] == '.safetensors':
|
||||
from safetensors.torch import save_file
|
||||
|
||||
# Precalculate model hashes to save time on indexing
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
metadata["sshs_model_hash"] = model_hash
|
||||
metadata["sshs_legacy_hash"] = legacy_hash
|
||||
|
||||
save_file(state_dict, file, metadata)
|
||||
else:
|
||||
torch.save(state_dict, file)
|
||||
|
||||
@@ -61,6 +61,7 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
||||
for key in lora_sd.keys():
|
||||
if "lora_down" in key:
|
||||
up_key = key.replace("lora_down", "lora_up")
|
||||
alpha_key = key[:key.index("lora_down")] + 'alpha'
|
||||
|
||||
# find original module for this lora
|
||||
module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight"
|
||||
@@ -73,14 +74,18 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
||||
down_weight = lora_sd[key]
|
||||
up_weight = lora_sd[up_key]
|
||||
|
||||
dim = down_weight.size()[0]
|
||||
alpha = lora_sd.get(alpha_key, dim)
|
||||
scale = alpha / dim
|
||||
|
||||
# W <- W + U * D
|
||||
weight = module.weight
|
||||
if len(weight.size()) == 2:
|
||||
# linear
|
||||
weight = weight + ratio * (up_weight @ down_weight)
|
||||
weight = weight + ratio * (up_weight @ down_weight) * scale
|
||||
else:
|
||||
# conv2d
|
||||
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) * scale
|
||||
|
||||
module.weight = torch.nn.Parameter(weight)
|
||||
|
||||
@@ -88,20 +93,35 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
||||
def merge_lora_models(models, ratios, merge_dtype):
|
||||
merged_sd = {}
|
||||
|
||||
alpha = None
|
||||
dim = None
|
||||
for model, ratio in zip(models, ratios):
|
||||
print(f"loading: {model}")
|
||||
lora_sd = load_state_dict(model, merge_dtype)
|
||||
|
||||
print(f"merging...")
|
||||
for key in lora_sd.keys():
|
||||
if key in merged_sd:
|
||||
assert merged_sd[key].size() == lora_sd[key].size(
|
||||
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
|
||||
merged_sd[key] = merged_sd[key] + lora_sd[key] * ratio
|
||||
if 'alpha' in key:
|
||||
if key in merged_sd:
|
||||
assert merged_sd[key] == lora_sd[key], f"alpha mismatch / alphaが異なる場合、現時点ではマージできません"
|
||||
else:
|
||||
alpha = lora_sd[key].detach().numpy()
|
||||
merged_sd[key] = lora_sd[key]
|
||||
else:
|
||||
merged_sd[key] = lora_sd[key] * ratio
|
||||
if key in merged_sd:
|
||||
assert merged_sd[key].size() == lora_sd[key].size(
|
||||
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
|
||||
merged_sd[key] = merged_sd[key] + lora_sd[key] * ratio
|
||||
else:
|
||||
if "lora_down" in key:
|
||||
dim = lora_sd[key].size()[0]
|
||||
merged_sd[key] = lora_sd[key] * ratio
|
||||
|
||||
return merged_sd
|
||||
print(f"dim (rank): {dim}, alpha: {alpha}")
|
||||
if alpha is None:
|
||||
alpha = dim
|
||||
|
||||
return merged_sd, dim, alpha
|
||||
|
||||
|
||||
def merge(args):
|
||||
@@ -132,7 +152,7 @@ def merge(args):
|
||||
model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet,
|
||||
args.sd_model, 0, 0, save_dtype, vae)
|
||||
else:
|
||||
state_dict = merge_lora_models(args.models, args.ratios, merge_dtype)
|
||||
state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype)
|
||||
|
||||
print(f"saving model to: {args.save_to}")
|
||||
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
|
||||
@@ -145,7 +165,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument("--save_precision", type=str, default=None,
|
||||
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ")
|
||||
parser.add_argument("--precision", type=str, default="float",
|
||||
choices=["float", "fp16", "bf16"], help="precision in merging / マージの計算時の精度")
|
||||
choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)")
|
||||
parser.add_argument("--sd_model", type=str, default=None,
|
||||
help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする")
|
||||
parser.add_argument("--save_to", type=str, default=None,
|
||||
|
||||
@@ -1,8 +1,4 @@
|
||||
# convert Diffusers v1.x/v2.0 model to original Stable Diffusion
|
||||
# v1: initial version
|
||||
# v2: support safetensors
|
||||
# v3: fix to support another format
|
||||
# v4: support safetensors in Diffusers
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
@@ -176,6 +176,8 @@ def train(args):
|
||||
# epoch数を計算する
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
||||
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
||||
|
||||
# 学習する
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
@@ -72,7 +72,7 @@ identifierとclassを使い、たとえば「shs dog」などでモデルを学
|
||||
※LoRA等の追加ネットワークを学習する場合のコマンドは ``train_db.py`` ではなく ``train_network.py`` となります。また追加でnetwork_\*オプションが必要となりますので、LoRAのガイドを参照してください。
|
||||
|
||||
```
|
||||
accelerate launch --num_cpu_threads_per_process 8 train_db.py
|
||||
accelerate launch --num_cpu_threads_per_process 1 train_db.py
|
||||
--pretrained_model_name_or_path=<.ckptまたは.safetensordまたはDiffusers版モデルのディレクトリ>
|
||||
--train_data_dir=<学習用データのディレクトリ>
|
||||
--reg_data_dir=<正則化画像のディレクトリ>
|
||||
@@ -89,7 +89,7 @@ accelerate launch --num_cpu_threads_per_process 8 train_db.py
|
||||
--gradient_checkpointing
|
||||
```
|
||||
|
||||
num_cpu_threads_per_processにはCPUコア数を指定するとよいようです。
|
||||
num_cpu_threads_per_processには通常は1を指定するとよいようです。
|
||||
|
||||
pretrained_model_name_or_pathに追加学習を行う元となるモデルを指定します。Stable Diffusionのcheckpointファイル(.ckptまたは.safetensors)、Diffusersのローカルディスクにあるモデルディレクトリ、DiffusersのモデルID("stabilityai/stable-diffusion-2"など)が指定できます。学習後のモデルの保存形式はデフォルトでは元のモデルと同じになります(save_model_asオプションで変更できます)。
|
||||
|
||||
@@ -159,7 +159,7 @@ v2.xモデルでWebUIで画像生成する場合、モデルの仕様が記述
|
||||
|
||||

|
||||
|
||||
各yamlファイルは[https://github.com/Stability-AI/stablediffusion/tree/main/configs/stable-diffusion](Stability AIのSD2.0のリポジトリ)にあります。
|
||||
各yamlファイルは[Stability AIのSD2.0のリポジトリ](https://github.com/Stability-AI/stablediffusion/tree/main/configs/stable-diffusion)にあります。
|
||||
|
||||
# その他の学習オプション
|
||||
|
||||
|
||||
143
train_network.py
143
train_network.py
@@ -3,6 +3,9 @@ import argparse
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import json
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
@@ -18,7 +21,89 @@ def collate_fn(examples):
|
||||
return examples[0]
|
||||
|
||||
|
||||
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
|
||||
logs = {"loss/current": current_loss, "loss/average": avr_loss}
|
||||
|
||||
if args.network_train_unet_only:
|
||||
logs["lr/unet"] = lr_scheduler.get_last_lr()[0]
|
||||
elif args.network_train_text_encoder_only:
|
||||
logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
|
||||
else:
|
||||
logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
|
||||
logs["lr/unet"] = lr_scheduler.get_last_lr()[-1] # may be same to textencoder
|
||||
|
||||
return logs
|
||||
|
||||
|
||||
# Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler
|
||||
# code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6
|
||||
# Which is a newer release of diffusers than currently packaged with sd-scripts
|
||||
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
|
||||
|
||||
from typing import Optional, Union
|
||||
from torch.optim import Optimizer
|
||||
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
||||
|
||||
def get_scheduler_fix(
|
||||
name: Union[str, SchedulerType],
|
||||
optimizer: Optimizer,
|
||||
num_warmup_steps: Optional[int] = None,
|
||||
num_training_steps: Optional[int] = None,
|
||||
num_cycles: int = 1,
|
||||
power: float = 1.0,
|
||||
):
|
||||
"""
|
||||
Unified API to get any scheduler from its name.
|
||||
Args:
|
||||
name (`str` or `SchedulerType`):
|
||||
The name of the scheduler to use.
|
||||
optimizer (`torch.optim.Optimizer`):
|
||||
The optimizer that will be used during training.
|
||||
num_warmup_steps (`int`, *optional*):
|
||||
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
||||
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
||||
num_training_steps (`int``, *optional*):
|
||||
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
||||
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
||||
num_cycles (`int`, *optional*):
|
||||
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
|
||||
power (`float`, *optional*, defaults to 1.0):
|
||||
Power factor. See `POLYNOMIAL` scheduler
|
||||
last_epoch (`int`, *optional*, defaults to -1):
|
||||
The index of the last epoch when resuming training.
|
||||
"""
|
||||
name = SchedulerType(name)
|
||||
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
||||
if name == SchedulerType.CONSTANT:
|
||||
return schedule_func(optimizer)
|
||||
|
||||
# All other schedulers require `num_warmup_steps`
|
||||
if num_warmup_steps is None:
|
||||
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
||||
|
||||
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
||||
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
|
||||
|
||||
# All other schedulers require `num_training_steps`
|
||||
if num_training_steps is None:
|
||||
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
||||
|
||||
if name == SchedulerType.COSINE_WITH_RESTARTS:
|
||||
return schedule_func(
|
||||
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
|
||||
)
|
||||
|
||||
if name == SchedulerType.POLYNOMIAL:
|
||||
return schedule_func(
|
||||
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
|
||||
)
|
||||
|
||||
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
|
||||
|
||||
|
||||
def train(args):
|
||||
session_id = random.randint(0, 2**32)
|
||||
training_started_at = time.time()
|
||||
train_util.verify_training_args(args)
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
|
||||
@@ -88,7 +173,8 @@ def train(args):
|
||||
key, value = net_arg.split('=')
|
||||
net_kwargs[key] = value
|
||||
|
||||
network = network_module.create_network(1.0, args.network_dim, vae, text_encoder, unet, **net_kwargs)
|
||||
# if a new network is added in future, add if ~ then blocks for each network (;'∀')
|
||||
network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs)
|
||||
if network is None:
|
||||
return
|
||||
|
||||
@@ -136,8 +222,11 @@ def train(args):
|
||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
# lr schedulerを用意する
|
||||
lr_scheduler = diffusers.optimization.get_scheduler(
|
||||
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
|
||||
# lr_scheduler = diffusers.optimization.get_scheduler(
|
||||
lr_scheduler = get_scheduler_fix(
|
||||
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
||||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
||||
if args.full_fp16:
|
||||
@@ -192,6 +281,8 @@ def train(args):
|
||||
# epoch数を計算する
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
||||
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
||||
|
||||
# 学習する
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
@@ -206,21 +297,26 @@ def train(args):
|
||||
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||
|
||||
metadata = {
|
||||
"ss_session_id": session_id, # random integer indicating which group of epochs the model came from
|
||||
"ss_training_started_at": training_started_at, # unix timestamp
|
||||
"ss_output_name": args.output_name,
|
||||
"ss_learning_rate": args.learning_rate,
|
||||
"ss_text_encoder_lr": args.text_encoder_lr,
|
||||
"ss_unet_lr": args.unet_lr,
|
||||
"ss_num_train_images": train_dataset.num_train_images, # includes repeating TODO more detailed data
|
||||
"ss_num_train_images": train_dataset.num_train_images, # includes repeating
|
||||
"ss_num_reg_images": train_dataset.num_reg_images,
|
||||
"ss_num_batches_per_epoch": len(train_dataloader),
|
||||
"ss_num_epochs": num_train_epochs,
|
||||
"ss_batch_size_per_device": args.train_batch_size,
|
||||
"ss_total_batch_size": total_batch_size,
|
||||
"ss_gradient_checkpointing": args.gradient_checkpointing,
|
||||
"ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
|
||||
"ss_max_train_steps": args.max_train_steps,
|
||||
"ss_lr_warmup_steps": args.lr_warmup_steps,
|
||||
"ss_lr_scheduler": args.lr_scheduler,
|
||||
"ss_network_module": args.network_module,
|
||||
"ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim
|
||||
"ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim
|
||||
"ss_network_alpha": args.network_alpha, # some networks may not use this value
|
||||
"ss_mixed_precision": args.mixed_precision,
|
||||
"ss_full_fp16": bool(args.full_fp16),
|
||||
"ss_v2": bool(args.v2),
|
||||
@@ -232,10 +328,15 @@ def train(args):
|
||||
"ss_random_crop": bool(args.random_crop),
|
||||
"ss_shuffle_caption": bool(args.shuffle_caption),
|
||||
"ss_cache_latents": bool(args.cache_latents),
|
||||
"ss_enable_bucket": bool(train_dataset.enable_bucket), # TODO move to BaseDataset from DB/FT
|
||||
"ss_min_bucket_reso": args.min_bucket_reso, # TODO get from dataset
|
||||
"ss_max_bucket_reso": args.max_bucket_reso,
|
||||
"ss_seed": args.seed
|
||||
"ss_enable_bucket": bool(train_dataset.enable_bucket),
|
||||
"ss_min_bucket_reso": train_dataset.min_bucket_reso,
|
||||
"ss_max_bucket_reso": train_dataset.max_bucket_reso,
|
||||
"ss_seed": args.seed,
|
||||
"ss_keep_tokens": args.keep_tokens,
|
||||
"ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info),
|
||||
"ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info),
|
||||
"ss_bucket_info": json.dumps(train_dataset.bucket_info),
|
||||
"ss_training_comment": args.training_comment # will not be updated after training
|
||||
}
|
||||
|
||||
# uncomment if another network is added
|
||||
@@ -246,6 +347,7 @@ def train(args):
|
||||
sd_model_name = args.pretrained_model_name_or_path
|
||||
if os.path.exists(sd_model_name):
|
||||
metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name)
|
||||
metadata["ss_new_sd_model_hash"] = train_util.calculate_sha256(sd_model_name)
|
||||
sd_model_name = os.path.basename(sd_model_name)
|
||||
metadata["ss_sd_model_name"] = sd_model_name
|
||||
|
||||
@@ -253,6 +355,7 @@ def train(args):
|
||||
vae_name = args.vae
|
||||
if os.path.exists(vae_name):
|
||||
metadata["ss_vae_hash"] = train_util.model_hash(vae_name)
|
||||
metadata["ss_new_vae_hash"] = train_util.calculate_sha256(vae_name)
|
||||
vae_name = os.path.basename(vae_name)
|
||||
metadata["ss_vae_name"] = vae_name
|
||||
|
||||
@@ -333,20 +436,20 @@ def train(args):
|
||||
global_step += 1
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
loss_total += current_loss
|
||||
avr_loss = loss_total / (step+1)
|
||||
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = {"epoch_loss": loss_total / len(train_dataloader)}
|
||||
logs = {"loss/epoch": loss_total / len(train_dataloader)}
|
||||
accelerator.log(logs, step=epoch+1)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
@@ -406,22 +509,30 @@ if __name__ == '__main__':
|
||||
train_util.add_training_arguments(parser, True)
|
||||
|
||||
parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
|
||||
parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
|
||||
help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)")
|
||||
parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
|
||||
help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)")
|
||||
|
||||
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
||||
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
|
||||
parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
|
||||
help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
|
||||
parser.add_argument("--lr_scheduler_power", type=float, default=1,
|
||||
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
|
||||
|
||||
parser.add_argument("--network_weights", type=str, default=None,
|
||||
help="pretrained weights for network / 学習するネットワークの初期重み")
|
||||
parser.add_argument("--network_module", type=str, default=None, help='network module to train / 学習対象のネットワークのモジュール')
|
||||
parser.add_argument("--network_dim", type=int, default=None,
|
||||
help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)')
|
||||
parser.add_argument("--network_alpha", type=float, default=1,
|
||||
help='alpha for LoRA weight scaling, default 1 (same as network_dim for same behavior as old version) / LoRaの重み調整のalpha値、デフォルト1(旧バージョンと同じ動作をするにはnetwork_dimと同じ値を指定)')
|
||||
parser.add_argument("--network_args", type=str, default=None, nargs='*',
|
||||
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
|
||||
parser.add_argument("--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する")
|
||||
parser.add_argument("--network_train_text_encoder_only", action="store_true",
|
||||
help="only training Text Encoder part / Text Encoder関連部分のみ学習する")
|
||||
parser.add_argument("--training_comment", type=str, default=None,
|
||||
help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列")
|
||||
|
||||
args = parser.parse_args()
|
||||
train(args)
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
|
||||
cloneofsimo氏のリポジトリ、およびd8ahazard氏の[Dreambooth Extension for Stable-Diffusion-WebUI](https://github.com/d8ahazard/sd_dreambooth_extension)とは、現時点では互換性がありません。いくつかの機能拡張を行っているためです(後述)。
|
||||
|
||||
WebUI等で画像生成する場合には、学習したLoRAのモデルを学習元のStable Diffusionのモデルにこのリポジトリ内のスクリプトであらかじめマージしておくか、こちらの[WebUI用extention](https://github.com/kohya-ss/sd-webui-additional-networks)を使ってください。
|
||||
WebUI等で画像生成する場合には、学習したLoRAのモデルを学習元のStable Diffusionのモデルにこのリポジトリ内のスクリプトであらかじめマージしておくか、こちらの[WebUI用extension](https://github.com/kohya-ss/sd-webui-additional-networks)を使ってください。
|
||||
|
||||
## 学習方法
|
||||
|
||||
@@ -24,7 +24,7 @@ DreamBoothの手法(identifier(sksなど)とclass、オプションで正
|
||||
|
||||
[DreamBoothのガイド](./train_db_README-ja.md) を参照してデータを用意してください。
|
||||
|
||||
学習するとき、train_db.pyの代わりにtrain_network.pyを指定してください。
|
||||
学習するとき、train_db.pyの代わりにtrain_network.pyを指定してください。そして「LoRAの学習のためのオプション」にあるようにLoRA関連のオプション(``network_dim``や``network_alpha``など)を追加してください。
|
||||
|
||||
ほぼすべてのオプション(Stable Diffusionのモデル保存関係を除く)が使えますが、stop_text_encoder_trainingはサポートしていません。
|
||||
|
||||
@@ -32,7 +32,7 @@ DreamBoothの手法(identifier(sksなど)とclass、オプションで正
|
||||
|
||||
[fine-tuningのガイド](./fine_tune_README_ja.md) を参照し、各手順を実行してください。
|
||||
|
||||
学習するとき、fine_tune.pyの代わりにtrain_network.pyを指定してください。ほぼすべてのオプション(モデル保存関係を除く)がそのまま使えます。
|
||||
学習するとき、fine_tune.pyの代わりにtrain_network.pyを指定してください。ほぼすべてのオプション(モデル保存関係を除く)がそのまま使えます。そして「LoRAの学習のためのオプション」にあるようにLoRA関連のオプション(``network_dim``や``network_alpha``など)を追加してください。
|
||||
|
||||
なお「latentsの事前取得」は行わなくても動作します。VAEから学習時(またはキャッシュ時)にlatentを取得するため学習速度は遅くなりますが、代わりにcolor_augが使えるようになります。
|
||||
|
||||
@@ -45,7 +45,7 @@ train_network.pyでは--network_moduleオプションに、学習対象のモジ
|
||||
以下はコマンドラインの例です(DreamBooth手法)。
|
||||
|
||||
```
|
||||
accelerate launch --num_cpu_threads_per_process 12 train_network.py
|
||||
accelerate launch --num_cpu_threads_per_process 1 train_network.py
|
||||
--pretrained_model_name_or_path=..\models\model.ckpt
|
||||
--train_data_dir=..\data\db\char1 --output_dir=..\lora_train1
|
||||
--reg_data_dir=..\data\db\reg1 --prior_loss_weight=1.0
|
||||
@@ -60,7 +60,9 @@ accelerate launch --num_cpu_threads_per_process 12 train_network.py
|
||||
その他、以下のオプションが指定できます。
|
||||
|
||||
* --network_dim
|
||||
* LoRAの次元数を指定します(``--networkdim=4``など)。省略時は4になります。数が多いほど表現力は増しますが、学習に必要なメモリ、時間は増えます。また闇雲に増やしても良くないようです。
|
||||
* LoRAのRANKを指定します(``--networkdim=4``など)。省略時は4になります。数が多いほど表現力は増しますが、学習に必要なメモリ、時間は増えます。また闇雲に増やしても良くないようです。
|
||||
* --network_alpha
|
||||
* アンダーフローを防ぎ安定して学習するための ``alpha`` 値を指定します。デフォルトは1です。``network_dim``と同じ値を指定すると以前のバージョンと同じ動作になります。
|
||||
* --network_weights
|
||||
* 学習前に学習済みのLoRAの重みを読み込み、そこから追加で学習します。
|
||||
* --network_train_unet_only
|
||||
@@ -126,7 +128,7 @@ python networks\merge_lora.py
|
||||
|
||||
--ratiosにそれぞれのモデルの比率(どのくらい重みを元モデルに反映するか)を0~1.0の数値で指定します。二つのモデルを一対一でマージす場合は、「0.5 0.5」になります。「1.0 1.0」では合計の重みが大きくなりすぎて、恐らく結果はあまり望ましくないものになると思われます。
|
||||
|
||||
v1で学習したLoRAとv2で学習したLoRA、次元数の異なるLoRAはマージできません。U-NetだけのLoRAとU-Net+Text EncoderのLoRAはマージできるはずですが、結果は未知数です。
|
||||
v1で学習したLoRAとv2で学習したLoRA、rank(次元数)や``alpha``の異なるLoRAはマージできません。U-NetだけのLoRAとU-Net+Text EncoderのLoRAはマージできるはずですが、結果は未知数です。
|
||||
|
||||
|
||||
### その他のオプション
|
||||
@@ -138,7 +140,7 @@ v1で学習したLoRAとv2で学習したLoRA、次元数の異なるLoRAはマ
|
||||
|
||||
## 当リポジトリ内の画像生成スクリプトで生成する
|
||||
|
||||
gen_img_diffusers.pyに、--network_module、--network_weights、--network_dim(省略可)の各オプションを追加してください。意味は学習時と同様です。
|
||||
gen_img_diffusers.pyに、--network_module、--network_weightsの各オプションを追加してください。意味は学習時と同様です。
|
||||
|
||||
--network_mulオプションで0~1.0の数値を指定すると、LoRAの適用率を変えられます。
|
||||
|
||||
|
||||
498
train_textual_inversion.py
Normal file
498
train_textual_inversion.py
Normal file
@@ -0,0 +1,498 @@
|
||||
import importlib
|
||||
import argparse
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
from accelerate.utils import set_seed
|
||||
import diffusers
|
||||
from diffusers import DDPMScheduler
|
||||
|
||||
import library.train_util as train_util
|
||||
from library.train_util import DreamBoothDataset, FineTuningDataset
|
||||
|
||||
imagenet_templates_small = [
|
||||
"a photo of a {}",
|
||||
"a rendering of a {}",
|
||||
"a cropped photo of the {}",
|
||||
"the photo of a {}",
|
||||
"a photo of a clean {}",
|
||||
"a photo of a dirty {}",
|
||||
"a dark photo of the {}",
|
||||
"a photo of my {}",
|
||||
"a photo of the cool {}",
|
||||
"a close-up photo of a {}",
|
||||
"a bright photo of the {}",
|
||||
"a cropped photo of a {}",
|
||||
"a photo of the {}",
|
||||
"a good photo of the {}",
|
||||
"a photo of one {}",
|
||||
"a close-up photo of the {}",
|
||||
"a rendition of the {}",
|
||||
"a photo of the clean {}",
|
||||
"a rendition of a {}",
|
||||
"a photo of a nice {}",
|
||||
"a good photo of a {}",
|
||||
"a photo of the nice {}",
|
||||
"a photo of the small {}",
|
||||
"a photo of the weird {}",
|
||||
"a photo of the large {}",
|
||||
"a photo of a cool {}",
|
||||
"a photo of a small {}",
|
||||
]
|
||||
|
||||
imagenet_style_templates_small = [
|
||||
"a painting in the style of {}",
|
||||
"a rendering in the style of {}",
|
||||
"a cropped painting in the style of {}",
|
||||
"the painting in the style of {}",
|
||||
"a clean painting in the style of {}",
|
||||
"a dirty painting in the style of {}",
|
||||
"a dark painting in the style of {}",
|
||||
"a picture in the style of {}",
|
||||
"a cool painting in the style of {}",
|
||||
"a close-up painting in the style of {}",
|
||||
"a bright painting in the style of {}",
|
||||
"a cropped painting in the style of {}",
|
||||
"a good painting in the style of {}",
|
||||
"a close-up painting in the style of {}",
|
||||
"a rendition in the style of {}",
|
||||
"a nice painting in the style of {}",
|
||||
"a small painting in the style of {}",
|
||||
"a weird painting in the style of {}",
|
||||
"a large painting in the style of {}",
|
||||
]
|
||||
|
||||
|
||||
def collate_fn(examples):
|
||||
return examples[0]
|
||||
|
||||
|
||||
def train(args):
|
||||
if args.output_name is None:
|
||||
args.output_name = args.token_string
|
||||
use_template = args.use_object_template or args.use_style_template
|
||||
|
||||
train_util.verify_training_args(args)
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
|
||||
cache_latents = args.cache_latents
|
||||
use_dreambooth_method = args.in_json is None
|
||||
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
|
||||
tokenizer = train_util.load_tokenizer(args)
|
||||
|
||||
# acceleratorを準備する
|
||||
print("prepare accelerator")
|
||||
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
|
||||
# モデルを読み込む
|
||||
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
|
||||
|
||||
# Convert the init_word to token_id
|
||||
if args.init_word is not None:
|
||||
init_token_id = tokenizer.encode(args.init_word, add_special_tokens=False)
|
||||
assert len(
|
||||
init_token_id) == 1, f"init word {args.init_word} is not converted to single token / 初期化単語が二つ以上のトークンに変換されます。別の単語を使ってください"
|
||||
init_token_id = init_token_id[0]
|
||||
else:
|
||||
init_token_id = None
|
||||
|
||||
# add new word to tokenizer, count is num_vectors_per_token
|
||||
token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
|
||||
num_added_tokens = tokenizer.add_tokens(token_strings)
|
||||
assert num_added_tokens == args.num_vectors_per_token, f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}"
|
||||
|
||||
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
|
||||
print(f"tokens are added: {token_ids}")
|
||||
assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered"
|
||||
assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}"
|
||||
|
||||
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
||||
text_encoder.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
# Initialise the newly added placeholder token with the embeddings of the initializer token
|
||||
token_embeds = text_encoder.get_input_embeddings().weight.data
|
||||
if init_token_id is not None:
|
||||
for token_id in token_ids:
|
||||
token_embeds[token_id] = token_embeds[init_token_id]
|
||||
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
|
||||
|
||||
# load weights
|
||||
if args.weights is not None:
|
||||
embeddings = load_weights(args.weights)
|
||||
assert len(token_ids) == len(
|
||||
embeddings), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}"
|
||||
# print(token_ids, embeddings.size())
|
||||
for token_id, embedding in zip(token_ids, embeddings):
|
||||
token_embeds[token_id] = embedding
|
||||
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
|
||||
print(f"weighs loaded")
|
||||
|
||||
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
|
||||
|
||||
# データセットを準備する
|
||||
if use_dreambooth_method:
|
||||
print("Use DreamBooth method.")
|
||||
train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
|
||||
tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
|
||||
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, args.prior_loss_weight,
|
||||
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset)
|
||||
else:
|
||||
print("Train with captions.")
|
||||
train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
|
||||
tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
|
||||
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
|
||||
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
|
||||
args.dataset_repeats, args.debug_dataset)
|
||||
|
||||
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
|
||||
if use_template:
|
||||
print("use template for training captions. is object: {args.use_object_template}")
|
||||
templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small
|
||||
replace_to = " ".join(token_strings)
|
||||
captions = []
|
||||
for tmpl in templates:
|
||||
captions.append(tmpl.format(replace_to))
|
||||
train_dataset.add_replacement("", captions)
|
||||
elif args.num_vectors_per_token > 1:
|
||||
replace_to = " ".join(token_strings)
|
||||
train_dataset.add_replacement(args.token_string, replace_to)
|
||||
|
||||
train_dataset.make_buckets()
|
||||
|
||||
if args.debug_dataset:
|
||||
train_util.debug_dataset(train_dataset, show_input_ids=True)
|
||||
return
|
||||
if len(train_dataset) == 0:
|
||||
print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
|
||||
return
|
||||
|
||||
# モデルに xformers とか memory efficient attention を組み込む
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
||||
|
||||
# 学習を準備する
|
||||
if cache_latents:
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
train_dataset.cache_latents(vae)
|
||||
vae.to("cpu")
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
text_encoder.gradient_checkpointing_enable()
|
||||
|
||||
# 学習に必要なクラスを準備する
|
||||
print("prepare optimizer, data loader etc.")
|
||||
|
||||
# 8-bit Adamを使う
|
||||
if args.use_8bit_adam:
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
||||
print("use 8-bit Adam optimizer")
|
||||
optimizer_class = bnb.optim.AdamW8bit
|
||||
else:
|
||||
optimizer_class = torch.optim.AdamW
|
||||
|
||||
trainable_params = text_encoder.get_input_embeddings().parameters()
|
||||
|
||||
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
||||
optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
|
||||
|
||||
# dataloaderを準備する
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers)
|
||||
|
||||
# 学習ステップ数を計算する
|
||||
if args.max_train_epochs is not None:
|
||||
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
|
||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||
|
||||
# lr schedulerを用意する
|
||||
lr_scheduler = diffusers.optimization.get_scheduler(
|
||||
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
text_encoder, optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
index_no_updates = torch.arange(len(tokenizer)) < token_ids[0]
|
||||
print(len(index_no_updates), torch.sum(index_no_updates))
|
||||
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
|
||||
|
||||
# Freeze all parameters except for the token embeddings in text encoder
|
||||
text_encoder.requires_grad_(True)
|
||||
text_encoder.text_model.encoder.requires_grad_(False)
|
||||
text_encoder.text_model.final_layer_norm.requires_grad_(False)
|
||||
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
|
||||
# text_encoder.text_model.embeddings.token_embedding.requires_grad_(True)
|
||||
|
||||
unet.requires_grad_(False)
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
if args.gradient_checkpointing: # according to TI example in Diffusers, train is required
|
||||
unet.train()
|
||||
else:
|
||||
unet.eval()
|
||||
|
||||
if not cache_latents:
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||
if args.full_fp16:
|
||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||
text_encoder.to(weight_dtype)
|
||||
|
||||
# resumeする
|
||||
if args.resume is not None:
|
||||
print(f"resume training from state: {args.resume}")
|
||||
accelerator.load_state(args.resume)
|
||||
|
||||
# epoch数を計算する
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
||||
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
||||
|
||||
# 学習する
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
print("running training / 学習開始")
|
||||
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
|
||||
print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
|
||||
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||
print(f" num epochs / epoch数: {num_train_epochs}")
|
||||
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
||||
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
||||
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||
|
||||
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
||||
global_step = 0
|
||||
|
||||
noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
|
||||
num_train_timesteps=1000, clip_sample=False)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("textual_inversion")
|
||||
|
||||
for epoch in range(num_train_epochs):
|
||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||
|
||||
text_encoder.train()
|
||||
|
||||
loss_total = 0
|
||||
bef_epo_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(text_encoder):
|
||||
with torch.no_grad():
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device)
|
||||
else:
|
||||
# latentに変換
|
||||
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||
latents = latents * 0.18215
|
||||
b_size = latents.shape[0]
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
input_ids = batch["input_ids"].to(accelerator.device)
|
||||
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, torch.float) # weight_dtype) use float instead of fp16/bf16 because text encoder is float
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents, device=latents.device)
|
||||
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Predict the noise residual
|
||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
if args.v_parameterization:
|
||||
# v-parameterization training
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
params_to_clip = text_encoder.get_input_embeddings().parameters()
|
||||
accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
with torch.no_grad():
|
||||
unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[index_no_updates]
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
loss_total += current_loss
|
||||
avr_loss = loss_total / (step+1)
|
||||
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss/epoch": loss_total / len(train_dataloader)}
|
||||
accelerator.log(logs, step=epoch+1)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
|
||||
d = updated_embs - bef_epo_embs
|
||||
print(bef_epo_embs.size(), updated_embs.size(), d.mean(), d.min())
|
||||
|
||||
if args.save_every_n_epochs is not None:
|
||||
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
|
||||
|
||||
def save_func():
|
||||
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
print(f"saving checkpoint: {ckpt_file}")
|
||||
save_weights(ckpt_file, updated_embs, save_dtype)
|
||||
|
||||
def remove_old_func(old_epoch_no):
|
||||
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
|
||||
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
||||
if os.path.exists(old_ckpt_file):
|
||||
print(f"removing old checkpoint: {old_ckpt_file}")
|
||||
os.remove(old_ckpt_file)
|
||||
|
||||
saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
|
||||
if saving and args.save_state:
|
||||
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
||||
|
||||
# end of epoch
|
||||
|
||||
is_main_process = accelerator.is_main_process
|
||||
if is_main_process:
|
||||
text_encoder = unwrap_model(text_encoder)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
if args.save_state:
|
||||
train_util.save_state_on_train_end(args, accelerator)
|
||||
|
||||
updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone()
|
||||
|
||||
del accelerator # この後メモリを使うのでこれは消す
|
||||
|
||||
if is_main_process:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
||||
ckpt_name = model_name + '.' + args.save_model_as
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
|
||||
print(f"save trained model to {ckpt_file}")
|
||||
save_weights(ckpt_file, updated_embs, save_dtype)
|
||||
print("model saved.")
|
||||
|
||||
|
||||
def save_weights(file, updated_embs, save_dtype):
|
||||
state_dict = {"emb_params": updated_embs}
|
||||
|
||||
if save_dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
v = state_dict[key]
|
||||
v = v.detach().clone().to("cpu").to(save_dtype)
|
||||
state_dict[key] = v
|
||||
|
||||
if os.path.splitext(file)[1] == '.safetensors':
|
||||
from safetensors.torch import save_file
|
||||
save_file(state_dict, file)
|
||||
else:
|
||||
torch.save(state_dict, file) # can be loaded in Web UI
|
||||
|
||||
|
||||
def load_weights(file):
|
||||
if os.path.splitext(file)[1] == '.safetensors':
|
||||
from safetensors.torch import load_file
|
||||
data = load_file(file)
|
||||
else:
|
||||
# compatible to Web UI's file format
|
||||
data = torch.load(file, map_location='cpu')
|
||||
if type(data) != dict:
|
||||
raise ValueError(f"weight file is not dict / 重みファイルがdict形式ではありません: {file}")
|
||||
|
||||
if 'string_to_param' in data: # textual inversion embeddings
|
||||
data = data['string_to_param']
|
||||
if hasattr(data, '_parameters'): # support old PyTorch?
|
||||
data = getattr(data, '_parameters')
|
||||
|
||||
emb = next(iter(data.values()))
|
||||
if type(emb) != torch.Tensor:
|
||||
raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {file}")
|
||||
|
||||
if len(emb.size()) == 1:
|
||||
emb = emb.unsqueeze(0)
|
||||
|
||||
return emb
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, True, True)
|
||||
train_util.add_training_arguments(parser, True)
|
||||
|
||||
parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
|
||||
help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)")
|
||||
|
||||
parser.add_argument("--weights", type=str, default=None,
|
||||
help="embedding weights to initialize / 学習するネットワークの初期重み")
|
||||
parser.add_argument("--num_vectors_per_token", type=int, default=1,
|
||||
help='number of vectors per token / トークンに割り当てるembeddingsの要素数')
|
||||
parser.add_argument("--token_string", type=str, default=None,
|
||||
help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること")
|
||||
parser.add_argument("--init_word", type=str, default=None,
|
||||
help="word to initialize vector / ベクトルを初期化に使用する単語、tokenizerで一語になること")
|
||||
parser.add_argument("--use_object_template", action='store_true',
|
||||
help="ignore caption and use default templates for object / キャプションは使わずデフォルトの物体用テンプレートで学習する")
|
||||
parser.add_argument("--use_style_template", action='store_true',
|
||||
help="ignore caption and use default templates for stype / キャプションは使わずデフォルトのスタイル用テンプレートで学習する")
|
||||
|
||||
args = parser.parse_args()
|
||||
train(args)
|
||||
63
train_ti_README-ja.md
Normal file
63
train_ti_README-ja.md
Normal file
@@ -0,0 +1,63 @@
|
||||
## Textual Inversionの学習について
|
||||
|
||||
[Textual Inversion](https://textual-inversion.github.io/)です。実装に当たっては https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion を大いに参考にしました。
|
||||
|
||||
学習したモデルはWeb UIでもそのまま使えます。
|
||||
|
||||
なお恐らくSD2.xにも対応していますが現時点では未テストです。
|
||||
|
||||
## 学習方法
|
||||
|
||||
``train_textual_inversion.py`` を用います。
|
||||
|
||||
データの準備については ``train_network.py`` と全く同じですので、[そちらのドキュメント](./train_network_README-ja.md)を参照してください。
|
||||
|
||||
## オプション
|
||||
|
||||
以下はコマンドラインの例です(DreamBooth手法)。
|
||||
|
||||
```
|
||||
accelerate launch --num_cpu_threads_per_process 1 train_textual_inversion.py
|
||||
--pretrained_model_name_or_path=..\models\model.ckpt
|
||||
--train_data_dir=..\data\db\char1 --output_dir=..\ti_train1
|
||||
--resolution=448,640 --train_batch_size=1 --learning_rate=1e-4
|
||||
--max_train_steps=400 --use_8bit_adam --xformers --mixed_precision=fp16
|
||||
--save_every_n_epochs=1 --save_model_as=safetensors --clip_skip=2 --seed=42 --color_aug
|
||||
--token_string=mychar4 --init_word=cute --num_vectors_per_token=4
|
||||
```
|
||||
|
||||
``--token_string`` に学習時のトークン文字列を指定します。__学習時のプロンプトは、この文字列を含むようにしてください(token_stringがmychar4なら、``mychar4 1girl`` など)__。プロンプトのこの文字列の部分が、Textual Inversionの新しいtokenに置換されて学習されます。
|
||||
|
||||
プロンプトにトークン文字列が含まれているかどうかは、``--debug_dataset`` で置換後のtoken idが表示されますので、以下のように ``49408`` 以降のtokenが存在するかどうかで確認できます。
|
||||
|
||||
```
|
||||
input ids: tensor([[49406, 49408, 49409, 49410, 49411, 49412, 49413, 49414, 49415, 49407,
|
||||
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
|
||||
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
|
||||
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
|
||||
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
|
||||
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
|
||||
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
|
||||
49407, 49407, 49407, 49407, 49407, 49407, 49407]])
|
||||
```
|
||||
|
||||
tokenizerがすでに持っている単語(一般的な単語)は使用できません。
|
||||
|
||||
``--init_word`` にembeddingsを初期化するときのコピー元トークンの文字列を指定します。学ばせたい概念が近いものを選ぶとよいようです。二つ以上のトークンになる文字列は指定できません。
|
||||
|
||||
``--num_vectors_per_token`` にいくつのトークンをこの学習で使うかを指定します。多いほうが表現力が増しますが、その分多くのトークンを消費します。たとえばnum_vectors_per_token=8の場合、指定したトークン文字列は(一般的なプロンプトの77トークン制限のうち)8トークンを消費します。
|
||||
|
||||
|
||||
その他、以下のオプションが指定できます。
|
||||
|
||||
* --weights
|
||||
* 学習前に学習済みのembeddingsを読み込み、そこから追加で学習します。
|
||||
* --use_object_template
|
||||
* キャプションではなく既定の物体用テンプレート文字列(``a photo of a {}``など)で学習します。公式実装と同じになります。キャプションは無視されます。
|
||||
* --use_style_template
|
||||
* キャプションではなく既定のスタイル用テンプレート文字列で学習します(``a painting in the style of {}``など)。公式実装と同じになります。キャプションは無視されます。
|
||||
|
||||
## 当リポジトリ内の画像生成スクリプトで生成する
|
||||
|
||||
gen_img_diffusers.pyに、``--textual_inversion_embeddings`` オプションで学習したembeddingsファイルを指定してください(複数可)。プロンプトでembeddingsファイルのファイル名(拡張子を除く)を使うと、そのembeddingsが適用されます。
|
||||
|
||||
Reference in New Issue
Block a user