Compare commits

...

83 Commits

Author SHA1 Message Date
Kohya S
5050971ac6 Merge pull request #388 from kohya-ss/dev
add weighted caption for training
2023-04-08 22:00:46 +09:00
Kohya S
08c54dcf22 update readme 2023-04-08 21:58:22 +09:00
Kohya S
6a5f87d874 disable weighted captions in TI/XTI training 2023-04-08 21:45:57 +09:00
Kohya S
a876f2d3fb format by black 2023-04-08 21:36:35 +09:00
Kohya S
a75f5898e6 Merge pull request #336 from AI-Casanova/weighted_captions
Proof of Concept: Weighted captions
2023-04-08 21:31:05 +09:00
AI-Casanova
dbab72153f Clean up custom_train_functions.py
Removed commented out lines from earlier bugfix.
2023-04-08 00:44:56 -05:00
AI-Casanova
0d54609435 Merge branch 'kohya-ss:main' into weighted_captions 2023-04-07 14:55:40 -05:00
Kohya S
b5c60d7d62 Merge pull request #381 from kohya-ss/dev
feature to upload to huggingface etc.
2023-04-06 08:20:07 +09:00
Kohya S
defefd79c5 Merge branch 'main' into dev 2023-04-06 08:16:31 +09:00
Kohya S
27834df444 update readme 2023-04-06 08:16:02 +09:00
Kohya S
5c020bed49 Add attension couple+reginal LoRA 2023-04-06 08:11:54 +09:00
Kohya S
c775ec1255 Add about using LoRA with Diffusers standard pipe 2023-04-06 08:10:41 +09:00
AI-Casanova
7527436549 Merge branch 'kohya-ss:main' into weighted_captions 2023-04-05 17:07:15 -05:00
Kohya S
541539a144 change method name, repo is private in default etc 2023-04-05 23:16:49 +09:00
Kohya S
74220bb52c Merge pull request #348 from ddPn08/dev
Added a function to upload to Huggingface and resume from Huggingface.
2023-04-05 21:47:36 +09:00
Kohya S
8eb60baf3a Merge pull request #374 from kohya-ss/dev
block learning rate, block dim(rank) etc.
2023-04-04 08:33:18 +09:00
Kohya S
4b47e8ecb0 update readme 2023-04-04 08:27:30 +09:00
Kohya S
76bac2c1c5 add backward compatiblity 2023-04-04 08:27:11 +09:00
Kohya S
0fcdda7175 Merge pull request #373 from rockerBOO/meta-min_snr_gamma
Add min_snr_gamma to metadata
2023-04-04 07:57:50 +09:00
Kohya S
e4eb3e63e6 improve compatibility 2023-04-04 07:48:48 +09:00
rockerBOO
626d4b433a Add min_snr_gamma to metadata 2023-04-03 12:38:20 -04:00
Kohya S
83c7e03d05 Fix network_weights not working in train_network 2023-04-03 22:45:28 +09:00
Kohya S
959561473c Merge branch 'main' into dev 2023-04-03 22:09:17 +09:00
Kohya S
7209eb74cc update readme 2023-04-03 22:08:58 +09:00
Kohya S
53cc3583df fix potential issue with dtype 2023-04-03 21:46:12 +09:00
Kohya S
82c2553f07 Merge pull request #353 from Riyaaaaa/patch-1
fix typo
2023-04-03 21:45:03 +09:00
Kohya S
6f6f9b537f Merge pull request #364 from shirayu/check_needless_num_warmup_steps
Check needless num_warmup_steps
2023-04-03 21:38:52 +09:00
Kohya S
f407f5a686 Merge pull request #352 from rockerBOO/dataset-config
Open dataset_config json file before load
2023-04-03 21:31:55 +09:00
Kohya S
6134619998 Add block dim(rank) feature 2023-04-03 21:19:49 +09:00
Kohya S
817a9268ff update readme for block weight lr 2023-04-03 08:43:26 +09:00
Kohya S
3beddf341e Suppor LR graphs for each block, base lr 2023-04-03 08:43:11 +09:00
AI-Casanova
1892c82a60 Reinstantiate weighted captions after a necessary revert to Main 2023-04-02 19:43:34 +00:00
ddPn08
3f339cda6f small fix 2023-04-02 23:21:17 +09:00
ddPn08
16ba1cec69 change async uploading to optional 2023-04-02 17:45:26 +09:00
ddPn08
8bfa50e283 small fix 2023-04-02 17:39:23 +09:00
ddPn08
c4a11e5a5a fix help 2023-04-02 17:39:23 +09:00
ddPn08
3cc4939dd3 Implement huggingface upload for all scripts 2023-04-02 17:39:22 +09:00
ddPn08
b5c7937f8d don't run when not needed 2023-04-02 17:39:21 +09:00
ddPn08
b5ff4e816f resume from huggingface repository 2023-04-02 17:39:21 +09:00
ddPn08
a7d302e196 write a random seed to metadata 2023-04-02 17:39:20 +09:00
ddPn08
45381b188c small fix 2023-04-02 17:39:20 +09:00
ddPn08
054fb3308c use access token 2023-04-02 17:39:19 +09:00
ddPn08
d42431d73a Added feature to upload to huggingface 2023-04-02 17:39:10 +09:00
Kohya S
c639cb7d5d support older type hint 2023-04-02 16:18:04 +09:00
Kohya S
97e65bf93f change 'stratify' to 'block', add en message 2023-04-02 16:10:09 +09:00
Kohya S
36c8a4aee7 Merge pull request #355 from u-haru/feature/stratified_lr
LoRA レイヤー別学習率の実装、state_dict読み込みの際のdevice指定削除、typo修正
2023-04-02 15:25:17 +09:00
u-haru
19340d82e6 層別学習率を使わない場合にparamsをまとめる 2023-04-02 12:57:55 +09:00
u-haru
058e442072 レイヤー数変更(hako-mikan/sd-webui-lora-block-weight参考) 2023-04-02 04:02:34 +09:00
Yuta Hayashibe
9577a9f38d Check needless num_warmup_steps 2023-04-01 20:33:20 +09:00
u-haru
786971d443 Merge branch 'dev' into feature/stratified_lr 2023-04-01 15:08:41 +09:00
Kohya S
f037b09c2d Merge pull request #360 from kohya-ss/dev
fix for merge_lora.py
2023-04-01 09:25:57 +09:00
Kohya S
18d69d8e5e update readme 2023-04-01 09:21:27 +09:00
Kohya S
770a56193e fix conv2d3x3 is not merged 2023-04-01 09:17:37 +09:00
Kohya S
4627b389ff fix device not specified in merge_lora.py 2023-04-01 09:15:57 +09:00
Kohya S
1cd07770a4 format by black 2023-04-01 09:13:47 +09:00
u-haru
1e164b6ec3 specify device when loading state_dict 2023-03-31 12:52:39 +09:00
u-haru
41ecccb2a9 Merge branch 'kohya-ss:main' into feature/stratified_lr 2023-03-31 12:47:56 +09:00
Kohya S
c93cbbc373 Merge pull request #357 from kohya-ss/dev
Fix device issue in load_file, reduce vram usage
2023-03-31 09:07:49 +09:00
Kohya S
8cecc676cf Fix device issue in load_file, reduce vram usage 2023-03-31 09:05:51 +09:00
u-haru
94441fa746 繰り返し回数のないディレクトリの名前表示修正 2023-03-31 02:26:54 +09:00
Atsumu Ono
ccb0ef518a fix typo 2023-03-31 01:45:49 +09:00
u-haru
3032a47af4 cosineをsineのreversedに変更 2023-03-31 01:42:57 +09:00
u-haru
1b75dbd4f2 引数名に_lrを追加 2023-03-31 01:40:29 +09:00
u-haru
dade23a414 stratified_zero_thresholdに変更 2023-03-31 01:14:03 +09:00
rockerBOO
313f3e8286 Open dataset_config json file before load 2023-03-30 12:08:04 -04:00
u-haru
4dacc52bde implement stratified_lr 2023-03-31 00:39:35 +09:00
u-haru
b1dffe8d9a ファイルロードができないバグ修正(Exception: device cuda is invalid) 2023-03-31 00:11:11 +09:00
Kohya S
ea1cf4acee Merge pull request #350 from kohya-ss/dev
fix gen not working
2023-03-30 22:30:47 +09:00
Kohya S
cd5e3baace Merge branch 'main' into dev 2023-03-30 22:29:19 +09:00
Kohya S
e76ea7cd7d fix not working 2023-03-30 22:28:55 +09:00
Kohya S
d68ba2f9de Merge pull request #349 from kohya-ss/dev
P+, reduce ram usage etc.
2023-03-30 22:07:03 +09:00
Kohya S
5fc80b7a5b update readme 2023-03-30 22:03:13 +09:00
Kohya S
31069e1dc5 add comments about debice for clarify 2023-03-30 21:44:40 +09:00
Kohya S
6c28dfb417 Merge pull request #332 from guaneec/ddp-lowram
Reduce peak RAM usage
2023-03-30 21:37:37 +09:00
Kohya S
2d6faa9860 support LoRA merge in advance 2023-03-30 21:34:36 +09:00
Kohya S
cb53a77334 show warning message for sample images in XTI 2023-03-30 21:33:57 +09:00
Kohya S
4d91dc0d30 Merge branch 'dev' of https://github.com/kohya-ss/sd-scripts into dev 2023-03-30 21:23:18 +09:00
Kohya S
935d4774a9 Merge pull request #327 from jakaline-dev/main
P+: Extended Textual Conditioning in Text-to-Image Generation
2023-03-30 19:44:57 +09:00
Jakaline-dev
24e3d4b464 disabled sampling (for now) 2023-03-30 02:20:04 +09:00
Jakaline-dev
b0c33a4294 Merge remote-tracking branch 'upstream/main' 2023-03-30 01:35:38 +09:00
Kohya S
bf3674c1db format by black 2023-03-29 21:23:27 +09:00
guaneec
3cdae0cbd2 Reduce peak RAM usage 2023-03-27 14:34:17 +08:00
Jakaline-dev
a35d7ef227 Implement XTI 2023-03-26 05:26:10 +09:00
20 changed files with 3848 additions and 1575 deletions

View File

