Compare commits

...

57 Commits

Author SHA1 Message Date
Kohya S
25c8279f26 Merge pull request #437 from kohya-ss/dev
fix tensorboard logging without log_with
2023-04-23 19:19:48 +09:00
Kohya S
05c57b9c7b update readme 2023-04-23 19:18:04 +09:00
Kohya S
46cbae088e fix to log with logging_dir without log_with 2023-04-23 19:15:48 +09:00
Kohya S
b824bbfce6 Merge pull request #436 from kohya-ss/dev
automatic wandb login with api key
2023-04-22 20:20:01 +09:00
Kohya S
9ba4c3edca update readme/comment 2023-04-22 20:18:25 +09:00
Kohya S
ed2eef1625 Merge pull request #435 from Linaqruf/main
Add new args to pass API key for wandb
2023-04-22 20:06:12 +09:00
Linaqruf
e9a641bde7 Merge branch 'main' of https://github.com/Linaqruf/sd-scripts 2023-04-22 16:17:22 +07:00
Linaqruf
ae3965a2a7 feat: add arguments to set \--wandb_api_key\ before training 2023-04-22 16:14:14 +07:00
Kohya S
700af1c96d Merge pull request #434 from kohya-ss/dev
fix no logging causes error
2023-04-22 18:10:19 +09:00
Kohya S
66edc5af7b invert condition for checking log_with 2023-04-22 18:05:19 +09:00
Kohya S
ed15f6808b Merge pull request #433 from sALTaccount/main
fix no logging command line arg
2023-04-22 18:00:27 +09:00
saltacc
dc37fd2ff6 fix no logging command line arg 2023-04-22 01:26:31 -07:00
Kohya S
f256660780 Merge pull request #431 from kohya-ss/dev
wandb logging, improve debug_dataset on non-windows
2023-04-22 10:52:40 +09:00
Kohya S
23b261de3f update readme 2023-04-22 10:48:58 +09:00
Kohya S
884e6bff5d fix face_crop_aug not working on finetune method, prepare upscaler 2023-04-22 10:41:36 +09:00
Kohya S
220436244c some minor fixes 2023-04-22 09:55:04 +09:00
Kohya S
c430cf481a Merge pull request #428 from p1atdev/dev
Add WandB logging support
2023-04-22 09:39:01 +09:00
Kohya S
9f8f27fbad Merge pull request #429 from tsukimiya/hotfix/debug_dataset_linux_support
Refixed --debug_dataset option to work in non-Windows environments
2023-04-22 08:57:06 +09:00
tsukimiya
e746829b5f おそらくlibgtk2がインストールされていない環境でcv2.waitKey() および cv2.destroyAllWindows() が動作しないので除外 2023-04-20 06:20:02 +09:00
Plat
a69b24a069 fix: tensorboard not working 2023-04-20 05:33:32 +09:00
Plat
12567f55cd chore: add wandb to gitignore 2023-04-20 05:16:43 +09:00
Plat
8090daca40 fix: wandb not working without logging_dir 2023-04-20 05:14:28 +09:00
Plat
27ffd9fe3d feat: support wandb logging 2023-04-20 01:41:12 +09:00
Kohya S
ee5cec7530 Merge pull request #427 from kohya-ss/dev
fix lora_interrogator, wd14 tagger for '^_^' etc
2023-04-19 21:57:34 +09:00
Kohya S
589a90bfbc update readme 2023-04-19 21:54:55 +09:00
Kohya S
314a364f61 restore sd_model arg for backward compat 2023-04-19 21:11:12 +09:00
Kohya S
f770cd96c6 Merge pull request #392 from A2va/fix
Lora interrogator fixes
2023-04-19 20:28:27 +09:00
Kohya S
01df1c0cc4 don't replace underscore in emoji tags like ^_^ 2023-04-19 12:42:20 +09:00
Kohya S
334589af4e Merge pull request #424 from kohya-ss/dev
recursive support for finetune scripts
2023-04-17 22:06:33 +09:00
Kohya S
43ef635be3 update readme 2023-04-17 22:03:57 +09:00
Kohya S
47d61e2c02 format by black 2023-04-17 22:00:26 +09:00
Kohya S
8f6fc8daa1 add ja comment, restore arg for backward compat 2023-04-17 21:58:55 +09:00
Kohya S
01ebfc41f3 Merge pull request #400 from Linaqruf/main
Recursive support for captioning/tagging scripts
2023-04-17 21:20:57 +09:00
A2va
87163cff8b Fix missing pretrained_model_name_or_path 2023-04-17 09:16:07 +02:00
Kohya S
6d5f847edc Merge pull request #413 from kohya-ss/dev
fix dylora loading
2023-04-14 22:17:08 +09:00
Kohya S
afb8700a95 update readme 2023-04-14 22:15:31 +09:00
Kohya S
e60d18cfb3 Merge branch 'main' into dev 2023-04-14 22:13:49 +09:00
Kohya S
92332eb96e fix load_state_dict failed in dylora 2023-04-14 22:13:26 +09:00
Linaqruf
d5263d442f Merge branch 'kohya-ss:main' into main 2023-04-14 05:46:59 +07:00
Kohya S
7ad7cac0c2 Merge pull request #408 from kohya-ss/dev
DyLoRA, latents disk cache etc.
2023-04-13 22:31:27 +09:00
Kohya S
06a9f51431 fix link in readme 2023-04-13 22:27:00 +09:00
Kohya S
849bc24d20 update readme 2023-04-13 22:24:47 +09:00
Kohya S
423e6c229c support metadata json+.npz caching (no prepare) 2023-04-13 22:12:13 +09:00
Kohya S
9fc27403b2 support disk cache: same as #164, might fix #407 2023-04-13 21:40:34 +09:00
Kohya S
2de9a51591 fix typos 2023-04-13 21:18:18 +09:00
Kohya S
a8632b7329 fix latents disk cache 2023-04-13 21:14:39 +09:00
Kohya S
9ff32fd4c0 fix parameters are not freezed 2023-04-13 21:14:20 +09:00
Kohya S
a097c42579 update docs 2023-04-13 21:07:22 +09:00
Kohya S
68e0767404 add comment about scaling 2023-04-12 23:40:10 +09:00
Kohya S
e09966024c delete unnecessary lines 2023-04-12 23:16:47 +09:00
Kohya S
893c2fc08a add DyLoRA (experimental) 2023-04-12 23:14:09 +09:00
Kohya S
2e9f7b5f91 cache latents to disk in dreambooth method 2023-04-12 23:10:39 +09:00
Linaqruf
7f8e05ccad feat: add --save_precision args 2023-04-12 05:43:19 +07:00
Linaqruf
c316c63dff fix: bring positional args back, add recursive to blip etc 2023-04-12 05:41:28 +07:00
A2va
683680e5c8 Fixes 2023-04-09 21:52:02 +02:00
Linaqruf
bf8088e225 Merge branch 'kohya-ss:main' into main 2023-04-09 22:34:31 +07:00
Linaqruf
07aa000750 feat: added 7 new functionalities including recursive 2023-04-07 16:51:43 +07:00
21 changed files with 2289 additions and 759 deletions

3
.gitignore vendored
View File

@@ -4,4 +4,5 @@ wd14_tagger_model
venv
*.egg-info
build
.vscode
.vscode
wandb

121
README.md
View File