@@ -127,58 +127,56 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
## Change History
- 28 Mar. 2023, 2023/3/28:
- Fix an issue that the training script crashes when `max_data_loader_n_workers` is 0.
- `max_data_loader_n_workers` が0の時に学習スクリプトがエラーとなる不具合を修正しました。
### 8 Apr. 2021, 2021/4/8:
- 27 Mar. 2023, 2023/3/27:
- Fix issues when `--persistent_data_loader_workers` is specified.
- The batch members of the bucket are not shuffled.
- `--caption_dropout_every_n_epochs` does not work.
- These issues occurred because the epoch transition was not recognized correctly. Thanks to u-haru for reporting the issue.
- Fix an issue that images are loaded twice in Windows environment.
- Add Min-SNR Weighting strategy. Details are in [#308](https://github.com/kohya-ss/sd-scripts/pull/308). Thank you to AI-Casanova for this great work!
- Add `--min_snr_gamma` option to training scripts, 5 is recommended by paper.
- Added support for training with weighted captions. Thanks to AI-Casanova for the great contribution!
- Please refer to the PR for details: [PR #336](https://github.com/kohya-ss/sd-scripts/pull/336)
- Specify the `--weighted_captions` option. It is available for all training scripts except Textual Inversion and XTI.
- This option is also applicable to token strings of the DreamBooth method.
- The syntax for weighted captions is almost the same as the Web UI, and you can use things like `(abc)`, `[abc]`, and `(abc:1.23)`. Nesting is also possible.
- If you include a comma in the parentheses, the parentheses will not be properly matched in the prompt shuffle/dropout, so do not include a comma in the parentheses.
- Add tag warmup. Details are in [#322](https://github.com/kohya-ss/sd-scripts/pull/322). Thanks to u-haru!
- Add `token_warmup_min` and `token_warmup_step` to dataset settings.
- Gradually increase the number of tokens from `token_warmup_min` to `token_warmup_step`.
- For example, if `token_warmup_min` is `3` and `token_warmup_step` is `10`, the first step will use the first 3 tokens, and the 10th step will use all tokens.
- Fix a bug in `resize_lora.py`. Thanks to mgz-dev! [#328](https://github.com/kohya-ss/sd-scripts/pull/328)
- Add `--debug_dataset` option to step to the next step with `S` key and to the next epoch with `E` key.
- Fix other bugs.
- 重みづけキャプションによる学習に対応しました。 AI-Casanova 氏の素晴らしい貢献に感謝します。
- 詳細はこちらをご確認ください。[PR #336](https://github.com/kohya-ss/sd-scripts/pull/336)
- `--weighted_captions` オプションを指定してください。Textual InversionおよびXTIを除く学習スクリプトで使用可能です。
- キャプションだけでなく DreamBooth 手法の token string でも有効です。
- 重みづけキャプションの記法はWeb UIとほぼ同じで、`(abc)`や`[abc]`、`(abc:1.23)`などが使用できます。入れ子も可能です。
- 括弧内にカンマを含めるとプロンプトのshuffle/dropoutで括弧の対応付けがおかしくなるため、括弧内にはカンマを含めないでください。
### 6 Apr. 2023, 2023/4/6:
- There may be bugs because I changed a lot. If you cannot revert the script to the previous version when a problem occurs, please wait for the update for a while.
- `--persistent_data_loader_workers` を指定した時の各種不具合を修正しました。
- `--caption_dropout_every_n_epochs` が効かない。
- バケットのバッチメンバーがシャッフルされない。
- エポックの遷移が正しく認識されないために発生していました。ご指摘いただいたu-haru氏に感謝します。
- Windows環境で画像が二重に読み込まれる不具合を修正しました。
- Min-SNR Weighting strategyを追加しました。 詳細は [#308](https://github.com/kohya-ss/sd-scripts/pull/308) をご参照ください。AI-Casanova氏の素晴らしい貢献に感謝します。
- `--min_snr_gamma` オプションを学習スクリプトに追加しました。論文では5が推奨されています。
- タグのウォームアップを追加しました。詳細は [#322](https://github.com/kohya-ss/sd-scripts/pull/322) をご参照ください。u-haru氏に感謝します。
- データセット設定に `token_warmup_min` と `token_warmup_step` を追加しました。
- `token_warmup_min` で指定した数のトークン(カンマ区切りの文字列)から、`token_warmup_step` で指定したステップまで、段階的にトークンを増やしていきます。
- たとえば `token_warmup_min`に `3` を、`token_warmup_step` に `10` を指定すると、最初のステップでは最初から3個のトークンが使われ、10ステップ目では全てのトークンが使われます。
- `resize_lora.py` の不具合を修正しました。mgz-dev氏に感謝します。[#328](https://github.com/kohya-ss/sd-scripts/pull/328)
- `--debug_dataset` オプションで、`S`キーで次のステップへ、`E`キーで次のエポックへ進めるようにしました。
- その他の不具合を修正しました。
- Added a feature to upload model and state to HuggingFace. Thanks to ddPn08 for the contribution! [PR #348](https://github.com/kohya-ss/sd-scripts/pull/348)
- When `--huggingface_repo_id` is specified, the model is uploaded to HuggingFace at the same time as saving the model.
- Please note that the access token is handled with caution. Please refer to the [HuggingFace documentation](https://huggingface.co/docs/hub/security-tokens).
- For example, specify other arguments as follows.
- `--huggingface_repo_id "your-hf-name/your-model" --huggingface_path_in_repo "path" --huggingface_repo_type model --huggingface_repo_visibility private --huggingface_token hf_YourAccessTokenHere`
- If `public` is specified for `--huggingface_repo_visibility`, the repository will be public. If the option is omitted or `private` (or anything other than `public`) is specified, it will be private.
- If you specify `--save_state` and `--save_state_to_huggingface`, the state will also be uploaded.
- If you specify `--resume` and `--resume_from_huggingface`, the state will be downloaded from HuggingFace and resumed.
- In this case, the `--resume` option is `--resume {repo_id}/{path_in_repo}:{revision}:{repo_type}`. For example: `--resume_from_huggingface --resume your-hf-name/your-model/path/test-000002-state:main:model`
- If you specify `--async_upload`, the upload will be done asynchronously.
- Added the documentation for applying LoRA to generate with the standard pipeline of Diffusers. [training LoRA](./train_network_README-ja.md#diffusersのpipelineで生成する) (Japanese only)
- Support for Attention Couple and regional LoRA in `gen_img_diffusers.py`.
- If you use ` AND ` to separate the prompts, each sub-prompt is sequentially applied to LoRA. `--mask_path` is treated as a mask image. The number of sub-prompts and the number of LoRA must match.
- 21 Mar. 2023, 2023/3/21:
- Add `--vae_batch_size` for faster latents caching to each training script. This batches VAE calls.
- Please start with`2` or `4` depending on the size of VRAM.
- Fix a number of training steps with `--gradient_accumulation_steps` and `--max_train_epochs`. Thanks to tsukimiya!
- Extract parser setup to external scripts. Thanks to robertsmieja!
- Fix an issue without `.npz` and with `--full_path` in training.
- Support extensions with upper cases for images for not Windows environment.
- Fix `resize_lora.py` to work with LoRA with dynamic rank (including `conv_dim != network_dim`). Thanks to toshiaki!
- latentsのキャッシュを高速化する`--vae_batch_size` オプションを各学習スクリプトに追加しました。VAE呼び出しをバッチ化します。
-VRAMサイズに応じて、`2` か `4` 程度から試してください
- `--gradient_accumulation_steps` と `--max_train_epochs` を指定した時、当該のepochで学習が止まらない不具合を修正しました。tsukimiya氏に感謝します。
- 外部のスクリプト用に引数parserの構築が関数化されました。robertsmieja氏に感謝します。
- 学習時、`--full_path` 指定時に `.npz` が存在しない場合の不具合を解消しました。
- Windows以外の環境向けに、画像ファイルの大文字の拡張子をサポートしました。
- `resize_lora.py` を dynamic rank rankが各LoRAモジュールで異なる場合、`conv_dim` が `network_dim` と異なる場合も含むの時に正しく動作しない不具合を修正しました。toshiaki氏に感謝します。
- 大きく変更したため不具合があるかもしれません。問題が起きた時にスクリプトを前のバージョンに戻せない場合は、しばらく更新を控えてください。
- モデルおよびstateをHuggingFaceにアップロードする機能を各スクリプトに追加しました。 [PR #348](https://github.com/kohya-ss/sd-scripts/pull/348) ddPn08 氏の貢献に感謝します。
- `--huggingface_repo_id`が指定されているとモデル保存時に同時にHuggingFaceにアップロードします。
- アクセストークンの取り扱いに注意してください。[HuggingFaceのドキュメント](https://huggingface.co/docs/hub/security-tokens)を参照してください。
- 他の引数をたとえば以下のように指定してください。
- `--huggingface_repo_id "your-hf-name/your-model" --huggingface_path_in_repo "path" --huggingface_repo_type model --huggingface_repo_visibility private --huggingface_token hf_YourAccessTokenHere`
- `--huggingface_repo_visibility`に`public`を指定するとリポジトリが公開されます。省略時または`private`(など`public`以外)を指定すると非公開になります。
- `--save_state`オプション指定時に`--save_state_to_huggingface`を指定するとstateもアップロードします。
- `--resume`オプション指定時に`--resume_from_huggingface`を指定するとHuggingFaceからstateをダウンロードして再開します
- その時の `--resume`オプションは `--resume {repo_id}/{path_in_repo}:{revision}:{repo_type}`になります。例: `--resume_from_huggingface --resume your-hf-name/your-model/path/test-000002-state:main:model`
- `--async_upload`オプションを指定するとアップロードを非同期で行います。
- [LoRAの文書](./train_network_README-ja.md#diffusersのpipelineで生成する)に、LoRAを適用してDiffusersの標準的なパイプラインで生成する方法を追記しました。
- `gen_img_diffusers.py` で Attention Couple および領域別LoRAに対応しました。
- プロンプトを` AND `で区切ると各サブプロンプトが順にLoRAに適用されます。`--mask_path` がマスク画像として扱われます。サブプロンプトの数とLoRAの数は一致している必要があります。
## Sample image generation during training
A prompt file might look like this, for example

209
XTI_hijack.py Normal file
View File

@@ -0,0 +1,209 @@
import torch
from typing import Union, List, Optional, Dict, Any, Tuple
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
def unet_forward_XTI(self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[UNet2DConditionOutput, Tuple]:
r"""
Args:
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
Returns:
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
# By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
# However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary.
default_overall_up_factor = 2**self.num_upsamplers
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
forward_upsample_size = False
upsample_size = None
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
logger.info("Forward upsample size to force interpolation output size.")
forward_upsample_size = True
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb)
if self.config.num_class_embeds is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
# 2. pre-process
sample = self.conv_in(sample)
# 3. down
down_block_res_samples = (sample,)
down_i = 0
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states[down_i:down_i+2],
)
down_i += 2
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
# 4. mid
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states[6])
# 5. up
up_i = 7
for i, upsample_block in enumerate(self.up_blocks):
is_final_block = i == len(self.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states[up_i:up_i+3],
upsample_size=upsample_size,
)
up_i += 3
else:
sample = upsample_block(
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
)
# 6. post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if not return_dict:
return (sample,)
return UNet2DConditionOutput(sample=sample)
def downblock_forward_XTI(
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
):
output_states = ()
i = 0
for resnet, attn in zip(self.resnets, self.attentions):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states[i]
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample
output_states += (hidden_states,)
i += 1
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states += (hidden_states,)
return hidden_states, output_states
def upblock_forward_XTI(
self,
hidden_states,
res_hidden_states_tuple,
temb=None,
encoder_hidden_states=None,
upsample_size=None,
):
i = 0
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states[i]
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample
i += 1
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states

View File

@@ -21,7 +21,7 @@ from library.config_util import (
BlueprintGenerator,
)
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings
def train(args):
@@ -231,9 +231,7 @@ def train(args):
train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする
if args.resume is not None:
print(f"resume training from state: {args.resume}")
accelerator.load_state(args.resume)
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -275,7 +273,7 @@ def train(args):
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
@@ -284,10 +282,19 @@ def train(args):
with torch.set_grad_enabled(args.train_text_encoder):
# Get the text embedding for conditioning
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
)
if args.weighted_captions:
encoder_hidden_states = get_weighted_text_embeddings(tokenizer,
text_encoder,
batch["captions"],
accelerator.device,
args.max_token_length // 75 if args.max_token_length else 1,
clip_skip=args.clip_skip,
)
else:
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
@@ -427,4 +434,4 @@ if __name__ == "__main__":
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
train(args)
train(args)

View File

@@ -92,9 +92,12 @@ from PIL.PngImagePlugin import PngInfo
import library.model_util as model_util
import library.train_util as train_util
from networks.lora import LoRANetwork
import tools.original_control_net as original_control_net
from tools.original_control_net import ControlNetInfo
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う
@@ -491,6 +494,9 @@ class PipelineLike:
# Textual Inversion
self.token_replacements = {}
# XTI
self.token_replacements_XTI = {}
# CLIP guidance
self.clip_guidance_scale = clip_guidance_scale
self.clip_image_guidance_scale = clip_image_guidance_scale
@@ -514,15 +520,26 @@ class PipelineLike:
def add_token_replacement(self, target_token_id, rep_token_ids):
self.token_replacements[target_token_id] = rep_token_ids
def replace_token(self, tokens):
def replace_token(self, tokens, layer=None):
new_tokens = []
for token in tokens:
if token in self.token_replacements:
new_tokens.extend(self.token_replacements[token])
replacer_ = self.token_replacements[token]
if layer:
replacer = []
for r in replacer_:
if r in self.token_replacements_XTI:
replacer.append(self.token_replacements_XTI[r][layer])
else:
replacer = replacer_
new_tokens.extend(replacer)
else:
new_tokens.append(token)
return new_tokens
def add_token_replacement_XTI(self, target_token_id, rep_token_ids):
self.token_replacements_XTI[target_token_id] = rep_token_ids
def set_control_nets(self, ctrl_nets):
self.control_nets = ctrl_nets
@@ -618,6 +635,7 @@ class PipelineLike:
img2img_noise=None,
clip_prompts=None,
clip_guide_images=None,
networks: Optional[List[LoRANetwork]] = None,
**kwargs,
):
r"""
@@ -701,6 +719,7 @@ class PipelineLike:
batch_size = len(prompt)
else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
reginonal_network = " AND " in prompt[0]
vae_batch_size = (
batch_size
@@ -744,14 +763,15 @@ class PipelineLike:
" the batch size of `prompt`."
)
text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings(
pipe=self,
prompt=prompt,
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
max_embeddings_multiples=max_embeddings_multiples,
clip_skip=self.clip_skip,
**kwargs,
)
if not self.token_replacements_XTI:
text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings(
pipe=self,
prompt=prompt,
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
max_embeddings_multiples=max_embeddings_multiples,
clip_skip=self.clip_skip,
**kwargs,
)
if negative_scale is not None:
_, real_uncond_embeddings, _ = get_weighted_text_embeddings(
@@ -763,11 +783,47 @@ class PipelineLike:
**kwargs,
)
if do_classifier_free_guidance:
if negative_scale is None:
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
else:
text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings])
if self.token_replacements_XTI:
text_embeddings_concat = []
for layer in [
"IN01",
"IN02",
"IN04",
"IN05",
"IN07",
"IN08",
"MID",
"OUT03",
"OUT04",
"OUT05",
"OUT06",
"OUT07",
"OUT08",
"OUT09",
"OUT10",
"OUT11",
]:
text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings(
pipe=self,
prompt=prompt,
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
max_embeddings_multiples=max_embeddings_multiples,
clip_skip=self.clip_skip,
layer=layer,
**kwargs,
)
if do_classifier_free_guidance:
if negative_scale is None:
text_embeddings_concat.append(torch.cat([uncond_embeddings, text_embeddings]))
else:
text_embeddings_concat.append(torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]))
text_embeddings = torch.stack(text_embeddings_concat)
else:
if do_classifier_free_guidance:
if negative_scale is None:
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
else:
text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings])
# CLIP guidanceで使用するembeddingsを取得する
if self.clip_guidance_scale > 0:
@@ -957,6 +1013,11 @@ class PipelineLike:
# predict the noise residual
if self.control_nets:
if reginonal_network:
num_sub_and_neg_prompts = len(text_embeddings) // batch_size
text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2::num_sub_and_neg_prompts] # last subprompt
else:
text_emb_last = text_embeddings
noise_pred = original_control_net.call_unet_and_control_net(
i,
num_latent_input,
@@ -966,7 +1027,7 @@ class PipelineLike:
i / len(timesteps),
latent_model_input,
t,
text_embeddings,
text_emb_last,
).sample
else:
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
@@ -1675,7 +1736,7 @@ def parse_prompt_attention(text):
return res
def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length: int):
def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length: int, layer=None):
r"""
Tokenize a list of prompts and return its tokens with weights of each token.
No padding, starting or ending token is included.
@@ -1691,7 +1752,7 @@ def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length:
# tokenize and discard the starting and the ending token
token = pipe.tokenizer(word).input_ids[1:-1]
token = pipe.replace_token(token)
token = pipe.replace_token(token, layer=layer)
text_token += token
# copy the weight by length of token
@@ -1807,6 +1868,7 @@ def get_weighted_text_embeddings(
skip_parsing: Optional[bool] = False,
skip_weighting: Optional[bool] = False,
clip_skip=None,
layer=None,
**kwargs,
):
r"""
@@ -1836,12 +1898,18 @@ def get_weighted_text_embeddings(
if isinstance(prompt, str):
prompt = [prompt]
# split the prompts with "AND". each prompt must have the same number of splits
new_prompts = []
for p in prompt:
new_prompts.extend(p.split(" AND "))
prompt = new_prompts
if not skip_parsing:
prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2, layer=layer)
if uncond_prompt is not None:
if isinstance(uncond_prompt, str):
uncond_prompt = [uncond_prompt]
uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2, layer=layer)
else:
prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids]
prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
@@ -2005,6 +2073,7 @@ class BatchDataExt(NamedTuple):
negative_scale: float
strength: float
network_muls: Tuple[float]
num_sub_prompts: int
class BatchData(NamedTuple):
@@ -2221,21 +2290,31 @@ def main(args):
if metadata is not None:
print(f"metadata for: {network_weight}: {metadata}")
network = imported_module.create_network_from_weights(
network_mul, network_weight, vae, text_encoder, unet, **net_kwargs
network, weights_sd = imported_module.create_network_from_weights(
network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs
)
else:
raise ValueError("No weight. Weight is required.")
if network is None:
return
network.apply_to(text_encoder, unet)
mergiable = hasattr(network, "merge_to")
if args.network_merge and not mergiable:
print("network is not mergiable. ignore merge option.")
if args.opt_channels_last:
network.to(memory_format=torch.channels_last)
network.to(dtype).to(device)
if not args.network_merge or not mergiable:
network.apply_to(text_encoder, unet)
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
print(f"weights are loaded: {info}")
if args.opt_channels_last:
network.to(memory_format=torch.channels_last)
network.to(dtype).to(device)
networks.append(network)
else:
network.merge_to(text_encoder, unet, weights_sd, dtype, device)
networks.append(network)
else:
networks = []
@@ -2289,7 +2368,12 @@ def main(args):
if args.diffusers_xformers:
pipe.enable_xformers_memory_efficient_attention()
# Textual Inversionを処理する
# Extended Textual Inversion および Textual Inversionを処理する
if args.XTI_embeddings:
diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI
diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D.forward = downblock_forward_XTI
diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D.forward = upblock_forward_XTI
if args.textual_inversion_embeddings:
token_ids_embeds = []
for embeds_file in args.textual_inversion_embeddings:
@@ -2335,6 +2419,71 @@ def main(args):
for token_id, embed in zip(token_ids, embeds):
token_embeds[token_id] = embed
if args.XTI_embeddings:
XTI_layers = [
"IN01",
"IN02",
"IN04",
"IN05",
"IN07",
"IN08",
"MID",
"OUT03",
"OUT04",
"OUT05",
"OUT06",
"OUT07",
"OUT08",
"OUT09",
"OUT10",
"OUT11",
]
token_ids_embeds_XTI = []
for embeds_file in args.XTI_embeddings:
if model_util.is_safetensors(embeds_file):
from safetensors.torch import load_file
data = load_file(embeds_file)
else:
data = torch.load(embeds_file, map_location="cpu")
if set(data.keys()) != set(XTI_layers):
raise ValueError("NOT XTI")
embeds = torch.concat(list(data.values()))
num_vectors_per_token = data["MID"].size()[0]
token_string = os.path.splitext(os.path.basename(embeds_file))[0]
token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)]
# add new word to tokenizer, count is num_vectors_per_token
num_added_tokens = tokenizer.add_tokens(token_strings)
assert (
num_added_tokens == num_vectors_per_token
), f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}"
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
print(f"XTI embeddings `{token_string}` loaded. Tokens are added: {token_ids}")
# if num_vectors_per_token > 1:
pipe.add_token_replacement(token_ids[0], token_ids)
token_strings_XTI = []
for layer_name in XTI_layers:
token_strings_XTI += [f"{t}_{layer_name}" for t in token_strings]
tokenizer.add_tokens(token_strings_XTI)
token_ids_XTI = tokenizer.convert_tokens_to_ids(token_strings_XTI)
token_ids_embeds_XTI.append((token_ids_XTI, embeds))
for t in token_ids:
t_XTI_dic = {}
for i, layer_name in enumerate(XTI_layers):
t_XTI_dic[layer_name] = t + (i + 1) * num_added_tokens
pipe.add_token_replacement_XTI(t, t_XTI_dic)
text_encoder.resize_token_embeddings(len(tokenizer))
token_embeds = text_encoder.get_input_embeddings().weight.data
for token_ids, embeds in token_ids_embeds_XTI:
for token_id, embed in zip(token_ids, embeds):
token_embeds[token_id] = embed
# promptを取得する
if args.from_file is not None:
print(f"reading prompts from {args.from_file}")
@@ -2428,16 +2577,22 @@ def main(args):
print(f"resize img2img mask images to {args.W}*{args.H}")
mask_images = resize_images(mask_images, (args.W, args.H))
regional_network = False
if networks and mask_images:
# mask を領域情報として流用する、現在は1枚だけ対応
# TODO 複数のnetwork classの混在時の考慮
# mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応
regional_network = True
print("use mask as region")
# import cv2
# for i in range(3):
# cv2.imshow("msk", np.array(mask_images[0])[:,:,i])
# cv2.waitKey()
# cv2.destroyAllWindows()
networks[0].__class__.set_regions(networks, np.array(mask_images[0]))
size = None
for i, network in enumerate(networks):
if i < 3:
np_mask = np.array(mask_images[0])
np_mask = np_mask[:, :, i]
size = np_mask.shape
else:
np_mask = np.full(size, 255, dtype=np.uint8)
mask = torch.from_numpy(np_mask.astype(np.float32) / 255.0)
network.set_region(i, i == len(networks) - 1, mask)
mask_images = None
prev_image = None # for VGG16 guided
@@ -2493,7 +2648,14 @@ def main(args):
height_1st = height_1st - height_1st % 32
ext_1st = BatchDataExt(
width_1st, height_1st, args.highres_fix_steps, ext.scale, ext.negative_scale, ext.strength, ext.network_muls
width_1st,
height_1st,
args.highres_fix_steps,
ext.scale,
ext.negative_scale,
ext.strength,
ext.network_muls,
ext.num_sub_prompts,
)
batch_1st.append(BatchData(args.highres_fix_latents_upscaling, base, ext_1st))
images_1st = process_batch(batch_1st, True, True)
@@ -2521,7 +2683,7 @@ def main(args):
(
return_latents,
(step_first, _, _, _, init_image, mask_image, _, guide_image),
(width, height, steps, scale, negative_scale, strength, network_muls),
(width, height, steps, scale, negative_scale, strength, network_muls, num_sub_prompts),
) = batch[0]
noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
@@ -2613,8 +2775,11 @@ def main(args):
# generate
if networks:
shared = {}
for n, m in zip(networks, network_muls if network_muls else network_default_muls):
n.set_multiplier(m)
if regional_network:
n.set_current_generation(batch_size, num_sub_prompts, width, height, shared)
images = pipe(
prompts,
@@ -2839,11 +3004,26 @@ def main(args):
print("Use previous image as guide image.")
guide_image = prev_image
if regional_network:
num_sub_prompts = len(prompt.split(" AND "))
assert (
len(networks) <= num_sub_prompts
), "Number of networks must be less than or equal to number of sub prompts."
else:
num_sub_prompts = None
b1 = BatchData(
False,
BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
BatchDataExt(
width, height, steps, scale, negative_scale, strength, tuple(network_muls) if network_muls else None
width,
height,
steps,
scale,
negative_scale,
strength,
tuple(network_muls) if network_muls else None,
num_sub_prompts,
),
)
if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要?
@@ -2983,6 +3163,7 @@ def setup_parser() -> argparse.ArgumentParser:
"--network_args", type=str, default=None, nargs="*", help="additional argmuments for network (key=value) / ネットワークへの追加の引数"
)
parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する")
parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする")
parser.add_argument(
"--textual_inversion_embeddings",
type=str,
@@ -2990,6 +3171,13 @@ def setup_parser() -> argparse.ArgumentParser:
nargs="*",
help="Embeddings files of Textual Inversion / Textual Inversionのembeddings",
)
parser.add_argument(
"--XTI_embeddings",
type=str,
default=None,
nargs="*",
help="Embeddings files of Extended Textual Inversion / Extended Textual Inversionのembeddings",
)
parser.add_argument("--clip_skip", type=int, default=None, help="layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う")
parser.add_argument(
"--max_embeddings_multiples",
@@ -3059,6 +3247,9 @@ def setup_parser() -> argparse.ArgumentParser:
nargs="*",
help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率",
)
# parser.add_argument(
# "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像"
# )
return parser

View File

@@ -445,7 +445,7 @@ def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str]
try:
n_repeats = int(tokens[0])
except ValueError as e:
print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}")
print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}")
return 0, ""
caption_by_folder = '_'.join(tokens[1:])
return n_repeats, caption_by_folder
@@ -486,7 +486,8 @@ def load_user_config(file: str) -> dict:
if file.name.lower().endswith('.json'):
try:
config = json.load(file)
with open(file, 'r') as f:
config = json.load(f)
except Exception:
print(f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}")
raise

View File

@@ -1,18 +1,344 @@
import torch
import argparse
import re
from typing import List, Optional, Union
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
alphas_cumprod = noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
alpha = sqrt_alphas_cumprod
sigma = sqrt_one_minus_alphas_cumprod
all_snr = (alpha / sigma) ** 2
snr = torch.stack([all_snr[t] for t in timesteps])
gamma_over_snr = torch.div(torch.ones_like(snr)*gamma,snr)
snr_weight = torch.minimum(gamma_over_snr,torch.ones_like(gamma_over_snr)).float() #from paper
loss = loss * snr_weight
return loss
def add_custom_train_arguments(parser: argparse.ArgumentParser):
parser.add_argument("--min_snr_gamma", type=float, default=None, help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨")
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
alphas_cumprod = noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
alpha = sqrt_alphas_cumprod
sigma = sqrt_one_minus_alphas_cumprod
all_snr = (alpha / sigma) ** 2
snr = torch.stack([all_snr[t] for t in timesteps])
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float() # from paper
loss = loss * snr_weight
return loss
def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted_captions: bool = True):
parser.add_argument(
"--min_snr_gamma",
type=float,
default=None,
help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨",
)
if support_weighted_captions:
parser.add_argument(
"--weighted_captions",
action="store_true",
default=False,
help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意",
)
re_attention = re.compile(
r"""
\\\(|
\\\)|
\\\[|
\\]|
\\\\|
\\|
\(|
\[|
:([+-]?[.\d]+)\)|
\)|
]|
[^\\()\[\]:]+|
:
""",
re.X,
)
def parse_prompt_attention(text):
"""
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
Accepted tokens are:
(abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12
[abc] - decreases attention to abc by a multiplier of 1.1
\( - literal character '('
\[ - literal character '['
\) - literal character ')'
\] - literal character ']'
\\ - literal character '\'
anything else - just text
>>> parse_prompt_attention('normal text')
[['normal text', 1.0]]
>>> parse_prompt_attention('an (important) word')
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
>>> parse_prompt_attention('(unbalanced')
[['unbalanced', 1.1]]
>>> parse_prompt_attention('\(literal\]')
[['(literal]', 1.0]]
>>> parse_prompt_attention('(unnecessary)(parens)')
[['unnecessaryparens', 1.1]]
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
[['a ', 1.0],
['house', 1.5730000000000004],
[' ', 1.1],
['on', 1.0],
[' a ', 1.1],
['hill', 0.55],
[', sun, ', 1.1],
['sky', 1.4641000000000006],
['.', 1.1]]
"""
res = []
round_brackets = []
square_brackets = []
round_bracket_multiplier = 1.1
square_bracket_multiplier = 1 / 1.1
def multiply_range(start_position, multiplier):
for p in range(start_position, len(res)):
res[p][1] *= multiplier
for m in re_attention.finditer(text):
text = m.group(0)
weight = m.group(1)
if text.startswith("\\"):
res.append([text[1:], 1.0])
elif text == "(":
round_brackets.append(len(res))
elif text == "[":
square_brackets.append(len(res))
elif weight is not None and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), float(weight))
elif text == ")" and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), round_bracket_multiplier)
elif text == "]" and len(square_brackets) > 0:
multiply_range(square_brackets.pop(), square_bracket_multiplier)
else:
res.append([text, 1.0])
for pos in round_brackets:
multiply_range(pos, round_bracket_multiplier)
for pos in square_brackets:
multiply_range(pos, square_bracket_multiplier)
if len(res) == 0:
res = [["", 1.0]]
# merge runs of identical weights
i = 0
while i + 1 < len(res):
if res[i][1] == res[i + 1][1]:
res[i][0] += res[i + 1][0]
res.pop(i + 1)
else:
i += 1
return res
def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int):
r"""
Tokenize a list of prompts and return its tokens with weights of each token.
No padding, starting or ending token is included.
"""
tokens = []
weights = []
truncated = False
for text in prompt:
texts_and_weights = parse_prompt_attention(text)
text_token = []
text_weight = []
for word, weight in texts_and_weights:
# tokenize and discard the starting and the ending token
token = tokenizer(word).input_ids[1:-1]
text_token += token
# copy the weight by length of token
text_weight += [weight] * len(token)
# stop if the text is too long (longer than truncation limit)
if len(text_token) > max_length:
truncated = True
break
# truncate
if len(text_token) > max_length:
truncated = True
text_token = text_token[:max_length]
text_weight = text_weight[:max_length]
tokens.append(text_token)
weights.append(text_weight)
if truncated:
print("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
return tokens, weights
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
r"""
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
"""
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
for i in range(len(tokens)):
tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
if no_boseos_middle:
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
else:
w = []
if len(weights[i]) == 0:
w = [1.0] * weights_length
else:
for j in range(max_embeddings_multiples):
w.append(1.0) # weight for starting token in this chunk
w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
w.append(1.0) # weight for ending token in this chunk
w += [1.0] * (weights_length - len(w))
weights[i] = w[:]
return tokens, weights
def get_unweighted_text_embeddings(
tokenizer,
text_encoder,
text_input: torch.Tensor,
chunk_length: int,
clip_skip: int,
eos: int,
pad: int,
no_boseos_middle: Optional[bool] = True,
):
"""
When the length of tokens is a multiple of the capacity of the text encoder,
it should be split into chunks and sent to the text encoder individually.
"""
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
if max_embeddings_multiples > 1:
text_embeddings = []
for i in range(max_embeddings_multiples):
# extract the i-th chunk
text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
# cover the head and the tail by the starting and the ending tokens
text_input_chunk[:, 0] = text_input[0, 0]
if pad == eos: # v1
text_input_chunk[:, -1] = text_input[0, -1]
else: # v2
for j in range(len(text_input_chunk)):
if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
text_input_chunk[j, -1] = eos
if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
text_input_chunk[j, 1] = eos
if clip_skip is None or clip_skip == 1:
text_embedding = text_encoder(text_input_chunk)[0]
else:
enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
text_embedding = enc_out["hidden_states"][-clip_skip]
text_embedding = text_encoder.text_model.final_layer_norm(text_embedding)
# cover the head and the tail by the starting and the ending tokens
text_input_chunk[:, 0] = text_input[0, 0]
text_input_chunk[:, -1] = text_input[0, -1]
text_embedding = text_encoder(text_input_chunk, attention_mask=None)[0]
if no_boseos_middle:
if i == 0:
# discard the ending token
text_embedding = text_embedding[:, :-1]
elif i == max_embeddings_multiples - 1:
# discard the starting token
text_embedding = text_embedding[:, 1:]
else:
# discard both starting and ending tokens
text_embedding = text_embedding[:, 1:-1]
text_embeddings.append(text_embedding)
text_embeddings = torch.concat(text_embeddings, axis=1)
else:
text_embeddings = text_encoder(text_input)[0]
return text_embeddings
def get_weighted_text_embeddings(
tokenizer,
text_encoder,
prompt: Union[str, List[str]],
device,
max_embeddings_multiples: Optional[int] = 3,
no_boseos_middle: Optional[bool] = False,
clip_skip=None,
):
r"""
Prompts can be assigned with local weights using brackets. For example,
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
The max multiple length of prompt embeddings compared to the max output length of text encoder.
no_boseos_middle (`bool`, *optional*, defaults to `False`):
If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
ending token in each of the chunk in the middle.
skip_parsing (`bool`, *optional*, defaults to `False`):
Skip the parsing of brackets.
skip_weighting (`bool`, *optional*, defaults to `False`):
Skip the weighting. When the parsing is skipped, it is forced True.
"""
max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
if isinstance(prompt, str):
prompt = [prompt]
prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, prompt, max_length - 2)
# round up the longest length of tokens to a multiple of (model_max_length - 2)
max_length = max([len(token) for token in prompt_tokens])
max_embeddings_multiples = min(
max_embeddings_multiples,
(max_length - 1) // (tokenizer.model_max_length - 2) + 1,
)
max_embeddings_multiples = max(1, max_embeddings_multiples)
max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
# pad the length of tokens and weights
bos = tokenizer.bos_token_id
eos = tokenizer.eos_token_id
pad = tokenizer.pad_token_id
prompt_tokens, prompt_weights = pad_tokens_and_weights(
prompt_tokens,
prompt_weights,
max_length,
bos,
eos,
no_boseos_middle=no_boseos_middle,
chunk_length=tokenizer.model_max_length,
)
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device)
# get the embeddings
text_embeddings = get_unweighted_text_embeddings(
tokenizer,
text_encoder,
prompt_tokens,
tokenizer.model_max_length,
clip_skip,
eos,
pad,
no_boseos_middle=no_boseos_middle,
)
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device)
# assign weights to the prompts and normalize in the sense of mean
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
text_embeddings = text_embeddings * prompt_weights.unsqueeze(-1)
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
text_embeddings = text_embeddings * (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
return text_embeddings

View File

@@ -0,0 +1,78 @@
from typing import *
from huggingface_hub import HfApi
from pathlib import Path
import argparse
import os
from library.utils import fire_in_thread
def exists_repo(
repo_id: str, repo_type: str, revision: str = "main", token: str = None
):
api = HfApi(
token=token,
)
try:
api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
return True
except:
return False
def upload(
args: argparse.Namespace,
src: Union[str, Path, bytes, BinaryIO],
dest_suffix: str = "",
force_sync_upload: bool = False,
):
repo_id = args.huggingface_repo_id
repo_type = args.huggingface_repo_type
token = args.huggingface_token
path_in_repo = args.huggingface_path_in_repo + dest_suffix
private = args.huggingface_repo_visibility is None or args.huggingface_repo_visibility != "public"
api = HfApi(token=token)
if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token):
api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
is_folder = (type(src) == str and os.path.isdir(src)) or (
isinstance(src, Path) and src.is_dir()
)
def uploader():
if is_folder:
api.upload_folder(
repo_id=repo_id,
repo_type=repo_type,
folder_path=src,
path_in_repo=path_in_repo,
)
else:
api.upload_file(
repo_id=repo_id,
repo_type=repo_type,
path_or_fileobj=src,
path_in_repo=path_in_repo,
)
if args.async_upload and not force_sync_upload:
fire_in_thread(uploader)
else:
uploader()
def list_dir(
repo_id: str,
subfolder: str,
repo_type: str,
revision: str = "main",
token: str = None,
):
api = HfApi(
token=token,
)
repo_info = api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
file_list = [
file for file in repo_info.siblings if file.rfilename.startswith(subfolder)
]
return file_list