@@ -127,56 +127,93 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
## Change History
### 8 Apr. 2021, 2021/4/8:
### 23 Apr. 2023, 2023/4/23:
- Added support for training with weighted captions. Thanks to AI-Casanova for the great contribution!
- Please refer to the PR for details: [PR #336](https://github.com/kohya-ss/sd-scripts/pull/336)
- Specify the `--weighted_captions` option. It is available for all training scripts except Textual Inversion and XTI.
- This option is also applicable to token strings of the DreamBooth method.
- The syntax for weighted captions is almost the same as the Web UI, and you can use things like `(abc)`, `[abc]`, and `(abc:1.23)`. Nesting is also possible.
- If you include a comma in the parentheses, the parentheses will not be properly matched in the prompt shuffle/dropout, so do not include a comma in the parentheses.
- Fixed to log to TensorBoard when `--logging_dir` is specified and `--log_with` is not specified.
- `--logging_dir`を指定し`--log_with`を指定しない場合に、以前と同様にTensorBoardへログ出力するよう修正しました。
- 重みづけキャプションによる学習に対応しました。 AI-Casanova 氏の素晴らしい貢献に感謝します。
- 詳細はこちらをご確認ください。[PR #336](https://github.com/kohya-ss/sd-scripts/pull/336)
- `--weighted_captions` オプションを指定してください。Textual InversionおよびXTIを除く学習スクリプトで使用可能です。
- キャプションだけでなく DreamBooth 手法の token string でも有効です。
- 重みづけキャプションの記法はWeb UIとほぼ同じで、`(abc)`や`[abc]`、`(abc:1.23)`などが使用できます。入れ子も可能です。
- 括弧内にカンマを含めるとプロンプトのshuffle/dropoutで括弧の対応付けがおかしくなるため、括弧内にはカンマを含めないでください。
### 22 Apr. 2023, 2023/4/22:
- Added support for logging to wandb. Please refer to [PR #428](https://github.com/kohya-ss/sd-scripts/pull/428). Thank you p1atdev!
- `wandb` installation is required. Please install it with `pip install wandb`. Login to wandb with `wandb login` command, or set `--wandb_api_key` option for automatic login.
- Please let me know if you find any bugs as the test is not complete.
- You can automatically login to wandb by setting the `--wandb_api_key` option. Please be careful with the handling of API Key. [PR #435](https://github.com/kohya-ss/sd-scripts/pull/435) Thank you Linaqruf!
- Improved the behavior of `--debug_dataset` on non-Windows environments. [PR #429](https://github.com/kohya-ss/sd-scripts/pull/429) Thank you tsukimiya!
- Fixed `--face_crop_aug` option not working in Fine tuning method.
- Prepared code to use any upscaler in `gen_img_diffusers.py`.
- wandbへのロギングをサポートしました。詳細は [PR #428](https://github.com/kohya-ss/sd-scripts/pull/428)をご覧ください。p1atdev氏に感謝します。
- `wandb` のインストールが別途必要です。`pip install wandb` でインストールしてください。また `wandb login` でログインしてください(学習スクリプト内でログインする場合は `--wandb_api_key` オプションを設定してください)。
- テスト未了のため不具合等ありましたらご連絡ください。
- wandbへのロギング時に `--wandb_api_key` オプションを設定することで自動ログインできます。API Keyの扱いにご注意ください。 [PR #435](https://github.com/kohya-ss/sd-scripts/pull/435) Linaqruf氏に感謝します。
- Windows以外の環境での`--debug_dataset` の動作を改善しました。[PR #429](https://github.com/kohya-ss/sd-scripts/pull/429) tsukimiya氏に感謝します。
- `--face_crop_aug`オプションがFine tuning方式で動作しなかったのを修正しました。
- `gen_img_diffusers.py`に任意のupscalerを利用するためのコード準備を行いました。
### 6 Apr. 2023, 2023/4/6:
- There may be bugs because I changed a lot. If you cannot revert the script to the previous version when a problem occurs, please wait for the update for a while.
### 19 Apr. 2023, 2023/4/19:
- Fixed `lora_interrogator.py` not working. Please refer to [PR #392](https://github.com/kohya-ss/sd-scripts/pull/392) for details. Thank you A2va and heyalexchoi!
- Fixed the handling of tags containing `_` in `tag_images_by_wd14_tagger.py`.
- `lora_interrogator.py`が動作しなくなっていたのを修正しました。詳細は [PR #392](https://github.com/kohya-ss/sd-scripts/pull/392) をご参照ください。A2va氏およびheyalexchoi氏に感謝します。
- `tag_images_by_wd14_tagger.py`で`_`を含むタグの取り扱いを修正しました。
- Added a feature to upload model and state to HuggingFace. Thanks to ddPn08 for the contribution! [PR #348](https://github.com/kohya-ss/sd-scripts/pull/348)
- When `--huggingface_repo_id` is specified, the model is uploaded to HuggingFace at the same time as saving the model.
- Please note that the access token is handled with caution. Please refer to the [HuggingFace documentation](https://huggingface.co/docs/hub/security-tokens).
- For example, specify other arguments as follows.
- `--huggingface_repo_id "your-hf-name/your-model" --huggingface_path_in_repo "path" --huggingface_repo_type model --huggingface_repo_visibility private --huggingface_token hf_YourAccessTokenHere`
- If `public` is specified for `--huggingface_repo_visibility`, the repository will be public. If the option is omitted or `private` (or anything other than `public`) is specified, it will be private.
- If you specify `--save_state` and `--save_state_to_huggingface`, the state will also be uploaded.
- If you specify `--resume` and `--resume_from_huggingface`, the state will be downloaded from HuggingFace and resumed.
- In this case, the `--resume` option is `--resume {repo_id}/{path_in_repo}:{revision}:{repo_type}`. For example: `--resume_from_huggingface --resume your-hf-name/your-model/path/test-000002-state:main:model`
- If you specify `--async_upload`, the upload will be done asynchronously.
- Added the documentation for applying LoRA to generate with the standard pipeline of Diffusers. [training LoRA](./train_network_README-ja.md#diffusersのpipelineで生成する) (Japanese only)
- Support for Attention Couple and regional LoRA in `gen_img_diffusers.py`.
- If you use ` AND ` to separate the prompts, each sub-prompt is sequentially applied to LoRA. `--mask_path` is treated as a mask image. The number of sub-prompts and the number of LoRA must match.
### Naming of LoRA
The LoRA supported by `train_network.py` has been named to avoid confusion. The documentation has been updated. The following are the names of LoRA types in this repository.
- 大きく変更したため不具合があるかもしれません。問題が起きた時にスクリプトを前のバージョンに戻せない場合は、しばらく更新を控えてください。
1. __LoRA-LierLa__ : (LoRA for __Li__ n __e__ a __r__ __La__ yers)
- モデルおよびstateをHuggingFaceにアップロードする機能を各スクリプトに追加しました。 [PR #348](https://github.com/kohya-ss/sd-scripts/pull/348) ddPn08 氏の貢献に感謝します。
- `--huggingface_repo_id`が指定されているとモデル保存時に同時にHuggingFaceにアップロードします。
- アクセストークンの取り扱いに注意してください。[HuggingFaceのドキュメント](https://huggingface.co/docs/hub/security-tokens)を参照してください。
- 他の引数をたとえば以下のように指定してください。
- `--huggingface_repo_id "your-hf-name/your-model" --huggingface_path_in_repo "path" --huggingface_repo_type model --huggingface_repo_visibility private --huggingface_token hf_YourAccessTokenHere`
- `--huggingface_repo_visibility`に`public`を指定するとリポジトリが公開されます。省略時または`private`(など`public`以外)を指定すると非公開になります。
- `--save_state`オプション指定時に`--save_state_to_huggingface`を指定するとstateもアップロードします。
- `--resume`オプション指定時に`--resume_from_huggingface`を指定するとHuggingFaceからstateをダウンロードして再開します。
- その時の `--resume`オプションは `--resume {repo_id}/{path_in_repo}:{revision}:{repo_type}`になります。例: `--resume_from_huggingface --resume your-hf-name/your-model/path/test-000002-state:main:model`
- `--async_upload`オプションを指定するとアップロードを非同期で行います。
- [LoRAの文書](./train_network_README-ja.md#diffusersのpipelineで生成する)に、LoRAを適用してDiffusersの標準的なパイプラインで生成する方法を追記しました。
- `gen_img_diffusers.py` で Attention Couple および領域別LoRAに対応しました。
- プロンプトを` AND `で区切ると各サブプロンプトが順にLoRAに適用されます。`--mask_path` がマスク画像として扱われます。サブプロンプトの数とLoRAの数は一致している必要があります。
LoRA for Linear layers and Conv2d layers with 1x1 kernel
2. __LoRA-C3Lier__ : (LoRA for __C__ olutional layers with __3__ x3 Kernel and __Li__ n __e__ a __r__ layers)
In addition to 1., LoRA for Conv2d layers with 3x3 kernel
LoRA-LierLa is the default LoRA type for `train_network.py` (without `conv_dim` network arg). LoRA-LierLa can be used with [our extension](https://github.com/kohya-ss/sd-webui-additional-networks) for AUTOMATIC1111's Web UI, or with the built-in LoRA feature of the Web UI.
To use LoRA-C3Liar with Web UI, please use our extension.
### LoRAの名称について
`train_network.py` がサポートするLoRAについて、混乱を避けるため名前を付けました。ドキュメントは更新済みです。以下は当リポジトリ内の独自の名称です。
1. __LoRA-LierLa__ : (LoRA for __Li__ n __e__ a __r__ __La__ yers、リエラと読みます)
Linear 層およびカーネルサイズ 1x1 の Conv2d 層に適用されるLoRA
2. __LoRA-C3Lier__ : (LoRA for __C__ olutional layers with __3__ x3 Kernel and __Li__ n __e__ a __r__ layers、セリアと読みます)
1.に加え、カーネルサイズ 3x3 の Conv2d 層に適用されるLoRA
LoRA-LierLa は[Web UI向け拡張](https://github.com/kohya-ss/sd-webui-additional-networks)、またはAUTOMATIC1111氏のWeb UIのLoRA機能で使用することができます。
LoRA-C3Liarを使いWeb UIで生成するには拡張を使用してください。
### 17 Apr. 2023, 2023/4/17:
- Added the `--recursive` option to each script in the `finetune` folder to process folders recursively. Please refer to [PR #400](https://github.com/kohya-ss/sd-scripts/pull/400/) for details. Thanks to Linaqruf!
- `finetune`フォルダ内の各スクリプトに再起的にフォルダを処理するオプション`--recursive`を追加しました。詳細は [PR #400](https://github.com/kohya-ss/sd-scripts/pull/400/) を参照してください。Linaqruf 氏に感謝します。
### 14 Apr. 2023, 2023/4/14:
- Fixed a bug that caused an error when loading DyLoRA with the `--network_weight` option in `train_network.py`.
- `train_network.py`で、DyLoRAを`--network_weight`オプションで読み込むとエラーになる不具合を修正しました。
### 13 Apr. 2023, 2023/4/13:
- Added support for DyLoRA in `train_network.py`. Please refer to [here](./train_network_README-ja.md#dylora) for details (currently only in Japanese).
- Added support for caching latents to disk in each training script. Please specify __both__ `--cache_latents` and `--cache_latents_to_disk` options.
- The files are saved in the same folder as the images with the extension `.npz`. If you specify the `--flip_aug` option, the files with `_flip.npz` will also be saved.
- Multi-GPU training has not been tested.
- This feature is not tested with all combinations of datasets and training scripts, so there may be bugs.
- Added workaround for an error that occurs when training with `fp16` or `bf16` in `fine_tune.py`.
- `train_network.py`でDyLoRAをサポートしました。詳細は[こちら](./train_network_README-ja.md#dylora)をご覧ください。
- 各学習スクリプトでlatentのディスクへのキャッシュをサポートしました。`--cache_latents`オプションに __加えて__、`--cache_latents_to_disk`オプションを指定してください。
- 画像と同じフォルダに、拡張子 `.npz` で保存されます。`--flip_aug`オプションを指定した場合、`_flip.npz`が付いたファイルにも保存されます。
- マルチGPUでの学習は未テストです。
- すべてのDataset、学習スクリプトの組み合わせでテストしたわけではないため、不具合があるかもしれません。
- `fine_tune.py`で、`fp16`および`bf16`の学習時にエラーが出る不具合に対して対策を行いました。
## Sample image generation during training
A prompt file might look like this, for example

View File

@@ -142,12 +142,14 @@ def train(args):
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size)
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
accelerator.wait_for_everyone()
# 学習を準備する:モデルを適切な状態にする
training_models = []
if args.gradient_checkpointing:
@@ -258,7 +260,7 @@ def train(args):
)
if accelerator.is_main_process:
accelerator.init_trackers("finetuning")
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name)
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
@@ -273,7 +275,7 @@ def train(args):
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
latents = batch["latents"].to(accelerator.device) # .to(dtype=weight_dtype)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
@@ -311,7 +313,8 @@ def train(args):
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
with accelerator.autocast():
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
if args.v_parameterization:
# v-parameterization training

View File

@@ -4,6 +4,7 @@ import os
import json
import random
from pathlib import Path
from PIL import Image
from tqdm import tqdm
import numpy as np
@@ -13,156 +14,185 @@ from torchvision.transforms.functional import InterpolationMode
from blip.blip import blip_decoder
import library.train_util as train_util
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGE_SIZE = 384
# 正方形でいいのか? という気がするがソースがそうなので
IMAGE_TRANSFORM = transforms.Compose([
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])
IMAGE_TRANSFORM = transforms.Compose(
[
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
]
)
# 共通化したいが微妙に処理が異なる……
class ImageLoadingTransformDataset(torch.utils.data.Dataset):
def __init__(self, image_paths):
self.images = image_paths
def __init__(self, image_paths):
self.images = image_paths
def __len__(self):
return len(self.images)
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = self.images[idx]
def __getitem__(self, idx):
img_path = self.images[idx]
try:
image = Image.open(img_path).convert("RGB")
# convert to tensor temporarily so dataloader will accept it
tensor = IMAGE_TRANSFORM(image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
return None
try:
image = Image.open(img_path).convert("RGB")
# convert to tensor temporarily so dataloader will accept it
tensor = IMAGE_TRANSFORM(image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
return None
return (tensor, img_path)
return (tensor, img_path)
def collate_fn_remove_corrupted(batch):
"""Collate function that allows to remove corrupted examples in the
dataloader. It expects that the dataloader returns 'None' when that occurs.
The 'None's in the batch are removed.
"""
# Filter out all the Nones (corrupted examples)
batch = list(filter(lambda x: x is not None, batch))
return batch
"""Collate function that allows to remove corrupted examples in the
dataloader. It expects that the dataloader returns 'None' when that occurs.
The 'None's in the batch are removed.
"""
# Filter out all the Nones (corrupted examples)
batch = list(filter(lambda x: x is not None, batch))
return batch
def main(args):
# fix the seed for reproducibility
seed = args.seed # + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
# fix the seed for reproducibility
seed = args.seed # + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if not os.path.exists("blip"):
args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path
if not os.path.exists("blip"):
args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path
cwd = os.getcwd()
print('Current Working Directory is: ', cwd)
os.chdir('finetune')
cwd = os.getcwd()
print("Current Working Directory is: ", cwd)
os.chdir("finetune")
print(f"load images from {args.train_data_dir}")
image_paths = train_util.glob_images(args.train_data_dir)
print(f"found {len(image_paths)} images.")
print(f"load images from {args.train_data_dir}")
train_data_dir_path = Path(args.train_data_dir)
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.")
print(f"loading BLIP caption: {args.caption_weights}")
model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit='large', med_config="./blip/med_config.json")
model.eval()
model = model.to(DEVICE)
print("BLIP loaded")
print(f"loading BLIP caption: {args.caption_weights}")
model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit="large", med_config="./blip/med_config.json")
model.eval()
model = model.to(DEVICE)
print("BLIP loaded")
# captioningする
def run_batch(path_imgs):
imgs = torch.stack([im for _, im in path_imgs]).to(DEVICE)
# captioningする
def run_batch(path_imgs):
imgs = torch.stack([im for _, im in path_imgs]).to(DEVICE)
with torch.no_grad():
if args.beam_search:
captions = model.generate(imgs, sample=False, num_beams=args.num_beams,
max_length=args.max_length, min_length=args.min_length)
else:
captions = model.generate(imgs, sample=True, top_p=args.top_p, max_length=args.max_length, min_length=args.min_length)
with torch.no_grad():
if args.beam_search:
captions = model.generate(
imgs, sample=False, num_beams=args.num_beams, max_length=args.max_length, min_length=args.min_length
)
else:
captions = model.generate(
imgs, sample=True, top_p=args.top_p, max_length=args.max_length, min_length=args.min_length
)
for (image_path, _), caption in zip(path_imgs, captions):
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
f.write(caption + "\n")
if args.debug:
print(image_path, caption)
for (image_path, _), caption in zip(path_imgs, captions):
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
f.write(caption + "\n")
if args.debug:
print(image_path, caption)
# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
dataset = ImageLoadingTransformDataset(image_paths)
data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
else:
data = [[(None, ip)] for ip in image_paths]
# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
dataset = ImageLoadingTransformDataset(image_paths)
data = torch.utils.data.DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.max_data_loader_n_workers,
collate_fn=collate_fn_remove_corrupted,
drop_last=False,
)
else:
data = [[(None, ip)] for ip in image_paths]
b_imgs = []
for data_entry in tqdm(data, smoothing=0.0):
for data in data_entry:
if data is None:
continue
b_imgs = []
for data_entry in tqdm(data, smoothing=0.0):
for data in data_entry:
if data is None:
continue
img_tensor, image_path = data
if img_tensor is None:
try:
raw_image = Image.open(image_path)
if raw_image.mode != 'RGB':
raw_image = raw_image.convert("RGB")
img_tensor = IMAGE_TRANSFORM(raw_image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue
img_tensor, image_path = data
if img_tensor is None:
try:
raw_image = Image.open(image_path)
if raw_image.mode != "RGB":
raw_image = raw_image.convert("RGB")
img_tensor = IMAGE_TRANSFORM(raw_image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue
b_imgs.append((image_path, img_tensor))
if len(b_imgs) >= args.batch_size:
b_imgs.append((image_path, img_tensor))
if len(b_imgs) >= args.batch_size:
run_batch(b_imgs)
b_imgs.clear()
if len(b_imgs) > 0:
run_batch(b_imgs)
b_imgs.clear()
if len(b_imgs) > 0:
run_batch(b_imgs)
print("done!")
print("done!")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("--caption_weights", type=str, default="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth",
help="BLIP caption weights (model_large_caption.pth) / BLIP captionの重みファイル(model_large_caption.pth)")
parser.add_argument("--caption_extention", type=str, default=None,
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
parser.add_argument("--beam_search", action="store_true",
help="use beam search (default Nucleus sampling) / beam searchを使うこのオプション未指定時はNucleus sampling")
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する読み込みを高速化")
parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数多いと精度が上がるが時間がかかる")
parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p")
parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長")
parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長")
parser.add_argument('--seed', default=42, type=int, help='seed for reproducibility / 再現性を確保するための乱数seed')
parser.add_argument("--debug", action="store_true", help="debug mode")
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument(
"--caption_weights",
type=str,
default="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth",
help="BLIP caption weights (model_large_caption.pth) / BLIP captionの重みファイル(model_large_caption.pth)",
)
parser.add_argument(
"--caption_extention",
type=str,
default=None,
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります",
)
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
parser.add_argument(
"--beam_search",
action="store_true",
help="use beam search (default Nucleus sampling) / beam searchを使うこのオプション未指定時はNucleus sampling",
)
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument(
"--max_data_loader_n_workers",
type=int,
default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する読み込みを高速化",
)
parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数多いと精度が上がるが時間がかかる")
parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p")
parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長")
parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長")
parser.add_argument("--seed", default=42, type=int, help="seed for reproducibility / 再現性を確保するための乱数seed")
parser.add_argument("--debug", action="store_true", help="debug mode")
parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する")
return parser
return parser
if __name__ == '__main__':
parser = setup_parser()
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
args = parser.parse_args()
# スペルミスしていたオプションを復元する
if args.caption_extention is not None:
args.caption_extension = args.caption_extention
# スペルミスしていたオプションを復元する
if args.caption_extention is not None:
args.caption_extension = args.caption_extention
main(args)
main(args)

View File

@@ -2,6 +2,7 @@ import argparse
import os
import re
from pathlib import Path
from PIL import Image
from tqdm import tqdm
import torch
@@ -11,141 +12,161 @@ from transformers.generation.utils import GenerationMixin
import library.train_util as train_util
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
PATTERN_REPLACE = [
re.compile(r'(has|with|and) the (words?|letters?|name) (" ?[^"]*"|\w+)( ?(is )?(on|in) (the |her |their |him )?\w+)?'),
re.compile(r'(with a sign )?that says ?(" ?[^"]*"|\w+)( ?on it)?'),
re.compile(r"(with a sign )?that says ?(' ?(i'm)?[^']*'|\w+)( ?on it)?"),
re.compile(r'with the number \d+ on (it|\w+ \w+)'),
re.compile(r"with the number \d+ on (it|\w+ \w+)"),
re.compile(r'with the words "'),
re.compile(r'word \w+ on it'),
re.compile(r'that says the word \w+ on it'),
re.compile('that says\'the word "( on it)?'),
re.compile(r"word \w+ on it"),
re.compile(r"that says the word \w+ on it"),
re.compile("that says'the word \"( on it)?"),
]
# 誤検知しまくりの with the word xxxx を消す
def remove_words(captions, debug):
removed_caps = []
for caption in captions:
cap = caption
for pat in PATTERN_REPLACE:
cap = pat.sub("", cap)
if debug and cap != caption:
print(caption)
print(cap)
removed_caps.append(cap)
return removed_caps
removed_caps = []
for caption in captions:
cap = caption
for pat in PATTERN_REPLACE:
cap = pat.sub("", cap)
if debug and cap != caption:
print(caption)
print(cap)
removed_caps.append(cap)
return removed_caps
def collate_fn_remove_corrupted(batch):
"""Collate function that allows to remove corrupted examples in the
dataloader. It expects that the dataloader returns 'None' when that occurs.
The 'None's in the batch are removed.
"""
# Filter out all the Nones (corrupted examples)
batch = list(filter(lambda x: x is not None, batch))
return batch
"""Collate function that allows to remove corrupted examples in the
dataloader. It expects that the dataloader returns 'None' when that occurs.
The 'None's in the batch are removed.
"""
# Filter out all the Nones (corrupted examples)
batch = list(filter(lambda x: x is not None, batch))
return batch
def main(args):
# GITにバッチサイズが1より大きくても動くようにパッチを当てる: transformers 4.26.0用
org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation
curr_batch_size = [args.batch_size] # ループの最後で件数がbatch_size未満になるので入れ替えられるように
# GITにバッチサイズが1より大きくても動くようにパッチを当てる: transformers 4.26.0用
org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation
curr_batch_size = [args.batch_size] # ループの最後で件数がbatch_size未満になるので入れ替えられるように
# input_idsがバッチサイズと同じ件数である必要があるバッチサイズはこの関数から参照できないので外から渡す
# ここより上で置き換えようとするとすごく大変
def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs):
input_ids = org_prepare_input_ids_for_generation(self, bos_token_id, encoder_outputs)
if input_ids.size()[0] != curr_batch_size[0]:
input_ids = input_ids.repeat(curr_batch_size[0], 1)
return input_ids
GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
# input_idsがバッチサイズと同じ件数である必要があるバッチサイズはこの関数から参照できないので外から渡す
# ここより上で置き換えようとするとすごく大変
def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs):
input_ids = org_prepare_input_ids_for_generation(self, bos_token_id, encoder_outputs)
if input_ids.size()[0] != curr_batch_size[0]:
input_ids = input_ids.repeat(curr_batch_size[0], 1)
return input_ids
print(f"load images from {args.train_data_dir}")
image_paths = train_util.glob_images(args.train_data_dir)
print(f"found {len(image_paths)} images.")
GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
# できればcacheに依存せず明示的にダウンロードしたい
print(f"loading GIT: {args.model_id}")
git_processor = AutoProcessor.from_pretrained(args.model_id)
git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE)
print("GIT loaded")
print(f"load images from {args.train_data_dir}")
train_data_dir_path = Path(args.train_data_dir)
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.")
# captioningする
def run_batch(path_imgs):
imgs = [im for _, im in path_imgs]
# できればcacheに依存せず明示的にダウンロードしたい
print(f"loading GIT: {args.model_id}")
git_processor = AutoProcessor.from_pretrained(args.model_id)
git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE)
print("GIT loaded")
curr_batch_size[0] = len(path_imgs)
inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE) # 画像はpil形式
generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length)
captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True)
# captioningする
def run_batch(path_imgs):
imgs = [im for _, im in path_imgs]
if args.remove_words:
captions = remove_words(captions, args.debug)
curr_batch_size[0] = len(path_imgs)
inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE) # 画像はpil形式
generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length)
captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True)
for (image_path, _), caption in zip(path_imgs, captions):
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
f.write(caption + "\n")
if args.debug:
print(image_path, caption)
if args.remove_words:
captions = remove_words(captions, args.debug)
# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
dataset = train_util.ImageLoadingDataset(image_paths)
data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
else:
data = [[(None, ip)] for ip in image_paths]
for (image_path, _), caption in zip(path_imgs, captions):
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
f.write(caption + "\n")
if args.debug:
print(image_path, caption)
b_imgs = []
for data_entry in tqdm(data, smoothing=0.0):
for data in data_entry:
if data is None:
continue
# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
dataset = train_util.ImageLoadingDataset(image_paths)
data = torch.utils.data.DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.max_data_loader_n_workers,
collate_fn=collate_fn_remove_corrupted,
drop_last=False,
)
else:
data = [[(None, ip)] for ip in image_paths]
image, image_path = data
if image is None:
try:
image = Image.open(image_path)
if image.mode != 'RGB':
image = image.convert("RGB")
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue
b_imgs = []
for data_entry in tqdm(data, smoothing=0.0):
for data in data_entry:
if data is None:
continue
b_imgs.append((image_path, image))
if len(b_imgs) >= args.batch_size:
image, image_path = data
if image is None:
try:
image = Image.open(image_path)
if image.mode != "RGB":
image = image.convert("RGB")
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue
b_imgs.append((image_path, image))
if len(b_imgs) >= args.batch_size:
run_batch(b_imgs)
b_imgs.clear()
if len(b_imgs) > 0:
run_batch(b_imgs)
b_imgs.clear()
if len(b_imgs) > 0:
run_batch(b_imgs)
print("done!")
print("done!")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
parser.add_argument("--model_id", type=str, default="microsoft/git-large-textcaps",
help="model id for GIT in Hugging Face / 使用するGITのHugging FaceのモデルID")
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する読み込みを高速化")
parser.add_argument("--max_length", type=int, default=50, help="max length of caption / captionの最大長")
parser.add_argument("--remove_words", action="store_true",
help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する")
parser.add_argument("--debug", action="store_true", help="debug mode")
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
parser.add_argument(
"--model_id",
type=str,
default="microsoft/git-large-textcaps",
help="model id for GIT in Hugging Face / 使用するGITのHugging FaceのモデルID",
)
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument(
"--max_data_loader_n_workers",
type=int,
default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する読み込みを高速化",
)
parser.add_argument("--max_length", type=int, default=50, help="max length of caption / captionの最大長")
parser.add_argument(
"--remove_words",
action="store_true",
help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する",
)
parser.add_argument("--debug", action="store_true", help="debug mode")
parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する")
return parser
return parser
if __name__ == '__main__':
parser = setup_parser()
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
main(args)
args = parser.parse_args()
main(args)

View File

@@ -2,6 +2,8 @@ import argparse
import os
import json
from pathlib import Path
from typing import List
from tqdm import tqdm
import numpy as np
from PIL import Image
@@ -12,7 +14,7 @@ from torchvision import transforms
import library.model_util as model_util
import library.train_util as train_util
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGE_TRANSFORMS = transforms.Compose(
[
@@ -23,245 +25,299 @@ IMAGE_TRANSFORMS = transforms.Compose(
def collate_fn_remove_corrupted(batch):
"""Collate function that allows to remove corrupted examples in the
dataloader. It expects that the dataloader returns 'None' when that occurs.
The 'None's in the batch are removed.
"""
# Filter out all the Nones (corrupted examples)
batch = list(filter(lambda x: x is not None, batch))
return batch
"""Collate function that allows to remove corrupted examples in the
dataloader. It expects that the dataloader returns 'None' when that occurs.
The 'None's in the batch are removed.
"""
# Filter out all the Nones (corrupted examples)
batch = list(filter(lambda x: x is not None, batch))
return batch
def get_latents(vae, images, weight_dtype):
img_tensors = [IMAGE_TRANSFORMS(image) for image in images]
img_tensors = torch.stack(img_tensors)
img_tensors = img_tensors.to(DEVICE, weight_dtype)
with torch.no_grad():
latents = vae.encode(img_tensors).latent_dist.sample().float().to("cpu").numpy()
return latents
img_tensors = [IMAGE_TRANSFORMS(image) for image in images]
img_tensors = torch.stack(img_tensors)
img_tensors = img_tensors.to(DEVICE, weight_dtype)
with torch.no_grad():
latents = vae.encode(img_tensors).latent_dist.sample().float().to("cpu").numpy()
return latents
def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip):
if is_full_path:
base_name = os.path.splitext(os.path.basename(image_key))[0]
else:
base_name = image_key
if flip:
base_name += '_flip'
return os.path.join(data_dir, base_name)
def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip, recursive):
if is_full_path:
base_name = os.path.splitext(os.path.basename(image_key))[0]
relative_path = os.path.relpath(os.path.dirname(image_key), data_dir)
else:
base_name = image_key
relative_path = ""
if flip:
base_name += "_flip"
if recursive and relative_path:
return os.path.join(data_dir, relative_path, base_name)
else:
return os.path.join(data_dir, base_name)
def main(args):
# assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります"
if args.bucket_reso_steps % 8 > 0:
print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります")
# assert args.bucket_reso_steps % 8 == 0, f"bucket_reso_steps must be divisible by 8 / bucket_reso_stepは8で割り切れる必要があります"
if args.bucket_reso_steps % 8 > 0:
print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります")
image_paths = train_util.glob_images(args.train_data_dir)
print(f"found {len(image_paths)} images.")
train_data_dir_path = Path(args.train_data_dir)
image_paths: List[str] = [str(p) for p in train_util.glob_images_pathlib(train_data_dir_path, args.recursive)]
print(f"found {len(image_paths)} images.")
if os.path.exists(args.in_json):
print(f"loading existing metadata: {args.in_json}")
with open(args.in_json, "rt", encoding='utf-8') as f:
metadata = json.load(f)
else:
print(f"no metadata / メタデータファイルがありません: {args.in_json}")
return
weight_dtype = torch.float32
if args.mixed_precision == "fp16":
weight_dtype = torch.float16
elif args.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
vae = model_util.load_vae(args.model_name_or_path, weight_dtype)
vae.eval()
vae.to(DEVICE, dtype=weight_dtype)
# bucketのサイズを計算する
max_reso = tuple([int(t) for t in args.max_resolution.split(',')])
assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
bucket_manager = train_util.BucketManager(args.bucket_no_upscale, max_reso,
args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps)
if not args.bucket_no_upscale:
bucket_manager.make_buckets()
else:
print("min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます")
# 画像をひとつずつ適切なbucketに割り当てながらlatentを計算する
img_ar_errors = []
def process_batch(is_last):
for bucket in bucket_manager.buckets:
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
latents = get_latents(vae, [img for _, img in bucket], weight_dtype)
assert latents.shape[2] == bucket[0][1].shape[0] // 8 and latents.shape[3] == bucket[0][1].shape[1] // 8, \
f"latent shape {latents.shape}, {bucket[0][1].shape}"
for (image_key, _), latent in zip(bucket, latents):
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False)
np.savez(npz_file_name, latent)
# flip
if args.flip_aug:
latents = get_latents(vae, [img[:, ::-1].copy() for _, img in bucket], weight_dtype) # copyがないとTensor変換できない
for (image_key, _), latent in zip(bucket, latents):
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True)
np.savez(npz_file_name, latent)
else:
# remove existing flipped npz
for image_key, _ in bucket:
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz"
if os.path.isfile(npz_file_name):
print(f"remove existing flipped npz / 既存のflipされたnpzファイルを削除します: {npz_file_name}")
os.remove(npz_file_name)
bucket.clear()
# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
dataset = train_util.ImageLoadingDataset(image_paths)
data = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False,
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
else:
data = [[(None, ip)] for ip in image_paths]
bucket_counts = {}
for data_entry in tqdm(data, smoothing=0.0):
if data_entry[0] is None:
continue
img_tensor, image_path = data_entry[0]
if img_tensor is not None:
image = transforms.functional.to_pil_image(img_tensor)
if os.path.exists(args.in_json):
print(f"loading existing metadata: {args.in_json}")
with open(args.in_json, "rt", encoding="utf-8") as f:
metadata = json.load(f)
else:
try:
image = Image.open(image_path)
if image.mode != 'RGB':
image = image.convert("RGB")
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue
print(f"no metadata / メタデータファイルがありません: {args.in_json}")
return
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
if image_key not in metadata:
metadata[image_key] = {}
weight_dtype = torch.float32
if args.mixed_precision == "fp16":
weight_dtype = torch.float16
elif args.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# 本当はこのあとの部分もDataSetに持っていけば高速化できるがいろいろ大変
vae = model_util.load_vae(args.model_name_or_path, weight_dtype)
vae.eval()
vae.to(DEVICE, dtype=weight_dtype)
reso, resized_size, ar_error = bucket_manager.select_bucket(image.width, image.height)
img_ar_errors.append(abs(ar_error))
bucket_counts[reso] = bucket_counts.get(reso, 0) + 1
# メタデータに記録する解像度はlatent単位とするので、8単位で切り捨て
metadata[image_key]['train_resolution'] = (reso[0] - reso[0] % 8, reso[1] - reso[1] % 8)
# bucketのサイズを計算する
max_reso = tuple([int(t) for t in args.max_resolution.split(",")])
assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
bucket_manager = train_util.BucketManager(
args.bucket_no_upscale, max_reso, args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps
)
if not args.bucket_no_upscale:
# upscaleを行わないときには、resize後のサイズは、bucketのサイズと、縦横どちらかが同じであることを確認する
assert resized_size[0] == reso[0] or resized_size[1] == reso[
1], f"internal error, resized size not match: {reso}, {resized_size}, {image.width}, {image.height}"
assert resized_size[0] >= reso[0] and resized_size[1] >= reso[
1], f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}"
bucket_manager.make_buckets()
else:
print(
"min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます"
)
assert resized_size[0] >= reso[0] and resized_size[1] >= reso[
1], f"internal error resized size is small: {resized_size}, {reso}"
# 画像をひとつずつ適切なbucketに割り当てながらlatentを計算する
img_ar_errors = []
# 既に存在するファイルがあればshapeを確認して同じならskipする
if args.skip_existing:
npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False) + ".npz"]
if args.flip_aug:
npz_files.append(get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz")
def process_batch(is_last):
for bucket in bucket_manager.buckets:
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
latents = get_latents(vae, [img for _, img in bucket], weight_dtype)
assert (
latents.shape[2] == bucket[0][1].shape[0] // 8 and latents.shape[3] == bucket[0][1].shape[1] // 8
), f"latent shape {latents.shape}, {bucket[0][1].shape}"
found = True
for npz_file in npz_files:
if not os.path.exists(npz_file):
found = False
break
for (image_key, _), latent in zip(bucket, latents):
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive)
np.savez(npz_file_name, latent)
dat = np.load(npz_file)['arr_0']
if dat.shape[1] != reso[1] // 8 or dat.shape[2] != reso[0] // 8: # latentsのshapeを確認
found = False
break
if found:
continue
# flip
if args.flip_aug:
latents = get_latents(vae, [img[:, ::-1].copy() for _, img in bucket], weight_dtype) # copyがないとTensor変換できない
# 画像をリサイズしてトリミングする
# PILにinter_areaがないのでcv2で……
image = np.array(image)
if resized_size[0] != image.shape[1] or resized_size[1] != image.shape[0]: # リサイズ処理が必要?
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA)
for (image_key, _), latent in zip(bucket, latents):
npz_file_name = get_npz_filename_wo_ext(
args.train_data_dir, image_key, args.full_path, True, args.recursive
)
np.savez(npz_file_name, latent)
else:
# remove existing flipped npz
for image_key, _ in bucket:
npz_file_name = (
get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz"
)
if os.path.isfile(npz_file_name):
print(f"remove existing flipped npz / 既存のflipされたnpzファイルを削除します: {npz_file_name}")
os.remove(npz_file_name)
if resized_size[0] > reso[0]:
trim_size = resized_size[0] - reso[0]
image = image[:, trim_size//2:trim_size//2 + reso[0]]
bucket.clear()
if resized_size[1] > reso[1]:
trim_size = resized_size[1] - reso[1]
image = image[trim_size//2:trim_size//2 + reso[1]]
# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
dataset = train_util.ImageLoadingDataset(image_paths)
data = torch.utils.data.DataLoader(
dataset,
batch_size=1,
shuffle=False,
num_workers=args.max_data_loader_n_workers,
collate_fn=collate_fn_remove_corrupted,
drop_last=False,
)
else:
data = [[(None, ip)] for ip in image_paths]
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
bucket_counts = {}
for data_entry in tqdm(data, smoothing=0.0):
if data_entry[0] is None:
continue
# # debug
# cv2.imwrite(f"r:\\test\\img_{len(img_ar_errors)}.jpg", image[:, :, ::-1])
img_tensor, image_path = data_entry[0]
if img_tensor is not None:
image = transforms.functional.to_pil_image(img_tensor)
else:
try:
image = Image.open(image_path)
if image.mode != "RGB":
image = image.convert("RGB")
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue
# バッチへ追加
bucket_manager.add_image(reso, (image_key, image))
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
if image_key not in metadata:
metadata[image_key] = {}
# バッチを推論するか判定して推論する
process_batch(False)
# 本当はこのあとの部分もDataSetに持っていけば高速化できるがいろいろ大変
# 残りを処理する
process_batch(True)
reso, resized_size, ar_error = bucket_manager.select_bucket(image.width, image.height)
img_ar_errors.append(abs(ar_error))
bucket_counts[reso] = bucket_counts.get(reso, 0) + 1
bucket_manager.sort()
for i, reso in enumerate(bucket_manager.resos):
count = bucket_counts.get(reso, 0)
if count > 0:
print(f"bucket {i} {reso}: {count}")
img_ar_errors = np.array(img_ar_errors)
print(f"mean ar error: {np.mean(img_ar_errors)}")
# メタデータに記録する解像度はlatent単位とするので、8単位で切り捨て
metadata[image_key]["train_resolution"] = (reso[0] - reso[0] % 8, reso[1] - reso[1] % 8)
# metadataを書き出して終わり
print(f"writing metadata: {args.out_json}")
with open(args.out_json, "wt", encoding='utf-8') as f:
json.dump(metadata, f, indent=2)
print("done!")
if not args.bucket_no_upscale:
# upscaleを行わないときには、resize後のサイズは、bucketのサイズと、縦横どちらかが同じであることを確認する
assert (
resized_size[0] == reso[0] or resized_size[1] == reso[1]
), f"internal error, resized size not match: {reso}, {resized_size}, {image.width}, {image.height}"
assert (
resized_size[0] >= reso[0] and resized_size[1] >= reso[1]
), f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}"
assert (
resized_size[0] >= reso[0] and resized_size[1] >= reso[1]
), f"internal error resized size is small: {resized_size}, {reso}"
# 既に存在するファイルがあればshapeを確認して同じならskipする
if args.skip_existing:
npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive) + ".npz"]
if args.flip_aug:
npz_files.append(
get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz"
)
found = True
for npz_file in npz_files:
if not os.path.exists(npz_file):
found = False
break
dat = np.load(npz_file)["arr_0"]
if dat.shape[1] != reso[1] // 8 or dat.shape[2] != reso[0] // 8: # latentsのshapeを確認
found = False
break
if found:
continue
# 画像をリサイズしてトリミングする
# PILにinter_areaがないのでcv2で……
image = np.array(image)
if resized_size[0] != image.shape[1] or resized_size[1] != image.shape[0]: # リサイズ処理が必要?
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA)
if resized_size[0] > reso[0]:
trim_size = resized_size[0] - reso[0]
image = image[:, trim_size // 2 : trim_size // 2 + reso[0]]
if resized_size[1] > reso[1]:
trim_size = resized_size[1] - reso[1]
image = image[trim_size // 2 : trim_size // 2 + reso[1]]
assert (
image.shape[0] == reso[1] and image.shape[1] == reso[0]
), f"internal error, illegal trimmed size: {image.shape}, {reso}"
# # debug
# cv2.imwrite(f"r:\\test\\img_{len(img_ar_errors)}.jpg", image[:, :, ::-1])
# バッチへ追加
bucket_manager.add_image(reso, (image_key, image))
# バッチを推論するか判定して推論する
process_batch(False)
# 残りを処理する
process_batch(True)
bucket_manager.sort()
for i, reso in enumerate(bucket_manager.resos):
count = bucket_counts.get(reso, 0)
if count > 0:
print(f"bucket {i} {reso}: {count}")
img_ar_errors = np.array(img_ar_errors)
print(f"mean ar error: {np.mean(img_ar_errors)}")
# metadataを書き出して終わり
print(f"writing metadata: {args.out_json}")
with open(args.out_json, "wt", encoding="utf-8") as f:
json.dump(metadata, f, indent=2)
print("done!")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
parser.add_argument("--v2", action='store_true',
help='not used (for backward compatibility) / 使用されません(互換性のため残してあります)')
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する読み込みを高速化")
parser.add_argument("--max_resolution", type=str, default="512,512",
help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します")
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度")
parser.add_argument("--bucket_reso_steps", type=int, default=64,
help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します")
parser.add_argument("--bucket_no_upscale", action="store_true",
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します")
parser.add_argument("--mixed_precision", type=str, default="no",
choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精")
parser.add_argument("--full_path", action="store_true",
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
parser.add_argument("--flip_aug", action="store_true",
help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する")
parser.add_argument("--skip_existing", action="store_true",
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップするflip_aug有効時は通常、反転の両方が存在する画像をスキップ")
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
parser.add_argument("--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)")
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument(
"--max_data_loader_n_workers",
type=int,
default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する読み込みを高速化",
)
parser.add_argument(
"--max_resolution",
type=str,
default="512,512",
help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します",
)
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像")
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度")
parser.add_argument(
"--bucket_reso_steps",
type=int,
default=64,
help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します",
)
parser.add_argument(
"--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します"
)
parser.add_argument(
"--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度"
)
parser.add_argument(
"--full_path",
action="store_true",
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)",
)
parser.add_argument(
"--flip_aug", action="store_true", help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する"
)
parser.add_argument(
"--skip_existing",
action="store_true",
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップするflip_aug有効時は通常、反転の両方が存在する画像をスキップ",
)
parser.add_argument(
"--recursive",
action="store_true",
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す",
)
return parser
return parser
if __name__ == '__main__':
parser = setup_parser()
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
main(args)
args = parser.parse_args()
main(args)

View File

@@ -10,6 +10,7 @@ import numpy as np
from tensorflow.keras.models import load_model
from huggingface_hub import hf_hub_download
import torch
from pathlib import Path
import library.train_util as train_util
@@ -17,7 +18,7 @@ import library.train_util as train_util
IMAGE_SIZE = 448
# wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2
DEFAULT_WD14_TAGGER_REPO = 'SmilingWolf/wd-v1-4-convnext-tagger-v2'
DEFAULT_WD14_TAGGER_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"]
SUB_DIR = "variables"
SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"]
@@ -25,182 +26,273 @@ CSV_FILE = FILES[-1]
def preprocess_image(image):
image = np.array(image)
image = image[:, :, ::-1] # RGB->BGR
image = np.array(image)
image = image[:, :, ::-1] # RGB->BGR
# pad to square
size = max(image.shape[0:2])
pad_x = size - image.shape[1]
pad_y = size - image.shape[0]
pad_l = pad_x // 2
pad_t = pad_y // 2
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode='constant', constant_values=255)
# pad to square
size = max(image.shape[0:2])
pad_x = size - image.shape[1]
pad_y = size - image.shape[0]
pad_l = pad_x // 2
pad_t = pad_y // 2
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
image = image.astype(np.float32)
return image
image = image.astype(np.float32)
return image
class ImageLoadingPrepDataset(torch.utils.data.Dataset):
def __init__(self, image_paths):
self.images = image_paths
def __init__(self, image_paths):
self.images = image_paths
def __len__(self):
return len(self.images)
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = self.images[idx]
def __getitem__(self, idx):
img_path = str(self.images[idx])
try:
image = Image.open(img_path).convert("RGB")
image = preprocess_image(image)
tensor = torch.tensor(image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
return None
try:
image = Image.open(img_path).convert("RGB")
image = preprocess_image(image)
tensor = torch.tensor(image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
return None
return (tensor, img_path)
return (tensor, img_path)
def collate_fn_remove_corrupted(batch):
"""Collate function that allows to remove corrupted examples in the
dataloader. It expects that the dataloader returns 'None' when that occurs.
The 'None's in the batch are removed.
"""
# Filter out all the Nones (corrupted examples)
batch = list(filter(lambda x: x is not None, batch))
return batch
"""Collate function that allows to remove corrupted examples in the
dataloader. It expects that the dataloader returns 'None' when that occurs.
The 'None's in the batch are removed.
"""
# Filter out all the Nones (corrupted examples)
batch = list(filter(lambda x: x is not None, batch))
return batch
def main(args):
# hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする
# depreacatedの警告が出るけどなくなったらその時
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
if not os.path.exists(args.model_dir) or args.force_download:
print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
for file in FILES:
hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
for file in SUB_DIR_FILES:
hf_hub_download(args.repo_id, file, subfolder=SUB_DIR, cache_dir=os.path.join(
args.model_dir, SUB_DIR), force_download=True, force_filename=file)
else:
print("using existing wd14 tagger model")
# hf_hub_downloadをそのまま使うとsymlink関係で問題があるらしいので、キャッシュディレクトリとforce_filenameを指定してなんとかする
# depreacatedの警告が出るけどなくなったらその時
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
if not os.path.exists(args.model_dir) or args.force_download:
print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
for file in FILES:
hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
for file in SUB_DIR_FILES:
hf_hub_download(
args.repo_id,
file,
subfolder=SUB_DIR,
cache_dir=os.path.join(args.model_dir, SUB_DIR),
force_download=True,
force_filename=file,
)
else:
print("using existing wd14 tagger model")
# 画像を読み込む
image_paths = train_util.glob_images(args.train_data_dir)
print(f"found {len(image_paths)} images.")
# 画像を読み込む
model = load_model(args.model_dir)
print("loading model and labels")
model = load_model(args.model_dir)
# label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv")
# 依存ライブラリを増やしたくないので自力で読むよ
# label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv")
# 依存ライブラリを増やしたくないので自力で読むよ
with open(os.path.join(args.model_dir, CSV_FILE), "r", encoding="utf-8") as f:
reader = csv.reader(f)
l = [row for row in reader]
header = l[0] # tag_id,name,category,count
rows = l[1:]
assert header[0] == 'tag_id' and header[1] == 'name' and header[2] == 'category', f"unexpected csv format: {header}"
with open(os.path.join(args.model_dir, CSV_FILE), "r", encoding="utf-8") as f:
reader = csv.reader(f)
l = [row for row in reader]
header = l[0] # tag_id,name,category,count
rows = l[1:]
assert header[0] == "tag_id" and header[1] == "name" and header[2] == "category", f"unexpected csv format: {header}"
tags = [row[1] for row in rows[1:] if row[2] == '0'] # categoryが0、つまり通常のタグのみ
general_tags = [row[1] for row in rows[1:] if row[2] == "0"]
character_tags = [row[1] for row in rows[1:] if row[2] == "4"]
# 推論する
def run_batch(path_imgs):
imgs = np.array([im for _, im in path_imgs])
# 画像を読み込む
probs = model(imgs, training=False)
probs = probs.numpy()
train_data_dir_path = Path(args.train_data_dir)
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.")
for (image_path, _), prob in zip(path_imgs, probs):
# 最初の4つはratingなので無視する
# # First 4 labels are actually ratings: pick one with argmax
# ratings_names = label_names[:4]
# rating_index = ratings_names["probs"].argmax()
# found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]]
tag_freq = {}
# それ以降はタグなのでconfidenceがthresholdより高いものを追加する
# Everything else is tags: pick any where prediction confidence > threshold
tag_text = ""
for i, p in enumerate(prob[4:]): # numpyとか使うのが良いけど、まあそれほど数も多くないのでループで
if p >= args.thresh and i < len(tags):
tag_text += ", " + tags[i]
undesired_tags = set(args.undesired_tags.split(","))
if len(tag_text) > 0:
tag_text = tag_text[2:] # 最初の ", " を消す
def run_batch(path_imgs):
imgs = np.array([im for _, im in path_imgs])
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
f.write(tag_text + '\n')
if args.debug:
print(image_path, tag_text)
probs = model(imgs, training=False)
probs = probs.numpy()
# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
dataset = ImageLoadingPrepDataset(image_paths)
data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
else:
data = [[(None, ip)] for ip in image_paths]
for (image_path, _), prob in zip(path_imgs, probs):
# 最初の4つはratingなので無視する
# # First 4 labels are actually ratings: pick one with argmax
# ratings_names = label_names[:4]
# rating_index = ratings_names["probs"].argmax()
# found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]]
b_imgs = []
for data_entry in tqdm(data, smoothing=0.0):
for data in data_entry:
if data is None:
continue
# それ以降はタグなのでconfidenceがthresholdより高いものを追加する
# Everything else is tags: pick any where prediction confidence > threshold
combined_tags = []
general_tag_text = ""
character_tag_text = ""
for i, p in enumerate(prob[4:]):
if i < len(general_tags) and p >= args.general_threshold:
tag_name = general_tags[i]
if args.remove_underscore and len(tag_name) > 3: # ignore emoji tags like >_< and ^_^
tag_name = tag_name.replace("_", " ")
image, image_path = data
if image is not None:
image = image.detach().numpy()
else:
try:
image = Image.open(image_path)
if image.mode != 'RGB':
image = image.convert("RGB")
image = preprocess_image(image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue
b_imgs.append((image_path, image))
if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
general_tag_text += ", " + tag_name
combined_tags.append(tag_name)
elif i >= len(general_tags) and p >= args.character_threshold:
tag_name = character_tags[i - len(general_tags)]
if args.remove_underscore and len(tag_name) > 3:
tag_name = tag_name.replace("_", " ")
if len(b_imgs) >= args.batch_size:
if tag_name not in undesired_tags:
tag_freq[tag_name] = tag_freq.get(tag_name, 0) + 1
character_tag_text += ", " + tag_name
combined_tags.append(tag_name)
# 先頭のカンマを取る
if len(general_tag_text) > 0:
general_tag_text = general_tag_text[2:]
if len(character_tag_text) > 0:
character_tag_text = character_tag_text[2:]
tag_text = ", ".join(combined_tags)
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
f.write(tag_text + "\n")
if args.debug:
print(f"\n{image_path}:\n Character tags: {character_tag_text}\n General tags: {general_tag_text}")
# 読み込みの高速化のためにDataLoaderを使うオプション
if args.max_data_loader_n_workers is not None:
dataset = ImageLoadingPrepDataset(image_paths)
data = torch.utils.data.DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.max_data_loader_n_workers,
collate_fn=collate_fn_remove_corrupted,
drop_last=False,
)
else:
data = [[(None, ip)] for ip in image_paths]
b_imgs = []
for data_entry in tqdm(data, smoothing=0.0):
for data in data_entry:
if data is None:
continue
image, image_path = data
if image is not None:
image = image.detach().numpy()
else:
try:
image = Image.open(image_path)
if image.mode != "RGB":
image = image.convert("RGB")
image = preprocess_image(image)
except Exception as e:
print(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
continue
b_imgs.append((image_path, image))
if len(b_imgs) >= args.batch_size:
b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string
run_batch(b_imgs)
b_imgs.clear()
if len(b_imgs) > 0:
b_imgs = [(str(image_path), image) for image_path, image in b_imgs] # Convert image_path to string
run_batch(b_imgs)
b_imgs.clear()
if len(b_imgs) > 0:
run_batch(b_imgs)
if args.frequency_tags:
sorted_tags = sorted(tag_freq.items(), key=lambda x: x[1], reverse=True)
print("\nTag frequencies:")
for tag, freq in sorted_tags:
print(f"{tag}: {freq}")
print("done!")
print("done!")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("--repo_id", type=str, default=DEFAULT_WD14_TAGGER_REPO,
help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID")
parser.add_argument("--model_dir", type=str, default="wd14_tagger_model",
help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ")
parser.add_argument("--force_download", action='store_true',
help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします")
parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値")
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument("--max_data_loader_n_workers", type=int, default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する読み込みを高速化")
parser.add_argument("--caption_extention", type=str, default=None,
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子")
parser.add_argument("--debug", action="store_true", help="debug mode")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument(
"--repo_id",
type=str,
default=DEFAULT_WD14_TAGGER_REPO,
help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID",
)
parser.add_argument(
"--model_dir",
type=str,
default="wd14_tagger_model",
help="directory to store wd14 tagger model / wd14 taggerのモデルを格納するディレクトリ",
)
parser.add_argument(
"--force_download", action="store_true", help="force downloading wd14 tagger models / wd14 taggerのモデルを再ダウンロードします"
)
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
parser.add_argument(
"--max_data_loader_n_workers",
type=int,
default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する読み込みを高速化",
)
parser.add_argument(
"--caption_extention",
type=str,
default=None,
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)",
)
parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子")
parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値")
parser.add_argument(
"--general_threshold",
type=float,
default=None,
help="threshold of confidence to add a tag for general category, same as --thresh if omitted / generalカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ",
)
parser.add_argument(
"--character_threshold",
type=float,
default=None,
help="threshold of confidence to add a tag for character category, same as --thres if omitted / characterカテゴリのタグを追加するための確信度の閾値、省略時は --thresh と同じ",
)
parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する")
parser.add_argument(
"--remove_underscore",
action="store_true",
help="replace underscores with spaces in the output tags / 出力されるタグのアンダースコアをスペースに置き換える",
)
parser.add_argument("--debug", action="store_true", help="debug mode")
parser.add_argument(
"--undesired_tags",
type=str,
default="",
help="comma-separated list of undesired tags to remove from the output / 出力から除外したいタグのカンマ区切りのリスト",
)
parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する")
return parser
args = parser.parse_args()
# スペルミスしていたオプションを復元する
if args.caption_extention is not None:
args.caption_extension = args.caption_extention
if __name__ == '__main__':
parser = setup_parser()
if args.general_threshold is None:
args.general_threshold = args.thresh
if args.character_threshold is None:
args.character_threshold = args.thresh
args = parser.parse_args()
# スペルミスしていたオプションを復元する
if args.caption_extention is not None:
args.caption_extension = args.caption_extention
main(args)
main(args)

View File

@@ -945,7 +945,7 @@ class PipelineLike:
# encode the init image into latents and scale the latents
init_image = init_image.to(device=self.device, dtype=latents_dtype)
if init_image.size()[2:] == (height // 8, width // 8):
if init_image.size()[1:] == (height // 8, width // 8):
init_latents = init_image
else:
if vae_batch_size >= batch_size:
@@ -1015,7 +1015,7 @@ class PipelineLike:
if self.control_nets:
if reginonal_network:
num_sub_and_neg_prompts = len(text_embeddings) // batch_size
text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2::num_sub_and_neg_prompts] # last subprompt
text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt
else:
text_emb_last = text_embeddings
noise_pred = original_control_net.call_unet_and_control_net(
@@ -2318,6 +2318,22 @@ def main(args):
else:
networks = []
# upscalerの指定があれば取得する
upscaler = None
if args.highres_fix_upscaler:
print("import upscaler module:", args.highres_fix_upscaler)
imported_module = importlib.import_module(args.highres_fix_upscaler)
us_kwargs = {}
if args.highres_fix_upscaler_args:
for net_arg in args.highres_fix_upscaler_args.split(";"):
key, value = net_arg.split("=")
us_kwargs[key] = value
print("create upscaler")
upscaler = imported_module.create_upscaler(**us_kwargs)
upscaler.to(dtype).to(device)
# ControlNetの処理
control_nets: List[ControlNetInfo] = []
if args.control_net_models:
@@ -2590,7 +2606,7 @@ def main(args):
np_mask = np_mask[:, :, i]
size = np_mask.shape
else:
np_mask = np.full(size, 255, dtype=np.uint8)
np_mask = np.full(size, 255, dtype=np.uint8)
mask = torch.from_numpy(np_mask.astype(np.float32) / 255.0)
network.set_region(i, i == len(networks) - 1, mask)
mask_images = None
@@ -2639,6 +2655,8 @@ def main(args):
# highres_fixの処理
if highres_fix and not highres_1st:
# 1st stageのバッチを作成して呼び出すサイズを小さくして呼び出す
is_1st_latent = upscaler.support_latents() if upscaler else args.highres_fix_latents_upscaling
print("process 1st stage")
batch_1st = []
for _, base, ext in batch:
@@ -2657,12 +2675,32 @@ def main(args):
ext.network_muls,
ext.num_sub_prompts,
)
batch_1st.append(BatchData(args.highres_fix_latents_upscaling, base, ext_1st))
batch_1st.append(BatchData(is_1st_latent, base, ext_1st))
images_1st = process_batch(batch_1st, True, True)
# 2nd stageのバッチを作成して以下処理する
print("process 2nd stage")
if args.highres_fix_latents_upscaling:
width_2nd, height_2nd = batch[0].ext.width, batch[0].ext.height
if upscaler:
# upscalerを使って画像を拡大する
lowreso_imgs = None if is_1st_latent else images_1st
lowreso_latents = None if not is_1st_latent else images_1st
# 戻り値はPIL.Image.Imageかtorch.Tensorのlatents
batch_size = len(images_1st)
vae_batch_size = (
batch_size
if args.vae_batch_size is None
else (max(1, int(batch_size * args.vae_batch_size)) if args.vae_batch_size < 1 else args.vae_batch_size)
)
vae_batch_size = int(vae_batch_size)
images_1st = upscaler.upscale(
vae, lowreso_imgs, lowreso_latents, dtype, width_2nd, height_2nd, batch_size, vae_batch_size
)
elif args.highres_fix_latents_upscaling:
# latentを拡大する
org_dtype = images_1st.dtype
if images_1st.dtype == torch.bfloat16:
images_1st = images_1st.to(torch.float) # interpolateがbf16をサポートしていない
@@ -2671,10 +2709,12 @@ def main(args):
) # , antialias=True)
images_1st = images_1st.to(org_dtype)
else:
# 画像をLANCZOSで拡大する
images_1st = [image.resize((width_2nd, height_2nd), resample=PIL.Image.LANCZOS) for image in images_1st]
batch_2nd = []
for i, (bd, image) in enumerate(zip(batch, images_1st)):
if not args.highres_fix_latents_upscaling:
image = image.resize((bd.ext.width, bd.ext.height), resample=PIL.Image.LANCZOS) # img2imgとして設定
bd_2nd = BatchData(False, BatchDataBase(*bd.base[0:3], bd.base.seed + 1, image, None, *bd.base[6:]), bd.ext)
batch_2nd.append(bd_2nd)
batch = batch_2nd
@@ -3229,6 +3269,16 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="use latents upscaling for highres fix / highres fixでlatentで拡大する",
)
parser.add_argument(
"--highres_fix_upscaler", type=str, default=None, help="upscaler module for highres fix / highres fixで使うupscalerのモジュール名"
)
parser.add_argument(
"--highres_fix_upscaler_args",
type=str,
default=None,
help="additional argmuments for upscaler (key=value) / upscalerへの追加の引数",
)
parser.add_argument(
"--negative_scale", type=float, default=None, help="set another guidance scale for negative prompt / ネガティブプロンプトのscaleを指定する"
)

View File

@@ -722,7 +722,7 @@ class BaseDataset(torch.utils.data.Dataset):
def is_latent_cacheable(self):
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
def cache_latents(self, vae, vae_batch_size=1):
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
# ちょっと速くした
print("caching latents.")
@@ -740,11 +740,38 @@ class BaseDataset(torch.utils.data.Dataset):
if info.latents_npz is not None:
info.latents = self.load_latents_from_npz(info, False)
info.latents = torch.FloatTensor(info.latents)
info.latents_flipped = self.load_latents_from_npz(info, True) # might be None
# might be None, but that's ok because check is done in dataset
info.latents_flipped = self.load_latents_from_npz(info, True)
if info.latents_flipped is not None:
info.latents_flipped = torch.FloatTensor(info.latents_flipped)
continue
# check disk cache exists and size of latents
if cache_to_disk:
# TODO: refactor to unify with FineTuningDataset
info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz"
info.latents_npz_flipped = os.path.splitext(info.absolute_path)[0] + "_flip.npz"
if not is_main_process:
continue
cache_available = False
expected_latents_size = (info.bucket_reso[1] // 8, info.bucket_reso[0] // 8) # bucket_resoはWxHなので注意
if os.path.exists(info.latents_npz):
cached_latents = np.load(info.latents_npz)["arr_0"]
if cached_latents.shape[1:3] == expected_latents_size:
cache_available = True
if subset.flip_aug:
cache_available = False
if os.path.exists(info.latents_npz_flipped):
cached_latents_flipped = np.load(info.latents_npz_flipped)["arr_0"]
if cached_latents_flipped.shape[1:3] == expected_latents_size:
cache_available = True
if cache_available:
continue
# if last member of batch has different resolution, flush the batch
if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso:
batches.append(batch)
@@ -760,6 +787,9 @@ class BaseDataset(torch.utils.data.Dataset):
if len(batch) > 0:
batches.append(batch)
if cache_to_disk and not is_main_process: # don't cache latents in non-main process, set to info only
return
# iterate batches
for batch in tqdm(batches, smoothing=1, total=len(batches)):
images = []
@@ -773,14 +803,21 @@ class BaseDataset(torch.utils.data.Dataset):
img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype)
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
for info, latent in zip(batch, latents):
info.latents = latent
if cache_to_disk:
np.savez(info.latents_npz, latent.float().numpy())
else:
info.latents = latent
if subset.flip_aug:
img_tensors = torch.flip(img_tensors, dims=[3])
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
for info, latent in zip(batch, latents):
info.latents_flipped = latent
if cache_to_disk:
np.savez(info.latents_npz_flipped, latent.float().numpy())
else:
info.latents_flipped = latent
def get_image_size(self, image_path):
image = Image.open(image_path)
@@ -808,9 +845,10 @@ class BaseDataset(torch.utils.data.Dataset):
# 画像サイズはsizeより大きいのでリサイズする
face_size = max(face_w, face_h)
size = min(self.height, self.width) # 短いほう
min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
min_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[1]))) # 指定した顔最小サイズ
max_scale = min(1.0, max(min_scale, self.size / (face_size * subset.face_crop_aug_range[0]))) # 指定した顔最大サイズ
min_scale = min(1.0, max(min_scale, size / (face_size * subset.face_crop_aug_range[1]))) # 指定した顔最小サイズ
max_scale = min(1.0, max(min_scale, size / (face_size * subset.face_crop_aug_range[0]))) # 指定した顔最大サイズ
if min_scale >= max_scale: # range指定がmin==max
scale = min_scale
else:
@@ -835,7 +873,7 @@ class BaseDataset(torch.utils.data.Dataset):
else:
# range指定があるときのみ、すこしだけランダムにわりと適当
if subset.face_crop_aug_range[0] != subset.face_crop_aug_range[1]:
if face_size > self.size // 10 and face_size >= 40:
if face_size > size // 10 and face_size >= 40:
p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
p1 = max(0, min(p1, length - target_size))
@@ -873,10 +911,10 @@ class BaseDataset(torch.utils.data.Dataset):
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
# image/latentsを処理する
if image_info.latents is not None:
if image_info.latents is not None: # cache_latents=Trueの場合
latents = image_info.latents if not subset.flip_aug or random.random() < 0.5 else image_info.latents_flipped
image = None
elif image_info.latents_npz is not None:
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= 0.5)
latents = torch.FloatTensor(latents)
image = None
@@ -1163,19 +1201,27 @@ class FineTuningDataset(BaseDataset):
tags_list = []
for image_key, img_md in metadata.items():
# path情報を作る
abs_path = None
# まず画像を優先して探す
if os.path.exists(image_key):
abs_path = image_key
elif os.path.exists(os.path.splitext(image_key)[0] + ".npz"):
abs_path = os.path.splitext(image_key)[0] + ".npz"
else:
npz_path = os.path.join(subset.image_dir, image_key + ".npz")
if os.path.exists(npz_path):
abs_path = npz_path
# わりといい加減だがいい方法が思いつかん
paths = glob_images(subset.image_dir, image_key)
if len(paths) > 0:
abs_path = paths[0]
# なければnpzを探す
if abs_path is None:
if os.path.exists(os.path.splitext(image_key)[0] + ".npz"):
abs_path = os.path.splitext(image_key)[0] + ".npz"
else:
# わりといい加減だがいい方法が思いつかん
abs_path = glob_images(subset.image_dir, image_key)
assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}"
abs_path = abs_path[0]
npz_path = os.path.join(subset.image_dir, image_key + ".npz")
if os.path.exists(npz_path):
abs_path = npz_path
assert abs_path is not None, f"no image / 画像がありません: {image_key}"
caption = img_md.get("caption")
tags = img_md.get("tags")
@@ -1340,10 +1386,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
for dataset in self.datasets:
dataset.enable_XTI(*args, **kwargs)
def cache_latents(self, vae, vae_batch_size=1):
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
for i, dataset in enumerate(self.datasets):
print(f"[Dataset {i}]")
dataset.cache_latents(vae, vae_batch_size)
dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process)
def is_latent_cacheable(self) -> bool:
return all([dataset.is_latent_cacheable() for dataset in self.datasets])
@@ -1400,8 +1446,8 @@ def debug_dataset(train_dataset, show_input_ids=False):
im = im[:, :, ::-1] # RGB -> BGR (OpenCV)
if os.name == "nt": # only windows
cv2.imshow("img", im)
k = cv2.waitKey()
cv2.destroyAllWindows()
k = cv2.waitKey()
cv2.destroyAllWindows()
if k == 27 or k == ord("s") or k == ord("e"):
break
steps += 1
@@ -2022,7 +2068,26 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=None,
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する",
)
parser.add_argument(
"--log_with",
type=str,
default=None,
choices=["tensorboard", "wandb", "all"],
help="what logging tool(s) to use (if 'all', TensorBoard and WandB are both used) / ログ出力に使用するツール (allを指定するとTensorBoardとWandBの両方が使用される)",
)
parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
parser.add_argument(
"--log_tracker_name",
type=str,
default=None,
help="name of tracker to use for logging, default is script-specific default name / ログ出力に使用するtrackerの名前、省略時はスクリプトごとのデフォルト名",
)
parser.add_argument(
"--wandb_api_key",
type=str,
default=None,
help="specify WandB API key to log in before starting training (optional). / WandB APIキーを指定して学習開始前にログインするオプション",
)
parser.add_argument(
"--noise_offset",
type=float,
@@ -2144,9 +2209,14 @@ def add_dataset_arguments(
parser.add_argument(
"--cache_latents",
action="store_true",
help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheするaugmentationは使用不可",
help="cache latents to main memory to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをメインメモリにcacheするaugmentationは使用不可 ",
)
parser.add_argument("--vae_batch_size", type=int, default=1, help="batch size for caching latents / latentのcache時のバッチサイズ")
parser.add_argument(
"--cache_latents_to_disk",
action="store_true",
help="cache latents to disk to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをディスクにcacheするaugmentationは使用不可",
)
parser.add_argument(
"--enable_bucket", action="store_true", help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする"
)
@@ -2238,7 +2308,7 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar
args_dict = vars(args)
# remove unnecessary keys
for key in ["config_file", "output_config"]:
for key in ["config_file", "output_config", "wandb_api_key"]:
if key in args_dict:
del args_dict[key]
@@ -2682,13 +2752,32 @@ def load_tokenizer(args: argparse.Namespace):
def prepare_accelerator(args: argparse.Namespace):
if args.logging_dir is None:
log_with = None
logging_dir = None
else:
log_with = "tensorboard"
log_prefix = "" if args.log_prefix is None else args.log_prefix
logging_dir = args.logging_dir + "/" + log_prefix + time.strftime("%Y%m%d%H%M%S", time.localtime())
if args.log_with is None:
if logging_dir is not None:
log_with = "tensorboard"
else:
log_with = None
else:
log_with = args.log_with
if log_with in ["tensorboard", "all"]:
if logging_dir is None:
raise ValueError("logging_dir is required when log_with is tensorboard / Tensorboardを使う場合、logging_dirを指定してください")
if log_with in ["wandb", "all"]:
try:
import wandb
except ImportError:
raise ImportError("No wandb / wandb がインストールされていないようです")
if logging_dir is not None:
os.makedirs(logging_dir, exist_ok=True)
os.environ["WANDB_DIR"] = logging_dir
if args.wandb_api_key is not None:
wandb.login(key=args.wandb_api_key)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
@@ -3147,6 +3236,18 @@ def sample_images(
image.save(os.path.join(save_dir, img_filename))
# wandb有効時のみログを送信
try:
wandb_tracker = accelerator.get_tracker("wandb")
try:
import wandb
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
raise ImportError("No wandb / wandb がインストールされていないようです")
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
except: # wandb 無効時
pass
# clear pipeline and cache to reduce vram usage
del pipeline
torch.cuda.empty_cache()
@@ -3203,4 +3304,4 @@ class collater_class:
# set epoch and step
dataset.set_current_epoch(self.current_epoch.value)
dataset.set_current_step(self.current_step.value)
return examples[0]
return examples[0]

450
networks/dylora.py Normal file
View File

@@ -0,0 +1,450 @@
# some codes are copied from:
# https://github.com/huawei-noah/KD-NLP/blob/main/DyLoRA/
# Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved.
# Changes made to the original code:
# 2022.08.20 - Integrate the DyLoRA layer for the LoRA Linear layer
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import math
import os
import random
from typing import List, Tuple, Union
import torch
from torch import nn
class DyLoRAModule(torch.nn.Module):
"""
replaces forward method of the original Linear, instead of replacing the original Linear module.
"""
# NOTE: support dropout in future
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, unit=1):
super().__init__()
self.lora_name = lora_name
self.lora_dim = lora_dim
self.unit = unit
assert self.lora_dim % self.unit == 0, "rank must be a multiple of unit"
if org_module.__class__.__name__ == "Conv2d":
in_dim = org_module.in_channels
out_dim = org_module.out_channels
else:
in_dim = org_module.in_features
out_dim = org_module.out_features
if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
self.scale = alpha / self.lora_dim
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
self.is_conv2d = org_module.__class__.__name__ == "Conv2d"
self.is_conv2d_3x3 = self.is_conv2d and org_module.kernel_size == (3, 3)
if self.is_conv2d and self.is_conv2d_3x3:
kernel_size = org_module.kernel_size
self.stride = org_module.stride
self.padding = org_module.padding
self.lora_A = nn.ParameterList([org_module.weight.new_zeros((1, in_dim, *kernel_size)) for _ in range(self.lora_dim)])
self.lora_B = nn.ParameterList([org_module.weight.new_zeros((out_dim, 1, 1, 1)) for _ in range(self.lora_dim)])
else:
self.lora_A = nn.ParameterList([org_module.weight.new_zeros((1, in_dim)) for _ in range(self.lora_dim)])
self.lora_B = nn.ParameterList([org_module.weight.new_zeros((out_dim, 1)) for _ in range(self.lora_dim)])
# same as microsoft's
for lora in self.lora_A:
torch.nn.init.kaiming_uniform_(lora, a=math.sqrt(5))
for lora in self.lora_B:
torch.nn.init.zeros_(lora)
self.multiplier = multiplier
self.org_module = org_module # remove in applying
def apply_to(self):
self.org_forward = self.org_module.forward
self.org_module.forward = self.forward
del self.org_module
def forward(self, x):
result = self.org_forward(x)
# specify the dynamic rank
trainable_rank = random.randint(0, self.lora_dim - 1)
trainable_rank = trainable_rank - trainable_rank % self.unit # make sure the rank is a multiple of unit
# 一部のパラメータを固定して、残りのパラメータを学習する
for i in range(0, trainable_rank):
self.lora_A[i].requires_grad = False
self.lora_B[i].requires_grad = False
for i in range(trainable_rank, trainable_rank + self.unit):
self.lora_A[i].requires_grad = True
self.lora_B[i].requires_grad = True
for i in range(trainable_rank + self.unit, self.lora_dim):
self.lora_A[i].requires_grad = False
self.lora_B[i].requires_grad = False
lora_A = torch.cat(tuple(self.lora_A), dim=0)
lora_B = torch.cat(tuple(self.lora_B), dim=1)
# calculate with lora_A and lora_B
if self.is_conv2d_3x3:
ab = torch.nn.functional.conv2d(x, lora_A, stride=self.stride, padding=self.padding)
ab = torch.nn.functional.conv2d(ab, lora_B)
else:
ab = x
if self.is_conv2d:
ab = ab.reshape(ab.size(0), ab.size(1), -1).transpose(1, 2) # (N, C, H, W) -> (N, H*W, C)
ab = torch.nn.functional.linear(ab, lora_A)
ab = torch.nn.functional.linear(ab, lora_B)
if self.is_conv2d:
ab = ab.transpose(1, 2).reshape(ab.size(0), -1, *x.size()[2:]) # (N, H*W, C) -> (N, C, H, W)
# 最後の項は、低rankをより大きくするためのスケーリングじゃないかな
result = result + ab * self.scale * math.sqrt(self.lora_dim / (trainable_rank + self.unit))
# NOTE weightに加算してからlinear/conv2dを呼んだほうが速いかも
return result
def state_dict(self, destination=None, prefix="", keep_vars=False):
# state dictを通常のLoRAと同じにする:
# nn.ParameterListは `.lora_A.0` みたいな名前になるので、forwardと同様にcatして入れ替える
sd = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
lora_A_weight = torch.cat(tuple(self.lora_A), dim=0)
if self.is_conv2d and not self.is_conv2d_3x3:
lora_A_weight = lora_A_weight.unsqueeze(-1).unsqueeze(-1)
lora_B_weight = torch.cat(tuple(self.lora_B), dim=1)
if self.is_conv2d and not self.is_conv2d_3x3:
lora_B_weight = lora_B_weight.unsqueeze(-1).unsqueeze(-1)
sd[self.lora_name + ".lora_down.weight"] = lora_A_weight if keep_vars else lora_A_weight.detach()
sd[self.lora_name + ".lora_up.weight"] = lora_B_weight if keep_vars else lora_B_weight.detach()
i = 0
while True:
key_a = f"{self.lora_name}.lora_A.{i}"
key_b = f"{self.lora_name}.lora_B.{i}"
if key_a in sd:
sd.pop(key_a)
sd.pop(key_b)
else:
break
i += 1
return sd
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
# 通常のLoRAと同じstate dictを読み込めるようにするこの方法はchatGPTに聞いた
lora_A_weight = state_dict.pop(self.lora_name + ".lora_down.weight", None)
lora_B_weight = state_dict.pop(self.lora_name + ".lora_up.weight", None)
if lora_A_weight is None or lora_B_weight is None:
if strict:
raise KeyError(f"{self.lora_name}.lora_down/up.weight is not found")
else:
return
if self.is_conv2d and not self.is_conv2d_3x3:
lora_A_weight = lora_A_weight.squeeze(-1).squeeze(-1)
lora_B_weight = lora_B_weight.squeeze(-1).squeeze(-1)
state_dict.update(
{f"{self.lora_name}.lora_A.{i}": nn.Parameter(lora_A_weight[i].unsqueeze(0)) for i in range(lora_A_weight.size(0))}
)
state_dict.update(
{f"{self.lora_name}.lora_B.{i}": nn.Parameter(lora_B_weight[:, i].unsqueeze(1)) for i in range(lora_B_weight.size(1))}
)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
if network_dim is None:
network_dim = 4 # default
if network_alpha is None:
network_alpha = 1.0
# extract dim/alpha for conv2d, and block dim
conv_dim = kwargs.get("conv_dim", None)
conv_alpha = kwargs.get("conv_alpha", None)
unit = kwargs.get("unit", None)
if conv_dim is not None:
conv_dim = int(conv_dim)
assert conv_dim == network_dim, "conv_dim must be same as network_dim"
if conv_alpha is None:
conv_alpha = 1.0
else:
conv_alpha = float(conv_alpha)
if unit is not None:
unit = int(unit)
else:
unit = 1
network = DyLoRANetwork(
text_encoder,
unet,
multiplier=multiplier,
lora_dim=network_dim,
alpha=network_alpha,
apply_to_conv=conv_dim is not None,
unit=unit,
varbose=True,
)
return network
# Create network from weights for inference, weights are not loaded here (because can be merged)
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
if weights_sd is None:
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file, safe_open
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
# get dim/alpha mapping
modules_dim = {}
modules_alpha = {}
for key, value in weights_sd.items():
if "." not in key:
continue
lora_name = key.split(".")[0]
if "alpha" in key:
modules_alpha[lora_name] = value
elif "lora_down" in key:
dim = value.size()[0]
modules_dim[lora_name] = dim
# print(lora_name, value.size(), dim)
# support old LoRA without alpha
for key in modules_dim.keys():
if key not in modules_alpha:
modules_alpha = modules_dim[key]
module_class = DyLoRAModule
network = DyLoRANetwork(
text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class
)
return network, weights_sd
class DyLoRANetwork(torch.nn.Module):
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
LORA_PREFIX_UNET = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te"
def __init__(
self,
text_encoder,
unet,
multiplier=1.0,
lora_dim=4,
alpha=1,
apply_to_conv=False,
modules_dim=None,
modules_alpha=None,
unit=1,
module_class=DyLoRAModule,
varbose=False,
) -> None:
super().__init__()
self.multiplier = multiplier
self.lora_dim = lora_dim
self.alpha = alpha
self.apply_to_conv = apply_to_conv
if modules_dim is not None:
print(f"create LoRA network from weights")
else:
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, unit: {unit}")
if self.apply_to_conv:
print(f"apply LoRA to Conv2d with kernel size (3,3).")
# create module instances
def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[DyLoRAModule]:
prefix = DyLoRANetwork.LORA_PREFIX_UNET if is_unet else DyLoRANetwork.LORA_PREFIX_TEXT_ENCODER
loras = []
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
is_linear = child_module.__class__.__name__ == "Linear"
is_conv2d = child_module.__class__.__name__ == "Conv2d"
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
if is_linear or is_conv2d:
lora_name = prefix + "." + name + "." + child_name
lora_name = lora_name.replace(".", "_")
dim = None
alpha = None
if modules_dim is not None:
if lora_name in modules_dim:
dim = modules_dim[lora_name]
alpha = modules_alpha[lora_name]
else:
if is_linear or is_conv2d_1x1 or apply_to_conv:
dim = self.lora_dim
alpha = self.alpha
if dim is None or dim == 0:
continue
# dropout and fan_in_fan_out is default
lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, unit)
loras.append(lora)
return loras
self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
target_modules = DyLoRANetwork.UNET_TARGET_REPLACE_MODULE
if modules_dim is not None or self.apply_to_conv:
target_modules += DyLoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
self.unet_loras = create_modules(True, unet, target_modules)
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
def set_multiplier(self, multiplier):
self.multiplier = multiplier
for lora in self.text_encoder_loras + self.unet_loras:
lora.multiplier = self.multiplier
def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")
info = self.load_state_dict(weights_sd, False)
return info
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
if apply_text_encoder:
print("enable LoRA for text encoder")
else:
self.text_encoder_loras = []
if apply_unet:
print("enable LoRA for U-Net")
else:
self.unet_loras = []
for lora in self.text_encoder_loras + self.unet_loras:
lora.apply_to()
self.add_module(lora.lora_name, lora)
"""
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
apply_text_encoder = apply_unet = False
for key in weights_sd.keys():
if key.startswith(DyLoRANetwork.LORA_PREFIX_TEXT_ENCODER):
apply_text_encoder = True
elif key.startswith(DyLoRANetwork.LORA_PREFIX_UNET):
apply_unet = True
if apply_text_encoder:
print("enable LoRA for text encoder")
else:
self.text_encoder_loras = []
if apply_unet:
print("enable LoRA for U-Net")
else:
self.unet_loras = []
for lora in self.text_encoder_loras + self.unet_loras:
sd_for_lora = {}
for key in weights_sd.keys():
if key.startswith(lora.lora_name):
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
lora.merge_to(sd_for_lora, dtype, device)
print(f"weights are merged")
"""
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
self.requires_grad_(True)
all_params = []
def enumerate_params(loras):
params = []
for lora in loras:
params.extend(lora.parameters())
return params
if self.text_encoder_loras:
param_data = {"params": enumerate_params(self.text_encoder_loras)}
if text_encoder_lr is not None:
param_data["lr"] = text_encoder_lr
all_params.append(param_data)
if self.unet_loras:
param_data = {"params": enumerate_params(self.unet_loras)}
if unet_lr is not None:
param_data["lr"] = unet_lr
all_params.append(param_data)
return all_params
def enable_gradient_checkpointing(self):
# not supported
pass
def prepare_grad_etc(self, text_encoder, unet):
self.requires_grad_(True)
def on_epoch_start(self, text_encoder, unet):
self.train()
def get_trainable_params(self):
return self.parameters()
def save_weights(self, file, dtype, metadata):
if metadata is not None and len(metadata) == 0:
metadata = None
state_dict = self.state_dict()
if dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(dtype)
state_dict[key] = v
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file
from library import train_util
# Precalculate model hashes to save time on indexing
if metadata is None:
metadata = {}
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
save_file(state_dict, file, metadata)
else:
torch.save(state_dict, file)
# mask is a tensor with values from 0 to 1
def set_region(self, sub_prompt_index, is_last_network, mask):
pass
def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared):
pass

View File

@@ -0,0 +1,125 @@
# Convert LoRA to different rank approximation (should only be used to go to lower rank)
# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
# Thanks to cloneofsimo
import argparse
import math
import os
import torch
from safetensors.torch import load_file, save_file, safe_open
from tqdm import tqdm
from library import train_util, model_util
import numpy as np
def load_state_dict(file_name):
if model_util.is_safetensors(file_name):
sd = load_file(file_name)
with safe_open(file_name, framework="pt") as f:
metadata = f.metadata()
else:
sd = torch.load(file_name, map_location="cpu")
metadata = None
return sd, metadata
def save_to_file(file_name, model, metadata):
if model_util.is_safetensors(file_name):
save_file(model, file_name, metadata)
else:
torch.save(model, file_name)
def split_lora_model(lora_sd, unit):
max_rank = 0
# Extract loaded lora dim and alpha
for key, value in lora_sd.items():
if "lora_down" in key:
rank = value.size()[0]
if rank > max_rank:
max_rank = rank
print(f"Max rank: {max_rank}")
rank = unit
split_models = []
new_alpha = None
while rank < max_rank:
print(f"Splitting rank {rank}")
new_sd = {}
for key, value in lora_sd.items():
if "lora_down" in key:
new_sd[key] = value[:rank].contiguous()
elif "lora_up" in key:
new_sd[key] = value[:, :rank].contiguous()
else:
# なぜかscaleするとおかしくなる……
# this_rank = lora_sd[key.replace("alpha", "lora_down.weight")].size()[0]
# scale = math.sqrt(this_rank / rank) # rank is > unit
# print(key, value.size(), this_rank, rank, value, scale)
# new_alpha = value * scale # always same
# new_sd[key] = new_alpha
new_sd[key] = value
split_models.append((new_sd, rank, new_alpha))
rank += unit
return max_rank, split_models
def split(args):
print("loading Model...")
lora_sd, metadata = load_state_dict(args.model)
print("Splitting Model...")
original_rank, split_models = split_lora_model(lora_sd, args.unit)
comment = metadata.get("ss_training_comment", "")
for state_dict, new_rank, new_alpha in split_models:
# update metadata
if metadata is None:
new_metadata = {}
else:
new_metadata = metadata.copy()
new_metadata["ss_training_comment"] = f"split from DyLoRA, rank {original_rank} to {new_rank}; {comment}"
new_metadata["ss_network_dim"] = str(new_rank)
# new_metadata["ss_network_alpha"] = str(new_alpha.float().numpy())
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
filename, ext = os.path.splitext(args.save_to)
model_file_name = filename + f"-{new_rank:04d}{ext}"
print(f"saving model to: {model_file_name}")
save_to_file(model_file_name, state_dict, new_metadata)
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("--unit", type=int, default=None, help="size of rank to split into / rankを分割するサイズ")
parser.add_argument(
"--save_to",
type=str,
default=None,
help="destination base file name: ckpt or safetensors file / 保存先のファイル名のbase、ckptまたはsafetensors",
)
parser.add_argument(
"--model",
type=str,
default=None,
help="DyLoRA model to resize at to new rank: ckpt or safetensors file / 読み込むDyLoRAモデル、ckptまたはsafetensors",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
split(args)

View File

@@ -2,6 +2,7 @@
from tqdm import tqdm
from library import model_util
import library.train_util as train_util
import argparse
from transformers import CLIPTokenizer
import torch
@@ -16,16 +17,20 @@ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def interrogate(args):
weights_dtype = torch.float16
# いろいろ準備する
print(f"loading SD model: {args.sd_model}")
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
args.pretrained_model_name_or_path = args.sd_model
args.vae = None
text_encoder, vae, unet, _ = train_util.load_target_model(args,weights_dtype, DEVICE)
print(f"loading LoRA: {args.model}")
network = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)
network, weights_sd = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)
# text encoder向けの重みがあるかチェックする本当はlora側でやるのがいい
has_te_weight = False
for key in network.weights_sd.keys():
for key in weights_sd.keys():
if 'lora_te' in key:
has_te_weight = True
break
@@ -40,9 +45,9 @@ def interrogate(args):
else:
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) # , model_max_length=max_token_length + 2)
text_encoder.to(DEVICE)
text_encoder.to(DEVICE, dtype=weights_dtype)
text_encoder.eval()
unet.to(DEVICE)
unet.to(DEVICE, dtype=weights_dtype)
unet.eval() # U-Netは呼び出さないので不要だけど
# トークンをひとつひとつ当たっていく
@@ -78,9 +83,14 @@ def interrogate(args):
orig_embs = get_all_embeddings(text_encoder)
network.apply_to(text_encoder, unet, True, len(network.unet_loras) > 0)
network.to(DEVICE)
info = network.load_state_dict(weights_sd, strict=False)
print(f"Loading LoRA weights: {info}")
network.to(DEVICE, dtype=weights_dtype)
network.eval()
del unet
print("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません")
print("get text encoder embeddings with lora.")
lora_embs = get_all_embeddings(text_encoder)
@@ -107,6 +117,7 @@ def interrogate(args):
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("--v2", action='store_true',
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
parser.add_argument("--sd_model", type=str, default=None,

View File

@@ -9,86 +9,122 @@ import library.model_util as model_util
def convert(args):
# 引数を確認する
load_dtype = torch.float16 if args.fp16 else None
# 引数を確認する
load_dtype = torch.float16 if args.fp16 else None
save_dtype = None
if args.fp16:
save_dtype = torch.float16
elif args.bf16:
save_dtype = torch.bfloat16
elif args.float:
save_dtype = torch.float
save_dtype = None
if args.fp16 or args.save_precision_as == "fp16":
save_dtype = torch.float16
elif args.bf16 or args.save_precision_as == "bf16":
save_dtype = torch.bfloat16
elif args.float or args.save_precision_as == "float":
save_dtype = torch.float
is_load_ckpt = os.path.isfile(args.model_to_load)
is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0
is_load_ckpt = os.path.isfile(args.model_to_load)
is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0
assert not is_load_ckpt or args.v1 != args.v2, f"v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です"
assert is_save_ckpt or args.reference_model is not None, f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です"
assert not is_load_ckpt or args.v1 != args.v2, f"v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です"
assert (
is_save_ckpt or args.reference_model is not None
), f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です"
# モデルを読み込む
msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else ""))
print(f"loading {msg}: {args.model_to_load}")
# モデルを読み込む
msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else ""))
print(f"loading {msg}: {args.model_to_load}")
if is_load_ckpt:
v2_model = args.v2
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load)
else:
pipe = StableDiffusionPipeline.from_pretrained(args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None)
text_encoder = pipe.text_encoder
vae = pipe.vae
unet = pipe.unet
if args.v1 == args.v2:
# 自動判定する
v2_model = unet.config.cross_attention_dim == 1024
print("checking model version: model is " + ('v2' if v2_model else 'v1'))
if is_load_ckpt:
v2_model = args.v2
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(v2_model, args.model_to_load)
else:
v2_model = not args.v1
pipe = StableDiffusionPipeline.from_pretrained(
args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None
)
text_encoder = pipe.text_encoder
vae = pipe.vae
unet = pipe.unet
# 変換して保存する
msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers"
print(f"converting and saving as {msg}: {args.model_to_save}")
if args.v1 == args.v2:
# 自動判定する
v2_model = unet.config.cross_attention_dim == 1024
print("checking model version: model is " + ("v2" if v2_model else "v1"))
else:
v2_model = not args.v1
if is_save_ckpt:
original_model = args.model_to_load if is_load_ckpt else None
key_count = model_util.save_stable_diffusion_checkpoint(v2_model, args.model_to_save, text_encoder, unet,
original_model, args.epoch, args.global_step, save_dtype, vae)
print(f"model saved. total converted state_dict keys: {key_count}")
else:
print(f"copy scheduler/tokenizer config from: {args.reference_model}")
model_util.save_diffusers_checkpoint(v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors)
print(f"model saved.")
# 変換して保存する
msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers"
print(f"converting and saving as {msg}: {args.model_to_save}")
if is_save_ckpt:
original_model = args.model_to_load if is_load_ckpt else None
key_count = model_util.save_stable_diffusion_checkpoint(
v2_model, args.model_to_save, text_encoder, unet, original_model, args.epoch, args.global_step, save_dtype, vae
)
print(f"model saved. total converted state_dict keys: {key_count}")
else:
print(f"copy scheduler/tokenizer config from: {args.reference_model}")
model_util.save_diffusers_checkpoint(
v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors
)
print(f"model saved.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("--v1", action='store_true',
help='load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む')
parser.add_argument("--v2", action='store_true',
help='load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む')
parser.add_argument("--fp16", action='store_true',
help='load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16形式で読み込みDiffusers形式のみ対応、保存するcheckpointのみ対応')
parser.add_argument("--bf16", action='store_true', help='save as bf16 (checkpoint only) / bf16形式で保存するcheckpointのみ対応')
parser.add_argument("--float", action='store_true',
help='save as float (checkpoint only) / float(float32)形式で保存するcheckpointのみ対応')
parser.add_argument("--epoch", type=int, default=0, help='epoch to write to checkpoint / checkpointに記録するepoch数の値')
parser.add_argument("--global_step", type=int, default=0,
help='global_step to write to checkpoint / checkpointに記録するglobal_stepの値')
parser.add_argument("--reference_model", type=str, default=None,
help="reference model for schduler/tokenizer, required in saving Diffusers, copy schduler/tokenizer from this / scheduler/tokenizerのコピー元のDiffusersモデル、Diffusers形式で保存するときに必要")
parser.add_argument("--use_safetensors", action='store_true',
help="use safetensors format to save Diffusers model (checkpoint depends on the file extension) / Duffusersモデルをsafetensors形式で保存するcheckpointは拡張子で自動判定")
parser = argparse.ArgumentParser()
parser.add_argument(
"--v1", action="store_true", help="load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む"
)
parser.add_argument(
"--v2", action="store_true", help="load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む"
)
parser.add_argument(
"--fp16",
action="store_true",
help="load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16形式で読み込みDiffusers形式のみ対応、保存するcheckpointのみ対応",
)
parser.add_argument("--bf16", action="store_true", help="save as bf16 (checkpoint only) / bf16形式で保存するcheckpointのみ対応")
parser.add_argument(
"--float", action="store_true", help="save as float (checkpoint only) / float(float32)形式で保存するcheckpointのみ対応"
)
parser.add_argument(
"--save_precision_as",
type=str,
default="no",
choices=["fp16", "bf16", "float"],
help="save precision, do not specify with --fp16/--bf16/--float / 保存する精度、--fp16/--bf16/--floatと併用しないでください",
)
parser.add_argument("--epoch", type=int, default=0, help="epoch to write to checkpoint / checkpointに記録するepoch数の値")
parser.add_argument(
"--global_step", type=int, default=0, help="global_step to write to checkpoint / checkpointに記録するglobal_stepの値"
)
parser.add_argument(
"--reference_model",
type=str,
default=None,
help="reference model for schduler/tokenizer, required in saving Diffusers, copy schduler/tokenizer from this / scheduler/tokenizerのコピー元のDiffusersモデル、Diffusers形式で保存するときに必要",
)
parser.add_argument(
"--use_safetensors",
action="store_true",
help="use safetensors format to save Diffusers model (checkpoint depends on the file extension) / Duffusersモデルをsafetensors形式で保存するcheckpointは拡張子で自動判定",
)
parser.add_argument("model_to_load", type=str, default=None,
help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ")
parser.add_argument("model_to_save", type=str, default=None,
help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存")
return parser
parser.add_argument(
"model_to_load",
type=str,
default=None,
help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ",
)
parser.add_argument(
"model_to_save",
type=str,
default=None,
help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存",
)
return parser
if __name__ == '__main__':
parser = setup_parser()
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
convert(args)
args = parser.parse_args()
convert(args)

342
tools/latent_upscaler.py Normal file
View File

@@ -0,0 +1,342 @@
# 外部から簡単にupscalerを呼ぶためのスクリプト
# 単体で動くようにモデル定義も含めている
import argparse
import glob
import os
import cv2
from diffusers import AutoencoderKL
from typing import Dict, List
import numpy as np
import torch
from torch import nn
from tqdm import tqdm
from PIL import Image
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=1):
super(ResidualBlock, self).__init__()
if out_channels is None:
out_channels = in_channels
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.relu2 = nn.ReLU(inplace=True) # このReLUはresidualに足す前にかけるほうがいいかも
# initialize weights
self._initialize_weights()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu1(out)
out = self.conv2(out)
out = self.bn2(out)
out += residual
out = self.relu2(out)
return out
class Upscaler(nn.Module):
def __init__(self):
super(Upscaler, self).__init__()
# define layers
# latent has 4 channels
self.conv1 = nn.Conv2d(4, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
self.bn1 = nn.BatchNorm2d(128)
self.relu1 = nn.ReLU(inplace=True)
# resblocks
# 数の暴力で20個次元数を増やすよりもブロックを増やしたほうがreceptive fieldが広がるはずだぞ
self.resblock1 = ResidualBlock(128)
self.resblock2 = ResidualBlock(128)
self.resblock3 = ResidualBlock(128)
self.resblock4 = ResidualBlock(128)
self.resblock5 = ResidualBlock(128)
self.resblock6 = ResidualBlock(128)
self.resblock7 = ResidualBlock(128)
self.resblock8 = ResidualBlock(128)
self.resblock9 = ResidualBlock(128)
self.resblock10 = ResidualBlock(128)
self.resblock11 = ResidualBlock(128)
self.resblock12 = ResidualBlock(128)
self.resblock13 = ResidualBlock(128)
self.resblock14 = ResidualBlock(128)
self.resblock15 = ResidualBlock(128)
self.resblock16 = ResidualBlock(128)
self.resblock17 = ResidualBlock(128)
self.resblock18 = ResidualBlock(128)
self.resblock19 = ResidualBlock(128)
self.resblock20 = ResidualBlock(128)
# last convs
self.conv2 = nn.Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
self.bn2 = nn.BatchNorm2d(64)
self.relu2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
self.bn3 = nn.BatchNorm2d(64)
self.relu3 = nn.ReLU(inplace=True)
# final conv: output 4 channels
self.conv_final = nn.Conv2d(64, 4, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
# initialize weights
self._initialize_weights()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
# initialize final conv weights to 0: 流行りのzero conv
nn.init.constant_(self.conv_final.weight, 0)
def forward(self, x):
inp = x
x = self.conv1(x)
x = self.bn1(x)
x = self.relu1(x)
# いくつかのresblockを通した後に、residualを足すことで精度向上と学習速度向上が見込めるはず
residual = x
x = self.resblock1(x)
x = self.resblock2(x)
x = self.resblock3(x)
x = self.resblock4(x)
x = x + residual
residual = x
x = self.resblock5(x)
x = self.resblock6(x)
x = self.resblock7(x)
x = self.resblock8(x)
x = x + residual
residual = x
x = self.resblock9(x)
x = self.resblock10(x)
x = self.resblock11(x)
x = self.resblock12(x)
x = x + residual
residual = x
x = self.resblock13(x)
x = self.resblock14(x)
x = self.resblock15(x)
x = self.resblock16(x)
x = x + residual
residual = x
x = self.resblock17(x)
x = self.resblock18(x)
x = self.resblock19(x)
x = self.resblock20(x)
x = x + residual
x = self.conv2(x)
x = self.bn2(x)
x = self.relu2(x)
x = self.conv3(x)
x = self.bn3(x)
# ここにreluを入れないほうがいい気がする
x = self.conv_final(x)
# network estimates the difference between the input and the output
x = x + inp
return x
def support_latents(self) -> bool:
return False
def upscale(
self,
vae: AutoencoderKL,
lowreso_images: List[Image.Image],
lowreso_latents: torch.Tensor,
dtype: torch.dtype,
width: int,
height: int,
batch_size: int = 1,
vae_batch_size: int = 1,
):
# assertion
assert lowreso_images is not None, "Upscaler requires lowreso image"
# make upsampled image with lanczos4
upsampled_images = []
for lowreso_image in lowreso_images:
upsampled_image = np.array(lowreso_image.resize((width, height), Image.LANCZOS))
upsampled_images.append(upsampled_image)
# convert to tensor: this tensor is too large to be converted to cuda
upsampled_images = [torch.from_numpy(upsampled_image).permute(2, 0, 1).float() for upsampled_image in upsampled_images]
upsampled_images = torch.stack(upsampled_images, dim=0)
upsampled_images = upsampled_images.to(dtype)
# normalize to [-1, 1]
upsampled_images = upsampled_images / 127.5 - 1.0
# convert upsample images to latents with batch size
# print("Encoding upsampled (LANCZOS4) images...")
upsampled_latents = []
for i in tqdm(range(0, upsampled_images.shape[0], vae_batch_size)):
batch = upsampled_images[i : i + vae_batch_size].to(vae.device)
with torch.no_grad():
batch = vae.encode(batch).latent_dist.sample()
upsampled_latents.append(batch)
upsampled_latents = torch.cat(upsampled_latents, dim=0)
# upscale (refine) latents with this model with batch size
print("Upscaling latents...")
upscaled_latents = []
for i in range(0, upsampled_latents.shape[0], batch_size):
with torch.no_grad():
upscaled_latents.append(self.forward(upsampled_latents[i : i + batch_size]))
upscaled_latents = torch.cat(upscaled_latents, dim=0)
return upscaled_latents * 0.18215
# external interface: returns a model
def create_upscaler(**kwargs):
weights = kwargs["weights"]
model = Upscaler()
print(f"Loading weights from {weights}...")
model.load_state_dict(torch.load(weights, map_location=torch.device("cpu")))
return model
# another interface: upscale images with a model for given images from command line
def upscale_images(args: argparse.Namespace):
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
us_dtype = torch.float16 # TODO: support fp32/bf16
os.makedirs(args.output_dir, exist_ok=True)
# load VAE with Diffusers
assert args.vae_path is not None, "VAE path is required"
print(f"Loading VAE from {args.vae_path}...")
vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae")
vae.to(DEVICE, dtype=us_dtype)
# prepare model
print("Preparing model...")
upscaler: Upscaler = create_upscaler(weights=args.weights)
# print("Loading weights from", args.weights)
# upscaler.load_state_dict(torch.load(args.weights))
upscaler.eval()
upscaler.to(DEVICE, dtype=us_dtype)
# load images
image_paths = glob.glob(args.image_pattern)
images = []
for image_path in image_paths:
image = Image.open(image_path)
image = image.convert("RGB")
# make divisible by 8
width = image.width
height = image.height
if width % 8 != 0:
width = width - (width % 8)
if height % 8 != 0:
height = height - (height % 8)
if width != image.width or height != image.height:
image = image.crop((0, 0, width, height))
images.append(image)
# debug output
if args.debug:
for image, image_path in zip(images, image_paths):
image_debug = image.resize((image.width * 2, image.height * 2), Image.LANCZOS)
basename = os.path.basename(image_path)
basename_wo_ext, ext = os.path.splitext(basename)
dest_file_name = os.path.join(args.output_dir, f"{basename_wo_ext}_lanczos4{ext}")
image_debug.save(dest_file_name)
# upscale
print("Upscaling...")
upscaled_latents = upscaler.upscale(
vae, images, None, us_dtype, width * 2, height * 2, batch_size=args.batch_size, vae_batch_size=args.vae_batch_size
)
upscaled_latents /= 0.18215
# decode with batch
print("Decoding...")
upscaled_images = []
for i in tqdm(range(0, upscaled_latents.shape[0], args.vae_batch_size)):
with torch.no_grad():
batch = vae.decode(upscaled_latents[i : i + args.vae_batch_size]).sample
batch = batch.to("cpu")
upscaled_images.append(batch)
upscaled_images = torch.cat(upscaled_images, dim=0)
# tensor to numpy
upscaled_images = upscaled_images.permute(0, 2, 3, 1).numpy()
upscaled_images = (upscaled_images + 1.0) * 127.5
upscaled_images = upscaled_images.clip(0, 255).astype(np.uint8)
upscaled_images = upscaled_images[..., ::-1]
# save images
for i, image in enumerate(upscaled_images):
basename = os.path.basename(image_paths[i])
basename_wo_ext, ext = os.path.splitext(basename)
dest_file_name = os.path.join(args.output_dir, f"{basename_wo_ext}_upscaled{ext}")
cv2.imwrite(dest_file_name, image)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--vae_path", type=str, default=None, help="VAE path")
parser.add_argument("--weights", type=str, default=None, help="Weights path")
parser.add_argument("--image_pattern", type=str, default=None, help="Image pattern")
parser.add_argument("--output_dir", type=str, default=".", help="Output directory")
parser.add_argument("--batch_size", type=int, default=4, help="Batch size")
parser.add_argument("--vae_batch_size", type=int, default=1, help="VAE batch size")
parser.add_argument("--debug", action="store_true", help="Debug mode")
args = parser.parse_args()
upscale_images(args)

View File

@@ -2,7 +2,7 @@ __ドキュメント更新中のため記述に誤りがあるかもしれませ
# 学習について、共通編
当リポジトリではモデルのfine tuning、DreamBooth、およびLoRAとTextual Inversionの学習をサポートします。この文書ではそれらに共通する、学習データの準備方法やオプション等について説明します。
当リポジトリではモデルのfine tuning、DreamBooth、およびLoRAとTextual Inversion[XTI:P+](https://github.com/kohya-ss/sd-scripts/pull/327)を含む)の学習をサポートします。この文書ではそれらに共通する、学習データの準備方法やオプション等について説明します。
# 概要
@@ -535,7 +535,7 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
- `--debug_dataset`
このオプションを付けることで学習を行う前に事前にどのような画像データ、キャプションで学習されるかを確認できます。Escキーを押すと終了してコマンドラインに戻ります。
このオプションを付けることで学習を行う前に事前にどのような画像データ、キャプションで学習されるかを確認できます。Escキーを押すと終了してコマンドラインに戻ります。`S`キーで次のステップ(バッチ)、`E`キーで次のエポックに進みます。
※Linux環境Colabを含むでは画像は表示されません。
@@ -545,6 +545,13 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
DreamBoothおよびfine tuningでは、保存されるモデルはこのVAEを組み込んだものになります。
- `--cache_latents`
使用VRAMを減らすためVAEの出力をメインメモリにキャッシュします。`flip_aug` 以外のaugmentationは使えなくなります。また全体の学習速度が若干速くなります。
- `--min_snr_gamma`
Min-SNR Weighting strategyを指定します。詳細は[こちら](https://github.com/kohya-ss/sd-scripts/pull/308)を参照してください。論文では`5`が推奨されています。
## オプティマイザ関係
@@ -570,7 +577,7 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
学習率のスケジューラ関連の指定です。
lr_schedulerオプションで学習率のスケジューラをlinear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmupから選べます。デフォルトはconstantです。
lr_schedulerオプションで学習率のスケジューラをlinear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmup, 任意のスケジューラから選べます。デフォルトはconstantです。
lr_warmup_stepsでスケジューラのウォームアップだんだん学習率を変えていくステップ数を指定できます。
@@ -578,6 +585,8 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
詳細については各自お調べください。
任意のスケジューラを使う場合、任意のオプティマイザと同様に、`--scheduler_args`でオプション引数を指定してください。
### オプティマイザの指定について
オプティマイザのオプション引数は--optimizer_argsオプションで指定してください。key=valueの形式で、複数の値が指定できます。また、valueはカンマ区切りで複数の値が指定できます。たとえばAdamWオプティマイザに引数を指定する場合は、``--optimizer_args weight_decay=0.01 betas=.9,.999``のようになります。

View File

@@ -117,12 +117,14 @@ def train(args):
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size)
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
accelerator.wait_for_everyone()
# 学習を準備する:モデルを適切な状態にする
train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0
unet.requires_grad_(True) # 念のため追加
@@ -229,7 +231,7 @@ def train(args):
)
if accelerator.is_main_process:
accelerator.init_trackers("dreambooth")
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name)
loss_list = []
loss_total = 0.0

View File

@@ -172,12 +172,14 @@ def train(args):
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size)
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
accelerator.wait_for_everyone()
# prepare network
import sys
@@ -195,7 +197,7 @@ def train(args):
network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs)
if network is None:
return
if hasattr(network, "prepare_network"):
network.prepare_network(args)
@@ -219,7 +221,9 @@ def train(args):
try:
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate)
except TypeError:
print("Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)")
print(
"Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)"
)
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
@@ -534,11 +538,17 @@ def train(args):
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
)
if accelerator.is_main_process:
accelerator.init_trackers("network_train")
accelerator.init_trackers("network_train" if args.log_tracker_name is None else args.log_tracker_name)
loss_list = []
loss_total = 0.0
del train_dataset_group
# if hasattr(network, "on_step_start"):
# on_step_start = network.on_step_start
# else:
# on_step_start = lambda *args, **kwargs: None
for epoch in range(num_train_epochs):
if is_main_process:
print(f"epoch {epoch+1}/{num_train_epochs}")
@@ -551,6 +561,8 @@ def train(args):
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(network):
# on_step_start(text_encoder, unet)
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
@@ -563,16 +575,17 @@ def train(args):
with torch.set_grad_enabled(train_text_encoder):
# Get the text embedding for conditioning
if args.weighted_captions:
encoder_hidden_states = get_weighted_text_embeddings(tokenizer,
text_encoder,
batch["captions"],
accelerator.device,
args.max_token_length // 75 if args.max_token_length else 1,
clip_skip=args.clip_skip,
encoder_hidden_states = get_weighted_text_embeddings(
tokenizer,
text_encoder,
batch["captions"],
accelerator.device,
args.max_token_length // 75 if args.max_token_length else 1,
clip_skip=args.clip_skip,
)
else:
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype)
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
@@ -757,4 +770,4 @@ if __name__ == "__main__":
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
train(args)
train(args)

View File

@@ -12,11 +12,31 @@ Conv2d 3x3への拡大は [cloneofsimo氏](https://github.com/cloneofsimo/lora)
[学習についての共通ドキュメント](./train_README-ja.md) もあわせてご覧ください。
# 学習できるLoRAの種類
以下の二種類をサポートします。以下は当リポジトリ内の独自の名称です。
1. __LoRA-LierLa__ : (LoRA for __Li__ n __e__ a __r__ __La__ yers、リエラと読みます)
Linear およびカーネルサイズ 1x1 の Conv2d に適用されるLoRA
2. __LoRA-C3Lier__ : (LoRA for __C__ olutional layers with __3__ x3 Kernel and __Li__ n __e__ a __r__ layers、セリアと読みます)
1.に加え、カーネルサイズ 3x3 の Conv2d に適用されるLoRA
LoRA-LierLaに比べ、LoRA-C3Liarは適用される層が増える分、高い精度が期待できるかもしれません。
また学習時は __DyLoRA__ を使用することもできます(後述します)。
## 学習したモデルに関する注意
cloneofsimo氏のリポジトリ、およびd8ahazard氏の[Dreambooth Extension for Stable-Diffusion-WebUI](https://github.com/d8ahazard/sd_dreambooth_extension)とは、現時点では互換性がありません。いくつかの機能拡張を行っているためです(後述)
LoRA-LierLa は、AUTOMATIC1111氏のWeb UIのLoRA機能で使用することができます
WebUI等で画像生成する場合には、学習したLoRAのモデルを学習元のStable Diffusionのモデルにこのリポジトリ内のスクリプトであらかじめマージしておくか、こちらの[WebUI用extension](https://github.com/kohya-ss/sd-webui-additional-networks)を使ってください。
LoRA-C3Liarを使いWeb UIで生成するには、こちらの[WebUI用extension](https://github.com/kohya-ss/sd-webui-additional-networks)を使ってください。
いずれも学習したLoRAのモデルを、Stable Diffusionのモデルにこのリポジトリ内のスクリプトであらかじめマージすることもできます。
cloneofsimo氏のリポジトリ、およびd8ahazard氏の[Dreambooth Extension for Stable-Diffusion-WebUI](https://github.com/d8ahazard/sd_dreambooth_extension)とは、現時点では互換性がありません。いくつかの機能拡張を行っているためです(後述)。
# 学習の手順
@@ -31,9 +51,9 @@ WebUI等で画像生成する場合には、学習したLoRAのモデルを学
`train_network.py`を用います。
`train_network.py`では `--network_module` オプションに、学習対象のモジュール名を指定します。LoRAに対応するのはnetwork.loraとなりますので、それを指定してください。
`train_network.py`では `--network_module` オプションに、学習対象のモジュール名を指定します。LoRAに対応するのは`network.lora`となりますので、それを指定してください。
なお学習率は通常のDreamBoothやfine tuningよりも高めの、1e-4程度を指定するとよいようです。
なお学習率は通常のDreamBoothやfine tuningよりも高めの、`1e-4``1e-3`程度を指定するとよいようです。
以下はコマンドラインの例です。
@@ -56,6 +76,8 @@ accelerate launch --num_cpu_threads_per_process 1 train_network.py
--network_module=networks.lora
```
このコマンドラインでは LoRA-LierLa が学習されます。
`--output_dir` オプションで指定したフォルダに、LoRAのモデルが保存されます。他のオプション、オプティマイザ等については [学習の共通ドキュメント](./train_README-ja.md) の「よく使われるオプション」も参照してください。
その他、以下のオプションが指定できます。
@@ -83,22 +105,143 @@ accelerate launch --num_cpu_threads_per_process 1 train_network.py
`--network_train_unet_only``--network_train_text_encoder_only` の両方とも未指定時デフォルトはText EncoderとU-Netの両方のLoRAモジュールを有効にします。
## LoRA を Conv2d に拡大して適用する
# その他の学習方法
通常のLoRAは Linear およぴカーネルサイズ 1x1 の Conv2d にのみ適用されますが、カーネルサイズ 3x3 のConv2dに適用を拡大することもできます。
## LoRA-C3Lier を学習する
`--network_args` に以下のように指定してください。`conv_dim` で Conv2d (3x3) の rank を、`conv_alpha` で alpha を指定してください。
```
--network_args "conv_dim=1" "conv_alpha=1"
--network_args "conv_dim=4" "conv_alpha=1"
```
以下のように alpha 省略時は1になります。
```
--network_args "conv_dim=1"
--network_args "conv_dim=4"
```
## DyLoRA
DyLoRAはこちらの論文で提案されたものです。[DyLoRA: Parameter Efficient Tuning of Pre-trained Models using Dynamic Search-Free Low-Rank Adaptation](https://arxiv.org/abs/2210.07558) 公式実装は[こちら](https://github.com/huawei-noah/KD-NLP/tree/main/DyLoRA)です。
論文によると、LoRAのrankは必ずしも高いほうが良いわけではなく、対象のモデル、データセット、タスクなどにより適切なrankを探す必要があるようです。DyLoRAを使うと、指定したdim(rank)以下のさまざまなrankで同時にLoRAを学習します。これにより最適なrankをそれぞれ学習して探す手間を省くことができます。
当リポジトリの実装は公式実装をベースに独自の拡張を加えています(そのため不具合などあるかもしれません)。
### 当リポジトリのDyLoRAの特徴
学習後のDyLoRAのモデルファイルはLoRAと互換性があります。また、モデルファイルから指定したdim(rank)以下の複数のdimのLoRAを抽出できます。
DyLoRA-LierLa、DyLoRA-C3Lierのどちらも学習できます。
### DyLoRAで学習する
`--network_module=networks.dylora` のように、DyLoRAに対応する`network.dylora`を指定してください。
また `--network_args` に、たとえば`--network_args "unit=4"`のように`unit`を指定します。`unit`はrankを分割する単位です。たとえば`--network_dim=16 --network_args "unit=4"` のように指定します。`unit``network_dim`を割り切れる値(`network_dim``unit`の倍数)としてください。
`unit`を指定しない場合は、`unit=1`として扱われます。
記述例は以下です。
```
--network_module=networks.dylora --network_dim=16 --network_args "unit=4"
--network_module=networks.dylora --network_dim=32 --network_alpha=16 --network_args "unit=4"
```
DyLoRA-C3Lierの場合は、`--network_args``"conv_dim=4"`のように`conv_dim`を指定します。通常のLoRAと異なり、`conv_dim``network_dim`と同じ値である必要があります。記述例は以下です。
```
--network_module=networks.dylora --network_dim=16 --network_args "conv_dim=16" "unit=4"
--network_module=networks.dylora --network_dim=32 --network_alpha=16 --network_args "conv_dim=32" "conv_alpha=16" "unit=8"
```
たとえばdim=16、unit=4後述で学習すると、4、8、12、16の4つのrankのLoRAを学習、抽出できます。抽出した各モデルで画像を生成し、比較することで、最適なrankのLoRAを選択できます。
その他のオプションは通常のLoRAと同じです。
`unit`は当リポジトリの独自拡張で、DyLoRAでは同dim(rank)の通常LoRAに比べると学習時間が長くなることが予想されるため、分割単位を大きくしたものです。
### DyLoRAのモデルからLoRAモデルを抽出する
`networks`フォルダ内の `extract_lora_from_dylora.py`を使用します。指定した`unit`単位で、DyLoRAのモデルからLoRAのモデルを抽出します。
コマンドラインはたとえば以下のようになります。
```powershell
python networks\extract_lora_from_dylora.py --model "foldername/dylora-model.safetensors" --save_to "foldername/dylora-model-split.safetensors" --unit 4
```
`--model` にはDyLoRAのモデルファイルを指定します。`--save_to` には抽出したモデルを保存するファイル名を指定しますrankの数値がファイル名に付加されます`--unit` にはDyLoRAの学習時の`unit`を指定します。
## 階層別学習率
詳細は[PR #355](https://github.com/kohya-ss/sd-scripts/pull/355) をご覧ください。
フルモデルの25個のブロックの重みを指定できます。最初のブロックに該当するLoRAは存在しませんが、階層別LoRA適用等との互換性のために25個としています。またconv2d3x3に拡張しない場合も一部のブロックにはLoRAが存在しませんが、記述を統一するため常に25個の値を指定してください。
`--network_args` で以下の引数を指定してください。
- `down_lr_weight` : U-Netのdown blocksの学習率の重みを指定します。以下が指定可能です。
- ブロックごとの重み : `"down_lr_weight=0,0,0,0,0,0,1,1,1,1,1,1"` のように12個の数値を指定します。
- プリセットからの指定 : `"down_lr_weight=sine"` のように指定しますサインカーブで重みを指定します。sine, cosine, linear, reverse_linear, zeros が指定可能です。また `"down_lr_weight=cosine+.25"` のように `+数値` を追加すると、指定した数値を加算します0.25~1.25になります)。
- `mid_lr_weight` : U-Netのmid blockの学習率の重みを指定します。`"down_lr_weight=0.5"` のように数値を一つだけ指定します。
- `up_lr_weight` : U-Netのup blocksの学習率の重みを指定します。down_lr_weightと同様です。
- 指定を省略した部分は1.0として扱われます。また重みを0にするとそのブロックのLoRAモジュールは作成されません。
- `block_lr_zero_threshold` : 重みがこの値以下の場合、LoRAモジュールを作成しません。デフォルトは0です。
### 階層別学習率コマンドライン指定例:
```powershell
--network_args "down_lr_weight=0.5,0.5,0.5,0.5,1.0,1.0,1.0,1.0,1.5,1.5,1.5,1.5" "mid_lr_weight=2.0" "up_lr_weight=1.5,1.5,1.5,1.5,1.0,1.0,1.0,1.0,0.5,0.5,0.5,0.5"
--network_args "block_lr_zero_threshold=0.1" "down_lr_weight=sine+.5" "mid_lr_weight=1.5" "up_lr_weight=cosine+.5"
```
### 階層別学習率tomlファイル指定例:
```toml
network_args = [ "down_lr_weight=0.5,0.5,0.5,0.5,1.0,1.0,1.0,1.0,1.5,1.5,1.5,1.5", "mid_lr_weight=2.0", "up_lr_weight=1.5,1.5,1.5,1.5,1.0,1.0,1.0,1.0,0.5,0.5,0.5,0.5",]
network_args = [ "block_lr_zero_threshold=0.1", "down_lr_weight=sine+.5", "mid_lr_weight=1.5", "up_lr_weight=cosine+.5", ]
```
## 階層別dim (rank)
フルモデルの25個のブロックのdim (rank)を指定できます。階層別学習率と同様に一部のブロックにはLoRAが存在しない場合がありますが、常に25個の値を指定してください。
`--network_args` で以下の引数を指定してください。
- `block_dims` : 各ブロックのdim (rank)を指定します。`"block_dims=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"` のように25個の数値を指定します。
- `block_alphas` : 各ブロックのalphaを指定します。block_dimsと同様に25個の数値を指定します。省略時はnetwork_alphaの値が使用されます。
- `conv_block_dims` : LoRAをConv2d 3x3に拡張し、各ブロックのdim (rank)を指定します。
- `conv_block_alphas` : LoRAをConv2d 3x3に拡張したときの各ブロックのalphaを指定します。省略時はconv_alphaの値が使用されます。
### 階層別dim (rank)コマンドライン指定例:
```powershell
--network_args "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2"
--network_args "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2" "conv_block_dims=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"
--network_args "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2" "block_alphas=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"
```
### 階層別dim (rank)tomlファイル指定例:
```toml
network_args = [ "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2",]
network_args = [ "block_dims=2,4,4,4,8,8,8,8,12,12,12,12,16,12,12,12,12,8,8,8,8,4,4,4,2", "block_alphas=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2",]
```
# その他のスクリプト
マージ等LoRAに関連するスクリプト群です。
## マージスクリプトについて
merge_lora.pyでStable DiffusionのモデルにLoRAの学習結果をマージしたり、複数のLoRAモデルをマージしたりできます。
@@ -323,14 +466,14 @@ python tools\resize_images_to_resolution.py --max_resolution 512x512,384x384,256
- 縮小時の補完方法を指定します。``area, cubic, lanczos4``から選択可能で、デフォルトは``area``です。
## 追加情報
# 追加情報
### cloneofsimo氏のリポジトリとの違い
## cloneofsimo氏のリポジトリとの違い
2022/12/25時点では、当リポジトリはLoRAの適用個所をText EncoderのMLP、U-NetのFFN、Transformerのin/out projectionに拡大し、表現力が増しています。ただその代わりメモリ使用量は増え、8GBぎりぎりになりました。
またモジュール入れ替え機構は全く異なります。
### 将来拡張について
## 将来拡張について
LoRAだけでなく他の拡張にも対応可能ですので、それらも追加予定です。

View File

@@ -185,10 +185,10 @@ def train(args):
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
current_epoch = Value('i',0)
current_step = Value('i',0)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch,current_step, ds_for_collater)
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
if use_template:
@@ -233,12 +233,14 @@ def train(args):
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size)
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
accelerator.wait_for_everyone()
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
text_encoder.gradient_checkpointing_enable()
@@ -262,7 +264,9 @@ def train(args):
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# データセット側にも学習ステップを送信
@@ -333,11 +337,11 @@ def train(args):
)
if accelerator.is_main_process:
accelerator.init_trackers("textual_inversion")
accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name)
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch+1
current_epoch.value = epoch + 1
text_encoder.train()
@@ -357,7 +361,7 @@ def train(args):
# Get the text embedding for conditioning
input_ids = batch["input_ids"].to(accelerator.device)
# weight_dtype) use float instead of fp16/bf16 because text encoder is float
# use float instead of fp16/bf16 because text encoder is float
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, torch.float)
# Sample noise that we'll add to the latents
@@ -375,7 +379,8 @@ def train(args):
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
with accelerator.autocast():
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
if args.v_parameterization:
# v-parameterization training
@@ -385,9 +390,9 @@ def train(args):
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights

View File

@@ -267,12 +267,14 @@ def train(args):
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size)
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
accelerator.wait_for_everyone()
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
text_encoder.gradient_checkpointing_enable()
@@ -369,7 +371,7 @@ def train(args):
)
if accelerator.is_main_process:
accelerator.init_trackers("textual_inversion")
accelerator.init_trackers("textual_inversion" if args.log_tracker_name is None else args.log_tracker_name)
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
@@ -416,7 +418,8 @@ def train(args):
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample
with accelerator.autocast():
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample
if args.v_parameterization:
# v-parameterization training

View File

@@ -4,7 +4,7 @@
実装に当たっては https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion を大いに参考にしました。
学習したモデルはWeb UIでもそのまま使えます。なお恐らくSD2.xにも対応していますが現時点では未テストです。
学習したモデルはWeb UIでもそのまま使えます。
# 学習の手順