File diff suppressed because it is too large Load Diff

View File

@@ -2,6 +2,7 @@
import argparse
import ast
import asyncio
import importlib
import json
import pathlib
@@ -49,6 +50,7 @@ from diffusers import (
KDPM2DiscreteScheduler,
KDPM2AncestralDiscreteScheduler,
)
from huggingface_hub import hf_hub_download
import albumentations as albu
import numpy as np
from PIL import Image
@@ -58,6 +60,7 @@ from torch import einsum
import safetensors.torch
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
import library.model_util as model_util
import library.huggingface_util as huggingface_util
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
@@ -404,6 +407,8 @@ class BaseDataset(torch.utils.data.Dataset):
self.token_padding_disabled = False
self.tag_frequency = {}
self.XTI_layers = None
self.token_strings = None
self.enable_bucket = False
self.bucket_manager: BucketManager = None # not initialized
@@ -464,6 +469,10 @@ class BaseDataset(torch.utils.data.Dataset):
def disable_token_padding(self):
self.token_padding_disabled = True
def enable_XTI(self, layers=None, token_strings=None):
self.XTI_layers = layers
self.token_strings = token_strings
def add_replacement(self, str_from, str_to):
self.replacements[str_from] = str_to
@@ -481,7 +490,7 @@ class BaseDataset(torch.utils.data.Dataset):
else:
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
tokens = [t.strip() for t in caption.strip().split(",")]
if subset.token_warmup_step < 1: # 初回に上書きする
if subset.token_warmup_step < 1: # 初回に上書きする
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
tokens_len = (
@@ -909,9 +918,22 @@ class BaseDataset(torch.utils.data.Dataset):
latents_list.append(latents)
caption = self.process_caption(subset, image_info.caption)
captions.append(caption)
if self.XTI_layers:
caption_layer = []
for layer in self.XTI_layers:
token_strings_from = " ".join(self.token_strings)
token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings])
caption_ = caption.replace(token_strings_from, token_strings_to)
caption_layer.append(caption_)
captions.append(caption_layer)
else:
captions.append(caption)
if not self.token_padding_disabled: # this option might be omitted in future
input_ids_list.append(self.get_input_ids(caption))
if self.XTI_layers:
token_caption = self.get_input_ids(caption_layer)
else:
token_caption = self.get_input_ids(caption)
input_ids_list.append(token_caption)
example = {}
example["loss_weights"] = torch.FloatTensor(loss_weights)
@@ -931,10 +953,10 @@ class BaseDataset(torch.utils.data.Dataset):
example["images"] = images
example["latents"] = torch.stack(latents_list) if latents_list[0] is not None else None
example["captions"] = captions
if self.debug_dataset:
example["image_keys"] = bucket[image_index : image_index + self.batch_size]
example["captions"] = captions
return example
@@ -1314,6 +1336,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
# for dataset in self.datasets:
# dataset.make_buckets()
def enable_XTI(self, *args, **kwargs):
for dataset in self.datasets:
dataset.enable_XTI(*args, **kwargs)
def cache_latents(self, vae, vae_batch_size=1):
for i, dataset in enumerate(self.datasets):
print(f"[Dataset {i}]")
@@ -1418,7 +1444,6 @@ def glob_images_pathlib(dir_path, recursive):
# endregion
# region モジュール入れ替え部
"""
高速化のためのモジュール入れ替え
@@ -1873,6 +1898,38 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
parser.add_argument("--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ")
parser.add_argument("--output_name", type=str, default=None, help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名")
parser.add_argument(
"--huggingface_repo_id", type=str, default=None, help="huggingface repo name to upload / huggingfaceにアップロードするリポジトリ名"
)
parser.add_argument(
"--huggingface_repo_type", type=str, default=None, help="huggingface repo type to upload / huggingfaceにアップロードするリポジトリの種類"
)
parser.add_argument(
"--huggingface_path_in_repo",
type=str,
default=None,
help="huggingface model path to upload files / huggingfaceにアップロードするファイルのパス",
)
parser.add_argument("--huggingface_token", type=str, default=None, help="huggingface token / huggingfaceのトークン")
parser.add_argument(
"--huggingface_repo_visibility",
type=str,
default=None,
help="huggingface repository visibility ('public' for public, 'private' or None for private) / huggingfaceにアップロードするリポジトリの公開設定'public'で公開、'private'またはNoneで非公開",
)
parser.add_argument(
"--save_state_to_huggingface", action="store_true", help="save state to huggingface / huggingfaceにstateを保存する"
)
parser.add_argument(
"--resume_from_huggingface",
action="store_true",
help="resume from huggingface (ex: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type}) / huggingfaceから学習を再開する(例: --resume {repo_id}/{path_in_repo}:{revision}:{repo_type})",
)
parser.add_argument(
"--async_upload",
action="store_true",
help="upload to huggingface asynchronously / huggingfaceに非同期でアップロードする",
)
parser.add_argument(
"--save_precision",
type=str,
@@ -2238,6 +2295,57 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar
# region utils
def resume_from_local_or_hf_if_specified(accelerator, args):
if not args.resume:
return
if not args.resume_from_huggingface:
print(f"resume training from local state: {args.resume}")
accelerator.load_state(args.resume)
return
print(f"resume training from huggingface state: {args.resume}")
repo_id = args.resume.split("/")[0] + "/" + args.resume.split("/")[1]
path_in_repo = "/".join(args.resume.split("/")[2:])
revision = None
repo_type = None
if ":" in path_in_repo:
divided = path_in_repo.split(":")
if len(divided) == 2:
path_in_repo, revision = divided
repo_type = "model"
else:
path_in_repo, revision, repo_type = divided
print(f"Downloading state from huggingface: {repo_id}/{path_in_repo}@{revision}")
list_files = huggingface_util.list_dir(
repo_id=repo_id,
subfolder=path_in_repo,
revision=revision,
token=args.huggingface_token,
repo_type=repo_type,
)
async def download(filename) -> str:
def task():
return hf_hub_download(
repo_id=repo_id,
filename=filename,
revision=revision,
repo_type=repo_type,
token=args.huggingface_token,
)
return await asyncio.get_event_loop().run_in_executor(None, task)
loop = asyncio.get_event_loop()
results = loop.run_until_complete(asyncio.gather(*[download(filename=filename.rfilename) for filename in list_files]))
if len(results) == 0:
raise ValueError("No files found in the specified repo id/path/revision / 指定されたリポジトリID/パス/リビジョンにファイルが見つかりませんでした")
dirname = os.path.dirname(results[0])
accelerator.load_state(dirname)
def get_optimizer(args, trainable_params):
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, Adafactor"
@@ -2437,7 +2545,7 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
Unified API to get any scheduler from its name.
"""
name = args.lr_scheduler
num_warmup_steps = args.lr_warmup_steps
num_warmup_steps: Optional[int] = args.lr_warmup_steps
num_training_steps = args.max_train_steps * num_processes * args.gradient_accumulation_steps
num_cycles = args.lr_scheduler_num_cycles
power = args.lr_scheduler_power
@@ -2461,6 +2569,11 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
lr_scheduler_kwargs[key] = value
def wrap_check_needless_num_warmup_steps(return_vals):
if num_warmup_steps is not None and num_warmup_steps != 0:
raise ValueError(f"{name} does not require `num_warmup_steps`. Set None or 0.")
return return_vals
# using any lr_scheduler from other library
if args.lr_scheduler_type:
lr_scheduler_type = args.lr_scheduler_type
@@ -2473,7 +2586,7 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
lr_scheduler_type = values[-1]
lr_scheduler_class = getattr(lr_scheduler_module, lr_scheduler_type)
lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_kwargs)
return lr_scheduler
return wrap_check_needless_num_warmup_steps(lr_scheduler)
if name.startswith("adafactor"):
assert (
@@ -2481,12 +2594,12 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
), f"adafactor scheduler must be used with Adafactor optimizer / adafactor schedulerはAdafactorオプティマイザと同時に使ってください"
initial_lr = float(name.split(":")[1])
# print("adafactor scheduler init lr", initial_lr)
return transformers.optimization.AdafactorSchedule(optimizer, initial_lr)
return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr))
name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
if name == SchedulerType.CONSTANT:
return schedule_func(optimizer)
return wrap_check_needless_num_warmup_steps(schedule_func(optimizer))
# All other schedulers require `num_warmup_steps`
if num_warmup_steps is None:
@@ -2617,14 +2730,15 @@ def prepare_dtype(args: argparse.Namespace):
return weight_dtype, save_dtype
def load_target_model(args: argparse.Namespace, weight_dtype):
def load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"):
name_or_path = args.pretrained_model_name_or_path
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
if load_stable_diffusion_format:
print("load StableDiffusion checkpoint")
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path)
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, name_or_path, device)
else:
# Diffusers model is loaded to CPU
print("load Diffusers pretrained models")
try:
pipe = StableDiffusionPipeline.from_pretrained(name_or_path, tokenizer=None, safety_checker=None)
@@ -2743,6 +2857,8 @@ def save_sd_model_on_epoch_end(
model_util.save_stable_diffusion_checkpoint(
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae
)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name)
def remove_sd(old_epoch_no):
_, old_ckpt_name = get_epoch_ckpt_name(args, use_safetensors, old_epoch_no)
@@ -2762,6 +2878,8 @@ def save_sd_model_on_epoch_end(
model_util.save_diffusers_checkpoint(
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, out_dir, "/" + model_name)
def remove_du(old_epoch_no):
out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no))
@@ -2779,7 +2897,11 @@ def save_sd_model_on_epoch_end(
def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, epoch_no):
print("saving state.")
accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no)))
state_dir = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))
accelerator.save_state(state_dir)
if args.save_state_to_huggingface:
print("uploading state to huggingface.")
huggingface_util.upload(args, state_dir, "/" + EPOCH_STATE_NAME.format(model_name, epoch_no))
last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs
if last_n_epochs is not None:
@@ -2790,6 +2912,17 @@ def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, e
shutil.rmtree(state_dir_old)
def save_state_on_train_end(args: argparse.Namespace, accelerator):
print("saving last state.")
os.makedirs(args.output_dir, exist_ok=True)
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
state_dir = os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name))
accelerator.save_state(state_dir)
if args.save_state_to_huggingface:
print("uploading last state to huggingface.")
huggingface_util.upload(args, state_dir, "/" + LAST_STATE_NAME.format(model_name))
def save_sd_model_on_train_end(
args: argparse.Namespace,
src_path: str,
@@ -2814,6 +2947,8 @@ def save_sd_model_on_train_end(
model_util.save_stable_diffusion_checkpoint(
args.v2, ckpt_file, text_encoder, unet, src_path, epoch, global_step, save_dtype, vae
)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True)
else:
out_dir = os.path.join(args.output_dir, model_name)
os.makedirs(out_dir, exist_ok=True)
@@ -2822,13 +2957,8 @@ def save_sd_model_on_train_end(
model_util.save_diffusers_checkpoint(
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
)
def save_state_on_train_end(args: argparse.Namespace, accelerator):
print("saving last state.")
os.makedirs(args.output_dir, exist_ok=True)
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)
# scheduler:
@@ -3060,7 +3190,7 @@ class collater_class:
def __init__(self, epoch, step, dataset):
self.current_epoch = epoch
self.current_step = step
self.dataset = dataset # not used if worker_info is not None, in case of multiprocessing
self.dataset = dataset # not used if worker_info is not None, in case of multiprocessing
def __call__(self, examples):
worker_info = torch.utils.data.get_worker_info()
@@ -3073,4 +3203,4 @@ class collater_class:
# set epoch and step
dataset.set_current_epoch(self.current_epoch.value)
dataset.set_current_step(self.current_step.value)
return examples[0]
return examples[0]

6
library/utils.py Normal file
View File

@@ -0,0 +1,6 @@
import threading
from typing import *
def fire_in_thread(f, *args, **kwargs):
threading.Thread(target=f, args=args, kwargs=kwargs).start()

View File

@@ -145,8 +145,8 @@ def svd(args):
lora_sd[lora_name + '.alpha'] = torch.tensor(down_weight.size()[0])
# load state dict to LoRA and save it
lora_network_save = lora.create_network_from_weights(1.0, None, None, text_encoder_o, unet_o, weights_sd=lora_sd)
lora_network_save.apply_to(text_encoder_o, unet_o) # create internal module references for state_dict
lora_network_save, lora_sd = lora.create_network_from_weights(1.0, None, None, text_encoder_o, unet_o, weights_sd=lora_sd)
lora_network_save.apply_to(text_encoder_o, unet_o) # create internal module references for state_dict
info = lora_network_save.load_state_dict(lora_sd)
print(f"Loading extracted LoRA weights: {info}")

File diff suppressed because it is too large Load Diff

View File

@@ -1,4 +1,3 @@
import math
import argparse
import os
@@ -9,216 +8,236 @@ import lora
def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == '.safetensors':
sd = load_file(file_name)
else:
sd = torch.load(file_name, map_location='cpu')
for key in list(sd.keys()):
if type(sd[key]) == torch.Tensor:
sd[key] = sd[key].to(dtype)
return sd
if os.path.splitext(file_name)[1] == ".safetensors":
sd = load_file(file_name)
else:
sd = torch.load(file_name, map_location="cpu")
for key in list(sd.keys()):
if type(sd[key]) == torch.Tensor:
sd[key] = sd[key].to(dtype)
return sd
def save_to_file(file_name, model, state_dict, dtype):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
if os.path.splitext(file_name)[1] == '.safetensors':
save_file(model, file_name)
else:
torch.save(model, file_name)
if os.path.splitext(file_name)[1] == ".safetensors":
save_file(model, file_name)
else:
torch.save(model, file_name)
def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
text_encoder.to(merge_dtype)
unet.to(merge_dtype)
text_encoder.to(merge_dtype)
unet.to(merge_dtype)
# create module map
name_to_module = {}
for i, root_module in enumerate([text_encoder, unet]):
if i == 0:
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER
target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
else:
prefix = lora.LoRANetwork.LORA_PREFIX_UNET
target_replace_modules = lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
lora_name = prefix + '.' + name + '.' + child_name
lora_name = lora_name.replace('.', '_')
name_to_module[lora_name] = child_module
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
lora_sd = load_state_dict(model, merge_dtype)
print(f"merging...")
for key in lora_sd.keys():
if "lora_down" in key:
up_key = key.replace("lora_down", "lora_up")
alpha_key = key[:key.index("lora_down")] + 'alpha'
# find original module for this lora
module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight"
if module_name not in name_to_module:
print(f"no module found for LoRA weight: {key}")
continue
module = name_to_module[module_name]
# print(f"apply {key} to {module}")
down_weight = lora_sd[key]
up_weight = lora_sd[up_key]
dim = down_weight.size()[0]
alpha = lora_sd.get(alpha_key, dim)
scale = alpha / dim
# W <- W + U * D
weight = module.weight
# print(module_name, down_weight.size(), up_weight.size())
if len(weight.size()) == 2:
# linear
weight = weight + ratio * (up_weight @ down_weight) * scale
elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
).unsqueeze(2).unsqueeze(3) * scale
# create module map
name_to_module = {}
for i, root_module in enumerate([text_encoder, unet]):
if i == 0:
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER
target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# print(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + ratio * conved * scale
prefix = lora.LoRANetwork.LORA_PREFIX_UNET
target_replace_modules = (
lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
)
module.weight = torch.nn.Parameter(weight)
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
lora_name = prefix + "." + name + "." + child_name
lora_name = lora_name.replace(".", "_")
name_to_module[lora_name] = child_module
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
lora_sd = load_state_dict(model, merge_dtype)
print(f"merging...")
for key in lora_sd.keys():
if "lora_down" in key:
up_key = key.replace("lora_down", "lora_up")
alpha_key = key[: key.index("lora_down")] + "alpha"
# find original module for this lora
module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
if module_name not in name_to_module:
print(f"no module found for LoRA weight: {key}")
continue
module = name_to_module[module_name]
# print(f"apply {key} to {module}")
down_weight = lora_sd[key]
up_weight = lora_sd[up_key]
dim = down_weight.size()[0]
alpha = lora_sd.get(alpha_key, dim)
scale = alpha / dim
# W <- W + U * D
weight = module.weight
# print(module_name, down_weight.size(), up_weight.size())
if len(weight.size()) == 2:
# linear
weight = weight + ratio * (up_weight @ down_weight) * scale
elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1
weight = (
weight
+ ratio
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* scale
)
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# print(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + ratio * conved * scale
module.weight = torch.nn.Parameter(weight)
def merge_lora_models(models, ratios, merge_dtype):
base_alphas = {} # alpha for merged model
base_dims = {}
base_alphas = {} # alpha for merged model
base_dims = {}
merged_sd = {}
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
lora_sd = load_state_dict(model, merge_dtype)
merged_sd = {}
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
lora_sd = load_state_dict(model, merge_dtype)
# get alpha and dim
alphas = {} # alpha for current model
dims = {} # dims for current model
for key in lora_sd.keys():
if 'alpha' in key:
lora_module_name = key[:key.rfind(".alpha")]
alpha = float(lora_sd[key].detach().numpy())
alphas[lora_module_name] = alpha
if lora_module_name not in base_alphas:
base_alphas[lora_module_name] = alpha
elif "lora_down" in key:
lora_module_name = key[:key.rfind(".lora_down")]
dim = lora_sd[key].size()[0]
dims[lora_module_name] = dim
if lora_module_name not in base_dims:
base_dims[lora_module_name] = dim
# get alpha and dim
alphas = {} # alpha for current model
dims = {} # dims for current model
for key in lora_sd.keys():
if "alpha" in key:
lora_module_name = key[: key.rfind(".alpha")]
alpha = float(lora_sd[key].detach().numpy())
alphas[lora_module_name] = alpha
if lora_module_name not in base_alphas:
base_alphas[lora_module_name] = alpha
elif "lora_down" in key:
lora_module_name = key[: key.rfind(".lora_down")]
dim = lora_sd[key].size()[0]
dims[lora_module_name] = dim
if lora_module_name not in base_dims:
base_dims[lora_module_name] = dim
for lora_module_name in dims.keys():
if lora_module_name not in alphas:
alpha = dims[lora_module_name]
alphas[lora_module_name] = alpha
if lora_module_name not in base_alphas:
base_alphas[lora_module_name] = alpha
for lora_module_name in dims.keys():
if lora_module_name not in alphas:
alpha = dims[lora_module_name]
alphas[lora_module_name] = alpha
if lora_module_name not in base_alphas:
base_alphas[lora_module_name] = alpha
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
# merge
print(f"merging...")
for key in lora_sd.keys():
if 'alpha' in key:
continue
# merge
print(f"merging...")
for key in lora_sd.keys():
if "alpha" in key:
continue
lora_module_name = key[:key.rfind(".lora_")]
lora_module_name = key[: key.rfind(".lora_")]
base_alpha = base_alphas[lora_module_name]
alpha = alphas[lora_module_name]
base_alpha = base_alphas[lora_module_name]
alpha = alphas[lora_module_name]
scale = math.sqrt(alpha / base_alpha) * ratio
scale = math.sqrt(alpha / base_alpha) * ratio
if key in merged_sd:
assert merged_sd[key].size() == lora_sd[key].size(
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
else:
merged_sd[key] = lora_sd[key] * scale
if key in merged_sd:
assert (
merged_sd[key].size() == lora_sd[key].size()
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
else:
merged_sd[key] = lora_sd[key] * scale
# set alpha to sd
for lora_module_name, alpha in base_alphas.items():
key = lora_module_name + ".alpha"
merged_sd[key] = torch.tensor(alpha)
# set alpha to sd
for lora_module_name, alpha in base_alphas.items():
key = lora_module_name + ".alpha"
merged_sd[key] = torch.tensor(alpha)
print("merged model")
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
print("merged model")
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
return merged_sd
return merged_sd
def merge(args):
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
def str_to_dtype(p):
if p == 'float':
return torch.float
if p == 'fp16':
return torch.float16
if p == 'bf16':
return torch.bfloat16
return None
def str_to_dtype(p):
if p == "float":
return torch.float
if p == "fp16":
return torch.float16
if p == "bf16":
return torch.bfloat16
return None
merge_dtype = str_to_dtype(args.precision)
save_dtype = str_to_dtype(args.save_precision)
if save_dtype is None:
save_dtype = merge_dtype
merge_dtype = str_to_dtype(args.precision)
save_dtype = str_to_dtype(args.save_precision)
if save_dtype is None:
save_dtype = merge_dtype
if args.sd_model is not None:
print(f"loading SD model: {args.sd_model}")
if args.sd_model is not None:
print(f"loading SD model: {args.sd_model}")
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype)
merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype)
print(f"saving SD model to: {args.save_to}")
model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet,
args.sd_model, 0, 0, save_dtype, vae)
else:
state_dict = merge_lora_models(args.models, args.ratios, merge_dtype)
print(f"saving SD model to: {args.save_to}")
model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet, args.sd_model, 0, 0, save_dtype, vae)
else:
state_dict = merge_lora_models(args.models, args.ratios, merge_dtype)
print(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
print(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("--v2", action='store_true',
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
parser.add_argument("--save_precision", type=str, default=None,
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ")
parser.add_argument("--precision", type=str, default="float",
choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度floatを推奨")
parser.add_argument("--sd_model", type=str, default=None,
help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする")
parser.add_argument("--save_to", type=str, default=None,
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
parser.add_argument("--models", type=str, nargs='*',
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors")
parser.add_argument("--ratios", type=float, nargs='*',
help="ratios for each model / それぞれのLoRAモデルの比率")
parser = argparse.ArgumentParser()
parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む")
parser.add_argument(
"--save_precision",
type=str,
default=None,
choices=[None, "float", "fp16", "bf16"],
help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ",
)
parser.add_argument(
"--precision",
type=str,
default="float",
choices=["float", "fp16", "bf16"],
help="precision in merging (float is recommended) / マージの計算時の精度floatを推奨",
)
parser.add_argument(
"--sd_model",
type=str,
default=None,
help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする",
)
parser.add_argument(
"--save_to", type=str, default=None, help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors"
)
parser.add_argument(
"--models", type=str, nargs="*", help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors"
)
parser.add_argument("--ratios", type=float, nargs="*", help="ratios for each model / それぞれのLoRAモデルの比率")
return parser
return parser
if __name__ == '__main__':
parser = setup_parser()
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
merge(args)
args = parser.parse_args()
merge(args)

View File

@@ -21,6 +21,6 @@ fairscale==0.4.13
# for WD14 captioning
# tensorflow<2.11
tensorflow==2.10.1
huggingface-hub==0.12.0
huggingface-hub==0.13.3
# for kohya_ss library
.

View File

@@ -801,7 +801,7 @@ model_dirオプションでモデルの保存先フォルダを指定できま
キャプションをメタデータに入れるには、作業フォルダ内で以下を実行してくださいキャプションを学習に使わない場合は実行不要です実際は1行で記述します、以下同様。`--full_path` オプションを指定してメタデータに画像ファイルの場所をフルパスで格納します。このオプションを省略すると相対パスで記録されますが、フォルダ指定が `.toml` ファイル内で別途必要になります。
```
python merge_captions_to_metadata.py --full_apth <教師データフォルダ>
python merge_captions_to_metadata.py --full_path <教師データフォルダ>
  --in_json <読み込むメタデータファイル名> <メタデータファイル名>
```

View File

@@ -23,8 +23,7 @@ from library.config_util import (
BlueprintGenerator,
)
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings
def train(args):
train_util.verify_training_args(args)
@@ -202,9 +201,7 @@ def train(args):
train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする
if args.resume is not None:
print(f"resume training from state: {args.resume}")
accelerator.load_state(args.resume)
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -273,10 +270,19 @@ def train(args):
# Get the text embedding for conditioning
with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
)
if args.weighted_captions:
encoder_hidden_states = get_weighted_text_embeddings(tokenizer,
text_encoder,
batch["captions"],
accelerator.device,
args.max_token_length // 75 if args.max_token_length else 1,
clip_skip=args.clip_skip,
)
else:
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
)
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
@@ -426,4 +432,4 @@ if __name__ == "__main__":
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
train(args)
train(args)

View File

@@ -24,24 +24,40 @@ from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
import library.huggingface_util as huggingface_util
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings
# TODO 他のスクリプトと共通化する
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
logs = {"loss/current": current_loss, "loss/average": avr_loss}
if args.network_train_unet_only:
logs["lr/unet"] = float(lr_scheduler.get_last_lr()[0])
elif args.network_train_text_encoder_only:
logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
else:
logs["lr/textencoder"] = float(lr_scheduler.get_last_lr()[0])
logs["lr/unet"] = float(lr_scheduler.get_last_lr()[-1]) # may be same to textencoder
lrs = lr_scheduler.get_last_lr()
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet.
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
if args.network_train_text_encoder_only or len(lrs) <= 2: # not block lr (or single block)
if args.network_train_unet_only:
logs["lr/unet"] = float(lrs[0])
elif args.network_train_text_encoder_only:
logs["lr/textencoder"] = float(lrs[0])
else:
logs["lr/textencoder"] = float(lrs[0])
logs["lr/unet"] = float(lrs[-1]) # may be same to textencoder
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value of unet.
logs["lr/d*lr"] = lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
else:
idx = 0
if not args.network_train_unet_only:
logs["lr/textencoder"] = float(lrs[0])
idx = 1
for i in range(idx, len(lrs)):
logs[f"lr/group{i}"] = float(lrs[i])
if args.optimizer_type.lower() == "DAdaptation".lower():
logs[f"lr/d*lr/group{i}"] = (
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
)
return logs
@@ -56,8 +72,9 @@ def train(args):
use_dreambooth_method = args.in_json is None
use_user_config = args.dataset_config is not None
if args.seed is not None:
set_seed(args.seed)
if args.seed is None:
args.seed = random.randint(0, 2**32)
set_seed(args.seed)
tokenizer = train_util.load_tokenizer(args)
@@ -99,10 +116,10 @@ def train(args):
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
current_epoch = Value('i',0)
current_step = Value('i',0)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch,current_step, ds_for_collater)
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
if args.debug_dataset:
train_util.debug_dataset(train_dataset_group)
@@ -127,12 +144,24 @@ def train(args):
weight_dtype, save_dtype = train_util.prepare_dtype(args)
# モデルを読み込む
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
for pi in range(accelerator.state.num_processes):
# TODO: modify other training scripts as well
if pi == accelerator.state.local_process_index:
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
# work on low-ram device
if args.lowram:
text_encoder.to("cuda")
unet.to("cuda")
text_encoder, vae, unet, _ = train_util.load_target_model(
args, weight_dtype, accelerator.device if args.lowram else "cpu"
)
# work on low-ram device
if args.lowram:
text_encoder.to(accelerator.device)
unet.to(accelerator.device)
vae.to(accelerator.device)
gc.collect()
torch.cuda.empty_cache()
accelerator.wait_for_everyone()
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
@@ -166,15 +195,18 @@ def train(args):
network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs)
if network is None:
return
if args.network_weights is not None:
print("load network weights from:", args.network_weights)
network.load_weights(args.network_weights)
if hasattr(network, "prepare_network"):
network.prepare_network(args)
train_unet = not args.network_train_text_encoder_only
train_text_encoder = not args.network_train_unet_only
network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
if args.network_weights is not None:
info = network.load_weights(args.network_weights)
print(f"load network weights from {args.network_weights}: {info}")
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
text_encoder.gradient_checkpointing_enable()
@@ -183,13 +215,19 @@ def train(args):
# 学習に必要なクラスを準備する
print("prepare optimizer, data loader etc.")
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
# 後方互換性を確保するよ
try:
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate)
except TypeError:
print("Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)")
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
@@ -201,7 +239,9 @@ def train(args):
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
if is_main_process:
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
@@ -270,9 +310,7 @@ def train(args):
train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする
if args.resume is not None:
print(f"resume training from state: {args.resume}")
accelerator.load_state(args.resume)
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -333,6 +371,7 @@ def train(args):
"ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate,
"ss_face_crop_aug_range": args.face_crop_aug_range,
"ss_prior_loss_weight": args.prior_loss_weight,
"ss_min_snr_gamma": args.min_snr_gamma,
}
if use_user_config:
@@ -461,8 +500,6 @@ def train(args):
# add extra args
if args.network_args:
metadata["ss_network_args"] = json.dumps(net_kwargs)
# for key, value in net_kwargs.items():
# metadata["ss_arg_" + key] = value
# model name and hash
if args.pretrained_model_name_or_path is not None:
@@ -505,7 +542,7 @@ def train(args):
for epoch in range(num_train_epochs):
if is_main_process:
print(f"epoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch+1
current_epoch.value = epoch + 1
metadata["ss_epoch"] = str(epoch + 1)
@@ -525,9 +562,17 @@ def train(args):
with torch.set_grad_enabled(train_text_encoder):
# Get the text embedding for conditioning
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype)
if args.weighted_captions:
encoder_hidden_states = get_weighted_text_embeddings(tokenizer,
text_encoder,
batch["captions"],
accelerator.device,
args.max_token_length // 75 if args.max_token_length else 1,
clip_skip=args.clip_skip,
)
else:
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
@@ -556,9 +601,9 @@ def train(args):
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
@@ -613,6 +658,8 @@ def train(args):
metadata["ss_training_finished_at"] = str(time.time())
print(f"saving checkpoint: {ckpt_file}")
unwrap_model(network).save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name)
def remove_old_func(old_epoch_no):
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
@@ -652,6 +699,8 @@ def train(args):
print(f"save trained model to {ckpt_file}")
network.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True)
print("model saved.")
@@ -708,4 +757,4 @@ if __name__ == "__main__":
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
train(args)
train(args)

View File

@@ -188,6 +188,73 @@ gen_img_diffusers.pyに、--network_module、--network_weightsの各オプショ
--network_mulオプションで0~1.0の数値を指定すると、LoRAの適用率を変えられます。
## Diffusersのpipelineで生成する
以下の例を参考にしてください。必要なファイルはnetworks/lora.pyのみです。Diffusersのバージョンは0.10.2以外では動作しない可能性があります。
```python
import torch
from diffusers import StableDiffusionPipeline
from networks.lora import LoRAModule, create_network_from_weights
from safetensors.torch import load_file
# if the ckpt is CompVis based, convert it to Diffusers beforehand with tools/convert_diffusers20_original_sd.py. See --help for more details.
model_id_or_dir = r"model_id_on_hugging_face_or_dir"
device = "cuda"
# create pipe
print(f"creating pipe from {model_id_or_dir}...")
pipe = StableDiffusionPipeline.from_pretrained(model_id_or_dir, revision="fp16", torch_dtype=torch.float16)
pipe = pipe.to(device)
vae = pipe.vae
text_encoder = pipe.text_encoder
unet = pipe.unet
# load lora networks
print(f"loading lora networks...")
lora_path1 = r"lora1.safetensors"
sd = load_file(lora_path1) # If the file is .ckpt, use torch.load instead.
network1, sd = create_network_from_weights(0.5, None, vae, text_encoder,unet, sd)
network1.apply_to(text_encoder, unet)
network1.load_state_dict(sd)
network1.to(device, dtype=torch.float16)
# # You can merge weights instead of apply_to+load_state_dict. network.set_multiplier does not work
# network.merge_to(text_encoder, unet, sd)
lora_path2 = r"lora2.safetensors"
sd = load_file(lora_path2)
network2, sd = create_network_from_weights(0.7, None, vae, text_encoder,unet, sd)
network2.apply_to(text_encoder, unet)
network2.load_state_dict(sd)
network2.to(device, dtype=torch.float16)
lora_path3 = r"lora3.safetensors"
sd = load_file(lora_path3)
network3, sd = create_network_from_weights(0.5, None, vae, text_encoder,unet, sd)
network3.apply_to(text_encoder, unet)
network3.load_state_dict(sd)
network3.to(device, dtype=torch.float16)
# prompts
prompt = "masterpiece, best quality, 1girl, in white shirt, looking at viewer"
negative_prompt = "bad quality, worst quality, bad anatomy, bad hands"
# exec pipe
print("generating image...")
with torch.autocast("cuda"):
image = pipe(prompt, guidance_scale=7.5, negative_prompt=negative_prompt).images[0]
# if not merged, you can use set_multiplier
# network1.set_multiplier(0.8)
# and generate image again...
# save image
image.save(r"by_diffusers..png")
```
## 二つのモデルの差分からLoRAモデルを作成する
[こちらのディスカッション](https://github.com/cloneofsimo/lora/discussions/56)を参考に実装したものです。数式はそのまま使わせていただきました(よく理解していませんが近似には特異値分解を用いるようです)。

View File

@@ -13,6 +13,7 @@ import diffusers
from diffusers import DDPMScheduler
import library.train_util as train_util
import library.huggingface_util as huggingface_util
import library.config_util as config_util
from library.config_util import (
ConfigSanitizer,
@@ -304,9 +305,7 @@ def train(args):
text_encoder.to(weight_dtype)
# resumeする
if args.resume is not None:
print(f"resume training from state: {args.resume}")
accelerator.load_state(args.resume)
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -452,6 +451,8 @@ def train(args):
ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"saving checkpoint: {ckpt_file}")
save_weights(ckpt_file, updated_embs, save_dtype)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name)
def remove_old_func(old_epoch_no):
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
@@ -492,6 +493,8 @@ def train(args):
print(f"save trained model to {ckpt_file}")
save_weights(ckpt_file, updated_embs, save_dtype)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True)
print("model saved.")
@@ -546,7 +549,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_training_arguments(parser, True)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser, False)
parser.add_argument(
"--save_model_as",

View File

@@ -0,0 +1,647 @@
import importlib
import argparse
import gc
import math
import os
import toml
from multiprocessing import Value
from tqdm import tqdm
import torch
from accelerate.utils import set_seed
import diffusers
from diffusers import DDPMScheduler
import library.train_util as train_util
import library.huggingface_util as huggingface_util
import library.config_util as config_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
imagenet_templates_small = [
"a photo of a {}",
"a rendering of a {}",
"a cropped photo of the {}",
"the photo of a {}",
"a photo of a clean {}",
"a photo of a dirty {}",
"a dark photo of the {}",
"a photo of my {}",
"a photo of the cool {}",
"a close-up photo of a {}",
"a bright photo of the {}",
"a cropped photo of a {}",
"a photo of the {}",
"a good photo of the {}",
"a photo of one {}",
"a close-up photo of the {}",
"a rendition of the {}",
"a photo of the clean {}",
"a rendition of a {}",
"a photo of a nice {}",
"a good photo of a {}",
"a photo of the nice {}",
"a photo of the small {}",
"a photo of the weird {}",
"a photo of the large {}",
"a photo of a cool {}",
"a photo of a small {}",
]
imagenet_style_templates_small = [
"a painting in the style of {}",
"a rendering in the style of {}",
"a cropped painting in the style of {}",
"the painting in the style of {}",
"a clean painting in the style of {}",
"a dirty painting in the style of {}",
"a dark painting in the style of {}",
"a picture in the style of {}",
"a cool painting in the style of {}",
"a close-up painting in the style of {}",
"a bright painting in the style of {}",
"a cropped painting in the style of {}",
"a good painting in the style of {}",
"a close-up painting in the style of {}",
"a rendition in the style of {}",
"a nice painting in the style of {}",
"a small painting in the style of {}",
"a weird painting in the style of {}",
"a large painting in the style of {}",
]
def train(args):
if args.output_name is None:
args.output_name = args.token_string
use_template = args.use_object_template or args.use_style_template
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
if args.sample_every_n_steps is not None or args.sample_every_n_epochs is not None:
print(
"sample_every_n_steps and sample_every_n_epochs are not supported in this script currently / sample_every_n_stepsとsample_every_n_epochsは現在このスクリプトではサポートされていません"
)
cache_latents = args.cache_latents
if args.seed is not None:
set_seed(args.seed)
tokenizer = train_util.load_tokenizer(args)
# acceleratorを準備する
print("prepare accelerator")
accelerator, unwrap_model = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
# モデルを読み込む
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
# Convert the init_word to token_id
if args.init_word is not None:
init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False)
if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token:
print(
f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}"
)
else:
init_token_ids = None
# add new word to tokenizer, count is num_vectors_per_token
token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
num_added_tokens = tokenizer.add_tokens(token_strings)
assert (
num_added_tokens == args.num_vectors_per_token
), f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}"
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
print(f"tokens are added: {token_ids}")
assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered"
assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}"
token_strings_XTI = []
XTI_layers = [
"IN01",
"IN02",
"IN04",
"IN05",
"IN07",
"IN08",
"MID",
"OUT03",
"OUT04",
"OUT05",
"OUT06",
"OUT07",
"OUT08",
"OUT09",
"OUT10",
"OUT11",
]
for layer_name in XTI_layers:
token_strings_XTI += [f"{t}_{layer_name}" for t in token_strings]
tokenizer.add_tokens(token_strings_XTI)
token_ids_XTI = tokenizer.convert_tokens_to_ids(token_strings_XTI)
print(f"tokens are added (XTI): {token_ids_XTI}")
# Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder.resize_token_embeddings(len(tokenizer))
# Initialise the newly added placeholder token with the embeddings of the initializer token
token_embeds = text_encoder.get_input_embeddings().weight.data
if init_token_ids is not None:
for i, token_id in enumerate(token_ids_XTI):
token_embeds[token_id] = token_embeds[init_token_ids[(i // 16) % len(init_token_ids)]]
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
# load weights
if args.weights is not None:
embeddings = load_weights(args.weights)
assert len(token_ids) == len(
embeddings
), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}"
# print(token_ids, embeddings.size())
for token_id, embedding in zip(token_ids_XTI, embeddings):
token_embeds[token_id] = embedding
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
print(f"weighs loaded")
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
# データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
use_dreambooth_method = args.in_json is None
if use_dreambooth_method:
print("Use DreamBooth method.")
user_config = {
"datasets": [
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
]
}
else:
print("Train with captions.")
user_config = {
"datasets": [
{
"subsets": [
{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}
]
}
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
train_dataset_group.enable_XTI(XTI_layers, token_strings=token_strings)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
if use_template:
print("use template for training captions. is object: {args.use_object_template}")
templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small
replace_to = " ".join(token_strings)
captions = []
for tmpl in templates:
captions.append(tmpl.format(replace_to))
train_dataset_group.add_replacement("", captions)
if args.num_vectors_per_token > 1:
prompt_replacement = (args.token_string, replace_to)
else:
prompt_replacement = None
else:
if args.num_vectors_per_token > 1:
replace_to = " ".join(token_strings)
train_dataset_group.add_replacement(args.token_string, replace_to)
prompt_replacement = (args.token_string, replace_to)
else:
prompt_replacement = None
if args.debug_dataset:
train_util.debug_dataset(train_dataset_group, show_input_ids=True)
return
if len(train_dataset_group) == 0:
print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
return
if cache_latents:
assert (
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI
diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D.forward = downblock_forward_XTI
diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D.forward = upblock_forward_XTI
# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
text_encoder.gradient_checkpointing_enable()
# 学習に必要なクラスを準備する
print("prepare optimizer, data loader etc.")
trainable_params = text_encoder.get_input_embeddings().parameters()
_, _, optimizer = train_util.get_optimizer(args, trainable_params)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collater,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
# acceleratorがなんかよろしくやってくれるらしい
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder, optimizer, train_dataloader, lr_scheduler
)
index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0]
# print(len(index_no_updates), torch.sum(index_no_updates))
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
# Freeze all parameters except for the token embeddings in text encoder
text_encoder.requires_grad_(True)
text_encoder.text_model.encoder.requires_grad_(False)
text_encoder.text_model.final_layer_norm.requires_grad_(False)
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
# text_encoder.text_model.embeddings.token_embedding.requires_grad_(True)
unet.requires_grad_(False)
unet.to(accelerator.device, dtype=weight_dtype)
if args.gradient_checkpointing: # according to TI example in Diffusers, train is required
unet.train()
else:
unet.eval()
if not cache_latents:
vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=weight_dtype)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
text_encoder.to(weight_dtype)
# resumeする
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
# 学習する
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
print("running training / 学習開始")
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
)
if accelerator.is_main_process:
accelerator.init_trackers("textual_inversion")
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
text_encoder.train()
loss_total = 0
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(text_encoder):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
b_size = latents.shape[0]
# Get the text embedding for conditioning
input_ids = batch["input_ids"].to(accelerator.device)
# weight_dtype) use float instead of fp16/bf16 because text encoder is float
encoder_hidden_states = torch.stack(
[
train_util.get_hidden_states(args, s, tokenizer, text_encoder, weight_dtype)
for s in torch.split(input_ids, 1, dim=1)
]
)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample
if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = text_encoder.get_input_embeddings().parameters()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# Let's make sure we don't update any embedding weights besides the newly added token
with torch.no_grad():
unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[
index_no_updates
]
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
# TODO: fix sample_images
# train_util.sample_images(
# accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
# )
current_loss = loss.detach().item()
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if args.optimizer_type.lower() == "DAdaptation".lower(): # tracking d*lr value
logs["lr/d*lr"] = (
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
)
accelerator.log(logs, step=global_step)
loss_total += current_loss
avr_loss = loss_total / (step + 1)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(train_dataloader)}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids_XTI].data.detach().clone()
if args.save_every_n_epochs is not None:
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
def save_func():
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + "." + args.save_model_as
ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"saving checkpoint: {ckpt_file}")
save_weights(ckpt_file, updated_embs, save_dtype)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name)
def remove_old_func(old_epoch_no):
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + "." + args.save_model_as
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
if os.path.exists(old_ckpt_file):
print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
if saving and args.save_state:
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
# TODO: fix sample_images
# train_util.sample_images(
# accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
# )
# end of epoch
is_main_process = accelerator.is_main_process
if is_main_process:
text_encoder = unwrap_model(text_encoder)
accelerator.end_training()
if args.save_state:
train_util.save_state_on_train_end(args, accelerator)
updated_embs = text_encoder.get_input_embeddings().weight[token_ids_XTI].data.detach().clone()
del accelerator # この後メモリを使うのでこれは消す
if is_main_process:
os.makedirs(args.output_dir, exist_ok=True)
model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
ckpt_name = model_name + "." + args.save_model_as
ckpt_file = os.path.join(args.output_dir, ckpt_name)
print(f"save trained model to {ckpt_file}")
save_weights(ckpt_file, updated_embs, save_dtype)
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True)
print("model saved.")
def save_weights(file, updated_embs, save_dtype):
updated_embs = updated_embs.reshape(16, -1, updated_embs.shape[-1])
updated_embs = updated_embs.chunk(16)
XTI_layers = [
"IN01",
"IN02",
"IN04",
"IN05",
"IN07",
"IN08",
"MID",
"OUT03",
"OUT04",
"OUT05",
"OUT06",
"OUT07",
"OUT08",
"OUT09",
"OUT10",
"OUT11",
]
state_dict = {}
for i, layer_name in enumerate(XTI_layers):
state_dict[layer_name] = updated_embs[i].squeeze(0).detach().clone().to("cpu").to(save_dtype)
# if save_dtype is not None:
# for key in list(state_dict.keys()):
# v = state_dict[key]
# v = v.detach().clone().to("cpu").to(save_dtype)
# state_dict[key] = v
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file
save_file(state_dict, file)
else:
torch.save(state_dict, file) # can be loaded in Web UI
def load_weights(file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
data = load_file(file)
else:
raise ValueError(f"NOT XTI: {file}")
if len(data.values()) != 16:
raise ValueError(f"NOT XTI: {file}")
emb = torch.concat([x for x in data.values()])
return emb
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, False)
train_util.add_training_arguments(parser, True)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
custom_train_functions.add_custom_train_arguments(parser, False)
parser.add_argument(
"--save_model_as",
type=str,
default="pt",
choices=[None, "ckpt", "pt", "safetensors"],
help="format to save the model (default is .pt) / モデル保存時の形式デフォルトはpt",
)
parser.add_argument("--weights", type=str, default=None, help="embedding weights to initialize / 学習するネットワークの初期重み")
parser.add_argument(
"--num_vectors_per_token", type=int, default=1, help="number of vectors per token / トークンに割り当てるembeddingsの要素数"
)
parser.add_argument(
"--token_string",
type=str,
default=None,
help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること",
)
parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可")
parser.add_argument(
"--use_object_template",
action="store_true",
help="ignore caption and use default templates for object / キャプションは使わずデフォルトの物体用テンプレートで学習する",
)
parser.add_argument(
"--use_style_template",
action="store_true",
help="ignore caption and use default templates for stype / キャプションは使わずデフォルトのスタイル用テンプレートで学習する",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
train(args)